clear 
clc

% User input
sub_code = 2931; % Test subject code

%% Curve and metric generation settings
% Load normative data
load("Normative Dataset/normative.mat");
Values = ["HipFlex", "HipAdd", "KneeFlex", "AnkDors", "AnkInv", ...
    "HipFlexM", "HipAddM", "KneeFlexM", "AnkDorsM"];

% EMG envelope variables
fs = 1259; % Sampling frequency in Hz
lowCutoff = 20;
highCutoff = 450;
[b, a] = butter(4, [lowCutoff, highCutoff]/(fs/2), 'bandpass'); % Bandpass filter
cutoffFrequency = 5;
[d, c] = butter(4, cutoffFrequency/(fs/2), 'low'); % 4th order Butterworth filter

%% Raw postprocessing
% Load the original .mat file for raw data exported from Qualysis
matFiles = dir(fullfile("Mat/"+num2str(sub_code), '*.mat')); matFiles = string({matFiles.name});
for file_idx = 1:length(matFiles)
    % Load a trial file
    fileName = matFiles(file_idx);
    load("Mat/"+num2str(sub_code)+"/"+fileName);

    % Extract the data structure inside the file
    trial_num = regexp(fileName, '\d+', 'match'); % Trial number
    dataStruct = evalin('base', "Gait_LB___IOR_"+trial_num); 

    % Make a table for raw analog data (IMU + EMG) and convert units
    T_analog = array2table([(0:1:(size(dataStruct.Analog.Data.',1)-1)).'*(1/(dataStruct.Analog.Frequency)), ...
        dataStruct.Analog.Data.'], 'VariableNames', ['Time' ,dataStruct.Analog.Labels]);
    varNames = T_analog.Properties.VariableNames;
    accCols = contains(varNames, 'ACC'); % Accelerometer columns
    T_analog{:, accCols} = T_analog{:, accCols} / 9.81; % accelerometer unit: g.
    gyroCols = contains(varNames, 'GYRO'); % Gyro columns
    T_analog{:, gyroCols} = T_analog{:, gyroCols} * (pi/180); % gyro unit: rad/sec.
    emgCols = contains(varNames, 'EMG'); % EMG columns
    T_analog{:, emgCols} = T_analog{:, emgCols} * (1/1000.0); % EMG unit: millivolts

    % Add initial phase and mode labels
    T_analog.Mode = ones(height(T_analog),1);
    T_analog.phase = zeros(height(T_analog),1);

    % Processed IMU and EMG data computation
    T_analog_processed = table;
    T_analog_processed.Time = T_analog.Time; % Processed analog has the same time stamps as raw analog
    analogHeaders = T_analog.Properties.VariableNames; % Analog signal headers
    % IMU signals processing
    [b_low, a_low] = butter(6, 20/(round(dataStruct.Analog.Frequency)/2), 'low'); % 6th order low-pass butterworth filter (cutoff: 20 Hz)
    for i = 1:size(analogHeaders, 2)
        head = analogHeaders(i); head = head{1};
        if contains(head,"ACC") || contains(head,"GYRO") 
            T_analog_processed.(head) = filtfilt(b_low, a_low, T_analog.(head)); % Filter and add to processed table
        end
    end
    % EMG
    [b_high, a_high] = butter(6, 20/(round(dataStruct.Analog.Frequency)/2), 'high'); % 6th order high-pass butterworth filter (20 Hz cutoff)
    [b_low, a_low] = butter(6, 350/(round(dataStruct.Analog.Frequency)/2), 'low'); % 6th order low-pass butterworth filter (350 Hz cutoff)
    [b_notch_60, a_notch_60] = butter(6, [60-3,60+3]/(round(dataStruct.Analog.Frequency)/2), 'stop'); % 6th order notch butterworth filter (60 Hz)
    [b_notch_180, a_notch_180] = butter(6, [180-3,180+3]/(round(dataStruct.Analog.Frequency)/2), 'stop'); % 6th order notch butterworth filter (180 Hz)
    [b_notch_300, a_notch_300] = butter(6, [300-3,300+3]/(round(dataStruct.Analog.Frequency)/2), 'stop'); % 6th order notch butterworth filter (300 Hz)
    for i = 1:size(analogHeaders, 2)
        head = analogHeaders(i); head = head{1};
        % Search for EMG signals of interest
        if contains(head,"EMG")
            T_analog_processed.(head) = filtfilt(b_high, a_high, T_analog.(head)); % HPF
            T_analog_processed.(head) = filtfilt(b_low, a_low, T_analog_processed.(head)); % LPF
            T_analog_processed.(head) = filtfilt(b_notch_60, a_notch_60, T_analog_processed.(head)); % NPF @60 Hz
            T_analog_processed.(head) = filtfilt(b_notch_180, a_notch_180, T_analog_processed.(head)); % NPF @180 Hz 
            T_analog_processed.(head) = filtfilt(b_notch_300, a_notch_300, T_analog_processed.(head)); % NPF @300 Hz
        end
    end
    T_analog_processed.Mode = T_analog.Mode; T_analog_processed.phase = T_analog.phase;

    % Feature csv files
    ms_delay = 50; % Delay in ms
    win_sz = ceil(ms_delay*(10^-3)*dataStruct.Analog.Frequency); % 50 ms window size
    delay_sz = round(win_sz/3); % delay size is 1/3 of window size
    for delay = 0 : delay_sz : win_sz
        T_feature = table; % Feature table associated with the current delay
        % Compute IMU & EMG features
        for j = 1:size(analogHeaders, 2)
            head = analogHeaders(j); head = head{1};
            if contains(head,"ACC") || contains(head,"GYRO") 
                T_feature.(string(head) + " mean") = feat_extract(T_analog.(head)(delay+1:end), "mean", win_sz); 
                T_feature.(string(head) + " median") = feat_extract(T_analog.(head)(delay+1:end), "median", win_sz);
                T_feature.(string(head) + " std_dev") = feat_extract(T_analog.(head)(delay+1:end), "std_dev", win_sz); 
                T_feature.(string(head) + " initial") = feat_extract(T_analog.(head)(delay+1:end), "initial", win_sz);
                T_feature.(string(head) + " final") = feat_extract(T_analog.(head)(delay+1:end), "final", win_sz); 
                T_feature.(string(head) + " max") = feat_extract(T_analog.(head)(delay+1:end), "max", win_sz);
                T_feature.(string(head) + " min") = feat_extract(T_analog.(head)(delay+1:end), "min", win_sz); 
            elseif contains(head,"EMG")
                T_feature.(string(head) + " MAV") = feat_extract(T_analog_processed.(head)(delay+1:end), "MAV", win_sz); 
                T_feature.(string(head) + " WL") = feat_extract(T_analog_processed.(head)(delay+1:end), "WL", win_sz); 
                T_feature.(string(head) + " ZC") = feat_extract(T_analog_processed.(head)(delay+1:end), "ZC", win_sz); 
                T_feature.(string(head) + " SS") = feat_extract(T_analog_processed.(head)(delay+1:end), "SS", win_sz);
                AR_coeff = AR_extract(T_analog_processed.(head)(delay+1:end), win_sz); 
                for k = 1 : 6
                    T_feature.(string(head) + " AR coeff" + num2str(k)) = AR_coeff(:,k);
                end
            end
        end
        T_feature.Mode = ones(height(T_feature),1); T_feature.phase = zeros(height(T_feature),1);
    end

    % Force plate data extraction (We have 4 force plates in the lab)
    num_plates = 4;
    plate_array = zeros(dataStruct.Force(1).NrOfSamples, num_plates*9); % Force plate quantities (will contain: Force vector + Moment vector + Center of pressure vector)
    plate_header = strings(1,1+num_plates*9); plate_header(1) = "Time"; % Table headers (will contain: "Time" plus string headers of the previously mentioned plate quantities)
    for i = 1 : num_plates
        plate_array(:,(i-1)*9+1:i*9) = [dataStruct.Force(i).Force.', dataStruct.Force(i).Moment.', dataStruct.Force(i).COP.'];
        plate_header((i-1)*9+2:i*9+1) = ["plate "+num2str(i)+" Fx", "plate "+num2str(i)+" Fy", "plate "+num2str(i)+" Fz", ...
            "plate "+num2str(i)+" Mx", "plate "+num2str(i)+" My", "plate "+num2str(i)+" Mz", ...
            "plate "+num2str(i)+" COPx", "plate "+num2str(i)+" COPy", "plate "+num2str(i)+" COPz"]; % Fill headers
    end
    T_force = array2table([(0:1:(size(plate_array,1)-1)).'*(1/(dataStruct.Force(1).Frequency)), plate_array], ...
        'VariableNames', plate_header); %--> Converting force array and headers to a table

    % Marker location data extraction
    num_samples = size(dataStruct.Trajectories.Labeled.Type, 2);
    marker_loc = zeros(num_samples,4*20); marker_label = strings(1,4*20); % Initial locations and headers (will contain: location of each marker + its header name)
    for i = 1 : 20
        % Each marker trajectory is defined by four quantities: x, y, z and residual
        for j = 1 : 4
            marker_loc(:,(i-1)*4+j) = reshape(dataStruct.Trajectories.Labeled.Data(i,j,:), [], 1);
            if j == 1
                l = " x";
            elseif j == 2
                l = " y";
            elseif j == 3
                l = " z";
            else
                l = " residual";
            end
            marker_label((i-1)*4+j) = dataStruct.Trajectories.Labeled.Labels(i)+l; % Append new header name
        end
    end
    T_marker = array2table(marker_loc, 'VariableNames', marker_label); % Convert arrays to a proper marker table
    T_force_plus_marker = [T_force, T_marker]; % Combine force and marker tables into one table, since they have same frequency

    % Save tables inside CSV files
    filename = "Final/AB" + num2str(sub_code) + "/Raw Data/AB" + num2str(sub_code) + "_Trial_" + num2str(trial_num,'%03d') + "_Analog_raw" + ".csv";
    writetable(T_analog, filename);
    filename = "Final/AB" + num2str(sub_code) + "/Raw Data/AB" + num2str(sub_code) + "_Trial_" + num2str(trial_num,'%03d') + "_Force_Motion_raw" + ".csv";
    writetable(T_force_plus_marker, filename);
    filename = "Final/AB" + num2str(sub_code) + "/Processed Data/AB" + num2str(sub_code) + "_Trial_" + num2str(trial_num,'%03d') + "_Analog_processed" + ".csv";
    writetable(T_analog_processed, filename);
    filename = "Final/AB" + num2str(sub_code) + "/Features/AB" + num2str(sub_code) + "_Trial_" + num2str(trial_num,'%03d') + "_feat_"+ num2str(round((delay/dataStruct.Analog.Frequency)*1000)) + "ms.csv";
    writetable(T_feature, filename);
end

%% Processed postprocessing
% Consistency data records for the subject
EMG = struct(); % EMG envelope data structure
KIN = struct(); % Kinetic data structure
kin_headers = ["Hip Angles X", "Hip Angles Y", "Knee Angles X", "Ankle Angles X", "Ankle Angles Y", ...
    "Hip Moment X", "Hip Moment Y", "Knee Moment X", "Ankle Moment X"]; % Kinematic/kinetic headers of interest
Kin = ["Hip_Flexion", "Hip_Adduction", "Knee_Flexion", "Ankle_Dorsiflexion", "Ankle_Inversion", ...
    "Internal_Hip_Extensor_Moment", "Int_Hip_Valgus_Moment", "Int_Knee_Extensor_Moment", "Int_Ankle_Plantarflexor_Moment"];

% Load the original .json file for processed data exported from Visual3D
jsonFiles = dir(fullfile("Json/"+num2str(sub_code), '*.json')); jsonFiles = string({jsonFiles.name});
for file_idx = 1:length(jsonFiles)
    % Load a trial file
    fileName = jsonFiles(file_idx);
    trial_num = regexp(fileName, '\d+', 'match'); % Trial number
    jsonText = fileread("Json/"+num2str(sub_code)+"/"+fileName);
    data = jsondecode(jsonText);

    % Examine Visual3D ouput quantities one by one
    T_body_processed = table; % Initialize empty table
    T_body_processed.Time = ((1:str2double(string(data.Visual3D(1).frames))).'-1)*...
        (1/dataStruct.Force(1).Frequency); % Assign force/marker time stamps to processed quantities (same refresh rate)
    for i = 1 : size(data.Visual3D,1)
        % Search for the physical quantities of interest in the LINK_MODEL_BASED property
        if data.Visual3D(i).type == "LINK_MODEL_BASED" && (contains(string(data.Visual3D(i).name), "Ang_Acc") || ...
                contains(string(data.Visual3D(i).name), "Ang_Vel") || (contains(string(data.Visual3D(i).name), "Angles")&& ...
                ~contains(string(data.Visual3D(i).name), "Angles_")) || ...
                contains(string(data.Visual3D(i).name), "Moment") || contains(string(data.Visual3D(i).name), "Power") || ...
                contains(string(data.Visual3D(i).name), "COP_rt_") || contains(string(data.Visual3D(i).name), "CG_rt_LFT") || ...
                contains(string(data.Visual3D(i).name), "CG_rt_RFT"))
            T_body_processed.(string(data.Visual3D(i).name) + " " + string(data.Visual3D(i).signal(1).component)) = ...
                data.Visual3D(i).signal(1).data; % Signal + vector component
            T_body_processed.(string(data.Visual3D(i).name) + " " + string(data.Visual3D(i).signal(2).component)) = ...
                data.Visual3D(i).signal(2).data;
            T_body_processed.(string(data.Visual3D(i).name) + " " + string(data.Visual3D(i).signal(3).component)) = ...
                data.Visual3D(i).signal(3).data;
        end
    end

    % Initialize mode and phase annotations using arbitrary values
    % (same frequency as force, so we'll use force size)
    T_body_processed.Mode = ones(height(T_body_processed),1);
    T_body_processed.phase = zeros(height(T_body_processed),1);

    % Add gait event labels based on the software's inspection (Heel strike, toe off, etc.) 
    for i = 1 : size(data.Visual3D,1)
        if data.Visual3D(i).type == "EVENT_LABEL" && string(data.Visual3D(i).name) ~= "start" && ...
                string(data.Visual3D(i).name) ~= "end" && string(data.Visual3D(i).name) ~= "RANGESTART" && ...
                string(data.Visual3D(i).name) ~= "temp1" && string(data.Visual3D(i).name) ~= "temp2"
            column_padded = [data.Visual3D(i).signal.data; NaN(height(T_body_processed) - length(data.Visual3D(i).signal.data), 1)];
            T_body_processed.(string(data.Visual3D(i).name) + " " + string(data.Visual3D(i).signal.component)) = ...
                column_padded;
        end
    end

    % Correct the gait phases based on Visual3D's labels
    LR = T_body_processed.("RHS X"); LR = LR(~isnan(LR)).'; 
    MST = T_body_processed.("LTO X"); MST = MST(~isnan(MST)).';
    TS = T_body_processed.("RTS X"); TS = TS(~isnan(TS)).';
    PSW = T_body_processed.("LHS X"); PSW = PSW(~isnan(PSW)).';
    ISW = T_body_processed.("RTO X"); ISW = ISW(~isnan(ISW)).';
    Trigs = [0, ones(size(LR)), 2*ones(size(MST)), 3*ones(size(TS)), 4*ones(size(PSW)), 5*ones(size(ISW)), 0; 
        0, LR, MST, TS, PSW, ISW, T_body_processed.Time(end)]; % List of gait triggers + terminal points
    [~, sorted_indices] = sort(Trigs(2,:)); % Sort the triggers temporally
    Trigs = Trigs(:, sorted_indices); % Rearranged triggers
    if Trigs(1,2)==1
        Trigs(1,1) = 5;
    else
        Trigs(1,1) = Trigs(1,2)-1;
    end
    [Trig_idx,~] = find(abs(T_body_processed.Time-Trigs(2,:))<1e-6); Trig_idx(end) = Trig_idx(end)+1; % Trigger indices
    for loc = 1:length(Trig_idx)-1
        T_body_processed.phase(Trig_idx(loc):Trig_idx(loc+1)-1) = Trigs(1,loc);
    end

    % Save tables into csv files
    filename = "Final/AB" + num2str(sub_code) + "/Processed Data/AB" + num2str(sub_code) + "_Trial_" + num2str(trial_num,'%03d') + "_Body_Motion_processed" + ".csv";
    writetable(T_body_processed, filename);

    % Adjust raw table phases
    filename = "Final/AB" + num2str(sub_code) + "/Raw Data/AB" + num2str(sub_code) + "_Trial_" + num2str(trial_num,'%03d') + "_Analog_raw" + ".csv";
    T_analog = readtable(filename);
    T_analog.phase = rescale(T_body_processed.phase, height(T_analog));
    writetable(T_analog, filename);
    filename = "Final/AB" + num2str(sub_code) + "/Raw Data/AB" + num2str(sub_code) + "_Trial_" + num2str(trial_num,'%03d') + "_Force_Motion_raw" + ".csv";
    T_force_plus_marker = readtable(filename);
    T_force_plus_marker.phase = T_body_processed.phase;
    writetable(T_force_plus_marker, filename);
    filename = "Final/AB" + num2str(sub_code) + "/Processed Data/AB" + num2str(sub_code) + "_Trial_" + num2str(trial_num,'%03d') + "_Analog_processed" + ".csv";
    T_analog_processed = readtable(filename);
    T_analog_processed.phase = rescale(T_body_processed.phase, height(T_analog_processed));
    writetable(T_analog_processed, filename);
    filename = "Final/AB" + num2str(sub_code) + "/Features/AB" + num2str(sub_code) + "_Trial_" + num2str(trial_num,'%03d') + "_feat_"+ num2str(round((delay/dataStruct.Analog.Frequency)*1000)) + "ms.csv";
    T_feature = readtable(filename);
    T_feature.phase = rescale(T_body_processed.phase, height(T_feature));
    writetable(T_feature, filename);

    % Collect EMG envelope data
    Rrange = find([0; diff(T_analog.phase)] & T_analog.phase ==1); % Right gait cycle range
    Lrange = find([0; diff(T_analog.phase)] & T_analog.phase ==4); % Left gait cycle range
    emgHeaders = T_analog.Properties.VariableNames(contains(T_analog.Properties.VariableNames, 'EMG'));
    for head = 1:length(emgHeaders)
        % Post-process the signal
        fst_letter = char(emgHeaders{head}); fst_letter = fst_letter(1);
        if fst_letter == 'R'
            cycle_range = Rrange;
        else
            cycle_range = Lrange;
        end
        if length(cycle_range) > 1
            cycle_range = cycle_range(1):cycle_range(2);
            emg_cut = T_analog(cycle_range,:);
            signal = emg_cut.(emgHeaders{head}).';
            filteredEMG = filtfilt(b, a, signal); % Bandpass filter
            rectifiedEMG = abs(filteredEMG); % Rectification
            envelopeEMG = filtfilt(d, c, rectifiedEMG); % EMG envelope
            envelopeEMG = interp1(((0:length(envelopeEMG)-1)/(length(envelopeEMG)-1))*100, ...
                envelopeEMG, 0:0.5:100, 'linear'); % Temporal normalization
    
            if ~isfield(EMG, emgHeaders{head})
                EMG.(emgHeaders{head}) = envelopeEMG;
            else
                EMG.(emgHeaders{head}) = [EMG.(emgHeaders{head}); envelopeEMG];
            end
        end
    end

    % Collect kinematic/kinetic consistency curve data
    Rrange = find([0; diff(T_body_processed.phase)] & T_body_processed.phase ==1); % Right gait cycle range
    Lrange = find([0; diff(T_body_processed.phase)] & T_body_processed.phase ==4); % Left gait cycle range
    % Right side
    cycle_range = Rrange;
    if length(cycle_range) > 1
        cycle_range = cycle_range(1):cycle_range(2);
        Rkin_cut = T_body_processed(cycle_range,:);
        for head = 1:length(kin_headers)
            signal = Rkin_cut.("Right "+kin_headers(head)).';
            signal = interp1(((0:length(signal)-1)/(length(signal)-1))*100, ...
                signal, 0:0.5:100, 'linear'); % Temporal normalization

            anatomy_field = "Right_" + Kin(head); % Field translation
    
            if ~isfield(KIN, anatomy_field)
                KIN.(anatomy_field) = signal;
            else
                KIN.(anatomy_field) = [KIN.(anatomy_field); signal];
            end
        end
    end
    % Left side
    cycle_range = Lrange;
    if length(cycle_range) > 1
        cycle_range = cycle_range(1):cycle_range(2);
        Lkin_cut = T_body_processed(cycle_range,:);
        for head = 1:length(kin_headers)
            signal = Lkin_cut.("Left "+kin_headers(head)).';
            signal = interp1(((0:length(signal)-1)/(length(signal)-1))*100, ...
                signal, 0:0.5:100, 'linear'); % Temporal normalization

            anatomy_field = "Left_" + Kin(head); % Field translation
    
            if ~isfield(KIN, anatomy_field)
                KIN.(anatomy_field) = signal;
            else
                KIN.(anatomy_field) = [KIN.(anatomy_field); signal];
            end
        end
    end
end

% Plot and save consistency data
% Kinematic/kinetic
for head = 1:length(Kin)
    meanData = mean(Norm.(Values(head)), 1);
    stdData = std(Norm.(Values(head)), 0, 1);
    figure
    fill([0:0.5:100 fliplr(0:0.5:100)], [meanData + stdData fliplr(meanData - stdData)], 'b', 'FaceAlpha', 0.2, 'EdgeColor', 'none');
    hold on
    plot(0:0.5:100, KIN.("Left_"+Kin(head)), 'b', 'LineWidth', 1.5);
    plot(0:0.5:100, KIN.("Right_"+Kin(head)), 'r', 'LineWidth', 1.5);
    set(gcf, 'Color', 'w');
    xlabel('Gait %', 'FontWeight', 'bold', 'FontSize', 24);
    type = char(Kin(head));
    if string(type(end-5:end)) == "Moment"
        unit = "Nm/kg";
    else
        unit = "degrees";
    end
    ylabel("("+unit+")", 'FontWeight', 'bold', 'FontSize', 24);
    ax = gca;
    ax.XAxis.FontSize = 24;
    ax.YAxis.FontSize = 24;
    ax.XAxis.FontWeight = 'bold';
    ax.YAxis.FontWeight = 'bold';
    title(strrep(Kin(head), '_', ' '), 'FontSize', 24);
    h1 = plot(nan, nan, 'r', 'LineWidth', 1);
    h2 = plot(nan, nan, 'b', 'LineWidth', 1);
    legend([h1, h2], {'Right', 'Left'}, 'FontWeight', 'bold', 'FontSize', 24, 'Location', 'best');
    saveas(gcf, "Final/AB"+num2str(sub_code)+"/Gait Analysis/"+Kin(head)+" consistency.png");

    figure
    fill([0:0.5:100 fliplr(0:0.5:100)], [meanData + stdData fliplr(meanData - stdData)], 'b', 'FaceAlpha', 0.2, 'EdgeColor', 'none');
    hold on
    plot(0:0.5:100, mean(KIN.("Left_"+Kin(head)),1), 'b', 'LineWidth', 1.5);
    plot(0:0.5:100, mean(KIN.("Right_"+Kin(head)),1), 'r', 'LineWidth', 1.5);
    set(gcf, 'Color', 'w');
    xlabel('Gait %', 'FontWeight', 'bold', 'FontSize', 24);
    type = char(Kin(head));
    if string(type(end-5:end)) == "Moment"
        unit = "Nm/kg";
    else
        unit = "degrees";
    end
    ylabel("("+unit+")", 'FontWeight', 'bold', 'FontSize', 24);
    ax = gca;
    ax.XAxis.FontSize = 24;
    ax.YAxis.FontSize = 24;
    ax.XAxis.FontWeight = 'bold';
    ax.YAxis.FontWeight = 'bold';
    title(strrep(Kin(head), '_', ' ')+" mean", 'FontSize', 24);
    h1 = plot(nan, nan, 'r', 'LineWidth', 1);
    h2 = plot(nan, nan, 'b', 'LineWidth', 1);
    legend([h1, h2], {'Right', 'Left'}, 'FontWeight', 'bold', 'FontSize', 24, 'Location', 'best');
    saveas(gcf, "Final/AB"+num2str(sub_code)+"/Gait Analysis/"+Kin(head)+" mean.png");
end
% EMG
for head = 1:length(emgHeaders)
    figure
    if emgHeaders{head}(1) == 'R'
        color = 'r';
    else
        color = 'b';
    end
    plot(0:0.5:100, EMG.(emgHeaders{head}), color, 'LineWidth', 1.5);
    hold on 
    plot(0:0.5:100, mean(EMG.(emgHeaders{head}),1), 'g', 'LineWidth', 1.5);
    set(gcf, 'Color', 'w');
    xlabel('Gait %', 'FontWeight', 'bold', 'FontSize', 24);
    unit = "millivolts";
    ylabel(" (" + unit + ")", 'FontWeight', 'bold', 'FontSize', 24);
    ax = gca;
    ax.XAxis.FontSize = 24;
    ax.YAxis.FontSize = 24;
    ax.XAxis.FontWeight = 'bold';
    ax.YAxis.FontWeight = 'bold';
    title(strrep(string(emgHeaders{head}), '_', ' '),'FontSize', 24);
    h1 = plot(nan, nan, color, 'LineWidth', 1);
    h2 = plot(nan, nan, 'g', 'LineWidth', 1);
    legend([h1, h2], {'All', 'Mean'}, 'FontWeight', 'bold', 'FontSize', 18, 'Location', 'best');
    saveas(gcf, "Final/AB"+num2str(sub_code)+"/Gait Analysis/"+string(emgHeaders{head})+".png");
end

close all

%% Supporting functions

function feat_vector = feat_extract(col, feat, w)
% Extracts features from column vector by diving it into windows and computing the feature for each
feat_vector = zeros(floor(length(col)/w),1);
for i = 1:w:floor(length(col)/w)*w
    if feat == "mean"
        feat_vector(ceil(i/w)) = mean(col(i:i+w-1));
    elseif feat == "median"
        feat_vector(ceil(i/w)) = median(col(i:i+w-1));
    elseif feat == "std_dev"
        feat_vector(ceil(i/w)) = std(col(i:i+w-1));
    elseif feat == "initial"
        feat_vector(ceil(i/w)) = col(i);
    elseif feat == "max"
        feat_vector(ceil(i/w)) = max(col(i:i+w-1));
    elseif feat == "min"
        feat_vector(ceil(i/w)) = min(col(i:i+w-1));
    elseif feat == "MAV"
        feat_vector(ceil(i/w)) = mean(abs(col(i:i+w-1)));
    elseif feat == "WL"
        feat_vector(ceil(i/w)) = sum(abs(diff(col(i:i+w-1))));
    elseif feat == "ZC"
        feat_vector(ceil(i/w)) = nnz(diff(col(i:i+w-1) > 0));
    elseif feat == "SS"
        feat_vector(ceil(i/w)) = nnz(diff(sign(diff(col(i:i+w-1)))) ~= 0);
    elseif feat == "phase"
        feat_vector(ceil(i/w)) = col(i);
    end
end
end

function AR_matrix = AR_extract(col, w)
% Extracts autoregressive coefficients from column vector by diving it into windows and computing the coefficients for each
AR_matrix = zeros(floor(length(col)/w),6);
for i = 1:w:floor(length(col)/w)*w
    [coeff, ~] = aryule(col(i:i+w-1), 6);
    for j = 1 : 6
        AR_matrix(ceil(i/w), j) = coeff(j);
    end
end
end

function new_arr = rescale(arr, new_size)
    new_arr = zeros(new_size,1);
    trans_old = [1; find(diff(arr))+1; length(arr)];
    trans_new = [1;round((find(diff(arr))+1)*(new_size/length(arr)));new_size+1];
    for i = 1:length(trans_new)-1
        new_arr(trans_new(i):trans_new(i+1)-1) = arr(trans_old(i));
    end
end
