import numpy as np
from scipy.optimize import leastsq, fsolve


def shiftguestglobal(arg, data):
    X = np.append(data.chost, data.cguest)
    for ki in range(len(arg)-3):
        arg0 = (arg[ki], arg[-3], arg[-2], arg[-1])
        try:
            out = np.append(out, shiftguest(arg0, X, getattr(data, 'g' + str(ki+1))))
        except:
            out = shiftguest(arg0, X, getattr(data, 'g' + str(ki+1)))

    return out

def shiftguest(arg, x, y):
    dhg, k1, k2, k3 = arg
    dhg2, dhg3 = dhg, dhg
    dg = y[0] # fixed values, absorbance at chost = 0
    chost = x[:len(x)/2]
    cguest = x[len(x)/2:]
    conc = np.zeros((len(chost), 5))
    conc[0, :] = np.array([chost[0], 0, 0, 0, cguest[0]])
    for i in range(1, len(chost)):
        if i == 1:
            x0 = np.array([chost[1], 0, 0, 0, cguest[1]])
        else:
            x0 = sol
        sol = fsolve(equations, x0, args = (k1, k2, k3, chost[i], cguest[i]))
        conc[i, :] = np.array(sol)

    norm = np.dot(conc[:, 1:],np.array([1, 2, 3, 1]))
    out = y - np.dot(conc[:, 1:], np.array((dhg, 2*dhg2, 3*dhg3, dg)))/norm
    out[0] = y[0] - dg

    # introduce reasonable constraints for Ki
    if k1 <=100 or k2 <= 100 or k3 <= 1:
        out += 1e10

    if k1 >= 10000 or k2 >= 8000 or k3 >= 1000:
        out += 1e10

    return out

def uvglobal(arg, data1, data2, data3, data4):
    # unpack arg depending on the number of datasets and available wavelengths
    K1, K2, K3 = arg[12:]
    arg1 = list(arg[:3])
    arg2 = list(arg[3:8])
    arg3 = list(arg[8:10])
    arg4 = list(arg[10:12])
    out1 = shiftguestglobal(arg1 + [K1, K2, K3], data1)
    out2 = shiftguestglobal(arg2 + [K1, K2, K3], data2)
    out3 = shiftguestglobal(arg3 + [K1, K2, K3], data3)
    out4 = shiftguestglobal(arg4 + [K1, K2, K3], data4)
    out = np.concatenate((out1, out2, out3, out4))
    return out

def equations(p, *data):
    ch, chg, chg2, chg3, cg = p
    k1, k2, k3, cht, cgt = data
    f1 = k1*ch*cg - chg
    f2 = k2*chg*cg - chg2
    f3 = k3*chg2*cg - chg3
    f4 = cht - ch - chg - chg2 - chg3
    f5 = cgt - cg - chg - 2*chg2 - 3*chg3
    if ch <0 or chg < 0 or chg2 <0 or chg3 < 0 or cg <0 :
       f1 += 1e25; f2 += 1e25; f5 += 1e25; f3 += 1e25; f4 += 1e25;
    return (f1, f2, f3, f4, f5)

class out:
    pass

# import titration data
# ------------------------------------------------------------------------
# each dataset is an instance of class "out" with the following variables:
# cguest in an array with guest concentration
# chost is an array with host concentrations
# g1, g2, g3, ... are the measured absorbances at selected wavelengths
# the number of gi columns is not limited, but should have successive names 
# For example 
# data = out()
# data.cguest = np.array([cg1, cg2, ...., cgn])
# data.chost = np.array([ch1, ch1, ...., ch2])
# data.g1 = np.array([a1, a2, ....])
# data.g2 = np.array([b1, b2, ....])
# ................................

# initial guess values for K1, K2, K3
K1, K2, K3 = 1800.0,  1000.0, 200.0
# In this example four datasets were fitted simultaneously,
# assuming that the absorbance of the HG, HG2, HG3 complexes
# are approximately the same, but different from the G absorbance
# The function uvglobal should be modified 
# for a different number of datasets and considered wavelengths

# Inital guess values for absorbances at different wavelengths
# for dataset data1 (three wavelengths)
A0_1 = #(a1, a2, a3)
# for dataset data2 (five wavelengths)
A0_2 = #(a1, a2, a3, a4, a5)
# for dataset data3
A0_3 = #(a1, a2) (two wavelengths)
# for dataset data4
A0_4 = #(a1, a2) (two wavelengths)

arg0 = (*A0_1, *A0_2, *A0_3, *A0_4,
        K1, K2, K3)

pfit, pcov, \
infodict, errmsg, \
success = leastsq(uvglobal, arg0,
                  args = (data1, data2, data3, data4),
                  full_output=1)

print(pfit)

testout = uvglobal(pfit, data1, data2, data3, data4)
res = np.std(testout)
perr = np.sqrt(np.diag(pcov)) * res

print('errors:', perr)
print('xisqrd = ', res)
print('---------------------')
for i in range(len(perr)):
    print(pfit[i], ' +/- ', perr[i])
        

