diff --git a/litedram/frontend/axi.py b/litedram/frontend/axi.py index 8624820..3af364c 100644 --- a/litedram/frontend/axi.py +++ b/litedram/frontend/axi.py @@ -6,11 +6,20 @@ from migen.genlib.roundrobin import * from litex.soc.interconnect import stream +burst_types = { + "fixed": 0b00, + "incr": 0b01, + "wrap": 0b10, # FIXME: Not implemented + "reserved": 0b11 +} -def aw_description(address_width, id_width): +def ax_description(address_width, id_width): return [ - ("addr", address_width), - ("id", id_width) + ("addr", address_width), + ("burst", 2), # burst type + ("len", 8), # number of data transfers (up to 256) + ("size", 4), # number of bytes of each data transfer (up to 1024 bits) + ("id", id_width) ] def w_description(data_width): @@ -24,12 +33,6 @@ def b_description(id_width): ("id", id_width) ] -def ar_description(address_width, id_width): - return [ - ("addr", address_width), - ("id", id_width) - ] - def r_description(data_width, id_width): return [ ("data", data_width), @@ -44,13 +47,64 @@ class LiteDRAMAXIPort(Record): self.id_width = id_width self.clock_domain = clock_domain - self.aw = stream.Endpoint(aw_description(address_width, id_width)) + self.aw = stream.Endpoint(ax_description(address_width, id_width)) self.w = stream.Endpoint(w_description(data_width)) self.b = stream.Endpoint(b_description(id_width)) - self.ar = stream.Endpoint(ar_description(address_width, id_width)) + self.ar = stream.Endpoint(ax_description(address_width, id_width)) self.r = stream.Endpoint(r_description(data_width, id_width)) +class LiteDRAMAXIBurst2Beat(Module): + def __init__(self, ax_burst, ax_beat): + + # # # + + count = Signal(8) + size = Signal(8+4) + offset = Signal(8+4) + + # convert burst size to bytes + cases = {} + for i in range(11): + cases[i] = size.eq(2**i) + self.comb += Case(ax_burst.size, cases) + + # fsm + self.submodules.fsm = fsm = FSM(reset_state="IDLE") + fsm.act("IDLE", + ax_beat.valid.eq(ax_burst.valid), + ax_beat.addr.eq(ax_burst.addr), + ax_beat.id.eq(ax_burst.id), + If(ax_beat.valid & ax_beat.ready, + If(ax_burst.len != 0, + NextValue(count, 0), + NextValue(offset, size), + NextState("BURST2BEAT") + ).Else( + ax_burst.ready.eq(1) + ) + ) + ) + fsm.act("BURST2BEAT", + ax_beat.valid.eq(1), + If(ax_burst.burst == burst_types["incr"], + ax_beat.addr.eq(ax_burst.addr + offset) + ).Else( + ax_beat.addr.eq(ax_burst.addr) + ), + ax_beat.id.eq(ax_burst.id), + If(ax_beat.valid & ax_beat.ready, + If(count == (ax_burst.len - 1), + ax_burst.ready.eq(1), + NextState("IDLE") + ).Else( + NextValue(count, count + 1), + NextValue(offset, offset + size) + ) + ) + ) + + class LiteDRAMAXI2Native(Module): def __init__(self, axi, port, w_buffer_depth=8, r_buffer_depth=8): @@ -61,6 +115,13 @@ class LiteDRAMAXI2Native(Module): can_write = Signal() can_read = Signal() + # Burst to beat + aw = stream.Endpoint(ax_description(axi.address_width, axi.id_width)) + ar = stream.Endpoint(ax_description(axi.address_width, axi.id_width)) + aw_burst2beat = LiteDRAMAXIBurst2Beat(axi.aw, aw) + ar_burst2beat = LiteDRAMAXIBurst2Beat(axi.ar, ar) + self.submodules += aw_burst2beat, ar_burst2beat + # Write / Read buffers w_buffer = stream.SyncFIFO(w_description(axi.data_width), w_buffer_depth) r_buffer = stream.SyncFIFO(r_description(axi.data_width, axi.id_width), r_buffer_depth) @@ -73,8 +134,8 @@ class LiteDRAMAXI2Native(Module): w_buffer_id = stream.SyncFIFO([("id", axi.id_width)], w_buffer_depth) self.submodules += w_buffer_id self.comb += [ - w_buffer_id.sink.valid.eq(axi.aw.valid & axi.aw.ready), - w_buffer_id.sink.id.eq(axi.aw.id), + w_buffer_id.sink.valid.eq(aw.valid & aw.ready), + w_buffer_id.sink.id.eq(aw.id), axi.b.valid.eq(axi.w.valid & axi.w.ready), # FIXME: axi.b always supposed to be ready axi.b.id.eq(w_buffer_id.source.id), w_buffer_id.source.ready.eq(axi.b.valid & axi.b.ready) @@ -103,8 +164,8 @@ class LiteDRAMAXI2Native(Module): r_buffer_id = stream.SyncFIFO([("id", axi.id_width)], r_buffer_depth) self.submodules += r_buffer_id self.comb += [ - r_buffer_id.sink.valid.eq(axi.ar.valid & axi.ar.ready), - r_buffer_id.sink.id.eq(axi.ar.id), + r_buffer_id.sink.valid.eq(ar.valid & ar.ready), + r_buffer_id.sink.id.eq(ar.id), axi.r.id.eq(r_buffer_id.source.id), r_buffer_id.source.ready.eq(axi.r.valid & axi.r.ready) ] @@ -113,22 +174,22 @@ class LiteDRAMAXI2Native(Module): arbiter = RoundRobin(2, SP_CE) self.submodules += arbiter self.comb += [ - arbiter.request[0].eq(axi.aw.valid & can_write), - arbiter.request[1].eq(axi.ar.valid & can_read), + arbiter.request[0].eq(aw.valid & can_write), + arbiter.request[1].eq(ar.valid & can_read), arbiter.ce.eq(~port.cmd.valid | port.cmd.ready) ] self.comb += [ If(arbiter.grant, - port.cmd.valid.eq(axi.ar.valid & can_read), - axi.ar.ready.eq(port.cmd.ready & can_read), + port.cmd.valid.eq(ar.valid & can_read), + ar.ready.eq(port.cmd.ready & can_read), port.cmd.we.eq(0), - port.cmd.adr.eq(axi.ar.addr >> ashift) + port.cmd.adr.eq(ar.addr >> ashift) ).Else( - port.cmd.valid.eq(axi.aw.valid & can_write), - axi.aw.ready.eq(port.cmd.ready & can_write), + port.cmd.valid.eq(aw.valid & can_write), + aw.ready.eq(port.cmd.ready & can_write), port.cmd.we.eq(1), - port.cmd.adr.eq(axi.aw.addr >> ashift) + port.cmd.adr.eq(aw.addr >> ashift) ) ] diff --git a/test/test_axi.py b/test/test_axi.py index 1795260..3893ff8 100755 --- a/test/test_axi.py +++ b/test/test_axi.py @@ -9,64 +9,143 @@ from litedram.frontend.axi import * from litex.gen.sim import * -def main_generator(axi_port, dram_port, dut): - prng = random.Random(42) - # axi_port always accepting wresps/rdatas - yield axi_port.b.ready.eq(1) - yield axi_port.r.ready.eq(1) - yield - # test writes - for i in range(16): - # write command - yield axi_port.aw.valid.eq(1) - yield axi_port.aw.addr.eq(i) - while (yield dram_port.cmd.ready) == 0: - if prng.randrange(100) < 20: - yield dram_port.cmd.ready.eq(1) - yield - yield axi_port.aw.valid.eq(0) - yield dram_port.cmd.ready.eq(0) - yield - # write data - yield axi_port.w.valid.eq(1) - yield axi_port.w.data.eq(i) - while (yield dram_port.wdata.ready) == 0: - if prng.randrange(100) < 20: - yield dram_port.wdata.ready.eq(1) - yield - if (yield axi_port.w.ready) == 1: - yield axi_port.w.valid.eq(0) - yield axi_port.aw.valid.eq(0) - yield dram_port.wdata.ready.eq(0) - yield - # test reads - for i in range(16): - # read command - yield axi_port.ar.valid.eq(1) - yield axi_port.ar.addr.eq(i) - while (yield dram_port.cmd.ready) == 0: - if prng.randrange(100) < 20: - yield dram_port.cmd.ready.eq(1) - yield - yield axi_port.ar.valid.eq(0) - yield dram_port.cmd.ready.eq(0) - yield - # read data - yield dram_port.rdata.valid.eq(1) - yield dram_port.rdata.data.eq(i) - while (yield dram_port.rdata.valid) == 0: - if prng.randrange(100) < 20: - yield dram_port.rdata.valid.eq(1) - yield - yield axi_port.ar.valid.eq(0) - yield dram_port.rdata.valid.eq(0) - yield - for i in range(128): - yield - class TestAXI(unittest.TestCase): - def test(self): - axi_port = LiteDRAMAXIPort(32, 24, 32) - dram_port = LiteDRAMNativePort("both", 24, 32) + def test_axi2native(self): + def main_generator(axi_port, dram_port, dut): + prng = random.Random(42) + # axi_port always accepting wresps/rdatas + yield axi_port.b.ready.eq(1) + yield axi_port.r.ready.eq(1) + yield + # test writes + for i in range(16): + # write command + yield axi_port.aw.valid.eq(1) + yield axi_port.aw.addr.eq(i) + while (yield dram_port.cmd.ready) == 0: + if prng.randrange(100) < 20: + yield dram_port.cmd.ready.eq(1) + yield + yield axi_port.aw.valid.eq(0) + yield dram_port.cmd.ready.eq(0) + yield + # write data + yield axi_port.w.valid.eq(1) + yield axi_port.w.data.eq(i) + while (yield dram_port.wdata.ready) == 0: + if prng.randrange(100) < 20: + yield dram_port.wdata.ready.eq(1) + yield + if (yield axi_port.w.ready) == 1: + yield axi_port.w.valid.eq(0) + yield axi_port.aw.valid.eq(0) + yield dram_port.wdata.ready.eq(0) + yield + # test reads + for i in range(16): + # read command + yield axi_port.ar.valid.eq(1) + yield axi_port.ar.addr.eq(i) + while (yield dram_port.cmd.ready) == 0: + if prng.randrange(100) < 20: + yield dram_port.cmd.ready.eq(1) + yield + yield axi_port.ar.valid.eq(0) + yield dram_port.cmd.ready.eq(0) + yield + # read data + yield dram_port.rdata.valid.eq(1) + yield dram_port.rdata.data.eq(i) + while (yield dram_port.rdata.valid) == 0: + if prng.randrange(100) < 20: + yield dram_port.rdata.valid.eq(1) + yield + yield axi_port.ar.valid.eq(0) + yield dram_port.rdata.valid.eq(0) + yield + for i in range(128): + yield + + axi_port = LiteDRAMAXIPort(32, 32, 32) + dram_port = LiteDRAMNativePort("both", 32, 32) dut = LiteDRAMAXI2Native(axi_port, dram_port) - run_simulation(dut, main_generator(axi_port, dram_port, dut), vcd_name="axi.vcd") + run_simulation(dut, main_generator(axi_port, dram_port, dut), vcd_name="axi2native.vcd") + + + def test_burst2beat(self): + class Beat: + def __init__(self, addr): + self.addr = addr + + class Burst: + def __init__(self, type, addr, len, size): + self.type = type + self.addr = addr + self.len = len + self.size = size + + def to_beats(self): + r = [] + for i in range(self.len + 1): + if self.type == burst_types["incr"]: + r += [Beat(self.addr + i*2**(self.size))] + else: + r += [Beat(self.addr)] + return r + + def bursts_generator(ax, bursts, valid_rand=50): + prng = random.Random(42) + for burst in bursts: + yield ax.valid.eq(1) + yield ax.addr.eq(burst.addr) + yield ax.burst.eq(burst.type) + yield ax.len.eq(burst.len) + yield ax.size.eq(burst.size) + while (yield ax.ready) == 0: + yield + yield ax.valid.eq(0) + while prng.randrange(100) < valid_rand: + yield + yield + + @passive + def beats_checker(ax, beats, ready_rand=50): + self.errors = 0 + yield ax.ready.eq(0) + prng = random.Random(42) + for beat in beats: + while ((yield ax.valid) and (yield ax.ready)) == 0: + if prng.randrange(100) > ready_rand: + yield ax.ready.eq(1) + else: + yield ax.ready.eq(0) + yield + ax_addr = (yield ax.addr) + if ax_addr != beat.addr: + self.errors += 1 + yield + + # dut + ax_burst = stream.Endpoint(ax_description(32, 32)) + ax_beat = stream.Endpoint(ax_description(32, 32)) + dut = LiteDRAMAXIBurst2Beat(ax_burst, ax_beat) + + # generate dut input (bursts) + prng = random.Random(42) + bursts = [] + for i in range(32): + bursts.append(Burst(burst_types["fixed"], prng.randrange(2**32), prng.randrange(256), log2_int(32//8))) + bursts.append(Burst(burst_types["incr"], prng.randrange(2**32), prng.randrange(256), log2_int(32//8))) + + # generate expexted dut output (beats for reference) + beats = [] + for burst in bursts: + beats += burst.to_beats() + + # simulation + generators = [ + bursts_generator(ax_burst, bursts), + beats_checker(ax_beat, beats) + ] + run_simulation(dut, generators, vcd_name="burst2beat.vcd") + self.assertEqual(self.errors, 0)