from collections import OrderedDict

from migen.fhdl.std import *
from migen.fhdl.module import FinalizeError
from migen.fhdl.visit import NodeTransformer
from migen.fhdl.bitcontainer import value_bits_sign

class AnonymousState:
	pass

# do not use namedtuple here as it inherits tuple
# and the latter is used elsewhere in FHDL
class NextState:
	def __init__(self, state):
		self.state = state

class NextValue:
	def __init__(self, register, value):
		self.register = register
		self.value = value

class _LowerNext(NodeTransformer):
	def __init__(self, next_state_signal, encoding, aliases):
		self.next_state_signal = next_state_signal
		self.encoding = encoding
		self.aliases = aliases
		# register -> next_value_ce, next_value
		self.registers = OrderedDict()

	def visit_unknown(self, node):
		if isinstance(node, NextState):
			try:
				actual_state = self.aliases[node.state]
			except KeyError:
				actual_state = node.state
			return self.next_state_signal.eq(self.encoding[actual_state])
		elif isinstance(node, NextValue):
			try:
				next_value_ce, next_value = self.registers[node.register]
			except KeyError:
				related = node.register if isinstance(node.register, Signal) else None
				next_value = Signal(bits_sign=value_bits_sign(node.register), related=related)
				next_value_ce = Signal(related=related)
				self.registers[node.register] = next_value_ce, next_value
			return next_value.eq(node.value), next_value_ce.eq(1)
		else:
			return node

class FSM(Module):
	def __init__(self, reset_state=None):
		self.actions = OrderedDict()
		self.state_aliases = dict()
		self.reset_state = reset_state

		self.before_entering_signals = OrderedDict()
		self.before_leaving_signals = OrderedDict()
		self.after_entering_signals = OrderedDict()
		self.after_leaving_signals = OrderedDict()

	def act(self, state, *statements):
		if self.finalized:
			raise FinalizeError
		if self.reset_state is None:
			self.reset_state = state
		if state not in self.actions:
			self.actions[state] = []
		self.actions[state] += statements

	def delayed_enter(self, name, target, delay):
		if self.finalized:
			raise FinalizeError
		if delay:
			state = name
			for i in range(delay):
				if i == delay - 1:
					next_state = target
				else:
					next_state = AnonymousState()
				self.act(state, NextState(next_state))
				state = next_state
		else:
			self.state_aliases[name] = target

	def ongoing(self, state):
		is_ongoing = Signal()
		self.act(state, is_ongoing.eq(1))
		return is_ongoing

	def _get_signal(self, d, state):
		if state not in self.actions:
			self.actions[state] = []
		try:
			return d[state]
		except KeyError:
			is_el = Signal()
			d[state] = is_el
			return is_el

	def before_entering(self, state):
		return self._get_signal(self.before_entering_signals, state)

	def before_leaving(self, state):
		return self._get_signal(self.before_leaving_signals, state)

	def after_entering(self, state):
		signal = self._get_signal(self.after_entering_signals, state)
		self.sync += signal.eq(self.before_entering(state))
		return signal

	def after_leaving(self, state):
		signal = self._get_signal(self.after_leaving_signals, state)
		self.sync += signal.eq(self.before_leaving(state))
		return signal

	def do_finalize(self):
		nstates = len(self.actions)
		self.encoding = dict((s, n) for n, s in enumerate(self.actions.keys()))
		self.state = Signal(max=nstates, reset=self.encoding[self.reset_state])
		self.next_state = Signal(max=nstates)

		ln = _LowerNext(self.next_state, self.encoding, self.state_aliases)
		cases = dict((self.encoding[k], ln.visit(v)) for k, v in self.actions.items() if v)
		self.comb += [
			self.next_state.eq(self.state),
			Case(self.state, cases).makedefault(self.encoding[self.reset_state])
		]
		self.sync += self.state.eq(self.next_state)
		for register, (next_value_ce, next_value) in ln.registers.items():
			self.sync += If(next_value_ce, register.eq(next_value))

		# drive entering/leaving signals
		for state, signal in self.before_leaving_signals.items():
			encoded = self.encoding[state]
			self.comb += signal.eq((self.state == encoded) & ~(self.next_state == encoded))
		if self.reset_state in self.after_entering_signals:
			self.after_entering_signals[self.reset_state].reset = 1
		for state, signal in self.before_entering_signals.items():
			encoded = self.encoding[state]
			self.comb += signal.eq(~(self.state == encoded) & (self.next_state == encoded))