diff --git a/litedram/core/bandwidth.py b/litedram/core/bandwidth.py index 0ac3f22..a39070d 100644 --- a/litedram/core/bandwidth.py +++ b/litedram/core/bandwidth.py @@ -14,8 +14,8 @@ from litex.soc.interconnect.csr import * class Bandwidth(Module, AutoCSR): def __init__(self, cmd, data_width, period_bits=24): self.update = CSR() - self.nreads = CSRStatus(period_bits) - self.nwrites = CSRStatus(period_bits) + self.nreads = CSRStatus(period_bits + 1) + self.nwrites = CSRStatus(period_bits + 1) self.data_width = CSRStatus(bits_for(data_width), reset=data_width) # # # @@ -33,17 +33,22 @@ class Bandwidth(Module, AutoCSR): counter = Signal(period_bits) period = Signal() - nreads = Signal(period_bits) - nwrites = Signal(period_bits) - nreads_r = Signal(period_bits) - nwrites_r = Signal(period_bits) + nreads = Signal(period_bits + 1) + nwrites = Signal(period_bits + 1) + nreads_r = Signal(period_bits + 1) + nwrites_r = Signal(period_bits + 1) self.sync += [ Cat(counter, period).eq(counter + 1), If(period, nreads_r.eq(nreads), nwrites_r.eq(nwrites), nreads.eq(0), - nwrites.eq(0) + nwrites.eq(0), + # don't miss command if there is one on period boundary + If(cmd_valid & cmd_ready, + If(cmd_is_read, nreads.eq(1)), + If(cmd_is_write, nwrites.eq(1)), + ) ).Elif(cmd_valid & cmd_ready, If(cmd_is_read, nreads.eq(nreads + 1)), If(cmd_is_write, nwrites.eq(nwrites + 1)), diff --git a/test/test_bandwidth.py b/test/test_bandwidth.py new file mode 100644 index 0000000..ef5b2ba --- /dev/null +++ b/test/test_bandwidth.py @@ -0,0 +1,242 @@ +# This file is Copyright (c) 2020 Antmicro +# License: BSD + +import random +import unittest +import itertools +import collections + +from migen import * + +from litex.soc.interconnect import stream + +from litedram.common import * +from litedram.core.bandwidth import Bandwidth + +from test.common import timeout_generator, CmdRequestRWDriver + + +class BandwidthDUT(Module): + def __init__(self, data_width=8, **kwargs): + a, ba = 13, 3 + self.cmd = stream.Endpoint(cmd_request_rw_layout(a, ba)) + self.submodules.bandwidth = Bandwidth(self.cmd, data_width, **kwargs) + + +class CommandDriver: + def __init__(self, cmd, cmd_options=None): + self.cmd = cmd + self.driver = CmdRequestRWDriver(cmd) + self.cmd_counts = collections.defaultdict(int) + + @passive + def random_generator(self, random_ready_max=20, commands=None): + commands = commands or ["read", "write"] + prng = random.Random(42) + + while True: + # generate random command + command = prng.choice(commands) + yield from getattr(self.driver, command)() + yield + # wait some times before it becomes ready + for _ in range(prng.randint(0, random_ready_max)): + yield + yield self.cmd.ready.eq(1) + yield + self.cmd_counts[command] += 1 + yield self.cmd.ready.eq(0) + # disable command + yield from self.driver.nop() + yield + + @passive + def timeline_generator(self, timeline): + # timeline: an iterator of tuples (cycle, command) + sim_cycle = 0 + for cycle, command in timeline: + assert cycle >= sim_cycle + while sim_cycle != cycle: + sim_cycle += 1 + yield + # set the command + yield from getattr(self.driver, command)() + yield self.cmd.ready.eq(1) + self.cmd_counts[command] += 1 + # advance 1 cycle + yield + sim_cycle += 1 + # clear state + yield self.cmd.ready.eq(0) + yield from self.driver.nop() + + +class TestBandwidth(unittest.TestCase): + def test_can_read_status_data_width(self): + # Verify that data width can be read from a CSR + def test(data_width): + def main_generator(dut): + yield + self.assertEqual((yield dut.bandwidth.data_width.status), data_width) + + dut = BandwidthDUT(data_width=data_width) + run_simulation(dut, main_generator(dut)) + + for data_width in [8, 16, 32, 64]: + with self.subTest(data_width=data_width): + test(data_width) + + def test_requires_update_to_copy_the_data(self): + # Verify that command counts are copied to CSRs only after `update` + def main_generator(dut): + nreads = (yield from dut.bandwidth.nreads.read()) + nwrites = (yield from dut.bandwidth.nwrites.read()) + self.assertEqual(nreads, 0) + self.assertEqual(nwrites, 0) + + # wait enough for the period to end + for _ in range(2**6): + yield + + nreads = (yield from dut.bandwidth.nreads.read()) + nwrites = (yield from dut.bandwidth.nwrites.read()) + self.assertEqual(nreads, 0) + self.assertEqual(nwrites, 0) + + # update register values + yield from dut.bandwidth.update.write(1) + + nreads = (yield from dut.bandwidth.nreads.read()) + nwrites = (yield from dut.bandwidth.nwrites.read()) + self.assertNotEqual((nreads, nwrites), (0, 0)) + + dut = BandwidthDUT(period_bits=6) + cmd_driver = CommandDriver(dut.cmd) + generators = [ + main_generator(dut), + cmd_driver.random_generator(), + ] + run_simulation(dut, generators) + + def test_correct_read_write_counts(self): + # Verify that the number of registered READ/WRITE commands is correct + results = {} + + def main_generator(dut): + # wait for the first period to end + for _ in range(2**8): + yield + yield from dut.bandwidth.update.write(1) + yield + results["nreads"] = (yield from dut.bandwidth.nreads.read()) + results["nwrites"] = (yield from dut.bandwidth.nwrites.read()) + + dut = BandwidthDUT(period_bits=8) + cmd_driver = CommandDriver(dut.cmd) + generators = [ + main_generator(dut), + cmd_driver.random_generator(), + ] + run_simulation(dut, generators) + + self.assertEqual(results["nreads"], cmd_driver.cmd_counts["read"]) + + def test_counts_read_write_only(self): + # Verify that only READ and WRITE commands are registered + results = {} + + def main_generator(dut): + # wait for the first period to end + for _ in range(2**8): + yield + yield from dut.bandwidth.update.write(1) + yield + results["nreads"] = (yield from dut.bandwidth.nreads.read()) + results["nwrites"] = (yield from dut.bandwidth.nwrites.read()) + + dut = BandwidthDUT(period_bits=8) + cmd_driver = CommandDriver(dut.cmd) + commands = ["read", "write", "activate", "precharge", "refresh"] + generators = [ + main_generator(dut), + cmd_driver.random_generator(commands=commands), + ] + run_simulation(dut, generators) + + self.assertEqual(results["nreads"], cmd_driver.cmd_counts["read"]) + + def test_correct_period_length(self): + # Verify that period length is correct by measuring time betwee CSR changes + period_bits = 5 + period = 2**period_bits + + n_per_period = {0: 3, 1: 6, 2: 9} + timeline = {} + for p, n in n_per_period.items(): + for i in range(n): + margin = 10 + timeline[period*p + margin + i] = "write" + + def main_generator(dut): + # keep the values always up to date + yield dut.bandwidth.update.re.eq(1) + + # wait until we have the data from 1st period + while (yield dut.bandwidth.nwrites.status) != 3: + yield + + # count time to next period + cycles = 0 + while (yield dut.bandwidth.nwrites.status) != 6: + cycles += 1 + yield + + self.assertEqual(cycles, period) + + dut = BandwidthDUT(period_bits=period_bits) + cmd_driver = CommandDriver(dut.cmd) + generators = [ + main_generator(dut), + cmd_driver.timeline_generator(timeline.items()), + timeout_generator(period * 3), + ] + run_simulation(dut, generators) + + def test_not_missing_commands_on_period_boundary(self): + # Verify that no data is lost in the cycle when new period starts + period_bits = 5 + period = 2**period_bits + + # start 10 cycles before period ends, end 10 cycles after it ends + base = period - 10 + nwrites = 20 + timeline = {base + i: "write" for i in range(nwrites)} + + def main_generator(dut): + # wait until 1st period ends (+ some margin) + for _ in range(period + 10): + yield + + # read the count from 1st period + yield from dut.bandwidth.update.write(1) + yield + nwrites_registered = (yield from dut.bandwidth.nwrites.read()) + + # wait until 2nd period ends + for _ in range(period): + yield + + # read the count from 1st period + yield from dut.bandwidth.update.write(1) + yield + nwrites_registered += (yield from dut.bandwidth.nwrites.read()) + + self.assertEqual(nwrites_registered, nwrites) + + dut = BandwidthDUT(period_bits=period_bits) + cmd_driver = CommandDriver(dut.cmd) + generators = [ + main_generator(dut), + cmd_driver.timeline_generator(timeline.items()), + ] + run_simulation(dut, generators)