Refactor Case

This commit is contained in:
Sebastien Bourdeauducq 2012-11-29 01:11:15 +01:00
parent 070652cc39
commit 6eebfce44a
10 changed files with 54 additions and 60 deletions

View file

@ -50,10 +50,10 @@ class Unpack(Actor):
) )
) )
] ]
cases = [(i if i else Default(), cases = {}
Cat(*self.token("source").flatten()).eq(Cat(*self.token("sink").subrecord("chunk{0}".format(i)).flatten()))) for i in range(self.n):
for i in range(self.n)] cases[i] = [Cat(*self.token("source").flatten()).eq(Cat(*self.token("sink").subrecord("chunk{0}".format(i)).flatten()))]
comb.append(Case(mux, *cases)) comb.append(Case(mux, cases).makedefault())
return Fragment(comb, sync) return Fragment(comb, sync)
class Pack(Actor): class Pack(Actor):
@ -69,9 +69,9 @@ class Pack(Actor):
load_part = Signal() load_part = Signal()
strobe_all = Signal() strobe_all = Signal()
cases = [(i, cases = {}
Cat(*self.token("source").subrecord("chunk{0}".format(i)).flatten()).eq(*self.token("sink").flatten())) for i in range(self.n):
for i in range(self.n)] cases[i] = [Cat(*self.token("source").subrecord("chunk{0}".format(i)).flatten()).eq(*self.token("sink").flatten())]
comb = [ comb = [
self.busy.eq(strobe_all), self.busy.eq(strobe_all),
self.endpoints["sink"].ack.eq(~strobe_all | self.endpoints["source"].ack), self.endpoints["sink"].ack.eq(~strobe_all | self.endpoints["source"].ack),
@ -83,7 +83,7 @@ class Pack(Actor):
strobe_all.eq(0) strobe_all.eq(0)
), ),
If(load_part, If(load_part,
Case(demux, *cases), Case(demux, cases),
If(demux == (self.n - 1), If(demux == (self.n - 1),
demux.eq(0), demux.eq(0),
strobe_all.eq(1) strobe_all.eq(1)

View file

@ -21,7 +21,7 @@ class Bank:
nbits = bits_for(len(desc_exp)-1) nbits = bits_for(len(desc_exp)-1)
# Bus writes # Bus writes
bwcases = [] bwcases = {}
for i, reg in enumerate(desc_exp): for i, reg in enumerate(desc_exp):
if isinstance(reg, RegisterRaw): if isinstance(reg, RegisterRaw):
comb.append(reg.r.eq(self.interface.dat_w[:reg.size])) comb.append(reg.r.eq(self.interface.dat_w[:reg.size]))
@ -29,14 +29,14 @@ class Bank:
self.interface.we & \ self.interface.we & \
(self.interface.adr[:nbits] == i))) (self.interface.adr[:nbits] == i)))
elif isinstance(reg, RegisterFields): elif isinstance(reg, RegisterFields):
bwra = [i] bwra = []
offset = 0 offset = 0
for field in reg.fields: for field in reg.fields:
if field.access_bus == WRITE_ONLY or field.access_bus == READ_WRITE: if field.access_bus == WRITE_ONLY or field.access_bus == READ_WRITE:
bwra.append(field.storage.eq(self.interface.dat_w[offset:offset+field.size])) bwra.append(field.storage.eq(self.interface.dat_w[offset:offset+field.size]))
offset += field.size offset += field.size
if len(bwra) > 1: if bwra:
bwcases.append(bwra) bwcases[i] = bwra
# commit atomic writes # commit atomic writes
for field in reg.fields: for field in reg.fields:
if isinstance(field, FieldAlias) and field.commit_list: if isinstance(field, FieldAlias) and field.commit_list:
@ -45,13 +45,13 @@ class Bank:
else: else:
raise TypeError raise TypeError
if bwcases: if bwcases:
sync.append(If(sel & self.interface.we, Case(self.interface.adr[:nbits], *bwcases))) sync.append(If(sel & self.interface.we, Case(self.interface.adr[:nbits], bwcases)))
# Bus reads # Bus reads
brcases = [] brcases = {}
for i, reg in enumerate(desc_exp): for i, reg in enumerate(desc_exp):
if isinstance(reg, RegisterRaw): if isinstance(reg, RegisterRaw):
brcases.append([i, self.interface.dat_r.eq(reg.w)]) brcases[i] = [self.interface.dat_r.eq(reg.w)]
elif isinstance(reg, RegisterFields): elif isinstance(reg, RegisterFields):
brs = [] brs = []
reg_readable = False reg_readable = False
@ -62,12 +62,12 @@ class Bank:
else: else:
brs.append(Replicate(0, field.size)) brs.append(Replicate(0, field.size))
if reg_readable: if reg_readable:
brcases.append([i, self.interface.dat_r.eq(Cat(*brs))]) brcases[i] = [self.interface.dat_r.eq(Cat(*brs))]
else: else:
raise TypeError raise TypeError
if brcases: if brcases:
sync.append(self.interface.dat_r.eq(0)) sync.append(self.interface.dat_r.eq(0))
sync.append(If(sel, Case(self.interface.adr[:nbits], *brcases))) sync.append(If(sel, Case(self.interface.adr[:nbits], brcases)))
else: else:
comb.append(self.interface.dat_r.eq(0)) comb.append(self.interface.dat_r.eq(0))

View file

@ -33,10 +33,10 @@ class FSM:
self.actions[state] += statements self.actions[state] += statements
def get_fragment(self): def get_fragment(self):
cases = [[s] + a for s, a in enumerate(self.actions) if a] cases = dict((s, a) for s, a in enumerate(self.actions) if a)
comb = [ comb = [
self._next_state.eq(self._state), self._next_state.eq(self._state),
Case(self._state, *cases) Case(self._state, cases)
] ]
sync = [self._state.eq(self._next_state)] sync = [self._state.eq(self._next_state)]
return Fragment(comb, sync) return Fragment(comb, sync)

View file

@ -43,15 +43,14 @@ def chooser(signal, shift, output, n=None, reverse=False):
if n is None: if n is None:
n = 2**len(shift) n = 2**len(shift)
w = len(output) w = len(output)
cases = [] cases = {}
for i in range(n): for i in range(n):
if reverse: if reverse:
s = n - i - 1 s = n - i - 1
else: else:
s = i s = i
cases.append([i, output.eq(signal[s*w:(s+1)*w])]) cases[i] = [output.eq(signal[s*w:(s+1)*w])]
cases[n-1][0] = Default() return Case(shift, cases).makedefault()
return Case(shift, *cases)
def timeline(trigger, events): def timeline(trigger, events):
lastevent = max([e[0] for e in events]) lastevent = max([e[0] for e in events])

View file

@ -14,7 +14,7 @@ class RoundRobin:
def get_fragment(self): def get_fragment(self):
if self.n > 1: if self.n > 1:
cases = [] cases = {}
for i in range(self.n): for i in range(self.n):
switch = [] switch = []
for j in reversed(range(i+1,i+self.n)): for j in reversed(range(i+1,i+self.n)):
@ -30,8 +30,8 @@ class RoundRobin:
case = [If(~self.request[i], *switch)] case = [If(~self.request[i], *switch)]
else: else:
case = switch case = switch
cases.append([i] + case) cases[i] = case
statement = Case(self.grant, *cases) statement = Case(self.grant, cases)
if self.switch_policy == SP_CE: if self.switch_policy == SP_CE:
statement = If(self.ce, statement) statement = If(self.ce, statement)
return Fragment(sync=[statement]) return Fragment(sync=[statement])

View file

@ -189,21 +189,19 @@ def _insert_else(obj, clause):
o = o.f[0] o = o.f[0]
o.f = clause o.f = clause
class Default:
pass
class Case: class Case:
def __init__(self, test, *cases): def __init__(self, test, cases):
self.test = test self.test = test
self.cases = [(c[0], list(c[1:])) for c in cases if not isinstance(c[0], Default)] self.cases = cases
self.default = None
for c in cases: def makedefault(self, key=None):
if isinstance(c[0], Default): if key is None:
if self.default is not None: for choice in self.cases.keys():
raise ValueError if key is None or choice > key:
self.default = list(c[1:]) key = choice
if self.default is None: self.cases["default"] = self.cases[key]
self.default = [] del self.cases[key]
return self
# arrays # arrays

View file

@ -151,21 +151,19 @@ class _ArrayLowerer(NodeTransformer):
def visit_Assign(self, node): def visit_Assign(self, node):
if isinstance(node.l, _ArrayProxy): if isinstance(node.l, _ArrayProxy):
k = self.visit(node.l.key) k = self.visit(node.l.key)
cases = [] cases = {}
for n, choice in enumerate(node.l.choices): for n, choice in enumerate(node.l.choices):
assign = self.visit_Assign(_Assign(choice, node.r)) assign = self.visit_Assign(_Assign(choice, node.r))
cases.append([n, assign]) cases[n] = [assign]
cases[-1][0] = Default() return Case(k, cases).makedefault()
return Case(k, *cases)
else: else:
return super().visit_Assign(node) return super().visit_Assign(node)
def visit_ArrayProxy(self, node): def visit_ArrayProxy(self, node):
array_muxed = Signal(value_bv(node)) array_muxed = Signal(value_bv(node))
cases = [[n, _Assign(array_muxed, self.visit(choice))] cases = dict((n, _Assign(array_muxed, self.visit(choice)))
for n, choice in enumerate(node.choices)] for n, choice in enumerate(node.choices))
cases[-1][0] = Default() self.comb.append(Case(self.visit(node.key), cases).makedefault())
self.comb.append(Case(self.visit(node.key), *cases))
return array_muxed return array_muxed
def lower_arrays(f): def lower_arrays(f):

View file

@ -92,15 +92,16 @@ def _printnode(ns, at, level, node):
r += "\t"*level + "end\n" r += "\t"*level + "end\n"
return r return r
elif isinstance(node, Case): elif isinstance(node, Case):
if node.cases or node.default: if node.cases:
r = "\t"*level + "case (" + _printexpr(ns, node.test) + ")\n" r = "\t"*level + "case (" + _printexpr(ns, node.test) + ")\n"
for case in node.cases: css = sorted([(k, v) for (k, v) in node.cases.items() if k != "default"], key=itemgetter(0))
r += "\t"*(level + 1) + _printexpr(ns, case[0]) + ": begin\n" for choice, statements in css:
r += _printnode(ns, at, level + 2, case[1]) r += "\t"*(level + 1) + _printexpr(ns, choice) + ": begin\n"
r += _printnode(ns, at, level + 2, statements)
r += "\t"*(level + 1) + "end\n" r += "\t"*(level + 1) + "end\n"
if node.default: if "default" in node.cases:
r += "\t"*(level + 1) + "default: begin\n" r += "\t"*(level + 1) + "default: begin\n"
r += _printnode(ns, at, level + 2, node.default) r += _printnode(ns, at, level + 2, node.cases["default"])
r += "\t"*(level + 1) + "end\n" r += "\t"*(level + 1) + "end\n"
r += "\t"*level + "endcase\n" r += "\t"*level + "endcase\n"
return r return r

View file

@ -65,9 +65,8 @@ class NodeVisitor:
def visit_Case(self, node): def visit_Case(self, node):
self.visit(node.test) self.visit(node.test)
for v, statements in node.cases: for v, statements in node.cases.items():
self.visit(statements) self.visit(statements)
self.visit(node.default)
def visit_Fragment(self, node): def visit_Fragment(self, node):
self.visit(node.comb) self.visit(node.comb)
@ -155,9 +154,8 @@ class NodeTransformer:
return r return r
def visit_Case(self, node): def visit_Case(self, node):
r = Case(self.visit(node.test)) cases = dict((v, self.visit(statements)) for v, statements in node.cases.items())
r.cases = [(v, self.visit(statements)) for v, statements in node.cases] r = Case(self.visit(node.test), cases)
r.default = self.visit(node.default)
return r return r
def visit_Fragment(self, node): def visit_Fragment(self, node):

View file

@ -48,6 +48,6 @@ class ImplRegister:
raise FinalizeError raise FinalizeError
# do nothing when sel == 0 # do nothing when sel == 0
items = sorted(self.source_encoding.items(), key=itemgetter(1)) items = sorted(self.source_encoding.items(), key=itemgetter(1))
cases = [(v, self.storage.eq(self.id_to_source[k])) for k, v in items] cases = dict((v, self.storage.eq(self.id_to_source[k])) for k, v in items)
sync = [Case(self.sel, *cases)] sync = [Case(self.sel, cases)]
return Fragment(sync=sync) return Fragment(sync=sync)