xref: /aosp_15_r20/external/pytorch/test/dynamo/test_optimizers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2"""
3PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
4with test_adam in OptimizerTests)
5"""
6import functools
7
8import torch
9import torch._dynamo
10import torch._dynamo.test_case
11import torch._dynamo.testing
12from torch.nn import Parameter
13
14
15class MyOptimizer(torch.optim.Optimizer):
16    def __init__(self, params):
17        super().__init__(params, {})
18
19    def _init_group(self, params, group):
20        any_complex = False
21        for p in group["params"]:
22            params.append(p)
23            any_complex |= p.is_complex()
24        return any_complex
25
26    def step(self):
27        for group in self.param_groups:
28            params = []
29            any_complex = self._init_group(params, group)
30            if any_complex:
31                params[0] -= 1
32            else:
33                params[0] += 1
34
35
36class End2EndTests(torch._dynamo.test_case.TestCase):
37    # https://github.com/pytorch/torchdynamo/issues/1604
38    def test_optimizing_over_tensor_with_requires_grad(self):
39        class Net(torch.nn.Module):
40            def forward(self, x, y):
41                z = torch.bmm(x, y)
42                z = torch.flatten(z, 1)
43                return z
44
45        def training_iter_fn(batch, model, optimizer):
46            optimizer.zero_grad()
47            out = model(**batch)
48            target = torch.tensor([0, 7])
49            loss = torch.nn.CrossEntropyLoss()(out, target)
50            loss.backward()
51            optimizer.step()
52            return loss
53
54        net = Net()
55        input1 = torch.randn(2, 1, 4)
56        input2 = torch.randn(2, 4, 8, requires_grad=True)
57        optimizer = torch.optim.Adam([input2], lr=0.1)
58
59        cnts = torch._dynamo.testing.CompileCounter()
60        opt_training_iter_fn = torch._dynamo.optimize(cnts)(training_iter_fn)
61        batch = {"x": input1, "y": input2}
62        for _ in range(2):
63            opt_training_iter_fn(batch, net, optimizer)
64        self.assertEqual(cnts.frame_count, 2)
65
66    def test_state_dict(self):
67        @torch.compile(backend="eager")
68        def _test_state_dict(weight, bias, input):
69            def fn_base(optimizer, weight, bias):
70                optimizer.zero_grad()
71                i = input
72                loss = (weight.mv(i) + bias).pow(2).sum()
73                loss.backward()
74                return loss
75
76            optimizer = torch.optim.Adagrad([weight, bias])
77            fn = functools.partial(fn_base, optimizer, weight, bias)
78            return optimizer, fn
79
80        optimizer, fn = _test_state_dict(
81            Parameter(torch.randn(10, 5)),
82            Parameter(torch.randn(10)),
83            torch.randn(5, requires_grad=True),
84        )
85        optimizer.step(fn)
86
87    def test_init_group(self):
88        for dtype in [torch.float32, torch.cfloat]:
89            tensor = torch.randn(5, 5, dtype=dtype)
90            params = Parameter(tensor.detach().clone(), requires_grad=False)
91            opt_params = Parameter(tensor.detach().clone(), requires_grad=False)
92
93            optim = MyOptimizer([params])
94            optim.step()
95
96            opt_optim = MyOptimizer([opt_params])
97            opt_step = torch.compile(backend="eager", fullgraph=True)(opt_optim.step)
98            opt_step()
99
100            self.assertEqual(params, opt_params)
101
102
103if __name__ == "__main__":
104    from torch._dynamo.test_case import run_tests
105
106    run_tests()
107