diff --git a/test/test_packet.py b/test/test_packet.py new file mode 100644 index 000000000..0fe6bf228 --- /dev/null +++ b/test/test_packet.py @@ -0,0 +1,107 @@ +# This file is Copyright (c) 2019 Florent Kermarrec +# License: BSD + +import unittest +import random + +from migen import * + +from litex.soc.interconnect.stream import * +from litex.soc.interconnect.packet import * + +packet_header_length = 31 +packet_header_fields = { + "field_8b" : HeaderField(0, 0, 8), + "field_16b" : HeaderField(1, 0, 16), + "field_32b" : HeaderField(3, 0, 32), + "field_64b" : HeaderField(7, 0, 64), + "field_128b": HeaderField(15, 0, 128), +} +packet_header = Header( + fields = packet_header_fields, + length = packet_header_length, + swap_field_bytes = True) + +def packet_description(dw): + param_layout = packet_header.get_layout() + payload_layout = [("data", dw)] + return EndpointDescription(payload_layout, param_layout) + +def raw_description(dw): + payload_layout = [("data", dw)] + return EndpointDescription(payload_layout) + +class Packet: + def __init__(self, header, datas): + self.header = header + self.datas = datas + + +class TestPacket(unittest.TestCase): + def test_loopback(self): + prng = random.Random(42) + # Prepare packets + npackets = 8 + packets = [] + for n in range(npackets): + header = {} + header["field_8b"] = prng.randrange(2**8) + header["field_16b"] = prng.randrange(2**16) + header["field_32b"] = prng.randrange(2**32) + header["field_64b"] = prng.randrange(2**64) + header["field_128b"] = prng.randrange(2**128) + datas = [prng.randrange(2**8) for _ in range(prng.randrange(2**7))] + packets.append(Packet(header, datas)) + + def generator(dut): + # Send packets + for packet in packets: + yield dut.sink.field_8b.eq(packet.header["field_8b"]) + yield dut.sink.field_16b.eq(packet.header["field_16b"]) + yield dut.sink.field_32b.eq(packet.header["field_32b"]) + yield dut.sink.field_64b.eq(packet.header["field_64b"]) + yield dut.sink.field_128b.eq(packet.header["field_128b"]) + yield + for n, data in enumerate(packet.datas): + yield dut.sink.valid.eq(1) + yield dut.sink.last.eq(n == (len(packet.datas) - 1)) + yield dut.sink.data.eq(data) + yield + while (yield dut.sink.ready) == 0: + yield + dut.sink.valid.eq(0) + + def checker(dut): + dut.header_errors = 0 + dut.data_errors = 0 + dut.last_errors = 0 + # Receive and check packets + yield dut.source.ready.eq(1) + for packet in packets: + for n, data in enumerate(packet.datas): + while (yield dut.source.valid) == 0: + yield + for field in ["field_8b", "field_16b", "field_32b", "field_64b", "field_128b"]: + if (yield getattr(dut.source, field)) != packet.header[field]: + dut.header_errors += 1 + #print("{:02x} vs {:02x}".format((yield dut.source.data), data)) + if ((yield dut.source.data) != data): + dut.data_errors += 1 + if ((yield dut.source.last) != (n == (len(packet.datas) - 1))): + dut.last_errors += 1 + yield + yield + + class DUT(Module): + def __init__(self): + packetizer = Packetizer(packet_description(8), raw_description(8), packet_header) + depacketizer = Depacketizer(raw_description(8), packet_description(8), packet_header) + self.submodules += packetizer, depacketizer + self.comb += packetizer.source.connect(depacketizer.sink) + self.sink, self.source = packetizer.sink, depacketizer.source + + dut = DUT() + run_simulation(dut, [generator(dut), checker(dut)]) + self.assertEqual(dut.header_errors, 0) + self.assertEqual(dut.data_errors, 0) + self.assertEqual(dut.last_errors, 0)