xref: /aosp_15_r20/external/pytorch/aten/src/ATen/LegacyBatchedFallback.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Context.h>
2 #include <ATen/LegacyBatchedFallback.h>
3 #include <ATen/MatrixRef.h>
4 #include <ATen/LegacyVmapTransforms.h>
5 #include <ATen/core/dispatch/Dispatcher.h>
6 #include <c10/util/accumulate.h>
7 #include <c10/util/llvmMathExtras.h>
8 #include <c10/util/irange.h>
9 
10 namespace at {
11 
12 // Given a linear index, return the actual index.
13 // Example: Given linear_idx = 3, sizes = [5, 2], we would return [1, 0]
14 static SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize>
computeIndex(int64_t linear_idx,IntArrayRef sizes)15 computeIndex(int64_t linear_idx, IntArrayRef sizes) {
16   SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize> result;
17   result.reserve(sizes.size());
18   for (auto it = sizes.rbegin(); it != sizes.rend(); it++) {
19     auto remainder = linear_idx % *it;
20     result.push_back(remainder);
21     linear_idx -= remainder;
22     linear_idx /= *it;
23   }
24   std::reverse(std::begin(result), std::end(result));
25   return result;
26 }
27 
areAllReturnsTensors(const FunctionSchema & schema)28 static bool areAllReturnsTensors(const FunctionSchema& schema) {
29   return std::all_of(
30       schema.returns().begin(),
31       schema.returns().end(),
32       [] (const Argument& arg) { return arg.type() == TensorType::get(); });
33 }
34 
areAnyArgumentsTensorList(const FunctionSchema & schema)35 static bool areAnyArgumentsTensorList(const FunctionSchema& schema) {
36   return std::any_of(
37       schema.arguments().begin(),
38       schema.arguments().end(),
39       [] (const Argument& arg) { return arg.type()->isSubtypeOf(*ListType::ofTensors()); });
40 }
41 
42 // Returns if an operator is in-place. An operator is inplace if:
43 // 1. The first argument is a Tensor and it is being written to
44 // 2. The first argument is being returned
45 // 3. No other arguments are aliased
46 // Here is an example of an in-place operator:
47 // add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
isInplaceOp(const c10::FunctionSchema & schema)48 static bool isInplaceOp(const c10::FunctionSchema& schema) {
49   if (!schema.is_mutable() || schema.returns().size() != 1) {
50     return false;
51   }
52   // Check that the first argument is being written to
53   const AliasInfo* first_arg_alias_info = schema.arguments().begin()->alias_info();
54   if (!first_arg_alias_info || !first_arg_alias_info->isWrite()) {
55     return false;
56   }
57   // Check that none of the other args are being aliased
58   for (auto it = schema.arguments().begin() + 1; it != schema.arguments().end(); ++it) {
59     const AliasInfo* alias_info = it->alias_info();
60     if (alias_info) {
61       return false;
62     }
63   }
64   // Check that the first tensor is being returned (i.e., output has a (a!))
65   const AliasInfo* return_alias_info = schema.returns()[0].alias_info();
66   return return_alias_info && return_alias_info->isWrite();
67 }
68 
warnFallback(const c10::FunctionSchema & schema)69 static void warnFallback(const c10::FunctionSchema& schema) {
70   if (!globalContext().areVmapFallbackWarningsEnabled()) {
71     return;
72   }
73   TORCH_WARN("There is a performance drop because we have not yet implemented ",
74              "the batching rule for ", schema.operator_name(), ". ",
75              "You are using the legacy vmap prototype (torch._vmap_internals.vmap). ",
76              "If you are using torch.autograd.functional.{jacobian, hessian} ",
77              "or torch._vmap_internals.vmap: please switch to using ",
78              "torch.func.{jacrev, jacfwd, hessian} and/or torch.vmap instead ",
79              "for better operator coverage and performance improvements .");
80 }
81 
82 // The general flow of the algorithm is as follows.
83 // - First, we figure out which arguments are BatchedTensors and save them
84 //   to a vector. We also store a vector of which index of the arguments list
85 //   each BatchedTensor appears in. This will be useful for bookkeeping later.
86 // - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors.
87 //   This returns a vector of VmapPhysicalView that hold tensors that contain
88 //   all of the collective batch dimensions at the front of the tensors.
89 // - Then, we attempt to call `op` once per slice of the inputs. To do this,
90 //   we repeatedly we slice the input arguments (if they are BatchedTensors),
91 //   put the sliced (or a not-sliced) version of the input onto the stack, invoke
92 //   the operator, and then pop the results off the stack.
batchedTensorInplaceForLoopFallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)93 static void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
94   const auto& schema = op.schema();
95   warnFallback(schema);
96 
97   const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
98   const auto arguments = torch::jit::last(stack, num_arguments);
99   const auto arguments_begin = stack->size() - num_arguments;
100 
101   // `self` is the Tensor being modified in-place
102   Tensor self = arguments[0].toTensor();
103   const auto* self_impl = maybeGetBatchedImpl(self);
104   std::bitset<kVmapMaxTensorDims> self_vmap_levels;
105   if (self_impl) {
106     self_vmap_levels = createVmapLevelsBitset(self_impl->bdims());
107   }
108 
109   // Figure out which arguments are BatchedTensor. Save them to a vector.
110   // For each BatchedTensor, also record what position of `arguments` they came from.
111   SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
112   VmapDimVector batched_tensor_inputs_position;
113   for (const auto idx : c10::irange(arguments.size())) {
114     const auto& ivalue = arguments[idx];
115     if (!ivalue.isTensor()) {
116       continue;
117     }
118     const auto& tensor = ivalue.toTensor();
119     if (!tensor.defined()) {
120       continue;
121     }
122     const auto* batched = maybeGetBatchedImpl(tensor);
123     if (!batched) {
124       continue;
125     }
126 
127     // NOTE: [vmap-incompatible in-place operations]
128     // In-place operations on `self` are not possible if there exists some vmap
129     // level `l` such that `self` is not being vmapped on that level but another
130     // argument is. For example, let B0 be a batch dim inside vmap and consider
131     // vmap(Tensor.add_, in_dims=(None, 0))(torch.ones(3), torch.ones(B0, 3))
132     // - self is torch.ones(3) and does not participate in this vmap
133     // - other is BatchedTensor(torch.ones(B0, 3))
134     // There's no way to do self.add_(other) because `other` has more elements
135     // elements than `self` due to being vmapped over.
136     //
137     // In the vmap fallback, we should error out when we detect this.
138     auto other_vmap_levels = createVmapLevelsBitset(batched->bdims());
139     if (self_vmap_levels != (self_vmap_levels | other_vmap_levels)) {
140       // Find one vmap level to complain about
141       auto additional_bdims = (self_vmap_levels | other_vmap_levels) ^ self_vmap_levels;
142       [[maybe_unused]] auto offending_level = llvm::findLastSet(additional_bdims.to_ulong());
143       // The following prints out "vmap: aten::add_(tensor, ...) is not possible",
144       // but it would be better to print out "tensor.add_(...) is not possible".
145       // Afaict there's no official way to get the add_ and there is no way to
146       // tell if an operator has method or function variants.
147       TORCH_CHECK(false,
148         "vmap: ", schema.name(), "(self, *extra_args) is not possible because ",
149         "there exists a Tensor `other` in extra_args that has more elements ",
150         "than `self`. This happened due to `other` being vmapped over but ",
151         "`self` not being vmapped over at level ", offending_level, ". ",
152         "Please try to use out-of-place operators instead of ", schema.name(), ". ",
153         "If said operator is being called inside the PyTorch framework, ",
154         "please file a bug report instead.");
155     }
156     batched_tensor_inputs.push_back(tensor);
157     batched_tensor_inputs_position.push_back(idx);
158   }
159   TORCH_INTERNAL_ASSERT(!batched_tensor_inputs.empty());
160 
161   // MultiBatchVmapTransform the BatchedTensor arguments. This returns
162   // VmapPhysicalViews that contain all of the batch dimensions.
163   const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical(
164       batched_tensor_inputs);
165 
166   // Compute the total number of batches
167   auto num_batch_dims = input_physical_views.front().numBatchDims();
168   auto first_physical_view_sizes = input_physical_views.front().tensor().sizes();
169   auto batch_sizes = ArrayRef<int64_t>(
170       first_physical_view_sizes.begin(), first_physical_view_sizes.begin() + num_batch_dims);
171   const auto num_batches = c10::multiply_integers(batch_sizes);
172   // Without a shape-checking API, we're unable to compute the correct shape of
173   // the output so we just error out.
174   TORCH_CHECK(num_batches > 0,
175       "Batching rule not implemented for ", schema.operator_name(), ". ",
176       "The fallback path does not support vmap over dims of size 0.");
177 
178   // Strategy: For each batch, we are going to push slices (where applicable)
179   // of the arguments onto `stack`, and call `op`.
180   for (const auto linear_idx : c10::irange(num_batches)) {
181     auto index = computeIndex(linear_idx, batch_sizes);
182     auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
183     auto input_physical_views_iter = input_physical_views.begin();
184     for (const auto arg_idx : c10::irange(num_arguments)) {
185       // We assume that torch::jit::Stack is backed by vector<IValue> for
186       // simplicity. When that is not the case, this code should be updated.
187       const auto& argument = (*stack)[arguments_begin + arg_idx];
188       if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
189           || arg_idx != *batched_tensor_inputs_pos_iter) {
190         // argument isn't a BatchedTensor
191         torch::jit::push(stack, argument);
192         continue;
193       }
194       // argument is a BatchedTensor
195       TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
196       const auto& physical_view_for_argument = *input_physical_views_iter;
197       torch::jit::push(stack, physical_view_for_argument.tensor().index(index));
198       batched_tensor_inputs_pos_iter++;
199       input_physical_views_iter++;
200     }
201 
202     op.callBoxed(stack);
203     torch::jit::drop(stack, 1);
204   }
205 
206   // Return the tensor that was written to in-place
207   torch::jit::drop(stack, num_arguments);
208   torch::jit::push(stack, self);
209 }
210 
safeStack(TensorList tensors)211 static Tensor safeStack(TensorList tensors) {
212   auto is_defined = [](const Tensor& t) { return t.defined(); };
213   if (std::all_of(tensors.begin(), tensors.end(), is_defined)) {
214     return at::stack(tensors);
215   }
216   // NOTE [vmap through backward and undefined grad]
217   // While vmapping through backward functions (to compute batched grad), it
218   // is possible for the backward function to return an undefined grad for some
219   // grad_input for each example. In that case, we return an undefined grad.
220   //
221   // It is theoretically posssible for *some* of the examples to produce an
222   // undefined grad (a kernel could peek at the gradient values and return an
223   // undefined tensor if it determines the gradient is full of zeros). We
224   // could handle this by treating the undefined grad as a zero-filled tensor
225   // of the correct shape while stacking the tensors together. However I expect
226   // this to happen very rarely (I have not been able to find an example in our
227   // codebase) so we just error out in this case.
228   if (std::none_of(tensors.begin(), tensors.end(), is_defined)) {
229     return Tensor();
230   }
231   TORCH_CHECK(false,
232       "vmap: slow fallback received a mix of undefined and defined tensors ",
233       "as the result of an operation. This is not supported, please file us ",
234       "an issue on github.");
235 }
236 
237 // The general flow of the algorithm is as follows.
238 // - First, we figure out which arguments are BatchedTensors and save them
239 //   to a vector. We also store a vector of which index of the arguments list
240 //   each BatchedTensor appears in. This will be useful for bookkeeping later.
241 // - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors.
242 //   This returns a vector of VmapPhysicalView that hold tensors that contain
243 //   all of the collective batch dimensions at the front of the tensors.
244 // - Then, we attempt to call `op` once per slice of the inputs. To do this,
245 //   we repeatedly we slice the input arguments (if they are BatchedTensors),
246 //   put the sliced (or a not-sliced) version of the input onto the stack, invoke
247 //   the operator, and then pop the results off the stack.
248 // - Each result obtained from the previous step is a slice of the total result,
249 //   so we stack those tensors together to form the final result.
batchedTensorForLoopFallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)250 void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
251   const auto& schema = op.schema();
252   const auto num_returns = schema.returns().size();
253 
254   if (isInplaceOp(schema)) {
255     batchedTensorInplaceForLoopFallback(op, stack);
256     return;
257   }
258   TORCH_CHECK(!schema.is_mutable() && !schema.hasAnyAliasInfo(),
259               "Batching rule not implemented for ", schema.operator_name(), "; ",
260               "the fallback path doesn't work on out= or view ops.");
261   TORCH_CHECK(areAllReturnsTensors(schema) && !areAnyArgumentsTensorList(schema),
262               "Batching rule not implemented for ", schema.operator_name(), ". ",
263               "We could not generate a fallback.");
264   TORCH_CHECK(num_returns >= 1,
265               "Batching rule not implemented for ", schema.operator_name(), ". ",
266               "The fallback path does not support operations with no returns.");
267   warnFallback(schema);
268 
269   const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
270   const auto arguments = torch::jit::last(stack, num_arguments);
271   const auto arguments_begin = stack->size() - num_arguments;
272 
273   // Figure out which arguments are BatchedTensor. Save them to a vector.
274   // For each BatchedTensor, also record what position of `arguments` they came from.
275   SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
276   VmapDimVector batched_tensor_inputs_position;
277   for (const auto idx : c10::irange(arguments.size())) {
278     const auto& ivalue = arguments[idx];
279     if (!ivalue.isTensor()) {
280       continue;
281     }
282     const auto& tensor = ivalue.toTensor();
283     if (!tensor.defined()) {
284       continue;
285     }
286     const auto* batched = maybeGetBatchedImpl(tensor);
287     if (!batched) {
288       continue;
289     }
290     batched_tensor_inputs.push_back(tensor);
291     batched_tensor_inputs_position.push_back(idx);
292   }
293   TORCH_INTERNAL_ASSERT(!batched_tensor_inputs.empty());
294 
295   // MultiBatchVmapTransform the BatchedTensor arguments. This returns
296   // VmapPhysicalViews that contain all of the batch dimensions.
297   const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical(
298       batched_tensor_inputs);
299 
300   // Compute the total number of batches
301   auto num_batch_dims = input_physical_views.front().numBatchDims();
302   auto some_sizes = input_physical_views.front().tensor().sizes();
303   auto batch_sizes = ArrayRef<int64_t>(some_sizes.begin(), some_sizes.begin() + num_batch_dims);
304   const auto num_batches = c10::multiply_integers(batch_sizes);
305   // Without a shape-checking API, we're unable to compute the correct shape of
306   // the output so we just error out.
307   TORCH_CHECK(num_batches > 0,
308       "Batching rule not implemented for ", schema.operator_name(), ". ",
309       "The fallback path does not support vmap over dims of size 0.");
310 
311   // Strategy: For each batch, we are going to push slices (where applicable)
312   // of the arguments onto `stack`, call `op`, and store the result in
313   // `output_shards`.
314   //
315   // NOTE: [Output shards layout]
316   // Assume that the operator has three outputs: a, b, c.
317   // The layout of output_shards is as follows:
318   // [ a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3]
319   // This is so that we can call at::stack([a0...a3]), at::stack([b0...b3])
320   // more easily in the next step.
321   std::vector<Tensor> output_shards(num_batches * num_returns);
322 
323   for (const auto linear_idx : c10::irange(num_batches)) {
324     auto index = computeIndex(linear_idx, batch_sizes);
325     auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
326     auto input_physical_views_iter = input_physical_views.begin();
327     for (const auto arg_idx : c10::irange(num_arguments)) {
328       // We assume that torch::jit::Stack is backed by vector<IValue> for
329       // simplicity. When that is not the case, this code should be updated.
330       const auto& argument = (*stack)[arguments_begin + arg_idx];
331       if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
332           || arg_idx != *batched_tensor_inputs_pos_iter) {
333         // argument isn't a BatchedTensor
334         torch::jit::push(stack, argument);
335         continue;
336       }
337       // argument is a BatchedTensor
338       TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
339       const auto& physical_view_for_argument = *input_physical_views_iter;
340       torch::jit::push(stack, physical_view_for_argument.tensor().index(index));
341       batched_tensor_inputs_pos_iter++;
342       input_physical_views_iter++;
343     }
344 
345     op.callBoxed(stack);
346 
347     // Store the result into `output_shards`. See NOTE: [Output shards layout]
348     // to learn about the details of how we store the shards.
349     const auto returns = torch::jit::last(stack, num_returns);
350     for (const auto return_idx : c10::irange(returns.size())) {
351       output_shards[num_batches * return_idx + linear_idx] = returns[return_idx].toTensor();
352     }
353     torch::jit::drop(stack, num_returns);
354   }
355 
356   // For each output Tensor, stack the shards of the tensor together to form a return
357   torch::jit::drop(stack, num_arguments);
358   auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_batches);
359   for (const auto return_idx : c10::irange(num_returns)) {
360     auto shards = output_shards_chunks[return_idx];
361     auto flat_output = safeStack(shards);
362     // See NOTE [vmap through backward and undefined grad]
363     if (!flat_output.defined()) {
364       torch::jit::push(stack, flat_output);
365       continue;
366     }
367     VmapDimVector output_sizes(batch_sizes);
368     output_sizes.insert(
369         output_sizes.end(),
370         flat_output.sizes().begin() + 1,
371         flat_output.sizes().end());
372     torch::jit::push(
373         stack,
374         input_physical_views.front().getPhysicalToLogicalMap().apply(flat_output.view(output_sizes)));
375   }
376 }
377 
378 } // namespace at
379