xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/transpose_util.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 #include <string.h>
13 
14 namespace torch {
15 namespace executor {
16 
17 using SizesType = exec_aten::SizesType;
18 using StridesType = exec_aten::StridesType;
19 
20 /**
21  * Returns a tensor that is a transposed version of input in out.
22  * The given dimensions dim0 and dim1 are swapped.
23  *
24  * @param[in] a the input tensor.
25  * @param[in] dim0 the first dimension to be transposed
26  * @param[in] dim1 the second dimension to be transposed.
27  *
28  */
29 template <typename T>
30 void transpose_tensors(
31     const Tensor& a,
32     int64_t dim0,
33     int64_t dim1,
34     Tensor& out);
35 
36 namespace {
37 /**
38  * Increments an N dimensional index like x[0,0,0] to x[0, 0, 1] to x[0, 0, 2]
39  * to x[0, 1, 0] to x[0, 1, 1] etc...
40  *
41  * @param index An array of the same size as sizes. This stores the "counter"
42  * being incremented.
43  *
44  * @param new_sizes The output tensor dimensions. Allows us to compute the
45  * offset into the input tensor.
46  *
47  * @param non_one_indices A list of indices into index that contain non-1
48  * dimension values. This allows us to eliminate an O(dim) factor from the
49  * runtime in case many dimensions have a value of 1.
50  *
51  * @param new_strides Strides corresponding to new_sizes.
52  *
53  * @param offset The computed offset to index into the input tensor's memory
54  * array.
55  */
increment_index_and_offset(size_t * index,const SizesType * new_sizes,const StridesType * new_strides,const ArrayRef<size_t> non_one_indices,size_t & offset)56 inline void increment_index_and_offset(
57     size_t* index,
58     const SizesType* new_sizes,
59     const StridesType* new_strides,
60     const ArrayRef<size_t> non_one_indices,
61     size_t& offset) {
62   for (size_t j = non_one_indices.size(); j > 0; --j) {
63     const size_t i = non_one_indices[j - 1];
64 
65     index[i]++;
66     // Impossible to happen at i = 0 due to precondition check before this
67     // function is called
68     offset += new_strides[i];
69     if (index[i] == new_sizes[i]) {
70       offset -= new_sizes[i] * new_strides[i];
71       index[i] = 0;
72     } else {
73       return;
74     }
75   }
76 }
77 
78 } // namespace
79 
80 template <typename T>
transpose_tensors(const Tensor & a,int64_t dim0,int64_t dim1,Tensor & out)81 void transpose_tensors(
82     const Tensor& a,
83     int64_t dim0,
84     int64_t dim1,
85     Tensor& out) {
86   auto dim = a.dim();
87   auto data_a = a.const_data_ptr<T>();
88   auto data_out = out.mutable_data_ptr<T>();
89 
90   size_t out_index[kTensorDimensionLimit];
91   memset(out_index, 0, sizeof(out_index));
92 
93   StridesType new_strides[kTensorDimensionLimit];
94   SizesType new_sizes[kTensorDimensionLimit];
95 
96   if (dim != 0) {
97     auto a_strides = a.strides();
98     memcpy(new_strides, a_strides.data(), dim * sizeof(StridesType));
99 
100     auto a_sizes = a.sizes();
101     memcpy(new_sizes, a_sizes.data(), dim * sizeof(SizesType));
102 
103     std::swap(new_sizes[dim0], new_sizes[dim1]);
104     std::swap(new_strides[dim1], new_strides[dim0]);
105   }
106 
107   // non_1_dim_indices stores the indices of the dimensions that have a value
108   // greater than 1. Dimensions can only have a value of 1 or larger.
109   //
110   // This list is stored in the increasing order of the output (not input)
111   // dimension. i.e. lower index of non-1 output dimension first). This
112   // allows us to loop over only the non-1 indices (and skip the ones that
113   // have a value of 1 since they don't contribute to any meaningful computation
114   // in terms of increasing the number of elements to be copied).
115   //
116   // We loop over these non-1 indices in the reverse order since we want to
117   // process the last output dimension first (to be able to walk the input
118   // tensor in output tensor order.
119   size_t non_1_dim_indices[kTensorDimensionLimit];
120   size_t num_non_1_dim_indices = 0;
121   for (size_t cur_dim = 0; cur_dim < dim; cur_dim++) {
122     if (new_sizes[cur_dim] != 1) {
123       non_1_dim_indices[num_non_1_dim_indices++] = cur_dim;
124     }
125   }
126 
127   ArrayRef<size_t> indices(non_1_dim_indices, num_non_1_dim_indices);
128 
129   // Loop over and copy input elements into output
130   size_t a_offset = 0;
131   for (ssize_t out_offset = 0; out_offset < a.numel(); out_offset++) {
132     data_out[out_offset] = data_a[a_offset];
133     increment_index_and_offset(
134         out_index, new_sizes, new_strides, indices, a_offset);
135   }
136 }
137 
check_t_copy_args(const Tensor & in,Tensor & out)138 inline bool check_t_copy_args(const Tensor& in, Tensor& out) {
139   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
140   ET_LOG_AND_RETURN_IF_FALSE(tensor_has_rank_smaller_or_equal_to(in, 2));
141   return true;
142 }
143 
check_transpose_copy_args(const Tensor & in,int64_t dim0,int64_t dim1,Tensor & out)144 inline bool check_transpose_copy_args(
145     const Tensor& in,
146     int64_t dim0,
147     int64_t dim1,
148     Tensor& out) {
149   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
150   ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim0));
151   ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim1));
152   return true;
153 }
154 
get_transpose_out_target_size(const Tensor & in,SizesType dim0,SizesType dim1,SizesType * out_sizes,size_t * out_ndim)155 inline void get_transpose_out_target_size(
156     const Tensor& in,
157     SizesType dim0,
158     SizesType dim1,
159     SizesType* out_sizes,
160     size_t* out_ndim) {
161   *out_ndim = in.dim();
162 
163   if (in.dim() == 0) {
164     return;
165   }
166 
167   for (size_t i = 0; i < in.dim(); ++i) {
168     out_sizes[i] = in.size(i);
169   }
170   out_sizes[dim0] = in.size(dim1);
171   out_sizes[dim1] = in.size(dim0);
172 }
173 
174 } // namespace executor
175 } // namespace torch
176