1# Owner(s): ["module: dynamo"] 2 3import os 4import unittest 5from unittest.mock import patch 6 7import torch 8import torch._dynamo 9import torch._dynamo.test_case 10import torch._functorch._aot_autograd 11from torch._dynamo import config as dynamo_config 12from torch._dynamo.utils import counters 13from torch._functorch import config as functorch_config 14from torch._functorch._aot_autograd.autograd_cache import ( 15 AOTAutogradCache, 16 autograd_cache_key, 17 BypassAOTAutogradCache, 18) 19from torch._functorch._aot_autograd.schemas import AOTConfig 20from torch._inductor import config as inductor_config 21from torch._inductor.test_case import TestCase as InductorTestCase 22from torch.testing._internal.common_cuda import SM80OrLater 23from torch.testing._internal.common_device_type import largeTensorTest 24from torch.testing._internal.common_utils import ( 25 instantiate_parametrized_tests, 26 parametrize, 27 skipIfWindows, 28) 29from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU 30 31 32@instantiate_parametrized_tests 33class AOTAutogradCacheTests(InductorTestCase): 34 def setUp(self): 35 """ 36 Reset all counters and caches before each unit test 37 """ 38 super().setUp() 39 counters.clear() 40 self._clear_all_caches() 41 42 def _clear_all_caches(self): 43 """ 44 Clear every cache, including AOTAutogradCache and FXCache 45 """ 46 torch._inductor.codecache.FxGraphCache.clear() 47 AOTAutogradCache.clear() 48 self._clear_dynamo_and_codecache() 49 50 def _clear_dynamo_and_codecache(self): 51 """ 52 Clear unrelated caches, like dynamo and PyCodeCache 53 """ 54 torch._dynamo.reset() 55 for m in torch._inductor.codecache.PyCodeCache.cache.values(): 56 os.remove(m.__file__) 57 torch._inductor.codecache.PyCodeCache.cache_clear() 58 59 @inductor_config.patch("fx_graph_remote_cache", False) 60 @inductor_config.patch("fx_graph_cache", True) 61 @functorch_config.patch({"enable_autograd_cache": True}) 62 def test_basic(self): 63 """ 64 Verify the interactions between FXGraphCache and AOTAutogradCache. 65 """ 66 67 def fn(x, y): 68 return (x * 2, y @ y) 69 70 a = torch.rand(25) 71 b = torch.rand(5, 5) 72 73 compiled_fn = torch.compile(fn, backend="inductor") 74 75 # A first call should miss in the cache. 76 self.assertEqual(fn(a, b), compiled_fn(a, b)) 77 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) 78 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) 79 self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) 80 81 # A second call should hit. (First reset so in-memory guards 82 # don't prevent compilation). 83 self._clear_dynamo_and_codecache() 84 self.assertEqual(fn(a, b), compiled_fn(a, b)) 85 86 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) 87 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) 88 self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) 89 90 @inductor_config.patch("fx_graph_remote_cache", False) 91 @inductor_config.patch("fx_graph_cache", True) 92 @functorch_config.patch({"enable_autograd_cache": True}) 93 @skipIfWindows( 94 msg="Known issue: Window can't delete loaded modules, so we can't clear module cache." 95 ) 96 def test_clear_fx_graph_cache(self): 97 """ 98 Verify the interactions between FXGraphCache and AOTAutogradCache. 99 """ 100 101 def fn(x, y): 102 return (x * 2, y @ y) 103 104 a = torch.rand(25) 105 b = torch.rand(5, 5) 106 107 compiled_fn = torch.compile(fn, backend="inductor") 108 109 # A first call should miss in the cache. 110 self.assertEqual(fn(a, b), compiled_fn(a, b)) 111 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) 112 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) 113 self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) 114 115 # Clear FX graph cache: second call should also be a miss 116 self._clear_dynamo_and_codecache() 117 torch._inductor.codecache.FxGraphCache.clear() 118 self.assertEqual(fn(a, b), compiled_fn(a, b)) 119 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) 120 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) 121 # We save again into the cache 122 self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2) 123 124 @inductor_config.patch("fx_graph_remote_cache", False) 125 @inductor_config.patch("fx_graph_cache", False) 126 @functorch_config.patch({"enable_autograd_cache": True}) 127 def test_fx_graph_cache_off(self): 128 """ 129 Should not use cache if FXGraphCache is not enabled 130 """ 131 132 def fn(x, y): 133 return (x * 2, y @ y) 134 135 a = torch.rand(25) 136 b = torch.rand(5, 5) 137 138 compiled_fn = torch.compile(fn, backend="inductor") 139 140 # A first call should miss in the cache. 141 self.assertEqual(fn(a, b), compiled_fn(a, b)) 142 self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1) 143 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) 144 self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) 145 146 # Clear FX graph cache: second call should also be a miss 147 self._clear_dynamo_and_codecache() 148 149 self.assertEqual(fn(a, b), compiled_fn(a, b)) 150 self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 2) 151 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) 152 self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) 153 154 @inductor_config.patch("fx_graph_remote_cache", False) 155 @inductor_config.patch("fx_graph_cache", True) 156 @functorch_config.patch({"enable_autograd_cache": True}) 157 @dynamo_config.patch("compiled_autograd", True) 158 def test_compiled_autograd_bypass(self): 159 def fn(a, b): 160 out = a.cos() + b 161 loss = out.sum() 162 ga, gb = torch.autograd.grad(loss, inputs=[a, b]) 163 164 a = torch.randn(25, requires_grad=True) 165 b = torch.randn(25, requires_grad=True) 166 a2 = a.detach().clone().requires_grad_(True) 167 b2 = b.detach().clone().requires_grad_(True) 168 compiled_fn = torch.compile(fn, backend="inductor") 169 self.assertEqual(fn(a, b), compiled_fn(a2, b2)) 170 self.assertEqual( 171 counters["aot_autograd"]["autograd_cache_miss"], 1 172 ) # from compiled forward 173 self.assertEqual( 174 counters["aot_autograd"]["autograd_cache_bypass"], 1 175 ) # from compiled autograd 176 177 @inductor_config.patch("fx_graph_remote_cache", False) 178 @inductor_config.patch("fx_graph_cache", True) 179 @functorch_config.patch({"enable_autograd_cache": True}) 180 @dynamo_config.patch("compiled_autograd", True) 181 def test_inference_graph_cache_hit_with_compiled_autograd_enabled(self): 182 def fn(a, b): 183 out = a.cos() + b 184 return out.sum() 185 186 a = torch.randn(25) 187 b = torch.randn(25) 188 compiled_fn = torch.compile(fn, backend="inductor") 189 self.assertEqual(fn(a, b), compiled_fn(a, b)) 190 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) 191 self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) 192 193 # Clear dynamo and run again. Should be a cache hit. 194 counters.clear() 195 self._clear_dynamo_and_codecache() 196 self.assertEqual(fn(a, b), compiled_fn(a, b)) 197 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 0) 198 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) 199 self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) 200 201 @inductor_config.patch("fx_graph_remote_cache", False) 202 @inductor_config.patch({"fx_graph_cache": True}) 203 @functorch_config.patch({"enable_autograd_cache": True}) 204 def test_autograd_lazy_backward(self): 205 """ 206 Lazily compile the backward, and lazily save to cache 207 """ 208 209 def fn(a, b): 210 return a.cos() + b 211 212 a = torch.randn(25, requires_grad=True) 213 b = torch.randn(25, requires_grad=True) 214 a2 = a.detach().clone().requires_grad_(True) 215 b2 = b.detach().clone().requires_grad_(True) 216 compiled_fn = torch.compile(fn, backend="inductor") 217 self.assertEqual(fn(a, b), compiled_fn(a2, b2)) 218 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) 219 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) 220 self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) 221 222 # Clear dynamo and run again. Should be a cache miss still, because backward hasn't run 223 self._clear_dynamo_and_codecache() 224 self.assertEqual(fn(a, b), compiled_fn(a2, b2)) 225 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) 226 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) 227 self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) 228 229 # Now let's run the backward 230 fn(a, b).sum().backward() 231 compiled_fn(a2, b2).sum().backward() 232 self.assertEqual(a.grad, a2.grad) 233 self.assertEqual(b.grad, b2.grad) 234 self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) 235 236 # Clear dynamo and rerun everything, now there should be a cache hit 237 self._clear_dynamo_and_codecache() 238 a = torch.randn(25, requires_grad=True) 239 b = torch.randn(25, requires_grad=True) 240 a2 = a.detach().clone().requires_grad_(True) 241 b2 = b.detach().clone().requires_grad_(True) 242 self.assertEqual(fn(a, b), compiled_fn(a2, b2)) 243 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) 244 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) 245 self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) 246 fn(a, b).sum().backward() 247 compiled_fn(a2, b2).sum().backward() 248 self.assertEqual(a.grad, a2.grad) 249 self.assertEqual(b.grad, b2.grad) 250 251 @inductor_config.patch("fx_graph_remote_cache", False) 252 @inductor_config.patch("fx_graph_cache", True) 253 @functorch_config.patch({"enable_autograd_cache": True}) 254 def test_autograd_function(self): 255 """ 256 Tests autograd cache hits 257 """ 258 259 def fn(a, b): 260 return a.sin() + b 261 262 a = torch.randn(25, requires_grad=True) 263 b = torch.randn(25, requires_grad=True) 264 a2 = a.detach().clone().requires_grad_(True) 265 b2 = b.detach().clone().requires_grad_(True) 266 267 compiled_fn = torch.compile(fn, backend="inductor") 268 269 # A first call should miss in the cache. 270 self.assertEqual(fn(a, b), compiled_fn(a2, b2)) 271 fn(a, b).sum().backward() 272 compiled_fn(a2, b2).sum().backward() 273 self.assertEqual(a.grad, a2.grad) 274 self.assertEqual(b.grad, b2.grad) 275 276 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) 277 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) 278 self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) 279 280 # Reset all tensors 281 a = torch.randn(25, requires_grad=True) 282 b = torch.randn(25, requires_grad=True) 283 a2 = a.detach().clone().requires_grad_(True) 284 b2 = b.detach().clone().requires_grad_(True) 285 286 # A second call should hit. (First reset so in-memory guards 287 # don't prevent compilation). 288 self._clear_dynamo_and_codecache() 289 self.assertEqual(fn(a, b), compiled_fn(a2, b2)) 290 fn(a, b).sum().backward() 291 compiled_fn(a2, b2).sum().backward() 292 self.assertEqual(a.grad, a2.grad) 293 self.assertEqual(b.grad, b2.grad) 294 295 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) 296 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) 297 self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) 298 299 @largeTensorTest("64GB", device=GPU_TYPE) 300 @parametrize("device", (GPU_TYPE,)) 301 @parametrize("dtype", (torch.float16, torch.bfloat16)) 302 @inductor_config.patch("fx_graph_cache", True) 303 @inductor_config.patch("fx_graph_remote_cache", False) 304 @functorch_config.patch({"enable_autograd_cache": True}) 305 def test_autograd_guard_single_entry(self, device, dtype): 306 """ 307 Test caching the same graph, but under conditions that introduce guards 308 for tensor sizes < int32. See test_codecache::TestFxGraphCache::test_cache_load_with_guards_int32_bounds. 309 310 This test in particular tests the behavior of a single entry cache. If we ever make AOTAutogradCache 311 support multiple entries under the same key, this test should be updated. 312 """ 313 if device == GPU_TYPE and not HAS_GPU: 314 raise unittest.SkipTest(f"requires {GPU_TYPE}") 315 if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: 316 raise unittest.SkipTest("requires CUDA SM80 or later") 317 318 def fn(x, y): 319 return (x + x, y + y) 320 321 def expect_miss(compiled_fn, a, b): 322 self._clear_dynamo_and_codecache() 323 counters.clear() 324 res = compiled_fn(a, b) 325 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) 326 self.assertEqual( 327 counters["aot_autograd"]["autograd_cache_guard_miss"], 328 0, 329 ) 330 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) 331 return res 332 333 def expect_hit(compiled_fn, a, b): 334 self._clear_dynamo_and_codecache() 335 counters.clear() 336 res = compiled_fn(a, b) 337 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 0) 338 self.assertEqual( 339 counters["aot_autograd"]["autograd_cache_guard_miss"], 340 0, 341 ) 342 self.assertEqual( 343 counters["aot_autograd"]["autograd_cache_hit"], 344 1, 345 ) 346 return res 347 348 def expect_guard_miss(compiled_fn, a, b): 349 self._clear_dynamo_and_codecache() 350 counters.clear() 351 res = compiled_fn(a, b) 352 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) 353 self.assertEqual( 354 counters["aot_autograd"]["autograd_cache_guard_miss"], 355 1, 356 ) 357 self.assertEqual( 358 counters["aot_autograd"]["autograd_cache_hit"], 359 0, 360 ) 361 return res 362 363 compiled_fn = torch.compile(fn, dynamic=True) 364 365 a_shape = (5, 6) 366 b_shape = (7, 8) 367 a = torch.rand(a_shape, device=device, dtype=dtype) 368 b = torch.rand(b_shape, device=device, dtype=dtype) 369 res1 = expect_miss(compiled_fn, a, b) 370 371 # Same shape, should cache hit 372 a2 = a.detach().clone() 373 b2 = b.detach().clone() 374 375 res2 = expect_hit(compiled_fn, a2, b2) 376 377 self.assertEqual(res1, res2) 378 379 # By changing the shape greatly, despite the same exact input 380 # graph, inductor should report a guard miss, leading 381 # to a cache miss on our end. 382 a_shape = (5, 6) 383 b_shape = (47000, 47001) 384 a3 = torch.rand(a_shape, device=device, dtype=dtype) 385 b3 = torch.rand(b_shape, device=device, dtype=dtype) 386 387 expect_guard_miss(compiled_fn, a3, b3) 388 389 # Wobble the shape a bit, but not enough 390 # to trigger a guard miss (since 6, 7 is still less than int32) 391 # Should result in a cache hit 392 a_shape = (6, 7) 393 b_shape = (47000, 47001) 394 a4 = torch.rand(a_shape, device=device, dtype=dtype) 395 b4 = torch.rand(b_shape, device=device, dtype=dtype) 396 expect_hit(compiled_fn, a4, b4) 397 398 # Change the shape back to the original, 399 # FXGraphCache should hit because it stores 400 # multiple entries 401 a_shape = (5, 6) 402 b_shape = (7, 8) 403 a5 = torch.rand(a_shape, device=device, dtype=dtype) 404 b5 = torch.rand(b_shape, device=device, dtype=dtype) 405 expect_hit(compiled_fn, a5, b5) 406 407 @largeTensorTest("64GB", device=GPU_TYPE) 408 @parametrize("device", (GPU_TYPE,)) 409 @parametrize("dtype", (torch.float16, torch.bfloat16)) 410 @parametrize("requires_grad", (True, False)) 411 @inductor_config.patch("fx_graph_cache", True) 412 @inductor_config.patch("fx_graph_remote_cache", False) 413 @functorch_config.patch({"enable_autograd_cache": True}) 414 def test_autograd_inductor_guards(self, device, dtype, requires_grad): 415 """ 416 Test caching the same graph, but under conditions that introduce guards 417 for tensor sizes < int32. 418 See test_codecache::TestFxGraphCache::test_cache_load_with_guards_int32_bounds. 419 """ 420 if device == GPU_TYPE and not HAS_GPU: 421 raise unittest.SkipTest(f"requires {GPU_TYPE}") 422 if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: 423 raise unittest.SkipTest("requires CUDA SM80 or later") 424 425 def fn(x, y): 426 return (x + x, y + y) 427 428 compiled_fn = torch.compile(fn, dynamic=True) 429 430 # Iterate over different shapes, varying whether the total 431 # size is below or above int32. For each combination, we expect 432 # different guards around whether the symbolic sizes do or do 433 # not exceed int32. 434 shapes = ( 435 ((5, 6), (7, 8)), 436 ((5, 6), (47000, 47001)), 437 ((47000, 47001), (5, 6)), 438 ) 439 expected_hits = expected_misses = expected_saves = 0 440 expected_guard_misses = 0 441 for a_shape, b_shape in shapes: 442 a = torch.rand( 443 a_shape, device=device, dtype=dtype, requires_grad=requires_grad 444 ) 445 b = torch.rand( 446 b_shape, device=device, dtype=dtype, requires_grad=requires_grad 447 ) 448 449 # AVOID a dynamo reset here. We expect guards to have been 450 # added that will be violated with the new shape. We should 451 # see a recompilation (along with a cache miss). 452 res1 = compiled_fn(a, b) 453 # A first call should miss in the cache. 454 expected_misses += 1 455 self.assertEqual( 456 counters["aot_autograd"]["autograd_cache_miss"], expected_misses 457 ) 458 self.assertEqual( 459 counters["aot_autograd"]["autograd_cache_guard_miss"], 460 expected_guard_misses, 461 ) 462 463 self.assertEqual( 464 counters["aot_autograd"]["autograd_cache_hit"], expected_hits 465 ) 466 # Because dynamic shapes are enabled, we expect backwards to be compiled ahead of time 467 # So we should see a cache save here 468 expected_saves += 1 469 self.assertEqual( 470 counters["aot_autograd"]["autograd_cache_saved"], expected_saves 471 ) 472 if requires_grad: 473 res1[0].sum().backward() 474 # No extra saves 475 self.assertEqual( 476 counters["aot_autograd"]["autograd_cache_saved"], expected_saves 477 ) 478 479 a2 = a.detach().clone().requires_grad_(requires_grad) 480 b2 = b.detach().clone().requires_grad_(requires_grad) 481 # A second call should hit. (First reset so in-memory guards 482 # don't prevent compilation). 483 484 # Now clear dynamo and we should see a cache hit 485 # This should populate guards to dynamo's cache, so that a subsequent run with a different 486 # shape will still trigger a second call to autograd_cache. 487 self._clear_dynamo_and_codecache() 488 res2 = compiled_fn(a2, b2) 489 expected_hits += 1 490 self.assertEqual( 491 counters["aot_autograd"]["autograd_cache_miss"], expected_misses 492 ) 493 self.assertEqual( 494 counters["aot_autograd"]["autograd_cache_guard_miss"], 495 expected_guard_misses, 496 ) 497 # First compile is a regular cache miss, subsequent are guard misses 498 expected_guard_misses += 1 499 self.assertEqual( 500 counters["aot_autograd"]["autograd_cache_hit"], expected_hits 501 ) 502 self.assertEqual( 503 counters["aot_autograd"]["autograd_cache_saved"], expected_saves 504 ) 505 self.assertEqual(res1, res2) 506 if requires_grad: 507 res2[0].sum().backward() 508 self.assertEqual(a.grad, a2.grad) 509 510 @inductor_config.patch("fx_graph_cache", True) 511 @inductor_config.patch("fx_graph_remote_cache", False) 512 @functorch_config.patch({"enable_autograd_cache": True}) 513 def test_nn_module_with_params_global_constant(self): 514 class MyMod(torch.nn.Module): 515 CONSTANT = torch.tensor([[2, 2], [2, 2]]) 516 517 def __init__(self) -> None: 518 super().__init__() 519 self.param = torch.nn.Parameter(torch.randn([2, 2])) 520 521 def forward(self, x): 522 return x.sin() + self.param + MyMod.CONSTANT 523 524 with torch.no_grad(): 525 compiled_fn = torch.compile(MyMod(), backend="inductor", fullgraph=True) 526 res1 = compiled_fn(torch.ones([2, 2])) 527 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) 528 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) 529 self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) 530 531 self._clear_dynamo_and_codecache() 532 res2 = compiled_fn(torch.ones([2, 2])) 533 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) 534 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) 535 self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) 536 537 self.assertEqual(res1, res2) 538 # Edit the "constant". We'll get a cache hit, 539 # but it should result in a different result when run 540 # because MyMod.CONSTANT is an input to the graph 541 MyMod.CONSTANT = torch.tensor([[3, 3], [3, 3]]) 542 self._clear_dynamo_and_codecache() 543 res3 = compiled_fn(torch.ones([2, 2])) 544 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) 545 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 2) 546 self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) 547 self.assertNotEqual(res1, res3) 548 self.assertEqual(res1, res3.sub(torch.ones(2, 2))) 549 550 551@inductor_config.patch("fx_graph_cache", True) 552class AOTAutogradCachePicklerTests(torch._dynamo.test_case.TestCase): 553 @property 554 def device_type(self) -> str: 555 return "cuda" if torch.cuda.is_available() else "cpu" 556 557 def default_config(self): 558 return AOTConfig( 559 fw_compiler=None, 560 bw_compiler=None, 561 inference_compiler=None, 562 partition_fn=None, 563 decompositions={}, 564 num_params_buffers=0, 565 aot_id=0, 566 keep_inference_input_mutations=False, 567 dynamic_shapes=True, 568 aot_autograd_arg_pos_to_source=None, 569 is_export=False, 570 no_tangents=False, 571 enable_log=False, 572 ) 573 574 def _get_dynamo_output(self, fn, *args, **kwargs): 575 # Reset dynamo between runs 576 torch._dynamo.reset() 577 fx_graph = None 578 example_inputs = None 579 580 def compiler(gm, inputs, **kwargs): 581 nonlocal fx_graph 582 nonlocal example_inputs 583 fx_graph = gm 584 example_inputs = inputs 585 return gm 586 587 g = torch.compile(fn, backend=compiler, fullgraph=True) 588 result = g(*args, **kwargs) 589 return (result, fx_graph, example_inputs) 590 591 def gen_cache_key(self, f, config, inputs=None): 592 if inputs is None: 593 inputs = [torch.ones(3)] 594 _, fx_g, example_inputs = self._get_dynamo_output(f, *inputs) 595 return autograd_cache_key(fx_g, example_inputs, config, {}) 596 597 def test_basic_hash_key(self): 598 def fn(x): 599 return x.sin().cos() 600 601 config = self.default_config() 602 # Check hash is stable on multiple runs 603 c1 = self.gen_cache_key(fn, config) 604 c2 = self.gen_cache_key(fn, config) 605 self.assertEqual(c1, c2) 606 607 def test_identical_graphs_and_configs(self): 608 def fn(x): 609 return x.sin().cos() 610 611 def fn2(x): 612 y = x.sin() 613 z = y.cos() 614 return z 615 616 # Make the id different, but otherwise identical 617 config = self.default_config() 618 config2 = self.default_config() 619 config2.aot_id = 1 620 621 c1 = self.gen_cache_key(fn, config) 622 c2 = self.gen_cache_key(fn, config2) 623 self.assertEqual(c1, c2) 624 625 def test_different_graphs(self): 626 def fn(x): 627 return x.cos().sin() 628 629 def fn2(x): 630 return x.sin().cos() 631 632 config = self.default_config() 633 c1 = self.gen_cache_key(fn, config) 634 c2 = self.gen_cache_key(fn2, config) 635 self.assertNotEqual(c1, c2) 636 637 def test_different_configs(self): 638 def fn(x): 639 return x.cos().sin() 640 641 config = self.default_config() 642 config2 = self.default_config() 643 config2.dynamic_shapes = False 644 c1 = self.gen_cache_key(fn, config) 645 c2 = self.gen_cache_key(fn, config2) 646 self.assertNotEqual(c1, c2) 647 648 def test_different_inputs(self): 649 def fn(x): 650 return x.cos().sin() 651 652 config = self.default_config() 653 c1 = self.gen_cache_key(fn, config, inputs=[torch.ones(3)]) 654 c2 = self.gen_cache_key(fn, config, inputs=[torch.ones(2)]) 655 self.assertNotEqual(c1, c2) 656 657 def test_different_global_configs(self): 658 def fn(x): 659 return x.cos().sin() 660 661 config = self.default_config() 662 663 c1 = self.gen_cache_key(fn, config) 664 c2 = self.gen_cache_key(fn, config) 665 self.assertEqual(c1, c2) 666 667 c1 = self.gen_cache_key(fn, config) 668 669 # Change functorch config 670 with functorch_config.patch( 671 {"debug_assert": not functorch_config.debug_assert} 672 ): 673 c2 = self.gen_cache_key(fn, config) 674 675 self.assertNotEqual(c1, c2) 676 677 c1 = self.gen_cache_key(fn, config) 678 # Change inductor config 679 with inductor_config.patch({"debug": not inductor_config.debug}): 680 c2 = self.gen_cache_key(fn, config) 681 682 self.assertNotEqual(c1, c2) 683 684 c1 = self.gen_cache_key(fn, config) 685 # Change torch grad enabled 686 with torch.no_grad(): 687 c2 = self.gen_cache_key(fn, config) 688 self.assertNotEqual(c1, c2) 689 690 def test_incompatible_function(self): 691 @torch._dynamo.allow_in_graph 692 class AllowInGraphFunc(torch.autograd.Function): 693 @staticmethod 694 def forward(_, x): 695 torch._dynamo.graph_break() 696 return x.sin() 697 698 def fn(x): 699 return AllowInGraphFunc.apply(x) 700 701 config = self.default_config() 702 self.assertRaises( 703 BypassAOTAutogradCache, lambda: self.gen_cache_key(fn, config) 704 ) 705 706 def test_private_namespace(self): 707 # TODO: anyone who monkeypatches a **public** function into torch namespace with @allow_in_graph 708 # could still break our sanity check and cache something bad. But that's an edge case we'll take the risk on. 709 # Monkeypatch some random private function into torch, see that it fails 710 @torch._dynamo.allow_in_graph 711 def my_private_fun(x): 712 return x.sin() 713 714 with patch("torch._my_priv", new=my_private_fun, create=True): 715 716 def fn(x): 717 return torch._my_priv(x) 718 719 config = self.default_config() 720 self.assertRaises( 721 BypassAOTAutogradCache, lambda: self.gen_cache_key(fn, config) 722 ) 723 724 def test_private_builtin(self): 725 # _foreach_add is a private torch function, but 726 # it's also a builtin_function_or_method, so it should be allowed to be cached 727 # since dynamo allows it in the graph 728 def fn(x, b): 729 y = (x, x) 730 return torch._foreach_add(y, b) 731 732 config = self.default_config() 733 r1 = self.gen_cache_key(fn, config, inputs=[torch.ones(3), 1]) 734 r2 = self.gen_cache_key(fn, config, inputs=[torch.ones(3), 2]) 735 self.assertNotEqual(r1, r2) 736 737 def test_nn_module_with_params(self): 738 class MyMod(torch.nn.Module): 739 def __init__(self) -> None: 740 super().__init__() 741 self.seq = torch.nn.Parameter(torch.ones((3, 3))) 742 743 def forward(self, x): 744 return self.seq + x 745 746 config = self.default_config() 747 # Different inputs and parameters, but all the same size 748 c1 = self.gen_cache_key(MyMod(), config, inputs=[torch.ones((3, 3))]) 749 c2 = self.gen_cache_key(MyMod(), config, inputs=[torch.ones((3, 3))]) 750 self.assertEqual(c1, c2) 751 752 def test_normal_torch_function(self): 753 @torch._dynamo.allow_in_graph 754 def fn(x): 755 y = torch.sin(x) 756 z = torch.cos(x) 757 w = y + z 758 w.abs() 759 return w 760 761 config = self.default_config() 762 self.gen_cache_key(fn, config) 763 764 765if __name__ == "__main__": 766 from torch._dynamo.test_case import run_tests 767 768 run_tests() 769