xref: /aosp_15_r20/external/pytorch/test/test_functionalization_of_rng_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: pt2"]
2import functools
3import sys
4import unittest
5from unittest.mock import patch
6
7import torch
8import torch.utils.checkpoint
9from functorch.compile import aot_function, min_cut_rematerialization_partition, nop
10
11from torch.testing._internal.common_device_type import (
12    dtypes,
13    instantiate_device_type_tests,
14)
15
16from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, run_tests, TestCase
17
18if IS_WINDOWS and IS_CI:
19    sys.stderr.write("torch.compile not supported on windows")
20    if __name__ == "__main__":
21        sys.exit(0)
22    raise unittest.SkipTest("torch.compile not supported on windows")
23
24
25def count_philox_rand(gm, args, freq):
26    assert [node.target for node in gm.graph.nodes].count(
27        torch.ops.rngprims.philox_rand.default
28    ) == freq
29    return gm
30
31
32class TestFunctionalizationRngOps(TestCase):
33    @dtypes(torch.float32)
34    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
35    def test_rand_like(self, dtype, device):
36        def fn(x):
37            a = torch.rand_like(x) * x
38            a = torch.rand_like(x) * a
39            return a
40
41        x = torch.rand(10, device=device, dtype=dtype)
42
43        for seed in range(10):
44            torch.cuda.manual_seed(seed)
45            ref = fn(x)
46
47            torch.cuda.manual_seed(seed)
48            aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2))
49            res = aot_fn(x)
50
51            self.assertEqual(ref, res)
52
53    @dtypes(torch.float32)
54    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
55    def test_rand_like_dynamic(self, dtype, device):
56        def fn(x):
57            a = torch.rand_like(x) * x
58            a = torch.rand_like(x) * a
59            return a
60
61        for seed in range(1, 10):
62            shape = (seed, seed)
63            x = torch.rand(shape, device=device, dtype=dtype)
64            torch.cuda.manual_seed(seed)
65            ref = fn(x)
66
67            torch.cuda.manual_seed(seed)
68            opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True)
69            res = opt_fn(x)
70
71            self.assertEqual(ref, res)
72
73    @dtypes(torch.float32)
74    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
75    def test_rand_like_dynamic_bwd(self, dtype, device):
76        def fn(x):
77            a = torch.rand_like(x) * x
78            a = torch.rand_like(x) * a
79            return a
80
81        for seed in range(1, 10):
82            shape = (seed, seed)
83            x = torch.rand(shape, device=device, dtype=dtype, requires_grad=True)
84            torch.cuda.manual_seed(seed)
85            ref = fn(x)
86            ref.sum().backward()
87
88            torch.cuda.manual_seed(seed)
89            opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True)
90            res = opt_fn(x)
91            res.sum().backward()
92
93            self.assertEqual(ref, res)
94
95    @dtypes(torch.float32)
96    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
97    def test_rand(self, dtype, device):
98        shape = (10,)
99
100        def fn(x):
101            a = torch.rand(*shape, device=device, dtype=dtype) * x
102            a = torch.rand(*shape, device=device, dtype=dtype) * a
103            return a
104
105        x = torch.rand(*shape, device=device, dtype=dtype)
106
107        for seed in range(10):
108            torch.cuda.manual_seed(seed)
109            ref = fn(x)
110
111            torch.cuda.manual_seed(seed)
112            aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2))
113            res = aot_fn(x)
114
115            self.assertEqual(ref, res)
116
117    @dtypes(torch.float32)
118    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
119    def test_autograd_function(self, dtype, device):
120        shape = (16, 16)
121
122        class Custom(torch.autograd.Function):
123            @staticmethod
124            def forward(ctx, x):
125                ctx.save_for_backward(x)
126                a = torch.rand_like(x) * x
127                a = torch.rand_like(x) * a
128                return a
129
130            @staticmethod
131            def backward(ctx, grad_out):
132                (x,) = ctx.saved_tensors
133                return grad_out * torch.rand_like(grad_out) * torch.cos(x)
134
135        custom = Custom.apply
136
137        x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
138
139        x_clone = x.clone().detach().requires_grad_(True)
140
141        torch.cuda.manual_seed(123)
142        ref = custom(x)
143        ref.sum().backward()
144
145        torch.cuda.manual_seed(123)
146        fwd_compiler = functools.partial(count_philox_rand, freq=2)
147        bwd_compiler = functools.partial(count_philox_rand, freq=1)
148        aot_custom = aot_function(custom, fwd_compiler, bwd_compiler)
149        res = aot_custom(x_clone)
150        res.sum().backward()
151
152        self.assertEqual(ref, res)
153        self.assertEqual(x.grad, x_clone.grad)
154
155    @dtypes(torch.float32)
156    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
157    def test_multiple_subgraphs(self, dtype, device):
158        # Checks that rng state is maintained when there are multiple aot traced
159        # graphs.
160        shape = (16, 16)
161
162        class CustomOp1(torch.autograd.Function):
163            @staticmethod
164            def forward(ctx, x):
165                ctx.save_for_backward(x)
166                a = torch.rand_like(x) * x
167                a = torch.rand_like(x) * a
168                return a
169
170            @staticmethod
171            def backward(ctx, grad_out):
172                (x,) = ctx.saved_tensors
173                return grad_out * torch.rand_like(grad_out) * torch.cos(x)
174
175        class CustomOp2(torch.autograd.Function):
176            @staticmethod
177            def forward(ctx, x):
178                ctx.save_for_backward(x)
179                a = torch.rand_like(x) * x
180                return a
181
182            @staticmethod
183            def backward(ctx, grad_out):
184                (x,) = ctx.saved_tensors
185                return grad_out * torch.rand_like(grad_out) * torch.rand_like(x)
186
187        custom_op1 = CustomOp1.apply
188        custom_op2 = CustomOp2.apply
189
190        def fn(x):
191            a = custom_op1(x)
192            b = a.sin()
193            return custom_op2(b)
194
195        fwd_compiler = functools.partial(count_philox_rand, freq=2)
196        bwd_compiler = functools.partial(count_philox_rand, freq=1)
197        aot_custom_op1 = aot_function(custom_op1, fwd_compiler, bwd_compiler)
198        fwd_compiler = functools.partial(count_philox_rand, freq=1)
199        bwd_compiler = functools.partial(count_philox_rand, freq=2)
200        aot_custom_op2 = aot_function(custom_op2, fwd_compiler, bwd_compiler)
201
202        def aot_fn(x):
203            a = aot_custom_op1(x)
204            b = a.sin()
205            return aot_custom_op2(b)
206
207        for seed in range(10):
208            torch.cuda.manual_seed(seed)
209            x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
210            x_clone = x.clone().detach().requires_grad_(True)
211
212            torch.cuda.manual_seed(seed)
213            ref = fn(x)
214            ref.sum().backward()
215
216            torch.cuda.manual_seed(seed)
217            res = aot_fn(x_clone)
218            res.sum().backward()
219
220            self.assertEqual(ref, res)
221            self.assertEqual(x.grad, x_clone.grad)
222
223    @dtypes(torch.float32)
224    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
225    def test_set_get_rng_state(self, dtype, device):
226        def fn(x):
227            a = torch.rand_like(x) * x
228            state = torch.cuda.get_rng_state()
229            a = torch.rand_like(x) * a
230            torch.cuda.set_rng_state(state)
231            a = torch.rand_like(x) * a
232            return a
233
234        x = torch.rand(10, device=device, dtype=dtype)
235
236        for seed in range(10):
237            torch.cuda.manual_seed(seed)
238            ref = fn(x)
239
240            torch.cuda.manual_seed(seed)
241            fwd_compiler = functools.partial(count_philox_rand, freq=3)
242            aot_fn = aot_function(fn, fwd_compiler)
243            res = aot_fn(x)
244
245            self.assertEqual(ref, res)
246
247    @dtypes(torch.float32)
248    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
249    def test_min_cut_partitioner(self, dtype, device):
250        # Checks that the calling convention is maintained
251        shape = (16, 16)
252
253        def fn(x):
254            a = torch.rand_like(x) * x
255            a = torch.rand_like(x) * a
256            a = torch.sin(a)
257            a = torch.sin(a)
258            a = torch.sin(a)
259            return a
260
261        x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
262
263        x_clone = x.clone().detach().requires_grad_(True)
264
265        torch.cuda.manual_seed(123)
266        ref = fn(x)
267        ref.sum().backward()
268
269        torch.cuda.manual_seed(123)
270        fwd_compiler = functools.partial(count_philox_rand, freq=2)
271        bwd_compiler = functools.partial(count_philox_rand, freq=0)
272        aot_custom = aot_function(
273            fn,
274            fwd_compiler,
275            bwd_compiler,
276            partition_fn=min_cut_rematerialization_partition,
277        )
278        # aot_custom = aot_function(fn, fwd_compiler, bwd_compiler)
279        res = aot_custom(x_clone)
280        res.sum().backward()
281
282        self.assertEqual(ref, res)
283        self.assertEqual(x.grad, x_clone.grad)
284
285    # TODO - Dropout needs more work because of offset calculation
286    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
287    @dtypes(torch.float32)
288    def test_checkpoint(self, dtype, device):
289        def g(x, y):
290            return torch.nn.functional.dropout(x, 0.6)
291
292        def fn(x, y):
293            return torch.utils.checkpoint.checkpoint(g, x, y, use_reentrant=False)
294
295        # x = torch.rand(2, 2, device="cuda", requires_grad=True)
296        x = torch.ones(2, 2, device="cuda", requires_grad=True)
297        y = torch.rand(2, 2, device="cuda", requires_grad=True)
298        torch.cuda.manual_seed(123)
299        ref = fn(x, y)
300
301        # With checkpointing we should recompute dropout in bwd, and philox_rand is passed from fwd
302        fwd_compiler = functools.partial(count_philox_rand, freq=1)
303        bwd_compiler = functools.partial(count_philox_rand, freq=0)
304        aot_fn = aot_function(fn, fwd_compiler, bwd_compiler)
305        # We cant check accuracy here because rand_like generated different rand numbers than dropout
306        res = aot_fn(x, y)
307        res.sum().backward()
308
309    @dtypes(torch.float32)
310    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
311    def test_dropout_decomp(self, dtype, device):
312        def fn(x):
313            return torch.nn.functional.dropout(x, 0.6) * x
314
315        x = torch.rand(10, device=device, dtype=dtype)
316
317        # Ensure the decomp is happening
318        aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=1))
319        # We cant check accuracy here because rand_like generated different rand numbers than dropout
320        aot_fn(x)
321
322
323only_for = ("cuda",)
324instantiate_device_type_tests(TestFunctionalizationRngOps, globals(), only_for=only_for)
325
326
327class NegativeTest(TestCase):
328    @dtypes(torch.float32)
329    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
330    def test_on_cpu(self, dtype, device):
331        def fn(x):
332            a = torch.rand_like(x) * x
333            a = torch.rand_like(x) * a
334            return a
335
336        x = torch.rand(10, device=device, dtype=dtype)
337
338        aot_fn = aot_function(fn, nop)
339        with self.assertRaises(RuntimeError):
340            aot_fn(x)
341
342
343only_for = ("cpu",)
344instantiate_device_type_tests(NegativeTest, globals(), only_for=only_for)
345
346if __name__ == "__main__":
347    run_tests()
348