1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
17
18 #include <functional>
19 #include <string>
20
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
24 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
25 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
29 #include "tensorflow/compiler/xla/stream_executor/dnn.pb.h"
30 #include "tensorflow/core/platform/errors.h"
31 #include "tensorflow/core/platform/statusor.h"
32 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
33
34 namespace xla {
35 namespace gpu {
36 namespace {
37
38 namespace m = match;
39
IsConvCustomCall(const HloInstruction * instr)40 bool IsConvCustomCall(const HloInstruction* instr) {
41 return instr->opcode() == HloOpcode::kCustomCall &&
42 (instr->custom_call_target() == kCudnnConvForwardCallTarget ||
43 instr->custom_call_target() ==
44 kCudnnConvBiasActivationForwardCallTarget);
45 }
46
47 // Can instr be converted to type `dst_ty` without losing any precision? For
48 // our purposes, this is true if:
49 //
50 // - instr already has type dst_ty, or
51 // - instr is convert<wider type>(op_with_dst_ty), or
52 // - instr is a constant which we can convert orig_ty -> dst_ty -> orig_ty and
53 // get back exactly the original value, or
54 // - instr is a broadcast, reshape, or transpose of one of the above.
IsLosslesslyConvertibleTo(const HloInstruction * instr,PrimitiveType dst_ty)55 bool IsLosslesslyConvertibleTo(const HloInstruction* instr,
56 PrimitiveType dst_ty) {
57 if (instr->shape().element_type() == dst_ty) {
58 return true;
59 }
60
61 if (Match(instr, m::Convert(m::Op().WithElementType(dst_ty)))) {
62 // Check that the convert from dst_ty to instr->element_type() doesn't lose
63 // precision. Otherwise, this convert is not lossless.
64 return primitive_util::CastPreservesValues(dst_ty,
65 instr->shape().element_type());
66 }
67
68 if (instr->opcode() == HloOpcode::kConstant) {
69 if (!instr->shape().IsArray()) {
70 return false;
71 }
72 // Check if instr's literal roundtrips to ty and back to its original type
73 // without modification.
74 PrimitiveType orig_ty = instr->shape().element_type();
75
76 // The only reason Convert() should fail is if we don't support converting
77 // from x to y, which indeed means it's not losslessly-convertible.
78 StatusOr<Literal> converted1 = instr->literal().Convert(dst_ty);
79 if (!converted1.ok()) {
80 return false;
81 }
82 StatusOr<Literal> converted2 = converted1->Convert(orig_ty);
83 if (!converted2.ok()) {
84 return false;
85 }
86
87 return instr->literal() == *converted2;
88 }
89
90 if (instr->opcode() == HloOpcode::kBroadcast ||
91 instr->opcode() == HloOpcode::kReshape ||
92 instr->opcode() == HloOpcode::kTranspose) {
93 return IsLosslesslyConvertibleTo(instr->operand(0), dst_ty);
94 }
95
96 return false;
97 }
98
99 // Helpers suitable for use in m::Op().WithPredicate(...).
IsLosslesslyConvertibleToS8(const HloInstruction * instr)100 bool IsLosslesslyConvertibleToS8(const HloInstruction* instr) {
101 return IsLosslesslyConvertibleTo(instr, S8);
102 }
IsLosslesslyConvertibleToF16(const HloInstruction * instr)103 bool IsLosslesslyConvertibleToF16(const HloInstruction* instr) {
104 return IsLosslesslyConvertibleTo(instr, F16);
105 }
106
107 // If `conv` is a vanilla forward conv, transforms it into a
108 // conv-bias-activation. If it's already a conv-bias-activation, does nothing.
109 //
110 // If `conv` is anything else, returns an error.
EnsureIsConvBiasActivation(HloInstruction * conv)111 StatusOr<HloInstruction*> EnsureIsConvBiasActivation(HloInstruction* conv) {
112 CHECK_EQ(conv->opcode(), HloOpcode::kCustomCall);
113
114 if (conv->custom_call_target() == kCudnnConvBiasActivationForwardCallTarget) {
115 return conv;
116 }
117
118 if (conv->custom_call_target() == kCudnnConvForwardCallTarget) {
119 HloComputation* comp = conv->parent();
120
121 const Shape& shape = conv->shape().tuple_shapes(0);
122 int64_t num_output_features = shape.dimensions(
123 conv->convolution_dimension_numbers().output_feature_dimension());
124
125 // bias for integer convs is always f32, see
126 // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
127 PrimitiveType bias_ty;
128 if (primitive_util::IsIntegralType(shape.element_type())) {
129 bias_ty = F32;
130 } else {
131 bias_ty = shape.element_type();
132 }
133 auto bias = BroadcastZeros(comp, bias_ty, {num_output_features});
134
135 absl::InlinedVector<HloInstruction*, 3> new_operands(
136 conv->operands().begin(), conv->operands().end());
137 new_operands.push_back(bias);
138
139 HloInstruction* new_conv = comp->AddInstruction(
140 conv->CloneWithNewOperands(conv->shape(), new_operands));
141 TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv));
142 new_conv->set_custom_call_target(kCudnnConvBiasActivationForwardCallTarget);
143 comp->parent()->SetAndUniquifyInstrName(new_conv,
144 "cudnn-conv-bias-activation");
145 return new_conv;
146 }
147
148 return FailedPrecondition("Unsupported conv: %s", conv->ToString());
149 }
150
151 // convert<float>(gte(custom-call<int32>(int8_x, int8_w))) ->
152 // gte(custom-call<float>(int8_x, int8_w))
FuseConvertToFloat(HloComputation * comp)153 StatusOr<bool> FuseConvertToFloat(HloComputation* comp) {
154 bool changed = false;
155 for (auto instr : comp->MakeInstructionPostOrder()) {
156 HloInstruction* conv = nullptr;
157 auto pattern =
158 m::Convert(
159 m::GetTupleElement(m::Op(&conv).WithPredicate(IsConvCustomCall), 0)
160 .WithElementType(S32))
161 .WithElementType(F32);
162 if (!Match(instr, pattern)) {
163 continue;
164 }
165 if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
166 return absl::StrCat("FuseConvertToFloat: ", conv->ToString());
167 })) {
168 continue;
169 }
170
171 Shape new_shape = conv->shape();
172 new_shape.mutable_tuple_shapes(0)->set_element_type(F32);
173 HloInstruction* new_conv =
174 comp->AddInstruction(conv->CloneWithNewShape(new_shape));
175 comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
176 TF_ASSIGN_OR_RETURN(HloInstruction * new_gte,
177 MakeGetTupleElementHlo(new_conv, 0));
178 TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_gte));
179
180 changed = true;
181 }
182
183 return changed;
184 }
185
186 // alpha * gte(custom-call(...)) ->
187 // gte(custom-call(..., backend_config={alpha})).
FuseConvAlpha(HloComputation * comp)188 StatusOr<bool> FuseConvAlpha(HloComputation* comp) {
189 bool changed = false;
190 for (auto instr : comp->MakeInstructionPostOrder()) {
191 HloInstruction* conv = nullptr;
192 HloInstruction* gte = nullptr;
193 HloInstruction* alpha = nullptr;
194 auto pattern = m::MultiplyAnyOrder(
195 m::GetTupleElement(>e, m::Op(&conv).WithPredicate(IsConvCustomCall),
196 0)
197 .WithOneUse(),
198 m::Broadcast(m::ConstantEffectiveScalar(&alpha)));
199 if (!Match(instr, pattern)) {
200 continue;
201 }
202
203 // alpha is f32 except for f64 convs, where it's f64. See
204 // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
205 PrimitiveType alpha_ty = gte->shape().element_type() == F64 ? F64 : F32;
206 if (!IsLosslesslyConvertibleTo(alpha, alpha_ty)) {
207 continue;
208 }
209
210 TF_ASSIGN_OR_RETURN(auto config,
211 conv->backend_config<CudnnConvBackendConfig>());
212 if (config.conv_result_scale() != 1) {
213 continue;
214 }
215 if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
216 return absl::StrCat("FuseConvAlpha: ", conv->ToString());
217 })) {
218 continue;
219 }
220
221 // StreamExecutor doesn't support the alpha parameter on non-bias-activation
222 // convs, so we have to upgrade `conv`.
223 TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
224
225 TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
226 config.set_conv_result_scale(alpha_f64.GetFirstElement<double>());
227
228 TF_RETURN_IF_ERROR(conv->set_backend_config(config));
229 TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(instr, gte));
230
231 changed = true;
232 }
233 return changed;
234 }
235
FuseBiasOrSideInput(HloComputation * comp)236 StatusOr<bool> FuseBiasOrSideInput(HloComputation* comp) {
237 bool changed = false;
238 for (auto instr : comp->MakeInstructionPostOrder()) {
239 HloInstruction* conv = nullptr;
240 HloInstruction* gte = nullptr;
241 HloInstruction* addend = nullptr;
242 auto pattern = m::AddAnyOrder(
243 m::GetTupleElement(
244 >e, m::Op(&conv).WithPredicate(IsConvCustomCall).WithOneUse(), 0)
245 .WithOneUse(),
246 m::Op(&addend));
247 if (!Match(instr, pattern)) {
248 continue;
249 }
250
251 // If it's a vanilla forward conv, upgrade it to a bias-activation conv. We
252 // only want to do this if the fusion will succeed, but we're guaranteed
253 // that it will, because the only reason we'll bail at this point is if
254 // !can_accept_bias && !can_accept_side_input, and our shiny new
255 // bias-activation conv will be able to accept both.
256 if (conv->custom_call_target() == kCudnnConvForwardCallTarget) {
257 TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
258 }
259
260 // Can't fuse bias or side-input if the conv already has a relu (or other
261 // activation), because bias and side-input are added before the activation
262 // is applied.
263 TF_ASSIGN_OR_RETURN(auto config,
264 conv->backend_config<CudnnConvBackendConfig>());
265 if (config.activation_mode() != se::dnn::kNone) {
266 continue;
267 }
268
269 // Does `conv` already have a (nonzero) bias? Does it already have a
270 // side_input?
271 bool can_accept_bias =
272 Match(conv->operand(2), m::Broadcast(m::ConstantEffectiveScalar(0)));
273 bool can_accept_side_input = conv->operand_count() < 4;
274
275 // The addend can be fused as a bias if
276 // - it is 1D broadcasted in the output feature dimension, and
277 // - it is losslessly-convertible to the correct type (f32 for s8/f32/u32
278 // convs, and conv_ty for floating-point convs)
279 PrimitiveType conv_ty = gte->shape().element_type();
280 PrimitiveType bias_ty =
281 primitive_util::IsFloatingPointType(conv_ty) ? conv_ty : F32;
282 bool addend_may_be_rank1_bias =
283 addend->opcode() == HloOpcode::kBroadcast &&
284 addend->dimensions().size() == 1 &&
285 addend->dimensions(0) ==
286 conv->convolution_dimension_numbers().output_feature_dimension() &&
287 IsLosslesslyConvertibleTo(addend, bias_ty);
288
289 bool addend_may_be_rank0_bias = addend->opcode() == HloOpcode::kBroadcast &&
290 addend->dimensions().empty() &&
291 IsLosslesslyConvertibleTo(addend, bias_ty);
292
293 absl::InlinedVector<HloInstruction*, 4> new_operands(
294 conv->operands().begin(), conv->operands().end());
295 if (can_accept_bias && addend_may_be_rank1_bias) {
296 new_operands[2] = MakeConvertToHlo(addend->mutable_operand(0), bias_ty,
297 &addend->operand(0)->metadata());
298 } else if (can_accept_bias && addend_may_be_rank0_bias) {
299 new_operands[2] = MakeBroadcastHlo(
300 MakeConvertToHlo(addend->mutable_operand(0), bias_ty,
301 &addend->operand(0)->metadata()),
302 /*broadcast_dimensions=*/{},
303 /*result_shape_bounds=*/
304 {gte->shape().dimensions(conv->convolution_dimension_numbers()
305 .output_feature_dimension())});
306 } else if (can_accept_side_input) {
307 CHECK_EQ(new_operands.size(), 3);
308 new_operands.push_back(addend);
309 config.set_side_input_scale(1);
310 } else {
311 // Can't fuse; this op already has a bias and a side-input.
312 continue;
313 }
314
315 if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
316 return absl::StrCat("FuseBiasOrSideInput: ", conv->ToString());
317 })) {
318 continue;
319 }
320
321 HloInstruction* new_conv = comp->AddInstruction(
322 conv->CloneWithNewOperands(conv->shape(), new_operands));
323 comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
324 TF_RETURN_IF_ERROR(new_conv->set_backend_config(config));
325 TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
326 MakeGetTupleElementHlo(new_conv, 0));
327 TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
328 changed = true;
329 }
330 return changed;
331 }
332
333 // custom-call(..., alpha * side_input) ->
334 // custom-call(..., side_input, backend_config={alpha}).
335 //
336 // We also have to support the more complicated case of
337 //
338 // custom-call(..., reshape(side_input * alpha)) -->
339 // custom-call(..., reshape(side_input), backend_config={alpha}),
340 //
341 // where `reshape` can be an arbitrary chain of reshapes+transposes. This idiom
342 // is created by the ReshapeMover pass.
FuseSideInputAlpha(HloComputation * comp)343 StatusOr<bool> FuseSideInputAlpha(HloComputation* comp) {
344 bool changed = false;
345 for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
346 HloInstruction* conv;
347 HloInstruction* side_input;
348 auto pattern = m::Op(&conv)
349 .WithPredicate(IsConvCustomCall)
350 .WithOperand(3, m::Op(&side_input));
351 if (!Match(instr, pattern)) {
352 continue;
353 }
354 TF_ASSIGN_OR_RETURN(auto config,
355 conv->backend_config<CudnnConvBackendConfig>());
356 if (config.side_input_scale() != 1) {
357 continue;
358 }
359
360 // Given side_input, pattern match the following (working from bottom up).
361 //
362 // before_reshape = multiply(base, broadcast(alpha))
363 // side_input = chain_of_reshapes_and_transposes(before_reshape)
364 //
365 // where alpha is a scalar constant.
366 //
367 // alpha is f32 except for f64 convs, where it's f64. See
368 // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
369 HloInstruction* before_reshape = side_input;
370 while (before_reshape->opcode() == HloOpcode::kReshape ||
371 before_reshape->opcode() == HloOpcode::kTranspose) {
372 before_reshape = before_reshape->mutable_operand(0);
373 }
374
375 PrimitiveType conv_ty = conv->shape().tuple_shapes(0).element_type();
376 PrimitiveType alpha_ty = conv_ty == F64 ? F64 : F32;
377 HloInstruction* base;
378 HloInstruction* alpha;
379 if (!Match(
380 before_reshape,
381 m::MultiplyAnyOrder(
382 m::Op(&base),
383 m::Broadcast(m::ConstantEffectiveScalar(&alpha).WithPredicate(
384 [&](const HloInstruction* instr) {
385 return IsLosslesslyConvertibleTo(instr, alpha_ty);
386 }))))) {
387 continue;
388 }
389 if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
390 return absl::StrCat("FuseSideInputAlpha: ", conv->ToString());
391 })) {
392 continue;
393 }
394
395 // Rewrite conv's operand 3 to
396 //
397 // chain_of_reshapes_and_transposes(before_reshape).
398 //
399 // and store alpha in the conv's backend config.
400 //
401 // We're going to do something bad here: We aren't going to check that the
402 // chain of reshapes/transposes has one use, so we're potentially
403 // duplicating all these instructions (once with alpha and once without).
404 //
405 // This is justified because
406 //
407 // - duplicating reshapes/transposes shouldn't be "that bad" -- these
408 // instructions can usually be fused, and
409 //
410 // - *not* fusing alpha can be catastrophic. For s8->s8 convolutions, the
411 // side-input must be s8. But the product side_input * alpha is f32, so
412 // we can only see that side-input is s8 if we fuse alpha. IOW not fusing
413 // alpha means we'll run this s8->s8 conv as s8->f32, which is *much*
414 // slower than some extra transposes.
415
416 // Recursively clone chain_of_reshapes_and_transposes until we get to
417 // `before_reshape`, at which point we skip the multiply(base, alpha) and
418 // just return base.
419 std::function<HloInstruction*(const HloInstruction*)> clone =
420 [&](const HloInstruction* instr) {
421 if (instr == before_reshape) {
422 return base;
423 }
424 CHECK(instr->opcode() == HloOpcode::kReshape ||
425 instr->opcode() == HloOpcode::kTranspose)
426 << "Must be reshape or transpose: " << instr->ToString();
427 return comp->AddInstruction(instr->CloneWithNewOperands(
428 instr->shape(), {clone(instr->operand(0))}));
429 };
430 absl::InlinedVector<HloInstruction*, 4> new_operands(
431 conv->operands().begin(), conv->operands().end());
432 new_operands[3] = clone(side_input);
433
434 HloInstruction* new_conv = comp->AddInstruction(
435 conv->CloneWithNewOperands(conv->shape(), new_operands));
436 comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
437
438 TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
439 config.set_side_input_scale(alpha_f64.GetFirstElement<double>());
440 TF_RETURN_IF_ERROR(new_conv->set_backend_config(config));
441
442 TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv));
443 changed = true;
444 }
445 return changed;
446 }
447
FuseRelu(HloComputation * comp)448 StatusOr<bool> FuseRelu(HloComputation* comp) {
449 bool changed = false;
450 for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
451 HloInstruction* gte;
452 HloInstruction* conv;
453 if (!Match(
454 instr,
455 m::MaximumAnyOrder(
456 m::Broadcast(m::ConstantEffectiveScalar(0)),
457 m::GetTupleElement(
458 >e,
459 m::Op(&conv).WithPredicate(IsConvCustomCall).WithOneUse())
460 .WithOneUse()))) {
461 continue;
462 }
463 TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config,
464 conv->backend_config<CudnnConvBackendConfig>());
465 if (config.activation_mode() != se::dnn::kNone) {
466 continue;
467 }
468
469 if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
470 return absl::StrCat("FuseRelu: ", conv->ToString());
471 })) {
472 continue;
473 }
474 TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
475 config.set_activation_mode(se::dnn::kRelu);
476 TF_RETURN_IF_ERROR(conv->set_backend_config(config));
477 TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte));
478 changed = true;
479 }
480 return changed;
481 }
482
FuseConvertToF16(HloComputation * comp)483 StatusOr<bool> FuseConvertToF16(HloComputation* comp) {
484 bool changed = false;
485 for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
486 HloInstruction* gte = nullptr;
487 HloInstruction* conv = nullptr;
488
489 auto f32_convertible_to_f16_pat =
490 m::Op().WithElementType(F32).WithPredicate(
491 IsLosslesslyConvertibleToF16);
492 if (!MatchAndLogIfFailed(
493 instr, "f16 conv",
494 m::Convert(
495 m::GetTupleElement(
496 >e,
497 m::Op(&conv)
498 .WithPredicate(IsConvCustomCall)
499 .WithOperand(0, f32_convertible_to_f16_pat)
500 .WithOperand(1, f32_convertible_to_f16_pat)
501 .WithOperandIfPresent(2, f32_convertible_to_f16_pat)
502 .WithOperandIfPresent(3, f32_convertible_to_f16_pat),
503 0)
504 .WithOneUse())
505 .WithElementType(F16),
506 VLOG_IS_ON(3),
507 m::Op().WithOperand(0, m::GetTupleElement(m::Op().WithPredicate(
508 IsConvCustomCall))))) {
509 continue;
510 }
511 if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
512 return absl::StrCat("FuseConvertToF16: ", conv->ToString());
513 })) {
514 continue;
515 }
516
517 VLOG(2) << "Matched fp16 conv: " << conv->ToString();
518
519 // In fp16 convs, all operands, including `bias`, must be fp16. This is
520 // different from int8 convs, where the bias is fp32. See table of
521 // supported datatypes at
522 // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
523 absl::InlinedVector<HloInstruction*, 4> new_operands;
524 for (HloInstruction* operand : conv->operands()) {
525 new_operands.push_back(
526 MakeConvertToHlo(operand, F16, &operand->metadata()));
527 }
528
529 Shape new_shape = conv->shape();
530 new_shape.mutable_tuple_shapes(0)->set_element_type(F16);
531
532 HloInstruction* new_conv = comp->AddInstruction(
533 conv->CloneWithNewOperands(new_shape, new_operands));
534 comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
535 TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
536 MakeGetTupleElementHlo(new_conv, 0));
537 TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
538 changed = true;
539 }
540 return changed;
541 }
542
FuseConvertToS8(HloComputation * comp)543 StatusOr<bool> FuseConvertToS8(HloComputation* comp) {
544 bool changed = false;
545 for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
546 HloInstruction* gte = nullptr;
547 HloInstruction* conv = nullptr;
548
549 auto conv_pattern =
550 m::Op(&conv)
551 .WithPredicate(IsConvCustomCall)
552 .WithOperand(0, m::Op().WithPredicate(IsLosslesslyConvertibleToS8))
553 .WithOperand(1, m::Op().WithPredicate(IsLosslesslyConvertibleToS8));
554
555 PrimitiveType conv_output_ty;
556 if (MatchAndLogIfFailed(
557 instr, "s8->s8 conv",
558 m::Convert(m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(-128)),
559 m::GetTupleElement(
560 >e,
561 conv_pattern.WithOperandIfPresent(
562 3, m::Op().WithPredicate(
563 IsLosslesslyConvertibleToS8)),
564 0)
565 .WithOneUse(),
566 m::Broadcast(m::ConstantEffectiveScalar(127))))
567 .WithElementType(S8),
568 VLOG_IS_ON(3),
569 m::Convert(m::Clamp(m::Op(),
570 m::GetTupleElement(
571 m::Op().WithPredicate(IsConvCustomCall)),
572 m::Op()))
573 .WithElementType(S8))) {
574 conv_output_ty = S8;
575 } else if (MatchAndLogIfFailed(
576 instr, "s8->f32 conv",
577 m::GetTupleElement(>e,
578 conv_pattern.WithOperandIfPresent(
579 3, m::Op().WithElementType(F32)),
580 0)
581 .WithElementType(F32),
582 VLOG_IS_ON(3),
583 m::GetTupleElement(m::Op().WithPredicate(IsConvCustomCall))
584 .WithElementType(F32))) {
585 conv_output_ty = F32;
586 } else {
587 continue;
588 }
589 if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
590 return absl::StrCat("FuseConvertToS8: ", conv->ToString());
591 })) {
592 continue;
593 }
594
595 absl::InlinedVector<HloInstruction*, 4> new_operands(
596 conv->operands().begin(), conv->operands().end());
597 new_operands[0] =
598 MakeConvertToHlo(new_operands[0], S8, &new_operands[0]->metadata());
599 new_operands[1] =
600 MakeConvertToHlo(new_operands[1], S8, &new_operands[1]->metadata());
601 // Don't convert bias (operand 2); it's always f32 for s8 ops in cudnn. See
602 // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
603 if (new_operands.size() >= 4) {
604 // side-input always matches conv output type. We checked in the patterns
605 // above that it's losslessly-convertible to this type.
606 new_operands[3] = MakeConvertToHlo(new_operands[3], conv_output_ty,
607 &new_operands[3]->metadata());
608 }
609
610 Shape new_shape = conv->shape();
611 new_shape.mutable_tuple_shapes(0)->set_element_type(conv_output_ty);
612
613 HloInstruction* new_conv = comp->AddInstruction(
614 conv->CloneWithNewOperands(new_shape, new_operands));
615 comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
616 TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
617 MakeGetTupleElementHlo(new_conv, 0));
618 TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
619 changed = true;
620 }
621 return changed;
622 }
623
CheckNoIllegalIntegerConvs(HloComputation * comp)624 Status CheckNoIllegalIntegerConvs(HloComputation* comp) {
625 auto is_integral_not_s8 = [](const Shape& s) {
626 return primitive_util::IsIntegralType(s.element_type()) &&
627 s.element_type() != S8;
628 };
629
630 std::vector<HloInstruction*> bad_convs;
631 for (HloInstruction* instr : comp->instructions()) {
632 if (!IsConvCustomCall(instr)) {
633 continue;
634 }
635 if (is_integral_not_s8(instr->shape().tuple_shapes(0)) ||
636 is_integral_not_s8(instr->operand(0)->shape()) ||
637 is_integral_not_s8(instr->operand(1)->shape()) ||
638 (instr->operand_count() >= 4 &&
639 is_integral_not_s8(instr->operand(3)->shape()))) {
640 bad_convs.push_back(instr);
641 }
642 }
643
644 if (bad_convs.empty()) {
645 return OkStatus();
646 }
647
648 return Unimplemented(
649 R"(
650 Can't lower one or more integer convolutions to idioms supported by CuDNN.
651
652 CuDNN integer convolutions must have:
653
654 - s8 input and filter,
655 - f32 bias (if present),
656 - s8 or f32 output, and
657 - s8 side_input (if present) if output is s8.
658
659 For each of the unsupported convs below, we weren't able to lower one of the
660 operands or the output to the appropriate type.
661
662 See specific HLO idioms in cudnn_fused_conv_rewriter.h, and see cudnn semantics:
663
664 https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward and
665 https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters
666
667 Unsupported convs:
668 %s
669
670 ******* Full HLO module *******
671 %s
672 )",
673 absl::StrJoin(bad_convs, "\n",
674 [](std::string* out, HloInstruction* instr) {
675 absl::StrAppend(out, " - ", instr->ToString());
676 }),
677 comp->parent()->ToString());
678 }
679
VlogStats(HloModule * module)680 void VlogStats(HloModule* module) {
681 if (!VLOG_IS_ON(1)) {
682 return;
683 }
684
685 VLOG(1) << "Results of CudnnFusedConvRewriter for " << module->name();
686 absl::flat_hash_map<std::string, int> stats;
687 for (HloComputation* comp : module->MakeNonfusionComputations()) {
688 for (HloInstruction* instr : comp->instructions()) {
689 if (!Match(instr, m::Op().WithPredicate(IsConvCustomCall))) {
690 continue;
691 }
692
693 VLOG(3) << instr->ToString();
694
695 if (instr->custom_call_target() == kCudnnConvForwardCallTarget) {
696 stats["01 non-fused forward convs"]++;
697 } else if (instr->custom_call_target() ==
698 kCudnnConvBiasActivationForwardCallTarget) {
699 stats["02 fused forward convs"]++;
700 }
701
702 PrimitiveType conv_in_ty = instr->operand(0)->shape().element_type();
703 PrimitiveType conv_out_ty = instr->shape().tuple_shapes(0).element_type();
704 if (conv_in_ty == F32) {
705 stats["10 f32 convs"]++;
706 } else if (conv_in_ty == F16) {
707 stats["11 f16 convs"]++;
708 } else if (conv_in_ty == S8) {
709 if (conv_out_ty == S8) {
710 stats["12 s8->s8 convs"]++;
711 } else if (conv_out_ty == F32) {
712 stats["13 s8->f32 convs"]++;
713 } else {
714 LOG(ERROR) << "Unexpected conv: " << instr->ToString();
715 }
716 }
717
718 if (instr->operand_count() > 2) {
719 stats["20 convs with bias"]++;
720 if (Match(instr->operand(2),
721 m::Broadcast(m::ConstantEffectiveScalar(0)))) {
722 stats["21 convs with 0 bias"]++;
723 }
724 }
725 if (instr->operand_count() > 3) {
726 stats["22 convs with side-input"]++;
727 }
728
729 auto config = instr->backend_config<CudnnConvBackendConfig>();
730 if (!config.ok()) {
731 LOG(ERROR) << "Couldn't parse backend config for " << instr->ToString();
732 continue;
733 }
734
735 if (config->conv_result_scale() != 1) {
736 stats["30 convs with result scale"]++;
737 }
738 if (config->side_input_scale() != 0 && config->side_input_scale() != 1) {
739 stats["31 convs with side-input scale"]++;
740 }
741 stats[absl::StrCat(
742 "32 convs with activation mode ",
743 se::dnn::ActivationMode_Name(config->activation_mode()))]++;
744 }
745 }
746
747 std::vector<std::pair<std::string, int>> stats_sorted(stats.begin(),
748 stats.end());
749 absl::c_sort(stats_sorted);
750 for (const auto& kv : stats_sorted) {
751 VLOG(1) << absl::StreamFormat("%4d %s", kv.second,
752 absl::string_view(kv.first).substr(3));
753 }
754 }
755
756 } // namespace
757
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)758 StatusOr<bool> CudnnFusedConvRewriter::Run(
759 HloModule* module,
760 const absl::flat_hash_set<absl::string_view>& execution_threads) {
761 bool any_changed = false;
762
763 for (HloComputation* comp :
764 module->MakeNonfusionComputations(execution_threads)) {
765 // Fuse "inside out" starting with the operations closest to the conv.
766 bool changed = false;
767
768 TF_ASSIGN_OR_RETURN(changed, FuseConvertToFloat(comp));
769 any_changed |= changed;
770
771 TF_ASSIGN_OR_RETURN(changed, FuseConvAlpha(comp));
772 any_changed |= changed;
773
774 // s8 convs' bias and side-input appear before conversion to s8.
775 //
776 // Run FuseBiasOrSideInput twice, so we get both the bias and the side
777 // input, if both are present.
778 TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
779 any_changed |= changed;
780 TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
781 any_changed |= changed;
782 TF_ASSIGN_OR_RETURN(changed, FuseSideInputAlpha(comp));
783 any_changed |= changed;
784
785 // Relu might appear before or after convert-to-f16/s8, so we check in both
786 // cases.
787 TF_ASSIGN_OR_RETURN(changed, FuseRelu(comp));
788 any_changed |= changed;
789
790 TF_ASSIGN_OR_RETURN(changed, FuseConvertToF16(comp));
791 any_changed |= changed;
792
793 TF_ASSIGN_OR_RETURN(changed, FuseConvertToS8(comp));
794 any_changed |= changed;
795
796 // f16 convs' bias+side-input can appear before or after conversion to f16.
797 TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
798 any_changed |= changed;
799 TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
800 any_changed |= changed;
801 TF_ASSIGN_OR_RETURN(changed, FuseSideInputAlpha(comp));
802 any_changed |= changed;
803
804 TF_ASSIGN_OR_RETURN(changed, FuseRelu(comp));
805 any_changed |= changed;
806
807 // Check that we don't have any convs outputing integer types other than s8.
808 // cudnn does not support these. They should have been transformed to
809 // int8->int8 or int8->float above.
810 TF_RETURN_IF_ERROR(CheckNoIllegalIntegerConvs(comp));
811 }
812
813 VlogStats(module);
814
815 return any_changed;
816 }
817 } // namespace gpu
818 } // namespace xla
819