xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/TensorBufferArrayView.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Tensor.hpp>
9 
10 #include <armnnUtils/DataLayoutIndexed.hpp>
11 
12 #include <armnn/utility/Assert.hpp>
13 
14 namespace armnn
15 {
16 
17 // Utility class providing access to raw tensor memory based on indices along each dimension.
18 template <typename DataType>
19 class TensorBufferArrayView
20 {
21 public:
TensorBufferArrayView(const TensorShape & shape,DataType * data,armnnUtils::DataLayoutIndexed dataLayout=DataLayout::NCHW)22     TensorBufferArrayView(const TensorShape& shape, DataType* data,
23                           armnnUtils::DataLayoutIndexed dataLayout = DataLayout::NCHW)
24         : m_Shape(shape)
25         , m_Data(data)
26         , m_DataLayout(dataLayout)
27     {
28         ARMNN_ASSERT(m_Shape.GetNumDimensions() == 4);
29     }
30 
Get(unsigned int b,unsigned int c,unsigned int h,unsigned int w) const31     DataType& Get(unsigned int b, unsigned int c, unsigned int h, unsigned int w) const
32     {
33         return m_Data[m_DataLayout.GetIndex(m_Shape, b, c, h, w)];
34     }
35 
36 private:
37     const TensorShape             m_Shape;
38     DataType*                     m_Data;
39     armnnUtils::DataLayoutIndexed m_DataLayout;
40 };
41 
42 } //namespace armnn
43