1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import os 8import re 9import sys 10 11from abc import ABC, abstractmethod 12 13from enum import Enum 14 15import torch 16 17""" 18A helper library to generate test cases for ET kernels. 19 20It simplifies the steps to generate a new c++ test case. User just need 21to specify the inputs and we use pytorch kernel to calculate the result. 22""" 23 24 25# Seed the RNG in all the common libraries for test reproducibility 26torch.manual_seed(0) 27 28 29def make_out_static_shape(tensor: torch.Tensor): 30 sizes = list(tensor.size()) 31 sizes = [str(s) for s in sizes] 32 sizes_str = "{" + ", ".join(sizes) + "}" 33 return sizes_str 34 35 36def make_out_dynamic_shape_bound_shape_same(tensor: torch.Tensor): 37 sizes = list(tensor.size()) 38 sizes = [str(s) for s in sizes] 39 sizes_str = "{" + ", ".join(sizes) + "}" 40 return sizes_str + ", torch::executor::TensorShapeDynamism::DYNAMIC_BOUND" 41 42 43def make_out_dynamic_shape_bound_shape_larger(tensor: torch.Tensor): 44 sizes = list(tensor.size()) 45 extra_sizes = [x * 2 for x in sizes] 46 extra_sizes = [str(s) for s in extra_sizes] 47 sizes_str = "{" + ", ".join(extra_sizes) + "}" 48 return sizes_str + ", torch::executor::TensorShapeDynamism::DYNAMIC_BOUND" 49 50 51def make_out_dynamic_shape_unbound_shape(tensor: torch.Tensor): 52 sizes = list(tensor.size()) 53 smaller_sizes = [1 for x in sizes] 54 smaller_sizes = [str(s) for s in smaller_sizes] 55 sizes_str = "{" + ", ".join(smaller_sizes) + "}" 56 return sizes_str + ", torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND" 57 58 59class ShapeDynamism(Enum): 60 # Static shape; shape is determined from pytorch output 61 STATIC = 1 62 # Dynamic bound with same size; shape is determined from pytorch output using the same size as static 63 DYNAMIC_BOUND_SAME_SHAPE = 2 64 # Dynamic bound with a larger size to test functionality; shape is determined from pytorch output 65 DYNAMIC_BOUND_LARGER_SHAPE = 3 66 # Dynamic unbound with a smaller size to test functionality 67 DYNAMIC_UNBOUND = 4 68 69 70out_dynamic_shape_fn_map = { 71 ShapeDynamism.STATIC: make_out_static_shape, 72 ShapeDynamism.DYNAMIC_BOUND_SAME_SHAPE: make_out_dynamic_shape_bound_shape_same, 73 ShapeDynamism.DYNAMIC_BOUND_LARGER_SHAPE: make_out_dynamic_shape_bound_shape_larger, 74 ShapeDynamism.DYNAMIC_UNBOUND: make_out_dynamic_shape_unbound_shape, 75} 76 77 78def make_test_cases_dynamic_shape(*args): 79 """ 80 A helper to make a list of tuples (test cases). Each tuple contains 81 the name, 82 inputs and output (expanded from *args), 83 dynamic shape type 84 """ 85 return [ 86 ( 87 "DynamicShapeUpperBoundSameAsExpected", 88 *args, 89 ShapeDynamism.DYNAMIC_BOUND_SAME_SHAPE, 90 ), 91 ( 92 "DynamicShapeUpperBoundLargerThanExpected", 93 *args, 94 ShapeDynamism.DYNAMIC_BOUND_LARGER_SHAPE, 95 ), 96 ( 97 "DynamicShapeUnbound", 98 *args, 99 ShapeDynamism.DYNAMIC_UNBOUND, 100 ), 101 ] 102 103 104def make_test_cases_broadcast_two_input_tensor(x, y, cpp_args, torch_args, torch_fn): 105 """ 106 A helper to make a list of tuples (test cases). Each tuple contains 107 the name, 108 inputs and output (expanded from *args), 109 dynamic shape type (use static here) 110 111 Used when we have two input tensors (like add, mul, div). 112 Generate test cases where 113 we drop a dimension from the first/second tensor 114 we set a dimension to one from the first/second tensor 115 """ 116 x_remove_dim = x[0] 117 x_first_dim_1 = x_remove_dim.squeeze(0) 118 y_remove_dim = y[0] 119 y_first_dim_1 = y_remove_dim.squeeze(0) 120 121 return [ 122 ( 123 "BroadcastDimSizeIsOneAB", 124 x_first_dim_1, 125 y, 126 *cpp_args, 127 torch_fn(x_first_dim_1, y, *torch_args), 128 ShapeDynamism.STATIC, 129 ), 130 ( 131 "BroadcastDimSizeMissingAB", 132 x_remove_dim, 133 y, 134 *cpp_args, 135 torch_fn(x_remove_dim, y, *torch_args), 136 ShapeDynamism.STATIC, 137 ), 138 ( 139 "BroadcastDimSizeIsOneBA", 140 x, 141 y_first_dim_1, 142 *cpp_args, 143 torch_fn(x, y_first_dim_1, *torch_args), 144 ShapeDynamism.STATIC, 145 ), 146 ( 147 "BroadcastDimSizeMissingBA", 148 x, 149 y_remove_dim, 150 *cpp_args, 151 torch_fn(x, y_remove_dim, *torch_args), 152 ShapeDynamism.STATIC, 153 ), 154 ] 155 156 157class ArgType(ABC): 158 """ 159 Represents an argument for generated C++ code and for pytorch call 160 """ 161 162 @abstractmethod 163 def to_pytorch(self): 164 return None 165 166 @abstractmethod 167 def to_cpp(self) -> str: 168 return "" 169 170 171class Scalar(ArgType): 172 def __init__(self, val): 173 self.val = val 174 175 def to_pytorch(self): 176 return self.val 177 178 def to_cpp(self): 179 return f"Scalar({self.val})" 180 181 182class OptScalar(ArgType): 183 def __init__(self, val): 184 self.val = val 185 186 def to_pytorch(self): 187 return self.val 188 189 def to_cpp(self): 190 return f"OptScalar({self.val})" 191 192 193class ArrayRef(ArgType): 194 def __init__(self, dtype, data: list): 195 self.dtype = dtype 196 self.data = data 197 198 def to_pytorch(self): 199 return self.data 200 201 def to_cpp(self): 202 array_str = "{" + ",".join(str(data) for data in self.data) + "}" 203 return f"ArrayRef<{self.dtype}>({array_str})" 204 205 206class EnumArg(ArgType): 207 def __init__(self, text): 208 self.text = text 209 210 def to_pytorch(self): 211 # Most likely it cannot be directly used 212 return "" 213 214 def to_cpp(self): 215 return self.text 216 217 218class StringArg(ArgType): 219 def __init__(self, text): 220 self.text = text 221 222 def to_pytorch(self): 223 return self.text 224 225 def to_cpp(self): 226 return f'"{self.text}"' 227 228 229def tensor_to_cpp_code(tensor: torch.Tensor) -> str: 230 sizes = list(tensor.size()) 231 sizes = [str(s) for s in sizes] 232 sizes_str = "{" + ", ".join(sizes) + "}" 233 data = torch.flatten(tensor).tolist() 234 data = [str(d) for d in data] 235 data_str = "{" + ", ".join(data) + "}" 236 if tensor.dtype == torch.bool: 237 return f"""tf_bool.make({sizes_str}, {data_str})""".replace( 238 "True", "true" 239 ).replace("False", "false") 240 return f"""tf.make({sizes_str}, {data_str})""" 241 242 243def argument_to_cpp_code(arg): 244 if isinstance(arg, str): 245 return arg 246 elif isinstance(arg, bool): 247 return "true" if arg else "false" 248 elif isinstance(arg, (int, float)) and not isinstance(arg, bool): 249 # Note: We explicitly exclude bool because bool is a subset of int 250 return str(arg) 251 elif isinstance(arg, bool): 252 return "true" if arg else "false" 253 elif isinstance(arg, torch.Tensor): 254 return tensor_to_cpp_code(arg) 255 elif isinstance(arg, ArgType): 256 return arg.to_cpp() 257 return "?" 258 259 260def argument_to_pytorch(arg): 261 if isinstance(arg, (str, int, float, torch.Tensor)): 262 return arg 263 elif isinstance(arg, ArgType): 264 return arg.to_pytorch() 265 return "?" 266 267 268class ArgForPyTorch: 269 """Sometimes an arg for cpp cannot directly be used in torch because it is not used, or used only in torch, or it is a kwarg""" 270 271 def __init__(self, cpp_arg, torch_kwarg_key, torch_kwarg_val): 272 self.cpp_arg = cpp_arg 273 self.kwarg_pair = torch_kwarg_key, torch_kwarg_val 274 275 def used_in_cpp(self): 276 return self.cpp_arg is not None 277 278 def used_in_torch(self): 279 return self.kwarg_pair != (None, None) 280 281 282def make_simple_generated_case(*args, torch_fn): 283 cpp_args = tuple( 284 arg.cpp_arg if isinstance(arg, ArgForPyTorch) else arg 285 for arg in args 286 if not isinstance(arg, ArgForPyTorch) or arg.used_in_cpp() 287 ) 288 torch_args = tuple( 289 argument_to_pytorch(arg) for arg in args if not isinstance(arg, ArgForPyTorch) 290 ) 291 kwargs_for_torch_fn = dict( 292 arg.kwarg_pair 293 for arg in args 294 if isinstance(arg, ArgForPyTorch) and arg.used_in_torch() 295 ) 296 return [ 297 ( 298 "SimpleGeneratedCase", 299 *cpp_args, 300 torch_fn(*torch_args, **kwargs_for_torch_fn), 301 ShapeDynamism.STATIC, 302 ) 303 ] 304 305 306def gen_test_cases(suite_name: str, op_name: str, test_cases, test_f=False): 307 """ 308 Used when some inputs are not Tensor or scalar. Treat them as code text and generate. 309 Each test case should be a tuple of 310 (test_case_name, inputs, expected_result, shape_dynamism) 311 out_size is the pre-allocatd size for out tensor 312 Set test_f to True if we want TEST_F (gtest fixture) 313 314 For example, in https://www.internalfb.com/code/fbsource/[7280e42e309e85294a77fbb51ccc6de1948f2497]/fbcode/executorch/kernels/test/op_add_test.cpp?lines=19-23, we have an additional alpha parameter 315 """ 316 317 variable_names = "xyzabcdefghijk" 318 newline = "\n" 319 320 generated_cases = [] 321 322 for test_name, *inputs, expected_result, shape_dynamism in test_cases: 323 out_dynamic_shape_fn = out_dynamic_shape_fn_map[shape_dynamism] 324 input_code = [argument_to_cpp_code(i) for i in inputs] 325 input_lines = [ 326 f"auto {variable_names[i]} = {input_code[i]};" for i in range(len(inputs)) 327 ] 328 329 need_tf_bool = any( 330 isinstance(i, torch.Tensor) and i.dtype == torch.bool for i in inputs 331 ) 332 333 ret_value = f"""{op_name}({", ".join(variable_names[:len(inputs)])}, out)""" 334 335 generated_cases.append( 336 f""" 337{"TEST_F" if test_f else "TEST"}({suite_name}, {test_name}) {{ 338 TensorFactory<ScalarType::Float> tf; 339 {"TensorFactory<ScalarType::Bool> tf_bool;" if need_tf_bool else ""} 340 341 {newline.join(input_lines)} 342 Tensor expected_result = {tensor_to_cpp_code(expected_result)}; 343 344 Tensor out = tf.zeros({out_dynamic_shape_fn(expected_result)}); 345 Tensor ret = {ret_value}; 346 EXPECT_TENSOR_CLOSE(out, expected_result); 347}} 348""" 349 ) 350 return generated_cases 351 352 353def gen_test_case_op_arange(): 354 return gen_test_cases( 355 "OpArangeOutTest", 356 "arange_out", 357 make_test_cases_dynamic_shape(Scalar(5), torch.arange(5)), 358 test_f=True, 359 ) 360 361 362def gen_test_case_op_as_strided_copy(): 363 # TODO: Implement 364 return 365 366 367def gen_test_case_op_bitwise_not(): 368 # TODO: Implement 369 return 370 371 372def gen_test_case_op_cat(): 373 # TODO: Implement 374 return 375 376 377def gen_test_case_op_clamp(): 378 x = torch.rand(3, 2) 379 380 return gen_test_cases( 381 "OpClampOutTest", 382 "clamp_out", 383 make_simple_generated_case( 384 torch.ones(10, 10), OptScalar(-0.5), OptScalar(0.5), torch_fn=torch.clamp 385 ) 386 + make_test_cases_dynamic_shape( 387 x, OptScalar(-0.5), OptScalar(0.5), torch.clamp(x, -0.5, 0.5) 388 ), 389 ) 390 391 392def gen_test_case_op_clone(): 393 x = torch.rand(3, 2) 394 395 return gen_test_cases( 396 "OpCloneTest", 397 "clone_out", 398 make_simple_generated_case( 399 torch.ones(10, 10), 400 ArgForPyTorch( 401 EnumArg("exec_aten::MemoryFormat::Contiguous"), 402 "memory_format", 403 torch.contiguous_format, 404 ), 405 torch_fn=torch.clone, 406 ) 407 + make_test_cases_dynamic_shape( 408 x, 409 EnumArg("exec_aten::MemoryFormat::Contiguous"), 410 torch.clone(x, memory_format=torch.contiguous_format), 411 ), 412 ) 413 414 415def gen_test_case_op_cumsum(): 416 x = torch.rand(3, 2) 417 418 return gen_test_cases( 419 "OpCumSumOutTest", 420 "cumsum_out", 421 make_simple_generated_case( 422 torch.ones(10, 10), 423 ArgForPyTorch(1, "dim", 1), 424 ArgForPyTorch(EnumArg("ScalarType::Float"), "dtype", torch.float), 425 torch_fn=torch.cumsum, 426 ) 427 + make_test_cases_dynamic_shape( 428 x, 429 1, 430 EnumArg("ScalarType::Float"), 431 torch.cumsum(x, dim=1, dtype=torch.float), 432 ), 433 ) 434 435 436def gen_test_case_op_detach_copy(): 437 x = torch.rand(3, 2) 438 439 return gen_test_cases( 440 "OpDetachCopyOutKernelTest", 441 "_detach_copy_out", 442 make_simple_generated_case(torch.ones(10, 10), torch_fn=torch.detach) 443 + make_test_cases_dynamic_shape(x, torch.Tensor.detach(x)), 444 ) 445 446 447def gen_test_case_op_exp(): 448 # TODO: Implement 449 return 450 451 452def gen_test_case_op_expand(): 453 # TODO: Implement 454 return 455 456 457def gen_test_case_op_full_like(): 458 x = torch.rand(3, 2) 459 460 return gen_test_cases( 461 "OpFullLikeTest", 462 "full_like_out", 463 make_simple_generated_case( 464 torch.ones(10, 10), 465 Scalar(3.0), 466 ArgForPyTorch( 467 EnumArg("MemoryFormat::Contiguous"), 468 "memory_format", 469 torch.contiguous_format, 470 ), 471 torch_fn=torch.full_like, 472 ) 473 + make_test_cases_dynamic_shape( 474 x, 475 Scalar(3.0), 476 EnumArg("MemoryFormat::Contiguous"), 477 torch.full_like(x, 3.0, memory_format=torch.contiguous_format), 478 ), 479 ) 480 481 482def gen_test_case_op_gelu(): 483 x = torch.rand(3, 2) 484 485 m = torch.nn.GELU(approximate="tanh") 486 487 return gen_test_cases( 488 "OpGeluKernelTest", 489 "gelu_out", 490 make_simple_generated_case( 491 torch.ones(10, 10), ArgForPyTorch(StringArg("tanh"), None, None), torch_fn=m 492 ) 493 + make_test_cases_dynamic_shape(x, StringArg("tanh"), m(x)), 494 ) 495 496 497def gen_test_case_op_glu(): 498 x = torch.rand(4, 2) 499 500 m = torch.nn.GLU(0) 501 502 return gen_test_cases( 503 "OpGluOutKernelTest", 504 "glu_out", 505 make_test_cases_dynamic_shape(x, 0, m(x)), 506 ) 507 508 509def gen_test_case_op_log(): 510 x = torch.rand(3, 2) 511 512 return gen_test_cases( 513 "OpLogOutKernelTest", 514 "_log_out", 515 make_simple_generated_case(torch.ones(10, 10), torch_fn=torch.log) 516 + make_test_cases_dynamic_shape(x, torch.log(x)), 517 ) 518 519 520def gen_test_case_op_log_softmax(): 521 x = torch.rand(3, 2) 522 523 return gen_test_cases( 524 "OpLogSoftmaxOutTest", 525 "log_softmax_out", 526 make_simple_generated_case( 527 torch.ones(10, 10), 528 1, 529 ArgForPyTorch(False, None, None), 530 ArgForPyTorch(None, "dtype", torch.float), 531 torch_fn=torch.log_softmax, 532 ) 533 + make_test_cases_dynamic_shape( 534 x, 1, False, torch.log_softmax(x, 1, torch.float) 535 ), 536 ) 537 538 539def gen_test_case_op_logit(): 540 x = torch.rand(3, 2) 541 542 return gen_test_cases( 543 "OpLogitOutKernelTest", 544 "logit_out", 545 make_simple_generated_case(torch.ones(10, 10), 0.1, torch_fn=torch.logit) 546 + make_test_cases_dynamic_shape(x, 0.1, torch.logit(x, 0.1)), 547 ) 548 549 550def gen_test_case_op_mean(): 551 x = torch.rand(3, 2) 552 553 return gen_test_cases( 554 "OpMeanOutTest", 555 "mean_dim_out", 556 make_simple_generated_case( 557 torch.ones(10, 10), 558 ArgForPyTorch(ArrayRef("int64_t", [1]), "dim", 1), 559 ArgForPyTorch(False, "keepdim", False), 560 ArgForPyTorch(EnumArg("ScalarType::Float"), "dtype", torch.float), 561 torch_fn=torch.mean, 562 ) 563 + make_test_cases_dynamic_shape( 564 x, 565 ArrayRef("int64_t", [1]), 566 False, 567 EnumArg("ScalarType::Float"), 568 torch.Tensor.mean(x, dim=1, keepdim=False, dtype=torch.float), 569 ), 570 ) 571 572 573def gen_test_case_op_nonzero(): 574 # TODO: Implement 575 return 576 577 578def gen_test_case_op_permute(): 579 # TODO: Implement 580 return 581 582 583def gen_test_case_op_relu(): 584 x = torch.rand(3, 2) 585 586 return gen_test_cases( 587 "OpReluOutKernelTest", 588 "_relu_out", 589 make_simple_generated_case(torch.ones(10, 10), torch_fn=torch.relu) 590 + make_test_cases_dynamic_shape(x, torch.relu(x)), 591 ) 592 593 594def gen_test_case_op_repeat(): 595 # TODO: Implement 596 return 597 598 599def gen_test_case_op_round(): 600 # TODO: Implement 601 return 602 603 604def gen_test_case_op_sigmoid(): 605 # TODO: Implement 606 return 607 608 609def gen_test_case_op_slice(): 610 # TODO: Implement 611 return 612 613 614def gen_test_case_op_softmax(): 615 x = torch.rand(3, 2) 616 617 return gen_test_cases( 618 "OpSoftmaxOutTest", 619 "softmax_out", 620 make_simple_generated_case( 621 torch.ones(10, 10), 622 1, 623 ArgForPyTorch(False, "dtype", torch.float), 624 torch_fn=torch.softmax, 625 ) 626 + make_test_cases_dynamic_shape(x, 1, False, torch.softmax(x, 1, torch.float)), 627 ) 628 629 630def gen_test_case_op_squeeze(): 631 # TODO: Implement 632 return 633 634 635def gen_test_case_op_sum(): 636 # TODO: Implement 637 return 638 639 640def gen_test_case_op_t(): 641 # TODO: Implement 642 return 643 644 645def gen_test_case_op_tanh(): 646 x = torch.rand(3, 2) 647 648 return gen_test_cases( 649 "OpTanhOutKernelTest", 650 "_tanh_out", 651 make_simple_generated_case(torch.ones(10, 10), torch_fn=torch.tanh) 652 + make_test_cases_dynamic_shape(x, torch.tanh(x)), 653 ) 654 655 656def gen_test_case_op_to(): 657 # TODO: Implement 658 return 659 660 661def gen_test_case_op_transpose(): 662 # TODO: Implement 663 return 664 665 666def gen_test_case_op_unsqueeze(): 667 # TODO: Implement 668 return 669 670 671def gen_test_case_op_view(): 672 # TODO: Implement 673 return 674 675 676def gen_test_case_op_zeros(): 677 # TODO: Implement 678 return 679 680 681def gen_test_case_op_add(): 682 x = torch.rand(3, 2) 683 y = torch.rand(3, 2) 684 685 return gen_test_cases( 686 "OpAddOutKernelTest", 687 "add_out", 688 make_simple_generated_case( 689 torch.ones(10, 10), torch.ones(10, 10), torch_fn=torch.add 690 ) 691 + make_test_cases_broadcast_two_input_tensor(x, y, (1,), (), torch_fn=torch.add) 692 + make_test_cases_dynamic_shape(x, y, 1, torch.add(x, y)), 693 ) 694 695 696def gen_test_case_op_bmm(): 697 x = torch.rand(3, 3, 6) 698 y = torch.rand(3, 6, 2) 699 700 return gen_test_cases( 701 "OpBmmOutKernelTest", 702 "_bmm_out", 703 make_test_cases_dynamic_shape(x, y, torch.bmm(x, y)), 704 ) 705 706 707def gen_test_case_op_copy(): 708 # TODO: Implement 709 return 710 711 712def gen_test_case_op_div(): 713 x = torch.rand(3, 2) 714 y = torch.rand(3, 2) 715 716 return gen_test_cases( 717 "OpDivOutKernelTest", 718 "_div_out", 719 make_test_cases_broadcast_two_input_tensor(x, y, (), (), torch_fn=torch.div) 720 + make_test_cases_dynamic_shape(x, y, torch.div(x, y)), 721 ) 722 723 724def gen_test_case_op_embedding(): 725 # TODO: Implement 726 return 727 728 729def gen_test_case_op_eq(): 730 # TODO: Implement 731 return 732 733 734def gen_test_case_op_floor_divide(): 735 x = torch.rand(3, 2) 736 y = torch.rand(3, 2) 737 738 return gen_test_cases( 739 "OpFloorDivideKernelTest", 740 "_floor_divide_out", 741 make_test_cases_broadcast_two_input_tensor( 742 x, y, (), (), torch_fn=torch.floor_divide 743 ) 744 + make_test_cases_dynamic_shape(x, y, torch.floor_divide(x, y)), 745 ) 746 747 748def gen_test_case_op_le(): 749 # TODO: Implement 750 return 751 752 753def gen_test_case_op_minimum(): 754 # TODO: Implement 755 return 756 757 758def gen_test_case_op_mm(): 759 x = torch.rand(3, 2) 760 y = torch.rand(2, 4) 761 762 return gen_test_cases( 763 "OpMmOutKernelTest", 764 "_mm_out", 765 make_test_cases_dynamic_shape(x, y, torch.mm(x, y)), 766 ) 767 768 769def gen_test_case_op_mul(): 770 x = torch.rand(3, 2) 771 y = torch.rand(3, 2) 772 773 return gen_test_cases( 774 "OpMulOutKernelTest", 775 "_mul_out", 776 make_test_cases_broadcast_two_input_tensor(x, y, (), (), torch_fn=torch.mul) 777 + make_test_cases_dynamic_shape(x, y, torch.mul(x, y)), 778 ) 779 780 781def gen_test_case_op_ne(): 782 # TODO: Implement 783 return 784 785 786def gen_test_case_op_select(): 787 # TODO: Implement 788 return 789 790 791def gen_test_case_op_select_scatter(): 792 # TODO: Implement 793 return 794 795 796def gen_test_case_op_sub(): 797 x = torch.rand(3, 2) 798 y = torch.rand(3, 2) 799 800 return gen_test_cases( 801 "OpSubOutKernelTest", 802 "sub_out", 803 make_test_cases_broadcast_two_input_tensor(x, y, (1,), (), torch_fn=torch.sub) 804 + make_test_cases_dynamic_shape(x, y, 1, torch.sub(x, y)), 805 ) 806 807 808def gen_test_case_op_addmm(): 809 x = torch.rand(3, 6) 810 y = torch.rand(6, 2) 811 812 b = torch.rand(3, 2) 813 b_dim_is_1 = torch.rand(1, 2) 814 b_miss_dim = torch.squeeze(b_dim_is_1) 815 816 return gen_test_cases( 817 "OpAddmmOutKernelTest", 818 "addmm_out", 819 [ 820 ( 821 "BroadcastDimSizeIsOne", 822 b_dim_is_1, 823 x, 824 y, 825 Scalar(1), 826 Scalar(1), 827 torch.addmm(b_dim_is_1, x, y), 828 ShapeDynamism.STATIC, 829 ), 830 ( 831 "BroadcastDimSizeMissing", 832 b_miss_dim, 833 x, 834 y, 835 Scalar(1), 836 Scalar(1), 837 torch.addmm(b_dim_is_1, x, y), 838 ShapeDynamism.STATIC, 839 ), 840 ] 841 + make_test_cases_dynamic_shape( 842 b, x, y, Scalar(1), Scalar(1), torch.addmm(b, x, y) 843 ), 844 ) 845 846 847def gen_test_case_op_convolution(): 848 # TODO: Implement 849 return 850 851 852def gen_test_case_op_where(): 853 # TODO: Implement 854 return 855 856 857def gen_test_case_op_masked_fill(): 858 a = torch.rand(3, 2) 859 860 b = torch.rand(3, 2) > 0.5 861 862 return gen_test_cases( 863 "OpMaskedFillTest", 864 "masked_fill_scalar_out", 865 make_test_cases_broadcast_two_input_tensor( 866 a, b, (Scalar(3.0),), (3.0,), torch_fn=torch.masked_fill 867 ) 868 + ( 869 make_test_cases_dynamic_shape( 870 a, b, Scalar(3.0), torch.masked_fill(a, b, 3.0) 871 ) 872 ), 873 ) 874 875 876def get_test_case_name(generated_test_case: str): 877 m = re.search("TEST(_F)?\\(.*\\)", generated_test_case) 878 if m is not None: 879 test_case = m.group(0) 880 return "".join(test_case.split()) 881 882 883def gen_test_cases_for_file(path_to_tests: str, op_name: str): 884 if ("gen_test_case_" + op_name) not in globals(): 885 print(f"generator function is not defined for {op_name}") 886 return 887 gen_func = globals()[("gen_test_case_" + op_name)] 888 generated_test_cases = gen_func() 889 if generated_test_cases is None: 890 print(f"generator function is not implemented for {op_name}") 891 return 892 file_name = op_name + "_test.cpp" 893 with open(os.path.join(path_to_tests, file_name), "r+") as f: 894 previous = f.read() 895 # Remove all white spaces and new lines 896 previous = "".join(previous.split()) 897 for generated_test_case in generated_test_cases: 898 if get_test_case_name(generated_test_case) not in previous: 899 f.write(generated_test_case) 900 print(f"test case {get_test_case_name(generated_test_case)} added") 901 902 903def main(): 904 print("Generating test cases...") 905 if len(sys.argv) < 2: 906 print("Usage: test_case_gen.py <path-to-kernels/test>") 907 return 908 test_dir = sys.argv[1] 909 ops = [ 910 f[:-9] 911 for f in os.listdir(test_dir) 912 if f.startswith("op_") and f.endswith("_test.cpp") 913 ] 914 for op in ops: 915 gen_test_cases_for_file(test_dir, op) 916 917 918if __name__ == "__main__": 919 main() 920