/*
 * Copyright 2020-2021 NXP
 * All rights reserved.
 *
 * SPDX-License-Identifier: BSD-3-Clause
 */

#include "input_proc.h"

#include <stdlib.h>

#include "model.h"
#include "image_utils.h"
#include "output_postproc.h"
#include "ew_bsp_clock.h"
#include "ewrte.h"

#define PRINT_INPUT false

tensor_dims_t inputDims;
tensor_type_t inputType;
tensor_dims_t outputDims;
tensor_type_t outputType;

uint8_t* inputData = nullptr;
uint8_t* outputData = nullptr;

/* This function goes through the matrix through "layers" from the outermost layer inwards and rotates each layer by 90 degrees.
 * To visualize what a layer means, the following matrix has each layer filled with its index:
 *
 * 1 1 1 1 1
 * 1 2 2 2 1
 * 1 2 3 2 1
 * 1 2 2 2 1
 * 1 1 1 1 1
 *
 * Amount of layers in a matrix by size:
 *
 * 2x2 = 1 layer
 * 3x3 = 2 layers (innermost has only one element)
 * 4x4 = 2 layers
 * 5x5 = 3 layers (innermost has only one element)
 * 6x6 = 3 layers
 * 7x7 = 4 layers (innermost has only one element)
 * ...
 *  *
 * If the innermost layer is composed of only one element, it is skipped (hence the width/2 condition).
 *
 * The function first takes the corner elements of a given layer, rotates them, then goes to the next in each side in a clockwise
 * direction and rotates them etc. until the whole layer has been rotated by 90 degrees.
 *
 * */
void MODEL_RotateMatrixClockwise90(uint8_t *mat, int W) {
	for (int i = 0; i < W / 2; i++) {
		for (int j = i; j < W - i - 1; j++) {
			uint8_t tmp = mat[j + (i * W)];
			mat[j + (i * W)] = mat[i + ((W - 1 - j) * W)];
			mat[i + ((W - 1 - j) * W)] = mat[(W - 1 - j) + ((W - 1 - i) * W)];
			mat[(W - 1 - j) + ((W - 1 - i) * W)] = mat[(W - 1 - i) + (j * W)];
			mat[(W - 1 - i) + (j * W)] = tmp;
		}
	}
}

uint8_t MODEL_PrepareInput( uint8_t *inputImage, int srcWidth, int srcHeight ) {
	int m = 0;
	int minY = srcHeight, minX = srcWidth;
	int maxY = 0, maxX = 0;

	if (PRINT_INPUT) {
		EwPrint("Original:\n\r");
		int o = 0;
		for (int h = 0; h < srcHeight; h++)	{
		  for (int w = 0; w < srcWidth; w++) {
				if (inputImage[o] == 0)
					EwPrint("0");
				else
					EwPrint("1");

			  o++;
		  }
		  EwPrint("\n\r");
		}
		EwPrint("\n\n\n\r");
	}

	for (int h = 0; h < srcHeight; h++)	{
		for (int w = 0; w < srcWidth; w++) {
			if (inputImage[m] == 0xFF) {
				maxY = h;
				if (maxX < w) maxX = w;

				if (minY > h) minY = h;
				if (minX > w) minX = w;
			}

			m++;
		}
	}

	if (minX > maxX) {
		/* Empty input. */
		EwPrint("Empty input.\n\r");
		return 1;
	}

	/* By incrementing the max values one more time, it is ensured that even the
	 last pixel is always included after the cropping. Otherwise it would be
	 necessary to use them as (maxX + 1) everywhere.*/
	maxX++;
	maxY++;

	/* Adding an empty border around the drawing to make the resulting picture more
     similar in layout to the MNIST dataset pictures. */
	int borderThickness = 12; // in pixels

	int croppedCenteredWidth = maxX - minX + (borderThickness * 2);
	int croppedCenteredHeight = maxY - minY + (borderThickness * 2);
	uint8_t *croppedCenteredSrc = (uint8_t*) malloc(sizeof(uint8_t) * croppedCenteredWidth * croppedCenteredHeight);

	if (croppedCenteredSrc == nullptr) {
		EwPrint("Out of memory!\n\r");
		return 2;
	}

	for (int i = 0; i < croppedCenteredWidth * croppedCenteredHeight; i++)
		croppedCenteredSrc[i] = 0;

	int n = croppedCenteredWidth * borderThickness; // skip the top border
	n += borderThickness; // skip the left border
	for (int h = minY; h < maxY; h++, n += 2 * borderThickness) { // skip the right and the left borders
		for (int w = minX; w < maxX; w++, n++) {
			croppedCenteredSrc[n] = inputImage[w + (h * srcWidth)];
		}
	}

	if (PRINT_INPUT) {
		EwPrint("Cropped:\n\r");
		int i = 0;
		for (int h = 0; h < croppedCenteredHeight; h++)	{
		  for (int w = 0; w < croppedCenteredWidth; w++) {
				if (croppedCenteredSrc[i] == 0)
					EwPrint("0");
				else
					EwPrint("1");

			  i++;
		  }
		  EwPrint("\n\r");
		}
		EwPrint("\n\n\n\r");
	}

	IMAGE_Resize(croppedCenteredSrc, croppedCenteredWidth, croppedCenteredHeight,
	                 inputData, inputDims.data[2], inputDims.data[1], inputDims.data[3]);

	free(croppedCenteredSrc);

	if (PRINT_INPUT) {
		EwPrint("Resized:\n\r");
		int j = 0;
		for (int h = 0; h < 28; h++) {
		  for (int w = 0; w < 28; w++) {
			  if (inputData[j] == 0)
				  EwPrint("0");
			  else
				  EwPrint("1");

			  j++;
		  }
		  EwPrint("\n\r");
		}
		EwPrint("\n\n\n\r");
	}

	MODEL_RotateMatrixClockwise90(inputData, 28);

	if (PRINT_INPUT) {
		EwPrint("Rotated:\n\r");
		int o = 0;
		for (int h = 0; h < 28; h++)	{
		  for (int w = 0; w < 28; w++) {
				if (inputData[o] == 0)
					EwPrint("0");
				else
					EwPrint("1");

			  o++;
		  }
		  EwPrint("\n\r");
		}
		EwPrint("\n\n\n\r");
	}

	MODEL_ConvertInput(inputData, &inputDims, inputType);

	if (PRINT_INPUT) {
		EwPrint("Input Tensor:\n\r");
		int l = 0;
		for (int h = 0; h < 28; h++) {
		  for (int w = 0; w < 28; w++) {
			    switch (inputType)
			    {
			        case kTensorType_UINT8:
			        	if (inputData[l] == 0)
						  EwPrint("0");
						else
						  EwPrint("1");
			            break;
			        case kTensorType_INT8:
			        	if (reinterpret_cast<int8_t*>(inputData)[l] == 0)
			        		EwPrint("0");
						else
							EwPrint("1");
			            break;
			        case kTensorType_FLOAT32:
			        	if (reinterpret_cast<float*>(inputData)[l] == 0)
							EwPrint("0");
						else
							EwPrint("1");
			            break;
			        default:
			            assert("Unknown input tensor data type");
			    }

			  l++;
		  }
		  EwPrint("\n\r");
		}
		EwPrint("\n\n\n\r");
	}

	return 0;
}

int MODEL_ClassifyInput( uint8_t* image, int width, int height) {
    if (inputData == nullptr || outputData == nullptr) {
        inputData = MODEL_GetInputTensorData(&inputDims, &inputType);
        outputData = MODEL_GetOutputTensorData(&outputDims, &outputType);
    }

    uint8_t status = MODEL_PrepareInput(image, width, height);

    if (status > 0) {
    	return -1;
    }


    auto startTime = TIMER_GetTimeInUS();
    MODEL_RunInference();
    auto endTime = TIMER_GetTimeInUS();

    return MODEL_ProcessOutput(outputData, &outputDims, outputType, endTime - startTime);
}



