diff --git a/litedram/dfii.py b/litedram/dfii.py index 3aa50e7..c73b62c 100644 --- a/litedram/dfii.py +++ b/litedram/dfii.py @@ -48,9 +48,12 @@ class PhaseInjector(Module, AutoCSR): class DFIInjector(Module, AutoCSR): def __init__(self, addressbits, bankbits, nranks, databits, nphases=1): - inti = dfi.Interface(addressbits, bankbits, nranks, databits, nphases) - self.slave = dfi.Interface(addressbits, bankbits, nranks, databits, nphases) - self.master = dfi.Interface(addressbits, bankbits, nranks, databits, nphases) + self.slave = dfi.Interface(addressbits, bankbits, nranks, databits, nphases) + self.master = dfi.Interface(addressbits, bankbits, nranks, databits, nphases) + csr_dfi = dfi.Interface(addressbits, bankbits, nranks, databits, nphases) + + self.ext_dfi = dfi.Interface(addressbits, bankbits, nranks, databits, nphases) + self.ext_dfi_sel = Signal() self._control = CSRStorage(fields=[ CSRField("sel", size=1, values=[ @@ -62,17 +65,21 @@ class DFIInjector(Module, AutoCSR): CSRField("reset_n", size=1), ]) - for n, phase in enumerate(inti.phases): + for n, phase in enumerate(csr_dfi.phases): setattr(self.submodules, "pi" + str(n), PhaseInjector(phase)) # # # - self.comb += If(self._control.fields.sel, - self.slave.connect(self.master) - ).Else( - inti.connect(self.master) + self.comb += If(self._control.fields.sel, # Hardware + If(self.ext_dfi_sel, # Hardware through ext_dfi + self.ext_dfi.connect(self.master) + ).Else( # Hardware by LiteDRAM controller + self.slave.connect(self.master) + ) + ).Else( # Software through CSRs + csr_dfi.connect(self.master) ) for i in range(nranks): - self.comb += [phase.cke[i].eq(self._control.fields.cke) for phase in inti.phases] - self.comb += [phase.odt[i].eq(self._control.fields.odt) for phase in inti.phases if hasattr(phase, "odt")] - self.comb += [phase.reset_n.eq(self._control.fields.reset_n) for phase in inti.phases if hasattr(phase, "reset_n")] + self.comb += [phase.cke[i].eq(self._control.fields.cke) for phase in csr_dfi.phases] + self.comb += [phase.odt[i].eq(self._control.fields.odt) for phase in csr_dfi.phases if hasattr(phase, "odt")] + self.comb += [phase.reset_n.eq(self._control.fields.reset_n) for phase in csr_dfi.phases if hasattr(phase, "reset_n")]