integration/export: Split _get_rw_functions_c in simpler functions.

This commit is contained in:
Florent Kermarrec 2024-05-14 10:47:38 +02:00
parent 2613ae606a
commit c42cc350c6
1 changed files with 62 additions and 33 deletions

View File

@ -194,24 +194,30 @@ def get_soc_header(constants, with_access_functions=True):
return r
def _get_csr_addr(csr_base, addr, with_csr_base_define=True):
"""
Generate the CSR address string.
"""
if with_csr_base_define:
return f"(CSR_BASE + {hex(addr)}L)"
else:
return f"{hex(csr_base + addr)}L"
def _get_rw_functions_c(reg_name, reg_base, nwords, busword, alignment, read_only, csr_base, with_csr_base_define, with_access_functions):
r = ""
def _generate_csr_definitions_c(reg_name, reg_base, nwords, csr_base, with_csr_base_define):
"""
Generate C code for CSR address and size definitions.
"""
addr_str = f"CSR_{reg_name.upper()}_ADDR"
size_str = f"CSR_{reg_name.upper()}_SIZE"
r += f"#define {addr_str} {_get_csr_addr(csr_base, reg_base, with_csr_base_define)}\n"
definitions = f"#define {addr_str} {_get_csr_addr(csr_base, reg_base, with_csr_base_define)}\n"
definitions += f"#define {size_str} {nwords}\n"
return definitions
r += f"#define {size_str} {nwords}\n"
size = nwords*busword//8
def _determine_ctype_and_stride_c(size, alignment):
"""
Determine the C type and stride based on the size.
"""
if size > 8:
# Downstream should select appropriate `csr_[rd|wr]_buf_uintX()` pair!
return r
return None, None
elif size > 4:
ctype = "uint64_t"
elif size > 2:
@ -220,38 +226,61 @@ def _get_rw_functions_c(reg_name, reg_base, nwords, busword, alignment, read_onl
ctype = "uint16_t"
else:
ctype = "uint8_t"
stride = alignment // 8
return ctype, stride
stride = alignment//8;
if with_access_functions:
r += f"static inline {ctype} {reg_name}_read(void) {{\n"
def _generate_read_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define):
"""
Generate C code for the read function of a CSR.
"""
read_function = f"static inline {ctype} {reg_name}_read(void) {{\n"
if nwords > 1:
r += f"\t{ctype} r = csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n"
read_function += f"\t{ctype} r = csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n"
for sub in range(1, nwords):
r += f"\tr <<= {busword};\n"
r += f"\tr |= csr_read_simple({_get_csr_addr(csr_base, reg_base+sub*stride, with_csr_base_define)});\n"
r += "\treturn r;\n}\n"
read_function += f"\tr <<= {busword};\n"
read_function += f"\tr |= csr_read_simple({_get_csr_addr(csr_base, reg_base + sub * stride, with_csr_base_define)});\n"
read_function += "\treturn r;\n}}\n"
else:
r += f"\treturn csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n}}\n"
read_function += f"\treturn csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n}}\n"
return read_function
if not read_only:
r += f"static inline void {reg_name}_write({ctype} v) {{\n"
def _generate_write_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define):
"""
Generate C code for the write function of a CSR.
"""
write_function = f"static inline void {reg_name}_write({ctype} v) {{\n"
for sub in range(nwords):
shift = (nwords - sub - 1) * busword
if shift:
v_shift = "v >> {}".format(shift)
else:
v_shift = "v"
r += f"\tcsr_write_simple({v_shift}, {_get_csr_addr(csr_base, reg_base+sub*stride, with_csr_base_define)});\n"
r += "}\n"
return r
v_shift = f"v >> {shift}" if shift else "v"
write_function += f"\tcsr_write_simple({v_shift}, {_get_csr_addr(csr_base, reg_base + sub * stride, with_csr_base_define)});\n"
write_function += "}\n"
return write_function
def _get_rw_functions_c(reg_name, reg_base, nwords, busword, alignment, read_only, csr_base, with_csr_base_define, with_access_functions):
"""
Generate C code for CSR read/write functions and definitions.
"""
result = _generate_csr_definitions_c(reg_name, reg_base, nwords, csr_base, with_csr_base_define)
size = nwords * busword // 8
ctype, stride = _determine_ctype_and_stride_c(size, alignment)
if ctype is None:
return result
if with_access_functions:
result += _generate_read_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define)
if not read_only:
result += _generate_write_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define)
return result
def get_csr_header(regions, constants, csr_base=None, with_csr_base_define=True, with_access_functions=True):
alignment = constants.get("CONFIG_CSR_ALIGNMENT", 32)
r = generated_banner("//")
if with_access_functions: # FIXME
r += "#include <generated/soc.h>\n"
r += "#ifndef __GENERATED_CSR_H\n#define __GENERATED_CSR_H\n"
r += "#ifndef __GENERATED_CSR_H\n"
r += "#define __GENERATED_CSR_H\n"
if with_access_functions:
r += "#include <stdint.h>\n"
r += "#include <system.h>\n"