import iris
import field_functions as ff
import cube_magic as cm
import matplotlib.pyplot as plt
import plot_functions as pf
import numpy as np
import monty
import cartopy.crs as ccrs
import iris.coord_categorisation as icc
import copy
import numpy.ma as ma
import pylab
import datetime as dt
import cartopy.crs as ccrs
import regions
from matplotlib.font_manager import FontProperties

font_large = FontProperties(size=24)

months_plot = [2,5,7]

sf = ff.sample_field('hadgom')
lat_coord = sf.coord('latitude')
lon_coord = sf.coord('longitude')
ratio = iris.load_cube('/data/local/hadax/PhD/cat_ratio_cube.nc')
outfile_eps = '/home/h01/hadax/graphics/PhD/Papers/Dec2013_HadGEM2_energy_budget/v3.1/Figures/FB1_evaluate_formulae.eps'
outfile_png = '/home/h01/hadax/graphics/PhD/Papers/Dec2013_HadGEM2_energy_budget/v3.1/Figures/FB1_evaluate_formulae.png'

model = 'HadGEM2_ES'
experiment = 'Hist1'
startyear = 1980
endyear = 1999
month = 1
region_name = 'Arctic_Ocean.0'

years = range(startyear,endyear+1)

threshold = 271.15
season_log_varname = 'l_Tsfc_cold.'+str(threshold)

kice = 2.03
ksno = 0.33
Tfreeze = 273.15
sb_const = 5.67e-8
emis_ice_old = 0.98
emis_ice = 1.00
alpha_c = 0.8
alpha_m = 0.65
alpha_i = 0.61
alpha_o = 0.06
Tbot = -1.8

B_lw_dep_const = emis_ice_old * sb_const * 4. * Tfreeze ** 3. * -1.
C_lw_ref_const = emis_ice_old * sb_const * Tfreeze ** 4.


basedir = '/data/local/hadax/PhD/Attribution/'
datadir = basedir + model + '/' + experiment + '/' + 'monthly/hadgom/'

vars_read = ['LW_down','SW_down','iimelt','LW_net','Total_surface_flux','Sensible_heat_flux_over_ice','complex_scale_factor','Ice_thickness',\
    'Snow_volume','TTsfc','B_lw_dep','C_lw_ref','Ice_area',season_log_varname,'SW_net',\
    'LW_net','Sensible_heat_flux_over_ocean','f_atmos_ice','f_atmos_ocean','ice_albedo','Latent_heat_flux',\
    'Sensible_heat_flux']

cubelist_list = {var: iris.cube.CubeList([]) for var in vars_read}
for year in years:
    print year
    ffiles = {var: datadir + var + '/' + str(year) + '.nc' for var in vars_read}
    cubes = {var: iris.load_cube(ffiles[var]) for var in vars_read}
    cube_ref = cubes['LW_down']
    cubes_cfg = {}
    for var in vars_read:
        ccube_cfg = copy.deepcopy(cube_ref)
        ccube_cfg.data = cubes[var].data
        cubes_cfg[var] = ccube_cfg

    Tbot_cube = cubes_cfg['TTsfc'] * -1. + Tbot + Tfreeze

    winter_Fsfc_synth = cubes_cfg['Ice_area'] * (cubes_cfg['f_atmos_ice'] + cubes_cfg['B_lw_dep'] * Tbot_cube) * \
        cubes_cfg['complex_scale_factor'] + cubes_cfg['f_atmos_ocean'] * (cubes_cfg['Ice_area'] * -1. + 1.)
    summer_Fsfc_synth = cubes_cfg['LW_down'] + cubes_cfg['Sensible_heat_flux'] + cubes_cfg['Latent_heat_flux'] + \
                          cubes_cfg['SW_down'] * (cubes_cfg['Ice_area']*(cubes_cfg['ice_albedo'] * -1. + 1.) + \
                                                  (cubes_cfg['Ice_area']*-1.+1.)*(1-alpha_o)) - \
                          cubes_cfg['TTsfc']**4. * emis_ice * sb_const

    total_Fsfc_synth = winter_Fsfc_synth * cubes_cfg[season_log_varname] + \
        summer_Fsfc_synth * (cubes_cfg[season_log_varname] * -1. + 1.)

    cubes_cfg['Fsfc_synth'] = total_Fsfc_synth
    cubes_cfg['anomaly'] = cubes_cfg['Fsfc_synth'] - cubes_cfg['Total_surface_flux']

    for var in cubes_cfg.keys():
        if var not in cubelist_list.keys():
            cubelist_list[var] = iris.cube.CubeList([])

        cubelist_list[var].append(cubes_cfg[var])

sc_mean_cubes = {}
sc_std_cubes = {}
for var in cubelist_list.keys():
    mcube = cubelist_list[var].concatenate()
    if len(mcube) > 1:
        print 'Unable to merge cubes for variable '+var
        raise ValueError
    
    icc.add_month(mcube[0],'time')
    sc_mean_cubes[var] = mcube[0].aggregated_by('month',iris.analysis.MEAN)
    sc_std_cubes[var] = mcube[0].aggregated_by('month',iris.analysis.STD_DEV)
    regions.reduce_to_region(sc_mean_cubes[var],region_name,'hadgom')

xl = .1
xs = .3
xw = .24

yb = .72
yb_cl = .7
yw_cl = .015
ys = .33
yw = .25
rrange = 3.e6

levels_abs = [np.arange(-50,-5,5),np.arange(-20,25,5),np.arange(60,145,5)]
levels_anom = np.arange(-40,41,5)
cmap_abs = plt.get_cmap('YlGnBu_r')
cmap_anom = plt.get_cmap('RdBu_r')
letters = [['(a)','(b)','(c)'],['(d)','(e)','(f)'],['(g)','(h)','(i)']]

nps = ccrs.NorthPolarStereo()
figure = plt.figure(figsize=(13,13))


for (im,month) in enumerate(months_plot):
    month_name = dt.date(1980,month,1).strftime('%b')
    ax_synth = figure.add_axes([xl+xs*0,yb-ys*im,xw,yw],projection=nps)
    ax_exact = figure.add_axes([xl+xs*1,yb-ys*im,xw,yw],projection=nps)
    ax_anom = figure.add_axes([xl+xs*2,yb-ys*im,xw,yw],projection=nps)
    
    clev_abs = pf.polarplot(sc_mean_cubes['Fsfc_synth'][month-1,:,:],rrange=rrange,\
        levels = levels_abs[im], cmap = cmap_abs, figure = figure,axis = ax_synth,\
        colorbar = False, show = False, title = ' ')
    
    clev_abs = pf.polarplot(sc_mean_cubes['Total_surface_flux'][month-1,:,:],\
        rrange=rrange,\
        levels = levels_abs[im], cmap = cmap_abs, figure = figure,axis = ax_exact,\
        colorbar = False, show = False, title = ' ')
    
    clev_anom = pf.polarplot(sc_mean_cubes['anomaly'][month-1,:,:],rrange=rrange,\
        levels = levels_anom, cmap = cmap_anom, figure = figure,axis = ax_anom,\
        colorbar = False, show = False, title = ' ')
    
    ax_cl_synth = figure.add_axes([xl+xs*0,yb_cl-ys*im,xw,yw_cl])
    ax_cl_exact = figure.add_axes([xl+xs*1,yb_cl-ys*im,xw,yw_cl])
    ax_cl_anom = figure.add_axes([xl+xs*2,yb_cl-ys*im,xw,yw_cl])
    
    ax_synth.set_title(letters[im][0])
    ax_exact.set_title(letters[im][1])
    ax_anom.set_title(letters[im][2])
    
    plt.colorbar(clev_abs,cax=ax_cl_synth,orientation='horizontal')
    plt.colorbar(clev_abs,cax=ax_cl_exact,orientation='horizontal')
    plt.colorbar(clev_anom,cax=ax_cl_anom,orientation='horizontal')
    
    ax_cl_synth.text(.5,-2.,'Flux ($W/m^{-2}$)',ha='center')
    ax_cl_exact.text(.5,-2.,'Flux ($W/m^{-2}$)',ha='center')
    ax_cl_anom.text(.5,-2.,'Flux error ($W/m^{-2}$)',ha='center')
    
    ax_cl_synth.text(-.08,8.,month_name,ha='right',fontproperties=font_large)

pylab.savefig(outfile_eps)
pylab.savefig(outfile_png)
