mirror of
https://github.com/enjoy-digital/litex.git
synced 2025-01-04 09:52:26 -05:00
Refactor Case
This commit is contained in:
parent
070652cc39
commit
6eebfce44a
10 changed files with 54 additions and 60 deletions
|
@ -50,10 +50,10 @@ class Unpack(Actor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
cases = [(i if i else Default(),
|
cases = {}
|
||||||
Cat(*self.token("source").flatten()).eq(Cat(*self.token("sink").subrecord("chunk{0}".format(i)).flatten())))
|
for i in range(self.n):
|
||||||
for i in range(self.n)]
|
cases[i] = [Cat(*self.token("source").flatten()).eq(Cat(*self.token("sink").subrecord("chunk{0}".format(i)).flatten()))]
|
||||||
comb.append(Case(mux, *cases))
|
comb.append(Case(mux, cases).makedefault())
|
||||||
return Fragment(comb, sync)
|
return Fragment(comb, sync)
|
||||||
|
|
||||||
class Pack(Actor):
|
class Pack(Actor):
|
||||||
|
@ -69,9 +69,9 @@ class Pack(Actor):
|
||||||
|
|
||||||
load_part = Signal()
|
load_part = Signal()
|
||||||
strobe_all = Signal()
|
strobe_all = Signal()
|
||||||
cases = [(i,
|
cases = {}
|
||||||
Cat(*self.token("source").subrecord("chunk{0}".format(i)).flatten()).eq(*self.token("sink").flatten()))
|
for i in range(self.n):
|
||||||
for i in range(self.n)]
|
cases[i] = [Cat(*self.token("source").subrecord("chunk{0}".format(i)).flatten()).eq(*self.token("sink").flatten())]
|
||||||
comb = [
|
comb = [
|
||||||
self.busy.eq(strobe_all),
|
self.busy.eq(strobe_all),
|
||||||
self.endpoints["sink"].ack.eq(~strobe_all | self.endpoints["source"].ack),
|
self.endpoints["sink"].ack.eq(~strobe_all | self.endpoints["source"].ack),
|
||||||
|
@ -83,7 +83,7 @@ class Pack(Actor):
|
||||||
strobe_all.eq(0)
|
strobe_all.eq(0)
|
||||||
),
|
),
|
||||||
If(load_part,
|
If(load_part,
|
||||||
Case(demux, *cases),
|
Case(demux, cases),
|
||||||
If(demux == (self.n - 1),
|
If(demux == (self.n - 1),
|
||||||
demux.eq(0),
|
demux.eq(0),
|
||||||
strobe_all.eq(1)
|
strobe_all.eq(1)
|
||||||
|
|
|
@ -21,7 +21,7 @@ class Bank:
|
||||||
nbits = bits_for(len(desc_exp)-1)
|
nbits = bits_for(len(desc_exp)-1)
|
||||||
|
|
||||||
# Bus writes
|
# Bus writes
|
||||||
bwcases = []
|
bwcases = {}
|
||||||
for i, reg in enumerate(desc_exp):
|
for i, reg in enumerate(desc_exp):
|
||||||
if isinstance(reg, RegisterRaw):
|
if isinstance(reg, RegisterRaw):
|
||||||
comb.append(reg.r.eq(self.interface.dat_w[:reg.size]))
|
comb.append(reg.r.eq(self.interface.dat_w[:reg.size]))
|
||||||
|
@ -29,14 +29,14 @@ class Bank:
|
||||||
self.interface.we & \
|
self.interface.we & \
|
||||||
(self.interface.adr[:nbits] == i)))
|
(self.interface.adr[:nbits] == i)))
|
||||||
elif isinstance(reg, RegisterFields):
|
elif isinstance(reg, RegisterFields):
|
||||||
bwra = [i]
|
bwra = []
|
||||||
offset = 0
|
offset = 0
|
||||||
for field in reg.fields:
|
for field in reg.fields:
|
||||||
if field.access_bus == WRITE_ONLY or field.access_bus == READ_WRITE:
|
if field.access_bus == WRITE_ONLY or field.access_bus == READ_WRITE:
|
||||||
bwra.append(field.storage.eq(self.interface.dat_w[offset:offset+field.size]))
|
bwra.append(field.storage.eq(self.interface.dat_w[offset:offset+field.size]))
|
||||||
offset += field.size
|
offset += field.size
|
||||||
if len(bwra) > 1:
|
if bwra:
|
||||||
bwcases.append(bwra)
|
bwcases[i] = bwra
|
||||||
# commit atomic writes
|
# commit atomic writes
|
||||||
for field in reg.fields:
|
for field in reg.fields:
|
||||||
if isinstance(field, FieldAlias) and field.commit_list:
|
if isinstance(field, FieldAlias) and field.commit_list:
|
||||||
|
@ -45,13 +45,13 @@ class Bank:
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
if bwcases:
|
if bwcases:
|
||||||
sync.append(If(sel & self.interface.we, Case(self.interface.adr[:nbits], *bwcases)))
|
sync.append(If(sel & self.interface.we, Case(self.interface.adr[:nbits], bwcases)))
|
||||||
|
|
||||||
# Bus reads
|
# Bus reads
|
||||||
brcases = []
|
brcases = {}
|
||||||
for i, reg in enumerate(desc_exp):
|
for i, reg in enumerate(desc_exp):
|
||||||
if isinstance(reg, RegisterRaw):
|
if isinstance(reg, RegisterRaw):
|
||||||
brcases.append([i, self.interface.dat_r.eq(reg.w)])
|
brcases[i] = [self.interface.dat_r.eq(reg.w)]
|
||||||
elif isinstance(reg, RegisterFields):
|
elif isinstance(reg, RegisterFields):
|
||||||
brs = []
|
brs = []
|
||||||
reg_readable = False
|
reg_readable = False
|
||||||
|
@ -62,12 +62,12 @@ class Bank:
|
||||||
else:
|
else:
|
||||||
brs.append(Replicate(0, field.size))
|
brs.append(Replicate(0, field.size))
|
||||||
if reg_readable:
|
if reg_readable:
|
||||||
brcases.append([i, self.interface.dat_r.eq(Cat(*brs))])
|
brcases[i] = [self.interface.dat_r.eq(Cat(*brs))]
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
if brcases:
|
if brcases:
|
||||||
sync.append(self.interface.dat_r.eq(0))
|
sync.append(self.interface.dat_r.eq(0))
|
||||||
sync.append(If(sel, Case(self.interface.adr[:nbits], *brcases)))
|
sync.append(If(sel, Case(self.interface.adr[:nbits], brcases)))
|
||||||
else:
|
else:
|
||||||
comb.append(self.interface.dat_r.eq(0))
|
comb.append(self.interface.dat_r.eq(0))
|
||||||
|
|
||||||
|
|
|
@ -33,10 +33,10 @@ class FSM:
|
||||||
self.actions[state] += statements
|
self.actions[state] += statements
|
||||||
|
|
||||||
def get_fragment(self):
|
def get_fragment(self):
|
||||||
cases = [[s] + a for s, a in enumerate(self.actions) if a]
|
cases = dict((s, a) for s, a in enumerate(self.actions) if a)
|
||||||
comb = [
|
comb = [
|
||||||
self._next_state.eq(self._state),
|
self._next_state.eq(self._state),
|
||||||
Case(self._state, *cases)
|
Case(self._state, cases)
|
||||||
]
|
]
|
||||||
sync = [self._state.eq(self._next_state)]
|
sync = [self._state.eq(self._next_state)]
|
||||||
return Fragment(comb, sync)
|
return Fragment(comb, sync)
|
||||||
|
|
|
@ -43,15 +43,14 @@ def chooser(signal, shift, output, n=None, reverse=False):
|
||||||
if n is None:
|
if n is None:
|
||||||
n = 2**len(shift)
|
n = 2**len(shift)
|
||||||
w = len(output)
|
w = len(output)
|
||||||
cases = []
|
cases = {}
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
if reverse:
|
if reverse:
|
||||||
s = n - i - 1
|
s = n - i - 1
|
||||||
else:
|
else:
|
||||||
s = i
|
s = i
|
||||||
cases.append([i, output.eq(signal[s*w:(s+1)*w])])
|
cases[i] = [output.eq(signal[s*w:(s+1)*w])]
|
||||||
cases[n-1][0] = Default()
|
return Case(shift, cases).makedefault()
|
||||||
return Case(shift, *cases)
|
|
||||||
|
|
||||||
def timeline(trigger, events):
|
def timeline(trigger, events):
|
||||||
lastevent = max([e[0] for e in events])
|
lastevent = max([e[0] for e in events])
|
||||||
|
|
|
@ -14,7 +14,7 @@ class RoundRobin:
|
||||||
|
|
||||||
def get_fragment(self):
|
def get_fragment(self):
|
||||||
if self.n > 1:
|
if self.n > 1:
|
||||||
cases = []
|
cases = {}
|
||||||
for i in range(self.n):
|
for i in range(self.n):
|
||||||
switch = []
|
switch = []
|
||||||
for j in reversed(range(i+1,i+self.n)):
|
for j in reversed(range(i+1,i+self.n)):
|
||||||
|
@ -30,8 +30,8 @@ class RoundRobin:
|
||||||
case = [If(~self.request[i], *switch)]
|
case = [If(~self.request[i], *switch)]
|
||||||
else:
|
else:
|
||||||
case = switch
|
case = switch
|
||||||
cases.append([i] + case)
|
cases[i] = case
|
||||||
statement = Case(self.grant, *cases)
|
statement = Case(self.grant, cases)
|
||||||
if self.switch_policy == SP_CE:
|
if self.switch_policy == SP_CE:
|
||||||
statement = If(self.ce, statement)
|
statement = If(self.ce, statement)
|
||||||
return Fragment(sync=[statement])
|
return Fragment(sync=[statement])
|
||||||
|
|
|
@ -189,21 +189,19 @@ def _insert_else(obj, clause):
|
||||||
o = o.f[0]
|
o = o.f[0]
|
||||||
o.f = clause
|
o.f = clause
|
||||||
|
|
||||||
class Default:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Case:
|
class Case:
|
||||||
def __init__(self, test, *cases):
|
def __init__(self, test, cases):
|
||||||
self.test = test
|
self.test = test
|
||||||
self.cases = [(c[0], list(c[1:])) for c in cases if not isinstance(c[0], Default)]
|
self.cases = cases
|
||||||
self.default = None
|
|
||||||
for c in cases:
|
def makedefault(self, key=None):
|
||||||
if isinstance(c[0], Default):
|
if key is None:
|
||||||
if self.default is not None:
|
for choice in self.cases.keys():
|
||||||
raise ValueError
|
if key is None or choice > key:
|
||||||
self.default = list(c[1:])
|
key = choice
|
||||||
if self.default is None:
|
self.cases["default"] = self.cases[key]
|
||||||
self.default = []
|
del self.cases[key]
|
||||||
|
return self
|
||||||
|
|
||||||
# arrays
|
# arrays
|
||||||
|
|
||||||
|
|
|
@ -151,21 +151,19 @@ class _ArrayLowerer(NodeTransformer):
|
||||||
def visit_Assign(self, node):
|
def visit_Assign(self, node):
|
||||||
if isinstance(node.l, _ArrayProxy):
|
if isinstance(node.l, _ArrayProxy):
|
||||||
k = self.visit(node.l.key)
|
k = self.visit(node.l.key)
|
||||||
cases = []
|
cases = {}
|
||||||
for n, choice in enumerate(node.l.choices):
|
for n, choice in enumerate(node.l.choices):
|
||||||
assign = self.visit_Assign(_Assign(choice, node.r))
|
assign = self.visit_Assign(_Assign(choice, node.r))
|
||||||
cases.append([n, assign])
|
cases[n] = [assign]
|
||||||
cases[-1][0] = Default()
|
return Case(k, cases).makedefault()
|
||||||
return Case(k, *cases)
|
|
||||||
else:
|
else:
|
||||||
return super().visit_Assign(node)
|
return super().visit_Assign(node)
|
||||||
|
|
||||||
def visit_ArrayProxy(self, node):
|
def visit_ArrayProxy(self, node):
|
||||||
array_muxed = Signal(value_bv(node))
|
array_muxed = Signal(value_bv(node))
|
||||||
cases = [[n, _Assign(array_muxed, self.visit(choice))]
|
cases = dict((n, _Assign(array_muxed, self.visit(choice)))
|
||||||
for n, choice in enumerate(node.choices)]
|
for n, choice in enumerate(node.choices))
|
||||||
cases[-1][0] = Default()
|
self.comb.append(Case(self.visit(node.key), cases).makedefault())
|
||||||
self.comb.append(Case(self.visit(node.key), *cases))
|
|
||||||
return array_muxed
|
return array_muxed
|
||||||
|
|
||||||
def lower_arrays(f):
|
def lower_arrays(f):
|
||||||
|
|
|
@ -92,15 +92,16 @@ def _printnode(ns, at, level, node):
|
||||||
r += "\t"*level + "end\n"
|
r += "\t"*level + "end\n"
|
||||||
return r
|
return r
|
||||||
elif isinstance(node, Case):
|
elif isinstance(node, Case):
|
||||||
if node.cases or node.default:
|
if node.cases:
|
||||||
r = "\t"*level + "case (" + _printexpr(ns, node.test) + ")\n"
|
r = "\t"*level + "case (" + _printexpr(ns, node.test) + ")\n"
|
||||||
for case in node.cases:
|
css = sorted([(k, v) for (k, v) in node.cases.items() if k != "default"], key=itemgetter(0))
|
||||||
r += "\t"*(level + 1) + _printexpr(ns, case[0]) + ": begin\n"
|
for choice, statements in css:
|
||||||
r += _printnode(ns, at, level + 2, case[1])
|
r += "\t"*(level + 1) + _printexpr(ns, choice) + ": begin\n"
|
||||||
|
r += _printnode(ns, at, level + 2, statements)
|
||||||
r += "\t"*(level + 1) + "end\n"
|
r += "\t"*(level + 1) + "end\n"
|
||||||
if node.default:
|
if "default" in node.cases:
|
||||||
r += "\t"*(level + 1) + "default: begin\n"
|
r += "\t"*(level + 1) + "default: begin\n"
|
||||||
r += _printnode(ns, at, level + 2, node.default)
|
r += _printnode(ns, at, level + 2, node.cases["default"])
|
||||||
r += "\t"*(level + 1) + "end\n"
|
r += "\t"*(level + 1) + "end\n"
|
||||||
r += "\t"*level + "endcase\n"
|
r += "\t"*level + "endcase\n"
|
||||||
return r
|
return r
|
||||||
|
|
|
@ -65,9 +65,8 @@ class NodeVisitor:
|
||||||
|
|
||||||
def visit_Case(self, node):
|
def visit_Case(self, node):
|
||||||
self.visit(node.test)
|
self.visit(node.test)
|
||||||
for v, statements in node.cases:
|
for v, statements in node.cases.items():
|
||||||
self.visit(statements)
|
self.visit(statements)
|
||||||
self.visit(node.default)
|
|
||||||
|
|
||||||
def visit_Fragment(self, node):
|
def visit_Fragment(self, node):
|
||||||
self.visit(node.comb)
|
self.visit(node.comb)
|
||||||
|
@ -155,9 +154,8 @@ class NodeTransformer:
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def visit_Case(self, node):
|
def visit_Case(self, node):
|
||||||
r = Case(self.visit(node.test))
|
cases = dict((v, self.visit(statements)) for v, statements in node.cases.items())
|
||||||
r.cases = [(v, self.visit(statements)) for v, statements in node.cases]
|
r = Case(self.visit(node.test), cases)
|
||||||
r.default = self.visit(node.default)
|
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def visit_Fragment(self, node):
|
def visit_Fragment(self, node):
|
||||||
|
|
|
@ -48,6 +48,6 @@ class ImplRegister:
|
||||||
raise FinalizeError
|
raise FinalizeError
|
||||||
# do nothing when sel == 0
|
# do nothing when sel == 0
|
||||||
items = sorted(self.source_encoding.items(), key=itemgetter(1))
|
items = sorted(self.source_encoding.items(), key=itemgetter(1))
|
||||||
cases = [(v, self.storage.eq(self.id_to_source[k])) for k, v in items]
|
cases = dict((v, self.storage.eq(self.id_to_source[k])) for k, v in items)
|
||||||
sync = [Case(self.sel, *cases)]
|
sync = [Case(self.sel, cases)]
|
||||||
return Fragment(sync=sync)
|
return Fragment(sync=sync)
|
||||||
|
|
Loading…
Reference in a new issue