Refactor Case

This commit is contained in:
Sebastien Bourdeauducq 2012-11-29 01:11:15 +01:00
parent 070652cc39
commit 6eebfce44a
10 changed files with 54 additions and 60 deletions

View file

@ -50,10 +50,10 @@ class Unpack(Actor):
)
)
]
cases = [(i if i else Default(),
Cat(*self.token("source").flatten()).eq(Cat(*self.token("sink").subrecord("chunk{0}".format(i)).flatten())))
for i in range(self.n)]
comb.append(Case(mux, *cases))
cases = {}
for i in range(self.n):
cases[i] = [Cat(*self.token("source").flatten()).eq(Cat(*self.token("sink").subrecord("chunk{0}".format(i)).flatten()))]
comb.append(Case(mux, cases).makedefault())
return Fragment(comb, sync)
class Pack(Actor):
@ -69,9 +69,9 @@ class Pack(Actor):
load_part = Signal()
strobe_all = Signal()
cases = [(i,
Cat(*self.token("source").subrecord("chunk{0}".format(i)).flatten()).eq(*self.token("sink").flatten()))
for i in range(self.n)]
cases = {}
for i in range(self.n):
cases[i] = [Cat(*self.token("source").subrecord("chunk{0}".format(i)).flatten()).eq(*self.token("sink").flatten())]
comb = [
self.busy.eq(strobe_all),
self.endpoints["sink"].ack.eq(~strobe_all | self.endpoints["source"].ack),
@ -83,7 +83,7 @@ class Pack(Actor):
strobe_all.eq(0)
),
If(load_part,
Case(demux, *cases),
Case(demux, cases),
If(demux == (self.n - 1),
demux.eq(0),
strobe_all.eq(1)

View file

@ -21,7 +21,7 @@ class Bank:
nbits = bits_for(len(desc_exp)-1)
# Bus writes
bwcases = []
bwcases = {}
for i, reg in enumerate(desc_exp):
if isinstance(reg, RegisterRaw):
comb.append(reg.r.eq(self.interface.dat_w[:reg.size]))
@ -29,14 +29,14 @@ class Bank:
self.interface.we & \
(self.interface.adr[:nbits] == i)))
elif isinstance(reg, RegisterFields):
bwra = [i]
bwra = []
offset = 0
for field in reg.fields:
if field.access_bus == WRITE_ONLY or field.access_bus == READ_WRITE:
bwra.append(field.storage.eq(self.interface.dat_w[offset:offset+field.size]))
offset += field.size
if len(bwra) > 1:
bwcases.append(bwra)
if bwra:
bwcases[i] = bwra
# commit atomic writes
for field in reg.fields:
if isinstance(field, FieldAlias) and field.commit_list:
@ -45,13 +45,13 @@ class Bank:
else:
raise TypeError
if bwcases:
sync.append(If(sel & self.interface.we, Case(self.interface.adr[:nbits], *bwcases)))
sync.append(If(sel & self.interface.we, Case(self.interface.adr[:nbits], bwcases)))
# Bus reads
brcases = []
brcases = {}
for i, reg in enumerate(desc_exp):
if isinstance(reg, RegisterRaw):
brcases.append([i, self.interface.dat_r.eq(reg.w)])
brcases[i] = [self.interface.dat_r.eq(reg.w)]
elif isinstance(reg, RegisterFields):
brs = []
reg_readable = False
@ -62,12 +62,12 @@ class Bank:
else:
brs.append(Replicate(0, field.size))
if reg_readable:
brcases.append([i, self.interface.dat_r.eq(Cat(*brs))])
brcases[i] = [self.interface.dat_r.eq(Cat(*brs))]
else:
raise TypeError
if brcases:
sync.append(self.interface.dat_r.eq(0))
sync.append(If(sel, Case(self.interface.adr[:nbits], *brcases)))
sync.append(If(sel, Case(self.interface.adr[:nbits], brcases)))
else:
comb.append(self.interface.dat_r.eq(0))

View file

@ -33,10 +33,10 @@ class FSM:
self.actions[state] += statements
def get_fragment(self):
cases = [[s] + a for s, a in enumerate(self.actions) if a]
cases = dict((s, a) for s, a in enumerate(self.actions) if a)
comb = [
self._next_state.eq(self._state),
Case(self._state, *cases)
Case(self._state, cases)
]
sync = [self._state.eq(self._next_state)]
return Fragment(comb, sync)

View file

@ -43,15 +43,14 @@ def chooser(signal, shift, output, n=None, reverse=False):
if n is None:
n = 2**len(shift)
w = len(output)
cases = []
cases = {}
for i in range(n):
if reverse:
s = n - i - 1
else:
s = i
cases.append([i, output.eq(signal[s*w:(s+1)*w])])
cases[n-1][0] = Default()
return Case(shift, *cases)
cases[i] = [output.eq(signal[s*w:(s+1)*w])]
return Case(shift, cases).makedefault()
def timeline(trigger, events):
lastevent = max([e[0] for e in events])

View file

@ -14,7 +14,7 @@ class RoundRobin:
def get_fragment(self):
if self.n > 1:
cases = []
cases = {}
for i in range(self.n):
switch = []
for j in reversed(range(i+1,i+self.n)):
@ -30,8 +30,8 @@ class RoundRobin:
case = [If(~self.request[i], *switch)]
else:
case = switch
cases.append([i] + case)
statement = Case(self.grant, *cases)
cases[i] = case
statement = Case(self.grant, cases)
if self.switch_policy == SP_CE:
statement = If(self.ce, statement)
return Fragment(sync=[statement])

View file

@ -189,21 +189,19 @@ def _insert_else(obj, clause):
o = o.f[0]
o.f = clause
class Default:
pass
class Case:
def __init__(self, test, *cases):
def __init__(self, test, cases):
self.test = test
self.cases = [(c[0], list(c[1:])) for c in cases if not isinstance(c[0], Default)]
self.default = None
for c in cases:
if isinstance(c[0], Default):
if self.default is not None:
raise ValueError
self.default = list(c[1:])
if self.default is None:
self.default = []
self.cases = cases
def makedefault(self, key=None):
if key is None:
for choice in self.cases.keys():
if key is None or choice > key:
key = choice
self.cases["default"] = self.cases[key]
del self.cases[key]
return self
# arrays

View file

@ -151,21 +151,19 @@ class _ArrayLowerer(NodeTransformer):
def visit_Assign(self, node):
if isinstance(node.l, _ArrayProxy):
k = self.visit(node.l.key)
cases = []
cases = {}
for n, choice in enumerate(node.l.choices):
assign = self.visit_Assign(_Assign(choice, node.r))
cases.append([n, assign])
cases[-1][0] = Default()
return Case(k, *cases)
cases[n] = [assign]
return Case(k, cases).makedefault()
else:
return super().visit_Assign(node)
def visit_ArrayProxy(self, node):
array_muxed = Signal(value_bv(node))
cases = [[n, _Assign(array_muxed, self.visit(choice))]
for n, choice in enumerate(node.choices)]
cases[-1][0] = Default()
self.comb.append(Case(self.visit(node.key), *cases))
cases = dict((n, _Assign(array_muxed, self.visit(choice)))
for n, choice in enumerate(node.choices))
self.comb.append(Case(self.visit(node.key), cases).makedefault())
return array_muxed
def lower_arrays(f):

View file

@ -92,15 +92,16 @@ def _printnode(ns, at, level, node):
r += "\t"*level + "end\n"
return r
elif isinstance(node, Case):
if node.cases or node.default:
if node.cases:
r = "\t"*level + "case (" + _printexpr(ns, node.test) + ")\n"
for case in node.cases:
r += "\t"*(level + 1) + _printexpr(ns, case[0]) + ": begin\n"
r += _printnode(ns, at, level + 2, case[1])
css = sorted([(k, v) for (k, v) in node.cases.items() if k != "default"], key=itemgetter(0))
for choice, statements in css:
r += "\t"*(level + 1) + _printexpr(ns, choice) + ": begin\n"
r += _printnode(ns, at, level + 2, statements)
r += "\t"*(level + 1) + "end\n"
if node.default:
if "default" in node.cases:
r += "\t"*(level + 1) + "default: begin\n"
r += _printnode(ns, at, level + 2, node.default)
r += _printnode(ns, at, level + 2, node.cases["default"])
r += "\t"*(level + 1) + "end\n"
r += "\t"*level + "endcase\n"
return r

View file

@ -65,9 +65,8 @@ class NodeVisitor:
def visit_Case(self, node):
self.visit(node.test)
for v, statements in node.cases:
for v, statements in node.cases.items():
self.visit(statements)
self.visit(node.default)
def visit_Fragment(self, node):
self.visit(node.comb)
@ -155,9 +154,8 @@ class NodeTransformer:
return r
def visit_Case(self, node):
r = Case(self.visit(node.test))
r.cases = [(v, self.visit(statements)) for v, statements in node.cases]
r.default = self.visit(node.default)
cases = dict((v, self.visit(statements)) for v, statements in node.cases.items())
r = Case(self.visit(node.test), cases)
return r
def visit_Fragment(self, node):

View file

@ -48,6 +48,6 @@ class ImplRegister:
raise FinalizeError
# do nothing when sel == 0
items = sorted(self.source_encoding.items(), key=itemgetter(1))
cases = [(v, self.storage.eq(self.id_to_source[k])) for k, v in items]
sync = [Case(self.sel, *cases)]
cases = dict((v, self.storage.eq(self.id_to_source[k])) for k, v in items)
sync = [Case(self.sel, cases)]
return Fragment(sync=sync)