soc/integration_export: Split C header generation by sections.

- CSR Includes.
- CSR Registers/Fields Definition.
- CSR Registers Access Functions.
- CSR Registers Field Access Functions.
This commit is contained in:
Florent Kermarrec 2024-05-21 09:58:05 +02:00
parent 4502edd33e
commit 5b297f5601
1 changed files with 147 additions and 121 deletions

View File

@ -136,6 +136,9 @@ def get_linker_regions(regions):
# C Export -----------------------------------------------------------------------------------------
# Header.
def get_git_header():
from litex.build.tools import get_litex_git_revision
r = generated_banner("//")
@ -193,91 +196,7 @@ def get_soc_header(constants, with_access_functions=True):
r += "\n#endif\n"
return r
def _get_csr_addr(csr_base, addr, with_csr_base_define=True):
"""
Generate the CSR address string.
"""
if with_csr_base_define:
return f"(CSR_BASE + {hex(addr)}L)"
else:
return f"{hex(csr_base + addr)}L"
def _generate_csr_definitions_c(reg_name, reg_base, nwords, csr_base, with_csr_base_define):
"""
Generate C code for CSR address and size definitions.
"""
addr_str = f"CSR_{reg_name.upper()}_ADDR"
size_str = f"CSR_{reg_name.upper()}_SIZE"
definitions = f"#define {addr_str} {_get_csr_addr(csr_base, reg_base, with_csr_base_define)}\n"
definitions += f"#define {size_str} {nwords}\n"
return definitions
def _determine_ctype_and_stride_c(size, alignment):
"""
Determine the C type and stride based on the size.
"""
if size > 8:
return None, None
elif size > 4:
ctype = "uint64_t"
elif size > 2:
ctype = "uint32_t"
elif size > 1:
ctype = "uint16_t"
else:
ctype = "uint8_t"
stride = alignment // 8
return ctype, stride
def _generate_read_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define):
"""
Generate C code for the read function of a CSR.
"""
read_function = f"static inline {ctype} {reg_name}_read(void) {{\n"
if nwords > 1:
read_function += f"\t{ctype} r = csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n"
for sub in range(1, nwords):
read_function += f"\tr <<= {busword};\n"
read_function += f"\tr |= csr_read_simple({_get_csr_addr(csr_base, reg_base + sub * stride, with_csr_base_define)});\n"
read_function += "\treturn r;\n}\n"
else:
read_function += f"\treturn csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n}}\n"
return read_function
def _generate_write_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define):
"""
Generate C code for the write function of a CSR.
"""
write_function = f"static inline void {reg_name}_write({ctype} v) {{\n"
for sub in range(nwords):
shift = (nwords - sub - 1) * busword
v_shift = f"v >> {shift}" if shift else "v"
write_function += f"\tcsr_write_simple({v_shift}, {_get_csr_addr(csr_base, reg_base + sub * stride, with_csr_base_define)});\n"
write_function += "}\n"
return write_function
def _get_rw_functions_c(reg_name, reg_base, nwords, busword, alignment, read_only, csr_base, with_csr_base_define, with_access_functions):
"""
Generate C code for CSR read/write functions and definitions.
"""
result = _generate_csr_definitions_c(reg_name, reg_base, nwords, csr_base, with_csr_base_define)
size = nwords * busword // 8
ctype, stride = _determine_ctype_and_stride_c(size, alignment)
if ctype is None:
return result
if with_access_functions:
result += _generate_read_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define)
if not read_only:
result += _generate_write_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define)
return result
def _generate_csr_header_includes_c(with_access_functions):
"""
Generate the necessary include directives for the CSR header file.
"""
includes = ""
if with_access_functions:
includes += "#include <generated/soc.h>\n"
@ -292,9 +211,6 @@ def _generate_csr_header_includes_c(with_access_functions):
return includes
def _generate_csr_base_define_c(csr_base, with_csr_base_define):
"""
Generate the CSR base address define directive.
"""
includes = ""
if with_csr_base_define:
includes += "\n"
@ -303,24 +219,129 @@ def _generate_csr_base_define_c(csr_base, with_csr_base_define):
includes += "#endif /* ! CSR_BASE */\n"
return includes
def _generate_field_definitions_c(csr, name, with_fields_access_functions):
"""
Generate definitions for CSR fields.
"""
# CSR Definitions.
def _get_csr_addr(csr_base, addr, with_csr_base_define=True):
if with_csr_base_define:
return f"(CSR_BASE + {hex(addr)}L)"
else:
return f"{hex(csr_base + addr)}L"
def _generate_csr_definitions_c(reg_name, reg_base, nwords, csr_base, with_csr_base_define):
addr_str = f"CSR_{reg_name.upper()}_ADDR"
size_str = f"CSR_{reg_name.upper()}_SIZE"
definitions = f"#define {addr_str} {_get_csr_addr(csr_base, reg_base, with_csr_base_define)}\n"
definitions += f"#define {size_str} {nwords}\n"
return definitions
def _generate_csr_region_definitions_c(name, region, origin, alignment, csr_base, with_csr_base_define):
base_define = with_csr_base_define and not isinstance(region, MockCSRRegion)
region_defs = f"\n/* {name.upper()} Registers */\n"
region_defs += f"#define CSR_{name.upper()}_BASE {_get_csr_addr(csr_base, origin, base_define)}\n"
if not isinstance(region.obj, Memory):
for csr in region.obj:
nr = (csr.size + region.busword - 1) // region.busword
region_defs += _generate_csr_definitions_c(
reg_name = name + "_" + csr.name,
reg_base = origin,
nwords = nr,
csr_base = csr_base,
with_csr_base_define = base_define,
)
origin += alignment // 8 * nr
region_defs += f"\n/* {name.upper()} Fields */\n"
if not isinstance(region.obj, Memory):
for csr in region.obj:
if hasattr(csr, "fields"):
region_defs += _generate_csr_field_definitions_c(csr, name)
return region_defs
# CSR Read/Write Access Functions.
def _determine_ctype_and_stride_c(size, alignment):
if size > 8:
return None, None
elif size > 4:
ctype = "uint64_t"
elif size > 2:
ctype = "uint32_t"
elif size > 1:
ctype = "uint16_t"
else:
ctype = "uint8_t"
stride = alignment // 8
return ctype, stride
def _generate_csr_read_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define):
read_function = f"static inline {ctype} {reg_name}_read(void) {{\n"
if nwords > 1:
read_function += f"\t{ctype} r = csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n"
for sub in range(1, nwords):
read_function += f"\tr <<= {busword};\n"
read_function += f"\tr |= csr_read_simple({_get_csr_addr(csr_base, reg_base + sub * stride, with_csr_base_define)});\n"
read_function += "\treturn r;\n}\n"
else:
read_function += f"\treturn csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n}}\n"
return read_function
def _generate_csr_write_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define):
write_function = f"static inline void {reg_name}_write({ctype} v) {{\n"
for sub in range(nwords):
shift = (nwords - sub - 1) * busword
v_shift = f"v >> {shift}" if shift else "v"
write_function += f"\tcsr_write_simple({v_shift}, {_get_csr_addr(csr_base, reg_base + sub * stride, with_csr_base_define)});\n"
write_function += "}\n"
return write_function
def _get_csr_read_write_access_functions_c(reg_name, reg_base, nwords, busword, alignment, read_only, csr_base, with_csr_base_define):
result = ""
size = nwords * busword // 8
ctype, stride = _determine_ctype_and_stride_c(size, alignment)
if ctype is None:
return result
result += _generate_csr_read_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define)
if not read_only:
result += _generate_csr_write_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define)
return result
def _generate_csr_region_access_functions_c(name, region, origin, alignment, csr_base, with_csr_base_define):
base_define = with_csr_base_define and not isinstance(region, MockCSRRegion)
region_defs = f"\n/* {name.upper()} Access Functions */\n"
if not isinstance(region.obj, Memory):
for csr in region.obj:
nr = (csr.size + region.busword - 1) // region.busword
region_defs += _get_csr_read_write_access_functions_c(
reg_name = name + "_" + csr.name,
reg_base = origin,
nwords = nr,
busword = region.busword,
alignment = alignment,
read_only = getattr(csr, "read_only", False),
csr_base = csr_base,
with_csr_base_define = base_define,
)
origin += alignment // 8 * nr
return region_defs
# CSR Fields.
def _generate_csr_field_definitions_c(csr, name):
field_defs = ""
for field in csr.fields.fields:
offset = str(field.offset)
size = str(field.size)
field_defs += f"#define CSR_{name.upper()}_{csr.name.upper()}_{field.name.upper()}_OFFSET {offset}\n"
field_defs += f"#define CSR_{name.upper()}_{csr.name.upper()}_{field.name.upper()}_SIZE {size}\n"
if with_fields_access_functions:
field_defs += _generate_field_accessors_c(name, csr, field)
return field_defs
def _generate_field_accessors_c(name, csr, field):
"""
Generate access functions for CSR fields if the CSR size is <= 32 bits.
"""
def _generate_csr_field_accessors_c(name, csr, field):
accessors = ""
if csr.size <= 32:
reg_name = name + "_" + csr.name.lower()
@ -343,40 +364,35 @@ def _generate_field_accessors_c(name, csr, field):
accessors += f"\t{reg_name}_write(newword);\n}}\n"
return accessors
def _generate_csr_region_definitions_c(name, region, origin, alignment, csr_base, with_csr_base_define, with_access_functions, with_fields_access_functions):
"""
Generate CSR address and size definitions for a region.
"""
def _generate_csr_field_functions_c(csr, name):
field_funcs = ""
for field in csr.fields.fields:
field_funcs += _generate_csr_field_accessors_c(name, csr, field)
return field_funcs
def _generate_csr_fields_access_functions_c(name, region, origin, alignment, csr_base, with_csr_base_define):
base_define = with_csr_base_define and not isinstance(region, MockCSRRegion)
region_defs = f"\n/* {name.upper()} */\n"
region_defs += f"#define CSR_{name.upper()}_BASE {_get_csr_addr(csr_base, origin, base_define)}\n"
region_defs = f"\n/* {name.upper()} Fields Access Functions */\n"
if not isinstance(region.obj, Memory):
for csr in region.obj:
nr = (csr.size + region.busword - 1) // region.busword
region_defs += _get_rw_functions_c(
reg_name = name + "_" + csr.name,
reg_base = origin,
nwords = nr,
busword = region.busword,
alignment = alignment,
read_only = getattr(csr, "read_only", False),
csr_base = csr_base,
with_csr_base_define = base_define,
with_access_functions = with_access_functions,
)
origin += alignment // 8 * nr
if hasattr(csr, "fields"):
region_defs += _generate_field_definitions_c(csr, name, with_access_functions and with_fields_access_functions)
region_defs += _generate_csr_field_functions_c(csr, name)
return region_defs
# CSR Header.
def get_csr_header(regions, constants, csr_base=None, with_csr_base_define=True, with_access_functions=True, with_fields_access_functions=True):
"""
Generate the CSR header file content.
"""
alignment = constants.get("CONFIG_CSR_ALIGNMENT", 32)
r = generated_banner("//")
# CSR Includes.
r += "\n"
r += generated_separator("//", "CSR Includes.")
r += "\n"
@ -385,18 +401,28 @@ def get_csr_header(regions, constants, csr_base=None, with_csr_base_define=True,
csr_base = csr_base if csr_base is not None else _csr_base
r += _generate_csr_base_define_c(csr_base, with_csr_base_define)
# CSR Registers/Fields Definition.
r += "\n"
r += generated_separator("//", "CSR Registers/Fields Definition.")
for name, region in regions.items():
origin = region.origin - _csr_base
r += _generate_csr_region_definitions_c(name, region, origin, alignment, csr_base, with_csr_base_define, with_access_functions, with_fields_access_functions)
r += _generate_csr_region_definitions_c(name, region, origin, alignment, csr_base, with_csr_base_define)
r += "\n"
r += generated_separator("//", "CSR Registers Access Functions.")
# CSR Registers Access Functions.
if with_access_functions:
r += "\n"
r += generated_separator("//", "CSR Registers Access Functions.")
for name, region in regions.items():
origin = region.origin - _csr_base
r += _generate_csr_region_access_functions_c(name, region, origin, alignment, csr_base, with_csr_base_define)
r += "\n"
r += generated_separator("//", "CSR Registers Field Access Functions.")
# CSR Registers Field Access Functions.
if with_fields_access_functions:
r += "\n"
r += generated_separator("//", "CSR Registers Field Access Functions.")
for name, region in regions.items():
origin = region.origin - _csr_base
r += _generate_csr_fields_access_functions_c(name, region, origin, alignment, csr_base, with_csr_base_define)
r += "\n#endif /* ! __GENERATED_CSR_H */\n"
return r