1Named Tensors using First-class Dimensions in PyTorch 2===================================================== 3 4-- Zachary DeVito [@Zachary_DeVito](https://twitter.com/Zachary_DeVito) 5 6_An implementation of [named tensors](https://namedtensor.github.io) with the functionality of [einsum](http://einops.rocks]http://einops.rocks) , batching ([vmap](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap), [xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html)), and tensor indexing by adding dimension objects to PyTorch_. 7 8The tensor input to a resnet might have the shape [8, 3, 224, 224] but informally we think of those dimensions as 'batch', 'channel', 'width', and 'height'. Eventhough 'width' and 'height' have the same _size_ we still think of them as separate dimensions, and if we have two _different_ images, we think of both as sharing the _same_ 'channel' dimension. 9 10Named tensors gives these dimensions names. [PyTorch's current implementation](https://pytorch.org/docs/stable/named_tensor.html) uses strings to name dimensions. Instead, this library introduces a Python object, a `Dim`, to represent the concept. By expanding the semantics of tensors with dim objects, in addition to naming dimensions, we can get behavior equivalent to batching transforms (xmap, vmap), einops-style rearrangement, and loop-style tensor indexing. 11 12A preview: 13 14```py 15from torchdim import dims 16 17# einsum 18def mm(A: torch.Tensor, B: torch.Tensor): 19 i, j, k = dims(3) 20 r = (A[i, k] * B[k, j]).sum(k) 21 return r.order(i, j) 22 23# rearrange 24def pixel_shuffle(img: torch.Tensor, upscale_factor=2): 25 h2, w2, c, b, h, w = dims(6) 26 h2.size = w2.size = upscale_factor 27 return img[b, (c, h2, w2), h, w].order(b, c, (h, h2), (w, w2)) 28 29# batching 30def bmm(A: torch.Tensor, B: torch.Tensor): 31 i = dims(1) 32 return mm(A[i], B[i]).order(i) 33 34# indexing 35def embedding_bag(input: torch.Tensor, embedding_weights: torch.Tensor): 36 batch, sequence, features = dims(3) 37 r = embedding_weights[input[batch, sequence], features].sum(sequence) 38 return r.order(batch, features) 39``` 40 41Installation 42============ 43 44 45_torchdim is a preview release so that we can collect feedback on the API. It may have bugs, and there are known places where performance can be improved._ 46 47First-class dims are a library that extends PyTorch, so they need to be installed separately. 48We may eventually upstream them into PyTorch itself along with `functorch`. 49 50 51We have to install a nightly build of PyTorch so first set up an environment: 52 53```sh 54conda create --name dim 55conda activate dim 56``` 57 58First-class dims requires a fairly recent nightly build of PyTorch so that functorch will work. You can install it using one of these commands: 59 60```sh 61# For CUDA 10.2 62conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch-nightly 63# For CUDA 11.3 64conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch-nightly 65# For CPU-only build 66conda install pytorch torchvision torchaudio cpuonly -c pytorch-nightly 67``` 68 69Install dim. You will be asked for github credentials to access the fairinternal organization. 70 71```sh 72pip install ninja # Makes the build go faster 73pip install --user "git+https://github.com/facebookresearch/torchdim" 74``` 75 76Creating and Binding Dims 77========================= 78 79Python objects that represent dimension are created using the `dims` operator.[^1] 80 81```py 82import torch 83from torchdim import dims 84 85batch, channel, width, height = dims(4) 86``` 87 88The existing implementation of [Named Tensors](https://pytorch.org/docs/stable/named_tensor.html) in PyTorch, or [JAX's xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html) use strings to name dimensions. We call these dimensions _first class_ because they are Python objects. 89 90In addition to the normal _positional_ dimensions in a tensor, tensors can also have a separate set of first-class dimensions. 91 92You can create tensors with first-class dimensions by indexing the normal positional dimensions of a tensor with a dimension object. The `ndim` property continues to list the number of positional dimensions, while the new `dims` property lists all the bound first-class dimensions. 93 94```py 95input = torch.rand(2, 3, 224, 224) 96print(input.ndim) 97> 4 98 99input_fc = input[batch, channel, width, height] 100print(input_fc.dims) # first class dimensions 101> (batch, channel, width, height) 102 103 104# since we converted all the positional dimensions 105# first class `input_fc` has 0 positional dimensions now. 106print(input_fc.ndim) 107> 0 108``` 109 110Notice that indexing creates a _new_ Tensor, `input_fc` with bound first-class dimensions. It does not modify the original tensor `input`, which still has 4 positional dimensions. 111 112```py 113print(input.ndim) # unchanged 114> 4 115``` 116 117Importantly, indexing with square brackets _applies only to positional dimensions_, so attempting to index a tensor with only first class dims will error[^2]: 118 119```py 120try: 121 input_fc[0] 122except ValueError as ve: 123 print(ve) 124> at least 1 indices were supplied but the tensor only has 0 dimensions 125``` 126 127Generally, it is possible to construct tensors with a mixture of positional and first class dimensions: 128 129```py 130input_mixed = input[batch, :, :, height] 131print(input_mixed.dims) 132> (batch, height) 133 134print(input_mixed.ndim) 135> 2 136``` 137 138Dimension Sizes 139--------------- 140 141Dimensions will take on the size of the first thing they are bound to: 142 143```py 144input = torch.rand(3) 145x = dims(1) 146input_fc = input[x] 147print(x.size) 148> 3 149``` 150 151But you can also directly set the size of dimension: 152 153```py 154i = dims(1) 155 156i.size = 5 # ok, i previously did not have a size 157 158i.size = 5 # ok, it already had the size 5 159try: 160 i.size = 3 161except Exception as e: 162 print(e) 163> Dim 'i' previously bound to a dimension of size 5 cannot bind to a dimension of size 3 164 165j = dims(sizes=[4]) # can also be set on construction 166``` 167 168[^1]: We use a bit of Python introspection to set the debug names for the dimensions based on the names of the variables they are assigned to. 169[^2]: Indexing of first-class dimensions can be done with the `index` method by specifying the dimension to be index into (e.g. `input_fc.index(batch, 0)`. 170 171Semantics of Dimensions 172======================= 173The power of named tensors arises from how the first-class dimensions in the Tensors composed with existing operations. 174 175Three rules define how dimension objects behave with existing Tensors. 176 177Rule 1: Implicit Batching 178------------------------- 179**Tensor operations (e.g. `input + bias`) are implicitly batched over the union of the first-class dimensions in their inputs.** 180 181If `input` has dimensions `batch, channel` and `bias` has dimension `channel`, the output will have the union of those dimensions (`batch, channel`), and the result will be computed as if there was a loop over all the first-class dimensions.[^3] 182 183```py 184input_positional = torch.rand(128, 32) 185bias_positional = torch.rand(32) 186 187batch, channel = dims(2) 188input = input_positional[batch, channel] 189bias = bias_positional[channel] 190 191result = input + bias 192print(result.dims) 193> (batch, channel) 194``` 195 196It is helpful to think of operators on tensors with first-class dimensions by analogy to code with explicit loops over dimensions, with the first-class dimensions of the inputs acting as implicit `for` loops, and the values in the tensor being scalars within the body of the loop: 197 198```py 199# mental model: loop-level analogy 200for batch in range(batch.size): 201 for channel in range(channel.size): 202 input = input_positional[batch, channels] 203 bias = bias_positional[channels] 204 result[batch, channels] = input + bias # arithmetic on scalars 205``` 206 207Positional dimensions behave as they did before (e.g. for + they will broadcast), and can be thought of as being a standard tensor _used within the implicit loops_ defined by first-class dimensions. 208 209In this example, we broke down the expression into lines that bind the dimension to positional tensors and then another line to do the compute. In practice, we often combine these in one statement: 210 211```py 212result = input_positional[batch, channel] + bias_positional[channel] 213result.dims 214``` 215 216[^3] This rule is similar to how named dimensions in xmap behave within a function, but instead of introducing the dimensions via a functional transform, they are bound on the objects using indexing. 217 218 219Rule 2: Specifying dimensions 220----------------------------- 221**Wherever an integer is used to specify a dimension in the existing torch operator, a first-class dimensions can be used instead to tell the operator to work over that dimension.** 222 223```py 224batch, channel, width, height = dims(4) 225input_positional = torch.rand(2, 3, 224, 224) 226input = input_positional[batch, channel, width, height] 227avg_pixel_color = input.mean((width, height)) 228 229print(avg_pixel_color.dims) 230> (batch, channel) 231``` 232 233Any other first-class dimensions (e.g. batch, channel) are still implicitly batched according to Rule #1. 234 235Rule 3: Dims are Tensors 236------------------------ 237**A first-class dimension `d` can be used wherever a Tensor is expected. It will act as if it were a tensor whose only dimension is itself, `d`, and the values along the dimension are the indices of each entry `(0, 1, 2, ..., d.size - 1)`** 238 239```py 240print(channel.dims) 241> (channel,) 242 243print(channel + 1000) 244> tensor([1000, 1001, 1002]) 245> with dims=(channel,) sizes=(3,) 246``` 247 248This means that a dimension used as a tensor acts as an index into that dimension. Going back to our loop-level analogy, it is analogous to using the loop variable as a value: 249 250```py 251# mental model: loop-level analogy 252for channel in range(batch.size): 253 result[channel] = channel + 1000 254``` 255 256Arithmetic using dimension indices comes up a lot, such as the mask for an upper triangular part of a matrix. Using dims as tensors makes it easy: 257 258```py 259from torchdim import dims 260i, j = dims(sizes=[4, 4]) 261print(i <= j) 262> tensor([[ True, True, True, True], 263> [False, True, True, True], 264> [False, False, True, True], 265> [False, False, False, True]]) 266> with dims=(i, j) sizes=(4, 4) 267``` 268 269Because of the intentional similarity to loop-level code, using dimensions as tensors makes complicated indexing arithmetic easier to read. 270 271Here is code that lookups up features in an embedding table given a sequence of ids: 272 273```py 274sequence, features = dims(2) 275embeddings = torch.rand(8, 128) 276words = torch.tensor([5, 4, 0,]) 277 278state = embeddings[words[sequence], features] 279print(state.dims) 280> (sequence, features) 281``` 282 283With the following analogy to loops: 284 285```py 286# mental model: loop-level analogy 287 288for sequence in range(words.size(0)): 289 for features in range(embeddings.size(1)): 290 state = embeddings[words[sequence], features] 291``` 292 293Earlier we showed how binding tensors dimension is done with indexing `A[i, j]`. In fact, this binding is just the normal indexing operator. Its behavior follows directly from the behavior of indexing with tensor indices combined with Rule #3 and Rule #1. The expression `A[i + 1, j]` also creates a tensor with dimensions `i` and `j` but with different indexing math. The implementation knows when simple indexing patterns are used and only actually runs a kernel to do indexing when needed. 294 295Unbinding Dims 296------------- 297The `order` method converts first-class dimensions in a tensor back to normal positional dimensions by specifying an order for those dimensions.[^4] 298 299By specifying a different order from how things were originally bound, it is easy to do transpositions. 300 301```py 302i, j = dims(2) 303A = torch.rand(3, 4) 304A_T = A[i, j].order(j, i) 305assert torch.allclose(A.T, A_T) 306``` 307 308Indexing acts left-to-right, and `order` also places the new dimensions back on the left, so it possible to work on tensors that have mixed positional and first-class dimensions: 309 310```py 311B = torch.rand(3, 4, 5) 312B_T = B[i, j].order(j, i) 313assert torch.allclose(B.permute(1, 0, 2), B_T) 314``` 315 316[^4] `order` is actually just a synonym for the already-existing `permute` method, which takes a list a dimension specifiers and puts the tensor in that order because rule #2 says that first-class dims can be passed as arguments to functions that previously took only integers as dimensions. However, the name `permute` is confusing in this context since it implies dim objects have an original order, so we prefer to use `order` when writing code. 317 318Flattening and Splitting Dims 319----------------------------- 320 321**Tuples of dimensions** can be passed to both indexing and `order`. In indexing, this will split the dimension being indexed across the dimensions in the tuple. In `order` it will flatten the dimensions in a single positional dimension: 322 323```py 324i, j, k = dims(3) 325j.size = 2 326A = torch.rand(6, 4) 327a = A[(i, j), k] # split dim 0 into i,j 328print(i.size, j.size, k.size) 329> 3 2 4 330 331r = a.order(i, (j, k)) # flatten j and k 332print(r.shape) 333> torch.Size([3, 8]) 334``` 335 336The size of one unsized dimension in a tuple such as `i` can be inferred if the other sizes are known. 337 338Examples 339======== 340 341The usefulness of dimension objects is best seen through examples. Let's look at some different ways they can be used. 342 343Einsum-style Products 344--------------------- 345Rather than having [einsum](https://pytorch.org/docs/stable/generated/torch.einsum.html) as a custom operator, it is possible to express matrix products directly as a composition of multiplies and summations. The implementation will pattern match any multiplication followed by a sum to the right matrix-multiply operator. 346 347```py 348def mm(A, B): 349 i, j, k = dims(3) 350 r = (A[i, k] * B[k, j]).sum(k) 351 return r.order(i, j) 352mm(torch.rand(3, 4), torch.rand(4, 5)).shape 353``` 354 355The implementation of named tensors delays the execution of multiply to see if a summation follows it as it does above. If so, it will turn this pattern into the correct _optimized matrix product_, similar to how the `einsum` function works. 356 357Since it is no longer necessary to manually match math to matrix functions, other tensor products are easier to express, like the Gram matrix used in style transfer: 358 359```py 360def gram_matrix_new(y): 361 b, c, c2, h, w = dims() 362 r = (y[b, c, h, w] * y[b, c2, h, w]).sum((h, w)) 363 r = r / (h.size * w.size) 364 return r.order(b, c, c2) 365 366gram_matrix_new(torch.rand(1, 2, 3, 4)) 367# [example adapted from http://einops.rocks/pytorch-examples.html] 368``` 369 370Attention is another example that has several matrix products embedded inside it: 371 372```py 373from torchdim import softmax 374def attention(K, Q, V): 375 batch, channel, key, query = dims(4) 376 k = K[batch, channel, key] 377 q = Q[batch, channel, query] 378 v = V[batch, channel, key] 379 380 a = (k * q).sum(channel) # matrix multiply 381 a = softmax(a * (channel.size ** -0.5), dim=key) 382 r = (v * a).sum(key) # matrix multiply 383 return torch.cat((r.order(batch, channel, query), Q), dim=1) 384 385inputs = (torch.rand(2, 3, 4) for _ in range(3)) 386attention(*inputs) 387# [example adapted from http://einops.rocks/pytorch-examples.html] 388``` 389 390Reshaping tensors (einops) 391-------------------------- 392 393Lots of operations in deep learning are just different ways of reshaping, splitting, and joining dimensions, such as the pixel shuffle used to upscale an image by turning channels into pixels: 394 395```py 396def pixel_shuffle(img, upscale_factor=2): 397 h2, w2, c, b, h, w = dims(6) 398 h2.size = w2.size = upscale_factor 399 return img[b, (c, h2, w2), h, w].order(b, c, (h, h2), (w, w2)) 400``` 401 402[Einops](http://einops.rocks) is an extension to einsum that adds support for the manipulation of dimensions through a few custom operators such as `rearrange`: 403 404```py 405def pixel_shuffle_einops(img, upscale_factor=2): 406 from einops import rearrange 407 return rearrange(img, 'b (c h2 w2) h w -> b c (h h2) (w w2)', h2=upscale_factor, w2=upscale_factor) 408``` 409 410Named tensors with first-class dimensions can accomplish the same goal, but using PyTorch's existing operator set. 411 412Automatically batching Code (`vmap`, `xmap`) 413----------------------------- 414 415The implicit batching of Rule #1 means it is easy to created batched versions of existing PyTorch code. Simply bind a dim to the dimensions that should act as a batch, and then pass the tensor to the unbatched function. Since the unbatched function does not know about the dim, the dim will be implicitly batched over: 416 417```py 418batch_size, feature_size = 3, 5 419weights = torch.randn(feature_size) 420 421def model(feature_vec): 422 # Very simple linear model with activation 423 assert feature_vec.dim() == 1 424 return feature_vec.dot(weights).relu() 425 426examples = torch.randn(batch_size, feature_size) 427batch = dims(1) 428r = model(examples[batch]) 429print(r) 430# in functorch: result = functorch.vmap(model)(examples) 431> tensor([0.4775, 0.0000, 0.3423]) 432> with dims=(batch,) sizes=(3,) 433``` 434 435This pattern also composes well with other code that also uses first class dimensions. For instance, we can write batched matrix multiply `bmm` by batching the `mm` operator. 436 437It doesn't matter whether the implementation of the function uses dimension objects, it is also possible to add additional batch dimensions and then call a function: 438 439```py 440def bmm(A, B): 441 i = dims(1) # note: i here is a different value from i inside mm so it works 442 return mm(A[i], B[i]).order(i) 443``` 444 445The equivalent code in JAX, using [xmap or vmap](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#auto-vectorization-with-vmap) are transforms over functions. So there is a lot of syntactic distance between the specification of the dimension mappings, and the values where those mappings apply. Dims express the mapping as indexing of the tensor, right at the place where the function is being applied. 446 447 448[xmap examples](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html): 449 450```py 451in_axes = [['inputs', 'hidden', ...], 452 ['hidden', 'classes', ...], 453 ['batch', 'inputs', ...], 454 ['batch', ...]] 455 456loss = xmap(named_loss, in_axes=in_axes, out_axes=[...]) 457print(loss(w1, w2, images, labels)) 458``` 459 460Equivalent with dimension objects: 461 462```py 463batch, inputs, hidden, classes = dims(4) 464print(loss(w1[inputs, hidden], w2[hidden, classes], images[batch, inputs], labels[batch], 465 batch, inputs, hidden, classes)) 466``` 467 468 469Composing matrix products, reshaping, and batching: 470--------------------- 471 472Multi-headed attention is a good example of how these different uses compose. It reshapes the inputs, splitting out different attention heads. It batches over those attention heads, and it uses matrix products to compute attention scores. 473 474```py 475from torchdim import softmax 476def multiheadattention(q, k, v, num_attention_heads, dropout_prob, use_positional_embedding): 477 batch, query_sequence, key_sequence, heads, features = dims(5) 478 heads.size = num_attention_heads 479 480 # binding dimensions, and unflattening the heads from the feature dimension 481 q = q[batch, query_sequence, [heads, features]] 482 k = k[batch, key_sequence, [heads, features]] 483 v = v[batch, key_sequence, [heads, features]] 484 485 # einsum-style operators to calculate scores, 486 attention_scores = (q*k).sum(features) * (features.size ** -0.5) 487 488 # use first-class dim to specify dimension for softmax 489 attention_probs = softmax(attention_scores, dim=key_sequence) 490 491 # dropout work pointwise, following Rule #1 492 attention_probs = torch.nn.functional.dropout(attention_probs, p=dropout_prob) 493 494 # another matrix product 495 context_layer = (attention_probs*v).sum(key_sequence) 496 497 # flatten heads back into features 498 return context_layer.order(batch, query_sequence, [heads, features]) 499``` 500 501Indexing 502-------- 503 504Rule #3 enables indexing because dimensions act as loop indices when used as a tensor. This allows for a lot of powerful behavior. The simplest might be using the dimensions to compute masks, such as extracting the upper triangular part of a matrix: 505 506```py 507from torch import where 508def triu(A): 509 i,j = dims() 510 a = A[i, j] 511 return where(i <= j, a, 0).order(i, j) 512triu(torch.rand(3, 4)) 513``` 514 515Embedding bag does an embedding table lookup followed by a sum, which can be expressed concisely: 516 517```py 518def embedding_bag(input, embedding_weights): 519 batch, sequence, features = dims(3) 520 r = embedding_weights[input[batch, sequence], features].sum(sequence) 521 return r.order(batch, features) 522 523input = torch.tensor([[1, 0, 4, 3]]) 524W = torch.rand(5,2) 525embedding_bag(input, W) 526``` 527 528Relative positional embeddings associate an embedding vector with the distance between the query and the key in the sequence. 529For instance, a key 3 and query 5 will have embedding ID `(5-3)=2`. We can use first-class dimensions to do the indexing arithmetic, and the embedding lookup: 530 531```py 532def relative_positional_embedding(q, k, distance_embedding_weight): 533 batch, query_sequence, key_sequence, heads, features = dims(5) 534 q = q[batch, query_sequence, [heads, features]] 535 k = k[batch, key_sequence, [heads, features]] 536 537 distance = query_sequence - key_sequence 538 n_embeddings = distance_embedding_weight.size(0) 539 index_bias = n_embeddings // 2 540 541 assert key_sequence.size + bias <= n_embeddings 542 543 # indexing with dims 544 positional_embedding = distance_embedding_weight[distance + index_bias, features] 545 546 # matrix multiplies with dims 547 relative_position_scores_query = (q*positional_embedding).sum(features) 548 relative_position_scores_key = (k*positional_embedding).sum(features) 549 return (relative_position_scores_query + relative_position_scores_key).order(batch, heads, key_sequence, query_sequence) 550``` 551 552Tensor Puzzlers 553=============== 554 555[Tensor Puzzlers](https://github.com/srush/Tensor-Puzzles), created by Sasha Rush, are a good exercise for learning the numpy and torch APIs by figuring out how to define common operations using a small set of primitive tensor operations. 556 557However, the difficulty of many of the puzzlers lies not in how to compute the answer but the awkwardness of the primitives themselves. 558 559**With first class dimensions, these puzzlers are nearly the same as the spec that defines them** 560 561 562### Puzzle 3 - outer 563 564Compute [outer](https://numpy.org/doc/stable/reference/generated/numpy.outer.html) - the outer product of two vectors. 565 566```py 567def outer_spec(a, b, out): 568 for i in range(len(out)): 569 for j in range(len(out[0])): 570 out[i][j] = a[i] * b[j] 571 572def outer(a, b): 573 i, j = dims(2) 574 return (a[i] * b[j]).order(i, j) 575``` 576 577### Puzzle 4 - diag 578 579Compute [diag](https://numpy.org/doc/stable/reference/generated/numpy.diag.html) - the diagonal vector of a square matrix. 580 581```py 582def diag_spec(a, out): 583 for i in range(len(a)): 584 out[i] = a[i][i] 585 586def diag(a): 587 i = dims(1) 588 return a[i, i].order(i) 589``` 590 591### Puzzle 5 - eye 592 593Compute [eye](https://numpy.org/doc/stable/reference/generated/numpy.eye.html) - the identity matrix. 594 595```py 596from torch import where 597def eye_spec(out): 598 for i in range(len(out)): 599 out[i][i] = 1 600 601def eye(j: int): 602 i,j = dims(sizes=[j, j]) 603 return where(i == j, 1, 0).order(i, j) 604``` 605 606### Puzzle 6 - triu 607 608Compute [triu](https://numpy.org/doc/stable/reference/generated/numpy.triu.html) - the upper triangular matrix. 609 610```py 611def triu_spec(out): 612 for i in range(len(out)): 613 for j in range(len(out)): 614 if i <= j: 615 out[i][j] = 1 616 else: 617 out[i][j] = 0 618 619def triu(j: int): 620 i,j = dims(sizes=[j, j]) 621 return where(i <= j, 1, 0).order(i, j) 622``` 623 624### Puzzle 8 - diff 625 626Compute [diff](https://numpy.org/doc/stable/reference/generated/numpy.diff.html) - the running difference. 627 628```py 629def diff_spec(a, out): 630 out[0] = a[0] 631 for i in range(1, len(out)): 632 out[i] = a[i] - a[i - 1] 633def diff(a, i: int): 634 i = dims(1) 635 d = a[i] - a[i - 1] 636 return where(i - 1 >= 0, d, a[i]).order(i) 637``` 638 639### Puzzle 9 - vstack 640 641Compute [vstack](https://numpy.org/doc/stable/reference/generated/numpy.vstack.html) - the matrix of two vectors 642 643```py 644def vstack_spec(a, b, out): 645 for i in range(len(out[0])): 646 out[0][i] = a[i] 647 out[1][i] = b[i] 648 649def vstack(a, b): 650 v, i = dims(sizes=[2, None]) 651 return where(v == 0, a[i], b[i]).order(v, i) 652``` 653 654### Puzzle 10 - roll 655 656Compute [roll](https://numpy.org/doc/stable/reference/generated/numpy.roll.html) - the vector shifted 1 circular position. 657 658```py 659def roll_spec(a, out): 660 for i in range(len(out)): 661 if i + 1 < len(out): 662 out[i] = a[i + 1] 663 else: 664 out[i] = a[i + 1 - len(out)] 665 666def roll(a, i: int): 667 i = dims(sizes=[a.size(0)]) 668 return a[where(i + 1 < i.size, i + 1, 0)].order(i) 669``` 670 671### Puzzle 11 - flip 672 673Compute [flip](https://numpy.org/doc/stable/reference/generated/numpy.flip.html) - the reversed vector 674 675```py 676def flip_spec(a, out): 677 for i in range(len(out)): 678 out[i] = a[len(out) - i - 1] 679 680def flip(a, i: int): 681 i = dims(sizes=[a.size(0)]) 682 return a[i.size - i - 1].order(i) 683``` 684 685### Puzzle 14 - sequence_mask 686 687 688Compute [sequence_mask](https://www.tensorflow.org/api_docs/python/tf/sequence_mask) - pad out to length per batch. 689 690```py 691def sequence_mask_spec(values, length, out): 692 for i in range(len(out)): 693 for j in range(len(out[0])): 694 if j < length[i]: 695 out[i][j] = values[i][j] 696 else: 697 out[i][j] = 0 698 699def sequence_mask(values, length): 700 j, i = dims() 701 v = values[i, j] 702 return where(j < length[i], v, 0).order(i, j) 703``` 704 705Advantages of First-class Dimensions over String Dimensions 706=================================================================== 707 708The most prominent difference between named tensors using first-class dimensions and alternatives (einops, named tensors implemented in PyTorch today , [tensors considered harmful](https://nlp.seas.harvard.edu/NamedTensor), or xmap) is that dimensions are objects rather than strings. Using objects has a number of nice properties. 709 710### Avoiding naming conflicts 711 712Using strings for dimensions introduces the possibility that two unrelated dimensions are given the same name. Using objects instead makes it clear the same names are not the same dimension. It's like the difference between having only global variables, and having the ability to locally bind names in functions. 713 For instance, we defined `bmm` by batching a call to `mm`, and even though they both use the name `i` to identify a dimension. Because each `i` is a different object, there is no naming conflict: 714 715```py 716def mm(A, B): 717 i, j, k = dims() 718 r = (A[i, k] * B[k, j]).sum(k) 719 return r.order(i, j) 720 721def bmm(A, B): 722 i = dims() # note: doesn't matter than mm internally also uses i 723 return mm(A[i], B[i]) 724``` 725 726Einops avoids conflicts by ensuring names are all introduced and removed in a single expression, but this precludes using long-lived dimensions to present implicit batching similar to xmap. When nested, JAX's xmap seems to consider axes the same if the string name matches. In the above example it would consider the `i` dimension to be the same dimension in both `bmm` and `mm` so the code would error. 727 728 729### Reuse the same operator set 730 731Having a new object type allows us to extend the existing operator set of PyTorch rather than come up with new operators. For instance, binding dimensions using indexing follows semantically from Rules #1 and #3, so there is no need for a special operator to do binding. Even unbinding is just the `permute` operator which follows from Rule #2, though we call it `order` for clarity. In contrast, using strings requires coming up with new APIs such as `einsum` for matrix multiplies, or `rearrange` for doing permutations. 732 733### Allows dims to act as tensors 734 735Rule #3 is not possible with strings since we cannot make strings behave as tensors. Without this rule, all of the indirect indexing that dims enable would not be easy to express. 736 737### Dims can have methods 738For instance, as objects, dims can have a size, which allows us to do size inference of dimensions in various places in the API where string based APIs would have to take additional arguments specifying size. 739 740 741Comparison to tensor compilers or languages (e.g. TVM or Dex) 742============================================================= 743 744The semantics and surface syntax of dimension objects resembles the kind of code written in tensor compilers such as [Halide](https://halide-lang.org), [TVM](https://tvm.apache.org), [Tensor Comprehensions](https://github.com/facebookresearch/TensorComprehensions), or the language [Dex](https://github.com/google-research/dex-lang). 745 746These compilers and language have syntax and semantics that resemble the loop-level analogy similar to first-class dimensions. However, as compilers or statically typed languages, they require some binding code to go from running deep learning framework code in Python to using the compiled language. This often at least requires refactoring the compiled parts into their own functions, and may require defining a gradient function. Similar to graph mode frameworks, this adds friction to using and debugging the code. 747 748Dimension objects are just an extension of the existing PyTorch tensors and eager semantics, so there is no friction switching between normal Python code and code that uses them. However, since loops over the dimensions are defined implicitly, they can still execute in Python with good performance compared to explicit loops. Furthermore, with dimension objects, a tensors containing dimensions can compute through code that is oblivious to the dimension such as batching examples. There is no need to separate code into 'compiled' vs 'eager'. 749 750In this way, first-class dims are a way of adapting the nicer syntax of these array compilers and languages to eager numpy-style libraries. 751 752 753Performance Expectations 754======================== 755First-class dimensions are not a compiler. They provide syntax for existing PyTorch operations such as advanced indexing that is easier to read and write. For large sized tensors, the performance of any statements including them will be the same as using the already existing operations. An important exception is the pattern matching of products and summation, where performance will be improved by issuing to a matrix-multiply kernel. The C++ implementation of dimensions adds a small overhead of around 2us on top of PyTorch's normal overhead of 8us to each function that uses them. In the future, the implementation can encorporate more fusion optimization to further improve performance of this style of code. 756 757 758## License 759Functorch has a BSD-style license, as found in the [LICENSE](LICENSE) file. 760