'''Estimates the effect of covariance of variables during the summer, as described in Appendix A'''

import datetime as dt
import iris
import numpy as np
import numpy.ma as ma
import field_functions as ff
import plot_functions as pf
import matplotlib.pyplot as plt
import numpy as np
from arrays import wrap
import copy

component_variables = ['SW_down','Ice_area','ice_albedo']
model = 'HadGEM2_ES'
experiment = 'Hist1'
gridname = 'hadgom'
year_examine = 1981
seconds_in_month = 8.64e4 * 30.
alpha_o = 0.06

sf = ff.sample_field(gridname)

ref_date = dt.date(1978,9,1)
ref_units = ref_date.strftime('days since %Y-%m-%d')

basedir = '/data/local/hadax/PhD/Attribution/'
datadir = basedir + model + '/' + experiment + '/' + \
    '/monthly/' + gridname + '/'

mean_cubedic = {}
trend_cubedic = {}
for cvvar in component_variables:
    filenames = [datadir + cvvar + '/' + str(year) + '.nc' \
                 for year in range(year_examine-1,year_examine+2)]
    cubes = iris.cube.CubeList([iris.load_cube(ffile) for ffile in filenames])
    ny, nx = cubes[0].shape[1:]
    full_data = ma.masked_array(np.zeros((36,ny,nx)),\
                     mask = np.zeros((36,ny,nx),dtype='bool'))
    for (ic,cube) in enumerate(cubes):
        if ma.is_masked(cube.data):
            full_data.data[ic*12:ic*12+12,:,:] = cube.data.data
            full_data.mask[ic*12:ic*12+12,:,:] = cube.data.mask
        else:
            full_data.data[ic*12:ic*12+12,:,:] = cube.data


    trend_mdata = ma.masked_array(np.zeros((12,ny,nx)),\
                     mask = np.zeros((12,ny,nx),dtype='bool'))

    for month in range(1,13):
        data_l = full_data[month+10,:,:]
        data_n = full_data[month+12,:,:]

        trend_data = (data_n - data_l) / 2.
        if ma.is_masked(data_l):
            trend_mdata.data[month-1,:,:] = trend_data.data
            trend_mdata.mask[month-1,:,:] = trend_data.mask
        else:
            trend_mdata.data[month-1,:,:] = trend_data

    trend_cube = iris.cube.Cube(trend_mdata)
    
    lat_coord = sf.coord('latitude')
    lon_coord = sf.coord('longitude')

    trend_cube.add_dim_coord(lat_coord,1)
    trend_cube.add_dim_coord(lon_coord,2)

    time_points = [dt.date(year,month,1).toordinal() - \
                   ref_date.toordinal() for month in range(1,13)]
    time_coord = iris.coords.DimCoord(time_points,'time',units=ref_units)
    trend_cubedic[cvvar] = trend_cube
    mean_cubedic[cvvar] = cubes[1]

cmap = plt.get_cmap('RdBu_r')
levels = np.arange(-30,31,5)

xt1 = mean_cubedic['SW_down'] * trend_cubedic['Ice_area'] * \
    trend_cubedic['ice_albedo'] * -1.
xt2 = mean_cubedic['Ice_area'] * trend_cubedic['SW_down'] * \
    trend_cubedic['ice_albedo'] * -1.
xt3 = (mean_cubedic['ice_albedo']*-1. + 1.) * trend_cubedic['SW_down'] * \
    trend_cubedic['Ice_area']

xt4 = trend_cubedic['SW_down'] * trend_cubedic['Ice_area'] * (alpha_o - 1.)

total = copy.deepcopy(xt1)
total.data = xt1.data + xt2.data + xt3.data + xt4.data

xx_plot = np.arange(-.5,13,1)
plt.plot(xx_plot,wrap(xt1[:,210,100].data),label = 'Cov(ice area, ice albedo)')
plt.plot(xx_plot,wrap(xt2[:,210,100].data),label = 'Cov(SW down, ice albedo)')
plt.plot(xx_plot,wrap(xt3[:,210,100].data),label = 'Cov(SW down, ice area), ice cpt')
plt.plot(xx_plot,wrap(xt4[:,210,100].data),label = 'Cov(SW_down, ice area), ocn cpt')

plt.plot(xx_plot,wrap(total[:,210,100].data),label = 'TOTAL',linewidth=2,color ='k')
plt.gca().legend()
plt.show()
