Name Date Size #Lines LOC

..--

README.mdH A D25-Apr-202529.8 KiB760548

__init__.pyH A D25-Apr-20254.6 KiB182108

batch_tensor.pyH A D25-Apr-2025668 2716

delayed_mul_tensor.pyH A D25-Apr-20252.4 KiB7861

dim.pyH A D25-Apr-20253.3 KiB12291

magic_trace.pyH A D25-Apr-20251.3 KiB4335

op_properties.pyH A D25-Apr-20256.5 KiB313301

reference.pyH A D25-Apr-202519.9 KiB646502

tree_map.pyH A D25-Apr-2025375 165

wrap_type.pyH A D25-Apr-20251.8 KiB7350

README.md

1Named Tensors using First-class Dimensions in PyTorch
2=====================================================
3
4-- Zachary DeVito [@Zachary_DeVito](https://twitter.com/Zachary_DeVito)
5
6_An implementation of [named tensors](https://namedtensor.github.io) with the functionality of [einsum](http://einops.rocks]http://einops.rocks) , batching ([vmap](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap), [xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html)), and tensor indexing by adding dimension objects to PyTorch_.
7
8The tensor input to a resnet might have the shape [8, 3, 224, 224] but informally we think of those dimensions as 'batch', 'channel', 'width', and 'height'. Eventhough 'width' and 'height' have the same _size_ we still think of them as separate dimensions, and if we have two _different_ images, we think of both as sharing the _same_ 'channel' dimension.
9
10Named tensors gives these dimensions names. [PyTorch's current implementation](https://pytorch.org/docs/stable/named_tensor.html) uses strings to name dimensions. Instead, this library introduces a Python object, a `Dim`, to represent the concept. By expanding the semantics of tensors with dim objects, in addition to naming dimensions, we can get behavior equivalent to batching transforms (xmap, vmap), einops-style rearrangement, and loop-style tensor indexing.
11
12A preview:
13
14```py
15from torchdim import dims
16
17# einsum
18def mm(A: torch.Tensor, B: torch.Tensor):
19    i, j, k = dims(3)
20    r = (A[i, k] * B[k, j]).sum(k)
21    return r.order(i, j)
22
23# rearrange
24def pixel_shuffle(img: torch.Tensor, upscale_factor=2):
25    h2, w2, c, b, h, w = dims(6)
26    h2.size = w2.size = upscale_factor
27    return img[b, (c, h2, w2), h, w].order(b, c, (h, h2), (w, w2))
28
29# batching
30def bmm(A: torch.Tensor, B: torch.Tensor):
31    i = dims(1)
32    return mm(A[i], B[i]).order(i)
33
34# indexing
35def embedding_bag(input: torch.Tensor, embedding_weights: torch.Tensor):
36    batch, sequence, features = dims(3)
37    r = embedding_weights[input[batch, sequence], features].sum(sequence)
38    return r.order(batch, features)
39```
40
41Installation
42============
43
44
45_torchdim is a preview release so that we can collect feedback on the API. It may have bugs, and there are known places where performance can be improved._
46
47First-class dims are a library that extends PyTorch, so they need to be installed separately.
48We may eventually upstream them into PyTorch itself along with `functorch`.
49
50
51We have to install a nightly build of PyTorch so first set up an environment:
52
53```sh
54conda create --name dim
55conda activate dim
56```
57
58First-class dims requires a fairly recent nightly build of PyTorch so that functorch will work. You can install it using one of these commands:
59
60```sh
61# For CUDA 10.2
62conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch-nightly
63# For CUDA 11.3
64conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch-nightly
65# For CPU-only build
66conda install pytorch torchvision torchaudio cpuonly -c pytorch-nightly
67```
68
69Install dim. You will be asked for github credentials to access the fairinternal organization.
70
71```sh
72pip install ninja  # Makes the build go faster
73pip install --user "git+https://github.com/facebookresearch/torchdim"
74```
75
76Creating and Binding Dims
77=========================
78
79Python objects that represent dimension are created using the `dims` operator.[^1]
80
81```py
82import torch
83from torchdim import dims
84
85batch, channel, width, height = dims(4)
86```
87
88The existing implementation of [Named Tensors](https://pytorch.org/docs/stable/named_tensor.html) in PyTorch, or [JAX's xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html) use strings to name dimensions. We call these dimensions _first class_ because they are Python objects.
89
90In addition to the normal _positional_ dimensions in a tensor, tensors can also have a separate set of first-class dimensions.
91
92You can create tensors with first-class dimensions by indexing the normal positional dimensions of a tensor with a dimension object. The `ndim` property continues to list the number of positional dimensions, while the new `dims` property lists all the bound first-class dimensions.
93
94```py
95input = torch.rand(2, 3, 224, 224)
96print(input.ndim)
97> 4
98
99input_fc = input[batch, channel, width, height]
100print(input_fc.dims) # first class dimensions
101> (batch, channel, width, height)
102
103
104# since we converted all the positional dimensions
105# first class `input_fc` has 0 positional dimensions now.
106print(input_fc.ndim)
107> 0
108```
109
110Notice that indexing creates a _new_ Tensor, `input_fc` with bound first-class dimensions. It does not modify the original tensor `input`, which still has 4 positional dimensions.
111
112```py
113print(input.ndim) # unchanged
114> 4
115```
116
117Importantly, indexing with square brackets _applies only to positional dimensions_, so attempting to index a tensor with only first class dims will error[^2]:
118
119```py
120try:
121    input_fc[0]
122except ValueError as ve:
123    print(ve)
124> at least 1 indices were supplied but the tensor only has 0 dimensions
125```
126
127Generally, it is possible to construct tensors with a mixture of positional and first class dimensions:
128
129```py
130input_mixed = input[batch, :, :, height]
131print(input_mixed.dims)
132> (batch, height)
133
134print(input_mixed.ndim)
135> 2
136```
137
138Dimension Sizes
139---------------
140
141Dimensions will take on the size of the first thing they are bound to:
142
143```py
144input = torch.rand(3)
145x = dims(1)
146input_fc = input[x]
147print(x.size)
148> 3
149```
150
151But you can also directly set the size of dimension:
152
153```py
154i = dims(1)
155
156i.size = 5 # ok, i previously did not have a size
157
158i.size = 5 # ok, it already had the size 5
159try:
160    i.size = 3
161except Exception as e:
162    print(e)
163> Dim 'i' previously bound to a dimension of size 5 cannot bind to a dimension of size 3
164
165j = dims(sizes=[4]) # can also be set on construction
166```
167
168[^1]: We use a bit of Python introspection to set the debug names for the dimensions based on the names of the variables they are assigned to.
169[^2]: Indexing of first-class dimensions can be done with the `index` method by specifying the dimension to be index into (e.g. `input_fc.index(batch, 0)`.
170
171Semantics of Dimensions
172=======================
173The power of named tensors arises from how the first-class dimensions in the Tensors composed with existing operations.
174
175Three rules define how dimension objects behave with existing Tensors.
176
177Rule 1: Implicit Batching
178-------------------------
179**Tensor operations (e.g. `input + bias`) are implicitly batched over the union of the first-class dimensions in their inputs.**
180
181If `input` has dimensions `batch, channel` and `bias` has dimension `channel`, the output will have the union of those dimensions (`batch, channel`), and the result will be computed as if there was a loop over all the first-class dimensions.[^3]
182
183```py
184input_positional = torch.rand(128, 32)
185bias_positional = torch.rand(32)
186
187batch, channel = dims(2)
188input = input_positional[batch, channel]
189bias = bias_positional[channel]
190
191result = input + bias
192print(result.dims)
193> (batch, channel)
194```
195
196It is helpful to think of operators on tensors with first-class dimensions by analogy to code with explicit loops over dimensions, with the first-class dimensions of the inputs acting as implicit `for` loops, and the values in the tensor being scalars within the body of the loop:
197
198```py
199# mental model: loop-level analogy
200for batch in range(batch.size):
201    for channel in range(channel.size):
202        input = input_positional[batch, channels]
203        bias = bias_positional[channels]
204        result[batch, channels] =  input + bias # arithmetic on scalars
205```
206
207Positional dimensions behave as they did before (e.g. for + they will broadcast), and can be thought of as being a standard tensor _used within the implicit loops_ defined by first-class dimensions.
208
209In this example, we broke down the expression into lines that bind the dimension to positional tensors and then another line to do the compute. In practice, we often combine these in one statement:
210
211```py
212result = input_positional[batch, channel] + bias_positional[channel]
213result.dims
214```
215
216[^3] This rule is similar to how named dimensions in xmap behave within a function, but instead of introducing the dimensions via a functional transform, they are bound on the objects using indexing.
217
218
219Rule 2: Specifying dimensions
220-----------------------------
221**Wherever an integer is used to specify a dimension in the existing torch operator, a first-class dimensions can be used instead to tell the operator to work over that dimension.**
222
223```py
224batch, channel, width, height = dims(4)
225input_positional = torch.rand(2, 3, 224, 224)
226input = input_positional[batch, channel, width, height]
227avg_pixel_color = input.mean((width, height))
228
229print(avg_pixel_color.dims)
230> (batch, channel)
231```
232
233Any other first-class dimensions (e.g. batch, channel) are still implicitly batched according to Rule #1.
234
235Rule 3: Dims are Tensors
236------------------------
237**A first-class dimension `d` can be used wherever a Tensor is expected. It will act as if it were a tensor whose only dimension is itself, `d`, and the values along the dimension are the indices of each entry `(0, 1, 2, ..., d.size - 1)`**
238
239```py
240print(channel.dims)
241> (channel,)
242
243print(channel + 1000)
244> tensor([1000, 1001, 1002])
245> with dims=(channel,) sizes=(3,)
246```
247
248This means that a dimension used as a tensor acts as an index into that dimension. Going back to our loop-level analogy, it is analogous to using the loop variable as a value:
249
250```py
251# mental model: loop-level analogy
252for channel in range(batch.size):
253    result[channel] = channel + 1000
254```
255
256Arithmetic using dimension indices comes up a lot, such as the mask for an upper triangular part of a matrix. Using dims as tensors makes it easy:
257
258```py
259from torchdim import dims
260i, j = dims(sizes=[4, 4])
261print(i <= j)
262> tensor([[ True,  True,  True,  True],
263>         [False,  True,  True,  True],
264>         [False, False,  True,  True],
265>         [False, False, False,  True]])
266> with dims=(i, j) sizes=(4, 4)
267```
268
269Because of the intentional similarity to loop-level code, using dimensions as tensors makes complicated indexing arithmetic easier to read.
270
271Here is code that lookups up features in an embedding table given a sequence of ids:
272
273```py
274sequence, features = dims(2)
275embeddings = torch.rand(8, 128)
276words = torch.tensor([5, 4, 0,])
277
278state = embeddings[words[sequence], features]
279print(state.dims)
280> (sequence, features)
281```
282
283With the following analogy to loops:
284
285```py
286# mental model: loop-level analogy
287
288for sequence in range(words.size(0)):
289    for features in range(embeddings.size(1)):
290        state = embeddings[words[sequence], features]
291```
292
293Earlier we showed how binding tensors dimension is done with indexing `A[i, j]`. In fact, this binding is just the normal indexing operator. Its behavior follows directly from the behavior of indexing with tensor indices combined with Rule #3 and Rule #1. The expression `A[i + 1, j]` also creates a tensor with dimensions `i` and `j` but with different indexing math. The implementation knows when simple indexing patterns are used and only actually runs a kernel to do indexing when needed.
294
295Unbinding Dims
296-------------
297The `order` method converts first-class dimensions in a tensor back to normal positional dimensions by specifying an order for those dimensions.[^4]
298
299By specifying a different order from how things were originally bound, it is easy to do transpositions.
300
301```py
302i, j = dims(2)
303A = torch.rand(3, 4)
304A_T = A[i, j].order(j, i)
305assert torch.allclose(A.T, A_T)
306```
307
308Indexing acts left-to-right, and `order` also places the new dimensions back on the left, so it possible to work on tensors that have mixed positional and first-class dimensions:
309
310```py
311B = torch.rand(3, 4, 5)
312B_T = B[i, j].order(j, i)
313assert torch.allclose(B.permute(1, 0, 2), B_T)
314```
315
316[^4] `order` is actually just a synonym for the already-existing `permute` method, which takes a list a dimension specifiers and puts the tensor in that order because rule #2 says that first-class dims can be passed as arguments to functions that previously took only integers as dimensions. However, the name `permute` is confusing in this context since it implies dim objects have an original order, so we prefer to use `order` when writing code.
317
318Flattening and Splitting Dims
319-----------------------------
320
321**Tuples of dimensions** can be passed to both indexing and `order`. In indexing, this will split the dimension being indexed across the dimensions in the tuple.  In `order` it will flatten the dimensions in a single positional dimension:
322
323```py
324i, j, k = dims(3)
325j.size = 2
326A = torch.rand(6, 4)
327a = A[(i, j), k] # split dim 0 into i,j
328print(i.size, j.size, k.size)
329> 3 2 4
330
331r = a.order(i, (j, k)) # flatten j and k
332print(r.shape)
333> torch.Size([3, 8])
334```
335
336The size of one unsized dimension in a tuple such as `i` can be inferred if the other sizes are known.
337
338Examples
339========
340
341The usefulness of dimension objects is best seen through examples. Let's look at some different ways they can be used.
342
343Einsum-style Products
344---------------------
345Rather than having [einsum](https://pytorch.org/docs/stable/generated/torch.einsum.html) as a custom operator, it is possible to express matrix products directly as a composition of multiplies and summations. The implementation will pattern match any multiplication followed by a sum to the right matrix-multiply operator.
346
347```py
348def mm(A, B):
349    i, j, k = dims(3)
350    r = (A[i, k] * B[k, j]).sum(k)
351    return r.order(i, j)
352mm(torch.rand(3, 4), torch.rand(4, 5)).shape
353```
354
355The implementation of named tensors delays the execution of multiply to see if a summation follows it as it does above. If so, it will turn this pattern into the correct _optimized matrix product_, similar to how the `einsum` function works.
356
357Since it is no longer necessary to manually match math to matrix functions, other tensor products are easier to express, like the Gram matrix used in style transfer:
358
359```py
360def gram_matrix_new(y):
361    b, c, c2, h, w = dims()
362    r = (y[b, c, h, w] * y[b, c2, h, w]).sum((h, w))
363    r = r / (h.size * w.size)
364    return r.order(b, c, c2)
365
366gram_matrix_new(torch.rand(1, 2, 3, 4))
367# [example adapted from http://einops.rocks/pytorch-examples.html]
368```
369
370Attention is another example that has several matrix products embedded inside it:
371
372```py
373from torchdim import softmax
374def attention(K, Q, V):
375    batch, channel, key, query = dims(4)
376    k = K[batch, channel, key]
377    q = Q[batch, channel, query]
378    v = V[batch, channel, key]
379
380    a = (k * q).sum(channel) # matrix multiply
381    a = softmax(a * (channel.size ** -0.5), dim=key)
382    r = (v * a).sum(key) # matrix multiply
383    return torch.cat((r.order(batch, channel, query), Q), dim=1)
384
385inputs = (torch.rand(2, 3, 4) for _ in range(3))
386attention(*inputs)
387# [example adapted from http://einops.rocks/pytorch-examples.html]
388```
389
390Reshaping tensors (einops)
391--------------------------
392
393Lots of operations in deep learning are just different ways of reshaping, splitting, and joining dimensions, such as the pixel shuffle used to upscale an image by turning channels into pixels:
394
395```py
396def pixel_shuffle(img, upscale_factor=2):
397    h2, w2, c, b, h, w = dims(6)
398    h2.size = w2.size = upscale_factor
399    return img[b, (c, h2, w2), h, w].order(b, c, (h, h2), (w, w2))
400```
401
402[Einops](http://einops.rocks) is an extension to einsum that adds support for the manipulation of dimensions through a few custom operators such as `rearrange`:
403
404```py
405def pixel_shuffle_einops(img, upscale_factor=2):
406    from einops import rearrange
407    return rearrange(img, 'b (c h2 w2) h w -> b c (h h2) (w w2)', h2=upscale_factor, w2=upscale_factor)
408```
409
410Named tensors with first-class dimensions can accomplish the same goal, but using PyTorch's existing operator set.
411
412Automatically batching Code (`vmap`, `xmap`)
413-----------------------------
414
415The implicit batching of Rule #1 means it is easy to created batched versions of existing PyTorch code. Simply bind a dim to the dimensions that should act as a batch, and then pass the tensor to the unbatched function. Since the unbatched function does not know about the dim, the dim will be implicitly batched over:
416
417```py
418batch_size, feature_size = 3, 5
419weights = torch.randn(feature_size)
420
421def model(feature_vec):
422    # Very simple linear model with activation
423    assert feature_vec.dim() == 1
424    return feature_vec.dot(weights).relu()
425
426examples = torch.randn(batch_size, feature_size)
427batch = dims(1)
428r = model(examples[batch])
429print(r)
430# in functorch: result = functorch.vmap(model)(examples)
431> tensor([0.4775, 0.0000, 0.3423])
432> with dims=(batch,) sizes=(3,)
433```
434
435This pattern also composes well with other code that also uses first class dimensions. For instance, we can write batched matrix multiply `bmm` by batching the `mm` operator.
436
437It doesn't matter whether the implementation of the function uses dimension objects, it is also possible to add additional batch dimensions and then call a function:
438
439```py
440def bmm(A, B):
441    i = dims(1) # note: i here is a different value from i inside mm so it works
442    return mm(A[i], B[i]).order(i)
443```
444
445The equivalent code in JAX, using [xmap or vmap](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#auto-vectorization-with-vmap) are transforms over functions. So there is a lot of syntactic distance between the specification of the dimension mappings, and the values where those mappings apply. Dims express the mapping as indexing of the tensor, right at the place where the function is being applied.
446
447
448[xmap examples](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html):
449
450```py
451in_axes = [['inputs', 'hidden', ...],
452           ['hidden', 'classes', ...],
453           ['batch', 'inputs', ...],
454           ['batch', ...]]
455
456loss = xmap(named_loss, in_axes=in_axes, out_axes=[...])
457print(loss(w1, w2, images, labels))
458```
459
460Equivalent with dimension objects:
461
462```py
463batch, inputs, hidden, classes = dims(4)
464print(loss(w1[inputs, hidden], w2[hidden, classes], images[batch, inputs], labels[batch],
465      batch, inputs, hidden, classes))
466```
467
468
469Composing matrix products, reshaping, and batching:
470---------------------
471
472Multi-headed attention is a good example of how these different uses compose. It reshapes the inputs, splitting out different attention heads. It batches over those attention heads, and it uses matrix products to compute attention scores.
473
474```py
475from torchdim import softmax
476def multiheadattention(q, k, v, num_attention_heads, dropout_prob, use_positional_embedding):
477    batch, query_sequence, key_sequence, heads, features = dims(5)
478    heads.size = num_attention_heads
479
480    # binding dimensions, and unflattening the heads from the feature dimension
481    q = q[batch, query_sequence, [heads, features]]
482    k = k[batch, key_sequence, [heads, features]]
483    v = v[batch, key_sequence, [heads, features]]
484
485    # einsum-style operators to calculate scores,
486    attention_scores = (q*k).sum(features) * (features.size ** -0.5)
487
488    # use first-class dim to specify dimension for softmax
489    attention_probs = softmax(attention_scores, dim=key_sequence)
490
491    # dropout work pointwise, following Rule #1
492    attention_probs = torch.nn.functional.dropout(attention_probs, p=dropout_prob)
493
494    # another matrix product
495    context_layer = (attention_probs*v).sum(key_sequence)
496
497    # flatten heads back into features
498    return context_layer.order(batch, query_sequence, [heads, features])
499```
500
501Indexing
502--------
503
504Rule #3 enables indexing because dimensions act as loop indices when used as a tensor. This allows for a lot of powerful behavior. The simplest might be using the dimensions to compute masks, such as extracting the upper triangular part of a matrix:
505
506```py
507from torch import where
508def triu(A):
509    i,j = dims()
510    a = A[i, j]
511    return where(i <= j, a, 0).order(i, j)
512triu(torch.rand(3, 4))
513```
514
515Embedding bag does an embedding table lookup followed by a sum, which can be expressed concisely:
516
517```py
518def embedding_bag(input, embedding_weights):
519    batch, sequence, features = dims(3)
520    r = embedding_weights[input[batch, sequence], features].sum(sequence)
521    return r.order(batch, features)
522
523input = torch.tensor([[1, 0, 4, 3]])
524W = torch.rand(5,2)
525embedding_bag(input, W)
526```
527
528Relative positional embeddings associate an embedding vector with the distance between the query and the key in the sequence.
529For instance, a key 3 and query 5 will have embedding ID `(5-3)=2`. We can use first-class dimensions to do the indexing arithmetic, and the embedding lookup:
530
531```py
532def relative_positional_embedding(q, k, distance_embedding_weight):
533    batch, query_sequence, key_sequence, heads, features = dims(5)
534    q = q[batch, query_sequence, [heads, features]]
535    k = k[batch, key_sequence, [heads, features]]
536
537    distance = query_sequence - key_sequence
538    n_embeddings = distance_embedding_weight.size(0)
539    index_bias = n_embeddings // 2
540
541    assert key_sequence.size + bias <= n_embeddings
542
543    # indexing with dims
544    positional_embedding = distance_embedding_weight[distance + index_bias, features]
545
546    # matrix multiplies with dims
547    relative_position_scores_query = (q*positional_embedding).sum(features)
548    relative_position_scores_key = (k*positional_embedding).sum(features)
549    return  (relative_position_scores_query + relative_position_scores_key).order(batch, heads, key_sequence, query_sequence)
550```
551
552Tensor Puzzlers
553===============
554
555[Tensor Puzzlers](https://github.com/srush/Tensor-Puzzles), created by Sasha Rush, are a good exercise for learning the numpy and torch APIs by figuring out how to define common operations using a small set of primitive tensor operations.
556
557However, the difficulty of many of the puzzlers lies not in how to compute the answer but the awkwardness of the primitives themselves.
558
559**With first class dimensions, these puzzlers are nearly the same as the spec that defines them**
560
561
562### Puzzle 3 - outer
563
564Compute [outer](https://numpy.org/doc/stable/reference/generated/numpy.outer.html) - the outer product of two vectors.
565
566```py
567def outer_spec(a, b, out):
568    for i in range(len(out)):
569        for j in range(len(out[0])):
570            out[i][j] = a[i] * b[j]
571
572def outer(a, b):
573    i, j = dims(2)
574    return (a[i] * b[j]).order(i, j)
575```
576
577### Puzzle 4 - diag
578
579Compute [diag](https://numpy.org/doc/stable/reference/generated/numpy.diag.html) - the diagonal vector of a square matrix.
580
581```py
582def diag_spec(a, out):
583    for i in range(len(a)):
584        out[i] = a[i][i]
585
586def diag(a):
587    i = dims(1)
588    return a[i, i].order(i)
589```
590
591### Puzzle 5 - eye
592
593Compute [eye](https://numpy.org/doc/stable/reference/generated/numpy.eye.html) - the identity matrix.
594
595```py
596from torch import where
597def eye_spec(out):
598    for i in range(len(out)):
599        out[i][i] = 1
600
601def eye(j: int):
602    i,j = dims(sizes=[j, j])
603    return where(i == j, 1, 0).order(i, j)
604```
605
606### Puzzle 6 - triu
607
608Compute [triu](https://numpy.org/doc/stable/reference/generated/numpy.triu.html) - the upper triangular matrix.
609
610```py
611def triu_spec(out):
612    for i in range(len(out)):
613        for j in range(len(out)):
614            if i <= j:
615                out[i][j] = 1
616            else:
617                out[i][j] = 0
618
619def triu(j: int):
620    i,j = dims(sizes=[j, j])
621    return where(i <= j, 1, 0).order(i, j)
622```
623
624### Puzzle 8 - diff
625
626Compute [diff](https://numpy.org/doc/stable/reference/generated/numpy.diff.html) - the running difference.
627
628```py
629def diff_spec(a, out):
630    out[0] = a[0]
631    for i in range(1, len(out)):
632        out[i] = a[i] - a[i - 1]
633def diff(a, i: int):
634    i = dims(1)
635    d = a[i] - a[i - 1]
636    return where(i - 1 >= 0, d, a[i]).order(i)
637```
638
639### Puzzle 9 - vstack
640
641Compute [vstack](https://numpy.org/doc/stable/reference/generated/numpy.vstack.html) - the matrix of two vectors
642
643```py
644def vstack_spec(a, b, out):
645    for i in range(len(out[0])):
646        out[0][i] = a[i]
647        out[1][i] = b[i]
648
649def vstack(a, b):
650    v, i = dims(sizes=[2, None])
651    return where(v == 0,  a[i], b[i]).order(v, i)
652```
653
654### Puzzle 10 - roll
655
656Compute [roll](https://numpy.org/doc/stable/reference/generated/numpy.roll.html) - the vector shifted 1 circular position.
657
658```py
659def roll_spec(a, out):
660    for i in range(len(out)):
661        if i + 1 < len(out):
662            out[i] = a[i + 1]
663        else:
664            out[i] = a[i + 1 - len(out)]
665
666def roll(a, i: int):
667    i = dims(sizes=[a.size(0)])
668    return a[where(i + 1 < i.size, i + 1, 0)].order(i)
669```
670
671### Puzzle 11 - flip
672
673Compute [flip](https://numpy.org/doc/stable/reference/generated/numpy.flip.html) - the reversed vector
674
675```py
676def flip_spec(a, out):
677    for i in range(len(out)):
678        out[i] = a[len(out) - i - 1]
679
680def flip(a, i: int):
681    i = dims(sizes=[a.size(0)])
682    return a[i.size - i - 1].order(i)
683```
684
685### Puzzle 14 - sequence_mask
686
687
688Compute [sequence_mask](https://www.tensorflow.org/api_docs/python/tf/sequence_mask) - pad out to length per batch.
689
690```py
691def sequence_mask_spec(values, length, out):
692    for i in range(len(out)):
693        for j in range(len(out[0])):
694            if j < length[i]:
695                out[i][j] = values[i][j]
696            else:
697                out[i][j] = 0
698
699def sequence_mask(values, length):
700    j, i = dims()
701    v = values[i, j]
702    return where(j < length[i], v, 0).order(i, j)
703```
704
705Advantages of First-class Dimensions over String Dimensions
706===================================================================
707
708The most prominent difference between named tensors using first-class dimensions and alternatives (einops, named tensors implemented in PyTorch today , [tensors considered harmful](https://nlp.seas.harvard.edu/NamedTensor), or xmap) is that dimensions are objects rather than strings. Using objects has a number of nice properties.
709
710### Avoiding naming conflicts
711
712Using strings for dimensions introduces the possibility that two unrelated dimensions are given the same name. Using objects instead makes it clear the same names are not the same dimension. It's like the difference between having only global variables, and having the ability to locally bind names in functions.
713 For instance, we defined `bmm` by batching a call to `mm`, and even though they both use the name `i` to identify a dimension.  Because each `i` is a different object, there is no naming conflict:
714
715```py
716def mm(A, B):
717    i, j, k = dims()
718    r = (A[i, k] * B[k, j]).sum(k)
719    return r.order(i, j)
720
721def bmm(A, B):
722    i = dims() # note: doesn't matter than mm internally also uses i
723    return mm(A[i], B[i])
724```
725
726Einops avoids conflicts by ensuring names are all introduced and removed in a single expression, but this precludes using long-lived dimensions to present implicit batching similar to xmap. When nested, JAX's xmap seems to consider axes the same if the string name matches. In the above example it would consider the `i` dimension to be the same dimension in both `bmm` and `mm` so the code would error.
727
728
729### Reuse the same operator set
730
731Having a new object type allows us to extend the existing operator set of PyTorch rather than come up with new operators. For instance, binding dimensions using indexing follows semantically from Rules #1 and #3, so there is no need for a special operator to do binding. Even unbinding is just the `permute` operator which follows from Rule #2, though we call it `order` for clarity. In contrast, using strings requires coming up with new APIs such as `einsum` for matrix multiplies, or `rearrange` for doing permutations.
732
733### Allows dims to act as tensors
734
735Rule #3 is not possible with strings since we cannot make strings behave as tensors. Without this rule, all of the indirect indexing that dims enable would not be easy to express.
736
737### Dims can have methods
738For instance, as objects, dims can have a size, which allows us to do size inference of dimensions in various places in the API where string based APIs would have to take additional arguments specifying size.
739
740
741Comparison to tensor compilers or languages (e.g. TVM or Dex)
742=============================================================
743
744The semantics and surface syntax of dimension objects resembles the kind of code written in tensor compilers such as [Halide](https://halide-lang.org), [TVM](https://tvm.apache.org), [Tensor Comprehensions](https://github.com/facebookresearch/TensorComprehensions), or the language [Dex](https://github.com/google-research/dex-lang).
745
746These compilers and language have syntax and semantics that resemble the loop-level analogy similar to first-class dimensions. However, as compilers or statically typed languages, they require some binding code to go from running deep learning framework code in Python to using the compiled language. This often at least requires refactoring the compiled parts into their own functions, and may require defining a gradient function. Similar to graph mode frameworks, this adds friction to using and debugging the code.
747
748Dimension objects are just an extension of the existing PyTorch tensors and eager semantics, so there is no friction switching between normal Python code and code that uses them. However, since loops over the dimensions are defined implicitly, they can still execute in Python with good performance compared to explicit loops. Furthermore, with dimension objects, a tensors containing dimensions can compute through code that is oblivious to the dimension such as batching examples. There is no need to separate code into 'compiled' vs 'eager'.
749
750In this way, first-class dims are a way of adapting the nicer syntax of these array compilers and languages to eager numpy-style libraries.
751
752
753Performance Expectations
754========================
755First-class dimensions are not a compiler. They provide syntax for existing PyTorch operations such as advanced indexing that is easier to read and write. For large sized tensors, the performance of any statements including them will be the same as using the already existing operations. An important exception is the pattern matching of products and summation, where performance will be improved by issuing to a matrix-multiply kernel. The C++ implementation of dimensions adds a small overhead of around 2us on top of PyTorch's normal overhead of 8us to each function that uses them. In the future, the implementation can encorporate more fusion optimization to further improve performance of this style of code.
756
757
758## License
759Functorch has a BSD-style license, as found in the [LICENSE](LICENSE) file.
760