1# Generates C++ functions that wrap ATen tensor factory methods to turn them into Variables. 2# 3# This writes one file: variable_factories.h 4 5from __future__ import annotations 6 7import re 8 9import torchgen.api.python as python 10from torchgen.api import cpp 11from torchgen.api.types import CppSignatureGroup 12from torchgen.context import with_native_function 13from torchgen.gen import parse_native_yaml 14from torchgen.model import NativeFunction, TensorOptionsArguments, Variant 15from torchgen.utils import FileManager, mapMaybe 16 17 18OPTIONAL_TYPE_PATTERN = re.compile(r"std::optional<(.+)>") 19TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)") 20 21 22# Add 'at::' to types defined in ATen namespace, e.g. Tensor, TensorList, IntArrayRef and etc. 23# TODO: maybe update the cpp argument API to take optional namespace argument? 24def fully_qualified_type(argument_type: str) -> str: 25 def maybe_optional_type(type: str, is_opt: bool) -> str: 26 return f"std::optional<{type}>" if is_opt else type 27 28 opt_match = OPTIONAL_TYPE_PATTERN.match(argument_type) 29 is_opt = opt_match is not None 30 if opt_match: 31 argument_type = argument_type[opt_match.start(1) : opt_match.end(1)] 32 match = TYPE_PATTERN.match(argument_type) 33 if match is None: 34 return maybe_optional_type(argument_type, is_opt) 35 index = match.start(1) 36 qualified_type = f"{argument_type[:index]}at::{argument_type[index:]}" 37 return maybe_optional_type(qualified_type, is_opt) 38 39 40def gen_variable_factories( 41 out: str, native_yaml_path: str, tags_yaml_path: str, template_path: str 42) -> None: 43 native_functions = parse_native_yaml( 44 native_yaml_path, tags_yaml_path 45 ).native_functions 46 factory_functions = [fn for fn in native_functions if is_factory_function(fn)] 47 fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) 48 fm.write_with_template( 49 "variable_factories.h", 50 "variable_factories.h", 51 lambda: { 52 "generated_comment": "@" 53 + f"generated from {fm.template_dir_for_comments()}/variable_factories.h", 54 "ops_headers": [ 55 f"#include <ATen/ops/{fn.root_name}.h>" for fn in factory_functions 56 ], 57 "function_definitions": list(mapMaybe(process_function, factory_functions)), 58 }, 59 ) 60 61 62@with_native_function 63def is_factory_function(f: NativeFunction) -> bool: 64 if Variant.function not in f.variants: 65 return False 66 67 name = cpp.name(f.func) 68 has_tensor_options = python.has_tensor_options(f) 69 return has_tensor_options or name.endswith("_like") 70 71 72@with_native_function 73def process_function(f: NativeFunction) -> str | None: 74 name = cpp.name(f.func) 75 has_tensor_options = python.has_tensor_options(f) 76 is_factory = has_tensor_options or name.endswith("_like") 77 78 if Variant.function not in f.variants or not is_factory: 79 return None 80 81 cpp_sigs = CppSignatureGroup.from_native_function(f, method=False) 82 sigs = [cpp_sigs.signature] 83 if cpp_sigs.symint_signature is not None: 84 sigs.append(cpp_sigs.symint_signature) 85 r = "" 86 for sig in sigs: 87 formals: list[str] = [] 88 exprs: list[str] = [] 89 requires_grad = "false" 90 for arg in sig.arguments(): 91 qualified_type = fully_qualified_type(arg.type) 92 if arg.default: 93 formals.append(f"{qualified_type} {arg.name} = {arg.default}") 94 else: 95 formals.append(f"{qualified_type} {arg.name}") 96 97 if isinstance(arg.argument, TensorOptionsArguments): 98 # note: we remove the requires_grad setting from the TensorOptions because 99 # it is ignored anyways (and we actually have an assertion that it isn't set 100 # which would fail otherwise). We handle requires_grad explicitly here 101 # instead of passing it through to the kernel. 102 exprs.append( 103 f"at::TensorOptions({arg.name}).requires_grad(::std::nullopt)" 104 ) 105 # Manually set the requires_grad bit on the result tensor. 106 requires_grad = f"{arg.name}.requires_grad()" 107 else: 108 exprs.append(arg.name) 109 110 r += f"""\ 111inline at::Tensor {sig.name()}({', '.join(formals)}) {{ 112 at::AutoDispatchBelowADInplaceOrView guard; 113 return autograd::make_variable(at::{sig.name()}({', '.join(exprs)}), /*requires_grad=*/{requires_grad}); 114}} 115""" 116 return r 117