diff --git a/migen/pytholite/compiler.py b/migen/pytholite/compiler.py index 4763bd6af..660719c06 100644 --- a/migen/pytholite/compiler.py +++ b/migen/pytholite/compiler.py @@ -62,7 +62,6 @@ class _Compiler: self.ioo = ioo self.symdict = symdict self.registers = registers - self.targetname = "" def visit_top(self, node): if isinstance(node, ast.Module) \ @@ -77,38 +76,40 @@ class _Compiler: def visit_block(self, statements): sa = StateAssembler() statements = iter(statements) + statement = None while True: - try: - statement = next(statements) - except StopIteration: - return sa.ret() + if statement is None: + try: + statement = next(statements) + except StopIteration: + return sa.ret() if isinstance(statement, ast.Assign): - self.visit_assign(sa, statement) - elif isinstance(statement, ast.If): - self.visit_if(sa, statement) - elif isinstance(statement, ast.While): - self.visit_while(sa, statement) - elif isinstance(statement, ast.For): - self.visit_for(sa, statement) - elif isinstance(statement, ast.Expr): - self.visit_expr_statement(sa, statement) + # visit_assign can recognize a I/O pattern, consume several + # statements from the iterator and return the first statement + # that is not part of the I/O pattern anymore. + statement = self.visit_assign(sa, statement, statements) else: - raise NotImplementedError - - def visit_assign(self, sa, node): - if isinstance(node.targets[0], ast.Name): - self.targetname = node.targets[0].id - value = self.visit_expr(node.value, True) - self.targetname = "" - - if isinstance(value, _Register): - self.registers.append(value) - for target in node.targets: - if isinstance(target, ast.Name): - self.symdict[target.id] = value + if isinstance(statement, ast.If): + self.visit_if(sa, statement) + elif isinstance(statement, ast.While): + self.visit_while(sa, statement) + elif isinstance(statement, ast.For): + self.visit_for(sa, statement) + elif isinstance(statement, ast.Expr): + self.visit_expr_statement(sa, statement) else: raise NotImplementedError - elif isinstance(value, Value): + statement = None + + def visit_assign(self, sa, node, statements): + if isinstance(node.value, ast.Call): + try: + value = self.visit_expr_call(node.value) + except NotImplementedError: + return self.visit_assign_special(sa, node, statements) + else: + value = self.visit_expr(node.value) + if isinstance(value, Value): r = [] for target in node.targets: if isinstance(target, ast.Attribute) and target.attr == "store": @@ -122,7 +123,33 @@ class _Compiler: sa.assemble([r], [r]) else: raise NotImplementedError - + + def visit_assign_special(self, sa, node, statements): + value = node.value + assert(isinstance(value, ast.Call)) + if isinstance(value.func, ast.Name): + callee = self.symdict[value.func.id] + else: + raise NotImplementedError + + if callee == transel.Register: + if len(value.args) != 1: + raise TypeError("Register() takes exactly 1 argument") + nbits = ast.literal_eval(value.args[0]) + if isinstance(node.targets[0], ast.Name): + targetname = node.targets[0].id + else: + targetname = "unk" + reg = _Register(targetname, nbits) + self.registers.append(reg) + for target in node.targets: + if isinstance(target, ast.Name): + self.symdict[target.id] = reg + else: + raise NotImplementedError + else: + raise NotImplementedError + def visit_if(self, sa, node): test = self.visit_expr(node.test) states_t, exit_states_t = self.visit_block(node.body) @@ -193,12 +220,9 @@ class _Compiler: raise NotImplementedError # expressions - def visit_expr(self, node, allow_registers=False): + def visit_expr(self, node): if isinstance(node, ast.Call): - r = self.visit_expr_call(node) - if not allow_registers and isinstance(r, _Register): - raise NotImplementedError - return r + return self.visit_expr_call(node) elif isinstance(node, ast.BinOp): return self.visit_expr_binop(node) elif isinstance(node, ast.Compare): @@ -215,12 +239,7 @@ class _Compiler: callee = self.symdict[node.func.id] else: raise NotImplementedError - if callee == transel.Register: - if len(node.args) != 1: - raise TypeError("Register() takes exactly 1 argument") - nbits = ast.literal_eval(node.args[0]) - return _Register(self.targetname, nbits) - elif callee == transel.bitslice: + if callee == transel.bitslice: if len(node.args) != 2 and len(node.args) != 3: raise TypeError("bitslice() takes 2 or 3 arguments") val = self.visit_expr(node.args[0])