# -*- coding: utf-8 -*-
"""
Created on Wed Jan 25 18:12:27 2023

@author: Philippe
"""

import os
import scipy as sp
from scipy import optimize

#trackpy
# [1] D. B. Allan, T. Caswell, N. C. Keim, C. M. van der Wel, and R. W. Verweij, “soft-matter/trackpy: Trackpy v0.5.0,” Apr. 2021, doi: 10.5281/ZENODO.4682814.
import matplotlib.pyplot as plt
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter, AutoMinorLocator)
import numpy as np
import pandas as pd
from pandas import DataFrame, Series  # for convenience
import trackpy as tp

#pims
import ffmpeg 
import av 
import pims


# %% S0 - Curve Definitions
def linear(x,a,b):
    return a*x+b

def parabola(x, a):
    return a*x**2


# %% S1 - trackpy Velocity Definition
# See http://soft-matter.github.io/trackpy/v0.5.0/tutorial/walkthrough.html for reference
# [1] D. B. Allan, T. Caswell, N. C. Keim, C. M. van der Wel, and R. W. Verweij, “soft-matter/trackpy: Trackpy v0.5.0,” Apr. 2021, doi: 10.5281/ZENODO.4682814.

def velocity(path, file, mass, mpp, fps, x1, x2, y1, y2, s1, df, vf, last, speed=25, memory=15, minmass=50, threshold=15, r=0.99):
    @pims.pipeline
    def gray(image):
        return image[y1:y2, x1:x2, 1]  # Take just the green channel
    frames = gray(pims.open(path+file))
    fb = tp.batch(frames[s1:], 13, invert=True, maxsize=10, minmass=minmass, threshold=threshold)
    tp.quiet()  # Turn off progress reports for best performance
    t = tp.link(fb,speed, memory=memory)
    t1 = tp.filter_stubs(t,last)
    #print('Before:', t['particle'].nunique())
    #print('After:', t1['particle'].nunique())
    plt.figure()
    tp.mass_size(t1.groupby('particle').mean()); # convenience function -- just plots size vs. mass
    # Filtering of particles
    t2 = t1[((t1['mass'] >mass) & (t1['size'] < 4.0) & (t1['size'] > 2.5) & (t1['ecc'] < 0.8))]
    i = 110
    plt.figure()
    tp.annotate(t2[t2['frame'] == i], frames[i]);
    plt.figure()
    tp.plot_traj(t2);
    # Mean Squared Displacement, individual particles
    im = tp.imsd(t2, mpp, fps, max_lagtime=max(t2['frame']))  # microns per pixel = 100/72., frames per second = 13.3...
    for col in im.columns:  # Extra filtering steps
        if np.max(im[col]) < df**2:                  # Displacement too short
            print(col, np.max(im[col]))
            im.drop(col, inplace=True, axis=1)
        elif len(im[col][im[col] == 0]) > 2:         # Too many 0 displacements
            print(col, len(im[col][im[col] == 0]))
            im.drop(col, inplace=True, axis=1)
        elif len(im[col][im[col].notna()]) < 3:      # Too few valid points
            print(col)
            im.drop(col, inplace=True, axis=1)
    im.replace(0, np.nan, inplace=True)              # Replace occasional 0 displacement
    im = np.sqrt(im)
    m_list = list()
    std_list = list()
    fig, ax = plt.subplots(figsize=(6, 6), dpi=900)
    for col in im.columns:
        m, b, r_value, p_value, std = sp.stats.linregress(im.index[im[col].notna()], im[col][im[col].notna()]) 
        if r_value < r:
            print(m, r_value)
            continue
        elif m > vf:
            m_list.append(m)
            std_list.append(std)
        elif m < vf:
            print(m, '<', vf, r_value)
            continue
        ax.scatter(im.index, im[col], s=1, marker='*', color='k')
        ax.plot(im.index[im[col].notna()], linear(im.index[im[col].notna()], m,b), color='red', alpha=0.1)  # black lines, semitransparent
    ax.set(ylabel=r'$\Delta r$ [$\mu$m]', xlabel='lag time $t$')
    n = len(m_list)
    std_mean = np.std(m_list, axis=0)
    std_dev = np.sqrt(np.sum(np.array(std_list)**2)/n)
    return [np.mean(m_list, axis=0), std_dev, np.max(m_list), std_list[np.argmax(m_list)], std_mean], im

# %% S2 - Velocity Extraction Voltage
path = "Cropped Movie Frames Voltage/"
#     ʌʌʌʌʌ Insert full path here ʌʌʌʌʌ 

#     vvvvv   Run line by line    vvvvv
#                     velocity(path, file,     mass,    mpp,    fps, x1, x2,  y1,  y2,  s1, df, vf, last)
V4, im4 = velocity(path, '04V_13.377/*.png',    800, 100/72, 13.377,  8, 48, 150, 550, 290,  1, 10, 15)
V5, im5 = velocity(path, '05V_13.377/*.png',   1500, 100/72, 13.377, 18, 63,  90, 530, 320, 50, 25, 15) #0-84
V6, im6 = velocity(path, '06V_13.366/*.png',    800, 100/72, 13.366, 25, 65, 150, 580, 100, 20, 35, 6) #0-88
V7, im7 = velocity(path, '07V_13.376/*.png',   2000, 100/72, 13.376, 15, 65,  80, 400,  90, 27, 40, 7) #0-74
V8, im8 = velocity(path, '08V_13.367/*.png',   1300, 100/72, 13.367, 15, 55, 310, 740,  65, 20, 40, 7, r=0.97,  minmass=200) #0-77
V9, im9 = velocity(path, '09V_13.361/*.png',   1000, 100/72, 13.361, 18, 55, 110,  500,  0,  0, 80, 7, r=0.96, speed=35, threshold=12, memory=5) #0-77
V10, im10 = velocity(path, '10V_13.378/*.png',  200, 100/72, 13.378, 10, 55, 400,  -1,   0,  0, 50, 15,  r=0.95, speed=30, threshold=12, memory=5) #0-70

# %% S3 Velocity Plot

VY = np.array(np.vstack((V4, V5, V6, V7, V8, V9, V10)))
VX = np.array([4, 5, 6, 7, 8, 9, 10])
VXP = np.array([3, 4, 5, 6, 7, 8, 9, 10, 11])

fig, ax0 = plt.subplots(num=None, figsize=(6, 7), facecolor='w', edgecolor='k', dpi=900)
plt.rc('xtick', labelsize=24) 
plt.rc('ytick', labelsize=24) 
for axis in ['top','bottom','left','right']:
    ax0.spines[axis].set_linewidth(2)
ax0.tick_params(axis='both',width=2,which='major',length=10)
ax0.tick_params(axis='both',width=1,which='minor',length=5)
ax0.xaxis.set_major_locator(MultipleLocator(1))
#ax0.xaxis.set_minor_locator(MultipleLocator(0.5))
ax0.yaxis.set_major_locator(MultipleLocator(25))
ax0.yaxis.set_minor_locator(MultipleLocator(5))
ax0.errorbar(VX, VY[:,0], yerr=VY[:,4], fmt='o', color='black', barsabove=True, capsize=10 , ms=5, mfc='b', mec='k', label='Average Velocity', zorder=2 )
popt, pcov = optimize.curve_fit(parabola, VX, VY[:,0], p0=(1), sigma=VY[:,1], method='lm')
ax0.plot(VXP, parabola(VXP, *popt), linestyle='--', linewidth=1.5, color='red', label='Parabolic Fit', zorder=1)
ax0.scatter(VX, VY[:,0],s=40, marker='o', color='blue', edgecolors='black',zorder=3)
print(popt)
ax0.text(6.7, 52, r'c$_p$ = %1.2F$\,$V$^2$' %(popt[0]), color='black', fontsize='x-large')
ax0.grid(True)
ax0.set_xlim(3.5,10.5)
ax0.set_xlabel("Voltage [V]", fontsize=28)
ax0.set_ylabel("Velocity c$_p$ [µm/s]", fontsize=28)
plt.legend(loc="best", fontsize='xx-large')
