# -*- coding: utf-8 -*-
"""
Analyse a folder of z-stacks of edge images, to recover resolution and field curvature.

"""
from __future__ import print_function
import matplotlib
matplotlib.use('TkAgg')
from matplotlib import pyplot as plt

import numpy as np
import cv2

import os
import sys
import os.path
from matplotlib.backends.backend_pdf import PdfPages
import psf_image
import analyse_distortion

def retrieve_cached_psfs(folder, fnames):
    """Attempt to retrieve the previously-analysed PSF/ESF from a file"""
    npz = np.load(os.path.join(folder, "edge_analysis.npz"), allow_pickle=True)

    cached_fnames = npz['filenames']
    print(cached_fnames)
    cached_psfs = npz['psfs'][0]

    cached_esfs = npz['esfs'][0]
    psfs = np.zeros((len(fnames), ) + cached_psfs.shape[1:], dtype=cached_psfs.dtype)
    esfs = np.zeros((len(fnames), ) + cached_esfs.shape[1:], dtype=cached_esfs.dtype)
    for i, fname in enumerate(cached_fnames):
        psfs[i] = cached_psfs[i]
        esfs[i] = cached_esfs[i]
    return psfs, esfs, fnames

def cache_psfs(folder, fnames, esfs, psfs):
    """Attempt to retrieve the previously-analysed PSF/ESF from a file"""
    link = os.path.join(folder, "edge_analysis.npz")
    
    cache = {
        'filenames': fnames,
        'psfs': psfs,
        'esfs': esfs
    }

    np.savez(link, **cache)

def analyse_zstack(folder):
    """Find the point spread function of each image in a Z stack series"""
    # The folder usually contains raw and non-raw versions - use the raw ones if possible
    fnames = [f for f in os.listdir(folder) if f.endswith(".jpeg") or f.endswith(".jpg")]
    assert len(fnames) > 0, "There were no files in the folder '{}'".format(folder)
    if len([f for f in fnames if "raw" in f]) > 0:
        fnames = [f for f in fnames if "raw" in f]
    # Extract the stage position from the filenames
    positions = np.array([int(analyse_distortion.position_from_filename(f, folder)) for f in fnames])

    fnames = [x for _, x in sorted(zip(positions, fnames))]
    positions = np.sort(positions)

    if type(positions) is dict:
        assert np.all(np.var(positions, axis=0)[:2] == 0), "Only the Z position should change!"
        zs = positions[:,2]
    else:
        zs = positions
    # extract the PSFs
    try:
        print("Checking for an existing cache of results in this folder")
        # raise Exception("I don't want to use the cache.")
        cached_psfs, cached_esfs, cached_fnames = retrieve_cached_psfs(folder, fnames)

        cached_psfs = np.asarray(cached_psfs)
        cached_esfs = np.asarray(cached_esfs)
        print("Shape of all cached psfs is {0}".format(cached_psfs.shape))

        psfs, _, esfs = psf_image.analyse_edge_image(os.path.join(folder, fnames[0]), save_plot = False)
        psfs = np.asarray(psfs)
        print("Shape of the psf of the first image is {0}".format(psfs.shape))

        # We check the results of the first image PSF against the cache - if it's the same, it's likely (not guaranteed) the rest will be the same
        # Also make sure the cache has the same length as the number of images
        if not np.array_equal(cached_psfs[0], psfs):
            print('cached psfs changed')
            raise Exception('Cache and recent code don\'t match, redoing and overwriting cache')
        elif cached_fnames != fnames:
            print('filenames changed from cache')
            raise Exception('Cache and recent code don\'t match, redoing and overwriting cache')

        esfs = cached_esfs
        psfs = cached_psfs
        print("The cache actually worked!")
    except:
        print("Cached PSFs failed, changed or missing, analysing images (may take some time)")
        spread_functions = psf_image.analyse_files([os.path.join(folder, f) for f in fnames], output_dir=folder)
        psfs = []
        psfs.append(np.array([spread_functions[fname]['psfs'] for fname in spread_functions.keys()]))
        psfs = np.asarray(psfs)

        esfs = []
        esfs.append(np.array([spread_functions[fname]['esfs'] for fname in spread_functions.keys()]))

        for esf in esfs:
            print(esf.shape)
        esfs = np.asarray(esfs)

        cache_psfs(folder, fnames, esfs, psfs)

        psfs = psfs[0]
        esfs = esfs[0]

    img = cv2.imread(os.path.join(folder, fnames[0]), -1)
    fig_img, ax = plt.subplots(1, 1)

    ax = plt.imshow(img)

    # plot the PSFs as YZ slices
    
    blocks = psfs.shape[1]
    channels = 3

    # zs is the list of heights
    
    z_bounds = np.concatenate(([zs[0]], (zs[1:] + zs[:-1])/2., [zs[-1]])) # pcolormesh wants edge coords not centre coords
    y_bounds = np.arange(psfs.shape[-1] + 1) - psfs.shape[-1]/2.0

    fig, axes = plt.subplots(channels, blocks, figsize=(2*blocks, 2*channels), sharey=True)

    rows = ['Red', 'Green', 'Blue']
    cols = ['Block {}'.format(row) for row in range(1,1 + psfs.shape[1])]

    # plt.setp(axes.flat, xlabel='X-label', ylabel='Y-label')

    pad = 5 # in points

    # Annotates the plots with blocks and colours
    for ax, col in zip(axes[0], cols):
        ax.annotate(col, xy=(0.5, 1), xytext=(0, pad),
                    xycoords='axes fraction', textcoords='offset points',
                    size='large', ha='center', va='baseline')

    for ax, row in zip(axes[:,0], rows):
        ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0),
                    xycoords=ax.yaxis.label, textcoords='offset points',
                    size='large', ha='right', va='center')

    fig.tight_layout()
    # tight_layout doesn't take these labels into account. We'll need 
    # to make some room. These numbers are are manually tweaked. 
    # You could automatically calculate them, but it's a pain.
    fig.subplots_adjust(left=0.15, top=0.95)

    print(psfs.shape)

    # extract brightnesses and only normalise those, not the pixel positions

    brightnesses = psfs[:,:,:,1,:]
    brightnesses8 = (brightnesses / np.max(brightnesses) * 255).astype(np.uint8)
    brightnesses8[brightnesses < 0] = 0

    for i in range(blocks):
        for channel in range(channels):
            axes[channel][i].pcolormesh(y_bounds, z_bounds, brightnesses8[:, i, channel], shading = 'auto')
    
    print(brightnesses8[:, i, channel].shape)

    # calculate sharpness metrics and fit for field curvature
    ys = np.arange(brightnesses.shape[-1])
    mean_ys = np.mean(np.abs(brightnesses[:, :, :])*ys)/np.mean(np.abs(brightnesses[:,:,:]))
    var_ys = np.mean(np.abs(brightnesses[:,:,:])*(ys-mean_ys)**2)/np.mean(np.abs(brightnesses[:,:,:]), axis=-1)
    field_curvature = np.argmin(var_ys, axis=0) # todo: use the thresholded centre of mass code from USAF
    fig2, ax2 = plt.subplots(1,1)
    colors = ['red', 'green', 'blue']
    for i, color in enumerate(colors):
        ax2.plot(field_curvature.T[i], color = color, label = color)
    plt.legend()
    plt.xlabel("Block across edge")
    plt.ylabel("Focused position in z stack (images)")
    with PdfPages(os.path.join(folder, "edge_zstack_summary.pdf")) as pdf:
        for f in [fig_img, fig, fig2]:
            pdf.savefig(f)
            plt.close(f)

    print('folder done')
    

if __name__ == "__main__":
    print("\n")
    try:
        path = sys.argv[1]
        assert os.path.isdir(path)
    except:
        print("Usage: {} <folder> [<folder> ...]".format(sys.argv[0]))
        print("This script expects arguments that are folders, containing jpeg files of edge images.")
        print("This scripts creates a number of PDF plots based on said images, per folder.")
            
    for dir in sys.argv[1:]:
        analyse_zstack(dir)