from misoclib.mem.litesata.common import *
from misoclib.mem.litesata.core.link.scrambler import Scrambler

from migen.fhdl.decorators import ModuleDecorator
from migen.bank.description import *

class LiteSATABISTGenerator(Module):
	def __init__(self, user_port):
		self.start = Signal()
		self.sector = Signal(48)
		self.count = Signal(16)
		self.random = Signal()

		self.done = Signal()
		self.aborted = Signal()
		self.errors = Signal(32) # Note: Not used for writes

		###

		source, sink = user_port.sink, user_port.source

		counter = Counter(32)
		self.submodules += counter

		scrambler = scrambler = InsertReset(Scrambler())
		self.submodules += scrambler
		self.comb += [
			scrambler.reset.eq(counter.reset),
			scrambler.ce.eq(counter.ce)
		]

		self.fsm = fsm = FSM(reset_state="IDLE")
		self.submodules += fsm
		fsm.act("IDLE",
			self.done.eq(1),
			counter.reset.eq(1),
			If(self.start,
				NextState("SEND_CMD_AND_DATA")
			)
		)
		self.comb += [
			source.sop.eq(counter.value == 0),
			source.eop.eq(counter.value == (logical_sector_size//4*self.count)-1),
			source.write.eq(1),
			source.sector.eq(self.sector),
			source.count.eq(self.count),
			If(self.random,
				source.data.eq(scrambler.value)
			).Else(
				source.data.eq(counter.value)
			)
		]
		fsm.act("SEND_CMD_AND_DATA",
			source.stb.eq(1),
			If(source.stb & source.ack,
				counter.ce.eq(1),
				If(source.eop,
					NextState("WAIT_ACK")
				)
			)
		)
		fsm.act("WAIT_ACK",
			sink.ack.eq(1),
			If(sink.stb,
				NextState("IDLE")
			)
		)
		self.sync += If(sink.stb & sink.ack, self.aborted.eq(sink.failed))

class LiteSATABISTChecker(Module):
	def __init__(self, user_port):
		self.start = Signal()
		self.sector = Signal(48)
		self.count = Signal(16)
		self.random = Signal()

		self.done = Signal()
		self.aborted = Signal()
		self.errors = Signal(32)

		###

		source, sink = user_port.sink, user_port.source

		counter = Counter(32)
		error_counter = Counter(32)
		self.submodules += counter, error_counter
		self.comb += self.errors.eq(error_counter.value)

		scrambler = InsertReset(Scrambler())
		self.submodules += scrambler
		self.comb += [
			scrambler.reset.eq(counter.reset),
			scrambler.ce.eq(counter.ce)
		]

		self.fsm = fsm = FSM(reset_state="IDLE")
		self.submodules += self.fsm
		fsm.act("IDLE",
			self.done.eq(1),
			counter.reset.eq(1),
			If(self.start,
				error_counter.reset.eq(1),
				NextState("SEND_CMD")
			)
		)
		self.comb += [
			source.sop.eq(1),
			source.eop.eq(1),
			source.read.eq(1),
			source.sector.eq(self.sector),
			source.count.eq(self.count),
		]
		fsm.act("SEND_CMD",
			source.stb.eq(1),
			If(source.ack,
				counter.reset.eq(1),
				NextState("WAIT_ACK")
			)
		)
		fsm.act("WAIT_ACK",
			If(sink.stb & sink.read,
				NextState("RECEIVE_DATA")
			)
		)
		expected_data = Signal(32)
		self.comb += \
			If(self.random,
				expected_data.eq(scrambler.value)
			).Else(
				expected_data.eq(counter.value)
			)
		fsm.act("RECEIVE_DATA",
			sink.ack.eq(1),
			If(sink.stb,
				counter.ce.eq(1),
				If(sink.data != expected_data,
					error_counter.ce.eq(~sink.last)
				),
				If(sink.eop,
					If(sink.last,
						NextState("IDLE")
					).Else(
						NextState("WAIT_ACK")
					)
				)
			)
		)
		self.sync += If(sink.stb & sink.ack, self.aborted.eq(sink.failed))

class LiteSATABISTUnitCSR(Module, AutoCSR):
	def __init__(self, bist_unit):
		self._start = CSR()
		self._sector = CSRStorage(48)
		self._count = CSRStorage(16)
		self._loops = CSRStorage(8)
		self._random = CSRStorage()

		self._done = CSRStatus()
		self._aborted = CSRStatus()
		self._errors = CSRStatus(32)
		self._cycles = CSRStatus(32)

		###

		self.submodules += bist_unit

		start = self._start.r & self._start.re
		done = self._done.status
		loops = self._loops.storage

		self.comb += [
			bist_unit.sector.eq(self._sector.storage),
			bist_unit.count.eq(self._count.storage),
			bist_unit.random.eq(self._random.storage),

			self._aborted.status.eq(bist_unit.aborted),
			self._errors.status.eq(bist_unit.errors)
		]

		self.fsm = fsm = FSM(reset_state="IDLE")
		loop_counter = Counter(8)
		self.submodules += fsm, loop_counter
		fsm.act("IDLE",
			self._done.status.eq(1),
			loop_counter.reset.eq(1),
			If(start,
				NextState("CHECK")
			)
		)
		fsm.act("CHECK",
			If(loop_counter.value < loops,
				NextState("START")
			).Else(
				NextState("IDLE")
			)
		)
		fsm.act("START",
			bist_unit.start.eq(1),
			NextState("WAIT_DONE")
		)
		fsm.act("WAIT_DONE",
			If(bist_unit.done,
				loop_counter.ce.eq(1),
				NextState("CHECK")
			)
		)

		cycles_counter = Counter(32)
		self.submodules += cycles_counter
		self.sync += [
			cycles_counter.reset.eq(start),
			cycles_counter.ce.eq(~fsm.ongoing("IDLE")),
			self._cycles.status.eq(cycles_counter.value)
		]

class LiteSATABISTIdentify(Module):
	def __init__(self, user_port):
		self.start = Signal()
		self.done  = Signal()

		fifo = SyncFIFO([("data", 32)], 512, buffered=True)
		self.submodules += fifo
		self.source = fifo.source

		###

		source, sink = user_port.sink, user_port.source

		self.fsm = fsm = FSM(reset_state="IDLE")
		self.submodules += fsm
		fsm.act("IDLE",
			self.done.eq(1),
			If(self.start,
				NextState("SEND_CMD")
			)
		)
		self.comb += [
			source.sop.eq(1),
			source.eop.eq(1),
			source.identify.eq(1),
		]
		fsm.act("SEND_CMD",
			source.stb.eq(1),
			If(source.stb & source.ack,
				NextState("WAIT_ACK")
			)
		)
		fsm.act("WAIT_ACK",
			If(sink.stb & sink.identify,
				NextState("RECEIVE_DATA")
			)
		)
		self.comb += fifo.sink.data.eq(sink.data)
		fsm.act("RECEIVE_DATA",
			sink.ack.eq(fifo.sink.ack),
			If(sink.stb,
				fifo.sink.stb.eq(1),
				If(sink.eop,
					NextState("IDLE")
				)
			)
		)

class LiteSATABISTIdentifyCSR(Module, AutoCSR):
	def __init__(self, bist_identify):
		self._start = CSR()
		self._done = CSRStatus()
		self._source_stb = CSRStatus()
		self._source_ack = CSR()
		self._source_data = CSRStatus(32)

		###

		self.submodules += bist_identify
		self.comb += [
			bist_identify.start.eq(self._start.r & self._start.re),
			self._done.status.eq(bist_identify.done),

			self._source_stb.status.eq(bist_identify.source.stb),
			self._source_data.status.eq(bist_identify.source.data),
			bist_identify.source.ack.eq(self._source_ack.r & self._source_ack.re)
		]

class LiteSATABIST(Module, AutoCSR):
	def __init__(self, crossbar, with_csr=False):
		generator = LiteSATABISTGenerator(crossbar.get_port())
		checker = LiteSATABISTChecker(crossbar.get_port())
		identify = LiteSATABISTIdentify(crossbar.get_port())
		if with_csr:
			generator = LiteSATABISTUnitCSR(generator)
			checker = LiteSATABISTUnitCSR(checker)
			identify = LiteSATABISTIdentifyCSR(identify)
		self.submodules.generator = generator
		self.submodules.checker = checker
		self.submodules.identify = identify