import matplotlib.pyplot as plt
import numpy as np
from scipy import signal
import os
import cv2
import time
from PIL import Image, ImageStat
import sys
sys.path.insert(0, r'C:\Users\Administrator\Source\openflexure-microscope-pyclient')
import openflexure_microscope_client

def normalize(arr, t_min = 0, t_max = 1):
    """Normalises an array so multiple can be plotted on the same axis easily"""
    norm_arr = []
    diff = t_max - t_min
    diff_arr = max(arr) - min(arr)
    for i in arr:
        temp = (((i - min(arr))*diff)/diff_arr) + t_min
        norm_arr.append(temp)
    return norm_arr

def measure_derivs(img):
    img = cv2.resize(img, dsize=[0,0], fx = 0.8, fy=0.8)

    imgray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    im_to_diff = np.array(imgray, dtype = np.int16)

    diff_x = np.diff(im_to_diff[int(2* imgray.shape[0] / 5):int(3* imgray.shape[0] / 5)], axis=1)
    deriv_x = np.abs(diff_x)**2

    # Differentiate every row / column of the image, then sum across each to give an overall rate of change in x / y 
    diff_y = np.diff(im_to_diff[int(2* imgray.shape[1] / 5):int(3* imgray.shape[1] / 5)], axis=0)
    deriv_y = np.abs(diff_y)**2

    deriv_x = np.sum(deriv_x,axis=1)
    deriv_y = np.sum(deriv_y,axis=0)

    try:
        rate_of_change_x = np.max(deriv_x)
    except:
        rate_of_change_x = 0

    try:
        rate_of_change_y = np.max(deriv_y)
    except:
        rate_of_change_y = 0

    # fig, axs = plt.subplots(1,2)
    # axs[0].plot(deriv_x)
    # axs[1].plot(deriv_y)
    # plt.show()

    return rate_of_change_x, rate_of_change_y

def looping_usaf_autofocus(microscope, dz=2000, steps = 20, backlash = 100, plot = False, undershoot = 0):
    repeat = True
    while repeat:
        all_stats = usaf_autofocus(microscope, dz, steps, backlash, plot, undershoot)
        heights = all_stats[:,0]
        if all_stats[np.argmax(all_stats[:,1])][0] - min(heights) < dz / 10 or max(heights) - all_stats[np.argmax(all_stats[:,1])][0] < dz / 10:
            microscope.move_rel([0,0,undershoot])
            pass
        else:
            repeat = False
            print(f'Focused at {all_stats[np.argmax(all_stats[:,1])][0]}')

def usaf_autofocus(microscope, dz, steps, backlash = 0, plot=False, undershoot = 0):    
    microscope.move_rel([0,0,-int(dz/2 + backlash)])
    microscope.move_rel([0,0,backlash])
    all_stats = []
    imgs = []
    for i in range(steps):
        im = microscope.capture_image().convert('L')
        # Calculate statistics
        stats = ImageStat.Stat(im)                                                                 
        all_stats.append([microscope.position['z'], stats.stddev[0], microscope.grab_image_size()])
        imgs.append(im)
        microscope.move_rel([0,0,int(dz/steps)])
        time.sleep(0.5)
    microscope.move_rel([0,0,-(dz+backlash+undershoot)])
    all_stats = np.array(all_stats)
    microscope.move([microscope.position['x'], microscope.position['y'], all_stats[np.argmax(all_stats[:,1])][0] - undershoot])
    if plot:
        plt.subplot(2, 1, 1)
        all_stats[:,1] = all_stats[:,1] - np.min(all_stats[:,1])
        all_stats[:,2] = all_stats[:,2] - np.min(all_stats[:,2])

        plt.plot(all_stats[:,0], all_stats[:,1] / np.max(all_stats[:,1]), color = 'red', label = 'Standard Dev')
        plt.plot(all_stats[:,0], all_stats[:,2] / np.max(all_stats[:,2]), color = 'green', label = 'JPEG size')
        plt.legend()
        plt.subplot(2, 2, 3)
        plt.imshow(imgs[np.argmax(all_stats[:,1])], cmap = 'gray')

        plt.subplot(2, 2, 4)
        plt.imshow(imgs[np.argmax(all_stats[:,2])], cmap = 'gray')
        plt.show()
    return all_stats

def usafocus(dz, steps, undershoot, microscope):
    peak_centred = False
    start_T = time.time()
    while peak_centred is False:
        
        deriv_score_x = []
        deriv_score_y = []
        heights = []
        start = microscope.position['z']
        microscope.move_rel((0,0,-(dz/2 + 500)))
        microscope.move_rel((0,0,500))

        heights = np.zeros(steps)
        deriv_score_x = np.zeros(steps)
        deriv_score_y = np.zeros(steps)

        for i in range(steps):
            im = microscope.grab_image_array()
            deriv_x, deriv_y = measure_derivs(im)

            heights[i] = microscope.position['z']
            deriv_score_x[i] = deriv_x
            deriv_score_y[i] = deriv_y
            microscope.move_rel((0, 0, int(dz / steps)))

            time.sleep(1)

            # plt.imshow(im)
            # plt.show()
            # print(deriv_score_x[i], deriv_score_y[i])
        
        if np.max(deriv_score_x) > np.max(deriv_score_y):
            peak = heights[np.argmax(deriv_score_x)]
        else:
            peak = heights[np.argmax(deriv_score_y)]

        if abs(peak - min(heights)) > 0.3 * abs(max(heights) - min(heights)) and abs(peak - max(heights)) > 0.3 * abs(max(heights) - min(heights)):
            peak_centred = True
        else:
            microscope.move((microscope.position['x'], microscope.position['y'], peak))

    plt.plot(heights, deriv_score_x)
    plt.plot(heights, deriv_score_y)
    plt.show()
    microscope.move((microscope.position['x'], microscope.position['y'], min(heights) - 500))
    microscope.move((microscope.position['x'], microscope.position['y'], peak - (undershoot+100)))
    print(peak, 'time taken to autofocus is {0}'.format(int(time.time() - start_T)))

def unpack(scan_data, start_index = 0):
    jpeg_times = scan_data['jpeg_times']
    jpeg_sizes = scan_data['jpeg_sizes']
    jpeg_sizes_MB = [x / 10**3 for x in jpeg_sizes]
    stage_times = scan_data['stage_times']
    stage_positions = scan_data['stage_positions']
    stage_height = [pos[2] for pos in stage_positions]

    jpeg_heights = np.interp(jpeg_times,stage_times,stage_height)

    return jpeg_heights[start_index:], jpeg_sizes_MB[start_index:]


def z_stack(dz, number_of_images, folder, microscope, backlash = 100):
    start = time.time()
    # os.makedirs(folder, exist_ok=True)

    data = microscope.pull_settings()
    csm_matrix = data['extensions']['org.openflexure.camera_stage_mapping']['image_to_stage_displacement']

    for i in range(number_of_images):

        height = microscope.position['z']
        
        m = microscope.capture_image_to_disk({
            "temporary": False,
            "use_video_port": False,
            "bayer": True,
            "filename": 'bayer_{0}/{1}'.format(folder, height),
            "tags": [
                    folder.split('/')[0],
                ]
            }
            )
        time.sleep(0.2)
        download_image(microscope,m['id'])
        microscope.move_rel((0,0,dz))
        time.sleep(0.9)
        
    microscope.move_rel((0,0,-(dz*number_of_images+backlash)))
    microscope.move_rel((0,0,backlash))
    print('z stack takes {0}'.format(int(time.time() - start)))

def download_image(microscope, id):
    data = microscope.get_capture_metadata(id)

    file_path = os.path.join('scans', data['path'].removeprefix("/var/openflexure/data/micrographs/"))
    file_path = os.path.split(file_path)[0]
    os.makedirs(file_path, exist_ok=True)
    im = microscope.download_from_id(id, file_path)