# Owner(s): ["oncall: pt2"] import tempfile import unittest import torch from torch._prims.debug_prims import load_tensor_reader from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode from torch.multiprocessing.reductions import StorageWeakRef from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_utils import ( IS_WINDOWS, run_tests, skipIfRocm, TestCase, ) from torch.utils._content_store import ( ContentStoreReader, ContentStoreWriter, hash_storage, ) @unittest.skipIf(IS_WINDOWS, "Test case not supported on Windows") class TestContentStore(TestCase): def test_basic(self, device): # setup test data x = torch.randn(4, device=device) y = torch.randn(6, device=device) z = x.view(2, 2) # start writing with tempfile.TemporaryDirectory() as loc: writer = ContentStoreWriter(loc) writer.write_tensor("x", x) writer.write_tensor("y", y) writer.write_tensor("z", z) # do some mutation that is VC UNTRACKED x.data.add_(1) writer.write_tensor("x2", x) writer.write_tensor("y2", y) writer.write_tensor("z2", z) del writer reader = ContentStoreReader(loc) n_x = reader.read_tensor("x") n_y = reader.read_tensor("y") n_z = reader.read_tensor("z") self.assertEqual(n_x + 1, x) self.assertEqual(n_y, y) self.assertEqual(n_z + 1, z) self.assertEqual( StorageWeakRef(n_x.untyped_storage()), StorageWeakRef(n_z.untyped_storage()), ) n_x2 = reader.read_tensor("x2") n_y2 = reader.read_tensor("y2") n_z2 = reader.read_tensor("z2") self.assertEqual(n_x2, x) self.assertEqual(n_y2, y) self.assertEqual(n_z2, z) self.assertEqual( StorageWeakRef(n_y2.untyped_storage()), StorageWeakRef(n_y.untyped_storage()), ) def test_scalar(self, device): # Should not raise an error hash_storage(torch.tensor(2, device=device).untyped_storage()) @torch._dynamo.config.patch(cache_size_limit=1) def test_repeated_hash(self, device): # Test that repeated hashing doesn't trigger a recompile in dynamo # If it does, we will execute prims.xor_sum in eager which fails for _ in range(4): hash_storage(torch.tensor(2, device=device).untyped_storage()) @skipIfRocm def test_load_tensor(self, device): with tempfile.TemporaryDirectory() as loc: writer = ContentStoreWriter(loc) x = torch.randn(4, device=device) def same_meta_as_x(t): self.assertEqual(t.size(), x.size()) self.assertEqual(t.stride(), x.stride()) self.assertEqual(t.dtype, x.dtype) self.assertEqual(t.device, x.device) writer.write_tensor("x", x) with load_tensor_reader(loc): x2 = torch.ops.debugprims.load_tensor.default( "x", (4,), (1,), dtype=torch.float32, device=device ) self.assertEqual(x, x2) x3 = torch.ops.debugprims.load_tensor.default( "x", (4,), (1,), dtype=torch.float32, device=device ) self.assertEqual(x, x3) # Must not alias! self.assertNotEqual( StorageWeakRef(x.untyped_storage()), StorageWeakRef(x2.untyped_storage()), ) self.assertNotEqual( StorageWeakRef(x2.untyped_storage()), StorageWeakRef(x3.untyped_storage()), ) # Check fake tensor mode works too with FakeTensorMode(): x4 = torch.ops.debugprims.load_tensor.default( "x", (4,), (1,), dtype=torch.float32, device=device ) self.assertIsInstance(x4, FakeTensor) same_meta_as_x(x4) # Check fp64 works x5 = torch.ops.debugprims.load_tensor.default( "x", (4,), (1,), dtype=torch.float64, device=device ) self.assertEqual(x5.float(), x) self.assertEqual(x5.dtype, torch.float64) x6 = torch.ops.debugprims.load_tensor.default( "x", (4,), (1,), dtype=torch.float32, device=device ) same_meta_as_x(x6) instantiate_device_type_tests(TestContentStore, globals()) if __name__ == "__main__": run_tests()