pytholite: move expression and register handling to separate modules

This commit is contained in:
Sebastien Bourdeauducq 2012-11-11 23:48:23 +01:00
parent f59fd69e34
commit bf5ce8dc20
3 changed files with 169 additions and 152 deletions

View File

@ -1,62 +1,14 @@
import inspect
import ast
from operator import itemgetter
from migen.fhdl.structure import *
from migen.fhdl.structure import _Slice
from migen.fhdl import visit as fhdl
from migen.pytholite.reg import *
from migen.pytholite.expr import *
from migen.pytholite import transel
from migen.pytholite.io import make_io_object, gen_io
from migen.pytholite.fsm import *
class FinalizeError(Exception):
pass
class _AbstractLoad:
def __init__(self, target, source):
self.target = target
self.source = source
def lower(self):
if not self.target.finalized:
raise FinalizeError
return self.target.sel.eq(self.target.source_encoding[self.source])
class _LowerAbstractLoad(fhdl.NodeTransformer):
def visit_unknown(self, node):
if isinstance(node, _AbstractLoad):
return node.lower()
else:
return node
class _Register:
def __init__(self, name, nbits):
self.name = name
self.storage = Signal(BV(nbits), name=self.name)
self.source_encoding = {}
self.finalized = False
def load(self, source):
if source not in self.source_encoding:
self.source_encoding[source] = len(self.source_encoding) + 1
return _AbstractLoad(self, source)
def finalize(self):
if self.finalized:
raise FinalizeError
self.sel = Signal(BV(bits_for(len(self.source_encoding) + 1)), name="pl_regsel_"+self.name)
self.finalized = True
def get_fragment(self):
if not self.finalized:
raise FinalizeError
# do nothing when sel == 0
items = sorted(self.source_encoding.items(), key=itemgetter(1))
cases = [(Constant(v, self.sel.bv),
self.storage.eq(k)) for k, v in items]
sync = [Case(self.sel, *cases)]
return Fragment(sync=sync)
def _is_name_used(node, name):
for n in ast.walk(node):
if isinstance(n, ast.Name) and n.id == name:
@ -68,6 +20,7 @@ class _Compiler:
self.ioo = ioo
self.symdict = symdict
self.registers = registers
self.ec = ExprCompiler(self.symdict)
def visit_top(self, node):
if isinstance(node, ast.Module) \
@ -109,12 +62,15 @@ class _Compiler:
def visit_assign(self, sa, node, statements):
if isinstance(node.value, ast.Call):
is_special = False
try:
value = self.visit_expr_call(node.value)
value = self.ec.visit_expr_call(node.value)
except NotImplementedError:
is_special = True
if is_special:
return self.visit_assign_special(sa, node, statements)
else:
value = self.visit_expr(node.value)
value = self.ec.visit_expr(node.value)
if isinstance(value, Value):
r = []
for target in node.targets:
@ -146,7 +102,7 @@ class _Compiler:
targetname = node.targets[0].id
else:
targetname = "unk"
reg = _Register(targetname, nbits)
reg = ImplRegister(targetname, nbits)
self.registers.append(reg)
for target in node.targets:
if isinstance(target, ast.Name):
@ -173,6 +129,7 @@ class _Compiler:
or not isinstance(ystatement.value, ast.Yield) \
or not isinstance(ystatement.value.value, ast.Name) \
or ystatement.value.value.id != modelname:
print(ast.dump(ystatement))
raise NotImplementedError("Unrecognized I/O pattern")
# following optional statements are assignments to registers
@ -202,7 +159,7 @@ class _Compiler:
return fstatement
def visit_if(self, sa, node):
test = self.visit_expr(node.test)
test = self.ec.visit_expr(node.test)
states_t, exit_states_t = self.visit_block(node.body)
states_f, exit_states_f = self.visit_block(node.orelse)
exit_states = exit_states_t + exit_states_f
@ -218,7 +175,7 @@ class _Compiler:
exit_states)
def visit_while(self, sa, node):
test = self.visit_expr(node.test)
test = self.ec.visit_expr(node.test)
states_b, exit_states_b = self.visit_block(node.body)
test_state = [If(test, AbstractNextState(states_b[0]))]
@ -269,102 +226,6 @@ class _Compiler:
sa.assemble(states, exit_states)
else:
raise NotImplementedError
# expressions
def visit_expr(self, node):
if isinstance(node, ast.Call):
return self.visit_expr_call(node)
elif isinstance(node, ast.BinOp):
return self.visit_expr_binop(node)
elif isinstance(node, ast.Compare):
return self.visit_expr_compare(node)
elif isinstance(node, ast.Name):
return self.visit_expr_name(node)
elif isinstance(node, ast.Num):
return self.visit_expr_num(node)
else:
raise NotImplementedError
def visit_expr_call(self, node):
if isinstance(node.func, ast.Name):
callee = self.symdict[node.func.id]
else:
raise NotImplementedError
if callee == transel.bitslice:
if len(node.args) != 2 and len(node.args) != 3:
raise TypeError("bitslice() takes 2 or 3 arguments")
val = self.visit_expr(node.args[0])
low = ast.literal_eval(node.args[1])
if len(node.args) == 3:
up = ast.literal_eval(node.args[2])
else:
up = low + 1
return _Slice(val, low, up)
else:
raise NotImplementedError
def visit_expr_binop(self, node):
left = self.visit_expr(node.left)
right = self.visit_expr(node.right)
if isinstance(node.op, ast.Add):
return left + right
elif isinstance(node.op, ast.Sub):
return left - right
elif isinstance(node.op, ast.Mult):
return left * right
elif isinstance(node.op, ast.LShift):
return left << right
elif isinstance(node.op, ast.RShift):
return left >> right
elif isinstance(node.op, ast.BitOr):
return left | right
elif isinstance(node.op, ast.BitXor):
return left ^ right
elif isinstance(node.op, ast.BitAnd):
return left & right
else:
raise NotImplementedError
def visit_expr_compare(self, node):
test = self.visit_expr(node.left)
r = None
for op, rcomparator in zip(node.ops, node.comparators):
comparator = self.visit_expr(rcomparator)
if isinstance(op, ast.Eq):
comparison = test == comparator
elif isinstance(op, ast.NotEq):
comparison = test != comparator
elif isinstance(op, ast.Lt):
comparison = test < comparator
elif isinstance(op, ast.LtE):
comparison = test <= comparator
elif isinstance(op, ast.Gt):
comparison = test > comparator
elif isinstance(op, ast.GtE):
comparison = test >= comparator
else:
raise NotImplementedError
if r is None:
r = comparison
else:
r = r & comparison
test = comparator
return r
def visit_expr_name(self, node):
if node.id == "True":
return Constant(1)
if node.id == "False":
return Constant(0)
r = self.symdict[node.id]
if isinstance(r, _Register):
r = r.storage
if isinstance(r, int):
r = Constant(r)
return r
def visit_expr_num(self, node):
return Constant(node.n)
def make_pytholite(func, **ioresources):
ioo = make_io_object(**ioresources)
@ -381,7 +242,7 @@ def make_pytholite(func, **ioresources):
regf += register.get_fragment()
fsm = implement_fsm(states)
fsmf = _LowerAbstractLoad().visit(fsm.get_fragment())
fsmf = LowerAbstractLoad().visit(fsm.get_fragment())
ioo.fragment = regf + fsmf
return ioo

104
migen/pytholite/expr.py Normal file
View File

@ -0,0 +1,104 @@
import ast
from migen.fhdl.structure import *
from migen.pytholite import transel
from migen.pytholite.reg import *
class ExprCompiler:
def __init__(self, symdict):
self.symdict = symdict
def visit_expr(self, node):
if isinstance(node, ast.Call):
return self.visit_expr_call(node)
elif isinstance(node, ast.BinOp):
return self.visit_expr_binop(node)
elif isinstance(node, ast.Compare):
return self.visit_expr_compare(node)
elif isinstance(node, ast.Name):
return self.visit_expr_name(node)
elif isinstance(node, ast.Num):
return self.visit_expr_num(node)
else:
raise NotImplementedError
def visit_expr_call(self, node):
if isinstance(node.func, ast.Name):
callee = self.symdict[node.func.id]
else:
raise NotImplementedError
if callee == transel.bitslice:
if len(node.args) != 2 and len(node.args) != 3:
raise TypeError("bitslice() takes 2 or 3 arguments")
val = self.visit_expr(node.args[0])
low = ast.literal_eval(node.args[1])
if len(node.args) == 3:
up = ast.literal_eval(node.args[2])
else:
up = low + 1
return _Slice(val, low, up)
else:
raise NotImplementedError
def visit_expr_binop(self, node):
left = self.visit_expr(node.left)
right = self.visit_expr(node.right)
if isinstance(node.op, ast.Add):
return left + right
elif isinstance(node.op, ast.Sub):
return left - right
elif isinstance(node.op, ast.Mult):
return left * right
elif isinstance(node.op, ast.LShift):
return left << right
elif isinstance(node.op, ast.RShift):
return left >> right
elif isinstance(node.op, ast.BitOr):
return left | right
elif isinstance(node.op, ast.BitXor):
return left ^ right
elif isinstance(node.op, ast.BitAnd):
return left & right
else:
raise NotImplementedError
def visit_expr_compare(self, node):
test = self.visit_expr(node.left)
r = None
for op, rcomparator in zip(node.ops, node.comparators):
comparator = self.visit_expr(rcomparator)
if isinstance(op, ast.Eq):
comparison = test == comparator
elif isinstance(op, ast.NotEq):
comparison = test != comparator
elif isinstance(op, ast.Lt):
comparison = test < comparator
elif isinstance(op, ast.LtE):
comparison = test <= comparator
elif isinstance(op, ast.Gt):
comparison = test > comparator
elif isinstance(op, ast.GtE):
comparison = test >= comparator
else:
raise NotImplementedError
if r is None:
r = comparison
else:
r = r & comparison
test = comparator
return r
def visit_expr_name(self, node):
if node.id == "True":
return Constant(1)
if node.id == "False":
return Constant(0)
r = self.symdict[node.id]
if isinstance(r, ImplRegister):
r = r.storage
if isinstance(r, int):
r = Constant(r)
return r
def visit_expr_num(self, node):
return Constant(node.n)

52
migen/pytholite/reg.py Normal file
View File

@ -0,0 +1,52 @@
from operator import itemgetter
from migen.fhdl.structure import *
from migen.fhdl import visit as fhdl
class FinalizeError(Exception):
pass
class AbstractLoad:
def __init__(self, target, source):
self.target = target
self.source = source
def lower(self):
if not self.target.finalized:
raise FinalizeError
return self.target.sel.eq(self.target.source_encoding[self.source])
class LowerAbstractLoad(fhdl.NodeTransformer):
def visit_unknown(self, node):
if isinstance(node, AbstractLoad):
return node.lower()
else:
return node
class ImplRegister:
def __init__(self, name, nbits):
self.name = name
self.storage = Signal(BV(nbits), name=self.name)
self.source_encoding = {}
self.finalized = False
def load(self, source):
if source not in self.source_encoding:
self.source_encoding[source] = len(self.source_encoding) + 1
return AbstractLoad(self, source)
def finalize(self):
if self.finalized:
raise FinalizeError
self.sel = Signal(BV(bits_for(len(self.source_encoding) + 1)), name="pl_regsel_"+self.name)
self.finalized = True
def get_fragment(self):
if not self.finalized:
raise FinalizeError
# do nothing when sel == 0
items = sorted(self.source_encoding.items(), key=itemgetter(1))
cases = [(Constant(v, self.sel.bv),
self.storage.eq(k)) for k, v in items]
sync = [Case(self.sel, *cases)]
return Fragment(sync=sync)