xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/PlumbingHelper.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/functorch/TensorWrapper.h>
8 #include <ATen/functorch/DynamicLayer.h>
9 #include <ATen/functorch/BatchedTensorImpl.h>
10 #include <ATen/functorch/PlumbingHelper.h>
11 
12 namespace at::functorch {
13 
vmap_check_escaped(const std::optional<DynamicLayer> & layer,const char * what)14 void vmap_check_escaped(const std::optional<DynamicLayer> &layer, const char* what) {
15   TORCH_CHECK(
16     layer.has_value(),
17     "Either your tensor may have escaped from inside a function being vmapped and this is a user error ",
18     "(see https://pytorch.org/functorch/stable/ux_limitations.html), "
19     "or there is an internal functorch error in `",
20     what,
21     "` Please file an issue if it looks like the latter"
22   )
23 }
24 
makeBatched(const Tensor & tensor,std::optional<int64_t> bdim,int64_t level)25 Tensor makeBatched(const Tensor& tensor, std::optional<int64_t> bdim, int64_t level) {
26   if (bdim.has_value()) {
27     TORCH_INTERNAL_ASSERT(*bdim >= 0);
28     TORCH_INTERNAL_ASSERT(*bdim < tensor.dim());
29     return makeBatched(tensor, bdim.value(), level);
30   }
31   return tensor;
32 }
33 
makeBatchedVector(const std::vector<Tensor> & tensors,std::optional<int64_t> bdim,int64_t level)34 std::vector<Tensor> makeBatchedVector(const std::vector<Tensor>& tensors, std::optional<int64_t> bdim, int64_t level) {
35   std::vector<Tensor> res;
36   res.reserve(tensors.size());
37   for (const auto & tensor : tensors) {
38     res.emplace_back(makeBatched(tensor, bdim, level));
39   }
40   return res;
41 }
42 
unwrapTensorAtLevel(const Tensor & tensor,int64_t level)43 std::tuple<Tensor, std::optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level) {
44   auto* batched = maybeGetBatchedImpl(tensor);
45   if (!batched) {
46     return std::make_tuple(tensor, std::nullopt);
47   }
48   if (batched->level() == level) {
49     return std::make_tuple(batched->value(), batched->bdim());
50   }
51   return std::make_tuple(tensor, std::nullopt);
52 }
53 
isBatchedAtLevel(const Tensor & tensor,int64_t level)54 bool isBatchedAtLevel(const Tensor& tensor, int64_t level) {
55   auto result = unwrapTensorAtLevel(tensor, level);
56   return std::get<1>(result).has_value();
57 }
58 
isBatchedAtLevel(const std::optional<Tensor> & maybe_tensor,int64_t level)59 bool isBatchedAtLevel(const std::optional<Tensor>& maybe_tensor, int64_t level) {
60   if (!maybe_tensor.has_value()) {
61     return false;
62   }
63   return isBatchedAtLevel(*maybe_tensor, level);
64 }
65 
isBatchedAtLevel(ITensorListRef tensors,int64_t level)66 bool isBatchedAtLevel(ITensorListRef tensors, int64_t level) {
67   for (const auto& tensor : tensors) {
68     if (isBatchedAtLevel(tensor, level)) {
69       return true;
70     }
71   }
72   return false;
73 }
74 
isBatchedAtLevel(const c10::List<std::optional<Tensor>> & maybe_tensors,int64_t level)75 bool isBatchedAtLevel(const c10::List<std::optional<Tensor>>& maybe_tensors, int64_t level) {
76   for (const auto idx : c10::irange(0, maybe_tensors.size())) {
77     const auto& maybe_tensor = maybe_tensors.get(idx);
78     if (isBatchedAtLevel(maybe_tensor, level)) {
79       return true;
80     }
81   }
82   return false;
83 }
84 
areAnyBatchedAtLevel(ArrayRef<std::optional<Tensor>> maybe_tensors,int64_t level)85 bool areAnyBatchedAtLevel(ArrayRef<std::optional<Tensor>> maybe_tensors, int64_t level) {
86   for (const auto& maybe_tensor : maybe_tensors) {
87     if (isBatchedAtLevel(maybe_tensor, level)) {
88       return true;
89     }
90   }
91   return false;
92 }
93 
94 
95 } // namespace at::functorch
96