diff --git a/litedram/init.py b/litedram/init.py index e84010c..41637da 100644 --- a/litedram/init.py +++ b/litedram/init.py @@ -636,6 +636,8 @@ def get_sdram_phy_c_header(phy_settings, timing_settings): r += "#define SDRAM_PHY_READ_LEVELING_CAPABLE\n" if phytype in ["ECP5DDRPHY"]: r += "#define SDRAM_PHY_READ_LEVELING_CAPABLE\n" + if phytype in ["LPDDR4SIMPHY"]: + r += "#define SDRAM_PHY_READ_LEVELING_CAPABLE\n" # Define number of modules/delays/bitslips if phytype in ["USDDRPHY", "USPDDRPHY"]: @@ -650,6 +652,10 @@ def get_sdram_phy_c_header(phy_settings, timing_settings): r += "#define SDRAM_PHY_MODULES DFII_PIX_DATA_BYTES/4\n" r += "#define SDRAM_PHY_DELAYS 8\n" r += "#define SDRAM_PHY_BITSLIPS 4\n" + elif phytype in ["LPDDR4SIMPHY"]: + r += "#define SDRAM_PHY_MODULES 2\n" + r += "#define SDRAM_PHY_DELAYS 1\n" + r += "#define SDRAM_PHY_BITSLIPS 16\n" if phy_settings.is_rdimm: assert phy_settings.memtype == "DDR4" diff --git a/litedram/phy/lpddr4/basephy.py b/litedram/phy/lpddr4/basephy.py index ec7e6e6..36e4ccc 100644 --- a/litedram/phy/lpddr4/basephy.py +++ b/litedram/phy/lpddr4/basephy.py @@ -93,8 +93,6 @@ class LPDDR4PHY(Module, AutoCSR): # Registers -------------------------------------------------------------------------------- self._rst = CSRStorage() - self._dly_sel = CSRStorage(databits//8) - self._wlevel_en = CSRStorage() self._wlevel_strobe = CSR() diff --git a/litedram/phy/lpddr4/sim.py b/litedram/phy/lpddr4/sim.py new file mode 100644 index 0000000..0ee9883 --- /dev/null +++ b/litedram/phy/lpddr4/sim.py @@ -0,0 +1,561 @@ +import math +from operator import or_ +from functools import reduce +from collections import defaultdict + +from migen import * + +from litex.soc.interconnect.stream import ClockDomainCrossing +from litex.soc.interconnect.csr import AutoCSR + +from litedram.common import TappedDelayLine, tXXDController +from litedram.phy.lpddr4.utils import delayed, once, SimLogger +from litedram.phy.lpddr4.commands import MPC + + +def log_level_getter(log_level): + def get_level(name): + return getattr(SimLogger, name.upper()) + # simple log_level, e.g. "INFO" + if "=" not in log_level: + return lambda _: get_level(log_level) + # parse log_level in the per-module form, e.g. "--log-level=all=INFO,data=DEBUG" + per_module = dict(part.split("=") for part in log_level.strip().split(",")) + return lambda module: get_level(per_module.get(module, per_module.get("all", None))) + + +class LPDDR4Sim(Module, AutoCSR): + def __init__(self, pads, *, sys_clk_freq, disable_delay, settings, log_level): + log_level = log_level_getter(log_level) + + cd_cmd = "sys8x_90" + cd_dq_wr = "sys8x_90_ddr" + cd_dqs_wr = "sys8x_ddr" + cd_dq_rd = "sys8x_90_ddr" + cd_dqs_rd = "sys8x_ddr" + + self.submodules.data = ClockDomainCrossing( + [("we", 1), ("masked", 1), ("bank", 3), ("row", 17), ("col", 10)], + cd_from=cd_cmd, cd_to=cd_dq_wr) + + cmd = CommandsSim(pads, + data_cdc = self.data, + clk_freq = 8*sys_clk_freq, + log_level = log_level("cmd"), + init_delays = not disable_delay, + ) + self.submodules.cmd = ClockDomainsRenamer(cd_cmd)(cmd) + + data = DataSim(pads, self.cmd, + cd_dq_wr = cd_dq_wr, + cd_dqs_wr = cd_dqs_wr, + cd_dq_rd = cd_dq_rd, + cd_dqs_rd = cd_dqs_rd, + clk_freq = 2*8*sys_clk_freq, + cl = settings.phy.cl, + cwl = settings.phy.cwl, + log_level = log_level("data"), + ) + self.submodules.data = ClockDomainsRenamer(cd_dq_wr)(data) + +# Commands ----------------------------------------------------------------------------------------- + +class CommandsSim(Module, AutoCSR): # clock domain: clk_p + def __init__(self, pads, data_cdc, *, clk_freq, log_level, init_delays=False): + self.submodules.log = log = SimLogger(log_level=log_level, clk_freq=clk_freq) + self.log.add_csrs() + + # Mode Registers storage + self.mode_regs = Array([Signal(8) for _ in range(64)]) + # Active banks + self.active_banks = Array([Signal() for _ in range(8)]) + self.active_rows = Array([Signal(17) for _ in range(8)]) + # Connection to DataSim + self.data_en = TappedDelayLine(ntaps=20) + self.data = data_cdc + self.submodules += self.data, self.data_en + + # CS/CA shift registers + cs = TappedDelayLine(pads.cs, ntaps=2) + ca = TappedDelayLine(pads.ca, ntaps=2) + self.submodules += cs, ca + + self.cs_low = Signal(6) + self.cs_high = Signal(6) + self.handle_cmd = Signal() + self.mpc_op = Signal(7) + + cmds_enabled = Signal() + cmd_handlers = { + "MRW": self.mrw_handler(), + "REF": self.refresh_handler(), + "ACT": self.activate_handler(), + "PRE": self.precharge_handler(), + "CAS": self.cas_handler(), + "MPC": self.mpc_handler(), + } + self.comb += [ + If(cmds_enabled, + If(Cat(cs.taps) == 0b10, + self.handle_cmd.eq(1), + self.cs_high.eq(ca.taps[1]), + self.cs_low.eq(ca.taps[0]), + ) + ), + If(self.handle_cmd & ~reduce(or_, cmd_handlers.values()), + self.log.error("Unexpected command: cs_high=0b%06b cs_low=0b%06b", self.cs_high, self.cs_low) + ), + ] + + def ck(t): + return math.ceil(t * clk_freq) + + self.submodules.tinit0 = tXXDController(ck(20e-3)) + self.submodules.tinit1 = tXXDController(ck(200e-6)) + self.submodules.tinit2 = tXXDController(ck(10e-9)) + self.submodules.tinit3 = tXXDController(ck(2e-3)) + self.submodules.tinit4 = tXXDController(5) # TODO: would require counting pads.clk_p ticks + self.submodules.tinit5 = tXXDController(ck(2e-6)) + self.submodules.tzqcal = tXXDController(ck(1e-6)) + self.submodules.tzqlat = tXXDController(max(8, ck(30e-9))) + + self.comb += [ + If(~delayed(self, pads.reset_n) & pads.reset_n, + self.log.info("RESET released"), + ), + If(delayed(self, pads.reset_n) & ~pads.reset_n, + self.log.info("RESET asserted"), + ), + If(delayed(self, pads.cke) & ~pads.cke, + self.log.info("CKE falling edge"), + ), + If(~delayed(self, pads.cke) & pads.cke, + self.log.info("CKE rising edge"), + ), + ] + + self.submodules.fsm = fsm = FSM() + fsm.act("POWER-RAMP", + self.tinit0.valid.eq(1), + If(~pads.reset_n, + If(self.tinit0.ready, # tINIT0 is MAX, so should be not ready + self.log.warn("tINIT0 violated") + ), + NextState("RESET") # Tb + ) + ) + fsm.act("RESET", + self.tinit1.valid.eq(1), + self.tinit2.valid.eq(~pads.cke), + If(pads.reset_n, + If(~self.tinit1.ready, + self.log.warn("tINIT1 violated") + ), + If(~self.tinit2.ready, + self.log.warn("tINIT2 violated") + ), + NextState("WAIT-PD"), # Tc + ) + ) + fsm.act("WAIT-PD", + self.tinit3.valid.eq(1), + If(self.tinit3.ready | (not init_delays), + NextState("EXIT-PD") # Td + ) + ) + fsm.act("EXIT-PD", + self.tinit5.valid.eq(1), + If(self.tinit5.ready | (not init_delays), + NextState("MRW") # Te + ) + ) + fsm.act("MRW", + cmds_enabled.eq(1), + If(self.handle_cmd & ~cmd_handlers["MRW"] & ~cmd_handlers["MPC"], + self.log.warn("Only MRW/MRR commands expected before ZQ calibration") + ), + If(cmd_handlers["MPC"], + If(self.mpc_op != MPC["ZQC-START"], + self.log.error("ZQC-START expected, got op=0b%07b", self.mpc_op) + ).Else( + NextState("ZQC") # Tf + ) + ), + ) + fsm.act("ZQC", + self.tzqcal.valid.eq(1), + cmds_enabled.eq(1), + If(self.handle_cmd, + If(~(cmd_handlers["MPC"] & (self.mpc_op == MPC["ZQC-LATCH"])), + self.log.error("Expected ZQC-LATCH") + ).Else( + If(~self.tzqcal.ready, + self.log.warn("tZQCAL violated") + ), + NextState("NORMAL") # Tg + ) + ), + ) + # TODO: Bus training currently is not performed in the simulation + fsm.act("NORMAL", + cmds_enabled.eq(1), + self.tzqlat.valid.eq(1), + once(self, self.handle_cmd & ~self.tzqlat.ready, + self.log.warn("tZQLAT violated") + ), + ) + + # Log state transitions + fsm.finalize() + prev_state = delayed(self, fsm.state) + self.comb += If(prev_state != fsm.state, + Case(prev_state, { + state: Case(fsm.state, { + next_state: self.log.info(f"FSM: {state_name} -> {next_state_name}") + for next_state, next_state_name in fsm.decoding.items() + }) + for state, state_name in fsm.decoding.items() + }) + ) + + def cmd_one_step(self, name, cond, comb, sync=None): + matched = Signal() + self.comb += If(self.handle_cmd & cond, + self.log.debug(name), + matched.eq(1), + *comb + ) + if sync is not None: + self.sync += If(self.handle_cmd & cond, + *sync + ) + return matched + + def cmd_two_step(self, name, cond1, body1, cond2, body2): + state1, state2 = f"{name}-1", f"{name}-2" + matched = Signal() + + fsm = FSM() + fsm.act(state1, + If(self.handle_cmd & cond1, + self.log.debug(state1), + matched.eq(1), + *body1, + NextState(state2) + ) + ) + fsm.act(state2, + If(self.handle_cmd, + If(cond2, + self.log.debug(state2), + matched.eq(1), + *body2 + ).Else( + self.log.error(f"Waiting for {state2} but got unexpected cs_high=0b%06b cs_low=0b%06b", self.cs_high, self.cs_low) + ), + NextState(state1) # always back to first + ) + ) + self.submodules += fsm + + return matched + + def mrw_handler(self): + ma = Signal(6) + op7 = Signal() + op = Signal(8) + return self.cmd_two_step("MRW", + cond1 = self.cs_high[:5] == 0b00110, + body1 = [ + NextValue(ma, self.cs_low), + NextValue(op7, self.cs_high[5]), + ], + cond2 = self.cs_high[:5] == 0b10110, + body2 = [ + self.log.info("MRW: MR[%d] = 0x%02x", ma, op), + op.eq(Cat(self.cs_low, self.cs_high[5], op7)), + NextValue(self.mode_regs[ma], op), + ] + ) + + def refresh_handler(self): + return self.cmd_one_step("REFRESH", + cond = self.cs_high[:5] == 0b01000, + comb = [ + If(reduce(or_, self.active_banks), + self.log.error("Not all banks precharged during REFRESH") + ) + ] + ) + + def activate_handler(self): + bank = Signal(3) + row1 = Signal(7) + row2 = Signal(10) + row = Signal(17) + return self.cmd_two_step("ACTIVATE", + cond1 = self.cs_high[:2] == 0b01, + body1 = [ + NextValue(bank, self.cs_low[:3]), + NextValue(row1, Cat(self.cs_low[4:6], self.cs_high[2:6], self.cs_low[3])), + ], + cond2 = self.cs_high[:2] == 0b11, + body2 = [ + self.log.info("ACT: bank=%d row=%d", bank, row), + row2.eq(Cat(self.cs_low, self.cs_high[2:])), + row.eq(Cat(row2, row1)), + NextValue(self.active_banks[bank], 1), + NextValue(self.active_rows[bank], row), + If(self.active_banks[bank], + self.log.error("ACT on already active bank: bank=%d row=%d", bank, row) + ), + ] + ) + + def precharge_handler(self): + bank = Signal(3) + return self.cmd_one_step("PRECHARGE", + cond = self.cs_high[:5] == 0b10000, + comb = [ + If(self.cs_high[5], + self.log.info("PRE: all banks"), + bank.eq(2**len(bank) - 1), + ).Else( + self.log.info("PRE: bank = %d", bank), + bank.eq(self.cs_low[:3]), + ), + ], + sync = [ + If(self.cs_high[5], + *[self.active_banks[b].eq(0) for b in range(2**len(bank))] + ).Else( + self.active_banks[bank].eq(0), + If(~self.active_banks[bank], + self.log.warn("PRE on inactive bank: bank=%d", bank) + ), + ), + ] + ) + + def mpc_handler(self): + cases = {value: self.log.info(f"MPC: {name}") for name, value in MPC.items()} + cases["default"] = self.log.error("Invalid MPC op=0b%07b", self.mpc_op) + return self.cmd_one_step("MPC", + cond = self.cs_high[:5] == 0b00000, + comb = [ + self.mpc_op.eq(Cat(self.cs_low, self.cs_high[5])), + If(self.cs_high[5] == 0, + self.log.info("MPC: NOOP") + ).Else( + Case(self.mpc_op, cases) + ) + ], + ) + + def cas_handler(self): + cas1 = Signal(5) + cas2 = 0b10010 + cas1_cmds = { + "WRITE": 0b00100, + "MASKED-WRITE": 0b01100, + "READ": 0b00010, + } + + bank = Signal(3) + row = Signal(17) + col9 = Signal() + col = Signal(10) + burst_len = Signal() + auto_precharge = Signal() + + return self.cmd_two_step("CAS", + cond1 = reduce(or_, [self.cs_high[:5] == cmd for cmd in cas1_cmds.values()]), + body1 = [ + NextValue(cas1, self.cs_high[:5]), + NextValue(bank, self.cs_low[:3]), + NextValue(col9, self.cs_low[4]), + NextValue(burst_len, self.cs_high[5]), + NextValue(auto_precharge, self.cs_low[5]), + ], + cond2 = self.cs_high[:5] == cas2, + body2 = [ + row.eq(self.active_rows[bank]), + col.eq(Cat(Replicate(0, 2), self.cs_low, self.cs_high[5], col9)), + # command type info + Case(cas1, { + value: self.log.info(f"{name}: bank=%d row=%d col=%d", bank, row, col) + for name, value in cas1_cmds.items() + }), + # sanity checks + If(~self.active_banks[bank], + self.log.error("CAS command on inactive bank: bank=%d row=%d col=%d", bank, row, col) + ), + If((cas1 != cas1_cmds["READ"]) & (col[:4] != 0), + self.log.error("WRITE commands must use C[3:2]=0 (must be aligned to full burst)") + ), + If(self.mode_regs[3][6] | self.mode_regs[3][7], + self.log.error("DBI currently not supported in the simulator") + ), + If((cas1 == cas1_cmds['MASKED-WRITE']) & (self.mode_regs[13][5] == 1), + self.log.error("MASKED-WRITE but Data Mask operation disabled in MR13[5]") + ), + If(auto_precharge, + self.log.info("AUTO-PRECHARGE: bank=%d row=%d", bank, row), + NextValue(self.active_banks[bank], 0), + ), + # pass the data to data simulator + self.data_en.input.eq(1), + self.data.sink.valid.eq(1), + self.data.sink.we.eq(cas1 != cas1_cmds["READ"]), + self.data.sink.masked.eq(cas1 == cas1_cmds["MASKED-WRITE"]), + self.data.sink.bank.eq(bank), + self.data.sink.row.eq(row), + self.data.sink.col.eq(col), + If(~self.data.sink.ready, + self.log.error("Simulator data FIFO overflow") + ) + ], + ) + +# Data --------------------------------------------------------------------------------------------- + +class DataSim(Module, AutoCSR): # clock domain: ddr + def __init__(self, pads, cmds_sim, *, cd_dq_wr, cd_dq_rd, cd_dqs_wr, cd_dqs_rd, cl, cwl, clk_freq, log_level): + self.submodules.log = log = SimLogger(log_level=log_level, clk_freq=clk_freq) + self.log.add_csrs() + + bl = 16 + + # Per-bank memory + nrows, ncols = 32768, 1024 + mems = [Memory(len(pads.dq), depth=nrows * ncols) for _ in range(8)] + ports = [mem.get_port(write_capable=True, we_granularity=8, async_read=True) for mem in mems] + self.specials += *mems, *ports + ports = Array(ports) + + bank = Signal(3) + row = Signal(17) + col = Signal(10) + + dq_kwargs = dict(bank=bank, row=row, col=col, bl=bl, nrows=nrows, ncols=ncols, + log_level=log_level, clk_freq=clk_freq) + dqs_kwargs = dict(bl=bl, log_level=log_level, clk_freq=clk_freq) + + self.submodules.dq_wr = ClockDomainsRenamer(cd_dq_wr)(DQWrite(dq=pads.dq, dmi=pads.dmi, ports=ports, **dq_kwargs)) + self.submodules.dq_rd = ClockDomainsRenamer(cd_dq_rd)(DQRead(dq=pads.dq_i, ports=ports, **dq_kwargs)) + self.submodules.dqs_wr = ClockDomainsRenamer(cd_dqs_wr)(DQSWrite(dqs=pads.dqs, **dqs_kwargs)) + self.submodules.dqs_rd = ClockDomainsRenamer(cd_dqs_rd)(DQSRead(dqs=pads.dqs_i,**dqs_kwargs)) + + write = Signal() + read = Signal() + + self.comb += [ + write.eq(cmds_sim.data_en.taps[cwl-1] & cmds_sim.data.source.valid & cmds_sim.data.source.we), + read.eq(cmds_sim.data_en.taps[cl-1] & cmds_sim.data.source.valid & ~cmds_sim.data.source.we), + cmds_sim.data.source.ready.eq(write | read), + self.dq_wr.masked.eq(write & cmds_sim.data.source.masked), + self.dq_wr.trigger.eq(write), + self.dq_rd.trigger.eq(read), + self.dqs_wr.trigger.eq(write), + self.dqs_rd.trigger.eq(read), + ] + + self.sync += [ + If(cmds_sim.data.source.ready, + bank.eq(cmds_sim.data.source.bank), + row.eq(cmds_sim.data.source.row), + col.eq(cmds_sim.data.source.col), + ) + ] + + +class DataBurst(Module, AutoCSR): + def __init__(self, *, bl, log_level, clk_freq): + self.submodules.log = log = SimLogger(log_level=log_level, clk_freq=clk_freq) + self.log.add_csrs() + + self.bl = bl + self.trigger = Signal() + self.burst_counter = Signal(max=bl - 1) + + def add_fsm(self, ops, on_trigger=[]): + self.submodules.fsm = fsm = FSM() + fsm.act("IDLE", + NextValue(self.burst_counter, 0), + If(self.trigger, + *on_trigger, + NextState("BURST") + ) + ) + fsm.act("BURST", + *ops, + NextValue(self.burst_counter, self.burst_counter + 1), + If(self.burst_counter == self.bl - 1, + NextState("IDLE") + ), + ) + +class DQBurst(DataBurst): + def __init__(self, *, nrows, ncols, row, col, **kwargs): + super().__init__(**kwargs) + self.addr = Signal(max=nrows * ncols) + self.col_burst = Signal(10) + self.comb += [ + self.col_burst.eq(col + self.burst_counter), + self.addr.eq(row * ncols + self.col_burst), + ] + +class DQWrite(DQBurst): + def __init__(self, *, dq, dmi, ports, nrows, ncols, bank, row, col, **kwargs): + super().__init__(nrows=nrows, ncols=ncols, row=row, col=col, **kwargs) + + assert len(dmi) == len(ports[0].we), "port.we should have the same width as the DMI line" + self.masked = Signal() + masked = Signal() + + self.add_fsm( + on_trigger = [ + NextValue(masked, self.masked), + ], + ops = [ + self.log.debug("WRITE[%d]: bank=%d, row=%d, col=%d, data=0x%04x", + self.burst_counter, bank, row, self.col_burst, dq, once=False), + If(masked, + ports[bank].we.eq(~dmi), # DMI high masks the beat + ).Else( + ports[bank].we.eq(2**len(ports[bank].we) - 1), + ), + ports[bank].adr.eq(self.addr), + ports[bank].dat_w.eq(dq), + ] + ) + +class DQRead(DQBurst): + def __init__(self, *, dq, ports, nrows, ncols, bank, row, col, **kwargs): + super().__init__(nrows=nrows, ncols=ncols, row=row, col=col, **kwargs) + self.add_fsm([ + self.log.debug("READ[%d]: bank=%d, row=%d, col=%d, data=0x%04x", + self.burst_counter, bank, row, self.col_burst, dq, once=False), + ports[bank].we.eq(0), + ports[bank].adr.eq(self.addr), + dq.eq(ports[bank].dat_r), + ]) + +class DQSWrite(DataBurst): + def __init__(self, *, dqs, **kwargs): + super().__init__(**kwargs) + dqs0 = Signal() + self.add_fsm([ + dqs0.eq(dqs[0]), + If(dqs[0] != self.burst_counter[0], + self.log.warn("Wrong DQS=%d for cycle=%d", dqs0, self.burst_counter, once=False) + ), + ]) + +class DQSRead(DataBurst): + def __init__(self, *, dqs, **kwargs): + super().__init__(**kwargs) + dqs0 = Signal() + self.add_fsm([ + *[i.eq(self.burst_counter[0]) for i in dqs], + ]) diff --git a/litedram/phy/lpddr4/simsoc.py b/litedram/phy/lpddr4/simsoc.py new file mode 100644 index 0000000..e503ff3 --- /dev/null +++ b/litedram/phy/lpddr4/simsoc.py @@ -0,0 +1,526 @@ +import os +import re +import argparse +from collections import namedtuple, defaultdict + +from migen import * + +from litex.build.generic_platform import Pins, Subsignal +from litex.build.sim import SimPlatform +from litex.build.sim.config import SimConfig + +from litex.soc.interconnect.csr import CSR +from litex.soc.integration.soc_core import SoCCore +from litex.soc.integration.soc_sdram import soc_sdram_args, soc_sdram_argdict +from litex.soc.integration.builder import builder_args, builder_argdict, Builder +from litex.soc.cores.cpu import CPUS + +from litedram import modules as litedram_modules +from litedram.core.controller import ControllerSettings +from litedram.phy.model import DFITimingsChecker, _speedgrade_timings, _technology_timings + +from litedram.phy.lpddr4.simphy import LPDDR4SimPHY +from litedram.phy.lpddr4.sim import LPDDR4Sim + +# Platform ----------------------------------------------------------------------------------------- + +_io = [ + # clocks added later + ("sys_rst", 0, Pins(1)), + + ("serial", 0, + Subsignal("source_valid", Pins(1)), + Subsignal("source_ready", Pins(1)), + Subsignal("source_data", Pins(8)), + Subsignal("sink_valid", Pins(1)), + Subsignal("sink_ready", Pins(1)), + Subsignal("sink_data", Pins(8)), + ), + + ("lpddr4", 0, + Subsignal("clk_p", Pins(1)), + Subsignal("clk_n", Pins(1)), + Subsignal("cke", Pins(1)), + Subsignal("odt", Pins(1)), + Subsignal("reset_n", Pins(1)), + Subsignal("cs", Pins(1)), + Subsignal("ca", Pins(6)), + Subsignal("dqs", Pins(2)), + # Subsignal("dqs_n", Pins(2)), + Subsignal("dmi", Pins(2)), + Subsignal("dq", Pins(16)), + ), +] + +class Platform(SimPlatform): + def __init__(self): + SimPlatform.__init__(self, "SIM", _io) + +# Clocks ------------------------------------------------------------------------------------------- + +class Clocks(dict): # FORMAT: {name: {"freq_hz": _, "phase_deg": _}, ...} + def names(self): + return list(self.keys()) + + def add_io(self, io): + for name in self.names(): + print((name + "_clk", 0, Pins(1))) + io.append((name + "_clk", 0, Pins(1))) + + def add_clockers(self, sim_config): + for name, desc in self.items(): + sim_config.add_clocker(name + "_clk", **desc) + +class _CRG(Module): + def __init__(self, platform, domains=None): + if domains is None: + domains = ["sys"] + # request() before creating domains to avoid signal renaming problem + domains = {name: platform.request(name + "_clk") for name in domains} + + self.clock_domains.cd_por = ClockDomain(reset_less=True) + for name in domains.keys(): + setattr(self.clock_domains, "cd_" + name, ClockDomain(name=name)) + + int_rst = Signal(reset=1) + self.sync.por += int_rst.eq(0) + self.comb += self.cd_por.clk.eq(self.cd_sys.clk) + + for name, clk in domains.items(): + cd = getattr(self, "cd_" + name) + self.comb += cd.clk.eq(clk) + self.comb += cd.rst.eq(int_rst) + +def get_clocks(sys_clk_freq): + return Clocks({ + "sys": dict(freq_hz=sys_clk_freq), + "sys_11_25": dict(freq_hz=sys_clk_freq, phase_deg=11.25), + "sys8x": dict(freq_hz=8*sys_clk_freq), + "sys8x_ddr": dict(freq_hz=2*8*sys_clk_freq), + "sys8x_90": dict(freq_hz=8*sys_clk_freq, phase_deg=90), + "sys8x_90_ddr": dict(freq_hz=2*8*sys_clk_freq, phase_deg=2*90), + }) + +# SoC ---------------------------------------------------------------------------------------------- + +class SimSoC(SoCCore): + def __init__(self, clocks, log_level, auto_precharge=False, with_refresh=True, trace_reset=0, + disable_delay=False, **kwargs): + platform = Platform() + sys_clk_freq = clocks["sys"]["freq_hz"] + + # SoCCore ---------------------------------------------------------------------------------- + super().__init__(platform, + clk_freq = sys_clk_freq, + ident = "LiteX Simulation", + ident_version = True, + cpu_variant = "minimal", + **kwargs) + + # CRG -------------------------------------------------------------------------------------- + self.submodules.crg = _CRG(platform, clocks.names()) + + # Debugging -------------------------------------------------------------------------------- + platform.add_debug(self, reset=trace_reset) + + # LPDDR4 ----------------------------------------------------------------------------------- + sdram_module = litedram_modules.MT53E256M16D1(sys_clk_freq, "1:8") + pads = platform.request("lpddr4") + self.submodules.ddrphy = LPDDR4SimPHY(sys_clk_freq=sys_clk_freq, aligned_reset_zero=True) + # fake delays (make no nsense in simulation, but sdram.c expects them) + self.ddrphy._rdly_dq_rst = CSR() + self.ddrphy._rdly_dq_inc = CSR() + self.add_csr("ddrphy") + + for p in ["clk_p", "clk_n", "cke", "odt", "reset_n", "cs", "ca", "dq", "dqs", "dmi"]: + self.comb += getattr(pads, p).eq(getattr(self.ddrphy.pads, p)) + + controller_settings = ControllerSettings() + controller_settings.auto_precharge = auto_precharge + controller_settings.with_refresh = with_refresh + + self.add_sdram("sdram", + phy = self.ddrphy, + module = sdram_module, + origin = self.mem_map["main_ram"], + size = kwargs.get("max_sdram_size", 0x40000000), + l2_cache_size = kwargs.get("l2_size", 8192), + l2_cache_min_data_width = kwargs.get("min_l2_data_width", 128), + l2_cache_reverse = False, + controller_settings = controller_settings + ) + # Reduce memtest size for simulation speedup + self.add_constant("MEMTEST_DATA_SIZE", 8*1024) + self.add_constant("MEMTEST_ADDR_SIZE", 8*1024) + + # LPDDR4 Sim ------------------------------------------------------------------------------- + self.submodules.lpddr4sim = LPDDR4Sim( + pads = self.ddrphy.pads, + settings = self.sdram.controller.settings, + sys_clk_freq = sys_clk_freq, + log_level = log_level, + disable_delay = disable_delay, + ) + self.add_csr("lpddr4sim") + + self.add_constant("CONFIG_SIM_DISABLE_BIOS_PROMPT") + if disable_delay: + self.add_constant("CONFIG_SIM_DISABLE_DELAYS") + + # Reuse DFITimingsChecker from phy/model.py + nphases = self.sdram.controller.settings.phy.nphases + timings = {"tCK": (1e9 / sys_clk_freq) / nphases} + for name in _speedgrade_timings + _technology_timings: + timings[name] = sdram_module.get(name) + + self.submodules.dfi_timings_checker = DFITimingsChecker( + dfi = self.ddrphy.dfi, + nbanks = 2**self.sdram.controller.settings.geom.bankbits, + nphases = nphases, + timings = timings, + refresh_mode = sdram_module.timing_settings.fine_refresh_mode, + memtype = self.sdram.controller.settings.phy.memtype, + verbose = False, + ) + + # Debug info ------------------------------------------------------------------------------- + def dump(obj): + print() + print(" " + obj.__class__.__name__) + print(" " + "-" * len(obj.__class__.__name__)) + d = obj if isinstance(obj, dict) else vars(obj) + for var, val in d.items(): + if var == "self": + continue + print(" {}: {}".format(var, val)) + + print("=" * 80) + dump(clocks) + dump(self.ddrphy.settings) + dump(sdram_module.geom_settings) + dump(sdram_module.timing_settings) + print() + print("=" * 80) + +# GTKWave ------------------------------------------------------------------------------------------ + +class SigTrace: + def __init__(self, name, alias=None, color=None, filter_file=None): + self.name = name + self.alias = alias + self.color = color + self.filter_file = filter_file + +def strip_bits(name): + if name.endswith("]") and "[" in name: + name = name[:name.rfind("[")] + return name + +def regex_map(sig, patterns, on_match, on_no_match, remove_bits=True): + # Given `patterns` return `on_match(sig, pattern)` if any pattern matches or else `on_no_match(sig)` + alias = sig.alias + if remove_bits: # get rid of signal bits (e.g. wb_adr[29:0]) + alias = strip_bits(alias) + for pattern in patterns: + if pattern.search(alias): + return on_match(sig, pattern) + return on_no_match(sig) + +def regex_filter(patterns, negate=False, **kwargs): + patterns = list(map(re.compile, patterns)) + def filt(sigs): + return list(filter(None, map(lambda sig: regex_map(sig, patterns, + on_match = lambda s, p: (s if not negate else None), + on_no_match = lambda s: (None if not negate else s), + **kwargs), sigs))) + return filt + +def regex_sorter(patterns, unmatched_last=True, **kwargs): + def sort(sigs): + order = {re.compile(pattern): i for i, pattern in enumerate(patterns)} + return sorted(sigs, key=lambda sig: regex_map(sig, order.keys(), + on_match = lambda s, p: order[p], + on_no_match = lambda s: len(order) if unmatched_last else -1, + **kwargs)) + return sort + +def suffixes2re(strings): + return ["{}$".format(s) for s in strings] + +def prefixes2re(strings): + return ["^{}".format(s) for s in strings] + +def strings2re(strings): + return suffixes2re(prefixes2re(strings)) + +def wishbone_sorter(**kwargs): + suffixes = ["cyc", "stb", "ack", "we", "sel", "adr", "dat_w", "dat_r"] + return regex_sorter(suffixes2re(suffixes), **kwargs) + +def dfi_sorter(phases=True, nphases_max=8, **kwargs): + suffixes = [ + "cas_n", "ras_n", "we_n", + "address", "bank", + "wrdata_en", "wrdata", "wrdata_mask", + "rddata_en", "rddata", "rddata_valid", + ] + if phases: + patterns = [] + for phase in range(nphases_max): + patterns.extend(["p{}_{}".format(phase, suffix) for suffix in suffixes]) + else: + patterns = suffixes + return regex_sorter(suffixes2re(patterns), **kwargs) + +def regex_colorer(color_patterns, default=None, **kwargs): + colors = {} + for color, patterns in color_patterns.items(): + for pattern in patterns: + colors[re.compile(pattern)] = color + + def add_color(sig, color): + sig.color = color + + def add_colors(sigs): + for sig in sigs: + regex_map(sig, colors.keys(), + on_match = lambda s, p: add_color(s, colors[p]), + on_no_match = lambda s: add_color(s, default), + **kwargs) + return sigs + return add_colors + +def dfi_per_phase_colorer(nphases_max=8, **kwargs): + colors = ["normal", "yellow", "orange", "red"] + color_patterns = {} + for p in range(nphases_max): + color = colors[p % len(colors)] + patterns = color_patterns.get(color, []) + patterns.append("p{}_".format(p)) + color_patterns[color] = patterns + return regex_colorer(color_patterns, default="indigo", **kwargs) + +def dfi_in_phase_colorer(**kwargs): + return regex_colorer({ + "normal": suffixes2re(["cas_n", "ras_n", "we_n"]), + "yellow": suffixes2re(["address", "bank"]), + "orange": suffixes2re(["wrdata_en", "wrdata", "wrdata_mask"]), + "red": suffixes2re(["rddata_en", "rddata", "rddata_valid"]), + }, default="indigo", **kwargs) + +class LitexGTKWSave: + def __init__(self, vns, savefile, dumpfile, filtersdir=None, prefix="TOP.sim."): + self.vns = vns # Namespace output of Builder.build, required to resolve signal names + self.prefix = prefix + self.savefile = savefile + self.dumpfile = dumpfile + self.filtersdir = filtersdir + if self.filtersdir is None: + self.filtersdir = os.path.dirname(self.dumpfile) + + def __enter__(self): + # pyvcd: https://pyvcd.readthedocs.io/en/latest/vcd.gtkw.html + from vcd.gtkw import GTKWSave + self.file = open(self.savefile, "w") + self.gtkw = GTKWSave(self.file) + self.gtkw.dumpfile(self.dumpfile) + self.gtkw.treeopen("TOP") + self.gtkw.sst_expanded(True) + return self + + def __exit__(self, type, value, traceback): + self.file.close() + print("\nGenerated GTKWave save file at: {}\n".format(self.savefile)) + + def name(self, sig): + bits = "" + if len(sig) > 1: + bits = "[{}:0]".format(len(sig) - 1) + return self.vns.get_name(sig) + bits + + def signal(self, signal): + self.gtkw.trace(self.prefix + self.name(signal)) + + def common_prefix(self, names): + prefix = os.path.commonprefix(names) + last_underscore = prefix.rfind("_") + return prefix[:last_underscore + 1] + + def group(self, signals, group_name=None, alias=True, closed=True, + filter=None, sorter=None, colorer=None, translation_files=None, **kwargs): + translation_files = translation_files or {} + if len(signals) == 1: + return self.signal(signals[0]) + + names = [self.name(s) for s in signals] + common = self.common_prefix(names) + + make_alias = (lambda n: n[len(common):]) if alias else (lambda n: n) + sigs = [SigTrace(name=n, alias=make_alias(n)) for i, n in enumerate(names)] + if translation_files is not None: + for sig, file in zip(sigs, translation_files): + sig.filter_file = file + + for mapper in [filter, sorter, colorer]: + if mapper is not None: + sigs = list(mapper(sigs)) + + with self.gtkw.group(group_name or common.strip("_"), closed=closed): + for s in sigs: + self.gtkw.trace(self.prefix + s.name, alias=s.alias, color=s.color, + translate_filter_file=s.filter_file, **kwargs) + + def by_regex(self, regex, **kwargs): + pattern = re.compile(regex) + for sig in self.vns.pnd.keys(): + m = pattern.search(self.vns.pnd[sig]) + signals = list(filter(lambda sig: pattern.search(self.vns.pnd[sig]), self.vns.pnd.keys())) + assert len(signals) > 0, "No match found for {}".format(regex) + return self.group(signals, **kwargs) + + def clocks(self, **kwargs): + clks = [cd.clk for cd in self.vns.clock_domains] + self.group(clks, group_name="clocks", alias=False, closed=False, **kwargs) + + def add(self, obj, **kwargs): + if isinstance(obj, Record): + self.group([s for s, _ in obj.iter_flat()], **kwargs) + elif isinstance(obj, Signal): + self.signal(obj) + else: + raise NotImplementedError(type(obj), obj) + + def make_fsm_state_translation(self, fsm): + # generate filter file + from vcd.gtkw import make_translation_filter + translations = list(fsm.decoding.items()) + filename = "filter__{}.txt".format(strip_bits(self.name(fsm.state))) + filepath = os.path.join(self.filtersdir, filename) + with open(filepath, 'w') as f: + f.write(make_translation_filter(translations, size=len(fsm.state))) + return filepath + + def iter_submodules(self, fragment): + for name, module in getattr(fragment, "_submodules", []): + yield module + yield from self.iter_submodules(module) + + def fsm_states(self, soc, **kwargs): + # TODO: generate alias names for the machines, because the defaults are hard to decipher + fsms = list(filter(lambda module: isinstance(module, FSM), self.iter_submodules(soc))) + states = [fsm.state for fsm in fsms] + files = [self.make_fsm_state_translation(fsm) for fsm in fsms] + self.group(states, group_name="FSM states", translation_files=files, **kwargs) + +# Build -------------------------------------------------------------------------------------------- + +def generate_gtkw_savefile(builder, vns, trace_fst): + dumpfile = os.path.join(builder.gateware_dir, "sim.{}".format("fst" if trace_fst else "vcd")) + savefile = os.path.join(builder.gateware_dir, "sim.gtkw") + soc = builder.soc + + with LitexGTKWSave(vns, savefile=savefile, dumpfile=dumpfile) as gtkw: + gtkw.clocks() + gtkw.add(soc.bus.slaves["main_ram"], sorter=wishbone_sorter()) + # all dfi signals + gtkw.add(soc.ddrphy.dfi, sorter=dfi_sorter(), colorer=dfi_in_phase_colorer()) + # each phase in separate group + with gtkw.gtkw.group("dfi phaseX", closed=True): + for i, phase in enumerate(soc.ddrphy.dfi.phases): + gtkw.add(phase, + group_name = "dfi p{}".format(i), + sorter = dfi_sorter(phases=False), + colorer = dfi_in_phase_colorer()) + # only dfi command signals + gtkw.add(soc.ddrphy.dfi, + group_name = "dfi commands", + filter = regex_filter(suffixes2re(["cas_n", "ras_n", "we_n"])), + sorter = dfi_sorter(), + colorer = dfi_per_phase_colorer()) + # only dfi data signals + gtkw.add(soc.ddrphy.dfi, + group_name = "dfi wrdata", + filter = regex_filter(suffixes2re(["wrdata"])), + sorter = dfi_sorter(), + colorer = dfi_per_phase_colorer()) + gtkw.add(soc.ddrphy.dfi, + group_name = "dfi rddata", + filter = regex_filter(suffixes2re(["rddata"])), + sorter = dfi_sorter(), + colorer = dfi_per_phase_colorer()) + # dram apds + gtkw.by_regex("pads_", + filter = regex_filter(["clk_n$", "_[io]$", "_oe$"], negate=True), + sorter = regex_sorter(suffixes2re(["cke", "odt", "reset_n", "clk_p", "cs", "ca", "dq", "dqs", "dmi"])), + colorer = regex_colorer({ + "yellow": suffixes2re(["cs", "ca"]), + "orange": suffixes2re(["dq"]), + }), + ) + gtkw.fsm_states(soc) + +def main(): + parser = argparse.ArgumentParser(description="Generic LiteX SoC Simulation") + builder_args(parser) + soc_sdram_args(parser) + parser.add_argument("--sdram-verbosity", default=0, help="Set SDRAM checker verbosity") + parser.add_argument("--trace", action="store_true", help="Enable Tracing") + parser.add_argument("--trace-fst", action="store_true", help="Enable FST tracing (default=VCD)") + parser.add_argument("--trace-start", default=0, help="Cycle to start tracing") + parser.add_argument("--trace-end", default=-1, help="Cycle to end tracing") + parser.add_argument("--trace-reset", default=0, help="Initial traceing state") + parser.add_argument("--sys-clk-freq", default="50e6", help="Core clock frequency") + parser.add_argument("--auto-precharge", action="store_true", help="Use DRAM auto precharge") + parser.add_argument("--no-refresh", action="store_true", help="Disable DRAM refresher") + parser.add_argument("--log-level", default="all=INFO", help="Set simulation logging level") + parser.add_argument("--disable-delay", action="store_true", help="Disable CPU delays") + parser.add_argument("--gtkw-savefile", action="store_true", help="Generate GTKWave savefile") + args = parser.parse_args() + + soc_kwargs = soc_sdram_argdict(args) + builder_kwargs = builder_argdict(args) + + sim_config = SimConfig() + sys_clk_freq = int(float(args.sys_clk_freq)) + clocks = get_clocks(sys_clk_freq) + clocks.add_io(_io) + clocks.add_clockers(sim_config) + + # Configuration -------------------------------------------------------------------------------- + cpu = CPUS[soc_kwargs.get("cpu_type", "vexriscv")] + if soc_kwargs["uart_name"] == "serial": + soc_kwargs["uart_name"] = "sim" + sim_config.add_module("serial2console", "serial") + args.with_sdram = True + soc_kwargs["integrated_main_ram_size"] = 0x0 + soc_kwargs["sdram_verbosity"] = int(args.sdram_verbosity) + + # SoC ------------------------------------------------------------------------------------------ + soc = SimSoC( + clocks = clocks, + auto_precharge = args.auto_precharge, + with_refresh = not args.no_refresh, + trace_reset = int(args.trace_reset), + log_level = args.log_level, + disable_delay = args.disable_delay, + **soc_kwargs) + + # Build/Run ------------------------------------------------------------------------------------ + builder_kwargs["csr_csv"] = "csr.csv" + builder = Builder(soc, **builder_kwargs) + build_kwargs = dict( + sim_config = sim_config, + trace = args.trace, + trace_fst = args.trace_fst, + trace_start = int(args.trace_start), + trace_end = int(args.trace_end) + ) + vns = builder.build(run=False, **build_kwargs) + + if args.gtkw_savefile: + generate_gtkw_savefile(builder, vns, trace_fst=args.trace_fst) + + builder.build(build=False, **build_kwargs) + +if __name__ == "__main__": + main() diff --git a/litedram/phy/lpddr4/utils.py b/litedram/phy/lpddr4/utils.py index 2ef092e..92a5e9a 100644 --- a/litedram/phy/lpddr4/utils.py +++ b/litedram/phy/lpddr4/utils.py @@ -3,6 +3,8 @@ from operator import or_ from migen import * +from litex.soc.interconnect.csr import CSRStorage, AutoCSR + from litedram.common import TappedDelayLine @@ -22,6 +24,11 @@ def delayed(mod, sig, cycles=1): mod.submodules += delay return delay.output +def once(mod, cond, *ops): + sig = Signal() + mod.sync += If(cond, sig.eq(1)) + return If(~sig & cond, *ops) + class ConstBitSlip(Module): def __init__(self, dw, i=None, o=None, slp=None, cycles=1): self.i = Signal(dw, name='i') if i is None else i @@ -69,3 +76,57 @@ class DQSPattern(Module): o = Signal.like(self.o) self.sync += o.eq(self.o) self.o = o + + +class SimLogger(Module, AutoCSR): + # Allows to use Display inside FSM and to filter log messages by level (statically or dynamically) + DEBUG = 0 + INFO = 1 + WARN = 2 + ERROR = 3 + NONE = 4 + + def __init__(self, log_level=INFO, clk_freq=None): + self.ops = [] + self.level = Signal(reset=log_level, max=self.NONE) + self.time_ps = None + if clk_freq is not None: + self.time_ps = Signal(64) + cnt = Signal(64) + self.sync += cnt.eq(cnt + 1) + self.comb += self.time_ps.eq(cnt * int(1e12/clk_freq)) + + def debug(self, fmt, *args, **kwargs): + return self.log("[DEBUG] " + fmt, *args, level=self.DEBUG, **kwargs) + + def info(self, fmt, *args, **kwargs): + return self.log("[INFO] " + fmt, *args, level=self.INFO, **kwargs) + + def warn(self, fmt, *args, **kwargs): + return self.log("[WARN] " + fmt, *args, level=self.WARN, **kwargs) + + def error(self, fmt, *args, **kwargs): + return self.log("[ERROR] " + fmt, *args, level=self.ERROR, **kwargs) + + def log(self, fmt, *args, level=DEBUG, once=True): + cond = Signal() + if once: # make the condition be triggered only on rising edge + cond_d = Signal() + self.sync += cond_d.eq(cond) + condition = ~cond_d & cond + else: + condition = cond + + self.ops.append((level, condition, fmt, args)) + return cond.eq(1) + + def add_csrs(self): + self._level = CSRStorage(len(self.level), reset=self.level.reset.value) + self.comb += self.level.eq(self._level.storage) + + def do_finalize(self): + for level, cond, fmt, args in self.ops: + if self.time_ps is not None: + fmt = f"[%16d ps] {fmt}" + args = (self.time_ps, *args) + self.sync += If((level >= self.level) & cond, Display(fmt, *args))