#%% toboooo | Grayson Group
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.patches as mpatch
from numpy import log
from scipy import constants

def concA(t, k1, k1p, A0):
	return A0 * tf.math.exp(-(k1+k1p) * t)

def concB(t, k1, k1p, k2, A0):
	p1 = (A0 * k1) / (k2 - k1 - k1p)
	p2 = (tf.math.exp(-(k1 + k1p) * t) - tf.math.exp(-k2 * t))
	return  p1 * p2

def concC(t, k1, k1p, k2, A0):
	k1p1 = k1p + k1
	k2p2 = k2 - k1 - k1p
	p1 = (A0 * k1) / k1p1
	p2 = 1 - ((k2 * tf.math.exp(-k1p1 * t)) / k2p2) + (k1p1 * tf.math.exp(-k2 * t)) / k2p2
	return p1 * p2

def delta_G(k):
	j = -(log(k) - log((constants.Boltzmann * 293.15) / constants.h)) * constants.R * 293.15
	return j * 0.000239

# Forward + rev
def reverse_solve_equations(t, k1, km2, k2, A0, B0, C0):
	concA = [A0]
	concB = [B0]
	concC = [C0]
	A = A0
	B = B0
	C = C0
	dt = t[1] - t[0]
	for ti in t[:t.shape[0]-1]:
		A1 = A + dt * (-k1 * A)
		B1 = B + dt * (k1 * A - k2 * B + km2 * C)
		C1 = C + dt * (k2 * B - km2 * C)
		A = A1
		B = B1
		C = C1
		concA.append(A)
		concB.append(B)
		concC.append(C)
	return concA, concB, concC

def plot(calcA, calcB, calcC, name='plot'):
	inds = np.where(t < 20000)
	plt.plot(t[inds],A[inds],color='royalblue')
	plt.plot(t[inds],B[inds],color='royalblue')
	plt.plot(t[inds],C[inds],color='royalblue')

	plt.plot(t[inds],calcA.numpy()[inds], color='orange')
	plt.plot(t[inds],calcB.numpy()[inds], color='orange')
	plt.plot(t[inds],calcC.numpy()[inds], color='orange')

	plt.xlabel('Time (s)')
	plt.ylabel('Concentration [M]')
	E_box  = mpatch.Patch(color='royalblue', label='Experimental')
	F_box = mpatch.Patch(color='orange', label='Fitted')
	plt.title(f'{name} kinetics')
	plt.legend(handles=[E_box,F_box])
	plt.show()

def rc_optimiser(t, A, A0, B, C, name):

	opt = tf.keras.optimizers.SGD(learning_rate=0.001)
	loss_fn = tf.keras.losses.MeanSquaredError()

	# First try optimising k1 from A data, then k2 from B data separately
	k1 = tf.Variable(0.001)
	k1p = tf.Variable(0.0001, constraint=tf.keras.constraints.NonNeg())
	k2 = tf.Variable(0.001)

	for i in range(1000):
		with tf.GradientTape() as tape:
			calcA = concA(t, k1, k1p, A0)
			loss = loss_fn(A, calcA)
		print("Iter: %d, k1 = %.7f, k1p = %.7f, loss = %.6f" % (i + 1, k1, k1p, loss))
		grad = tape.gradient(loss, [k1, k1p])
		opt.apply_gradients(zip(grad, [k1, k1p]))
	print("Optimised k1 = %f" % k1, '\nOptimised k1p = %f' % k1p)

	opt2 = tf.keras.optimizers.SGD(learning_rate=0.0001)
	for i in range(1000):
		with tf.GradientTape() as tape:
			calcB = concB(t, k1, k1p, k2, A0)
			loss = loss_fn(B, calcB)
		print("Iter: %d, k2 = %.7f, loss = %.6f" % (i + 1, k2, loss))
		grad = tape.gradient(loss, [k2])
		opt2.apply_gradients(zip(grad, [k2]))
	print("Optimised k2 = %f" % k2) 

	opt3 = tf.keras.optimizers.SGD(learning_rate=0.0001)
	for i in range(1000):
		with tf.GradientTape() as tape:
			calcA = concA(t, k1, k1p, A0)
			calcB = concB(t, k1, k1p, k2, A0)
			calcC = concC(t, k1, k1p, k2, A0)
			lossA = loss_fn(A, calcA)
			lossB = loss_fn(B, calcB)
			lossC = loss_fn(C, calcC)
			loss = lossA + lossB + lossC
		print("Iter: %d, k1 = %.7f, k1p = %.7f, k2 = %.7f,  loss = %.6f" % (i + 1, k1, k1p, k2, loss))
		grad = tape.gradient(loss, [k1, k1p, k2])
		opt3.apply_gradients(zip(grad, [k1, k1p, k2]))
	print("Optimised k1 = %f, k1p = %f, k2 = %f " % (k1, k1p, k2))
	calcA = concA(t, k1, k1p, A0)
	calcB = concB(t, k1, k1p, k2, A0)
	calcC = concC(t, k1, k1p, k2, A0)
	plot(calcA, calcB, calcC, name=name)
	b1 = delta_G(k1)
	b2 = delta_G(k2)
	print(f'Barrier for [2,3] step: {b1} \nBarrier for [1,3] step: {b2}\n')

#%%
if __name__ == "__main__":
	data = ['NMe.dat', 'NBn.dat', 'NCPh3.dat']
	names= ['N-Me', 'N-Bn', 'N-Trityl']
	for dat, name in zip(data,names):
		if name == 'N-Trityl':
			t, A, B, C = np.loadtxt(dat, unpack=True)
		else:
			t, A, B, C = np.loadtxt(dat, skiprows=1, unpack=True)
		t = t.astype(np.float32)
		A0 = float(A[0])
		rc_optimiser(t,A,A0,B,C,name)

# %%
