Merge pull request #1467 from enjoy-digital/csr_mapping

Add optional support for fixed CSR mapping.
This commit is contained in:
enjoy-digital 2022-10-21 18:36:26 +02:00 committed by GitHub
commit 14b2829a5f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 111 additions and 28 deletions

View file

@ -43,13 +43,14 @@ from migen.fhdl.tracer import get_obj_var_name
# CSRBase ------------------------------------------------------------------------------------------
class _CSRBase(DUID):
def __init__(self, size, name):
def __init__(self, size, name, n=None):
DUID.__init__(self)
self.n = n
self.fixed = n is not None
self.size = size
self.name = get_obj_var_name(name)
if self.name is None:
raise ValueError("Cannot extract CSR name from code, need to specify.")
self.size = size
# CSRConstant --------------------------------------------------------------------------------------
class CSRConstant(DUID):
@ -59,8 +60,10 @@ class CSRConstant(DUID):
running on the device.
"""
def __init__(self, value, bits_sign=None, name=None):
def __init__(self, value, bits_sign=None, name=None, n=None):
DUID.__init__(self)
self.n = n
self.fixed = n is not None
self.value = Constant(value, bits_sign)
self.name = get_obj_var_name(name)
self.constant = value
@ -105,8 +108,8 @@ class CSR(_CSRBase):
It is active for one cycle, after or during a read from the bus.
"""
def __init__(self, size=1, name=None):
_CSRBase.__init__(self, size, name)
def __init__(self, size=1, name=None, n=None):
_CSRBase.__init__(self, size, name, n)
self.re = Signal(name=self.name + "_re")
self.r = Signal(self.size, name=self.name + "_r")
self.we = Signal(name=self.name + "_we")
@ -129,8 +132,8 @@ class CSR(_CSRBase):
class _CompoundCSR(_CSRBase, Module):
def __init__(self, size, name):
_CSRBase.__init__(self, size, name)
def __init__(self, size, name, n=None):
_CSRBase.__init__(self, size, name, n)
self.simple_csrs = []
def get_simple_csrs(self):
@ -288,12 +291,12 @@ class CSRStatus(_CompoundCSR):
The value of the CSRStatus register.
"""
def __init__(self, size=1, reset=0, fields=[], name=None, description=None, read_only=True):
def __init__(self, size=1, reset=0, fields=[], name=None, description=None, read_only=True, n=None):
if fields != []:
self.fields = CSRFieldAggregate(fields, CSRAccess.ReadOnly)
size = self.fields.get_size()
reset = self.fields.get_reset()
_CompoundCSR.__init__(self, size, name)
_CompoundCSR.__init__(self, size, name, n)
self.description = description
self.read_only = read_only
self.status = Signal(self.size, reset=reset)
@ -377,12 +380,12 @@ class CSRStorage(_CompoundCSR):
``write_from_dev == True``
"""
def __init__(self, size=1, reset=0, reset_less=False, fields=[], atomic_write=False, write_from_dev=False, name=None, description=None):
def __init__(self, size=1, reset=0, reset_less=False, fields=[], atomic_write=False, write_from_dev=False, name=None, description=None, n=None):
if fields != []:
self.fields = CSRFieldAggregate(fields, CSRAccess.ReadWrite)
size = self.fields.get_size()
reset = self.fields.get_reset()
_CompoundCSR.__init__(self, size, name)
_CompoundCSR.__init__(self, size, name, n)
self.description = description
self.storage = Signal(self.size, reset=reset, reset_less=reset_less)
self.atomic_write = atomic_write
@ -460,9 +463,69 @@ def memprefix(prefix, memories, done):
memory.name_override = prefix + memory.name_override
done.add(memory.duid)
def _sort_gathered_items(items):
# Create list of variable items and sort it by DUID.
# --------------------------------------------------
variable_items = []
for item in items:
if not item.fixed:
variable_items.append(item)
variable_items = sorted(variable_items, key=lambda x: x.duid)
# Create list of fixed items:
# ---------------------------
fixed_items = []
for item in items:
if item.fixed:
fixed_items.append(item)
# Determine items length.
# -----------------------
# Set to length of provided items.
items_length = len(items)
# Eventually extend with fixed items:
for item in fixed_items:
if item.n > items_length:
items_length = (item.n + 1)
# Create list of sorted items:
# ----------------------------
# Create empty list.
sorted_items = [None for _ in range(items_length)]
# Fill fixed items.
for item in fixed_items:
if sorted_items[item.n] is not None:
csr0 = item.name
csr1 = sorted_items[item.n].name
raise ValueError(f"CSR conflict on location {item.n} between {csr0} and {csr1}.")
sorted_items[item.n] = item
# Fill variable items in empty locations.
while len(variable_items):
item = variable_items.pop(0)
for i in range(items_length):
if sorted_items[i] is None:
sorted_items[i] = item
break
# Fill remaining location with reserved CSR.
for i in range(items_length):
if sorted_items[i] is None:
sorted_items[i] = CSR(name=f"reserved{i}")
# Verify all locations are filled.
assert None not in sorted_items
# Return.
return sorted_items
def _make_gatherer(method, cls, prefix_cb):
def gatherer(self):
def gatherer(self, level=0):
try:
exclude = self.autocsr_exclude
except AttributeError:
@ -477,10 +540,13 @@ def _make_gatherer(method, cls, prefix_cb):
if isinstance(v, cls):
r.append(v)
elif hasattr(v, method) and callable(getattr(v, method)):
items = getattr(v, method)()
items = getattr(v, method)(level=level+1)
prefix_cb(k + "_", items, prefixed)
r += items
return sorted(r, key=lambda x: x.duid)
if level == 0:
return _sort_gathered_items(r)
else:
return r
return gatherer
@ -494,9 +560,9 @@ class AutoCSR:
they will be called by the``AutoCSR`` methods and their CSR and memories added to the lists returned,
with the child objects' names as prefixes.
"""
get_memories = _make_gatherer("get_memories", Memory, memprefix)
get_csrs = _make_gatherer("get_csrs", _CSRBase, csrprefix)
get_constants = _make_gatherer("get_constants", CSRConstant, csrprefix)
get_memories = _make_gatherer(method="get_memories", cls=Memory, prefix_cb=memprefix)
get_csrs = _make_gatherer(method="get_csrs", cls=_CSRBase, prefix_cb=csrprefix)
get_constants = _make_gatherer(method="get_constants", cls=CSRConstant, prefix_cb=csrprefix)
class GenericBank(Module):
@ -508,7 +574,7 @@ class GenericBank(Module):
if isinstance(c, CSR):
assert c.size <= busword
self.simple_csrs.append(c)
else:
elif hasattr(c, "finalize"):
c.finalize(busword, ordering)
self.simple_csrs += c.get_simple_csrs()
self.submodules += c

View file

@ -215,14 +215,21 @@ class CSRBankArray(Module):
self.scan(ifargs, ifkwargs)
def scan(self, ifargs, ifkwargs):
self.banks = []
self.srams = []
self.banks = []
self.srams = []
self.constants = []
for name, obj in xdir(self.source, True):
# Collect CSR Registers.
# ---------------------
csrs = []
if hasattr(obj, "get_csrs"):
csrs = obj.get_csrs()
else:
csrs = []
# Collect CSR Memories.
# ---------------------
if hasattr(obj, "get_memories"):
memories = obj.get_memories()
for memory in memories:
@ -241,9 +248,15 @@ class CSRBankArray(Module):
self.submodules += mmap
csrs += mmap.get_csrs()
self.srams.append((name, memory, mapaddr, mmap))
# Collect CSR Constants.
# ----------------------
if hasattr(obj, "get_constants"):
for constant in obj.get_constants():
self.constants.append((name, constant))
# Create CSRBank with CSRs found.
# -------------------------------
if csrs:
mapaddr = self.address_map(name, None)
if mapaddr is None:

View file

@ -34,8 +34,8 @@ class CSRModule(Module, csr.AutoCSR):
# # #
# When csr is written:
# - set storage to 0xdeadbeef
# - set status to storage value
# - Set storage to 0xdeadbeef.
# - Set status to storage value.
self.comb += [
If(self._csr.re,
self._storage.we.eq(1),
@ -55,11 +55,15 @@ class CSRDUT(Module):
def __init__(self):
self.csr = csr_bus.Interface()
self.submodules.csrmodule = CSRModule()
self.submodules.csrmodule = CSRModule()
self.submodules.csrbankarray = csr_bus.CSRBankArray(
self, self.address_map)
source = self,
address_map = self.address_map,
)
self.submodules.csrcon = csr_bus.Interconnect(
self.csr, self.csrbankarray.get_buses())
master = self.csr,
slaves = self.csrbankarray.get_buses()
)
class TestCSR(unittest.TestCase):
def test_csr_constant(self):