xref: /aosp_15_r20/external/executorch/kernels/test/test_case_gen.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import os
8import re
9import sys
10
11from abc import ABC, abstractmethod
12
13from enum import Enum
14
15import torch
16
17"""
18A helper library to generate test cases for ET kernels.
19
20It simplifies the steps to generate a new c++ test case. User just need
21to specify the inputs and we use pytorch kernel to calculate the result.
22"""
23
24
25# Seed the RNG in all the common libraries for test reproducibility
26torch.manual_seed(0)
27
28
29def make_out_static_shape(tensor: torch.Tensor):
30    sizes = list(tensor.size())
31    sizes = [str(s) for s in sizes]
32    sizes_str = "{" + ", ".join(sizes) + "}"
33    return sizes_str
34
35
36def make_out_dynamic_shape_bound_shape_same(tensor: torch.Tensor):
37    sizes = list(tensor.size())
38    sizes = [str(s) for s in sizes]
39    sizes_str = "{" + ", ".join(sizes) + "}"
40    return sizes_str + ", torch::executor::TensorShapeDynamism::DYNAMIC_BOUND"
41
42
43def make_out_dynamic_shape_bound_shape_larger(tensor: torch.Tensor):
44    sizes = list(tensor.size())
45    extra_sizes = [x * 2 for x in sizes]
46    extra_sizes = [str(s) for s in extra_sizes]
47    sizes_str = "{" + ", ".join(extra_sizes) + "}"
48    return sizes_str + ", torch::executor::TensorShapeDynamism::DYNAMIC_BOUND"
49
50
51def make_out_dynamic_shape_unbound_shape(tensor: torch.Tensor):
52    sizes = list(tensor.size())
53    smaller_sizes = [1 for x in sizes]
54    smaller_sizes = [str(s) for s in smaller_sizes]
55    sizes_str = "{" + ", ".join(smaller_sizes) + "}"
56    return sizes_str + ", torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND"
57
58
59class ShapeDynamism(Enum):
60    # Static shape; shape is determined from pytorch output
61    STATIC = 1
62    # Dynamic bound with same size; shape is determined from pytorch output using the same size as static
63    DYNAMIC_BOUND_SAME_SHAPE = 2
64    # Dynamic bound with a larger size to test functionality; shape is determined from pytorch output
65    DYNAMIC_BOUND_LARGER_SHAPE = 3
66    # Dynamic unbound with a smaller size to test functionality
67    DYNAMIC_UNBOUND = 4
68
69
70out_dynamic_shape_fn_map = {
71    ShapeDynamism.STATIC: make_out_static_shape,
72    ShapeDynamism.DYNAMIC_BOUND_SAME_SHAPE: make_out_dynamic_shape_bound_shape_same,
73    ShapeDynamism.DYNAMIC_BOUND_LARGER_SHAPE: make_out_dynamic_shape_bound_shape_larger,
74    ShapeDynamism.DYNAMIC_UNBOUND: make_out_dynamic_shape_unbound_shape,
75}
76
77
78def make_test_cases_dynamic_shape(*args):
79    """
80    A helper to make a list of tuples (test cases). Each tuple contains
81    the name,
82    inputs and output (expanded from *args),
83    dynamic shape type
84    """
85    return [
86        (
87            "DynamicShapeUpperBoundSameAsExpected",
88            *args,
89            ShapeDynamism.DYNAMIC_BOUND_SAME_SHAPE,
90        ),
91        (
92            "DynamicShapeUpperBoundLargerThanExpected",
93            *args,
94            ShapeDynamism.DYNAMIC_BOUND_LARGER_SHAPE,
95        ),
96        (
97            "DynamicShapeUnbound",
98            *args,
99            ShapeDynamism.DYNAMIC_UNBOUND,
100        ),
101    ]
102
103
104def make_test_cases_broadcast_two_input_tensor(x, y, cpp_args, torch_args, torch_fn):
105    """
106    A helper to make a list of tuples (test cases). Each tuple contains
107    the name,
108    inputs and output (expanded from *args),
109    dynamic shape type (use static here)
110
111    Used when we have two input tensors (like add, mul, div).
112    Generate test cases where
113    we drop a dimension from the first/second tensor
114    we set a dimension to one from the first/second tensor
115    """
116    x_remove_dim = x[0]
117    x_first_dim_1 = x_remove_dim.squeeze(0)
118    y_remove_dim = y[0]
119    y_first_dim_1 = y_remove_dim.squeeze(0)
120
121    return [
122        (
123            "BroadcastDimSizeIsOneAB",
124            x_first_dim_1,
125            y,
126            *cpp_args,
127            torch_fn(x_first_dim_1, y, *torch_args),
128            ShapeDynamism.STATIC,
129        ),
130        (
131            "BroadcastDimSizeMissingAB",
132            x_remove_dim,
133            y,
134            *cpp_args,
135            torch_fn(x_remove_dim, y, *torch_args),
136            ShapeDynamism.STATIC,
137        ),
138        (
139            "BroadcastDimSizeIsOneBA",
140            x,
141            y_first_dim_1,
142            *cpp_args,
143            torch_fn(x, y_first_dim_1, *torch_args),
144            ShapeDynamism.STATIC,
145        ),
146        (
147            "BroadcastDimSizeMissingBA",
148            x,
149            y_remove_dim,
150            *cpp_args,
151            torch_fn(x, y_remove_dim, *torch_args),
152            ShapeDynamism.STATIC,
153        ),
154    ]
155
156
157class ArgType(ABC):
158    """
159    Represents an argument for generated C++ code and for pytorch call
160    """
161
162    @abstractmethod
163    def to_pytorch(self):
164        return None
165
166    @abstractmethod
167    def to_cpp(self) -> str:
168        return ""
169
170
171class Scalar(ArgType):
172    def __init__(self, val):
173        self.val = val
174
175    def to_pytorch(self):
176        return self.val
177
178    def to_cpp(self):
179        return f"Scalar({self.val})"
180
181
182class OptScalar(ArgType):
183    def __init__(self, val):
184        self.val = val
185
186    def to_pytorch(self):
187        return self.val
188
189    def to_cpp(self):
190        return f"OptScalar({self.val})"
191
192
193class ArrayRef(ArgType):
194    def __init__(self, dtype, data: list):
195        self.dtype = dtype
196        self.data = data
197
198    def to_pytorch(self):
199        return self.data
200
201    def to_cpp(self):
202        array_str = "{" + ",".join(str(data) for data in self.data) + "}"
203        return f"ArrayRef<{self.dtype}>({array_str})"
204
205
206class EnumArg(ArgType):
207    def __init__(self, text):
208        self.text = text
209
210    def to_pytorch(self):
211        # Most likely it cannot be directly used
212        return ""
213
214    def to_cpp(self):
215        return self.text
216
217
218class StringArg(ArgType):
219    def __init__(self, text):
220        self.text = text
221
222    def to_pytorch(self):
223        return self.text
224
225    def to_cpp(self):
226        return f'"{self.text}"'
227
228
229def tensor_to_cpp_code(tensor: torch.Tensor) -> str:
230    sizes = list(tensor.size())
231    sizes = [str(s) for s in sizes]
232    sizes_str = "{" + ", ".join(sizes) + "}"
233    data = torch.flatten(tensor).tolist()
234    data = [str(d) for d in data]
235    data_str = "{" + ", ".join(data) + "}"
236    if tensor.dtype == torch.bool:
237        return f"""tf_bool.make({sizes_str}, {data_str})""".replace(
238            "True", "true"
239        ).replace("False", "false")
240    return f"""tf.make({sizes_str}, {data_str})"""
241
242
243def argument_to_cpp_code(arg):
244    if isinstance(arg, str):
245        return arg
246    elif isinstance(arg, bool):
247        return "true" if arg else "false"
248    elif isinstance(arg, (int, float)) and not isinstance(arg, bool):
249        # Note: We explicitly exclude bool because bool is a subset of int
250        return str(arg)
251    elif isinstance(arg, bool):
252        return "true" if arg else "false"
253    elif isinstance(arg, torch.Tensor):
254        return tensor_to_cpp_code(arg)
255    elif isinstance(arg, ArgType):
256        return arg.to_cpp()
257    return "?"
258
259
260def argument_to_pytorch(arg):
261    if isinstance(arg, (str, int, float, torch.Tensor)):
262        return arg
263    elif isinstance(arg, ArgType):
264        return arg.to_pytorch()
265    return "?"
266
267
268class ArgForPyTorch:
269    """Sometimes an arg for cpp cannot directly be used in torch because it is not used, or used only in torch, or it is a kwarg"""
270
271    def __init__(self, cpp_arg, torch_kwarg_key, torch_kwarg_val):
272        self.cpp_arg = cpp_arg
273        self.kwarg_pair = torch_kwarg_key, torch_kwarg_val
274
275    def used_in_cpp(self):
276        return self.cpp_arg is not None
277
278    def used_in_torch(self):
279        return self.kwarg_pair != (None, None)
280
281
282def make_simple_generated_case(*args, torch_fn):
283    cpp_args = tuple(
284        arg.cpp_arg if isinstance(arg, ArgForPyTorch) else arg
285        for arg in args
286        if not isinstance(arg, ArgForPyTorch) or arg.used_in_cpp()
287    )
288    torch_args = tuple(
289        argument_to_pytorch(arg) for arg in args if not isinstance(arg, ArgForPyTorch)
290    )
291    kwargs_for_torch_fn = dict(
292        arg.kwarg_pair
293        for arg in args
294        if isinstance(arg, ArgForPyTorch) and arg.used_in_torch()
295    )
296    return [
297        (
298            "SimpleGeneratedCase",
299            *cpp_args,
300            torch_fn(*torch_args, **kwargs_for_torch_fn),
301            ShapeDynamism.STATIC,
302        )
303    ]
304
305
306def gen_test_cases(suite_name: str, op_name: str, test_cases, test_f=False):
307    """
308    Used when some inputs are not Tensor or scalar. Treat them as code text and generate.
309    Each test case should be a tuple of
310    (test_case_name, inputs, expected_result, shape_dynamism)
311    out_size is the pre-allocatd size for out tensor
312    Set test_f to True if we want TEST_F (gtest fixture)
313
314    For example, in https://www.internalfb.com/code/fbsource/[7280e42e309e85294a77fbb51ccc6de1948f2497]/fbcode/executorch/kernels/test/op_add_test.cpp?lines=19-23, we have an additional alpha parameter
315    """
316
317    variable_names = "xyzabcdefghijk"
318    newline = "\n"
319
320    generated_cases = []
321
322    for test_name, *inputs, expected_result, shape_dynamism in test_cases:
323        out_dynamic_shape_fn = out_dynamic_shape_fn_map[shape_dynamism]
324        input_code = [argument_to_cpp_code(i) for i in inputs]
325        input_lines = [
326            f"auto {variable_names[i]} = {input_code[i]};" for i in range(len(inputs))
327        ]
328
329        need_tf_bool = any(
330            isinstance(i, torch.Tensor) and i.dtype == torch.bool for i in inputs
331        )
332
333        ret_value = f"""{op_name}({", ".join(variable_names[:len(inputs)])}, out)"""
334
335        generated_cases.append(
336            f"""
337{"TEST_F" if test_f else "TEST"}({suite_name}, {test_name}) {{
338  TensorFactory<ScalarType::Float> tf;
339  {"TensorFactory<ScalarType::Bool> tf_bool;" if need_tf_bool else ""}
340
341  {newline.join(input_lines)}
342  Tensor expected_result = {tensor_to_cpp_code(expected_result)};
343
344  Tensor out = tf.zeros({out_dynamic_shape_fn(expected_result)});
345  Tensor ret = {ret_value};
346  EXPECT_TENSOR_CLOSE(out, expected_result);
347}}
348"""
349        )
350    return generated_cases
351
352
353def gen_test_case_op_arange():
354    return gen_test_cases(
355        "OpArangeOutTest",
356        "arange_out",
357        make_test_cases_dynamic_shape(Scalar(5), torch.arange(5)),
358        test_f=True,
359    )
360
361
362def gen_test_case_op_as_strided_copy():
363    # TODO: Implement
364    return
365
366
367def gen_test_case_op_bitwise_not():
368    # TODO: Implement
369    return
370
371
372def gen_test_case_op_cat():
373    # TODO: Implement
374    return
375
376
377def gen_test_case_op_clamp():
378    x = torch.rand(3, 2)
379
380    return gen_test_cases(
381        "OpClampOutTest",
382        "clamp_out",
383        make_simple_generated_case(
384            torch.ones(10, 10), OptScalar(-0.5), OptScalar(0.5), torch_fn=torch.clamp
385        )
386        + make_test_cases_dynamic_shape(
387            x, OptScalar(-0.5), OptScalar(0.5), torch.clamp(x, -0.5, 0.5)
388        ),
389    )
390
391
392def gen_test_case_op_clone():
393    x = torch.rand(3, 2)
394
395    return gen_test_cases(
396        "OpCloneTest",
397        "clone_out",
398        make_simple_generated_case(
399            torch.ones(10, 10),
400            ArgForPyTorch(
401                EnumArg("exec_aten::MemoryFormat::Contiguous"),
402                "memory_format",
403                torch.contiguous_format,
404            ),
405            torch_fn=torch.clone,
406        )
407        + make_test_cases_dynamic_shape(
408            x,
409            EnumArg("exec_aten::MemoryFormat::Contiguous"),
410            torch.clone(x, memory_format=torch.contiguous_format),
411        ),
412    )
413
414
415def gen_test_case_op_cumsum():
416    x = torch.rand(3, 2)
417
418    return gen_test_cases(
419        "OpCumSumOutTest",
420        "cumsum_out",
421        make_simple_generated_case(
422            torch.ones(10, 10),
423            ArgForPyTorch(1, "dim", 1),
424            ArgForPyTorch(EnumArg("ScalarType::Float"), "dtype", torch.float),
425            torch_fn=torch.cumsum,
426        )
427        + make_test_cases_dynamic_shape(
428            x,
429            1,
430            EnumArg("ScalarType::Float"),
431            torch.cumsum(x, dim=1, dtype=torch.float),
432        ),
433    )
434
435
436def gen_test_case_op_detach_copy():
437    x = torch.rand(3, 2)
438
439    return gen_test_cases(
440        "OpDetachCopyOutKernelTest",
441        "_detach_copy_out",
442        make_simple_generated_case(torch.ones(10, 10), torch_fn=torch.detach)
443        + make_test_cases_dynamic_shape(x, torch.Tensor.detach(x)),
444    )
445
446
447def gen_test_case_op_exp():
448    # TODO: Implement
449    return
450
451
452def gen_test_case_op_expand():
453    # TODO: Implement
454    return
455
456
457def gen_test_case_op_full_like():
458    x = torch.rand(3, 2)
459
460    return gen_test_cases(
461        "OpFullLikeTest",
462        "full_like_out",
463        make_simple_generated_case(
464            torch.ones(10, 10),
465            Scalar(3.0),
466            ArgForPyTorch(
467                EnumArg("MemoryFormat::Contiguous"),
468                "memory_format",
469                torch.contiguous_format,
470            ),
471            torch_fn=torch.full_like,
472        )
473        + make_test_cases_dynamic_shape(
474            x,
475            Scalar(3.0),
476            EnumArg("MemoryFormat::Contiguous"),
477            torch.full_like(x, 3.0, memory_format=torch.contiguous_format),
478        ),
479    )
480
481
482def gen_test_case_op_gelu():
483    x = torch.rand(3, 2)
484
485    m = torch.nn.GELU(approximate="tanh")
486
487    return gen_test_cases(
488        "OpGeluKernelTest",
489        "gelu_out",
490        make_simple_generated_case(
491            torch.ones(10, 10), ArgForPyTorch(StringArg("tanh"), None, None), torch_fn=m
492        )
493        + make_test_cases_dynamic_shape(x, StringArg("tanh"), m(x)),
494    )
495
496
497def gen_test_case_op_glu():
498    x = torch.rand(4, 2)
499
500    m = torch.nn.GLU(0)
501
502    return gen_test_cases(
503        "OpGluOutKernelTest",
504        "glu_out",
505        make_test_cases_dynamic_shape(x, 0, m(x)),
506    )
507
508
509def gen_test_case_op_log():
510    x = torch.rand(3, 2)
511
512    return gen_test_cases(
513        "OpLogOutKernelTest",
514        "_log_out",
515        make_simple_generated_case(torch.ones(10, 10), torch_fn=torch.log)
516        + make_test_cases_dynamic_shape(x, torch.log(x)),
517    )
518
519
520def gen_test_case_op_log_softmax():
521    x = torch.rand(3, 2)
522
523    return gen_test_cases(
524        "OpLogSoftmaxOutTest",
525        "log_softmax_out",
526        make_simple_generated_case(
527            torch.ones(10, 10),
528            1,
529            ArgForPyTorch(False, None, None),
530            ArgForPyTorch(None, "dtype", torch.float),
531            torch_fn=torch.log_softmax,
532        )
533        + make_test_cases_dynamic_shape(
534            x, 1, False, torch.log_softmax(x, 1, torch.float)
535        ),
536    )
537
538
539def gen_test_case_op_logit():
540    x = torch.rand(3, 2)
541
542    return gen_test_cases(
543        "OpLogitOutKernelTest",
544        "logit_out",
545        make_simple_generated_case(torch.ones(10, 10), 0.1, torch_fn=torch.logit)
546        + make_test_cases_dynamic_shape(x, 0.1, torch.logit(x, 0.1)),
547    )
548
549
550def gen_test_case_op_mean():
551    x = torch.rand(3, 2)
552
553    return gen_test_cases(
554        "OpMeanOutTest",
555        "mean_dim_out",
556        make_simple_generated_case(
557            torch.ones(10, 10),
558            ArgForPyTorch(ArrayRef("int64_t", [1]), "dim", 1),
559            ArgForPyTorch(False, "keepdim", False),
560            ArgForPyTorch(EnumArg("ScalarType::Float"), "dtype", torch.float),
561            torch_fn=torch.mean,
562        )
563        + make_test_cases_dynamic_shape(
564            x,
565            ArrayRef("int64_t", [1]),
566            False,
567            EnumArg("ScalarType::Float"),
568            torch.Tensor.mean(x, dim=1, keepdim=False, dtype=torch.float),
569        ),
570    )
571
572
573def gen_test_case_op_nonzero():
574    # TODO: Implement
575    return
576
577
578def gen_test_case_op_permute():
579    # TODO: Implement
580    return
581
582
583def gen_test_case_op_relu():
584    x = torch.rand(3, 2)
585
586    return gen_test_cases(
587        "OpReluOutKernelTest",
588        "_relu_out",
589        make_simple_generated_case(torch.ones(10, 10), torch_fn=torch.relu)
590        + make_test_cases_dynamic_shape(x, torch.relu(x)),
591    )
592
593
594def gen_test_case_op_repeat():
595    # TODO: Implement
596    return
597
598
599def gen_test_case_op_round():
600    # TODO: Implement
601    return
602
603
604def gen_test_case_op_sigmoid():
605    # TODO: Implement
606    return
607
608
609def gen_test_case_op_slice():
610    # TODO: Implement
611    return
612
613
614def gen_test_case_op_softmax():
615    x = torch.rand(3, 2)
616
617    return gen_test_cases(
618        "OpSoftmaxOutTest",
619        "softmax_out",
620        make_simple_generated_case(
621            torch.ones(10, 10),
622            1,
623            ArgForPyTorch(False, "dtype", torch.float),
624            torch_fn=torch.softmax,
625        )
626        + make_test_cases_dynamic_shape(x, 1, False, torch.softmax(x, 1, torch.float)),
627    )
628
629
630def gen_test_case_op_squeeze():
631    # TODO: Implement
632    return
633
634
635def gen_test_case_op_sum():
636    # TODO: Implement
637    return
638
639
640def gen_test_case_op_t():
641    # TODO: Implement
642    return
643
644
645def gen_test_case_op_tanh():
646    x = torch.rand(3, 2)
647
648    return gen_test_cases(
649        "OpTanhOutKernelTest",
650        "_tanh_out",
651        make_simple_generated_case(torch.ones(10, 10), torch_fn=torch.tanh)
652        + make_test_cases_dynamic_shape(x, torch.tanh(x)),
653    )
654
655
656def gen_test_case_op_to():
657    # TODO: Implement
658    return
659
660
661def gen_test_case_op_transpose():
662    # TODO: Implement
663    return
664
665
666def gen_test_case_op_unsqueeze():
667    # TODO: Implement
668    return
669
670
671def gen_test_case_op_view():
672    # TODO: Implement
673    return
674
675
676def gen_test_case_op_zeros():
677    # TODO: Implement
678    return
679
680
681def gen_test_case_op_add():
682    x = torch.rand(3, 2)
683    y = torch.rand(3, 2)
684
685    return gen_test_cases(
686        "OpAddOutKernelTest",
687        "add_out",
688        make_simple_generated_case(
689            torch.ones(10, 10), torch.ones(10, 10), torch_fn=torch.add
690        )
691        + make_test_cases_broadcast_two_input_tensor(x, y, (1,), (), torch_fn=torch.add)
692        + make_test_cases_dynamic_shape(x, y, 1, torch.add(x, y)),
693    )
694
695
696def gen_test_case_op_bmm():
697    x = torch.rand(3, 3, 6)
698    y = torch.rand(3, 6, 2)
699
700    return gen_test_cases(
701        "OpBmmOutKernelTest",
702        "_bmm_out",
703        make_test_cases_dynamic_shape(x, y, torch.bmm(x, y)),
704    )
705
706
707def gen_test_case_op_copy():
708    # TODO: Implement
709    return
710
711
712def gen_test_case_op_div():
713    x = torch.rand(3, 2)
714    y = torch.rand(3, 2)
715
716    return gen_test_cases(
717        "OpDivOutKernelTest",
718        "_div_out",
719        make_test_cases_broadcast_two_input_tensor(x, y, (), (), torch_fn=torch.div)
720        + make_test_cases_dynamic_shape(x, y, torch.div(x, y)),
721    )
722
723
724def gen_test_case_op_embedding():
725    # TODO: Implement
726    return
727
728
729def gen_test_case_op_eq():
730    # TODO: Implement
731    return
732
733
734def gen_test_case_op_floor_divide():
735    x = torch.rand(3, 2)
736    y = torch.rand(3, 2)
737
738    return gen_test_cases(
739        "OpFloorDivideKernelTest",
740        "_floor_divide_out",
741        make_test_cases_broadcast_two_input_tensor(
742            x, y, (), (), torch_fn=torch.floor_divide
743        )
744        + make_test_cases_dynamic_shape(x, y, torch.floor_divide(x, y)),
745    )
746
747
748def gen_test_case_op_le():
749    # TODO: Implement
750    return
751
752
753def gen_test_case_op_minimum():
754    # TODO: Implement
755    return
756
757
758def gen_test_case_op_mm():
759    x = torch.rand(3, 2)
760    y = torch.rand(2, 4)
761
762    return gen_test_cases(
763        "OpMmOutKernelTest",
764        "_mm_out",
765        make_test_cases_dynamic_shape(x, y, torch.mm(x, y)),
766    )
767
768
769def gen_test_case_op_mul():
770    x = torch.rand(3, 2)
771    y = torch.rand(3, 2)
772
773    return gen_test_cases(
774        "OpMulOutKernelTest",
775        "_mul_out",
776        make_test_cases_broadcast_two_input_tensor(x, y, (), (), torch_fn=torch.mul)
777        + make_test_cases_dynamic_shape(x, y, torch.mul(x, y)),
778    )
779
780
781def gen_test_case_op_ne():
782    # TODO: Implement
783    return
784
785
786def gen_test_case_op_select():
787    # TODO: Implement
788    return
789
790
791def gen_test_case_op_select_scatter():
792    # TODO: Implement
793    return
794
795
796def gen_test_case_op_sub():
797    x = torch.rand(3, 2)
798    y = torch.rand(3, 2)
799
800    return gen_test_cases(
801        "OpSubOutKernelTest",
802        "sub_out",
803        make_test_cases_broadcast_two_input_tensor(x, y, (1,), (), torch_fn=torch.sub)
804        + make_test_cases_dynamic_shape(x, y, 1, torch.sub(x, y)),
805    )
806
807
808def gen_test_case_op_addmm():
809    x = torch.rand(3, 6)
810    y = torch.rand(6, 2)
811
812    b = torch.rand(3, 2)
813    b_dim_is_1 = torch.rand(1, 2)
814    b_miss_dim = torch.squeeze(b_dim_is_1)
815
816    return gen_test_cases(
817        "OpAddmmOutKernelTest",
818        "addmm_out",
819        [
820            (
821                "BroadcastDimSizeIsOne",
822                b_dim_is_1,
823                x,
824                y,
825                Scalar(1),
826                Scalar(1),
827                torch.addmm(b_dim_is_1, x, y),
828                ShapeDynamism.STATIC,
829            ),
830            (
831                "BroadcastDimSizeMissing",
832                b_miss_dim,
833                x,
834                y,
835                Scalar(1),
836                Scalar(1),
837                torch.addmm(b_dim_is_1, x, y),
838                ShapeDynamism.STATIC,
839            ),
840        ]
841        + make_test_cases_dynamic_shape(
842            b, x, y, Scalar(1), Scalar(1), torch.addmm(b, x, y)
843        ),
844    )
845
846
847def gen_test_case_op_convolution():
848    # TODO: Implement
849    return
850
851
852def gen_test_case_op_where():
853    # TODO: Implement
854    return
855
856
857def gen_test_case_op_masked_fill():
858    a = torch.rand(3, 2)
859
860    b = torch.rand(3, 2) > 0.5
861
862    return gen_test_cases(
863        "OpMaskedFillTest",
864        "masked_fill_scalar_out",
865        make_test_cases_broadcast_two_input_tensor(
866            a, b, (Scalar(3.0),), (3.0,), torch_fn=torch.masked_fill
867        )
868        + (
869            make_test_cases_dynamic_shape(
870                a, b, Scalar(3.0), torch.masked_fill(a, b, 3.0)
871            )
872        ),
873    )
874
875
876def get_test_case_name(generated_test_case: str):
877    m = re.search("TEST(_F)?\\(.*\\)", generated_test_case)
878    if m is not None:
879        test_case = m.group(0)
880        return "".join(test_case.split())
881
882
883def gen_test_cases_for_file(path_to_tests: str, op_name: str):
884    if ("gen_test_case_" + op_name) not in globals():
885        print(f"generator function is not defined for {op_name}")
886        return
887    gen_func = globals()[("gen_test_case_" + op_name)]
888    generated_test_cases = gen_func()
889    if generated_test_cases is None:
890        print(f"generator function is not implemented for {op_name}")
891        return
892    file_name = op_name + "_test.cpp"
893    with open(os.path.join(path_to_tests, file_name), "r+") as f:
894        previous = f.read()
895        # Remove all white spaces and new lines
896        previous = "".join(previous.split())
897        for generated_test_case in generated_test_cases:
898            if get_test_case_name(generated_test_case) not in previous:
899                f.write(generated_test_case)
900                print(f"test case {get_test_case_name(generated_test_case)} added")
901
902
903def main():
904    print("Generating test cases...")
905    if len(sys.argv) < 2:
906        print("Usage: test_case_gen.py <path-to-kernels/test>")
907        return
908    test_dir = sys.argv[1]
909    ops = [
910        f[:-9]
911        for f in os.listdir(test_dir)
912        if f.startswith("op_") and f.endswith("_test.cpp")
913    ]
914    for op in ops:
915        gen_test_cases_for_file(test_dir, op)
916
917
918if __name__ == "__main__":
919    main()
920