pytholite: do not use ast.NodeVisitor

This commit is contained in:
Sebastien Bourdeauducq 2012-11-06 13:52:19 +01:00
parent 56d4cdeb48
commit 18758d87f6
1 changed files with 97 additions and 17 deletions

View File

@ -37,31 +37,76 @@ class _Register:
sync = [Case(self.sel, *cases)] sync = [Case(self.sel, *cases)]
return Fragment(sync=sync) return Fragment(sync=sync)
class _AnonymousRegister: class _Compiler:
def __init__(self, nbits):
self.nbits = nbits
class _CompileVisitor(ast.NodeVisitor):
def __init__(self, symdict, registers): def __init__(self, symdict, registers):
self.symdict = symdict self.symdict = symdict
self.registers = registers self.registers = registers
self.targetname = ""
def visit_Assign(self, node): def visit_top(self, node):
value = self.visit(node.value) if isinstance(node, ast.Module) \
if isinstance(value, _AnonymousRegister): and len(node.body) == 1 \
if isinstance(node.targets[0], ast.Name): and isinstance(node.body[0], ast.FunctionDef):
name = node.targets[0].id return self.visit_block(node.body[0].body)
else: else:
raise NotImplementedError raise NotImplementedError
value = _Register(name, value.nbits)
# blocks and statements
def visit_block(self, statements):
r = []
for statement in statements:
if isinstance(statement, ast.Assign):
r += self.visit_assign(statement)
else:
raise NotImplementedError
return r
def visit_assign(self, node):
if isinstance(node.targets[0], ast.Name):
self.targetname = node.targets[0].id
value = self.visit_expr(node.value, True)
self.targetname = ""
if isinstance(value, _Register):
self.registers.append(value) self.registers.append(value)
for target in node.targets: for target in node.targets:
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
self.symdict[target.id] = value self.symdict[target.id] = value
else: else:
raise NotImplementedError raise NotImplementedError
return []
elif isinstance(value, Value):
r = []
for target in node.targets:
if isinstance(target, ast.Attribute) and target.attr == "store":
treg = target.value
if isinstance(treg, ast.Name):
r.append(self.symdict[treg.id].load(value))
else:
raise NotImplementedError
else:
raise NotImplementedError
return r
else:
raise NotImplementedError
def visit_Call(self, node): # expressions
def visit_expr(self, node, allow_call=False):
if isinstance(node, ast.Call):
if allow_call:
return self.visit_expr_call(node)
else:
raise NotImplementedError
elif isinstance(node, ast.BinOp):
return self.visit_expr_binop(node)
elif isinstance(node, ast.Name):
return self.visit_expr_name(node)
elif isinstance(node, ast.Num):
return self.visit_expr_num(node)
else:
raise NotImplementedError
def visit_expr_call(self, node):
if isinstance(node.func, ast.Name): if isinstance(node.func, ast.Name):
callee = self.symdict[node.func.id] callee = self.symdict[node.func.id]
else: else:
@ -70,19 +115,54 @@ class _CompileVisitor(ast.NodeVisitor):
if len(node.args) != 1: if len(node.args) != 1:
raise TypeError("Register() takes exactly 1 argument") raise TypeError("Register() takes exactly 1 argument")
nbits = ast.literal_eval(node.args[0]) nbits = ast.literal_eval(node.args[0])
return _AnonymousRegister(nbits) return _Register(self.targetname, nbits)
else: else:
raise NotImplementedError raise NotImplementedError
def visit_expr_binop(self, node):
left = self.visit_expr(node.left)
right = self.visit_expr(node.right)
if isinstance(node.op, ast.Add):
return left + right
elif isinstance(node.op, ast.Sub):
return left - right
elif isinstance(node.op, ast.Mult):
return left * right
elif isinstance(node.op, ast.LShift):
return left << right
elif isinstance(node.op, ast.RShift):
return left >> right
elif isinstance(node.op, ast.BitOr):
return left | right
elif isinstance(node.op, ast.BitXor):
return left ^ right
elif isinstance(node.op, ast.BitAnd):
return left & right
else:
raise NotImplementedError
def visit_expr_name(self, node):
r = self.symdict[node.id]
if isinstance(r, _Register):
r = r.storage
return r
def visit_expr_num(self, node):
return node.n
def make_pytholite(func): def make_pytholite(func):
tree = ast.parse(inspect.getsource(func)) tree = ast.parse(inspect.getsource(func))
symdict = func.__globals__.copy() symdict = func.__globals__.copy()
registers = [] registers = []
cv = _CompileVisitor(symdict, registers) c = _Compiler(symdict, registers)
cv.visit(tree) print("compilation result:")
print(c.visit_top(tree))
print("registers:")
print(registers) print(registers)
print(symdict) #print("symdict:")
#print(symdict)
#print(ast.dump(tree)) print("ast:")
print(ast.dump(tree))