xref: /aosp_15_r20/external/pytorch/test/test_functional_optim.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import unittest
4from typing import List, Optional, Tuple
5
6import torch
7import torch.distributed
8import torch.nn as nn
9import torch.nn.functional as F
10from torch import Tensor
11from torch.optim import Adam, AdamW, SGD
12from torch.testing._internal.common_utils import run_tests, TestCase
13
14
15class MyModule(torch.nn.Module):
16    def __init__(self) -> None:
17        super().__init__()
18        torch.manual_seed(0)
19        self.lin1 = nn.Linear(3, 3, bias=False)
20        self.lin2 = nn.Linear(3, 3, bias=False)
21
22    def forward(self, t1):
23        return self.lin2(F.relu(self.lin1(t1)))
24
25
26# dummy class to showcase custom optimizer registration with functional wrapper
27class MyDummyFnOptimizer:
28    def __init__(
29        self,
30        params: List[Tensor],
31        lr: float = 1e-3,
32        betas: Tuple[float, float] = (0.9, 0.999),
33        eps: float = 1e-6,
34        weight_decay: float = 0.0,
35        _allow_empty_param_list: bool = False,
36    ):
37        if not 0.0 <= lr:
38            raise ValueError(f"Invalid learning rate: {lr}")
39        if not 0.0 <= eps:
40            raise ValueError(f"Invalid epsilon value: {eps}")
41        if not 0.0 <= betas[0] < 1.0:
42            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
43        if not 0.0 <= betas[1] < 1.0:
44            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
45        if not 0.0 < weight_decay:
46            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
47
48        self.defaults = {
49            "lr": lr,
50            "eps": eps,
51            "beta1": betas[0],
52            "beta2": betas[1],
53            "weight_decay": weight_decay,
54        }
55
56        if len(params) == 0 and not _allow_empty_param_list:
57            raise ValueError("optimizer got an empty parameter list")
58
59    def step_param(self, param: Tensor, grad: Optional[Tensor]):
60        # call the custom optimizer step_param implementation
61        with torch.no_grad():
62            raise RuntimeError(
63                "MyDummyFnOptimizer does not support step_param() as of now"
64            )
65
66    def step(self, gradients: List[Optional[Tensor]]):
67        # call the custom optimizer step implementation
68        with torch.no_grad():
69            raise RuntimeError("MyDummyFnOptimizer does not support step() as of now")
70
71
72if torch.distributed.is_available():
73    from torch.distributed.optim.utils import (
74        functional_optim_map,
75        register_functional_optim,
76    )
77
78
79@unittest.skipIf(
80    not torch.distributed.is_available(), "These are testing distributed functions"
81)
82class TestFunctionalOptimParity(TestCase):
83    def _validate_parameters(self, params_1, params_2):
84        for p1, p2 in zip(params_1, params_2):
85            self.assertEqual(p1, p2)
86
87    # Dynamo fails at compiling this for python 3.8/3.11
88    # Since it passes while compiling the actual code under test
89    # we disable dynamo here.
90    @torch._disable_dynamo(recursive=False)
91    def _test_functional_optim_parity(self, optim_cls, *args, **kwargs):
92        module_optim = MyModule()
93        module_functional = MyModule()
94        optim_params = module_optim.parameters()
95        functional_params = module_functional.parameters()
96        optim = optim_cls(optim_params, *args, **kwargs)
97        functional_optim_cls = functional_optim_map.get(optim_cls, None)
98        if not functional_optim_cls:
99            raise ValueError(f"Functional optimizer not implemented for {optim_cls}")
100        optim_functional = functional_optim_cls(
101            [], *args, **kwargs, _allow_empty_param_list=True
102        )
103        if not hasattr(optim_functional, "step_param"):
104            raise ValueError(
105                f"Functional optimizer class {optim_functional} must implement step_param method."
106            )
107
108        # Initial weights should match
109        self._validate_parameters(
110            module_optim.parameters(), module_functional.parameters()
111        )
112        # Save old parameters to verify optimizer modifies them.
113        old_module_optim_params = [
114            param.clone().detach() for param in module_optim.parameters()
115        ]
116        old_module_functional_params = [
117            param.clone().detach() for param in module_functional.parameters()
118        ]
119
120        t1 = torch.randn(3, 3)
121        for _ in range(10):
122            module_optim.zero_grad()
123            module_functional.zero_grad()
124            # Forward + Backward
125            optim_out = module_optim(t1).sum()
126            functional_out = module_functional(t1).sum()
127            optim_out.backward()
128            functional_out.backward()
129            # Optimizer step
130            optim.step()
131            # Functional optimizer step_param
132            for param in module_functional.parameters():
133                grad = param.grad
134                optim_functional.step_param(param, grad)
135
136            # Validate parameters are equal
137            for optim_param, functional_param in zip(
138                module_optim.parameters(), module_functional.parameters()
139            ):
140                self.assertEqual(optim_param, functional_param)
141            # Validate parameters are modified.
142            for i, (optim_param, functional_param) in enumerate(
143                zip(module_optim.parameters(), module_functional.parameters())
144            ):
145                self.assertNotEqual(old_module_optim_params[i], optim_param)
146                self.assertNotEqual(old_module_functional_params[i], functional_param)
147
148    def _test_functional_optim_registration(self):
149        fn_map_key = "MyDummyFnOptimizer"
150        fn_optim = MyDummyFnOptimizer
151        register_functional_optim(fn_map_key, fn_optim)
152        functional_optim_cls = functional_optim_map.get(fn_map_key, None)
153        if not functional_optim_cls:
154            raise ValueError(f"Functional optimizer not registered for {fn_map_key}")
155
156    def test_functional_optim_registration(self):
157        self._test_functional_optim_registration()
158
159    def test_functional_optim_parity_sgd(self):
160        self._test_functional_optim_parity(SGD, 1e-2, momentum=0.9, weight_decay=0.01)
161
162    def test_functional_optim_parity_adam(self):
163        self._test_functional_optim_parity(Adam, 1e-2, betas=(0.9, 0.999), eps=1e-6)
164
165    def test_functional_optim_parity_adam_w(self):
166        self._test_functional_optim_parity(AdamW, 1e-2, betas=(0.9, 0.999), eps=1e-6)
167
168
169if __name__ == "__main__":
170    run_tests()
171