xref: /aosp_15_r20/external/pytorch/test/benchmark_utils/test_benchmark_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3import collections
4import json
5import os
6import re
7import textwrap
8import timeit
9import unittest
10from typing import Any, List, Tuple
11
12import expecttest
13import numpy as np
14
15import torch
16import torch.utils.benchmark as benchmark_utils
17from torch.testing._internal.common_utils import (
18    IS_SANDCASTLE,
19    IS_WINDOWS,
20    run_tests,
21    slowTest,
22    TEST_WITH_ASAN,
23    TestCase,
24)
25
26
27CALLGRIND_ARTIFACTS: str = os.path.join(
28    os.path.split(os.path.abspath(__file__))[0], "callgrind_artifacts.json"
29)
30
31
32def generate_callgrind_artifacts() -> None:
33    """Regenerate `callgrind_artifacts.json`
34
35    Unlike the expect tests, regenerating callgrind counts will produce a
36    large diff since build directories and conda/pip directories are included
37    in the instruction string. It is also not 100% deterministic (due to jitter
38    from Python) and takes over a minute to run. As a result, running this
39    function is manual.
40    """
41    print("Regenerating callgrind artifact.")
42
43    stats_no_data = benchmark_utils.Timer("y = torch.ones(())").collect_callgrind(
44        number=1000
45    )
46
47    stats_with_data = benchmark_utils.Timer("y = torch.ones((1,))").collect_callgrind(
48        number=1000
49    )
50
51    user = os.getenv("USER")
52
53    def to_entry(fn_counts):
54        return [f"{c} {fn.replace(f'/{user}/', '/test_user/')}" for c, fn in fn_counts]
55
56    artifacts = {
57        "baseline_inclusive": to_entry(stats_no_data.baseline_inclusive_stats),
58        "baseline_exclusive": to_entry(stats_no_data.baseline_exclusive_stats),
59        "ones_no_data_inclusive": to_entry(stats_no_data.stmt_inclusive_stats),
60        "ones_no_data_exclusive": to_entry(stats_no_data.stmt_exclusive_stats),
61        "ones_with_data_inclusive": to_entry(stats_with_data.stmt_inclusive_stats),
62        "ones_with_data_exclusive": to_entry(stats_with_data.stmt_exclusive_stats),
63    }
64
65    with open(CALLGRIND_ARTIFACTS, "w") as f:
66        json.dump(artifacts, f, indent=4)
67
68
69def load_callgrind_artifacts() -> (
70    Tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats]
71):
72    """Hermetic artifact to unit test Callgrind wrapper.
73
74    In addition to collecting counts, this wrapper provides some facilities for
75    manipulating and displaying the collected counts. The results of several
76    measurements are stored in callgrind_artifacts.json.
77
78    While FunctionCounts and CallgrindStats are pickleable, the artifacts for
79    testing are stored in raw string form for easier inspection and to avoid
80    baking any implementation details into the artifact itself.
81    """
82    with open(CALLGRIND_ARTIFACTS) as f:
83        artifacts = json.load(f)
84
85    pattern = re.compile(r"^\s*([0-9]+)\s(.+)$")
86
87    def to_function_counts(
88        count_strings: List[str], inclusive: bool
89    ) -> benchmark_utils.FunctionCounts:
90        data: List[benchmark_utils.FunctionCount] = []
91        for cs in count_strings:
92            # Storing entries as f"{c} {fn}" rather than [c, fn] adds some work
93            # reviving the artifact, but it makes the json much easier to read.
94            match = pattern.search(cs)
95            assert match is not None
96            c, fn = match.groups()
97            data.append(benchmark_utils.FunctionCount(count=int(c), function=fn))
98
99        return benchmark_utils.FunctionCounts(
100            tuple(sorted(data, reverse=True)), inclusive=inclusive
101        )
102
103    baseline_inclusive = to_function_counts(artifacts["baseline_inclusive"], True)
104    baseline_exclusive = to_function_counts(artifacts["baseline_exclusive"], False)
105
106    stats_no_data = benchmark_utils.CallgrindStats(
107        benchmark_utils.TaskSpec("y = torch.ones(())", "pass"),
108        number_per_run=1000,
109        built_with_debug_symbols=True,
110        baseline_inclusive_stats=baseline_inclusive,
111        baseline_exclusive_stats=baseline_exclusive,
112        stmt_inclusive_stats=to_function_counts(
113            artifacts["ones_no_data_inclusive"], True
114        ),
115        stmt_exclusive_stats=to_function_counts(
116            artifacts["ones_no_data_exclusive"], False
117        ),
118        stmt_callgrind_out=None,
119    )
120
121    stats_with_data = benchmark_utils.CallgrindStats(
122        benchmark_utils.TaskSpec("y = torch.ones((1,))", "pass"),
123        number_per_run=1000,
124        built_with_debug_symbols=True,
125        baseline_inclusive_stats=baseline_inclusive,
126        baseline_exclusive_stats=baseline_exclusive,
127        stmt_inclusive_stats=to_function_counts(
128            artifacts["ones_with_data_inclusive"], True
129        ),
130        stmt_exclusive_stats=to_function_counts(
131            artifacts["ones_with_data_exclusive"], False
132        ),
133        stmt_callgrind_out=None,
134    )
135
136    return stats_no_data, stats_with_data
137
138
139class MyModule(torch.nn.Module):
140    def forward(self, x):
141        return x + 1
142
143
144class TestBenchmarkUtils(TestCase):
145    def regularizeAndAssertExpectedInline(
146        self, x: Any, expect: str, indent: int = 12
147    ) -> None:
148        x_str: str = re.sub(
149            "object at 0x[0-9a-fA-F]+>",
150            "object at 0xXXXXXXXXXXXX>",
151            x if isinstance(x, str) else repr(x),
152        )
153        if "\n" in x_str:
154            # Indent makes the reference align at the call site.
155            x_str = textwrap.indent(x_str, " " * indent)
156
157        self.assertExpectedInline(x_str, expect, skip=1)
158
159    def test_timer(self):
160        timer = benchmark_utils.Timer(
161            stmt="torch.ones(())",
162        )
163        sample = timer.timeit(5).median
164        self.assertIsInstance(sample, float)
165
166        median = timer.blocked_autorange(min_run_time=0.01).median
167        self.assertIsInstance(median, float)
168
169        # We set a very high threshold to avoid flakiness in CI.
170        # The internal algorithm is tested in `test_adaptive_timer`
171        median = timer.adaptive_autorange(threshold=0.5).median
172
173        # Test that multi-line statements work properly.
174        median = (
175            benchmark_utils.Timer(
176                stmt="""
177                with torch.no_grad():
178                    y = x + 1""",
179                setup="""
180                x = torch.ones((1,), requires_grad=True)
181                for _ in range(5):
182                    x = x + 1.0""",
183            )
184            .timeit(5)
185            .median
186        )
187        self.assertIsInstance(sample, float)
188
189    @slowTest
190    @unittest.skipIf(IS_SANDCASTLE, "C++ timing is OSS only.")
191    @unittest.skipIf(True, "Failing on clang, see 74398")
192    def test_timer_tiny_fast_snippet(self):
193        timer = benchmark_utils.Timer(
194            "auto x = 1;(void)x;",
195            timer=timeit.default_timer,
196            language=benchmark_utils.Language.CPP,
197        )
198        median = timer.blocked_autorange().median
199        self.assertIsInstance(median, float)
200
201    @slowTest
202    @unittest.skipIf(IS_SANDCASTLE, "C++ timing is OSS only.")
203    @unittest.skipIf(True, "Failing on clang, see 74398")
204    def test_cpp_timer(self):
205        timer = benchmark_utils.Timer(
206            """
207                #ifndef TIMER_GLOBAL_CHECK
208                static_assert(false);
209                #endif
210
211                torch::Tensor y = x + 1;
212            """,
213            setup="torch::Tensor x = torch::empty({1});",
214            global_setup="#define TIMER_GLOBAL_CHECK",
215            timer=timeit.default_timer,
216            language=benchmark_utils.Language.CPP,
217        )
218        t = timer.timeit(10)
219        self.assertIsInstance(t.median, float)
220
221    class _MockTimer:
222        _seed = 0
223
224        _timer_noise_level = 0.05
225        _timer_cost = 100e-9  # 100 ns
226
227        _function_noise_level = 0.05
228        _function_costs = (
229            ("pass", 8e-9),
230            ("cheap_fn()", 4e-6),
231            ("expensive_fn()", 20e-6),
232            ("with torch.no_grad():\n    y = x + 1", 10e-6),
233        )
234
235        def __init__(self, stmt, setup, timer, globals):
236            self._random_state = np.random.RandomState(seed=self._seed)
237            self._mean_cost = dict(self._function_costs)[stmt]
238
239        def sample(self, mean, noise_level):
240            return max(self._random_state.normal(mean, mean * noise_level), 5e-9)
241
242        def timeit(self, number):
243            return sum(
244                [
245                    # First timer invocation
246                    self.sample(self._timer_cost, self._timer_noise_level),
247                    # Stmt body
248                    self.sample(self._mean_cost * number, self._function_noise_level),
249                    # Second timer invocation
250                    self.sample(self._timer_cost, self._timer_noise_level),
251                ]
252            )
253
254    def test_adaptive_timer(self):
255        class MockTimer(benchmark_utils.Timer):
256            _timer_cls = self._MockTimer
257
258        class _MockCudaTimer(self._MockTimer):
259            # torch.cuda.synchronize is much more expensive than
260            # just timeit.default_timer
261            _timer_cost = 10e-6
262
263            _function_costs = (
264                self._MockTimer._function_costs[0],
265                self._MockTimer._function_costs[1],
266                # GPU should be faster once there is enough work.
267                ("expensive_fn()", 5e-6),
268            )
269
270        class MockCudaTimer(benchmark_utils.Timer):
271            _timer_cls = _MockCudaTimer
272
273        m = MockTimer("pass").blocked_autorange(min_run_time=10)
274        self.regularizeAndAssertExpectedInline(
275            m,
276            """\
277            <torch.utils.benchmark.utils.common.Measurement object at 0xXXXXXXXXXXXX>
278            pass
279              Median: 7.98 ns
280              IQR:    0.52 ns (7.74 to 8.26)
281              125 measurements, 10000000 runs per measurement, 1 thread""",
282        )
283
284        self.regularizeAndAssertExpectedInline(
285            MockTimer("pass").adaptive_autorange(),
286            """\
287            <torch.utils.benchmark.utils.common.Measurement object at 0xXXXXXXXXXXXX>
288            pass
289              Median: 7.86 ns
290              IQR:    0.71 ns (7.63 to 8.34)
291              6 measurements, 1000000 runs per measurement, 1 thread""",
292        )
293
294        # Check against strings so we can reuse expect infra.
295        self.regularizeAndAssertExpectedInline(m.mean, """8.0013658357956e-09""")
296        self.regularizeAndAssertExpectedInline(m.median, """7.983151323215967e-09""")
297        self.regularizeAndAssertExpectedInline(len(m.times), """125""")
298        self.regularizeAndAssertExpectedInline(m.number_per_run, """10000000""")
299
300        self.regularizeAndAssertExpectedInline(
301            MockTimer("cheap_fn()").blocked_autorange(min_run_time=10),
302            """\
303            <torch.utils.benchmark.utils.common.Measurement object at 0xXXXXXXXXXXXX>
304            cheap_fn()
305              Median: 3.98 us
306              IQR:    0.27 us (3.85 to 4.12)
307              252 measurements, 10000 runs per measurement, 1 thread""",
308        )
309
310        self.regularizeAndAssertExpectedInline(
311            MockTimer("cheap_fn()").adaptive_autorange(),
312            """\
313            <torch.utils.benchmark.utils.common.Measurement object at 0xXXXXXXXXXXXX>
314            cheap_fn()
315              Median: 4.16 us
316              IQR:    0.22 us (4.04 to 4.26)
317              4 measurements, 1000 runs per measurement, 1 thread""",
318        )
319
320        self.regularizeAndAssertExpectedInline(
321            MockTimer("expensive_fn()").blocked_autorange(min_run_time=10),
322            """\
323            <torch.utils.benchmark.utils.common.Measurement object at 0xXXXXXXXXXXXX>
324            expensive_fn()
325              Median: 19.97 us
326              IQR:    1.35 us (19.31 to 20.65)
327              501 measurements, 1000 runs per measurement, 1 thread""",
328        )
329
330        self.regularizeAndAssertExpectedInline(
331            MockTimer("expensive_fn()").adaptive_autorange(),
332            """\
333            <torch.utils.benchmark.utils.common.Measurement object at 0xXXXXXXXXXXXX>
334            expensive_fn()
335              Median: 20.79 us
336              IQR:    1.09 us (20.20 to 21.29)
337              4 measurements, 1000 runs per measurement, 1 thread""",
338        )
339
340        self.regularizeAndAssertExpectedInline(
341            MockCudaTimer("pass").blocked_autorange(min_run_time=10),
342            """\
343            <torch.utils.benchmark.utils.common.Measurement object at 0xXXXXXXXXXXXX>
344            pass
345              Median: 7.92 ns
346              IQR:    0.43 ns (7.75 to 8.17)
347              13 measurements, 100000000 runs per measurement, 1 thread""",
348        )
349
350        self.regularizeAndAssertExpectedInline(
351            MockCudaTimer("pass").adaptive_autorange(),
352            """\
353            <torch.utils.benchmark.utils.common.Measurement object at 0xXXXXXXXXXXXX>
354            pass
355              Median: 7.75 ns
356              IQR:    0.57 ns (7.56 to 8.13)
357              4 measurements, 10000000 runs per measurement, 1 thread""",
358        )
359
360        self.regularizeAndAssertExpectedInline(
361            MockCudaTimer("cheap_fn()").blocked_autorange(min_run_time=10),
362            """\
363            <torch.utils.benchmark.utils.common.Measurement object at 0xXXXXXXXXXXXX>
364            cheap_fn()
365              Median: 4.04 us
366              IQR:    0.30 us (3.90 to 4.19)
367              25 measurements, 100000 runs per measurement, 1 thread""",
368        )
369
370        self.regularizeAndAssertExpectedInline(
371            MockCudaTimer("cheap_fn()").adaptive_autorange(),
372            """\
373            <torch.utils.benchmark.utils.common.Measurement object at 0xXXXXXXXXXXXX>
374            cheap_fn()
375              Median: 4.09 us
376              IQR:    0.38 us (3.90 to 4.28)
377              4 measurements, 100000 runs per measurement, 1 thread""",
378        )
379
380        self.regularizeAndAssertExpectedInline(
381            MockCudaTimer("expensive_fn()").blocked_autorange(min_run_time=10),
382            """\
383            <torch.utils.benchmark.utils.common.Measurement object at 0xXXXXXXXXXXXX>
384            expensive_fn()
385              Median: 4.98 us
386              IQR:    0.31 us (4.83 to 5.13)
387              20 measurements, 100000 runs per measurement, 1 thread""",
388        )
389
390        self.regularizeAndAssertExpectedInline(
391            MockCudaTimer("expensive_fn()").adaptive_autorange(),
392            """\
393            <torch.utils.benchmark.utils.common.Measurement object at 0xXXXXXXXXXXXX>
394            expensive_fn()
395              Median: 5.01 us
396              IQR:    0.28 us (4.87 to 5.15)
397              4 measurements, 10000 runs per measurement, 1 thread""",
398        )
399
400        # Make sure __repr__ is reasonable for
401        # multi-line / label / sub_label / description, but we don't need to
402        # check numerics.
403        multi_line_stmt = """
404        with torch.no_grad():
405            y = x + 1
406        """
407
408        self.regularizeAndAssertExpectedInline(
409            MockTimer(multi_line_stmt).blocked_autorange(),
410            """\
411            <torch.utils.benchmark.utils.common.Measurement object at 0xXXXXXXXXXXXX>
412            stmt:
413              with torch.no_grad():
414                  y = x + 1
415
416              Median: 10.06 us
417              IQR:    0.54 us (9.73 to 10.27)
418              20 measurements, 1000 runs per measurement, 1 thread""",
419        )
420
421        self.regularizeAndAssertExpectedInline(
422            MockTimer(multi_line_stmt, sub_label="scalar_add").blocked_autorange(),
423            """\
424            <torch.utils.benchmark.utils.common.Measurement object at 0xXXXXXXXXXXXX>
425            stmt: (scalar_add)
426              with torch.no_grad():
427                  y = x + 1
428
429              Median: 10.06 us
430              IQR:    0.54 us (9.73 to 10.27)
431              20 measurements, 1000 runs per measurement, 1 thread""",
432        )
433
434        self.regularizeAndAssertExpectedInline(
435            MockTimer(
436                multi_line_stmt,
437                label="x + 1 (no grad)",
438                sub_label="scalar_add",
439            ).blocked_autorange(),
440            """\
441            <torch.utils.benchmark.utils.common.Measurement object at 0xXXXXXXXXXXXX>
442            x + 1 (no grad): scalar_add
443              Median: 10.06 us
444              IQR:    0.54 us (9.73 to 10.27)
445              20 measurements, 1000 runs per measurement, 1 thread""",
446        )
447
448        self.regularizeAndAssertExpectedInline(
449            MockTimer(
450                multi_line_stmt,
451                setup="setup_fn()",
452                sub_label="scalar_add",
453            ).blocked_autorange(),
454            """\
455            <torch.utils.benchmark.utils.common.Measurement object at 0xXXXXXXXXXXXX>
456            stmt: (scalar_add)
457              with torch.no_grad():
458                  y = x + 1
459
460            setup: setup_fn()
461              Median: 10.06 us
462              IQR:    0.54 us (9.73 to 10.27)
463              20 measurements, 1000 runs per measurement, 1 thread""",
464        )
465
466        self.regularizeAndAssertExpectedInline(
467            MockTimer(
468                multi_line_stmt,
469                setup="""
470                    x = torch.ones((1,), requires_grad=True)
471                    for _ in range(5):
472                        x = x + 1.0""",
473                sub_label="scalar_add",
474                description="Multi-threaded scalar math!",
475                num_threads=16,
476            ).blocked_autorange(),
477            """\
478            <torch.utils.benchmark.utils.common.Measurement object at 0xXXXXXXXXXXXX>
479            stmt: (scalar_add)
480              with torch.no_grad():
481                  y = x + 1
482
483            Multi-threaded scalar math!
484            setup:
485              x = torch.ones((1,), requires_grad=True)
486              for _ in range(5):
487                  x = x + 1.0
488
489              Median: 10.06 us
490              IQR:    0.54 us (9.73 to 10.27)
491              20 measurements, 1000 runs per measurement, 16 threads""",
492        )
493
494    @slowTest
495    @unittest.skipIf(IS_WINDOWS, "Valgrind is not supported on Windows.")
496    @unittest.skipIf(IS_SANDCASTLE, "Valgrind is OSS only.")
497    @unittest.skipIf(TEST_WITH_ASAN, "fails on asan")
498    def test_collect_callgrind(self):
499        with self.assertRaisesRegex(
500            ValueError,
501            r"`collect_callgrind` requires that globals be wrapped "
502            r"in `CopyIfCallgrind` so that serialization is explicit.",
503        ):
504            benchmark_utils.Timer("pass", globals={"x": 1}).collect_callgrind(
505                collect_baseline=False
506            )
507
508        with self.assertRaisesRegex(
509            # Subprocess raises AttributeError (from pickle),
510            # _ValgrindWrapper re-raises as generic OSError.
511            OSError,
512            "AttributeError: Can't get attribute 'MyModule'",
513        ):
514            benchmark_utils.Timer(
515                "model(1)",
516                globals={"model": benchmark_utils.CopyIfCallgrind(MyModule())},
517            ).collect_callgrind(collect_baseline=False)
518
519        @torch.jit.script
520        def add_one(x):
521            return x + 1
522
523        timer = benchmark_utils.Timer(
524            "y = add_one(x) + k",
525            setup="x = torch.ones((1,))",
526            globals={
527                "add_one": benchmark_utils.CopyIfCallgrind(add_one),
528                "k": benchmark_utils.CopyIfCallgrind(5),
529                "model": benchmark_utils.CopyIfCallgrind(
530                    MyModule(),
531                    setup=f"""\
532                    import sys
533                    sys.path.append({repr(os.path.split(os.path.abspath(__file__))[0])})
534                    from test_benchmark_utils import MyModule
535                    """,
536                ),
537            },
538        )
539
540        stats = timer.collect_callgrind(number=1000)
541        counts = stats.counts(denoise=False)
542
543        self.assertIsInstance(counts, int)
544        self.assertGreater(counts, 0)
545
546        # There is some jitter with the allocator, so we use a simpler task to
547        # test reproducibility.
548        timer = benchmark_utils.Timer(
549            "x += 1",
550            setup="x = torch.ones((1,))",
551        )
552
553        stats = timer.collect_callgrind(number=1000, repeats=20)
554        assert isinstance(stats, tuple)
555
556        # Check that the repeats are at least somewhat repeatable. (within 10 instructions per iter)
557        counts = collections.Counter(
558            [s.counts(denoise=True) // 10_000 * 10_000 for s in stats]
559        )
560        self.assertGreater(
561            max(counts.values()),
562            1,
563            f"Every instruction count total was unique: {counts}",
564        )
565
566        from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import (
567            wrapper_singleton,
568        )
569
570        self.assertIsNone(
571            wrapper_singleton()._bindings_module,
572            "JIT'd bindings are only for back testing.",
573        )
574
575    @slowTest
576    @unittest.skipIf(IS_WINDOWS, "Valgrind is not supported on Windows.")
577    @unittest.skipIf(IS_SANDCASTLE, "Valgrind is OSS only.")
578    @unittest.skipIf(True, "Failing on clang, see 74398")
579    def test_collect_cpp_callgrind(self):
580        timer = benchmark_utils.Timer(
581            "x += 1;",
582            setup="torch::Tensor x = torch::ones({1});",
583            timer=timeit.default_timer,
584            language="c++",
585        )
586        stats = [timer.collect_callgrind() for _ in range(3)]
587        counts = [s.counts() for s in stats]
588
589        self.assertGreater(min(counts), 0, "No stats were collected")
590        self.assertEqual(
591            min(counts), max(counts), "C++ Callgrind should be deterministic"
592        )
593
594        for s in stats:
595            self.assertEqual(
596                s.counts(denoise=True),
597                s.counts(denoise=False),
598                "De-noising should not apply to C++.",
599            )
600
601        stats = timer.collect_callgrind(number=1000, repeats=20)
602        assert isinstance(stats, tuple)
603
604        # NB: Unlike the example above, there is no expectation that all
605        #     repeats will be identical.
606        counts = collections.Counter(
607            [s.counts(denoise=True) // 10_000 * 10_000 for s in stats]
608        )
609        self.assertGreater(max(counts.values()), 1, repr(counts))
610
611    def test_manipulate_callgrind_stats(self):
612        stats_no_data, stats_with_data = load_callgrind_artifacts()
613
614        # Mock `torch.set_printoptions(linewidth=160)`
615        wide_linewidth = benchmark_utils.FunctionCounts(
616            stats_no_data.stats(inclusive=False)._data, False, _linewidth=160
617        )
618
619        for l in repr(wide_linewidth).splitlines(keepends=False):
620            self.assertLessEqual(len(l), 160)
621
622        self.assertEqual(
623            # `delta` is just a convenience method.
624            stats_with_data.delta(stats_no_data)._data,
625            (stats_with_data.stats() - stats_no_data.stats())._data,
626        )
627
628        deltas = stats_with_data.as_standardized().delta(
629            stats_no_data.as_standardized()
630        )
631
632        def custom_transforms(fn: str):
633            fn = re.sub(re.escape("/usr/include/c++/8/bits/"), "", fn)
634            fn = re.sub(r"build/../", "", fn)
635            fn = re.sub(".+" + re.escape("libsupc++"), "libsupc++", fn)
636            return fn
637
638        self.regularizeAndAssertExpectedInline(
639            stats_no_data,
640            """\
641            <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0xXXXXXXXXXXXX>
642            y = torch.ones(())
643                                       All          Noisy symbols removed
644                Instructions:      8869966                    8728096
645                Baseline:             6682                       5766
646            1000 runs per measurement, 1 thread""",
647        )
648
649        self.regularizeAndAssertExpectedInline(
650            stats_no_data.counts(),
651            """8869966""",
652        )
653
654        self.regularizeAndAssertExpectedInline(
655            stats_no_data.counts(denoise=True),
656            """8728096""",
657        )
658
659        self.regularizeAndAssertExpectedInline(
660            stats_no_data.stats(),
661            """\
662            <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0xXXXXXXXXXXXX>
663              408000  ???:__tls_get_addr [/usr/lib64/ld-2.28.so]
664              388193  ???:_int_free [/usr/lib64/libc-2.28.so]
665              274000  build/../torch/csrc/utils/python ... rch/torch/lib/libtorch_python.so]
666              264000  build/../aten/src/ATen/record_fu ... ytorch/torch/lib/libtorch_cpu.so]
667              192000  build/../c10/core/Device.h:c10:: ... epos/pytorch/torch/lib/libc10.so]
668              169855  ???:_int_malloc [/usr/lib64/libc-2.28.so]
669              154000  build/../c10/core/TensorOptions. ... ytorch/torch/lib/libtorch_cpu.so]
670              148561  /tmp/build/80754af9/python_15996 ... da3/envs/throwaway/bin/python3.6]
671              135000  ???:malloc [/usr/lib64/libc-2.28.so]
672                 ...
673                2000  /usr/include/c++/8/ext/new_allocator.h:torch::PythonArgs::intlist(int)
674                2000  /usr/include/c++/8/bits/stl_vect ... *, _object*, _object*, _object**)
675                2000  /usr/include/c++/8/bits/stl_vect ... rningHandler::~PyWarningHandler()
676                2000  /usr/include/c++/8/bits/stl_vect ... ject*, _object*, _object**, bool)
677                2000  /usr/include/c++/8/bits/stl_algobase.h:torch::PythonArgs::intlist(int)
678                2000  /usr/include/c++/8/bits/shared_p ... ad_accumulator(at::Tensor const&)
679                2000  /usr/include/c++/8/bits/move.h:c ... te<c10::AutogradMetaInterface> >)
680                2000  /usr/include/c++/8/bits/atomic_b ... DispatchKey&&, caffe2::TypeMeta&)
681                2000  /usr/include/c++/8/array:at::Ten ... , at::Tensor&, c10::Scalar) const
682
683            Total: 8869966""",
684        )
685
686        self.regularizeAndAssertExpectedInline(
687            stats_no_data.stats(inclusive=True),
688            """\
689            <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0xXXXXXXXXXXXX>
690              8959166  ???:0x0000000000001050 [/usr/lib64/ld-2.28.so]
691              8959166  ???:(below main) [/usr/lib64/libc-2.28.so]
692              8959166  /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6]
693              8959166  /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6]
694              8959166  /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6]
695              8959166  /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6]
696              8959166  /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6]
697              8959166  /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6]
698              8959166  /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6]
699                  ...
700                92821  /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6]
701                91000  build/../torch/csrc/tensor/pytho ... ch/torch/lib/libtorch_python.so]
702                91000  /data/users/test_user/repos/pyto ... nsors::get_default_scalar_type()
703                90090  ???:pthread_mutex_lock [/usr/lib64/libpthread-2.28.so]
704                90000  build/../c10/core/TensorImpl.h:c ... ch/torch/lib/libtorch_python.so]
705                90000  build/../aten/src/ATen/record_fu ... torch/torch/lib/libtorch_cpu.so]
706                90000  /data/users/test_user/repos/pyto ... uard(c10::optional<c10::Device>)
707                90000  /data/users/test_user/repos/pyto ... ersionCounter::~VersionCounter()
708                88000  /data/users/test_user/repos/pyto ... ratorKernel*, at::Tensor const&)""",
709        )
710
711        self.regularizeAndAssertExpectedInline(
712            wide_linewidth,
713            """\
714            <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0xXXXXXXXXXXXX>
715              408000  ???:__tls_get_addr [/usr/lib64/ld-2.28.so]
716              388193  ???:_int_free [/usr/lib64/libc-2.28.so]
717              274000  build/../torch/csrc/utils/python_arg_parser.cpp:torch::FunctionSignature ...  bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]
718              264000  build/../aten/src/ATen/record_function.cpp:at::RecordFunction::RecordFun ... ordScope) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]
719              192000  build/../c10/core/Device.h:c10::Device::validate() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]
720              169855  ???:_int_malloc [/usr/lib64/libc-2.28.so]
721              154000  build/../c10/core/TensorOptions.h:c10::TensorOptions::merge_in(c10::Tens ... ns) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]
722              148561  /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:_PyEval_EvalFrameDefault [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]
723              135000  ???:malloc [/usr/lib64/libc-2.28.so]
724                 ...
725                2000  /usr/include/c++/8/ext/new_allocator.h:torch::PythonArgs::intlist(int)
726                2000  /usr/include/c++/8/bits/stl_vector.h:torch::PythonArgParser::raw_parse(_object*, _object*, _object*, _object**)
727                2000  /usr/include/c++/8/bits/stl_vector.h:torch::PyWarningHandler::~PyWarningHandler()
728                2000  /usr/include/c++/8/bits/stl_vector.h:torch::FunctionSignature::parse(_object*, _object*, _object*, _object**, bool)
729                2000  /usr/include/c++/8/bits/stl_algobase.h:torch::PythonArgs::intlist(int)
730                2000  /usr/include/c++/8/bits/shared_ptr_base.h:torch::autograd::impl::try_get_grad_accumulator(at::Tensor const&)
731                2000  /usr/include/c++/8/bits/move.h:c10::TensorImpl::set_autograd_meta(std::u ... AutogradMetaInterface, std::default_delete<c10::AutogradMetaInterface> >)
732                2000  /usr/include/c++/8/bits/atomic_base.h:at::Tensor at::detail::make_tensor ... t_null_type<c10::StorageImpl> >&&, c10::DispatchKey&&, caffe2::TypeMeta&)
733                2000  /usr/include/c++/8/array:at::Tensor& c10::Dispatcher::callWithDispatchKe ... , c10::Scalar)> const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const
734
735            Total: 8869966""",  # noqa: B950
736        )
737
738        self.regularizeAndAssertExpectedInline(
739            stats_no_data.as_standardized().stats(),
740            """\
741            <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0xXXXXXXXXXXXX>
742              408000  ???:__tls_get_addr
743              388193  ???:_int_free
744              274000  build/../torch/csrc/utils/python ... ject*, _object*, _object**, bool)
745              264000  build/../aten/src/ATen/record_fu ... ::RecordFunction(at::RecordScope)
746              192000  build/../c10/core/Device.h:c10::Device::validate()
747              169855  ???:_int_malloc
748              154000  build/../c10/core/TensorOptions. ... erge_in(c10::TensorOptions) const
749              148561  Python/ceval.c:_PyEval_EvalFrameDefault
750              135000  ???:malloc
751                 ...
752                2000  /usr/include/c++/8/ext/new_allocator.h:torch::PythonArgs::intlist(int)
753                2000  /usr/include/c++/8/bits/stl_vect ... *, _object*, _object*, _object**)
754                2000  /usr/include/c++/8/bits/stl_vect ... rningHandler::~PyWarningHandler()
755                2000  /usr/include/c++/8/bits/stl_vect ... ject*, _object*, _object**, bool)
756                2000  /usr/include/c++/8/bits/stl_algobase.h:torch::PythonArgs::intlist(int)
757                2000  /usr/include/c++/8/bits/shared_p ... ad_accumulator(at::Tensor const&)
758                2000  /usr/include/c++/8/bits/move.h:c ... te<c10::AutogradMetaInterface> >)
759                2000  /usr/include/c++/8/bits/atomic_b ... DispatchKey&&, caffe2::TypeMeta&)
760                2000  /usr/include/c++/8/array:at::Ten ... , at::Tensor&, c10::Scalar) const
761
762            Total: 8869966""",
763        )
764
765        self.regularizeAndAssertExpectedInline(
766            deltas,
767            """\
768            <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0xXXXXXXXXXXXX>
769                85000  Objects/dictobject.c:lookdict_unicode
770                59089  ???:_int_free
771                43000  ???:malloc
772                25000  build/../torch/csrc/utils/python ... :torch::PythonArgs::intlist(int)
773                24000  ???:__tls_get_addr
774                23000  ???:free
775                21067  Objects/dictobject.c:lookdict_unicode_nodummy
776                20000  build/../torch/csrc/utils/python ... :torch::PythonArgs::intlist(int)
777                18000  Objects/longobject.c:PyLong_AsLongLongAndOverflow
778                  ...
779                 2000  /home/nwani/m3/conda-bld/compile ... del_op.cc:operator delete(void*)
780                 1000  /usr/include/c++/8/bits/stl_vector.h:torch::PythonArgs::intlist(int)
781                  193  ???:_int_malloc
782                   75  ???:_int_memalign
783                -1000  build/../c10/util/SmallVector.h: ... _contiguous(c10::ArrayRef<long>)
784                -1000  build/../c10/util/SmallVector.h: ... nsor_restride(c10::MemoryFormat)
785                -1000  /usr/include/c++/8/bits/stl_vect ... es(_object*, _object*, _object*)
786                -8000  Python/ceval.c:_PyEval_EvalFrameDefault
787               -16000  Objects/tupleobject.c:PyTuple_New
788
789            Total: 432917""",
790        )
791
792        self.regularizeAndAssertExpectedInline(len(deltas), """35""")
793
794        self.regularizeAndAssertExpectedInline(
795            deltas.transform(custom_transforms),
796            """\
797            <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0xXXXXXXXXXXXX>
798                85000  Objects/dictobject.c:lookdict_unicode
799                59089  ???:_int_free
800                43000  ???:malloc
801                25000  torch/csrc/utils/python_numbers.h:torch::PythonArgs::intlist(int)
802                24000  ???:__tls_get_addr
803                23000  ???:free
804                21067  Objects/dictobject.c:lookdict_unicode_nodummy
805                20000  torch/csrc/utils/python_arg_parser.h:torch::PythonArgs::intlist(int)
806                18000  Objects/longobject.c:PyLong_AsLongLongAndOverflow
807                  ...
808                 2000  c10/util/SmallVector.h:c10::TensorImpl::compute_contiguous() const
809                 1000  stl_vector.h:torch::PythonArgs::intlist(int)
810                  193  ???:_int_malloc
811                   75  ???:_int_memalign
812                -1000  stl_vector.h:torch::autograd::TH ... es(_object*, _object*, _object*)
813                -1000  c10/util/SmallVector.h:c10::Tens ... _contiguous(c10::ArrayRef<long>)
814                -1000  c10/util/SmallVector.h:c10::Tens ... nsor_restride(c10::MemoryFormat)
815                -8000  Python/ceval.c:_PyEval_EvalFrameDefault
816               -16000  Objects/tupleobject.c:PyTuple_New
817
818            Total: 432917""",
819        )
820
821        self.regularizeAndAssertExpectedInline(
822            deltas.filter(lambda fn: fn.startswith("???")),
823            """\
824            <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0xXXXXXXXXXXXX>
825              59089  ???:_int_free
826              43000  ???:malloc
827              24000  ???:__tls_get_addr
828              23000  ???:free
829                193  ???:_int_malloc
830                 75  ???:_int_memalign
831
832            Total: 149357""",
833        )
834
835        self.regularizeAndAssertExpectedInline(
836            deltas[:5],
837            """\
838            <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0xXXXXXXXXXXXX>
839              85000  Objects/dictobject.c:lookdict_unicode
840              59089  ???:_int_free
841              43000  ???:malloc
842              25000  build/../torch/csrc/utils/python_ ... h:torch::PythonArgs::intlist(int)
843              24000  ???:__tls_get_addr
844
845            Total: 236089""",
846        )
847
848    def test_compare(self):
849        # Simulate several approaches.
850        costs = (
851            # overhead_optimized_fn()
852            (1e-6, 1e-9),
853            # compute_optimized_fn()
854            (3e-6, 5e-10),
855            # special_case_fn()  [square inputs only]
856            (1e-6, 4e-10),
857        )
858
859        sizes = (
860            (16, 16),
861            (16, 128),
862            (128, 128),
863            (4096, 1024),
864            (2048, 2048),
865        )
866
867        # overhead_optimized_fn()
868        class _MockTimer_0(self._MockTimer):
869            _function_costs = tuple(
870                (f"fn({i}, {j})", costs[0][0] + costs[0][1] * i * j) for i, j in sizes
871            )
872
873        class MockTimer_0(benchmark_utils.Timer):
874            _timer_cls = _MockTimer_0
875
876        # compute_optimized_fn()
877        class _MockTimer_1(self._MockTimer):
878            _function_costs = tuple(
879                (f"fn({i}, {j})", costs[1][0] + costs[1][1] * i * j) for i, j in sizes
880            )
881
882        class MockTimer_1(benchmark_utils.Timer):
883            _timer_cls = _MockTimer_1
884
885        # special_case_fn()
886        class _MockTimer_2(self._MockTimer):
887            _function_costs = tuple(
888                (f"fn({i}, {j})", costs[2][0] + costs[2][1] * i * j)
889                for i, j in sizes
890                if i == j
891            )
892
893        class MockTimer_2(benchmark_utils.Timer):
894            _timer_cls = _MockTimer_2
895
896        results = []
897        for i, j in sizes:
898            results.append(
899                MockTimer_0(
900                    f"fn({i}, {j})",
901                    label="fn",
902                    description=f"({i}, {j})",
903                    sub_label="overhead_optimized",
904                ).blocked_autorange(min_run_time=10)
905            )
906
907            results.append(
908                MockTimer_1(
909                    f"fn({i}, {j})",
910                    label="fn",
911                    description=f"({i}, {j})",
912                    sub_label="compute_optimized",
913                ).blocked_autorange(min_run_time=10)
914            )
915
916            if i == j:
917                results.append(
918                    MockTimer_2(
919                        f"fn({i}, {j})",
920                        label="fn",
921                        description=f"({i}, {j})",
922                        sub_label="special_case (square)",
923                    ).blocked_autorange(min_run_time=10)
924                )
925
926        def rstrip_lines(s: str) -> str:
927            # VSCode will rstrip the `expected` string literal whether you like
928            # it or not. So we have to rstrip the compare table as well.
929            return "\n".join([i.rstrip() for i in s.splitlines(keepends=False)])
930
931        compare = benchmark_utils.Compare(results)
932        self.regularizeAndAssertExpectedInline(
933            rstrip_lines(str(compare).strip()),
934            """\
935            [------------------------------------------------- fn ------------------------------------------------]
936                                         |  (16, 16)  |  (16, 128)  |  (128, 128)  |  (4096, 1024)  |  (2048, 2048)
937            1 threads: --------------------------------------------------------------------------------------------
938                  overhead_optimized     |    1.3     |     3.0     |     17.4     |     4174.4     |     4174.4
939                  compute_optimized      |    3.1     |     4.0     |     11.2     |     2099.3     |     2099.3
940                  special_case (square)  |    1.1     |             |      7.5     |                |     1674.7
941
942            Times are in microseconds (us).""",
943        )
944
945        compare.trim_significant_figures()
946        self.regularizeAndAssertExpectedInline(
947            rstrip_lines(str(compare).strip()),
948            """\
949            [------------------------------------------------- fn ------------------------------------------------]
950                                         |  (16, 16)  |  (16, 128)  |  (128, 128)  |  (4096, 1024)  |  (2048, 2048)
951            1 threads: --------------------------------------------------------------------------------------------
952                  overhead_optimized     |     1      |     3.0     |      17      |      4200      |      4200
953                  compute_optimized      |     3      |     4.0     |      11      |      2100      |      2100
954                  special_case (square)  |     1      |             |       8      |                |      1700
955
956            Times are in microseconds (us).""",
957        )
958
959        compare.colorize()
960        columnwise_colored_actual = rstrip_lines(str(compare).strip())
961        columnwise_colored_expected = textwrap.dedent(
962            """\
963            [------------------------------------------------- fn ------------------------------------------------]
964                                         |  (16, 16)  |  (16, 128)  |  (128, 128)  |  (4096, 1024)  |  (2048, 2048)
965            1 threads: --------------------------------------------------------------------------------------------
966                  overhead_optimized     |     1      |  \x1b[92m\x1b[1m   3.0   \x1b[0m\x1b[0m  |  \x1b[2m\x1b[91m    17    \x1b[0m\x1b[0m  |      4200      |  \x1b[2m\x1b[91m    4200    \x1b[0m\x1b[0m
967                  compute_optimized      |  \x1b[2m\x1b[91m   3    \x1b[0m\x1b[0m  |     4.0     |      11      |  \x1b[92m\x1b[1m    2100    \x1b[0m\x1b[0m  |      2100
968                  special_case (square)  |  \x1b[92m\x1b[1m   1    \x1b[0m\x1b[0m  |             |  \x1b[92m\x1b[1m     8    \x1b[0m\x1b[0m  |                |  \x1b[92m\x1b[1m    1700    \x1b[0m\x1b[0m
969
970            Times are in microseconds (us)."""  # noqa: B950
971        )
972
973        compare.colorize(rowwise=True)
974        rowwise_colored_actual = rstrip_lines(str(compare).strip())
975        rowwise_colored_expected = textwrap.dedent(
976            """\
977            [------------------------------------------------- fn ------------------------------------------------]
978                                         |  (16, 16)  |  (16, 128)  |  (128, 128)  |  (4096, 1024)  |  (2048, 2048)
979            1 threads: --------------------------------------------------------------------------------------------
980                  overhead_optimized     |  \x1b[92m\x1b[1m   1    \x1b[0m\x1b[0m  |  \x1b[2m\x1b[91m   3.0   \x1b[0m\x1b[0m  |  \x1b[31m\x1b[1m    17    \x1b[0m\x1b[0m  |  \x1b[31m\x1b[1m    4200    \x1b[0m\x1b[0m  |  \x1b[31m\x1b[1m    4200    \x1b[0m\x1b[0m
981                  compute_optimized      |  \x1b[92m\x1b[1m   3    \x1b[0m\x1b[0m  |     4.0     |  \x1b[2m\x1b[91m    11    \x1b[0m\x1b[0m  |  \x1b[31m\x1b[1m    2100    \x1b[0m\x1b[0m  |  \x1b[31m\x1b[1m    2100    \x1b[0m\x1b[0m
982                  special_case (square)  |  \x1b[92m\x1b[1m   1    \x1b[0m\x1b[0m  |             |  \x1b[31m\x1b[1m     8    \x1b[0m\x1b[0m  |                |  \x1b[31m\x1b[1m    1700    \x1b[0m\x1b[0m
983
984            Times are in microseconds (us)."""  # noqa: B950
985        )
986
987        def print_new_expected(s: str) -> None:
988            print(f'{"":>12}"""\\', end="")
989            for l in s.splitlines(keepends=False):
990                print("\n" + textwrap.indent(repr(l)[1:-1], " " * 12), end="")
991            print('"""\n')
992
993        if expecttest.ACCEPT:
994            # expecttest does not currently support non-printable characters,
995            # so these two entries have to be updated manually.
996            if columnwise_colored_actual != columnwise_colored_expected:
997                print("New columnwise coloring:\n")
998                print_new_expected(columnwise_colored_actual)
999
1000            if rowwise_colored_actual != rowwise_colored_expected:
1001                print("New rowwise coloring:\n")
1002                print_new_expected(rowwise_colored_actual)
1003
1004        self.assertEqual(columnwise_colored_actual, columnwise_colored_expected)
1005        self.assertEqual(rowwise_colored_actual, rowwise_colored_expected)
1006
1007    @unittest.skipIf(
1008        IS_WINDOWS and os.getenv("VC_YEAR") == "2019", "Random seed only accepts int32"
1009    )
1010    def test_fuzzer(self):
1011        fuzzer = benchmark_utils.Fuzzer(
1012            parameters=[
1013                benchmark_utils.FuzzedParameter(
1014                    "n", minval=1, maxval=16, distribution="loguniform"
1015                )
1016            ],
1017            tensors=[benchmark_utils.FuzzedTensor("x", size=("n",))],
1018            seed=0,
1019        )
1020
1021        expected_results = [
1022            (0.7821, 0.0536, 0.9888, 0.1949, 0.5242, 0.1987, 0.5094),
1023            (0.7166, 0.5961, 0.8303, 0.005),
1024        ]
1025
1026        for i, (tensors, _, _) in enumerate(fuzzer.take(2)):
1027            x = tensors["x"]
1028            self.assertEqual(x, torch.tensor(expected_results[i]), rtol=1e-3, atol=1e-3)
1029
1030
1031if __name__ == "__main__":
1032    run_tests()
1033