xref: /aosp_15_r20/external/pytorch/test/cpp/api/optim_baseline.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // @generated from optim_baseline.py
2 
3 #include <torch/types.h>
4 
5 #include <vector>
6 
7 namespace expected_parameters {
8 
LBFGS()9 inline std::vector<std::vector<torch::Tensor>> LBFGS() {
10   return {
11       {
12           torch::tensor(
13               {-0.20959197386869663,
14                -0.49580870398532073,
15                -0.1313442585372408,
16                -0.3287331939506787,
17                -0.24613947168465267,
18                0.705889510763571}),
19           torch::tensor(
20               {-0.10412662274500666, -0.2644705062031845, 0.7102859961803084}),
21           torch::tensor(
22               {-0.19787984636009417, -0.5320223708266223, -0.5396083236337847}),
23           torch::tensor({-0.43108206822505857}),
24       },
25       {
26           torch::tensor(
27               {0.4377600774755075,
28                0.3828823919505896,
29                0.5308031277873992,
30                0.5752746453369446,
31                0.23943592910168343,
32                1.3739197373644627}),
33           torch::tensor(
34               {2.209263823172053, 2.154134023426646, 2.534834254325867}),
35           torch::tensor(
36               {-4.091952315741579, -4.67916063385269, -4.781279234594454}),
37           torch::tensor({-4.776742087865583}),
38       },
39       {
40           torch::tensor(
41               {0.4377600774755075,
42                0.3828823919505896,
43                0.5308031277873992,
44                0.5752746453369446,
45                0.23943592910168343,
46                1.3739197373644627}),
47           torch::tensor(
48               {2.209263823172053, 2.154134023426646, 2.534834254325867}),
49           torch::tensor(
50               {-4.091952315741579, -4.67916063385269, -4.781279234594454}),
51           torch::tensor({-4.776742087865583}),
52       },
53       {
54           torch::tensor(
55               {0.4377600774755075,
56                0.3828823919505896,
57                0.5308031277873992,
58                0.5752746453369446,
59                0.23943592910168343,
60                1.3739197373644627}),
61           torch::tensor(
62               {2.209263823172053, 2.154134023426646, 2.534834254325867}),
63           torch::tensor(
64               {-4.091952315741579, -4.67916063385269, -4.781279234594454}),
65           torch::tensor({-4.776742087865583}),
66       },
67       {
68           torch::tensor(
69               {0.4377600774755075,
70                0.3828823919505896,
71                0.5308031277873992,
72                0.5752746453369446,
73                0.23943592910168343,
74                1.3739197373644627}),
75           torch::tensor(
76               {2.209263823172053, 2.154134023426646, 2.534834254325867}),
77           torch::tensor(
78               {-4.091952315741579, -4.67916063385269, -4.781279234594454}),
79           torch::tensor({-4.776742087865583}),
80       },
81       {
82           torch::tensor(
83               {0.4377600774755075,
84                0.3828823919505896,
85                0.5308031277873992,
86                0.5752746453369446,
87                0.23943592910168343,
88                1.3739197373644627}),
89           torch::tensor(
90               {2.209263823172053, 2.154134023426646, 2.534834254325867}),
91           torch::tensor(
92               {-4.091952315741579, -4.67916063385269, -4.781279234594454}),
93           torch::tensor({-4.776742087865583}),
94       },
95       {
96           torch::tensor(
97               {0.4377600774755075,
98                0.3828823919505896,
99                0.5308031277873992,
100                0.5752746453369446,
101                0.23943592910168343,
102                1.3739197373644627}),
103           torch::tensor(
104               {2.209263823172053, 2.154134023426646, 2.534834254325867}),
105           torch::tensor(
106               {-4.091952315741579, -4.67916063385269, -4.781279234594454}),
107           torch::tensor({-4.776742087865583}),
108       },
109       {
110           torch::tensor(
111               {0.4377600774755075,
112                0.3828823919505896,
113                0.5308031277873992,
114                0.5752746453369446,
115                0.23943592910168343,
116                1.3739197373644627}),
117           torch::tensor(
118               {2.209263823172053, 2.154134023426646, 2.534834254325867}),
119           torch::tensor(
120               {-4.091952315741579, -4.67916063385269, -4.781279234594454}),
121           torch::tensor({-4.776742087865583}),
122       },
123       {
124           torch::tensor(
125               {0.4377600774755075,
126                0.3828823919505896,
127                0.5308031277873992,
128                0.5752746453369446,
129                0.23943592910168343,
130                1.3739197373644627}),
131           torch::tensor(
132               {2.209263823172053, 2.154134023426646, 2.534834254325867}),
133           torch::tensor(
134               {-4.091952315741579, -4.67916063385269, -4.781279234594454}),
135           torch::tensor({-4.776742087865583}),
136       },
137       {
138           torch::tensor(
139               {0.4377600774755075,
140                0.3828823919505896,
141                0.5308031277873992,
142                0.5752746453369446,
143                0.23943592910168343,
144                1.3739197373644627}),
145           torch::tensor(
146               {2.209263823172053, 2.154134023426646, 2.534834254325867}),
147           torch::tensor(
148               {-4.091952315741579, -4.67916063385269, -4.781279234594454}),
149           torch::tensor({-4.776742087865583}),
150       },
151       {
152           torch::tensor(
153               {0.4377600774755075,
154                0.3828823919505896,
155                0.5308031277873992,
156                0.5752746453369446,
157                0.23943592910168343,
158                1.3739197373644627}),
159           torch::tensor(
160               {2.209263823172053, 2.154134023426646, 2.534834254325867}),
161           torch::tensor(
162               {-4.091952315741579, -4.67916063385269, -4.781279234594454}),
163           torch::tensor({-4.776742087865583}),
164       },
165   };
166 }
167 
LBFGS_with_line_search()168 inline std::vector<std::vector<torch::Tensor>> LBFGS_with_line_search() {
169   return {
170       {
171           torch::tensor(
172               {-0.2108988568338871,
173                -0.4975560466422629,
174                -0.14129216202471762,
175                -0.3420288967865903,
176                -0.2523635082803723,
177                0.6975570255493777}),
178           torch::tensor(
179               {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
180           torch::tensor(
181               {-0.05080313011597659,
182                -0.39413518751058996,
183                -0.28433759745928844}),
184           torch::tensor({-0.07113812430174116}),
185       },
186       {
187           torch::tensor(
188               {-0.2108988568338871,
189                -0.4975560466422629,
190                -0.14129216202471762,
191                -0.3420288967865903,
192                -0.2523635082803723,
193                0.6975570255493777}),
194           torch::tensor(
195               {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
196           torch::tensor(
197               {-0.05080313011597659,
198                -0.39413518751058996,
199                -0.28433759745928844}),
200           torch::tensor({-0.07113812430174116}),
201       },
202       {
203           torch::tensor(
204               {-0.2108988568338871,
205                -0.4975560466422629,
206                -0.14129216202471762,
207                -0.3420288967865903,
208                -0.2523635082803723,
209                0.6975570255493777}),
210           torch::tensor(
211               {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
212           torch::tensor(
213               {-0.05080313011597659,
214                -0.39413518751058996,
215                -0.28433759745928844}),
216           torch::tensor({-0.07113812430174116}),
217       },
218       {
219           torch::tensor(
220               {-0.2108988568338871,
221                -0.4975560466422629,
222                -0.14129216202471762,
223                -0.3420288967865903,
224                -0.2523635082803723,
225                0.6975570255493777}),
226           torch::tensor(
227               {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
228           torch::tensor(
229               {-0.05080313011597659,
230                -0.39413518751058996,
231                -0.28433759745928844}),
232           torch::tensor({-0.07113812430174116}),
233       },
234       {
235           torch::tensor(
236               {-0.2108988568338871,
237                -0.4975560466422629,
238                -0.14129216202471762,
239                -0.3420288967865903,
240                -0.2523635082803723,
241                0.6975570255493777}),
242           torch::tensor(
243               {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
244           torch::tensor(
245               {-0.05080313011597659,
246                -0.39413518751058996,
247                -0.28433759745928844}),
248           torch::tensor({-0.07113812430174116}),
249       },
250       {
251           torch::tensor(
252               {-0.2108988568338871,
253                -0.4975560466422629,
254                -0.14129216202471762,
255                -0.3420288967865903,
256                -0.2523635082803723,
257                0.6975570255493777}),
258           torch::tensor(
259               {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
260           torch::tensor(
261               {-0.05080313011597659,
262                -0.39413518751058996,
263                -0.28433759745928844}),
264           torch::tensor({-0.07113812430174116}),
265       },
266       {
267           torch::tensor(
268               {-0.2108988568338871,
269                -0.4975560466422629,
270                -0.14129216202471762,
271                -0.3420288967865903,
272                -0.2523635082803723,
273                0.6975570255493777}),
274           torch::tensor(
275               {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
276           torch::tensor(
277               {-0.05080313011597659,
278                -0.39413518751058996,
279                -0.28433759745928844}),
280           torch::tensor({-0.07113812430174116}),
281       },
282       {
283           torch::tensor(
284               {-0.2108988568338871,
285                -0.4975560466422629,
286                -0.14129216202471762,
287                -0.3420288967865903,
288                -0.2523635082803723,
289                0.6975570255493777}),
290           torch::tensor(
291               {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
292           torch::tensor(
293               {-0.05080313011597659,
294                -0.39413518751058996,
295                -0.28433759745928844}),
296           torch::tensor({-0.07113812430174116}),
297       },
298       {
299           torch::tensor(
300               {-0.2108988568338871,
301                -0.4975560466422629,
302                -0.14129216202471762,
303                -0.3420288967865903,
304                -0.2523635082803723,
305                0.6975570255493777}),
306           torch::tensor(
307               {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
308           torch::tensor(
309               {-0.05080313011597659,
310                -0.39413518751058996,
311                -0.28433759745928844}),
312           torch::tensor({-0.07113812430174116}),
313       },
314       {
315           torch::tensor(
316               {-0.2108988568338871,
317                -0.4975560466422629,
318                -0.14129216202471762,
319                -0.3420288967865903,
320                -0.2523635082803723,
321                0.6975570255493777}),
322           torch::tensor(
323               {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
324           torch::tensor(
325               {-0.05080313011597659,
326                -0.39413518751058996,
327                -0.28433759745928844}),
328           torch::tensor({-0.07113812430174116}),
329       },
330       {
331           torch::tensor(
332               {-0.2108988568338871,
333                -0.4975560466422629,
334                -0.14129216202471762,
335                -0.3420288967865903,
336                -0.2523635082803723,
337                0.6975570255493777}),
338           torch::tensor(
339               {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
340           torch::tensor(
341               {-0.05080313011597659,
342                -0.39413518751058996,
343                -0.28433759745928844}),
344           torch::tensor({-0.07113812430174116}),
345       },
346   };
347 }
348 
Adam()349 inline std::vector<std::vector<torch::Tensor>> Adam() {
350   return {
351       {
352           torch::tensor(
353               {0.7890972864438472,
354                0.5024410688121617,
355                0.8587073313055582,
356                0.6579707241208395,
357                0.7476356819075531,
358                1.697556420651692}),
359           torch::tensor(
360               {0.891467636010675, 0.7020513497567501, 1.6892012709428947}),
361           torch::tensor(
362               {-1.0508030958460797, -1.3941351509567657, -1.284337577714353}),
363           torch::tensor({-1.071138110298716}),
364       },
365       {
366           torch::tensor(
367               {8.233039313231828,
368                7.971150747377481,
369                6.6436209506776,
370                6.470977407900541,
371                6.170125488259256,
372                7.1507391033435015}),
373           torch::tensor(
374               {8.417695070103735, 6.597188212844593, 7.23175710827678}),
375           torch::tensor(
376               {-6.729624357635757, -7.09743493108154, -6.753301896575352}),
377           torch::tensor({-6.435639096011218}),
378       },
379       {
380           torch::tensor(
381               {8.233424596059296,
382                7.971537360032308,
383                6.643920150720394,
384                6.47127807553724,
385                6.170405874224489,
386                7.151021086137982}),
387           torch::tensor(
388               {8.418084791214294, 6.597493171180545, 7.232043740621598}),
389           torch::tensor(
390               {-6.729918250724671, -7.097730102046093, -6.753584809755359}),
391           torch::tensor({-6.4359165566974985}),
392       },
393       {
394           torch::tensor(
395               {8.233424610557648,
396                7.971537374586563,
397                6.643920161995285,
398                6.471278086877829,
399                6.170405884785074,
400                7.151021096766405}),
401           torch::tensor(
402               {8.418084805901902, 6.597493182713584, 7.2320437514477875}),
403           torch::tensor(
404               {-6.72991829363266, -7.097730147102975, -6.753584838821182}),
405           torch::tensor({-6.435916580217771}),
406       },
407       {
408           torch::tensor(
409               {8.233424610575101,
410                7.971537374611125,
411                6.643920162027962,
412                6.471278086923278,
413                6.170405884809245,
414                7.15102109680004}),
415           torch::tensor(
416               {8.418084805946389, 6.597493182796847, 7.232043751509309}),
417           torch::tensor(
418               {-6.729918332327653, -7.097730188349552, -6.753584861205486}),
419           torch::tensor({-6.435916596115672}),
420       },
421       {
422           torch::tensor(
423               {8.233424610594858,
424                7.971537374639166,
425                6.643920162065571,
426                6.471278086975759,
427                6.170405884836981,
428                7.1510210968387975}),
429           torch::tensor(
430               {8.418084805997614, 6.59749318289335, 7.232043751580523}),
431           torch::tensor(
432               {-6.72991837738045, -7.097730236373201, -6.753584887267492}),
433           torch::tensor({-6.43591661462546}),
434       },
435       {
436           torch::tensor(
437               {8.233424610617288,
438                7.971537374671012,
439                6.643920162108285,
440                6.471278087035362,
441                6.170405884868481,
442                7.151021096882811}),
443           torch::tensor(
444               {8.418084806055795, 6.59749318300295, 7.232043751661401}),
445           torch::tensor(
446               {-6.729918428547273, -7.09773029091405, -6.753584916866329}),
447           torch::tensor({-6.4359166356471755}),
448       },
449       {
450           torch::tensor(
451               {8.233424610642352,
452                7.9715373747065925,
453                6.6439201621560064,
454                6.471278087101955,
455                6.1704058849036745,
456                7.151021096931989}),
457           torch::tensor(
458               {8.418084806120799, 6.597493183125404, 7.232043751751764}),
459           torch::tensor(
460               {-6.729918485714688, -7.0977303518511805, -6.753584949936365}),
461           torch::tensor({-6.43591665913422}),
462       },
463       {
464           torch::tensor(
465               {8.233424610670035,
466                7.97153737474589,
467                6.6439201622087145,
468                6.471278087175502,
469                6.170405884942545,
470                7.151021096986302}),
471           torch::tensor(
472               {8.418084806192592, 6.597493183260647, 7.232043751851564}),
473           torch::tensor(
474               {-6.729918548853505, -7.097730419153473, -6.753584986460725}),
475           torch::tensor({-6.435916685074594}),
476       },
477       {
478           torch::tensor(
479               {8.233424610700348,
480                7.971537374788922,
481                6.643920162266433,
482                6.4712780872560405,
483                6.17040588498511,
484                7.151021097045779}),
485           torch::tensor(
486               {8.418084806271214, 6.597493183408747, 7.232043751960854}),
487           torch::tensor(
488               {-6.7299186179943, -7.097730492853521, -6.753585026457088}),
489           torch::tensor({-6.435916713480863}),
490       },
491       {
492           torch::tensor(
493               {8.233424610733326,
494                7.971537374835737,
495                6.643920162329225,
496                6.471278087343659,
497                6.170405885031416,
498                7.151021097110483}),
499           torch::tensor(
500               {8.418084806356743, 6.597493183569867, 7.232043752079749}),
501           torch::tensor(
502               {-6.729918693213275, -7.097730573032567, -6.753585069969552}),
503           torch::tensor({-6.43591674438434}),
504       },
505   };
506 }
507 
Adam_with_weight_decay()508 inline std::vector<std::vector<torch::Tensor>> Adam_with_weight_decay() {
509   return {
510       {
511           torch::tensor(
512               {0.7890990163499767,
513                0.5024427688479549,
514                0.858707365154099,
515                0.65797076763247,
516                0.7476358193232038,
517                1.6975559791029715}),
518           torch::tensor(
519               {0.8914677624298939, 0.7020513562204098, 1.6892012237887575}),
520           torch::tensor(
521               {-1.050803095786311, -1.3941351504224309, -1.2843375776028747}),
522           torch::tensor({-1.0711381102847533}),
523       },
524       {
525           torch::tensor(
526               {0.17835734655765323,
527                0.2542117171890537,
528                0.19681971909229715,
529                0.23522651199260597,
530                0.17806083719648957,
531                0.22943655675307303}),
532           torch::tensor(
533               {0.6227676931552837, 0.6058596954431213, 0.6077176546857177}),
534           torch::tensor(
535               {-1.4259755901844118, -1.4333355461952704, -1.408545526635006}),
536           torch::tensor({-2.0710783081666215}),
537       },
538       {
539           torch::tensor(
540               {0.17965695035191162,
541                0.24254352340441693,
542                0.17964663531482672,
543                0.24250834976541322,
544                0.17962893833698693,
545                0.24249920074277215}),
546           torch::tensor(
547               {0.6287144967638043, 0.6286955805603279, 0.6286563093833837}),
548           torch::tensor(
549               {-1.4123887230853596, -1.4124007126659273, -1.4122701589749163}),
550           torch::tensor({-2.063357041247863}),
551       },
552       {
553           torch::tensor(
554               {0.1796366651819,
555                0.24250861931831874,
556                0.17963731759793083,
557                0.24250861142436989,
558                0.1796372002681969,
559                0.24250890248031373}),
560           torch::tensor(
561               {0.6287221269294724, 0.6287225821354421, 0.6287220274975922}),
562           torch::tensor(
563               {-1.4123466103044011, -1.4123465669572683, -1.4123462614739388}),
564           torch::tensor({-2.063368365143669}),
565       },
566       {
567           torch::tensor(
568               {0.17963666103165563,
569                0.24250882317446784,
570                0.17963665831217887,
571                0.24250882481082656,
572                0.17963666029066117,
573                0.24250882426223175}),
574           torch::tensor(
575               {0.6287216329900817, 0.6287216340515608, 0.6287216326960158}),
576           torch::tensor(
577               {-1.4123467542623926, -1.4123467542350234, -1.4123467478191443}),
578           torch::tensor({-2.0633690432440437}),
579       },
580       {
581           torch::tensor(
582               {0.17963666098500394,
583                0.24250882442377164,
584                0.17963666099348902,
585                0.2425088244120223,
586                0.1796366609725109,
587                0.24250882441058697}),
588           torch::tensor(
589               {0.6287216343798432, 0.6287216343800675, 0.6287216343742645}),
590           torch::tensor(
591               {-1.4123467490742723, -1.412346749072554, -1.4123467490678536}),
592           torch::tensor({-2.0633690434425396}),
593       },
594       {
595           torch::tensor(
596               {0.17963666098407144,
597                0.24250882442250174,
598                0.17963666098407347,
599                0.2425088244224233,
600                0.17963666098409325,
601                0.2425088244225157}),
602           torch::tensor(
603               {0.6287216343836609, 0.6287216343836147, 0.6287216343836255}),
604           torch::tensor(
605               {-1.412346749067226, -1.412346749067243, -1.412346749067169}),
606           torch::tensor({-2.063369043434909}),
607       },
608       {
609           torch::tensor(
610               {0.17963666098406988,
611                0.2425088244224408,
612                0.17963666098407077,
613                0.24250882442244073,
614                0.17963666098407008,
615                0.2425088244224409}),
616           torch::tensor(
617               {0.6287216343837067, 0.6287216343837065, 0.6287216343837069}),
618           torch::tensor(
619               {-1.4123467490671706, -1.412346749067171, -1.4123467490671713}),
620           torch::tensor({-2.0633690434349057}),
621       },
622       {
623           torch::tensor(
624               {0.17963666098407038,
625                0.24250882442244104,
626                0.17963666098407027,
627                0.24250882442244104,
628                0.17963666098407025,
629                0.24250882442244098}),
630           torch::tensor(
631               {0.6287216343837067, 0.628721634383707, 0.6287216343837067}),
632           torch::tensor(
633               {-1.4123467490671706, -1.4123467490671708, -1.4123467490671706}),
634           torch::tensor({-2.0633690434349052}),
635       },
636       {
637           torch::tensor(
638               {0.1796366609840706,
639                0.24250882442244143,
640                0.17963666098407047,
641                0.24250882442244096,
642                0.17963666098407025,
643                0.24250882442244098}),
644           torch::tensor(
645               {0.6287216343837069, 0.6287216343837067, 0.6287216343837067}),
646           torch::tensor(
647               {-1.4123467490671706, -1.4123467490671706, -1.4123467490671708}),
648           torch::tensor({-2.0633690434349052}),
649       },
650       {
651           torch::tensor(
652               {0.1796366609840692,
653                0.24250882442244046,
654                0.17963666098407022,
655                0.24250882442244082,
656                0.17963666098407,
657                0.24250882442244104}),
658           torch::tensor(
659               {0.6287216343837063, 0.6287216343837068, 0.6287216343837067}),
660           torch::tensor(
661               {-1.4123467490671708, -1.4123467490671706, -1.4123467490671708}),
662           torch::tensor({-2.0633690434349052}),
663       },
664   };
665 }
666 
667 inline std::vector<std::vector<torch::Tensor>>
Adam_with_weight_decay_and_amsgrad()668 Adam_with_weight_decay_and_amsgrad() {
669   return {
670       {
671           torch::tensor(
672               {0.7890972867575196,
673                0.5024410692260988,
674                0.8587073313091852,
675                0.6579707241257546,
676                0.7476356819241026,
677                1.6975564206261673}),
678           torch::tensor(
679               {0.8914676360248869, 0.7020513497574256, 1.6892012709389561}),
680           torch::tensor(
681               {-1.050803095846074, -1.3941351509567128, -1.284337577714342}),
682           torch::tensor({-1.0711381102987145}),
683       },
684       {
685           torch::tensor(
686               {6.790598887618061,
687                6.914995398136696,
688                6.41533478566264,
689                6.297644005485053,
690                5.845162499872375,
691                6.862229173597117}),
692           torch::tensor(
693               {7.958707058914726, 6.511338624975532, 7.100969502256063}),
694           torch::tensor(
695               {-6.690689640539306, -7.056584601121166, -6.72114879738572}),
696           torch::tensor({-6.406608295022552}),
697       },
698       {
699           torch::tensor(
700               {4.707506618354547,
701                5.291519064582759,
702                6.0451502264500006,
703                6.024403702678936,
704                5.309533822430375,
705                6.388110918107735}),
706           torch::tensor(
707               {7.200495189000188, 6.398387074819269, 6.904125817198589}),
708           torch::tensor(
709               {-6.664150387053514, -7.026716194929788, -6.705821732490459}),
710           torch::tensor({-6.396310969025695}),
711       },
712       {
713           torch::tensor(
714               {2.9508632109188633,
715                3.7657755643994775,
716                5.60741774331852,
717                5.6957903028180565,
718                4.70145185833677,
719                5.835064148607041}),
720           torch::tensor(
721               {6.343524400109462, 6.258242740866945, 6.663022973860484}),
722           torch::tensor(
723               {-6.630461605133603, -6.988854886932907, -6.686194352796841}),
724           torch::tensor({-6.383010489575922}),
725       },
726       {
727           torch::tensor(
728               {1.7128944635692829,
729                2.536365345915568,
730                5.140416924817106,
731                5.33803266121343,
732                4.083921806116116,
733                5.254596369238127}),
734           torch::tensor(
735               {5.477917349690043, 6.100068192681452, 6.394747035239918}),
736           torch::tensor(
737               {-6.591742325353548, -6.945355504749947, -6.663584873152447}),
738           torch::tensor({-6.367675348676994}),
739       },
740       {
741           torch::tensor(
742               {0.9341502247285258,
743                1.6339620765410685,
744                4.6679910940755835,
745                4.967688979298023,
746                3.4933141073198866,
747                4.678195347615295}),
748           torch::tensor(
749               {4.655117321178743, 5.9294836450698245, 6.110061909503652}),
750           torch::tensor(
751               {-6.549116242899458, -6.897486328511809, -6.638629863638681}),
752           torch::tensor({-6.350731975696792}),
753       },
754       {
755           torch::tensor(
756               {0.483518223008081,
757                1.014261673266094,
758                4.205928052060015,
759                4.596195204751035,
760                2.9502123780175378,
761                4.125826973031755}),
762           torch::tensor(
763               {3.90359317709392, 5.750505227817309, 5.816617309738603}),
764           torch::tensor(
765               {-6.503371263778851, -6.846137347295816, -6.611773185243846}),
766           torch::tensor({-6.3324769650576656}),
767       },
768       {
769           torch::tensor(
770               {0.2393576446248765,
771                0.6100241101779533,
772                3.764601561942264,
773                4.231602962540335,
774                2.4647709193637635,
775                3.6096961114614476}),
776           torch::tensor(
777               {3.236721556734532, 5.566168160977344, 5.520085344708356}),
778           torch::tensor(
779               {-6.455100840665729, -6.791979259673106, -6.583347815648856}),
780           torch::tensor({-6.3131323905801136}),
781       },
782       {
783           torch::tensor(
784               {0.11401265024593016,
785                0.3570521972760832,
786                3.350517925954931,
787                3.8795333009419823,
788                2.0402068661130683,
789                3.136550602110189}),
790           torch::tensor(
791               {2.6579570162250215, 5.378834741966309, 5.224742933241745}),
792           torch::tensor(
793               {-6.4047717084756375, -6.7355397098532155, -6.55361495837462}),
794           torch::tensor({-6.2928722155069945}),
795       },
796       {
797           torch::tensor(
798               {0.05251515193791458,
799                0.20410212473600725,
800                2.9673680881961273,
801                3.543794405777883,
802                1.6752677855061209,
803                2.709287985107431}),
804           torch::tensor(
805               {2.164479166686583, 5.190372657839918, 4.9338240234040756}),
806           torch::tensor(
807               {-6.352761531270841, -6.6772456859648175, -6.52278570167088}),
808           torch::tensor({-6.271836898876738}),
809       },
810       {
811           torch::tensor(
812               {0.023489947480426834,
813                0.11428338573638941,
814                2.616797245623764,
815                3.226821571853439,
816                1.3659994589608537,
817                2.328136084816453}),
818           torch::tensor(
819               {1.74978620664146, 5.002269811977871, 4.649756802441968}),
820           torch::tensor(
821               {-6.299382007948917, -6.617449564196286, -6.491034254081261}),
822           torch::tensor({-6.250142306631259}),
823       },
824   };
825 }
826 
AdamW()827 inline std::vector<std::vector<torch::Tensor>> AdamW() {
828   return {
829       {
830           torch::tensor(
831               {0.7912062750121864,
832                0.5074166292785842,
833                0.8601202529258052,
834                0.6613910130887053,
835                0.7501593169903569,
836                1.6905808503961983}),
837           torch::tensor(
838               {0.8925529482073002, 0.7050308347536254, 1.682309255842939}),
839           torch::tensor(
840               {-1.05029506454492, -1.3901937990816595, -1.2814942017397601}),
841           torch::tensor({-1.0704267290556988}),
842       },
843       {
844           torch::tensor(
845               {3.3165329599188507,
846                3.223120441823618,
847                2.665544565239194,
848                2.6044341406663225,
849                2.479859063483047,
850                2.836831717112226}),
851           torch::tensor(
852               {3.3885192024669744, 2.6544147219174556, 2.8709245656887328}),
853           torch::tensor(
854               {-2.70172647102137, -2.836731459490802, -2.69652471546253}),
855           torch::tensor({-2.575239255076019}),
856       },
857       {
858           torch::tensor(
859               {2.231471944853865,
860                2.3549328325971755,
861                1.5699078054795328,
862                1.6160272935884685,
863                1.5339085081403547,
864                1.7397405105941612}),
865           torch::tensor(
866               {2.8552579170807926, 1.8369866847839356, 1.9735168512425862}),
867           torch::tensor(
868               {-2.6042083360293855, -2.6996673713262336, -1.8976087706977893}),
869           torch::tensor({-1.6180915942867784}),
870       },
871       {
872           torch::tensor(
873               {2.084688381515552,
874                2.3141612674892946,
875                1.4850714710140511,
876                1.5961047256668386,
877                1.440300645879787,
878                1.6065354941586025}),
879           torch::tensor(
880               {3.0111385685659444, 1.955556497153507, 1.9596562467797627}),
881           torch::tensor(
882               {-2.889337305884852, -2.965249100126337, -1.7721676671605975}),
883           torch::tensor({-1.4001341655590005}),
884       },
885       {
886           torch::tensor(
887               {2.0465343456006604,
888                2.311613891239368,
889                1.4666717526896398,
890                1.601383980913499,
891                1.4223660595993763,
892                1.5711552625612757}),
893           torch::tensor(
894               {3.07151984580744, 2.0112690538174802, 1.9592484602763875}),
895           torch::tensor(
896               {-3.0186469726426863, -3.093855445542849, -1.7367953899738784}),
897           torch::tensor({-1.3299011560804312}),
898       },
899       {
900           torch::tensor(
901               {2.039659777412556,
902                2.3178034179536273,
903                1.4654302718412722,
904                1.6094701969162322,
905                1.4230510816446773,
906                1.565168902852383}),
907           torch::tensor(
908               {3.1007583934270064, 2.039757113618415, 1.9652096140698696}),
909           torch::tensor(
910               {-3.0880626664330832, -3.166705422245348, -1.73538367534238}),
911           torch::tensor({-1.3130428735015893}),
912       },
913       {
914           torch::tensor(
915               {2.0413773043991963,
916                2.3251469369586366,
917                1.4690808101517236,
918                1.6174065798291044,
919                1.4280274009117935,
920                1.5682418226469732}),
921           torch::tensor(
922               {3.118843540209399, 2.057729936485249, 1.9742319629710936}),
923           torch::tensor(
924               {-3.1331019663177013, -3.2154332694373107, -1.7459831639793468}),
925           torch::tensor({-1.3148644134154366}),
926       },
927       {
928           torch::tensor(
929               {2.0452604138357113,
930                2.332074253989847,
931                1.4738845773449165,
932                1.6246403004735728,
933                1.4335712611625357,
934                1.573826630920094}),
935           torch::tensor(
936               {3.1324088069784093, 2.0711763619826575, 1.9841582498316732}),
937           torch::tensor(
938               {-3.16737058959847, -3.2529206463859146, -1.7602788393925501}),
939           torch::tensor({-1.32281766461531}),
940       },
941       {
942           torch::tensor(
943               {2.0495243704493262,
944                2.338413341249581,
945                1.4787599440132637,
946                1.631210274009555,
947                1.438849155552895,
948                1.5798736919537595}),
949           torch::tensor(
950               {3.1438209015414227, 2.0823943437659658, 1.9940075805973108}),
951           torch::tensor(
952               {-3.19628690363529, -3.2845941643030367, -1.7752333900055153}),
953           torch::tensor({-1.332456718314933}),
954       },
955       {
956           torch::tensor(
957               {2.0536979895206295,
958                2.3442520601250334,
959                1.4834272584224222,
960                1.6372462654983486,
961                1.4437517398490174,
962                1.585780877834892}),
963           torch::tensor(
964               {3.1540081461072447, 2.0923381262560454, 2.0034284957296107}),
965           torch::tensor(
966               {-3.222142968201519, -3.312867602521477, -1.7898220261118043}),
967           torch::tensor({-1.3422692037690986}),
968       },
969       {
970           torch::tensor(
971               {2.0576784836825315,
972                2.3496934395759377,
973                1.4878413407927933,
974                1.6428612479757005,
975                1.4483225979568104,
976                1.5914034339763325}),
977           torch::tensor(
978               {3.163383747232199, 2.101446878895216, 2.012344413569353}),
979           torch::tensor(
980               {-3.246000281299229, -3.338904166978488, -1.8037666936489785}),
981           torch::tensor({-1.3517884775416527}),
982       },
983   };
984 }
985 
AdamW_without_weight_decay()986 inline std::vector<std::vector<torch::Tensor>> AdamW_without_weight_decay() {
987   return {
988       {
989           torch::tensor(
990               {0.7890972864438476,
991                0.5024410688121617,
992                0.858707331305558,
993                0.6579707241208395,
994                0.7476356819075531,
995                1.6975564206516922}),
996           torch::tensor(
997               {0.891467636010675, 0.70205134975675, 1.689201270942895}),
998           torch::tensor(
999               {-1.0508030958460797, -1.3941351509567654, -1.284337577714353}),
1000           torch::tensor({-1.071138110298716}),
1001       },
1002       {
1003           torch::tensor(
1004               {8.233039313231831,
1005                7.971150747377481,
1006                6.643620950677599,
1007                6.47097740790054,
1008                6.170125488259256,
1009                7.150739103343502}),
1010           torch::tensor(
1011               {8.417695070103738, 6.597188212844593, 7.23175710827678}),
1012           torch::tensor(
1013               {-6.729624357635757, -7.09743493108154, -6.753301896575352}),
1014           torch::tensor({-6.435639096011218}),
1015       },
1016       {
1017           torch::tensor(
1018               {8.233424596059299,
1019                7.971537360032308,
1020                6.643920150720393,
1021                6.471278075537239,
1022                6.170405874224489,
1023                7.151021086137983}),
1024           torch::tensor(
1025               {8.418084791214298, 6.597493171180545, 7.232043740621598}),
1026           torch::tensor(
1027               {-6.729918250724671, -7.097730102046093, -6.753584809755359}),
1028           torch::tensor({-6.4359165566974985}),
1029       },
1030       {
1031           torch::tensor(
1032               {8.233424610557652,
1033                7.971537374586563,
1034                6.643920161995284,
1035                6.471278086877828,
1036                6.170405884785074,
1037                7.151021096766406}),
1038           torch::tensor(
1039               {8.418084805901906, 6.597493182713584, 7.2320437514477875}),
1040           torch::tensor(
1041               {-6.72991829363266, -7.097730147102975, -6.753584838821182}),
1042           torch::tensor({-6.435916580217771}),
1043       },
1044       {
1045           torch::tensor(
1046               {8.233424610575105,
1047                7.971537374611125,
1048                6.643920162027961,
1049                6.471278086923277,
1050                6.170405884809245,
1051                7.151021096800041}),
1052           torch::tensor(
1053               {8.418084805946393, 6.597493182796847, 7.232043751509309}),
1054           torch::tensor(
1055               {-6.729918332327653, -7.097730188349552, -6.753584861205486}),
1056           torch::tensor({-6.435916596115672}),
1057       },
1058       {
1059           torch::tensor(
1060               {8.233424610594861,
1061                7.971537374639166,
1062                6.64392016206557,
1063                6.471278086975758,
1064                6.170405884836981,
1065                7.151021096838798}),
1066           torch::tensor(
1067               {8.418084805997617, 6.59749318289335, 7.232043751580523}),
1068           torch::tensor(
1069               {-6.72991837738045, -7.097730236373201, -6.753584887267492}),
1070           torch::tensor({-6.43591661462546}),
1071       },
1072       {
1073           torch::tensor(
1074               {8.233424610617291,
1075                7.971537374671012,
1076                6.643920162108284,
1077                6.471278087035361,
1078                6.170405884868481,
1079                7.151021096882812}),
1080           torch::tensor(
1081               {8.418084806055798, 6.59749318300295, 7.232043751661401}),
1082           torch::tensor(
1083               {-6.729918428547273, -7.09773029091405, -6.753584916866329}),
1084           torch::tensor({-6.4359166356471755}),
1085       },
1086       {
1087           torch::tensor(
1088               {8.233424610642356,
1089                7.9715373747065925,
1090                6.643920162156006,
1091                6.471278087101954,
1092                6.1704058849036745,
1093                7.15102109693199}),
1094           torch::tensor(
1095               {8.418084806120802, 6.597493183125404, 7.232043751751764}),
1096           torch::tensor(
1097               {-6.729918485714688, -7.0977303518511805, -6.753584949936365}),
1098           torch::tensor({-6.43591665913422}),
1099       },
1100       {
1101           torch::tensor(
1102               {8.233424610670038,
1103                7.97153737474589,
1104                6.643920162208714,
1105                6.471278087175501,
1106                6.170405884942545,
1107                7.151021096986303}),
1108           torch::tensor(
1109               {8.418084806192596, 6.597493183260647, 7.232043751851564}),
1110           torch::tensor(
1111               {-6.729918548853505, -7.097730419153473, -6.753584986460725}),
1112           torch::tensor({-6.435916685074594}),
1113       },
1114       {
1115           torch::tensor(
1116               {8.233424610700352,
1117                7.971537374788922,
1118                6.643920162266432,
1119                6.47127808725604,
1120                6.17040588498511,
1121                7.1510210970457795}),
1122           torch::tensor(
1123               {8.418084806271217, 6.597493183408747, 7.232043751960854}),
1124           torch::tensor(
1125               {-6.7299186179943, -7.097730492853521, -6.753585026457088}),
1126           torch::tensor({-6.435916713480863}),
1127       },
1128       {
1129           torch::tensor(
1130               {8.23342461073333,
1131                7.971537374835737,
1132                6.643920162329224,
1133                6.471278087343658,
1134                6.170405885031416,
1135                7.151021097110484}),
1136           torch::tensor(
1137               {8.418084806356747, 6.597493183569867, 7.232043752079749}),
1138           torch::tensor(
1139               {-6.729918693213275, -7.097730573032567, -6.753585069969552}),
1140           torch::tensor({-6.43591674438434}),
1141       },
1142   };
1143 }
1144 
AdamW_with_amsgrad()1145 inline std::vector<std::vector<torch::Tensor>> AdamW_with_amsgrad() {
1146   return {
1147       {
1148           torch::tensor(
1149               {0.7912062750121864,
1150                0.5074166292785842,
1151                0.8601202529258052,
1152                0.6613910130887053,
1153                0.7501593169903569,
1154                1.6905808503961983}),
1155           torch::tensor(
1156               {0.8925529482073002, 0.7050308347536254, 1.682309255842939}),
1157           torch::tensor(
1158               {-1.05029506454492, -1.3901937990816595, -1.2814942017397601}),
1159           torch::tensor({-1.0704267290556988}),
1160       },
1161       {
1162           torch::tensor(
1163               {3.3017259270507915,
1164                3.2082991753694565,
1165                2.653930978510442,
1166                2.5927674339810585,
1167                2.4689608790182933,
1168                2.825873703467739}),
1169           torch::tensor(
1170               {3.373698198112671, 2.6425942964586664, 2.8597930424244304}),
1171           torch::tensor(
1172               {-2.690360632302962, -2.8253191596069525, -2.6855499873057473}),
1173           torch::tensor({-2.5644658591929406}),
1174       },
1175       {
1176           torch::tensor(
1177               {2.222607725541013,
1178                2.3447188854637004,
1179                1.5614270655258826,
1180                1.606610018462357,
1181                1.5260497191448619,
1182                1.7309643622674138}),
1183           torch::tensor(
1184               {2.84137462783552, 1.824806600633721, 1.9620493659996037}),
1185           torch::tensor(
1186               {-2.576642773625787, -2.6706153846815766, -1.8799876863754623}),
1187           torch::tensor({-1.6044722984810953}),
1188       },
1189       {
1190           torch::tensor(
1191               {2.0739558768648205,
1192                2.3008338863863496,
1193                1.4738888208638767,
1194                1.5829485271829449,
1195                1.4296176764284294,
1196                1.5939984909850073}),
1197           torch::tensor(
1198               {2.9908013612792415, 1.936590940953305, 1.941691630199464}),
1199           torch::tensor(
1200               {-2.846562884997548, -2.9195962101501203, -1.746484716887341}),
1201           torch::tensor({-1.381525131003179}),
1202       },
1203       {
1204           torch::tensor(
1205               {2.0333926094256953,
1206                2.294977109171754,
1207                1.452870514716895,
1208                1.584853677999522,
1209                1.4086299433181402,
1210                1.5548201727855224}),
1211           torch::tensor(
1212               {3.0454817801193976, 1.9867169062383696, 1.935312753106444}),
1213           torch::tensor(
1214               {-2.9612116762746394, -3.0322275992001084, -1.7026114905180725}),
1215           torch::tensor({-1.30563541393247}),
1216       },
1217       {
1218           torch::tensor(
1219               {2.02392417168201,
1220                2.2977279587859,
1221                1.4488511120131309,
1222                1.5894646930743725,
1223                1.406253686759073,
1224                1.5450647949022756}),
1225           torch::tensor(
1226               {3.069025602955343, 2.0096872138967488, 1.935438546309299}),
1227           torch::tensor(
1228               {-3.016103148166836, -3.0893062953033583, -1.6925290685615872}),
1229           torch::tensor({-1.282870120405012}),
1230       },
1231       {
1232           torch::tensor(
1233               {2.0230257817348316,
1234                2.3016167065040647,
1235                1.4496901978629444,
1236                1.5939034289777392,
1237                1.4082421794430946,
1238                1.5444538756003756}),
1239           torch::tensor(
1240               {3.0814016132787954, 2.022150201844143, 1.9387429991308658}),
1241           torch::tensor(
1242               {-3.0466485946438406, -3.1223144611322446, -1.6944083009127773}),
1243           torch::tensor({-1.2786980736911064}),
1244       },
1245       {
1246           torch::tensor(
1247               {2.024305922065404,
1248                2.305101549510461,
1249                1.4516863420588493,
1250                1.5976447954882376,
1251                1.4108596097183552,
1252                1.5464236284303425}),
1253           torch::tensor(
1254               {3.0892574233545065, 2.0300944858242236, 1.943040321845021}),
1255           torch::tensor(
1256               {-3.0664189567897306, -3.1440820888166425, -1.6999750448893618}),
1257           torch::tensor({-1.2806281811826203}),
1258       },
1259       {
1260           torch::tensor(
1261               {2.0259854116237364,
1262                2.3080169152218293,
1263                1.4537680296813915,
1264                1.6007369432392426,
1265                1.4132529823064277,
1266                1.5489035046525346}),
1267           torch::tensor(
1268               {3.0949717180851337, 2.0358251379915764, 1.9473249654893}),
1269           torch::tensor(
1270               {-3.0808231377905426, -3.160021699873689, -1.7062031001273494}),
1271           torch::tensor({-1.28423401369705}),
1272       },
1273       {
1274           torch::tensor(
1275               {2.0275923948348638,
1276                2.3104512601389637,
1277                1.455657715721078,
1278                1.6033123357613526,
1279                1.4153003204463288,
1280                1.5512775896116622}),
1281           torch::tensor(
1282               {3.099479021299846, 2.0403012223048775, 1.9512285931847464}),
1283           torch::tensor(
1284               {-3.092151979336299, -3.1725453680885267, -1.7120689614428697}),
1285           torch::tensor({-1.2880095517062655}),
1286       },
1287       {
1288           torch::tensor(
1289               {2.029022468328371,
1290                2.3125066985045892,
1291                1.4573100228823295,
1292                1.605484259933419,
1293                1.417038245960655,
1294                1.5533932056240227}),
1295           torch::tensor(
1296               {3.103195011518616, 2.043964003458376, 1.9546640840748621}),
1297           torch::tensor(
1298               {-3.1014680843131184, -3.1828179298513968, -1.7172933346797972}),
1299           torch::tensor({-1.2914899987134136}),
1300       },
1301   };
1302 }
1303 
Adagrad()1304 inline std::vector<std::vector<torch::Tensor>> Adagrad() {
1305   return {
1306       {
1307           torch::tensor(
1308               {0.7891011045987429,
1309                0.502443924512199,
1310                0.8587078329085825,
1311                0.6579710994224826,
1312                0.7476364836215006,
1313                1.697557019500397}),
1314           torch::tensor(
1315               {0.8914687688941954, 0.7020514988069096, 1.6892015076050444}),
1316           torch::tensor(
1317               {-1.0508031297732776, -1.3941351871450518, -1.284337597261839}),
1318           torch::tensor({-1.071138124161711}),
1319       },
1320       {
1321           torch::tensor(
1322               {2.4079229696892583,
1323                2.2346803754764286,
1324                1.6967885588547365,
1325                1.552279695827649,
1326                1.2259044248443602,
1327                2.221279696180243}),
1328           torch::tensor(
1329               {2.9334079162217193, 1.7619824934767887, 2.3464577179091473}),
1330           torch::tensor(
1331               {-2.221396083069719, -2.549950976011168, -1.9709315957317095}),
1332           torch::tensor({-1.5858816837541876}),
1333       },
1334       {
1335           torch::tensor(
1336               {2.510404433941812,
1337                2.3522584510262887,
1338                1.7921695110761213,
1339                1.657755825836846,
1340                1.2891186618593045,
1341                2.291878516133922}),
1342           torch::tensor(
1343               {3.092171180776419, 1.8971624370952997, 2.438734251283465}),
1344           torch::tensor(
1345               {-2.437641633486504, -2.7704264590526573, -2.0949471699460225}),
1346           torch::tensor({-1.6769121890401757}),
1347       },
1348       {
1349           torch::tensor(
1350               {2.5652648968109415,
1351                2.4155313947260972,
1352                1.844241233613541,
1353                1.7156513351246399,
1354                1.3245206506797171,
1355                2.3315409972138825}),
1356           torch::tensor(
1357               {3.178399916514377, 1.9721945764936502, 2.4909037706250428}),
1358           torch::tensor(
1359               {-2.5658710403147933, -2.901921821645266, -2.168560672193225}),
1360           torch::tensor({-1.7307903926154131}),
1361       },
1362       {
1363           torch::tensor(
1364               {2.6021584494332592,
1365                2.4582101324909065,
1366                1.8796060082750778,
1367                1.7550965207414717,
1368                1.3489253597999988,
1369                2.3589345190118247}),
1370           torch::tensor(
1371               {3.2368674310041516, 2.0236468833666894, 2.52707132741292}),
1372           torch::tensor(
1373               {-2.6573969292994164, -2.9960731060650505, -2.2211375717304076}),
1374           torch::tensor({-1.7692090167089707}),
1375       },
1376       {
1377           torch::tensor(
1378               {2.629700772579208,
1379                2.4901377017698683,
1380                1.906173477530586,
1381                1.7847957161833832,
1382                1.3674517119505822,
1383                2.3797578857769905}),
1384           torch::tensor(
1385               {3.2807643102638546, 2.062561811940094, 2.5546379424362775}),
1386           torch::tensor(
1387               {-2.7286379977755035, -3.0695109399636236, -2.262081199960513}),
1388           torch::tensor({-1.7990936323432214}),
1389       },
1390       {
1391           torch::tensor(
1392               {2.6515471766995247,
1393                2.51550257362603,
1394                1.927341363452414,
1395                1.8084994719811576,
1396                1.3823309942932445,
1397                2.3964995243914373}),
1398           torch::tensor(
1399               {3.3157334001309473, 2.093728023484945, 2.5768468697402924}),
1400           torch::tensor(
1401               {-2.786981763434855, -3.129746439571402, -2.29562487034177}),
1402           torch::tensor({-1.8235564908139104}),
1403       },
1404       {
1405           torch::tensor(
1406               {2.6695780544837886,
1407                2.53646401614724,
1408                1.9448721033433505,
1409                1.828157582353901,
1410                1.3947329882074622,
1411                2.4104657178934947}),
1412           torch::tensor(
1413               {3.344694775590452, 2.1196465761628516, 2.5954050923596252}),
1414           torch::tensor(
1415               {-2.8363936812536537, -3.1808219609745194, -2.32404190866147}),
1416           torch::tensor({-1.8442667636913117}),
1417       },
1418       {
1419           torch::tensor(
1420               {2.684883801533072,
1421                2.5542762192735515,
1422                1.9597939532350015,
1423                1.844909608012419,
1424                1.4053459079217485,
1425                2.4224257790968386}),
1426           torch::tensor(
1427               {3.369349515259956, 2.1417845308976795, 2.611319989214332}),
1428           torch::tensor(
1429               {-2.879251075341889, -3.225165734647855, -2.3486956737228057}),
1430           torch::tensor({-1.86222449978646}),
1431       },
1432       {
1433           torch::tensor(
1434               {2.698151012423769,
1435                2.56972998600169,
1436                1.972757472697587,
1437                1.8594775691681182,
1438                1.4146081751022495,
1439                2.43287021079559}),
1440           torch::tensor(
1441               {3.390772758897601, 2.1610741754331757, 2.6252349489549824}),
1442           torch::tensor(
1443               {-2.917092322961074, -3.264351563375218, -2.370468664387175}),
1444           torch::tensor({-1.8780765115117757}),
1445       },
1446       {
1447           torch::tensor(
1448               {2.7098389356033783,
1449                2.5833548721723747,
1450                1.9841994925173085,
1451                1.8723468731726323,
1452                1.4228158926355312,
1453                2.4421305315945085}),
1454           torch::tensor(
1455               {3.4096859099156673, 2.178143852041279, 2.6375854547611364}),
1456           torch::tensor(
1457               {-2.9509704554208467, -3.2994581338995044, -2.3899651139415874}),
1458           torch::tensor({-1.8922653655195538}),
1459       },
1460   };
1461 }
1462 
Adagrad_with_weight_decay()1463 inline std::vector<std::vector<torch::Tensor>> Adagrad_with_weight_decay() {
1464   return {
1465       {
1466           torch::tensor(
1467               {0.7891011218979068,
1468                0.5024439415126254,
1469                0.8587078332470682,
1470                0.6579710998575992,
1471                0.7476364849956589,
1472                1.6975570150849029}),
1473           torch::tensor(
1474               {0.8914687701583902, 0.7020514988715463, 1.6892015071335027}),
1475           torch::tensor(
1476               {-1.0508031297726799, -1.3941351871397083, -1.2843375972607243}),
1477           torch::tensor({-1.0711381241615712}),
1478       },
1479       {
1480           torch::tensor(
1481               {0.1846116678522213,
1482                0.24944077103107917,
1483                0.18651745437755768,
1484                0.25219093533041764,
1485                0.18712037968446713,
1486                0.25289206444055234}),
1487           torch::tensor(
1488               {0.6482869597891656, 0.6580215784646755, 0.6581256007663537}),
1489           torch::tensor(
1490               {-1.454709711443681, -1.4748063405174818, -1.4811625946604765}),
1491           torch::tensor({-1.905292836544363}),
1492       },
1493       {
1494           torch::tensor(
1495               {0.18059895999281475,
1496                0.2438515539257779,
1497                0.18067177884778182,
1498                0.24397186395008694,
1499                0.18168388351830797,
1500                0.24533853846052017}),
1501           torch::tensor(
1502               {0.6325250261983028, 0.6331827793513023, 0.6366659383355598}),
1503           torch::tensor(
1504               {-1.420803333750877, -1.4215627240541653, -1.4320264544533396}),
1505           torch::tensor({-2.030135641848322}),
1506       },
1507       {
1508           torch::tensor(
1509               {0.17981392697398363,
1510                0.2427571544305695,
1511                0.17981150414451733,
1512                0.24275725992310523,
1513                0.18014798619115763,
1514                0.2432144956227816}),
1515           torch::tensor(
1516               {0.6294321320817985, 0.6294873737410742, 0.6306958589251878}),
1517           torch::tensor(
1518               {-1.4139253354785764, -1.413902680470981, -1.4173628530293867}),
1519           torch::tensor({-2.056210117690093}),
1520       },
1521       {
1522           torch::tensor(
1523               {0.17967006242163747,
1524                0.24255582734557277,
1525                0.17966873677301953,
1526                0.24255462870545766,
1527                0.1797588230898851,
1528                0.24267729072765756}),
1529           torch::tensor(
1530               {0.6288576295241085, 0.6288643132826753, 0.6291921485342001}),
1531           torch::tensor(
1532               {-1.4126465879787569, -1.4126335126907266, -1.4135586793353685}),
1533           torch::tensor({-2.0618018405404825}),
1534       },
1535       {
1536           torch::tensor(
1537               {0.17964321284685653,
1538                0.24251808241139364,
1539                0.17964291377171066,
1540                0.24251779104198598,
1541                0.1796651574178102,
1542                0.2425481059085456}),
1543           torch::tensor(
1544               {0.628748693136779, 0.6287498167193976, 0.6288312441271762}),
1545           torch::tensor(
1546               {-1.412405895385289, -1.4124029484481164, -1.4126313051315378}),
1547           torch::tensor({-2.0630223163099304}),
1548       },
1549       {
1550           torch::tensor(
1551               {0.1796379973927849,
1552                0.2425107191246215,
1553                0.1796379363134342,
1554                0.2425106592536174,
1555                0.17964321802205502,
1556                0.24251786094585864}),
1557           torch::tensor(
1558               {0.6287272170354161, 0.6287274414587727, 0.6287468362309863}),
1559           torch::tensor(
1560               {-1.4123588626342634, -1.412358263650784, -1.412412480569672}),
1561           torch::tensor({-2.0632918101480255}),
1562       },
1563       {
1564           torch::tensor(
1565               {0.17963694231402444,
1566                0.24250922426893615,
1567                0.1796369298071621,
1568                0.2425092121007451,
1569                0.17963815759528073,
1570                0.24251088666939838}),
1571           torch::tensor(
1572               {0.6287228195255881, 0.6287228675172439, 0.628727383936762}),
1573           torch::tensor(
1574               {-1.4123493065102781, -1.412349184462438, -1.4123617872243597}),
1575           torch::tensor({-2.063351765096138}),
1576       },
1577       {
1578           torch::tensor(
1579               {0.1796367215904632,
1580                0.24250891070978667,
1581                0.17963671897003045,
1582                0.24250890818318074,
1583                0.17963700091084855,
1584                0.24250929278196054}),
1585           torch::tensor(
1586               {0.6287218911936107, 0.6287219017313679, 0.6287229399204574}),
1587           torch::tensor(
1588               {-1.4123473011084142, -1.412347275640343, -1.41235016959507}),
1589           torch::tensor({-2.0633651674043505}),
1590       },
1591       {
1592           torch::tensor(
1593               {0.17963667424978783,
1594                0.24250884333120687,
1595                0.17963667368829764,
1596                0.24250884279379462,
1597                0.17963673796131557,
1598                0.24250893047794023}),
1599           torch::tensor(
1600               {0.6287216908150736, 0.6287216931558691, 0.6287219299749583}),
1601           torch::tensor(
1602               {-1.4123468700596724, -1.4123468646187736, -1.4123475243360133}),
1603           torch::tensor({-2.0633681724342527}),
1604       },
1605       {
1606           torch::tensor(
1607               {0.1796366639185348,
1608                0.24250882860835257,
1609                0.17963666379614568,
1610                0.24250882849182892,
1611                0.17963667838367053,
1612                0.2425088483939741}),
1613           torch::tensor(
1614               {0.6287216468984888, 0.6287216474215305, 0.6287217011907862}),
1615           torch::tensor(
1616               {-1.4123467758545658, -1.412346774671038, -1.4123469244007658}),
1617           torch::tensor({-2.0633688474977467}),
1618       },
1619   };
1620 }
1621 
1622 inline std::vector<std::vector<torch::Tensor>>
Adagrad_with_weight_decay_and_lr_decay()1623 Adagrad_with_weight_decay_and_lr_decay() {
1624   return {
1625       {
1626           torch::tensor(
1627               {0.7891011046018798,
1628                0.5024439245163383,
1629                0.8587078329086189,
1630                0.6579710994225316,
1631                0.747636483621666,
1632                1.697557019500142}),
1633           torch::tensor(
1634               {0.8914687688943375, 0.7020514988069164, 1.6892015076050049}),
1635           torch::tensor(
1636               {-1.0508031297732776, -1.3941351871450511, -1.284337597261839}),
1637           torch::tensor({-1.0711381241617108}),
1638       },
1639       {
1640           torch::tensor(
1641               {2.346218944110103,
1642                2.191939439502003,
1643                1.683355201740813,
1644                1.5405520021635604,
1645                1.2137800230828062,
1646                2.205283463717303}),
1647           torch::tensor(
1648               {2.9090564593404, 1.7509657336815554, 2.336166413186925}),
1649           torch::tensor(
1650               {-2.206159683368316, -2.5344318233445415, -1.9622783535807609}),
1651           torch::tensor({-1.5796101463783623}),
1652       },
1653       {
1654           torch::tensor(
1655               {2.3889328781057233,
1656                2.2678221038007296,
1657                1.7667624725138267,
1658                1.6358015176639822,
1659                1.2655767687152566,
1660                2.261088056711282}),
1661           torch::tensor(
1662               {3.045569451994985, 1.8770196253823253, 2.4192707519566765}),
1663           torch::tensor(
1664               {-2.4079300017528613, -2.7399112002234305, -2.0780613510632375}),
1665           torch::tensor({-1.664722108226537}),
1666       },
1667       {
1668           torch::tensor(
1669               {2.3886137557806384,
1670                2.2922158071009178,
1671                1.8078384116424007,
1672                1.684352474440932,
1673                1.290353948335789,
1674                2.2870715509706496}),
1675           torch::tensor(
1676               {3.111110355394278, 1.9438501730282314, 2.4630249355872826}),
1677           torch::tensor(
1678               {-2.5226122034499263, -2.857315093916292, -2.143964860243905}),
1679           torch::tensor({-1.7130685809905042}),
1680       },
1681       {
1682           torch::tensor(
1683               {2.374703352203156,
1684                2.298804499257456,
1685                1.8330249458212446,
1686                1.7151661013307244,
1687                1.3048586226945842,
1688                2.3017650590464274}),
1689           torch::tensor(
1690               {3.150318222034133, 1.9877926185369321, 2.491399976401679}),
1691           torch::tensor(
1692               {-2.601415913361488, -2.938203895113964, -2.1892988334550028}),
1693           torch::tensor({-1.7462964261966805}),
1694       },
1695       {
1696           torch::tensor(
1697               {2.3553658567303812,
1698                2.297191758042688,
1699                1.8501154749072124,
1700                1.736836058688188,
1701                1.3141313000193942,
1702                2.3107452592153854}),
1703           torch::tensor(
1704               {3.1762315339155434, 2.0197585204578647, 2.5117041377790197}),
1705           torch::tensor(
1706               {-2.6606644002288697, -2.9991216074293856, -2.223413376189609}),
1707           torch::tensor({-1.7712905233118807}),
1708       },
1709       {
1710           torch::tensor(
1711               {2.3338052201696207,
1712                2.2913023710914993,
1713                1.8624163948044772,
1714                1.7530300731725454,
1715                1.3203313209234842,
1716                2.3163969478854747}),
1717           torch::tensor(
1718               {3.1943525925688934, 2.044447386769377, 2.527109724607397}),
1719           torch::tensor(
1720               {-2.7076634717294894, -3.047500808469036, -2.250495807208967}),
1721           torch::tensor({-1.7911288238757486}),
1722       },
1723       {
1724           torch::tensor(
1725               {2.3114979154644892,
1726                2.2830501835377808,
1727                1.871616142999356,
1728                1.765632597660841,
1729                1.324565631636651,
1730                2.319939234205203}),
1731           torch::tensor(
1732               {3.2074779925809085, 2.0642940833670544, 2.5392671301471235}),
1733           torch::tensor(
1734               {-2.7463093287485925, -3.087315541134716, -2.272780318857348}),
1735           torch::tensor({-1.8074516661537263}),
1736       },
1737       {
1738           torch::tensor(
1739               {2.2891841627387346,
1740                2.2734716995793693,
1741                1.8786818999895825,
1742                1.7757301317117602,
1743                1.3274682997719436,
1744                2.3220679353993825}),
1745           torch::tensor(
1746               {3.2172019619454075, 2.0807140893178175, 2.5491374815141876}),
1747           torch::tensor(
1748               {-2.7789204504423823, -3.1209351402429175, -2.2915969523376867}),
1749           torch::tensor({-1.821234722948421}),
1750       },
1751       {
1752           torch::tensor(
1753               {2.2672498238343066,
1754                2.2631678037928893,
1755                1.8842131287032622,
1756                1.7840007705383882,
1757                1.3294311820750493,
1758                2.323211243034543}),
1759           torch::tensor(
1760               {3.224507488068445, 2.094598223519413, 2.5573257155791715}),
1761           torch::tensor(
1762               {-2.8069849199086647, -3.1498826045022925, -2.3077996970997727}),
1763           torch::tensor({-1.8331040438272388}),
1764       },
1765       {
1766           torch::tensor(
1767               {2.2458961718688957,
1768                2.2525031725114775,
1769                1.8886034384961112,
1770                1.7908930341267952,
1771                1.3307102435291205,
1772                2.323647497600462}),
1773           torch::tensor(
1774               {3.230036385413078, 2.1065407459636134, 2.5642349249609664}),
1775           torch::tensor(
1776               {-2.83151249424399, -3.1751926295316566, -2.3219682378974036}),
1777           torch::tensor({-1.8434843744626483}),
1778       },
1779   };
1780 }
1781 
RMSprop()1782 inline std::vector<std::vector<torch::Tensor>> RMSprop() {
1783   return {
1784       {
1785           torch::tensor(
1786               {0.7890625772821005,
1787                0.502415108650816,
1788                0.8587027713011453,
1789                0.657967312300643,
1790                0.7476283936579036,
1791                1.6975509766054537}),
1792           torch::tensor(
1793               {0.8914573371873159, 0.7020499947573374, 1.6891991194739453}),
1794           torch::tensor(
1795               {-1.0508027874171133, -1.3941348219724659, -1.2843374000099703}),
1796           torch::tensor({-1.0711379842715099}),
1797       },
1798       {
1799           torch::tensor(
1800               {2.448571858277443,
1801                2.2809152044417678,
1802                1.7346424449151965,
1803                1.5940004770230667,
1804                1.250761131839982,
1805                2.248993270255382}),
1806           torch::tensor(
1807               {2.994661478530102, 1.8150485290864256, 2.382542610897819}),
1808           torch::tensor(
1809               {-2.3036981738757825, -2.6337299521275646, -2.018370122358821}),
1810           torch::tensor({-1.620787559800898}),
1811       },
1812       {
1813           torch::tensor(
1814               {2.5837582475607785,
1815                2.4365737242301537,
1816                1.8622886519354538,
1817                1.7357065282848232,
1818                1.3369695670141974,
1819                2.3454934716983695}),
1820           torch::tensor(
1821               {3.2061266499381618, 1.9981112525417788, 2.5092495986614}),
1822           torch::tensor(
1823               {-2.6110809365525958, -2.9484807193016787, -2.194898560798439}),
1824           torch::tensor({-1.7501043480625826}),
1825       },
1826       {
1827           torch::tensor(
1828               {2.669969051134511,
1829                2.536559412710799,
1830                1.9456091681389671,
1831                1.828914948091767,
1832                1.3952956766999587,
1833                2.4110816686341923}),
1834           torch::tensor(
1835               {3.343672975593657, 2.1204057198913002, 2.5961524902119497}),
1836           torch::tensor(
1837               {-2.8372329851331006, -3.1817729538857207, -2.3249971853996954}),
1838           torch::tensor({-1.8450422173907486}),
1839       },
1840       {
1841           torch::tensor(
1842               {2.7375365004059122,
1843                2.6153071545358633,
1844                2.0117493624534313,
1845                1.9033001982031035,
1846                1.4427501882445097,
1847                2.4646213743186127}),
1848           torch::tensor(
1849               {3.452912454199796, 2.2190451524127535, 2.667552790123282}),
1850           torch::tensor(
1851               {-3.0329479456731505, -3.384582488936652, -2.4377299824997136}),
1852           torch::tensor({-1.9271014784118226}),
1853       },
1854       {
1855           torch::tensor(
1856               {2.7952372917068753,
1857                2.682820220375722,
1858                2.0687223272686994,
1859                1.967654548778711,
1860                1.4844410726622166,
1861                2.5117888904510117}),
1862           torch::tensor(
1863               {3.5471904628565745, 2.305113548262141, 2.7307948248967304}),
1864           torch::tensor(
1865               {-3.2141190290332537, -3.572944633614449, -2.5421970206546827}),
1866           torch::tensor({-2.0029976985219666}),
1867       },
1868       {
1869           torch::tensor(
1870               {2.8467333937519483,
1871                2.7432785177110395,
1872                2.119898810135385,
1873                2.0256805416741255,
1874                1.5225256464280221,
1875                2.554983108087885}),
1876           torch::tensor(
1877               {3.632098323876194, 2.383289304179778, 2.7889864719999222}),
1878           torch::tensor(
1879               {-3.387579944167926, -3.7537658010839294, -2.6423123266260427}),
1880           torch::tensor({-2.075616951445725}),
1881       },
1882       {
1883           torch::tensor(
1884               {2.8938600498060203,
1885                2.798776984183615,
1886                2.16697381475156,
1887                2.079238538430203,
1888                1.5580820887115123,
1889                2.5954023023969692}),
1890           torch::tensor(
1891               {3.71043881435304, 2.455919099321953, 2.8436784441941008}),
1892           torch::tensor(
1893               {-3.5567368287146417, -3.930484868709691, -2.740026479434574}),
1894           torch::tensor({-2.146398256871758}),
1895       },
1896       {
1897           torch::tensor(
1898               {2.937649394399939,
1899                2.850492965312311,
1900                2.210902777232446,
1901                2.1293746183147633,
1902                1.5917084661873866,
1903                2.6337100421079533}),
1904           torch::tensor(
1905               {3.7837853328443516, 2.5243155701130604, 2.8957265009949373}),
1906           torch::tensor(
1907               {-3.7234268485210475, -4.104949193318518, -2.836390693799751}),
1908           torch::tensor({-2.216118773360611}),
1909       },
1910       {
1911           torch::tensor(
1912               {2.9787316798887558,
1913                2.8991451078473207,
1914                2.252272487010975,
1915                2.1767288729202767,
1916                1.6237627746697592,
1917                2.670302507579268}),
1918           torch::tensor(
1919               {3.8530980655045086, 2.589275531553025, 2.9456388178450936}),
1920           torch::tensor(
1921               {-3.8886856459619636, -4.278191888396593, -2.9319964928350313}),
1922           torch::tensor({-2.285217699505124}),
1923       },
1924       {
1925           torch::tensor(
1926               {3.017515620579049,
1927                2.9452004251453268,
1928                2.291469925522962,
1929                2.2217217025782916,
1930                1.6544730927621272,
1931                2.705431441204422}),
1932           torch::tensor(
1933               {3.9190041004420166, 2.651317624465938, 2.993736489599992}),
1934           torch::tensor(
1935               {-4.053111559341913, -4.450801801162238, -3.0271845519131957}),
1936           torch::tensor({-2.3539498905973444}),
1937       },
1938   };
1939 }
1940 
RMSprop_with_weight_decay()1941 inline std::vector<std::vector<torch::Tensor>> RMSprop_with_weight_decay() {
1942   return {
1943       {
1944           torch::tensor(
1945               {0.7890798754118442,
1946                0.5024321083861885,
1947                0.8587031097835685,
1948                0.6579677474141494,
1949                0.7476297677960806,
1950                1.6975465611838714}),
1951           torch::tensor(
1952               {0.891458601354904, 0.7020500593937647, 1.6891986479348047}),
1953           torch::tensor(
1954               {-1.0508027868194278, -1.3941348166291232, -1.2843373988951865}),
1955           torch::tensor({-1.0711379841318796}),
1956       },
1957       {
1958           torch::tensor(
1959               {0.2139892652405453,
1960                0.2779011713353896,
1961                0.18684802794665187,
1962                0.2507569370785562,
1963                0.19145335235130007,
1964                0.2557687813140708}),
1965           torch::tensor(
1966               {0.6720959116689083, 0.6480734848064099, 0.654263007004671}),
1967           torch::tensor(
1968               {-1.4357633640899097, -1.4493557950073235, -1.4619011018357073}),
1969           torch::tensor({-1.9673083558727926}),
1970       },
1971       {
1972           torch::tensor(
1973               {0.23961935744660673,
1974                0.30354236865888945,
1975                0.19567694278583514,
1976                0.2544696440133763,
1977                0.21982879020261814,
1978                0.27495711471979495}),
1979           torch::tensor(
1980               {0.6927895724658635, 0.638015535479105, 0.6523245375960234}),
1981           torch::tensor(
1982               {-1.413722583500382, -1.4170291001633526, -1.4166977298480703}),
1983           torch::tensor({-2.0626651115437147}),
1984       },
1985       {
1986           torch::tensor(
1987               {0.2506635865272117,
1988                0.314639511428635,
1989                0.25116892910818034,
1990                0.30431399579592144,
1991                0.25219625048710015,
1992                0.3160110008170742}),
1993           torch::tensor(
1994               {0.7051419232960522, 0.6699011906397543, 0.699097284678438}),
1995           torch::tensor(
1996               {-1.4206083241624232, -1.4257037444100107, -1.4171061826065132}),
1997           torch::tensor({-2.075537874763694}),
1998       },
1999       {
2000           torch::tensor(
2001               {0.23285924743063605,
2002                0.29652494304777544,
2003                0.2335322002738168,
2004                0.2969991261380461,
2005                0.23358272245229555,
2006                0.2973997498166104}),
2007           torch::tensor(
2008               {0.6855589594925036, 0.6796983775695974, 0.6864174803983276}),
2009           torch::tensor(
2010               {-1.43110762794651, -1.4334934742818164, -1.422739552145125}),
2011           torch::tensor({-2.0842642493046184}),
2012       },
2013       {
2014           torch::tensor(
2015               {0.23356397699389828,
2016                0.29737142391985355,
2017                0.23367622061822368,
2018                0.29749447597160267,
2019                0.23418481357395918,
2020                0.29818122925156104}),
2021           torch::tensor(
2022               {0.6866530583001205, 0.6858933385102559, 0.6883944045412603}),
2023           torch::tensor(
2024               {-1.4564955509607018, -1.4583548131500643, -1.4418225445708595}),
2025           torch::tensor({-2.1064103749186183}),
2026       },
2027       {
2028           torch::tensor(
2029               {0.2318717301174723,
2030                0.2952904159872858,
2031                0.23194024439476665,
2032                0.29537019824987687,
2033                0.2316421336904657,
2034                0.2951041425983894}),
2035           torch::tensor(
2036               {0.6834813130194509, 0.6834401711464199, 0.6837275457100463}),
2037           torch::tensor(
2038               {-1.4647835805276763, -1.4653452408179053, -1.4571142112777709}),
2039           torch::tensor({-2.1209598505912086}),
2040       },
2041       {
2042           torch::tensor(
2043               {0.2308683396504178,
2044                0.2940474629750448,
2045                0.23089067678260966,
2046                0.29407615110959306,
2047                0.23064069314214175,
2048                0.29379043611390243}),
2049           torch::tensor(
2050               {0.6815062281792611, 0.6815233687209215, 0.6812759203026146}),
2051           torch::tensor(
2052               {-1.4643013018530682, -1.4644523635284246, -1.4617493939684878}),
2053           torch::tensor({-2.1247293635678854}),
2054       },
2055       {
2056           torch::tensor(
2057               {0.23066464201462678,
2058                0.29376059273730426,
2059                0.23067069245857366,
2060                0.2937690399784267,
2061                0.23057551211606675,
2062                0.2936477517373107}),
2063           torch::tensor(
2064               {0.6809028781780304, 0.6809134105028244, 0.6807404613096301}),
2065           torch::tensor(
2066               {-1.4637927352177986, -1.4638374228010727, -1.46299287102643}),
2067           torch::tensor({-2.1258082720638107}),
2068       },
2069       {
2070           torch::tensor(
2071               {0.23062625079199173,
2072                0.2936990787425707,
2073                0.2306278924729115,
2074                0.2937014834661651,
2075                0.23059813368157003,
2076                0.29366073890476396}),
2077           torch::tensor(
2078               {0.6807251804689082, 0.6807295616357246, 0.6806523640328994}),
2079           torch::tensor(
2080               {-1.4635790398985618, -1.463592926902286, -1.4633272688236565}),
2081           torch::tensor({-2.1261396358141798}),
2082       },
2083       {
2084           torch::tensor(
2085               {0.23061701122700193,
2086                0.293683983817782,
2087                0.23061747865501653,
2088                0.29368467998690806,
2089                0.23060855595638208,
2090                0.2936719507340021}),
2091           torch::tensor(
2092               {0.6806714673830832, 0.6806730903175793, 0.6806434720800856}),
2093           torch::tensor(
2094               {-1.4635008778278134, -1.4635052859178375, -1.4634208375068285}),
2095           torch::tensor({-2.1262432969587723}),
2096       },
2097   };
2098 }
2099 
2100 inline std::vector<std::vector<torch::Tensor>>
RMSprop_with_weight_decay_and_centered()2101 RMSprop_with_weight_decay_and_centered() {
2102   return {
2103       {
2104           torch::tensor(
2105               {0.7941000061626792,
2106                0.507452636734552,
2107                0.8637405354185987,
2108                0.663005089317529,
2109                0.7526661272860107,
2110                1.7025887305065852}),
2111           torch::tensor(
2112               {0.8964950370033696, 0.7070877948157552, 1.6942369105467197}),
2113           torch::tensor(
2114               {-1.055840599214661, -1.3991726335388424, -1.2893752132746332}),
2115           torch::tensor({-1.0761757981162612}),
2116       },
2117       {
2118           torch::tensor(
2119               {2.3762999876885833,
2120                2.239095829416783,
2121                1.726175067071914,
2122                1.5891569459230444,
2123                1.2410074108588462,
2124                2.2345431036725723}),
2125           torch::tensor(
2126               {2.990896455635836, 1.8152108764849464, 2.377985429759037}),
2127           torch::tensor(
2128               {-2.3071822180635286, -2.636859516619699, -2.0198181394256642}),
2129           torch::tensor({-1.622583045791722}),
2130       },
2131       {
2132           torch::tensor(
2133               {2.372800588647971,
2134                2.3022753207224254,
2135                1.836028714221617,
2136                1.7190937269287105,
2137                1.3068955839895078,
2138                2.3035835673200364}),
2139           torch::tensor(
2140               {3.1656599892042343, 1.9942937608209466, 2.4947143457182657}),
2141           torch::tensor(
2142               {-2.6139790332516775, -2.9507738987695404, -2.1954425128779516}),
2143           torch::tensor({-1.7513053380188806}),
2144       },
2145       {
2146           torch::tensor(
2147               {2.2398453700818455,
2148                2.2513384246965904,
2149                1.8892176431436287,
2150                1.7921873754661686,
2151                1.3310951408713538,
2152                2.3236392222350397}),
2153           torch::tensor(
2154               {3.240166119454613, 2.109742813600189, 2.5651614461576973}),
2155           torch::tensor(
2156               {-2.8388734382997454, -3.1824200770676123, -2.324831397600949}),
2157           torch::tensor({-1.8460315737386976}),
2158       },
2159       {
2160           torch::tensor(
2161               {1.9829606312242465,
2162                2.097356567850692,
2163                1.9050263843525033,
2164                1.8325835415812346,
2165                1.3222762370713104,
2166                2.3024963133870147}),
2167           torch::tensor(
2168               {3.2465360572089974, 2.1967266045869915, 2.6091992649970672}),
2169           torch::tensor(
2170               {-3.0326878099587207, -3.3827004807595005, -2.436989182250496}),
2171           torch::tensor({-1.928273216206344}),
2172       },
2173       {
2174           torch::tensor(
2175               {1.6051175329080525,
2176                1.8332107491649114,
2177                1.8794767349053179,
2178                1.8403588051948856,
2179                1.273824111314107,
2180                2.2296571379436823}),
2181           torch::tensor(
2182               {3.1814362940910437, 2.263019214072847, 2.6273016977574013}),
2183           torch::tensor(
2184               {-3.210932646440219, -3.567153254014387, -2.541016943923914}),
2185           torch::tensor({-2.0049155134617154}),
2186       },
2187       {
2188           torch::tensor(
2189               {1.1588059349082709,
2190                1.477861379523226,
2191                1.7992410089026634,
2192                1.806460009198667,
2193                1.1739931551629919,
2194                2.08647960875392}),
2195           torch::tensor(
2196               {3.03843703712275, 2.308203068375877, 2.6125393914734083}),
2197           torch::tensor(
2198               {-3.379830678608588, -3.741970414470626, -2.6410082400846546}),
2199           torch::tensor({-2.079294995910487}),
2200       },
2201       {
2202           torch::tensor(
2203               {0.7701433312419088,
2204                1.1105026677424745,
2205                1.646507516936639,
2206                1.71625269098179,
2207                1.013748545414221,
2208                1.8532966501655352}),
2209           torch::tensor(
2210               {2.827176875885245, 2.327401948159928, 2.5535309398603405}),
2211           torch::tensor(
2212               {-3.54193329850986, -3.9096652952123145, -2.739408870192437}),
2213           torch::tensor({-2.1537939241668997}),
2214       },
2215       {
2216           torch::tensor(
2217               {0.5598923129351211,
2218                0.8460500042788701,
2219                1.4084175549165017,
2220                1.5547314210944563,
2221                0.8019580519338424,
2222                1.5258384663629627}),
2223           torch::tensor(
2224               {2.5774950379490265, 2.313101306699127, 2.4388695757441745}),
2225           torch::tensor(
2226               {-3.6974974230160087, -4.070190514312716, -2.8378932675718405}),
2227           torch::tensor({-2.2307225014430423}),
2228       },
2229       {
2230           torch::tensor(
2231               {0.5016784472836648,
2232                0.7258690889265433,
2233                1.0976902935953956,
2234                1.319949187972513,
2235                0.5853930356154851,
2236                1.1446978015944624}),
2237           torch::tensor(
2238               {2.3235249877284945, 2.2592840970420176, 2.2681461698609375}),
2239           torch::tensor(
2240               {-3.8444921272569115, -4.22021051361099, -2.9373192115434263}),
2241           torch::tensor({-2.312733063937045}),
2242       },
2243       {
2244           torch::tensor(
2245               {0.4875468895095056,
2246                0.6878747871467128,
2247                0.7787871237567606,
2248                1.0462592546102176,
2249                0.4416468896022397,
2250                0.8122992916762792}),
2251           torch::tensor(
2252               {2.1078734515587483, 2.17034337037527, 2.0666325968568535}),
2253           torch::tensor(
2254               {-3.9782695475825216, -4.352093055115415, -3.0377809502927033}),
2255           torch::tensor({-2.403496388200805}),
2256       },
2257   };
2258 }
2259 
2260 inline std::vector<std::vector<torch::Tensor>>
RMSprop_with_weight_decay_and_centered_and_momentum()2261 RMSprop_with_weight_decay_and_centered_and_momentum() {
2262   return {
2263       {
2264           torch::tensor(
2265               {0.7941000061626794,
2266                0.507452636734552,
2267                0.8637405354185985,
2268                0.663005089317529,
2269                0.7526661272860107,
2270                1.7025887305065852}),
2271           torch::tensor(
2272               {0.8964950370033699, 0.7070877948157552, 1.6942369105467197}),
2273           torch::tensor(
2274               {-1.055840599214661, -1.3991726335388424, -1.2893752132746332}),
2275           torch::tensor({-1.0761757981162612}),
2276       },
2277       {
2278           torch::tensor(
2279               {11.587263945492355,
2280                12.552112516667206,
2281                10.773002960161074,
2282                10.782117868337808,
2283                9.675467654064093,
2284                10.830689360054789}),
2285           torch::tensor(
2286               {15.298238342006444, 11.252244653209866, 11.423905295074075}),
2287           torch::tensor(
2288               {-11.287147147258441, -11.673871066494183, -11.143068139029769}),
2289           torch::tensor({-10.744790465364126}),
2290       },
2291       {
2292           torch::tensor(
2293               {5.993130757784388,
2294                7.778269455146452,
2295                9.705741295559012,
2296                9.974952848613889,
2297                8.171307305871647,
2298                9.551498426643077}),
2299           torch::tensor(
2300               {12.811268477045155, 10.912201832960703, 10.87477550647832}),
2301           torch::tensor(
2302               {-11.20842921856976, -11.58706973895515, -11.098172235374586}),
2303           torch::tensor({-10.714110383698559}),
2304       },
2305       {
2306           torch::tensor(
2307               {1.917316794757853,
2308                3.442098373003915,
2309                8.160846071267297,
2310                8.76673426856121,
2311                6.163892823252042,
2312                7.748894752821816}),
2313           torch::tensor(
2314               {9.52929937981379, 10.371703621802425, 10.02242566317017}),
2315           torch::tensor(
2316               {-11.07914626767133, -11.444639737948599, -11.02397978065452}),
2317           torch::tensor({-10.663204622623406}),
2318       },
2319       {
2320           torch::tensor(
2321               {0.24211162925745067,
2322                0.8235150923738451,
2323                6.109652191353378,
2324                7.070860554523037,
2325                3.8366635637770212,
2326                5.46037058418296}),
2327           torch::tensor(
2328               {5.7908039507441, 9.534309069066389, 8.752252906881251}),
2329           torch::tensor(
2330               {-10.868651889371552, -11.212965695734527, -10.90242744782103}),
2331           torch::tensor({-10.579596899816439}),
2332       },
2333       {
2334           torch::tensor(
2335               {0.0024206009020476234,
2336                0.05521740497689468,
2337                3.753606156332189,
2338                4.9331546064599685,
2339                1.7094621184709604,
2340                3.022224882400484}),
2341           torch::tensor(
2342               {2.4729429920325234, 8.290211439306459, 6.983317870704776}),
2343           torch::tensor(
2344               {-10.529133489023623, -10.839885990130032, -10.704345435808353}),
2345           torch::tensor({-10.44279235413811}),
2346       },
2347       {
2348           torch::tensor(
2349               {8.523664833406631e-06,
2350                -0.00018498015809617104,
2351                1.6343074841140277,
2352                2.683608480982546,
2353                0.41425107807132744,
2354                1.092111816609512}),
2355           torch::tensor(
2356               {0.553119873538318, 6.566845593450314, 4.783317472190566}),
2357           torch::tensor(
2358               {-9.990101114696575, -10.24914448933998, -10.38447825909146}),
2359           torch::tensor({-10.220382375374728}),
2360       },
2361       {
2362           torch::tensor(
2363               {5.3669182339397725e-08,
2364                -2.899704029399283e-07,
2365                0.37916783268568177,
2366                0.9399553431452395,
2367                0.02859528129337607,
2368                0.17650614337704745}),
2369           torch::tensor(
2370               {0.03166973497545419, 4.442846994093523, 2.5203464928754724}),
2371           torch::tensor(
2372               {-9.15653357178671, -9.339631853060773, -9.875729313751442}),
2373           torch::tensor({-9.862669711962374}),
2374       },
2375       {
2376           torch::tensor(
2377               {2.1133356499004335e-06,
2378                2.4524630407768025e-06,
2379                0.023655729923601883,
2380                0.14273709578291396,
2381                -8.950192389690758e-05,
2382                0.004237697008964042}),
2383           torch::tensor(
2384               {-0.00012364097582548376, 2.291191859107928, 0.8331414409602524}),
2385           torch::tensor(
2386               {-7.922566174765117, -8.003055545094796, -9.086673634672907}),
2387           torch::tensor({-9.297519364373224}),
2388       },
2389       {
2390           torch::tensor(
2391               {0.0023497430294992434,
2392                0.0028611316714725037,
2393                0.0006998739627296072,
2394                0.003657156536057531,
2395                0.001654303471369622,
2396                0.0018171459470053366}),
2397           torch::tensor(
2398               {0.004569191565477355, 0.7292466599711233, 0.11475431260766135}),
2399           torch::tensor(
2400               {-6.223834483308681, -6.185383631607397, -7.912955414853613}),
2401           torch::tensor({-8.430731662958186}),
2402       },
2403       {
2404           torch::tensor(
2405               {0.10393820340367545,
2406                0.13982074666181732,
2407                0.0831407198272949,
2408                0.10183584198629944,
2409                0.13949594516972202,
2410                0.17822672100147108}),
2411           torch::tensor(
2412               {0.340394645020639, 0.24860888862359687, 0.3191404515531066}),
2413           torch::tensor(
2414               {-4.174294597914298, -4.037528929635062, -6.297198700024484}),
2415           torch::tensor({-7.182093090194918}),
2416       },
2417   };
2418 }
2419 
SGD()2420 inline std::vector<std::vector<torch::Tensor>> SGD() {
2421   return {
2422       {
2423           torch::tensor(
2424               {-0.21063957030131192,
2425                -0.4972093725858961,
2426                -0.13931849072410168,
2427                -0.33939101965581686,
2428                -0.25112865488453673,
2429                0.6992101966874735}),
2430           torch::tensor(
2431               {-0.1076573444246077, -0.2913064413859577, 0.6933846874181748}),
2432           torch::tensor(
2433               {-0.07998325778863398,
2434                -0.42149210515421365,
2435                -0.33498349553944556}),
2436           torch::tensor({-0.14255126505509488}),
2437       },
2438       {
2439           torch::tensor(
2440               {-0.15543131540224012,
2441                -0.42351103963720343,
2442                -0.04196796248622072,
2443                -0.2095223178068499,
2444                -0.16031407286541022,
2445                0.8209742464453325}),
2446           torch::tensor(
2447               {0.07724343607160136, 0.03387529472490231, 1.0028793648054941}),
2448           torch::tensor(
2449               {-0.8213382425894498, -1.1570800333254736, -1.615476033165743}),
2450           torch::tensor({-1.8734090731084845}),
2451       },
2452       {
2453           torch::tensor(
2454               {-0.13342791770744886,
2455                -0.3941509709488104,
2456                -0.011470356542661934,
2457                -0.16885142516066962,
2458                -0.13306680693528108,
2459                0.8576491729785701}),
2460           torch::tensor(
2461               {0.15081014600761677, 0.13560816175111742, 1.0971559708365837}),
2462           torch::tensor(
2463               {-0.9780975407869251, -1.3215153697157924, -1.876021387605152}),
2464           torch::tensor({-2.2024413056528886}),
2465       },
2466       {
2467           torch::tensor(
2468               {-0.11963097684681223,
2469                -0.37573675130134543,
2470                0.0069987166413883715,
2471                -0.14420855651125972,
2472                -0.11733423659038758,
2473                0.8788673419128562}),
2474           torch::tensor(
2475               {0.1969829338759005, 0.1973461164047132, 1.1520119567305152}),
2476           torch::tensor(
2477               {-1.0677802792431819, -1.4166561260631119, -2.022033753216991}),
2478           torch::tensor({-2.383452427292781}),
2479       },
2480       {
2481           torch::tensor(
2482               {-0.10950806441156272,
2483                -0.3622226699218595,
2484                0.02028489243523426,
2485                -0.1264725422838007,
2486                -0.10635775660996463,
2487                0.8936912722040982}),
2488           torch::tensor(
2489               {0.23089462331826793, 0.24184450074084418, 1.1904864598387046}),
2490           torch::tensor(
2491               {-1.1306213044009719, -1.4837186483578142, -2.122884602514208}),
2492           torch::tensor({-2.5071352505158395}),
2493       },
2494       {
2495           torch::tensor(
2496               {-0.10149090356585248,
2497                -0.3515172115812867,
2498                0.030662536099764083,
2499                -0.11261325211798616,
2500                -0.09797248308626623,
2501                0.905027632401109}),
2502           torch::tensor(
2503               {0.25777759826689434, 0.2766609657536915, 1.2199973265718322}),
2504           torch::tensor(
2505               {-1.1789655573653979, -1.5355073692636774, -2.199612583884608}),
2506           torch::tensor({-2.600529541471662}),
2507       },
2508       {
2509           torch::tensor(
2510               {-0.09484472748389533,
2511                -0.3426405023243085,
2512                0.03917399284640637,
2513                -0.10124188994381228,
2514                -0.09121264836307835,
2515                0.9141743475340721}),
2516           torch::tensor(
2517               {0.2800829300171032, 0.3052600200290069, 1.2438661306695873}),
2518           torch::tensor(
2519               {-1.2182324765944266, -1.5776851394085492, -2.2613704866316295}),
2520           torch::tensor({-2.6752743361973184}),
2521       },
2522       {
2523           torch::tensor(
2524               {-0.08916446117741175,
2525                -0.33505233521798666,
2526                0.04638527943959316,
2527                -0.09160422984057517,
2528                -0.08556486270584644,
2529                0.9218219103015535}),
2530           torch::tensor(
2531               {0.2991619380154852, 0.3295237551295101, 1.2638639017720827}),
2532           torch::tensor(
2533               {-1.251282493526328, -1.6132564639504312, -2.3129529937213853}),
2534           torch::tensor({-2.73741957239466}),
2535       },
2536       {
2537           torch::tensor(
2538               {-0.08420245801272856,
2539                -0.3284224385121882,
2540                0.05263847708646642,
2541                -0.08324438788845245,
2542                -0.08072424164719598,
2543                0.9283806476306355}),
2544           torch::tensor(
2545               {0.31584087342663564, 0.35059019818200393, 1.2810450644764015}),
2546           torch::tensor(
2547               {-1.2798091496372141, -1.6440072538210193, -2.357180462961105}),
2548           torch::tensor({-2.7905023459395872}),
2549       },
2550       {
2551           torch::tensor(
2552               {-0.07979600214534928,
2553                -0.3225337978155753,
2554                0.0581562720006689,
2555                -0.07586555700667826,
2556                -0.07649523955108037,
2557                0.9341138824526719}),
2558           torch::tensor(
2559               {0.3306627217189733, 0.3692005578577212, 1.2960873917356066}),
2560           torch::tensor(
2561               {-1.3048976883823566, -1.6710855742501123, -2.395849898454614}),
2562           torch::tensor({-2.836765085555123}),
2563       },
2564       {
2565           torch::tensor(
2566               {-0.07583232846497832,
2567                -0.3172360102461862,
2568                0.06309179259248046,
2569                -0.06926361352067158,
2570                -0.07274510848082802,
2571                0.9392004636935606}),
2572           torch::tensor(
2573               {0.3440038606091545, 0.3858647867996722, 1.3094518934419668}),
2574           torch::tensor(
2575               {-1.3272851146877218, -1.6952731308502653, -2.4301754289421598}),
2576           torch::tensor({-2.8777164728823017}),
2577       },
2578   };
2579 }
2580 
SGD_with_weight_decay()2581 inline std::vector<std::vector<torch::Tensor>> SGD_with_weight_decay() {
2582   return {
2583       {
2584           torch::tensor(
2585               {-0.21042867144447805,
2586                -0.49671181653925384,
2587                -0.13917719856207697,
2588                -0.3390489907590303,
2589                -0.2508762913762564,
2590                0.6985126396619242}),
2591           torch::tensor(
2592               {-0.10754881320494518, -0.2910084928862701, 0.6926954859081793}),
2593           torch::tensor(
2594               {-0.079932454658518, -0.42109796996670307, -0.33469915794198624}),
2595           torch::tensor({-0.14248012693079315}),
2596       },
2597       {
2598           torch::tensor(
2599               {-0.13579982290274883,
2600                -0.3765456284475787,
2601                -0.03166970700350034,
2602                -0.18102559254681197,
2603                -0.1373234786735746,
2604                0.7522156177001302}),
2605           torch::tensor(
2606               {0.08550003826014418, 0.051563225553454196, 0.9321399061276381}),
2607           torch::tensor(
2608               {-0.796312238882584, -1.1010063686038731, -1.5363716774172782}),
2609           torch::tensor({-1.8045854907382846}),
2610       },
2611       {
2612           torch::tensor(
2613               {-0.09659168723529124,
2614                -0.30562076936588267,
2615                0.0067128671455129185,
2616                -0.1166002367977548,
2617                -0.09012083166238948,
2618                0.7264953102453368}),
2619           torch::tensor(
2620               {0.16531808496504802, 0.16488328577596398, 0.9610743966573317}),
2621           torch::tensor(
2622               {-0.9202466399245914, -1.2052829272891832, -1.7049756710541348}),
2623           torch::tensor({-2.0415977924493043}),
2624       },
2625       {
2626           torch::tensor(
2627               {-0.06728100597713035,
2628                -0.2496589601654196,
2629                0.03186158526394668,
2630                -0.07105441484407878,
2631                -0.056478595544178806,
2632                0.6910758436366733}),
2633           torch::tensor(
2634               {0.21707768347081777, 0.23575238192099465, 0.9564382346520686}),
2635           torch::tensor(
2636               {-0.9788195039029999, -1.2447191597975946, -1.762020156061963}),
2637           torch::tensor({-2.131504419683077}),
2638       },
2639       {
2640           torch::tensor(
2641               {-0.04304955053155505,
2642                -0.20206572730420902,
2643                0.050959513946324475,
2644                -0.034700093557440984,
2645                -0.02922465201167018,
2646                0.6547611705604361}),
2647           torch::tensor(
2648               {0.2563898231537708, 0.2878867158887637, 0.9414221685252802}),
2649           torch::tensor(
2650               {-1.0143969472996655, -1.2623288365082088, -1.7800471460065668}),
2651           torch::tensor({-2.170255083720924}),
2652       },
2653       {
2654           torch::tensor(
2655               {-0.022154717038262738,
2656                -0.16036518660639862,
2657                0.06644401410758827,
2658                -0.004183373274651896,
2659                -0.005965877978527781,
2660                0.6200298215101535}),
2661           torch::tensor(
2662               {0.2886406829874717, 0.32924516791460257, 0.9230983700837223}),
2663           torch::tensor(
2664               {-1.0397895250773481, -1.2710914166240181, -1.7807758009603087}),
2665           torch::tensor({-2.1862978976514738}),
2666       },
2667       {
2668           torch::tensor(
2669               {-0.0037439139848317077,
2670                -0.12328293308251938,
2671                0.07944696186805641,
2672                0.022100305718442022,
2673                0.014399113804332037,
2674                0.587697912745227}),
2675           torch::tensor(
2676               {0.3162871074692008, 0.36346293565421134, 0.9042402154310412}),
2677           torch::tensor(
2678               {-1.060234961430088, -1.2762264965487675, -1.7731268727630662}),
2679           torch::tensor({-2.191253945056341}),
2680       },
2681       {
2682           torch::tensor(
2683               {0.012675985938854726,
2684                -0.09003711893222131,
2685                0.09059095692632844,
2686                0.04506778924310349,
2687                0.03247299240601001,
2688                0.5579755127260052}),
2689           torch::tensor(
2690               {0.3406226998933173, 0.3924947745885882, 0.8860121369119325}),
2691           torch::tensor(
2692               {-1.0781407849705034, -1.2800528898634018, -1.7613120374342215}),
2693           torch::tensor({-2.190575043873577}),
2694       },
2695       {
2696           torch::tensor(
2697               {0.027425440985777993,
2698                -0.06008809958617219,
2699                0.10026092920861808,
2700                0.06531092947039244,
2701                0.048628754907931976,
2702                0.5308215072596255}),
2703           torch::tensor(
2704               {0.36239744520280553, 0.4175162387638887, 0.8688788105023479}),
2705           torch::tensor(
2706               {-1.0946579691370502, -1.283610342226948, -1.7474706191775764}),
2707           torch::tensor({-2.1870021744944763}),
2708       },
2709       {
2710           torch::tensor(
2711               {0.04073250980147411,
2712                -0.03303024103555013,
2713                0.1087177047593139,
2714                0.08324870459183518,
2715                0.0631222868881554,
2716                0.5060892094042873}),
2717           torch::tensor(
2718               {0.38208249693950175, 0.4393002654989596, 0.8529817924677643}),
2719           torch::tensor(
2720               {-1.1103326127955466, -1.287332405916359, -1.73273866274852}),
2721           torch::tensor({-2.1819672316721337}),
2722       },
2723       {
2724           torch::tensor(
2725               {0.05277160918732605,
2726                -0.008539186625351441,
2727                0.1161515444487197,
2728                0.09919929206676087,
2729                0.07614530177703588,
2730                0.48359250162323586}),
2731           torch::tensor(
2732               {0.3999968617221315, 0.45839442009256354, 0.8383132966805791}),
2733           torch::tensor(
2734               {-1.1254107858333455, -1.2913604197768889, -1.717739109221235}),
2735           torch::tensor({-2.1762368071604308}),
2736       },
2737   };
2738 }
2739 
2740 inline std::vector<std::vector<torch::Tensor>>
SGD_with_weight_decay_and_momentum()2741 SGD_with_weight_decay_and_momentum() {
2742   return {
2743       {
2744           torch::tensor(
2745               {-0.21042867144447805,
2746                -0.49671181653925384,
2747                -0.13917719856207697,
2748                -0.3390489907590303,
2749                -0.2508762913762564,
2750                0.6985126396619242}),
2751           torch::tensor(
2752               {-0.10754881320494518, -0.2910084928862701, 0.6926954859081793}),
2753           torch::tensor(
2754               {-0.079932454658518, -0.42109796996670307, -0.33469915794198624}),
2755           torch::tensor({-0.14248012693079315}),
2756       },
2757       {
2758           torch::tensor(
2759               {0.0056118487251954775,
2760                -0.0710915563059199,
2761                0.07701400891926036,
2762                0.047067327035013866,
2763                0.0428654052972598,
2764                0.4352977220593751}),
2765           torch::tensor(
2766               {0.23834837300214828, 0.32366382503704183, 0.7128321016634689}),
2767           torch::tensor(
2768               {-1.041947788394885, -1.1730950187020548, -1.7648205873351157}),
2769           torch::tensor({-2.3359277661920594}),
2770       },
2771       {
2772           torch::tensor(
2773               {0.11520007183759418,
2774                0.1289453768763286,
2775                0.14586845555951963,
2776                0.1775341535876219,
2777                0.15614155642578995,
2778                0.33379126147460536}),
2779           torch::tensor(
2780               {0.465853656413685, 0.520197917876909, 0.7274876508280723}),
2781           torch::tensor(
2782               {-1.2034746444882527, -1.286126969233868, -1.604528340632377}),
2783           torch::tensor({-2.203215909196624}),
2784       },
2785       {
2786           torch::tensor(
2787               {0.15331258730374997,
2788                0.197909036233604,
2789                0.16663814647374195,
2790                0.2183320498727895,
2791                0.1803274550482287,
2792                0.28362745794417826}),
2793           torch::tensor(
2794               {0.5532312776994917, 0.5834224152126115, 0.6903579410976886}),
2795           torch::tensor(
2796               {-1.3052171323471546, -1.3514190497186434, -1.5153574535010634}),
2797           torch::tensor({-2.123181139806548}),
2798       },
2799       {
2800           torch::tensor(
2801               {0.16814113185552507,
2802                0.22386572201448868,
2803                0.17413795101952864,
2804                0.23280515326261633,
2805                0.1839142207976228,
2806                0.2614499495870909}),
2807           torch::tensor(
2808               {0.592282876576759, 0.6083877519652824, 0.663438748699906}),
2809           torch::tensor(
2810               {-1.3591143274292896, -1.383673065830997, -1.467157893517277}),
2811           torch::tensor({-2.087859547998447}),
2812       },
2813       {
2814           torch::tensor(
2815               {0.1743742243877178,
2816                0.2343126153059798,
2817                0.17716942927642254,
2818                0.23838669643330088,
2819                0.18308461132092924,
2820                0.25149544624452974}),
2821           torch::tensor(
2822               {0.6108281747800746, 0.6192657661217672, 0.6475519545045926}),
2823           torch::tensor(
2824               {-1.3860527054444407, -1.398816664238087, -1.4412527948055516}),
2825           torch::tensor({-2.0731939075659627}),
2826       },
2827       {
2828           torch::tensor(
2829               {0.1771465478751462,
2830                0.23875859951719522,
2831                0.1784868271584857,
2832                0.2406786372566496,
2833                0.18181103291606765,
2834                0.24687877342069478}),
2835           torch::tensor(
2836               {0.6198586021174767, 0.6242349464856269, 0.638736845373371}),
2837           torch::tensor(
2838               {-1.3993307716862977, -1.4058965193851591, -1.42747775986796}),
2839           torch::tensor({-2.0672675843404598}),
2840       },
2841       {
2842           torch::tensor(
2843               {0.17843093585357683,
2844                0.24073954802700465,
2845                0.17908697027440873,
2846                0.2416675839909268,
2847                0.18088350526559058,
2848                0.24467193314356378}),
2849           torch::tensor(
2850               {0.6243071074374693, 0.6265628975677455, 0.6339840865876518}),
2851           torch::tensor(
2852               {-1.4058750036106915, -1.4092362337714568, -1.4202202926903085}),
2853           torch::tensor({-2.0649062340635584}),
2854       },
2855       {
2856           torch::tensor(
2857               {0.17904350645021613,
2858                0.24165496946247034,
2859                0.17936920658487726,
2860                0.24211164489776849,
2861                0.18031858582735988,
2862                0.2435923992630521}),
2863           torch::tensor(
2864               {0.626513445507806, 0.6276715667697311, 0.6314641991686346}),
2865           torch::tensor(
2866               {-1.409113940967948, -1.410830795235453, -1.4164247285253404}),
2867           torch::tensor({-2.0639728292802046}),
2868       },
2869       {
2870           torch::tensor(
2871               {0.17934167113683835,
2872                0.242089962404631,
2873                0.17950490408309286,
2874                0.24231745350706005,
2875                0.17999989292556767,
2876                0.2430557755257577}),
2877           torch::tensor(
2878               {0.6276131793232345, 0.6282062328090801, 0.6301427155170752}),
2879           torch::tensor(
2880               {-1.4107251789010826, -1.4116011824171857, -1.4144511767962422}),
2881           torch::tensor({-2.0636056316673934}),
2882       },
2883       {
2884           torch::tensor(
2885               {0.17948886155124505,
2886                0.24230096332204806,
2887                0.17957117450689372,
2888                0.242415213133214,
2889                0.17982712042628357,
2890                0.2427862039224869}),
2891           torch::tensor(
2892               {0.6281635672171683, 0.6284667582211864, 0.6294549191500093}),
2893           torch::tensor(
2894               {-1.4115305541843781, -1.4119772978756444, -1.4134296522818641}),
2895           torch::tensor({-2.0634616066978615}),
2896       },
2897   };
2898 }
2899 
2900 inline std::vector<std::vector<torch::Tensor>>
SGD_with_weight_decay_and_nesterov_momentum()2901 SGD_with_weight_decay_and_nesterov_momentum() {
2902   return {
2903       {
2904           torch::tensor(
2905               {-0.21040617235121148,
2906                -0.49689727139951717,
2907                -0.13754215970803657,
2908                -0.33701686525263036,
2909                -0.2500172388792182,
2910                0.700697918175925}),
2911           torch::tensor(
2912               {-0.1068708360895515, -0.2853285323043249, 0.6971494161502307}),
2913           torch::tensor(
2914               {-0.10624536304143092, -0.4461132561477894, -0.3805647497874434}),
2915           torch::tensor({-0.2068230782168696}),
2916       },
2917       {
2918           torch::tensor(
2919               {-0.1262387113548655,
2920                -0.3844658218758334,
2921                0.03124406856508884,
2922                -0.11170532152425781,
2923                -0.09823268522398332,
2924                0.9040698525178972}),
2925           torch::tensor(
2926               {0.17551336074135096, 0.27976614792027166, 1.2138399680985128}),
2927           torch::tensor(
2928               {-1.592840413595591, -1.8986806244521564, -2.966181914454827}),
2929           torch::tensor({-3.7728444542017687}),
2930       },
2931       {
2932           torch::tensor(
2933               {-0.11614716303292183,
2934                -0.3709539909720773,
2935                0.04307078045512772,
2936                -0.09588329367245825,
2937                -0.08795603365024904,
2938                0.9178771227283019}),
2939           torch::tensor(
2940               {0.20944042006388683, 0.3195483889401668, 1.2500270348310718}),
2941           torch::tensor(
2942               {-1.635011052494502, -1.9463243375558272, -3.035708036973984}),
2943           torch::tensor({-3.8570351018212796}),
2944       },
2945       {
2946           torch::tensor(
2947               {-0.10793942832760066,
2948                -0.35995697973682966,
2949                0.05260329955808716,
2950                -0.08312010825923577,
2951                -0.07986326997915319,
2952                0.9287409473303162}),
2953           torch::tensor(
2954               {0.2370574459090396, 0.35168415020524857, 1.278618438127574}),
2955           torch::tensor(
2956               {-1.669141810658011, -1.984894370767313, -3.091259532917102}),
2957           torch::tensor({-3.923827025320545}),
2958       },
2959       {
2960           torch::tensor(
2961               {-0.1010142826857921,
2962                -0.35067247612415425,
2963                0.06058642765135953,
2964                -0.07242353828264116,
2965                -0.07320722520220559,
2966                0.9376663294528951}),
2967           torch::tensor(
2968               {0.26037531373638517, 0.37864768429039036, 1.3021925174954938}),
2969           torch::tensor(
2970               {-1.6978623668013235, -2.017346013780729, -3.137511248751908}),
2971           torch::tensor({-3.9791368472670334}),
2972       },
2973       {
2974           torch::tensor(
2975               {-0.0950223925827384,
2976                -0.34263425874631004,
2977                0.06744912149060932,
2978                -0.0632219668955612,
2979                -0.06756850374320933,
2980                0.9452179348012486}),
2981           torch::tensor(
2982               {0.2805629021730173, 0.40186559210837897, 1.322201974233735}),
2983           torch::tensor(
2984               {-1.7226667375672964, -2.0453651314263936, -3.177094625235675}),
2985           torch::tensor({-4.02626353351958}),
2986       },
2987       {
2988           torch::tensor(
2989               {-0.08974074058929343,
2990                -0.3355446553621404,
2991                0.07346375443579244,
2992                -0.055152336279101065,
2993                -0.06268648871001672,
2994                0.9517469338705254}),
2995           torch::tensor(
2996               {0.2983667593362755, 0.42224471824689497, 1.3395523443811077}),
2997           torch::tensor(
2998               {-1.7445022249557358, -2.070022023204061, -3.2116640112699977}),
2999           torch::tensor({-4.0672678681014025}),
3000       },
3001       {
3002           torch::tensor(
3003               {-0.08501805567029425,
3004                -0.32920173535901165,
3005                0.07881418855733362,
3006                -0.04796950604202488,
3007                -0.058388411064112675,
3008                0.9574862878804136}),
3009           torch::tensor(
3010               {0.314293480998118, 0.44039784591234016, 1.3548455497581404}),
3011           torch::tensor(
3012               {-1.7640079833055848, -2.092039516337239, -3.2423272727017984}),
3013           torch::tensor({-4.1035226231344275}),
3014       },
3015       {
3016           torch::tensor(
3017               {-0.08074691916762683,
3018                -0.32346209125355385,
3019                0.08363041954091821,
3020                -0.04150011326440421,
3021                -0.05455400195824525,
3022                0.9625982669377223}),
3023           torch::tensor(
3024               {0.3287029554083447, 0.456758647445438, 1.3685016029692183}),
3025           torch::tensor(
3026               {-1.781635969965323, -2.1119291060463254, -3.2698625668054278}),
3027           torch::tensor({-4.135987715622241}),
3028       },
3029       {
3030           torch::tensor(
3031               {-0.07684825926741544,
3032                -0.3182201487496606,
3033                0.08800775942949753,
3034                -0.03561702050647377,
3035                -0.05109626940874274,
3036                0.967200308063431}),
3037           torch::tensor(
3038               {0.3418602073865105, 0.47164537681262036, 1.3808249543962092}),
3039           torch::tensor(
3040               {-1.7977177360253929, -2.1300662313234127, -3.2948372910743537}),
3041           torch::tensor({-4.165360368453936}),
3042       },
3043       {
3044           torch::tensor(
3045               {-0.07326215742204903,
3046                -0.31339589848358795,
3047                0.09201816976416921,
3048                -0.030224217178854797,
3049                -0.0479503174605994,
3050                0.9713800632469923}),
3051           torch::tensor(
3052               {0.35396614509118257, 0.485298524494989, 1.3920431643924076}),
3053           torch::tensor(
3054               {-1.8125038126190638, -2.146734711618823, -3.3176778240157505}),
3055           torch::tensor({-4.192162739857097}),
3056       },
3057   };
3058 }
3059 
3060 } // namespace expected_parameters
3061