clear
clc
close all

%% Input
[expFile,expPath] = uigetfile({'*.csv;*.xlsx'},'Select table containing EXPERIMENTAL data.','DataBase.csv');
rawExp = readmatrix([expPath,expFile],...
                    'NumHeaderLines',3,...
                    'DecimalSeparator','.');

[hybFile,hybPath] = uigetfile({'*.csv;*.xlsx'},'Select table containing HYBRIDIZATION data.','SEGWE.csv');
rawHyb = readmatrix([hybPath,hybFile],...
                    'NumHeaderLines',2,...
                    'DecimalSeparator','.');

% get numerical data
rawExp = rawExp(~isnan(rawExp(:,5)),:); % remove non-numeric entries
soluteID_exp  = rawExp(:,1);
solventID_exp = rawExp(:,3);
Dinf_data = log(rawExp(:,5));
soluteID_hyb  = rawHyb(:,1);
solventID_hyb = rawHyb(:,2);
SEGWE_data = log(rawHyb(:,3));

% transform experimental data from list to matrix
keySolute = unique(soluteID_exp);
keySolvent = unique(solventID_exp);
[~,j] = ismember(solventID_exp,keySolvent);
[~,i] = ismember(soluteID_exp,keySolute);
expDataAvailable = full(sparse(i,j,true));
Dinf_data = full(sparse(i,j,Dinf_data));
Dinf_data(~expDataAvailable) = nan;

%transform hybridization data from list to matrix
[~,j] = ismember(solventID_hyb,keySolvent);
[~,i] = ismember(soluteID_hyb,keySolute);
ind = (i~=0 & j~=0);
hybDataAvailable = full(sparse(i(ind),j(ind),true));
SEGWE_data = full(sparse(i(ind),j(ind),SEGWE_data(ind)));
SEGWE_data(~hybDataAvailable) = nan;

%% Fit with Stan
I = size(Dinf_data,1); % number of solutes
J = size(Dinf_data,2); % number of solvents
K = 2; % number of latent dimensions
sigma_0 = 1; % Prior standard deviation
lambda = 0.2; % likelihood scale

% residuals
Res_Dinf = SEGWE_data-Dinf_data;
Res_Dinf(isnan(Res_Dinf))=-99;

data = struct('I',I,'J',J,'K',K,'R',Res_Dinf,'sigma_0',sigma_0,'lambda',lambda);
fit = stan('file','model_boosting.stan','data',data,'method','variational','iter',1000,'verbose',true,'seed',208776);
fit.block();
results = fit.extract;

% residuals for each of the 1001 sample points
u_samples = results.u;
v_samples = results.v;
% prediction of residuals
Res_Dinf_pred = zeros(1001,I,J);
for k = 1:size(u_samples,1)
    Res_k = squeeze(u_samples(k,:,:)) * squeeze(v_samples(k,:,:))';
    Res_Dinf_pred(k,:,:) = Res_k;
end
Std_Dinf_pred = squeeze(std(Res_Dinf_pred,0,1));

% set unknown entries to NaN
Res_Dinf_pred = squeeze(mean(Res_Dinf_pred,1));
Std_Dinf_pred(Std_Dinf_pred==0) = NaN;

% add Residuals
Dinf_pred = SEGWE_data - Res_Dinf_pred; % predicted D

%% Output
[jj,ii] = find(~isnan(Dinf_pred.'));
outData(:,1) = keySolute(ii);
outData(:,2) = keySolvent(jj);
outData(:,3) = reshape(exp(Dinf_pred),[],1);
outData(:,4) = reshape(exp(Std_Dinf_pred),[],1);

writematrix(["Prediction of Dij_inf by MCM-Boosting","","",""; ...
             "Solutes i", "Solvents j", "Dij_inf", "STD(Dij_inf)"],...
            [expPath,'Boosting_Predictions.csv'],...
            'Delimiter',';');
writematrix(outData, ...
            [expPath,'Boosting_Predictions.csv'],...
            'Delimiter',';','WriteMode','append');