'''Code for calculating fields of induced surface flux biases'''

version = 'v4_06_09_2018'

class InducedSurfaceFlux:
    def __init__(self,model,experiment,freq,process,obs_scheme,threshold,\
           gridname):
        import os
        self.model = model
        self.experiment = experiment
        self.freq = freq
        self.process = process
        self.obs_scheme = obs_scheme
        self.gridname = gridname
        self.threshold = threshold
        self.version = version

        self.winter_logical_name = 'l_Tsfc_cold.'+str(self.threshold)
        self.cvnames = read_process(self.process)['cvnames'].split(',')+[self.winter_logical_name]
        self.obs_cvnames = read_process(self.process)['obs_cvnames'].split(',')
        self.obsnames = [read_obs()[self.obs_scheme][key] for \
             key in read_obs()[self.obs_scheme].keys()]
        self.model_datadirs = {cvname: '/data/local/hadax/PhD/Attribution/'+\
            self.model + '/' + self.experiment + '/' + self.freq + '/' +\
            self.gridname + '/' + cvname + '/' for cvname in self.cvnames}
        self.obs_datadirs = {obsname:'/data/local/hadax/PhD/Attribution/'+\
            obsname + '/' + self.freq + '/' + self.gridname + '/'  \
            for obsname in self.obsnames}
        self.savedir_3D = '/data/local/hadax/PhD/Attribution/output_'+\
            self.version+'/output_3D/' + \
            self.model + '/' + self.experiment + '/' + self.freq + '/' + \
            self.gridname + '/'
        self.savedir_1D = '/data/local/hadax/PhD/Attribution/output_'+\
            self.version+'/output_1D/' + \
            self.model + '/' + self.experiment + '/' + self.freq + '/' + \
            self.gridname + '/'
        self.model_cubes = {cvname: {} for cvname in self.cvnames}
        self.reference_cubes = {cvname: {} for cvname in self.cvnames}
        self.anomaly_cubes = {cvname: {} for cvname in self.cvnames}
        self.process_anomaly = {}
        self.partial_derivative = {}
        self.flux_contribution = {}
        self.obs_cubes = {}
        self.masks = {}
        self.ts_fc = {}

        for ddir in [self.savedir_3D,self.savedir_1D]:
            if not os.path.exists(ddir):
                os.makedirs(ddir)


    def get_model_data(self,year):
        import glob
        import iris
        import timemod
        print 'Reading model data'
        for cvname in self.cvnames:
            model_file = self.model_datadirs[cvname] + str(year) + '.nc'
            incube = iris.load_cube(model_file)
            self.model_cubes[cvname][year] = incube


    def get_obs_data(self):
        import iris
        import cube_magic as cm
        import iris.coord_categorisation as icc
        print 'Reading obs data'
        for cvname in self.obs_cvnames:
            obsname = read_obs()[self.obs_scheme][cvname]
            obs_file = self.obs_datadirs[obsname] + cvname + '.nc'
            obs_cube = iris.load_cube(obs_file)
            try:
                icc.add_month(obs_cube,'time')
            except ValueError:
                pass
            self.obs_cubes[cvname] = cm.sc_translate(obs_cube)


    def make_reference_states(self,year):
        print 'Creating reference cubes'
        import field_functions as ff
        import datetime as dt
        import iris
        import numpy as np
        import numpy.ma as ma
        sf = ff.sample_field(self.gridname)
        lat_coord = sf.coord('latitude')
        lon_coord = sf.coord('longitude')
        ref_time = dt.date(1978,9,1)
        ref_units = ref_time.strftime('days since %Y-%m-%d')
        time_points = [dt.date(year,month,1).toordinal() - \
                       ref_time.toordinal() for month in range(1,13)]
        
        time_coord = iris.coords.DimCoord(time_points,'time',units=ref_units)
        shape = (12,sf.shape[0],sf.shape[1])

        for cvname in self.cvnames:
            data = ma.masked_array(np.zeros(shape),\
                      mask=np.zeros(shape,dtype='bool'))
            cube = iris.cube.Cube(data)
            cube.add_dim_coord(time_coord,0)
            try:
                cube.add_dim_coord(lat_coord,1)
                cube.add_dim_coord(lon_coord,2)
            except ValueError:
                cube.add_aux_coord(lat_coord,[1,2])
                cube.add_aux_coord(lon_coord,[1,2])
            cube.data = self.model_cubes[cvname][year].data
            self.reference_cubes[cvname][year] = cube
            
        self.apply_mask(self.reference_cubes[cvname][year],year)


    def make_process_anomaly(self,year):
        print 'Calculating process anomaly'
        import iris
        import copy
        import os
        obs_cube_cfg = copy.deepcopy(self.model_cubes[self.process][year])
        obs_cube_cfg.data = self.obs_cubes[self.process].data
        self.process_anomaly[year] = self.model_cubes[self.process][year] - \
                                     obs_cube_cfg

        self.apply_mask(self.process_anomaly[year],year)
        pa_dir = self.savedir_3D + 'process_anomalies/'
        if not os.path.exists(pa_dir):
            os.makedirs(pa_dir)
        iris.save(self.process_anomaly[year],pa_dir + self.filename(year) + '.nc')

    def pd_expression(self,year,order=1):
        try:
            winter_formula = read_process(self.process)['winter_formula'+str(order)]
            summer_formula = read_process(self.process)['summer_formula'+str(order)]
        except KeyError:
            print 'Could not find formulae for this order of derivative'

        full_formula = '(' + winter_formula + ') * '+self.winter_logical_name +' + (' + \
                             summer_formula + ') * ('+self.winter_logical_name+' * -1. + 1.)'

        for cvname in self.cvnames:
            full_formula = full_formula.replace(cvname,\
                    'self.reference_cubes[\''+cvname+'\']['+str(year)+']')

        return full_formula


    def evaluate_pd(self,year):
        print 'Evaluating partial derivative'
        import attribution_constants as ac
        import os
        import iris
        self.partial_derivative[year] = eval(self.pd_expression(year))
        self.apply_mask(self.partial_derivative[year],year)
        pd_dir = self.savedir_3D + 'partial_derivatives/'
        if not os.path.exists(pd_dir):
            os.makedirs(pd_dir)
        iris.save(self.partial_derivative[year],pd_dir + self.filename(year) + '.nc')

    def evaluate_expr(self,expr,in_cube_list):
        import attribution_constants as ac
        return eval(expr)


    def evaluate_pd2(self,year):
        import attribution_constants as ac
        import os
        import iris
        if not hasattr(self,'pd2'):
            self.pd2 = {}
        self.pd2[year] = eval(self.pd_expression(year,order=2))
        self.apply_mask(self.pd2[year],year)
        pd_dir = self.savedir_3D + 'partial_derivatives2/'
        if not os.path.exists(pd_dir):
            os.makedirs(pd_dir)
        iris.save(self.pd2[year],pd_dir + self.filename(year) + '.pd2.nc')


    def evaluate_flux_contribution(self,year):
        print 'Evaluating flux contribution'
        self.flux_contribution[year] = self.partial_derivative[year] * \
              self.process_anomaly[year]
        self.apply_mask(self.flux_contribution[year],year)


    def evaluate_taylor2(self,year):
        if not hasattr(self,'taylor2'):
            self.taylor2 = {}
        self.taylor2[year] = 2. * self.pd2[year] * self.process_anomaly[year] ** 2.


    def generate_missing_data_mask(self,year):
        import numpy as np
        import numpy.ma as ma

        mdmask = np.zeros(self.model_cubes[self.cvnames[0]][year].shape,dtype='bool')
        for cvname in self.cvnames:
            source_cubes = [self.model_cubes[cvname][year]]
            if cvname in self.obs_cvnames:
                source_cubes = source_cubes + [self.obs_cubes[cvname]]

            for source_cube in source_cubes:
                if ma.is_masked(source_cube.data):
                    index = np.where(source_cube.data.mask == True)
                    mdmask[index] = True

        self.masks[year] = mdmask

                
    def apply_mask(self,ccube,year):
        import numpy.ma as ma
        import numpy as np

        if ccube.shape != self.masks[year].shape:
            print 'Unable to apply mask; cube.shape = ', ccube.shape, \
               'mask.shape = ', self.masks[year].shape
        if ma.is_masked(ccube):
            index = np.where(self.masks[year] == True)
            ccube.data.mask[index] = True
        else:
            ccube.data = ma.masked_array(ccube.data, mask = self.masks[year])


    def calc_ts(self,region_list):
        years = self.flux_contribution.keys()
        years.sort()
        self.ts_fc = calc_ts(self.flux_contribution,years,region_list,self.gridname)


    def filename(self,year):
        return self.process + '.' + self.obs_scheme + '.' + \
                       str(year)


    def save(self):
        import iris
        years = self.flux_contribution.keys()
        years.sort()
        for year in years:
            iris.save(self.flux_contribution[year],self.savedir_3D + self.filename(year) + '.nc')

        for region_name in self.ts_fc.keys():
            iris.save(self.ts_fc[region_name],self.savedir_1D + self.filename(year) + '.' + region_name + '.nc')


def read_process(process_name):
    ffile = 'isf_input.txt'
    return readfile(ffile)[process_name]

def read_obs():
    ffile = 'isf_obs.txt'
    result = readfile(ffile)
    return result

def readfile(ffile):
    import configparser
    import StringIO
    conf = configparser.SafeConfigParser()
    with open(ffile) as fid:
       lines = [i.lstrip() for i in fid.readlines()]
       _str = StringIO.StringIO()
       _str.write(''.join(lines)) 
       _str.seek(0)
       conf.read_file(_str)
       conf.file = ffile
       _str.close()

    return conf

def processes():
    info = readfile('isf_input.txt')
    rv = info.keys()
    rv.remove('DEFAULT')
    return rv

def obs_schemes():
    info = readfile('isf_obs.txt')
    rv = info.keys()
    rv.remove('DEFAULT')
    return rv

def calc_ts(cubedic,years,region_list,gridname):
    import iris
    import copy
    import regions
    import cube_magic as cm

    return_value = {}

    ts_cubedic = {region_name: iris.cube.CubeList([]) for region_name in \
               region_list}
    for year in years:
        cubel = iris.cube.CubeList([])
        cube = copy.deepcopy(cubedic[year])
        if len(cube.coord('latitude').points.shape) == 1:
            cm.llbounds(cube)
            aw = iris.analysis.cartography.area_weights(cube)
        elif len(cube.coord('latitude').points.shape) == 2:
            import mesh
            import numpy as np
            aw2 = mesh.area_weights_2D(gridname)
            aw = np.tile(aw2,(cube.shape[0],1,1))

        for region_name in region_list:
            rcube = copy.deepcopy(cube)
            regions.reduce_to_region(rcube,region_name,gridname)
            tscube = rcube.collapsed(['latitude','longitude'],\
                           iris.analysis.MEAN, weights = aw)
            ts_cubedic[region_name].append(tscube)

    for region_name in region_list:
        iris.util.unify_time_units(ts_cubedic[region_name])
        ts_cubelc = ts_cubedic[region_name].concatenate()
        if len(ts_cubelc) > 1:
            print 'Unable to concatenate into single cube'
            return None

        return_value[region_name] = ts_cubelc[0]

    return return_value
