xref: /aosp_15_r20/external/pytorch/test/lazy/test_extract_compiled_graph.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import unittest
4
5from torch._lazy.ts_backend import init as init_ts_backend
6
7
8init_ts_backend()
9import copy
10import dis
11import inspect
12import re
13from contextlib import contextmanager
14
15import torch
16from torch import fx, nn
17from torch._lazy import config
18from torch._lazy.extract_compiled_graph import extract_compiled_graph
19
20
21class ModuleConstScale(nn.Module):
22    def forward(self, a):
23        return a * 2
24
25
26class ModuleSub(nn.Module):
27    def forward(self, a, b):
28        return a - b
29
30
31class ModuleAddcmul(nn.Module):
32    """
33    addcmul function takes a at::Scalar which results in a special TSData containing a Scalar rather than a Tensor.
34    """
35
36    def forward(self, a, b, c):
37        return torch.addcmul(a, b, c, value=5)
38
39
40class ModuleReturnMulti(nn.Module):
41    def forward(self, a, b):
42        return (b + 1, a - 1)
43
44
45# The default fx tracer will convert torch.randn to a constant.. We may need
46# a custom tracer.
47# class ModuleEagerTensor(nn.Module):
48#     def __init__(self) -> None:
49#         super().__init__()
50#
51#     def forward(self, a):
52#         b = torch.randn(2, 3, device="cpu") # eager device
53#         return a + b
54
55# The module was planned to cover the case that a Fx graph return an eager
56# tensor on the default device. It's harder than ModuleEagerTensor because
57# we can not just override the device argument to Lazy since there is no
58# explicit device argument.
59#
60# Unfortunately, the default fx tracer convert the return value of the forward
61# method to a constant.. Comment out for now
62# class ModuleReturnEagerTensorOnDefaultDevice(nn.Module):
63#     def __init__(self) -> None:
64#         super().__init__()
65#
66#     def forward(self):
67#         return torch.tensor((2, 3), dtype=torch.float32)
68
69
70class ModuleReturnDupTensor(nn.Module):
71    """
72    Handle the corner case that the same tensor appears multiple times in the
73    returned tuple. torchbench like drq will hit this corner case when running
74    thru torchdynamo..
75    """
76
77    def forward(self, a, b):
78        c = a + b
79        return a - b, c, a + 1, c
80
81
82class ModuleInplaceUpdate(nn.Module):
83    def forward(self, a, b):
84        a.sub_(b)
85        return b - 1, b + 1
86
87
88@contextmanager
89def force_fallback_ctx_mgr(fallback_op):
90    oldconfig = config.get_force_fallback()
91    config.set_force_fallback(fallback_op)
92    try:
93        yield None
94    finally:
95        config.set_force_fallback(oldconfig)
96
97
98@contextmanager
99def nop_ctx_mgr():
100    try:
101        yield None
102    finally:
103        pass
104
105
106def gen_rand_args(mod):
107    args = []
108    for _ in range(len(inspect.signature(mod.forward).parameters)):
109        args.append(torch.randn(2, 3))
110    return args
111
112
113def allclose(expected, actual):
114    def unwrap(cont):
115        if isinstance(cont, (list, tuple)) and len(cont) == 1:
116            return cont[0]
117        return cont
118
119    expected = unwrap(expected)
120    actual = unwrap(actual)
121
122    if isinstance(expected, torch.Tensor) and isinstance(actual, torch.Tensor):
123        return torch.allclose(expected, actual)
124    elif isinstance(expected, (tuple, list)) and isinstance(actual, (tuple, list)):
125        return len(expected) == len(actual) and all(
126            torch.allclose(a, b) for a, b in zip(expected, actual)
127        )
128    else:
129        raise RuntimeError("Unexpected types")
130
131
132def verify_reusing_compiled_graph(mod, exception_msg_pattern, ncase=10):
133    args = gen_rand_args(mod)
134    out = mod(*args)
135
136    dis.dis(mod.forward)
137
138    try:
139        optimized_mod = extract_compiled_graph(fx.symbolic_trace(mod), args)
140    except RuntimeError as e:
141        if exception_msg_pattern is None:
142            raise e  # reraise the exception
143        exception_message = str(e)
144        if not re.search(exception_msg_pattern, exception_message):
145            raise RuntimeError(
146                f"Exception message does not match the required pattern: {exception_message}"
147            ) from e
148        else:
149            # We are done for the test case that expects an exception
150            return
151
152    if exception_msg_pattern is not None:
153        raise RuntimeError(
154            f"Expect an exception matching pattern {exception_msg_pattern}"
155        )
156    print("return value of optimized_mod", optimized_mod(*args))
157
158    # check correctness
159    failed_index = []
160    for i in range(ncase):
161        rand_args = gen_rand_args(mod)
162        rand_args_copy = copy.deepcopy(rand_args)
163        expected = mod(*rand_args)
164        actual = optimized_mod(*rand_args_copy)
165
166        if not allclose(expected, actual):
167            print(f"Incorrect results. expected {expected}, actual {actual}")
168            failed_index.append(i)
169            continue
170
171        # make sure arguments match after calling the model forward method to handle inplace
172        # updates.
173        if not allclose(rand_args, rand_args_copy):
174            print(
175                f"Incorrect updated arguments. expected {rand_args}, actual {rand_args_copy}"
176            )
177            failed_index.append(i)
178            continue
179
180    if len(failed_index) > 0:
181        raise RuntimeError(f"Failed {len(failed_index)}/{ncase} cases")
182
183
184def maketest(module_cls, exception_msg_pattern=None, ctxmgr=None):
185    def wrapper(self):
186        nonlocal ctxmgr
187        if not ctxmgr:
188            ctxmgr = nop_ctx_mgr()
189        with ctxmgr:
190            verify_reusing_compiled_graph(module_cls(), exception_msg_pattern)
191
192    return wrapper
193
194
195class OptimizeTest(unittest.TestCase):
196    test_sub = maketest(ModuleSub)
197    # Same as test_sub but force aten::sub to fallback
198    # We expect an exception caught because of LTC fallabck.
199    test_ltc_fallback = maketest(
200        ModuleSub,
201        exception_msg_pattern="fallback.*aten::sub",
202        ctxmgr=force_fallback_ctx_mgr("aten::sub"),
203    )
204    test_const_scale = maketest(ModuleConstScale)
205    test_addcmul = maketest(ModuleAddcmul)
206    test_return_multi = maketest(ModuleReturnMulti)
207    test_return_dup_tensor = maketest(ModuleReturnDupTensor)
208    test_inplace_update = maketest(ModuleInplaceUpdate)
209