import iris
import numpy as np
import plot_functions as pf
import matplotlib.pyplot as plt
import datetime as dt
from arrays import wrap
import pylab
import cartopy.crs as ccrs
import regions
import cube_magic as cm
import iris.coord_categorisation as icc
import copy
from matplotlib import font_manager
import sys
sys.path.insert(0,'/home/h01/hadax/Python/PhD/ISF/')
import get_ma_cube

font_large = font_manager.FontProperties(size=16)
font_medium = font_manager.FontProperties(size=13)

outfile_eps = '/home/h01/hadax/graphics/PhD/Papers/'+\
    'Dec2013_HadGEM2_energy_budget/'+\
    'v2.3_figs/Figures/F6_plot_isf.v2.eps'
outfile_png = '/home/h01/hadax/graphics/PhD/Papers/'+\
    'Dec2013_HadGEM2_energy_budget/'+\
    'v2.3_figs/Figures/F6_plot_isf.v2.png'

model = 'HadGEM2_ES'
experiments = ['Hist'+str(em) for em in range(1,5)]
startyear = 1980
endyear = 1999
version = 'v4_06_09_2018'
region_name = 'Arctic_Ocean.0'
obs_names = ['CERES','ISCCP','ERAI']
processes_read = ['SW_down','LW_down','iimelt','Ice_area'] + \
    ['Ice_thickness.complex.cat.' + str(cat) for cat in range(1,6)]
processes_plot = ['SW_down','LW_down','iimelt','Ice_area','Ice_thickness']

obs_titles = {'CERES':'CERES', 'ERAI': 'ERAI', 'ISCCP':'ISCCP-FD'}

linestyles = {'ISCCP': (4,4),'ERAI': (6,2),'CERES': (5,3,1,3)}

all_cubes_contour = {}
processes_contour = {'LW_down':2, 'iimelt':6, 'Ice_area':8, 'Ice_thickness': 11}

for obs_name in obs_names:
    obs_scheme = 'std_' + obs_name
    
    infos = {process: {'model':model,'experiments':experiments,'startyear':startyear,\
        'endyear':endyear,'version':version,'obs_scheme':obs_scheme,\
        'process':process} for process in processes_contour.keys()}
    cubes_sc_contour = {key: get_ma_cube.get_cube(infos[key]) for key in processes_contour.keys()}
    cubes_contour = {key: cubes_sc_contour[key][processes_contour[key]-1,:,:] for key in processes_contour.keys()}
    all_cubes_contour[obs_name] = cubes_contour

ylims = {'CERES':[-23,29], 'ERAI':[-16,33]}
ylim = [-45,35]

obs_scheme = 'std_' + obs_name

colors = {'SW_down': '#ffff80',\
          'LW_down': '#c0c0ff',\
          'iimelt': '#ffa0a0',\
          'Ice_area': '#bfffa0', \
          'Ice_thickness':'#c0c0c0' }
titles = {'SW_down': 'Downwelling SW',\
          'LW_down': 'Downwelling LW',\
          'iimelt':  'Surface melt onset',\
          'Ice_area': 'Ice fraction',\
          'Ice_thickness': 'Ice thickness'}

compare_file = '/data/local/hadax/PhD/Attribution/HadGEM2_ES/Hist1/monthly/'+\
    'hadgom/aux/flux_anomaly.nc'
compare_field = iris.load_cube(compare_file)
regions.reduce_to_region(compare_field,region_name,'hadgom')
cm.llbounds(compare_field)
aw = iris.analysis.cartography.area_weights(compare_field)
ts_flux_anomaly = compare_field.collapsed(['latitude','longitude'],\
    iris.analysis.MEAN, weights = aw)

rad_files = {obs_name:'/data/users/hadax/PhD/Papers/Dec2013_HadGEM2_energy_budget/v2.3/radiative_anom.'+obs_name+'.'+region_name+'.nc' for obs_name in obs_names}
ts_rad = {obs_name: iris.load_cube(rad_files[obs_name]) for obs_name in obs_names}

basedir = '/data/local/hadax/PhD/Attribution/'
datadirs1 = [basedir + 'output_'+version+'/output_1D/' + model + \
    '/' + expt + '/monthly/hadgom/' for expt in experiments]
datadirs3 = [basedir + 'output_'+version+'/output_3D/' + model + \
    '/' + expt + '/monthly/hadgom/' for expt in experiments]

cubedicdic = {}
cubes_contour = {}
for obs_name in obs_names:
    obs_scheme = 'std_' + obs_name
    cubedic = {}
    for process in processes_read:
        ensemble_cubelist = iris.cube.CubeList([])
        for (iexpt,datadir) in enumerate(datadirs1):
            filename = process + '.' + obs_scheme + '.' + \
                str(endyear) + '.' + region_name + '.nc'
            ffile = datadir + filename
            ccube = iris.load_cube(ffile)
            rcoord = iris.coords.AuxCoord(iexpt+1,'realization')
            ccube.add_aux_coord(rcoord)
            ensemble_cubelist.append(ccube)

        ensemble_cubes = ensemble_cubelist.merge()
        if len(ensemble_cubes) > 1:
            print 'Could not merge ensemble cubelist'
            raise ValueError

        all_cube = ensemble_cubes[0]
        icc.add_month(all_cube,'time')
        mean_cube = all_cube.collapsed(['realization'],iris.analysis.MEAN)
        sc_cube = mean_cube.aggregated_by(['month'],iris.analysis.MEAN)
        cubedic[process] = sc_cube

    hice_cube = copy.deepcopy(cubedic['Ice_thickness.complex.cat.1'])
    hice_cube.data[:] = 0.
    for cat in range(1,6):
        hice_cube = hice_cube + cubedic['Ice_thickness.complex.cat.'+str(cat)]
    cubedic['Ice_thickness'] = hice_cube
    cubedicdic[obs_name] = cubedic

pda_dic = {}
nda_dic = {}
pbase_dic = {}
nbase_dic = {}
total_dic = {}

for obs_name in obs_names:
    pda = {}
    nda = {}
    pbase = {}
    nbase = {}
    pbr = np.zeros(12)
    nbr = np.zeros(12)
    total = np.zeros(12)
    for process in processes_plot:
        pbase[process] = copy.copy(pbr)
        nbase[process] = copy.copy(nbr)
        cube = cubedicdic[obs_name][process]
        pda[process] = cube.data * (cube.data > 0.).astype('int')
        nda[process] = cube.data * (cube.data < 0.).astype('int')  
        pbr = pbr + pda[process]
        nbr = nbr + nda[process]
        total = total + cubedicdic[obs_name][process].data
    
    total_dic[obs_name] = total

    pda_dic[obs_name] = pda
    nda_dic[obs_name] = nda
    pbase_dic[obs_name] = pbase
    nbase_dic[obs_name] = nbase
    

xl = .1
yb = .3
xw = .82
yw = .7
xx_plot = np.arange(0.,12,1)
xx_plot14 = np.arange(-.3,13,1)
width = .08
width_t = .06
offset_cpts = .1
offset_total = .2
xgap = .28

figure = plt.figure(figsize=(12,12))
axis = figure.add_axes([xl,yb,xw,yw])

for (iobs,obs_name) in enumerate(obs_names):
    for (iproc,process) in enumerate(processes_plot):
        axis.bar(xx_plot + offset_cpts + xgap * iobs, pda_dic[obs_name][process], width=width, bottom = pbase_dic[obs_name][process], \
            color = colors[process], label = titles[process]*(iobs==0))
        axis.bar(xx_plot + offset_cpts + xgap * iobs, nda_dic[obs_name][process], width=width, bottom = nbase_dic[obs_name][process], \
            color = colors[process])
    
    axis.bar(xx_plot + offset_total + xgap * iobs, total_dic[obs_name], width = width_t, color = 'k', label = 'TOTAL'*(iobs==0))
    
    axis.text(xx_plot[0]+offset_cpts + xgap * iobs+.04, \
              pda_dic[obs_name]['Ice_area'][0]+.4, \
              obs_titles[obs_name],ha='center',va='bottom',rotation=-90)
    
    axis.text(xx_plot[-1]+offset_cpts + xgap * iobs+.04, \
              pda_dic[obs_name]['Ice_area'][-1]+.4, \
              obs_titles[obs_name],ha='center',va='bottom',rotation=-90)

axis.plot([0,12],[0,0],linestyle='--',color='k')

axis.set_xticks(np.arange(0,13))
for label in axis.get_xticklabels():
    label.set_visible(False)

for label in axis.get_yticklabels():
    label.set_fontproperties(font_large)

axis.set_ylabel('Induced surface flux bias ($Wm^{-2}$)',fontproperties=font_large)
axis.set_xlabel('Month',fontproperties=font_large,labelpad=22)

axis.plot(xx_plot14,wrap(ts_flux_anomaly.data),color = '#0000c0',\
        linewidth=3,marker='*',markersize=10, \
        label = 'Sea ice latent heat uptake\nbias relative to PIOMAS')

xcs = .16
ycs = .19
yms = [.8, .1, .1 ,.8]
xoff = [0.,-.4,.4,0.]
nps = ccrs.NorthPolarStereo()
levels = np.arange(-30,31,5)
cmap = plt.get_cmap('RdBu_r')
for (iproc,process) in enumerate(processes_contour.keys()):
    month_name = dt.date(1980,processes_contour[process],1).strftime('%b')
    title = titles[process]+', '+month_name
    plot_cube = all_cubes_contour['CERES'][process]
    regions.reduce_to_region(plot_cube,region_name,'hadgom')
    xm = processes_contour[process]-1.5
    xcl = (xm - xcs / 2. + xoff[iproc])*xw/12. + xl
    ycb = (yms[iproc] - ycs / 2.)*yw + yb
    ax_c = figure.add_axes([xcl,ycb,xcs,ycs],projection=nps)
    clev = pf.polarplot(plot_cube,figure=figure,axis=ax_c,levels=levels,\
        cmap = cmap, colorbar = False, show = False, title = title,\
        rrange = 3.e6)

ax_cl = figure.add_axes([.55,.2,.38,.02])
plt.colorbar(clev,cax=ax_cl,orientation = 'horizontal')
ax_cl.text(.5,-2.,'Flux bias ($Wm^{-2}$)',ha='center',\
    fontproperties = font_large)

for label in ax_cl.get_xticklabels():
    label.set_fontproperties(font_large)

axis.plot([0],[0],color = '#ffffff',\
    label = 'Net radiative anomalies\nevaluated with respect to')
for obs_name in obs_names:
    line, = axis.plot(xx_plot14,wrap(ts_rad[obs_name].data),color='#c00000',\
        linewidth=3,marker='*',markersize=10,label = obs_titles[obs_name])
    line.set_dashes(linestyles[obs_name])

for month in range(1,13):
    axis.plot([month,month],ylim,color = '#c0c0c0')
    axis.text(month-.5,ylim[0] - .03*(ylim[1]-ylim[0]),str(month),ha='center',\
        fontproperties = font_large)

axis.legend(ncol=2,bbox_to_anchor = [.5,-.09,0.,0.],\
    prop = font_medium)
axis.set_xlim(0,12)
axis.set_ylim(ylim)
pylab.savefig(outfile_eps)
pylab.savefig(outfile_png)
