diff --git a/litex/soc/interconnect/csr.py b/litex/soc/interconnect/csr.py index 5e04e314d..ed69e516f 100644 --- a/litex/soc/interconnect/csr.py +++ b/litex/soc/interconnect/csr.py @@ -43,13 +43,14 @@ from migen.fhdl.tracer import get_obj_var_name # CSRBase ------------------------------------------------------------------------------------------ class _CSRBase(DUID): - def __init__(self, size, name): + def __init__(self, size, name, n=None): DUID.__init__(self) + self.n = n + self.fixed = n is not None + self.size = size self.name = get_obj_var_name(name) if self.name is None: raise ValueError("Cannot extract CSR name from code, need to specify.") - self.size = size - # CSRConstant -------------------------------------------------------------------------------------- class CSRConstant(DUID): @@ -59,8 +60,10 @@ class CSRConstant(DUID): running on the device. """ - def __init__(self, value, bits_sign=None, name=None): + def __init__(self, value, bits_sign=None, name=None, n=None): DUID.__init__(self) + self.n = n + self.fixed = n is not None self.value = Constant(value, bits_sign) self.name = get_obj_var_name(name) self.constant = value @@ -105,8 +108,8 @@ class CSR(_CSRBase): It is active for one cycle, after or during a read from the bus. """ - def __init__(self, size=1, name=None): - _CSRBase.__init__(self, size, name) + def __init__(self, size=1, name=None, n=None): + _CSRBase.__init__(self, size, name, n) self.re = Signal(name=self.name + "_re") self.r = Signal(self.size, name=self.name + "_r") self.we = Signal(name=self.name + "_we") @@ -129,8 +132,8 @@ class CSR(_CSRBase): class _CompoundCSR(_CSRBase, Module): - def __init__(self, size, name): - _CSRBase.__init__(self, size, name) + def __init__(self, size, name, n=None): + _CSRBase.__init__(self, size, name, n) self.simple_csrs = [] def get_simple_csrs(self): @@ -288,12 +291,12 @@ class CSRStatus(_CompoundCSR): The value of the CSRStatus register. """ - def __init__(self, size=1, reset=0, fields=[], name=None, description=None, read_only=True): + def __init__(self, size=1, reset=0, fields=[], name=None, description=None, read_only=True, n=None): if fields != []: self.fields = CSRFieldAggregate(fields, CSRAccess.ReadOnly) size = self.fields.get_size() reset = self.fields.get_reset() - _CompoundCSR.__init__(self, size, name) + _CompoundCSR.__init__(self, size, name, n) self.description = description self.read_only = read_only self.status = Signal(self.size, reset=reset) @@ -377,12 +380,12 @@ class CSRStorage(_CompoundCSR): ``write_from_dev == True`` """ - def __init__(self, size=1, reset=0, reset_less=False, fields=[], atomic_write=False, write_from_dev=False, name=None, description=None): + def __init__(self, size=1, reset=0, reset_less=False, fields=[], atomic_write=False, write_from_dev=False, name=None, description=None, n=None): if fields != []: self.fields = CSRFieldAggregate(fields, CSRAccess.ReadWrite) size = self.fields.get_size() reset = self.fields.get_reset() - _CompoundCSR.__init__(self, size, name) + _CompoundCSR.__init__(self, size, name, n) self.description = description self.storage = Signal(self.size, reset=reset, reset_less=reset_less) self.atomic_write = atomic_write @@ -460,6 +463,65 @@ def memprefix(prefix, memories, done): memory.name_override = prefix + memory.name_override done.add(memory.duid) +def _sort_gathered_items(items): + + # Create list of variable items and sort it by DUID. + # -------------------------------------------------- + variable_items = [] + for item in items: + if not item.fixed: + variable_items.append(item) + variable_items = sorted(variable_items, key=lambda x: x.duid) + + # Create list of fixed items: + # --------------------------- + fixed_items = [] + for item in items: + if item.fixed: + fixed_items.append(item) + + # Determine items length. + # ----------------------- + # Set to length of provided items. + items_length = len(items) + + # Eventually extend with fixed items: + for item in fixed_items: + if item.n > items_length: + items_length = (item.n + 1) + + # Create list of sorted items: + # ---------------------------- + + # Create empty list. + sorted_items = [None for _ in range(items_length)] + + # Fill fixed items. + for item in fixed_items: + if sorted_items[item.n] is not None: + csr0 = item.name + csr1 = sorted_items[item.n].name + raise ValueError(f"CSR conflict on location {item.n} between {csr0} and {csr1}.") + sorted_items[item.n] = item + + # Fill variable items in empty locations. + while len(variable_items): + item = variable_items.pop(0) + for i in range(items_length): + if sorted_items[i] is None: + sorted_items[i] = item + break + + # Fill remaining location with reserved CSR. + for i in range(items_length): + if sorted_items[i] is None: + sorted_items[i] = CSR(name=f"reserved{i}") + + # Verify all locations are filled. + assert None not in sorted_items + + # Return. + return sorted_items def _make_gatherer(method, cls, prefix_cb): def gatherer(self): @@ -480,7 +542,7 @@ def _make_gatherer(method, cls, prefix_cb): items = getattr(v, method)() prefix_cb(k + "_", items, prefixed) r += items - return sorted(r, key=lambda x: x.duid) + return _sort_gathered_items(r) return gatherer @@ -494,9 +556,9 @@ class AutoCSR: they will be called by the``AutoCSR`` methods and their CSR and memories added to the lists returned, with the child objects' names as prefixes. """ - get_memories = _make_gatherer("get_memories", Memory, memprefix) - get_csrs = _make_gatherer("get_csrs", _CSRBase, csrprefix) - get_constants = _make_gatherer("get_constants", CSRConstant, csrprefix) + get_memories = _make_gatherer(method="get_memories", cls=Memory, prefix_cb=memprefix) + get_csrs = _make_gatherer(method="get_csrs", cls=_CSRBase, prefix_cb=csrprefix) + get_constants = _make_gatherer(method="get_constants", cls=CSRConstant, prefix_cb=csrprefix) class GenericBank(Module): @@ -508,7 +570,7 @@ class GenericBank(Module): if isinstance(c, CSR): assert c.size <= busword self.simple_csrs.append(c) - else: + elif hasattr(c, "finalize"): c.finalize(busword, ordering) self.simple_csrs += c.get_simple_csrs() self.submodules += c diff --git a/litex/soc/interconnect/csr_bus.py b/litex/soc/interconnect/csr_bus.py index c748124c9..a1c8e5740 100644 --- a/litex/soc/interconnect/csr_bus.py +++ b/litex/soc/interconnect/csr_bus.py @@ -215,14 +215,21 @@ class CSRBankArray(Module): self.scan(ifargs, ifkwargs) def scan(self, ifargs, ifkwargs): - self.banks = [] - self.srams = [] + + self.banks = [] + self.srams = [] self.constants = [] + for name, obj in xdir(self.source, True): + + # Collect CSR Registers. + # --------------------- + csrs = [] if hasattr(obj, "get_csrs"): csrs = obj.get_csrs() - else: - csrs = [] + + # Collect CSR Memories. + # --------------------- if hasattr(obj, "get_memories"): memories = obj.get_memories() for memory in memories: @@ -241,9 +248,15 @@ class CSRBankArray(Module): self.submodules += mmap csrs += mmap.get_csrs() self.srams.append((name, memory, mapaddr, mmap)) + + # Collect CSR Constants. + # ---------------------- if hasattr(obj, "get_constants"): for constant in obj.get_constants(): self.constants.append((name, constant)) + + # Create CSRBank with CSRs found. + # ------------------------------- if csrs: mapaddr = self.address_map(name, None) if mapaddr is None: diff --git a/test/test_csr.py b/test/test_csr.py index e426a1a7d..92b08ef35 100644 --- a/test/test_csr.py +++ b/test/test_csr.py @@ -34,8 +34,8 @@ class CSRModule(Module, csr.AutoCSR): # # # # When csr is written: - # - set storage to 0xdeadbeef - # - set status to storage value + # - Set storage to 0xdeadbeef. + # - Set status to storage value. self.comb += [ If(self._csr.re, self._storage.we.eq(1), @@ -55,11 +55,15 @@ class CSRDUT(Module): def __init__(self): self.csr = csr_bus.Interface() - self.submodules.csrmodule = CSRModule() + self.submodules.csrmodule = CSRModule() self.submodules.csrbankarray = csr_bus.CSRBankArray( - self, self.address_map) + source = self, + address_map = self.address_map, + ) self.submodules.csrcon = csr_bus.Interconnect( - self.csr, self.csrbankarray.get_buses()) + master = self.csr, + slaves = self.csrbankarray.get_buses() + ) class TestCSR(unittest.TestCase): def test_csr_constant(self):