fsm: support complex targets in NextValue. Closes #27.

This commit is contained in:
Sebastien Bourdeauducq 2015-09-22 16:55:24 +08:00
parent 1857ec6c32
commit 31ffa8c18f
2 changed files with 40 additions and 9 deletions

View File

@ -5,6 +5,7 @@ class Example(Module):
def __init__(self): def __init__(self):
self.s = Signal() self.s = Signal()
self.counter = Signal(8) self.counter = Signal(8)
x = Array(Signal(name="a") for i in range(7))
myfsm = FSM() myfsm = FSM()
self.submodules += myfsm self.submodules += myfsm
@ -16,6 +17,7 @@ class Example(Module):
myfsm.act("BAR", myfsm.act("BAR",
self.s.eq(0), self.s.eq(0),
NextValue(self.counter, self.counter + 1), NextValue(self.counter, self.counter + 1),
NextValue(x[self.counter], 89),
NextState("FOO") NextState("FOO")
) )

View File

@ -1,6 +1,7 @@
from collections import OrderedDict from collections import OrderedDict
from migen.fhdl.structure import * from migen.fhdl.structure import *
from migen.fhdl.structure import _Slice, _ArrayProxy
from migen.fhdl.module import Module, FinalizeError from migen.fhdl.module import Module, FinalizeError
from migen.fhdl.visit import NodeTransformer from migen.fhdl.visit import NodeTransformer
from migen.fhdl.bitcontainer import value_bits_sign from migen.fhdl.bitcontainer import value_bits_sign
@ -21,18 +22,46 @@ class NextState:
class NextValue: class NextValue:
def __init__(self, register, value): def __init__(self, target, value):
self.register = register self.target = target
self.value = value self.value = value
def _target_eq(a, b):
if type(a) != type(b):
return False
ty = type(a)
if ty == Constant:
return a.value == b.value
elif ty == Signal:
return a is b
elif ty == Cat:
return all(_target_eq(x, y) for x, y in zip(a.l, b.l))
elif ty == _Slice:
return (_target_eq(a.value, b.value)
and a.start == b.start
and a.end == b.end)
elif ty == _ArrayProxy:
return (all(_target_eq(x, y) for x, y in zip(a.choices, b.choices))
and _target_eq(a.key, b.key))
else:
raise ValueError("NextValue cannot be used with target type '{}'"
.format(ty))
class _LowerNext(NodeTransformer): class _LowerNext(NodeTransformer):
def __init__(self, next_state_signal, encoding, aliases): def __init__(self, next_state_signal, encoding, aliases):
self.next_state_signal = next_state_signal self.next_state_signal = next_state_signal
self.encoding = encoding self.encoding = encoding
self.aliases = aliases self.aliases = aliases
# register -> next_value_ce, next_value # (target, next_value_ce, next_value)
self.registers = OrderedDict() self.registers = []
def _get_register_control(self, target):
for x in self.registers:
if _target_eq(target, x[0]):
return x
raise KeyError
def visit_unknown(self, node): def visit_unknown(self, node):
if isinstance(node, NextState): if isinstance(node, NextState):
@ -43,12 +72,12 @@ class _LowerNext(NodeTransformer):
return self.next_state_signal.eq(self.encoding[actual_state]) return self.next_state_signal.eq(self.encoding[actual_state])
elif isinstance(node, NextValue): elif isinstance(node, NextValue):
try: try:
next_value_ce, next_value = self.registers[node.register] next_value_ce, next_value = self._get_register_control(node.target)
except KeyError: except KeyError:
related = node.register if isinstance(node.register, Signal) else None related = node.target if isinstance(node.target, Signal) else None
next_value = Signal(bits_sign=value_bits_sign(node.register), related=related) next_value = Signal(bits_sign=value_bits_sign(node.target), related=related)
next_value_ce = Signal(related=related) next_value_ce = Signal(related=related)
self.registers[node.register] = next_value_ce, next_value self.registers.append((node.target, next_value_ce, next_value))
return next_value.eq(node.value), next_value_ce.eq(1) return next_value.eq(node.value), next_value_ce.eq(1)
else: else:
return node return node
@ -133,7 +162,7 @@ class FSM(Module):
Case(self.state, cases).makedefault(self.encoding[self.reset_state]) Case(self.state, cases).makedefault(self.encoding[self.reset_state])
] ]
self.sync += self.state.eq(self.next_state) self.sync += self.state.eq(self.next_state)
for register, (next_value_ce, next_value) in ln.registers.items(): for register, next_value_ce, next_value in ln.registers:
self.sync += If(next_value_ce, register.eq(next_value)) self.sync += If(next_value_ce, register.eq(next_value))
# drive entering/leaving signals # drive entering/leaving signals