import inspect import ast from operator import itemgetter from migen.fhdl.structure import * from migen.fhdl import visit as fhdl from migen.corelogic.fsm import FSM from migen.pytholite import transel class FinalizeError(Exception): pass class _AbstractLoad: def __init__(self, target, source): self.target = target self.source = source def lower(self): if not self.target.finalized: raise FinalizeError return self.target.sel.eq(self.target.source_encoding[self.source]) class _LowerAbstractLoad(fhdl.NodeTransformer): def visit_unknown(self, node): if isinstance(node, _AbstractLoad): return node.lower() else: return node class _Register: def __init__(self, name, nbits): self.name = name self.storage = Signal(BV(nbits), name=self.name) self.source_encoding = {} self.finalized = False def load(self, source): if source not in self.source_encoding: self.source_encoding[source] = len(self.source_encoding) + 1 return _AbstractLoad(self, source) def finalize(self): if self.finalized: raise FinalizeError self.sel = Signal(BV(bits_for(len(self.source_encoding) + 1)), name="pl_regsel_"+self.name) self.finalized = True def get_fragment(self): if not self.finalized: raise FinalizeError # do nothing when sel == 0 items = sorted(self.source_encoding.items(), key=itemgetter(1)) cases = [(Constant(v, self.sel.bv), self.storage.eq(k)) for k, v in items] sync = [Case(self.sel, *cases)] return Fragment(sync=sync) class _AbstractNextState: def __init__(self, target_state): self.target_state = target_state class _Compiler: def __init__(self, symdict, registers): self.symdict = symdict self.registers = registers self.targetname = "" def visit_top(self, node): if isinstance(node, ast.Module) \ and len(node.body) == 1 \ and isinstance(node.body[0], ast.FunctionDef): states, exit_states = self.visit_block(node.body[0].body) return states else: raise NotImplementedError # blocks and statements def visit_block(self, statements): states = [] exit_states = [] for statement in statements: n_states, n_exit_states = self.visit_statement(statement) if n_states: states += n_states for exit_state in exit_states: exit_state.insert(0, _AbstractNextState(n_states[0])) exit_states = n_exit_states return states, exit_states # entry state is first state returned def visit_statement(self, statement): states = [] exit_states = [] if isinstance(statement, ast.Assign): op = self.visit_assign(statement) if op: states.append(op) exit_states.append(op) elif isinstance(statement, ast.If): test = self.visit_expr(statement.test) states_t, exit_states_t = self.visit_block(statement.body) states_f, exit_states_f = self.visit_block(statement.orelse) test_state_stmt = If(test, _AbstractNextState(states_t[0])) test_state = [test_state_stmt] if states_f: test_state_stmt.Else(_AbstractNextState(states_f[0])) else: exit_states.append(test_state) states.append(test_state) states += states_t + states_f exit_states += exit_states_t + exit_states_f elif isinstance(statement, ast.While): test = self.visit_expr(statement.test) states_b, exit_states_b = self.visit_block(statement.body) test_state = [If(test, _AbstractNextState(states_b[0]))] for exit_state in exit_states_b: exit_state.insert(0, _AbstractNextState(test_state)) exit_states.append(test_state) states += states_b states.append(test_state) elif isinstance(statement, ast.For): if not isinstance(statement.target, ast.Name): raise NotImplementedError target = statement.target.id if target in self.symdict: raise NotImplementedError("For loop target must use an available name") it = ast.literal_eval(statement.iter) last_exit_states = [] for iteration in it: self.symdict[target] = iteration states_b, exit_states_b = self.visit_block(statement.body) for exit_state in last_exit_states: exit_state.insert(0, _AbstractNextState(states_b[0])) last_exit_states = exit_states_b states += states_b exit_states += last_exit_states del self.symdict[target] else: raise NotImplementedError return states, exit_states def visit_assign(self, node): if isinstance(node.targets[0], ast.Name): self.targetname = node.targets[0].id value = self.visit_expr(node.value, True) self.targetname = "" if isinstance(value, _Register): self.registers.append(value) for target in node.targets: if isinstance(target, ast.Name): self.symdict[target.id] = value else: raise NotImplementedError return [] elif isinstance(value, Value): r = [] for target in node.targets: if isinstance(target, ast.Attribute) and target.attr == "store": treg = target.value if isinstance(treg, ast.Name): r.append(self.symdict[treg.id].load(value)) else: raise NotImplementedError else: raise NotImplementedError return r else: raise NotImplementedError # expressions def visit_expr(self, node, allow_call=False): if isinstance(node, ast.Call): if allow_call: return self.visit_expr_call(node) else: raise NotImplementedError elif isinstance(node, ast.BinOp): return self.visit_expr_binop(node) elif isinstance(node, ast.Compare): return self.visit_expr_compare(node) elif isinstance(node, ast.Name): return self.visit_expr_name(node) elif isinstance(node, ast.Num): return self.visit_expr_num(node) else: raise NotImplementedError def visit_expr_call(self, node): if isinstance(node.func, ast.Name): callee = self.symdict[node.func.id] else: raise NotImplementedError if callee == transel.Register: if len(node.args) != 1: raise TypeError("Register() takes exactly 1 argument") nbits = ast.literal_eval(node.args[0]) return _Register(self.targetname, nbits) else: raise NotImplementedError def visit_expr_binop(self, node): left = self.visit_expr(node.left) right = self.visit_expr(node.right) if isinstance(node.op, ast.Add): return left + right elif isinstance(node.op, ast.Sub): return left - right elif isinstance(node.op, ast.Mult): return left * right elif isinstance(node.op, ast.LShift): return left << right elif isinstance(node.op, ast.RShift): return left >> right elif isinstance(node.op, ast.BitOr): return left | right elif isinstance(node.op, ast.BitXor): return left ^ right elif isinstance(node.op, ast.BitAnd): return left & right else: raise NotImplementedError def visit_expr_compare(self, node): test = self.visit_expr(node.left) r = None for op, rcomparator in zip(node.ops, node.comparators): comparator = self.visit_expr(rcomparator) if isinstance(op, ast.Eq): comparison = test == comparator elif isinstance(op, ast.NotEq): comparison = test != comparator elif isinstance(op, ast.Lt): comparison = test < comparator elif isinstance(op, ast.LtE): comparison = test <= comparator elif isinstance(op, ast.Gt): comparison = test > comparator elif isinstance(op, ast.GtE): comparison = test >= comparator else: raise NotImplementedError if r is None: r = comparison else: r = r & comparison test = comparator return r def visit_expr_name(self, node): if node.id == "True": return Constant(1) if node.id == "False": return Constant(0) r = self.symdict[node.id] if isinstance(r, _Register): r = r.storage if isinstance(r, int): r = Constant(r) return r def visit_expr_num(self, node): return Constant(node.n) # like list.index, but using "is" instead of comparison def _index_is(l, x): for i, e in enumerate(l): if e is x: return i class _LowerAbstractNextState(fhdl.NodeTransformer): def __init__(self, fsm, states, stnames): self.fsm = fsm self.states = states self.stnames = stnames def visit_unknown(self, node): if isinstance(node, _AbstractNextState): index = _index_is(self.states, node.target_state) estate = getattr(self.fsm, self.stnames[index]) return self.fsm.next_state(estate) else: return node def _create_fsm(states): stnames = ["S" + str(i) for i in range(len(states))] fsm = FSM(*stnames) lans = _LowerAbstractNextState(fsm, states, stnames) for i, state in enumerate(states): actions = lans.visit(state) fsm.act(getattr(fsm, stnames[i]), *actions) return fsm def make_pytholite(func): tree = ast.parse(inspect.getsource(func)) symdict = func.__globals__.copy() registers = [] print("ast:") print(ast.dump(tree)) states = _Compiler(symdict, registers).visit_top(tree) print("compilation result:") print(states) regf = Fragment() for register in registers: register.finalize() regf += register.get_fragment() fsm = _create_fsm(states) fsmf = _LowerAbstractLoad().visit(fsm.get_fragment()) return regf + fsmf