clear
clc
close all

%% Input
[expFile,expPath] = uigetfile({'*.csv;*.xlsx'},'Select table containing EXPERIMENTAL data.','DataBase.csv');
rawExp = readmatrix([expPath,expFile],...
                    'NumHeaderLines',3,...
                    'DecimalSeparator','.');

[hybFile,hybPath] = uigetfile({'*.csv;*.xlsx'},'Select table containing HYBRIDIZATION data.','SEGWE.csv');
rawHyb = readmatrix([hybPath,hybFile],...
                    'NumHeaderLines',2,...
                    'DecimalSeparator','.');

% get numerical data
rawExp = rawExp(~isnan(rawExp(:,5)),:); % remove non-numeric entries
soluteID_exp  = rawExp(:,1);
solventID_exp = rawExp(:,3);
Dinf_data = log(rawExp(:,5));
soluteID_hyb  = rawHyb(:,1);
solventID_hyb = rawHyb(:,2);
SEGWE_data = log(rawHyb(:,3));

% transform experimental data from list to matrix
keySolute = unique(soluteID_exp);
keySolvent = unique(solventID_exp);
[~,j] = ismember(solventID_exp,keySolvent);
[~,i] = ismember(soluteID_exp,keySolute);
expDataAvailable = full(sparse(i,j,true));
Dinf_data = full(sparse(i,j,Dinf_data));
Dinf_data(~expDataAvailable) = nan;

%transform hybridization data from list to matrix
[~,j] = ismember(solventID_hyb,keySolvent);
[~,i] = ismember(soluteID_hyb,keySolute);
ind = (i~=0 & j~=0);
hybDataAvailable = full(sparse(i(ind),j(ind),true));
SEGWE_data = full(sparse(i(ind),j(ind),SEGWE_data(ind)));
SEGWE_data(~hybDataAvailable) = nan;

%% Fit with Stan
I = size(Dinf_data,1); % number of solutes
J = size(Dinf_data,2); % number of solvents
K = 2; % number of latent dimensions
sigma_0 = 1; % Prior standard deviation
lambda = 0.2; % likelihood scale

% Distillation Step (Pretraining)
data = struct('I',I,'J',J,'K',K,'ln_D',SEGWE_data,'sigma_0',sigma_0,'lambda',lambda);
fit = stan('file','model_whisky_distillation.stan','data',data,'method','variational','iter',1000,'verbose',true,'seed',208776);
fit.block();
results = fit.extract;

% results
u_Segwe = squeeze(mean(results.u,1));
v_Segwe = squeeze(mean(results.v,1));
u_std_Segwe = squeeze(std(results.u,0,1));
v_std_Segwe = squeeze(std(results.v,0,1));

% calculation of the mean standard deviation and scale to 1.0
uninf_Solutes = sum((SEGWE_data ~= -99),2) == 0;
uninf_Solvents = (sum((SEGWE_data ~= -99),1) == 0)';
std_fac = 0.5.*sigma_0 / mean([u_std_Segwe(repmat(~uninf_Solutes,K,1)); v_std_Segwe(repmat(~uninf_Solvents,K,1))]);
u_std_Segwe = u_std_Segwe.*std_fac;
v_std_Segwe = v_std_Segwe.*std_fac;

% Product of Experts (PoE)
u_Segwe = u_Segwe .* (sigma_0.^2) ./ (u_std_Segwe.^2 + sigma_0.^2); % N(0,1)
v_Segwe = v_Segwe .* (sigma_0.^2) ./ (v_std_Segwe.^2 + sigma_0.^2); % N(0,1)
u_std_Segwe = sqrt((u_std_Segwe.^2 .* sigma_0.^2) ./ (u_std_Segwe.^2 + sigma_0.^2));
v_std_Segwe = sqrt((v_std_Segwe.^2 .* sigma_0.^2) ./ (v_std_Segwe.^2 + sigma_0.^2));

% Mean and Std of uninformed prior
u_Segwe(uninf_Solutes,:) = 0;
v_Segwe(uninf_Solvents,:) = 0;
u_std_Segwe(uninf_Solutes,:) = sigma_0;
v_std_Segwe(uninf_Solvents,:) = sigma_0;

%% Maturation Step
data = struct('I',I,'J',J,'K',K,'ln_D',Dinf_matrix,'mu_0_u',u_Segwe,'mu_0_v',v_Segwe,'sigma_0_u',u_std_Segwe,'sigma_0_v',v_std_Segwe,'lambda',lambda);
fit = stan('file','model_whisky_maturation.stan','data',data,'method','variational','iter',1000,'verbose',true,'seed',208776);
fit.block();
results = fit.extract;

% D for each of the 1001 sample points
u_samples = results.u;
v_samples = results.v;
% prediction of D
Dinf_pred = zeros(1001,I,J);
for k = 1:size(u_samples,1)
    Dinf_k = [squeeze(u_samples(k,:,:))] * [squeeze(v_samples(k,:,:))]';
    Dinf_pred(k,:,:) = Dinf_k;
end
Std_Dinf_pred = squeeze(std(Dinf_pred,0,1));

% set unknown entries to NaN
Dinf_pred = squeeze(mean(Dinf_pred,1)); % predicted D
Std_Dinf_pred(Std_Dinf_pred==0) = NaN;

%% Output
[jj,ii] = find(~isnan(Dinf_pred.'));
outData(:,1) = keySolute(ii);
outData(:,2) = keySolvent(jj);
outData(:,3) = reshape(exp(Dinf_pred),[],1);
outData(:,4) = reshape(exp(Std_Dinf_pred),[],1);

writematrix(["Prediction of Dij_inf by MCM-Whisky","","",""; ...
             "Solutes i", "Solvents j", "Dij_inf", "STD(Dij_inf)"],...
            [expPath,'Whisky_Predictions.csv'],...
            'Delimiter',';');
writematrix(outData,...
            [expPath,'Whisky_Predictions.csv'],...
            'Delimiter',';','WriteMode','append');