xref: /aosp_15_r20/external/pytorch/test/lazy/test_functionalization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import re
4
5import torch
6import torch._lazy.metrics as metrics
7import torch._lazy.ts_backend
8from torch.testing._internal.common_utils import run_tests, TestCase
9
10
11torch._lazy.ts_backend.init()
12
13NODE_TYPE_PATTERN = re.compile(r", NodeType=[^\n]+")
14
15
16class LazyFuncionalizationTest(TestCase):
17    def test_lazy_init_with_view(self):
18        def f(device, reset_storage=False):
19            torch.manual_seed(2023)
20
21            if device == "lazy":
22                metrics.reset()
23
24            class Model(torch.nn.Module):
25                def __init__(self) -> None:
26                    super().__init__()
27                    self.fc1 = torch.nn.Linear(4, 2, bias=False)
28
29                def forward(self, x):
30                    return x @ self.fc1.weight.transpose(0, 1)
31
32            with torch.device(device):
33                model = Model()
34
35                if device == "lazy":
36                    if reset_storage:
37                        torch._C._unsafe_reset_storage(model.fc1.weight)
38
39                    torch._lazy.mark_step()
40
41                    sync_tensors = metrics.counter_value("SyncedTensorsWithIR")
42                    if reset_storage:
43                        assert sync_tensors == 1
44                    else:
45                        # There is an extra tensor being unnecessarily synced if
46                        # the functional storage is not reset.
47                        assert sync_tensors == 2
48
49                x = torch.ones(4)
50                out = model(x)
51
52                if device == "lazy":
53                    torch._lazy.mark_step()
54
55                return out
56
57        cpu_out = f("cpu")
58        lazy_out_1 = f("lazy", reset_storage=False)
59        lazy_out_2 = f("lazy", reset_storage=True)
60
61        self.assertEqual(cpu_out, lazy_out_1.to("cpu"))
62        self.assertEqual(cpu_out, lazy_out_2.to("cpu"))
63
64    def test_data_assign(self):
65        def text(lazyt):
66            raw = torch._C._lazy._get_tensors_text([lazyt])
67            return NODE_TYPE_PATTERN.sub("", raw)
68
69        origin = torch.rand(3, dtype=torch.float32)
70        tensor = origin.to("lazy")
71
72        self.assertExpectedInline(
73            text(tensor),
74            """\
75IR {
76  %0 = [Float[3]] lazy_tensors::device_data(), device=CPU0, ROOT=0
77}
78""",
79        )
80
81        # Modify the data-type of tensor, and assign it to 'data'.
82        # This should update the inner tensor of FunctionalTensorWrapper,
83        # changing the corresponding IR node.
84        modified_tensor = tensor.to(torch.bfloat16)
85        tensor.data = modified_tensor
86
87        self.assertExpectedInline(
88            text(tensor),
89            """\
90IR {
91  %0 = [Float[3]] lazy_tensors::device_data(), device=CPU0
92  %1 = [BFloat16[3]] aten::_to_copy(%0), dtype=BFloat16, layout=null, device=null, pin_memory=null, non_blocking=0, memory_format=null, ROOT=0
93}
94""",  # noqa: B950
95        )
96
97
98if __name__ == "__main__":
99    run_tests()
100