1# Owner(s): ["module: meta tensors"] 2 3 4import contextlib 5import copy 6import dataclasses 7import inspect 8import itertools 9import pickle 10import unittest 11import weakref 12from unittest.mock import patch 13 14import numpy as np 15import torch 16import torch._dynamo 17import torch._functorch.config 18import torch._prims as prims 19import torch.testing._internal.optests as optests 20import torch.utils._pytree as pytree 21 22from torch import distributed as dist 23from torch._C._functorch import _add_batch_dim, get_unwrapped, is_batchedtensor 24from torch._dynamo.testing import make_test_cls_with_patches, rand_strided 25from torch._guards import tracing, TracingContext 26from torch._subclasses.fake_tensor import ( 27 DynamicOutputShapeException, 28 extract_tensor_metadata, 29 FakeTensor, 30 FakeTensorConverter, 31 FakeTensorMode, 32 unset_fake_temporarily, 33 UnsupportedOperatorException, 34 _CacheKeyState 35) 36from torch.fx.experimental.proxy_tensor import make_fx 37from torch.fx.experimental.symbolic_shapes import ( 38 DimDynamic, 39 free_symbols, 40 ShapeEnv, 41 ShapeEnvSettings, 42 StatelessSymbolicContext, 43 statically_known_true, 44) 45from torch.fx.passes.fake_tensor_prop import FakeTensorProp 46from torch.testing import FileCheck 47from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION 48from torch.testing._internal.common_device_type import ( 49 instantiate_device_type_tests, 50 OpDTypes, 51 ops, 52) 53from torch.testing._internal.common_utils import ( 54 instantiate_parametrized_tests, 55 parametrize, 56 run_tests, 57 skipIfCrossRef, 58 skipIfRocm, 59 skipIfTorchDynamo, 60 TemporaryFileName, 61 TEST_WITH_TORCHDYNAMO, 62 TestCase, 63) 64 65from torch.testing._internal.inductor_utils import GPU_TYPE 66from torch.testing._internal.custom_op_db import custom_op_db 67from torch.testing._internal.jit_utils import RUN_CUDA 68from torch.utils._mode_utils import no_dispatch 69from torch.utils._python_dispatch import TorchDispatchMode 70 71aten = torch.ops.aten 72 73torch._dynamo.config.fake_tensor_cache_enabled = True 74torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True 75 76 77def expectedFailurePropagateRealTensors(fn): 78 fn._expected_failure_propagate_real_tensors = True 79 return fn 80 81 82class FakeTensorTest(TestCase): 83 def checkType(self, t, device_str, size): 84 self.assertTrue(isinstance(t, FakeTensor)) 85 self.assertEqual(t.device.type, device_str) 86 self.assertEqual(list(t.size()), size) 87 88 @unittest.skipIf(not RUN_CUDA, "requires cuda") 89 def test_cuda_initialized(self): 90 # doesnt error 91 with FakeTensorMode(): 92 p = torch.randn(4, 2, requires_grad=True, device="cuda") 93 x = torch.randn(8, 4, device="cuda") 94 y = torch.mm(x, p).square().sum() 95 y.backward() 96 97 def test_basic(self): 98 x = torch.empty(2, 2, device="cpu") 99 y = torch.empty(4, 2, 2, device="cpu") 100 with FakeTensorMode() as mode: 101 x = mode.from_tensor(x) 102 y = mode.from_tensor(y) 103 z = x + y 104 self.assertEqual(z.shape, (4, 2, 2)) 105 self.assertEqual(z.device, torch.device("cpu")) 106 self.assertTrue(isinstance(z, FakeTensor)) 107 108 def test_custom_op_fallback(self): 109 from torch.library import impl, Library 110 111 try: 112 test_lib = Library("my_test_op", "DEF") # noqa: TOR901 113 test_lib.define("foo(Tensor self) -> Tensor") 114 115 @impl(test_lib, "foo", "CPU") 116 def foo_impl(self): 117 return self.cos() 118 119 x = torch.empty(2, 2, device="cpu") 120 with self.assertRaisesRegex( 121 UnsupportedOperatorException, "my_test_op.foo.default" 122 ): 123 with FakeTensorMode(allow_fallback_kernels=True) as mode: 124 x = mode.from_tensor(x) 125 torch.ops.my_test_op.foo(x) 126 127 finally: 128 test_lib._destroy() 129 130 def test_parameter_instantiation(self): 131 with FakeTensorMode(): 132 x = torch.rand([4]) 133 y = torch.nn.parameter.Parameter(x) 134 self.assertTrue(isinstance(y, torch.nn.Parameter)) 135 136 @unittest.skipIf(not dist.is_available(), "requires distributed") 137 def test_fsdp_flat_param(self): 138 from torch.distributed.fsdp._flat_param import FlatParameter 139 140 with FakeTensorMode() as m: 141 data = torch.randn(2, 2) 142 param = FlatParameter(data, requires_grad=True) 143 self.assertIsInstance(param, FlatParameter) 144 self.assertIsInstance(param, torch.nn.Parameter) 145 self.assertIsInstance(param, FakeTensor) 146 147 def test_non_parameter_grad(self): 148 mode = FakeTensorMode() 149 t = torch.rand([4], requires_grad=True) 150 fake_t = mode.from_tensor(t) 151 self.assertEqual(fake_t.requires_grad, t.requires_grad) 152 153 @unittest.skipIf( 154 TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" 155 ) 156 @unittest.skipIf(not RUN_CUDA, "requires cuda") 157 def test_index_cuda_with_cpu(self): 158 with FakeTensorMode(): 159 x = torch.rand([2048], device="cuda") 160 out = x[torch.zeros([36], dtype=torch.int64)] 161 self.checkType(out, "cuda", [36]) 162 163 @unittest.skipIf(not RUN_CUDA, "requires cuda") 164 def test_shape_take_not_device(self): 165 with FakeTensorMode(): 166 x = torch.empty(1, device="cpu") 167 y = torch.empty(8, 8, device="cuda") 168 out = x.resize_as_(y) 169 self.assertEqual(out.shape, (8, 8)) 170 self.assertEqual(out.device.type, "cpu") 171 self.assertTrue(isinstance(out, FakeTensor)) 172 173 def test_repr(self): 174 with FakeTensorMode(): 175 x = torch.empty(2, 2, device="cpu") 176 self.assertEqual(repr(x), "FakeTensor(..., size=(2, 2))") 177 x = torch.empty(2, 2, device="meta") 178 self.assertEqual(repr(x), "FakeTensor(..., device='meta', size=(2, 2))") 179 180 @unittest.skipIf(not RUN_CUDA, "requires cuda") 181 def test_zero_dim(self): 182 with FakeTensorMode() as mode: 183 x = torch.tensor(0.0) 184 y = torch.rand([4, 4], device="cuda") 185 out = x + y 186 self.assertEqual(out.shape, (4, 4)) 187 self.assertEqual(out.device, y.device) 188 self.assertTrue(isinstance(out, FakeTensor)) 189 190 def test_nan_to_num(self): 191 with FakeTensorMode(): 192 for dtype in [torch.float16, torch.float32]: 193 x = torch.rand([4], dtype=dtype) 194 y = torch.nan_to_num(x, nan=None) 195 z = torch.nan_to_num(x, 0.0) 196 self.assertEqual(dtype, y.dtype) 197 self.assertEqual(dtype, z.dtype) 198 199 @unittest.skipIf(not RUN_CUDA, "requires cuda") 200 def test_throw(self): 201 x = torch.tensor(0.0) # TODO: tensor() errors 202 with FakeTensorMode() as mode: 203 x_conv = mode.from_tensor(x) 204 y = torch.rand([4, 4], device="cuda") 205 z = torch.rand([4, 4], device="cpu") 206 self.assertRaises(Exception, lambda: torch.lerp(x_conv, y, z)) 207 208 @unittest.skipIf(not RUN_CUDA, "requires cuda") 209 def test_type_as(self): 210 with FakeTensorMode(): 211 x = torch.rand([16, 1], device="cpu") 212 y = torch.rand([4, 4], device="cuda") 213 out = x.type_as(y) 214 self.assertEqual(out.device.type, "cuda") 215 self.assertTrue(isinstance(out, FakeTensor)) 216 217 @unittest.skipIf(not RUN_CUDA, "requires cuda") 218 def test_setitem(self): 219 for device in ["cpu", "cuda"]: 220 with FakeTensorMode(): 221 x = torch.rand([16, 1], device=device) 222 x[..., 0] = 0 223 224 @unittest.skipIf(not RUN_CUDA, "requires cuda") 225 def test_device_inplace_copy(self): 226 with FakeTensorMode(): 227 x = torch.rand([8, 8], device="cpu") 228 y = torch.rand([8, 8], device="cuda") 229 assert x.copy_(y).device.type == "cpu" 230 assert y.copy_(x).device.type == "cuda" 231 232 def test_fake_dispatch_keys(self): 233 with FakeTensorMode(): 234 x = torch.rand([4]) 235 f = ( 236 FileCheck() 237 .check("CPU") 238 .check("ADInplaceOrView") 239 .check("AutogradCPU") 240 .check("AutocastCPU") 241 ) 242 f.run(torch._C._dispatch_key_set(x)) 243 244 with torch.inference_mode(): 245 x = torch.rand([4]) 246 y = x + x 247 FileCheck().check("CPU").check("AutocastCPU").run( 248 torch._C._dispatch_key_set(y) 249 ) 250 FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run( 251 torch._C._dispatch_key_set(y) 252 ) 253 254 def test_batch_tensor(self): 255 x = torch.rand((3, 4, 5)) 256 b = _add_batch_dim(x, 0, 0) 257 mode = FakeTensorMode() 258 fake_b = mode.from_tensor(b) 259 prims.utils.compare_tensor_meta(b, fake_b, check_strides=True) 260 261 b1 = _add_batch_dim(x, 1, 1) 262 b2 = _add_batch_dim(b1, 0, 2) 263 fake_b2 = mode.from_tensor(b2) 264 prims.utils.compare_tensor_meta(b2, fake_b2, check_strides=True) 265 self.assertTrue(is_batchedtensor(fake_b2)) 266 fake_b1 = get_unwrapped(fake_b2) 267 self.assertTrue(is_batchedtensor(fake_b1)) 268 fake_tensor = get_unwrapped(fake_b1) 269 self.assertIsInstance(fake_tensor, FakeTensor) 270 271 def test_constructor(self): 272 with FakeTensorMode(): 273 x = torch.rand([4, 4], device="cpu") 274 275 self.assertTrue(isinstance(x, FakeTensor)) 276 self.assertTrue(x.device.type == "cpu") 277 278 def test_mode(self): 279 with FakeTensorMode(): 280 y = torch.rand([4], device="cpu") 281 out = y + y 282 283 self.assertTrue(isinstance(out, FakeTensor)) 284 285 def test_full(self): 286 # Test torch.full returns tensor with correct dtype 287 with torch._subclasses.CrossRefFakeMode(): 288 y = torch.full((4, 4), 1) 289 290 def check_function_with_fake(self, fn): 291 out = fn() 292 with torch._subclasses.FakeTensorMode(): 293 out_fake = fn() 294 295 for a, b in zip(pytree.tree_leaves(out), pytree.tree_leaves(out_fake)): 296 if not isinstance(a, torch.Tensor): 297 self.assertTrue(not isinstance(b, torch.Tensor)) 298 continue 299 300 prims.utils.compare_tensor_meta(a, b, check_strides=True) 301 302 @unittest.skipIf(not RUN_CUDA, "requires cuda") 303 def test_non_kwarg_device(self): 304 with FakeTensorMode(): 305 x = torch.rand([16, 1], device="cpu") 306 y = x.to(torch.device("cpu")) 307 self.assertIs(x, y) 308 z = x.to(torch.device("cuda")) 309 self.assertEqual(z.device.type, "cuda") 310 311 def test_non_overlapping_stride_zero(self): 312 def foo(): 313 x = torch.empty_strided([1, 3, 427, 640], (0, 1, 1920, 3)) 314 return x.half() 315 316 self.check_function_with_fake(foo) 317 318 def test_fake_mode_error(self): 319 x = torch.rand([4, 4]) 320 321 with self.assertRaisesRegex(Exception, "Please convert all Tensors"): 322 with FakeTensorMode(): 323 y = x[0] 324 325 @unittest.skipIf( 326 TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" 327 ) 328 def test_fake_grad_copy(self): 329 x = torch.rand([4, 4], requires_grad=True) 330 x.grad = torch.rand([4, 4]) 331 mode = FakeTensorMode() 332 fake_x = mode.from_tensor(x) 333 prims.utils.compare_tensor_meta(fake_x, x) 334 prims.utils.compare_tensor_meta(fake_x.grad, x.grad) 335 336 self.assertTrue(isinstance(fake_x.grad, FakeTensor)) 337 338 @unittest.skipIf(not RUN_CUDA, "requires cuda") 339 def test_index_put_error(self): 340 mode = FakeTensorMode() 341 for context in [contextlib.nullcontext, lambda: mode]: 342 with context(): 343 y = torch.randn(2, 2, 3) 344 x = torch.randn(2, 2, 3).to("cuda") 345 with self.assertRaises(RuntimeError): 346 x[[1, 1]] = y 347 348 with self.assertRaises(RuntimeError): 349 torch.ops.aten.index_put(x, torch.tensor([1, 1], device="cuda"), y) 350 351 # no error 352 torch.ops.aten.index_put( 353 x, torch.tensor([1, 1], device="cuda"), torch.tensor(5.0) 354 ) 355 torch.ops.aten.index_put_( 356 x, torch.tensor([1, 1], device="cuda"), torch.tensor(5.0) 357 ) 358 359 @unittest.skipIf(not RUN_CUDA, "requires cuda") 360 def test_like_constructor(self): 361 with FakeTensorMode(): 362 x = torch.rand([4, 4]) 363 y = torch.ones_like(x) 364 self.assertTrue(isinstance(y, FakeTensor)) 365 self.assertEqual(y.device.type, "cpu") 366 z = torch.ones_like(x, device="cuda") 367 self.assertTrue(isinstance(z, FakeTensor)) 368 self.assertEqual(z.device.type, "cuda") 369 370 def test_binary_op_type_promotion(self): 371 with FakeTensorMode(): 372 x = torch.empty([2, 2], dtype=torch.float) 373 y = torch.empty([2, 2], dtype=torch.int64) 374 out = x / y 375 self.assertEqual(out.dtype, torch.float) 376 self.assertEqual(out.device.type, "cpu") 377 378 @unittest.skipIf( 379 TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" 380 ) 381 def test_from_numpy(self): 382 with FakeTensorMode(): 383 x = torch.tensor(np.zeros([4, 4])) 384 self.checkType(x, "cpu", [4, 4]) 385 386 def test_randperm(self): 387 x = torch.randperm(10) 388 y = torch.randperm(5, device="cpu") 389 with FakeTensorMode(): 390 x1 = torch.randperm(10) 391 prims.utils.compare_tensor_meta(x, x1) 392 y1 = torch.randperm(5, device="cpu") 393 prims.utils.compare_tensor_meta(y, y1) 394 395 def test_print_in_fake_mode(self): 396 x = torch.zeros(2) 397 # does not fail 398 with FakeTensorMode(): 399 out = str(x) 400 assert "FakeTensor" not in out 401 402 @unittest.skipIf(not RUN_CUDA, "requires cuda") 403 def test_upsample_bilinear_small_channels(self): 404 out = [] 405 mode = FakeTensorMode() 406 for i, context in enumerate([contextlib.nullcontext, lambda: mode]): 407 with context(): 408 arg0_1 = torch.empty_strided( 409 (3, 427, 640), (1, 1920, 3), dtype=torch.float32, device="cuda" 410 ) 411 unsqueeze = torch.ops.aten.unsqueeze.default(arg0_1, 0) 412 out.append( 413 torch.ops.aten.upsample_bilinear2d.default( 414 unsqueeze, [800, 1199], False 415 ) 416 ) 417 418 self.assertTrue(out[1].is_contiguous()) 419 self.checkMetaProps(out[0], out[1]) 420 421 @unittest.skipIf(not RUN_CUDA, "requires cuda") 422 def test_cpu_fallback(self): 423 with FakeTensorMode(allow_fallback_kernels=False): 424 filters = torch.randn(8, 4, 3, 3).cuda() 425 inputs = torch.randn(1, 4, 5, 5).cuda() 426 out = torch.nn.functional.conv2d(inputs, filters, padding=1) 427 self.assertEqual(out.device.type, "cuda") 428 self.assertEqual(list(out.size()), [1, 8, 5, 5]) 429 430 with FakeTensorMode(allow_fallback_kernels=True): 431 # intentionally bad inputs 432 filters = torch.randn(8, 20, 3, 3).cuda() 433 inputs = torch.randn(1, 7, 10, 5).cuda() 434 with self.assertRaises(RuntimeError): 435 torch.nn.functional.conv2d(inputs, filters, padding=1) 436 437 with FakeTensorMode(allow_fallback_kernels=True): 438 filters = torch.randn(8, 4, 3, 3).cuda() 439 inputs = torch.randn(1, 4, 5, 5).cuda() 440 441 out = torch.nn.functional.conv2d(inputs, filters, padding=1) 442 self.assertEqual(out.device.type, "cuda") 443 self.assertEqual(list(out.size()), [1, 8, 5, 5]) 444 445 @unittest.skipIf(not RUN_CUDA, "requires cuda") 446 def test_out_multi_device(self): 447 with FakeTensorMode(): 448 x = torch.rand([4]) 449 y = torch.rand([4], device="cuda") 450 451 with self.assertRaisesRegex(Exception, "found.+two.+devices"): 452 torch.sin(x, out=y) 453 454 with self.assertRaisesRegex(Exception, "found.+two.+devices"): 455 x.add_(y) 456 457 @unittest.skipIf( 458 TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" 459 ) 460 @unittest.skipIf(not RUN_CUDA, "requires cuda") 461 def test_normalize_device(self): 462 with FakeTensorMode(): 463 x = torch.empty(1, device="cuda") 464 y = torch.empty(1, device=f"cuda:{torch.cuda.current_device()}") 465 out = x + y 466 self.checkType(out, "cuda", [1]) 467 468 def test_recursive_invocation(self): 469 mode = FakeTensorMode() 470 with mode: 471 x = torch.tensor(2) 472 mode.in_kernel_invocation = True 473 y = x + x 474 self.assertTrue(mode.in_kernel_invocation) 475 476 @unittest.skipIf( 477 TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" 478 ) 479 @skipIfRocm 480 @parametrize( 481 "allow_fallback_kernels", 482 [False, True], 483 lambda a: "with_fallback" if a else "without_fallback", 484 ) 485 @unittest.skipIf(not RUN_CUDA, "requires cuda") 486 def test_cudnn_rnn(self, allow_fallback_kernels): 487 def fn( 488 a0, 489 b0, 490 b1, 491 b2, 492 b3, 493 b4, 494 b5, 495 b6, 496 b7, 497 b8, 498 b9, 499 b10, 500 b11, 501 b12, 502 b13, 503 b14, 504 b15, 505 a3, 506 a4, 507 a5, 508 ): 509 a1 = [ 510 b0, 511 b1, 512 b2, 513 b3, 514 b4, 515 b5, 516 b6, 517 b7, 518 b8, 519 b9, 520 b10, 521 b11, 522 b12, 523 b13, 524 b14, 525 b15, 526 ] 527 return torch.ops.aten._cudnn_rnn( 528 a0, 529 a1, 530 4, 531 a3, 532 a4, 533 a5, 534 2, 535 2048, 536 0, 537 2, 538 False, 539 0.0, 540 False, 541 True, 542 [], 543 None, 544 ) 545 546 mode = FakeTensorMode(allow_fallback_kernels=allow_fallback_kernels) 547 for i, context in enumerate([contextlib.nullcontext, lambda: mode]): 548 with context(): 549 inps1 = [ 550 torch.randn([92, 8, 2048]).cuda(), 551 torch.randn([8192, 2048]).cuda(), 552 torch.randn([8192, 2048]).cuda(), 553 torch.randn([8192]).cuda(), 554 torch.randn([8192]).cuda(), 555 torch.randn([8192, 2048]).cuda(), 556 torch.randn([8192, 2048]).cuda(), 557 torch.randn([8192]).cuda(), 558 torch.randn([8192]).cuda(), 559 torch.randn([8192, 4096]).cuda(), 560 torch.randn([8192, 2048]).cuda(), 561 torch.randn([8192]).cuda(), 562 torch.randn([8192]).cuda(), 563 torch.randn([8192, 4096]).cuda(), 564 torch.randn([8192, 2048]).cuda(), 565 torch.randn([8192]).cuda(), 566 torch.randn([8192]).cuda(), 567 torch.randn([167837696]).cuda(), 568 torch.randn([4, 8, 2048]).cuda(), 569 torch.randn([4, 8, 2048]).cuda(), 570 ] 571 inps2 = inps1 572 inps2[len(inps2) - 1] = None # argument `cx` can be None 573 574 for inps in [inps1, inps2]: 575 out = fn(*inps) 576 self.assertIs(out[4], inps[-3]) 577 for ten in out: 578 if i == 1: 579 self.assertTrue(isinstance(ten, FakeTensor)) 580 self.assertEqual(ten.device.type, "cuda") 581 582 @unittest.skipIf(not RUN_CUDA, "requires cuda") 583 def test_cuda_lstm(self): 584 # Ensure CUDA (non-cuDNN) impl succeeds with fake tensors. 585 with torch.backends.cudnn.flags(enabled=False): 586 fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False) 587 with fake_tensor_mode: 588 N = 5 589 L = 4 590 H_in = 2 591 hidden_size = 3 592 proj_size = 2 593 num_layers = 2 594 bidir = False 595 D = 2 if bidir else 1 596 H_out = proj_size if proj_size > 0 else hidden_size 597 598 lstm = torch.nn.LSTM( 599 input_size=H_in, 600 hidden_size=hidden_size, 601 num_layers=num_layers, 602 proj_size=proj_size, 603 batch_first=False, 604 bias=True, 605 bidirectional=bidir, 606 device="cuda", 607 ) 608 609 h_0 = torch.randn((num_layers * D, N, H_out), device="cuda") 610 c_0 = torch.randn((num_layers * D, N, hidden_size), device="cuda") 611 inp = torch.randn((L, N, H_in), device="cuda") 612 (output, (h_n, c_n)) = lstm(inp, (h_0, c_0)) 613 output.sum().backward() 614 615 self.assertEqual(output.shape, (L, N, D * H_out)) 616 self.assertEqual(h_n.shape, (D * num_layers, N, H_out)) 617 self.assertEqual(c_n.shape, (D * num_layers, N, hidden_size)) 618 619 def test_data_dependent_operator(self): 620 with FakeTensorMode(allow_fallback_kernels=False): 621 x = torch.rand([10, 10]) 622 623 self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x)) 624 625 def test_parameter_view(self): 626 x = torch.nn.Parameter(torch.randn(4)) 627 x_view = x.view(4) 628 mode = FakeTensorMode() 629 fake_x_view = mode.from_tensor(x_view) 630 fake_x = mode.from_tensor(x) 631 self.assertFalse(isinstance(fake_x_view, torch.nn.Parameter)) 632 self.assertTrue(isinstance(fake_x, torch.nn.Parameter)) 633 634 def test_tolist(self): 635 shape_env = ShapeEnv() 636 with FakeTensorMode(allow_fallback_kernels=False, shape_env=shape_env): 637 x = torch.rand([10]) 638 x.tolist() 639 640 # Propagate real tensors doesn't work with fake-on-fake 641 @expectedFailurePropagateRealTensors 642 def test_same_shape_env_preserved(self): 643 shape_env = ShapeEnv() 644 mode1 = FakeTensorMode(shape_env=shape_env) 645 t1 = mode1.from_tensor( 646 torch.randn(10), 647 symbolic_context=StatelessSymbolicContext( 648 dynamic_sizes=[DimDynamic.DYNAMIC], constraint_sizes=[None] 649 ), 650 ) 651 mode2 = FakeTensorMode(shape_env=shape_env) 652 t2 = mode2.from_tensor(t1) 653 # t2.size(0) is still dynamic, even though we didn't pass DYNAMIC here 654 self.assertIsNot(t2, t1) 655 self.assertIs(t1.fake_mode, mode1) 656 self.assertIs(t2.fake_mode, mode2) 657 self.assertIs(t2.size(0).node.shape_env, t1.size(0).node.shape_env) 658 self.assertEqual(str(t2.size(0)), str(t1.size(0))) 659 660 # TODO: Support NJT. There's also some funny business with dynamic shapes 661 # which would need to be dealt with as well 662 @expectedFailurePropagateRealTensors 663 def test_jagged_fake_to_fake_preserved(self): 664 from torch.nested._internal.nested_tensor import jagged_from_list 665 666 S0, S1, S2 = 3, 4, 5 667 D = 4 668 a = torch.randn(S0, D, requires_grad=True, dtype=torch.float64) 669 b = torch.randn(S1, D, requires_grad=True, dtype=torch.float64) 670 c = torch.randn(S2, D, requires_grad=True, dtype=torch.float64) 671 offsets = None 672 jt, _ = jagged_from_list([a, b, c], offsets) 673 shape_env = ShapeEnv() 674 mode1 = FakeTensorMode(shape_env=shape_env) 675 t1 = mode1.from_tensor(jt) 676 mode2 = FakeTensorMode(shape_env=shape_env) 677 t2 = mode2.from_tensor(t1) 678 # It's not obvious that the invocation above makes it dynamic but it 679 # does! 680 self.assertTrue(free_symbols(t1.size())) 681 self.assertIsNot(t2, t1) 682 self.assertIs(t1.offsets().fake_mode, mode1) 683 self.assertIs(t2.offsets().fake_mode, mode2) 684 self.assertIs(t2.size(1).node.shape_env, t1.size(1).node.shape_env) 685 self.assertEqual(str(t2.size(1)), str(t1.size(1))) 686 687 def checkMetaProps(self, t1, t2): 688 prims.utils.compare_tensor_meta(t1, t2, check_strides=True) 689 690 @skipIfCrossRef 691 def test_deepcopy(self): 692 with FakeTensorMode() as mode: 693 pass 694 mod = torch.nn.BatchNorm2d(10) 695 with torch._subclasses.fake_tensor.FakeCopyMode(mode): 696 mod_copied = copy.deepcopy(mod) 697 698 def check_copy(mod, mod_copied): 699 for name, param in itertools.chain( 700 mod.named_parameters(), mod.named_buffers() 701 ): 702 param_copied = getattr(mod_copied, name) 703 self.checkMetaProps(param, param_copied) 704 self.assertTrue(isinstance(param_copied, FakeTensor)) 705 self.assertEqual( 706 isinstance(param, torch.nn.Parameter), 707 isinstance(param_copied, torch.nn.Parameter), 708 ) 709 self.assertEqual(param.requires_grad, param_copied.requires_grad) 710 711 check_copy(mod, mod_copied) 712 713 class ModuleNew(torch.nn.Module): 714 def __init__(self) -> None: 715 super().__init__() 716 self.a = torch.rand([10, 2]) 717 self.b = self.a 718 self.c = self.a[0] 719 720 mod = ModuleNew() 721 with torch._subclasses.fake_tensor.FakeCopyMode(mode): 722 mod_copied = copy.deepcopy(mod) 723 724 self.assertIs(mod_copied.a, mod_copied.b) 725 self.assertEqual(mod_copied.b.storage()._cdata, mod_copied.a.storage()._cdata) 726 727 @unittest.skipIf( 728 TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" 729 ) 730 @unittest.skipIf(not RUN_CUDA, "requires cuda") 731 def test_new(self): 732 with FakeTensorMode(): 733 a = torch.rand([16, 1]) 734 self.checkType(a.new(10, 10), "cpu", [10, 10]) 735 self.checkType(a.new([1, 2, 3, 4]), "cpu", [4]) 736 b = torch.rand([4, 4], device="cuda") 737 self.checkType(b.new(device="cuda"), "cuda", [0]) 738 self.checkType(a.new(torch.rand([1])), "cpu", [1]) 739 740 @unittest.skipIf( 741 TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" 742 ) 743 def test_scalar_inputs(self): 744 with FakeTensorMode(): 745 self.checkType(torch.div(3, 2), "cpu", []) 746 ten = torch.zeros(2, dtype=torch.int32) * 2.0 747 self.assertEqual(ten.dtype, torch.float) 748 self.checkType(ten, "cpu", [2]) 749 750 @unittest.skipIf( 751 TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" 752 ) 753 def test_allow_meta(self): 754 def run_meta(): 755 with FakeTensorMode(): 756 x = torch.rand([4], device="meta") 757 return x + x 758 759 self.checkType(run_meta(), "meta", [4]) 760 761 with patch.object(torch._functorch.config, "fake_tensor_allow_meta", False): 762 self.assertRaises(Exception, run_meta) 763 764 def test_embedding_bag_meta(self): 765 def f(): 766 # This behavior was originally unintentional but we see people 767 # relying on it 768 embedding = torch.nn.EmbeddingBag(10, 3, mode="sum", device="meta") 769 input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long) 770 offsets = torch.tensor([0, 4], dtype=torch.long) 771 return embedding(input, offsets) 772 773 real_out = f() 774 with FakeTensorMode(): 775 fake_out = f() 776 777 for r, f in zip(real_out, fake_out): 778 self.assertEqual(r.size(), f.size()) 779 self.assertEqual(r.device, f.device) 780 781 @unittest.skipIf( 782 TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" 783 ) 784 def test_mixed_real_and_fake_inputs(self): 785 class _TestPattern(torch.nn.Module): 786 def __init__(self) -> None: 787 super().__init__() 788 self.conv = torch.nn.Conv2d(1, 1, 1) 789 self.bn = torch.nn.BatchNorm2d(1) 790 791 def forward(self, input): 792 running_std = torch.sqrt(self.bn.running_var + self.bn.eps) 793 scale_factor = self.bn.weight / running_std 794 weight_shape = [1] * len(self.conv.weight.shape) 795 weight_shape[0] = -1 796 bias_shape = [1] * len(self.conv.weight.shape) 797 bias_shape[1] = -1 798 scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape) 799 zero_bias = torch.zeros_like(self.conv.bias, dtype=input.dtype) 800 conv = self.conv._conv_forward(input, scaled_weight, zero_bias) 801 conv_orig = conv / scale_factor.reshape(bias_shape) 802 conv_orig = conv_orig + self.conv.bias.reshape(bias_shape) 803 conv = self.bn(conv_orig) 804 return conv 805 806 example_inputs = (torch.randn(1, 1, 3, 3),) 807 mod = _TestPattern() 808 with FakeTensorMode(allow_non_fake_inputs=True): 809 out = mod(torch.randn(1, 1, 3, 3)) 810 self.checkType(out, "cpu", (1, 1, 3, 3)) 811 812 @unittest.skipIf( 813 TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" 814 ) 815 @unittest.skipIf(not RUN_CUDA, "requires cuda") 816 def test_aten_copy_multi_device(self): 817 with FakeTensorMode(): 818 x1 = torch.rand(4, device="cpu") 819 x2 = torch.rand(4, device="cuda") 820 copy1 = torch.ops.aten.copy.default(x1, x2) 821 copy2 = torch.ops.aten.copy.default(x2, x1) 822 out = torch.empty(4, device="cpu") 823 torch.ops.aten.copy.out(x1, x2, out=out) 824 self.checkType(copy1, "cpu", (4,)) 825 self.checkType(copy2, "cuda", (4,)) 826 self.checkType(out, "cpu", (4,)) 827 828 @unittest.skipIf( 829 TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" 830 ) 831 @unittest.skipIf(not RUN_CUDA, "requires cuda") 832 def test_aten_index_multi_device(self): 833 with FakeTensorMode(): 834 x1 = torch.rand(4, 4, device="cpu") 835 x2 = torch.rand(4, 4, device="cuda") 836 i1 = torch.tensor([0, 1], device="cuda") 837 i2 = torch.tensor([0, 1], device="cpu") 838 # NB: This one does not work: cuda indices not allowed on cpu 839 # tensor 840 # r1 = torch.ops.aten.index(x1, i1) 841 r2 = torch.ops.aten.index(x2, i2) 842 843 y1 = torch.rand(4, device="cpu") 844 y2 = torch.rand(4, device="cuda") 845 j1 = torch.tensor([2], device="cuda") 846 j2 = torch.tensor([2], device="cpu") 847 r3 = torch.ops.aten.index_put.default(x1, j1, y1) 848 r4 = torch.ops.aten.index_put.default(x2, j2, y2) 849 # self.checkType(r1, "cpu", ()) 850 self.checkType(r2, "cuda", ()) 851 self.checkType(r3, "cpu", (4, 4)) 852 self.checkType(r4, "cuda", (4, 4)) 853 854 @unittest.skipIf( 855 TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" 856 ) 857 @unittest.skipIf(not RUN_CUDA, "requires cuda") 858 def test_aten_slice_scatter_multi_device(self): 859 with FakeTensorMode(): 860 x1 = torch.rand(4, 4, device="cpu") 861 y1 = torch.rand(2, 4, device="cuda") 862 x2 = torch.rand(4, 4, device="cuda") 863 y2 = torch.rand(2, 4, device="cpu") 864 out = torch.empty(4, 4, device="cpu") 865 r1 = torch.ops.aten.slice_scatter.default(x1, y1, start=2) 866 r2 = torch.ops.aten.slice_scatter.default(x2, y2, start=2) 867 r3 = torch.ops.aten.slice_scatter.out(x1, y1, out=out, start=2) 868 self.checkType(r1, "cpu", (4, 4)) 869 self.checkType(r2, "cuda", (4, 4)) 870 self.checkType(r3, "cpu", (4, 4)) 871 self.checkType(out, "cpu", (4, 4)) 872 873 def test__adaptive_avg_pool2d_backward(self): 874 with FakeTensorMode(): 875 grad_out = torch.rand(2, 3, 4, 4) 876 inp = torch.rand(2, 3, 4, 4).to(memory_format=torch.channels_last) 877 grad_in = torch.ops.aten._adaptive_avg_pool2d_backward(grad_out, inp) 878 self.assertTrue( 879 torch._prims_common.suggest_memory_format(grad_in) 880 == torch.channels_last 881 ) 882 883 def test_export_numpy(self): 884 class MyNumpyModel(torch.nn.Module): 885 def forward(self, input): 886 input = input.numpy() 887 return input + np.random.randn(*input.shape) 888 889 with FakeTensorMode(): 890 ep = torch.export.export(MyNumpyModel(), args=(torch.randn(1000),)) 891 self.assertTrue(isinstance(ep, torch.export.ExportedProgram)) 892 893 def test_unsqueeze_copy(self): 894 shape_env = ShapeEnv() 895 t1 = torch.ones(2, 2, 768) 896 with FakeTensorMode(shape_env=shape_env) as fake_mode: 897 t = fake_mode.from_tensor( 898 t1, 899 symbolic_context=StatelessSymbolicContext( 900 dynamic_sizes=[ 901 DimDynamic.DYNAMIC, 902 DimDynamic.STATIC, 903 DimDynamic.STATIC, 904 ], 905 ), 906 ) 907 908 self.assertEqual(t.shape[0], torch.ops.aten.unsqueeze_copy(t, 1).shape[0]) 909 910 def test_alias_call(self): 911 fwAD = torch.autograd.forward_ad 912 913 def f(x): 914 return 4312491 * x 915 916 with torch._subclasses.fake_tensor.FakeTensorMode(): 917 with fwAD.dual_level(): 918 x = torch.randn(3, device="cpu") 919 y = torch.ones_like(x) 920 dual = fwAD.make_dual(x, y) 921 r = f(dual) 922 923 self.assertIsInstance(r, FakeTensor) 924 self.assertEqual(r.size(), [3]) 925 926 927instantiate_parametrized_tests(FakeTensorTest) 928 929 930def make_propagate_real_tensors_cls(cls): 931 cls = make_test_cls_with_patches( 932 cls, 933 "PropagateRealTensors", 934 "_propagate_real_tensors", 935 (torch._functorch.config, "fake_tensor_propagate_real_tensors", True), 936 xfail_prop="_expected_failure_propagate_real_tensors", 937 decorator=skipIfTorchDynamo("propagate_real_tensors affects Dynamo"), 938 ) 939 cls.__file__ = __file__ 940 cls.__module__ = __name__ 941 globals()[cls.__name__] = cls 942 943 944make_propagate_real_tensors_cls(FakeTensorTest) 945 946 947class FakeTensorConstHandling(TestCase): 948 def assertConst(self, *args): 949 for arg in args: 950 self.assertTrue(arg.constant is not None) 951 952 def assertNotConst(self, *args): 953 for arg in args: 954 self.assertTrue(arg.constant is None) 955 956 def test_simple(self): 957 with FakeTensorMode(): 958 x = torch.tensor(4.0) 959 self.assertEqual(x.item(), 4.0) 960 961 def test_inplace_add(self): 962 with FakeTensorMode(): 963 x = torch.tensor(4.0) 964 y = x.add_(1) 965 self.assertEqual(x.item(), 5.0) 966 self.assertEqual(y.item(), 5.0) 967 self.assertConst(x, y) 968 969 def test_shared_storages(self): 970 with FakeTensorMode(): 971 x = torch.tensor([4.0]) 972 y = x[:] 973 974 self.assertEqual(x.storage()._cdata, y.storage()._cdata) 975 self.assertEqual(x.constant.storage()._cdata, y.constant.storage()._cdata) 976 977 def test_constant_invalidation(self): 978 with FakeTensorMode(): 979 x = torch.tensor([1.0]) 980 self.assertConst(x) 981 y = torch.rand([1]) 982 x.add_(y) 983 self.assertNotConst(x) 984 985 def test_inplace_view_invalidation(self): 986 with FakeTensorMode(): 987 x = torch.tensor([1]) 988 self.assertConst(x) 989 x.resize_([2]) 990 self.assertEqual(x.size(0), 2) 991 self.assertNotConst(x) 992 993 def test_fake_tensor_in_intlist_repro(self): 994 def fn(tensors): 995 max_size = torch.tensor([800, 1216], dtype=torch.int64) 996 batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size) 997 return tensors[0].new_full(batch_shape, 0.0) 998 999 with self.assertRaises( 1000 torch._subclasses.fake_tensor.DataDependentOutputException 1001 ): 1002 with torch._subclasses.fake_tensor.FakeTensorMode(): 1003 a = torch.randn(3, 800, 1199) 1004 b = torch.randn(3, 800, 800) 1005 inputs = [a, b] 1006 ref = fn(inputs) 1007 1008 def test_fake_tensor_batch_norm_cpu(self): 1009 with torch._subclasses.CrossRefFakeMode(): 1010 m = torch.nn.Sequential( 1011 torch.nn.BatchNorm2d(10), 1012 torch.nn.ReLU(), 1013 ) 1014 m.eval() 1015 out = m(torch.randn([2, 10, 8, 8])) 1016 1017 def test_shared_storage_invalidation(self): 1018 with FakeTensorMode(): 1019 x = torch.tensor([1.0]) 1020 y = x[:] 1021 self.assertConst(x, y) 1022 y.add_(torch.rand([1])) 1023 self.assertNotConst(x, y) 1024 1025 def test_aliased_const_write(self): 1026 with FakeTensorMode(): 1027 x = torch.tensor([1]) 1028 y = x.expand([4]) 1029 self.assertNotConst(y) 1030 y[0] = 1 1031 self.assertNotConst(x) 1032 1033 def test_constant_propagate_through_functions(self): 1034 with FakeTensorMode(): 1035 y = torch.div(4, 4, rounding_mode="trunc") 1036 self.assertConst(y) 1037 1038 1039make_propagate_real_tensors_cls(FakeTensorConstHandling) 1040 1041 1042def contains_type(type: torch.Type, maybe_contained_type: torch.Type): 1043 return maybe_contained_type.isSubtypeOf(type) or any( 1044 contains_type(e, maybe_contained_type) for e in type.containedTypes() 1045 ) 1046 1047 1048class FakeTensorOpInfoTest(TestCase): 1049 @ops(custom_op_db, dtypes=OpDTypes.any_one) 1050 def test_fake(self, device, dtype, op): 1051 sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) 1052 for sample_input in sample_inputs_itr: 1053 args = (sample_input.input,) + sample_input.args 1054 kwargs = sample_input.kwargs 1055 optests.fake_check(op, args, kwargs) 1056 1057 1058make_propagate_real_tensors_cls(FakeTensorOpInfoTest) 1059instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=("cpu", "cuda")) 1060instantiate_device_type_tests( 1061 PropagateRealTensorsFakeTensorOpInfoTest, globals(), only_for=("cpu",) # noqa: F821 1062) 1063 1064 1065class FakeTensorConverterTest(TestCase): 1066 def test_memoized_conversion_to_meta(self): 1067 x = torch.rand(2, 2, 2) 1068 mode = FakeTensorMode() 1069 self.assertTrue(mode.from_tensor(x) is mode.from_tensor(x)) 1070 1071 def test_memoized_conversion_from_meta(self): 1072 x = torch.rand(2, 2).to(device="meta") 1073 mode = FakeTensorMode() 1074 converter = mode.fake_tensor_converter 1075 self.assertTrue( 1076 converter.from_meta_and_device(mode, x, "cpu") 1077 is converter.from_meta_and_device(mode, x, "cpu") 1078 ) 1079 1080 def test_separate_tensor_storages_view(self): 1081 x = torch.rand(2, 2, 2) 1082 y = x[0] 1083 mode = FakeTensorMode() 1084 converter = mode.fake_tensor_converter 1085 x_conv = converter.from_real_tensor(mode, x) 1086 y_conv = converter.from_real_tensor(mode, y) 1087 self.assertEqual(torch._C._storage_id(x_conv), torch._C._storage_id(y_conv)) 1088 1089 @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") 1090 def test_separate_tensor_storages_non_view(self): 1091 x = torch.rand(2, 2, 2) 1092 y = torch.rand(4, 2) 1093 y.set_(x.storage()) 1094 mode = FakeTensorMode() 1095 converter = mode.fake_tensor_converter 1096 x_conv = converter.from_real_tensor(mode, x) 1097 y_conv = converter.from_real_tensor(mode, y) 1098 stor_id = torch._C._storage_id(x_conv) 1099 self.assertEqual(stor_id, torch._C._storage_id(y_conv)) 1100 del x 1101 del x_conv 1102 self.assertEqual(len(converter.tensor_memo), 1) 1103 self.assertEqual(len(converter.meta_converter.storage_memo), 1) 1104 del y 1105 del y_conv 1106 self.assertEqual(len(converter.tensor_memo), 0) 1107 self.assertEqual(len(converter.meta_converter.storage_memo), 0) 1108 1109 @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") 1110 def test_dead_weak_ref(self): 1111 x = torch.rand(2, 2, 2) 1112 y = x[0] 1113 mode = FakeTensorMode() 1114 converter = FakeTensorConverter() 1115 x_conv = converter.from_real_tensor(mode, x) 1116 x_conv_storage = x_conv.untyped_storage() 1117 del x_conv 1118 self.assertFalse(x in converter.tensor_memo) 1119 y_conv = converter.from_real_tensor(mode, y) 1120 self.assertIs(x_conv_storage, y_conv.untyped_storage()) 1121 1122 @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") 1123 def test_dead_key(self): 1124 x = torch.rand(2, 2, 2) 1125 mode = FakeTensorMode() 1126 converter = FakeTensorConverter() 1127 x_conv = converter.from_real_tensor(mode, x) 1128 self.assertEqual(len(converter.tensor_memo), 1) 1129 x_conv2 = converter.from_real_tensor(mode, x) 1130 assert x_conv2 is x_conv 1131 del x 1132 del x_conv 1133 del x_conv2 1134 self.assertEqual(len(converter.tensor_memo), 0) 1135 1136 def test_no_active_mode(self): 1137 with FakeTensorMode() as mode: 1138 x = torch.empty(2, 2, device="cpu") 1139 y = torch.empty(2, 2, device="cpu") 1140 1141 out = x + y 1142 self.assertEqual(mode, out.fake_mode) 1143 self.assertTrue(isinstance(out, FakeTensor)) 1144 self.assertEqual(out.device.type, "cpu") 1145 1146 def test_multiple_modes(self): 1147 t = torch.rand([4]) 1148 t2 = torch.rand([4]) 1149 with FakeTensorMode() as m: 1150 with FakeTensorMode() as m2: 1151 t_fake = m.from_tensor(t) 1152 t2_fake = m2.from_tensor(t2) 1153 1154 with self.assertRaisesRegex(Exception, "Mixing fake modes"): 1155 t_fake + t2_fake 1156 1157 def test_separate_mode_error(self): 1158 with FakeTensorMode(): 1159 x = torch.empty(2, 2, device="cpu") 1160 with FakeTensorMode(): 1161 y = torch.empty(2, 2, device="cpu") 1162 self.assertRaises(Exception, lambda: x, y) 1163 1164 @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") 1165 def test_no_ref_cycle(self): 1166 x = torch.rand([4]) 1167 mode = FakeTensorMode() 1168 y = mode.from_tensor(x) 1169 self.assertEqual(len(mode.fake_tensor_converter.tensor_memo), 1) 1170 mode_weak = weakref.ref(mode) 1171 y_weak = weakref.ref(mode) 1172 del mode 1173 del y 1174 assert mode_weak() is None 1175 assert y_weak() is None 1176 1177 1178make_propagate_real_tensors_cls(FakeTensorConverterTest) 1179 1180 1181class FakeTensorOperatorInvariants(TestCase): 1182 def get_aten_op(self, schema): 1183 namespace, name = schema.name.split("::") 1184 overload = schema.overload_name if schema.overload_name else "default" 1185 assert namespace == "aten" 1186 return getattr(getattr(torch.ops.aten, name), overload) 1187 1188 def get_all_aten_schemas(self): 1189 for schema in torch._C._jit_get_all_schemas(): 1190 namespace = schema.name.split("::")[0] 1191 if namespace != "aten": 1192 continue 1193 yield schema 1194 1195 def test_non_kwarg_only_device(self): 1196 for schema in self.get_all_aten_schemas(): 1197 ten_type = torch._C.TensorType.get() 1198 if not any( 1199 contains_type(arg.type, ten_type) 1200 for arg in itertools.chain(schema.arguments, schema.returns) 1201 ): 1202 continue 1203 1204 opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get()) 1205 has_non_kwarg_device = any( 1206 not arg.kwarg_only and arg.type.isSubtypeOf(opt_device) 1207 for arg in schema.arguments 1208 ) 1209 if has_non_kwarg_device: 1210 self.assertTrue( 1211 self.get_aten_op(schema) 1212 in torch._subclasses.fake_tensor._device_not_kwarg_ops 1213 ) 1214 1215 def test_tensor_constructors_all_have_kwarg_device(self): 1216 for schema in self.get_all_aten_schemas(): 1217 op = self.get_aten_op(schema) 1218 if not torch._subclasses.fake_tensor._is_tensor_constructor(op): 1219 continue 1220 1221 opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get()) 1222 has_kwarg_device = any( 1223 arg.kwarg_only and arg.type.isSubtypeOf(opt_device) 1224 for arg in schema.arguments 1225 ) 1226 1227 self.assertTrue( 1228 has_kwarg_device or op == torch.ops.aten._list_to_tensor.default 1229 ) 1230 1231 @unittest.expectedFailure 1232 def test_sparse_new(self): 1233 with FakeTensorMode(): 1234 indices = torch.randn(1, 1, dtype=torch.int64) 1235 values = torch.randn(1) 1236 extra = (2,) 1237 sparse = torch.randn(1).to_sparse() 1238 # This used to segfault, now it does not, but it still raises an 1239 # error 1240 sparse2 = sparse.new(indices, values, extra) 1241 1242 def test_tensor_new(self): 1243 with FakeTensorMode(): 1244 x = torch.Tensor([1, 2, 3]) 1245 self.assertIsInstance(x, FakeTensor) 1246 1247 def test_like_ops(self): 1248 for schema in self.get_all_aten_schemas(): 1249 if "_like" == schema.name[-5:]: 1250 op = self.get_aten_op(schema) 1251 self.assertIn( 1252 op, torch._subclasses.fake_tensor._like_tensor_constructors 1253 ) 1254 1255 def test_str_storage(self): 1256 x = torch.zeros(3) 1257 with FakeTensorMode() as m: 1258 y = m.from_tensor(x) 1259 self.assertExpectedInline( 1260 str(x.storage()), 1261 """\ 1262 0.0 1263 0.0 1264 0.0 1265[torch.storage.TypedStorage(dtype=torch.float32, device=cpu) of size 3]""", 1266 ) 1267 self.assertExpectedInline( 1268 str(y.storage()), 1269 """\ 1270... 1271[torch.storage.TypedStorage(dtype=torch.float32, device=meta) of size 3]""", 1272 ) 1273 1274 self.assertExpectedInline( 1275 str(y.storage()), 1276 """\ 1277... 1278[torch.storage.TypedStorage(dtype=torch.float32, device=meta) of size 3]""", 1279 ) 1280 1281 # at::_embedding_bag has no op info, 1282 # and returns extra tensors that at::embedding bag throws away 1283 def test_embedding_bag_private(self): 1284 args = [ 1285 torch.ones(6, 1), 1286 torch.ones(6, dtype=torch.int64), 1287 torch.arange(2, dtype=torch.int64), 1288 False, 1289 2, # mode = max 1290 ] 1291 1292 ref_out = torch.ops.aten._embedding_bag(*args) 1293 with FakeTensorMode() as m: 1294 meta_args = [ 1295 m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args 1296 ] 1297 meta_out = torch.ops.aten._embedding_bag(*meta_args) 1298 1299 self.assertEqual(len(ref_out), len(meta_out)) 1300 for ref_o, meta_o in zip(ref_out, meta_out): 1301 self.assertEqual(ref_o.size(), meta_o.size()) 1302 1303 def test_cross_entropy_loss(self): 1304 inp = torch.randn(3, 5) 1305 target = torch.randint(5, (3,), dtype=torch.long) 1306 weight = torch.rand(5) 1307 fn = torch.nn.functional.cross_entropy 1308 for w in (weight, None): 1309 args = (inp, target, w) 1310 ref = fn(*args) 1311 with FakeTensorMode() as m: 1312 meta_args = [ 1313 m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args 1314 ] 1315 meta_out = torch.nn.functional.cross_entropy( 1316 *meta_args, label_smoothing=0.5 1317 ) 1318 1319 self.assertEqual(ref.size(), meta_out.size()) 1320 1321 @skipIfRocm 1322 @unittest.skipIf( 1323 not PLATFORM_SUPPORTS_FLASH_ATTENTION, 1324 "Does not support SDPA or pre-SM80 hardware", 1325 ) 1326 def test_flash_attention(self): 1327 class Repro(torch.nn.Module): 1328 def __init__(self) -> None: 1329 super().__init__() 1330 1331 def forward(self, arg1, arg2, arg3): 1332 torch.ops.aten._scaled_dot_product_flash_attention( 1333 arg1, arg2, arg3, scale=0.17677669529663687 1334 ) 1335 1336 args_new = [ 1337 [ 1338 ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"), 1339 ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"), 1340 ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"), 1341 ], 1342 [ 1343 ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"), 1344 ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"), 1345 ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"), 1346 ], 1347 ] 1348 for args_list in args_new: 1349 args = [ 1350 rand_strided(bsz, num_heads, seq_len, head_dim) 1351 for (bsz, num_heads, seq_len, head_dim) in args_list 1352 ] 1353 try: 1354 with torch._subclasses.CrossRefFakeMode(): 1355 Repro()(*args) 1356 except RuntimeError as e: 1357 # We expect the cross ref to succed for the first output to fail 1358 # for the rng state, see Note [Seed and Offset] 1359 self.assertTrue("output[0]" not in str(e)) 1360 self.assertTrue( 1361 "found mismatched tensor metadata for output[6]: Devices cpu and cuda:0 are not equal!" 1362 in str(e) 1363 ) 1364 1365 # IMPORTANT!!! Always run even if CUDA is not available 1366 def test_fake_gpu_no_init(self): 1367 # Skip this test, we will try to run CUDA operations to real prop so 1368 # it clearly will not work on CPU runner 1369 if torch._functorch.config.fake_tensor_propagate_real_tensors: 1370 return 1371 with FakeTensorMode(): 1372 torch.empty(10, device=GPU_TYPE) 1373 torch.ones(10, device=GPU_TYPE) 1374 torch.zeros(10, device=GPU_TYPE) 1375 torch.rand(10, device=GPU_TYPE) 1376 torch.tensor(3.14, device=GPU_TYPE) 1377 torch.tensor([[3.14, 2], [1, 2]], device=GPU_TYPE) 1378 1379 @skipIfRocm 1380 @unittest.skipIf(not RUN_CUDA, "requires cuda") 1381 def test_conv_c1_backward(self): 1382 class Repro(torch.nn.Module): 1383 def __init__(self) -> None: 1384 super().__init__() 1385 1386 def forward(self, arg1, arg2, arg3): 1387 torch.ops.aten.convolution_backward.default( 1388 arg1, 1389 arg2, 1390 arg3, 1391 [1], 1392 [1, 1], 1393 [1, 1], 1394 [1, 1], 1395 False, 1396 [0, 0], 1397 1, 1398 [True, True, False], 1399 ) 1400 1401 args_new = [ 1402 ((16, 1, 128, 128), (16384, 16384, 128, 1), torch.float16, "cuda"), 1403 ((16, 64, 128, 128), (1048576, 1, 8192, 64), torch.float16, "cuda"), 1404 ((1, 64, 3, 3), (576, 9, 3, 1), torch.float16, "cuda"), 1405 ] 1406 args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args_new] 1407 1408 with torch._subclasses.CrossRefFakeMode(): 1409 Repro()(*args) 1410 1411 def test_no_dispatch_with_like_function(self): 1412 class CountingMode(TorchDispatchMode): 1413 def __init__(self) -> None: 1414 self.count = 0 1415 1416 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1417 self.count += 1 1418 return func(*args, **kwargs) 1419 1420 with FakeTensorMode(): 1421 x = torch.randn(2) 1422 with CountingMode() as mode: 1423 with no_dispatch(): 1424 torch.zeros_like(x) 1425 1426 self.assertEqual(mode.count, 0) 1427 1428 1429make_propagate_real_tensors_cls(FakeTensorOperatorInvariants) 1430 1431 1432class FakeTensorPropTest(TestCase): 1433 def test_fake_tensor_prop_on_nn_module(self): 1434 class ToyNnModuleWithParameters(torch.nn.Module): 1435 def __init__(self) -> None: 1436 super().__init__() 1437 self.layer1 = torch.nn.Linear(4, 3) 1438 self.layer2 = torch.nn.Linear(3, 2) 1439 1440 def forward(self, value): 1441 value = self.layer1(value) 1442 value = torch.relu(value) 1443 value = self.layer2(value) 1444 return value 1445 1446 model = ToyNnModuleWithParameters() 1447 value = torch.randn(5, 4) 1448 # Convert nn.Module to GraphModule so that FakeTensorProp runs. 1449 graph_model = torch.fx.symbolic_trace(model, (value,)) 1450 # The following block runs FakeTensorProp on graph_module w/to the same FakeTensorMode 1451 # 1452 # TODO(wschin): there should be an API to run FakeTensorProp for GraphModule 1453 # with parameters and buffers. 1454 with FakeTensorMode() as fake_tensor_mode: 1455 1456 def to_fake_tensor(x): 1457 if isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor): 1458 return fake_tensor_mode.from_tensor(x) 1459 return x 1460 1461 fake_parameters_and_buffers = { 1462 k: to_fake_tensor(v) 1463 for k, v in itertools.chain( 1464 graph_model.named_parameters(), graph_model.named_buffers() 1465 ) 1466 } 1467 with torch.nn.utils.stateless._reparametrize_module( 1468 graph_model, fake_parameters_and_buffers 1469 ): 1470 # This case uses the **same** fake tensor mode to 1471 # 1. create fake parameters and fake buffers, and 1472 # 2. run FakeTensorProp 1473 # The result should be correct. 1474 result = FakeTensorProp(graph_model, fake_tensor_mode).propagate(value) 1475 self.assertTrue(isinstance(result, FakeTensor)) 1476 self.assertEqual(result.shape, (5, 2)) 1477 # This case uses the **different** fake tensor modes to 1478 # 1. create fake parameters and fake buffers, and 1479 # 2. run FakeTensorProp 1480 # The following code should fail. 1481 failed = False 1482 try: 1483 FakeTensorProp(graph_model).propagate(value) 1484 except AssertionError: 1485 # AssertionError: tensor's device must be `meta`, got cpu instead 1486 failed = True 1487 self.assertTrue(failed) 1488 1489 def test_fake_tensor_prop_on_nn_module_with_optional_args(self): 1490 class OptionalArgumentInBetween(torch.nn.Module): 1491 def __init__(self) -> None: 1492 super().__init__() 1493 self.layer1 = torch.nn.Linear(4, 3) 1494 self.layer2 = torch.nn.Linear(3, 2) 1495 1496 def forward(self, value, another_value=None, another_optional_value=None): 1497 # Mimic huggingface's `forward` methods which have several optional arguments. 1498 # For example, GPT accepts forward(self, input_ids, None, attention_mask, ...). 1499 # To apply FakeTensorProp, its from_real_tensor(...) needs to accept None. 1500 if another_value is None: 1501 another_value = torch.rand_like(value) 1502 if another_optional_value is None: 1503 another_optional_value = torch.rand_like(value) 1504 value = value + another_value + another_optional_value 1505 return value * value 1506 1507 fake_mode = FakeTensorMode( 1508 allow_non_fake_inputs=True, allow_fallback_kernels=False 1509 ) 1510 with fake_mode: 1511 model = OptionalArgumentInBetween() 1512 value = torch.randn(5, 4) 1513 another_optional_value = torch.randn(5, 4) 1514 graph_model = torch.fx.symbolic_trace( 1515 model, (value, None, another_optional_value) 1516 ) 1517 FakeTensorProp(graph_model, fake_mode).propagate( 1518 value, None, another_optional_value 1519 ) 1520 1521 def test_unbacked_shape_realloc(self): 1522 def f(x): 1523 return x.nonzero() 1524 1525 shape_env = ShapeEnv() 1526 fake_mode = FakeTensorMode(shape_env=shape_env) 1527 with fake_mode: 1528 value = torch.randn(5) 1529 gm = make_fx(f)(value) 1530 nonzero_nodes = [ 1531 n for n in gm.graph.nodes if n.target is torch.ops.aten.nonzero.default 1532 ] 1533 self.assertEqual(len(nonzero_nodes), 1) 1534 self.assertIsInstance(nonzero_nodes[0].meta["val"].shape[0], torch.SymInt) 1535 u0 = nonzero_nodes[0].meta["val"].shape[0] 1536 FakeTensorProp(gm, fake_mode).propagate(value) 1537 u1 = nonzero_nodes[0].meta["val"].shape[0] 1538 # Test that this test is actually doing something in that the 1539 # FakeTensorProp actually triggered a reallocation. If this assert is 1540 # failing, it could be because we started memoizing the nnz count for 1541 # nonzero, which is nice in some sense (no reallocation) but not 1542 # helpful for this test, which is checking what we do when we have 1543 # to reallocate. If so, you need to make this example more 1544 # complicated (e.g., maybe have a nontrivial computation on the input 1545 # before feeding it into nonzero, or have some sort of randomness) 1546 self.assertIsNot(u0, u1) 1547 self.assertTrue(statically_known_true(u0 == u1)) 1548 1549 def test_torch_load_with_fake_mode(self): 1550 class TheModelClass(torch.nn.Module): 1551 def __init__(self) -> None: 1552 super().__init__() 1553 self.fc1 = torch.nn.Linear(5, 10) 1554 1555 def forward(self, x): 1556 return self.fc1(x) 1557 1558 with TemporaryFileName() as state_dict_file: 1559 # Create state_dict to be loaded later 1560 model = TheModelClass() 1561 torch.save(model.state_dict(), state_dict_file) 1562 1563 fake_mode = FakeTensorMode() 1564 with fake_mode: 1565 torch.load(state_dict_file) # scenario 1 1566 torch.load(state_dict_file, map_location="cpu") # scenario 2 1567 1568 1569make_propagate_real_tensors_cls(FakeTensorPropTest) 1570 1571 1572class FakeTensorSerialization(TestCase): 1573 def test_serialization(self): 1574 x = torch.tensor([0], device="cpu") 1575 with FakeTensorMode(): 1576 y = pickle.loads(pickle.dumps(x)) 1577 self.assertEqual(type(y), FakeTensor) 1578 self.assertEqual(y.device.type, "meta") 1579 1580 with unset_fake_temporarily(): 1581 y = pickle.loads(pickle.dumps(x)) 1582 self.assertEqual(x.device, y.device) 1583 1584 def test_serialization_with_tracing(self): 1585 x = torch.tensor([0], device="cpu") 1586 with tracing(TracingContext(FakeTensorMode())): 1587 y = pickle.loads(pickle.dumps(x)) 1588 self.assertEqual(x.device, y.device) 1589 1590 1591class FakeTensorDispatchCache(TestCase): 1592 def test_shape_env_settings(self): 1593 """ 1594 Validation that any boolean settings in ShapeEnv are present in the 1595 ShapeEnvSettings. We hope to ensure that any new settings that might 1596 affect FakeTensor dispatch are included in the cache key calculation. 1597 If this test fails, consider updating ShapeEnvSettings or change this 1598 test to omit checking for the new field. 1599 """ 1600 init_sig = inspect.signature(ShapeEnv._init) 1601 args = [ 1602 name 1603 for name, param in init_sig.parameters.items() 1604 if type(param.default) is bool 1605 ] 1606 1607 settings = [f.name for f in dataclasses.fields(ShapeEnvSettings)] 1608 for arg in args: 1609 self.assertTrue(arg in settings) 1610 1611 def _test_cache_key(self, fm, x, y, z): 1612 """ 1613 Helper for all test_cache_key_* tests below. Assert that the 1614 cache keys for inputs x and y are the same, but z is different. 1615 """ 1616 func = aten.add.Tensor 1617 state = _CacheKeyState() 1618 key_x = fm._cache_key(state, func, [x], {}) 1619 key_y = fm._cache_key(state, func, [y], {}) 1620 key_z = fm._cache_key(state, func, [z], {}) 1621 1622 self.assertEqual(key_x, key_y) 1623 self.assertNotEqual(key_x, key_z) 1624 1625 def test_cache_key_dtype(self): 1626 with FakeTensorMode() as fm: 1627 x = torch.randn(4, 3, dtype=torch.float16) 1628 y = torch.randn(4, 3, dtype=torch.float16) 1629 z = x.to(dtype=torch.float32) 1630 self._test_cache_key(fm, x, y, z) 1631 1632 def test_cache_key_shape(self): 1633 with FakeTensorMode() as fm: 1634 x = torch.randn(4, 3) 1635 y = torch.randn(4, 3) 1636 z = torch.randn(4, 2) 1637 self._test_cache_key(fm, x, y, z) 1638 1639 def test_cache_key_stride(self): 1640 with FakeTensorMode() as fm: 1641 x = torch.randn(4, 2) 1642 y = torch.randn(4, 2) 1643 z = x.as_strided((4, 2), (1, 2)) 1644 self._test_cache_key(fm, x, y, z) 1645 1646 @unittest.skipIf(not RUN_CUDA, "requires cuda") 1647 def test_cache_key_device(self): 1648 with FakeTensorMode() as fm: 1649 x = torch.randn(4, 3) 1650 y = torch.randn(4, 3) 1651 z = x.to(device="cuda") 1652 self._test_cache_key(fm, x, y, z) 1653 1654 def test_cache_key_memory_format(self): 1655 with FakeTensorMode() as fm: 1656 x = torch.randn(1, 2, 3, 4) 1657 y = torch.randn(1, 2, 3, 4) 1658 z = x.to(memory_format=torch.channels_last) 1659 self._test_cache_key(fm, x, y, z) 1660 1661 def test_cache_key_storage_offset(self): 1662 with FakeTensorMode() as fm: 1663 x = torch.randn(3)[1:] 1664 y = torch.randn(3)[1:] 1665 z = torch.randn(2) 1666 self._test_cache_key(fm, x, y, z) 1667 1668 def test_cache_key_requires_grad(self): 1669 with FakeTensorMode() as fm: 1670 x = torch.randn(4, 3) 1671 y = torch.randn(4, 3) 1672 z = torch.randn(4, 3, requires_grad=True) 1673 self._test_cache_key(fm, x, y, z) 1674 1675 def test_cache_key_is_conj(self): 1676 with FakeTensorMode() as fm: 1677 x = torch.randn(4, 3, dtype=torch.complex64) 1678 y = torch.randn(4, 3, dtype=torch.complex64) 1679 z = torch.randn(4, 3, dtype=torch.complex64) 1680 torch._C._set_conj(z, not z.is_conj()) 1681 self._test_cache_key(fm, x, y, z) 1682 1683 def test_cache_key_is_neg(self): 1684 with FakeTensorMode() as fm: 1685 x = torch.randn(4, 3, dtype=torch.complex64) 1686 y = torch.randn(4, 3, dtype=torch.complex64) 1687 z = torch.randn(4, 3, dtype=torch.complex64) 1688 torch._C._set_neg(z, not z.is_neg()) 1689 self._test_cache_key(fm, x, y, z) 1690 1691 def test_cache_key_is_inference(self): 1692 with torch.inference_mode(True): 1693 t = torch.randn(4, 3) 1694 with FakeTensorMode() as fm: 1695 x = torch.randn(4, 3) 1696 y = torch.randn(4, 3) 1697 z = fm.from_tensor(t) 1698 self._test_cache_key(fm, x, y, z) 1699 1700 def test_cache_key_constants(self): 1701 with FakeTensorMode() as fm: 1702 # Python hashes 1.0 to the same value as 1. Make sure the 1703 # cache key calculation differentiates them. 1704 self._test_cache_key(fm, 1.0, 1.0, 1) 1705 self._test_cache_key(fm, 0.0, 0.0, 0) 1706 1707 def assertHitsMisses(self, hits, misses): 1708 """ 1709 Helper to assert on the number of recorded hits and misses. 1710 """ 1711 info = FakeTensorMode.cache_info() 1712 self.assertEqual(info.hits, hits) 1713 self.assertEqual(info.misses, misses) 1714 1715 def assertBypasses(self, reason, count): 1716 """ 1717 Helper to assert on the number of recorded bypasses. 1718 """ 1719 info = FakeTensorMode.cache_info() 1720 if count > 0: 1721 self.assertIn(reason, info.bypasses) 1722 self.assertEqual(info.bypasses[reason], count) 1723 else: 1724 self.assertNotIn(reason, info.bypasses) 1725 1726 def test_cache_hit(self): 1727 """ 1728 Test that cache hit/miss counters are updated correctly. 1729 """ 1730 with FakeTensorMode(): 1731 x = torch.randn(4, 3) 1732 y = torch.randn(4, 3) 1733 1734 FakeTensorMode.cache_clear() 1735 self.assertHitsMisses(0, 0) 1736 res1 = x + y 1737 self.assertHitsMisses(0, 1) 1738 res2 = x + y 1739 self.assertHitsMisses(1, 1) 1740 1741 self.assertEqual( 1742 extract_tensor_metadata(res1), 1743 extract_tensor_metadata(res2), 1744 ) 1745 1746 def test_cache_bypass(self): 1747 """ 1748 Test that cache bypass counters are updated correctly. 1749 """ 1750 with FakeTensorMode(): 1751 x = torch.randn(1, 2) 1752 1753 FakeTensorMode.cache_clear() 1754 self.assertBypasses("inplace view", 0) 1755 1756 x.unsqueeze_(0) 1757 self.assertBypasses("inplace view", 1) 1758 1759 def test_cache_default_dtype(self): 1760 """ 1761 Test that the default dtype is respected when serving cached results. 1762 """ 1763 with FakeTensorMode(): 1764 x = torch.tensor([1, 2], dtype=torch.int32) 1765 torch.set_default_dtype(torch.float32) 1766 1767 FakeTensorMode.cache_clear() 1768 self.assertHitsMisses(0, 0) 1769 1770 y = x + 1.0 1771 self.assertEqual(y.dtype, torch.float32) 1772 self.assertHitsMisses(0, 1) 1773 1774 torch.set_default_dtype(torch.float16) 1775 y = x + 1.0 1776 self.assertEqual(y.dtype, torch.float16) 1777 self.assertHitsMisses(0, 2) 1778 1779 torch.set_default_dtype(torch.float32) 1780 y = x + 1.0 1781 self.assertEqual(y.dtype, torch.float32) 1782 self.assertHitsMisses(1, 2) 1783 1784 @unittest.skipIf(not RUN_CUDA, "requires cuda") 1785 def test_cache_default_device(self): 1786 """ 1787 Test that the default device is respected when serving cached results. 1788 """ 1789 with FakeTensorMode(): 1790 FakeTensorMode.cache_clear() 1791 self.assertHitsMisses(0, 0) 1792 1793 torch.set_default_device("cpu") 1794 x = torch.tensor([1, 2]) 1795 y = x + 1.0 1796 self.assertEqual(y.device.type, "cpu") 1797 self.assertHitsMisses(0, 1) 1798 1799 torch.set_default_device("cuda") 1800 x = torch.tensor([1, 2]) 1801 y = x + 1.0 1802 self.assertEqual(y.device.type, "cuda") 1803 self.assertHitsMisses(0, 2) 1804 1805 torch.set_default_device("cpu") 1806 x = torch.tensor([1, 2]) 1807 y = x + 1.0 1808 self.assertEqual(y.device.type, "cpu") 1809 self.assertHitsMisses(1, 2) 1810 1811 def test_cache_inplace_op(self): 1812 """ 1813 Test that inplace ops served from the cache correctly reference the 1814 input parameter. 1815 """ 1816 with FakeTensorMode(): 1817 x = torch.randn(1, 2) 1818 y = torch.randn(1, 2) 1819 1820 FakeTensorMode.cache_clear() 1821 self.assertHitsMisses(0, 0) 1822 1823 z = x.add_(y) 1824 self.assertHitsMisses(0, 1) 1825 self.assertEqual(id(x), id(z)) 1826 1827 w = x.add_(y) 1828 self.assertHitsMisses(1, 1) 1829 self.assertEqual(id(x), id(w)) 1830 1831 def test_cache_view_op(self): 1832 """ 1833 Test that view ops are handled correctly when served from the cache. 1834 """ 1835 with FakeTensorMode(): 1836 x1 = torch.ones(2, requires_grad=True).clone() 1837 x2 = torch.ones(2, requires_grad=True).clone() 1838 y2 = x2.view(-1) 1839 1840 # Test operating on a non-view tensor, then the same operation 1841 # on a view tensor. Assert that the view property is set correctly. 1842 z1 = x1.mul_(2) 1843 self.assertFalse(z1._is_view()) 1844 1845 z2 = y2.mul_(2) 1846 self.assertTrue(z2._is_view()) 1847 1848 # Now the other way around: first operate on a view tensor, then 1849 # the same operation on a non-view tensor. 1850 z2 = y2.mul_(2) 1851 self.assertTrue(z2._is_view()) 1852 1853 z1 = x1.mul_(2) 1854 self.assertFalse(z1._is_view()) 1855 1856 def test_cache_dispatch_key_set(self): 1857 """ 1858 Test that operations that change the dispatch key set bypass caching. 1859 """ 1860 with FakeTensorMode(): 1861 FakeTensorMode.cache_clear() 1862 self.assertBypasses("dispatch_key_set mismatch", 0) 1863 1864 x = torch._efficientzerotensor(3) 1865 self.assertTrue(x._is_zerotensor()) 1866 self.assertBypasses("dispatch_key_set mismatch", 1) 1867 1868 y = torch._efficientzerotensor(3) 1869 self.assertTrue(y._is_zerotensor()) 1870 self.assertBypasses("dispatch_key_set mismatch", 2) 1871 1872 def test_inference_mode(self): 1873 """ 1874 Test that caching handles inference mode correctly. 1875 """ 1876 with FakeTensorMode(): 1877 x = torch.randn(4, 3) 1878 y = torch.randn(4, 3) 1879 1880 FakeTensorMode.cache_clear() 1881 self.assertHitsMisses(0, 0) 1882 1883 # Expect a miss when the inference mode is different 1884 res1 = x + y 1885 with torch.inference_mode(): 1886 res2 = x + y 1887 1888 self.assertHitsMisses(0, 2) 1889 self.assertFalse(res1.is_inference()) 1890 self.assertTrue(res2.is_inference()) 1891 1892 # Second tries should see hits 1893 res3 = x + y 1894 1895 self.assertHitsMisses(1, 2) 1896 self.assertFalse(res3.is_inference()) 1897 self.assertEqual( 1898 extract_tensor_metadata(res1), 1899 extract_tensor_metadata(res3), 1900 ) 1901 1902 with torch.inference_mode(): 1903 res4 = x + y 1904 1905 self.assertHitsMisses(2, 2) 1906 self.assertTrue(res4.is_inference()) 1907 self.assertEqual( 1908 extract_tensor_metadata(res2), 1909 extract_tensor_metadata(res4), 1910 ) 1911 1912 1913if __name__ == "__main__": 1914 run_tests() 1915