1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import unittest 10 11import torch 12from executorch.exir.operator import convert as op_convert 13from executorch.exir.operator.convert import to_out_variant 14from torch._ops import OpOverload 15 16 17class TestToOutVariant(unittest.TestCase): 18 def test_already_out_var(self) -> None: 19 self.assertTrue( 20 op_convert.is_out_variant( 21 "aten::topk", 22 "values", 23 ) 24 ) 25 26 def test_to_out_variant_already_out(self) -> None: 27 op_overload = torch.ops.aten.topk.values 28 out_var_op = op_convert.to_out_variant(op_overload)[0] 29 self.assertTrue(op_overload is out_var_op) 30 31 def test_to_out_variant_success(self) -> None: 32 op_overload = torch.ops.aten.topk.default 33 out_var_op, out_args = op_convert.to_out_variant(op_overload) 34 35 input_tensor = torch.randn(100, 200) 36 k = 10 37 38 self.assertTrue(out_var_op is not torch.ops.aten.topk.default) 39 self.assertTrue(out_var_op is torch.ops.aten.topk.values) 40 41 expect_values, expect_indices = op_overload(input_tensor, k) 42 43 kwargs = {} 44 45 out_arg_dtypes = [val.dtype for val in (expect_values, expect_indices)] 46 for name, dtype in zip(out_args, out_arg_dtypes): 47 kwargs[name] = torch.Tensor().to(dtype=dtype) 48 49 actual_values, actual_indices = out_var_op(input_tensor, k, **kwargs) 50 51 self.assertTrue(torch.equal(expect_values, actual_values)) 52 self.assertTrue(torch.equal(expect_indices, actual_indices)) 53 54 # These checks are copied from the unsafe_replace_to_out_variant method 55 # (https://www.fburl.com/code/ukwq31xz) 56 # which are patch rules for the functional ops that can not be 57 # handled generically before. Add unit tests to showoff that we can handle 58 # the custom ops generically now! 59 def test_to_out_variant_batch(self) -> None: 60 aten = torch.ops.aten 61 checklist = { 62 aten.topk.default: (aten.topk.values, ("values", "indices")), 63 aten.view_copy.default: aten.view_copy.out, 64 aten.log_softmax.int: aten.log_softmax.int_out, 65 aten.softmax.int: aten.softmax.int_out, 66 aten.relu.default: aten.relu.out, 67 torch.ops.my_awesome_3rdparty_ns.my_awesome_op.func: torch.ops.my_awesome_3rdparty_ns.my_awesome_op.out, 68 } 69 for func_op, expected_any in checklist.items(): 70 if isinstance(expected_any, OpOverload): 71 # the default case where the out args are ("out",) 72 expected_out_var = expected_any 73 expected_out_args = ("out",) 74 else: 75 expected_out_var, expected_out_args = expected_any 76 actual_out_var, actual_out_args = op_convert.to_out_variant(func_op) 77 self.assertEqual(expected_out_var, actual_out_var) 78 self.assertEqual(expected_out_args, actual_out_args) 79 80 def test_to_out_variant_schema_mismatch(self) -> None: 81 func_var_op: OpOverload = ( 82 torch.ops.my_awesome_3rdparty_ns.schema_mismatch_op.default 83 ) 84 with self.assertRaisesRegex( 85 RuntimeError, 86 "Found an out variant for operator name .* but its schema mismatched with functional op.", 87 ): 88 to_out_variant(func_var_op) 89