1import unittest 2from types import ModuleType 3 4from torchgen import local 5from torchgen.api import cpp as aten_cpp, types as aten_types 6from torchgen.api.types import ( 7 ArgName, 8 BaseCType, 9 ConstRefCType, 10 MutRefCType, 11 NamedCType, 12) 13from torchgen.executorch.api import et_cpp as et_cpp, types as et_types 14from torchgen.executorch.api.unboxing import Unboxing 15from torchgen.model import BaseTy, BaseType, ListType, OptionalType, Type 16 17 18def aten_argumenttype_type_wrapper( 19 t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False 20) -> NamedCType: 21 return aten_cpp.argumenttype_type( 22 t, 23 mutable=mutable, 24 binds=binds, 25 remove_non_owning_ref_types=remove_non_owning_ref_types, 26 ) 27 28 29ATEN_UNBOXING = Unboxing(argument_type_gen=aten_argumenttype_type_wrapper) 30ET_UNBOXING = Unboxing(argument_type_gen=et_cpp.argumenttype_type) 31 32 33class TestUnboxing(unittest.TestCase): 34 """ 35 Could use torch.testing._internal.common_utils to reduce boilerplate. 36 GH CI job doesn't build torch before running tools unit tests, hence 37 manually adding these parametrized tests. 38 """ 39 40 @local.parametrize( 41 use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False 42 ) 43 def test_symint_argument_translate_ctype_aten(self) -> None: 44 # test if `SymInt[]` JIT argument can be translated into C++ argument correctly. 45 # should be `IntArrayRef` due to the fact that Executorch doesn't use symint sig. 46 47 # pyre-fixme[16]: `enum.Enum` has no attribute `SymInt` 48 # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided. 49 symint_list_type = ListType(elem=BaseType(BaseTy.SymInt), size=None) 50 51 out_name, ctype, _, _ = ATEN_UNBOXING.argumenttype_evalue_convert( 52 t=symint_list_type, arg_name="size", mutable=False 53 ) 54 55 self.assertEqual(out_name, "size_list_out") 56 self.assertIsInstance(ctype, BaseCType) 57 # pyre-fixme[16]: 58 self.assertEqual(ctype, aten_types.BaseCType(aten_types.intArrayRefT)) 59 60 @local.parametrize( 61 use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False 62 ) 63 def test_symint_argument_translate_ctype_executorch(self) -> None: 64 # test if `SymInt[]` JIT argument can be translated into C++ argument correctly. 65 # should be `IntArrayRef` due to the fact that Executorch doesn't use symint sig. 66 67 # pyre-fixme[16]: `enum.Enum` has no attribute `SymInt` 68 # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided. 69 symint_list_type = ListType(elem=BaseType(BaseTy.SymInt), size=None) 70 71 out_name, ctype, _, _ = ET_UNBOXING.argumenttype_evalue_convert( 72 t=symint_list_type, arg_name="size", mutable=False 73 ) 74 75 self.assertEqual(out_name, "size_list_out") 76 self.assertIsInstance(ctype, et_types.ArrayRefCType) 77 # pyre-fixme[16]: 78 self.assertEqual( 79 ctype, et_types.ArrayRefCType(elem=BaseCType(aten_types.longT)) 80 ) 81 82 @local.parametrize( 83 use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False 84 ) 85 def _test_const_tensor_argument_translate_ctype( 86 self, unboxing: Unboxing, types: ModuleType 87 ) -> None: 88 # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor` 89 # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided. 90 tensor_type = BaseType(BaseTy.Tensor) 91 92 out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert( 93 t=tensor_type, arg_name="self", mutable=False 94 ) 95 96 self.assertEqual(out_name, "self_base") 97 # pyre-fixme[16]: 98 self.assertEqual(ctype, ConstRefCType(BaseCType(types.tensorT))) 99 100 def test_const_tensor_argument_translate_ctype_aten(self) -> None: 101 self._test_const_tensor_argument_translate_ctype(ATEN_UNBOXING, aten_types) 102 103 def test_const_tensor_argument_translate_ctype_executorch(self) -> None: 104 self._test_const_tensor_argument_translate_ctype(ET_UNBOXING, et_types) 105 106 @local.parametrize( 107 use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False 108 ) 109 def _test_mutable_tensor_argument_translate_ctype( 110 self, unboxing: Unboxing, types: ModuleType 111 ) -> None: 112 # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor` 113 # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided. 114 tensor_type = BaseType(BaseTy.Tensor) 115 116 out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert( 117 t=tensor_type, arg_name="out", mutable=True 118 ) 119 120 self.assertEqual(out_name, "out_base") 121 # pyre-fixme[16]: 122 self.assertEqual(ctype, MutRefCType(BaseCType(types.tensorT))) 123 124 def test_mutable_tensor_argument_translate_ctype_aten(self) -> None: 125 self._test_mutable_tensor_argument_translate_ctype(ATEN_UNBOXING, aten_types) 126 127 def test_mutable_tensor_argument_translate_ctype_executorch(self) -> None: 128 self._test_mutable_tensor_argument_translate_ctype(ET_UNBOXING, et_types) 129 130 @local.parametrize( 131 use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False 132 ) 133 def _test_tensor_list_argument_translate_ctype( 134 self, unboxing: Unboxing, types: ModuleType 135 ) -> None: 136 # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor` 137 # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided. 138 tensor_list_type = ListType(elem=BaseType(BaseTy.Tensor), size=None) 139 140 out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert( 141 t=tensor_list_type, arg_name="out", mutable=True 142 ) 143 144 self.assertEqual(out_name, "out_list_out") 145 # pyre-fixme[16]: 146 self.assertEqual(ctype, BaseCType(types.tensorListT)) 147 148 def test_tensor_list_argument_translate_ctype_aten(self) -> None: 149 self._test_tensor_list_argument_translate_ctype(ATEN_UNBOXING, aten_types) 150 151 def test_tensor_list_argument_translate_ctype_executorch(self) -> None: 152 self._test_tensor_list_argument_translate_ctype(ET_UNBOXING, et_types) 153 154 @local.parametrize( 155 use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False 156 ) 157 def _test_optional_int_argument_translate_ctype( 158 self, unboxing: Unboxing, types: ModuleType 159 ) -> None: 160 # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor` 161 # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided. 162 optional_int_type = OptionalType(elem=BaseType(BaseTy.int)) 163 164 out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert( 165 t=optional_int_type, arg_name="something", mutable=True 166 ) 167 168 self.assertEqual(out_name, "something_opt_out") 169 # pyre-fixme[16]: 170 self.assertEqual(ctype, types.OptionalCType(BaseCType(types.longT))) 171 172 def test_optional_int_argument_translate_ctype_aten(self) -> None: 173 self._test_optional_int_argument_translate_ctype(ATEN_UNBOXING, aten_types) 174 175 def test_optional_int_argument_translate_ctype_executorch(self) -> None: 176 self._test_optional_int_argument_translate_ctype(ET_UNBOXING, et_types) 177