1import unittest 2 3from torchgen.executorch.api.types import ExecutorchCppSignature 4from torchgen.local import parametrize 5from torchgen.model import Location, NativeFunction 6 7 8DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml( 9 {"func": "foo.out(Tensor input, *, Tensor(a!) out) -> Tensor(a!)"}, 10 loc=Location(__file__, 1), 11 valid_tags=set(), 12) 13 14 15class ExecutorchCppSignatureTest(unittest.TestCase): 16 def setUp(self) -> None: 17 self.sig = ExecutorchCppSignature.from_native_function(DEFAULT_NATIVE_FUNCTION) 18 19 def test_runtime_signature_contains_runtime_context(self) -> None: 20 # test if `KernelRuntimeContext` argument exists in `RuntimeSignature` 21 with parametrize( 22 use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False 23 ): 24 args = self.sig.arguments(include_context=True) 25 self.assertEqual(len(args), 3) 26 self.assertTrue(any(a.name == "context" for a in args)) 27 28 def test_runtime_signature_does_not_contain_runtime_context(self) -> None: 29 # test if `KernelRuntimeContext` argument is missing in `RuntimeSignature` 30 with parametrize( 31 use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False 32 ): 33 args = self.sig.arguments(include_context=False) 34 self.assertEqual(len(args), 2) 35 self.assertFalse(any(a.name == "context" for a in args)) 36 37 def test_runtime_signature_declaration_correct(self) -> None: 38 with parametrize( 39 use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False 40 ): 41 decl = self.sig.decl(include_context=True) 42 self.assertEqual( 43 decl, 44 ( 45 "torch::executor::Tensor & foo_outf(" 46 "torch::executor::KernelRuntimeContext & context, " 47 "const torch::executor::Tensor & input, " 48 "torch::executor::Tensor & out)" 49 ), 50 ) 51 no_context_decl = self.sig.decl(include_context=False) 52 self.assertEqual( 53 no_context_decl, 54 ( 55 "torch::executor::Tensor & foo_outf(" 56 "const torch::executor::Tensor & input, " 57 "torch::executor::Tensor & out)" 58 ), 59 ) 60