xref: /aosp_15_r20/external/pytorch/test/test_fake_tensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: meta tensors"]
2
3
4import contextlib
5import copy
6import dataclasses
7import inspect
8import itertools
9import pickle
10import unittest
11import weakref
12from unittest.mock import patch
13
14import numpy as np
15import torch
16import torch._dynamo
17import torch._functorch.config
18import torch._prims as prims
19import torch.testing._internal.optests as optests
20import torch.utils._pytree as pytree
21
22from torch import distributed as dist
23from torch._C._functorch import _add_batch_dim, get_unwrapped, is_batchedtensor
24from torch._dynamo.testing import make_test_cls_with_patches, rand_strided
25from torch._guards import tracing, TracingContext
26from torch._subclasses.fake_tensor import (
27    DynamicOutputShapeException,
28    extract_tensor_metadata,
29    FakeTensor,
30    FakeTensorConverter,
31    FakeTensorMode,
32    unset_fake_temporarily,
33    UnsupportedOperatorException,
34    _CacheKeyState
35)
36from torch.fx.experimental.proxy_tensor import make_fx
37from torch.fx.experimental.symbolic_shapes import (
38    DimDynamic,
39    free_symbols,
40    ShapeEnv,
41    ShapeEnvSettings,
42    StatelessSymbolicContext,
43    statically_known_true,
44)
45from torch.fx.passes.fake_tensor_prop import FakeTensorProp
46from torch.testing import FileCheck
47from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
48from torch.testing._internal.common_device_type import (
49    instantiate_device_type_tests,
50    OpDTypes,
51    ops,
52)
53from torch.testing._internal.common_utils import (
54    instantiate_parametrized_tests,
55    parametrize,
56    run_tests,
57    skipIfCrossRef,
58    skipIfRocm,
59    skipIfTorchDynamo,
60    TemporaryFileName,
61    TEST_WITH_TORCHDYNAMO,
62    TestCase,
63)
64
65from torch.testing._internal.inductor_utils import GPU_TYPE
66from torch.testing._internal.custom_op_db import custom_op_db
67from torch.testing._internal.jit_utils import RUN_CUDA
68from torch.utils._mode_utils import no_dispatch
69from torch.utils._python_dispatch import TorchDispatchMode
70
71aten = torch.ops.aten
72
73torch._dynamo.config.fake_tensor_cache_enabled = True
74torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True
75
76
77def expectedFailurePropagateRealTensors(fn):
78    fn._expected_failure_propagate_real_tensors = True
79    return fn
80
81
82class FakeTensorTest(TestCase):
83    def checkType(self, t, device_str, size):
84        self.assertTrue(isinstance(t, FakeTensor))
85        self.assertEqual(t.device.type, device_str)
86        self.assertEqual(list(t.size()), size)
87
88    @unittest.skipIf(not RUN_CUDA, "requires cuda")
89    def test_cuda_initialized(self):
90        # doesnt error
91        with FakeTensorMode():
92            p = torch.randn(4, 2, requires_grad=True, device="cuda")
93            x = torch.randn(8, 4, device="cuda")
94            y = torch.mm(x, p).square().sum()
95            y.backward()
96
97    def test_basic(self):
98        x = torch.empty(2, 2, device="cpu")
99        y = torch.empty(4, 2, 2, device="cpu")
100        with FakeTensorMode() as mode:
101            x = mode.from_tensor(x)
102            y = mode.from_tensor(y)
103            z = x + y
104            self.assertEqual(z.shape, (4, 2, 2))
105            self.assertEqual(z.device, torch.device("cpu"))
106            self.assertTrue(isinstance(z, FakeTensor))
107
108    def test_custom_op_fallback(self):
109        from torch.library import impl, Library
110
111        try:
112            test_lib = Library("my_test_op", "DEF")  # noqa: TOR901
113            test_lib.define("foo(Tensor self) -> Tensor")
114
115            @impl(test_lib, "foo", "CPU")
116            def foo_impl(self):
117                return self.cos()
118
119            x = torch.empty(2, 2, device="cpu")
120            with self.assertRaisesRegex(
121                UnsupportedOperatorException, "my_test_op.foo.default"
122            ):
123                with FakeTensorMode(allow_fallback_kernels=True) as mode:
124                    x = mode.from_tensor(x)
125                    torch.ops.my_test_op.foo(x)
126
127        finally:
128            test_lib._destroy()
129
130    def test_parameter_instantiation(self):
131        with FakeTensorMode():
132            x = torch.rand([4])
133            y = torch.nn.parameter.Parameter(x)
134            self.assertTrue(isinstance(y, torch.nn.Parameter))
135
136    @unittest.skipIf(not dist.is_available(), "requires distributed")
137    def test_fsdp_flat_param(self):
138        from torch.distributed.fsdp._flat_param import FlatParameter
139
140        with FakeTensorMode() as m:
141            data = torch.randn(2, 2)
142            param = FlatParameter(data, requires_grad=True)
143        self.assertIsInstance(param, FlatParameter)
144        self.assertIsInstance(param, torch.nn.Parameter)
145        self.assertIsInstance(param, FakeTensor)
146
147    def test_non_parameter_grad(self):
148        mode = FakeTensorMode()
149        t = torch.rand([4], requires_grad=True)
150        fake_t = mode.from_tensor(t)
151        self.assertEqual(fake_t.requires_grad, t.requires_grad)
152
153    @unittest.skipIf(
154        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
155    )
156    @unittest.skipIf(not RUN_CUDA, "requires cuda")
157    def test_index_cuda_with_cpu(self):
158        with FakeTensorMode():
159            x = torch.rand([2048], device="cuda")
160            out = x[torch.zeros([36], dtype=torch.int64)]
161            self.checkType(out, "cuda", [36])
162
163    @unittest.skipIf(not RUN_CUDA, "requires cuda")
164    def test_shape_take_not_device(self):
165        with FakeTensorMode():
166            x = torch.empty(1, device="cpu")
167            y = torch.empty(8, 8, device="cuda")
168            out = x.resize_as_(y)
169            self.assertEqual(out.shape, (8, 8))
170            self.assertEqual(out.device.type, "cpu")
171            self.assertTrue(isinstance(out, FakeTensor))
172
173    def test_repr(self):
174        with FakeTensorMode():
175            x = torch.empty(2, 2, device="cpu")
176            self.assertEqual(repr(x), "FakeTensor(..., size=(2, 2))")
177            x = torch.empty(2, 2, device="meta")
178            self.assertEqual(repr(x), "FakeTensor(..., device='meta', size=(2, 2))")
179
180    @unittest.skipIf(not RUN_CUDA, "requires cuda")
181    def test_zero_dim(self):
182        with FakeTensorMode() as mode:
183            x = torch.tensor(0.0)
184            y = torch.rand([4, 4], device="cuda")
185            out = x + y
186            self.assertEqual(out.shape, (4, 4))
187            self.assertEqual(out.device, y.device)
188            self.assertTrue(isinstance(out, FakeTensor))
189
190    def test_nan_to_num(self):
191        with FakeTensorMode():
192            for dtype in [torch.float16, torch.float32]:
193                x = torch.rand([4], dtype=dtype)
194                y = torch.nan_to_num(x, nan=None)
195                z = torch.nan_to_num(x, 0.0)
196                self.assertEqual(dtype, y.dtype)
197                self.assertEqual(dtype, z.dtype)
198
199    @unittest.skipIf(not RUN_CUDA, "requires cuda")
200    def test_throw(self):
201        x = torch.tensor(0.0)  # TODO: tensor() errors
202        with FakeTensorMode() as mode:
203            x_conv = mode.from_tensor(x)
204            y = torch.rand([4, 4], device="cuda")
205            z = torch.rand([4, 4], device="cpu")
206            self.assertRaises(Exception, lambda: torch.lerp(x_conv, y, z))
207
208    @unittest.skipIf(not RUN_CUDA, "requires cuda")
209    def test_type_as(self):
210        with FakeTensorMode():
211            x = torch.rand([16, 1], device="cpu")
212            y = torch.rand([4, 4], device="cuda")
213            out = x.type_as(y)
214            self.assertEqual(out.device.type, "cuda")
215            self.assertTrue(isinstance(out, FakeTensor))
216
217    @unittest.skipIf(not RUN_CUDA, "requires cuda")
218    def test_setitem(self):
219        for device in ["cpu", "cuda"]:
220            with FakeTensorMode():
221                x = torch.rand([16, 1], device=device)
222                x[..., 0] = 0
223
224    @unittest.skipIf(not RUN_CUDA, "requires cuda")
225    def test_device_inplace_copy(self):
226        with FakeTensorMode():
227            x = torch.rand([8, 8], device="cpu")
228            y = torch.rand([8, 8], device="cuda")
229            assert x.copy_(y).device.type == "cpu"
230            assert y.copy_(x).device.type == "cuda"
231
232    def test_fake_dispatch_keys(self):
233        with FakeTensorMode():
234            x = torch.rand([4])
235            f = (
236                FileCheck()
237                .check("CPU")
238                .check("ADInplaceOrView")
239                .check("AutogradCPU")
240                .check("AutocastCPU")
241            )
242            f.run(torch._C._dispatch_key_set(x))
243
244            with torch.inference_mode():
245                x = torch.rand([4])
246                y = x + x
247                FileCheck().check("CPU").check("AutocastCPU").run(
248                    torch._C._dispatch_key_set(y)
249                )
250                FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(
251                    torch._C._dispatch_key_set(y)
252                )
253
254    def test_batch_tensor(self):
255        x = torch.rand((3, 4, 5))
256        b = _add_batch_dim(x, 0, 0)
257        mode = FakeTensorMode()
258        fake_b = mode.from_tensor(b)
259        prims.utils.compare_tensor_meta(b, fake_b, check_strides=True)
260
261        b1 = _add_batch_dim(x, 1, 1)
262        b2 = _add_batch_dim(b1, 0, 2)
263        fake_b2 = mode.from_tensor(b2)
264        prims.utils.compare_tensor_meta(b2, fake_b2, check_strides=True)
265        self.assertTrue(is_batchedtensor(fake_b2))
266        fake_b1 = get_unwrapped(fake_b2)
267        self.assertTrue(is_batchedtensor(fake_b1))
268        fake_tensor = get_unwrapped(fake_b1)
269        self.assertIsInstance(fake_tensor, FakeTensor)
270
271    def test_constructor(self):
272        with FakeTensorMode():
273            x = torch.rand([4, 4], device="cpu")
274
275        self.assertTrue(isinstance(x, FakeTensor))
276        self.assertTrue(x.device.type == "cpu")
277
278    def test_mode(self):
279        with FakeTensorMode():
280            y = torch.rand([4], device="cpu")
281            out = y + y
282
283        self.assertTrue(isinstance(out, FakeTensor))
284
285    def test_full(self):
286        # Test torch.full returns tensor with correct dtype
287        with torch._subclasses.CrossRefFakeMode():
288            y = torch.full((4, 4), 1)
289
290    def check_function_with_fake(self, fn):
291        out = fn()
292        with torch._subclasses.FakeTensorMode():
293            out_fake = fn()
294
295        for a, b in zip(pytree.tree_leaves(out), pytree.tree_leaves(out_fake)):
296            if not isinstance(a, torch.Tensor):
297                self.assertTrue(not isinstance(b, torch.Tensor))
298                continue
299
300            prims.utils.compare_tensor_meta(a, b, check_strides=True)
301
302    @unittest.skipIf(not RUN_CUDA, "requires cuda")
303    def test_non_kwarg_device(self):
304        with FakeTensorMode():
305            x = torch.rand([16, 1], device="cpu")
306            y = x.to(torch.device("cpu"))
307            self.assertIs(x, y)
308            z = x.to(torch.device("cuda"))
309            self.assertEqual(z.device.type, "cuda")
310
311    def test_non_overlapping_stride_zero(self):
312        def foo():
313            x = torch.empty_strided([1, 3, 427, 640], (0, 1, 1920, 3))
314            return x.half()
315
316        self.check_function_with_fake(foo)
317
318    def test_fake_mode_error(self):
319        x = torch.rand([4, 4])
320
321        with self.assertRaisesRegex(Exception, "Please convert all Tensors"):
322            with FakeTensorMode():
323                y = x[0]
324
325    @unittest.skipIf(
326        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
327    )
328    def test_fake_grad_copy(self):
329        x = torch.rand([4, 4], requires_grad=True)
330        x.grad = torch.rand([4, 4])
331        mode = FakeTensorMode()
332        fake_x = mode.from_tensor(x)
333        prims.utils.compare_tensor_meta(fake_x, x)
334        prims.utils.compare_tensor_meta(fake_x.grad, x.grad)
335
336        self.assertTrue(isinstance(fake_x.grad, FakeTensor))
337
338    @unittest.skipIf(not RUN_CUDA, "requires cuda")
339    def test_index_put_error(self):
340        mode = FakeTensorMode()
341        for context in [contextlib.nullcontext, lambda: mode]:
342            with context():
343                y = torch.randn(2, 2, 3)
344                x = torch.randn(2, 2, 3).to("cuda")
345                with self.assertRaises(RuntimeError):
346                    x[[1, 1]] = y
347
348                with self.assertRaises(RuntimeError):
349                    torch.ops.aten.index_put(x, torch.tensor([1, 1], device="cuda"), y)
350
351                # no error
352                torch.ops.aten.index_put(
353                    x, torch.tensor([1, 1], device="cuda"), torch.tensor(5.0)
354                )
355                torch.ops.aten.index_put_(
356                    x, torch.tensor([1, 1], device="cuda"), torch.tensor(5.0)
357                )
358
359    @unittest.skipIf(not RUN_CUDA, "requires cuda")
360    def test_like_constructor(self):
361        with FakeTensorMode():
362            x = torch.rand([4, 4])
363            y = torch.ones_like(x)
364            self.assertTrue(isinstance(y, FakeTensor))
365            self.assertEqual(y.device.type, "cpu")
366            z = torch.ones_like(x, device="cuda")
367            self.assertTrue(isinstance(z, FakeTensor))
368            self.assertEqual(z.device.type, "cuda")
369
370    def test_binary_op_type_promotion(self):
371        with FakeTensorMode():
372            x = torch.empty([2, 2], dtype=torch.float)
373            y = torch.empty([2, 2], dtype=torch.int64)
374            out = x / y
375            self.assertEqual(out.dtype, torch.float)
376            self.assertEqual(out.device.type, "cpu")
377
378    @unittest.skipIf(
379        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
380    )
381    def test_from_numpy(self):
382        with FakeTensorMode():
383            x = torch.tensor(np.zeros([4, 4]))
384            self.checkType(x, "cpu", [4, 4])
385
386    def test_randperm(self):
387        x = torch.randperm(10)
388        y = torch.randperm(5, device="cpu")
389        with FakeTensorMode():
390            x1 = torch.randperm(10)
391            prims.utils.compare_tensor_meta(x, x1)
392            y1 = torch.randperm(5, device="cpu")
393            prims.utils.compare_tensor_meta(y, y1)
394
395    def test_print_in_fake_mode(self):
396        x = torch.zeros(2)
397        # does not fail
398        with FakeTensorMode():
399            out = str(x)
400        assert "FakeTensor" not in out
401
402    @unittest.skipIf(not RUN_CUDA, "requires cuda")
403    def test_upsample_bilinear_small_channels(self):
404        out = []
405        mode = FakeTensorMode()
406        for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
407            with context():
408                arg0_1 = torch.empty_strided(
409                    (3, 427, 640), (1, 1920, 3), dtype=torch.float32, device="cuda"
410                )
411                unsqueeze = torch.ops.aten.unsqueeze.default(arg0_1, 0)
412                out.append(
413                    torch.ops.aten.upsample_bilinear2d.default(
414                        unsqueeze, [800, 1199], False
415                    )
416                )
417
418        self.assertTrue(out[1].is_contiguous())
419        self.checkMetaProps(out[0], out[1])
420
421    @unittest.skipIf(not RUN_CUDA, "requires cuda")
422    def test_cpu_fallback(self):
423        with FakeTensorMode(allow_fallback_kernels=False):
424            filters = torch.randn(8, 4, 3, 3).cuda()
425            inputs = torch.randn(1, 4, 5, 5).cuda()
426            out = torch.nn.functional.conv2d(inputs, filters, padding=1)
427            self.assertEqual(out.device.type, "cuda")
428            self.assertEqual(list(out.size()), [1, 8, 5, 5])
429
430        with FakeTensorMode(allow_fallback_kernels=True):
431            # intentionally bad inputs
432            filters = torch.randn(8, 20, 3, 3).cuda()
433            inputs = torch.randn(1, 7, 10, 5).cuda()
434            with self.assertRaises(RuntimeError):
435                torch.nn.functional.conv2d(inputs, filters, padding=1)
436
437        with FakeTensorMode(allow_fallback_kernels=True):
438            filters = torch.randn(8, 4, 3, 3).cuda()
439            inputs = torch.randn(1, 4, 5, 5).cuda()
440
441            out = torch.nn.functional.conv2d(inputs, filters, padding=1)
442            self.assertEqual(out.device.type, "cuda")
443            self.assertEqual(list(out.size()), [1, 8, 5, 5])
444
445    @unittest.skipIf(not RUN_CUDA, "requires cuda")
446    def test_out_multi_device(self):
447        with FakeTensorMode():
448            x = torch.rand([4])
449            y = torch.rand([4], device="cuda")
450
451            with self.assertRaisesRegex(Exception, "found.+two.+devices"):
452                torch.sin(x, out=y)
453
454            with self.assertRaisesRegex(Exception, "found.+two.+devices"):
455                x.add_(y)
456
457    @unittest.skipIf(
458        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
459    )
460    @unittest.skipIf(not RUN_CUDA, "requires cuda")
461    def test_normalize_device(self):
462        with FakeTensorMode():
463            x = torch.empty(1, device="cuda")
464            y = torch.empty(1, device=f"cuda:{torch.cuda.current_device()}")
465            out = x + y
466        self.checkType(out, "cuda", [1])
467
468    def test_recursive_invocation(self):
469        mode = FakeTensorMode()
470        with mode:
471            x = torch.tensor(2)
472            mode.in_kernel_invocation = True
473            y = x + x
474            self.assertTrue(mode.in_kernel_invocation)
475
476    @unittest.skipIf(
477        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
478    )
479    @skipIfRocm
480    @parametrize(
481        "allow_fallback_kernels",
482        [False, True],
483        lambda a: "with_fallback" if a else "without_fallback",
484    )
485    @unittest.skipIf(not RUN_CUDA, "requires cuda")
486    def test_cudnn_rnn(self, allow_fallback_kernels):
487        def fn(
488            a0,
489            b0,
490            b1,
491            b2,
492            b3,
493            b4,
494            b5,
495            b6,
496            b7,
497            b8,
498            b9,
499            b10,
500            b11,
501            b12,
502            b13,
503            b14,
504            b15,
505            a3,
506            a4,
507            a5,
508        ):
509            a1 = [
510                b0,
511                b1,
512                b2,
513                b3,
514                b4,
515                b5,
516                b6,
517                b7,
518                b8,
519                b9,
520                b10,
521                b11,
522                b12,
523                b13,
524                b14,
525                b15,
526            ]
527            return torch.ops.aten._cudnn_rnn(
528                a0,
529                a1,
530                4,
531                a3,
532                a4,
533                a5,
534                2,
535                2048,
536                0,
537                2,
538                False,
539                0.0,
540                False,
541                True,
542                [],
543                None,
544            )
545
546        mode = FakeTensorMode(allow_fallback_kernels=allow_fallback_kernels)
547        for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
548            with context():
549                inps1 = [
550                    torch.randn([92, 8, 2048]).cuda(),
551                    torch.randn([8192, 2048]).cuda(),
552                    torch.randn([8192, 2048]).cuda(),
553                    torch.randn([8192]).cuda(),
554                    torch.randn([8192]).cuda(),
555                    torch.randn([8192, 2048]).cuda(),
556                    torch.randn([8192, 2048]).cuda(),
557                    torch.randn([8192]).cuda(),
558                    torch.randn([8192]).cuda(),
559                    torch.randn([8192, 4096]).cuda(),
560                    torch.randn([8192, 2048]).cuda(),
561                    torch.randn([8192]).cuda(),
562                    torch.randn([8192]).cuda(),
563                    torch.randn([8192, 4096]).cuda(),
564                    torch.randn([8192, 2048]).cuda(),
565                    torch.randn([8192]).cuda(),
566                    torch.randn([8192]).cuda(),
567                    torch.randn([167837696]).cuda(),
568                    torch.randn([4, 8, 2048]).cuda(),
569                    torch.randn([4, 8, 2048]).cuda(),
570                ]
571                inps2 = inps1
572                inps2[len(inps2) - 1] = None  # argument `cx` can be None
573
574                for inps in [inps1, inps2]:
575                    out = fn(*inps)
576                    self.assertIs(out[4], inps[-3])
577                    for ten in out:
578                        if i == 1:
579                            self.assertTrue(isinstance(ten, FakeTensor))
580                        self.assertEqual(ten.device.type, "cuda")
581
582    @unittest.skipIf(not RUN_CUDA, "requires cuda")
583    def test_cuda_lstm(self):
584        # Ensure CUDA (non-cuDNN) impl succeeds with fake tensors.
585        with torch.backends.cudnn.flags(enabled=False):
586            fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
587            with fake_tensor_mode:
588                N = 5
589                L = 4
590                H_in = 2
591                hidden_size = 3
592                proj_size = 2
593                num_layers = 2
594                bidir = False
595                D = 2 if bidir else 1
596                H_out = proj_size if proj_size > 0 else hidden_size
597
598                lstm = torch.nn.LSTM(
599                    input_size=H_in,
600                    hidden_size=hidden_size,
601                    num_layers=num_layers,
602                    proj_size=proj_size,
603                    batch_first=False,
604                    bias=True,
605                    bidirectional=bidir,
606                    device="cuda",
607                )
608
609                h_0 = torch.randn((num_layers * D, N, H_out), device="cuda")
610                c_0 = torch.randn((num_layers * D, N, hidden_size), device="cuda")
611                inp = torch.randn((L, N, H_in), device="cuda")
612                (output, (h_n, c_n)) = lstm(inp, (h_0, c_0))
613                output.sum().backward()
614
615                self.assertEqual(output.shape, (L, N, D * H_out))
616                self.assertEqual(h_n.shape, (D * num_layers, N, H_out))
617                self.assertEqual(c_n.shape, (D * num_layers, N, hidden_size))
618
619    def test_data_dependent_operator(self):
620        with FakeTensorMode(allow_fallback_kernels=False):
621            x = torch.rand([10, 10])
622
623            self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x))
624
625    def test_parameter_view(self):
626        x = torch.nn.Parameter(torch.randn(4))
627        x_view = x.view(4)
628        mode = FakeTensorMode()
629        fake_x_view = mode.from_tensor(x_view)
630        fake_x = mode.from_tensor(x)
631        self.assertFalse(isinstance(fake_x_view, torch.nn.Parameter))
632        self.assertTrue(isinstance(fake_x, torch.nn.Parameter))
633
634    def test_tolist(self):
635        shape_env = ShapeEnv()
636        with FakeTensorMode(allow_fallback_kernels=False, shape_env=shape_env):
637            x = torch.rand([10])
638            x.tolist()
639
640    # Propagate real tensors doesn't work with fake-on-fake
641    @expectedFailurePropagateRealTensors
642    def test_same_shape_env_preserved(self):
643        shape_env = ShapeEnv()
644        mode1 = FakeTensorMode(shape_env=shape_env)
645        t1 = mode1.from_tensor(
646            torch.randn(10),
647            symbolic_context=StatelessSymbolicContext(
648                dynamic_sizes=[DimDynamic.DYNAMIC], constraint_sizes=[None]
649            ),
650        )
651        mode2 = FakeTensorMode(shape_env=shape_env)
652        t2 = mode2.from_tensor(t1)
653        # t2.size(0) is still dynamic, even though we didn't pass DYNAMIC here
654        self.assertIsNot(t2, t1)
655        self.assertIs(t1.fake_mode, mode1)
656        self.assertIs(t2.fake_mode, mode2)
657        self.assertIs(t2.size(0).node.shape_env, t1.size(0).node.shape_env)
658        self.assertEqual(str(t2.size(0)), str(t1.size(0)))
659
660    # TODO: Support NJT.  There's also some funny business with dynamic shapes
661    # which would need to be dealt with as well
662    @expectedFailurePropagateRealTensors
663    def test_jagged_fake_to_fake_preserved(self):
664        from torch.nested._internal.nested_tensor import jagged_from_list
665
666        S0, S1, S2 = 3, 4, 5
667        D = 4
668        a = torch.randn(S0, D, requires_grad=True, dtype=torch.float64)
669        b = torch.randn(S1, D, requires_grad=True, dtype=torch.float64)
670        c = torch.randn(S2, D, requires_grad=True, dtype=torch.float64)
671        offsets = None
672        jt, _ = jagged_from_list([a, b, c], offsets)
673        shape_env = ShapeEnv()
674        mode1 = FakeTensorMode(shape_env=shape_env)
675        t1 = mode1.from_tensor(jt)
676        mode2 = FakeTensorMode(shape_env=shape_env)
677        t2 = mode2.from_tensor(t1)
678        # It's not obvious that the invocation above makes it dynamic but it
679        # does!
680        self.assertTrue(free_symbols(t1.size()))
681        self.assertIsNot(t2, t1)
682        self.assertIs(t1.offsets().fake_mode, mode1)
683        self.assertIs(t2.offsets().fake_mode, mode2)
684        self.assertIs(t2.size(1).node.shape_env, t1.size(1).node.shape_env)
685        self.assertEqual(str(t2.size(1)), str(t1.size(1)))
686
687    def checkMetaProps(self, t1, t2):
688        prims.utils.compare_tensor_meta(t1, t2, check_strides=True)
689
690    @skipIfCrossRef
691    def test_deepcopy(self):
692        with FakeTensorMode() as mode:
693            pass
694        mod = torch.nn.BatchNorm2d(10)
695        with torch._subclasses.fake_tensor.FakeCopyMode(mode):
696            mod_copied = copy.deepcopy(mod)
697
698        def check_copy(mod, mod_copied):
699            for name, param in itertools.chain(
700                mod.named_parameters(), mod.named_buffers()
701            ):
702                param_copied = getattr(mod_copied, name)
703                self.checkMetaProps(param, param_copied)
704                self.assertTrue(isinstance(param_copied, FakeTensor))
705                self.assertEqual(
706                    isinstance(param, torch.nn.Parameter),
707                    isinstance(param_copied, torch.nn.Parameter),
708                )
709                self.assertEqual(param.requires_grad, param_copied.requires_grad)
710
711        check_copy(mod, mod_copied)
712
713        class ModuleNew(torch.nn.Module):
714            def __init__(self) -> None:
715                super().__init__()
716                self.a = torch.rand([10, 2])
717                self.b = self.a
718                self.c = self.a[0]
719
720        mod = ModuleNew()
721        with torch._subclasses.fake_tensor.FakeCopyMode(mode):
722            mod_copied = copy.deepcopy(mod)
723
724        self.assertIs(mod_copied.a, mod_copied.b)
725        self.assertEqual(mod_copied.b.storage()._cdata, mod_copied.a.storage()._cdata)
726
727    @unittest.skipIf(
728        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
729    )
730    @unittest.skipIf(not RUN_CUDA, "requires cuda")
731    def test_new(self):
732        with FakeTensorMode():
733            a = torch.rand([16, 1])
734            self.checkType(a.new(10, 10), "cpu", [10, 10])
735            self.checkType(a.new([1, 2, 3, 4]), "cpu", [4])
736            b = torch.rand([4, 4], device="cuda")
737            self.checkType(b.new(device="cuda"), "cuda", [0])
738            self.checkType(a.new(torch.rand([1])), "cpu", [1])
739
740    @unittest.skipIf(
741        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
742    )
743    def test_scalar_inputs(self):
744        with FakeTensorMode():
745            self.checkType(torch.div(3, 2), "cpu", [])
746            ten = torch.zeros(2, dtype=torch.int32) * 2.0
747            self.assertEqual(ten.dtype, torch.float)
748            self.checkType(ten, "cpu", [2])
749
750    @unittest.skipIf(
751        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
752    )
753    def test_allow_meta(self):
754        def run_meta():
755            with FakeTensorMode():
756                x = torch.rand([4], device="meta")
757                return x + x
758
759        self.checkType(run_meta(), "meta", [4])
760
761        with patch.object(torch._functorch.config, "fake_tensor_allow_meta", False):
762            self.assertRaises(Exception, run_meta)
763
764    def test_embedding_bag_meta(self):
765        def f():
766            # This behavior was originally unintentional but we see people
767            # relying on it
768            embedding = torch.nn.EmbeddingBag(10, 3, mode="sum", device="meta")
769            input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
770            offsets = torch.tensor([0, 4], dtype=torch.long)
771            return embedding(input, offsets)
772
773        real_out = f()
774        with FakeTensorMode():
775            fake_out = f()
776
777        for r, f in zip(real_out, fake_out):
778            self.assertEqual(r.size(), f.size())
779            self.assertEqual(r.device, f.device)
780
781    @unittest.skipIf(
782        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
783    )
784    def test_mixed_real_and_fake_inputs(self):
785        class _TestPattern(torch.nn.Module):
786            def __init__(self) -> None:
787                super().__init__()
788                self.conv = torch.nn.Conv2d(1, 1, 1)
789                self.bn = torch.nn.BatchNorm2d(1)
790
791            def forward(self, input):
792                running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
793                scale_factor = self.bn.weight / running_std
794                weight_shape = [1] * len(self.conv.weight.shape)
795                weight_shape[0] = -1
796                bias_shape = [1] * len(self.conv.weight.shape)
797                bias_shape[1] = -1
798                scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape)
799                zero_bias = torch.zeros_like(self.conv.bias, dtype=input.dtype)
800                conv = self.conv._conv_forward(input, scaled_weight, zero_bias)
801                conv_orig = conv / scale_factor.reshape(bias_shape)
802                conv_orig = conv_orig + self.conv.bias.reshape(bias_shape)
803                conv = self.bn(conv_orig)
804                return conv
805
806        example_inputs = (torch.randn(1, 1, 3, 3),)
807        mod = _TestPattern()
808        with FakeTensorMode(allow_non_fake_inputs=True):
809            out = mod(torch.randn(1, 1, 3, 3))
810        self.checkType(out, "cpu", (1, 1, 3, 3))
811
812    @unittest.skipIf(
813        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
814    )
815    @unittest.skipIf(not RUN_CUDA, "requires cuda")
816    def test_aten_copy_multi_device(self):
817        with FakeTensorMode():
818            x1 = torch.rand(4, device="cpu")
819            x2 = torch.rand(4, device="cuda")
820            copy1 = torch.ops.aten.copy.default(x1, x2)
821            copy2 = torch.ops.aten.copy.default(x2, x1)
822            out = torch.empty(4, device="cpu")
823            torch.ops.aten.copy.out(x1, x2, out=out)
824        self.checkType(copy1, "cpu", (4,))
825        self.checkType(copy2, "cuda", (4,))
826        self.checkType(out, "cpu", (4,))
827
828    @unittest.skipIf(
829        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
830    )
831    @unittest.skipIf(not RUN_CUDA, "requires cuda")
832    def test_aten_index_multi_device(self):
833        with FakeTensorMode():
834            x1 = torch.rand(4, 4, device="cpu")
835            x2 = torch.rand(4, 4, device="cuda")
836            i1 = torch.tensor([0, 1], device="cuda")
837            i2 = torch.tensor([0, 1], device="cpu")
838            # NB: This one does not work: cuda indices not allowed on cpu
839            # tensor
840            # r1 = torch.ops.aten.index(x1, i1)
841            r2 = torch.ops.aten.index(x2, i2)
842
843            y1 = torch.rand(4, device="cpu")
844            y2 = torch.rand(4, device="cuda")
845            j1 = torch.tensor([2], device="cuda")
846            j2 = torch.tensor([2], device="cpu")
847            r3 = torch.ops.aten.index_put.default(x1, j1, y1)
848            r4 = torch.ops.aten.index_put.default(x2, j2, y2)
849        # self.checkType(r1, "cpu", ())
850        self.checkType(r2, "cuda", ())
851        self.checkType(r3, "cpu", (4, 4))
852        self.checkType(r4, "cuda", (4, 4))
853
854    @unittest.skipIf(
855        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
856    )
857    @unittest.skipIf(not RUN_CUDA, "requires cuda")
858    def test_aten_slice_scatter_multi_device(self):
859        with FakeTensorMode():
860            x1 = torch.rand(4, 4, device="cpu")
861            y1 = torch.rand(2, 4, device="cuda")
862            x2 = torch.rand(4, 4, device="cuda")
863            y2 = torch.rand(2, 4, device="cpu")
864            out = torch.empty(4, 4, device="cpu")
865            r1 = torch.ops.aten.slice_scatter.default(x1, y1, start=2)
866            r2 = torch.ops.aten.slice_scatter.default(x2, y2, start=2)
867            r3 = torch.ops.aten.slice_scatter.out(x1, y1, out=out, start=2)
868        self.checkType(r1, "cpu", (4, 4))
869        self.checkType(r2, "cuda", (4, 4))
870        self.checkType(r3, "cpu", (4, 4))
871        self.checkType(out, "cpu", (4, 4))
872
873    def test__adaptive_avg_pool2d_backward(self):
874        with FakeTensorMode():
875            grad_out = torch.rand(2, 3, 4, 4)
876            inp = torch.rand(2, 3, 4, 4).to(memory_format=torch.channels_last)
877            grad_in = torch.ops.aten._adaptive_avg_pool2d_backward(grad_out, inp)
878            self.assertTrue(
879                torch._prims_common.suggest_memory_format(grad_in)
880                == torch.channels_last
881            )
882
883    def test_export_numpy(self):
884        class MyNumpyModel(torch.nn.Module):
885            def forward(self, input):
886                input = input.numpy()
887                return input + np.random.randn(*input.shape)
888
889        with FakeTensorMode():
890            ep = torch.export.export(MyNumpyModel(), args=(torch.randn(1000),))
891            self.assertTrue(isinstance(ep, torch.export.ExportedProgram))
892
893    def test_unsqueeze_copy(self):
894        shape_env = ShapeEnv()
895        t1 = torch.ones(2, 2, 768)
896        with FakeTensorMode(shape_env=shape_env) as fake_mode:
897            t = fake_mode.from_tensor(
898                t1,
899                symbolic_context=StatelessSymbolicContext(
900                    dynamic_sizes=[
901                        DimDynamic.DYNAMIC,
902                        DimDynamic.STATIC,
903                        DimDynamic.STATIC,
904                    ],
905                ),
906            )
907
908        self.assertEqual(t.shape[0], torch.ops.aten.unsqueeze_copy(t, 1).shape[0])
909
910    def test_alias_call(self):
911        fwAD = torch.autograd.forward_ad
912
913        def f(x):
914            return 4312491 * x
915
916        with torch._subclasses.fake_tensor.FakeTensorMode():
917            with fwAD.dual_level():
918                x = torch.randn(3, device="cpu")
919                y = torch.ones_like(x)
920                dual = fwAD.make_dual(x, y)
921                r = f(dual)
922
923        self.assertIsInstance(r, FakeTensor)
924        self.assertEqual(r.size(), [3])
925
926
927instantiate_parametrized_tests(FakeTensorTest)
928
929
930def make_propagate_real_tensors_cls(cls):
931    cls = make_test_cls_with_patches(
932        cls,
933        "PropagateRealTensors",
934        "_propagate_real_tensors",
935        (torch._functorch.config, "fake_tensor_propagate_real_tensors", True),
936        xfail_prop="_expected_failure_propagate_real_tensors",
937        decorator=skipIfTorchDynamo("propagate_real_tensors affects Dynamo"),
938    )
939    cls.__file__ = __file__
940    cls.__module__ = __name__
941    globals()[cls.__name__] = cls
942
943
944make_propagate_real_tensors_cls(FakeTensorTest)
945
946
947class FakeTensorConstHandling(TestCase):
948    def assertConst(self, *args):
949        for arg in args:
950            self.assertTrue(arg.constant is not None)
951
952    def assertNotConst(self, *args):
953        for arg in args:
954            self.assertTrue(arg.constant is None)
955
956    def test_simple(self):
957        with FakeTensorMode():
958            x = torch.tensor(4.0)
959            self.assertEqual(x.item(), 4.0)
960
961    def test_inplace_add(self):
962        with FakeTensorMode():
963            x = torch.tensor(4.0)
964            y = x.add_(1)
965            self.assertEqual(x.item(), 5.0)
966            self.assertEqual(y.item(), 5.0)
967            self.assertConst(x, y)
968
969    def test_shared_storages(self):
970        with FakeTensorMode():
971            x = torch.tensor([4.0])
972            y = x[:]
973
974            self.assertEqual(x.storage()._cdata, y.storage()._cdata)
975            self.assertEqual(x.constant.storage()._cdata, y.constant.storage()._cdata)
976
977    def test_constant_invalidation(self):
978        with FakeTensorMode():
979            x = torch.tensor([1.0])
980            self.assertConst(x)
981            y = torch.rand([1])
982            x.add_(y)
983            self.assertNotConst(x)
984
985    def test_inplace_view_invalidation(self):
986        with FakeTensorMode():
987            x = torch.tensor([1])
988            self.assertConst(x)
989            x.resize_([2])
990            self.assertEqual(x.size(0), 2)
991            self.assertNotConst(x)
992
993    def test_fake_tensor_in_intlist_repro(self):
994        def fn(tensors):
995            max_size = torch.tensor([800, 1216], dtype=torch.int64)
996            batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size)
997            return tensors[0].new_full(batch_shape, 0.0)
998
999        with self.assertRaises(
1000            torch._subclasses.fake_tensor.DataDependentOutputException
1001        ):
1002            with torch._subclasses.fake_tensor.FakeTensorMode():
1003                a = torch.randn(3, 800, 1199)
1004                b = torch.randn(3, 800, 800)
1005                inputs = [a, b]
1006                ref = fn(inputs)
1007
1008    def test_fake_tensor_batch_norm_cpu(self):
1009        with torch._subclasses.CrossRefFakeMode():
1010            m = torch.nn.Sequential(
1011                torch.nn.BatchNorm2d(10),
1012                torch.nn.ReLU(),
1013            )
1014            m.eval()
1015            out = m(torch.randn([2, 10, 8, 8]))
1016
1017    def test_shared_storage_invalidation(self):
1018        with FakeTensorMode():
1019            x = torch.tensor([1.0])
1020            y = x[:]
1021            self.assertConst(x, y)
1022            y.add_(torch.rand([1]))
1023            self.assertNotConst(x, y)
1024
1025    def test_aliased_const_write(self):
1026        with FakeTensorMode():
1027            x = torch.tensor([1])
1028            y = x.expand([4])
1029            self.assertNotConst(y)
1030            y[0] = 1
1031            self.assertNotConst(x)
1032
1033    def test_constant_propagate_through_functions(self):
1034        with FakeTensorMode():
1035            y = torch.div(4, 4, rounding_mode="trunc")
1036            self.assertConst(y)
1037
1038
1039make_propagate_real_tensors_cls(FakeTensorConstHandling)
1040
1041
1042def contains_type(type: torch.Type, maybe_contained_type: torch.Type):
1043    return maybe_contained_type.isSubtypeOf(type) or any(
1044        contains_type(e, maybe_contained_type) for e in type.containedTypes()
1045    )
1046
1047
1048class FakeTensorOpInfoTest(TestCase):
1049    @ops(custom_op_db, dtypes=OpDTypes.any_one)
1050    def test_fake(self, device, dtype, op):
1051        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
1052        for sample_input in sample_inputs_itr:
1053            args = (sample_input.input,) + sample_input.args
1054            kwargs = sample_input.kwargs
1055            optests.fake_check(op, args, kwargs)
1056
1057
1058make_propagate_real_tensors_cls(FakeTensorOpInfoTest)
1059instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=("cpu", "cuda"))
1060instantiate_device_type_tests(
1061    PropagateRealTensorsFakeTensorOpInfoTest, globals(), only_for=("cpu",)  # noqa: F821
1062)
1063
1064
1065class FakeTensorConverterTest(TestCase):
1066    def test_memoized_conversion_to_meta(self):
1067        x = torch.rand(2, 2, 2)
1068        mode = FakeTensorMode()
1069        self.assertTrue(mode.from_tensor(x) is mode.from_tensor(x))
1070
1071    def test_memoized_conversion_from_meta(self):
1072        x = torch.rand(2, 2).to(device="meta")
1073        mode = FakeTensorMode()
1074        converter = mode.fake_tensor_converter
1075        self.assertTrue(
1076            converter.from_meta_and_device(mode, x, "cpu")
1077            is converter.from_meta_and_device(mode, x, "cpu")
1078        )
1079
1080    def test_separate_tensor_storages_view(self):
1081        x = torch.rand(2, 2, 2)
1082        y = x[0]
1083        mode = FakeTensorMode()
1084        converter = mode.fake_tensor_converter
1085        x_conv = converter.from_real_tensor(mode, x)
1086        y_conv = converter.from_real_tensor(mode, y)
1087        self.assertEqual(torch._C._storage_id(x_conv), torch._C._storage_id(y_conv))
1088
1089    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
1090    def test_separate_tensor_storages_non_view(self):
1091        x = torch.rand(2, 2, 2)
1092        y = torch.rand(4, 2)
1093        y.set_(x.storage())
1094        mode = FakeTensorMode()
1095        converter = mode.fake_tensor_converter
1096        x_conv = converter.from_real_tensor(mode, x)
1097        y_conv = converter.from_real_tensor(mode, y)
1098        stor_id = torch._C._storage_id(x_conv)
1099        self.assertEqual(stor_id, torch._C._storage_id(y_conv))
1100        del x
1101        del x_conv
1102        self.assertEqual(len(converter.tensor_memo), 1)
1103        self.assertEqual(len(converter.meta_converter.storage_memo), 1)
1104        del y
1105        del y_conv
1106        self.assertEqual(len(converter.tensor_memo), 0)
1107        self.assertEqual(len(converter.meta_converter.storage_memo), 0)
1108
1109    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
1110    def test_dead_weak_ref(self):
1111        x = torch.rand(2, 2, 2)
1112        y = x[0]
1113        mode = FakeTensorMode()
1114        converter = FakeTensorConverter()
1115        x_conv = converter.from_real_tensor(mode, x)
1116        x_conv_storage = x_conv.untyped_storage()
1117        del x_conv
1118        self.assertFalse(x in converter.tensor_memo)
1119        y_conv = converter.from_real_tensor(mode, y)
1120        self.assertIs(x_conv_storage, y_conv.untyped_storage())
1121
1122    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
1123    def test_dead_key(self):
1124        x = torch.rand(2, 2, 2)
1125        mode = FakeTensorMode()
1126        converter = FakeTensorConverter()
1127        x_conv = converter.from_real_tensor(mode, x)
1128        self.assertEqual(len(converter.tensor_memo), 1)
1129        x_conv2 = converter.from_real_tensor(mode, x)
1130        assert x_conv2 is x_conv
1131        del x
1132        del x_conv
1133        del x_conv2
1134        self.assertEqual(len(converter.tensor_memo), 0)
1135
1136    def test_no_active_mode(self):
1137        with FakeTensorMode() as mode:
1138            x = torch.empty(2, 2, device="cpu")
1139            y = torch.empty(2, 2, device="cpu")
1140
1141        out = x + y
1142        self.assertEqual(mode, out.fake_mode)
1143        self.assertTrue(isinstance(out, FakeTensor))
1144        self.assertEqual(out.device.type, "cpu")
1145
1146    def test_multiple_modes(self):
1147        t = torch.rand([4])
1148        t2 = torch.rand([4])
1149        with FakeTensorMode() as m:
1150            with FakeTensorMode() as m2:
1151                t_fake = m.from_tensor(t)
1152                t2_fake = m2.from_tensor(t2)
1153
1154                with self.assertRaisesRegex(Exception, "Mixing fake modes"):
1155                    t_fake + t2_fake
1156
1157    def test_separate_mode_error(self):
1158        with FakeTensorMode():
1159            x = torch.empty(2, 2, device="cpu")
1160        with FakeTensorMode():
1161            y = torch.empty(2, 2, device="cpu")
1162        self.assertRaises(Exception, lambda: x, y)
1163
1164    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
1165    def test_no_ref_cycle(self):
1166        x = torch.rand([4])
1167        mode = FakeTensorMode()
1168        y = mode.from_tensor(x)
1169        self.assertEqual(len(mode.fake_tensor_converter.tensor_memo), 1)
1170        mode_weak = weakref.ref(mode)
1171        y_weak = weakref.ref(mode)
1172        del mode
1173        del y
1174        assert mode_weak() is None
1175        assert y_weak() is None
1176
1177
1178make_propagate_real_tensors_cls(FakeTensorConverterTest)
1179
1180
1181class FakeTensorOperatorInvariants(TestCase):
1182    def get_aten_op(self, schema):
1183        namespace, name = schema.name.split("::")
1184        overload = schema.overload_name if schema.overload_name else "default"
1185        assert namespace == "aten"
1186        return getattr(getattr(torch.ops.aten, name), overload)
1187
1188    def get_all_aten_schemas(self):
1189        for schema in torch._C._jit_get_all_schemas():
1190            namespace = schema.name.split("::")[0]
1191            if namespace != "aten":
1192                continue
1193            yield schema
1194
1195    def test_non_kwarg_only_device(self):
1196        for schema in self.get_all_aten_schemas():
1197            ten_type = torch._C.TensorType.get()
1198            if not any(
1199                contains_type(arg.type, ten_type)
1200                for arg in itertools.chain(schema.arguments, schema.returns)
1201            ):
1202                continue
1203
1204            opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
1205            has_non_kwarg_device = any(
1206                not arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
1207                for arg in schema.arguments
1208            )
1209            if has_non_kwarg_device:
1210                self.assertTrue(
1211                    self.get_aten_op(schema)
1212                    in torch._subclasses.fake_tensor._device_not_kwarg_ops
1213                )
1214
1215    def test_tensor_constructors_all_have_kwarg_device(self):
1216        for schema in self.get_all_aten_schemas():
1217            op = self.get_aten_op(schema)
1218            if not torch._subclasses.fake_tensor._is_tensor_constructor(op):
1219                continue
1220
1221            opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
1222            has_kwarg_device = any(
1223                arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
1224                for arg in schema.arguments
1225            )
1226
1227            self.assertTrue(
1228                has_kwarg_device or op == torch.ops.aten._list_to_tensor.default
1229            )
1230
1231    @unittest.expectedFailure
1232    def test_sparse_new(self):
1233        with FakeTensorMode():
1234            indices = torch.randn(1, 1, dtype=torch.int64)
1235            values = torch.randn(1)
1236            extra = (2,)
1237            sparse = torch.randn(1).to_sparse()
1238            # This used to segfault, now it does not, but it still raises an
1239            # error
1240            sparse2 = sparse.new(indices, values, extra)
1241
1242    def test_tensor_new(self):
1243        with FakeTensorMode():
1244            x = torch.Tensor([1, 2, 3])
1245        self.assertIsInstance(x, FakeTensor)
1246
1247    def test_like_ops(self):
1248        for schema in self.get_all_aten_schemas():
1249            if "_like" == schema.name[-5:]:
1250                op = self.get_aten_op(schema)
1251                self.assertIn(
1252                    op, torch._subclasses.fake_tensor._like_tensor_constructors
1253                )
1254
1255    def test_str_storage(self):
1256        x = torch.zeros(3)
1257        with FakeTensorMode() as m:
1258            y = m.from_tensor(x)
1259            self.assertExpectedInline(
1260                str(x.storage()),
1261                """\
1262 0.0
1263 0.0
1264 0.0
1265[torch.storage.TypedStorage(dtype=torch.float32, device=cpu) of size 3]""",
1266            )
1267            self.assertExpectedInline(
1268                str(y.storage()),
1269                """\
1270...
1271[torch.storage.TypedStorage(dtype=torch.float32, device=meta) of size 3]""",
1272            )
1273
1274        self.assertExpectedInline(
1275            str(y.storage()),
1276            """\
1277...
1278[torch.storage.TypedStorage(dtype=torch.float32, device=meta) of size 3]""",
1279        )
1280
1281    # at::_embedding_bag has no op info,
1282    # and returns extra tensors that at::embedding bag throws away
1283    def test_embedding_bag_private(self):
1284        args = [
1285            torch.ones(6, 1),
1286            torch.ones(6, dtype=torch.int64),
1287            torch.arange(2, dtype=torch.int64),
1288            False,
1289            2,  # mode = max
1290        ]
1291
1292        ref_out = torch.ops.aten._embedding_bag(*args)
1293        with FakeTensorMode() as m:
1294            meta_args = [
1295                m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args
1296            ]
1297            meta_out = torch.ops.aten._embedding_bag(*meta_args)
1298
1299        self.assertEqual(len(ref_out), len(meta_out))
1300        for ref_o, meta_o in zip(ref_out, meta_out):
1301            self.assertEqual(ref_o.size(), meta_o.size())
1302
1303    def test_cross_entropy_loss(self):
1304        inp = torch.randn(3, 5)
1305        target = torch.randint(5, (3,), dtype=torch.long)
1306        weight = torch.rand(5)
1307        fn = torch.nn.functional.cross_entropy
1308        for w in (weight, None):
1309            args = (inp, target, w)
1310            ref = fn(*args)
1311            with FakeTensorMode() as m:
1312                meta_args = [
1313                    m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args
1314                ]
1315                meta_out = torch.nn.functional.cross_entropy(
1316                    *meta_args, label_smoothing=0.5
1317                )
1318
1319            self.assertEqual(ref.size(), meta_out.size())
1320
1321    @skipIfRocm
1322    @unittest.skipIf(
1323        not PLATFORM_SUPPORTS_FLASH_ATTENTION,
1324        "Does not support SDPA or pre-SM80 hardware",
1325    )
1326    def test_flash_attention(self):
1327        class Repro(torch.nn.Module):
1328            def __init__(self) -> None:
1329                super().__init__()
1330
1331            def forward(self, arg1, arg2, arg3):
1332                torch.ops.aten._scaled_dot_product_flash_attention(
1333                    arg1, arg2, arg3, scale=0.17677669529663687
1334                )
1335
1336        args_new = [
1337            [
1338                ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
1339                ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
1340                ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
1341            ],
1342            [
1343                ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
1344                ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
1345                ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
1346            ],
1347        ]
1348        for args_list in args_new:
1349            args = [
1350                rand_strided(bsz, num_heads, seq_len, head_dim)
1351                for (bsz, num_heads, seq_len, head_dim) in args_list
1352            ]
1353            try:
1354                with torch._subclasses.CrossRefFakeMode():
1355                    Repro()(*args)
1356            except RuntimeError as e:
1357                # We expect the cross ref to succed for the first output to fail
1358                # for the rng state, see Note [Seed and Offset]
1359                self.assertTrue("output[0]" not in str(e))
1360                self.assertTrue(
1361                    "found mismatched tensor metadata for output[6]: Devices cpu and cuda:0 are not equal!"
1362                    in str(e)
1363                )
1364
1365    # IMPORTANT!!! Always run even if CUDA is not available
1366    def test_fake_gpu_no_init(self):
1367        # Skip this test, we will try to run CUDA operations to real prop so
1368        # it clearly will not work on CPU runner
1369        if torch._functorch.config.fake_tensor_propagate_real_tensors:
1370            return
1371        with FakeTensorMode():
1372            torch.empty(10, device=GPU_TYPE)
1373            torch.ones(10, device=GPU_TYPE)
1374            torch.zeros(10, device=GPU_TYPE)
1375            torch.rand(10, device=GPU_TYPE)
1376            torch.tensor(3.14, device=GPU_TYPE)
1377            torch.tensor([[3.14, 2], [1, 2]], device=GPU_TYPE)
1378
1379    @skipIfRocm
1380    @unittest.skipIf(not RUN_CUDA, "requires cuda")
1381    def test_conv_c1_backward(self):
1382        class Repro(torch.nn.Module):
1383            def __init__(self) -> None:
1384                super().__init__()
1385
1386            def forward(self, arg1, arg2, arg3):
1387                torch.ops.aten.convolution_backward.default(
1388                    arg1,
1389                    arg2,
1390                    arg3,
1391                    [1],
1392                    [1, 1],
1393                    [1, 1],
1394                    [1, 1],
1395                    False,
1396                    [0, 0],
1397                    1,
1398                    [True, True, False],
1399                )
1400
1401        args_new = [
1402            ((16, 1, 128, 128), (16384, 16384, 128, 1), torch.float16, "cuda"),
1403            ((16, 64, 128, 128), (1048576, 1, 8192, 64), torch.float16, "cuda"),
1404            ((1, 64, 3, 3), (576, 9, 3, 1), torch.float16, "cuda"),
1405        ]
1406        args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args_new]
1407
1408        with torch._subclasses.CrossRefFakeMode():
1409            Repro()(*args)
1410
1411    def test_no_dispatch_with_like_function(self):
1412        class CountingMode(TorchDispatchMode):
1413            def __init__(self) -> None:
1414                self.count = 0
1415
1416            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1417                self.count += 1
1418                return func(*args, **kwargs)
1419
1420        with FakeTensorMode():
1421            x = torch.randn(2)
1422            with CountingMode() as mode:
1423                with no_dispatch():
1424                    torch.zeros_like(x)
1425
1426        self.assertEqual(mode.count, 0)
1427
1428
1429make_propagate_real_tensors_cls(FakeTensorOperatorInvariants)
1430
1431
1432class FakeTensorPropTest(TestCase):
1433    def test_fake_tensor_prop_on_nn_module(self):
1434        class ToyNnModuleWithParameters(torch.nn.Module):
1435            def __init__(self) -> None:
1436                super().__init__()
1437                self.layer1 = torch.nn.Linear(4, 3)
1438                self.layer2 = torch.nn.Linear(3, 2)
1439
1440            def forward(self, value):
1441                value = self.layer1(value)
1442                value = torch.relu(value)
1443                value = self.layer2(value)
1444                return value
1445
1446        model = ToyNnModuleWithParameters()
1447        value = torch.randn(5, 4)
1448        # Convert nn.Module to GraphModule so that FakeTensorProp runs.
1449        graph_model = torch.fx.symbolic_trace(model, (value,))
1450        # The following block runs FakeTensorProp on graph_module w/to the same FakeTensorMode
1451        #
1452        # TODO(wschin): there should be an API to run FakeTensorProp for GraphModule
1453        # with parameters and buffers.
1454        with FakeTensorMode() as fake_tensor_mode:
1455
1456            def to_fake_tensor(x):
1457                if isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor):
1458                    return fake_tensor_mode.from_tensor(x)
1459                return x
1460
1461            fake_parameters_and_buffers = {
1462                k: to_fake_tensor(v)
1463                for k, v in itertools.chain(
1464                    graph_model.named_parameters(), graph_model.named_buffers()
1465                )
1466            }
1467            with torch.nn.utils.stateless._reparametrize_module(
1468                graph_model, fake_parameters_and_buffers
1469            ):
1470                # This case uses the **same** fake tensor mode to
1471                #  1. create fake parameters and fake buffers, and
1472                #  2. run FakeTensorProp
1473                # The result should be correct.
1474                result = FakeTensorProp(graph_model, fake_tensor_mode).propagate(value)
1475                self.assertTrue(isinstance(result, FakeTensor))
1476                self.assertEqual(result.shape, (5, 2))
1477                # This case uses the **different** fake tensor modes to
1478                #  1. create fake parameters and fake buffers, and
1479                #  2. run FakeTensorProp
1480                # The following code should fail.
1481                failed = False
1482                try:
1483                    FakeTensorProp(graph_model).propagate(value)
1484                except AssertionError:
1485                    # AssertionError: tensor's device must be `meta`, got cpu instead
1486                    failed = True
1487                self.assertTrue(failed)
1488
1489    def test_fake_tensor_prop_on_nn_module_with_optional_args(self):
1490        class OptionalArgumentInBetween(torch.nn.Module):
1491            def __init__(self) -> None:
1492                super().__init__()
1493                self.layer1 = torch.nn.Linear(4, 3)
1494                self.layer2 = torch.nn.Linear(3, 2)
1495
1496            def forward(self, value, another_value=None, another_optional_value=None):
1497                # Mimic huggingface's `forward` methods which have several optional arguments.
1498                # For example, GPT accepts forward(self, input_ids, None, attention_mask, ...).
1499                # To apply FakeTensorProp, its from_real_tensor(...) needs to accept None.
1500                if another_value is None:
1501                    another_value = torch.rand_like(value)
1502                if another_optional_value is None:
1503                    another_optional_value = torch.rand_like(value)
1504                value = value + another_value + another_optional_value
1505                return value * value
1506
1507        fake_mode = FakeTensorMode(
1508            allow_non_fake_inputs=True, allow_fallback_kernels=False
1509        )
1510        with fake_mode:
1511            model = OptionalArgumentInBetween()
1512            value = torch.randn(5, 4)
1513            another_optional_value = torch.randn(5, 4)
1514            graph_model = torch.fx.symbolic_trace(
1515                model, (value, None, another_optional_value)
1516            )
1517            FakeTensorProp(graph_model, fake_mode).propagate(
1518                value, None, another_optional_value
1519            )
1520
1521    def test_unbacked_shape_realloc(self):
1522        def f(x):
1523            return x.nonzero()
1524
1525        shape_env = ShapeEnv()
1526        fake_mode = FakeTensorMode(shape_env=shape_env)
1527        with fake_mode:
1528            value = torch.randn(5)
1529            gm = make_fx(f)(value)
1530        nonzero_nodes = [
1531            n for n in gm.graph.nodes if n.target is torch.ops.aten.nonzero.default
1532        ]
1533        self.assertEqual(len(nonzero_nodes), 1)
1534        self.assertIsInstance(nonzero_nodes[0].meta["val"].shape[0], torch.SymInt)
1535        u0 = nonzero_nodes[0].meta["val"].shape[0]
1536        FakeTensorProp(gm, fake_mode).propagate(value)
1537        u1 = nonzero_nodes[0].meta["val"].shape[0]
1538        # Test that this test is actually doing something in that the
1539        # FakeTensorProp actually triggered a reallocation.  If this assert is
1540        # failing, it could be because we started memoizing the nnz count for
1541        # nonzero, which is nice in some sense (no reallocation) but not
1542        # helpful for this test, which is checking what we do when we have
1543        # to reallocate.  If so, you need to make this example more
1544        # complicated (e.g., maybe have a nontrivial computation on the input
1545        # before feeding it into nonzero, or have some sort of randomness)
1546        self.assertIsNot(u0, u1)
1547        self.assertTrue(statically_known_true(u0 == u1))
1548
1549    def test_torch_load_with_fake_mode(self):
1550        class TheModelClass(torch.nn.Module):
1551            def __init__(self) -> None:
1552                super().__init__()
1553                self.fc1 = torch.nn.Linear(5, 10)
1554
1555            def forward(self, x):
1556                return self.fc1(x)
1557
1558        with TemporaryFileName() as state_dict_file:
1559            # Create state_dict to be loaded later
1560            model = TheModelClass()
1561            torch.save(model.state_dict(), state_dict_file)
1562
1563            fake_mode = FakeTensorMode()
1564            with fake_mode:
1565                torch.load(state_dict_file)  # scenario 1
1566                torch.load(state_dict_file, map_location="cpu")  # scenario 2
1567
1568
1569make_propagate_real_tensors_cls(FakeTensorPropTest)
1570
1571
1572class FakeTensorSerialization(TestCase):
1573    def test_serialization(self):
1574        x = torch.tensor([0], device="cpu")
1575        with FakeTensorMode():
1576            y = pickle.loads(pickle.dumps(x))
1577            self.assertEqual(type(y), FakeTensor)
1578            self.assertEqual(y.device.type, "meta")
1579
1580            with unset_fake_temporarily():
1581                y = pickle.loads(pickle.dumps(x))
1582                self.assertEqual(x.device, y.device)
1583
1584    def test_serialization_with_tracing(self):
1585        x = torch.tensor([0], device="cpu")
1586        with tracing(TracingContext(FakeTensorMode())):
1587            y = pickle.loads(pickle.dumps(x))
1588            self.assertEqual(x.device, y.device)
1589
1590
1591class FakeTensorDispatchCache(TestCase):
1592    def test_shape_env_settings(self):
1593        """
1594        Validation that any boolean settings in ShapeEnv are present in the
1595        ShapeEnvSettings. We hope to ensure that any new settings that might
1596        affect FakeTensor dispatch are included in the cache key calculation.
1597        If this test fails, consider updating ShapeEnvSettings or change this
1598        test to omit checking for the new field.
1599        """
1600        init_sig = inspect.signature(ShapeEnv._init)
1601        args = [
1602            name
1603            for name, param in init_sig.parameters.items()
1604            if type(param.default) is bool
1605        ]
1606
1607        settings = [f.name for f in dataclasses.fields(ShapeEnvSettings)]
1608        for arg in args:
1609            self.assertTrue(arg in settings)
1610
1611    def _test_cache_key(self, fm, x, y, z):
1612        """
1613        Helper for all test_cache_key_* tests below. Assert that the
1614        cache keys for inputs x and y are the same, but z is different.
1615        """
1616        func = aten.add.Tensor
1617        state = _CacheKeyState()
1618        key_x = fm._cache_key(state, func, [x], {})
1619        key_y = fm._cache_key(state, func, [y], {})
1620        key_z = fm._cache_key(state, func, [z], {})
1621
1622        self.assertEqual(key_x, key_y)
1623        self.assertNotEqual(key_x, key_z)
1624
1625    def test_cache_key_dtype(self):
1626        with FakeTensorMode() as fm:
1627            x = torch.randn(4, 3, dtype=torch.float16)
1628            y = torch.randn(4, 3, dtype=torch.float16)
1629            z = x.to(dtype=torch.float32)
1630            self._test_cache_key(fm, x, y, z)
1631
1632    def test_cache_key_shape(self):
1633        with FakeTensorMode() as fm:
1634            x = torch.randn(4, 3)
1635            y = torch.randn(4, 3)
1636            z = torch.randn(4, 2)
1637            self._test_cache_key(fm, x, y, z)
1638
1639    def test_cache_key_stride(self):
1640        with FakeTensorMode() as fm:
1641            x = torch.randn(4, 2)
1642            y = torch.randn(4, 2)
1643            z = x.as_strided((4, 2), (1, 2))
1644            self._test_cache_key(fm, x, y, z)
1645
1646    @unittest.skipIf(not RUN_CUDA, "requires cuda")
1647    def test_cache_key_device(self):
1648        with FakeTensorMode() as fm:
1649            x = torch.randn(4, 3)
1650            y = torch.randn(4, 3)
1651            z = x.to(device="cuda")
1652            self._test_cache_key(fm, x, y, z)
1653
1654    def test_cache_key_memory_format(self):
1655        with FakeTensorMode() as fm:
1656            x = torch.randn(1, 2, 3, 4)
1657            y = torch.randn(1, 2, 3, 4)
1658            z = x.to(memory_format=torch.channels_last)
1659            self._test_cache_key(fm, x, y, z)
1660
1661    def test_cache_key_storage_offset(self):
1662        with FakeTensorMode() as fm:
1663            x = torch.randn(3)[1:]
1664            y = torch.randn(3)[1:]
1665            z = torch.randn(2)
1666            self._test_cache_key(fm, x, y, z)
1667
1668    def test_cache_key_requires_grad(self):
1669        with FakeTensorMode() as fm:
1670            x = torch.randn(4, 3)
1671            y = torch.randn(4, 3)
1672            z = torch.randn(4, 3, requires_grad=True)
1673            self._test_cache_key(fm, x, y, z)
1674
1675    def test_cache_key_is_conj(self):
1676        with FakeTensorMode() as fm:
1677            x = torch.randn(4, 3, dtype=torch.complex64)
1678            y = torch.randn(4, 3, dtype=torch.complex64)
1679            z = torch.randn(4, 3, dtype=torch.complex64)
1680            torch._C._set_conj(z, not z.is_conj())
1681            self._test_cache_key(fm, x, y, z)
1682
1683    def test_cache_key_is_neg(self):
1684        with FakeTensorMode() as fm:
1685            x = torch.randn(4, 3, dtype=torch.complex64)
1686            y = torch.randn(4, 3, dtype=torch.complex64)
1687            z = torch.randn(4, 3, dtype=torch.complex64)
1688            torch._C._set_neg(z, not z.is_neg())
1689            self._test_cache_key(fm, x, y, z)
1690
1691    def test_cache_key_is_inference(self):
1692        with torch.inference_mode(True):
1693            t = torch.randn(4, 3)
1694        with FakeTensorMode() as fm:
1695            x = torch.randn(4, 3)
1696            y = torch.randn(4, 3)
1697            z = fm.from_tensor(t)
1698            self._test_cache_key(fm, x, y, z)
1699
1700    def test_cache_key_constants(self):
1701        with FakeTensorMode() as fm:
1702            # Python hashes 1.0 to the same value as 1. Make sure the
1703            # cache key calculation differentiates them.
1704            self._test_cache_key(fm, 1.0, 1.0, 1)
1705            self._test_cache_key(fm, 0.0, 0.0, 0)
1706
1707    def assertHitsMisses(self, hits, misses):
1708        """
1709        Helper to assert on the number of recorded hits and misses.
1710        """
1711        info = FakeTensorMode.cache_info()
1712        self.assertEqual(info.hits, hits)
1713        self.assertEqual(info.misses, misses)
1714
1715    def assertBypasses(self, reason, count):
1716        """
1717        Helper to assert on the number of recorded bypasses.
1718        """
1719        info = FakeTensorMode.cache_info()
1720        if count > 0:
1721            self.assertIn(reason, info.bypasses)
1722            self.assertEqual(info.bypasses[reason], count)
1723        else:
1724            self.assertNotIn(reason, info.bypasses)
1725
1726    def test_cache_hit(self):
1727        """
1728        Test that cache hit/miss counters are updated correctly.
1729        """
1730        with FakeTensorMode():
1731            x = torch.randn(4, 3)
1732            y = torch.randn(4, 3)
1733
1734            FakeTensorMode.cache_clear()
1735            self.assertHitsMisses(0, 0)
1736            res1 = x + y
1737            self.assertHitsMisses(0, 1)
1738            res2 = x + y
1739            self.assertHitsMisses(1, 1)
1740
1741            self.assertEqual(
1742                extract_tensor_metadata(res1),
1743                extract_tensor_metadata(res2),
1744            )
1745
1746    def test_cache_bypass(self):
1747        """
1748        Test that cache bypass counters are updated correctly.
1749        """
1750        with FakeTensorMode():
1751            x = torch.randn(1, 2)
1752
1753            FakeTensorMode.cache_clear()
1754            self.assertBypasses("inplace view", 0)
1755
1756            x.unsqueeze_(0)
1757            self.assertBypasses("inplace view", 1)
1758
1759    def test_cache_default_dtype(self):
1760        """
1761        Test that the default dtype is respected when serving cached results.
1762        """
1763        with FakeTensorMode():
1764            x = torch.tensor([1, 2], dtype=torch.int32)
1765            torch.set_default_dtype(torch.float32)
1766
1767            FakeTensorMode.cache_clear()
1768            self.assertHitsMisses(0, 0)
1769
1770            y = x + 1.0
1771            self.assertEqual(y.dtype, torch.float32)
1772            self.assertHitsMisses(0, 1)
1773
1774            torch.set_default_dtype(torch.float16)
1775            y = x + 1.0
1776            self.assertEqual(y.dtype, torch.float16)
1777            self.assertHitsMisses(0, 2)
1778
1779            torch.set_default_dtype(torch.float32)
1780            y = x + 1.0
1781            self.assertEqual(y.dtype, torch.float32)
1782            self.assertHitsMisses(1, 2)
1783
1784    @unittest.skipIf(not RUN_CUDA, "requires cuda")
1785    def test_cache_default_device(self):
1786        """
1787        Test that the default device is respected when serving cached results.
1788        """
1789        with FakeTensorMode():
1790            FakeTensorMode.cache_clear()
1791            self.assertHitsMisses(0, 0)
1792
1793            torch.set_default_device("cpu")
1794            x = torch.tensor([1, 2])
1795            y = x + 1.0
1796            self.assertEqual(y.device.type, "cpu")
1797            self.assertHitsMisses(0, 1)
1798
1799            torch.set_default_device("cuda")
1800            x = torch.tensor([1, 2])
1801            y = x + 1.0
1802            self.assertEqual(y.device.type, "cuda")
1803            self.assertHitsMisses(0, 2)
1804
1805            torch.set_default_device("cpu")
1806            x = torch.tensor([1, 2])
1807            y = x + 1.0
1808            self.assertEqual(y.device.type, "cpu")
1809            self.assertHitsMisses(1, 2)
1810
1811    def test_cache_inplace_op(self):
1812        """
1813        Test that inplace ops served from the cache correctly reference the
1814        input parameter.
1815        """
1816        with FakeTensorMode():
1817            x = torch.randn(1, 2)
1818            y = torch.randn(1, 2)
1819
1820            FakeTensorMode.cache_clear()
1821            self.assertHitsMisses(0, 0)
1822
1823            z = x.add_(y)
1824            self.assertHitsMisses(0, 1)
1825            self.assertEqual(id(x), id(z))
1826
1827            w = x.add_(y)
1828            self.assertHitsMisses(1, 1)
1829            self.assertEqual(id(x), id(w))
1830
1831    def test_cache_view_op(self):
1832        """
1833        Test that view ops are handled correctly when served from the cache.
1834        """
1835        with FakeTensorMode():
1836            x1 = torch.ones(2, requires_grad=True).clone()
1837            x2 = torch.ones(2, requires_grad=True).clone()
1838            y2 = x2.view(-1)
1839
1840            # Test operating on a non-view tensor, then the same operation
1841            # on a view tensor. Assert that the view property is set correctly.
1842            z1 = x1.mul_(2)
1843            self.assertFalse(z1._is_view())
1844
1845            z2 = y2.mul_(2)
1846            self.assertTrue(z2._is_view())
1847
1848            # Now the other way around: first operate on a view tensor, then
1849            # the same operation on a non-view tensor.
1850            z2 = y2.mul_(2)
1851            self.assertTrue(z2._is_view())
1852
1853            z1 = x1.mul_(2)
1854            self.assertFalse(z1._is_view())
1855
1856    def test_cache_dispatch_key_set(self):
1857        """
1858        Test that operations that change the dispatch key set bypass caching.
1859        """
1860        with FakeTensorMode():
1861            FakeTensorMode.cache_clear()
1862            self.assertBypasses("dispatch_key_set mismatch", 0)
1863
1864            x = torch._efficientzerotensor(3)
1865            self.assertTrue(x._is_zerotensor())
1866            self.assertBypasses("dispatch_key_set mismatch", 1)
1867
1868            y = torch._efficientzerotensor(3)
1869            self.assertTrue(y._is_zerotensor())
1870            self.assertBypasses("dispatch_key_set mismatch", 2)
1871
1872    def test_inference_mode(self):
1873        """
1874        Test that caching handles inference mode correctly.
1875        """
1876        with FakeTensorMode():
1877            x = torch.randn(4, 3)
1878            y = torch.randn(4, 3)
1879
1880            FakeTensorMode.cache_clear()
1881            self.assertHitsMisses(0, 0)
1882
1883            # Expect a miss when the inference mode is different
1884            res1 = x + y
1885            with torch.inference_mode():
1886                res2 = x + y
1887
1888            self.assertHitsMisses(0, 2)
1889            self.assertFalse(res1.is_inference())
1890            self.assertTrue(res2.is_inference())
1891
1892            # Second tries should see hits
1893            res3 = x + y
1894
1895            self.assertHitsMisses(1, 2)
1896            self.assertFalse(res3.is_inference())
1897            self.assertEqual(
1898                extract_tensor_metadata(res1),
1899                extract_tensor_metadata(res3),
1900            )
1901
1902            with torch.inference_mode():
1903                res4 = x + y
1904
1905            self.assertHitsMisses(2, 2)
1906            self.assertTrue(res4.is_inference())
1907            self.assertEqual(
1908                extract_tensor_metadata(res2),
1909                extract_tensor_metadata(res4),
1910            )
1911
1912
1913if __name__ == "__main__":
1914    run_tests()
1915