'''Constructs a reference dataset of ice thickness by category as described in Section 4'''

import iris
import numpy as np
import iris.coord_categorisation as icc
import copy
import timemod
import numpy.ma as ma
import cube_magic as cm

model = 'HadGEM3_GC3.1'
experiment = 'Hist1'
tprof = 'monthly'
gridname = 'eorca1'
startyear = 1980
endyear = 1999

model_dir = '/data/local/hadax/PhD/Attribution/' + model + '/' + \
    experiment + '/' + tprof + '/' + gridname + '/'

prop_cutoff = 5.

#PIOMAS
piomas_file = '/data/local/hadax/PhD/Attribution/PIOMAS/monthly/eorca1/Ice_thickness.nc'
piomas_cube = iris.load_cube(piomas_file)
piomas_cubel_month = {month: timemod.reduce(piomas_cube,month=month) \
    for month in range(1,13)}

years = range(startyear,endyear+1)
all_cubelist = {cat: iris.cube.CubeList([]) for cat in range(1,6)}
cat_base_weights = np.array([.5,1.,1.,1.,1.])

for year in years:
    print ' '
    print '---------------------------'
    print 'Year = '+str(year)
    gbm_area_file = model_dir + 'Ice_area' + \
           '/' + str(year) + '.nc'
    gbm_area_cube = iris.load_cube(gbm_area_file)
    cat_area_cubelist = iris.cube.CubeList([])
    for cat in range(1,6):
        cat_area_file = model_dir + 'Ice_area.cat.' + str(cat) + \
           '/' + str(year) + '.nc'
        cat_area_cube = iris.load_cube(cat_area_file)
        cat_area_cubelist.append(cat_area_cube)

    weights_array = np.zeros((5,)+gbm_area_cube.shape)
    weights_array[0,:,::,] = cat_base_weights[0]
    print '       Calculating initial weights'
    ncat1_prop = (gbm_area_cube.data - cat_area_cubelist[0].data / 2.) / \
            (gbm_area_cube.data - cat_area_cubelist[0].data)
    mask_field = (1. - np.isfinite(ncat1_prop).astype('int')).astype('bool')
    ncat1_prop = ma.masked_array(ncat1_prop, mask = mask_field)
    index = np.where(ncat1_prop > prop_cutoff)
    ncat1_prop[index] = prop_cutoff
    for cat in range(2,6):
        weights_array[cat-1,:,:,:] = cat_base_weights[cat-1] * ncat1_prop

    work_array = np.zeros(gbm_area_cube.shape)
    cumulative_downwards_area = np.zeros((5,)+gbm_area_cube.shape)
    print '       Calculating cumulative downwards area'
    for cat in range(5,0,-1):
        cumulative_downwards_area[cat-1,:,:] = work_array + \
            cat_area_cubelist[cat-1].data
        work_array = work_array + \
            cat_area_cubelist[cat-1].data

    correctional_increment_anti_ice = np.zeros(gbm_area_cube.shape)
    print '       Starting main calculation'
    for cat in range(1,6):
        print '            cat = '+str(cat)
        cat_thickness_file = model_dir + 'Ice_thickness.cat.' + str(cat) + \
           '/' + str(year) + '.nc'

        cat_thickness_cube = iris.load_cube(cat_thickness_file)
        cat_thickness_cube.units = 'm'

        gbm_thickness_file = model_dir + 'Ice_thickness' + \
           '/' + str(year) + '.nc'
        gbm_thickness_cube = iris.load_cube(gbm_thickness_file)
        piomas_cube_cfg = copy.deepcopy(gbm_thickness_cube)
        piomas_cube_cfg.data = piomas_cube.data

        print '              Calculating initial increment'
        increment = (piomas_cube_cfg - gbm_thickness_cube) * \
                  weights_array[cat-1,:,:,:]
        piomas_cat_thickness_year = copy.deepcopy(cat_thickness_cube)
        piomas_cat_thickness_year.data = cat_thickness_cube.data + \
              increment.data
        piomas_cat_thickness_year = piomas_cat_thickness_year \
            + correctional_increment_anti_ice

        print '              Calculating anti-ice increment'
        increment_anti_ice = piomas_cat_thickness_year.data * \
             (piomas_cat_thickness_year.data < 0.)
        piomas_cat_thickness_year.data = \
            piomas_cat_thickness_year.data - \
            increment_anti_ice

        proportion = cat_area_cubelist[cat-1].data / cumulative_downwards_area[cat-1]
        mask_field = (1. - np.isfinite(proportion).astype('int')).astype('bool')
        proportion = ma.masked_array(proportion, mask = mask_field)
        index = np.where(proportion > prop_cutoff)
        proportion[index] = prop_cutoff
        if cat != 5:
            correctional_increment_anti_ice = \
                correctional_increment_anti_ice + \
                increment_anti_ice * proportion

        all_cubelist[cat].append(piomas_cat_thickness_year)

for cat in range(1,6):
    mcubelist = all_cubelist[cat].concatenate()
    if len(mcubelist) > 1:
        print 'Unable to merge cubelist'
        mcubelist = iris.cube.CubeList([cm.easy_concatenate(all_cubelist[cat])])

    piomas_cat_thickness_all = mcubelist[0]
    icc.add_month(piomas_cat_thickness_all,'time')
    piomas_cat_thickness_mean = piomas_cat_thickness_all.aggregated_by(\
        'month',iris.analysis.MEAN)

    outfile = '/data/local/hadax/PhD/Attribution/PIOMAS/monthly/'+gridname+'/' + \
        'Ice_thickness.synth_v2.cat.' + str(cat) + '.nc'

    iris.save(piomas_cat_thickness_mean,outfile)
