diff --git a/ci/generate_checked_functions.py b/ci/generate_checked_functions.py index 572d94646..f682fd676 100644 --- a/ci/generate_checked_functions.py +++ b/ci/generate_checked_functions.py @@ -2,10 +2,78 @@ from pycparser import c_parser, c_ast, parse_file import os -# Updated generate_checked_function to dynamically update Result definition for new return types +def collect_typedefs(ast): + """Collect all typedefs in the AST.""" + typedefs = {} + for node in ast.ext: + if not isinstance(node, c_ast.Typedef): + continue + + typedef_name = node.name + typedef_type = node.type + typedefs[typedef_name] = typedef_type + return typedefs -def generate_checked_function(func): +def resolve_typedef(typedefs, type_name): + """Resolve a typedef to its underlying type.""" + + def resolve_base_type(ptr_decl): + # handle cases like: typedef int******* ptr; + cur_type = ptr_decl + pointer_type_name = "" + + while isinstance(cur_type, c_ast.PtrDecl) or isinstance( + cur_type, c_ast.TypeDecl + ): + if isinstance(cur_type, c_ast.PtrDecl): + cur_type = cur_type.type + pointer_type_name += "*" + elif isinstance(cur_type, c_ast.TypeDecl): + if isinstance(cur_type.type, c_ast.IdentifierType): + base_type_name = " ".join(cur_type.type.names) + pointer_type_name = base_type_name + pointer_type_name + return pointer_type_name + else: + pointer_type_name = "".join(cur_type.type.name) + pointer_type_name + return pointer_type_name + return None + + resolved_type = typedefs.get(type_name) + + if resolved_type is None: + return type_name + + if isinstance(resolved_type, c_ast.TypeDecl): + if isinstance(resolved_type.type, c_ast.Enum): + print(f"Resolved enum typedef {type_name}") + return type_name + + if isinstance(resolved_type.type, c_ast.Struct): + print(f"Resolved struct typedef {type_name}") + return type_name + + if isinstance(resolved_type.type, c_ast.Union): + print(f"Resolved union typedef {type_name}") + return type_name + + if isinstance(resolved_type.type, c_ast.IdentifierType): + base_type_name = " ".join(resolved_type.type.names) + print(f"Resolved base typedef {type_name} to {base_type_name}") + return type_name + + resolved_type.show() + raise Exception(f"Unhandled TypeDecl typedef {type_name}") + elif isinstance(resolved_type, c_ast.PtrDecl): + pointer_type_name = resolve_base_type(resolved_type) + print(f"Resolved pointer typedef {type_name} to {pointer_type_name}") + return pointer_type_name + else: + resolved_type.show() + raise Exception(f"Unhandled typedef {type_name}") + + +def generate_checked_function(func, typedefs): func_name = func.name # Access the name directly from Decl new_func_name = f"{func_name}_checked" @@ -17,8 +85,9 @@ def generate_checked_function(func): return_type = "void" # Default to void if no return type is specified if isinstance(func.type.type, c_ast.TypeDecl): return_type = " ".join(func.type.type.type.names) - # TODO: figure out a better way to detect typedef from pointer - if isinstance(func.type.type, c_ast.PtrDecl): + + resolved_type = resolve_typedef(typedefs, return_type) + if resolved_type.endswith("*"): return_pointer = True # Start building the new function @@ -88,12 +157,13 @@ def generate_checked_function(func): else: new_func.append(f" if (original_result == 0) {{") new_func.append(f" res.error_code = 0;") - new_func.append(f" res.value._Bool_value = original_result;") + new_func.append(f" res.value.{return_type}_value = original_result;") new_func.append(f" }} else {{") new_func.append(f" res.error_code = -2;") new_func.append(f" }}") new_func.append(f" return res;") + new_func.append(f"}}") return "\n".join(new_func) @@ -139,6 +209,9 @@ def process_header(): ], ) + # Collect all typedefs + typedefs = collect_typedefs(ast) + # Collect all function declarations functions = [ node @@ -151,7 +224,8 @@ def process_header(): for func in functions: if isinstance(func.type.type, c_ast.TypeDecl): return_type = " ".join(func.type.type.type.names) - return_types.add(return_type) + resolved_type = resolve_typedef(typedefs, return_type) + return_types.add(resolved_type) # Update the Result struct with all return types for return_type in return_types: @@ -180,7 +254,7 @@ def process_header(): f.write(RESULT_STRUCT + "\n") for func in functions: - new_func = generate_checked_function(func) + new_func = generate_checked_function(func, typedefs) f.write(new_func + "\n\n") f.write("#endif // WASM_EXPORT_CHECKED_H\n")