diff --git a/litex/gen/fhdl/structure.py b/litex/gen/fhdl/structure.py index 652f26685..c4afc433f 100644 --- a/litex/gen/fhdl/structure.py +++ b/litex/gen/fhdl/structure.py @@ -755,7 +755,11 @@ class _Fragment: self.clock_domains += other.clock_domains return self + class Display(_Statement): def __init__(self, s, *args): self.s = s self.args = args + +class Finish(_Statement): + pass diff --git a/litex/gen/fhdl/verilog.py b/litex/gen/fhdl/verilog.py index 33873154f..8ce67bc0e 100644 --- a/litex/gen/fhdl/verilog.py +++ b/litex/gen/fhdl/verilog.py @@ -115,9 +115,11 @@ def _printexpr(ns, node): (_AT_BLOCKING, _AT_NONBLOCKING, _AT_SIGNAL) = range(3) -def _printnode(ns, at, level, node): - if isinstance(node, Display): - s = "\"" + node.s + "\\r\"" +def _printnode(ns, at, level, node, target_filter=None): + if target_filter is not None and target_filter not in list_targets(node): + return "" + elif isinstance(node, Display): + s = "\"" + node.s + "\"" for arg in node.args: s += ", " if isinstance(arg, Signal): @@ -125,6 +127,8 @@ def _printnode(ns, at, level, node): else: s += str(arg) return "\t"*level + "$display(" + s + ");\n" + elif isinstance(node, Finish): + return "\t"*level + "$finish;\n" elif isinstance(node, _Assign): if at == _AT_BLOCKING: assignment = " = " @@ -136,13 +140,13 @@ def _printnode(ns, at, level, node): assignment = " <= " return "\t"*level + _printexpr(ns, node.l)[0] + assignment + _printexpr(ns, node.r)[0] + ";\n" elif isinstance(node, collections.Iterable): - return "".join(list(map(partial(_printnode, ns, at, level), node))) + return "".join(_printnode(ns, at, level, n, target_filter) for n in node) elif isinstance(node, If): r = "\t"*level + "if (" + _printexpr(ns, node.cond)[0] + ") begin\n" - r += _printnode(ns, at, level + 1, node.t) + r += _printnode(ns, at, level + 1, node.t, target_filter) if node.f: r += "\t"*level + "end else begin\n" - r += _printnode(ns, at, level + 1, node.f) + r += _printnode(ns, at, level + 1, node.f, target_filter) r += "\t"*level + "end\n" return r elif isinstance(node, Case): @@ -152,11 +156,11 @@ def _printnode(ns, at, level, node): css = sorted(css, key=lambda x: x[0].value) for choice, statements in css: r += "\t"*(level + 1) + _printexpr(ns, choice)[0] + ": begin\n" - r += _printnode(ns, at, level + 2, statements) + r += _printnode(ns, at, level + 2, statements, target_filter) r += "\t"*(level + 1) + "end\n" if "default" in node.cases: r += "\t"*(level + 1) + "default: begin\n" - r += _printnode(ns, at, level + 2, node.cases["default"]) + r += _printnode(ns, at, level + 2, node.cases["default"], target_filter) r += "\t"*(level + 1) + "end\n" r += "\t"*level + "endcase\n" return r @@ -238,32 +242,39 @@ def _printheader(f, ios, name, ns, attr_translate, return r -def _printcomb(f, ns, - display_run, - dummy_signal, - blocking_assign): +def _printcomb_simulation(f, ns, + display_run, + dummy_signal, + blocking_assign): r = "" if f.comb: if dummy_signal: - explanation = """ -// Adding a dummy event (using a dummy signal 'dummy_s') to get the simulator -// to run the combinatorial process once at the beginning. -""" + # Generate a dummy event to get the simulator + # to run the combinatorial process once at the beginning. syn_off = "// synthesis translate_off\n" syn_on = "// synthesis translate_on\n" dummy_s = Signal(name_override="dummy_s") - r += explanation r += syn_off r += "reg " + _printsig(ns, dummy_s) + ";\n" r += "initial " + ns.get_name(dummy_s) + " <= 1'd0;\n" r += syn_on - r += "\n" + + + from collections import defaultdict + + target_stmt_map = defaultdict(list) + + for statement in flat_iteration(f.comb): + targets = list_targets(statement) + for t in targets: + target_stmt_map[t].append(statement) groups = group_by_targets(f.comb) - for n, g in enumerate(groups): - if len(g[1]) == 1 and isinstance(g[1][0], _Assign): - r += "assign " + _printnode(ns, _AT_BLOCKING, 0, g[1][0]) + for n, (t, stmts) in enumerate(target_stmt_map.items()): + assert isinstance(t, Signal) + if len(stmts) == 1 and isinstance(stmts[0], _Assign): + r += "assign " + _printnode(ns, _AT_BLOCKING, 0, stmts[0]) else: if dummy_signal: dummy_d = Signal(name_override="dummy_d") @@ -274,6 +285,31 @@ def _printcomb(f, ns, r += "always @(*) begin\n" if display_run: r += "\t$display(\"Running comb block #" + str(n) + "\");\n" + if blocking_assign: + r += "\t" + ns.get_name(t) + " = " + _printexpr(ns, t.reset)[0] + ";\n" + r += _printnode(ns, _AT_BLOCKING, 1, stmts, t) + else: + r += "\t" + ns.get_name(t) + " <= " + _printexpr(ns, t.reset)[0] + ";\n" + r += _printnode(ns, _AT_NONBLOCKING, 1, stmts, t) + if dummy_signal: + r += syn_off + r += "\t" + ns.get_name(dummy_d) + " = " + ns.get_name(dummy_s) + ";\n" + r += syn_on + r += "end\n" + r += "\n" + return r + + +def _printcomb_regular(f, ns, blocking_assign): + r = "" + if f.comb: + groups = group_by_targets(f.comb) + + for n, g in enumerate(groups): + if len(g[1]) == 1 and isinstance(g[1][0], _Assign): + r += "assign " + _printnode(ns, _AT_BLOCKING, 0, g[1][0]) + else: + r += "always @(*) begin\n" if blocking_assign: for t in g[0]: r += "\t" + ns.get_name(t) + " = " + _printexpr(ns, t.reset)[0] + ";\n" @@ -282,10 +318,6 @@ def _printcomb(f, ns, for t in g[0]: r += "\t" + ns.get_name(t) + " <= " + _printexpr(ns, t.reset)[0] + ";\n" r += _printnode(ns, _AT_NONBLOCKING, 1, g[1]) - if dummy_signal: - r += syn_off - r += "\t" + ns.get_name(dummy_d) + " <= " + ns.get_name(dummy_s) + ";\n" - r += syn_on r += "end\n" r += "\n" return r @@ -368,7 +400,11 @@ def convert(f, ios=None, name="top", src = "/* Machine-generated using LiteX gen */\n" src += _printheader(f, ios, name, ns, attr_translate, reg_initialization=reg_initialization) - src += _printcomb(f, ns, + if regular_comb: + src += _printcomb_regular(f, ns, + blocking_assign=blocking_assign) + else: + src += _printcomb_simulation(f, ns, display_run=display_run, dummy_signal=dummy_signal, blocking_assign=blocking_assign)