#!/usr/bin/env python3

# Copyright 2023-2024 NXP
# SPDX-License-Identifier: Apache-2.0

import os
import csv
import time
import atexit
import argparse
import json
import sys

from pathlib import Path
from ctypes import c_char_p
from multiprocessing import Process, Queue, Manager

import requests
import numpy as np
import cv2

sys.path.insert(0, "base/runtimes/face/")
sys.path.insert(0, "base/runtimes/hand/")

from face_database import FaceDatabase
from face_detection import YoloFace
from face_recognition import Facenet
from hand_tracker import HandTracker

### Arguments ###

PARSE = argparse.ArgumentParser(
    prog="i.MX Computer Vision Benchmark",
    description="Test the performance of compatible i.MX devices on simultaneous vision ML tasks.",
    epilog="NXP",
)

PARSE.add_argument(
    "-p",
    "--platform",
    type=str,
    default=None,
    help="""Target platform configuration to load.
    To see a list of available platforms, choose `-l`.""",
)

PARSE.add_argument(
    "-d",
    "--demo",
    type=str,
    default="both",
    help="Choose which demo to run. Options: face, hand, both.",
)

PARSE.add_argument(
    "-s",
    "--save",
    type=Path,
    default=None,
    help="""The .csv file path must be specified for the
    benchmark results, otherwise the results will not be saved.""",
)

PARSE.add_argument(
    "-v",
    "--verbose",
    action="store_true",
    default=False,
    help="See detailed information about the runtime process.",
)

PARSE.add_argument(
    "-l",
    "--list-platforms",
    action="store_true",
    default=False,
    help="List all available platforms.",
)

PARSE.add_argument(
    "-c",
    "--capture_device",
    type=str,
    default=None,
    help="""Choose the camera device. To see the capture devices
    run v4l2-ctl --list-devices. Ex: For /dev/video2 we will have 2.""",
)

PARSE.add_argument(
    "--face_model_padding", default=10, type=int, help="Used for face tracking."
)


def list_platforms():
    """displays all the platforms on which the model can be run."""
    print("The list of available platforms: ")
    disponible_platforms = os.listdir("media/platforms")
    for platform in disponible_platforms:
        with open("media/platforms/" + str(platform) + "/platform.json") as json_file:
            content = json_file.read()
            if content:
                parse_json_content = json.loads(content)
                pretty = json.dumps(parse_json_content, indent=5)
                print(
                    "-------------------------- Platform: "
                    + str(platform)
                    + " -----------------------------------"
                )
                print(pretty)
    raise SystemExit


def download_file(name, url, path, retry=3):
    """Function used to download models."""
    if os.path.exists(path):
        os.unlink(path)

    print("Downloading ", name, " model(s) file(s) from", url)
    while retry != 0:
        try:
            req = requests.get(url)
            break
        except Exception:
            retry -= 1
            print("Failed to download file from", url, "Retrying")
    with open(path, "wb") as f:
        f.write(req.content)


def download_model(path, url, platform, model_name):
    """Download the pattern and save it accordingly."""
    path = os.path.join(path.replace(model_name, ""), model_name)
    download_file(model_name, url, path)
    if platform == "imx93":
        os.system("vela " + path)
        vela_model_name = model_name.replace(".tflite", "_vela.tflite")
        os.system("mv ./output/" + vela_model_name + " " + path)


def load_parameters_json():
    """The json file of the desired platform is parsed."""
    with open("media/platforms/" + str(args.platform) + "/platform.json") as f:
        content = f.read()
        if content:
            parse_json_content = json.loads(content)
            args.delegate_path = parse_json_content["delegate"]
            if args.demo == "hand" or args.demo == "both":
                args.path_palm_model = (parse_json_content["hand"])["palm"]
                args.path_landmark_model = (parse_json_content["hand"])["landmark"]
                args.path_anchors = (parse_json_content["hand"])["anchors"]
                if not os.path.exists(args.path_palm_model):
                    print(
                        """For the first use palm_detection_builtin_256_integer_quant.tflite
                                            will be downloaded."""
                    )
                    model_name = "palm_detection_builtin_256_integer_quant.tflite"
                    url = (parse_json_content["hand"])["url_palm"]
                    download_model(args.path_palm_model, url, args.platform, model_name)

                if not os.path.exists(args.path_landmark_model):
                    print(
                        """For the first use hand_landmark_3d_256_integer_quant.tflite
                    will be downloaded."""
                    )
                    url = (parse_json_content["hand"])["url_landmark"]
                    model_name = "hand_landmark_3d_256_integer_quant.tflite"
                    download_model(
                        args.path_landmark_model, url, args.platform, model_name
                    )

            if args.demo == "face" or args.demo == "both":
                args.path_yoloface_model = (parse_json_content["face"])["yoloface"]
                args.path_facenet_model = (parse_json_content["face"])["facenet"]

                if not os.path.exists(args.path_yoloface_model):
                    print(
                        """For the first use yoloface_int8.tflite
                    will be downloaded."""
                    )
                    url = (parse_json_content["face"])["url_yoloface"]
                    model_name = "yoloface_int8.tflite"
                    download_model(
                        args.path_yoloface_model, url, args.platform, model_name
                    )

                if not os.path.exists(args.path_facenet_model):
                    print(
                        """For the first use facenet_512_int_quantized.tflite
                    will be downloaded."""
                    )
                    url = (parse_json_content["face"])["url_facenet"]
                    model_name = "facenet_512_int_quantized.tflite"
                    download_model(
                        args.path_facenet_model, url, args.platform, model_name
                    )

                if args.delegate_path is None:
                    args.face_model_padding = (parse_json_content["face"])[
                        "face_model_padding"
                    ]

            if args.capture_device is None:
                args.capture_device = parse_json_content["capture_device"]


args = PARSE.parse_args()

if args.list_platforms != False:
    list_platforms()

if args.platform != None:
    load_parameters_json()

if args.demo == "hand" or args.demo == "both":
    DETECTOR_HAND = HandTracker(
        args.path_palm_model,
        args.path_landmark_model,
        args.path_anchors,
        args.delegate_path,
        box_shift=0.2,
        box_enlarge=1.3,
    )

if args.demo == "face" or args.demo == "both":
    DETECTOR_FACE = YoloFace(args.path_yoloface_model, args.delegate_path)
    RECOGNIZER = Facenet(args.path_facenet_model, args.delegate_path)
    DATABASE = FaceDatabase()


### Essential Methods
def hand_draw_landmarks(points, frame):
    """Function used to draw the shape of the hand."""
    connections = [
        (5, 6),
        (6, 7),
        (7, 8),
        (9, 10),
        (10, 11),
        (11, 12),
        (13, 14),
        (14, 15),
        (15, 16),
        (0, 5),
        (5, 9),
        (9, 13),
        (13, 17),
        (0, 9),
        (0, 13),
    ]
    connections += [
        (0, 17),
        (17, 18),
        (18, 19),
        (19, 20),
        (0, 1),
        (1, 2),
        (2, 3),
        (3, 4),
    ]

    if points is not None:
        for point in points:
            coordinates_x, coordinates_y = point
            cv2.circle(
                frame, (int(coordinates_x), int(coordinates_y)), 4, (0, 255, 0), 2
            )
        for connection in connections:
            coordinates_x0, coordinates_y0 = points[connection[0]]
            coordinates_x1, coordinates_y1 = points[connection[1]]
            cv2.line(
                frame,
                (int(coordinates_x0), int(coordinates_y0)),
                (int(coordinates_x1), int(coordinates_y1)),
                (255, 0, 0),
                2,
            )


def general_camera_runtime(frames_m1, points_m1, frames_m2, name_m2, coordinates_m2):
    """The camera is turned on and the images displayed."""
    flag = 0
    try:
        cv2.namedWindow("preview")
        index = 0
        VC = cv2.VideoCapture(int(args.capture_device))
        VC.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
        VC.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)

        if VC.isOpened():  # try to get the first frame
            rval, frame = VC.read()
        else:
            rval = False

        while rval:
            if flag == 0:
                cv2.imshow("preview", frame)
                rval, frame = VC.read()
                key = cv2.waitKey(20)
                if key == 27:
                    os.system("pkill python")
                    break

                index_frame = [index, frame]
                if args.demo == "hand" or args.demo == "both":
                    frames_m1.put(index_frame)
                if args.demo == "face" or args.demo == "both":
                    frames_m2.put(index_frame)

                flag = 1

                if not coordinates_m2.empty():
                    coordinates = coordinates_m2.get()
                    cv2.rectangle(
                        frame,
                        (coordinates[0], coordinates[2]),
                        (coordinates[1], coordinates[3]),
                        (0, 0, 255),
                        2,
                    )

                    cv2.putText(
                        frame,
                        name_m2.value,
                        (coordinates[0], coordinates[2] + 13),
                        cv2.FONT_HERSHEY_SIMPLEX,
                        1,
                        (0, 0, 255),
                        2,
                    )

                if not points_m1.empty():
                    hand_draw_landmarks(points_m1.get(), frame)
                index += 1

            if args.demo == "hand" and frames_m1.empty():
                flag = 0
            if args.demo == "face" and frames_m2.empty():
                flag = 0
            if args.demo == "both" and frames_m1.empty() and frames_m2.empty():
                flag = 0

        cv2.destroyWindow("preview")
        VC.release()
    except KeyboardInterrupt as exception_general_camera:
        print(exception_general_camera)


def general_application_end(results_hand, results_face):
    """This function is responsible for writing the results to the csv file."""
    print("\n**** Writing benchmark results...")
    with open(args.save, "w") as file:
        writer = csv.writer(file, delimiter="\t")
        index = 0
        if args.demo == "both":
            writer.writerow(["Index", "Hand (ms)", "Face (ms)"])
            while (not results_hand.empty()) and (not results_face.empty()):
                writer.writerow([index, results_hand.get(), results_face.get()])
                index += 1
        if args.demo == "hand":
            writer.writerow(["Index", "Hand (ms)"])
            while not results_hand.empty():
                writer.writerow([index, results_hand.get()])
                index += 1
        if args.demo == "face":
            writer.writerow(["Index", "Face (ms)"])
            while not results_face.empty():
                writer.writerow([index, results_face.get()])
                index += 1

    print("** Total entries:", index)
    print("**** Benchmark", args.save, "complete.")


### Runtimes
def hand_runtime(frames, points_queue, results):
    """The hand detection runtime function."""
    print("\n**** Hand detection started ****")
    try:
        while True:
            frame_base = frames.get()
            if args.verbose:
                print("**(Model 1) Hand Frame: ", frame_base[0])
            frame = frame_base[1]
            image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

            points, _, timer = DETECTOR_HAND(image, verbose=args.verbose)
            points_queue.put(points)
            results.put(timer)

    except KeyboardInterrupt as excepetion_hand_runtime:
        print(excepetion_hand_runtime)


def face_runtime(frames, name, coordinates, results):
    """The face recognition runtime function."""
    print("\n**** Face detection started ****")
    try:
        while True:
            frame_base = frames.get()
            if args.verbose:
                print("**(Face) Frame:", frame_base[0])
            frame = frame_base[1]

            embeddings = None
            boxes, timer = DETECTOR_FACE.detect(frame, verbose=args.verbose)
            for box in boxes:
                box[[0, 2]] *= frame.shape[1]
                box[[1, 3]] *= frame.shape[0]

                (
                    coordinates_x1,
                    coordinates_y1,
                    coordinates_x2,
                    coordinates_y2,
                ) = box.astype(np.int32)

                hight, width, _ = frame.shape
                coordinates_x1 = max(coordinates_x1 - args.face_model_padding, 0)
                coordinates_x2 = min(coordinates_x2 + args.face_model_padding, width)
                coordinates_y1 = max(coordinates_y1 - args.face_model_padding, 0)
                coordinates_y2 = min(coordinates_y2 + args.face_model_padding, hight)

                coordinates.put(
                    [coordinates_x1, coordinates_x2, coordinates_y1, coordinates_y2]
                )
                face = frame[
                    coordinates_y1:coordinates_y2, coordinates_x1:coordinates_x2
                ]
                embeddings = RECOGNIZER.get_embeddings(face)
                name.value = DATABASE.find_name(embeddings)

            results.put(timer)
    except KeyboardInterrupt as exception_hand_runtime:
        print(exception_hand_runtime)


if __name__ == "__main__":
    try:
        HAND_FRAMES = Queue()
        HAND_POINTS = Queue()
        HAND_RESULTS = Queue()

        FACE_FRAMES = Queue()
        FACE_COORDS = Queue()
        FACE_RESULTS = Queue()

        MANAGER = Manager()
        FACE_LOAD = MANAGER.Value(c_char_p, "**(Face) Starting...")

        if args.save is not None:
            atexit.register(
                general_application_end,
                results_hand=HAND_RESULTS,
                results_face=FACE_RESULTS,
            )

        GENERAL_CAMERA_PROCESS = Process(
            target=general_camera_runtime,
            args=(
                HAND_FRAMES,
                HAND_POINTS,
                FACE_FRAMES,
                FACE_LOAD,
                FACE_COORDS,
            ),
        )

        if args.demo == "hand" or args.demo == "both":
            HAND_PROCESS = Process(
                target=hand_runtime,
                args=(
                    HAND_FRAMES,
                    HAND_POINTS,
                    HAND_RESULTS,
                ),
            )
        if args.demo == "face" or args.demo == "both":
            FACE_PROCESS = Process(
                target=face_runtime,
                args=(
                    FACE_FRAMES,
                    FACE_LOAD,
                    FACE_COORDS,
                    FACE_RESULTS,
                ),
            )

        GENERAL_CAMERA_PROCESS.start()
        time.sleep(5)

        if args.demo == "hand" or args.demo == "both":
            HAND_PROCESS.start()

        if args.demo == "face" or args.demo == "both":
            FACE_PROCESS.start()

        GENERAL_CAMERA_PROCESS.join()

        if args.demo == "hand" or args.demo == "both":
            HAND_PROCESS.join()

        if args.demo == "face" or args.demo == "both":
            FACE_PROCESS.join()

        while True:
            time.sleep(1)
            if args.verbose:
                print(FACE_COORDS)
    except KeyboardInterrupt as excepetion_stop:
        print(excepetion_stop)
        GENERAL_CAMERA_PROCESS.terminate()
        if args.demo == "hand" or args.demo == "both":
            HAND_PROCESS.terminate()
        if args.demo == "face" or args.demo == "both":
            FACE_PROCESS.terminate()
