xref: /aosp_15_r20/external/pytorch/tools/test/test_executorch_unboxing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import unittest
2from types import ModuleType
3
4from torchgen import local
5from torchgen.api import cpp as aten_cpp, types as aten_types
6from torchgen.api.types import (
7    ArgName,
8    BaseCType,
9    ConstRefCType,
10    MutRefCType,
11    NamedCType,
12)
13from torchgen.executorch.api import et_cpp as et_cpp, types as et_types
14from torchgen.executorch.api.unboxing import Unboxing
15from torchgen.model import BaseTy, BaseType, ListType, OptionalType, Type
16
17
18def aten_argumenttype_type_wrapper(
19    t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False
20) -> NamedCType:
21    return aten_cpp.argumenttype_type(
22        t,
23        mutable=mutable,
24        binds=binds,
25        remove_non_owning_ref_types=remove_non_owning_ref_types,
26    )
27
28
29ATEN_UNBOXING = Unboxing(argument_type_gen=aten_argumenttype_type_wrapper)
30ET_UNBOXING = Unboxing(argument_type_gen=et_cpp.argumenttype_type)
31
32
33class TestUnboxing(unittest.TestCase):
34    """
35    Could use torch.testing._internal.common_utils to reduce boilerplate.
36    GH CI job doesn't build torch before running tools unit tests, hence
37    manually adding these parametrized tests.
38    """
39
40    @local.parametrize(
41        use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
42    )
43    def test_symint_argument_translate_ctype_aten(self) -> None:
44        # test if `SymInt[]` JIT argument can be translated into C++ argument correctly.
45        # should be `IntArrayRef` due to the fact that Executorch doesn't use symint sig.
46
47        # pyre-fixme[16]: `enum.Enum` has no attribute `SymInt`
48        # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided.
49        symint_list_type = ListType(elem=BaseType(BaseTy.SymInt), size=None)
50
51        out_name, ctype, _, _ = ATEN_UNBOXING.argumenttype_evalue_convert(
52            t=symint_list_type, arg_name="size", mutable=False
53        )
54
55        self.assertEqual(out_name, "size_list_out")
56        self.assertIsInstance(ctype, BaseCType)
57        # pyre-fixme[16]:
58        self.assertEqual(ctype, aten_types.BaseCType(aten_types.intArrayRefT))
59
60    @local.parametrize(
61        use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
62    )
63    def test_symint_argument_translate_ctype_executorch(self) -> None:
64        # test if `SymInt[]` JIT argument can be translated into C++ argument correctly.
65        # should be `IntArrayRef` due to the fact that Executorch doesn't use symint sig.
66
67        # pyre-fixme[16]: `enum.Enum` has no attribute `SymInt`
68        # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided.
69        symint_list_type = ListType(elem=BaseType(BaseTy.SymInt), size=None)
70
71        out_name, ctype, _, _ = ET_UNBOXING.argumenttype_evalue_convert(
72            t=symint_list_type, arg_name="size", mutable=False
73        )
74
75        self.assertEqual(out_name, "size_list_out")
76        self.assertIsInstance(ctype, et_types.ArrayRefCType)
77        # pyre-fixme[16]:
78        self.assertEqual(
79            ctype, et_types.ArrayRefCType(elem=BaseCType(aten_types.longT))
80        )
81
82    @local.parametrize(
83        use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
84    )
85    def _test_const_tensor_argument_translate_ctype(
86        self, unboxing: Unboxing, types: ModuleType
87    ) -> None:
88        # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor`
89        # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided.
90        tensor_type = BaseType(BaseTy.Tensor)
91
92        out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert(
93            t=tensor_type, arg_name="self", mutable=False
94        )
95
96        self.assertEqual(out_name, "self_base")
97        # pyre-fixme[16]:
98        self.assertEqual(ctype, ConstRefCType(BaseCType(types.tensorT)))
99
100    def test_const_tensor_argument_translate_ctype_aten(self) -> None:
101        self._test_const_tensor_argument_translate_ctype(ATEN_UNBOXING, aten_types)
102
103    def test_const_tensor_argument_translate_ctype_executorch(self) -> None:
104        self._test_const_tensor_argument_translate_ctype(ET_UNBOXING, et_types)
105
106    @local.parametrize(
107        use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
108    )
109    def _test_mutable_tensor_argument_translate_ctype(
110        self, unboxing: Unboxing, types: ModuleType
111    ) -> None:
112        # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor`
113        # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided.
114        tensor_type = BaseType(BaseTy.Tensor)
115
116        out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert(
117            t=tensor_type, arg_name="out", mutable=True
118        )
119
120        self.assertEqual(out_name, "out_base")
121        # pyre-fixme[16]:
122        self.assertEqual(ctype, MutRefCType(BaseCType(types.tensorT)))
123
124    def test_mutable_tensor_argument_translate_ctype_aten(self) -> None:
125        self._test_mutable_tensor_argument_translate_ctype(ATEN_UNBOXING, aten_types)
126
127    def test_mutable_tensor_argument_translate_ctype_executorch(self) -> None:
128        self._test_mutable_tensor_argument_translate_ctype(ET_UNBOXING, et_types)
129
130    @local.parametrize(
131        use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
132    )
133    def _test_tensor_list_argument_translate_ctype(
134        self, unboxing: Unboxing, types: ModuleType
135    ) -> None:
136        # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor`
137        # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided.
138        tensor_list_type = ListType(elem=BaseType(BaseTy.Tensor), size=None)
139
140        out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert(
141            t=tensor_list_type, arg_name="out", mutable=True
142        )
143
144        self.assertEqual(out_name, "out_list_out")
145        # pyre-fixme[16]:
146        self.assertEqual(ctype, BaseCType(types.tensorListT))
147
148    def test_tensor_list_argument_translate_ctype_aten(self) -> None:
149        self._test_tensor_list_argument_translate_ctype(ATEN_UNBOXING, aten_types)
150
151    def test_tensor_list_argument_translate_ctype_executorch(self) -> None:
152        self._test_tensor_list_argument_translate_ctype(ET_UNBOXING, et_types)
153
154    @local.parametrize(
155        use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
156    )
157    def _test_optional_int_argument_translate_ctype(
158        self, unboxing: Unboxing, types: ModuleType
159    ) -> None:
160        # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor`
161        # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided.
162        optional_int_type = OptionalType(elem=BaseType(BaseTy.int))
163
164        out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert(
165            t=optional_int_type, arg_name="something", mutable=True
166        )
167
168        self.assertEqual(out_name, "something_opt_out")
169        # pyre-fixme[16]:
170        self.assertEqual(ctype, types.OptionalCType(BaseCType(types.longT)))
171
172    def test_optional_int_argument_translate_ctype_aten(self) -> None:
173        self._test_optional_int_argument_translate_ctype(ATEN_UNBOXING, aten_types)
174
175    def test_optional_int_argument_translate_ctype_executorch(self) -> None:
176        self._test_optional_int_argument_translate_ctype(ET_UNBOXING, et_types)
177