1# Owner(s): ["oncall: jit"] 2 3import unittest 4 5from torch._lazy.ts_backend import init as init_ts_backend 6 7 8init_ts_backend() 9import copy 10import dis 11import inspect 12import re 13from contextlib import contextmanager 14 15import torch 16from torch import fx, nn 17from torch._lazy import config 18from torch._lazy.extract_compiled_graph import extract_compiled_graph 19 20 21class ModuleConstScale(nn.Module): 22 def forward(self, a): 23 return a * 2 24 25 26class ModuleSub(nn.Module): 27 def forward(self, a, b): 28 return a - b 29 30 31class ModuleAddcmul(nn.Module): 32 """ 33 addcmul function takes a at::Scalar which results in a special TSData containing a Scalar rather than a Tensor. 34 """ 35 36 def forward(self, a, b, c): 37 return torch.addcmul(a, b, c, value=5) 38 39 40class ModuleReturnMulti(nn.Module): 41 def forward(self, a, b): 42 return (b + 1, a - 1) 43 44 45# The default fx tracer will convert torch.randn to a constant.. We may need 46# a custom tracer. 47# class ModuleEagerTensor(nn.Module): 48# def __init__(self) -> None: 49# super().__init__() 50# 51# def forward(self, a): 52# b = torch.randn(2, 3, device="cpu") # eager device 53# return a + b 54 55# The module was planned to cover the case that a Fx graph return an eager 56# tensor on the default device. It's harder than ModuleEagerTensor because 57# we can not just override the device argument to Lazy since there is no 58# explicit device argument. 59# 60# Unfortunately, the default fx tracer convert the return value of the forward 61# method to a constant.. Comment out for now 62# class ModuleReturnEagerTensorOnDefaultDevice(nn.Module): 63# def __init__(self) -> None: 64# super().__init__() 65# 66# def forward(self): 67# return torch.tensor((2, 3), dtype=torch.float32) 68 69 70class ModuleReturnDupTensor(nn.Module): 71 """ 72 Handle the corner case that the same tensor appears multiple times in the 73 returned tuple. torchbench like drq will hit this corner case when running 74 thru torchdynamo.. 75 """ 76 77 def forward(self, a, b): 78 c = a + b 79 return a - b, c, a + 1, c 80 81 82class ModuleInplaceUpdate(nn.Module): 83 def forward(self, a, b): 84 a.sub_(b) 85 return b - 1, b + 1 86 87 88@contextmanager 89def force_fallback_ctx_mgr(fallback_op): 90 oldconfig = config.get_force_fallback() 91 config.set_force_fallback(fallback_op) 92 try: 93 yield None 94 finally: 95 config.set_force_fallback(oldconfig) 96 97 98@contextmanager 99def nop_ctx_mgr(): 100 try: 101 yield None 102 finally: 103 pass 104 105 106def gen_rand_args(mod): 107 args = [] 108 for _ in range(len(inspect.signature(mod.forward).parameters)): 109 args.append(torch.randn(2, 3)) 110 return args 111 112 113def allclose(expected, actual): 114 def unwrap(cont): 115 if isinstance(cont, (list, tuple)) and len(cont) == 1: 116 return cont[0] 117 return cont 118 119 expected = unwrap(expected) 120 actual = unwrap(actual) 121 122 if isinstance(expected, torch.Tensor) and isinstance(actual, torch.Tensor): 123 return torch.allclose(expected, actual) 124 elif isinstance(expected, (tuple, list)) and isinstance(actual, (tuple, list)): 125 return len(expected) == len(actual) and all( 126 torch.allclose(a, b) for a, b in zip(expected, actual) 127 ) 128 else: 129 raise RuntimeError("Unexpected types") 130 131 132def verify_reusing_compiled_graph(mod, exception_msg_pattern, ncase=10): 133 args = gen_rand_args(mod) 134 out = mod(*args) 135 136 dis.dis(mod.forward) 137 138 try: 139 optimized_mod = extract_compiled_graph(fx.symbolic_trace(mod), args) 140 except RuntimeError as e: 141 if exception_msg_pattern is None: 142 raise e # reraise the exception 143 exception_message = str(e) 144 if not re.search(exception_msg_pattern, exception_message): 145 raise RuntimeError( 146 f"Exception message does not match the required pattern: {exception_message}" 147 ) from e 148 else: 149 # We are done for the test case that expects an exception 150 return 151 152 if exception_msg_pattern is not None: 153 raise RuntimeError( 154 f"Expect an exception matching pattern {exception_msg_pattern}" 155 ) 156 print("return value of optimized_mod", optimized_mod(*args)) 157 158 # check correctness 159 failed_index = [] 160 for i in range(ncase): 161 rand_args = gen_rand_args(mod) 162 rand_args_copy = copy.deepcopy(rand_args) 163 expected = mod(*rand_args) 164 actual = optimized_mod(*rand_args_copy) 165 166 if not allclose(expected, actual): 167 print(f"Incorrect results. expected {expected}, actual {actual}") 168 failed_index.append(i) 169 continue 170 171 # make sure arguments match after calling the model forward method to handle inplace 172 # updates. 173 if not allclose(rand_args, rand_args_copy): 174 print( 175 f"Incorrect updated arguments. expected {rand_args}, actual {rand_args_copy}" 176 ) 177 failed_index.append(i) 178 continue 179 180 if len(failed_index) > 0: 181 raise RuntimeError(f"Failed {len(failed_index)}/{ncase} cases") 182 183 184def maketest(module_cls, exception_msg_pattern=None, ctxmgr=None): 185 def wrapper(self): 186 nonlocal ctxmgr 187 if not ctxmgr: 188 ctxmgr = nop_ctx_mgr() 189 with ctxmgr: 190 verify_reusing_compiled_graph(module_cls(), exception_msg_pattern) 191 192 return wrapper 193 194 195class OptimizeTest(unittest.TestCase): 196 test_sub = maketest(ModuleSub) 197 # Same as test_sub but force aten::sub to fallback 198 # We expect an exception caught because of LTC fallabck. 199 test_ltc_fallback = maketest( 200 ModuleSub, 201 exception_msg_pattern="fallback.*aten::sub", 202 ctxmgr=force_fallback_ctx_mgr("aten::sub"), 203 ) 204 test_const_scale = maketest(ModuleConstScale) 205 test_addcmul = maketest(ModuleAddcmul) 206 test_return_multi = maketest(ModuleReturnMulti) 207 test_return_dup_tensor = maketest(ModuleReturnDupTensor) 208 test_inplace_update = maketest(ModuleInplaceUpdate) 209