xref: /aosp_15_r20/external/pytorch/test/test_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Owner(s): ["module: unknown"]
3
4import os
5import random
6import re
7import shutil
8import subprocess
9import sys
10import tempfile
11import textwrap
12import traceback
13import unittest
14import warnings
15from typing import Any, Dict, List
16
17import torch
18import torch.cuda
19import torch.nn as nn
20import torch.utils.cpp_extension
21import torch.utils.data
22from torch.autograd._functions.utils import check_onnx_broadcast
23from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings
24from torch.testing._internal.common_cuda import TEST_MULTIGPU
25from torch.testing._internal.common_device_type import (
26    instantiate_device_type_tests,
27    onlyCPU,
28    ops,
29)
30from torch.testing._internal.common_methods_invocations import op_db
31from torch.testing._internal.common_utils import (  # type: ignore[attr-defined]
32    IS_FBCODE,
33    IS_SANDCASTLE,
34    IS_WINDOWS,
35    load_tests,
36)
37from torch.utils._device import set_device
38from torch.utils._pytree import tree_all_only, tree_any
39from torch.utils._traceback import (
40    CapturedTraceback,
41    format_traceback_short,
42    report_compile_source_on_error,
43)
44from torch.utils.checkpoint import (
45    _infer_device_type,
46    checkpoint,
47    checkpoint_sequential,
48    get_device_states,
49)
50from torch.utils.data import DataLoader
51
52
53# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
54# sharding on sandcastle. This line silences flake warnings
55load_tests = load_tests
56
57HAS_CUDA = torch.cuda.is_available()
58
59
60from torch.testing._internal.common_utils import run_tests, TestCase
61
62
63class RandomDatasetMock(torch.utils.data.Dataset):
64    def __getitem__(self, index):
65        return torch.tensor([torch.rand(1).item(), random.uniform(0, 1)])
66
67    def __len__(self):
68        return 1000
69
70
71class TestCheckpoint(TestCase):
72    # This runs checkpoint_sequential on each of the nets in
73    # module_lists_to_compare, and compares them against the uncheckpointed model.
74    # To compare, it checks outputs as well as input gradients and parameter gradients
75    def _check_checkpoint_sequential(
76        self,
77        model,
78        module_lists_to_compare,
79        num_chunks,
80        input,
81        use_reentrant,
82    ):
83        # not checkpointed
84        out = model(input)
85        out_not_checkpointed = out.detach().clone()
86        model.zero_grad()
87        out.sum().backward()
88        grad_not_checkpointed = {
89            name: param.grad.detach().clone()
90            for name, param in model.named_parameters()
91        }
92        input_grad_not_checkpointed = input.grad.detach().clone()
93        for model_to_compare in module_lists_to_compare:
94            # checkpointed model by passing list of modules
95            detached = input.detach()
96            detached.requires_grad = True
97
98            # pass list of modules to checkpoint
99            out = checkpoint_sequential(
100                model_to_compare, num_chunks, detached, use_reentrant=use_reentrant
101            )
102            out_checkpointed = out.detach().clone()
103            model.zero_grad()
104            out.sum().backward()
105            grad_checkpointed = {
106                name: param.grad.detach().clone()
107                for name, param in model.named_parameters()
108            }
109            input_grad_checkpointed = detached.grad.detach().clone()
110            # compare outputs as well as the gradients of input and parameters
111            self.assertEqual(out_checkpointed, out_not_checkpointed)
112            self.assertEqual(input_grad_not_checkpointed, input_grad_checkpointed)
113            for name in grad_checkpointed:
114                self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])
115
116    # Test whether checkpoint is being triggered or not. For this, we check
117    # the number of times forward pass happens
118    def test_checkpoint_trigger(self):
119        class Net(nn.Module):
120            def __init__(self) -> None:
121                super().__init__()
122                self.counter = 0
123
124            def forward(self, input_var):
125                self.counter += 1
126                # For reentrant, need to have autograd actually
127                # pack a tensor to trigger recomp
128                ret = input_var * torch.tensor(2.0)
129                return ret
130
131        # checkpointed
132        for use_reentrant in [True, False]:
133            with self.subTest(use_reentrant=use_reentrant):
134                modules = [Net() for _ in range(10)]
135                for m in modules:
136                    self.assertEqual(m.counter, 0)
137                input_var = torch.randn(3, 4, requires_grad=True)
138                out = checkpoint_sequential(
139                    modules, 2, input_var, use_reentrant=use_reentrant
140                )
141                for m in modules:
142                    self.assertEqual(m.counter, 1)
143                out.sum().backward()
144                for m in modules[: (len(modules) // 2)]:
145                    self.assertEqual(m.counter, 2)
146                for m in modules[(len(modules) // 2) :]:
147                    self.assertEqual(m.counter, 1)
148
149    def test_checkpoint_valid(self):
150        model = nn.Sequential(
151            nn.Linear(100, 50),
152            nn.ReLU(),
153            nn.Linear(50, 20),
154            nn.ReLU(),
155            nn.Linear(20, 5),
156            nn.ReLU(),
157        )
158
159        input_var = torch.randn(1, 100, requires_grad=True)
160
161        # checkpointed
162        chunks = 2
163        modules = list(model.children())
164        out = checkpoint_sequential(modules, chunks, input_var, use_reentrant=True)
165        with self.assertRaisesRegex(
166            RuntimeError, "torch.utils.checkpoint is incompatible"
167        ):
168            torch.autograd.grad(
169                outputs=[out],
170                grad_outputs=[torch.ones(1, 5)],
171                inputs=[input_var],
172                create_graph=True,
173            )
174        # works with use_reentrant=False, and grads are the same
175        out = model(input_var)
176        grads_no_checkpoint = torch.autograd.grad(
177            outputs=[out],
178            grad_outputs=[torch.ones(1, 5)],
179            inputs=[input_var],
180            create_graph=True,
181        )
182        out_checkpoint = checkpoint_sequential(
183            modules, chunks, input_var, use_reentrant=False
184        )
185        # check outputs are the same
186        self.assertEqual(out_checkpoint, out)
187        grads_checkpoint = torch.autograd.grad(
188            outputs=[out_checkpoint],
189            grad_outputs=[torch.ones(1, 5)],
190            inputs=[input_var],
191            create_graph=True,
192        )
193        self.assertEqual(grads_no_checkpoint, grads_checkpoint)
194
195    def test_checkpoint(self):
196        for use_reentrant in [True, False]:
197            with self.subTest(use_reentrant=use_reentrant):
198                model = nn.Sequential(
199                    nn.Linear(100, 50),
200                    nn.ReLU(),
201                    nn.Linear(50, 20),
202                    nn.ReLU(),
203                    nn.Linear(20, 5),
204                    nn.ReLU(),
205                )
206
207                # Compare uncheckpointed model with its checkpointed counterparts
208                # In addition to running checkpoint_sequential on the nn.Sequential
209                # instance, we also run the function on the list of functions within
210                # the module.
211                self._check_checkpoint_sequential(
212                    model,
213                    [list(model.children()), model],
214                    2,
215                    torch.randn(1, 100, requires_grad=True),
216                    use_reentrant=use_reentrant,
217                )
218
219    def test_checkpoint_module_list(self):
220        class ModuleListNet(nn.Module):
221            def __init__(self) -> None:
222                super().__init__()
223                module_list = [
224                    nn.Linear(100, 50),
225                    nn.ReLU(),
226                    nn.Linear(50, 20),
227                    nn.ReLU(),
228                    nn.Linear(20, 5),
229                    nn.ReLU(),
230                ]
231                self.module_list = nn.ModuleList(module_list)
232
233            def forward(self, input):
234                for layer in self.module_list:
235                    input = layer(input)
236                return input
237
238        for use_reentrant in [True, False]:
239            with self.subTest(use_reentrant=use_reentrant):
240                model = ModuleListNet()
241
242                # Compare uncheckpointed model with its checkpointed counterparts.
243                self._check_checkpoint_sequential(
244                    model,
245                    [list(model.module_list.children()), model.module_list],
246                    2,
247                    torch.randn(1, 100, requires_grad=True),
248                    use_reentrant=use_reentrant,
249                )
250
251    def test_checkpoint_sequential_deprecated_multiple_args(self):
252        class Two(nn.Module):
253            def forward(self, a, b):
254                return a, b
255
256        model = nn.Sequential(Two())
257        a = torch.randn(1, 100, requires_grad=True)
258        b = torch.randn(1, 100, requires_grad=True)
259
260        for use_reentrant in [True, False]:
261            with self.subTest(use_reentrant=use_reentrant):
262                with self.assertRaises(TypeError):
263                    checkpoint_sequential(model, 1, a, b)  # type: ignore[call-arg]
264
265    def test_checkpoint_sequential_deprecated_no_args(self):
266        class Noop(nn.Module):
267            def forward(self):
268                pass
269
270        model = nn.Sequential(Noop())
271        for use_reentrant in [True, False]:
272            with self.subTest(use_reentrant=use_reentrant):
273                with self.assertRaises(TypeError):
274                    checkpoint_sequential(model, 1)  # type: ignore[call-arg]
275
276    def test_checkpoint_rng_cpu(self):
277        for _ in range(5):
278            inp = torch.randn(20000, device="cpu").requires_grad_()
279            phase1 = torch.nn.Dropout()
280            phase2 = torch.nn.Dropout()
281
282            def run_fn(input):
283                return phase2(input)
284
285            state = torch.get_rng_state()
286
287            out = phase1(inp)
288            out = checkpoint(run_fn, out, use_reentrant=True)
289            out.sum().backward()
290            grad_with_checkpointing = inp.grad
291
292            torch.set_rng_state(state)
293
294            inp.grad = None
295
296            out = phase1(inp)
297            out = run_fn(out)
298            out.sum().backward()
299            grad_no_checkpointing = inp.grad
300
301            self.assertEqual(grad_with_checkpointing, grad_no_checkpointing)
302
303    @unittest.skipIf(not HAS_CUDA, "No CUDA")
304    def test_checkpoint_rng_cuda(self):
305        for _ in range(5):
306            inp = torch.randn(20000, device="cuda").requires_grad_()
307            phase1 = torch.nn.Dropout()
308            phase2 = torch.nn.Dropout()
309
310            def run_fn(input):
311                return phase2(input)
312
313            state = torch.cuda.get_rng_state()
314
315            out = phase1(inp)
316            out = checkpoint(run_fn, out, use_reentrant=True)
317            out.sum().backward()
318            grad_with_checkpointing = inp.grad
319
320            torch.cuda.set_rng_state(state)
321
322            inp.grad = None
323
324            out = phase1(inp)
325            out = run_fn(out)
326            out.sum().backward()
327            grad_no_checkpointing = inp.grad
328
329            self.assertEqual(grad_with_checkpointing, grad_no_checkpointing)
330
331    @unittest.skipIf(not HAS_CUDA, "No CUDA")
332    def test_checkpoint_not_preserve_rng_state_and_without_reentrant(self):
333        inp = torch.randn(2, device="cuda").requires_grad_()
334        layer = torch.nn.Dropout()
335
336        def run_fn(input):
337            return layer(input)
338
339        out = checkpoint(run_fn, inp, use_reentrant=False, preserve_rng_state=False)
340        out.sum().backward()
341        # This should run without error
342
343    def test_checkpoint_non_tensor(self):
344        def run_fn(tensor1, tensor2):
345            if tensor2 is None:
346                return tensor1
347            return tensor1 + tensor2
348
349        input_var = torch.randn(1, 100, requires_grad=True)
350        out = checkpoint(run_fn, input_var, None, use_reentrant=True)
351        out.sum().backward()
352
353    def test_checkpoint_non_tensor_inputs_outputs(self):
354        def foo(t1, t2, scale, t3):
355            t4 = t1 + t2 * t3
356            t5 = t1 * t2 + t3
357            t4 *= scale
358            t5 *= scale
359            return scale, t4, None, True, t5, "bar", t1
360
361        t1 = torch.rand(10, requires_grad=True)
362        t2 = torch.rand(10, requires_grad=True)
363        t3 = torch.rand(10)
364        scale = random.randint(0, 10)
365        res = checkpoint(foo, t1, t2, scale, t3, use_reentrant=True)
366        self.assertEqual(scale, res[0])
367        self.assertEqual((t1 + t2 * t3) * scale, res[1])
368        self.assertEqual(None, res[2])
369        self.assertEqual(True, res[3])
370        self.assertEqual((t1 * t2 + t3) * scale, res[4])
371        self.assertEqual("bar", res[5])
372        self.assertEqual(t1, res[6])
373
374        # Validate running backward.
375        res[1].sum().backward(retain_graph=True)
376        res[4].sum().backward(retain_graph=True)
377        res[6].sum().backward()
378        with self.assertRaisesRegex(
379            RuntimeError, "Trying to backward through the graph a second time"
380        ):
381            res[6].sum().backward()
382        t1_grad = t1.grad
383        t2_grad = t2.grad
384
385        # Reset grads, run without checkpoint and validate we receive same grads.
386        t1.grad = None
387        t2.grad = None
388        res = foo(t1, t2, scale, t3)
389        torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()])
390        self.assertEqual(t1.grad, t1_grad)
391        self.assertEqual(t2.grad, t2_grad)
392
393    def test_checkpoint_no_tensors(self):
394        def foo(t1, t2, scale, t3):
395            t4 = t1 + t2 * t3
396            t5 = t1 * t2 + t3
397            t4 *= scale
398            t5 *= scale
399            return scale, t4, None, True, t5, "bar", t1
400
401        t1 = random.random()
402        t2 = random.random()
403        t3 = random.random()
404        scale = random.randint(0, 10)
405        res = checkpoint(foo, t1, t2, scale, t3, use_reentrant=True)
406        self.assertEqual(scale, res[0])
407        self.assertEqual((t1 + t2 * t3) * scale, res[1])
408        self.assertEqual(None, res[2])
409        self.assertEqual(True, res[3])
410        self.assertEqual((t1 * t2 + t3) * scale, res[4])
411        self.assertEqual("bar", res[5])
412        self.assertEqual(t1, res[6])
413
414    def test_checkpoint_partial_grad(self):
415        def run_fn(tensor1, tensor2):
416            # tensor 2 is used for other application logic
417            return tensor1, tensor2
418
419        input_var = torch.randn(1, 4, requires_grad=True)
420        input_var2 = torch.randn(1, 4, requires_grad=False)
421        out = checkpoint(run_fn, input_var, input_var2, use_reentrant=True)
422        out[0].sum().backward()
423
424        def run_fn2(tensor1, tensor2):
425            return tensor1
426
427        input_var = torch.randn(1, 4, requires_grad=False)
428        input_var2 = torch.randn(1, 4, requires_grad=True)
429        with self.assertRaisesRegex(
430            RuntimeError,
431            r"none of output has requires_grad=True, this checkpoint\(\) is not necessary",
432        ):
433            out = checkpoint(run_fn2, input_var, input_var2, use_reentrant=True)
434            out.sum().backward()
435
436    @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
437    def test_checkpointing_without_reentrant_early_free(self):
438        # I don't know how to check if the temporary saved variable buffer
439        # get de-allocated directly. So using cuda memory usage as a proxy
440
441        def _do_test(fn, should_free):
442            stats: List[int] = []
443
444            def track(x, idx):
445                # Track that at each step of the backward, some Tensor were
446                # de-allocated (which correspond to the checkpoint storage being
447                # emptied at each step)
448                def hook(_unused):
449                    self.assertEqual(len(stats), idx)
450                    torch.cuda.synchronize()
451                    stats.append(torch.cuda.memory_allocated())
452                    if idx > 0:
453                        if should_free:
454                            self.assertLess(stats[idx], stats[idx - 1])
455                        else:
456                            self.assertEqual(stats[idx], stats[idx - 1])
457
458                x.register_hook(hook)
459
460            def test_fn(x):
461                # The main property of this function is that it contains multiple
462                # operations that save gradients in a chain.
463                x = x**2
464                track(x, 2)
465                x = x**2
466                track(x, 1)
467                x = x**2
468                track(x, 0)
469                x = x**2
470                return x.sum()
471
472            fn(test_fn)
473
474            return stats
475
476        x = torch.zeros(10, device="cuda", requires_grad=True)
477        x.grad = torch.zeros_like(x)
478
479        # In a regular backward, buffers get eagerly freed
480        non_retain_stats = _do_test(lambda fn: fn(x).backward(), True)
481
482        # In a retain_grad backward, buffers get preserved
483        _unused_retain_stats = _do_test(
484            lambda fn: fn(x).backward(retain_graph=True), False
485        )
486
487        # In a regular backward with checkpoint, buffers get eagerly freed
488        checkpoint_non_retain_stats = _do_test(
489            lambda fn: checkpoint(fn, x, use_reentrant=False).backward(), True
490        )
491
492        # In a retain_grad backward with checkpoint, buffers get eagerly freed
493        checkpoint_retain_stats = _do_test(
494            lambda fn: checkpoint(fn, x, use_reentrant=False).backward(
495                retain_graph=True
496            ),
497            True,
498        )
499
500        self.assertEqual(non_retain_stats, checkpoint_non_retain_stats)
501        self.assertEqual(non_retain_stats, checkpoint_retain_stats)
502
503    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
504    def test_get_device_states_recursive(self):
505        inp = {
506            "foo": torch.rand(10, device="cuda:0"),
507            "bar": [torch.rand(10, device="cuda:1")],
508        }
509        device_ids, device_states = get_device_states(inp)
510        self.assertEqual(2, len(device_ids))
511        self.assertEqual(2, len(device_states))
512        self.assertEqual(0, device_ids[0])
513        self.assertEqual(1, device_ids[1])
514        self.assertTrue(isinstance(device_states[0], torch.Tensor))
515        self.assertTrue(isinstance(device_states[1], torch.Tensor))
516
517    def test_infer_device_state_recursive_meta(self):
518        inp = {"foo": torch.rand(10, device="meta")}
519        device_type = _infer_device_type(inp)
520        self.assertEqual("meta", device_type)
521
522    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
523    def test_infer_device_state_recursive_multi_cuda(self):
524        # Check that no warning is issued for either cuda:0, cuda:1 or
525        # cuda:0, cuda:0 cases since they are both the same device type
526        inp = {
527            "foo": torch.rand(10, device="cuda:0"),
528            "bar": [torch.rand(10, device="cuda:1")],
529        }
530        with warnings.catch_warnings():
531            warnings.simplefilter("error")
532            device_type = _infer_device_type(inp)
533            self.assertEqual("cuda", device_type)
534        inp = {
535            "foo": torch.rand(10, device="cuda:0"),
536            "bar": [torch.rand(10, device="cuda:0")],
537        }
538        with warnings.catch_warnings():
539            warnings.simplefilter("error")
540            device_type = _infer_device_type(inp)
541            self.assertEqual("cuda", device_type)
542        # Check that a warning is issued for cuda:0, meta and that it includes
543        # device type information
544        inp = {
545            "foo": torch.rand(10, device="cuda:0"),
546            "bar": [torch.rand(10, device="meta")],
547        }
548        with warnings.catch_warnings(record=True) as w:
549            device_type = _infer_device_type(inp)
550            self.assertEqual("cuda", device_type)
551        self.assertEqual(len(w), 1)
552        warning_msg = str(w[-1].message)
553        self.assertTrue(
554            "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices"
555            in warning_msg
556        )
557        self.assertTrue("Device types: ['cuda', 'meta']" in warning_msg)
558        self.assertTrue("first device type: cuda" in warning_msg)
559
560
561class TestDataLoaderUtils(TestCase):
562    MAX_TIMEOUT_IN_SECOND = 300
563
564    def setUp(self):
565        super().setUp()
566        self.dataset = torch.randn(5, 3, 3, 2)
567        self.batch_size = 3
568
569    def test_random_seed(self):
570        def run():
571            dataloader = torch.utils.data.DataLoader(
572                RandomDatasetMock(),
573                batch_size=2,
574                num_workers=4,
575                shuffle=True,
576                timeout=self.MAX_TIMEOUT_IN_SECOND,
577            )
578            return next(iter(dataloader))
579
580        torch.manual_seed(2018)
581        x1 = run()
582        torch.manual_seed(2018)
583        x2 = run()
584        self.assertEqual(x1, x2)
585
586    def test_single_keep(self):
587        # self.dataset is a Tensor here; technically not a valid input because
588        # not a Dataset subclass, but needs to stay working so add ignore's
589        # for type checking with mypy
590        dataloader: DataLoader = DataLoader(
591            self.dataset,  # type: ignore[arg-type]
592            batch_size=self.batch_size,
593            num_workers=0,
594            drop_last=False,
595        )
596        dataiter = iter(dataloader)
597        self.assertEqual(len(list(dataiter)), 2)
598
599    def test_single_drop(self):
600        dataloader: DataLoader = DataLoader(
601            self.dataset,  # type: ignore[arg-type]
602            batch_size=self.batch_size,
603            num_workers=0,
604            drop_last=True,
605        )
606        dataiter = iter(dataloader)
607        self.assertEqual(len(list(dataiter)), 1)
608
609    @unittest.skip(
610        "FIXME: Intermittent CUDA out-of-memory error on Windows and time-out under ASAN"
611    )
612    def test_multi_keep(self):
613        dataloader: DataLoader = DataLoader(
614            self.dataset,  # type: ignore[arg-type]
615            batch_size=self.batch_size,
616            num_workers=2,
617            drop_last=False,
618            timeout=self.MAX_TIMEOUT_IN_SECOND,
619        )
620        dataiter = iter(dataloader)
621        self.assertEqual(len(list(dataiter)), 2)
622
623    def test_multi_drop(self):
624        dataloader: DataLoader = DataLoader(
625            self.dataset,  # type: ignore[arg-type]
626            batch_size=self.batch_size,
627            num_workers=2,
628            drop_last=True,
629            timeout=self.MAX_TIMEOUT_IN_SECOND,
630        )
631        dataiter = iter(dataloader)
632        self.assertEqual(len(list(dataiter)), 1)
633
634
635test_dir = os.path.abspath(os.path.dirname(str(__file__)))
636
637
638@unittest.skipIf(
639    "SKIP_TEST_BOTTLENECK" in os.environ.keys(), "SKIP_TEST_BOTTLENECK is set"
640)
641class TestBottleneck(TestCase):
642    def _run(self, command, timeout=30):
643        """Returns (return-code, stdout, stderr)"""
644        import subprocess
645
646        p = subprocess.Popen(
647            command,
648            stdout=subprocess.PIPE,
649            stderr=subprocess.PIPE,
650            shell=True,
651        )
652        try:
653            output, err = p.communicate(timeout=timeout)
654        except subprocess.TimeoutExpired:
655            p.kill()
656            output, err = p.communicate()
657        rc = p.returncode
658        output_str = output.decode("ascii")
659        err_str = err.decode("ascii")
660        return (rc, output_str, err_str)
661
662    def _run_bottleneck(self, test_file, scriptargs=""):
663        curdir = os.path.dirname(os.path.abspath(__file__))
664        filepath = f"{curdir}/{test_file}"
665        if scriptargs != "":
666            scriptargs = f" {scriptargs}"
667        rc, out, err = self._run(
668            f"{sys.executable} -m torch.utils.bottleneck {filepath}{scriptargs}"
669        )
670        return rc, out, err
671
672    def _check_run_args(self):
673        # Check that this fails due to missing args
674        rc, out, err = self._run_bottleneck("bottleneck_test/test_args.py")
675        self.assertEqual(
676            rc,
677            2,
678            atol=0,
679            rtol=0,
680            msg=self._fail_msg("Missing args should error", out + err),
681        )
682
683        # This should succeed
684        rc, out, err = self._run_bottleneck(
685            "bottleneck_test/test_args.py", "--foo foo --bar bar"
686        )
687        self.assertEqual(
688            rc,
689            0,
690            atol=0,
691            rtol=0,
692            msg=self._fail_msg("Should pass args to script", out + err),
693        )
694
695    def _fail_msg(self, msg, output):
696        return f"{msg}, output was:\n{output}"
697
698    def _check_environment_summary(self, output):
699        results = re.search("Environment Summary", output)
700        self.assertIsNotNone(
701            results, self._fail_msg("Should have Environment Summary", output)
702        )
703
704        # Up to five lines away from the heading, there should be the version number
705        results = re.search(
706            r"Environment Summary.*(\n.*){,5}\nPyTorch \d+\.\d+", output
707        )
708        self.assertIsNotNone(
709            results, self._fail_msg("Should have PyTorch version", output)
710        )
711
712    def _check_cprof_summary(self, output):
713        results = re.search("cProfile output", output)
714        self.assertIsNotNone(
715            results, self._fail_msg("Should have cProfile output", output)
716        )
717
718        # This assumes that after the cProfile output section we have
719        # the autograd profiler output
720        results = re.search(
721            r"cProfile output.*(\n.*){6,50}\n.*autograd profiler output", output
722        )
723        self.assertIsNotNone(
724            results,
725            self._fail_msg(
726                "Distance between cProfile and autograd prof out not in [6, 50] lines",
727                output,
728            ),
729        )
730
731    def _check_autograd_summary(self, output):
732        results = re.search("autograd profiler output", output)
733        self.assertIsNotNone(
734            results, self._fail_msg("Should have autograd profiler output", output)
735        )
736
737        # This assumes that after the autograd profiler output is the end of the
738        # output.
739        results = re.search(r"autograd profiler output.*(\n.*){6,100}", output)
740        self.assertIsNotNone(
741            results,
742            self._fail_msg(
743                "Distance between autograd prof output and end of output not in [6, 100] lines",
744                output,
745            ),
746        )
747
748    def _check_cuda(self, output):
749        if HAS_CUDA:
750            results = re.search("CUDA mode", output)
751            self.assertIsNotNone(
752                results, self._fail_msg("Should tell users CUDA", output)
753            )
754        else:
755            results = re.search("CUDA mode", output)
756            self.assertIsNone(
757                results, self._fail_msg("Should not tell users about CUDA", output)
758            )
759
760    @unittest.skipIf(HAS_CUDA, "CPU-only test")
761    def test_bottleneck_cpu_only(self):
762        rc, out, err = self._run_bottleneck("bottleneck_test/test.py")
763        self.assertEqual(rc, 0, msg=f"Run failed with\n{err}")
764
765        self._check_run_args()
766        self._check_environment_summary(out)
767        self._check_autograd_summary(out)
768        self._check_cprof_summary(out)
769        self._check_cuda(out)
770
771    @unittest.skipIf(not HAS_CUDA, "No CUDA")
772    def test_bottleneck_cuda(self):
773        rc, out, err = self._run_bottleneck("bottleneck_test/test_cuda.py")
774        self.assertEqual(rc, 0, msg=f"Run failed with\n{err}")
775
776        self._check_run_args()
777        self._check_environment_summary(out)
778        self._check_autograd_summary(out)
779        self._check_cprof_summary(out)
780        self._check_cuda(out)
781
782
783from torch.utils.collect_env import get_pretty_env_info
784
785
786@unittest.skipIf(IS_FBCODE, "runs pip which is not available internally")
787class TestCollectEnv(TestCase):
788    def test_smoke(self):
789        info_output = get_pretty_env_info()
790        self.assertTrue(info_output.count("\n") >= 17)
791
792
793class TestONNXUtils(TestCase):
794    def test_prepare_onnx_paddings(self):
795        sizes = [2, 3, 4]
796        pad = [1, 2, 3, 4]
797        paddings = _prepare_onnx_paddings(len(sizes), pad)
798        self.assertEqual(paddings, [0, 3, 1, 0, 4, 2])
799
800    def test_check_onnx_broadcast(self):
801        def try_check_onnx_broadcast(dims1, dims2, expect_broadcast, expect_fail):
802            broadcast = True
803            fail = False
804            try:
805                broadcast = check_onnx_broadcast(dims1, dims2)
806            except ValueError:
807                fail = True
808            self.assertEqual(broadcast, expect_broadcast)
809            self.assertEqual(fail, expect_fail)
810
811        # Case 1, check the case when len(dims1) < len(dims2) and numel(dims2) > 1
812        dims1 = [3, 4]
813        dims2 = [2, 3, 4]
814        try_check_onnx_broadcast(dims1, dims2, True, True)
815
816        # Case 2, check the case when len(dims1) < len(dims2) and numel(dims2) == 1
817        dims1 = [3, 4]
818        dims2 = [1, 1, 1]
819        try_check_onnx_broadcast(dims1, dims2, True, False)
820
821        # Case 3, check the case when len(dims1) > len(dims2) and numel(dims2) == 1
822        dims1 = [1, 1]
823        dims2 = [1]
824        try_check_onnx_broadcast(dims1, dims2, True, False)
825
826        # Case 4, check the case when len(dims1) > len(dims2) and dims1[x:] == dims2
827        dims1 = [2, 3, 4]
828        dims2 = [3, 4]
829        try_check_onnx_broadcast(dims1, dims2, True, False)
830
831        # Case 5, check the case when len(dims1) > len(dims2), but dims1[x:] != dims2
832        dims1 = [2, 3, 4]
833        dims2 = [1, 4]
834        try_check_onnx_broadcast(dims1, dims2, True, True)
835
836        # Case 6, check the equal case, no broadcast
837        dims1 = [3, 4]
838        dims2 = [3, 4]
839        try_check_onnx_broadcast(dims1, dims2, False, False)
840
841        # Case 7, check the case when len(dims1) == len(dims2), but dims1 != dims2
842        dims1 = [3, 4]
843        dims2 = [1, 4]
844        try_check_onnx_broadcast(dims1, dims2, True, True)
845
846        # Case 8, check the case when len(dims1) == len(dims2) and numel(s2) == 1
847        dims1 = [3, 4]
848        dims2 = [1, 1]
849        try_check_onnx_broadcast(dims1, dims2, True, False)
850
851
852class TestHipify(TestCase):
853    def test_import_hipify(self):
854        from torch.utils.hipify import hipify_python  # noqa: F401
855
856
857class TestHipifyTrie(TestCase):
858    def setUp(self):
859        self.trie = torch.utils.hipify.hipify_python.Trie()
860
861    def test_add_and_search_trie(self):
862        self.trie.add("banana")
863        self.assertTrue(self.trie.search("banana"))
864        self.assertFalse(self.trie.search("ban"))
865        self.assertFalse(self.trie.search("dog"))
866
867    def test_add_multiple_and_search_trie(self):
868        words_to_add = ["banana", "apple", "orange"]
869        for word in words_to_add:
870            self.trie.add(word)
871
872        for word in words_to_add:
873            self.assertTrue(self.trie.search(word))
874
875        for word in ["ban", "dog", "okay", "app"]:
876            self.assertFalse(self.trie.search(word))
877
878    def test_quote_escape(self):
879        orig_chars = ["*", "[", ".", "+", "a", "z", "-"]
880        quoted_strs = ["\\*", "\\[", "\\.", "\\+", "a", "z", "\\-"]
881        for i in range(len(orig_chars)):
882            self.assertEqual(self.trie.quote(orig_chars[i]), quoted_strs[i])
883
884    def test_export_trie_to_regex(self):
885        words_to_add = [
886            "__CUDACC__",
887            "CUDA_ERROR_CONTEXT_ALREADY_CURRENT",
888            "CUDA_ERROR_ARRAY_IS_MAPPED",
889            "CUDA_ERROR_NOT_MAPPED",
890            "CUDA_ERROR_INVALID_SOURCE",
891        ]
892        for word in words_to_add:
893            self.trie.add(word)
894        regex = self.trie.export_to_regex()
895        expected_regex = r"(?:CUDA_ERROR_(?:ARRAY_IS_MAPPED|CONTEXT_ALREADY_CURRENT|INVALID_SOURCE|NOT_MAPPED)|__CUDACC__)"
896        self.assertEqual(regex, expected_regex)
897
898    def test_prefix_words_export_trie_to_regex(self):
899        # test case where some nodes have both children and are also leaf nodes.
900        words_to_add = ["apple", "app", "ban", "banana"]
901        for word in words_to_add:
902            self.trie.add(word)
903        regex = self.trie.export_to_regex()
904        expected_regex = r"(?:app(?:le)?|ban(?:ana)?)"
905        self.assertEqual(regex, expected_regex)
906
907    def test_single_export_trie_to_regex(self):
908        words_to_add = ["cudaErrorInvalidMemcpyDirection"]
909        for word in words_to_add:
910            self.trie.add(word)
911        regex = self.trie.export_to_regex()
912        expected_regex = "cudaErrorInvalidMemcpyDirection"
913        self.assertEqual(regex, expected_regex)
914
915    def test_char_export_trie_to_regex(self):
916        self.trie.add("a")
917        self.assertEqual(self.trie.export_to_regex(), "a")
918        self.trie.add("b")
919        self.assertEqual(self.trie.export_to_regex(), "[ab]")
920
921    def test_special_char_export_trie_to_regex(self):
922        self.trie.add(r"c*")
923        self.assertEqual(self.trie.export_to_regex(), r"c\*")
924
925
926class TestAssert(TestCase):
927    def test_assert_true(self):
928        # verify assertions work as expected
929        # bool argument
930        torch._assert(True, "foo")
931        with self.assertRaisesRegex(AssertionError, "bar"):
932            torch._assert(False, "bar")
933        # tensor argument
934        torch._assert(torch.tensor([True], dtype=torch.bool), "foo")
935        with self.assertRaisesRegex(AssertionError, "bar"):
936            torch._assert(torch.tensor([False], dtype=torch.bool), "bar")
937
938    def test_assert_scriptable(self):
939        class M(torch.nn.Module):
940            def forward(self, x):
941                torch._assert(x.sum() > 0, "foo")
942                return x
943
944        m = M()
945        # scriptable
946        ms = torch.jit.script(m)
947        # data can be passed without errors
948        x = torch.randn(4, 4).fill_(1.0)
949        ms(x)
950        with self.assertRaisesRegex(torch.jit.Error, "foo"):
951            ms(torch.tensor([False], dtype=torch.bool))
952
953
954@unittest.skipIf(IS_SANDCASTLE, "cpp_extension is OSS only")
955class TestStandaloneCPPJIT(TestCase):
956    def test_load_standalone(self):
957        build_dir = tempfile.mkdtemp()
958        try:
959            src_path = os.path.join(build_dir, "main.cpp")
960            src = textwrap.dedent(
961                """\
962                #include <iostream>
963                #include <torch/torch.h>
964                int main() {
965                    auto x = torch::eye(3);
966                    std::cout << x << std::endl;
967                }
968            """
969            )
970            with open(src_path, "w") as f:
971                f.write(src)
972
973            exec_path = torch.utils.cpp_extension.load(
974                "standalone_load_test",
975                src_path,
976                build_directory=build_dir,
977                is_python_module=False,
978                is_standalone=True,
979            )
980
981            ext = ".exe" if IS_WINDOWS else ""
982            self.assertEqual(
983                exec_path, os.path.join(build_dir, f"standalone_load_test{ext}")
984            )
985
986            for shell in [True, False]:
987                r = subprocess.run(
988                    [exec_path],
989                    shell=shell,
990                    stdout=subprocess.PIPE,
991                )
992                self.assertEqual(r.returncode, 0)
993                self.assertEqual(
994                    # Windows prints "\r\n" for newlines.
995                    textwrap.dedent(r.stdout.decode("utf-8")).replace("\r\n", "\n"),
996                    textwrap.dedent(
997                        """\
998                     1  0  0
999                     0  1  0
1000                     0  0  1
1001                    [ CPUFloatType{3,3} ]
1002                    """
1003                    ),
1004                )
1005
1006        finally:
1007            shutil.rmtree(build_dir)
1008
1009
1010class DummyPrivateUse1Module:
1011    @staticmethod
1012    def is_available():
1013        return True
1014
1015    @staticmethod
1016    def is_autocast_enabled():
1017        return True
1018
1019    @staticmethod
1020    def get_autocast_dtype():
1021        return torch.float16
1022
1023    @staticmethod
1024    def set_autocast_enabled(enable):
1025        pass
1026
1027    @staticmethod
1028    def set_autocast_dtype(dtype):
1029        pass
1030
1031    @staticmethod
1032    def get_amp_supported_dtype():
1033        return [torch.float16]
1034
1035
1036class TestExtensionUtils(TestCase):
1037    def tearDown(self):
1038        # Clean up
1039        backend_name = torch._C._get_privateuse1_backend_name()
1040        if hasattr(torch, backend_name):
1041            delattr(torch, backend_name)
1042        if f"torch.{backend_name}" in sys.modules:
1043            del sys.modules[f"torch.{backend_name}"]
1044
1045    def test_external_module_register(self):
1046        # Built-in module
1047        with self.assertRaisesRegex(RuntimeError, "The runtime module of"):
1048            torch._register_device_module("cuda", torch.cuda)
1049
1050        # Wrong device type
1051        with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"):
1052            torch._register_device_module("dummmy", DummyPrivateUse1Module)
1053
1054        with self.assertRaises(AttributeError):
1055            torch.privateuseone.is_available()  # type: ignore[attr-defined]
1056
1057        torch._register_device_module("privateuseone", DummyPrivateUse1Module)
1058
1059        torch.privateuseone.is_available()  # type: ignore[attr-defined]
1060
1061        # No supporting for override
1062        with self.assertRaisesRegex(RuntimeError, "The runtime module of"):
1063            torch._register_device_module("privateuseone", DummyPrivateUse1Module)
1064
1065    def test_external_module_register_with_renamed_backend(self):
1066        torch.utils.rename_privateuse1_backend("foo")
1067        with self.assertRaisesRegex(RuntimeError, "has already been set"):
1068            torch.utils.rename_privateuse1_backend("dummmy")
1069
1070        custom_backend_name = torch._C._get_privateuse1_backend_name()
1071        self.assertEqual(custom_backend_name, "foo")
1072
1073        with self.assertRaises(AttributeError):
1074            torch.foo.is_available()  # type: ignore[attr-defined]
1075
1076        with self.assertRaisesRegex(AssertionError, "Tried to use AMP with the"):
1077            with torch.autocast(device_type=custom_backend_name):
1078                pass
1079        torch._register_device_module("foo", DummyPrivateUse1Module)
1080
1081        torch.foo.is_available()  # type: ignore[attr-defined]
1082        with torch.autocast(device_type=custom_backend_name):
1083            pass
1084
1085        self.assertEqual(torch._utils._get_device_index("foo:1"), 1)
1086        self.assertEqual(torch._utils._get_device_index(torch.device("foo:2")), 2)
1087
1088
1089class TestRenderUtils(TestCase):
1090    def test_basic(self):
1091        self.assertExpectedInline(
1092            torch._utils.render_call(torch.sum, [torch.randn(100)], {"dim": 0}),
1093            """torch.sum(tensor([...], size=(100,)), dim=0)""",
1094        )
1095        self.assertExpectedInline(
1096            torch._utils.render_call(torch.sum, [torch.randn(100, 100)], {"dim": 0}),
1097            """torch.sum(tensor([...], size=(100, 100)), dim=0)""",
1098        )
1099
1100
1101class TestDeviceUtils(TestCase):
1102    def test_basic(self):
1103        with torch.device("meta") as dev:
1104            x = torch.empty(3, 3)
1105        self.assertEqual(x.device.type, "meta")
1106        self.assertEqual(dev, torch.device("meta"))
1107
1108    def test_decorator(self):
1109        @set_device("meta")
1110        def f():
1111            return torch.empty(3, 3)
1112
1113        self.assertEqual(f().device.type, "meta")
1114
1115    def test_decorator_generator(self):
1116        @set_device("meta")
1117        def f():
1118            yield torch.empty(3, 3)
1119            yield torch.empty(3, 3)
1120
1121        r1, r2 = list(f())
1122        self.assertEqual(r1.device.type, "meta")
1123        self.assertEqual(r2.device.type, "meta")
1124
1125    def test_nn_module(self):
1126        with torch.device("meta"):
1127            m = nn.Linear(40, 50)
1128        self.assertEqual(m.weight.device.type, "meta")
1129
1130    def test_set_default_device(self):
1131        try:
1132            torch.set_default_device("meta")
1133            r = torch.empty(2, 2)
1134        finally:
1135            torch.set_default_device(None)
1136
1137        self.assertEqual(r.device.type, "meta")
1138
1139    def test_get_default_device(self):
1140        torch.set_default_device("meta")
1141        self.assertEqual(torch.get_default_device().type, "meta")
1142        torch.set_default_device(None)
1143
1144    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
1145    def test_get_default_device_more(self):
1146        torch.set_default_device("cuda")
1147        self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
1148        torch.set_default_device(None)
1149
1150        torch.set_default_device("cuda")
1151        torch.cuda.set_device("cuda:1")
1152        self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
1153        torch.set_default_device(None)
1154
1155        torch.set_default_device("cuda:1")
1156        self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
1157        torch.set_default_device(None)
1158
1159    @onlyCPU
1160    @ops(op_db)
1161    def test_device_mode_ops(self, device, dtype, op):
1162        func = op.get_op()
1163        samples = op.sample_inputs(device, dtype, requires_grad=False)
1164        for sample in samples:
1165            # Only test samples which don't have Tensor inputs.  However,
1166            # we don't test the factory property on OpInfo as it is very,
1167            # very incomplete
1168            if tree_any(
1169                lambda x: isinstance(x, torch.Tensor),
1170                (sample.input, sample.args, sample.kwargs),
1171            ):
1172                continue
1173            # Many OpInfos will explicitly pass in a device.  DeviceContext
1174            # will respect device if it is explicitly specified.  To test
1175            # DeviceContext, we have to remove the device kwarg in this case.
1176            # NB: Can't pass None to sample_inputs, the function can't
1177            # handle it.
1178            kwargs = sample.kwargs.copy()
1179            kwargs.pop("device", None)
1180            with torch.device("meta"):
1181                r = func(sample.input, *sample.args, **kwargs)
1182
1183            def is_meta_device(x: torch.Tensor) -> bool:
1184                return x.device.type == "meta"
1185
1186            self.assertTrue(tree_all_only(torch.Tensor, is_meta_device, r))
1187
1188
1189instantiate_device_type_tests(TestDeviceUtils, globals())
1190
1191
1192class TestCppExtensionUtils(TestCase):
1193    def test_cpp_compiler_is_ok(self):
1194        self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform("c++"))
1195
1196    def test_cc_compiler_is_ok(self):
1197        self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform("cc"))
1198
1199
1200class TestTraceback(TestCase):
1201    def test_basic(self):
1202        source = """\
1203def f(x):
1204    def g(x):
1205        raise RuntimeError  # HEYA
1206
1207    x = x * 3
1208    return g(x) + 1
1209"""
1210
1211        out: Dict[str, Any] = {}
1212        scope = {"__compile_source__": source}
1213        exec(source, scope, out)
1214
1215        try:
1216            with report_compile_source_on_error():
1217                out["f"](1)
1218        except RuntimeError as e:
1219            self.assertIn("HEYA", "".join(traceback.format_tb(e.__traceback__)))
1220
1221    def test_format_traceback_short(self):
1222        try:
1223            raise RuntimeError
1224        except RuntimeError as e:
1225            self.assertRegex(
1226                format_traceback_short(e.__traceback__),
1227                r".*test_utils.py:\d+ in test_format_traceback_short",
1228            )
1229
1230    def test_captured_traceback(self):
1231        self.assertIn(
1232            "test_captured_traceback", "".join(CapturedTraceback.extract().format())
1233        )
1234
1235    def test_captured_traceback_format_all(self):
1236        rs = CapturedTraceback.format_all(
1237            [CapturedTraceback.extract(), CapturedTraceback.extract()]
1238        )
1239        self.assertEqual(len(rs), 2)
1240        self.assertIn("test_captured_traceback_format_all", "".join(rs[0]))
1241
1242    def test_captured_traceback_format_all_cached(self):
1243        tb = CapturedTraceback.extract()
1244        tb.format()  # cached
1245        rs = CapturedTraceback.format_all([tb, CapturedTraceback.extract()])
1246        self.assertEqual(len(rs), 2)
1247        self.assertIn("test_captured_traceback_format_all", "".join(rs[0]))
1248
1249
1250if __name__ == "__main__":
1251    run_tests()
1252