Merge pull request #593 from antmicro/jboc/axi-lite

Add AXILite components: AXILiteSRAM and AXILite2CSR
This commit is contained in:
enjoy-digital 2020-07-16 11:56:57 +02:00 committed by GitHub
commit 21c48eed76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 315 additions and 10 deletions

View File

@ -280,21 +280,46 @@ class SoCBusHandler(Module):
# Add Master/Slave ----------------------------------------------------------------------------- # Add Master/Slave -----------------------------------------------------------------------------
def add_adapter(self, name, interface, direction="m2s"): def add_adapter(self, name, interface, direction="m2s"):
assert direction in ["m2s", "s2m"] assert direction in ["m2s", "s2m"]
if interface.data_width != self.data_width:
self.logger.info("{} Bus {} from {}-bit to {}-bit.".format( if isinstance(interface, wishbone.Interface):
colorer(name),
colorer("converted", color="cyan"),
colorer(interface.data_width),
colorer(self.data_width)))
new_interface = wishbone.Interface(data_width=self.data_width) new_interface = wishbone.Interface(data_width=self.data_width)
if direction == "m2s": if direction == "m2s":
converter = wishbone.Converter(master=interface, slave=new_interface) converter = wishbone.Converter(master=interface, slave=new_interface)
if direction == "s2m": if direction == "s2m":
converter = wishbone.Converter(master=new_interface, slave=interface) converter = wishbone.Converter(master=new_interface, slave=interface)
self.submodules += converter self.submodules += converter
return new_interface elif isinstance(interface, axi.AXILiteInterface):
# Data width conversion
intermediate = axi.AXILiteInterface(data_width=self.data_width)
if direction == "m2s":
converter = axi.AXILiteConverter(master=interface, slave=intermediate)
if direction == "s2m":
converter = axi.AXILiteConverter(master=intermediate, slave=interface)
self.submodules += converter
# Bus type conversion
new_interface = wishbone.Interface(data_width=self.data_width)
if direction == "m2s":
converter = axi.AXILite2Wishbone(axi_lite=intermediate, wishbone=new_interface)
elif direction == "s2m":
converter = axi.Wishbone2AXILite(wishbone=new_interface, axi_lite=intermediate)
self.submodules += converter
else: else:
return interface raise TypeError(interface)
fmt = "{name} Bus {converted} from {frombus} {frombits}-bit to {tobus} {tobits}-bit."
frombus = "Wishbone" if isinstance(interface, wishbone.Interface) else "AXILite"
tobus = "Wishbone" if isinstance(new_interface, wishbone.Interface) else "AXILite"
frombits = interface.data_width
tobits = new_interface.data_width
if frombus != tobus or frombits != tobits:
self.logger.info(fmt.format(
name = colorer(name),
converted = colorer("converted", color="cyan"),
frombus = colorer("Wishbone" if isinstance(interface, wishbone.Interface) else "AXILite"),
frombits = colorer(interface.data_width),
tobus = colorer("Wishbone" if isinstance(new_interface, wishbone.Interface) else "AXILite"),
tobits = colorer(new_interface.data_width)))
return new_interface
def add_master(self, name=None, master=None): def add_master(self, name=None, master=None):
if name is None: if name is None:

View File

@ -55,6 +55,23 @@ def r_description(data_width, id_width):
("id", id_width) ("id", id_width)
] ]
def _connect_axi(master, slave):
channel_modes = {
"aw": "master",
"w" : "master",
"b" : "slave",
"ar": "master",
"r" : "slave",
}
r = []
for channel, mode in channel_modes.items():
if mode == "master":
m, s = getattr(master, channel), getattr(slave, channel)
else:
s, m = getattr(master, channel), getattr(slave, channel)
r.extend(m.connect(s))
return r
class AXIInterface: class AXIInterface:
def __init__(self, data_width=32, address_width=32, id_width=1, clock_domain="sys"): def __init__(self, data_width=32, address_width=32, id_width=1, clock_domain="sys"):
self.data_width = data_width self.data_width = data_width
@ -68,6 +85,9 @@ class AXIInterface:
self.ar = stream.Endpoint(ax_description(address_width, id_width)) self.ar = stream.Endpoint(ax_description(address_width, id_width))
self.r = stream.Endpoint(r_description(data_width, id_width)) self.r = stream.Endpoint(r_description(data_width, id_width))
def connect(self, slave):
return _connect_axi(self, slave)
# AXI Lite Definition ------------------------------------------------------------------------------ # AXI Lite Definition ------------------------------------------------------------------------------
def ax_lite_description(address_width): def ax_lite_description(address_width):
@ -138,6 +158,45 @@ class AXILiteInterface:
r.append(pad.eq(sig)) r.append(pad.eq(sig))
return r return r
def connect(self, slave):
return _connect_axi(self, slave)
def write(self, addr, data, strb=None):
if strb is None:
strb = 2**len(self.w.strb) - 1
yield self.aw.valid.eq(1)
yield self.aw.addr.eq(addr)
yield self.w.data.eq(data)
yield self.w.valid.eq(1)
yield self.w.strb.eq(strb)
yield
while not (yield self.aw.ready):
yield
while not (yield self.w.ready):
yield
while not (yield self.b.valid):
yield
yield self.b.ready.eq(1)
resp = (yield self.b.resp)
yield
yield self.b.ready.eq(0)
return resp
def read(self, addr):
yield self.ar.valid.eq(1)
yield self.ar.addr.eq(addr)
yield
while not (yield self.ar.ready):
yield
while not (yield self.r.valid):
yield
yield self.r.ready.eq(1)
data = (yield self.r.data)
resp = (yield self.r.resp)
yield
yield self.r.ready.eq(0)
return (data, resp)
# AXI Stream Definition ---------------------------------------------------------------------------- # AXI Stream Definition ----------------------------------------------------------------------------
class AXIStreamInterface(stream.Endpoint): class AXIStreamInterface(stream.Endpoint):
@ -494,3 +553,145 @@ class Wishbone2AXILite(Module):
wishbone.err.eq(1), wishbone.err.eq(1),
NextState("IDLE") NextState("IDLE")
) )
# AXILite to CSR -----------------------------------------------------------------------------------
def axi_lite_to_simple(axi_lite, port_adr, port_dat_r, port_dat_w=None, port_we=None):
"""Connection of AXILite to simple bus with 1-cycle latency, such as CSR bus or Memory port"""
bus_data_width = axi_lite.data_width
adr_shift = log2_int(bus_data_width//8)
do_read = Signal()
do_write = Signal()
last_was_read = Signal()
comb = []
if port_dat_w is not None:
comb.append(port_dat_w.eq(axi_lite.w.data))
if port_we is not None:
if len(port_we) > 1:
for i in range(bus_data_width//8):
comb.append(port_we[i].eq(axi_lite.w.valid & axi_lite.w.ready & axi_lite.w.strb[i]))
else:
comb.append(port_we.eq(axi_lite.w.valid & axi_lite.w.ready & (axi_lite.w.strb != 0)))
fsm = FSM()
fsm.act("START-TRANSACTION",
# If the last access was a read, do a write, and vice versa
If(axi_lite.aw.valid & axi_lite.ar.valid,
do_write.eq(last_was_read),
do_read.eq(~last_was_read),
).Else(
do_write.eq(axi_lite.aw.valid),
do_read.eq(axi_lite.ar.valid),
),
# Start reading/writing immediately not to waste a cycle
If(do_write,
port_adr.eq(axi_lite.aw.addr[adr_shift:]),
If(axi_lite.w.valid,
axi_lite.aw.ready.eq(1),
axi_lite.w.ready.eq(1),
NextState("SEND-WRITE-RESPONSE")
)
).Elif(do_read,
port_adr.eq(axi_lite.ar.addr[adr_shift:]),
axi_lite.ar.ready.eq(1),
NextState("SEND-READ-RESPONSE"),
)
)
fsm.act("SEND-READ-RESPONSE",
NextValue(last_was_read, 1),
# As long as we have correct address port.dat_r will be valid
port_adr.eq(axi_lite.ar.addr[adr_shift:]),
axi_lite.r.data.eq(port_dat_r),
axi_lite.r.resp.eq(RESP_OKAY),
axi_lite.r.valid.eq(1),
If(axi_lite.r.ready,
NextState("START-TRANSACTION")
)
)
fsm.act("SEND-WRITE-RESPONSE",
NextValue(last_was_read, 0),
axi_lite.b.valid.eq(1),
axi_lite.b.resp.eq(RESP_OKAY),
If(axi_lite.b.ready,
NextState("START-TRANSACTION")
)
)
return fsm, comb
class AXILite2CSR(Module):
def __init__(self, axi_lite=None, csr=None):
if axi_lite is None:
axi_lite = AXILiteInterface()
if csr is None:
csr = csr.bus.Interface()
self.axi_lite = axi_lite
self.csr = csr
fsm, comb = axi_lite_to_simple(self.axi_lite,
port_adr=self.csr.adr, port_dat_r=self.csr.dat_r,
port_dat_w=self.csr.dat_w, port_we=self.csr.we)
self.submodules.fsm = fsm
self.comb += comb
# AXILite SRAM -------------------------------------------------------------------------------------
class AXILiteSRAM(Module):
def __init__(self, mem_or_size, read_only=None, init=None, bus=None):
if bus is None:
bus = AXILiteInterface()
self.bus = bus
bus_data_width = len(self.bus.r.data)
if isinstance(mem_or_size, Memory):
assert(mem_or_size.width <= bus_data_width)
self.mem = mem_or_size
else:
self.mem = Memory(bus_data_width, mem_or_size//(bus_data_width//8), init=init)
if read_only is None:
if hasattr(self.mem, "bus_read_only"):
read_only = self.mem.bus_read_only
else:
read_only = False
###
# Create memory port
port = self.mem.get_port(write_capable=not read_only, we_granularity=8,
mode=READ_FIRST if read_only else WRITE_FIRST)
self.specials += self.mem, port
# Generate write enable signal
if not read_only:
self.comb += port.dat_w.eq(self.bus.w.data),
self.comb += [port.we[i].eq(self.bus.w.valid & self.bus.w.ready & self.bus.w.strb[i])
for i in range(bus_data_width//8)]
# Transaction logic
fsm, comb = axi_lite_to_simple(self.bus,
port_adr=port.adr, port_dat_r=port.dat_r,
port_dat_w=port.dat_w if not read_only else None,
port_we=port.we if not read_only else None)
self.submodules.fsm = fsm
self.comb += comb
# AXILite Data Width Converter ---------------------------------------------------------------------
class AXILiteConverter(Module):
"""AXILite data width converter"""
def __init__(self, master, slave):
self.master = master
self.slave = slave
# # #
dw_from = len(master.r.data)
dw_to = len(slave.r.data)
if dw_from > dw_to:
raise NotImplementedError
elif dw_from < dw_to:
raise NotImplementedError
else:
self.comb += master.connect(slave)

View File

@ -7,7 +7,7 @@ import random
from migen import * from migen import *
from litex.soc.interconnect.axi import * from litex.soc.interconnect.axi import *
from litex.soc.interconnect import wishbone from litex.soc.interconnect import wishbone, csr_bus
# Software Models ---------------------------------------------------------------------------------- # Software Models ----------------------------------------------------------------------------------
@ -358,5 +358,84 @@ class TestAXI(unittest.TestCase):
dut.errors += 1 dut.errors += 1
dut = DUT() dut = DUT()
run_simulation(dut, [generator(dut)], vcd_name="toto.vcd") run_simulation(dut, [generator(dut)])
self.assertEqual(dut.errors, 0)
def test_axilite2csr(self):
@passive
def csr_mem_handler(csr, mem):
while True:
adr = (yield csr.adr)
yield csr.dat_r.eq(mem[adr])
if (yield csr.we):
mem[adr] = (yield csr.dat_w)
yield
class DUT(Module):
def __init__(self):
self.axi_lite = AXILiteInterface()
self.csr = csr_bus.Interface()
self.submodules.axilite2csr = AXILite2CSR(self.axi_lite, self.csr)
self.errors = 0
prng = random.Random(42)
mem_ref = [prng.randrange(255) for i in range(100)]
def generator(dut):
dut.errors = 0
for adr, ref in enumerate(mem_ref):
adr = adr << 2
data, resp = (yield from dut.axi_lite.read(adr))
self.assertEqual(resp, 0b00)
if data != ref:
dut.errors += 1
write_data = [prng.randrange(255) for _ in mem_ref]
for adr, wdata in enumerate(write_data):
adr = adr << 2
resp = (yield from dut.axi_lite.write(adr, wdata))
self.assertEqual(resp, 0b00)
rdata, resp = (yield from dut.axi_lite.read(adr))
self.assertEqual(resp, 0b00)
if rdata != wdata:
dut.errors += 1
dut = DUT()
mem = [v for v in mem_ref]
run_simulation(dut, [generator(dut), csr_mem_handler(dut.csr, mem)])
self.assertEqual(dut.errors, 0)
def test_axilite_sram(self):
class DUT(Module):
def __init__(self, size, init):
self.axi_lite = AXILiteInterface()
self.submodules.sram = AXILiteSRAM(size, init=init, bus=self.axi_lite)
self.errors = 0
def generator(dut, ref_init):
for adr, ref in enumerate(ref_init):
adr = adr << 2
data, resp = (yield from dut.axi_lite.read(adr))
self.assertEqual(resp, 0b00)
if data != ref:
dut.errors += 1
write_data = [prng.randrange(255) for _ in ref_init]
for adr, wdata in enumerate(write_data):
adr = adr << 2
resp = (yield from dut.axi_lite.write(adr, wdata))
self.assertEqual(resp, 0b00)
rdata, resp = (yield from dut.axi_lite.read(adr))
self.assertEqual(resp, 0b00)
if rdata != wdata:
dut.errors += 1
prng = random.Random(42)
init = [prng.randrange(2**32) for i in range(100)]
dut = DUT(size=len(init)*4, init=[v for v in init])
run_simulation(dut, [generator(dut, init)])
self.assertEqual(dut.errors, 0) self.assertEqual(dut.errors, 0)