xref: /aosp_15_r20/external/executorch/exir/tests/test_op_convert.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import unittest
10
11import torch
12from executorch.exir.operator import convert as op_convert
13from executorch.exir.operator.convert import to_out_variant
14from torch._ops import OpOverload
15
16
17class TestToOutVariant(unittest.TestCase):
18    def test_already_out_var(self) -> None:
19        self.assertTrue(
20            op_convert.is_out_variant(
21                "aten::topk",
22                "values",
23            )
24        )
25
26    def test_to_out_variant_already_out(self) -> None:
27        op_overload = torch.ops.aten.topk.values
28        out_var_op = op_convert.to_out_variant(op_overload)[0]
29        self.assertTrue(op_overload is out_var_op)
30
31    def test_to_out_variant_success(self) -> None:
32        op_overload = torch.ops.aten.topk.default
33        out_var_op, out_args = op_convert.to_out_variant(op_overload)
34
35        input_tensor = torch.randn(100, 200)
36        k = 10
37
38        self.assertTrue(out_var_op is not torch.ops.aten.topk.default)
39        self.assertTrue(out_var_op is torch.ops.aten.topk.values)
40
41        expect_values, expect_indices = op_overload(input_tensor, k)
42
43        kwargs = {}
44
45        out_arg_dtypes = [val.dtype for val in (expect_values, expect_indices)]
46        for name, dtype in zip(out_args, out_arg_dtypes):
47            kwargs[name] = torch.Tensor().to(dtype=dtype)
48
49        actual_values, actual_indices = out_var_op(input_tensor, k, **kwargs)
50
51        self.assertTrue(torch.equal(expect_values, actual_values))
52        self.assertTrue(torch.equal(expect_indices, actual_indices))
53
54    # These checks are copied from the unsafe_replace_to_out_variant method
55    # (https://www.fburl.com/code/ukwq31xz)
56    # which are patch rules for the functional ops that can not be
57    # handled generically before. Add unit tests to showoff that we can handle
58    # the custom ops generically now!
59    def test_to_out_variant_batch(self) -> None:
60        aten = torch.ops.aten
61        checklist = {
62            aten.topk.default: (aten.topk.values, ("values", "indices")),
63            aten.view_copy.default: aten.view_copy.out,
64            aten.log_softmax.int: aten.log_softmax.int_out,
65            aten.softmax.int: aten.softmax.int_out,
66            aten.relu.default: aten.relu.out,
67            torch.ops.my_awesome_3rdparty_ns.my_awesome_op.func: torch.ops.my_awesome_3rdparty_ns.my_awesome_op.out,
68        }
69        for func_op, expected_any in checklist.items():
70            if isinstance(expected_any, OpOverload):
71                # the default case where the out args are ("out",)
72                expected_out_var = expected_any
73                expected_out_args = ("out",)
74            else:
75                expected_out_var, expected_out_args = expected_any
76            actual_out_var, actual_out_args = op_convert.to_out_variant(func_op)
77            self.assertEqual(expected_out_var, actual_out_var)
78            self.assertEqual(expected_out_args, actual_out_args)
79
80    def test_to_out_variant_schema_mismatch(self) -> None:
81        func_var_op: OpOverload = (
82            torch.ops.my_awesome_3rdparty_ns.schema_mismatch_op.default
83        )
84        with self.assertRaisesRegex(
85            RuntimeError,
86            "Found an out variant for operator name .* but its schema mismatched with functional op.",
87        ):
88            to_out_variant(func_var_op)
89