xref: /aosp_15_r20/external/pytorch/test/test_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"]
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport random
6*da0073e9SAndroid Build Coastguard Workerimport re
7*da0073e9SAndroid Build Coastguard Workerimport shutil
8*da0073e9SAndroid Build Coastguard Workerimport subprocess
9*da0073e9SAndroid Build Coastguard Workerimport sys
10*da0073e9SAndroid Build Coastguard Workerimport tempfile
11*da0073e9SAndroid Build Coastguard Workerimport textwrap
12*da0073e9SAndroid Build Coastguard Workerimport traceback
13*da0073e9SAndroid Build Coastguard Workerimport unittest
14*da0073e9SAndroid Build Coastguard Workerimport warnings
15*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, List
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerimport torch
18*da0073e9SAndroid Build Coastguard Workerimport torch.cuda
19*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
20*da0073e9SAndroid Build Coastguard Workerimport torch.utils.cpp_extension
21*da0073e9SAndroid Build Coastguard Workerimport torch.utils.data
22*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd._functions.utils import check_onnx_broadcast
23*da0073e9SAndroid Build Coastguard Workerfrom torch.onnx.symbolic_opset9 import _prepare_onnx_paddings
24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_MULTIGPU
25*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
26*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
27*da0073e9SAndroid Build Coastguard Worker    onlyCPU,
28*da0073e9SAndroid Build Coastguard Worker    ops,
29*da0073e9SAndroid Build Coastguard Worker)
30*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import op_db
31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (  # type: ignore[attr-defined]
32*da0073e9SAndroid Build Coastguard Worker    IS_FBCODE,
33*da0073e9SAndroid Build Coastguard Worker    IS_SANDCASTLE,
34*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
35*da0073e9SAndroid Build Coastguard Worker    load_tests,
36*da0073e9SAndroid Build Coastguard Worker)
37*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._device import set_device
38*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_all_only, tree_any
39*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._traceback import (
40*da0073e9SAndroid Build Coastguard Worker    CapturedTraceback,
41*da0073e9SAndroid Build Coastguard Worker    format_traceback_short,
42*da0073e9SAndroid Build Coastguard Worker    report_compile_source_on_error,
43*da0073e9SAndroid Build Coastguard Worker)
44*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.checkpoint import (
45*da0073e9SAndroid Build Coastguard Worker    _infer_device_type,
46*da0073e9SAndroid Build Coastguard Worker    checkpoint,
47*da0073e9SAndroid Build Coastguard Worker    checkpoint_sequential,
48*da0073e9SAndroid Build Coastguard Worker    get_device_states,
49*da0073e9SAndroid Build Coastguard Worker)
50*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data import DataLoader
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
54*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings
55*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard WorkerHAS_CUDA = torch.cuda.is_available()
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Workerclass RandomDatasetMock(torch.utils.data.Dataset):
64*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, index):
65*da0073e9SAndroid Build Coastguard Worker        return torch.tensor([torch.rand(1).item(), random.uniform(0, 1)])
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
68*da0073e9SAndroid Build Coastguard Worker        return 1000
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Workerclass TestCheckpoint(TestCase):
72*da0073e9SAndroid Build Coastguard Worker    # This runs checkpoint_sequential on each of the nets in
73*da0073e9SAndroid Build Coastguard Worker    # module_lists_to_compare, and compares them against the uncheckpointed model.
74*da0073e9SAndroid Build Coastguard Worker    # To compare, it checks outputs as well as input gradients and parameter gradients
75*da0073e9SAndroid Build Coastguard Worker    def _check_checkpoint_sequential(
76*da0073e9SAndroid Build Coastguard Worker        self,
77*da0073e9SAndroid Build Coastguard Worker        model,
78*da0073e9SAndroid Build Coastguard Worker        module_lists_to_compare,
79*da0073e9SAndroid Build Coastguard Worker        num_chunks,
80*da0073e9SAndroid Build Coastguard Worker        input,
81*da0073e9SAndroid Build Coastguard Worker        use_reentrant,
82*da0073e9SAndroid Build Coastguard Worker    ):
83*da0073e9SAndroid Build Coastguard Worker        # not checkpointed
84*da0073e9SAndroid Build Coastguard Worker        out = model(input)
85*da0073e9SAndroid Build Coastguard Worker        out_not_checkpointed = out.detach().clone()
86*da0073e9SAndroid Build Coastguard Worker        model.zero_grad()
87*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
88*da0073e9SAndroid Build Coastguard Worker        grad_not_checkpointed = {
89*da0073e9SAndroid Build Coastguard Worker            name: param.grad.detach().clone()
90*da0073e9SAndroid Build Coastguard Worker            for name, param in model.named_parameters()
91*da0073e9SAndroid Build Coastguard Worker        }
92*da0073e9SAndroid Build Coastguard Worker        input_grad_not_checkpointed = input.grad.detach().clone()
93*da0073e9SAndroid Build Coastguard Worker        for model_to_compare in module_lists_to_compare:
94*da0073e9SAndroid Build Coastguard Worker            # checkpointed model by passing list of modules
95*da0073e9SAndroid Build Coastguard Worker            detached = input.detach()
96*da0073e9SAndroid Build Coastguard Worker            detached.requires_grad = True
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker            # pass list of modules to checkpoint
99*da0073e9SAndroid Build Coastguard Worker            out = checkpoint_sequential(
100*da0073e9SAndroid Build Coastguard Worker                model_to_compare, num_chunks, detached, use_reentrant=use_reentrant
101*da0073e9SAndroid Build Coastguard Worker            )
102*da0073e9SAndroid Build Coastguard Worker            out_checkpointed = out.detach().clone()
103*da0073e9SAndroid Build Coastguard Worker            model.zero_grad()
104*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
105*da0073e9SAndroid Build Coastguard Worker            grad_checkpointed = {
106*da0073e9SAndroid Build Coastguard Worker                name: param.grad.detach().clone()
107*da0073e9SAndroid Build Coastguard Worker                for name, param in model.named_parameters()
108*da0073e9SAndroid Build Coastguard Worker            }
109*da0073e9SAndroid Build Coastguard Worker            input_grad_checkpointed = detached.grad.detach().clone()
110*da0073e9SAndroid Build Coastguard Worker            # compare outputs as well as the gradients of input and parameters
111*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_checkpointed, out_not_checkpointed)
112*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input_grad_not_checkpointed, input_grad_checkpointed)
113*da0073e9SAndroid Build Coastguard Worker            for name in grad_checkpointed:
114*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    # Test whether checkpoint is being triggered or not. For this, we check
117*da0073e9SAndroid Build Coastguard Worker    # the number of times forward pass happens
118*da0073e9SAndroid Build Coastguard Worker    def test_checkpoint_trigger(self):
119*da0073e9SAndroid Build Coastguard Worker        class Net(nn.Module):
120*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
121*da0073e9SAndroid Build Coastguard Worker                super().__init__()
122*da0073e9SAndroid Build Coastguard Worker                self.counter = 0
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker            def forward(self, input_var):
125*da0073e9SAndroid Build Coastguard Worker                self.counter += 1
126*da0073e9SAndroid Build Coastguard Worker                # For reentrant, need to have autograd actually
127*da0073e9SAndroid Build Coastguard Worker                # pack a tensor to trigger recomp
128*da0073e9SAndroid Build Coastguard Worker                ret = input_var * torch.tensor(2.0)
129*da0073e9SAndroid Build Coastguard Worker                return ret
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker        # checkpointed
132*da0073e9SAndroid Build Coastguard Worker        for use_reentrant in [True, False]:
133*da0073e9SAndroid Build Coastguard Worker            with self.subTest(use_reentrant=use_reentrant):
134*da0073e9SAndroid Build Coastguard Worker                modules = [Net() for _ in range(10)]
135*da0073e9SAndroid Build Coastguard Worker                for m in modules:
136*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(m.counter, 0)
137*da0073e9SAndroid Build Coastguard Worker                input_var = torch.randn(3, 4, requires_grad=True)
138*da0073e9SAndroid Build Coastguard Worker                out = checkpoint_sequential(
139*da0073e9SAndroid Build Coastguard Worker                    modules, 2, input_var, use_reentrant=use_reentrant
140*da0073e9SAndroid Build Coastguard Worker                )
141*da0073e9SAndroid Build Coastguard Worker                for m in modules:
142*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(m.counter, 1)
143*da0073e9SAndroid Build Coastguard Worker                out.sum().backward()
144*da0073e9SAndroid Build Coastguard Worker                for m in modules[: (len(modules) // 2)]:
145*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(m.counter, 2)
146*da0073e9SAndroid Build Coastguard Worker                for m in modules[(len(modules) // 2) :]:
147*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(m.counter, 1)
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker    def test_checkpoint_valid(self):
150*da0073e9SAndroid Build Coastguard Worker        model = nn.Sequential(
151*da0073e9SAndroid Build Coastguard Worker            nn.Linear(100, 50),
152*da0073e9SAndroid Build Coastguard Worker            nn.ReLU(),
153*da0073e9SAndroid Build Coastguard Worker            nn.Linear(50, 20),
154*da0073e9SAndroid Build Coastguard Worker            nn.ReLU(),
155*da0073e9SAndroid Build Coastguard Worker            nn.Linear(20, 5),
156*da0073e9SAndroid Build Coastguard Worker            nn.ReLU(),
157*da0073e9SAndroid Build Coastguard Worker        )
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker        input_var = torch.randn(1, 100, requires_grad=True)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker        # checkpointed
162*da0073e9SAndroid Build Coastguard Worker        chunks = 2
163*da0073e9SAndroid Build Coastguard Worker        modules = list(model.children())
164*da0073e9SAndroid Build Coastguard Worker        out = checkpoint_sequential(modules, chunks, input_var, use_reentrant=True)
165*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
166*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "torch.utils.checkpoint is incompatible"
167*da0073e9SAndroid Build Coastguard Worker        ):
168*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(
169*da0073e9SAndroid Build Coastguard Worker                outputs=[out],
170*da0073e9SAndroid Build Coastguard Worker                grad_outputs=[torch.ones(1, 5)],
171*da0073e9SAndroid Build Coastguard Worker                inputs=[input_var],
172*da0073e9SAndroid Build Coastguard Worker                create_graph=True,
173*da0073e9SAndroid Build Coastguard Worker            )
174*da0073e9SAndroid Build Coastguard Worker        # works with use_reentrant=False, and grads are the same
175*da0073e9SAndroid Build Coastguard Worker        out = model(input_var)
176*da0073e9SAndroid Build Coastguard Worker        grads_no_checkpoint = torch.autograd.grad(
177*da0073e9SAndroid Build Coastguard Worker            outputs=[out],
178*da0073e9SAndroid Build Coastguard Worker            grad_outputs=[torch.ones(1, 5)],
179*da0073e9SAndroid Build Coastguard Worker            inputs=[input_var],
180*da0073e9SAndroid Build Coastguard Worker            create_graph=True,
181*da0073e9SAndroid Build Coastguard Worker        )
182*da0073e9SAndroid Build Coastguard Worker        out_checkpoint = checkpoint_sequential(
183*da0073e9SAndroid Build Coastguard Worker            modules, chunks, input_var, use_reentrant=False
184*da0073e9SAndroid Build Coastguard Worker        )
185*da0073e9SAndroid Build Coastguard Worker        # check outputs are the same
186*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_checkpoint, out)
187*da0073e9SAndroid Build Coastguard Worker        grads_checkpoint = torch.autograd.grad(
188*da0073e9SAndroid Build Coastguard Worker            outputs=[out_checkpoint],
189*da0073e9SAndroid Build Coastguard Worker            grad_outputs=[torch.ones(1, 5)],
190*da0073e9SAndroid Build Coastguard Worker            inputs=[input_var],
191*da0073e9SAndroid Build Coastguard Worker            create_graph=True,
192*da0073e9SAndroid Build Coastguard Worker        )
193*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grads_no_checkpoint, grads_checkpoint)
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker    def test_checkpoint(self):
196*da0073e9SAndroid Build Coastguard Worker        for use_reentrant in [True, False]:
197*da0073e9SAndroid Build Coastguard Worker            with self.subTest(use_reentrant=use_reentrant):
198*da0073e9SAndroid Build Coastguard Worker                model = nn.Sequential(
199*da0073e9SAndroid Build Coastguard Worker                    nn.Linear(100, 50),
200*da0073e9SAndroid Build Coastguard Worker                    nn.ReLU(),
201*da0073e9SAndroid Build Coastguard Worker                    nn.Linear(50, 20),
202*da0073e9SAndroid Build Coastguard Worker                    nn.ReLU(),
203*da0073e9SAndroid Build Coastguard Worker                    nn.Linear(20, 5),
204*da0073e9SAndroid Build Coastguard Worker                    nn.ReLU(),
205*da0073e9SAndroid Build Coastguard Worker                )
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker                # Compare uncheckpointed model with its checkpointed counterparts
208*da0073e9SAndroid Build Coastguard Worker                # In addition to running checkpoint_sequential on the nn.Sequential
209*da0073e9SAndroid Build Coastguard Worker                # instance, we also run the function on the list of functions within
210*da0073e9SAndroid Build Coastguard Worker                # the module.
211*da0073e9SAndroid Build Coastguard Worker                self._check_checkpoint_sequential(
212*da0073e9SAndroid Build Coastguard Worker                    model,
213*da0073e9SAndroid Build Coastguard Worker                    [list(model.children()), model],
214*da0073e9SAndroid Build Coastguard Worker                    2,
215*da0073e9SAndroid Build Coastguard Worker                    torch.randn(1, 100, requires_grad=True),
216*da0073e9SAndroid Build Coastguard Worker                    use_reentrant=use_reentrant,
217*da0073e9SAndroid Build Coastguard Worker                )
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker    def test_checkpoint_module_list(self):
220*da0073e9SAndroid Build Coastguard Worker        class ModuleListNet(nn.Module):
221*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
222*da0073e9SAndroid Build Coastguard Worker                super().__init__()
223*da0073e9SAndroid Build Coastguard Worker                module_list = [
224*da0073e9SAndroid Build Coastguard Worker                    nn.Linear(100, 50),
225*da0073e9SAndroid Build Coastguard Worker                    nn.ReLU(),
226*da0073e9SAndroid Build Coastguard Worker                    nn.Linear(50, 20),
227*da0073e9SAndroid Build Coastguard Worker                    nn.ReLU(),
228*da0073e9SAndroid Build Coastguard Worker                    nn.Linear(20, 5),
229*da0073e9SAndroid Build Coastguard Worker                    nn.ReLU(),
230*da0073e9SAndroid Build Coastguard Worker                ]
231*da0073e9SAndroid Build Coastguard Worker                self.module_list = nn.ModuleList(module_list)
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
234*da0073e9SAndroid Build Coastguard Worker                for layer in self.module_list:
235*da0073e9SAndroid Build Coastguard Worker                    input = layer(input)
236*da0073e9SAndroid Build Coastguard Worker                return input
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker        for use_reentrant in [True, False]:
239*da0073e9SAndroid Build Coastguard Worker            with self.subTest(use_reentrant=use_reentrant):
240*da0073e9SAndroid Build Coastguard Worker                model = ModuleListNet()
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker                # Compare uncheckpointed model with its checkpointed counterparts.
243*da0073e9SAndroid Build Coastguard Worker                self._check_checkpoint_sequential(
244*da0073e9SAndroid Build Coastguard Worker                    model,
245*da0073e9SAndroid Build Coastguard Worker                    [list(model.module_list.children()), model.module_list],
246*da0073e9SAndroid Build Coastguard Worker                    2,
247*da0073e9SAndroid Build Coastguard Worker                    torch.randn(1, 100, requires_grad=True),
248*da0073e9SAndroid Build Coastguard Worker                    use_reentrant=use_reentrant,
249*da0073e9SAndroid Build Coastguard Worker                )
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker    def test_checkpoint_sequential_deprecated_multiple_args(self):
252*da0073e9SAndroid Build Coastguard Worker        class Two(nn.Module):
253*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
254*da0073e9SAndroid Build Coastguard Worker                return a, b
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker        model = nn.Sequential(Two())
257*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 100, requires_grad=True)
258*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(1, 100, requires_grad=True)
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        for use_reentrant in [True, False]:
261*da0073e9SAndroid Build Coastguard Worker            with self.subTest(use_reentrant=use_reentrant):
262*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(TypeError):
263*da0073e9SAndroid Build Coastguard Worker                    checkpoint_sequential(model, 1, a, b)  # type: ignore[call-arg]
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker    def test_checkpoint_sequential_deprecated_no_args(self):
266*da0073e9SAndroid Build Coastguard Worker        class Noop(nn.Module):
267*da0073e9SAndroid Build Coastguard Worker            def forward(self):
268*da0073e9SAndroid Build Coastguard Worker                pass
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker        model = nn.Sequential(Noop())
271*da0073e9SAndroid Build Coastguard Worker        for use_reentrant in [True, False]:
272*da0073e9SAndroid Build Coastguard Worker            with self.subTest(use_reentrant=use_reentrant):
273*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(TypeError):
274*da0073e9SAndroid Build Coastguard Worker                    checkpoint_sequential(model, 1)  # type: ignore[call-arg]
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker    def test_checkpoint_rng_cpu(self):
277*da0073e9SAndroid Build Coastguard Worker        for _ in range(5):
278*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(20000, device="cpu").requires_grad_()
279*da0073e9SAndroid Build Coastguard Worker            phase1 = torch.nn.Dropout()
280*da0073e9SAndroid Build Coastguard Worker            phase2 = torch.nn.Dropout()
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker            def run_fn(input):
283*da0073e9SAndroid Build Coastguard Worker                return phase2(input)
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker            state = torch.get_rng_state()
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker            out = phase1(inp)
288*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(run_fn, out, use_reentrant=True)
289*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
290*da0073e9SAndroid Build Coastguard Worker            grad_with_checkpointing = inp.grad
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker            torch.set_rng_state(state)
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker            inp.grad = None
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker            out = phase1(inp)
297*da0073e9SAndroid Build Coastguard Worker            out = run_fn(out)
298*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
299*da0073e9SAndroid Build Coastguard Worker            grad_no_checkpointing = inp.grad
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad_with_checkpointing, grad_no_checkpointing)
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not HAS_CUDA, "No CUDA")
304*da0073e9SAndroid Build Coastguard Worker    def test_checkpoint_rng_cuda(self):
305*da0073e9SAndroid Build Coastguard Worker        for _ in range(5):
306*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(20000, device="cuda").requires_grad_()
307*da0073e9SAndroid Build Coastguard Worker            phase1 = torch.nn.Dropout()
308*da0073e9SAndroid Build Coastguard Worker            phase2 = torch.nn.Dropout()
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker            def run_fn(input):
311*da0073e9SAndroid Build Coastguard Worker                return phase2(input)
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker            state = torch.cuda.get_rng_state()
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker            out = phase1(inp)
316*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(run_fn, out, use_reentrant=True)
317*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
318*da0073e9SAndroid Build Coastguard Worker            grad_with_checkpointing = inp.grad
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker            torch.cuda.set_rng_state(state)
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker            inp.grad = None
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker            out = phase1(inp)
325*da0073e9SAndroid Build Coastguard Worker            out = run_fn(out)
326*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
327*da0073e9SAndroid Build Coastguard Worker            grad_no_checkpointing = inp.grad
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad_with_checkpointing, grad_no_checkpointing)
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not HAS_CUDA, "No CUDA")
332*da0073e9SAndroid Build Coastguard Worker    def test_checkpoint_not_preserve_rng_state_and_without_reentrant(self):
333*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(2, device="cuda").requires_grad_()
334*da0073e9SAndroid Build Coastguard Worker        layer = torch.nn.Dropout()
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker        def run_fn(input):
337*da0073e9SAndroid Build Coastguard Worker            return layer(input)
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker        out = checkpoint(run_fn, inp, use_reentrant=False, preserve_rng_state=False)
340*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
341*da0073e9SAndroid Build Coastguard Worker        # This should run without error
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker    def test_checkpoint_non_tensor(self):
344*da0073e9SAndroid Build Coastguard Worker        def run_fn(tensor1, tensor2):
345*da0073e9SAndroid Build Coastguard Worker            if tensor2 is None:
346*da0073e9SAndroid Build Coastguard Worker                return tensor1
347*da0073e9SAndroid Build Coastguard Worker            return tensor1 + tensor2
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker        input_var = torch.randn(1, 100, requires_grad=True)
350*da0073e9SAndroid Build Coastguard Worker        out = checkpoint(run_fn, input_var, None, use_reentrant=True)
351*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker    def test_checkpoint_non_tensor_inputs_outputs(self):
354*da0073e9SAndroid Build Coastguard Worker        def foo(t1, t2, scale, t3):
355*da0073e9SAndroid Build Coastguard Worker            t4 = t1 + t2 * t3
356*da0073e9SAndroid Build Coastguard Worker            t5 = t1 * t2 + t3
357*da0073e9SAndroid Build Coastguard Worker            t4 *= scale
358*da0073e9SAndroid Build Coastguard Worker            t5 *= scale
359*da0073e9SAndroid Build Coastguard Worker            return scale, t4, None, True, t5, "bar", t1
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker        t1 = torch.rand(10, requires_grad=True)
362*da0073e9SAndroid Build Coastguard Worker        t2 = torch.rand(10, requires_grad=True)
363*da0073e9SAndroid Build Coastguard Worker        t3 = torch.rand(10)
364*da0073e9SAndroid Build Coastguard Worker        scale = random.randint(0, 10)
365*da0073e9SAndroid Build Coastguard Worker        res = checkpoint(foo, t1, t2, scale, t3, use_reentrant=True)
366*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scale, res[0])
367*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((t1 + t2 * t3) * scale, res[1])
368*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(None, res[2])
369*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(True, res[3])
370*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((t1 * t2 + t3) * scale, res[4])
371*da0073e9SAndroid Build Coastguard Worker        self.assertEqual("bar", res[5])
372*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1, res[6])
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker        # Validate running backward.
375*da0073e9SAndroid Build Coastguard Worker        res[1].sum().backward(retain_graph=True)
376*da0073e9SAndroid Build Coastguard Worker        res[4].sum().backward(retain_graph=True)
377*da0073e9SAndroid Build Coastguard Worker        res[6].sum().backward()
378*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
379*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Trying to backward through the graph a second time"
380*da0073e9SAndroid Build Coastguard Worker        ):
381*da0073e9SAndroid Build Coastguard Worker            res[6].sum().backward()
382*da0073e9SAndroid Build Coastguard Worker        t1_grad = t1.grad
383*da0073e9SAndroid Build Coastguard Worker        t2_grad = t2.grad
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker        # Reset grads, run without checkpoint and validate we receive same grads.
386*da0073e9SAndroid Build Coastguard Worker        t1.grad = None
387*da0073e9SAndroid Build Coastguard Worker        t2.grad = None
388*da0073e9SAndroid Build Coastguard Worker        res = foo(t1, t2, scale, t3)
389*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()])
390*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1.grad, t1_grad)
391*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t2.grad, t2_grad)
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker    def test_checkpoint_no_tensors(self):
394*da0073e9SAndroid Build Coastguard Worker        def foo(t1, t2, scale, t3):
395*da0073e9SAndroid Build Coastguard Worker            t4 = t1 + t2 * t3
396*da0073e9SAndroid Build Coastguard Worker            t5 = t1 * t2 + t3
397*da0073e9SAndroid Build Coastguard Worker            t4 *= scale
398*da0073e9SAndroid Build Coastguard Worker            t5 *= scale
399*da0073e9SAndroid Build Coastguard Worker            return scale, t4, None, True, t5, "bar", t1
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker        t1 = random.random()
402*da0073e9SAndroid Build Coastguard Worker        t2 = random.random()
403*da0073e9SAndroid Build Coastguard Worker        t3 = random.random()
404*da0073e9SAndroid Build Coastguard Worker        scale = random.randint(0, 10)
405*da0073e9SAndroid Build Coastguard Worker        res = checkpoint(foo, t1, t2, scale, t3, use_reentrant=True)
406*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scale, res[0])
407*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((t1 + t2 * t3) * scale, res[1])
408*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(None, res[2])
409*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(True, res[3])
410*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((t1 * t2 + t3) * scale, res[4])
411*da0073e9SAndroid Build Coastguard Worker        self.assertEqual("bar", res[5])
412*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1, res[6])
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker    def test_checkpoint_partial_grad(self):
415*da0073e9SAndroid Build Coastguard Worker        def run_fn(tensor1, tensor2):
416*da0073e9SAndroid Build Coastguard Worker            # tensor 2 is used for other application logic
417*da0073e9SAndroid Build Coastguard Worker            return tensor1, tensor2
418*da0073e9SAndroid Build Coastguard Worker
419*da0073e9SAndroid Build Coastguard Worker        input_var = torch.randn(1, 4, requires_grad=True)
420*da0073e9SAndroid Build Coastguard Worker        input_var2 = torch.randn(1, 4, requires_grad=False)
421*da0073e9SAndroid Build Coastguard Worker        out = checkpoint(run_fn, input_var, input_var2, use_reentrant=True)
422*da0073e9SAndroid Build Coastguard Worker        out[0].sum().backward()
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker        def run_fn2(tensor1, tensor2):
425*da0073e9SAndroid Build Coastguard Worker            return tensor1
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker        input_var = torch.randn(1, 4, requires_grad=False)
428*da0073e9SAndroid Build Coastguard Worker        input_var2 = torch.randn(1, 4, requires_grad=True)
429*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
430*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
431*da0073e9SAndroid Build Coastguard Worker            r"none of output has requires_grad=True, this checkpoint\(\) is not necessary",
432*da0073e9SAndroid Build Coastguard Worker        ):
433*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(run_fn2, input_var, input_var2, use_reentrant=True)
434*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
437*da0073e9SAndroid Build Coastguard Worker    def test_checkpointing_without_reentrant_early_free(self):
438*da0073e9SAndroid Build Coastguard Worker        # I don't know how to check if the temporary saved variable buffer
439*da0073e9SAndroid Build Coastguard Worker        # get de-allocated directly. So using cuda memory usage as a proxy
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker        def _do_test(fn, should_free):
442*da0073e9SAndroid Build Coastguard Worker            stats: List[int] = []
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker            def track(x, idx):
445*da0073e9SAndroid Build Coastguard Worker                # Track that at each step of the backward, some Tensor were
446*da0073e9SAndroid Build Coastguard Worker                # de-allocated (which correspond to the checkpoint storage being
447*da0073e9SAndroid Build Coastguard Worker                # emptied at each step)
448*da0073e9SAndroid Build Coastguard Worker                def hook(_unused):
449*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(len(stats), idx)
450*da0073e9SAndroid Build Coastguard Worker                    torch.cuda.synchronize()
451*da0073e9SAndroid Build Coastguard Worker                    stats.append(torch.cuda.memory_allocated())
452*da0073e9SAndroid Build Coastguard Worker                    if idx > 0:
453*da0073e9SAndroid Build Coastguard Worker                        if should_free:
454*da0073e9SAndroid Build Coastguard Worker                            self.assertLess(stats[idx], stats[idx - 1])
455*da0073e9SAndroid Build Coastguard Worker                        else:
456*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(stats[idx], stats[idx - 1])
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker                x.register_hook(hook)
459*da0073e9SAndroid Build Coastguard Worker
460*da0073e9SAndroid Build Coastguard Worker            def test_fn(x):
461*da0073e9SAndroid Build Coastguard Worker                # The main property of this function is that it contains multiple
462*da0073e9SAndroid Build Coastguard Worker                # operations that save gradients in a chain.
463*da0073e9SAndroid Build Coastguard Worker                x = x**2
464*da0073e9SAndroid Build Coastguard Worker                track(x, 2)
465*da0073e9SAndroid Build Coastguard Worker                x = x**2
466*da0073e9SAndroid Build Coastguard Worker                track(x, 1)
467*da0073e9SAndroid Build Coastguard Worker                x = x**2
468*da0073e9SAndroid Build Coastguard Worker                track(x, 0)
469*da0073e9SAndroid Build Coastguard Worker                x = x**2
470*da0073e9SAndroid Build Coastguard Worker                return x.sum()
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker            fn(test_fn)
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker            return stats
475*da0073e9SAndroid Build Coastguard Worker
476*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(10, device="cuda", requires_grad=True)
477*da0073e9SAndroid Build Coastguard Worker        x.grad = torch.zeros_like(x)
478*da0073e9SAndroid Build Coastguard Worker
479*da0073e9SAndroid Build Coastguard Worker        # In a regular backward, buffers get eagerly freed
480*da0073e9SAndroid Build Coastguard Worker        non_retain_stats = _do_test(lambda fn: fn(x).backward(), True)
481*da0073e9SAndroid Build Coastguard Worker
482*da0073e9SAndroid Build Coastguard Worker        # In a retain_grad backward, buffers get preserved
483*da0073e9SAndroid Build Coastguard Worker        _unused_retain_stats = _do_test(
484*da0073e9SAndroid Build Coastguard Worker            lambda fn: fn(x).backward(retain_graph=True), False
485*da0073e9SAndroid Build Coastguard Worker        )
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker        # In a regular backward with checkpoint, buffers get eagerly freed
488*da0073e9SAndroid Build Coastguard Worker        checkpoint_non_retain_stats = _do_test(
489*da0073e9SAndroid Build Coastguard Worker            lambda fn: checkpoint(fn, x, use_reentrant=False).backward(), True
490*da0073e9SAndroid Build Coastguard Worker        )
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker        # In a retain_grad backward with checkpoint, buffers get eagerly freed
493*da0073e9SAndroid Build Coastguard Worker        checkpoint_retain_stats = _do_test(
494*da0073e9SAndroid Build Coastguard Worker            lambda fn: checkpoint(fn, x, use_reentrant=False).backward(
495*da0073e9SAndroid Build Coastguard Worker                retain_graph=True
496*da0073e9SAndroid Build Coastguard Worker            ),
497*da0073e9SAndroid Build Coastguard Worker            True,
498*da0073e9SAndroid Build Coastguard Worker        )
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(non_retain_stats, checkpoint_non_retain_stats)
501*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(non_retain_stats, checkpoint_retain_stats)
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
504*da0073e9SAndroid Build Coastguard Worker    def test_get_device_states_recursive(self):
505*da0073e9SAndroid Build Coastguard Worker        inp = {
506*da0073e9SAndroid Build Coastguard Worker            "foo": torch.rand(10, device="cuda:0"),
507*da0073e9SAndroid Build Coastguard Worker            "bar": [torch.rand(10, device="cuda:1")],
508*da0073e9SAndroid Build Coastguard Worker        }
509*da0073e9SAndroid Build Coastguard Worker        device_ids, device_states = get_device_states(inp)
510*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, len(device_ids))
511*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, len(device_states))
512*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, device_ids[0])
513*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, device_ids[1])
514*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(device_states[0], torch.Tensor))
515*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(device_states[1], torch.Tensor))
516*da0073e9SAndroid Build Coastguard Worker
517*da0073e9SAndroid Build Coastguard Worker    def test_infer_device_state_recursive_meta(self):
518*da0073e9SAndroid Build Coastguard Worker        inp = {"foo": torch.rand(10, device="meta")}
519*da0073e9SAndroid Build Coastguard Worker        device_type = _infer_device_type(inp)
520*da0073e9SAndroid Build Coastguard Worker        self.assertEqual("meta", device_type)
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
523*da0073e9SAndroid Build Coastguard Worker    def test_infer_device_state_recursive_multi_cuda(self):
524*da0073e9SAndroid Build Coastguard Worker        # Check that no warning is issued for either cuda:0, cuda:1 or
525*da0073e9SAndroid Build Coastguard Worker        # cuda:0, cuda:0 cases since they are both the same device type
526*da0073e9SAndroid Build Coastguard Worker        inp = {
527*da0073e9SAndroid Build Coastguard Worker            "foo": torch.rand(10, device="cuda:0"),
528*da0073e9SAndroid Build Coastguard Worker            "bar": [torch.rand(10, device="cuda:1")],
529*da0073e9SAndroid Build Coastguard Worker        }
530*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings():
531*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("error")
532*da0073e9SAndroid Build Coastguard Worker            device_type = _infer_device_type(inp)
533*da0073e9SAndroid Build Coastguard Worker            self.assertEqual("cuda", device_type)
534*da0073e9SAndroid Build Coastguard Worker        inp = {
535*da0073e9SAndroid Build Coastguard Worker            "foo": torch.rand(10, device="cuda:0"),
536*da0073e9SAndroid Build Coastguard Worker            "bar": [torch.rand(10, device="cuda:0")],
537*da0073e9SAndroid Build Coastguard Worker        }
538*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings():
539*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("error")
540*da0073e9SAndroid Build Coastguard Worker            device_type = _infer_device_type(inp)
541*da0073e9SAndroid Build Coastguard Worker            self.assertEqual("cuda", device_type)
542*da0073e9SAndroid Build Coastguard Worker        # Check that a warning is issued for cuda:0, meta and that it includes
543*da0073e9SAndroid Build Coastguard Worker        # device type information
544*da0073e9SAndroid Build Coastguard Worker        inp = {
545*da0073e9SAndroid Build Coastguard Worker            "foo": torch.rand(10, device="cuda:0"),
546*da0073e9SAndroid Build Coastguard Worker            "bar": [torch.rand(10, device="meta")],
547*da0073e9SAndroid Build Coastguard Worker        }
548*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
549*da0073e9SAndroid Build Coastguard Worker            device_type = _infer_device_type(inp)
550*da0073e9SAndroid Build Coastguard Worker            self.assertEqual("cuda", device_type)
551*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(w), 1)
552*da0073e9SAndroid Build Coastguard Worker        warning_msg = str(w[-1].message)
553*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
554*da0073e9SAndroid Build Coastguard Worker            "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices"
555*da0073e9SAndroid Build Coastguard Worker            in warning_msg
556*da0073e9SAndroid Build Coastguard Worker        )
557*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("Device types: ['cuda', 'meta']" in warning_msg)
558*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("first device type: cuda" in warning_msg)
559*da0073e9SAndroid Build Coastguard Worker
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Workerclass TestDataLoaderUtils(TestCase):
562*da0073e9SAndroid Build Coastguard Worker    MAX_TIMEOUT_IN_SECOND = 300
563*da0073e9SAndroid Build Coastguard Worker
564*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
565*da0073e9SAndroid Build Coastguard Worker        super().setUp()
566*da0073e9SAndroid Build Coastguard Worker        self.dataset = torch.randn(5, 3, 3, 2)
567*da0073e9SAndroid Build Coastguard Worker        self.batch_size = 3
568*da0073e9SAndroid Build Coastguard Worker
569*da0073e9SAndroid Build Coastguard Worker    def test_random_seed(self):
570*da0073e9SAndroid Build Coastguard Worker        def run():
571*da0073e9SAndroid Build Coastguard Worker            dataloader = torch.utils.data.DataLoader(
572*da0073e9SAndroid Build Coastguard Worker                RandomDatasetMock(),
573*da0073e9SAndroid Build Coastguard Worker                batch_size=2,
574*da0073e9SAndroid Build Coastguard Worker                num_workers=4,
575*da0073e9SAndroid Build Coastguard Worker                shuffle=True,
576*da0073e9SAndroid Build Coastguard Worker                timeout=self.MAX_TIMEOUT_IN_SECOND,
577*da0073e9SAndroid Build Coastguard Worker            )
578*da0073e9SAndroid Build Coastguard Worker            return next(iter(dataloader))
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(2018)
581*da0073e9SAndroid Build Coastguard Worker        x1 = run()
582*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(2018)
583*da0073e9SAndroid Build Coastguard Worker        x2 = run()
584*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1, x2)
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Worker    def test_single_keep(self):
587*da0073e9SAndroid Build Coastguard Worker        # self.dataset is a Tensor here; technically not a valid input because
588*da0073e9SAndroid Build Coastguard Worker        # not a Dataset subclass, but needs to stay working so add ignore's
589*da0073e9SAndroid Build Coastguard Worker        # for type checking with mypy
590*da0073e9SAndroid Build Coastguard Worker        dataloader: DataLoader = DataLoader(
591*da0073e9SAndroid Build Coastguard Worker            self.dataset,  # type: ignore[arg-type]
592*da0073e9SAndroid Build Coastguard Worker            batch_size=self.batch_size,
593*da0073e9SAndroid Build Coastguard Worker            num_workers=0,
594*da0073e9SAndroid Build Coastguard Worker            drop_last=False,
595*da0073e9SAndroid Build Coastguard Worker        )
596*da0073e9SAndroid Build Coastguard Worker        dataiter = iter(dataloader)
597*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(dataiter)), 2)
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker    def test_single_drop(self):
600*da0073e9SAndroid Build Coastguard Worker        dataloader: DataLoader = DataLoader(
601*da0073e9SAndroid Build Coastguard Worker            self.dataset,  # type: ignore[arg-type]
602*da0073e9SAndroid Build Coastguard Worker            batch_size=self.batch_size,
603*da0073e9SAndroid Build Coastguard Worker            num_workers=0,
604*da0073e9SAndroid Build Coastguard Worker            drop_last=True,
605*da0073e9SAndroid Build Coastguard Worker        )
606*da0073e9SAndroid Build Coastguard Worker        dataiter = iter(dataloader)
607*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(dataiter)), 1)
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker    @unittest.skip(
610*da0073e9SAndroid Build Coastguard Worker        "FIXME: Intermittent CUDA out-of-memory error on Windows and time-out under ASAN"
611*da0073e9SAndroid Build Coastguard Worker    )
612*da0073e9SAndroid Build Coastguard Worker    def test_multi_keep(self):
613*da0073e9SAndroid Build Coastguard Worker        dataloader: DataLoader = DataLoader(
614*da0073e9SAndroid Build Coastguard Worker            self.dataset,  # type: ignore[arg-type]
615*da0073e9SAndroid Build Coastguard Worker            batch_size=self.batch_size,
616*da0073e9SAndroid Build Coastguard Worker            num_workers=2,
617*da0073e9SAndroid Build Coastguard Worker            drop_last=False,
618*da0073e9SAndroid Build Coastguard Worker            timeout=self.MAX_TIMEOUT_IN_SECOND,
619*da0073e9SAndroid Build Coastguard Worker        )
620*da0073e9SAndroid Build Coastguard Worker        dataiter = iter(dataloader)
621*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(dataiter)), 2)
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker    def test_multi_drop(self):
624*da0073e9SAndroid Build Coastguard Worker        dataloader: DataLoader = DataLoader(
625*da0073e9SAndroid Build Coastguard Worker            self.dataset,  # type: ignore[arg-type]
626*da0073e9SAndroid Build Coastguard Worker            batch_size=self.batch_size,
627*da0073e9SAndroid Build Coastguard Worker            num_workers=2,
628*da0073e9SAndroid Build Coastguard Worker            drop_last=True,
629*da0073e9SAndroid Build Coastguard Worker            timeout=self.MAX_TIMEOUT_IN_SECOND,
630*da0073e9SAndroid Build Coastguard Worker        )
631*da0073e9SAndroid Build Coastguard Worker        dataiter = iter(dataloader)
632*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(dataiter)), 1)
633*da0073e9SAndroid Build Coastguard Worker
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Workertest_dir = os.path.abspath(os.path.dirname(str(__file__)))
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker
638*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
639*da0073e9SAndroid Build Coastguard Worker    "SKIP_TEST_BOTTLENECK" in os.environ.keys(), "SKIP_TEST_BOTTLENECK is set"
640*da0073e9SAndroid Build Coastguard Worker)
641*da0073e9SAndroid Build Coastguard Workerclass TestBottleneck(TestCase):
642*da0073e9SAndroid Build Coastguard Worker    def _run(self, command, timeout=30):
643*da0073e9SAndroid Build Coastguard Worker        """Returns (return-code, stdout, stderr)"""
644*da0073e9SAndroid Build Coastguard Worker        import subprocess
645*da0073e9SAndroid Build Coastguard Worker
646*da0073e9SAndroid Build Coastguard Worker        p = subprocess.Popen(
647*da0073e9SAndroid Build Coastguard Worker            command,
648*da0073e9SAndroid Build Coastguard Worker            stdout=subprocess.PIPE,
649*da0073e9SAndroid Build Coastguard Worker            stderr=subprocess.PIPE,
650*da0073e9SAndroid Build Coastguard Worker            shell=True,
651*da0073e9SAndroid Build Coastguard Worker        )
652*da0073e9SAndroid Build Coastguard Worker        try:
653*da0073e9SAndroid Build Coastguard Worker            output, err = p.communicate(timeout=timeout)
654*da0073e9SAndroid Build Coastguard Worker        except subprocess.TimeoutExpired:
655*da0073e9SAndroid Build Coastguard Worker            p.kill()
656*da0073e9SAndroid Build Coastguard Worker            output, err = p.communicate()
657*da0073e9SAndroid Build Coastguard Worker        rc = p.returncode
658*da0073e9SAndroid Build Coastguard Worker        output_str = output.decode("ascii")
659*da0073e9SAndroid Build Coastguard Worker        err_str = err.decode("ascii")
660*da0073e9SAndroid Build Coastguard Worker        return (rc, output_str, err_str)
661*da0073e9SAndroid Build Coastguard Worker
662*da0073e9SAndroid Build Coastguard Worker    def _run_bottleneck(self, test_file, scriptargs=""):
663*da0073e9SAndroid Build Coastguard Worker        curdir = os.path.dirname(os.path.abspath(__file__))
664*da0073e9SAndroid Build Coastguard Worker        filepath = f"{curdir}/{test_file}"
665*da0073e9SAndroid Build Coastguard Worker        if scriptargs != "":
666*da0073e9SAndroid Build Coastguard Worker            scriptargs = f" {scriptargs}"
667*da0073e9SAndroid Build Coastguard Worker        rc, out, err = self._run(
668*da0073e9SAndroid Build Coastguard Worker            f"{sys.executable} -m torch.utils.bottleneck {filepath}{scriptargs}"
669*da0073e9SAndroid Build Coastguard Worker        )
670*da0073e9SAndroid Build Coastguard Worker        return rc, out, err
671*da0073e9SAndroid Build Coastguard Worker
672*da0073e9SAndroid Build Coastguard Worker    def _check_run_args(self):
673*da0073e9SAndroid Build Coastguard Worker        # Check that this fails due to missing args
674*da0073e9SAndroid Build Coastguard Worker        rc, out, err = self._run_bottleneck("bottleneck_test/test_args.py")
675*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
676*da0073e9SAndroid Build Coastguard Worker            rc,
677*da0073e9SAndroid Build Coastguard Worker            2,
678*da0073e9SAndroid Build Coastguard Worker            atol=0,
679*da0073e9SAndroid Build Coastguard Worker            rtol=0,
680*da0073e9SAndroid Build Coastguard Worker            msg=self._fail_msg("Missing args should error", out + err),
681*da0073e9SAndroid Build Coastguard Worker        )
682*da0073e9SAndroid Build Coastguard Worker
683*da0073e9SAndroid Build Coastguard Worker        # This should succeed
684*da0073e9SAndroid Build Coastguard Worker        rc, out, err = self._run_bottleneck(
685*da0073e9SAndroid Build Coastguard Worker            "bottleneck_test/test_args.py", "--foo foo --bar bar"
686*da0073e9SAndroid Build Coastguard Worker        )
687*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
688*da0073e9SAndroid Build Coastguard Worker            rc,
689*da0073e9SAndroid Build Coastguard Worker            0,
690*da0073e9SAndroid Build Coastguard Worker            atol=0,
691*da0073e9SAndroid Build Coastguard Worker            rtol=0,
692*da0073e9SAndroid Build Coastguard Worker            msg=self._fail_msg("Should pass args to script", out + err),
693*da0073e9SAndroid Build Coastguard Worker        )
694*da0073e9SAndroid Build Coastguard Worker
695*da0073e9SAndroid Build Coastguard Worker    def _fail_msg(self, msg, output):
696*da0073e9SAndroid Build Coastguard Worker        return f"{msg}, output was:\n{output}"
697*da0073e9SAndroid Build Coastguard Worker
698*da0073e9SAndroid Build Coastguard Worker    def _check_environment_summary(self, output):
699*da0073e9SAndroid Build Coastguard Worker        results = re.search("Environment Summary", output)
700*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(
701*da0073e9SAndroid Build Coastguard Worker            results, self._fail_msg("Should have Environment Summary", output)
702*da0073e9SAndroid Build Coastguard Worker        )
703*da0073e9SAndroid Build Coastguard Worker
704*da0073e9SAndroid Build Coastguard Worker        # Up to five lines away from the heading, there should be the version number
705*da0073e9SAndroid Build Coastguard Worker        results = re.search(
706*da0073e9SAndroid Build Coastguard Worker            r"Environment Summary.*(\n.*){,5}\nPyTorch \d+\.\d+", output
707*da0073e9SAndroid Build Coastguard Worker        )
708*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(
709*da0073e9SAndroid Build Coastguard Worker            results, self._fail_msg("Should have PyTorch version", output)
710*da0073e9SAndroid Build Coastguard Worker        )
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker    def _check_cprof_summary(self, output):
713*da0073e9SAndroid Build Coastguard Worker        results = re.search("cProfile output", output)
714*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(
715*da0073e9SAndroid Build Coastguard Worker            results, self._fail_msg("Should have cProfile output", output)
716*da0073e9SAndroid Build Coastguard Worker        )
717*da0073e9SAndroid Build Coastguard Worker
718*da0073e9SAndroid Build Coastguard Worker        # This assumes that after the cProfile output section we have
719*da0073e9SAndroid Build Coastguard Worker        # the autograd profiler output
720*da0073e9SAndroid Build Coastguard Worker        results = re.search(
721*da0073e9SAndroid Build Coastguard Worker            r"cProfile output.*(\n.*){6,50}\n.*autograd profiler output", output
722*da0073e9SAndroid Build Coastguard Worker        )
723*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(
724*da0073e9SAndroid Build Coastguard Worker            results,
725*da0073e9SAndroid Build Coastguard Worker            self._fail_msg(
726*da0073e9SAndroid Build Coastguard Worker                "Distance between cProfile and autograd prof out not in [6, 50] lines",
727*da0073e9SAndroid Build Coastguard Worker                output,
728*da0073e9SAndroid Build Coastguard Worker            ),
729*da0073e9SAndroid Build Coastguard Worker        )
730*da0073e9SAndroid Build Coastguard Worker
731*da0073e9SAndroid Build Coastguard Worker    def _check_autograd_summary(self, output):
732*da0073e9SAndroid Build Coastguard Worker        results = re.search("autograd profiler output", output)
733*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(
734*da0073e9SAndroid Build Coastguard Worker            results, self._fail_msg("Should have autograd profiler output", output)
735*da0073e9SAndroid Build Coastguard Worker        )
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker        # This assumes that after the autograd profiler output is the end of the
738*da0073e9SAndroid Build Coastguard Worker        # output.
739*da0073e9SAndroid Build Coastguard Worker        results = re.search(r"autograd profiler output.*(\n.*){6,100}", output)
740*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(
741*da0073e9SAndroid Build Coastguard Worker            results,
742*da0073e9SAndroid Build Coastguard Worker            self._fail_msg(
743*da0073e9SAndroid Build Coastguard Worker                "Distance between autograd prof output and end of output not in [6, 100] lines",
744*da0073e9SAndroid Build Coastguard Worker                output,
745*da0073e9SAndroid Build Coastguard Worker            ),
746*da0073e9SAndroid Build Coastguard Worker        )
747*da0073e9SAndroid Build Coastguard Worker
748*da0073e9SAndroid Build Coastguard Worker    def _check_cuda(self, output):
749*da0073e9SAndroid Build Coastguard Worker        if HAS_CUDA:
750*da0073e9SAndroid Build Coastguard Worker            results = re.search("CUDA mode", output)
751*da0073e9SAndroid Build Coastguard Worker            self.assertIsNotNone(
752*da0073e9SAndroid Build Coastguard Worker                results, self._fail_msg("Should tell users CUDA", output)
753*da0073e9SAndroid Build Coastguard Worker            )
754*da0073e9SAndroid Build Coastguard Worker        else:
755*da0073e9SAndroid Build Coastguard Worker            results = re.search("CUDA mode", output)
756*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(
757*da0073e9SAndroid Build Coastguard Worker                results, self._fail_msg("Should not tell users about CUDA", output)
758*da0073e9SAndroid Build Coastguard Worker            )
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(HAS_CUDA, "CPU-only test")
761*da0073e9SAndroid Build Coastguard Worker    def test_bottleneck_cpu_only(self):
762*da0073e9SAndroid Build Coastguard Worker        rc, out, err = self._run_bottleneck("bottleneck_test/test.py")
763*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(rc, 0, msg=f"Run failed with\n{err}")
764*da0073e9SAndroid Build Coastguard Worker
765*da0073e9SAndroid Build Coastguard Worker        self._check_run_args()
766*da0073e9SAndroid Build Coastguard Worker        self._check_environment_summary(out)
767*da0073e9SAndroid Build Coastguard Worker        self._check_autograd_summary(out)
768*da0073e9SAndroid Build Coastguard Worker        self._check_cprof_summary(out)
769*da0073e9SAndroid Build Coastguard Worker        self._check_cuda(out)
770*da0073e9SAndroid Build Coastguard Worker
771*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not HAS_CUDA, "No CUDA")
772*da0073e9SAndroid Build Coastguard Worker    def test_bottleneck_cuda(self):
773*da0073e9SAndroid Build Coastguard Worker        rc, out, err = self._run_bottleneck("bottleneck_test/test_cuda.py")
774*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(rc, 0, msg=f"Run failed with\n{err}")
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Worker        self._check_run_args()
777*da0073e9SAndroid Build Coastguard Worker        self._check_environment_summary(out)
778*da0073e9SAndroid Build Coastguard Worker        self._check_autograd_summary(out)
779*da0073e9SAndroid Build Coastguard Worker        self._check_cprof_summary(out)
780*da0073e9SAndroid Build Coastguard Worker        self._check_cuda(out)
781*da0073e9SAndroid Build Coastguard Worker
782*da0073e9SAndroid Build Coastguard Worker
783*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.collect_env import get_pretty_env_info
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker
786*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_FBCODE, "runs pip which is not available internally")
787*da0073e9SAndroid Build Coastguard Workerclass TestCollectEnv(TestCase):
788*da0073e9SAndroid Build Coastguard Worker    def test_smoke(self):
789*da0073e9SAndroid Build Coastguard Worker        info_output = get_pretty_env_info()
790*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(info_output.count("\n") >= 17)
791*da0073e9SAndroid Build Coastguard Worker
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Workerclass TestONNXUtils(TestCase):
794*da0073e9SAndroid Build Coastguard Worker    def test_prepare_onnx_paddings(self):
795*da0073e9SAndroid Build Coastguard Worker        sizes = [2, 3, 4]
796*da0073e9SAndroid Build Coastguard Worker        pad = [1, 2, 3, 4]
797*da0073e9SAndroid Build Coastguard Worker        paddings = _prepare_onnx_paddings(len(sizes), pad)
798*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(paddings, [0, 3, 1, 0, 4, 2])
799*da0073e9SAndroid Build Coastguard Worker
800*da0073e9SAndroid Build Coastguard Worker    def test_check_onnx_broadcast(self):
801*da0073e9SAndroid Build Coastguard Worker        def try_check_onnx_broadcast(dims1, dims2, expect_broadcast, expect_fail):
802*da0073e9SAndroid Build Coastguard Worker            broadcast = True
803*da0073e9SAndroid Build Coastguard Worker            fail = False
804*da0073e9SAndroid Build Coastguard Worker            try:
805*da0073e9SAndroid Build Coastguard Worker                broadcast = check_onnx_broadcast(dims1, dims2)
806*da0073e9SAndroid Build Coastguard Worker            except ValueError:
807*da0073e9SAndroid Build Coastguard Worker                fail = True
808*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(broadcast, expect_broadcast)
809*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fail, expect_fail)
810*da0073e9SAndroid Build Coastguard Worker
811*da0073e9SAndroid Build Coastguard Worker        # Case 1, check the case when len(dims1) < len(dims2) and numel(dims2) > 1
812*da0073e9SAndroid Build Coastguard Worker        dims1 = [3, 4]
813*da0073e9SAndroid Build Coastguard Worker        dims2 = [2, 3, 4]
814*da0073e9SAndroid Build Coastguard Worker        try_check_onnx_broadcast(dims1, dims2, True, True)
815*da0073e9SAndroid Build Coastguard Worker
816*da0073e9SAndroid Build Coastguard Worker        # Case 2, check the case when len(dims1) < len(dims2) and numel(dims2) == 1
817*da0073e9SAndroid Build Coastguard Worker        dims1 = [3, 4]
818*da0073e9SAndroid Build Coastguard Worker        dims2 = [1, 1, 1]
819*da0073e9SAndroid Build Coastguard Worker        try_check_onnx_broadcast(dims1, dims2, True, False)
820*da0073e9SAndroid Build Coastguard Worker
821*da0073e9SAndroid Build Coastguard Worker        # Case 3, check the case when len(dims1) > len(dims2) and numel(dims2) == 1
822*da0073e9SAndroid Build Coastguard Worker        dims1 = [1, 1]
823*da0073e9SAndroid Build Coastguard Worker        dims2 = [1]
824*da0073e9SAndroid Build Coastguard Worker        try_check_onnx_broadcast(dims1, dims2, True, False)
825*da0073e9SAndroid Build Coastguard Worker
826*da0073e9SAndroid Build Coastguard Worker        # Case 4, check the case when len(dims1) > len(dims2) and dims1[x:] == dims2
827*da0073e9SAndroid Build Coastguard Worker        dims1 = [2, 3, 4]
828*da0073e9SAndroid Build Coastguard Worker        dims2 = [3, 4]
829*da0073e9SAndroid Build Coastguard Worker        try_check_onnx_broadcast(dims1, dims2, True, False)
830*da0073e9SAndroid Build Coastguard Worker
831*da0073e9SAndroid Build Coastguard Worker        # Case 5, check the case when len(dims1) > len(dims2), but dims1[x:] != dims2
832*da0073e9SAndroid Build Coastguard Worker        dims1 = [2, 3, 4]
833*da0073e9SAndroid Build Coastguard Worker        dims2 = [1, 4]
834*da0073e9SAndroid Build Coastguard Worker        try_check_onnx_broadcast(dims1, dims2, True, True)
835*da0073e9SAndroid Build Coastguard Worker
836*da0073e9SAndroid Build Coastguard Worker        # Case 6, check the equal case, no broadcast
837*da0073e9SAndroid Build Coastguard Worker        dims1 = [3, 4]
838*da0073e9SAndroid Build Coastguard Worker        dims2 = [3, 4]
839*da0073e9SAndroid Build Coastguard Worker        try_check_onnx_broadcast(dims1, dims2, False, False)
840*da0073e9SAndroid Build Coastguard Worker
841*da0073e9SAndroid Build Coastguard Worker        # Case 7, check the case when len(dims1) == len(dims2), but dims1 != dims2
842*da0073e9SAndroid Build Coastguard Worker        dims1 = [3, 4]
843*da0073e9SAndroid Build Coastguard Worker        dims2 = [1, 4]
844*da0073e9SAndroid Build Coastguard Worker        try_check_onnx_broadcast(dims1, dims2, True, True)
845*da0073e9SAndroid Build Coastguard Worker
846*da0073e9SAndroid Build Coastguard Worker        # Case 8, check the case when len(dims1) == len(dims2) and numel(s2) == 1
847*da0073e9SAndroid Build Coastguard Worker        dims1 = [3, 4]
848*da0073e9SAndroid Build Coastguard Worker        dims2 = [1, 1]
849*da0073e9SAndroid Build Coastguard Worker        try_check_onnx_broadcast(dims1, dims2, True, False)
850*da0073e9SAndroid Build Coastguard Worker
851*da0073e9SAndroid Build Coastguard Worker
852*da0073e9SAndroid Build Coastguard Workerclass TestHipify(TestCase):
853*da0073e9SAndroid Build Coastguard Worker    def test_import_hipify(self):
854*da0073e9SAndroid Build Coastguard Worker        from torch.utils.hipify import hipify_python  # noqa: F401
855*da0073e9SAndroid Build Coastguard Worker
856*da0073e9SAndroid Build Coastguard Worker
857*da0073e9SAndroid Build Coastguard Workerclass TestHipifyTrie(TestCase):
858*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
859*da0073e9SAndroid Build Coastguard Worker        self.trie = torch.utils.hipify.hipify_python.Trie()
860*da0073e9SAndroid Build Coastguard Worker
861*da0073e9SAndroid Build Coastguard Worker    def test_add_and_search_trie(self):
862*da0073e9SAndroid Build Coastguard Worker        self.trie.add("banana")
863*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.trie.search("banana"))
864*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(self.trie.search("ban"))
865*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(self.trie.search("dog"))
866*da0073e9SAndroid Build Coastguard Worker
867*da0073e9SAndroid Build Coastguard Worker    def test_add_multiple_and_search_trie(self):
868*da0073e9SAndroid Build Coastguard Worker        words_to_add = ["banana", "apple", "orange"]
869*da0073e9SAndroid Build Coastguard Worker        for word in words_to_add:
870*da0073e9SAndroid Build Coastguard Worker            self.trie.add(word)
871*da0073e9SAndroid Build Coastguard Worker
872*da0073e9SAndroid Build Coastguard Worker        for word in words_to_add:
873*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.trie.search(word))
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Worker        for word in ["ban", "dog", "okay", "app"]:
876*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(self.trie.search(word))
877*da0073e9SAndroid Build Coastguard Worker
878*da0073e9SAndroid Build Coastguard Worker    def test_quote_escape(self):
879*da0073e9SAndroid Build Coastguard Worker        orig_chars = ["*", "[", ".", "+", "a", "z", "-"]
880*da0073e9SAndroid Build Coastguard Worker        quoted_strs = ["\\*", "\\[", "\\.", "\\+", "a", "z", "\\-"]
881*da0073e9SAndroid Build Coastguard Worker        for i in range(len(orig_chars)):
882*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(self.trie.quote(orig_chars[i]), quoted_strs[i])
883*da0073e9SAndroid Build Coastguard Worker
884*da0073e9SAndroid Build Coastguard Worker    def test_export_trie_to_regex(self):
885*da0073e9SAndroid Build Coastguard Worker        words_to_add = [
886*da0073e9SAndroid Build Coastguard Worker            "__CUDACC__",
887*da0073e9SAndroid Build Coastguard Worker            "CUDA_ERROR_CONTEXT_ALREADY_CURRENT",
888*da0073e9SAndroid Build Coastguard Worker            "CUDA_ERROR_ARRAY_IS_MAPPED",
889*da0073e9SAndroid Build Coastguard Worker            "CUDA_ERROR_NOT_MAPPED",
890*da0073e9SAndroid Build Coastguard Worker            "CUDA_ERROR_INVALID_SOURCE",
891*da0073e9SAndroid Build Coastguard Worker        ]
892*da0073e9SAndroid Build Coastguard Worker        for word in words_to_add:
893*da0073e9SAndroid Build Coastguard Worker            self.trie.add(word)
894*da0073e9SAndroid Build Coastguard Worker        regex = self.trie.export_to_regex()
895*da0073e9SAndroid Build Coastguard Worker        expected_regex = r"(?:CUDA_ERROR_(?:ARRAY_IS_MAPPED|CONTEXT_ALREADY_CURRENT|INVALID_SOURCE|NOT_MAPPED)|__CUDACC__)"
896*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(regex, expected_regex)
897*da0073e9SAndroid Build Coastguard Worker
898*da0073e9SAndroid Build Coastguard Worker    def test_prefix_words_export_trie_to_regex(self):
899*da0073e9SAndroid Build Coastguard Worker        # test case where some nodes have both children and are also leaf nodes.
900*da0073e9SAndroid Build Coastguard Worker        words_to_add = ["apple", "app", "ban", "banana"]
901*da0073e9SAndroid Build Coastguard Worker        for word in words_to_add:
902*da0073e9SAndroid Build Coastguard Worker            self.trie.add(word)
903*da0073e9SAndroid Build Coastguard Worker        regex = self.trie.export_to_regex()
904*da0073e9SAndroid Build Coastguard Worker        expected_regex = r"(?:app(?:le)?|ban(?:ana)?)"
905*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(regex, expected_regex)
906*da0073e9SAndroid Build Coastguard Worker
907*da0073e9SAndroid Build Coastguard Worker    def test_single_export_trie_to_regex(self):
908*da0073e9SAndroid Build Coastguard Worker        words_to_add = ["cudaErrorInvalidMemcpyDirection"]
909*da0073e9SAndroid Build Coastguard Worker        for word in words_to_add:
910*da0073e9SAndroid Build Coastguard Worker            self.trie.add(word)
911*da0073e9SAndroid Build Coastguard Worker        regex = self.trie.export_to_regex()
912*da0073e9SAndroid Build Coastguard Worker        expected_regex = "cudaErrorInvalidMemcpyDirection"
913*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(regex, expected_regex)
914*da0073e9SAndroid Build Coastguard Worker
915*da0073e9SAndroid Build Coastguard Worker    def test_char_export_trie_to_regex(self):
916*da0073e9SAndroid Build Coastguard Worker        self.trie.add("a")
917*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.trie.export_to_regex(), "a")
918*da0073e9SAndroid Build Coastguard Worker        self.trie.add("b")
919*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.trie.export_to_regex(), "[ab]")
920*da0073e9SAndroid Build Coastguard Worker
921*da0073e9SAndroid Build Coastguard Worker    def test_special_char_export_trie_to_regex(self):
922*da0073e9SAndroid Build Coastguard Worker        self.trie.add(r"c*")
923*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.trie.export_to_regex(), r"c\*")
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker
926*da0073e9SAndroid Build Coastguard Workerclass TestAssert(TestCase):
927*da0073e9SAndroid Build Coastguard Worker    def test_assert_true(self):
928*da0073e9SAndroid Build Coastguard Worker        # verify assertions work as expected
929*da0073e9SAndroid Build Coastguard Worker        # bool argument
930*da0073e9SAndroid Build Coastguard Worker        torch._assert(True, "foo")
931*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, "bar"):
932*da0073e9SAndroid Build Coastguard Worker            torch._assert(False, "bar")
933*da0073e9SAndroid Build Coastguard Worker        # tensor argument
934*da0073e9SAndroid Build Coastguard Worker        torch._assert(torch.tensor([True], dtype=torch.bool), "foo")
935*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, "bar"):
936*da0073e9SAndroid Build Coastguard Worker            torch._assert(torch.tensor([False], dtype=torch.bool), "bar")
937*da0073e9SAndroid Build Coastguard Worker
938*da0073e9SAndroid Build Coastguard Worker    def test_assert_scriptable(self):
939*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
940*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
941*da0073e9SAndroid Build Coastguard Worker                torch._assert(x.sum() > 0, "foo")
942*da0073e9SAndroid Build Coastguard Worker                return x
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Worker        m = M()
945*da0073e9SAndroid Build Coastguard Worker        # scriptable
946*da0073e9SAndroid Build Coastguard Worker        ms = torch.jit.script(m)
947*da0073e9SAndroid Build Coastguard Worker        # data can be passed without errors
948*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4).fill_(1.0)
949*da0073e9SAndroid Build Coastguard Worker        ms(x)
950*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.jit.Error, "foo"):
951*da0073e9SAndroid Build Coastguard Worker            ms(torch.tensor([False], dtype=torch.bool))
952*da0073e9SAndroid Build Coastguard Worker
953*da0073e9SAndroid Build Coastguard Worker
954*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_SANDCASTLE, "cpp_extension is OSS only")
955*da0073e9SAndroid Build Coastguard Workerclass TestStandaloneCPPJIT(TestCase):
956*da0073e9SAndroid Build Coastguard Worker    def test_load_standalone(self):
957*da0073e9SAndroid Build Coastguard Worker        build_dir = tempfile.mkdtemp()
958*da0073e9SAndroid Build Coastguard Worker        try:
959*da0073e9SAndroid Build Coastguard Worker            src_path = os.path.join(build_dir, "main.cpp")
960*da0073e9SAndroid Build Coastguard Worker            src = textwrap.dedent(
961*da0073e9SAndroid Build Coastguard Worker                """\
962*da0073e9SAndroid Build Coastguard Worker                #include <iostream>
963*da0073e9SAndroid Build Coastguard Worker                #include <torch/torch.h>
964*da0073e9SAndroid Build Coastguard Worker                int main() {
965*da0073e9SAndroid Build Coastguard Worker                    auto x = torch::eye(3);
966*da0073e9SAndroid Build Coastguard Worker                    std::cout << x << std::endl;
967*da0073e9SAndroid Build Coastguard Worker                }
968*da0073e9SAndroid Build Coastguard Worker            """
969*da0073e9SAndroid Build Coastguard Worker            )
970*da0073e9SAndroid Build Coastguard Worker            with open(src_path, "w") as f:
971*da0073e9SAndroid Build Coastguard Worker                f.write(src)
972*da0073e9SAndroid Build Coastguard Worker
973*da0073e9SAndroid Build Coastguard Worker            exec_path = torch.utils.cpp_extension.load(
974*da0073e9SAndroid Build Coastguard Worker                "standalone_load_test",
975*da0073e9SAndroid Build Coastguard Worker                src_path,
976*da0073e9SAndroid Build Coastguard Worker                build_directory=build_dir,
977*da0073e9SAndroid Build Coastguard Worker                is_python_module=False,
978*da0073e9SAndroid Build Coastguard Worker                is_standalone=True,
979*da0073e9SAndroid Build Coastguard Worker            )
980*da0073e9SAndroid Build Coastguard Worker
981*da0073e9SAndroid Build Coastguard Worker            ext = ".exe" if IS_WINDOWS else ""
982*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
983*da0073e9SAndroid Build Coastguard Worker                exec_path, os.path.join(build_dir, f"standalone_load_test{ext}")
984*da0073e9SAndroid Build Coastguard Worker            )
985*da0073e9SAndroid Build Coastguard Worker
986*da0073e9SAndroid Build Coastguard Worker            for shell in [True, False]:
987*da0073e9SAndroid Build Coastguard Worker                r = subprocess.run(
988*da0073e9SAndroid Build Coastguard Worker                    [exec_path],
989*da0073e9SAndroid Build Coastguard Worker                    shell=shell,
990*da0073e9SAndroid Build Coastguard Worker                    stdout=subprocess.PIPE,
991*da0073e9SAndroid Build Coastguard Worker                )
992*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(r.returncode, 0)
993*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
994*da0073e9SAndroid Build Coastguard Worker                    # Windows prints "\r\n" for newlines.
995*da0073e9SAndroid Build Coastguard Worker                    textwrap.dedent(r.stdout.decode("utf-8")).replace("\r\n", "\n"),
996*da0073e9SAndroid Build Coastguard Worker                    textwrap.dedent(
997*da0073e9SAndroid Build Coastguard Worker                        """\
998*da0073e9SAndroid Build Coastguard Worker                     1  0  0
999*da0073e9SAndroid Build Coastguard Worker                     0  1  0
1000*da0073e9SAndroid Build Coastguard Worker                     0  0  1
1001*da0073e9SAndroid Build Coastguard Worker                    [ CPUFloatType{3,3} ]
1002*da0073e9SAndroid Build Coastguard Worker                    """
1003*da0073e9SAndroid Build Coastguard Worker                    ),
1004*da0073e9SAndroid Build Coastguard Worker                )
1005*da0073e9SAndroid Build Coastguard Worker
1006*da0073e9SAndroid Build Coastguard Worker        finally:
1007*da0073e9SAndroid Build Coastguard Worker            shutil.rmtree(build_dir)
1008*da0073e9SAndroid Build Coastguard Worker
1009*da0073e9SAndroid Build Coastguard Worker
1010*da0073e9SAndroid Build Coastguard Workerclass DummyPrivateUse1Module:
1011*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1012*da0073e9SAndroid Build Coastguard Worker    def is_available():
1013*da0073e9SAndroid Build Coastguard Worker        return True
1014*da0073e9SAndroid Build Coastguard Worker
1015*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1016*da0073e9SAndroid Build Coastguard Worker    def is_autocast_enabled():
1017*da0073e9SAndroid Build Coastguard Worker        return True
1018*da0073e9SAndroid Build Coastguard Worker
1019*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1020*da0073e9SAndroid Build Coastguard Worker    def get_autocast_dtype():
1021*da0073e9SAndroid Build Coastguard Worker        return torch.float16
1022*da0073e9SAndroid Build Coastguard Worker
1023*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1024*da0073e9SAndroid Build Coastguard Worker    def set_autocast_enabled(enable):
1025*da0073e9SAndroid Build Coastguard Worker        pass
1026*da0073e9SAndroid Build Coastguard Worker
1027*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1028*da0073e9SAndroid Build Coastguard Worker    def set_autocast_dtype(dtype):
1029*da0073e9SAndroid Build Coastguard Worker        pass
1030*da0073e9SAndroid Build Coastguard Worker
1031*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1032*da0073e9SAndroid Build Coastguard Worker    def get_amp_supported_dtype():
1033*da0073e9SAndroid Build Coastguard Worker        return [torch.float16]
1034*da0073e9SAndroid Build Coastguard Worker
1035*da0073e9SAndroid Build Coastguard Worker
1036*da0073e9SAndroid Build Coastguard Workerclass TestExtensionUtils(TestCase):
1037*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
1038*da0073e9SAndroid Build Coastguard Worker        # Clean up
1039*da0073e9SAndroid Build Coastguard Worker        backend_name = torch._C._get_privateuse1_backend_name()
1040*da0073e9SAndroid Build Coastguard Worker        if hasattr(torch, backend_name):
1041*da0073e9SAndroid Build Coastguard Worker            delattr(torch, backend_name)
1042*da0073e9SAndroid Build Coastguard Worker        if f"torch.{backend_name}" in sys.modules:
1043*da0073e9SAndroid Build Coastguard Worker            del sys.modules[f"torch.{backend_name}"]
1044*da0073e9SAndroid Build Coastguard Worker
1045*da0073e9SAndroid Build Coastguard Worker    def test_external_module_register(self):
1046*da0073e9SAndroid Build Coastguard Worker        # Built-in module
1047*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "The runtime module of"):
1048*da0073e9SAndroid Build Coastguard Worker            torch._register_device_module("cuda", torch.cuda)
1049*da0073e9SAndroid Build Coastguard Worker
1050*da0073e9SAndroid Build Coastguard Worker        # Wrong device type
1051*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"):
1052*da0073e9SAndroid Build Coastguard Worker            torch._register_device_module("dummmy", DummyPrivateUse1Module)
1053*da0073e9SAndroid Build Coastguard Worker
1054*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AttributeError):
1055*da0073e9SAndroid Build Coastguard Worker            torch.privateuseone.is_available()  # type: ignore[attr-defined]
1056*da0073e9SAndroid Build Coastguard Worker
1057*da0073e9SAndroid Build Coastguard Worker        torch._register_device_module("privateuseone", DummyPrivateUse1Module)
1058*da0073e9SAndroid Build Coastguard Worker
1059*da0073e9SAndroid Build Coastguard Worker        torch.privateuseone.is_available()  # type: ignore[attr-defined]
1060*da0073e9SAndroid Build Coastguard Worker
1061*da0073e9SAndroid Build Coastguard Worker        # No supporting for override
1062*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "The runtime module of"):
1063*da0073e9SAndroid Build Coastguard Worker            torch._register_device_module("privateuseone", DummyPrivateUse1Module)
1064*da0073e9SAndroid Build Coastguard Worker
1065*da0073e9SAndroid Build Coastguard Worker    def test_external_module_register_with_renamed_backend(self):
1066*da0073e9SAndroid Build Coastguard Worker        torch.utils.rename_privateuse1_backend("foo")
1067*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "has already been set"):
1068*da0073e9SAndroid Build Coastguard Worker            torch.utils.rename_privateuse1_backend("dummmy")
1069*da0073e9SAndroid Build Coastguard Worker
1070*da0073e9SAndroid Build Coastguard Worker        custom_backend_name = torch._C._get_privateuse1_backend_name()
1071*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(custom_backend_name, "foo")
1072*da0073e9SAndroid Build Coastguard Worker
1073*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AttributeError):
1074*da0073e9SAndroid Build Coastguard Worker            torch.foo.is_available()  # type: ignore[attr-defined]
1075*da0073e9SAndroid Build Coastguard Worker
1076*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, "Tried to use AMP with the"):
1077*da0073e9SAndroid Build Coastguard Worker            with torch.autocast(device_type=custom_backend_name):
1078*da0073e9SAndroid Build Coastguard Worker                pass
1079*da0073e9SAndroid Build Coastguard Worker        torch._register_device_module("foo", DummyPrivateUse1Module)
1080*da0073e9SAndroid Build Coastguard Worker
1081*da0073e9SAndroid Build Coastguard Worker        torch.foo.is_available()  # type: ignore[attr-defined]
1082*da0073e9SAndroid Build Coastguard Worker        with torch.autocast(device_type=custom_backend_name):
1083*da0073e9SAndroid Build Coastguard Worker            pass
1084*da0073e9SAndroid Build Coastguard Worker
1085*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch._utils._get_device_index("foo:1"), 1)
1086*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch._utils._get_device_index(torch.device("foo:2")), 2)
1087*da0073e9SAndroid Build Coastguard Worker
1088*da0073e9SAndroid Build Coastguard Worker
1089*da0073e9SAndroid Build Coastguard Workerclass TestRenderUtils(TestCase):
1090*da0073e9SAndroid Build Coastguard Worker    def test_basic(self):
1091*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1092*da0073e9SAndroid Build Coastguard Worker            torch._utils.render_call(torch.sum, [torch.randn(100)], {"dim": 0}),
1093*da0073e9SAndroid Build Coastguard Worker            """torch.sum(tensor([...], size=(100,)), dim=0)""",
1094*da0073e9SAndroid Build Coastguard Worker        )
1095*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1096*da0073e9SAndroid Build Coastguard Worker            torch._utils.render_call(torch.sum, [torch.randn(100, 100)], {"dim": 0}),
1097*da0073e9SAndroid Build Coastguard Worker            """torch.sum(tensor([...], size=(100, 100)), dim=0)""",
1098*da0073e9SAndroid Build Coastguard Worker        )
1099*da0073e9SAndroid Build Coastguard Worker
1100*da0073e9SAndroid Build Coastguard Worker
1101*da0073e9SAndroid Build Coastguard Workerclass TestDeviceUtils(TestCase):
1102*da0073e9SAndroid Build Coastguard Worker    def test_basic(self):
1103*da0073e9SAndroid Build Coastguard Worker        with torch.device("meta") as dev:
1104*da0073e9SAndroid Build Coastguard Worker            x = torch.empty(3, 3)
1105*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.device.type, "meta")
1106*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dev, torch.device("meta"))
1107*da0073e9SAndroid Build Coastguard Worker
1108*da0073e9SAndroid Build Coastguard Worker    def test_decorator(self):
1109*da0073e9SAndroid Build Coastguard Worker        @set_device("meta")
1110*da0073e9SAndroid Build Coastguard Worker        def f():
1111*da0073e9SAndroid Build Coastguard Worker            return torch.empty(3, 3)
1112*da0073e9SAndroid Build Coastguard Worker
1113*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f().device.type, "meta")
1114*da0073e9SAndroid Build Coastguard Worker
1115*da0073e9SAndroid Build Coastguard Worker    def test_decorator_generator(self):
1116*da0073e9SAndroid Build Coastguard Worker        @set_device("meta")
1117*da0073e9SAndroid Build Coastguard Worker        def f():
1118*da0073e9SAndroid Build Coastguard Worker            yield torch.empty(3, 3)
1119*da0073e9SAndroid Build Coastguard Worker            yield torch.empty(3, 3)
1120*da0073e9SAndroid Build Coastguard Worker
1121*da0073e9SAndroid Build Coastguard Worker        r1, r2 = list(f())
1122*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r1.device.type, "meta")
1123*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r2.device.type, "meta")
1124*da0073e9SAndroid Build Coastguard Worker
1125*da0073e9SAndroid Build Coastguard Worker    def test_nn_module(self):
1126*da0073e9SAndroid Build Coastguard Worker        with torch.device("meta"):
1127*da0073e9SAndroid Build Coastguard Worker            m = nn.Linear(40, 50)
1128*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.weight.device.type, "meta")
1129*da0073e9SAndroid Build Coastguard Worker
1130*da0073e9SAndroid Build Coastguard Worker    def test_set_default_device(self):
1131*da0073e9SAndroid Build Coastguard Worker        try:
1132*da0073e9SAndroid Build Coastguard Worker            torch.set_default_device("meta")
1133*da0073e9SAndroid Build Coastguard Worker            r = torch.empty(2, 2)
1134*da0073e9SAndroid Build Coastguard Worker        finally:
1135*da0073e9SAndroid Build Coastguard Worker            torch.set_default_device(None)
1136*da0073e9SAndroid Build Coastguard Worker
1137*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.device.type, "meta")
1138*da0073e9SAndroid Build Coastguard Worker
1139*da0073e9SAndroid Build Coastguard Worker    def test_get_default_device(self):
1140*da0073e9SAndroid Build Coastguard Worker        torch.set_default_device("meta")
1141*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.get_default_device().type, "meta")
1142*da0073e9SAndroid Build Coastguard Worker        torch.set_default_device(None)
1143*da0073e9SAndroid Build Coastguard Worker
1144*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
1145*da0073e9SAndroid Build Coastguard Worker    def test_get_default_device_more(self):
1146*da0073e9SAndroid Build Coastguard Worker        torch.set_default_device("cuda")
1147*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
1148*da0073e9SAndroid Build Coastguard Worker        torch.set_default_device(None)
1149*da0073e9SAndroid Build Coastguard Worker
1150*da0073e9SAndroid Build Coastguard Worker        torch.set_default_device("cuda")
1151*da0073e9SAndroid Build Coastguard Worker        torch.cuda.set_device("cuda:1")
1152*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
1153*da0073e9SAndroid Build Coastguard Worker        torch.set_default_device(None)
1154*da0073e9SAndroid Build Coastguard Worker
1155*da0073e9SAndroid Build Coastguard Worker        torch.set_default_device("cuda:1")
1156*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
1157*da0073e9SAndroid Build Coastguard Worker        torch.set_default_device(None)
1158*da0073e9SAndroid Build Coastguard Worker
1159*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1160*da0073e9SAndroid Build Coastguard Worker    @ops(op_db)
1161*da0073e9SAndroid Build Coastguard Worker    def test_device_mode_ops(self, device, dtype, op):
1162*da0073e9SAndroid Build Coastguard Worker        func = op.get_op()
1163*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype, requires_grad=False)
1164*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
1165*da0073e9SAndroid Build Coastguard Worker            # Only test samples which don't have Tensor inputs.  However,
1166*da0073e9SAndroid Build Coastguard Worker            # we don't test the factory property on OpInfo as it is very,
1167*da0073e9SAndroid Build Coastguard Worker            # very incomplete
1168*da0073e9SAndroid Build Coastguard Worker            if tree_any(
1169*da0073e9SAndroid Build Coastguard Worker                lambda x: isinstance(x, torch.Tensor),
1170*da0073e9SAndroid Build Coastguard Worker                (sample.input, sample.args, sample.kwargs),
1171*da0073e9SAndroid Build Coastguard Worker            ):
1172*da0073e9SAndroid Build Coastguard Worker                continue
1173*da0073e9SAndroid Build Coastguard Worker            # Many OpInfos will explicitly pass in a device.  DeviceContext
1174*da0073e9SAndroid Build Coastguard Worker            # will respect device if it is explicitly specified.  To test
1175*da0073e9SAndroid Build Coastguard Worker            # DeviceContext, we have to remove the device kwarg in this case.
1176*da0073e9SAndroid Build Coastguard Worker            # NB: Can't pass None to sample_inputs, the function can't
1177*da0073e9SAndroid Build Coastguard Worker            # handle it.
1178*da0073e9SAndroid Build Coastguard Worker            kwargs = sample.kwargs.copy()
1179*da0073e9SAndroid Build Coastguard Worker            kwargs.pop("device", None)
1180*da0073e9SAndroid Build Coastguard Worker            with torch.device("meta"):
1181*da0073e9SAndroid Build Coastguard Worker                r = func(sample.input, *sample.args, **kwargs)
1182*da0073e9SAndroid Build Coastguard Worker
1183*da0073e9SAndroid Build Coastguard Worker            def is_meta_device(x: torch.Tensor) -> bool:
1184*da0073e9SAndroid Build Coastguard Worker                return x.device.type == "meta"
1185*da0073e9SAndroid Build Coastguard Worker
1186*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(tree_all_only(torch.Tensor, is_meta_device, r))
1187*da0073e9SAndroid Build Coastguard Worker
1188*da0073e9SAndroid Build Coastguard Worker
1189*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestDeviceUtils, globals())
1190*da0073e9SAndroid Build Coastguard Worker
1191*da0073e9SAndroid Build Coastguard Worker
1192*da0073e9SAndroid Build Coastguard Workerclass TestCppExtensionUtils(TestCase):
1193*da0073e9SAndroid Build Coastguard Worker    def test_cpp_compiler_is_ok(self):
1194*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform("c++"))
1195*da0073e9SAndroid Build Coastguard Worker
1196*da0073e9SAndroid Build Coastguard Worker    def test_cc_compiler_is_ok(self):
1197*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform("cc"))
1198*da0073e9SAndroid Build Coastguard Worker
1199*da0073e9SAndroid Build Coastguard Worker
1200*da0073e9SAndroid Build Coastguard Workerclass TestTraceback(TestCase):
1201*da0073e9SAndroid Build Coastguard Worker    def test_basic(self):
1202*da0073e9SAndroid Build Coastguard Worker        source = """\
1203*da0073e9SAndroid Build Coastguard Workerdef f(x):
1204*da0073e9SAndroid Build Coastguard Worker    def g(x):
1205*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError  # HEYA
1206*da0073e9SAndroid Build Coastguard Worker
1207*da0073e9SAndroid Build Coastguard Worker    x = x * 3
1208*da0073e9SAndroid Build Coastguard Worker    return g(x) + 1
1209*da0073e9SAndroid Build Coastguard Worker"""
1210*da0073e9SAndroid Build Coastguard Worker
1211*da0073e9SAndroid Build Coastguard Worker        out: Dict[str, Any] = {}
1212*da0073e9SAndroid Build Coastguard Worker        scope = {"__compile_source__": source}
1213*da0073e9SAndroid Build Coastguard Worker        exec(source, scope, out)
1214*da0073e9SAndroid Build Coastguard Worker
1215*da0073e9SAndroid Build Coastguard Worker        try:
1216*da0073e9SAndroid Build Coastguard Worker            with report_compile_source_on_error():
1217*da0073e9SAndroid Build Coastguard Worker                out["f"](1)
1218*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as e:
1219*da0073e9SAndroid Build Coastguard Worker            self.assertIn("HEYA", "".join(traceback.format_tb(e.__traceback__)))
1220*da0073e9SAndroid Build Coastguard Worker
1221*da0073e9SAndroid Build Coastguard Worker    def test_format_traceback_short(self):
1222*da0073e9SAndroid Build Coastguard Worker        try:
1223*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError
1224*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as e:
1225*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(
1226*da0073e9SAndroid Build Coastguard Worker                format_traceback_short(e.__traceback__),
1227*da0073e9SAndroid Build Coastguard Worker                r".*test_utils.py:\d+ in test_format_traceback_short",
1228*da0073e9SAndroid Build Coastguard Worker            )
1229*da0073e9SAndroid Build Coastguard Worker
1230*da0073e9SAndroid Build Coastguard Worker    def test_captured_traceback(self):
1231*da0073e9SAndroid Build Coastguard Worker        self.assertIn(
1232*da0073e9SAndroid Build Coastguard Worker            "test_captured_traceback", "".join(CapturedTraceback.extract().format())
1233*da0073e9SAndroid Build Coastguard Worker        )
1234*da0073e9SAndroid Build Coastguard Worker
1235*da0073e9SAndroid Build Coastguard Worker    def test_captured_traceback_format_all(self):
1236*da0073e9SAndroid Build Coastguard Worker        rs = CapturedTraceback.format_all(
1237*da0073e9SAndroid Build Coastguard Worker            [CapturedTraceback.extract(), CapturedTraceback.extract()]
1238*da0073e9SAndroid Build Coastguard Worker        )
1239*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(rs), 2)
1240*da0073e9SAndroid Build Coastguard Worker        self.assertIn("test_captured_traceback_format_all", "".join(rs[0]))
1241*da0073e9SAndroid Build Coastguard Worker
1242*da0073e9SAndroid Build Coastguard Worker    def test_captured_traceback_format_all_cached(self):
1243*da0073e9SAndroid Build Coastguard Worker        tb = CapturedTraceback.extract()
1244*da0073e9SAndroid Build Coastguard Worker        tb.format()  # cached
1245*da0073e9SAndroid Build Coastguard Worker        rs = CapturedTraceback.format_all([tb, CapturedTraceback.extract()])
1246*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(rs), 2)
1247*da0073e9SAndroid Build Coastguard Worker        self.assertIn("test_captured_traceback_format_all", "".join(rs[0]))
1248*da0073e9SAndroid Build Coastguard Worker
1249*da0073e9SAndroid Build Coastguard Worker
1250*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1251*da0073e9SAndroid Build Coastguard Worker    run_tests()
1252