1 #include <torch/library.h>
2 #include <ATen/ATen.h>
3 #include <ATen/LegacyVmapTransforms.h>
4 #include <ATen/LegacyBatchedFallback.h>
5 #include <ATen/RedispatchFunctions.h>
6 #include <ATen/native/ResizeCommon.h>
7 #include <ATen/core/IListRef.h>
8 #include <c10/util/irange.h>
9 #include <c10/core/SymIntArrayRef.h>
10
11 #include <utility>
12
13 namespace at {
14
15 // NOTE: [What is a batching rule?]
16 //
17 // A *batching rule* implements the logic of how to call an operator on inputs
18 // that have zero or more additional batch dimensions. When one does a vmap, the
19 // dimension(s) being vmap'ed over get recorded as batch dimensions.
20 //
21 // For example, vmap(torch.add)(x, y)
22 // 1. wraps `x` into batched_x = BatchedTensor(x, bdims=[(lvl=1, dim=0)];
23 // 2. wraps `y` into batched_y = BatchedTensor(y, bdims=[(lvl=1, dim=0)];
24 // 3. and then runs `torch.add(batched_x, batched_y)`.
25
26 // NOTE: [When should I add a batching rule?]
27 // When you are adding a new operator, you'll need to add a batching rule so
28 // that vmap can work efficiently with said operator. If you do not, we'll attempt
29 // to generate a slow fallback for the batching rule.
30
31 // NOTE: [How to write batching rules?]
32 // The signature of a batching rule should look like exactly like the C++ signature
33 // of its operator.
34 //
35 // First, see NOTE: [Logical vs physical args] in VmapTransforms.h for terminology.
36 //
37 // At a high level, what a batching rule does is the following:
38 // 1. Converts (logical) BatchedTensors to views on physical tensors.
39 // 2. Converts logical arguments (e.g. dimension indexes, shapes) to physical
40 // arguments that correspond to the physical tensors.
41 // 3. Calls at:: operations on the physical tensors and arguments to produce
42 // some physical results.
43 // 4. Converts physical results back to BatchedTensors.
44 //
45 // Steps 1, 2, and 4 differ for operators with different batching behaviors. When
46 // writing a new batching rule, please select a VmapTransform that matches the
47 // batching behavior of your operation. The VmapTransform provides helper functions
48 // to do steps (1), (2), and (4).
49 // (see NOTE: [What is an VmapTransform?] in VmapTransforms.h)
50
51 // Note: [Future plans]
52 // The API for writing a batching rule isn't stable. In the future, we'd like
53 // to think about the problem of translating these batching rules to TorchScript.
54 // Ideally batching rules in eager mode vs TorchScript would look pretty similar,
55 // if not use the same mechanism. In order to accomplish that we might have to
56 // do some refactoring.
57
58 namespace{
59
60 // PyTorch allows operations to specify dim 0 and dim -1 on a scalar tensor.
is_allowed_dim_on_scalar_tensor(int64_t dim)61 static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
62 return dim == 0 || dim == -1;
63 }
64
sum_batching_rule(const Tensor & self,OptionalIntArrayRef opt_dims,bool keepdim,std::optional<ScalarType> dtype)65 Tensor sum_batching_rule(const Tensor& self, OptionalIntArrayRef opt_dims, bool keepdim, std::optional<ScalarType> dtype) {
66 if (opt_dims.has_value()) {
67 auto dims = opt_dims.value();
68 // PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail
69 // and instead returns a new scalar tensor (this also happens for dim=-1)
70 // If the following happens:
71 // >>> x = torch.randn(B0) # the per-examples are all scalars
72 // >>> vmap(partial(torch.sum, dim=0), x)
73 // then we replicate the behavior of sum(scalar_tensor, dim=0).
74 if (/*logical*/self.dim() == 0 && (dims.empty() || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) {
75 return self.clone();
76 }
77 }
78 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
79 auto dims_physical = self_physical.getPhysicalDims(opt_dims);
80 auto result = at::sum(self_physical.tensor(), dims_physical, keepdim, dtype);
81 return self_physical.getPhysicalToLogicalMap().apply(result);
82 }
83
isPhysicalScalarTensor(const Tensor & logical_tensor)84 bool isPhysicalScalarTensor(const Tensor& logical_tensor) {
85 if (logical_tensor.dim() > 0) {
86 return false;
87 }
88 auto* batched = maybeGetBatchedImpl(logical_tensor);
89 if (batched) {
90 return false;
91 }
92 return true;
93 }
94
95 template <typename F, F Func, typename... ExtraArgs>
binary_pointwise_batching_rule(const Tensor & self,const Tensor & other,ExtraArgs...args)96 Tensor binary_pointwise_batching_rule(
97 const Tensor& self, const Tensor& other, ExtraArgs... args) {
98 if (self.dim() > 0 && other.dim() > 0) {
99 auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
100 auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
101 return physical_args[0].getPhysicalToLogicalMap().apply(result);
102 }
103 if (isPhysicalScalarTensor(self)) {
104 auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
105 auto result = Func(self, other_physical.tensor(), args...);
106 return other_physical.getPhysicalToLogicalMap().apply(result);
107 }
108 if (isPhysicalScalarTensor(other)) {
109 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
110 auto result = Func(self_physical.tensor(), other, args...);
111 return self_physical.getPhysicalToLogicalMap().apply(result);
112 }
113
114 // At this point, we know at least one of the operands is a logical Scalar tensor.
115 // Here we must emulate TensorIterator's special behavior on Scalars.
116 //
117 // As a motivating example, consider the following:
118 // x = torch.randn(3, 10)
119 // y = torch.randn(3, dtype=torch.double)
120 // vmap(torch.mul)(torch.randn(3, 10), torch.randn(3, dtype=torch.double))
121 //
122 // At a per-example level, we are adding FloatTensor[10] and DoubleTensor[];
123 // Type Promotion dictates that the result should be FloatTensor[10].
124 // This means we cannot directly pass the physical tensors (x and y) to
125 // TensorIterator (if we did, it would promote them to DoubleTensor).
126 //
127 // FIXME(rzou): I didn't want to go down the slippery slope of emulating
128 // everything TensorIterator does (it would be better to refactor out the
129 // TensorIterator logic). The one thing that this code doesn't handle
130 // is cross-device logical scalar tensors.
131 // cpu_tensor = torch.randn(3)
132 // cuda_tensor = torch.randn(3, 10, device='cuda')
133 // vmap(torch.mul)(cpu_tensor, cuda_tensor)
134 //
135 // At a per-example level, we are adding CPUTensor[] and CUDATensor[10].
136 // TensorIterator allows for this cross-device operation because one of the
137 // tensors is a Scalar CPU tensor. However, the following code will throw an
138 // error in that case. I don't expect to see many use cases for this, so
139 // this is probably fine as-is.
140 auto logical_self = self;
141 auto logical_other = other;
142 auto result_type = at::native::result_type(logical_self, logical_other);
143 if (logical_self.scalar_type() != result_type) {
144 logical_self = logical_self.to(result_type);
145 }
146 if (logical_other.scalar_type() != result_type) {
147 logical_other = logical_other.to(result_type);
148 }
149 auto physical_args = BroadcastingVmapTransform::logicalToPhysical(
150 {std::move(logical_self), std::move(logical_other)});
151 auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
152 return physical_args[0].getPhysicalToLogicalMap().apply(result);
153 }
154
expand_batching_rule(const Tensor & self,IntArrayRef size,bool implicit)155 Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) {
156 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
157 auto size_physical = self_physical.getPhysicalShape(size);
158 auto self_physical_dim = self_physical.tensor().dim();
159
160 TORCH_CHECK(self_physical_dim <= static_cast<int64_t>(size_physical.size()),
161 "expand: the number of sizes provided (", /*logical*/size.size(), ") ",
162 "must be greater or equal to the number of dimensions in the tensor (",
163 /*logical dim*/self.dim(), ")");
164
165 if (self_physical_dim == static_cast<int64_t>(size_physical.size())) {
166 auto result = self_physical.tensor().expand(size_physical, implicit);
167 return self_physical.getPhysicalToLogicalMap().apply(result);
168 }
169
170 TORCH_INTERNAL_ASSERT(self_physical_dim < static_cast<int64_t>(size_physical.size()));
171 // Here, we know we are expanding a (logical) tensor to a larger number
172 // of dimensions. We have to be careful because we can't call expand directly
173 // due to the presence of batch dimensions.
174 //
175 // As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]).
176 // The result should be a tensor of size [B0, 2, 3].
177 // A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3]
178 // so the strategy here is to view it first as a tensor of size [B0, 1, 3] and
179 // then expand.
180 auto self_physical_size = self_physical.tensor().sizes();
181 auto extra_dims = size_physical.size() - self_physical_dim;
182 VmapDimVector view_shape(size_physical.size(), 1);
183 std::copy(self_physical_size.begin(),
184 self_physical_size.begin() + self_physical.numBatchDims(),
185 view_shape.begin());
186 std::copy(self_physical_size.begin() + self_physical.numBatchDims(),
187 self_physical_size.end(),
188 view_shape.begin() + self_physical.numBatchDims() + extra_dims);
189 auto result = self_physical.tensor().view(view_shape).expand(size_physical, implicit);
190 return self_physical.getPhysicalToLogicalMap().apply(result);
191 }
192
chunk_batching_rule(const Tensor & self,int64_t chunks,int64_t dim)193 std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
194 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
195 auto dim_physical = self_physical.getPhysicalDim(dim);
196 auto result = at::chunk(self_physical.tensor(), chunks, dim_physical);
197 self_physical.getPhysicalToLogicalMap().applyInplace(result);
198 return result;
199 }
200
clamp_batching_rule(const Tensor & self,const std::optional<Scalar> & min,const std::optional<Scalar> & max)201 Tensor clamp_batching_rule(const Tensor& self, const std::optional<Scalar>& min, const std::optional<Scalar>& max) {
202 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
203 auto result = at::clamp(self_physical.tensor(), min, max);
204 return self_physical.getPhysicalToLogicalMap().apply(result);
205 }
206
clamp_min_batching_rule(const Tensor & self,const Scalar & min)207 Tensor clamp_min_batching_rule(const Tensor& self, const Scalar& min) {
208 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
209 auto result = at::clamp_min(self_physical.tensor(), min);
210 return self_physical.getPhysicalToLogicalMap().apply(result);
211 }
212
clamp_max_batching_rule(const Tensor & self,const Scalar & max)213 Tensor clamp_max_batching_rule(const Tensor& self, const Scalar& max) {
214 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
215 auto result = at::clamp_max(self_physical.tensor(), max);
216 return self_physical.getPhysicalToLogicalMap().apply(result);
217 }
218
tensor_split_sections_batching_rule(const Tensor & self,int64_t sections,int64_t dim)219 std::vector<Tensor> tensor_split_sections_batching_rule(const Tensor& self, int64_t sections, int64_t dim) {
220 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
221 auto dim_physical = self_physical.getPhysicalDim(dim);
222 auto result = at::tensor_split(self_physical.tensor(), sections, dim_physical);
223 self_physical.getPhysicalToLogicalMap().applyInplace(result);
224 return result;
225 }
226
tensor_split_indices_batching_rule(const Tensor & self,IntArrayRef indices,int64_t dim)227 std::vector<Tensor> tensor_split_indices_batching_rule(const Tensor& self, IntArrayRef indices, int64_t dim) {
228 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
229 auto dim_physical = self_physical.getPhysicalDim(dim);
230 auto result = at::tensor_split(self_physical.tensor(), indices, dim_physical);
231 self_physical.getPhysicalToLogicalMap().applyInplace(result);
232 return result;
233 }
234
unsqueeze_batching_rule(const Tensor & self,int64_t dim)235 Tensor unsqueeze_batching_rule(const Tensor& self, int64_t dim) {
236 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
237 // NB: unsqueeze has some special handling of its `dim` argument so we can't call
238 // self_physical.getPhysicalDim directly. In particular, native::unsqueeze
239 // wraps the dim to (the logical dimension) + 1, so we need to do that here too.
240 // https://github.com/pytorch/pytorch/blob/b623bdeabb0aa8da44285d303246e7f8ac06c2a9/aten/src/ATen/native/TensorShape.cpp#L1413
241 auto dim_physical =
242 self_physical.numBatchDims() + maybe_wrap_dim(dim, /*logical_dim*/self.dim() + 1);
243 auto result = self_physical.tensor().unsqueeze(dim_physical);
244 return self_physical.getPhysicalToLogicalMap().apply(result);
245 }
246
fill_inplace_scalar_batching_rule(Tensor & self,const Scalar & value)247 Tensor& fill_inplace_scalar_batching_rule(Tensor& self, const Scalar& value) {
248 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
249 self_physical.tensor().fill_(value);
250 return self;
251 }
252
fill_inplace_tensor_batching_rule(Tensor & self,const Tensor & value)253 Tensor& fill_inplace_tensor_batching_rule(Tensor& self, const Tensor& value) {
254 auto value_batched = isBatchedTensor(value);
255
256 if (value_batched) {
257 auto physical_args =
258 BroadcastingVmapTransform::logicalToPhysical({self, value});
259 physical_args[0].tensor().copy_(physical_args[1].tensor());
260 } else {
261 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
262 self_physical.tensor().fill_(value);
263 }
264 return self;
265 }
266
zero_inplace_batching_rule(Tensor & self)267 Tensor& zero_inplace_batching_rule(Tensor &self) {
268 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
269 self_physical.tensor().zero_();
270 return self;
271 }
272
squeeze_batching_rule(const Tensor & self)273 Tensor squeeze_batching_rule(const Tensor& self) {
274 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
275 auto physical_sizes = self_physical.tensor().sizes();
276
277 // Don't squeeze the batch dims!
278 VmapDimVector squeezed_sizes;
279 int64_t num_batch_dims = self_physical.numBatchDims();
280 squeezed_sizes.insert(
281 squeezed_sizes.end(),
282 physical_sizes.begin(),
283 physical_sizes.begin() + num_batch_dims);
284 for (auto it = physical_sizes.begin() + num_batch_dims; it != physical_sizes.end(); ++it) {
285 if (*it != 1) {
286 squeezed_sizes.push_back(*it);
287 }
288 }
289
290 auto result = self_physical.tensor().view(squeezed_sizes);
291 return self_physical.getPhysicalToLogicalMap().apply(result);
292 }
293
squeeze_dim_batching_rule(const Tensor & self,int64_t dim)294 Tensor squeeze_dim_batching_rule(const Tensor& self, int64_t dim) {
295 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
296 auto dim_physical = self_physical.getPhysicalDim(dim);
297 auto result = self_physical.tensor().squeeze(dim_physical);
298 return self_physical.getPhysicalToLogicalMap().apply(result);
299 }
300
squeeze_dims_batching_rule(const Tensor & self,IntArrayRef dims)301 Tensor squeeze_dims_batching_rule(const Tensor& self, IntArrayRef dims) {
302 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
303 auto dims_physical = self_physical.getPhysicalDims(dims);
304 auto result = self_physical.tensor().squeeze(dims_physical);
305 return self_physical.getPhysicalToLogicalMap().apply(result);
306 }
307
trace_batching_rule(const Tensor & self)308 Tensor trace_batching_rule(const Tensor& self) {
309 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
310 // Batched Diagonal View
311 auto self_diag = at::diagonal(self_physical.tensor(), /*offset*/0, /*dim1*/-2, /*dim2*/-1);
312 auto result = at::sum(self_diag, -1);
313 return self_physical.getPhysicalToLogicalMap().apply(result);
314 }
315
trace_backward_batching_rule(const Tensor & grad,IntArrayRef input_sizes)316 Tensor trace_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes) {
317 auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
318 auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
319 // Batched Diagonal View
320 auto grad_input_diag = at::diagonal(grad_input, /*offset*/0, /*dim1*/-2, /*dim2*/-1);
321 // Append a dimension of size one to the grad output
322 auto grad_physical_tensor = grad_physical.tensor().unsqueeze(-1);
323 grad_input_diag.copy_(grad_physical_tensor);
324 return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
325 }
326
transpose_int_batching_rule(const Tensor & self,int64_t dim0,int64_t dim1)327 Tensor transpose_int_batching_rule(const Tensor& self, int64_t dim0, int64_t dim1) {
328 // PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works
329 // for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens:
330 // >>> x = torch.randn(B0) # the per-examples are all scalars
331 // >>> vmap(lambda x: x.transpose(0, -1), x)
332 // then we replicate this behavior.
333 if (/*logical*/self.dim() == 0 && is_allowed_dim_on_scalar_tensor(dim0) &&
334 is_allowed_dim_on_scalar_tensor(dim1)) {
335 return self;
336 }
337 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
338 auto dim0_physical = self_physical.getPhysicalDim(dim0);
339 auto dim1_physical = self_physical.getPhysicalDim(dim1);
340 auto result = self_physical.tensor().transpose(dim0_physical, dim1_physical);
341 return self_physical.getPhysicalToLogicalMap().apply(result);
342 }
343
permute_batching_rule(const Tensor & self,IntArrayRef dims)344 Tensor permute_batching_rule(const Tensor& self, IntArrayRef dims) {
345 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
346 auto dims_physical = self_physical.getPhysicalDims(dims);
347
348 VmapDimVector all_dims_physical;
349 all_dims_physical.reserve(self_physical.tensor().dim());
350 for (const auto bdim : c10::irange(self_physical.numBatchDims())) {
351 all_dims_physical.push_back(bdim);
352 }
353 all_dims_physical.insert(
354 all_dims_physical.end(),
355 dims_physical.begin(),
356 dims_physical.end());
357 auto result = self_physical.tensor().permute(all_dims_physical);
358 return self_physical.getPhysicalToLogicalMap().apply(result);
359 }
360
select_batching_rule(const Tensor & self,int64_t dim,int64_t index)361 Tensor select_batching_rule(const Tensor& self, int64_t dim, int64_t index) {
362 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
363 auto dim_physical = self_physical.getPhysicalDim(dim);
364 auto result = self_physical.tensor().select(dim_physical, index);
365 return self_physical.getPhysicalToLogicalMap().apply(result);
366 }
367
getGradInputPhysicalDim(int64_t dim,IntArrayRef input_sizes,int64_t num_batch_dims)368 static int64_t getGradInputPhysicalDim(int64_t dim, IntArrayRef input_sizes, int64_t num_batch_dims) {
369 return maybe_wrap_dim(dim, input_sizes.size()) + num_batch_dims;
370 }
371
select_backward_batching_rule(const Tensor & grad,IntArrayRef input_sizes,int64_t dim,int64_t index)372 Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
373 auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
374 auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
375 auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims());
376 grad_input.select(physical_dim, index).copy_(grad_physical.tensor());
377 return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
378 }
379
slice_batching_rule(const Tensor & self,int64_t dim,std::optional<int64_t> start,std::optional<int64_t> end,int64_t step)380 Tensor slice_batching_rule(
381 const Tensor& self,
382 int64_t dim,
383 std::optional<int64_t> start,
384 std::optional<int64_t> end,
385 int64_t step) {
386 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
387 auto dim_physical = self_physical.getPhysicalDim(dim);
388 auto result = self_physical.tensor().slice(dim_physical, start, end, step);
389 return self_physical.getPhysicalToLogicalMap().apply(result);
390 }
391
slice_backward_batching_rule(const Tensor & grad,IntArrayRef input_sizes,int64_t dim,int64_t start,int64_t end,int64_t step)392 Tensor slice_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
393 auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
394 auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
395 auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims());
396 grad_input.slice(physical_dim, start, end, step).copy_(grad_physical.tensor());
397 return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
398 }
399
diagonal_batching_rule(const Tensor & self,int64_t offset,int64_t dim1,int64_t dim2)400 Tensor diagonal_batching_rule(const Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) {
401 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
402 auto dim1_physical = self_physical.getPhysicalDim(dim1);
403 auto dim2_physical = self_physical.getPhysicalDim(dim2);
404 auto result = at::diagonal(self_physical.tensor(), offset, dim1_physical, dim2_physical);
405 return self_physical.getPhysicalToLogicalMap().apply(result);
406 }
407
diagonal_backward_batching_rule(const Tensor & grad,IntArrayRef input_sizes,int64_t offset,int64_t dim1,int64_t dim2)408 Tensor diagonal_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
409 auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
410 auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
411 auto dim1_physical = getGradInputPhysicalDim(dim1, input_sizes, grad_physical.numBatchDims());
412 auto dim2_physical = getGradInputPhysicalDim(dim2, input_sizes, grad_physical.numBatchDims());
413 grad_input.diagonal(offset, dim1_physical, dim2_physical).copy_(grad_physical.tensor());
414 return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
415 }
416
movedim_batching_rule(const Tensor & self,IntArrayRef source,IntArrayRef destination)417 Tensor movedim_batching_rule(const Tensor& self, IntArrayRef source, IntArrayRef destination) {
418 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
419 auto source_physical = self_physical.getPhysicalDims(source);
420 auto destination_physical = self_physical.getPhysicalDims(destination);
421 auto result = at::movedim(self_physical.tensor(), source_physical, destination_physical);
422 return self_physical.getPhysicalToLogicalMap().apply(result);
423 }
424
reshape_batching_rule(const Tensor & self,IntArrayRef shape)425 Tensor reshape_batching_rule(const Tensor& self, IntArrayRef shape) {
426 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
427 auto shape_physical = self_physical.getPhysicalShape(shape);
428 auto result = self_physical.tensor().reshape(shape_physical);
429 return self_physical.getPhysicalToLogicalMap().apply(result);
430 }
431
split_batching_rule(const Tensor & self,int64_t split_size,int64_t dim)432 std::vector<Tensor> split_batching_rule(const Tensor& self, int64_t split_size, int64_t dim) {
433 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
434 auto dim_physical = self_physical.getPhysicalDim(dim);
435 auto result = at::split(self_physical.tensor(), split_size, dim_physical);
436 self_physical.getPhysicalToLogicalMap().applyInplace(result);
437 return result;
438 }
439
split_with_sizes_batching_rule(const Tensor & self,IntArrayRef split_sizes,int64_t dim)440 std::vector<Tensor> split_with_sizes_batching_rule(const Tensor& self, IntArrayRef split_sizes, int64_t dim) {
441 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
442 auto dim_physical = self_physical.getPhysicalDim(dim);
443 auto result = at::split_with_sizes(self_physical.tensor(), split_sizes, dim_physical);
444 self_physical.getPhysicalToLogicalMap().applyInplace(result);
445 return result;
446 }
447
unbind_batching_rule(const Tensor & self,int64_t dim)448 std::vector<Tensor> unbind_batching_rule(const Tensor& self, int64_t dim) {
449 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
450 auto dim_physical = self_physical.getPhysicalDim(dim);
451 auto result = at::unbind(self_physical.tensor(), dim_physical);
452 self_physical.getPhysicalToLogicalMap().applyInplace(result);
453 return result;
454 }
455
unfold_batching_rule(const Tensor & self,int64_t dim,int64_t size,int64_t step)456 Tensor unfold_batching_rule(const Tensor& self, int64_t dim, int64_t size, int64_t step) {
457 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
458 auto dim_physical = self_physical.getPhysicalDim(dim);
459 auto result = self_physical.tensor().unfold(dim_physical, size, step);
460 return self_physical.getPhysicalToLogicalMap().apply(result);
461 }
462
contiguous_batching_rule(const Tensor & self,MemoryFormat memory_format)463 Tensor contiguous_batching_rule(const Tensor& self, MemoryFormat memory_format) {
464 TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
465 "NYI: Tensor.contiguous(...) inside of vmap for memory_format other ",
466 "than torch.contiguous_format");
467 auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
468 auto result = physical_view.tensor().contiguous(memory_format);
469 return physical_view.getPhysicalToLogicalMap().apply(result);
470 }
471
view_batching_rule(const Tensor & self,IntArrayRef size)472 Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
473 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
474 auto size_physical = self_physical.getPhysicalShape(size);
475 auto result = self_physical.tensor().view(size_physical);
476 return self_physical.getPhysicalToLogicalMap().apply(result);
477 }
478
view_as_complex_batching_rule(const Tensor & self)479 Tensor view_as_complex_batching_rule(const Tensor& self) {
480 // guard against the user passing in a batch of scalar tensors with batch
481 // size equal to 2.
482 TORCH_CHECK(!self.sizes().empty(), "Input tensor must have one or more dimensions");
483 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
484 auto result = at::view_as_complex(self_physical.tensor());
485 return self_physical.getPhysicalToLogicalMap().apply(result);
486 }
487
488 // Checks that the smallest batch stride is greater than the largest example
489 // stride. This is something we can support but we choose not to because it's
490 // potentially error prone.
checkBatchDimsAtFrontInLayout(IntArrayRef physical_strides,int64_t num_batch_dims)491 static void checkBatchDimsAtFrontInLayout(IntArrayRef physical_strides, int64_t num_batch_dims) {
492 auto smallest_batch_stride = std::min_element(
493 physical_strides.begin(), physical_strides.begin() + num_batch_dims);
494 auto largest_example_stride = std::max_element(
495 physical_strides.begin() + num_batch_dims, physical_strides.end());
496 if (largest_example_stride == physical_strides.end()) {
497 // No example dimensions
498 return;
499 }
500 TORCH_CHECK(*smallest_batch_stride >= *largest_example_stride,
501 "vmap: Calling Tensor.as_strided is not supported unless the batch dims being ",
502 "vmapped over are at the front of the tensor (in memory layout). When they are ",
503 "not at the front of the tensor this operation can be error prone so we "
504 "actively discourage it; please file us a bug report and/or try to ",
505 "express the as_strided operation in terms of PyTorch view operations");
506 }
507
508 // given (sizes, strides, storage_offset) returns the maximum location that
509 // can be indexed (or nullopt if such a location doesn't exist, e.g., tensors
510 // with zero-size dims).
maximum_indexable_location(IntArrayRef sizes,IntArrayRef strides,int64_t storage_offset)511 static std::optional<int64_t> maximum_indexable_location(
512 IntArrayRef sizes, IntArrayRef strides, int64_t storage_offset) {
513 auto result = native::storage_size_for(sizes, strides);
514 if (result == 0) {
515 return std::nullopt;
516 }
517 return result + storage_offset;
518 }
519
520 // Let x be the "first slice" of physical_tensor.
521 // This checks that the range of possible memory locations accessible by
522 // x.as_strided(sizes, strides, maybe_storage_offset)
523 // are within the bounds of possible memory locations accessible by x.
checkBasicAsStridedValidForSlice(const Tensor & physical_tensor,int64_t num_batch_dims,IntArrayRef sizes,IntArrayRef strides,std::optional<int64_t> maybe_storage_offset)524 static void checkBasicAsStridedValidForSlice(
525 const Tensor& physical_tensor,
526 int64_t num_batch_dims,
527 IntArrayRef sizes,
528 IntArrayRef strides,
529 std::optional<int64_t> maybe_storage_offset) {
530 auto slice_sizes = physical_tensor.sizes().slice(num_batch_dims);
531 auto slice_strides = physical_tensor.strides().slice(num_batch_dims);
532 auto base_offset = physical_tensor.storage_offset();
533
534 auto storage_offset = maybe_storage_offset.value_or(base_offset);
535
536 auto max_as_strided_loc = maximum_indexable_location(sizes, strides, storage_offset);
537 auto max_slice_loc = maximum_indexable_location(slice_sizes, slice_strides, base_offset);
538
539 if (!max_as_strided_loc.has_value()) {
540 return;
541 }
542 if (!max_slice_loc.has_value()) {
543 TORCH_CHECK(false,
544 "result = tensor.as_strided(", sizes, ",", strides, ",", storage_offset, ")",
545 "can access memory outside of `tensor`. `tensor` has no storage but the ",
546 "passed-in (size, stride, storage_offset) imply a result with some storage. ",
547 "This is not supported inside of vmap, please try to rewrite the ",
548 "`as_strided` call as a sequence of PyTorch view operations");
549 }
550
551 TORCH_CHECK(
552 *max_as_strided_loc <= *max_slice_loc && base_offset <= storage_offset,
553 "result = tensor.as_strided(", sizes, ",", strides, ",", storage_offset, ")",
554 "can access memory outside of `tensor`. `result` can access some",
555 "memory in range [", storage_offset, ", ", *max_as_strided_loc, "], but ",
556 "`tensor` can only access some memory in range [", base_offset, ", ",
557 *max_slice_loc, "]. This is not supported inside of vmap, please try to",
558 "rewrite the `as_strided` call as a sequence of PyTorch view operations");
559 }
560
_reshape_alias_batching_rule(const Tensor & self,IntArrayRef sizes,IntArrayRef strides)561 Tensor _reshape_alias_batching_rule(const Tensor& self, IntArrayRef sizes, IntArrayRef strides [[maybe_unused]]) {
562 return reshape_batching_rule(self, sizes);
563 }
564
_new_zeros_with_same_feature_meta_batching_rule(const Tensor & self,const Tensor & other,int64_t unused_num_batch_dims)565 Tensor _new_zeros_with_same_feature_meta_batching_rule(
566 const Tensor& self,
567 const Tensor& other,
568 int64_t unused_num_batch_dims) {
569 TORCH_CHECK(isBatchedTensor(self) && !isBatchedTensor(other),
570 "Only the 'batched grad' use case is supported in PyTorch core.");
571
572 TORCH_INTERNAL_ASSERT(unused_num_batch_dims == 0,
573 "num_batch_dims should not be explicitly passed in because it will be overridden");
574 auto self_physical_view = at::MultiBatchVmapTransform::logicalToPhysical(self);
575 const auto& self_physical_tensor = self_physical_view.tensor();
576 int64_t num_batch_dims = self_physical_view.numBatchDims();
577 checkBatchDimsAtFrontInLayout(self_physical_tensor.strides(), num_batch_dims);
578 auto result = at::_new_zeros_with_same_feature_meta(self_physical_tensor, other, num_batch_dims);
579 return self_physical_view.getPhysicalToLogicalMap().apply(result);
580 }
581
_has_same_storage_numel_batching_rule(const Tensor & self,const Tensor & other)582 bool _has_same_storage_numel_batching_rule(const Tensor& self, const Tensor& other) {
583 TORCH_CHECK(isBatchedTensor(self) && !isBatchedTensor(other),
584 "Only the 'batched grad' use case is supported in PyTorch core.");
585 // The _has_same_storage_numel check is skipped if the tangent is a batched
586 // tensor because using as_strided to access storage locations not indexable
587 // by the input tensor is not supported in vmap
588 return true;
589 }
590
591 // What are the semantics of as_strided inside of vmap?
592 // y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs)
593 // This returns a view on `x`, `y`, such that each y[i] has:
594 // - sizes: `sizes`
595 // - strides: `strides`
596 // - storage_offset: offset + i * x.stride(batch_dim)
597 //
598 // In other words, it is as if we had treated each x[i] as having storage
599 // offset equal to xs.offset() and called as_strided(sizes, sizes, offset).
600 // (that is equivalent to x[i].as_strided(
601 // sizes, sizes, offset + x[i].storage_offset() - xs.offset()) for all i)
602 //
603 // Note that this *may* be different from actually running as_strided
604 // in a for-loop. This is due to how as_strided takes in `offset` to be
605 // an *absolute* offset. As an example, consider:
606 // >>> x = torch.tensor([0., 1., 2., 3., 4.]).as_strided([4], [1], 1)
607 // >>> z = [x[i].as_strided([1], [1], 1) for i in range(4)]
608 // Each z[i] is actually the same view on x (z[i] == torch.tensor([1.]))!
609 // However, we consider the above for-loop comprehension to be a user error:
610 // a user should have written the following if they wanted to use as_strided
611 // in a per-sample way:
612 // >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset() - 1) for i in range(4)]
as_strided_batching_rule(const Tensor & tensor,IntArrayRef sizes,IntArrayRef strides,std::optional<int64_t> storage_offset)613 Tensor as_strided_batching_rule(
614 const Tensor& tensor,
615 IntArrayRef sizes,
616 IntArrayRef strides,
617 std::optional<int64_t> storage_offset) {
618 auto physical_view = at::MultiBatchVmapTransform::logicalToPhysical(tensor);
619 auto num_batch_dims = physical_view.numBatchDims();
620 auto physical_sizes = physical_view.getPhysicalShape(sizes);
621 const auto& physical_tensor = physical_view.tensor();
622
623 // We can't rely on the physical as_strided call to do this for us because
624 // we do some sanity checks on the size/strides before calling into as_strided.
625 TORCH_CHECK(sizes.size() == strides.size(),
626 "Tensor.as_strided(size, stride, ...): size and stride must have the ",
627 "same length! Got size ", sizes, " and stride ", strides);
628
629 // Sanity checks:
630 // 1. All batch dims are at the front in memory layout (not necessary for
631 // correctness, but we are worried the user might be doing crazy things)
632 // 2. as_strided(sizes, strides, storage_offset + tensor[i].offset() - tensor.offset())
633 // is valid for a slice of the input tensor.
634 // See Note: [When will the as_strided batching rule fail?] for details.
635 checkBatchDimsAtFrontInLayout(physical_tensor.strides(), num_batch_dims);
636 checkBasicAsStridedValidForSlice(
637 physical_tensor, num_batch_dims, sizes, strides, storage_offset);
638
639 // physical_strides = physical tensor's batch strides + (logical) strides
640 auto batch_strides = physical_tensor.strides().slice(0, num_batch_dims);
641 at::VmapDimVector physical_strides;
642 physical_strides.reserve(num_batch_dims + strides.size());
643 physical_strides.insert(
644 physical_strides.end(), batch_strides.begin(), batch_strides.end());
645 physical_strides.insert(
646 physical_strides.end(), strides.begin(), strides.end());
647
648 // If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
649 // is valid for all i, then it turns out that
650 // xs.as_strided(physical_sizes, physical_strides, offset) always succeeds
651 // and creates a tensor y such that each y[i] references the same memory
652 // locations as zi. See NOTE: [When will the as_strided batching rule fail?]
653 auto result = physical_view.tensor().as_strided(
654 physical_sizes, physical_strides, storage_offset);
655 return physical_view.getPhysicalToLogicalMap().apply(result);
656 }
657
658 // NOTE: [When will the as_strided batching rule fail?]
659 // If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
660 // is valid for all i, then it turns out that
661 // xs.as_strided(physical_sizes, physical_strides, offset) always succeeds and
662 // creates a tensor y such that each y[i] refers to the same memory as zi.
663 //
664 // Let's say we have xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()).
665 // Furthermore, let's say that as a part of being "valid" this as_strided call
666 // does not return a result that can index memory not indexable by xs[i].
667 //
668 // WLOG, assume that there's only one batch dim and it is at the front of the
669 // `xs` tensor. Let B be the batch size and S be the stride of the batch dim.
670 // - If the batch dim isn't at the front of the tensor, then we can just move it
671 // to the front with movedim/permute. This is always valid because it just swaps
672 // some strides around.
673 // - This proof also works for tensors with multiple batch dims. We just have to
674 // do a little accounting:
675 // - instead of [B], we'd have [B0, B1, ..., Bk].
676 // - instead of [S], we'd have [S0, S1, ..., Sk].
677 // - instead of i, we'd have a list of indices [I0, I1, ..., Ik]
678 // - instead of S * I, we'd have \sum_{i=0}^k S_i * I_i
679 //
680 // [Equation 1]
681 // xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) has:
682 // - sizes: sizes
683 // - strides: strides
684 // - offset: offset + S * i
685 //
686 // x.as_strided itself checks that:
687 // - (sizes, strides, offset) are in bounds for `x`'s storage.
688 // - strides are positive
689 // - offset is positive
690 //
691 // Claim 1: if xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
692 // is valid, then
693 // ([B] + sizes, [S] + strides, offset + xs.offset()) are in bounds for `xs`'s storage.
694 //
695 // If we have the claim, then xs.as_strided([B] + sizes, [S] + strides, offset)
696 // won't error out. So all we need to check is that the memory locations are
697 // what we expected. See [Hand-wavy proof of Claim 1] for proof (it's not very important)
698 //
699 // xs.as_strided(physical_sizes, physical_strides, offset) is equivalent to
700 // xs.as_strided([B] + sizes, [S] + strides, offset)
701 //
702 // xs.as_strided([B] + sizes, [S] + strides, offset) has:
703 // - sizes: [B] + sizes
704 // - strides: [S] + strides
705 // - offset: offset
706 //
707 // xs.as_strided([B] + sizes, [S] + strides, offset)[i] has:
708 // - sizes: sizes
709 // - strides: strides
710 // - offset: offset + S * i
711 // These memory locations are exactly the same as what we got for [Equation 1],
712 // so the xs.as_strided([B] + sizes, [S] + strides, offset) is valid.
713 //
714 // [Hand-wavy proof of Claim 1]
715 // Part of our definition of being valid is that xs[i].as_strided(...)
716 // must return a tensor that only uses memory indexable by xs[i].
717 // This means that (sizes, strides, offset + xs[i].offset() - xs.offset()) satisfies:
718 // offset + xs[i].offset() - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
719 // <= xs[i].offset() + 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
720 // (the largest-index memory location of xs[i].as_strided(...) must be \leq
721 // the largest-index memory location of xs[i])
722 //
723 // Fiddling that inequality gives us:
724 // offset - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
725 // <= 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
726 //
727 // offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
728 // <= 1 + (B-1)*S + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
729 //
730 // offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
731 // <= 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
732 //
733 // offset + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
734 // <= xs.offset() + 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
735 // (the largest-index memory location of xs.as_strided(size, stride, offset)
736 // is \leq than the largest-index memory location of xs)
737 // Under the assumptions we've made, the lower bound (lowest indexed memory)
738 // is trivially within the storage.
739 //
740 // Therefore ([B] + sizes, [S] + strides, offset) are in bounds for
741 // `xs`'s storage.
742
743 template <typename F, F Func, typename... ExtraArgs>
unwrap_and_call(const Tensor & input,ExtraArgs...args)744 Tensor unwrap_and_call(const Tensor& input, ExtraArgs... args) {
745 auto* input_batched = unsafeGetBatchedImpl(input);
746 auto output_physical = Func(input_batched->value(), args...);
747 auto old_bdims = input_batched->bdims();
748 return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
749 }
750
751 template <typename F, F Func, typename... ExtraArgs>
unwrap_and_call_method(const Tensor & input,ExtraArgs...extra_args)752 Tensor unwrap_and_call_method(const Tensor& input, ExtraArgs... extra_args) {
753 auto* input_batched = unsafeGetBatchedImpl(input);
754 auto output_physical = (input_batched->value().*Func)(extra_args...);
755 auto old_bdims = input_batched->bdims();
756 return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
757 }
758
pow_scalar_Tensor_batching_rule(const Scalar & other,const Tensor & self)759 Tensor pow_scalar_Tensor_batching_rule(const Scalar& other, const Tensor& self) {
760 auto* self_batched = unsafeGetBatchedImpl(self);
761 auto output_physical = at::pow(other, self_batched->value());
762 auto old_bdims = self_batched->bdims();
763 return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
764 }
765
clone_batching_rule(const Tensor & self,std::optional<MemoryFormat> memory_format)766 Tensor clone_batching_rule(const Tensor& self, std::optional<MemoryFormat> memory_format) {
767 // Memory format support is a little tricky because vmap is allowed to move
768 // around batch dimensions and some memory formats are rank-dependent.
769 // Another weird case is:
770 // - a tensor with MemoryFormat::ChannelsLast MUST have 4 dimensions. Do we
771 // allow the user to clone a Tensor with 3 logical dimensions and 1 batch
772 // dim into a ChannelsLast Tensor? What about a Tensor with 3 logical dims
773 // and N>1 batch dims?
774 TORCH_CHECK(!memory_format.has_value() || memory_format == MemoryFormat::Preserve
775 || memory_format == MemoryFormat::Contiguous,
776 "NYI: Tensor.clone(memory_format) inside vmap is only supported with ",
777 "memory_format torch.preserve_format or torch.contiguous_format (got ",
778 *memory_format, ")");
779
780 if (memory_format == MemoryFormat::Contiguous) {
781 // There is an ambiguity here when the batch dims are not at the front of
782 // the tensor.
783 // >>> x = torch.randn(3, B0, 5)
784 // >>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1, out_dims=0)(x)
785 // >>> y[0].is_contiguous()
786 // ???
787 // Should we make the whole tensor contiguous, or should we
788 // make the non-batch dims contiguous? We've chosen the latter because
789 // philosophically vmap hides the batch dims and operates on a per-sample level.
790 auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
791 auto output_physical = at::clone(physical_view.tensor(), memory_format);
792 return physical_view.getPhysicalToLogicalMap().apply(output_physical);
793 }
794
795 TORCH_INTERNAL_ASSERT(!memory_format.has_value() || memory_format == MemoryFormat::Preserve);
796 auto* self_batched = unsafeGetBatchedImpl(self);
797 auto output_physical = at::clone(self_batched->value(), memory_format);
798 auto old_bdims = self_batched->bdims();
799 return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
800 }
801
802 // Note [Batching rules for matmul-like operators]
803 // at::matmul doesn't "de-expand" arguments to get better performance (maybe
804 // it should). In the batching rules for matmul-like operators (dot, mv, mm),
805 // we should be careful not to expand any unnecessary dimensions. e.g., if
806 // only one of the two arguments is a BatchedTensor, then we should try
807 // not to expand batch dimensions onto the other arg.
mv_batching_rule(const Tensor & self,const Tensor & other)808 Tensor mv_batching_rule(const Tensor& self, const Tensor& other) {
809 auto self_batched = isBatchedTensor(self);
810 auto other_batched = isBatchedTensor(other);
811
812 // A shape checking API would be nice...
813 TORCH_CHECK(self.dim() == 2 && other.dim() == 1,
814 "mv(self, other): Shape mismatch: expected matrix "
815 "(got `self` of size ", self.sizes(), ") ",
816 "and vector (got `other` of size ", other.sizes(), ")");
817
818 // See Note [Batching rules for matmul-like operators] for why we have cases
819 if (self_batched && !other_batched) {
820 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
821 auto result = at::matmul(self_physical.tensor(), other);
822 return self_physical.getPhysicalToLogicalMap().apply(result);
823 }
824 if (!self_batched && other_batched) {
825 // self_physical: [L, K], other_physical: [..., K]
826 // We view the tensors as [L, K], [..., K, 1], perform matmul to get
827 // a tensor of size [..., L, 1], and unsqueeze the last dim.
828 auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
829 auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1));
830 return other_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
831 }
832 if (self_batched && other_batched) {
833 // self_physical: [..., L, K], other_physical: [..., K]
834 // We view the tensors as [..., L, K], [..., K, 1], perform matmul to get
835 // a tensor of size [..., L, 1], and unsqueeze the last dim.
836 auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
837 auto result = at::matmul(
838 physical_args[0].tensor(),
839 physical_args[1].tensor().unsqueeze(-1));
840 return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1));
841 }
842 TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
843 }
844
_make_dual_batching_rule(c10::DispatchKeySet ks,const Tensor & primal,const Tensor & tangent,int64_t level)845 Tensor _make_dual_batching_rule(
846 c10::DispatchKeySet ks,
847 const Tensor& primal,
848 const Tensor& tangent,
849 int64_t level
850 ) {
851 DispatchKeySet after_batched_keyset =
852 DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::Batched);
853 return at::redispatch::_make_dual(ks & after_batched_keyset, primal, tangent, level);
854 }
855
dot_batching_rule(const Tensor & self,const Tensor & other)856 Tensor dot_batching_rule(const Tensor& self, const Tensor& other) {
857 auto self_batched = isBatchedTensor(self);
858 auto other_batched = isBatchedTensor(other);
859
860 TORCH_CHECK(/*logical*/self.dim() == 1 && /*logical*/other.dim() == 1,
861 "dot(self, other): Shape mismatch: vector "
862 "(got `self` of size ", self.sizes(), ") ",
863 "and vector (got `other` of size ", other.sizes(), ")");
864
865 // See Note [Batching rules for matmul-like operators] for why we have cases
866 if (self_batched && !other_batched) {
867 // self_physical: [..., K], other_physical: [K]
868 // View the tensors as [..., 1, K] and [K], perform matmul, and unsqueeze.
869 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
870 auto result = at::matmul(self_physical.tensor().unsqueeze(-2), other);
871 return self_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
872 }
873 if (!self_batched && other_batched) {
874 // self_physical: [K], other_physical: [..., K]
875 // View the tensors as [K] and [..., K, 1], perform matmul, and unsqueeze.
876 auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
877 auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1));
878 return other_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
879 }
880 if (self_batched && other_batched) {
881 // self_physical: [..., K], other_physical: [..., K]
882 // View the tensors as [..., 1, K] and [..., K, 1], perform matmul, and unsqueeze.
883 auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
884 auto result = at::matmul(
885 physical_args[0].tensor().unsqueeze(-2),
886 physical_args[1].tensor().unsqueeze(-1));
887 return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1).squeeze(-1));
888 }
889 TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
890 }
891
bmm_batching_rule(const Tensor & self,const Tensor & other)892 Tensor bmm_batching_rule(const Tensor& self, const Tensor& other) {
893 TORCH_CHECK(/*logical*/self.dim() == 3 && /*logical*/other.dim() == 3,
894 "bmm(self, other): Shape mismatch: expected 3D `self` "
895 "(got `self` of size ", self.sizes(), ") ",
896 "and 3D `other` (got `other` of size ", other.sizes(), ")");
897
898 auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
899 auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor());
900 return physical_args[0].getPhysicalToLogicalMap().apply(result);
901 }
902
mm_batching_rule(const Tensor & self,const Tensor & other)903 Tensor mm_batching_rule(const Tensor& self, const Tensor& other) {
904 auto self_batched = isBatchedTensor(self);
905 auto other_batched = isBatchedTensor(other);
906
907 TORCH_CHECK(/*logical*/self.dim() == 2 && /*logical*/other.dim() == 2,
908 "mm(self, other): Shape mismatch: expected matrix "
909 "(got `self` of size ", self.sizes(), ") ",
910 "and matrix (got `other` of size ", other.sizes(), ")");
911
912 // See Note [Batching rules for matmul-like operators] for why we have cases
913 if (self_batched && !other_batched) {
914 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
915 auto result = at::matmul(self_physical.tensor(), other);
916 return self_physical.getPhysicalToLogicalMap().apply(result);
917 }
918 if (!self_batched && other_batched) {
919 auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
920 auto result = at::matmul(self, other_physical.tensor());
921 return other_physical.getPhysicalToLogicalMap().apply(result);
922 }
923 if (self_batched && other_batched) {
924 auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
925 auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor());
926 return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1).squeeze(-1));
927 }
928 TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
929 }
930
cat_batching_rule(const ITensorListRef & tensors,int64_t dim)931 Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) {
932 auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
933 auto physical_tensors = fmap(
934 physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
935 TORCH_INTERNAL_ASSERT(
936 !tensors.empty(), "The dispatcher should not have dispatched here otherwise.");
937 auto result = at::cat(physical_tensors, physical_views[0].getPhysicalDim(dim));
938 return physical_views[0].getPhysicalToLogicalMap().apply(result);
939 }
940
stack_batching_rule(TensorList tensors,int64_t dim)941 Tensor stack_batching_rule(TensorList tensors, int64_t dim) {
942 auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
943 auto physical_tensors = fmap(
944 physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
945 TORCH_INTERNAL_ASSERT(
946 !tensors.empty(), "The dispatcher should not have dispatched here otherwise.");
947 // NB: stack wraps the dimensionality to (logical dim + 1), so we have to
948 // manually handle that here.
949 auto dim_physical =
950 physical_views[0].numBatchDims() + maybe_wrap_dim(dim, /*logical*/tensors[0].dim() + 1);
951 auto result = at::stack(physical_tensors, dim_physical);
952 return physical_views[0].getPhysicalToLogicalMap().apply(result);
953 }
954
955 // I am quite sad that we need to register operators with exploded TensorOptions,
956 // even though the native:: implementations can use TensorOptions&.
957 // This also makes it hard to metaprogram: i.e., we can't use
958 // unwrap_and_call<..., at::to> because at::to takes TensorOptions& (!!)
to_dtype_layout_batching_rule(const Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,bool non_blocking,bool copy,std::optional<MemoryFormat> memory_format)959 Tensor to_dtype_layout_batching_rule(
960 const Tensor& self,
961 std::optional<ScalarType> dtype,
962 std::optional<Layout> layout,
963 std::optional<Device> device,
964 std::optional<bool> pin_memory,
965 bool non_blocking, bool copy,
966 std::optional<MemoryFormat> memory_format) {
967 auto options = TensorOptions()
968 .dtype(dtype)
969 .layout(layout)
970 .device(device)
971 .pinned_memory(pin_memory);
972 auto* input_batched = unsafeGetBatchedImpl(self);
973 auto output_physical = input_batched->value().to(options, non_blocking, copy, memory_format);
974 auto old_bdims = input_batched->bdims();
975 return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
976 }
977
new_zeros_batching_rule(const Tensor & self,IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)978 Tensor new_zeros_batching_rule(
979 const Tensor& self,
980 IntArrayRef size,
981 std::optional<ScalarType> dtype,
982 std::optional<Layout> layout,
983 std::optional<Device> device,
984 std::optional<bool> pin_memory) {
985 auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
986 auto physical_size = physical_view.getPhysicalShape(size);
987 auto options = TensorOptions()
988 .dtype(dtype)
989 .layout(layout)
990 .device(device)
991 .pinned_memory(pin_memory);
992 auto result = physical_view.tensor().new_zeros(physical_size, options);
993 return physical_view.getPhysicalToLogicalMap().apply(result);
994 }
995
new_empty_batching_rule(const Tensor & self,IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)996 Tensor new_empty_batching_rule(
997 const Tensor& self,
998 IntArrayRef size,
999 std::optional<ScalarType> dtype,
1000 std::optional<Layout> layout,
1001 std::optional<Device> device,
1002 std::optional<bool> pin_memory) {
1003 auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
1004 auto physical_size = physical_view.getPhysicalShape(size);
1005 auto result = physical_view.tensor().new_empty(physical_size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory));
1006 return physical_view.getPhysicalToLogicalMap().apply(result);
1007 }
1008
new_empty_strided_batching_rule(const Tensor & self,IntArrayRef size,IntArrayRef stride,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)1009 Tensor new_empty_strided_batching_rule(
1010 const Tensor& self,
1011 IntArrayRef size,
1012 IntArrayRef stride,
1013 std::optional<ScalarType> dtype,
1014 std::optional<Layout> layout,
1015 std::optional<Device> device,
1016 std::optional<bool> pin_memory) {
1017 auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
1018 auto physical_size = physical_view.getPhysicalShape(size);
1019
1020 // Let [B0, B1, B2] be the shape of the batch dims. We're going to create
1021 // the batch dimensions at the front of the tensor (in memory layout),
1022 // irrespective of whether or not they are actually at the front (in memory layout)
1023 // in the original `self` tensor. This is because when a user calls
1024 // `new_empty_strided` in general, the `strides` they provide are for a new
1025 // tensor and have no relation to the strides of the original tensor.
1026 //
1027 // So, the physical shape of the result should be ([B0, B1, B2] + size),
1028 // but what about the physical strides?
1029 //
1030 // We're actually free to pick whatever stride we want:
1031 // e.g., for size=[5, 3], stride=[0, 1], we could decide to
1032 // use
1033 // - physical size: [B0, B1, B2, 5, 3]
1034 // - physical stride: [9999*B1*B2, 9999*B2, 9999, 0, 1]
1035 //
1036 // Let's select some reasonable strides such that:
1037 // - The batch dims are "contiguous" with respect to each other
1038 // - if empty_strided(size, stride) would have created a contiguous Tensor,
1039 // then this new physical Tensor (with batch dims) is also contiguous
1040 //
1041 // Let S be the size of the storage if one were to construct a tensor
1042 // with `size` and `stride` via empty_strided(size, stride).
1043 // Then the physical sizes/strides should be:
1044 // - physical size: [B0, B1, B2, 5, 3]
1045 // - physical stride: [B1 * B2 * S, B2 * S, S, 0, 1]
1046 auto batch_shape = IntArrayRef(
1047 physical_view.tensor().sizes().begin(), physical_view.numBatchDims());
1048
1049 // physical_strides = [B1 * B2 * S, B2 * S, S]
1050 auto physical_strides = at::detail::defaultStrides(batch_shape);
1051 TORCH_CHECK(size.size() == stride.size(),
1052 "new_empty_strided(sizes, strides): dimensionality of sizes (",
1053 size.size(), ") must match dimensionality of strides (",
1054 stride.size(), ")");
1055 auto storage_size = native::storage_size_for(size, stride);
1056 for (auto& physical_stride : physical_strides) {
1057 physical_stride *= storage_size;
1058 }
1059
1060 // physical_strides = [B1 * B2 * S, B2 * S, S] + strides
1061 physical_strides.insert(physical_strides.end(), stride.begin(), stride.end());
1062
1063 auto result = physical_view.tensor().new_empty_strided(
1064 physical_size, physical_strides, dtype, layout, device, pin_memory);
1065 return physical_view.getPhysicalToLogicalMap().apply(result);
1066 }
1067
1068 template <typename F, F Func>
comparison_pointwise_batching_rule(const Tensor & self,const Tensor & other)1069 Tensor comparison_pointwise_batching_rule(const Tensor& self, const Tensor& other) {
1070 auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
1071 auto result = Func(physical_args[0].tensor(), physical_args[1].tensor());
1072 return physical_args[0].getPhysicalToLogicalMap().apply(result);
1073 }
1074 }
TORCH_LIBRARY_IMPL(_,Batched,m)1075 TORCH_LIBRARY_IMPL(_, Batched, m) {
1076 m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
1077 }
1078
TORCH_LIBRARY_IMPL(aten,Batched,m)1079 TORCH_LIBRARY_IMPL(aten, Batched, m) {
1080 // NB: Ideally we would like some operators, like size.int, to "fallthrough"
1081 // to the underlying implementation. However, because a BatchedTensor is a
1082 // Tensor wrapper, it only has one dispatch key (Batched) on it. The resolution
1083 // here is to just directly call the underlying implementation.
1084 m.impl("size.int", static_cast<int64_t (*)(const Tensor&, int64_t)>(native::size));
1085 m.impl("_add_batch_dim", native::_add_batch_dim);
1086 m.impl("_remove_batch_dim", native::_remove_batch_dim);
1087 m.impl("_make_dual", _make_dual_batching_rule);
1088 m.impl("_has_same_storage_numel", _has_same_storage_numel_batching_rule);
1089 m.impl("is_same_size", native::is_same_size);
1090 m.impl("_new_zeros_with_same_feature_meta", _new_zeros_with_same_feature_meta_batching_rule);
1091
1092 m.impl("sum.dim_IntList", sum_batching_rule);
1093 m.impl("is_complex", native::is_complex);
1094
1095 // inplace operations
1096 m.impl("fill_.Scalar", fill_inplace_scalar_batching_rule);
1097 m.impl("fill_.Tensor", fill_inplace_tensor_batching_rule);
1098 m.impl("zero_", zero_inplace_batching_rule);
1099
1100 // view operations
1101 m.impl("as_strided", as_strided_batching_rule);
1102 m.impl("chunk", chunk_batching_rule);
1103 m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
1104 m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
1105 m.impl("diagonal", diagonal_batching_rule);
1106 m.impl("expand", expand_batching_rule);
1107 m.impl("expand_as", native::expand_as); // composite wrt autograd
1108 m.impl("movedim.intlist", movedim_batching_rule);
1109 m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
1110 // There is another variant of narrow. However, we don't
1111 // want to support the other variant yet bc it isn't documented...
1112 m.impl("narrow", native::narrow_symint); // composite wrt autograd
1113 m.impl("numpy_T", native::numpy_T); // composite wrt autograd
1114 m.impl("matrix_H", native::matrix_H); // composite wrt autograd
1115 m.impl("mT", native::mT); // composite wrt autograd
1116 m.impl("mH", native::mH); // composite wrt autograd
1117 m.impl("permute", permute_batching_rule);
1118 m.impl("reshape", reshape_batching_rule);
1119 m.impl("_reshape_alias", _reshape_alias_batching_rule);
1120 m.impl("reshape_as", native::reshape_as); // composite wrt autograd
1121 m.impl("select.int", select_batching_rule);
1122 m.impl("slice.Tensor", slice_batching_rule);
1123 m.impl("split.Tensor", split_batching_rule);
1124 m.impl("split.sizes", split_with_sizes_batching_rule);
1125 m.impl("split_with_sizes", split_with_sizes_batching_rule);
1126 m.impl("squeeze", squeeze_batching_rule);
1127 m.impl("squeeze.dim", squeeze_dim_batching_rule);
1128 m.impl("squeeze.dims", squeeze_dims_batching_rule);
1129 m.impl("t", native::t); // composite wrt autograd
1130 m.impl("trace", trace_batching_rule);
1131 m.impl("transpose.int", transpose_int_batching_rule);
1132 m.impl("unbind.int", unbind_batching_rule);
1133 m.impl("unfold", unfold_batching_rule);
1134 m.impl("unsqueeze", unsqueeze_batching_rule);
1135 m.impl("view", view_batching_rule);
1136 m.impl("view_as", native::view_as); // composite wrt autograd
1137
1138 // clamp operations
1139 m.impl("clamp", clamp_batching_rule);
1140 m.impl("clamp_min", clamp_min_batching_rule);
1141 m.impl("clamp_max", clamp_max_batching_rule);
1142
1143 // unary pointwise, out-of-place, no additional arguments.
1144 #define UNARY_POINTWISE(op) m.impl(#op, \
1145 unwrap_and_call<Tensor (*)(const Tensor&), at::op>);
1146 UNARY_POINTWISE(abs);
1147 UNARY_POINTWISE(acos);
1148 UNARY_POINTWISE(asin);
1149 UNARY_POINTWISE(atan);
1150 UNARY_POINTWISE(ceil);
1151 UNARY_POINTWISE(cos);
1152 UNARY_POINTWISE(cosh);
1153 UNARY_POINTWISE(conj_physical);
1154 UNARY_POINTWISE(digamma);
1155 UNARY_POINTWISE(exp);
1156 UNARY_POINTWISE(expm1);
1157 UNARY_POINTWISE(floor);
1158 UNARY_POINTWISE(frac);
1159 UNARY_POINTWISE(lgamma);
1160 UNARY_POINTWISE(log);
1161 UNARY_POINTWISE(log10);
1162 UNARY_POINTWISE(log1p);
1163 UNARY_POINTWISE(log2);
1164 UNARY_POINTWISE(neg);
1165 UNARY_POINTWISE(reciprocal);
1166 UNARY_POINTWISE(relu);
1167 UNARY_POINTWISE(round);
1168 UNARY_POINTWISE(rsqrt);
1169 UNARY_POINTWISE(sigmoid);
1170 UNARY_POINTWISE(sign);
1171 UNARY_POINTWISE(sin);
1172 UNARY_POINTWISE(sinh);
1173 UNARY_POINTWISE(sqrt);
1174 UNARY_POINTWISE(tan);
1175 UNARY_POINTWISE(tanh);
1176 UNARY_POINTWISE(trunc);
1177 #undef UNARY_POINTWISE
1178 #define TO_BATCHING_RULE(name, ...) \
1179 { \
1180 using to_type = Tensor(Tensor::*)(__VA_ARGS__) const; \
1181 m.impl(name, unwrap_and_call_method< \
1182 to_type, &Tensor::to, __VA_ARGS__>);\
1183 }
1184 TO_BATCHING_RULE("to.device", Device, ScalarType, bool, bool, std::optional<MemoryFormat>)
1185 TO_BATCHING_RULE("to.dtype", ScalarType, bool, bool, std::optional<MemoryFormat>)
1186 TO_BATCHING_RULE("to.other", const Tensor&, bool, bool, std::optional<MemoryFormat>)
1187 m.impl("to.dtype_layout", to_dtype_layout_batching_rule);
1188 #undef TO_BATCHING_RULE
1189 m.impl("clone", clone_batching_rule);
1190
1191 using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, const Scalar&);
1192 using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
1193 using TensorScalarType = Tensor (*)(const Tensor&, const Scalar&);
1194
1195 #define BINARY_POINTWISE(op) \
1196 m.impl(#op".Tensor", binary_pointwise_batching_rule<TensorTensorType, at::op>); \
1197 m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, const Scalar&>);
1198 #define BINARY_POINTWISE_VA(op, ...) \
1199 { \
1200 using Binop = Tensor (*)(const Tensor&, const Tensor&, __VA_ARGS__); \
1201 using Unop = Tensor (*)(const Tensor&, const Scalar&, __VA_ARGS__); \
1202 m.impl(#op".Tensor", binary_pointwise_batching_rule<Binop, at::op, __VA_ARGS__>); \
1203 m.impl(#op".Scalar", unwrap_and_call<Unop, at::op, const Scalar&, __VA_ARGS__>); \
1204 }
1205
1206 BINARY_POINTWISE_VA(add, const Scalar&);
1207 BINARY_POINTWISE_VA(sub, const Scalar&);
1208 BINARY_POINTWISE_VA(rsub, const Scalar&);
1209 BINARY_POINTWISE(mul);
1210 BINARY_POINTWISE(div);
1211 {
1212 using Binop = Tensor (*)(const Tensor&, const Tensor&, std::optional<c10::string_view>);
1213 using Unop = Tensor (*)(const Tensor&, const Scalar&, std::optional<c10::string_view>);
1214 m.impl("div.Tensor_mode", binary_pointwise_batching_rule<Binop, at::div, std::optional<c10::string_view>>);
1215 m.impl("div.Scalar_mode", unwrap_and_call<Unop, at::div, const Scalar&, std::optional<c10::string_view>>);
1216 }
1217
1218 // at::pow has three out-of-place overloads
1219 m.impl("pow.Tensor_Tensor", binary_pointwise_batching_rule<TensorTensorType, at::pow>);
1220 m.impl("pow.Tensor_Scalar", unwrap_and_call<TensorScalarType, at::pow, const Scalar&>);
1221 m.impl("pow.Scalar", pow_scalar_Tensor_batching_rule);
1222
1223 m.impl("sigmoid_backward", binary_pointwise_batching_rule<TensorTensorType, at::sigmoid_backward>);
1224 m.impl(
1225 "threshold_backward",
1226 binary_pointwise_batching_rule<
1227 TensorTensorScalarType,
1228 at::threshold_backward,
1229 const Scalar&>);
1230
1231 // for at::result_type, call the native::result_type implementation.
1232 // We don't have to do anything special because native::result_type operates
1233 // on the logical shape of the tensors.
1234 m.impl("result_type.Tensor", static_cast<ScalarType (*)(const Tensor&, const Tensor&)>(native::result_type));
1235 m.impl("result_type.Scalar", static_cast<ScalarType (*)(const Tensor&, const Scalar&)>(native::result_type));
1236 m.impl("result_type.Scalar_Tensor", static_cast<ScalarType (*)(const Scalar&, const Tensor&)>(native::result_type));
1237 m.impl("result_type.Scalar_Scalar", static_cast<ScalarType (*)(const Scalar&, const Scalar&)>(native::result_type));
1238
1239 #undef BINARY_POINTWISE_VA
1240 #undef BINARY_POINTWISE
1241
1242
1243 #define TRIVIAL_OP(op) m.impl(#op, \
1244 unwrap_and_call<Tensor (*)(const Tensor&), at::op>);
1245 // complex number view operators
1246 TRIVIAL_OP(imag)
1247 TRIVIAL_OP(real);
1248 TRIVIAL_OP(view_as_real);
1249 TRIVIAL_OP(conj);
1250 TRIVIAL_OP(_conj);
1251 TRIVIAL_OP(resolve_conj);
1252 TRIVIAL_OP(resolve_neg);
1253 m.impl("view_as_complex", view_as_complex_batching_rule);
1254 #undef TRIVIAL
1255
1256 // matmul-like operators
1257 m.impl("mv", mv_batching_rule);
1258 m.impl("dot", dot_batching_rule);
1259 m.impl("bmm", bmm_batching_rule);
1260 m.impl("mm", mm_batching_rule);
1261
1262 // cat/stack
1263 m.impl("cat", cat_batching_rule);
1264 m.impl("stack", stack_batching_rule);
1265
1266 // backward operators
1267 m.impl("select_backward", select_backward_batching_rule);
1268 m.impl("slice_backward", slice_backward_batching_rule);
1269 m.impl("trace_backward", trace_backward_batching_rule);
1270 m.impl("diagonal_backward", diagonal_backward_batching_rule);
1271
1272 // Tensor.new_* operators
1273 m.impl("new_empty", new_empty_batching_rule);
1274 m.impl("new_empty_strided", new_empty_strided_batching_rule);
1275 m.impl("new_zeros", new_zeros_batching_rule);
1276
1277 m.impl("contiguous", contiguous_batching_rule);
1278
1279 // Comparison ops
1280 #define COMPARISON_POINTWISE(op) \
1281 m.impl(#op".Tensor", comparison_pointwise_batching_rule<TensorTensorType, at::op>); \
1282 m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, const Scalar&>);
1283
1284 COMPARISON_POINTWISE(eq);
1285 COMPARISON_POINTWISE(gt);
1286 COMPARISON_POINTWISE(ge);
1287 COMPARISON_POINTWISE(le);
1288 COMPARISON_POINTWISE(lt);
1289 COMPARISON_POINTWISE(ne);
1290
1291 #undef COMPARISON_POINTWISE
1292 }
1293
1294 } // namespace at
1295