From c42cc350c664524c409ef99bd38adfae3b8a0d5e Mon Sep 17 00:00:00 2001 From: Florent Kermarrec Date: Tue, 14 May 2024 10:47:38 +0200 Subject: [PATCH] integration/export: Split _get_rw_functions_c in simpler functions. --- litex/soc/integration/export.py | 95 +++++++++++++++++++++------------ 1 file changed, 62 insertions(+), 33 deletions(-) diff --git a/litex/soc/integration/export.py b/litex/soc/integration/export.py index c443b0f62..0973a47c0 100644 --- a/litex/soc/integration/export.py +++ b/litex/soc/integration/export.py @@ -194,24 +194,30 @@ def get_soc_header(constants, with_access_functions=True): return r def _get_csr_addr(csr_base, addr, with_csr_base_define=True): + """ + Generate the CSR address string. + """ if with_csr_base_define: return f"(CSR_BASE + {hex(addr)}L)" else: return f"{hex(csr_base + addr)}L" -def _get_rw_functions_c(reg_name, reg_base, nwords, busword, alignment, read_only, csr_base, with_csr_base_define, with_access_functions): - r = "" +def _generate_csr_definitions_c(reg_name, reg_base, nwords, csr_base, with_csr_base_define): + """ + Generate C code for CSR address and size definitions. + """ + addr_str = f"CSR_{reg_name.upper()}_ADDR" + size_str = f"CSR_{reg_name.upper()}_SIZE" + definitions = f"#define {addr_str} {_get_csr_addr(csr_base, reg_base, with_csr_base_define)}\n" + definitions += f"#define {size_str} {nwords}\n" + return definitions - addr_str = f"CSR_{reg_name.upper()}_ADDR" - size_str = f"CSR_{reg_name.upper()}_SIZE" - r += f"#define {addr_str} {_get_csr_addr(csr_base, reg_base, with_csr_base_define)}\n" - - r += f"#define {size_str} {nwords}\n" - - size = nwords*busword//8 +def _determine_ctype_and_stride_c(size, alignment): + """ + Determine the C type and stride based on the size. + """ if size > 8: - # Downstream should select appropriate `csr_[rd|wr]_buf_uintX()` pair! - return r + return None, None elif size > 4: ctype = "uint64_t" elif size > 2: @@ -220,38 +226,61 @@ def _get_rw_functions_c(reg_name, reg_base, nwords, busword, alignment, read_onl ctype = "uint16_t" else: ctype = "uint8_t" + stride = alignment // 8 + return ctype, stride + +def _generate_read_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define): + """ + Generate C code for the read function of a CSR. + """ + read_function = f"static inline {ctype} {reg_name}_read(void) {{\n" + if nwords > 1: + read_function += f"\t{ctype} r = csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n" + for sub in range(1, nwords): + read_function += f"\tr <<= {busword};\n" + read_function += f"\tr |= csr_read_simple({_get_csr_addr(csr_base, reg_base + sub * stride, with_csr_base_define)});\n" + read_function += "\treturn r;\n}}\n" + else: + read_function += f"\treturn csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n}}\n" + return read_function + +def _generate_write_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define): + """ + Generate C code for the write function of a CSR. + """ + write_function = f"static inline void {reg_name}_write({ctype} v) {{\n" + for sub in range(nwords): + shift = (nwords - sub - 1) * busword + v_shift = f"v >> {shift}" if shift else "v" + write_function += f"\tcsr_write_simple({v_shift}, {_get_csr_addr(csr_base, reg_base + sub * stride, with_csr_base_define)});\n" + write_function += "}\n" + return write_function + +def _get_rw_functions_c(reg_name, reg_base, nwords, busword, alignment, read_only, csr_base, with_csr_base_define, with_access_functions): + """ + Generate C code for CSR read/write functions and definitions. + """ + result = _generate_csr_definitions_c(reg_name, reg_base, nwords, csr_base, with_csr_base_define) + size = nwords * busword // 8 + + ctype, stride = _determine_ctype_and_stride_c(size, alignment) + if ctype is None: + return result - stride = alignment//8; if with_access_functions: - r += f"static inline {ctype} {reg_name}_read(void) {{\n" - if nwords > 1: - r += f"\t{ctype} r = csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n" - for sub in range(1, nwords): - r += f"\tr <<= {busword};\n" - r += f"\tr |= csr_read_simple({_get_csr_addr(csr_base, reg_base+sub*stride, with_csr_base_define)});\n" - r += "\treturn r;\n}\n" - else: - r += f"\treturn csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n}}\n" - + result += _generate_read_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define) if not read_only: - r += f"static inline void {reg_name}_write({ctype} v) {{\n" - for sub in range(nwords): - shift = (nwords-sub-1)*busword - if shift: - v_shift = "v >> {}".format(shift) - else: - v_shift = "v" - r += f"\tcsr_write_simple({v_shift}, {_get_csr_addr(csr_base, reg_base+sub*stride, with_csr_base_define)});\n" - r += "}\n" - return r + result += _generate_write_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define) + return result def get_csr_header(regions, constants, csr_base=None, with_csr_base_define=True, with_access_functions=True): alignment = constants.get("CONFIG_CSR_ALIGNMENT", 32) r = generated_banner("//") if with_access_functions: # FIXME r += "#include \n" - r += "#ifndef __GENERATED_CSR_H\n#define __GENERATED_CSR_H\n" + r += "#ifndef __GENERATED_CSR_H\n" + r += "#define __GENERATED_CSR_H\n" if with_access_functions: r += "#include \n" r += "#include \n"