1# Owner(s): ["oncall: jit"] 2 3import re 4 5import torch 6import torch._lazy.metrics as metrics 7import torch._lazy.ts_backend 8from torch.testing._internal.common_utils import run_tests, TestCase 9 10 11torch._lazy.ts_backend.init() 12 13NODE_TYPE_PATTERN = re.compile(r", NodeType=[^\n]+") 14 15 16class LazyFuncionalizationTest(TestCase): 17 def test_lazy_init_with_view(self): 18 def f(device, reset_storage=False): 19 torch.manual_seed(2023) 20 21 if device == "lazy": 22 metrics.reset() 23 24 class Model(torch.nn.Module): 25 def __init__(self) -> None: 26 super().__init__() 27 self.fc1 = torch.nn.Linear(4, 2, bias=False) 28 29 def forward(self, x): 30 return x @ self.fc1.weight.transpose(0, 1) 31 32 with torch.device(device): 33 model = Model() 34 35 if device == "lazy": 36 if reset_storage: 37 torch._C._unsafe_reset_storage(model.fc1.weight) 38 39 torch._lazy.mark_step() 40 41 sync_tensors = metrics.counter_value("SyncedTensorsWithIR") 42 if reset_storage: 43 assert sync_tensors == 1 44 else: 45 # There is an extra tensor being unnecessarily synced if 46 # the functional storage is not reset. 47 assert sync_tensors == 2 48 49 x = torch.ones(4) 50 out = model(x) 51 52 if device == "lazy": 53 torch._lazy.mark_step() 54 55 return out 56 57 cpu_out = f("cpu") 58 lazy_out_1 = f("lazy", reset_storage=False) 59 lazy_out_2 = f("lazy", reset_storage=True) 60 61 self.assertEqual(cpu_out, lazy_out_1.to("cpu")) 62 self.assertEqual(cpu_out, lazy_out_2.to("cpu")) 63 64 def test_data_assign(self): 65 def text(lazyt): 66 raw = torch._C._lazy._get_tensors_text([lazyt]) 67 return NODE_TYPE_PATTERN.sub("", raw) 68 69 origin = torch.rand(3, dtype=torch.float32) 70 tensor = origin.to("lazy") 71 72 self.assertExpectedInline( 73 text(tensor), 74 """\ 75IR { 76 %0 = [Float[3]] lazy_tensors::device_data(), device=CPU0, ROOT=0 77} 78""", 79 ) 80 81 # Modify the data-type of tensor, and assign it to 'data'. 82 # This should update the inner tensor of FunctionalTensorWrapper, 83 # changing the corresponding IR node. 84 modified_tensor = tensor.to(torch.bfloat16) 85 tensor.data = modified_tensor 86 87 self.assertExpectedInline( 88 text(tensor), 89 """\ 90IR { 91 %0 = [Float[3]] lazy_tensors::device_data(), device=CPU0 92 %1 = [BFloat16[3]] aten::_to_copy(%0), dtype=BFloat16, layout=null, device=null, pin_memory=null, non_blocking=0, memory_format=null, ROOT=0 93} 94""", # noqa: B950 95 ) 96 97 98if __name__ == "__main__": 99 run_tests() 100