import sys
sys.path.insert(0,'/home/h01/hadax/Python/PhD/Obs/')
sys.path.insert(0,'/home/h01/hadax/Python/PhD/Heat_budget/v5_May_2016/')
import timemod
import pylab
import sheba
import regions
import datetime as dt
import iris
import iris.coord_categorisation
from arrays import wrap
import budgets
import os
import numpy.ma as ma
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.font_manager import FontProperties
import copy
import cartopy.crs as ccrs
import field_functions as ff
import cartopy
import monty
import plot_functions as pf

font_medium = FontProperties(size=14)

outfile_eps = '/home/h01/hadax/graphics/PhD/Papers/'+\
    'Dec2013_HadGEM2_energy_budget/'+\
    'v2.2_figs/F4_sfcrad_evaluation.eps'
outfile_png = '/home/h01/hadax/graphics/PhD/Papers/'+\
    'Dec2013_HadGEM2_energy_budget/'+\
    'v2.2_figs/F4_sfcrad_evaluation.png'

mean_model_cubedic = {}
mean_model_cubelist_sheba = iris.cube.CubeList([])
std_model_cubelist_ao = iris.cube.CubeList([])
std_model_cubelist_sheba = iris.cube.CubeList([])
obs_sheba_cubelist = iris.cube.CubeList([])
sheba_anomaly_cubelist = iris.cube.CubeList([])
rcode = 'Arctic_Ocean.1'
model_tcode = '1980.1999'

map_dataset = 'CERES'
map_months = {'SW_down': [5],\
              'SW_up': [6],\
              'SW_net': [6],\
              'LW_down': [2],\
              'LW_up':[],\
              'LW_net':[2]}

vars_plot_anom = ['SW_down',\
                  'SW_up',\
                  'SW_net',\
                  'LW_down',\
                  'LW_up',\
                  'LW_net']

vars_plot_anom_over_ice = [v+'_over_ice' for v in vars_plot_anom]
anom_place = len(vars_plot_anom) - 1

label_dictionary = {\
    'SW_down':'SW down',\
    'SW_up':'SW up',\
    'SW_net':'Net SW',\
    'LW_down':'LW down',\
    'LW_up':'LW up',\
    'LW_net':'Net LW',\
    'Sensible_heat_flux':'Sensible\nheat flux',\
    'Latent_heat_flux':'Latent\nheat flux',\
    'Total_snowfall':'Heat flux due\nto snowfall',\
    'SW_down_over_ice':'SW down',\
    'SW_up_over_ice':'SW up',\
    'SW_net_over_ice':'Net SW',\
    'LW_down_over_ice':'LW down',\
    'LW_up_over_ice':'LW up',\
    'LW_net_over_ice':'Net LW',\
    'Sensible_heat_flux_over_ice':'Sensible\nheat flux',\
    'Latent_heat_flux_over_ice':'Latent\nheat flux',\
    }    

obs_datasets = {'ISCCP':\
    {'data':{},\
     'anomaly':{},\
     'linestyle':[8,3],'tcode':'1980.1999'},\
                'CERES':\
    {'data':{},\
     'anomaly':{},\
     'linestyle':[6,3,2,3],'tcode':'2000.2013'},\
                'ERAI':\
    {'data':{},\
     'anomaly':{},\
     'linestyle':[2,3],'tcode':'1980.1999'},\
}

obs_colours = {'ISCCP':'#0000ff','CERES':'#a04000','ERAI':'#808080',\
   'AGR':'#ff00ff'}

model_linestyle = '-'
sheba_linestyle = '--'

model_attributes = {'name':'HadGEM2_ES','startyear':1980,'endyear':1999,\
    'rcode':'SHEBA.0','linestyle':'-'}
obs_attributes = {'name':'SHEBA','startyear':'Nov1997','endyear':'Sep1998',\
    'rcode':'SHEBA.actual_track','linestyle':'.'}

for fluxname in vars_plot_anom:
    model_cubelist_ao = iris.cube.CubeList([])
    for ensemble_member in range(1,5):
        model_file_ao = '/data/cr1/hadax/HadGEM2_ES/Hist1'+\
            '/'+'SFC'+'/'+'monthly'+'/'+\
            fluxname+'.'+model_tcode+'.'+rcode+'.nc'
        model_cube_ao = iris.load_cube(model_file_ao)
        ensemble_coord = iris.coords.AuxCoord(ensemble_member,'realization')
        model_cube_ao = timemod.sc_1D(model_cube_ao)
        model_cube_ao.add_aux_coord(ensemble_coord)

        model_cubelist_ao.append(model_cube_ao)

    model_cubelist_ao_merge = model_cubelist_ao.merge()
    if len(model_cubelist_ao_merge) > 1:
        print 'Unable to merge model ao cubelist'
        raise ValueError

    model_cube_ao_all = model_cubelist_ao_merge[0]
    model_cube_ao_mean = model_cube_ao_all.collapsed(['realization'],\
        iris.analysis.MEAN)
    model_cube_ao_std = model_cube_ao_all.collapsed(['realization'],\
        iris.analysis.STD_DEV)
    mean_model_cubedic[fluxname] = model_cube_ao_mean

    for obs_label in obs_datasets.keys():
        obs_file = '/data/cr1/hadax/'+obs_label+'/'+'SFC'+'/'+fluxname+'.'+\
            obs_datasets[obs_label]['tcode']+'.'+rcode+'.nc'
        if os.path.exists(obs_file):
            obs_cube = iris.load_cube(obs_file)
            if obs_datasets[obs_label]['tcode'] != 'CLIMO':
                obs_cube = timemod.sc_1D(obs_cube)
        else:
            obs_cube = iris.cube.Cube(ma.masked_array(np.zeros(12),mask=True),long_name=None)
            obs_cube.units = 'W/m^2'

        obs_cube_for_anom = copy.deepcopy(model_cube_ao_mean)
        obs_cube_for_anom.data = obs_cube.data
        anom_cube = model_cube_ao_mean - obs_cube_for_anom
        anom_cube.long_name = model_cube_ao_mean.long_name
        obs_datasets[obs_label]['data'][fluxname] = obs_cube
        obs_datasets[obs_label]['anomaly'][fluxname] = anom_cube

mod_sc_cubes = {}
obs_sc_cubes = {}
for var in vars_plot_anom:
    mod_file = '/data/local/hadax/PhD/multiannual_fields/HadGEM2_ES/Hist1/'+\
        'SFC/monthly/'+var + '.' + model_tcode + '.nc'
    obs_file = '/data/local/hadax/PhD/multiannual_fields/' + \
        map_dataset + '/n96/'+\
        var + '.' + obs_datasets[map_dataset]['tcode'] + '.nc'
    mod_cube = iris.load_cube(mod_file)
    obs_cube = iris.load_cube(obs_file)
    mod_sc_cubes[var] = mod_cube
    obs_sc_cubes[var] = obs_cube

xl = .1
yt = .71
xw = .26
yw_abs = .26
yw_anom = .115
xgap = .315
ygap = .46
aa_gap = 0.02

xoffsets = [0.07,-.04,0.04,0.08,0.,0.08]
yoffsets = [.48,1.03,.48,1.03,0.,.48]

ylims = [[-50,280], [-200,65],[-50,140],[160,340],[-340,-200],[-85,-15]]
yticks = [np.arange(-50,280,50),np.arange(-200,65,50),np.arange(-50,140,50),\
          np.arange(150,340,50),np.arange(-340,-200,20),np.arange(-80,-15,20)]

xw_c = .09
yw_c = .12

nps = ccrs.NorthPolarStereo()

xx_plot = np.arange(-.5,13)
figure = plt.figure(figsize=(12,10))
letters = ['(a)','(b)','(c)','(d)','(e)','(f)']
levels = np.array([-100,-50,-20,-10,-5,0,5,10,20,50,100])
cmap = monty.clr_cmap('/home/h01/hadax/IDL/colour_tables/anomalies.clr')

for (ivar,var) in enumerate(vars_plot_anom):
    xi = ivar % 3
    yi = ivar / 3
    
    axis_abs = figure.add_axes([xl + xi*xgap, yt - yi*ygap, xw, yw_abs])
    axis_anom = figure.add_axes([xl + xi*xgap, yt - yi*ygap - yw_anom - aa_gap,\
                    xw, yw_anom])

    axis_abs.plot(xx_plot,wrap(mean_model_cubedic[var].data),\
        label = 'HadGEM2-ES', color = 'k')
    
    for obs_label in obs_datasets.keys():
        axis_abs.plot(xx_plot,wrap(obs_datasets[obs_label]['data'][var].data),\
            label = obs_label, color = obs_colours[obs_label])
        axis_anom.plot(xx_plot,\
            wrap(obs_datasets[obs_label]['anomaly'][var].data),\
            label = obs_label, color = obs_colours[obs_label])

    if ivar == 0:
        axis_abs.legend(bbox_to_anchor = [0.45,-3.4,1,1],ncol=2)

    axis_abs.set_title(letters[ivar],fontproperties=font_medium)
    axis_abs.set_xlim(0,12)
    axis_anom.set_xlim(0,12)
    axis_abs.set_xticks(np.arange(0,13))

    if xi == 0:
        axis_abs.set_ylabel('Flux ($W/m^{-2}$)',fontproperties=font_medium)
        axis_anom.set_ylabel('Flux anomaly\n($W/m^{-2}$)',\
            fontproperties=font_medium)

    axis_abs.plot([0,12],[0,0],linestyle='--',color='k')
    axis_abs.set_ylim(ylims[ivar])
    axis_anom.plot([0,12],[0,0],linestyle='--',color='k')
    axis_anom.set_xticks(np.arange(0,13))
    axis_anom.set_ylim([-50,50])
    axis_anom.set_yticks(np.arange(-50,51,25))
    for xlabel in axis_abs.get_xticklabels()+axis_anom.get_xticklabels():
        xlabel.set_visible(False)

    ylim_abs = axis_abs.get_ylim()
    ylim_anom = axis_anom.get_ylim()
    for month in range(1,13):
        month_name = dt.date(1980,month,1).strftime('%b')[0]
        axis_anom.text(month-.5,ylim_anom[0] - .18 * \
               (ylim_anom[1]-ylim_anom[0]),\
            month_name,ha='center')

    for month in map_months[var]:
        month_name = dt.date(1980,month,1).strftime('%b')
        obs_index = np.where(obs_sc_cubes[var].coord('month').points == \
                             month_name)
        mod_index = np.where(mod_sc_cubes[var].coord('month').points == \
                             month_name)
        obs_cube = obs_sc_cubes[var][obs_index[0],:,:]
        mod_cube = mod_sc_cubes[var][mod_index[0],:,:]
        obs_cube_cfg = copy.deepcopy(mod_cube)
        obs_cube_cfg.data = obs_cube.data
        anom_cube = mod_cube - obs_cube_cfg
        regions.reduce_to_region(anom_cube,'Arctic_Ocean.1','n96')

        axis_abs.fill_between([month-1,month],[axis_abs.get_ylim()[0]]*2,\
            [axis_abs.get_ylim()[1]]*2,color='#f4e8d0')
        axis_anom.fill_between([month-1,month],[axis_anom.get_ylim()[0]]*2,\
            [axis_anom.get_ylim()[1]]*2,color='#f4e8d0')

        xl_c = xl + xi*xgap + (month-.5)/12. * xw - xw_c * .5 + \
               xw*xoffsets[ivar]
        yb_c = yt - yi*ygap - yw_abs + yw_abs*yoffsets[ivar] + aa_gap + yw_anom
        axis_c = figure.add_axes([xl_c,yb_c,xw_c,yw_c],projection=nps)
        clev = pf.polarplot(anom_cube[0,:,:],figure=figure,axis=axis_c,\
            show=False,colorbar=False, levels = levels, cmap=cmap,\
            rrange = 3.5e6, title = ' ')

    axis_abs.set_ylim(ylim_abs)
    axis_anom.set_ylim(ylim_anom)
    axis_abs.set_yticks(yticks[ivar])

    for label in axis_abs.get_yticklabels() + axis_anom.get_yticklabels():
        label.set_fontproperties(font_medium)

ax_cl = figure.add_axes([.54,.05,.42,.02])
cb = plt.colorbar(clev,cax=ax_cl,orientation='horizontal')
for label in ax_cl.get_xticklabels():
    label.set_fontproperties(font_medium)

ax_cl.text(.5,-2.2,'HadGEM2-ES - CERES flux anomaly ($W/m^2$)',ha='center',\
    fontproperties=font_medium)
cb.set_ticks(levels[1:-1])

pylab.savefig(outfile_eps)
pylab.savefig(outfile_png)
