xref: /aosp_15_r20/external/executorch/backends/qualcomm/tests/test_qnn_delegate.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import io
7import json
8import subprocess
9import sys
10import tempfile
11import unittest
12from multiprocessing.connection import Listener
13from pathlib import Path
14
15import torch
16from executorch.backends.qualcomm.tests.utils import (
17    generate_context_binary,
18    QnnPartitioner,
19    QnnQuantizer,
20    QuantDtype,
21    TestQNN,
22    to_backend,
23)
24from executorch.backends.qualcomm.utils.constants import (
25    QCOM_ANNOTATION,
26    QCOM_MODULE,
27    QCOM_QUANT_DTYPE,
28    QCOM_SAMPLE_INPUTS,
29)
30
31from executorch.backends.qualcomm.utils.utils import (
32    capture_program,
33    from_context_binary,
34    generate_htp_compiler_spec,
35    generate_multi_graph_program,
36    generate_qnn_executorch_compiler_spec,
37    skip_annotation,
38    update_spill_fill_size,
39)
40
41from executorch.examples.models.llama.llama_transformer import ModelArgs, MOEFeedForward
42
43from executorch.examples.qualcomm.utils import setup_common_args_and_variables
44
45from executorch.backends.qualcomm.tests.models import *  # noqa: F403
46
47from executorch.examples.models.deeplab_v3 import DeepLabV3ResNet101Model
48from executorch.examples.models.edsr import EdsrModel
49from executorch.examples.models.inception_v3 import InceptionV3Model
50from executorch.examples.models.inception_v4 import InceptionV4Model
51
52# from executorch.examples.models.llama import Llama2Model
53from executorch.examples.models.mobilebert import MobileBertModelExample
54from executorch.examples.models.mobilenet_v2 import MV2Model
55from executorch.examples.models.mobilenet_v3 import MV3Model
56from executorch.examples.models.torchvision_vit.model import TorchVisionViTModel
57from executorch.examples.models.wav2letter import Wav2LetterModel
58from executorch.exir import to_edge
59from executorch.exir.backend.backend_api import disable_validation
60
61
62class TestQNNFloatingPointOperator(TestQNN):
63    # TODO: refactor to support different backends
64    def setUp(self):
65        TestQNN.atol = 1e-1
66        TestQNN.rtol = 1e-1
67        backend_options = generate_htp_compiler_spec(use_fp16=True)
68        TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
69            soc_model=self.chipset_table[TestQNN.model],
70            backend_options=backend_options,
71            debug=False,
72            saver=False,
73            online_prepare=TestQNN.online_prepare,
74            dump_intermediate_outputs=TestQNN.dump_intermediate_outputs,
75            profile=TestQNN.enable_profile,
76            shared_buffer=TestQNN.shared_buffer,
77        )
78
79    def test_qnn_backend_arange(self):
80        module = Arange(5)  # noqa: F405
81        sample_input = (torch.randn(5),)
82        self.lower_module_and_test_output(module, sample_input)
83
84    def test_qnn_backend_avg_pool2d(self):
85        module = AvgPoolModule()  # noqa: F405
86        sample_input = (torch.randn(1, 3, 2, 2),)
87        self.lower_module_and_test_output(module, sample_input)
88
89    def test_qnn_backend_batch_norm(self):
90        module = BatchNorm(32)  # noqa: F405
91        sample_input = (torch.randn([4, 32, 16, 16]),)
92        self.lower_module_and_test_output(module, sample_input)
93
94    def test_qnn_backend_bmm(self):
95        module = Bmm()  # noqa: F405
96        torch.manual_seed(8)
97        sample_input = (torch.randn([4, 8, 32]), torch.randn([4, 32, 8]))
98        self.lower_module_and_test_output(module, sample_input)
99
100    def test_qnn_backend_cast(self):
101        module = Cast()  # noqa: F405
102        sample_input = (10 * torch.rand((9, 4, 5, 3)),)
103        self.lower_module_and_test_output(module, sample_input)
104
105    def test_qnn_backend_cat(self):
106        modules = [Cat2(), Cat3(), Cat4()]  # noqa: F405
107        sample_input = (torch.randn(1, 1, 2, 2), torch.randn(1, 1, 4, 2))
108        for i, module in enumerate(modules):
109            with self.subTest(i=i):
110                self.lower_module_and_test_output(module, sample_input)
111
112    def test_qnn_backend_chunk_single(self):
113        module = Chunk()  # noqa: F405
114        sample_input = (torch.randn(1, 1, 4, 3),)
115        self.lower_module_and_test_output(module, sample_input)
116
117    def test_qnn_backend_clamp(self):
118        module = Clamp()  # noqa: F405
119        sample_input = (torch.randn((9, 4, 5, 3)),)
120        self.lower_module_and_test_output(module, sample_input)
121
122    def test_qnn_backend_conv1d(self):
123        modules = [Conv1dSequential(), Conv1dSequential(bias=False)]  # noqa: F405
124        sample_input = (torch.randn([1, 1, 3]),)
125        for i, module in enumerate(modules):
126            with self.subTest(i=i):
127                self.lower_module_and_test_output(module, sample_input)
128
129    def test_qnn_backend_conv2d(self):
130        modules = [Conv2dSequential(), Conv2dSequential(bias=False)]  # noqa: F405
131        sample_input = (torch.randn([1, 1, 3, 3]),)
132        for i, module in enumerate(modules):
133            with self.subTest(i=i):
134                self.lower_module_and_test_output(module, sample_input)
135
136    def test_qnn_backend_conv_transpose2d(self):
137        modules = [
138            ConvTranspose2dSingle(),  # noqa: F405
139            ConvTranspose2dSingle(bias=False),  # noqa: F405
140        ]
141        sample_input = (torch.randn([1, 1, 3, 3]),)
142        for i, module in enumerate(modules):
143            with self.subTest(i=i):
144                self.lower_module_and_test_output(module, sample_input)
145
146    def test_qnn_backend_einsum_outer_product(self):
147        module = EinsumOuterProduct()  # noqa: F405
148        x = torch.randn(5)
149        y = torch.randn(4)
150        sample_input = (
151            x,
152            y,
153        )
154        self.lower_module_and_test_output(module, sample_input)
155
156    def test_qnn_backend_einsum_bilinear(self):
157        module = EinsumBilinear()  # noqa: F405
158        bn = torch.randn(2, 5)
159        anm = torch.randn(3, 5, 4)
160        bm = torch.randn(2, 4)
161        sample_input = (
162            bn,
163            anm,
164            bm,
165        )
166        self.lower_module_and_test_output(module, sample_input)
167
168    def test_qnn_backend_element_wise_add(self):
169        test_comb = [
170            {
171                QCOM_MODULE: [Add()],  # noqa: F405
172                QCOM_SAMPLE_INPUTS: [
173                    (torch.randn(2, 5, 1, 3), torch.randn(2, 5, 1, 3)),
174                    (torch.randn([2, 5, 1, 3]), torch.randn([4, 1])),
175                ],
176            },
177            {
178                QCOM_MODULE: [AddConstantFloat()],  # noqa: F405
179                QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
180            },
181        ]
182
183        index = 0
184        for comb in test_comb:
185            for module in comb[QCOM_MODULE]:
186                for sample_input in comb[QCOM_SAMPLE_INPUTS]:
187                    with self.subTest(i=index):
188                        self.lower_module_and_test_output(module, sample_input)
189                        index += 1
190
191    def test_qnn_backend_element_wise_ceil(self):
192        module = Ceil()  # noqa: F405
193        sample_input = (torch.randn([2, 5, 1, 3]),)
194        self.lower_module_and_test_output(module, sample_input)
195
196    def test_qnn_backend_element_wise_div(self):
197        eps = 1e-03
198        torch.manual_seed(8)
199        test_comb = [
200            {
201                QCOM_MODULE: [Div()],  # noqa: F405
202                QCOM_SAMPLE_INPUTS: [
203                    (torch.randn(2, 5, 1, 3), eps + torch.randn(2, 5, 1, 3)),
204                    (torch.randn([2, 5, 1, 3]), eps + torch.randn([4, 1])),
205                ],
206            },
207            {
208                QCOM_MODULE: [DivConstantFloat()],  # noqa: F405
209                QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
210            },
211        ]
212
213        index = 0
214        for comb in test_comb:
215            for module in comb[QCOM_MODULE]:
216                for sample_input in comb[QCOM_SAMPLE_INPUTS]:
217                    with self.subTest(i=index):
218                        self.lower_module_and_test_output(module, sample_input)
219                        index += 1
220
221    def test_qnn_backend_element_wise_mul(self):
222        test_comb = [
223            {
224                QCOM_MODULE: [Mul()],  # noqa: F405
225                QCOM_SAMPLE_INPUTS: [
226                    (torch.randn(2, 5, 1, 3), torch.randn(2, 5, 1, 3)),
227                    (torch.randn([2, 5, 1, 3]), torch.randn([4, 1])),
228                ],
229            },
230            {
231                QCOM_MODULE: [MulConstantFloat()],  # noqa: F405
232                QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
233            },
234            {
235                QCOM_MODULE: [MulScalar()],  # noqa: F405
236                QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
237            },
238        ]
239
240        index = 0
241        for comb in test_comb:
242            for module in comb[QCOM_MODULE]:
243                for sample_input in comb[QCOM_SAMPLE_INPUTS]:
244                    with self.subTest(i=index):
245                        self.lower_module_and_test_output(module, sample_input)
246                        index += 1
247
248    def test_qnn_backend_element_wise_sqrt(self):
249        modules = [Sqrt(), SqrtConstant()]  # noqa: F405
250        for i, module in enumerate(modules):
251            sample_input = (torch.rand([3, 1]),)
252            with self.subTest(i=i):
253                self.lower_module_and_test_output(module, sample_input)
254
255    def test_qnn_backend_element_wise_sub(self):
256        test_comb = [
257            {
258                QCOM_MODULE: [Sub()],  # noqa: F405
259                QCOM_SAMPLE_INPUTS: [
260                    (torch.randn(2, 5, 1, 3), torch.randn(2, 5, 1, 3)),
261                    (torch.randn([2, 5, 1, 3]), torch.randn([4, 1])),
262                ],
263            },
264            {
265                QCOM_MODULE: [SubConstantFloat()],  # noqa: F405
266                QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
267            },
268        ]
269
270        index = 0
271        for comb in test_comb:
272            for module in comb[QCOM_MODULE]:
273                for sample_input in comb[QCOM_SAMPLE_INPUTS]:
274                    with self.subTest(i=index):
275                        self.lower_module_and_test_output(module, sample_input)
276                        index += 1
277
278    def test_qnn_backend_embedding(self):
279        module = Embedding()  # noqa: F405
280        sample_input = (torch.Tensor([[1, 2, 4, 5], [4, 3, 2, 9]]).to(torch.int32),)
281        self.lower_module_and_test_output(module, sample_input)
282
283    def test_qnn_backend_expand_copy(self):
284        module = ExpandCopy()  # noqa: F405
285        sample_input = (torch.randn([3, 1]),)
286        self.lower_module_and_test_output(module, sample_input)
287
288    def test_qnn_backend_gelu(self):
289        module = Gelu()  # noqa: F405
290        sample_input = (torch.randn(2, 5, 1, 3),)
291        self.lower_module_and_test_output(module, sample_input)
292
293    def test_qnn_backend_group_norm(self):
294        modules = [GroupNorm(), GroupNorm(bias=False)]  # noqa: F405
295        sample_input = (torch.randn(3, 32, 56, 56),)
296        for i, module in enumerate(modules):
297            with self.subTest(i=i):
298                self.lower_module_and_test_output(module, sample_input)
299
300    def test_qnn_backend_hardsigmoid(self):
301        module = HardSigmoid()  # noqa: F405
302        sample_input = (torch.randn(2, 5, 1, 3),)
303        self.lower_module_and_test_output(module, sample_input)
304
305    def test_qnn_backend_hardswish(self):
306        module = HardSwish()  # noqa: F405
307        sample_input = (torch.randn(2, 5, 1, 3),)
308        self.lower_module_and_test_output(module, sample_input)
309
310    def test_qnn_backend_hardtanh(self):
311        module = HardTanh()  # noqa: F405
312        sample_input = (torch.randn([2, 5, 1, 3]),)
313        self.lower_module_and_test_output(module, sample_input)
314
315    def test_qnn_backend_index(self):
316        module = Index()  # noqa: F405
317        sample_input = (torch.randn([8, 172, 64]),)
318        self.lower_module_and_test_output(module, sample_input)
319
320    def test_qnn_backend_index_put(self):
321        module = IndexPut()  # noqa: F405
322        sample_input = (
323            torch.tensor([2], dtype=torch.int32),
324            torch.randn([1, 1, 12, 64]),
325        )
326        self.lower_module_and_test_output(module, sample_input)
327
328    def test_qnn_backend_interpolate_bilinear_2d(self):
329        module = ResizeBilinear2D()  # noqa: F405
330        sample_input = (torch.randn(2, 3, 4, 5),)
331        self.lower_module_and_test_output(module, sample_input)
332
333    def test_qnn_backend_interpolate_nearest_2d(self):
334        module = ResizeNearest2D()  # noqa: F405
335        sample_input = (torch.randn(2, 3, 4, 5),)
336        self.lower_module_and_test_output(module, sample_input)
337
338    def test_qnn_backend_layer_norm(self):
339        module = LayerNorm()  # noqa: F405
340        sample_input = (torch.randn(196, 768),)
341        self.lower_module_and_test_output(module, sample_input)
342
343    def test_qnn_backend_leaky_relu(self):
344        test_comb = [
345            {
346                QCOM_MODULE: [LeakyReLUDefault()],  # noqa: F405
347                QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
348            },
349            {
350                QCOM_MODULE: [LeakyReLUCustom(0.05)],  # noqa: F405
351                QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
352            },
353        ]
354
355        index = 0
356        for comb in test_comb:
357            for module in comb[QCOM_MODULE]:
358                for sample_input in comb[QCOM_SAMPLE_INPUTS]:
359                    with self.subTest(i=index):
360                        self.lower_module_and_test_output(module, sample_input)
361                        index += 1
362
363    def test_qnn_backend_linear(self):
364        module = Linear()  # noqa: F405
365        sample_input = (torch.randn([3, 4]),)
366        self.lower_module_and_test_output(module, sample_input)
367
368    def test_qnn_backend_log_softmax(self):
369        module = LogSoftmax()  # noqa: F405
370        sample_input = (torch.randn([1, 4, 8, 8]),)
371        self.lower_module_and_test_output(module, sample_input)
372
373    def test_qnn_backend_max_pool2d(self):
374        module = MaxPool2d()  # noqa: F405
375        sample_input = (torch.randn(4, 3, 24, 24),)
376        self.lower_module_and_test_output(module, sample_input)
377
378    def test_qnn_backend_mean_dim(self):
379        modules = [MeanWKeppDim(), MeanWOKeppDim()]  # noqa: F405
380        sample_input = (torch.randn([2, 5, 1, 3]),)
381        for i, module in enumerate(modules):
382            with self.subTest(i=i):
383                self.lower_module_and_test_output(module, sample_input)
384
385    @unittest.skip("failed to lower in QNN 2.26")
386    def test_qnn_backend_mha(self):
387        module = MultiheadAttention()  # noqa: F405
388        sample_input = (torch.randn(1, 197, 96),)
389        self.lower_module_and_test_output(module, sample_input)
390
391    def test_qnn_backend_pad(self):
392        module = Pad()  # noqa: F405
393        sample_input = (torch.randn([1, 8, 128]),)
394        self.lower_module_and_test_output(module, sample_input)
395
396    def test_qnn_backend_pixel_shuffle(self):
397        module = PixelShuffle(2)  # noqa: F405
398        sample_input = (torch.ones([2, 4, 3, 3]),)
399        self.lower_module_and_test_output(module, sample_input)
400
401    def test_qnn_backend_pixel_unshuffle(self):
402        module = PixelUnshuffle(2)  # noqa: F405
403        sample_input = (torch.ones([2, 2, 6, 6]),)
404        self.lower_module_and_test_output(module, sample_input)
405
406    def test_qnn_backend_pow_tensor_scalar(self):
407        module = PowTensorScalar()  # noqa: F405
408        sample_input = (torch.rand([2, 4, 3, 3]),)
409        self.lower_module_and_test_output(module, sample_input)
410
411    def test_qnn_backend_prelu(self):
412        test_comb = [
413            {
414                QCOM_MODULE: [PReLUDefault()],  # noqa: F405
415                QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
416            },
417            {
418                QCOM_MODULE: [PReLUPerChannel(5)],  # noqa: F405
419                QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
420            },
421        ]
422
423        index = 0
424        for comb in test_comb:
425            for module in comb[QCOM_MODULE]:
426                for sample_input in comb[QCOM_SAMPLE_INPUTS]:
427                    with self.subTest(i=index):
428                        self.lower_module_and_test_output(module, sample_input)
429                        index += 1
430
431    def test_qnn_backend_relu(self):
432        module = Relu()  # noqa: F405
433        sample_input = (torch.randn([2, 5, 1, 3]),)
434        self.lower_module_and_test_output(module, sample_input)
435
436    def test_qnn_backend_reshape(self):
437        module = Reshape()  # noqa: F405
438        sample_input = (torch.randn([3, 4]),)
439        self.lower_module_and_test_output(module, sample_input)
440
441    def test_qnn_backend_rms_norm(self):
442        module = RmsNorm()  # noqa: F405
443        sample_input = (torch.abs(torch.randn([1, 1, 1, 4])),)
444        self.lower_module_and_test_output(module, sample_input)
445
446    def test_qnn_backend_rsqrt(self):
447        module = Rsqrt()  # noqa: F405
448        sample_input = (torch.abs(torch.randn([3, 4])),)
449        self.lower_module_and_test_output(module, sample_input)
450
451    def test_qnn_backend_sdpa(self):
452        module = ScaledDotProductAttention()  # noqa: F405
453        mask = torch.tril(torch.randn(1, 1, 100, 100))
454        mask[mask == 0] = float("-inf")
455        sample_input = (
456            torch.randn(1, 4, 100, 64),
457            torch.randn(1, 4, 100, 64),
458            torch.randn(1, 4, 100, 64),
459            mask,
460        )
461        self.lower_module_and_test_output(module, sample_input)
462
463    def test_qnn_backend_sigmoid(self):
464        module = Sigmoid()  # noqa: F405
465        sample_input = (torch.randn([1, 3, 3, 3]),)
466        self.lower_module_and_test_output(module, sample_input)
467
468    def test_qnn_backend_select_copy(self):
469        module = SelectCopy()  # noqa: F405
470        sample_input = (torch.randn([1, 3, 3, 3]),)
471        self.lower_module_and_test_output(module, sample_input)
472
473    def test_qnn_backend_slice_copy(self):
474        modules = [SliceCopy(), SliceCopyWithStep()]  # noqa: F405
475        sample_input = (
476            torch.randn([1, 512]),
477            torch.randn([1, 8]),
478        )
479        for module in modules:
480            self.lower_module_and_test_output(module, sample_input)
481
482    def test_qnn_backend_stack(self):
483        module = Stack()  # noqa: F405
484        sample_input = (torch.randn([1, 2, 3, 4]), torch.randn([1, 2, 3, 4]))
485        self.lower_module_and_test_output(module, sample_input)
486
487    def test_qnn_backend_softmax(self):
488        module = Softmax()  # noqa: F405
489        sample_input = (torch.randn([1, 4, 8, 8]),)
490        self.lower_module_and_test_output(module, sample_input)
491
492    def test_qnn_backend_squeeze(self):
493        module = Squeeze()  # noqa: F405
494        sample_input = (torch.randn([1, 3, 3]),)
495        self.lower_module_and_test_output(module, sample_input)
496
497    def test_qnn_backend_sum_int_list(self):
498        module = SumIntList()  # noqa: F405
499        sample_input = (torch.randn([1, 4, 8, 8]),)
500        self.lower_module_and_test_output(module, sample_input)
501
502    def test_qnn_backend_tanh(self):
503        module = Tanh()  # noqa: F405
504        sample_input = (torch.randn(2, 5, 1, 3),)
505        self.lower_module_and_test_output(module, sample_input)
506
507    def test_qnn_backend_unbind(self):
508        module = Unbind()  # noqa: F405
509        sample_input = (torch.randn([3, 3]),)
510        self.lower_module_and_test_output(module, sample_input)
511
512    def test_qnn_backend_unsqueeze(self):
513        module = Unsqueeze()  # noqa: F405
514        sample_input = (torch.randn([1, 3, 3]),)
515        self.lower_module_and_test_output(module, sample_input)
516
517    def test_qnn_backend_view(self):
518        module = View()  # noqa: F405
519        sample_input = (torch.randn([1, 8, 512]), torch.randn([1, 2, 8, 256]))
520        self.lower_module_and_test_output(module, sample_input)
521
522
523class TestQNNFloatingPointModel(TestQNN):
524    # TODO: refactor to support different backends
525    def setUp(self):
526        TestQNN.atol = 1e-1
527        TestQNN.rtol = 1e-1
528        backend_options = generate_htp_compiler_spec(use_fp16=True)
529        TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
530            soc_model=self.chipset_table[TestQNN.model],
531            backend_options=backend_options,
532            debug=False,
533            saver=False,
534            online_prepare=TestQNN.online_prepare,
535            dump_intermediate_outputs=TestQNN.dump_intermediate_outputs,
536            profile=TestQNN.enable_profile,
537            shared_buffer=TestQNN.shared_buffer,
538        )
539
540    def test_qnn_backend_chunk_add(self):
541        module = ChunkAdd()  # noqa: F405
542        torch.manual_seed(8)
543        sample_input = (torch.randn(1, 2, 4, 2),)
544        self.lower_module_and_test_output(module, sample_input)
545
546    def test_qnn_backend_conv1d_relu_log_softmax(self):
547        module = Conv1dReluLogSoftmax()  # noqa: F405
548        sample_input = (torch.rand(1, 2, 28),)
549        self.lower_module_and_test_output(module, sample_input)
550
551    def test_qnn_backend_conv2d_avg_pool2d(self):
552        module = Conv2dAvgPool2d()  # noqa: F405
553        sample_input = (torch.randn(16, 3, 16, 16),)
554        self.lower_module_and_test_output(module, sample_input)
555
556    def test_qnn_backend_conv2d_bn_hardtanh_mean(self):
557        module = Conv2dBnHardtanhMean()  # noqa: F405
558        sample_input = (torch.randn(1, 1, 6, 6),)
559        self.lower_module_and_test_output(module, sample_input)
560
561    def test_qnn_backend_conv2d_cat(self):
562        module = Conv2dCat()  # noqa: F405
563        sample_input = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5))
564        self.lower_module_and_test_output(module, sample_input)
565
566    def test_qnn_backend_conv2d_down_up_sample(self):
567        module = Conv2dDownUpSample()  # noqa: F405
568        sample_input = (torch.randn(1, 16, 224, 224),)
569        self.lower_module_and_test_output(module, sample_input)
570
571    def test_qnn_backend_conv2d_max_pool2d(self):
572        module = Conv2dMaxPool2d()  # noqa: F405
573        sample_input = (torch.rand(1, 2, 14, 14),)
574        self.lower_module_and_test_output(module, sample_input)
575
576    def test_qnn_backend_conv2d_sum_reduce_dim(self):
577        module = Conv2dSumReduceDim()  # noqa: F405
578        sample_input = (torch.randn([1, 1, 3, 3]),)
579        self.lower_module_and_test_output(module, sample_input)
580
581    def test_qnn_backend_conv2d_topk(self):
582        module = Conv2dTopK()  # noqa: F405
583        sample_input = (torch.randn(1, 3, 32, 32),)
584        self.lower_module_and_test_output(module, sample_input)
585
586    def test_qnn_backend_einsum_outer_product_relu(self):
587        module = EinsumOuterProductRelu()  # noqa: F405
588        x = torch.randn(5)
589        y = torch.randn(4)
590        sample_input = (
591            x,
592            y,
593        )
594        self.lower_module_and_test_output(module, sample_input)
595
596    @unittest.skip("Fail because of bad accuracy")
597    def test_qnn_backend_moe_feed_forward(self):
598        args = ModelArgs()
599        args.dim = 32
600        args.n_heads = 8
601        args.n_layers = 2
602        self.head_dim = args.dim // args.n_heads
603        module = MOEFeedForward(args)  # noqa: F405
604        sample_input = (
605            torch.randint(low=0, high=100, size=(1, 32), dtype=torch.float32),
606        )
607        self.lower_module_and_test_output(module, sample_input)
608
609    def test_qnn_backend_pixel_unshuffle_math_equivalent(self):
610        module = PixelUnshuffleMathEquivalent(2)  # noqa: F405
611        sample_input = (torch.rand(2, 2, 6, 6),)
612        self.lower_module_and_test_output(module, sample_input)
613
614    def test_qnn_backend_residual_block(self):
615        module = ResidualBlockModule()  # noqa: F405
616        sample_input = (torch.randn(1, 32, 28, 28),)
617        self.lower_module_and_test_output(module, sample_input)
618
619    def test_qnn_backend_simple_model(self):
620        module = SimpleModel()  # noqa: F405
621        sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
622        self.lower_module_and_test_output(module, sample_input)
623
624    def test_qnn_backend_topk_and_index(self):
625        module = TopKandIndex()  # noqa: F405
626        sample_input = (torch.randn(3, 10),)
627        self.lower_module_and_test_output(module, sample_input)
628
629    def test_qnn_backend_view_permute_matmul(self):
630        module = ViewPermuteMatMul()  # noqa: F405
631        torch.manual_seed(8)
632        sample_input = (torch.randn([1, 8, 512]), torch.randn([1, 2, 8, 256]))
633        self.lower_module_and_test_output(module, sample_input)
634
635    def test_qnn_backend_example_models(self):
636        instances = [
637            DeepLabV3ResNet101Model(),
638            EdsrModel(),
639            InceptionV3Model(),
640            InceptionV4Model(),
641            # The module of llama is changing frequently. Reopen it when it's stable
642            # Llama2Model(),
643            MV2Model(),
644            MV3Model(),
645            MobileBertModelExample(),
646            TorchVisionViTModel(),
647            Wav2LetterModel(),
648        ]
649        expected_partitions = [
650            1,
651            1,
652            1,
653            1,
654            1,
655            1,
656            1,
657            1,
658            1,
659            1,
660        ]
661        # TODO: Due to trigger maximum recursion depth exceeded, need to check it.
662        disable_validation()
663        for i, instance in enumerate(instances):
664            with self.subTest(i=i):
665                module = instance.get_eager_model().eval()
666                sample_input = instance.get_example_inputs()
667                self.lower_module_and_test_output(
668                    module,
669                    sample_input,
670                    expected_partitions=expected_partitions[i],
671                    assert_output_equal=False,
672                )
673
674
675class TestQNNQuantizedOperator(TestQNN):
676    # TODO: refactor to support different backends
677    def setUp(self):
678        TestQNN.atol = 1e-1
679        TestQNN.rtol = 1
680        backend_options = generate_htp_compiler_spec(use_fp16=False)
681        TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
682            soc_model=self.chipset_table[TestQNN.model],
683            backend_options=backend_options,
684            debug=False,
685            saver=False,
686            online_prepare=TestQNN.online_prepare,
687            dump_intermediate_outputs=TestQNN.dump_intermediate_outputs,
688            profile=TestQNN.enable_profile,
689            shared_buffer=TestQNN.shared_buffer,
690        )
691
692    def test_qnn_backend_16a4w_conv2d(self):
693        modules = [Conv2dSingle(), Conv2dSingle(bias=False)]  # noqa: F405
694        sample_input = (torch.randn([1, 1, 3, 3]),)
695        for i, module in enumerate(modules):
696            with self.subTest(i=i):
697                module = self.get_qdq_module(
698                    module, sample_input, quant_dtype=QuantDtype.use_16a4w
699                )
700                self.lower_module_and_test_output(module, sample_input)
701
702    def test_qnn_backend_16a4w_conv2d_qat(self):
703        modules = [Conv2dSingle(), Conv2dSingle(bias=False)]  # noqa: F405
704        sample_input = (torch.randn([1, 1, 3, 3]),)
705        for i, module in enumerate(modules):
706            with self.subTest(i=i):
707                prepared = self.get_prepared_qat_module(module, sample_input)
708                converted = self.get_converted_sgd_trained_module(
709                    module, prepared, sample_input
710                )
711                self.lower_module_and_test_output(converted, sample_input)
712
713    def test_qnn_backend_16a4w_layer_norm(self):
714        module = LayerNorm()  # noqa: F405
715        sample_input = (torch.randn(196, 768),)
716        module = self.get_qdq_module(
717            module,
718            sample_input,
719            quant_dtype=QuantDtype.use_16a4w,
720        )
721        self.lower_module_and_test_output(module, sample_input)
722
723    def test_qnn_backend_16a4w_linear(self):
724        module = Linear()  # noqa: F405
725        sample_input = (torch.randn([3, 4]),)
726        module = self.get_qdq_module(
727            module,
728            sample_input,
729            quant_dtype=QuantDtype.use_16a4w,
730        )
731        self.lower_module_and_test_output(module, sample_input)
732
733    @unittest.skip("segfault happens in QNN 2.26")
734    def test_qnn_backend_16a4w_per_channel_linear(self):
735        module = Linear(use_bias=False)  # noqa: F405
736        sample_input = (torch.randn([3, 4]),)
737        module = self.get_qdq_module(
738            module,
739            sample_input,
740            is_linear_per_channel=True,
741            quant_dtype=QuantDtype.use_16a4w,
742        )
743        self.lower_module_and_test_output(module, sample_input)
744
745    def test_qnn_backend_16a4w_per_channel_linear_with_bias(self):
746        module = Linear()  # noqa: F405
747        sample_input = (torch.randn([3, 4]),)
748        module = self.get_qdq_module(
749            module,
750            sample_input,
751            is_linear_per_channel=True,
752            quant_dtype=QuantDtype.use_16a4w,
753        )
754        self.lower_module_and_test_output(module, sample_input)
755
756    def test_qnn_backend_arange(self):
757        module = Arange(5)  # noqa: F405
758        sample_input = (torch.randn(5),)
759        module = self.get_qdq_module(module, sample_input)
760        self.lower_module_and_test_output(module, sample_input)
761
762    def test_qnn_backend_avg_pool2d(self):
763        module = AvgPoolModule()  # noqa: F405
764        sample_input = (torch.randn(1, 3, 2, 2),)
765        module = self.get_qdq_module(module, sample_input)
766        self.lower_module_and_test_output(module, sample_input)
767
768    def test_qnn_backend_batch_norm(self):
769        module = BatchNorm(32)  # noqa: F405
770        sample_input = (torch.randn([4, 32, 16, 16]),)
771        module = self.get_qdq_module(module, sample_input)
772        self.lower_module_and_test_output(module, sample_input)
773
774    def test_qnn_backend_bmm(self):
775        module = Bmm()  # noqa: F405
776        torch.manual_seed(8)
777        sample_input = (torch.randn([4, 8, 32]), torch.randn([4, 32, 8]))
778        module = self.get_qdq_module(module, sample_input)
779        self.lower_module_and_test_output(module, sample_input)
780
781    def test_qnn_backend_cat(self):
782        modules = [Cat2(), Cat3(), Cat4()]  # noqa: F405
783        sample_input = (torch.randn(1, 1, 2, 2), torch.randn(1, 1, 4, 2))
784        for i, module in enumerate(modules):
785            with self.subTest(i=i):
786                module = self.get_qdq_module(module, sample_input)
787                self.lower_module_and_test_output(module, sample_input)
788
789    def test_qnn_backend_chunk_single(self):
790        module = Chunk()  # noqa: F405
791        sample_input = (torch.randn(1, 1, 4, 3),)
792        module = self.get_qdq_module(module, sample_input)
793        self.lower_module_and_test_output(module, sample_input)
794
795    def test_qnn_backend_clamp(self):
796        module = Clamp()  # noqa: F405
797        sample_input = (torch.randn((9, 4, 5, 3)),)
798        module = self.get_qdq_module(module, sample_input)
799        self.lower_module_and_test_output(module, sample_input)
800
801    def test_qnn_backend_conv1d(self):
802        modules = [Conv1dSequential(), Conv1dSequential(bias=False)]  # noqa: F405
803        sample_input = (torch.randn([1, 1, 3]),)
804        for i, module in enumerate(modules):
805            with self.subTest(i=i):
806                module = self.get_qdq_module(module, sample_input)
807                self.lower_module_and_test_output(module, sample_input)
808
809    def test_qnn_backend_conv2d(self):
810        modules = [Conv2dSequential(), Conv2dSequential(bias=False)]  # noqa: F405
811        sample_input = (torch.randn([1, 1, 3, 3]),)
812        for i, module in enumerate(modules):
813            with self.subTest(i=i):
814                module = self.get_qdq_module(module, sample_input)
815                self.lower_module_and_test_output(module, sample_input)
816
817    def test_qnn_backend_conv_transpose2d(self):
818        modules = [
819            ConvTranspose2dSingle(),  # noqa: F405
820            ConvTranspose2dSingle(bias=False),  # noqa: F405
821        ]  # noqa: F405
822        sample_input = (torch.randn([1, 1, 3, 3]),)
823        for i, module in enumerate(modules):
824            with self.subTest(i=i):
825                module = self.get_qdq_module(module, sample_input)
826                self.lower_module_and_test_output(module, sample_input)
827
828    def test_qnn_backend_einsum_outer_product(self):
829        module = EinsumOuterProduct()  # noqa: F405
830        x = torch.randn(5)
831        y = torch.randn(4)
832        sample_input = (
833            x,
834            y,
835        )
836        module = self.get_qdq_module(module, sample_input)
837        self.lower_module_and_test_output(module, sample_input)
838
839    def test_qnn_backend_einsum_bilinear(self):
840        module = EinsumBilinear()  # noqa: F405
841        bn = torch.randn(2, 5)
842        anm = torch.randn(3, 5, 4)
843        bm = torch.randn(2, 4)
844        sample_input = (
845            bn,
846            anm,
847            bm,
848        )
849        module = self.get_qdq_module(module, sample_input)
850        self.lower_module_and_test_output(module, sample_input)
851
852    def test_qnn_backend_element_wise_add(self):
853        test_comb = [
854            {
855                QCOM_MODULE: [Add()],  # noqa: F405
856                QCOM_SAMPLE_INPUTS: [
857                    (torch.randn(2, 5, 1, 3), torch.randn(2, 5, 1, 3)),
858                    (torch.randn([2, 5, 1, 3]), torch.randn([4, 1])),
859                ],
860            },
861            {
862                QCOM_MODULE: [AddConstantFloat(), AddConstantLong()],  # noqa: F405
863                QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
864            },
865        ]
866
867        index = 0
868        for comb in test_comb:
869            for module in comb[QCOM_MODULE]:
870                for sample_input in comb[QCOM_SAMPLE_INPUTS]:
871                    with self.subTest(i=index):
872                        module = self.get_qdq_module(module, sample_input)
873                        self.lower_module_and_test_output(module, sample_input)
874                        index += 1
875
876    def test_qnn_backend_element_wise_ceil(self):
877        module = Ceil()  # noqa: F405
878        sample_input = (torch.randn([2, 5, 1, 3]),)
879        module = self.get_qdq_module(module, sample_input)
880        self.lower_module_and_test_output(module, sample_input)
881
882    def test_qnn_backend_element_wise_div(self):
883        eps = 1e-03
884        torch.manual_seed(8)
885        test_comb = [
886            {
887                QCOM_MODULE: [Div()],  # noqa: F405
888                QCOM_SAMPLE_INPUTS: [
889                    (torch.randn(2, 5, 1, 3), eps + torch.randn(2, 5, 1, 3)),
890                    (torch.randn([2, 5, 1, 3]), eps + torch.randn([4, 1])),
891                ],
892            },
893            {
894                QCOM_MODULE: [DivConstantFloat(), DivConstantLong()],  # noqa: F405
895                QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
896            },
897        ]
898
899        index = 0
900        for comb in test_comb:
901            for module in comb[QCOM_MODULE]:
902                for sample_input in comb[QCOM_SAMPLE_INPUTS]:
903                    with self.subTest(i=index):
904                        module = self.get_qdq_module(module, sample_input)
905                        self.lower_module_and_test_output(module, sample_input)
906                        index += 1
907
908    def test_qnn_backend_element_wise_mul(self):
909        test_comb = [
910            {
911                QCOM_MODULE: [Mul()],  # noqa: F405
912                QCOM_SAMPLE_INPUTS: [
913                    (torch.randn(2, 5, 1, 3), torch.randn(2, 5, 1, 3)),
914                    (torch.randn([2, 5, 1, 3]), torch.randn([4, 1])),
915                ],
916            },
917            {
918                QCOM_MODULE: [MulConstantFloat(), MulConstantLong()],  # noqa: F405
919                QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
920            },
921            {
922                QCOM_MODULE: [MulScalar()],  # noqa: F405
923                QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
924            },
925        ]
926
927        index = 0
928        for comb in test_comb:
929            for module in comb[QCOM_MODULE]:
930                for sample_input in comb[QCOM_SAMPLE_INPUTS]:
931                    with self.subTest(i=index):
932                        module = self.get_qdq_module(module, sample_input)
933                        self.lower_module_and_test_output(module, sample_input)
934                        index += 1
935
936    def test_qnn_backend_element_wise_sqrt(self):
937        modules = [Sqrt(), SqrtConstant()]  # noqa: F405
938        for i, module in enumerate(modules):
939            sample_input = (torch.rand([3, 1]),)
940            with self.subTest(i=i):
941                module = self.get_qdq_module(module, sample_input)
942                self.lower_module_and_test_output(module, sample_input)
943
944    def test_qnn_backend_element_wise_sub(self):
945        test_comb = [
946            {
947                QCOM_MODULE: [Sub()],  # noqa: F405
948                QCOM_SAMPLE_INPUTS: [
949                    (torch.randn(2, 5, 1, 3), torch.randn(2, 5, 1, 3)),
950                    (torch.randn([2, 5, 1, 3]), torch.randn([4, 1])),
951                ],
952            },
953            {
954                QCOM_MODULE: [SubConstantFloat(), SubConstantLong()],  # noqa: F405
955                QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
956            },
957        ]
958
959        index = 0
960        for comb in test_comb:
961            for module in comb[QCOM_MODULE]:
962                for sample_input in comb[QCOM_SAMPLE_INPUTS]:
963                    with self.subTest(i=index):
964                        module = self.get_qdq_module(module, sample_input)
965                        self.lower_module_and_test_output(module, sample_input)
966                        index += 1
967
968    def test_qnn_backend_embedding(self):
969        module = Embedding()  # noqa: F405
970        sample_input = (torch.Tensor([[1, 2, 4, 5], [4, 3, 2, 9]]).to(torch.int32),)
971        module = self.get_qdq_module(module, sample_input)
972        self.lower_module_and_test_output(module, sample_input)
973
974    def test_qnn_backend_expand_copy(self):
975        module = ExpandCopy()  # noqa: F405
976        sample_input = (torch.randn([3, 1]),)
977        module = self.get_qdq_module(module, sample_input)
978        self.lower_module_and_test_output(module, sample_input)
979
980    def test_qnn_backend_gelu(self):
981        module = Gelu()  # noqa: F405
982        sample_input = (torch.randn(2, 5, 1, 3),)
983        module = self.get_qdq_module(module, sample_input)
984        self.lower_module_and_test_output(module, sample_input)
985
986    def test_qnn_backend_group_norm(self):
987        modules = [GroupNorm(), GroupNorm(bias=False)]  # noqa: F405
988        sample_input = (torch.randn(3, 32, 56, 56),)
989        for i, module in enumerate(modules):
990            with self.subTest(i=i):
991                module = self.get_qdq_module(module, sample_input)
992                self.lower_module_and_test_output(module, sample_input)
993
994    def test_qnn_backend_hardsigmoid(self):
995        module = HardSigmoid()  # noqa: F405
996        sample_input = (torch.randn(2, 5, 1, 3),)
997        module = self.get_qdq_module(module, sample_input)
998        self.lower_module_and_test_output(module, sample_input)
999
1000    def test_qnn_backend_hardswish(self):
1001        module = HardSwish()  # noqa: F405
1002        sample_input = (torch.randn(2, 5, 1, 3),)
1003        module = self.get_qdq_module(module, sample_input)
1004        self.lower_module_and_test_output(module, sample_input)
1005
1006    def test_qnn_backend_hardtanh(self):
1007        module = HardTanh()  # noqa: F405
1008        sample_input = (torch.randn([2, 5, 1, 3]),)
1009        module = self.get_qdq_module(module, sample_input)
1010        self.lower_module_and_test_output(module, sample_input)
1011
1012    def test_qnn_backend_index(self):
1013        module = Index()  # noqa: F405
1014        sample_input = (torch.randn([8, 172, 64]),)
1015        module = self.get_qdq_module(module, sample_input)
1016        self.lower_module_and_test_output(module, sample_input)
1017
1018    def test_qnn_backend_index_put(self):
1019        module = IndexPut()  # noqa: F405
1020        sample_input = (
1021            torch.tensor([2], dtype=torch.int32),
1022            torch.randn([1, 1, 12, 64]),
1023        )
1024        module = self.get_qdq_module(module, sample_input)
1025        self.lower_module_and_test_output(module, sample_input)
1026
1027    def test_qnn_backend_interpolate_bilinear_2d(self):
1028        module = ResizeBilinear2D()  # noqa: F405
1029        sample_input = (torch.randn(2, 3, 4, 5),)
1030        module = self.get_qdq_module(module, sample_input)
1031        self.lower_module_and_test_output(module, sample_input)
1032
1033    def test_qnn_backend_interpolate_nearest_2d(self):
1034        module = ResizeNearest2D()  # noqa: F405
1035        sample_input = (torch.randn(2, 3, 4, 5),)
1036        module = self.get_qdq_module(module, sample_input)
1037        self.lower_module_and_test_output(module, sample_input)
1038
1039    def test_qnn_backend_layer_norm(self):
1040        module = LayerNorm()  # noqa: F405
1041        sample_input = (torch.randn(196, 768),)
1042        module = self.get_qdq_module(module, sample_input)
1043        self.lower_module_and_test_output(module, sample_input)
1044
1045    def test_qnn_backend_leaky_relu(self):
1046        test_comb = [
1047            {
1048                QCOM_MODULE: [LeakyReLUDefault()],  # noqa: F405
1049                QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
1050            },
1051            {
1052                QCOM_MODULE: [LeakyReLUCustom(0.05)],  # noqa: F405
1053                QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
1054            },
1055        ]
1056
1057        index = 0
1058        for comb in test_comb:
1059            for module in comb[QCOM_MODULE]:
1060                for sample_input in comb[QCOM_SAMPLE_INPUTS]:
1061                    with self.subTest(i=index):
1062                        module = self.get_qdq_module(module, sample_input)
1063                        self.lower_module_and_test_output(module, sample_input)
1064                        index += 1
1065
1066    def test_qnn_backend_linear(self):
1067        module = Linear()  # noqa: F405
1068        sample_input = (torch.randn([3, 4]),)
1069        module = self.get_qdq_module(module, sample_input)
1070        self.lower_module_and_test_output(module, sample_input)
1071
1072    def test_qnn_backend_linear_qat(self):
1073        """
1074        Prototype to test qat model
1075        """
1076        module = Linear()  # noqa: F405
1077        sample_input = (torch.randn([3, 4]),)
1078        prepared = self.get_prepared_qat_module(module, sample_input)
1079        module = self.get_converted_sgd_trained_module(module, prepared, sample_input)
1080        self.lower_module_and_test_output(module, sample_input)
1081
1082    def test_qnn_backend_log_softmax(self):
1083        module = LogSoftmax()  # noqa: F405
1084        sample_input = (torch.randn([1, 4, 8, 8]),)
1085        module = self.get_qdq_module(module, sample_input)
1086        self.lower_module_and_test_output(module, sample_input)
1087
1088    def test_qnn_backend_max_pool2d(self):
1089        module = MaxPool2d()  # noqa: F405
1090        sample_input = (torch.randn(4, 3, 24, 24),)
1091        module = self.get_qdq_module(module, sample_input)
1092        self.lower_module_and_test_output(module, sample_input)
1093
1094    def test_qnn_backend_mean_dim(self):
1095        modules = [MeanWKeppDim(), MeanWOKeppDim()]  # noqa: F405
1096        sample_input = (torch.randn([2, 5, 1, 3]),)
1097        for i, module in enumerate(modules):
1098            with self.subTest(i=i):
1099                module = self.get_qdq_module(module, sample_input)
1100                self.lower_module_and_test_output(module, sample_input)
1101
1102    def test_qnn_backend_mha(self):
1103        module = MultiheadAttention()  # noqa: F405
1104        sample_input = (torch.randn(1, 197, 96),)
1105        module = self.get_qdq_module(module, sample_input)
1106        self.lower_module_and_test_output(module, sample_input)
1107
1108    def test_qnn_backend_pad(self):
1109        module = Pad()  # noqa: F405
1110        sample_input = (torch.randn([1, 8, 128]),)
1111        module = self.get_qdq_module(module, sample_input)
1112        self.lower_module_and_test_output(module, sample_input)
1113
1114    def test_qnn_backend_pixel_shuffle(self):
1115        module = PixelShuffle(2)  # noqa: F405
1116        sample_input = (torch.ones([2, 4, 3, 3]),)
1117        module = self.get_qdq_module(module, sample_input)
1118        self.lower_module_and_test_output(module, sample_input)
1119
1120    def test_qnn_backend_pixel_unshuffle(self):
1121        module = PixelUnshuffle(2)  # noqa: F405
1122        sample_input = (torch.ones([2, 2, 6, 6]),)
1123        module = self.get_qdq_module(module, sample_input)
1124        self.lower_module_and_test_output(module, sample_input)
1125
1126    def test_qnn_backend_pow_tensor_scalar(self):
1127        module = PowTensorScalar()  # noqa: F405
1128        sample_input = (torch.rand([2, 4, 3, 3]),)
1129        module = self.get_qdq_module(module, sample_input)
1130        self.lower_module_and_test_output(module, sample_input)
1131
1132    def test_qnn_backend_prelu(self):
1133        test_comb = [
1134            {
1135                QCOM_MODULE: [PReLUDefault()],  # noqa: F405
1136                QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
1137            },
1138            {
1139                QCOM_MODULE: [PReLUPerChannel(5)],  # noqa: F405
1140                QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
1141            },
1142        ]
1143
1144        index = 0
1145        for comb in test_comb:
1146            for module in comb[QCOM_MODULE]:
1147                for sample_input in comb[QCOM_SAMPLE_INPUTS]:
1148                    with self.subTest(i=index):
1149                        module = self.get_qdq_module(module, sample_input)
1150                        self.lower_module_and_test_output(module, sample_input)
1151                        index += 1
1152
1153    def test_qnn_backend_relu(self):
1154        module = Relu()  # noqa: F405
1155        sample_input = (torch.randn([2, 5, 1, 3]),)
1156        module = self.get_qdq_module(module, sample_input)
1157        self.lower_module_and_test_output(module, sample_input)
1158
1159    def test_qnn_backend_reshape(self):
1160        module = Reshape()  # noqa: F405
1161        sample_input = (torch.randn([3, 4]),)
1162        module = self.get_qdq_module(module, sample_input)
1163        self.lower_module_and_test_output(module, sample_input)
1164
1165    def test_qnn_backend_rms_norm(self):
1166        module = RmsNorm()  # noqa: F405
1167        sample_input = (torch.abs(torch.randn([1, 1, 1, 4])),)
1168        module = self.get_qdq_module(
1169            module, sample_input, quant_dtype=QuantDtype.use_16a4w
1170        )
1171        self.lower_module_and_test_output(module, sample_input)
1172
1173    def test_qnn_backend_rsqrt(self):
1174        module = Rsqrt()  # noqa: F405
1175        sample_input = (torch.abs(torch.randn([3, 4])),)
1176        module = self.get_qdq_module(module, sample_input)
1177        self.lower_module_and_test_output(module, sample_input)
1178
1179    def test_qnn_backend_sdpa(self):
1180        module = ScaledDotProductAttention()  # noqa: F405
1181        mask = torch.tril(torch.randn(1, 1, 100, 100))
1182        mask[mask == 0] = torch.finfo(torch.float32).min
1183        sample_input = (
1184            torch.randn(1, 4, 100, 64),
1185            torch.randn(1, 4, 100, 64),
1186            torch.randn(1, 4, 100, 64),
1187            mask,
1188        )
1189        module = self.get_qdq_module(module, sample_input)
1190        self.lower_module_and_test_output(module, sample_input)
1191
1192    def test_qnn_backend_select_copy(self):
1193        module = SelectCopy()  # noqa: F405
1194        sample_input = (torch.randn([1, 3, 3, 3]),)
1195        module = self.get_qdq_module(module, sample_input)
1196        self.lower_module_and_test_output(module, sample_input)
1197
1198    def test_qnn_backend_sigmoid(self):
1199        module = Sigmoid()  # noqa: F405
1200        sample_input = (torch.randn([1, 3, 3, 3]),)
1201        module = self.get_qdq_module(module, sample_input)
1202        self.lower_module_and_test_output(module, sample_input)
1203
1204    def test_qnn_backend_slice_copy(self):
1205        modules = [SliceCopy(), SliceCopyWithStep()]  # noqa: F405
1206        sample_input = (
1207            torch.randn([1, 512]),
1208            torch.randn([1, 8]),
1209        )
1210        for module in modules:
1211            module = self.get_qdq_module(module, sample_input)
1212            self.lower_module_and_test_output(module, sample_input)
1213
1214    def test_qnn_backend_softmax(self):
1215        module = Softmax()  # noqa: F405
1216        sample_input = (torch.randn([1, 4, 8, 8]),)
1217        module = self.get_qdq_module(module, sample_input)
1218        self.lower_module_and_test_output(module, sample_input)
1219
1220    def test_qnn_backend_squeeze(self):
1221        module = Squeeze()  # noqa: F405
1222        sample_input = (torch.randn([1, 3, 3]),)
1223        module = self.get_qdq_module(module, sample_input)
1224        self.lower_module_and_test_output(module, sample_input)
1225
1226    def test_qnn_backend_stack(self):
1227        module = Stack()  # noqa: F405
1228        sample_input = (
1229            torch.randn([1, 2, 3, 4]),
1230            torch.randn([1, 2, 3, 4]),
1231        )
1232        module = self.get_qdq_module(module, sample_input)
1233        self.lower_module_and_test_output(module, sample_input)
1234
1235    def test_qnn_backend_sum_int_list(self):
1236        module = SumIntList()  # noqa: F405
1237        sample_input = (torch.randn([1, 4, 8, 8]),)
1238        module = self.get_qdq_module(module, sample_input)
1239        self.lower_module_and_test_output(module, sample_input)
1240
1241    def test_qnn_backend_tanh(self):
1242        module = Tanh()  # noqa: F405
1243        sample_input = (torch.randn(2, 5, 1, 3),)
1244        module = self.get_qdq_module(module, sample_input)
1245        self.lower_module_and_test_output(module, sample_input)
1246
1247    def test_qnn_backend_unbind(self):
1248        module = Unbind()  # noqa: F405
1249        sample_input = (torch.randn([3, 3]),)
1250        module = self.get_qdq_module(module, sample_input)
1251        self.lower_module_and_test_output(module, sample_input)
1252
1253    def test_qnn_backend_unsqueeze(self):
1254        module = Unsqueeze()  # noqa: F405
1255        sample_input = (torch.randn([1, 3, 3]),)
1256        module = self.get_qdq_module(module, sample_input)
1257        self.lower_module_and_test_output(module, sample_input)
1258
1259    def test_qnn_backend_view(self):
1260        module = View()  # noqa: F405
1261        sample_input = (torch.randn([1, 8, 512]), torch.randn([1, 2, 8, 256]))
1262        module = self.get_qdq_module(module, sample_input)
1263        self.lower_module_and_test_output(module, sample_input)
1264
1265
1266class TestQNNQuantizedModel(TestQNN):
1267    # TODO: refactor to support different backends
1268    def setUp(self):
1269        TestQNN.atol = 1e-1
1270        TestQNN.rtol = 1
1271        backend_options = generate_htp_compiler_spec(use_fp16=False)
1272        TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
1273            soc_model=self.chipset_table[TestQNN.model],
1274            backend_options=backend_options,
1275            debug=False,
1276            saver=False,
1277            online_prepare=TestQNN.online_prepare,
1278            dump_intermediate_outputs=TestQNN.dump_intermediate_outputs,
1279            profile=TestQNN.enable_profile,
1280            shared_buffer=TestQNN.shared_buffer,
1281        )
1282
1283    def test_qnn_backend_chunk_add(self):
1284        module = ChunkAdd()  # noqa: F405
1285        torch.manual_seed(8)
1286        sample_input = (torch.randn(1, 1, 4, 2),)
1287        module = self.get_qdq_module(module, sample_input)
1288        self.lower_module_and_test_output(module, sample_input)
1289
1290    def test_qnn_backend_conv1d_relu_log_softmax(self):
1291        module = Conv1dReluLogSoftmax()  # noqa: F405
1292        sample_input = (torch.rand(1, 2, 28),)
1293        module = self.get_qdq_module(module, sample_input)
1294        self.lower_module_and_test_output(module, sample_input)
1295
1296    def test_qnn_backend_conv2d_avg_pool2d(self):
1297        module = Conv2dAvgPool2d()  # noqa: F405
1298        sample_input = (torch.randn(16, 3, 16, 16),)
1299        module = self.get_qdq_module(module, sample_input)
1300        self.lower_module_and_test_output(module, sample_input)
1301
1302    def test_qnn_backend_conv2d_bn_hardtanh_mean(self):
1303        module = Conv2dBnHardtanhMean()  # noqa: F405
1304        sample_input = (torch.randn(1, 1, 6, 6),)
1305        module = self.get_qdq_module(module, sample_input)
1306        self.lower_module_and_test_output(module, sample_input)
1307
1308    def test_qnn_backend_conv2d_cat(self):
1309        module = Conv2dCat()  # noqa: F405
1310        sample_input = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5))
1311        module = self.get_qdq_module(module, sample_input)
1312        self.lower_module_and_test_output(module, sample_input)
1313
1314    def test_qnn_backend_conv2d_down_up_sample(self):
1315        module = Conv2dDownUpSample()  # noqa: F405
1316        sample_input = (torch.randn(1, 16, 224, 224),)
1317        module = self.get_qdq_module(module, sample_input)
1318        self.lower_module_and_test_output(module, sample_input)
1319
1320    def test_qnn_backend_conv2d_max_pool2d(self):
1321        module = Conv2dMaxPool2d()  # noqa: F405
1322        sample_input = (torch.rand(1, 2, 14, 14),)
1323        module = self.get_qdq_module(module, sample_input)
1324        self.lower_module_and_test_output(module, sample_input)
1325
1326    def test_qnn_backend_conv2d_sum_reduce_dim(self):
1327        module = Conv2dSumReduceDim()  # noqa: F405
1328        sample_input = (torch.randn([1, 1, 3, 3]),)
1329        module = self.get_qdq_module(module, sample_input)
1330        self.lower_module_and_test_output(module, sample_input)
1331
1332    def test_qnn_backend_conv2d_topk(self):
1333        module = Conv2dTopK()  # noqa: F405
1334        sample_input = (torch.randn(1, 3, 32, 32),)
1335        module = self.get_qdq_module(module, sample_input)
1336        self.lower_module_and_test_output(module, sample_input)
1337
1338    def test_qnn_backend_einsum_outer_product_relu(self):
1339        module = EinsumOuterProductRelu()  # noqa: F405
1340        x = torch.randn(5)
1341        y = torch.randn(4)
1342        sample_input = (
1343            x,
1344            y,
1345        )
1346        module = self.get_qdq_module(module, sample_input)
1347        self.lower_module_and_test_output(module, sample_input)
1348
1349    @unittest.skip("UT pass before QNN 2.26, segfault during partitioner")
1350    def test_qnn_backend_moe_feed_forward(self):
1351        args = ModelArgs()
1352        args.dim = 32
1353        args.n_heads = 8
1354        args.n_layers = 2
1355        self.head_dim = args.dim // args.n_heads
1356        module = MOEFeedForward(args)  # noqa: F405
1357        sample_input = (
1358            torch.randint(low=0, high=100, size=(1, 32), dtype=torch.float32),
1359        )
1360        module = self.get_qdq_module(module, sample_input)
1361        self.lower_module_and_test_output(module, sample_input)
1362
1363    def test_qnn_backend_pixel_unshuffle_math_equivalent(self):
1364        module = PixelUnshuffleMathEquivalent(2)  # noqa: F405
1365        sample_input = (torch.rand(2, 2, 6, 6),)
1366        module = self.get_qdq_module(module, sample_input)
1367        self.lower_module_and_test_output(module, sample_input)
1368
1369    def test_qnn_backend_residual_block(self):
1370        module = ResidualBlockModule()  # noqa: F405
1371        sample_input = (torch.randn(1, 32, 28, 28),)
1372        module = self.get_qdq_module(module, sample_input)
1373        self.lower_module_and_test_output(module, sample_input)
1374
1375    def test_qnn_backend_simple_model(self):
1376        module = SimpleModel()  # noqa: F405
1377        sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1378        module = self.get_qdq_module(module, sample_input)
1379        self.lower_module_and_test_output(module, sample_input)
1380
1381    def test_qnn_backend_topk_and_index(self):
1382        module = TopKandIndex()  # noqa: F405
1383        sample_input = (torch.randn(3, 10),)
1384        module = self.get_qdq_module(module, sample_input)
1385        self.lower_module_and_test_output(module, sample_input)
1386
1387    def test_qnn_backend_view_permute_matmul(self):
1388        module = ViewPermuteMatMul()  # noqa: F405
1389        torch.manual_seed(8)
1390        sample_input = (torch.randn([1, 8, 512]), torch.randn([1, 2, 8, 256]))
1391        module = self.get_qdq_module(module, sample_input)
1392        self.lower_module_and_test_output(module, sample_input)
1393
1394    def test_qnn_backend_example_models(self):
1395        instances = [
1396            {
1397                QCOM_MODULE: DeepLabV3ResNet101Model(),
1398                QCOM_ANNOTATION: (),
1399                QCOM_QUANT_DTYPE: QuantDtype.use_8a8w,
1400            },
1401            {
1402                QCOM_MODULE: EdsrModel(),
1403                QCOM_ANNOTATION: (),
1404                QCOM_QUANT_DTYPE: QuantDtype.use_8a8w,
1405            },
1406            {
1407                QCOM_MODULE: InceptionV3Model(),
1408                QCOM_ANNOTATION: (),
1409                QCOM_QUANT_DTYPE: QuantDtype.use_8a8w,
1410            },
1411            {
1412                QCOM_MODULE: InceptionV4Model(),
1413                QCOM_ANNOTATION: (),
1414                QCOM_QUANT_DTYPE: QuantDtype.use_8a8w,
1415            },
1416            # The module of llama is changing frequently. Reopen it when it's stable
1417            # {QCOM_MODULE: Llama2Model(), QCOM_ANNOTATION: (), QCOM_QUANT_DTYPE: QuantDtype.use_8a8w},
1418            {
1419                QCOM_MODULE: MV2Model(),
1420                QCOM_ANNOTATION: (),
1421                QCOM_QUANT_DTYPE: QuantDtype.use_8a8w,
1422            },
1423            {
1424                QCOM_MODULE: MV3Model(),
1425                QCOM_ANNOTATION: (),
1426                QCOM_QUANT_DTYPE: QuantDtype.use_8a8w,
1427            },
1428            # only works on QNN 2.12 so far
1429            # { 'module': MobileBertModelExample(), 'annotation': (), QCOM_QUANT_DTYPE: QuantDtype.use_8a8w },
1430            {
1431                QCOM_MODULE: TorchVisionViTModel(),
1432                QCOM_ANNOTATION: (),
1433                QCOM_QUANT_DTYPE: QuantDtype.use_8a8w,
1434            },
1435            {
1436                QCOM_MODULE: Wav2LetterModel(),
1437                QCOM_ANNOTATION: (),
1438                QCOM_QUANT_DTYPE: QuantDtype.use_8a8w,
1439            },
1440        ]
1441        expected_partitions = [
1442            1,
1443            1,
1444            1,
1445            1,
1446            1,
1447            1,
1448            1,
1449            # For MobileBertModelExample
1450            # 1,
1451            1,
1452            1,
1453        ]
1454        # TODO: Due to trigger maximum recursion depth exceeded, need to check it.
1455        disable_validation()
1456        for i, instance in enumerate(instances):
1457            with self.subTest(i=i):
1458                module = instance[QCOM_MODULE].get_eager_model().eval()
1459                sample_input = instance[QCOM_MODULE].get_example_inputs()
1460                module = self.get_qdq_module(
1461                    module,
1462                    sample_input,
1463                    custom_quant_annotations=instance[QCOM_ANNOTATION],
1464                    quant_dtype=instance[QCOM_QUANT_DTYPE],
1465                )
1466                self.lower_module_and_test_output(
1467                    module,
1468                    sample_input,
1469                    expected_partitions=expected_partitions[i],
1470                    assert_output_equal=False,
1471                )
1472
1473
1474class TestQNNFloatingPointUtils(TestQNN):
1475    # TODO: refactor to support different backends
1476    def setUp(self):
1477        TestQNN.atol = 1e-1
1478        TestQNN.rtol = 1e-1
1479        backend_options = generate_htp_compiler_spec(use_fp16=True)
1480        TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
1481            soc_model=self.chipset_table[TestQNN.model],
1482            backend_options=backend_options,
1483            debug=False,
1484            saver=False,
1485        )
1486
1487    def test_qnn_backend_dump_intermediate_outputs(self):
1488        backend_options = generate_htp_compiler_spec(use_fp16=True)
1489        TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
1490            soc_model=self.chipset_table[TestQNN.model],
1491            backend_options=backend_options,
1492            dump_intermediate_outputs=True,
1493        )
1494        module = Relu()  # noqa: F405
1495        sample_input = (torch.randn([2, 5, 1, 3]),)
1496        self.lower_module_and_test_output(
1497            module,
1498            sample_input,
1499            expected_partitions=1,
1500            expected_intermediate_events=3,
1501        )
1502
1503    def test_qnn_backend_skip_node_id(self):
1504        module = SimpleModel()  # noqa: F405
1505        sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1506        self.lower_module_and_test_output(
1507            module,
1508            sample_input,
1509            expected_partitions=3,
1510            skip_node_id_set={"aten_add_tensor", "aten_mean_dim"},
1511        )
1512
1513    def test_qnn_backend_skip_node_op(self):
1514        module = SimpleModel()  # noqa: F405
1515        sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1516        self.lower_module_and_test_output(
1517            module,
1518            sample_input,
1519            expected_partitions=2,
1520            skip_node_op_set={"aten.add.Tensor"},
1521        )
1522
1523    def test_qnn_backend_multi_contexts(self):
1524        module = SimpleModel()  # noqa: F405
1525        sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1526        edge_prog = capture_program(module, sample_input)
1527        self.split_graph(edge_prog.exported_program.graph_module, 4)
1528
1529        backend_options = generate_htp_compiler_spec(
1530            use_fp16=True,
1531            use_dlbc=True,
1532            use_multi_contexts=True,
1533        )
1534        compiler_specs = generate_qnn_executorch_compiler_spec(
1535            soc_model=self.chipset_table[TestQNN.model],
1536            backend_options=backend_options,
1537        )
1538        partitioner = QnnPartitioner(compiler_specs)
1539        edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner)
1540        update_spill_fill_size(edge_prog.exported_program)
1541        exec_prog = edge_prog.to_executorch()
1542        self.verify_output(module, sample_input, exec_prog)
1543
1544    def test_qnn_backend_multi_contexts_composite(self):
1545        backend_options = generate_htp_compiler_spec(
1546            use_fp16=True,
1547            use_dlbc=True,
1548            use_multi_contexts=True,
1549        )
1550        compiler_specs = generate_qnn_executorch_compiler_spec(
1551            soc_model=self.chipset_table[TestQNN.model],
1552            backend_options=backend_options,
1553        )
1554        module = CompositeDelegateModule(  # noqa: F405
1555            compiler_specs=compiler_specs,
1556            partitioner_type=QnnPartitioner,
1557            capture_method=capture_program,
1558            lowered_method=to_backend,
1559        )
1560        sample_input = module.get_random_input()
1561        edge_prog = to_edge(
1562            torch.export.export(module, sample_input),
1563        )
1564        update_spill_fill_size(edge_prog.exported_program())
1565        exec_prog = edge_prog.to_executorch()
1566        self.verify_output(module.get_reference_module(), sample_input, exec_prog)
1567
1568    def test_qnn_backend_multi_graphs(self):
1569        if self.enable_x86_64:
1570            self.skipTest("weight sharing is not supported on host machine")
1571
1572        seq_conv = Conv2dSequential()  # noqa: F405
1573        # weight sharing
1574        modules = [seq_conv, seq_conv.second]
1575        sample_inputs = [(torch.randn([1, 1, 3, 3]),), (torch.randn([1, 3, 3, 3]),)]
1576        graph_names = ["seq_conv", "single_conv"]
1577        edge_progs = [
1578            capture_program(module, sample_input)
1579            for module, sample_input in zip(modules, sample_inputs)
1580        ]
1581        backend_options = generate_htp_compiler_spec(
1582            use_fp16=True,
1583        )
1584        compiler_specs = [
1585            generate_qnn_executorch_compiler_spec(
1586                soc_model=self.chipset_table[TestQNN.model],
1587                backend_options=backend_options,
1588                multiple_graphs=True,
1589                graph_name=graph_name,
1590            )
1591            for graph_name in graph_names
1592        ]
1593        exported_programs = [
1594            to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i]))
1595            for i, edge_prog in enumerate(edge_progs)
1596        ]
1597        prog_mgr = generate_multi_graph_program(
1598            compiler_specs=compiler_specs[0],
1599            processed_bytes=[
1600                prog.graph_module.lowered_module_0.processed_bytes
1601                for prog in exported_programs
1602            ],
1603        )
1604        for index, module in enumerate(modules):
1605            self.verify_output(
1606                module=module,
1607                sample_inputs=sample_inputs[index],
1608                executorch_prog=prog_mgr,
1609                method_index=index,
1610            )
1611
1612    def test_qnn_backend_profile_op(self):
1613        TestQNN.enable_profile = True
1614        backend_options = generate_htp_compiler_spec(use_fp16=True)
1615        TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
1616            soc_model=self.chipset_table[TestQNN.model],
1617            backend_options=backend_options,
1618            profile=True,
1619        )
1620        module = SimpleModel()  # noqa: F405
1621        sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1622        self.lower_module_and_test_output(
1623            module,
1624            sample_input,
1625            expected_partitions=1,
1626            expected_profile_events=24,
1627        )
1628
1629    def test_qnn_backend_shared_buffer(self):
1630        TestQNN.shared_buffer = True
1631        backend_options = generate_htp_compiler_spec(
1632            use_fp16=True,
1633        )
1634        TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
1635            soc_model=self.chipset_table[TestQNN.model],
1636            backend_options=backend_options,
1637            shared_buffer=True,
1638        )
1639        module = SimpleModel()  # noqa: F405
1640        sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1641        self.lower_module_and_test_output(
1642            module,
1643            sample_input,
1644            expected_partitions=1,
1645        )
1646
1647    def test_qnn_backend_online_prepare(self):
1648        backend_options = generate_htp_compiler_spec(use_fp16=True)
1649        TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
1650            soc_model=self.chipset_table[TestQNN.model],
1651            backend_options=backend_options,
1652            online_prepare=True,
1653        )
1654        module = SimpleModel()  # noqa: F405
1655        sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1656        self.lower_module_and_test_output(module, sample_input)
1657
1658    def test_qnn_backend_context_direct(self):
1659        with tempfile.TemporaryDirectory() as tmp_dir:
1660            module = ContextBinaryExample()  # noqa: F405
1661            generate_context_binary(
1662                module=module,
1663                inputs=module.example_inputs(),
1664                quantized=False,
1665                artifact_dir=tmp_dir,
1666            )
1667            ctx_path = f"{tmp_dir}/model_ctx.bin"
1668            bundle_program = from_context_binary(ctx_path, "ctx_loader")
1669            self.verify_output(
1670                module,
1671                tuple(
1672                    torch.randn(size=v.shape, dtype=v.dtype)
1673                    for v in bundle_program["inputs"].values()
1674                ),
1675                bundle_program["edge_program_manager"].to_executorch(),
1676            )
1677
1678
1679class TestQNNQuantizedUtils(TestQNN):
1680    # TODO: refactor to support different backends
1681    def setUp(self):
1682        TestQNN.atol = 1e-1
1683        TestQNN.rtol = 1
1684        backend_options = generate_htp_compiler_spec(use_fp16=False)
1685        TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
1686            soc_model=self.chipset_table[TestQNN.model],
1687            backend_options=backend_options,
1688            debug=False,
1689            saver=False,
1690        )
1691
1692    def test_qnn_backend_dump_intermediate_outputs(self):
1693        backend_options = generate_htp_compiler_spec(use_fp16=False)
1694        TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
1695            soc_model=self.chipset_table[TestQNN.model],
1696            backend_options=backend_options,
1697            dump_intermediate_outputs=True,
1698        )
1699        module = Relu()  # noqa: F405
1700        sample_input = (torch.randn([2, 5, 1, 3]),)
1701        module = self.get_qdq_module(module, sample_input)
1702        self.lower_module_and_test_output(
1703            module,
1704            sample_input,
1705            expected_partitions=1,
1706            expected_intermediate_events=5,
1707        )
1708
1709    def test_qnn_backend_skip_node_id_partitioner(self):
1710        module = SimpleModel()  # noqa: F405
1711        sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1712        module = self.get_qdq_module(module, sample_input)
1713        self.lower_module_and_test_output(
1714            module,
1715            sample_input,
1716            expected_partitions=3,
1717            skip_node_id_set={"aten_add_tensor", "aten_mean_dim"},
1718        )
1719
1720    def test_qnn_backend_skip_node_id_quantizer(self):
1721        module = SimpleModel()  # noqa: F405
1722        sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1723
1724        # define partitioner
1725        backend_options = generate_htp_compiler_spec(
1726            use_fp16=False,
1727        )
1728        compiler_specs = generate_qnn_executorch_compiler_spec(
1729            soc_model=self.chipset_table[TestQNN.model],
1730            backend_options=backend_options,
1731        )
1732        partitioner = QnnPartitioner(compiler_specs)
1733        # define quantizer
1734        quantizer = QnnQuantizer()
1735
1736        # define calibration method
1737        def calibrator(gm):
1738            gm(*sample_input)
1739
1740        # get partially lowererd graph module
1741        graph_module, exported_progs = skip_annotation(
1742            nn_module=module,
1743            quantizer=quantizer,
1744            partitioner=partitioner,
1745            sample_input=sample_input,
1746            calibration_cb=calibrator,
1747            fp_node_id_set={"conv2d"},
1748        )
1749        self.assertEqual(len(exported_progs), 1)
1750        # lower all graph again, the skipped operators will be left in CPU
1751        exec_prog = to_edge(
1752            torch.export.export(graph_module, sample_input),
1753        ).to_executorch()
1754        self.verify_output(module, sample_input, exec_prog)
1755
1756    def test_qnn_backend_skip_node_op_partitioner(self):
1757        module = SimpleModel()  # noqa: F405
1758        sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1759        module = self.get_qdq_module(module, sample_input)
1760        self.lower_module_and_test_output(
1761            module,
1762            sample_input,
1763            expected_partitions=2,
1764            skip_node_op_set={"aten.add.Tensor"},
1765        )
1766
1767    def test_qnn_backend_skip_node_op_quantizer(self):
1768        module = SimpleModel()  # noqa: F405
1769        sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1770
1771        # define partitioner
1772        backend_options = generate_htp_compiler_spec(
1773            use_fp16=False,
1774        )
1775        compiler_specs = generate_qnn_executorch_compiler_spec(
1776            soc_model=self.chipset_table[TestQNN.model],
1777            backend_options=backend_options,
1778        )
1779        partitioner = QnnPartitioner(compiler_specs)
1780        # define quantizer
1781        quantizer = QnnQuantizer()
1782
1783        # define calibration method
1784        def calibrator(gm):
1785            gm(*sample_input)
1786
1787        # get partially lowererd graph module
1788        graph_module, exported_progs = skip_annotation(
1789            nn_module=module,
1790            quantizer=quantizer,
1791            partitioner=partitioner,
1792            sample_input=sample_input,
1793            calibration_cb=calibrator,
1794            fp_node_op_set={torch.ops.aten.add.Tensor},
1795        )
1796        self.assertEqual(len(exported_progs), 2)
1797        # lower all graph again, the skipped operators will be left in CPU
1798        exec_prog = exec_prog = to_edge(
1799            torch.export.export(graph_module, sample_input),
1800        ).to_executorch()
1801        self.verify_output(module, sample_input, exec_prog)
1802
1803    def test_qnn_backend_graph_level_mixed_precision(self):
1804        module = SimpleModel()  # noqa: F405
1805        sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1806
1807        # define partitioner
1808        backend_options = generate_htp_compiler_spec(
1809            use_fp16=False,
1810        )
1811        compiler_specs = generate_qnn_executorch_compiler_spec(
1812            soc_model=self.chipset_table[TestQNN.model],
1813            backend_options=backend_options,
1814        )
1815        partitioner = QnnPartitioner(compiler_specs)
1816        # define quantizer
1817        quantizer = QnnQuantizer()
1818
1819        # define calibration method
1820        def calibrator(gm):
1821            gm(*sample_input)
1822
1823        # get partially lowererd graph module
1824        graph_module, exported_progs = skip_annotation(
1825            nn_module=module,
1826            quantizer=quantizer,
1827            partitioner=partitioner,
1828            sample_input=sample_input,
1829            calibration_cb=calibrator,
1830            fp_node_id_set={"add", "mean"},
1831            fallback_to_cpu=False,
1832        )
1833        self.assertEqual(len(exported_progs), 5)
1834        # lower all graph again, the skipped operators will be delegated with fp16
1835        exec_prog = to_edge(
1836            torch.export.export(graph_module, sample_input),
1837        ).to_executorch()
1838        self.verify_output(module, sample_input, exec_prog)
1839
1840    def test_qnn_backend_multi_contexts(self):
1841        module = SimpleModel()  # noqa: F405
1842        sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1843        module = self.get_qdq_module(module, sample_input)
1844        edge_prog = capture_program(module, sample_input)
1845        self.split_graph(edge_prog.exported_program.graph_module, 4)
1846
1847        backend_options = generate_htp_compiler_spec(
1848            use_fp16=False,
1849            use_dlbc=True,
1850            use_multi_contexts=True,
1851        )
1852        compiler_specs = generate_qnn_executorch_compiler_spec(
1853            soc_model=self.chipset_table[TestQNN.model],
1854            backend_options=backend_options,
1855        )
1856        partitioner = QnnPartitioner(compiler_specs)
1857        edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner)
1858        update_spill_fill_size(edge_prog.exported_program)
1859        exec_prog = edge_prog.to_executorch()
1860        self.verify_output(module, sample_input, exec_prog)
1861
1862    def test_qnn_backend_multi_contexts_composite(self):
1863        backend_options = generate_htp_compiler_spec(
1864            use_fp16=False,
1865            use_dlbc=True,
1866            use_multi_contexts=True,
1867        )
1868        compiler_specs = generate_qnn_executorch_compiler_spec(
1869            soc_model=self.chipset_table[TestQNN.model],
1870            backend_options=backend_options,
1871        )
1872        module = CompositeDelegateModule(  # noqa: F405
1873            compiler_specs=compiler_specs,
1874            partitioner_type=QnnPartitioner,
1875            capture_method=capture_program,
1876            lowered_method=to_backend,
1877            quantize_method=self.get_qdq_module,
1878        )
1879        sample_input = module.get_random_input()
1880        edge_prog = to_edge(
1881            torch.export.export(module, sample_input),
1882        )
1883        update_spill_fill_size(edge_prog.exported_program())
1884        exec_prog = edge_prog.to_executorch()
1885        self.verify_output(module.get_reference_module(), sample_input, exec_prog)
1886
1887    def test_qnn_backend_multi_graphs(self):
1888        if self.enable_x86_64:
1889            self.skipTest("weight sharing is not supported on host machine")
1890
1891        seq_conv = Conv2dSequential()  # noqa: F405
1892        # weight sharing
1893        modules = [seq_conv, seq_conv.second]
1894        sample_inputs = [(torch.randn([1, 1, 3, 3]),), (torch.randn([1, 3, 3, 3]),)]
1895        graph_names = ["seq_conv", "single_conv"]
1896        edge_progs = [
1897            capture_program(self.get_qdq_module(module, sample_input), sample_input)
1898            for module, sample_input in zip(modules, sample_inputs)
1899        ]
1900        backend_options = generate_htp_compiler_spec(
1901            use_fp16=True,
1902        )
1903        compiler_specs = [
1904            generate_qnn_executorch_compiler_spec(
1905                soc_model=self.chipset_table[TestQNN.model],
1906                backend_options=backend_options,
1907                multiple_graphs=True,
1908                graph_name=graph_name,
1909            )
1910            for graph_name in graph_names
1911        ]
1912        exported_programs = [
1913            to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i]))
1914            for i, edge_prog in enumerate(edge_progs)
1915        ]
1916        prog_mgr = generate_multi_graph_program(
1917            compiler_specs=compiler_specs[0],
1918            processed_bytes=[
1919                prog.graph_module.lowered_module_0.processed_bytes
1920                for prog in exported_programs
1921            ],
1922        )
1923        for index, module in enumerate(modules):
1924            self.verify_output(
1925                module=module,
1926                sample_inputs=sample_inputs[index],
1927                executorch_prog=prog_mgr,
1928                method_index=index,
1929            )
1930
1931    def test_qnn_backend_profile_op(self):
1932        TestQNN.enable_profile = True
1933        backend_options = generate_htp_compiler_spec(use_fp16=False)
1934        TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
1935            soc_model=self.chipset_table[TestQNN.model],
1936            backend_options=backend_options,
1937            profile=True,
1938        )
1939        module = SimpleModel()  # noqa: F405
1940        sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1941        module = self.get_qdq_module(module, sample_input)
1942        self.lower_module_and_test_output(
1943            module,
1944            sample_input,
1945            expected_partitions=1,
1946            expected_profile_events=25,
1947        )
1948
1949    def test_qnn_backend_shared_buffer(self):
1950        TestQNN.shared_buffer = True
1951        backend_options = generate_htp_compiler_spec(
1952            use_fp16=False,
1953        )
1954        TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
1955            soc_model=self.chipset_table[TestQNN.model],
1956            backend_options=backend_options,
1957            shared_buffer=True,
1958        )
1959        module = SimpleModel()  # noqa: F405
1960        sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1961        module = self.get_qdq_module(module, sample_input)
1962        self.lower_module_and_test_output(
1963            module,
1964            sample_input,
1965            expected_partitions=1,
1966        )
1967
1968    def test_qnn_backend_online_prepare(self):
1969        backend_options = generate_htp_compiler_spec(use_fp16=False)
1970        TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
1971            soc_model=self.chipset_table[TestQNN.model],
1972            backend_options=backend_options,
1973            online_prepare=True,
1974        )
1975        module = SimpleModel()  # noqa: F405
1976        sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1977        module = self.get_qdq_module(module, sample_input)
1978        self.lower_module_and_test_output(module, sample_input)
1979
1980    def test_qnn_backend_context_direct(self):
1981        with tempfile.TemporaryDirectory() as tmp_dir:
1982            module = ContextBinaryExample()  # noqa: F405
1983            generate_context_binary(
1984                module=module,
1985                inputs=module.example_inputs(),
1986                quantized=True,
1987                artifact_dir=tmp_dir,
1988            )
1989            ctx_path = f"{tmp_dir}/model_ctx.bin"
1990            bundle_program = from_context_binary(ctx_path, "ctx_loader")
1991            self.verify_output(
1992                module,
1993                tuple(
1994                    torch.randn(size=v.shape, dtype=v.dtype)
1995                    for v in bundle_program["inputs"].values()
1996                ),
1997                bundle_program["edge_program_manager"].to_executorch(),
1998            )
1999
2000
2001class TestExampleOssScript(TestQNN):
2002    def required_envs(self, conditions=None) -> bool:
2003        conditions = [] if conditions is None else conditions
2004        return all(
2005            [
2006                self.executorch_root,
2007                self.artifact_dir,
2008                *conditions,
2009            ]
2010        )
2011
2012    def test_dino_v2(self):
2013        if not self.required_envs([self.image_dataset]):
2014            self.skipTest("missing required envs")
2015        cmds = [
2016            "python",
2017            f"{self.executorch_root}/examples/qualcomm/oss_scripts/dino_v2.py",
2018            "--dataset",
2019            self.image_dataset,
2020            "--artifact",
2021            self.artifact_dir,
2022            "--build_folder",
2023            self.build_folder,
2024            "--device",
2025            self.device,
2026            "--model",
2027            self.model,
2028            "--ip",
2029            self.ip,
2030            "--port",
2031            str(self.port),
2032        ]
2033        if self.host:
2034            cmds.extend(["--host", self.host])
2035
2036        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2037        with Listener((self.ip, self.port)) as listener:
2038            conn = listener.accept()
2039            p.communicate()
2040            msg = json.loads(conn.recv())
2041            if "Error" in msg:
2042                self.fail(msg["Error"])
2043            else:
2044                self.assertGreaterEqual(msg["top_1"], 70)
2045                self.assertGreaterEqual(msg["top_5"], 85)
2046
2047    def test_esrgan(self):
2048        if not self.required_envs():
2049            self.skipTest("missing required envs")
2050
2051        cmds = [
2052            "python",
2053            f"{self.executorch_root}/examples/qualcomm/oss_scripts/esrgan.py",
2054            "--artifact",
2055            self.artifact_dir,
2056            "--build_folder",
2057            self.build_folder,
2058            "--device",
2059            self.device,
2060            "--model",
2061            self.model,
2062            "--default_dataset",
2063            "--oss_repo",
2064            self.oss_repo,
2065            "--ip",
2066            self.ip,
2067            "--port",
2068            str(self.port),
2069        ]
2070        if self.host:
2071            cmds.extend(["--host", self.host])
2072
2073        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2074        with Listener((self.ip, self.port)) as listener:
2075            conn = listener.accept()
2076            p.communicate()
2077            msg = json.loads(conn.recv())
2078            if "Error" in msg:
2079                self.fail(msg["Error"])
2080            else:
2081                self.assertGreaterEqual(msg["PSNR"], 24)
2082                self.assertGreaterEqual(msg["SSIM"], 0.8)
2083
2084    def test_fastvit(self):
2085        if not self.required_envs(
2086            [self.image_dataset, self.pretrained_weight, self.oss_repo]
2087        ):
2088            self.skipTest("missing required envs")
2089        cmds = [
2090            "python",
2091            f"{self.executorch_root}/examples/qualcomm/oss_scripts/fastvit.py",
2092            "--dataset",
2093            self.image_dataset,
2094            "--artifact",
2095            self.artifact_dir,
2096            "--build_folder",
2097            self.build_folder,
2098            "--device",
2099            self.device,
2100            "--model",
2101            self.model,
2102            "--oss_repo",
2103            self.oss_repo,
2104            "--pretrained_weight",
2105            self.pretrained_weight,
2106            "--ip",
2107            self.ip,
2108            "--port",
2109            str(self.port),
2110        ]
2111        if self.host:
2112            cmds.extend(["--host", self.host])
2113
2114        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2115        with Listener((self.ip, self.port)) as listener:
2116            conn = listener.accept()
2117            p.communicate()
2118            msg = json.loads(conn.recv())
2119            if "Error" in msg:
2120                self.fail(msg["Error"])
2121            else:
2122                self.assertGreaterEqual(msg["top_1"], 60)
2123                self.assertGreaterEqual(msg["top_5"], 80)
2124
2125    def test_fbnet(self):
2126        if not self.required_envs([self.image_dataset]):
2127            self.skipTest("missing required envs")
2128
2129        cmds = [
2130            "python",
2131            f"{self.executorch_root}/examples/qualcomm/oss_scripts/fbnet.py",
2132            "--dataset",
2133            self.image_dataset,
2134            "--artifact",
2135            self.artifact_dir,
2136            "--build_folder",
2137            self.build_folder,
2138            "--device",
2139            self.device,
2140            "--model",
2141            self.model,
2142            "--ip",
2143            self.ip,
2144            "--port",
2145            str(self.port),
2146        ]
2147        if self.host:
2148            cmds.extend(["--host", self.host])
2149
2150        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2151        with Listener((self.ip, self.port)) as listener:
2152            conn = listener.accept()
2153            p.communicate()
2154            msg = json.loads(conn.recv())
2155            if "Error" in msg:
2156                self.fail(msg["Error"])
2157            else:
2158                self.assertGreaterEqual(msg["top_1"], 60)
2159                self.assertGreaterEqual(msg["top_5"], 90)
2160
2161    def test_gMLP(self):
2162        if not self.required_envs([self.image_dataset]):
2163            self.skipTest("missing required envs")
2164
2165        cmds = [
2166            "python",
2167            f"{self.executorch_root}/examples/qualcomm/oss_scripts/gMLP_image_classification.py",
2168            "--dataset",
2169            self.image_dataset,
2170            "--artifact",
2171            self.artifact_dir,
2172            "--build_folder",
2173            self.build_folder,
2174            "--device",
2175            self.device,
2176            "--model",
2177            self.model,
2178            "--ip",
2179            self.ip,
2180            "--port",
2181            str(self.port),
2182        ]
2183        if self.host:
2184            cmds.extend(["--host", self.host])
2185
2186        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2187        with Listener((self.ip, self.port)) as listener:
2188            conn = listener.accept()
2189            p.communicate()
2190            msg = json.loads(conn.recv())
2191            if "Error" in msg:
2192                self.fail(msg["Error"])
2193            else:
2194                self.assertGreaterEqual(msg["top_1"], 60)
2195                self.assertGreaterEqual(msg["top_5"], 90)
2196
2197    def test_regnet(self):
2198        if not self.required_envs([self.image_dataset]):
2199            self.skipTest("missing required envs")
2200
2201        weights = ["regnet_y_400mf", "regnet_x_400mf"]
2202        cmds = [
2203            "python",
2204            f"{self.executorch_root}/examples/qualcomm/oss_scripts/regnet.py",
2205            "--dataset",
2206            self.image_dataset,
2207            "--artifact",
2208            self.artifact_dir,
2209            "--build_folder",
2210            self.build_folder,
2211            "--device",
2212            self.device,
2213            "--model",
2214            self.model,
2215            "--ip",
2216            self.ip,
2217            "--port",
2218            str(self.port),
2219        ]
2220        if self.host:
2221            cmds.extend(["--host", self.host])
2222
2223        for weight in weights:
2224            p = subprocess.Popen(
2225                cmds + ["--weights", weight], stdout=subprocess.DEVNULL
2226            )
2227            with Listener((self.ip, self.port)) as listener:
2228                conn = listener.accept()
2229                p.communicate()
2230                msg = json.loads(conn.recv())
2231                if "Error" in msg:
2232                    self.fail(msg["Error"])
2233                else:
2234                    self.assertGreaterEqual(msg["top_1"], 60)
2235                    self.assertGreaterEqual(msg["top_5"], 85)
2236
2237    def test_retinanet(self):
2238        if not self.required_envs([self.image_dataset]):
2239            self.skipTest("missing required envs")
2240
2241        cmds = [
2242            "python",
2243            f"{self.executorch_root}/examples/qualcomm/oss_scripts/retinanet.py",
2244            "--artifact",
2245            self.artifact_dir,
2246            "--build_folder",
2247            self.build_folder,
2248            "--device",
2249            self.device,
2250            "--model",
2251            self.model,
2252            "--dataset",
2253            self.image_dataset,
2254            "--ip",
2255            self.ip,
2256            "--port",
2257            str(self.port),
2258        ]
2259        if self.host:
2260            cmds.extend(["--host", self.host])
2261
2262        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2263        with Listener((self.ip, self.port)) as listener:
2264            conn = listener.accept()
2265            p.communicate()
2266            msg = json.loads(conn.recv())
2267            if "Error" in msg:
2268                self.fail(msg["Error"])
2269            else:
2270                self.assertGreaterEqual(msg["mAP"], 0.6)
2271
2272    def test_squeezenet(self):
2273        if not self.required_envs([self.image_dataset]):
2274            self.skipTest("missing required envs")
2275
2276        cmds = [
2277            "python",
2278            f"{self.executorch_root}/examples/qualcomm/oss_scripts/squeezenet.py",
2279            "--dataset",
2280            self.image_dataset,
2281            "--artifact",
2282            self.artifact_dir,
2283            "--build_folder",
2284            self.build_folder,
2285            "--device",
2286            self.device,
2287            "--model",
2288            self.model,
2289            "--ip",
2290            self.ip,
2291            "--port",
2292            str(self.port),
2293        ]
2294        if self.host:
2295            cmds.extend(["--host", self.host])
2296
2297        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2298        with Listener((self.ip, self.port)) as listener:
2299            conn = listener.accept()
2300            p.communicate()
2301            msg = json.loads(conn.recv())
2302            if "Error" in msg:
2303                self.fail(msg["Error"])
2304            else:
2305                self.assertGreaterEqual(msg["top_1"], 45)
2306                self.assertGreaterEqual(msg["top_5"], 70)
2307
2308    def test_ssd300_vgg16(self):
2309        if not self.required_envs([self.pretrained_weight, self.oss_repo]):
2310            self.skipTest("missing required envs")
2311
2312        cmds = [
2313            "python",
2314            f"{self.executorch_root}/examples/qualcomm/oss_scripts/ssd300_vgg16.py",
2315            "--artifact",
2316            self.artifact_dir,
2317            "--build_folder",
2318            self.build_folder,
2319            "--device",
2320            self.device,
2321            "--model",
2322            self.model,
2323            "--oss_repo",
2324            self.oss_repo,
2325            "--pretrained_weight",
2326            self.pretrained_weight,
2327            "--ip",
2328            self.ip,
2329            "--port",
2330            str(self.port),
2331        ]
2332        if self.host:
2333            cmds.extend(["--host", self.host])
2334
2335        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2336        with Listener((self.ip, self.port)) as listener:
2337            conn = listener.accept()
2338            p.communicate()
2339            msg = json.loads(conn.recv())
2340            if "Error" in msg:
2341                self.fail(msg["Error"])
2342            else:
2343                self.assertGreaterEqual(msg["mAP"], 0.70)
2344
2345
2346class TestExampleQaihubScript(TestQNN):
2347
2348    def required_envs(self, conditions=None) -> bool:
2349        conditions = [] if conditions is None else conditions
2350        return all(
2351            [
2352                self.executorch_root,
2353                self.artifact_dir,
2354                *conditions,
2355            ]
2356        )
2357
2358    def test_utils_export(self):
2359        with tempfile.TemporaryDirectory() as tmp_dir:
2360            module = ContextBinaryExample()  # noqa: F405
2361            generate_context_binary(
2362                module=module,
2363                inputs=module.example_inputs(),
2364                quantized=True,
2365                artifact_dir=tmp_dir,
2366            )
2367            ctx_path = f"{tmp_dir}/model_ctx.bin"
2368            fpath = f"{self.executorch_root}/examples/qualcomm/qaihub_scripts/utils/export.py"
2369
2370            # do compilation
2371            compile_cmds = [
2372                "python",
2373                fpath,
2374                "compile",
2375                "-a",
2376                ctx_path,
2377                "-m",
2378                self.model,
2379                "-l",
2380                "False",
2381                "-b",
2382                self.build_folder,
2383                "-o",
2384                f"{tmp_dir}/output_pte",
2385            ]
2386            compile_process = subprocess.Popen(
2387                compile_cmds, stdout=subprocess.DEVNULL, cwd=self.executorch_root
2388            )
2389            output_pte_dir = f"{tmp_dir}/output_pte/model_ctx"
2390            compile_process.communicate()
2391
2392            # check artifacts are correctly generated
2393            self.assertTrue(
2394                all(
2395                    [
2396                        Path(output_pte_dir).exists(),
2397                        Path(f"{output_pte_dir}/model_ctx.json").exists(),
2398                        Path(f"{output_pte_dir}/model_ctx.svg").exists(),
2399                    ]
2400                )
2401            )
2402
2403            # prepare input files
2404            input_list, inputs = [], module.example_inputs()
2405            for name, tensor in inputs.items():
2406                tensor_path = f"{output_pte_dir}/{name}.pt"
2407                torch.save(tensor, tensor_path)
2408                input_list.append(tensor_path)
2409
2410            # do execution
2411            output_data_dir = f"{tmp_dir}/output_data"
2412            execute_cmds = [
2413                "python",
2414                fpath,
2415                "execute",
2416                "-p",
2417                output_pte_dir,
2418                "-i",
2419                *input_list,
2420                "-s",
2421                self.device,
2422                "-z",
2423                "-b",
2424                self.build_folder,
2425                "-o",
2426                output_data_dir,
2427            ]
2428            if self.host is not None:
2429                execute_cmds.append(f"-H {self.host}")
2430            execute_process = subprocess.Popen(execute_cmds, cwd=self.executorch_root)
2431            execute_process.communicate()
2432
2433            # read outputs
2434            with open(f"{output_pte_dir}/model_ctx.json", "r") as f:
2435                graph_info = json.load(f)
2436
2437            device_output = []
2438            for output in graph_info["outputs"]:
2439                with open(f"{output_data_dir}/{output['name']}.pt", "rb") as f:
2440                    buffer = io.BytesIO(f.read())
2441                    device_output.append(torch.load(buffer, weights_only=False))
2442
2443            # validate outputs
2444            golden_output = module.forward(inputs["x"], inputs["y"])
2445            self.atol, self.rtol = 1e-1, 1
2446            self._assert_outputs_equal(golden_output, device_output)
2447
2448    def test_llama2_7b(self):
2449        if not self.required_envs():
2450            self.skipTest("missing required envs")
2451
2452        prompt = "Explain the rules of baseball"
2453        cmds = [
2454            "python",
2455            f"{self.executorch_root}/examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b.py",
2456            "--artifact",
2457            self.artifact_dir,
2458            "--build_folder",
2459            self.build_folder,
2460            "--device",
2461            self.device,
2462            "--model",
2463            self.model,
2464            "--tokenizer_bin",
2465            f"{self.artifact_dir}/tokenizer.bin",
2466            "--context_binaries",
2467            f"{self.artifact_dir}",
2468            "--ip",
2469            self.ip,
2470            "--port",
2471            str(self.port),
2472            "--prompt",
2473            f"{prompt}",
2474        ]
2475        if self.host:
2476            cmds.extend(["--host", self.host])
2477
2478        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2479        with Listener((self.ip, self.port)) as listener:
2480            conn = listener.accept()
2481            p.communicate()
2482            msg = json.loads(conn.recv())
2483            if "Error" in msg:
2484                self.fail(msg["Error"])
2485            else:
2486                model_out = msg["result"]
2487                self.assertTrue(model_out.startswith(prompt))
2488
2489    def test_llama3_8b(self):
2490        if not self.required_envs():
2491            self.skipTest("missing required envs")
2492
2493        prompt = "Explain the rules of baseball"
2494        cmds = [
2495            "python",
2496            f"{self.executorch_root}/examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b.py",
2497            "--artifact",
2498            self.artifact_dir,
2499            "--build_folder",
2500            self.build_folder,
2501            "--device",
2502            self.device,
2503            "--model",
2504            self.model,
2505            "--tokenizer_model",
2506            f"{self.artifact_dir}/tokenizer.model",
2507            "--context_binaries",
2508            f"{self.artifact_dir}",
2509            "--ip",
2510            self.ip,
2511            "--port",
2512            str(self.port),
2513            "--prompt",
2514            f"{prompt}",
2515        ]
2516        if self.host:
2517            cmds.extend(["--host", self.host])
2518
2519        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2520        with Listener((self.ip, self.port)) as listener:
2521            conn = listener.accept()
2522            p.communicate()
2523            msg = json.loads(conn.recv())
2524            if "Error" in msg:
2525                self.fail(msg["Error"])
2526            else:
2527                model_out = msg["result"]
2528                expected_result = (
2529                    "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
2530                    + prompt
2531                    + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
2532                )
2533                self.assertTrue(model_out.startswith(expected_result))
2534
2535    def test_stable_diffusion(self):
2536        if not self.required_envs():
2537            self.skipTest("missing required envs")
2538
2539        prompt = "a photo of an astronaut riding a horse on mars"
2540        cmds = [
2541            "python",
2542            f"{self.executorch_root}/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py",
2543            "--artifact",
2544            self.artifact_dir,
2545            "--build_folder",
2546            self.build_folder,
2547            "--device",
2548            self.device,
2549            "--model",
2550            self.model,
2551            "--text_encoder_bin",
2552            f"{self.artifact_dir}/text_encoder.serialized.bin",
2553            "--unet_bin",
2554            f"{self.artifact_dir}/unet.serialized.bin",
2555            "--vae_bin",
2556            f"{self.artifact_dir}/vae.serialized.bin",
2557            "--vocab_json",
2558            f"{self.artifact_dir}/vocab.json",
2559            "--num_time_steps",
2560            "20",
2561            "--ip",
2562            self.ip,
2563            "--port",
2564            str(self.port),
2565            "--prompt",
2566            f"{prompt}",
2567            "--fix_latents",
2568        ]
2569        if self.host:
2570            cmds.extend(["--host", self.host])
2571
2572        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2573        with Listener((self.ip, self.port)) as listener:
2574            conn = listener.accept()
2575            p.communicate()
2576            msg = json.loads(conn.recv())
2577            if "Error" in msg:
2578                self.fail(msg["Error"])
2579            else:
2580                # For the default settings and prompt, the expected results will be {PSNR: 23.258, SSIM: 0.852}
2581                self.assertGreaterEqual(msg["PSNR"], 20)
2582                self.assertGreaterEqual(msg["SSIM"], 0.8)
2583
2584
2585class TestExampleScript(TestQNN):
2586    def required_envs(self, conditions=None) -> bool:
2587        conditions = [] if conditions is None else conditions
2588        return all(
2589            [
2590                self.executorch_root,
2591                self.artifact_dir,
2592                *conditions,
2593            ]
2594        )
2595
2596    def test_mobilenet_v2(self):
2597        if not self.required_envs([self.image_dataset]):
2598            self.skipTest("missing required envs")
2599
2600        cmds = [
2601            "python",
2602            f"{self.executorch_root}/examples/qualcomm/scripts/mobilenet_v2.py",
2603            "--dataset",
2604            self.image_dataset,
2605            "--artifact",
2606            self.artifact_dir,
2607            "--build_folder",
2608            self.build_folder,
2609            "--device",
2610            self.device,
2611            "--model",
2612            self.model,
2613            "--ip",
2614            self.ip,
2615            "--port",
2616            str(self.port),
2617        ]
2618        if self.host:
2619            cmds.extend(["--host", self.host])
2620        if self.shared_buffer:
2621            cmds.extend(["--shared_buffer"])
2622
2623        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2624        with Listener((self.ip, self.port)) as listener:
2625            conn = listener.accept()
2626            p.communicate()
2627            msg = json.loads(conn.recv())
2628            if "Error" in msg:
2629                self.fail(msg["Error"])
2630            else:
2631                self.assertGreaterEqual(msg["top_1"], 60)
2632                self.assertGreaterEqual(msg["top_5"], 80)
2633
2634    def test_mobilenet_v3(self):
2635        if not self.required_envs([self.image_dataset]):
2636            self.skipTest("missing required envs")
2637
2638        cmds = [
2639            "python",
2640            f"{self.executorch_root}/examples/qualcomm/scripts/mobilenet_v3.py",
2641            "--dataset",
2642            self.image_dataset,
2643            "--artifact",
2644            self.artifact_dir,
2645            "--build_folder",
2646            self.build_folder,
2647            "--device",
2648            self.device,
2649            "--model",
2650            self.model,
2651            "--ip",
2652            self.ip,
2653            "--port",
2654            str(self.port),
2655        ]
2656        if self.host:
2657            cmds.extend(["--host", self.host])
2658        if self.shared_buffer:
2659            cmds.extend(["--shared_buffer"])
2660
2661        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2662        with Listener((self.ip, self.port)) as listener:
2663            conn = listener.accept()
2664            p.communicate()
2665            msg = json.loads(conn.recv())
2666            if "Error" in msg:
2667                self.fail(msg["Error"])
2668            else:
2669                self.assertGreaterEqual(msg["top_1"], 60)
2670                self.assertGreaterEqual(msg["top_5"], 80)
2671
2672    def test_inception_v3(self):
2673        if not self.required_envs([self.image_dataset]):
2674            self.skipTest("missing required envs")
2675
2676        cmds = [
2677            "python",
2678            f"{self.executorch_root}/examples/qualcomm/scripts/inception_v3.py",
2679            "--dataset",
2680            self.image_dataset,
2681            "--artifact",
2682            self.artifact_dir,
2683            "--build_folder",
2684            self.build_folder,
2685            "--device",
2686            self.device,
2687            "--model",
2688            self.model,
2689            "--ip",
2690            self.ip,
2691            "--port",
2692            str(self.port),
2693        ]
2694        if self.host:
2695            cmds.extend(["--host", self.host])
2696        if self.shared_buffer:
2697            cmds.extend(["--shared_buffer"])
2698
2699        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2700        with Listener((self.ip, self.port)) as listener:
2701            conn = listener.accept()
2702            p.communicate()
2703            msg = json.loads(conn.recv())
2704            if "Error" in msg:
2705                self.fail(msg["Error"])
2706            else:
2707                self.assertGreaterEqual(msg["top_1"], 60)
2708                self.assertGreaterEqual(msg["top_5"], 80)
2709
2710    def test_inception_v4(self):
2711        if not self.required_envs([self.image_dataset]):
2712            self.skipTest("missing required envs")
2713
2714        cmds = [
2715            "python",
2716            f"{self.executorch_root}/examples/qualcomm/scripts/inception_v4.py",
2717            "--dataset",
2718            self.image_dataset,
2719            "--artifact",
2720            self.artifact_dir,
2721            "--build_folder",
2722            self.build_folder,
2723            "--device",
2724            self.device,
2725            "--model",
2726            self.model,
2727            "--ip",
2728            self.ip,
2729            "--port",
2730            str(self.port),
2731        ]
2732        if self.host:
2733            cmds.extend(["--host", self.host])
2734        if self.shared_buffer:
2735            cmds.extend(["--shared_buffer"])
2736
2737        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2738        with Listener((self.ip, self.port)) as listener:
2739            conn = listener.accept()
2740            p.communicate()
2741            msg = json.loads(conn.recv())
2742            if "Error" in msg:
2743                self.fail(msg["Error"])
2744            else:
2745                self.assertGreaterEqual(msg["top_1"], 60)
2746                self.assertGreaterEqual(msg["top_5"], 80)
2747
2748    def test_vit(self):
2749        if not self.required_envs([self.image_dataset]):
2750            self.skipTest("missing required envs")
2751
2752        cmds = [
2753            "python",
2754            f"{self.executorch_root}/examples/qualcomm/scripts/torchvision_vit.py",
2755            "--dataset",
2756            self.image_dataset,
2757            "--artifact",
2758            self.artifact_dir,
2759            "--build_folder",
2760            self.build_folder,
2761            "--device",
2762            self.device,
2763            "--model",
2764            self.model,
2765            "--ip",
2766            self.ip,
2767            "--port",
2768            str(self.port),
2769        ]
2770        if self.host:
2771            cmds.extend(["--host", self.host])
2772        if self.shared_buffer:
2773            cmds.extend(["--shared_buffer"])
2774
2775        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2776        with Listener((self.ip, self.port)) as listener:
2777            conn = listener.accept()
2778            p.communicate()
2779            msg = json.loads(conn.recv())
2780            if "Error" in msg:
2781                self.fail(msg["Error"])
2782            else:
2783                self.assertGreaterEqual(msg["top_1"], 65)
2784                self.assertGreaterEqual(msg["top_5"], 90)
2785
2786    def test_edsr(self):
2787        if not self.required_envs():
2788            self.skipTest("missing required envs")
2789
2790        cmds = [
2791            "python",
2792            f"{self.executorch_root}/examples/qualcomm/scripts/edsr.py",
2793            "--artifact",
2794            self.artifact_dir,
2795            "--build_folder",
2796            self.build_folder,
2797            "--device",
2798            self.device,
2799            "--model",
2800            self.model,
2801            "--default_dataset",
2802            "--ip",
2803            self.ip,
2804            "--port",
2805            str(self.port),
2806        ]
2807        if self.host:
2808            cmds.extend(["--host", self.host])
2809        if self.shared_buffer:
2810            cmds.extend(["--shared_buffer"])
2811
2812        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2813        with Listener((self.ip, self.port)) as listener:
2814            conn = listener.accept()
2815            p.communicate()
2816            msg = json.loads(conn.recv())
2817            if "Error" in msg:
2818                self.fail(msg["Error"])
2819            else:
2820                self.assertGreaterEqual(msg["PSNR"], 25)
2821                self.assertGreaterEqual(msg["SSIM"], 0.8)
2822
2823    def test_deeplab_v3(self):
2824        if not self.required_envs():
2825            self.skipTest("missing required envs")
2826
2827        cmds = [
2828            "python",
2829            f"{self.executorch_root}/examples/qualcomm/scripts/deeplab_v3.py",
2830            "--artifact",
2831            self.artifact_dir,
2832            "--build_folder",
2833            self.build_folder,
2834            "--device",
2835            self.device,
2836            "--model",
2837            self.model,
2838            "--download",
2839            "--ip",
2840            self.ip,
2841            "--port",
2842            str(self.port),
2843        ]
2844        if self.host:
2845            cmds.extend(["--host", self.host])
2846        if self.shared_buffer:
2847            cmds.extend(["--shared_buffer"])
2848
2849        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2850        with Listener((self.ip, self.port)) as listener:
2851            conn = listener.accept()
2852            p.communicate()
2853            msg = json.loads(conn.recv())
2854            if "Error" in msg:
2855                self.fail(msg["Error"])
2856            else:
2857                self.assertGreaterEqual(msg["PA"], 0.85)
2858                self.assertGreaterEqual(msg["MPA"], 0.70)
2859                self.assertGreaterEqual(msg["MIoU"], 0.55)
2860
2861    def test_stories_single_llama(self):
2862        if not self.required_envs():
2863            self.skipTest("missing required envs")
2864
2865        cmds = [
2866            "python",
2867            f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama2/llama.py",
2868            "--artifact",
2869            self.artifact_dir,
2870            "--build_folder",
2871            self.build_folder,
2872            "--device",
2873            self.device,
2874            "--model",
2875            self.model,
2876            "--checkpoint",
2877            f"{self.artifact_dir}/stories110M.pt",
2878            "--params",
2879            f"{self.artifact_dir}/params.json",
2880            "--tokenizer_model",
2881            f"{self.artifact_dir}/tokenizer.model",
2882            "--tokenizer_bin",
2883            f"{self.artifact_dir}/tokenizer.bin",
2884            "--ip",
2885            self.ip,
2886            "--port",
2887            str(self.port),
2888            "--prompt",
2889            "Once",
2890            "--ptq",
2891            "16a4w",
2892            "--temperature",
2893            "0",
2894        ]
2895        if self.host:
2896            cmds.extend(["--host", self.host])
2897
2898        golden_start_with = "Once upon a time,"
2899        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2900        with Listener((self.ip, self.port)) as listener:
2901            conn = listener.accept()
2902            p.communicate()
2903            msg = json.loads(conn.recv())
2904            if "Error" in msg:
2905                self.fail(msg["Error"])
2906            else:
2907                model_out = msg["result"][0]
2908                self.assertTrue(model_out.startswith(golden_start_with))
2909
2910    @unittest.skip("dynamic shape inputs appear in recent torch.export.export")
2911    def test_mobilebert(self):
2912        if not self.required_envs([self.pretrained_weight]):
2913            self.skipTest("missing required envs")
2914
2915        cmds = [
2916            "python",
2917            f"{self.executorch_root}/examples/qualcomm/scripts/mobilebert_fine_tune.py",
2918            "--artifact",
2919            self.artifact_dir,
2920            "--build_folder",
2921            self.build_folder,
2922            "--device",
2923            self.device,
2924            "--model",
2925            self.model,
2926            "--pretrained_weight",
2927            self.pretrained_weight,
2928            "--ip",
2929            self.ip,
2930            "--port",
2931            str(self.port),
2932            "--use_fp16",
2933        ]
2934        if self.host:
2935            cmds.extend(["--host", self.host])
2936        if self.shared_buffer:
2937            cmds.extend(["--shared_buffer"])
2938
2939        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2940        with Listener((self.ip, self.port)) as listener:
2941            conn = listener.accept()
2942            p.communicate()
2943            msg = json.loads(conn.recv())
2944            if "Error" in msg:
2945                self.fail(msg["Error"])
2946            else:
2947                cpu, htp = msg["CPU"], msg["HTP"]
2948                for k, v in cpu.items():
2949                    self.assertLessEqual(abs(v[0] - htp[k][0]), 2)
2950
2951    @unittest.skip("eagar mode fake quant works well, need further investigation")
2952    def test_ptq_mobilebert(self):
2953        if not self.required_envs([self.pretrained_weight]):
2954            self.skipTest("missing required envs")
2955
2956        cmds = [
2957            "python",
2958            f"{self.executorch_root}/examples/qualcomm/scripts/mobilebert_fine_tune.py",
2959            "--artifact",
2960            self.artifact_dir,
2961            "--build_folder",
2962            self.build_folder,
2963            "--device",
2964            self.device,
2965            "--model",
2966            self.model,
2967            "--pretrained_weight",
2968            self.pretrained_weight,
2969            "--ptq",
2970            "16a16w",
2971            "--ip",
2972            self.ip,
2973            "--port",
2974            str(self.port),
2975        ]
2976        if self.host:
2977            cmds.extend(["--host", self.host])
2978        if self.shared_buffer:
2979            cmds.extend(["--shared_buffer"])
2980
2981        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2982        with Listener((self.ip, self.port)) as listener:
2983            conn = listener.accept()
2984            p.communicate()
2985            msg = json.loads(conn.recv())
2986            if "Error" in msg:
2987                self.fail(msg["Error"])
2988            else:
2989                cpu, htp = msg["CPU"], msg["HTP"]
2990                for k, v in cpu.items():
2991                    self.assertLessEqual(abs(v[0] - htp[k][0]), 5)
2992
2993    def test_wav2letter(self):
2994        if not self.required_envs([self.pretrained_weight]):
2995            self.skipTest("missing required envs")
2996
2997        cmds = [
2998            "python",
2999            f"{self.executorch_root}/examples/qualcomm/scripts/wav2letter.py",
3000            "--artifact",
3001            self.artifact_dir,
3002            "--build_folder",
3003            self.build_folder,
3004            "--device",
3005            self.device,
3006            "--model",
3007            self.model,
3008            "--pretrained_weight",
3009            self.pretrained_weight,
3010            "--ip",
3011            self.ip,
3012            "--port",
3013            str(self.port),
3014        ]
3015        if self.host:
3016            cmds.extend(["--host", self.host])
3017        if self.shared_buffer:
3018            cmds.extend(["--shared_buffer"])
3019
3020        p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
3021        with Listener((self.ip, self.port)) as listener:
3022            conn = listener.accept()
3023            p.communicate()
3024            msg = json.loads(conn.recv())
3025            if "Error" in msg:
3026                self.fail(msg["Error"])
3027            else:
3028                self.assertLessEqual(msg["wer"], 0.5)
3029                self.assertLessEqual(msg["cer"], 0.25)
3030
3031    def test_export_example(self):
3032        if not self.required_envs([self.model_name]):
3033            self.skipTest("missing required envs")
3034
3035        with tempfile.TemporaryDirectory() as tmp_dir:
3036            cmds = [
3037                "python",
3038                "qualcomm/scripts/export_example.py",
3039                "--model_name",
3040                self.model_name,
3041                "--output_folder",
3042                "{}/".format(tmp_dir),
3043                "--generate_etrecord",
3044            ]
3045
3046            p = subprocess.Popen(
3047                cmds, stdout=subprocess.DEVNULL, cwd=f"{self.executorch_root}/examples"
3048            )
3049            p.communicate()
3050            self.assertTrue(
3051                Path("{0}/{1}.pte".format(tmp_dir, self.model_name)).exists()
3052            )
3053
3054
3055def setup_environment():
3056    parser = setup_common_args_and_variables()
3057
3058    parser.add_argument(
3059        "-r",
3060        "--executorch_root",
3061        help="Root location of current repo",
3062        type=str,
3063    )
3064    parser.add_argument(
3065        "-a",
3066        "--artifact_dir",
3067        help="Location for putting generated artifacts",
3068        type=str,
3069    )
3070    parser.add_argument(
3071        "-i",
3072        "--image_dataset",
3073        help="Location for imagenet dataset",
3074        type=str,
3075    )
3076    parser.add_argument(
3077        "-p",
3078        "--pretrained_weight",
3079        help="Location for pretrained weighting",
3080        default="",
3081        type=str,
3082    )
3083    parser.add_argument(
3084        "-n",
3085        "--model_name",
3086        help="Input the model to export",
3087        type=str,
3088    )
3089    parser.add_argument(
3090        "-o",
3091        "--online_prepare",
3092        help="Conduct on-device graph compilation",
3093        action="store_true",
3094    )
3095    parser.add_argument(
3096        "-P",
3097        "--enable_profile",
3098        help="Profile the performance of each operator with kProfileDetailed profile level",
3099        action="store_true",
3100    )
3101    parser.add_argument(
3102        "-e",
3103        "--error_only",
3104        help="Emit log only when error happened",
3105        action="store_true",
3106    )
3107    parser.add_argument(
3108        "--oss_repo",
3109        help="Path to open source software model repository",
3110        type=str,
3111    )
3112    parser.add_argument(
3113        "-x",
3114        "--enable_x86_64",
3115        help="Enable unittest to be executed on x86_64 platform",
3116        action="store_true",
3117    )
3118
3119    args, ns_args = parser.parse_known_args(namespace=unittest)
3120    TestQNN.host = args.host
3121    TestQNN.device = args.device
3122    TestQNN.model = args.model
3123    TestQNN.build_folder = args.build_folder
3124    TestQNN.executorch_root = args.executorch_root
3125    TestQNN.artifact_dir = args.artifact_dir
3126    TestQNN.image_dataset = args.image_dataset
3127    TestQNN.pretrained_weight = args.pretrained_weight
3128    TestQNN.model_name = args.model_name
3129    TestQNN.online_prepare = args.online_prepare
3130    TestQNN.enable_profile = args.enable_profile
3131    TestQNN.error_only = args.error_only
3132    TestQNN.oss_repo = args.oss_repo
3133    TestQNN.shared_buffer = args.shared_buffer
3134    TestQNN.enable_x86_64 = args.enable_x86_64
3135    TestQNN.dump_intermediate_outputs = args.dump_intermediate_outputs
3136    return sys.argv[:1] + ns_args
3137
3138
3139if __name__ == "__main__":
3140    ut_args = setup_environment()
3141    unittest.main(argv=ut_args)
3142