xref: /aosp_15_r20/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_MAP_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_MAP_H
12 
13 namespace Eigen {
14 
15 // FIXME use proper doxygen documentation (e.g. \tparam MakePointer_)
16 
17 /** \class TensorMap
18   * \ingroup CXX11_Tensor_Module
19   *
20   * \brief A tensor expression mapping an existing array of data.
21   *
22   */
23 /// `template <class> class MakePointer_` is added to convert the host pointer to the device pointer.
24 /// It is added due to the fact that for our device compiler `T*` is not allowed.
25 /// If we wanted to use the same Evaluator functions we have to convert that type to our pointer `T`.
26 /// This is done through our `MakePointer_` class. By default the Type in the `MakePointer_<T>` is `T*` .
27 /// Therefore, by adding the default value, we managed to convert the type and it does not break any
28 /// existing code as its default value is `T*`.
29 template<typename PlainObjectType, int Options_, template <class> class MakePointer_> class TensorMap : public TensorBase<TensorMap<PlainObjectType, Options_, MakePointer_> >
30 {
31   public:
32     typedef TensorMap<PlainObjectType, Options_, MakePointer_> Self;
33     typedef TensorBase<TensorMap<PlainObjectType, Options_, MakePointer_> > Base;
34   #ifdef EIGEN_USE_SYCL
35     typedef  typename Eigen::internal::remove_reference<typename Eigen::internal::nested<Self>::type>::type Nested;
36   #else
37      typedef typename Eigen::internal::nested<Self>::type Nested;
38   #endif
39    typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind;
40     typedef typename internal::traits<PlainObjectType>::Index Index;
41     typedef typename internal::traits<PlainObjectType>::Scalar Scalar;
42     typedef typename NumTraits<Scalar>::Real RealScalar;
43     typedef typename PlainObjectType::Base::CoeffReturnType CoeffReturnType;
44 
45     typedef typename MakePointer_<Scalar>::Type PointerType;
46     typedef typename MakePointer_<Scalar>::ConstType PointerConstType;
47 
48     // WARN: PointerType still can be a pointer to const (const Scalar*), for
49     // example in TensorMap<Tensor<const Scalar, ...>> expression. This type of
50     // expression should be illegal, but adding this restriction is not possible
51     // in practice (see https://bitbucket.org/eigen/eigen/pull-requests/488).
52     typedef typename internal::conditional<
53         bool(internal::is_lvalue<PlainObjectType>::value),
54         PointerType,      // use simple pointer in lvalue expressions
55         PointerConstType  // use const pointer in rvalue expressions
56         >::type StoragePointerType;
57 
58     // If TensorMap was constructed over rvalue expression (e.g. const Tensor),
59     // we should return a reference to const from operator() (and others), even
60     // if TensorMap itself is not const.
61     typedef typename internal::conditional<
62         bool(internal::is_lvalue<PlainObjectType>::value),
63         Scalar&,
64         const Scalar&
65         >::type StorageRefType;
66 
67     static const int Options = Options_;
68 
69     static const Index NumIndices = PlainObjectType::NumIndices;
70     typedef typename PlainObjectType::Dimensions Dimensions;
71 
72     enum {
73       IsAligned = ((int(Options_)&Aligned)==Aligned),
74       Layout = PlainObjectType::Layout,
75       CoordAccess = true,
76       RawAccess = true
77     };
78 
79     EIGEN_DEVICE_FUNC
TensorMap(StoragePointerType dataPtr)80     EIGEN_STRONG_INLINE TensorMap(StoragePointerType dataPtr) : m_data(dataPtr), m_dimensions() {
81       // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
82       EIGEN_STATIC_ASSERT((0 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE)
83     }
84 
85 #if EIGEN_HAS_VARIADIC_TEMPLATES
86     template<typename... IndexTypes> EIGEN_DEVICE_FUNC
TensorMap(StoragePointerType dataPtr,Index firstDimension,IndexTypes...otherDimensions)87     EIGEN_STRONG_INLINE TensorMap(StoragePointerType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(firstDimension, otherDimensions...) {
88       // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
89       EIGEN_STATIC_ASSERT((sizeof...(otherDimensions) + 1 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE)
90     }
91 #else
92     EIGEN_DEVICE_FUNC
TensorMap(StoragePointerType dataPtr,Index firstDimension)93     EIGEN_STRONG_INLINE TensorMap(StoragePointerType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(firstDimension) {
94       // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
95       EIGEN_STATIC_ASSERT((1 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE)
96     }
97     EIGEN_DEVICE_FUNC
TensorMap(StoragePointerType dataPtr,Index dim1,Index dim2)98     EIGEN_STRONG_INLINE TensorMap(StoragePointerType dataPtr, Index dim1, Index dim2) : m_data(dataPtr), m_dimensions(dim1, dim2) {
99       EIGEN_STATIC_ASSERT(2 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
100     }
101     EIGEN_DEVICE_FUNC
TensorMap(StoragePointerType dataPtr,Index dim1,Index dim2,Index dim3)102     EIGEN_STRONG_INLINE TensorMap(StoragePointerType dataPtr, Index dim1, Index dim2, Index dim3) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3) {
103       EIGEN_STATIC_ASSERT(3 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
104     }
105     EIGEN_DEVICE_FUNC
TensorMap(StoragePointerType dataPtr,Index dim1,Index dim2,Index dim3,Index dim4)106     EIGEN_STRONG_INLINE TensorMap(StoragePointerType dataPtr, Index dim1, Index dim2, Index dim3, Index dim4) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3, dim4) {
107       EIGEN_STATIC_ASSERT(4 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
108     }
109     EIGEN_DEVICE_FUNC
TensorMap(StoragePointerType dataPtr,Index dim1,Index dim2,Index dim3,Index dim4,Index dim5)110     EIGEN_STRONG_INLINE TensorMap(StoragePointerType dataPtr, Index dim1, Index dim2, Index dim3, Index dim4, Index dim5) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3, dim4, dim5) {
111       EIGEN_STATIC_ASSERT(5 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
112     }
113 #endif
114 
TensorMap(StoragePointerType dataPtr,const array<Index,NumIndices> & dimensions)115    EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(StoragePointerType dataPtr, const array<Index, NumIndices>& dimensions)
116       : m_data(dataPtr), m_dimensions(dimensions)
117     { }
118 
119     template <typename Dimensions>
TensorMap(StoragePointerType dataPtr,const Dimensions & dimensions)120     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(StoragePointerType dataPtr, const Dimensions& dimensions)
121       : m_data(dataPtr), m_dimensions(dimensions)
122     { }
123 
TensorMap(PlainObjectType & tensor)124     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PlainObjectType& tensor)
125       : m_data(tensor.data()), m_dimensions(tensor.dimensions())
126     { }
127 
128     EIGEN_DEVICE_FUNC
rank()129     EIGEN_STRONG_INLINE Index rank() const { return m_dimensions.rank(); }
130     EIGEN_DEVICE_FUNC
dimension(Index n)131     EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_dimensions[n]; }
132     EIGEN_DEVICE_FUNC
dimensions()133     EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
134     EIGEN_DEVICE_FUNC
size()135     EIGEN_STRONG_INLINE Index size() const { return m_dimensions.TotalSize(); }
136     EIGEN_DEVICE_FUNC
data()137     EIGEN_STRONG_INLINE StoragePointerType data() { return m_data; }
138     EIGEN_DEVICE_FUNC
data()139     EIGEN_STRONG_INLINE StoragePointerType data() const { return m_data; }
140 
141     EIGEN_DEVICE_FUNC
operator()142     EIGEN_STRONG_INLINE StorageRefType operator()(const array<Index, NumIndices>& indices) const
143     {
144       //      eigen_assert(checkIndexRange(indices));
145       if (PlainObjectType::Options&RowMajor) {
146         const Index index = m_dimensions.IndexOfRowMajor(indices);
147         return m_data[index];
148       } else {
149         const Index index = m_dimensions.IndexOfColMajor(indices);
150         return m_data[index];
151       }
152     }
153 
154     EIGEN_DEVICE_FUNC
operator()155     EIGEN_STRONG_INLINE StorageRefType operator()() const
156     {
157       EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE)
158       return m_data[0];
159     }
160 
161     EIGEN_DEVICE_FUNC
operator()162     EIGEN_STRONG_INLINE StorageRefType operator()(Index index) const
163     {
164       eigen_internal_assert(index >= 0 && index < size());
165       return m_data[index];
166     }
167 
168 #if EIGEN_HAS_VARIADIC_TEMPLATES
169     template<typename... IndexTypes> EIGEN_DEVICE_FUNC
operator()170     EIGEN_STRONG_INLINE StorageRefType operator()(Index firstIndex, Index secondIndex, IndexTypes... otherIndices) const
171     {
172       EIGEN_STATIC_ASSERT(sizeof...(otherIndices) + 2 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
173       eigen_assert(internal::all((Eigen::NumTraits<Index>::highest() >= otherIndices)...));
174       if (PlainObjectType::Options&RowMajor) {
175         const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumIndices>{{firstIndex, secondIndex, otherIndices...}});
176         return m_data[index];
177       } else {
178         const Index index = m_dimensions.IndexOfColMajor(array<Index, NumIndices>{{firstIndex, secondIndex, otherIndices...}});
179         return m_data[index];
180       }
181     }
182 #else
183     EIGEN_DEVICE_FUNC
operator()184     EIGEN_STRONG_INLINE StorageRefType operator()(Index i0, Index i1) const
185     {
186       if (PlainObjectType::Options&RowMajor) {
187         const Index index = i1 + i0 * m_dimensions[1];
188         return m_data[index];
189       } else {
190         const Index index = i0 + i1 * m_dimensions[0];
191         return m_data[index];
192       }
193     }
194     EIGEN_DEVICE_FUNC
operator()195     EIGEN_STRONG_INLINE StorageRefType operator()(Index i0, Index i1, Index i2) const
196     {
197       if (PlainObjectType::Options&RowMajor) {
198          const Index index = i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0);
199          return m_data[index];
200       } else {
201          const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2);
202         return m_data[index];
203       }
204     }
205     EIGEN_DEVICE_FUNC
operator()206     EIGEN_STRONG_INLINE StorageRefType operator()(Index i0, Index i1, Index i2, Index i3) const
207     {
208       if (PlainObjectType::Options&RowMajor) {
209         const Index index = i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0));
210         return m_data[index];
211       } else {
212         const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * i3));
213         return m_data[index];
214       }
215     }
216     EIGEN_DEVICE_FUNC
operator()217     EIGEN_STRONG_INLINE StorageRefType operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const
218     {
219       if (PlainObjectType::Options&RowMajor) {
220         const Index index = i4 + m_dimensions[4] * (i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0)));
221         return m_data[index];
222       } else {
223         const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * (i3 + m_dimensions[3] * i4)));
224         return m_data[index];
225       }
226     }
227 #endif
228 
229     EIGEN_DEVICE_FUNC
operator()230     EIGEN_STRONG_INLINE StorageRefType operator()(const array<Index, NumIndices>& indices)
231     {
232       //      eigen_assert(checkIndexRange(indices));
233       if (PlainObjectType::Options&RowMajor) {
234         const Index index = m_dimensions.IndexOfRowMajor(indices);
235         return m_data[index];
236       } else {
237         const Index index = m_dimensions.IndexOfColMajor(indices);
238         return m_data[index];
239       }
240     }
241 
242     EIGEN_DEVICE_FUNC
operator()243     EIGEN_STRONG_INLINE StorageRefType operator()()
244     {
245       EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE)
246       return m_data[0];
247     }
248 
249     EIGEN_DEVICE_FUNC
operator()250     EIGEN_STRONG_INLINE StorageRefType operator()(Index index)
251     {
252       eigen_internal_assert(index >= 0 && index < size());
253       return m_data[index];
254     }
255 
256 #if EIGEN_HAS_VARIADIC_TEMPLATES
257     template<typename... IndexTypes> EIGEN_DEVICE_FUNC
operator()258     EIGEN_STRONG_INLINE StorageRefType operator()(Index firstIndex, Index secondIndex, IndexTypes... otherIndices)
259     {
260       static_assert(sizeof...(otherIndices) + 2 == NumIndices || NumIndices == Dynamic, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
261        eigen_assert(internal::all((Eigen::NumTraits<Index>::highest() >= otherIndices)...));
262       const std::size_t NumDims = sizeof...(otherIndices) + 2;
263       if (PlainObjectType::Options&RowMajor) {
264         const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumDims>{{firstIndex, secondIndex, otherIndices...}});
265         return m_data[index];
266       } else {
267         const Index index = m_dimensions.IndexOfColMajor(array<Index, NumDims>{{firstIndex, secondIndex, otherIndices...}});
268         return m_data[index];
269       }
270     }
271 #else
272     EIGEN_DEVICE_FUNC
operator()273     EIGEN_STRONG_INLINE StorageRefType operator()(Index i0, Index i1)
274     {
275        if (PlainObjectType::Options&RowMajor) {
276          const Index index = i1 + i0 * m_dimensions[1];
277         return m_data[index];
278       } else {
279         const Index index = i0 + i1 * m_dimensions[0];
280         return m_data[index];
281       }
282     }
283     EIGEN_DEVICE_FUNC
operator()284     EIGEN_STRONG_INLINE StorageRefType operator()(Index i0, Index i1, Index i2)
285     {
286        if (PlainObjectType::Options&RowMajor) {
287          const Index index = i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0);
288         return m_data[index];
289       } else {
290          const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2);
291         return m_data[index];
292       }
293     }
294     EIGEN_DEVICE_FUNC
operator()295     EIGEN_STRONG_INLINE StorageRefType operator()(Index i0, Index i1, Index i2, Index i3)
296     {
297       if (PlainObjectType::Options&RowMajor) {
298         const Index index = i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0));
299         return m_data[index];
300       } else {
301         const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * i3));
302         return m_data[index];
303       }
304     }
305     EIGEN_DEVICE_FUNC
operator()306     EIGEN_STRONG_INLINE StorageRefType operator()(Index i0, Index i1, Index i2, Index i3, Index i4)
307     {
308       if (PlainObjectType::Options&RowMajor) {
309         const Index index = i4 + m_dimensions[4] * (i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0)));
310         return m_data[index];
311       } else {
312         const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * (i3 + m_dimensions[3] * i4)));
313         return m_data[index];
314       }
315     }
316 #endif
317 
318     EIGEN_TENSOR_INHERIT_ASSIGNMENT_OPERATORS(TensorMap)
319 
320   private:
321     StoragePointerType m_data;
322     Dimensions m_dimensions;
323 };
324 
325 } // end namespace Eigen
326 
327 #endif // EIGEN_CXX11_TENSOR_TENSOR_MAP_H
328