gen/fhdl/verilog: Rename _print_xy to _generate_xy and cleanup imports.

This commit is contained in:
Florent Kermarrec 2023-11-03 10:14:38 +01:00
parent 71ae8fe828
commit b60bd92533

View file

@ -15,18 +15,18 @@
import time
import datetime
from functools import partial
from operator import itemgetter
import collections
from migen.fhdl.structure import *
from migen.fhdl.structure import _Operator, _Slice, _Assign, _Fragment
from migen.fhdl.tools import *
from operator import itemgetter
from migen.fhdl.structure import *
from migen.fhdl.structure import _Operator, _Slice, _Assign, _Fragment
from migen.fhdl.tools import *
from migen.fhdl.conv_output import ConvOutput
from migen.fhdl.specials import Memory
from migen.fhdl.specials import Memory
from litex.gen.fhdl.namer import build_namespace
from litex.build.tools import get_litex_git_revision
# ------------------------------------------------------------------------------------------------ #
@ -35,7 +35,7 @@ from litex.build.tools import get_litex_git_revision
_tab = " "*4
def _print_banner(filename, device):
def _generate_banner(filename, device):
return """\
// -----------------------------------------------------------------------------
// Auto-Generated by: __ _ __ _ __
@ -57,7 +57,7 @@ def _print_banner(filename, device):
date = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S")
)
def _print_trailer():
def _generate_trailer():
return """
// -----------------------------------------------------------------------------
// Auto-Generated by LiteX on {date}.
@ -66,7 +66,7 @@ def _print_trailer():
date=datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S")
)
def _print_separator(msg=""):
def _generate_separator(msg=""):
r = "\n"
r += "//" + "-"*78 + "\n"
r += f"// {msg}\n"
@ -78,7 +78,7 @@ def _print_separator(msg=""):
# TIMESCALE #
# ------------------------------------------------------------------------------------------------ #
def _print_timescale(time_unit="1ns", time_precision="1ps"):
def _generate_timescale(time_unit="1ns", time_precision="1ps"):
r = f"`timescale {time_unit} / {time_precision}\n"
return r
@ -145,7 +145,7 @@ _ieee_1800_2017_verilog_reserved_keywords = {
# Print Constant -----------------------------------------------------------------------------------
def _print_constant(node):
def _generate_constant(node):
return "{sign}{bits}'d{value}".format(
sign = "" if node.value >= 0 else "-",
bits = str(node.nbits),
@ -154,7 +154,7 @@ def _print_constant(node):
# Print Signal -------------------------------------------------------------------------------------
def _print_signal(ns, s):
def _generate_signal(ns, s):
length = 8
vector = f"[{str(len(s)-1)}:0] "
vector = " "*(length-len(vector)) + vector
@ -168,7 +168,7 @@ def _print_signal(ns, s):
(UNARY, BINARY, TERNARY) = (1, 2, 3)
def _print_operator(ns, node):
def _generate_operator(ns, node):
operator = node.op
operands = node.operands
arity = len(operands)
@ -179,7 +179,7 @@ def _print_operator(ns, node):
# Unary Operator.
if arity == UNARY:
r1, s1 = _print_expression(ns, operands[0])
r1, s1 = _generate_expression(ns, operands[0])
# Negation Operator.
if operator == "-":
# Negate and convert to signed if not already.
@ -192,8 +192,8 @@ def _print_operator(ns, node):
# Binary Operator.
if arity == BINARY:
r1, s1 = _print_expression(ns, operands[0])
r2, s2 = _print_expression(ns, operands[1])
r1, s1 = _generate_expression(ns, operands[0])
r2, s2 = _generate_expression(ns, operands[1])
# Convert all expressions to signed when at least one is signed.
if operator not in ["<<<", ">>>"]:
if s2 and not s1:
@ -206,9 +206,9 @@ def _print_operator(ns, node):
# Ternary Operator.
if arity == TERNARY:
assert operator == "m"
r1, s1 = _print_expression(ns, operands[0])
r2, s2 = _print_expression(ns, operands[1])
r3, s3 = _print_expression(ns, operands[2])
r1, s1 = _generate_expression(ns, operands[0])
r2, s2 = _generate_expression(ns, operands[1])
r3, s3 = _generate_expression(ns, operands[2])
# Convert all expressions to signed when at least one is signed.
if s2 and not s3:
r3 = to_signed(r3)
@ -221,33 +221,33 @@ def _print_operator(ns, node):
# Print Slice --------------------------------------------------------------------------------------
def _print_slice(ns, node):
def _generate_slice(ns, node):
assert (node.stop - node.start) >= 1
if (isinstance(node.value, Signal) and len(node.value) == 1):
assert node.start == 0
sr = "" # Avoid slicing 1-bit Signals.
else:
sr = f"[{node.stop-1}:{node.start}]" if (node.stop - node.start) > 1 else f"[{node.start}]"
r, s = _print_expression(ns, node.value)
r, s = _generate_expression(ns, node.value)
return r + sr, s
# Print Cat ----------------------------------------------------------------------------------------
def _print_cat(ns, node):
l = [_print_expression(ns, v)[0] for v in reversed(node.l)]
def _generate_cat(ns, node):
l = [_generate_expression(ns, v)[0] for v in reversed(node.l)]
return "{" + ", ".join(l) + "}", False
# Print Replicate ----------------------------------------------------------------------------------
def _print_replicate(ns, node):
return "{" + str(node.n) + "{" + _print_expression(ns, node.v)[0] + "}}", False
def _generate_replicate(ns, node):
return "{" + str(node.n) + "{" + _generate_expression(ns, node.v)[0] + "}}", False
# Print Expression ---------------------------------------------------------------------------------
def _print_expression(ns, node):
def _generate_expression(ns, node):
# Constant.
if isinstance(node, Constant):
return _print_constant(node)
return _generate_constant(node)
# Signal.
elif isinstance(node, Signal):
@ -255,19 +255,19 @@ def _print_expression(ns, node):
# Operator.
elif isinstance(node, _Operator):
return _print_operator(ns, node)
return _generate_operator(ns, node)
# Slice.
elif isinstance(node, _Slice):
return _print_slice(ns, node)
return _generate_slice(ns, node)
# Cat.
elif isinstance(node, Cat):
return _print_cat(ns, node)
return _generate_cat(ns, node)
# Replicate.
elif isinstance(node, Replicate):
return _print_replicate(ns, node)
return _generate_replicate(ns, node)
# Unknown.
else:
@ -279,7 +279,7 @@ def _print_expression(ns, node):
(_AT_BLOCKING, _AT_NONBLOCKING, _AT_SIGNAL) = range(3)
def _print_node(ns, at, level, node, target_filter=None):
def _generate_node(ns, at, level, node, target_filter=None):
if target_filter is not None and target_filter not in list_targets(node):
return ""
@ -293,35 +293,35 @@ def _print_node(ns, at, level, node, target_filter=None):
assignment = " = "
else:
assignment = " <= "
return _tab*level + _print_expression(ns, node.l)[0] + assignment + _print_expression(ns, node.r)[0] + ";\n"
return _tab*level + _generate_expression(ns, node.l)[0] + assignment + _generate_expression(ns, node.r)[0] + ";\n"
# Iterable.
elif isinstance(node, collections.abc.Iterable):
return "".join(_print_node(ns, at, level, n, target_filter) for n in node)
return "".join(_generate_node(ns, at, level, n, target_filter) for n in node)
# If.
elif isinstance(node, If):
r = _tab*level + "if (" + _print_expression(ns, node.cond)[0] + ") begin\n"
r += _print_node(ns, at, level + 1, node.t, target_filter)
r = _tab*level + "if (" + _generate_expression(ns, node.cond)[0] + ") begin\n"
r += _generate_node(ns, at, level + 1, node.t, target_filter)
if node.f:
r += _tab*level + "end else begin\n"
r += _print_node(ns, at, level + 1, node.f, target_filter)
r += _generate_node(ns, at, level + 1, node.f, target_filter)
r += _tab*level + "end\n"
return r
# Case.
elif isinstance(node, Case):
if node.cases:
r = _tab*level + "case (" + _print_expression(ns, node.test)[0] + ")\n"
r = _tab*level + "case (" + _generate_expression(ns, node.test)[0] + ")\n"
css = [(k, v) for k, v in node.cases.items() if isinstance(k, Constant)]
css = sorted(css, key=lambda x: x[0].value)
for choice, statements in css:
r += _tab*(level + 1) + _print_expression(ns, choice)[0] + ": begin\n"
r += _print_node(ns, at, level + 2, statements, target_filter)
r += _tab*(level + 1) + _generate_expression(ns, choice)[0] + ": begin\n"
r += _generate_node(ns, at, level + 2, statements, target_filter)
r += _tab*(level + 1) + "end\n"
if "default" in node.cases:
r += _tab*(level + 1) + "default: begin\n"
r += _print_node(ns, at, level + 2, node.cases["default"], target_filter)
r += _generate_node(ns, at, level + 2, node.cases["default"], target_filter)
r += _tab*(level + 1) + "end\n"
r += _tab*level + "endcase\n"
return r
@ -351,7 +351,7 @@ def _print_node(ns, at, level, node, target_filter=None):
# ATTRIBUTES #
# ------------------------------------------------------------------------------------------------ #
def _print_attribute(attr, attr_translate):
def _generate_attribute(attr, attr_translate):
r = ""
first = True
for attr in sorted(attr, key=lambda x: ("", x) if isinstance(x, str) else x):
@ -389,7 +389,7 @@ def _list_comb_wires(f):
r |= g[0]
return r
def _print_module(f, ios, name, ns, attr_translate):
def _generate_module(f, ios, name, ns, attr_translate):
sigs = list_signals(f) | list_special_ios(f, ins=True, outs=True, inouts=True)
special_outs = list_special_ios(f, ins=False, outs=True, inouts=True)
inouts = list_special_ios(f, ins=False, outs=False, inouts=True)
@ -402,7 +402,7 @@ def _print_module(f, ios, name, ns, attr_translate):
if not firstp:
r += ",\n"
firstp = False
attr = _print_attribute(sig.attr, attr_translate)
attr = _generate_attribute(sig.attr, attr_translate)
if attr:
r += _tab + attr
sig.type = "wire"
@ -410,22 +410,22 @@ def _print_module(f, ios, name, ns, attr_translate):
sig.port = True
if sig in inouts:
sig.direction = "inout"
r += _tab + "inout wire " + _print_signal(ns, sig)
r += _tab + "inout wire " + _generate_signal(ns, sig)
elif sig in targets:
sig.direction = "output"
if sig in wires:
r += _tab + "output wire " + _print_signal(ns, sig)
r += _tab + "output wire " + _generate_signal(ns, sig)
else:
sig.type = "reg"
r += _tab + "output reg " + _print_signal(ns, sig)
r += _tab + "output reg " + _generate_signal(ns, sig)
else:
sig.direction = "input"
r += _tab + "input wire " + _print_signal(ns, sig)
r += _tab + "input wire " + _generate_signal(ns, sig)
r += "\n);\n\n"
return r
def _print_signals(f, ios, name, ns, attr_translate, regs_init):
def _generate_signals(f, ios, name, ns, attr_translate, regs_init):
sigs = list_signals(f) | list_special_ios(f, ins=True, outs=True, inouts=True)
special_outs = list_special_ios(f, ins=False, outs=True, inouts=True)
inouts = list_special_ios(f, ins=False, outs=False, inouts=True)
@ -434,13 +434,13 @@ def _print_signals(f, ios, name, ns, attr_translate, regs_init):
r = ""
for sig in sorted(sigs - ios, key=lambda x: ns.get_name(x)):
r += _print_attribute(sig.attr, attr_translate)
r += _generate_attribute(sig.attr, attr_translate)
if sig in wires:
r += "wire " + _print_signal(ns, sig) + ";\n"
r += "wire " + _generate_signal(ns, sig) + ";\n"
else:
r += "reg " + _print_signal(ns, sig)
r += "reg " + _generate_signal(ns, sig)
if regs_init:
r += " = " + _print_expression(ns, sig.reset)[0]
r += " = " + _generate_expression(ns, sig.reset)[0]
r += ";\n"
return r
@ -448,12 +448,10 @@ def _print_signals(f, ios, name, ns, attr_translate, regs_init):
# COMBINATORIAL LOGIC #
# ------------------------------------------------------------------------------------------------ #
def _print_combinatorial_logic_sim(f, ns):
def _generate_combinatorial_logic_sim(f, ns):
r = ""
if f.comb:
from collections import defaultdict
target_stmt_map = defaultdict(list)
target_stmt_map = collections.defaultdict(list)
for statement in flat_iteration(f.comb):
targets = list_targets(statement)
@ -465,28 +463,28 @@ def _print_combinatorial_logic_sim(f, ns):
for n, (t, stmts) in enumerate(target_stmt_map.items()):
assert isinstance(t, Signal)
if _use_wire(stmts):
r += "assign " + _print_node(ns, _AT_BLOCKING, 0, stmts[0])
r += "assign " + _generate_node(ns, _AT_BLOCKING, 0, stmts[0])
else:
r += "always @(*) begin\n"
r += _tab + ns.get_name(t) + " <= " + _print_expression(ns, t.reset)[0] + ";\n"
r += _print_node(ns, _AT_NONBLOCKING, 1, stmts, t)
r += _tab + ns.get_name(t) + " <= " + _generate_expression(ns, t.reset)[0] + ";\n"
r += _generate_node(ns, _AT_NONBLOCKING, 1, stmts, t)
r += "end\n"
r += "\n"
return r
def _print_combinatorial_logic_synth(f, ns):
def _generate_combinatorial_logic_synth(f, ns):
r = ""
if f.comb:
groups = group_by_targets(f.comb)
for n, g in enumerate(groups):
if _use_wire(g[1]):
r += "assign " + _print_node(ns, _AT_BLOCKING, 0, g[1][0])
r += "assign " + _generate_node(ns, _AT_BLOCKING, 0, g[1][0])
else:
r += "always @(*) begin\n"
for t in sorted(g[0], key=lambda x: ns.get_name(x)):
r += _tab + ns.get_name(t) + " <= " + _print_expression(ns, t.reset)[0] + ";\n"
r += _print_node(ns, _AT_NONBLOCKING, 1, g[1])
r += _tab + ns.get_name(t) + " <= " + _generate_expression(ns, t.reset)[0] + ";\n"
r += _generate_node(ns, _AT_NONBLOCKING, 1, g[1])
r += "end\n"
r += "\n"
return r
@ -495,11 +493,11 @@ def _print_combinatorial_logic_synth(f, ns):
# SYNCHRONOUS LOGIC #
# ------------------------------------------------------------------------------------------------ #
def _print_synchronous_logic(f, ns):
def _generate_synchronous_logic(f, ns):
r = ""
for k, v in sorted(f.sync.items(), key=itemgetter(0)):
r += "always @(posedge " + ns.get_name(f.clock_domains[k].clk) + ") begin\n"
r += _print_node(ns, _AT_SIGNAL, 1, v)
r += _generate_node(ns, _AT_SIGNAL, 1, v)
r += "end\n\n"
return r
@ -507,11 +505,11 @@ def _print_synchronous_logic(f, ns):
# SPECIALS #
# ------------------------------------------------------------------------------------------------ #
def _print_specials(name, overrides, specials, namespace, add_data_file, attr_translate):
def _generate_specials(name, overrides, specials, namespace, add_data_file, attr_translate):
r = ""
for special in sorted(specials, key=lambda x: x.duid):
if hasattr(special, "attr"):
r += _print_attribute(special.attr, attr_translate)
r += _generate_attribute(special.attr, attr_translate)
# Replace Migen Memory's emit_verilog with LiteX's implementation.
if isinstance(special, Memory):
from litex.gen.fhdl.memory import memory_emit_verilog
@ -600,7 +598,8 @@ def convert(f, ios=set(), name="top", platform=None,
signals = (
list_signals(f) |
list_special_ios(f, ins=True, outs=True, inouts=True) |
ios),
ios
),
reserved_keywords = _ieee_1800_2017_verilog_reserved_keywords
)
ns.clock_domains = f.clock_domains
@ -608,40 +607,41 @@ def convert(f, ios=set(), name="top", platform=None,
# Build Verilog.
# --------------
verilog = ""
# Banner.
verilog += _print_banner(
verilog += _generate_banner(
filename = name,
device = getattr(platform, "device", "Unknown")
)
# Timescale.
verilog += _print_timescale(
verilog += _generate_timescale(
time_unit = time_unit,
time_precision = time_precision
)
# Module Definition.
verilog += _print_separator("Module")
verilog += _print_module(f, ios, name, ns, attr_translate)
verilog += _generate_separator("Module")
verilog += _generate_module(f, ios, name, ns, attr_translate)
# Module Signals.
verilog += _print_separator("Signals")
verilog += _print_signals(f, ios, name, ns, attr_translate, regs_init)
verilog += _generate_separator("Signals")
verilog += _generate_signals(f, ios, name, ns, attr_translate, regs_init)
# Combinatorial Logic.
verilog += _print_separator("Combinatorial Logic")
verilog += _generate_separator("Combinatorial Logic")
if regular_comb:
verilog += _print_combinatorial_logic_synth(f, ns)
verilog += _generate_combinatorial_logic_synth(f, ns)
else:
verilog += _print_combinatorial_logic_sim(f, ns)
verilog += _generate_combinatorial_logic_sim(f, ns)
# Synchronous Logic.
verilog += _print_separator("Synchronous Logic")
verilog += _print_synchronous_logic(f, ns)
verilog += _generate_separator("Synchronous Logic")
verilog += _generate_synchronous_logic(f, ns)
# Specials
verilog += _print_separator("Specialized Logic")
verilog += _print_specials(
verilog += _generate_separator("Specialized Logic")
verilog += _generate_specials(
name = name,
overrides = special_overrides,
specials = f.specials - lowered_specials,
@ -654,7 +654,7 @@ def convert(f, ios=set(), name="top", platform=None,
verilog += "endmodule\n"
# Trailer.
verilog += _print_trailer()
verilog += _generate_trailer()
r.set_main_source(verilog)
r.ns = ns