1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11from torchaudio import models 12 13 14class TestW2L(unittest.TestCase): 15 batch_size = 10 16 input_frames = 700 17 vocab_size = 4096 18 num_features = 1 19 wav2letter = models.Wav2Letter(num_classes=vocab_size).eval() 20 21 model_inputs = (torch.randn(batch_size, num_features, input_frames),) 22 dynamic_shape = ({0: torch.export.Dim("batch", min=2, max=10)},) 23 24 def test_fp32_w2l(self): 25 ( 26 Tester(self.wav2letter, self.model_inputs, self.dynamic_shape) 27 .export() 28 .to_edge_transform_and_lower() 29 .check_not( 30 [ 31 "executorch_exir_dialectes_edge__ops_aten_convolution_default", 32 "executorch_exir_dialects_edge__ops_aten_relu_default", 33 ] 34 ) 35 .check(["torch.ops.higher_order.executorch_call_delegate"]) 36 .to_executorch() 37 .serialize() 38 .run_method_and_compare_outputs(num_runs=10) 39 ) 40 41 def test_qs8_w2l(self): 42 ( 43 Tester(self.wav2letter.eval(), self.model_inputs, self.dynamic_shape) 44 .quantize() 45 .export() 46 .to_edge_transform_and_lower() 47 .check_not( 48 [ 49 "executorch_exir_dialectes_edge__ops_aten_convolution_default", 50 "executorch_exir_dialects_edge__ops_aten_relu_default", 51 ] 52 ) 53 .check(["torch.ops.higher_order.executorch_call_delegate"]) 54 .to_executorch() 55 .serialize() 56 .run_method_and_compare_outputs(num_runs=10) 57 ) 58