diff --git a/litex/soc/interconnect/wishbonebridge.py b/litex/soc/interconnect/wishbonebridge.py index 51d84c8c7..7448733df 100644 --- a/litex/soc/interconnect/wishbonebridge.py +++ b/litex/soc/interconnect/wishbonebridge.py @@ -144,7 +144,7 @@ class WishboneStreamingBridge(Module): phy.sink.stb.eq(1), If(phy.sink.ack, byte_counter_ce.eq(1), - If(byte_counter.value == 3, + If(byte_counter == 3, word_counter_ce.eq(1), If(word_counter == (length-1), NextState("IDLE") diff --git a/litex/soc/tools/remote/__init__.py b/litex/soc/tools/remote/__init__.py new file mode 100644 index 000000000..f5a29bcb9 --- /dev/null +++ b/litex/soc/tools/remote/__init__.py @@ -0,0 +1,3 @@ +from litex.soc.tools.remote.comm_uart import CommUART +from litex.soc.tools.remote.server import RemoteServer +from litex.soc.tools.remote.client import RemoteClient diff --git a/litex/soc/tools/remote/client.py b/litex/soc/tools/remote/client.py new file mode 100644 index 000000000..2c1e64d16 --- /dev/null +++ b/litex/soc/tools/remote/client.py @@ -0,0 +1,66 @@ +import socket + +from litex.soc.tools.remote.etherbone import EtherbonePacket, EtherboneRecord +from litex.soc.tools.remote.etherbone import EtherboneReads, EtherboneWrites +from litex.soc.tools.remote.etherbone import EtherboneIPC +from litex.soc.tools.remote.csr_builder import CSRBuilder + + +class RemoteClient(EtherboneIPC, CSRBuilder): + def __init__(self, host="localhost", port=1234, csr_csv="csr.csv", csr_data_width=32, debug=False): + CSRBuilder.__init__(self, self, csr_csv, csr_data_width) + self.host = host + self.port = port + self.debug = debug + + def open(self): + if hasattr(self, "socket"): + return + self.socket = socket.create_connection((self.host, self.port), 5.0) + self.socket.settimeout(1.0) + + def close(self): + if not hasattr(self, "socket"): + return + self.socket.close() + del self.socket + + def read(self, addr, length=1): + # prepare packet + record = EtherboneRecord() + record.reads = EtherboneReads(addrs=[addr + 4*j for j in range(length)]) + record.rcount = len(record.reads) + + # send packet + packet = EtherbonePacket() + packet.records = [record] + packet.encode() + self.send_packet(self.socket, packet[:]) + + # receive response + packet = EtherbonePacket(self.receive_packet(self.socket)) + packet.decode() + datas = packet.records.pop().writes.get_datas() + if self.debug: + for i, data in enumerate(datas): + print("read {:08x} @ {:08x}".format(data, addr + 4*i)) + if length == 1: + return datas[0] + else: + return datas + + def write(self, addr, datas): + if not isinstance(datas, list): + datas = [datas] + record = EtherboneRecord() + record.writes = EtherboneWrites(base_addr=addr, datas=[d for d in datas]) + record.wcount = len(record.writes) + + packet = EtherbonePacket() + packet.records = [record] + packet.encode() + self.send_packet(self.socket, packet) + + if self.debug: + for i, data in enumerate(datas): + print("write {:08x} @ {:08x}".format(data, addr + 4*i)) diff --git a/litex/soc/tools/remote/comm_uart.py b/litex/soc/tools/remote/comm_uart.py new file mode 100644 index 000000000..5845d102a --- /dev/null +++ b/litex/soc/tools/remote/comm_uart.py @@ -0,0 +1,64 @@ +import serial +import struct + + +class CommUART: + msg_type = { + "write": 0x01, + "read": 0x02 + } + def __init__(self, port, baudrate=115200, debug=False): + self.port = port + self.baudrate = str(baudrate) + self.csr_data_width = None + self.debug = debug + self.port = serial.serial_for_url(port, baudrate) + + def open(self, csr_data_width): + self.csr_data_width = csr_data_width + if hasattr(self, "port"): + return + self.port.open() + + def close(self): + if not hasattr(self, "port"): + return + del self.port + + def _read(self, length): + r = bytes() + while len(r) < length: + r += self.port.read(length - len(r)) + return r + + def _write(self, data): + remaining = len(data) + pos = 0 + while remaining: + written = self.port.write(data[pos:]) + remaining -= written + pos += written + + def read(self, addr, length=None): + r = [] + length_int = 1 if length is None else length + self._write([self.msg_type["read"], length_int]) + self._write(list((addr//4).to_bytes(4, byteorder="big"))) + for i in range(length_int): + data = int.from_bytes(self._read(4), "big") + if self.debug: + print("read {:08x} @ {:08x}".format(data, addr + 4*i)) + if length is None: + return data + r.append(data) + return r + + def write(self, addr, data): + data = data if isinstance(data, list) else [data] + length = len(data) + self._write([self.msg_type["write"], length]) + self._write(list((addr//4).to_bytes(4, byteorder="big"))) + for i in range(len(data)): + self._write(list(data[i].to_bytes(4, byteorder="big"))) + if self.debug: + print("write {:08x} @ {:08x}".format(data[i], addr + 4*i)) diff --git a/litex/soc/tools/remote/csr_builder.py b/litex/soc/tools/remote/csr_builder.py new file mode 100644 index 000000000..942929152 --- /dev/null +++ b/litex/soc/tools/remote/csr_builder.py @@ -0,0 +1,84 @@ +import csv + + +class CSRElements: + def __init__(self, d): + self.d = d + + def __getattr__(self, attr): + try: + return self.__dict__['d'][attr] + except KeyError: + pass + raise KeyError("No such element " + attr) + + +class CSRRegister: + def __init__(self, readfn, writefn, name, addr, length, data_width, mode): + self.readfn = readfn + self.writefn = writefn + self.addr = addr + self.length = length + self.data_width = data_width + self.mode = mode + + def read(self): + if self.mode not in ["rw", "ro"]: + raise KeyError(name + "register not readable") + datas = self.readfn(self.addr, length=self.length) + if isinstance(datas, int): + return datas + else: + data = 0 + for i in range(self.length): + data = data << self.data_width + data |= datas[i] + return data + + def write(self, value): + if self.mode not in ["rw", "wo"]: + raise KeyError(name + "register not writable") + datas = [] + for i in range(self.length): + datas.append((value >> ((self.length-1-i)*self.data_width)) & (2**self.data_width-1)) + self.writefn(self.addr, datas) + + +class CSRBuilder: + def __init__(self, comm, csr_csv, csr_data_width): + self.csr_data_width = csr_data_width + self.constants = self.build_constants(csr_csv) + self.bases = self.build_bases(csr_csv) + self.regs = self.build_registers(csr_csv, comm.read, comm.write) + + def build_bases(self, csr_csv): + csv_reader = csv.reader(open(csr_csv), delimiter=',', quotechar='#') + d = {} + for item in csv_reader: + group, name, addr, dummy0, dummy1 = item + if group == "csr_base": + d[name] = int(addr.replace("0x", ""), 16) + return CSRElements(d) + + def build_registers(self, csr_csv, readfn, writefn): + csv_reader = csv.reader(open(csr_csv), delimiter=',', quotechar='#') + d = {} + for item in csv_reader: + group, name, addr, length, mode = item + if group == "csr_register": + addr = int(addr.replace("0x", ""), 16) + length = int(length) + d[name] = CSRRegister(readfn, writefn, name, addr, length, self.csr_data_width, mode) + return CSRElements(d) + + def build_constants(self, csr_csv): + csv_reader = csv.reader(open(csr_csv), delimiter=',', quotechar='#') + d = {} + for item in csv_reader: + group, name, value, dummy0, dummy1 = item + if group == "constant": + try: + d[name] = int(value) + except: + d[name] = value + return CSRElements(d) diff --git a/litex/soc/tools/remote/etherbone.py b/litex/soc/tools/remote/etherbone.py new file mode 100644 index 000000000..e119d162d --- /dev/null +++ b/litex/soc/tools/remote/etherbone.py @@ -0,0 +1,376 @@ +import math +from copy import deepcopy +import struct + +from litex.soc.interconnect.stream_packet import HeaderField, Header + +etherbone_magic = 0x4e6f +etherbone_version = 1 +etherbone_packet_header_length = 8 +etherbone_packet_header_fields = { + "magic": HeaderField(0, 0, 16), + + "version": HeaderField(2, 4, 4), + "nr": HeaderField(2, 2, 1), + "pr": HeaderField(2, 1, 1), + "pf": HeaderField(2, 0, 1), + + "addr_size": HeaderField(3, 4, 4), + "port_size": HeaderField(3, 0, 4) +} +etherbone_packet_header = Header(etherbone_packet_header_fields, + etherbone_packet_header_length, + swap_field_bytes=True) + +etherbone_record_header_length = 4 +etherbone_record_header_fields = { + "bca": HeaderField(0, 0, 1), + "rca": HeaderField(0, 1, 1), + "rff": HeaderField(0, 2, 1), + "cyc": HeaderField(0, 4, 1), + "wca": HeaderField(0, 5, 1), + "wff": HeaderField(0, 6, 1), + + "byte_enable": HeaderField(1, 0, 8), + + "wcount": HeaderField(2, 0, 8), + + "rcount": HeaderField(3, 0, 8) +} +etherbone_record_header = Header(etherbone_record_header_fields, + etherbone_record_header_length, + swap_field_bytes=True) + + +def split_bytes(v, n, endianness="big"): + r = [] + r_bytes = v.to_bytes(n, byteorder=endianness) + for byte in r_bytes: + r.append(int(byte)) + return r + + +def merge_bytes(b, endianness="big"): + return int.from_bytes(bytes(b), endianness) + + +def get_field_data(field, datas): + v = merge_bytes(datas[field.byte:field.byte+math.ceil(field.width/8)]) + return (v >> field.offset) & (2**field.width-1) + + +class Packet(list): + def __init__(self, init=[]): + self.ongoing = False + self.done = False + for data in init: + self.append(data) + + +class EtherboneWrite: + def __init__(self, data): + self.data = data + + def __repr__(self): + return "WR32 0x{:08x}".format(self.data) + + +class EtherboneRead: + def __init__(self, addr): + self.addr = addr + + def __repr__(self): + return "RD32 @ 0x{:08x}".format(self.addr) + + +class EtherboneWrites(Packet): + def __init__(self, init=[], base_addr=0, datas=[]): + Packet.__init__(self, init) + self.base_addr = base_addr + self.writes = [] + self.encoded = init != [] + for data in datas: + self.add(EtherboneWrite(data)) + + def add(self, write): + self.writes.append(write) + + def get_datas(self): + datas = [] + for write in self.writes: + datas.append(write.data) + return datas + + def encode(self): + if self.encoded: + raise ValueError + for byte in split_bytes(self.base_addr, 4): + self.append(byte) + for write in self.writes: + for byte in split_bytes(write.data, 4): + self.append(byte) + self.encoded = True + + def decode(self): + if not self.encoded: + raise ValueError + base_addr = [] + for i in range(4): + base_addr.append(self.pop(0)) + self.base_addr = merge_bytes(base_addr) + self.writes = [] + while len(self) != 0: + write = [] + for i in range(4): + write.append(self.pop(0)) + self.writes.append(EtherboneWrite(merge_bytes(write))) + self.encoded = False + + def __repr__(self): + r = "Writes\n" + r += "--------\n" + r += "BaseAddr @ 0x{:08x}\n".format(self.base_addr) + for write in self.writes: + r += write.__repr__() + "\n" + return r + + +class EtherboneReads(Packet): + def __init__(self, init=[], base_ret_addr=0, addrs=[]): + Packet.__init__(self, init) + self.base_ret_addr = base_ret_addr + self.reads = [] + self.encoded = init != [] + for addr in addrs: + self.add(EtherboneRead(addr)) + + def add(self, read): + self.reads.append(read) + + def get_addrs(self): + addrs = [] + for read in self.reads: + addrs.append(read.addr) + return addrs + + def encode(self): + if self.encoded: + raise ValueError + for byte in split_bytes(self.base_ret_addr, 4): + self.append(byte) + for read in self.reads: + for byte in split_bytes(read.addr, 4): + self.append(byte) + self.encoded = True + + def decode(self): + if not self.encoded: + raise ValueError + base_ret_addr = [] + for i in range(4): + base_ret_addr.append(self.pop(0)) + self.base_ret_addr = merge_bytes(base_ret_addr) + self.reads = [] + while len(self) != 0: + read = [] + for i in range(4): + read.append(self.pop(0)) + self.reads.append(EtherboneRead(merge_bytes(read))) + self.encoded = False + + def __repr__(self): + r = "Reads\n" + r += "--------\n" + r += "BaseRetAddr @ 0x{:08x}\n".format(self.base_ret_addr) + for read in self.reads: + r += read.__repr__() + "\n" + return r + + +class EtherboneRecord(Packet): + def __init__(self, init=[]): + Packet.__init__(self, init) + self.writes = None + self.reads = None + self.bca = 0 + self.rca = 0 + self.rff = 0 + self.cyc = 0 + self.wca = 0 + self.wff = 0 + self.byte_enable = 0xf + self.wcount = 0 + self.rcount = 0 + self.encoded = init != [] + + + def get_writes(self): + if self.wcount == 0: + return None + else: + writes = [] + for i in range((self.wcount+1)*4): + writes.append(self.pop(0)) + return EtherboneWrites(writes) + + def get_reads(self): + if self.rcount == 0: + return None + else: + reads = [] + for i in range((self.rcount+1)*4): + reads.append(self.pop(0)) + return EtherboneReads(reads) + + def decode(self): + if not self.encoded: + raise ValueError + header = [] + for byte in self[:etherbone_record_header.length]: + header.append(self.pop(0)) + for k, v in sorted(etherbone_record_header.fields.items()): + setattr(self, k, get_field_data(v, header)) + self.writes = self.get_writes() + if self.writes is not None: + self.writes.decode() + self.reads = self.get_reads() + if self.reads is not None: + self.reads.decode() + self.encoded = False + + def set_writes(self, writes): + self.wcount = len(writes.writes) + writes.encode() + for byte in writes: + self.append(byte) + + def set_reads(self, reads): + self.rcount = len(reads.reads) + reads.encode() + for byte in reads: + self.append(byte) + + def encode(self): + if self.encoded: + raise ValueError + if self.writes is not None: + self.set_writes(self.writes) + if self.reads is not None: + self.set_reads(self.reads) + header = 0 + for k, v in sorted(etherbone_record_header.fields.items()): + value = merge_bytes(split_bytes(getattr(self, k), + math.ceil(v.width/8)), + "little") + header += (value << v.offset+(v.byte*8)) + for d in split_bytes(header, etherbone_record_header.length): + self.insert(0, d) + self.encoded = True + + def __repr__(self, n=0): + r = "Record {}\n".format(n) + r += "--------\n" + if self.encoded: + for d in self: + r += "{:02x}".format(d) + else: + for k in sorted(etherbone_record_header.fields.keys()): + r += k + " : 0x{:0x}\n".format(getattr(self, k)) + if self.wcount != 0: + r += self.writes.__repr__() + if self.rcount != 0: + r += self.reads.__repr__() + return r + + +class EtherbonePacket(Packet): + def __init__(self, init=[]): + Packet.__init__(self, init) + self.encoded = init != [] + self.records = [] + + self.magic = etherbone_magic + self.version = etherbone_version + self.addr_size = 32//8 + self.port_size = 32//8 + self.nr = 0 + self.pr = 0 + self.pf = 0 + + def get_records(self): + records = [] + done = False + payload = self + while len(payload) != 0: + record = EtherboneRecord(payload) + record.decode() + records.append(deepcopy(record)) + payload = record + return records + + def decode(self): + if not self.encoded: + raise ValueError + header = [] + for byte in self[:etherbone_packet_header.length]: + header.append(self.pop(0)) + for k, v in sorted(etherbone_packet_header.fields.items()): + setattr(self, k, get_field_data(v, header)) + self.records = self.get_records() + self.encoded = False + + def set_records(self, records): + for record in records: + record.encode() + for byte in record: + self.append(byte) + + def encode(self): + if self.encoded: + raise ValueError + self.set_records(self.records) + header = 0 + for k, v in sorted(etherbone_packet_header.fields.items()): + value = merge_bytes(split_bytes(getattr(self, k), math.ceil(v.width/8)), "little") + header += (value << v.offset+(v.byte*8)) + for d in split_bytes(header, etherbone_packet_header.length): + self.insert(0, d) + self.encoded = True + + def __repr__(self): + r = "Packet\n" + r += "--------\n" + if self.encoded: + for d in self: + r += "{:02x}".format(d) + else: + for k in sorted(etherbone_packet_header.fields.keys()): + r += k + " : 0x{:0x}\n".format(getattr(self, k)) + for i, record in enumerate(self.records): + r += record.__repr__(i) + return r + + +class EtherboneIPC: + def send_packet(self, socket, packet): + socket.sendall(bytes(packet)) + + def receive_packet(self, socket): + header_length = etherbone_packet_header_length + etherbone_record_header_length + packet = bytes() + while len(packet) < header_length: + chunk = socket.recv(header_length - len(packet)) + if len(chunk) == 0: + return 0 + else: + packet += chunk + wcount, rcount = struct.unpack(">BB", packet[header_length-2:]) + counts = wcount + rcount + packet_size = header_length + 4*(counts + 1) + while len(packet) < packet_size: + chunk = socket.recv(packet_size - len(packet)) + if len(chunk) == 0: + return 0 + else: + packet += chunk + return packet diff --git a/litex/soc/tools/remote/server.py b/litex/soc/tools/remote/server.py new file mode 100644 index 000000000..5f6bb99f6 --- /dev/null +++ b/litex/soc/tools/remote/server.py @@ -0,0 +1,104 @@ +import socket +import threading +import argparse + +from litex.soc.tools.remote.etherbone import EtherbonePacket, EtherboneRecord, EtherboneWrites +from litex.soc.tools.remote.etherbone import EtherboneIPC + + +class RemoteServer(EtherboneIPC): + def __init__(self, comm, port=1234, csr_data_width=32): + self.comm = comm + self.port = port + self.csr_data_width = 32 + + def open(self): + if hasattr(self, "socket"): + return + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.bind(("localhost", self.port)) + self.socket.listen(1) + self.comm.open(self.csr_data_width) + + def close(self): + self.comm.close() + if not hasattr(self, "socket"): + return + self.socket.close() + del self.socket + + def _serve_thread(self): + while True: + client_socket, addr = self.socket.accept() + print("Connected with " + addr[0] + ":" + str(addr[1])) + try: + while True: + packet = self.receive_packet(client_socket) + if packet == 0: + break + packet = EtherbonePacket(packet) + packet.decode() + + record = packet.records.pop() + + # writes: + if record.writes != None: + self.comm.write(record.writes.base_addr, record.writes.get_datas()) + + # reads + if record.reads != None: + reads = [] + for addr in record.reads.get_addrs(): + reads.append(self.comm.read(addr)) + + record = EtherboneRecord() + record.writes = EtherboneWrites(datas=reads) + record.wcount = len(record.writes) + + packet = EtherbonePacket() + packet.records = [record] + packet.encode() + self.send_packet(client_socket, packet) + finally: + print("Disconnect") + client_socket.close() + + def start(self): + self.serve_thread = threading.Thread(target=self._serve_thread) + self.serve_thread.setDaemon(True) + self.serve_thread.start() + + def join(self, writer_only=False): + if not hasattr(self, "serve_thread"): + return + self.serve_thread.join() + +def _get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--comm", default="uart", help="comm interface") + parser.add_argument("--port", default="2", help="UART port") + parser.add_argument("--baudrate", default=115200, help="UART baudrate") + parser.add_argument("--csr_data_width", default=32, help="CSR data_width") + return parser.parse_args() + +def main(): + args = _get_args() + if args.comm == "uart": + from litex.soc.tools.remote import CommUART + port = args.port if not args.port.isdigit() else int(args.port) + comm = CommUART(args.port if not args.port.isdigit() else int(args.port), + args.baudrate, + debug=False) + else: + raise NotImplementedError + + server = RemoteServer(comm, csr_data_width=args.csr_data_width) + server.open() + server.start() + try: + server.join(True) + except KeyboardInterrupt: # FIXME + pass + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index f1bb2abe9..e10cf41ec 100755 --- a/setup.py +++ b/setup.py @@ -36,6 +36,8 @@ setup( "console_scripts": [ "flterm=litex.soc.tools.flterm:main", "mkmscimg=litex.soc.tools.mkmscimg:main", + "remote_server=litex.soc.tools.remote.server:main", + "remote_client=litex.soc.tools.remote.client:main" ], }, )