# -*- coding: utf-8 -*-
"""
Created on Tue May 29 17:23:36 2018

@author: ck336
"""

import numpy as np
import glob
import matplotlib as mpl
import matplotlib.pyplot as plt
from detect_peaks import detect_peaks


mpl.rc('pdf', fonttype=42)
#%%
# For each intensity peak found in spectrum, the CD is calculated and stored as a peak for easy accessibility later
# TODO: incorporate CD calculation into class def. 

# set global parameters 
THRESHOLD = 1570 # noise level in spectra
ANGLE_ACCEPTANCE = 1 # depends on how misaligned or noisy data is, if only partial diffracted lines visible, increase value; otherwise decrease to reduce outlyers
LATTICE_CONSTANT = 2.4E-6 # check sample geometries!
ANGLE_SCATTER_CORRECT = 1 # asymmetric correction for first order beam. This is necessary since the molecule layer scatters strongly for low angles.
MINPEAKDIS = 100 # in peak finder, the distance allowed between peaks
WAVE_IND_DIF = 10 # difference between peaks in two spectra in case noise in spectral measurement
WAVELENGTH_DIFF = 1
WAVE_WINDOW = 1


MY_FOLDER = '20181203_B3_L1_3'

class Peaks: 
    def __init__(self,inpeaks=None,precalculated=False):
        self.angle = []
        self.wavelength = []
        self.order = []
        self.CD = []
        
        if inpeaks is not None:
            for peak in inpeaks:
                if precalculated:
                    self.addprecalculatedpeak(*peak)
                else:
                    self.addpeak(*peak)
        
    def addpeak(self, a, wl, cd, ang_range = ANGLE_ACCEPTANCE, ang_scat_correct = ANGLE_SCATTER_CORRECT):
        pred_angls = predicted_angle(wl)
        if not len(self.angle):
            ordernum=1
        else:
            ordernum = 1
            while True:
                ordpeaks = self.withorder(ordernum)
                pred_order = predicted_order(wl, a)
                if (len(ordpeaks.angle)>0 and wl <  max(ordpeaks.wavelength) and ordernum < pred_order):
                    ordernum += 1
                else:
                    break
        if ordernum == 1:
            if sum([i<=a+ang_range and i>= a-ang_range for i in pred_angls]):
                if not a <= pred_angls[0]-ang_scat_correct: 
                    self.addprecalculatedpeak(a,wl,cd,ordernum)
        elif a + ang_range >= pred_angls[ordernum-1] and a - ang_range <= pred_angls[ordernum-1]:
            self.addprecalculatedpeak(a,wl,cd,ordernum)


    def addprecalculatedpeak(self,a,wl,cd,ordernum):
        self.angle.append(a)
        self.wavelength.append(wl)
        self.order.append(ordernum)
        self.CD.append(cd)
    
    def showpeak(self,n):
        return self.angle[n], self.wavelength[n], self.CD[n],self.order[n]
    
    def withorder (self, ordernum):
        ps = [i for i, x in enumerate(self.order) if x == ordernum]
        matchpeaks = []
        for i in ps:
            matchpeaks.append(self.showpeak(i))
        return Peaks(matchpeaks,precalculated=True)
    
    
#%%
# Reads in all files and sorts them into left and right spectra after averaging (double measurements are take, since waveplate significantly changes intensity upon rotation, this is taken care of with this method)
def fold_spectrum(M1_file_list, M2_file_list, P1_file_list, P2_file_list):
    wavelengths = M1_file_list[0][0]

    M1_intensity = []
    M2_intensity = []
    P1_intensity = []
    P2_intensity = []
    for i, _ in enumerate(nano_angles):
        M1_intensity.append(M1_file_list[i][1])
        M2_intensity.append(M2_file_list[i][1])
        P1_intensity.append(P1_file_list[i][1])
        P2_intensity.append(P2_file_list[i][1])


    M1_intensity = np.array(M1_intensity)
    M2_intensity = np.array(M2_intensity)
    P1_intensity = np.array(P1_intensity)
    P2_intensity = np.array(P2_intensity)

    return wavelengths, M1_intensity, M2_intensity, P1_intensity, P2_intensity

def find_equal(ls, target, limit = WAVELENGTH_DIFF):
    for x in ls:
        if (x>= target-limit and x<= target+limit):
            return x

def unique_peaks(ls1, ls2):
    ls_result = np.concatenate((ls1,ls2))
    _,i = np.unique(ls_result, return_index=True)
    return ls_result[np.sort(i)]

def separate_peaks(y):
    ind = detect_peaks(y, mph = THRESHOLD, mpd = MINPEAKDIS)
    ind_range=[]    
    for peak_ind in ind:
        ind_range.append([peak_ind-WAVE_WINDOW, peak_ind+WAVE_WINDOW])
    return ind_range

def centroid(x,y, wind_start, wind_end):
    sum_1 = 0
    sum_2 = 0
    for i in range(wind_start, wind_end+1):
        sum_1+= x[i]*y[i]
        sum_2+= y[i]
    centroid_pos = sum_1/sum_2
    return centroid_pos

def near(lst, val):
    return min(range(len(lst)), key=lambda i: abs(lst[i]-val))

def find_centroid(x,y):
    ind_range = separate_peaks(y)
    ind_peaks = []
    for count, val in enumerate(ind_range):
        start = ind_range[count][0]
        end = ind_range[count][1]
        if end>=1043:
            end = 1043
        wave = centroid(x,y, start, end)
        ind_peaks.append(near(x, wave))
    return ind_peaks

def build_peak_array(M1, M2, P1, P2, wave, threshold = THRESHOLD, diff = WAVE_IND_DIF):
    M1_inds = find_centroid(wave, M1)
    M2_inds = find_centroid(wave, M2)
    P1_inds = find_centroid(wave, P1)
    P2_inds = find_centroid(wave, P2)
    comp_ls = [len(M1_inds), len(M2_inds), len(P1_inds), len(P2_inds)]
    
    m1_peaks = []
    m2_peaks = []
    p1_peaks = []
    p2_peaks = []
    wave_ls = []
    
    m1_1 =[]
    m2_1 = []
    p1_1 =[]
    p2_1 = []
    wave_1=[]
    m1_2 =[]
    m2_2 = []
    p1_2 =[]
    p2_2 = []
    wave_2= []

    if (len(M1_inds) == 0 or len(M2_inds) == 0 or len(P1_inds) == 0 or len(P2_inds) == 0):
        return [],[],[],[],[]
    elif not all(x == comp_ls[0] for x in comp_ls):
        for count, val in enumerate(M1_inds):
            m2 = find_equal(M2_inds, val)
            p1 = find_equal(P1_inds, val)
            p2 = find_equal(P2_inds, val)
            if m2 is not None and p1 is not None and p2 is not None:
                wave_1.append((wave[M1_inds[count]]+wave[m2]+wave[p1]+wave[p2])/4)
                m1_1.append(M1[M1_inds[count]])
                m2_1.append(M2[m2])
                p1_1.append(P1[p1])
                p2_1.append(P2[p2])
        for count, val in enumerate(P1_inds):
            m1 = find_equal(M1_inds, val)
            m2 = find_equal(M2_inds, val)
            p2 = find_equal(P2_inds, val) 
            
            if m1 is not None and m2 is not None and p2 is not None:
                wave_2.append((wave[m1]+wave[m2]+wave[P1_inds[count]]+wave[p2])/4)
                m1_2.append(M1[m1])
                m2_2.append(M2[m2])
                p1_2.append(P1[P1_inds[count]])
                p2_2.append(P2[p2])
            m1_peaks = unique_peaks(m1_1, m1_2)
            m2_peaks = unique_peaks(m2_1, m2_2)
            p1_peaks = unique_peaks(p1_1, p1_2)
            p2_peaks = unique_peaks(p2_1, p2_2)
            wave_ls = unique_peaks(wave_1, wave_2)
        return m1_peaks, m2_peaks, p1_peaks, p2_peaks, wave_ls
    elif len(M1_inds) > 1:
        for count, val in enumerate (M1_inds):
            if (abs(M1_inds[count]-M2_inds[count])<=diff and abs(P1_inds[count]-P2_inds[count])<=diff):
                 m1_peaks.append(M1[M1_inds[count]])
                 m2_peaks.append(M2[M2_inds[count]])
                 p1_peaks.append(P1[P1_inds[count]])
                 p2_peaks.append(P2[P2_inds[count]])
                 wave_ls.append((wave[M1_inds[count]]+wave[M2_inds[count]]+wave[P1_inds[count]]+wave[P2_inds[count]])/4)
        return m1_peaks, m2_peaks, p1_peaks, p2_peaks, wave_ls
    else:
        wave_ls = (wave[M1_inds]+wave[M2_inds]+wave[P1_inds]+wave[P2_inds])/4
        return M1[M1_inds], M2[M2_inds], P1[P1_inds], P2[P2_inds], wave_ls

def ave_spec(spec_array):
    M = []
    P = []
    for count, entry in enumerate(spec_array[0]):   
        M.append((spec_array[0][count]+spec_array[1][count])/2)
        P.append((spec_array[2][count]+spec_array[3][count])/2)
    return M, P

def calc_CID(P,M):
    CID = []
    for count, entry in enumerate(M):
        CID.append((P[count]-M[count])/(P[count]+M[count]))
    return CID

def predicted_order(wavelength, angle, lattice_const = LATTICE_CONSTANT):
    order = (lattice_const * np.sin(np.deg2rad(angle)))/(wavelength*10**-9)
    return order

def predicted_angle(wavelength, orders = 5, lattice_const = LATTICE_CONSTANT):
    angles = []
    for order in range(1,orders):
        angle = np.rad2deg(np.arcsin(order*(wavelength*10**-9)/lattice_const))
        if not np.isnan(angle):    
            angles.append(round(angle*2)/2)
        else:
            angles.append(0.)
    return angles
    
def getKey(item):
    return item[0]
    
#%%


# LOAD DATA FILES INTO ARRAYS
cd_p = Peaks()

M1_data = []
M2_data = []
P1_data = []
P2_data = []
angle_ls = []
data_keys = {
    "M1": M1_data,
    "M2": M2_data,
    "P1": P1_data,
    "P2": P2_data
}

files = glob.glob("data\\{}\\*.csv".format(MY_FOLDER))
for data_file in files:
    data = np.loadtxt(data_file, delimiter = ',', skiprows = 4, unpack = True )
    file_info = data_file.replace('.csv', '').split('_')[-1] 
    polarisation = file_info[-2:]
    angle = float(file_info[:-2])

    # Tuple
    angle_and_data = (angle, data)

    data_keys[polarisation].append(angle_and_data)
    angle_ls.append(angle)

angle_list = sorted(angle_ls)
nano_angles = []
for angle in angle_list:
  if angle not in nano_angles:
    nano_angles.append(angle)

M1_files = [i[1] for i in sorted(M1_data, key=getKey)]
M2_files = [i[1] for i in sorted(M2_data, key=getKey)]
P1_files = [i[1] for i in sorted(P1_data, key=getKey)]
P2_files = [i[1] for i in sorted(P2_data, key=getKey)]


# Pull folded data from files

wavelengths_ls, M1_intensity, M2_intensity, P1_intensity, P2_intensity = fold_spectrum(M1_files, M2_files, P1_files, P2_files)

M_intensity = []
P_intensity = []
for count, angle in enumerate(nano_angles):
    peaks_array = build_peak_array(M1_intensity[count], M2_intensity[count], P1_intensity[count], P2_intensity[count], wavelengths_ls)
    if len(peaks_array[0]) >0:
        M, P = ave_spec(peaks_array[:-1])
        M_intensity.append(M)
        P_intensity.append(P)
        CID = calc_CID(P,M)
        waves = peaks_array[4]
        for count,val in enumerate(M):
            cd_p.addpeak(angle, waves[count], CID[count])

# this next part removes the entries around the fundamental wavelength of the laser, since it introduces weird effects. not sure this is super ligit but Ventsi wanted this 
# TODO: Check if this is really necessary - feels uncomfortable to delete stuff
waves = cd_p.wavelength
fund = 1064
for index, wave in enumerate(waves):
    if wave > fund*0.975 and wave < fund*1.03:
        del (cd_p.wavelength[index], cd_p.angle[index], cd_p.CD[index], cd_p.order[index])   

#%%
# Plot colourmesh with diffraction angle vs wavelength with CD colour-coded
plt.figure(1)
extreme = abs(max(cd_p.CD, key = abs))
#extreme  = 0.2
for i, _ in enumerate(cd_p.CD):
    plt.scatter(cd_p.wavelength[i], cd_p.angle[i], c=cd_p.CD[i],norm=mpl.colors.Normalize(vmin=-extreme,
                                                                 vmax=extreme), cmap='RdBu',s=20)
plt_name = 'Diffraction CD spectra for '+ MY_FOLDER
plt.title(plt_name)
# set the limits of the plot to the limits of the data
plt.axis([450, 1000, 0,90])
cbar=plt.colorbar(ticks=[-extreme, -extreme / 2, 0, extreme / 2, extreme], format='%.2g')
cbar.set_label("CD (frac)")
plt.ylabel('Diffraction angle (deg.)')
plt.xlabel('Wavelength (nm)')

#%%

            
# Plot individual diffraction orders
max_order = max(cd_p.order)

for ord in range(1,max_order+1):
    ind = [i for i, x in enumerate(cd_p.order) if x == ord]
    wls = []
    cd = []
    for i in ind: 
        wls.append(cd_p.wavelength[i]) 
        cd.append(cd_p.CD[i])
    plt.figure()
    #plt.plot(wls, cd)
    plt.plot(wls, cd, marker = 'x') # Use scatter if too few data points, i.e. if lineplot looks unphysical
    plt_name = MY_FOLDER + ' CD for m = %i' %ord
    plt.title(plt_name)
    plt.axhline(y=0, linestyle= 'dashed' , color = 'k')
    plt.xlabel('Wavelength (nm)')
    plt.ylabel('CD (frac.)')

'''
if (max(M_intensity) or max(P_intensity)) == 200000.:
    print("YOU HAVE SATURATED THE DETECTOR, DATA INVALID!!!")
else:
    print("Saturation test passed")

#%%
#normalize the intensities by max values
intensities = np.divide(M1_intensity, M1_intensity.max())

# Plot intensities diffraction orders with predicted angles overlay
plt.figure()
plt.pcolormesh(wavelengths_ls, nano_angles, intensities, cmap='gnuplot')
plt.title(MY_FOLDER + ' M_intensity spectra')
plt.xlabel('Wavelength (nm)')
plt.ylabel('Diffraction angle (deg.)')
plt.axis([450, 1000, 10, 90])
cbar = plt.colorbar()
cbar.set_label("Intensity (norm)")
for wave in wavelengths_ls:
    angles = predicted_angle(wave, orders = 5)
    for angle in angles:
        plt.scatter(wave, angle+1, marker = 'o', c = 'w', s=1, alpha = 0.1)
'''
plt.show()
