diff --git a/litex/soc/interconnect/axi.py b/litex/soc/interconnect/axi.py index 40fae8be5..e76abe301 100644 --- a/litex/soc/interconnect/axi.py +++ b/litex/soc/interconnect/axi.py @@ -192,6 +192,7 @@ class AXILiteInterface: def write(self, addr, data, strb=None): if strb is None: strb = 2**len(self.w.strb) - 1 + # aw + w yield self.aw.valid.eq(1) yield self.aw.addr.eq(addr) yield self.w.data.eq(data) @@ -201,9 +202,12 @@ class AXILiteInterface: while not (yield self.aw.ready): yield yield self.aw.valid.eq(0) + yield self.aw.addr.eq(0) while not (yield self.w.ready): yield yield self.w.valid.eq(0) + yield self.w.strb.eq(0) + # b yield self.b.ready.eq(1) while not (yield self.b.valid): yield @@ -212,12 +216,14 @@ class AXILiteInterface: return resp def read(self, addr): + # ar yield self.ar.valid.eq(1) yield self.ar.addr.eq(addr) yield while not (yield self.ar.ready): yield yield self.ar.valid.eq(0) + # r yield self.r.ready.eq(1) while not (yield self.r.valid): yield @@ -943,6 +949,7 @@ class AXILiteRequestCounter(Module): self.full = full = Signal() self.empty = empty = Signal() self.stall = stall = Signal() + self.ready = self.empty self.comb += [ full.eq(counter == max_requests - 1), @@ -994,15 +1001,15 @@ class AXILiteArbiter(Module): self.comb += dest.eq(source) # allow to change rr.grant only after all requests from a master have been responded to - self.submodules.wr_counter = wr_counter = AXILiteRequestCounter( + self.submodules.wr_lock = wr_lock = AXILiteRequestCounter( request=target.aw.valid & target.aw.ready, response=target.b.valid & target.b.ready) - self.submodules.rd_counter = rd_counter = AXILiteRequestCounter( + self.submodules.rd_lock = rd_lock = AXILiteRequestCounter( request=target.ar.valid & target.ar.ready, response=target.r.valid & target.r.ready) # switch to next request only if there are no responses pending self.comb += [ - self.rr_write.ce.eq(~(target.aw.valid | target.w.valid | target.b.valid) & wr_counter.empty), - self.rr_read.ce.eq(~(target.ar.valid | target.r.valid) & rd_counter.empty), + self.rr_write.ce.eq(~(target.aw.valid | target.w.valid | target.b.valid) & wr_lock.ready), + self.rr_read.ce.eq(~(target.ar.valid | target.r.valid) & rd_lock.ready), ] # connect bus requests to round-robin selectors @@ -1012,54 +1019,103 @@ class AXILiteArbiter(Module): ] class AXILiteDecoder(Module): - # slaves is a list of pairs: - # 0) function that takes the address signal and returns a FHDL expression - # that evaluates to 1 when the slave is selected and 0 otherwise. - # 1) wishbone.Slave reference. - # register adds flip-flops after the address comparators. Improves timing, - # but breaks Wishbone combinatorial feedback. - def __init__(self, master, slaves, register=False): + _doc_slaves = """ + slaves: [(decoder, slave), ...] + List of slaves with address decoders, where `decoder` is a function: + decoder(Signal(address_width - log2(data_width//8))) -> Signal(1) + that returns 1 when the slave is selected and 0 otherwise. + """.strip() + + __doc__ = """AXI Lite decoder + + Decode master access to particular slave based on its decoder function. + + {slaves} + """.format(slaves=_doc_slaves) + + def __init__(self, master, slaves): addr_shift = log2_int(master.data_width//8) - ns = len(slaves) - slave_sel = Signal(ns) - slave_sel_r = Signal(ns) + + channels = { + "write": {"aw", "w", "b"}, + "read": {"ar", "r"}, + } + # reverse mapping: directions[channel] -> "write"/"read" + directions = {ch: d for d, chs in channels.items() for ch in chs} + + def new_slave_sel(): + return {"write": Signal(len(slaves)), "read": Signal(len(slaves))} + + slave_sel_dec = new_slave_sel() + slave_sel_reg = new_slave_sel() + slave_sel = new_slave_sel() + + # we need to hold the slave selected until all responses come back + # TODO: check if this will break Timeout if a slave does not respond? + # should probably work correctly as it uses master signals + # TODO: we could reuse arbiter counters + locks = { + "write": AXILiteRequestCounter( + request=master.aw.valid & master.aw.ready, + response=master.b.valid & master.b.ready), + "read": AXILiteRequestCounter( + request=master.ar.valid & master.ar.ready, + response=master.r.valid & master.r.ready), + } + self.submodules += locks.values() def get_sig(interface, channel, name): return getattr(getattr(interface, channel), name) + # # # + # decode slave addresses - self.comb += [slave_sel[i].eq(fun(master.aw.addr[addr_shift:]) | fun(master.aw.addr[addr_shift:])) - for i, (fun, bus) in enumerate(slaves)] - if register: - self.sync += slave_sel_r.eq(slave_sel) - else: - self.comb += slave_sel_r.eq(slave_sel) + for i, (decoder, bus) in enumerate(slaves): + self.comb += [ + slave_sel_dec["write"][i].eq(decoder(master.aw.addr[addr_shift:])), + slave_sel_dec["read"][i].eq(decoder(master.ar.addr[addr_shift:])), + ] - # connect master->slaves signals except valid - for fun, slave in slaves: + # change the current selection only when we've got all responses + for channel in locks.keys(): + self.sync += If(locks[channel].ready, slave_sel_reg[channel].eq(slave_sel_dec[channel])) + # we have to cut the delaying select + for ch, final in slave_sel.items(): + self.comb += If(locks[ch].ready, + final.eq(slave_sel_dec[ch]) + ).Else( + final.eq(slave_sel_reg[ch]) + ) + + # connect master->slaves signals except valid/ready + for i, (_, slave) in enumerate(slaves): for channel, name, direction in master.layout_flat(): - if direction == DIR_M_TO_S and name != "valid": - self.comb += get_sig(slave, channel, name).eq(get_sig(master, channel, name)) + if direction == DIR_M_TO_S: + src = get_sig(master, channel, name) + dst = get_sig(slave, channel, name) + # mask master control signals depending on slave selection + if name in ["valid", "ready"]: + src = src & slave_sel[directions[channel]][i] + self.comb += dst.eq(src) - # combine cyc with slave selection signals - for i, (fun, slave) in enumerate(slaves): - for ch in ["aw", "w", "ar"]: - slave_valid = get_sig(slave, ch, "valid") - master_valid = get_sig(master, ch, "valid") - self.comb += slave_valid.eq(master_valid & slave_sel[i]) - - # generate master ready by ORing all slave readys - self.comb += [ - master.aw.ready.eq(reduce(or_, [slave.aw.ready for fun, slave in slaves])), - master.w.ready.eq(reduce(or_, [slave.w.ready for fun, slave in slaves])), - master.ar.ready.eq(reduce(or_, [slave.ar.ready for fun, slave in slaves])), - ] - - # mux (1-hot) slave data return - masked = [Replicate(slave_sel_r[i], len(master.r.data)) & slaves[i][1].r.data for i in range(ns)] - self.comb += master.r.data.eq(reduce(or_, masked)) + # connect slave->master signals masking not selected slaves + for channel, name, direction in master.layout_flat(): + if direction == DIR_S_TO_M: + dst = get_sig(master, channel, name) + masked = [] + for i, (_, slave) in enumerate(slaves): + src = get_sig(slave, channel, name) + # mask depending on channel + mask = Replicate(slave_sel[directions[channel]][i], len(dst)) + masked.append(src & mask) + self.comb += dst.eq(reduce(or_, masked)) class AXILiteInterconnectShared(Module): + __doc__ = """AXI Lite shared interconnect + + {slaves} + """.format(slaves=AXILiteDecoder._doc_slaves) + def __init__(self, masters, slaves, register=False, timeout_cycles=1e6): # TODO: data width shared = AXILiteInterface() @@ -1069,13 +1125,21 @@ class AXILiteInterconnectShared(Module): self.submodules.timeout = AXILiteTimeout(shared, timeout_cycles) class AXILiteCrossbar(Module): + __doc__ = """AXI Lite crossbar + + MxN crossbar for M masters and N slaves. + + {slaves} + """.format(slaves=AXILiteDecoder._doc_slaves) + def __init__(self, masters, slaves, register=False): matches, busses = zip(*slaves) - access = [[AXILiteInterface() for j in slaves] for i in masters] + access_m_s = [[AXILiteInterface() for j in slaves] for i in masters] # a[master][slave] + access_s_m = list(zip(*access_m_s)) # a[slave][master] # decode each master into its access row - for row, master in zip(access, masters): - row = list(zip(matches, row)) - self.submodules += AXILiteDecoder(master, row, register) + for slaves, master in zip(access_m_s, masters): + slaves = list(zip(matches, slaves)) + self.submodules += AXILiteDecoder(master, slaves, register) # arbitrate each access column onto its slave - for column, bus in zip(zip(*access), busses): - self.submodules += AXILiteArbiter(column, bus) + for masters, bus in zip(access_s_m, busses): + self.submodules += AXILiteArbiter(masters, bus) diff --git a/test/test_axi.py b/test/test_axi.py index ee5d4107d..1c8345569 100644 --- a/test/test_axi.py +++ b/test/test_axi.py @@ -342,6 +342,7 @@ class AXILiteChecker: yield def handle_write(self, axi_lite): + # aw while not (yield axi_lite.aw.valid): yield yield from self.delay(self.ready_latency) @@ -352,12 +353,14 @@ class AXILiteChecker: while not (yield axi_lite.w.valid): yield yield from self.delay(self.ready_latency) + # w data = (yield axi_lite.w.data) strb = (yield axi_lite.w.strb) yield axi_lite.w.ready.eq(1) yield yield axi_lite.w.ready.eq(0) yield from self.delay(self.response_latency) + # b yield axi_lite.b.valid.eq(1) yield axi_lite.b.resp.eq(RESP_OKAY) yield @@ -367,6 +370,7 @@ class AXILiteChecker: self.writes.append((addr, data, strb)) def handle_read(self, axi_lite): + # ar while not (yield axi_lite.ar.valid): yield yield from self.delay(self.ready_latency) @@ -375,6 +379,7 @@ class AXILiteChecker: yield yield axi_lite.ar.ready.eq(0) yield from self.delay(self.response_latency) + # r data = self.rdata_generator(addr) yield axi_lite.r.valid.eq(1) yield axi_lite.r.resp.eq(RESP_OKAY) @@ -383,6 +388,7 @@ class AXILiteChecker: while not (yield axi_lite.r.ready): yield yield axi_lite.r.valid.eq(0) + yield axi_lite.r.data.eq(0) self.reads.append((addr, data)) @passive @@ -650,7 +656,7 @@ class TestAXILite(unittest.TestCase): # TestAXILiteInterconnet --------------------------------------------------------------------------- class TestAXILiteInterconnect(unittest.TestCase): - def axilite_pattern_generator(self, axi_lite, pattern): + def axilite_pattern_generator(self, axi_lite, pattern, delay=0): for rw, addr, data in pattern: assert rw in ["w", "r"] if rw == "w": @@ -660,6 +666,8 @@ class TestAXILiteInterconnect(unittest.TestCase): rdata, resp = (yield from axi_lite.read(addr)) self.assertEqual(resp, RESP_OKAY) self.assertEqual(rdata, data) + for _ in range(delay): + yield for _ in range(16): yield @@ -776,7 +784,7 @@ class TestAXILiteInterconnect(unittest.TestCase): checker = AXILiteChecker() generators = [generator(i, master, delay=1) for i, master in enumerate(dut.masters)] generators += [timeout(300), checker.handler(dut.slave)] - run_simulation(dut, generators, vcd_name='sim.vcd') + run_simulation(dut, generators) order = [0, 100, 200, 1, 101, 201, 2, 102, 202, 3, 103, 203] self.assertEqual([addr for addr, data, strb in checker.writes], order) self.assertEqual([addr for addr, data in checker.reads], order) @@ -805,7 +813,6 @@ class TestAXILiteInterconnect(unittest.TestCase): for _ in range(8): yield - n_masters = 3 # with no delay each master will do all transfers at once @@ -825,7 +832,117 @@ class TestAXILiteInterconnect(unittest.TestCase): checker = AXILiteChecker(response_latency=lambda: 3) generators = [generator(i, master, delay=1) for i, master in enumerate(dut.masters)] generators += [timeout(300), checker.handler(dut.slave)] - run_simulation(dut, generators, vcd_name='sim.vcd') + run_simulation(dut, generators) order = [0, 100, 200, 1, 101, 201, 2, 102, 202, 3, 103, 203] self.assertEqual([addr for addr, data, strb in checker.writes], order) self.assertEqual([addr for addr, data in checker.reads], order) + + def decoder_test(self, n_slaves, pattern, generator_delay=0): + class DUT(Module): + def __init__(self, decoders): + self.master = AXILiteInterface() + self.slaves = [AXILiteInterface() for _ in range(len(decoders))] + slaves = list(zip(decoders, self.slaves)) + self.submodules.decoder = AXILiteDecoder(self.master, slaves) + + def decoder(i): + # bytes to 32-bit words aligned + size = (0x100) >> 2 + origin = (0x100 * i) >> 2 + return lambda a: (a[log2_int(size):] == (origin >> log2_int(size))) + + def rdata_generator(adr): + for rw, a, v in pattern: + if rw == "r" and a == adr: + return v + return 0xbaadc0de + + dut = DUT([decoder(i) for i in range(n_slaves)]) + checkers = [AXILiteChecker(rdata_generator=rdata_generator) for _ in dut.slaves] + + generators = [self.axilite_pattern_generator(dut.master, pattern, delay=generator_delay)] + generators += [checker.handler(slave) for (slave, checker) in zip(dut.slaves, checkers)] + generators += [timeout(300)] + run_simulation(dut, generators, vcd_name='sim.vcd') + + return checkers + + def test_decoder_write(self): + for delay in [0, 1, 0]: + with self.subTest(delay=delay): + slaves = self.decoder_test(n_slaves=3, pattern=[ + ("w", 0x010, 1), + ("w", 0x110, 2), + ("w", 0x210, 3), + ("w", 0x011, 1), + ("w", 0x012, 1), + ("w", 0x111, 2), + ("w", 0x112, 2), + ("w", 0x211, 3), + ("w", 0x212, 3), + ], generator_delay=delay) + + def addr(checker_list): + return [entry[0] for entry in checker_list] + + self.assertEqual(addr(slaves[0].writes), [0x010, 0x011, 0x012]) + self.assertEqual(addr(slaves[1].writes), [0x110, 0x111, 0x112]) + self.assertEqual(addr(slaves[2].writes), [0x210, 0x211, 0x212]) + for slave in slaves: + self.assertEqual(slave.reads, []) + + def test_decoder_read(self): + for delay in [0, 1]: + with self.subTest(delay=delay): + slaves = self.decoder_test(n_slaves=3, pattern=[ + ("r", 0x010, 1), + ("r", 0x110, 2), + ("r", 0x210, 3), + ("r", 0x011, 1), + ("r", 0x012, 1), + ("r", 0x111, 2), + ("r", 0x112, 2), + ("r", 0x211, 3), + ("r", 0x212, 3), + ], generator_delay=delay) + + def addr(checker_list): + return [entry[0] for entry in checker_list] + + self.assertEqual(addr(slaves[0].reads), [0x010, 0x011, 0x012]) + self.assertEqual(addr(slaves[1].reads), [0x110, 0x111, 0x112]) + self.assertEqual(addr(slaves[2].reads), [0x210, 0x211, 0x212]) + for slave in slaves: + self.assertEqual(slave.writes, []) + + def test_decoder_read_write(self): + for delay in [0, 1]: + with self.subTest(delay=delay): + slaves = self.decoder_test(n_slaves=3, pattern=[ + ("w", 0x010, 1), + ("w", 0x110, 2), + ("r", 0x111, 2), + ("r", 0x011, 1), + ("r", 0x211, 3), + ("w", 0x210, 3), + ], generator_delay=delay) + + def addr(checker_list): + return [entry[0] for entry in checker_list] + + self.assertEqual(addr(slaves[0].writes), [0x010]) + self.assertEqual(addr(slaves[0].reads), [0x011]) + self.assertEqual(addr(slaves[1].writes), [0x110]) + self.assertEqual(addr(slaves[1].reads), [0x111]) + self.assertEqual(addr(slaves[2].writes), [0x210]) + self.assertEqual(addr(slaves[2].reads), [0x211]) + + def test_decoder_stall(self): + with self.assertRaises(TimeoutError): + self.decoder_test(n_slaves=3, pattern=[ + ("w", 0x300, 1), + ]) + with self.assertRaises(TimeoutError): + self.decoder_test(n_slaves=3, pattern=[ + ("r", 0x300, 1), + ])