1from benchmark_core import _register_test 2from benchmark_pytorch import create_pytorch_op_test_case 3 4 5def generate_pt_test(configs, pt_bench_op): 6 """This function creates PyTorch op test based on the given operator""" 7 _register_test(configs, pt_bench_op, create_pytorch_op_test_case, False) 8 9 10def generate_pt_gradient_test(configs, pt_bench_op): 11 """This function creates PyTorch op test based on the given operator""" 12 _register_test(configs, pt_bench_op, create_pytorch_op_test_case, True) 13 14 15def generate_pt_tests_from_op_list(ops_list, configs, pt_bench_op): 16 """This function creates pt op tests one by one from a list of dictionaries. 17 ops_list is a list of dictionary. Each dictionary includes 18 the name of the operator and the math operation. Here is an example of using this API: 19 unary_ops_configs = op_bench.config_list( 20 attrs=[...], 21 attr_names=["M", "N"], 22 ) 23 unary_ops_list = op_bench.op_list( 24 attr_names=["op_name", "op_func"], 25 attrs=[ 26 ["abs", torch.abs], 27 ], 28 ) 29 class UnaryOpBenchmark(op_bench.TorchBenchmarkBase): 30 def init(self, M, N, op_name, op_func): 31 ... 32 def forward(self): 33 ... 34 op_bench.generate_pt_tests_from_op_list(unary_ops_list, unary_ops_configs, UnaryOpBenchmark) 35 """ 36 for op in ops_list: 37 _register_test(configs, pt_bench_op, create_pytorch_op_test_case, False, op) 38 39 40def generate_pt_gradient_tests_from_op_list(ops_list, configs, pt_bench_op): 41 for op in ops_list: 42 _register_test(configs, pt_bench_op, create_pytorch_op_test_case, True, op) 43