import numpy as np

def get_tcoord(cube):
    import iris
    time_names = ['time','t']
    for time_name in time_names:
        found_coord = True
        try:
            coordinate = cube.coord(time_name)
        except iris.exceptions.CoordinateNotFoundError:
            found_coord = False

        if found_coord:
            return coordinate

    if not found_coord:
        print 'Could not identify time coordinate'
        return None


def tday(cube):
    import numpy as np
    tcoord = get_tcoord(cube)
    tday_array = np.array([tcoord.units.num2date(pt).day for pt in tcoord.points])
    return tday_array

def tmonth(cube):
    import numpy as np
    tcoord = get_tcoord(cube)
    tmonth_array = np.array([tcoord.units.num2date(pt).month for pt in tcoord.points])
    return tmonth_array

def tyear(cube):
    import numpy as np
    tcoord = get_tcoord(cube)
    tyear_array = np.array([tcoord.units.num2date(pt).year for pt in tcoord.points])
    return tyear_array

def iindex(cube,month=None,startyear=None,endyear=None):
    import numpy as np
    nl = cube.shape[0]
    
    tcoord = get_tcoord(cube)
    if month:
        month_logical = (tmonth(cube) == month)
    else:
        month_logical = np.repeat(True,nl)

    if startyear:
        sy_logical = tyear(cube) >= startyear
    else:
        sy_logical = np.repeat(True,nl)
	
    if endyear:
        ey_logical = tyear(cube) <= endyear
    else:
        ey_logical = np.repeat(True,nl)
	
    year_logical = np.logical_and(sy_logical,ey_logical)
    time_logical = np.logical_and(month_logical,year_logical)   
    
    return time_logical
     

def reduce(cube,month=None,startyear=None,endyear=None):
    import iris
    import copy
    tempcube = copy.deepcopy(cube)
    
    if month:
        constraint = iris.Constraint(time = lambda cell: cell.point.month == month)
        with iris.FUTURE.context(cell_datetime_objects = True):
            tempcube = tempcube.extract(constraint)
    
    if tempcube is None:
        return None

    if startyear:
        constraint = iris.Constraint(time = lambda cell: cell.point.year >= startyear)
        with iris.FUTURE.context(cell_datetime_objects = True):
            tempcube = tempcube.extract(constraint)
        
    if tempcube is None:
        return None

    if endyear:
        constraint = iris.Constraint(time = lambda cell: cell.point.year <= endyear)
        with iris.FUTURE.context(cell_datetime_objects = True):
            tempcube = tempcube.extract(constraint)
    
    return tempcube


def simple_reduce(cube,month=None,startyear=None,endyear=None):

    time_logical = iindex(cube,month=month,startyear=startyear,endyear=endyear)
    
    new_cube = cube[time_logical]
    return new_cube
    
    
def sc_1D(in_cube,required_years=[],method='mean'):
    import numpy as np
    import numpy.ma as ma
    import iris
    
    if len(in_cube.shape) > 1:
        print 'This should only be used for 1-dimensional cubes'
        return None

    mdic = {'mean':ma.mean,'stddev':ma.std}

    times = get_tcoord(in_cube)
    months = [times.units.num2date(pt).month for pt in times.points]
    years  = [times.units.num2date(pt).year  for pt in times.points]
    n_years = len(set(years))
    
    in_array = in_cube.data
    
    sc_array = np.zeros((12))
    sc_mask  = np.zeros((12),dtype='bool')
    sc_ma_array = ma.masked_array(sc_array,mask=sc_mask)
    
    for mno in range(12):
        iwhere = np.where(np.array(months)==mno+1)
	
        month_array = in_array[iwhere]
	avg_array = mdic[method](month_array)
	sc_ma_array[mno] = avg_array

        if len(required_years) > 0:
            years_for_month = np.array(years)[iwhere]
            year_present = np.array([year in list(years_for_month) for year in required_years])
            if not year_present.all():
                sc_ma_array.mask[mno] = True

    sc_ma_cube = iris.cube.Cube(sc_ma_array)
    sc_ma_cube.long_name = in_cube.long_name
    sc_ma_cube.units = in_cube.units

    month_coord = iris.coords.DimCoord(range(1,13),'time',long_name='month')
    sc_ma_cube.add_dim_coord(month_coord,0)

        
    return sc_ma_cube
    
    
def annual_mean_1D(in_cube,required_months=np.arange(12)+1):
    import numpy as np
    import numpy.ma as ma
    import iris
    
    times = get_tcoord(in_cube)
    months = [times.units.num2date(pt).month for pt in times.points]
    years  = [times.units.num2date(pt).year  for pt in times.points]
    n_years = len(set(years))
    unique_years = np.unique(np.array(years))
    unique_years = unique_years[np.argsort(unique_years)]
    
    in_array = in_cube.data
    
    ts_am_array = np.zeros((n_years))
    ts_am_mask  = np.zeros((n_years),dtype='bool')
    ts_am_ma_array = ma.masked_array(ts_am_array,mask=ts_am_mask)
    
    for (yno,year) in enumerate(unique_years):
        iwhere = np.where(np.array(years)==year)
	
        year_array = in_array[iwhere]
	avg_array = ma.mean(year_array)
	ts_am_ma_array[yno] = avg_array

        if len(required_months) > 0:
            months_for_year = np.array(months)[iwhere]
            year_present = np.array([month in list(months_for_year) for month in range(1,13)])
            if not year_present.all():
                ts_am_ma_array.mask[yno] = True

    ts_am_ma_cube = iris.cube.Cube(ts_am_ma_array)
    ts_am_ma_cube.long_name = in_cube.long_name        

    return ts_am_ma_cube


def equalise_time_and_fp(cube,ref_cube):
    import copy
    new_cube = copy.deepcopy(cube)
    new_cube.coord('time').points = ref_cube.coord('time').points
    new_cube.coord('time').bounds = ref_cube.coord('time').bounds
    new_cube.coord('time').units = ref_cube.coord('time').units

    if ('forecast_period' in [coord.standard_name for coord in ref_cube.coords()] and \
       'forecast_period' in [coord.standard_name for coord in cube.coords()]):
        new_cube.coord('forecast_period').points = ref_cube.coord('forecast_period').points
        new_cube.coord('forecast_period').bounds = ref_cube.coord('forecast_period').bounds
        new_cube.coord('forecast_period').units = ref_cube.coord('forecast_period').units
    elif 'forecast_period' in [coord.standard_name for coord in cube.coords()]:
        new_cube.remove_coord('forecast_period')

    return new_cube


def mdic(mode='short'):
    import datetime as dt
    return_values = {\
        'short':\
         dict([(dt.date(1980,ii,1).strftime("%b"),ii) \
                  for ii in range(1,13)]),\
        'long':\
         dict([(dt.date(1980,ii,1).strftime("%B"),ii) \
                  for ii in range(1,13)])}

    try:
        return_value = return_values[mode]
    except KeyError:
        print 'Invalid mode argument'
        print 'Try '+', '.join(return_values.keys())
        return 0

    return return_value
    

def mtrans(mname_array,mode='short'):
    import numpy as np
    number_list = [mdic(mode='short')[mname] for mname in mname_array]
    return np.array(number_list)


def sc_for_plotting(sc_cube, start_month = 1):
    month_coord = sc_cube.coord('month').points
    month_points = mtrans(month_coord)
    
    reference_points = (np.arange(start_month-1,start_month+11) % 12) + 1

    index_plot = np.array([np.where(month_points==ii)[0][0] for ii in reference_points])
    
    yy_plot = sc_cube.data[list(index_plot)]
    return yy_plot


def time2D(cube):
    import copy
    import iris

    years = tyear(cube)
    months = tmonth(cube)

    n_years = np.unique(years).shape[0]
    
    rs_years = np.reshape(years,(n_years,12))
    rs_months = np.reshape(months,(n_years,12))
    rsd_years = np.roll(rs_years,1,1) - rs_years
    rsd_months = np.roll(rs_months,1,0) - rs_months

    if np.sum(rsd_years != 0) != 0 or np.sum(rsd_months != 0) != 0:
        print 'Can\'t make time 2D - don\'t have a perfect '+\
              'seasonal cycle'
        return None
    
    new_data = copy.copy(cube.data)
    reshaped_data = np.reshape(new_data,(n_years,12))
    reshaped_cube = iris.cube.Cube(reshaped_data)
    mcoordname = 'time'
    ycoordname = 'forecast_reference_time'
    month_coord = iris.coords.DimCoord(range(1,13),mcoordname,long_name='month')
    year_coord = iris.coords.DimCoord(np.unique(years),ycoordname,\
        long_name='year')
    reshaped_cube.add_dim_coord(month_coord,1)
    reshaped_cube.add_dim_coord(year_coord,0)

    reshaped_cube.long_name = cube.long_name
    reshaped_cube.units = cube.units
    
    return mcoordname, ycoordname, reshaped_cube
    
    
def year_in(cube,period):
    import numpy as np
    cube_years = tyear(cube)
    in_logical = np.in1d(cube_years,np.arange(period[0],period[1]+1))
    any_in_logical = np.sum(in_logical) > 0
    return any_in_logical
