1#!/usr/bin/env fbpython 2# Copyright (c) Meta Platforms, Inc. and affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import unittest 9 10import executorch.kernels.quantized # noqa[F401] 'executorch.kernels.quantized' imported but unused 11 12import torch 13import torch.ao.quantization.fx._decomposed # noqa[F401] 'torch.ao.quantization.fx._decomposed' imported but unused 14from executorch.exir.dialects._ops import ops 15from executorch.exir.passes._quant_patterns_and_replacements import ( # noqa 16 quantized_decomposed_lib, # noqa 17) 18 19 20class TestOutVariants(unittest.TestCase): 21 def setUp(self) -> None: 22 super().setUp() 23 24 def test_add_to_out_variant(self) -> None: 25 self.assertIsNotNone(ops.edge.quantized_decomposed.add.out) 26 fn = ops.edge.quantized_decomposed.add.default 27 out_variant = fn.to_out_variant() 28 self.assertEqual(out_variant.name(), "quantized_decomposed::add.out") 29 30 def test_choose_qparams_tensor_to_out_variant(self) -> None: 31 self.assertIsNotNone(ops.edge.quantized_decomposed.choose_qparams.Tensor_out) 32 choose_qparams = ops.edge.quantized_decomposed.choose_qparams.tensor 33 out_variant = choose_qparams.to_out_variant() 34 self.assertEqual( 35 out_variant.name(), "quantized_decomposed::choose_qparams.Tensor_out" 36 ) 37 38 def test_dequantize_per_tensor_to_out_variant(self) -> None: 39 self.assertIsNotNone(ops.edge.quantized_decomposed.dequantize_per_tensor.out) 40 fn = ops.edge.quantized_decomposed.dequantize_per_tensor.default 41 out_variant = fn.to_out_variant() 42 self.assertEqual( 43 out_variant.name(), "quantized_decomposed::dequantize_per_tensor.out" 44 ) 45 46 def test_dequantize_per_tensor_tensor_to_out_variant(self) -> None: 47 self.assertIsNotNone( 48 ops.edge.quantized_decomposed.dequantize_per_tensor.Tensor_out 49 ) 50 fn = ops.edge.quantized_decomposed.dequantize_per_tensor.tensor 51 out_variant = fn.to_out_variant() 52 self.assertEqual( 53 out_variant.name(), "quantized_decomposed::dequantize_per_tensor.Tensor_out" 54 ) 55 56 def test_dequantize_per_channel_to_out_variant(self) -> None: 57 self.assertIsNotNone(ops.edge.quantized_decomposed.dequantize_per_channel.out) 58 fn = ops.edge.quantized_decomposed.dequantize_per_channel.default 59 out_variant = fn.to_out_variant() 60 self.assertEqual( 61 out_variant.name(), "quantized_decomposed::dequantize_per_channel.out" 62 ) 63 64 def test_mixed_linear_to_out_variant(self) -> None: 65 self.assertIsNotNone(ops.edge.quantized_decomposed.mixed_linear.out) 66 fn = ops.edge.quantized_decomposed.mixed_linear.default 67 out_variant = fn.to_out_variant() 68 self.assertEqual(out_variant.name(), "quantized_decomposed::mixed_linear.out") 69 70 def test_mixed_mm_to_out_variant(self) -> None: 71 self.assertIsNotNone(ops.edge.quantized_decomposed.mixed_mm.out) 72 fn = ops.edge.quantized_decomposed.mixed_mm.default 73 out_variant = fn.to_out_variant() 74 self.assertEqual(out_variant.name(), "quantized_decomposed::mixed_mm.out") 75 76 def test_quantize_per_tensor_to_out_variant(self) -> None: 77 self.assertIsNotNone(ops.edge.quantized_decomposed.quantize_per_tensor.out) 78 fn = ops.edge.quantized_decomposed.quantize_per_tensor.default 79 out_variant = fn.to_out_variant() 80 self.assertEqual( 81 out_variant.name(), "quantized_decomposed::quantize_per_tensor.out" 82 ) 83 84 def test_quantize_per_tensor_tensor_to_out_variant(self) -> None: 85 self.assertIsNotNone( 86 ops.edge.quantized_decomposed.quantize_per_tensor.Tensor_out 87 ) 88 fn = ops.edge.quantized_decomposed.quantize_per_tensor.tensor 89 out_variant = fn.to_out_variant() 90 self.assertEqual( 91 out_variant.name(), "quantized_decomposed::quantize_per_tensor.Tensor_out" 92 ) 93 94 def test_quantize_per_channel_to_out_variant(self) -> None: 95 self.assertIsNotNone(ops.edge.quantized_decomposed.quantize_per_channel.out) 96 fn = ops.edge.quantized_decomposed.quantize_per_channel.default 97 out_variant = fn.to_out_variant() 98 self.assertEqual( 99 out_variant.name(), "quantized_decomposed::quantize_per_channel.out" 100 ) 101