1import operator_benchmark as op_bench 2 3import torch 4 5 6qcomparators_configs = op_bench.cross_product_configs( 7 N=(8, 64), 8 dtype=(torch.quint8, torch.qint8, torch.qint32), 9 contig=(False, True), 10 other_scalar=(False, True), 11 out_variant=(False, True), 12 tags=("short",), 13) 14 15qcomparators_ops = op_bench.op_list( 16 attrs=( 17 ("eq", torch.eq), 18 ("ne", torch.ne), 19 ("lt", torch.lt), 20 ("gt", torch.gt), 21 ("le", torch.le), 22 ("ge", torch.ge), 23 ), 24 attr_names=("op_name", "op_func"), 25) 26 27 28class QComparatorBenchmark(op_bench.TorchBenchmarkBase): 29 def init(self, N, dtype, contig, other_scalar, out_variant, op_func): 30 # TODO: Consider more diverse shapes 31 f_input = (torch.rand(N, N) - 0.5) * 256 32 scale = 1.0 33 zero_point = 0 34 35 q_input_a = torch.quantize_per_tensor( 36 f_input, scale=scale, zero_point=zero_point, dtype=dtype 37 ) 38 q_input_b = q_input_a.clone() 39 40 if not contig: 41 permute_dims = list(range(f_input.ndim))[::-1] 42 q_input_a = q_input_a.permute(permute_dims) 43 44 self.qop = op_func 45 self.inputs = { 46 "q_input_a": q_input_a, 47 "q_input_b": q_input_b, 48 "out_variant": out_variant, 49 "other_scalar": other_scalar, 50 } 51 52 def forward(self, q_input_a, q_input_b, out_variant: bool, other_scalar: bool): 53 if out_variant: 54 if other_scalar: 55 return self.qop(q_input_a, 42, out=torch.tensor(True, dtype=torch.bool)) 56 else: 57 return self.qop( 58 q_input_a, q_input_b, out=torch.tensor(True, dtype=torch.bool) 59 ) 60 else: 61 if other_scalar: 62 return self.qop(q_input_a, 42) 63 else: 64 return self.qop(q_input_a, q_input_b) 65 66 67op_bench.generate_pt_tests_from_op_list( 68 qcomparators_ops, qcomparators_configs, QComparatorBenchmark 69) 70 71 72if __name__ == "__main__": 73 op_bench.benchmark_runner.main() 74