soc/interconnect/axi: implement AXI Lite decoder

This commit is contained in:
Jędrzej Boczar 2020-07-21 14:25:24 +02:00
parent 214cfdcaeb
commit 3a08b21d44
2 changed files with 232 additions and 51 deletions

View File

@ -192,6 +192,7 @@ class AXILiteInterface:
def write(self, addr, data, strb=None):
if strb is None:
strb = 2**len(self.w.strb) - 1
# aw + w
yield self.aw.valid.eq(1)
yield self.aw.addr.eq(addr)
yield self.w.data.eq(data)
@ -201,9 +202,12 @@ class AXILiteInterface:
while not (yield self.aw.ready):
yield
yield self.aw.valid.eq(0)
yield self.aw.addr.eq(0)
while not (yield self.w.ready):
yield
yield self.w.valid.eq(0)
yield self.w.strb.eq(0)
# b
yield self.b.ready.eq(1)
while not (yield self.b.valid):
yield
@ -212,12 +216,14 @@ class AXILiteInterface:
return resp
def read(self, addr):
# ar
yield self.ar.valid.eq(1)
yield self.ar.addr.eq(addr)
yield
while not (yield self.ar.ready):
yield
yield self.ar.valid.eq(0)
# r
yield self.r.ready.eq(1)
while not (yield self.r.valid):
yield
@ -943,6 +949,7 @@ class AXILiteRequestCounter(Module):
self.full = full = Signal()
self.empty = empty = Signal()
self.stall = stall = Signal()
self.ready = self.empty
self.comb += [
full.eq(counter == max_requests - 1),
@ -994,15 +1001,15 @@ class AXILiteArbiter(Module):
self.comb += dest.eq(source)
# allow to change rr.grant only after all requests from a master have been responded to
self.submodules.wr_counter = wr_counter = AXILiteRequestCounter(
self.submodules.wr_lock = wr_lock = AXILiteRequestCounter(
request=target.aw.valid & target.aw.ready, response=target.b.valid & target.b.ready)
self.submodules.rd_counter = rd_counter = AXILiteRequestCounter(
self.submodules.rd_lock = rd_lock = AXILiteRequestCounter(
request=target.ar.valid & target.ar.ready, response=target.r.valid & target.r.ready)
# switch to next request only if there are no responses pending
self.comb += [
self.rr_write.ce.eq(~(target.aw.valid | target.w.valid | target.b.valid) & wr_counter.empty),
self.rr_read.ce.eq(~(target.ar.valid | target.r.valid) & rd_counter.empty),
self.rr_write.ce.eq(~(target.aw.valid | target.w.valid | target.b.valid) & wr_lock.ready),
self.rr_read.ce.eq(~(target.ar.valid | target.r.valid) & rd_lock.ready),
]
# connect bus requests to round-robin selectors
@ -1012,54 +1019,103 @@ class AXILiteArbiter(Module):
]
class AXILiteDecoder(Module):
# slaves is a list of pairs:
# 0) function that takes the address signal and returns a FHDL expression
# that evaluates to 1 when the slave is selected and 0 otherwise.
# 1) wishbone.Slave reference.
# register adds flip-flops after the address comparators. Improves timing,
# but breaks Wishbone combinatorial feedback.
def __init__(self, master, slaves, register=False):
_doc_slaves = """
slaves: [(decoder, slave), ...]
List of slaves with address decoders, where `decoder` is a function:
decoder(Signal(address_width - log2(data_width//8))) -> Signal(1)
that returns 1 when the slave is selected and 0 otherwise.
""".strip()
__doc__ = """AXI Lite decoder
Decode master access to particular slave based on its decoder function.
{slaves}
""".format(slaves=_doc_slaves)
def __init__(self, master, slaves):
addr_shift = log2_int(master.data_width//8)
ns = len(slaves)
slave_sel = Signal(ns)
slave_sel_r = Signal(ns)
channels = {
"write": {"aw", "w", "b"},
"read": {"ar", "r"},
}
# reverse mapping: directions[channel] -> "write"/"read"
directions = {ch: d for d, chs in channels.items() for ch in chs}
def new_slave_sel():
return {"write": Signal(len(slaves)), "read": Signal(len(slaves))}
slave_sel_dec = new_slave_sel()
slave_sel_reg = new_slave_sel()
slave_sel = new_slave_sel()
# we need to hold the slave selected until all responses come back
# TODO: check if this will break Timeout if a slave does not respond?
# should probably work correctly as it uses master signals
# TODO: we could reuse arbiter counters
locks = {
"write": AXILiteRequestCounter(
request=master.aw.valid & master.aw.ready,
response=master.b.valid & master.b.ready),
"read": AXILiteRequestCounter(
request=master.ar.valid & master.ar.ready,
response=master.r.valid & master.r.ready),
}
self.submodules += locks.values()
def get_sig(interface, channel, name):
return getattr(getattr(interface, channel), name)
# # #
# decode slave addresses
self.comb += [slave_sel[i].eq(fun(master.aw.addr[addr_shift:]) | fun(master.aw.addr[addr_shift:]))
for i, (fun, bus) in enumerate(slaves)]
if register:
self.sync += slave_sel_r.eq(slave_sel)
else:
self.comb += slave_sel_r.eq(slave_sel)
# connect master->slaves signals except valid
for fun, slave in slaves:
for channel, name, direction in master.layout_flat():
if direction == DIR_M_TO_S and name != "valid":
self.comb += get_sig(slave, channel, name).eq(get_sig(master, channel, name))
# combine cyc with slave selection signals
for i, (fun, slave) in enumerate(slaves):
for ch in ["aw", "w", "ar"]:
slave_valid = get_sig(slave, ch, "valid")
master_valid = get_sig(master, ch, "valid")
self.comb += slave_valid.eq(master_valid & slave_sel[i])
# generate master ready by ORing all slave readys
for i, (decoder, bus) in enumerate(slaves):
self.comb += [
master.aw.ready.eq(reduce(or_, [slave.aw.ready for fun, slave in slaves])),
master.w.ready.eq(reduce(or_, [slave.w.ready for fun, slave in slaves])),
master.ar.ready.eq(reduce(or_, [slave.ar.ready for fun, slave in slaves])),
slave_sel_dec["write"][i].eq(decoder(master.aw.addr[addr_shift:])),
slave_sel_dec["read"][i].eq(decoder(master.ar.addr[addr_shift:])),
]
# mux (1-hot) slave data return
masked = [Replicate(slave_sel_r[i], len(master.r.data)) & slaves[i][1].r.data for i in range(ns)]
self.comb += master.r.data.eq(reduce(or_, masked))
# change the current selection only when we've got all responses
for channel in locks.keys():
self.sync += If(locks[channel].ready, slave_sel_reg[channel].eq(slave_sel_dec[channel]))
# we have to cut the delaying select
for ch, final in slave_sel.items():
self.comb += If(locks[ch].ready,
final.eq(slave_sel_dec[ch])
).Else(
final.eq(slave_sel_reg[ch])
)
# connect master->slaves signals except valid/ready
for i, (_, slave) in enumerate(slaves):
for channel, name, direction in master.layout_flat():
if direction == DIR_M_TO_S:
src = get_sig(master, channel, name)
dst = get_sig(slave, channel, name)
# mask master control signals depending on slave selection
if name in ["valid", "ready"]:
src = src & slave_sel[directions[channel]][i]
self.comb += dst.eq(src)
# connect slave->master signals masking not selected slaves
for channel, name, direction in master.layout_flat():
if direction == DIR_S_TO_M:
dst = get_sig(master, channel, name)
masked = []
for i, (_, slave) in enumerate(slaves):
src = get_sig(slave, channel, name)
# mask depending on channel
mask = Replicate(slave_sel[directions[channel]][i], len(dst))
masked.append(src & mask)
self.comb += dst.eq(reduce(or_, masked))
class AXILiteInterconnectShared(Module):
__doc__ = """AXI Lite shared interconnect
{slaves}
""".format(slaves=AXILiteDecoder._doc_slaves)
def __init__(self, masters, slaves, register=False, timeout_cycles=1e6):
# TODO: data width
shared = AXILiteInterface()
@ -1069,13 +1125,21 @@ class AXILiteInterconnectShared(Module):
self.submodules.timeout = AXILiteTimeout(shared, timeout_cycles)
class AXILiteCrossbar(Module):
__doc__ = """AXI Lite crossbar
MxN crossbar for M masters and N slaves.
{slaves}
""".format(slaves=AXILiteDecoder._doc_slaves)
def __init__(self, masters, slaves, register=False):
matches, busses = zip(*slaves)
access = [[AXILiteInterface() for j in slaves] for i in masters]
access_m_s = [[AXILiteInterface() for j in slaves] for i in masters] # a[master][slave]
access_s_m = list(zip(*access_m_s)) # a[slave][master]
# decode each master into its access row
for row, master in zip(access, masters):
row = list(zip(matches, row))
self.submodules += AXILiteDecoder(master, row, register)
for slaves, master in zip(access_m_s, masters):
slaves = list(zip(matches, slaves))
self.submodules += AXILiteDecoder(master, slaves, register)
# arbitrate each access column onto its slave
for column, bus in zip(zip(*access), busses):
self.submodules += AXILiteArbiter(column, bus)
for masters, bus in zip(access_s_m, busses):
self.submodules += AXILiteArbiter(masters, bus)

View File

@ -342,6 +342,7 @@ class AXILiteChecker:
yield
def handle_write(self, axi_lite):
# aw
while not (yield axi_lite.aw.valid):
yield
yield from self.delay(self.ready_latency)
@ -352,12 +353,14 @@ class AXILiteChecker:
while not (yield axi_lite.w.valid):
yield
yield from self.delay(self.ready_latency)
# w
data = (yield axi_lite.w.data)
strb = (yield axi_lite.w.strb)
yield axi_lite.w.ready.eq(1)
yield
yield axi_lite.w.ready.eq(0)
yield from self.delay(self.response_latency)
# b
yield axi_lite.b.valid.eq(1)
yield axi_lite.b.resp.eq(RESP_OKAY)
yield
@ -367,6 +370,7 @@ class AXILiteChecker:
self.writes.append((addr, data, strb))
def handle_read(self, axi_lite):
# ar
while not (yield axi_lite.ar.valid):
yield
yield from self.delay(self.ready_latency)
@ -375,6 +379,7 @@ class AXILiteChecker:
yield
yield axi_lite.ar.ready.eq(0)
yield from self.delay(self.response_latency)
# r
data = self.rdata_generator(addr)
yield axi_lite.r.valid.eq(1)
yield axi_lite.r.resp.eq(RESP_OKAY)
@ -383,6 +388,7 @@ class AXILiteChecker:
while not (yield axi_lite.r.ready):
yield
yield axi_lite.r.valid.eq(0)
yield axi_lite.r.data.eq(0)
self.reads.append((addr, data))
@passive
@ -650,7 +656,7 @@ class TestAXILite(unittest.TestCase):
# TestAXILiteInterconnet ---------------------------------------------------------------------------
class TestAXILiteInterconnect(unittest.TestCase):
def axilite_pattern_generator(self, axi_lite, pattern):
def axilite_pattern_generator(self, axi_lite, pattern, delay=0):
for rw, addr, data in pattern:
assert rw in ["w", "r"]
if rw == "w":
@ -660,6 +666,8 @@ class TestAXILiteInterconnect(unittest.TestCase):
rdata, resp = (yield from axi_lite.read(addr))
self.assertEqual(resp, RESP_OKAY)
self.assertEqual(rdata, data)
for _ in range(delay):
yield
for _ in range(16):
yield
@ -776,7 +784,7 @@ class TestAXILiteInterconnect(unittest.TestCase):
checker = AXILiteChecker()
generators = [generator(i, master, delay=1) for i, master in enumerate(dut.masters)]
generators += [timeout(300), checker.handler(dut.slave)]
run_simulation(dut, generators, vcd_name='sim.vcd')
run_simulation(dut, generators)
order = [0, 100, 200, 1, 101, 201, 2, 102, 202, 3, 103, 203]
self.assertEqual([addr for addr, data, strb in checker.writes], order)
self.assertEqual([addr for addr, data in checker.reads], order)
@ -805,7 +813,6 @@ class TestAXILiteInterconnect(unittest.TestCase):
for _ in range(8):
yield
n_masters = 3
# with no delay each master will do all transfers at once
@ -825,7 +832,117 @@ class TestAXILiteInterconnect(unittest.TestCase):
checker = AXILiteChecker(response_latency=lambda: 3)
generators = [generator(i, master, delay=1) for i, master in enumerate(dut.masters)]
generators += [timeout(300), checker.handler(dut.slave)]
run_simulation(dut, generators, vcd_name='sim.vcd')
run_simulation(dut, generators)
order = [0, 100, 200, 1, 101, 201, 2, 102, 202, 3, 103, 203]
self.assertEqual([addr for addr, data, strb in checker.writes], order)
self.assertEqual([addr for addr, data in checker.reads], order)
def decoder_test(self, n_slaves, pattern, generator_delay=0):
class DUT(Module):
def __init__(self, decoders):
self.master = AXILiteInterface()
self.slaves = [AXILiteInterface() for _ in range(len(decoders))]
slaves = list(zip(decoders, self.slaves))
self.submodules.decoder = AXILiteDecoder(self.master, slaves)
def decoder(i):
# bytes to 32-bit words aligned
size = (0x100) >> 2
origin = (0x100 * i) >> 2
return lambda a: (a[log2_int(size):] == (origin >> log2_int(size)))
def rdata_generator(adr):
for rw, a, v in pattern:
if rw == "r" and a == adr:
return v
return 0xbaadc0de
dut = DUT([decoder(i) for i in range(n_slaves)])
checkers = [AXILiteChecker(rdata_generator=rdata_generator) for _ in dut.slaves]
generators = [self.axilite_pattern_generator(dut.master, pattern, delay=generator_delay)]
generators += [checker.handler(slave) for (slave, checker) in zip(dut.slaves, checkers)]
generators += [timeout(300)]
run_simulation(dut, generators, vcd_name='sim.vcd')
return checkers
def test_decoder_write(self):
for delay in [0, 1, 0]:
with self.subTest(delay=delay):
slaves = self.decoder_test(n_slaves=3, pattern=[
("w", 0x010, 1),
("w", 0x110, 2),
("w", 0x210, 3),
("w", 0x011, 1),
("w", 0x012, 1),
("w", 0x111, 2),
("w", 0x112, 2),
("w", 0x211, 3),
("w", 0x212, 3),
], generator_delay=delay)
def addr(checker_list):
return [entry[0] for entry in checker_list]
self.assertEqual(addr(slaves[0].writes), [0x010, 0x011, 0x012])
self.assertEqual(addr(slaves[1].writes), [0x110, 0x111, 0x112])
self.assertEqual(addr(slaves[2].writes), [0x210, 0x211, 0x212])
for slave in slaves:
self.assertEqual(slave.reads, [])
def test_decoder_read(self):
for delay in [0, 1]:
with self.subTest(delay=delay):
slaves = self.decoder_test(n_slaves=3, pattern=[
("r", 0x010, 1),
("r", 0x110, 2),
("r", 0x210, 3),
("r", 0x011, 1),
("r", 0x012, 1),
("r", 0x111, 2),
("r", 0x112, 2),
("r", 0x211, 3),
("r", 0x212, 3),
], generator_delay=delay)
def addr(checker_list):
return [entry[0] for entry in checker_list]
self.assertEqual(addr(slaves[0].reads), [0x010, 0x011, 0x012])
self.assertEqual(addr(slaves[1].reads), [0x110, 0x111, 0x112])
self.assertEqual(addr(slaves[2].reads), [0x210, 0x211, 0x212])
for slave in slaves:
self.assertEqual(slave.writes, [])
def test_decoder_read_write(self):
for delay in [0, 1]:
with self.subTest(delay=delay):
slaves = self.decoder_test(n_slaves=3, pattern=[
("w", 0x010, 1),
("w", 0x110, 2),
("r", 0x111, 2),
("r", 0x011, 1),
("r", 0x211, 3),
("w", 0x210, 3),
], generator_delay=delay)
def addr(checker_list):
return [entry[0] for entry in checker_list]
self.assertEqual(addr(slaves[0].writes), [0x010])
self.assertEqual(addr(slaves[0].reads), [0x011])
self.assertEqual(addr(slaves[1].writes), [0x110])
self.assertEqual(addr(slaves[1].reads), [0x111])
self.assertEqual(addr(slaves[2].writes), [0x210])
self.assertEqual(addr(slaves[2].reads), [0x211])
def test_decoder_stall(self):
with self.assertRaises(TimeoutError):
self.decoder_test(n_slaves=3, pattern=[
("w", 0x300, 1),
])
with self.assertRaises(TimeoutError):
self.decoder_test(n_slaves=3, pattern=[
("r", 0x300, 1),
])