# -*- coding: utf-8 -*-
"""
Created on Thu Jun 21 23:38:07 2018

@author: Ben
"""

import tensorflow as tf


def bead_model_fn(features, labels, mode):
  """Model function for CNN."""
  # Input Layer
  input_layer = tf.reshape(features["x"], [-1,478,478, 1], name = 'tf_reshape1')
  
  # Convolutional Layer #1
  conv1 = tf.layers.conv2d(
      inputs=input_layer,
      filters=2, 
      kernel_size=[15, 15],
      padding="same",
      activation=tf.nn.relu) 

  # Pooling Layer #1
  pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

  # Convolutional Layer #2
  conv2 = tf.layers.conv2d(
      inputs=pool1,
      filters=4,
      kernel_size=[15, 15],
      padding="same",
      activation=tf.nn.relu)
 
  # Pooling Layer #2
  pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
  
  # Convolutional Layer #2
  conv3 = tf.layers.conv2d(
      inputs=pool2,
      filters=8,
      kernel_size=[15, 15],
      padding="same",
      activation=tf.nn.relu)
  
  # Pooling Layer #2
  pool3 = tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)
  

  # Flatten tensor into a batch of vectors
  pool3_flat = tf.reshape(pool3, [-1, 59 * 59 * 8])

  # Dense Layer
  dense = tf.layers.dense(inputs=pool3_flat, units=128, activation=tf.nn.relu)

  # Add dropout operation; 0.6 probability that element will be kept
  dropout = tf.layers.dropout(
      inputs=dense , rate=0.6, training=mode == tf.estimator.ModeKeys.TRAIN)

  # Logits layer = Number of classes
  logits = tf.layers.dense(inputs=dropout, units=3)

  predictions = {
      # Generate predictions (for PREDICT and EVAL mode)
      "classes": tf.argmax(input=logits, axis=1),
      # Add `softmax_tensor` to the graph. It is used for PREDICT and by the
      # `logging_hook`.
      "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
  }
  
  
  if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

  # Calculate Loss (for both TRAIN and EVAL modes)
  loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
  accuracy, update_op = tf.metrics.accuracy(labels=labels, predictions=predictions["classes"])
  my_acc = tf.reduce_mean(tf.cast(tf.equal(tf.cast(labels, tf.int64), predictions['classes']), tf.float32))


  # Configure the Training Op (for TRAIN mode)
  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) #0.001 = Matlab
    train_op = optimizer.minimize(
        loss=loss,
        global_step=tf.train.get_global_step())
    logging_hook = tf.train.LoggingTensorHook({"My accuracy": my_acc}, every_n_iter=100)
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, training_hooks = [logging_hook])

  # Add evaluation metrics (for EVAL mode)
  eval_metric_ops = {
      "accuracy": tf.metrics.accuracy(
          labels=labels, predictions=predictions["classes"])}
  return tf.estimator.EstimatorSpec(
      mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
  