1# Owner(s): ["module: nn"] 2 3import contextlib 4import os 5import re 6import subprocess 7import sys 8import unittest 9 10import torch 11import torch.nn.utils.stateless as stateless 12from torch.testing._internal.common_cuda import TEST_MULTIGPU 13from torch.testing._internal.common_utils import run_tests, TestCase, parametrize, instantiate_parametrized_tests, \ 14 subtest 15 16 17class MockModule(torch.nn.Module): 18 def __init__(self) -> None: 19 super().__init__() 20 self.l1 = torch.nn.Linear(1, 1) 21 self.buffer = torch.nn.Buffer(torch.ones(1)) 22 self.foo = 0.0 23 24 def forward(self, x): 25 return self.l1(x) + self.buffer 26 27 28class MockTiedModule(torch.nn.Module): 29 def __init__(self) -> None: 30 super().__init__() 31 self.l1 = torch.nn.Linear(1, 1) 32 self.tied_bias = self.l1.bias 33 self.buffer = torch.nn.Buffer(torch.ones(1)) 34 self.tied_buffer = self.buffer 35 36 def forward(self, x): 37 return self.l1(x) + self.tied_bias + self.buffer + self.tied_buffer 38 39 40class TestStatelessFunctionalAPI(TestCase): 41 def _run_call_with_mock_module(self, module, functional_call, device='cpu', prefix=''): 42 43 x = torch.rand((1, 1)).to(device) 44 weight = torch.tensor([[1.0]], device=device) 45 bias = torch.tensor([0.0], device=device) 46 buffer = torch.tensor([0.0], device=device) 47 if prefix != '': 48 parameters = {f'{prefix}.l1.weight': weight, 49 f'{prefix}.l1.bias': bias, 50 f'{prefix}.buffer': buffer} 51 else: 52 parameters = {'l1.weight': weight, 53 'l1.bias': bias, 54 'buffer': buffer} 55 to_check = module 56 if prefix != '': 57 to_check = getattr(module, prefix) 58 prev_weight = to_check.l1.weight.clone() 59 prev_buffer = to_check.buffer.clone() 60 # the parameters represent an identity function contrary to the 61 # existing params in module. So here we expect the result to be the 62 # same as the input if the weight swapping went well. 63 res = functional_call(module, parameters, x) 64 self.assertEqual(x, res) 65 # check that the weight remain unmodified 66 cur_weight = to_check.l1.weight 67 cur_buffer = to_check.buffer 68 self.assertEqual(cur_weight, prev_weight) 69 self.assertEqual(cur_buffer, prev_buffer) 70 71 @contextlib.contextmanager 72 def _ensure_module_unchanged(self, module, message): 73 orig_parameters, orig_buffers = tuple(module.parameters()), tuple(module.buffers()) 74 orig_tensors = orig_parameters + orig_buffers 75 orig_tensors_values = tuple(t.clone() for t in orig_tensors) 76 try: 77 yield module 78 finally: 79 parameters, buffers = tuple(module.parameters()), tuple(module.buffers()) 80 self.assertTrue( 81 len(parameters) == len(orig_parameters) 82 and len(buffers) == len(orig_buffers) 83 and all( 84 t1 is t2 and torch.allclose(t1, t3) 85 for t1, t2, t3 in zip( 86 orig_tensors, 87 parameters + buffers, 88 orig_tensors_values, 89 ) 90 ), 91 message, 92 ) 93 94 @parametrize("functional_call", [ 95 subtest(torch.func.functional_call, "torch_func"), 96 subtest(stateless.functional_call, "stateless") 97 ]) 98 def test_functional_call(self, functional_call): 99 module = MockModule() 100 self._run_call_with_mock_module(module, functional_call) 101 102 @parametrize("functional_call", [ 103 subtest(torch.func.functional_call, "torch_func"), 104 subtest(stateless.functional_call, "stateless") 105 ]) 106 def test_functional_call_with_jit(self, functional_call): 107 module = MockModule() 108 jit_module = torch.jit.script(module) 109 with self.assertRaisesRegex( 110 RuntimeError, 111 r'used with Jitted modules' 112 ): 113 self._run_call_with_mock_module(jit_module, functional_call) 114 x = torch.rand((1, 1)) 115 traced_module = torch.jit.trace(module, x) 116 with self.assertRaisesRegex( 117 RuntimeError, 118 r'used with Jitted modules' 119 ): 120 self._run_call_with_mock_module(traced_module, functional_call) 121 122 @unittest.skipIf(not TEST_MULTIGPU, 'multi-GPU not supported') 123 @unittest.skip("This doesn't work right now") 124 @parametrize("functional_call", [ 125 subtest(torch.func.functional_call, "torch_func"), 126 subtest(stateless.functional_call, "stateless") 127 ]) 128 def test_functional_call_with_data_parallel(self, functional_call): 129 module = MockModule() 130 module.cuda() 131 dp_module = torch.nn.DataParallel(module, [0, 1]) 132 self._run_call_with_mock_module(dp_module, functional_call, device='cuda', prefix='module') 133 134 @unittest.skipIf(not TEST_MULTIGPU, 'multi-GPU not supported') 135 @parametrize("functional_call", [ 136 subtest(torch.func.functional_call, "torch_func"), 137 subtest(stateless.functional_call, "stateless") 138 ]) 139 def test_functional_call_with_data_parallel_error(self, functional_call): 140 module = MockModule() 141 module.cuda() 142 dp_module = torch.nn.DataParallel(module, [0, 1]) 143 with self.assertRaisesRegex(RuntimeError, r'used with nn.DataParallel module'): 144 functional_call( 145 dp_module, 146 {'module.weight': torch.zeros(5, device='cuda')}, 147 (torch.ones(2, 5, device='cuda'),)) 148 149 @parametrize("functional_call", [ 150 subtest(torch.func.functional_call, "torch_func"), 151 subtest(stateless.functional_call, "stateless") 152 ]) 153 def test_functional_call_with_gradient(self, functional_call): 154 module = MockModule() 155 x = torch.rand((1, 1)) 156 weight = torch.tensor([[1.0]], requires_grad=True) 157 bias = torch.tensor([0.0], requires_grad=True) 158 buffer = torch.tensor([0.0]) 159 parameters = {'l1.weight': weight, 160 'l1.bias': bias, 161 'buffer': buffer} 162 res = functional_call(module, parameters, x) 163 # Check that a backward step calculates the gradient of the supplied parameters 164 res.backward() 165 self.assertIsNotNone(weight.grad) 166 self.assertIsNotNone(bias.grad) 167 self.assertIsNone(buffer.grad) 168 # Gradient was not calculated for the module stated and buffers 169 self.assertIsNone(module.l1.weight.grad) 170 self.assertIsNone(module.l1.bias.grad) 171 self.assertIsNone(module.buffer.grad) 172 173 @parametrize("functional_call", [ 174 subtest(torch.func.functional_call, "torch_func"), 175 subtest(stateless.functional_call, "stateless") 176 ]) 177 def test_functional_batch_norm(self, functional_call): 178 module = torch.nn.BatchNorm1d(10) 179 module.train() # Allow stats update 180 # lets replace the running_mean buffer and check if its correctly updated 181 x = torch.full((20, 10), 128.0) 182 rm = torch.zeros(10) 183 parameters = {'running_mean': rm} 184 prev_rm = module.running_mean.clone() 185 res = functional_call(module, parameters, x) 186 cur_rm = module.running_mean 187 self.assertEqual(cur_rm, prev_rm) 188 self.assertEqual(rm, torch.full((10,), 12.8)) 189 # Now run functional without reparametrization and check that the module has 190 # been updated 191 res = functional_call(module, {}, x) 192 self.assertEqual(module.running_mean, torch.full((10,), 12.8)) 193 194 @parametrize("functional_call", [ 195 subtest(torch.func.functional_call, "torch_func"), 196 subtest(stateless.functional_call, "stateless") 197 ]) 198 def test_circular_references(self, functional_call): 199 module = MockModule() 200 # Add a circular reference 201 module.l1.m = module 202 x = torch.rand((1, 1)) 203 weight = torch.tensor([[1.0]]) 204 bias = torch.tensor([0.0]) 205 buffer = torch.tensor([0.0]) 206 parameters = {'l1.m.l1.weight': weight, 207 'l1.bias': bias, 208 'l1.m.buffer': buffer} 209 prev_weight = module.l1.weight.clone() 210 prev_buffer = module.buffer.clone() 211 res = functional_call(module, parameters, x, tie_weights=False) 212 self.assertEqual(x, res) 213 # check that the weights remain unmodified and were correctly accesed 214 cur_weight = module.l1.weight 215 cur_buffer = module.buffer 216 self.assertEqual(cur_weight, prev_weight) 217 self.assertEqual(cur_buffer, prev_buffer) 218 219 @parametrize("functional_call", [ 220 subtest(torch.func.functional_call, "torch_func"), 221 subtest(stateless.functional_call, "stateless") 222 ]) 223 def test_reparametrized_module_change_parametrization_original(self, functional_call): 224 module = MockModule() 225 torch.nn.utils.parametrizations.spectral_norm(module.l1) 226 self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters())) 227 orig_sn_weight = module.l1.weight.clone() 228 x = torch.rand((1, 1)) 229 # We substitute the parameter inside the parametrization 230 # the parametrization itself is not overwritten so it will be applied with a different 231 # value for the original tensor 232 parameters = {'l1.parametrizations.weight.original': torch.nn.Parameter(torch.tensor([[1.0]])), 233 'l1.bias': torch.tensor([0.0]), 234 'buffer': torch.tensor([0.0])} 235 res = functional_call(module, parameters, x) 236 self.assertEqual(x, res) 237 # verify that the spectral normalization is still applied 238 self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters())) 239 self.assertEqual(orig_sn_weight, module.l1.weight) 240 241 @parametrize("functional_call", [ 242 subtest(torch.func.functional_call, "torch_func"), 243 subtest(stateless.functional_call, "stateless") 244 ]) 245 def test_reparametrize_module_fail_reset_to_original(self, functional_call): 246 module = MockModule() 247 torch.nn.utils.parametrizations.spectral_norm(module.l1) 248 self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters())) 249 orig_sn_weight = module.l1.weight.clone() 250 # We substitute the parameter inside the parametrization 251 # the parametrization itself is not overwritten so it will be applied with a different 252 # value for the original tensor 253 parameters = {'l1.parametrizations.weight.original': torch.nn.Parameter(torch.tensor([[1.0]])), 254 'l1.bias': torch.tensor([0.0]), 255 'buffer': torch.tensor([0.0])} 256 257 with self.assertRaisesRegex(RuntimeError, "shapes cannot be multiplied"): 258 @torch._dynamo.disable 259 def _error_case(): 260 x = torch.rand((4, 5)) # to work, it should be of size (1, 1) 261 functional_call(module, parameters, x) # this call will fail because x is the wrong size 262 _error_case() 263 264 # verify that the spectral normalization is still applied 265 self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters())) 266 self.assertEqual(orig_sn_weight, module.l1.weight) 267 268 @parametrize("functional_call", [ 269 subtest(torch.func.functional_call, "torch_func"), 270 subtest(stateless.functional_call, "stateless") 271 ]) 272 def test_reparametrize_some_weights(self, functional_call): 273 module = MockModule() 274 weight = torch.tensor([[2.0]]) 275 bias = torch.tensor([5.0]) 276 buffer = torch.tensor([3.0]) 277 extra = torch.tensor([1.0]) 278 279 parameters = {'l1.weight': weight} 280 x = torch.randn(1, 1) 281 out = functional_call(module, parameters, x) 282 self.assertEqual(out, x * weight + module.l1.bias + module.buffer) 283 284 parameters = {'l1.weight': weight, 285 'extra': extra} 286 x = torch.randn(1, 1) 287 out = functional_call(module, parameters, x) 288 self.assertEqual(out, x * weight + module.l1.bias + module.buffer) 289 290 @parametrize("functional_call", [ 291 subtest(torch.func.functional_call, "torch_func"), 292 subtest(stateless.functional_call, "stateless") 293 ]) 294 def test_reparametrize_strict(self, functional_call): 295 module = MockModule() 296 weight = torch.tensor([[2.0]]) 297 bias = torch.tensor([5.0]) 298 buffer = torch.tensor([3.0]) 299 extra = torch.tensor([1.0]) 300 301 # All weights no error 302 parameters = {'l1.weight': weight, 303 'l1.bias': bias, 304 'buffer': buffer} 305 x = torch.randn(1, 1) 306 with self._ensure_module_unchanged( 307 module, 308 'the module should not have been modified by a successful call', 309 ): 310 out = functional_call(module, parameters, x, strict=True) 311 self.assertEqual(out, x * weight + bias + buffer) 312 313 # Some weights 314 parameters = {'l1.weight': weight} 315 x = torch.randn(1, 1) 316 with self._ensure_module_unchanged( 317 module, 318 'the module should not have been modified by a failed call', 319 ): 320 with self.assertRaisesRegex( 321 RuntimeError, 322 re.escape("Missing key(s): 'buffer', 'l1.bias'."), 323 ): 324 out = functional_call(module, parameters, x, strict=True) 325 326 # Extra keys 327 parameters = {'l1.weight': weight, 328 'l1.bias': bias, 329 'buffer': buffer, 330 'extra': extra} 331 x = torch.randn(1, 1) 332 with self._ensure_module_unchanged( 333 module, 334 'the module should not have been modified by a failed call', 335 ): 336 with self.assertRaisesRegex( 337 RuntimeError, 338 re.escape("Unexpected key(s): 'extra'."), 339 ): 340 out = functional_call(module, parameters, x, strict=True) 341 342 # Some weights with extra keys 343 parameters = {'l1.weight': weight, 344 'extra': extra} 345 x = torch.randn(1, 1) 346 with self._ensure_module_unchanged( 347 module, 348 'the module should not have been modified by a failed call', 349 ): 350 with self.assertRaisesRegex( 351 RuntimeError, 352 re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'buffer', 'l1.bias'."), 353 ): 354 out = functional_call(module, parameters, x, strict=True) 355 356 @parametrize("functional_call", [ 357 subtest(torch.func.functional_call, "torch_func"), 358 subtest(stateless.functional_call, "stateless") 359 ]) 360 def test_reparametrize_special(self, functional_call): 361 class NonTensor: 362 def __repr__(self): 363 return f'<{self.__class__.__name__}>' 364 365 module = MockModule() 366 weight = torch.tensor([[2.0]]) 367 bias = torch.tensor([5.0]) 368 buffer = torch.tensor([3.0]) 369 non_tensor = NonTensor() 370 371 # Set to None 372 parameters = {'l1.weight': weight, 373 'l1.bias': None, 374 'buffer': buffer} 375 x = torch.randn(1, 1) 376 with self._ensure_module_unchanged( 377 module, 378 'the module should not have been modified by a successful call', 379 ): 380 out = functional_call(module, parameters, x) 381 self.assertEqual(out, x * weight + buffer) 382 383 # Set non-tensor 384 parameters = {'l1.weight': non_tensor} 385 x = torch.randn(1, 1) 386 with self._ensure_module_unchanged( 387 module, 388 'the module should not have been modified by a failed call', 389 ): 390 with self.assertRaisesRegex( 391 TypeError, 392 re.escape("<NonTensor> is not an instance of torch.Tensor"), 393 ): 394 out = functional_call(module, parameters, x) 395 396 # Set non-tensor attribute 397 parameters = {'l1.weight': weight, 'foo': torch.tensor([1.0])} 398 x = torch.randn(1, 1) 399 with self._ensure_module_unchanged( 400 module, 401 'the module should not have been modified by a failed call', 402 ): 403 with self.assertRaisesRegex( 404 TypeError, 405 re.escape("attribute `foo`: 0.0 is not an instance of torch.Tensor"), 406 ): 407 out = functional_call(module, parameters, x) 408 409 # Set non-exist submodule 410 parameters = {'l1.weight': weight, 411 'l2.bias': bias} 412 x = torch.randn(1, 1) 413 with self._ensure_module_unchanged( 414 module, 415 'the module should not have been modified by a failed call', 416 ): 417 with self.assertRaisesRegex( 418 AttributeError, 419 re.escape("MockModule has no attribute `l2`"), 420 ): 421 out = functional_call(module, parameters, x) 422 423 @parametrize("functional_call", [ 424 subtest(torch.func.functional_call, "torch_func"), 425 subtest(stateless.functional_call, "stateless") 426 ]) 427 def test_tied_weights_warns(self, functional_call): 428 module = MockModule() 429 module.tied_bias = module.l1.bias 430 module.tied_buffer = torch.nn.Buffer(module.buffer) 431 432 @parametrize("functional_call", [ 433 subtest(torch.func.functional_call, "torch_func"), 434 subtest(stateless.functional_call, "stateless") 435 ]) 436 def test_reparametrize_tie_weights(self, functional_call): 437 module = MockTiedModule() 438 weight = torch.tensor([[2.0]]) 439 bias = torch.tensor([5.0]) 440 buffer = torch.tensor([3.0]) 441 extra = torch.tensor([1.0]) 442 443 parameters = {'l1.weight': weight, 444 'l1.bias': bias, 445 'buffer': buffer} 446 x = torch.randn(1, 1) 447 out = functional_call(module, parameters, x, tie_weights=True) 448 self.assertEqual(out, x * weight + bias + bias + buffer + buffer) 449 450 parameters = {'l1.weight': weight, 451 'l1.bias': bias, 452 'buffer': buffer, 453 'extra': extra} 454 x = torch.randn(1, 1) 455 out = functional_call(module, parameters, x, tie_weights=True) 456 self.assertEqual(out, x * weight + bias + bias + buffer + buffer) 457 458 @parametrize("functional_call", [ 459 subtest(torch.func.functional_call, "torch_func"), 460 subtest(stateless.functional_call, "stateless") 461 ]) 462 def test_reparametrize_tie_some_weights(self, functional_call): 463 module = MockTiedModule() 464 weight = torch.tensor([[2.0]]) 465 buffer = torch.tensor([3.0]) 466 467 parameters = {'l1.weight': weight, 468 'buffer': buffer} 469 x = torch.randn(1, 1) 470 out = stateless.functional_call(module, parameters, x, tie_weights=True) 471 self.assertEqual(out, x * 2. + module.l1.bias + module.tied_bias + buffer + buffer) 472 473 @parametrize("functional_call", [ 474 subtest(torch.func.functional_call, "torch_func"), 475 subtest(stateless._functional_call, "stateless") 476 ]) 477 def test_tied_weights_errors(self, functional_call): 478 module = MockTiedModule() 479 weight = torch.tensor([[1.0]]) 480 bias = torch.tensor([0.0]) 481 buffer = torch.tensor([0.0]) 482 483 parameters = {'l1.weight': weight, 484 'l1.bias': bias, 485 'buffer': buffer} 486 x = torch.randn(1, 1) 487 self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True)) 488 489 # if tied values are the same tensors, shouldn't warn 490 parameters['tied_bias'] = bias 491 parameters['tied_buffer'] = buffer 492 self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True)) 493 del parameters['tied_bias'] 494 del parameters['tied_buffer'] 495 496 with self.assertRaisesRegex( 497 ValueError, 498 re.escape("functional_call got multiple values for keys ['l1.bias', 'tied_bias']"), 499 ): 500 parameters['tied_bias'] = torch.tensor([5.0]) 501 functional_call(module, parameters, x, tie_weights=True) 502 del parameters['tied_bias'] 503 504 with self.assertRaisesRegex( 505 ValueError, 506 re.escape("functional_call got multiple values for keys ['buffer', 'tied_buffer']"), 507 ): 508 parameters['tied_buffer'] = torch.tensor([5.0]) 509 functional_call(module, parameters, x, tie_weights=True) 510 511 def test_tied_weights_no_error_without_flag(self): 512 module = MockTiedModule() 513 weight = torch.tensor([[1.0]]) 514 bias = torch.tensor([0.0]) 515 buffer = torch.tensor([0.0]) 516 517 parameters = {'l1.weight': weight, 518 'l1.bias': bias, 519 'buffer': buffer} 520 x = torch.randn(1, 1) 521 self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False)) 522 parameters['tied_bias'] = torch.tensor([5.0]) 523 self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False)) 524 del parameters['tied_bias'] 525 parameters['tied_buffer'] = torch.tensor([5.0]) 526 self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False)) 527 528 @parametrize("functional_call", [ 529 subtest(torch.func.functional_call, "torch_func"), 530 subtest(stateless.functional_call, "stateless") 531 ]) 532 def test_reparametrize_tie_weights_strict(self, functional_call): 533 module = MockTiedModule() 534 weight = torch.tensor([[2.0]]) 535 bias = torch.tensor([5.0]) 536 buffer = torch.tensor([3.0]) 537 extra = torch.tensor([1.0]) 538 539 # Tie weights no error 540 parameters = {'l1.weight': weight, 541 'l1.bias': bias, 542 'buffer': buffer} 543 x = torch.randn(1, 1) 544 with self._ensure_module_unchanged( 545 module, 546 'the module should not have been modified by a successful call', 547 ): 548 out = functional_call(module, parameters, x, tie_weights=True, strict=True) 549 self.assertEqual(out, x * weight + bias + bias + buffer + buffer) 550 551 # Tie weights without flag 552 parameters = {'l1.weight': weight, 553 'l1.bias': bias, 554 'buffer': buffer} 555 x = torch.randn(1, 1) 556 with self._ensure_module_unchanged( 557 module, 558 'the module should not have been modified by a failed call', 559 ): 560 with self.assertRaisesRegex( 561 RuntimeError, 562 re.escape("Missing key(s): 'tied_bias', 'tied_buffer'."), 563 ): 564 out = functional_call(module, parameters, x, tie_weights=False, strict=True) 565 566 # Tie some weights 567 parameters = {'l1.weight': weight, 568 'buffer': buffer} 569 x = torch.randn(1, 1) 570 with self._ensure_module_unchanged( 571 module, 572 'the module should not have been modified by a failed call', 573 ): 574 with self.assertRaisesRegex( 575 RuntimeError, 576 re.escape("Missing key(s): 'l1.bias', 'tied_bias'."), 577 ): 578 out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True) 579 580 # Tie weights with extra keys 581 parameters = {'l1.weight': weight, 582 'l1.bias': bias, 583 'buffer': buffer, 584 'extra': extra} 585 x = torch.randn(1, 1) 586 with self._ensure_module_unchanged( 587 module, 588 'the module should not have been modified by a failed call', 589 ): 590 with self.assertRaisesRegex( 591 RuntimeError, 592 re.escape("Unexpected key(s): 'extra'."), 593 ): 594 out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True) 595 596 # Tie weights with extra keys and without flag 597 parameters = {'l1.weight': weight, 598 'l1.bias': bias, 599 'buffer': buffer, 600 'extra': extra} 601 x = torch.randn(1, 1) 602 with self._ensure_module_unchanged( 603 module, 604 'the module should not have been modified by a failed call', 605 ): 606 with self.assertRaisesRegex( 607 RuntimeError, 608 re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'tied_bias', 'tied_buffer'."), 609 ): 610 out = stateless.functional_call(module, parameters, x, tie_weights=False, strict=True) 611 612 # Tie some weights with extra keys 613 parameters = {'l1.weight': weight, 614 'buffer': buffer, 615 'extra': extra} 616 x = torch.randn(1, 1) 617 with self._ensure_module_unchanged( 618 module, 619 'the module should not have been modified by a failed call', 620 ): 621 with self.assertRaisesRegex( 622 RuntimeError, 623 re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'l1.bias', 'tied_bias'."), 624 ): 625 out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True) 626 627 @parametrize("functional_call", [ 628 subtest(torch.func.functional_call, "torch_func"), 629 subtest(stateless.functional_call, "stateless") 630 ]) 631 def test_setattr(self, functional_call): 632 class Foo(torch.nn.Module): 633 def __init__(self) -> None: 634 super().__init__() 635 self.foo = torch.nn.Buffer(torch.tensor([0.0])) 636 637 def forward(self, x): 638 self.foo = self.foo + 1 639 return x + self.foo 640 641 foo = torch.tensor([2.0]) 642 x = torch.randn(1) 643 a = {'foo': foo} 644 mod = Foo() 645 functional_call(mod, a, x) 646 self.assertEqual(mod.foo, torch.tensor([0.0])) 647 self.assertEqual(a['foo'], torch.tensor([3.0])) 648 self.assertEqual(foo, torch.tensor([2.0])) 649 self.assertTrue(a['foo'] is not foo) 650 651 @parametrize("functional_call", [ 652 subtest(torch.func.functional_call, "torch_func"), 653 subtest(stateless.functional_call, "stateless") 654 ]) 655 def test_in_place_operator(self, functional_call): 656 class Foo(torch.nn.Module): 657 def __init__(self) -> None: 658 super().__init__() 659 self.foo = torch.nn.Buffer(torch.tensor([0.0])) 660 661 def forward(self, x): 662 self.foo.add_(1) 663 return x + self.foo 664 665 foo = torch.tensor([2.0]) 666 x = torch.randn(1) 667 a = {'foo': foo} 668 mod = Foo() 669 functional_call(mod, a, x) 670 self.assertEqual(mod.foo, torch.tensor([0.0])) 671 self.assertEqual(a['foo'], torch.tensor([3.0])) 672 self.assertEqual(foo, torch.tensor([3.0])) 673 self.assertTrue(a['foo'] is foo) 674 675 @parametrize("functional_call", [ 676 subtest(torch.func.functional_call, "torch_func"), 677 subtest(stateless.functional_call, "stateless") 678 ]) 679 def test_setattr_strict(self, functional_call): 680 class Bar(torch.nn.Module): 681 def __init__(self) -> None: 682 super().__init__() 683 assert not hasattr(self, 'extra') 684 685 def forward(self, x): 686 return x + self.extra 687 688 a = {'extra': torch.zeros(())} 689 mod = Bar() 690 self.assertTrue(not hasattr(mod, 'extra')) 691 out = functional_call(mod, a, torch.ones(())) 692 self.assertEqual(out, torch.ones(())) 693 self.assertTrue(not hasattr(mod, 'extra')) 694 695 a = {'extra': torch.zeros(())} 696 with self.assertRaisesRegex( 697 RuntimeError, 698 re.escape("Unexpected key(s): 'extra'."), 699 ): 700 out = functional_call(mod, a, torch.ones(()), strict=True) 701 self.assertTrue(not hasattr(mod, 'extra')) 702 703 a = {} 704 with self.assertRaisesRegex( 705 AttributeError, 706 re.escape("'Bar' object has no attribute 'extra'"), 707 ): 708 out = functional_call(mod, a, torch.ones(())) 709 self.assertTrue(not hasattr(mod, 'extra')) 710 711 a = {} 712 with self.assertRaisesRegex( 713 AttributeError, 714 re.escape("'Bar' object has no attribute 'extra'"), 715 ): 716 out = functional_call(mod, a, torch.ones(()), strict=True) 717 self.assertTrue(not hasattr(mod, 'extra')) 718 719 @parametrize("functional_call", [ 720 subtest(torch.func.functional_call, "torch_func"), 721 subtest(stateless.functional_call, "stateless") 722 ]) 723 def test_functional_call_with_kwargs(self, functional_call): 724 class Foo(torch.nn.Module): 725 def __init__(self, x): 726 super().__init__() 727 self.x = x 728 729 def forward(self, inp, *, other_inp): 730 return inp * self.x + other_inp 731 732 a = {'x': torch.zeros(2, 3)} 733 mod = Foo(torch.randn(2, 3)) 734 inp, other_inp = torch.randn(2, 3), torch.randn(2, 3) 735 with self.assertRaisesRegex(TypeError, "missing 1 required keyword-only argument: 'other_inp'"): 736 functional_call(mod, a, inp) 737 res = functional_call(mod, a, inp, {'other_inp': other_inp}) 738 self.assertEqual(res, other_inp) 739 res_1 = functional_call(mod, a, (), {'inp': inp, 'other_inp': other_inp}) 740 self.assertEqual(res, res_1) 741 742 def test_functional_call_tuple_dicts(self): 743 mod = MockModule() 744 x = torch.rand((1, 1)) 745 parameters = {k: torch.ones_like(v) for k, v in mod.named_parameters()} 746 buffers = {k: torch.zeros_like(v) for k, v in mod.named_buffers()} 747 748 # two dictionaries 749 res = torch.func.functional_call(mod, (parameters, buffers), x) 750 self.assertEqual(res, x + 1) 751 752 # no dictionaries 753 res = torch.func.functional_call(mod, (), x) 754 self.assertEqual(res, mod(x)) 755 756 # three dictonaries 757 a = ({'l1.weight': torch.ones(1, 1)}, {'l1.bias': torch.ones(1)}, {'buffer': torch.zeros(1)}) 758 res = torch.func.functional_call(mod, a, x) 759 self.assertEqual(res, x + 1) 760 761 def test_functional_call_multiple_dicts_error(self): 762 mod = MockModule() 763 x = torch.rand((1, 1)) 764 parameters = {'l1.weight': torch.zeros((1, 1)), 'l1.bias': torch.zeros((1, 1))} 765 repeated_parameters = {'l1.weight': torch.ones((1, 1))} 766 with self.assertRaisesRegex( 767 ValueError, 768 re.escape("['l1.weight'] appeared in multiple dictionaries"), 769 ): 770 torch.func.functional_call(mod, (parameters, repeated_parameters), x) 771 772 @parametrize("functional_call", [ 773 subtest(torch.func.functional_call, "torch_func"), 774 subtest(stateless.functional_call, "stateless") 775 ]) 776 def test_functional_call_member_reference(self, functional_call): 777 class Module(torch.nn.Module): 778 def __init__(self) -> None: 779 super().__init__() 780 self.l1 = torch.nn.Linear(1, 1) 781 self.buffer = torch.nn.Buffer(torch.ones(1)) 782 783 def forward(self, x): 784 parameters = tuple(self.parameters()) 785 buffers = tuple(self.buffers()) 786 return self.l1(x) + self.buffer, parameters, buffers 787 788 module = Module() 789 weight = torch.tensor([[2.0]]) 790 bias = torch.tensor([5.0]) 791 buffer = torch.tensor([3.0]) 792 extra = torch.tensor([1.0]) 793 extra_p = torch.nn.Parameter(extra) 794 795 # All weights 796 parameters = {'l1.weight': weight, 797 'l1.bias': bias, 798 'buffer': buffer} 799 x = torch.randn(1, 1) 800 out, parameters, buffers = functional_call(module, parameters, x) 801 self.assertEqual(out, x * weight + bias + buffer) 802 self.assertEqual(parameters, (weight, bias)) 803 self.assertEqual(buffers, (buffer,)) 804 self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias)))) 805 self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,)))) 806 807 # Some weights 808 parameters = {'l1.weight': weight} 809 x = torch.randn(1, 1) 810 out, parameters, buffers = functional_call(module, parameters, x) 811 self.assertEqual(out, x * weight + module.l1.bias + module.buffer) 812 self.assertEqual(parameters, (weight, module.l1.bias)) 813 self.assertEqual(buffers, (module.buffer,)) 814 self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias)))) 815 self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,)))) 816 817 # All weights with extra keys 818 parameters = {'l1.weight': weight, 819 'l1.bias': bias, 820 'buffer': buffer, 821 'l1.extra': extra} 822 x = torch.randn(1, 1) 823 out, parameters, buffers = functional_call(module, parameters, x) 824 self.assertEqual(out, x * weight + bias + buffer) 825 self.assertEqual(parameters, (weight, bias)) 826 self.assertEqual(buffers, (buffer,)) 827 self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias)))) 828 self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,)))) 829 830 # All weights with extra keys with parameters 831 parameters = {'l1.weight': weight, 832 'l1.bias': bias, 833 'buffer': buffer, 834 'l1.extra': extra_p} 835 x = torch.randn(1, 1) 836 out, parameters, buffers = functional_call(module, parameters, x) 837 self.assertEqual(out, x * weight + bias + buffer) 838 self.assertEqual(parameters, (weight, bias, extra_p)) 839 self.assertEqual(buffers, (buffer,)) 840 self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias, extra_p)))) 841 self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,)))) 842 843 # Some weights with extra keys 844 parameters = {'l1.weight': weight, 845 'l1.extra': extra} 846 x = torch.randn(1, 1) 847 out, parameters, buffers = functional_call(module, parameters, x) 848 self.assertEqual(out, x * weight + module.l1.bias + module.buffer) 849 self.assertEqual(parameters, (weight, module.l1.bias)) 850 self.assertEqual(buffers, (module.buffer)) 851 self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias)))) 852 self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,)))) 853 854 # Some weights with extra keys with parameters 855 parameters = {'l1.weight': weight, 856 'l1.extra': extra_p} 857 x = torch.randn(1, 1) 858 out, parameters, buffers = functional_call(module, parameters, x) 859 self.assertEqual(out, x * weight + module.l1.bias + module.buffer) 860 self.assertEqual(parameters, (weight, module.l1.bias, extra_p)) 861 self.assertEqual(buffers, (module.buffer)) 862 self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias, extra_p)))) 863 self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,)))) 864 865 # Set None 866 parameters = {'l1.weight': weight, 867 'l1.bias': None} 868 x = torch.randn(1, 1) 869 out, parameters, buffers = functional_call(module, parameters, x) 870 self.assertEqual(out, x * weight + module.buffer) 871 self.assertEqual(parameters, (weight,)) 872 self.assertEqual(buffers, (module.buffer)) 873 self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight,)))) 874 self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,)))) 875 876 877class TestStatelessDeprecation(TestCase): 878 def test_private_stateless_warns(self): 879 script = """ 880import torch 881import warnings 882 883with warnings.catch_warnings(record=True) as w: 884 from torch.nn.utils import _stateless 885 886exit(len(w)) 887""" 888 try: 889 subprocess.check_output( 890 [sys.executable, '-W', 'always', '-c', script], 891 stderr=subprocess.STDOUT, 892 # On Windows, opening the subprocess with the default CWD makes `import torch` 893 # fail, so just set CWD to this script's directory 894 cwd=os.path.dirname(os.path.realpath(__file__)),) 895 except subprocess.CalledProcessError as e: 896 self.assertEqual(e.returncode, 1) 897 else: 898 self.assertTrue(False, "No warning was raised.") 899 900 def test_stateless_functional_call_warns(self): 901 m = torch.nn.Linear(1, 1) 902 params = dict(m.named_parameters()) 903 x = torch.randn(3, 1) 904 with self.assertWarnsRegex(FutureWarning, "Please use `torch.func.functional_call`"): 905 stateless.functional_call(m, params, x) 906 907class TestPythonOptimizeMode(TestCase): 908 def test_runs_with_optimize_flag(self): 909 script = "import torch; import torch._functorch.deprecated" 910 try: 911 subprocess.check_output( 912 [sys.executable, "-OO", "-c", script], 913 stderr=subprocess.STDOUT, 914 # On Windows, opening the subprocess with the default CWD makes `import torch` 915 # fail, so just set CWD to this script's directory 916 cwd=os.path.dirname(os.path.realpath(__file__)),) 917 except subprocess.CalledProcessError as e: 918 self.assertFalse(e.returncode, "Import failed while running python in optimized mode") 919 920 921instantiate_parametrized_tests( 922 TestStatelessFunctionalAPI, 923) 924 925if __name__ == '__main__': 926 run_tests() 927