1*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h> 2*da0073e9SAndroid Build Coastguard Worker #include <utility> 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Worker namespace at { 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker /* 7*da0073e9SAndroid Build Coastguard Worker [collapse dims] Updates sizes, and strides to reflect a "collapse" of 8*da0073e9SAndroid Build Coastguard Worker the info, possibly excluding the optional excludeDim. A "collapsed" version 9*da0073e9SAndroid Build Coastguard Worker of the info is the fewest dims that order the tensor's elements in the same 10*da0073e9SAndroid Build Coastguard Worker way as the original info. If excludeDim is specified, the collapse is the 11*da0073e9SAndroid Build Coastguard Worker fewest dims that order the tensor's elements as the original and preserve the 12*da0073e9SAndroid Build Coastguard Worker excluded dimension, unless the tensor collapses to a point. 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker This function returns a pair of values. 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker 1) The (new) index of the preserved dimension if excludeDim is 17*da0073e9SAndroid Build Coastguard Worker specified. 0 if the tensor is collapsed to a point. -1 18*da0073e9SAndroid Build Coastguard Worker otherwise. 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker 2) The new number of dimensions. 21*da0073e9SAndroid Build Coastguard Worker */ 22*da0073e9SAndroid Build Coastguard Worker template <typename T> 23*da0073e9SAndroid Build Coastguard Worker inline std::pair<int64_t, int64_t> collapse_dims( 24*da0073e9SAndroid Build Coastguard Worker T* sizes, 25*da0073e9SAndroid Build Coastguard Worker T* strides, 26*da0073e9SAndroid Build Coastguard Worker int64_t dims, 27*da0073e9SAndroid Build Coastguard Worker const int excludeDim = -1) { 28*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK( 29*da0073e9SAndroid Build Coastguard Worker excludeDim >= -1 && excludeDim < dims, 30*da0073e9SAndroid Build Coastguard Worker "expected excluded dim between -1 and dims - 1"); 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker int64_t stopDim = (excludeDim == -1) ? dims : excludeDim; 33*da0073e9SAndroid Build Coastguard Worker int64_t newIndex = -1; 34*da0073e9SAndroid Build Coastguard Worker int64_t oldIndex = 0; 35*da0073e9SAndroid Build Coastguard Worker int64_t remappedExcludedDim = -1; 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker while (oldIndex < dims) { 38*da0073e9SAndroid Build Coastguard Worker // Finds a dimension to collapse into 39*da0073e9SAndroid Build Coastguard Worker for (; oldIndex < stopDim; ++oldIndex) { 40*da0073e9SAndroid Build Coastguard Worker if (sizes[oldIndex] == 1) { 41*da0073e9SAndroid Build Coastguard Worker continue; 42*da0073e9SAndroid Build Coastguard Worker } 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker ++newIndex; 45*da0073e9SAndroid Build Coastguard Worker sizes[newIndex] = sizes[oldIndex]; 46*da0073e9SAndroid Build Coastguard Worker strides[newIndex] = strides[oldIndex]; 47*da0073e9SAndroid Build Coastguard Worker ++oldIndex; 48*da0073e9SAndroid Build Coastguard Worker break; 49*da0073e9SAndroid Build Coastguard Worker } 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker // Collapses dims 52*da0073e9SAndroid Build Coastguard Worker for (; oldIndex < stopDim; ++oldIndex) { 53*da0073e9SAndroid Build Coastguard Worker if (sizes[oldIndex] == 1) { 54*da0073e9SAndroid Build Coastguard Worker continue; 55*da0073e9SAndroid Build Coastguard Worker } 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) { 58*da0073e9SAndroid Build Coastguard Worker sizes[newIndex] *= sizes[oldIndex]; 59*da0073e9SAndroid Build Coastguard Worker strides[newIndex] = strides[oldIndex]; 60*da0073e9SAndroid Build Coastguard Worker } else { 61*da0073e9SAndroid Build Coastguard Worker ++newIndex; 62*da0073e9SAndroid Build Coastguard Worker sizes[newIndex] = sizes[oldIndex]; 63*da0073e9SAndroid Build Coastguard Worker strides[newIndex] = strides[oldIndex]; 64*da0073e9SAndroid Build Coastguard Worker } 65*da0073e9SAndroid Build Coastguard Worker } 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker // Handles excludeDim being set (oldIndex == excludeDim) 68*da0073e9SAndroid Build Coastguard Worker if (oldIndex != dims) { 69*da0073e9SAndroid Build Coastguard Worker // Preserves excluded dimension 70*da0073e9SAndroid Build Coastguard Worker ++newIndex; 71*da0073e9SAndroid Build Coastguard Worker sizes[newIndex] = sizes[oldIndex]; 72*da0073e9SAndroid Build Coastguard Worker strides[newIndex] = strides[oldIndex]; 73*da0073e9SAndroid Build Coastguard Worker remappedExcludedDim = newIndex; 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker // Restarts iteration after excludeDim 76*da0073e9SAndroid Build Coastguard Worker ++oldIndex; 77*da0073e9SAndroid Build Coastguard Worker stopDim = dims; 78*da0073e9SAndroid Build Coastguard Worker } 79*da0073e9SAndroid Build Coastguard Worker } 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker // Handles special case of all dims size 1 82*da0073e9SAndroid Build Coastguard Worker if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) { 83*da0073e9SAndroid Build Coastguard Worker dims = 1; 84*da0073e9SAndroid Build Coastguard Worker sizes[0] = 1; 85*da0073e9SAndroid Build Coastguard Worker strides[0] = 1; 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker return std::pair<int64_t, int64_t>(0, 1); 88*da0073e9SAndroid Build Coastguard Worker } 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker dims = newIndex + 1; 91*da0073e9SAndroid Build Coastguard Worker return std::pair<int64_t, int64_t>(remappedExcludedDim, dims); 92*da0073e9SAndroid Build Coastguard Worker } 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker } // namespace at 95