import serial
from struct import *
from misoclib.tools.litescope.host.driver.reg import *

def write_b(uart, data):
	uart.write(pack('B',data))

class LiteScopeUARTDriver:
	cmds = {
		"write"	: 0x01,
		"read"	: 0x02
	}
	def __init__(self, port, baudrate=115200, addrmap=None, busword=8, debug=False):
		self.port = port
		self.baudrate = str(baudrate)
		self.debug = debug
		self.uart = serial.Serial(port, baudrate, timeout=0.25)
		if addrmap is not None:
			self.regs = build_map(addrmap, busword, self.read, self.write)

	def open(self):
		self.uart.flushOutput()
		self.uart.close()
		self.uart.open()
		self.uart.flushInput()
		try:
			self.regs.uart2wb_sel.write(1)
		except:
			pass

	def close(self):
		try:
			self.regs.uart2wb_sel.write(0)
		except:
			pass
		self.uart.flushOutput()
		self.uart.close()

	def read(self, addr, burst_length=None, repeats=None):
		datas = []
		def to_int(v):
			return 1 if v is None else v
		for i in range(to_int(repeats)):
			self.uart.flushInput()
			write_b(self.uart, self.cmds["read"])
			write_b(self.uart, burst_length)
			write_b(self.uart, (addr//4 & 0xff000000) >> 24)
			write_b(self.uart, (addr//4 & 0x00ff0000) >> 16)
			write_b(self.uart, (addr//4 & 0x0000ff00) >> 8)
			write_b(self.uart, (addr//4 & 0x000000ff))
			for j in range(to_int(burst_length)):
				data = 0
				for k in range(4):
					data = data << 8
					data |= ord(self.uart.read())
				if self.debug:
					print("RD %08X @ %08X" %(data, (addr+j)*4))
				datas.append(data)
		return datas

	def write(self, addr, data):
		if isinstance(data, list):
			burst_length = len(data)
		else:
			burst_length = 1
		write_b(self.uart, self.cmds["write"])
		write_b(self.uart, burst_length)
		write_b(self.uart, (addr//4 & 0xff000000) >> 24)
		write_b(self.uart, (addr//4 & 0x00ff0000) >> 16)
		write_b(self.uart, (addr//4 & 0x0000ff00) >> 8)
		write_b(self.uart, (addr//4 & 0x000000ff))
		if isinstance(data, list):
			for i in range(len(data)):
				dat = data[i]
				for j in range(4):
					write_b(self.uart, (dat & 0xff000000) >> 24)
					dat = dat << 8
				if self.debug:
					print("WR %08X @ %08X" %(data[i], (addr + i)*4))
		else:
			dat = data
			for j in range(4):
				write_b(self.uart, (dat & 0xff000000) >> 24)
				dat = dat << 8
			if self.debug:
				print("WR %08X @ %08X" %(data, (addr * 4)))