1*da0073e9SAndroid Build Coastguard Worker{ 2*da0073e9SAndroid Build Coastguard Worker "cells": [ 3*da0073e9SAndroid Build Coastguard Worker { 4*da0073e9SAndroid Build Coastguard Worker "cell_type": "markdown", 5*da0073e9SAndroid Build Coastguard Worker "id": "de1548fb-a313-4e9c-ae5d-8ec4c12ddd94", 6*da0073e9SAndroid Build Coastguard Worker "metadata": { 7*da0073e9SAndroid Build Coastguard Worker "id": "de1548fb-a313-4e9c-ae5d-8ec4c12ddd94" 8*da0073e9SAndroid Build Coastguard Worker }, 9*da0073e9SAndroid Build Coastguard Worker "source": [ 10*da0073e9SAndroid Build Coastguard Worker "# Model ensembling\n", 11*da0073e9SAndroid Build Coastguard Worker "\n", 12*da0073e9SAndroid Build Coastguard Worker "This example illustrates how to vectorize model ensembling using vmap.\n", 13*da0073e9SAndroid Build Coastguard Worker "\n", 14*da0073e9SAndroid Build Coastguard Worker "<a href=\"https://colab.research.google.com/github/pytorch/pytorch/blob/master/functorch/notebooks/ensembling.ipynb\">\n", 15*da0073e9SAndroid Build Coastguard Worker " <img style=\"width: auto\" src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n", 16*da0073e9SAndroid Build Coastguard Worker "</a>\n", 17*da0073e9SAndroid Build Coastguard Worker "\n", 18*da0073e9SAndroid Build Coastguard Worker "## What is model ensembling?\n", 19*da0073e9SAndroid Build Coastguard Worker "Model ensembling combines the predictions from multiple models together.\n", 20*da0073e9SAndroid Build Coastguard Worker "Traditionally this is done by running each model on some inputs separately\n", 21*da0073e9SAndroid Build Coastguard Worker "and then combining the predictions. However, if you're running models with\n", 22*da0073e9SAndroid Build Coastguard Worker "the same architecture, then it may be possible to combine them together\n", 23*da0073e9SAndroid Build Coastguard Worker "using `vmap`. `vmap` is a function transform that maps functions across\n", 24*da0073e9SAndroid Build Coastguard Worker "dimensions of the input tensors. One of its use cases is eliminating\n", 25*da0073e9SAndroid Build Coastguard Worker "for-loops and speeding them up through vectorization.\n", 26*da0073e9SAndroid Build Coastguard Worker "\n", 27*da0073e9SAndroid Build Coastguard Worker "Let's demonstrate how to do this using an ensemble of simple MLPs." 28*da0073e9SAndroid Build Coastguard Worker ] 29*da0073e9SAndroid Build Coastguard Worker }, 30*da0073e9SAndroid Build Coastguard Worker { 31*da0073e9SAndroid Build Coastguard Worker "cell_type": "code", 32*da0073e9SAndroid Build Coastguard Worker "source": [ 33*da0073e9SAndroid Build Coastguard Worker "import torch\n", 34*da0073e9SAndroid Build Coastguard Worker "import torch.nn as nn\n", 35*da0073e9SAndroid Build Coastguard Worker "import torch.nn.functional as F\n", 36*da0073e9SAndroid Build Coastguard Worker "from functools import partial\n", 37*da0073e9SAndroid Build Coastguard Worker "torch.manual_seed(0);" 38*da0073e9SAndroid Build Coastguard Worker ], 39*da0073e9SAndroid Build Coastguard Worker "metadata": { 40*da0073e9SAndroid Build Coastguard Worker "id": "Gb-yt4VKUUuc" 41*da0073e9SAndroid Build Coastguard Worker }, 42*da0073e9SAndroid Build Coastguard Worker "execution_count": null, 43*da0073e9SAndroid Build Coastguard Worker "outputs": [], 44*da0073e9SAndroid Build Coastguard Worker "id": "Gb-yt4VKUUuc" 45*da0073e9SAndroid Build Coastguard Worker }, 46*da0073e9SAndroid Build Coastguard Worker { 47*da0073e9SAndroid Build Coastguard Worker "cell_type": "code", 48*da0073e9SAndroid Build Coastguard Worker "source": [ 49*da0073e9SAndroid Build Coastguard Worker "# Here's a simple MLP\n", 50*da0073e9SAndroid Build Coastguard Worker "class SimpleMLP(nn.Module):\n", 51*da0073e9SAndroid Build Coastguard Worker " def __init__(self):\n", 52*da0073e9SAndroid Build Coastguard Worker " super().__init__()\n", 53*da0073e9SAndroid Build Coastguard Worker " self.fc1 = nn.Linear(784, 128)\n", 54*da0073e9SAndroid Build Coastguard Worker " self.fc2 = nn.Linear(128, 128)\n", 55*da0073e9SAndroid Build Coastguard Worker " self.fc3 = nn.Linear(128, 10)\n", 56*da0073e9SAndroid Build Coastguard Worker "\n", 57*da0073e9SAndroid Build Coastguard Worker " def forward(self, x):\n", 58*da0073e9SAndroid Build Coastguard Worker " x = x.flatten(1)\n", 59*da0073e9SAndroid Build Coastguard Worker " x = self.fc1(x)\n", 60*da0073e9SAndroid Build Coastguard Worker " x = F.relu(x)\n", 61*da0073e9SAndroid Build Coastguard Worker " x = self.fc2(x)\n", 62*da0073e9SAndroid Build Coastguard Worker " x = F.relu(x)\n", 63*da0073e9SAndroid Build Coastguard Worker " x = self.fc3(x)\n", 64*da0073e9SAndroid Build Coastguard Worker " return x\n" 65*da0073e9SAndroid Build Coastguard Worker ], 66*da0073e9SAndroid Build Coastguard Worker "metadata": { 67*da0073e9SAndroid Build Coastguard Worker "id": "tf-HKHjUUbyY" 68*da0073e9SAndroid Build Coastguard Worker }, 69*da0073e9SAndroid Build Coastguard Worker "execution_count": null, 70*da0073e9SAndroid Build Coastguard Worker "outputs": [], 71*da0073e9SAndroid Build Coastguard Worker "id": "tf-HKHjUUbyY" 72*da0073e9SAndroid Build Coastguard Worker }, 73*da0073e9SAndroid Build Coastguard Worker { 74*da0073e9SAndroid Build Coastguard Worker "cell_type": "markdown", 75*da0073e9SAndroid Build Coastguard Worker "source": [ 76*da0073e9SAndroid Build Coastguard Worker "Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset. Thus, the dummy images are 28 by 28, and we have a minibatch of size 64. Furthermore, lets say we want to combine the predictions from 10 different models. \n" 77*da0073e9SAndroid Build Coastguard Worker ], 78*da0073e9SAndroid Build Coastguard Worker "metadata": { 79*da0073e9SAndroid Build Coastguard Worker "id": "VEDPe-EoU5Fa" 80*da0073e9SAndroid Build Coastguard Worker }, 81*da0073e9SAndroid Build Coastguard Worker "id": "VEDPe-EoU5Fa" 82*da0073e9SAndroid Build Coastguard Worker }, 83*da0073e9SAndroid Build Coastguard Worker { 84*da0073e9SAndroid Build Coastguard Worker "cell_type": "code", 85*da0073e9SAndroid Build Coastguard Worker "source": [ 86*da0073e9SAndroid Build Coastguard Worker "device = 'cuda'\n", 87*da0073e9SAndroid Build Coastguard Worker "num_models = 10\n", 88*da0073e9SAndroid Build Coastguard Worker "\n", 89*da0073e9SAndroid Build Coastguard Worker "data = torch.randn(100, 64, 1, 28, 28, device=device)\n", 90*da0073e9SAndroid Build Coastguard Worker "targets = torch.randint(10, (6400,), device=device)\n", 91*da0073e9SAndroid Build Coastguard Worker "\n", 92*da0073e9SAndroid Build Coastguard Worker "models = [SimpleMLP().to(device) for _ in range(num_models)]" 93*da0073e9SAndroid Build Coastguard Worker ], 94*da0073e9SAndroid Build Coastguard Worker "metadata": { 95*da0073e9SAndroid Build Coastguard Worker "id": "WB2Qe3AHUvPN" 96*da0073e9SAndroid Build Coastguard Worker }, 97*da0073e9SAndroid Build Coastguard Worker "execution_count": null, 98*da0073e9SAndroid Build Coastguard Worker "outputs": [], 99*da0073e9SAndroid Build Coastguard Worker "id": "WB2Qe3AHUvPN" 100*da0073e9SAndroid Build Coastguard Worker }, 101*da0073e9SAndroid Build Coastguard Worker { 102*da0073e9SAndroid Build Coastguard Worker "cell_type": "markdown", 103*da0073e9SAndroid Build Coastguard Worker "source": [ 104*da0073e9SAndroid Build Coastguard Worker "We have a couple of options for generating predictions. Maybe we want to give each model a different randomized minibatch of data. Alternatively, maybe we want to run the same minibatch of data through each model (e.g. if we were testing the effect of different model initializations).\n", 105*da0073e9SAndroid Build Coastguard Worker "\n", 106*da0073e9SAndroid Build Coastguard Worker "\n", 107*da0073e9SAndroid Build Coastguard Worker "\n" 108*da0073e9SAndroid Build Coastguard Worker ], 109*da0073e9SAndroid Build Coastguard Worker "metadata": { 110*da0073e9SAndroid Build Coastguard Worker "id": "GOGJ-OUxVcT5" 111*da0073e9SAndroid Build Coastguard Worker }, 112*da0073e9SAndroid Build Coastguard Worker "id": "GOGJ-OUxVcT5" 113*da0073e9SAndroid Build Coastguard Worker }, 114*da0073e9SAndroid Build Coastguard Worker { 115*da0073e9SAndroid Build Coastguard Worker "cell_type": "markdown", 116*da0073e9SAndroid Build Coastguard Worker "source": [ 117*da0073e9SAndroid Build Coastguard Worker "Option 1: different minibatch for each model" 118*da0073e9SAndroid Build Coastguard Worker ], 119*da0073e9SAndroid Build Coastguard Worker "metadata": { 120*da0073e9SAndroid Build Coastguard Worker "id": "CwJBb09MxCN3" 121*da0073e9SAndroid Build Coastguard Worker }, 122*da0073e9SAndroid Build Coastguard Worker "id": "CwJBb09MxCN3" 123*da0073e9SAndroid Build Coastguard Worker }, 124*da0073e9SAndroid Build Coastguard Worker { 125*da0073e9SAndroid Build Coastguard Worker "cell_type": "code", 126*da0073e9SAndroid Build Coastguard Worker "source": [ 127*da0073e9SAndroid Build Coastguard Worker "minibatches = data[:num_models]\n", 128*da0073e9SAndroid Build Coastguard Worker "predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]" 129*da0073e9SAndroid Build Coastguard Worker ], 130*da0073e9SAndroid Build Coastguard Worker "metadata": { 131*da0073e9SAndroid Build Coastguard Worker "id": "WYjMx8QTUvRu" 132*da0073e9SAndroid Build Coastguard Worker }, 133*da0073e9SAndroid Build Coastguard Worker "execution_count": null, 134*da0073e9SAndroid Build Coastguard Worker "outputs": [], 135*da0073e9SAndroid Build Coastguard Worker "id": "WYjMx8QTUvRu" 136*da0073e9SAndroid Build Coastguard Worker }, 137*da0073e9SAndroid Build Coastguard Worker { 138*da0073e9SAndroid Build Coastguard Worker "cell_type": "markdown", 139*da0073e9SAndroid Build Coastguard Worker "source": [ 140*da0073e9SAndroid Build Coastguard Worker "Option 2: Same minibatch" 141*da0073e9SAndroid Build Coastguard Worker ], 142*da0073e9SAndroid Build Coastguard Worker "metadata": { 143*da0073e9SAndroid Build Coastguard Worker "id": "HNw4_IVzU5Pz" 144*da0073e9SAndroid Build Coastguard Worker }, 145*da0073e9SAndroid Build Coastguard Worker "id": "HNw4_IVzU5Pz" 146*da0073e9SAndroid Build Coastguard Worker }, 147*da0073e9SAndroid Build Coastguard Worker { 148*da0073e9SAndroid Build Coastguard Worker "cell_type": "code", 149*da0073e9SAndroid Build Coastguard Worker "source": [ 150*da0073e9SAndroid Build Coastguard Worker "minibatch = data[0]\n", 151*da0073e9SAndroid Build Coastguard Worker "predictions2 = [model(minibatch) for model in models]" 152*da0073e9SAndroid Build Coastguard Worker ], 153*da0073e9SAndroid Build Coastguard Worker "metadata": { 154*da0073e9SAndroid Build Coastguard Worker "id": "vUsb3VfexJrY" 155*da0073e9SAndroid Build Coastguard Worker }, 156*da0073e9SAndroid Build Coastguard Worker "execution_count": null, 157*da0073e9SAndroid Build Coastguard Worker "outputs": [], 158*da0073e9SAndroid Build Coastguard Worker "id": "vUsb3VfexJrY" 159*da0073e9SAndroid Build Coastguard Worker }, 160*da0073e9SAndroid Build Coastguard Worker { 161*da0073e9SAndroid Build Coastguard Worker "cell_type": "markdown", 162*da0073e9SAndroid Build Coastguard Worker "source": [ 163*da0073e9SAndroid Build Coastguard Worker "## Using vmap to vectorize the ensemble\n", 164*da0073e9SAndroid Build Coastguard Worker "\n", 165*da0073e9SAndroid Build Coastguard Worker "\n", 166*da0073e9SAndroid Build Coastguard Worker "\n", 167*da0073e9SAndroid Build Coastguard Worker "\n" 168*da0073e9SAndroid Build Coastguard Worker ], 169*da0073e9SAndroid Build Coastguard Worker "metadata": { 170*da0073e9SAndroid Build Coastguard Worker "id": "aNkX6lFIxzcm" 171*da0073e9SAndroid Build Coastguard Worker }, 172*da0073e9SAndroid Build Coastguard Worker "id": "aNkX6lFIxzcm" 173*da0073e9SAndroid Build Coastguard Worker }, 174*da0073e9SAndroid Build Coastguard Worker { 175*da0073e9SAndroid Build Coastguard Worker "cell_type": "markdown", 176*da0073e9SAndroid Build Coastguard Worker "source": [ 177*da0073e9SAndroid Build Coastguard Worker "Let’s use vmap to speed up the for-loop. We must first prepare the models for use with vmap.\n", 178*da0073e9SAndroid Build Coastguard Worker "\n", 179*da0073e9SAndroid Build Coastguard Worker "First, let’s combine the states of the model together by stacking each parameter. For example, `model[i].fc1.weight` has shape `[784, 128]`; we are going to stack the .fc1.weight of each of the 10 models to produce a big weight of shape `[10, 784, 128]`.\n", 180*da0073e9SAndroid Build Coastguard Worker "\n", 181*da0073e9SAndroid Build Coastguard Worker "functorch offers the 'combine_state_for_ensemble' convenience function to do that. It returns a stateless version of the model (fmodel) and stacked parameters and buffers.\n", 182*da0073e9SAndroid Build Coastguard Worker "\n" 183*da0073e9SAndroid Build Coastguard Worker ], 184*da0073e9SAndroid Build Coastguard Worker "metadata": { 185*da0073e9SAndroid Build Coastguard Worker "id": "-sFMojhryviM" 186*da0073e9SAndroid Build Coastguard Worker }, 187*da0073e9SAndroid Build Coastguard Worker "id": "-sFMojhryviM" 188*da0073e9SAndroid Build Coastguard Worker }, 189*da0073e9SAndroid Build Coastguard Worker { 190*da0073e9SAndroid Build Coastguard Worker "cell_type": "code", 191*da0073e9SAndroid Build Coastguard Worker "source": [ 192*da0073e9SAndroid Build Coastguard Worker "from functorch import combine_state_for_ensemble\n", 193*da0073e9SAndroid Build Coastguard Worker "\n", 194*da0073e9SAndroid Build Coastguard Worker "fmodel, params, buffers = combine_state_for_ensemble(models)\n", 195*da0073e9SAndroid Build Coastguard Worker "[p.requires_grad_() for p in params];\n" 196*da0073e9SAndroid Build Coastguard Worker ], 197*da0073e9SAndroid Build Coastguard Worker "metadata": { 198*da0073e9SAndroid Build Coastguard Worker "id": "C3a9_clvyPho" 199*da0073e9SAndroid Build Coastguard Worker }, 200*da0073e9SAndroid Build Coastguard Worker "execution_count": null, 201*da0073e9SAndroid Build Coastguard Worker "outputs": [], 202*da0073e9SAndroid Build Coastguard Worker "id": "C3a9_clvyPho" 203*da0073e9SAndroid Build Coastguard Worker }, 204*da0073e9SAndroid Build Coastguard Worker { 205*da0073e9SAndroid Build Coastguard Worker "cell_type": "markdown", 206*da0073e9SAndroid Build Coastguard Worker "source": [ 207*da0073e9SAndroid Build Coastguard Worker "Option 1: get predictions using a different minibatch for each model. \n", 208*da0073e9SAndroid Build Coastguard Worker "\n", 209*da0073e9SAndroid Build Coastguard Worker "By default, vmap maps a function across the first dimension of all inputs to the passed-in function. After using the combine_state_for_ensemble, each of the params and buffers have an additional dimension of size 'num_models' at the front, and minibatches has a dimension of size 'num_models'.\n", 210*da0073e9SAndroid Build Coastguard Worker "\n", 211*da0073e9SAndroid Build Coastguard Worker "\n", 212*da0073e9SAndroid Build Coastguard Worker "\n", 213*da0073e9SAndroid Build Coastguard Worker "\n" 214*da0073e9SAndroid Build Coastguard Worker ], 215*da0073e9SAndroid Build Coastguard Worker "metadata": { 216*da0073e9SAndroid Build Coastguard Worker "id": "mFJDWMM9yaYZ" 217*da0073e9SAndroid Build Coastguard Worker }, 218*da0073e9SAndroid Build Coastguard Worker "id": "mFJDWMM9yaYZ" 219*da0073e9SAndroid Build Coastguard Worker }, 220*da0073e9SAndroid Build Coastguard Worker { 221*da0073e9SAndroid Build Coastguard Worker "cell_type": "code", 222*da0073e9SAndroid Build Coastguard Worker "source": [ 223*da0073e9SAndroid Build Coastguard Worker "print([p.size(0) for p in params]) # show the leading 'num_models' dimension\n", 224*da0073e9SAndroid Build Coastguard Worker "\n", 225*da0073e9SAndroid Build Coastguard Worker "assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'" 226*da0073e9SAndroid Build Coastguard Worker ], 227*da0073e9SAndroid Build Coastguard Worker "metadata": { 228*da0073e9SAndroid Build Coastguard Worker "colab": { 229*da0073e9SAndroid Build Coastguard Worker "base_uri": "https://localhost:8080/" 230*da0073e9SAndroid Build Coastguard Worker }, 231*da0073e9SAndroid Build Coastguard Worker "id": "ezuFQx1G1zLG", 232*da0073e9SAndroid Build Coastguard Worker "outputId": "ab260da3-77f2-4ff9-d843-e0d0f1e0a884" 233*da0073e9SAndroid Build Coastguard Worker }, 234*da0073e9SAndroid Build Coastguard Worker "execution_count": null, 235*da0073e9SAndroid Build Coastguard Worker "outputs": [ 236*da0073e9SAndroid Build Coastguard Worker { 237*da0073e9SAndroid Build Coastguard Worker "output_type": "stream", 238*da0073e9SAndroid Build Coastguard Worker "name": "stdout", 239*da0073e9SAndroid Build Coastguard Worker "text": [ 240*da0073e9SAndroid Build Coastguard Worker "[10, 10, 10, 10, 10, 10]\n" 241*da0073e9SAndroid Build Coastguard Worker ] 242*da0073e9SAndroid Build Coastguard Worker } 243*da0073e9SAndroid Build Coastguard Worker ], 244*da0073e9SAndroid Build Coastguard Worker "id": "ezuFQx1G1zLG" 245*da0073e9SAndroid Build Coastguard Worker }, 246*da0073e9SAndroid Build Coastguard Worker { 247*da0073e9SAndroid Build Coastguard Worker "cell_type": "code", 248*da0073e9SAndroid Build Coastguard Worker "source": [ 249*da0073e9SAndroid Build Coastguard Worker "from functorch import vmap\n", 250*da0073e9SAndroid Build Coastguard Worker "\n", 251*da0073e9SAndroid Build Coastguard Worker "predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)\n", 252*da0073e9SAndroid Build Coastguard Worker "\n", 253*da0073e9SAndroid Build Coastguard Worker "# verify the vmap predictions match the \n", 254*da0073e9SAndroid Build Coastguard Worker "assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)" 255*da0073e9SAndroid Build Coastguard Worker ], 256*da0073e9SAndroid Build Coastguard Worker "metadata": { 257*da0073e9SAndroid Build Coastguard Worker "id": "VroLnfD82DDf" 258*da0073e9SAndroid Build Coastguard Worker }, 259*da0073e9SAndroid Build Coastguard Worker "execution_count": null, 260*da0073e9SAndroid Build Coastguard Worker "outputs": [], 261*da0073e9SAndroid Build Coastguard Worker "id": "VroLnfD82DDf" 262*da0073e9SAndroid Build Coastguard Worker }, 263*da0073e9SAndroid Build Coastguard Worker { 264*da0073e9SAndroid Build Coastguard Worker "cell_type": "markdown", 265*da0073e9SAndroid Build Coastguard Worker "source": [ 266*da0073e9SAndroid Build Coastguard Worker "Option 2: get predictions using the same minibatch of data.\n", 267*da0073e9SAndroid Build Coastguard Worker "\n", 268*da0073e9SAndroid Build Coastguard Worker "vmap has an in_dims arg that specifies which dimensions to map over. By using `None`, we tell vmap we want the same minibatch to apply for all of the 10 models.\n", 269*da0073e9SAndroid Build Coastguard Worker "\n", 270*da0073e9SAndroid Build Coastguard Worker "\n" 271*da0073e9SAndroid Build Coastguard Worker ], 272*da0073e9SAndroid Build Coastguard Worker "metadata": { 273*da0073e9SAndroid Build Coastguard Worker "id": "tlkmyQyfY6XU" 274*da0073e9SAndroid Build Coastguard Worker }, 275*da0073e9SAndroid Build Coastguard Worker "id": "tlkmyQyfY6XU" 276*da0073e9SAndroid Build Coastguard Worker }, 277*da0073e9SAndroid Build Coastguard Worker { 278*da0073e9SAndroid Build Coastguard Worker "cell_type": "code", 279*da0073e9SAndroid Build Coastguard Worker "source": [ 280*da0073e9SAndroid Build Coastguard Worker "predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)\n", 281*da0073e9SAndroid Build Coastguard Worker "\n", 282*da0073e9SAndroid Build Coastguard Worker "assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)" 283*da0073e9SAndroid Build Coastguard Worker ], 284*da0073e9SAndroid Build Coastguard Worker "metadata": { 285*da0073e9SAndroid Build Coastguard Worker "id": "WiSMupvCyecd" 286*da0073e9SAndroid Build Coastguard Worker }, 287*da0073e9SAndroid Build Coastguard Worker "execution_count": null, 288*da0073e9SAndroid Build Coastguard Worker "outputs": [], 289*da0073e9SAndroid Build Coastguard Worker "id": "WiSMupvCyecd" 290*da0073e9SAndroid Build Coastguard Worker }, 291*da0073e9SAndroid Build Coastguard Worker { 292*da0073e9SAndroid Build Coastguard Worker "cell_type": "markdown", 293*da0073e9SAndroid Build Coastguard Worker "source": [ 294*da0073e9SAndroid Build Coastguard Worker "A quick note: there are limitations around what types of functions can be transformed by vmap. The best functions to transform are ones that are pure functions: a function where the outputs are only determined by the inputs that have no side effects (e.g. mutation). vmap is unable to handle mutation of arbitrary Python data structures, but it is able to handle many in-place PyTorch operations." 295*da0073e9SAndroid Build Coastguard Worker ], 296*da0073e9SAndroid Build Coastguard Worker "metadata": { 297*da0073e9SAndroid Build Coastguard Worker "id": "KrXQsUCIGLWm" 298*da0073e9SAndroid Build Coastguard Worker }, 299*da0073e9SAndroid Build Coastguard Worker "id": "KrXQsUCIGLWm" 300*da0073e9SAndroid Build Coastguard Worker }, 301*da0073e9SAndroid Build Coastguard Worker { 302*da0073e9SAndroid Build Coastguard Worker "cell_type": "markdown", 303*da0073e9SAndroid Build Coastguard Worker "source": [ 304*da0073e9SAndroid Build Coastguard Worker "## Performance\n", 305*da0073e9SAndroid Build Coastguard Worker "\n", 306*da0073e9SAndroid Build Coastguard Worker "Curious about performance numbers? Here's how the numbers look on Google Colab." 307*da0073e9SAndroid Build Coastguard Worker ], 308*da0073e9SAndroid Build Coastguard Worker "metadata": { 309*da0073e9SAndroid Build Coastguard Worker "id": "MCjBhMrVF5hH" 310*da0073e9SAndroid Build Coastguard Worker }, 311*da0073e9SAndroid Build Coastguard Worker "id": "MCjBhMrVF5hH" 312*da0073e9SAndroid Build Coastguard Worker }, 313*da0073e9SAndroid Build Coastguard Worker { 314*da0073e9SAndroid Build Coastguard Worker "cell_type": "code", 315*da0073e9SAndroid Build Coastguard Worker "source": [ 316*da0073e9SAndroid Build Coastguard Worker "from torch.utils.benchmark import Timer\n", 317*da0073e9SAndroid Build Coastguard Worker "without_vmap = Timer(\n", 318*da0073e9SAndroid Build Coastguard Worker " stmt=\"[model(minibatch) for model, minibatch in zip(models, minibatches)]\",\n", 319*da0073e9SAndroid Build Coastguard Worker " globals=globals())\n", 320*da0073e9SAndroid Build Coastguard Worker "with_vmap = Timer(\n", 321*da0073e9SAndroid Build Coastguard Worker " stmt=\"vmap(fmodel)(params, buffers, minibatches)\",\n", 322*da0073e9SAndroid Build Coastguard Worker " globals=globals())\n", 323*da0073e9SAndroid Build Coastguard Worker "print(f'Predictions without vmap {without_vmap.timeit(100)}')\n", 324*da0073e9SAndroid Build Coastguard Worker "print(f'Predictions with vmap {with_vmap.timeit(100)}')" 325*da0073e9SAndroid Build Coastguard Worker ], 326*da0073e9SAndroid Build Coastguard Worker "metadata": { 327*da0073e9SAndroid Build Coastguard Worker "colab": { 328*da0073e9SAndroid Build Coastguard Worker "base_uri": "https://localhost:8080/" 329*da0073e9SAndroid Build Coastguard Worker }, 330*da0073e9SAndroid Build Coastguard Worker "id": "gJPrGdS0GBjz", 331*da0073e9SAndroid Build Coastguard Worker "outputId": "04e75950-b964-419c-fa9c-f1590e0081bb" 332*da0073e9SAndroid Build Coastguard Worker }, 333*da0073e9SAndroid Build Coastguard Worker "execution_count": null, 334*da0073e9SAndroid Build Coastguard Worker "outputs": [ 335*da0073e9SAndroid Build Coastguard Worker { 336*da0073e9SAndroid Build Coastguard Worker "output_type": "stream", 337*da0073e9SAndroid Build Coastguard Worker "name": "stdout", 338*da0073e9SAndroid Build Coastguard Worker "text": [ 339*da0073e9SAndroid Build Coastguard Worker "Predictions without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7fe22c58b3d0>\n", 340*da0073e9SAndroid Build Coastguard Worker "[model(minibatch) for model, minibatch in zip(models, minibatches)]\n", 341*da0073e9SAndroid Build Coastguard Worker " 3.25 ms\n", 342*da0073e9SAndroid Build Coastguard Worker " 1 measurement, 100 runs , 1 thread\n", 343*da0073e9SAndroid Build Coastguard Worker "Predictions with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7fe22c50c450>\n", 344*da0073e9SAndroid Build Coastguard Worker "vmap(fmodel)(params, buffers, minibatches)\n", 345*da0073e9SAndroid Build Coastguard Worker " 879.28 us\n", 346*da0073e9SAndroid Build Coastguard Worker " 1 measurement, 100 runs , 1 thread\n" 347*da0073e9SAndroid Build Coastguard Worker ] 348*da0073e9SAndroid Build Coastguard Worker } 349*da0073e9SAndroid Build Coastguard Worker ], 350*da0073e9SAndroid Build Coastguard Worker "id": "gJPrGdS0GBjz" 351*da0073e9SAndroid Build Coastguard Worker }, 352*da0073e9SAndroid Build Coastguard Worker { 353*da0073e9SAndroid Build Coastguard Worker "cell_type": "markdown", 354*da0073e9SAndroid Build Coastguard Worker "source": [ 355*da0073e9SAndroid Build Coastguard Worker "There's a large speedup using vmap! \n", 356*da0073e9SAndroid Build Coastguard Worker "\n", 357*da0073e9SAndroid Build Coastguard Worker "In general, vectorization with vmap should be faster than running a function in a for-loop and competitive with manual batching. There are some exceptions though, like if we haven’t implemented the vmap rule for a particular operation or if the underlying kernels weren’t optimized for older hardware (GPUs). If you see any of these cases, please let us know by opening an issue at our [GitHub](https://github.com/pytorch/functorch)!\n", 358*da0073e9SAndroid Build Coastguard Worker "\n" 359*da0073e9SAndroid Build Coastguard Worker ], 360*da0073e9SAndroid Build Coastguard Worker "metadata": { 361*da0073e9SAndroid Build Coastguard Worker "id": "UI74G9JarQU8" 362*da0073e9SAndroid Build Coastguard Worker }, 363*da0073e9SAndroid Build Coastguard Worker "id": "UI74G9JarQU8" 364*da0073e9SAndroid Build Coastguard Worker } 365*da0073e9SAndroid Build Coastguard Worker ], 366*da0073e9SAndroid Build Coastguard Worker "metadata": { 367*da0073e9SAndroid Build Coastguard Worker "kernelspec": { 368*da0073e9SAndroid Build Coastguard Worker "display_name": "Python 3", 369*da0073e9SAndroid Build Coastguard Worker "language": "python", 370*da0073e9SAndroid Build Coastguard Worker "name": "python3" 371*da0073e9SAndroid Build Coastguard Worker }, 372*da0073e9SAndroid Build Coastguard Worker "language_info": { 373*da0073e9SAndroid Build Coastguard Worker "codemirror_mode": { 374*da0073e9SAndroid Build Coastguard Worker "name": "ipython", 375*da0073e9SAndroid Build Coastguard Worker "version": 3 376*da0073e9SAndroid Build Coastguard Worker }, 377*da0073e9SAndroid Build Coastguard Worker "file_extension": ".py", 378*da0073e9SAndroid Build Coastguard Worker "mimetype": "text/x-python", 379*da0073e9SAndroid Build Coastguard Worker "name": "python", 380*da0073e9SAndroid Build Coastguard Worker "nbconvert_exporter": "python", 381*da0073e9SAndroid Build Coastguard Worker "pygments_lexer": "ipython3", 382*da0073e9SAndroid Build Coastguard Worker "version": "3.8.5" 383*da0073e9SAndroid Build Coastguard Worker }, 384*da0073e9SAndroid Build Coastguard Worker "colab": { 385*da0073e9SAndroid Build Coastguard Worker "name": "ensembling.ipynb", 386*da0073e9SAndroid Build Coastguard Worker "provenance": [] 387*da0073e9SAndroid Build Coastguard Worker } 388*da0073e9SAndroid Build Coastguard Worker }, 389*da0073e9SAndroid Build Coastguard Worker "nbformat": 4, 390*da0073e9SAndroid Build Coastguard Worker "nbformat_minor": 5 391*da0073e9SAndroid Build Coastguard Worker} 392