xref: /aosp_15_r20/external/ComputeLibrary/tests/validation/fixtures/SpaceToBatchFixture.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2018-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_TEST_SPACE_TO_BATCH_LAYER_FIXTURE
25 #define ARM_COMPUTE_TEST_SPACE_TO_BATCH_LAYER_FIXTURE
26 
27 #include "tests/Globals.h"
28 #include "tests/framework/Asserts.h"
29 #include "tests/framework/Fixture.h"
30 #include "tests/validation/reference/SpaceToBatch.h"
31 
32 namespace arm_compute
33 {
34 namespace test
35 {
36 namespace validation
37 {
38 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
39 class SpaceToBatchLayerValidationGenericFixture : public framework::Fixture
40 {
41 public:
42     template <typename...>
setup(TensorShape input_shape,TensorShape block_shape_shape,TensorShape paddings_shape,TensorShape output_shape,DataType data_type,DataLayout data_layout,QuantizationInfo quantization_info)43     void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape paddings_shape, TensorShape output_shape,
44                DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info)
45     {
46         _target    = compute_target(input_shape, block_shape_shape, paddings_shape, output_shape, data_type, data_layout, quantization_info);
47         _reference = compute_reference(input_shape, block_shape_shape, paddings_shape, output_shape, data_type, quantization_info);
48     }
49 
50 protected:
51     template <typename U>
fill(U && tensor,int i)52     void fill(U &&tensor, int i)
53     {
54         library->fill_tensor_uniform(tensor, i);
55     }
56 
57     template <typename U>
fill_pad(U && tensor)58     void fill_pad(U &&tensor)
59     {
60         library->fill_tensor_value(tensor, 0);
61     }
62 
compute_target(TensorShape input_shape,const TensorShape & block_shape_shape,const TensorShape & paddings_shape,TensorShape output_shape,DataType data_type,DataLayout data_layout,QuantizationInfo quantization_info)63     TensorType compute_target(TensorShape input_shape, const TensorShape &block_shape_shape, const TensorShape &paddings_shape, TensorShape output_shape,
64                               DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info)
65     {
66         if(data_layout == DataLayout::NHWC)
67         {
68             permute(input_shape, PermutationVector(2U, 0U, 1U));
69             permute(output_shape, PermutationVector(2U, 0U, 1U));
70         }
71 
72         // Create tensors
73         TensorType input       = create_tensor<TensorType>(input_shape, data_type, 1, quantization_info, data_layout);
74         TensorType block_shape = create_tensor<TensorType>(block_shape_shape, DataType::S32);
75         TensorType paddings    = create_tensor<TensorType>(paddings_shape, DataType::S32);
76         TensorType output      = create_tensor<TensorType>(output_shape, data_type, 1, quantization_info, data_layout);
77 
78         // Create and configure function
79         FunctionType space_to_batch;
80         space_to_batch.configure(&input, &block_shape, &paddings, &output);
81 
82         ARM_COMPUTE_ASSERT(input.info()->is_resizable());
83         ARM_COMPUTE_ASSERT(block_shape.info()->is_resizable());
84         ARM_COMPUTE_ASSERT(paddings.info()->is_resizable());
85         ARM_COMPUTE_ASSERT(output.info()->is_resizable());
86 
87         // Allocate tensors
88         input.allocator()->allocate();
89         block_shape.allocator()->allocate();
90         paddings.allocator()->allocate();
91         output.allocator()->allocate();
92 
93         ARM_COMPUTE_ASSERT(!input.info()->is_resizable());
94         ARM_COMPUTE_ASSERT(!block_shape.info()->is_resizable());
95         ARM_COMPUTE_ASSERT(!paddings.info()->is_resizable());
96         ARM_COMPUTE_ASSERT(!output.info()->is_resizable());
97 
98         // Fill tensors
99         fill(AccessorType(input), 0);
100         fill_pad(AccessorType(paddings));
101         {
102             auto      block_shape_data = AccessorType(block_shape);
103             const int idx_width        = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
104             for(unsigned int i = 0; i < block_shape_shape.x(); ++i)
105             {
106                 static_cast<int32_t *>(block_shape_data.data())[i] = input_shape[i + idx_width] / output_shape[i + idx_width];
107             }
108         }
109         // Compute function
110         space_to_batch.run();
111 
112         return output;
113     }
114 
compute_reference(const TensorShape & input_shape,const TensorShape & block_shape_shape,const TensorShape & paddings_shape,const TensorShape & output_shape,DataType data_type,QuantizationInfo quantization_info)115     SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &block_shape_shape, const TensorShape &paddings_shape,
116                                       const TensorShape &output_shape, DataType data_type, QuantizationInfo quantization_info)
117     {
118         // Create reference
119         SimpleTensor<T>       input{ input_shape, data_type, 1, quantization_info };
120         SimpleTensor<int32_t> block_shape{ block_shape_shape, DataType::S32 };
121         SimpleTensor<int32_t> paddings{ paddings_shape, DataType::S32 };
122 
123         // Fill reference
124         fill(input, 0);
125         fill_pad(paddings);
126         for(unsigned int i = 0; i < block_shape_shape.x(); ++i)
127         {
128             block_shape[i] = input_shape[i] / output_shape[i];
129         }
130 
131         // Compute reference
132         return reference::space_to_batch(input, block_shape, paddings, output_shape);
133     }
134 
135     TensorType      _target{};
136     SimpleTensor<T> _reference{};
137 };
138 
139 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
140 class SpaceToBatchLayerValidationFixture : public SpaceToBatchLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
141 {
142 public:
143     template <typename...>
setup(TensorShape input_shape,TensorShape block_shape_shape,TensorShape paddings_shape,TensorShape output_shape,DataType data_type,DataLayout data_layout)144     void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape paddings_shape, TensorShape output_shape,
145                DataType data_type, DataLayout data_layout)
146     {
147         SpaceToBatchLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, block_shape_shape, paddings_shape, output_shape, data_type, data_layout, QuantizationInfo());
148     }
149 };
150 
151 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
152 class SpaceToBatchLayerValidationQuantizedFixture : public SpaceToBatchLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
153 {
154 public:
155     template <typename...>
setup(TensorShape input_shape,TensorShape block_shape_shape,TensorShape paddings_shape,TensorShape output_shape,DataType data_type,DataLayout data_layout,QuantizationInfo quantization_info)156     void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape paddings_shape, TensorShape output_shape,
157                DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info)
158     {
159         SpaceToBatchLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, block_shape_shape, paddings_shape, output_shape, data_type, data_layout, quantization_info);
160     }
161 };
162 } // namespace validation
163 } // namespace test
164 } // namespace arm_compute
165 #endif /* ARM_COMPUTE_TEST_SPACE_TO_BATCH_LAYER_FIXTURE */
166