1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/functorch/TensorWrapper.h>
8 #include <ATen/functorch/DynamicLayer.h>
9 #include <ATen/functorch/BatchedTensorImpl.h>
10 #include <ATen/functorch/PlumbingHelper.h>
11
12 namespace at::functorch {
13
vmap_check_escaped(const std::optional<DynamicLayer> & layer,const char * what)14 void vmap_check_escaped(const std::optional<DynamicLayer> &layer, const char* what) {
15 TORCH_CHECK(
16 layer.has_value(),
17 "Either your tensor may have escaped from inside a function being vmapped and this is a user error ",
18 "(see https://pytorch.org/functorch/stable/ux_limitations.html), "
19 "or there is an internal functorch error in `",
20 what,
21 "` Please file an issue if it looks like the latter"
22 )
23 }
24
makeBatched(const Tensor & tensor,std::optional<int64_t> bdim,int64_t level)25 Tensor makeBatched(const Tensor& tensor, std::optional<int64_t> bdim, int64_t level) {
26 if (bdim.has_value()) {
27 TORCH_INTERNAL_ASSERT(*bdim >= 0);
28 TORCH_INTERNAL_ASSERT(*bdim < tensor.dim());
29 return makeBatched(tensor, bdim.value(), level);
30 }
31 return tensor;
32 }
33
makeBatchedVector(const std::vector<Tensor> & tensors,std::optional<int64_t> bdim,int64_t level)34 std::vector<Tensor> makeBatchedVector(const std::vector<Tensor>& tensors, std::optional<int64_t> bdim, int64_t level) {
35 std::vector<Tensor> res;
36 res.reserve(tensors.size());
37 for (const auto & tensor : tensors) {
38 res.emplace_back(makeBatched(tensor, bdim, level));
39 }
40 return res;
41 }
42
unwrapTensorAtLevel(const Tensor & tensor,int64_t level)43 std::tuple<Tensor, std::optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level) {
44 auto* batched = maybeGetBatchedImpl(tensor);
45 if (!batched) {
46 return std::make_tuple(tensor, std::nullopt);
47 }
48 if (batched->level() == level) {
49 return std::make_tuple(batched->value(), batched->bdim());
50 }
51 return std::make_tuple(tensor, std::nullopt);
52 }
53
isBatchedAtLevel(const Tensor & tensor,int64_t level)54 bool isBatchedAtLevel(const Tensor& tensor, int64_t level) {
55 auto result = unwrapTensorAtLevel(tensor, level);
56 return std::get<1>(result).has_value();
57 }
58
isBatchedAtLevel(const std::optional<Tensor> & maybe_tensor,int64_t level)59 bool isBatchedAtLevel(const std::optional<Tensor>& maybe_tensor, int64_t level) {
60 if (!maybe_tensor.has_value()) {
61 return false;
62 }
63 return isBatchedAtLevel(*maybe_tensor, level);
64 }
65
isBatchedAtLevel(ITensorListRef tensors,int64_t level)66 bool isBatchedAtLevel(ITensorListRef tensors, int64_t level) {
67 for (const auto& tensor : tensors) {
68 if (isBatchedAtLevel(tensor, level)) {
69 return true;
70 }
71 }
72 return false;
73 }
74
isBatchedAtLevel(const c10::List<std::optional<Tensor>> & maybe_tensors,int64_t level)75 bool isBatchedAtLevel(const c10::List<std::optional<Tensor>>& maybe_tensors, int64_t level) {
76 for (const auto idx : c10::irange(0, maybe_tensors.size())) {
77 const auto& maybe_tensor = maybe_tensors.get(idx);
78 if (isBatchedAtLevel(maybe_tensor, level)) {
79 return true;
80 }
81 }
82 return false;
83 }
84
areAnyBatchedAtLevel(ArrayRef<std::optional<Tensor>> maybe_tensors,int64_t level)85 bool areAnyBatchedAtLevel(ArrayRef<std::optional<Tensor>> maybe_tensors, int64_t level) {
86 for (const auto& maybe_tensor : maybe_tensors) {
87 if (isBatchedAtLevel(maybe_tensor, level)) {
88 return true;
89 }
90 }
91 return false;
92 }
93
94
95 } // namespace at::functorch
96