xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/models/w2l.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11from torchaudio import models
12
13
14class TestW2L(unittest.TestCase):
15    batch_size = 10
16    input_frames = 700
17    vocab_size = 4096
18    num_features = 1
19    wav2letter = models.Wav2Letter(num_classes=vocab_size).eval()
20
21    model_inputs = (torch.randn(batch_size, num_features, input_frames),)
22    dynamic_shape = ({0: torch.export.Dim("batch", min=2, max=10)},)
23
24    def test_fp32_w2l(self):
25        (
26            Tester(self.wav2letter, self.model_inputs, self.dynamic_shape)
27            .export()
28            .to_edge_transform_and_lower()
29            .check_not(
30                [
31                    "executorch_exir_dialectes_edge__ops_aten_convolution_default",
32                    "executorch_exir_dialects_edge__ops_aten_relu_default",
33                ]
34            )
35            .check(["torch.ops.higher_order.executorch_call_delegate"])
36            .to_executorch()
37            .serialize()
38            .run_method_and_compare_outputs(num_runs=10)
39        )
40
41    def test_qs8_w2l(self):
42        (
43            Tester(self.wav2letter.eval(), self.model_inputs, self.dynamic_shape)
44            .quantize()
45            .export()
46            .to_edge_transform_and_lower()
47            .check_not(
48                [
49                    "executorch_exir_dialectes_edge__ops_aten_convolution_default",
50                    "executorch_exir_dialects_edge__ops_aten_relu_default",
51                ]
52            )
53            .check(["torch.ops.higher_order.executorch_call_delegate"])
54            .to_executorch()
55            .serialize()
56            .run_method_and_compare_outputs(num_runs=10)
57        )
58