1# Owner(s): ["oncall: pt2"] 2import functools 3import sys 4import unittest 5from unittest.mock import patch 6 7import torch 8import torch.utils.checkpoint 9from functorch.compile import aot_function, min_cut_rematerialization_partition, nop 10 11from torch.testing._internal.common_device_type import ( 12 dtypes, 13 instantiate_device_type_tests, 14) 15 16from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, run_tests, TestCase 17 18if IS_WINDOWS and IS_CI: 19 sys.stderr.write("torch.compile not supported on windows") 20 if __name__ == "__main__": 21 sys.exit(0) 22 raise unittest.SkipTest("torch.compile not supported on windows") 23 24 25def count_philox_rand(gm, args, freq): 26 assert [node.target for node in gm.graph.nodes].count( 27 torch.ops.rngprims.philox_rand.default 28 ) == freq 29 return gm 30 31 32class TestFunctionalizationRngOps(TestCase): 33 @dtypes(torch.float32) 34 @patch.object(torch._functorch.config, "functionalize_rng_ops", True) 35 def test_rand_like(self, dtype, device): 36 def fn(x): 37 a = torch.rand_like(x) * x 38 a = torch.rand_like(x) * a 39 return a 40 41 x = torch.rand(10, device=device, dtype=dtype) 42 43 for seed in range(10): 44 torch.cuda.manual_seed(seed) 45 ref = fn(x) 46 47 torch.cuda.manual_seed(seed) 48 aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2)) 49 res = aot_fn(x) 50 51 self.assertEqual(ref, res) 52 53 @dtypes(torch.float32) 54 @patch.object(torch._functorch.config, "functionalize_rng_ops", True) 55 def test_rand_like_dynamic(self, dtype, device): 56 def fn(x): 57 a = torch.rand_like(x) * x 58 a = torch.rand_like(x) * a 59 return a 60 61 for seed in range(1, 10): 62 shape = (seed, seed) 63 x = torch.rand(shape, device=device, dtype=dtype) 64 torch.cuda.manual_seed(seed) 65 ref = fn(x) 66 67 torch.cuda.manual_seed(seed) 68 opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True) 69 res = opt_fn(x) 70 71 self.assertEqual(ref, res) 72 73 @dtypes(torch.float32) 74 @patch.object(torch._functorch.config, "functionalize_rng_ops", True) 75 def test_rand_like_dynamic_bwd(self, dtype, device): 76 def fn(x): 77 a = torch.rand_like(x) * x 78 a = torch.rand_like(x) * a 79 return a 80 81 for seed in range(1, 10): 82 shape = (seed, seed) 83 x = torch.rand(shape, device=device, dtype=dtype, requires_grad=True) 84 torch.cuda.manual_seed(seed) 85 ref = fn(x) 86 ref.sum().backward() 87 88 torch.cuda.manual_seed(seed) 89 opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True) 90 res = opt_fn(x) 91 res.sum().backward() 92 93 self.assertEqual(ref, res) 94 95 @dtypes(torch.float32) 96 @patch.object(torch._functorch.config, "functionalize_rng_ops", True) 97 def test_rand(self, dtype, device): 98 shape = (10,) 99 100 def fn(x): 101 a = torch.rand(*shape, device=device, dtype=dtype) * x 102 a = torch.rand(*shape, device=device, dtype=dtype) * a 103 return a 104 105 x = torch.rand(*shape, device=device, dtype=dtype) 106 107 for seed in range(10): 108 torch.cuda.manual_seed(seed) 109 ref = fn(x) 110 111 torch.cuda.manual_seed(seed) 112 aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2)) 113 res = aot_fn(x) 114 115 self.assertEqual(ref, res) 116 117 @dtypes(torch.float32) 118 @patch.object(torch._functorch.config, "functionalize_rng_ops", True) 119 def test_autograd_function(self, dtype, device): 120 shape = (16, 16) 121 122 class Custom(torch.autograd.Function): 123 @staticmethod 124 def forward(ctx, x): 125 ctx.save_for_backward(x) 126 a = torch.rand_like(x) * x 127 a = torch.rand_like(x) * a 128 return a 129 130 @staticmethod 131 def backward(ctx, grad_out): 132 (x,) = ctx.saved_tensors 133 return grad_out * torch.rand_like(grad_out) * torch.cos(x) 134 135 custom = Custom.apply 136 137 x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True) 138 139 x_clone = x.clone().detach().requires_grad_(True) 140 141 torch.cuda.manual_seed(123) 142 ref = custom(x) 143 ref.sum().backward() 144 145 torch.cuda.manual_seed(123) 146 fwd_compiler = functools.partial(count_philox_rand, freq=2) 147 bwd_compiler = functools.partial(count_philox_rand, freq=1) 148 aot_custom = aot_function(custom, fwd_compiler, bwd_compiler) 149 res = aot_custom(x_clone) 150 res.sum().backward() 151 152 self.assertEqual(ref, res) 153 self.assertEqual(x.grad, x_clone.grad) 154 155 @dtypes(torch.float32) 156 @patch.object(torch._functorch.config, "functionalize_rng_ops", True) 157 def test_multiple_subgraphs(self, dtype, device): 158 # Checks that rng state is maintained when there are multiple aot traced 159 # graphs. 160 shape = (16, 16) 161 162 class CustomOp1(torch.autograd.Function): 163 @staticmethod 164 def forward(ctx, x): 165 ctx.save_for_backward(x) 166 a = torch.rand_like(x) * x 167 a = torch.rand_like(x) * a 168 return a 169 170 @staticmethod 171 def backward(ctx, grad_out): 172 (x,) = ctx.saved_tensors 173 return grad_out * torch.rand_like(grad_out) * torch.cos(x) 174 175 class CustomOp2(torch.autograd.Function): 176 @staticmethod 177 def forward(ctx, x): 178 ctx.save_for_backward(x) 179 a = torch.rand_like(x) * x 180 return a 181 182 @staticmethod 183 def backward(ctx, grad_out): 184 (x,) = ctx.saved_tensors 185 return grad_out * torch.rand_like(grad_out) * torch.rand_like(x) 186 187 custom_op1 = CustomOp1.apply 188 custom_op2 = CustomOp2.apply 189 190 def fn(x): 191 a = custom_op1(x) 192 b = a.sin() 193 return custom_op2(b) 194 195 fwd_compiler = functools.partial(count_philox_rand, freq=2) 196 bwd_compiler = functools.partial(count_philox_rand, freq=1) 197 aot_custom_op1 = aot_function(custom_op1, fwd_compiler, bwd_compiler) 198 fwd_compiler = functools.partial(count_philox_rand, freq=1) 199 bwd_compiler = functools.partial(count_philox_rand, freq=2) 200 aot_custom_op2 = aot_function(custom_op2, fwd_compiler, bwd_compiler) 201 202 def aot_fn(x): 203 a = aot_custom_op1(x) 204 b = a.sin() 205 return aot_custom_op2(b) 206 207 for seed in range(10): 208 torch.cuda.manual_seed(seed) 209 x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True) 210 x_clone = x.clone().detach().requires_grad_(True) 211 212 torch.cuda.manual_seed(seed) 213 ref = fn(x) 214 ref.sum().backward() 215 216 torch.cuda.manual_seed(seed) 217 res = aot_fn(x_clone) 218 res.sum().backward() 219 220 self.assertEqual(ref, res) 221 self.assertEqual(x.grad, x_clone.grad) 222 223 @dtypes(torch.float32) 224 @patch.object(torch._functorch.config, "functionalize_rng_ops", True) 225 def test_set_get_rng_state(self, dtype, device): 226 def fn(x): 227 a = torch.rand_like(x) * x 228 state = torch.cuda.get_rng_state() 229 a = torch.rand_like(x) * a 230 torch.cuda.set_rng_state(state) 231 a = torch.rand_like(x) * a 232 return a 233 234 x = torch.rand(10, device=device, dtype=dtype) 235 236 for seed in range(10): 237 torch.cuda.manual_seed(seed) 238 ref = fn(x) 239 240 torch.cuda.manual_seed(seed) 241 fwd_compiler = functools.partial(count_philox_rand, freq=3) 242 aot_fn = aot_function(fn, fwd_compiler) 243 res = aot_fn(x) 244 245 self.assertEqual(ref, res) 246 247 @dtypes(torch.float32) 248 @patch.object(torch._functorch.config, "functionalize_rng_ops", True) 249 def test_min_cut_partitioner(self, dtype, device): 250 # Checks that the calling convention is maintained 251 shape = (16, 16) 252 253 def fn(x): 254 a = torch.rand_like(x) * x 255 a = torch.rand_like(x) * a 256 a = torch.sin(a) 257 a = torch.sin(a) 258 a = torch.sin(a) 259 return a 260 261 x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True) 262 263 x_clone = x.clone().detach().requires_grad_(True) 264 265 torch.cuda.manual_seed(123) 266 ref = fn(x) 267 ref.sum().backward() 268 269 torch.cuda.manual_seed(123) 270 fwd_compiler = functools.partial(count_philox_rand, freq=2) 271 bwd_compiler = functools.partial(count_philox_rand, freq=0) 272 aot_custom = aot_function( 273 fn, 274 fwd_compiler, 275 bwd_compiler, 276 partition_fn=min_cut_rematerialization_partition, 277 ) 278 # aot_custom = aot_function(fn, fwd_compiler, bwd_compiler) 279 res = aot_custom(x_clone) 280 res.sum().backward() 281 282 self.assertEqual(ref, res) 283 self.assertEqual(x.grad, x_clone.grad) 284 285 # TODO - Dropout needs more work because of offset calculation 286 @patch.object(torch._functorch.config, "functionalize_rng_ops", True) 287 @dtypes(torch.float32) 288 def test_checkpoint(self, dtype, device): 289 def g(x, y): 290 return torch.nn.functional.dropout(x, 0.6) 291 292 def fn(x, y): 293 return torch.utils.checkpoint.checkpoint(g, x, y, use_reentrant=False) 294 295 # x = torch.rand(2, 2, device="cuda", requires_grad=True) 296 x = torch.ones(2, 2, device="cuda", requires_grad=True) 297 y = torch.rand(2, 2, device="cuda", requires_grad=True) 298 torch.cuda.manual_seed(123) 299 ref = fn(x, y) 300 301 # With checkpointing we should recompute dropout in bwd, and philox_rand is passed from fwd 302 fwd_compiler = functools.partial(count_philox_rand, freq=1) 303 bwd_compiler = functools.partial(count_philox_rand, freq=0) 304 aot_fn = aot_function(fn, fwd_compiler, bwd_compiler) 305 # We cant check accuracy here because rand_like generated different rand numbers than dropout 306 res = aot_fn(x, y) 307 res.sum().backward() 308 309 @dtypes(torch.float32) 310 @patch.object(torch._functorch.config, "functionalize_rng_ops", True) 311 def test_dropout_decomp(self, dtype, device): 312 def fn(x): 313 return torch.nn.functional.dropout(x, 0.6) * x 314 315 x = torch.rand(10, device=device, dtype=dtype) 316 317 # Ensure the decomp is happening 318 aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=1)) 319 # We cant check accuracy here because rand_like generated different rand numbers than dropout 320 aot_fn(x) 321 322 323only_for = ("cuda",) 324instantiate_device_type_tests(TestFunctionalizationRngOps, globals(), only_for=only_for) 325 326 327class NegativeTest(TestCase): 328 @dtypes(torch.float32) 329 @patch.object(torch._functorch.config, "functionalize_rng_ops", True) 330 def test_on_cpu(self, dtype, device): 331 def fn(x): 332 a = torch.rand_like(x) * x 333 a = torch.rand_like(x) * a 334 return a 335 336 x = torch.rand(10, device=device, dtype=dtype) 337 338 aot_fn = aot_function(fn, nop) 339 with self.assertRaises(RuntimeError): 340 aot_fn(x) 341 342 343only_for = ("cpu",) 344instantiate_device_type_tests(NegativeTest, globals(), only_for=only_for) 345 346if __name__ == "__main__": 347 run_tests() 348