xref: /aosp_15_r20/external/pytorch/test/test_stateless.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2
3import contextlib
4import os
5import re
6import subprocess
7import sys
8import unittest
9
10import torch
11import torch.nn.utils.stateless as stateless
12from torch.testing._internal.common_cuda import TEST_MULTIGPU
13from torch.testing._internal.common_utils import run_tests, TestCase, parametrize, instantiate_parametrized_tests, \
14    subtest
15
16
17class MockModule(torch.nn.Module):
18    def __init__(self) -> None:
19        super().__init__()
20        self.l1 = torch.nn.Linear(1, 1)
21        self.buffer = torch.nn.Buffer(torch.ones(1))
22        self.foo = 0.0
23
24    def forward(self, x):
25        return self.l1(x) + self.buffer
26
27
28class MockTiedModule(torch.nn.Module):
29    def __init__(self) -> None:
30        super().__init__()
31        self.l1 = torch.nn.Linear(1, 1)
32        self.tied_bias = self.l1.bias
33        self.buffer = torch.nn.Buffer(torch.ones(1))
34        self.tied_buffer = self.buffer
35
36    def forward(self, x):
37        return self.l1(x) + self.tied_bias + self.buffer + self.tied_buffer
38
39
40class TestStatelessFunctionalAPI(TestCase):
41    def _run_call_with_mock_module(self, module, functional_call, device='cpu', prefix=''):
42
43        x = torch.rand((1, 1)).to(device)
44        weight = torch.tensor([[1.0]], device=device)
45        bias = torch.tensor([0.0], device=device)
46        buffer = torch.tensor([0.0], device=device)
47        if prefix != '':
48            parameters = {f'{prefix}.l1.weight': weight,
49                          f'{prefix}.l1.bias': bias,
50                          f'{prefix}.buffer': buffer}
51        else:
52            parameters = {'l1.weight': weight,
53                          'l1.bias': bias,
54                          'buffer': buffer}
55        to_check = module
56        if prefix != '':
57            to_check = getattr(module, prefix)
58        prev_weight = to_check.l1.weight.clone()
59        prev_buffer = to_check.buffer.clone()
60        # the parameters represent an identity function contrary to the
61        # existing params in module. So here we expect the result to be the
62        # same as the input if the weight swapping went well.
63        res = functional_call(module, parameters, x)
64        self.assertEqual(x, res)
65        # check that the weight remain unmodified
66        cur_weight = to_check.l1.weight
67        cur_buffer = to_check.buffer
68        self.assertEqual(cur_weight, prev_weight)
69        self.assertEqual(cur_buffer, prev_buffer)
70
71    @contextlib.contextmanager
72    def _ensure_module_unchanged(self, module, message):
73        orig_parameters, orig_buffers = tuple(module.parameters()), tuple(module.buffers())
74        orig_tensors = orig_parameters + orig_buffers
75        orig_tensors_values = tuple(t.clone() for t in orig_tensors)
76        try:
77            yield module
78        finally:
79            parameters, buffers = tuple(module.parameters()), tuple(module.buffers())
80            self.assertTrue(
81                len(parameters) == len(orig_parameters)
82                and len(buffers) == len(orig_buffers)
83                and all(
84                    t1 is t2 and torch.allclose(t1, t3)
85                    for t1, t2, t3 in zip(
86                        orig_tensors,
87                        parameters + buffers,
88                        orig_tensors_values,
89                    )
90                ),
91                message,
92            )
93
94    @parametrize("functional_call", [
95        subtest(torch.func.functional_call, "torch_func"),
96        subtest(stateless.functional_call, "stateless")
97    ])
98    def test_functional_call(self, functional_call):
99        module = MockModule()
100        self._run_call_with_mock_module(module, functional_call)
101
102    @parametrize("functional_call", [
103        subtest(torch.func.functional_call, "torch_func"),
104        subtest(stateless.functional_call, "stateless")
105    ])
106    def test_functional_call_with_jit(self, functional_call):
107        module = MockModule()
108        jit_module = torch.jit.script(module)
109        with self.assertRaisesRegex(
110            RuntimeError,
111            r'used with Jitted modules'
112        ):
113            self._run_call_with_mock_module(jit_module, functional_call)
114        x = torch.rand((1, 1))
115        traced_module = torch.jit.trace(module, x)
116        with self.assertRaisesRegex(
117            RuntimeError,
118            r'used with Jitted modules'
119        ):
120            self._run_call_with_mock_module(traced_module, functional_call)
121
122    @unittest.skipIf(not TEST_MULTIGPU, 'multi-GPU not supported')
123    @unittest.skip("This doesn't work right now")
124    @parametrize("functional_call", [
125        subtest(torch.func.functional_call, "torch_func"),
126        subtest(stateless.functional_call, "stateless")
127    ])
128    def test_functional_call_with_data_parallel(self, functional_call):
129        module = MockModule()
130        module.cuda()
131        dp_module = torch.nn.DataParallel(module, [0, 1])
132        self._run_call_with_mock_module(dp_module, functional_call, device='cuda', prefix='module')
133
134    @unittest.skipIf(not TEST_MULTIGPU, 'multi-GPU not supported')
135    @parametrize("functional_call", [
136        subtest(torch.func.functional_call, "torch_func"),
137        subtest(stateless.functional_call, "stateless")
138    ])
139    def test_functional_call_with_data_parallel_error(self, functional_call):
140        module = MockModule()
141        module.cuda()
142        dp_module = torch.nn.DataParallel(module, [0, 1])
143        with self.assertRaisesRegex(RuntimeError, r'used with nn.DataParallel module'):
144            functional_call(
145                dp_module,
146                {'module.weight': torch.zeros(5, device='cuda')},
147                (torch.ones(2, 5, device='cuda'),))
148
149    @parametrize("functional_call", [
150        subtest(torch.func.functional_call, "torch_func"),
151        subtest(stateless.functional_call, "stateless")
152    ])
153    def test_functional_call_with_gradient(self, functional_call):
154        module = MockModule()
155        x = torch.rand((1, 1))
156        weight = torch.tensor([[1.0]], requires_grad=True)
157        bias = torch.tensor([0.0], requires_grad=True)
158        buffer = torch.tensor([0.0])
159        parameters = {'l1.weight': weight,
160                      'l1.bias': bias,
161                      'buffer': buffer}
162        res = functional_call(module, parameters, x)
163        # Check that a backward step calculates the gradient of the supplied parameters
164        res.backward()
165        self.assertIsNotNone(weight.grad)
166        self.assertIsNotNone(bias.grad)
167        self.assertIsNone(buffer.grad)
168        # Gradient was not calculated for the module stated and buffers
169        self.assertIsNone(module.l1.weight.grad)
170        self.assertIsNone(module.l1.bias.grad)
171        self.assertIsNone(module.buffer.grad)
172
173    @parametrize("functional_call", [
174        subtest(torch.func.functional_call, "torch_func"),
175        subtest(stateless.functional_call, "stateless")
176    ])
177    def test_functional_batch_norm(self, functional_call):
178        module = torch.nn.BatchNorm1d(10)
179        module.train()  # Allow stats update
180        # lets replace the running_mean buffer and check if its correctly updated
181        x = torch.full((20, 10), 128.0)
182        rm = torch.zeros(10)
183        parameters = {'running_mean': rm}
184        prev_rm = module.running_mean.clone()
185        res = functional_call(module, parameters, x)
186        cur_rm = module.running_mean
187        self.assertEqual(cur_rm, prev_rm)
188        self.assertEqual(rm, torch.full((10,), 12.8))
189        # Now run functional without reparametrization and check that the module has
190        # been updated
191        res = functional_call(module, {}, x)
192        self.assertEqual(module.running_mean, torch.full((10,), 12.8))
193
194    @parametrize("functional_call", [
195        subtest(torch.func.functional_call, "torch_func"),
196        subtest(stateless.functional_call, "stateless")
197    ])
198    def test_circular_references(self, functional_call):
199        module = MockModule()
200        # Add a circular reference
201        module.l1.m = module
202        x = torch.rand((1, 1))
203        weight = torch.tensor([[1.0]])
204        bias = torch.tensor([0.0])
205        buffer = torch.tensor([0.0])
206        parameters = {'l1.m.l1.weight': weight,
207                      'l1.bias': bias,
208                      'l1.m.buffer': buffer}
209        prev_weight = module.l1.weight.clone()
210        prev_buffer = module.buffer.clone()
211        res = functional_call(module, parameters, x, tie_weights=False)
212        self.assertEqual(x, res)
213        # check that the weights remain unmodified and were correctly accesed
214        cur_weight = module.l1.weight
215        cur_buffer = module.buffer
216        self.assertEqual(cur_weight, prev_weight)
217        self.assertEqual(cur_buffer, prev_buffer)
218
219    @parametrize("functional_call", [
220        subtest(torch.func.functional_call, "torch_func"),
221        subtest(stateless.functional_call, "stateless")
222    ])
223    def test_reparametrized_module_change_parametrization_original(self, functional_call):
224        module = MockModule()
225        torch.nn.utils.parametrizations.spectral_norm(module.l1)
226        self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
227        orig_sn_weight = module.l1.weight.clone()
228        x = torch.rand((1, 1))
229        # We substitute the parameter inside the parametrization
230        # the parametrization itself is not overwritten so it will be applied with a different
231        # value for the original tensor
232        parameters = {'l1.parametrizations.weight.original': torch.nn.Parameter(torch.tensor([[1.0]])),
233                      'l1.bias': torch.tensor([0.0]),
234                      'buffer': torch.tensor([0.0])}
235        res = functional_call(module, parameters, x)
236        self.assertEqual(x, res)
237        # verify that the spectral normalization is still applied
238        self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
239        self.assertEqual(orig_sn_weight, module.l1.weight)
240
241    @parametrize("functional_call", [
242        subtest(torch.func.functional_call, "torch_func"),
243        subtest(stateless.functional_call, "stateless")
244    ])
245    def test_reparametrize_module_fail_reset_to_original(self, functional_call):
246        module = MockModule()
247        torch.nn.utils.parametrizations.spectral_norm(module.l1)
248        self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
249        orig_sn_weight = module.l1.weight.clone()
250        # We substitute the parameter inside the parametrization
251        # the parametrization itself is not overwritten so it will be applied with a different
252        # value for the original tensor
253        parameters = {'l1.parametrizations.weight.original': torch.nn.Parameter(torch.tensor([[1.0]])),
254                      'l1.bias': torch.tensor([0.0]),
255                      'buffer': torch.tensor([0.0])}
256
257        with self.assertRaisesRegex(RuntimeError, "shapes cannot be multiplied"):
258            @torch._dynamo.disable
259            def _error_case():
260                x = torch.rand((4, 5))  # to work, it should be of size (1, 1)
261                functional_call(module, parameters, x)  # this call will fail because x is the wrong size
262            _error_case()
263
264        # verify that the spectral normalization is still applied
265        self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
266        self.assertEqual(orig_sn_weight, module.l1.weight)
267
268    @parametrize("functional_call", [
269        subtest(torch.func.functional_call, "torch_func"),
270        subtest(stateless.functional_call, "stateless")
271    ])
272    def test_reparametrize_some_weights(self, functional_call):
273        module = MockModule()
274        weight = torch.tensor([[2.0]])
275        bias = torch.tensor([5.0])
276        buffer = torch.tensor([3.0])
277        extra = torch.tensor([1.0])
278
279        parameters = {'l1.weight': weight}
280        x = torch.randn(1, 1)
281        out = functional_call(module, parameters, x)
282        self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
283
284        parameters = {'l1.weight': weight,
285                      'extra': extra}
286        x = torch.randn(1, 1)
287        out = functional_call(module, parameters, x)
288        self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
289
290    @parametrize("functional_call", [
291        subtest(torch.func.functional_call, "torch_func"),
292        subtest(stateless.functional_call, "stateless")
293    ])
294    def test_reparametrize_strict(self, functional_call):
295        module = MockModule()
296        weight = torch.tensor([[2.0]])
297        bias = torch.tensor([5.0])
298        buffer = torch.tensor([3.0])
299        extra = torch.tensor([1.0])
300
301        # All weights no error
302        parameters = {'l1.weight': weight,
303                      'l1.bias': bias,
304                      'buffer': buffer}
305        x = torch.randn(1, 1)
306        with self._ensure_module_unchanged(
307            module,
308            'the module should not have been modified by a successful call',
309        ):
310            out = functional_call(module, parameters, x, strict=True)
311            self.assertEqual(out, x * weight + bias + buffer)
312
313        # Some weights
314        parameters = {'l1.weight': weight}
315        x = torch.randn(1, 1)
316        with self._ensure_module_unchanged(
317            module,
318            'the module should not have been modified by a failed call',
319        ):
320            with self.assertRaisesRegex(
321                RuntimeError,
322                re.escape("Missing key(s): 'buffer', 'l1.bias'."),
323            ):
324                out = functional_call(module, parameters, x, strict=True)
325
326        # Extra keys
327        parameters = {'l1.weight': weight,
328                      'l1.bias': bias,
329                      'buffer': buffer,
330                      'extra': extra}
331        x = torch.randn(1, 1)
332        with self._ensure_module_unchanged(
333            module,
334            'the module should not have been modified by a failed call',
335        ):
336            with self.assertRaisesRegex(
337                RuntimeError,
338                re.escape("Unexpected key(s): 'extra'."),
339            ):
340                out = functional_call(module, parameters, x, strict=True)
341
342        # Some weights with extra keys
343        parameters = {'l1.weight': weight,
344                      'extra': extra}
345        x = torch.randn(1, 1)
346        with self._ensure_module_unchanged(
347            module,
348            'the module should not have been modified by a failed call',
349        ):
350            with self.assertRaisesRegex(
351                RuntimeError,
352                re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'buffer', 'l1.bias'."),
353            ):
354                out = functional_call(module, parameters, x, strict=True)
355
356    @parametrize("functional_call", [
357        subtest(torch.func.functional_call, "torch_func"),
358        subtest(stateless.functional_call, "stateless")
359    ])
360    def test_reparametrize_special(self, functional_call):
361        class NonTensor:
362            def __repr__(self):
363                return f'<{self.__class__.__name__}>'
364
365        module = MockModule()
366        weight = torch.tensor([[2.0]])
367        bias = torch.tensor([5.0])
368        buffer = torch.tensor([3.0])
369        non_tensor = NonTensor()
370
371        # Set to None
372        parameters = {'l1.weight': weight,
373                      'l1.bias': None,
374                      'buffer': buffer}
375        x = torch.randn(1, 1)
376        with self._ensure_module_unchanged(
377            module,
378            'the module should not have been modified by a successful call',
379        ):
380            out = functional_call(module, parameters, x)
381            self.assertEqual(out, x * weight + buffer)
382
383        # Set non-tensor
384        parameters = {'l1.weight': non_tensor}
385        x = torch.randn(1, 1)
386        with self._ensure_module_unchanged(
387            module,
388            'the module should not have been modified by a failed call',
389        ):
390            with self.assertRaisesRegex(
391                TypeError,
392                re.escape("<NonTensor> is not an instance of torch.Tensor"),
393            ):
394                out = functional_call(module, parameters, x)
395
396        # Set non-tensor attribute
397        parameters = {'l1.weight': weight, 'foo': torch.tensor([1.0])}
398        x = torch.randn(1, 1)
399        with self._ensure_module_unchanged(
400            module,
401            'the module should not have been modified by a failed call',
402        ):
403            with self.assertRaisesRegex(
404                TypeError,
405                re.escape("attribute `foo`: 0.0 is not an instance of torch.Tensor"),
406            ):
407                out = functional_call(module, parameters, x)
408
409        # Set non-exist submodule
410        parameters = {'l1.weight': weight,
411                      'l2.bias': bias}
412        x = torch.randn(1, 1)
413        with self._ensure_module_unchanged(
414            module,
415            'the module should not have been modified by a failed call',
416        ):
417            with self.assertRaisesRegex(
418                AttributeError,
419                re.escape("MockModule has no attribute `l2`"),
420            ):
421                out = functional_call(module, parameters, x)
422
423    @parametrize("functional_call", [
424        subtest(torch.func.functional_call, "torch_func"),
425        subtest(stateless.functional_call, "stateless")
426    ])
427    def test_tied_weights_warns(self, functional_call):
428        module = MockModule()
429        module.tied_bias = module.l1.bias
430        module.tied_buffer = torch.nn.Buffer(module.buffer)
431
432    @parametrize("functional_call", [
433        subtest(torch.func.functional_call, "torch_func"),
434        subtest(stateless.functional_call, "stateless")
435    ])
436    def test_reparametrize_tie_weights(self, functional_call):
437        module = MockTiedModule()
438        weight = torch.tensor([[2.0]])
439        bias = torch.tensor([5.0])
440        buffer = torch.tensor([3.0])
441        extra = torch.tensor([1.0])
442
443        parameters = {'l1.weight': weight,
444                      'l1.bias': bias,
445                      'buffer': buffer}
446        x = torch.randn(1, 1)
447        out = functional_call(module, parameters, x, tie_weights=True)
448        self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
449
450        parameters = {'l1.weight': weight,
451                      'l1.bias': bias,
452                      'buffer': buffer,
453                      'extra': extra}
454        x = torch.randn(1, 1)
455        out = functional_call(module, parameters, x, tie_weights=True)
456        self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
457
458    @parametrize("functional_call", [
459        subtest(torch.func.functional_call, "torch_func"),
460        subtest(stateless.functional_call, "stateless")
461    ])
462    def test_reparametrize_tie_some_weights(self, functional_call):
463        module = MockTiedModule()
464        weight = torch.tensor([[2.0]])
465        buffer = torch.tensor([3.0])
466
467        parameters = {'l1.weight': weight,
468                      'buffer': buffer}
469        x = torch.randn(1, 1)
470        out = stateless.functional_call(module, parameters, x, tie_weights=True)
471        self.assertEqual(out, x * 2. + module.l1.bias + module.tied_bias + buffer + buffer)
472
473    @parametrize("functional_call", [
474        subtest(torch.func.functional_call, "torch_func"),
475        subtest(stateless._functional_call, "stateless")
476    ])
477    def test_tied_weights_errors(self, functional_call):
478        module = MockTiedModule()
479        weight = torch.tensor([[1.0]])
480        bias = torch.tensor([0.0])
481        buffer = torch.tensor([0.0])
482
483        parameters = {'l1.weight': weight,
484                      'l1.bias': bias,
485                      'buffer': buffer}
486        x = torch.randn(1, 1)
487        self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
488
489        # if tied values are the same tensors, shouldn't warn
490        parameters['tied_bias'] = bias
491        parameters['tied_buffer'] = buffer
492        self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
493        del parameters['tied_bias']
494        del parameters['tied_buffer']
495
496        with self.assertRaisesRegex(
497            ValueError,
498            re.escape("functional_call got multiple values for keys ['l1.bias', 'tied_bias']"),
499        ):
500            parameters['tied_bias'] = torch.tensor([5.0])
501            functional_call(module, parameters, x, tie_weights=True)
502        del parameters['tied_bias']
503
504        with self.assertRaisesRegex(
505            ValueError,
506            re.escape("functional_call got multiple values for keys ['buffer', 'tied_buffer']"),
507        ):
508            parameters['tied_buffer'] = torch.tensor([5.0])
509            functional_call(module, parameters, x, tie_weights=True)
510
511    def test_tied_weights_no_error_without_flag(self):
512        module = MockTiedModule()
513        weight = torch.tensor([[1.0]])
514        bias = torch.tensor([0.0])
515        buffer = torch.tensor([0.0])
516
517        parameters = {'l1.weight': weight,
518                      'l1.bias': bias,
519                      'buffer': buffer}
520        x = torch.randn(1, 1)
521        self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False))
522        parameters['tied_bias'] = torch.tensor([5.0])
523        self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False))
524        del parameters['tied_bias']
525        parameters['tied_buffer'] = torch.tensor([5.0])
526        self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False))
527
528    @parametrize("functional_call", [
529        subtest(torch.func.functional_call, "torch_func"),
530        subtest(stateless.functional_call, "stateless")
531    ])
532    def test_reparametrize_tie_weights_strict(self, functional_call):
533        module = MockTiedModule()
534        weight = torch.tensor([[2.0]])
535        bias = torch.tensor([5.0])
536        buffer = torch.tensor([3.0])
537        extra = torch.tensor([1.0])
538
539        # Tie weights no error
540        parameters = {'l1.weight': weight,
541                      'l1.bias': bias,
542                      'buffer': buffer}
543        x = torch.randn(1, 1)
544        with self._ensure_module_unchanged(
545            module,
546            'the module should not have been modified by a successful call',
547        ):
548            out = functional_call(module, parameters, x, tie_weights=True, strict=True)
549            self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
550
551        # Tie weights without flag
552        parameters = {'l1.weight': weight,
553                      'l1.bias': bias,
554                      'buffer': buffer}
555        x = torch.randn(1, 1)
556        with self._ensure_module_unchanged(
557            module,
558            'the module should not have been modified by a failed call',
559        ):
560            with self.assertRaisesRegex(
561                RuntimeError,
562                re.escape("Missing key(s): 'tied_bias', 'tied_buffer'."),
563            ):
564                out = functional_call(module, parameters, x, tie_weights=False, strict=True)
565
566        # Tie some weights
567        parameters = {'l1.weight': weight,
568                      'buffer': buffer}
569        x = torch.randn(1, 1)
570        with self._ensure_module_unchanged(
571            module,
572            'the module should not have been modified by a failed call',
573        ):
574            with self.assertRaisesRegex(
575                RuntimeError,
576                re.escape("Missing key(s): 'l1.bias', 'tied_bias'."),
577            ):
578                out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True)
579
580        # Tie weights with extra keys
581        parameters = {'l1.weight': weight,
582                      'l1.bias': bias,
583                      'buffer': buffer,
584                      'extra': extra}
585        x = torch.randn(1, 1)
586        with self._ensure_module_unchanged(
587            module,
588            'the module should not have been modified by a failed call',
589        ):
590            with self.assertRaisesRegex(
591                RuntimeError,
592                re.escape("Unexpected key(s): 'extra'."),
593            ):
594                out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True)
595
596        # Tie weights with extra keys and without flag
597        parameters = {'l1.weight': weight,
598                      'l1.bias': bias,
599                      'buffer': buffer,
600                      'extra': extra}
601        x = torch.randn(1, 1)
602        with self._ensure_module_unchanged(
603            module,
604            'the module should not have been modified by a failed call',
605        ):
606            with self.assertRaisesRegex(
607                RuntimeError,
608                re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'tied_bias', 'tied_buffer'."),
609            ):
610                out = stateless.functional_call(module, parameters, x, tie_weights=False, strict=True)
611
612        # Tie some weights with extra keys
613        parameters = {'l1.weight': weight,
614                      'buffer': buffer,
615                      'extra': extra}
616        x = torch.randn(1, 1)
617        with self._ensure_module_unchanged(
618            module,
619            'the module should not have been modified by a failed call',
620        ):
621            with self.assertRaisesRegex(
622                RuntimeError,
623                re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'l1.bias', 'tied_bias'."),
624            ):
625                out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True)
626
627    @parametrize("functional_call", [
628        subtest(torch.func.functional_call, "torch_func"),
629        subtest(stateless.functional_call, "stateless")
630    ])
631    def test_setattr(self, functional_call):
632        class Foo(torch.nn.Module):
633            def __init__(self) -> None:
634                super().__init__()
635                self.foo = torch.nn.Buffer(torch.tensor([0.0]))
636
637            def forward(self, x):
638                self.foo = self.foo + 1
639                return x + self.foo
640
641        foo = torch.tensor([2.0])
642        x = torch.randn(1)
643        a = {'foo': foo}
644        mod = Foo()
645        functional_call(mod, a, x)
646        self.assertEqual(mod.foo, torch.tensor([0.0]))
647        self.assertEqual(a['foo'], torch.tensor([3.0]))
648        self.assertEqual(foo, torch.tensor([2.0]))
649        self.assertTrue(a['foo'] is not foo)
650
651    @parametrize("functional_call", [
652        subtest(torch.func.functional_call, "torch_func"),
653        subtest(stateless.functional_call, "stateless")
654    ])
655    def test_in_place_operator(self, functional_call):
656        class Foo(torch.nn.Module):
657            def __init__(self) -> None:
658                super().__init__()
659                self.foo = torch.nn.Buffer(torch.tensor([0.0]))
660
661            def forward(self, x):
662                self.foo.add_(1)
663                return x + self.foo
664
665        foo = torch.tensor([2.0])
666        x = torch.randn(1)
667        a = {'foo': foo}
668        mod = Foo()
669        functional_call(mod, a, x)
670        self.assertEqual(mod.foo, torch.tensor([0.0]))
671        self.assertEqual(a['foo'], torch.tensor([3.0]))
672        self.assertEqual(foo, torch.tensor([3.0]))
673        self.assertTrue(a['foo'] is foo)
674
675    @parametrize("functional_call", [
676        subtest(torch.func.functional_call, "torch_func"),
677        subtest(stateless.functional_call, "stateless")
678    ])
679    def test_setattr_strict(self, functional_call):
680        class Bar(torch.nn.Module):
681            def __init__(self) -> None:
682                super().__init__()
683                assert not hasattr(self, 'extra')
684
685            def forward(self, x):
686                return x + self.extra
687
688        a = {'extra': torch.zeros(())}
689        mod = Bar()
690        self.assertTrue(not hasattr(mod, 'extra'))
691        out = functional_call(mod, a, torch.ones(()))
692        self.assertEqual(out, torch.ones(()))
693        self.assertTrue(not hasattr(mod, 'extra'))
694
695        a = {'extra': torch.zeros(())}
696        with self.assertRaisesRegex(
697            RuntimeError,
698            re.escape("Unexpected key(s): 'extra'."),
699        ):
700            out = functional_call(mod, a, torch.ones(()), strict=True)
701        self.assertTrue(not hasattr(mod, 'extra'))
702
703        a = {}
704        with self.assertRaisesRegex(
705            AttributeError,
706            re.escape("'Bar' object has no attribute 'extra'"),
707        ):
708            out = functional_call(mod, a, torch.ones(()))
709        self.assertTrue(not hasattr(mod, 'extra'))
710
711        a = {}
712        with self.assertRaisesRegex(
713            AttributeError,
714            re.escape("'Bar' object has no attribute 'extra'"),
715        ):
716            out = functional_call(mod, a, torch.ones(()), strict=True)
717        self.assertTrue(not hasattr(mod, 'extra'))
718
719    @parametrize("functional_call", [
720        subtest(torch.func.functional_call, "torch_func"),
721        subtest(stateless.functional_call, "stateless")
722    ])
723    def test_functional_call_with_kwargs(self, functional_call):
724        class Foo(torch.nn.Module):
725            def __init__(self, x):
726                super().__init__()
727                self.x = x
728
729            def forward(self, inp, *, other_inp):
730                return inp * self.x + other_inp
731
732        a = {'x': torch.zeros(2, 3)}
733        mod = Foo(torch.randn(2, 3))
734        inp, other_inp = torch.randn(2, 3), torch.randn(2, 3)
735        with self.assertRaisesRegex(TypeError, "missing 1 required keyword-only argument: 'other_inp'"):
736            functional_call(mod, a, inp)
737        res = functional_call(mod, a, inp, {'other_inp': other_inp})
738        self.assertEqual(res, other_inp)
739        res_1 = functional_call(mod, a, (), {'inp': inp, 'other_inp': other_inp})
740        self.assertEqual(res, res_1)
741
742    def test_functional_call_tuple_dicts(self):
743        mod = MockModule()
744        x = torch.rand((1, 1))
745        parameters = {k: torch.ones_like(v) for k, v in mod.named_parameters()}
746        buffers = {k: torch.zeros_like(v) for k, v in mod.named_buffers()}
747
748        # two dictionaries
749        res = torch.func.functional_call(mod, (parameters, buffers), x)
750        self.assertEqual(res, x + 1)
751
752        # no dictionaries
753        res = torch.func.functional_call(mod, (), x)
754        self.assertEqual(res, mod(x))
755
756        # three dictonaries
757        a = ({'l1.weight': torch.ones(1, 1)}, {'l1.bias': torch.ones(1)}, {'buffer': torch.zeros(1)})
758        res = torch.func.functional_call(mod, a, x)
759        self.assertEqual(res, x + 1)
760
761    def test_functional_call_multiple_dicts_error(self):
762        mod = MockModule()
763        x = torch.rand((1, 1))
764        parameters = {'l1.weight': torch.zeros((1, 1)), 'l1.bias': torch.zeros((1, 1))}
765        repeated_parameters = {'l1.weight': torch.ones((1, 1))}
766        with self.assertRaisesRegex(
767            ValueError,
768            re.escape("['l1.weight'] appeared in multiple dictionaries"),
769        ):
770            torch.func.functional_call(mod, (parameters, repeated_parameters), x)
771
772    @parametrize("functional_call", [
773        subtest(torch.func.functional_call, "torch_func"),
774        subtest(stateless.functional_call, "stateless")
775    ])
776    def test_functional_call_member_reference(self, functional_call):
777        class Module(torch.nn.Module):
778            def __init__(self) -> None:
779                super().__init__()
780                self.l1 = torch.nn.Linear(1, 1)
781                self.buffer = torch.nn.Buffer(torch.ones(1))
782
783            def forward(self, x):
784                parameters = tuple(self.parameters())
785                buffers = tuple(self.buffers())
786                return self.l1(x) + self.buffer, parameters, buffers
787
788        module = Module()
789        weight = torch.tensor([[2.0]])
790        bias = torch.tensor([5.0])
791        buffer = torch.tensor([3.0])
792        extra = torch.tensor([1.0])
793        extra_p = torch.nn.Parameter(extra)
794
795        # All weights
796        parameters = {'l1.weight': weight,
797                      'l1.bias': bias,
798                      'buffer': buffer}
799        x = torch.randn(1, 1)
800        out, parameters, buffers = functional_call(module, parameters, x)
801        self.assertEqual(out, x * weight + bias + buffer)
802        self.assertEqual(parameters, (weight, bias))
803        self.assertEqual(buffers, (buffer,))
804        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias))))
805        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,))))
806
807        # Some weights
808        parameters = {'l1.weight': weight}
809        x = torch.randn(1, 1)
810        out, parameters, buffers = functional_call(module, parameters, x)
811        self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
812        self.assertEqual(parameters, (weight, module.l1.bias))
813        self.assertEqual(buffers, (module.buffer,))
814        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias))))
815        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
816
817        # All weights with extra keys
818        parameters = {'l1.weight': weight,
819                      'l1.bias': bias,
820                      'buffer': buffer,
821                      'l1.extra': extra}
822        x = torch.randn(1, 1)
823        out, parameters, buffers = functional_call(module, parameters, x)
824        self.assertEqual(out, x * weight + bias + buffer)
825        self.assertEqual(parameters, (weight, bias))
826        self.assertEqual(buffers, (buffer,))
827        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias))))
828        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,))))
829
830        # All weights with extra keys with parameters
831        parameters = {'l1.weight': weight,
832                      'l1.bias': bias,
833                      'buffer': buffer,
834                      'l1.extra': extra_p}
835        x = torch.randn(1, 1)
836        out, parameters, buffers = functional_call(module, parameters, x)
837        self.assertEqual(out, x * weight + bias + buffer)
838        self.assertEqual(parameters, (weight, bias, extra_p))
839        self.assertEqual(buffers, (buffer,))
840        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias, extra_p))))
841        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,))))
842
843        # Some weights with extra keys
844        parameters = {'l1.weight': weight,
845                      'l1.extra': extra}
846        x = torch.randn(1, 1)
847        out, parameters, buffers = functional_call(module, parameters, x)
848        self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
849        self.assertEqual(parameters, (weight, module.l1.bias))
850        self.assertEqual(buffers, (module.buffer))
851        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias))))
852        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
853
854        # Some weights with extra keys with parameters
855        parameters = {'l1.weight': weight,
856                      'l1.extra': extra_p}
857        x = torch.randn(1, 1)
858        out, parameters, buffers = functional_call(module, parameters, x)
859        self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
860        self.assertEqual(parameters, (weight, module.l1.bias, extra_p))
861        self.assertEqual(buffers, (module.buffer))
862        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias, extra_p))))
863        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
864
865        # Set None
866        parameters = {'l1.weight': weight,
867                      'l1.bias': None}
868        x = torch.randn(1, 1)
869        out, parameters, buffers = functional_call(module, parameters, x)
870        self.assertEqual(out, x * weight + module.buffer)
871        self.assertEqual(parameters, (weight,))
872        self.assertEqual(buffers, (module.buffer))
873        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight,))))
874        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
875
876
877class TestStatelessDeprecation(TestCase):
878    def test_private_stateless_warns(self):
879        script = """
880import torch
881import warnings
882
883with warnings.catch_warnings(record=True) as w:
884    from torch.nn.utils import _stateless
885
886exit(len(w))
887"""
888        try:
889            subprocess.check_output(
890                [sys.executable, '-W', 'always', '-c', script],
891                stderr=subprocess.STDOUT,
892                # On Windows, opening the subprocess with the default CWD makes `import torch`
893                # fail, so just set CWD to this script's directory
894                cwd=os.path.dirname(os.path.realpath(__file__)),)
895        except subprocess.CalledProcessError as e:
896            self.assertEqual(e.returncode, 1)
897        else:
898            self.assertTrue(False, "No warning was raised.")
899
900    def test_stateless_functional_call_warns(self):
901        m = torch.nn.Linear(1, 1)
902        params = dict(m.named_parameters())
903        x = torch.randn(3, 1)
904        with self.assertWarnsRegex(FutureWarning, "Please use `torch.func.functional_call`"):
905            stateless.functional_call(m, params, x)
906
907class TestPythonOptimizeMode(TestCase):
908    def test_runs_with_optimize_flag(self):
909        script = "import torch; import torch._functorch.deprecated"
910        try:
911            subprocess.check_output(
912                [sys.executable, "-OO", "-c", script],
913                stderr=subprocess.STDOUT,
914                # On Windows, opening the subprocess with the default CWD makes `import torch`
915                # fail, so just set CWD to this script's directory
916                cwd=os.path.dirname(os.path.realpath(__file__)),)
917        except subprocess.CalledProcessError as e:
918            self.assertFalse(e.returncode, "Import failed while running python in optimized mode")
919
920
921instantiate_parametrized_tests(
922    TestStatelessFunctionalAPI,
923)
924
925if __name__ == '__main__':
926    run_tests()
927