xref: /aosp_15_r20/external/pytorch/test/cpp/api/optim_baseline.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Script to generate baseline values from PyTorch optimization algorithms"""
2
3import argparse
4import math
5import sys
6
7import torch
8import torch.optim
9
10
11HEADER = """
12#include <torch/types.h>
13
14#include <vector>
15
16namespace expected_parameters {
17"""
18
19FOOTER = "} // namespace expected_parameters"
20
21PARAMETERS = "inline std::vector<std::vector<torch::Tensor>> {}() {{"
22
23OPTIMIZERS = {
24    "LBFGS": lambda p: torch.optim.LBFGS(p, 1.0),
25    "LBFGS_with_line_search": lambda p: torch.optim.LBFGS(
26        p, 1.0, line_search_fn="strong_wolfe"
27    ),
28    "Adam": lambda p: torch.optim.Adam(p, 1.0),
29    "Adam_with_weight_decay": lambda p: torch.optim.Adam(p, 1.0, weight_decay=1e-2),
30    "Adam_with_weight_decay_and_amsgrad": lambda p: torch.optim.Adam(
31        p, 1.0, weight_decay=1e-6, amsgrad=True
32    ),
33    "AdamW": lambda p: torch.optim.AdamW(p, 1.0),
34    "AdamW_without_weight_decay": lambda p: torch.optim.AdamW(p, 1.0, weight_decay=0),
35    "AdamW_with_amsgrad": lambda p: torch.optim.AdamW(p, 1.0, amsgrad=True),
36    "Adagrad": lambda p: torch.optim.Adagrad(p, 1.0),
37    "Adagrad_with_weight_decay": lambda p: torch.optim.Adagrad(
38        p, 1.0, weight_decay=1e-2
39    ),
40    "Adagrad_with_weight_decay_and_lr_decay": lambda p: torch.optim.Adagrad(
41        p, 1.0, weight_decay=1e-6, lr_decay=1e-3
42    ),
43    "RMSprop": lambda p: torch.optim.RMSprop(p, 0.1),
44    "RMSprop_with_weight_decay": lambda p: torch.optim.RMSprop(
45        p, 0.1, weight_decay=1e-2
46    ),
47    "RMSprop_with_weight_decay_and_centered": lambda p: torch.optim.RMSprop(
48        p, 0.1, weight_decay=1e-6, centered=True
49    ),
50    "RMSprop_with_weight_decay_and_centered_and_momentum": lambda p: torch.optim.RMSprop(
51        p, 0.1, weight_decay=1e-6, centered=True, momentum=0.9
52    ),
53    "SGD": lambda p: torch.optim.SGD(p, 0.1),
54    "SGD_with_weight_decay": lambda p: torch.optim.SGD(p, 0.1, weight_decay=1e-2),
55    "SGD_with_weight_decay_and_momentum": lambda p: torch.optim.SGD(
56        p, 0.1, momentum=0.9, weight_decay=1e-2
57    ),
58    "SGD_with_weight_decay_and_nesterov_momentum": lambda p: torch.optim.SGD(
59        p, 0.1, momentum=0.9, weight_decay=1e-6, nesterov=True
60    ),
61}
62
63
64def weight_init(module):
65    if isinstance(module, torch.nn.Linear):
66        stdev = 1.0 / math.sqrt(module.weight.size(1))
67        for p in module.parameters():
68            p.data.uniform_(-stdev, stdev)
69
70
71def run(optimizer_name, iterations, sample_every):
72    torch.manual_seed(0)
73    model = torch.nn.Sequential(
74        torch.nn.Linear(2, 3),
75        torch.nn.Sigmoid(),
76        torch.nn.Linear(3, 1),
77        torch.nn.Sigmoid(),
78    )
79    model = model.to(torch.float64).apply(weight_init)
80
81    optimizer = OPTIMIZERS[optimizer_name](model.parameters())
82
83    input = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], dtype=torch.float64)
84
85    values = []
86    for i in range(iterations):
87        optimizer.zero_grad()
88
89        output = model.forward(input)
90        loss = output.sum()
91        loss.backward()
92
93        def closure():
94            return torch.tensor([10.0])
95
96        optimizer.step(closure)
97
98        if i % sample_every == 0:
99            values.append(
100                [p.clone().flatten().data.numpy() for p in model.parameters()]
101            )
102
103    return values
104
105
106def emit(optimizer_parameter_map):
107    # Don't write generated with an @ in front, else this file is recognized as generated.
108    print("// @{} from {}".format("generated", __file__))
109    print(HEADER)
110    for optimizer_name, parameters in optimizer_parameter_map.items():
111        print(PARAMETERS.format(optimizer_name))
112        print("  return {")
113        for sample in parameters:
114            print("    {")
115            for parameter in sample:
116                parameter_values = "{{{}}}".format(", ".join(map(str, parameter)))
117                print(f"      torch::tensor({parameter_values}),")
118            print("    },")
119        print("  };")
120        print("}\n")
121    print(FOOTER)
122
123
124def main():
125    parser = argparse.ArgumentParser(
126        "Produce optimization output baseline from PyTorch"
127    )
128    parser.add_argument("-i", "--iterations", default=1001, type=int)
129    parser.add_argument("-s", "--sample-every", default=100, type=int)
130    options = parser.parse_args()
131
132    optimizer_parameter_map = {}
133    for optimizer in OPTIMIZERS.keys():
134        sys.stderr.write(f"Evaluating {optimizer} ...\n")
135        optimizer_parameter_map[optimizer] = run(
136            optimizer, options.iterations, options.sample_every
137        )
138
139    emit(optimizer_parameter_map)
140
141
142if __name__ == "__main__":
143    main()
144