"""
Take a single image with a nearly horizontal or vertical edge, and return the point spread function of the system
"""
import sys
import yaml
import matplotlib
from matplotlib import pyplot as plt
import matplotlib.patches
import numpy as np
import os
from skimage.io import imread
import cv2
import skimage
from skimage.feature import peak_local_max
import scipy
from extract_raw_image import load_raw_image
from matplotlib.backends.backend_pdf import PdfPages
from ast import literal_eval
import json
import piexif
import logging
import time
# matplotlib.use('TKAgg')

def analyse_files(fnames, output_dir="."):
    """Analyse a number of files.  kwargs passed to analyse_file"""
    spread_functions = {}
    with PdfPages(os.path.join(output_dir, "edge_analysis.pdf")) as pdf:
        for fname in fnames:
            fig, psfs, esfs = analyse_edge_image(fname)
            pdf.savefig(fig)
            plt.close(fig)
            spread_functions[fname] = {"psfs": psfs, "esfs": esfs}
            print('done {0}'.format(fname))

    with open(os.path.join(output_dir, "spread_functions.yaml"), 'w') as outfile:
        yaml.dump({f: {"esfs": spread_functions[f][e], "psfs": spread_functions[f][p]} for f, (p, e) in spread_functions.items()}, outfile)
    print('spread functions length is {0}, type is {1}'.format(len(spread_functions), type(spread_functions)))

    return spread_functions


def analyse_edge_image(fname, blocks = 9, save_plot = True):
    '''Analyse a single image, split into 9 equal spaced blocks along the edge
    Optionally save the plot of that edge'''

    print("Processing {}...".format(fname))

    # Load the images, align as needed, find the edge spread and point spread function
    raw_image, rgb_image = load_images(fname)
    rgb_image, raw_image = arrange_imgs(rgb_image, raw_image)
    esfs, m, y_inc, tx_all = find_esf(rgb_image, raw_image, blocks)
    with open(f'{fname.split(".")[0]}.json', 'w') as fp:
        json.dump(tx_all, fp)
    psfs = psf_from_esfs(esfs)

    if save_plot:
        fig = plot_to_pdf(rgb_image, y_inc, m, esfs, psfs, blocks, fname)
        return fig, psfs, esfs
    else:
        return psfs, psfs, esfs


def pull_usercomment_dict(filepath):
    """
    Reads UserComment Exif data from a file, and returns the contained bytes as a dictionary.
    Args:
        filepath: Path to the Exif-containing file
    """
    try:
        exif_dict = piexif.load(filepath)
    except piexif._exceptions.InvalidImageDataError:
        logging.warning("Invalid data at {}. Skipping.".format(filepath))
        return None
    if "Exif" in exif_dict and piexif.ExifIFD.UserComment in exif_dict["Exif"]:
        try:
            return json.loads(exif_dict["Exif"][piexif.ExifIFD.UserComment].decode())
        except json.decoder.JSONDecodeError:
            logging.error(
                f"Capture {filepath} has old, corrupt, or missing OpenFlexure metadata. Unable to reload to server."
            )
    else:
        return None

def csm_from_tags(fname):
    '''Pull the tags from the image being worked on,
    and return any in the shape of a CSM matrix (2x2).
    Else return [0, 0], indicating that the CSM isn't known'''

    try:
        csm = pull_usercomment_dict(os.path.join(fname))['instrument']['settings']['extensions']['org.openflexure.camera_stage_mapping']['image_to_stage_displacement']
        return csm
    except:
        return [0, 0]

def pixels_to_steps(pixels, csm):
    '''Take an 'xy' pixel move and the camera stage mapping, and turn it into a number of steps to move'''

    # If the move is only one int or float, treat it as [given_move, 0]
    if isinstance(pixels, int) or isinstance(pixels, float):
        pixels = np.array([pixels, 0])

    return np.matmul(pixels, csm)

def xy_pixels_to_um(pixels, csm):
    '''Take an 'xy' pixel move and camera stage mapping results, and turn it into a distance in microns'''

    # # First convert from pixels to stage move using the CSM matrix
    # # If the move is only one int or float, treat it as [given_move, 0]
    # steps = pixels_to_steps(pixels, csm)

    # # Assume each xy step is 72 nm, unless otherwise supplied in metadata
    # #TODO: add metadata for this conversion from micat
    # return steps * (72/1000)
    # return pixels * 0.0435 #for 100x
    return pixels * 0.1073 # for 40x
    # return pixels * 0.0712 #for 60x
    # return pixels * 0.2193 #for 20x
    # return pixels * 0.4404 #for 10x

def z_steps_to_um(steps):
    '''Take a move in z in motor steps, and return the estimated distance'''

    # Assume each xy step is 50 nm, unless otherwise supplied in metadata
    #TODO: add metadata for this conversion from micat
    return steps * 50 / 1000

def detect_edge(rgb_img, boundary):
    '''Take an RGB image, and a brightness value between 0 and 255 to be the cutoff between light and dark
    Return details of the edge if the image is aligned as desired, otherwise return instructions on how to manipulate the image'''

    imgray = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY)
    
    # Sino and peak_local_max return the distance from the centre of the image to the line, and the angle of the line
    sino, _ = find_sino(imgray, boundary)
    xy = peak_local_max(sino, threshold_abs=8, num_peaks=1, exclude_border=False)[0]

    # print(xy)
    
    # We don't want a line that's outside the range -9 to +9, or 81 to 99
    if xy[1] <= 9 or xy[1] >= 171:
        pass
    elif xy[1] >= 81 and xy[1] <= 99:
        return 0, "Horizontal edge"
    elif xy[1] == 0 or xy[1] == 90:
        return 0, "No angle on edge"
    else:
        return 0, "Poor edge angle"

    # Crop out an area from either end of the image. We expect the left side to be dark, right side to be light
    edge_region = int(rgb_img.shape[1] / 70)
    dark = np.mean(imgray[:, :edge_region])
    light = np.mean(imgray[:, -edge_region:])

    # If not, return this info
    if dark > light:
        return 0, 'flipped'

    # Find the equation of the line from the angle and distance
    # m is the gradient, [x1, y1] is the coords of the nearest point to the centre
    m, x1, y1 = line_from_xy(xy, rgb_img, sino)

    return (m, x1, y1), 'success'

def rotate_img(img):
    '''Rotate the image by 90 degrees'''
    img = img.transpose((1, 0, 2))  # if the edge is in the first array index, move it to the second
    return img

def flip(img):
    '''Mirror the image in y'''
    img = img[:, ::-1, ...]  # for falling edges, flip the image so they're rising
    return img

def find_sino(imgray, boundary):
    '''find the sinogram and edges of the image, based on the boundary between 0 and 255 that is the transition
    from light to dark'''

    # resize the image for speed and to get an overall gradient, not many local, smaller gradients
    resize = 0.04
    imgray = cv2.resize(imgray, (0, 0), fx = resize, fy = resize)

    # convert to grayscale
    _, thresh = cv2.threshold(imgray, boundary, 255, cv2.THRESH_BINARY)

    # edges extracts the lines from the binary image, then randon transform returns the sinogram
    edges = cv2.Canny(thresh, 100, 200)
    sinogram = skimage.transform.radon(edges, theta=None, circle=False, preserve_range=False)
    # We found the edge on a lower res image, but still need the sino to have the original y range
    # Don't scale x because it's an angle - still want 0-180 degrees
    sinogram = cv2.resize(sinogram, (0, 0), fx=1, fy=1 / resize)
    
    # fig, axs = plt.subplots(1,3)
    # axs[0].imshow(imgray, cmap='gray')
    # axs[1].imshow(edges, cmap='gray')
    # axs[2].imshow(sinogram, cmap = 'gray')
    # plt.show()
    return sinogram, edges

def line_from_xy(xy, img, sino):
    '''return gradient and coordinates of nearest point'''

    # r needs to be based on centre of image, not corner
    xy[0] = xy[0] - sino.shape[0] / 2
    r = xy[0]
    angle = xy[1]
    m = np.tan(np.deg2rad(90 - angle))
    x1 = r * np.cos(np.deg2rad(angle)) + img.shape[1] / 2
    y1 = -r * np.sin(np.deg2rad(angle)) + img.shape[0] / 2

    return m, x1, y1

def arrange_imgs(rgb_img, raw_img):
    '''take the RGB and raw Bayer image, and rotate / flip them as needed
    output should be the images aligned so that the edge is vertical(ish)
    and the transition goes from dark to light (low to high brightness)'''

    result = 0
    if raw_img is None:
        raw_img = rgb_img

    gray_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY)
    gray_img = cv2.resize(gray_img, [0,0], fx = 0.4, fy=0.4)

    # boundary = np.mean(np.array(np.min(gray_img), np.max(gray_img))) * 1.5
    boundary = 90

    # run until 'detect_edge' returns success
    while result != 'success':
        _, result = detect_edge(rgb_img, boundary)

        if result == "Horizontal edge":
            rgb_img = rotate_img(rgb_img)
            if raw_img is not None:
                raw_img = rotate_img(raw_img)
        elif result == "flipped":
            rgb_img = flip(rgb_img)
            if raw_img is not None:
                raw_img = flip(raw_img)
        
        # return "failed" if detect edge can't find an angle or the angle is unsuitable for the ESF finder
        elif result == "Poor edge angle" or result == "No angle on edge":
            print(result)
            return "failed"

    return (rgb_img, raw_img)

def load_images(fname):
    '''load an image as a Bayer image and convert a copy to RGB if possible,
    otherwise, just load the image as RGB'''
    try:
        # Try to load it as a bayer image, convert to rgb, then normalise both as ints 0-255
        bayer_array = load_raw_image(fname)
        rgb_image = (bayer_array.demosaic()//4).astype(np.uint8)
        raw_image = bayer_array.array
        max_channels = ([np.amax(rgb_image[:, :, 0]), np.amax(rgb_image[:, :,1]), np.amax(rgb_image[:, :,2])])
        rgb_image = (np.divide(rgb_image, max_channels)*255).astype('uint8')
        max_channels = ([np.amax(raw_image[:, :, 0]), np.amax(raw_image[:, :,1]), np.amax(raw_image[:, :,2])])
        raw_image = (np.divide(raw_image, max_channels)*255).astype('uint8')

    except Exception as e:
        logging.warning("Can't load raw data, falling back to JPEG data")
        raw_image = None
        rgb_image = imread(fname)
    
    return raw_image, rgb_image

def psf_from_esfs(esfs):
    '''from a list of edge spread functions, differentiate to find the point spread function'''

    #TODO: timing - this could probably be an array, but takes 6ms anyway
    psfs = []

    for esf in esfs:
        psf = []
        for x, I, _ in esf:
            if isinstance(x, int):
                psf.append([np.zeros(799),np.zeros(799)])
            else:
                psf.append((numerical_diff(x, I, sqrt=False, crop=100)))
        psfs.append(psf)
    return psfs

def plot_to_pdf(rgb_image, y_inc, m, esfs, psfs, blocks, fname):
    '''Take the rgb image, information about the edge, esfs / psfs, block and filename,
    and plot a series of graphs showing the edge response'''

    #TODO: timing - 1.5 seconds

    with matplotlib.rc_context(rc={"font.size":6}):
        print(" plotting...",end="")

        # get the camera stage mapping matrix from the metadata of the images
        csm = csm_from_tags(fname)

        # set up the plots
        fig, ax = plt.subplots(2, 3, figsize=(6,9), gridspec_kw={'width_ratios': [2, 3, 8], 'height_ratios': [1, 7]})
        fig.suptitle(fname)

        # show the image without it rotated or flipped
        #TODO: is there anything else to show at the top of the page?
        ax[0][0].imshow(cv2.cvtColor(cv2.imread(fname, -1), cv2.COLOR_BGR2RGB))
        ax[0][0].set_axis_off()

        ax[0][1].set_axis_off()
        ax[0][2].set_axis_off()
        
        # Plot the aligned, cropped image on the left, and overlay the fitted line
        image_ax = ax[1][0]
        ys = np.arange(rgb_image.shape[0])
        xs = (ys - y_inc)/m
        image_ax.imshow(rgb_image[:, int(np.min(xs) - 5*10):int(np.max(xs) + 5*10), ...])
        image_ax.xaxis.set_visible(False)
        image_ax.yaxis.set_visible(False)
        image_ax.plot(xs - (np.min(xs) - 5*10), ys, color="red", dashes=(2,8))

        # Display the ESF for each block of the image
        esf_ax = ax[1][1]
        esf_max = [0, 0, 0]
        for esf in esfs:
            for channel, (x, I, Tx) in enumerate(esf):
                esf_max[channel] = max(esf_max[channel], np.max(I))

        for i, esf in enumerate(esfs):
            for channel, col, (x, I, Tx) in zip(range(3), ['red', 'green', 'blue'], esf):
                offset = (len(esfs) - i)*0.75
                normI = I/esf_max[channel] + offset
                esf_ax.plot(x, normI, color=col)

        # Display the PSF for each block of the image
        psf_ax = ax[1][2]
        psf_max = [0, 0, 0]
        for psf in psfs:
            for channel, (x, I) in enumerate(psf):
                psf_max[channel] = max(psf_max[channel], np.max(I))

        for i, psf in enumerate(psfs):
            for channel, col, (x, I) in zip(range(3), ['red', 'green', 'blue'], psf):
                offset = (len(psfs) - i)*0.75
                normI = I/psf_max[channel] + offset
                psf_ax.plot(x, normI, color=col)
                fwhm = find_fwhm(x, normI, zerolevel=offset, annotate_ax=psf_ax, color="black", text_y=offset+(channel+2)*0.1, channel = col, csm = csm)
            centre_y = rgb_image.shape[0]*(i+0.5)/blocks
            image_ax.annotate(str(i), xy = (y_inc+m*centre_y, centre_y), fontsize = 'smaller')

    # save the figure as a pdf for that specific image, and return the figure if they are combined into a single file with other image results
    fig.savefig(fname.split('.')[0] + "_edge_analysis.pdf")

    return fig

def subpixel_averaging(x, I, dx = 0.25):    
    regression_x = np.arange(x.min(), x.max(), dx)

    j = 1
    
    x_split = []
    x_splits = []

    i_split = []
    i_splits = []
    for k, i in enumerate(x):
        if j == len(regression_x):
            break
        if i < regression_x[j]:
            x_split.append(i)
            i_split.append(I[k])
        else:
            if x_split:
                x_splits.append(x_split)
                i_splits.append(i_split)
            j += 1
            x_split = []
            i_split = []
    import statistics
    x_aved = [statistics.mean(x) for x in x_splits]
    I_aved = [statistics.mean(x) for x in i_splits]

    return x_aved[:398], I_aved[:398]

def local_linear_regression(x, I, dx=0.1, kernel_width=1):
    """At each point, use a local linear regression to estimate the value
    
    Arguments:
        x, I: Unsorted coordinates of the data points
        dx: spacing of the points at which to estimate the regression curve
        kernel_width: width (standard deviation) of the Gaussian smoothing kernel
        
    Return Value:
    xs, Is: X and I coordinates of the regression estimator.  xs will have
    a spacing of dx, and runs from the min to the max of the input x array.
    
    Details:
    This is essentially a dumb implementation of the formula given on page 50
    of Bowman & Azzalini (1997), "Applied smoothing techniques for data analysis".
    """

    #TODO: timings - 0 to 3 seconds - would be quicker if all weights were pre-calculated
    
    regression_x = np.arange(x.min(), x.max(), dx)

    # This is ugly, but we would expect 1000 points. If we're over, remove the edges - shouldn't be anything interesting there
    if len(regression_x) > 1000:
        remove_end = int(np.ceil((len(regression_x) - 1000) / 2))
        remove_start = int(np.floor((len(regression_x) - 1000) / 2))

        regression_x = regression_x[remove_start:-remove_end]

    # If we're over, repeat the last point to make it up to 1000 points
    if len(regression_x) != 1000:
        regression_x = np.append(regression_x, np.full(1000-len(regression_x), regression_x[-1]))

    # Construct a 2D array representing the difference in X coordinate between
    # the points and the estimator positions
    displacements = x[:, np.newaxis] - regression_x[np.newaxis, :]


    #TODO: timings - this is slow

    weights = np.exp(-displacements**2 / (kernel_width**2 * 2))

    # Compute the weighted means at each estimate position
    #local_means = np.sum(weights * I[:, np.newaxis], axis=0) / np.sum(weights, axis=0)
    # Compute the weighted moments for each estimate position
    s0 = np.mean(weights, axis=0)[np.newaxis, :]

    s1 = np.mean(weights * displacements**1, axis=0)[np.newaxis, :]

    s2 = np.mean(weights * displacements**2, axis=0)[np.newaxis, :]

    #TODO timings - this is slow
    local_linear_estimator = np.mean(
        ((s2 - s1*displacements)*weights*I[:,np.newaxis])
        /
        (s2*s0 - s1**2),
        axis=0
    )

    return regression_x[10:-10], local_linear_estimator[10:-10]


def meanest(x, I, dx=0.1, kernel_width=1):
    """At each point, use a local linear regression to estimate the value
    
    Arguments:
        x, I: Unsorted coordinates of the data points
        dx: spacing of the points at which to estimate the regression curve
        kernel_width: width (standard deviation) of the Gaussian smoothing kernel
        
    Return Value:
    xs, Is: X and I coordinates of the regression estimator.  xs will have
    a spacing of dx, and runs from the min to the max of the input x array.
    
    Details:
    This is essentially a dumb implementation of the formula given on page 50
    of Bowman & Azzalini (1997), "Applied smoothing techniques for data analysis".
    """

    #TODO: timings - 0 to 3 seconds - would be quicker if all weights were pre-calculated
    
    regression_x = np.arange(x.min(), x.max(), dx)

    # This is ugly, but we would expect 1000 points. If we're over, remove the edges - shouldn't be anything interesting there
    if len(regression_x) > 1000:
        remove_end = int(np.ceil((len(regression_x) - 1000) / 2))
        remove_start = int(np.floor((len(regression_x) - 1000) / 2))

        regression_x = regression_x[remove_start:-remove_end]

    # If we're over, repeat the last point to make it up to 1000 points
    if len(regression_x) != 1000:
        regression_x = np.append(regression_x, np.full(1000-len(regression_x), regression_x[-1]))

    # Construct a 2D array representing the difference in X coordinate between
    # the points and the estimator positions

    positions = {}
    for i, I_point in enumerate(I):
        x_point = x[i]
        if x_point < max(regression_x):
            position = (x_point - x.min()) // dx
            if position not in positions.keys():
                positions[position] = [I_point]
            else:
                positions[position].append(I_point)

    print(positions)

    for position in positions.keys():
        positions[position] = np.mean(positions[position])

    # plt.plot(list(positions.keys()), list(positions.values()))
    # plt.show()

    regression_x = np.array(list(positions.keys()))
    I_est = np.array(list(positions.values()))
    return regression_x, I_est

def find_fwhm(x, I, zerolevel=0, annotate_ax=None, color="black", text_y=None, channel = 'unknown', csm = [[0, 0], [0, 0]]):
    '''From the point spread function, find and optionally plot the FWHM.
    If the PSF isn't a clean, Gaussian-looking peak, return a FWHM of 50 instead - chosen to be higher than any we'd expect
    Uses the camera stage mapping matrix to convert the FWHM from pixels into physical size'''

    # Find the half max and lower / upper points the PSF crosses them
    threshold = (np.max(I) - zerolevel) / 2 + zerolevel
    ileft = np.argmax(I > threshold)
    iright = (len(I) - 1 - np.argmax(I[::-1] > threshold))
    fwhm = x[iright] - x[ileft]


    # In the range between the first and last time the half max is crossed,
    # count how many more times it's crossed. We don't want to cross multiple times!
    I_thresh = I[ileft : iright + 1] - threshold
    cross_thresh = (x*I for x,I in zip(I_thresh, I_thresh[1:]))
    crosses = sum(1 for x in cross_thresh if x < 0)
    
    if crosses < 1:
        if annotate_ax is not None:
            annotate_ax.annotate("{0} fwhm {1:.2f}".format(channel, fwhm),
                                    xy=(x[ileft], threshold),
                                    xytext=(annotate_ax.get_xlim()[0] + 1, threshold if text_y is None else text_y),
                                    arrowprops = dict(color=color, shrink=0, width = 0.2, headwidth = 1),
                                    color=color,)
            um_fwhm = xy_pixels_to_um(np.array([fwhm, 0]), csm)

            annotate_ax.text(annotate_ax.get_xlim()[1]-20, threshold if text_y is None else text_y,
                                    # "{0:.3f} $\mu$ m".format(np.max(np.abs(um_fwhm))),
                                    "{0:.3f} μm".format(np.max(np.abs(um_fwhm))),
                                    color=color,)
    else:
        if annotate_ax is not None:
            annotate_ax.text(annotate_ax.get_xlim()[0] + 1, threshold if text_y is None else text_y,
                                "{0} fwhm failed".format(channel),
                                    color=color,)
        fwhm = 50

    return fwhm

def numerical_diff(x, y, crop=100, sqrt=False):
    """Numerically differentiate a vector of equally spaced values

    Arguments:
        x: the independent variable, assumed to be evenly spaced
        y: the dependent variable to be differentiated
        crop: return a shorter array than the input by removing
            this many points from the start and end (default: 0)
        sqrt: instead of calculating d/dx y, calculate (d/dx y**0.5)**2
            to recover the point spread function from an edge (because
            it's the underlying complex function, and not the intensity,
            that we really want to differentiate).
    
    Returns: x_midpoints, diff_y
        The two 1D arrays are the mid-points of the x coordinates (i.e. 
        the X coordinates of the differentiated y values), and the 
        numerically differentiated y values (scaled by the difference 
        between x points to approximate the true derivative).
    """
    
    # TODO why? I think it's removing that many pixels, not that many points?
    # crop = np.argmax(x > np.min(x) + crop)

    # This just removes that many points - means that the shape is consistent
    if crop > 0:
        cropped_x = x[crop:-crop]
    else:
        print('Warning, numerical diff may struggle at the start of an image, safer to crop')

    mid_x = (cropped_x[1:] + cropped_x[:-1])/2.0

    # instead of calculating d/dx y, calculate (d/dx y**0.5)**2
    # to recover the point spread function from an edge (because
    # it's the underlying complex function, and not the intensity,
    # that we really want to differentiate).
    if sqrt:
        diff_y = (np.diff((y - np.min(y))**0.5)/np.mean(np.diff(x)))**2
    else:
        diff_y = np.diff(y - np.min(y))/np.mean(np.diff(x))
    if crop > 0:
        diff_y = diff_y[crop:-crop]
        
    return mid_x, diff_y

def find_esf(rgb_image, raw_image, blocks):
    '''take the input image and number of blocks, make sure the edge is suitable,
    then find the edge spread function'''

    # if we're working on a JPEG without bayer data, treat it as a raw image anyway
    if raw_image is None:
        raw_image = rgb_image

    # seems a good bright / dark intensity value
    gray_img = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2GRAY)
    gray_img = cv2.resize(gray_img, [0,0], fx = 0.4, fy=0.4)

    # boundary = np.mean(np.array(np.min(gray_img), np.max(gray_img))) * 1.5
    boundary = 90
    
    # find the angle of the edge - should already be lined up
    (m, x1, y1), result = detect_edge(rgb_image, boundary)
    y_inc = y1 - m * x1

    # make sure the edge is slanted enough that it's not just effectively vertical, limiting the use of the esf alignment
    try:
        assert np.abs(raw_image.shape[0]/(m * blocks)) > 1, "The edge is insufficiently slanted."
    except Exception as e:
        print("Warning: {}".format(e))
        print("The edge only moves {} of a pixel over the image blocks".format(round(raw_image.shape[0]/(m * blocks), 3)))

    h = raw_image.shape[0]
    block_height = h // blocks

    esfs = []

    # plt.imshow(rgb_image)
    # plt.plot(xs, ys)
    # plt.show()

    tx_all = {}
    
    for i in range(blocks):
        
        # for each block, crop an equal sized region out around the edge
        height_slice = slice(i * block_height, (i+1) * block_height) # the section of the image we're analysing
        height_centre = (height_slice.start + height_slice.stop)/2.0
        width_centre = (height_centre - y_inc) / m
        width_slice = slice(int(width_centre - 300), int(width_centre + 300))
        # width_slice = slice(0, raw_image.shape[1])

        # plt.imshow(raw_image[height_slice, width_slice, :], cmap = "gray")
        # plt.show()

        edges = []

        tx_all[f'{i}'] = []

        for channel in range(3):
            #TODO: timings - each one takes 1-2 seconds

            # go through the cropped image in each channel, 
            aligned_rows = align_edges(raw_image[height_slice, width_slice, :], channel)

            if len(aligned_rows) != 0:

                sorted_x, sorted_I, txs = sorted_x_and_I(np.asarray(aligned_rows))
                tx_all[f'{i}'].append(txs[0][0])
                # start = time.time()
                llr_x, llr_I = local_linear_regression(sorted_x, sorted_I, dx = 0.25, kernel_width=0.8)
                # llr_x = llr_x[0:380]
                # llr_I = llr_I[0:380]
                txs = np.full(llr_x.shape, np.mean(txs))

                edges.append((llr_x, llr_I, txs))

            else:
                length = edges[0][0].shape[0]
                edges.append((np.zeros(length),np.zeros(length),np.zeros(length)))

        esfs.append(edges)
    
    esfs = np.array(esfs)

    #TODO: timing - overall this takes 45 seconds per image

    return esfs, m, y_inc, tx_all

def deduce_bayer_pattern(image, unit_cell=2):
    """Deduce which pixels in an image are nonzero, modulo a unit cell.
    
    image: a 3D numpy array corresponding to a colour image.
    unit_cell: either a 2-element tuple or a scalar (for square unit cells).
    
    Returns: 3D numpy array of booleans, showing which elements of the unit
        cell are nonzero for each colour channel.
    
    The image should have one plane per colour channel (any number of channels
    is ok), and only pixels with the appropriate colour filter should be
    nonzero.  I.e. you'd expect for a standard RGGB pattern that 3/4 of the
    red and blue pixels are zero, and 2/4 of the green pixels are zero.
    * If bayer_pattern is a Boolean array, it should have the dimensions
    of the unit cell, and the number of colour planes of the image.  E.g.
    for a standard RGGB pattern, it should be 2x2x3.
    """
    if np.isscalar(unit_cell):
        unit_cell = (unit_cell, unit_cell)
    bayer_pattern = np.zeros(tuple(unit_cell) + (image.shape[2],), dtype=bool)
    w, h = unit_cell
    for i in range(w):
        for j in range(h):
            bayer_pattern[i,j, :] = np.sum(np.sum(image[i::w, j::h, :], axis=0), axis=0) > 0
    return bayer_pattern

def align_edges(raw_image, channel):
    '''Take a cropped area of an image around the edge, and a channel
    Extract the pixels in each row that contain pixel info (in case of Bayer patterns)
    and plot the area around a transition from light to dark.
    Tries to find the transition adaptively, otherwise uses a set threshold'''

    #TODO: timing - alledgely around 1 second per block-channel

    bayer = deduce_bayer_pattern(raw_image)[:, :, channel]
    aligned_rows = []
    for i in range(raw_image.shape[0]):
        row_bayer = bayer[i % bayer.shape[0], :]
        if np.any(row_bayer): # if there are active pixels of this channel in this row
            x = np.argmax(row_bayer) # find the position of the first relevant pixel in the row
            active_pixel_slice = slice(x, None, bayer.shape[1]) # This slice object takes every other pixel, starting at x (x==0 or 1)

            # Extract the position and intensity of all of the pixels that are active
            x = np.arange(raw_image.shape[1])[active_pixel_slice]
            I = raw_image[i, active_pixel_slice, channel]
            
            normI = I - np.min(I)
            normI = (np.divide(normI, np.max(normI))*255).astype(int)

            escape = False
            while escape == False:
                normI = normI - np.min(normI)
                normI = (np.divide(normI, np.max(normI))*255).astype(int)
                
                if normI[0] - normI[1] > 10:
                    x = np.delete(x, 0)
                    normI = np.delete(normI, 0)
                elif np.abs(normI[-2] - normI[-1]) > 30:
                    x = np.delete(x, -1)
                    normI = np.delete(normI, -1)
                else:
                    diffs = np.diff(normI)
                    for j in range(1, len(diffs)):
                        if abs(diffs[j-1]) > 25 and abs(diffs[j]) > 25 and np.diff(normI)[j-1] * np.diff(normI)[j] < 0:
                            x = np.delete(x, j)
                            normI = np.delete(normI, j)
                            break
                        if j == len(diffs)-1:
                            escape = True
            
            cross_8 = np.argmax(normI > np.min(normI) + 0.8 * (np.max(normI) - np.min(normI)))

            b = normI[::-1]
            cross_2 = len(b) - np.argmax(b < np.min(b) + 0.2 * (np.max(b) - np.min(b)))

            failure = 0

            try:
                var_thresholds = [np.mean(normI[:cross_2]) + 2*np.std(normI[:cross_2]), np.mean(normI[cross_8:]) - 2*np.std(normI[cross_8:])]

                b = normI[::-1]
                start = len(b) - np.argmax(b < var_thresholds[0])
                stop = np.argmax(normI > var_thresholds[1])

                x_t = x[start:stop]
                I_t = normI[start:stop]

                slope, intercept, r, p, se = scipy.stats.linregress(x_t, I_t)

                transition = ((min(normI) + 0.5 * (max(normI) - min(normI))) - intercept)/slope # 0.5 = intercept + gradient*xt

                if slope < 0.001 or np.isnan(transition) or np.isinf(transition):
                    print('problem with transition')
                    raise Exception
                width = 100

                # But half(-ish) of them are empty - so we need to crop to just the points we have
                start = np.argmax(x > transition - width/2.0)
                stop = start + width//row_bayer.shape[0] # Should do the same as above, but guarantees length.

                if len(x[start:stop]) != width//row_bayer.shape[0]:
                    print('problem')
                    raise Exception

            except:
                x_t = x[cross_2:cross_8]
                I_t = normI[cross_2:cross_8]

                try:
                    slope, intercept, r, p, se = scipy.stats.linregress(x_t, I_t)
                    transition = ((min(normI) + 0.5 * (max(normI) - min(normI))) - intercept)/slope # 0.5 = intercept + gradient*xt
                    if slope < 0.001 or np.isnan(transition) or np.isinf(transition):
                        # plt.plot(x, normI, '.')
                        # plt.plot(x_t, I_t, '.')
                        # plt.plot(x_t, x_t*slope + intercept)
                        # plt.axhline(normI[cross_2])
                        # plt.axhline(normI[cross_8])
                        # plt.show()
                        raise Exception
                except:
                    failure = 1

            if failure == 1:
                # print(f"WARNING: fit failed, ignoring line {i}")
                pass
            else:
                # Now, crop out a region centered on the transition
                # This width is in pixels - we'll take the hundred pixels around the transition
                width = 100

                # But half(-ish) of them are empty - so we need to crop to just the points we have
                start = np.argmax(x > transition - width/2.0)
                stop = start + width//row_bayer.shape[0] # Should do the same as above, but guarantees length.
                
                if len(x[start:stop]) != width//row_bayer.shape[0]:
                    print(f"WARNING: fit failed, ignoring line {i}")
                elif np.isnan(transition):
                    print(f"WARNING: fit failed, ignoring line {i}")
                elif x[start] - transition < -width:
                    print("This line is off centre, ignoring")
                else:
                    tx = np.full(x[start:stop].shape, transition)
                    aligned_rows.append((x[start:stop] - transition, normI[start:stop], tx))

    aligned_rows = np.array(aligned_rows)

    return aligned_rows

def sorted_x_and_I(aligned_rows):
    """Extract all x and I coordinates from a set of rows, in ascending x order
    
    Arguments: 
        aligned_rows: the output from extract_aligned_edge
        
    Returns:
        sorted_x, sorted_I, txs

    1D numpy arrays of x and I, in ascending x order, and an array of all the original x positions
    (the mean of txs tells you where the edge was before alignment)
    """

    # First, extract the x and I coordinates into separate arrays and flatten them.
    xs, Is, txs = zip(*aligned_rows)
    all_x = np.array(xs).flatten()
    all_I = np.array(Is).flatten()

    # Points must be sorted in ascending order in x
    order = np.argsort(all_x)
    sorted_x = all_x[order]
    sorted_I = all_I[order]

    # If any points are the same, spline fitting fails - so add a little noise
    while np.any(np.diff(sorted_x) <= 0):
        i = np.argmin(np.diff(sorted_x))
        sorted_x[i+1] += 0.0001 # 0.0001 is in units of pixels, i.e. insignificant.

    return sorted_x, sorted_I, txs

def psf_stats(x, I):
    av = np.sum(x*I) / np.sum(I)
    return [av, np.sum(np.sqrt((I * (np.sum(x - av)**2))))]

def edge_image_fnames(folder):
    """Find all the images in a folder that look like edge images"""
    fnames = []
    for f in os.listdir(folder):
        if f.endswith(".jpg") and "raw" not in f:
            fnames.append(os.path.join(folder, f))
    return fnames

if __name__ == '__main__':
    print('\n')
    if len(sys.argv) == 1:
        print("Usage: {} <file_or_folder> [<file2> ...]".format(sys.argv[0]))
        print("If a file is specified, we produce <file>_analysis.pdf")
        print("If a folder is specified, we produce a single PDF in that folder, analysing all its JPEG contents")
        print("Multiple files may be specified, using wildcards if your OS supports it - e.g. myfolder/calib*.jpg")
        print("In that case, one PDF (./edge_analysis.pdf) is generated for all; the files")
        print("if multiple folders are specified, or a mix of folders and files, each folder is handled separately as above")
        print("and all the files are processed together.")
        print("In the case of multiple files or a folder, we will also save the extracted PSFs in edge_analysis.npz")
        exit(-1)
    if len(sys.argv) == 2 and os.path.isfile(sys.argv[1]):
        # A single file produces a pdf with a name specific to that file, so it's different.
        analyse_edge_image(sys.argv[1], 9, save_plot=True)
    else:
        fnames = []
        folders = []
        for fname in sys.argv[1:]:
            if os.path.isfile(fname):
                fnames.append(fname)
            if os.path.isdir(fname):
                folders.append(fname)
        if len(fnames) > 0:
            print("Analysing files...")
            analyse_edge_image(fnames, 9)
        for folder in folders:
            #try:
                print("\nAnalysing folder: {}".format(folder))
                analyse_edge_image(edge_image_fnames(folder), 9, output_dir=folder)
            #except Exception as e:
            #    print("ERROR: {}".format(e))
            #    print("Aborting this folder.\n")