From 32bb2554bc7410b05c88cf12fcb3de1ffc011098 Mon Sep 17 00:00:00 2001 From: Florent Kermarrec Date: Sat, 23 Oct 2021 17:40:41 +0200 Subject: [PATCH] test: Rename new test_packet/stream to test_packet2/stream2 and revert old tests. Old and new tests are complementary and would need to be merged. --- test/test_packet.py | 245 ++++++++++++------------------ test/test_packet2.py | 188 +++++++++++++++++++++++ test/test_stream.py | 349 ++++--------------------------------------- test/test_stream2.py | 339 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 653 insertions(+), 468 deletions(-) create mode 100644 test/test_packet2.py create mode 100644 test/test_stream2.py diff --git a/test/test_packet.py b/test/test_packet.py index 83fc8eb25..5d5c3bfd4 100644 --- a/test/test_packet.py +++ b/test/test_packet.py @@ -12,177 +12,118 @@ from migen import * from litex.soc.interconnect.stream import * from litex.soc.interconnect.packet import * -from .test_stream import StreamPacket, stream_inserter, stream_collector, compare_packets +packet_header_length = 31 +packet_header_fields = { + "field_8b" : HeaderField(0, 0, 8), + "field_16b" : HeaderField(1, 0, 16), + "field_32b" : HeaderField(3, 0, 32), + "field_64b" : HeaderField(7, 0, 64), + "field_128b": HeaderField(15, 0, 128), +} +packet_header = Header( + fields = packet_header_fields, + length = packet_header_length, + swap_field_bytes = True) -def mask_last_be(dw, data, last_be): - masked_data = 0 +def packet_description(dw): + param_layout = packet_header.get_layout() + payload_layout = [("data", dw)] + return EndpointDescription(payload_layout, param_layout) - for byte in range(dw // 8): - if 2**byte > last_be: - break - masked_data |= data & (0xFF << (byte * 8)) +def raw_description(dw): + payload_layout = [("data", dw)] + return EndpointDescription(payload_layout) + +class Packet: + def __init__(self, header, datas): + self.header = header + self.datas = datas - return masked_data class TestPacket(unittest.TestCase): - def loopback_test(self, dw, seed=42, with_last_be=False, debug_print=False): - # Independent random number generator to ensure we're the - # stream_inserter and stream_collectors still have - # reproducible behavior independent of the headers - prng = random.Random(seed + 5) - - # Generate a random number of differently sized header fields - nheader_fields = prng.randrange(16) - i = 0 - packet_header_length = 0 - packet_header_fields = {} - while packet_header_length < dw // 8 or i < nheader_fields: - # Header field size can be 1, 2, 4, 8, 16 bytes - field_length = 2**prng.randrange(5) - packet_header_fields["field{}_{}b".format(i, field_length * 8)] = \ - HeaderField(packet_header_length, 0, field_length * 8) - packet_header_length += field_length - i += 1 - - packet_header = Header( - fields = packet_header_fields, - length = packet_header_length, - swap_field_bytes = bool(prng.getrandbits(1))) - - def packet_description(dw): - param_layout = packet_header.get_layout() - payload_layout = [("data", dw)] - - if with_last_be: - payload_layout += [("last_be", dw // 8)] - - return EndpointDescription(payload_layout, param_layout) - - def raw_description(dw): - payload_layout = [("data", dw)] - - if with_last_be: - payload_layout += [("last_be", dw // 8)] - - return EndpointDescription(payload_layout) - + def loopback_test(self, dw): + prng = random.Random(42) # Prepare packets - npackets = 32 + npackets = 8 packets = [] for n in range(npackets): header = {} - for name, headerfield in packet_header_fields.items(): - header[name] = prng.randrange(2**headerfield.width) - datas = [prng.randrange(2**8) for _ in range(prng.randrange(dw - 1) + 1)] - packets.append(StreamPacket(datas, header)) + header["field_8b"] = prng.randrange(2**8) + header["field_16b"] = prng.randrange(2**16) + header["field_32b"] = prng.randrange(2**32) + header["field_64b"] = prng.randrange(2**64) + header["field_128b"] = prng.randrange(2**128) + datas = [prng.randrange(2**dw) for _ in range(prng.randrange(2**7))] + packets.append(Packet(header, datas)) + + def generator(dut, valid_rand=50): + # Send packets + for packet in packets: + yield dut.sink.field_8b.eq(packet.header["field_8b"]) + yield dut.sink.field_16b.eq(packet.header["field_16b"]) + yield dut.sink.field_32b.eq(packet.header["field_32b"]) + yield dut.sink.field_64b.eq(packet.header["field_64b"]) + yield dut.sink.field_128b.eq(packet.header["field_128b"]) + yield + for n, data in enumerate(packet.datas): + yield dut.sink.valid.eq(1) + yield dut.sink.last.eq(n == (len(packet.datas) - 1)) + yield dut.sink.data.eq(data) + yield + while (yield dut.sink.ready) == 0: + yield + yield dut.sink.valid.eq(0) + yield dut.sink.last.eq(0) + while prng.randrange(100) < valid_rand: + yield + + def checker(dut, ready_rand=50): + dut.header_errors = 0 + dut.data_errors = 0 + dut.last_errors = 0 + # Receive and check packets + for packet in packets: + for n, data in enumerate(packet.datas): + yield dut.source.ready.eq(0) + yield + while (yield dut.source.valid) == 0: + yield + while prng.randrange(100) < ready_rand: + yield + yield dut.source.ready.eq(1) + yield + for field in ["field_8b", "field_16b", "field_32b", "field_64b", "field_128b"]: + if (yield getattr(dut.source, field)) != packet.header[field]: + dut.header_errors += 1 + #print("{:x} vs {:x}".format((yield dut.source.data), data)) + if ((yield dut.source.data) != data): + dut.data_errors += 1 + if ((yield dut.source.last) != (n == (len(packet.datas) - 1))): + dut.last_errors += 1 + yield class DUT(Module): def __init__(self): - self.submodules.packetizer = Packetizer( - packet_description(dw), - raw_description(dw), - packet_header, - ) - self.submodules.depacketizer = Depacketizer( - raw_description(dw), - packet_description(dw), - packet_header, - ) - self.comb += self.packetizer.source.connect(self.depacketizer.sink) - self.sink, self.source = self.packetizer.sink, self.depacketizer.source + packetizer = Packetizer(packet_description(dw), raw_description(dw), packet_header) + depacketizer = Depacketizer(raw_description(dw), packet_description(dw), packet_header) + self.submodules += packetizer, depacketizer + self.comb += packetizer.source.connect(depacketizer.sink) + self.sink, self.source = packetizer.sink, depacketizer.source dut = DUT() - recvd_packets = [] - run_simulation( - dut, - [ - stream_inserter( - dut.sink, - src=packets, - seed=seed, - debug_print=debug_print, - valid_rand=50, - ), - stream_collector( - dut.source, - dest=recvd_packets, - expect_npackets=npackets, - seed=seed, - debug_print=debug_print, - ready_rand=50, - ), - ], - ) - - # When we don't have a last_be signal, the Packetizer will simply throw - # away the partial bus word. The Depacketizer will then fill up these - # values with garbage again. Thus we also have to remove the proper - # amount of bytes from the sent packets so the comparson will work. - if not with_last_be and dw != 8: - # Modulo operation which returns the divisor instead of zero. - def upmod(a, b): - return b if a % b == 0 else a % b - - for (packet, recvd_packet) in zip(packets, recvd_packets): - # How many bytes of the header have to be interleaved with the - # first data word on the bus. - header_leftover = packet_header_length % (dw // 8) - - # If the last word of our data would fit together with the - # header_leftover bytes in a single bus word, all data (plus - # some trailing garbage) will arrive. Otherwise, some data bytes - # will be missing. - if header_leftover != 0 and \ - header_leftover + upmod(len(packet.data), dw // 8) <= (dw // 8): - # The entire data will arrive, plus some trailing - # garbage. Remove that. - garbage_bytes = -len(packet.data) % (dw // 8) - recvd_packet.data = recvd_packet.data[:-garbage_bytes] - else: - # header_leftover bytes in received data have been replaced - # with garbage. Remove these bytes from the received and - # sent data. - recvd_packet.data = recvd_packet.data[:-header_leftover] - packet.data = packet.data[:len(recvd_packet.data)] - - self.assertTrue(compare_packets(packets, recvd_packets)) + run_simulation(dut, [generator(dut), checker(dut)]) + self.assertEqual(dut.header_errors, 0) + self.assertEqual(dut.data_errors, 0) + self.assertEqual(dut.last_errors, 0) def test_8bit_loopback(self): - for seed in range(42, 48): - with self.subTest(seed=seed): - self.loopback_test(dw=8, seed=seed) - - def test_8bit_loopback_last_be(self): - for seed in range(42, 48): - with self.subTest(seed=seed): - self.loopback_test(dw=8, seed=seed, with_last_be=True) + self.loopback_test(dw=8) def test_32bit_loopback(self): - for seed in range(42, 48): - with self.subTest(seed=seed): - self.loopback_test(dw=32, seed=seed) - - def test_32bit_loopback_last_be(self): - for seed in range(42, 48): - with self.subTest(seed=seed): - self.loopback_test(dw=32, seed=seed, with_last_be=True) + self.loopback_test(dw=32) def test_64bit_loopback(self): - for seed in range(42, 48): - with self.subTest(seed=seed): - self.loopback_test(dw=64, seed=seed) - - def test_64bit_loopback_last_be(self): - for seed in range(42, 48): - with self.subTest(seed=seed): - self.loopback_test(dw=64, seed=seed, with_last_be=True) + self.loopback_test(dw=64) def test_128bit_loopback(self): - for seed in range(42, 48): - with self.subTest(seed=seed): - self.loopback_test(dw=128, seed=seed) - - def test_128bit_loopback_last_be(self): - for seed in range(42, 48): - with self.subTest(seed=seed): - self.loopback_test(dw=128, seed=seed, with_last_be=True) + self.loopback_test(dw=128) diff --git a/test/test_packet2.py b/test/test_packet2.py new file mode 100644 index 000000000..bd315d4d7 --- /dev/null +++ b/test/test_packet2.py @@ -0,0 +1,188 @@ +# +# This file is part of LiteX. +# +# Copyright (c) 2021 Leon Schuermann +# SPDX-License-Identifier: BSD-2-Clause + +import unittest +import random + +from migen import * + +from litex.soc.interconnect.stream import * +from litex.soc.interconnect.packet import * + +from .test_stream import StreamPacket, stream_inserter, stream_collector, compare_packets + +def mask_last_be(dw, data, last_be): + masked_data = 0 + + for byte in range(dw // 8): + if 2**byte > last_be: + break + masked_data |= data & (0xFF << (byte * 8)) + + return masked_data + +class TestPacket(unittest.TestCase): + def loopback_test(self, dw, seed=42, with_last_be=False, debug_print=False): + # Independent random number generator to ensure we're the + # stream_inserter and stream_collectors still have + # reproducible behavior independent of the headers + prng = random.Random(seed + 5) + + # Generate a random number of differently sized header fields + nheader_fields = prng.randrange(16) + i = 0 + packet_header_length = 0 + packet_header_fields = {} + while packet_header_length < dw // 8 or i < nheader_fields: + # Header field size can be 1, 2, 4, 8, 16 bytes + field_length = 2**prng.randrange(5) + packet_header_fields["field{}_{}b".format(i, field_length * 8)] = \ + HeaderField(packet_header_length, 0, field_length * 8) + packet_header_length += field_length + i += 1 + + packet_header = Header( + fields = packet_header_fields, + length = packet_header_length, + swap_field_bytes = bool(prng.getrandbits(1))) + + def packet_description(dw): + param_layout = packet_header.get_layout() + payload_layout = [("data", dw)] + + if with_last_be: + payload_layout += [("last_be", dw // 8)] + + return EndpointDescription(payload_layout, param_layout) + + def raw_description(dw): + payload_layout = [("data", dw)] + + if with_last_be: + payload_layout += [("last_be", dw // 8)] + + return EndpointDescription(payload_layout) + + # Prepare packets + npackets = 32 + packets = [] + for n in range(npackets): + header = {} + for name, headerfield in packet_header_fields.items(): + header[name] = prng.randrange(2**headerfield.width) + datas = [prng.randrange(2**8) for _ in range(prng.randrange(dw - 1) + 1)] + packets.append(StreamPacket(datas, header)) + + class DUT(Module): + def __init__(self): + self.submodules.packetizer = Packetizer( + packet_description(dw), + raw_description(dw), + packet_header, + ) + self.submodules.depacketizer = Depacketizer( + raw_description(dw), + packet_description(dw), + packet_header, + ) + self.comb += self.packetizer.source.connect(self.depacketizer.sink) + self.sink, self.source = self.packetizer.sink, self.depacketizer.source + + dut = DUT() + recvd_packets = [] + run_simulation( + dut, + [ + stream_inserter( + dut.sink, + src=packets, + seed=seed, + debug_print=debug_print, + valid_rand=50, + ), + stream_collector( + dut.source, + dest=recvd_packets, + expect_npackets=npackets, + seed=seed, + debug_print=debug_print, + ready_rand=50, + ), + ], + ) + + # When we don't have a last_be signal, the Packetizer will simply throw + # away the partial bus word. The Depacketizer will then fill up these + # values with garbage again. Thus we also have to remove the proper + # amount of bytes from the sent packets so the comparson will work. + if not with_last_be and dw != 8: + # Modulo operation which returns the divisor instead of zero. + def upmod(a, b): + return b if a % b == 0 else a % b + + for (packet, recvd_packet) in zip(packets, recvd_packets): + # How many bytes of the header have to be interleaved with the + # first data word on the bus. + header_leftover = packet_header_length % (dw // 8) + + # If the last word of our data would fit together with the + # header_leftover bytes in a single bus word, all data (plus + # some trailing garbage) will arrive. Otherwise, some data bytes + # will be missing. + if header_leftover != 0 and \ + header_leftover + upmod(len(packet.data), dw // 8) <= (dw // 8): + # The entire data will arrive, plus some trailing + # garbage. Remove that. + garbage_bytes = -len(packet.data) % (dw // 8) + recvd_packet.data = recvd_packet.data[:-garbage_bytes] + else: + # header_leftover bytes in received data have been replaced + # with garbage. Remove these bytes from the received and + # sent data. + recvd_packet.data = recvd_packet.data[:-header_leftover] + packet.data = packet.data[:len(recvd_packet.data)] + + self.assertTrue(compare_packets(packets, recvd_packets)) + + def test_8bit_loopback(self): + for seed in range(42, 48): + with self.subTest(seed=seed): + self.loopback_test(dw=8, seed=seed) + + def test_8bit_loopback_last_be(self): + for seed in range(42, 48): + with self.subTest(seed=seed): + self.loopback_test(dw=8, seed=seed, with_last_be=True) + + def test_32bit_loopback(self): + for seed in range(42, 48): + with self.subTest(seed=seed): + self.loopback_test(dw=32, seed=seed) + + def test_32bit_loopback_last_be(self): + for seed in range(42, 48): + with self.subTest(seed=seed): + self.loopback_test(dw=32, seed=seed, with_last_be=True) + + def test_64bit_loopback(self): + for seed in range(42, 48): + with self.subTest(seed=seed): + self.loopback_test(dw=64, seed=seed) + + def test_64bit_loopback_last_be(self): + for seed in range(42, 48): + with self.subTest(seed=seed): + self.loopback_test(dw=64, seed=seed, with_last_be=True) + + def test_128bit_loopback(self): + for seed in range(42, 48): + with self.subTest(seed=seed): + self.loopback_test(dw=128, seed=seed) + + def test_128bit_loopback_last_be(self): + for seed in range(42, 48): + with self.subTest(seed=seed): + self.loopback_test(dw=128, seed=seed, with_last_be=True) diff --git a/test/test_stream.py b/test/test_stream.py index 3ef4f4703..7da893850 100644 --- a/test/test_stream.py +++ b/test/test_stream.py @@ -1,335 +1,52 @@ +# +# This file is part of LiteX. +# # Copyright (c) 2020 Florent Kermarrec # SPDX-License-Identifier: BSD-2-Clause import unittest import random -import itertools -import sys from migen import * from litex.soc.interconnect.stream import * -# Function to iterate over chunks of data, from -# https://docs.python.org/3/library/itertools.html#itertools-recipes -def grouper(iterable, n, fillvalue=None): - "Collect data into fixed-length chunks or blocks" - # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" - args = [iter(iterable)] * n - return itertools.zip_longest(*args, fillvalue=fillvalue) - -class StreamPacket: - def __init__(self, data, params={}): - # Data must be a list of bytes - assert type(data) == list - for b in data: - assert type(b) == int and b >= 0 and b < 256 - - # Params must be a dictionary of strings mapping to integers - assert type(params) == dict - for param_key, param_value in params.items(): - assert type(param_key) == str - assert type(param_value) == int - - self.data = data - self.params = params - - def compare(self, other, quiet=True, output_target=sys.stdout): - if len(self.data) != len(other.data): - if not quiet: - print("Length mismatch in number of received bytes of packet:" \ - " {} {}".format(len(self.data), len(other.data)), - file=sys.stdout) - return False - - - for nbyte, (byte_a, byte_b) in enumerate(zip(self.data, other.data)): - if byte_a != byte_b: - if not quiet: - print("Mismatch between sent and received bytes {}: " \ - "0x{:02x} 0x{:02x}".format(nbyte, byte_a, byte_b), - file=sys.stdout) - return False - - if set(self.params.keys()) != set(other.params.keys()): - if not quiet: - print("Sent and received packets have different param fields:" \ - " {} {}".format(self.params.keys(), other.params.keys()), - file=sys.stdout) - return False - - for param_name, self_param_value in self.params.items(): - other_param_value = other.params[param_name] - if self_param_value != other_param_value: - if not quiet: - print("Sent and received packets have different value for" \ - " param signal \"{}\": 0x{:x} 0x{:x}".format( - param_name, - self_param_value, - other_param_value), - file=sys.stdout) - return False - - return True - -def stream_inserter( - sink, - src, - seed=42, - valid_rand=50, - debug_print=False, - broken_8bit_last_be=True): - """Insert a list of packets of bytes on to the stream interface `sink`. If - `sink` has a `last_be` signal, that is set accordingly. - - """ - - prng = random.Random(seed) - - # Extract the data width from the provided sink Endpoint - dw = len(sink.data) - - # Make sure dw is evenly divisible by 8 as the logic below relies on - # that. Also, last_be wouldn't make much sense otherwise. - assert dw % 8 == 0 - - # If a last_be signal is provided, it must contain one bit per byte of data, - # i.e. be dw // 8 long. - if hasattr(sink, "last_be"): - assert dw // 8 == len(sink.last_be) - - # src is a list of lists. Each list represents a packet of bytes. Send each - # packet over the bus. - for pi, packet in enumerate(src): - assert len(packet.data) > 0, "Packets of length 0 are not compatible " \ - "with the ready/valid stream interface" - - # Each packet is a list. We must send dw // 8 bytes at a time. Use the - # grouper method to get a chunked iterator over the packet bytes and - # shift them to their correct position. Use a random filler byte to - # complete a bus word. - words = [] - for chunk in grouper(packet.data, dw // 8, prng.randrange(256)): - word = 0 - for i, b in enumerate(chunk): - assert b >= 0 and b < 256 - word |= b << (i * 8) - words += [word] - - if hasattr(sink, "last_be"): - encoded_last_be = Constant( - 1 << ((len(packet.data) - 1) % (dw // 8)), - bits_sign=len(sink.last_be) - ) - - # In legacy code for 8bit data paths last_be might not be set - # properly: while last_be should always be equal to last for 8bit - # data paths, if new code interacts with old code which is not yet - # last_be aware, it might always be deasserted. If - # broken_8bit_last_be is set and we have an 8bit data path, randomly - # set last_be to either one or zero to check whether the DUT handles - # these cases properly. - if broken_8bit_last_be and dw == 8: - encoded_last_be = Constant(prng.randrange(2), bits_sign=1) - - # At the very beginning of the packet transmission, set the param - # signals - for param_signal, param_value in packet.params.items(): - yield getattr(sink, param_signal).eq(param_value) - - for i, word in enumerate(words): - last = i == len(words) - 1 - - # Place the word on the bus, if its the last word set last and - # last_be accordingly and finally set sink to valid - yield sink.data.eq(word) - yield sink.last.eq(last) - if hasattr(sink, "last_be"): - if last: - yield sink.last_be.eq(encoded_last_be) - else: - yield sink.last_be.eq(0) - yield sink.valid.eq(1) - yield - - # Wait until the sink has become ready for one clock cycle - while not (yield sink.ready): - yield - - # Set sink to not valid for a random amount of time - yield sink.valid.eq(0) - while prng.randrange(100) < valid_rand: - yield - - # Okay, we've transmitted a packet. We must set sink.valid to false, for - # good measure clear all other signals as well. We don't explicitly - # yield, given a there might be a new packet waiting already. - yield sink.data.eq(0) - yield sink.last.eq(0) - if hasattr(sink, "last_be"): - yield sink.last_be.eq(0) - for param_signal in packet.params.keys(): - yield getattr(sink, param_signal).eq(0) - yield sink.valid.eq(0) - - if debug_print: - print("Sent packet {}.".format(pi), file=sys.stderr) - - # All packets have been transmitted. sink.valid has already been - # deasserted, yield once to properly apply that value. - yield - -def stream_collector( - source, - dest=[], - expect_npackets=None, - seed=42, - ready_rand=50, - debug_print=False): - """Consume some packets of bytes from the stream interface - `source`. If `source` has a `last_be` signal, that is respected - properly. - - """ - - prng = random.Random(seed) - - # Extract the data width from the provided source endpoint - dw = len(source.data) - - # Make sure dw is evenly divisible by 8 as the logic below relies on - # that. Also, last_be wouldn't make much sense otherwise. - assert dw % 8 == 0 - - # If a last_be signal is provided, it must contain one bit per byte of data, - # i.e. be dw // 8 long. - if hasattr(source, "last_be"): - assert dw // 8 == len(source.last_be) - - # Extract "param_signals" from the source Endpoint. They are extracted on - # the first valid word of a packet. If dest will be a list of tuples with - # data and param signals if there are any, otherwise just a list of lists. - param_signals = [ - signal_name for signal_name, _, _ in source.param.layout - ] if hasattr(source, "param") else [] - - # Loop for collecting individual packets, separated by source.last - while expect_npackets == None or len(dest) < expect_npackets: - # Buffer for the current packet - collected_bytes = [] - param_signal_states = {} - - # Iterate until "last" has been seen. That concludes the end of a bus - # transaction / packet. - read_last = False - first_word = True - while not read_last: - # We are ready to accept another bus word - yield source.ready.eq(1) - yield - - # Wait for data to become valid - while (yield source.valid) == 0: - yield - - # Data is now valid, read it byte by byte - data = yield source.data - for byte in range(dw // 8): - if (yield source.last) == 1: - read_last = True - if hasattr(source, "last_be") and \ - 2**byte > (yield source.last_be): - break - collected_bytes += [((data >> (byte * 8)) & 0xFF)] - - # Also, if this is the first loop iteration, latch all param signals - for param_signal in param_signals: - param_signal_states[param_signal] = \ - yield getattr(source, param_signal) - - # Set source to not valid for a random amount of time - yield source.ready.eq(0) - while prng.randrange(100) < ready_rand: - yield - - # This is no longer the first loop iteration - first_word = False - - # A full packet has been read. Append it to dest. - dest += [StreamPacket(collected_bytes, param_signal_states)] - if debug_print: - print("Received packet {}.".format(len(dest) - 1), file=sys.stderr) - -def generate_test_packets(npackets, seed=42): - # Generate a number of last-terminated bus transaction byte contents (dubbed - # packets) - prng = random.Random(42) - - packets = [] - for _ in range(npackets): - # With a random number of bytes from [1, 1024) - values = [] - for _ in range(prng.randrange(1023) + 1): - # With random values from [0, 256). - values += [prng.randrange(256)] - packets += [StreamPacket(values)] - - return packets - -def compare_packets(packets_a, packets_b): - if len(packets_a) != len(packets_b): - print("Length mismatch in number of received packets: {} {}" - .format(len(packets_a), len(packets_b)), file=sys.stderr) - return False - - for npacket, (packet_a, packet_b) in enumerate(zip(packets_a, packets_b)): - if not packet_a.compare(packet_b): - print("Error in packet", npacket) - packet_a.compare(packet_b, quiet=False) - return False - - return True class TestStream(unittest.TestCase): - def pipe_test(self, dut, seed=42, npackets=64, debug_print=False): - # Get some data to test with - packets = generate_test_packets(npackets, seed=seed) + def pipe_test(self, dut): + prng = random.Random(42) + def generator(dut, valid_rand=90): + for data in range(128): + yield dut.sink.valid.eq(1) + yield dut.sink.data.eq(data) + yield + while (yield dut.sink.ready) == 0: + yield + yield dut.sink.valid.eq(0) + while prng.randrange(100) < valid_rand: + yield - # Buffer for received packets (filled by collector) - recvd_packets = [] - - run_simulation( - dut, - [ - stream_inserter( - dut.sink, - src=packets, - debug_print=debug_print, - seed=seed, - ), - stream_collector( - dut.source, - dest=recvd_packets, - expect_npackets=npackets, - debug_print=debug_print, - seed=seed, - ), - ], - ) - self.assertTrue(compare_packets(packets, recvd_packets)) + def checker(dut, ready_rand=90): + dut.errors = 0 + for data in range(128): + yield dut.source.ready.eq(0) + yield + while (yield dut.source.valid) == 0: + yield + while prng.randrange(100) < ready_rand: + yield + yield dut.source.ready.eq(1) + yield + if ((yield dut.source.data) != data): + dut.errors += 1 + yield + run_simulation(dut, [generator(dut), checker(dut)]) + self.assertEqual(dut.errors, 0) def test_pipe_valid(self): - # PipeValid either connects the entire payload or not. Thus we don't - # need to test for 8bit support or a missing last_be signal - # specifically. This test does however ensure that last_be will continue - # to be respected in the future. - dut = PipeValid([("data", 32), ("last_be", 4)]) + dut = PipeValid([("data", 8)]) self.pipe_test(dut) def test_pipe_ready(self): - # PipeReady either connects the entire stream Endpoint or not. Thus we - # don't need to test for 8bit support or a missing last_be signal - # specifically. This test does however ensure that last_be will continue - # to be respected in the future. - dut = PipeReady([("data", 64), ("last_be", 8)]) + dut = PipeReady([("data", 8)]) self.pipe_test(dut) diff --git a/test/test_stream2.py b/test/test_stream2.py new file mode 100644 index 000000000..609a4c976 --- /dev/null +++ b/test/test_stream2.py @@ -0,0 +1,339 @@ +# +# This file is part of LiteX. +# +# Copyright (c) 2021 Leon Schuermann +# SPDX-License-Identifier: BSD-2-Clause + +import unittest +import unittest +import random +import itertools +import sys + +from migen import * + +from litex.soc.interconnect.stream import * + +# Function to iterate over chunks of data, from +# https://docs.python.org/3/library/itertools.html#itertools-recipes +def grouper(iterable, n, fillvalue=None): + "Collect data into fixed-length chunks or blocks" + # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" + args = [iter(iterable)] * n + return itertools.zip_longest(*args, fillvalue=fillvalue) + +class StreamPacket: + def __init__(self, data, params={}): + # Data must be a list of bytes + assert type(data) == list + for b in data: + assert type(b) == int and b >= 0 and b < 256 + + # Params must be a dictionary of strings mapping to integers + assert type(params) == dict + for param_key, param_value in params.items(): + assert type(param_key) == str + assert type(param_value) == int + + self.data = data + self.params = params + + def compare(self, other, quiet=True, output_target=sys.stdout): + if len(self.data) != len(other.data): + if not quiet: + print("Length mismatch in number of received bytes of packet:" \ + " {} {}".format(len(self.data), len(other.data)), + file=sys.stdout) + return False + + + for nbyte, (byte_a, byte_b) in enumerate(zip(self.data, other.data)): + if byte_a != byte_b: + if not quiet: + print("Mismatch between sent and received bytes {}: " \ + "0x{:02x} 0x{:02x}".format(nbyte, byte_a, byte_b), + file=sys.stdout) + return False + + if set(self.params.keys()) != set(other.params.keys()): + if not quiet: + print("Sent and received packets have different param fields:" \ + " {} {}".format(self.params.keys(), other.params.keys()), + file=sys.stdout) + return False + + for param_name, self_param_value in self.params.items(): + other_param_value = other.params[param_name] + if self_param_value != other_param_value: + if not quiet: + print("Sent and received packets have different value for" \ + " param signal \"{}\": 0x{:x} 0x{:x}".format( + param_name, + self_param_value, + other_param_value), + file=sys.stdout) + return False + + return True + +def stream_inserter( + sink, + src, + seed=42, + valid_rand=50, + debug_print=False, + broken_8bit_last_be=True): + """Insert a list of packets of bytes on to the stream interface `sink`. If + `sink` has a `last_be` signal, that is set accordingly. + + """ + + prng = random.Random(seed) + + # Extract the data width from the provided sink Endpoint + dw = len(sink.data) + + # Make sure dw is evenly divisible by 8 as the logic below relies on + # that. Also, last_be wouldn't make much sense otherwise. + assert dw % 8 == 0 + + # If a last_be signal is provided, it must contain one bit per byte of data, + # i.e. be dw // 8 long. + if hasattr(sink, "last_be"): + assert dw // 8 == len(sink.last_be) + + # src is a list of lists. Each list represents a packet of bytes. Send each + # packet over the bus. + for pi, packet in enumerate(src): + assert len(packet.data) > 0, "Packets of length 0 are not compatible " \ + "with the ready/valid stream interface" + + # Each packet is a list. We must send dw // 8 bytes at a time. Use the + # grouper method to get a chunked iterator over the packet bytes and + # shift them to their correct position. Use a random filler byte to + # complete a bus word. + words = [] + for chunk in grouper(packet.data, dw // 8, prng.randrange(256)): + word = 0 + for i, b in enumerate(chunk): + assert b >= 0 and b < 256 + word |= b << (i * 8) + words += [word] + + if hasattr(sink, "last_be"): + encoded_last_be = Constant( + 1 << ((len(packet.data) - 1) % (dw // 8)), + bits_sign=len(sink.last_be) + ) + + # In legacy code for 8bit data paths last_be might not be set + # properly: while last_be should always be equal to last for 8bit + # data paths, if new code interacts with old code which is not yet + # last_be aware, it might always be deasserted. If + # broken_8bit_last_be is set and we have an 8bit data path, randomly + # set last_be to either one or zero to check whether the DUT handles + # these cases properly. + if broken_8bit_last_be and dw == 8: + encoded_last_be = Constant(prng.randrange(2), bits_sign=1) + + # At the very beginning of the packet transmission, set the param + # signals + for param_signal, param_value in packet.params.items(): + yield getattr(sink, param_signal).eq(param_value) + + for i, word in enumerate(words): + last = i == len(words) - 1 + + # Place the word on the bus, if its the last word set last and + # last_be accordingly and finally set sink to valid + yield sink.data.eq(word) + yield sink.last.eq(last) + if hasattr(sink, "last_be"): + if last: + yield sink.last_be.eq(encoded_last_be) + else: + yield sink.last_be.eq(0) + yield sink.valid.eq(1) + yield + + # Wait until the sink has become ready for one clock cycle + while not (yield sink.ready): + yield + + # Set sink to not valid for a random amount of time + yield sink.valid.eq(0) + while prng.randrange(100) < valid_rand: + yield + + # Okay, we've transmitted a packet. We must set sink.valid to false, for + # good measure clear all other signals as well. We don't explicitly + # yield, given a there might be a new packet waiting already. + yield sink.data.eq(0) + yield sink.last.eq(0) + if hasattr(sink, "last_be"): + yield sink.last_be.eq(0) + for param_signal in packet.params.keys(): + yield getattr(sink, param_signal).eq(0) + yield sink.valid.eq(0) + + if debug_print: + print("Sent packet {}.".format(pi), file=sys.stderr) + + # All packets have been transmitted. sink.valid has already been + # deasserted, yield once to properly apply that value. + yield + +def stream_collector( + source, + dest=[], + expect_npackets=None, + seed=42, + ready_rand=50, + debug_print=False): + """Consume some packets of bytes from the stream interface + `source`. If `source` has a `last_be` signal, that is respected + properly. + + """ + + prng = random.Random(seed) + + # Extract the data width from the provided source endpoint + dw = len(source.data) + + # Make sure dw is evenly divisible by 8 as the logic below relies on + # that. Also, last_be wouldn't make much sense otherwise. + assert dw % 8 == 0 + + # If a last_be signal is provided, it must contain one bit per byte of data, + # i.e. be dw // 8 long. + if hasattr(source, "last_be"): + assert dw // 8 == len(source.last_be) + + # Extract "param_signals" from the source Endpoint. They are extracted on + # the first valid word of a packet. If dest will be a list of tuples with + # data and param signals if there are any, otherwise just a list of lists. + param_signals = [ + signal_name for signal_name, _, _ in source.param.layout + ] if hasattr(source, "param") else [] + + # Loop for collecting individual packets, separated by source.last + while expect_npackets == None or len(dest) < expect_npackets: + # Buffer for the current packet + collected_bytes = [] + param_signal_states = {} + + # Iterate until "last" has been seen. That concludes the end of a bus + # transaction / packet. + read_last = False + first_word = True + while not read_last: + # We are ready to accept another bus word + yield source.ready.eq(1) + yield + + # Wait for data to become valid + while (yield source.valid) == 0: + yield + + # Data is now valid, read it byte by byte + data = yield source.data + for byte in range(dw // 8): + if (yield source.last) == 1: + read_last = True + if hasattr(source, "last_be") and \ + 2**byte > (yield source.last_be): + break + collected_bytes += [((data >> (byte * 8)) & 0xFF)] + + # Also, if this is the first loop iteration, latch all param signals + for param_signal in param_signals: + param_signal_states[param_signal] = \ + yield getattr(source, param_signal) + + # Set source to not valid for a random amount of time + yield source.ready.eq(0) + while prng.randrange(100) < ready_rand: + yield + + # This is no longer the first loop iteration + first_word = False + + # A full packet has been read. Append it to dest. + dest += [StreamPacket(collected_bytes, param_signal_states)] + if debug_print: + print("Received packet {}.".format(len(dest) - 1), file=sys.stderr) + +def generate_test_packets(npackets, seed=42): + # Generate a number of last-terminated bus transaction byte contents (dubbed + # packets) + prng = random.Random(42) + + packets = [] + for _ in range(npackets): + # With a random number of bytes from [1, 1024) + values = [] + for _ in range(prng.randrange(1023) + 1): + # With random values from [0, 256). + values += [prng.randrange(256)] + packets += [StreamPacket(values)] + + return packets + +def compare_packets(packets_a, packets_b): + if len(packets_a) != len(packets_b): + print("Length mismatch in number of received packets: {} {}" + .format(len(packets_a), len(packets_b)), file=sys.stderr) + return False + + for npacket, (packet_a, packet_b) in enumerate(zip(packets_a, packets_b)): + if not packet_a.compare(packet_b): + print("Error in packet", npacket) + packet_a.compare(packet_b, quiet=False) + return False + + return True + +class TestStream(unittest.TestCase): + def pipe_test(self, dut, seed=42, npackets=64, debug_print=False): + # Get some data to test with + packets = generate_test_packets(npackets, seed=seed) + + # Buffer for received packets (filled by collector) + recvd_packets = [] + + run_simulation( + dut, + [ + stream_inserter( + dut.sink, + src=packets, + debug_print=debug_print, + seed=seed, + ), + stream_collector( + dut.source, + dest=recvd_packets, + expect_npackets=npackets, + debug_print=debug_print, + seed=seed, + ), + ], + ) + self.assertTrue(compare_packets(packets, recvd_packets)) + + def test_pipe_valid(self): + # PipeValid either connects the entire payload or not. Thus we don't + # need to test for 8bit support or a missing last_be signal + # specifically. This test does however ensure that last_be will continue + # to be respected in the future. + dut = PipeValid([("data", 32), ("last_be", 4)]) + self.pipe_test(dut) + + def test_pipe_ready(self): + # PipeReady either connects the entire stream Endpoint or not. Thus we + # don't need to test for 8bit support or a missing last_be signal + # specifically. This test does however ensure that last_be will continue + # to be respected in the future. + dut = PipeReady([("data", 64), ("last_be", 8)]) + self.pipe_test(dut)