xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
17 
18 #include <algorithm>
19 
20 #include "tensorflow/lite/kernels/internal/types.h"
21 
22 namespace tflite {
23 
24 namespace reference_ops {
25 
26 // Consolidates dimensions in broadcast inputs, checks for five-fold pattern.
27 //
28 // For example, if sequence of dimensions of one input is
29 // ..., 1, 3, 1, 7, 9, 5,... and the other is ..., 2, 3, 1, 7, 1, 1, ...
30 // we can consolidate these as
31 // ..., 1, 3*7, 9*5, ... and 2, 3*7, 1.
32 //
33 // The category is updated in the less-frequent case of shapes that are
34 // not suited to a fivefold-loop broadcast.
35 //
36 // Falls back to generic pattern when it does not know how to process properly.
37 //
38 // Returns true iff there is some sort of broadcast, which includes five-fold
39 // patterns and falling back to generic broadcast.
ProcessBroadcastShapes(const RuntimeShape & shape0,const RuntimeShape & shape1,tflite::ArithmeticParams * params)40 inline bool ProcessBroadcastShapes(const RuntimeShape& shape0,
41                                    const RuntimeShape& shape1,
42                                    tflite::ArithmeticParams* params) {
43   const int dims_count =
44       std::max(shape0.DimensionsCount(), shape1.DimensionsCount());
45 
46   params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
47   RuntimeShape scalar_shape(dims_count, 1);
48 
49   auto extended_shape0 = RuntimeShape::ExtendedShape(dims_count, shape0);
50   auto extended_shape1 = RuntimeShape::ExtendedShape(dims_count, shape1);
51 
52   // Check for "exact" match, implicitly accepting any scalar shapes.
53   if (extended_shape0 == extended_shape1) {
54     params->broadcast_category = BroadcastableOpCategory::kNonBroadcast;
55     return false;
56   }
57 
58   for (int i = dims_count - 1; i >= 0; --i) {
59     if (extended_shape0.Dims(i) == extended_shape1.Dims(i)) {
60       continue;
61     } else if (extended_shape0.Dims(i) == 1) {
62       params->broadcast_category =
63           BroadcastableOpCategory::kFirstInputBroadcastsFast;
64       break;
65     } else if (extended_shape1.Dims(i) == 1) {
66       params->broadcast_category =
67           BroadcastableOpCategory::kSecondInputBroadcastsFast;
68       break;
69     } else {
70       // This case is erroneous: there is a dimension that does not match and
71       // is not a broadcast from one shape to the other.
72       params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
73       return true;
74     }
75   }
76 
77   if (params->broadcast_category !=
78           BroadcastableOpCategory::kFirstInputBroadcastsFast &&
79       params->broadcast_category !=
80           BroadcastableOpCategory::kSecondInputBroadcastsFast) {
81     // This is unreachable because at least one else clause in the above loop
82     // must be reached.
83     TFLITE_DCHECK(false);
84     params->broadcast_category = BroadcastableOpCategory::kNonBroadcast;
85     return false;
86   }
87 
88   // From this point it is assumed contractually that corresponding dimensions
89   // in shape0 and shape1 are either (a) equal or (b) one or other equals 1.
90   const bool swap_inputs = params->broadcast_category ==
91                            BroadcastableOpCategory::kSecondInputBroadcastsFast;
92   const RuntimeShape* shape_a =
93       swap_inputs ? &extended_shape1 : &extended_shape0;
94   const RuntimeShape* shape_b =
95       swap_inputs ? &extended_shape0 : &extended_shape1;
96 
97   int i = dims_count - 1;
98   params->broadcast_shape[0] = 1;
99   params->broadcast_shape[1] = 1;
100   params->broadcast_shape[2] = 1;
101   params->broadcast_shape[3] = 1;
102   params->broadcast_shape[4] = 1;
103   // y_0 is greedy: include dims if both or neither equal 1: in other words,
104   // test for equality rather than (shape_a->Dims(i) != 1).
105   while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
106     params->broadcast_shape[4] *= shape_b->Dims(i);
107     --i;
108   }
109   // Here either input_a or input_b has dim of 1 (if i >= 0).  If it is input_b
110   // that has the unit dimension, the next two loops are not entered.
111   while (i >= 0 && shape_a->Dims(i) == 1) {
112     params->broadcast_shape[3] *= shape_b->Dims(i);
113     --i;
114   }
115   while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
116     params->broadcast_shape[2] *= shape_a->Dims(i);
117     --i;
118   }
119   // Here either input_a or input_b has dim of 1 (if i >= 0).
120   while (i >= 0 && shape_b->Dims(i) == 1) {
121     params->broadcast_shape[1] *= shape_a->Dims(i);
122     --i;
123   }
124   while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
125     params->broadcast_shape[0] *= shape_b->Dims(i);
126     --i;
127   }
128 
129   // Rarer case is when the broadcast dimensions cannot be handled by a fivefold
130   // loop.
131   if (i >= 0) {
132     params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
133   }
134   return true;
135 }
136 
137 }  // namespace reference_ops
138 }  // namespace tflite
139 
140 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
141