1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/QuantizeHelper.hpp>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/DataLayoutUtils.hpp>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker namespace
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker
CreateDepthToSpaceNetwork(const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & outputInfo,const armnn::DepthToSpaceDescriptor & descriptor)18*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateDepthToSpaceNetwork(const armnn::TensorInfo& inputInfo,
19*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputInfo,
20*89c4ff92SAndroid Build Coastguard Worker const armnn::DepthToSpaceDescriptor& descriptor)
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
23*89c4ff92SAndroid Build Coastguard Worker
24*89c4ff92SAndroid Build Coastguard Worker INetworkPtr network(INetwork::Create());
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input = network->AddInputLayer(0, "input");
27*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* depthToSpace = network->AddDepthToSpaceLayer(descriptor, "depthToSpace");
28*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = network->AddOutputLayer(0, "output");
29*89c4ff92SAndroid Build Coastguard Worker
30*89c4ff92SAndroid Build Coastguard Worker Connect(input, depthToSpace, inputInfo, 0, 0);
31*89c4ff92SAndroid Build Coastguard Worker Connect(depthToSpace, output, outputInfo, 0, 0);
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker return network;
34*89c4ff92SAndroid Build Coastguard Worker }
35*89c4ff92SAndroid Build Coastguard Worker
36*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
DepthToSpaceEndToEndImpl(const std::vector<armnn::BackendId> & backends,const DepthToSpaceDescriptor & descriptor,const armnn::TensorShape & nhwcInputShape,const armnn::TensorShape & nhwcOutputShape,const std::vector<float> & floatInputData,const std::vector<float> & floatExpectedOutputData)37*89c4ff92SAndroid Build Coastguard Worker void DepthToSpaceEndToEndImpl(const std::vector<armnn::BackendId>& backends,
38*89c4ff92SAndroid Build Coastguard Worker const DepthToSpaceDescriptor& descriptor,
39*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape& nhwcInputShape,
40*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape& nhwcOutputShape,
41*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& floatInputData,
42*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& floatExpectedOutputData)
43*89c4ff92SAndroid Build Coastguard Worker {
44*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
45*89c4ff92SAndroid Build Coastguard Worker
46*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputInfo(nhwcInputShape, ArmnnType);
47*89c4ff92SAndroid Build Coastguard Worker inputInfo.SetConstant(true);
48*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputInfo(nhwcOutputShape, ArmnnType);
49*89c4ff92SAndroid Build Coastguard Worker
50*89c4ff92SAndroid Build Coastguard Worker constexpr float qScale = 0.25f;
51*89c4ff92SAndroid Build Coastguard Worker constexpr int32_t qOffset = 128;
52*89c4ff92SAndroid Build Coastguard Worker
53*89c4ff92SAndroid Build Coastguard Worker // Set quantization parameters for quantized types
54*89c4ff92SAndroid Build Coastguard Worker if (IsQuantizedType<T>())
55*89c4ff92SAndroid Build Coastguard Worker {
56*89c4ff92SAndroid Build Coastguard Worker inputInfo.SetQuantizationScale(qScale);
57*89c4ff92SAndroid Build Coastguard Worker inputInfo.SetQuantizationOffset(qOffset);
58*89c4ff92SAndroid Build Coastguard Worker outputInfo.SetQuantizationScale(qScale);
59*89c4ff92SAndroid Build Coastguard Worker outputInfo.SetQuantizationOffset(qOffset);
60*89c4ff92SAndroid Build Coastguard Worker }
61*89c4ff92SAndroid Build Coastguard Worker
62*89c4ff92SAndroid Build Coastguard Worker std::vector<T> inputData = armnnUtils::QuantizedVector<T>(floatInputData, qScale, qOffset);
63*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutputData = armnnUtils::QuantizedVector<T>(floatExpectedOutputData, qScale, qOffset);
64*89c4ff92SAndroid Build Coastguard Worker
65*89c4ff92SAndroid Build Coastguard Worker // Permute tensors from NHWC to NCHW (if needed)
66*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_DataLayout == DataLayout::NCHW)
67*89c4ff92SAndroid Build Coastguard Worker {
68*89c4ff92SAndroid Build Coastguard Worker PermuteTensorNhwcToNchw(inputInfo, inputData);
69*89c4ff92SAndroid Build Coastguard Worker PermuteTensorNhwcToNchw(outputInfo, expectedOutputData);
70*89c4ff92SAndroid Build Coastguard Worker }
71*89c4ff92SAndroid Build Coastguard Worker
72*89c4ff92SAndroid Build Coastguard Worker INetworkPtr network = CreateDepthToSpaceNetwork(inputInfo, outputInfo, descriptor);
73*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(network),
74*89c4ff92SAndroid Build Coastguard Worker { { 0, inputData } },
75*89c4ff92SAndroid Build Coastguard Worker { { 0, expectedOutputData } },
76*89c4ff92SAndroid Build Coastguard Worker backends);
77*89c4ff92SAndroid Build Coastguard Worker }
78*89c4ff92SAndroid Build Coastguard Worker
79*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
80*89c4ff92SAndroid Build Coastguard Worker
81*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType>
DepthToSpaceEndToEnd(const std::vector<armnn::BackendId> & defaultBackends,armnn::DataLayout dataLayout)82*89c4ff92SAndroid Build Coastguard Worker void DepthToSpaceEndToEnd(const std::vector<armnn::BackendId>& defaultBackends,
83*89c4ff92SAndroid Build Coastguard Worker armnn::DataLayout dataLayout)
84*89c4ff92SAndroid Build Coastguard Worker {
85*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
86*89c4ff92SAndroid Build Coastguard Worker
87*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape = { 2, 2, 2, 4 };
88*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape = { 2, 4, 4, 1 };
89*89c4ff92SAndroid Build Coastguard Worker
90*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData =
91*89c4ff92SAndroid Build Coastguard Worker {
92*89c4ff92SAndroid Build Coastguard Worker 1.f, 2.f, 3.f, 4.f,
93*89c4ff92SAndroid Build Coastguard Worker 5.f, 6.f, 7.f, 8.f,
94*89c4ff92SAndroid Build Coastguard Worker 9.f, 10.f, 11.f, 12.f,
95*89c4ff92SAndroid Build Coastguard Worker 13.f, 14.f, 15.f, 16.f,
96*89c4ff92SAndroid Build Coastguard Worker
97*89c4ff92SAndroid Build Coastguard Worker 17.f, 18.f, 19.f, 20.f,
98*89c4ff92SAndroid Build Coastguard Worker 21.f, 22.f, 23.f, 24.f,
99*89c4ff92SAndroid Build Coastguard Worker 25.f, 26.f, 27.f, 28.f,
100*89c4ff92SAndroid Build Coastguard Worker 29.f, 30.f, 31.f, 32.f
101*89c4ff92SAndroid Build Coastguard Worker };
102*89c4ff92SAndroid Build Coastguard Worker
103*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputData =
104*89c4ff92SAndroid Build Coastguard Worker {
105*89c4ff92SAndroid Build Coastguard Worker 1.f, 2.f, 5.f, 6.f,
106*89c4ff92SAndroid Build Coastguard Worker 3.f, 4.f, 7.f, 8.f,
107*89c4ff92SAndroid Build Coastguard Worker 9.f, 10.f, 13.f, 14.f,
108*89c4ff92SAndroid Build Coastguard Worker 11.f, 12.f, 15.f, 16.f,
109*89c4ff92SAndroid Build Coastguard Worker
110*89c4ff92SAndroid Build Coastguard Worker 17.f, 18.f, 21.f, 22.f,
111*89c4ff92SAndroid Build Coastguard Worker 19.f, 20.f, 23.f, 24.f,
112*89c4ff92SAndroid Build Coastguard Worker 25.f, 26.f, 29.f, 30.f,
113*89c4ff92SAndroid Build Coastguard Worker 27.f, 28.f, 31.f, 32.f
114*89c4ff92SAndroid Build Coastguard Worker };
115*89c4ff92SAndroid Build Coastguard Worker
116*89c4ff92SAndroid Build Coastguard Worker DepthToSpaceEndToEndImpl<ArmnnType>(defaultBackends,
117*89c4ff92SAndroid Build Coastguard Worker DepthToSpaceDescriptor(2, dataLayout),
118*89c4ff92SAndroid Build Coastguard Worker inputShape,
119*89c4ff92SAndroid Build Coastguard Worker outputShape,
120*89c4ff92SAndroid Build Coastguard Worker inputData,
121*89c4ff92SAndroid Build Coastguard Worker expectedOutputData);
122*89c4ff92SAndroid Build Coastguard Worker }
123