diff --git a/litex/gen/fhdl/verilog.py b/litex/gen/fhdl/verilog.py index bc03183ed..d6b9536bc 100644 --- a/litex/gen/fhdl/verilog.py +++ b/litex/gen/fhdl/verilog.py @@ -15,18 +15,18 @@ import time import datetime - -from functools import partial -from operator import itemgetter import collections -from migen.fhdl.structure import * -from migen.fhdl.structure import _Operator, _Slice, _Assign, _Fragment -from migen.fhdl.tools import * +from operator import itemgetter + +from migen.fhdl.structure import * +from migen.fhdl.structure import _Operator, _Slice, _Assign, _Fragment +from migen.fhdl.tools import * from migen.fhdl.conv_output import ConvOutput -from migen.fhdl.specials import Memory +from migen.fhdl.specials import Memory from litex.gen.fhdl.namer import build_namespace + from litex.build.tools import get_litex_git_revision # ------------------------------------------------------------------------------------------------ # @@ -35,7 +35,7 @@ from litex.build.tools import get_litex_git_revision _tab = " "*4 -def _print_banner(filename, device): +def _generate_banner(filename, device): return """\ // ----------------------------------------------------------------------------- // Auto-Generated by: __ _ __ _ __ @@ -57,7 +57,7 @@ def _print_banner(filename, device): date = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") ) -def _print_trailer(): +def _generate_trailer(): return """ // ----------------------------------------------------------------------------- // Auto-Generated by LiteX on {date}. @@ -66,7 +66,7 @@ def _print_trailer(): date=datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") ) -def _print_separator(msg=""): +def _generate_separator(msg=""): r = "\n" r += "//" + "-"*78 + "\n" r += f"// {msg}\n" @@ -78,7 +78,7 @@ def _print_separator(msg=""): # TIMESCALE # # ------------------------------------------------------------------------------------------------ # -def _print_timescale(time_unit="1ns", time_precision="1ps"): +def _generate_timescale(time_unit="1ns", time_precision="1ps"): r = f"`timescale {time_unit} / {time_precision}\n" return r @@ -145,7 +145,7 @@ _ieee_1800_2017_verilog_reserved_keywords = { # Print Constant ----------------------------------------------------------------------------------- -def _print_constant(node): +def _generate_constant(node): return "{sign}{bits}'d{value}".format( sign = "" if node.value >= 0 else "-", bits = str(node.nbits), @@ -154,7 +154,7 @@ def _print_constant(node): # Print Signal ------------------------------------------------------------------------------------- -def _print_signal(ns, s): +def _generate_signal(ns, s): length = 8 vector = f"[{str(len(s)-1)}:0] " vector = " "*(length-len(vector)) + vector @@ -168,7 +168,7 @@ def _print_signal(ns, s): (UNARY, BINARY, TERNARY) = (1, 2, 3) -def _print_operator(ns, node): +def _generate_operator(ns, node): operator = node.op operands = node.operands arity = len(operands) @@ -179,7 +179,7 @@ def _print_operator(ns, node): # Unary Operator. if arity == UNARY: - r1, s1 = _print_expression(ns, operands[0]) + r1, s1 = _generate_expression(ns, operands[0]) # Negation Operator. if operator == "-": # Negate and convert to signed if not already. @@ -192,8 +192,8 @@ def _print_operator(ns, node): # Binary Operator. if arity == BINARY: - r1, s1 = _print_expression(ns, operands[0]) - r2, s2 = _print_expression(ns, operands[1]) + r1, s1 = _generate_expression(ns, operands[0]) + r2, s2 = _generate_expression(ns, operands[1]) # Convert all expressions to signed when at least one is signed. if operator not in ["<<<", ">>>"]: if s2 and not s1: @@ -206,9 +206,9 @@ def _print_operator(ns, node): # Ternary Operator. if arity == TERNARY: assert operator == "m" - r1, s1 = _print_expression(ns, operands[0]) - r2, s2 = _print_expression(ns, operands[1]) - r3, s3 = _print_expression(ns, operands[2]) + r1, s1 = _generate_expression(ns, operands[0]) + r2, s2 = _generate_expression(ns, operands[1]) + r3, s3 = _generate_expression(ns, operands[2]) # Convert all expressions to signed when at least one is signed. if s2 and not s3: r3 = to_signed(r3) @@ -221,33 +221,33 @@ def _print_operator(ns, node): # Print Slice -------------------------------------------------------------------------------------- -def _print_slice(ns, node): +def _generate_slice(ns, node): assert (node.stop - node.start) >= 1 if (isinstance(node.value, Signal) and len(node.value) == 1): assert node.start == 0 sr = "" # Avoid slicing 1-bit Signals. else: sr = f"[{node.stop-1}:{node.start}]" if (node.stop - node.start) > 1 else f"[{node.start}]" - r, s = _print_expression(ns, node.value) + r, s = _generate_expression(ns, node.value) return r + sr, s # Print Cat ---------------------------------------------------------------------------------------- -def _print_cat(ns, node): - l = [_print_expression(ns, v)[0] for v in reversed(node.l)] +def _generate_cat(ns, node): + l = [_generate_expression(ns, v)[0] for v in reversed(node.l)] return "{" + ", ".join(l) + "}", False # Print Replicate ---------------------------------------------------------------------------------- -def _print_replicate(ns, node): - return "{" + str(node.n) + "{" + _print_expression(ns, node.v)[0] + "}}", False +def _generate_replicate(ns, node): + return "{" + str(node.n) + "{" + _generate_expression(ns, node.v)[0] + "}}", False # Print Expression --------------------------------------------------------------------------------- -def _print_expression(ns, node): +def _generate_expression(ns, node): # Constant. if isinstance(node, Constant): - return _print_constant(node) + return _generate_constant(node) # Signal. elif isinstance(node, Signal): @@ -255,19 +255,19 @@ def _print_expression(ns, node): # Operator. elif isinstance(node, _Operator): - return _print_operator(ns, node) + return _generate_operator(ns, node) # Slice. elif isinstance(node, _Slice): - return _print_slice(ns, node) + return _generate_slice(ns, node) # Cat. elif isinstance(node, Cat): - return _print_cat(ns, node) + return _generate_cat(ns, node) # Replicate. elif isinstance(node, Replicate): - return _print_replicate(ns, node) + return _generate_replicate(ns, node) # Unknown. else: @@ -279,7 +279,7 @@ def _print_expression(ns, node): (_AT_BLOCKING, _AT_NONBLOCKING, _AT_SIGNAL) = range(3) -def _print_node(ns, at, level, node, target_filter=None): +def _generate_node(ns, at, level, node, target_filter=None): if target_filter is not None and target_filter not in list_targets(node): return "" @@ -293,35 +293,35 @@ def _print_node(ns, at, level, node, target_filter=None): assignment = " = " else: assignment = " <= " - return _tab*level + _print_expression(ns, node.l)[0] + assignment + _print_expression(ns, node.r)[0] + ";\n" + return _tab*level + _generate_expression(ns, node.l)[0] + assignment + _generate_expression(ns, node.r)[0] + ";\n" # Iterable. elif isinstance(node, collections.abc.Iterable): - return "".join(_print_node(ns, at, level, n, target_filter) for n in node) + return "".join(_generate_node(ns, at, level, n, target_filter) for n in node) # If. elif isinstance(node, If): - r = _tab*level + "if (" + _print_expression(ns, node.cond)[0] + ") begin\n" - r += _print_node(ns, at, level + 1, node.t, target_filter) + r = _tab*level + "if (" + _generate_expression(ns, node.cond)[0] + ") begin\n" + r += _generate_node(ns, at, level + 1, node.t, target_filter) if node.f: r += _tab*level + "end else begin\n" - r += _print_node(ns, at, level + 1, node.f, target_filter) + r += _generate_node(ns, at, level + 1, node.f, target_filter) r += _tab*level + "end\n" return r # Case. elif isinstance(node, Case): if node.cases: - r = _tab*level + "case (" + _print_expression(ns, node.test)[0] + ")\n" + r = _tab*level + "case (" + _generate_expression(ns, node.test)[0] + ")\n" css = [(k, v) for k, v in node.cases.items() if isinstance(k, Constant)] css = sorted(css, key=lambda x: x[0].value) for choice, statements in css: - r += _tab*(level + 1) + _print_expression(ns, choice)[0] + ": begin\n" - r += _print_node(ns, at, level + 2, statements, target_filter) + r += _tab*(level + 1) + _generate_expression(ns, choice)[0] + ": begin\n" + r += _generate_node(ns, at, level + 2, statements, target_filter) r += _tab*(level + 1) + "end\n" if "default" in node.cases: r += _tab*(level + 1) + "default: begin\n" - r += _print_node(ns, at, level + 2, node.cases["default"], target_filter) + r += _generate_node(ns, at, level + 2, node.cases["default"], target_filter) r += _tab*(level + 1) + "end\n" r += _tab*level + "endcase\n" return r @@ -351,7 +351,7 @@ def _print_node(ns, at, level, node, target_filter=None): # ATTRIBUTES # # ------------------------------------------------------------------------------------------------ # -def _print_attribute(attr, attr_translate): +def _generate_attribute(attr, attr_translate): r = "" first = True for attr in sorted(attr, key=lambda x: ("", x) if isinstance(x, str) else x): @@ -389,7 +389,7 @@ def _list_comb_wires(f): r |= g[0] return r -def _print_module(f, ios, name, ns, attr_translate): +def _generate_module(f, ios, name, ns, attr_translate): sigs = list_signals(f) | list_special_ios(f, ins=True, outs=True, inouts=True) special_outs = list_special_ios(f, ins=False, outs=True, inouts=True) inouts = list_special_ios(f, ins=False, outs=False, inouts=True) @@ -402,7 +402,7 @@ def _print_module(f, ios, name, ns, attr_translate): if not firstp: r += ",\n" firstp = False - attr = _print_attribute(sig.attr, attr_translate) + attr = _generate_attribute(sig.attr, attr_translate) if attr: r += _tab + attr sig.type = "wire" @@ -410,22 +410,22 @@ def _print_module(f, ios, name, ns, attr_translate): sig.port = True if sig in inouts: sig.direction = "inout" - r += _tab + "inout wire " + _print_signal(ns, sig) + r += _tab + "inout wire " + _generate_signal(ns, sig) elif sig in targets: sig.direction = "output" if sig in wires: - r += _tab + "output wire " + _print_signal(ns, sig) + r += _tab + "output wire " + _generate_signal(ns, sig) else: sig.type = "reg" - r += _tab + "output reg " + _print_signal(ns, sig) + r += _tab + "output reg " + _generate_signal(ns, sig) else: sig.direction = "input" - r += _tab + "input wire " + _print_signal(ns, sig) + r += _tab + "input wire " + _generate_signal(ns, sig) r += "\n);\n\n" return r -def _print_signals(f, ios, name, ns, attr_translate, regs_init): +def _generate_signals(f, ios, name, ns, attr_translate, regs_init): sigs = list_signals(f) | list_special_ios(f, ins=True, outs=True, inouts=True) special_outs = list_special_ios(f, ins=False, outs=True, inouts=True) inouts = list_special_ios(f, ins=False, outs=False, inouts=True) @@ -434,13 +434,13 @@ def _print_signals(f, ios, name, ns, attr_translate, regs_init): r = "" for sig in sorted(sigs - ios, key=lambda x: ns.get_name(x)): - r += _print_attribute(sig.attr, attr_translate) + r += _generate_attribute(sig.attr, attr_translate) if sig in wires: - r += "wire " + _print_signal(ns, sig) + ";\n" + r += "wire " + _generate_signal(ns, sig) + ";\n" else: - r += "reg " + _print_signal(ns, sig) + r += "reg " + _generate_signal(ns, sig) if regs_init: - r += " = " + _print_expression(ns, sig.reset)[0] + r += " = " + _generate_expression(ns, sig.reset)[0] r += ";\n" return r @@ -448,12 +448,10 @@ def _print_signals(f, ios, name, ns, attr_translate, regs_init): # COMBINATORIAL LOGIC # # ------------------------------------------------------------------------------------------------ # -def _print_combinatorial_logic_sim(f, ns): +def _generate_combinatorial_logic_sim(f, ns): r = "" if f.comb: - from collections import defaultdict - - target_stmt_map = defaultdict(list) + target_stmt_map = collections.defaultdict(list) for statement in flat_iteration(f.comb): targets = list_targets(statement) @@ -465,28 +463,28 @@ def _print_combinatorial_logic_sim(f, ns): for n, (t, stmts) in enumerate(target_stmt_map.items()): assert isinstance(t, Signal) if _use_wire(stmts): - r += "assign " + _print_node(ns, _AT_BLOCKING, 0, stmts[0]) + r += "assign " + _generate_node(ns, _AT_BLOCKING, 0, stmts[0]) else: r += "always @(*) begin\n" - r += _tab + ns.get_name(t) + " <= " + _print_expression(ns, t.reset)[0] + ";\n" - r += _print_node(ns, _AT_NONBLOCKING, 1, stmts, t) + r += _tab + ns.get_name(t) + " <= " + _generate_expression(ns, t.reset)[0] + ";\n" + r += _generate_node(ns, _AT_NONBLOCKING, 1, stmts, t) r += "end\n" r += "\n" return r -def _print_combinatorial_logic_synth(f, ns): +def _generate_combinatorial_logic_synth(f, ns): r = "" if f.comb: groups = group_by_targets(f.comb) for n, g in enumerate(groups): if _use_wire(g[1]): - r += "assign " + _print_node(ns, _AT_BLOCKING, 0, g[1][0]) + r += "assign " + _generate_node(ns, _AT_BLOCKING, 0, g[1][0]) else: r += "always @(*) begin\n" for t in sorted(g[0], key=lambda x: ns.get_name(x)): - r += _tab + ns.get_name(t) + " <= " + _print_expression(ns, t.reset)[0] + ";\n" - r += _print_node(ns, _AT_NONBLOCKING, 1, g[1]) + r += _tab + ns.get_name(t) + " <= " + _generate_expression(ns, t.reset)[0] + ";\n" + r += _generate_node(ns, _AT_NONBLOCKING, 1, g[1]) r += "end\n" r += "\n" return r @@ -495,11 +493,11 @@ def _print_combinatorial_logic_synth(f, ns): # SYNCHRONOUS LOGIC # # ------------------------------------------------------------------------------------------------ # -def _print_synchronous_logic(f, ns): +def _generate_synchronous_logic(f, ns): r = "" for k, v in sorted(f.sync.items(), key=itemgetter(0)): r += "always @(posedge " + ns.get_name(f.clock_domains[k].clk) + ") begin\n" - r += _print_node(ns, _AT_SIGNAL, 1, v) + r += _generate_node(ns, _AT_SIGNAL, 1, v) r += "end\n\n" return r @@ -507,11 +505,11 @@ def _print_synchronous_logic(f, ns): # SPECIALS # # ------------------------------------------------------------------------------------------------ # -def _print_specials(name, overrides, specials, namespace, add_data_file, attr_translate): +def _generate_specials(name, overrides, specials, namespace, add_data_file, attr_translate): r = "" for special in sorted(specials, key=lambda x: x.duid): if hasattr(special, "attr"): - r += _print_attribute(special.attr, attr_translate) + r += _generate_attribute(special.attr, attr_translate) # Replace Migen Memory's emit_verilog with LiteX's implementation. if isinstance(special, Memory): from litex.gen.fhdl.memory import memory_emit_verilog @@ -600,7 +598,8 @@ def convert(f, ios=set(), name="top", platform=None, signals = ( list_signals(f) | list_special_ios(f, ins=True, outs=True, inouts=True) | - ios), + ios + ), reserved_keywords = _ieee_1800_2017_verilog_reserved_keywords ) ns.clock_domains = f.clock_domains @@ -608,40 +607,41 @@ def convert(f, ios=set(), name="top", platform=None, # Build Verilog. # -------------- verilog = "" + # Banner. - verilog += _print_banner( + verilog += _generate_banner( filename = name, device = getattr(platform, "device", "Unknown") ) # Timescale. - verilog += _print_timescale( + verilog += _generate_timescale( time_unit = time_unit, time_precision = time_precision ) # Module Definition. - verilog += _print_separator("Module") - verilog += _print_module(f, ios, name, ns, attr_translate) + verilog += _generate_separator("Module") + verilog += _generate_module(f, ios, name, ns, attr_translate) # Module Signals. - verilog += _print_separator("Signals") - verilog += _print_signals(f, ios, name, ns, attr_translate, regs_init) + verilog += _generate_separator("Signals") + verilog += _generate_signals(f, ios, name, ns, attr_translate, regs_init) # Combinatorial Logic. - verilog += _print_separator("Combinatorial Logic") + verilog += _generate_separator("Combinatorial Logic") if regular_comb: - verilog += _print_combinatorial_logic_synth(f, ns) + verilog += _generate_combinatorial_logic_synth(f, ns) else: - verilog += _print_combinatorial_logic_sim(f, ns) + verilog += _generate_combinatorial_logic_sim(f, ns) # Synchronous Logic. - verilog += _print_separator("Synchronous Logic") - verilog += _print_synchronous_logic(f, ns) + verilog += _generate_separator("Synchronous Logic") + verilog += _generate_synchronous_logic(f, ns) # Specials - verilog += _print_separator("Specialized Logic") - verilog += _print_specials( + verilog += _generate_separator("Specialized Logic") + verilog += _generate_specials( name = name, overrides = special_overrides, specials = f.specials - lowered_specials, @@ -654,7 +654,7 @@ def convert(f, ios=set(), name="top", platform=None, verilog += "endmodule\n" # Trailer. - verilog += _print_trailer() + verilog += _generate_trailer() r.set_main_source(verilog) r.ns = ns