# Owner(s): ["oncall: jit"] import re import torch import torch._lazy.metrics as metrics import torch._lazy.ts_backend from torch.testing._internal.common_utils import run_tests, TestCase torch._lazy.ts_backend.init() NODE_TYPE_PATTERN = re.compile(r", NodeType=[^\n]+") class LazyFuncionalizationTest(TestCase): def test_lazy_init_with_view(self): def f(device, reset_storage=False): torch.manual_seed(2023) if device == "lazy": metrics.reset() class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = torch.nn.Linear(4, 2, bias=False) def forward(self, x): return x @ self.fc1.weight.transpose(0, 1) with torch.device(device): model = Model() if device == "lazy": if reset_storage: torch._C._unsafe_reset_storage(model.fc1.weight) torch._lazy.mark_step() sync_tensors = metrics.counter_value("SyncedTensorsWithIR") if reset_storage: assert sync_tensors == 1 else: # There is an extra tensor being unnecessarily synced if # the functional storage is not reset. assert sync_tensors == 2 x = torch.ones(4) out = model(x) if device == "lazy": torch._lazy.mark_step() return out cpu_out = f("cpu") lazy_out_1 = f("lazy", reset_storage=False) lazy_out_2 = f("lazy", reset_storage=True) self.assertEqual(cpu_out, lazy_out_1.to("cpu")) self.assertEqual(cpu_out, lazy_out_2.to("cpu")) def test_data_assign(self): def text(lazyt): raw = torch._C._lazy._get_tensors_text([lazyt]) return NODE_TYPE_PATTERN.sub("", raw) origin = torch.rand(3, dtype=torch.float32) tensor = origin.to("lazy") self.assertExpectedInline( text(tensor), """\ IR { %0 = [Float[3]] lazy_tensors::device_data(), device=CPU0, ROOT=0 } """, ) # Modify the data-type of tensor, and assign it to 'data'. # This should update the inner tensor of FunctionalTensorWrapper, # changing the corresponding IR node. modified_tensor = tensor.to(torch.bfloat16) tensor.data = modified_tensor self.assertExpectedInline( text(tensor), """\ IR { %0 = [Float[3]] lazy_tensors::device_data(), device=CPU0 %1 = [BFloat16[3]] aten::_to_copy(%0), dtype=BFloat16, layout=null, device=null, pin_memory=null, non_blocking=0, memory_format=null, ROOT=0 } """, # noqa: B950 ) if __name__ == "__main__": run_tests()