xref: /aosp_15_r20/external/pytorch/tools/autograd/gen_variable_factories.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Generates C++ functions that wrap ATen tensor factory methods to turn them into Variables.
2#
3# This writes one file: variable_factories.h
4
5from __future__ import annotations
6
7import re
8
9import torchgen.api.python as python
10from torchgen.api import cpp
11from torchgen.api.types import CppSignatureGroup
12from torchgen.context import with_native_function
13from torchgen.gen import parse_native_yaml
14from torchgen.model import NativeFunction, TensorOptionsArguments, Variant
15from torchgen.utils import FileManager, mapMaybe
16
17
18OPTIONAL_TYPE_PATTERN = re.compile(r"std::optional<(.+)>")
19TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)")
20
21
22# Add 'at::' to types defined in ATen namespace, e.g. Tensor, TensorList, IntArrayRef and etc.
23# TODO: maybe update the cpp argument API to take optional namespace argument?
24def fully_qualified_type(argument_type: str) -> str:
25    def maybe_optional_type(type: str, is_opt: bool) -> str:
26        return f"std::optional<{type}>" if is_opt else type
27
28    opt_match = OPTIONAL_TYPE_PATTERN.match(argument_type)
29    is_opt = opt_match is not None
30    if opt_match:
31        argument_type = argument_type[opt_match.start(1) : opt_match.end(1)]
32    match = TYPE_PATTERN.match(argument_type)
33    if match is None:
34        return maybe_optional_type(argument_type, is_opt)
35    index = match.start(1)
36    qualified_type = f"{argument_type[:index]}at::{argument_type[index:]}"
37    return maybe_optional_type(qualified_type, is_opt)
38
39
40def gen_variable_factories(
41    out: str, native_yaml_path: str, tags_yaml_path: str, template_path: str
42) -> None:
43    native_functions = parse_native_yaml(
44        native_yaml_path, tags_yaml_path
45    ).native_functions
46    factory_functions = [fn for fn in native_functions if is_factory_function(fn)]
47    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
48    fm.write_with_template(
49        "variable_factories.h",
50        "variable_factories.h",
51        lambda: {
52            "generated_comment": "@"
53            + f"generated from {fm.template_dir_for_comments()}/variable_factories.h",
54            "ops_headers": [
55                f"#include <ATen/ops/{fn.root_name}.h>" for fn in factory_functions
56            ],
57            "function_definitions": list(mapMaybe(process_function, factory_functions)),
58        },
59    )
60
61
62@with_native_function
63def is_factory_function(f: NativeFunction) -> bool:
64    if Variant.function not in f.variants:
65        return False
66
67    name = cpp.name(f.func)
68    has_tensor_options = python.has_tensor_options(f)
69    return has_tensor_options or name.endswith("_like")
70
71
72@with_native_function
73def process_function(f: NativeFunction) -> str | None:
74    name = cpp.name(f.func)
75    has_tensor_options = python.has_tensor_options(f)
76    is_factory = has_tensor_options or name.endswith("_like")
77
78    if Variant.function not in f.variants or not is_factory:
79        return None
80
81    cpp_sigs = CppSignatureGroup.from_native_function(f, method=False)
82    sigs = [cpp_sigs.signature]
83    if cpp_sigs.symint_signature is not None:
84        sigs.append(cpp_sigs.symint_signature)
85    r = ""
86    for sig in sigs:
87        formals: list[str] = []
88        exprs: list[str] = []
89        requires_grad = "false"
90        for arg in sig.arguments():
91            qualified_type = fully_qualified_type(arg.type)
92            if arg.default:
93                formals.append(f"{qualified_type} {arg.name} = {arg.default}")
94            else:
95                formals.append(f"{qualified_type} {arg.name}")
96
97            if isinstance(arg.argument, TensorOptionsArguments):
98                # note: we remove the requires_grad setting from the TensorOptions because
99                # it is ignored anyways (and we actually have an assertion that it isn't set
100                # which would fail otherwise). We handle requires_grad explicitly here
101                # instead of passing it through to the kernel.
102                exprs.append(
103                    f"at::TensorOptions({arg.name}).requires_grad(::std::nullopt)"
104                )
105                # Manually set the requires_grad bit on the result tensor.
106                requires_grad = f"{arg.name}.requires_grad()"
107            else:
108                exprs.append(arg.name)
109
110        r += f"""\
111inline at::Tensor {sig.name()}({', '.join(formals)}) {{
112  at::AutoDispatchBelowADInplaceOrView guard;
113  return autograd::make_variable(at::{sig.name()}({', '.join(exprs)}), /*requires_grad=*/{requires_grad});
114}}
115"""
116    return r
117