#################################
# Prasenjit Sen
#
# This scripts does RF classification of stable, magnetic (FM & AFM)
# compunds in c2db database. 
####################################

import matplotlib.pyplot as plt
import seaborn as sb
import numpy as np
from scipy.stats import pearsonr
#from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score
from sklearn.model_selection import cross_val_score, GridSearchCV, train_test_split
from sklearn import preprocessing
from sklearn.feature_selection import SelectFromModel
from sklearn.inspection import plot_partial_dependence
import statistics
import pandas as pd
from math import sqrt

Mag_data = pd.read_csv('../../datasets/magnetic/cmr-mag-classification.csv',sep=',')
print(Mag_data.shape)

# Select desired features
Mag_data = Mag_data[['d_elect','mean_Z','del_Z','mode_Z','mean_group','del_group','mode_group',\
'mean_period','del_period','mode_period','mean_val','del_val','mode_val','mean_electroneg','del_electroneg',\
'mode_electroneg', 'cell_area',\
'smeig1','smeig2','smeig3','smeig4','smeig5','smeig6','smeig7','smeig8','smeig9','smeig10','smeig11','smeig12', 'mag_state'\
]]

#'acsf1','acsf2','acsf3','acsf4','acsf5','acsf6','acsf7','acsf8','acsf9','acsf10','acsf11','acsf12','acsf13','acsf14','acsf15','acsf16',\
#'acsf17','acsf18','acsf19','acsf20','acsf21','acsf22','acsf23','acsf24','acsf25','acsf26','acsf27','acsf28','acsf29','acsf30','acsf31','acsf32',\
#'entropy', 'L2_norm', 'L3_norm',\ 

#Mag_data.drop(labels=['id','n_metal'], axis=1, inplace=True)
print(Mag_data.shape)

Nsamples = Mag_data.shape[0]
#print('log(Mag) selected between [', y_min, ':', y_max, ']')
print('No. of samples = ', Nsamples)

#Shuffle rows randomly
Mag_data = Mag_data.sample(frac=1, axis=0).reset_index(drop=True)


# Import function to create training and test set splits
from sklearn.model_selection import train_test_split

ismagnetic = Mag_data[['mag_state']].values
Mag_data.drop(labels=['mag_state'], axis=1, inplace=True)

#print(ismagnetic)
y = []
for i in range(len(ismagnetic)):
	#print(ismagnetic[i])
	if (ismagnetic[i] == ' NM'): y.append('0')
	else: y.append('1')
#print(len(ismagnetic))
#print(y)



X = Mag_data.to_numpy()
# Test/train split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)
print('Shape ',X_train.shape)

Ntrain = np.shape(y_train)[0] # no. of training examples

# Scale the features and the target
from sklearn.preprocessing import QuantileTransformer, RobustScaler, MinMaxScaler
# Scaling feature data
Xscaler = QuantileTransformer(n_quantiles=Ntrain, output_distribution='normal').fit(X_train) # Handles outliers better
#Xscaler = MinMaxScaler(feature_range=(-1,1)).fit(X_train)
X_train = Xscaler.transform(X_train)
X_test = Xscaler.transform(X_test)

#######################

#Create random forest regressor
estimators = np.linspace(100,250,7, dtype=int, endpoint=True)
depth = np.linspace(11,25,15, dtype=int, endpoint=True)
min_split = np.linspace(5,18,14, dtype=int, endpoint=True)
#min_leaf = np.linspace(1,10,9, dtype=int, endpoint=True) # 'min_samples_leaf':min_leaf
min_leaf = [1] 
max_sample = np.linspace(0.6,0.6,1, endpoint=True)  # 'max_samples':max_sample
max_features = 'sqrt'
#max_leaf = np.linspace(50,600,10, dtype=int, endpoint=True) # max_leaf_nodes
#alphas = np.linspace(0.0,0.3,4, endpoint=True)
alphas = [0.0]

parameters = {'n_estimators':estimators, 'max_depth':depth, 'max_samples':max_sample,\
           'min_samples_leaf':min_leaf, 'min_samples_split':min_split, 'ccp_alpha':alphas}
rf = GridSearchCV(RandomForestClassifier(criterion='gini',max_features='sqrt'), parameters, n_jobs=-1, cv=5, verbose=0)

#rf.fit(X_train,y_train)
opt = rf.fit(X_train,y_train)
print('Best parameters : ', opt.best_params_)
print('Scorer gini :\t', 'best score : ', opt.best_score_)

y_pred = rf.best_estimator_.predict(X_test)

print()
print(confusion_matrix(y_test,y_pred))
print('Precision=',precision_score(y_test, y_pred, pos_label='1'))
print('Recall=',recall_score(y_test, y_pred, pos_label='1'))
print('f1=',f1_score(y_test, y_pred, pos_label='1'))
print()

importance = opt.best_estimator_.feature_importances_
print("Feature ranking:")

sum = 0.0
for c,col in enumerate(Mag_data.columns):
	sum += importance[c]
	if (importance[c] > 0.01):
	        print(c,'\t',col,'\t',importance[c])

print('Sum of importances ',sum)
