1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h"
17
18 #include <any>
19 #include <memory>
20 #include <string>
21 #include <variant>
22 #include <vector>
23
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
26 #include "tensorflow/lite/delegates/gpu/common/model.h"
27 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
28 #include "tensorflow/lite/delegates/gpu/common/operations.h"
29 #include "tensorflow/lite/delegates/gpu/common/shape.h"
30 #include "tensorflow/lite/delegates/gpu/common/status.h"
31 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
32
33 namespace tflite {
34 namespace gpu {
35 namespace {
36
37 class MergeConvolutionWithMul : public SequenceTransformation {
38 public:
ExpectedSequenceLength() const39 int ExpectedSequenceLength() const final { return 2; }
40
ApplyToNodesSequence(const std::vector<Node * > & sequence,GraphFloat32 * graph)41 TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
42 GraphFloat32* graph) final {
43 auto& conv_node = *sequence[0];
44 if (graph->FindInputs(conv_node.id).size() != 1) {
45 return {TransformStatus::DECLINED,
46 "This fusion is only applicable to ops with one runtime input."};
47 }
48
49 auto& mul_node = *sequence[1];
50 if (mul_node.operation.type != ToString(OperationType::MUL) ||
51 !mul_node.operation.attributes.has_value()) {
52 return {TransformStatus::SKIPPED, ""};
53 }
54
55 ElementwiseAttributes mul_attr =
56 absl::any_cast<ElementwiseAttributes>(mul_node.operation.attributes);
57 if (!absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(
58 mul_attr.param) &&
59 !absl::holds_alternative<float>(mul_attr.param)) {
60 return {
61 TransformStatus::DECLINED,
62 "This fuse applicable only for broadcast or scalar multiplication."};
63 }
64
65 if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) {
66 Convolution2DAttributes* conv_attr =
67 absl::any_cast<Convolution2DAttributes>(
68 &conv_node.operation.attributes);
69 FuseConvolution2DWithMultiply(mul_attr, conv_attr);
70 } else if (conv_node.operation.type ==
71 ToString(OperationType::CONVOLUTION_TRANSPOSED)) {
72 ConvolutionTransposedAttributes* conv_attr =
73 absl::any_cast<ConvolutionTransposedAttributes>(
74 &conv_node.operation.attributes);
75 FuseConvolutionTransposedWithMultiply(mul_attr, conv_attr);
76 } else if (conv_node.operation.type ==
77 ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
78 DepthwiseConvolution2DAttributes* conv_attr =
79 absl::any_cast<DepthwiseConvolution2DAttributes>(
80 &conv_node.operation.attributes);
81 FuseDepthwiseConvolution2DWithMultiply(mul_attr, conv_attr);
82 } else if (conv_node.operation.type ==
83 ToString(OperationType::FULLY_CONNECTED)) {
84 FullyConnectedAttributes* conv_attr =
85 absl::any_cast<FullyConnectedAttributes>(
86 &conv_node.operation.attributes);
87 FuseFullyConnectedWithMultiply(mul_attr, conv_attr);
88 } else {
89 return {TransformStatus::SKIPPED, ""};
90 }
91
92 absl::Status status = RemoveFollowingNode(graph, &mul_node, &conv_node);
93 if (!status.ok()) {
94 return {TransformStatus::INVALID,
95 "Unable to remove mul node after convolution: " +
96 std::string(status.message())};
97 }
98 return {TransformStatus::APPLIED, ""};
99 }
100 };
101
102 class MergeMulWithConvolution : public SequenceTransformation {
103 public:
ExpectedSequenceLength() const104 int ExpectedSequenceLength() const final { return 2; }
105
ApplyToNodesSequence(const std::vector<Node * > & sequence,GraphFloat32 * graph)106 TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
107 GraphFloat32* graph) final {
108 auto& conv_node = *sequence[1];
109 if (graph->FindInputs(conv_node.id).size() != 1) {
110 return {TransformStatus::DECLINED,
111 "This fusion is only applicable to ops with one runtime input."};
112 }
113 auto& mul_node = *sequence[0];
114 if (mul_node.operation.type != ToString(OperationType::MUL) ||
115 !mul_node.operation.attributes.has_value()) {
116 return {TransformStatus::SKIPPED, ""};
117 }
118
119 ElementwiseAttributes mul_attr =
120 absl::any_cast<ElementwiseAttributes>(mul_node.operation.attributes);
121 if (!absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(
122 mul_attr.param) &&
123 !absl::holds_alternative<float>(mul_attr.param)) {
124 return {
125 TransformStatus::DECLINED,
126 "This fuse applicable only for broadcast or scalar multiplication."};
127 }
128
129 if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) {
130 Convolution2DAttributes* conv_attr =
131 absl::any_cast<Convolution2DAttributes>(
132 &conv_node.operation.attributes);
133 FuseMultiplyWithConvolution2D(mul_attr, conv_attr);
134 } else if (conv_node.operation.type ==
135 ToString(OperationType::CONVOLUTION_TRANSPOSED)) {
136 ConvolutionTransposedAttributes* conv_attr =
137 absl::any_cast<ConvolutionTransposedAttributes>(
138 &conv_node.operation.attributes);
139 FuseMultiplyWithConvolutionTransposed(mul_attr, conv_attr);
140 } else if (conv_node.operation.type ==
141 ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
142 DepthwiseConvolution2DAttributes* conv_attr =
143 absl::any_cast<DepthwiseConvolution2DAttributes>(
144 &conv_node.operation.attributes);
145 FuseMultiplyWithDepthwiseConvolution2D(mul_attr, conv_attr);
146 } else if (conv_node.operation.type ==
147 ToString(OperationType::FULLY_CONNECTED)) {
148 FullyConnectedAttributes* conv_attr =
149 absl::any_cast<FullyConnectedAttributes>(
150 &conv_node.operation.attributes);
151 FuseMultiplyWithFullyConnected(mul_attr, conv_attr);
152 } else {
153 return {TransformStatus::SKIPPED, ""};
154 }
155
156 absl::Status status = RemovePrecedingNode(graph, &mul_node, &conv_node);
157 if (!status.ok()) {
158 return {TransformStatus::INVALID,
159 "Unable to remove mul node after convolution: " +
160 std::string(status.message())};
161 }
162 return {TransformStatus::APPLIED, ""};
163 }
164 };
165
166 } // namespace
167
NewMergeConvolutionWithMul()168 std::unique_ptr<SequenceTransformation> NewMergeConvolutionWithMul() {
169 return absl::make_unique<MergeConvolutionWithMul>();
170 }
171
NewMergeMulWithConvolution()172 std::unique_ptr<SequenceTransformation> NewMergeMulWithConvolution() {
173 return absl::make_unique<MergeMulWithConvolution>();
174 }
175
FuseConvolution2DWithMultiply(const ElementwiseAttributes & mul_attr,Convolution2DAttributes * attr)176 void FuseConvolution2DWithMultiply(const ElementwiseAttributes& mul_attr,
177 Convolution2DAttributes* attr) {
178 auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
179 auto mul_scalar = absl::get_if<float>(&mul_attr.param);
180 for (int d = 0; d < attr->weights.shape.o; ++d) {
181 const float multiplier = mul ? mul->data[d] : *mul_scalar;
182 for (int s = 0; s < attr->weights.shape.i; ++s) {
183 for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
184 for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
185 const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}});
186 attr->weights.data[index] *= multiplier;
187 }
188 }
189 }
190 if (!attr->bias.data.empty()) {
191 attr->bias.data[d] *= multiplier;
192 }
193 }
194 }
195
FuseDepthwiseConvolution2DWithMultiply(const ElementwiseAttributes & mul_attr,DepthwiseConvolution2DAttributes * attr)196 void FuseDepthwiseConvolution2DWithMultiply(
197 const ElementwiseAttributes& mul_attr,
198 DepthwiseConvolution2DAttributes* attr) {
199 auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
200 auto mul_scalar = absl::get_if<float>(&mul_attr.param);
201 for (int g = 0; g < attr->weights.shape.o; ++g) {
202 for (int s = 0; s < attr->weights.shape.i; ++s) {
203 const int d = s * attr->weights.shape.o + g;
204 const float multiplier = mul ? mul->data[d] : *mul_scalar;
205 for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
206 for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
207 const int index = attr->weights.shape.LinearIndex({{g, k_y, k_x, s}});
208 attr->weights.data[index] *= multiplier;
209 }
210 }
211 if (!attr->bias.data.empty()) {
212 attr->bias.data[d] *= multiplier;
213 }
214 }
215 }
216 }
217
FuseConvolutionTransposedWithMultiply(const ElementwiseAttributes & mul_attr,ConvolutionTransposedAttributes * attr)218 void FuseConvolutionTransposedWithMultiply(
219 const ElementwiseAttributes& mul_attr,
220 ConvolutionTransposedAttributes* attr) {
221 auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
222 auto mul_scalar = absl::get_if<float>(&mul_attr.param);
223 for (int d = 0; d < attr->weights.shape.o; ++d) {
224 const float multiplier = mul ? mul->data[d] : *mul_scalar;
225 for (int s = 0; s < attr->weights.shape.i; ++s) {
226 for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
227 for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
228 const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}});
229 attr->weights.data[index] *= multiplier;
230 }
231 }
232 }
233 if (!attr->bias.data.empty()) {
234 attr->bias.data[d] *= multiplier;
235 }
236 }
237 }
238
FuseFullyConnectedWithMultiply(const ElementwiseAttributes & mul_attr,FullyConnectedAttributes * attr)239 void FuseFullyConnectedWithMultiply(const ElementwiseAttributes& mul_attr,
240 FullyConnectedAttributes* attr) {
241 auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
242 auto mul_scalar = absl::get_if<float>(&mul_attr.param);
243 for (int d = 0; d < attr->weights.shape.o; ++d) {
244 const float multiplier = mul ? mul->data[d] : *mul_scalar;
245 for (int s = 0; s < attr->weights.shape.i; ++s) {
246 const int index = attr->weights.shape.LinearIndex({{d, 0, 0, s}});
247 attr->weights.data[index] *= multiplier;
248 }
249 if (!attr->bias.data.empty()) {
250 attr->bias.data[d] *= multiplier;
251 }
252 }
253 }
254
FuseMultiplyWithConvolution2D(const ElementwiseAttributes & mul_attr,Convolution2DAttributes * attr)255 void FuseMultiplyWithConvolution2D(const ElementwiseAttributes& mul_attr,
256 Convolution2DAttributes* attr) {
257 auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
258 auto mul_scalar = absl::get_if<float>(&mul_attr.param);
259 for (int s = 0; s < attr->weights.shape.i; ++s) {
260 const float multiplier = mul ? mul->data[s] : *mul_scalar;
261 for (int d = 0; d < attr->weights.shape.o; ++d) {
262 for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
263 for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
264 const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}});
265 attr->weights.data[index] *= multiplier;
266 }
267 }
268 }
269 }
270 }
271
FuseMultiplyWithDepthwiseConvolution2D(const ElementwiseAttributes & mul_attr,DepthwiseConvolution2DAttributes * attr)272 void FuseMultiplyWithDepthwiseConvolution2D(
273 const ElementwiseAttributes& mul_attr,
274 DepthwiseConvolution2DAttributes* attr) {
275 auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
276 auto mul_scalar = absl::get_if<float>(&mul_attr.param);
277 for (int s = 0; s < attr->weights.shape.i; ++s) {
278 const float multiplier = mul ? mul->data[s] : *mul_scalar;
279 for (int g = 0; g < attr->weights.shape.o; ++g) {
280 for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
281 for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
282 const int index = attr->weights.shape.LinearIndex({{g, k_y, k_x, s}});
283 attr->weights.data[index] *= multiplier;
284 }
285 }
286 }
287 }
288 }
289
FuseMultiplyWithConvolutionTransposed(const ElementwiseAttributes & mul_attr,ConvolutionTransposedAttributes * attr)290 void FuseMultiplyWithConvolutionTransposed(
291 const ElementwiseAttributes& mul_attr,
292 ConvolutionTransposedAttributes* attr) {
293 auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
294 auto mul_scalar = absl::get_if<float>(&mul_attr.param);
295 for (int s = 0; s < attr->weights.shape.i; ++s) {
296 const float multiplier = mul ? mul->data[s] : *mul_scalar;
297 for (int d = 0; d < attr->weights.shape.o; ++d) {
298 for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
299 for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
300 const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}});
301 attr->weights.data[index] *= multiplier;
302 }
303 }
304 }
305 }
306 }
307
FuseMultiplyWithFullyConnected(const ElementwiseAttributes & mul_attr,FullyConnectedAttributes * attr)308 void FuseMultiplyWithFullyConnected(const ElementwiseAttributes& mul_attr,
309 FullyConnectedAttributes* attr) {
310 auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
311 auto mul_scalar = absl::get_if<float>(&mul_attr.param);
312 for (int s = 0; s < attr->weights.shape.i; ++s) {
313 const float multiplier = mul ? mul->data[s] : *mul_scalar;
314 for (int d = 0; d < attr->weights.shape.o; ++d) {
315 const int index = attr->weights.shape.LinearIndex({{d, 0, 0, s}});
316 attr->weights.data[index] *= multiplier;
317 }
318 }
319 }
320
321 } // namespace gpu
322 } // namespace tflite
323