1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/operations.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <set>
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25
26 #include "tensorflow/lite/delegates/gpu/common/shape.h"
27 #include "tensorflow/lite/delegates/gpu/common/status.h"
28 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
29
30 namespace tflite {
31 namespace gpu {
32
operator =(const Padding2D & value)33 Padding2D& Padding2D::operator=(const Padding2D& value) {
34 prepended = value.prepended;
35 appended = value.appended;
36 return *this;
37 }
38
operator ==(const Padding2D & value)39 bool Padding2D::operator==(const Padding2D& value) {
40 return this->prepended == value.prepended && this->appended == value.appended;
41 }
42
operator !=(const Padding2D & value)43 bool Padding2D::operator!=(const Padding2D& value) { return !(*this == value); }
44
operator -(const Padding2D & value)45 Padding2D& Padding2D::operator-(const Padding2D& value) {
46 prepended.h -= value.prepended.h;
47 prepended.w -= value.prepended.w;
48 appended.h -= value.appended.h;
49 appended.w -= value.appended.w;
50 return *this;
51 }
52
operator =(const Padding3D & value)53 Padding3D& Padding3D::operator=(const Padding3D& value) {
54 prepended = value.prepended;
55 appended = value.appended;
56 return *this;
57 }
58
operator ==(const Padding3D & value)59 bool Padding3D::operator==(const Padding3D& value) {
60 return this->prepended == value.prepended && this->appended == value.appended;
61 }
62
operator !=(const Padding3D & value)63 bool Padding3D::operator!=(const Padding3D& value) { return !(*this == value); }
64
operator -(const Padding3D & value)65 Padding3D& Padding3D::operator-(const Padding3D& value) {
66 prepended.h -= value.prepended.h;
67 prepended.w -= value.prepended.w;
68 prepended.d -= value.prepended.d;
69 appended.h -= value.appended.h;
70 appended.w -= value.appended.w;
71 appended.d -= value.appended.d;
72 return *this;
73 }
74
ToString(enum OperationType op)75 std::string ToString(enum OperationType op) {
76 switch (op) {
77 case OperationType::ABS:
78 return "abs";
79 case OperationType::ADD:
80 return "add";
81 case OperationType::BATCH_NORMALIZATION:
82 return "batch_normalization";
83 case OperationType::BATCH_TO_SPACE:
84 return "batch_to_space";
85 case OperationType::BATCHED_MATMUL:
86 return "batched_matmul";
87 case OperationType::CAST:
88 return "cast";
89 case OperationType::CONCAT:
90 return "concat";
91 case OperationType::CONSTANT:
92 return "const";
93 case OperationType::CONVOLUTION_2D:
94 return "convolution_2d";
95 case OperationType::CONVOLUTION_TRANSPOSED:
96 return "convolution_transposed";
97 case OperationType::COPY:
98 return "copy";
99 case OperationType::COS:
100 return "cos";
101 case OperationType::CUMSUM:
102 return "cumsum";
103 case OperationType::DENSIFY:
104 return "densify";
105 case OperationType::DEPTHWISE_CONVOLUTION:
106 return "depthwise_convolution";
107 case OperationType::DEPTH_TO_SPACE:
108 return "depth_to_space";
109 case OperationType::DIV:
110 return "div";
111 case OperationType::ELU:
112 return "elu";
113 case OperationType::EQUAL:
114 return "equal";
115 case OperationType::EXP:
116 return "exp";
117 case OperationType::FLOOR:
118 return "floor";
119 case OperationType::FLOOR_DIV:
120 return "floor_div";
121 case OperationType::FLOOR_MOD:
122 return "floor_mod";
123 case OperationType::FULLY_CONNECTED:
124 return "fully_connected";
125 case OperationType::FULLY_CONNECTED_INT8:
126 return "fully_connected_int8";
127 case OperationType::GATHER:
128 return "gather";
129 case OperationType::GREATER:
130 return "greater";
131 case OperationType::GREATER_EQUAL:
132 return "greater_equal";
133 case OperationType::HARD_SWISH:
134 return "hard_swish";
135 case OperationType::LESS:
136 return "less";
137 case OperationType::LESS_EQUAL:
138 return "less_equal";
139 case OperationType::LOG:
140 return "log";
141 case OperationType::LSTM:
142 return "lstm";
143 case OperationType::MAXIMUM:
144 return "maximum";
145 case OperationType::MAX_UNPOOLING_2D:
146 return "max_unpooling";
147 case OperationType::MEAN:
148 return "mean";
149 case OperationType::MEAN_STDDEV_NORMALIZATION:
150 return "mean_stddev_normalization";
151 case OperationType::MINIMUM:
152 return "minimum";
153 case OperationType::MUL:
154 return "mul";
155 case OperationType::NEG:
156 return "neg";
157 case OperationType::NOT_EQUAL:
158 return "not_equal";
159 case OperationType::ONE_HOT:
160 return "one_hot";
161 case OperationType::PAD:
162 return "pad";
163 case OperationType::POOLING_2D:
164 return "pooling_2d";
165 case OperationType::POW:
166 return "pow";
167 case OperationType::PRELU:
168 return "prelu";
169 case OperationType::QUANTIZE_AND_DEQUANTIZE:
170 return "quantize_and_dequantize";
171 case OperationType::REDUCE_MAXIMUM:
172 return "reduce_maximum";
173 case OperationType::REDUCE_MINIMUM:
174 return "reduce_minimum";
175 case OperationType::REDUCE_PRODUCT:
176 return "reduce_product";
177 case OperationType::REDUCE_SUM:
178 return "reduce_sum";
179 case OperationType::RELU:
180 return "relu";
181 case OperationType::RESAMPLER:
182 return "resampler";
183 case OperationType::RESHAPE:
184 return "reshape";
185 case OperationType::RESIZE:
186 return "resize";
187 case OperationType::RSQRT:
188 return "rsqrt";
189 case OperationType::SELECT_V2:
190 return "select_v2";
191 case OperationType::SIGMOID:
192 return "sigmoid";
193 case OperationType::SIN:
194 return "sin";
195 case OperationType::SLICE:
196 return "slice";
197 case OperationType::SOFTMAX:
198 return "softmax";
199 case OperationType::SPACE_TO_BATCH:
200 return "space_to_batch";
201 case OperationType::SPACE_TO_DEPTH:
202 return "space_to_depth";
203 case OperationType::SPLIT:
204 return "split";
205 case OperationType::SQRT:
206 return "sqrt";
207 case OperationType::SQUARE:
208 return "square";
209 case OperationType::SQUARED_DIFF:
210 return "squared_diff";
211 case OperationType::SUB:
212 return "subtract";
213 case OperationType::TANH:
214 return "tanh";
215 case OperationType::TILE:
216 return "tile";
217 case OperationType::TRANSPOSE:
218 return "transpose";
219 case OperationType::UNKNOWN:
220 return "unknown_operation";
221 }
222 }
223
OperationTypeFromString(const std::string & name)224 OperationType OperationTypeFromString(const std::string& name) {
225 static const auto* operations =
226 new std::unordered_map<std::string, OperationType>({
227 {"abs", OperationType::ABS},
228 {"add", OperationType::ADD},
229 {"batch_normalization", OperationType::BATCH_NORMALIZATION},
230 {"batched_matmul", OperationType::BATCHED_MATMUL},
231 {"cast", OperationType::CAST},
232 {"concat", OperationType::CONCAT},
233 {"const", OperationType::CONSTANT},
234 {"convolution_2d", OperationType::CONVOLUTION_2D},
235 {"convolution_transposed", OperationType::CONVOLUTION_TRANSPOSED},
236 {"copy", OperationType::COPY},
237 {"cos", OperationType::COS},
238 {"cumsum", OperationType::CUMSUM},
239 {"densify", OperationType::DENSIFY},
240 {"depthwise_convolution", OperationType::DEPTHWISE_CONVOLUTION},
241 {"depth_to_space", OperationType::DEPTH_TO_SPACE},
242 {"div", OperationType::DIV},
243 {"elu", OperationType::ELU},
244 {"equal", OperationType::EQUAL},
245 {"exp", OperationType::EXP},
246 {"floor", OperationType::FLOOR},
247 {"floor_div", OperationType::FLOOR_DIV},
248 {"floor_mod", OperationType::FLOOR_MOD},
249 {"fully_connected", OperationType::FULLY_CONNECTED},
250 {"fully_connected_int8", OperationType::FULLY_CONNECTED_INT8},
251 {"gather", OperationType::GATHER},
252 {"greater", OperationType::GREATER},
253 {"greater_equal", OperationType::GREATER_EQUAL},
254 {"hard_swish", OperationType::HARD_SWISH},
255 {"less", OperationType::LESS},
256 {"less_equal", OperationType::LESS_EQUAL},
257 {"log", OperationType::LOG},
258 {"lstm", OperationType::LSTM},
259 {"maximum", OperationType::MAXIMUM},
260 {"max_unpooling", OperationType::MAX_UNPOOLING_2D},
261 {"mean", OperationType::MEAN},
262 {"mean_stddev_normalization",
263 OperationType::MEAN_STDDEV_NORMALIZATION},
264 {"minimum", OperationType::MINIMUM},
265 {"mul", OperationType::MUL},
266 {"neg", OperationType::NEG},
267 {"not_equal", OperationType::NOT_EQUAL},
268 {"one_hot", OperationType::ONE_HOT},
269 {"pad", OperationType::PAD},
270 {"pooling_2d", OperationType::POOLING_2D},
271 {"pow", OperationType::POW},
272 {"prelu", OperationType::PRELU},
273 {"quantize_and_dequantize", OperationType::QUANTIZE_AND_DEQUANTIZE},
274 {"reduce_maximum", OperationType::REDUCE_MAXIMUM},
275 {"reduce_minimum", OperationType::REDUCE_MINIMUM},
276 {"reduce_product", OperationType::REDUCE_PRODUCT},
277 {"reduce_sum", OperationType::REDUCE_SUM},
278 {"relu", OperationType::RELU},
279 {"resampler", OperationType::RESAMPLER},
280 {"resize", OperationType::RESIZE},
281 {"reshape", OperationType::RESHAPE},
282 {"rsqrt", OperationType::RSQRT},
283 {"select_v2", OperationType::SELECT_V2},
284 {"sigmoid", OperationType::SIGMOID},
285 {"sin", OperationType::SIN},
286 {"slice", OperationType::SLICE},
287 {"softmax", OperationType::SOFTMAX},
288 {"space_to_depth", OperationType::SPACE_TO_DEPTH},
289 {"split", OperationType::SPLIT},
290 {"sqrt", OperationType::SQRT},
291 {"square", OperationType::SQUARE},
292 {"squared_diff", OperationType::SQUARED_DIFF},
293 {"subtract", OperationType::SUB},
294 {"tanh", OperationType::TANH},
295 {"tile", OperationType::TILE},
296 {"transpose", OperationType::TRANSPOSE},
297 });
298 auto op = operations->find(name);
299 return op == operations->end() ? OperationType::UNKNOWN : op->second;
300 }
301
302 namespace {
303
304 template <typename T>
DivideRoundUp(T n,T divisor)305 T DivideRoundUp(T n, T divisor) {
306 return (n - 1) / divisor + 1;
307 }
308
CalculateOutputSizeBeforeStrides(int32_t input,int32_t kernel,int32_t padding,int32_t dilation)309 int32_t CalculateOutputSizeBeforeStrides(int32_t input, int32_t kernel,
310 int32_t padding, int32_t dilation) {
311 const int32_t dilated_kernel = (kernel - 1) * dilation + 1;
312 return input + padding - dilated_kernel + 1;
313 }
314
315 template <Axis T>
CalculateOutputWithoutStrides(const BHWC & input,const Convolution2DAttributes & attr)316 int32_t CalculateOutputWithoutStrides(const BHWC& input,
317 const Convolution2DAttributes& attr) {
318 return CalculateOutputSizeBeforeStrides(
319 input.get<T>(), attr.weights.shape.get<T>(),
320 attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(),
321 attr.dilations.get<T>());
322 }
323
324 template <Axis T>
CalculateOutputWithoutStrides(const BHWDC & input,const Convolution3DAttributes & attr)325 int32_t CalculateOutputWithoutStrides(const BHWDC& input,
326 const Convolution3DAttributes& attr) {
327 return CalculateOutputSizeBeforeStrides(
328 input.get<T>(), attr.weights.shape.get<T>(),
329 attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(),
330 attr.dilations.get<T>());
331 }
332
333 template <Axis T>
CalculateOutputWithoutStrides(const BHWC & input,const Pooling2DAttributes & attr)334 int32_t CalculateOutputWithoutStrides(const BHWC& input,
335 const Pooling2DAttributes& attr) {
336 return CalculateOutputSizeBeforeStrides(
337 input.get<T>(), attr.kernel.get<T>(),
338 attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(),
339 /*dilation=*/1);
340 }
341
342 template <Axis T>
CalculateOutputWithoutStrides(const BHWDC & input,const Pooling3DAttributes & attr)343 int32_t CalculateOutputWithoutStrides(const BHWDC& input,
344 const Pooling3DAttributes& attr) {
345 return CalculateOutputSizeBeforeStrides(
346 input.get<T>(), attr.kernel.get<T>(),
347 attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(),
348 /*dilation=*/1);
349 }
350
351 template <Axis T>
CalculateOutput(const BHWC & input,const ConvolutionTransposedAttributes & attr)352 int32_t CalculateOutput(const BHWC& input,
353 const ConvolutionTransposedAttributes& attr) {
354 return (input.get<T>() - 1) * attr.stride.get<T>() -
355 (attr.padding.prepended.get<T>() + attr.padding.appended.get<T>()) +
356 attr.weights.shape.get<T>() + attr.adjacent.get<T>();
357 }
358
359 template <Axis T>
CalculateOutput(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)360 int32_t CalculateOutput(const BHWDC& input,
361 const ConvolutionTransposed3DAttributes& attr) {
362 return (input.get<T>() - 1) * attr.stride.get<T>() -
363 (attr.padding.prepended.get<T>() + attr.padding.appended.get<T>()) +
364 attr.weights.shape.get<T>();
365 }
366
StridedSize(int32_t size,int32_t stride)367 inline int32_t StridedSize(int32_t size, int32_t stride) {
368 return stride == 0 ? -1 : DivideRoundUp(size, stride);
369 }
370
371 template <Axis AxisT, typename AttrT>
CalculateOutput(const BHWC & input,const AttrT & attr)372 int32_t CalculateOutput(const BHWC& input, const AttrT& attr) {
373 return StridedSize(CalculateOutputWithoutStrides<AxisT>(input, attr),
374 attr.strides.template get<AxisT>());
375 }
376
377 template <Axis AxisT, typename AttrT>
CalculateOutput(const BHWDC & input,const AttrT & attr)378 int32_t CalculateOutput(const BHWDC& input, const AttrT& attr) {
379 return StridedSize(CalculateOutputWithoutStrides<AxisT>(input, attr),
380 attr.strides.template get<AxisT>());
381 }
382
CalculateSamePadding(int32_t input,int32_t kernel,int32_t dilation,int32_t stride)383 int32_t CalculateSamePadding(int32_t input, int32_t kernel, int32_t dilation,
384 int32_t stride) {
385 const int32_t dilated_kernel = (kernel - 1) * dilation + 1;
386 return std::max(0, dilated_kernel - (input - 1) % stride - 1);
387 }
388
389 // Returns a padding that should be present to make sure image size stays
390 // the same.
391 template <Axis AxisT>
CalculateSamePadding(const BHWC & input,const Convolution2DAttributes & attr)392 int32_t CalculateSamePadding(const BHWC& input,
393 const Convolution2DAttributes& attr) {
394 return CalculateSamePadding(
395 input.get<AxisT>(), attr.weights.shape.get<AxisT>(),
396 attr.dilations.get<AxisT>(), attr.strides.get<AxisT>());
397 }
398
399 // Returns a padding that should be present to make sure image size stays
400 // the same.
401 template <Axis AxisT>
CalculateSamePadding(const BHWDC & input,const Convolution3DAttributes & attr)402 int32_t CalculateSamePadding(const BHWDC& input,
403 const Convolution3DAttributes& attr) {
404 return CalculateSamePadding(
405 input.get<AxisT>(), attr.weights.shape.get<AxisT>(),
406 attr.dilations.get<AxisT>(), attr.strides.get<AxisT>());
407 }
408
409 template <Axis AxisT>
CalculateSamePadding(const BHWC & input,const ConvolutionTransposedAttributes & attr)410 int32_t CalculateSamePadding(const BHWC& input,
411 const ConvolutionTransposedAttributes& attr) {
412 return CalculateSamePadding(input.get<AxisT>(),
413 attr.weights.shape.get<AxisT>(),
414 /*dilation=*/1, attr.stride.get<AxisT>());
415 }
416
417 template <Axis AxisT>
CalculateSamePadding(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)418 int32_t CalculateSamePadding(const BHWDC& input,
419 const ConvolutionTransposed3DAttributes& attr) {
420 return CalculateSamePadding(input.get<AxisT>(),
421 attr.weights.shape.get<AxisT>(),
422 /*dilation=*/1, attr.stride.get<AxisT>());
423 }
424
425 template <Axis AxisT>
CalculateSamePadding(const BHWC & input,const Pooling2DAttributes & attr)426 int32_t CalculateSamePadding(const BHWC& input,
427 const Pooling2DAttributes& attr) {
428 return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
429 /*dilation=*/1, attr.strides.get<AxisT>());
430 }
431
432 template <Axis AxisT>
CalculateSamePadding(const BHWDC & input,const Pooling3DAttributes & attr)433 int32_t CalculateSamePadding(const BHWDC& input,
434 const Pooling3DAttributes& attr) {
435 return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
436 /*dilation=*/1, attr.strides.get<AxisT>());
437 }
438
439 template <Axis AxisT>
CalculateSamePadding(const BHWC & input,const MaxUnpooling2DAttributes & attr)440 int32_t CalculateSamePadding(const BHWC& input,
441 const MaxUnpooling2DAttributes& attr) {
442 return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
443 /*dilation=*/1, attr.strides.get<AxisT>());
444 }
445
446 template <Axis AxisT>
CalculateSamePadding(const BHWDC & input,const MaxUnpooling3DAttributes & attr)447 int32_t CalculateSamePadding(const BHWDC& input,
448 const MaxUnpooling3DAttributes& attr) {
449 return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
450 /*dilation=*/1, attr.strides.get<AxisT>());
451 }
452
MakeSamePadding(const BHWC & input,const ConvolutionTransposedAttributes & attr)453 Padding2D MakeSamePadding(const BHWC& input,
454 const ConvolutionTransposedAttributes& attr) {
455 int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
456 int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
457 Padding2D padding;
458 padding.prepended = HW(padding_height / 2, padding_width / 2);
459 padding.appended = HW(padding_height - padding_height / 2,
460 padding_width - padding_width / 2);
461 return padding;
462 }
463
MakeSamePadding(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)464 Padding3D MakeSamePadding(const BHWDC& input,
465 const ConvolutionTransposed3DAttributes& attr) {
466 int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
467 int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
468 int32_t padding_depth = CalculateSamePadding<Axis::DEPTH>(input, attr);
469 Padding3D padding;
470 padding.prepended =
471 HWD(padding_height / 2, padding_width / 2, padding_depth / 2);
472 padding.appended =
473 HWD(padding_height - padding_height / 2,
474 padding_width - padding_width / 2, padding_depth - padding_depth / 2);
475 return padding;
476 }
477
478 // If padding depends on input, convert it into fixed padding.
479 template <class AttrT>
MakeSamePadding(const BHWC & input,const AttrT & attr)480 Padding2D MakeSamePadding(const BHWC& input, const AttrT& attr) {
481 int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
482 int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
483 Padding2D padding;
484 padding.prepended = HW(padding_height / 2, padding_width / 2);
485 padding.appended = HW(padding_height - padding_height / 2,
486 padding_width - padding_width / 2);
487 return padding;
488 }
489
490 // If padding depends on input, convert it into fixed padding.
491 template <class AttrT>
MakeSamePadding(const BHWDC & input,const AttrT & attr)492 Padding3D MakeSamePadding(const BHWDC& input, const AttrT& attr) {
493 int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
494 int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
495 int32_t padding_depth = CalculateSamePadding<Axis::DEPTH>(input, attr);
496 Padding3D padding;
497 padding.prepended =
498 HWD(padding_height / 2, padding_width / 2, padding_depth / 2);
499 padding.appended =
500 HWD(padding_height - padding_height / 2,
501 padding_width - padding_width / 2, padding_depth - padding_depth / 2);
502 return padding;
503 }
504
505 } // namespace
506
CalculateOutputShape(const BHWC & input,const MaxUnpooling2DAttributes & attr)507 BHWC CalculateOutputShape(const BHWC& input,
508 const MaxUnpooling2DAttributes& attr) {
509 return BHWC(input.b,
510 input.h * attr.strides.h - attr.padding.prepended.h -
511 attr.padding.appended.h,
512 input.w * attr.strides.w - attr.padding.prepended.w -
513 attr.padding.appended.w,
514 input.c);
515 }
516
CalculateOutputShape(const BHWDC & input,const MaxUnpooling3DAttributes & attr)517 BHWDC CalculateOutputShape(const BHWDC& input,
518 const MaxUnpooling3DAttributes& attr) {
519 return BHWDC(input.b,
520 input.h * attr.strides.h - attr.padding.prepended.h -
521 attr.padding.appended.h,
522 input.w * attr.strides.w - attr.padding.prepended.w -
523 attr.padding.appended.w,
524 input.d * attr.strides.d - attr.padding.prepended.d -
525 attr.padding.appended.d,
526 input.c);
527 }
528
CalculateOutputShape(const BHWC & input,const Pooling2DAttributes & attr)529 BHWC CalculateOutputShape(const BHWC& input, const Pooling2DAttributes& attr) {
530 return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
531 CalculateOutput<Axis::WIDTH>(input, attr), input.c);
532 }
533
CalculateOutputShape(const BHWDC & input,const Pooling3DAttributes & attr)534 BHWDC CalculateOutputShape(const BHWDC& input,
535 const Pooling3DAttributes& attr) {
536 return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
537 CalculateOutput<Axis::WIDTH>(input, attr),
538 CalculateOutput<Axis::DEPTH>(input, attr), input.c);
539 }
540
CalculateOutputShape(const BHWC & input,const Convolution2DAttributes & attr)541 BHWC CalculateOutputShape(const BHWC& input,
542 const Convolution2DAttributes& attr) {
543 return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
544 CalculateOutput<Axis::WIDTH>(input, attr),
545 attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
546 }
547
CalculateOutputShape(const BHWDC & input,const Convolution3DAttributes & attr)548 BHWDC CalculateOutputShape(const BHWDC& input,
549 const Convolution3DAttributes& attr) {
550 return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
551 CalculateOutput<Axis::WIDTH>(input, attr),
552 CalculateOutput<Axis::DEPTH>(input, attr),
553 attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
554 }
555
CalculateOutputShape(const BHWC & input,const ConvolutionTransposedAttributes & attr)556 BHWC CalculateOutputShape(const BHWC& input,
557 const ConvolutionTransposedAttributes& attr) {
558 return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
559 CalculateOutput<Axis::WIDTH>(input, attr),
560 attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
561 }
562
CalculateOutputShape(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)563 BHWDC CalculateOutputShape(const BHWDC& input,
564 const ConvolutionTransposed3DAttributes& attr) {
565 return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
566 CalculateOutput<Axis::WIDTH>(input, attr),
567 CalculateOutput<Axis::DEPTH>(input, attr),
568 attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
569 }
570
CalculateOutputShape(const BHWC & input,const DepthwiseConvolution2DAttributes & attr)571 BHWC CalculateOutputShape(const BHWC& input,
572 const DepthwiseConvolution2DAttributes& attr) {
573 return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
574 CalculateOutput<Axis::WIDTH>(input, attr),
575 attr.weights.shape.get<Axis::OUTPUT_CHANNELS>() *
576 attr.weights.shape.get<Axis::INPUT_CHANNELS>());
577 }
578
CalculateOutputShape(const BHWDC & input,const DepthwiseConvolution3DAttributes & attr)579 BHWDC CalculateOutputShape(const BHWDC& input,
580 const DepthwiseConvolution3DAttributes& attr) {
581 return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
582 CalculateOutput<Axis::WIDTH>(input, attr),
583 CalculateOutput<Axis::DEPTH>(input, attr),
584 attr.weights.shape.get<Axis::OUTPUT_CHANNELS>() *
585 attr.weights.shape.get<Axis::INPUT_CHANNELS>());
586 }
587
CalculateOutputShape(const BHWC & input,const SliceAttributes & attr)588 BHWC CalculateOutputShape(const BHWC& input, const SliceAttributes& attr) {
589 return BHWC(StridedSize(attr.ends.b - attr.starts.b, attr.strides.b),
590 StridedSize(attr.ends.h - attr.starts.h, attr.strides.h),
591 StridedSize(attr.ends.w - attr.starts.w, attr.strides.w),
592 StridedSize(attr.ends.c - attr.starts.c, attr.strides.c));
593 }
594
CalculateOutputShape(const BHWDC & input,const Slice3DAttributes & attr)595 BHWDC CalculateOutputShape(const BHWDC& input, const Slice3DAttributes& attr) {
596 return BHWDC(StridedSize(attr.ends.b - attr.starts.b, attr.strides.b),
597 StridedSize(attr.ends.h - attr.starts.h, attr.strides.h),
598 StridedSize(attr.ends.w - attr.starts.w, attr.strides.w),
599 StridedSize(attr.ends.d - attr.starts.d, attr.strides.d),
600 StridedSize(attr.ends.c - attr.starts.c, attr.strides.c));
601 }
602
CalculateOutputShape(const BHWC & input,const PadAttributes & attr)603 BHWC CalculateOutputShape(const BHWC& input, const PadAttributes& attr) {
604 return BHWC(attr.appended.b + attr.prepended.b + input.b,
605 attr.appended.h + attr.prepended.h + input.h,
606 attr.appended.w + attr.prepended.w + input.w,
607 attr.appended.c + attr.prepended.c + input.c);
608 }
609
CalculateOutputShape(const BHWDC & input,const Pad3DAttributes & attr)610 BHWDC CalculateOutputShape(const BHWDC& input, const Pad3DAttributes& attr) {
611 return BHWDC(attr.appended.b + attr.prepended.b + input.b,
612 attr.appended.h + attr.prepended.h + input.h,
613 attr.appended.w + attr.prepended.w + input.w,
614 attr.appended.d + attr.prepended.d + input.d,
615 attr.appended.c + attr.prepended.c + input.c);
616 }
617
CalculateOutputShape(const BHWC & input,const FullyConnectedAttributes & attr)618 BHWC CalculateOutputShape(const BHWC& input,
619 const FullyConnectedAttributes& attr) {
620 return BHWC(input.b, 1, 1, attr.weights.shape.o);
621 }
622
CalculateOutputShape(const BHWC & input,const MeanAttributes & attr)623 BHWC CalculateOutputShape(const BHWC& input, const MeanAttributes& attr) {
624 const int b = attr.dims.find(Axis::BATCH) == attr.dims.end() ? input.b : 1;
625 const int h = attr.dims.find(Axis::HEIGHT) == attr.dims.end() ? input.h : 1;
626 const int w = attr.dims.find(Axis::WIDTH) == attr.dims.end() ? input.w : 1;
627 const int c = attr.dims.find(Axis::CHANNELS) == attr.dims.end() ? input.c : 1;
628 return BHWC(b, h, w, c);
629 }
630
CalculateOutputShape(const BHWDC & input,const MeanAttributes & attr)631 BHWDC CalculateOutputShape(const BHWDC& input, const MeanAttributes& attr) {
632 const int b = attr.dims.find(Axis::BATCH) == attr.dims.end() ? input.b : 1;
633 const int h = attr.dims.find(Axis::HEIGHT) == attr.dims.end() ? input.h : 1;
634 const int w = attr.dims.find(Axis::WIDTH) == attr.dims.end() ? input.w : 1;
635 const int d = attr.dims.find(Axis::DEPTH) == attr.dims.end() ? input.d : 1;
636 const int c = attr.dims.find(Axis::CHANNELS) == attr.dims.end() ? input.c : 1;
637 return BHWDC(b, h, w, d, c);
638 }
639
CalculateOutputShape(const std::vector<BHWC> & input,const ConcatAttributes & attr,BHWC * output_shape)640 absl::Status CalculateOutputShape(const std::vector<BHWC>& input,
641 const ConcatAttributes& attr,
642 BHWC* output_shape) {
643 BHWC new_shape = input[0];
644 switch (attr.axis) {
645 case Axis::CHANNELS:
646 for (int i = 1; i < input.size(); i++) {
647 if (input[i].h != new_shape.h || input[i].w != new_shape.w ||
648 input[i].b != new_shape.b) {
649 return absl::InvalidArgumentError(
650 "Height, Width and Batch must be the same when concatenating "
651 "by channels axis");
652 }
653 new_shape.c += input[i].c;
654 }
655 break;
656 case Axis::HEIGHT:
657 for (int i = 1; i < input.size(); i++) {
658 if (input[i].w != new_shape.w || input[i].c != new_shape.c ||
659 input[i].b != new_shape.b) {
660 return absl::InvalidArgumentError(
661 "Channels, Width and Batch must be the same when concatenating "
662 "by height axis");
663 }
664 new_shape.h += input[i].h;
665 }
666 break;
667 case Axis::WIDTH:
668 for (int i = 1; i < input.size(); i++) {
669 if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
670 input[i].b != new_shape.b) {
671 return absl::InvalidArgumentError(
672 "Height, Channels and Batch must be the same when concatenating "
673 "by width axis");
674 }
675 new_shape.w += input[i].w;
676 }
677 break;
678 case Axis::BATCH:
679 for (int i = 1; i < input.size(); i++) {
680 if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
681 input[i].w != new_shape.w) {
682 return absl::InvalidArgumentError(
683 "Width, Height and Channels must be the same when concatenating "
684 "by batch axis");
685 }
686 new_shape.b += input[i].b;
687 }
688 break;
689 default:
690 return absl::InvalidArgumentError("Invalid axis");
691 break;
692 }
693 *output_shape = new_shape;
694 return absl::OkStatus();
695 }
696
CalculateOutputShape(const std::vector<BHWDC> & input,const ConcatAttributes & attr,BHWDC * output_shape)697 absl::Status CalculateOutputShape(const std::vector<BHWDC>& input,
698 const ConcatAttributes& attr,
699 BHWDC* output_shape) {
700 BHWDC new_shape = input[0];
701 switch (attr.axis) {
702 case Axis::CHANNELS:
703 for (int i = 1; i < input.size(); ++i) {
704 if (input[i].h != new_shape.h || input[i].w != new_shape.w ||
705 input[i].d != new_shape.d || input[i].b != new_shape.b) {
706 return absl::InvalidArgumentError(
707 "Height, Width, Batch and Depth must be the same when "
708 "concatenating "
709 "by channels axis");
710 }
711 new_shape.c += input[i].c;
712 }
713 break;
714 case Axis::HEIGHT:
715 for (int i = 1; i < input.size(); ++i) {
716 if (input[i].w != new_shape.w || input[i].c != new_shape.c ||
717 input[i].d != new_shape.d || input[i].b != new_shape.b) {
718 return absl::InvalidArgumentError(
719 "Width, Depth, Batch and Channels must be the same when "
720 "concatenating "
721 "by height axis");
722 }
723 new_shape.h += input[i].h;
724 }
725 break;
726 case Axis::WIDTH:
727 for (int i = 1; i < input.size(); ++i) {
728 if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
729 input[i].d != new_shape.d || input[i].b != new_shape.b) {
730 return absl::InvalidArgumentError(
731 "Height, Depth, Batch and Channels must be the same when "
732 "concatenating "
733 "by width axis");
734 }
735 new_shape.w += input[i].w;
736 }
737 break;
738 case Axis::DEPTH:
739 for (int i = 1; i < input.size(); ++i) {
740 if (input[i].w != new_shape.w || input[i].h != new_shape.h ||
741 input[i].c != new_shape.c || input[i].b != new_shape.b) {
742 return absl::InvalidArgumentError(
743 "Width, Height, Batch and Channels must be the same when "
744 "concatenating "
745 "by depth axis");
746 }
747 new_shape.d += input[i].d;
748 }
749 break;
750 case Axis::BATCH:
751 for (int i = 1; i < input.size(); ++i) {
752 if (input[i].w != new_shape.w || input[i].h != new_shape.h ||
753 input[i].c != new_shape.c || input[i].d != new_shape.d) {
754 return absl::InvalidArgumentError(
755 "Width, Height, Depth and Channels must be the same when "
756 "concatenating "
757 "by batch axis");
758 }
759 new_shape.b += input[i].b;
760 }
761 break;
762 default:
763 return absl::InvalidArgumentError("Invalid axis");
764 }
765 *output_shape = new_shape;
766 return absl::OkStatus();
767 }
768
CalculateSamePadding(const BHWC & input,const Convolution2DAttributes & attr)769 Padding2D CalculateSamePadding(const BHWC& input,
770 const Convolution2DAttributes& attr) {
771 return MakeSamePadding(input, attr);
772 }
773
CalculateSamePadding(const BHWDC & input,const Convolution3DAttributes & attr)774 Padding3D CalculateSamePadding(const BHWDC& input,
775 const Convolution3DAttributes& attr) {
776 return MakeSamePadding(input, attr);
777 }
778
CalculateSamePadding(const BHWC & input,const ConvolutionTransposedAttributes & attr)779 Padding2D CalculateSamePadding(const BHWC& input,
780 const ConvolutionTransposedAttributes& attr) {
781 return MakeSamePadding(input, attr);
782 }
783
CalculateSamePadding(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)784 Padding3D CalculateSamePadding(const BHWDC& input,
785 const ConvolutionTransposed3DAttributes& attr) {
786 return MakeSamePadding(input, attr);
787 }
788
CalculateSamePadding(const BHWC & input,const DepthwiseConvolution2DAttributes & attr)789 Padding2D CalculateSamePadding(const BHWC& input,
790 const DepthwiseConvolution2DAttributes& attr) {
791 return MakeSamePadding(input, attr);
792 }
793
CalculateSamePadding(const BHWDC & input,const DepthwiseConvolution3DAttributes & attr)794 Padding3D CalculateSamePadding(const BHWDC& input,
795 const DepthwiseConvolution3DAttributes& attr) {
796 return MakeSamePadding(input, attr);
797 }
798
CalculateSamePadding(const BHWC & input,const Pooling2DAttributes & attr)799 Padding2D CalculateSamePadding(const BHWC& input,
800 const Pooling2DAttributes& attr) {
801 return MakeSamePadding(input, attr);
802 }
803
CalculateSamePadding(const BHWDC & input,const Pooling3DAttributes & attr)804 Padding3D CalculateSamePadding(const BHWDC& input,
805 const Pooling3DAttributes& attr) {
806 return MakeSamePadding(input, attr);
807 }
808
CalculateSamePadding(const BHWC & input,const MaxUnpooling2DAttributes & attr)809 Padding2D CalculateSamePadding(const BHWC& input,
810 const MaxUnpooling2DAttributes& attr) {
811 return MakeSamePadding(input, attr);
812 }
813
CalculateSamePadding(const BHWDC & input,const MaxUnpooling3DAttributes & attr)814 Padding3D CalculateSamePadding(const BHWDC& input,
815 const MaxUnpooling3DAttributes& attr) {
816 return MakeSamePadding(input, attr);
817 }
818
CalculateResizeScale(int32_t input_size,int32_t output_size,const Resize2DAttributes & attr)819 float CalculateResizeScale(int32_t input_size, int32_t output_size,
820 const Resize2DAttributes& attr) {
821 return attr.align_corners && input_size > 1 && output_size > 1
822 ? static_cast<float>(input_size - 1) / (output_size - 1)
823 : static_cast<float>(input_size) / output_size;
824 }
825
CalculateResizeScale(int32_t input_size,int32_t output_size,const Resize3DAttributes & attr)826 float CalculateResizeScale(int32_t input_size, int32_t output_size,
827 const Resize3DAttributes& attr) {
828 return attr.align_corners && input_size > 1 && output_size > 1
829 ? static_cast<float>(input_size - 1) / (output_size - 1)
830 : static_cast<float>(input_size) / output_size;
831 }
832
CalculateOutputShape(const BHWC & input,const Resize2DAttributes & attr)833 BHWC CalculateOutputShape(const BHWC& input, const Resize2DAttributes& attr) {
834 return BHWC(input.b, attr.new_shape.h, attr.new_shape.w, input.c);
835 }
836
CalculateOutputShape(const BHWDC & input,const Resize3DAttributes & attr)837 BHWDC CalculateOutputShape(const BHWDC& input, const Resize3DAttributes& attr) {
838 return BHWDC(input.b, attr.new_shape.h, attr.new_shape.w, attr.new_shape.d,
839 input.c);
840 }
841
CalculateOutputShape(const BHWC & input,const TransposeAttributes & attr)842 BHWC CalculateOutputShape(const BHWC& input, const TransposeAttributes& attr) {
843 return BHWC(input.get(attr.perm.b), input.get(attr.perm.h),
844 input.get(attr.perm.w), input.get(attr.perm.c));
845 }
846
CalculateOutputShape(const BHWDC & input,const Transpose3DAttributes & attr)847 BHWDC CalculateOutputShape(const BHWDC& input,
848 const Transpose3DAttributes& attr) {
849 return BHWDC(input.get(attr.perm.b), input.get(attr.perm.h),
850 input.get(attr.perm.w), input.get(attr.perm.d),
851 input.get(attr.perm.c));
852 }
853
DequatizeFullyConnectedAttr(const FullyConnectedInt8Attributes & attr)854 FullyConnectedAttributes DequatizeFullyConnectedAttr(
855 const FullyConnectedInt8Attributes& attr) {
856 FullyConnectedAttributes dequant_attr;
857 dequant_attr.weights.id = attr.weights.id;
858 dequant_attr.weights.shape = attr.weights.shape;
859 dequant_attr.weights.data.resize(
860 dequant_attr.weights.shape.DimensionsProduct());
861 dequant_attr.bias = attr.bias;
862
863 // weights dequantization to float32
864 for (int i = 0; i < attr.weights.data.size(); i++) {
865 const int32_t val = attr.weights.data[i];
866 dequant_attr.weights.data[i] = attr.scale * (val - attr.zero_point);
867 }
868 return dequant_attr;
869 }
870
871 } // namespace gpu
872 } // namespace tflite
873