import numpy as np #version 1.26.3
from scipy.signal import bessel, filtfilt #version 1.11.4
import pandas as pd #version 3.8.2
#using python 3.11

def main():
    #number of events to generate
    numevents = 500 

    #basic event parameters
    samplerate = 1/0.24e-6 #close to chimera samplerate
    cutoff = 1000000.0 #Hz
    poles = 4
    Wn = 2 * cutoff / samplerate
    b, a = bessel(poles, Wn, norm='mag') #define bessel filter to more closely approximate nanopore noise
    
    baseline_mean = 1000 #pA
    baseline_std = 50 #pA
    rc_mean = 1 #microseconds


    blockage_mean = [6 * baseline_std] #pA
    duration_mean = np.arange(2,11,1,dtype=np.float64) #microseconds

    print('Generating simple spikes...')
    num = 0
    for blockage in blockage_mean:
        for duration in duration_mean:
            for i in range(numevents):
                event = generate_simple_spike(baseline_mean, baseline_std, rc_mean, blockage, duration, b, a, samplerate)
                time = np.arange(0, len(event)/samplerate*1e6, 1e6/samplerate)
                df = pd.DataFrame({'Time': time, 'Current': event})
                df.to_csv('Simple Spike/event_{0}_{1}_{2}.csv'.format(num,blockage,duration), index=False)
                num += 1

    print('Generating end spikes')
    num = 0
    for blockage in blockage_mean:
        for duration in duration_mean:
            for i in range(numevents):
                event = generate_spike(baseline_mean, baseline_std, rc_mean, blockage, duration, b, a, samplerate, loc='end')
                time = np.arange(0, len(event)/samplerate*1e6, 1e6/samplerate)
                df = pd.DataFrame({'Time': time, 'Current': event})
                df.to_csv('End Spike/event_{0}_{1}_{2}.csv'.format(num,blockage,duration), index=False)
                num += 1

    print('Generating mid spikes')
    num = 0
    for blockage in blockage_mean:
        for duration in duration_mean:
            for i in range(numevents):
                event = generate_spike(baseline_mean, baseline_std, rc_mean, blockage, duration, b, a, samplerate, loc='middle')
                time = np.arange(0, len(event)/samplerate*1e6, 1e6/samplerate)
                df = pd.DataFrame({'Time': time, 'Current': event})
                df.to_csv('Mid Spike/event_{0}_{1}_{2}.csv'.format(num,blockage,duration), index=False)
                num += 1

    print('Generating start spikes')
    num = 0
    for blockage in blockage_mean:
        for duration in duration_mean:
            for i in range(numevents):
                event = generate_spike(baseline_mean, baseline_std, rc_mean, blockage, duration, b, a, samplerate, loc='start')
                time = np.arange(0, len(event)/samplerate*1e6, 1e6/samplerate)
                df = pd.DataFrame({'Time': time, 'Current': event})
                df.to_csv('Start Spike/event_{0}_{1}_{2}.csv'.format(num,blockage,duration), index=False)
                num += 1

    print('Generating rising edges')
    num = 0
    for blockage in blockage_mean:
        for duration in duration_mean:
            for i in range(numevents):
                event = generate_sloped_sublevel(baseline_mean, baseline_std, rc_mean, blockage, duration, b, a, samplerate, loc='end')
                time = np.arange(0, len(event)/samplerate*1e6, 1e6/samplerate)
                df = pd.DataFrame({'Time': time, 'Current': event})
                df.to_csv('Rising Edge/event_{0}_{1}_{2}.csv'.format(num,blockage,duration), index=False)
                num += 1

    print('Generating falling edges')
    num = 0
    for blockage in blockage_mean:
        for duration in duration_mean:
            for i in range(numevents):
                event = generate_sloped_sublevel(baseline_mean, baseline_std, rc_mean, blockage, duration, b, a, samplerate, loc='start')
                time = np.arange(0, len(event)/samplerate*1e6, 1e6/samplerate)
                df = pd.DataFrame({'Time': time, 'Current': event})
                df.to_csv('Falling Edge/event_{0}_{1}_{2}.csv'.format(num,blockage,duration), index=False)
                num += 1

    print('Generating double spikes')
    num = 0
    for blockage in blockage_mean:
        for duration in duration_mean:
            for i in range(numevents):
                event = generate_double_spike(baseline_mean, baseline_std, rc_mean, blockage, duration, b, a, samplerate, loc='start')
                time = np.arange(0, len(event), 1) * 1e6/samplerate
                df = pd.DataFrame({'Time': time, 'Current': event})
                df.to_csv('Double Spike/event_{0}_{1}_{2}.csv'.format(num,blockage,duration), index=False)
                num += 1


def generate_simple_spike(baseline, baseline_std, rc, blockage, duration, b, a, samplerate):
    duration *=  samplerate * 1e-6 #change to units of data points
    rc *= samplerate * 1e-6 #change to units of data points
    padding = max(int(duration * 2),100)
    length = int(duration)+2*padding   
    event = np.ones(length)*baseline
    start = padding
    end = start + int(duration)
    
        
    for i in range(length):
        if i > start:
            event[i] -= blockage*(1-np.exp(-(i - start)/rc))
        if i > end:
            event[i] += blockage*(1-np.exp(-(i - end)/rc))
            
    event += np.random.normal(0, baseline_std, size=length)
    padded = np.pad(event, padding, mode='edge')
    event = filtfilt(b, a, padded)[padding:-padding]
    return event


def generate_spike(baseline, baseline_std, rc, blockage, duration, b, a, samplerate, loc='start'):
    carrier = 200 #add a 200 us carrier event
    duration *=  samplerate * 1e-6 #change to units of data points
    rc *= samplerate * 1e-6 #change to units of data points
    padding = max(int(duration * 2),100)
    length = int(duration)+2*padding + int(carrier)    
    event = np.ones(length)*baseline
    start = padding
    if loc == 'start':
        start = padding
    elif loc == 'middle':
        start = padding + int(carrier/2) - int(duration/2)
    elif loc == 'end':
        start = padding + int(carrier) - int(duration)
    end = start + int(duration)
    
        
    for i in range(length):
        if i > padding:
            event[i] -= blockage*(1-np.exp(-(i - padding)/rc))
        if i > start:
            event[i] -= blockage*(1-np.exp(-(i - start)/rc))
        if i > end:
            event[i] += blockage*(1-np.exp(-(i - end)/rc))
        if i > padding + carrier:
            event[i] += blockage*(1-np.exp(-(i - (padding+carrier))/rc))
            
    event += np.random.normal(0, baseline_std, size=length)
    padded = np.pad(event, padding, mode='edge')
    event = filtfilt(b, a, padded)[padding:-padding]
    return event

def generate_sloped_sublevel(baseline, baseline_std, rc, blockage, duration, b, a, samplerate, loc='start'):
    carrier = 200 #add a 200 us carrier event
    duration *=  samplerate * 1e-6 #change to units of data points
    rc *= samplerate * 1e-6 #change to units of data points
    padding = max(int(duration * 2),100)
    length = int(duration)+2*padding + int(carrier)    
    event = np.ones(length)*baseline
    if loc == 'start':
        start = padding
        step = padding + int(duration)
        end = step + int(carrier)
    elif loc == 'end':
        start = padding
        step = padding + int(carrier)
        end = step + int(duration)
    if loc == 'start':
        for i in range(length):
            if i > start:
                event[i] -= blockage*(1-np.exp(-(i - start)/rc))
            if i > step:
                event[i] -= blockage*(1-np.exp(-(i - step)/rc))
            if i > end:
                event[i] += 2*blockage*(1-np.exp(-(i - end)/rc))
    elif loc == 'end':
         for i in range(length):
            if i > start:
                event[i] -= 2*blockage*(1-np.exp(-(i - start)/rc))
            if i > step:
                event[i] += blockage*(1-np.exp(-(i - step)/rc))
            if i > end:
                event[i] += blockage*(1-np.exp(-(i - end)/rc))
            
    event += np.random.normal(0, baseline_std, size=length)
    padded = np.pad(event, padding, mode='edge')
    event = filtfilt(b, a, padded)[padding:-padding]
    return event


def generate_double_spike(baseline, baseline_std, rc, blockage, duration, b, a, samplerate, loc='start'):
    carrier = 200 #add a 200 us carrier event
    duration *=  samplerate * 1e-6 #change to units of data points
    rc *= samplerate * 1e-6 #change to units of data points
    padding = max(int(duration * 2),100)
    length = int(duration)+2*padding + int(carrier)+ int(8*samplerate*1e-6)
    event = np.ones(length)*baseline
    step1 = padding + int(carrier/2) - int(duration/2) + int(4*samplerate*1e-6)
    step2 = step1 + int(4*samplerate*1e-6)
    step3 = step2 + int(duration)
    step4 = step3 + int(4*samplerate*1e-6)
    end = step4 + int(carrier/2)
    for i in range(length):
        if i > padding:
            event[i] -= blockage*(1-np.exp(-(i - padding)/rc))
        if i > step1:
            event[i] -= blockage*(1-np.exp(-(i - step1)/rc))
        if i > step2:
            event[i] += blockage*(1-np.exp(-(i - step2)/rc))
        if i > step3:
            event[i] -= blockage*(1-np.exp(-(i - step3)/rc))
        if i > step4:
            event[i] += blockage*(1-np.exp(-(i - step4)/rc))
        if i > end:
            event[i] += blockage*(1-np.exp(-(i - end)/rc))
            
    event += np.random.normal(0, baseline_std, size=length)
    padded = np.pad(event, padding, mode='edge')
    event = filtfilt(b, a, padded)[padding:-padding]
    return event


if __name__=='__main__':
    main()
