soc/tools: initialize wishbone remote control (for now only uart)

This commit is contained in:
Florent Kermarrec 2015-11-16 17:46:36 +01:00
parent 1cde84dccf
commit 71483b8935
8 changed files with 700 additions and 1 deletions

View file

@ -144,7 +144,7 @@ class WishboneStreamingBridge(Module):
phy.sink.stb.eq(1),
If(phy.sink.ack,
byte_counter_ce.eq(1),
If(byte_counter.value == 3,
If(byte_counter == 3,
word_counter_ce.eq(1),
If(word_counter == (length-1),
NextState("IDLE")

View file

@ -0,0 +1,3 @@
from litex.soc.tools.remote.comm_uart import CommUART
from litex.soc.tools.remote.server import RemoteServer
from litex.soc.tools.remote.client import RemoteClient

View file

@ -0,0 +1,66 @@
import socket
from litex.soc.tools.remote.etherbone import EtherbonePacket, EtherboneRecord
from litex.soc.tools.remote.etherbone import EtherboneReads, EtherboneWrites
from litex.soc.tools.remote.etherbone import EtherboneIPC
from litex.soc.tools.remote.csr_builder import CSRBuilder
class RemoteClient(EtherboneIPC, CSRBuilder):
def __init__(self, host="localhost", port=1234, csr_csv="csr.csv", csr_data_width=32, debug=False):
CSRBuilder.__init__(self, self, csr_csv, csr_data_width)
self.host = host
self.port = port
self.debug = debug
def open(self):
if hasattr(self, "socket"):
return
self.socket = socket.create_connection((self.host, self.port), 5.0)
self.socket.settimeout(1.0)
def close(self):
if not hasattr(self, "socket"):
return
self.socket.close()
del self.socket
def read(self, addr, length=1):
# prepare packet
record = EtherboneRecord()
record.reads = EtherboneReads(addrs=[addr + 4*j for j in range(length)])
record.rcount = len(record.reads)
# send packet
packet = EtherbonePacket()
packet.records = [record]
packet.encode()
self.send_packet(self.socket, packet[:])
# receive response
packet = EtherbonePacket(self.receive_packet(self.socket))
packet.decode()
datas = packet.records.pop().writes.get_datas()
if self.debug:
for i, data in enumerate(datas):
print("read {:08x} @ {:08x}".format(data, addr + 4*i))
if length == 1:
return datas[0]
else:
return datas
def write(self, addr, datas):
if not isinstance(datas, list):
datas = [datas]
record = EtherboneRecord()
record.writes = EtherboneWrites(base_addr=addr, datas=[d for d in datas])
record.wcount = len(record.writes)
packet = EtherbonePacket()
packet.records = [record]
packet.encode()
self.send_packet(self.socket, packet)
if self.debug:
for i, data in enumerate(datas):
print("write {:08x} @ {:08x}".format(data, addr + 4*i))

View file

@ -0,0 +1,64 @@
import serial
import struct
class CommUART:
msg_type = {
"write": 0x01,
"read": 0x02
}
def __init__(self, port, baudrate=115200, debug=False):
self.port = port
self.baudrate = str(baudrate)
self.csr_data_width = None
self.debug = debug
self.port = serial.serial_for_url(port, baudrate)
def open(self, csr_data_width):
self.csr_data_width = csr_data_width
if hasattr(self, "port"):
return
self.port.open()
def close(self):
if not hasattr(self, "port"):
return
del self.port
def _read(self, length):
r = bytes()
while len(r) < length:
r += self.port.read(length - len(r))
return r
def _write(self, data):
remaining = len(data)
pos = 0
while remaining:
written = self.port.write(data[pos:])
remaining -= written
pos += written
def read(self, addr, length=None):
r = []
length_int = 1 if length is None else length
self._write([self.msg_type["read"], length_int])
self._write(list((addr//4).to_bytes(4, byteorder="big")))
for i in range(length_int):
data = int.from_bytes(self._read(4), "big")
if self.debug:
print("read {:08x} @ {:08x}".format(data, addr + 4*i))
if length is None:
return data
r.append(data)
return r
def write(self, addr, data):
data = data if isinstance(data, list) else [data]
length = len(data)
self._write([self.msg_type["write"], length])
self._write(list((addr//4).to_bytes(4, byteorder="big")))
for i in range(len(data)):
self._write(list(data[i].to_bytes(4, byteorder="big")))
if self.debug:
print("write {:08x} @ {:08x}".format(data[i], addr + 4*i))

View file

@ -0,0 +1,84 @@
import csv
class CSRElements:
def __init__(self, d):
self.d = d
def __getattr__(self, attr):
try:
return self.__dict__['d'][attr]
except KeyError:
pass
raise KeyError("No such element " + attr)
class CSRRegister:
def __init__(self, readfn, writefn, name, addr, length, data_width, mode):
self.readfn = readfn
self.writefn = writefn
self.addr = addr
self.length = length
self.data_width = data_width
self.mode = mode
def read(self):
if self.mode not in ["rw", "ro"]:
raise KeyError(name + "register not readable")
datas = self.readfn(self.addr, length=self.length)
if isinstance(datas, int):
return datas
else:
data = 0
for i in range(self.length):
data = data << self.data_width
data |= datas[i]
return data
def write(self, value):
if self.mode not in ["rw", "wo"]:
raise KeyError(name + "register not writable")
datas = []
for i in range(self.length):
datas.append((value >> ((self.length-1-i)*self.data_width)) & (2**self.data_width-1))
self.writefn(self.addr, datas)
class CSRBuilder:
def __init__(self, comm, csr_csv, csr_data_width):
self.csr_data_width = csr_data_width
self.constants = self.build_constants(csr_csv)
self.bases = self.build_bases(csr_csv)
self.regs = self.build_registers(csr_csv, comm.read, comm.write)
def build_bases(self, csr_csv):
csv_reader = csv.reader(open(csr_csv), delimiter=',', quotechar='#')
d = {}
for item in csv_reader:
group, name, addr, dummy0, dummy1 = item
if group == "csr_base":
d[name] = int(addr.replace("0x", ""), 16)
return CSRElements(d)
def build_registers(self, csr_csv, readfn, writefn):
csv_reader = csv.reader(open(csr_csv), delimiter=',', quotechar='#')
d = {}
for item in csv_reader:
group, name, addr, length, mode = item
if group == "csr_register":
addr = int(addr.replace("0x", ""), 16)
length = int(length)
d[name] = CSRRegister(readfn, writefn, name, addr, length, self.csr_data_width, mode)
return CSRElements(d)
def build_constants(self, csr_csv):
csv_reader = csv.reader(open(csr_csv), delimiter=',', quotechar='#')
d = {}
for item in csv_reader:
group, name, value, dummy0, dummy1 = item
if group == "constant":
try:
d[name] = int(value)
except:
d[name] = value
return CSRElements(d)

View file

@ -0,0 +1,376 @@
import math
from copy import deepcopy
import struct
from litex.soc.interconnect.stream_packet import HeaderField, Header
etherbone_magic = 0x4e6f
etherbone_version = 1
etherbone_packet_header_length = 8
etherbone_packet_header_fields = {
"magic": HeaderField(0, 0, 16),
"version": HeaderField(2, 4, 4),
"nr": HeaderField(2, 2, 1),
"pr": HeaderField(2, 1, 1),
"pf": HeaderField(2, 0, 1),
"addr_size": HeaderField(3, 4, 4),
"port_size": HeaderField(3, 0, 4)
}
etherbone_packet_header = Header(etherbone_packet_header_fields,
etherbone_packet_header_length,
swap_field_bytes=True)
etherbone_record_header_length = 4
etherbone_record_header_fields = {
"bca": HeaderField(0, 0, 1),
"rca": HeaderField(0, 1, 1),
"rff": HeaderField(0, 2, 1),
"cyc": HeaderField(0, 4, 1),
"wca": HeaderField(0, 5, 1),
"wff": HeaderField(0, 6, 1),
"byte_enable": HeaderField(1, 0, 8),
"wcount": HeaderField(2, 0, 8),
"rcount": HeaderField(3, 0, 8)
}
etherbone_record_header = Header(etherbone_record_header_fields,
etherbone_record_header_length,
swap_field_bytes=True)
def split_bytes(v, n, endianness="big"):
r = []
r_bytes = v.to_bytes(n, byteorder=endianness)
for byte in r_bytes:
r.append(int(byte))
return r
def merge_bytes(b, endianness="big"):
return int.from_bytes(bytes(b), endianness)
def get_field_data(field, datas):
v = merge_bytes(datas[field.byte:field.byte+math.ceil(field.width/8)])
return (v >> field.offset) & (2**field.width-1)
class Packet(list):
def __init__(self, init=[]):
self.ongoing = False
self.done = False
for data in init:
self.append(data)
class EtherboneWrite:
def __init__(self, data):
self.data = data
def __repr__(self):
return "WR32 0x{:08x}".format(self.data)
class EtherboneRead:
def __init__(self, addr):
self.addr = addr
def __repr__(self):
return "RD32 @ 0x{:08x}".format(self.addr)
class EtherboneWrites(Packet):
def __init__(self, init=[], base_addr=0, datas=[]):
Packet.__init__(self, init)
self.base_addr = base_addr
self.writes = []
self.encoded = init != []
for data in datas:
self.add(EtherboneWrite(data))
def add(self, write):
self.writes.append(write)
def get_datas(self):
datas = []
for write in self.writes:
datas.append(write.data)
return datas
def encode(self):
if self.encoded:
raise ValueError
for byte in split_bytes(self.base_addr, 4):
self.append(byte)
for write in self.writes:
for byte in split_bytes(write.data, 4):
self.append(byte)
self.encoded = True
def decode(self):
if not self.encoded:
raise ValueError
base_addr = []
for i in range(4):
base_addr.append(self.pop(0))
self.base_addr = merge_bytes(base_addr)
self.writes = []
while len(self) != 0:
write = []
for i in range(4):
write.append(self.pop(0))
self.writes.append(EtherboneWrite(merge_bytes(write)))
self.encoded = False
def __repr__(self):
r = "Writes\n"
r += "--------\n"
r += "BaseAddr @ 0x{:08x}\n".format(self.base_addr)
for write in self.writes:
r += write.__repr__() + "\n"
return r
class EtherboneReads(Packet):
def __init__(self, init=[], base_ret_addr=0, addrs=[]):
Packet.__init__(self, init)
self.base_ret_addr = base_ret_addr
self.reads = []
self.encoded = init != []
for addr in addrs:
self.add(EtherboneRead(addr))
def add(self, read):
self.reads.append(read)
def get_addrs(self):
addrs = []
for read in self.reads:
addrs.append(read.addr)
return addrs
def encode(self):
if self.encoded:
raise ValueError
for byte in split_bytes(self.base_ret_addr, 4):
self.append(byte)
for read in self.reads:
for byte in split_bytes(read.addr, 4):
self.append(byte)
self.encoded = True
def decode(self):
if not self.encoded:
raise ValueError
base_ret_addr = []
for i in range(4):
base_ret_addr.append(self.pop(0))
self.base_ret_addr = merge_bytes(base_ret_addr)
self.reads = []
while len(self) != 0:
read = []
for i in range(4):
read.append(self.pop(0))
self.reads.append(EtherboneRead(merge_bytes(read)))
self.encoded = False
def __repr__(self):
r = "Reads\n"
r += "--------\n"
r += "BaseRetAddr @ 0x{:08x}\n".format(self.base_ret_addr)
for read in self.reads:
r += read.__repr__() + "\n"
return r
class EtherboneRecord(Packet):
def __init__(self, init=[]):
Packet.__init__(self, init)
self.writes = None
self.reads = None
self.bca = 0
self.rca = 0
self.rff = 0
self.cyc = 0
self.wca = 0
self.wff = 0
self.byte_enable = 0xf
self.wcount = 0
self.rcount = 0
self.encoded = init != []
def get_writes(self):
if self.wcount == 0:
return None
else:
writes = []
for i in range((self.wcount+1)*4):
writes.append(self.pop(0))
return EtherboneWrites(writes)
def get_reads(self):
if self.rcount == 0:
return None
else:
reads = []
for i in range((self.rcount+1)*4):
reads.append(self.pop(0))
return EtherboneReads(reads)
def decode(self):
if not self.encoded:
raise ValueError
header = []
for byte in self[:etherbone_record_header.length]:
header.append(self.pop(0))
for k, v in sorted(etherbone_record_header.fields.items()):
setattr(self, k, get_field_data(v, header))
self.writes = self.get_writes()
if self.writes is not None:
self.writes.decode()
self.reads = self.get_reads()
if self.reads is not None:
self.reads.decode()
self.encoded = False
def set_writes(self, writes):
self.wcount = len(writes.writes)
writes.encode()
for byte in writes:
self.append(byte)
def set_reads(self, reads):
self.rcount = len(reads.reads)
reads.encode()
for byte in reads:
self.append(byte)
def encode(self):
if self.encoded:
raise ValueError
if self.writes is not None:
self.set_writes(self.writes)
if self.reads is not None:
self.set_reads(self.reads)
header = 0
for k, v in sorted(etherbone_record_header.fields.items()):
value = merge_bytes(split_bytes(getattr(self, k),
math.ceil(v.width/8)),
"little")
header += (value << v.offset+(v.byte*8))
for d in split_bytes(header, etherbone_record_header.length):
self.insert(0, d)
self.encoded = True
def __repr__(self, n=0):
r = "Record {}\n".format(n)
r += "--------\n"
if self.encoded:
for d in self:
r += "{:02x}".format(d)
else:
for k in sorted(etherbone_record_header.fields.keys()):
r += k + " : 0x{:0x}\n".format(getattr(self, k))
if self.wcount != 0:
r += self.writes.__repr__()
if self.rcount != 0:
r += self.reads.__repr__()
return r
class EtherbonePacket(Packet):
def __init__(self, init=[]):
Packet.__init__(self, init)
self.encoded = init != []
self.records = []
self.magic = etherbone_magic
self.version = etherbone_version
self.addr_size = 32//8
self.port_size = 32//8
self.nr = 0
self.pr = 0
self.pf = 0
def get_records(self):
records = []
done = False
payload = self
while len(payload) != 0:
record = EtherboneRecord(payload)
record.decode()
records.append(deepcopy(record))
payload = record
return records
def decode(self):
if not self.encoded:
raise ValueError
header = []
for byte in self[:etherbone_packet_header.length]:
header.append(self.pop(0))
for k, v in sorted(etherbone_packet_header.fields.items()):
setattr(self, k, get_field_data(v, header))
self.records = self.get_records()
self.encoded = False
def set_records(self, records):
for record in records:
record.encode()
for byte in record:
self.append(byte)
def encode(self):
if self.encoded:
raise ValueError
self.set_records(self.records)
header = 0
for k, v in sorted(etherbone_packet_header.fields.items()):
value = merge_bytes(split_bytes(getattr(self, k), math.ceil(v.width/8)), "little")
header += (value << v.offset+(v.byte*8))
for d in split_bytes(header, etherbone_packet_header.length):
self.insert(0, d)
self.encoded = True
def __repr__(self):
r = "Packet\n"
r += "--------\n"
if self.encoded:
for d in self:
r += "{:02x}".format(d)
else:
for k in sorted(etherbone_packet_header.fields.keys()):
r += k + " : 0x{:0x}\n".format(getattr(self, k))
for i, record in enumerate(self.records):
r += record.__repr__(i)
return r
class EtherboneIPC:
def send_packet(self, socket, packet):
socket.sendall(bytes(packet))
def receive_packet(self, socket):
header_length = etherbone_packet_header_length + etherbone_record_header_length
packet = bytes()
while len(packet) < header_length:
chunk = socket.recv(header_length - len(packet))
if len(chunk) == 0:
return 0
else:
packet += chunk
wcount, rcount = struct.unpack(">BB", packet[header_length-2:])
counts = wcount + rcount
packet_size = header_length + 4*(counts + 1)
while len(packet) < packet_size:
chunk = socket.recv(packet_size - len(packet))
if len(chunk) == 0:
return 0
else:
packet += chunk
return packet

View file

@ -0,0 +1,104 @@
import socket
import threading
import argparse
from litex.soc.tools.remote.etherbone import EtherbonePacket, EtherboneRecord, EtherboneWrites
from litex.soc.tools.remote.etherbone import EtherboneIPC
class RemoteServer(EtherboneIPC):
def __init__(self, comm, port=1234, csr_data_width=32):
self.comm = comm
self.port = port
self.csr_data_width = 32
def open(self):
if hasattr(self, "socket"):
return
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.bind(("localhost", self.port))
self.socket.listen(1)
self.comm.open(self.csr_data_width)
def close(self):
self.comm.close()
if not hasattr(self, "socket"):
return
self.socket.close()
del self.socket
def _serve_thread(self):
while True:
client_socket, addr = self.socket.accept()
print("Connected with " + addr[0] + ":" + str(addr[1]))
try:
while True:
packet = self.receive_packet(client_socket)
if packet == 0:
break
packet = EtherbonePacket(packet)
packet.decode()
record = packet.records.pop()
# writes:
if record.writes != None:
self.comm.write(record.writes.base_addr, record.writes.get_datas())
# reads
if record.reads != None:
reads = []
for addr in record.reads.get_addrs():
reads.append(self.comm.read(addr))
record = EtherboneRecord()
record.writes = EtherboneWrites(datas=reads)
record.wcount = len(record.writes)
packet = EtherbonePacket()
packet.records = [record]
packet.encode()
self.send_packet(client_socket, packet)
finally:
print("Disconnect")
client_socket.close()
def start(self):
self.serve_thread = threading.Thread(target=self._serve_thread)
self.serve_thread.setDaemon(True)
self.serve_thread.start()
def join(self, writer_only=False):
if not hasattr(self, "serve_thread"):
return
self.serve_thread.join()
def _get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--comm", default="uart", help="comm interface")
parser.add_argument("--port", default="2", help="UART port")
parser.add_argument("--baudrate", default=115200, help="UART baudrate")
parser.add_argument("--csr_data_width", default=32, help="CSR data_width")
return parser.parse_args()
def main():
args = _get_args()
if args.comm == "uart":
from litex.soc.tools.remote import CommUART
port = args.port if not args.port.isdigit() else int(args.port)
comm = CommUART(args.port if not args.port.isdigit() else int(args.port),
args.baudrate,
debug=False)
else:
raise NotImplementedError
server = RemoteServer(comm, csr_data_width=args.csr_data_width)
server.open()
server.start()
try:
server.join(True)
except KeyboardInterrupt: # FIXME
pass
if __name__ == "__main__":
main()

View file

@ -36,6 +36,8 @@ setup(
"console_scripts": [
"flterm=litex.soc.tools.flterm:main",
"mkmscimg=litex.soc.tools.mkmscimg:main",
"remote_server=litex.soc.tools.remote.server:main",
"remote_client=litex.soc.tools.remote.client:main"
],
},
)