Merge pull request #1822 from trabucayre/rework_tools

tools: litex_server, litex_client and remote: don't hardcode bus address
This commit is contained in:
enjoy-digital 2023-10-30 17:13:46 +01:00 committed by GitHub
commit 4b9601bdab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 130 additions and 58 deletions

View file

@ -21,16 +21,21 @@ from litex.tools.remote.csr_builder import CSRBuilder
# Remote Client ------------------------------------------------------------------------------------
class RemoteClient(EtherboneIPC, CSRBuilder):
def __init__(self, host="localhost", port=1234, base_address=0, csr_csv=None, csr_data_width=None, debug=False):
def __init__(self, host="localhost", port=1234, base_address=0, csr_csv=None, csr_data_width=None,
csr_bus_address_width=None, debug=False):
# If csr_csv set to None and local csr.csv file exists, use it.
if csr_csv is None and os.path.exists("csr.csv"):
csr_csv = "csr.csv"
# If valid csr_csv file found, build the CSRs.
if csr_csv is not None:
CSRBuilder.__init__(self, self, csr_csv, csr_data_width)
# Else if csr_data_width set to None, force to csr_data_width 32-bit.
elif csr_data_width is None:
csr_data_width = 32
else:
# Else if csr_data_width set to None, force to csr_data_width 32-bit.
if csr_data_width is None:
csr_data_width = 32
# Else if csr_bus_address_width set to None, force to csr_bus_address_width 32-bit.
if self.csr_bus_address_width is None:
self.csr_bus_address_width = 32
self.host = host
self.port = port
self.debug = debug
@ -61,20 +66,27 @@ class RemoteClient(EtherboneIPC, CSRBuilder):
def read(self, addr, length=None, burst="incr"):
length_int = 1 if length is None else length
addr_size = self.csr_bus_address_width // 8
# Prepare packet
record = EtherboneRecord()
record = EtherboneRecord(addr_size)
incr = (burst == "incr")
record.reads = EtherboneReads(addrs=[self.base_address + addr + 4*incr*j for j in range(length_int)])
record.reads = EtherboneReads(
addr_size = addr_size,
addrs = [self.base_address + addr + 4*incr*j for j in range(length_int)]
)
record.rcount = len(record.reads)
# Send packet
packet = EtherbonePacket()
packet = EtherbonePacket(self.csr_bus_address_width)
packet.records = [record]
packet.encode()
self.send_packet(self.socket, packet)
# Receive response
packet = EtherbonePacket(self.receive_packet(self.socket))
packet = EtherbonePacket(
addr_width = self.csr_bus_address_width,
init = self.receive_packet(self.socket, addr_size)
)
packet.decode()
datas = packet.records.pop().writes.get_datas()
if self.debug:
@ -84,11 +96,16 @@ class RemoteClient(EtherboneIPC, CSRBuilder):
def write(self, addr, datas):
datas = datas if isinstance(datas, list) else [datas]
record = EtherboneRecord()
record.writes = EtherboneWrites(base_addr=self.base_address + addr, datas=[d for d in datas])
addr_size = self.csr_bus_address_width // 8
record = EtherboneRecord(addr_size)
record.writes = EtherboneWrites(
base_addr = self.base_address + addr,
addr_size = addr_size,
datas = [d for d in datas]
)
record.wcount = len(record.writes)
packet = EtherbonePacket()
packet = EtherbonePacket(self.csr_bus_address_width)
packet.records = [record]
packet.encode()
self.send_packet(self.socket, packet)

View file

@ -70,11 +70,12 @@ def _read_merger(addrs, max_length=256, bursts=["incr", "fixed"]):
# Remote Server ------------------------------------------------------------------------------------
class RemoteServer(EtherboneIPC):
def __init__(self, comm, bind_ip, bind_port=1234):
self.comm = comm
self.bind_ip = bind_ip
self.bind_port = bind_port
self.lock = False
def __init__(self, comm, bind_ip, bind_port=1234, addr_width=32):
self.comm = comm
self.bind_ip = bind_ip
self.bind_port = bind_port
self.lock = False
self.addr_width = addr_width
def open(self):
if hasattr(self, "socket"):
@ -115,14 +116,13 @@ class RemoteServer(EtherboneIPC):
while True:
# Receive packet.
try:
packet = self.receive_packet(client_socket)
packet = self.receive_packet(client_socket, self.addr_width // 8)
if packet == 0:
break
except:
break
# Decode Packet.
packet = EtherbonePacket(packet)
packet = EtherbonePacket(self.addr_width, packet)
packet.decode()
# Get Packet's Record.
@ -152,11 +152,12 @@ class RemoteServer(EtherboneIPC):
bursts = bursts):
reads += self.comm.read(addr, length, burst)
record = EtherboneRecord()
record.writes = EtherboneWrites(datas=reads)
addr_size = self.addr_width // 8
record = EtherboneRecord(addr_size)
record.writes = EtherboneWrites(addr_size=addr_size, datas=reads)
record.wcount = len(record.writes)
packet = EtherbonePacket()
packet = EtherbonePacket(self.addr_width)
packet.records = [record]
packet.encode()
self.send_packet(client_socket, packet)
@ -181,6 +182,7 @@ def main():
# Common arguments
parser.add_argument("--bind-ip", default="localhost", help="Host bind address.")
parser.add_argument("--bind-port", default=1234, help="Host bind port.")
parser.add_argument("--addr-width", default=32, help="bus address width.")
parser.add_argument("--debug", action="store_true", help="Enable debug.")
# UART arguments
@ -220,7 +222,7 @@ def main():
uart_port = args.uart_port
uart_baudrate = int(float(args.uart_baudrate))
print("[CommUART] port: {} / baudrate: {} / ".format(uart_port, uart_baudrate), end="")
comm = CommUART(uart_port, uart_baudrate, debug=args.debug)
comm = CommUART(uart_port, uart_baudrate, debug=args.debug, addr_width=int(args.addr_width))
# JTAG mode
elif args.jtag:
@ -279,7 +281,7 @@ def main():
parser.print_help()
exit()
server = RemoteServer(comm, args.bind_ip, int(args.bind_port))
server = RemoteServer(comm, args.bind_ip, int(args.bind_port), addr_width=int(args.addr_width))
server.open()
server.start(4)
try:

View file

@ -19,11 +19,12 @@ CMD_READ_BURST_FIXED = 0x04
# CommUART -----------------------------------------------------------------------------------------
class CommUART(CSRBuilder):
def __init__(self, port, baudrate=115200, csr_csv=None, debug=False):
def __init__(self, port, baudrate=115200, csr_csv=None, debug=False, addr_width=32):
CSRBuilder.__init__(self, comm=self, csr_csv=csr_csv)
self.port = serial.serial_for_url(port, baudrate)
self.baudrate = str(baudrate)
self.debug = debug
self.port = serial.serial_for_url(port, baudrate)
self.baudrate = str(baudrate)
self.debug = debug
self.addr_bytes = addr_width // 8
def open(self):
if hasattr(self, "port"):
@ -63,7 +64,7 @@ class CommUART(CSRBuilder):
"fixed": CMD_READ_BURST_FIXED,
}[burst]
self._write([cmd, length_int])
self._write(list((addr//4).to_bytes(4, byteorder="big")))
self._write(list((addr//4).to_bytes(self.addr_bytes, byteorder="big")))
for i in range(length_int):
value = int.from_bytes(self._read(4), "big")
if self.debug:
@ -85,7 +86,7 @@ class CommUART(CSRBuilder):
"fixed": CMD_WRITE_BURST_FIXED,
}[burst]
self._write([cmd, size])
self._write(list(((addr//4 + offset).to_bytes(4, byteorder="big"))))
self._write(list(((addr//4 + offset).to_bytes(self.addr_bytes, byteorder="big"))))
for i, value in enumerate(data[offset:offset+size]):
self._write(list(value.to_bytes(4, byteorder="big")))
if self.debug:

View file

@ -64,7 +64,7 @@ class CSRMemoryRegion:
# CSR Builder --------------------------------------------------------------------------------------
class CSRBuilder:
def __init__(self, comm, csr_csv, csr_data_width=None):
def __init__(self, comm, csr_csv, csr_data_width=None, csr_bus_address_width=None):
if csr_csv is not None:
self.items = self.get_csr_items(csr_csv)
self.constants = self.build_constants()
@ -79,7 +79,18 @@ class CSRBuilder:
raise KeyError("csr_data_width of {} provided but {} found in constants".format(
csr_data_width, constant_csr_data_width))
self.csr_data_width = csr_data_width
# Load csr_data_width from the constants, otherwise it must be provided
constant_csr_bus_address_width = self.constants.d.get("config_bus_address_width", None)
if csr_bus_address_width is None:
csr_bus_address_width = constant_csr_bus_address_width
if csr_bus_address_width is None:
raise KeyError("csr_bus_address_width not found in constants, please provide!")
if csr_bus_address_width != constant_csr_bus_address_width:
raise KeyError("csr_bus_address_width of {} provided but {} found in constants".format(
csr_bus_address_width, constant_csr_bus_address_width))
self.csr_data_width = csr_data_width
self.csr_bus_address_width = csr_bus_address_width
self.bases = self.build_bases()
self.regs = self.build_registers(comm.read, comm.write)
self.mems = self.build_memories()

View file

@ -59,6 +59,8 @@ def get_field_data(field, datas):
pack_to_uint32 = struct.Struct('>I').pack
unpack_uint32_from = struct.Struct('>I').unpack
pack_to_uint64 = struct.Struct('>Q').pack
unpack_uint64_from = struct.Struct('>Q').unpack
# Packet -------------------------------------------------------------------------------------------
@ -88,13 +90,15 @@ class EtherboneRead:
# Etherbone Writes ---------------------------------------------------------------------------------
class EtherboneWrites(Packet):
def __init__(self, init=[], base_addr=0, datas=[]):
def __init__(self, addr_size, init=[], base_addr=0, datas=[]):
if isinstance(datas, list) and len(datas) > 255:
raise ValueError(f"Burst size of {len(datas)} exceeds maximum of 255 allowed by Etherbone.")
assert addr_size in [1, 2, 4, 8]
Packet.__init__(self, init)
self.base_addr = base_addr
self.writes = []
self.encoded = init != []
self.addr_size = addr_size
for data in datas:
self.add(EtherboneWrite(data))
@ -111,7 +115,10 @@ class EtherboneWrites(Packet):
if self.encoded:
raise ValueError
ba = bytearray()
ba += pack_to_uint32(self.base_addr)
if self.addr_size == 4:
ba += pack_to_uint32(self.base_addr)
else:
ba += pack_to_uint64(self.base_addr)
for write in self.writes:
ba += pack_to_uint32(write.data)
self.bytes = ba
@ -121,9 +128,12 @@ class EtherboneWrites(Packet):
if not self.encoded:
raise ValueError
ba = self.bytes
self.base_addr = unpack_uint32_from(ba[:4])[0]
if self.addr_size == 4:
self.base_addr = unpack_uint32_from(ba[:self.addr_size])[0]
else:
self.base_addr = unpack_uint64_from(ba[:self.addr_size])[0]
writes = []
offset = 4
offset = self.addr_size
length = len(ba)
while length > offset:
writes.append(EtherboneWrite(unpack_uint32_from(ba[offset:offset+4])[0]))
@ -142,13 +152,15 @@ class EtherboneWrites(Packet):
# Etherbone Reads ----------------------------------------------------------------------------------
class EtherboneReads(Packet):
def __init__(self, init=[], base_ret_addr=0, addrs=[]):
def __init__(self, addr_size, init=[], base_ret_addr=0, addrs=[]):
if isinstance(addrs, list) and len(addrs) > 255:
raise ValueError(f"Burst size of {len(addrs)} exceeds maximum of 255 allowed by Etherbone.")
assert addr_size in [1, 2, 4, 8]
Packet.__init__(self, init)
self.base_ret_addr = base_ret_addr
self.reads = []
self.encoded = init != []
self.reads = []
self.encoded = init != []
self.addr_size = addr_size
for addr in addrs:
self.add(EtherboneRead(addr))
@ -165,9 +177,15 @@ class EtherboneReads(Packet):
if self.encoded:
raise ValueError
ba = bytearray()
ba += pack_to_uint32(self.base_ret_addr)
if (self.addr_size == 4):
ba += pack_to_uint32(self.base_ret_addr)
else:
ba += pack_to_uint64(self.base_ret_addr)
for read in self.reads:
ba += pack_to_uint32(read.addr)
if self.addr_size == 4:
ba += pack_to_uint32(read.addr)
else:
ba += pack_to_uint64(read.addr)
self.bytes = ba
self.encoded = True
@ -175,13 +193,20 @@ class EtherboneReads(Packet):
if not self.encoded:
raise ValueError
ba = self.bytes
base_ret_addr = unpack_uint32_from(ba[:4])[0]
if self.addr_size == 4:
base_ret_addr = unpack_uint32_from(ba[:self.addr_size])[0]
else:
base_ret_addr = unpack_uint64_from(ba[:self.addr_size])[0]
reads = []
offset = 4
offset = self.addr_size
length = len(ba)
while length > offset:
reads.append(EtherboneRead(unpack_uint32_from(ba[offset:offset+4])[0]))
offset += 4
v = ba[offset:offset+self.addr_size]
if self.addr_size == 4:
reads.append(EtherboneRead(unpack_uint32_from(v)[0]))
else:
reads.append(EtherboneRead(unpack_uint64_from(v)[0]))
offset += self.addr_size
self.reads = reads
self.encoded = False
@ -196,7 +221,9 @@ class EtherboneReads(Packet):
# Etherbone Record ---------------------------------------------------------------------------------
class EtherboneRecord(Packet):
def __init__(self, init=[]):
def __init__(self, addr_size, init=[]):
assert addr_size in [1, 2, 4, 8]
Packet.__init__(self, init)
self.writes = None
self.reads = None
@ -210,6 +237,7 @@ class EtherboneRecord(Packet):
self.wcount = 0
self.rcount = 0
self.encoded = init != []
self.addr_size = addr_size
def decode(self):
if not self.encoded:
@ -223,14 +251,20 @@ class EtherboneRecord(Packet):
# Decode writes
if self.wcount:
self.writes = EtherboneWrites(self.bytes[offset:offset + 4*(self.wcount+1)])
offset += 4*(self.wcount+1)
init_length = (4 * self.wcount) + (self.addr_size)
self.writes = EtherboneWrites(
addr_size = self.addr_size,
init = self.bytes[offset:offset + init_length])
offset += init_length
self.writes.decode()
# Decode reads
if self.rcount:
self.reads = EtherboneReads(self.bytes[offset:offset + 4*(self.rcount+1)])
offset += 4*(self.rcount+1)
init_length = (self.rcount + 1) * self.addr_size
self.reads = EtherboneReads(
addr_size = self.addr_size,
init = self.bytes[offset:offset + init_length])
offset += init_length
self.reads.decode()
self.encoded = False
@ -283,15 +317,17 @@ class EtherboneRecord(Packet):
# Etherbone Packet ---------------------------------------------------------------------------------
class EtherbonePacket(Packet):
def __init__(self, init=[]):
def __init__(self, addr_width, init=[]):
assert addr_width in [8, 16, 32, 64]
Packet.__init__(self, init)
self.encoded = init != []
self.records = []
self.magic = etherbone_magic
self.version = etherbone_version
self.addr_size = 32//8
self.port_size = 32//8
self.addr_size = addr_width//8
self.port_size = 4 # FIXME: use data_size
self.nr = 0
self.pr = 0
self.pf = 0
@ -311,14 +347,14 @@ class EtherbonePacket(Packet):
# Decode records
length = len(ba)
while length > offset:
record = EtherboneRecord(ba[offset:])
record = EtherboneRecord(addr_size=self.addr_size, init=ba[offset:])
record.decode()
self.records.append(record)
offset += etherbone_record_header.length
if record.wcount:
offset += 4*(record.wcount + 1)
offset += (4*record.wcount) + self.addr_size
if record.rcount:
offset += 4*(record.rcount + 1)
offset += (record.rcount + 1) * self.addr_size
self.encoded = False
@ -362,7 +398,8 @@ class EtherboneIPC:
def send_packet(self, socket, packet):
socket.sendall(packet.bytes)
def receive_packet(self, socket):
def receive_packet(self, socket, addr_size):
assert addr_size in [1, 2, 4, 8]
header_length = etherbone_packet_header_length + etherbone_record_header_length
packet = bytes()
while len(packet) < header_length:
@ -373,7 +410,11 @@ class EtherboneIPC:
packet += chunk
wcount, rcount = struct.unpack(">BB", packet[header_length-2:])
counts = wcount + rcount
packet_size = header_length + 4*(counts + 1)
packet_size = header_length
if wcount != 0:
packet_size += 4 * (wcount ) + addr_size
if rcount != 0:
packet_size += (rcount + 1 ) * addr_size
while len(packet) < packet_size:
chunk = socket.recv(packet_size - len(packet))
if len(chunk) == 0: