# -*- coding: utf-8 -*-
import matplotlib as plt
import numpy as np

from osgeo import gdal 
from scipy.fftpack import fft2, ifft2, fftshift
from skimage.feature import match_template

def pattern_match_example(I1, I2, ij_list, temp_radius=1, sear_radius=1):
    # include additional space along image to include areas at the border
    I1 = np.pad(I1, sear_radius, 'constant', constant_values=0)
    I2 = np.pad(I2, sear_radius, 'constant', constant_values=0)
    i = np.round(ij_list[:,0]).astype(np.int64) + sear_radius
    j = np.round(ij_list[:,1]).astype(np.int64) + sear_radius
    
    for counter in range(ij_list.shape[0]):
        # construct image sub-sets
        I1sub = I1[i[counter]-temp_radius:i[counter]+temp_radius+1,
                   j[counter]-temp_radius:j[counter]+temp_radius+1]
        I2sub = I2[i[counter]-sear_radius:i[counter]+sear_radius+1,
                   j[counter]-sear_radius:j[counter]+sear_radius+1]
        
        # cross correlation calculation
        pcf = match_template(I2sub, I1sub)       
        val = np.amax(pcf)
        snr = val / np.mean(pcf)
        ij = np.unravel_index(np.argmax(pcf), pcf.shape)
        intI, intJ = ij#[::-1]
        
        # gaussian sub-pixel estimation
        dI,dJ = gauss_top(pcf, intI, intJ)
        
        # gaussian parameter estimation
        qi, qj, rho,_,frac = gauss_spread(pcf, intI, intJ, dI, dJ)
        
        # wrapping-up
        intI = intI-(sear_radius-temp_radius)
        intJ = intJ-(sear_radius-temp_radius)
        
        # c
        ij_result[counter,:] = np.array([[intI, intJ, dI, dJ, val, snr, qi, qj, rho]])
        
    return ij_result

def gauss_top(pcf, intI, intJ):
    """
    gauss_top
    simple localization of the top of the correlation.
    
    input:
        pcf - array with correlation values
        intI - integer location of highest value in search space
        intJ - integer location of highest value in search space
    output:
        dI - sub-pixel bias of top location
        dJ - sub-pixel bias of top location
    """
    # avoid minus correlations
    pcf += 2*abs(np.min(pcf[intI-1:intI+2,intJ-1:intJ+2])) 
    
    dI = (np.log(pcf[intI+1,intJ]) - np.log(pcf[intI-1,intJ])) / 2*(2*np.log(pcf[intI,intJ]) -np.log(pcf[intI+1,intJ]) -np.log(pcf[intI-1,intJ]))
    dJ = (np.log(pcf[intI,intJ+1]) - np.log(pcf[intI,intJ-1])) / 2*(2*np.log(pcf[intI,intJ]) -np.log(pcf[intI,intJ+1]) -np.log(pcf[intI,intJ-1]))
    return dI, dJ

def gauss_spread(pcf, intI, intJ, dI, dJ, est='dist'):
    (m,n) = pcf.shape
    pcf = pcf - np.mean(pcf)
    
    if np.min((intI, intJ))<=1 or (intI+2)>m or (intJ+2)>n:
        dub = 1.
    else:
        dub = 2.
    
    I,J = np.mgrid[-dub:+dub+1, -dub:+dub+1]
    I -= dI
    J -= dJ

    dub = int(dub)    
    P_sub = pcf[intI-dub:intI+dub+1,intJ-dub:intJ+dub+1]
    IN = P_sub>0
    # normalize correlation score to probability function    
    frac = np.sum(P_sub[IN])
    P_sub = P_sub/frac
    
    if np.sum(IN)>4:
        A = np.vstack((I[IN]**2, 
                       2*I[IN]*J[IN], 
                       J[IN]**2, 
                       np.ones((1, np.sum(IN)))
                       )).transpose()
        y = P_sub[IN]
        
    # least squares estimation
    if est=='distance':
        dub = float(dub)
        W = (dub**2 - np.sqrt(A[:,0]+A[:,2]))/ dub**2 # distance from top
        Aw = A * np.sqrt(W[:,np.newaxis])
        yw = y * np.sqrt(W)
        hess = np.linalg.lstsq(Aw, yw, rcond=None)[0]
    else:
        hess = np.linalg.lstsq(A, y, rcond=None)[0]
    
    # transform to parameters
    rho = (0.5*hess[1])/np.sqrt(hess[0]*hess[2]);
    qi = 1/(-2*(1 - rho)*hess[0]);
    qj = 1/(-2*(1 - rho)*hess[2]);
    
    # deviations can be negative, then 
    
    if np.iscomplex(rho) or np.iscomplex(qi) or np.iscomplex(qj):
        qi, qj, rho = 0, 0, 0
    return qi, qj, rho, hess, frac

def high_pass_im(Im, radius=10):
    (m,n) = Im.shape
    If = fft2(Im)
    
    # fourier coordinate system
    (I,J) = np.mgrid[:m,:n]
    Hi = fftshift(1 - np.exp(-( ((I-m/2)**2 + (J-n/2)**2) / (2*radius)**2 )) )
    
    Inew = np.real(ifft2(If*Hi))
    return Inew

# read imagery
img = gdal.Open('T22WEB_20200730T151921_B04_clip.jp2')
I1 = np.array(img.GetRasterBand(1).ReadAsArray())
R = img.GetGeoTransform
img = gdal.Open('T22WEB_20200720T151911_B04_clip.jp2')
I2 = np.array(img.GetRasterBand(1).ReadAsArray())

I1 = np.float64(I1) / 15e3
I2 = np.float64(I2) / 15e3

# locations in the imagery of interest
ij_list = np.array([[1095, 2427], [1579, 2447], [1749, 724]]);

ij_result = pattern_match_example(I1, I2, ij_list, 
                                  temp_radius=10, sear_radius=40)







