1*da0073e9SAndroid Build Coastguard Worker #include <torch/library.h>
2*da0073e9SAndroid Build Coastguard Worker #include <ATen/ATen.h>
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/LegacyVmapTransforms.h>
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/LegacyBatchedFallback.h>
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/RedispatchFunctions.h>
6*da0073e9SAndroid Build Coastguard Worker #include <ATen/native/ResizeCommon.h>
7*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/IListRef.h>
8*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
9*da0073e9SAndroid Build Coastguard Worker #include <c10/core/SymIntArrayRef.h>
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker #include <utility>
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker namespace at {
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker // NOTE: [What is a batching rule?]
16*da0073e9SAndroid Build Coastguard Worker //
17*da0073e9SAndroid Build Coastguard Worker // A *batching rule* implements the logic of how to call an operator on inputs
18*da0073e9SAndroid Build Coastguard Worker // that have zero or more additional batch dimensions. When one does a vmap, the
19*da0073e9SAndroid Build Coastguard Worker // dimension(s) being vmap'ed over get recorded as batch dimensions.
20*da0073e9SAndroid Build Coastguard Worker //
21*da0073e9SAndroid Build Coastguard Worker // For example, vmap(torch.add)(x, y)
22*da0073e9SAndroid Build Coastguard Worker // 1. wraps `x` into batched_x = BatchedTensor(x, bdims=[(lvl=1, dim=0)];
23*da0073e9SAndroid Build Coastguard Worker // 2. wraps `y` into batched_y = BatchedTensor(y, bdims=[(lvl=1, dim=0)];
24*da0073e9SAndroid Build Coastguard Worker // 3. and then runs `torch.add(batched_x, batched_y)`.
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker // NOTE: [When should I add a batching rule?]
27*da0073e9SAndroid Build Coastguard Worker // When you are adding a new operator, you'll need to add a batching rule so
28*da0073e9SAndroid Build Coastguard Worker // that vmap can work efficiently with said operator. If you do not, we'll attempt
29*da0073e9SAndroid Build Coastguard Worker // to generate a slow fallback for the batching rule.
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker // NOTE: [How to write batching rules?]
32*da0073e9SAndroid Build Coastguard Worker // The signature of a batching rule should look like exactly like the C++ signature
33*da0073e9SAndroid Build Coastguard Worker // of its operator.
34*da0073e9SAndroid Build Coastguard Worker //
35*da0073e9SAndroid Build Coastguard Worker // First, see NOTE: [Logical vs physical args] in VmapTransforms.h for terminology.
36*da0073e9SAndroid Build Coastguard Worker //
37*da0073e9SAndroid Build Coastguard Worker // At a high level, what a batching rule does is the following:
38*da0073e9SAndroid Build Coastguard Worker // 1. Converts (logical) BatchedTensors to views on physical tensors.
39*da0073e9SAndroid Build Coastguard Worker // 2. Converts logical arguments (e.g. dimension indexes, shapes) to physical
40*da0073e9SAndroid Build Coastguard Worker // arguments that correspond to the physical tensors.
41*da0073e9SAndroid Build Coastguard Worker // 3. Calls at:: operations on the physical tensors and arguments to produce
42*da0073e9SAndroid Build Coastguard Worker // some physical results.
43*da0073e9SAndroid Build Coastguard Worker // 4. Converts physical results back to BatchedTensors.
44*da0073e9SAndroid Build Coastguard Worker //
45*da0073e9SAndroid Build Coastguard Worker // Steps 1, 2, and 4 differ for operators with different batching behaviors. When
46*da0073e9SAndroid Build Coastguard Worker // writing a new batching rule, please select a VmapTransform that matches the
47*da0073e9SAndroid Build Coastguard Worker // batching behavior of your operation. The VmapTransform provides helper functions
48*da0073e9SAndroid Build Coastguard Worker // to do steps (1), (2), and (4).
49*da0073e9SAndroid Build Coastguard Worker // (see NOTE: [What is an VmapTransform?] in VmapTransforms.h)
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker // Note: [Future plans]
52*da0073e9SAndroid Build Coastguard Worker // The API for writing a batching rule isn't stable. In the future, we'd like
53*da0073e9SAndroid Build Coastguard Worker // to think about the problem of translating these batching rules to TorchScript.
54*da0073e9SAndroid Build Coastguard Worker // Ideally batching rules in eager mode vs TorchScript would look pretty similar,
55*da0073e9SAndroid Build Coastguard Worker // if not use the same mechanism. In order to accomplish that we might have to
56*da0073e9SAndroid Build Coastguard Worker // do some refactoring.
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker namespace{
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker // PyTorch allows operations to specify dim 0 and dim -1 on a scalar tensor.
is_allowed_dim_on_scalar_tensor(int64_t dim)61*da0073e9SAndroid Build Coastguard Worker static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
62*da0073e9SAndroid Build Coastguard Worker return dim == 0 || dim == -1;
63*da0073e9SAndroid Build Coastguard Worker }
64*da0073e9SAndroid Build Coastguard Worker
sum_batching_rule(const Tensor & self,OptionalIntArrayRef opt_dims,bool keepdim,std::optional<ScalarType> dtype)65*da0073e9SAndroid Build Coastguard Worker Tensor sum_batching_rule(const Tensor& self, OptionalIntArrayRef opt_dims, bool keepdim, std::optional<ScalarType> dtype) {
66*da0073e9SAndroid Build Coastguard Worker if (opt_dims.has_value()) {
67*da0073e9SAndroid Build Coastguard Worker auto dims = opt_dims.value();
68*da0073e9SAndroid Build Coastguard Worker // PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail
69*da0073e9SAndroid Build Coastguard Worker // and instead returns a new scalar tensor (this also happens for dim=-1)
70*da0073e9SAndroid Build Coastguard Worker // If the following happens:
71*da0073e9SAndroid Build Coastguard Worker // >>> x = torch.randn(B0) # the per-examples are all scalars
72*da0073e9SAndroid Build Coastguard Worker // >>> vmap(partial(torch.sum, dim=0), x)
73*da0073e9SAndroid Build Coastguard Worker // then we replicate the behavior of sum(scalar_tensor, dim=0).
74*da0073e9SAndroid Build Coastguard Worker if (/*logical*/self.dim() == 0 && (dims.empty() || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) {
75*da0073e9SAndroid Build Coastguard Worker return self.clone();
76*da0073e9SAndroid Build Coastguard Worker }
77*da0073e9SAndroid Build Coastguard Worker }
78*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
79*da0073e9SAndroid Build Coastguard Worker auto dims_physical = self_physical.getPhysicalDims(opt_dims);
80*da0073e9SAndroid Build Coastguard Worker auto result = at::sum(self_physical.tensor(), dims_physical, keepdim, dtype);
81*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
82*da0073e9SAndroid Build Coastguard Worker }
83*da0073e9SAndroid Build Coastguard Worker
isPhysicalScalarTensor(const Tensor & logical_tensor)84*da0073e9SAndroid Build Coastguard Worker bool isPhysicalScalarTensor(const Tensor& logical_tensor) {
85*da0073e9SAndroid Build Coastguard Worker if (logical_tensor.dim() > 0) {
86*da0073e9SAndroid Build Coastguard Worker return false;
87*da0073e9SAndroid Build Coastguard Worker }
88*da0073e9SAndroid Build Coastguard Worker auto* batched = maybeGetBatchedImpl(logical_tensor);
89*da0073e9SAndroid Build Coastguard Worker if (batched) {
90*da0073e9SAndroid Build Coastguard Worker return false;
91*da0073e9SAndroid Build Coastguard Worker }
92*da0073e9SAndroid Build Coastguard Worker return true;
93*da0073e9SAndroid Build Coastguard Worker }
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker template <typename F, F Func, typename... ExtraArgs>
binary_pointwise_batching_rule(const Tensor & self,const Tensor & other,ExtraArgs...args)96*da0073e9SAndroid Build Coastguard Worker Tensor binary_pointwise_batching_rule(
97*da0073e9SAndroid Build Coastguard Worker const Tensor& self, const Tensor& other, ExtraArgs... args) {
98*da0073e9SAndroid Build Coastguard Worker if (self.dim() > 0 && other.dim() > 0) {
99*da0073e9SAndroid Build Coastguard Worker auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
100*da0073e9SAndroid Build Coastguard Worker auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
101*da0073e9SAndroid Build Coastguard Worker return physical_args[0].getPhysicalToLogicalMap().apply(result);
102*da0073e9SAndroid Build Coastguard Worker }
103*da0073e9SAndroid Build Coastguard Worker if (isPhysicalScalarTensor(self)) {
104*da0073e9SAndroid Build Coastguard Worker auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
105*da0073e9SAndroid Build Coastguard Worker auto result = Func(self, other_physical.tensor(), args...);
106*da0073e9SAndroid Build Coastguard Worker return other_physical.getPhysicalToLogicalMap().apply(result);
107*da0073e9SAndroid Build Coastguard Worker }
108*da0073e9SAndroid Build Coastguard Worker if (isPhysicalScalarTensor(other)) {
109*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
110*da0073e9SAndroid Build Coastguard Worker auto result = Func(self_physical.tensor(), other, args...);
111*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
112*da0073e9SAndroid Build Coastguard Worker }
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker // At this point, we know at least one of the operands is a logical Scalar tensor.
115*da0073e9SAndroid Build Coastguard Worker // Here we must emulate TensorIterator's special behavior on Scalars.
116*da0073e9SAndroid Build Coastguard Worker //
117*da0073e9SAndroid Build Coastguard Worker // As a motivating example, consider the following:
118*da0073e9SAndroid Build Coastguard Worker // x = torch.randn(3, 10)
119*da0073e9SAndroid Build Coastguard Worker // y = torch.randn(3, dtype=torch.double)
120*da0073e9SAndroid Build Coastguard Worker // vmap(torch.mul)(torch.randn(3, 10), torch.randn(3, dtype=torch.double))
121*da0073e9SAndroid Build Coastguard Worker //
122*da0073e9SAndroid Build Coastguard Worker // At a per-example level, we are adding FloatTensor[10] and DoubleTensor[];
123*da0073e9SAndroid Build Coastguard Worker // Type Promotion dictates that the result should be FloatTensor[10].
124*da0073e9SAndroid Build Coastguard Worker // This means we cannot directly pass the physical tensors (x and y) to
125*da0073e9SAndroid Build Coastguard Worker // TensorIterator (if we did, it would promote them to DoubleTensor).
126*da0073e9SAndroid Build Coastguard Worker //
127*da0073e9SAndroid Build Coastguard Worker // FIXME(rzou): I didn't want to go down the slippery slope of emulating
128*da0073e9SAndroid Build Coastguard Worker // everything TensorIterator does (it would be better to refactor out the
129*da0073e9SAndroid Build Coastguard Worker // TensorIterator logic). The one thing that this code doesn't handle
130*da0073e9SAndroid Build Coastguard Worker // is cross-device logical scalar tensors.
131*da0073e9SAndroid Build Coastguard Worker // cpu_tensor = torch.randn(3)
132*da0073e9SAndroid Build Coastguard Worker // cuda_tensor = torch.randn(3, 10, device='cuda')
133*da0073e9SAndroid Build Coastguard Worker // vmap(torch.mul)(cpu_tensor, cuda_tensor)
134*da0073e9SAndroid Build Coastguard Worker //
135*da0073e9SAndroid Build Coastguard Worker // At a per-example level, we are adding CPUTensor[] and CUDATensor[10].
136*da0073e9SAndroid Build Coastguard Worker // TensorIterator allows for this cross-device operation because one of the
137*da0073e9SAndroid Build Coastguard Worker // tensors is a Scalar CPU tensor. However, the following code will throw an
138*da0073e9SAndroid Build Coastguard Worker // error in that case. I don't expect to see many use cases for this, so
139*da0073e9SAndroid Build Coastguard Worker // this is probably fine as-is.
140*da0073e9SAndroid Build Coastguard Worker auto logical_self = self;
141*da0073e9SAndroid Build Coastguard Worker auto logical_other = other;
142*da0073e9SAndroid Build Coastguard Worker auto result_type = at::native::result_type(logical_self, logical_other);
143*da0073e9SAndroid Build Coastguard Worker if (logical_self.scalar_type() != result_type) {
144*da0073e9SAndroid Build Coastguard Worker logical_self = logical_self.to(result_type);
145*da0073e9SAndroid Build Coastguard Worker }
146*da0073e9SAndroid Build Coastguard Worker if (logical_other.scalar_type() != result_type) {
147*da0073e9SAndroid Build Coastguard Worker logical_other = logical_other.to(result_type);
148*da0073e9SAndroid Build Coastguard Worker }
149*da0073e9SAndroid Build Coastguard Worker auto physical_args = BroadcastingVmapTransform::logicalToPhysical(
150*da0073e9SAndroid Build Coastguard Worker {std::move(logical_self), std::move(logical_other)});
151*da0073e9SAndroid Build Coastguard Worker auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
152*da0073e9SAndroid Build Coastguard Worker return physical_args[0].getPhysicalToLogicalMap().apply(result);
153*da0073e9SAndroid Build Coastguard Worker }
154*da0073e9SAndroid Build Coastguard Worker
expand_batching_rule(const Tensor & self,IntArrayRef size,bool implicit)155*da0073e9SAndroid Build Coastguard Worker Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) {
156*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
157*da0073e9SAndroid Build Coastguard Worker auto size_physical = self_physical.getPhysicalShape(size);
158*da0073e9SAndroid Build Coastguard Worker auto self_physical_dim = self_physical.tensor().dim();
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(self_physical_dim <= static_cast<int64_t>(size_physical.size()),
161*da0073e9SAndroid Build Coastguard Worker "expand: the number of sizes provided (", /*logical*/size.size(), ") ",
162*da0073e9SAndroid Build Coastguard Worker "must be greater or equal to the number of dimensions in the tensor (",
163*da0073e9SAndroid Build Coastguard Worker /*logical dim*/self.dim(), ")");
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker if (self_physical_dim == static_cast<int64_t>(size_physical.size())) {
166*da0073e9SAndroid Build Coastguard Worker auto result = self_physical.tensor().expand(size_physical, implicit);
167*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
168*da0073e9SAndroid Build Coastguard Worker }
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(self_physical_dim < static_cast<int64_t>(size_physical.size()));
171*da0073e9SAndroid Build Coastguard Worker // Here, we know we are expanding a (logical) tensor to a larger number
172*da0073e9SAndroid Build Coastguard Worker // of dimensions. We have to be careful because we can't call expand directly
173*da0073e9SAndroid Build Coastguard Worker // due to the presence of batch dimensions.
174*da0073e9SAndroid Build Coastguard Worker //
175*da0073e9SAndroid Build Coastguard Worker // As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]).
176*da0073e9SAndroid Build Coastguard Worker // The result should be a tensor of size [B0, 2, 3].
177*da0073e9SAndroid Build Coastguard Worker // A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3]
178*da0073e9SAndroid Build Coastguard Worker // so the strategy here is to view it first as a tensor of size [B0, 1, 3] and
179*da0073e9SAndroid Build Coastguard Worker // then expand.
180*da0073e9SAndroid Build Coastguard Worker auto self_physical_size = self_physical.tensor().sizes();
181*da0073e9SAndroid Build Coastguard Worker auto extra_dims = size_physical.size() - self_physical_dim;
182*da0073e9SAndroid Build Coastguard Worker VmapDimVector view_shape(size_physical.size(), 1);
183*da0073e9SAndroid Build Coastguard Worker std::copy(self_physical_size.begin(),
184*da0073e9SAndroid Build Coastguard Worker self_physical_size.begin() + self_physical.numBatchDims(),
185*da0073e9SAndroid Build Coastguard Worker view_shape.begin());
186*da0073e9SAndroid Build Coastguard Worker std::copy(self_physical_size.begin() + self_physical.numBatchDims(),
187*da0073e9SAndroid Build Coastguard Worker self_physical_size.end(),
188*da0073e9SAndroid Build Coastguard Worker view_shape.begin() + self_physical.numBatchDims() + extra_dims);
189*da0073e9SAndroid Build Coastguard Worker auto result = self_physical.tensor().view(view_shape).expand(size_physical, implicit);
190*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
191*da0073e9SAndroid Build Coastguard Worker }
192*da0073e9SAndroid Build Coastguard Worker
chunk_batching_rule(const Tensor & self,int64_t chunks,int64_t dim)193*da0073e9SAndroid Build Coastguard Worker std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
194*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
195*da0073e9SAndroid Build Coastguard Worker auto dim_physical = self_physical.getPhysicalDim(dim);
196*da0073e9SAndroid Build Coastguard Worker auto result = at::chunk(self_physical.tensor(), chunks, dim_physical);
197*da0073e9SAndroid Build Coastguard Worker self_physical.getPhysicalToLogicalMap().applyInplace(result);
198*da0073e9SAndroid Build Coastguard Worker return result;
199*da0073e9SAndroid Build Coastguard Worker }
200*da0073e9SAndroid Build Coastguard Worker
clamp_batching_rule(const Tensor & self,const std::optional<Scalar> & min,const std::optional<Scalar> & max)201*da0073e9SAndroid Build Coastguard Worker Tensor clamp_batching_rule(const Tensor& self, const std::optional<Scalar>& min, const std::optional<Scalar>& max) {
202*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
203*da0073e9SAndroid Build Coastguard Worker auto result = at::clamp(self_physical.tensor(), min, max);
204*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
205*da0073e9SAndroid Build Coastguard Worker }
206*da0073e9SAndroid Build Coastguard Worker
clamp_min_batching_rule(const Tensor & self,const Scalar & min)207*da0073e9SAndroid Build Coastguard Worker Tensor clamp_min_batching_rule(const Tensor& self, const Scalar& min) {
208*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
209*da0073e9SAndroid Build Coastguard Worker auto result = at::clamp_min(self_physical.tensor(), min);
210*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
211*da0073e9SAndroid Build Coastguard Worker }
212*da0073e9SAndroid Build Coastguard Worker
clamp_max_batching_rule(const Tensor & self,const Scalar & max)213*da0073e9SAndroid Build Coastguard Worker Tensor clamp_max_batching_rule(const Tensor& self, const Scalar& max) {
214*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
215*da0073e9SAndroid Build Coastguard Worker auto result = at::clamp_max(self_physical.tensor(), max);
216*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
217*da0073e9SAndroid Build Coastguard Worker }
218*da0073e9SAndroid Build Coastguard Worker
tensor_split_sections_batching_rule(const Tensor & self,int64_t sections,int64_t dim)219*da0073e9SAndroid Build Coastguard Worker std::vector<Tensor> tensor_split_sections_batching_rule(const Tensor& self, int64_t sections, int64_t dim) {
220*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
221*da0073e9SAndroid Build Coastguard Worker auto dim_physical = self_physical.getPhysicalDim(dim);
222*da0073e9SAndroid Build Coastguard Worker auto result = at::tensor_split(self_physical.tensor(), sections, dim_physical);
223*da0073e9SAndroid Build Coastguard Worker self_physical.getPhysicalToLogicalMap().applyInplace(result);
224*da0073e9SAndroid Build Coastguard Worker return result;
225*da0073e9SAndroid Build Coastguard Worker }
226*da0073e9SAndroid Build Coastguard Worker
tensor_split_indices_batching_rule(const Tensor & self,IntArrayRef indices,int64_t dim)227*da0073e9SAndroid Build Coastguard Worker std::vector<Tensor> tensor_split_indices_batching_rule(const Tensor& self, IntArrayRef indices, int64_t dim) {
228*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
229*da0073e9SAndroid Build Coastguard Worker auto dim_physical = self_physical.getPhysicalDim(dim);
230*da0073e9SAndroid Build Coastguard Worker auto result = at::tensor_split(self_physical.tensor(), indices, dim_physical);
231*da0073e9SAndroid Build Coastguard Worker self_physical.getPhysicalToLogicalMap().applyInplace(result);
232*da0073e9SAndroid Build Coastguard Worker return result;
233*da0073e9SAndroid Build Coastguard Worker }
234*da0073e9SAndroid Build Coastguard Worker
unsqueeze_batching_rule(const Tensor & self,int64_t dim)235*da0073e9SAndroid Build Coastguard Worker Tensor unsqueeze_batching_rule(const Tensor& self, int64_t dim) {
236*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
237*da0073e9SAndroid Build Coastguard Worker // NB: unsqueeze has some special handling of its `dim` argument so we can't call
238*da0073e9SAndroid Build Coastguard Worker // self_physical.getPhysicalDim directly. In particular, native::unsqueeze
239*da0073e9SAndroid Build Coastguard Worker // wraps the dim to (the logical dimension) + 1, so we need to do that here too.
240*da0073e9SAndroid Build Coastguard Worker // https://github.com/pytorch/pytorch/blob/b623bdeabb0aa8da44285d303246e7f8ac06c2a9/aten/src/ATen/native/TensorShape.cpp#L1413
241*da0073e9SAndroid Build Coastguard Worker auto dim_physical =
242*da0073e9SAndroid Build Coastguard Worker self_physical.numBatchDims() + maybe_wrap_dim(dim, /*logical_dim*/self.dim() + 1);
243*da0073e9SAndroid Build Coastguard Worker auto result = self_physical.tensor().unsqueeze(dim_physical);
244*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
245*da0073e9SAndroid Build Coastguard Worker }
246*da0073e9SAndroid Build Coastguard Worker
fill_inplace_scalar_batching_rule(Tensor & self,const Scalar & value)247*da0073e9SAndroid Build Coastguard Worker Tensor& fill_inplace_scalar_batching_rule(Tensor& self, const Scalar& value) {
248*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
249*da0073e9SAndroid Build Coastguard Worker self_physical.tensor().fill_(value);
250*da0073e9SAndroid Build Coastguard Worker return self;
251*da0073e9SAndroid Build Coastguard Worker }
252*da0073e9SAndroid Build Coastguard Worker
fill_inplace_tensor_batching_rule(Tensor & self,const Tensor & value)253*da0073e9SAndroid Build Coastguard Worker Tensor& fill_inplace_tensor_batching_rule(Tensor& self, const Tensor& value) {
254*da0073e9SAndroid Build Coastguard Worker auto value_batched = isBatchedTensor(value);
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker if (value_batched) {
257*da0073e9SAndroid Build Coastguard Worker auto physical_args =
258*da0073e9SAndroid Build Coastguard Worker BroadcastingVmapTransform::logicalToPhysical({self, value});
259*da0073e9SAndroid Build Coastguard Worker physical_args[0].tensor().copy_(physical_args[1].tensor());
260*da0073e9SAndroid Build Coastguard Worker } else {
261*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
262*da0073e9SAndroid Build Coastguard Worker self_physical.tensor().fill_(value);
263*da0073e9SAndroid Build Coastguard Worker }
264*da0073e9SAndroid Build Coastguard Worker return self;
265*da0073e9SAndroid Build Coastguard Worker }
266*da0073e9SAndroid Build Coastguard Worker
zero_inplace_batching_rule(Tensor & self)267*da0073e9SAndroid Build Coastguard Worker Tensor& zero_inplace_batching_rule(Tensor &self) {
268*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
269*da0073e9SAndroid Build Coastguard Worker self_physical.tensor().zero_();
270*da0073e9SAndroid Build Coastguard Worker return self;
271*da0073e9SAndroid Build Coastguard Worker }
272*da0073e9SAndroid Build Coastguard Worker
squeeze_batching_rule(const Tensor & self)273*da0073e9SAndroid Build Coastguard Worker Tensor squeeze_batching_rule(const Tensor& self) {
274*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
275*da0073e9SAndroid Build Coastguard Worker auto physical_sizes = self_physical.tensor().sizes();
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker // Don't squeeze the batch dims!
278*da0073e9SAndroid Build Coastguard Worker VmapDimVector squeezed_sizes;
279*da0073e9SAndroid Build Coastguard Worker int64_t num_batch_dims = self_physical.numBatchDims();
280*da0073e9SAndroid Build Coastguard Worker squeezed_sizes.insert(
281*da0073e9SAndroid Build Coastguard Worker squeezed_sizes.end(),
282*da0073e9SAndroid Build Coastguard Worker physical_sizes.begin(),
283*da0073e9SAndroid Build Coastguard Worker physical_sizes.begin() + num_batch_dims);
284*da0073e9SAndroid Build Coastguard Worker for (auto it = physical_sizes.begin() + num_batch_dims; it != physical_sizes.end(); ++it) {
285*da0073e9SAndroid Build Coastguard Worker if (*it != 1) {
286*da0073e9SAndroid Build Coastguard Worker squeezed_sizes.push_back(*it);
287*da0073e9SAndroid Build Coastguard Worker }
288*da0073e9SAndroid Build Coastguard Worker }
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker auto result = self_physical.tensor().view(squeezed_sizes);
291*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
292*da0073e9SAndroid Build Coastguard Worker }
293*da0073e9SAndroid Build Coastguard Worker
squeeze_dim_batching_rule(const Tensor & self,int64_t dim)294*da0073e9SAndroid Build Coastguard Worker Tensor squeeze_dim_batching_rule(const Tensor& self, int64_t dim) {
295*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
296*da0073e9SAndroid Build Coastguard Worker auto dim_physical = self_physical.getPhysicalDim(dim);
297*da0073e9SAndroid Build Coastguard Worker auto result = self_physical.tensor().squeeze(dim_physical);
298*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
299*da0073e9SAndroid Build Coastguard Worker }
300*da0073e9SAndroid Build Coastguard Worker
squeeze_dims_batching_rule(const Tensor & self,IntArrayRef dims)301*da0073e9SAndroid Build Coastguard Worker Tensor squeeze_dims_batching_rule(const Tensor& self, IntArrayRef dims) {
302*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
303*da0073e9SAndroid Build Coastguard Worker auto dims_physical = self_physical.getPhysicalDims(dims);
304*da0073e9SAndroid Build Coastguard Worker auto result = self_physical.tensor().squeeze(dims_physical);
305*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
306*da0073e9SAndroid Build Coastguard Worker }
307*da0073e9SAndroid Build Coastguard Worker
trace_batching_rule(const Tensor & self)308*da0073e9SAndroid Build Coastguard Worker Tensor trace_batching_rule(const Tensor& self) {
309*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
310*da0073e9SAndroid Build Coastguard Worker // Batched Diagonal View
311*da0073e9SAndroid Build Coastguard Worker auto self_diag = at::diagonal(self_physical.tensor(), /*offset*/0, /*dim1*/-2, /*dim2*/-1);
312*da0073e9SAndroid Build Coastguard Worker auto result = at::sum(self_diag, -1);
313*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
314*da0073e9SAndroid Build Coastguard Worker }
315*da0073e9SAndroid Build Coastguard Worker
trace_backward_batching_rule(const Tensor & grad,IntArrayRef input_sizes)316*da0073e9SAndroid Build Coastguard Worker Tensor trace_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes) {
317*da0073e9SAndroid Build Coastguard Worker auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
318*da0073e9SAndroid Build Coastguard Worker auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
319*da0073e9SAndroid Build Coastguard Worker // Batched Diagonal View
320*da0073e9SAndroid Build Coastguard Worker auto grad_input_diag = at::diagonal(grad_input, /*offset*/0, /*dim1*/-2, /*dim2*/-1);
321*da0073e9SAndroid Build Coastguard Worker // Append a dimension of size one to the grad output
322*da0073e9SAndroid Build Coastguard Worker auto grad_physical_tensor = grad_physical.tensor().unsqueeze(-1);
323*da0073e9SAndroid Build Coastguard Worker grad_input_diag.copy_(grad_physical_tensor);
324*da0073e9SAndroid Build Coastguard Worker return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
325*da0073e9SAndroid Build Coastguard Worker }
326*da0073e9SAndroid Build Coastguard Worker
transpose_int_batching_rule(const Tensor & self,int64_t dim0,int64_t dim1)327*da0073e9SAndroid Build Coastguard Worker Tensor transpose_int_batching_rule(const Tensor& self, int64_t dim0, int64_t dim1) {
328*da0073e9SAndroid Build Coastguard Worker // PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works
329*da0073e9SAndroid Build Coastguard Worker // for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens:
330*da0073e9SAndroid Build Coastguard Worker // >>> x = torch.randn(B0) # the per-examples are all scalars
331*da0073e9SAndroid Build Coastguard Worker // >>> vmap(lambda x: x.transpose(0, -1), x)
332*da0073e9SAndroid Build Coastguard Worker // then we replicate this behavior.
333*da0073e9SAndroid Build Coastguard Worker if (/*logical*/self.dim() == 0 && is_allowed_dim_on_scalar_tensor(dim0) &&
334*da0073e9SAndroid Build Coastguard Worker is_allowed_dim_on_scalar_tensor(dim1)) {
335*da0073e9SAndroid Build Coastguard Worker return self;
336*da0073e9SAndroid Build Coastguard Worker }
337*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
338*da0073e9SAndroid Build Coastguard Worker auto dim0_physical = self_physical.getPhysicalDim(dim0);
339*da0073e9SAndroid Build Coastguard Worker auto dim1_physical = self_physical.getPhysicalDim(dim1);
340*da0073e9SAndroid Build Coastguard Worker auto result = self_physical.tensor().transpose(dim0_physical, dim1_physical);
341*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
342*da0073e9SAndroid Build Coastguard Worker }
343*da0073e9SAndroid Build Coastguard Worker
permute_batching_rule(const Tensor & self,IntArrayRef dims)344*da0073e9SAndroid Build Coastguard Worker Tensor permute_batching_rule(const Tensor& self, IntArrayRef dims) {
345*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
346*da0073e9SAndroid Build Coastguard Worker auto dims_physical = self_physical.getPhysicalDims(dims);
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker VmapDimVector all_dims_physical;
349*da0073e9SAndroid Build Coastguard Worker all_dims_physical.reserve(self_physical.tensor().dim());
350*da0073e9SAndroid Build Coastguard Worker for (const auto bdim : c10::irange(self_physical.numBatchDims())) {
351*da0073e9SAndroid Build Coastguard Worker all_dims_physical.push_back(bdim);
352*da0073e9SAndroid Build Coastguard Worker }
353*da0073e9SAndroid Build Coastguard Worker all_dims_physical.insert(
354*da0073e9SAndroid Build Coastguard Worker all_dims_physical.end(),
355*da0073e9SAndroid Build Coastguard Worker dims_physical.begin(),
356*da0073e9SAndroid Build Coastguard Worker dims_physical.end());
357*da0073e9SAndroid Build Coastguard Worker auto result = self_physical.tensor().permute(all_dims_physical);
358*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
359*da0073e9SAndroid Build Coastguard Worker }
360*da0073e9SAndroid Build Coastguard Worker
select_batching_rule(const Tensor & self,int64_t dim,int64_t index)361*da0073e9SAndroid Build Coastguard Worker Tensor select_batching_rule(const Tensor& self, int64_t dim, int64_t index) {
362*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
363*da0073e9SAndroid Build Coastguard Worker auto dim_physical = self_physical.getPhysicalDim(dim);
364*da0073e9SAndroid Build Coastguard Worker auto result = self_physical.tensor().select(dim_physical, index);
365*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
366*da0073e9SAndroid Build Coastguard Worker }
367*da0073e9SAndroid Build Coastguard Worker
getGradInputPhysicalDim(int64_t dim,IntArrayRef input_sizes,int64_t num_batch_dims)368*da0073e9SAndroid Build Coastguard Worker static int64_t getGradInputPhysicalDim(int64_t dim, IntArrayRef input_sizes, int64_t num_batch_dims) {
369*da0073e9SAndroid Build Coastguard Worker return maybe_wrap_dim(dim, input_sizes.size()) + num_batch_dims;
370*da0073e9SAndroid Build Coastguard Worker }
371*da0073e9SAndroid Build Coastguard Worker
select_backward_batching_rule(const Tensor & grad,IntArrayRef input_sizes,int64_t dim,int64_t index)372*da0073e9SAndroid Build Coastguard Worker Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
373*da0073e9SAndroid Build Coastguard Worker auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
374*da0073e9SAndroid Build Coastguard Worker auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
375*da0073e9SAndroid Build Coastguard Worker auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims());
376*da0073e9SAndroid Build Coastguard Worker grad_input.select(physical_dim, index).copy_(grad_physical.tensor());
377*da0073e9SAndroid Build Coastguard Worker return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
378*da0073e9SAndroid Build Coastguard Worker }
379*da0073e9SAndroid Build Coastguard Worker
slice_batching_rule(const Tensor & self,int64_t dim,std::optional<int64_t> start,std::optional<int64_t> end,int64_t step)380*da0073e9SAndroid Build Coastguard Worker Tensor slice_batching_rule(
381*da0073e9SAndroid Build Coastguard Worker const Tensor& self,
382*da0073e9SAndroid Build Coastguard Worker int64_t dim,
383*da0073e9SAndroid Build Coastguard Worker std::optional<int64_t> start,
384*da0073e9SAndroid Build Coastguard Worker std::optional<int64_t> end,
385*da0073e9SAndroid Build Coastguard Worker int64_t step) {
386*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
387*da0073e9SAndroid Build Coastguard Worker auto dim_physical = self_physical.getPhysicalDim(dim);
388*da0073e9SAndroid Build Coastguard Worker auto result = self_physical.tensor().slice(dim_physical, start, end, step);
389*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
390*da0073e9SAndroid Build Coastguard Worker }
391*da0073e9SAndroid Build Coastguard Worker
slice_backward_batching_rule(const Tensor & grad,IntArrayRef input_sizes,int64_t dim,int64_t start,int64_t end,int64_t step)392*da0073e9SAndroid Build Coastguard Worker Tensor slice_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
393*da0073e9SAndroid Build Coastguard Worker auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
394*da0073e9SAndroid Build Coastguard Worker auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
395*da0073e9SAndroid Build Coastguard Worker auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims());
396*da0073e9SAndroid Build Coastguard Worker grad_input.slice(physical_dim, start, end, step).copy_(grad_physical.tensor());
397*da0073e9SAndroid Build Coastguard Worker return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
398*da0073e9SAndroid Build Coastguard Worker }
399*da0073e9SAndroid Build Coastguard Worker
diagonal_batching_rule(const Tensor & self,int64_t offset,int64_t dim1,int64_t dim2)400*da0073e9SAndroid Build Coastguard Worker Tensor diagonal_batching_rule(const Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) {
401*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
402*da0073e9SAndroid Build Coastguard Worker auto dim1_physical = self_physical.getPhysicalDim(dim1);
403*da0073e9SAndroid Build Coastguard Worker auto dim2_physical = self_physical.getPhysicalDim(dim2);
404*da0073e9SAndroid Build Coastguard Worker auto result = at::diagonal(self_physical.tensor(), offset, dim1_physical, dim2_physical);
405*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
406*da0073e9SAndroid Build Coastguard Worker }
407*da0073e9SAndroid Build Coastguard Worker
diagonal_backward_batching_rule(const Tensor & grad,IntArrayRef input_sizes,int64_t offset,int64_t dim1,int64_t dim2)408*da0073e9SAndroid Build Coastguard Worker Tensor diagonal_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
409*da0073e9SAndroid Build Coastguard Worker auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
410*da0073e9SAndroid Build Coastguard Worker auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
411*da0073e9SAndroid Build Coastguard Worker auto dim1_physical = getGradInputPhysicalDim(dim1, input_sizes, grad_physical.numBatchDims());
412*da0073e9SAndroid Build Coastguard Worker auto dim2_physical = getGradInputPhysicalDim(dim2, input_sizes, grad_physical.numBatchDims());
413*da0073e9SAndroid Build Coastguard Worker grad_input.diagonal(offset, dim1_physical, dim2_physical).copy_(grad_physical.tensor());
414*da0073e9SAndroid Build Coastguard Worker return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
415*da0073e9SAndroid Build Coastguard Worker }
416*da0073e9SAndroid Build Coastguard Worker
movedim_batching_rule(const Tensor & self,IntArrayRef source,IntArrayRef destination)417*da0073e9SAndroid Build Coastguard Worker Tensor movedim_batching_rule(const Tensor& self, IntArrayRef source, IntArrayRef destination) {
418*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
419*da0073e9SAndroid Build Coastguard Worker auto source_physical = self_physical.getPhysicalDims(source);
420*da0073e9SAndroid Build Coastguard Worker auto destination_physical = self_physical.getPhysicalDims(destination);
421*da0073e9SAndroid Build Coastguard Worker auto result = at::movedim(self_physical.tensor(), source_physical, destination_physical);
422*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
423*da0073e9SAndroid Build Coastguard Worker }
424*da0073e9SAndroid Build Coastguard Worker
reshape_batching_rule(const Tensor & self,IntArrayRef shape)425*da0073e9SAndroid Build Coastguard Worker Tensor reshape_batching_rule(const Tensor& self, IntArrayRef shape) {
426*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
427*da0073e9SAndroid Build Coastguard Worker auto shape_physical = self_physical.getPhysicalShape(shape);
428*da0073e9SAndroid Build Coastguard Worker auto result = self_physical.tensor().reshape(shape_physical);
429*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
430*da0073e9SAndroid Build Coastguard Worker }
431*da0073e9SAndroid Build Coastguard Worker
split_batching_rule(const Tensor & self,int64_t split_size,int64_t dim)432*da0073e9SAndroid Build Coastguard Worker std::vector<Tensor> split_batching_rule(const Tensor& self, int64_t split_size, int64_t dim) {
433*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
434*da0073e9SAndroid Build Coastguard Worker auto dim_physical = self_physical.getPhysicalDim(dim);
435*da0073e9SAndroid Build Coastguard Worker auto result = at::split(self_physical.tensor(), split_size, dim_physical);
436*da0073e9SAndroid Build Coastguard Worker self_physical.getPhysicalToLogicalMap().applyInplace(result);
437*da0073e9SAndroid Build Coastguard Worker return result;
438*da0073e9SAndroid Build Coastguard Worker }
439*da0073e9SAndroid Build Coastguard Worker
split_with_sizes_batching_rule(const Tensor & self,IntArrayRef split_sizes,int64_t dim)440*da0073e9SAndroid Build Coastguard Worker std::vector<Tensor> split_with_sizes_batching_rule(const Tensor& self, IntArrayRef split_sizes, int64_t dim) {
441*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
442*da0073e9SAndroid Build Coastguard Worker auto dim_physical = self_physical.getPhysicalDim(dim);
443*da0073e9SAndroid Build Coastguard Worker auto result = at::split_with_sizes(self_physical.tensor(), split_sizes, dim_physical);
444*da0073e9SAndroid Build Coastguard Worker self_physical.getPhysicalToLogicalMap().applyInplace(result);
445*da0073e9SAndroid Build Coastguard Worker return result;
446*da0073e9SAndroid Build Coastguard Worker }
447*da0073e9SAndroid Build Coastguard Worker
unbind_batching_rule(const Tensor & self,int64_t dim)448*da0073e9SAndroid Build Coastguard Worker std::vector<Tensor> unbind_batching_rule(const Tensor& self, int64_t dim) {
449*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
450*da0073e9SAndroid Build Coastguard Worker auto dim_physical = self_physical.getPhysicalDim(dim);
451*da0073e9SAndroid Build Coastguard Worker auto result = at::unbind(self_physical.tensor(), dim_physical);
452*da0073e9SAndroid Build Coastguard Worker self_physical.getPhysicalToLogicalMap().applyInplace(result);
453*da0073e9SAndroid Build Coastguard Worker return result;
454*da0073e9SAndroid Build Coastguard Worker }
455*da0073e9SAndroid Build Coastguard Worker
unfold_batching_rule(const Tensor & self,int64_t dim,int64_t size,int64_t step)456*da0073e9SAndroid Build Coastguard Worker Tensor unfold_batching_rule(const Tensor& self, int64_t dim, int64_t size, int64_t step) {
457*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
458*da0073e9SAndroid Build Coastguard Worker auto dim_physical = self_physical.getPhysicalDim(dim);
459*da0073e9SAndroid Build Coastguard Worker auto result = self_physical.tensor().unfold(dim_physical, size, step);
460*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
461*da0073e9SAndroid Build Coastguard Worker }
462*da0073e9SAndroid Build Coastguard Worker
contiguous_batching_rule(const Tensor & self,MemoryFormat memory_format)463*da0073e9SAndroid Build Coastguard Worker Tensor contiguous_batching_rule(const Tensor& self, MemoryFormat memory_format) {
464*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
465*da0073e9SAndroid Build Coastguard Worker "NYI: Tensor.contiguous(...) inside of vmap for memory_format other ",
466*da0073e9SAndroid Build Coastguard Worker "than torch.contiguous_format");
467*da0073e9SAndroid Build Coastguard Worker auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
468*da0073e9SAndroid Build Coastguard Worker auto result = physical_view.tensor().contiguous(memory_format);
469*da0073e9SAndroid Build Coastguard Worker return physical_view.getPhysicalToLogicalMap().apply(result);
470*da0073e9SAndroid Build Coastguard Worker }
471*da0073e9SAndroid Build Coastguard Worker
view_batching_rule(const Tensor & self,IntArrayRef size)472*da0073e9SAndroid Build Coastguard Worker Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
473*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
474*da0073e9SAndroid Build Coastguard Worker auto size_physical = self_physical.getPhysicalShape(size);
475*da0073e9SAndroid Build Coastguard Worker auto result = self_physical.tensor().view(size_physical);
476*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
477*da0073e9SAndroid Build Coastguard Worker }
478*da0073e9SAndroid Build Coastguard Worker
view_as_complex_batching_rule(const Tensor & self)479*da0073e9SAndroid Build Coastguard Worker Tensor view_as_complex_batching_rule(const Tensor& self) {
480*da0073e9SAndroid Build Coastguard Worker // guard against the user passing in a batch of scalar tensors with batch
481*da0073e9SAndroid Build Coastguard Worker // size equal to 2.
482*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(!self.sizes().empty(), "Input tensor must have one or more dimensions");
483*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
484*da0073e9SAndroid Build Coastguard Worker auto result = at::view_as_complex(self_physical.tensor());
485*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
486*da0073e9SAndroid Build Coastguard Worker }
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker // Checks that the smallest batch stride is greater than the largest example
489*da0073e9SAndroid Build Coastguard Worker // stride. This is something we can support but we choose not to because it's
490*da0073e9SAndroid Build Coastguard Worker // potentially error prone.
checkBatchDimsAtFrontInLayout(IntArrayRef physical_strides,int64_t num_batch_dims)491*da0073e9SAndroid Build Coastguard Worker static void checkBatchDimsAtFrontInLayout(IntArrayRef physical_strides, int64_t num_batch_dims) {
492*da0073e9SAndroid Build Coastguard Worker auto smallest_batch_stride = std::min_element(
493*da0073e9SAndroid Build Coastguard Worker physical_strides.begin(), physical_strides.begin() + num_batch_dims);
494*da0073e9SAndroid Build Coastguard Worker auto largest_example_stride = std::max_element(
495*da0073e9SAndroid Build Coastguard Worker physical_strides.begin() + num_batch_dims, physical_strides.end());
496*da0073e9SAndroid Build Coastguard Worker if (largest_example_stride == physical_strides.end()) {
497*da0073e9SAndroid Build Coastguard Worker // No example dimensions
498*da0073e9SAndroid Build Coastguard Worker return;
499*da0073e9SAndroid Build Coastguard Worker }
500*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(*smallest_batch_stride >= *largest_example_stride,
501*da0073e9SAndroid Build Coastguard Worker "vmap: Calling Tensor.as_strided is not supported unless the batch dims being ",
502*da0073e9SAndroid Build Coastguard Worker "vmapped over are at the front of the tensor (in memory layout). When they are ",
503*da0073e9SAndroid Build Coastguard Worker "not at the front of the tensor this operation can be error prone so we "
504*da0073e9SAndroid Build Coastguard Worker "actively discourage it; please file us a bug report and/or try to ",
505*da0073e9SAndroid Build Coastguard Worker "express the as_strided operation in terms of PyTorch view operations");
506*da0073e9SAndroid Build Coastguard Worker }
507*da0073e9SAndroid Build Coastguard Worker
508*da0073e9SAndroid Build Coastguard Worker // given (sizes, strides, storage_offset) returns the maximum location that
509*da0073e9SAndroid Build Coastguard Worker // can be indexed (or nullopt if such a location doesn't exist, e.g., tensors
510*da0073e9SAndroid Build Coastguard Worker // with zero-size dims).
maximum_indexable_location(IntArrayRef sizes,IntArrayRef strides,int64_t storage_offset)511*da0073e9SAndroid Build Coastguard Worker static std::optional<int64_t> maximum_indexable_location(
512*da0073e9SAndroid Build Coastguard Worker IntArrayRef sizes, IntArrayRef strides, int64_t storage_offset) {
513*da0073e9SAndroid Build Coastguard Worker auto result = native::storage_size_for(sizes, strides);
514*da0073e9SAndroid Build Coastguard Worker if (result == 0) {
515*da0073e9SAndroid Build Coastguard Worker return std::nullopt;
516*da0073e9SAndroid Build Coastguard Worker }
517*da0073e9SAndroid Build Coastguard Worker return result + storage_offset;
518*da0073e9SAndroid Build Coastguard Worker }
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker // Let x be the "first slice" of physical_tensor.
521*da0073e9SAndroid Build Coastguard Worker // This checks that the range of possible memory locations accessible by
522*da0073e9SAndroid Build Coastguard Worker // x.as_strided(sizes, strides, maybe_storage_offset)
523*da0073e9SAndroid Build Coastguard Worker // are within the bounds of possible memory locations accessible by x.
checkBasicAsStridedValidForSlice(const Tensor & physical_tensor,int64_t num_batch_dims,IntArrayRef sizes,IntArrayRef strides,std::optional<int64_t> maybe_storage_offset)524*da0073e9SAndroid Build Coastguard Worker static void checkBasicAsStridedValidForSlice(
525*da0073e9SAndroid Build Coastguard Worker const Tensor& physical_tensor,
526*da0073e9SAndroid Build Coastguard Worker int64_t num_batch_dims,
527*da0073e9SAndroid Build Coastguard Worker IntArrayRef sizes,
528*da0073e9SAndroid Build Coastguard Worker IntArrayRef strides,
529*da0073e9SAndroid Build Coastguard Worker std::optional<int64_t> maybe_storage_offset) {
530*da0073e9SAndroid Build Coastguard Worker auto slice_sizes = physical_tensor.sizes().slice(num_batch_dims);
531*da0073e9SAndroid Build Coastguard Worker auto slice_strides = physical_tensor.strides().slice(num_batch_dims);
532*da0073e9SAndroid Build Coastguard Worker auto base_offset = physical_tensor.storage_offset();
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker auto storage_offset = maybe_storage_offset.value_or(base_offset);
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker auto max_as_strided_loc = maximum_indexable_location(sizes, strides, storage_offset);
537*da0073e9SAndroid Build Coastguard Worker auto max_slice_loc = maximum_indexable_location(slice_sizes, slice_strides, base_offset);
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker if (!max_as_strided_loc.has_value()) {
540*da0073e9SAndroid Build Coastguard Worker return;
541*da0073e9SAndroid Build Coastguard Worker }
542*da0073e9SAndroid Build Coastguard Worker if (!max_slice_loc.has_value()) {
543*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(false,
544*da0073e9SAndroid Build Coastguard Worker "result = tensor.as_strided(", sizes, ",", strides, ",", storage_offset, ")",
545*da0073e9SAndroid Build Coastguard Worker "can access memory outside of `tensor`. `tensor` has no storage but the ",
546*da0073e9SAndroid Build Coastguard Worker "passed-in (size, stride, storage_offset) imply a result with some storage. ",
547*da0073e9SAndroid Build Coastguard Worker "This is not supported inside of vmap, please try to rewrite the ",
548*da0073e9SAndroid Build Coastguard Worker "`as_strided` call as a sequence of PyTorch view operations");
549*da0073e9SAndroid Build Coastguard Worker }
550*da0073e9SAndroid Build Coastguard Worker
551*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
552*da0073e9SAndroid Build Coastguard Worker *max_as_strided_loc <= *max_slice_loc && base_offset <= storage_offset,
553*da0073e9SAndroid Build Coastguard Worker "result = tensor.as_strided(", sizes, ",", strides, ",", storage_offset, ")",
554*da0073e9SAndroid Build Coastguard Worker "can access memory outside of `tensor`. `result` can access some",
555*da0073e9SAndroid Build Coastguard Worker "memory in range [", storage_offset, ", ", *max_as_strided_loc, "], but ",
556*da0073e9SAndroid Build Coastguard Worker "`tensor` can only access some memory in range [", base_offset, ", ",
557*da0073e9SAndroid Build Coastguard Worker *max_slice_loc, "]. This is not supported inside of vmap, please try to",
558*da0073e9SAndroid Build Coastguard Worker "rewrite the `as_strided` call as a sequence of PyTorch view operations");
559*da0073e9SAndroid Build Coastguard Worker }
560*da0073e9SAndroid Build Coastguard Worker
_reshape_alias_batching_rule(const Tensor & self,IntArrayRef sizes,IntArrayRef strides)561*da0073e9SAndroid Build Coastguard Worker Tensor _reshape_alias_batching_rule(const Tensor& self, IntArrayRef sizes, IntArrayRef strides [[maybe_unused]]) {
562*da0073e9SAndroid Build Coastguard Worker return reshape_batching_rule(self, sizes);
563*da0073e9SAndroid Build Coastguard Worker }
564*da0073e9SAndroid Build Coastguard Worker
_new_zeros_with_same_feature_meta_batching_rule(const Tensor & self,const Tensor & other,int64_t unused_num_batch_dims)565*da0073e9SAndroid Build Coastguard Worker Tensor _new_zeros_with_same_feature_meta_batching_rule(
566*da0073e9SAndroid Build Coastguard Worker const Tensor& self,
567*da0073e9SAndroid Build Coastguard Worker const Tensor& other,
568*da0073e9SAndroid Build Coastguard Worker int64_t unused_num_batch_dims) {
569*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(isBatchedTensor(self) && !isBatchedTensor(other),
570*da0073e9SAndroid Build Coastguard Worker "Only the 'batched grad' use case is supported in PyTorch core.");
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(unused_num_batch_dims == 0,
573*da0073e9SAndroid Build Coastguard Worker "num_batch_dims should not be explicitly passed in because it will be overridden");
574*da0073e9SAndroid Build Coastguard Worker auto self_physical_view = at::MultiBatchVmapTransform::logicalToPhysical(self);
575*da0073e9SAndroid Build Coastguard Worker const auto& self_physical_tensor = self_physical_view.tensor();
576*da0073e9SAndroid Build Coastguard Worker int64_t num_batch_dims = self_physical_view.numBatchDims();
577*da0073e9SAndroid Build Coastguard Worker checkBatchDimsAtFrontInLayout(self_physical_tensor.strides(), num_batch_dims);
578*da0073e9SAndroid Build Coastguard Worker auto result = at::_new_zeros_with_same_feature_meta(self_physical_tensor, other, num_batch_dims);
579*da0073e9SAndroid Build Coastguard Worker return self_physical_view.getPhysicalToLogicalMap().apply(result);
580*da0073e9SAndroid Build Coastguard Worker }
581*da0073e9SAndroid Build Coastguard Worker
_has_same_storage_numel_batching_rule(const Tensor & self,const Tensor & other)582*da0073e9SAndroid Build Coastguard Worker bool _has_same_storage_numel_batching_rule(const Tensor& self, const Tensor& other) {
583*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(isBatchedTensor(self) && !isBatchedTensor(other),
584*da0073e9SAndroid Build Coastguard Worker "Only the 'batched grad' use case is supported in PyTorch core.");
585*da0073e9SAndroid Build Coastguard Worker // The _has_same_storage_numel check is skipped if the tangent is a batched
586*da0073e9SAndroid Build Coastguard Worker // tensor because using as_strided to access storage locations not indexable
587*da0073e9SAndroid Build Coastguard Worker // by the input tensor is not supported in vmap
588*da0073e9SAndroid Build Coastguard Worker return true;
589*da0073e9SAndroid Build Coastguard Worker }
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker // What are the semantics of as_strided inside of vmap?
592*da0073e9SAndroid Build Coastguard Worker // y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs)
593*da0073e9SAndroid Build Coastguard Worker // This returns a view on `x`, `y`, such that each y[i] has:
594*da0073e9SAndroid Build Coastguard Worker // - sizes: `sizes`
595*da0073e9SAndroid Build Coastguard Worker // - strides: `strides`
596*da0073e9SAndroid Build Coastguard Worker // - storage_offset: offset + i * x.stride(batch_dim)
597*da0073e9SAndroid Build Coastguard Worker //
598*da0073e9SAndroid Build Coastguard Worker // In other words, it is as if we had treated each x[i] as having storage
599*da0073e9SAndroid Build Coastguard Worker // offset equal to xs.offset() and called as_strided(sizes, sizes, offset).
600*da0073e9SAndroid Build Coastguard Worker // (that is equivalent to x[i].as_strided(
601*da0073e9SAndroid Build Coastguard Worker // sizes, sizes, offset + x[i].storage_offset() - xs.offset()) for all i)
602*da0073e9SAndroid Build Coastguard Worker //
603*da0073e9SAndroid Build Coastguard Worker // Note that this *may* be different from actually running as_strided
604*da0073e9SAndroid Build Coastguard Worker // in a for-loop. This is due to how as_strided takes in `offset` to be
605*da0073e9SAndroid Build Coastguard Worker // an *absolute* offset. As an example, consider:
606*da0073e9SAndroid Build Coastguard Worker // >>> x = torch.tensor([0., 1., 2., 3., 4.]).as_strided([4], [1], 1)
607*da0073e9SAndroid Build Coastguard Worker // >>> z = [x[i].as_strided([1], [1], 1) for i in range(4)]
608*da0073e9SAndroid Build Coastguard Worker // Each z[i] is actually the same view on x (z[i] == torch.tensor([1.]))!
609*da0073e9SAndroid Build Coastguard Worker // However, we consider the above for-loop comprehension to be a user error:
610*da0073e9SAndroid Build Coastguard Worker // a user should have written the following if they wanted to use as_strided
611*da0073e9SAndroid Build Coastguard Worker // in a per-sample way:
612*da0073e9SAndroid Build Coastguard Worker // >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset() - 1) for i in range(4)]
as_strided_batching_rule(const Tensor & tensor,IntArrayRef sizes,IntArrayRef strides,std::optional<int64_t> storage_offset)613*da0073e9SAndroid Build Coastguard Worker Tensor as_strided_batching_rule(
614*da0073e9SAndroid Build Coastguard Worker const Tensor& tensor,
615*da0073e9SAndroid Build Coastguard Worker IntArrayRef sizes,
616*da0073e9SAndroid Build Coastguard Worker IntArrayRef strides,
617*da0073e9SAndroid Build Coastguard Worker std::optional<int64_t> storage_offset) {
618*da0073e9SAndroid Build Coastguard Worker auto physical_view = at::MultiBatchVmapTransform::logicalToPhysical(tensor);
619*da0073e9SAndroid Build Coastguard Worker auto num_batch_dims = physical_view.numBatchDims();
620*da0073e9SAndroid Build Coastguard Worker auto physical_sizes = physical_view.getPhysicalShape(sizes);
621*da0073e9SAndroid Build Coastguard Worker const auto& physical_tensor = physical_view.tensor();
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker // We can't rely on the physical as_strided call to do this for us because
624*da0073e9SAndroid Build Coastguard Worker // we do some sanity checks on the size/strides before calling into as_strided.
625*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(sizes.size() == strides.size(),
626*da0073e9SAndroid Build Coastguard Worker "Tensor.as_strided(size, stride, ...): size and stride must have the ",
627*da0073e9SAndroid Build Coastguard Worker "same length! Got size ", sizes, " and stride ", strides);
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Worker // Sanity checks:
630*da0073e9SAndroid Build Coastguard Worker // 1. All batch dims are at the front in memory layout (not necessary for
631*da0073e9SAndroid Build Coastguard Worker // correctness, but we are worried the user might be doing crazy things)
632*da0073e9SAndroid Build Coastguard Worker // 2. as_strided(sizes, strides, storage_offset + tensor[i].offset() - tensor.offset())
633*da0073e9SAndroid Build Coastguard Worker // is valid for a slice of the input tensor.
634*da0073e9SAndroid Build Coastguard Worker // See Note: [When will the as_strided batching rule fail?] for details.
635*da0073e9SAndroid Build Coastguard Worker checkBatchDimsAtFrontInLayout(physical_tensor.strides(), num_batch_dims);
636*da0073e9SAndroid Build Coastguard Worker checkBasicAsStridedValidForSlice(
637*da0073e9SAndroid Build Coastguard Worker physical_tensor, num_batch_dims, sizes, strides, storage_offset);
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Worker // physical_strides = physical tensor's batch strides + (logical) strides
640*da0073e9SAndroid Build Coastguard Worker auto batch_strides = physical_tensor.strides().slice(0, num_batch_dims);
641*da0073e9SAndroid Build Coastguard Worker at::VmapDimVector physical_strides;
642*da0073e9SAndroid Build Coastguard Worker physical_strides.reserve(num_batch_dims + strides.size());
643*da0073e9SAndroid Build Coastguard Worker physical_strides.insert(
644*da0073e9SAndroid Build Coastguard Worker physical_strides.end(), batch_strides.begin(), batch_strides.end());
645*da0073e9SAndroid Build Coastguard Worker physical_strides.insert(
646*da0073e9SAndroid Build Coastguard Worker physical_strides.end(), strides.begin(), strides.end());
647*da0073e9SAndroid Build Coastguard Worker
648*da0073e9SAndroid Build Coastguard Worker // If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
649*da0073e9SAndroid Build Coastguard Worker // is valid for all i, then it turns out that
650*da0073e9SAndroid Build Coastguard Worker // xs.as_strided(physical_sizes, physical_strides, offset) always succeeds
651*da0073e9SAndroid Build Coastguard Worker // and creates a tensor y such that each y[i] references the same memory
652*da0073e9SAndroid Build Coastguard Worker // locations as zi. See NOTE: [When will the as_strided batching rule fail?]
653*da0073e9SAndroid Build Coastguard Worker auto result = physical_view.tensor().as_strided(
654*da0073e9SAndroid Build Coastguard Worker physical_sizes, physical_strides, storage_offset);
655*da0073e9SAndroid Build Coastguard Worker return physical_view.getPhysicalToLogicalMap().apply(result);
656*da0073e9SAndroid Build Coastguard Worker }
657*da0073e9SAndroid Build Coastguard Worker
658*da0073e9SAndroid Build Coastguard Worker // NOTE: [When will the as_strided batching rule fail?]
659*da0073e9SAndroid Build Coastguard Worker // If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
660*da0073e9SAndroid Build Coastguard Worker // is valid for all i, then it turns out that
661*da0073e9SAndroid Build Coastguard Worker // xs.as_strided(physical_sizes, physical_strides, offset) always succeeds and
662*da0073e9SAndroid Build Coastguard Worker // creates a tensor y such that each y[i] refers to the same memory as zi.
663*da0073e9SAndroid Build Coastguard Worker //
664*da0073e9SAndroid Build Coastguard Worker // Let's say we have xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()).
665*da0073e9SAndroid Build Coastguard Worker // Furthermore, let's say that as a part of being "valid" this as_strided call
666*da0073e9SAndroid Build Coastguard Worker // does not return a result that can index memory not indexable by xs[i].
667*da0073e9SAndroid Build Coastguard Worker //
668*da0073e9SAndroid Build Coastguard Worker // WLOG, assume that there's only one batch dim and it is at the front of the
669*da0073e9SAndroid Build Coastguard Worker // `xs` tensor. Let B be the batch size and S be the stride of the batch dim.
670*da0073e9SAndroid Build Coastguard Worker // - If the batch dim isn't at the front of the tensor, then we can just move it
671*da0073e9SAndroid Build Coastguard Worker // to the front with movedim/permute. This is always valid because it just swaps
672*da0073e9SAndroid Build Coastguard Worker // some strides around.
673*da0073e9SAndroid Build Coastguard Worker // - This proof also works for tensors with multiple batch dims. We just have to
674*da0073e9SAndroid Build Coastguard Worker // do a little accounting:
675*da0073e9SAndroid Build Coastguard Worker // - instead of [B], we'd have [B0, B1, ..., Bk].
676*da0073e9SAndroid Build Coastguard Worker // - instead of [S], we'd have [S0, S1, ..., Sk].
677*da0073e9SAndroid Build Coastguard Worker // - instead of i, we'd have a list of indices [I0, I1, ..., Ik]
678*da0073e9SAndroid Build Coastguard Worker // - instead of S * I, we'd have \sum_{i=0}^k S_i * I_i
679*da0073e9SAndroid Build Coastguard Worker //
680*da0073e9SAndroid Build Coastguard Worker // [Equation 1]
681*da0073e9SAndroid Build Coastguard Worker // xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) has:
682*da0073e9SAndroid Build Coastguard Worker // - sizes: sizes
683*da0073e9SAndroid Build Coastguard Worker // - strides: strides
684*da0073e9SAndroid Build Coastguard Worker // - offset: offset + S * i
685*da0073e9SAndroid Build Coastguard Worker //
686*da0073e9SAndroid Build Coastguard Worker // x.as_strided itself checks that:
687*da0073e9SAndroid Build Coastguard Worker // - (sizes, strides, offset) are in bounds for `x`'s storage.
688*da0073e9SAndroid Build Coastguard Worker // - strides are positive
689*da0073e9SAndroid Build Coastguard Worker // - offset is positive
690*da0073e9SAndroid Build Coastguard Worker //
691*da0073e9SAndroid Build Coastguard Worker // Claim 1: if xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
692*da0073e9SAndroid Build Coastguard Worker // is valid, then
693*da0073e9SAndroid Build Coastguard Worker // ([B] + sizes, [S] + strides, offset + xs.offset()) are in bounds for `xs`'s storage.
694*da0073e9SAndroid Build Coastguard Worker //
695*da0073e9SAndroid Build Coastguard Worker // If we have the claim, then xs.as_strided([B] + sizes, [S] + strides, offset)
696*da0073e9SAndroid Build Coastguard Worker // won't error out. So all we need to check is that the memory locations are
697*da0073e9SAndroid Build Coastguard Worker // what we expected. See [Hand-wavy proof of Claim 1] for proof (it's not very important)
698*da0073e9SAndroid Build Coastguard Worker //
699*da0073e9SAndroid Build Coastguard Worker // xs.as_strided(physical_sizes, physical_strides, offset) is equivalent to
700*da0073e9SAndroid Build Coastguard Worker // xs.as_strided([B] + sizes, [S] + strides, offset)
701*da0073e9SAndroid Build Coastguard Worker //
702*da0073e9SAndroid Build Coastguard Worker // xs.as_strided([B] + sizes, [S] + strides, offset) has:
703*da0073e9SAndroid Build Coastguard Worker // - sizes: [B] + sizes
704*da0073e9SAndroid Build Coastguard Worker // - strides: [S] + strides
705*da0073e9SAndroid Build Coastguard Worker // - offset: offset
706*da0073e9SAndroid Build Coastguard Worker //
707*da0073e9SAndroid Build Coastguard Worker // xs.as_strided([B] + sizes, [S] + strides, offset)[i] has:
708*da0073e9SAndroid Build Coastguard Worker // - sizes: sizes
709*da0073e9SAndroid Build Coastguard Worker // - strides: strides
710*da0073e9SAndroid Build Coastguard Worker // - offset: offset + S * i
711*da0073e9SAndroid Build Coastguard Worker // These memory locations are exactly the same as what we got for [Equation 1],
712*da0073e9SAndroid Build Coastguard Worker // so the xs.as_strided([B] + sizes, [S] + strides, offset) is valid.
713*da0073e9SAndroid Build Coastguard Worker //
714*da0073e9SAndroid Build Coastguard Worker // [Hand-wavy proof of Claim 1]
715*da0073e9SAndroid Build Coastguard Worker // Part of our definition of being valid is that xs[i].as_strided(...)
716*da0073e9SAndroid Build Coastguard Worker // must return a tensor that only uses memory indexable by xs[i].
717*da0073e9SAndroid Build Coastguard Worker // This means that (sizes, strides, offset + xs[i].offset() - xs.offset()) satisfies:
718*da0073e9SAndroid Build Coastguard Worker // offset + xs[i].offset() - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
719*da0073e9SAndroid Build Coastguard Worker // <= xs[i].offset() + 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
720*da0073e9SAndroid Build Coastguard Worker // (the largest-index memory location of xs[i].as_strided(...) must be \leq
721*da0073e9SAndroid Build Coastguard Worker // the largest-index memory location of xs[i])
722*da0073e9SAndroid Build Coastguard Worker //
723*da0073e9SAndroid Build Coastguard Worker // Fiddling that inequality gives us:
724*da0073e9SAndroid Build Coastguard Worker // offset - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
725*da0073e9SAndroid Build Coastguard Worker // <= 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
726*da0073e9SAndroid Build Coastguard Worker //
727*da0073e9SAndroid Build Coastguard Worker // offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
728*da0073e9SAndroid Build Coastguard Worker // <= 1 + (B-1)*S + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
729*da0073e9SAndroid Build Coastguard Worker //
730*da0073e9SAndroid Build Coastguard Worker // offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
731*da0073e9SAndroid Build Coastguard Worker // <= 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
732*da0073e9SAndroid Build Coastguard Worker //
733*da0073e9SAndroid Build Coastguard Worker // offset + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
734*da0073e9SAndroid Build Coastguard Worker // <= xs.offset() + 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
735*da0073e9SAndroid Build Coastguard Worker // (the largest-index memory location of xs.as_strided(size, stride, offset)
736*da0073e9SAndroid Build Coastguard Worker // is \leq than the largest-index memory location of xs)
737*da0073e9SAndroid Build Coastguard Worker // Under the assumptions we've made, the lower bound (lowest indexed memory)
738*da0073e9SAndroid Build Coastguard Worker // is trivially within the storage.
739*da0073e9SAndroid Build Coastguard Worker //
740*da0073e9SAndroid Build Coastguard Worker // Therefore ([B] + sizes, [S] + strides, offset) are in bounds for
741*da0073e9SAndroid Build Coastguard Worker // `xs`'s storage.
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker template <typename F, F Func, typename... ExtraArgs>
unwrap_and_call(const Tensor & input,ExtraArgs...args)744*da0073e9SAndroid Build Coastguard Worker Tensor unwrap_and_call(const Tensor& input, ExtraArgs... args) {
745*da0073e9SAndroid Build Coastguard Worker auto* input_batched = unsafeGetBatchedImpl(input);
746*da0073e9SAndroid Build Coastguard Worker auto output_physical = Func(input_batched->value(), args...);
747*da0073e9SAndroid Build Coastguard Worker auto old_bdims = input_batched->bdims();
748*da0073e9SAndroid Build Coastguard Worker return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
749*da0073e9SAndroid Build Coastguard Worker }
750*da0073e9SAndroid Build Coastguard Worker
751*da0073e9SAndroid Build Coastguard Worker template <typename F, F Func, typename... ExtraArgs>
unwrap_and_call_method(const Tensor & input,ExtraArgs...extra_args)752*da0073e9SAndroid Build Coastguard Worker Tensor unwrap_and_call_method(const Tensor& input, ExtraArgs... extra_args) {
753*da0073e9SAndroid Build Coastguard Worker auto* input_batched = unsafeGetBatchedImpl(input);
754*da0073e9SAndroid Build Coastguard Worker auto output_physical = (input_batched->value().*Func)(extra_args...);
755*da0073e9SAndroid Build Coastguard Worker auto old_bdims = input_batched->bdims();
756*da0073e9SAndroid Build Coastguard Worker return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
757*da0073e9SAndroid Build Coastguard Worker }
758*da0073e9SAndroid Build Coastguard Worker
pow_scalar_Tensor_batching_rule(const Scalar & other,const Tensor & self)759*da0073e9SAndroid Build Coastguard Worker Tensor pow_scalar_Tensor_batching_rule(const Scalar& other, const Tensor& self) {
760*da0073e9SAndroid Build Coastguard Worker auto* self_batched = unsafeGetBatchedImpl(self);
761*da0073e9SAndroid Build Coastguard Worker auto output_physical = at::pow(other, self_batched->value());
762*da0073e9SAndroid Build Coastguard Worker auto old_bdims = self_batched->bdims();
763*da0073e9SAndroid Build Coastguard Worker return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
764*da0073e9SAndroid Build Coastguard Worker }
765*da0073e9SAndroid Build Coastguard Worker
clone_batching_rule(const Tensor & self,std::optional<MemoryFormat> memory_format)766*da0073e9SAndroid Build Coastguard Worker Tensor clone_batching_rule(const Tensor& self, std::optional<MemoryFormat> memory_format) {
767*da0073e9SAndroid Build Coastguard Worker // Memory format support is a little tricky because vmap is allowed to move
768*da0073e9SAndroid Build Coastguard Worker // around batch dimensions and some memory formats are rank-dependent.
769*da0073e9SAndroid Build Coastguard Worker // Another weird case is:
770*da0073e9SAndroid Build Coastguard Worker // - a tensor with MemoryFormat::ChannelsLast MUST have 4 dimensions. Do we
771*da0073e9SAndroid Build Coastguard Worker // allow the user to clone a Tensor with 3 logical dimensions and 1 batch
772*da0073e9SAndroid Build Coastguard Worker // dim into a ChannelsLast Tensor? What about a Tensor with 3 logical dims
773*da0073e9SAndroid Build Coastguard Worker // and N>1 batch dims?
774*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(!memory_format.has_value() || memory_format == MemoryFormat::Preserve
775*da0073e9SAndroid Build Coastguard Worker || memory_format == MemoryFormat::Contiguous,
776*da0073e9SAndroid Build Coastguard Worker "NYI: Tensor.clone(memory_format) inside vmap is only supported with ",
777*da0073e9SAndroid Build Coastguard Worker "memory_format torch.preserve_format or torch.contiguous_format (got ",
778*da0073e9SAndroid Build Coastguard Worker *memory_format, ")");
779*da0073e9SAndroid Build Coastguard Worker
780*da0073e9SAndroid Build Coastguard Worker if (memory_format == MemoryFormat::Contiguous) {
781*da0073e9SAndroid Build Coastguard Worker // There is an ambiguity here when the batch dims are not at the front of
782*da0073e9SAndroid Build Coastguard Worker // the tensor.
783*da0073e9SAndroid Build Coastguard Worker // >>> x = torch.randn(3, B0, 5)
784*da0073e9SAndroid Build Coastguard Worker // >>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1, out_dims=0)(x)
785*da0073e9SAndroid Build Coastguard Worker // >>> y[0].is_contiguous()
786*da0073e9SAndroid Build Coastguard Worker // ???
787*da0073e9SAndroid Build Coastguard Worker // Should we make the whole tensor contiguous, or should we
788*da0073e9SAndroid Build Coastguard Worker // make the non-batch dims contiguous? We've chosen the latter because
789*da0073e9SAndroid Build Coastguard Worker // philosophically vmap hides the batch dims and operates on a per-sample level.
790*da0073e9SAndroid Build Coastguard Worker auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
791*da0073e9SAndroid Build Coastguard Worker auto output_physical = at::clone(physical_view.tensor(), memory_format);
792*da0073e9SAndroid Build Coastguard Worker return physical_view.getPhysicalToLogicalMap().apply(output_physical);
793*da0073e9SAndroid Build Coastguard Worker }
794*da0073e9SAndroid Build Coastguard Worker
795*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(!memory_format.has_value() || memory_format == MemoryFormat::Preserve);
796*da0073e9SAndroid Build Coastguard Worker auto* self_batched = unsafeGetBatchedImpl(self);
797*da0073e9SAndroid Build Coastguard Worker auto output_physical = at::clone(self_batched->value(), memory_format);
798*da0073e9SAndroid Build Coastguard Worker auto old_bdims = self_batched->bdims();
799*da0073e9SAndroid Build Coastguard Worker return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
800*da0073e9SAndroid Build Coastguard Worker }
801*da0073e9SAndroid Build Coastguard Worker
802*da0073e9SAndroid Build Coastguard Worker // Note [Batching rules for matmul-like operators]
803*da0073e9SAndroid Build Coastguard Worker // at::matmul doesn't "de-expand" arguments to get better performance (maybe
804*da0073e9SAndroid Build Coastguard Worker // it should). In the batching rules for matmul-like operators (dot, mv, mm),
805*da0073e9SAndroid Build Coastguard Worker // we should be careful not to expand any unnecessary dimensions. e.g., if
806*da0073e9SAndroid Build Coastguard Worker // only one of the two arguments is a BatchedTensor, then we should try
807*da0073e9SAndroid Build Coastguard Worker // not to expand batch dimensions onto the other arg.
mv_batching_rule(const Tensor & self,const Tensor & other)808*da0073e9SAndroid Build Coastguard Worker Tensor mv_batching_rule(const Tensor& self, const Tensor& other) {
809*da0073e9SAndroid Build Coastguard Worker auto self_batched = isBatchedTensor(self);
810*da0073e9SAndroid Build Coastguard Worker auto other_batched = isBatchedTensor(other);
811*da0073e9SAndroid Build Coastguard Worker
812*da0073e9SAndroid Build Coastguard Worker // A shape checking API would be nice...
813*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(self.dim() == 2 && other.dim() == 1,
814*da0073e9SAndroid Build Coastguard Worker "mv(self, other): Shape mismatch: expected matrix "
815*da0073e9SAndroid Build Coastguard Worker "(got `self` of size ", self.sizes(), ") ",
816*da0073e9SAndroid Build Coastguard Worker "and vector (got `other` of size ", other.sizes(), ")");
817*da0073e9SAndroid Build Coastguard Worker
818*da0073e9SAndroid Build Coastguard Worker // See Note [Batching rules for matmul-like operators] for why we have cases
819*da0073e9SAndroid Build Coastguard Worker if (self_batched && !other_batched) {
820*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
821*da0073e9SAndroid Build Coastguard Worker auto result = at::matmul(self_physical.tensor(), other);
822*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
823*da0073e9SAndroid Build Coastguard Worker }
824*da0073e9SAndroid Build Coastguard Worker if (!self_batched && other_batched) {
825*da0073e9SAndroid Build Coastguard Worker // self_physical: [L, K], other_physical: [..., K]
826*da0073e9SAndroid Build Coastguard Worker // We view the tensors as [L, K], [..., K, 1], perform matmul to get
827*da0073e9SAndroid Build Coastguard Worker // a tensor of size [..., L, 1], and unsqueeze the last dim.
828*da0073e9SAndroid Build Coastguard Worker auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
829*da0073e9SAndroid Build Coastguard Worker auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1));
830*da0073e9SAndroid Build Coastguard Worker return other_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
831*da0073e9SAndroid Build Coastguard Worker }
832*da0073e9SAndroid Build Coastguard Worker if (self_batched && other_batched) {
833*da0073e9SAndroid Build Coastguard Worker // self_physical: [..., L, K], other_physical: [..., K]
834*da0073e9SAndroid Build Coastguard Worker // We view the tensors as [..., L, K], [..., K, 1], perform matmul to get
835*da0073e9SAndroid Build Coastguard Worker // a tensor of size [..., L, 1], and unsqueeze the last dim.
836*da0073e9SAndroid Build Coastguard Worker auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
837*da0073e9SAndroid Build Coastguard Worker auto result = at::matmul(
838*da0073e9SAndroid Build Coastguard Worker physical_args[0].tensor(),
839*da0073e9SAndroid Build Coastguard Worker physical_args[1].tensor().unsqueeze(-1));
840*da0073e9SAndroid Build Coastguard Worker return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1));
841*da0073e9SAndroid Build Coastguard Worker }
842*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
843*da0073e9SAndroid Build Coastguard Worker }
844*da0073e9SAndroid Build Coastguard Worker
_make_dual_batching_rule(c10::DispatchKeySet ks,const Tensor & primal,const Tensor & tangent,int64_t level)845*da0073e9SAndroid Build Coastguard Worker Tensor _make_dual_batching_rule(
846*da0073e9SAndroid Build Coastguard Worker c10::DispatchKeySet ks,
847*da0073e9SAndroid Build Coastguard Worker const Tensor& primal,
848*da0073e9SAndroid Build Coastguard Worker const Tensor& tangent,
849*da0073e9SAndroid Build Coastguard Worker int64_t level
850*da0073e9SAndroid Build Coastguard Worker ) {
851*da0073e9SAndroid Build Coastguard Worker DispatchKeySet after_batched_keyset =
852*da0073e9SAndroid Build Coastguard Worker DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::Batched);
853*da0073e9SAndroid Build Coastguard Worker return at::redispatch::_make_dual(ks & after_batched_keyset, primal, tangent, level);
854*da0073e9SAndroid Build Coastguard Worker }
855*da0073e9SAndroid Build Coastguard Worker
dot_batching_rule(const Tensor & self,const Tensor & other)856*da0073e9SAndroid Build Coastguard Worker Tensor dot_batching_rule(const Tensor& self, const Tensor& other) {
857*da0073e9SAndroid Build Coastguard Worker auto self_batched = isBatchedTensor(self);
858*da0073e9SAndroid Build Coastguard Worker auto other_batched = isBatchedTensor(other);
859*da0073e9SAndroid Build Coastguard Worker
860*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(/*logical*/self.dim() == 1 && /*logical*/other.dim() == 1,
861*da0073e9SAndroid Build Coastguard Worker "dot(self, other): Shape mismatch: vector "
862*da0073e9SAndroid Build Coastguard Worker "(got `self` of size ", self.sizes(), ") ",
863*da0073e9SAndroid Build Coastguard Worker "and vector (got `other` of size ", other.sizes(), ")");
864*da0073e9SAndroid Build Coastguard Worker
865*da0073e9SAndroid Build Coastguard Worker // See Note [Batching rules for matmul-like operators] for why we have cases
866*da0073e9SAndroid Build Coastguard Worker if (self_batched && !other_batched) {
867*da0073e9SAndroid Build Coastguard Worker // self_physical: [..., K], other_physical: [K]
868*da0073e9SAndroid Build Coastguard Worker // View the tensors as [..., 1, K] and [K], perform matmul, and unsqueeze.
869*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
870*da0073e9SAndroid Build Coastguard Worker auto result = at::matmul(self_physical.tensor().unsqueeze(-2), other);
871*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
872*da0073e9SAndroid Build Coastguard Worker }
873*da0073e9SAndroid Build Coastguard Worker if (!self_batched && other_batched) {
874*da0073e9SAndroid Build Coastguard Worker // self_physical: [K], other_physical: [..., K]
875*da0073e9SAndroid Build Coastguard Worker // View the tensors as [K] and [..., K, 1], perform matmul, and unsqueeze.
876*da0073e9SAndroid Build Coastguard Worker auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
877*da0073e9SAndroid Build Coastguard Worker auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1));
878*da0073e9SAndroid Build Coastguard Worker return other_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
879*da0073e9SAndroid Build Coastguard Worker }
880*da0073e9SAndroid Build Coastguard Worker if (self_batched && other_batched) {
881*da0073e9SAndroid Build Coastguard Worker // self_physical: [..., K], other_physical: [..., K]
882*da0073e9SAndroid Build Coastguard Worker // View the tensors as [..., 1, K] and [..., K, 1], perform matmul, and unsqueeze.
883*da0073e9SAndroid Build Coastguard Worker auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
884*da0073e9SAndroid Build Coastguard Worker auto result = at::matmul(
885*da0073e9SAndroid Build Coastguard Worker physical_args[0].tensor().unsqueeze(-2),
886*da0073e9SAndroid Build Coastguard Worker physical_args[1].tensor().unsqueeze(-1));
887*da0073e9SAndroid Build Coastguard Worker return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1).squeeze(-1));
888*da0073e9SAndroid Build Coastguard Worker }
889*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
890*da0073e9SAndroid Build Coastguard Worker }
891*da0073e9SAndroid Build Coastguard Worker
bmm_batching_rule(const Tensor & self,const Tensor & other)892*da0073e9SAndroid Build Coastguard Worker Tensor bmm_batching_rule(const Tensor& self, const Tensor& other) {
893*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(/*logical*/self.dim() == 3 && /*logical*/other.dim() == 3,
894*da0073e9SAndroid Build Coastguard Worker "bmm(self, other): Shape mismatch: expected 3D `self` "
895*da0073e9SAndroid Build Coastguard Worker "(got `self` of size ", self.sizes(), ") ",
896*da0073e9SAndroid Build Coastguard Worker "and 3D `other` (got `other` of size ", other.sizes(), ")");
897*da0073e9SAndroid Build Coastguard Worker
898*da0073e9SAndroid Build Coastguard Worker auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
899*da0073e9SAndroid Build Coastguard Worker auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor());
900*da0073e9SAndroid Build Coastguard Worker return physical_args[0].getPhysicalToLogicalMap().apply(result);
901*da0073e9SAndroid Build Coastguard Worker }
902*da0073e9SAndroid Build Coastguard Worker
mm_batching_rule(const Tensor & self,const Tensor & other)903*da0073e9SAndroid Build Coastguard Worker Tensor mm_batching_rule(const Tensor& self, const Tensor& other) {
904*da0073e9SAndroid Build Coastguard Worker auto self_batched = isBatchedTensor(self);
905*da0073e9SAndroid Build Coastguard Worker auto other_batched = isBatchedTensor(other);
906*da0073e9SAndroid Build Coastguard Worker
907*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(/*logical*/self.dim() == 2 && /*logical*/other.dim() == 2,
908*da0073e9SAndroid Build Coastguard Worker "mm(self, other): Shape mismatch: expected matrix "
909*da0073e9SAndroid Build Coastguard Worker "(got `self` of size ", self.sizes(), ") ",
910*da0073e9SAndroid Build Coastguard Worker "and matrix (got `other` of size ", other.sizes(), ")");
911*da0073e9SAndroid Build Coastguard Worker
912*da0073e9SAndroid Build Coastguard Worker // See Note [Batching rules for matmul-like operators] for why we have cases
913*da0073e9SAndroid Build Coastguard Worker if (self_batched && !other_batched) {
914*da0073e9SAndroid Build Coastguard Worker auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
915*da0073e9SAndroid Build Coastguard Worker auto result = at::matmul(self_physical.tensor(), other);
916*da0073e9SAndroid Build Coastguard Worker return self_physical.getPhysicalToLogicalMap().apply(result);
917*da0073e9SAndroid Build Coastguard Worker }
918*da0073e9SAndroid Build Coastguard Worker if (!self_batched && other_batched) {
919*da0073e9SAndroid Build Coastguard Worker auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
920*da0073e9SAndroid Build Coastguard Worker auto result = at::matmul(self, other_physical.tensor());
921*da0073e9SAndroid Build Coastguard Worker return other_physical.getPhysicalToLogicalMap().apply(result);
922*da0073e9SAndroid Build Coastguard Worker }
923*da0073e9SAndroid Build Coastguard Worker if (self_batched && other_batched) {
924*da0073e9SAndroid Build Coastguard Worker auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
925*da0073e9SAndroid Build Coastguard Worker auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor());
926*da0073e9SAndroid Build Coastguard Worker return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1).squeeze(-1));
927*da0073e9SAndroid Build Coastguard Worker }
928*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
929*da0073e9SAndroid Build Coastguard Worker }
930*da0073e9SAndroid Build Coastguard Worker
cat_batching_rule(const ITensorListRef & tensors,int64_t dim)931*da0073e9SAndroid Build Coastguard Worker Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) {
932*da0073e9SAndroid Build Coastguard Worker auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
933*da0073e9SAndroid Build Coastguard Worker auto physical_tensors = fmap(
934*da0073e9SAndroid Build Coastguard Worker physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
935*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
936*da0073e9SAndroid Build Coastguard Worker !tensors.empty(), "The dispatcher should not have dispatched here otherwise.");
937*da0073e9SAndroid Build Coastguard Worker auto result = at::cat(physical_tensors, physical_views[0].getPhysicalDim(dim));
938*da0073e9SAndroid Build Coastguard Worker return physical_views[0].getPhysicalToLogicalMap().apply(result);
939*da0073e9SAndroid Build Coastguard Worker }
940*da0073e9SAndroid Build Coastguard Worker
stack_batching_rule(TensorList tensors,int64_t dim)941*da0073e9SAndroid Build Coastguard Worker Tensor stack_batching_rule(TensorList tensors, int64_t dim) {
942*da0073e9SAndroid Build Coastguard Worker auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
943*da0073e9SAndroid Build Coastguard Worker auto physical_tensors = fmap(
944*da0073e9SAndroid Build Coastguard Worker physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
945*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
946*da0073e9SAndroid Build Coastguard Worker !tensors.empty(), "The dispatcher should not have dispatched here otherwise.");
947*da0073e9SAndroid Build Coastguard Worker // NB: stack wraps the dimensionality to (logical dim + 1), so we have to
948*da0073e9SAndroid Build Coastguard Worker // manually handle that here.
949*da0073e9SAndroid Build Coastguard Worker auto dim_physical =
950*da0073e9SAndroid Build Coastguard Worker physical_views[0].numBatchDims() + maybe_wrap_dim(dim, /*logical*/tensors[0].dim() + 1);
951*da0073e9SAndroid Build Coastguard Worker auto result = at::stack(physical_tensors, dim_physical);
952*da0073e9SAndroid Build Coastguard Worker return physical_views[0].getPhysicalToLogicalMap().apply(result);
953*da0073e9SAndroid Build Coastguard Worker }
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker // I am quite sad that we need to register operators with exploded TensorOptions,
956*da0073e9SAndroid Build Coastguard Worker // even though the native:: implementations can use TensorOptions&.
957*da0073e9SAndroid Build Coastguard Worker // This also makes it hard to metaprogram: i.e., we can't use
958*da0073e9SAndroid Build Coastguard Worker // unwrap_and_call<..., at::to> because at::to takes TensorOptions& (!!)
to_dtype_layout_batching_rule(const Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,bool non_blocking,bool copy,std::optional<MemoryFormat> memory_format)959*da0073e9SAndroid Build Coastguard Worker Tensor to_dtype_layout_batching_rule(
960*da0073e9SAndroid Build Coastguard Worker const Tensor& self,
961*da0073e9SAndroid Build Coastguard Worker std::optional<ScalarType> dtype,
962*da0073e9SAndroid Build Coastguard Worker std::optional<Layout> layout,
963*da0073e9SAndroid Build Coastguard Worker std::optional<Device> device,
964*da0073e9SAndroid Build Coastguard Worker std::optional<bool> pin_memory,
965*da0073e9SAndroid Build Coastguard Worker bool non_blocking, bool copy,
966*da0073e9SAndroid Build Coastguard Worker std::optional<MemoryFormat> memory_format) {
967*da0073e9SAndroid Build Coastguard Worker auto options = TensorOptions()
968*da0073e9SAndroid Build Coastguard Worker .dtype(dtype)
969*da0073e9SAndroid Build Coastguard Worker .layout(layout)
970*da0073e9SAndroid Build Coastguard Worker .device(device)
971*da0073e9SAndroid Build Coastguard Worker .pinned_memory(pin_memory);
972*da0073e9SAndroid Build Coastguard Worker auto* input_batched = unsafeGetBatchedImpl(self);
973*da0073e9SAndroid Build Coastguard Worker auto output_physical = input_batched->value().to(options, non_blocking, copy, memory_format);
974*da0073e9SAndroid Build Coastguard Worker auto old_bdims = input_batched->bdims();
975*da0073e9SAndroid Build Coastguard Worker return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
976*da0073e9SAndroid Build Coastguard Worker }
977*da0073e9SAndroid Build Coastguard Worker
new_zeros_batching_rule(const Tensor & self,IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)978*da0073e9SAndroid Build Coastguard Worker Tensor new_zeros_batching_rule(
979*da0073e9SAndroid Build Coastguard Worker const Tensor& self,
980*da0073e9SAndroid Build Coastguard Worker IntArrayRef size,
981*da0073e9SAndroid Build Coastguard Worker std::optional<ScalarType> dtype,
982*da0073e9SAndroid Build Coastguard Worker std::optional<Layout> layout,
983*da0073e9SAndroid Build Coastguard Worker std::optional<Device> device,
984*da0073e9SAndroid Build Coastguard Worker std::optional<bool> pin_memory) {
985*da0073e9SAndroid Build Coastguard Worker auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
986*da0073e9SAndroid Build Coastguard Worker auto physical_size = physical_view.getPhysicalShape(size);
987*da0073e9SAndroid Build Coastguard Worker auto options = TensorOptions()
988*da0073e9SAndroid Build Coastguard Worker .dtype(dtype)
989*da0073e9SAndroid Build Coastguard Worker .layout(layout)
990*da0073e9SAndroid Build Coastguard Worker .device(device)
991*da0073e9SAndroid Build Coastguard Worker .pinned_memory(pin_memory);
992*da0073e9SAndroid Build Coastguard Worker auto result = physical_view.tensor().new_zeros(physical_size, options);
993*da0073e9SAndroid Build Coastguard Worker return physical_view.getPhysicalToLogicalMap().apply(result);
994*da0073e9SAndroid Build Coastguard Worker }
995*da0073e9SAndroid Build Coastguard Worker
new_empty_batching_rule(const Tensor & self,IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)996*da0073e9SAndroid Build Coastguard Worker Tensor new_empty_batching_rule(
997*da0073e9SAndroid Build Coastguard Worker const Tensor& self,
998*da0073e9SAndroid Build Coastguard Worker IntArrayRef size,
999*da0073e9SAndroid Build Coastguard Worker std::optional<ScalarType> dtype,
1000*da0073e9SAndroid Build Coastguard Worker std::optional<Layout> layout,
1001*da0073e9SAndroid Build Coastguard Worker std::optional<Device> device,
1002*da0073e9SAndroid Build Coastguard Worker std::optional<bool> pin_memory) {
1003*da0073e9SAndroid Build Coastguard Worker auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
1004*da0073e9SAndroid Build Coastguard Worker auto physical_size = physical_view.getPhysicalShape(size);
1005*da0073e9SAndroid Build Coastguard Worker auto result = physical_view.tensor().new_empty(physical_size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory));
1006*da0073e9SAndroid Build Coastguard Worker return physical_view.getPhysicalToLogicalMap().apply(result);
1007*da0073e9SAndroid Build Coastguard Worker }
1008*da0073e9SAndroid Build Coastguard Worker
new_empty_strided_batching_rule(const Tensor & self,IntArrayRef size,IntArrayRef stride,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)1009*da0073e9SAndroid Build Coastguard Worker Tensor new_empty_strided_batching_rule(
1010*da0073e9SAndroid Build Coastguard Worker const Tensor& self,
1011*da0073e9SAndroid Build Coastguard Worker IntArrayRef size,
1012*da0073e9SAndroid Build Coastguard Worker IntArrayRef stride,
1013*da0073e9SAndroid Build Coastguard Worker std::optional<ScalarType> dtype,
1014*da0073e9SAndroid Build Coastguard Worker std::optional<Layout> layout,
1015*da0073e9SAndroid Build Coastguard Worker std::optional<Device> device,
1016*da0073e9SAndroid Build Coastguard Worker std::optional<bool> pin_memory) {
1017*da0073e9SAndroid Build Coastguard Worker auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
1018*da0073e9SAndroid Build Coastguard Worker auto physical_size = physical_view.getPhysicalShape(size);
1019*da0073e9SAndroid Build Coastguard Worker
1020*da0073e9SAndroid Build Coastguard Worker // Let [B0, B1, B2] be the shape of the batch dims. We're going to create
1021*da0073e9SAndroid Build Coastguard Worker // the batch dimensions at the front of the tensor (in memory layout),
1022*da0073e9SAndroid Build Coastguard Worker // irrespective of whether or not they are actually at the front (in memory layout)
1023*da0073e9SAndroid Build Coastguard Worker // in the original `self` tensor. This is because when a user calls
1024*da0073e9SAndroid Build Coastguard Worker // `new_empty_strided` in general, the `strides` they provide are for a new
1025*da0073e9SAndroid Build Coastguard Worker // tensor and have no relation to the strides of the original tensor.
1026*da0073e9SAndroid Build Coastguard Worker //
1027*da0073e9SAndroid Build Coastguard Worker // So, the physical shape of the result should be ([B0, B1, B2] + size),
1028*da0073e9SAndroid Build Coastguard Worker // but what about the physical strides?
1029*da0073e9SAndroid Build Coastguard Worker //
1030*da0073e9SAndroid Build Coastguard Worker // We're actually free to pick whatever stride we want:
1031*da0073e9SAndroid Build Coastguard Worker // e.g., for size=[5, 3], stride=[0, 1], we could decide to
1032*da0073e9SAndroid Build Coastguard Worker // use
1033*da0073e9SAndroid Build Coastguard Worker // - physical size: [B0, B1, B2, 5, 3]
1034*da0073e9SAndroid Build Coastguard Worker // - physical stride: [9999*B1*B2, 9999*B2, 9999, 0, 1]
1035*da0073e9SAndroid Build Coastguard Worker //
1036*da0073e9SAndroid Build Coastguard Worker // Let's select some reasonable strides such that:
1037*da0073e9SAndroid Build Coastguard Worker // - The batch dims are "contiguous" with respect to each other
1038*da0073e9SAndroid Build Coastguard Worker // - if empty_strided(size, stride) would have created a contiguous Tensor,
1039*da0073e9SAndroid Build Coastguard Worker // then this new physical Tensor (with batch dims) is also contiguous
1040*da0073e9SAndroid Build Coastguard Worker //
1041*da0073e9SAndroid Build Coastguard Worker // Let S be the size of the storage if one were to construct a tensor
1042*da0073e9SAndroid Build Coastguard Worker // with `size` and `stride` via empty_strided(size, stride).
1043*da0073e9SAndroid Build Coastguard Worker // Then the physical sizes/strides should be:
1044*da0073e9SAndroid Build Coastguard Worker // - physical size: [B0, B1, B2, 5, 3]
1045*da0073e9SAndroid Build Coastguard Worker // - physical stride: [B1 * B2 * S, B2 * S, S, 0, 1]
1046*da0073e9SAndroid Build Coastguard Worker auto batch_shape = IntArrayRef(
1047*da0073e9SAndroid Build Coastguard Worker physical_view.tensor().sizes().begin(), physical_view.numBatchDims());
1048*da0073e9SAndroid Build Coastguard Worker
1049*da0073e9SAndroid Build Coastguard Worker // physical_strides = [B1 * B2 * S, B2 * S, S]
1050*da0073e9SAndroid Build Coastguard Worker auto physical_strides = at::detail::defaultStrides(batch_shape);
1051*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(size.size() == stride.size(),
1052*da0073e9SAndroid Build Coastguard Worker "new_empty_strided(sizes, strides): dimensionality of sizes (",
1053*da0073e9SAndroid Build Coastguard Worker size.size(), ") must match dimensionality of strides (",
1054*da0073e9SAndroid Build Coastguard Worker stride.size(), ")");
1055*da0073e9SAndroid Build Coastguard Worker auto storage_size = native::storage_size_for(size, stride);
1056*da0073e9SAndroid Build Coastguard Worker for (auto& physical_stride : physical_strides) {
1057*da0073e9SAndroid Build Coastguard Worker physical_stride *= storage_size;
1058*da0073e9SAndroid Build Coastguard Worker }
1059*da0073e9SAndroid Build Coastguard Worker
1060*da0073e9SAndroid Build Coastguard Worker // physical_strides = [B1 * B2 * S, B2 * S, S] + strides
1061*da0073e9SAndroid Build Coastguard Worker physical_strides.insert(physical_strides.end(), stride.begin(), stride.end());
1062*da0073e9SAndroid Build Coastguard Worker
1063*da0073e9SAndroid Build Coastguard Worker auto result = physical_view.tensor().new_empty_strided(
1064*da0073e9SAndroid Build Coastguard Worker physical_size, physical_strides, dtype, layout, device, pin_memory);
1065*da0073e9SAndroid Build Coastguard Worker return physical_view.getPhysicalToLogicalMap().apply(result);
1066*da0073e9SAndroid Build Coastguard Worker }
1067*da0073e9SAndroid Build Coastguard Worker
1068*da0073e9SAndroid Build Coastguard Worker template <typename F, F Func>
comparison_pointwise_batching_rule(const Tensor & self,const Tensor & other)1069*da0073e9SAndroid Build Coastguard Worker Tensor comparison_pointwise_batching_rule(const Tensor& self, const Tensor& other) {
1070*da0073e9SAndroid Build Coastguard Worker auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
1071*da0073e9SAndroid Build Coastguard Worker auto result = Func(physical_args[0].tensor(), physical_args[1].tensor());
1072*da0073e9SAndroid Build Coastguard Worker return physical_args[0].getPhysicalToLogicalMap().apply(result);
1073*da0073e9SAndroid Build Coastguard Worker }
1074*da0073e9SAndroid Build Coastguard Worker }
TORCH_LIBRARY_IMPL(_,Batched,m)1075*da0073e9SAndroid Build Coastguard Worker TORCH_LIBRARY_IMPL(_, Batched, m) {
1076*da0073e9SAndroid Build Coastguard Worker m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
1077*da0073e9SAndroid Build Coastguard Worker }
1078*da0073e9SAndroid Build Coastguard Worker
TORCH_LIBRARY_IMPL(aten,Batched,m)1079*da0073e9SAndroid Build Coastguard Worker TORCH_LIBRARY_IMPL(aten, Batched, m) {
1080*da0073e9SAndroid Build Coastguard Worker // NB: Ideally we would like some operators, like size.int, to "fallthrough"
1081*da0073e9SAndroid Build Coastguard Worker // to the underlying implementation. However, because a BatchedTensor is a
1082*da0073e9SAndroid Build Coastguard Worker // Tensor wrapper, it only has one dispatch key (Batched) on it. The resolution
1083*da0073e9SAndroid Build Coastguard Worker // here is to just directly call the underlying implementation.
1084*da0073e9SAndroid Build Coastguard Worker m.impl("size.int", static_cast<int64_t (*)(const Tensor&, int64_t)>(native::size));
1085*da0073e9SAndroid Build Coastguard Worker m.impl("_add_batch_dim", native::_add_batch_dim);
1086*da0073e9SAndroid Build Coastguard Worker m.impl("_remove_batch_dim", native::_remove_batch_dim);
1087*da0073e9SAndroid Build Coastguard Worker m.impl("_make_dual", _make_dual_batching_rule);
1088*da0073e9SAndroid Build Coastguard Worker m.impl("_has_same_storage_numel", _has_same_storage_numel_batching_rule);
1089*da0073e9SAndroid Build Coastguard Worker m.impl("is_same_size", native::is_same_size);
1090*da0073e9SAndroid Build Coastguard Worker m.impl("_new_zeros_with_same_feature_meta", _new_zeros_with_same_feature_meta_batching_rule);
1091*da0073e9SAndroid Build Coastguard Worker
1092*da0073e9SAndroid Build Coastguard Worker m.impl("sum.dim_IntList", sum_batching_rule);
1093*da0073e9SAndroid Build Coastguard Worker m.impl("is_complex", native::is_complex);
1094*da0073e9SAndroid Build Coastguard Worker
1095*da0073e9SAndroid Build Coastguard Worker // inplace operations
1096*da0073e9SAndroid Build Coastguard Worker m.impl("fill_.Scalar", fill_inplace_scalar_batching_rule);
1097*da0073e9SAndroid Build Coastguard Worker m.impl("fill_.Tensor", fill_inplace_tensor_batching_rule);
1098*da0073e9SAndroid Build Coastguard Worker m.impl("zero_", zero_inplace_batching_rule);
1099*da0073e9SAndroid Build Coastguard Worker
1100*da0073e9SAndroid Build Coastguard Worker // view operations
1101*da0073e9SAndroid Build Coastguard Worker m.impl("as_strided", as_strided_batching_rule);
1102*da0073e9SAndroid Build Coastguard Worker m.impl("chunk", chunk_batching_rule);
1103*da0073e9SAndroid Build Coastguard Worker m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
1104*da0073e9SAndroid Build Coastguard Worker m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
1105*da0073e9SAndroid Build Coastguard Worker m.impl("diagonal", diagonal_batching_rule);
1106*da0073e9SAndroid Build Coastguard Worker m.impl("expand", expand_batching_rule);
1107*da0073e9SAndroid Build Coastguard Worker m.impl("expand_as", native::expand_as); // composite wrt autograd
1108*da0073e9SAndroid Build Coastguard Worker m.impl("movedim.intlist", movedim_batching_rule);
1109*da0073e9SAndroid Build Coastguard Worker m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
1110*da0073e9SAndroid Build Coastguard Worker // There is another variant of narrow. However, we don't
1111*da0073e9SAndroid Build Coastguard Worker // want to support the other variant yet bc it isn't documented...
1112*da0073e9SAndroid Build Coastguard Worker m.impl("narrow", native::narrow_symint); // composite wrt autograd
1113*da0073e9SAndroid Build Coastguard Worker m.impl("numpy_T", native::numpy_T); // composite wrt autograd
1114*da0073e9SAndroid Build Coastguard Worker m.impl("matrix_H", native::matrix_H); // composite wrt autograd
1115*da0073e9SAndroid Build Coastguard Worker m.impl("mT", native::mT); // composite wrt autograd
1116*da0073e9SAndroid Build Coastguard Worker m.impl("mH", native::mH); // composite wrt autograd
1117*da0073e9SAndroid Build Coastguard Worker m.impl("permute", permute_batching_rule);
1118*da0073e9SAndroid Build Coastguard Worker m.impl("reshape", reshape_batching_rule);
1119*da0073e9SAndroid Build Coastguard Worker m.impl("_reshape_alias", _reshape_alias_batching_rule);
1120*da0073e9SAndroid Build Coastguard Worker m.impl("reshape_as", native::reshape_as); // composite wrt autograd
1121*da0073e9SAndroid Build Coastguard Worker m.impl("select.int", select_batching_rule);
1122*da0073e9SAndroid Build Coastguard Worker m.impl("slice.Tensor", slice_batching_rule);
1123*da0073e9SAndroid Build Coastguard Worker m.impl("split.Tensor", split_batching_rule);
1124*da0073e9SAndroid Build Coastguard Worker m.impl("split.sizes", split_with_sizes_batching_rule);
1125*da0073e9SAndroid Build Coastguard Worker m.impl("split_with_sizes", split_with_sizes_batching_rule);
1126*da0073e9SAndroid Build Coastguard Worker m.impl("squeeze", squeeze_batching_rule);
1127*da0073e9SAndroid Build Coastguard Worker m.impl("squeeze.dim", squeeze_dim_batching_rule);
1128*da0073e9SAndroid Build Coastguard Worker m.impl("squeeze.dims", squeeze_dims_batching_rule);
1129*da0073e9SAndroid Build Coastguard Worker m.impl("t", native::t); // composite wrt autograd
1130*da0073e9SAndroid Build Coastguard Worker m.impl("trace", trace_batching_rule);
1131*da0073e9SAndroid Build Coastguard Worker m.impl("transpose.int", transpose_int_batching_rule);
1132*da0073e9SAndroid Build Coastguard Worker m.impl("unbind.int", unbind_batching_rule);
1133*da0073e9SAndroid Build Coastguard Worker m.impl("unfold", unfold_batching_rule);
1134*da0073e9SAndroid Build Coastguard Worker m.impl("unsqueeze", unsqueeze_batching_rule);
1135*da0073e9SAndroid Build Coastguard Worker m.impl("view", view_batching_rule);
1136*da0073e9SAndroid Build Coastguard Worker m.impl("view_as", native::view_as); // composite wrt autograd
1137*da0073e9SAndroid Build Coastguard Worker
1138*da0073e9SAndroid Build Coastguard Worker // clamp operations
1139*da0073e9SAndroid Build Coastguard Worker m.impl("clamp", clamp_batching_rule);
1140*da0073e9SAndroid Build Coastguard Worker m.impl("clamp_min", clamp_min_batching_rule);
1141*da0073e9SAndroid Build Coastguard Worker m.impl("clamp_max", clamp_max_batching_rule);
1142*da0073e9SAndroid Build Coastguard Worker
1143*da0073e9SAndroid Build Coastguard Worker // unary pointwise, out-of-place, no additional arguments.
1144*da0073e9SAndroid Build Coastguard Worker #define UNARY_POINTWISE(op) m.impl(#op, \
1145*da0073e9SAndroid Build Coastguard Worker unwrap_and_call<Tensor (*)(const Tensor&), at::op>);
1146*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(abs);
1147*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(acos);
1148*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(asin);
1149*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(atan);
1150*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(ceil);
1151*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(cos);
1152*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(cosh);
1153*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(conj_physical);
1154*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(digamma);
1155*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(exp);
1156*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(expm1);
1157*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(floor);
1158*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(frac);
1159*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(lgamma);
1160*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(log);
1161*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(log10);
1162*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(log1p);
1163*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(log2);
1164*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(neg);
1165*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(reciprocal);
1166*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(relu);
1167*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(round);
1168*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(rsqrt);
1169*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(sigmoid);
1170*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(sign);
1171*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(sin);
1172*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(sinh);
1173*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(sqrt);
1174*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(tan);
1175*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(tanh);
1176*da0073e9SAndroid Build Coastguard Worker UNARY_POINTWISE(trunc);
1177*da0073e9SAndroid Build Coastguard Worker #undef UNARY_POINTWISE
1178*da0073e9SAndroid Build Coastguard Worker #define TO_BATCHING_RULE(name, ...) \
1179*da0073e9SAndroid Build Coastguard Worker { \
1180*da0073e9SAndroid Build Coastguard Worker using to_type = Tensor(Tensor::*)(__VA_ARGS__) const; \
1181*da0073e9SAndroid Build Coastguard Worker m.impl(name, unwrap_and_call_method< \
1182*da0073e9SAndroid Build Coastguard Worker to_type, &Tensor::to, __VA_ARGS__>);\
1183*da0073e9SAndroid Build Coastguard Worker }
1184*da0073e9SAndroid Build Coastguard Worker TO_BATCHING_RULE("to.device", Device, ScalarType, bool, bool, std::optional<MemoryFormat>)
1185*da0073e9SAndroid Build Coastguard Worker TO_BATCHING_RULE("to.dtype", ScalarType, bool, bool, std::optional<MemoryFormat>)
1186*da0073e9SAndroid Build Coastguard Worker TO_BATCHING_RULE("to.other", const Tensor&, bool, bool, std::optional<MemoryFormat>)
1187*da0073e9SAndroid Build Coastguard Worker m.impl("to.dtype_layout", to_dtype_layout_batching_rule);
1188*da0073e9SAndroid Build Coastguard Worker #undef TO_BATCHING_RULE
1189*da0073e9SAndroid Build Coastguard Worker m.impl("clone", clone_batching_rule);
1190*da0073e9SAndroid Build Coastguard Worker
1191*da0073e9SAndroid Build Coastguard Worker using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, const Scalar&);
1192*da0073e9SAndroid Build Coastguard Worker using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
1193*da0073e9SAndroid Build Coastguard Worker using TensorScalarType = Tensor (*)(const Tensor&, const Scalar&);
1194*da0073e9SAndroid Build Coastguard Worker
1195*da0073e9SAndroid Build Coastguard Worker #define BINARY_POINTWISE(op) \
1196*da0073e9SAndroid Build Coastguard Worker m.impl(#op".Tensor", binary_pointwise_batching_rule<TensorTensorType, at::op>); \
1197*da0073e9SAndroid Build Coastguard Worker m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, const Scalar&>);
1198*da0073e9SAndroid Build Coastguard Worker #define BINARY_POINTWISE_VA(op, ...) \
1199*da0073e9SAndroid Build Coastguard Worker { \
1200*da0073e9SAndroid Build Coastguard Worker using Binop = Tensor (*)(const Tensor&, const Tensor&, __VA_ARGS__); \
1201*da0073e9SAndroid Build Coastguard Worker using Unop = Tensor (*)(const Tensor&, const Scalar&, __VA_ARGS__); \
1202*da0073e9SAndroid Build Coastguard Worker m.impl(#op".Tensor", binary_pointwise_batching_rule<Binop, at::op, __VA_ARGS__>); \
1203*da0073e9SAndroid Build Coastguard Worker m.impl(#op".Scalar", unwrap_and_call<Unop, at::op, const Scalar&, __VA_ARGS__>); \
1204*da0073e9SAndroid Build Coastguard Worker }
1205*da0073e9SAndroid Build Coastguard Worker
1206*da0073e9SAndroid Build Coastguard Worker BINARY_POINTWISE_VA(add, const Scalar&);
1207*da0073e9SAndroid Build Coastguard Worker BINARY_POINTWISE_VA(sub, const Scalar&);
1208*da0073e9SAndroid Build Coastguard Worker BINARY_POINTWISE_VA(rsub, const Scalar&);
1209*da0073e9SAndroid Build Coastguard Worker BINARY_POINTWISE(mul);
1210*da0073e9SAndroid Build Coastguard Worker BINARY_POINTWISE(div);
1211*da0073e9SAndroid Build Coastguard Worker {
1212*da0073e9SAndroid Build Coastguard Worker using Binop = Tensor (*)(const Tensor&, const Tensor&, std::optional<c10::string_view>);
1213*da0073e9SAndroid Build Coastguard Worker using Unop = Tensor (*)(const Tensor&, const Scalar&, std::optional<c10::string_view>);
1214*da0073e9SAndroid Build Coastguard Worker m.impl("div.Tensor_mode", binary_pointwise_batching_rule<Binop, at::div, std::optional<c10::string_view>>);
1215*da0073e9SAndroid Build Coastguard Worker m.impl("div.Scalar_mode", unwrap_and_call<Unop, at::div, const Scalar&, std::optional<c10::string_view>>);
1216*da0073e9SAndroid Build Coastguard Worker }
1217*da0073e9SAndroid Build Coastguard Worker
1218*da0073e9SAndroid Build Coastguard Worker // at::pow has three out-of-place overloads
1219*da0073e9SAndroid Build Coastguard Worker m.impl("pow.Tensor_Tensor", binary_pointwise_batching_rule<TensorTensorType, at::pow>);
1220*da0073e9SAndroid Build Coastguard Worker m.impl("pow.Tensor_Scalar", unwrap_and_call<TensorScalarType, at::pow, const Scalar&>);
1221*da0073e9SAndroid Build Coastguard Worker m.impl("pow.Scalar", pow_scalar_Tensor_batching_rule);
1222*da0073e9SAndroid Build Coastguard Worker
1223*da0073e9SAndroid Build Coastguard Worker m.impl("sigmoid_backward", binary_pointwise_batching_rule<TensorTensorType, at::sigmoid_backward>);
1224*da0073e9SAndroid Build Coastguard Worker m.impl(
1225*da0073e9SAndroid Build Coastguard Worker "threshold_backward",
1226*da0073e9SAndroid Build Coastguard Worker binary_pointwise_batching_rule<
1227*da0073e9SAndroid Build Coastguard Worker TensorTensorScalarType,
1228*da0073e9SAndroid Build Coastguard Worker at::threshold_backward,
1229*da0073e9SAndroid Build Coastguard Worker const Scalar&>);
1230*da0073e9SAndroid Build Coastguard Worker
1231*da0073e9SAndroid Build Coastguard Worker // for at::result_type, call the native::result_type implementation.
1232*da0073e9SAndroid Build Coastguard Worker // We don't have to do anything special because native::result_type operates
1233*da0073e9SAndroid Build Coastguard Worker // on the logical shape of the tensors.
1234*da0073e9SAndroid Build Coastguard Worker m.impl("result_type.Tensor", static_cast<ScalarType (*)(const Tensor&, const Tensor&)>(native::result_type));
1235*da0073e9SAndroid Build Coastguard Worker m.impl("result_type.Scalar", static_cast<ScalarType (*)(const Tensor&, const Scalar&)>(native::result_type));
1236*da0073e9SAndroid Build Coastguard Worker m.impl("result_type.Scalar_Tensor", static_cast<ScalarType (*)(const Scalar&, const Tensor&)>(native::result_type));
1237*da0073e9SAndroid Build Coastguard Worker m.impl("result_type.Scalar_Scalar", static_cast<ScalarType (*)(const Scalar&, const Scalar&)>(native::result_type));
1238*da0073e9SAndroid Build Coastguard Worker
1239*da0073e9SAndroid Build Coastguard Worker #undef BINARY_POINTWISE_VA
1240*da0073e9SAndroid Build Coastguard Worker #undef BINARY_POINTWISE
1241*da0073e9SAndroid Build Coastguard Worker
1242*da0073e9SAndroid Build Coastguard Worker
1243*da0073e9SAndroid Build Coastguard Worker #define TRIVIAL_OP(op) m.impl(#op, \
1244*da0073e9SAndroid Build Coastguard Worker unwrap_and_call<Tensor (*)(const Tensor&), at::op>);
1245*da0073e9SAndroid Build Coastguard Worker // complex number view operators
1246*da0073e9SAndroid Build Coastguard Worker TRIVIAL_OP(imag)
1247*da0073e9SAndroid Build Coastguard Worker TRIVIAL_OP(real);
1248*da0073e9SAndroid Build Coastguard Worker TRIVIAL_OP(view_as_real);
1249*da0073e9SAndroid Build Coastguard Worker TRIVIAL_OP(conj);
1250*da0073e9SAndroid Build Coastguard Worker TRIVIAL_OP(_conj);
1251*da0073e9SAndroid Build Coastguard Worker TRIVIAL_OP(resolve_conj);
1252*da0073e9SAndroid Build Coastguard Worker TRIVIAL_OP(resolve_neg);
1253*da0073e9SAndroid Build Coastguard Worker m.impl("view_as_complex", view_as_complex_batching_rule);
1254*da0073e9SAndroid Build Coastguard Worker #undef TRIVIAL
1255*da0073e9SAndroid Build Coastguard Worker
1256*da0073e9SAndroid Build Coastguard Worker // matmul-like operators
1257*da0073e9SAndroid Build Coastguard Worker m.impl("mv", mv_batching_rule);
1258*da0073e9SAndroid Build Coastguard Worker m.impl("dot", dot_batching_rule);
1259*da0073e9SAndroid Build Coastguard Worker m.impl("bmm", bmm_batching_rule);
1260*da0073e9SAndroid Build Coastguard Worker m.impl("mm", mm_batching_rule);
1261*da0073e9SAndroid Build Coastguard Worker
1262*da0073e9SAndroid Build Coastguard Worker // cat/stack
1263*da0073e9SAndroid Build Coastguard Worker m.impl("cat", cat_batching_rule);
1264*da0073e9SAndroid Build Coastguard Worker m.impl("stack", stack_batching_rule);
1265*da0073e9SAndroid Build Coastguard Worker
1266*da0073e9SAndroid Build Coastguard Worker // backward operators
1267*da0073e9SAndroid Build Coastguard Worker m.impl("select_backward", select_backward_batching_rule);
1268*da0073e9SAndroid Build Coastguard Worker m.impl("slice_backward", slice_backward_batching_rule);
1269*da0073e9SAndroid Build Coastguard Worker m.impl("trace_backward", trace_backward_batching_rule);
1270*da0073e9SAndroid Build Coastguard Worker m.impl("diagonal_backward", diagonal_backward_batching_rule);
1271*da0073e9SAndroid Build Coastguard Worker
1272*da0073e9SAndroid Build Coastguard Worker // Tensor.new_* operators
1273*da0073e9SAndroid Build Coastguard Worker m.impl("new_empty", new_empty_batching_rule);
1274*da0073e9SAndroid Build Coastguard Worker m.impl("new_empty_strided", new_empty_strided_batching_rule);
1275*da0073e9SAndroid Build Coastguard Worker m.impl("new_zeros", new_zeros_batching_rule);
1276*da0073e9SAndroid Build Coastguard Worker
1277*da0073e9SAndroid Build Coastguard Worker m.impl("contiguous", contiguous_batching_rule);
1278*da0073e9SAndroid Build Coastguard Worker
1279*da0073e9SAndroid Build Coastguard Worker // Comparison ops
1280*da0073e9SAndroid Build Coastguard Worker #define COMPARISON_POINTWISE(op) \
1281*da0073e9SAndroid Build Coastguard Worker m.impl(#op".Tensor", comparison_pointwise_batching_rule<TensorTensorType, at::op>); \
1282*da0073e9SAndroid Build Coastguard Worker m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, const Scalar&>);
1283*da0073e9SAndroid Build Coastguard Worker
1284*da0073e9SAndroid Build Coastguard Worker COMPARISON_POINTWISE(eq);
1285*da0073e9SAndroid Build Coastguard Worker COMPARISON_POINTWISE(gt);
1286*da0073e9SAndroid Build Coastguard Worker COMPARISON_POINTWISE(ge);
1287*da0073e9SAndroid Build Coastguard Worker COMPARISON_POINTWISE(le);
1288*da0073e9SAndroid Build Coastguard Worker COMPARISON_POINTWISE(lt);
1289*da0073e9SAndroid Build Coastguard Worker COMPARISON_POINTWISE(ne);
1290*da0073e9SAndroid Build Coastguard Worker
1291*da0073e9SAndroid Build Coastguard Worker #undef COMPARISON_POINTWISE
1292*da0073e9SAndroid Build Coastguard Worker }
1293*da0073e9SAndroid Build Coastguard Worker
1294*da0073e9SAndroid Build Coastguard Worker } // namespace at
1295