# Copyright 2021 by NXP.  All Rights Reserved.
#
# This code is provided as part of DeepView Creator and eIQ Portal for
# demonstration purposes.  Redistribution outside of DeepView Creator, DeepView
# ModelPacks, or eIQ Portal is strictly prohibited.
from deepview.trainer.extensions import interfaces
import tensorflow as tf
import os
import tensorflow.keras.backend as backend
from tensorflow.python.keras.applications import imagenet_utils

def get_plugin():
    """
    This function is very important since the trainer uses it to retrieve the main class 
    that where the plugin is implemented.
    In this sample, the function return cifar10,

    """
    return cifar10

def preprocess_input(x, data_format=None):
    """
    This method is used to preprocess input images. 
    This function will be passed to a tf.DataSet instance to internally handle and prefetch preprocessed images.
    """
    return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf')

class cifar10(interfaces.ImageClassificationInterface):

    """
    In this case we are creating the cifar10 model and export it as a plugin into eIQ Portal
    """

    def get_name(self):
        '''
        Model name shown in the eIQ Portal
        '''
        return "cifar10"

    def is_base(self):
        '''
         This function returns True or False if the model belongs to the eIQ Portal core plugins or not.
        '''
        return True

    def get_model(self, input_shape, num_classes, weights, named_params={}):
        '''
        In this method we receive several parameters that help us to configure the classification model
        - input_shape: (W, H, 3) tuple where W means width and H means height
        - num_classes: this is the number of classes the problem has
        - weights: For this parameter, the user is responsible of handling the tensor assignment inside the model
        - named_params: This parameter is exposed by this class and gives the GUI a dynamic behaviour.
        '''
        alpha = float(named_params.get('alpha', "0.35"))
        
        layer = 0
        drop = 0.25
        lstOCs = [32, 32, 64]

        model = tf.keras.Sequential()
        model.add(tf.keras.layers.Input(shape=input_shape))
        model.add(tf.keras.layers.Conv2D(filters = lstOCs[layer], kernel_size=(5, 5), 
                padding = 'same', strides = (1,1)))
        model.add(tf.keras.layers.BatchNormalization())
        model.add(tf.keras.layers.ReLU())
        model.add(tf.keras.layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding = 'same'))
        layer += 1

        if drop > 0 :
            model.add(tf.keras.layers.Dropout(drop))
        model.add(tf.keras.layers.Conv2D(filters = lstOCs[layer], kernel_size=(5, 5), 
                    padding = 'same', strides = (1,1)))
        model.add(tf.keras.layers.BatchNormalization())
        model.add(tf.keras.layers.ReLU())
        model.add(tf.keras.layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding = 'same'))
        layer += 1
        
        if drop > 0 :
            model.add(tf.keras.layers.Dropout(drop))
        model.add(tf.keras.layers.Conv2D(filters = lstOCs[layer], kernel_size=(3, 3), 
                    padding = 'same', strides = (1,1)))
        model.add(tf.keras.layers.BatchNormalization())
        model.add(tf.keras.layers.ReLU())
        model.add(tf.keras.layers.AveragePooling2D(pool_size=(3, 3), strides=(2, 2), padding = 'same'))
        layer += 1
        
        model.add(tf.keras.layers.Flatten())
        model.add(tf.keras.layers.Dropout(drop))
        model.add(tf.keras.layers.Dense(num_classes))
        model.add(tf.keras.layers.BatchNormalization())
        model.add(tf.keras.layers.Activation(activation='softmax'))

        if(os.path.exists(weights)):
            model.load_weights(weights, by_name=True, skip_mismatch=True)
        
        return model

    def get_task(self):
        """
        In this function we return the classification task in order to introduce more organization to our plugin system
        """
        return "classification"

    def get_metadata(self, base_path):
        """
        This function takes the base_path parameter wich stores the model path (checkpoint)
        and returns some additional metadata that should be used as constants while quantizing
        or converting the model

        In this case we only return the normalization as signed. This tells the converter 
        that the model needs a signed normalization during samples callibration in the quantization 
        process
        """
        response = {
            "constants": {},
            "params": {
                "normalization": "signed"
            }
        }
        return response
    
    def get_exposed_parameters(self):
        """
        This method allows the GUI can be configure dinamicaly by using the returned object list.
        Notice how each object has the same keys and only the values are changing.
        In this case, we only introduce two parameters, alpha and optimizer.
        In your case, you can include anything you need to configure the model. 
        Our GUI is prepared to read that properties and autoconfigure it. 
        Remember you need to parse this parameters inside the get_model method
        """
        return [
            {
                "name": "Alpha",
                "key": "alpha",
                "default": "0.35",
                "values": ["0.35", "0.50", "0.75", "1.00", "1.30"],
                "description": "Controls the width of the network"
            },
            {
                "name": "Optimizer",
                "key": "optimizer",
                "default": "Adam",
                "values": ["SGD", "Adam", "RMSprop", "Nadam", "Adadelta", "Adagrad", "Adamax"],
                "description": "Model optimizer"
            }
        ]

    
    def get_preprocess_function(self):
        """
        This method is used to preprocess input images. 
        This function will be passed to a tf.DataSet instance to internally handle and prefetch preprocessed images.
        """
        return preprocess_input

    def get_losses(self):
        """
        Our dataset iterators returns classes in a categorical way so the loss function we need to use is:
        "CategoricalCrossentropy". Also, you can create new loss functions and exposed them but this documentation 
        will be introduced in further guides.
        if you want to know all the supported losses
        then visit: GET http://127.0.0.1:10814/v2/losses
        """
        return ["CategoricalCrossentropy"]

    def get_optimizers(self):
        """
        This is a helper methods that returns the default optimizer
        if you want to know all the supported optimizers
        then visit: GET http://127.0.0.1:10814/v2/optimizers
        """
        return ["Adam"]

    def get_allowed_dimensions(self):
        """
        With this methods we are telling the GUI that our model support any input dimension larger or equal than 32
        """
        return ["32", '32']

    def get_pretrained_dimensions(self):
        """
        In this method we introduce the set of dimensions with pretrained weights as well as the source of the 
        weights. In our case we are only setting 'imagenet' as pretrained weights, but you can set anything there,
        One more time, remember what you set here can be a possible option in the weights parameter of the 
        get_model method, so you need to handle properly inside it.
        """
        return [["32", 'cifar10']]
    
    def get_qat_support(self):
        """
        In this method we'll tell the GUI if our model support Quantization Aware Training or not.
        In the case of the model does support QAT, then we can provide the input/outputs types 
        and the framework where the per-channel or per-tensor quantization is provided.
        Unfortunately Tensorflow only provides per-channel quantization but our converter is able to
        add this feature for you and create very accurate approximations for per-tensor quantization.

        This method returns two objects.
        The first one is related with per-channel quantization (more accurate, but unssuported by some edge 
        devices) and the second one per-tensor quantiaztion (less accurate, but faster)

        In regards to the properties of the objects:
        - supported: "true" or "false". 
        - types: list of the supported types
        - frameworks: list of the frameworks where the per-channel or per-tensor quantization is performed.
                      For the moment we only support two: TensorFlow and Converter

        NOTE: Per-tensor quantization is only supported by our Converter, for the moment
        """
        return [{
            # Per-Channel Quantization
            "supported": "true",
            "types": ['uint8', 'int8', 'float32'],
            "frameworks": ['Tensorflow', 'Converter']
        }, {
            # Per-Tensor Quantization
            "supported": "true",
            "types": ['uint8', 'int8', 'float32'],
            "frameworks": ["Converter"]
        }]

    def get_ptq_support(self):
        """
        In this method we'll tell the GUI if our model support Post Training Quantization or not.
        In the case of the model does support PTQ, then we can provide the input/outputs types 
        and the framework where the per-channel or per-tensor quantization is provided.
        Unfortunately Tensorflow only provides per-channel quantization but our converter is able to
        add this feature for you, and create very accurate approximations for per-tensor quantizations.

        This method returns two objects.
        The first one is related with per-channel quantization (more accurate, but unssuported by some edge 
        devices) and the second one per-tensor quantiaztion (less accurate, but faster)

        In regards to the properties of the objects:
        - supported: "true" or "false". 
        - types: list of the supported types
        - frameworks: list of the frameworks where the per-channel or per-tensor quantization is performed.
                      For the moment we only support two: TensorFlow and Converter

        NOTE: Per-tensor quantization is only supported by our Converter, for the moment
        """
        return [{
            # Per-Channel Quantization
            "supported": "true",
            "types": ['uint8', 'int8', 'float32'],
            "frameworks": ['Tensorflow', 'Converter']
        }, {
            # Per-Tensor Quantization
            "supported": "true",
            "types": ['uint8', 'int8', 'float32'],
            "frameworks": ["Converter"]
        }]

