1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 6*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 7*da0073e9SAndroid Build Coastguard Workerfrom torch._C import ( 8*da0073e9SAndroid Build Coastguard Worker _len_torch_function_stack, 9*da0073e9SAndroid Build Coastguard Worker _pop_torch_function_stack, 10*da0073e9SAndroid Build Coastguard Worker _push_on_torch_function_stack, 11*da0073e9SAndroid Build Coastguard Worker) 12*da0073e9SAndroid Build Coastguard Workerfrom torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode 13*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._device import DeviceContext 14*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerclass TorchDispatchModeTests(torch._dynamo.test_case.TestCase): 18*da0073e9SAndroid Build Coastguard Worker @classmethod 19*da0073e9SAndroid Build Coastguard Worker def setUpClass(cls): 20*da0073e9SAndroid Build Coastguard Worker super().setUpClass() 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker @classmethod 23*da0073e9SAndroid Build Coastguard Worker def tearDownClass(cls): 24*da0073e9SAndroid Build Coastguard Worker super().tearDownClass() 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker def test_skip_torch_dispatch_modes(self): 27*da0073e9SAndroid Build Coastguard Worker class RewriteAddToMul(TorchDispatchMode): 28*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 29*da0073e9SAndroid Build Coastguard Worker if func is torch.ops.aten.add.Tensor: 30*da0073e9SAndroid Build Coastguard Worker func = torch.ops.aten.mul.Tensor 31*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker def fn(x): 34*da0073e9SAndroid Build Coastguard Worker return x + x 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([3.0]) 39*da0073e9SAndroid Build Coastguard Worker with RewriteAddToMul(): 40*da0073e9SAndroid Build Coastguard Worker eager_res = fn(x) 41*da0073e9SAndroid Build Coastguard Worker compiled_res = torch._dynamo.optimize(cnt)(fn)(x) 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_res, compiled_res) 44*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 0) 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Workerclass TorchFunctionModeTests(torch._dynamo.test_case.TestCase): 48*da0073e9SAndroid Build Coastguard Worker @classmethod 49*da0073e9SAndroid Build Coastguard Worker def setUpClass(cls): 50*da0073e9SAndroid Build Coastguard Worker cls.default_device_old = torch.get_default_device() 51*da0073e9SAndroid Build Coastguard Worker super().setUpClass() 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker @classmethod 54*da0073e9SAndroid Build Coastguard Worker def tearDownClass(cls): 55*da0073e9SAndroid Build Coastguard Worker torch.set_default_device(cls.default_device_old) 56*da0073e9SAndroid Build Coastguard Worker super().tearDownClass() 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker def setUp(self): 59*da0073e9SAndroid Build Coastguard Worker torch.set_default_device(None) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 62*da0073e9SAndroid Build Coastguard Worker torch.set_default_device(None) 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker def _run_torch_function_mode_guard_test(self): 65*da0073e9SAndroid Build Coastguard Worker class TestMode1(BaseTorchFunctionMode): 66*da0073e9SAndroid Build Coastguard Worker pass 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker class TestMode2(BaseTorchFunctionMode): 69*da0073e9SAndroid Build Coastguard Worker pass 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt.__call__) 74*da0073e9SAndroid Build Coastguard Worker def fn(x): 75*da0073e9SAndroid Build Coastguard Worker return x + 1 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(2, 2) 78*da0073e9SAndroid Build Coastguard Worker fn(inp) 79*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker with TestMode1(): 82*da0073e9SAndroid Build Coastguard Worker fn(inp) 83*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker with TestMode1(), TestMode2(): 86*da0073e9SAndroid Build Coastguard Worker fn(inp) 87*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 3) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker with TestMode2(), TestMode1(): 90*da0073e9SAndroid Build Coastguard Worker fn(inp) 91*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 4) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker with TestMode1(): 94*da0073e9SAndroid Build Coastguard Worker fn(inp) 95*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 4) 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker def _run_ignored_mode_types_test(self): 98*da0073e9SAndroid Build Coastguard Worker class IgnoredMode(BaseTorchFunctionMode): 99*da0073e9SAndroid Build Coastguard Worker pass 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt.__call__, fullgraph=True) 104*da0073e9SAndroid Build Coastguard Worker def fn(x): 105*da0073e9SAndroid Build Coastguard Worker return x + 1 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(2, 2) 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker with patch( 110*da0073e9SAndroid Build Coastguard Worker "torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode} 111*da0073e9SAndroid Build Coastguard Worker ): 112*da0073e9SAndroid Build Coastguard Worker # initial compile 113*da0073e9SAndroid Build Coastguard Worker fn(inp) 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker # no recompile, mode ignored 116*da0073e9SAndroid Build Coastguard Worker # note: the ref stack is length 0, and the stack we are checking against has length 2 117*da0073e9SAndroid Build Coastguard Worker # we want to check both ref stack len > runtime stack, and ref stack len < runtime stack 118*da0073e9SAndroid Build Coastguard Worker with IgnoredMode(), IgnoredMode(): 119*da0073e9SAndroid Build Coastguard Worker fn(inp) 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker # recompile due to new mode on the stack 124*da0073e9SAndroid Build Coastguard Worker with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): 125*da0073e9SAndroid Build Coastguard Worker fn(inp) 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker # recompile 130*da0073e9SAndroid Build Coastguard Worker # tests both ref stack len > runtime stack len for the above guard check 131*da0073e9SAndroid Build Coastguard Worker # and ref stack len < runtime stack len for the initial zero mode case 132*da0073e9SAndroid Build Coastguard Worker with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode(): 133*da0073e9SAndroid Build Coastguard Worker fn(inp) 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 3) 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker # no recompile 138*da0073e9SAndroid Build Coastguard Worker with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): 139*da0073e9SAndroid Build Coastguard Worker fn(inp) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 3) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker # This is tricky, basically the ignored modes are baked into the guard 144*da0073e9SAndroid Build Coastguard Worker # IgnoredMode will be ignored forever by that guard. 145*da0073e9SAndroid Build Coastguard Worker # This is okay since we don't expect to be modifying IGNORED_MODES 146*da0073e9SAndroid Build Coastguard Worker # in the middle of execution except for the purposes of testing. 147*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker with IgnoredMode(): 150*da0073e9SAndroid Build Coastguard Worker fn(inp) 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 4) 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch("enable_cpp_guard_manager", False) 155*da0073e9SAndroid Build Coastguard Worker def test_torch_function_mode_guards_ignored_types_py(self): 156*da0073e9SAndroid Build Coastguard Worker self._run_ignored_mode_types_test() 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker def test_torch_function_mode_guards_ignored_types_cpp(self): 159*da0073e9SAndroid Build Coastguard Worker self._run_ignored_mode_types_test() 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch("enable_cpp_guard_manager", False) 162*da0073e9SAndroid Build Coastguard Worker def test_torch_function_mode_guards_py(self): 163*da0073e9SAndroid Build Coastguard Worker self._run_torch_function_mode_guard_test() 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker def test_torch_function_mode_guards_cpp(self): 166*da0073e9SAndroid Build Coastguard Worker self._run_torch_function_mode_guard_test() 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker def test_stack_state_mutation_default_device(self): 169*da0073e9SAndroid Build Coastguard Worker m = BaseTorchFunctionMode() 170*da0073e9SAndroid Build Coastguard Worker m1 = BaseTorchFunctionMode() 171*da0073e9SAndroid Build Coastguard Worker with m, m1: 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True) 174*da0073e9SAndroid Build Coastguard Worker def fn(x): 175*da0073e9SAndroid Build Coastguard Worker torch.set_default_device("cpu") 176*da0073e9SAndroid Build Coastguard Worker _pop_torch_function_stack() 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker fn(torch.ones(2, 2)) 179*da0073e9SAndroid Build Coastguard Worker _push_on_torch_function_stack(m1) 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker stack = _get_current_function_mode_stack() 182*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(stack[0], DeviceContext) 183*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stack[0].device, torch.device("cpu")) 184*da0073e9SAndroid Build Coastguard Worker self.assertIs(stack[1], m) 185*da0073e9SAndroid Build Coastguard Worker self.assertIs(stack[2], m1) 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker def test_stack_state_clear_default_device(self): 188*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True) 189*da0073e9SAndroid Build Coastguard Worker def fn(x): 190*da0073e9SAndroid Build Coastguard Worker torch.set_default_device(None) 191*da0073e9SAndroid Build Coastguard Worker return x + 1 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker fn(torch.ones(2, 2)) 194*da0073e9SAndroid Build Coastguard Worker stack = _get_current_function_mode_stack() 195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(stack), 0) 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker m = BaseTorchFunctionMode() 198*da0073e9SAndroid Build Coastguard Worker m1 = BaseTorchFunctionMode() 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker # Stack populated, add device 201*da0073e9SAndroid Build Coastguard Worker with m, m1: 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True) 204*da0073e9SAndroid Build Coastguard Worker def fn(x): 205*da0073e9SAndroid Build Coastguard Worker torch.set_default_device("cpu") 206*da0073e9SAndroid Build Coastguard Worker torch.set_default_device(None) 207*da0073e9SAndroid Build Coastguard Worker torch.set_default_device("cpu") 208*da0073e9SAndroid Build Coastguard Worker return x + 1 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker fn(torch.ones(2, 2)) 211*da0073e9SAndroid Build Coastguard Worker stack = _get_current_function_mode_stack() 212*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stack[0].device, torch.device("cpu")) 213*da0073e9SAndroid Build Coastguard Worker self.assertIs(stack[1], m) 214*da0073e9SAndroid Build Coastguard Worker self.assertIs(stack[2], m1) 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker # Stack populated, remove device 217*da0073e9SAndroid Build Coastguard Worker torch.set_default_device("cpu") 218*da0073e9SAndroid Build Coastguard Worker with m, m1: 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True) 221*da0073e9SAndroid Build Coastguard Worker def fn(x): 222*da0073e9SAndroid Build Coastguard Worker torch.set_default_device(None) 223*da0073e9SAndroid Build Coastguard Worker return x + 1 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker fn(torch.ones(2, 2)) 226*da0073e9SAndroid Build Coastguard Worker stack = _get_current_function_mode_stack() 227*da0073e9SAndroid Build Coastguard Worker self.assertIs(stack[0], m) 228*da0073e9SAndroid Build Coastguard Worker self.assertIs(stack[1], m1) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True) 231*da0073e9SAndroid Build Coastguard Worker def fn(x): 232*da0073e9SAndroid Build Coastguard Worker torch.set_default_device("cpu") 233*da0073e9SAndroid Build Coastguard Worker torch.set_default_device("cpu") 234*da0073e9SAndroid Build Coastguard Worker return x + 1 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker fn(torch.ones(2, 2)) 237*da0073e9SAndroid Build Coastguard Worker stack = _get_current_function_mode_stack() 238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stack[0].device, torch.device("cpu")) 239*da0073e9SAndroid Build Coastguard Worker torch.set_default_device(None) 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker def test_pop_torch_function_mode(self): 242*da0073e9SAndroid Build Coastguard Worker m = BaseTorchFunctionMode() 243*da0073e9SAndroid Build Coastguard Worker with m: 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True) 246*da0073e9SAndroid Build Coastguard Worker def fn(x): 247*da0073e9SAndroid Build Coastguard Worker _pop_torch_function_stack() 248*da0073e9SAndroid Build Coastguard Worker return x + 1 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker fn(torch.ones(2, 2)) 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_len_torch_function_stack(), 0) 253*da0073e9SAndroid Build Coastguard Worker # reset stack so __exit__ doesn't crash 254*da0073e9SAndroid Build Coastguard Worker _push_on_torch_function_stack(m) 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_len_torch_function_stack(), 0) 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker def test_error_empty_stack_pop_torch_function_mode(self): 259*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True) 260*da0073e9SAndroid Build Coastguard Worker def fn(x): 261*da0073e9SAndroid Build Coastguard Worker _pop_torch_function_stack() 262*da0073e9SAndroid Build Coastguard Worker return x + 1 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 265*da0073e9SAndroid Build Coastguard Worker torch._dynamo.exc.Unsupported, 266*da0073e9SAndroid Build Coastguard Worker "Popping from an empty torch function mode stack", 267*da0073e9SAndroid Build Coastguard Worker lambda: fn(torch.ones(2, 2)), 268*da0073e9SAndroid Build Coastguard Worker ) 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker def test_push_torch_function_mode(self): 271*da0073e9SAndroid Build Coastguard Worker m = BaseTorchFunctionMode() 272*da0073e9SAndroid Build Coastguard Worker with m: 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True) 275*da0073e9SAndroid Build Coastguard Worker def fn(x, m): 276*da0073e9SAndroid Build Coastguard Worker _push_on_torch_function_stack(m) 277*da0073e9SAndroid Build Coastguard Worker return x + 1 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker fn(torch.ones(2, 2), m) 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_len_torch_function_stack(), 2) 282*da0073e9SAndroid Build Coastguard Worker # reset stack state 283*da0073e9SAndroid Build Coastguard Worker _pop_torch_function_stack() 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_len_torch_function_stack(), 0) 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker def test_len_torch_function_mode(self): 288*da0073e9SAndroid Build Coastguard Worker m = BaseTorchFunctionMode() 289*da0073e9SAndroid Build Coastguard Worker with m: 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True) 292*da0073e9SAndroid Build Coastguard Worker def fn(x): 293*da0073e9SAndroid Build Coastguard Worker z = _len_torch_function_stack() 294*da0073e9SAndroid Build Coastguard Worker return x + z 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker res = fn(torch.ones(2, 2)) 297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, torch.ones(2, 2) + 1) 298*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_len_torch_function_stack(), 1) 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Worker def test_intermedate_torch_function_mode_construction_mutation(self): 301*da0073e9SAndroid Build Coastguard Worker class TestMode(BaseTorchFunctionMode): 302*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 303*da0073e9SAndroid Build Coastguard Worker self.x = x 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True) 306*da0073e9SAndroid Build Coastguard Worker def fn(x): 307*da0073e9SAndroid Build Coastguard Worker z = TestMode(2) 308*da0073e9SAndroid Build Coastguard Worker z.y = 2 309*da0073e9SAndroid Build Coastguard Worker return x + 1, z 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker fn(torch.ones(2, 2)) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker def test_torch_function_mode_enabled_guard(self): 314*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 315*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(2, 2) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt.__call__) 318*da0073e9SAndroid Build Coastguard Worker def fn(x): 319*da0073e9SAndroid Build Coastguard Worker return x + 1 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker with BaseTorchFunctionMode(), torch._C.DisableTorchFunctionSubclass(): 322*da0073e9SAndroid Build Coastguard Worker with torch._C.DisableTorchFunction(): 323*da0073e9SAndroid Build Coastguard Worker fn(inp) 324*da0073e9SAndroid Build Coastguard Worker fn(inp) 325*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 329*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker run_tests() 332