1"""Script to generate baseline values from PyTorch optimization algorithms""" 2 3import argparse 4import math 5import sys 6 7import torch 8import torch.optim 9 10 11HEADER = """ 12#include <torch/types.h> 13 14#include <vector> 15 16namespace expected_parameters { 17""" 18 19FOOTER = "} // namespace expected_parameters" 20 21PARAMETERS = "inline std::vector<std::vector<torch::Tensor>> {}() {{" 22 23OPTIMIZERS = { 24 "LBFGS": lambda p: torch.optim.LBFGS(p, 1.0), 25 "LBFGS_with_line_search": lambda p: torch.optim.LBFGS( 26 p, 1.0, line_search_fn="strong_wolfe" 27 ), 28 "Adam": lambda p: torch.optim.Adam(p, 1.0), 29 "Adam_with_weight_decay": lambda p: torch.optim.Adam(p, 1.0, weight_decay=1e-2), 30 "Adam_with_weight_decay_and_amsgrad": lambda p: torch.optim.Adam( 31 p, 1.0, weight_decay=1e-6, amsgrad=True 32 ), 33 "AdamW": lambda p: torch.optim.AdamW(p, 1.0), 34 "AdamW_without_weight_decay": lambda p: torch.optim.AdamW(p, 1.0, weight_decay=0), 35 "AdamW_with_amsgrad": lambda p: torch.optim.AdamW(p, 1.0, amsgrad=True), 36 "Adagrad": lambda p: torch.optim.Adagrad(p, 1.0), 37 "Adagrad_with_weight_decay": lambda p: torch.optim.Adagrad( 38 p, 1.0, weight_decay=1e-2 39 ), 40 "Adagrad_with_weight_decay_and_lr_decay": lambda p: torch.optim.Adagrad( 41 p, 1.0, weight_decay=1e-6, lr_decay=1e-3 42 ), 43 "RMSprop": lambda p: torch.optim.RMSprop(p, 0.1), 44 "RMSprop_with_weight_decay": lambda p: torch.optim.RMSprop( 45 p, 0.1, weight_decay=1e-2 46 ), 47 "RMSprop_with_weight_decay_and_centered": lambda p: torch.optim.RMSprop( 48 p, 0.1, weight_decay=1e-6, centered=True 49 ), 50 "RMSprop_with_weight_decay_and_centered_and_momentum": lambda p: torch.optim.RMSprop( 51 p, 0.1, weight_decay=1e-6, centered=True, momentum=0.9 52 ), 53 "SGD": lambda p: torch.optim.SGD(p, 0.1), 54 "SGD_with_weight_decay": lambda p: torch.optim.SGD(p, 0.1, weight_decay=1e-2), 55 "SGD_with_weight_decay_and_momentum": lambda p: torch.optim.SGD( 56 p, 0.1, momentum=0.9, weight_decay=1e-2 57 ), 58 "SGD_with_weight_decay_and_nesterov_momentum": lambda p: torch.optim.SGD( 59 p, 0.1, momentum=0.9, weight_decay=1e-6, nesterov=True 60 ), 61} 62 63 64def weight_init(module): 65 if isinstance(module, torch.nn.Linear): 66 stdev = 1.0 / math.sqrt(module.weight.size(1)) 67 for p in module.parameters(): 68 p.data.uniform_(-stdev, stdev) 69 70 71def run(optimizer_name, iterations, sample_every): 72 torch.manual_seed(0) 73 model = torch.nn.Sequential( 74 torch.nn.Linear(2, 3), 75 torch.nn.Sigmoid(), 76 torch.nn.Linear(3, 1), 77 torch.nn.Sigmoid(), 78 ) 79 model = model.to(torch.float64).apply(weight_init) 80 81 optimizer = OPTIMIZERS[optimizer_name](model.parameters()) 82 83 input = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], dtype=torch.float64) 84 85 values = [] 86 for i in range(iterations): 87 optimizer.zero_grad() 88 89 output = model.forward(input) 90 loss = output.sum() 91 loss.backward() 92 93 def closure(): 94 return torch.tensor([10.0]) 95 96 optimizer.step(closure) 97 98 if i % sample_every == 0: 99 values.append( 100 [p.clone().flatten().data.numpy() for p in model.parameters()] 101 ) 102 103 return values 104 105 106def emit(optimizer_parameter_map): 107 # Don't write generated with an @ in front, else this file is recognized as generated. 108 print("// @{} from {}".format("generated", __file__)) 109 print(HEADER) 110 for optimizer_name, parameters in optimizer_parameter_map.items(): 111 print(PARAMETERS.format(optimizer_name)) 112 print(" return {") 113 for sample in parameters: 114 print(" {") 115 for parameter in sample: 116 parameter_values = "{{{}}}".format(", ".join(map(str, parameter))) 117 print(f" torch::tensor({parameter_values}),") 118 print(" },") 119 print(" };") 120 print("}\n") 121 print(FOOTER) 122 123 124def main(): 125 parser = argparse.ArgumentParser( 126 "Produce optimization output baseline from PyTorch" 127 ) 128 parser.add_argument("-i", "--iterations", default=1001, type=int) 129 parser.add_argument("-s", "--sample-every", default=100, type=int) 130 options = parser.parse_args() 131 132 optimizer_parameter_map = {} 133 for optimizer in OPTIMIZERS.keys(): 134 sys.stderr.write(f"Evaluating {optimizer} ...\n") 135 optimizer_parameter_map[optimizer] = run( 136 optimizer, options.iterations, options.sample_every 137 ) 138 139 emit(optimizer_parameter_map) 140 141 142if __name__ == "__main__": 143 main() 144