xref: /aosp_15_r20/external/pytorch/aten/src/ATen/LegacyVmapTransforms.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/LegacyVmapTransforms.h>
2 #include <ATen/ATen.h>
3 #include <ATen/core/IListRef.h>
4 #include <c10/util/irange.h>
5 
6 namespace at {
7 
8 // Checks if the batch dims in `bdims` appear at the front of the tensor.
areBdimsAtFrontInOrder(BatchDimsRef bdims)9 static bool areBdimsAtFrontInOrder(BatchDimsRef bdims) {
10   for (const auto idx : c10::irange(static_cast<int64_t>(bdims.size()))) {
11     if (bdims[idx].dim() != idx) {
12       return false;
13     }
14   }
15   return true;
16 }
17 
18 // Takes a BatchedTensorImpl, permutes all of the batch dims to the front,
19 // and then returns a physical version of the Tensor.
permuteBatchDimsToFront(BatchedTensorImpl * batched)20 static Tensor permuteBatchDimsToFront(BatchedTensorImpl* batched) {
21   auto bdims = batched->bdims();
22   const Tensor& physical_tensor = batched->value();
23   if (areBdimsAtFrontInOrder(bdims)) {
24     return physical_tensor;
25   }
26   const auto sizes = physical_tensor.sizes();
27   VmapDimVector permutation(sizes.size(), 0);
28   permutation.reserve(sizes.size());
29   const auto is_bdim = createBatchDimBitset(bdims);
30   int64_t idx = 0;
31   for (const auto& bdim : bdims) {
32     permutation[idx++] = bdim.dim();
33   }
34   for (const auto ptr : c10::irange(sizes.size())) {
35     if (is_bdim[ptr]) {
36       continue;
37     }
38     permutation[idx++] = ptr;
39   }
40   return physical_tensor.permute(permutation);
41 }
42 
logicalToPhysical(const Tensor & logical_tensor)43 VmapPhysicalView MultiBatchVmapTransform::logicalToPhysical(const Tensor& logical_tensor) {
44   auto* batched = maybeGetBatchedImpl(logical_tensor);
45   TORCH_INTERNAL_ASSERT(
46       batched,
47       "logicalToPhysical(tensor) should only be passed a BatchedTensor");
48   return { permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims()) };
49 }
50 
numBatchDims() const51 int64_t VmapPhysicalView::numBatchDims() const {
52   return levels_.count();
53 }
54 
numLogicalDims() const55 int64_t VmapPhysicalView::numLogicalDims() const {
56   return /*physical*/tensor_.dim() - numBatchDims();
57 }
58 
getPhysicalDims(OptionalIntArrayRef opt_logical_dims) const59 VmapDimVector VmapPhysicalView::getPhysicalDims(OptionalIntArrayRef opt_logical_dims) const {
60   auto logical_ndim = numLogicalDims();
61   // NB: fmap doesn't have a SmallVector variant, so we don't use it here.
62   VmapDimVector result;
63   result.reserve(logical_ndim);
64   if (opt_logical_dims.has_value() && !opt_logical_dims.value().empty()) {
65     auto logical_dims = opt_logical_dims.value();
66     for (auto dim : logical_dims) {
67       result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims());
68     }
69   } else {
70     for (int64_t dim = 0; dim < logical_ndim; dim++) {
71       result.push_back(dim + numBatchDims());
72     }
73   }
74   return result;
75 }
76 
getPhysicalDim(int64_t logical_dim) const77 int64_t VmapPhysicalView::getPhysicalDim(int64_t logical_dim) const {
78   auto logical_ndim = numLogicalDims();
79   return maybe_wrap_dim(logical_dim, logical_ndim) + numBatchDims();
80 }
81 
getPhysicalShape(IntArrayRef logical_shape) const82 VmapDimVector VmapPhysicalView::getPhysicalShape(IntArrayRef logical_shape) const {
83   VmapDimVector result;
84   result.reserve(logical_shape.size() + numBatchDims());
85   auto tensor_sizes = tensor_.sizes();
86   result.insert(result.end(), tensor_sizes.begin(), tensor_sizes.begin() + numBatchDims());
87   result.insert(result.end(), logical_shape.begin(), logical_shape.end());
88   return result;
89 }
90 
computeFrontBatchDimsFromLevels(std::bitset<kVmapNumLevels> levels_bitset)91 static BatchDims computeFrontBatchDimsFromLevels(std::bitset<kVmapNumLevels> levels_bitset) {
92   BatchDims bdims;
93   int64_t dim = 0;
94   for (const auto level : c10::irange(kVmapNumLevels)) {
95     if (!levels_bitset[level]) {
96       continue;
97     }
98     bdims.emplace_back(level, dim++);
99   }
100   return bdims;
101 }
102 
103 // Given a Tensor or a BatchedTensor, returns the underlying physical tensor
104 // with all vmapped dimensions permuted to the front, if they exist, and a
105 // bitset of vmap levels that were present in the tensor.
106 static std::pair<Tensor,std::bitset<kVmapNumLevels>>
getPhysicalTensorAndLevels(const Tensor & self)107 getPhysicalTensorAndLevels(const Tensor& self) {
108   auto* batched = maybeGetBatchedImpl(self);
109   if (batched) {
110     return {permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims())};
111   }
112   return {self, 0};
113 }
114 
115 // Given a Tensor or a BatchedTensor, creates a physical view of the tensor
116 // such that it has a batch dimension for each level in `requested_levels`
117 // and `requested_example_dim` number of non-batch-dimensions.
118 //
119 // This function is useful in preparing physical views on tensors that can
120 // then be passed into broadcasting operations. For example, when adding
121 // two BatchedTensors of sizes [B0, 3] and [B0, B1, 2, 3], where the Bi are the
122 // batch dimensions, we must align the batch dimensions and non-batch-dimensions
123 // (henceforth referred to as the "example" dimensions) separately to produce
124 // tensors of size [B0, 1, 1, 3] and [B0, B1, 2, 3] so that they can be added.
125 //
126 // Here's a direct example of using alignBatchDimsAtFront on the above two tensors.
127 //
128 // 1) alignBatchDimsAtFront([B0, 3], requested_levels={0, 1}, requested_example_dim=2)
129 // returns a physical view of size [B0, 1, 1, 3] by adding an extra dimension for
130 // level 1 and another extra dimension to pad the example dimensions to 2.
131 //
132 // 2) alignBatchDimsAtFront([B0, B1, 2, 3], requested_levels={0, 1}, requested_example_dim=2)
133 // returns a physical view of size [B0, B1, 2, 3]
alignBatchDimsAtFront(const Tensor & self,std::bitset<kVmapNumLevels> requested_levels,int64_t requested_example_dim)134 static Tensor alignBatchDimsAtFront(
135     const Tensor& self,
136     std::bitset<kVmapNumLevels> requested_levels,
137     int64_t requested_example_dim) {
138   auto [physical_tensor, tensor_levels] = getPhysicalTensorAndLevels(self);
139 
140   TORCH_INTERNAL_ASSERT(
141     (tensor_levels | requested_levels) == requested_levels,
142     "`requested_levels` must be a superset of `self`'s levels");
143 
144   auto physical_sizes = physical_tensor.sizes();
145 
146   const auto tensor_example_dim = (
147     static_cast<int64_t>(physical_sizes.size())
148     - /*num_batch_dims*/static_cast<int64_t>(tensor_levels.count())
149   );
150   TORCH_INTERNAL_ASSERT(tensor_example_dim <= requested_example_dim);
151 
152   if (tensor_levels == requested_levels && tensor_example_dim == requested_example_dim) {
153     // Optimization: no need to do another view if the physical tensor is
154     // already the correct shape
155     return physical_tensor;
156   }
157 
158   VmapDimVector aligned_sizes(requested_levels.count() + requested_example_dim, 1);
159 
160   // align the example dims (non-bdims dims) first
161   // aligned_sizes[-tensor_example_dim:] = tensor_sizes[-tensor_example_dim:]
162   std::copy(
163       physical_sizes.rbegin(),
164       physical_sizes.rbegin() + tensor_example_dim,
165       aligned_sizes.rbegin());
166 
167   // align the bdims
168   int64_t level = 0;
169   int64_t tensor_dim = 0;
170   for (const auto bdim : c10::irange(requested_levels.count())) {
171     // Determine the level of the bdim
172     while (!requested_levels[level]) level++;
173     if (tensor_levels[level]) {
174       aligned_sizes[bdim] = physical_sizes[tensor_dim++];
175     }
176     level++;
177   }
178   return physical_tensor.view(aligned_sizes);
179 }
180 
181 // The algorithm is as follows:
182 // 1. Figure out what all of the collective levels in `logical_tensors` is.
183 // 2. Move all batch dims to the front of the tensors and add extra dims
184 //    of size 1. At this point, every tensor will have a dimension for
185 //    each of the collective levels.
186 // 3. Compute the batch_sizes.
187 // 4. Expand each physical tensor so that they have output batch size equal
188 //    to `batch_sizes`
189 VmapPhysicalViewVec
logicalToPhysical(ITensorListRef logical_tensors)190 MultiBatchVmapTransform::logicalToPhysical(ITensorListRef logical_tensors) {
191   // Figure out all of the collective vmap levels in `logical_tensors`.
192   std::bitset<kVmapNumLevels> collective_levels;
193   for (const auto& logical_tensor : logical_tensors) {
194     auto* batched = maybeGetBatchedImpl(logical_tensor);
195     if (batched) {
196       collective_levels |= createVmapLevelsBitset(batched->bdims());
197     }
198   }
199 
200   // Populate physical_tensors.
201   // This contains a list of regular (non-Batched) Tensors where all of the
202   // batch dims have been moved to the front of the tensor. Any previously
203   // non-existing batch dims get added to the tensors as new dimensions of size 1.
204   std::vector<Tensor> physical_tensors;
205   int64_t num_batch_dims = collective_levels.count();
206   for (const auto& logical_tensor : logical_tensors) {
207     auto requested_example_dim = /*logical_dim*/logical_tensor.dim();
208     auto physical_tensor = alignBatchDimsAtFront(
209         logical_tensor, collective_levels, requested_example_dim);
210     physical_tensors.push_back(std::move(physical_tensor));
211   }
212 
213   // Compute batch_sizes
214   VmapDimVector batch_sizes(num_batch_dims, 1);
215   for (const auto& physical_tensor : physical_tensors) {
216     auto physical_sizes = physical_tensor.sizes();
217     for (const auto dim : c10::irange(num_batch_dims)) {
218       if (physical_sizes[dim] != 1) {
219         batch_sizes[dim] = physical_sizes[dim];
220       }
221     }
222   }
223 
224   // Expand each physical_tensor so that it has batch sizes `batch_sizes`
225   VmapPhysicalViewVec result;
226   for (const auto& physical_tensor : physical_tensors) {
227     VmapDimVector expanded_size(batch_sizes.begin(), batch_sizes.end());
228     auto physical_sizes = physical_tensor.sizes();
229     expanded_size.insert(
230         expanded_size.end(),
231         physical_sizes.begin() + num_batch_dims,
232         physical_sizes.end());
233     result.emplace_back(physical_tensor.expand(expanded_size), collective_levels);
234   }
235   return result;
236 }
237 
238 static std::pair<std::bitset<kVmapNumLevels>,int64_t>
getLevelsAndLargestLogicalDim(TensorList logical_tensors)239 getLevelsAndLargestLogicalDim(TensorList logical_tensors) {
240   TORCH_INTERNAL_ASSERT(!logical_tensors.empty());
241   std::bitset<kVmapNumLevels> levels;
242   int64_t largest_logical_dim = -1;
243   for (const auto& tensor : logical_tensors) {
244     auto* batched = maybeGetBatchedImpl(tensor);
245     if (batched) {
246       levels = levels | createVmapLevelsBitset(batched->bdims());
247     }
248     auto tensor_logical_dim = /*logical dim*/tensor.dim();
249     if (tensor_logical_dim > largest_logical_dim) {
250       largest_logical_dim = tensor_logical_dim;
251     }
252   }
253   return { levels, largest_logical_dim };
254 }
255 
logicalToPhysical(TensorList logical_tensors)256 VmapPhysicalViewVec BroadcastingVmapTransform::logicalToPhysical(TensorList logical_tensors) {
257   TORCH_INTERNAL_ASSERT(
258       logical_tensors.size() == 2,
259       "This function has only been tested for two tensors. Please add more tests ",
260       "before removing this check ");
261 
262   VmapPhysicalViewVec result;
263 
264   auto [levels, largest_logical_dim] = getLevelsAndLargestLogicalDim(logical_tensors);
265 
266   for (const auto& tensor : logical_tensors) {
267     // NB: It's possible that we didn't actually need to align `tensor`.
268     // For example, when adding two tensors of size (B, 2), and (3, 2), where
269     // the first Tensor is a BatchedTensor with batch dim B and the second is
270     // a regular Tensor, we will return views of size (B, 1, 2) and (1, 3, 2).
271     // However, the view on the second tensor is unnecessary: broadcasting
272     // semantics allow for the addition of two tensors of size (B, 1, 2) and (3, 2)!
273     //
274     // If this unnecessary view is a problem, consider optimizing it away in
275     // the future. This may involve creating a new type of VmapPhysicalView
276     auto aligned = alignBatchDimsAtFront(tensor, levels, largest_logical_dim) ;
277     result.emplace_back(std::move(aligned), levels);
278   }
279   return result;
280 }
281 
getPhysicalToLogicalMap() const282 VmapPhysicalToLogicalMap VmapPhysicalView::getPhysicalToLogicalMap() const {
283   return VmapPhysicalToLogicalMap(levels_);
284 }
285 
apply(const Tensor & physical_tensor) const286 Tensor VmapPhysicalToLogicalMap::apply(const Tensor& physical_tensor) const {
287   return makeBatched(physical_tensor, computeFrontBatchDimsFromLevels(levels_));
288 }
289 
applyInplace(std::vector<Tensor> & physical_tensors) const290 void VmapPhysicalToLogicalMap::applyInplace(std::vector<Tensor>& physical_tensors) const {
291   for (auto & physical_tensor : physical_tensors) {
292     physical_tensor = apply(physical_tensor);
293   }
294 }
295 
296 } // namespace at
297