1# Owner(s): ["module: dynamo"] 2""" 3PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes 4with test_adam in OptimizerTests) 5""" 6import functools 7 8import torch 9import torch._dynamo 10import torch._dynamo.test_case 11import torch._dynamo.testing 12from torch.nn import Parameter 13 14 15class MyOptimizer(torch.optim.Optimizer): 16 def __init__(self, params): 17 super().__init__(params, {}) 18 19 def _init_group(self, params, group): 20 any_complex = False 21 for p in group["params"]: 22 params.append(p) 23 any_complex |= p.is_complex() 24 return any_complex 25 26 def step(self): 27 for group in self.param_groups: 28 params = [] 29 any_complex = self._init_group(params, group) 30 if any_complex: 31 params[0] -= 1 32 else: 33 params[0] += 1 34 35 36class End2EndTests(torch._dynamo.test_case.TestCase): 37 # https://github.com/pytorch/torchdynamo/issues/1604 38 def test_optimizing_over_tensor_with_requires_grad(self): 39 class Net(torch.nn.Module): 40 def forward(self, x, y): 41 z = torch.bmm(x, y) 42 z = torch.flatten(z, 1) 43 return z 44 45 def training_iter_fn(batch, model, optimizer): 46 optimizer.zero_grad() 47 out = model(**batch) 48 target = torch.tensor([0, 7]) 49 loss = torch.nn.CrossEntropyLoss()(out, target) 50 loss.backward() 51 optimizer.step() 52 return loss 53 54 net = Net() 55 input1 = torch.randn(2, 1, 4) 56 input2 = torch.randn(2, 4, 8, requires_grad=True) 57 optimizer = torch.optim.Adam([input2], lr=0.1) 58 59 cnts = torch._dynamo.testing.CompileCounter() 60 opt_training_iter_fn = torch._dynamo.optimize(cnts)(training_iter_fn) 61 batch = {"x": input1, "y": input2} 62 for _ in range(2): 63 opt_training_iter_fn(batch, net, optimizer) 64 self.assertEqual(cnts.frame_count, 2) 65 66 def test_state_dict(self): 67 @torch.compile(backend="eager") 68 def _test_state_dict(weight, bias, input): 69 def fn_base(optimizer, weight, bias): 70 optimizer.zero_grad() 71 i = input 72 loss = (weight.mv(i) + bias).pow(2).sum() 73 loss.backward() 74 return loss 75 76 optimizer = torch.optim.Adagrad([weight, bias]) 77 fn = functools.partial(fn_base, optimizer, weight, bias) 78 return optimizer, fn 79 80 optimizer, fn = _test_state_dict( 81 Parameter(torch.randn(10, 5)), 82 Parameter(torch.randn(10)), 83 torch.randn(5, requires_grad=True), 84 ) 85 optimizer.step(fn) 86 87 def test_init_group(self): 88 for dtype in [torch.float32, torch.cfloat]: 89 tensor = torch.randn(5, 5, dtype=dtype) 90 params = Parameter(tensor.detach().clone(), requires_grad=False) 91 opt_params = Parameter(tensor.detach().clone(), requires_grad=False) 92 93 optim = MyOptimizer([params]) 94 optim.step() 95 96 opt_optim = MyOptimizer([opt_params]) 97 opt_step = torch.compile(backend="eager", fullgraph=True)(opt_optim.step) 98 opt_step() 99 100 self.assertEqual(params, opt_params) 101 102 103if __name__ == "__main__": 104 from torch._dynamo.test_case import run_tests 105 106 run_tests() 107