xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/operations.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/operations.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <set>
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 
26 #include "tensorflow/lite/delegates/gpu/common/shape.h"
27 #include "tensorflow/lite/delegates/gpu/common/status.h"
28 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
29 
30 namespace tflite {
31 namespace gpu {
32 
operator =(const Padding2D & value)33 Padding2D& Padding2D::operator=(const Padding2D& value) {
34   prepended = value.prepended;
35   appended = value.appended;
36   return *this;
37 }
38 
operator ==(const Padding2D & value)39 bool Padding2D::operator==(const Padding2D& value) {
40   return this->prepended == value.prepended && this->appended == value.appended;
41 }
42 
operator !=(const Padding2D & value)43 bool Padding2D::operator!=(const Padding2D& value) { return !(*this == value); }
44 
operator -(const Padding2D & value)45 Padding2D& Padding2D::operator-(const Padding2D& value) {
46   prepended.h -= value.prepended.h;
47   prepended.w -= value.prepended.w;
48   appended.h -= value.appended.h;
49   appended.w -= value.appended.w;
50   return *this;
51 }
52 
operator =(const Padding3D & value)53 Padding3D& Padding3D::operator=(const Padding3D& value) {
54   prepended = value.prepended;
55   appended = value.appended;
56   return *this;
57 }
58 
operator ==(const Padding3D & value)59 bool Padding3D::operator==(const Padding3D& value) {
60   return this->prepended == value.prepended && this->appended == value.appended;
61 }
62 
operator !=(const Padding3D & value)63 bool Padding3D::operator!=(const Padding3D& value) { return !(*this == value); }
64 
operator -(const Padding3D & value)65 Padding3D& Padding3D::operator-(const Padding3D& value) {
66   prepended.h -= value.prepended.h;
67   prepended.w -= value.prepended.w;
68   prepended.d -= value.prepended.d;
69   appended.h -= value.appended.h;
70   appended.w -= value.appended.w;
71   appended.d -= value.appended.d;
72   return *this;
73 }
74 
ToString(enum OperationType op)75 std::string ToString(enum OperationType op) {
76   switch (op) {
77     case OperationType::ABS:
78       return "abs";
79     case OperationType::ADD:
80       return "add";
81     case OperationType::BATCH_NORMALIZATION:
82       return "batch_normalization";
83     case OperationType::BATCH_TO_SPACE:
84       return "batch_to_space";
85     case OperationType::BATCHED_MATMUL:
86       return "batched_matmul";
87     case OperationType::CAST:
88       return "cast";
89     case OperationType::CONCAT:
90       return "concat";
91     case OperationType::CONSTANT:
92       return "const";
93     case OperationType::CONVOLUTION_2D:
94       return "convolution_2d";
95     case OperationType::CONVOLUTION_TRANSPOSED:
96       return "convolution_transposed";
97     case OperationType::COPY:
98       return "copy";
99     case OperationType::COS:
100       return "cos";
101     case OperationType::CUMSUM:
102       return "cumsum";
103     case OperationType::DENSIFY:
104       return "densify";
105     case OperationType::DEPTHWISE_CONVOLUTION:
106       return "depthwise_convolution";
107     case OperationType::DEPTH_TO_SPACE:
108       return "depth_to_space";
109     case OperationType::DIV:
110       return "div";
111     case OperationType::ELU:
112       return "elu";
113     case OperationType::EQUAL:
114       return "equal";
115     case OperationType::EXP:
116       return "exp";
117     case OperationType::FLOOR:
118       return "floor";
119     case OperationType::FLOOR_DIV:
120       return "floor_div";
121     case OperationType::FLOOR_MOD:
122       return "floor_mod";
123     case OperationType::FULLY_CONNECTED:
124       return "fully_connected";
125     case OperationType::FULLY_CONNECTED_INT8:
126       return "fully_connected_int8";
127     case OperationType::GATHER:
128       return "gather";
129     case OperationType::GREATER:
130       return "greater";
131     case OperationType::GREATER_EQUAL:
132       return "greater_equal";
133     case OperationType::HARD_SWISH:
134       return "hard_swish";
135     case OperationType::LESS:
136       return "less";
137     case OperationType::LESS_EQUAL:
138       return "less_equal";
139     case OperationType::LOG:
140       return "log";
141     case OperationType::LSTM:
142       return "lstm";
143     case OperationType::MAXIMUM:
144       return "maximum";
145     case OperationType::MAX_UNPOOLING_2D:
146       return "max_unpooling";
147     case OperationType::MEAN:
148       return "mean";
149     case OperationType::MEAN_STDDEV_NORMALIZATION:
150       return "mean_stddev_normalization";
151     case OperationType::MINIMUM:
152       return "minimum";
153     case OperationType::MUL:
154       return "mul";
155     case OperationType::NEG:
156       return "neg";
157     case OperationType::NOT_EQUAL:
158       return "not_equal";
159     case OperationType::ONE_HOT:
160       return "one_hot";
161     case OperationType::PAD:
162       return "pad";
163     case OperationType::POOLING_2D:
164       return "pooling_2d";
165     case OperationType::POW:
166       return "pow";
167     case OperationType::PRELU:
168       return "prelu";
169     case OperationType::QUANTIZE_AND_DEQUANTIZE:
170       return "quantize_and_dequantize";
171     case OperationType::REDUCE_MAXIMUM:
172       return "reduce_maximum";
173     case OperationType::REDUCE_MINIMUM:
174       return "reduce_minimum";
175     case OperationType::REDUCE_PRODUCT:
176       return "reduce_product";
177     case OperationType::REDUCE_SUM:
178       return "reduce_sum";
179     case OperationType::RELU:
180       return "relu";
181     case OperationType::RESAMPLER:
182       return "resampler";
183     case OperationType::RESHAPE:
184       return "reshape";
185     case OperationType::RESIZE:
186       return "resize";
187     case OperationType::RSQRT:
188       return "rsqrt";
189     case OperationType::SELECT_V2:
190       return "select_v2";
191     case OperationType::SIGMOID:
192       return "sigmoid";
193     case OperationType::SIN:
194       return "sin";
195     case OperationType::SLICE:
196       return "slice";
197     case OperationType::SOFTMAX:
198       return "softmax";
199     case OperationType::SPACE_TO_BATCH:
200       return "space_to_batch";
201     case OperationType::SPACE_TO_DEPTH:
202       return "space_to_depth";
203     case OperationType::SPLIT:
204       return "split";
205     case OperationType::SQRT:
206       return "sqrt";
207     case OperationType::SQUARE:
208       return "square";
209     case OperationType::SQUARED_DIFF:
210       return "squared_diff";
211     case OperationType::SUB:
212       return "subtract";
213     case OperationType::TANH:
214       return "tanh";
215     case OperationType::TILE:
216       return "tile";
217     case OperationType::TRANSPOSE:
218       return "transpose";
219     case OperationType::UNKNOWN:
220       return "unknown_operation";
221   }
222 }
223 
OperationTypeFromString(const std::string & name)224 OperationType OperationTypeFromString(const std::string& name) {
225   static const auto* operations =
226       new std::unordered_map<std::string, OperationType>({
227           {"abs", OperationType::ABS},
228           {"add", OperationType::ADD},
229           {"batch_normalization", OperationType::BATCH_NORMALIZATION},
230           {"batched_matmul", OperationType::BATCHED_MATMUL},
231           {"cast", OperationType::CAST},
232           {"concat", OperationType::CONCAT},
233           {"const", OperationType::CONSTANT},
234           {"convolution_2d", OperationType::CONVOLUTION_2D},
235           {"convolution_transposed", OperationType::CONVOLUTION_TRANSPOSED},
236           {"copy", OperationType::COPY},
237           {"cos", OperationType::COS},
238           {"cumsum", OperationType::CUMSUM},
239           {"densify", OperationType::DENSIFY},
240           {"depthwise_convolution", OperationType::DEPTHWISE_CONVOLUTION},
241           {"depth_to_space", OperationType::DEPTH_TO_SPACE},
242           {"div", OperationType::DIV},
243           {"elu", OperationType::ELU},
244           {"equal", OperationType::EQUAL},
245           {"exp", OperationType::EXP},
246           {"floor", OperationType::FLOOR},
247           {"floor_div", OperationType::FLOOR_DIV},
248           {"floor_mod", OperationType::FLOOR_MOD},
249           {"fully_connected", OperationType::FULLY_CONNECTED},
250           {"fully_connected_int8", OperationType::FULLY_CONNECTED_INT8},
251           {"gather", OperationType::GATHER},
252           {"greater", OperationType::GREATER},
253           {"greater_equal", OperationType::GREATER_EQUAL},
254           {"hard_swish", OperationType::HARD_SWISH},
255           {"less", OperationType::LESS},
256           {"less_equal", OperationType::LESS_EQUAL},
257           {"log", OperationType::LOG},
258           {"lstm", OperationType::LSTM},
259           {"maximum", OperationType::MAXIMUM},
260           {"max_unpooling", OperationType::MAX_UNPOOLING_2D},
261           {"mean", OperationType::MEAN},
262           {"mean_stddev_normalization",
263            OperationType::MEAN_STDDEV_NORMALIZATION},
264           {"minimum", OperationType::MINIMUM},
265           {"mul", OperationType::MUL},
266           {"neg", OperationType::NEG},
267           {"not_equal", OperationType::NOT_EQUAL},
268           {"one_hot", OperationType::ONE_HOT},
269           {"pad", OperationType::PAD},
270           {"pooling_2d", OperationType::POOLING_2D},
271           {"pow", OperationType::POW},
272           {"prelu", OperationType::PRELU},
273           {"quantize_and_dequantize", OperationType::QUANTIZE_AND_DEQUANTIZE},
274           {"reduce_maximum", OperationType::REDUCE_MAXIMUM},
275           {"reduce_minimum", OperationType::REDUCE_MINIMUM},
276           {"reduce_product", OperationType::REDUCE_PRODUCT},
277           {"reduce_sum", OperationType::REDUCE_SUM},
278           {"relu", OperationType::RELU},
279           {"resampler", OperationType::RESAMPLER},
280           {"resize", OperationType::RESIZE},
281           {"reshape", OperationType::RESHAPE},
282           {"rsqrt", OperationType::RSQRT},
283           {"select_v2", OperationType::SELECT_V2},
284           {"sigmoid", OperationType::SIGMOID},
285           {"sin", OperationType::SIN},
286           {"slice", OperationType::SLICE},
287           {"softmax", OperationType::SOFTMAX},
288           {"space_to_depth", OperationType::SPACE_TO_DEPTH},
289           {"split", OperationType::SPLIT},
290           {"sqrt", OperationType::SQRT},
291           {"square", OperationType::SQUARE},
292           {"squared_diff", OperationType::SQUARED_DIFF},
293           {"subtract", OperationType::SUB},
294           {"tanh", OperationType::TANH},
295           {"tile", OperationType::TILE},
296           {"transpose", OperationType::TRANSPOSE},
297       });
298   auto op = operations->find(name);
299   return op == operations->end() ? OperationType::UNKNOWN : op->second;
300 }
301 
302 namespace {
303 
304 template <typename T>
DivideRoundUp(T n,T divisor)305 T DivideRoundUp(T n, T divisor) {
306   return (n - 1) / divisor + 1;
307 }
308 
CalculateOutputSizeBeforeStrides(int32_t input,int32_t kernel,int32_t padding,int32_t dilation)309 int32_t CalculateOutputSizeBeforeStrides(int32_t input, int32_t kernel,
310                                          int32_t padding, int32_t dilation) {
311   const int32_t dilated_kernel = (kernel - 1) * dilation + 1;
312   return input + padding - dilated_kernel + 1;
313 }
314 
315 template <Axis T>
CalculateOutputWithoutStrides(const BHWC & input,const Convolution2DAttributes & attr)316 int32_t CalculateOutputWithoutStrides(const BHWC& input,
317                                       const Convolution2DAttributes& attr) {
318   return CalculateOutputSizeBeforeStrides(
319       input.get<T>(), attr.weights.shape.get<T>(),
320       attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(),
321       attr.dilations.get<T>());
322 }
323 
324 template <Axis T>
CalculateOutputWithoutStrides(const BHWDC & input,const Convolution3DAttributes & attr)325 int32_t CalculateOutputWithoutStrides(const BHWDC& input,
326                                       const Convolution3DAttributes& attr) {
327   return CalculateOutputSizeBeforeStrides(
328       input.get<T>(), attr.weights.shape.get<T>(),
329       attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(),
330       attr.dilations.get<T>());
331 }
332 
333 template <Axis T>
CalculateOutputWithoutStrides(const BHWC & input,const Pooling2DAttributes & attr)334 int32_t CalculateOutputWithoutStrides(const BHWC& input,
335                                       const Pooling2DAttributes& attr) {
336   return CalculateOutputSizeBeforeStrides(
337       input.get<T>(), attr.kernel.get<T>(),
338       attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(),
339       /*dilation=*/1);
340 }
341 
342 template <Axis T>
CalculateOutputWithoutStrides(const BHWDC & input,const Pooling3DAttributes & attr)343 int32_t CalculateOutputWithoutStrides(const BHWDC& input,
344                                       const Pooling3DAttributes& attr) {
345   return CalculateOutputSizeBeforeStrides(
346       input.get<T>(), attr.kernel.get<T>(),
347       attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(),
348       /*dilation=*/1);
349 }
350 
351 template <Axis T>
CalculateOutput(const BHWC & input,const ConvolutionTransposedAttributes & attr)352 int32_t CalculateOutput(const BHWC& input,
353                         const ConvolutionTransposedAttributes& attr) {
354   return (input.get<T>() - 1) * attr.stride.get<T>() -
355          (attr.padding.prepended.get<T>() + attr.padding.appended.get<T>()) +
356          attr.weights.shape.get<T>() + attr.adjacent.get<T>();
357 }
358 
359 template <Axis T>
CalculateOutput(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)360 int32_t CalculateOutput(const BHWDC& input,
361                         const ConvolutionTransposed3DAttributes& attr) {
362   return (input.get<T>() - 1) * attr.stride.get<T>() -
363          (attr.padding.prepended.get<T>() + attr.padding.appended.get<T>()) +
364          attr.weights.shape.get<T>();
365 }
366 
StridedSize(int32_t size,int32_t stride)367 inline int32_t StridedSize(int32_t size, int32_t stride) {
368   return stride == 0 ? -1 : DivideRoundUp(size, stride);
369 }
370 
371 template <Axis AxisT, typename AttrT>
CalculateOutput(const BHWC & input,const AttrT & attr)372 int32_t CalculateOutput(const BHWC& input, const AttrT& attr) {
373   return StridedSize(CalculateOutputWithoutStrides<AxisT>(input, attr),
374                      attr.strides.template get<AxisT>());
375 }
376 
377 template <Axis AxisT, typename AttrT>
CalculateOutput(const BHWDC & input,const AttrT & attr)378 int32_t CalculateOutput(const BHWDC& input, const AttrT& attr) {
379   return StridedSize(CalculateOutputWithoutStrides<AxisT>(input, attr),
380                      attr.strides.template get<AxisT>());
381 }
382 
CalculateSamePadding(int32_t input,int32_t kernel,int32_t dilation,int32_t stride)383 int32_t CalculateSamePadding(int32_t input, int32_t kernel, int32_t dilation,
384                              int32_t stride) {
385   const int32_t dilated_kernel = (kernel - 1) * dilation + 1;
386   return std::max(0, dilated_kernel - (input - 1) % stride - 1);
387 }
388 
389 // Returns a padding that should be present to make sure image size stays
390 // the same.
391 template <Axis AxisT>
CalculateSamePadding(const BHWC & input,const Convolution2DAttributes & attr)392 int32_t CalculateSamePadding(const BHWC& input,
393                              const Convolution2DAttributes& attr) {
394   return CalculateSamePadding(
395       input.get<AxisT>(), attr.weights.shape.get<AxisT>(),
396       attr.dilations.get<AxisT>(), attr.strides.get<AxisT>());
397 }
398 
399 // Returns a padding that should be present to make sure image size stays
400 // the same.
401 template <Axis AxisT>
CalculateSamePadding(const BHWDC & input,const Convolution3DAttributes & attr)402 int32_t CalculateSamePadding(const BHWDC& input,
403                              const Convolution3DAttributes& attr) {
404   return CalculateSamePadding(
405       input.get<AxisT>(), attr.weights.shape.get<AxisT>(),
406       attr.dilations.get<AxisT>(), attr.strides.get<AxisT>());
407 }
408 
409 template <Axis AxisT>
CalculateSamePadding(const BHWC & input,const ConvolutionTransposedAttributes & attr)410 int32_t CalculateSamePadding(const BHWC& input,
411                              const ConvolutionTransposedAttributes& attr) {
412   return CalculateSamePadding(input.get<AxisT>(),
413                               attr.weights.shape.get<AxisT>(),
414                               /*dilation=*/1, attr.stride.get<AxisT>());
415 }
416 
417 template <Axis AxisT>
CalculateSamePadding(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)418 int32_t CalculateSamePadding(const BHWDC& input,
419                              const ConvolutionTransposed3DAttributes& attr) {
420   return CalculateSamePadding(input.get<AxisT>(),
421                               attr.weights.shape.get<AxisT>(),
422                               /*dilation=*/1, attr.stride.get<AxisT>());
423 }
424 
425 template <Axis AxisT>
CalculateSamePadding(const BHWC & input,const Pooling2DAttributes & attr)426 int32_t CalculateSamePadding(const BHWC& input,
427                              const Pooling2DAttributes& attr) {
428   return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
429                               /*dilation=*/1, attr.strides.get<AxisT>());
430 }
431 
432 template <Axis AxisT>
CalculateSamePadding(const BHWDC & input,const Pooling3DAttributes & attr)433 int32_t CalculateSamePadding(const BHWDC& input,
434                              const Pooling3DAttributes& attr) {
435   return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
436                               /*dilation=*/1, attr.strides.get<AxisT>());
437 }
438 
439 template <Axis AxisT>
CalculateSamePadding(const BHWC & input,const MaxUnpooling2DAttributes & attr)440 int32_t CalculateSamePadding(const BHWC& input,
441                              const MaxUnpooling2DAttributes& attr) {
442   return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
443                               /*dilation=*/1, attr.strides.get<AxisT>());
444 }
445 
446 template <Axis AxisT>
CalculateSamePadding(const BHWDC & input,const MaxUnpooling3DAttributes & attr)447 int32_t CalculateSamePadding(const BHWDC& input,
448                              const MaxUnpooling3DAttributes& attr) {
449   return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
450                               /*dilation=*/1, attr.strides.get<AxisT>());
451 }
452 
MakeSamePadding(const BHWC & input,const ConvolutionTransposedAttributes & attr)453 Padding2D MakeSamePadding(const BHWC& input,
454                           const ConvolutionTransposedAttributes& attr) {
455   int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
456   int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
457   Padding2D padding;
458   padding.prepended = HW(padding_height / 2, padding_width / 2);
459   padding.appended = HW(padding_height - padding_height / 2,
460                         padding_width - padding_width / 2);
461   return padding;
462 }
463 
MakeSamePadding(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)464 Padding3D MakeSamePadding(const BHWDC& input,
465                           const ConvolutionTransposed3DAttributes& attr) {
466   int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
467   int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
468   int32_t padding_depth = CalculateSamePadding<Axis::DEPTH>(input, attr);
469   Padding3D padding;
470   padding.prepended =
471       HWD(padding_height / 2, padding_width / 2, padding_depth / 2);
472   padding.appended =
473       HWD(padding_height - padding_height / 2,
474           padding_width - padding_width / 2, padding_depth - padding_depth / 2);
475   return padding;
476 }
477 
478 // If padding depends on input, convert it into fixed padding.
479 template <class AttrT>
MakeSamePadding(const BHWC & input,const AttrT & attr)480 Padding2D MakeSamePadding(const BHWC& input, const AttrT& attr) {
481   int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
482   int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
483   Padding2D padding;
484   padding.prepended = HW(padding_height / 2, padding_width / 2);
485   padding.appended = HW(padding_height - padding_height / 2,
486                         padding_width - padding_width / 2);
487   return padding;
488 }
489 
490 // If padding depends on input, convert it into fixed padding.
491 template <class AttrT>
MakeSamePadding(const BHWDC & input,const AttrT & attr)492 Padding3D MakeSamePadding(const BHWDC& input, const AttrT& attr) {
493   int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
494   int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
495   int32_t padding_depth = CalculateSamePadding<Axis::DEPTH>(input, attr);
496   Padding3D padding;
497   padding.prepended =
498       HWD(padding_height / 2, padding_width / 2, padding_depth / 2);
499   padding.appended =
500       HWD(padding_height - padding_height / 2,
501           padding_width - padding_width / 2, padding_depth - padding_depth / 2);
502   return padding;
503 }
504 
505 }  // namespace
506 
CalculateOutputShape(const BHWC & input,const MaxUnpooling2DAttributes & attr)507 BHWC CalculateOutputShape(const BHWC& input,
508                           const MaxUnpooling2DAttributes& attr) {
509   return BHWC(input.b,
510               input.h * attr.strides.h - attr.padding.prepended.h -
511                   attr.padding.appended.h,
512               input.w * attr.strides.w - attr.padding.prepended.w -
513                   attr.padding.appended.w,
514               input.c);
515 }
516 
CalculateOutputShape(const BHWDC & input,const MaxUnpooling3DAttributes & attr)517 BHWDC CalculateOutputShape(const BHWDC& input,
518                            const MaxUnpooling3DAttributes& attr) {
519   return BHWDC(input.b,
520                input.h * attr.strides.h - attr.padding.prepended.h -
521                    attr.padding.appended.h,
522                input.w * attr.strides.w - attr.padding.prepended.w -
523                    attr.padding.appended.w,
524                input.d * attr.strides.d - attr.padding.prepended.d -
525                    attr.padding.appended.d,
526                input.c);
527 }
528 
CalculateOutputShape(const BHWC & input,const Pooling2DAttributes & attr)529 BHWC CalculateOutputShape(const BHWC& input, const Pooling2DAttributes& attr) {
530   return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
531               CalculateOutput<Axis::WIDTH>(input, attr), input.c);
532 }
533 
CalculateOutputShape(const BHWDC & input,const Pooling3DAttributes & attr)534 BHWDC CalculateOutputShape(const BHWDC& input,
535                            const Pooling3DAttributes& attr) {
536   return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
537                CalculateOutput<Axis::WIDTH>(input, attr),
538                CalculateOutput<Axis::DEPTH>(input, attr), input.c);
539 }
540 
CalculateOutputShape(const BHWC & input,const Convolution2DAttributes & attr)541 BHWC CalculateOutputShape(const BHWC& input,
542                           const Convolution2DAttributes& attr) {
543   return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
544               CalculateOutput<Axis::WIDTH>(input, attr),
545               attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
546 }
547 
CalculateOutputShape(const BHWDC & input,const Convolution3DAttributes & attr)548 BHWDC CalculateOutputShape(const BHWDC& input,
549                            const Convolution3DAttributes& attr) {
550   return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
551                CalculateOutput<Axis::WIDTH>(input, attr),
552                CalculateOutput<Axis::DEPTH>(input, attr),
553                attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
554 }
555 
CalculateOutputShape(const BHWC & input,const ConvolutionTransposedAttributes & attr)556 BHWC CalculateOutputShape(const BHWC& input,
557                           const ConvolutionTransposedAttributes& attr) {
558   return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
559               CalculateOutput<Axis::WIDTH>(input, attr),
560               attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
561 }
562 
CalculateOutputShape(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)563 BHWDC CalculateOutputShape(const BHWDC& input,
564                            const ConvolutionTransposed3DAttributes& attr) {
565   return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
566                CalculateOutput<Axis::WIDTH>(input, attr),
567                CalculateOutput<Axis::DEPTH>(input, attr),
568                attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
569 }
570 
CalculateOutputShape(const BHWC & input,const DepthwiseConvolution2DAttributes & attr)571 BHWC CalculateOutputShape(const BHWC& input,
572                           const DepthwiseConvolution2DAttributes& attr) {
573   return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
574               CalculateOutput<Axis::WIDTH>(input, attr),
575               attr.weights.shape.get<Axis::OUTPUT_CHANNELS>() *
576                   attr.weights.shape.get<Axis::INPUT_CHANNELS>());
577 }
578 
CalculateOutputShape(const BHWDC & input,const DepthwiseConvolution3DAttributes & attr)579 BHWDC CalculateOutputShape(const BHWDC& input,
580                            const DepthwiseConvolution3DAttributes& attr) {
581   return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
582                CalculateOutput<Axis::WIDTH>(input, attr),
583                CalculateOutput<Axis::DEPTH>(input, attr),
584                attr.weights.shape.get<Axis::OUTPUT_CHANNELS>() *
585                    attr.weights.shape.get<Axis::INPUT_CHANNELS>());
586 }
587 
CalculateOutputShape(const BHWC & input,const SliceAttributes & attr)588 BHWC CalculateOutputShape(const BHWC& input, const SliceAttributes& attr) {
589   return BHWC(StridedSize(attr.ends.b - attr.starts.b, attr.strides.b),
590               StridedSize(attr.ends.h - attr.starts.h, attr.strides.h),
591               StridedSize(attr.ends.w - attr.starts.w, attr.strides.w),
592               StridedSize(attr.ends.c - attr.starts.c, attr.strides.c));
593 }
594 
CalculateOutputShape(const BHWDC & input,const Slice3DAttributes & attr)595 BHWDC CalculateOutputShape(const BHWDC& input, const Slice3DAttributes& attr) {
596   return BHWDC(StridedSize(attr.ends.b - attr.starts.b, attr.strides.b),
597                StridedSize(attr.ends.h - attr.starts.h, attr.strides.h),
598                StridedSize(attr.ends.w - attr.starts.w, attr.strides.w),
599                StridedSize(attr.ends.d - attr.starts.d, attr.strides.d),
600                StridedSize(attr.ends.c - attr.starts.c, attr.strides.c));
601 }
602 
CalculateOutputShape(const BHWC & input,const PadAttributes & attr)603 BHWC CalculateOutputShape(const BHWC& input, const PadAttributes& attr) {
604   return BHWC(attr.appended.b + attr.prepended.b + input.b,
605               attr.appended.h + attr.prepended.h + input.h,
606               attr.appended.w + attr.prepended.w + input.w,
607               attr.appended.c + attr.prepended.c + input.c);
608 }
609 
CalculateOutputShape(const BHWDC & input,const Pad3DAttributes & attr)610 BHWDC CalculateOutputShape(const BHWDC& input, const Pad3DAttributes& attr) {
611   return BHWDC(attr.appended.b + attr.prepended.b + input.b,
612                attr.appended.h + attr.prepended.h + input.h,
613                attr.appended.w + attr.prepended.w + input.w,
614                attr.appended.d + attr.prepended.d + input.d,
615                attr.appended.c + attr.prepended.c + input.c);
616 }
617 
CalculateOutputShape(const BHWC & input,const FullyConnectedAttributes & attr)618 BHWC CalculateOutputShape(const BHWC& input,
619                           const FullyConnectedAttributes& attr) {
620   return BHWC(input.b, 1, 1, attr.weights.shape.o);
621 }
622 
CalculateOutputShape(const BHWC & input,const MeanAttributes & attr)623 BHWC CalculateOutputShape(const BHWC& input, const MeanAttributes& attr) {
624   const int b = attr.dims.find(Axis::BATCH) == attr.dims.end() ? input.b : 1;
625   const int h = attr.dims.find(Axis::HEIGHT) == attr.dims.end() ? input.h : 1;
626   const int w = attr.dims.find(Axis::WIDTH) == attr.dims.end() ? input.w : 1;
627   const int c = attr.dims.find(Axis::CHANNELS) == attr.dims.end() ? input.c : 1;
628   return BHWC(b, h, w, c);
629 }
630 
CalculateOutputShape(const BHWDC & input,const MeanAttributes & attr)631 BHWDC CalculateOutputShape(const BHWDC& input, const MeanAttributes& attr) {
632   const int b = attr.dims.find(Axis::BATCH) == attr.dims.end() ? input.b : 1;
633   const int h = attr.dims.find(Axis::HEIGHT) == attr.dims.end() ? input.h : 1;
634   const int w = attr.dims.find(Axis::WIDTH) == attr.dims.end() ? input.w : 1;
635   const int d = attr.dims.find(Axis::DEPTH) == attr.dims.end() ? input.d : 1;
636   const int c = attr.dims.find(Axis::CHANNELS) == attr.dims.end() ? input.c : 1;
637   return BHWDC(b, h, w, d, c);
638 }
639 
CalculateOutputShape(const std::vector<BHWC> & input,const ConcatAttributes & attr,BHWC * output_shape)640 absl::Status CalculateOutputShape(const std::vector<BHWC>& input,
641                                   const ConcatAttributes& attr,
642                                   BHWC* output_shape) {
643   BHWC new_shape = input[0];
644   switch (attr.axis) {
645     case Axis::CHANNELS:
646       for (int i = 1; i < input.size(); i++) {
647         if (input[i].h != new_shape.h || input[i].w != new_shape.w ||
648             input[i].b != new_shape.b) {
649           return absl::InvalidArgumentError(
650               "Height, Width and Batch must be the same when concatenating "
651               "by channels axis");
652         }
653         new_shape.c += input[i].c;
654       }
655       break;
656     case Axis::HEIGHT:
657       for (int i = 1; i < input.size(); i++) {
658         if (input[i].w != new_shape.w || input[i].c != new_shape.c ||
659             input[i].b != new_shape.b) {
660           return absl::InvalidArgumentError(
661               "Channels, Width and Batch must be the same when concatenating "
662               "by height axis");
663         }
664         new_shape.h += input[i].h;
665       }
666       break;
667     case Axis::WIDTH:
668       for (int i = 1; i < input.size(); i++) {
669         if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
670             input[i].b != new_shape.b) {
671           return absl::InvalidArgumentError(
672               "Height, Channels and Batch must be the same when concatenating "
673               "by width axis");
674         }
675         new_shape.w += input[i].w;
676       }
677       break;
678     case Axis::BATCH:
679       for (int i = 1; i < input.size(); i++) {
680         if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
681             input[i].w != new_shape.w) {
682           return absl::InvalidArgumentError(
683               "Width, Height and Channels must be the same when concatenating "
684               "by batch axis");
685         }
686         new_shape.b += input[i].b;
687       }
688       break;
689     default:
690       return absl::InvalidArgumentError("Invalid axis");
691       break;
692   }
693   *output_shape = new_shape;
694   return absl::OkStatus();
695 }
696 
CalculateOutputShape(const std::vector<BHWDC> & input,const ConcatAttributes & attr,BHWDC * output_shape)697 absl::Status CalculateOutputShape(const std::vector<BHWDC>& input,
698                                   const ConcatAttributes& attr,
699                                   BHWDC* output_shape) {
700   BHWDC new_shape = input[0];
701   switch (attr.axis) {
702     case Axis::CHANNELS:
703       for (int i = 1; i < input.size(); ++i) {
704         if (input[i].h != new_shape.h || input[i].w != new_shape.w ||
705             input[i].d != new_shape.d || input[i].b != new_shape.b) {
706           return absl::InvalidArgumentError(
707               "Height, Width, Batch and Depth must be the same when "
708               "concatenating "
709               "by channels axis");
710         }
711         new_shape.c += input[i].c;
712       }
713       break;
714     case Axis::HEIGHT:
715       for (int i = 1; i < input.size(); ++i) {
716         if (input[i].w != new_shape.w || input[i].c != new_shape.c ||
717             input[i].d != new_shape.d || input[i].b != new_shape.b) {
718           return absl::InvalidArgumentError(
719               "Width, Depth, Batch and Channels must be the same when "
720               "concatenating "
721               "by height axis");
722         }
723         new_shape.h += input[i].h;
724       }
725       break;
726     case Axis::WIDTH:
727       for (int i = 1; i < input.size(); ++i) {
728         if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
729             input[i].d != new_shape.d || input[i].b != new_shape.b) {
730           return absl::InvalidArgumentError(
731               "Height, Depth, Batch and Channels must be the same when "
732               "concatenating "
733               "by width axis");
734         }
735         new_shape.w += input[i].w;
736       }
737       break;
738     case Axis::DEPTH:
739       for (int i = 1; i < input.size(); ++i) {
740         if (input[i].w != new_shape.w || input[i].h != new_shape.h ||
741             input[i].c != new_shape.c || input[i].b != new_shape.b) {
742           return absl::InvalidArgumentError(
743               "Width, Height, Batch and Channels must be the same when "
744               "concatenating "
745               "by depth axis");
746         }
747         new_shape.d += input[i].d;
748       }
749       break;
750     case Axis::BATCH:
751       for (int i = 1; i < input.size(); ++i) {
752         if (input[i].w != new_shape.w || input[i].h != new_shape.h ||
753             input[i].c != new_shape.c || input[i].d != new_shape.d) {
754           return absl::InvalidArgumentError(
755               "Width, Height, Depth and Channels must be the same when "
756               "concatenating "
757               "by batch axis");
758         }
759         new_shape.b += input[i].b;
760       }
761       break;
762     default:
763       return absl::InvalidArgumentError("Invalid axis");
764   }
765   *output_shape = new_shape;
766   return absl::OkStatus();
767 }
768 
CalculateSamePadding(const BHWC & input,const Convolution2DAttributes & attr)769 Padding2D CalculateSamePadding(const BHWC& input,
770                                const Convolution2DAttributes& attr) {
771   return MakeSamePadding(input, attr);
772 }
773 
CalculateSamePadding(const BHWDC & input,const Convolution3DAttributes & attr)774 Padding3D CalculateSamePadding(const BHWDC& input,
775                                const Convolution3DAttributes& attr) {
776   return MakeSamePadding(input, attr);
777 }
778 
CalculateSamePadding(const BHWC & input,const ConvolutionTransposedAttributes & attr)779 Padding2D CalculateSamePadding(const BHWC& input,
780                                const ConvolutionTransposedAttributes& attr) {
781   return MakeSamePadding(input, attr);
782 }
783 
CalculateSamePadding(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)784 Padding3D CalculateSamePadding(const BHWDC& input,
785                                const ConvolutionTransposed3DAttributes& attr) {
786   return MakeSamePadding(input, attr);
787 }
788 
CalculateSamePadding(const BHWC & input,const DepthwiseConvolution2DAttributes & attr)789 Padding2D CalculateSamePadding(const BHWC& input,
790                                const DepthwiseConvolution2DAttributes& attr) {
791   return MakeSamePadding(input, attr);
792 }
793 
CalculateSamePadding(const BHWDC & input,const DepthwiseConvolution3DAttributes & attr)794 Padding3D CalculateSamePadding(const BHWDC& input,
795                                const DepthwiseConvolution3DAttributes& attr) {
796   return MakeSamePadding(input, attr);
797 }
798 
CalculateSamePadding(const BHWC & input,const Pooling2DAttributes & attr)799 Padding2D CalculateSamePadding(const BHWC& input,
800                                const Pooling2DAttributes& attr) {
801   return MakeSamePadding(input, attr);
802 }
803 
CalculateSamePadding(const BHWDC & input,const Pooling3DAttributes & attr)804 Padding3D CalculateSamePadding(const BHWDC& input,
805                                const Pooling3DAttributes& attr) {
806   return MakeSamePadding(input, attr);
807 }
808 
CalculateSamePadding(const BHWC & input,const MaxUnpooling2DAttributes & attr)809 Padding2D CalculateSamePadding(const BHWC& input,
810                                const MaxUnpooling2DAttributes& attr) {
811   return MakeSamePadding(input, attr);
812 }
813 
CalculateSamePadding(const BHWDC & input,const MaxUnpooling3DAttributes & attr)814 Padding3D CalculateSamePadding(const BHWDC& input,
815                                const MaxUnpooling3DAttributes& attr) {
816   return MakeSamePadding(input, attr);
817 }
818 
CalculateResizeScale(int32_t input_size,int32_t output_size,const Resize2DAttributes & attr)819 float CalculateResizeScale(int32_t input_size, int32_t output_size,
820                            const Resize2DAttributes& attr) {
821   return attr.align_corners && input_size > 1 && output_size > 1
822              ? static_cast<float>(input_size - 1) / (output_size - 1)
823              : static_cast<float>(input_size) / output_size;
824 }
825 
CalculateResizeScale(int32_t input_size,int32_t output_size,const Resize3DAttributes & attr)826 float CalculateResizeScale(int32_t input_size, int32_t output_size,
827                            const Resize3DAttributes& attr) {
828   return attr.align_corners && input_size > 1 && output_size > 1
829              ? static_cast<float>(input_size - 1) / (output_size - 1)
830              : static_cast<float>(input_size) / output_size;
831 }
832 
CalculateOutputShape(const BHWC & input,const Resize2DAttributes & attr)833 BHWC CalculateOutputShape(const BHWC& input, const Resize2DAttributes& attr) {
834   return BHWC(input.b, attr.new_shape.h, attr.new_shape.w, input.c);
835 }
836 
CalculateOutputShape(const BHWDC & input,const Resize3DAttributes & attr)837 BHWDC CalculateOutputShape(const BHWDC& input, const Resize3DAttributes& attr) {
838   return BHWDC(input.b, attr.new_shape.h, attr.new_shape.w, attr.new_shape.d,
839                input.c);
840 }
841 
CalculateOutputShape(const BHWC & input,const TransposeAttributes & attr)842 BHWC CalculateOutputShape(const BHWC& input, const TransposeAttributes& attr) {
843   return BHWC(input.get(attr.perm.b), input.get(attr.perm.h),
844               input.get(attr.perm.w), input.get(attr.perm.c));
845 }
846 
CalculateOutputShape(const BHWDC & input,const Transpose3DAttributes & attr)847 BHWDC CalculateOutputShape(const BHWDC& input,
848                            const Transpose3DAttributes& attr) {
849   return BHWDC(input.get(attr.perm.b), input.get(attr.perm.h),
850                input.get(attr.perm.w), input.get(attr.perm.d),
851                input.get(attr.perm.c));
852 }
853 
DequatizeFullyConnectedAttr(const FullyConnectedInt8Attributes & attr)854 FullyConnectedAttributes DequatizeFullyConnectedAttr(
855     const FullyConnectedInt8Attributes& attr) {
856   FullyConnectedAttributes dequant_attr;
857   dequant_attr.weights.id = attr.weights.id;
858   dequant_attr.weights.shape = attr.weights.shape;
859   dequant_attr.weights.data.resize(
860       dequant_attr.weights.shape.DimensionsProduct());
861   dequant_attr.bias = attr.bias;
862 
863   // weights dequantization to float32
864   for (int i = 0; i < attr.weights.data.size(); i++) {
865     const int32_t val = attr.weights.data[i];
866     dequant_attr.weights.data[i] = attr.scale * (val - attr.zero_point);
867   }
868   return dequant_attr;
869 }
870 
871 }  // namespace gpu
872 }  // namespace tflite
873