xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/transpose_functor.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_
17 #define TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_
18 
19 #include <numeric>
20 #include <string>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_types.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace tensorflow {
28 // Transpose tensor 'in' into tensor 'out' according to dimension
29 // permutation 'perm'.
30 //
31 // REQUIRES: in.dtype() == out->dtype()
32 // REQUIRES: in.dims() == out->dims()
33 // REQUIRES: in.dims() == perm.size()
34 // REQUIRES: in.dim_size(perm[i]) == out->dim_size(i)
35 template <typename Device>
36 Status DoTranspose(const Device& device, const Tensor& in,
37                    const gtl::ArraySlice<int32> perm, Tensor* out);
38 
39 // Conjugate and transpose tensor 'in' into tensor 'out' according to dimension
40 // permutation 'perm'.
41 //
42 // REQUIRES: in.dtype() == out->dtype()
43 // REQUIRES: in.dims() == out->dims()
44 // REQUIRES: in.dims() == perm.size()
45 // REQUIRES: in.dim_size(perm[i]) == out->dim_size(i)
46 template <typename Device>
47 Status DoConjugateTranspose(const Device& device, const Tensor& in,
48                             const gtl::ArraySlice<int32> perm, Tensor* out);
49 
50 // Convenience versions of DoTranspose that only swap the last (inner) two
51 // dimensions.
52 template <typename Device>
53 Status DoMatrixTranspose(const Device& device, const Tensor& in, Tensor* out);
54 
55 // Convenience versions of DoConjugateTranspose that only swap the last (inner)
56 // two dimensions.
57 template <typename Device>
58 Status DoConjugateMatrixTranspose(const Device& device, const Tensor& in,
59                                   Tensor* out);
60 
61 // Primary device specific functor to be specialized for each device and type.
62 template <typename Device, typename T, bool conjugate = false>
63 struct Transpose {
64   static void run(const Device& d, const Tensor& in,
65                   const gtl::ArraySlice<int32> perm, Tensor* out);
66 };
67 
68 // Implementation details.
69 namespace internal {
70 
71 typedef gtl::InlinedVector<int64_t, 8> TransposeDimsVec;
72 typedef gtl::InlinedVector<int32, 8> TransposePermsVec;
73 
74 // Helper function that takes a tensor shape, a permutation, combines the
75 // neighboring shapes if their indices in the permutation are consecutive.
76 // The function outputs the combined shape and new permutation.
77 // Example: Tensor shape {2, 3, 4, 5, 120} and permutation {0, 4, 1, 2, 3} will
78 // produce new shape {2, 60, 120} and new permutation {0, 2, 1}.
ReduceTransposeDimensions(const TensorShape & shape,gtl::ArraySlice<int32> perm,TransposePermsVec * new_perm,TransposeDimsVec * new_dims)79 inline void ReduceTransposeDimensions(const TensorShape& shape,
80                                       gtl::ArraySlice<int32> perm,
81                                       TransposePermsVec* new_perm,
82                                       TransposeDimsVec* new_dims) {
83   CHECK_EQ(shape.dims(), perm.size());
84   if (shape.dims() == 1) {
85     // If input dimension is already 1, no need to reduce dimension.
86     new_perm->resize(1);
87     (*new_perm)[0] = perm[0];
88     (*new_dims)[0] = shape.dim_size(0);
89     return;
90   }
91   TransposePermsVec new_dim_position(shape.dims(), -1);
92   TransposeDimsVec combined_dims(shape.dims(), 0);
93   int cur_head = perm[0];
94   new_dim_position[cur_head] = 0;
95   combined_dims[0] = shape.dim_size(cur_head);
96   int dim_idx = 0;
97   for (int perm_idx = 1; perm_idx < shape.dims(); ++perm_idx) {
98     // If two indices in permutation are consecutive numbers, combine their
99     // dimensions.
100     if (cur_head + 1 == perm[perm_idx]) {
101       cur_head = perm[perm_idx];
102       combined_dims[dim_idx] *= shape.dim_size(cur_head);
103     } else {
104       // Else start a new dimension.
105       cur_head = perm[perm_idx];
106       dim_idx++;
107       new_dim_position[cur_head] = dim_idx;
108       combined_dims[dim_idx] = shape.dim_size(cur_head);
109     }
110   }
111   // Compact the new permutations and dimension sizes.
112   new_perm->resize(dim_idx + 1);
113   new_dims->resize(dim_idx + 1);
114   dim_idx = 0;
115   for (int i = 0; i < new_dim_position.size(); ++i) {
116     if (new_dim_position[i] >= 0) {
117       int new_perm_idx = new_dim_position[i];
118       (*new_perm)[dim_idx] = new_perm_idx;
119       (*new_dims)[dim_idx] = combined_dims[new_perm_idx];
120       dim_idx++;
121     }
122   }
123 }
124 
125 // If all non-singleton dimensions remain in ascending order, the shuffled
126 // singletons can be transposed by a reshape, saving a memory allocation & copy.
127 // |permutation| must be a permutation of {0, .., input_shape.dims() - 1}.
128 // That is, for all i, 0 <= perm[i] < input_shape.dims().
129 // In practice, this is checked in TransposeOp::Compute prior to calling this
130 // function, and the function sits here to facilitate unit testing.
NonSingletonDimensionsAlign(const TensorShape & input_shape,const std::vector<int32> & permutation)131 inline bool NonSingletonDimensionsAlign(const TensorShape& input_shape,
132                                         const std::vector<int32>& permutation) {
133   int last_nonsingleton_perm_dim = -1;
134   for (int perm_dim : permutation) {
135     if (input_shape.dim_size(perm_dim) == 1) {
136       continue;
137     }
138     if (perm_dim < last_nonsingleton_perm_dim) {
139       return false;
140     }
141     last_nonsingleton_perm_dim = perm_dim;
142   }
143   return true;
144 }
145 
146 // Uses Eigen to transpose.
147 template <typename Device, typename T, int NDIMS>
TransposeUsingEigen(const Device & d,const Tensor & in,const gtl::ArraySlice<int32> perm,bool conjugate,Tensor * out)148 void TransposeUsingEigen(const Device& d, const Tensor& in,
149                          const gtl::ArraySlice<int32> perm, bool conjugate,
150                          Tensor* out) {
151   Eigen::array<int, NDIMS> p;
152   for (int i = 0; i < NDIMS; ++i) p[i] = perm[i];
153   auto x = typename TTypes<T, NDIMS>::ConstTensor(
154       reinterpret_cast<const T*>(in.tensor_data().data()),
155       in.shape().AsEigenDSizes<NDIMS>());
156   auto y = typename TTypes<T, NDIMS>::Tensor(
157       reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data())),
158       out->shape().AsEigenDSizes<NDIMS>());
159   if (conjugate) {
160     y.device(d) = x.conjugate().shuffle(p);
161   } else {
162     y.device(d) = x.shuffle(p);
163   }
164 }
165 
166 template <typename Device>
DoTransposeImpl(const Device & d,const Tensor & in,const gtl::ArraySlice<int32> perm,bool conjugate,Tensor * out)167 Status DoTransposeImpl(const Device& d, const Tensor& in,
168                        const gtl::ArraySlice<int32> perm, bool conjugate,
169                        Tensor* out) {
170   CHECK_EQ(in.dims(), out->dims());
171   CHECK_EQ(in.dims(), perm.size());
172   CHECK_EQ(in.dtype(), out->dtype());
173   switch (in.dtype()) {
174     case DT_BOOL:
175     case DT_INT8:
176     case DT_QINT8:
177     case DT_QUINT8:
178     case DT_UINT8:
179       Transpose<Device, uint8>::run(d, in, perm, out);
180       break;
181 
182     case DT_BFLOAT16:
183     case DT_HALF:
184     case DT_INT16:
185     case DT_QINT16:
186     case DT_QUINT16:
187     case DT_UINT16:
188       Transpose<Device, uint16>::run(d, in, perm, out);
189       break;
190 
191     case DT_FLOAT:
192     case DT_INT32:
193     case DT_QINT32:
194     case DT_UINT32:
195       Transpose<Device, uint32>::run(d, in, perm, out);
196       break;
197 
198     case DT_DOUBLE:
199     case DT_INT64:
200     case DT_UINT64:
201       Transpose<Device, uint64>::run(d, in, perm, out);
202       break;
203 
204     case DT_COMPLEX64:
205       if (conjugate) {
206 #if defined(__ANDROID__) and !defined(__clang__)
207         // Workaround for GCC compiler bug in Android toolchain.
208         return errors::Unimplemented(
209             "Conjugate transpose of complex64 not supported for GCC on "
210             "Android.");
211 #else
212         Transpose<Device, complex64, /*conjugate=*/true>::run(d, in, perm, out);
213 #endif
214       } else {
215         Transpose<Device, uint64>::run(d, in, perm, out);
216       }
217       break;
218 
219     case DT_COMPLEX128:
220       if (conjugate) {
221         Transpose<Device, complex128, /*conjugate=*/true>::run(d, in, perm,
222                                                                out);
223       } else {
224         Transpose<Device, complex128, /*conjugate=*/false>::run(d, in, perm,
225                                                                 out);
226       }
227       break;
228 
229     case DT_STRING:
230       Transpose<Device, tstring>::run(d, in, perm, out);
231       break;
232 
233     default:
234       return errors::Unimplemented("Unsupported dtype on CPU: ", in.dtype());
235   }
236   return OkStatus();
237 }
238 
239 template <typename Device>
DoMatrixTransposeImpl(const Device & device,const Tensor & in,bool conjugate,Tensor * out)240 inline Status DoMatrixTransposeImpl(const Device& device, const Tensor& in,
241                                     bool conjugate, Tensor* out) {
242   const int ndims = in.dims();
243   if (ndims == 0) return OkStatus();
244   TransposePermsVec perm(ndims);
245   std::iota(perm.begin(), perm.end(), 0);
246   std::swap(perm[ndims - 2], perm[ndims - 1]);
247   return DoTransposeImpl(device, in, perm, conjugate, out);
248 }
249 
250 }  // namespace internal
251 }  // namespace tensorflow
252 
253 #endif  // TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_
254