import iris
import timemod
import copy
import sys
sys.path.insert(0,'/home/h01/hadax/Python/PhD/Attribution/')
import attribution_constants as ac
import cube_magic as cm
import os
import numpy.ma as ma
import numpy as np

process = 'iimelt'
flux_name = 'SW_net'
model = 'HadGEM2_ES'
experiment = 'Hist1'
freq = 'daily'
gridname = 'n96'
startyear = 1980
endyear = 1999

#PD_formula = 'l_Tsfc_cold * (-1. * ac.B_lw_dep * (Ice_volume / ac.kice + Snow_volume / ac.ksno) + 1.)**-2. * (LW_down - ac.LW_ref + ac.B_lw_dep * ac.Tbot) * ac.B_lw_dep / ac.kice'
#component_variable_names = ['LW_down','Ice_volume','Snow_volume','l_Tsfc_cold']
#obs_component_variable_names = ['LW_down','Ice_volume']

#PD_formula = '(-1.*(ac.alpha_i - ac.alpha_o) * iifrac + (ac.alpha_m - ac.alpha_c) * (-1.*iimelt+1) + 1 - ac.alpha_o)*(l_Tsfc_cold * (-1. * ac.B_lw_dep * (Ice_volume / ac.kice + Snow_volume / ac.ksno) + 1.)**-1. + (-1.*l_Tsfc_cold + 1.))'
#component_variable_names = ['SW_down','iifrac','iimelt','l_Tsfc_cold','Ice_volume','Snow_volume']
#obs_component_variable_names  = ['SW_down','iifrac','iimelt','Ice_volume']

#PD_formula = 'l_Tsfc_cold * (-1. * ac.B_lw_dep * (Ice_volume / ac.kice + Snow_volume / ac.ksno) + 1.)**-1. + (-1.*l_Tsfc_cold + 1.)'
#component_variable_names = ['LW_down','Ice_volume','Snow_volume','l_Tsfc_cold']
#obs_component_variable_names = ['LW_down','Ice_volume']

#PD_formula = 'SW_down * (ac.alpha_o - ac.alpha_i)'
#component_variable_names = ['SW_down','iifrac']
#obs_component_variable_names = ['iifrac']

PD_formula = 'SW_down * (ac.alpha_c - ac.alpha_m)'
component_variable_names = ['SW_down','iimelt']
obs_component_variable_names = ['iimelt']

obs_names = {'LW_down':'ISCCP-FD','Ice_volume':'PIOMAS','SW_down':'ISCCP-FD','iifrac':'NSIDCC','iimelt':'melt_onset'}
obs_weights = {'LW_down':0.5,'Ice_volume':0.5,'Snow_volume':0.,'l_Tsfc_cold':0.,'iifrac':0.5,'iimelt':0.5,'SW_down':0.}
obs_files = {cvname: '/data/local/hadax/PhD/Attribution/' + obs_names[cvname] + '/' + freq + '/' + gridname + '/' + \
   cvname + '.nc' for cvname in obs_component_variable_names}

ntime_dic = {'monthly':12, 'daily': 365}

for year in range(startyear,endyear+1):
    model_files = {cvname:'/data/local/hadax/PhD/Attribution/' + model + '/' + experiment + '/' + freq + '/' + gridname + '/' + \
       cvname + '/' + str(year) + '.nc' for cvname in component_variable_names}

    savedir = '/data/local/hadax/PhD/Attribution/' + model + '/' + experiment + '/' + freq + '/' + gridname + '/output/'
    if not os.path.exists(savedir):
        os.makedirs(savedir)
    savefile = savedir + \
        process + '.' + flux_name + '.' + '-'.join([obs_names[key] for key in obs_names.keys() if key in obs_component_variable_names]) + '.' + str(year) + '.nc'
    obs_cubes = {cvname: iris.load_cube(obs_files[cvname]) \
        for cvname in obs_component_variable_names}

    if freq == 'monthly':
        obs_cubes_ordered = {cvname: cm.sc_translate(obs_cubes[cvname]) \
            for cvname in obs_component_variable_names}
    else:
        obs_cubes_ordered = copy.deepcopy(obs_cubes)

    model_cubes = {cvname: iris.load_cube(model_files[cvname]) \
        for cvname in component_variable_names}

    obs_cubes_configured = {}
    for cvname in obs_component_variable_names:
        obs_cubes_configured[cvname] = copy.deepcopy(model_cubes[cvname])
        obs_cubes_configured[cvname].data = obs_cubes_ordered[cvname].data

    process_anomaly = model_cubes[process] - obs_cubes_configured[process]
    if not ma.is_masked(process_anomaly):
        process_anomaly.data = ma.masked_array(process_anomaly.data,\
            mask = np.zeros(process_anomaly.shape,dtype='bool'))

    reference_cubes = {}
    ny,nx = process_anomaly.shape[1:]
    ntime = ntime_dic[freq]
    missing_data_mask = np.zeros((ntime,ny,nx),dtype='bool')
    for cvname in component_variable_names:
        if obs_weights[cvname] == 0.:
            reference_cubes[cvname] = copy.deepcopy(model_cubes[cvname])
        else:
            reference_cubes[cvname] = obs_weights[cvname] * obs_cubes_configured[cvname] + (1 - obs_weights[cvname]) * model_cubes[cvname]
            if ma.is_masked(model_cubes[cvname].data):
                index = np.where(model_cubes[cvname].data.mask == True)
                reference_cubes[cvname].data.mask[index] = True

        reference_cubes[cvname].units = '1'

        if not ma.is_masked(reference_cubes[cvname].data):
            reference_cubes[cvname].data = ma.masked_array(\
                reference_cubes[cvname].data, mask = \
                np.zeros((ntime,ny,nx),dtype='bool'))

        missing_data_mask = np.logical_or(missing_data_mask,\
            reference_cubes[cvname].data.mask)

    for cvname in component_variable_names:
        reference_cubes[cvname].data.mask = missing_data_mask

    process_anomaly.data.mask = missing_data_mask

    expr_subs = copy.copy(PD_formula)
    for (icv,cvname) in enumerate(component_variable_names):
        expr_subs = str.replace(expr_subs,cvname,'reference_cubes[\''+cvname+'\']')

    partial_derivative = eval(expr_subs)
    flux_contribution = process_anomaly * partial_derivative

    iris.save(flux_contribution,savefile)
