xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/__main__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import itertools
3import os
4
5# from . import conv           # noqa: F401
6# from . import normalization  # noqa: F401
7# from . import pooling        # noqa: F401
8from . import (  # noqa: F401
9    attention,
10    benchmark,
11    broadcast,
12    concat,
13    elementwise,
14    matmul,
15    reduction,
16    rnn_eltwise,
17    softmax,
18    swish,
19    tensor_engine,
20)
21
22
23def main():
24    parser = argparse.ArgumentParser(
25        formatter_class=argparse.RawDescriptionHelpFormatter,
26        description="""Benchmark operators in specific shapes.
27Works only with Python3.\n A few examples:
28  * benchmark.py: runs all the default configs with all the benchmarks.
29  * benchmark.py reduce: runs all the default configs with all benchmark with a prefix 'reduce'
30  * benchmark.py layernorm_fwd_cpu_128_32_128_128: run a particular benchmark in that config""",
31    )
32    parser.add_argument(
33        "benchmark_names",
34        type=str,
35        default=None,
36        nargs="*",
37        help="name of the benchmark to run",
38    )
39    parser.add_argument(
40        "--device",
41        type=str,
42        default="cpu,cuda",
43        help="a comma separated list of device names",
44    )
45    parser.add_argument(
46        "--mode",
47        type=str,
48        default="fwd,both",
49        help="a comma separated list of running modes",
50    )
51    parser.add_argument(
52        "--dtype",
53        type=str,
54        default="float32",
55        help="a comma separated list of Data Types: {float32[default], float16}",
56    )
57    parser.add_argument(
58        "--input-iter",
59        type=str,
60        default=None,
61        help="a comma separated list of Tensor dimensions that includes a start, \
62              stop, and increment that can be constant or a power of 2 \
63              {start:stop:inc,start:stop:pow2}",
64    )
65    parser.add_argument(
66        "--engine",
67        type=str,
68        default="pt",
69        help="the underlying tensor engine. only pt for now",
70    )
71    parser.add_argument(
72        "--jit-mode",
73        "--jit_mode",
74        type=str,
75        default="trace",
76        help="the jit mode to use: one of {trace, none}",
77    )
78    parser.add_argument(
79        "--cuda-pointwise-loop-levels",
80        "--cuda_pointwise_loop_levels",
81        type=int,
82        default=None,
83        help="num of loop levesl for Cuda pointwise operations: 2 or 3",
84    )
85    parser.add_argument(
86        "--cuda-pointwise-block-count",
87        "--cuda_pointwise_block_count",
88        type=int,
89        default=None,
90        help="num of block for Cuda pointwise operations",
91    )
92    parser.add_argument(
93        "--cuda-pointwise-block-size",
94        "--cuda_pointwise_block_size",
95        type=int,
96        default=None,
97        help="num of blocks for Cuda pointwise operations",
98    )
99    parser.add_argument(
100        "--cuda-fuser",
101        "--cuda_fuser",
102        type=str,
103        default="te",
104        help="The Cuda fuser backend to use: one of {te, nvf, old, none}",
105    )
106    parser.add_argument(
107        "--output",
108        type=str,
109        default="stdout",
110        help="The output format of the benchmark run {stdout[default], json}",
111    )
112    parser.add_argument(
113        "--print-ir",
114        action="store_true",
115        help="Print the IR graph of the Fusion.",
116    )
117    parser.add_argument(
118        "--print-kernel",
119        action="store_true",
120        help="Print generated kernel(s).",
121    )
122    parser.add_argument(
123        "--no-dynamic-shape",
124        action="store_true",
125        help="Disable shape randomization in dynamic benchmarks.",
126    )
127    parser.add_argument(
128        "--cpu-fusion",
129        "--cpu_fusion",
130        default=False,
131        action="store_true",
132        help="Enable CPU fusion.",
133    )
134    parser.add_argument(
135        "--cat-wo-conditionals",
136        "--cat_wo_conditionals",
137        default=False,
138        action="store_true",
139        help="Enable CAT wo conditionals.",
140    )
141
142    args = parser.parse_args()
143
144    if args.cuda_fuser == "te":
145        import torch
146
147        torch._C._jit_set_profiling_executor(True)
148        torch._C._jit_set_texpr_fuser_enabled(True)
149        torch._C._jit_override_can_fuse_on_gpu(True)
150        torch._C._get_graph_executor_optimize(True)
151    elif args.cuda_fuser == "old":
152        import torch
153
154        torch._C._jit_set_profiling_executor(False)
155        torch._C._jit_set_texpr_fuser_enabled(False)
156        torch._C._jit_override_can_fuse_on_gpu(True)
157    elif args.cuda_fuser == "nvf":
158        import torch
159
160        torch._C._jit_set_profiling_executor(True)
161        torch._C._jit_set_texpr_fuser_enabled(False)
162        torch._C._jit_set_nvfuser_enabled(True)
163        torch._C._get_graph_executor_optimize(True)
164    else:
165        raise ValueError(f"Undefined fuser: {args.cuda_fuser}")
166
167    if args.cpu_fusion:
168        import torch
169
170        torch._C._jit_override_can_fuse_on_cpu(True)
171    else:
172        import torch
173
174        torch._C._jit_override_can_fuse_on_cpu(False)
175
176    if args.cat_wo_conditionals:
177        import torch
178
179        torch._C._jit_cat_wo_conditionals(True)
180    else:
181        import torch
182
183        torch._C._jit_cat_wo_conditionals(False)
184
185    def set_global_threads(num_threads):
186        os.environ["OMP_NUM_THREADS"] = str(num_threads)
187        os.environ["MKL_NUM_THREADS"] = str(num_threads)
188        os.environ["TVM_NUM_THREADS"] = str(num_threads)
189        os.environ["NNC_NUM_THREADS"] = str(num_threads)
190
191    devices = args.device.split(",")
192    # accept 'gpu' as an alternative as the 'cuda' device
193    devices = ["cuda" if device == "gpu" else device for device in devices]
194    cpu_count = 0
195    for index, device in enumerate(devices):
196        if device.startswith("cpu"):
197            cpu_count += 1
198            if cpu_count > 1:
199                raise ValueError(
200                    "more than one CPU device is not allowed: %d" % (cpu_count)
201                )
202            if device == "cpu":
203                continue
204            num_threads_str = device[3:]
205            try:
206                # see if the device is in 'cpu1' or 'cpu4' format
207                num_threads = int(num_threads_str)
208                set_global_threads(num_threads)
209                devices[index] = "cpu"
210            except ValueError:
211                continue
212
213    modes = args.mode.split(",")
214
215    datatypes = args.dtype.split(",")
216    for index, dtype in enumerate(datatypes):
217        datatypes[index] = getattr(torch, dtype)
218        if not datatypes[index]:
219            raise AttributeError(f"DataType: {dtype} is not valid!")
220
221    tensor_engine.set_engine_mode(args.engine)
222
223    def run_default_configs(bench_cls, allow_skip=True):
224        for mode, device, dtype, config in itertools.product(
225            modes, devices, datatypes, bench_cls.default_configs()
226        ):
227            bench = bench_cls(mode, device, dtype, *config)
228            bench.output_type = args.output
229            bench.jit_mode = args.jit_mode
230            if not bench.is_supported():
231                if allow_skip:
232                    continue
233                else:
234                    raise ValueError(
235                        f"attempted to run an unsupported benchmark: {bench.desc()}"
236                    )
237            bench.run(args)
238
239    def run_with_input_iter(bench_cls, input_iter, allow_skip=True):
240        tensor_dim_specs = input_iter.split(",")
241        tensor_dim_specs = [dim.split(":") for dim in tensor_dim_specs]
242
243        configs = []
244        for start, stop, inc in tensor_dim_specs:
245            dim_list = []
246            if inc == "pow2":
247                curr = int(start)
248                while curr <= int(stop):
249                    dim_list.append(curr)
250                    curr <<= 1
251            elif inc == "pow2+1":
252                curr = int(start)
253                while curr <= int(stop):
254                    dim_list.append(curr)
255                    curr -= 1
256                    curr <<= 1
257                    curr += 1
258            else:
259                dim_list = list(range(int(start), int(stop) + int(inc), int(inc)))
260            configs.append(dim_list)
261        configs = itertools.product(*configs)
262
263        for mode, device, dtype, config in itertools.product(
264            modes, devices, datatypes, list(configs)
265        ):
266            bench = bench_cls(mode, device, dtype, *config)
267            bench.output_type = args.output
268            bench.jit_mode = args.jit_mode
269            if not bench.is_supported():
270                if allow_skip:
271                    continue
272                else:
273                    raise ValueError(
274                        f"attempted to run an unsupported benchmark: {bench.desc()}"
275                    )
276            bench.run(args)
277
278    benchmark_classes = benchmark.benchmark_classes
279    if not args.benchmark_names:
280        # by default, run all the benchmarks
281        for benchmark_cls in benchmark_classes:
282            run_default_configs(benchmark_cls, allow_skip=True)
283    else:
284        for name in args.benchmark_names:
285            # if the name is the prefix of a benchmark class, run all the benchmarks for that class
286            match_class_name = False
287            for bench_cls in benchmark_classes:
288                if name in bench_cls.module():
289                    match_class_name = True
290                    if (args.input_iter is not None) and bench_cls.input_iterable():
291                        run_with_input_iter(bench_cls, args.input_iter, allow_skip=True)
292                    else:
293                        if args.input_iter is not None:
294                            print(
295                                f"WARNING: Incompatible benchmark class called with input_iter arg: {name}"
296                            )
297                        run_default_configs(bench_cls, allow_skip=True)
298
299            if match_class_name:
300                continue
301
302            # if not a class module, parse the config and call it that way
303            match_class_name = False
304            for bench_cls in benchmark_classes:
305                cls_module = bench_cls.module()
306                if name.startswith(cls_module):
307                    match_class_name = True
308                    if name[len(cls_module)] != "_":
309                        raise ValueError(f"invalid name: {name}")
310                    config_str = name[(len(cls_module) + 1) :]
311                    config = config_str.split("_")
312                    if len(config) < 2:
313                        raise ValueError(f"invalid config: {config}")
314                    mode, device = config[0:2]
315                    # TODO: make sure virtual devices such as 'cpu1' and 'cpu4' are supported.
316                    if mode not in ["fwd", "both"]:
317                        raise ValueError(f"invalid mode: {mode}")
318                    for i, entry in enumerate(config):
319                        try:
320                            value = int(entry)
321                            config[i] = value
322                        except ValueError:
323                            pass
324                    # TODO: output dtype in the config and  parse it back from the str
325                    bench = bench_cls(config[0], config[1], torch.float32, *config[2:])
326                    bench.jit_mode = args.jit_mode
327                    bench.output_type = args.output
328                    bench.run(args)
329
330            if not match_class_name:
331                available_classes = ", ".join(
332                    [bench_cls.module() for bench_cls in benchmark_classes]
333                )
334                raise ValueError(
335                    f"invalid name: {name}\nAvailable benchmark classes:\n{available_classes}"
336                )
337
338
339if __name__ == "__main__":
340    main()
341