xref: /aosp_15_r20/external/pytorch/aten/src/ATen/LegacyBatchingRegistrations.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/library.h>
2 #include <ATen/ATen.h>
3 #include <ATen/LegacyVmapTransforms.h>
4 #include <ATen/LegacyBatchedFallback.h>
5 #include <ATen/RedispatchFunctions.h>
6 #include <ATen/native/ResizeCommon.h>
7 #include <ATen/core/IListRef.h>
8 #include <c10/util/irange.h>
9 #include <c10/core/SymIntArrayRef.h>
10 
11 #include <utility>
12 
13 namespace at {
14 
15 // NOTE: [What is a batching rule?]
16 //
17 // A *batching rule* implements the logic of how to call an operator on inputs
18 // that have zero or more additional batch dimensions. When one does a vmap, the
19 // dimension(s) being vmap'ed over get recorded as batch dimensions.
20 //
21 // For example, vmap(torch.add)(x, y)
22 // 1. wraps `x` into batched_x = BatchedTensor(x, bdims=[(lvl=1, dim=0)];
23 // 2. wraps `y` into batched_y = BatchedTensor(y, bdims=[(lvl=1, dim=0)];
24 // 3. and then runs `torch.add(batched_x, batched_y)`.
25 
26 // NOTE: [When should I add a batching rule?]
27 // When you are adding a new operator, you'll need to add a batching rule so
28 // that vmap can work efficiently with said operator. If you do not, we'll attempt
29 // to generate a slow fallback for the batching rule.
30 
31 // NOTE: [How to write batching rules?]
32 // The signature of a batching rule should look like exactly like the C++ signature
33 // of its operator.
34 //
35 // First, see NOTE: [Logical vs physical args] in VmapTransforms.h for terminology.
36 //
37 // At a high level, what a batching rule does is the following:
38 // 1. Converts (logical) BatchedTensors to views on physical tensors.
39 // 2. Converts logical arguments (e.g. dimension indexes, shapes) to physical
40 //    arguments that correspond to the physical tensors.
41 // 3. Calls at:: operations on the physical tensors and arguments to produce
42 //    some physical results.
43 // 4. Converts physical results back to BatchedTensors.
44 //
45 // Steps 1, 2, and 4 differ for operators with different batching behaviors. When
46 // writing a new batching rule, please select a VmapTransform that matches the
47 // batching behavior of your operation. The VmapTransform provides helper functions
48 // to do steps (1), (2), and (4).
49 // (see NOTE: [What is an VmapTransform?] in VmapTransforms.h)
50 
51 // Note: [Future plans]
52 // The API for writing a batching rule isn't stable. In the future, we'd like
53 // to think about the problem of translating these batching rules to TorchScript.
54 // Ideally batching rules in eager mode vs TorchScript would look pretty similar,
55 // if not use the same mechanism. In order to accomplish that we might have to
56 // do some refactoring.
57 
58 namespace{
59 
60 // PyTorch allows operations to specify dim 0 and dim -1 on a scalar tensor.
is_allowed_dim_on_scalar_tensor(int64_t dim)61 static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
62   return dim == 0 || dim == -1;
63 }
64 
sum_batching_rule(const Tensor & self,OptionalIntArrayRef opt_dims,bool keepdim,std::optional<ScalarType> dtype)65 Tensor sum_batching_rule(const Tensor& self, OptionalIntArrayRef opt_dims, bool keepdim, std::optional<ScalarType> dtype) {
66   if (opt_dims.has_value()) {
67     auto dims = opt_dims.value();
68     // PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail
69     // and instead returns a new scalar tensor (this also happens for dim=-1)
70     // If the following happens:
71     // >>> x = torch.randn(B0)  # the per-examples are all scalars
72     // >>> vmap(partial(torch.sum, dim=0), x)
73     // then we replicate the behavior of sum(scalar_tensor, dim=0).
74     if (/*logical*/self.dim() == 0 && (dims.empty() || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) {
75       return self.clone();
76     }
77   }
78   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
79   auto dims_physical = self_physical.getPhysicalDims(opt_dims);
80   auto result = at::sum(self_physical.tensor(), dims_physical, keepdim, dtype);
81   return self_physical.getPhysicalToLogicalMap().apply(result);
82 }
83 
isPhysicalScalarTensor(const Tensor & logical_tensor)84 bool isPhysicalScalarTensor(const Tensor& logical_tensor) {
85   if (logical_tensor.dim() > 0) {
86     return false;
87   }
88   auto* batched = maybeGetBatchedImpl(logical_tensor);
89   if (batched) {
90     return false;
91   }
92   return true;
93 }
94 
95 template <typename F, F Func, typename... ExtraArgs>
binary_pointwise_batching_rule(const Tensor & self,const Tensor & other,ExtraArgs...args)96 Tensor binary_pointwise_batching_rule(
97     const Tensor& self, const Tensor& other, ExtraArgs... args) {
98   if (self.dim() > 0 && other.dim() > 0) {
99     auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
100     auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
101     return physical_args[0].getPhysicalToLogicalMap().apply(result);
102   }
103   if (isPhysicalScalarTensor(self)) {
104     auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
105     auto result = Func(self, other_physical.tensor(), args...);
106     return other_physical.getPhysicalToLogicalMap().apply(result);
107   }
108   if (isPhysicalScalarTensor(other)) {
109     auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
110     auto result = Func(self_physical.tensor(), other, args...);
111     return self_physical.getPhysicalToLogicalMap().apply(result);
112   }
113 
114   // At this point, we know at least one of the operands is a logical Scalar tensor.
115   // Here we must emulate TensorIterator's special behavior on Scalars.
116   //
117   // As a motivating example, consider the following:
118   //   x = torch.randn(3, 10)
119   //   y = torch.randn(3, dtype=torch.double)
120   //   vmap(torch.mul)(torch.randn(3, 10), torch.randn(3, dtype=torch.double))
121   //
122   // At a per-example level, we are adding FloatTensor[10] and DoubleTensor[];
123   // Type Promotion dictates that the result should be FloatTensor[10].
124   // This means we cannot directly pass the physical tensors (x and y) to
125   // TensorIterator (if we did, it would promote them to DoubleTensor).
126   //
127   // FIXME(rzou): I didn't want to go down the slippery slope of emulating
128   // everything TensorIterator does (it would be better to refactor out the
129   // TensorIterator logic). The one thing that this code doesn't handle
130   // is cross-device logical scalar tensors.
131   //   cpu_tensor = torch.randn(3)
132   //   cuda_tensor = torch.randn(3, 10, device='cuda')
133   //   vmap(torch.mul)(cpu_tensor, cuda_tensor)
134   //
135   // At a per-example level, we are adding CPUTensor[] and CUDATensor[10].
136   // TensorIterator allows for this cross-device operation because one of the
137   // tensors is a Scalar CPU tensor. However, the following code will throw an
138   // error in that case. I don't expect to see many use cases for this, so
139   // this is probably fine as-is.
140   auto logical_self = self;
141   auto logical_other = other;
142   auto result_type = at::native::result_type(logical_self, logical_other);
143   if (logical_self.scalar_type() != result_type) {
144     logical_self = logical_self.to(result_type);
145   }
146   if (logical_other.scalar_type() != result_type) {
147     logical_other = logical_other.to(result_type);
148   }
149   auto physical_args = BroadcastingVmapTransform::logicalToPhysical(
150       {std::move(logical_self), std::move(logical_other)});
151   auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
152   return physical_args[0].getPhysicalToLogicalMap().apply(result);
153 }
154 
expand_batching_rule(const Tensor & self,IntArrayRef size,bool implicit)155 Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) {
156   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
157   auto size_physical = self_physical.getPhysicalShape(size);
158   auto self_physical_dim = self_physical.tensor().dim();
159 
160   TORCH_CHECK(self_physical_dim <= static_cast<int64_t>(size_physical.size()),
161        "expand: the number of sizes provided (", /*logical*/size.size(), ") ",
162        "must be greater or equal to the number of dimensions in the tensor (",
163        /*logical dim*/self.dim(), ")");
164 
165   if (self_physical_dim == static_cast<int64_t>(size_physical.size())) {
166     auto result = self_physical.tensor().expand(size_physical, implicit);
167     return self_physical.getPhysicalToLogicalMap().apply(result);
168   }
169 
170   TORCH_INTERNAL_ASSERT(self_physical_dim < static_cast<int64_t>(size_physical.size()));
171   // Here, we know we are expanding a (logical) tensor to a larger number
172   // of dimensions. We have to be careful because we can't call expand directly
173   // due to the presence of batch dimensions.
174   //
175   // As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]).
176   // The result should be a tensor of size [B0, 2, 3].
177   // A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3]
178   // so the strategy here is to view it first as a tensor of size [B0, 1, 3] and
179   // then expand.
180   auto self_physical_size = self_physical.tensor().sizes();
181   auto extra_dims = size_physical.size() - self_physical_dim;
182   VmapDimVector view_shape(size_physical.size(), 1);
183   std::copy(self_physical_size.begin(),
184             self_physical_size.begin() + self_physical.numBatchDims(),
185             view_shape.begin());
186   std::copy(self_physical_size.begin() + self_physical.numBatchDims(),
187             self_physical_size.end(),
188             view_shape.begin() + self_physical.numBatchDims() + extra_dims);
189   auto result = self_physical.tensor().view(view_shape).expand(size_physical, implicit);
190   return self_physical.getPhysicalToLogicalMap().apply(result);
191 }
192 
chunk_batching_rule(const Tensor & self,int64_t chunks,int64_t dim)193 std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
194   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
195   auto dim_physical = self_physical.getPhysicalDim(dim);
196   auto result = at::chunk(self_physical.tensor(), chunks, dim_physical);
197   self_physical.getPhysicalToLogicalMap().applyInplace(result);
198   return result;
199 }
200 
clamp_batching_rule(const Tensor & self,const std::optional<Scalar> & min,const std::optional<Scalar> & max)201 Tensor clamp_batching_rule(const Tensor& self, const std::optional<Scalar>& min, const std::optional<Scalar>& max) {
202   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
203   auto result = at::clamp(self_physical.tensor(), min, max);
204   return self_physical.getPhysicalToLogicalMap().apply(result);
205 }
206 
clamp_min_batching_rule(const Tensor & self,const Scalar & min)207 Tensor clamp_min_batching_rule(const Tensor& self, const Scalar& min) {
208   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
209   auto result = at::clamp_min(self_physical.tensor(), min);
210   return self_physical.getPhysicalToLogicalMap().apply(result);
211 }
212 
clamp_max_batching_rule(const Tensor & self,const Scalar & max)213 Tensor clamp_max_batching_rule(const Tensor& self, const Scalar& max) {
214   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
215   auto result = at::clamp_max(self_physical.tensor(), max);
216   return self_physical.getPhysicalToLogicalMap().apply(result);
217 }
218 
tensor_split_sections_batching_rule(const Tensor & self,int64_t sections,int64_t dim)219 std::vector<Tensor> tensor_split_sections_batching_rule(const Tensor& self, int64_t sections, int64_t dim) {
220   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
221   auto dim_physical = self_physical.getPhysicalDim(dim);
222   auto result = at::tensor_split(self_physical.tensor(), sections, dim_physical);
223   self_physical.getPhysicalToLogicalMap().applyInplace(result);
224   return result;
225 }
226 
tensor_split_indices_batching_rule(const Tensor & self,IntArrayRef indices,int64_t dim)227 std::vector<Tensor> tensor_split_indices_batching_rule(const Tensor& self, IntArrayRef indices, int64_t dim) {
228   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
229   auto dim_physical = self_physical.getPhysicalDim(dim);
230   auto result = at::tensor_split(self_physical.tensor(), indices, dim_physical);
231   self_physical.getPhysicalToLogicalMap().applyInplace(result);
232   return result;
233 }
234 
unsqueeze_batching_rule(const Tensor & self,int64_t dim)235 Tensor unsqueeze_batching_rule(const Tensor& self, int64_t dim) {
236   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
237   // NB: unsqueeze has some special handling of its `dim` argument so we can't call
238   // self_physical.getPhysicalDim directly. In particular, native::unsqueeze
239   // wraps the dim to (the logical dimension) + 1, so we need to do that here too.
240   // https://github.com/pytorch/pytorch/blob/b623bdeabb0aa8da44285d303246e7f8ac06c2a9/aten/src/ATen/native/TensorShape.cpp#L1413
241   auto dim_physical =
242       self_physical.numBatchDims() + maybe_wrap_dim(dim, /*logical_dim*/self.dim() + 1);
243   auto result = self_physical.tensor().unsqueeze(dim_physical);
244   return self_physical.getPhysicalToLogicalMap().apply(result);
245 }
246 
fill_inplace_scalar_batching_rule(Tensor & self,const Scalar & value)247 Tensor& fill_inplace_scalar_batching_rule(Tensor& self, const Scalar& value) {
248   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
249   self_physical.tensor().fill_(value);
250   return self;
251 }
252 
fill_inplace_tensor_batching_rule(Tensor & self,const Tensor & value)253 Tensor& fill_inplace_tensor_batching_rule(Tensor& self, const Tensor& value) {
254   auto value_batched = isBatchedTensor(value);
255 
256   if (value_batched) {
257     auto physical_args =
258       BroadcastingVmapTransform::logicalToPhysical({self, value});
259     physical_args[0].tensor().copy_(physical_args[1].tensor());
260   } else {
261     auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
262     self_physical.tensor().fill_(value);
263   }
264   return self;
265 }
266 
zero_inplace_batching_rule(Tensor & self)267 Tensor& zero_inplace_batching_rule(Tensor &self) {
268   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
269   self_physical.tensor().zero_();
270   return self;
271 }
272 
squeeze_batching_rule(const Tensor & self)273 Tensor squeeze_batching_rule(const Tensor& self) {
274   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
275   auto physical_sizes = self_physical.tensor().sizes();
276 
277   // Don't squeeze the batch dims!
278   VmapDimVector squeezed_sizes;
279   int64_t num_batch_dims = self_physical.numBatchDims();
280   squeezed_sizes.insert(
281       squeezed_sizes.end(),
282       physical_sizes.begin(),
283       physical_sizes.begin() + num_batch_dims);
284   for (auto it = physical_sizes.begin() + num_batch_dims; it != physical_sizes.end(); ++it) {
285     if (*it != 1) {
286       squeezed_sizes.push_back(*it);
287     }
288   }
289 
290   auto result = self_physical.tensor().view(squeezed_sizes);
291   return self_physical.getPhysicalToLogicalMap().apply(result);
292 }
293 
squeeze_dim_batching_rule(const Tensor & self,int64_t dim)294 Tensor squeeze_dim_batching_rule(const Tensor& self, int64_t dim) {
295   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
296   auto dim_physical = self_physical.getPhysicalDim(dim);
297   auto result = self_physical.tensor().squeeze(dim_physical);
298   return self_physical.getPhysicalToLogicalMap().apply(result);
299 }
300 
squeeze_dims_batching_rule(const Tensor & self,IntArrayRef dims)301 Tensor squeeze_dims_batching_rule(const Tensor& self, IntArrayRef dims) {
302   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
303   auto dims_physical = self_physical.getPhysicalDims(dims);
304   auto result = self_physical.tensor().squeeze(dims_physical);
305   return self_physical.getPhysicalToLogicalMap().apply(result);
306 }
307 
trace_batching_rule(const Tensor & self)308 Tensor trace_batching_rule(const Tensor& self) {
309   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
310   // Batched Diagonal View
311   auto self_diag = at::diagonal(self_physical.tensor(), /*offset*/0, /*dim1*/-2, /*dim2*/-1);
312   auto result =  at::sum(self_diag, -1);
313   return self_physical.getPhysicalToLogicalMap().apply(result);
314 }
315 
trace_backward_batching_rule(const Tensor & grad,IntArrayRef input_sizes)316 Tensor trace_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes) {
317   auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
318   auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
319   // Batched Diagonal View
320   auto grad_input_diag = at::diagonal(grad_input, /*offset*/0, /*dim1*/-2, /*dim2*/-1);
321   // Append a dimension of size one to the grad output
322   auto grad_physical_tensor = grad_physical.tensor().unsqueeze(-1);
323   grad_input_diag.copy_(grad_physical_tensor);
324   return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
325 }
326 
transpose_int_batching_rule(const Tensor & self,int64_t dim0,int64_t dim1)327 Tensor transpose_int_batching_rule(const Tensor& self, int64_t dim0, int64_t dim1) {
328   // PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works
329   // for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens:
330   // >>> x = torch.randn(B0)  # the per-examples are all scalars
331   // >>> vmap(lambda x: x.transpose(0, -1), x)
332   // then we replicate this behavior.
333   if (/*logical*/self.dim() == 0 && is_allowed_dim_on_scalar_tensor(dim0) &&
334       is_allowed_dim_on_scalar_tensor(dim1)) {
335     return self;
336   }
337   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
338   auto dim0_physical = self_physical.getPhysicalDim(dim0);
339   auto dim1_physical = self_physical.getPhysicalDim(dim1);
340   auto result = self_physical.tensor().transpose(dim0_physical, dim1_physical);
341   return self_physical.getPhysicalToLogicalMap().apply(result);
342 }
343 
permute_batching_rule(const Tensor & self,IntArrayRef dims)344 Tensor permute_batching_rule(const Tensor& self, IntArrayRef dims) {
345   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
346   auto dims_physical = self_physical.getPhysicalDims(dims);
347 
348   VmapDimVector all_dims_physical;
349   all_dims_physical.reserve(self_physical.tensor().dim());
350   for (const auto bdim : c10::irange(self_physical.numBatchDims())) {
351     all_dims_physical.push_back(bdim);
352   }
353   all_dims_physical.insert(
354       all_dims_physical.end(),
355       dims_physical.begin(),
356       dims_physical.end());
357   auto result = self_physical.tensor().permute(all_dims_physical);
358   return self_physical.getPhysicalToLogicalMap().apply(result);
359 }
360 
select_batching_rule(const Tensor & self,int64_t dim,int64_t index)361 Tensor select_batching_rule(const Tensor& self, int64_t dim, int64_t index) {
362   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
363   auto dim_physical = self_physical.getPhysicalDim(dim);
364   auto result = self_physical.tensor().select(dim_physical, index);
365   return self_physical.getPhysicalToLogicalMap().apply(result);
366 }
367 
getGradInputPhysicalDim(int64_t dim,IntArrayRef input_sizes,int64_t num_batch_dims)368 static int64_t getGradInputPhysicalDim(int64_t dim, IntArrayRef input_sizes, int64_t num_batch_dims) {
369   return maybe_wrap_dim(dim, input_sizes.size()) + num_batch_dims;
370 }
371 
select_backward_batching_rule(const Tensor & grad,IntArrayRef input_sizes,int64_t dim,int64_t index)372 Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
373   auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
374   auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
375   auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims());
376   grad_input.select(physical_dim, index).copy_(grad_physical.tensor());
377   return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
378 }
379 
slice_batching_rule(const Tensor & self,int64_t dim,std::optional<int64_t> start,std::optional<int64_t> end,int64_t step)380 Tensor slice_batching_rule(
381     const Tensor& self,
382     int64_t dim,
383     std::optional<int64_t> start,
384     std::optional<int64_t> end,
385     int64_t step) {
386   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
387   auto dim_physical = self_physical.getPhysicalDim(dim);
388   auto result = self_physical.tensor().slice(dim_physical, start, end, step);
389   return self_physical.getPhysicalToLogicalMap().apply(result);
390 }
391 
slice_backward_batching_rule(const Tensor & grad,IntArrayRef input_sizes,int64_t dim,int64_t start,int64_t end,int64_t step)392 Tensor slice_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
393   auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
394   auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
395   auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims());
396   grad_input.slice(physical_dim, start, end, step).copy_(grad_physical.tensor());
397   return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
398 }
399 
diagonal_batching_rule(const Tensor & self,int64_t offset,int64_t dim1,int64_t dim2)400 Tensor diagonal_batching_rule(const Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) {
401   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
402   auto dim1_physical = self_physical.getPhysicalDim(dim1);
403   auto dim2_physical = self_physical.getPhysicalDim(dim2);
404   auto result = at::diagonal(self_physical.tensor(), offset, dim1_physical, dim2_physical);
405   return self_physical.getPhysicalToLogicalMap().apply(result);
406 }
407 
diagonal_backward_batching_rule(const Tensor & grad,IntArrayRef input_sizes,int64_t offset,int64_t dim1,int64_t dim2)408 Tensor diagonal_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
409   auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
410   auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
411   auto dim1_physical = getGradInputPhysicalDim(dim1, input_sizes, grad_physical.numBatchDims());
412   auto dim2_physical = getGradInputPhysicalDim(dim2, input_sizes, grad_physical.numBatchDims());
413   grad_input.diagonal(offset, dim1_physical, dim2_physical).copy_(grad_physical.tensor());
414   return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
415 }
416 
movedim_batching_rule(const Tensor & self,IntArrayRef source,IntArrayRef destination)417 Tensor movedim_batching_rule(const Tensor& self, IntArrayRef source, IntArrayRef destination) {
418   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
419   auto source_physical = self_physical.getPhysicalDims(source);
420   auto destination_physical = self_physical.getPhysicalDims(destination);
421   auto result = at::movedim(self_physical.tensor(), source_physical, destination_physical);
422   return self_physical.getPhysicalToLogicalMap().apply(result);
423 }
424 
reshape_batching_rule(const Tensor & self,IntArrayRef shape)425 Tensor reshape_batching_rule(const Tensor& self, IntArrayRef shape) {
426   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
427   auto shape_physical = self_physical.getPhysicalShape(shape);
428   auto result = self_physical.tensor().reshape(shape_physical);
429   return self_physical.getPhysicalToLogicalMap().apply(result);
430 }
431 
split_batching_rule(const Tensor & self,int64_t split_size,int64_t dim)432 std::vector<Tensor> split_batching_rule(const Tensor& self, int64_t split_size, int64_t dim) {
433   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
434   auto dim_physical = self_physical.getPhysicalDim(dim);
435   auto result = at::split(self_physical.tensor(), split_size, dim_physical);
436   self_physical.getPhysicalToLogicalMap().applyInplace(result);
437   return result;
438 }
439 
split_with_sizes_batching_rule(const Tensor & self,IntArrayRef split_sizes,int64_t dim)440 std::vector<Tensor> split_with_sizes_batching_rule(const Tensor& self, IntArrayRef split_sizes, int64_t dim) {
441   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
442   auto dim_physical = self_physical.getPhysicalDim(dim);
443   auto result = at::split_with_sizes(self_physical.tensor(), split_sizes, dim_physical);
444   self_physical.getPhysicalToLogicalMap().applyInplace(result);
445   return result;
446 }
447 
unbind_batching_rule(const Tensor & self,int64_t dim)448 std::vector<Tensor> unbind_batching_rule(const Tensor& self, int64_t dim) {
449   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
450   auto dim_physical = self_physical.getPhysicalDim(dim);
451   auto result = at::unbind(self_physical.tensor(), dim_physical);
452   self_physical.getPhysicalToLogicalMap().applyInplace(result);
453   return result;
454 }
455 
unfold_batching_rule(const Tensor & self,int64_t dim,int64_t size,int64_t step)456 Tensor unfold_batching_rule(const Tensor& self, int64_t dim, int64_t size, int64_t step) {
457   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
458   auto dim_physical = self_physical.getPhysicalDim(dim);
459   auto result = self_physical.tensor().unfold(dim_physical, size, step);
460   return self_physical.getPhysicalToLogicalMap().apply(result);
461 }
462 
contiguous_batching_rule(const Tensor & self,MemoryFormat memory_format)463 Tensor contiguous_batching_rule(const Tensor& self, MemoryFormat memory_format) {
464   TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
465       "NYI: Tensor.contiguous(...) inside of vmap for memory_format other ",
466       "than torch.contiguous_format");
467   auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
468   auto result = physical_view.tensor().contiguous(memory_format);
469   return physical_view.getPhysicalToLogicalMap().apply(result);
470 }
471 
view_batching_rule(const Tensor & self,IntArrayRef size)472 Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
473   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
474   auto size_physical = self_physical.getPhysicalShape(size);
475   auto result = self_physical.tensor().view(size_physical);
476   return self_physical.getPhysicalToLogicalMap().apply(result);
477 }
478 
view_as_complex_batching_rule(const Tensor & self)479 Tensor view_as_complex_batching_rule(const Tensor& self) {
480   // guard against the user passing in a batch of scalar tensors with batch
481   // size equal to 2.
482   TORCH_CHECK(!self.sizes().empty(), "Input tensor must have one or more dimensions");
483   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
484   auto result = at::view_as_complex(self_physical.tensor());
485   return self_physical.getPhysicalToLogicalMap().apply(result);
486 }
487 
488 // Checks that the smallest batch stride is greater than the largest example
489 // stride. This is something we can support but we choose not to because it's
490 // potentially error prone.
checkBatchDimsAtFrontInLayout(IntArrayRef physical_strides,int64_t num_batch_dims)491 static void checkBatchDimsAtFrontInLayout(IntArrayRef physical_strides, int64_t num_batch_dims) {
492   auto smallest_batch_stride = std::min_element(
493       physical_strides.begin(), physical_strides.begin() + num_batch_dims);
494   auto largest_example_stride = std::max_element(
495       physical_strides.begin() + num_batch_dims, physical_strides.end());
496   if (largest_example_stride == physical_strides.end()) {
497     // No example dimensions
498     return;
499   }
500   TORCH_CHECK(*smallest_batch_stride >= *largest_example_stride,
501     "vmap: Calling Tensor.as_strided is not supported unless the batch dims being ",
502     "vmapped over are at the front of the tensor (in memory layout). When they are ",
503     "not at the front of the tensor this operation can be error prone so we "
504     "actively discourage it; please file us a bug report and/or try to ",
505     "express the as_strided operation in terms of PyTorch view operations");
506 }
507 
508 // given (sizes, strides, storage_offset) returns the maximum location that
509 // can be indexed (or nullopt if such a location doesn't exist, e.g., tensors
510 // with zero-size dims).
maximum_indexable_location(IntArrayRef sizes,IntArrayRef strides,int64_t storage_offset)511 static std::optional<int64_t> maximum_indexable_location(
512     IntArrayRef sizes, IntArrayRef strides, int64_t storage_offset) {
513   auto result = native::storage_size_for(sizes, strides);
514   if (result == 0) {
515     return std::nullopt;
516   }
517   return result + storage_offset;
518 }
519 
520 // Let x be the "first slice" of physical_tensor.
521 // This checks that the range of possible memory locations accessible by
522 // x.as_strided(sizes, strides, maybe_storage_offset)
523 // are within the bounds of possible memory locations accessible by x.
checkBasicAsStridedValidForSlice(const Tensor & physical_tensor,int64_t num_batch_dims,IntArrayRef sizes,IntArrayRef strides,std::optional<int64_t> maybe_storage_offset)524 static void checkBasicAsStridedValidForSlice(
525     const Tensor& physical_tensor,
526     int64_t num_batch_dims,
527     IntArrayRef sizes,
528     IntArrayRef strides,
529     std::optional<int64_t> maybe_storage_offset) {
530   auto slice_sizes = physical_tensor.sizes().slice(num_batch_dims);
531   auto slice_strides = physical_tensor.strides().slice(num_batch_dims);
532   auto base_offset = physical_tensor.storage_offset();
533 
534   auto storage_offset = maybe_storage_offset.value_or(base_offset);
535 
536   auto max_as_strided_loc = maximum_indexable_location(sizes, strides, storage_offset);
537   auto max_slice_loc = maximum_indexable_location(slice_sizes, slice_strides, base_offset);
538 
539   if (!max_as_strided_loc.has_value()) {
540     return;
541   }
542   if (!max_slice_loc.has_value()) {
543     TORCH_CHECK(false,
544         "result = tensor.as_strided(", sizes, ",",  strides, ",", storage_offset, ")",
545         "can access memory outside of `tensor`. `tensor` has no storage but the ",
546         "passed-in (size, stride, storage_offset) imply a result with some storage. ",
547         "This is not supported inside of vmap, please try to rewrite the ",
548         "`as_strided` call as a sequence of PyTorch view operations");
549   }
550 
551   TORCH_CHECK(
552       *max_as_strided_loc <= *max_slice_loc && base_offset <= storage_offset,
553       "result = tensor.as_strided(", sizes, ",",  strides, ",", storage_offset, ")",
554       "can access memory outside of `tensor`. `result` can access some",
555       "memory in range [", storage_offset, ", ", *max_as_strided_loc, "], but ",
556       "`tensor` can only access some memory in range [", base_offset, ", ",
557       *max_slice_loc, "]. This is not supported inside of vmap, please try to",
558       "rewrite the `as_strided` call as a sequence of PyTorch view operations");
559 }
560 
_reshape_alias_batching_rule(const Tensor & self,IntArrayRef sizes,IntArrayRef strides)561 Tensor _reshape_alias_batching_rule(const Tensor& self, IntArrayRef sizes, IntArrayRef strides [[maybe_unused]]) {
562   return reshape_batching_rule(self, sizes);
563 }
564 
_new_zeros_with_same_feature_meta_batching_rule(const Tensor & self,const Tensor & other,int64_t unused_num_batch_dims)565 Tensor _new_zeros_with_same_feature_meta_batching_rule(
566     const Tensor& self,
567     const Tensor& other,
568     int64_t unused_num_batch_dims) {
569   TORCH_CHECK(isBatchedTensor(self) && !isBatchedTensor(other),
570     "Only the 'batched grad' use case is supported in PyTorch core.");
571 
572   TORCH_INTERNAL_ASSERT(unused_num_batch_dims == 0,
573     "num_batch_dims should not be explicitly passed in because it will be overridden");
574   auto self_physical_view = at::MultiBatchVmapTransform::logicalToPhysical(self);
575   const auto& self_physical_tensor = self_physical_view.tensor();
576   int64_t num_batch_dims = self_physical_view.numBatchDims();
577   checkBatchDimsAtFrontInLayout(self_physical_tensor.strides(), num_batch_dims);
578   auto result = at::_new_zeros_with_same_feature_meta(self_physical_tensor, other, num_batch_dims);
579   return self_physical_view.getPhysicalToLogicalMap().apply(result);
580 }
581 
_has_same_storage_numel_batching_rule(const Tensor & self,const Tensor & other)582 bool _has_same_storage_numel_batching_rule(const Tensor& self, const Tensor& other) {
583   TORCH_CHECK(isBatchedTensor(self) && !isBatchedTensor(other),
584     "Only the 'batched grad' use case is supported in PyTorch core.");
585   // The _has_same_storage_numel check is skipped if the tangent is a batched
586   // tensor because using as_strided to access storage locations not indexable
587   // by the input tensor is not supported in vmap
588   return true;
589 }
590 
591 // What are the semantics of as_strided inside of vmap?
592 // y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs)
593 // This returns a view on `x`, `y`, such that each y[i] has:
594 // - sizes: `sizes`
595 // - strides: `strides`
596 // - storage_offset: offset + i * x.stride(batch_dim)
597 //
598 // In other words, it is as if we had treated each x[i] as having storage
599 // offset equal to xs.offset() and called as_strided(sizes, sizes, offset).
600 // (that is equivalent to x[i].as_strided(
601 //    sizes, sizes, offset + x[i].storage_offset() - xs.offset()) for all i)
602 //
603 // Note that this *may* be different from actually running as_strided
604 // in a for-loop. This is due to how as_strided takes in `offset` to be
605 // an *absolute* offset. As an example, consider:
606 // >>> x = torch.tensor([0., 1., 2., 3., 4.]).as_strided([4], [1], 1)
607 // >>> z = [x[i].as_strided([1], [1], 1) for i in range(4)]
608 // Each z[i] is actually the same view on x (z[i] == torch.tensor([1.]))!
609 // However, we consider the above for-loop comprehension to be a user error:
610 // a user should have written the following if they wanted to use as_strided
611 // in a per-sample way:
612 // >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset() - 1) for i in range(4)]
as_strided_batching_rule(const Tensor & tensor,IntArrayRef sizes,IntArrayRef strides,std::optional<int64_t> storage_offset)613 Tensor as_strided_batching_rule(
614     const Tensor& tensor,
615     IntArrayRef sizes,
616     IntArrayRef strides,
617     std::optional<int64_t> storage_offset) {
618   auto physical_view = at::MultiBatchVmapTransform::logicalToPhysical(tensor);
619   auto num_batch_dims = physical_view.numBatchDims();
620   auto physical_sizes = physical_view.getPhysicalShape(sizes);
621   const auto& physical_tensor = physical_view.tensor();
622 
623   // We can't rely on the physical as_strided call to do this for us because
624   // we do some sanity checks on the size/strides before calling into as_strided.
625   TORCH_CHECK(sizes.size() == strides.size(),
626       "Tensor.as_strided(size, stride, ...): size and stride must have the ",
627       "same length! Got size ", sizes, " and stride ", strides);
628 
629   // Sanity checks:
630   // 1. All batch dims are at the front in memory layout (not necessary for
631   // correctness, but we are worried the user might be doing crazy things)
632   // 2. as_strided(sizes, strides, storage_offset + tensor[i].offset() - tensor.offset())
633   // is valid for a slice of the input tensor.
634   // See Note: [When will the as_strided batching rule fail?] for details.
635   checkBatchDimsAtFrontInLayout(physical_tensor.strides(), num_batch_dims);
636   checkBasicAsStridedValidForSlice(
637       physical_tensor, num_batch_dims, sizes, strides, storage_offset);
638 
639   // physical_strides = physical tensor's batch strides + (logical) strides
640   auto batch_strides = physical_tensor.strides().slice(0, num_batch_dims);
641   at::VmapDimVector physical_strides;
642   physical_strides.reserve(num_batch_dims + strides.size());
643   physical_strides.insert(
644       physical_strides.end(), batch_strides.begin(), batch_strides.end());
645   physical_strides.insert(
646       physical_strides.end(), strides.begin(), strides.end());
647 
648   // If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
649   // is valid for all i, then it turns out that
650   // xs.as_strided(physical_sizes, physical_strides, offset) always succeeds
651   // and creates a tensor y such that each y[i] references the same memory
652   // locations as zi. See NOTE: [When will the as_strided batching rule fail?]
653   auto result = physical_view.tensor().as_strided(
654       physical_sizes, physical_strides, storage_offset);
655   return physical_view.getPhysicalToLogicalMap().apply(result);
656 }
657 
658 // NOTE: [When will the as_strided batching rule fail?]
659 // If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
660 // is valid for all i, then it turns out that
661 // xs.as_strided(physical_sizes, physical_strides, offset) always succeeds and
662 // creates a tensor y such that each y[i] refers to the same memory as zi.
663 //
664 // Let's say we have xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()).
665 // Furthermore, let's say that as a part of being "valid" this as_strided call
666 // does not return a result that can index memory not indexable by xs[i].
667 //
668 // WLOG, assume that there's only one batch dim and it is at the front of the
669 // `xs` tensor. Let B be the batch size and S be the stride of the batch dim.
670 // - If the batch dim isn't at the front of the tensor, then we can just move it
671 // to the front with movedim/permute. This is always valid because it just swaps
672 // some strides around.
673 // - This proof also works for tensors with multiple batch dims. We just have to
674 // do a little accounting:
675 //   - instead of [B], we'd have [B0, B1, ..., Bk].
676 //   - instead of [S], we'd have [S0, S1, ..., Sk].
677 //   - instead of i, we'd have a list of indices [I0, I1, ..., Ik]
678 //   - instead of S * I, we'd have \sum_{i=0}^k S_i * I_i
679 //
680 // [Equation 1]
681 // xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) has:
682 // - sizes: sizes
683 // - strides: strides
684 // - offset: offset + S * i
685 //
686 // x.as_strided itself checks that:
687 // - (sizes, strides, offset) are in bounds for `x`'s storage.
688 // - strides are positive
689 // - offset is positive
690 //
691 // Claim 1: if xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
692 // is valid, then
693 // ([B] + sizes, [S] + strides, offset + xs.offset()) are in bounds for `xs`'s storage.
694 //
695 // If we have the claim, then xs.as_strided([B] + sizes, [S] + strides, offset)
696 // won't error out. So all we need to check is that the memory locations are
697 // what we expected. See [Hand-wavy proof of Claim 1] for proof (it's not very important)
698 //
699 // xs.as_strided(physical_sizes, physical_strides, offset) is equivalent to
700 // xs.as_strided([B] + sizes, [S] + strides, offset)
701 //
702 // xs.as_strided([B] + sizes, [S] + strides, offset) has:
703 // - sizes: [B] + sizes
704 // - strides: [S] + strides
705 // - offset: offset
706 //
707 // xs.as_strided([B] + sizes, [S] + strides, offset)[i] has:
708 // - sizes: sizes
709 // - strides: strides
710 // - offset: offset + S * i
711 // These memory locations are exactly the same as what we got for [Equation 1],
712 // so the xs.as_strided([B] + sizes, [S] + strides, offset) is valid.
713 //
714 // [Hand-wavy proof of Claim 1]
715 // Part of our definition of being valid is that xs[i].as_strided(...)
716 // must return a tensor that only uses memory indexable by xs[i].
717 // This means that (sizes, strides, offset + xs[i].offset() - xs.offset()) satisfies:
718 //    offset + xs[i].offset() - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
719 //    <= xs[i].offset() + 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
720 // (the largest-index memory location of xs[i].as_strided(...) must be \leq
721 // the largest-index memory location of xs[i])
722 //
723 // Fiddling that inequality gives us:
724 //    offset - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
725 //    <= 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
726 //
727 //    offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
728 //    <= 1 + (B-1)*S + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
729 //
730 //    offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
731 //    <= 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
732 //
733 //    offset + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
734 //    <= xs.offset() + 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
735 // (the largest-index memory location of xs.as_strided(size, stride, offset)
736 // is \leq than the largest-index memory location of xs)
737 // Under the assumptions we've made, the lower bound (lowest indexed memory)
738 // is trivially within the storage.
739 //
740 // Therefore ([B] + sizes, [S] + strides, offset) are in bounds for
741 // `xs`'s storage.
742 
743 template <typename F, F Func, typename... ExtraArgs>
unwrap_and_call(const Tensor & input,ExtraArgs...args)744 Tensor unwrap_and_call(const Tensor& input, ExtraArgs... args) {
745   auto* input_batched = unsafeGetBatchedImpl(input);
746   auto output_physical = Func(input_batched->value(), args...);
747   auto old_bdims = input_batched->bdims();
748   return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
749 }
750 
751 template <typename F, F Func, typename... ExtraArgs>
unwrap_and_call_method(const Tensor & input,ExtraArgs...extra_args)752 Tensor unwrap_and_call_method(const Tensor& input, ExtraArgs... extra_args) {
753   auto* input_batched = unsafeGetBatchedImpl(input);
754   auto output_physical = (input_batched->value().*Func)(extra_args...);
755   auto old_bdims = input_batched->bdims();
756   return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
757 }
758 
pow_scalar_Tensor_batching_rule(const Scalar & other,const Tensor & self)759 Tensor pow_scalar_Tensor_batching_rule(const Scalar& other, const Tensor& self) {
760   auto* self_batched = unsafeGetBatchedImpl(self);
761   auto output_physical = at::pow(other, self_batched->value());
762   auto old_bdims = self_batched->bdims();
763   return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
764 }
765 
clone_batching_rule(const Tensor & self,std::optional<MemoryFormat> memory_format)766 Tensor clone_batching_rule(const Tensor& self, std::optional<MemoryFormat> memory_format) {
767   // Memory format support is a little tricky because vmap is allowed to move
768   // around batch dimensions and some memory formats are rank-dependent.
769   // Another weird case is:
770   // - a tensor with MemoryFormat::ChannelsLast MUST have 4 dimensions. Do we
771   //   allow the user to clone a Tensor with 3 logical dimensions and 1 batch
772   //   dim into a ChannelsLast Tensor? What about a Tensor with 3 logical dims
773   //   and N>1 batch dims?
774   TORCH_CHECK(!memory_format.has_value() || memory_format == MemoryFormat::Preserve
775       || memory_format == MemoryFormat::Contiguous,
776       "NYI: Tensor.clone(memory_format) inside vmap is only supported with ",
777       "memory_format torch.preserve_format or torch.contiguous_format (got ",
778       *memory_format, ")");
779 
780   if (memory_format == MemoryFormat::Contiguous) {
781     // There is an ambiguity here when the batch dims are not at the front of
782     // the tensor.
783     // >>> x = torch.randn(3, B0, 5)
784     // >>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1, out_dims=0)(x)
785     // >>> y[0].is_contiguous()
786     // ???
787     // Should we make the whole tensor contiguous, or should we
788     // make the non-batch dims contiguous? We've chosen the latter because
789     // philosophically vmap hides the batch dims and operates on a per-sample level.
790     auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
791     auto output_physical = at::clone(physical_view.tensor(), memory_format);
792     return physical_view.getPhysicalToLogicalMap().apply(output_physical);
793   }
794 
795   TORCH_INTERNAL_ASSERT(!memory_format.has_value() || memory_format == MemoryFormat::Preserve);
796   auto* self_batched = unsafeGetBatchedImpl(self);
797   auto output_physical = at::clone(self_batched->value(), memory_format);
798   auto old_bdims = self_batched->bdims();
799   return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
800 }
801 
802 // Note [Batching rules for matmul-like operators]
803 // at::matmul doesn't "de-expand" arguments to get better performance (maybe
804 // it should). In the batching rules for matmul-like operators (dot, mv, mm),
805 // we should be careful not to expand any unnecessary dimensions. e.g., if
806 // only one of the two arguments is a BatchedTensor, then we should try
807 // not to expand batch dimensions onto the other arg.
mv_batching_rule(const Tensor & self,const Tensor & other)808 Tensor mv_batching_rule(const Tensor& self, const Tensor& other) {
809   auto self_batched = isBatchedTensor(self);
810   auto other_batched = isBatchedTensor(other);
811 
812   // A shape checking API would be nice...
813   TORCH_CHECK(self.dim() == 2 && other.dim() == 1,
814       "mv(self, other): Shape mismatch: expected matrix "
815       "(got `self` of size ", self.sizes(), ") ",
816       "and vector (got `other` of size ", other.sizes(), ")");
817 
818   // See Note [Batching rules for matmul-like operators] for why we have cases
819   if (self_batched && !other_batched) {
820     auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
821     auto result = at::matmul(self_physical.tensor(), other);
822     return self_physical.getPhysicalToLogicalMap().apply(result);
823   }
824   if (!self_batched && other_batched) {
825     // self_physical: [L, K], other_physical: [..., K]
826     // We view the tensors as [L, K], [..., K, 1], perform matmul to get
827     // a tensor of size [..., L, 1], and unsqueeze the last dim.
828     auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
829     auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1));
830     return other_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
831   }
832   if (self_batched && other_batched) {
833     // self_physical: [..., L, K], other_physical: [..., K]
834     // We view the tensors as [..., L, K], [..., K, 1], perform matmul to get
835     // a tensor of size [..., L, 1], and unsqueeze the last dim.
836     auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
837     auto result = at::matmul(
838         physical_args[0].tensor(),
839         physical_args[1].tensor().unsqueeze(-1));
840     return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1));
841   }
842   TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
843 }
844 
_make_dual_batching_rule(c10::DispatchKeySet ks,const Tensor & primal,const Tensor & tangent,int64_t level)845 Tensor _make_dual_batching_rule(
846   c10::DispatchKeySet ks,
847   const Tensor& primal,
848   const Tensor& tangent,
849   int64_t level
850 ) {
851   DispatchKeySet after_batched_keyset =
852       DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::Batched);
853   return at::redispatch::_make_dual(ks & after_batched_keyset, primal, tangent, level);
854 }
855 
dot_batching_rule(const Tensor & self,const Tensor & other)856 Tensor dot_batching_rule(const Tensor& self, const Tensor& other) {
857   auto self_batched = isBatchedTensor(self);
858   auto other_batched = isBatchedTensor(other);
859 
860   TORCH_CHECK(/*logical*/self.dim() == 1 && /*logical*/other.dim() == 1,
861       "dot(self, other): Shape mismatch: vector "
862       "(got `self` of size ", self.sizes(), ") ",
863       "and vector (got `other` of size ", other.sizes(), ")");
864 
865   // See Note [Batching rules for matmul-like operators] for why we have cases
866   if (self_batched && !other_batched) {
867     // self_physical: [..., K], other_physical: [K]
868     // View the tensors as [..., 1, K] and [K], perform matmul, and unsqueeze.
869     auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
870     auto result = at::matmul(self_physical.tensor().unsqueeze(-2), other);
871     return self_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
872   }
873   if (!self_batched && other_batched) {
874     // self_physical: [K], other_physical: [..., K]
875     // View the tensors as [K] and [..., K, 1], perform matmul, and unsqueeze.
876     auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
877     auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1));
878     return other_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
879   }
880   if (self_batched && other_batched) {
881     // self_physical: [..., K], other_physical: [..., K]
882     // View the tensors as [..., 1, K] and [..., K, 1], perform matmul, and unsqueeze.
883     auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
884     auto result = at::matmul(
885         physical_args[0].tensor().unsqueeze(-2),
886         physical_args[1].tensor().unsqueeze(-1));
887     return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1).squeeze(-1));
888   }
889   TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
890 }
891 
bmm_batching_rule(const Tensor & self,const Tensor & other)892 Tensor bmm_batching_rule(const Tensor& self, const Tensor& other) {
893   TORCH_CHECK(/*logical*/self.dim() == 3 && /*logical*/other.dim() == 3,
894       "bmm(self, other): Shape mismatch: expected 3D `self` "
895       "(got `self` of size ", self.sizes(), ") ",
896       "and 3D `other` (got `other` of size ", other.sizes(), ")");
897 
898   auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
899   auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor());
900   return physical_args[0].getPhysicalToLogicalMap().apply(result);
901 }
902 
mm_batching_rule(const Tensor & self,const Tensor & other)903 Tensor mm_batching_rule(const Tensor& self, const Tensor& other) {
904   auto self_batched = isBatchedTensor(self);
905   auto other_batched = isBatchedTensor(other);
906 
907   TORCH_CHECK(/*logical*/self.dim() == 2 && /*logical*/other.dim() == 2,
908       "mm(self, other): Shape mismatch: expected matrix "
909       "(got `self` of size ", self.sizes(), ") ",
910       "and matrix (got `other` of size ", other.sizes(), ")");
911 
912   // See Note [Batching rules for matmul-like operators] for why we have cases
913   if (self_batched && !other_batched) {
914     auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
915     auto result = at::matmul(self_physical.tensor(), other);
916     return self_physical.getPhysicalToLogicalMap().apply(result);
917   }
918   if (!self_batched && other_batched) {
919     auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
920     auto result = at::matmul(self, other_physical.tensor());
921     return other_physical.getPhysicalToLogicalMap().apply(result);
922   }
923   if (self_batched && other_batched) {
924     auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
925     auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor());
926     return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1).squeeze(-1));
927   }
928   TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
929 }
930 
cat_batching_rule(const ITensorListRef & tensors,int64_t dim)931 Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) {
932   auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
933   auto physical_tensors = fmap(
934       physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
935   TORCH_INTERNAL_ASSERT(
936       !tensors.empty(), "The dispatcher should not have dispatched here otherwise.");
937   auto result = at::cat(physical_tensors, physical_views[0].getPhysicalDim(dim));
938   return physical_views[0].getPhysicalToLogicalMap().apply(result);
939 }
940 
stack_batching_rule(TensorList tensors,int64_t dim)941 Tensor stack_batching_rule(TensorList tensors, int64_t dim) {
942   auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
943   auto physical_tensors = fmap(
944       physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
945   TORCH_INTERNAL_ASSERT(
946       !tensors.empty(), "The dispatcher should not have dispatched here otherwise.");
947   // NB: stack wraps the dimensionality to (logical dim + 1), so we have to
948   // manually handle that here.
949   auto dim_physical =
950       physical_views[0].numBatchDims() + maybe_wrap_dim(dim, /*logical*/tensors[0].dim() + 1);
951   auto result = at::stack(physical_tensors, dim_physical);
952   return physical_views[0].getPhysicalToLogicalMap().apply(result);
953 }
954 
955 // I am quite sad that we need to register operators with exploded TensorOptions,
956 // even though the native:: implementations can use TensorOptions&.
957 // This also makes it hard to metaprogram: i.e., we can't use
958 // unwrap_and_call<..., at::to> because at::to takes TensorOptions& (!!)
to_dtype_layout_batching_rule(const Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,bool non_blocking,bool copy,std::optional<MemoryFormat> memory_format)959 Tensor to_dtype_layout_batching_rule(
960     const Tensor& self,
961     std::optional<ScalarType> dtype,
962     std::optional<Layout> layout,
963     std::optional<Device> device,
964     std::optional<bool> pin_memory,
965     bool non_blocking, bool copy,
966     std::optional<MemoryFormat> memory_format) {
967   auto options = TensorOptions()
968     .dtype(dtype)
969     .layout(layout)
970     .device(device)
971     .pinned_memory(pin_memory);
972   auto* input_batched = unsafeGetBatchedImpl(self);
973   auto output_physical = input_batched->value().to(options, non_blocking, copy, memory_format);
974   auto old_bdims = input_batched->bdims();
975   return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
976 }
977 
new_zeros_batching_rule(const Tensor & self,IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)978 Tensor new_zeros_batching_rule(
979     const Tensor& self,
980     IntArrayRef size,
981     std::optional<ScalarType> dtype,
982     std::optional<Layout> layout,
983     std::optional<Device> device,
984     std::optional<bool> pin_memory) {
985   auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
986   auto physical_size = physical_view.getPhysicalShape(size);
987   auto options = TensorOptions()
988     .dtype(dtype)
989     .layout(layout)
990     .device(device)
991     .pinned_memory(pin_memory);
992   auto result = physical_view.tensor().new_zeros(physical_size, options);
993   return physical_view.getPhysicalToLogicalMap().apply(result);
994 }
995 
new_empty_batching_rule(const Tensor & self,IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)996 Tensor new_empty_batching_rule(
997     const Tensor& self,
998     IntArrayRef size,
999     std::optional<ScalarType> dtype,
1000     std::optional<Layout> layout,
1001     std::optional<Device> device,
1002     std::optional<bool> pin_memory) {
1003   auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
1004   auto physical_size = physical_view.getPhysicalShape(size);
1005   auto result = physical_view.tensor().new_empty(physical_size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory));
1006   return physical_view.getPhysicalToLogicalMap().apply(result);
1007 }
1008 
new_empty_strided_batching_rule(const Tensor & self,IntArrayRef size,IntArrayRef stride,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)1009 Tensor new_empty_strided_batching_rule(
1010     const Tensor& self,
1011     IntArrayRef size,
1012     IntArrayRef stride,
1013     std::optional<ScalarType> dtype,
1014     std::optional<Layout> layout,
1015     std::optional<Device> device,
1016     std::optional<bool> pin_memory) {
1017   auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
1018   auto physical_size = physical_view.getPhysicalShape(size);
1019 
1020   // Let [B0, B1, B2] be the shape of the batch dims. We're going to create
1021   // the batch dimensions at the front of the tensor (in memory layout),
1022   // irrespective of whether or not they are actually at the front (in memory layout)
1023   // in the original `self` tensor. This is because when a user calls
1024   // `new_empty_strided` in general, the `strides` they provide are for a new
1025   // tensor and have no relation to the strides of the original tensor.
1026   //
1027   // So, the physical shape of the result should be ([B0, B1, B2] + size),
1028   // but what about the physical strides?
1029   //
1030   // We're actually free to pick whatever stride we want:
1031   // e.g., for size=[5, 3], stride=[0, 1], we could decide to
1032   // use
1033   // - physical size: [B0, B1, B2, 5, 3]
1034   // - physical stride: [9999*B1*B2, 9999*B2, 9999, 0, 1]
1035   //
1036   // Let's select some reasonable strides such that:
1037   // - The batch dims are "contiguous" with respect to each other
1038   // - if empty_strided(size, stride) would have created a contiguous Tensor,
1039   // then this new physical Tensor (with batch dims) is also contiguous
1040   //
1041   // Let S be the size of the storage if one were to construct a tensor
1042   // with `size` and `stride` via empty_strided(size, stride).
1043   // Then the physical sizes/strides should be:
1044   // - physical size: [B0, B1, B2, 5, 3]
1045   // - physical stride: [B1 * B2 * S, B2 * S, S, 0, 1]
1046   auto batch_shape = IntArrayRef(
1047       physical_view.tensor().sizes().begin(), physical_view.numBatchDims());
1048 
1049   // physical_strides = [B1 * B2 * S, B2 * S, S]
1050   auto physical_strides = at::detail::defaultStrides(batch_shape);
1051   TORCH_CHECK(size.size() == stride.size(),
1052         "new_empty_strided(sizes, strides): dimensionality of sizes (",
1053         size.size(), ") must match dimensionality of strides (",
1054         stride.size(), ")");
1055   auto storage_size = native::storage_size_for(size, stride);
1056   for (auto& physical_stride : physical_strides) {
1057     physical_stride *= storage_size;
1058   }
1059 
1060   // physical_strides = [B1 * B2 * S, B2 * S, S] + strides
1061   physical_strides.insert(physical_strides.end(), stride.begin(), stride.end());
1062 
1063   auto result = physical_view.tensor().new_empty_strided(
1064       physical_size, physical_strides, dtype, layout, device, pin_memory);
1065   return physical_view.getPhysicalToLogicalMap().apply(result);
1066 }
1067 
1068 template <typename F, F Func>
comparison_pointwise_batching_rule(const Tensor & self,const Tensor & other)1069 Tensor comparison_pointwise_batching_rule(const Tensor& self, const Tensor& other) {
1070   auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
1071   auto result = Func(physical_args[0].tensor(), physical_args[1].tensor());
1072   return physical_args[0].getPhysicalToLogicalMap().apply(result);
1073 }
1074 }
TORCH_LIBRARY_IMPL(_,Batched,m)1075 TORCH_LIBRARY_IMPL(_, Batched, m) {
1076   m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
1077 }
1078 
TORCH_LIBRARY_IMPL(aten,Batched,m)1079 TORCH_LIBRARY_IMPL(aten, Batched, m) {
1080   // NB: Ideally we would like some operators, like size.int, to "fallthrough"
1081   // to the underlying implementation. However, because a BatchedTensor is a
1082   // Tensor wrapper, it only has one dispatch key (Batched) on it. The resolution
1083   // here is to just directly call the underlying implementation.
1084   m.impl("size.int", static_cast<int64_t (*)(const Tensor&, int64_t)>(native::size));
1085   m.impl("_add_batch_dim", native::_add_batch_dim);
1086   m.impl("_remove_batch_dim", native::_remove_batch_dim);
1087   m.impl("_make_dual", _make_dual_batching_rule);
1088   m.impl("_has_same_storage_numel", _has_same_storage_numel_batching_rule);
1089   m.impl("is_same_size", native::is_same_size);
1090   m.impl("_new_zeros_with_same_feature_meta", _new_zeros_with_same_feature_meta_batching_rule);
1091 
1092   m.impl("sum.dim_IntList", sum_batching_rule);
1093   m.impl("is_complex", native::is_complex);
1094 
1095   // inplace operations
1096   m.impl("fill_.Scalar", fill_inplace_scalar_batching_rule);
1097   m.impl("fill_.Tensor", fill_inplace_tensor_batching_rule);
1098   m.impl("zero_", zero_inplace_batching_rule);
1099 
1100   // view operations
1101   m.impl("as_strided", as_strided_batching_rule);
1102   m.impl("chunk", chunk_batching_rule);
1103   m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
1104   m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
1105   m.impl("diagonal", diagonal_batching_rule);
1106   m.impl("expand", expand_batching_rule);
1107   m.impl("expand_as", native::expand_as); // composite wrt autograd
1108   m.impl("movedim.intlist", movedim_batching_rule);
1109   m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
1110   // There is another variant of narrow.  However, we don't
1111   // want to support the other variant yet bc it isn't documented...
1112   m.impl("narrow", native::narrow_symint); // composite wrt autograd
1113   m.impl("numpy_T", native::numpy_T);   // composite wrt autograd
1114   m.impl("matrix_H", native::matrix_H); // composite wrt autograd
1115   m.impl("mT", native::mT);             // composite wrt autograd
1116   m.impl("mH", native::mH);             // composite wrt autograd
1117   m.impl("permute", permute_batching_rule);
1118   m.impl("reshape", reshape_batching_rule);
1119   m.impl("_reshape_alias", _reshape_alias_batching_rule);
1120   m.impl("reshape_as", native::reshape_as); // composite wrt autograd
1121   m.impl("select.int", select_batching_rule);
1122   m.impl("slice.Tensor", slice_batching_rule);
1123   m.impl("split.Tensor", split_batching_rule);
1124   m.impl("split.sizes", split_with_sizes_batching_rule);
1125   m.impl("split_with_sizes", split_with_sizes_batching_rule);
1126   m.impl("squeeze", squeeze_batching_rule);
1127   m.impl("squeeze.dim", squeeze_dim_batching_rule);
1128   m.impl("squeeze.dims", squeeze_dims_batching_rule);
1129   m.impl("t", native::t); // composite wrt autograd
1130   m.impl("trace", trace_batching_rule);
1131   m.impl("transpose.int", transpose_int_batching_rule);
1132   m.impl("unbind.int", unbind_batching_rule);
1133   m.impl("unfold", unfold_batching_rule);
1134   m.impl("unsqueeze", unsqueeze_batching_rule);
1135   m.impl("view", view_batching_rule);
1136   m.impl("view_as", native::view_as); // composite wrt autograd
1137 
1138   // clamp operations
1139   m.impl("clamp", clamp_batching_rule);
1140   m.impl("clamp_min", clamp_min_batching_rule);
1141   m.impl("clamp_max", clamp_max_batching_rule);
1142 
1143   // unary pointwise, out-of-place, no additional arguments.
1144 #define UNARY_POINTWISE(op) m.impl(#op, \
1145     unwrap_and_call<Tensor (*)(const Tensor&), at::op>);
1146   UNARY_POINTWISE(abs);
1147   UNARY_POINTWISE(acos);
1148   UNARY_POINTWISE(asin);
1149   UNARY_POINTWISE(atan);
1150   UNARY_POINTWISE(ceil);
1151   UNARY_POINTWISE(cos);
1152   UNARY_POINTWISE(cosh);
1153   UNARY_POINTWISE(conj_physical);
1154   UNARY_POINTWISE(digamma);
1155   UNARY_POINTWISE(exp);
1156   UNARY_POINTWISE(expm1);
1157   UNARY_POINTWISE(floor);
1158   UNARY_POINTWISE(frac);
1159   UNARY_POINTWISE(lgamma);
1160   UNARY_POINTWISE(log);
1161   UNARY_POINTWISE(log10);
1162   UNARY_POINTWISE(log1p);
1163   UNARY_POINTWISE(log2);
1164   UNARY_POINTWISE(neg);
1165   UNARY_POINTWISE(reciprocal);
1166   UNARY_POINTWISE(relu);
1167   UNARY_POINTWISE(round);
1168   UNARY_POINTWISE(rsqrt);
1169   UNARY_POINTWISE(sigmoid);
1170   UNARY_POINTWISE(sign);
1171   UNARY_POINTWISE(sin);
1172   UNARY_POINTWISE(sinh);
1173   UNARY_POINTWISE(sqrt);
1174   UNARY_POINTWISE(tan);
1175   UNARY_POINTWISE(tanh);
1176   UNARY_POINTWISE(trunc);
1177 #undef UNARY_POINTWISE
1178 #define TO_BATCHING_RULE(name, ...) \
1179   { \
1180     using to_type = Tensor(Tensor::*)(__VA_ARGS__) const; \
1181     m.impl(name, unwrap_and_call_method< \
1182         to_type, &Tensor::to, __VA_ARGS__>);\
1183   }
1184   TO_BATCHING_RULE("to.device", Device, ScalarType, bool, bool, std::optional<MemoryFormat>)
1185   TO_BATCHING_RULE("to.dtype", ScalarType, bool, bool, std::optional<MemoryFormat>)
1186   TO_BATCHING_RULE("to.other", const Tensor&, bool, bool, std::optional<MemoryFormat>)
1187   m.impl("to.dtype_layout", to_dtype_layout_batching_rule);
1188 #undef TO_BATCHING_RULE
1189   m.impl("clone", clone_batching_rule);
1190 
1191   using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, const Scalar&);
1192   using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
1193   using TensorScalarType = Tensor (*)(const Tensor&, const Scalar&);
1194 
1195 #define BINARY_POINTWISE(op) \
1196   m.impl(#op".Tensor", binary_pointwise_batching_rule<TensorTensorType, at::op>); \
1197   m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, const Scalar&>);
1198 #define BINARY_POINTWISE_VA(op, ...) \
1199   { \
1200     using Binop = Tensor (*)(const Tensor&, const Tensor&, __VA_ARGS__); \
1201     using Unop = Tensor (*)(const Tensor&, const Scalar&, __VA_ARGS__); \
1202     m.impl(#op".Tensor", binary_pointwise_batching_rule<Binop, at::op, __VA_ARGS__>); \
1203     m.impl(#op".Scalar", unwrap_and_call<Unop, at::op, const Scalar&, __VA_ARGS__>); \
1204   }
1205 
1206   BINARY_POINTWISE_VA(add, const Scalar&);
1207   BINARY_POINTWISE_VA(sub, const Scalar&);
1208   BINARY_POINTWISE_VA(rsub, const Scalar&);
1209   BINARY_POINTWISE(mul);
1210   BINARY_POINTWISE(div);
1211   {
1212     using Binop = Tensor (*)(const Tensor&, const Tensor&, std::optional<c10::string_view>);
1213     using Unop = Tensor (*)(const Tensor&, const Scalar&, std::optional<c10::string_view>);
1214     m.impl("div.Tensor_mode", binary_pointwise_batching_rule<Binop, at::div, std::optional<c10::string_view>>);
1215     m.impl("div.Scalar_mode", unwrap_and_call<Unop, at::div, const Scalar&, std::optional<c10::string_view>>);
1216   }
1217 
1218   // at::pow has three out-of-place overloads
1219   m.impl("pow.Tensor_Tensor", binary_pointwise_batching_rule<TensorTensorType, at::pow>);
1220   m.impl("pow.Tensor_Scalar", unwrap_and_call<TensorScalarType, at::pow, const Scalar&>);
1221   m.impl("pow.Scalar", pow_scalar_Tensor_batching_rule);
1222 
1223   m.impl("sigmoid_backward", binary_pointwise_batching_rule<TensorTensorType, at::sigmoid_backward>);
1224   m.impl(
1225       "threshold_backward",
1226       binary_pointwise_batching_rule<
1227           TensorTensorScalarType,
1228           at::threshold_backward,
1229           const Scalar&>);
1230 
1231   // for at::result_type, call the native::result_type implementation.
1232   // We don't have to do anything special because native::result_type operates
1233   // on the logical shape of the tensors.
1234   m.impl("result_type.Tensor", static_cast<ScalarType (*)(const Tensor&, const Tensor&)>(native::result_type));
1235   m.impl("result_type.Scalar", static_cast<ScalarType (*)(const Tensor&, const Scalar&)>(native::result_type));
1236   m.impl("result_type.Scalar_Tensor", static_cast<ScalarType (*)(const Scalar&, const Tensor&)>(native::result_type));
1237   m.impl("result_type.Scalar_Scalar", static_cast<ScalarType (*)(const Scalar&, const Scalar&)>(native::result_type));
1238 
1239 #undef BINARY_POINTWISE_VA
1240 #undef BINARY_POINTWISE
1241 
1242 
1243 #define TRIVIAL_OP(op) m.impl(#op, \
1244     unwrap_and_call<Tensor (*)(const Tensor&), at::op>);
1245   // complex number view operators
1246   TRIVIAL_OP(imag)
1247   TRIVIAL_OP(real);
1248   TRIVIAL_OP(view_as_real);
1249   TRIVIAL_OP(conj);
1250   TRIVIAL_OP(_conj);
1251   TRIVIAL_OP(resolve_conj);
1252   TRIVIAL_OP(resolve_neg);
1253   m.impl("view_as_complex", view_as_complex_batching_rule);
1254 #undef TRIVIAL
1255 
1256   // matmul-like operators
1257   m.impl("mv", mv_batching_rule);
1258   m.impl("dot", dot_batching_rule);
1259   m.impl("bmm", bmm_batching_rule);
1260   m.impl("mm", mm_batching_rule);
1261 
1262   // cat/stack
1263   m.impl("cat", cat_batching_rule);
1264   m.impl("stack", stack_batching_rule);
1265 
1266   // backward operators
1267   m.impl("select_backward", select_backward_batching_rule);
1268   m.impl("slice_backward", slice_backward_batching_rule);
1269   m.impl("trace_backward", trace_backward_batching_rule);
1270   m.impl("diagonal_backward", diagonal_backward_batching_rule);
1271 
1272   // Tensor.new_* operators
1273   m.impl("new_empty", new_empty_batching_rule);
1274   m.impl("new_empty_strided", new_empty_strided_batching_rule);
1275   m.impl("new_zeros", new_zeros_batching_rule);
1276 
1277   m.impl("contiguous", contiguous_batching_rule);
1278 
1279   // Comparison ops
1280 #define COMPARISON_POINTWISE(op) \
1281   m.impl(#op".Tensor", comparison_pointwise_batching_rule<TensorTensorType, at::op>); \
1282   m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, const Scalar&>);
1283 
1284   COMPARISON_POINTWISE(eq);
1285   COMPARISON_POINTWISE(gt);
1286   COMPARISON_POINTWISE(ge);
1287   COMPARISON_POINTWISE(le);
1288   COMPARISON_POINTWISE(lt);
1289   COMPARISON_POINTWISE(ne);
1290 
1291 #undef COMPARISON_POINTWISE
1292 }
1293 
1294 } // namespace at
1295