/*
 * Copyright 2019 NXP
 * All rights reserved.
 *
 * SPDX-License-Identifier: BSD-3-Clause
 */

#include <mnist_lock.h>

#include "board.h"
#include "pin_mux.h"
#include "clock_config.h"

#include <iostream>
#include <string>
#include <vector>

#include "image.h"

#include "ewmain.h"
#include "ewrte.h"
#include "ew_bsp_system.h"
#include "ew_bsp_serial.h"

#define LOG(x) std::cout
#define PRINT_INPUT false
#define PRINT_CONFIDENCE false
#include "arm_math.h"
#include "parameter.h"
#include "weights.h"
#include "CMSIS/NN/Include/arm_nnfunctions.h"
static const uint8_t mean[DATA_OUT_CH*DATA_OUT_DIM*DATA_OUT_DIM] = MEAN_DATA;

static const q7_t conv1_wt[CONV1_IN_CH*CONV1_KER_DIM*CONV1_KER_DIM*CONV1_OUT_CH] = CONV1_WT;
static const q7_t conv1_bias[CONV1_OUT_CH] = CONV1_BIAS;

static const q7_t conv2_wt[CONV2_IN_CH*CONV2_KER_DIM*CONV2_KER_DIM*CONV2_OUT_CH] = CONV2_WT;
static const q7_t conv2_bias[CONV2_OUT_CH] = CONV2_BIAS;

static const q7_t conv3_wt[CONV3_IN_CH*CONV3_KER_DIM*CONV3_KER_DIM*CONV3_OUT_CH] = CONV3_WT;
static const q7_t conv3_bias[CONV3_OUT_CH] = CONV3_BIAS;

static const q7_t ip1_wt[IP1_IN_DIM*IP1_OUT_DIM] = IP1_WT;
static const q7_t ip1_bias[IP1_OUT_DIM] = IP1_BIAS;

//Add input_data and output_data in top main.cpp file
const char* labels[] = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"};
//uint8_t image_data[DATA_OUT_CH*DATA_OUT_DIM*DATA_OUT_DIM]=DIGIT_IMG_DATA;
q7_t output_data[IP1_OUT_DIM];

q7_t col_buffer[5000];
q7_t scratch_buffer[27200];

void mean_subtract(uint8_t* image_data) {
  for(int i=0; i<DATA_OUT_CH*DATA_OUT_DIM*DATA_OUT_DIM; i++) {
    image_data[i] = (q7_t)__SSAT( ((int)(image_data[i] - mean[i]) >> DATA_RSHIFT), 8);
  }
}

int *run_nn(uint8_t* input_data) {
  q7_t* buffer1 = scratch_buffer;
  q7_t* buffer2 = buffer1 + 15680;
  mean_subtract(input_data);
  arm_convolve_HWC_q7_basic((q7_t*)input_data, CONV1_IN_DIM, CONV1_IN_CH, conv1_wt, CONV1_OUT_CH, CONV1_KER_DIM, CONV1_PAD, CONV1_STRIDE, conv1_bias, CONV1_BIAS_LSHIFT, CONV1_OUT_RSHIFT, buffer1, CONV1_OUT_DIM, (q15_t*)col_buffer, NULL);
  arm_relu_q7(buffer1, RELU1_OUT_DIM*RELU1_OUT_DIM*RELU1_OUT_CH);
  arm_maxpool_q7_HWC(buffer1, POOL1_IN_DIM, POOL1_IN_CH, POOL1_KER_DIM, POOL1_PAD, POOL1_STRIDE, POOL1_OUT_DIM, col_buffer, buffer2);
  arm_convolve_HWC_q7_fast(buffer2, CONV2_IN_DIM, CONV2_IN_CH, conv2_wt, CONV2_OUT_CH, CONV2_KER_DIM, CONV2_PAD, CONV2_STRIDE, conv2_bias, CONV2_BIAS_LSHIFT, CONV2_OUT_RSHIFT, buffer1, CONV2_OUT_DIM, (q15_t*)col_buffer, NULL);
  arm_relu_q7(buffer1, RELU2_OUT_DIM*RELU2_OUT_DIM*RELU2_OUT_CH);
  arm_maxpool_q7_HWC(buffer1, POOL2_IN_DIM, POOL2_IN_CH, POOL2_KER_DIM, POOL2_PAD, POOL2_STRIDE, POOL2_OUT_DIM, col_buffer, buffer2);
  arm_convolve_HWC_q7_fast(buffer2, CONV3_IN_DIM, CONV3_IN_CH, conv3_wt, CONV3_OUT_CH, CONV3_KER_DIM, CONV3_PAD, CONV3_STRIDE, conv3_bias, CONV3_BIAS_LSHIFT, CONV3_OUT_RSHIFT, buffer1, CONV3_OUT_DIM, (q15_t*)col_buffer, NULL);
  arm_relu_q7(buffer1, RELU3_OUT_DIM*RELU3_OUT_DIM*RELU3_OUT_CH);
  arm_maxpool_q7_HWC(buffer1, POOL3_IN_DIM, POOL3_IN_CH, POOL3_KER_DIM, POOL3_PAD, POOL3_STRIDE, POOL3_OUT_DIM, col_buffer, buffer2);
  arm_fully_connected_q7_opt(buffer2, ip1_wt, IP1_IN_DIM, IP1_OUT_DIM, IP1_BIAS_LSHIFT, IP1_OUT_RSHIFT, ip1_bias, output_data, (q15_t*)col_buffer);
  arm_softmax_q7(output_data, 10, output_data);

  /* Get the object class with the highest confidence value */
  q7_t max_value;
  uint32_t max_index;
  arm_max_q7(output_data, 10, &max_value, &max_index);

  int *result_array = (int*) malloc(sizeof(int) * 2);
  result_array[0] = -1;
  result_array[1] = -1;

  if (result_array[0] == -1)
  {
	result_array[0] = (int)((((int)max_value + 128) * 100) / 255);
	result_array[1] = max_index;
  }
  if (PRINT_CONFIDENCE)
	  LOG(INFO) << "  " << labels[max_index] << " (" << (int)result_array[0] << "% confidence)\r\n";
  return result_array;
}

int main(void)
{
	/* initialize system */
	EwBspConfigSystem();

	/* initialize serial interface for debug messages */
	EwBspConfigSerial();

	/* initialize Embedded Wizard application */
	if ( EwInit() == 0 )
		return 0;

	EwPrintSystemInfo();

	/* Tensorflow-lite initialization. */
//	tflite::mnist::InferenceInit();

	/* process the Embedded Wizard main loop */
	while( EwProcess())
		;

	/* de-initialize Embedded Wizard application */

	EwDone();

	return 0;
}

/*!
 * @brief Process the input image, prepare the global interpreter and call inference
 *
 * @param pointer to the input image
 * @param image width in pixels
 * @param image height in pixels
 *
 * @return pointer to an array with results from the inference
 */
int *processImage( uint8_t *src, int image_width, int image_height )
{ 
	Image pre_sca;
	pre_sca.channels = 1;
	double dx = 0, dy = 0;

	int m = 0;
	int minY = image_height, minX = image_width;
	int maxY = 0, maxX = 0;
	for (int h = 0; h < image_height; h++)
	{
		for (int w = 0; w < image_width; w++)
		{
			if (src[m] == 0xFF)
			{
				maxY = h;
				if (maxX < w) maxX = w;

				if (minY > h) minY = h;
				if (minX > w) minX = w;
			}

			m++;
		}
	}

	/* By incrementing the max values one more time, it is ensured that even the
     last pixel is always included after the cropping. Otherwise it would be
     necessary to use them as (maxX + 1) everywhere.*/
	maxX++;
	maxY++;

	/* Adding an empty border around the drawing to make the resulting picture more
     similar in layout to the MNIST dataset pictures. */
	int border_thickness = 12; // in pixels

	int cropped_centered_width = maxX - minX + (border_thickness * 2);
	int cropped_centered_height = maxY - minY + (border_thickness * 2);
	uint8_t *cropped_centered_src = (uint8_t*) malloc(sizeof(uint8_t) * cropped_centered_width * cropped_centered_height);

	if (cropped_centered_src == NULL)
	{
		LOG(INFO) << "Out of memory!\n\r";
		return NULL;
	}

	for (int i = 0; i < cropped_centered_width * cropped_centered_height; i++)
		cropped_centered_src[i] = 0;

	int n = cropped_centered_width * border_thickness; // skip the top border
	n += border_thickness; // skip the left border
	for (int h = minY; h < maxY; h++, n += 2 * border_thickness) // skip the right and the left borders
	{
		for (int w = minX; w < maxX; w++, n++)
		{
			cropped_centered_src[n] = src[w + (h * image_width)];
		}
	}

	pre_sca.width = cropped_centered_width;
	pre_sca.height = cropped_centered_height;
	dx = 1.0*28/pre_sca.width, dy = 1.0*28/pre_sca.height;
	Image *img = ImCreate(&pre_sca, dx, dy);

	pre_sca.imageData = cropped_centered_src;

	if (PRINT_INPUT)
	{
		LOG(INFO) << "Original:\n\r";
		int o = 0;
		for (int h = 0; h < image_height; h++)
		{
		  for (int w = 0; w < image_width; w++)
		  {
				if (src[o] == 0)
				  LOG(INFO) << "0";
				else
				  LOG(INFO) << "1";

			  o++;
		  }
		  LOG(INFO) << "\n\r";
		}
		LOG(INFO) << "\n\n\n\r";

		LOG(INFO) << "Cropped:\n\r";
		int i = 0;
		for (int h = 0; h < cropped_centered_height; h++)
		{
		  for (int w = 0; w < cropped_centered_width; w++)
		  {
				if (cropped_centered_src[i] == 0)
				  LOG(INFO) << "0";
				else
				  LOG(INFO) << "1";

			  i++;
		  }
		  LOG(INFO) << "\n\r";
		}
		LOG(INFO) << "\n\n\n\r";
		std::flush(std::cout);
	}

	img = ImScale(&pre_sca, img, dx, dy);

	if (PRINT_INPUT)
	{
		LOG(INFO) << "Scaled:\n\r";
		int j = 0;
		for (int h = 0; h < 28; h++)
		{
		  for (int w = 0; w < 28; w++)
		  {
			  if (img->imageData[j] == 0)
				LOG(INFO) << "0";
			  else
				LOG(INFO) << "1";

			  j++;
		  }
		  LOG(INFO) << "\n\r";
		}
		LOG(INFO) << "\n\n\n\r";
		std::flush(std::cout);
	}

/*	int input = interpreter->inputs()[0];
	float* input_tensor = interpreter->typed_tensor<float>(input);

	int k = 0;
	for (int h = 0; h < 28; h++)
	{
		for (int w = 0; w < 28; w++)
		{
			input_tensor[k] = img->imageData[k] / 255.0;
			k++;
		}
	}

	if (PRINT_INPUT)
	{
		LOG(INFO) << "Input Tensor:\n\r";
		int l = 0;
		for (int h = 0; h < 28; h++)
		{
		  for (int w = 0; w < 28; w++)
		  {
			  if (interpreter->typed_tensor<float>(input)[l] == 0)
				LOG(INFO) << "0";
			  else
				LOG(INFO) << "1";

			  l++;
		  }
		  LOG(INFO) << "\n\r";
		}
		LOG(INFO) << "\n\n\n\r";
		std::flush(std::cout);
	}*/
	uint8_t inputImage[785];
	int k = 0;
	for (int h = 0; h < 28; h++)
	{
		for (int w = 0; w < 28; w++)
		{
			inputImage[k] = img->imageData[k];
			// add extra 3 pixels to improve the quality of image
			if(img->imageData[k] == 255){		//		  x
				if((h-1) < 28 )					//		x 1 x
					inputImage[w + (h - 1) * 28] =255;
				if((w-1)>0)
					inputImage[(w - 1) + h * 28] =255;
				if((w+1) < 28)
					inputImage[(w + 1) + h * 28] =255;
			}
			k++;
		}
	}
	free(cropped_centered_src);
	free(img->imageData);
	free(img);

//	tflite::mnist::Settings s; // use the default settings
//	return tflite::mnist::RunInference(&s);
	return run_nn(inputImage);
}
