xref: /aosp_15_r20/external/executorch/backends/vulkan/test/op_tests/cases.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7
8from collections import namedtuple
9from typing import Callable
10
11from executorch.backends.vulkan.test.op_tests.utils.test_suite import VkTestSuite
12
13
14# Prime numbers dim sizes for testing
15XL = 113
16L = 89
17M2 = 41
18M1 = 37
19M = 29
20S2 = 11
21S1 = 7
22S = 5
23XS = 3
24
25test_suites = {}
26
27
28def register_test_suite(aten_op):
29    def test_suite_decorator(fn: Callable) -> Callable:
30        if isinstance(aten_op, str):
31            test_suites[aten_op] = fn()
32        elif isinstance(aten_op, list):
33            for op in aten_op:
34                test_suites[op] = fn()
35        return fn
36
37    return test_suite_decorator
38
39
40@register_test_suite(
41    ["aten.add.Tensor", "aten.sub.Tensor", "aten.div.Tensor", "aten.mul.Tensor"]
42)
43def get_binary_elementwise_inputs():
44    test_suite = VkTestSuite(
45        [
46            ((M1, M2), (M1, M2)),
47            ((M1, M2), (M1, 1), 2.0),
48            ((M1, M2), (1, M2)),
49            ((S, S1, S2), (S, S1, S2)),
50            ((S, S1, S2), (S, S1, 1), 2.0),
51            ((S, S1, S2), (S, 1, S2), 2.0),
52            ((XS, S, S1, S2), (XS, S, 1, 1), 2.0),
53        ]
54    )
55    test_suite.layouts = [
56        "utils::kWidthPacked",
57        "utils::kChannelsPacked",
58    ]
59    return test_suite
60
61
62@register_test_suite("aten.mm.default")
63def get_mm_inputs():
64    test_suite = VkTestSuite(
65        [
66            ((M1, L), (L, M2)),
67            ((S1, S2), (S2, M)),
68            ((6, 32), (32, 64)),
69        ],
70    )
71    test_suite.prepacked_args = ["mat2"]
72    # ATen matmul doesn't support half
73    test_suite.dtypes = ["at::kFloat"]
74    test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"]
75    test_suite.layouts = [
76        "utils::kWidthPacked",
77        "utils::kChannelsPacked",
78    ]
79    return test_suite
80
81
82@register_test_suite("aten.bmm.default")
83def get_bmm_inputs():
84    test_suite = VkTestSuite(
85        [
86            ((S, M1, L), (S, L, M2)),
87            ((M, S1, S2), (M, S2, M)),
88            ((4, 6, 32), (4, 32, 16)),
89        ],
90    )
91    test_suite.prepacked_args = ["mat2"]
92    # ATen matmul doesn't support half
93    test_suite.dtypes = ["at::kFloat"]
94    test_suite.layouts = [
95        "utils::kWidthPacked",
96        "utils::kChannelsPacked",
97    ]
98    return test_suite
99
100
101@register_test_suite("aten.addmm.default")
102def get_addmm_inputs():
103    test_suite = VkTestSuite(
104        [
105            ((1, S), (S1, S), (S, S), 1.0, 1.5),
106            ((S, 1), (S, S1), (S1, S1), 1.0, 1.0),
107            ((M1, M2), (M1, M2), (M2, M2)),
108            ((M1, M2), (M1, M2), (M2, M2), 4.2, 2.3),
109            ((M1, 1), (M1, L), (L, L), 2.0, 3.0),
110            ((M2), (M1, M2), (M2, M2)),
111            ((6, M2), (6, M2), (M2, M2)),
112        ]
113    )
114    # ATen matmul doesn't support half
115    test_suite.dtypes = ["at::kFloat"]
116    test_suite.layouts = [
117        "utils::kWidthPacked",
118        "utils::kChannelsPacked",
119    ]
120    return test_suite
121
122
123common_MKN_list = [
124    (S2, M2, M1),
125    (L, L, M1),
126]
127
128
129def get_linear_texture_inputs():
130    MKN_list = common_MKN_list
131
132    inputs_list = [((M, K), (N, K), None) for M, K, N in MKN_list]
133    inputs_list += [((M, K), (N, K), (N)) for M, K, N in MKN_list]
134    inputs_list += [((3, M, K), (N, K), None) for M, K, N in MKN_list]
135    inputs_list += [((3, M, K), (N, K), (N)) for M, K, N in MKN_list]
136    inputs_list += [((3, 6, K), (N, K), (N)) for M, K, N in MKN_list]
137
138    test_suite = VkTestSuite(inputs_list)
139    test_suite.dtypes = ["at::kFloat"]
140    test_suite.layouts = [
141        "utils::kWidthPacked",
142        "utils::kChannelsPacked",
143    ]
144    test_suite.test_name_suffix = "texture"
145    return test_suite
146
147
148def get_linear_buffer_inputs():
149    MKN_list = common_MKN_list
150
151    inputs_list = [((M, K), (N, K), None) for M, K, N in MKN_list]
152    inputs_list += [((3, M, K), (N, K), None) for M, K, N in MKN_list]
153
154    test_suite = VkTestSuite(inputs_list)
155    test_suite.dtypes = ["at::kFloat"]
156    test_suite.layouts = [
157        "utils::kWidthPacked",
158        "utils::kChannelsPacked",
159    ]
160    test_suite.storage_types = ["utils::kBuffer"]
161    test_suite.test_name_suffix = "buffer"
162    return test_suite
163
164
165@register_test_suite("aten.linear.default")
166def get_linear_test_suites():
167    return [get_linear_texture_inputs(), get_linear_buffer_inputs()]
168
169
170@register_test_suite("aten._weight_int8pack_mm.default")
171def get_weight_int8pack_mm_inputs():
172    MKN_list = common_MKN_list
173
174    inputs_list = [((M, K), (N, K), (N)) for M, K, N in MKN_list]
175
176    test_suite = VkTestSuite(inputs_list)
177    test_suite.dtypes = ["at::kFloat", "at::kHalf"]
178    test_suite.layouts = ["utils::kWidthPacked"]
179    test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"]
180    test_suite.prepacked_args = ["mat2", "scales"]
181    test_suite.requires_prepack = True
182
183    test_suite.arg_dtype["mat2"] = "at::kChar"
184    test_suite.arg_data_range["mat2"] = (0, 100)
185
186    test_suite.arg_data_range["scales"] = (0.0008, 0.001)
187
188    return test_suite
189
190
191@register_test_suite("aten.avg_pool2d.default")
192def get_avg_pool2d_inputs():
193    Test = namedtuple(
194        "VkAvgPoolTest",
195        [
196            "self",
197            "kernel_size",
198            "stride",
199            "padding",
200            "ceil_mode",
201            "count_include_pad",
202            "divisor_override",
203        ],
204    )
205
206    test_cases = []
207    for ceil_mode in [True, False]:
208        for count_include_pad in [True, False]:
209            for divisor_override in [None, 5]:
210                test_cases += [
211                    Test(
212                        self=(S, M1, M2),
213                        kernel_size=[2, 2],
214                        stride=[1, 1],
215                        padding=[0, 0],
216                        ceil_mode=ceil_mode,
217                        count_include_pad=count_include_pad,
218                        divisor_override=divisor_override,
219                    ),
220                ]
221
222    test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
223    test_suite.dtypes = ["at::kFloat"]
224    return test_suite
225
226
227@register_test_suite("aten.max_pool2d_with_indices.default")
228def get_max_pool2d_inputs():
229    test_suite = VkTestSuite(
230        [
231            ((S, M1, M2), [2, 2], [1, 1], [0, 0], [1, 1]),
232        ]
233    )
234    return test_suite
235
236
237@register_test_suite("aten.convolution.default")
238def get_conv_inputs():
239    test_suite = VkTestSuite(
240        [
241            (
242                (1, 6, 40, 50),
243                (8, 6, 3, 3),
244                (8,),
245                [1, 2],
246                [2, 3],
247                [1, 1],
248                False,
249                [0, 0],
250                1,
251            ),
252            (
253                (1, 6, 40, 50),
254                (6, 8, 3, 3),
255                (8,),
256                [1, 2],
257                [2, 3],
258                [1, 1],
259                True,
260                [0, 1],
261                1,
262            ),
263            (
264                (1, 8, 72, 96),
265                (8, 1, 3, 3),
266                (8,),
267                [1, 1],
268                [1, 1],
269                [1, 1],
270                False,
271                [0, 0],
272                8,
273            ),
274            (
275                (1, 8, 72, 96),
276                (8, 8, 1, 1),
277                (8,),
278                [1, 1],
279                [1, 1],
280                [1, 1],
281                False,
282                [0, 0],
283                1,
284            ),
285            (
286                (1, 6, 40, 50),
287                (8, 6, 3, 3),
288                None,
289                [1, 2],
290                [2, 3],
291                [1, 1],
292                False,
293                [0, 0],
294                1,
295            ),
296            (
297                (1, 6, 7),
298                (6, 1, 3),
299                (6,),
300                [1],
301                [0],
302                [1],
303                False,
304                [0],
305                6,
306            ),
307            (
308                (2, 20, 30),
309                (10, 4, 6),
310                (10,),
311                [5],
312                [5],
313                [3],
314                False,
315                [0],
316                5,
317            ),
318            (
319                (1, 9, 11),
320                (9, 1, 3),
321                None,
322                [1],
323                [0],
324                [1],
325                False,
326                [0],
327                9,
328            ),
329            (
330                (5, 15, 30),
331                (20, 3, 3),
332                None,
333                [3],
334                [5],
335                [7],
336                False,
337                [0],
338                5,
339            ),
340            (
341                (1, 16, 672, 512),
342                (64, 16, 1, 1),
343                (64,),
344                [1, 1],
345                [0, 0],
346                [1, 1],
347                False,
348                [0, 0],
349                1,
350            ),
351        ]
352    )
353    return test_suite
354
355
356@register_test_suite("aten.native_layer_norm.default")
357def get_native_layer_norm_inputs():
358    test_suite = VkTestSuite(
359        [
360            ((S1, S2), [S2], (S2), (S2), 0.001),
361            ((M, M1, M2), [M2], (M2), (M2), 0.001),
362            ((S, XL, M1, M2), [M2], (M2), (M2), 0.001),
363        ]
364    )
365    return test_suite
366
367
368@register_test_suite("aten.upsample_nearest2d.vec")
369def get_upsample_inputs():
370    test_suite = VkTestSuite(
371        [
372            # (input tensor shape, output 2D image size (H, W), output scaling factors)
373            ((2, 2, 2, 2), None, [1, 1]),
374            ((1, 1, 2, 2), None, [2, 2]),
375            ((1, 1, 2, 2), None, [2, 4]),
376            ((1, 1, 2, 2), None, [4, 2]),
377            ((1, 1, 2, 2), [2, 2], None),
378            ((1, 1, 2, 2), [2, 4], None),
379            ((1, 1, 2, 2), [3, 2], None),
380        ]
381    )
382    return test_suite
383
384
385@register_test_suite(["aten.full.default", "aten.full_like.default"])
386def get_full_inputs():
387    test_suite = VkTestSuite(
388        [
389            ([S1, S2], 42.0),
390            ([M, M1, M2], 3.14),
391            ([L, M, M1, M2], 2.72),
392        ]
393    )
394    return test_suite
395
396
397@register_test_suite(
398    [
399        "aten.zeros.default",
400        "aten.zeros_like.default",
401        "aten.ones.default",
402        "aten.ones_like.default",
403    ]
404)
405def get_ones_inputs():
406    test_suite = VkTestSuite(
407        [
408            ([S1, S2]),
409            ([M, M1, M2]),
410            ([L, M, M1, M2]),
411        ]
412    )
413    return test_suite
414
415
416@register_test_suite(["aten.select.int", "aten.select_copy.int"])
417def get_select_int_inputs():
418    test_suite = VkTestSuite(
419        [
420            ((6, 2, 7), 0, 3),
421            ((6, 2, 7), 1, 0),
422            ((6, 2, 7), 2, 3),
423            ((6, 10, 7), 0, 3),
424            ((6, 10, 7), 1, 0),
425            ((6, 10, 7), 1, 9),
426            ((6, 10, 7), 2, 6),
427            ((9, 2, 9, 4), 0, 8),
428            ((9, 2, 9, 4), 1, 1),
429            ((9, 2, 9, 4), 2, 0),
430            ((9, 2, 9, 4), 2, 8),
431            ((9, 2, 9, 4), 3, 3),
432            ((8, 6, 1, 1), 0, 4),
433            ((8, 6, 1, 1), 1, 4),
434        ]
435    )
436    return test_suite
437
438
439@register_test_suite(["aten.permute.default", "aten.permute_copy.default"])
440def get_permute_inputs():
441    test_suite = VkTestSuite(
442        [
443            ((9, 2, 9, 4), [0, 1, 2, 3]),
444            ((9, 2, 9, 4), [0, 1, 3, 2]),
445            ((9, 2, 9, 4), [0, 2, 1, 3]),
446            ((9, 2, 9, 4), [0, 2, 3, 1]),
447            ((9, 2, 9, 4), [0, 3, 1, 2]),
448            ((9, 2, 9, 4), [0, 3, 2, 1]),
449            ((9, 2, 9, 4), [3, 0, 1, 2]),
450            ((9, 2, 9, 4), [3, 2, 0, 1]),
451            ((9, 2, 9, 4), [2, 3, 0, 1]),
452            ((9, 2, 9, 4), [2, 0, 3, 1]),
453            ((9, 2, 9), [2, 0, 1]),
454            ((9, 2, 9), [1, 2, 0]),
455            ((9, 2), [0, 1]),
456            ((9, 2), [1, 0]),
457        ]
458    )
459
460    test_suite.layouts = ["utils::kChannelsPacked"]
461    return test_suite
462
463
464@register_test_suite("aten.view_copy.default")
465def get_view_inputs():
466    test_suite = VkTestSuite(
467        [
468            ((3, 4, 5), [1, 1, -1]),
469            ((3, 4, 5), [1, -1, 1]),
470            ((3, 4, 5), [-1, 1, 1]),
471            ((8, 7, 2, 3), [4, 3, 7, 4]),
472            ((8, 7, 2, 3), [7, -1, 2, 1]),
473            ((8, 7, 2, 3), [1, 1, 1, -1]),
474            ((8, 7, 2, 3), [-1]),
475            ((2, 3, 3, 7), [2, -1, 1, 1]),
476            ((3, 5, 2, 7), [7, -1, 2, 1]),
477            ((2, 2, 8, 6), [2, 6, -1, 1]),
478            ((2, 2, 8, 6), [6, -1, 1]),
479            ((S1, S2, S1, S2), [S2, -1, 1, S1]),
480            ((S1, S2, S1, S2), [S1, 1, -1, S2]),
481            ((S1, S2, S1, S2), [-1, 1, S1, S2]),
482        ]
483    )
484    test_suite.layouts = [
485        "utils::kWidthPacked",
486        "utils::kHeightPacked",
487        "utils::kChannelsPacked",
488    ]
489    return test_suite
490
491
492@register_test_suite("aten.slice_copy.Tensor")
493def get_slice_out_inputs():
494    Test = namedtuple("VkSliceTest", ["self", "dim", "start", "end", "step"])
495    Test.__new__.__defaults__ = (None, 0, None, None, 1)
496
497    # Slice by width and height
498    test_cases = [
499        Test(self=[1, 1, 4, 10], dim=3, start=3),
500        Test(self=[1, 1, 4, 10], dim=3, start=3, step=2),
501        Test(self=[1, 1, 4, 10], dim=3, start=3, end=4, step=2),
502        Test(self=[1, 1, 4, 10], dim=2, start=3),
503        Test(self=[9, 9, 9, 9], dim=2, start=0, end=9, step=1),
504        Test(self=[9, 9, 9, 9], dim=2, start=1, end=8, step=1),
505        Test(self=[9, 9, 9, 9], dim=2, start=1, end=2, step=1),
506        Test(self=[9, 9, 9, 9], dim=3, start=1, end=5, step=1),
507        Test(self=[9, 9, 9, 9], dim=3, start=1, end=5, step=2),
508        Test(self=[9, 9, 9, 9], dim=-1, start=1, end=5, step=2),
509        Test(self=[9, 9, 9, 9], dim=-2, start=1, end=5, step=2),
510        Test(self=[9, 9, 9], dim=1, start=2, step=1),
511        Test(self=[9, 9, 9], dim=1, start=2, step=2),
512        Test(self=[9, 9, 9], dim=2, start=2, step=1),
513        Test(self=[9, 9, 9], dim=2, start=2, step=2),
514        Test(self=[9, 9], dim=0, start=2, step=1),
515        Test(self=[9, 9], dim=0, start=2, step=2),
516        Test(self=[9, 9], dim=1, start=2, step=1),
517        Test(self=[9, 9], dim=1, start=2, step=2),
518    ]
519
520    # Slice by batch
521    test_cases += [
522        Test(self=[6, 5, 3, 2], dim=0),
523        Test(self=[6, 5, 3, 2], dim=0, step=2),
524        Test(self=[13, 13, 3, 2], dim=0, step=2),
525        Test(self=[13, 13, 3, 2], dim=0, start=1, step=2),
526        Test(self=[13, 13, 3, 2], dim=0, start=1, step=5),
527        Test(self=[13, 13, 3, 2], dim=0, start=1, step=20),
528        Test(self=[13, 2, 3, 2], dim=0, start=1, step=2),
529        Test(self=[13, 2, 3, 2], dim=0, start=1, step=5),
530        Test(self=[13, 2, 3, 2], dim=0, start=1, step=20),
531    ]
532
533    # Slice by channel
534    test_cases += [
535        Test(self=[2, 5, 1, 10], dim=1),
536        Test(self=[2, 5, 1, 10], dim=1, start=1),
537        Test(self=[2, 5, 1, 10], dim=1, start=1, step=2),
538        Test(self=[5, 13, 1, 10], dim=1),
539        Test(self=[5, 13, 1, 10], dim=1, start=1),
540        Test(self=[5, 13, 1, 10], dim=1, start=1, step=2),
541        Test(self=[5, 13, 1, 10], dim=1, start=1, step=5),
542        Test(self=[5, 13, 1, 10], dim=1, start=1, step=20),
543        Test(self=[13, 1, 10], dim=0),
544        Test(self=[13, 1, 10], dim=0, start=1),
545        Test(self=[13, 1, 10], dim=0, start=1, step=2),
546        Test(self=[13, 1, 10], dim=0, start=1, step=5),
547        Test(self=[13, 1, 10], dim=0, start=1, step=20),
548    ]
549
550    # Slice by negative/unspecified indices
551    INT64_MAX = 9223372036854775807  # represents arr[:]
552    test_cases += [
553        Test(self=[8, 9], dim=0, start=-2, step=1),
554        Test(self=[8, 9], dim=0, start=-2, step=2),
555        Test(self=[8, 9], dim=0, end=-2, step=1),
556        Test(self=[8, 9], dim=0, end=-2, step=2),
557        Test(self=[8, 9], dim=0, end=INT64_MAX, step=1),
558        Test(self=[8, 9], dim=0, end=INT64_MAX, step=2),
559        Test(self=[8, 9], dim=1, start=-2, step=1),
560        Test(self=[8, 9], dim=1, start=-2, step=2),
561        Test(self=[8, 9], dim=1, end=-2, step=1),
562        Test(self=[8, 9], dim=1, end=-2, step=2),
563        Test(self=[8, 9], dim=1, end=INT64_MAX, step=1),
564        Test(self=[8, 9], dim=1, end=INT64_MAX, step=2),
565    ]
566
567    test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
568
569    test_suite.dtypes = ["at::kFloat", "at::kHalf"]
570    test_suite.layouts = ["utils::kChannelsPacked"]
571    test_suite.data_gen = "make_seq_tensor"
572    return test_suite
573
574
575def get_slice_view_inputs():
576    Test = namedtuple("VkSliceTest", ["self", "dim", "start", "end", "step"])
577    Test.__new__.__defaults__ = (None, 0, None, None, 1)
578
579    # Slice by channel
580    test_cases = [
581        Test(self=[1, 17, 1, 10], dim=1, start=0, end=4),
582        Test(self=[1, 17, 1, 10], dim=1, start=0, end=8),
583        Test(self=[1, 17, 3, 7], dim=1, start=0, end=12),
584    ]
585
586    test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
587
588    test_suite.dtypes = ["at::kFloat"]
589    test_suite.storage_types = ["utils::kBuffer", "utils::kTexture3D"]
590    test_suite.layouts = ["utils::kWidthPacked"]
591    test_suite.data_gen = "make_seq_tensor"
592    test_suite.is_view_op = True
593
594    return test_suite
595
596
597@register_test_suite(["aten.slice.Tensor"])
598def get_slice_inputs():
599    texture_test_suite = get_slice_out_inputs()
600    texture_test_suite.test_name_suffix = "no_view"
601
602    view_test_suite = get_slice_view_inputs()
603    view_test_suite.test_name_suffix = "view"
604
605    return [view_test_suite, texture_test_suite]
606
607
608@register_test_suite(["aten.transpose.int"])
609def get_transpose_inputs():
610    Test = namedtuple("VkTransposeViewTest", ["self", "dim0", "dim1"])
611    Test.__new__.__defaults__ = (None, 0, 1)
612
613    test_cases = [
614        Test(self=[M1, M2], dim0=0, dim1=1),
615        Test(self=[M1, S2, M], dim0=0, dim1=1),
616        Test(self=[M1, S2, M], dim0=0, dim1=2),
617        Test(self=[M1, S2, M], dim0=2, dim1=1),
618        Test(self=[S, M, S2, M2], dim0=3, dim1=2),
619        Test(self=[S, M, S2, M2], dim0=1, dim1=2),
620        Test(self=[S, M, S2, M2], dim0=3, dim1=1),
621    ]
622
623    test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
624
625    test_suite.dtypes = ["at::kFloat"]
626    test_suite.storage_types = ["utils::kBuffer", "utils::kTexture3D"]
627    test_suite.layouts = ["utils::kWidthPacked", "utils::kChannelsPacked"]
628    test_suite.data_gen = "make_seq_tensor"
629    test_suite.is_view_op = True
630    return test_suite
631
632
633@register_test_suite("aten.index_select.default")
634def get_index_select_inputs():
635    Test = namedtuple("VkIndexSelectTest", ["self", "dim", "index"])
636    Test.__new__.__defaults__ = (None, 0, None)
637
638    test_cases = []
639
640    for i in range(4):
641        test_cases += [
642            Test(self=[9, 9, 9, 9], dim=i, index=[0]),
643            Test(self=[9, 9, 9, 9], dim=i, index=[2]),
644            Test(self=[9, 9, 9, 9], dim=i, index=[0, 2]),
645            Test(self=[9, 9, 9, 9], dim=i, index=[3, 1]),
646            Test(self=[9, 9, 9, 9], dim=i, index=[5, 5]),
647            Test(self=[9, 9, 9, 9], dim=i, index=[2, 3, 4, 5, 7]),
648        ]
649
650    test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
651
652    test_suite.dtypes = ["at::kFloat"]
653    test_suite.layouts = ["utils::kChannelsPacked"]
654    return test_suite
655
656
657@register_test_suite("aten.embedding.default")
658def get_embedding_inputs():
659    Test = namedtuple("VkEmbeddingTest", ["weight", "indices"])
660    Test.__new__.__defaults__ = (None, None)
661
662    test_cases = [
663        Test(weight=[10, 9], indices=[0, 2]),
664        Test(weight=[10, 9], indices=[2, 3, 4, 5, 7]),
665        Test(weight=[10, 9], indices=[[0, 2], [1, 4], [7, 7]]),
666        Test(weight=[10, 9], indices=[[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]),
667        Test(weight=[10, 9], indices=[[[3, 1, 4], [1, 5, 9]], [[2, 6, 5], [3, 5, 8]]]),
668    ]
669
670    test_suite = VkTestSuite([tuple(tc) + (-1, "false", "false") for tc in test_cases])
671
672    test_suite.dtypes = ["at::kFloat"]
673    test_suite.layouts = ["utils::kChannelsPacked"]
674    return test_suite
675
676
677@register_test_suite("aten.unsqueeze_copy.default")
678def get_unsqueeze_inputs():
679    test_suite = VkTestSuite(
680        [
681            ((2, 3, 4), 0),
682            ((1, 1, 1), 0),
683            ((1, 1, 1), 1),
684            ((1, 1, 1), 2),
685            ((1, 1, 1), 3),
686            ((9, 9, 9), 0),
687            ((9, 9, 9), 1),
688            ((9, 9, 9), 2),
689            ((9, 9, 9), 3),
690            ((9, 9), 0),
691            ((9, 9), 1),
692            ((9, 9), 2),
693            ((9,), 0),
694            ((9,), 1),
695        ]
696    )
697    test_suite.layouts = [
698        "utils::kChannelsPacked",
699    ]
700    test_suite.data_gen = "make_seq_tensor"
701    return test_suite
702
703
704@register_test_suite("aten.clone.default")
705def get_clone_inputs():
706    test_suite = VkTestSuite(
707        [
708            ((S2, S1, S2, S1),),
709            ((S2, S1, S2),),
710            ((S2, S1),),
711            ((S2,),),
712            ((XS, S1, XS, S1),),
713            ((XS, S1, XS),),
714            ((S1, XS, S1),),
715            ((XS, S1),),
716            ((S1, XS),),
717            ((S1,),),
718            ((XS,),),
719        ]
720    )
721    test_suite.layouts = [
722        "utils::kChannelsPacked",
723    ]
724    test_suite.data_gen = "make_seq_tensor"
725    return test_suite
726
727
728@register_test_suite("aten.repeat.default")
729def get_repeat_inputs():
730    test_suite_2d = VkTestSuite(
731        [
732            ((2, 3), [1, 4]),
733            ((2, 3), [4, 1]),
734            ((2, 3), [4, 4]),
735            ((2, 3), [3, 1, 4]),
736        ]
737    )
738    test_suite_2d.layouts = ["utils::kChannelsPacked"]
739    test_suite_2d.storage_types = ["utils::kTexture2D"]
740    test_suite_2d.data_gen = "make_seq_tensor"
741    test_suite_2d.dtypes = ["at::kFloat"]
742    test_suite_2d.test_name_suffix = "2d"
743
744    test_suite_3d = VkTestSuite(
745        [
746            # Repeat channels only (most challenging case)
747            ((3, XS, S), [2, 1, 1]),
748            ((7, XS, S), [4, 1, 1]),
749            ((1, 7, XS, S), [1, 4, 1, 1]),
750            ((3, 7, XS, S), [1, 4, 1, 1]),
751            # Repat channels with other dims
752            ((1, 7, XS, S), [1, 4, 1, 3]),
753            ((3, 7, XS, S), [1, 4, 1, 3]),
754            ((3, 7, XS, S), [1, 4, 3, 1]),
755            ((3, 7, XS, S), [1, 4, 3, 3]),
756            # Repeat Batch
757            ((3, 7, XS, S), [3, 4, 3, 3]),
758            ((3, 7, XS, S), [3, 1, 3, 3]),
759            # More other cases
760            ((3, 7, 1, 1), [1, 4, 1, 1]),
761            ((2, 3), [1, 4]),
762            ((2, 3), [4, 1]),
763            ((2, 3), [4, 4]),
764            ((S1, S2, S2), [1, 3, 1]),
765            ((S1, S2, S2), [1, 3, 3]),
766            ((S1, S2, S2), [3, 3, 1]),
767            ((S1, S2, S2), [3, 3, 3]),
768            ((S1, S2, S2, S2), [1, 1, 3, 1]),
769            ((S1, S2, S2, S2), [1, 1, 1, 3]),
770            ((S1, S2, S2, S2), [1, 1, 3, 3]),
771            ((S1, S2, S2, S2), [1, 3, 1, 3]),
772            ((S1, S2, S2, S2), [3, 3, 3, 3]),
773            ((S1, S2, S2, S2), [3, 3, 1, 1]),
774            # Expanding cases
775            ((2, 3), [3, 1, 4]),
776            ((2, 3), [3, 3, 2, 4]),
777        ]
778    )
779    test_suite_3d.layouts = ["utils::kChannelsPacked"]
780    test_suite_3d.storage_types = ["utils::kTexture3D"]
781    test_suite_3d.data_gen = "make_seq_tensor"
782    test_suite_3d.dtypes = ["at::kFloat"]
783    test_suite_3d.test_name_suffix = "3d"
784
785    return [test_suite_2d, test_suite_3d]
786
787
788@register_test_suite("aten.repeat_interleave.self_int")
789def get_repeat_interleave_inputs():
790    test_suite_W = VkTestSuite(
791        [
792            ((4, 32, 256), 3, -2),
793            # Test repeat on each non-packed dim
794            ((16, 32, 64), 5, -2),
795            ((16, 32, 64), 5, -3),
796            # Test batched inputs
797            ((3, 5, 32, 64), 4, -2),
798            ((3, 5, 32, 64), 4, -3),
799        ]
800    )
801    test_suite_W.layouts = [
802        "utils::kWidthPacked",
803    ]
804    test_suite_W.data_gen = "make_seq_tensor"
805    test_suite_W.dtypes = ["at::kFloat"]
806    test_suite_W.test_name_suffix = "W_packed"
807
808    test_suite_C = VkTestSuite(
809        [
810            # Test repeat on each non-packed dim
811            ((32, 32, 16), 5, -1),
812            ((32, 32, 16), 5, -2),
813            # Test batched inputs
814            ((3, 16, 8, 64), 4, -1),
815            ((3, 16, 8, 64), 4, -2),
816        ]
817    )
818    test_suite_C.layouts = [
819        "utils::kChannelsPacked",
820    ]
821    test_suite_C.data_gen = "make_seq_tensor"
822    test_suite_C.dtypes = ["at::kFloat"]
823    test_suite_C.test_name_suffix = "C_packed"
824
825    return [test_suite_W, test_suite_C]
826
827
828@register_test_suite("aten.cat.default")
829def get_cat_inputs():
830    # TensorList must be specified as list of tuples
831    test_suite = VkTestSuite(
832        [
833            # Cat on Height
834            ([(S1, S1, 3, 5), (S1, S1, 0, 5)], 2),
835            ([(S1, S1, 3, 5), (S1, S1, 4, 5)], 2),
836            ([(S1, 3, 5), (S1, 4, 5)], 1),
837            ([(3, 5), (4, 5)], 0),
838            ([(3, 5), (4, 5), (1, 5)], 0),
839            (
840                [(3, 5)],
841                0,
842            ),
843            # Cat on Width
844            ([(S1, S1, 5, 3), (S1, S1, 5, 4)], 3),
845            ([(S1, 5, 3), (S1, 5, 4)], 2),
846            ([(5, 0), (5, 4)], 1),
847            ([(5, 3), (5, 4)], 1),
848            ([(5, 3), (5, 4), (5, 1)], 1),
849            (
850                [(5, 4)],
851                1,
852            ),
853            ([(5,), (6,)], 0),
854            # Cat on Batch
855            ([(S, S1, 5, 4), (S1, S1, 5, 4)], 0),
856            ([(S, XS, 5, 4), (S1, XS, 5, 4)], 0),
857            ([(S, S2, 5, 4), (S1, S2, 5, 4)], 0),
858            (
859                [
860                    (3, 1, 2, 5),
861                    (3, 1, 2, 5),
862                    (3, 1, 2, 5),
863                ],
864                0,
865            ),
866            # Cat on Channel
867            ([(S, 5, 4), (0, 5, 4), (S2, 5, 4)], 0),
868            ([(S, 5, 4), (S1, 5, 4), (S2, 5, 4)], 0),
869            ([(XS, 5, 4), (XS, 5, 4), (S2, 5, 4)], 0),
870            ([(XS, S, 5, 4), (XS, S1, 5, 4), (XS, S2, 5, 4)], 1),
871            ([(XS, XS, 5, 4), (XS, XS, 5, 4), (XS, S2, 5, 4)], 1),
872            (
873                [
874                    (XS, 1, 2, 5),
875                    (XS, 1, 2, 5),
876                    (XS, 1, 2, 5),
877                ],
878                1,
879            ),
880        ]
881    )
882    test_suite.layouts = [
883        "utils::kChannelsPacked",
884    ]
885    test_suite.data_gen = "make_seq_tensor"
886    test_suite.dtypes = ["at::kFloat"]
887    return test_suite
888
889
890@register_test_suite("aten.split_with_sizes_copy.default")
891def get_split_with_sizes_inputs():
892    Test = namedtuple("VkSliceTest", ["self", "sizes", "dim"])
893    test_cases = [
894        # Split on Width
895        Test(self=(S1, 7, 10, 10), sizes=[1, 2, 3, 4], dim=3),
896        Test(self=(7, 10, 10), sizes=[1, 2, 3, 4], dim=2),
897        Test(self=(7, 10, 10), sizes=[1, 9], dim=2),
898        Test(self=(10, 10), sizes=[1, 9], dim=1),
899        Test(self=(10,), sizes=[1, 9], dim=0),
900        # Split on Height
901        Test(self=(S1, 7, 10, 10), sizes=[1, 2, 3, 4], dim=2),
902        Test(self=(7, 10, 10), sizes=[1, 2, 3, 4], dim=1),
903        Test(self=(7, 10, 10), sizes=[10], dim=1),
904        Test(self=(7, 6, 10), sizes=[1, 1, 1, 1, 1, 1], dim=1),
905        Test(self=(10, 10), sizes=[1, 2, 3, 4], dim=0),
906        # Split on Batch
907        Test(self=(10, 7, 10, 10), sizes=[3, 6, 1], dim=0),
908        Test(self=(10, 7, 10, 10), sizes=[10], dim=0),
909        # Split on Channel
910        Test(self=(7, 13, 4, 8), sizes=[3, 6, 1, 3], dim=1),
911        Test(self=(7, 13, 4, 8), sizes=[3, 3, 3, 3, 1], dim=1),
912        Test(self=(13, 4, 8), sizes=[3, 3, 3, 3, 1], dim=0),
913        Test(self=(13, 4, 8), sizes=[2, 9, 2], dim=0),
914        Test(self=(13, 4, 8), sizes=[13], dim=0),
915    ]
916    test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
917
918    test_suite.layouts = [
919        "utils::kChannelsPacked",
920    ]
921    test_suite.data_gen = "make_seq_tensor"
922    test_suite.dtypes = ["at::kFloat"]
923    return test_suite
924
925
926@register_test_suite("aten.split.Tensor")
927def get_split_tensor_inputs():
928    test_suite = VkTestSuite(
929        [
930            # Split on Width
931            ((S1, 7, 10, 12), 12, 3),
932            ((S1, 7, 10, 12), 3, 3),
933            ((S1, 7, 10, 12), 1, 3),
934            ((7, 10, 12), 12, 2),
935            ((7, 10, 12), 3, 2),
936            ((7, 10, 12), 1, 2),
937            ((10, 12), 12, 1),
938            ((10, 12), 3, 1),
939            ((10, 12), 1, 1),
940            ((12,), 12, 0),
941            ((12,), 3, 0),
942            ((12,), 1, 0),
943            # Split on Height
944            ((S1, 7, 12, 8), 12, 2),
945            ((S1, 7, 12, 8), 3, 2),
946            ((S1, 7, 12, 8), 1, 2),
947            ((7, 12, 8), 12, 1),
948            ((7, 12, 8), 3, 1),
949            ((7, 12, 8), 1, 1),
950            ((12, 8), 12, 0),
951            ((12, 8), 3, 0),
952            ((12, 8), 1, 0),
953            # Split  on Batch
954            ((12, 7, 10, 10), 12, 0),
955            ((12, 7, 10, 10), 3, 0),
956            ((12, 7, 10, 10), 1, 0),
957            # Split  on Channel
958            ((7, 15, 10, 10), 15, 1),
959            ((7, 15, 10, 10), 5, 1),
960            ((7, 15, 10, 10), 3, 1),
961            ((7, 15, 10, 10), 1, 1),
962            ((15, 10, 10), 15, 0),
963            ((15, 10, 10), 5, 0),
964            ((15, 10, 10), 3, 0),
965            ((15, 10, 10), 1, 0),
966        ]
967    )
968
969    test_suite.layouts = [
970        "utils::kChannelsPacked",
971    ]
972    test_suite.data_gen = "make_seq_tensor"
973    test_suite.dtypes = ["at::kFloat"]
974    return test_suite
975
976
977def get_reduce_inputs(is_softmax: bool = False):
978    bool_arg = False if is_softmax else True
979    return [
980        ((L), 0, bool_arg),
981        ((L), -1, bool_arg),
982        ((M, L), 0, bool_arg),
983        ((M, L), 1, bool_arg),
984        ((L, M), -1, bool_arg),
985        ((M, L), -2, bool_arg),
986        ((S, S1, S2), 0, bool_arg),
987        ((S, S1, S2), 1, bool_arg),
988        ((S, S1, S2), 2, bool_arg),
989        ((S, S1, S2), -1, bool_arg),
990        ((S, S1, S2), -2, bool_arg),
991        ((S, S1, S2), -3, bool_arg),
992        ((1, S, S1, S2), 1, bool_arg),
993        ((1, S, S1, S2), 2, bool_arg),
994        ((1, S, S1, S2), 3, bool_arg),
995        ((1, S, S1, S2), -1, bool_arg),
996        ((1, S, S1, S2), -2, bool_arg),
997        ((1, S, S1, S2), -3, bool_arg),
998        # Test batches > 1 where the reduction dim is not the concat dim
999        ((S, S2, S1, 128), -1, bool_arg),
1000    ]
1001
1002
1003@register_test_suite(["aten._softmax.default", "aten._log_softmax.default"])
1004def get_softmax_inputs():
1005    test_suite = VkTestSuite(get_reduce_inputs(is_softmax=True))
1006    test_suite.layouts = [
1007        "utils::kWidthPacked",
1008        "utils::kChannelsPacked",
1009    ]
1010    return test_suite
1011
1012
1013@register_test_suite(
1014    ["aten.amax.default", "aten.amin.default", "aten.sum.dim_IntList", "aten.mean.dim"]
1015)
1016def get_reduce_op_inputs():
1017    test_suite = VkTestSuite(get_reduce_inputs())
1018    test_suite.layouts = [
1019        "utils::kChannelsPacked",
1020        "utils::kWidthPacked",
1021    ]
1022    return test_suite
1023
1024
1025@register_test_suite(
1026    [
1027        "aten.sqrt.default",
1028        "aten.rsqrt.default",
1029        "aten.exp.default",
1030        "aten.hardshrink.default",
1031        "aten.sin.default",
1032        "aten.neg.default",
1033        "aten.cos.default",
1034        "aten.hardswish.default",
1035        "aten.hardsigmoid.default",
1036    ]
1037)
1038def get_unary_ops_inputs():
1039    test_suite = VkTestSuite(
1040        [
1041            (M1,),
1042            (M1, M2),
1043            (S1, M1, M2),
1044            (S1, S2, S2, M2),
1045        ]
1046    )
1047    test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"]
1048    test_suite.atol = "1e-4"
1049    test_suite.rtol = "1e-4"
1050    return test_suite
1051
1052
1053@register_test_suite("aten._native_batch_norm_legit_no_training.default")
1054def get_native_batch_norm_inputs():
1055    Test = namedtuple(
1056        "VkSliceTest", ["self", "weight", "bias", "mean", "var", "momentum", "eps"]
1057    )
1058
1059    test_cases = [
1060        Test(
1061            self=(1, 1, 2, 5),
1062            weight=(1,),
1063            bias=(1,),
1064            mean=(1,),
1065            var=(1,),
1066            momentum=0.0,
1067            eps=0.001,
1068        ),
1069        Test(
1070            self=(S2, 1, 2, 5),
1071            weight=(1,),
1072            bias=(1,),
1073            mean=(1,),
1074            var=(1,),
1075            momentum=0.0,
1076            eps=0.001,
1077        ),
1078        Test(
1079            self=(1, S2, 2, 5),
1080            weight=(S2,),
1081            bias=(S2,),
1082            mean=(S2,),
1083            var=(S2,),
1084            momentum=0.0,
1085            eps=0.001,
1086        ),
1087        Test(
1088            self=(9, S1, 2, 5),
1089            weight=(S1,),
1090            bias=(S1,),
1091            mean=(S1,),
1092            var=(S1,),
1093            momentum=0.0,
1094            eps=0.01,
1095        ),
1096        Test(
1097            self=(3, S1, 2, 5),
1098            weight=(S1,),
1099            bias=(S1,),
1100            mean=(S1,),
1101            var=(S1,),
1102            momentum=0.0,
1103            eps=0.001,
1104        ),
1105        Test(
1106            self=(3, S2, 2, 5),
1107            weight=(S2,),
1108            bias=(S2,),
1109            mean=(S2,),
1110            var=(S2,),
1111            momentum=0.0,
1112            eps=0.001,
1113        ),
1114        Test(
1115            self=(3, S2, 2, 5),
1116            weight=(S2,),
1117            bias=(S2,),
1118            mean=(S2,),
1119            var=(S2,),
1120            momentum=0.0,
1121            eps=0.000,
1122        ),
1123    ]
1124
1125    test_suite = VkTestSuite(test_cases)
1126    test_suite.requires_prepack = True
1127    test_suite.prepacked_args = ["weight", "bias", "mean", "var"]
1128
1129    return test_suite
1130
1131
1132@register_test_suite("aten.gelu.default")
1133def get_gelu_inputs():
1134    test_suite = VkTestSuite(
1135        [
1136            ((M1), "tanh"),
1137            ((M1, M2), "tanh"),
1138            ((S1, M1, M2), "tanh"),
1139            ((S1, S2, S2, M2), "tanh"),
1140        ]
1141    )
1142    return test_suite
1143
1144
1145@register_test_suite("aten.arange.start_step")
1146def get_arange_inputs():
1147    test_suite = VkTestSuite(
1148        [
1149            (1, 13),
1150            (1.0, 11),
1151            (-13, 3),
1152            (-11.0, 2),
1153            (3, 15, 3),
1154            (3, 23, 2),
1155            (3, 23.0, 4),
1156            (13, 1, -1),
1157            (-3, -13, -2),
1158            (13, -2.0, -4),
1159        ],
1160    )
1161
1162    test_suite.layouts = [
1163        "utils::kChannelsPacked",
1164    ]
1165    return test_suite
1166
1167
1168@register_test_suite("aten.constant_pad_nd.default")
1169def get_constant_pad_nd_inputs():
1170    test_suite = VkTestSuite(
1171        [
1172            ([S1, S2], [1, 1], 24.0),
1173            ([M, M1, M2], [2, 2], 23.2),
1174            ([L, M, M1, M2], [3, 5], 12.2),
1175            ([S1, S2], [1, 1, 1, 1], 24.0),
1176            ([M, M1, M2], [2, 2, 2, 2], 23.2),
1177            ([L, M, M1, M2], [3, 5, 3, 5], 12.2),
1178            ([M, M1, M2], [1, 2, 3, 4, 5, 6], 23.2),
1179            ([L, M, M1, M2], [3, 3, 3, 3, 3, 3], 12.2),
1180        ]
1181    )
1182    return test_suite
1183
1184
1185@register_test_suite("aten.minimum.default")
1186def get_minimum_inputs():
1187    test_suite = VkTestSuite(
1188        [
1189            ((M1, M2), (M2)),
1190            ((M1, M2), (M1, M2)),
1191            ((M1, M2, M), (M2, M)),
1192            ((M1, M1, S1, S2), (M1, M1, S1, S2)),
1193            ((S1, S1, S2, S), (S1, S2, S)),
1194            ((M1, S1, S2), (L, M1, S1, S2)),
1195            ((S1, S2), (L, M1, S1, S2)),
1196        ]
1197    )
1198    return test_suite
1199
1200
1201@register_test_suite("aten.squeeze_copy.dims")
1202def get_squeeze_copy_dim_inputs():
1203    test_suite = VkTestSuite(
1204        [
1205            ([S, S, S, 1], 3),
1206            ([S, 1, S, S], 1),
1207            ([S, 1, 1, S], [1, 2]),
1208            ([1, S, S, S], 0),
1209            ([S, S, S, S], 3),
1210            ([S, S, S, S], 2),
1211            ([S, S, S, S], 1),
1212            ([M, M1, 1], 2),
1213            ([M, 1, M1], 1),
1214            ([1, M1, M1], 0),
1215        ]
1216    )
1217    return test_suite
1218
1219
1220@register_test_suite("aten.flip.default")
1221def get_flip_inputs():
1222    Test = namedtuple("Flip", ["self", "dim"])
1223    Test.__new__.__defaults__ = (None, 0)
1224
1225    test_cases = [
1226        Test(self=[9], dim=[0]),
1227        Test(self=[9, 9], dim=[0, 1]),
1228        Test(self=[9, 9, 9], dim=[0, 2]),
1229        Test(self=[9, 9, 9], dim=[0, 1, 2]),
1230        Test(self=[9, 9, 9, 9], dim=[0]),
1231        Test(self=[9, 9, 9, 9], dim=[0, 2, 3]),
1232        Test(self=[9, 9, 9, 9], dim=[1, 3]),
1233        Test(self=[9, 9, 9, 9], dim=[0, 1, 2, 3]),
1234    ]
1235
1236    test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
1237    return test_suite
1238