diff --git a/litex/soc/cores/spi.py b/litex/soc/cores/spi.py index 1ebc3ff34..d31812265 100644 --- a/litex/soc/cores/spi.py +++ b/litex/soc/cores/spi.py @@ -41,94 +41,96 @@ class SPIMaster(Module, AutoCSR): # # # - done = Signal() - bits = Signal(8) - xfer = Signal() - shift = Signal() + clk_enable = Signal() + cs_enable = Signal() + count = Signal(max=data_width) + mosi_latch = Signal() + miso_latch = Signal() # Clock generation ------------------------------------------------------------------------- clk_divider = Signal(16) clk_rise = Signal() clk_fall = Signal() + self.comb += clk_rise.eq(clk_divider == (self.clk_divider[1:] - 1)) + self.comb += clk_fall.eq(clk_divider == (self.clk_divider - 1)) self.sync += [ - If(clk_rise, pads.clk.eq(xfer)), - If(clk_fall, pads.clk.eq(0)), - If(clk_fall, - clk_divider.eq(0) - ).Else( - clk_divider.eq(clk_divider + 1) + clk_divider.eq(clk_divider + 1), + If(clk_rise, + pads.clk.eq(clk_enable), + ).Elif(clk_fall, + clk_divider.eq(0), + pads.clk.eq(0), ) ] - self.comb += clk_rise.eq(clk_divider == (self.clk_divider[1:] - 1)) - self.comb += clk_fall.eq(clk_divider == (self.clk_divider - 1)) # Control FSM ------------------------------------------------------------------------------ self.submodules.fsm = fsm = FSM(reset_state="IDLE") fsm.act("IDLE", - done.eq(1), + self.done.eq(1), If(self.start, - NextValue(bits, 0), - NextState("WAIT-CLK-FALL") + self.done.eq(0), + mosi_latch.eq(1), + NextState("START") ) ) - fsm.act("WAIT-CLK-FALL", + fsm.act("START", + NextValue(count, 0), If(clk_fall, - NextState("XFER") + cs_enable.eq(1), + NextState("RUN") ) ) - fsm.act("XFER", - If(bits == self.length, - NextState("END") - ).Elif(clk_fall, - NextValue(bits, bits + 1) - ), - xfer.eq(1), - shift.eq(1) + fsm.act("RUN", + clk_enable.eq(1), + cs_enable.eq(1), + If(clk_fall, + NextValue(count, count + 1), + If(count == (self.length - 1), + NextState("STOP") + ) + ) ) - fsm.act("END", + fsm.act("STOP", + cs_enable.eq(1), If(clk_rise, + miso_latch.eq(1), + self.irq.eq(1), NextState("IDLE") - ), - shift.eq(1), - self.irq.eq(1) + ) ) - self.sync += self.done.eq(done & ~self.start) # Chip Select generation ------------------------------------------------------------------- if hasattr(pads, "cs_n"): for i in range(len(pads.cs_n)): - self.comb += pads.cs_n[i].eq(~self.cs[i] | ~xfer) + self.sync += pads.cs_n[i].eq(~self.cs[i] | ~cs_enable) # Master Out Slave In (MOSI) generation (generated on spi_clk falling edge) ---------------- - mosi_data = Array(self.mosi[i] for i in range(data_width)) - mosi_bit = Signal(max=data_width) + mosi_data = Signal(data_width) + mosi_array = Array(mosi_data[i] for i in range(data_width)) + mosi_sel = Signal(max=data_width) self.sync += [ - If(self.start, - mosi_bit.eq(self.length - 1 if mode == "aligned" else data_width - 1), - ).Elif(clk_rise & shift, - mosi_bit.eq(mosi_bit - 1) + If(mosi_latch, + mosi_data.eq(self.mosi), + mosi_sel.eq((self.length-1) if mode == "aligned" else (data_width-1)), + ).Elif(clk_fall, + If(cs_enable, pads.mosi.eq(mosi_array[mosi_sel])), + mosi_sel.eq(mosi_sel - 1) ), - If(clk_fall, - pads.mosi.eq(mosi_data[mosi_bit]) - ) ] # Master In Slave Out (MISO) capture (captured on spi_clk rising edge) -------------------- miso = Signal() miso_data = Signal(data_width) self.sync += [ - If(clk_rise & shift, + If(clk_rise, If(self.loopback, - miso.eq(pads.mosi) + miso_data.eq(Cat(pads.mosi, miso_data)) ).Else( - miso.eq(pads.miso) + miso_data.eq(Cat(pads.miso, miso_data)) ) - ), - If(clk_fall & shift, - miso_data.eq(Cat(miso, miso_data)) - ), - If(done, self.miso.eq(miso_data)), + ) ] + self.sync += If(miso_latch, self.miso.eq(miso_data)) def add_csr(self, with_cs=True, with_loopback=True): self._control = CSRStorage(fields=[ diff --git a/test/test_spi.py b/test/test_spi.py index 92d1ad827..e9684dae3 100644 --- a/test/test_spi.py +++ b/test/test_spi.py @@ -16,6 +16,7 @@ class TestSPI(unittest.TestCase): def test_spi_master_xfer_loopback_32b_32b(self): def generator(dut): yield dut.loopback.eq(1) + yield dut.clk_divider.eq(2) yield dut.mosi.eq(0xdeadbeef) yield dut.length.eq(32) yield dut.start.eq(1) @@ -24,7 +25,8 @@ class TestSPI(unittest.TestCase): yield while (yield dut.done) == 0: yield - self.assertEqual((yield dut.miso), 0xdeadbeef) + yield + self.assertEqual(hex((yield dut.miso)), hex(0xdeadbeef)) dut = SPIMaster(pads=None, data_width=32, sys_clk_freq=100e6, spi_clk_freq=5e6, with_csr=False) run_simulation(dut, generator(dut)) @@ -40,7 +42,8 @@ class TestSPI(unittest.TestCase): yield while (yield dut.done) == 0: yield - self.assertEqual((yield dut.miso), 0xbeef) + yield + self.assertEqual(hex((yield dut.miso)), hex(0xbeef)) dut = SPIMaster(pads=None, data_width=32, sys_clk_freq=100e6, spi_clk_freq=5e6, with_csr=False, mode="aligned") run_simulation(dut, generator(dut)) @@ -59,6 +62,8 @@ class TestSPI(unittest.TestCase): self.submodules.slave = SPISlave(pads, data_width=32) def master_generator(dut): + for i in range(8): + yield yield dut.master.mosi.eq(0xdeadbeef) yield dut.master.length.eq(32) yield dut.master.start.eq(1) @@ -67,15 +72,19 @@ class TestSPI(unittest.TestCase): yield while (yield dut.master.done) == 0: yield - self.assertEqual((yield dut.master.miso), 0x12345678) + yield + self.assertEqual(hex((yield dut.master.miso)), hex(0x12345678)) def slave_generator(dut): + for i in range(8): + yield yield dut.slave.miso.eq(0x12345678) while (yield dut.slave.start) == 0: yield while (yield dut.slave.done) == 0: yield - self.assertEqual((yield dut.slave.mosi), 0xdeadbeef) + yield + self.assertEqual(hex((yield dut.slave.mosi)), hex(0xdeadbeef)) self.assertEqual((yield dut.slave.length), 32) dut = DUT()