xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h"
17 
18 #include <any>
19 #include <memory>
20 #include <string>
21 #include <variant>
22 #include <vector>
23 
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
26 #include "tensorflow/lite/delegates/gpu/common/model.h"
27 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
28 #include "tensorflow/lite/delegates/gpu/common/operations.h"
29 #include "tensorflow/lite/delegates/gpu/common/shape.h"
30 #include "tensorflow/lite/delegates/gpu/common/status.h"
31 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
32 
33 namespace tflite {
34 namespace gpu {
35 namespace {
36 
37 class MergeConvolutionWithMul : public SequenceTransformation {
38  public:
ExpectedSequenceLength() const39   int ExpectedSequenceLength() const final { return 2; }
40 
ApplyToNodesSequence(const std::vector<Node * > & sequence,GraphFloat32 * graph)41   TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
42                                        GraphFloat32* graph) final {
43     auto& conv_node = *sequence[0];
44     if (graph->FindInputs(conv_node.id).size() != 1) {
45       return {TransformStatus::DECLINED,
46               "This fusion is only applicable to ops with one runtime input."};
47     }
48 
49     auto& mul_node = *sequence[1];
50     if (mul_node.operation.type != ToString(OperationType::MUL) ||
51         !mul_node.operation.attributes.has_value()) {
52       return {TransformStatus::SKIPPED, ""};
53     }
54 
55     ElementwiseAttributes mul_attr =
56         absl::any_cast<ElementwiseAttributes>(mul_node.operation.attributes);
57     if (!absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(
58             mul_attr.param) &&
59         !absl::holds_alternative<float>(mul_attr.param)) {
60       return {
61           TransformStatus::DECLINED,
62           "This fuse applicable only for broadcast or scalar multiplication."};
63     }
64 
65     if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) {
66       Convolution2DAttributes* conv_attr =
67           absl::any_cast<Convolution2DAttributes>(
68               &conv_node.operation.attributes);
69       FuseConvolution2DWithMultiply(mul_attr, conv_attr);
70     } else if (conv_node.operation.type ==
71                ToString(OperationType::CONVOLUTION_TRANSPOSED)) {
72       ConvolutionTransposedAttributes* conv_attr =
73           absl::any_cast<ConvolutionTransposedAttributes>(
74               &conv_node.operation.attributes);
75       FuseConvolutionTransposedWithMultiply(mul_attr, conv_attr);
76     } else if (conv_node.operation.type ==
77                ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
78       DepthwiseConvolution2DAttributes* conv_attr =
79           absl::any_cast<DepthwiseConvolution2DAttributes>(
80               &conv_node.operation.attributes);
81       FuseDepthwiseConvolution2DWithMultiply(mul_attr, conv_attr);
82     } else if (conv_node.operation.type ==
83                ToString(OperationType::FULLY_CONNECTED)) {
84       FullyConnectedAttributes* conv_attr =
85           absl::any_cast<FullyConnectedAttributes>(
86               &conv_node.operation.attributes);
87       FuseFullyConnectedWithMultiply(mul_attr, conv_attr);
88     } else {
89       return {TransformStatus::SKIPPED, ""};
90     }
91 
92     absl::Status status = RemoveFollowingNode(graph, &mul_node, &conv_node);
93     if (!status.ok()) {
94       return {TransformStatus::INVALID,
95               "Unable to remove mul node after convolution: " +
96                   std::string(status.message())};
97     }
98     return {TransformStatus::APPLIED, ""};
99   }
100 };
101 
102 class MergeMulWithConvolution : public SequenceTransformation {
103  public:
ExpectedSequenceLength() const104   int ExpectedSequenceLength() const final { return 2; }
105 
ApplyToNodesSequence(const std::vector<Node * > & sequence,GraphFloat32 * graph)106   TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
107                                        GraphFloat32* graph) final {
108     auto& conv_node = *sequence[1];
109     if (graph->FindInputs(conv_node.id).size() != 1) {
110       return {TransformStatus::DECLINED,
111               "This fusion is only applicable to ops with one runtime input."};
112     }
113     auto& mul_node = *sequence[0];
114     if (mul_node.operation.type != ToString(OperationType::MUL) ||
115         !mul_node.operation.attributes.has_value()) {
116       return {TransformStatus::SKIPPED, ""};
117     }
118 
119     ElementwiseAttributes mul_attr =
120         absl::any_cast<ElementwiseAttributes>(mul_node.operation.attributes);
121     if (!absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(
122             mul_attr.param) &&
123         !absl::holds_alternative<float>(mul_attr.param)) {
124       return {
125           TransformStatus::DECLINED,
126           "This fuse applicable only for broadcast or scalar multiplication."};
127     }
128 
129     if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) {
130       Convolution2DAttributes* conv_attr =
131           absl::any_cast<Convolution2DAttributes>(
132               &conv_node.operation.attributes);
133       FuseMultiplyWithConvolution2D(mul_attr, conv_attr);
134     } else if (conv_node.operation.type ==
135                ToString(OperationType::CONVOLUTION_TRANSPOSED)) {
136       ConvolutionTransposedAttributes* conv_attr =
137           absl::any_cast<ConvolutionTransposedAttributes>(
138               &conv_node.operation.attributes);
139       FuseMultiplyWithConvolutionTransposed(mul_attr, conv_attr);
140     } else if (conv_node.operation.type ==
141                ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
142       DepthwiseConvolution2DAttributes* conv_attr =
143           absl::any_cast<DepthwiseConvolution2DAttributes>(
144               &conv_node.operation.attributes);
145       FuseMultiplyWithDepthwiseConvolution2D(mul_attr, conv_attr);
146     } else if (conv_node.operation.type ==
147                ToString(OperationType::FULLY_CONNECTED)) {
148       FullyConnectedAttributes* conv_attr =
149           absl::any_cast<FullyConnectedAttributes>(
150               &conv_node.operation.attributes);
151       FuseMultiplyWithFullyConnected(mul_attr, conv_attr);
152     } else {
153       return {TransformStatus::SKIPPED, ""};
154     }
155 
156     absl::Status status = RemovePrecedingNode(graph, &mul_node, &conv_node);
157     if (!status.ok()) {
158       return {TransformStatus::INVALID,
159               "Unable to remove mul node after convolution: " +
160                   std::string(status.message())};
161     }
162     return {TransformStatus::APPLIED, ""};
163   }
164 };
165 
166 }  // namespace
167 
NewMergeConvolutionWithMul()168 std::unique_ptr<SequenceTransformation> NewMergeConvolutionWithMul() {
169   return absl::make_unique<MergeConvolutionWithMul>();
170 }
171 
NewMergeMulWithConvolution()172 std::unique_ptr<SequenceTransformation> NewMergeMulWithConvolution() {
173   return absl::make_unique<MergeMulWithConvolution>();
174 }
175 
FuseConvolution2DWithMultiply(const ElementwiseAttributes & mul_attr,Convolution2DAttributes * attr)176 void FuseConvolution2DWithMultiply(const ElementwiseAttributes& mul_attr,
177                                    Convolution2DAttributes* attr) {
178   auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
179   auto mul_scalar = absl::get_if<float>(&mul_attr.param);
180   for (int d = 0; d < attr->weights.shape.o; ++d) {
181     const float multiplier = mul ? mul->data[d] : *mul_scalar;
182     for (int s = 0; s < attr->weights.shape.i; ++s) {
183       for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
184         for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
185           const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}});
186           attr->weights.data[index] *= multiplier;
187         }
188       }
189     }
190     if (!attr->bias.data.empty()) {
191       attr->bias.data[d] *= multiplier;
192     }
193   }
194 }
195 
FuseDepthwiseConvolution2DWithMultiply(const ElementwiseAttributes & mul_attr,DepthwiseConvolution2DAttributes * attr)196 void FuseDepthwiseConvolution2DWithMultiply(
197     const ElementwiseAttributes& mul_attr,
198     DepthwiseConvolution2DAttributes* attr) {
199   auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
200   auto mul_scalar = absl::get_if<float>(&mul_attr.param);
201   for (int g = 0; g < attr->weights.shape.o; ++g) {
202     for (int s = 0; s < attr->weights.shape.i; ++s) {
203       const int d = s * attr->weights.shape.o + g;
204       const float multiplier = mul ? mul->data[d] : *mul_scalar;
205       for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
206         for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
207           const int index = attr->weights.shape.LinearIndex({{g, k_y, k_x, s}});
208           attr->weights.data[index] *= multiplier;
209         }
210       }
211       if (!attr->bias.data.empty()) {
212         attr->bias.data[d] *= multiplier;
213       }
214     }
215   }
216 }
217 
FuseConvolutionTransposedWithMultiply(const ElementwiseAttributes & mul_attr,ConvolutionTransposedAttributes * attr)218 void FuseConvolutionTransposedWithMultiply(
219     const ElementwiseAttributes& mul_attr,
220     ConvolutionTransposedAttributes* attr) {
221   auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
222   auto mul_scalar = absl::get_if<float>(&mul_attr.param);
223   for (int d = 0; d < attr->weights.shape.o; ++d) {
224     const float multiplier = mul ? mul->data[d] : *mul_scalar;
225     for (int s = 0; s < attr->weights.shape.i; ++s) {
226       for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
227         for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
228           const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}});
229           attr->weights.data[index] *= multiplier;
230         }
231       }
232     }
233     if (!attr->bias.data.empty()) {
234       attr->bias.data[d] *= multiplier;
235     }
236   }
237 }
238 
FuseFullyConnectedWithMultiply(const ElementwiseAttributes & mul_attr,FullyConnectedAttributes * attr)239 void FuseFullyConnectedWithMultiply(const ElementwiseAttributes& mul_attr,
240                                     FullyConnectedAttributes* attr) {
241   auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
242   auto mul_scalar = absl::get_if<float>(&mul_attr.param);
243   for (int d = 0; d < attr->weights.shape.o; ++d) {
244     const float multiplier = mul ? mul->data[d] : *mul_scalar;
245     for (int s = 0; s < attr->weights.shape.i; ++s) {
246       const int index = attr->weights.shape.LinearIndex({{d, 0, 0, s}});
247       attr->weights.data[index] *= multiplier;
248     }
249     if (!attr->bias.data.empty()) {
250       attr->bias.data[d] *= multiplier;
251     }
252   }
253 }
254 
FuseMultiplyWithConvolution2D(const ElementwiseAttributes & mul_attr,Convolution2DAttributes * attr)255 void FuseMultiplyWithConvolution2D(const ElementwiseAttributes& mul_attr,
256                                    Convolution2DAttributes* attr) {
257   auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
258   auto mul_scalar = absl::get_if<float>(&mul_attr.param);
259   for (int s = 0; s < attr->weights.shape.i; ++s) {
260     const float multiplier = mul ? mul->data[s] : *mul_scalar;
261     for (int d = 0; d < attr->weights.shape.o; ++d) {
262       for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
263         for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
264           const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}});
265           attr->weights.data[index] *= multiplier;
266         }
267       }
268     }
269   }
270 }
271 
FuseMultiplyWithDepthwiseConvolution2D(const ElementwiseAttributes & mul_attr,DepthwiseConvolution2DAttributes * attr)272 void FuseMultiplyWithDepthwiseConvolution2D(
273     const ElementwiseAttributes& mul_attr,
274     DepthwiseConvolution2DAttributes* attr) {
275   auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
276   auto mul_scalar = absl::get_if<float>(&mul_attr.param);
277   for (int s = 0; s < attr->weights.shape.i; ++s) {
278     const float multiplier = mul ? mul->data[s] : *mul_scalar;
279     for (int g = 0; g < attr->weights.shape.o; ++g) {
280       for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
281         for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
282           const int index = attr->weights.shape.LinearIndex({{g, k_y, k_x, s}});
283           attr->weights.data[index] *= multiplier;
284         }
285       }
286     }
287   }
288 }
289 
FuseMultiplyWithConvolutionTransposed(const ElementwiseAttributes & mul_attr,ConvolutionTransposedAttributes * attr)290 void FuseMultiplyWithConvolutionTransposed(
291     const ElementwiseAttributes& mul_attr,
292     ConvolutionTransposedAttributes* attr) {
293   auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
294   auto mul_scalar = absl::get_if<float>(&mul_attr.param);
295   for (int s = 0; s < attr->weights.shape.i; ++s) {
296     const float multiplier = mul ? mul->data[s] : *mul_scalar;
297     for (int d = 0; d < attr->weights.shape.o; ++d) {
298       for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
299         for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
300           const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}});
301           attr->weights.data[index] *= multiplier;
302         }
303       }
304     }
305   }
306 }
307 
FuseMultiplyWithFullyConnected(const ElementwiseAttributes & mul_attr,FullyConnectedAttributes * attr)308 void FuseMultiplyWithFullyConnected(const ElementwiseAttributes& mul_attr,
309                                     FullyConnectedAttributes* attr) {
310   auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
311   auto mul_scalar = absl::get_if<float>(&mul_attr.param);
312   for (int s = 0; s < attr->weights.shape.i; ++s) {
313     const float multiplier = mul ? mul->data[s] : *mul_scalar;
314     for (int d = 0; d < attr->weights.shape.o; ++d) {
315       const int index = attr->weights.shape.LinearIndex({{d, 0, 0, s}});
316       attr->weights.data[index] *= multiplier;
317     }
318   }
319 }
320 
321 }  // namespace gpu
322 }  // namespace tflite
323