# -*- coding: utf-8 -*-
"""
Created on Wed Jun 20 22:24:22 2018

@author: Ben+Fab
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function



import sys
sys.path.append(r"C:\Users\fg302-admin\Desktop\2019-06-28-double_Poisson")

import os.path
from pymba import Vimba
import PyDAQmx
from PyDAQmx import Task
import numpy as np
import tensorflow as tf
import time

import imageio
# Call saved model
model_dir=r'C:\Users\fg302-admin\Desktop\2019-06-28-double_Poisson\new-model_18_07_2019'

print(tf.__version__)

tf.reset_default_graph()
# Specify number of frames of interest
NFRAMES = 25000
# Specify location of the cropped region
HEIGHT = 480
WIDTH = 480
OffsetY=0
OffsetX=150

#Position of the center of the droplet 
x_droplet=240 
y_droplet=240 


Gain=15
Exposure_Time=100

# Load the trained model and metadata
new_saver= tf.train.import_meta_graph(model_dir + r'\\model.ckpt-40000.meta') 


sess = tf.Session()   

new_saver.restore(sess, tf.train.latest_checkpoint(model_dir)) #

graph = tf.get_default_graph()

# This is the output tensor of the trained model
softmax_tensor = sess.graph.get_tensor_by_name('softmax_tensor:0') 

prediction = []
classifier = []
###############   DEFINITION OF MASK FOR CIRCULAR MASK #####################
a, b = 239, 239
n = 478
# Radius
r = 239
y,x = np.ogrid[-a:n-a, -b:n-b]
mask = x*x + y*y <= r*r
            ######################################
# Load camera settings
with Vimba() as vimba:
    # get system object
        system = vimba.getSystem()
    
        # list available cameras (after enabling discovery for GigE cameras)
        if system.GeVTLIsPresent:                                                      
            system.runFeatureCommand("GeVDiscoveryAllOnce")
            time.sleep(0.2)
            cameraIds = vimba.getCameraIds()
            for cameraId in cameraIds:
                print('Camera ID:', cameraId)
    
    # get and open a camera
        camera0 = vimba.getCamera(cameraIds[0])
        camera0.openCamera()
           
        camera0.Height = HEIGHT
        camera0.Width = WIDTH
        camera0.OffsetY = OffsetY
        camera0.OffsetX = OffsetX
        
        camera0.Gain=Gain
        camera0.ExposureTime = Exposure_Time
        
        
        #TRIGGER MODE ON
        camera0.TriggerMode = 'On'
        #camera0.TriggerSource = 'Line0'
        camera0.TriggerActivation = 'RisingEdge'
        camera0.PixelFormat = 'Mono8'
        #camera0.TriggerSelector = 'FrameStart'
        camera0.AcquisitionMode = 'Continuous'
        camera0.ExposureMode = 'Timed'
        camera0.LineSelector = 'Line0'

        print ('TriggerMode:', camera0.TriggerMode)
        print ('TriggerSource:', camera0.TriggerSource)
        print ('TriggerActivation:', camera0.TriggerActivation)
        print ('TriggerSelector:', camera0.TriggerSelector)
            
        camera0.writeRegister('0xF0F00614', '00000000')#set trigger pins cf manual
        camera0.writeRegister('0xF0F00830', '82000000')#set trigger pins
        camera0.writeRegister('0xF0F0061C', '80000001')#set trigger pins
    
        print ('TriggerSource:', camera0.TriggerSource)
    
        # create new frames for the camera
        frame0 = camera0.getFrame()
    
        # announce frame
        frame0.announceFrame()
    
        # capture camera images
        camera0.startCapture()
   
        camera0.runFeatureCommand('AcquisitionStart')
        frame0.queueFrameCapture()
       
   
        #WAIT FOR NEXT AVAILABLE DROPLET IMAGE
    
        count=0
        print("now entering the loop!")
        # SET TOTAL NUMBER OF DROPLETS TO SORT
        analysis_time=[]
        #Start Task
        task = Task()
        #Analog output
        task.CreateAOVoltageChan("/Dev1/ao0","",0.0,5.0,PyDAQmx.DAQmx_Val_Volts,None) #PyDAQmx.DAQmx_Val_Volts
        task.StartTask()
        value = 5.0
        
        pred_class=0
        
        while (count < NFRAMES):
            
            frame0.waitFrameCapture(timeout=2147483648)   # 2**31 -- uses int32
            frame0.queueFrameCapture()
            
                     
            t0=time.time()


            data1_np = np.ndarray(buffer=frame0.getBufferByteData(),
                      dtype=np.uint8,
                      shape=(frame0.height, frame0.width))
             
             
            data2_np=data1_np.astype(np.float32)
            
            
            #### HERE CROP IMAGE AROUND DROPLET AND APPLY CIRCULAR MASK
            h=478
            
            crop_img = data2_np[y_droplet-239:y_droplet-239+h, x_droplet-239:x_droplet-239+h]
                       
            crop_img[~mask] = 0
                                               
            crop_img2=crop_img-np.mean(crop_img)
            crop_img2=crop_img2/np.std(crop_img2)
            
            crop_img2 = np.reshape(crop_img2, (1,h,h))    # Having to reshape the image here is a bug that can be easily fixed
            crop_img2 = np.expand_dims(crop_img2, axis = 3)  
                     
            
            
            predictions = sess.run(softmax_tensor, {'tf_reshape1:0': crop_img2} ) # Passes the image data (crop_img) through the graph

            pred_result = np.argmax(predictions[0]) # Predictions contains the probability for the three classes. Max probabilty is selected as the predicted result
            pred_probs = np.max(predictions[0])

                            
            dt=time.time()-t0
            t1=time.time()
            
            print(str(np.round(dt,5)))   
            #####SEND A PULSE ONLY IF THE TIMING IS RIGHT
            if dt<0.015:
                
                if (pred_result==1) & (pred_probs>0.6):
                  #  print("droplet selected !")
                                                     
                    task.WriteAnalogScalarF64(1,5.0,value,None) #this takes roughly 1ms to execute !
                                           
                    time.sleep(0.015) #10ms pulse
                    task.WriteAnalogScalarF64(1,5.0,0.0,None)
                    #Saves sorted image
                    imageio.imsave(os.path.join('Image Collection\Live',
                                     'sorted{} '.format(count)+str(round(pred_probs,2))+'.png'),crop_img.astype('uint8')) #with saving
 
                 
            count+=1