#!/usr/bin/env python3
#
# Copyright 2021 NXP
# SPDX-License-Identifier: BSD-3-Clause
#
# TEST CODE of NXP USBSIO Library - I2C tests
#

from abc import abstractmethod
import unittest
import functools
import logging
import sys
import os

from test import *

# global I2C test parameters
I2C_PORT = 0
I2C_DEV_ADDR = 20
I2C_BUS_SPEED = 100000
I2C_ECHO_DATA = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ";
#I2C_ECHO_DATA = b"1234567890";

# SIOTEST protocol description:
#
#  RECEIVE:   >[XXXX][cmd]$
#     >      - start of frame character which resets our receiver state machine
#     [XXXX] - four hex bytes indicating length of [cmd]
#     [cmd]  - command and arguments as a text, multiple commands may be separated by ;
#     $      - command terminator to validate the [cmd] length matches XXXX
#  TRANSMIT:  <[YYYY][resp]$
#     <      - start of frame character which resets the remote receiver's state machine
#     [YYYY] - four hex bytes indicating length of [resp]
#     [resp] - response text (generated by commands processing)
#     $      - response terminator to validate the [resp] length matches YYYY

def use_i2c(portNum):
    '''Decorator to automatically open the SIO library and I2C port'''
    def decorator(func):
        @functools.wraps(func)
        def wrapper(self, *args, **kwargs):
            if self.OpenDefaultPort():
                self.i2c = self.sio.I2C_Open(I2C_BUS_SPEED, portNum=portNum)
                if self.i2c:
                    return func(self, *args, **kwargs)
            raise Exception("The 'use_i2c(%d)' decorator has failed to open I2C port" % portNum)
        return wrapper
    return decorator

class TestI2C(TestBase):

    def test_I2C_NormalTransferOptions(self):
        self.assertEqual(LIBUSBSIO._I2C_NormalXferOptions(0, start=0, stop=0, ignoreNAK=1, nackLastByte=0, noAddress=0),
                         0)
        self.assertEqual(LIBUSBSIO._I2C_NormalXferOptions(0, start=0, stop=0, ignoreNAK=0, nackLastByte=0, noAddress=0),
                         LIBUSBSIO.I2C_TRANSFER_OPTIONS_BREAK_ON_NACK)
        self.assertEqual(LIBUSBSIO._I2C_NormalXferOptions(0, start=1, stop=0, ignoreNAK=0, nackLastByte=0, noAddress=0),
                         LIBUSBSIO.I2C_TRANSFER_OPTIONS_START_BIT | LIBUSBSIO.I2C_TRANSFER_OPTIONS_BREAK_ON_NACK)
        self.assertEqual(LIBUSBSIO._I2C_NormalXferOptions(0, start=1, stop=1, ignoreNAK=0, nackLastByte=0, noAddress=0),
                         LIBUSBSIO.I2C_TRANSFER_OPTIONS_START_BIT | LIBUSBSIO.I2C_TRANSFER_OPTIONS_STOP_BIT | LIBUSBSIO.I2C_TRANSFER_OPTIONS_BREAK_ON_NACK)
        self.assertEqual(LIBUSBSIO._I2C_NormalXferOptions(0, start=0, stop=0, ignoreNAK=0, nackLastByte=1, noAddress=0),
                         LIBUSBSIO.I2C_TRANSFER_OPTIONS_NACK_LAST_BYTE | LIBUSBSIO.I2C_TRANSFER_OPTIONS_BREAK_ON_NACK)
        self.assertEqual(LIBUSBSIO._I2C_NormalXferOptions(0, start=0, stop=0, ignoreNAK=0, nackLastByte=0, noAddress=1),
                         LIBUSBSIO.I2C_TRANSFER_OPTIONS_NO_ADDRESS | LIBUSBSIO.I2C_TRANSFER_OPTIONS_BREAK_ON_NACK)

    @use_i2c(I2C_PORT)
    def test_I2C_OpenClose(self):
        self.assertTrue(isinstance(self.i2c, LIBUSBSIO.I2C))
        self.assertTrue(self.i2c._h)
        self.assertEqual(self.i2c.Close(), 0)
        self.assertFalse(self.i2c._h)

    @use_i2c(I2C_PORT)
    def test_I2C_Reset(self):
        self.assertEqual(self.i2c.Reset(), 0)

    @use_i2c(I2C_PORT)
    def test_I2C_IsAnyOpen(self):
        self.assertTrue(isinstance(self.i2c, LIBUSBSIO.I2C))
        self.assertTrue(self.sio.IsAnyPortOpen())
        self.assertEqual(self.i2c.Close(), 0)
        self.assertFalse(self.sio.IsAnyPortOpen())

# the USE_MULTI_PART bit determines if multi-part operation is used at all, the PART_xxx identify which part it is
(NO_MULTI_PART, USE_MULTI_PART, PART_FIRST, PART_LAST) = (0x0, 0x01, 0x02, 0x04)

class TestI2C_BaseComm(TestBase):
    def setUp(self):
        super().setUp()
        # by default, we use isolated write and read transactions, each with start+stop
        self.MultiPartWrite = NO_MULTI_PART
        self.MultiPartRead  = NO_MULTI_PART

    # Abstract I2C comm methods used to implement the SIOTEST protocol. The functions will
    # be overridden below to use Write/Read transactions or a FastXFER transaction
    @abstractmethod
    def i2c_write(self, devAddr:int,  txData:bytes, multipartWrite:int=NO_MULTI_PART, **kwargs):
        '''Base method will be overridden by WriteRead or XFER test'''
        raise Exception("this function must be overridden")
    @abstractmethod
    def i2c_read(self, devAddr:int, rxSize:int, multipartRead:int=NO_MULTI_PART, **kwargs):
        '''Base method will be overridden by WriteRead or XFER test'''
        raise Exception("this function must be overridden")
    @abstractmethod
    def i2c_write_read(self, devAddr:int, txData:bytes, rxSize:int, multipartRead:int=NO_MULTI_PART, **kwargs):
        '''Base method will be overridden by WriteRead or XFER test'''
        raise Exception("this function must be overridden")

    # SIOTEST protocol implementation, send Command and receive Response, return tuple [OK, ResponseData]
    def i2c_siotest_send_cmd(self, data:bytes):
        '''I2C write+read transfer command to the siotest application at MCU side.
        This call uses abstract i2c_write_read method which is overridden by test class
        to be either Write+Read or FastXFer.
        '''
        # command frame is >XXXX[cmd]$
        size = ">%04x" % len(data)
        tx = size.encode() + data + b'$'

        # send and receive 5 bytes of answer back
        rx, ret = self.i2c_write_read(I2C_DEV_ADDR, tx, rxSize=5, multipartRead=self.MultiPartRead | PART_FIRST)
        self.assertEqual(5, len(rx), "Expected to write[%d]:%s and read[%d], but got %d" % (len(tx), tx, 5, ret))

        # parse received <YYYY response length
        if(ret != 5):
            return False, None
        if(rx[0:1] != b'<'):
            self.assertTrue(False, "Expected SOF byte '<' but received %s" % rx[0:1])
            return False, None
        try:
            sz = int(rx[1:], 16)
        except:
            self.assertTrue(False, "Expected response SIZE could not be parsed %s" % rx[1:])
            return False, None

        # read the data payload and EOF($)
        rx, ret = self.i2c_write_read(I2C_DEV_ADDR, None, rxSize=sz+1, multipartRead=self.MultiPartRead | PART_LAST)
        self.assertEqual(ret, sz+1, "Expected to read data[%d]+EOF, but received [%d]:%s" % (sz, ret, rx))

        if(ret == (sz+1)):
            if(rx[-1:] != b'$'):
                self.assertTrue(False, "Response not terminated with EOF %s" % rx)
                return False, None
        else:
            return False, None

        # success!
        return True, rx[:-1]

    def siotest_i2c_echo(self, minlen, maxlen):
        pattern = I2C_ECHO_DATA * int(maxlen/(len(I2C_ECHO_DATA))+1)
        for sz in range(minlen, maxlen):
            cmd = b"echo " + pattern[:sz]
            ok, resp = self.i2c_siotest_send_cmd(cmd)
            self.assertTrue(ok, "Command 'echo' with test pattern sz=%d was send properly" % sz)
            if ok:
                self.assertEqual(sz, len(resp), "I2C test echo data length:%d matching expected:%d" % (len(resp), sz))
                self.assertEqual(resp, pattern[:len(resp)], "I2C test echo data content matching. received:%s" % resp)

    def skip_in_base(func):
        @functools.wraps(func)
        def wrapper(self, *args, **kwargs):
            if(type(self).__name__ == "TestI2C_BaseComm"):
                self.logger.info("This test is skipped in the base class, always passing")
            else:
                return func(self, *args, **kwargs)
        return wrapper

    ## base tests inherited in WriteRead, XFER and other test classes which override the i2c access methods
    @skip_in_base
    @use_i2c(I2C_PORT)
    def test_I2C_Transfer_Echo_single(self):
        self.siotest_i2c_echo(1, 2)

    @skip_in_base
    @use_i2c(I2C_PORT)
    def test_I2C_Transfer_Echo_short(self):
        self.siotest_i2c_echo(20, 60)

    @skip_in_base
    @use_i2c(I2C_PORT)
    def test_I2C_Transfer_Echo_long(self):
        self.siotest_i2c_echo(250, 300)

    @unittest.skipUnless(RUN_SLOW_TESTS, "Takes long time")
    @skip_in_base
    @use_i2c(I2C_PORT)
    def test_I2C_Transfer_Echo_full(self):
        self.siotest_i2c_echo(1, self.sio.GetMaxDataSize() - 10)

class TestI2C_XFER(TestI2C_BaseComm):
    def i2c_write(self, devAddr:int,  txData:bytes, multipartWrite:int=NO_MULTI_PART, **kwargs):
        self.assertFalse(multipartWrite & USE_MULTI_PART, "Cannot handle multi-part write in FastXfer method")
        ret = self.i2c.FastXfer(devAddr, txData, rxSize=0, **kwargs)
        return ret
    def i2c_read(self, devAddr:int, rxSize:int, multipartRead:int=NO_MULTI_PART, **kwargs):
        self.assertFalse(multipartRead & USE_MULTI_PART, "Cannot handle multi-part reads in FastXfer method")
        rx, ret = self.i2c.FastXfer(devAddr, None, rxSize=rxSize, **kwargs)
        return rx, ret
    def i2c_write_read(self, devAddr:int, txData:bytes, rxSize:int, multipartRead:int=NO_MULTI_PART, **kwargs):
        self.assertFalse(multipartRead & USE_MULTI_PART, "Cannot handle multi-part reads in FastXfer method")
        rx, ret = self.i2c.FastXfer(devAddr, txData, rxSize=rxSize, **kwargs)
        return rx, ret

class TestI2C_WriteRead(TestI2C_BaseComm):
    '''WriteRead test uses I2C_DeviceWrite and I2C_DeviceRead library calls to process the transaction.
    Note that this calls is also ready to support the MultipartRead and MultipartWrite operations
    according to the multipartXXX parameters passed into the i2c_read and i2c_write'''
    def get_transaction_options(self, multipart:int, reading:bool, kwargs):
        if(multipart & USE_MULTI_PART):
            if((multipart & PART_FIRST) and (multipart & PART_LAST)):
                # this is actually a single transaction with start stop
                kwargs["start"] = True
                kwargs["stop"] = True
                kwargs["noAddress"] = False
                if(reading):
                    kwargs["nackLastByte"] = True
            elif((multipart & PART_FIRST) and not (multipart & PART_LAST)):
                # first part
                kwargs["start"] = True
                kwargs["stop"] = False
                kwargs["noAddress"] = False
                if(reading):
                    kwargs["nackLastByte"] = False
            elif(not (multipart & PART_FIRST) and (multipart & PART_LAST)):
                # last part
                kwargs["start"] = False
                kwargs["stop"] = True
                kwargs["noAddress"] = True
                if(reading):
                    kwargs["nackLastByte"] = True
            else:
                # middle part
                kwargs["start"] = False
                kwargs["stop"] = False
                kwargs["noAddress"] = True
                if(reading):
                    kwargs["nackLastByte"] = False
        return kwargs

    def i2c_write(self, devAddr:int,  txData:bytes, multipartWrite:int=NO_MULTI_PART, **kwargs):
        kwargs = self.get_transaction_options(multipartWrite, False, kwargs)
        ret = self.i2c.DeviceWrite(devAddr, txData, **kwargs)
        return ret

    def i2c_read(self, devAddr:int, rxSize:int, multipartRead:int=NO_MULTI_PART, **kwargs):
        kwargs = self.get_transaction_options(multipartRead, True, kwargs)
        rx, ret = self.i2c.DeviceRead(devAddr, rxSize, **kwargs)
        return rx, ret

    def i2c_write_read(self, devAddr:int, txData:bytes, rxSize:int, multipartRead:int=NO_MULTI_PART, **kwargs):
        rx, ret = b'', 0
        if(txData):
            ret = self.i2c_write(devAddr, txData, **kwargs)
        if(ret >= 0):
            kwargs = self.get_transaction_options(multipartRead, True, kwargs)
            rx, ret = self.i2c_read(devAddr, rxSize, **kwargs)
        return rx, ret


class TestI2C_WriteRead_MultipartRead(TestI2C_WriteRead):
    '''MultipartRead test instructs the base I2C test class to split the incoming I2C read requests to two
    and process them using two I2C_DeviceRead library calls with partial options (no STOP and NACK generated in
    the first half and no START+ADDRESS generated in the second half)'''
    def setUp(self):
        super().setUp()
        # use split-read transactions (see the base class above)
        self.MultiPartRead = USE_MULTI_PART

    @known_issue
    def test_I2C_Transfer_Echo_single(self):
        super().test_I2C_Transfer_Echo_single()

    @known_issue
    def test_I2C_Transfer_Echo_short(self):
        super().test_I2C_Transfer_Echo_short()

    @known_issue
    def test_I2C_Transfer_Echo_long(self):
        super.test_I2C_Transfer_Echo_long()

    @known_issue
    def test_I2C_Transfer_Echo_full(self):
        super.test_I2C_Transfer_Echo_full()

class TestI2C_WriteRead_MultipartWrite(TestI2C_WriteRead):
    '''MultipartWrite test splits the outgoing I2C write requests to two halves and tries to send
    them using two I2C_DeviceWrite library calls with partial options (no STOP generated in
    the first half and no START+ADDRESS generated in the second half)'''
    def setUp(self):
        super().setUp()
        # use split-write transactions in normal operations.
        # note that this class also overrides the "normal" i2c_write and tries to split them
        # into partial writes
        self.MultiPartWrite = USE_MULTI_PART

    def i2c_write(self, devAddr:int,  txData:bytes, multipartWrite:int=NO_MULTI_PART, **kwargs):
        '''artificially split normal write request to two partial writes'''
        if(multipartWrite == NO_MULTI_PART and len(txData) >= 2):
            txLen1 = len(txData) // 2
            txData1 = txData[:txLen1]
            txData2 = txData[txLen1:]
            self.logger.debug("Writing 1st half of the I2C write request[%d]: %s" % (len(txData1), txData1))
            ret1 = super().i2c_write(devAddr, txData1, USE_MULTI_PART | PART_FIRST, **kwargs)
            self.assertEqual(ret1, len(txData1), "Split I2C write, 1st part sent all %d bytes (result=%d)" % (len(txData1), ret1))
            self.logger.debug("Writing 2nd half of the I2C write request[%d]: %s" % (len(txData2), txData2))
            ret2 = super().i2c_write(devAddr, txData2, USE_MULTI_PART | PART_LAST, **kwargs)
            self.assertEqual(ret2, len(txData2), "Split I2C write, 2nd part sent all %d bytes (result=%d)" % (len(txData2), ret2))
            ret = ret1+ret2
        else:
            self.logger.debug("Not splitting the I2C write request")
            ret = self.i2c.DeviceWrite(devAddr, txData, **kwargs)
        return ret

    @known_issue
    def test_I2C_Transfer_Echo_single(self):
        super().test_I2C_Transfer_Echo_single()

    @known_issue
    def test_I2C_Transfer_Echo_short(self):
        super().test_I2C_Transfer_Echo_short()

    @known_issue
    def test_I2C_Transfer_Echo_long(self):
        super.test_I2C_Transfer_Echo_long()

    @known_issue
    def test_I2C_Transfer_Echo_full(self):
        super.test_I2C_Transfer_Echo_full()
