mirror of
https://github.com/enjoy-digital/litex.git
synced 2025-01-04 09:52:26 -05:00
248 lines
6.4 KiB
Python
248 lines
6.4 KiB
Python
from migen.fhdl.structure import *
|
|
from migen.fhdl.structure import _Slice, _Assign
|
|
from migen.fhdl.visit import NodeVisitor, NodeTransformer
|
|
from migen.fhdl.bitcontainer import value_bits_sign
|
|
from migen.util.misc import flat_iteration
|
|
|
|
class _SignalLister(NodeVisitor):
|
|
def __init__(self):
|
|
self.output_list = set()
|
|
|
|
def visit_Signal(self, node):
|
|
self.output_list.add(node)
|
|
|
|
class _TargetLister(NodeVisitor):
|
|
def __init__(self):
|
|
self.output_list = set()
|
|
self.target_context = False
|
|
|
|
def visit_Signal(self, node):
|
|
if self.target_context:
|
|
self.output_list.add(node)
|
|
|
|
def visit_Assign(self, node):
|
|
self.target_context = True
|
|
self.visit(node.l)
|
|
self.target_context = False
|
|
|
|
def visit_ArrayProxy(self, node):
|
|
for choice in node.choices:
|
|
self.visit(choice)
|
|
|
|
def list_signals(node):
|
|
lister = _SignalLister()
|
|
lister.visit(node)
|
|
return lister.output_list
|
|
|
|
def list_targets(node):
|
|
lister = _TargetLister()
|
|
lister.visit(node)
|
|
return lister.output_list
|
|
|
|
def _resort_statements(ol):
|
|
return [statement for i, statement in
|
|
sorted(ol, key=lambda x: x[0])]
|
|
|
|
def group_by_targets(sl):
|
|
groups = []
|
|
seen = set()
|
|
for order, stmt in enumerate(flat_iteration(sl)):
|
|
targets = set(list_targets(stmt))
|
|
group = [(order, stmt)]
|
|
disjoint = targets.isdisjoint(seen)
|
|
seen |= targets
|
|
if not disjoint:
|
|
groups, old_groups = [], groups
|
|
for old_targets, old_group in old_groups:
|
|
if targets.isdisjoint(old_targets):
|
|
groups.append((old_targets, old_group))
|
|
else:
|
|
targets |= old_targets
|
|
group += old_group
|
|
groups.append((targets, group))
|
|
return [(targets, _resort_statements(stmts))
|
|
for targets, stmts in groups]
|
|
|
|
def list_special_ios(f, ins, outs, inouts):
|
|
r = set()
|
|
for special in f.specials:
|
|
r |= special.list_ios(ins, outs, inouts)
|
|
return r
|
|
|
|
class _ClockDomainLister(NodeVisitor):
|
|
def __init__(self):
|
|
self.clock_domains = set()
|
|
|
|
def visit_ClockSignal(self, node):
|
|
self.clock_domains.add(node.cd)
|
|
|
|
def visit_ResetSignal(self, node):
|
|
self.clock_domains.add(node.cd)
|
|
|
|
def visit_clock_domains(self, node):
|
|
for clockname, statements in node.items():
|
|
self.clock_domains.add(clockname)
|
|
self.visit(statements)
|
|
|
|
def list_clock_domains_expr(f):
|
|
cdl = _ClockDomainLister()
|
|
cdl.visit(f)
|
|
return cdl.clock_domains
|
|
|
|
def list_clock_domains(f):
|
|
r = list_clock_domains_expr(f)
|
|
for special in f.specials:
|
|
r |= special.list_clock_domains()
|
|
for cd in f.clock_domains:
|
|
r.add(cd.name)
|
|
return r
|
|
|
|
def is_variable(node):
|
|
if isinstance(node, Signal):
|
|
return node.variable
|
|
elif isinstance(node, _Slice):
|
|
return is_variable(node.value)
|
|
elif isinstance(node, Cat):
|
|
arevars = list(map(is_variable, node.l))
|
|
r = arevars[0]
|
|
for x in arevars:
|
|
if x != r:
|
|
raise TypeError
|
|
return r
|
|
else:
|
|
raise TypeError
|
|
|
|
def generate_reset(rst, sl):
|
|
targets = list_targets(sl)
|
|
return [t.eq(t.reset) for t in sorted(targets, key=lambda x: x.huid)]
|
|
|
|
def insert_reset(rst, sl):
|
|
return [If(rst, *generate_reset(rst, sl)).Else(*sl)]
|
|
|
|
def insert_resets(f):
|
|
newsync = dict()
|
|
for k, v in f.sync.items():
|
|
if f.clock_domains[k].rst is not None:
|
|
newsync[k] = insert_reset(ResetSignal(k), v)
|
|
else:
|
|
newsync[k] = v
|
|
f.sync = newsync
|
|
|
|
class _Lowerer(NodeTransformer):
|
|
def __init__(self):
|
|
self.target_context = False
|
|
self.extra_stmts = []
|
|
self.comb = []
|
|
|
|
def visit_Assign(self, node):
|
|
old_target_context, old_extra_stmts = self.target_context, self.extra_stmts
|
|
self.extra_stmts = []
|
|
|
|
self.target_context = True
|
|
lhs = self.visit(node.l)
|
|
self.target_context = False
|
|
rhs = self.visit(node.r)
|
|
r = _Assign(lhs, rhs)
|
|
if self.extra_stmts:
|
|
r = [r] + self.extra_stmts
|
|
|
|
self.target_context, self.extra_stmts = old_target_context, old_extra_stmts
|
|
return r
|
|
|
|
# Basics are FHDL structure elements that back-ends are not required to support
|
|
# but can be expressed in terms of other elements (lowered) before conversion.
|
|
class _BasicLowerer(_Lowerer):
|
|
def __init__(self, clock_domains):
|
|
self.clock_domains = clock_domains
|
|
_Lowerer.__init__(self)
|
|
|
|
def visit_ArrayProxy(self, node):
|
|
# TODO: rewrite without variables
|
|
array_muxed = Signal(value_bits_sign(node), variable=True)
|
|
if self.target_context:
|
|
k = self.visit(node.key)
|
|
cases = {}
|
|
for n, choice in enumerate(node.choices):
|
|
cases[n] = [self.visit_Assign(_Assign(choice, array_muxed))]
|
|
self.extra_stmts.append(Case(k, cases).makedefault())
|
|
else:
|
|
cases = dict((n, _Assign(array_muxed, self.visit(choice)))
|
|
for n, choice in enumerate(node.choices))
|
|
self.comb.append(Case(self.visit(node.key), cases).makedefault())
|
|
return array_muxed
|
|
|
|
def visit_ClockSignal(self, node):
|
|
return self.clock_domains[node.cd].clk
|
|
|
|
def visit_ResetSignal(self, node):
|
|
return self.clock_domains[node.cd].rst
|
|
|
|
class _ComplexSliceLowerer(_Lowerer):
|
|
def visit_Slice(self, node):
|
|
if not isinstance(node.value, Signal):
|
|
slice_proxy = Signal(value_bits_sign(node.value))
|
|
if self.target_context:
|
|
a = _Assign(node.value, slice_proxy)
|
|
else:
|
|
a = _Assign(slice_proxy, node.value)
|
|
self.comb.append(self.visit_Assign(a))
|
|
node = _Slice(slice_proxy, node.start, node.stop)
|
|
return NodeTransformer.visit_Slice(self, node)
|
|
|
|
def _apply_lowerer(l, f):
|
|
f = l.visit(f)
|
|
f.comb += l.comb
|
|
|
|
for special in f.specials:
|
|
for obj, attr, direction in special.iter_expressions():
|
|
if direction != SPECIAL_INOUT:
|
|
# inouts are only supported by Migen when connected directly to top-level
|
|
# in this case, they are Signal and never need lowering
|
|
l.comb = []
|
|
l.target_context = direction != SPECIAL_INPUT
|
|
l.extra_stmts = []
|
|
expr = getattr(obj, attr)
|
|
expr = l.visit(expr)
|
|
setattr(obj, attr, expr)
|
|
f.comb += l.comb + l.extra_stmts
|
|
|
|
return f
|
|
|
|
def lower_basics(f):
|
|
return _apply_lowerer(_BasicLowerer(f.clock_domains), f)
|
|
|
|
def lower_complex_slices(f):
|
|
return _apply_lowerer(_ComplexSliceLowerer(), f)
|
|
|
|
class _ClockDomainRenamer(NodeVisitor):
|
|
def __init__(self, old, new):
|
|
self.old = old
|
|
self.new = new
|
|
|
|
def visit_ClockSignal(self, node):
|
|
if node.cd == self.old:
|
|
node.cd = self.new
|
|
|
|
def visit_ResetSignal(self, node):
|
|
if node.cd == self.old:
|
|
node.cd = self.new
|
|
|
|
def rename_clock_domain_expr(f, old, new):
|
|
cdr = _ClockDomainRenamer(old, new)
|
|
cdr.visit(f)
|
|
|
|
def rename_clock_domain(f, old, new):
|
|
rename_clock_domain_expr(f, old, new)
|
|
if new in f.sync:
|
|
f.sync[new].extend(f.sync[old])
|
|
else:
|
|
f.sync[new] = f.sync[old]
|
|
del f.sync[old]
|
|
for special in f.specials:
|
|
special.rename_clock_domain(old, new)
|
|
try:
|
|
cd = f.clock_domains[old]
|
|
except KeyError:
|
|
pass
|
|
else:
|
|
cd.rename(new)
|