mirror of
https://github.com/enjoy-digital/litex.git
synced 2025-01-04 09:52:26 -05:00
Refactor Case
This commit is contained in:
parent
070652cc39
commit
6eebfce44a
10 changed files with 54 additions and 60 deletions
|
@ -50,10 +50,10 @@ class Unpack(Actor):
|
|||
)
|
||||
)
|
||||
]
|
||||
cases = [(i if i else Default(),
|
||||
Cat(*self.token("source").flatten()).eq(Cat(*self.token("sink").subrecord("chunk{0}".format(i)).flatten())))
|
||||
for i in range(self.n)]
|
||||
comb.append(Case(mux, *cases))
|
||||
cases = {}
|
||||
for i in range(self.n):
|
||||
cases[i] = [Cat(*self.token("source").flatten()).eq(Cat(*self.token("sink").subrecord("chunk{0}".format(i)).flatten()))]
|
||||
comb.append(Case(mux, cases).makedefault())
|
||||
return Fragment(comb, sync)
|
||||
|
||||
class Pack(Actor):
|
||||
|
@ -69,9 +69,9 @@ class Pack(Actor):
|
|||
|
||||
load_part = Signal()
|
||||
strobe_all = Signal()
|
||||
cases = [(i,
|
||||
Cat(*self.token("source").subrecord("chunk{0}".format(i)).flatten()).eq(*self.token("sink").flatten()))
|
||||
for i in range(self.n)]
|
||||
cases = {}
|
||||
for i in range(self.n):
|
||||
cases[i] = [Cat(*self.token("source").subrecord("chunk{0}".format(i)).flatten()).eq(*self.token("sink").flatten())]
|
||||
comb = [
|
||||
self.busy.eq(strobe_all),
|
||||
self.endpoints["sink"].ack.eq(~strobe_all | self.endpoints["source"].ack),
|
||||
|
@ -83,7 +83,7 @@ class Pack(Actor):
|
|||
strobe_all.eq(0)
|
||||
),
|
||||
If(load_part,
|
||||
Case(demux, *cases),
|
||||
Case(demux, cases),
|
||||
If(demux == (self.n - 1),
|
||||
demux.eq(0),
|
||||
strobe_all.eq(1)
|
||||
|
|
|
@ -21,7 +21,7 @@ class Bank:
|
|||
nbits = bits_for(len(desc_exp)-1)
|
||||
|
||||
# Bus writes
|
||||
bwcases = []
|
||||
bwcases = {}
|
||||
for i, reg in enumerate(desc_exp):
|
||||
if isinstance(reg, RegisterRaw):
|
||||
comb.append(reg.r.eq(self.interface.dat_w[:reg.size]))
|
||||
|
@ -29,14 +29,14 @@ class Bank:
|
|||
self.interface.we & \
|
||||
(self.interface.adr[:nbits] == i)))
|
||||
elif isinstance(reg, RegisterFields):
|
||||
bwra = [i]
|
||||
bwra = []
|
||||
offset = 0
|
||||
for field in reg.fields:
|
||||
if field.access_bus == WRITE_ONLY or field.access_bus == READ_WRITE:
|
||||
bwra.append(field.storage.eq(self.interface.dat_w[offset:offset+field.size]))
|
||||
offset += field.size
|
||||
if len(bwra) > 1:
|
||||
bwcases.append(bwra)
|
||||
if bwra:
|
||||
bwcases[i] = bwra
|
||||
# commit atomic writes
|
||||
for field in reg.fields:
|
||||
if isinstance(field, FieldAlias) and field.commit_list:
|
||||
|
@ -45,13 +45,13 @@ class Bank:
|
|||
else:
|
||||
raise TypeError
|
||||
if bwcases:
|
||||
sync.append(If(sel & self.interface.we, Case(self.interface.adr[:nbits], *bwcases)))
|
||||
sync.append(If(sel & self.interface.we, Case(self.interface.adr[:nbits], bwcases)))
|
||||
|
||||
# Bus reads
|
||||
brcases = []
|
||||
brcases = {}
|
||||
for i, reg in enumerate(desc_exp):
|
||||
if isinstance(reg, RegisterRaw):
|
||||
brcases.append([i, self.interface.dat_r.eq(reg.w)])
|
||||
brcases[i] = [self.interface.dat_r.eq(reg.w)]
|
||||
elif isinstance(reg, RegisterFields):
|
||||
brs = []
|
||||
reg_readable = False
|
||||
|
@ -62,12 +62,12 @@ class Bank:
|
|||
else:
|
||||
brs.append(Replicate(0, field.size))
|
||||
if reg_readable:
|
||||
brcases.append([i, self.interface.dat_r.eq(Cat(*brs))])
|
||||
brcases[i] = [self.interface.dat_r.eq(Cat(*brs))]
|
||||
else:
|
||||
raise TypeError
|
||||
if brcases:
|
||||
sync.append(self.interface.dat_r.eq(0))
|
||||
sync.append(If(sel, Case(self.interface.adr[:nbits], *brcases)))
|
||||
sync.append(If(sel, Case(self.interface.adr[:nbits], brcases)))
|
||||
else:
|
||||
comb.append(self.interface.dat_r.eq(0))
|
||||
|
||||
|
|
|
@ -33,10 +33,10 @@ class FSM:
|
|||
self.actions[state] += statements
|
||||
|
||||
def get_fragment(self):
|
||||
cases = [[s] + a for s, a in enumerate(self.actions) if a]
|
||||
cases = dict((s, a) for s, a in enumerate(self.actions) if a)
|
||||
comb = [
|
||||
self._next_state.eq(self._state),
|
||||
Case(self._state, *cases)
|
||||
Case(self._state, cases)
|
||||
]
|
||||
sync = [self._state.eq(self._next_state)]
|
||||
return Fragment(comb, sync)
|
||||
|
|
|
@ -43,15 +43,14 @@ def chooser(signal, shift, output, n=None, reverse=False):
|
|||
if n is None:
|
||||
n = 2**len(shift)
|
||||
w = len(output)
|
||||
cases = []
|
||||
cases = {}
|
||||
for i in range(n):
|
||||
if reverse:
|
||||
s = n - i - 1
|
||||
else:
|
||||
s = i
|
||||
cases.append([i, output.eq(signal[s*w:(s+1)*w])])
|
||||
cases[n-1][0] = Default()
|
||||
return Case(shift, *cases)
|
||||
cases[i] = [output.eq(signal[s*w:(s+1)*w])]
|
||||
return Case(shift, cases).makedefault()
|
||||
|
||||
def timeline(trigger, events):
|
||||
lastevent = max([e[0] for e in events])
|
||||
|
|
|
@ -14,7 +14,7 @@ class RoundRobin:
|
|||
|
||||
def get_fragment(self):
|
||||
if self.n > 1:
|
||||
cases = []
|
||||
cases = {}
|
||||
for i in range(self.n):
|
||||
switch = []
|
||||
for j in reversed(range(i+1,i+self.n)):
|
||||
|
@ -30,8 +30,8 @@ class RoundRobin:
|
|||
case = [If(~self.request[i], *switch)]
|
||||
else:
|
||||
case = switch
|
||||
cases.append([i] + case)
|
||||
statement = Case(self.grant, *cases)
|
||||
cases[i] = case
|
||||
statement = Case(self.grant, cases)
|
||||
if self.switch_policy == SP_CE:
|
||||
statement = If(self.ce, statement)
|
||||
return Fragment(sync=[statement])
|
||||
|
|
|
@ -189,21 +189,19 @@ def _insert_else(obj, clause):
|
|||
o = o.f[0]
|
||||
o.f = clause
|
||||
|
||||
class Default:
|
||||
pass
|
||||
|
||||
class Case:
|
||||
def __init__(self, test, *cases):
|
||||
def __init__(self, test, cases):
|
||||
self.test = test
|
||||
self.cases = [(c[0], list(c[1:])) for c in cases if not isinstance(c[0], Default)]
|
||||
self.default = None
|
||||
for c in cases:
|
||||
if isinstance(c[0], Default):
|
||||
if self.default is not None:
|
||||
raise ValueError
|
||||
self.default = list(c[1:])
|
||||
if self.default is None:
|
||||
self.default = []
|
||||
self.cases = cases
|
||||
|
||||
def makedefault(self, key=None):
|
||||
if key is None:
|
||||
for choice in self.cases.keys():
|
||||
if key is None or choice > key:
|
||||
key = choice
|
||||
self.cases["default"] = self.cases[key]
|
||||
del self.cases[key]
|
||||
return self
|
||||
|
||||
# arrays
|
||||
|
||||
|
|
|
@ -151,21 +151,19 @@ class _ArrayLowerer(NodeTransformer):
|
|||
def visit_Assign(self, node):
|
||||
if isinstance(node.l, _ArrayProxy):
|
||||
k = self.visit(node.l.key)
|
||||
cases = []
|
||||
cases = {}
|
||||
for n, choice in enumerate(node.l.choices):
|
||||
assign = self.visit_Assign(_Assign(choice, node.r))
|
||||
cases.append([n, assign])
|
||||
cases[-1][0] = Default()
|
||||
return Case(k, *cases)
|
||||
cases[n] = [assign]
|
||||
return Case(k, cases).makedefault()
|
||||
else:
|
||||
return super().visit_Assign(node)
|
||||
|
||||
def visit_ArrayProxy(self, node):
|
||||
array_muxed = Signal(value_bv(node))
|
||||
cases = [[n, _Assign(array_muxed, self.visit(choice))]
|
||||
for n, choice in enumerate(node.choices)]
|
||||
cases[-1][0] = Default()
|
||||
self.comb.append(Case(self.visit(node.key), *cases))
|
||||
cases = dict((n, _Assign(array_muxed, self.visit(choice)))
|
||||
for n, choice in enumerate(node.choices))
|
||||
self.comb.append(Case(self.visit(node.key), cases).makedefault())
|
||||
return array_muxed
|
||||
|
||||
def lower_arrays(f):
|
||||
|
|
|
@ -92,15 +92,16 @@ def _printnode(ns, at, level, node):
|
|||
r += "\t"*level + "end\n"
|
||||
return r
|
||||
elif isinstance(node, Case):
|
||||
if node.cases or node.default:
|
||||
if node.cases:
|
||||
r = "\t"*level + "case (" + _printexpr(ns, node.test) + ")\n"
|
||||
for case in node.cases:
|
||||
r += "\t"*(level + 1) + _printexpr(ns, case[0]) + ": begin\n"
|
||||
r += _printnode(ns, at, level + 2, case[1])
|
||||
css = sorted([(k, v) for (k, v) in node.cases.items() if k != "default"], key=itemgetter(0))
|
||||
for choice, statements in css:
|
||||
r += "\t"*(level + 1) + _printexpr(ns, choice) + ": begin\n"
|
||||
r += _printnode(ns, at, level + 2, statements)
|
||||
r += "\t"*(level + 1) + "end\n"
|
||||
if node.default:
|
||||
if "default" in node.cases:
|
||||
r += "\t"*(level + 1) + "default: begin\n"
|
||||
r += _printnode(ns, at, level + 2, node.default)
|
||||
r += _printnode(ns, at, level + 2, node.cases["default"])
|
||||
r += "\t"*(level + 1) + "end\n"
|
||||
r += "\t"*level + "endcase\n"
|
||||
return r
|
||||
|
|
|
@ -65,9 +65,8 @@ class NodeVisitor:
|
|||
|
||||
def visit_Case(self, node):
|
||||
self.visit(node.test)
|
||||
for v, statements in node.cases:
|
||||
for v, statements in node.cases.items():
|
||||
self.visit(statements)
|
||||
self.visit(node.default)
|
||||
|
||||
def visit_Fragment(self, node):
|
||||
self.visit(node.comb)
|
||||
|
@ -155,9 +154,8 @@ class NodeTransformer:
|
|||
return r
|
||||
|
||||
def visit_Case(self, node):
|
||||
r = Case(self.visit(node.test))
|
||||
r.cases = [(v, self.visit(statements)) for v, statements in node.cases]
|
||||
r.default = self.visit(node.default)
|
||||
cases = dict((v, self.visit(statements)) for v, statements in node.cases.items())
|
||||
r = Case(self.visit(node.test), cases)
|
||||
return r
|
||||
|
||||
def visit_Fragment(self, node):
|
||||
|
|
|
@ -48,6 +48,6 @@ class ImplRegister:
|
|||
raise FinalizeError
|
||||
# do nothing when sel == 0
|
||||
items = sorted(self.source_encoding.items(), key=itemgetter(1))
|
||||
cases = [(v, self.storage.eq(self.id_to_source[k])) for k, v in items]
|
||||
sync = [Case(self.sel, *cases)]
|
||||
cases = dict((v, self.storage.eq(self.id_to_source[k])) for k, v in items)
|
||||
sync = [Case(self.sel, cases)]
|
||||
return Fragment(sync=sync)
|
||||
|
|
Loading…
Reference in a new issue