pytholite: support generator arguments

This commit is contained in:
Sebastien Bourdeauducq 2013-07-03 16:35:07 +02:00
parent 04efee7847
commit 0aa58f5dcf
1 changed files with 69 additions and 4 deletions

View File

@ -1,5 +1,6 @@
import inspect import inspect
import ast import ast
from collections import OrderedDict
from migen.fhdl.structure import * from migen.fhdl.structure import *
from migen.fhdl.visit import TransformModule from migen.fhdl.visit import TransformModule
@ -17,6 +18,66 @@ def _is_name_used(node, name):
return True return True
return False return False
def _make_function_args_dict(undefined, symdict, args, defaults):
d = OrderedDict()
for argument in args:
d[argument.arg] = undefined
for default, argname in zip(defaults, reversed(list(d.keys()))):
default_val = eval_ast(default, symdict)
d[argname] = default_val
return d
def _process_function_args(symdict, function_def, args, kwargs):
defargs = function_def.args
undefined = object()
ad_positional = _make_function_args_dict(undefined, symdict, defargs.args, defargs.defaults)
vararg_name = defargs.vararg
kwarg_name = defargs.kwarg
ad_kwonly = _make_function_args_dict(undefined, symdict, defargs.kwonlyargs, defargs.kw_defaults)
# grab argument values
current_argvalue = iter(args)
try:
for argname in ad_positional.keys():
ad_positional[argname] = next(current_argvalue)
except StopIteration:
pass
vararg = tuple(current_argvalue)
kwarg = OrderedDict()
for k, v in kwarg.items():
if k in ad_positional:
ad_positional[k] = v
elif k in ad_kwonly:
ad_kwonly[k] = v
else:
kwarg[k] = v
# check
undefined_pos = [k for k, v in ad_positional.items() if v is undefined]
if undefined_pos:
formatted = " and ".join("'" + k + "'" for k in undefined_pos)
raise TypeError("Missing required positional arguments: " + formatted)
if vararg and vararg_name is None:
raise TypeError("Function takes {} positional arguments but {} were given".format(len(ad_positional),
len(ad_positional) + len(vararg)))
ad_kwonly = [k for k, v in ad_positional.items() if v is undefined]
if undefined_pos:
formatted = " and ".join("'" + k + "'" for k in undefined_pos)
raise TypeError("Missing required keyword-only arguments: " + formatted)
if kwarg and kwarg_name is None:
formatted = " and ".join("'" + k + "'" for k in kwarg.keys())
raise TypeError("Got unexpected keyword arguments: " + formatted)
# update symdict
symdict.update(ad_positional)
if vararg_name is not None:
symdict[vararg_name] = vararg
symdict.update(ad_kwonly)
if kwarg_name is not None:
symdict[kwarg_name] = kwarg
class _Compiler: class _Compiler:
def __init__(self, ioo, symdict, registers): def __init__(self, ioo, symdict, registers):
self.ioo = ioo self.ioo = ioo
@ -24,11 +85,13 @@ class _Compiler:
self.registers = registers self.registers = registers
self.ec = ExprCompiler(self.symdict) self.ec = ExprCompiler(self.symdict)
def visit_top(self, node): def visit_top(self, node, args, kwargs):
if isinstance(node, ast.Module) \ if isinstance(node, ast.Module) \
and len(node.body) == 1 \ and len(node.body) == 1 \
and isinstance(node.body[0], ast.FunctionDef): and isinstance(node.body[0], ast.FunctionDef):
states, exit_states = self.visit_block(node.body[0].body) function_def = node.body[0]
_process_function_args(self.symdict, function_def, args, kwargs)
states, exit_states = self.visit_block(function_def.body)
return states return states
else: else:
raise NotImplementedError raise NotImplementedError
@ -220,8 +283,10 @@ class _Compiler:
raise NotImplementedError raise NotImplementedError
class Pytholite(UnifiedIOObject): class Pytholite(UnifiedIOObject):
def __init__(self, func): def __init__(self, func, *args, **kwargs):
self.func = func self.func = func
self.args = args
self.kwargs = kwargs
def do_finalize(self): def do_finalize(self):
UnifiedIOObject.do_finalize(self) UnifiedIOObject.do_finalize(self)
@ -240,7 +305,7 @@ class Pytholite(UnifiedIOObject):
symdict = self.func.__globals__.copy() symdict = self.func.__globals__.copy()
registers = [] registers = []
states = _Compiler(self, symdict, registers).visit_top(tree) states = _Compiler(self, symdict, registers).visit_top(tree, self.args, self.kwargs)
for register in registers: for register in registers:
if register.source_encoding: if register.source_encoding: