#############################################
# P. Sen 2021
#
# This script first trains a RF classifier to classify materials as 
# magnetic or non-magnetic. This model is then used to classify the set 
# of 278 combined 3d5d TM compounds
# into magnetic and non-magnetic classes.
#
###########################################
#import required packages
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rc
#rc('text', usetex=True) # If latex causes problems, comment this line 
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, explained_variance_score

np.set_printoptions(threshold=np.inf, formatter={'float': lambda x: "{0:9.5f}".format(x)})

Mag_data = pd.read_csv('./data-cmr-mag-classification.csv',sep=',')

# Add derived features
print('Shape= ',Mag_data.shape)

# Plot the hform distributions before and after scaling
#import seaborn as sb
#fig, ax = plt.subplots(2,2,figsize=(6,6))
#sb.distplot(hform_data['hform'])
#sb.distplot(MAE_data['log_MAE'], ax=ax[0,1])
#plt.xlabel('Heat of formation per atom', fontsize=18)
#ax[0,1].set_xlabel('$log(MAE)$', fontsize=18)
#ax[0,0].tick_params(direction='in', labelsize=18)
#ax[0,1].tick_params(direction='in', labelsize=18)

Mag_data.drop(labels=['id','n_metal'], axis=1, inplace=True)

print('Shape= ',Mag_data.shape)
# Select rows after setting cutoff for the target variable 
#y_min = 0.0
#y_max = -0.25
#MAE_data = MAE_data.loc[MAE_data['mom'] >= y_min]

Nsamples = Mag_data.shape[0]
#print('log(MAE) selected between [', y_min, ':', y_max, ']')
print('No. of samples = ', Nsamples)

ismagnetic = Mag_data[['mag_state']].values
Mag_data.drop(labels=['mag_state'], axis=1, inplace=True)

#print(ismagnetic)
y_train = []
for i in range(len(ismagnetic)):
        #print(ismagnetic[i])
        if (ismagnetic[i] == ' NM'): y_train.append('0')
        else: y_train.append('1')
#print(len(ismagnetic))
#print(y)

X_train = Mag_data.to_numpy()
print('Shape= ',Mag_data.shape)
print()


Ntrain = np.shape(y_train)[0] # no. of training examples

# Scale the features and the target
from sklearn.preprocessing import QuantileTransformer, RobustScaler, MinMaxScaler
# Scaling feature data
Xscaler = QuantileTransformer(n_quantiles=Ntrain, output_distribution='normal').fit(X_train) # Handles outliers better
#Xscaler = MinMaxScaler(feature_range=(-1,1)).fit(X_train)
X_train = Xscaler.transform(X_train)

# Scaling target data
#y_train = y_train.reshape(-1,1)
#y_test = y_test.reshape(-1,1)

#Yscaler = QuantileTransformer(n_quantiles=Ntrain, output_distribution='normal').fit(y_train)
#Yscaler = MinMaxScaler(feature_range=(-1, 1)).fit(y_train)
#y_train = Yscaler.transform(y_train)
#y_test = Yscaler.transform(y_test)

# Plot scaled target distributions
#fig, ax = plt.subplots(1,2,figsize=(6,6))
#sb.distplot(y_train, ax=ax[0])
#sb.distplot(y_test, ax=ax[1])
#ax[0].set_xlabel('$scaled\,y_{train}$', fontsize=18)
#ax[1].set_xlabel('$scaled\, y_{test}$', fontsize=18)
#ax[0].tick_params(direction='in', labelsize=18)
#ax[1].tick_params(direction='in', labelsize=18)
#fig.tight_layout()
#plt.savefig('target_dist.png')
#plt.show()

from sklearn.model_selection import GridSearchCV

N_CV = 5
#Create random forest regressor
#estimators = np.linspace(90,110,21, dtype=int, endpoint=True)
estimators = [100,125,150,175,200,225] 
depth = [15,16,17,18,19,20,21,22,23,24,25]
min_split = [2,3,4,5,6,7]
#min_leaf = np.linspace(1,10,9, dtype=int, endpoint=True) # 'min_samples_leaf':min_leaf
min_leaf = [1]
max_sample = np.linspace(0.6,0.6,1, endpoint=True)  # 'max_samples':max_sample
max_features = 'sqrt'
#max_leaf = np.linspace(50,600,10, dtype=int, endpoint=True) # max_leaf_nodes
#alphas = np.linspace(0.0,0.3,4, endpoint=True)
alphas = [0.0]

parameters = {'n_estimators':estimators, 'max_depth':depth, 'max_samples': max_sample, \
           'min_samples_leaf':min_leaf, 'min_samples_split':min_split, 'ccp_alpha':alphas}
#parameters = {'max_depth':depth, 'max_samples':max_sample}
rf = GridSearchCV(RandomForestClassifier(criterion='gini',max_features='sqrt'), parameters, n_jobs=-1, cv=N_CV, verbose=0)
#rf = RandomForestRegressor(n_estimators=125,criterion='mae',max_depth=19,min_samples_split=2,\
#min_samples_leaf=1,max_features='sqrt',n_jobs=-1,verbose=0,ccp_alpha=0.0,max_samples=0.6,random_state=1)

#rf.fit(X_train,y_train)
opt = rf.fit(X_train,y_train)
#print('Best parameters : ', opt.get_params)
print('Scorer gini :\t', 'best score : ', opt.score)

y_pred = opt.best_estimator_.predict(X_train)
#y_pred_train = opt.predict(X_train)

print()

# Perform the inverse scaling transormations
#X_test_inv = Xscaler.inverse_transform(X_test)
#y_test_inv = Yscaler.inverse_transform(y_test.reshape(-1,1))
#y_train_inv = Yscaler.inverse_transform(y_train.reshape(-1,1))
#y_pred_inv = Yscaler.inverse_transform(y_pred.reshape(-1,1))
#y_pred_train_inv = Yscaler.inverse_transform(y_pred_train.reshape(-1,1))

#Mean absolute error
#print('MAE for train hform: %.4f' % mean_absolute_error(y_train_inv, y_pred_train_inv))
#print('MAE for test hform: %.4f' % mean_absolute_error(y_test_inv, y_pred_inv))
#print()


# Print the feature ranking

#importance = opt.best_estimator_.feature_importances_
#print("Feature ranking:")

#for x in range(len(importance)):
#	print(x,'\t',importance[x])

#ax.bar([x for x in range(len(importance))], importance)
#plt.show()

#for f in range(X.shape[1]):
#    print("%d. feature %d (%f)" % (f + 1, indices[f], importances[indices[f]]))

# Now read the data for new compounds
new_Mag_data = pd.read_csv('./3d5d-data-mag.csv',sep=',')
print('Shape= ',new_Mag_data.shape)

new_Mag_data.drop(labels=['id','n_metal'], axis=1, inplace=True)
print('Shape= ',new_Mag_data.shape)

print('features selected')

X_test = new_Mag_data.to_numpy()
X_test = Xscaler.transform(X_test)

y_pred = opt.predict(X_test)

cnt_high = 0
index_mag = []
for i in range(len(y_pred)):
	        #print(i, y_pred[i])
	if (y_pred[i] == '1'):
	        print(i,' is magnetic')
	        index_mag.append(i)
	        cnt_high += 1
print()
print('Number of predicted magnetic compounds ', cnt_high)

