mirror of
https://github.com/enjoy-digital/litex.git
synced 2025-01-04 09:52:26 -05:00
pytholite: move expression and register handling to separate modules
This commit is contained in:
parent
f59fd69e34
commit
bf5ce8dc20
3 changed files with 169 additions and 152 deletions
migen/pytholite
|
@ -1,62 +1,14 @@
|
||||||
import inspect
|
import inspect
|
||||||
import ast
|
import ast
|
||||||
from operator import itemgetter
|
|
||||||
|
|
||||||
from migen.fhdl.structure import *
|
from migen.fhdl.structure import *
|
||||||
from migen.fhdl.structure import _Slice
|
from migen.fhdl.structure import _Slice
|
||||||
from migen.fhdl import visit as fhdl
|
from migen.pytholite.reg import *
|
||||||
|
from migen.pytholite.expr import *
|
||||||
from migen.pytholite import transel
|
from migen.pytholite import transel
|
||||||
from migen.pytholite.io import make_io_object, gen_io
|
from migen.pytholite.io import make_io_object, gen_io
|
||||||
from migen.pytholite.fsm import *
|
from migen.pytholite.fsm import *
|
||||||
|
|
||||||
class FinalizeError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class _AbstractLoad:
|
|
||||||
def __init__(self, target, source):
|
|
||||||
self.target = target
|
|
||||||
self.source = source
|
|
||||||
|
|
||||||
def lower(self):
|
|
||||||
if not self.target.finalized:
|
|
||||||
raise FinalizeError
|
|
||||||
return self.target.sel.eq(self.target.source_encoding[self.source])
|
|
||||||
|
|
||||||
class _LowerAbstractLoad(fhdl.NodeTransformer):
|
|
||||||
def visit_unknown(self, node):
|
|
||||||
if isinstance(node, _AbstractLoad):
|
|
||||||
return node.lower()
|
|
||||||
else:
|
|
||||||
return node
|
|
||||||
|
|
||||||
class _Register:
|
|
||||||
def __init__(self, name, nbits):
|
|
||||||
self.name = name
|
|
||||||
self.storage = Signal(BV(nbits), name=self.name)
|
|
||||||
self.source_encoding = {}
|
|
||||||
self.finalized = False
|
|
||||||
|
|
||||||
def load(self, source):
|
|
||||||
if source not in self.source_encoding:
|
|
||||||
self.source_encoding[source] = len(self.source_encoding) + 1
|
|
||||||
return _AbstractLoad(self, source)
|
|
||||||
|
|
||||||
def finalize(self):
|
|
||||||
if self.finalized:
|
|
||||||
raise FinalizeError
|
|
||||||
self.sel = Signal(BV(bits_for(len(self.source_encoding) + 1)), name="pl_regsel_"+self.name)
|
|
||||||
self.finalized = True
|
|
||||||
|
|
||||||
def get_fragment(self):
|
|
||||||
if not self.finalized:
|
|
||||||
raise FinalizeError
|
|
||||||
# do nothing when sel == 0
|
|
||||||
items = sorted(self.source_encoding.items(), key=itemgetter(1))
|
|
||||||
cases = [(Constant(v, self.sel.bv),
|
|
||||||
self.storage.eq(k)) for k, v in items]
|
|
||||||
sync = [Case(self.sel, *cases)]
|
|
||||||
return Fragment(sync=sync)
|
|
||||||
|
|
||||||
def _is_name_used(node, name):
|
def _is_name_used(node, name):
|
||||||
for n in ast.walk(node):
|
for n in ast.walk(node):
|
||||||
if isinstance(n, ast.Name) and n.id == name:
|
if isinstance(n, ast.Name) and n.id == name:
|
||||||
|
@ -68,6 +20,7 @@ class _Compiler:
|
||||||
self.ioo = ioo
|
self.ioo = ioo
|
||||||
self.symdict = symdict
|
self.symdict = symdict
|
||||||
self.registers = registers
|
self.registers = registers
|
||||||
|
self.ec = ExprCompiler(self.symdict)
|
||||||
|
|
||||||
def visit_top(self, node):
|
def visit_top(self, node):
|
||||||
if isinstance(node, ast.Module) \
|
if isinstance(node, ast.Module) \
|
||||||
|
@ -109,12 +62,15 @@ class _Compiler:
|
||||||
|
|
||||||
def visit_assign(self, sa, node, statements):
|
def visit_assign(self, sa, node, statements):
|
||||||
if isinstance(node.value, ast.Call):
|
if isinstance(node.value, ast.Call):
|
||||||
|
is_special = False
|
||||||
try:
|
try:
|
||||||
value = self.visit_expr_call(node.value)
|
value = self.ec.visit_expr_call(node.value)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
|
is_special = True
|
||||||
|
if is_special:
|
||||||
return self.visit_assign_special(sa, node, statements)
|
return self.visit_assign_special(sa, node, statements)
|
||||||
else:
|
else:
|
||||||
value = self.visit_expr(node.value)
|
value = self.ec.visit_expr(node.value)
|
||||||
if isinstance(value, Value):
|
if isinstance(value, Value):
|
||||||
r = []
|
r = []
|
||||||
for target in node.targets:
|
for target in node.targets:
|
||||||
|
@ -146,7 +102,7 @@ class _Compiler:
|
||||||
targetname = node.targets[0].id
|
targetname = node.targets[0].id
|
||||||
else:
|
else:
|
||||||
targetname = "unk"
|
targetname = "unk"
|
||||||
reg = _Register(targetname, nbits)
|
reg = ImplRegister(targetname, nbits)
|
||||||
self.registers.append(reg)
|
self.registers.append(reg)
|
||||||
for target in node.targets:
|
for target in node.targets:
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
|
@ -173,6 +129,7 @@ class _Compiler:
|
||||||
or not isinstance(ystatement.value, ast.Yield) \
|
or not isinstance(ystatement.value, ast.Yield) \
|
||||||
or not isinstance(ystatement.value.value, ast.Name) \
|
or not isinstance(ystatement.value.value, ast.Name) \
|
||||||
or ystatement.value.value.id != modelname:
|
or ystatement.value.value.id != modelname:
|
||||||
|
print(ast.dump(ystatement))
|
||||||
raise NotImplementedError("Unrecognized I/O pattern")
|
raise NotImplementedError("Unrecognized I/O pattern")
|
||||||
|
|
||||||
# following optional statements are assignments to registers
|
# following optional statements are assignments to registers
|
||||||
|
@ -202,7 +159,7 @@ class _Compiler:
|
||||||
return fstatement
|
return fstatement
|
||||||
|
|
||||||
def visit_if(self, sa, node):
|
def visit_if(self, sa, node):
|
||||||
test = self.visit_expr(node.test)
|
test = self.ec.visit_expr(node.test)
|
||||||
states_t, exit_states_t = self.visit_block(node.body)
|
states_t, exit_states_t = self.visit_block(node.body)
|
||||||
states_f, exit_states_f = self.visit_block(node.orelse)
|
states_f, exit_states_f = self.visit_block(node.orelse)
|
||||||
exit_states = exit_states_t + exit_states_f
|
exit_states = exit_states_t + exit_states_f
|
||||||
|
@ -218,7 +175,7 @@ class _Compiler:
|
||||||
exit_states)
|
exit_states)
|
||||||
|
|
||||||
def visit_while(self, sa, node):
|
def visit_while(self, sa, node):
|
||||||
test = self.visit_expr(node.test)
|
test = self.ec.visit_expr(node.test)
|
||||||
states_b, exit_states_b = self.visit_block(node.body)
|
states_b, exit_states_b = self.visit_block(node.body)
|
||||||
|
|
||||||
test_state = [If(test, AbstractNextState(states_b[0]))]
|
test_state = [If(test, AbstractNextState(states_b[0]))]
|
||||||
|
@ -270,102 +227,6 @@ class _Compiler:
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
# expressions
|
|
||||||
def visit_expr(self, node):
|
|
||||||
if isinstance(node, ast.Call):
|
|
||||||
return self.visit_expr_call(node)
|
|
||||||
elif isinstance(node, ast.BinOp):
|
|
||||||
return self.visit_expr_binop(node)
|
|
||||||
elif isinstance(node, ast.Compare):
|
|
||||||
return self.visit_expr_compare(node)
|
|
||||||
elif isinstance(node, ast.Name):
|
|
||||||
return self.visit_expr_name(node)
|
|
||||||
elif isinstance(node, ast.Num):
|
|
||||||
return self.visit_expr_num(node)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def visit_expr_call(self, node):
|
|
||||||
if isinstance(node.func, ast.Name):
|
|
||||||
callee = self.symdict[node.func.id]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
if callee == transel.bitslice:
|
|
||||||
if len(node.args) != 2 and len(node.args) != 3:
|
|
||||||
raise TypeError("bitslice() takes 2 or 3 arguments")
|
|
||||||
val = self.visit_expr(node.args[0])
|
|
||||||
low = ast.literal_eval(node.args[1])
|
|
||||||
if len(node.args) == 3:
|
|
||||||
up = ast.literal_eval(node.args[2])
|
|
||||||
else:
|
|
||||||
up = low + 1
|
|
||||||
return _Slice(val, low, up)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def visit_expr_binop(self, node):
|
|
||||||
left = self.visit_expr(node.left)
|
|
||||||
right = self.visit_expr(node.right)
|
|
||||||
if isinstance(node.op, ast.Add):
|
|
||||||
return left + right
|
|
||||||
elif isinstance(node.op, ast.Sub):
|
|
||||||
return left - right
|
|
||||||
elif isinstance(node.op, ast.Mult):
|
|
||||||
return left * right
|
|
||||||
elif isinstance(node.op, ast.LShift):
|
|
||||||
return left << right
|
|
||||||
elif isinstance(node.op, ast.RShift):
|
|
||||||
return left >> right
|
|
||||||
elif isinstance(node.op, ast.BitOr):
|
|
||||||
return left | right
|
|
||||||
elif isinstance(node.op, ast.BitXor):
|
|
||||||
return left ^ right
|
|
||||||
elif isinstance(node.op, ast.BitAnd):
|
|
||||||
return left & right
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def visit_expr_compare(self, node):
|
|
||||||
test = self.visit_expr(node.left)
|
|
||||||
r = None
|
|
||||||
for op, rcomparator in zip(node.ops, node.comparators):
|
|
||||||
comparator = self.visit_expr(rcomparator)
|
|
||||||
if isinstance(op, ast.Eq):
|
|
||||||
comparison = test == comparator
|
|
||||||
elif isinstance(op, ast.NotEq):
|
|
||||||
comparison = test != comparator
|
|
||||||
elif isinstance(op, ast.Lt):
|
|
||||||
comparison = test < comparator
|
|
||||||
elif isinstance(op, ast.LtE):
|
|
||||||
comparison = test <= comparator
|
|
||||||
elif isinstance(op, ast.Gt):
|
|
||||||
comparison = test > comparator
|
|
||||||
elif isinstance(op, ast.GtE):
|
|
||||||
comparison = test >= comparator
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
if r is None:
|
|
||||||
r = comparison
|
|
||||||
else:
|
|
||||||
r = r & comparison
|
|
||||||
test = comparator
|
|
||||||
return r
|
|
||||||
|
|
||||||
def visit_expr_name(self, node):
|
|
||||||
if node.id == "True":
|
|
||||||
return Constant(1)
|
|
||||||
if node.id == "False":
|
|
||||||
return Constant(0)
|
|
||||||
r = self.symdict[node.id]
|
|
||||||
if isinstance(r, _Register):
|
|
||||||
r = r.storage
|
|
||||||
if isinstance(r, int):
|
|
||||||
r = Constant(r)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def visit_expr_num(self, node):
|
|
||||||
return Constant(node.n)
|
|
||||||
|
|
||||||
def make_pytholite(func, **ioresources):
|
def make_pytholite(func, **ioresources):
|
||||||
ioo = make_io_object(**ioresources)
|
ioo = make_io_object(**ioresources)
|
||||||
|
|
||||||
|
@ -381,7 +242,7 @@ def make_pytholite(func, **ioresources):
|
||||||
regf += register.get_fragment()
|
regf += register.get_fragment()
|
||||||
|
|
||||||
fsm = implement_fsm(states)
|
fsm = implement_fsm(states)
|
||||||
fsmf = _LowerAbstractLoad().visit(fsm.get_fragment())
|
fsmf = LowerAbstractLoad().visit(fsm.get_fragment())
|
||||||
|
|
||||||
ioo.fragment = regf + fsmf
|
ioo.fragment = regf + fsmf
|
||||||
return ioo
|
return ioo
|
||||||
|
|
104
migen/pytholite/expr.py
Normal file
104
migen/pytholite/expr.py
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
import ast
|
||||||
|
|
||||||
|
from migen.fhdl.structure import *
|
||||||
|
from migen.pytholite import transel
|
||||||
|
from migen.pytholite.reg import *
|
||||||
|
|
||||||
|
class ExprCompiler:
|
||||||
|
def __init__(self, symdict):
|
||||||
|
self.symdict = symdict
|
||||||
|
|
||||||
|
def visit_expr(self, node):
|
||||||
|
if isinstance(node, ast.Call):
|
||||||
|
return self.visit_expr_call(node)
|
||||||
|
elif isinstance(node, ast.BinOp):
|
||||||
|
return self.visit_expr_binop(node)
|
||||||
|
elif isinstance(node, ast.Compare):
|
||||||
|
return self.visit_expr_compare(node)
|
||||||
|
elif isinstance(node, ast.Name):
|
||||||
|
return self.visit_expr_name(node)
|
||||||
|
elif isinstance(node, ast.Num):
|
||||||
|
return self.visit_expr_num(node)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def visit_expr_call(self, node):
|
||||||
|
if isinstance(node.func, ast.Name):
|
||||||
|
callee = self.symdict[node.func.id]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
if callee == transel.bitslice:
|
||||||
|
if len(node.args) != 2 and len(node.args) != 3:
|
||||||
|
raise TypeError("bitslice() takes 2 or 3 arguments")
|
||||||
|
val = self.visit_expr(node.args[0])
|
||||||
|
low = ast.literal_eval(node.args[1])
|
||||||
|
if len(node.args) == 3:
|
||||||
|
up = ast.literal_eval(node.args[2])
|
||||||
|
else:
|
||||||
|
up = low + 1
|
||||||
|
return _Slice(val, low, up)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def visit_expr_binop(self, node):
|
||||||
|
left = self.visit_expr(node.left)
|
||||||
|
right = self.visit_expr(node.right)
|
||||||
|
if isinstance(node.op, ast.Add):
|
||||||
|
return left + right
|
||||||
|
elif isinstance(node.op, ast.Sub):
|
||||||
|
return left - right
|
||||||
|
elif isinstance(node.op, ast.Mult):
|
||||||
|
return left * right
|
||||||
|
elif isinstance(node.op, ast.LShift):
|
||||||
|
return left << right
|
||||||
|
elif isinstance(node.op, ast.RShift):
|
||||||
|
return left >> right
|
||||||
|
elif isinstance(node.op, ast.BitOr):
|
||||||
|
return left | right
|
||||||
|
elif isinstance(node.op, ast.BitXor):
|
||||||
|
return left ^ right
|
||||||
|
elif isinstance(node.op, ast.BitAnd):
|
||||||
|
return left & right
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def visit_expr_compare(self, node):
|
||||||
|
test = self.visit_expr(node.left)
|
||||||
|
r = None
|
||||||
|
for op, rcomparator in zip(node.ops, node.comparators):
|
||||||
|
comparator = self.visit_expr(rcomparator)
|
||||||
|
if isinstance(op, ast.Eq):
|
||||||
|
comparison = test == comparator
|
||||||
|
elif isinstance(op, ast.NotEq):
|
||||||
|
comparison = test != comparator
|
||||||
|
elif isinstance(op, ast.Lt):
|
||||||
|
comparison = test < comparator
|
||||||
|
elif isinstance(op, ast.LtE):
|
||||||
|
comparison = test <= comparator
|
||||||
|
elif isinstance(op, ast.Gt):
|
||||||
|
comparison = test > comparator
|
||||||
|
elif isinstance(op, ast.GtE):
|
||||||
|
comparison = test >= comparator
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
if r is None:
|
||||||
|
r = comparison
|
||||||
|
else:
|
||||||
|
r = r & comparison
|
||||||
|
test = comparator
|
||||||
|
return r
|
||||||
|
|
||||||
|
def visit_expr_name(self, node):
|
||||||
|
if node.id == "True":
|
||||||
|
return Constant(1)
|
||||||
|
if node.id == "False":
|
||||||
|
return Constant(0)
|
||||||
|
r = self.symdict[node.id]
|
||||||
|
if isinstance(r, ImplRegister):
|
||||||
|
r = r.storage
|
||||||
|
if isinstance(r, int):
|
||||||
|
r = Constant(r)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def visit_expr_num(self, node):
|
||||||
|
return Constant(node.n)
|
52
migen/pytholite/reg.py
Normal file
52
migen/pytholite/reg.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
from operator import itemgetter
|
||||||
|
|
||||||
|
from migen.fhdl.structure import *
|
||||||
|
from migen.fhdl import visit as fhdl
|
||||||
|
|
||||||
|
class FinalizeError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class AbstractLoad:
|
||||||
|
def __init__(self, target, source):
|
||||||
|
self.target = target
|
||||||
|
self.source = source
|
||||||
|
|
||||||
|
def lower(self):
|
||||||
|
if not self.target.finalized:
|
||||||
|
raise FinalizeError
|
||||||
|
return self.target.sel.eq(self.target.source_encoding[self.source])
|
||||||
|
|
||||||
|
class LowerAbstractLoad(fhdl.NodeTransformer):
|
||||||
|
def visit_unknown(self, node):
|
||||||
|
if isinstance(node, AbstractLoad):
|
||||||
|
return node.lower()
|
||||||
|
else:
|
||||||
|
return node
|
||||||
|
|
||||||
|
class ImplRegister:
|
||||||
|
def __init__(self, name, nbits):
|
||||||
|
self.name = name
|
||||||
|
self.storage = Signal(BV(nbits), name=self.name)
|
||||||
|
self.source_encoding = {}
|
||||||
|
self.finalized = False
|
||||||
|
|
||||||
|
def load(self, source):
|
||||||
|
if source not in self.source_encoding:
|
||||||
|
self.source_encoding[source] = len(self.source_encoding) + 1
|
||||||
|
return AbstractLoad(self, source)
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
if self.finalized:
|
||||||
|
raise FinalizeError
|
||||||
|
self.sel = Signal(BV(bits_for(len(self.source_encoding) + 1)), name="pl_regsel_"+self.name)
|
||||||
|
self.finalized = True
|
||||||
|
|
||||||
|
def get_fragment(self):
|
||||||
|
if not self.finalized:
|
||||||
|
raise FinalizeError
|
||||||
|
# do nothing when sel == 0
|
||||||
|
items = sorted(self.source_encoding.items(), key=itemgetter(1))
|
||||||
|
cases = [(Constant(v, self.sel.bv),
|
||||||
|
self.storage.eq(k)) for k, v in items]
|
||||||
|
sync = [Case(self.sel, *cases)]
|
||||||
|
return Fragment(sync=sync)
|
Loading…
Reference in a new issue