xref: /aosp_15_r20/external/executorch/kernels/quantized/test/test_out_variants.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#!/usr/bin/env fbpython
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import unittest
9
10import executorch.kernels.quantized  # noqa[F401] 'executorch.kernels.quantized' imported but unused
11
12import torch
13import torch.ao.quantization.fx._decomposed  # noqa[F401] 'torch.ao.quantization.fx._decomposed' imported but unused
14from executorch.exir.dialects._ops import ops
15from executorch.exir.passes._quant_patterns_and_replacements import (  # noqa
16    quantized_decomposed_lib,  # noqa
17)
18
19
20class TestOutVariants(unittest.TestCase):
21    def setUp(self) -> None:
22        super().setUp()
23
24    def test_add_to_out_variant(self) -> None:
25        self.assertIsNotNone(ops.edge.quantized_decomposed.add.out)
26        fn = ops.edge.quantized_decomposed.add.default
27        out_variant = fn.to_out_variant()
28        self.assertEqual(out_variant.name(), "quantized_decomposed::add.out")
29
30    def test_choose_qparams_tensor_to_out_variant(self) -> None:
31        self.assertIsNotNone(ops.edge.quantized_decomposed.choose_qparams.Tensor_out)
32        choose_qparams = ops.edge.quantized_decomposed.choose_qparams.tensor
33        out_variant = choose_qparams.to_out_variant()
34        self.assertEqual(
35            out_variant.name(), "quantized_decomposed::choose_qparams.Tensor_out"
36        )
37
38    def test_dequantize_per_tensor_to_out_variant(self) -> None:
39        self.assertIsNotNone(ops.edge.quantized_decomposed.dequantize_per_tensor.out)
40        fn = ops.edge.quantized_decomposed.dequantize_per_tensor.default
41        out_variant = fn.to_out_variant()
42        self.assertEqual(
43            out_variant.name(), "quantized_decomposed::dequantize_per_tensor.out"
44        )
45
46    def test_dequantize_per_tensor_tensor_to_out_variant(self) -> None:
47        self.assertIsNotNone(
48            ops.edge.quantized_decomposed.dequantize_per_tensor.Tensor_out
49        )
50        fn = ops.edge.quantized_decomposed.dequantize_per_tensor.tensor
51        out_variant = fn.to_out_variant()
52        self.assertEqual(
53            out_variant.name(), "quantized_decomposed::dequantize_per_tensor.Tensor_out"
54        )
55
56    def test_dequantize_per_channel_to_out_variant(self) -> None:
57        self.assertIsNotNone(ops.edge.quantized_decomposed.dequantize_per_channel.out)
58        fn = ops.edge.quantized_decomposed.dequantize_per_channel.default
59        out_variant = fn.to_out_variant()
60        self.assertEqual(
61            out_variant.name(), "quantized_decomposed::dequantize_per_channel.out"
62        )
63
64    def test_mixed_linear_to_out_variant(self) -> None:
65        self.assertIsNotNone(ops.edge.quantized_decomposed.mixed_linear.out)
66        fn = ops.edge.quantized_decomposed.mixed_linear.default
67        out_variant = fn.to_out_variant()
68        self.assertEqual(out_variant.name(), "quantized_decomposed::mixed_linear.out")
69
70    def test_mixed_mm_to_out_variant(self) -> None:
71        self.assertIsNotNone(ops.edge.quantized_decomposed.mixed_mm.out)
72        fn = ops.edge.quantized_decomposed.mixed_mm.default
73        out_variant = fn.to_out_variant()
74        self.assertEqual(out_variant.name(), "quantized_decomposed::mixed_mm.out")
75
76    def test_quantize_per_tensor_to_out_variant(self) -> None:
77        self.assertIsNotNone(ops.edge.quantized_decomposed.quantize_per_tensor.out)
78        fn = ops.edge.quantized_decomposed.quantize_per_tensor.default
79        out_variant = fn.to_out_variant()
80        self.assertEqual(
81            out_variant.name(), "quantized_decomposed::quantize_per_tensor.out"
82        )
83
84    def test_quantize_per_tensor_tensor_to_out_variant(self) -> None:
85        self.assertIsNotNone(
86            ops.edge.quantized_decomposed.quantize_per_tensor.Tensor_out
87        )
88        fn = ops.edge.quantized_decomposed.quantize_per_tensor.tensor
89        out_variant = fn.to_out_variant()
90        self.assertEqual(
91            out_variant.name(), "quantized_decomposed::quantize_per_tensor.Tensor_out"
92        )
93
94    def test_quantize_per_channel_to_out_variant(self) -> None:
95        self.assertIsNotNone(ops.edge.quantized_decomposed.quantize_per_channel.out)
96        fn = ops.edge.quantized_decomposed.quantize_per_channel.default
97        out_variant = fn.to_out_variant()
98        self.assertEqual(
99            out_variant.name(), "quantized_decomposed::quantize_per_channel.out"
100        )
101