xref: /aosp_15_r20/external/pytorch/benchmarks/fastrnns/bench.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import copy
3import gc
4import json
5import sys
6import time
7from collections import namedtuple
8
9import torch
10from torch.autograd.profiler import record_function
11
12from .fuser import set_fuser
13from .runner import get_nn_runners
14
15
16BenchResult = namedtuple(
17    "BenchResult",
18    [
19        "name",
20        "avg_fwd",
21        "std_fwd",
22        "info_fwd",
23        "avg_bwd",
24        "std_bwd",
25        "info_bwd",
26    ],
27)
28
29
30def fit_str(string, colwidth=16):
31    if len(string) < colwidth:
32        return (colwidth - len(string)) * " " + string
33    else:
34        return string[:colwidth]
35
36
37def to_str(item):
38    if isinstance(item, float):
39        return f"{item:.4g}"
40    return str(item)
41
42
43def print_header(colwidth=16, sep=" "):
44    items = []
45    for item in BenchResult._fields:
46        items.append(fit_str(item))
47    return sep.join(items)
48
49
50def pretty_print(benchresult, colwidth=16, sep=" "):
51    items = []
52    for thing in benchresult:
53        items.append(fit_str(to_str(thing)))
54    return sep.join(items)
55
56
57# shim for torch.cuda.Event when running on cpu
58class Event:
59    def __init__(self, enable_timing):
60        pass
61
62    def record(self):
63        self.time = time.perf_counter()
64
65    def elapsed_time(self, end_event):
66        assert isinstance(end_event, Event)
67        return end_event.time - self.time
68
69
70def trainbench(
71    name,
72    rnn_creator,
73    nloops=100,
74    warmup=10,
75    seqLength=100,
76    numLayers=1,
77    inputSize=512,
78    hiddenSize=512,
79    miniBatch=64,
80    device="cuda",
81    seed=None,
82):
83    def train_batch(modeldef):
84        # CUDA events for timing
85        if device == "cuda":
86            timer_class = torch.cuda.Event
87        else:
88            timer_class = Event
89
90        fwd_start_event = timer_class(enable_timing=True)
91        fwd_end_event = timer_class(enable_timing=True)
92        bwd_start_event = timer_class(enable_timing=True)
93        bwd_end_event = timer_class(enable_timing=True)
94
95        gc.collect()
96
97        fwd_start_event.record()
98        with record_function("## forward ##"):
99            forward_output = modeldef.forward(*modeldef.inputs)
100        fwd_end_event.record()
101
102        # XXX: Use if need to print something
103        # print(modeldef.forward.graph_for(*modeldef.inputs))
104
105        if modeldef.backward_setup is not None:
106            backward_input = modeldef.backward_setup(forward_output)
107        else:
108            backward_input = forward_output
109
110        gc.collect()
111
112        bwd_start_event.record()
113        if modeldef.backward is not None:
114            modeldef.backward(*backward_input)
115        bwd_end_event.record()
116
117        if modeldef.backward is not None:
118            with torch.no_grad():
119                for param in modeldef.params:
120                    assert param.grad is not None
121                    param.grad.zero_()
122
123        if device == "cuda":
124            torch.cuda.synchronize()
125
126        fwd_time = fwd_start_event.elapsed_time(fwd_end_event)
127        bwd_time = bwd_start_event.elapsed_time(bwd_end_event)
128        return fwd_time, bwd_time
129
130    creator_args = creator_args = {
131        "seqLength": seqLength,
132        "numLayers": numLayers,
133        "inputSize": inputSize,
134        "hiddenSize": hiddenSize,
135        "miniBatch": miniBatch,
136        "device": device,
137        "seed": seed,
138    }
139
140    modeldef = rnn_creator(**creator_args)
141
142    [train_batch(modeldef) for _ in range(warmup)]
143
144    results = [train_batch(modeldef) for _ in range(nloops)]
145    fwd_times, bwd_times = zip(*results)
146
147    fwd_times = torch.tensor(fwd_times)
148    bwd_times = torch.tensor(bwd_times)
149    return BenchResult(
150        name=name,
151        avg_fwd=fwd_times.mean().item(),
152        std_fwd=fwd_times.std().item(),
153        info_fwd=fwd_times,
154        avg_bwd=bwd_times.mean().item(),
155        std_bwd=bwd_times.std().item(),
156        info_bwd=bwd_times,
157    )
158
159
160def print_stderr(*args, **kwargs):
161    kwargs["file"] = sys.stderr
162    return print(*args, **kwargs)
163
164
165def print_json_oss_format(results):
166    oss_results = {}
167    for group_name, group_val in results.items():
168        oss_results[group_name] = {}
169        for model_name, run_time in group_val.items():
170            # Output for OSS
171            oss_results[group_name][model_name] = run_time["avg"]
172
173    print(json.dumps(oss_results))
174
175
176def print_json_pep_format(results):
177    # print the AI-PEP format json string for each model
178    for group_name, group_val in results.items():
179        for model_name, run_time in group_val.items():
180            # Output for AI-PEP
181            num_iters = len(run_time["info"])
182            info = run_time["info"].tolist()
183            for i in range(num_iters):
184                print(
185                    "Caffe2Observer "
186                    + json.dumps(
187                        {
188                            "type": "NET",
189                            "metric": group_name + "-" + model_name,
190                            "unit": "ms",
191                            "value": str(info[i]),
192                        }
193                    )
194                )
195
196
197def bench(rnn_runners, group_name, print_json=False, sep=" ", **params):
198    print_stderr(print_header(sep=sep))
199    results = {}
200    for name, creator, context in rnn_runners:
201        with context():
202            try:
203                result = trainbench(name, creator, **params)
204                # Replace the value of info_fwd and info_bwd to None
205                result_with_no_info = result._replace(info_fwd="None", info_bwd="None")
206                print_stderr(pretty_print(result_with_no_info, sep=sep))
207                results[name] = result
208            except Exception as e:
209                if not print_json:
210                    raise
211
212    return {
213        group_name: {
214            k: {"avg": v.avg_fwd, "std": v.std_fwd, "info": v.info_fwd}
215            for k, v in results.items()
216        },
217        group_name
218        + "-backward": {
219            k: {"avg": v.avg_bwd, "std": v.std_bwd, "info": v.info_bwd}
220            for k, v in results.items()
221        },
222    }
223
224
225def bench_group(model_list, bench_name, bench_group, bench_args):
226    print_stderr(f"Benchmarking {bench_name}s...")
227    nn_results = bench(get_nn_runners(*model_list), bench_group, **bench_args)
228    print_stderr("")
229    return nn_results
230
231
232if __name__ == "__main__":
233    parser = argparse.ArgumentParser(description="Profile RNNs")
234
235    # groups help control which test group you want to run
236    # if you only want to run one/two benchmark, run it with
237    # e.g: python -m fastrnns.bench --rnns jit and --group rnns
238    default_groups = ["cnns", "rnns"]
239
240    parser.add_argument("--seqLength", default="100", type=int)
241    parser.add_argument("--numLayers", default="1", type=int)
242    parser.add_argument("--inputSize", default="512", type=int)
243    parser.add_argument("--hiddenSize", default="512", type=int)
244    parser.add_argument("--miniBatch", default="64", type=int)
245    parser.add_argument("--warmup", default="10", type=int)
246    parser.add_argument("--nloops", default="100", type=int)
247    parser.add_argument("--device", default="cuda", type=str)
248    parser.add_argument(
249        "--variable-lstms",
250        "--variable_lstms",
251        action="store_true",
252        help="Also benchmark variable sequence length lstms "
253        "Note that some of these run really slowly "
254        "and that the `seqLength` flag will be ignored.",
255    )
256    parser.add_argument("--sep", default=" ", type=str)
257    parser.add_argument("--print-json", nargs="?", default=None, const="oss")
258    parser.add_argument("--rnns", nargs="*", help="What to run. cudnn, aten, jit, etc")
259    parser.add_argument(
260        "--cnns", nargs="*", help="What to run. resnet18, resnet18_jit, resnet50, etc"
261    )
262    parser.add_argument(
263        "--group",
264        nargs="*",
265        default=default_groups,
266        help="Which group to run. cnns, rnns, etc.",
267    )
268    parser.add_argument(
269        "--fuser",
270        default="te",
271        type=str,
272        help="The fuser backend to use. One of: te, old, or none",
273    )
274    parser.add_argument(
275        "--executor",
276        default=None,
277        type=str,
278        help="The executor to use. One of: legacy, simple, profiling",
279    )
280    parser.add_argument(
281        "--cuda-pointwise-loop-level",
282        "--cuda_pointwise_loop_level",
283        default=None,
284        type=int,
285    )
286    parser.add_argument(
287        "--cuda-pointwise-block-count",
288        "--cuda_pointwise_block_count",
289        default=None,
290        type=int,
291    )
292    parser.add_argument(
293        "--cuda-pointwise-block-size",
294        "--cuda_pointwise_block_size",
295        default=None,
296        type=int,
297    )
298
299    args = parser.parse_args()
300    set_fuser(args.fuser, args.executor)
301
302    if args.cuda_pointwise_loop_level:
303        torch._C._jit_set_te_cuda_pointwise_loop_levels(args.cuda_pointwise_loop_level)
304    if args.cuda_pointwise_block_count:
305        torch._C._jit_set_te_cuda_pointwise_block_count(args.cuda_pointwise_block_count)
306    if args.cuda_pointwise_block_size:
307        torch._C._jit_set_te_cuda_pointwise_block_size(args.cuda_pointwise_block_size)
308
309    rnns = args.rnns or [
310        "cudnn",
311        "aten",
312        "jit",
313        "jit_premul",
314        "jit_premul_bias",
315        "jit_simple",
316        "jit_multilayer",
317        "py",
318    ]
319    cnns = args.cnns or ["resnet18", "resnet18_jit", "resnet50", "resnet50_jit"]
320    # TODO: Maybe add a separate section for the layernorm/dropout lstms
321    # 'cudnn_layernorm', jit_layernorm', 'jit_layernom_decom',
322    # 'jit', 'jit_dropout', 'cudnn_dropout'
323    vlrnns = ["vl_cudnn", "vl_jit", "vl_py"]
324
325    if args.print_json:
326        print_stderr = lambda *args, **kwargs: None  # noqa: E731,F811
327    print_stderr(args)
328
329    bench_args = copy.deepcopy(vars(args))
330    should_bench_varlen_lstms = args.variable_lstms
331    del bench_args["group"]
332    del bench_args["rnns"]
333    del bench_args["cnns"]
334    del bench_args["variable_lstms"]
335    del bench_args["fuser"]
336    del bench_args["executor"]
337    del bench_args["cuda_pointwise_loop_level"]
338    del bench_args["cuda_pointwise_block_count"]
339    del bench_args["cuda_pointwise_block_size"]
340
341    results = {}
342    if should_bench_varlen_lstms:
343        if args.nloops + args.warmup > 30:
344            print_stderr(
345                "WARNING: some of the variable sequence length lstms are "
346                "very unoptimized and therefore take forever to run."
347            )
348        results.update(
349            bench_group(vlrnns, "variable-length sequence LSTM", "vl_lstm", bench_args)
350        )
351
352    if "rnns" in args.group:
353        results.update(bench_group(rnns, "LSTM", "lstm", bench_args))
354    if "cnns" in args.group:
355        results.update(bench_group(cnns, "ResNet", "resnet", bench_args))
356
357    if args.print_json == "oss":
358        print_json_oss_format(results)
359    elif args.print_json == "pep":
360        print_json_pep_format(results)
361