xref: /aosp_15_r20/external/pytorch/tools/test/test_executorch_signatures.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import unittest
2
3from torchgen.executorch.api.types import ExecutorchCppSignature
4from torchgen.local import parametrize
5from torchgen.model import Location, NativeFunction
6
7
8DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
9    {"func": "foo.out(Tensor input, *, Tensor(a!) out) -> Tensor(a!)"},
10    loc=Location(__file__, 1),
11    valid_tags=set(),
12)
13
14
15class ExecutorchCppSignatureTest(unittest.TestCase):
16    def setUp(self) -> None:
17        self.sig = ExecutorchCppSignature.from_native_function(DEFAULT_NATIVE_FUNCTION)
18
19    def test_runtime_signature_contains_runtime_context(self) -> None:
20        # test if `KernelRuntimeContext` argument exists in `RuntimeSignature`
21        with parametrize(
22            use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
23        ):
24            args = self.sig.arguments(include_context=True)
25            self.assertEqual(len(args), 3)
26            self.assertTrue(any(a.name == "context" for a in args))
27
28    def test_runtime_signature_does_not_contain_runtime_context(self) -> None:
29        # test if `KernelRuntimeContext` argument is missing in `RuntimeSignature`
30        with parametrize(
31            use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
32        ):
33            args = self.sig.arguments(include_context=False)
34            self.assertEqual(len(args), 2)
35            self.assertFalse(any(a.name == "context" for a in args))
36
37    def test_runtime_signature_declaration_correct(self) -> None:
38        with parametrize(
39            use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
40        ):
41            decl = self.sig.decl(include_context=True)
42            self.assertEqual(
43                decl,
44                (
45                    "torch::executor::Tensor & foo_outf("
46                    "torch::executor::KernelRuntimeContext & context, "
47                    "const torch::executor::Tensor & input, "
48                    "torch::executor::Tensor & out)"
49                ),
50            )
51            no_context_decl = self.sig.decl(include_context=False)
52            self.assertEqual(
53                no_context_decl,
54                (
55                    "torch::executor::Tensor & foo_outf("
56                    "const torch::executor::Tensor & input, "
57                    "torch::executor::Tensor & out)"
58                ),
59            )
60