from copy import copy

from migen.fhdl.structure import *
from migen.fhdl.structure import _Operator, _Slice, _Assign, _ArrayProxy

class NodeVisitor:
	def visit(self, node):
		if isinstance(node, (int, bool)):
			self.visit_constant(node)
		elif isinstance(node, Signal):
			self.visit_Signal(node)
		elif isinstance(node, _Operator):
			self.visit_Operator(node)
		elif isinstance(node, _Slice):
			self.visit_Slice(node)
		elif isinstance(node, Cat):
			self.visit_Cat(node)
		elif isinstance(node, Replicate):
			self.visit_Replicate(node)
		elif isinstance(node, _Assign):
			self.visit_Assign(node)
		elif isinstance(node, If):
			self.visit_If(node)
		elif isinstance(node, Case):
			self.visit_Case(node)
		elif isinstance(node, Fragment):
			self.visit_Fragment(node)
		elif isinstance(node, (list, tuple)):
			self.visit_statements(node)
		elif isinstance(node, dict):
			self.visit_clock_domains(node)
		elif isinstance(node, _ArrayProxy):
			self.visit_ArrayProxy(node)
		elif node is not None:
			self.visit_unknown(node)
	
	def visit_constant(self, node):
		pass
	
	def visit_Signal(self, node):
		pass
	
	def visit_Operator(self, node):
		for o in node.operands:
			self.visit(o)
	
	def visit_Slice(self, node):
		self.visit(node.value)
	
	def visit_Cat(self, node):
		for e in node.l:
			self.visit(e)
	
	def visit_Replicate(self, node):
		self.visit(node.v)
	
	def visit_Assign(self, node):
		self.visit(node.l)
		self.visit(node.r)
	
	def visit_If(self, node):
		self.visit(node.cond)
		self.visit(node.t)
		self.visit(node.f)
	
	def visit_Case(self, node):
		self.visit(node.test)
		for v, statements in node.cases.items():
			self.visit(statements)
	
	def visit_Fragment(self, node):
		self.visit(node.comb)
		self.visit(node.sync)
	
	def visit_statements(self, node):
		for statement in node:
			self.visit(statement)
	
	def visit_clock_domains(self, node):
		for clockname, statements in node.items():
			self.visit(statements)
	
	def visit_ArrayProxy(self, node):
		for choice in node.choices:
			self.visit(choice)
		self.visit(node.key)
	
	def visit_unknown(self, node):
		pass

# Default methods always copy the node, except for:
# - Signals
# - Unknown objects
# - All fragment fields except comb and sync
# In those cases, the original node is returned unchanged.
class NodeTransformer:
	def visit(self, node):
		if isinstance(node, (int, bool)):
			return self.visit_constant(node)
		elif isinstance(node, Signal):
			return self.visit_Signal(node)
		elif isinstance(node, _Operator):
			return self.visit_Operator(node)
		elif isinstance(node, _Slice):
			return self.visit_Slice(node)
		elif isinstance(node, Cat):
			return self.visit_Cat(node)
		elif isinstance(node, Replicate):
			return self.visit_Replicate(node)
		elif isinstance(node, _Assign):
			return self.visit_Assign(node)
		elif isinstance(node, If):
			return self.visit_If(node)
		elif isinstance(node, Case):
			return self.visit_Case(node)
		elif isinstance(node, Fragment):
			return self.visit_Fragment(node)
		elif isinstance(node, (list, tuple)):
			return self.visit_statements(node)
		elif isinstance(node, dict):
			return self.visit_clock_domains(node)
		elif isinstance(node, _ArrayProxy):
			return self.visit_ArrayProxy(node)
		elif node is not None:
			return self.visit_unknown(node)
		else:
			return None
	
	def visit_constant(self, node):
		return node
	
	def visit_Signal(self, node):
		return node
	
	def visit_Operator(self, node):
		return _Operator(node.op, [self.visit(o) for o in node.operands])
	
	def visit_Slice(self, node):
		return _Slice(self.visit(node.value), node.start, node.stop)
	
	def visit_Cat(self, node):
		return Cat(*[self.visit(e) for e in node.l])
	
	def visit_Replicate(self, node):
		return Replicate(self.visit(node.v), node.n)
	
	def visit_Assign(self, node):
		return _Assign(self.visit(node.l), self.visit(node.r))
	
	def visit_If(self, node):
		r = If(self.visit(node.cond))
		r.t = self.visit(node.t)
		r.f = self.visit(node.f)
		return r
	
	def visit_Case(self, node):
		cases = dict((v, self.visit(statements)) for v, statements in node.cases.items())
		r = Case(self.visit(node.test), cases)
		return r
	
	def visit_Fragment(self, node):
		r = copy(node)
		r.comb = self.visit(node.comb)
		r.sync = self.visit(node.sync)
		return r
	
	# NOTE: this will always return a list, even if node is a tuple
	def visit_statements(self, node):
		return [self.visit(statement) for statement in node]
	
	def visit_clock_domains(self, node):
		return dict((clockname, self.visit(statements)) for clockname, statements in node.items())
	
	def visit_ArrayProxy(self, node):
		return _ArrayProxy([self.visit(choice) for choice in node.choices],
			self.visit(node.key))
	
	def visit_unknown(self, node):
		return node