xref: /aosp_15_r20/external/pytorch/test/distributed/optim/test_apply_optimizer_in_backward.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3# Copyright (c) Meta Platforms, Inc. and affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9
10import unittest
11from copy import deepcopy
12
13import torch
14import torch.nn as nn
15from torch.distributed.optim import (
16    _apply_optimizer_in_backward,
17    _get_in_backward_optimizers,
18)
19
20
21# TODO (rohan-varma): Add FSDP & DDP tests once supported
22
23
24def _validate_params(params_list, fn):
25    ref_params = params_list[0]
26    for param_list in params_list[1:]:
27        for p1, p2 in zip(ref_params, param_list):
28            fn(p1, p2)
29
30
31class ApplyOverlappedOptimizerTest(unittest.TestCase):
32    def _run_training_loop_and_validate(self, inp, models, optimizers):
33        for i in range(6):
34            for model in models:
35                model(inp).sum().backward()
36            for opt in optimizers:
37                opt.step()
38
39            with self.subTest(i):
40                _validate_params(
41                    [model.parameters() for model in models],
42                    torch.testing.assert_allclose,
43                )
44
45            for opt in optimizers:
46                opt.zero_grad(set_to_none=True)
47
48    def _test_apply_optimizer_in_backward(self, share_params) -> None:
49        weight_optimizer_kwargs = {"lr": 1.0}
50        bias_optimizer_kwargs = {"lr": 0.5}
51        model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))
52        if share_params:
53            model[0].weight = model[1].weight
54
55        # Use different optimizers for weights & biases.
56        weights = [m.weight for m in model]
57        biases = [m.bias for m in model]
58        optim_weight = torch.optim.SGD(weights, **weight_optimizer_kwargs)
59        optim_bias = torch.optim.SGD(biases, **bias_optimizer_kwargs)
60        model_with_opt_in_bwd = deepcopy(model)
61
62        # Apply different optimizer in backwards for weights and biases.
63        _apply_optimizer_in_backward(
64            torch.optim.SGD,
65            [m.weight for m in model_with_opt_in_bwd],
66            optimizer_kwargs=weight_optimizer_kwargs,
67        )
68
69        _apply_optimizer_in_backward(
70            torch.optim.SGD,
71            [m.bias for m in model_with_opt_in_bwd],
72            optimizer_kwargs=bias_optimizer_kwargs,
73        )
74
75        _validate_params(
76            [
77                model.parameters(),
78                model_with_opt_in_bwd.parameters(),
79            ],
80            torch.testing.assert_allclose,
81        )
82
83        self._run_training_loop_and_validate(
84            torch.randn(4, 10),
85            [model, model_with_opt_in_bwd],
86            [optim_weight, optim_bias],
87        )
88
89    def test_apply_optimizer_in_backward(self) -> None:
90        self._test_apply_optimizer_in_backward(share_params=False)
91
92    def test_apply_optimizer_in_backward_shared_params(self) -> None:
93        self._test_apply_optimizer_in_backward(share_params=True)
94
95    def test_no_register_hook(self):
96        model_with_hook = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))
97        initial_model = deepcopy(model_with_hook)
98        model_no_hook = deepcopy(model_with_hook)
99        _apply_optimizer_in_backward(
100            torch.optim.SGD,
101            model_with_hook.parameters(),
102            optimizer_kwargs={"lr": 0.03},
103        )
104        _apply_optimizer_in_backward(
105            torch.optim.SGD,
106            model_no_hook.parameters(),
107            optimizer_kwargs={"lr": 0.03},
108            register_hook=False,
109        )
110        inp = torch.randn(4, 10)
111        model_with_hook(inp).sum().backward()
112        model_no_hook(inp).sum().backward()
113
114        for p1, p2 in zip(model_with_hook.parameters(), initial_model.parameters()):
115            with self.assertRaises(AssertionError):
116                torch.testing.assert_allclose(p1, p2)
117
118        for p1, p2 in zip(model_no_hook.parameters(), initial_model.parameters()):
119            torch.testing.assert_allclose(p1, p2)
120
121    def test_multiple_optim_for_params(self) -> None:
122        model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))
123        opt_0_kwargs = {"lr": 0.03}
124        opt_1_kwargs = {"lr": 0.01}
125        opt_0 = torch.optim.SGD(model.parameters(), **opt_0_kwargs)
126        opt_1 = torch.optim.SGD(model.parameters(), **opt_1_kwargs)
127        model_with_opt_in_bwd = deepcopy(model)
128        _apply_optimizer_in_backward(
129            torch.optim.SGD,
130            model_with_opt_in_bwd.parameters(),
131            optimizer_kwargs=opt_0_kwargs,
132        )
133        _apply_optimizer_in_backward(
134            torch.optim.SGD,
135            model_with_opt_in_bwd.parameters(),
136            optimizer_kwargs=opt_1_kwargs,
137        )
138        self._run_training_loop_and_validate(
139            torch.randn(4, 10),
140            [model, model_with_opt_in_bwd],
141            [opt_0, opt_1],
142        )
143
144    def test_get_optimizers_in_backward(self):
145        # Create a simple test model
146        class TestModel(torch.nn.Module):
147            def __init__(self) -> None:
148                super().__init__()
149                self.linear1 = torch.nn.Linear(10, 5)
150                self.linear2 = torch.nn.Linear(5, 2)
151
152        model = TestModel()
153
154        # Apply optimizers in backward
155        _apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {"lr": 0.01})
156        in_backward_optims = _get_in_backward_optimizers(model)
157        self.assertEqual(len(list(model.parameters())), len(in_backward_optims))
158        result = set(in_backward_optims)
159        expected = {
160            optim for p in model.parameters() for optim in p._in_backward_optimizers
161        }
162        self.assertEqual(result, expected)
163