1 #include <ATen/LegacyVmapTransforms.h>
2 #include <ATen/ATen.h>
3 #include <ATen/core/IListRef.h>
4 #include <c10/util/irange.h>
5
6 namespace at {
7
8 // Checks if the batch dims in `bdims` appear at the front of the tensor.
areBdimsAtFrontInOrder(BatchDimsRef bdims)9 static bool areBdimsAtFrontInOrder(BatchDimsRef bdims) {
10 for (const auto idx : c10::irange(static_cast<int64_t>(bdims.size()))) {
11 if (bdims[idx].dim() != idx) {
12 return false;
13 }
14 }
15 return true;
16 }
17
18 // Takes a BatchedTensorImpl, permutes all of the batch dims to the front,
19 // and then returns a physical version of the Tensor.
permuteBatchDimsToFront(BatchedTensorImpl * batched)20 static Tensor permuteBatchDimsToFront(BatchedTensorImpl* batched) {
21 auto bdims = batched->bdims();
22 const Tensor& physical_tensor = batched->value();
23 if (areBdimsAtFrontInOrder(bdims)) {
24 return physical_tensor;
25 }
26 const auto sizes = physical_tensor.sizes();
27 VmapDimVector permutation(sizes.size(), 0);
28 permutation.reserve(sizes.size());
29 const auto is_bdim = createBatchDimBitset(bdims);
30 int64_t idx = 0;
31 for (const auto& bdim : bdims) {
32 permutation[idx++] = bdim.dim();
33 }
34 for (const auto ptr : c10::irange(sizes.size())) {
35 if (is_bdim[ptr]) {
36 continue;
37 }
38 permutation[idx++] = ptr;
39 }
40 return physical_tensor.permute(permutation);
41 }
42
logicalToPhysical(const Tensor & logical_tensor)43 VmapPhysicalView MultiBatchVmapTransform::logicalToPhysical(const Tensor& logical_tensor) {
44 auto* batched = maybeGetBatchedImpl(logical_tensor);
45 TORCH_INTERNAL_ASSERT(
46 batched,
47 "logicalToPhysical(tensor) should only be passed a BatchedTensor");
48 return { permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims()) };
49 }
50
numBatchDims() const51 int64_t VmapPhysicalView::numBatchDims() const {
52 return levels_.count();
53 }
54
numLogicalDims() const55 int64_t VmapPhysicalView::numLogicalDims() const {
56 return /*physical*/tensor_.dim() - numBatchDims();
57 }
58
getPhysicalDims(OptionalIntArrayRef opt_logical_dims) const59 VmapDimVector VmapPhysicalView::getPhysicalDims(OptionalIntArrayRef opt_logical_dims) const {
60 auto logical_ndim = numLogicalDims();
61 // NB: fmap doesn't have a SmallVector variant, so we don't use it here.
62 VmapDimVector result;
63 result.reserve(logical_ndim);
64 if (opt_logical_dims.has_value() && !opt_logical_dims.value().empty()) {
65 auto logical_dims = opt_logical_dims.value();
66 for (auto dim : logical_dims) {
67 result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims());
68 }
69 } else {
70 for (int64_t dim = 0; dim < logical_ndim; dim++) {
71 result.push_back(dim + numBatchDims());
72 }
73 }
74 return result;
75 }
76
getPhysicalDim(int64_t logical_dim) const77 int64_t VmapPhysicalView::getPhysicalDim(int64_t logical_dim) const {
78 auto logical_ndim = numLogicalDims();
79 return maybe_wrap_dim(logical_dim, logical_ndim) + numBatchDims();
80 }
81
getPhysicalShape(IntArrayRef logical_shape) const82 VmapDimVector VmapPhysicalView::getPhysicalShape(IntArrayRef logical_shape) const {
83 VmapDimVector result;
84 result.reserve(logical_shape.size() + numBatchDims());
85 auto tensor_sizes = tensor_.sizes();
86 result.insert(result.end(), tensor_sizes.begin(), tensor_sizes.begin() + numBatchDims());
87 result.insert(result.end(), logical_shape.begin(), logical_shape.end());
88 return result;
89 }
90
computeFrontBatchDimsFromLevels(std::bitset<kVmapNumLevels> levels_bitset)91 static BatchDims computeFrontBatchDimsFromLevels(std::bitset<kVmapNumLevels> levels_bitset) {
92 BatchDims bdims;
93 int64_t dim = 0;
94 for (const auto level : c10::irange(kVmapNumLevels)) {
95 if (!levels_bitset[level]) {
96 continue;
97 }
98 bdims.emplace_back(level, dim++);
99 }
100 return bdims;
101 }
102
103 // Given a Tensor or a BatchedTensor, returns the underlying physical tensor
104 // with all vmapped dimensions permuted to the front, if they exist, and a
105 // bitset of vmap levels that were present in the tensor.
106 static std::pair<Tensor,std::bitset<kVmapNumLevels>>
getPhysicalTensorAndLevels(const Tensor & self)107 getPhysicalTensorAndLevels(const Tensor& self) {
108 auto* batched = maybeGetBatchedImpl(self);
109 if (batched) {
110 return {permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims())};
111 }
112 return {self, 0};
113 }
114
115 // Given a Tensor or a BatchedTensor, creates a physical view of the tensor
116 // such that it has a batch dimension for each level in `requested_levels`
117 // and `requested_example_dim` number of non-batch-dimensions.
118 //
119 // This function is useful in preparing physical views on tensors that can
120 // then be passed into broadcasting operations. For example, when adding
121 // two BatchedTensors of sizes [B0, 3] and [B0, B1, 2, 3], where the Bi are the
122 // batch dimensions, we must align the batch dimensions and non-batch-dimensions
123 // (henceforth referred to as the "example" dimensions) separately to produce
124 // tensors of size [B0, 1, 1, 3] and [B0, B1, 2, 3] so that they can be added.
125 //
126 // Here's a direct example of using alignBatchDimsAtFront on the above two tensors.
127 //
128 // 1) alignBatchDimsAtFront([B0, 3], requested_levels={0, 1}, requested_example_dim=2)
129 // returns a physical view of size [B0, 1, 1, 3] by adding an extra dimension for
130 // level 1 and another extra dimension to pad the example dimensions to 2.
131 //
132 // 2) alignBatchDimsAtFront([B0, B1, 2, 3], requested_levels={0, 1}, requested_example_dim=2)
133 // returns a physical view of size [B0, B1, 2, 3]
alignBatchDimsAtFront(const Tensor & self,std::bitset<kVmapNumLevels> requested_levels,int64_t requested_example_dim)134 static Tensor alignBatchDimsAtFront(
135 const Tensor& self,
136 std::bitset<kVmapNumLevels> requested_levels,
137 int64_t requested_example_dim) {
138 auto [physical_tensor, tensor_levels] = getPhysicalTensorAndLevels(self);
139
140 TORCH_INTERNAL_ASSERT(
141 (tensor_levels | requested_levels) == requested_levels,
142 "`requested_levels` must be a superset of `self`'s levels");
143
144 auto physical_sizes = physical_tensor.sizes();
145
146 const auto tensor_example_dim = (
147 static_cast<int64_t>(physical_sizes.size())
148 - /*num_batch_dims*/static_cast<int64_t>(tensor_levels.count())
149 );
150 TORCH_INTERNAL_ASSERT(tensor_example_dim <= requested_example_dim);
151
152 if (tensor_levels == requested_levels && tensor_example_dim == requested_example_dim) {
153 // Optimization: no need to do another view if the physical tensor is
154 // already the correct shape
155 return physical_tensor;
156 }
157
158 VmapDimVector aligned_sizes(requested_levels.count() + requested_example_dim, 1);
159
160 // align the example dims (non-bdims dims) first
161 // aligned_sizes[-tensor_example_dim:] = tensor_sizes[-tensor_example_dim:]
162 std::copy(
163 physical_sizes.rbegin(),
164 physical_sizes.rbegin() + tensor_example_dim,
165 aligned_sizes.rbegin());
166
167 // align the bdims
168 int64_t level = 0;
169 int64_t tensor_dim = 0;
170 for (const auto bdim : c10::irange(requested_levels.count())) {
171 // Determine the level of the bdim
172 while (!requested_levels[level]) level++;
173 if (tensor_levels[level]) {
174 aligned_sizes[bdim] = physical_sizes[tensor_dim++];
175 }
176 level++;
177 }
178 return physical_tensor.view(aligned_sizes);
179 }
180
181 // The algorithm is as follows:
182 // 1. Figure out what all of the collective levels in `logical_tensors` is.
183 // 2. Move all batch dims to the front of the tensors and add extra dims
184 // of size 1. At this point, every tensor will have a dimension for
185 // each of the collective levels.
186 // 3. Compute the batch_sizes.
187 // 4. Expand each physical tensor so that they have output batch size equal
188 // to `batch_sizes`
189 VmapPhysicalViewVec
logicalToPhysical(ITensorListRef logical_tensors)190 MultiBatchVmapTransform::logicalToPhysical(ITensorListRef logical_tensors) {
191 // Figure out all of the collective vmap levels in `logical_tensors`.
192 std::bitset<kVmapNumLevels> collective_levels;
193 for (const auto& logical_tensor : logical_tensors) {
194 auto* batched = maybeGetBatchedImpl(logical_tensor);
195 if (batched) {
196 collective_levels |= createVmapLevelsBitset(batched->bdims());
197 }
198 }
199
200 // Populate physical_tensors.
201 // This contains a list of regular (non-Batched) Tensors where all of the
202 // batch dims have been moved to the front of the tensor. Any previously
203 // non-existing batch dims get added to the tensors as new dimensions of size 1.
204 std::vector<Tensor> physical_tensors;
205 int64_t num_batch_dims = collective_levels.count();
206 for (const auto& logical_tensor : logical_tensors) {
207 auto requested_example_dim = /*logical_dim*/logical_tensor.dim();
208 auto physical_tensor = alignBatchDimsAtFront(
209 logical_tensor, collective_levels, requested_example_dim);
210 physical_tensors.push_back(std::move(physical_tensor));
211 }
212
213 // Compute batch_sizes
214 VmapDimVector batch_sizes(num_batch_dims, 1);
215 for (const auto& physical_tensor : physical_tensors) {
216 auto physical_sizes = physical_tensor.sizes();
217 for (const auto dim : c10::irange(num_batch_dims)) {
218 if (physical_sizes[dim] != 1) {
219 batch_sizes[dim] = physical_sizes[dim];
220 }
221 }
222 }
223
224 // Expand each physical_tensor so that it has batch sizes `batch_sizes`
225 VmapPhysicalViewVec result;
226 for (const auto& physical_tensor : physical_tensors) {
227 VmapDimVector expanded_size(batch_sizes.begin(), batch_sizes.end());
228 auto physical_sizes = physical_tensor.sizes();
229 expanded_size.insert(
230 expanded_size.end(),
231 physical_sizes.begin() + num_batch_dims,
232 physical_sizes.end());
233 result.emplace_back(physical_tensor.expand(expanded_size), collective_levels);
234 }
235 return result;
236 }
237
238 static std::pair<std::bitset<kVmapNumLevels>,int64_t>
getLevelsAndLargestLogicalDim(TensorList logical_tensors)239 getLevelsAndLargestLogicalDim(TensorList logical_tensors) {
240 TORCH_INTERNAL_ASSERT(!logical_tensors.empty());
241 std::bitset<kVmapNumLevels> levels;
242 int64_t largest_logical_dim = -1;
243 for (const auto& tensor : logical_tensors) {
244 auto* batched = maybeGetBatchedImpl(tensor);
245 if (batched) {
246 levels = levels | createVmapLevelsBitset(batched->bdims());
247 }
248 auto tensor_logical_dim = /*logical dim*/tensor.dim();
249 if (tensor_logical_dim > largest_logical_dim) {
250 largest_logical_dim = tensor_logical_dim;
251 }
252 }
253 return { levels, largest_logical_dim };
254 }
255
logicalToPhysical(TensorList logical_tensors)256 VmapPhysicalViewVec BroadcastingVmapTransform::logicalToPhysical(TensorList logical_tensors) {
257 TORCH_INTERNAL_ASSERT(
258 logical_tensors.size() == 2,
259 "This function has only been tested for two tensors. Please add more tests ",
260 "before removing this check ");
261
262 VmapPhysicalViewVec result;
263
264 auto [levels, largest_logical_dim] = getLevelsAndLargestLogicalDim(logical_tensors);
265
266 for (const auto& tensor : logical_tensors) {
267 // NB: It's possible that we didn't actually need to align `tensor`.
268 // For example, when adding two tensors of size (B, 2), and (3, 2), where
269 // the first Tensor is a BatchedTensor with batch dim B and the second is
270 // a regular Tensor, we will return views of size (B, 1, 2) and (1, 3, 2).
271 // However, the view on the second tensor is unnecessary: broadcasting
272 // semantics allow for the addition of two tensors of size (B, 1, 2) and (3, 2)!
273 //
274 // If this unnecessary view is a problem, consider optimizing it away in
275 // the future. This may involve creating a new type of VmapPhysicalView
276 auto aligned = alignBatchDimsAtFront(tensor, levels, largest_logical_dim) ;
277 result.emplace_back(std::move(aligned), levels);
278 }
279 return result;
280 }
281
getPhysicalToLogicalMap() const282 VmapPhysicalToLogicalMap VmapPhysicalView::getPhysicalToLogicalMap() const {
283 return VmapPhysicalToLogicalMap(levels_);
284 }
285
apply(const Tensor & physical_tensor) const286 Tensor VmapPhysicalToLogicalMap::apply(const Tensor& physical_tensor) const {
287 return makeBatched(physical_tensor, computeFrontBatchDimsFromLevels(levels_));
288 }
289
applyInplace(std::vector<Tensor> & physical_tensors) const290 void VmapPhysicalToLogicalMap::applyInplace(std::vector<Tensor>& physical_tensors) const {
291 for (auto & physical_tensor : physical_tensors) {
292 physical_tensor = apply(physical_tensor);
293 }
294 }
295
296 } // namespace at
297