1#!/usr/bin/env python3 2# 3# Copyright (c) Facebook, Inc. and its affiliates. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17""" 18This example shows how to use higher to do Model Agnostic Meta Learning (MAML) 19for few-shot Omniglot classification. 20For more details see the original MAML paper: 21https://arxiv.org/abs/1703.03400 22 23This code has been modified from Jackie Loong's PyTorch MAML implementation: 24https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py 25 26Our MAML++ fork and experiments are available at: 27https://github.com/bamos/HowToTrainYourMAMLPytorch 28""" 29 30import argparse 31import time 32 33import higher 34import matplotlib as mpl 35import matplotlib.pyplot as plt 36import numpy as np 37import pandas as pd 38from support.omniglot_loaders import OmniglotNShot 39 40import torch 41import torch.nn.functional as F 42import torch.optim as optim 43from torch import nn 44 45 46mpl.use("Agg") 47plt.style.use("bmh") 48 49 50def main(): 51 argparser = argparse.ArgumentParser() 52 argparser.add_argument("--n-way", "--n_way", type=int, help="n way", default=5) 53 argparser.add_argument( 54 "--k-spt", "--k_spt", type=int, help="k shot for support set", default=5 55 ) 56 argparser.add_argument( 57 "--k-qry", "--k_qry", type=int, help="k shot for query set", default=15 58 ) 59 argparser.add_argument("--device", type=str, help="device", default="cuda") 60 argparser.add_argument( 61 "--task-num", 62 "--task_num", 63 type=int, 64 help="meta batch size, namely task num", 65 default=32, 66 ) 67 argparser.add_argument("--seed", type=int, help="random seed", default=1) 68 args = argparser.parse_args() 69 70 torch.manual_seed(args.seed) 71 if torch.cuda.is_available(): 72 torch.cuda.manual_seed_all(args.seed) 73 np.random.seed(args.seed) 74 75 # Set up the Omniglot loader. 76 device = args.device 77 db = OmniglotNShot( 78 "/tmp/omniglot-data", 79 batchsz=args.task_num, 80 n_way=args.n_way, 81 k_shot=args.k_spt, 82 k_query=args.k_qry, 83 imgsz=28, 84 device=device, 85 ) 86 87 # Create a vanilla PyTorch neural network that will be 88 # automatically monkey-patched by higher later. 89 # Before higher, models could *not* be created like this 90 # and the parameters needed to be manually updated and copied 91 # for the updates. 92 net = nn.Sequential( 93 nn.Conv2d(1, 64, 3), 94 nn.BatchNorm2d(64, momentum=1, affine=True), 95 nn.ReLU(inplace=True), 96 nn.MaxPool2d(2, 2), 97 nn.Conv2d(64, 64, 3), 98 nn.BatchNorm2d(64, momentum=1, affine=True), 99 nn.ReLU(inplace=True), 100 nn.MaxPool2d(2, 2), 101 nn.Conv2d(64, 64, 3), 102 nn.BatchNorm2d(64, momentum=1, affine=True), 103 nn.ReLU(inplace=True), 104 nn.MaxPool2d(2, 2), 105 Flatten(), 106 nn.Linear(64, args.n_way), 107 ).to(device) 108 109 # We will use Adam to (meta-)optimize the initial parameters 110 # to be adapted. 111 meta_opt = optim.Adam(net.parameters(), lr=1e-3) 112 113 log = [] 114 for epoch in range(100): 115 train(db, net, device, meta_opt, epoch, log) 116 test(db, net, device, epoch, log) 117 plot(log) 118 119 120def train(db, net, device, meta_opt, epoch, log): 121 net.train() 122 n_train_iter = db.x_train.shape[0] // db.batchsz 123 124 for batch_idx in range(n_train_iter): 125 start_time = time.time() 126 # Sample a batch of support and query images and labels. 127 x_spt, y_spt, x_qry, y_qry = db.next() 128 129 task_num, setsz, c_, h, w = x_spt.size() 130 querysz = x_qry.size(1) 131 132 # TODO: Maybe pull this out into a separate module so it 133 # doesn't have to be duplicated between `train` and `test`? 134 135 # Initialize the inner optimizer to adapt the parameters to 136 # the support set. 137 n_inner_iter = 5 138 inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1) 139 140 qry_losses = [] 141 qry_accs = [] 142 meta_opt.zero_grad() 143 for i in range(task_num): 144 with higher.innerloop_ctx(net, inner_opt, copy_initial_weights=False) as ( 145 fnet, 146 diffopt, 147 ): 148 # Optimize the likelihood of the support set by taking 149 # gradient steps w.r.t. the model's parameters. 150 # This adapts the model's meta-parameters to the task. 151 # higher is able to automatically keep copies of 152 # your network's parameters as they are being updated. 153 for _ in range(n_inner_iter): 154 spt_logits = fnet(x_spt[i]) 155 spt_loss = F.cross_entropy(spt_logits, y_spt[i]) 156 diffopt.step(spt_loss) 157 158 # The final set of adapted parameters will induce some 159 # final loss and accuracy on the query dataset. 160 # These will be used to update the model's meta-parameters. 161 qry_logits = fnet(x_qry[i]) 162 qry_loss = F.cross_entropy(qry_logits, y_qry[i]) 163 qry_losses.append(qry_loss.detach()) 164 qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz 165 qry_accs.append(qry_acc) 166 167 # print([b.shape for b in fnet[1].buffers()]) 168 169 # Update the model's meta-parameters to optimize the query 170 # losses across all of the tasks sampled in this batch. 171 # This unrolls through the gradient steps. 172 qry_loss.backward() 173 174 meta_opt.step() 175 qry_losses = sum(qry_losses) / task_num 176 qry_accs = 100.0 * sum(qry_accs) / task_num 177 i = epoch + float(batch_idx) / n_train_iter 178 iter_time = time.time() - start_time 179 if batch_idx % 4 == 0: 180 print( 181 f"[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}" 182 ) 183 184 log.append( 185 { 186 "epoch": i, 187 "loss": qry_losses, 188 "acc": qry_accs, 189 "mode": "train", 190 "time": time.time(), 191 } 192 ) 193 194 195def test(db, net, device, epoch, log): 196 # Crucially in our testing procedure here, we do *not* fine-tune 197 # the model during testing for simplicity. 198 # Most research papers using MAML for this task do an extra 199 # stage of fine-tuning here that should be added if you are 200 # adapting this code for research. 201 net.train() 202 n_test_iter = db.x_test.shape[0] // db.batchsz 203 204 qry_losses = [] 205 qry_accs = [] 206 207 for _ in range(n_test_iter): 208 x_spt, y_spt, x_qry, y_qry = db.next("test") 209 210 task_num, setsz, c_, h, w = x_spt.size() 211 212 # TODO: Maybe pull this out into a separate module so it 213 # doesn't have to be duplicated between `train` and `test`? 214 n_inner_iter = 5 215 inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1) 216 217 for i in range(task_num): 218 with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as ( 219 fnet, 220 diffopt, 221 ): 222 # Optimize the likelihood of the support set by taking 223 # gradient steps w.r.t. the model's parameters. 224 # This adapts the model's meta-parameters to the task. 225 for _ in range(n_inner_iter): 226 spt_logits = fnet(x_spt[i]) 227 spt_loss = F.cross_entropy(spt_logits, y_spt[i]) 228 diffopt.step(spt_loss) 229 230 # The query loss and acc induced by these parameters. 231 qry_logits = fnet(x_qry[i]).detach() 232 qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction="none") 233 qry_losses.append(qry_loss.detach()) 234 qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach()) 235 236 qry_losses = torch.cat(qry_losses).mean().item() 237 qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item() 238 print(f"[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}") 239 log.append( 240 { 241 "epoch": epoch + 1, 242 "loss": qry_losses, 243 "acc": qry_accs, 244 "mode": "test", 245 "time": time.time(), 246 } 247 ) 248 249 250def plot(log): 251 # Generally you should pull your plotting code out of your training 252 # script but we are doing it here for brevity. 253 df = pd.DataFrame(log) 254 255 fig, ax = plt.subplots(figsize=(6, 4)) 256 train_df = df[df["mode"] == "train"] 257 test_df = df[df["mode"] == "test"] 258 ax.plot(train_df["epoch"], train_df["acc"], label="Train") 259 ax.plot(test_df["epoch"], test_df["acc"], label="Test") 260 ax.set_xlabel("Epoch") 261 ax.set_ylabel("Accuracy") 262 ax.set_ylim(70, 100) 263 fig.legend(ncol=2, loc="lower right") 264 fig.tight_layout() 265 fname = "maml-accs.png" 266 print(f"--- Plotting accuracy to {fname}") 267 fig.savefig(fname) 268 plt.close(fig) 269 270 271# Won't need this after this PR is merged in: 272# https://github.com/pytorch/pytorch/pull/22245 273class Flatten(nn.Module): 274 def forward(self, input): 275 return input.view(input.size(0), -1) 276 277 278if __name__ == "__main__": 279 main() 280