1#!/usr/bin/env python3 2# 3# Copyright (c) Facebook, Inc. and its affiliates. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17""" 18This example shows how to use higher to do Model Agnostic Meta Learning (MAML) 19for few-shot Omniglot classification. 20For more details see the original MAML paper: 21https://arxiv.org/abs/1703.03400 22 23This code has been modified from Jackie Loong's PyTorch MAML implementation: 24https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py 25 26Our MAML++ fork and experiments are available at: 27https://github.com/bamos/HowToTrainYourMAMLPytorch 28""" 29 30import argparse 31import time 32 33import matplotlib as mpl 34import matplotlib.pyplot as plt 35import numpy as np 36import pandas as pd 37from support.omniglot_loaders import OmniglotNShot 38 39import torch 40import torch.nn.functional as F 41import torch.optim as optim 42from functorch import make_functional_with_buffers 43from torch import nn 44 45 46mpl.use("Agg") 47plt.style.use("bmh") 48 49 50def main(): 51 argparser = argparse.ArgumentParser() 52 argparser.add_argument("--n-way", "--n_way", type=int, help="n way", default=5) 53 argparser.add_argument( 54 "--k-spt", "--k_spt", type=int, help="k shot for support set", default=5 55 ) 56 argparser.add_argument( 57 "--k-qry", "--k_qry", type=int, help="k shot for query set", default=15 58 ) 59 argparser.add_argument("--device", type=str, help="device", default="cuda") 60 argparser.add_argument( 61 "--task-num", 62 "--task_num", 63 type=int, 64 help="meta batch size, namely task num", 65 default=32, 66 ) 67 argparser.add_argument("--seed", type=int, help="random seed", default=1) 68 args = argparser.parse_args() 69 70 torch.manual_seed(args.seed) 71 if torch.cuda.is_available(): 72 torch.cuda.manual_seed_all(args.seed) 73 np.random.seed(args.seed) 74 75 # Set up the Omniglot loader. 76 device = args.device 77 db = OmniglotNShot( 78 "/tmp/omniglot-data", 79 batchsz=args.task_num, 80 n_way=args.n_way, 81 k_shot=args.k_spt, 82 k_query=args.k_qry, 83 imgsz=28, 84 device=device, 85 ) 86 87 # Create a vanilla PyTorch neural network that will be 88 # automatically monkey-patched by higher later. 89 # Before higher, models could *not* be created like this 90 # and the parameters needed to be manually updated and copied 91 # for the updates. 92 net = nn.Sequential( 93 nn.Conv2d(1, 64, 3), 94 nn.BatchNorm2d(64, momentum=1, affine=True), 95 nn.ReLU(inplace=True), 96 nn.MaxPool2d(2, 2), 97 nn.Conv2d(64, 64, 3), 98 nn.BatchNorm2d(64, momentum=1, affine=True), 99 nn.ReLU(inplace=True), 100 nn.MaxPool2d(2, 2), 101 nn.Conv2d(64, 64, 3), 102 nn.BatchNorm2d(64, momentum=1, affine=True), 103 nn.ReLU(inplace=True), 104 nn.MaxPool2d(2, 2), 105 Flatten(), 106 nn.Linear(64, args.n_way), 107 ).to(device) 108 109 net.train() 110 fnet, params, buffers = make_functional_with_buffers(net) 111 112 # We will use Adam to (meta-)optimize the initial parameters 113 # to be adapted. 114 meta_opt = optim.Adam(params, lr=1e-3) 115 116 log = [] 117 for epoch in range(100): 118 train(db, [params, buffers, fnet], device, meta_opt, epoch, log) 119 test(db, [params, buffers, fnet], device, epoch, log) 120 plot(log) 121 122 123def train(db, net, device, meta_opt, epoch, log): 124 params, buffers, fnet = net 125 n_train_iter = db.x_train.shape[0] // db.batchsz 126 127 for batch_idx in range(n_train_iter): 128 start_time = time.time() 129 # Sample a batch of support and query images and labels. 130 x_spt, y_spt, x_qry, y_qry = db.next() 131 132 task_num, setsz, c_, h, w = x_spt.size() 133 querysz = x_qry.size(1) 134 135 # TODO: Maybe pull this out into a separate module so it 136 # doesn't have to be duplicated between `train` and `test`? 137 138 # Initialize the inner optimizer to adapt the parameters to 139 # the support set. 140 n_inner_iter = 5 141 # inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1) 142 143 qry_losses = [] 144 qry_accs = [] 145 meta_opt.zero_grad() 146 for i in range(task_num): 147 # Optimize the likelihood of the support set by taking 148 # gradient steps w.r.t. the model's parameters. 149 # This adapts the model's meta-parameters to the task. 150 new_params = params 151 for _ in range(n_inner_iter): 152 spt_logits = fnet(new_params, buffers, x_spt[i]) 153 spt_loss = F.cross_entropy(spt_logits, y_spt[i]) 154 grads = torch.autograd.grad(spt_loss, new_params, create_graph=True) 155 new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)] 156 157 # The final set of adapted parameters will induce some 158 # final loss and accuracy on the query dataset. 159 # These will be used to update the model's meta-parameters. 160 qry_logits = fnet(new_params, buffers, x_qry[i]) 161 qry_loss = F.cross_entropy(qry_logits, y_qry[i]) 162 qry_losses.append(qry_loss.detach()) 163 qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz 164 qry_accs.append(qry_acc) 165 166 # Update the model's meta-parameters to optimize the query 167 # losses across all of the tasks sampled in this batch. 168 # This unrolls through the gradient steps. 169 qry_loss.backward() 170 171 meta_opt.step() 172 qry_losses = sum(qry_losses) / task_num 173 qry_accs = 100.0 * sum(qry_accs) / task_num 174 i = epoch + float(batch_idx) / n_train_iter 175 iter_time = time.time() - start_time 176 if batch_idx % 4 == 0: 177 print( 178 f"[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}" 179 ) 180 181 log.append( 182 { 183 "epoch": i, 184 "loss": qry_losses, 185 "acc": qry_accs, 186 "mode": "train", 187 "time": time.time(), 188 } 189 ) 190 191 192def test(db, net, device, epoch, log): 193 # Crucially in our testing procedure here, we do *not* fine-tune 194 # the model during testing for simplicity. 195 # Most research papers using MAML for this task do an extra 196 # stage of fine-tuning here that should be added if you are 197 # adapting this code for research. 198 [params, buffers, fnet] = net 199 n_test_iter = db.x_test.shape[0] // db.batchsz 200 201 qry_losses = [] 202 qry_accs = [] 203 204 for batch_idx in range(n_test_iter): 205 x_spt, y_spt, x_qry, y_qry = db.next("test") 206 task_num, setsz, c_, h, w = x_spt.size() 207 208 # TODO: Maybe pull this out into a separate module so it 209 # doesn't have to be duplicated between `train` and `test`? 210 n_inner_iter = 5 211 212 for i in range(task_num): 213 new_params = params 214 for _ in range(n_inner_iter): 215 spt_logits = fnet(new_params, buffers, x_spt[i]) 216 spt_loss = F.cross_entropy(spt_logits, y_spt[i]) 217 grads = torch.autograd.grad(spt_loss, new_params) 218 new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)] 219 220 # The query loss and acc induced by these parameters. 221 qry_logits = fnet(new_params, buffers, x_qry[i]).detach() 222 qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction="none") 223 qry_losses.append(qry_loss.detach()) 224 qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach()) 225 226 qry_losses = torch.cat(qry_losses).mean().item() 227 qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item() 228 print(f"[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}") 229 log.append( 230 { 231 "epoch": epoch + 1, 232 "loss": qry_losses, 233 "acc": qry_accs, 234 "mode": "test", 235 "time": time.time(), 236 } 237 ) 238 239 240def plot(log): 241 # Generally you should pull your plotting code out of your training 242 # script but we are doing it here for brevity. 243 df = pd.DataFrame(log) 244 245 fig, ax = plt.subplots(figsize=(6, 4)) 246 train_df = df[df["mode"] == "train"] 247 test_df = df[df["mode"] == "test"] 248 ax.plot(train_df["epoch"], train_df["acc"], label="Train") 249 ax.plot(test_df["epoch"], test_df["acc"], label="Test") 250 ax.set_xlabel("Epoch") 251 ax.set_ylabel("Accuracy") 252 ax.set_ylim(70, 100) 253 fig.legend(ncol=2, loc="lower right") 254 fig.tight_layout() 255 fname = "maml-accs.png" 256 print(f"--- Plotting accuracy to {fname}") 257 fig.savefig(fname) 258 plt.close(fig) 259 260 261# Won't need this after this PR is merged in: 262# https://github.com/pytorch/pytorch/pull/22245 263class Flatten(nn.Module): 264 def forward(self, input): 265 return input.view(input.size(0), -1) 266 267 268if __name__ == "__main__": 269 main() 270