xref: /aosp_15_r20/external/pytorch/test/inductor/test_distributed_patterns.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: pt2"]
2import dataclasses
3import functools
4
5import torch
6from torch import nn
7from torch._dynamo import compiled_autograd
8from torch._dynamo.test_case import run_tests, TestCase
9from torch._dynamo.testing import CompileCounter
10from torch.testing._internal.common_utils import IS_MACOS, skipIfRocm, skipIfXpu
11from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, requires_gpu
12
13
14# Fake distributed
15WORLD_SIZE = 2
16
17
18def init_fake_distributed(device="cpu"):
19    @torch.no_grad
20    def all_gather(t):
21        return torch.cat([t] * WORLD_SIZE, 0)
22
23    @torch.no_grad
24    def reduce_scatter(t):
25        # clone since reduce_scatter input and output should not be aliases.
26        return t.narrow(0, 0, t.size(0) // WORLD_SIZE).clone()
27
28    def fw_pre_hook(mod, inp):
29        if not compiled_autograd.compiled_autograd_enabled:
30            # torch.ops.fsdp.set_ doesn't work well in eager mode, so use the slow copy_ path instead.
31            mod.unsharded_weight.untyped_storage().resize_(
32                mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size()
33            )
34            with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
35                mod.unsharded_weight
36            ):
37                mod.unsharded_weight.copy_(all_gather(mod.sharded_weight))
38        else:
39            with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
40                mod.unsharded_weight
41            ):
42                torch.ops.fsdp.set_(
43                    mod.unsharded_weight, all_gather(mod.sharded_weight)
44                )
45        mod._parameters["weight"] = mod.unsharded_weight
46
47    # Forward:
48    #   mod.sharded_weight = local_shard (always)
49    #   Before:
50    #     mod.weight = local_shard
51    #     mod.unsharded_weight = zero-sized allgather
52    #   After:
53    #     mod.weight = local_shard
54    #     mod.unsharded_weight = zero-sized allgather
55
56    def fw_post_hook(mod, inp, out):
57        mod._parameters["weight"] = mod.sharded_weight
58        mod.unsharded_weight.untyped_storage().resize_(0)
59
60    def bw_pre_hook(mod, gO):
61        if not compiled_autograd.compiled_autograd_enabled:
62            # torch.ops.fsdp.set_ doesn't work well in eager mode, so use the slow copy_ path instead.
63            mod.unsharded_weight.untyped_storage().resize_(
64                mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size()
65            )
66            with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
67                mod.unsharded_weight
68            ):
69                mod.unsharded_weight.copy_(all_gather(mod.sharded_weight))
70        else:
71            with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
72                mod.unsharded_weight
73            ):
74                torch.ops.fsdp.set_(
75                    mod.unsharded_weight, all_gather(mod.sharded_weight)
76                )
77        mod._parameters["weight"] = mod.unsharded_weight
78
79    # Backward:
80    #   mod.sharded_weight = local_shard (always)
81    #   Before:
82    #     mod.weight = local_shard
83    #     mod.unsharded_weight = zero-sized allgather
84    #   After:
85    #     mod.weight = local_shard
86    #     mod.unsharded_weight = zero-sized allgather
87
88    def bw_post_hook(mod, gI, gO):
89        grad = mod.weight.grad
90        new_grad = reduce_scatter(grad)
91        mod._parameters["weight"] = mod.sharded_weight
92        mod.weight.grad = new_grad
93        mod.unsharded_weight.untyped_storage().resize_(0)
94
95    torch.manual_seed(1234)
96    m = nn.Linear(20, 10, bias=False, device=device)
97
98    # Mimics eager 1st iteration
99    m.sharded_weight = nn.Parameter(reduce_scatter(m.weight))
100    m.unsharded_weight = nn.Parameter(all_gather(m.sharded_weight))
101    m.unsharded_weight.untyped_storage().resize_(0)
102
103    m.register_full_backward_pre_hook(bw_pre_hook)
104    m.register_full_backward_hook(bw_post_hook)
105    m.register_forward_pre_hook(fw_pre_hook)
106    m.register_forward_hook(fw_post_hook)
107    return m, torch.rand(2, 20, requires_grad=True, device=device)
108
109
110def init_module_bw_hooks(allow_eager):
111    def bw_pre_hook(mod, gO):
112        assert allow_eager or torch._dynamo.is_compiling()
113        assert mod.weight.size() == (10, 10)
114        mod.hook_count_pre.add_(1)
115        return (torch.sin(gO[0] + 1.2),)
116
117    def bw_post_hook(mod, gI, gO):
118        assert allow_eager or torch._dynamo.is_compiling()
119        assert mod.weight.size() == (10, 10)
120        mod.hook_count_post.add_(1)
121        return (torch.sin(gI[0] + 3.4),)
122
123    torch.manual_seed(1234)
124    m = nn.Linear(10, 10)
125    m.hook_count_pre = torch.tensor(0)
126    m.hook_count_post = torch.tensor(0)
127    m.register_full_backward_pre_hook(bw_pre_hook)
128    m.register_full_backward_hook(bw_post_hook)
129    return m, torch.rand(2, 10, requires_grad=True)
130
131
132def steps(m, inp):
133    for _ in range(4):
134        out = m(inp)
135        out.sum().backward()
136    return out
137
138
139class DistributedPatternTests(TestCase):
140    def test_intermediate_hook_with_closure(self):
141        @dataclasses.dataclass
142        class CustomObj:
143            val: torch.Tensor
144
145        def fn(x, obj):
146            y = x.sin()
147            closure_var = y + 1
148            y.register_hook(lambda grad: grad + obj.val + closure_var)
149            z = y.sin()
150            return z
151
152        opt = torch.compile(fn, fullgraph=True)
153
154        obj1 = CustomObj(torch.tensor(88))
155        obj2 = CustomObj(torch.tensor(99))
156        x0 = torch.ones(4, requires_grad=True)
157        x1 = torch.ones(4, requires_grad=True)
158        x2 = torch.ones(4, requires_grad=True)
159        x3 = torch.ones(4, requires_grad=True)
160        fn(x0, obj1).sum().backward()
161        fn(x1, obj2).sum().backward()
162
163        with compiled_autograd.enable(functools.partial(torch.compile, fullgraph=True)):
164            opt(x2, obj1).sum().backward()
165            opt(x3, obj2).sum().backward()
166
167        self.assertEqual(x0.grad, x2.grad)
168        self.assertEqual(x1.grad, x3.grad)
169
170    @torch.no_grad()
171    def _test_storage_resize_zero(self, device):
172        @torch.compile(fullgraph=True)
173        def fn(x):
174            y = torch.sin(x)
175            x.untyped_storage().resize_(0)
176            return torch.cos(y)
177
178        x = torch.randn(10, device=device)
179        expected = torch.cos(torch.sin(x))
180        y = fn(x)
181        self.assertEqual(y, expected)
182        self.assertEqual(x.untyped_storage().size(), 0)
183
184    def test_storage_resize_zero_cpu(self):
185        self._test_storage_resize_zero("cpu")
186
187    @skipIfRocm
188    @requires_gpu()
189    def test_storage_resize_zero_gpu(self):
190        self._test_storage_resize_zero(GPU_TYPE)
191
192    @torch.no_grad()
193    def _test_storage_resize_nonzero(self, device):
194        @torch.compile(fullgraph=True)
195        def fn(x, out):
196            y = torch.sin(x)
197            assert out.untyped_storage().size() == 0
198            out.untyped_storage().resize_(x.untyped_storage().size())
199            out.copy_(y.cos())
200
201        x = torch.randn(10, device=device)
202        out = torch.randn(10, device=device)
203        expected = torch.cos(torch.sin(x))
204        out.untyped_storage().resize_(0)
205        fn(x, out)
206        self.assertEqual(out.untyped_storage().size(), x.untyped_storage().size())
207        self.assertEqual(out, expected)
208
209    def test_storage_resize_nonzero_cpu(self):
210        self._test_storage_resize_nonzero("cpu")
211
212    @skipIfRocm
213    @requires_gpu()
214    def test_storage_resize_nonzero_gpu(self):
215        self._test_storage_resize_nonzero(GPU_TYPE)
216
217    @torch.no_grad()
218    def test_unsafe_set_version_counter1(self):
219        cnt = CompileCounter()
220
221        @torch.compile(backend=cnt, fullgraph=True)
222        def fn(w, x):
223            x = x.sin()
224            v = w._version
225            w.copy_(x + 1)
226            torch._C._autograd._unsafe_set_version_counter(w, v)
227            return w, v
228
229        for v in (3, 0, 1):
230            w1 = torch.randn(16)
231            for i in range(v):
232                w1.fill_(i)  # bump w1._version
233            self.assertEqual(w1._version, v)
234            x1 = torch.randn(16)
235            w2, v2 = fn(w1, x1)
236
237            self.assertIs(w1, w2)
238            self.assertEqual(w1, x1.sin() + 1)
239            self.assertEqual(v2, v)
240            self.assertEqual(w1._version, v)
241            self.assertEqual(cnt.frame_count, 1)
242
243    def test_unsafe_set_version_counter2(self):
244        @torch.compile(backend="inductor", fullgraph=True)
245        def fn(w, x):
246            r = w.sin()
247            with torch.no_grad():
248                v = w._version
249                w.copy_(x)
250                torch._C._autograd._unsafe_set_version_counter(w, v)
251            return r
252
253        w1 = torch.randn(1, requires_grad=True)
254        x1 = torch.randn(1)
255        expected_r1 = w1.detach().sin()
256
257        r1 = fn(w1, x1)
258        r1.backward()
259        self.assertEqual(r1, expected_r1)
260        self.assertEqual(w1, x1)
261        self.assertEqual(w1.grad, x1.cos())
262
263    @torch.no_grad()
264    def test_unsafe_preserve_version_counter1(self):
265        @torch.compile(backend="eager", fullgraph=True)
266        def fn(w, x):
267            x = x.sin()
268            with torch.autograd._unsafe_preserve_version_counter(w):
269                w.copy_(x + 1)
270            return w
271
272        w1 = torch.randn(16).fill_(0).fill_(1)
273        x1 = torch.randn(16)
274        v1 = w1._version
275        w2 = fn(w1, x1)
276        v2 = w1._version
277
278        self.assertIs(w1, w2)
279        self.assertEqual(w1, x1.sin() + 1)
280        self.assertEqual(v1, v2)
281
282    def test_unsafe_preserve_version_counter2(self):
283        @torch.compile(backend="inductor", fullgraph=True)
284        def fn(w, x):
285            r = w.sin()
286            with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(w):
287                w.copy_(x)
288            return r
289
290        w1 = torch.randn(1, requires_grad=True)
291        x1 = torch.randn(1)
292        expected_r1 = w1.detach().sin()
293
294        r1 = fn(w1, x1)
295        r1.backward()
296        self.assertEqual(r1, expected_r1)
297        self.assertEqual(w1, x1)
298        self.assertEqual(w1.grad, x1.cos())
299
300    def test_module_backward_hooks_eager(self):
301        m1, inp1 = init_module_bw_hooks(True)
302        out1 = steps(m1, inp1)
303
304        m2, inp2 = init_module_bw_hooks(False)
305        fw_cnt = CompileCounter()
306        bw_cnt = CompileCounter()
307        with compiled_autograd.enable(torch.compile(backend=bw_cnt, fullgraph=True)):
308            m2 = torch.compile(m2, backend=fw_cnt, fullgraph=True)
309            out2 = steps(m2, inp2)
310
311        self.assertEqual(m1.hook_count_pre, m2.hook_count_pre)
312        self.assertEqual(m1.hook_count_post, m2.hook_count_post)
313        self.assertEqual(out1, out2)
314        self.assertEqual(inp1.grad, inp2.grad)
315        self.assertEqual(m1.weight.grad, m2.weight.grad)
316        self.assertEqual(m1.bias.grad, m2.bias.grad)
317
318        self.assertEqual(fw_cnt.frame_count, 1)
319        self.assertEqual(fw_cnt.op_count, 5)
320        self.assertEqual(bw_cnt.frame_count, 2)  # grad=None and grad!=None
321        self.assertEqual(bw_cnt.op_count, 48)
322
323    def test_module_backward_hooks_aot(self):
324        m1, inp1 = init_module_bw_hooks(True)
325        out1 = steps(m1, inp1)
326
327        m2, inp2 = init_module_bw_hooks(True)
328        m2 = torch.compile(m2, backend="aot_eager", fullgraph=True)
329        with compiled_autograd.enable(lambda gm: gm):
330            out2 = steps(m2, inp2)
331
332        self.assertEqual(m1.hook_count_pre, m2.hook_count_pre)
333        self.assertEqual(m1.hook_count_post, m2.hook_count_post)
334        self.assertEqual(out1, out2)
335        self.assertEqual(inp1.grad, inp2.grad)
336        self.assertEqual(m1.weight.grad, m2.weight.grad)
337        self.assertEqual(m1.bias.grad, m2.bias.grad)
338
339    def test_module_backward_hooks_inductor(self):
340        m1, inp1 = init_module_bw_hooks(True)
341        out1 = steps(m1, inp1)
342
343        m2, inp2 = init_module_bw_hooks(False)
344        m2 = torch.compile(m2, fullgraph=True)
345        with compiled_autograd.enable(torch.compile(fullgraph=True)):
346            out2 = steps(m2, inp2)
347
348        self.assertEqual(m1.hook_count_pre, m2.hook_count_pre)
349        self.assertEqual(m1.hook_count_post, m2.hook_count_post)
350        self.assertEqual(out1, out2)
351        self.assertEqual(inp1.grad, inp2.grad)
352        self.assertEqual(m1.weight.grad, m2.weight.grad)
353        self.assertEqual(m1.bias.grad, m2.bias.grad)
354
355    def test_module_backward_hooks_multi_layers(self):
356        a1, inp1 = init_module_bw_hooks(True)
357        b1, _ = init_module_bw_hooks(True)
358        out1 = steps(torch.nn.Sequential(a1, b1), inp1)
359
360        a2, inp2 = init_module_bw_hooks(False)
361        b2, _ = init_module_bw_hooks(False)
362        with compiled_autograd.enable(torch.compile(fullgraph=True)):
363            out2 = steps(
364                torch.compile(torch.nn.Sequential(a2, b2), fullgraph=True), inp2
365            )
366
367        self.assertEqual(a1.hook_count_pre, a2.hook_count_pre)
368        self.assertEqual(a1.hook_count_post, a2.hook_count_post)
369        self.assertEqual(b1.hook_count_pre, b2.hook_count_pre)
370        self.assertEqual(b1.hook_count_post, b2.hook_count_post)
371        self.assertEqual(out1, out2)
372        self.assertEqual(inp1.grad, inp2.grad)
373        self.assertEqual(a1.weight.grad, a2.weight.grad)
374        self.assertEqual(a1.bias.grad, a2.bias.grad)
375        self.assertEqual(b1.weight.grad, b2.weight.grad)
376        self.assertEqual(b1.bias.grad, b2.bias.grad)
377
378    # TODO(jansel): support bw hooks with graph break
379
380    def _assert_same_grad(self, a, b):
381        self.assertEqual(type(a), type(b))
382        self.assertEqual(a, b)
383        self.assertEqual(a.grad, b.grad)
384        self.assertEqual(a.requires_grad, b.requires_grad)
385
386    def test_nn_param_return1(self):
387        def fn(x):
388            p = torch.nn.Parameter(x)
389            return p, p.sin()
390
391        opt = torch.compile(fn, fullgraph=True)
392        x1 = torch.randn(16)
393        x2 = x1.clone()
394
395        p1, r1 = fn(x1)
396        r1.sum().backward()
397        p2, r2 = opt(x2)
398        r2.sum().backward()
399        self._assert_same_grad(r1, r2)
400        self._assert_same_grad(p1, p2)
401
402    def test_nn_param_return2(self):
403        def fn(x):
404            p = torch.nn.Parameter(x, requires_grad=False)
405            return p, x + 1
406
407        opt = torch.compile(fn, fullgraph=True)
408        x1 = torch.randn(16)
409        x2 = x1.clone()
410
411        p1, r1 = fn(x1)
412        p2, r2 = opt(x2)
413        self._assert_same_grad(r1, r2)
414        self._assert_same_grad(p1, p2)
415
416    def test_nn_param_return3(self):
417        def fn(x):
418            p = torch.nn.Parameter(x + 123)
419            return p, p.sin()
420
421        opt = torch.compile(fn, fullgraph=True)
422        x1 = torch.randn(16)
423        x2 = x1.clone()
424
425        p1, r1 = fn(x1)
426        r1.sum().backward()
427        p2, r2 = opt(x2)
428        r2.sum().backward()
429        self._assert_same_grad(r1, r2)
430        self._assert_same_grad(p1, p2)
431
432    def test_nn_param_return4(self):
433        def fn(x):
434            p = torch.nn.Parameter(x + 123, requires_grad=False)
435            return p, x + 1
436
437        opt = torch.compile(fn, fullgraph=True)
438        x1 = torch.randn(16)
439        x2 = x1.clone()
440
441        p1, r1 = fn(x1)
442        p2, r2 = opt(x2)
443        self._assert_same_grad(r1, r2)
444        self._assert_same_grad(p1, p2)
445
446    @torch._functorch.config.patch(recompute_views=True)
447    def test_fake_distributed_aot_eager(self):
448        m1, inp1 = init_fake_distributed()
449        out1 = steps(m1, inp1)
450
451        m2, inp2 = init_fake_distributed()
452        m2 = torch.compile(m2, backend="aot_eager", fullgraph=True)
453        bw_cnt = CompileCounter()
454        with compiled_autograd.enable(torch.compile(backend=bw_cnt, fullgraph=True)):
455            out2 = steps(m2, inp2)
456
457        self._assert_same_grad(m1.weight, m2.weight)
458        self._assert_same_grad(inp1, inp2)
459        self._assert_same_grad(out1, out2)
460        # Recompile on grad==None/grad!=None
461        self.assertEqual(bw_cnt.frame_count, 2)
462
463    @skipIfRocm
464    @skipIfXpu
465    @requires_gpu()
466    @torch._functorch.config.patch(recompute_views=True)
467    def test_fake_distributed_inductor(self):
468        # TODO: fix .set_ lowering in CPU inductor, and enable the CPU test.
469        m1, inp1 = init_fake_distributed(GPU_TYPE)
470        out1 = steps(m1, inp1)
471
472        m2, inp2 = init_fake_distributed(GPU_TYPE)
473        m2 = torch.compile(m2, fullgraph=True)
474        with compiled_autograd.enable(torch.compile(fullgraph=True)):
475            out2 = steps(m2, inp2)
476
477        self._assert_same_grad(m1.weight, m2.weight)
478        self._assert_same_grad(inp1, inp2)
479        self._assert_same_grad(out1, out2)
480
481
482if __name__ == "__main__":
483    if HAS_CPU and not IS_MACOS:
484        run_tests(needs="filelock")
485