1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
17
18 #include <algorithm>
19
20 #include "tensorflow/lite/kernels/internal/types.h"
21
22 namespace tflite {
23
24 namespace reference_ops {
25
26 // Consolidates dimensions in broadcast inputs, checks for five-fold pattern.
27 //
28 // For example, if sequence of dimensions of one input is
29 // ..., 1, 3, 1, 7, 9, 5,... and the other is ..., 2, 3, 1, 7, 1, 1, ...
30 // we can consolidate these as
31 // ..., 1, 3*7, 9*5, ... and 2, 3*7, 1.
32 //
33 // The category is updated in the less-frequent case of shapes that are
34 // not suited to a fivefold-loop broadcast.
35 //
36 // Falls back to generic pattern when it does not know how to process properly.
37 //
38 // Returns true iff there is some sort of broadcast, which includes five-fold
39 // patterns and falling back to generic broadcast.
ProcessBroadcastShapes(const RuntimeShape & shape0,const RuntimeShape & shape1,tflite::ArithmeticParams * params)40 inline bool ProcessBroadcastShapes(const RuntimeShape& shape0,
41 const RuntimeShape& shape1,
42 tflite::ArithmeticParams* params) {
43 const int dims_count =
44 std::max(shape0.DimensionsCount(), shape1.DimensionsCount());
45
46 params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
47 RuntimeShape scalar_shape(dims_count, 1);
48
49 auto extended_shape0 = RuntimeShape::ExtendedShape(dims_count, shape0);
50 auto extended_shape1 = RuntimeShape::ExtendedShape(dims_count, shape1);
51
52 // Check for "exact" match, implicitly accepting any scalar shapes.
53 if (extended_shape0 == extended_shape1) {
54 params->broadcast_category = BroadcastableOpCategory::kNonBroadcast;
55 return false;
56 }
57
58 for (int i = dims_count - 1; i >= 0; --i) {
59 if (extended_shape0.Dims(i) == extended_shape1.Dims(i)) {
60 continue;
61 } else if (extended_shape0.Dims(i) == 1) {
62 params->broadcast_category =
63 BroadcastableOpCategory::kFirstInputBroadcastsFast;
64 break;
65 } else if (extended_shape1.Dims(i) == 1) {
66 params->broadcast_category =
67 BroadcastableOpCategory::kSecondInputBroadcastsFast;
68 break;
69 } else {
70 // This case is erroneous: there is a dimension that does not match and
71 // is not a broadcast from one shape to the other.
72 params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
73 return true;
74 }
75 }
76
77 if (params->broadcast_category !=
78 BroadcastableOpCategory::kFirstInputBroadcastsFast &&
79 params->broadcast_category !=
80 BroadcastableOpCategory::kSecondInputBroadcastsFast) {
81 // This is unreachable because at least one else clause in the above loop
82 // must be reached.
83 TFLITE_DCHECK(false);
84 params->broadcast_category = BroadcastableOpCategory::kNonBroadcast;
85 return false;
86 }
87
88 // From this point it is assumed contractually that corresponding dimensions
89 // in shape0 and shape1 are either (a) equal or (b) one or other equals 1.
90 const bool swap_inputs = params->broadcast_category ==
91 BroadcastableOpCategory::kSecondInputBroadcastsFast;
92 const RuntimeShape* shape_a =
93 swap_inputs ? &extended_shape1 : &extended_shape0;
94 const RuntimeShape* shape_b =
95 swap_inputs ? &extended_shape0 : &extended_shape1;
96
97 int i = dims_count - 1;
98 params->broadcast_shape[0] = 1;
99 params->broadcast_shape[1] = 1;
100 params->broadcast_shape[2] = 1;
101 params->broadcast_shape[3] = 1;
102 params->broadcast_shape[4] = 1;
103 // y_0 is greedy: include dims if both or neither equal 1: in other words,
104 // test for equality rather than (shape_a->Dims(i) != 1).
105 while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
106 params->broadcast_shape[4] *= shape_b->Dims(i);
107 --i;
108 }
109 // Here either input_a or input_b has dim of 1 (if i >= 0). If it is input_b
110 // that has the unit dimension, the next two loops are not entered.
111 while (i >= 0 && shape_a->Dims(i) == 1) {
112 params->broadcast_shape[3] *= shape_b->Dims(i);
113 --i;
114 }
115 while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
116 params->broadcast_shape[2] *= shape_a->Dims(i);
117 --i;
118 }
119 // Here either input_a or input_b has dim of 1 (if i >= 0).
120 while (i >= 0 && shape_b->Dims(i) == 1) {
121 params->broadcast_shape[1] *= shape_a->Dims(i);
122 --i;
123 }
124 while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
125 params->broadcast_shape[0] *= shape_b->Dims(i);
126 --i;
127 }
128
129 // Rarer case is when the broadcast dimensions cannot be handled by a fivefold
130 // loop.
131 if (i >= 0) {
132 params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
133 }
134 return true;
135 }
136
137 } // namespace reference_ops
138 } // namespace tflite
139
140 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
141