xref: /aosp_15_r20/external/pytorch/aten/src/ATen/CollapseDims.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
2*da0073e9SAndroid Build Coastguard Worker #include <utility>
3*da0073e9SAndroid Build Coastguard Worker 
4*da0073e9SAndroid Build Coastguard Worker namespace at {
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker /*
7*da0073e9SAndroid Build Coastguard Worker [collapse dims] Updates sizes, and strides to reflect a "collapse" of
8*da0073e9SAndroid Build Coastguard Worker the info, possibly excluding the optional excludeDim. A "collapsed" version
9*da0073e9SAndroid Build Coastguard Worker of the info is the fewest dims that order the tensor's elements in the same
10*da0073e9SAndroid Build Coastguard Worker way as the original info. If excludeDim is specified, the collapse is the
11*da0073e9SAndroid Build Coastguard Worker fewest dims that order the tensor's elements as the original and preserve the
12*da0073e9SAndroid Build Coastguard Worker excluded dimension, unless the tensor collapses to a point.
13*da0073e9SAndroid Build Coastguard Worker 
14*da0073e9SAndroid Build Coastguard Worker This function returns a pair of values.
15*da0073e9SAndroid Build Coastguard Worker 
16*da0073e9SAndroid Build Coastguard Worker 1) The (new) index of the preserved dimension if excludeDim is
17*da0073e9SAndroid Build Coastguard Worker specified. 0 if the tensor is collapsed to a point. -1
18*da0073e9SAndroid Build Coastguard Worker otherwise.
19*da0073e9SAndroid Build Coastguard Worker 
20*da0073e9SAndroid Build Coastguard Worker 2) The new number of dimensions.
21*da0073e9SAndroid Build Coastguard Worker */
22*da0073e9SAndroid Build Coastguard Worker template <typename T>
23*da0073e9SAndroid Build Coastguard Worker inline std::pair<int64_t, int64_t> collapse_dims(
24*da0073e9SAndroid Build Coastguard Worker     T* sizes,
25*da0073e9SAndroid Build Coastguard Worker     T* strides,
26*da0073e9SAndroid Build Coastguard Worker     int64_t dims,
27*da0073e9SAndroid Build Coastguard Worker     const int excludeDim = -1) {
28*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
29*da0073e9SAndroid Build Coastguard Worker       excludeDim >= -1 && excludeDim < dims,
30*da0073e9SAndroid Build Coastguard Worker       "expected excluded dim between -1 and dims - 1");
31*da0073e9SAndroid Build Coastguard Worker 
32*da0073e9SAndroid Build Coastguard Worker   int64_t stopDim = (excludeDim == -1) ? dims : excludeDim;
33*da0073e9SAndroid Build Coastguard Worker   int64_t newIndex = -1;
34*da0073e9SAndroid Build Coastguard Worker   int64_t oldIndex = 0;
35*da0073e9SAndroid Build Coastguard Worker   int64_t remappedExcludedDim = -1;
36*da0073e9SAndroid Build Coastguard Worker 
37*da0073e9SAndroid Build Coastguard Worker   while (oldIndex < dims) {
38*da0073e9SAndroid Build Coastguard Worker     // Finds a dimension to collapse into
39*da0073e9SAndroid Build Coastguard Worker     for (; oldIndex < stopDim; ++oldIndex) {
40*da0073e9SAndroid Build Coastguard Worker       if (sizes[oldIndex] == 1) {
41*da0073e9SAndroid Build Coastguard Worker         continue;
42*da0073e9SAndroid Build Coastguard Worker       }
43*da0073e9SAndroid Build Coastguard Worker 
44*da0073e9SAndroid Build Coastguard Worker       ++newIndex;
45*da0073e9SAndroid Build Coastguard Worker       sizes[newIndex] = sizes[oldIndex];
46*da0073e9SAndroid Build Coastguard Worker       strides[newIndex] = strides[oldIndex];
47*da0073e9SAndroid Build Coastguard Worker       ++oldIndex;
48*da0073e9SAndroid Build Coastguard Worker       break;
49*da0073e9SAndroid Build Coastguard Worker     }
50*da0073e9SAndroid Build Coastguard Worker 
51*da0073e9SAndroid Build Coastguard Worker     // Collapses dims
52*da0073e9SAndroid Build Coastguard Worker     for (; oldIndex < stopDim; ++oldIndex) {
53*da0073e9SAndroid Build Coastguard Worker       if (sizes[oldIndex] == 1) {
54*da0073e9SAndroid Build Coastguard Worker         continue;
55*da0073e9SAndroid Build Coastguard Worker       }
56*da0073e9SAndroid Build Coastguard Worker 
57*da0073e9SAndroid Build Coastguard Worker       if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) {
58*da0073e9SAndroid Build Coastguard Worker         sizes[newIndex] *= sizes[oldIndex];
59*da0073e9SAndroid Build Coastguard Worker         strides[newIndex] = strides[oldIndex];
60*da0073e9SAndroid Build Coastguard Worker       } else {
61*da0073e9SAndroid Build Coastguard Worker         ++newIndex;
62*da0073e9SAndroid Build Coastguard Worker         sizes[newIndex] = sizes[oldIndex];
63*da0073e9SAndroid Build Coastguard Worker         strides[newIndex] = strides[oldIndex];
64*da0073e9SAndroid Build Coastguard Worker       }
65*da0073e9SAndroid Build Coastguard Worker     }
66*da0073e9SAndroid Build Coastguard Worker 
67*da0073e9SAndroid Build Coastguard Worker     // Handles excludeDim being set (oldIndex == excludeDim)
68*da0073e9SAndroid Build Coastguard Worker     if (oldIndex != dims) {
69*da0073e9SAndroid Build Coastguard Worker       // Preserves excluded dimension
70*da0073e9SAndroid Build Coastguard Worker       ++newIndex;
71*da0073e9SAndroid Build Coastguard Worker       sizes[newIndex] = sizes[oldIndex];
72*da0073e9SAndroid Build Coastguard Worker       strides[newIndex] = strides[oldIndex];
73*da0073e9SAndroid Build Coastguard Worker       remappedExcludedDim = newIndex;
74*da0073e9SAndroid Build Coastguard Worker 
75*da0073e9SAndroid Build Coastguard Worker       // Restarts iteration after excludeDim
76*da0073e9SAndroid Build Coastguard Worker       ++oldIndex;
77*da0073e9SAndroid Build Coastguard Worker       stopDim = dims;
78*da0073e9SAndroid Build Coastguard Worker     }
79*da0073e9SAndroid Build Coastguard Worker   }
80*da0073e9SAndroid Build Coastguard Worker 
81*da0073e9SAndroid Build Coastguard Worker   // Handles special case of all dims size 1
82*da0073e9SAndroid Build Coastguard Worker   if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) {
83*da0073e9SAndroid Build Coastguard Worker     dims = 1;
84*da0073e9SAndroid Build Coastguard Worker     sizes[0] = 1;
85*da0073e9SAndroid Build Coastguard Worker     strides[0] = 1;
86*da0073e9SAndroid Build Coastguard Worker 
87*da0073e9SAndroid Build Coastguard Worker     return std::pair<int64_t, int64_t>(0, 1);
88*da0073e9SAndroid Build Coastguard Worker   }
89*da0073e9SAndroid Build Coastguard Worker 
90*da0073e9SAndroid Build Coastguard Worker   dims = newIndex + 1;
91*da0073e9SAndroid Build Coastguard Worker   return std::pair<int64_t, int64_t>(remappedExcludedDim, dims);
92*da0073e9SAndroid Build Coastguard Worker }
93*da0073e9SAndroid Build Coastguard Worker 
94*da0073e9SAndroid Build Coastguard Worker } // namespace at
95