genlib/fsm: add NextValue to replace reg/reg_next/ce pattern

This commit is contained in:
Sebastien Bourdeauducq 2014-11-25 17:16:21 +08:00
parent 5801e5746b
commit 4542de2c11
2 changed files with 37 additions and 8 deletions

View File

@ -1,18 +1,29 @@
from migen.fhdl.std import * from migen.fhdl.std import *
from migen.fhdl import verilog from migen.fhdl import verilog
from migen.genlib.fsm import FSM, NextState from migen.genlib.fsm import FSM, NextState, NextValue
class Example(Module): class Example(Module):
def __init__(self): def __init__(self):
self.s = Signal() self.s = Signal()
self.counter = Signal(8)
myfsm = FSM() myfsm = FSM()
self.submodules += myfsm self.submodules += myfsm
myfsm.act("FOO", self.s.eq(1), NextState("BAR"))
myfsm.act("BAR", self.s.eq(0), NextState("FOO")) myfsm.act("FOO",
self.s.eq(1),
NextState("BAR")
)
myfsm.act("BAR",
self.s.eq(0),
NextValue(self.counter, self.counter + 1),
NextState("FOO")
)
self.be = myfsm.before_entering("FOO") self.be = myfsm.before_entering("FOO")
self.ae = myfsm.after_entering("FOO") self.ae = myfsm.after_entering("FOO")
self.bl = myfsm.before_leaving("FOO") self.bl = myfsm.before_leaving("FOO")
self.al = myfsm.after_leaving("FOO") self.al = myfsm.after_leaving("FOO")
example = Example() example = Example()
print(verilog.convert(example, {example.s, example.be, example.ae, example.bl, example.al})) print(verilog.convert(example, {example.s, example.counter, example.be, example.ae, example.bl, example.al}))

View File

@ -3,6 +3,7 @@ from collections import OrderedDict
from migen.fhdl.std import * from migen.fhdl.std import *
from migen.fhdl.module import FinalizeError from migen.fhdl.module import FinalizeError
from migen.fhdl.visit import NodeTransformer from migen.fhdl.visit import NodeTransformer
from migen.fhdl.bitcontainer import value_bits_sign
class AnonymousState: class AnonymousState:
pass pass
@ -13,11 +14,18 @@ class NextState:
def __init__(self, state): def __init__(self, state):
self.state = state self.state = state
class _LowerNextState(NodeTransformer): class NextValue:
def __init__(self, register, value):
self.register = register
self.value = value
class _LowerNext(NodeTransformer):
def __init__(self, next_state_signal, encoding, aliases): def __init__(self, next_state_signal, encoding, aliases):
self.next_state_signal = next_state_signal self.next_state_signal = next_state_signal
self.encoding = encoding self.encoding = encoding
self.aliases = aliases self.aliases = aliases
# register -> next_value_ce, next_value
self.registers = OrderedDict()
def visit_unknown(self, node): def visit_unknown(self, node):
if isinstance(node, NextState): if isinstance(node, NextState):
@ -26,6 +34,15 @@ class _LowerNextState(NodeTransformer):
except KeyError: except KeyError:
actual_state = node.state actual_state = node.state
return self.next_state_signal.eq(self.encoding[actual_state]) return self.next_state_signal.eq(self.encoding[actual_state])
elif isinstance(node, NextValue):
try:
next_value_ce, next_value = self.registers[node.register]
except KeyError:
related = node.register if isinstance(node.register, Signal) else None
next_value = Signal(bits_sign=value_bits_sign(node.register), related=related)
next_value_ce = Signal(related=related)
self.registers[node.register] = next_value_ce, next_value
return next_value.eq(node.value), next_value_ce.eq(1)
else: else:
return node return node
@ -97,18 +114,19 @@ class FSM(Module):
def do_finalize(self): def do_finalize(self):
nstates = len(self.actions) nstates = len(self.actions)
self.encoding = dict((s, n) for n, s in enumerate(self.actions.keys())) self.encoding = dict((s, n) for n, s in enumerate(self.actions.keys()))
self.state = Signal(max=nstates, reset=self.encoding[self.reset_state]) self.state = Signal(max=nstates, reset=self.encoding[self.reset_state])
self.next_state = Signal(max=nstates) self.next_state = Signal(max=nstates)
lns = _LowerNextState(self.next_state, self.encoding, self.state_aliases) ln = _LowerNext(self.next_state, self.encoding, self.state_aliases)
cases = dict((self.encoding[k], lns.visit(v)) for k, v in self.actions.items() if v) cases = dict((self.encoding[k], ln.visit(v)) for k, v in self.actions.items() if v)
self.comb += [ self.comb += [
self.next_state.eq(self.state), self.next_state.eq(self.state),
Case(self.state, cases).makedefault(self.encoding[self.reset_state]) Case(self.state, cases).makedefault(self.encoding[self.reset_state])
] ]
self.sync += self.state.eq(self.next_state) self.sync += self.state.eq(self.next_state)
for register, (next_value_ce, next_value) in ln.registers.items():
self.sync += If(next_value_ce, register.eq(next_value))
# drive entering/leaving signals # drive entering/leaving signals
for state, signal in self.before_leaving_signals.items(): for state, signal in self.before_leaving_signals.items():