'''Computes the scale factor (1-BR_{ice}^{cat})**-1., as described in Section 4'''

import iris
import numpy as np
import numpy.ma as ma
import iris.coord_categorisation as icc
import os
import cube_magic as cm
import datetime as dt
import glob
import copy
import sys
sys.path.insert(0,'/home/h01/hadax/Python/PhD/Heat_budget/v6_May_2017/')
import cice_concat
import field_functions as ff

model = 'HadGEM3_GC3.1'
experiment = 'Hist1'
startyear = 1980
endyear = 1999
tprof = 'monthly'
gridname = 'eorca1'

sf = ff.sample_field(gridname)

emis_ice = 1.0
kice = 2.03
ksno = 0.33
sb_const = 5.67e-8

ddic = {198: 'i', 199:'j'}

datadir_tsfc = '/data/local/hadax/PhD/Attribution/' + model + '/' + \
    experiment + '/' + tprof + '/' + gridname + '/' + 'TTsfc' + '/' 
datadir_ice_gbm = '/data/local/hadax/PhD/' + model + '/' + \
    experiment + '/' + 'ICE_STATE' + '/' + 'monthly' + '/' 
datadir_ice_cat = '/data/local/hadax/PhD/' + model + '/' + \
    experiment + '/' + 'ICE_CAT' + '/' + 'monthly' + '/' 

vars_save = ['complex_scale_factor','complex_scale_factor2',\
    'complex_scale_factor3','B_lw_dep','C_lw_ref'] + \
   ['complex_scale_factor.cat.'+str(cat) for cat in range(1,6)] + \
   ['Ice_area.cat.'+str(cat) for cat in range(1,6)] + \
   ['Ice_thickness.cat.'+str(cat) for cat in range(1,6)]
savedirs = {var: '/data/local/hadax/PhD/Attribution/' + model + '/' + \
        experiment + '/' + tprof + '/' + gridname + '/' + \
        var + '/' for var in vars_save}

for var in vars_save:
    if not os.path.exists(savedirs[var]):
        os.makedirs(savedirs[var])


for year in range(startyear,endyear+1):
    cubelist_dic = {var: iris.cube.CubeList([]) for var in vars_save}
    cube_dic = {}
    tsfc_file = datadir_tsfc + str(year) + '.nc'
    tsfc_cube = iris.load_cube(tsfc_file)
    try:
        icc.add_month(tsfc_cube,'time')
    except ValueError:
        pass

    decade = year / 10
    ydec = year % 10
    ystring = ddic[decade] + str(ydec)

    csf_cubelist = iris.cube.CubeList([])
    csf2_cubelist = iris.cube.CubeList([])
    bld_cubelist = iris.cube.CubeList([])
    clr_cubelist = iris.cube.CubeList([])
    for month in range(1,13):
        print '{:5d} {:3d}'.format(year,month)
        month_name = dt.date(1980,month,1).strftime('%b')

        if model == 'HadGEM2_ES':
            sstring_ice_gbm = datadir_ice_gbm + '??????.p?'+ystring+\
                month_name.lower()+'.pp'
            sstring_ice_cat = datadir_ice_cat + '??????.p?'+ystring+\
                month_name.lower()+'.pp'
        elif model == 'HadGEM3_GC3.1':
            sstring_ice_gbm = datadir_ice_gbm + 'cice_?????i_1m_'+str(year)+\
                str(month).zfill(2)+'01-??????01.nc'
            sstring_ice_cat = datadir_ice_cat + 'cice_?????i_1m_'+str(year)+\
                str(month).zfill(2)+'01-??????01.nc'

        files_ice_gbm = glob.glob(sstring_ice_gbm)
        files_ice_cat = glob.glob(sstring_ice_cat)

        cubes_ice_gbm = iris.load(files_ice_gbm)
        cubes_ice_cat = iris.load(files_ice_cat)

        snow_gbm_index = cm.field_index(cubes_ice_gbm,\
            'thickness_of_snow_on_sea_ice')[0][0]
        ice_area_index = cm.field_index(cubes_ice_gbm,\
            'sea_ice_area_fraction')[0][0]

        ice_area_cat_index = cm.field_index(cubes_ice_cat,\
            'category_ice_area')[0][0]
        ice_volume_cat_index = cm.field_index(cubes_ice_cat,\
            'category_ice_volume')[0][0]

        snow_cube_gbm = cubes_ice_gbm[snow_gbm_index]
        ice_area_cube = cubes_ice_gbm[ice_area_index]

        ice_area_cat_cube = cubes_ice_cat[ice_area_cat_index]
        ice_volume_cat_cube = cubes_ice_cat[ice_volume_cat_index]

        if len(snow_cube_gbm.shape) == 3:
            snow_cube_gbm = snow_cube_gbm[0,:,:]

        if len(ice_area_cube.shape) == 3:
            ice_area_cube = ice_area_cube[0,:,:]

        if len(ice_area_cat_cube.shape) == 4:
            ice_area_cat_cube = ice_area_cat_cube[0,:,:]

        if len(ice_volume_cat_cube.shape) == 4:
            ice_volume_cat_cube = ice_volume_cat_cube[0,:,:,:]

        tsfc_index = np.where(tsfc_cube.coord('month').points == month_name)[0]
        tsfc_subcube = tsfc_cube[tsfc_index,:,:]
        B_lw_dep = 4. * emis_ice * sb_const * tsfc_subcube ** 3. * -1.
        C_lw_ref = emis_ice * sb_const * tsfc_subcube ** 4.

        snow_cube_cat = iris.cube.Cube(np.zeros(ice_area_cat_cube.shape))
        snow_cube_cat.units = 'm'
        B_lw_dep_cat = iris.cube.Cube(np.zeros(ice_area_cat_cube.shape))

        for coord in ice_area_cat_cube._dim_coords_and_dims:
            snow_cube_cat.add_dim_coord(*coord)
            B_lw_dep_cat.add_dim_coord(*coord)

        for coord in ice_area_cat_cube._aux_coords_and_dims:
            snow_cube_cat.add_aux_coord(*coord)
            B_lw_dep_cat.add_aux_coord(*coord)

        for cat in range(5):
            snow_cube_cat.data[cat,:,:] = snow_cube_gbm.data
            B_lw_dep_cat.data[cat,:,:] = B_lw_dep[0,:,:].data

        local_ice_thickness = ice_volume_cat_cube / ice_area_cat_cube
        local_ice_thickness.units = 'm'

        pl_coord = iris.coords.DimCoord(range(1,6),long_name='pseudo_level')
        try:
            local_ice_thickness.add_dim_coord(pl_coord,0)
            snow_cube_cat.add_dim_coord(pl_coord,0)
        except ValueError:
            pass

        scf_cat = ((local_ice_thickness / kice + snow_cube_cat / ksno)*\
                   B_lw_dep_cat*-1. + 1.)**-1.
        scf2_cat = ((local_ice_thickness / kice + snow_cube_cat / ksno)*\
                   B_lw_dep_cat*-1. + 1.)**-2.
        scf3_cat = ((local_ice_thickness / kice + snow_cube_cat / ksno)*\
                   B_lw_dep_cat*-1. + 1.)**-3.

        scf_cat.data = ma.masked_array(scf_cat.data, \
              mask = np.isnan(scf_cat.data))
        scf2_cat.data = ma.masked_array(scf2_cat.data, \
              mask = np.isnan(scf2_cat.data))
        scf3_cat.data = ma.masked_array(scf3_cat.data, \
              mask = np.isnan(scf3_cat.data))

        weighted_scf_cat = scf_cat * ice_area_cat_cube
        weighted_scf2_cat = scf2_cat * ice_area_cat_cube
        weighted_scf3_cat = scf3_cat * ice_area_cat_cube
        sum_weighted_scf_cat = weighted_scf_cat.collapsed(['pseudo_level'],\
                               iris.analysis.SUM)
        sum_weighted_scf2_cat = weighted_scf2_cat.collapsed(['pseudo_level'],\
                               iris.analysis.SUM)
        sum_weighted_scf3_cat = weighted_scf3_cat.collapsed(['pseudo_level'],\
                               iris.analysis.SUM)

        gbm_scf_cat = sum_weighted_scf_cat / ice_area_cube
        gbm_scf2_cat = sum_weighted_scf2_cat / ice_area_cube
        gbm_scf3_cat = sum_weighted_scf3_cat / ice_area_cube
        gbm_scf_cat.data = ma.masked_array(gbm_scf_cat.data, \
              mask = np.isnan(gbm_scf_cat.data))
        gbm_scf2_cat.data = ma.masked_array(gbm_scf2_cat.data, \
              mask = np.isnan(gbm_scf2_cat.data))
        gbm_scf3_cat.data = ma.masked_array(gbm_scf2_cat.data, \
              mask = np.isnan(gbm_scf3_cat.data))

        for cat in range(1,6):
            varname = 'complex_scale_factor.cat.'+str(cat)
            cubelist_dic[varname].append(scf_cat[cat-1,:,:])
            varname = 'Ice_area.cat.'+str(cat)
            cubelist_dic[varname].append(ice_area_cat_cube[cat-1,:,:])        
            varname = 'Ice_thickness.cat.'+str(cat)
            cubelist_dic[varname].append(local_ice_thickness[cat-1,:,:])        

        cubelist_dic['complex_scale_factor3'].append(gbm_scf3_cat)
        cubelist_dic['complex_scale_factor2'].append(gbm_scf2_cat)
        cubelist_dic['complex_scale_factor'].append(gbm_scf_cat)
        cubelist_dic['B_lw_dep'].append(B_lw_dep)
        cubelist_dic['C_lw_ref'].append(C_lw_ref)

    for var in vars_save:
        for c in cubelist_dic[var]:
            if gridname == 'eorca1':
                if ma.is_masked(c.data):
                    mask_index = np.where(sf.data.mask)
                    c.data.mask[mask_index] = True
                else: 
                    c.data = ma.masked_array(c.data, mask = sf.data.mask)

        cubes = cubelist_dic[var].merge()
        if len(cubes) > 1:
            cubes = cubelist_dic[var].concatenate()
    
 
            if len(cubes) > 1:
                try:
                    for c in cubelist_dic[var]:
                        c.remove_coord('month')
                except (iris.exceptions.CoordinateNotFoundError,ValueError) as general_error:
                    pass

                if len(cubelist_dic[var][0].shape) == 3:
                    save_cube = cm.easy_concatenate(cubelist_dic[var])
                else:
                    save_cube = cice_concat.concat(cubelist_dic[var])

                icc.add_month(save_cube,'time')
            else:
                save_cube = cubes[0]
        else:
            save_cube = cubes[0]

        cube_dic[var] = save_cube
        savefile = savedirs[var] + str(year) + '.nc'
        iris.save(cube_dic[var],savefile)
