Remove Constant

This commit is contained in:
Sebastien Bourdeauducq 2012-11-28 23:18:43 +01:00
parent 2a3ef28041
commit fee22a4631
16 changed files with 83 additions and 118 deletions

View file

@ -50,7 +50,7 @@ class Unpack(Actor):
)
)
]
cases = [(Constant(i, BV(muxbits)) if i else Default(),
cases = [(i if i else Default(),
Cat(*self.token("source").flatten()).eq(Cat(*self.token("sink").subrecord("chunk{0}".format(i)).flatten())))
for i in range(self.n)]
comb.append(Case(mux, *cases))
@ -69,7 +69,7 @@ class Pack(Actor):
load_part = Signal()
strobe_all = Signal()
cases = [(Constant(i, BV(demuxbits)),
cases = [(i,
Cat(*self.token("source").subrecord("chunk{0}".format(i)).flatten()).eq(*self.token("sink").flatten()))
for i in range(self.n)]
comb = [

View file

@ -15,7 +15,7 @@ class Bank:
sync = []
sel = Signal()
comb.append(sel.eq(self.interface.adr[9:] == Constant(self.address, BV(5))))
comb.append(sel.eq(self.interface.adr[9:] == self.address))
desc_exp = expand_description(self.description, csr.data_width)
nbits = bits_for(len(desc_exp)-1)
@ -27,9 +27,9 @@ class Bank:
comb.append(reg.r.eq(self.interface.dat_w[:reg.size]))
comb.append(reg.re.eq(sel & \
self.interface.we & \
(self.interface.adr[:nbits] == Constant(i, BV(nbits)))))
(self.interface.adr[:nbits] == i)))
elif isinstance(reg, RegisterFields):
bwra = [Constant(i, BV(nbits))]
bwra = [i]
offset = 0
for field in reg.fields:
if field.access_bus == WRITE_ONLY or field.access_bus == READ_WRITE:
@ -51,7 +51,7 @@ class Bank:
brcases = []
for i, reg in enumerate(desc_exp):
if isinstance(reg, RegisterRaw):
brcases.append([Constant(i, BV(nbits)), self.interface.dat_r.eq(reg.w)])
brcases.append([i, self.interface.dat_r.eq(reg.w)])
elif isinstance(reg, RegisterFields):
brs = []
reg_readable = False
@ -60,9 +60,9 @@ class Bank:
brs.append(field.storage)
reg_readable = True
else:
brs.append(Constant(0, BV(field.size)))
brs.append(Replicate(0, field.size))
if reg_readable:
brcases.append([Constant(i, BV(nbits)), self.interface.dat_r.eq(Cat(*brs))])
brcases.append([i, self.interface.dat_r.eq(Cat(*brs))])
else:
raise TypeError
if brcases:

View file

@ -83,7 +83,7 @@ class Port:
if not self.finalized:
raise FinalizeError
return self.call \
& (self.tag_call == Constant(self.base + slotn, BV(self.tagbits)))
& (self.tag_call == (self.base + slotn))
def get_fragment(self):
if not self.finalized:

View file

@ -25,10 +25,10 @@ class Interface:
self.pdesc = phase_description(a, ba, d)
self.phases = [SimpleInterface(self.pdesc) for i in range(nphases)]
for p in self.phases:
p.cas_n.reset = Constant(1)
p.cs_n.reset = Constant(1)
p.ras_n.reset = Constant(1)
p.we_n.reset = Constant(1)
p.cas_n.reset = 1
p.cs_n.reset = 1
p.ras_n.reset = 1
p.we_n.reset = 1
# Returns pairs (DFI-mandated signal name, Migen signal object)
def get_standard_names(self, m2s=True, s2m=True):

View file

@ -47,7 +47,7 @@ class Arbiter:
for i, m in enumerate(self.masters):
dest = getattr(m, name)
if name == "ack" or name == "err":
comb.append(dest.eq(source & (self.rr.grant == Constant(i, self.rr.grant.bv))))
comb.append(dest.eq(source & (self.rr.grant == i)))
else:
comb.append(dest.eq(source))
@ -59,27 +59,15 @@ class Arbiter:
class Decoder:
# slaves is a list of pairs:
# 0) structure.Constant defining address (always decoded on the upper bits)
# Slaves can have differing numbers of address bits, but addresses
# must not conflict.
# 1) wishbone.Slave reference
# Addresses are decoded from bit 31-offset and downwards.
# 0) function that takes the address signal and returns a FHDL expression
# that evaluates to 1 when the slave is selected and 0 otherwise.
# 1) wishbone.Slave reference.
# register adds flip-flops after the address comparators. Improves timing,
# but breaks Wishbone combinatorial feedback.
def __init__(self, master, slaves, offset=0, register=False):
def __init__(self, master, slaves, register=False):
self.master = master
self.slaves = slaves
self.offset = offset
self.register = register
addresses = [slave[0] for slave in self.slaves]
maxbits = max([bits_for(addr) for addr in addresses])
def mkconst(x):
if isinstance(x, int):
return Constant(x, BV(maxbits))
else:
return x
self.addresses = list(map(mkconst, addresses))
def get_fragment(self):
comb = []
@ -90,9 +78,8 @@ class Decoder:
slave_sel_r = Signal(BV(ns))
# decode slave addresses
hi = len(self.master.adr) - self.offset
comb += [slave_sel[i].eq(self.master.adr[hi-len(addr):hi] == addr)
for i, addr in enumerate(self.addresses)]
comb += [slave_sel[i].eq(fun(self.master.adr))
for i, (fun, bus) in enumerate(self.slaves)]
if self.register:
sync.append(slave_sel_r.eq(slave_sel))
else:
@ -120,11 +107,10 @@ class Decoder:
return Fragment(comb, sync)
class InterconnectShared:
def __init__(self, masters, slaves, offset=0, register=False):
def __init__(self, masters, slaves, register=False):
self._shared = Interface()
self._arbiter = Arbiter(masters, self._shared)
self._decoder = Decoder(self._shared, slaves, offset, register)
self.addresses = self._decoder.addresses
self._decoder = Decoder(self._shared, slaves, register)
def get_fragment(self):
return self._arbiter.get_fragment() + self._decoder.get_fragment()

View file

@ -22,7 +22,7 @@ class Divider:
comb = [
self.quotient_o.eq(qr[:w]),
self.remainder_o.eq(qr[w:]),
self.ready_o.eq(counter == Constant(0, counter.bv)),
self.ready_o.eq(counter == 0),
diff.eq(self.remainder_o - divisor_r)
]
sync = [
@ -36,7 +36,7 @@ class Divider:
).Else(
qr.eq(Cat(1, qr[:w-1], diff[:w]))
),
counter.eq(counter - Constant(1, counter.bv))
counter.eq(counter - 1)
)
]
return Fragment(comb, sync)

View file

@ -8,16 +8,16 @@ class FSM:
self._state = Signal(self._state_bv)
self._next_state = Signal(self._state_bv)
for n, state in enumerate(states):
setattr(self, state, Constant(n, self._state_bv))
setattr(self, state, n)
self.actions = [[] for i in range(len(states))]
for name, target, delay in delayed_enters:
target_state = getattr(self, target)
if delay:
name_state = len(self.actions)
setattr(self, name, Constant(name_state, self._state_bv))
setattr(self, name, name_state)
for i in range(delay-1):
self.actions.append([self.next_state(Constant(name_state+i+1, self._state_bv))])
self.actions.append([self.next_state(name_state+i+1)])
self.actions.append([self.next_state(target_state)])
else:
# alias
@ -30,11 +30,10 @@ class FSM:
return self._next_state.eq(state)
def act(self, state, *statements):
self.actions[state.n] += statements
self.actions[state] += statements
def get_fragment(self):
cases = [[Constant(s, self._state_bv)] + a
for s, a in enumerate(self.actions) if a]
cases = [[s] + a for s, a in enumerate(self.actions) if a]
comb = [
self._next_state.eq(self._state),
Case(self._state, *cases)

View file

@ -49,7 +49,7 @@ def chooser(signal, shift, output, n=None, reverse=False):
s = n - i - 1
else:
s = i
cases.append([Constant(i, shift.bv), output.eq(signal[s*w:(s+1)*w])])
cases.append([i, output.eq(signal[s*w:(s+1)*w])])
cases[n-1][0] = Default()
return Case(shift, *cases)
@ -57,25 +57,25 @@ def timeline(trigger, events):
lastevent = max([e[0] for e in events])
counter = Signal(BV(bits_for(lastevent)))
counterlogic = If(counter != Constant(0, counter.bv),
counter.eq(counter + Constant(1, counter.bv))
counterlogic = If(counter != 0,
counter.eq(counter + 1)
).Elif(trigger,
counter.eq(Constant(1, counter.bv))
counter.eq(1)
)
# insert counter reset if it doesn't naturally overflow
# (test if lastevent+1 is a power of 2)
if (lastevent & (lastevent + 1)) != 0:
counterlogic = If(counter == lastevent,
counter.eq(Constant(0, counter.bv))
counter.eq(0)
).Else(
counterlogic
)
def get_cond(e):
if e[0] == 0:
return trigger & (counter == Constant(0, counter.bv))
return trigger & (counter == 0)
else:
return counter == Constant(e[0], counter.bv)
return counter == e[0]
sync = [If(get_cond(e), *e[1]) for e in events]
sync.append(counterlogic)
return sync

View file

@ -1,4 +1,5 @@
from migen.fhdl.structure import *
from migen.fhdl.tools import value_bv
class Record:
def __init__(self, layout, name=""):
@ -76,7 +77,7 @@ class Record:
if align:
pad_size = alignment - (offset % alignment)
if pad_size < alignment:
l.append(Constant(0, BV(pad_size)))
l.append(Replicate(0, pad_size))
offset += pad_size
e = self.__dict__[key]
@ -87,7 +88,7 @@ class Record:
else:
raise TypeError
for x in added:
offset += len(x)
offset += value_bv(x).width
l += added
if return_offset:
return (l, offset)

View file

@ -21,7 +21,7 @@ class RoundRobin:
t = j % self.n
switch = [
If(self.request[t],
self.grant.eq(Constant(t, BV(self.bn)))
self.grant.eq(t)
).Else(
*switch
)
@ -30,7 +30,7 @@ class RoundRobin:
case = [If(~self.request[i], *switch)]
else:
case = switch
cases.append([Constant(i, BV(self.bn))] + case)
cases.append([i] + case)
statement = Case(self.grant, *cases)
if self.switch_policy == SP_CE:
statement = If(self.ce, statement)

View file

@ -15,8 +15,6 @@ def log2_int(n, need_pow2=True):
return r
def bits_for(n, require_sign_bit=False):
if isinstance(n, Constant):
return len(n)
if n > 0:
r = log2_int(n + 1, False)
else:
@ -126,7 +124,7 @@ class _Operator(Value):
def __init__(self, op, operands):
super().__init__()
self.op = op
self.operands = list(map(_cst, operands))
self.operands = operands
class _Slice(Value):
def __init__(self, value, start, stop):
@ -138,49 +136,21 @@ class _Slice(Value):
class Cat(Value):
def __init__(self, *args):
super().__init__()
self.l = list(map(_cst, args))
self.l = args
class Replicate(Value):
def __init__(self, v, n):
super().__init__()
self.v = _cst(v)
self.v = v
self.n = n
class Constant(Value):
def __init__(self, n, bv=None):
super().__init__()
self.bv = bv or BV(bits_for(n), n < 0)
self.n = n
def __len__(self):
return self.bv.width
def __repr__(self):
return str(self.bv) + str(self.n)
def __eq__(self, other):
return self.bv == other.bv and self.n == other.n
def __hash__(self):
return super().__hash__()
def binc(x, signed=False):
return Constant(int(x, 2), BV(len(x), signed))
def _cst(x):
if isinstance(x, int):
return Constant(x)
else:
return x
class Signal(Value):
def __init__(self, bv=BV(), name=None, variable=False, reset=0, name_override=None):
super().__init__()
assert(isinstance(bv, BV))
self.bv = bv
self.variable = variable
self.reset = Constant(reset, bv)
self.reset = reset
self.name_override = name_override
self.backtrace = tracer.trace_back(name)
@ -195,7 +165,7 @@ class Signal(Value):
class _Assign:
def __init__(self, l, r):
self.l = l
self.r = _cst(r)
self.r = r
class If:
def __init__(self, cond, *t):
@ -274,8 +244,6 @@ class Instance(HUID):
self.name = name
if isinstance(expr, BV):
self.expr = Signal(expr, name)
elif isinstance(expr, int):
self.expr = Constant(expr)
else:
self.expr = expr
class Input(_IO):

View file

@ -105,8 +105,10 @@ def insert_reset(rst, sl):
return If(rst, *resetcode).Else(*sl)
def value_bv(v):
if isinstance(v, Constant):
return v.bv
if isinstance(v, bool):
return BV(1, False)
elif isinstance(v, int):
return BV(bits_for(v), v < 0)
elif isinstance(v, Signal):
return v.bv
elif isinstance(v, _Operator):
@ -152,7 +154,7 @@ class _ArrayLowerer(NodeTransformer):
cases = []
for n, choice in enumerate(node.l.choices):
assign = self.visit_Assign(_Assign(choice, node.r))
cases.append([Constant(n), assign])
cases.append([n, assign])
cases[-1][0] = Default()
return Case(k, *cases)
else:
@ -160,7 +162,7 @@ class _ArrayLowerer(NodeTransformer):
def visit_ArrayProxy(self, node):
array_muxed = Signal(value_bv(node))
cases = [[Constant(n), _Assign(array_muxed, self.visit(choice))]
cases = [[n, _Assign(array_muxed, self.visit(choice))]
for n, choice in enumerate(node.choices)]
cases[-1][0] = Default()
self.comb.append(Case(self.visit(node.key), *cases))

View file

@ -17,12 +17,23 @@ def _printsig(ns, s):
n += ns.get_name(s)
return n
def _printexpr(ns, node):
if isinstance(node, Constant):
if node.n >= 0:
return str(node.bv) + str(node.n)
def _printintbool(node):
if isinstance(node, bool):
if node:
return "1'd1"
else:
return "-" + str(node.bv) + str(-node.n)
return "1'd0"
elif isinstance(node, int):
if node >= 0:
return str(bits_for(node)) + "'d" + str(node)
else:
return "-" + str(bits_for(node)) + "'sd" + str(-node)
else:
raise TypeError
def _printexpr(ns, node):
if isinstance(node, (int, bool)):
return _printintbool(node)
elif isinstance(node, Signal):
return ns.get_name(node)
elif isinstance(node, _Operator):
@ -146,7 +157,7 @@ def _printcomb(f, ns, display_run):
dummy_s = Signal(name_override="dummy_s")
r += syn_off
r += "reg " + _printsig(ns, dummy_s) + ";\n"
r += "initial " + ns.get_name(dummy_s) + " <= 1'b0;\n"
r += "initial " + ns.get_name(dummy_s) + " <= 1'd0;\n"
r += syn_on
groups = group_by_targets(f.comb)
@ -164,7 +175,7 @@ def _printcomb(f, ns, display_run):
if display_run:
r += "\t$display(\"Running comb block #" + str(n) + "\");\n"
for t in g[0]:
r += "\t" + ns.get_name(t) + " <= " + str(t.reset) + ";\n"
r += "\t" + ns.get_name(t) + " <= " + _printexpr(ns, t.reset) + ";\n"
r += _printnode(ns, _AT_NONBLOCKING, 1, g[1])
r += syn_off
r += "\t" + ns.get_name(dummy_d) + " <= " + ns.get_name(dummy_s) + ";\n"
@ -194,7 +205,9 @@ def _printinstances(f, ns, clock_domains):
r += ",\n"
firstp = False
r += "\t." + p.name + "("
if isinstance(p.value, int) or isinstance(p.value, float) or isinstance(p.value, Constant):
if isinstance(p.value, (int, bool)):
r += _printintbool(p.value)
elif isinstance(p.value, float):
r += str(p.value)
elif isinstance(p.value, str):
r += "\"" + p.value + "\""

View file

@ -5,8 +5,8 @@ from migen.fhdl.structure import _Operator, _Slice, _Assign, _ArrayProxy
class NodeVisitor:
def visit(self, node):
if isinstance(node, Constant):
self.visit_Constant(node)
if isinstance(node, (int, bool)):
self.visit_constant(node)
elif isinstance(node, Signal):
self.visit_Signal(node)
elif isinstance(node, _Operator):
@ -34,7 +34,7 @@ class NodeVisitor:
elif node is not None:
self.visit_unknown(node)
def visit_Constant(self, node):
def visit_constant(self, node):
pass
def visit_Signal(self, node):
@ -90,15 +90,14 @@ class NodeVisitor:
pass
# Default methods always copy the node, except for:
# - Constants
# - Signals
# - Unknown objects
# - All fragment fields except comb and sync
# In those cases, the original node is returned unchanged.
class NodeTransformer:
def visit(self, node):
if isinstance(node, Constant):
return self.visit_Constant(node)
if isinstance(node, (int, bool)):
return self.visit_constant(node)
elif isinstance(node, Signal):
return self.visit_Signal(node)
elif isinstance(node, _Operator):
@ -128,7 +127,7 @@ class NodeTransformer:
else:
return None
def visit_Constant(self, node):
def visit_constant(self, node):
return node
def visit_Signal(self, node):

View file

@ -95,18 +95,16 @@ class ExprCompiler:
def visit_expr_name(self, node):
if node.id == "True":
return Constant(1)
return 1
if node.id == "False":
return Constant(0)
return 0
r = self.symdict[node.id]
if isinstance(r, ImplRegister):
r = r.storage
if isinstance(r, int):
r = Constant(r)
return r
def visit_expr_num(self, node):
return Constant(node.n)
return node.n
def visit_expr_attribute(self, node):
raise NotImplementedError

View file

@ -46,7 +46,6 @@ class ImplRegister:
raise FinalizeError
# do nothing when sel == 0
items = sorted(self.source_encoding.items(), key=itemgetter(1))
cases = [(Constant(v, self.sel.bv),
self.storage.eq(k)) for k, v in items]
cases = [(v, self.storage.eq(k)) for k, v in items]
sync = [Case(self.sel, *cases)]
return Fragment(sync=sync)