1from __future__ import annotations 2 3import torchgen.api.meta as meta 4import torchgen.api.structured as structured 5from torchgen.api.types import kernel_signature 6from torchgen.context import with_native_function_and_index 7from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup 8from torchgen.utils import mapMaybe 9 10 11@with_native_function_and_index 12def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | None: 13 sig = kernel_signature(f, backend_index) 14 metadata = backend_index.get_kernel(f) 15 if metadata is None: 16 return None 17 if "legacy::" in metadata.kernel: 18 return None 19 else: 20 prefix = "static" if backend_index.external else "TORCH_API" 21 return f"{prefix} {sig.decl(name=metadata.kernel)};" 22 23 24@with_native_function_and_index 25def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list[str]: 26 meta_name = meta.name(g) 27 out_args = structured.impl_arguments(g) 28 metadata = backend_index.get_kernel(g) 29 if metadata is None: 30 return [] 31 prefix = "" if backend_index.external else "TORCH_API " 32 return [ 33 f"""\ 34struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{ 35void impl({', '.join(a.decl() for a in out_args)}); 36}}; 37""" 38 ] 39 40 41# Generates NativeFunctions.h, a list of forward declarations of all 42# actual kernel definitions we keep in aten/src/ATen/native/ 43@with_native_function_and_index 44def compute_native_function_declaration( 45 g: NativeFunctionsGroup | NativeFunction, backend_index: BackendIndex 46) -> list[str]: 47 metadata = backend_index.get_kernel(g) 48 if isinstance(g, NativeFunctionsGroup): 49 if metadata is not None and metadata.structured: 50 if backend_index.external: 51 # Structured hasn't been tested with external backends yet. 52 raise AssertionError( 53 "Structured external backend functions are not implemented yet." 54 ) 55 else: 56 return gen_structured(g, backend_index) 57 else: 58 return list( 59 mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions()) 60 ) 61 else: 62 x = gen_unstructured(g, backend_index) 63 return [] if x is None else [x] 64