xref: /aosp_15_r20/external/pytorch/functorch/dim/README.md (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard WorkerNamed Tensors using First-class Dimensions in PyTorch
2*da0073e9SAndroid Build Coastguard Worker=====================================================
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Worker-- Zachary DeVito [@Zachary_DeVito](https://twitter.com/Zachary_DeVito)
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker_An implementation of [named tensors](https://namedtensor.github.io) with the functionality of [einsum](http://einops.rocks]http://einops.rocks) , batching ([vmap](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap), [xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html)), and tensor indexing by adding dimension objects to PyTorch_.
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard WorkerThe tensor input to a resnet might have the shape [8, 3, 224, 224] but informally we think of those dimensions as 'batch', 'channel', 'width', and 'height'. Eventhough 'width' and 'height' have the same _size_ we still think of them as separate dimensions, and if we have two _different_ images, we think of both as sharing the _same_ 'channel' dimension.
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard WorkerNamed tensors gives these dimensions names. [PyTorch's current implementation](https://pytorch.org/docs/stable/named_tensor.html) uses strings to name dimensions. Instead, this library introduces a Python object, a `Dim`, to represent the concept. By expanding the semantics of tensors with dim objects, in addition to naming dimensions, we can get behavior equivalent to batching transforms (xmap, vmap), einops-style rearrangement, and loop-style tensor indexing.
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard WorkerA preview:
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker```py
15*da0073e9SAndroid Build Coastguard Workerfrom torchdim import dims
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker# einsum
18*da0073e9SAndroid Build Coastguard Workerdef mm(A: torch.Tensor, B: torch.Tensor):
19*da0073e9SAndroid Build Coastguard Worker    i, j, k = dims(3)
20*da0073e9SAndroid Build Coastguard Worker    r = (A[i, k] * B[k, j]).sum(k)
21*da0073e9SAndroid Build Coastguard Worker    return r.order(i, j)
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker# rearrange
24*da0073e9SAndroid Build Coastguard Workerdef pixel_shuffle(img: torch.Tensor, upscale_factor=2):
25*da0073e9SAndroid Build Coastguard Worker    h2, w2, c, b, h, w = dims(6)
26*da0073e9SAndroid Build Coastguard Worker    h2.size = w2.size = upscale_factor
27*da0073e9SAndroid Build Coastguard Worker    return img[b, (c, h2, w2), h, w].order(b, c, (h, h2), (w, w2))
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker# batching
30*da0073e9SAndroid Build Coastguard Workerdef bmm(A: torch.Tensor, B: torch.Tensor):
31*da0073e9SAndroid Build Coastguard Worker    i = dims(1)
32*da0073e9SAndroid Build Coastguard Worker    return mm(A[i], B[i]).order(i)
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker# indexing
35*da0073e9SAndroid Build Coastguard Workerdef embedding_bag(input: torch.Tensor, embedding_weights: torch.Tensor):
36*da0073e9SAndroid Build Coastguard Worker    batch, sequence, features = dims(3)
37*da0073e9SAndroid Build Coastguard Worker    r = embedding_weights[input[batch, sequence], features].sum(sequence)
38*da0073e9SAndroid Build Coastguard Worker    return r.order(batch, features)
39*da0073e9SAndroid Build Coastguard Worker```
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard WorkerInstallation
42*da0073e9SAndroid Build Coastguard Worker============
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker_torchdim is a preview release so that we can collect feedback on the API. It may have bugs, and there are known places where performance can be improved._
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard WorkerFirst-class dims are a library that extends PyTorch, so they need to be installed separately.
48*da0073e9SAndroid Build Coastguard WorkerWe may eventually upstream them into PyTorch itself along with `functorch`.
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard WorkerWe have to install a nightly build of PyTorch so first set up an environment:
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker```sh
54*da0073e9SAndroid Build Coastguard Workerconda create --name dim
55*da0073e9SAndroid Build Coastguard Workerconda activate dim
56*da0073e9SAndroid Build Coastguard Worker```
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard WorkerFirst-class dims requires a fairly recent nightly build of PyTorch so that functorch will work. You can install it using one of these commands:
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker```sh
61*da0073e9SAndroid Build Coastguard Worker# For CUDA 10.2
62*da0073e9SAndroid Build Coastguard Workerconda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch-nightly
63*da0073e9SAndroid Build Coastguard Worker# For CUDA 11.3
64*da0073e9SAndroid Build Coastguard Workerconda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch-nightly
65*da0073e9SAndroid Build Coastguard Worker# For CPU-only build
66*da0073e9SAndroid Build Coastguard Workerconda install pytorch torchvision torchaudio cpuonly -c pytorch-nightly
67*da0073e9SAndroid Build Coastguard Worker```
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard WorkerInstall dim. You will be asked for github credentials to access the fairinternal organization.
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker```sh
72*da0073e9SAndroid Build Coastguard Workerpip install ninja  # Makes the build go faster
73*da0073e9SAndroid Build Coastguard Workerpip install --user "git+https://github.com/facebookresearch/torchdim"
74*da0073e9SAndroid Build Coastguard Worker```
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard WorkerCreating and Binding Dims
77*da0073e9SAndroid Build Coastguard Worker=========================
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard WorkerPython objects that represent dimension are created using the `dims` operator.[^1]
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker```py
82*da0073e9SAndroid Build Coastguard Workerimport torch
83*da0073e9SAndroid Build Coastguard Workerfrom torchdim import dims
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Workerbatch, channel, width, height = dims(4)
86*da0073e9SAndroid Build Coastguard Worker```
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard WorkerThe existing implementation of [Named Tensors](https://pytorch.org/docs/stable/named_tensor.html) in PyTorch, or [JAX's xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html) use strings to name dimensions. We call these dimensions _first class_ because they are Python objects.
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard WorkerIn addition to the normal _positional_ dimensions in a tensor, tensors can also have a separate set of first-class dimensions.
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard WorkerYou can create tensors with first-class dimensions by indexing the normal positional dimensions of a tensor with a dimension object. The `ndim` property continues to list the number of positional dimensions, while the new `dims` property lists all the bound first-class dimensions.
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker```py
95*da0073e9SAndroid Build Coastguard Workerinput = torch.rand(2, 3, 224, 224)
96*da0073e9SAndroid Build Coastguard Workerprint(input.ndim)
97*da0073e9SAndroid Build Coastguard Worker> 4
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Workerinput_fc = input[batch, channel, width, height]
100*da0073e9SAndroid Build Coastguard Workerprint(input_fc.dims) # first class dimensions
101*da0073e9SAndroid Build Coastguard Worker> (batch, channel, width, height)
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker# since we converted all the positional dimensions
105*da0073e9SAndroid Build Coastguard Worker# first class `input_fc` has 0 positional dimensions now.
106*da0073e9SAndroid Build Coastguard Workerprint(input_fc.ndim)
107*da0073e9SAndroid Build Coastguard Worker> 0
108*da0073e9SAndroid Build Coastguard Worker```
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard WorkerNotice that indexing creates a _new_ Tensor, `input_fc` with bound first-class dimensions. It does not modify the original tensor `input`, which still has 4 positional dimensions.
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker```py
113*da0073e9SAndroid Build Coastguard Workerprint(input.ndim) # unchanged
114*da0073e9SAndroid Build Coastguard Worker> 4
115*da0073e9SAndroid Build Coastguard Worker```
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard WorkerImportantly, indexing with square brackets _applies only to positional dimensions_, so attempting to index a tensor with only first class dims will error[^2]:
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker```py
120*da0073e9SAndroid Build Coastguard Workertry:
121*da0073e9SAndroid Build Coastguard Worker    input_fc[0]
122*da0073e9SAndroid Build Coastguard Workerexcept ValueError as ve:
123*da0073e9SAndroid Build Coastguard Worker    print(ve)
124*da0073e9SAndroid Build Coastguard Worker> at least 1 indices were supplied but the tensor only has 0 dimensions
125*da0073e9SAndroid Build Coastguard Worker```
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard WorkerGenerally, it is possible to construct tensors with a mixture of positional and first class dimensions:
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker```py
130*da0073e9SAndroid Build Coastguard Workerinput_mixed = input[batch, :, :, height]
131*da0073e9SAndroid Build Coastguard Workerprint(input_mixed.dims)
132*da0073e9SAndroid Build Coastguard Worker> (batch, height)
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Workerprint(input_mixed.ndim)
135*da0073e9SAndroid Build Coastguard Worker> 2
136*da0073e9SAndroid Build Coastguard Worker```
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard WorkerDimension Sizes
139*da0073e9SAndroid Build Coastguard Worker---------------
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard WorkerDimensions will take on the size of the first thing they are bound to:
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker```py
144*da0073e9SAndroid Build Coastguard Workerinput = torch.rand(3)
145*da0073e9SAndroid Build Coastguard Workerx = dims(1)
146*da0073e9SAndroid Build Coastguard Workerinput_fc = input[x]
147*da0073e9SAndroid Build Coastguard Workerprint(x.size)
148*da0073e9SAndroid Build Coastguard Worker> 3
149*da0073e9SAndroid Build Coastguard Worker```
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard WorkerBut you can also directly set the size of dimension:
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker```py
154*da0073e9SAndroid Build Coastguard Workeri = dims(1)
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Workeri.size = 5 # ok, i previously did not have a size
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Workeri.size = 5 # ok, it already had the size 5
159*da0073e9SAndroid Build Coastguard Workertry:
160*da0073e9SAndroid Build Coastguard Worker    i.size = 3
161*da0073e9SAndroid Build Coastguard Workerexcept Exception as e:
162*da0073e9SAndroid Build Coastguard Worker    print(e)
163*da0073e9SAndroid Build Coastguard Worker> Dim 'i' previously bound to a dimension of size 5 cannot bind to a dimension of size 3
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Workerj = dims(sizes=[4]) # can also be set on construction
166*da0073e9SAndroid Build Coastguard Worker```
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker[^1]: We use a bit of Python introspection to set the debug names for the dimensions based on the names of the variables they are assigned to.
169*da0073e9SAndroid Build Coastguard Worker[^2]: Indexing of first-class dimensions can be done with the `index` method by specifying the dimension to be index into (e.g. `input_fc.index(batch, 0)`.
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard WorkerSemantics of Dimensions
172*da0073e9SAndroid Build Coastguard Worker=======================
173*da0073e9SAndroid Build Coastguard WorkerThe power of named tensors arises from how the first-class dimensions in the Tensors composed with existing operations.
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard WorkerThree rules define how dimension objects behave with existing Tensors.
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard WorkerRule 1: Implicit Batching
178*da0073e9SAndroid Build Coastguard Worker-------------------------
179*da0073e9SAndroid Build Coastguard Worker**Tensor operations (e.g. `input + bias`) are implicitly batched over the union of the first-class dimensions in their inputs.**
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard WorkerIf `input` has dimensions `batch, channel` and `bias` has dimension `channel`, the output will have the union of those dimensions (`batch, channel`), and the result will be computed as if there was a loop over all the first-class dimensions.[^3]
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker```py
184*da0073e9SAndroid Build Coastguard Workerinput_positional = torch.rand(128, 32)
185*da0073e9SAndroid Build Coastguard Workerbias_positional = torch.rand(32)
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Workerbatch, channel = dims(2)
188*da0073e9SAndroid Build Coastguard Workerinput = input_positional[batch, channel]
189*da0073e9SAndroid Build Coastguard Workerbias = bias_positional[channel]
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Workerresult = input + bias
192*da0073e9SAndroid Build Coastguard Workerprint(result.dims)
193*da0073e9SAndroid Build Coastguard Worker> (batch, channel)
194*da0073e9SAndroid Build Coastguard Worker```
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard WorkerIt is helpful to think of operators on tensors with first-class dimensions by analogy to code with explicit loops over dimensions, with the first-class dimensions of the inputs acting as implicit `for` loops, and the values in the tensor being scalars within the body of the loop:
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker```py
199*da0073e9SAndroid Build Coastguard Worker# mental model: loop-level analogy
200*da0073e9SAndroid Build Coastguard Workerfor batch in range(batch.size):
201*da0073e9SAndroid Build Coastguard Worker    for channel in range(channel.size):
202*da0073e9SAndroid Build Coastguard Worker        input = input_positional[batch, channels]
203*da0073e9SAndroid Build Coastguard Worker        bias = bias_positional[channels]
204*da0073e9SAndroid Build Coastguard Worker        result[batch, channels] =  input + bias # arithmetic on scalars
205*da0073e9SAndroid Build Coastguard Worker```
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard WorkerPositional dimensions behave as they did before (e.g. for + they will broadcast), and can be thought of as being a standard tensor _used within the implicit loops_ defined by first-class dimensions.
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard WorkerIn this example, we broke down the expression into lines that bind the dimension to positional tensors and then another line to do the compute. In practice, we often combine these in one statement:
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker```py
212*da0073e9SAndroid Build Coastguard Workerresult = input_positional[batch, channel] + bias_positional[channel]
213*da0073e9SAndroid Build Coastguard Workerresult.dims
214*da0073e9SAndroid Build Coastguard Worker```
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker[^3] This rule is similar to how named dimensions in xmap behave within a function, but instead of introducing the dimensions via a functional transform, they are bound on the objects using indexing.
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard WorkerRule 2: Specifying dimensions
220*da0073e9SAndroid Build Coastguard Worker-----------------------------
221*da0073e9SAndroid Build Coastguard Worker**Wherever an integer is used to specify a dimension in the existing torch operator, a first-class dimensions can be used instead to tell the operator to work over that dimension.**
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker```py
224*da0073e9SAndroid Build Coastguard Workerbatch, channel, width, height = dims(4)
225*da0073e9SAndroid Build Coastguard Workerinput_positional = torch.rand(2, 3, 224, 224)
226*da0073e9SAndroid Build Coastguard Workerinput = input_positional[batch, channel, width, height]
227*da0073e9SAndroid Build Coastguard Workeravg_pixel_color = input.mean((width, height))
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Workerprint(avg_pixel_color.dims)
230*da0073e9SAndroid Build Coastguard Worker> (batch, channel)
231*da0073e9SAndroid Build Coastguard Worker```
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard WorkerAny other first-class dimensions (e.g. batch, channel) are still implicitly batched according to Rule #1.
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard WorkerRule 3: Dims are Tensors
236*da0073e9SAndroid Build Coastguard Worker------------------------
237*da0073e9SAndroid Build Coastguard Worker**A first-class dimension `d` can be used wherever a Tensor is expected. It will act as if it were a tensor whose only dimension is itself, `d`, and the values along the dimension are the indices of each entry `(0, 1, 2, ..., d.size - 1)`**
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker```py
240*da0073e9SAndroid Build Coastguard Workerprint(channel.dims)
241*da0073e9SAndroid Build Coastguard Worker> (channel,)
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Workerprint(channel + 1000)
244*da0073e9SAndroid Build Coastguard Worker> tensor([1000, 1001, 1002])
245*da0073e9SAndroid Build Coastguard Worker> with dims=(channel,) sizes=(3,)
246*da0073e9SAndroid Build Coastguard Worker```
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard WorkerThis means that a dimension used as a tensor acts as an index into that dimension. Going back to our loop-level analogy, it is analogous to using the loop variable as a value:
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker```py
251*da0073e9SAndroid Build Coastguard Worker# mental model: loop-level analogy
252*da0073e9SAndroid Build Coastguard Workerfor channel in range(batch.size):
253*da0073e9SAndroid Build Coastguard Worker    result[channel] = channel + 1000
254*da0073e9SAndroid Build Coastguard Worker```
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard WorkerArithmetic using dimension indices comes up a lot, such as the mask for an upper triangular part of a matrix. Using dims as tensors makes it easy:
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker```py
259*da0073e9SAndroid Build Coastguard Workerfrom torchdim import dims
260*da0073e9SAndroid Build Coastguard Workeri, j = dims(sizes=[4, 4])
261*da0073e9SAndroid Build Coastguard Workerprint(i <= j)
262*da0073e9SAndroid Build Coastguard Worker> tensor([[ True,  True,  True,  True],
263*da0073e9SAndroid Build Coastguard Worker>         [False,  True,  True,  True],
264*da0073e9SAndroid Build Coastguard Worker>         [False, False,  True,  True],
265*da0073e9SAndroid Build Coastguard Worker>         [False, False, False,  True]])
266*da0073e9SAndroid Build Coastguard Worker> with dims=(i, j) sizes=(4, 4)
267*da0073e9SAndroid Build Coastguard Worker```
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard WorkerBecause of the intentional similarity to loop-level code, using dimensions as tensors makes complicated indexing arithmetic easier to read.
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard WorkerHere is code that lookups up features in an embedding table given a sequence of ids:
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker```py
274*da0073e9SAndroid Build Coastguard Workersequence, features = dims(2)
275*da0073e9SAndroid Build Coastguard Workerembeddings = torch.rand(8, 128)
276*da0073e9SAndroid Build Coastguard Workerwords = torch.tensor([5, 4, 0,])
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Workerstate = embeddings[words[sequence], features]
279*da0073e9SAndroid Build Coastguard Workerprint(state.dims)
280*da0073e9SAndroid Build Coastguard Worker> (sequence, features)
281*da0073e9SAndroid Build Coastguard Worker```
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard WorkerWith the following analogy to loops:
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker```py
286*da0073e9SAndroid Build Coastguard Worker# mental model: loop-level analogy
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Workerfor sequence in range(words.size(0)):
289*da0073e9SAndroid Build Coastguard Worker    for features in range(embeddings.size(1)):
290*da0073e9SAndroid Build Coastguard Worker        state = embeddings[words[sequence], features]
291*da0073e9SAndroid Build Coastguard Worker```
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard WorkerEarlier we showed how binding tensors dimension is done with indexing `A[i, j]`. In fact, this binding is just the normal indexing operator. Its behavior follows directly from the behavior of indexing with tensor indices combined with Rule #3 and Rule #1. The expression `A[i + 1, j]` also creates a tensor with dimensions `i` and `j` but with different indexing math. The implementation knows when simple indexing patterns are used and only actually runs a kernel to do indexing when needed.
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard WorkerUnbinding Dims
296*da0073e9SAndroid Build Coastguard Worker-------------
297*da0073e9SAndroid Build Coastguard WorkerThe `order` method converts first-class dimensions in a tensor back to normal positional dimensions by specifying an order for those dimensions.[^4]
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard WorkerBy specifying a different order from how things were originally bound, it is easy to do transpositions.
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker```py
302*da0073e9SAndroid Build Coastguard Workeri, j = dims(2)
303*da0073e9SAndroid Build Coastguard WorkerA = torch.rand(3, 4)
304*da0073e9SAndroid Build Coastguard WorkerA_T = A[i, j].order(j, i)
305*da0073e9SAndroid Build Coastguard Workerassert torch.allclose(A.T, A_T)
306*da0073e9SAndroid Build Coastguard Worker```
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard WorkerIndexing acts left-to-right, and `order` also places the new dimensions back on the left, so it possible to work on tensors that have mixed positional and first-class dimensions:
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker```py
311*da0073e9SAndroid Build Coastguard WorkerB = torch.rand(3, 4, 5)
312*da0073e9SAndroid Build Coastguard WorkerB_T = B[i, j].order(j, i)
313*da0073e9SAndroid Build Coastguard Workerassert torch.allclose(B.permute(1, 0, 2), B_T)
314*da0073e9SAndroid Build Coastguard Worker```
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker[^4] `order` is actually just a synonym for the already-existing `permute` method, which takes a list a dimension specifiers and puts the tensor in that order because rule #2 says that first-class dims can be passed as arguments to functions that previously took only integers as dimensions. However, the name `permute` is confusing in this context since it implies dim objects have an original order, so we prefer to use `order` when writing code.
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard WorkerFlattening and Splitting Dims
319*da0073e9SAndroid Build Coastguard Worker-----------------------------
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker**Tuples of dimensions** can be passed to both indexing and `order`. In indexing, this will split the dimension being indexed across the dimensions in the tuple.  In `order` it will flatten the dimensions in a single positional dimension:
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker```py
324*da0073e9SAndroid Build Coastguard Workeri, j, k = dims(3)
325*da0073e9SAndroid Build Coastguard Workerj.size = 2
326*da0073e9SAndroid Build Coastguard WorkerA = torch.rand(6, 4)
327*da0073e9SAndroid Build Coastguard Workera = A[(i, j), k] # split dim 0 into i,j
328*da0073e9SAndroid Build Coastguard Workerprint(i.size, j.size, k.size)
329*da0073e9SAndroid Build Coastguard Worker> 3 2 4
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Workerr = a.order(i, (j, k)) # flatten j and k
332*da0073e9SAndroid Build Coastguard Workerprint(r.shape)
333*da0073e9SAndroid Build Coastguard Worker> torch.Size([3, 8])
334*da0073e9SAndroid Build Coastguard Worker```
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard WorkerThe size of one unsized dimension in a tuple such as `i` can be inferred if the other sizes are known.
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard WorkerExamples
339*da0073e9SAndroid Build Coastguard Worker========
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard WorkerThe usefulness of dimension objects is best seen through examples. Let's look at some different ways they can be used.
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard WorkerEinsum-style Products
344*da0073e9SAndroid Build Coastguard Worker---------------------
345*da0073e9SAndroid Build Coastguard WorkerRather than having [einsum](https://pytorch.org/docs/stable/generated/torch.einsum.html) as a custom operator, it is possible to express matrix products directly as a composition of multiplies and summations. The implementation will pattern match any multiplication followed by a sum to the right matrix-multiply operator.
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker```py
348*da0073e9SAndroid Build Coastguard Workerdef mm(A, B):
349*da0073e9SAndroid Build Coastguard Worker    i, j, k = dims(3)
350*da0073e9SAndroid Build Coastguard Worker    r = (A[i, k] * B[k, j]).sum(k)
351*da0073e9SAndroid Build Coastguard Worker    return r.order(i, j)
352*da0073e9SAndroid Build Coastguard Workermm(torch.rand(3, 4), torch.rand(4, 5)).shape
353*da0073e9SAndroid Build Coastguard Worker```
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard WorkerThe implementation of named tensors delays the execution of multiply to see if a summation follows it as it does above. If so, it will turn this pattern into the correct _optimized matrix product_, similar to how the `einsum` function works.
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard WorkerSince it is no longer necessary to manually match math to matrix functions, other tensor products are easier to express, like the Gram matrix used in style transfer:
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker```py
360*da0073e9SAndroid Build Coastguard Workerdef gram_matrix_new(y):
361*da0073e9SAndroid Build Coastguard Worker    b, c, c2, h, w = dims()
362*da0073e9SAndroid Build Coastguard Worker    r = (y[b, c, h, w] * y[b, c2, h, w]).sum((h, w))
363*da0073e9SAndroid Build Coastguard Worker    r = r / (h.size * w.size)
364*da0073e9SAndroid Build Coastguard Worker    return r.order(b, c, c2)
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Workergram_matrix_new(torch.rand(1, 2, 3, 4))
367*da0073e9SAndroid Build Coastguard Worker# [example adapted from http://einops.rocks/pytorch-examples.html]
368*da0073e9SAndroid Build Coastguard Worker```
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard WorkerAttention is another example that has several matrix products embedded inside it:
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker```py
373*da0073e9SAndroid Build Coastguard Workerfrom torchdim import softmax
374*da0073e9SAndroid Build Coastguard Workerdef attention(K, Q, V):
375*da0073e9SAndroid Build Coastguard Worker    batch, channel, key, query = dims(4)
376*da0073e9SAndroid Build Coastguard Worker    k = K[batch, channel, key]
377*da0073e9SAndroid Build Coastguard Worker    q = Q[batch, channel, query]
378*da0073e9SAndroid Build Coastguard Worker    v = V[batch, channel, key]
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker    a = (k * q).sum(channel) # matrix multiply
381*da0073e9SAndroid Build Coastguard Worker    a = softmax(a * (channel.size ** -0.5), dim=key)
382*da0073e9SAndroid Build Coastguard Worker    r = (v * a).sum(key) # matrix multiply
383*da0073e9SAndroid Build Coastguard Worker    return torch.cat((r.order(batch, channel, query), Q), dim=1)
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Workerinputs = (torch.rand(2, 3, 4) for _ in range(3))
386*da0073e9SAndroid Build Coastguard Workerattention(*inputs)
387*da0073e9SAndroid Build Coastguard Worker# [example adapted from http://einops.rocks/pytorch-examples.html]
388*da0073e9SAndroid Build Coastguard Worker```
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard WorkerReshaping tensors (einops)
391*da0073e9SAndroid Build Coastguard Worker--------------------------
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard WorkerLots of operations in deep learning are just different ways of reshaping, splitting, and joining dimensions, such as the pixel shuffle used to upscale an image by turning channels into pixels:
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Worker```py
396*da0073e9SAndroid Build Coastguard Workerdef pixel_shuffle(img, upscale_factor=2):
397*da0073e9SAndroid Build Coastguard Worker    h2, w2, c, b, h, w = dims(6)
398*da0073e9SAndroid Build Coastguard Worker    h2.size = w2.size = upscale_factor
399*da0073e9SAndroid Build Coastguard Worker    return img[b, (c, h2, w2), h, w].order(b, c, (h, h2), (w, w2))
400*da0073e9SAndroid Build Coastguard Worker```
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker[Einops](http://einops.rocks) is an extension to einsum that adds support for the manipulation of dimensions through a few custom operators such as `rearrange`:
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker```py
405*da0073e9SAndroid Build Coastguard Workerdef pixel_shuffle_einops(img, upscale_factor=2):
406*da0073e9SAndroid Build Coastguard Worker    from einops import rearrange
407*da0073e9SAndroid Build Coastguard Worker    return rearrange(img, 'b (c h2 w2) h w -> b c (h h2) (w w2)', h2=upscale_factor, w2=upscale_factor)
408*da0073e9SAndroid Build Coastguard Worker```
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard WorkerNamed tensors with first-class dimensions can accomplish the same goal, but using PyTorch's existing operator set.
411*da0073e9SAndroid Build Coastguard Worker
412*da0073e9SAndroid Build Coastguard WorkerAutomatically batching Code (`vmap`, `xmap`)
413*da0073e9SAndroid Build Coastguard Worker-----------------------------
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard WorkerThe implicit batching of Rule #1 means it is easy to created batched versions of existing PyTorch code. Simply bind a dim to the dimensions that should act as a batch, and then pass the tensor to the unbatched function. Since the unbatched function does not know about the dim, the dim will be implicitly batched over:
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker```py
418*da0073e9SAndroid Build Coastguard Workerbatch_size, feature_size = 3, 5
419*da0073e9SAndroid Build Coastguard Workerweights = torch.randn(feature_size)
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Workerdef model(feature_vec):
422*da0073e9SAndroid Build Coastguard Worker    # Very simple linear model with activation
423*da0073e9SAndroid Build Coastguard Worker    assert feature_vec.dim() == 1
424*da0073e9SAndroid Build Coastguard Worker    return feature_vec.dot(weights).relu()
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Workerexamples = torch.randn(batch_size, feature_size)
427*da0073e9SAndroid Build Coastguard Workerbatch = dims(1)
428*da0073e9SAndroid Build Coastguard Workerr = model(examples[batch])
429*da0073e9SAndroid Build Coastguard Workerprint(r)
430*da0073e9SAndroid Build Coastguard Worker# in functorch: result = functorch.vmap(model)(examples)
431*da0073e9SAndroid Build Coastguard Worker> tensor([0.4775, 0.0000, 0.3423])
432*da0073e9SAndroid Build Coastguard Worker> with dims=(batch,) sizes=(3,)
433*da0073e9SAndroid Build Coastguard Worker```
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard WorkerThis pattern also composes well with other code that also uses first class dimensions. For instance, we can write batched matrix multiply `bmm` by batching the `mm` operator.
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard WorkerIt doesn't matter whether the implementation of the function uses dimension objects, it is also possible to add additional batch dimensions and then call a function:
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Worker```py
440*da0073e9SAndroid Build Coastguard Workerdef bmm(A, B):
441*da0073e9SAndroid Build Coastguard Worker    i = dims(1) # note: i here is a different value from i inside mm so it works
442*da0073e9SAndroid Build Coastguard Worker    return mm(A[i], B[i]).order(i)
443*da0073e9SAndroid Build Coastguard Worker```
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard WorkerThe equivalent code in JAX, using [xmap or vmap](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#auto-vectorization-with-vmap) are transforms over functions. So there is a lot of syntactic distance between the specification of the dimension mappings, and the values where those mappings apply. Dims express the mapping as indexing of the tensor, right at the place where the function is being applied.
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker[xmap examples](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html):
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker```py
451*da0073e9SAndroid Build Coastguard Workerin_axes = [['inputs', 'hidden', ...],
452*da0073e9SAndroid Build Coastguard Worker           ['hidden', 'classes', ...],
453*da0073e9SAndroid Build Coastguard Worker           ['batch', 'inputs', ...],
454*da0073e9SAndroid Build Coastguard Worker           ['batch', ...]]
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Workerloss = xmap(named_loss, in_axes=in_axes, out_axes=[...])
457*da0073e9SAndroid Build Coastguard Workerprint(loss(w1, w2, images, labels))
458*da0073e9SAndroid Build Coastguard Worker```
459*da0073e9SAndroid Build Coastguard Worker
460*da0073e9SAndroid Build Coastguard WorkerEquivalent with dimension objects:
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker```py
463*da0073e9SAndroid Build Coastguard Workerbatch, inputs, hidden, classes = dims(4)
464*da0073e9SAndroid Build Coastguard Workerprint(loss(w1[inputs, hidden], w2[hidden, classes], images[batch, inputs], labels[batch],
465*da0073e9SAndroid Build Coastguard Worker      batch, inputs, hidden, classes))
466*da0073e9SAndroid Build Coastguard Worker```
467*da0073e9SAndroid Build Coastguard Worker
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard WorkerComposing matrix products, reshaping, and batching:
470*da0073e9SAndroid Build Coastguard Worker---------------------
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard WorkerMulti-headed attention is a good example of how these different uses compose. It reshapes the inputs, splitting out different attention heads. It batches over those attention heads, and it uses matrix products to compute attention scores.
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker```py
475*da0073e9SAndroid Build Coastguard Workerfrom torchdim import softmax
476*da0073e9SAndroid Build Coastguard Workerdef multiheadattention(q, k, v, num_attention_heads, dropout_prob, use_positional_embedding):
477*da0073e9SAndroid Build Coastguard Worker    batch, query_sequence, key_sequence, heads, features = dims(5)
478*da0073e9SAndroid Build Coastguard Worker    heads.size = num_attention_heads
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker    # binding dimensions, and unflattening the heads from the feature dimension
481*da0073e9SAndroid Build Coastguard Worker    q = q[batch, query_sequence, [heads, features]]
482*da0073e9SAndroid Build Coastguard Worker    k = k[batch, key_sequence, [heads, features]]
483*da0073e9SAndroid Build Coastguard Worker    v = v[batch, key_sequence, [heads, features]]
484*da0073e9SAndroid Build Coastguard Worker
485*da0073e9SAndroid Build Coastguard Worker    # einsum-style operators to calculate scores,
486*da0073e9SAndroid Build Coastguard Worker    attention_scores = (q*k).sum(features) * (features.size ** -0.5)
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker    # use first-class dim to specify dimension for softmax
489*da0073e9SAndroid Build Coastguard Worker    attention_probs = softmax(attention_scores, dim=key_sequence)
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker    # dropout work pointwise, following Rule #1
492*da0073e9SAndroid Build Coastguard Worker    attention_probs = torch.nn.functional.dropout(attention_probs, p=dropout_prob)
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Worker    # another matrix product
495*da0073e9SAndroid Build Coastguard Worker    context_layer = (attention_probs*v).sum(key_sequence)
496*da0073e9SAndroid Build Coastguard Worker
497*da0073e9SAndroid Build Coastguard Worker    # flatten heads back into features
498*da0073e9SAndroid Build Coastguard Worker    return context_layer.order(batch, query_sequence, [heads, features])
499*da0073e9SAndroid Build Coastguard Worker```
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard WorkerIndexing
502*da0073e9SAndroid Build Coastguard Worker--------
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard WorkerRule #3 enables indexing because dimensions act as loop indices when used as a tensor. This allows for a lot of powerful behavior. The simplest might be using the dimensions to compute masks, such as extracting the upper triangular part of a matrix:
505*da0073e9SAndroid Build Coastguard Worker
506*da0073e9SAndroid Build Coastguard Worker```py
507*da0073e9SAndroid Build Coastguard Workerfrom torch import where
508*da0073e9SAndroid Build Coastguard Workerdef triu(A):
509*da0073e9SAndroid Build Coastguard Worker    i,j = dims()
510*da0073e9SAndroid Build Coastguard Worker    a = A[i, j]
511*da0073e9SAndroid Build Coastguard Worker    return where(i <= j, a, 0).order(i, j)
512*da0073e9SAndroid Build Coastguard Workertriu(torch.rand(3, 4))
513*da0073e9SAndroid Build Coastguard Worker```
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard WorkerEmbedding bag does an embedding table lookup followed by a sum, which can be expressed concisely:
516*da0073e9SAndroid Build Coastguard Worker
517*da0073e9SAndroid Build Coastguard Worker```py
518*da0073e9SAndroid Build Coastguard Workerdef embedding_bag(input, embedding_weights):
519*da0073e9SAndroid Build Coastguard Worker    batch, sequence, features = dims(3)
520*da0073e9SAndroid Build Coastguard Worker    r = embedding_weights[input[batch, sequence], features].sum(sequence)
521*da0073e9SAndroid Build Coastguard Worker    return r.order(batch, features)
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Workerinput = torch.tensor([[1, 0, 4, 3]])
524*da0073e9SAndroid Build Coastguard WorkerW = torch.rand(5,2)
525*da0073e9SAndroid Build Coastguard Workerembedding_bag(input, W)
526*da0073e9SAndroid Build Coastguard Worker```
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard WorkerRelative positional embeddings associate an embedding vector with the distance between the query and the key in the sequence.
529*da0073e9SAndroid Build Coastguard WorkerFor instance, a key 3 and query 5 will have embedding ID `(5-3)=2`. We can use first-class dimensions to do the indexing arithmetic, and the embedding lookup:
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Worker```py
532*da0073e9SAndroid Build Coastguard Workerdef relative_positional_embedding(q, k, distance_embedding_weight):
533*da0073e9SAndroid Build Coastguard Worker    batch, query_sequence, key_sequence, heads, features = dims(5)
534*da0073e9SAndroid Build Coastguard Worker    q = q[batch, query_sequence, [heads, features]]
535*da0073e9SAndroid Build Coastguard Worker    k = k[batch, key_sequence, [heads, features]]
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker    distance = query_sequence - key_sequence
538*da0073e9SAndroid Build Coastguard Worker    n_embeddings = distance_embedding_weight.size(0)
539*da0073e9SAndroid Build Coastguard Worker    index_bias = n_embeddings // 2
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker    assert key_sequence.size + bias <= n_embeddings
542*da0073e9SAndroid Build Coastguard Worker
543*da0073e9SAndroid Build Coastguard Worker    # indexing with dims
544*da0073e9SAndroid Build Coastguard Worker    positional_embedding = distance_embedding_weight[distance + index_bias, features]
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker    # matrix multiplies with dims
547*da0073e9SAndroid Build Coastguard Worker    relative_position_scores_query = (q*positional_embedding).sum(features)
548*da0073e9SAndroid Build Coastguard Worker    relative_position_scores_key = (k*positional_embedding).sum(features)
549*da0073e9SAndroid Build Coastguard Worker    return  (relative_position_scores_query + relative_position_scores_key).order(batch, heads, key_sequence, query_sequence)
550*da0073e9SAndroid Build Coastguard Worker```
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard WorkerTensor Puzzlers
553*da0073e9SAndroid Build Coastguard Worker===============
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker[Tensor Puzzlers](https://github.com/srush/Tensor-Puzzles), created by Sasha Rush, are a good exercise for learning the numpy and torch APIs by figuring out how to define common operations using a small set of primitive tensor operations.
556*da0073e9SAndroid Build Coastguard Worker
557*da0073e9SAndroid Build Coastguard WorkerHowever, the difficulty of many of the puzzlers lies not in how to compute the answer but the awkwardness of the primitives themselves.
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker**With first class dimensions, these puzzlers are nearly the same as the spec that defines them**
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Worker### Puzzle 3 - outer
563*da0073e9SAndroid Build Coastguard Worker
564*da0073e9SAndroid Build Coastguard WorkerCompute [outer](https://numpy.org/doc/stable/reference/generated/numpy.outer.html) - the outer product of two vectors.
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard Worker```py
567*da0073e9SAndroid Build Coastguard Workerdef outer_spec(a, b, out):
568*da0073e9SAndroid Build Coastguard Worker    for i in range(len(out)):
569*da0073e9SAndroid Build Coastguard Worker        for j in range(len(out[0])):
570*da0073e9SAndroid Build Coastguard Worker            out[i][j] = a[i] * b[j]
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Workerdef outer(a, b):
573*da0073e9SAndroid Build Coastguard Worker    i, j = dims(2)
574*da0073e9SAndroid Build Coastguard Worker    return (a[i] * b[j]).order(i, j)
575*da0073e9SAndroid Build Coastguard Worker```
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker### Puzzle 4 - diag
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard WorkerCompute [diag](https://numpy.org/doc/stable/reference/generated/numpy.diag.html) - the diagonal vector of a square matrix.
580*da0073e9SAndroid Build Coastguard Worker
581*da0073e9SAndroid Build Coastguard Worker```py
582*da0073e9SAndroid Build Coastguard Workerdef diag_spec(a, out):
583*da0073e9SAndroid Build Coastguard Worker    for i in range(len(a)):
584*da0073e9SAndroid Build Coastguard Worker        out[i] = a[i][i]
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Workerdef diag(a):
587*da0073e9SAndroid Build Coastguard Worker    i = dims(1)
588*da0073e9SAndroid Build Coastguard Worker    return a[i, i].order(i)
589*da0073e9SAndroid Build Coastguard Worker```
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker### Puzzle 5 - eye
592*da0073e9SAndroid Build Coastguard Worker
593*da0073e9SAndroid Build Coastguard WorkerCompute [eye](https://numpy.org/doc/stable/reference/generated/numpy.eye.html) - the identity matrix.
594*da0073e9SAndroid Build Coastguard Worker
595*da0073e9SAndroid Build Coastguard Worker```py
596*da0073e9SAndroid Build Coastguard Workerfrom torch import where
597*da0073e9SAndroid Build Coastguard Workerdef eye_spec(out):
598*da0073e9SAndroid Build Coastguard Worker    for i in range(len(out)):
599*da0073e9SAndroid Build Coastguard Worker        out[i][i] = 1
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Workerdef eye(j: int):
602*da0073e9SAndroid Build Coastguard Worker    i,j = dims(sizes=[j, j])
603*da0073e9SAndroid Build Coastguard Worker    return where(i == j, 1, 0).order(i, j)
604*da0073e9SAndroid Build Coastguard Worker```
605*da0073e9SAndroid Build Coastguard Worker
606*da0073e9SAndroid Build Coastguard Worker### Puzzle 6 - triu
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard WorkerCompute [triu](https://numpy.org/doc/stable/reference/generated/numpy.triu.html) - the upper triangular matrix.
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker```py
611*da0073e9SAndroid Build Coastguard Workerdef triu_spec(out):
612*da0073e9SAndroid Build Coastguard Worker    for i in range(len(out)):
613*da0073e9SAndroid Build Coastguard Worker        for j in range(len(out)):
614*da0073e9SAndroid Build Coastguard Worker            if i <= j:
615*da0073e9SAndroid Build Coastguard Worker                out[i][j] = 1
616*da0073e9SAndroid Build Coastguard Worker            else:
617*da0073e9SAndroid Build Coastguard Worker                out[i][j] = 0
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Workerdef triu(j: int):
620*da0073e9SAndroid Build Coastguard Worker    i,j = dims(sizes=[j, j])
621*da0073e9SAndroid Build Coastguard Worker    return where(i <= j, 1, 0).order(i, j)
622*da0073e9SAndroid Build Coastguard Worker```
623*da0073e9SAndroid Build Coastguard Worker
624*da0073e9SAndroid Build Coastguard Worker### Puzzle 8 - diff
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard WorkerCompute [diff](https://numpy.org/doc/stable/reference/generated/numpy.diff.html) - the running difference.
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Worker```py
629*da0073e9SAndroid Build Coastguard Workerdef diff_spec(a, out):
630*da0073e9SAndroid Build Coastguard Worker    out[0] = a[0]
631*da0073e9SAndroid Build Coastguard Worker    for i in range(1, len(out)):
632*da0073e9SAndroid Build Coastguard Worker        out[i] = a[i] - a[i - 1]
633*da0073e9SAndroid Build Coastguard Workerdef diff(a, i: int):
634*da0073e9SAndroid Build Coastguard Worker    i = dims(1)
635*da0073e9SAndroid Build Coastguard Worker    d = a[i] - a[i - 1]
636*da0073e9SAndroid Build Coastguard Worker    return where(i - 1 >= 0, d, a[i]).order(i)
637*da0073e9SAndroid Build Coastguard Worker```
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Worker### Puzzle 9 - vstack
640*da0073e9SAndroid Build Coastguard Worker
641*da0073e9SAndroid Build Coastguard WorkerCompute [vstack](https://numpy.org/doc/stable/reference/generated/numpy.vstack.html) - the matrix of two vectors
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker```py
644*da0073e9SAndroid Build Coastguard Workerdef vstack_spec(a, b, out):
645*da0073e9SAndroid Build Coastguard Worker    for i in range(len(out[0])):
646*da0073e9SAndroid Build Coastguard Worker        out[0][i] = a[i]
647*da0073e9SAndroid Build Coastguard Worker        out[1][i] = b[i]
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Workerdef vstack(a, b):
650*da0073e9SAndroid Build Coastguard Worker    v, i = dims(sizes=[2, None])
651*da0073e9SAndroid Build Coastguard Worker    return where(v == 0,  a[i], b[i]).order(v, i)
652*da0073e9SAndroid Build Coastguard Worker```
653*da0073e9SAndroid Build Coastguard Worker
654*da0073e9SAndroid Build Coastguard Worker### Puzzle 10 - roll
655*da0073e9SAndroid Build Coastguard Worker
656*da0073e9SAndroid Build Coastguard WorkerCompute [roll](https://numpy.org/doc/stable/reference/generated/numpy.roll.html) - the vector shifted 1 circular position.
657*da0073e9SAndroid Build Coastguard Worker
658*da0073e9SAndroid Build Coastguard Worker```py
659*da0073e9SAndroid Build Coastguard Workerdef roll_spec(a, out):
660*da0073e9SAndroid Build Coastguard Worker    for i in range(len(out)):
661*da0073e9SAndroid Build Coastguard Worker        if i + 1 < len(out):
662*da0073e9SAndroid Build Coastguard Worker            out[i] = a[i + 1]
663*da0073e9SAndroid Build Coastguard Worker        else:
664*da0073e9SAndroid Build Coastguard Worker            out[i] = a[i + 1 - len(out)]
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Workerdef roll(a, i: int):
667*da0073e9SAndroid Build Coastguard Worker    i = dims(sizes=[a.size(0)])
668*da0073e9SAndroid Build Coastguard Worker    return a[where(i + 1 < i.size, i + 1, 0)].order(i)
669*da0073e9SAndroid Build Coastguard Worker```
670*da0073e9SAndroid Build Coastguard Worker
671*da0073e9SAndroid Build Coastguard Worker### Puzzle 11 - flip
672*da0073e9SAndroid Build Coastguard Worker
673*da0073e9SAndroid Build Coastguard WorkerCompute [flip](https://numpy.org/doc/stable/reference/generated/numpy.flip.html) - the reversed vector
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Worker```py
676*da0073e9SAndroid Build Coastguard Workerdef flip_spec(a, out):
677*da0073e9SAndroid Build Coastguard Worker    for i in range(len(out)):
678*da0073e9SAndroid Build Coastguard Worker        out[i] = a[len(out) - i - 1]
679*da0073e9SAndroid Build Coastguard Worker
680*da0073e9SAndroid Build Coastguard Workerdef flip(a, i: int):
681*da0073e9SAndroid Build Coastguard Worker    i = dims(sizes=[a.size(0)])
682*da0073e9SAndroid Build Coastguard Worker    return a[i.size - i - 1].order(i)
683*da0073e9SAndroid Build Coastguard Worker```
684*da0073e9SAndroid Build Coastguard Worker
685*da0073e9SAndroid Build Coastguard Worker### Puzzle 14 - sequence_mask
686*da0073e9SAndroid Build Coastguard Worker
687*da0073e9SAndroid Build Coastguard Worker
688*da0073e9SAndroid Build Coastguard WorkerCompute [sequence_mask](https://www.tensorflow.org/api_docs/python/tf/sequence_mask) - pad out to length per batch.
689*da0073e9SAndroid Build Coastguard Worker
690*da0073e9SAndroid Build Coastguard Worker```py
691*da0073e9SAndroid Build Coastguard Workerdef sequence_mask_spec(values, length, out):
692*da0073e9SAndroid Build Coastguard Worker    for i in range(len(out)):
693*da0073e9SAndroid Build Coastguard Worker        for j in range(len(out[0])):
694*da0073e9SAndroid Build Coastguard Worker            if j < length[i]:
695*da0073e9SAndroid Build Coastguard Worker                out[i][j] = values[i][j]
696*da0073e9SAndroid Build Coastguard Worker            else:
697*da0073e9SAndroid Build Coastguard Worker                out[i][j] = 0
698*da0073e9SAndroid Build Coastguard Worker
699*da0073e9SAndroid Build Coastguard Workerdef sequence_mask(values, length):
700*da0073e9SAndroid Build Coastguard Worker    j, i = dims()
701*da0073e9SAndroid Build Coastguard Worker    v = values[i, j]
702*da0073e9SAndroid Build Coastguard Worker    return where(j < length[i], v, 0).order(i, j)
703*da0073e9SAndroid Build Coastguard Worker```
704*da0073e9SAndroid Build Coastguard Worker
705*da0073e9SAndroid Build Coastguard WorkerAdvantages of First-class Dimensions over String Dimensions
706*da0073e9SAndroid Build Coastguard Worker===================================================================
707*da0073e9SAndroid Build Coastguard Worker
708*da0073e9SAndroid Build Coastguard WorkerThe most prominent difference between named tensors using first-class dimensions and alternatives (einops, named tensors implemented in PyTorch today , [tensors considered harmful](https://nlp.seas.harvard.edu/NamedTensor), or xmap) is that dimensions are objects rather than strings. Using objects has a number of nice properties.
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker### Avoiding naming conflicts
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard WorkerUsing strings for dimensions introduces the possibility that two unrelated dimensions are given the same name. Using objects instead makes it clear the same names are not the same dimension. It's like the difference between having only global variables, and having the ability to locally bind names in functions.
713*da0073e9SAndroid Build Coastguard Worker For instance, we defined `bmm` by batching a call to `mm`, and even though they both use the name `i` to identify a dimension.  Because each `i` is a different object, there is no naming conflict:
714*da0073e9SAndroid Build Coastguard Worker
715*da0073e9SAndroid Build Coastguard Worker```py
716*da0073e9SAndroid Build Coastguard Workerdef mm(A, B):
717*da0073e9SAndroid Build Coastguard Worker    i, j, k = dims()
718*da0073e9SAndroid Build Coastguard Worker    r = (A[i, k] * B[k, j]).sum(k)
719*da0073e9SAndroid Build Coastguard Worker    return r.order(i, j)
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Workerdef bmm(A, B):
722*da0073e9SAndroid Build Coastguard Worker    i = dims() # note: doesn't matter than mm internally also uses i
723*da0073e9SAndroid Build Coastguard Worker    return mm(A[i], B[i])
724*da0073e9SAndroid Build Coastguard Worker```
725*da0073e9SAndroid Build Coastguard Worker
726*da0073e9SAndroid Build Coastguard WorkerEinops avoids conflicts by ensuring names are all introduced and removed in a single expression, but this precludes using long-lived dimensions to present implicit batching similar to xmap. When nested, JAX's xmap seems to consider axes the same if the string name matches. In the above example it would consider the `i` dimension to be the same dimension in both `bmm` and `mm` so the code would error.
727*da0073e9SAndroid Build Coastguard Worker
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker### Reuse the same operator set
730*da0073e9SAndroid Build Coastguard Worker
731*da0073e9SAndroid Build Coastguard WorkerHaving a new object type allows us to extend the existing operator set of PyTorch rather than come up with new operators. For instance, binding dimensions using indexing follows semantically from Rules #1 and #3, so there is no need for a special operator to do binding. Even unbinding is just the `permute` operator which follows from Rule #2, though we call it `order` for clarity. In contrast, using strings requires coming up with new APIs such as `einsum` for matrix multiplies, or `rearrange` for doing permutations.
732*da0073e9SAndroid Build Coastguard Worker
733*da0073e9SAndroid Build Coastguard Worker### Allows dims to act as tensors
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard WorkerRule #3 is not possible with strings since we cannot make strings behave as tensors. Without this rule, all of the indirect indexing that dims enable would not be easy to express.
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker### Dims can have methods
738*da0073e9SAndroid Build Coastguard WorkerFor instance, as objects, dims can have a size, which allows us to do size inference of dimensions in various places in the API where string based APIs would have to take additional arguments specifying size.
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard WorkerComparison to tensor compilers or languages (e.g. TVM or Dex)
742*da0073e9SAndroid Build Coastguard Worker=============================================================
743*da0073e9SAndroid Build Coastguard Worker
744*da0073e9SAndroid Build Coastguard WorkerThe semantics and surface syntax of dimension objects resembles the kind of code written in tensor compilers such as [Halide](https://halide-lang.org), [TVM](https://tvm.apache.org), [Tensor Comprehensions](https://github.com/facebookresearch/TensorComprehensions), or the language [Dex](https://github.com/google-research/dex-lang).
745*da0073e9SAndroid Build Coastguard Worker
746*da0073e9SAndroid Build Coastguard WorkerThese compilers and language have syntax and semantics that resemble the loop-level analogy similar to first-class dimensions. However, as compilers or statically typed languages, they require some binding code to go from running deep learning framework code in Python to using the compiled language. This often at least requires refactoring the compiled parts into their own functions, and may require defining a gradient function. Similar to graph mode frameworks, this adds friction to using and debugging the code.
747*da0073e9SAndroid Build Coastguard Worker
748*da0073e9SAndroid Build Coastguard WorkerDimension objects are just an extension of the existing PyTorch tensors and eager semantics, so there is no friction switching between normal Python code and code that uses them. However, since loops over the dimensions are defined implicitly, they can still execute in Python with good performance compared to explicit loops. Furthermore, with dimension objects, a tensors containing dimensions can compute through code that is oblivious to the dimension such as batching examples. There is no need to separate code into 'compiled' vs 'eager'.
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard WorkerIn this way, first-class dims are a way of adapting the nicer syntax of these array compilers and languages to eager numpy-style libraries.
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard WorkerPerformance Expectations
754*da0073e9SAndroid Build Coastguard Worker========================
755*da0073e9SAndroid Build Coastguard WorkerFirst-class dimensions are not a compiler. They provide syntax for existing PyTorch operations such as advanced indexing that is easier to read and write. For large sized tensors, the performance of any statements including them will be the same as using the already existing operations. An important exception is the pattern matching of products and summation, where performance will be improved by issuing to a matrix-multiply kernel. The C++ implementation of dimensions adds a small overhead of around 2us on top of PyTorch's normal overhead of 8us to each function that uses them. In the future, the implementation can encorporate more fusion optimization to further improve performance of this style of code.
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Worker
758*da0073e9SAndroid Build Coastguard Worker## License
759*da0073e9SAndroid Build Coastguard WorkerFunctorch has a BSD-style license, as found in the [LICENSE](LICENSE) file.
760