xref: /aosp_15_r20/external/pytorch/torchgen/executorch/api/unboxing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from dataclasses import dataclass
4from typing import Callable, Sequence, TYPE_CHECKING
5
6from torchgen.model import (
7    Argument,
8    BaseTy,
9    BaseType,
10    ListType,
11    NativeFunction,
12    OptionalType,
13    Type,
14)
15
16
17if TYPE_CHECKING:
18    from torchgen.api.types import Binding, CType, NamedCType
19
20
21connector = "\n\t"
22
23
24# Return unboxing function name for a NativeFunction
25def name(f: NativeFunction) -> str:
26    return f.func.name.unambiguous_name()
27
28
29@dataclass(frozen=True)
30class Unboxing:
31    """
32    Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
33    A sample generated code:
34    // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
35    void mul_out(EValue** stack) {
36        EValue& self = *stack[0];
37        EValue& other = *stack[1];
38        EValue& out = *stack[2];
39        const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
40        const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
41        torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
42
43        EXECUTORCH_SCOPE_PROF("native_call_mul.out");
44        torch::executor::mul_outf(self_base, other_base, out_base);
45
46
47    }
48    """
49
50    # this is a callable that converts a JIT argument, into its C++ type.
51    # Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type.
52    argument_type_gen: Callable[
53        ...,
54        NamedCType,
55    ]
56
57    # Convert all the arguments in a NativeFunction to C++ code
58    def convert_arguments(
59        self, args: Sequence[Binding]
60    ) -> tuple[list[Binding], list[str]]:
61        code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
62        binding_list = []
63        for arg in args:
64            # expecting only Argument
65            if not isinstance(arg.argument, Argument):
66                raise Exception(  # noqa: TRY002
67                    f"Unexpected argument type, expecting `Argument` but got {arg}"
68                )
69            argument: Argument = arg.argument
70            unboxed_name, _, code, decl = self.argumenttype_evalue_convert(
71                argument.type, argument.name, mutable=argument.is_write
72            )
73            code_list.extend(decl)
74            code_list.extend(code)
75            binding_list.append(arg.with_name(unboxed_name))
76        return binding_list, code_list
77
78    def argumenttype_evalue_convert(
79        self, t: Type, arg_name: str, *, mutable: bool = False
80    ) -> tuple[str, CType, list[str], list[str]]:
81        """
82        Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
83        (1) the C++ code necessary to unbox the argument
84        (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
85        :param t: a `Type` of an argument
86        :param arg_name: argument name
87        :param mutable: boolean for whether this argument type is mutable
88        :return: unboxed result
89        """
90        ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type
91
92        if isinstance(t, BaseType):
93            out_name = f"{arg_name}_base"
94            code, decl = self._gen_code_base_type(
95                arg_name=arg_name, out_name=out_name, ctype=ctype
96            )
97        elif isinstance(t, OptionalType):
98            out_name = f"{arg_name}_opt_out"
99            code, decl = self._gen_code_optional_type(
100                arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
101            )
102        elif isinstance(t, ListType):
103            out_name = f"{arg_name}_list_out"
104            code, decl = self._gen_code_list_type(
105                arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
106            )
107        else:
108            raise Exception(  # noqa: TRY002
109                f"Cannot handle type {t}. arg_name: {arg_name}"
110            )  # noqa: TRY002
111        return out_name, ctype, code, decl
112
113    def _gen_code_base_type(
114        self, arg_name: str, out_name: str, ctype: CType
115    ) -> tuple[list[str], list[str]]:
116        return [
117            f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
118        ], []
119
120    def _gen_code_optional_type(
121        self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
122    ) -> tuple[list[str], list[str]]:
123        in_name = f"{arg_name}_opt_in"
124        res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
125            t.elem, in_name
126        )
127        return (
128            f"""
129    auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
130            """.split(
131                "\n"
132            ),
133            decl,
134        )
135
136    def _gen_code_list_type(
137        self, arg_name: str, out_name: str, t: ListType, ctype: CType
138    ) -> tuple[list[str], list[str]]:
139        in_name = f"{arg_name}_list_in"
140        elem_name = f"{arg_name}_elem"
141        code = []
142        res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert(
143            t.elem, elem_name
144        )
145
146        if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
147            code.extend(
148                f"""
149    auto {out_name} = {arg_name}.toTensorList();
150                """.split(
151                    "\n"
152                )
153            )
154        elif isinstance(t.elem, BaseType) and (
155            t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
156        ):
157            code.extend(
158                f"""
159    auto {out_name} = {arg_name}.toIntList();
160                """.split(
161                    "\n"
162                )
163            )
164        elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
165            code.extend(
166                f"""
167    auto {out_name} = {arg_name}.toDoubleList();
168                """.split(
169                    "\n"
170                )
171            )
172        elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
173            # handle list type with size, e.g., bool[4]
174            code.extend(
175                f"""
176#ifdef USE_ATEN_LIB
177std::array<bool, {t.size}> {out_name};
178auto {in_name} = {arg_name}.toBoolList();
179size_t _i = 0;
180for (auto {elem_name}: {in_name}) {{
181    {out_name}[_i++] = {elem_name};
182}}
183#else
184auto {out_name} = {arg_name}.toBoolList();
185#endif
186                """.split(
187                    "\n"
188                )
189            )
190        # pytorch codegen:
191        # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
192        elif (
193            isinstance(t.elem, OptionalType)
194            and isinstance(t.elem.elem, BaseType)
195            and t.elem.elem.name == BaseTy.Tensor
196        ):
197            code.extend(
198                f"""
199#ifdef USE_ATEN_LIB
200auto {in_name} = {arg_name}.toListOptionalTensor();
201c10::List<::std::optional<at::Tensor>> {out_name};
202for (auto {elem_name}: {in_name}) {{
203    {out_name}.push_back({elem_name});
204}}
205#else
206auto {out_name} = {arg_name}.toListOptionalTensor();
207#endif
208                """.split(
209                    "\n"
210                )
211            )
212        else:
213            # use ArrayRef as default.
214            vec_name = arg_name + "_vec"
215            # need to bring vector instantiation out of scope so that ArrayRef has valid data
216            decl.append(
217                f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
218            )
219            code.extend(
220                f"""
221    for (EValue {elem_name}: {in_name}) {{
222        {connector.join(res_code)}
223        {vec_name}.push_back({res_name});
224    }}
225    {ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
226                """.split(
227                    "\n"
228                )
229            )
230        return code, decl
231