import iris
import regions
import matplotlib.pyplot as plt
import pylab
import iris
from matplotlib.font_manager import FontProperties
import numpy as np
import matplotlib
from arrays import wrap
import timemod
import datetime as dt
import numpy.ma as ma
import cartopy.crs as ccrs
import plot_functions as pf
import copy
import sys
sys.path.insert(0,'/home/h01/hadax/Python/PhD/Obs/')
import envisat
import iris.coord_categorisation as icc
import cube_magic as cm

outfile_eps = '/home/h01/hadax/graphics/PhD/Papers/'+\
    'Dec2013_HadGEM2_energy_budget/'+\
    'v3.1/Figures/F2_sea_ice_evaluation.eps'

outfile_png = '/home/h01/hadax/graphics/PhD/Papers/'+\
    'Dec2013_HadGEM2_energy_budget/'+\
    'v3.1/Figures/F2_sea_ice_evaluation.png'

model = 'HadGEM2_ES'
experiments = ['Hist' + str(em) for em in range(1,5)]
period = [1980,1999]
hice_months_show = [4,10]
aice_months_show = [9]

basedir_ma = '/data/local/hadax/PhD/multiannual_fields/'
basedir_ts = '/data/local/hadax/PhD/timeseries/'

rcodes_hice = ['Arctic_Ocean.0','VAR.Envisat','SCICEX']
obs_names = ['PIOMAS','Envisat','subs']
obs_tcodes = ['1980.1999','1993.1999','1980.1999']
mod_tcode = str(period[0]) + '.' + str(period[1])

hg2_aice_cubelist = iris.cube.CubeList([])
for (iexpt,experiment) in enumerate(experiments):
    scube = iris.load_cube(basedir_ma + model + '/' + experiment + '/' + \
   'ICE_STATE/' + 'monthly/' + 'Ice_area.'+mod_tcode+'.MEAN.nc')
    rcoord = iris.coords.AuxCoord(iexpt+1,'realization')
    scube.add_aux_coord(rcoord)
    hg2_aice_cubelist.append(scube)

hg2_aice_cubes = hg2_aice_cubelist.merge()
if len(hg2_aice_cubes) > 1:
    print 'Could not merge ice area cubes'
    raise ValueError

hg2_aice_cube = hg2_aice_cubes[0]
hg2_aice_cube_mean = hg2_aice_cube.collapsed(['realization'],iris.analysis.MEAN)
hadisst_aice_cube = iris.load_cube(basedir_ma + 'HadISST/hadgom/' + \
   'Ice_area.'+mod_tcode+'.'+rcodes_hice[0]+'.nc')

hg2_time_cubes = iris.cube.CubeList([])
for tcode in obs_tcodes:
    hg2_hice_cubelist = iris.cube.CubeList([])
    for (iexpt,experiment) in enumerate(experiments):
        scube = iris.load_cube(basedir_ma + model + '/' + experiment + '/' + \
       'ICE_STATE/' + 'monthly/' + 'Ice_volume.'+tcode+'.MEAN.nc')
        rcoord = iris.coords.AuxCoord(iexpt+1,'realization')
        scube.add_aux_coord(rcoord)
        hg2_hice_cubelist.append(scube)

    hg2_hice_cubes = hg2_hice_cubelist.merge()
    if len(hg2_hice_cubes) > 1:
        print 'Could not merge ice volume cubes'
        raise ValueError

    hg2_hice_cube = hg2_hice_cubes[0]
    hg2_hice_cube_mean = hg2_hice_cube.collapsed(['realization'],iris.analysis.MEAN)
    hg2_time_cubes.append(hg2_hice_cube_mean)
    
    
obs_hice_cubes = [iris.load_cube(basedir_ma + obs_name + '/hadgom/' + \
    'Ice_volume.'+tcode+'.'+rcode+'.nc') for (obs_name,rcode,tcode) in \
    zip(obs_names,rcodes_hice,obs_tcodes)]

hg2_extent_ts_cubelist = iris.cube.CubeList([])
for (iexpt,experiment) in enumerate(experiments):
    scube = iris.load_cube(basedir_ts + model + '/' + experiment + '/' + \
   'ICE_STATE/' + 'monthly/' + 'Ice_extent.'+mod_tcode+'.'+rcodes_hice[0]+'.SUM.nc')
    rcoord = iris.coords.AuxCoord(iexpt+1,'realization')
    scube.add_aux_coord(rcoord)
    hg2_extent_ts_cubelist.append(scube)

hg2_extent_ts_cubes = hg2_extent_ts_cubelist.merge()
if len(hg2_extent_ts_cubes) > 1:
    print 'Could not merge ice ts extent cubes'
    raise ValueError

hg2_extent_ts = hg2_extent_ts_cubes[0]
hg2_extent_ts_mean = hg2_extent_ts.collapsed(['realization'],iris.analysis.MEAN)
hadisst_extent_ts = iris.load_cube(basedir_ts + 'HadISST' + '/' + \
   'Ice_extent.'+mod_tcode+'.'+rcodes_hice[0]+'.SUM.nc')

icc.add_month(hg2_extent_ts_mean,'time')
icc.add_month(hadisst_extent_ts,'time')

hg2_extent_sc = hg2_extent_ts_mean.aggregated_by('month',iris.analysis.MEAN)
hadisst_extent_sc = hadisst_extent_ts.aggregated_by('month',iris.analysis.MEAN)

hg2_hice_cubes = iris.cube.CubeList([])
for (region,time_cube) in zip(rcodes_hice,hg2_time_cubes):
    ccube = copy.deepcopy(time_cube)
    if region == 'VAR.Envisat':
        envisat_cube_sc = envisat.gbm_climo()
        areas = envisat.monthly_mask_3D(ccube,envisat_cube_sc)
    else:
        regions.reduce_to_region(ccube,region,'hadgom')

    hg2_hice_cubes.append(ccube)

hg2_ts_hice_cubes = iris.cube.CubeList([])
for cube in hg2_hice_cubes:
    cm.llbounds(cube)
    aw = iris.analysis.cartography.area_weights(cube)
    ts_cube = cube.collapsed(['latitude','longitude'],iris.analysis.MEAN,\
        weights = aw)
    hg2_ts_hice_cubes.append(ts_cube)


obs_ts_hice_cubes = iris.cube.CubeList([])
for cube in obs_hice_cubes:
    cm.llbounds(cube)
    aw = iris.analysis.cartography.area_weights(cube)
    ts_cube = cube.collapsed(['latitude','longitude'],iris.analysis.MEAN,\
        weights = aw)
    obs_ts_hice_cubes.append(ts_cube)

xl = .1
yb = .13
ygap = .455
xgap = .5
yw = .38
xw = .38
xx_plot = np.arange(-.5,13)
y_offset_aice = .08
y_offsets_hice = [[.24,.8],[.24,.8],[.47,.2]]
x_offsets_hice = [[0.,0.],[.0,0.],[0.015,0.]]
#ylims = [[0.8,2.6],[1.2,2.6],[1,5.5]]
ylims = [[0,3.9],[0,3.9],[0,3.9]]
aice_scale = 1.e12

levels_aice = np.arange(-1.,1.01,0.1)
levels_hice = np.arange(-2,2.1,.5)
cmap = plt.get_cmap('PRGn')

xw_c = .195
yw_c = .135

nps = ccrs.NorthPolarStereo()

matplotlib.rcParams.update({'font.size':20})
legend_fontP = FontProperties()
legend_fontP.set_size(13)

model_colour = '#000000'
obs_colour = '#ff0000'

fig = plt.figure(figsize=(12,12))
ax_ts_extent = fig.add_axes([xl+xgap*0.,yb+ygap*1.,xw,yw])
ax_ts_piomas = fig.add_axes([xl+xgap*1.,yb+ygap*1.,xw,yw])
ax_ts_ers    = fig.add_axes([xl+xgap*0.,yb+ygap*0.,xw,yw])
ax_ts_subs   = fig.add_axes([xl+xgap*1.,yb+ygap*0.,xw,yw])

ax_ts_extent.plot(xx_plot,wrap(timemod.sc_for_plotting(hg2_extent_sc))/aice_scale,\
    color = model_colour, label = 'HadGEM2-ES')
ax_ts_extent.plot(xx_plot,wrap(timemod.sc_for_plotting(hadisst_extent_sc))/aice_scale,\
    color = obs_colour, label = 'HadISST1.2')
ax_ts_extent.set_ylim(0,10)

thickness_obs_labels = ['PIOMAS','Envisat','Submarine\nregression\nanalysis']
bbox_to_anchors = [[.34,.75,.2,.2],[.34,.75,.2,.2],[.34,.1,.2,.2]]
for (iax,ax) in enumerate((ax_ts_piomas,ax_ts_ers,ax_ts_subs)):
    obs_plot = timemod.sc_for_plotting(obs_ts_hice_cubes[iax])
    obs_plot = ma.masked_array(obs_plot, mask = np.zeros(12,dtype = 'bool'))
    mod_plot = timemod.sc_for_plotting(hg2_ts_hice_cubes[iax])
    mod_plot = ma.masked_array(mod_plot, mask = np.zeros(12,dtype = 'bool'))
    if iax == 1:
        obs_plot.mask[4:9] = True
        mod_plot.mask[4:9] = True
    ax.plot(xx_plot,wrap(mod_plot),\
        color = model_colour, label = 'HadGEM2-ES')
    ax.plot(xx_plot,wrap(obs_plot),\
        color = obs_colour, label = thickness_obs_labels[iax])
    ax.set_ylim(ylims[iax])
    ax.legend(bbox_to_anchor=bbox_to_anchors[iax],prop=legend_fontP)

for month in aice_months_show:
    month_name = dt.date(1980,month,1).strftime('%b')
    obs_index = np.where(hadisst_aice_cube.coord('month').points==month_name)
    mod_index = np.where(hg2_aice_cube_mean.coord('month').points==month_name)
    obs_cube = hadisst_aice_cube[obs_index[0],:,:]
    mod_cube = hg2_aice_cube_mean[mod_index[0],:,:]

    obs_cube_cfg = copy.deepcopy(mod_cube)
    obs_cube_cfg.data = obs_cube.data
    anom_cube = mod_cube - obs_cube_cfg

    xl_c = xl + xw * ((month-.5)/12.) - xw_c/2.
    yb_c = yb + ygap + y_offset_aice - yw_c/2.

    ax_contour = fig.add_axes([xl_c,yb_c,xw_c,yw_c],projection = nps)
    clev_aice = pf.polarplot(anom_cube[0,:,:],axis=ax_contour,figure=fig,\
        levels=levels_aice,title = ' ',\
        cmap = cmap,show=False, colorbar = False,rrange=3.5e6)

for (imonth,month) in enumerate(hice_months_show):
    for (icube,(mod_sc_cube,obs_sc_cube)) in \
        enumerate(zip(hg2_hice_cubes,obs_hice_cubes)):
        
        month_name = dt.date(1980,month,1).strftime('%b')
        obs_index = np.where(obs_sc_cube.coord('month').points==month_name)
        mod_index = np.where(mod_sc_cube.coord('month').points==month_name)
        obs_cube = obs_sc_cube[obs_index[0],:,:]
        mod_cube = mod_sc_cube[mod_index[0],:,:]

        obs_cube_cfg = copy.deepcopy(mod_cube)
        obs_cube_cfg.data = obs_cube.data
        anom_cube = mod_cube - obs_cube_cfg

        xi = (icube+1) % 2
        yi = 1 - (icube+1) / 2
        xl_c = xl + xw * ((month-.5)/12.) - xw_c/2. + xi * xgap + \
               x_offsets_hice[icube][imonth]*xw
        yb_c = yb + y_offsets_hice[icube][imonth]*yw - yw_c/2. + yi * ygap

        ax_contour = fig.add_axes([xl_c,yb_c,xw_c,yw_c],projection = nps)
        clev_hice = pf.polarplot(anom_cube[0,:,:],axis=ax_contour,figure=fig,\
            levels=levels_hice,title = ' ',\
            cmap = cmap,show=False, colorbar=False,rrange=3.5e6)

letters = ['(a)','(b)','(c)','(d)']
for (ii,ax) in enumerate((ax_ts_extent,ax_ts_piomas,ax_ts_ers,ax_ts_subs)):
    ax.set_xlim([0,12])
    ylim = ax.get_ylim()
    ax.properties()['xaxis'].properties()['major_ticks'] = [0,12]
    for xlabel in ax.axes.get_xticklabels():
        xlabel.set_visible(False)
    
    if ii == 0:
        ax.set_ylabel('Ice extent (x $10^6 km^2$)')
    else:
        ax.set_ylabel('Ice thickness (m)')

    if ii == 0:
        for month in aice_months_show:
            ax.fill_between([month-1,month],[ax.get_ylim()[0]]*2,[ax.get_ylim()[1]]*2,color='#e0e0e0')
    else:
        for month in hice_months_show:
            ax.fill_between([month-1,month],[ax.get_ylim()[0]]*2,[ax.get_ylim()[1]]*2,color='#e0e0e0')
    ax.set_ylim(ylim)
    ax.set_title(letters[ii])    


for im in range(0,13):
    for ax in (ax_ts_extent,ax_ts_piomas,ax_ts_subs,ax_ts_ers):
        ylim = ax.get_ylim()
        ax.plot([im,im],[ylim[0] + (ylim[1]-ylim[0]) * yy for yy in [0,.05]],color='k')
        if im != 12:
            ax.text(im+.5,ylim[0] - (ylim[1]-ylim[0]) * .07,dt.date(1980,im+1,1).strftime('%b')[0],\
            ha='center')
        ax.set_ylim(ylim)


ax_ts_extent.legend(bbox_to_anchor = [.32,.43,.2,.2],prop=legend_fontP)
ax_cl_aice = fig.add_axes([.11,.055,.35,.02])
ax_cl_hice = fig.add_axes([.61,.055,.35,.02])
cb_aice = plt.colorbar(clev_aice,cax=ax_cl_aice,orientation='horizontal')
cb_hice = plt.colorbar(clev_hice,cax=ax_cl_hice,orientation='horizontal')
cb_aice.set_ticks(np.arange(-1,1.01,.5))
cb_hice.set_ticks(np.arange(-2.,2.01,1.))
ax_cl_aice.text(.5,-2.5,'Ice fraction anomaly',ha='center')
ax_cl_hice.text(.5,-2.5,'Ice thickness anomaly (m)', ha='center')

pylab.savefig(outfile_eps)
pylab.savefig(outfile_png)
