visit/NodeTransformer: copy most nodes

This commit is contained in:
Sebastien Bourdeauducq 2012-11-28 17:50:55 +01:00
parent a2bcbfdf8f
commit 11b1e53224
1 changed files with 27 additions and 25 deletions

View File

@ -1,3 +1,5 @@
from copy import copy
from migen.fhdl.structure import * from migen.fhdl.structure import *
from migen.fhdl.structure import _Operator, _Slice, _Assign, _ArrayProxy from migen.fhdl.structure import _Operator, _Slice, _Assign, _ArrayProxy
@ -87,6 +89,12 @@ class NodeVisitor:
def visit_unknown(self, node): def visit_unknown(self, node):
pass pass
# Default methods always copy the node, except for:
# - Constants
# - Signals
# - Unknown objects
# - All fragment fields except comb and sync
# In those cases, the original node is returned unchanged.
class NodeTransformer: class NodeTransformer:
def visit(self, node): def visit(self, node):
if isinstance(node, Constant): if isinstance(node, Constant):
@ -127,42 +135,37 @@ class NodeTransformer:
return node return node
def visit_Operator(self, node): def visit_Operator(self, node):
node.operands = [self.visit(o) for o in node.operands] return _Operator(node.op, [self.visit(o) for o in node.operands])
return node
def visit_Slice(self, node): def visit_Slice(self, node):
node.value = self.visit(node.value) return _Slice(self.visit(node.value), node.start, node.stop)
return node
def visit_Cat(self, node): def visit_Cat(self, node):
node.l = [self.visit(e) for e in node.l] return Cat(*[self.visit(e) for e in node.l])
return node
def visit_Replicate(self, node): def visit_Replicate(self, node):
node.v = self.visit(node.v) return Replicate(self.visit(node.v), node.n)
return node
def visit_Assign(self, node): def visit_Assign(self, node):
node.l = self.visit(node.l) return _Assign(self.visit(node.l), self.visit(node.r))
node.r = self.visit(node.r)
return node
def visit_If(self, node): def visit_If(self, node):
node.cond = self.visit(node.cond) r = If(self.visit(node.cond))
node.t = self.visit(node.t) r.t = self.visit(node.t)
node.f = self.visit(node.f) r.f = self.visit(node.f)
return node return r
def visit_Case(self, node): def visit_Case(self, node):
node.test = self.visit(node.test) r = Case(self.visit(node.test))
node.cases = [(v, self.visit(statements)) for v, statements in node.cases] r.cases = [(v, self.visit(statements)) for v, statements in node.cases]
node.default = self.visit(node.default) r.default = self.visit(node.default)
return node return r
def visit_Fragment(self, node): def visit_Fragment(self, node):
node.comb = self.visit(node.comb) r = copy(node)
node.sync = self.visit(node.sync) r.comb = self.visit(node.comb)
return node r.sync = self.visit(node.sync)
return r
def visit_statements(self, node): def visit_statements(self, node):
return [self.visit(statement) for statement in node] return [self.visit(statement) for statement in node]
@ -171,9 +174,8 @@ class NodeTransformer:
return dict((clockname, self.visit(statements)) for clockname, statements in node.items()) return dict((clockname, self.visit(statements)) for clockname, statements in node.items())
def visit_ArrayProxy(self, node): def visit_ArrayProxy(self, node):
node.choices = [self.visit(choice) for choice in node.choices] return _ArrayProxy([self.visit(choice) for choice in node.choices],
node.key = self.visit(node.key) self.visit(node.key))
return node
def visit_unknown(self, node): def visit_unknown(self, node):
return node return node