xref: /aosp_15_r20/external/armnn/src/armnnUtils/DataLayoutIndexed.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <armnnUtils/DataLayoutIndexed.hpp>
7 
8 using namespace armnn;
9 
10 namespace armnnUtils
11 {
12 
DataLayoutIndexed(armnn::DataLayout dataLayout)13 DataLayoutIndexed::DataLayoutIndexed(armnn::DataLayout dataLayout)
14     : m_DataLayout(dataLayout)
15 {
16     switch (dataLayout)
17     {
18         case armnn::DataLayout::NHWC:
19             m_ChannelsIndex = 3;
20             m_HeightIndex   = 1;
21             m_WidthIndex    = 2;
22             break;
23         case armnn::DataLayout::NCHW:
24             m_ChannelsIndex = 1;
25             m_HeightIndex   = 2;
26             m_WidthIndex    = 3;
27             break;
28         case armnn::DataLayout::NDHWC:
29             m_DepthIndex    = 1;
30             m_HeightIndex   = 2;
31             m_WidthIndex    = 3;
32             m_ChannelsIndex = 4;
33             break;
34         case armnn::DataLayout::NCDHW:
35             m_ChannelsIndex = 1;
36             m_DepthIndex    = 2;
37             m_HeightIndex   = 3;
38             m_WidthIndex    = 4;
39             break;
40         default:
41             throw armnn::InvalidArgumentException("Unknown DataLayout value: " +
42                                                   std::to_string(static_cast<int>(dataLayout)));
43     }
44 }
45 
operator ==(const DataLayout & dataLayout,const DataLayoutIndexed & indexed)46 bool operator==(const DataLayout& dataLayout, const DataLayoutIndexed& indexed)
47 {
48     return dataLayout == indexed.GetDataLayout();
49 }
50 
operator ==(const DataLayoutIndexed & indexed,const DataLayout & dataLayout)51 bool operator==(const DataLayoutIndexed& indexed, const DataLayout& dataLayout)
52 {
53     return indexed.GetDataLayout() == dataLayout;
54 }
55 
56 } // namespace armnnUtils
57