Merge pull request #175 from antmicro/jboc/unit-tests-bandwidth

Add core.bandwidth tests
This commit is contained in:
enjoy-digital 2020-04-07 13:13:54 +02:00 committed by GitHub
commit 492b9fa209
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 254 additions and 7 deletions

View File

@ -14,8 +14,8 @@ from litex.soc.interconnect.csr import *
class Bandwidth(Module, AutoCSR): class Bandwidth(Module, AutoCSR):
def __init__(self, cmd, data_width, period_bits=24): def __init__(self, cmd, data_width, period_bits=24):
self.update = CSR() self.update = CSR()
self.nreads = CSRStatus(period_bits) self.nreads = CSRStatus(period_bits + 1)
self.nwrites = CSRStatus(period_bits) self.nwrites = CSRStatus(period_bits + 1)
self.data_width = CSRStatus(bits_for(data_width), reset=data_width) self.data_width = CSRStatus(bits_for(data_width), reset=data_width)
# # # # # #
@ -33,17 +33,22 @@ class Bandwidth(Module, AutoCSR):
counter = Signal(period_bits) counter = Signal(period_bits)
period = Signal() period = Signal()
nreads = Signal(period_bits) nreads = Signal(period_bits + 1)
nwrites = Signal(period_bits) nwrites = Signal(period_bits + 1)
nreads_r = Signal(period_bits) nreads_r = Signal(period_bits + 1)
nwrites_r = Signal(period_bits) nwrites_r = Signal(period_bits + 1)
self.sync += [ self.sync += [
Cat(counter, period).eq(counter + 1), Cat(counter, period).eq(counter + 1),
If(period, If(period,
nreads_r.eq(nreads), nreads_r.eq(nreads),
nwrites_r.eq(nwrites), nwrites_r.eq(nwrites),
nreads.eq(0), nreads.eq(0),
nwrites.eq(0) nwrites.eq(0),
# don't miss command if there is one on period boundary
If(cmd_valid & cmd_ready,
If(cmd_is_read, nreads.eq(1)),
If(cmd_is_write, nwrites.eq(1)),
)
).Elif(cmd_valid & cmd_ready, ).Elif(cmd_valid & cmd_ready,
If(cmd_is_read, nreads.eq(nreads + 1)), If(cmd_is_read, nreads.eq(nreads + 1)),
If(cmd_is_write, nwrites.eq(nwrites + 1)), If(cmd_is_write, nwrites.eq(nwrites + 1)),

242
test/test_bandwidth.py Normal file
View File

@ -0,0 +1,242 @@
# This file is Copyright (c) 2020 Antmicro <www.antmicro.com>
# License: BSD
import random
import unittest
import itertools
import collections
from migen import *
from litex.soc.interconnect import stream
from litedram.common import *
from litedram.core.bandwidth import Bandwidth
from test.common import timeout_generator, CmdRequestRWDriver
class BandwidthDUT(Module):
def __init__(self, data_width=8, **kwargs):
a, ba = 13, 3
self.cmd = stream.Endpoint(cmd_request_rw_layout(a, ba))
self.submodules.bandwidth = Bandwidth(self.cmd, data_width, **kwargs)
class CommandDriver:
def __init__(self, cmd, cmd_options=None):
self.cmd = cmd
self.driver = CmdRequestRWDriver(cmd)
self.cmd_counts = collections.defaultdict(int)
@passive
def random_generator(self, random_ready_max=20, commands=None):
commands = commands or ["read", "write"]
prng = random.Random(42)
while True:
# generate random command
command = prng.choice(commands)
yield from getattr(self.driver, command)()
yield
# wait some times before it becomes ready
for _ in range(prng.randint(0, random_ready_max)):
yield
yield self.cmd.ready.eq(1)
yield
self.cmd_counts[command] += 1
yield self.cmd.ready.eq(0)
# disable command
yield from self.driver.nop()
yield
@passive
def timeline_generator(self, timeline):
# timeline: an iterator of tuples (cycle, command)
sim_cycle = 0
for cycle, command in timeline:
assert cycle >= sim_cycle
while sim_cycle != cycle:
sim_cycle += 1
yield
# set the command
yield from getattr(self.driver, command)()
yield self.cmd.ready.eq(1)
self.cmd_counts[command] += 1
# advance 1 cycle
yield
sim_cycle += 1
# clear state
yield self.cmd.ready.eq(0)
yield from self.driver.nop()
class TestBandwidth(unittest.TestCase):
def test_can_read_status_data_width(self):
# Verify that data width can be read from a CSR
def test(data_width):
def main_generator(dut):
yield
self.assertEqual((yield dut.bandwidth.data_width.status), data_width)
dut = BandwidthDUT(data_width=data_width)
run_simulation(dut, main_generator(dut))
for data_width in [8, 16, 32, 64]:
with self.subTest(data_width=data_width):
test(data_width)
def test_requires_update_to_copy_the_data(self):
# Verify that command counts are copied to CSRs only after `update`
def main_generator(dut):
nreads = (yield from dut.bandwidth.nreads.read())
nwrites = (yield from dut.bandwidth.nwrites.read())
self.assertEqual(nreads, 0)
self.assertEqual(nwrites, 0)
# wait enough for the period to end
for _ in range(2**6):
yield
nreads = (yield from dut.bandwidth.nreads.read())
nwrites = (yield from dut.bandwidth.nwrites.read())
self.assertEqual(nreads, 0)
self.assertEqual(nwrites, 0)
# update register values
yield from dut.bandwidth.update.write(1)
nreads = (yield from dut.bandwidth.nreads.read())
nwrites = (yield from dut.bandwidth.nwrites.read())
self.assertNotEqual((nreads, nwrites), (0, 0))
dut = BandwidthDUT(period_bits=6)
cmd_driver = CommandDriver(dut.cmd)
generators = [
main_generator(dut),
cmd_driver.random_generator(),
]
run_simulation(dut, generators)
def test_correct_read_write_counts(self):
# Verify that the number of registered READ/WRITE commands is correct
results = {}
def main_generator(dut):
# wait for the first period to end
for _ in range(2**8):
yield
yield from dut.bandwidth.update.write(1)
yield
results["nreads"] = (yield from dut.bandwidth.nreads.read())
results["nwrites"] = (yield from dut.bandwidth.nwrites.read())
dut = BandwidthDUT(period_bits=8)
cmd_driver = CommandDriver(dut.cmd)
generators = [
main_generator(dut),
cmd_driver.random_generator(),
]
run_simulation(dut, generators)
self.assertEqual(results["nreads"], cmd_driver.cmd_counts["read"])
def test_counts_read_write_only(self):
# Verify that only READ and WRITE commands are registered
results = {}
def main_generator(dut):
# wait for the first period to end
for _ in range(2**8):
yield
yield from dut.bandwidth.update.write(1)
yield
results["nreads"] = (yield from dut.bandwidth.nreads.read())
results["nwrites"] = (yield from dut.bandwidth.nwrites.read())
dut = BandwidthDUT(period_bits=8)
cmd_driver = CommandDriver(dut.cmd)
commands = ["read", "write", "activate", "precharge", "refresh"]
generators = [
main_generator(dut),
cmd_driver.random_generator(commands=commands),
]
run_simulation(dut, generators)
self.assertEqual(results["nreads"], cmd_driver.cmd_counts["read"])
def test_correct_period_length(self):
# Verify that period length is correct by measuring time betwee CSR changes
period_bits = 5
period = 2**period_bits
n_per_period = {0: 3, 1: 6, 2: 9}
timeline = {}
for p, n in n_per_period.items():
for i in range(n):
margin = 10
timeline[period*p + margin + i] = "write"
def main_generator(dut):
# keep the values always up to date
yield dut.bandwidth.update.re.eq(1)
# wait until we have the data from 1st period
while (yield dut.bandwidth.nwrites.status) != 3:
yield
# count time to next period
cycles = 0
while (yield dut.bandwidth.nwrites.status) != 6:
cycles += 1
yield
self.assertEqual(cycles, period)
dut = BandwidthDUT(period_bits=period_bits)
cmd_driver = CommandDriver(dut.cmd)
generators = [
main_generator(dut),
cmd_driver.timeline_generator(timeline.items()),
timeout_generator(period * 3),
]
run_simulation(dut, generators)
def test_not_missing_commands_on_period_boundary(self):
# Verify that no data is lost in the cycle when new period starts
period_bits = 5
period = 2**period_bits
# start 10 cycles before period ends, end 10 cycles after it ends
base = period - 10
nwrites = 20
timeline = {base + i: "write" for i in range(nwrites)}
def main_generator(dut):
# wait until 1st period ends (+ some margin)
for _ in range(period + 10):
yield
# read the count from 1st period
yield from dut.bandwidth.update.write(1)
yield
nwrites_registered = (yield from dut.bandwidth.nwrites.read())
# wait until 2nd period ends
for _ in range(period):
yield
# read the count from 1st period
yield from dut.bandwidth.update.write(1)
yield
nwrites_registered += (yield from dut.bandwidth.nwrites.read())
self.assertEqual(nwrites_registered, nwrites)
dut = BandwidthDUT(period_bits=period_bits)
cmd_driver = CommandDriver(dut.cmd)
generators = [
main_generator(dut),
cmd_driver.timeline_generator(timeline.items()),
]
run_simulation(dut, generators)