diff --git a/litex/soc/integration/builder.py b/litex/soc/integration/builder.py index 0e01155ee..9f7239a75 100644 --- a/litex/soc/integration/builder.py +++ b/litex/soc/integration/builder.py @@ -252,7 +252,10 @@ class Builder: csr_contents = export.get_csr_header( regions = self.soc.csr_regions, constants = self.soc.constants, - csr_base = self.soc.mem_regions["csr"].origin) + csr_base = self.soc.mem_regions["csr"].origin, + with_access_functions = True, + with_fields_access_functions = False, + ) write_to_file(os.path.join(self.generated_dir, "csr.h"), csr_contents) # Generate Git SHA1 of tools to git.h diff --git a/litex/soc/integration/export.py b/litex/soc/integration/export.py index c443b0f62..e9d0e9167 100644 --- a/litex/soc/integration/export.py +++ b/litex/soc/integration/export.py @@ -194,24 +194,30 @@ def get_soc_header(constants, with_access_functions=True): return r def _get_csr_addr(csr_base, addr, with_csr_base_define=True): + """ + Generate the CSR address string. + """ if with_csr_base_define: return f"(CSR_BASE + {hex(addr)}L)" else: return f"{hex(csr_base + addr)}L" -def _get_rw_functions_c(reg_name, reg_base, nwords, busword, alignment, read_only, csr_base, with_csr_base_define, with_access_functions): - r = "" +def _generate_csr_definitions_c(reg_name, reg_base, nwords, csr_base, with_csr_base_define): + """ + Generate C code for CSR address and size definitions. + """ + addr_str = f"CSR_{reg_name.upper()}_ADDR" + size_str = f"CSR_{reg_name.upper()}_SIZE" + definitions = f"#define {addr_str} {_get_csr_addr(csr_base, reg_base, with_csr_base_define)}\n" + definitions += f"#define {size_str} {nwords}\n" + return definitions - addr_str = f"CSR_{reg_name.upper()}_ADDR" - size_str = f"CSR_{reg_name.upper()}_SIZE" - r += f"#define {addr_str} {_get_csr_addr(csr_base, reg_base, with_csr_base_define)}\n" - - r += f"#define {size_str} {nwords}\n" - - size = nwords*busword//8 +def _determine_ctype_and_stride_c(size, alignment): + """ + Determine the C type and stride based on the size. + """ if size > 8: - # Downstream should select appropriate `csr_[rd|wr]_buf_uintX()` pair! - return r + return None, None elif size > 4: ctype = "uint64_t" elif size > 2: @@ -220,99 +226,167 @@ def _get_rw_functions_c(reg_name, reg_base, nwords, busword, alignment, read_onl ctype = "uint16_t" else: ctype = "uint8_t" + stride = alignment // 8 + return ctype, stride + +def _generate_read_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define): + """ + Generate C code for the read function of a CSR. + """ + read_function = f"static inline {ctype} {reg_name}_read(void) {{\n" + if nwords > 1: + read_function += f"\t{ctype} r = csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n" + for sub in range(1, nwords): + read_function += f"\tr <<= {busword};\n" + read_function += f"\tr |= csr_read_simple({_get_csr_addr(csr_base, reg_base + sub * stride, with_csr_base_define)});\n" + read_function += "\treturn r;\n}\n" + else: + read_function += f"\treturn csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n}}\n" + return read_function + +def _generate_write_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define): + """ + Generate C code for the write function of a CSR. + """ + write_function = f"static inline void {reg_name}_write({ctype} v) {{\n" + for sub in range(nwords): + shift = (nwords - sub - 1) * busword + v_shift = f"v >> {shift}" if shift else "v" + write_function += f"\tcsr_write_simple({v_shift}, {_get_csr_addr(csr_base, reg_base + sub * stride, with_csr_base_define)});\n" + write_function += "}\n" + return write_function + +def _get_rw_functions_c(reg_name, reg_base, nwords, busword, alignment, read_only, csr_base, with_csr_base_define, with_access_functions): + """ + Generate C code for CSR read/write functions and definitions. + """ + result = _generate_csr_definitions_c(reg_name, reg_base, nwords, csr_base, with_csr_base_define) + size = nwords * busword // 8 + + ctype, stride = _determine_ctype_and_stride_c(size, alignment) + if ctype is None: + return result - stride = alignment//8; if with_access_functions: - r += f"static inline {ctype} {reg_name}_read(void) {{\n" - if nwords > 1: - r += f"\t{ctype} r = csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n" - for sub in range(1, nwords): - r += f"\tr <<= {busword};\n" - r += f"\tr |= csr_read_simple({_get_csr_addr(csr_base, reg_base+sub*stride, with_csr_base_define)});\n" - r += "\treturn r;\n}\n" - else: - r += f"\treturn csr_read_simple({_get_csr_addr(csr_base, reg_base, with_csr_base_define)});\n}}\n" - + result += _generate_read_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define) if not read_only: - r += f"static inline void {reg_name}_write({ctype} v) {{\n" - for sub in range(nwords): - shift = (nwords-sub-1)*busword - if shift: - v_shift = "v >> {}".format(shift) - else: - v_shift = "v" - r += f"\tcsr_write_simple({v_shift}, {_get_csr_addr(csr_base, reg_base+sub*stride, with_csr_base_define)});\n" - r += "}\n" - return r + result += _generate_write_function_c(reg_name, reg_base, nwords, busword, ctype, stride, csr_base, with_csr_base_define) + return result -def get_csr_header(regions, constants, csr_base=None, with_csr_base_define=True, with_access_functions=True): +def _generate_csr_header_includes_c(with_access_functions): + """ + Generate the necessary include directives for the CSR header file. + """ + includes = "#include \n" + includes += "#ifndef __GENERATED_CSR_H\n" + includes += "#define __GENERATED_CSR_H\n" + if with_access_functions: + includes += "#include \n" + includes += "#include \n" + includes += "#ifndef CSR_ACCESSORS_DEFINED\n" + includes += "#include \n" + includes += "#endif /* ! CSR_ACCESSORS_DEFINED */\n" + return includes + +def _generate_csr_base_define_c(csr_base, with_csr_base_define): + """ + Generate the CSR base address define directive. + """ + if with_csr_base_define: + return f"\n#ifndef CSR_BASE\n#define CSR_BASE {hex(csr_base)}L\n#endif\n" + return "" + +def _generate_field_definitions_c(csr, name, with_fields_access_functions): + """ + Generate definitions for CSR fields. + """ + field_defs = "" + for field in csr.fields.fields: + offset = str(field.offset) + size = str(field.size) + field_defs += f"#define CSR_{name.upper()}_{csr.name.upper()}_{field.name.upper()}_OFFSET {offset}\n" + field_defs += f"#define CSR_{name.upper()}_{csr.name.upper()}_{field.name.upper()}_SIZE {size}\n" + if with_fields_access_functions: + field_defs += _generate_field_accessors_c(name, csr, field) + return field_defs + +def _generate_field_accessors_c(name, csr, field): + """ + Generate access functions for CSR fields if the CSR size is <= 32 bits. + """ + accessors = "" + if csr.size <= 32: + reg_name = name + "_" + csr.name.lower() + field_name = reg_name + "_" + field.name.lower() + offset = str(field.offset) + size = str(field.size) + accessors += f"static inline uint32_t {field_name}_extract(uint32_t oldword) {{\n" + accessors += f"\tuint32_t mask = 0x{(1 << int(size)) - 1:x};\n" + accessors += f"\treturn ((oldword >> {offset}) & mask);\n}}\n" + accessors += f"static inline uint32_t {field_name}_read(void) {{\n" + accessors += f"\tuint32_t word = {reg_name}_read();\n" + accessors += f"\treturn {field_name}_extract(word);\n}}\n" + if not getattr(csr, "read_only", False): + accessors += f"static inline uint32_t {field_name}_replace(uint32_t oldword, uint32_t plain_value) {{\n" + accessors += f"\tuint32_t mask = 0x{(1 << int(size)) - 1:x};\n" + accessors += f"\treturn (oldword & (~(mask << {offset}))) | ((mask & plain_value) << {offset});\n}}\n" + accessors += f"static inline void {field_name}_write(uint32_t plain_value) {{\n" + accessors += f"\tuint32_t oldword = {reg_name}_read();\n" + accessors += f"\tuint32_t newword = {field_name}_replace(oldword, plain_value);\n" + accessors += f"\t{reg_name}_write(newword);\n}}\n" + return accessors + +def _generate_csr_region_definitions_c(name, region, origin, alignment, csr_base, with_csr_base_define, with_access_functions, with_fields_access_functions): + """ + Generate CSR address and size definitions for a region. + """ + base_define = with_csr_base_define and not isinstance(region, MockCSRRegion) + region_defs = f"\n/* {name} */\n" + region_defs += f"#define CSR_{name.upper()}_BASE {_get_csr_addr(csr_base, origin, base_define)}\n" + + if not isinstance(region.obj, Memory): + for csr in region.obj: + nr = (csr.size + region.busword - 1) // region.busword + region_defs += _get_rw_functions_c( + reg_name = name + "_" + csr.name, + reg_base = origin, + nwords = nr, + busword = region.busword, + alignment = alignment, + read_only = getattr(csr, "read_only", False), + csr_base = csr_base, + with_csr_base_define = base_define, + with_access_functions = with_access_functions, + ) + origin += alignment // 8 * nr + if hasattr(csr, "fields"): + region_defs += _generate_field_definitions_c(csr, name, with_access_functions and with_fields_access_functions) + return region_defs + +def get_csr_header(regions, constants, csr_base=None, with_csr_base_define=True, with_access_functions=True, with_fields_access_functions=True): + """ + Generate the CSR header file content. + """ alignment = constants.get("CONFIG_CSR_ALIGNMENT", 32) r = generated_banner("//") - if with_access_functions: # FIXME - r += "#include \n" - r += "#ifndef __GENERATED_CSR_H\n#define __GENERATED_CSR_H\n" - if with_access_functions: - r += "#include \n" - r += "#include \n" - r += "#ifndef CSR_ACCESSORS_DEFINED\n" - r += "#include \n" - r += "#endif /* ! CSR_ACCESSORS_DEFINED */\n" + + r += _generate_csr_header_includes_c(with_access_functions) + _csr_base = regions[next(iter(regions))].origin csr_base = csr_base if csr_base is not None else _csr_base - if with_csr_base_define: - r += "\n#ifndef CSR_BASE\n" - r += f"#define CSR_BASE {hex(csr_base)}L\n" - r += "#endif\n" + + r += _generate_csr_base_define_c(csr_base, with_csr_base_define) + for name, region in regions.items(): - origin = region.origin - _csr_base - base_define = with_csr_base_define and (not isinstance(region, MockCSRRegion)) - r += "\n/* "+name+" */\n" - r += f"#define CSR_{name.upper()}_BASE {_get_csr_addr(csr_base, origin, base_define)}\n" - if not isinstance(region.obj, Memory): - for csr in region.obj: - nr = (csr.size + region.busword - 1)//region.busword - r += _get_rw_functions_c( - reg_name = name + "_" + csr.name, - reg_base = origin, - nwords = nr, - busword = region.busword, - alignment = alignment, - read_only = getattr(csr, "read_only", False), - csr_base = csr_base, - with_csr_base_define = base_define, - with_access_functions = with_access_functions, - ) - origin += alignment//8*nr - if hasattr(csr, "fields"): - for field in csr.fields.fields: - offset = str(field.offset) - size = str(field.size) - r += f"#define CSR_{name.upper()}_{csr.name.upper()}_{field.name.upper()}_OFFSET {offset}\n" - r += f"#define CSR_{name.upper()}_{csr.name.upper()}_{field.name.upper()}_SIZE {size}\n" - if with_access_functions and csr.size <= 32: # FIXME: Implement extract/read functions for csr.size > 32-bit. - reg_name = name + "_" + csr.name.lower() - field_name = reg_name + "_" + field.name.lower() - r += "static inline uint32_t " + field_name + "_extract(uint32_t oldword) {\n" - r += f"\tuint32_t mask = 0x{(1<