1*da0073e9SAndroid Build Coastguard WorkerNamed Tensors using First-class Dimensions in PyTorch 2*da0073e9SAndroid Build Coastguard Worker===================================================== 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Worker-- Zachary DeVito [@Zachary_DeVito](https://twitter.com/Zachary_DeVito) 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker_An implementation of [named tensors](https://namedtensor.github.io) with the functionality of [einsum](http://einops.rocks]http://einops.rocks) , batching ([vmap](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap), [xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html)), and tensor indexing by adding dimension objects to PyTorch_. 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard WorkerThe tensor input to a resnet might have the shape [8, 3, 224, 224] but informally we think of those dimensions as 'batch', 'channel', 'width', and 'height'. Eventhough 'width' and 'height' have the same _size_ we still think of them as separate dimensions, and if we have two _different_ images, we think of both as sharing the _same_ 'channel' dimension. 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard WorkerNamed tensors gives these dimensions names. [PyTorch's current implementation](https://pytorch.org/docs/stable/named_tensor.html) uses strings to name dimensions. Instead, this library introduces a Python object, a `Dim`, to represent the concept. By expanding the semantics of tensors with dim objects, in addition to naming dimensions, we can get behavior equivalent to batching transforms (xmap, vmap), einops-style rearrangement, and loop-style tensor indexing. 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard WorkerA preview: 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker```py 15*da0073e9SAndroid Build Coastguard Workerfrom torchdim import dims 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker# einsum 18*da0073e9SAndroid Build Coastguard Workerdef mm(A: torch.Tensor, B: torch.Tensor): 19*da0073e9SAndroid Build Coastguard Worker i, j, k = dims(3) 20*da0073e9SAndroid Build Coastguard Worker r = (A[i, k] * B[k, j]).sum(k) 21*da0073e9SAndroid Build Coastguard Worker return r.order(i, j) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker# rearrange 24*da0073e9SAndroid Build Coastguard Workerdef pixel_shuffle(img: torch.Tensor, upscale_factor=2): 25*da0073e9SAndroid Build Coastguard Worker h2, w2, c, b, h, w = dims(6) 26*da0073e9SAndroid Build Coastguard Worker h2.size = w2.size = upscale_factor 27*da0073e9SAndroid Build Coastguard Worker return img[b, (c, h2, w2), h, w].order(b, c, (h, h2), (w, w2)) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker# batching 30*da0073e9SAndroid Build Coastguard Workerdef bmm(A: torch.Tensor, B: torch.Tensor): 31*da0073e9SAndroid Build Coastguard Worker i = dims(1) 32*da0073e9SAndroid Build Coastguard Worker return mm(A[i], B[i]).order(i) 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker# indexing 35*da0073e9SAndroid Build Coastguard Workerdef embedding_bag(input: torch.Tensor, embedding_weights: torch.Tensor): 36*da0073e9SAndroid Build Coastguard Worker batch, sequence, features = dims(3) 37*da0073e9SAndroid Build Coastguard Worker r = embedding_weights[input[batch, sequence], features].sum(sequence) 38*da0073e9SAndroid Build Coastguard Worker return r.order(batch, features) 39*da0073e9SAndroid Build Coastguard Worker``` 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard WorkerInstallation 42*da0073e9SAndroid Build Coastguard Worker============ 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker_torchdim is a preview release so that we can collect feedback on the API. It may have bugs, and there are known places where performance can be improved._ 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard WorkerFirst-class dims are a library that extends PyTorch, so they need to be installed separately. 48*da0073e9SAndroid Build Coastguard WorkerWe may eventually upstream them into PyTorch itself along with `functorch`. 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard WorkerWe have to install a nightly build of PyTorch so first set up an environment: 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker```sh 54*da0073e9SAndroid Build Coastguard Workerconda create --name dim 55*da0073e9SAndroid Build Coastguard Workerconda activate dim 56*da0073e9SAndroid Build Coastguard Worker``` 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard WorkerFirst-class dims requires a fairly recent nightly build of PyTorch so that functorch will work. You can install it using one of these commands: 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker```sh 61*da0073e9SAndroid Build Coastguard Worker# For CUDA 10.2 62*da0073e9SAndroid Build Coastguard Workerconda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch-nightly 63*da0073e9SAndroid Build Coastguard Worker# For CUDA 11.3 64*da0073e9SAndroid Build Coastguard Workerconda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch-nightly 65*da0073e9SAndroid Build Coastguard Worker# For CPU-only build 66*da0073e9SAndroid Build Coastguard Workerconda install pytorch torchvision torchaudio cpuonly -c pytorch-nightly 67*da0073e9SAndroid Build Coastguard Worker``` 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard WorkerInstall dim. You will be asked for github credentials to access the fairinternal organization. 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker```sh 72*da0073e9SAndroid Build Coastguard Workerpip install ninja # Makes the build go faster 73*da0073e9SAndroid Build Coastguard Workerpip install --user "git+https://github.com/facebookresearch/torchdim" 74*da0073e9SAndroid Build Coastguard Worker``` 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard WorkerCreating and Binding Dims 77*da0073e9SAndroid Build Coastguard Worker========================= 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard WorkerPython objects that represent dimension are created using the `dims` operator.[^1] 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker```py 82*da0073e9SAndroid Build Coastguard Workerimport torch 83*da0073e9SAndroid Build Coastguard Workerfrom torchdim import dims 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Workerbatch, channel, width, height = dims(4) 86*da0073e9SAndroid Build Coastguard Worker``` 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard WorkerThe existing implementation of [Named Tensors](https://pytorch.org/docs/stable/named_tensor.html) in PyTorch, or [JAX's xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html) use strings to name dimensions. We call these dimensions _first class_ because they are Python objects. 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard WorkerIn addition to the normal _positional_ dimensions in a tensor, tensors can also have a separate set of first-class dimensions. 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard WorkerYou can create tensors with first-class dimensions by indexing the normal positional dimensions of a tensor with a dimension object. The `ndim` property continues to list the number of positional dimensions, while the new `dims` property lists all the bound first-class dimensions. 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker```py 95*da0073e9SAndroid Build Coastguard Workerinput = torch.rand(2, 3, 224, 224) 96*da0073e9SAndroid Build Coastguard Workerprint(input.ndim) 97*da0073e9SAndroid Build Coastguard Worker> 4 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Workerinput_fc = input[batch, channel, width, height] 100*da0073e9SAndroid Build Coastguard Workerprint(input_fc.dims) # first class dimensions 101*da0073e9SAndroid Build Coastguard Worker> (batch, channel, width, height) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker# since we converted all the positional dimensions 105*da0073e9SAndroid Build Coastguard Worker# first class `input_fc` has 0 positional dimensions now. 106*da0073e9SAndroid Build Coastguard Workerprint(input_fc.ndim) 107*da0073e9SAndroid Build Coastguard Worker> 0 108*da0073e9SAndroid Build Coastguard Worker``` 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard WorkerNotice that indexing creates a _new_ Tensor, `input_fc` with bound first-class dimensions. It does not modify the original tensor `input`, which still has 4 positional dimensions. 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker```py 113*da0073e9SAndroid Build Coastguard Workerprint(input.ndim) # unchanged 114*da0073e9SAndroid Build Coastguard Worker> 4 115*da0073e9SAndroid Build Coastguard Worker``` 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard WorkerImportantly, indexing with square brackets _applies only to positional dimensions_, so attempting to index a tensor with only first class dims will error[^2]: 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker```py 120*da0073e9SAndroid Build Coastguard Workertry: 121*da0073e9SAndroid Build Coastguard Worker input_fc[0] 122*da0073e9SAndroid Build Coastguard Workerexcept ValueError as ve: 123*da0073e9SAndroid Build Coastguard Worker print(ve) 124*da0073e9SAndroid Build Coastguard Worker> at least 1 indices were supplied but the tensor only has 0 dimensions 125*da0073e9SAndroid Build Coastguard Worker``` 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard WorkerGenerally, it is possible to construct tensors with a mixture of positional and first class dimensions: 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker```py 130*da0073e9SAndroid Build Coastguard Workerinput_mixed = input[batch, :, :, height] 131*da0073e9SAndroid Build Coastguard Workerprint(input_mixed.dims) 132*da0073e9SAndroid Build Coastguard Worker> (batch, height) 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Workerprint(input_mixed.ndim) 135*da0073e9SAndroid Build Coastguard Worker> 2 136*da0073e9SAndroid Build Coastguard Worker``` 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard WorkerDimension Sizes 139*da0073e9SAndroid Build Coastguard Worker--------------- 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard WorkerDimensions will take on the size of the first thing they are bound to: 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker```py 144*da0073e9SAndroid Build Coastguard Workerinput = torch.rand(3) 145*da0073e9SAndroid Build Coastguard Workerx = dims(1) 146*da0073e9SAndroid Build Coastguard Workerinput_fc = input[x] 147*da0073e9SAndroid Build Coastguard Workerprint(x.size) 148*da0073e9SAndroid Build Coastguard Worker> 3 149*da0073e9SAndroid Build Coastguard Worker``` 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard WorkerBut you can also directly set the size of dimension: 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker```py 154*da0073e9SAndroid Build Coastguard Workeri = dims(1) 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Workeri.size = 5 # ok, i previously did not have a size 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Workeri.size = 5 # ok, it already had the size 5 159*da0073e9SAndroid Build Coastguard Workertry: 160*da0073e9SAndroid Build Coastguard Worker i.size = 3 161*da0073e9SAndroid Build Coastguard Workerexcept Exception as e: 162*da0073e9SAndroid Build Coastguard Worker print(e) 163*da0073e9SAndroid Build Coastguard Worker> Dim 'i' previously bound to a dimension of size 5 cannot bind to a dimension of size 3 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Workerj = dims(sizes=[4]) # can also be set on construction 166*da0073e9SAndroid Build Coastguard Worker``` 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker[^1]: We use a bit of Python introspection to set the debug names for the dimensions based on the names of the variables they are assigned to. 169*da0073e9SAndroid Build Coastguard Worker[^2]: Indexing of first-class dimensions can be done with the `index` method by specifying the dimension to be index into (e.g. `input_fc.index(batch, 0)`. 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard WorkerSemantics of Dimensions 172*da0073e9SAndroid Build Coastguard Worker======================= 173*da0073e9SAndroid Build Coastguard WorkerThe power of named tensors arises from how the first-class dimensions in the Tensors composed with existing operations. 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard WorkerThree rules define how dimension objects behave with existing Tensors. 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard WorkerRule 1: Implicit Batching 178*da0073e9SAndroid Build Coastguard Worker------------------------- 179*da0073e9SAndroid Build Coastguard Worker**Tensor operations (e.g. `input + bias`) are implicitly batched over the union of the first-class dimensions in their inputs.** 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard WorkerIf `input` has dimensions `batch, channel` and `bias` has dimension `channel`, the output will have the union of those dimensions (`batch, channel`), and the result will be computed as if there was a loop over all the first-class dimensions.[^3] 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker```py 184*da0073e9SAndroid Build Coastguard Workerinput_positional = torch.rand(128, 32) 185*da0073e9SAndroid Build Coastguard Workerbias_positional = torch.rand(32) 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Workerbatch, channel = dims(2) 188*da0073e9SAndroid Build Coastguard Workerinput = input_positional[batch, channel] 189*da0073e9SAndroid Build Coastguard Workerbias = bias_positional[channel] 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Workerresult = input + bias 192*da0073e9SAndroid Build Coastguard Workerprint(result.dims) 193*da0073e9SAndroid Build Coastguard Worker> (batch, channel) 194*da0073e9SAndroid Build Coastguard Worker``` 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard WorkerIt is helpful to think of operators on tensors with first-class dimensions by analogy to code with explicit loops over dimensions, with the first-class dimensions of the inputs acting as implicit `for` loops, and the values in the tensor being scalars within the body of the loop: 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker```py 199*da0073e9SAndroid Build Coastguard Worker# mental model: loop-level analogy 200*da0073e9SAndroid Build Coastguard Workerfor batch in range(batch.size): 201*da0073e9SAndroid Build Coastguard Worker for channel in range(channel.size): 202*da0073e9SAndroid Build Coastguard Worker input = input_positional[batch, channels] 203*da0073e9SAndroid Build Coastguard Worker bias = bias_positional[channels] 204*da0073e9SAndroid Build Coastguard Worker result[batch, channels] = input + bias # arithmetic on scalars 205*da0073e9SAndroid Build Coastguard Worker``` 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard WorkerPositional dimensions behave as they did before (e.g. for + they will broadcast), and can be thought of as being a standard tensor _used within the implicit loops_ defined by first-class dimensions. 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard WorkerIn this example, we broke down the expression into lines that bind the dimension to positional tensors and then another line to do the compute. In practice, we often combine these in one statement: 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker```py 212*da0073e9SAndroid Build Coastguard Workerresult = input_positional[batch, channel] + bias_positional[channel] 213*da0073e9SAndroid Build Coastguard Workerresult.dims 214*da0073e9SAndroid Build Coastguard Worker``` 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker[^3] This rule is similar to how named dimensions in xmap behave within a function, but instead of introducing the dimensions via a functional transform, they are bound on the objects using indexing. 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard WorkerRule 2: Specifying dimensions 220*da0073e9SAndroid Build Coastguard Worker----------------------------- 221*da0073e9SAndroid Build Coastguard Worker**Wherever an integer is used to specify a dimension in the existing torch operator, a first-class dimensions can be used instead to tell the operator to work over that dimension.** 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker```py 224*da0073e9SAndroid Build Coastguard Workerbatch, channel, width, height = dims(4) 225*da0073e9SAndroid Build Coastguard Workerinput_positional = torch.rand(2, 3, 224, 224) 226*da0073e9SAndroid Build Coastguard Workerinput = input_positional[batch, channel, width, height] 227*da0073e9SAndroid Build Coastguard Workeravg_pixel_color = input.mean((width, height)) 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Workerprint(avg_pixel_color.dims) 230*da0073e9SAndroid Build Coastguard Worker> (batch, channel) 231*da0073e9SAndroid Build Coastguard Worker``` 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard WorkerAny other first-class dimensions (e.g. batch, channel) are still implicitly batched according to Rule #1. 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard WorkerRule 3: Dims are Tensors 236*da0073e9SAndroid Build Coastguard Worker------------------------ 237*da0073e9SAndroid Build Coastguard Worker**A first-class dimension `d` can be used wherever a Tensor is expected. It will act as if it were a tensor whose only dimension is itself, `d`, and the values along the dimension are the indices of each entry `(0, 1, 2, ..., d.size - 1)`** 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker```py 240*da0073e9SAndroid Build Coastguard Workerprint(channel.dims) 241*da0073e9SAndroid Build Coastguard Worker> (channel,) 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Workerprint(channel + 1000) 244*da0073e9SAndroid Build Coastguard Worker> tensor([1000, 1001, 1002]) 245*da0073e9SAndroid Build Coastguard Worker> with dims=(channel,) sizes=(3,) 246*da0073e9SAndroid Build Coastguard Worker``` 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard WorkerThis means that a dimension used as a tensor acts as an index into that dimension. Going back to our loop-level analogy, it is analogous to using the loop variable as a value: 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker```py 251*da0073e9SAndroid Build Coastguard Worker# mental model: loop-level analogy 252*da0073e9SAndroid Build Coastguard Workerfor channel in range(batch.size): 253*da0073e9SAndroid Build Coastguard Worker result[channel] = channel + 1000 254*da0073e9SAndroid Build Coastguard Worker``` 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard WorkerArithmetic using dimension indices comes up a lot, such as the mask for an upper triangular part of a matrix. Using dims as tensors makes it easy: 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker```py 259*da0073e9SAndroid Build Coastguard Workerfrom torchdim import dims 260*da0073e9SAndroid Build Coastguard Workeri, j = dims(sizes=[4, 4]) 261*da0073e9SAndroid Build Coastguard Workerprint(i <= j) 262*da0073e9SAndroid Build Coastguard Worker> tensor([[ True, True, True, True], 263*da0073e9SAndroid Build Coastguard Worker> [False, True, True, True], 264*da0073e9SAndroid Build Coastguard Worker> [False, False, True, True], 265*da0073e9SAndroid Build Coastguard Worker> [False, False, False, True]]) 266*da0073e9SAndroid Build Coastguard Worker> with dims=(i, j) sizes=(4, 4) 267*da0073e9SAndroid Build Coastguard Worker``` 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard WorkerBecause of the intentional similarity to loop-level code, using dimensions as tensors makes complicated indexing arithmetic easier to read. 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard WorkerHere is code that lookups up features in an embedding table given a sequence of ids: 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker```py 274*da0073e9SAndroid Build Coastguard Workersequence, features = dims(2) 275*da0073e9SAndroid Build Coastguard Workerembeddings = torch.rand(8, 128) 276*da0073e9SAndroid Build Coastguard Workerwords = torch.tensor([5, 4, 0,]) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Workerstate = embeddings[words[sequence], features] 279*da0073e9SAndroid Build Coastguard Workerprint(state.dims) 280*da0073e9SAndroid Build Coastguard Worker> (sequence, features) 281*da0073e9SAndroid Build Coastguard Worker``` 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard WorkerWith the following analogy to loops: 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker```py 286*da0073e9SAndroid Build Coastguard Worker# mental model: loop-level analogy 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Workerfor sequence in range(words.size(0)): 289*da0073e9SAndroid Build Coastguard Worker for features in range(embeddings.size(1)): 290*da0073e9SAndroid Build Coastguard Worker state = embeddings[words[sequence], features] 291*da0073e9SAndroid Build Coastguard Worker``` 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard WorkerEarlier we showed how binding tensors dimension is done with indexing `A[i, j]`. In fact, this binding is just the normal indexing operator. Its behavior follows directly from the behavior of indexing with tensor indices combined with Rule #3 and Rule #1. The expression `A[i + 1, j]` also creates a tensor with dimensions `i` and `j` but with different indexing math. The implementation knows when simple indexing patterns are used and only actually runs a kernel to do indexing when needed. 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard WorkerUnbinding Dims 296*da0073e9SAndroid Build Coastguard Worker------------- 297*da0073e9SAndroid Build Coastguard WorkerThe `order` method converts first-class dimensions in a tensor back to normal positional dimensions by specifying an order for those dimensions.[^4] 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard WorkerBy specifying a different order from how things were originally bound, it is easy to do transpositions. 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker```py 302*da0073e9SAndroid Build Coastguard Workeri, j = dims(2) 303*da0073e9SAndroid Build Coastguard WorkerA = torch.rand(3, 4) 304*da0073e9SAndroid Build Coastguard WorkerA_T = A[i, j].order(j, i) 305*da0073e9SAndroid Build Coastguard Workerassert torch.allclose(A.T, A_T) 306*da0073e9SAndroid Build Coastguard Worker``` 307*da0073e9SAndroid Build Coastguard Worker 308*da0073e9SAndroid Build Coastguard WorkerIndexing acts left-to-right, and `order` also places the new dimensions back on the left, so it possible to work on tensors that have mixed positional and first-class dimensions: 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker```py 311*da0073e9SAndroid Build Coastguard WorkerB = torch.rand(3, 4, 5) 312*da0073e9SAndroid Build Coastguard WorkerB_T = B[i, j].order(j, i) 313*da0073e9SAndroid Build Coastguard Workerassert torch.allclose(B.permute(1, 0, 2), B_T) 314*da0073e9SAndroid Build Coastguard Worker``` 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker[^4] `order` is actually just a synonym for the already-existing `permute` method, which takes a list a dimension specifiers and puts the tensor in that order because rule #2 says that first-class dims can be passed as arguments to functions that previously took only integers as dimensions. However, the name `permute` is confusing in this context since it implies dim objects have an original order, so we prefer to use `order` when writing code. 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard WorkerFlattening and Splitting Dims 319*da0073e9SAndroid Build Coastguard Worker----------------------------- 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker**Tuples of dimensions** can be passed to both indexing and `order`. In indexing, this will split the dimension being indexed across the dimensions in the tuple. In `order` it will flatten the dimensions in a single positional dimension: 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker```py 324*da0073e9SAndroid Build Coastguard Workeri, j, k = dims(3) 325*da0073e9SAndroid Build Coastguard Workerj.size = 2 326*da0073e9SAndroid Build Coastguard WorkerA = torch.rand(6, 4) 327*da0073e9SAndroid Build Coastguard Workera = A[(i, j), k] # split dim 0 into i,j 328*da0073e9SAndroid Build Coastguard Workerprint(i.size, j.size, k.size) 329*da0073e9SAndroid Build Coastguard Worker> 3 2 4 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Workerr = a.order(i, (j, k)) # flatten j and k 332*da0073e9SAndroid Build Coastguard Workerprint(r.shape) 333*da0073e9SAndroid Build Coastguard Worker> torch.Size([3, 8]) 334*da0073e9SAndroid Build Coastguard Worker``` 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard WorkerThe size of one unsized dimension in a tuple such as `i` can be inferred if the other sizes are known. 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard WorkerExamples 339*da0073e9SAndroid Build Coastguard Worker======== 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard WorkerThe usefulness of dimension objects is best seen through examples. Let's look at some different ways they can be used. 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard WorkerEinsum-style Products 344*da0073e9SAndroid Build Coastguard Worker--------------------- 345*da0073e9SAndroid Build Coastguard WorkerRather than having [einsum](https://pytorch.org/docs/stable/generated/torch.einsum.html) as a custom operator, it is possible to express matrix products directly as a composition of multiplies and summations. The implementation will pattern match any multiplication followed by a sum to the right matrix-multiply operator. 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker```py 348*da0073e9SAndroid Build Coastguard Workerdef mm(A, B): 349*da0073e9SAndroid Build Coastguard Worker i, j, k = dims(3) 350*da0073e9SAndroid Build Coastguard Worker r = (A[i, k] * B[k, j]).sum(k) 351*da0073e9SAndroid Build Coastguard Worker return r.order(i, j) 352*da0073e9SAndroid Build Coastguard Workermm(torch.rand(3, 4), torch.rand(4, 5)).shape 353*da0073e9SAndroid Build Coastguard Worker``` 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard WorkerThe implementation of named tensors delays the execution of multiply to see if a summation follows it as it does above. If so, it will turn this pattern into the correct _optimized matrix product_, similar to how the `einsum` function works. 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard WorkerSince it is no longer necessary to manually match math to matrix functions, other tensor products are easier to express, like the Gram matrix used in style transfer: 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker```py 360*da0073e9SAndroid Build Coastguard Workerdef gram_matrix_new(y): 361*da0073e9SAndroid Build Coastguard Worker b, c, c2, h, w = dims() 362*da0073e9SAndroid Build Coastguard Worker r = (y[b, c, h, w] * y[b, c2, h, w]).sum((h, w)) 363*da0073e9SAndroid Build Coastguard Worker r = r / (h.size * w.size) 364*da0073e9SAndroid Build Coastguard Worker return r.order(b, c, c2) 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Workergram_matrix_new(torch.rand(1, 2, 3, 4)) 367*da0073e9SAndroid Build Coastguard Worker# [example adapted from http://einops.rocks/pytorch-examples.html] 368*da0073e9SAndroid Build Coastguard Worker``` 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard WorkerAttention is another example that has several matrix products embedded inside it: 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker```py 373*da0073e9SAndroid Build Coastguard Workerfrom torchdim import softmax 374*da0073e9SAndroid Build Coastguard Workerdef attention(K, Q, V): 375*da0073e9SAndroid Build Coastguard Worker batch, channel, key, query = dims(4) 376*da0073e9SAndroid Build Coastguard Worker k = K[batch, channel, key] 377*da0073e9SAndroid Build Coastguard Worker q = Q[batch, channel, query] 378*da0073e9SAndroid Build Coastguard Worker v = V[batch, channel, key] 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker a = (k * q).sum(channel) # matrix multiply 381*da0073e9SAndroid Build Coastguard Worker a = softmax(a * (channel.size ** -0.5), dim=key) 382*da0073e9SAndroid Build Coastguard Worker r = (v * a).sum(key) # matrix multiply 383*da0073e9SAndroid Build Coastguard Worker return torch.cat((r.order(batch, channel, query), Q), dim=1) 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Workerinputs = (torch.rand(2, 3, 4) for _ in range(3)) 386*da0073e9SAndroid Build Coastguard Workerattention(*inputs) 387*da0073e9SAndroid Build Coastguard Worker# [example adapted from http://einops.rocks/pytorch-examples.html] 388*da0073e9SAndroid Build Coastguard Worker``` 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard WorkerReshaping tensors (einops) 391*da0073e9SAndroid Build Coastguard Worker-------------------------- 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard WorkerLots of operations in deep learning are just different ways of reshaping, splitting, and joining dimensions, such as the pixel shuffle used to upscale an image by turning channels into pixels: 394*da0073e9SAndroid Build Coastguard Worker 395*da0073e9SAndroid Build Coastguard Worker```py 396*da0073e9SAndroid Build Coastguard Workerdef pixel_shuffle(img, upscale_factor=2): 397*da0073e9SAndroid Build Coastguard Worker h2, w2, c, b, h, w = dims(6) 398*da0073e9SAndroid Build Coastguard Worker h2.size = w2.size = upscale_factor 399*da0073e9SAndroid Build Coastguard Worker return img[b, (c, h2, w2), h, w].order(b, c, (h, h2), (w, w2)) 400*da0073e9SAndroid Build Coastguard Worker``` 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker[Einops](http://einops.rocks) is an extension to einsum that adds support for the manipulation of dimensions through a few custom operators such as `rearrange`: 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker```py 405*da0073e9SAndroid Build Coastguard Workerdef pixel_shuffle_einops(img, upscale_factor=2): 406*da0073e9SAndroid Build Coastguard Worker from einops import rearrange 407*da0073e9SAndroid Build Coastguard Worker return rearrange(img, 'b (c h2 w2) h w -> b c (h h2) (w w2)', h2=upscale_factor, w2=upscale_factor) 408*da0073e9SAndroid Build Coastguard Worker``` 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard WorkerNamed tensors with first-class dimensions can accomplish the same goal, but using PyTorch's existing operator set. 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard WorkerAutomatically batching Code (`vmap`, `xmap`) 413*da0073e9SAndroid Build Coastguard Worker----------------------------- 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard WorkerThe implicit batching of Rule #1 means it is easy to created batched versions of existing PyTorch code. Simply bind a dim to the dimensions that should act as a batch, and then pass the tensor to the unbatched function. Since the unbatched function does not know about the dim, the dim will be implicitly batched over: 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker```py 418*da0073e9SAndroid Build Coastguard Workerbatch_size, feature_size = 3, 5 419*da0073e9SAndroid Build Coastguard Workerweights = torch.randn(feature_size) 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Workerdef model(feature_vec): 422*da0073e9SAndroid Build Coastguard Worker # Very simple linear model with activation 423*da0073e9SAndroid Build Coastguard Worker assert feature_vec.dim() == 1 424*da0073e9SAndroid Build Coastguard Worker return feature_vec.dot(weights).relu() 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Workerexamples = torch.randn(batch_size, feature_size) 427*da0073e9SAndroid Build Coastguard Workerbatch = dims(1) 428*da0073e9SAndroid Build Coastguard Workerr = model(examples[batch]) 429*da0073e9SAndroid Build Coastguard Workerprint(r) 430*da0073e9SAndroid Build Coastguard Worker# in functorch: result = functorch.vmap(model)(examples) 431*da0073e9SAndroid Build Coastguard Worker> tensor([0.4775, 0.0000, 0.3423]) 432*da0073e9SAndroid Build Coastguard Worker> with dims=(batch,) sizes=(3,) 433*da0073e9SAndroid Build Coastguard Worker``` 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard WorkerThis pattern also composes well with other code that also uses first class dimensions. For instance, we can write batched matrix multiply `bmm` by batching the `mm` operator. 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard WorkerIt doesn't matter whether the implementation of the function uses dimension objects, it is also possible to add additional batch dimensions and then call a function: 438*da0073e9SAndroid Build Coastguard Worker 439*da0073e9SAndroid Build Coastguard Worker```py 440*da0073e9SAndroid Build Coastguard Workerdef bmm(A, B): 441*da0073e9SAndroid Build Coastguard Worker i = dims(1) # note: i here is a different value from i inside mm so it works 442*da0073e9SAndroid Build Coastguard Worker return mm(A[i], B[i]).order(i) 443*da0073e9SAndroid Build Coastguard Worker``` 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard WorkerThe equivalent code in JAX, using [xmap or vmap](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#auto-vectorization-with-vmap) are transforms over functions. So there is a lot of syntactic distance between the specification of the dimension mappings, and the values where those mappings apply. Dims express the mapping as indexing of the tensor, right at the place where the function is being applied. 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker[xmap examples](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html): 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker```py 451*da0073e9SAndroid Build Coastguard Workerin_axes = [['inputs', 'hidden', ...], 452*da0073e9SAndroid Build Coastguard Worker ['hidden', 'classes', ...], 453*da0073e9SAndroid Build Coastguard Worker ['batch', 'inputs', ...], 454*da0073e9SAndroid Build Coastguard Worker ['batch', ...]] 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Workerloss = xmap(named_loss, in_axes=in_axes, out_axes=[...]) 457*da0073e9SAndroid Build Coastguard Workerprint(loss(w1, w2, images, labels)) 458*da0073e9SAndroid Build Coastguard Worker``` 459*da0073e9SAndroid Build Coastguard Worker 460*da0073e9SAndroid Build Coastguard WorkerEquivalent with dimension objects: 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker```py 463*da0073e9SAndroid Build Coastguard Workerbatch, inputs, hidden, classes = dims(4) 464*da0073e9SAndroid Build Coastguard Workerprint(loss(w1[inputs, hidden], w2[hidden, classes], images[batch, inputs], labels[batch], 465*da0073e9SAndroid Build Coastguard Worker batch, inputs, hidden, classes)) 466*da0073e9SAndroid Build Coastguard Worker``` 467*da0073e9SAndroid Build Coastguard Worker 468*da0073e9SAndroid Build Coastguard Worker 469*da0073e9SAndroid Build Coastguard WorkerComposing matrix products, reshaping, and batching: 470*da0073e9SAndroid Build Coastguard Worker--------------------- 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard WorkerMulti-headed attention is a good example of how these different uses compose. It reshapes the inputs, splitting out different attention heads. It batches over those attention heads, and it uses matrix products to compute attention scores. 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker```py 475*da0073e9SAndroid Build Coastguard Workerfrom torchdim import softmax 476*da0073e9SAndroid Build Coastguard Workerdef multiheadattention(q, k, v, num_attention_heads, dropout_prob, use_positional_embedding): 477*da0073e9SAndroid Build Coastguard Worker batch, query_sequence, key_sequence, heads, features = dims(5) 478*da0073e9SAndroid Build Coastguard Worker heads.size = num_attention_heads 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker # binding dimensions, and unflattening the heads from the feature dimension 481*da0073e9SAndroid Build Coastguard Worker q = q[batch, query_sequence, [heads, features]] 482*da0073e9SAndroid Build Coastguard Worker k = k[batch, key_sequence, [heads, features]] 483*da0073e9SAndroid Build Coastguard Worker v = v[batch, key_sequence, [heads, features]] 484*da0073e9SAndroid Build Coastguard Worker 485*da0073e9SAndroid Build Coastguard Worker # einsum-style operators to calculate scores, 486*da0073e9SAndroid Build Coastguard Worker attention_scores = (q*k).sum(features) * (features.size ** -0.5) 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker # use first-class dim to specify dimension for softmax 489*da0073e9SAndroid Build Coastguard Worker attention_probs = softmax(attention_scores, dim=key_sequence) 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker # dropout work pointwise, following Rule #1 492*da0073e9SAndroid Build Coastguard Worker attention_probs = torch.nn.functional.dropout(attention_probs, p=dropout_prob) 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Worker # another matrix product 495*da0073e9SAndroid Build Coastguard Worker context_layer = (attention_probs*v).sum(key_sequence) 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Worker # flatten heads back into features 498*da0073e9SAndroid Build Coastguard Worker return context_layer.order(batch, query_sequence, [heads, features]) 499*da0073e9SAndroid Build Coastguard Worker``` 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard WorkerIndexing 502*da0073e9SAndroid Build Coastguard Worker-------- 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard WorkerRule #3 enables indexing because dimensions act as loop indices when used as a tensor. This allows for a lot of powerful behavior. The simplest might be using the dimensions to compute masks, such as extracting the upper triangular part of a matrix: 505*da0073e9SAndroid Build Coastguard Worker 506*da0073e9SAndroid Build Coastguard Worker```py 507*da0073e9SAndroid Build Coastguard Workerfrom torch import where 508*da0073e9SAndroid Build Coastguard Workerdef triu(A): 509*da0073e9SAndroid Build Coastguard Worker i,j = dims() 510*da0073e9SAndroid Build Coastguard Worker a = A[i, j] 511*da0073e9SAndroid Build Coastguard Worker return where(i <= j, a, 0).order(i, j) 512*da0073e9SAndroid Build Coastguard Workertriu(torch.rand(3, 4)) 513*da0073e9SAndroid Build Coastguard Worker``` 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard WorkerEmbedding bag does an embedding table lookup followed by a sum, which can be expressed concisely: 516*da0073e9SAndroid Build Coastguard Worker 517*da0073e9SAndroid Build Coastguard Worker```py 518*da0073e9SAndroid Build Coastguard Workerdef embedding_bag(input, embedding_weights): 519*da0073e9SAndroid Build Coastguard Worker batch, sequence, features = dims(3) 520*da0073e9SAndroid Build Coastguard Worker r = embedding_weights[input[batch, sequence], features].sum(sequence) 521*da0073e9SAndroid Build Coastguard Worker return r.order(batch, features) 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Workerinput = torch.tensor([[1, 0, 4, 3]]) 524*da0073e9SAndroid Build Coastguard WorkerW = torch.rand(5,2) 525*da0073e9SAndroid Build Coastguard Workerembedding_bag(input, W) 526*da0073e9SAndroid Build Coastguard Worker``` 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard WorkerRelative positional embeddings associate an embedding vector with the distance between the query and the key in the sequence. 529*da0073e9SAndroid Build Coastguard WorkerFor instance, a key 3 and query 5 will have embedding ID `(5-3)=2`. We can use first-class dimensions to do the indexing arithmetic, and the embedding lookup: 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Worker```py 532*da0073e9SAndroid Build Coastguard Workerdef relative_positional_embedding(q, k, distance_embedding_weight): 533*da0073e9SAndroid Build Coastguard Worker batch, query_sequence, key_sequence, heads, features = dims(5) 534*da0073e9SAndroid Build Coastguard Worker q = q[batch, query_sequence, [heads, features]] 535*da0073e9SAndroid Build Coastguard Worker k = k[batch, key_sequence, [heads, features]] 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Worker distance = query_sequence - key_sequence 538*da0073e9SAndroid Build Coastguard Worker n_embeddings = distance_embedding_weight.size(0) 539*da0073e9SAndroid Build Coastguard Worker index_bias = n_embeddings // 2 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker assert key_sequence.size + bias <= n_embeddings 542*da0073e9SAndroid Build Coastguard Worker 543*da0073e9SAndroid Build Coastguard Worker # indexing with dims 544*da0073e9SAndroid Build Coastguard Worker positional_embedding = distance_embedding_weight[distance + index_bias, features] 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker # matrix multiplies with dims 547*da0073e9SAndroid Build Coastguard Worker relative_position_scores_query = (q*positional_embedding).sum(features) 548*da0073e9SAndroid Build Coastguard Worker relative_position_scores_key = (k*positional_embedding).sum(features) 549*da0073e9SAndroid Build Coastguard Worker return (relative_position_scores_query + relative_position_scores_key).order(batch, heads, key_sequence, query_sequence) 550*da0073e9SAndroid Build Coastguard Worker``` 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard WorkerTensor Puzzlers 553*da0073e9SAndroid Build Coastguard Worker=============== 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker[Tensor Puzzlers](https://github.com/srush/Tensor-Puzzles), created by Sasha Rush, are a good exercise for learning the numpy and torch APIs by figuring out how to define common operations using a small set of primitive tensor operations. 556*da0073e9SAndroid Build Coastguard Worker 557*da0073e9SAndroid Build Coastguard WorkerHowever, the difficulty of many of the puzzlers lies not in how to compute the answer but the awkwardness of the primitives themselves. 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker**With first class dimensions, these puzzlers are nearly the same as the spec that defines them** 560*da0073e9SAndroid Build Coastguard Worker 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Worker### Puzzle 3 - outer 563*da0073e9SAndroid Build Coastguard Worker 564*da0073e9SAndroid Build Coastguard WorkerCompute [outer](https://numpy.org/doc/stable/reference/generated/numpy.outer.html) - the outer product of two vectors. 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard Worker```py 567*da0073e9SAndroid Build Coastguard Workerdef outer_spec(a, b, out): 568*da0073e9SAndroid Build Coastguard Worker for i in range(len(out)): 569*da0073e9SAndroid Build Coastguard Worker for j in range(len(out[0])): 570*da0073e9SAndroid Build Coastguard Worker out[i][j] = a[i] * b[j] 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Workerdef outer(a, b): 573*da0073e9SAndroid Build Coastguard Worker i, j = dims(2) 574*da0073e9SAndroid Build Coastguard Worker return (a[i] * b[j]).order(i, j) 575*da0073e9SAndroid Build Coastguard Worker``` 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker### Puzzle 4 - diag 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard WorkerCompute [diag](https://numpy.org/doc/stable/reference/generated/numpy.diag.html) - the diagonal vector of a square matrix. 580*da0073e9SAndroid Build Coastguard Worker 581*da0073e9SAndroid Build Coastguard Worker```py 582*da0073e9SAndroid Build Coastguard Workerdef diag_spec(a, out): 583*da0073e9SAndroid Build Coastguard Worker for i in range(len(a)): 584*da0073e9SAndroid Build Coastguard Worker out[i] = a[i][i] 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Workerdef diag(a): 587*da0073e9SAndroid Build Coastguard Worker i = dims(1) 588*da0073e9SAndroid Build Coastguard Worker return a[i, i].order(i) 589*da0073e9SAndroid Build Coastguard Worker``` 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Worker### Puzzle 5 - eye 592*da0073e9SAndroid Build Coastguard Worker 593*da0073e9SAndroid Build Coastguard WorkerCompute [eye](https://numpy.org/doc/stable/reference/generated/numpy.eye.html) - the identity matrix. 594*da0073e9SAndroid Build Coastguard Worker 595*da0073e9SAndroid Build Coastguard Worker```py 596*da0073e9SAndroid Build Coastguard Workerfrom torch import where 597*da0073e9SAndroid Build Coastguard Workerdef eye_spec(out): 598*da0073e9SAndroid Build Coastguard Worker for i in range(len(out)): 599*da0073e9SAndroid Build Coastguard Worker out[i][i] = 1 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Workerdef eye(j: int): 602*da0073e9SAndroid Build Coastguard Worker i,j = dims(sizes=[j, j]) 603*da0073e9SAndroid Build Coastguard Worker return where(i == j, 1, 0).order(i, j) 604*da0073e9SAndroid Build Coastguard Worker``` 605*da0073e9SAndroid Build Coastguard Worker 606*da0073e9SAndroid Build Coastguard Worker### Puzzle 6 - triu 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard WorkerCompute [triu](https://numpy.org/doc/stable/reference/generated/numpy.triu.html) - the upper triangular matrix. 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker```py 611*da0073e9SAndroid Build Coastguard Workerdef triu_spec(out): 612*da0073e9SAndroid Build Coastguard Worker for i in range(len(out)): 613*da0073e9SAndroid Build Coastguard Worker for j in range(len(out)): 614*da0073e9SAndroid Build Coastguard Worker if i <= j: 615*da0073e9SAndroid Build Coastguard Worker out[i][j] = 1 616*da0073e9SAndroid Build Coastguard Worker else: 617*da0073e9SAndroid Build Coastguard Worker out[i][j] = 0 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Workerdef triu(j: int): 620*da0073e9SAndroid Build Coastguard Worker i,j = dims(sizes=[j, j]) 621*da0073e9SAndroid Build Coastguard Worker return where(i <= j, 1, 0).order(i, j) 622*da0073e9SAndroid Build Coastguard Worker``` 623*da0073e9SAndroid Build Coastguard Worker 624*da0073e9SAndroid Build Coastguard Worker### Puzzle 8 - diff 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard WorkerCompute [diff](https://numpy.org/doc/stable/reference/generated/numpy.diff.html) - the running difference. 627*da0073e9SAndroid Build Coastguard Worker 628*da0073e9SAndroid Build Coastguard Worker```py 629*da0073e9SAndroid Build Coastguard Workerdef diff_spec(a, out): 630*da0073e9SAndroid Build Coastguard Worker out[0] = a[0] 631*da0073e9SAndroid Build Coastguard Worker for i in range(1, len(out)): 632*da0073e9SAndroid Build Coastguard Worker out[i] = a[i] - a[i - 1] 633*da0073e9SAndroid Build Coastguard Workerdef diff(a, i: int): 634*da0073e9SAndroid Build Coastguard Worker i = dims(1) 635*da0073e9SAndroid Build Coastguard Worker d = a[i] - a[i - 1] 636*da0073e9SAndroid Build Coastguard Worker return where(i - 1 >= 0, d, a[i]).order(i) 637*da0073e9SAndroid Build Coastguard Worker``` 638*da0073e9SAndroid Build Coastguard Worker 639*da0073e9SAndroid Build Coastguard Worker### Puzzle 9 - vstack 640*da0073e9SAndroid Build Coastguard Worker 641*da0073e9SAndroid Build Coastguard WorkerCompute [vstack](https://numpy.org/doc/stable/reference/generated/numpy.vstack.html) - the matrix of two vectors 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker```py 644*da0073e9SAndroid Build Coastguard Workerdef vstack_spec(a, b, out): 645*da0073e9SAndroid Build Coastguard Worker for i in range(len(out[0])): 646*da0073e9SAndroid Build Coastguard Worker out[0][i] = a[i] 647*da0073e9SAndroid Build Coastguard Worker out[1][i] = b[i] 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Workerdef vstack(a, b): 650*da0073e9SAndroid Build Coastguard Worker v, i = dims(sizes=[2, None]) 651*da0073e9SAndroid Build Coastguard Worker return where(v == 0, a[i], b[i]).order(v, i) 652*da0073e9SAndroid Build Coastguard Worker``` 653*da0073e9SAndroid Build Coastguard Worker 654*da0073e9SAndroid Build Coastguard Worker### Puzzle 10 - roll 655*da0073e9SAndroid Build Coastguard Worker 656*da0073e9SAndroid Build Coastguard WorkerCompute [roll](https://numpy.org/doc/stable/reference/generated/numpy.roll.html) - the vector shifted 1 circular position. 657*da0073e9SAndroid Build Coastguard Worker 658*da0073e9SAndroid Build Coastguard Worker```py 659*da0073e9SAndroid Build Coastguard Workerdef roll_spec(a, out): 660*da0073e9SAndroid Build Coastguard Worker for i in range(len(out)): 661*da0073e9SAndroid Build Coastguard Worker if i + 1 < len(out): 662*da0073e9SAndroid Build Coastguard Worker out[i] = a[i + 1] 663*da0073e9SAndroid Build Coastguard Worker else: 664*da0073e9SAndroid Build Coastguard Worker out[i] = a[i + 1 - len(out)] 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Workerdef roll(a, i: int): 667*da0073e9SAndroid Build Coastguard Worker i = dims(sizes=[a.size(0)]) 668*da0073e9SAndroid Build Coastguard Worker return a[where(i + 1 < i.size, i + 1, 0)].order(i) 669*da0073e9SAndroid Build Coastguard Worker``` 670*da0073e9SAndroid Build Coastguard Worker 671*da0073e9SAndroid Build Coastguard Worker### Puzzle 11 - flip 672*da0073e9SAndroid Build Coastguard Worker 673*da0073e9SAndroid Build Coastguard WorkerCompute [flip](https://numpy.org/doc/stable/reference/generated/numpy.flip.html) - the reversed vector 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Worker```py 676*da0073e9SAndroid Build Coastguard Workerdef flip_spec(a, out): 677*da0073e9SAndroid Build Coastguard Worker for i in range(len(out)): 678*da0073e9SAndroid Build Coastguard Worker out[i] = a[len(out) - i - 1] 679*da0073e9SAndroid Build Coastguard Worker 680*da0073e9SAndroid Build Coastguard Workerdef flip(a, i: int): 681*da0073e9SAndroid Build Coastguard Worker i = dims(sizes=[a.size(0)]) 682*da0073e9SAndroid Build Coastguard Worker return a[i.size - i - 1].order(i) 683*da0073e9SAndroid Build Coastguard Worker``` 684*da0073e9SAndroid Build Coastguard Worker 685*da0073e9SAndroid Build Coastguard Worker### Puzzle 14 - sequence_mask 686*da0073e9SAndroid Build Coastguard Worker 687*da0073e9SAndroid Build Coastguard Worker 688*da0073e9SAndroid Build Coastguard WorkerCompute [sequence_mask](https://www.tensorflow.org/api_docs/python/tf/sequence_mask) - pad out to length per batch. 689*da0073e9SAndroid Build Coastguard Worker 690*da0073e9SAndroid Build Coastguard Worker```py 691*da0073e9SAndroid Build Coastguard Workerdef sequence_mask_spec(values, length, out): 692*da0073e9SAndroid Build Coastguard Worker for i in range(len(out)): 693*da0073e9SAndroid Build Coastguard Worker for j in range(len(out[0])): 694*da0073e9SAndroid Build Coastguard Worker if j < length[i]: 695*da0073e9SAndroid Build Coastguard Worker out[i][j] = values[i][j] 696*da0073e9SAndroid Build Coastguard Worker else: 697*da0073e9SAndroid Build Coastguard Worker out[i][j] = 0 698*da0073e9SAndroid Build Coastguard Worker 699*da0073e9SAndroid Build Coastguard Workerdef sequence_mask(values, length): 700*da0073e9SAndroid Build Coastguard Worker j, i = dims() 701*da0073e9SAndroid Build Coastguard Worker v = values[i, j] 702*da0073e9SAndroid Build Coastguard Worker return where(j < length[i], v, 0).order(i, j) 703*da0073e9SAndroid Build Coastguard Worker``` 704*da0073e9SAndroid Build Coastguard Worker 705*da0073e9SAndroid Build Coastguard WorkerAdvantages of First-class Dimensions over String Dimensions 706*da0073e9SAndroid Build Coastguard Worker=================================================================== 707*da0073e9SAndroid Build Coastguard Worker 708*da0073e9SAndroid Build Coastguard WorkerThe most prominent difference between named tensors using first-class dimensions and alternatives (einops, named tensors implemented in PyTorch today , [tensors considered harmful](https://nlp.seas.harvard.edu/NamedTensor), or xmap) is that dimensions are objects rather than strings. Using objects has a number of nice properties. 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker### Avoiding naming conflicts 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard WorkerUsing strings for dimensions introduces the possibility that two unrelated dimensions are given the same name. Using objects instead makes it clear the same names are not the same dimension. It's like the difference between having only global variables, and having the ability to locally bind names in functions. 713*da0073e9SAndroid Build Coastguard Worker For instance, we defined `bmm` by batching a call to `mm`, and even though they both use the name `i` to identify a dimension. Because each `i` is a different object, there is no naming conflict: 714*da0073e9SAndroid Build Coastguard Worker 715*da0073e9SAndroid Build Coastguard Worker```py 716*da0073e9SAndroid Build Coastguard Workerdef mm(A, B): 717*da0073e9SAndroid Build Coastguard Worker i, j, k = dims() 718*da0073e9SAndroid Build Coastguard Worker r = (A[i, k] * B[k, j]).sum(k) 719*da0073e9SAndroid Build Coastguard Worker return r.order(i, j) 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Workerdef bmm(A, B): 722*da0073e9SAndroid Build Coastguard Worker i = dims() # note: doesn't matter than mm internally also uses i 723*da0073e9SAndroid Build Coastguard Worker return mm(A[i], B[i]) 724*da0073e9SAndroid Build Coastguard Worker``` 725*da0073e9SAndroid Build Coastguard Worker 726*da0073e9SAndroid Build Coastguard WorkerEinops avoids conflicts by ensuring names are all introduced and removed in a single expression, but this precludes using long-lived dimensions to present implicit batching similar to xmap. When nested, JAX's xmap seems to consider axes the same if the string name matches. In the above example it would consider the `i` dimension to be the same dimension in both `bmm` and `mm` so the code would error. 727*da0073e9SAndroid Build Coastguard Worker 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker### Reuse the same operator set 730*da0073e9SAndroid Build Coastguard Worker 731*da0073e9SAndroid Build Coastguard WorkerHaving a new object type allows us to extend the existing operator set of PyTorch rather than come up with new operators. For instance, binding dimensions using indexing follows semantically from Rules #1 and #3, so there is no need for a special operator to do binding. Even unbinding is just the `permute` operator which follows from Rule #2, though we call it `order` for clarity. In contrast, using strings requires coming up with new APIs such as `einsum` for matrix multiplies, or `rearrange` for doing permutations. 732*da0073e9SAndroid Build Coastguard Worker 733*da0073e9SAndroid Build Coastguard Worker### Allows dims to act as tensors 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard WorkerRule #3 is not possible with strings since we cannot make strings behave as tensors. Without this rule, all of the indirect indexing that dims enable would not be easy to express. 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker### Dims can have methods 738*da0073e9SAndroid Build Coastguard WorkerFor instance, as objects, dims can have a size, which allows us to do size inference of dimensions in various places in the API where string based APIs would have to take additional arguments specifying size. 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard WorkerComparison to tensor compilers or languages (e.g. TVM or Dex) 742*da0073e9SAndroid Build Coastguard Worker============================================================= 743*da0073e9SAndroid Build Coastguard Worker 744*da0073e9SAndroid Build Coastguard WorkerThe semantics and surface syntax of dimension objects resembles the kind of code written in tensor compilers such as [Halide](https://halide-lang.org), [TVM](https://tvm.apache.org), [Tensor Comprehensions](https://github.com/facebookresearch/TensorComprehensions), or the language [Dex](https://github.com/google-research/dex-lang). 745*da0073e9SAndroid Build Coastguard Worker 746*da0073e9SAndroid Build Coastguard WorkerThese compilers and language have syntax and semantics that resemble the loop-level analogy similar to first-class dimensions. However, as compilers or statically typed languages, they require some binding code to go from running deep learning framework code in Python to using the compiled language. This often at least requires refactoring the compiled parts into their own functions, and may require defining a gradient function. Similar to graph mode frameworks, this adds friction to using and debugging the code. 747*da0073e9SAndroid Build Coastguard Worker 748*da0073e9SAndroid Build Coastguard WorkerDimension objects are just an extension of the existing PyTorch tensors and eager semantics, so there is no friction switching between normal Python code and code that uses them. However, since loops over the dimensions are defined implicitly, they can still execute in Python with good performance compared to explicit loops. Furthermore, with dimension objects, a tensors containing dimensions can compute through code that is oblivious to the dimension such as batching examples. There is no need to separate code into 'compiled' vs 'eager'. 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard WorkerIn this way, first-class dims are a way of adapting the nicer syntax of these array compilers and languages to eager numpy-style libraries. 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard WorkerPerformance Expectations 754*da0073e9SAndroid Build Coastguard Worker======================== 755*da0073e9SAndroid Build Coastguard WorkerFirst-class dimensions are not a compiler. They provide syntax for existing PyTorch operations such as advanced indexing that is easier to read and write. For large sized tensors, the performance of any statements including them will be the same as using the already existing operations. An important exception is the pattern matching of products and summation, where performance will be improved by issuing to a matrix-multiply kernel. The C++ implementation of dimensions adds a small overhead of around 2us on top of PyTorch's normal overhead of 8us to each function that uses them. In the future, the implementation can encorporate more fusion optimization to further improve performance of this style of code. 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Worker 758*da0073e9SAndroid Build Coastguard Worker## License 759*da0073e9SAndroid Build Coastguard WorkerFunctorch has a BSD-style license, as found in the [LICENSE](LICENSE) file. 760