soc/interconnect/csr: Add optional support fixed CSR mapping.

By default, location is still automatically determined but it's now possible to
specific locations:

The following module:

class MyModule(Module, AutoCSR):
    def __init__(self):
        self.csr0 = CSRStorage()
        self.csr1 = CSRStorage(n=0)
        self.csr2 = CSRStorage(n=2)

built on a SoC with 32-bit CSR data-width will have the following CSR mapping:
- 0x00 : csr1
- 0x04 : csr0
- 0x08 : reserved
- 0x0c : csr2
This commit is contained in:
Florent Kermarrec 2022-10-21 14:45:46 +02:00
parent 804a1a5b26
commit a57f0640cc
3 changed files with 105 additions and 26 deletions

View File

@ -43,13 +43,14 @@ from migen.fhdl.tracer import get_obj_var_name
# CSRBase ------------------------------------------------------------------------------------------ # CSRBase ------------------------------------------------------------------------------------------
class _CSRBase(DUID): class _CSRBase(DUID):
def __init__(self, size, name): def __init__(self, size, name, n=None):
DUID.__init__(self) DUID.__init__(self)
self.n = n
self.fixed = n is not None
self.size = size
self.name = get_obj_var_name(name) self.name = get_obj_var_name(name)
if self.name is None: if self.name is None:
raise ValueError("Cannot extract CSR name from code, need to specify.") raise ValueError("Cannot extract CSR name from code, need to specify.")
self.size = size
# CSRConstant -------------------------------------------------------------------------------------- # CSRConstant --------------------------------------------------------------------------------------
class CSRConstant(DUID): class CSRConstant(DUID):
@ -59,8 +60,10 @@ class CSRConstant(DUID):
running on the device. running on the device.
""" """
def __init__(self, value, bits_sign=None, name=None): def __init__(self, value, bits_sign=None, name=None, n=None):
DUID.__init__(self) DUID.__init__(self)
self.n = n
self.fixed = n is not None
self.value = Constant(value, bits_sign) self.value = Constant(value, bits_sign)
self.name = get_obj_var_name(name) self.name = get_obj_var_name(name)
self.constant = value self.constant = value
@ -105,8 +108,8 @@ class CSR(_CSRBase):
It is active for one cycle, after or during a read from the bus. It is active for one cycle, after or during a read from the bus.
""" """
def __init__(self, size=1, name=None): def __init__(self, size=1, name=None, n=None):
_CSRBase.__init__(self, size, name) _CSRBase.__init__(self, size, name, n)
self.re = Signal(name=self.name + "_re") self.re = Signal(name=self.name + "_re")
self.r = Signal(self.size, name=self.name + "_r") self.r = Signal(self.size, name=self.name + "_r")
self.we = Signal(name=self.name + "_we") self.we = Signal(name=self.name + "_we")
@ -129,8 +132,8 @@ class CSR(_CSRBase):
class _CompoundCSR(_CSRBase, Module): class _CompoundCSR(_CSRBase, Module):
def __init__(self, size, name): def __init__(self, size, name, n=None):
_CSRBase.__init__(self, size, name) _CSRBase.__init__(self, size, name, n)
self.simple_csrs = [] self.simple_csrs = []
def get_simple_csrs(self): def get_simple_csrs(self):
@ -288,12 +291,12 @@ class CSRStatus(_CompoundCSR):
The value of the CSRStatus register. The value of the CSRStatus register.
""" """
def __init__(self, size=1, reset=0, fields=[], name=None, description=None, read_only=True): def __init__(self, size=1, reset=0, fields=[], name=None, description=None, read_only=True, n=None):
if fields != []: if fields != []:
self.fields = CSRFieldAggregate(fields, CSRAccess.ReadOnly) self.fields = CSRFieldAggregate(fields, CSRAccess.ReadOnly)
size = self.fields.get_size() size = self.fields.get_size()
reset = self.fields.get_reset() reset = self.fields.get_reset()
_CompoundCSR.__init__(self, size, name) _CompoundCSR.__init__(self, size, name, n)
self.description = description self.description = description
self.read_only = read_only self.read_only = read_only
self.status = Signal(self.size, reset=reset) self.status = Signal(self.size, reset=reset)
@ -377,12 +380,12 @@ class CSRStorage(_CompoundCSR):
``write_from_dev == True`` ``write_from_dev == True``
""" """
def __init__(self, size=1, reset=0, reset_less=False, fields=[], atomic_write=False, write_from_dev=False, name=None, description=None): def __init__(self, size=1, reset=0, reset_less=False, fields=[], atomic_write=False, write_from_dev=False, name=None, description=None, n=None):
if fields != []: if fields != []:
self.fields = CSRFieldAggregate(fields, CSRAccess.ReadWrite) self.fields = CSRFieldAggregate(fields, CSRAccess.ReadWrite)
size = self.fields.get_size() size = self.fields.get_size()
reset = self.fields.get_reset() reset = self.fields.get_reset()
_CompoundCSR.__init__(self, size, name) _CompoundCSR.__init__(self, size, name, n)
self.description = description self.description = description
self.storage = Signal(self.size, reset=reset, reset_less=reset_less) self.storage = Signal(self.size, reset=reset, reset_less=reset_less)
self.atomic_write = atomic_write self.atomic_write = atomic_write
@ -460,6 +463,65 @@ def memprefix(prefix, memories, done):
memory.name_override = prefix + memory.name_override memory.name_override = prefix + memory.name_override
done.add(memory.duid) done.add(memory.duid)
def _sort_gathered_items(items):
# Create list of variable items and sort it by DUID.
# --------------------------------------------------
variable_items = []
for item in items:
if not item.fixed:
variable_items.append(item)
variable_items = sorted(variable_items, key=lambda x: x.duid)
# Create list of fixed items:
# ---------------------------
fixed_items = []
for item in items:
if item.fixed:
fixed_items.append(item)
# Determine items length.
# -----------------------
# Set to length of provided items.
items_length = len(items)
# Eventually extend with fixed items:
for item in fixed_items:
if item.n > items_length:
items_length = (item.n + 1)
# Create list of sorted items:
# ----------------------------
# Create empty list.
sorted_items = [None for _ in range(items_length)]
# Fill fixed items.
for item in fixed_items:
if sorted_items[item.n] is not None:
csr0 = item.name
csr1 = sorted_items[item.n].name
raise ValueError(f"CSR conflict on location {item.n} between {csr0} and {csr1}.")
sorted_items[item.n] = item
# Fill variable items in empty locations.
while len(variable_items):
item = variable_items.pop(0)
for i in range(items_length):
if sorted_items[i] is None:
sorted_items[i] = item
break
# Fill remaining location with reserved CSR.
for i in range(items_length):
if sorted_items[i] is None:
sorted_items[i] = CSR(name=f"reserved{i}")
# Verify all locations are filled.
assert None not in sorted_items
# Return.
return sorted_items
def _make_gatherer(method, cls, prefix_cb): def _make_gatherer(method, cls, prefix_cb):
def gatherer(self): def gatherer(self):
@ -480,7 +542,7 @@ def _make_gatherer(method, cls, prefix_cb):
items = getattr(v, method)() items = getattr(v, method)()
prefix_cb(k + "_", items, prefixed) prefix_cb(k + "_", items, prefixed)
r += items r += items
return sorted(r, key=lambda x: x.duid) return _sort_gathered_items(r)
return gatherer return gatherer
@ -494,9 +556,9 @@ class AutoCSR:
they will be called by the``AutoCSR`` methods and their CSR and memories added to the lists returned, they will be called by the``AutoCSR`` methods and their CSR and memories added to the lists returned,
with the child objects' names as prefixes. with the child objects' names as prefixes.
""" """
get_memories = _make_gatherer("get_memories", Memory, memprefix) get_memories = _make_gatherer(method="get_memories", cls=Memory, prefix_cb=memprefix)
get_csrs = _make_gatherer("get_csrs", _CSRBase, csrprefix) get_csrs = _make_gatherer(method="get_csrs", cls=_CSRBase, prefix_cb=csrprefix)
get_constants = _make_gatherer("get_constants", CSRConstant, csrprefix) get_constants = _make_gatherer(method="get_constants", cls=CSRConstant, prefix_cb=csrprefix)
class GenericBank(Module): class GenericBank(Module):
@ -508,7 +570,7 @@ class GenericBank(Module):
if isinstance(c, CSR): if isinstance(c, CSR):
assert c.size <= busword assert c.size <= busword
self.simple_csrs.append(c) self.simple_csrs.append(c)
else: elif hasattr(c, "finalize"):
c.finalize(busword, ordering) c.finalize(busword, ordering)
self.simple_csrs += c.get_simple_csrs() self.simple_csrs += c.get_simple_csrs()
self.submodules += c self.submodules += c

View File

@ -215,14 +215,21 @@ class CSRBankArray(Module):
self.scan(ifargs, ifkwargs) self.scan(ifargs, ifkwargs)
def scan(self, ifargs, ifkwargs): def scan(self, ifargs, ifkwargs):
self.banks = [] self.banks = []
self.srams = [] self.srams = []
self.constants = [] self.constants = []
for name, obj in xdir(self.source, True): for name, obj in xdir(self.source, True):
# Collect CSR Registers.
# ---------------------
csrs = []
if hasattr(obj, "get_csrs"): if hasattr(obj, "get_csrs"):
csrs = obj.get_csrs() csrs = obj.get_csrs()
else:
csrs = [] # Collect CSR Memories.
# ---------------------
if hasattr(obj, "get_memories"): if hasattr(obj, "get_memories"):
memories = obj.get_memories() memories = obj.get_memories()
for memory in memories: for memory in memories:
@ -241,9 +248,15 @@ class CSRBankArray(Module):
self.submodules += mmap self.submodules += mmap
csrs += mmap.get_csrs() csrs += mmap.get_csrs()
self.srams.append((name, memory, mapaddr, mmap)) self.srams.append((name, memory, mapaddr, mmap))
# Collect CSR Constants.
# ----------------------
if hasattr(obj, "get_constants"): if hasattr(obj, "get_constants"):
for constant in obj.get_constants(): for constant in obj.get_constants():
self.constants.append((name, constant)) self.constants.append((name, constant))
# Create CSRBank with CSRs found.
# -------------------------------
if csrs: if csrs:
mapaddr = self.address_map(name, None) mapaddr = self.address_map(name, None)
if mapaddr is None: if mapaddr is None:

View File

@ -34,8 +34,8 @@ class CSRModule(Module, csr.AutoCSR):
# # # # # #
# When csr is written: # When csr is written:
# - set storage to 0xdeadbeef # - Set storage to 0xdeadbeef.
# - set status to storage value # - Set status to storage value.
self.comb += [ self.comb += [
If(self._csr.re, If(self._csr.re,
self._storage.we.eq(1), self._storage.we.eq(1),
@ -57,9 +57,13 @@ class CSRDUT(Module):
self.csr = csr_bus.Interface() self.csr = csr_bus.Interface()
self.submodules.csrmodule = CSRModule() self.submodules.csrmodule = CSRModule()
self.submodules.csrbankarray = csr_bus.CSRBankArray( self.submodules.csrbankarray = csr_bus.CSRBankArray(
self, self.address_map) source = self,
address_map = self.address_map,
)
self.submodules.csrcon = csr_bus.Interconnect( self.submodules.csrcon = csr_bus.Interconnect(
self.csr, self.csrbankarray.get_buses()) master = self.csr,
slaves = self.csrbankarray.get_buses()
)
class TestCSR(unittest.TestCase): class TestCSR(unittest.TestCase):
def test_csr_constant(self): def test_csr_constant(self):