From 18758d87f69212a5f81f014114502b022f6064b8 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Tue, 6 Nov 2012 13:52:19 +0100 Subject: [PATCH] pytholite: do not use ast.NodeVisitor --- migen/pytholite/compiler.py | 114 ++++++++++++++++++++++++++++++------ 1 file changed, 97 insertions(+), 17 deletions(-) diff --git a/migen/pytholite/compiler.py b/migen/pytholite/compiler.py index 5a9e0978b..e3f5a7609 100644 --- a/migen/pytholite/compiler.py +++ b/migen/pytholite/compiler.py @@ -37,31 +37,76 @@ class _Register: sync = [Case(self.sel, *cases)] return Fragment(sync=sync) -class _AnonymousRegister: - def __init__(self, nbits): - self.nbits = nbits - -class _CompileVisitor(ast.NodeVisitor): +class _Compiler: def __init__(self, symdict, registers): self.symdict = symdict self.registers = registers + self.targetname = "" - def visit_Assign(self, node): - value = self.visit(node.value) - if isinstance(value, _AnonymousRegister): - if isinstance(node.targets[0], ast.Name): - name = node.targets[0].id + def visit_top(self, node): + if isinstance(node, ast.Module) \ + and len(node.body) == 1 \ + and isinstance(node.body[0], ast.FunctionDef): + return self.visit_block(node.body[0].body) + else: + raise NotImplementedError + + # blocks and statements + def visit_block(self, statements): + r = [] + for statement in statements: + if isinstance(statement, ast.Assign): + r += self.visit_assign(statement) else: raise NotImplementedError - value = _Register(name, value.nbits) + return r + + def visit_assign(self, node): + if isinstance(node.targets[0], ast.Name): + self.targetname = node.targets[0].id + value = self.visit_expr(node.value, True) + self.targetname = "" + + if isinstance(value, _Register): self.registers.append(value) for target in node.targets: if isinstance(target, ast.Name): self.symdict[target.id] = value else: raise NotImplementedError + return [] + elif isinstance(value, Value): + r = [] + for target in node.targets: + if isinstance(target, ast.Attribute) and target.attr == "store": + treg = target.value + if isinstance(treg, ast.Name): + r.append(self.symdict[treg.id].load(value)) + else: + raise NotImplementedError + else: + raise NotImplementedError + return r + else: + raise NotImplementedError - def visit_Call(self, node): + # expressions + def visit_expr(self, node, allow_call=False): + if isinstance(node, ast.Call): + if allow_call: + return self.visit_expr_call(node) + else: + raise NotImplementedError + elif isinstance(node, ast.BinOp): + return self.visit_expr_binop(node) + elif isinstance(node, ast.Name): + return self.visit_expr_name(node) + elif isinstance(node, ast.Num): + return self.visit_expr_num(node) + else: + raise NotImplementedError + + def visit_expr_call(self, node): if isinstance(node.func, ast.Name): callee = self.symdict[node.func.id] else: @@ -70,19 +115,54 @@ class _CompileVisitor(ast.NodeVisitor): if len(node.args) != 1: raise TypeError("Register() takes exactly 1 argument") nbits = ast.literal_eval(node.args[0]) - return _AnonymousRegister(nbits) + return _Register(self.targetname, nbits) else: raise NotImplementedError + def visit_expr_binop(self, node): + left = self.visit_expr(node.left) + right = self.visit_expr(node.right) + if isinstance(node.op, ast.Add): + return left + right + elif isinstance(node.op, ast.Sub): + return left - right + elif isinstance(node.op, ast.Mult): + return left * right + elif isinstance(node.op, ast.LShift): + return left << right + elif isinstance(node.op, ast.RShift): + return left >> right + elif isinstance(node.op, ast.BitOr): + return left | right + elif isinstance(node.op, ast.BitXor): + return left ^ right + elif isinstance(node.op, ast.BitAnd): + return left & right + else: + raise NotImplementedError + + def visit_expr_name(self, node): + r = self.symdict[node.id] + if isinstance(r, _Register): + r = r.storage + return r + + def visit_expr_num(self, node): + return node.n + def make_pytholite(func): tree = ast.parse(inspect.getsource(func)) symdict = func.__globals__.copy() registers = [] - cv = _CompileVisitor(symdict, registers) - cv.visit(tree) + c = _Compiler(symdict, registers) + print("compilation result:") + print(c.visit_top(tree)) + print("registers:") print(registers) - print(symdict) + #print("symdict:") + #print(symdict) - #print(ast.dump(tree)) + print("ast:") + print(ast.dump(tree))