xref: /aosp_15_r20/external/pytorch/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2#
3# Copyright (c) Facebook, Inc. and its affiliates.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""
18This example shows how to use higher to do Model Agnostic Meta Learning (MAML)
19for few-shot Omniglot classification.
20For more details see the original MAML paper:
21https://arxiv.org/abs/1703.03400
22
23This code has been modified from Jackie Loong's PyTorch MAML implementation:
24https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
25
26Our MAML++ fork and experiments are available at:
27https://github.com/bamos/HowToTrainYourMAMLPytorch
28"""
29
30import argparse
31import time
32
33import matplotlib as mpl
34import matplotlib.pyplot as plt
35import numpy as np
36import pandas as pd
37from support.omniglot_loaders import OmniglotNShot
38
39import torch
40import torch.nn.functional as F
41import torch.optim as optim
42from functorch import make_functional_with_buffers
43from torch import nn
44
45
46mpl.use("Agg")
47plt.style.use("bmh")
48
49
50def main():
51    argparser = argparse.ArgumentParser()
52    argparser.add_argument("--n-way", "--n_way", type=int, help="n way", default=5)
53    argparser.add_argument(
54        "--k-spt", "--k_spt", type=int, help="k shot for support set", default=5
55    )
56    argparser.add_argument(
57        "--k-qry", "--k_qry", type=int, help="k shot for query set", default=15
58    )
59    argparser.add_argument("--device", type=str, help="device", default="cuda")
60    argparser.add_argument(
61        "--task-num",
62        "--task_num",
63        type=int,
64        help="meta batch size, namely task num",
65        default=32,
66    )
67    argparser.add_argument("--seed", type=int, help="random seed", default=1)
68    args = argparser.parse_args()
69
70    torch.manual_seed(args.seed)
71    if torch.cuda.is_available():
72        torch.cuda.manual_seed_all(args.seed)
73    np.random.seed(args.seed)
74
75    # Set up the Omniglot loader.
76    device = args.device
77    db = OmniglotNShot(
78        "/tmp/omniglot-data",
79        batchsz=args.task_num,
80        n_way=args.n_way,
81        k_shot=args.k_spt,
82        k_query=args.k_qry,
83        imgsz=28,
84        device=device,
85    )
86
87    # Create a vanilla PyTorch neural network that will be
88    # automatically monkey-patched by higher later.
89    # Before higher, models could *not* be created like this
90    # and the parameters needed to be manually updated and copied
91    # for the updates.
92    net = nn.Sequential(
93        nn.Conv2d(1, 64, 3),
94        nn.BatchNorm2d(64, momentum=1, affine=True),
95        nn.ReLU(inplace=True),
96        nn.MaxPool2d(2, 2),
97        nn.Conv2d(64, 64, 3),
98        nn.BatchNorm2d(64, momentum=1, affine=True),
99        nn.ReLU(inplace=True),
100        nn.MaxPool2d(2, 2),
101        nn.Conv2d(64, 64, 3),
102        nn.BatchNorm2d(64, momentum=1, affine=True),
103        nn.ReLU(inplace=True),
104        nn.MaxPool2d(2, 2),
105        Flatten(),
106        nn.Linear(64, args.n_way),
107    ).to(device)
108
109    net.train()
110    fnet, params, buffers = make_functional_with_buffers(net)
111
112    # We will use Adam to (meta-)optimize the initial parameters
113    # to be adapted.
114    meta_opt = optim.Adam(params, lr=1e-3)
115
116    log = []
117    for epoch in range(100):
118        train(db, [params, buffers, fnet], device, meta_opt, epoch, log)
119        test(db, [params, buffers, fnet], device, epoch, log)
120        plot(log)
121
122
123def train(db, net, device, meta_opt, epoch, log):
124    params, buffers, fnet = net
125    n_train_iter = db.x_train.shape[0] // db.batchsz
126
127    for batch_idx in range(n_train_iter):
128        start_time = time.time()
129        # Sample a batch of support and query images and labels.
130        x_spt, y_spt, x_qry, y_qry = db.next()
131
132        task_num, setsz, c_, h, w = x_spt.size()
133        querysz = x_qry.size(1)
134
135        # TODO: Maybe pull this out into a separate module so it
136        # doesn't have to be duplicated between `train` and `test`?
137
138        # Initialize the inner optimizer to adapt the parameters to
139        # the support set.
140        n_inner_iter = 5
141        # inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
142
143        qry_losses = []
144        qry_accs = []
145        meta_opt.zero_grad()
146        for i in range(task_num):
147            # Optimize the likelihood of the support set by taking
148            # gradient steps w.r.t. the model's parameters.
149            # This adapts the model's meta-parameters to the task.
150            new_params = params
151            for _ in range(n_inner_iter):
152                spt_logits = fnet(new_params, buffers, x_spt[i])
153                spt_loss = F.cross_entropy(spt_logits, y_spt[i])
154                grads = torch.autograd.grad(spt_loss, new_params, create_graph=True)
155                new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
156
157            # The final set of adapted parameters will induce some
158            # final loss and accuracy on the query dataset.
159            # These will be used to update the model's meta-parameters.
160            qry_logits = fnet(new_params, buffers, x_qry[i])
161            qry_loss = F.cross_entropy(qry_logits, y_qry[i])
162            qry_losses.append(qry_loss.detach())
163            qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz
164            qry_accs.append(qry_acc)
165
166            # Update the model's meta-parameters to optimize the query
167            # losses across all of the tasks sampled in this batch.
168            # This unrolls through the gradient steps.
169            qry_loss.backward()
170
171        meta_opt.step()
172        qry_losses = sum(qry_losses) / task_num
173        qry_accs = 100.0 * sum(qry_accs) / task_num
174        i = epoch + float(batch_idx) / n_train_iter
175        iter_time = time.time() - start_time
176        if batch_idx % 4 == 0:
177            print(
178                f"[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}"
179            )
180
181        log.append(
182            {
183                "epoch": i,
184                "loss": qry_losses,
185                "acc": qry_accs,
186                "mode": "train",
187                "time": time.time(),
188            }
189        )
190
191
192def test(db, net, device, epoch, log):
193    # Crucially in our testing procedure here, we do *not* fine-tune
194    # the model during testing for simplicity.
195    # Most research papers using MAML for this task do an extra
196    # stage of fine-tuning here that should be added if you are
197    # adapting this code for research.
198    [params, buffers, fnet] = net
199    n_test_iter = db.x_test.shape[0] // db.batchsz
200
201    qry_losses = []
202    qry_accs = []
203
204    for batch_idx in range(n_test_iter):
205        x_spt, y_spt, x_qry, y_qry = db.next("test")
206        task_num, setsz, c_, h, w = x_spt.size()
207
208        # TODO: Maybe pull this out into a separate module so it
209        # doesn't have to be duplicated between `train` and `test`?
210        n_inner_iter = 5
211
212        for i in range(task_num):
213            new_params = params
214            for _ in range(n_inner_iter):
215                spt_logits = fnet(new_params, buffers, x_spt[i])
216                spt_loss = F.cross_entropy(spt_logits, y_spt[i])
217                grads = torch.autograd.grad(spt_loss, new_params)
218                new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
219
220            # The query loss and acc induced by these parameters.
221            qry_logits = fnet(new_params, buffers, x_qry[i]).detach()
222            qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction="none")
223            qry_losses.append(qry_loss.detach())
224            qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
225
226    qry_losses = torch.cat(qry_losses).mean().item()
227    qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item()
228    print(f"[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}")
229    log.append(
230        {
231            "epoch": epoch + 1,
232            "loss": qry_losses,
233            "acc": qry_accs,
234            "mode": "test",
235            "time": time.time(),
236        }
237    )
238
239
240def plot(log):
241    # Generally you should pull your plotting code out of your training
242    # script but we are doing it here for brevity.
243    df = pd.DataFrame(log)
244
245    fig, ax = plt.subplots(figsize=(6, 4))
246    train_df = df[df["mode"] == "train"]
247    test_df = df[df["mode"] == "test"]
248    ax.plot(train_df["epoch"], train_df["acc"], label="Train")
249    ax.plot(test_df["epoch"], test_df["acc"], label="Test")
250    ax.set_xlabel("Epoch")
251    ax.set_ylabel("Accuracy")
252    ax.set_ylim(70, 100)
253    fig.legend(ncol=2, loc="lower right")
254    fig.tight_layout()
255    fname = "maml-accs.png"
256    print(f"--- Plotting accuracy to {fname}")
257    fig.savefig(fname)
258    plt.close(fig)
259
260
261# Won't need this after this PR is merged in:
262# https://github.com/pytorch/pytorch/pull/22245
263class Flatten(nn.Module):
264    def forward(self, input):
265        return input.view(input.size(0), -1)
266
267
268if __name__ == "__main__":
269    main()
270