import numpy as np
import matplotlib as ml
from scipy import interpolate
import matplotlib.pyplot as plt
from matplotlib import gridspec
from netCDF4 import Dataset
import pickle
import numpy.ma as ma
import os
import f90nml

# ===========================================
def stack_iso_srf(iso,srf,mmask=np.array([])):
  array           = np.zeros([iso.shape[0]+1,iso.shape[1],iso.shape[2]])
  array[0,:,:]    = srf
  array[1:5,:,:]  = iso
  if (mmask.size > 0):
    array = ma.masked_array(array,mask=np.tile(mmask,[array.shape[0],1,1]))

  return array

# ===========================================
def find_name_LHS(index):
  run_name = 'GRL32_LHS_' + str(index).zfill(4)
  return run_name

# ===========================================
def find_name(obs_f,f_to,f_hol,f_glac,calv_tau,calv_method):
  if f_glac == 1.0 and f_to != 0.6:
    run_name = 'GRL-32KM_'+str(obs_f)+'_'+str(f_to)+'_'+str(f_hol)+'_'+str(calv_tau)+'_'+str(calv_method)
  else:
    run_name = 'GRL-32KM_2_'+str(obs_f)+'_'+str(f_to)+'_'+str(f_hol)+'_'+str(f_glac)+'_'+str(calv_tau)+'_'+str(calv_method)

  return run_name

# ===========================================
def find_name_noiso(obs_f,f_hol,f_glac,f_p,calv_method,abl_method,mm_snow,mm_ice,alb_snow_dry,alb_snow_wet):
  if f_p == 0.05:
    run_name = 'GRL-32KM_noiso_'+str(obs_f)+'_'+str(f_hol)+'_'+str(f_glac)+'_'+str(calv_method)+'_'+str(abl_method)+'_'+str(mm_snow)+'_'+str(mm_ice)+'_'+str(alb_snow_dry)+'_'+str(alb_snow_wet)
  else:
    run_name = 'GRL-32KM_noiso2_'+str(obs_f)+'_'+str(f_hol)+'_'+str(f_glac)+'_'+str(f_p)+'_'+str(calv_method)+'_'+str(abl_method)+'_'+str(mm_snow)+'_'+str(mm_ice)
  return run_name

# ===========================================
def plot_var(var,ax,my_cmap,cmin,cmax):

  plt.sca(ax)
  plot_handle = plt.imshow(var,interpolation='none',aspect='auto',origin='lower',cmap=my_cmap,vmin=cmin,vmax=cmax)
  plt.axis('off')

  return plot_handle

# ===========================================
def diso_compare(fig,var1,var2,coastline):
  plt.clf()
  gs = gridspec.GridSpec(5, 4, width_ratios=[1,1,1,1],hspace=0,wspace=0)

  for row in np.arange(0,5):
    vmax = np.ceil(var1.max()/100.)*100.

    plt.subplot(gs[(row*4)+0])
    plt.contour(coastline,[0],colors='gray')
    p1 = plot_var(var1[row,:,:],plt.gca(),plt.get_cmap('RdPu'),0,vmax)

    plt.subplot(gs[(row*4)+1])
    plt.contour(coastline,[0],colors='gray')
    plot_var(var2[row,:,:],plt.gca(),plt.get_cmap('RdPu'),0,vmax)

    plt.subplot(gs[(row*4)+3])
    plt.contour(coastline,[0],colors='gray')
    p2 = plot_var(var1[row,:,:]-var2[row,:,:],plt.gca(),plt.get_cmap('bwr'),-.1*vmax,.1*vmax)

    y_pos    = plt.subplot(gs[row*4]).get_position().bounds[1]+.02
    y_height = plt.subplot(gs[row*4]).get_position().bounds[3]-.04
    ax1 = fig.add_axes([0.60, y_pos, 0.01, y_height])
    #ax1 = fig.add_axes([0.60, .73-.2*row, 0.01, 0.15])
    cbar = plt.colorbar(p1,ax1)
    if (row == 0):
      cbar.set_label('thickness (m)')
    else:
      cbar.set_label('depth (m)')
    ax1.yaxis.set_label_position('left')
    ax1.yaxis.set_ticks_position('left')

    ax2 = fig.add_axes([0.61, y_pos, 0.01, y_height])
    cbar = plt.colorbar(p2,ax2)
    if (row == 0):
      cbar.set_label('$\Delta$thickness (m)')
    else:
      cbar.set_label('$\Delta$depth (m)')

  return gs

# ===========================================
# ===========================================
class yelmo_run():

# ===========================================
  def __init__(self,output_dir,run_name):
    self.run_name = run_name
    self.path = output_dir + self.run_name + '/'

    config_file = self.path + self.run_name + '.nml'
    if not os.path.isfile(config_file):
      print(self.path + self.run_name + '.nml does not exist.')
      return

    self.config         = f90nml.read(config_file)
    self.grid_name      = self.config['yelmo','grid_name']
    self.time_init      = self.config['ctrl','time_init_final']
    self.time_end       = self.config['ctrl','time_end']
    self.runtime        = self.time_end-self.time_init
    self.dpr_hol        = self.config['ctrl','dpr_hol']
    self.obs_f          = self.config['marine_shelf','obs_f']
    self.f_to           = self.config['snap','f_to']
    self.f_stddev       = self.config['snap','f_stdev']
    self.f_hol          = self.config['snap_hybrid','f_hol']
    self.f_glac         = self.config['snap_hybrid','f_glac']
    self.calv_tau       = self.config['ytopo','calv_tau']
    self.calv_method    = self.config['ytopo','calv_method']
    self.abl_method     = self.config['smbpal','abl_method']
    self.mm_snow        = self.config['smbpal','mm_snow']
    self.mm_ice         = self.config['smbpal','mm_ice']
    self.enh_shear      = self.config['ymat','enh_shear']
    self.enh_glac       = self.config['ice_enh','enh_glac']
    self.enh_int        = self.config['ice_enh','enh_int']
    self.ghf_const      = self.config['geothermal','ghf_const']

    if "yiso" in self.config:
      self.dt_iso       = self.config['yiso','smb_modulo']
      self.init_layers  = self.config['yiso','n_layers_init']
      self.yiso_out     = self.config['yiso','yiso_out']
      self.kmax         = int(self.runtime/self.dt_iso+self.init_layers)

    if self.grid_name == 'GRL-16KM':
      self.dx           = 16       # horizontal grid spacing in km
      self.imax         = 106
      self.jmax         = 181
    if self.grid_name == 'GRL-32KM':
      self.dx           = 32       # horizontal grid spacing in km
      self.imax         = 54
      self.jmax         = 91

    gridinfo_filename = self.path + 'ice_data/Greenland/' + self.grid_name + '/' + self.grid_name + '_REGIONS.nc'
    f = Dataset(gridinfo_filename, 'r')
    self.lon  = f.variables['lon2D'][:].copy()+360
    self.lat  = f.variables['lat2D'][:].copy()
    self.area = f.variables['area'][:].copy()
    f.close()

    return


# ===========================================
  def find_name(self):

    if self.f_glac == 1.0 and self.f_to != 0.6:
      self.run_name = 'GRL-32KM_'+str(self.obs_f)+'_'+str(self.f_to)+'_'+str(self.f_hol)+'_'+str(self.calv_tau)+'_'+str(self.calv_method)
    else:
      self.run_name = 'GRL-32KM_2_'+str(self.obs_f)+'_'+str(self.f_to)+'_'+str(self.f_hol)+'_'+str(self.f_glac)+'_'+str(self.calv_tau)+'_'+str(self.calv_method)

    return


# ===========================================
  def does_run_exist(self):
    return os.path.isdir(self.path)


# ===========================================
  def plot_bin(self,year,var,k,ax,my_cmap,cmin,cmax):
    var_all = self.load_bin(var,year)
    self.load_ice()

    plot_handle = plot_var(ma.masked_array(var_all[k,:],mask=self.ice_mask),ax,my_cmap,cmin,cmax)

    return plot_handle


# ===========================================
  def add_PD_contour(self,lev):
    #self.load_ice_obs()

    plt.contour(yelmo_obs(self.dx).srf,[0],colors='k',linewidths=3)
    plt.contour(yelmo_obs(self.dx).srf,levels=lev,colors='k')
    plt.contour(yelmo_obs(self.dx).ice,[50],colors='#FF934F',linewidths=3)


# ===========================================
  def plot_nc(self,var,timestep,k,ax,cmin,cmax,cincr,cmap):
    plotvar = self.load_nc(var)

    if (plotvar.ndim == 3):
      plotvar = plotvar[timestep,:,:]
    if (plotvar.ndim == 4):
      plotvar = plotvar[timestep,k,:,:]

    # make colorbar and norm
    cnorm = ml.colors.BoundaryNorm(np.arange(cmin,cmax+.5*cincr,cincr),cmap.N)

    plt.sca(ax)
    plot_handle = plt.imshow(plotvar,interpolation='none',aspect='auto',origin='lower',vmin=cmin,vmax=cmax,cmap=cmap,norm=cnorm)
    plt.axis('off')

    return plot_handle


# ===========================================
  def load_nc(self,varname):
    ncfile = Dataset(self.path + self.run_name + '.2D.nc', 'r')
    var = ncfile.variables[varname][:].copy()
    ncfile.close()

    return var


# ===========================================
  def load_bin(self,varname,year,is_2D=False):
    #full_path = self.path + self.run_name + '.' + str("%06d" % year)
    filename = self.path + self.run_name + '.' + str("%06d" % year) +'.' + varname + '.bin'

    if (is_2D):
      var = np.fromfile(file=filename, dtype=np.float32).reshape((self.jmax,self.imax))
    else:
      var = np.fromfile(file=filename, dtype=np.float32).reshape((self.kmax,self.jmax,self.imax))
    return var


# ===========================================
  def load_ice(self):
    self.bed = self.load_nc('z_bed')[-1,:]
    self.srf = self.load_nc('z_srf')[-1,:]

    #ncfile    = Dataset(self.path + self.run_name + '.2D.nc', 'r')
    #self.bed  = ncfile.variables['z_bed'][-1,:].copy()
    #self.srf  = ncfile.variables['z_srf'][-1,:].copy()
    #ncfile.close()

    self.srf[self.srf==0.] = self.bed[self.srf==0.]
    self.ice               = self.srf-self.bed
    self.ice_mask          = np.zeros_like(self.ice)
    self.ice_mask          = (np.where(self.ice<1.   ,1,0))
    self.mask_1km          = np.zeros_like(self.ice)
    self.mask_1km          = (np.where(self.ice<1000.,1,0))
    self.mask_1p5km        = np.zeros_like(self.ice)
    self.mask_1p5km        = (np.where(self.ice<1500.,1,0))
    self.mask_2km          = np.zeros_like(self.ice)
    self.mask_2km          = (np.where(self.ice<2000.,1,0))

    return


# ===========================================
  def load_age(self,j,i):

    dsum = self.load_bin('dsum_iso',160000)

    age_depth = np.array((dsum[:,j,i]/dsum[-1,j,i], np.arange(dsum.shape[0],0,-1) * self.dt_iso * 1e-3))

    return age_depth


# ===========================================
  def load_age_nc(self,j,i):

    zeta  = self.load_nc('zeta')
    age   = self.load_nc('dep_time')[-1,:,j,i]

    age_depth = np.array((zeta, -1.*age))

    return age_depth


# ===========================================
  def load_vel_masks(self):
    self.load_ice()

    ux_s = self.load_nc('ux_s')[-1,:,:]
    uy_s = self.load_nc('uy_s')[-1,:,:]
    u_s  = (ux_s**2+uy_s**2)**.5

    self.slowice_mask      = self.ice_mask.copy()
    self.slowice_mask      = np.maximum((np.where(u_s>50,1,0)),self.ice_mask)
    self.fastice_mask      = self.ice_mask.copy()
    self.fastice_mask      = np.maximum((np.where(u_s<50,1,0)),self.ice_mask)

# ===========================================
  #def load_ice_obs(self):
    #if self.dx == 16:
      #path = '../ice_data/Greenland/GRL-16KM/GRL-16KM_TOPO-B13.nc'
    #if self.dx == 32:
      #path = '../ice_data/Greenland/GRL-32KM/GRL-32KM_TOPO-B13.nc'
    #ncfile_obs = Dataset(path, 'r')
    #self.srf_obs = ncfile_obs.variables['zs'][:].copy().astype(float)
    #self.bed_obs = ncfile_obs.variables['zb'][:].copy().astype(float)
    #ncfile_obs.close()

    #self.bed_obs[self.bed_obs==0.0] = np.nan
    #self.srf_obs[self.srf_obs==0.0] = self.bed_obs[self.srf_obs==0.0]
    #self.ice_obs = self.srf_obs - self.bed_obs

    #return

# ===========================================
  def indx_latlon(self,coords):

    lat,lon = coords

    if (lon < 0.):
      lon = lon + 360.

    # find closest node in yelmo:
    dist  = ((self.lat-lat)**2+(self.lon-lon)**2)**.5
    if (np.min(dist) < 1.):
      node = np.unravel_index(np.argmin(dist), dist.shape)
    else:
      print('yelmo_utils::indx_latlon: Did not find nearest node.')
      node = np.nan

    return node

# ===========================================
  def get_data_on_isochrones(self,var,isochrones):
    iso_indx = self.kmax - (np.array(isochrones)/self.dt_iso).round().astype('int')
    if (var == 'depth_below_surface'):
      self.load_ice()
      var_full = self.load_bin('dsum_iso',160000)
      var_full = self.ice - var_full
    else:
      var_full = self.load_bin(var,160000)

    var_isochrones = var_full[iso_indx,:,:]

    return var_isochrones

# ===========================================
  def get_diso(self,ages):
    self.diso = self.get_data_on_isochrones('depth_below_surface',ages)
    return

# ===========================================
  def get_obsage_at_simdiso(self,obs):
    # What age is observed at the depth of the simulated isochrones?
    # Variable self.diso must exist to run this

    self.obsage_diso = np.zeros_like(self.diso)

    age_norm_nan = obs.age_norm.copy()
    age_norm_nan[age_norm_nan == 0.] = np.nan

    for i in np.arange(0,self.obsage_diso.shape[1]):
      for j in np.arange(0,self.obsage_diso.shape[2]):
        f = interpolate.interp1d(np.arange(0.,1.,0.04),age_norm_nan[:,i,j])

        if (self.ice_mask[i,j] == 0):
          for isochrone in np.arange(0,4):
            if (self.diso[isochrone,i,j]/self.ice[i,j] < 0.):
              self.obsage_diso[isochrone,i,j] = np.nan
            elif (self.diso[isochrone,i,j]/self.ice[i,j] > .96):
              self.obsage_diso[isochrone,i,j] = np.nan
            else:
              self.obsage_diso[isochrone,i,j] = f(self.diso[isochrone,i,j]/self.ice[i,j])
        else:
          self.obsage_diso[:,i,j] = np.nan

    return


# ===========================================
# ===========================================
class yelmo_obs():

# ===========================================
  def __init__(self,resolution):

    # Load surface and bed topographies:
    if resolution   == 16:
      path = '../ice_data/Greenland/GRL-16KM/GRL-16KM_TOPO-B13.nc'
      self.dx = 16
    elif resolution == 32:
      path = '../ice_data/Greenland/GRL-32KM/GRL-32KM_TOPO-B13.nc'
      self.dx = 32
    else:
      print('Observations do not exist at resolution ' + str(resolution) + '.')
      return

    ncfile_obs = Dataset(path, 'r')
    self.srf = ncfile_obs.variables['zs'][:].copy().astype(float)
    self.bed = ncfile_obs.variables['zb'][:].copy().astype(float)
    ncfile_obs.close()

    self.bed[self.bed==0.0]     = np.nan
    self.srf[self.srf==0.0]     = self.bed[self.srf==0.0]
    self.ice                    = self.srf - self.bed
    self.ice_mask               = np.zeros_like(self.ice)
    self.ice_mask               = (np.where(self.ice<1,1,0))
    self.mask_1km               = np.zeros_like(self.ice)
    self.mask_1km               = (np.where(self.ice<1000.,1,0))
    self.mask_1p5km             = np.zeros_like(self.ice)
    self.mask_1p5km             = (np.where(self.ice<1500.,1,0))
    self.mask_2km               = np.zeros_like(self.ice)
    self.mask_2km               = (np.where(self.ice<2000.,1,0))

    self.coastline              = self.srf.copy()
    #self.coastline[self.srf>0.] = 2.
    self.coastline[:,50:53]     = np.nan
    self.coastline[:,0:3]       = np.nan
    self.coastline[75:-1,0:10]  = np.nan
    self.coastline[82:-1,10:16] = np.nan
    self.coastline[80:82,10:13] = np.nan

    # Load age of isochrones:
    ncfile_nasa = Dataset('/work2/aborn/backup/backup.0/ism/nasa_icebridge_data/RRRAG4_Greenland_1993_2013_01_age_grid.nc', 'r')
    self.age_iso_nasa = ncfile_nasa.variables['age_iso'][:].copy()
    ncfile_nasa.close()

    # Load isochrone depths:
    if resolution   == 16:
      nasa_file = open("/work2/aborn/backup/backup.0/ism/nasa_icebridge_data/python/interp_icebridge_yelmo_16km.bin", "rb")
    elif resolution == 32:
      nasa_file = open("/work2/aborn/backup/backup.0/ism/nasa_icebridge_data/python/interp_icebridge_yelmo_32km.bin", "rb")
    else:
      print('Observations do not exist at resolution ' + str(resolution) + '.')
      return

    data = pickle.load(nasa_file, encoding="latin1")
    nasa_file.close()

    self.thickness                                  = data[2]
    self.age_norm                                   = data[3]
    self.age_norm_uncert                            = data[4]
    self.depth_iso                                  = data[5]
    self.depth_iso_uncert                           = data[6]

    self.depth_iso[self.depth_iso==0]               = np.nan
    self.depth_iso_uncert[self.depth_iso_uncert==0] = np.nan
    self.dsum_iso                                   = np.tile(self.thickness,[4,1,1]) - self.depth_iso

    # Ice core coordinates (lat,lon):
    self.summit_coords   = [72.58,-38.46]
    self.neem_coords     = [79   ,-50   ]
    self.dye3_coords     = [65   ,-43.75]
    self.ngrip_coords    = [75.16,-42.5 ]
    self.egrip_coords    = [75.5 ,-36   ]

# ===========================================
  def add_PD_contours(self,show_srf=False,show_ice=False,lev_srf=np.arange(0,4e3,1e3),lev_ice=[50],
                     col_coast='gray',col_srf='k',col_ice='m',width_coast=1,width_srf=1,width_ice=1):

    plt.contour(self.coastline,[0],colors=col_coast,linewidths=width_coast)
    if (show_srf):
      plt.contour(self.srf,levels=lev_srf,colors=col_srf,linewidths=width_srf)
    if (show_ice):
      plt.contour(self.ice,levels=lev_ice,colors=col_ice,linewidths=width_ice)

    return

## ===========================================
## ===========================================
class yelmo_ensemble():

# ===========================================
  def __init__(self,size_ens,sample_run):
    self.run_list         = np.empty(size_ens,dtype=object)

    self.size_ens         = size_ens
    self.path             = sample_run.path[0:-len(sample_run.run_name)-1]
    self.imax             = sample_run.imax
    self.jmax             = sample_run.jmax
    self.dx               = sample_run.dx

    self.dpr_hol_list     = np.zeros(self.size_ens)
    self.f_stdev_list     = np.zeros(self.size_ens)
    self.mm_snow_list     = np.zeros(self.size_ens)
    self.enh_shear_list   = np.zeros(self.size_ens)
    self.enh_glac_list    = np.zeros(self.size_ens)
    self.ghf_list         = np.zeros(self.size_ens)
    self.ignore_run       = np.full(self.size_ens,False,dtype='bool')

    for i in np.arange(0,self.size_ens):
      exp_name = find_name_LHS(i+1)
      #print(exp_name)
      run = yelmo_run(output_dir=self.path,run_name=exp_name)
      if run.does_run_exist():
        #append run to ensemble list:
        self.run_list[i]        = run

        self.dpr_hol_list[i]    = run.dpr_hol
        self.f_stdev_list[i]    = run.f_stddev
        self.mm_snow_list[i]    = run.mm_snow
        self.enh_shear_list[i]  = run.enh_shear
        self.enh_glac_list[i]   = run.enh_glac
        self.ghf_list[i]        = run.ghf_const
      else:
        self.dpr_hol_list[i]    = np.nan
        self.f_stdev_list[i]    = np.nan
        self.mm_snow_list[i]    = np.nan
        self.enh_shear_list[i]  = np.nan
        self.enh_glac_list[i]   = np.nan
        self.ghf_list[i]        = np.nan

    print('Ensemble read.')

    return

# ===========================================
  def ignore_missing_runs(self):
    for i in np.arange(0,self.size_ens):
      if (not self.run_list[i]):
        #print('Ignore run ',i)
        self.ignore_run[i]        = True

# ===========================================
  def find_issues(self):
    summit_node = self.run_list[0].indx_latlon(self.obs.summit_coords)

    for i in np.arange(0,self.size_ens):
      try:
        d = self.run_list[i].load_bin('d_iso',160000)[:,summit_node[0],summit_node[1]]
      except:
        self.ignore_run[i] = True
      else:
        d[d<1e-3] = 0
        if ((d > 100.).any()):
          self.ignore_run[i] = True
          print('Found issues in ', self.run_list[i].run_name)

    return

# ===========================================
  def make_composite(self,param,param_range):

    if (not hasattr(self.run_list[0], param)):
      print('yelmo_ensemble::make_composite: Parameter does not exist.')
      return

    ldict = {}
    comp_list = []

    for run in self.run_list[~self.ignore_run.astype('bool')]:
    #for run in self.run_list:
      var_local = np.nan
      exec('a = run.'+param,locals(),ldict)
      var_local = ldict['a']
      if ((var_local > param_range[0]) and (var_local < param_range[1])):
         comp_list.append(run)

    #comp_name = 'composite_'+param+'_'+str(param_range).strip('[]').replace(',','_').replace(' ','').replace('.','p').replace('-','m')
    #print(comp_name)
    #setattr(self,comp_name,comp_list)

    return comp_list

# ===========================================
  def get_obs(self):

    self.obs                          = yelmo_obs(self.dx)

    self.test_isochrones              = self.obs.age_iso_nasa
    self.num_iso                      = self.test_isochrones.size

    self.all_obs                      = np.zeros([self.obs.depth_iso.shape[0]+1,self.obs.depth_iso.shape[1],self.obs.depth_iso.shape[2]])
    self.all_obs[0,:,:]               = self.obs.ice
    self.all_obs[1:5,:,:]             = self.obs.depth_iso

    return

  class RMSE():
  # ===========================================
    def __init__(self,outer_self):
      self.rmse_srf_array     = np.zeros(outer_self.size_ens)
      self.rmse_ice_array     = np.zeros(outer_self.size_ens)
      self.rmse_all_array     = np.zeros(outer_self.size_ens)
      self.rmse_iso_array     = np.zeros([outer_self.size_ens,outer_self.num_iso])
      self.rmse_isodiff_array = np.zeros([outer_self.size_ens,3])

      for i in np.arange(0,outer_self.size_ens):
        if ((not hasattr(outer_self.run_list[i], 'diso')) and (not outer_self.ignore_run[i])):
          outer_self.run_list[i].get_diso(outer_self.test_isochrones)

      return


  # ===========================================
    def get_RMSE(self,outer_self,obs,source,mask_name):

      print('Get RMSE for source', source, 'and mask', mask_name)

      iso_array                  = np.zeros([outer_self.num_iso,outer_self.jmax,outer_self.imax])
      iso_diff_array             = np.zeros([3,outer_self.jmax,outer_self.imax])
      all_array                  = np.zeros([outer_self.num_iso+1,outer_self.jmax,outer_self.imax])

      for i in np.arange(0,outer_self.size_ens):
        # for convenience:
        if (not outer_self.ignore_run[i]):
          #print(outer_self.run_list[i].run_name)
          run = outer_self.run_list[i]

        #if (run.does_run_exist() and (not outer_self.ignore_run[i])):
          run.load_ice()
          if (source == 'yiso'):
            iso_array           = run.diso
          elif (source == 'deptime'):
            iso_array           = run.load_nc('depth_iso')[-1,:,:,:]
          else:
            print('Source does not exist')
            return

          iso_diff_array[0,:,:] = iso_array[3,:,:]-iso_array[2,:,:]# MIS 4 - MIS 5d
          iso_diff_array[1,:,:] = iso_array[2,:,:]-iso_array[1,:,:]# MIS 3
          iso_diff_array[2,:,:] = iso_array[1,:,:]-iso_array[0,:,:]# MIS 2

          all_array[0,:]        = run.ice
          all_array[1:5,:]      = iso_array

          # Mask to be used for RMSE:
          if (mask_name == '1km'):
            mmask = obs.mask_1km
          elif (mask_name == '1.5km'):
            mmask = obs.mask_1p5km
          elif (mask_name == '2km'):
            mmask = obs.mask_2km
          elif (mask_name == 'slowice'):
            run.load_vel_masks()
            mmask = run.slowice_mask
          elif (mask_name == 'fastice'):
            run.load_vel_masks()
            mmask = run.fastice_mask
          else:
            print('Mask does not exist.')
            return

          obs_isodiff        = np.zeros_like(iso_diff_array)
          obs_isodiff[0,:,:] = obs.depth_iso[3,:,:]-obs.depth_iso[2,:,:]
          obs_isodiff[1,:,:] = obs.depth_iso[2,:,:]-obs.depth_iso[1,:,:]
          obs_isodiff[2,:,:] = obs.depth_iso[1,:,:]-obs.depth_iso[0,:,:]

          self.rmse_srf_array[i]       = (ma.masked_array(run.srf-obs.srf,mask=mmask)**2).mean()**.5
          self.rmse_ice_array[i]       = (ma.masked_array(run.ice-obs.ice,mask=mmask)**2).mean()**.5
          self.rmse_iso_array[i,:]     = (ma.masked_array(iso_array-obs.depth_iso,
                                                          mask=np.tile(mmask,[4,1,1]))**2).mean(axis=1).mean(axis=1)**.5
          self.rmse_all_array[i]       = (ma.masked_array(all_array-outer_self.all_obs,
                                                          mask=np.tile(mmask,[5,1,1]))**2).mean()**.5
          self.rmse_isodiff_array[i,:] = (ma.masked_array(iso_diff_array-obs_isodiff,
                                                          mask=np.tile(mmask,[3,1,1]))**2).mean(axis=1).mean(axis=1)**.5

          del run

        else:
          self.rmse_srf_array[i]       = np.nan
          self.rmse_ice_array[i]       = np.nan
          self.rmse_iso_array[i]       = np.nan
          self.rmse_all_array[i]       = np.nan
          self.rmse_isodiff_array[i,:] = [np.nan, np.nan, np.nan]


      self.best_srf        = np.nanargmin(self.rmse_srf_array)
      self.best_ice        = np.nanargmin(self.rmse_ice_array)
      self.best_iso        = np.nanargmin(self.rmse_iso_array,     axis=0)
      self.best_isodiff    = np.nanargmin(self.rmse_isodiff_array, axis=0)
      self.best_all        = np.nanargmin(self.rmse_all_array)
      #self.best_srf        = self.rmse_srf_array.argmin( )
      #self.best_ice        = self.rmse_ice_array.argmin( )
      #self.best_iso        = self.rmse_iso_array.argmin(0)
      #self.best_isodiff    = self.rmse_isodiff_array.argmin(0)
      #self.best_all        = self.rmse_all_array.argmin( )

      self.bestrun_srf     = outer_self.run_list[self.best_srf]
      self.bestrun_ice     = outer_self.run_list[self.best_ice]
      self.bestrun_iso     = outer_self.run_list[self.best_iso]
      self.bestrun_isodiff = outer_self.run_list[self.best_isodiff]
      self.bestrun_all     = outer_self.run_list[self.best_all]

      print('Done.')
      return
