1# Owner(s): ["module: dynamo"] 2from unittest.mock import patch 3 4import torch 5import torch._dynamo.test_case 6import torch._dynamo.testing 7from torch._C import ( 8 _len_torch_function_stack, 9 _pop_torch_function_stack, 10 _push_on_torch_function_stack, 11) 12from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode 13from torch.utils._device import DeviceContext 14from torch.utils._python_dispatch import TorchDispatchMode 15 16 17class TorchDispatchModeTests(torch._dynamo.test_case.TestCase): 18 @classmethod 19 def setUpClass(cls): 20 super().setUpClass() 21 22 @classmethod 23 def tearDownClass(cls): 24 super().tearDownClass() 25 26 def test_skip_torch_dispatch_modes(self): 27 class RewriteAddToMul(TorchDispatchMode): 28 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 29 if func is torch.ops.aten.add.Tensor: 30 func = torch.ops.aten.mul.Tensor 31 return func(*args, **kwargs) 32 33 def fn(x): 34 return x + x 35 36 cnt = torch._dynamo.testing.CompileCounter() 37 38 x = torch.tensor([3.0]) 39 with RewriteAddToMul(): 40 eager_res = fn(x) 41 compiled_res = torch._dynamo.optimize(cnt)(fn)(x) 42 43 self.assertEqual(eager_res, compiled_res) 44 self.assertEqual(cnt.frame_count, 0) 45 46 47class TorchFunctionModeTests(torch._dynamo.test_case.TestCase): 48 @classmethod 49 def setUpClass(cls): 50 cls.default_device_old = torch.get_default_device() 51 super().setUpClass() 52 53 @classmethod 54 def tearDownClass(cls): 55 torch.set_default_device(cls.default_device_old) 56 super().tearDownClass() 57 58 def setUp(self): 59 torch.set_default_device(None) 60 61 def tearDown(self): 62 torch.set_default_device(None) 63 64 def _run_torch_function_mode_guard_test(self): 65 class TestMode1(BaseTorchFunctionMode): 66 pass 67 68 class TestMode2(BaseTorchFunctionMode): 69 pass 70 71 cnt = torch._dynamo.testing.CompileCounter() 72 73 @torch.compile(backend=cnt.__call__) 74 def fn(x): 75 return x + 1 76 77 inp = torch.ones(2, 2) 78 fn(inp) 79 self.assertEqual(cnt.frame_count, 1) 80 81 with TestMode1(): 82 fn(inp) 83 self.assertEqual(cnt.frame_count, 2) 84 85 with TestMode1(), TestMode2(): 86 fn(inp) 87 self.assertEqual(cnt.frame_count, 3) 88 89 with TestMode2(), TestMode1(): 90 fn(inp) 91 self.assertEqual(cnt.frame_count, 4) 92 93 with TestMode1(): 94 fn(inp) 95 self.assertEqual(cnt.frame_count, 4) 96 97 def _run_ignored_mode_types_test(self): 98 class IgnoredMode(BaseTorchFunctionMode): 99 pass 100 101 cnt = torch._dynamo.testing.CompileCounter() 102 103 @torch.compile(backend=cnt.__call__, fullgraph=True) 104 def fn(x): 105 return x + 1 106 107 inp = torch.ones(2, 2) 108 109 with patch( 110 "torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode} 111 ): 112 # initial compile 113 fn(inp) 114 115 # no recompile, mode ignored 116 # note: the ref stack is length 0, and the stack we are checking against has length 2 117 # we want to check both ref stack len > runtime stack, and ref stack len < runtime stack 118 with IgnoredMode(), IgnoredMode(): 119 fn(inp) 120 121 self.assertEqual(cnt.frame_count, 1) 122 123 # recompile due to new mode on the stack 124 with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): 125 fn(inp) 126 127 self.assertEqual(cnt.frame_count, 2) 128 129 # recompile 130 # tests both ref stack len > runtime stack len for the above guard check 131 # and ref stack len < runtime stack len for the initial zero mode case 132 with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode(): 133 fn(inp) 134 135 self.assertEqual(cnt.frame_count, 3) 136 137 # no recompile 138 with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): 139 fn(inp) 140 141 self.assertEqual(cnt.frame_count, 3) 142 143 # This is tricky, basically the ignored modes are baked into the guard 144 # IgnoredMode will be ignored forever by that guard. 145 # This is okay since we don't expect to be modifying IGNORED_MODES 146 # in the middle of execution except for the purposes of testing. 147 torch._dynamo.reset() 148 149 with IgnoredMode(): 150 fn(inp) 151 152 self.assertEqual(cnt.frame_count, 4) 153 154 @torch._dynamo.config.patch("enable_cpp_guard_manager", False) 155 def test_torch_function_mode_guards_ignored_types_py(self): 156 self._run_ignored_mode_types_test() 157 158 def test_torch_function_mode_guards_ignored_types_cpp(self): 159 self._run_ignored_mode_types_test() 160 161 @torch._dynamo.config.patch("enable_cpp_guard_manager", False) 162 def test_torch_function_mode_guards_py(self): 163 self._run_torch_function_mode_guard_test() 164 165 def test_torch_function_mode_guards_cpp(self): 166 self._run_torch_function_mode_guard_test() 167 168 def test_stack_state_mutation_default_device(self): 169 m = BaseTorchFunctionMode() 170 m1 = BaseTorchFunctionMode() 171 with m, m1: 172 173 @torch.compile(fullgraph=True) 174 def fn(x): 175 torch.set_default_device("cpu") 176 _pop_torch_function_stack() 177 178 fn(torch.ones(2, 2)) 179 _push_on_torch_function_stack(m1) 180 181 stack = _get_current_function_mode_stack() 182 self.assertIsInstance(stack[0], DeviceContext) 183 self.assertEqual(stack[0].device, torch.device("cpu")) 184 self.assertIs(stack[1], m) 185 self.assertIs(stack[2], m1) 186 187 def test_stack_state_clear_default_device(self): 188 @torch.compile(fullgraph=True) 189 def fn(x): 190 torch.set_default_device(None) 191 return x + 1 192 193 fn(torch.ones(2, 2)) 194 stack = _get_current_function_mode_stack() 195 self.assertEqual(len(stack), 0) 196 197 m = BaseTorchFunctionMode() 198 m1 = BaseTorchFunctionMode() 199 200 # Stack populated, add device 201 with m, m1: 202 203 @torch.compile(fullgraph=True) 204 def fn(x): 205 torch.set_default_device("cpu") 206 torch.set_default_device(None) 207 torch.set_default_device("cpu") 208 return x + 1 209 210 fn(torch.ones(2, 2)) 211 stack = _get_current_function_mode_stack() 212 self.assertEqual(stack[0].device, torch.device("cpu")) 213 self.assertIs(stack[1], m) 214 self.assertIs(stack[2], m1) 215 216 # Stack populated, remove device 217 torch.set_default_device("cpu") 218 with m, m1: 219 220 @torch.compile(fullgraph=True) 221 def fn(x): 222 torch.set_default_device(None) 223 return x + 1 224 225 fn(torch.ones(2, 2)) 226 stack = _get_current_function_mode_stack() 227 self.assertIs(stack[0], m) 228 self.assertIs(stack[1], m1) 229 230 @torch.compile(fullgraph=True) 231 def fn(x): 232 torch.set_default_device("cpu") 233 torch.set_default_device("cpu") 234 return x + 1 235 236 fn(torch.ones(2, 2)) 237 stack = _get_current_function_mode_stack() 238 self.assertEqual(stack[0].device, torch.device("cpu")) 239 torch.set_default_device(None) 240 241 def test_pop_torch_function_mode(self): 242 m = BaseTorchFunctionMode() 243 with m: 244 245 @torch.compile(fullgraph=True) 246 def fn(x): 247 _pop_torch_function_stack() 248 return x + 1 249 250 fn(torch.ones(2, 2)) 251 252 self.assertEqual(_len_torch_function_stack(), 0) 253 # reset stack so __exit__ doesn't crash 254 _push_on_torch_function_stack(m) 255 256 self.assertEqual(_len_torch_function_stack(), 0) 257 258 def test_error_empty_stack_pop_torch_function_mode(self): 259 @torch.compile(fullgraph=True) 260 def fn(x): 261 _pop_torch_function_stack() 262 return x + 1 263 264 self.assertRaisesRegex( 265 torch._dynamo.exc.Unsupported, 266 "Popping from an empty torch function mode stack", 267 lambda: fn(torch.ones(2, 2)), 268 ) 269 270 def test_push_torch_function_mode(self): 271 m = BaseTorchFunctionMode() 272 with m: 273 274 @torch.compile(fullgraph=True) 275 def fn(x, m): 276 _push_on_torch_function_stack(m) 277 return x + 1 278 279 fn(torch.ones(2, 2), m) 280 281 self.assertEqual(_len_torch_function_stack(), 2) 282 # reset stack state 283 _pop_torch_function_stack() 284 285 self.assertEqual(_len_torch_function_stack(), 0) 286 287 def test_len_torch_function_mode(self): 288 m = BaseTorchFunctionMode() 289 with m: 290 291 @torch.compile(fullgraph=True) 292 def fn(x): 293 z = _len_torch_function_stack() 294 return x + z 295 296 res = fn(torch.ones(2, 2)) 297 self.assertEqual(res, torch.ones(2, 2) + 1) 298 self.assertEqual(_len_torch_function_stack(), 1) 299 300 def test_intermedate_torch_function_mode_construction_mutation(self): 301 class TestMode(BaseTorchFunctionMode): 302 def __init__(self, x): 303 self.x = x 304 305 @torch.compile(fullgraph=True) 306 def fn(x): 307 z = TestMode(2) 308 z.y = 2 309 return x + 1, z 310 311 fn(torch.ones(2, 2)) 312 313 def test_torch_function_mode_enabled_guard(self): 314 cnt = torch._dynamo.testing.CompileCounter() 315 inp = torch.ones(2, 2) 316 317 @torch.compile(backend=cnt.__call__) 318 def fn(x): 319 return x + 1 320 321 with BaseTorchFunctionMode(), torch._C.DisableTorchFunctionSubclass(): 322 with torch._C.DisableTorchFunction(): 323 fn(inp) 324 fn(inp) 325 self.assertEqual(cnt.frame_count, 2) 326 327 328if __name__ == "__main__": 329 from torch._dynamo.test_case import run_tests 330 331 run_tests() 332