import iris.exceptions as ie
import iris

def llbounds(cube):
    lat = cube.coord('latitude')
    lon = cube.coord('longitude')

    if not lat.has_bounds():
        lat.guess_bounds()
    if not lon.has_bounds():
        lon.guess_bounds()


def dimcoordnames(cube,long_name=False):
    if long_name:
        return [dc.long_name for dc in cube.dim_coords]
    else:
        return [dc.standard_name for dc in cube.dim_coords]


def auxcoordnames(cube,long_name=False):
    if long_name:
        return [ac.long_name for ac in cube.aux_coords]
    else:
        return [ac.standard_name for ac in cube.aux_coords]


def dimension(cube,coordname):
    dc = dict([(a[0].standard_name,a[1]) for a in cube._dim_coords_and_dims])

    try:
        return dc[coordname]
    except KeyError:
        print coordname+' is not a DimCoord on this cube'
    

def common_coords(cubelist):
    import iris.std_names
    standard_names_set = set(iris.std_names.STD_NAMES)
    for cube in cubelist:
        individual_set = set(dimcoordnames(cube))
        standard_names_set = standard_names_set & individual_set
        
    return standard_names_set


def dimension_pair_sort(dimsets):
    import numpy as np
    dim_list = []
    for dimset in dimsets:
        dimsort_index = np.argsort([dc[1] for dc in dimset])
        for (i,dim) in enumerate(dimset):
            dim_as_list = list(dim)
            dim_as_list[1] = dimsort_index[i] + len(dim_list)
            dim = tuple(dim_as_list)
            dimset[i] = dim

        dim_list = dim_list + dimset

    return dim_list


def grid_latlon(cube):
    import copy
    if 'latitude' in dimcoordnames(cube):
        cube.coord('latitude').standard_name = 'grid_latitude'
    if 'longitude' in dimcoordnames(cube):
        cube.coord('longitude').standard_name = 'grid_longitude'


def plain_latlon(cube):
    import copy
    if 'grid_latitude' in dimcoordnames(cube):
        cube.coord('grid_latitude').standard_name = 'latitude'
    if 'grid_longitude' in dimcoordnames(cube):
        cube.coord('grid_longitude').standard_name = 'longitude'
        

def area_weights(cube,gridname):
    import numpy as np
    import netCDF4 as nc4
    normal_grids = ['n96','n216','hadgom']

    if gridname == 'orca025':
        mesh_file = '/project/ujcc/CDFTOOLS/mesh_ORCA025L75/mesh_mask_GO5.nc'
        ncid = nc4.Dataset(mesh_file)
        area_h3 = (ncid.variables['e1t'][:] * ncid.variables['e2t'][:])
        area_weights = area_h3[0:1020,1:1441]
    elif gridname == 'orca025_ext':
        mesh_file = '/project/ujcc/CDFTOOLS/mesh_ORCA025extL75/mesh_mask.nc'
        ncid = nc4.Dataset(mesh_file)
        area_h3 = (ncid.variables['e1t'][:] * ncid.variables['e2t'][:])
        area_weights = area_h3[0,0:1206,1:1441]
    elif gridname in normal_grids:
        llbounds(cube)
        area_weights = iris.analysis.cartography.area_weights(cube)
    else: 
        print 'We don\'t know how to deal with this grid yet'
        return None

    if gridname not in normal_grids:
        tile_array = list(cube.shape)
        lat_index = dimension(cube,'latitude')
        lon_index = dimension(cube,'longitude')
        tile_array[lat_index] = 1
        tile_array[lon_index] = 1
        print tile_array
        area_weights = np.tile(area_weights,tile_array)

    return area_weights


def cubelist_anomaly(cubelist1,cubelist2):
    import copy
    import numpy.ma as ma

    if len(cubelist1) != len(cubelist2):
        print 'Cubelists are of different lengths'

    new_cubelist = copy.deepcopy(cubelist1)
    for (i,(cube1,cube2)) in enumerate(zip(cubelist1,cubelist2)):
        if cube1.long_name != cube2.long_name:
            print 'Cubes appear to represent different phenomena'
            return None

        data1 = cube1.data
        data2 = cube2.data
        if data1.shape != data2.shape:
            print 'Cubes appear to have differently-shaped data'
            return None
        
        data1_ma = ma.masked_array(data1)
        data2_ma = ma.masked_array(data2)
        data_new = data1_ma - data2_ma
        new_cubelist[i].data = data_new

    return new_cubelist


def tile_cube(cube_to_tile,template_cube):
    import numpy as np
    # This function is intended to extend a given cube (cube_to_tile)
    # by all dimensions in the given template_cube that it does not 
    # share

    print cube_to_tile

    try:
        check_coords([cube_to_tile,template_cube]) 
    except ie.IrisError:
        print 'All dimensions of cube_to_tile and template_cube must have a DimCoord'

    print cube_to_tile

    tile_dc = cube_to_tile._dim_coords_and_dims
    template_dc = template_cube._dim_coords_and_dims
    cc = common_coords([cube_to_tile,template_cube])  

    print cube_to_tile

    new_coords_to_add = set([dc[0].standard_name for dc in template_dc]) ^ cc
    new_dims_to_add_dc = [dc for dc in template_dc if dc[0].standard_name in new_coords_to_add]
    
    new_dims_dc = dimension_pair_sort((new_dims_to_add_dc,tile_dc))

    print cube_to_tile


    tile_shape = np.ones(len(new_dims_dc))
    for dim in new_dims_dc:
        if dim[0].standard_name not in dimcoordnames(cube_to_tile):
            tile_shape[dim[1]] = len(dim[0].points)

    print cube_to_tile


    new_array = np.tile(cube_to_tile.data,tile_shape)
    new_cube = iris.cube.Cube(new_array)

    print cube_to_tile


    for dim in new_dims_dc:
        new_cube.add_dim_coord(*dim)

    return new_cube


def insert_slice_into_cube(cubeslice,cube):
    import numpy as np
    
    slice_dims = dimcoordnames(cubeslice)
    cube_dims = dimcoordnames(cube)
    
    if not set(slice_dims) <= set(cube_dims):
        print 'Error: slice dimensions must be a subset of cube dimensions'
        return 0
    else:
        for dim in (set(slice_dims) & set(cube_dims)):
            if cubeslice.coord(dim) != cube.coord(dim):
                print 'Error: dimensions common to cube and slice must be identical'
                return 0

    insert_dims = list(set(cube_dims) ^ set(slice_dims))
    n_insert = len(insert_dims)
 
    insert_index = np.zeros(n_insert)
    for (i,dim) in enumerate(insert_dims):
        print dim
        dim_value = cubeslice.coord(dim).points[0]
        index = np.where(cube.coord(dim).points == dim_value)
        if len(index)==0:
             print 'Error: slice scalar coordinates do not match up to cube dim coordinates'
             return 0

        insert_index.itemset(i,index[0])

    cube_position = np.zeros(cube.ndim,dtype='int64')
    slicedim_indices = np.zeros(cubeslice.ndim,dtype='int64')
    cubedim_dict = dict([(a[0].standard_name,a[1]) for a in cube._dim_coords_and_dims])
    slicedim_dict = dict([(a[0].standard_name,a[1]) for a in cubeslice._dim_coords_and_dims])

    for (i,dim) in enumerate(insert_dims):
        cube_position[cubedim_dict[dim]] = insert_index[i]

    for dim in slicedim_dict.keys():
        print cubedim_dict[dim]
        slicedim_indices.itemset(slicedim_dict[dim],cubedim_dict[dim])

    for (slice_position,value) in np.ndenumerate(cubeslice.data):
        print slicedim_indices, '\n', slice_position, '\n'
        cube_position[slicedim_indices] = slice_position
        print cube_position, '\n', value, '\n'
        cube.data.itemset(tuple(cube_position),value)


def check_coords(cube_list):
    # Make sure that all dimensions of each cube in cube_list have a DimCoord
    for cube in cube_list:
        if cube.ndim != len(cube._dim_coords_and_dims):
            raise ie.IrisError


def fix_time_from_month(cube):
    # After aggregating by coordinate 'month', the time coordinate
    # usually takes inaccurate values.  This module fixes the time
    # coordinate by giving it points which reflect the correct
    # months.  Note that the 'year' is meaningless in this context.

    from time import strptime
    import datetime as dt
    import iris
    import numpy as np

    dummy_year = 1979
    dummy_day = 15

    ref_date = dt.date(1978,9,1)
    ref_units = 'Days since'+ref_date.strftime('%Y-%m-%d')

    cube_months = [strptime(month_name,'%b').tm_mon for month_name in \
                   cube.coord('month').points]

    old_ddt_reps = [cube.coord('time').units.num2date(pt) for pt in cube.coord('time').points]

    new_datetimes = [dt.datetime(dummy_year,month,dummy_day,0,0) for month in cube_months]
    old_datetimes = [dt.datetime(a.year,a.month,a.day,a.hour,a.minute) for a in old_ddt_reps]

    nday_diffs = [(a-b).days for (a,b) in zip(new_datetimes,old_datetimes)]
    if cube.coord('time').units.name[0:4] == 'hour':
        diffs = [day_diff * 24. for day_diff in nday_diffs]
    else:
        diffs = nday_diffs

    cube.coord('time').points = cube.coord('time').points + np.array(diffs)


def easy_concatenate(cubelist):
    import numpy as np
    import numpy.ma as ma

    cube_shapes = np.array([cc.shape for cc in cubelist])
    if cube_shapes.dtype!='int':
        print 'Cubes in cubelist are not of equal rank'
        #return None

    rank = cube_shapes.shape[1]

    new_dims = np.zeros(rank)
    new_dims[0] = np.sum(cube_shapes[:,0])
    for dim in range(1,rank):
        if np.unique(cube_shapes[:,dim]).shape[0] > 1:
            print 'Lengths of dim '+str(dim)+' are not all equal.'
            return None
        else:
            new_dims[dim] = np.unique(cube_shapes[:,dim])[0]
            
    if np.product(new_dims) > 2.e8:
        print 'New cube will be very large and may cause memory error'
        return None

    data_values = ma.zeros(new_dims)
    new_cube = iris.cube.Cube(data_values)

    for coordtpl in cubelist[0]._dim_coords_and_dims:
        if coordtpl[1]==0:
            dcname = coordtpl[0].standard_name
        else:
            new_cube.add_dim_coord(*coordtpl)

    for coordtpl in cubelist[0]._aux_coords_and_dims:
        new_cube.add_aux_coord(*coordtpl)

    dcunits = cubelist[0].coord(dcname).units
    dccoord_start_points = np.array([cc.coord(dcname).points[0] for cc in cubelist])
    cube_order = np.argsort(dccoord_start_points)


    concat_coord_points = np.zeros(new_dims[0])
    counter = 0
    for item in cube_order:
        cube = cubelist[item]            
        concat_coord_points[counter:counter+cube.shape[0]] = \
              cube.coord(dcname).points
        counter = counter + cube.shape[0]

    concat_coordinate = iris.coords.DimCoord(concat_coord_points,dcname,units=dcunits)
    new_cube.add_dim_coord(concat_coordinate,0)

    counter = 0
    for item in cube_order:
        cube = cubelist[item]
        if hasattr(cube.data,'mask'):
            new_cube.data[counter:counter+cube.shape[0],...] = \
              cube.data
        else:
            new_cube.data[counter:counter+cube.shape[0],...] = \
              ma.masked_array(cube.data,mask=False)
           
        counter = counter + cube.shape[0]

    return new_cube


def record_collapse(cube,coords,method,**kwargs):
    import iris
    
    if len(cube.shape) != 3:
        print 'This can only be called for a cube of rank 3'
        return None

    nt = cube.shape[0]
    cubelist = iris.cube.CubeList([])
    for ii in range(nt):
        print ii
        icube = cube[ii,:,:].collapsed(coords,method,**kwargs)
        cubelist.append(icube)

    tcube = cubelist.merge()
    if len(tcube) != 1:
        print 'Could not merge resulting cube'
        return None

    return tcube[0]


def find_field(files,varname,allow_multiple=False):
    import iris
    import variables
    import copy
    import numpy as np

    cubes = iris.load(files)
    index = field_index(cubes,varname,allow_multiple=allow_multiple)

    if index is not None:
        return iris.cube.CubeList([cubes[ii] for ii in index[0]])


def field_index(cubes,varname,allow_multiple=False):
    import variables
    import numpy as np

    logical_arrays = {}
    order = ['standard_names','stash_codes','CICE_long_names']

    uvar = variables.UniqueVariable(varname)
    logical_arrays['standard_names'] = \
        np.array([(cube.standard_name == uvar.iris_standard_name) or cube.standard_name is None for cube in cubes])
    logical_arrays['CICE_long_names'] = \
        np.array([cube.long_name in [uvar.CICE_long_name,uvar.UM_long_name] or cube.long_name is None for cube in cubes])

    stash_log_l = []
    for cube in cubes:
        if 'STASH' in cube.attributes.keys():
            stash_log_l.append(cube.attributes['STASH'] == uvar.stash_code)
        else:
            stash_log_l.append(True)
    logical_arrays['stash_codes'] = np.array(stash_log_l)

    running_logical = np.ones(len(cubes),dtype='bool')
    
    order_no = 0
    while order_no < len(order):
        running_logical = np.logical_and(running_logical,logical_arrays[order[order_no]])
        order_no += 1

    if np.sum(running_logical) > 1:
        if allow_multiple:
            return np.where(running_logical)
        else:
            print 'Unable to uniquely identify variable'
            return None

    if np.sum(running_logical) == 1:
        return np.where(running_logical)

    if np.sum(running_logical) == 0:
        print 'Could not find variable in this cube list'
        return None
        
        
def fields_present(ffile,varnames):
    import iris
    import variables
    import numpy as np

    cubes = iris.load(ffile)

    result_array = np.array([field_index(cubes,varname) is not None \
        for varname in varnames])

    return np.sum(result_array) == result_array.shape[0]
        
        
def fields_present_multi(files,varnames):
    import iris
    import variables
    import numpy as np

    all_cubes = iris.cube.CubeList([])
    for ffile in files:
        cubes = iris.load(ffile)
        all_cubes = all_cubes + cubes

    result_array = np.array([field_index(all_cubes,varname) is not None \
        for varname in varnames])

    return np.sum(result_array) == result_array.shape[0]


def sc_translate(cube,start_month=1):
    import datetime as dt
    import numpy as np
    import copy
    import iris.coord_categorisation as icc

    enumerate_months = ((np.arange(12) + start_month - 1) % 12) + 1

    new_cube = copy.deepcopy(cube)
    try:
        new_cube.remove_coord('time')
    except iris.exceptions.CoordinateNotFoundError:
        pass

    new_cube.remove_coord('month')

    ref_date = dt.date(1978,9,1)
    ref_units = ref_date.strftime('days since %Y-%m-%d')
    for month in enumerate_months:
        month_name = dt.date(1980,month,1).strftime('%b')
        month_index = np.where(cube.coord('month').points == month_name)
        new_cube.data[month-1,:,:] = cube.data[month_index[0],:,:]

    tpoints = [dt.date(1980,month,1).toordinal() - ref_date.toordinal() for \
              month in range(1,13)]
    tcoord = iris.coords.DimCoord(tpoints,'time',units=ref_units)
    new_cube.add_dim_coord(tcoord,0)
    icc.add_month(new_cube,'time')
    return new_cube
