'''Calculates mixed partial derivatives as described in Appendix A'''

import iris
import copy
import plot_functions as pf
import matplotlib.pyplot as plt
import numpy as np
import field_functions as ff
import regions
import cube_magic as cm

cmap = plt.get_cmap('RdBu_r')
model = 'HadGEM2_ES'
experiment = 'Hist1'
year = 1981
category = 1

alpha_ci = 0.61
alpha_i = 0.535
alpha_o = 0.06


sf = ff.sample_field('hadgom')

varnames_read = ['LW_down','Ice_area','SW_down','iimelt']+['Ice_thickness.cat.'+str(category) for category in range(1,6)]+\
    ['complex_scale_factor.cat.'+str(category) for category in range(1,6)]+\
    ['Ice_area.cat.'+str(category) for category in range(1,6)]
datasets = {'LW_down':'CERES',\
            'SW_down':'CERES',\
            'iimelt':'melt_onset',\
            'Ice_thickness.cat.1':'PIOMAS',\
            'Ice_thickness.cat.2':'PIOMAS',\
            'Ice_thickness.cat.3':'PIOMAS',\
            'Ice_thickness.cat.4':'PIOMAS',\
            'Ice_thickness.cat.5':'PIOMAS',\
            'Ice_area':'NSIDCC'}

basedir = '/data/local/hadax/PhD/Attribution/'
sdatadir_m = basedir + model + '/' + experiment + '/' + \
    'monthly/hadgom/'

cubedic_o = {}
cubedic_m = {}
cubedic_a = {}
template_cube = iris.load_cube(sdatadir_m + 'SW_down/1981.nc')
template_cube.add_dim_coord(sf.coord('latitude'),1)
template_cube.add_dim_coord(sf.coord('longitude'),2)

for varname in varnames_read:
    datadir_m = sdatadir_m + varname + '/'
    file_m = datadir_m + str(year) + '.nc'
    cube_m = iris.load_cube(file_m)
    cube_m_cfg = copy.deepcopy(template_cube)
    cube_m_cfg.data = cube_m.data
    cubedic_m[varname] = cube_m_cfg

    if varname in datasets.keys():
        dataset = datasets[varname]
        datadir_o = basedir + dataset + '/monthly/hadgom/'
        file_o = datadir_o + varname + '.nc'
        cube_o = iris.load_cube(file_o)
        cube_o_cfg = copy.deepcopy(template_cube)
        cube_o_cfg.data = cube_o.data
        cubedic_o[varname] = cube_o_cfg
        cubedic_a[varname] = cube_m_cfg - cube_o_cfg

total_cat = np.zeros((12,5))
total_all = np.zeros(12)
total_process = np.zeros((12,3))

cubess = []
for category in range(1,6):
    xa1 = cubedic_m['Ice_area.cat.'+str(category)] * cubedic_a['LW_down'] * cubedic_a['Ice_thickness.cat.'+str(category)] * 4.5 / 2.03 * cubedic_m['complex_scale_factor.cat.'+str(category)] ** 2.
    xa2 = cubedic_m['LW_down'] * cubedic_a['Ice_area'] * cubedic_m['Ice_area.cat.'+str(category)] / cubedic_m['Ice_area'] * cubedic_a['Ice_thickness.cat.'+str(category)] * 4.5 / 2.03 * cubedic_m['complex_scale_factor.cat.'+str(category)] ** 2.
    xa3 = cubedic_a['LW_down'] * cubedic_a['Ice_area'] * cubedic_m['Ice_area.cat.'+str(category)] / cubedic_m['Ice_area'] * cubedic_m['complex_scale_factor.cat.'+str(category)]

    cubelist = iris.cube.CubeList([])
    for (iproc,cube) in enumerate((xa1,xa2,xa3)):   
        regions.reduce_to_region(cube,'Arctic_Ocean.0','hadgom')
        cm.llbounds(cube)
        aw = iris.analysis.cartography.area_weights(cube)
        ts_cube = cube.collapsed(['latitude','longitude'],iris.analysis.MEAN,weights=aw)
        cubelist.append(ts_cube)
        total_cat[:,category-1] = total_cat[:,category-1] + ts_cube.data
        total_all = total_all + ts_cube.data
        total_process[:,iproc] = total_process[:,iproc] + ts_cube.data

    cubess.append(cubelist)
    

xa1_s = cubedic_a['SW_down'] * cubedic_a['Ice_area'] * ((alpha_i-alpha_ci)*cubedic_m['iimelt'] + alpha_o - alpha_i)
xa2_s = cubedic_a['SW_down'] * cubedic_m['Ice_area'] * cubedic_a['iimelt'] * (alpha_ci - alpha_i)
xa3_s = cubedic_m['SW_down'] * cubedic_a['Ice_area'] * cubedic_a['iimelt'] * (alpha_ci - alpha_i)

summer_cubes = iris.cube.CubeList([])
for (iproc,cube) in enumerate((xa1_s,xa2_s,xa3_s)):   
    regions.reduce_to_region(cube,'Arctic_Ocean.0','hadgom')
    cm.llbounds(cube)
    aw = iris.analysis.cartography.area_weights(cube)
    ts_cube = cube.collapsed(['latitude','longitude'],iris.analysis.MEAN,weights=aw)
    summer_cubes.append(ts_cube)

    
