1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "RoiPooling.h"
20 
21 #include <algorithm>
22 #include <cfloat>
23 #include <cmath>
24 #include <vector>
25 
26 #include "OperationResolver.h"
27 #include "OperationsExecutionUtils.h"
28 #include "Tracing.h"
29 
30 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
31 #include "CpuOperationUtils.h"
32 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
33 
34 namespace android {
35 namespace nn {
36 namespace roi_pooling {
37 
38 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
39 namespace {
40 
41 template <typename T_Input, typename T_Roi>
roiPoolingNhwc(const T_Input * inputData,const Shape & inputShape,const T_Roi * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape &,float heightStride,float widthStride,T_Input * outputData,const Shape & outputShape)42 inline bool roiPoolingNhwc(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
43                            const Shape& roiShape, const int32_t* batchSplitData,
44                            const Shape& /*batchSplitShape*/, float heightStride, float widthStride,
45                            T_Input* outputData, const Shape& outputShape) {
46     NNTRACE_TRANS("RoiPooling");
47 
48     const uint32_t kRoiDim = 4;
49     const T_Roi heightScale = 1.0f / heightStride;
50     const T_Roi widthScale = 1.0f / widthStride;
51 
52     uint32_t numBatches = getSizeOfDimension(inputShape, 0);
53     uint32_t inHeight = getSizeOfDimension(inputShape, 1);
54     uint32_t inWidth = getSizeOfDimension(inputShape, 2);
55     uint32_t inDepth = getSizeOfDimension(inputShape, 3);
56     uint32_t outHeight = getSizeOfDimension(outputShape, 1);
57     uint32_t outWidth = getSizeOfDimension(outputShape, 2);
58     uint32_t numRois = getSizeOfDimension(roiShape, 0);
59     uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1);
60 
61     T_Input* outPtr = outputData;
62     const T_Roi* roiDataEnd = roiData + numRois * roiInfoLength;
63     uint32_t roiIndex = 0;
64     for (const T_Roi* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) {
65         uint32_t batchId = batchSplitData[roiIndex];
66         // Check for malformed data
67         // 1. invalid batch id
68         // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight
69         // 3. Invalid region: x2 < x1 || y2 < y1
70         NN_RET_CHECK_GE(batchId, 0u);
71         NN_RET_CHECK_LT(batchId, numBatches);
72         NN_RET_CHECK(roiInfo[0] >= 0);
73         NN_RET_CHECK(roiInfo[1] >= 0);
74         NN_RET_CHECK(roiInfo[2] >= 0);
75         NN_RET_CHECK(roiInfo[3] >= 0);
76         NN_RET_CHECK(roiInfo[0] * widthScale <= inWidth);
77         NN_RET_CHECK(roiInfo[1] * heightScale <= inHeight);
78         NN_RET_CHECK(roiInfo[2] * widthScale <= inWidth);
79         NN_RET_CHECK(roiInfo[3] * heightScale <= inHeight);
80         NN_RET_CHECK(roiInfo[0] <= roiInfo[2]);
81         NN_RET_CHECK(roiInfo[1] <= roiInfo[3]);
82 
83         int32_t wRoiStart = std::round(static_cast<float>(roiInfo[0] * widthScale));
84         int32_t hRoiStart = std::round(static_cast<float>(roiInfo[1] * heightScale));
85         int32_t wRoiEnd = std::round(static_cast<float>(roiInfo[2] * widthScale));
86         int32_t hRoiEnd = std::round(static_cast<float>(roiInfo[3] * heightScale));
87 
88         // Rois with width/height < 1 are considered malformed and are forced to be 1
89         T_Roi roiWidth = static_cast<T_Roi>(std::max(wRoiEnd - wRoiStart + 1, 1));
90         T_Roi roiHeight = static_cast<T_Roi>(std::max(hRoiEnd - hRoiStart + 1, 1));
91         T_Roi wStepSize = roiWidth / static_cast<T_Roi>(outWidth);
92         T_Roi hStepSize = roiHeight / static_cast<T_Roi>(outHeight);
93 
94         const T_Input* batchBase = inputData + batchId * inHeight * inWidth * inDepth;
95         for (uint32_t i = 0; i < outHeight; i++) {
96             for (uint32_t j = 0; j < outWidth; j++) {
97                 // Take floor on start, ceil on end, start included, end excluded, i.e. [start, end)
98                 // end is guaranteed to larger than start by at least 1
99                 uint32_t wStart = std::floor(static_cast<float>(wStepSize * j + wRoiStart));
100                 uint32_t wEnd = std::ceil(static_cast<float>(wStepSize * (j + 1) + wRoiStart));
101                 uint32_t hStart = std::floor(static_cast<float>(hStepSize * i + hRoiStart));
102                 uint32_t hEnd = std::ceil(static_cast<float>(hStepSize * (i + 1) + hRoiStart));
103 
104                 wStart = std::min(wStart, inWidth);
105                 wEnd = std::min(wEnd, inWidth);
106                 hStart = std::min(hStart, inHeight);
107                 hEnd = std::min(hEnd, inHeight);
108 
109                 for (uint32_t k = 0; k < inDepth; k++) {
110                     T_Input maxValue = static_cast<T_Input>(inputShape.offset);
111                     bool first = true;
112                     for (uint32_t h = hStart; h < hEnd; h++) {
113                         for (uint32_t w = wStart; w < wEnd; w++) {
114                             T_Input inputValue = batchBase[h * inWidth * inDepth + w * inDepth + k];
115                             if (first || inputValue > maxValue) {
116                                 maxValue = inputValue;
117                                 first = false;
118                             }
119                         }
120                     }
121                     outPtr[k] = maxValue;
122                 }
123                 outPtr += inDepth;
124             }
125         }
126     }
127     return true;
128 }
129 
130 template <typename T_Input, typename T_Roi>
roiPooling(const T_Input * inputData,const Shape & inputShape,const T_Roi * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,bool useNchw,T_Input * outputData,const Shape & outputShape)131 inline bool roiPooling(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
132                        const Shape& roiShape, const int32_t* batchSplitData,
133                        const Shape& batchSplitShape, float heightStride, float widthStride,
134                        bool useNchw, T_Input* outputData, const Shape& outputShape) {
135     InputWithLayout<T_Input> input(useNchw);
136     OutputWithLayout<T_Input> output(useNchw);
137     NN_RET_CHECK(input.initialize(inputData, inputShape));
138     NN_RET_CHECK(output.initialize(outputData, outputShape));
139     NN_RET_CHECK(roiPoolingNhwc(input.getNhwcBuffer(), input.getNhwcShape(), roiData, roiShape,
140                                 batchSplitData, batchSplitShape, heightStride, widthStride,
141                                 output.getNhwcBuffer(), output.getNhwcShape()));
142     NN_RET_CHECK(output.commit());
143     return true;
144 }
145 
146 template <>
roiPooling(const uint8_t * inputData,const Shape & inputShape,const uint16_t * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,bool useNchw,uint8_t * outputData,const Shape & outputShape)147 inline bool roiPooling<uint8_t, uint16_t>(const uint8_t* inputData, const Shape& inputShape,
148                                           const uint16_t* roiData, const Shape& roiShape,
149                                           const int32_t* batchSplitData,
150                                           const Shape& batchSplitShape, float heightStride,
151                                           float widthStride, bool useNchw, uint8_t* outputData,
152                                           const Shape& outputShape) {
153     std::vector<float> roi_float32(getNumberOfElements(roiShape));
154     convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32);
155     NN_RET_CHECK(roiPooling(inputData, inputShape, roi_float32.data(), roiShape, batchSplitData,
156                             batchSplitShape, heightStride, widthStride, useNchw, outputData,
157                             outputShape));
158     return true;
159 }
160 
161 template <>
roiPooling(const int8_t * inputData,const Shape & inputShape,const uint16_t * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,bool useNchw,int8_t * outputData,const Shape & outputShape)162 inline bool roiPooling<int8_t, uint16_t>(const int8_t* inputData, const Shape& inputShape,
163                                          const uint16_t* roiData, const Shape& roiShape,
164                                          const int32_t* batchSplitData,
165                                          const Shape& batchSplitShape, float heightStride,
166                                          float widthStride, bool useNchw, int8_t* outputData,
167                                          const Shape& outputShape) {
168     std::vector<float> roi_float32(getNumberOfElements(roiShape));
169     convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32);
170     NN_RET_CHECK(roiPooling(inputData, inputShape, roi_float32.data(), roiShape, batchSplitData,
171                             batchSplitShape, heightStride, widthStride, useNchw, outputData,
172                             outputShape));
173     return true;
174 }
175 
176 }  // namespace
177 
prepare(IOperationExecutionContext * context)178 bool prepare(IOperationExecutionContext* context) {
179     bool useNchw = context->getInputValue<bool>(kLayoutScalar);
180     Shape input = context->getInputShape(kInputTensor);
181     Shape roiShape = context->getInputShape(kRoiTensor);
182     Shape batchSplitShape = context->getInputShape(kBatchSplitTensor);
183     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4u);
184     NN_RET_CHECK_EQ(getNumberOfDimensions(roiShape), 2u);
185 
186     [[maybe_unused]] uint32_t numBatches = getSizeOfDimension(input, 0);
187     [[maybe_unused]] uint32_t inHeight = getSizeOfDimension(input, useNchw ? 2 : 1);
188     [[maybe_unused]] uint32_t inWidth = getSizeOfDimension(input, useNchw ? 3 : 2);
189     uint32_t inDepth = getSizeOfDimension(input, useNchw ? 1 : 3);
190     uint32_t numRois = getSizeOfDimension(roiShape, 0);
191     NN_RET_CHECK_EQ(getSizeOfDimension(roiShape, 1), 4u);
192     NN_RET_CHECK_EQ(getSizeOfDimension(batchSplitShape, 0), numRois);
193 
194     auto outputHeight = context->getInputValue<int32_t>(kOutputHeightScalar);
195     auto outputWidth = context->getInputValue<int32_t>(kOutputWidthScalar);
196     float heightStride, widthStride;
197     if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
198         heightStride = context->getInputValue<_Float16>(kHeightStrideSalar);
199         widthStride = context->getInputValue<_Float16>(kWidthStrideScalar);
200     } else {
201         heightStride = context->getInputValue<float>(kHeightStrideSalar);
202         widthStride = context->getInputValue<float>(kWidthStrideScalar);
203     }
204     NN_RET_CHECK_GT(outputHeight, 0);
205     NN_RET_CHECK_GT(outputWidth, 0);
206     NN_RET_CHECK_GT(heightStride, 0);
207     NN_RET_CHECK_GT(widthStride, 0);
208 
209     if (roiShape.type == OperandType::TENSOR_QUANT16_ASYMM) {
210         NN_RET_CHECK_EQ(roiShape.scale, 0.125f);
211         NN_RET_CHECK_EQ(roiShape.offset, 0);
212     }
213 
214     Shape output = input;
215     if (useNchw) {
216         output.dimensions = {numRois, inDepth, static_cast<uint32_t>(outputHeight),
217                              static_cast<uint32_t>(outputWidth)};
218     } else {
219         output.dimensions = {numRois, static_cast<uint32_t>(outputHeight),
220                              static_cast<uint32_t>(outputWidth), inDepth};
221     }
222     return context->setOutputShape(kOutputTensor, output);
223 }
224 
execute(IOperationExecutionContext * context)225 bool execute(IOperationExecutionContext* context) {
226     switch (context->getInputType(kInputTensor)) {
227         case OperandType::TENSOR_FLOAT16:
228             return roiPooling(context->getInputBuffer<_Float16>(kInputTensor),
229                               context->getInputShape(kInputTensor),
230                               context->getInputBuffer<_Float16>(kRoiTensor),
231                               context->getInputShape(kRoiTensor),
232                               context->getInputBuffer<int32_t>(kBatchSplitTensor),
233                               context->getInputShape(kBatchSplitTensor),
234                               context->getInputValue<_Float16>(kHeightStrideSalar),
235                               context->getInputValue<_Float16>(kWidthStrideScalar),
236                               context->getInputValue<bool>(kLayoutScalar),
237                               context->getOutputBuffer<_Float16>(kOutputTensor),
238                               context->getOutputShape(kOutputTensor));
239         case OperandType::TENSOR_FLOAT32:
240             return roiPooling(context->getInputBuffer<float>(kInputTensor),
241                               context->getInputShape(kInputTensor),
242                               context->getInputBuffer<float>(kRoiTensor),
243                               context->getInputShape(kRoiTensor),
244                               context->getInputBuffer<int32_t>(kBatchSplitTensor),
245                               context->getInputShape(kBatchSplitTensor),
246                               context->getInputValue<float>(kHeightStrideSalar),
247                               context->getInputValue<float>(kWidthStrideScalar),
248                               context->getInputValue<bool>(kLayoutScalar),
249                               context->getOutputBuffer<float>(kOutputTensor),
250                               context->getOutputShape(kOutputTensor));
251         case OperandType::TENSOR_QUANT8_ASYMM:
252             return roiPooling(context->getInputBuffer<uint8_t>(kInputTensor),
253                               context->getInputShape(kInputTensor),
254                               context->getInputBuffer<uint16_t>(kRoiTensor),
255                               context->getInputShape(kRoiTensor),
256                               context->getInputBuffer<int32_t>(kBatchSplitTensor),
257                               context->getInputShape(kBatchSplitTensor),
258                               context->getInputValue<float>(kHeightStrideSalar),
259                               context->getInputValue<float>(kWidthStrideScalar),
260                               context->getInputValue<bool>(kLayoutScalar),
261                               context->getOutputBuffer<uint8_t>(kOutputTensor),
262                               context->getOutputShape(kOutputTensor));
263         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
264             return roiPooling(context->getInputBuffer<int8_t>(kInputTensor),
265                               context->getInputShape(kInputTensor),
266                               context->getInputBuffer<uint16_t>(kRoiTensor),
267                               context->getInputShape(kRoiTensor),
268                               context->getInputBuffer<int32_t>(kBatchSplitTensor),
269                               context->getInputShape(kBatchSplitTensor),
270                               context->getInputValue<float>(kHeightStrideSalar),
271                               context->getInputValue<float>(kWidthStrideScalar),
272                               context->getInputValue<bool>(kLayoutScalar),
273                               context->getOutputBuffer<int8_t>(kOutputTensor),
274                               context->getOutputShape(kOutputTensor));
275         default:
276             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
277     }
278 }
279 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
280 
281 }  // namespace roi_pooling
282 
283 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(ROI_POOLING, roi_pooling::prepare, roi_pooling::execute);
284 
285 }  // namespace nn
286 }  // namespace android
287