import iris
import numpy as np
import matplotlib.pyplot as plt
import pylab
from arrays import wrap
import cube_magic as cm
import iris.coord_categorisation as icc
import timemod
import regions
import datetime as dt
import copy
from matplotlib.font_manager import FontProperties
import datetime as dt
import plot_functions as pf
import cartopy.crs as ccrs
import monty

font_large = FontProperties(size=20)
font_medium = FontProperties(size=13)

outfile_eps = '/home/h01/hadax/graphics/PhD/Papers/'+\
    'Dec2013_HadGEM2_energy_budget/'+\
    'v2.2_figs/F5_flux_contribution_barplot.eps'
outfile_png = '/home/h01/hadax/graphics/PhD/Papers/'+\
    'Dec2013_HadGEM2_energy_budget/'+\
    'v2.2_figs/F5_flux_contribution_barplot.png'

radiation_datasets = ['CERES','ISCCP-FD','ERAI']
start_month = 5
model = 'HadGEM2_ES'
experiment = 'Hist1'
model_year_start = 1980
model_year_end = 1999
region_name = 'Arctic_Ocean.1'

nc = 5
nrd = len(radiation_datasets)

xl = .1
yb = .45
xw = .82
yw = .53

nps = ccrs.NorthPolarStereo()
levels = np.array([-50,-20,-10,-5,-2,0,2,5,10,20,50])
cmap = monty.clr_cmap('/home/h01/hadax/IDL/colour_tables/anomalies.clr')

figure = plt.figure(figsize=(12,12))
axis = figure.add_axes([xl,yb,xw,yw])

data_array = np.zeros((nrd,nc,12))
xx_plot = np.arange(.26,12)
width = .1
xgap = .22
xw_c = .1
yw_c = .13
yoffset = .24

xm = .017
ym = .017


compare_file = '/data/local/hadax/PhD/Attribution/HadGEM2_ES/Hist1/monthly/'+\
    'hadgom/aux/flux_anomaly.nc'
compare_field = iris.load_cube(compare_file)
regions.reduce_to_region(compare_field,region_name,'hadgom')
cm.llbounds(compare_field)
aw = iris.analysis.cartography.area_weights(compare_field)
ts_flux_anomaly = compare_field.collapsed(['latitude','longitude'],\
    iris.analysis.MEAN, weights = aw)

n_model_years = model_year_end - model_year_start + 1
translate_indices = (np.arange(12) + start_month - 1) % 12

for (ird,radiation_dataset) in enumerate(radiation_datasets):
    contributions = [\
        {'label':    'Ice fraction',\
         'grid': 'n96', \
         'datafile': '/data/local/hadax/PhD/Attribution/'+model+'/'+\
                experiment+'/monthly/n96/output/'+\
            'iifrac.SW_net.NSIDCC-'+radiation_dataset+'.',\
         'plot_color': '#bfffa0', \
         'months_plot': [8]}, \
        {'label':    'Melt onset', \
         'grid': 'n96', \
         'datafile': '/data/local/hadax/PhD/Attribution/'+model+'/'+\
                experiment+'/monthly/n96/output/'+\
            'iimelt.SW_net.melt_onset.',\
         'plot_color': '#ffa0a0', \
         'months_plot': [6]}, \
        {'label':    'Ice thickness',\
         'grid': 'hadgom', \
         'datafile': '/data/local/hadax/PhD/Attribution/'+model+'/'+\
                experiment+'/monthly/hadgom/output/'+\
            'Ice_volume.LW_net.'+radiation_dataset+'-PIOMAS.',\
         'plot_color': '#c0c0c0', \
         'months_plot': [11]},\
        {'label':    'Downwelling SW', \
         'grid': 'n96', \
         'datafile': '/data/local/hadax/PhD/Attribution/'+model+'/'+\
                experiment+'/monthly/n96/output/'+\
            'SW_down.SW_net.NSIDCC-melt_onset-PIOMAS-'+radiation_dataset+'.',\
         'plot_color': '#ffff80', \
         'months_plot': []}, \
        {'label':    'Downwelling LW',\
         'grid': 'hadgom', \
         'datafile': '/data/local/hadax/PhD/Attribution/'+model+'/'+\
                experiment+'/monthly/hadgom/output/'+\
            'LW_down.LW_net.'+radiation_dataset+'-PIOMAS.',\
         'plot_color': '#c0c0ff', \
         'months_plot': [2]}]

    for (ic,contribution) in enumerate(contributions):
        print 'Calculating contribution '+contribution['label']
        accum_array = np.zeros(12)

        for year in range(model_year_start,model_year_end+1):
            datacube = iris.load_cube(contribution['datafile'] + str(year) + '.nc')
            regions.reduce_to_region(datacube,region_name,contribution['grid'])
            cm.llbounds(datacube)
            aw = iris.analysis.cartography.area_weights(datacube)
            tscube = datacube.collapsed(['latitude','longitude'],iris.analysis.MEAN,weights=aw)
            ind = np.where(tscube.data.mask == True)
            tscube.data.mask[ind] = False
            tscube.data.data[ind] = 0.
            accum_array = accum_array + tscube.data

            if year == model_year_start:
                accum_cube = copy.deepcopy(datacube)
            else:
                ic_cube = copy.deepcopy(accum_cube)
                ic_cube.data = datacube.data
                accum_cube = accum_cube + ic_cube

        contribution['sc_cube'] = accum_cube / n_model_years
        try:
            icc.add_month(contribution['sc_cube'],'time')
        except ValueError:
            pass

        accum_array = accum_array / n_model_years
        sc_plot = accum_array[translate_indices]

        data_array[ird,ic,:] = sc_plot

    pda = copy.copy(data_array)
    nda = copy.copy(data_array)

    pda = pda * (pda > 0).astype('int')
    nda = nda * (nda < 0).astype('int')

    pbase = np.zeros(12)
    nbase = np.zeros(12)

    axis.plot([0,12],[0,0], linewidth = 2, linestyle='--', color = '#c0c0c0')
    for (ic,contribution) in enumerate(contributions):
        if ird == 0:
            label = contribution['label']
        else:
            label = None
        axis.bar(xx_plot+xgap*ird, pda[ird,ic,:], width=width, bottom = pbase, \
            color = contribution['plot_color'], label = label)
        axis.bar(xx_plot+xgap*ird, nda[ird,ic,:], width=width, bottom = nbase, \
            color = contribution['plot_color'])
        pbase = pbase + pda[ird,ic,:]
        nbase = nbase + nda[ird,ic,:]
    
    for contribution in contributions:
        for month_plot in contribution['months_plot']:
            month_name = dt.date(1980,month_plot,1).strftime('%b')
            month_tr = (month_plot + 8) % 12
            xl_c = xl + (month_tr-.5)/12. * xw - xw_c / 2.
            yb_c = yb - yoffset
            ax_c_bg = figure.add_axes([xl_c-xm,yb_c-ym,xw_c+xm*2,\
                                       yw_c+ym*4])
            ax_c_bg.set_xlim([0,1])
            ax_c_bg.set_ylim([0,1])
            ax_c_bg.fill_between([0,1],[0,0],[1,1],\
               color=contribution['plot_color'])
            for label in ax_c_bg.get_xticklabels()+ax_c_bg.get_yticklabels():
                label.set_visible(False)

            ax_c_bg.set_xticks([0,1])
            ax_c_bg.set_yticks([0,1])
            ax_c_bg.get_xaxis().set_visible(False)
            ax_c_bg.get_yaxis().set_visible(False)

            ax_c = figure.add_axes([xl_c,yb_c,xw_c,yw_c],projection=nps)
            sf_cube = contribution['sc_cube']
            index = np.where(sf_cube.coord('month').points == month_name)
            month_cube = sf_cube[index[0][0],:,:]
            clev = pf.polarplot(month_cube, axis = ax_c, figure = figure, \
                levels = levels, cmap = cmap, show=False,\
                colorbar = False, title = ' ', rrange = 3.5e6)
            ax_c.set_title(contribution['label']+'\n'+\
                                       month_name,fontproperties = font_medium)

    axis.text(xx_plot[0]+xgap*ird-.06,nbase[0]-0.5,radiation_dataset,rotation=-90.,va='top')
    axis.text(xx_plot[3]+xgap*ird-.06,nbase[3]-0.5,radiation_dataset,rotation=-90.,va='top')
    axis.text(xx_plot[6]+xgap*ird-.06,0.5,radiation_dataset,rotation=-90.,\
        va='bottom')
    axis.text(xx_plot[9]+xgap*ird-.06,0.5,radiation_dataset,rotation=-90.,\
        va='bottom')

xsample = np.array([.3,.5,.7])
data_sample_p = [.6,.5,.2]
data_sample_n = [-.2,-.5,-.4]
#ax_key = figure.add_axes([.85,yb_c-ym,xw_c+xm*2,\
#                                       yw_c+ym*4])
#ax_key.set_ylim([-2,1])
#ax_key.set_xlim([0,1])
#ax_key.plot([0,1],[0,0],linewidth = 2, linestyle='--', color = '#c0c0c0')
#ax_key.bar(xsample-width/2.,data_sample_p,color = '#c0c0c0',width=width)
#ax_key.bar(xsample-width/2.,data_sample_n,color = '#c0c0c0',width=width)

#for label in ax_key.get_xticklabels() + ax_key.get_yticklabels():
#    label.set_visible(False)

#ax_key.set_yticks([-2,1])
#ax_key.set_xticks([0,1])
#ax_key.plot([.1,.1],[-2,1],linestyle='--',color = 'k')
#ax_key.plot([.9,.9],[-2,1],linestyle='--',color = 'k')

#for (ix,xval) in enumerate(xsample):
#    ax_key.text(xval-width*.7,data_sample_n[ix]-.28,radiation_datasets[ix],\
#         rotation = -90)

ax_cl = figure.add_axes([.3,.045,.4,.015])
ax_cl.set_yticks(levels)
plt.colorbar(clev,cax=ax_cl,orientation='horizontal')
ax_cl.text(.5,-2.8,'Induced surface flux anomaly ($W/m^2$)',ha='center')

axis.plot(xx_plot+xgap, ts_flux_anomaly.data[translate_indices], \
    marker='o',linewidth = 3, markersize=12,\
    color = '#202020',label = 'Latent heat flux anomaly\nimplied by ice volume anomalies\nwith respect to PIOMAS')

axis.legend(ncol=2,bbox_to_anchor = [.5,-.5,0.,0.])
axis.set_ylabel('Induced surface flux anomaly ($W/m^2$)',fontproperties=font_large)

sum_array = np.sum(data_array,axis=1)
for ird in range(nrd):
    axis.plot(xx_plot + width * .5 + ird * xgap, \
        sum_array[ird,:], linewidth=2,marker= '*', color = '#808080', \
        markersize=20,linestyle=' ')

axis.set_xticks(np.arange(0,13))
for label in axis.get_xticklabels():
    label.set_visible(False)

for label in axis.get_yticklabels():
    label.set_fontproperties(font_large)

ylim = axis.get_ylim()
for month in range(1,13):
    month_name = dt.date(1980,month,1).strftime('%b')
    position = (month + 12 - start_month) % 12 + .5
    axis.text(position,ylim[0] - .05 * (ylim[1] - ylim[0]), month_name, ha = 'center',fontproperties=font_large)
    axis.plot([position+.5]*2,axis.get_ylim(),linestyle='--', color = 'k')

pylab.savefig(outfile_eps)
pylab.savefig(outfile_png)
