diff --git a/test/test_packet.py b/test/test_packet.py index 9349134ee..63d07d8db 100644 --- a/test/test_packet.py +++ b/test/test_packet.py @@ -14,35 +14,51 @@ from litex.soc.interconnect.packet import * from .test_stream import StreamPacket, stream_inserter, stream_collector, compare_packets -packet_header_length = 31 -packet_header_fields = { - "field_8b" : HeaderField(0, 0, 8), - "field_16b" : HeaderField(1, 0, 16), - "field_32b" : HeaderField(3, 0, 32), - "field_64b" : HeaderField(7, 0, 64), - "field_128b": HeaderField(15, 0, 128), -} -packet_header = Header( - fields = packet_header_fields, - length = packet_header_length, - swap_field_bytes = True) - -def packet_description(dw): - param_layout = packet_header.get_layout() - payload_layout = [("data", dw)] - return EndpointDescription(payload_layout, param_layout) - -def raw_description(dw): - payload_layout = [("data", dw)] - return EndpointDescription(payload_layout) +def mask_last_be(dw, data, last_be): + masked_data = 0 + for byte in range(dw // 8): + if 2**byte > last_be: + break + masked_data |= data & (0xFF << (byte * 8)) + return masked_data class TestPacket(unittest.TestCase): def loopback_test(self, dw, seed=42, with_last_be=False, debug_print=False): + packet_header_length = 31 + packet_header_fields = { + "field_8b" : HeaderField(0, 0, 8), + "field_16b" : HeaderField(1, 0, 16), + "field_32b" : HeaderField(3, 0, 32), + "field_64b" : HeaderField(7, 0, 64), + "field_128b": HeaderField(15, 0, 128), + } + packet_header = Header( + fields = packet_header_fields, + length = packet_header_length, + swap_field_bytes = True) + + def packet_description(dw): + param_layout = packet_header.get_layout() + payload_layout = [("data", dw)] + + if with_last_be: + payload_layout += [("last_be", dw // 8)] + + return EndpointDescription(payload_layout, param_layout) + + def raw_description(dw): + payload_layout = [("data", dw)] + + if with_last_be: + payload_layout += [("last_be", dw // 8)] + + return EndpointDescription(payload_layout) + prng = random.Random(seed) # Prepare packets - npackets = 8 + npackets = 64 packets = [] for n in range(npackets): header = {} @@ -56,10 +72,16 @@ class TestPacket(unittest.TestCase): class DUT(Module): def __init__(self): - self.submodules.packetizer = \ - Packetizer(packet_description(dw), raw_description(dw), packet_header) - self.submodules.depacketizer = \ - Depacketizer(raw_description(dw), packet_description(dw), packet_header) + self.submodules.packetizer = Packetizer( + packet_description(dw), + raw_description(dw), + packet_header, + ) + self.submodules.depacketizer = Depacketizer( + raw_description(dw), + packet_description(dw), + packet_header, + ) self.comb += self.packetizer.source.connect(self.depacketizer.sink) self.sink, self.source = self.packetizer.sink, self.depacketizer.source @@ -102,11 +124,23 @@ class TestPacket(unittest.TestCase): def test_8bit_loopback(self): self.loopback_test(dw=8) + def test_8bit_loopback_last_be(self): + self.loopback_test(dw=8, with_last_be=True) + def test_32bit_loopback(self): self.loopback_test(dw=32) + def test_32bit_loopback_last_be(self): + self.loopback_test(dw=32, with_last_be=True) + def test_64bit_loopback(self): self.loopback_test(dw=64) + def test_64bit_loopback_last_be(self): + self.loopback_test(dw=64, with_last_be=True) + def test_128bit_loopback(self): self.loopback_test(dw=128) + + def test_128bit_loopback_last_be(self): + self.loopback_test(dw=128, with_last_be=True)