diff --git a/litex/soc/integration/export.py b/litex/soc/integration/export.py index 0973a47c0..731ea09f0 100644 --- a/litex/soc/integration/export.py +++ b/litex/soc/integration/export.py @@ -274,74 +274,118 @@ def _get_rw_functions_c(reg_name, reg_base, nwords, busword, alignment, read_onl return result +def _generate_csr_header_includes_c(with_access_functions): + """ + Generate the necessary include directives for the CSR header file. + """ + includes = "#include \n" + includes += "#ifndef __GENERATED_CSR_H\n" + includes += "#define __GENERATED_CSR_H\n" + if with_access_functions: + includes += "#include \n" + includes += "#include \n" + includes += "#ifndef CSR_ACCESSORS_DEFINED\n" + includes += "#include \n" + includes += "#endif /* ! CSR_ACCESSORS_DEFINED */\n" + return includes + +def _generate_csr_base_define_c(csr_base, with_csr_base_define): + """ + Generate the CSR base address define directive. + """ + if with_csr_base_define: + return f"\n#ifndef CSR_BASE\n#define CSR_BASE {hex(csr_base)}L\n#endif\n" + return "" + +def _generate_field_definitions_c(csr, name): + """ + Generate definitions for CSR fields. + """ + field_defs = "" + for field in csr.fields.fields: + offset = str(field.offset) + size = str(field.size) + field_defs += f"#define CSR_{name.upper()}_{csr.name.upper()}_{field.name.upper()}_OFFSET {offset}\n" + field_defs += f"#define CSR_{name.upper()}_{csr.name.upper()}_{field.name.upper()}_SIZE {size}\n" + field_defs += _generate_field_accessors_c(name, csr, field) + return field_defs + +def _generate_field_accessors_c(name, csr, field): + """ + Generate access functions for CSR fields if the CSR size is <= 32 bits. + """ + accessors = "" + if csr.size <= 32: + reg_name = name + "_" + csr.name.lower() + field_name = reg_name + "_" + field.name.lower() + offset = str(field.offset) + size = str(field.size) + accessors += f"static inline uint32_t {field_name}_extract(uint32_t oldword) {{\n" + accessors += f"\tuint32_t mask = 0x{(1 << int(size)) - 1:x};\n" + accessors += f"\treturn ((oldword >> {offset}) & mask);\n}}\n" + accessors += f"static inline uint32_t {field_name}_read(void) {{\n" + accessors += f"\tuint32_t word = {reg_name}_read();\n" + accessors += f"\treturn {field_name}_extract(word);\n}}\n" + if not getattr(csr, "read_only", False): + accessors += f"static inline uint32_t {field_name}_replace(uint32_t oldword, uint32_t plain_value) {{\n" + accessors += f"\tuint32_t mask = 0x{(1 << int(size)) - 1:x};\n" + accessors += f"\treturn (oldword & (~(mask << {offset}))) | ((mask & plain_value) << {offset});\n}}\n" + accessors += f"static inline void {field_name}_write(uint32_t plain_value) {{\n" + accessors += f"\tuint32_t oldword = {reg_name}_read();\n" + accessors += f"\tuint32_t newword = {field_name}_replace(oldword, plain_value);\n" + accessors += f"\t{reg_name}_write(newword);\n}}\n" + return accessors + +def _generate_csr_region_definitions_c(name, region, origin, alignment, csr_base, with_csr_base_define, with_access_functions): + """ + Generate CSR address and size definitions for a region. + """ + base_define = with_csr_base_define and not isinstance(region, MockCSRRegion) + region_defs = f"\n/* {name} */\n" + region_defs += f"#define CSR_{name.upper()}_BASE {_get_csr_addr(csr_base, origin, base_define)}\n" + + if not isinstance(region.obj, Memory): + for csr in region.obj: + nr = (csr.size + region.busword - 1) // region.busword + region_defs += _get_rw_functions_c( + reg_name = name + "_" + csr.name, + reg_base = origin, + nwords = nr, + busword = region.busword, + alignment = alignment, + read_only = getattr(csr, "read_only", False), + csr_base = csr_base, + with_csr_base_define = base_define, + with_access_functions = with_access_functions, + ) + origin += alignment // 8 * nr + if hasattr(csr, "fields"): + region_defs += _generate_field_definitions_c(csr, name) + return region_defs + def get_csr_header(regions, constants, csr_base=None, with_csr_base_define=True, with_access_functions=True): + """ + Generate the CSR header file content. + """ alignment = constants.get("CONFIG_CSR_ALIGNMENT", 32) r = generated_banner("//") - if with_access_functions: # FIXME - r += "#include \n" - r += "#ifndef __GENERATED_CSR_H\n" - r += "#define __GENERATED_CSR_H\n" - if with_access_functions: - r += "#include \n" - r += "#include \n" - r += "#ifndef CSR_ACCESSORS_DEFINED\n" - r += "#include \n" - r += "#endif /* ! CSR_ACCESSORS_DEFINED */\n" + + r += _generate_csr_header_includes_c(with_access_functions) + _csr_base = regions[next(iter(regions))].origin csr_base = csr_base if csr_base is not None else _csr_base - if with_csr_base_define: - r += "\n#ifndef CSR_BASE\n" - r += f"#define CSR_BASE {hex(csr_base)}L\n" - r += "#endif\n" + + r += _generate_csr_base_define_c(csr_base, with_csr_base_define) + for name, region in regions.items(): - origin = region.origin - _csr_base - base_define = with_csr_base_define and (not isinstance(region, MockCSRRegion)) - r += "\n/* "+name+" */\n" - r += f"#define CSR_{name.upper()}_BASE {_get_csr_addr(csr_base, origin, base_define)}\n" - if not isinstance(region.obj, Memory): - for csr in region.obj: - nr = (csr.size + region.busword - 1)//region.busword - r += _get_rw_functions_c( - reg_name = name + "_" + csr.name, - reg_base = origin, - nwords = nr, - busword = region.busword, - alignment = alignment, - read_only = getattr(csr, "read_only", False), - csr_base = csr_base, - with_csr_base_define = base_define, - with_access_functions = with_access_functions, - ) - origin += alignment//8*nr - if hasattr(csr, "fields"): - for field in csr.fields.fields: - offset = str(field.offset) - size = str(field.size) - r += f"#define CSR_{name.upper()}_{csr.name.upper()}_{field.name.upper()}_OFFSET {offset}\n" - r += f"#define CSR_{name.upper()}_{csr.name.upper()}_{field.name.upper()}_SIZE {size}\n" - if with_access_functions and csr.size <= 32: # FIXME: Implement extract/read functions for csr.size > 32-bit. - reg_name = name + "_" + csr.name.lower() - field_name = reg_name + "_" + field.name.lower() - r += "static inline uint32_t " + field_name + "_extract(uint32_t oldword) {\n" - r += f"\tuint32_t mask = 0x{(1<