diff --git a/litex/soc/integration/export.py b/litex/soc/integration/export.py index 7915c136c..675522ff9 100644 --- a/litex/soc/integration/export.py +++ b/litex/soc/integration/export.py @@ -136,6 +136,9 @@ def get_linker_regions(regions): # C Export ----------------------------------------------------------------------------------------- + +# Header. + def get_git_header(): from litex.build.tools import get_litex_git_revision r = generated_banner("//") @@ -193,91 +196,7 @@ def get_soc_header(constants, with_access_functions=True): r += "\n#endif\n" return r -def _get_csr_addr(csr_base, addr, with_csr_base_define=True): - """ - Generate the CSR address string. - """ - if with_csr_base_define: - return f"(CSR_BASE + {hex(addr)}L)" - else: - return f"{hex(csr_base + addr)}L" - -def _generate_csr_definitions_c(reg_name, reg_base, nwords, csr_base, with_csr_base_define): - """ - Generate C code for CSR address and size definitions. - """ - addr_str = f"CSR_{reg_name.upper()}_ADDR" - size_str = f"CSR_{reg_name.upper()}_SIZE" - definitions = f"#define {addr_str} {_get_csr_addr(csr_base, reg_base, with_csr_base_define)}\n" - definitions += f"#define {size_str} {nwords}\n" - return definitions - -def _determine_ctype_and_stride_c(size, alignment): - """ - Determine the C type and stride based on the size. - """ - if size > 8: - return None, None - elif size > 4: - ctype = "uint64_t" - elif size > 2: - ctype = "uint32_t" - elif size > 1: - ctype = "uint16_t" - else: - ctype = "uint8_t" - stride = alignment // 8 - return ctype, stride - -def _generate_read_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define): - """ - Generate C code for the read function of a CSR. - """ - read_function = f"static inline {ctype} {reg_name}_read(void) {{\n" - if nwords > 1: - read_function += f"\t{ctype} r = csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n" - for sub in range(1, nwords): - read_function += f"\tr <<= {busword};\n" - read_function += f"\tr |= csr_read_simple({_get_csr_addr(csr_base, reg_base + sub * stride, with_csr_base_define)});\n" - read_function += "\treturn r;\n}\n" - else: - read_function += f"\treturn csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n}}\n" - return read_function - -def _generate_write_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define): - """ - Generate C code for the write function of a CSR. - """ - write_function = f"static inline void {reg_name}_write({ctype} v) {{\n" - for sub in range(nwords): - shift = (nwords - sub - 1) * busword - v_shift = f"v >> {shift}" if shift else "v" - write_function += f"\tcsr_write_simple({v_shift}, {_get_csr_addr(csr_base, reg_base + sub * stride, with_csr_base_define)});\n" - write_function += "}\n" - return write_function - -def _get_rw_functions_c(reg_name, reg_base, nwords, busword, alignment, read_only, csr_base, with_csr_base_define, with_access_functions): - """ - Generate C code for CSR read/write functions and definitions. - """ - result = _generate_csr_definitions_c(reg_name, reg_base, nwords, csr_base, with_csr_base_define) - size = nwords * busword // 8 - - ctype, stride = _determine_ctype_and_stride_c(size, alignment) - if ctype is None: - return result - - if with_access_functions: - result += _generate_read_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define) - if not read_only: - result += _generate_write_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define) - - return result - def _generate_csr_header_includes_c(with_access_functions): - """ - Generate the necessary include directives for the CSR header file. - """ includes = "" if with_access_functions: includes += "#include \n" @@ -292,9 +211,6 @@ def _generate_csr_header_includes_c(with_access_functions): return includes def _generate_csr_base_define_c(csr_base, with_csr_base_define): - """ - Generate the CSR base address define directive. - """ includes = "" if with_csr_base_define: includes += "\n" @@ -303,24 +219,129 @@ def _generate_csr_base_define_c(csr_base, with_csr_base_define): includes += "#endif /* ! CSR_BASE */\n" return includes -def _generate_field_definitions_c(csr, name, with_fields_access_functions): - """ - Generate definitions for CSR fields. - """ +# CSR Definitions. + +def _get_csr_addr(csr_base, addr, with_csr_base_define=True): + if with_csr_base_define: + return f"(CSR_BASE + {hex(addr)}L)" + else: + return f"{hex(csr_base + addr)}L" + +def _generate_csr_definitions_c(reg_name, reg_base, nwords, csr_base, with_csr_base_define): + addr_str = f"CSR_{reg_name.upper()}_ADDR" + size_str = f"CSR_{reg_name.upper()}_SIZE" + definitions = f"#define {addr_str} {_get_csr_addr(csr_base, reg_base, with_csr_base_define)}\n" + definitions += f"#define {size_str} {nwords}\n" + return definitions + +def _generate_csr_region_definitions_c(name, region, origin, alignment, csr_base, with_csr_base_define): + base_define = with_csr_base_define and not isinstance(region, MockCSRRegion) + region_defs = f"\n/* {name.upper()} Registers */\n" + region_defs += f"#define CSR_{name.upper()}_BASE {_get_csr_addr(csr_base, origin, base_define)}\n" + + if not isinstance(region.obj, Memory): + for csr in region.obj: + nr = (csr.size + region.busword - 1) // region.busword + region_defs += _generate_csr_definitions_c( + reg_name = name + "_" + csr.name, + reg_base = origin, + nwords = nr, + csr_base = csr_base, + with_csr_base_define = base_define, + ) + origin += alignment // 8 * nr + + region_defs += f"\n/* {name.upper()} Fields */\n" + if not isinstance(region.obj, Memory): + for csr in region.obj: + if hasattr(csr, "fields"): + region_defs += _generate_csr_field_definitions_c(csr, name) + + return region_defs + +# CSR Read/Write Access Functions. + +def _determine_ctype_and_stride_c(size, alignment): + if size > 8: + return None, None + elif size > 4: + ctype = "uint64_t" + elif size > 2: + ctype = "uint32_t" + elif size > 1: + ctype = "uint16_t" + else: + ctype = "uint8_t" + stride = alignment // 8 + return ctype, stride + +def _generate_csr_read_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define): + read_function = f"static inline {ctype} {reg_name}_read(void) {{\n" + if nwords > 1: + read_function += f"\t{ctype} r = csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n" + for sub in range(1, nwords): + read_function += f"\tr <<= {busword};\n" + read_function += f"\tr |= csr_read_simple({_get_csr_addr(csr_base, reg_base + sub * stride, with_csr_base_define)});\n" + read_function += "\treturn r;\n}\n" + else: + read_function += f"\treturn csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n}}\n" + return read_function + +def _generate_csr_write_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define): + write_function = f"static inline void {reg_name}_write({ctype} v) {{\n" + for sub in range(nwords): + shift = (nwords - sub - 1) * busword + v_shift = f"v >> {shift}" if shift else "v" + write_function += f"\tcsr_write_simple({v_shift}, {_get_csr_addr(csr_base, reg_base + sub * stride, with_csr_base_define)});\n" + write_function += "}\n" + return write_function + +def _get_csr_read_write_access_functions_c(reg_name, reg_base, nwords, busword, alignment, read_only, csr_base, with_csr_base_define): + result = "" + size = nwords * busword // 8 + + ctype, stride = _determine_ctype_and_stride_c(size, alignment) + if ctype is None: + return result + + result += _generate_csr_read_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define) + if not read_only: + result += _generate_csr_write_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define) + + return result + +def _generate_csr_region_access_functions_c(name, region, origin, alignment, csr_base, with_csr_base_define): + base_define = with_csr_base_define and not isinstance(region, MockCSRRegion) + region_defs = f"\n/* {name.upper()} Access Functions */\n" + + if not isinstance(region.obj, Memory): + for csr in region.obj: + nr = (csr.size + region.busword - 1) // region.busword + region_defs += _get_csr_read_write_access_functions_c( + reg_name = name + "_" + csr.name, + reg_base = origin, + nwords = nr, + busword = region.busword, + alignment = alignment, + read_only = getattr(csr, "read_only", False), + csr_base = csr_base, + with_csr_base_define = base_define, + ) + origin += alignment // 8 * nr + return region_defs + +# CSR Fields. + +def _generate_csr_field_definitions_c(csr, name): field_defs = "" for field in csr.fields.fields: offset = str(field.offset) size = str(field.size) field_defs += f"#define CSR_{name.upper()}_{csr.name.upper()}_{field.name.upper()}_OFFSET {offset}\n" field_defs += f"#define CSR_{name.upper()}_{csr.name.upper()}_{field.name.upper()}_SIZE {size}\n" - if with_fields_access_functions: - field_defs += _generate_field_accessors_c(name, csr, field) return field_defs -def _generate_field_accessors_c(name, csr, field): - """ - Generate access functions for CSR fields if the CSR size is <= 32 bits. - """ +def _generate_csr_field_accessors_c(name, csr, field): accessors = "" if csr.size <= 32: reg_name = name + "_" + csr.name.lower() @@ -343,40 +364,35 @@ def _generate_field_accessors_c(name, csr, field): accessors += f"\t{reg_name}_write(newword);\n}}\n" return accessors -def _generate_csr_region_definitions_c(name, region, origin, alignment, csr_base, with_csr_base_define, with_access_functions, with_fields_access_functions): - """ - Generate CSR address and size definitions for a region. - """ +def _generate_csr_field_functions_c(csr, name): + field_funcs = "" + for field in csr.fields.fields: + field_funcs += _generate_csr_field_accessors_c(name, csr, field) + return field_funcs + +def _generate_csr_fields_access_functions_c(name, region, origin, alignment, csr_base, with_csr_base_define): base_define = with_csr_base_define and not isinstance(region, MockCSRRegion) - region_defs = f"\n/* {name.upper()} */\n" - region_defs += f"#define CSR_{name.upper()}_BASE {_get_csr_addr(csr_base, origin, base_define)}\n" + region_defs = f"\n/* {name.upper()} Fields Access Functions */\n" if not isinstance(region.obj, Memory): for csr in region.obj: nr = (csr.size + region.busword - 1) // region.busword - region_defs += _get_rw_functions_c( - reg_name = name + "_" + csr.name, - reg_base = origin, - nwords = nr, - busword = region.busword, - alignment = alignment, - read_only = getattr(csr, "read_only", False), - csr_base = csr_base, - with_csr_base_define = base_define, - with_access_functions = with_access_functions, - ) origin += alignment // 8 * nr if hasattr(csr, "fields"): - region_defs += _generate_field_definitions_c(csr, name, with_access_functions and with_fields_access_functions) + region_defs += _generate_csr_field_functions_c(csr, name) return region_defs +# CSR Header. + def get_csr_header(regions, constants, csr_base=None, with_csr_base_define=True, with_access_functions=True, with_fields_access_functions=True): """ Generate the CSR header file content. """ + alignment = constants.get("CONFIG_CSR_ALIGNMENT", 32) r = generated_banner("//") + # CSR Includes. r += "\n" r += generated_separator("//", "CSR Includes.") r += "\n" @@ -385,18 +401,28 @@ def get_csr_header(regions, constants, csr_base=None, with_csr_base_define=True, csr_base = csr_base if csr_base is not None else _csr_base r += _generate_csr_base_define_c(csr_base, with_csr_base_define) - + # CSR Registers/Fields Definition. r += "\n" r += generated_separator("//", "CSR Registers/Fields Definition.") for name, region in regions.items(): origin = region.origin - _csr_base - r += _generate_csr_region_definitions_c(name, region, origin, alignment, csr_base, with_csr_base_define, with_access_functions, with_fields_access_functions) + r += _generate_csr_region_definitions_c(name, region, origin, alignment, csr_base, with_csr_base_define) - r += "\n" - r += generated_separator("//", "CSR Registers Access Functions.") + # CSR Registers Access Functions. + if with_access_functions: + r += "\n" + r += generated_separator("//", "CSR Registers Access Functions.") + for name, region in regions.items(): + origin = region.origin - _csr_base + r += _generate_csr_region_access_functions_c(name, region, origin, alignment, csr_base, with_csr_base_define) - r += "\n" - r += generated_separator("//", "CSR Registers Field Access Functions.") + # CSR Registers Field Access Functions. + if with_fields_access_functions: + r += "\n" + r += generated_separator("//", "CSR Registers Field Access Functions.") + for name, region in regions.items(): + origin = region.origin - _csr_base + r += _generate_csr_fields_access_functions_c(name, region, origin, alignment, csr_base, with_csr_base_define) r += "\n#endif /* ! __GENERATED_CSR_H */\n" return r