xref: /aosp_15_r20/external/pytorch/torchgen/api/unboxing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from torchgen.api import cpp
4from torchgen.api.types import Binding, CppSignatureGroup, CType
5from torchgen.model import (
6    Argument,
7    BaseTy,
8    BaseType,
9    ListType,
10    NativeFunction,
11    OptionalType,
12    Type,
13)
14
15
16# This file generates the code for unboxing wrappers, i.e., the glue logic to unbox a boxed operator and convert the
17# ivalues from stack to correct arguments to the unboxed kernel, based on corresponding JIT schema. This codegen is
18# an alternative way to generate unboxing wrappers similar to the existing C++ metaprogramming approach but gets the
19# job done statically. These generated unboxing wrappers will be useful under the scenario where we need to register
20# a fixed set of operators known at compile time and thus can save some time in runtime initialization phase.
21#
22# Here's an example on how the codegen works:
23#
24# - Function Schema (source of truth)
25#
26#      aten::empty.names(int[] size, *, Dimname[]? names,
27#                        ScalarType? dtype=None, Layout? layout=None,
28#                        Device? device=None, bool? pin_memory=None,
29#                        MemoryFormat? memory_format=None) -> Tensor
30# - Argument Conversion
31#       Generates C++ code to convert an ivalue (from stack) to its underlying C++ type.
32#    - int[] size
33#        ```cpp
34#           const c10::List<c10::IValue> size_list_in = (std::move(peek(stack, 0, 7))).toList();
35#
36#           std::vector<int64_t> size_vec;
37#           for (c10::IValue size_elem: size_list_in) {
38#               int64_t size_base = size_elem.to<int64_t>();
39#               size_vec.push_back(size_base);
40#           }
41#           at::ArrayRef<int64_t> size_list_out(size_vec);
42#                                 ~~~~~~~~~~~~~ <-- The converted argument from ivalues in the stack.
43#                                                   Will be passed to unboxed kernel.
44#       ```
45#    - Dimname[]? names
46#       ```cpp
47#           ::std::optional<c10::IValue> names_opt = (std::move(peek(stack, 1, 7))).toOptional<c10::IValue>();
48#           ::std::optional<at::ArrayRef<at::Dimname>> names_opt_out;
49#           if (names_opt.has_value()) {
50#                         ~~~~~~~~~~~ <-- Unwrapping optional shell
51#               const c10::IValue names_opt_in = names_opt.value();
52#               const c10::List<c10::IValue> names_list_in = names_opt_in.toList();
53#
54#               std::vector<at::Dimname> names_vec;
55#               for (c10::IValue names_elem: names_list_in) {
56#                                ~~~~~~~~~~~~~~~~~~~~~~~~~ <-- Unrolling list, then convert elements one by one.
57#                   at::Dimname names_base = names_elem.to<at::Dimname>();
58#                   names_vec.push_back(names_base);
59#               }
60#               at::ArrayRef<at::Dimname> names_list_out(names_vec);
61#
62#               names_opt_out = ::std::optional<at::ArrayRef<at::Dimname>>(names_list_out);
63#           } else {
64#               names_opt_out = ::std::optional<at::ArrayRef<at::Dimname>>();
65#           }
66#       ```
67#    - ScalarType? dtype (similarly for the rest of the arguments)
68#       ```cpp
69#           ::std::optional<c10::IValue> dtype_opt = (std::move(peek(stack, 2, 7))).toOptional<c10::IValue>();
70#           ::std::optional<at::ScalarType> dtype_opt_out;
71#           if (dtype_opt.has_value()) {
72#               const c10::IValue dtype_opt_in = dtype_opt.value();
73#               at::ScalarType dtype_base = dtype_opt_in.to<at::ScalarType>();
74#                                                        ~~~~~~~~~~~~~~~~~~~~ <-- For base types, convert ivalue to it
75#                                                                                 directly using ".to<T>()" API.
76#               dtype_opt_out = ::std::optional<at::ScalarType>(dtype_base);
77#           } else {
78#               dtype_opt_out = ::std::optional<at::ScalarType>();
79#           }
80#       ```
81#
82# - Unboxed Kernel Call
83#   ```cpp
84#       auto result_ = torch::empty(
85#           size_list_out,
86#           names_opt_out,
87#           options,
88#           memory_format_opt_out
89#       );
90#   ```
91#
92# - Push Result Back to Stack
93#   ```cpp
94#       drop(stack, 7);
95#       pack(stack, std::move(result_));
96#   ```
97connector = "\n\t"
98
99
100# Return unboxing function name for a NativeFunction
101def name(f: NativeFunction) -> str:
102    return f.func.name.unambiguous_name()
103
104
105# Convert all the arguments in a NativeFunction to C++ code
106def convert_arguments(f: NativeFunction) -> tuple[list[Binding], list[str]]:
107    # we need the 'self' argument so method needs to be False
108    args = (
109        CppSignatureGroup.from_native_function(f, method=False)
110        .most_faithful_signature()
111        .arguments()
112    )
113    code_list = [
114        f"c10::IValue {args[i].name} = std::move(peek(stack, {i}, {len(args)}));"
115        for i in range(len(args))
116    ] + [""]
117    binding_list = []
118    for arg in args:
119        # expecting only Argument
120        if not isinstance(arg.argument, Argument):
121            raise Exception(  # noqa: TRY002
122                f"Unexpected argument type, expecting `Argument` but got {arg}"
123            )
124        argument: Argument = arg.argument
125        unboxed_name, _, code, decl = argumenttype_ivalue_convert(
126            argument.type,
127            argument.name,
128            mutable=argument.is_write,
129        )
130        code_list.extend(decl)
131        code_list.extend(code)
132        binding_list.append(arg.with_name(unboxed_name))
133    return binding_list, code_list
134
135
136# Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
137# (1) the C++ code necessary to unbox the argument
138# (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
139def argumenttype_ivalue_convert(
140    t: Type, arg_name: str, *, mutable: bool = False
141) -> tuple[str, CType, list[str], list[str]]:
142    # Unboxing is for mobile, which doesn't care about SymInts
143    ctype = cpp.argumenttype_type(
144        t=t, mutable=mutable, binds=arg_name, symint=False
145    ).type
146
147    if isinstance(t, BaseType):
148        out_name = f"{arg_name}_base"
149        code, decl = _gen_code_base_type(
150            arg_name=arg_name, out_name=out_name, ctype=ctype
151        )
152    elif isinstance(t, OptionalType):
153        out_name = f"{arg_name}_opt_out"
154        code, decl = _gen_code_optional_type(
155            arg_name=arg_name,
156            out_name=out_name,
157            t=t,
158            ctype=ctype,
159        )
160    elif isinstance(t, ListType):
161        out_name = f"{arg_name}_list_out"
162        code, decl = _gen_code_list_type(
163            arg_name=arg_name,
164            out_name=out_name,
165            t=t,
166            ctype=ctype,
167        )
168    else:
169        raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}")  # noqa: TRY002
170    return out_name, ctype, code, decl
171
172
173def _gen_code_base_type(
174    arg_name: str, out_name: str, ctype: CType
175) -> tuple[list[str], list[str]]:
176    return [
177        f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
178    ], []
179
180
181def _gen_code_optional_type(
182    arg_name: str, out_name: str, t: OptionalType, ctype: CType
183) -> tuple[list[str], list[str]]:
184    in_name = f"{arg_name}_opt_in"
185    res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name)
186    return (
187        f"""
188auto {arg_name}_opt = {arg_name}.toOptional<c10::IValue>();
189{ctype.cpp_type(strip_ref=True)} {out_name};
190if ({arg_name}_opt.has_value()) {{
191    const c10::IValue {in_name} = {arg_name}_opt.value();
192    {connector.join(res_code)}
193    {out_name} = {ctype.cpp_type(strip_ref=True)}({res_name});
194}} else {{
195    {out_name} = {ctype.cpp_type(strip_ref=True)}();
196}}
197        """.split(
198            "\n"
199        ),
200        decl,
201    )
202
203
204def _gen_code_list_type(
205    arg_name: str, out_name: str, t: ListType, ctype: CType
206) -> tuple[list[str], list[str]]:
207    in_name = f"{arg_name}_list_in"
208    elem_name = f"{arg_name}_elem"
209    code = [f"const c10::List<c10::IValue> {in_name} = {arg_name}.toList();"]
210    res_name, res_ctype, res_code, decl = argumenttype_ivalue_convert(t.elem, elem_name)
211    # handle list type with size, e.g., bool[4]
212    if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool and t.size:
213        code.extend(
214            f"""
215{ctype.cpp_type(strip_ref=True)} {out_name} = as_array<{res_ctype.cpp_type(strip_ref=True)}, {t.size}>({in_name});
216            """.split(
217                "\n"
218            )
219        )
220    # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
221    elif isinstance(t.elem, OptionalType):
222        code.extend(
223            f"""
224{ctype.cpp_type(strip_ref=True)} {out_name};
225for (c10::IValue {elem_name}: {in_name}) {{
226    {connector.join(res_code)}
227    {out_name}.push_back({res_name});
228}}
229            """.split(
230                "\n"
231            )
232        )
233    else:
234        # use ArrayRef as default.
235        vec_name = arg_name + "_vec"
236        # need to bring vector instantiation out of scope so that ArrayRef has valid data
237        decl.append(f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};")
238        code.extend(
239            f"""
240for (c10::IValue {elem_name}: {in_name}) {{
241    {connector.join(res_code)}
242    {vec_name}.push_back({res_name});
243}}
244{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
245            """.split(
246                "\n"
247            )
248        )
249    return code, decl
250