xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
17 
18 #include <functional>
19 #include <string>
20 
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
24 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
25 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
29 #include "tensorflow/compiler/xla/stream_executor/dnn.pb.h"
30 #include "tensorflow/core/platform/errors.h"
31 #include "tensorflow/core/platform/statusor.h"
32 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
33 
34 namespace xla {
35 namespace gpu {
36 namespace {
37 
38 namespace m = match;
39 
IsConvCustomCall(const HloInstruction * instr)40 bool IsConvCustomCall(const HloInstruction* instr) {
41   return instr->opcode() == HloOpcode::kCustomCall &&
42          (instr->custom_call_target() == kCudnnConvForwardCallTarget ||
43           instr->custom_call_target() ==
44               kCudnnConvBiasActivationForwardCallTarget);
45 }
46 
47 // Can instr be converted to type `dst_ty` without losing any precision?  For
48 // our purposes, this is true if:
49 //
50 //  - instr already has type dst_ty, or
51 //  - instr is convert<wider type>(op_with_dst_ty), or
52 //  - instr is a constant which we can convert orig_ty -> dst_ty -> orig_ty and
53 //    get back exactly the original value, or
54 //  - instr is a broadcast, reshape, or transpose of one of the above.
IsLosslesslyConvertibleTo(const HloInstruction * instr,PrimitiveType dst_ty)55 bool IsLosslesslyConvertibleTo(const HloInstruction* instr,
56                                PrimitiveType dst_ty) {
57   if (instr->shape().element_type() == dst_ty) {
58     return true;
59   }
60 
61   if (Match(instr, m::Convert(m::Op().WithElementType(dst_ty)))) {
62     // Check that the convert from dst_ty to instr->element_type() doesn't lose
63     // precision.  Otherwise, this convert is not lossless.
64     return primitive_util::CastPreservesValues(dst_ty,
65                                                instr->shape().element_type());
66   }
67 
68   if (instr->opcode() == HloOpcode::kConstant) {
69     if (!instr->shape().IsArray()) {
70       return false;
71     }
72     // Check if instr's literal roundtrips to ty and back to its original type
73     // without modification.
74     PrimitiveType orig_ty = instr->shape().element_type();
75 
76     // The only reason Convert() should fail is if we don't support converting
77     // from x to y, which indeed means it's not losslessly-convertible.
78     StatusOr<Literal> converted1 = instr->literal().Convert(dst_ty);
79     if (!converted1.ok()) {
80       return false;
81     }
82     StatusOr<Literal> converted2 = converted1->Convert(orig_ty);
83     if (!converted2.ok()) {
84       return false;
85     }
86 
87     return instr->literal() == *converted2;
88   }
89 
90   if (instr->opcode() == HloOpcode::kBroadcast ||
91       instr->opcode() == HloOpcode::kReshape ||
92       instr->opcode() == HloOpcode::kTranspose) {
93     return IsLosslesslyConvertibleTo(instr->operand(0), dst_ty);
94   }
95 
96   return false;
97 }
98 
99 // Helpers suitable for use in m::Op().WithPredicate(...).
IsLosslesslyConvertibleToS8(const HloInstruction * instr)100 bool IsLosslesslyConvertibleToS8(const HloInstruction* instr) {
101   return IsLosslesslyConvertibleTo(instr, S8);
102 }
IsLosslesslyConvertibleToF16(const HloInstruction * instr)103 bool IsLosslesslyConvertibleToF16(const HloInstruction* instr) {
104   return IsLosslesslyConvertibleTo(instr, F16);
105 }
106 
107 // If `conv` is a vanilla forward conv, transforms it into a
108 // conv-bias-activation.  If it's already a conv-bias-activation, does nothing.
109 //
110 // If `conv` is anything else, returns an error.
EnsureIsConvBiasActivation(HloInstruction * conv)111 StatusOr<HloInstruction*> EnsureIsConvBiasActivation(HloInstruction* conv) {
112   CHECK_EQ(conv->opcode(), HloOpcode::kCustomCall);
113 
114   if (conv->custom_call_target() == kCudnnConvBiasActivationForwardCallTarget) {
115     return conv;
116   }
117 
118   if (conv->custom_call_target() == kCudnnConvForwardCallTarget) {
119     HloComputation* comp = conv->parent();
120 
121     const Shape& shape = conv->shape().tuple_shapes(0);
122     int64_t num_output_features = shape.dimensions(
123         conv->convolution_dimension_numbers().output_feature_dimension());
124 
125     // bias for integer convs is always f32, see
126     // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
127     PrimitiveType bias_ty;
128     if (primitive_util::IsIntegralType(shape.element_type())) {
129       bias_ty = F32;
130     } else {
131       bias_ty = shape.element_type();
132     }
133     auto bias = BroadcastZeros(comp, bias_ty, {num_output_features});
134 
135     absl::InlinedVector<HloInstruction*, 3> new_operands(
136         conv->operands().begin(), conv->operands().end());
137     new_operands.push_back(bias);
138 
139     HloInstruction* new_conv = comp->AddInstruction(
140         conv->CloneWithNewOperands(conv->shape(), new_operands));
141     TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv));
142     new_conv->set_custom_call_target(kCudnnConvBiasActivationForwardCallTarget);
143     comp->parent()->SetAndUniquifyInstrName(new_conv,
144                                             "cudnn-conv-bias-activation");
145     return new_conv;
146   }
147 
148   return FailedPrecondition("Unsupported conv: %s", conv->ToString());
149 }
150 
151 // convert<float>(gte(custom-call<int32>(int8_x, int8_w))) ->
152 // gte(custom-call<float>(int8_x, int8_w))
FuseConvertToFloat(HloComputation * comp)153 StatusOr<bool> FuseConvertToFloat(HloComputation* comp) {
154   bool changed = false;
155   for (auto instr : comp->MakeInstructionPostOrder()) {
156     HloInstruction* conv = nullptr;
157     auto pattern =
158         m::Convert(
159             m::GetTupleElement(m::Op(&conv).WithPredicate(IsConvCustomCall), 0)
160                 .WithElementType(S32))
161             .WithElementType(F32);
162     if (!Match(instr, pattern)) {
163       continue;
164     }
165     if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
166           return absl::StrCat("FuseConvertToFloat: ", conv->ToString());
167         })) {
168       continue;
169     }
170 
171     Shape new_shape = conv->shape();
172     new_shape.mutable_tuple_shapes(0)->set_element_type(F32);
173     HloInstruction* new_conv =
174         comp->AddInstruction(conv->CloneWithNewShape(new_shape));
175     comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
176     TF_ASSIGN_OR_RETURN(HloInstruction * new_gte,
177                         MakeGetTupleElementHlo(new_conv, 0));
178     TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_gte));
179 
180     changed = true;
181   }
182 
183   return changed;
184 }
185 
186 // alpha * gte(custom-call(...)) ->
187 // gte(custom-call(..., backend_config={alpha})).
FuseConvAlpha(HloComputation * comp)188 StatusOr<bool> FuseConvAlpha(HloComputation* comp) {
189   bool changed = false;
190   for (auto instr : comp->MakeInstructionPostOrder()) {
191     HloInstruction* conv = nullptr;
192     HloInstruction* gte = nullptr;
193     HloInstruction* alpha = nullptr;
194     auto pattern = m::MultiplyAnyOrder(
195         m::GetTupleElement(&gte, m::Op(&conv).WithPredicate(IsConvCustomCall),
196                            0)
197             .WithOneUse(),
198         m::Broadcast(m::ConstantEffectiveScalar(&alpha)));
199     if (!Match(instr, pattern)) {
200       continue;
201     }
202 
203     // alpha is f32 except for f64 convs, where it's f64.  See
204     // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
205     PrimitiveType alpha_ty = gte->shape().element_type() == F64 ? F64 : F32;
206     if (!IsLosslesslyConvertibleTo(alpha, alpha_ty)) {
207       continue;
208     }
209 
210     TF_ASSIGN_OR_RETURN(auto config,
211                         conv->backend_config<CudnnConvBackendConfig>());
212     if (config.conv_result_scale() != 1) {
213       continue;
214     }
215     if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
216           return absl::StrCat("FuseConvAlpha: ", conv->ToString());
217         })) {
218       continue;
219     }
220 
221     // StreamExecutor doesn't support the alpha parameter on non-bias-activation
222     // convs, so we have to upgrade `conv`.
223     TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
224 
225     TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
226     config.set_conv_result_scale(alpha_f64.GetFirstElement<double>());
227 
228     TF_RETURN_IF_ERROR(conv->set_backend_config(config));
229     TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(instr, gte));
230 
231     changed = true;
232   }
233   return changed;
234 }
235 
FuseBiasOrSideInput(HloComputation * comp)236 StatusOr<bool> FuseBiasOrSideInput(HloComputation* comp) {
237   bool changed = false;
238   for (auto instr : comp->MakeInstructionPostOrder()) {
239     HloInstruction* conv = nullptr;
240     HloInstruction* gte = nullptr;
241     HloInstruction* addend = nullptr;
242     auto pattern = m::AddAnyOrder(
243         m::GetTupleElement(
244             &gte, m::Op(&conv).WithPredicate(IsConvCustomCall).WithOneUse(), 0)
245             .WithOneUse(),
246         m::Op(&addend));
247     if (!Match(instr, pattern)) {
248       continue;
249     }
250 
251     // If it's a vanilla forward conv, upgrade it to a bias-activation conv.  We
252     // only want to do this if the fusion will succeed, but we're guaranteed
253     // that it will, because the only reason we'll bail at this point is if
254     // !can_accept_bias && !can_accept_side_input, and our shiny new
255     // bias-activation conv will be able to accept both.
256     if (conv->custom_call_target() == kCudnnConvForwardCallTarget) {
257       TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
258     }
259 
260     // Can't fuse bias or side-input if the conv already has a relu (or other
261     // activation), because bias and side-input are added before the activation
262     // is applied.
263     TF_ASSIGN_OR_RETURN(auto config,
264                         conv->backend_config<CudnnConvBackendConfig>());
265     if (config.activation_mode() != se::dnn::kNone) {
266       continue;
267     }
268 
269     // Does `conv` already have a (nonzero) bias?  Does it already have a
270     // side_input?
271     bool can_accept_bias =
272         Match(conv->operand(2), m::Broadcast(m::ConstantEffectiveScalar(0)));
273     bool can_accept_side_input = conv->operand_count() < 4;
274 
275     // The addend can be fused as a bias if
276     //  - it is 1D broadcasted in the output feature dimension, and
277     //  - it is losslessly-convertible to the correct type (f32 for s8/f32/u32
278     //    convs, and conv_ty for floating-point convs)
279     PrimitiveType conv_ty = gte->shape().element_type();
280     PrimitiveType bias_ty =
281         primitive_util::IsFloatingPointType(conv_ty) ? conv_ty : F32;
282     bool addend_may_be_rank1_bias =
283         addend->opcode() == HloOpcode::kBroadcast &&
284         addend->dimensions().size() == 1 &&
285         addend->dimensions(0) ==
286             conv->convolution_dimension_numbers().output_feature_dimension() &&
287         IsLosslesslyConvertibleTo(addend, bias_ty);
288 
289     bool addend_may_be_rank0_bias = addend->opcode() == HloOpcode::kBroadcast &&
290                                     addend->dimensions().empty() &&
291                                     IsLosslesslyConvertibleTo(addend, bias_ty);
292 
293     absl::InlinedVector<HloInstruction*, 4> new_operands(
294         conv->operands().begin(), conv->operands().end());
295     if (can_accept_bias && addend_may_be_rank1_bias) {
296       new_operands[2] = MakeConvertToHlo(addend->mutable_operand(0), bias_ty,
297                                          &addend->operand(0)->metadata());
298     } else if (can_accept_bias && addend_may_be_rank0_bias) {
299       new_operands[2] = MakeBroadcastHlo(
300           MakeConvertToHlo(addend->mutable_operand(0), bias_ty,
301                            &addend->operand(0)->metadata()),
302           /*broadcast_dimensions=*/{},
303           /*result_shape_bounds=*/
304           {gte->shape().dimensions(conv->convolution_dimension_numbers()
305                                        .output_feature_dimension())});
306     } else if (can_accept_side_input) {
307       CHECK_EQ(new_operands.size(), 3);
308       new_operands.push_back(addend);
309       config.set_side_input_scale(1);
310     } else {
311       // Can't fuse; this op already has a bias and a side-input.
312       continue;
313     }
314 
315     if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
316           return absl::StrCat("FuseBiasOrSideInput: ", conv->ToString());
317         })) {
318       continue;
319     }
320 
321     HloInstruction* new_conv = comp->AddInstruction(
322         conv->CloneWithNewOperands(conv->shape(), new_operands));
323     comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
324     TF_RETURN_IF_ERROR(new_conv->set_backend_config(config));
325     TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
326                         MakeGetTupleElementHlo(new_conv, 0));
327     TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
328     changed = true;
329   }
330   return changed;
331 }
332 
333 // custom-call(..., alpha * side_input) ->
334 // custom-call(..., side_input, backend_config={alpha}).
335 //
336 // We also have to support the more complicated case of
337 //
338 //   custom-call(..., reshape(side_input * alpha)) -->
339 //   custom-call(..., reshape(side_input), backend_config={alpha}),
340 //
341 // where `reshape` can be an arbitrary chain of reshapes+transposes.  This idiom
342 // is created by the ReshapeMover pass.
FuseSideInputAlpha(HloComputation * comp)343 StatusOr<bool> FuseSideInputAlpha(HloComputation* comp) {
344   bool changed = false;
345   for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
346     HloInstruction* conv;
347     HloInstruction* side_input;
348     auto pattern = m::Op(&conv)
349                        .WithPredicate(IsConvCustomCall)
350                        .WithOperand(3, m::Op(&side_input));
351     if (!Match(instr, pattern)) {
352       continue;
353     }
354     TF_ASSIGN_OR_RETURN(auto config,
355                         conv->backend_config<CudnnConvBackendConfig>());
356     if (config.side_input_scale() != 1) {
357       continue;
358     }
359 
360     // Given side_input, pattern match the following (working from bottom up).
361     //
362     // before_reshape = multiply(base, broadcast(alpha))
363     // side_input = chain_of_reshapes_and_transposes(before_reshape)
364     //
365     // where alpha is a scalar constant.
366     //
367     // alpha is f32 except for f64 convs, where it's f64.  See
368     // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
369     HloInstruction* before_reshape = side_input;
370     while (before_reshape->opcode() == HloOpcode::kReshape ||
371            before_reshape->opcode() == HloOpcode::kTranspose) {
372       before_reshape = before_reshape->mutable_operand(0);
373     }
374 
375     PrimitiveType conv_ty = conv->shape().tuple_shapes(0).element_type();
376     PrimitiveType alpha_ty = conv_ty == F64 ? F64 : F32;
377     HloInstruction* base;
378     HloInstruction* alpha;
379     if (!Match(
380             before_reshape,
381             m::MultiplyAnyOrder(
382                 m::Op(&base),
383                 m::Broadcast(m::ConstantEffectiveScalar(&alpha).WithPredicate(
384                     [&](const HloInstruction* instr) {
385                       return IsLosslesslyConvertibleTo(instr, alpha_ty);
386                     }))))) {
387       continue;
388     }
389     if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
390           return absl::StrCat("FuseSideInputAlpha: ", conv->ToString());
391         })) {
392       continue;
393     }
394 
395     // Rewrite conv's operand 3 to
396     //
397     //   chain_of_reshapes_and_transposes(before_reshape).
398     //
399     // and store alpha in the conv's backend config.
400     //
401     // We're going to do something bad here: We aren't going to check that the
402     // chain of reshapes/transposes has one use, so we're potentially
403     // duplicating all these instructions (once with alpha and once without).
404     //
405     // This is justified because
406     //
407     //  - duplicating reshapes/transposes shouldn't be "that bad" -- these
408     //    instructions can usually be fused, and
409     //
410     //  - *not* fusing alpha can be catastrophic.  For s8->s8 convolutions, the
411     //    side-input must be s8.  But the product side_input * alpha is f32, so
412     //    we can only see that side-input is s8 if we fuse alpha. IOW not fusing
413     //    alpha means we'll run this s8->s8 conv as s8->f32, which is *much*
414     //    slower than some extra transposes.
415 
416     // Recursively clone chain_of_reshapes_and_transposes until we get to
417     // `before_reshape`, at which point we skip the multiply(base, alpha) and
418     // just return base.
419     std::function<HloInstruction*(const HloInstruction*)> clone =
420         [&](const HloInstruction* instr) {
421           if (instr == before_reshape) {
422             return base;
423           }
424           CHECK(instr->opcode() == HloOpcode::kReshape ||
425                 instr->opcode() == HloOpcode::kTranspose)
426               << "Must be reshape or transpose: " << instr->ToString();
427           return comp->AddInstruction(instr->CloneWithNewOperands(
428               instr->shape(), {clone(instr->operand(0))}));
429         };
430     absl::InlinedVector<HloInstruction*, 4> new_operands(
431         conv->operands().begin(), conv->operands().end());
432     new_operands[3] = clone(side_input);
433 
434     HloInstruction* new_conv = comp->AddInstruction(
435         conv->CloneWithNewOperands(conv->shape(), new_operands));
436     comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
437 
438     TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
439     config.set_side_input_scale(alpha_f64.GetFirstElement<double>());
440     TF_RETURN_IF_ERROR(new_conv->set_backend_config(config));
441 
442     TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv));
443     changed = true;
444   }
445   return changed;
446 }
447 
FuseRelu(HloComputation * comp)448 StatusOr<bool> FuseRelu(HloComputation* comp) {
449   bool changed = false;
450   for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
451     HloInstruction* gte;
452     HloInstruction* conv;
453     if (!Match(
454             instr,
455             m::MaximumAnyOrder(
456                 m::Broadcast(m::ConstantEffectiveScalar(0)),
457                 m::GetTupleElement(
458                     &gte,
459                     m::Op(&conv).WithPredicate(IsConvCustomCall).WithOneUse())
460                     .WithOneUse()))) {
461       continue;
462     }
463     TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config,
464                         conv->backend_config<CudnnConvBackendConfig>());
465     if (config.activation_mode() != se::dnn::kNone) {
466       continue;
467     }
468 
469     if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
470           return absl::StrCat("FuseRelu: ", conv->ToString());
471         })) {
472       continue;
473     }
474     TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
475     config.set_activation_mode(se::dnn::kRelu);
476     TF_RETURN_IF_ERROR(conv->set_backend_config(config));
477     TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte));
478     changed = true;
479   }
480   return changed;
481 }
482 
FuseConvertToF16(HloComputation * comp)483 StatusOr<bool> FuseConvertToF16(HloComputation* comp) {
484   bool changed = false;
485   for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
486     HloInstruction* gte = nullptr;
487     HloInstruction* conv = nullptr;
488 
489     auto f32_convertible_to_f16_pat =
490         m::Op().WithElementType(F32).WithPredicate(
491             IsLosslesslyConvertibleToF16);
492     if (!MatchAndLogIfFailed(
493             instr, "f16 conv",
494             m::Convert(
495                 m::GetTupleElement(
496                     &gte,
497                     m::Op(&conv)
498                         .WithPredicate(IsConvCustomCall)
499                         .WithOperand(0, f32_convertible_to_f16_pat)
500                         .WithOperand(1, f32_convertible_to_f16_pat)
501                         .WithOperandIfPresent(2, f32_convertible_to_f16_pat)
502                         .WithOperandIfPresent(3, f32_convertible_to_f16_pat),
503                     0)
504                     .WithOneUse())
505                 .WithElementType(F16),
506             VLOG_IS_ON(3),
507             m::Op().WithOperand(0, m::GetTupleElement(m::Op().WithPredicate(
508                                        IsConvCustomCall))))) {
509       continue;
510     }
511     if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
512           return absl::StrCat("FuseConvertToF16: ", conv->ToString());
513         })) {
514       continue;
515     }
516 
517     VLOG(2) << "Matched fp16 conv: " << conv->ToString();
518 
519     // In fp16 convs, all operands, including `bias`, must be fp16.  This is
520     // different from int8 convs, where the bias is fp32.  See table of
521     // supported datatypes at
522     // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
523     absl::InlinedVector<HloInstruction*, 4> new_operands;
524     for (HloInstruction* operand : conv->operands()) {
525       new_operands.push_back(
526           MakeConvertToHlo(operand, F16, &operand->metadata()));
527     }
528 
529     Shape new_shape = conv->shape();
530     new_shape.mutable_tuple_shapes(0)->set_element_type(F16);
531 
532     HloInstruction* new_conv = comp->AddInstruction(
533         conv->CloneWithNewOperands(new_shape, new_operands));
534     comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
535     TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
536                         MakeGetTupleElementHlo(new_conv, 0));
537     TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
538     changed = true;
539   }
540   return changed;
541 }
542 
FuseConvertToS8(HloComputation * comp)543 StatusOr<bool> FuseConvertToS8(HloComputation* comp) {
544   bool changed = false;
545   for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
546     HloInstruction* gte = nullptr;
547     HloInstruction* conv = nullptr;
548 
549     auto conv_pattern =
550         m::Op(&conv)
551             .WithPredicate(IsConvCustomCall)
552             .WithOperand(0, m::Op().WithPredicate(IsLosslesslyConvertibleToS8))
553             .WithOperand(1, m::Op().WithPredicate(IsLosslesslyConvertibleToS8));
554 
555     PrimitiveType conv_output_ty;
556     if (MatchAndLogIfFailed(
557             instr, "s8->s8 conv",
558             m::Convert(m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(-128)),
559                                 m::GetTupleElement(
560                                     &gte,
561                                     conv_pattern.WithOperandIfPresent(
562                                         3, m::Op().WithPredicate(
563                                                IsLosslesslyConvertibleToS8)),
564                                     0)
565                                     .WithOneUse(),
566                                 m::Broadcast(m::ConstantEffectiveScalar(127))))
567                 .WithElementType(S8),
568             VLOG_IS_ON(3),
569             m::Convert(m::Clamp(m::Op(),
570                                 m::GetTupleElement(
571                                     m::Op().WithPredicate(IsConvCustomCall)),
572                                 m::Op()))
573                 .WithElementType(S8))) {
574       conv_output_ty = S8;
575     } else if (MatchAndLogIfFailed(
576                    instr, "s8->f32 conv",
577                    m::GetTupleElement(&gte,
578                                       conv_pattern.WithOperandIfPresent(
579                                           3, m::Op().WithElementType(F32)),
580                                       0)
581                        .WithElementType(F32),
582                    VLOG_IS_ON(3),
583                    m::GetTupleElement(m::Op().WithPredicate(IsConvCustomCall))
584                        .WithElementType(F32))) {
585       conv_output_ty = F32;
586     } else {
587       continue;
588     }
589     if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
590           return absl::StrCat("FuseConvertToS8: ", conv->ToString());
591         })) {
592       continue;
593     }
594 
595     absl::InlinedVector<HloInstruction*, 4> new_operands(
596         conv->operands().begin(), conv->operands().end());
597     new_operands[0] =
598         MakeConvertToHlo(new_operands[0], S8, &new_operands[0]->metadata());
599     new_operands[1] =
600         MakeConvertToHlo(new_operands[1], S8, &new_operands[1]->metadata());
601     // Don't convert bias (operand 2); it's always f32 for s8 ops in cudnn.  See
602     // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
603     if (new_operands.size() >= 4) {
604       // side-input always matches conv output type.  We checked in the patterns
605       // above that it's losslessly-convertible to this type.
606       new_operands[3] = MakeConvertToHlo(new_operands[3], conv_output_ty,
607                                          &new_operands[3]->metadata());
608     }
609 
610     Shape new_shape = conv->shape();
611     new_shape.mutable_tuple_shapes(0)->set_element_type(conv_output_ty);
612 
613     HloInstruction* new_conv = comp->AddInstruction(
614         conv->CloneWithNewOperands(new_shape, new_operands));
615     comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
616     TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
617                         MakeGetTupleElementHlo(new_conv, 0));
618     TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
619     changed = true;
620   }
621   return changed;
622 }
623 
CheckNoIllegalIntegerConvs(HloComputation * comp)624 Status CheckNoIllegalIntegerConvs(HloComputation* comp) {
625   auto is_integral_not_s8 = [](const Shape& s) {
626     return primitive_util::IsIntegralType(s.element_type()) &&
627            s.element_type() != S8;
628   };
629 
630   std::vector<HloInstruction*> bad_convs;
631   for (HloInstruction* instr : comp->instructions()) {
632     if (!IsConvCustomCall(instr)) {
633       continue;
634     }
635     if (is_integral_not_s8(instr->shape().tuple_shapes(0)) ||
636         is_integral_not_s8(instr->operand(0)->shape()) ||
637         is_integral_not_s8(instr->operand(1)->shape()) ||
638         (instr->operand_count() >= 4 &&
639          is_integral_not_s8(instr->operand(3)->shape()))) {
640       bad_convs.push_back(instr);
641     }
642   }
643 
644   if (bad_convs.empty()) {
645     return OkStatus();
646   }
647 
648   return Unimplemented(
649       R"(
650 Can't lower one or more integer convolutions to idioms supported by CuDNN.
651 
652 CuDNN integer convolutions must have:
653 
654   - s8 input and filter,
655   - f32 bias (if present),
656   - s8 or f32 output, and
657   - s8 side_input (if present) if output is s8.
658 
659 For each of the unsupported convs below, we weren't able to lower one of the
660 operands or the output to the appropriate type.
661 
662 See specific HLO idioms in cudnn_fused_conv_rewriter.h, and see cudnn semantics:
663 
664 https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward and
665 https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters
666 
667 Unsupported convs:
668 %s
669 
670 ******* Full HLO module *******
671 %s
672 )",
673       absl::StrJoin(bad_convs, "\n",
674                     [](std::string* out, HloInstruction* instr) {
675                       absl::StrAppend(out, " - ", instr->ToString());
676                     }),
677       comp->parent()->ToString());
678 }
679 
VlogStats(HloModule * module)680 void VlogStats(HloModule* module) {
681   if (!VLOG_IS_ON(1)) {
682     return;
683   }
684 
685   VLOG(1) << "Results of CudnnFusedConvRewriter for " << module->name();
686   absl::flat_hash_map<std::string, int> stats;
687   for (HloComputation* comp : module->MakeNonfusionComputations()) {
688     for (HloInstruction* instr : comp->instructions()) {
689       if (!Match(instr, m::Op().WithPredicate(IsConvCustomCall))) {
690         continue;
691       }
692 
693       VLOG(3) << instr->ToString();
694 
695       if (instr->custom_call_target() == kCudnnConvForwardCallTarget) {
696         stats["01 non-fused forward convs"]++;
697       } else if (instr->custom_call_target() ==
698                  kCudnnConvBiasActivationForwardCallTarget) {
699         stats["02 fused forward convs"]++;
700       }
701 
702       PrimitiveType conv_in_ty = instr->operand(0)->shape().element_type();
703       PrimitiveType conv_out_ty = instr->shape().tuple_shapes(0).element_type();
704       if (conv_in_ty == F32) {
705         stats["10 f32 convs"]++;
706       } else if (conv_in_ty == F16) {
707         stats["11 f16 convs"]++;
708       } else if (conv_in_ty == S8) {
709         if (conv_out_ty == S8) {
710           stats["12 s8->s8 convs"]++;
711         } else if (conv_out_ty == F32) {
712           stats["13 s8->f32 convs"]++;
713         } else {
714           LOG(ERROR) << "Unexpected conv: " << instr->ToString();
715         }
716       }
717 
718       if (instr->operand_count() > 2) {
719         stats["20 convs with bias"]++;
720         if (Match(instr->operand(2),
721                   m::Broadcast(m::ConstantEffectiveScalar(0)))) {
722           stats["21 convs with 0 bias"]++;
723         }
724       }
725       if (instr->operand_count() > 3) {
726         stats["22 convs with side-input"]++;
727       }
728 
729       auto config = instr->backend_config<CudnnConvBackendConfig>();
730       if (!config.ok()) {
731         LOG(ERROR) << "Couldn't parse backend config for " << instr->ToString();
732         continue;
733       }
734 
735       if (config->conv_result_scale() != 1) {
736         stats["30 convs with result scale"]++;
737       }
738       if (config->side_input_scale() != 0 && config->side_input_scale() != 1) {
739         stats["31 convs with side-input scale"]++;
740       }
741       stats[absl::StrCat(
742           "32 convs with activation mode ",
743           se::dnn::ActivationMode_Name(config->activation_mode()))]++;
744     }
745   }
746 
747   std::vector<std::pair<std::string, int>> stats_sorted(stats.begin(),
748                                                         stats.end());
749   absl::c_sort(stats_sorted);
750   for (const auto& kv : stats_sorted) {
751     VLOG(1) << absl::StreamFormat("%4d %s", kv.second,
752                                   absl::string_view(kv.first).substr(3));
753   }
754 }
755 
756 }  // namespace
757 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)758 StatusOr<bool> CudnnFusedConvRewriter::Run(
759     HloModule* module,
760     const absl::flat_hash_set<absl::string_view>& execution_threads) {
761   bool any_changed = false;
762 
763   for (HloComputation* comp :
764        module->MakeNonfusionComputations(execution_threads)) {
765     // Fuse "inside out" starting with the operations closest to the conv.
766     bool changed = false;
767 
768     TF_ASSIGN_OR_RETURN(changed, FuseConvertToFloat(comp));
769     any_changed |= changed;
770 
771     TF_ASSIGN_OR_RETURN(changed, FuseConvAlpha(comp));
772     any_changed |= changed;
773 
774     // s8 convs' bias and side-input appear before conversion to s8.
775     //
776     // Run FuseBiasOrSideInput twice, so we get both the bias and the side
777     // input, if both are present.
778     TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
779     any_changed |= changed;
780     TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
781     any_changed |= changed;
782     TF_ASSIGN_OR_RETURN(changed, FuseSideInputAlpha(comp));
783     any_changed |= changed;
784 
785     // Relu might appear before or after convert-to-f16/s8, so we check in both
786     // cases.
787     TF_ASSIGN_OR_RETURN(changed, FuseRelu(comp));
788     any_changed |= changed;
789 
790     TF_ASSIGN_OR_RETURN(changed, FuseConvertToF16(comp));
791     any_changed |= changed;
792 
793     TF_ASSIGN_OR_RETURN(changed, FuseConvertToS8(comp));
794     any_changed |= changed;
795 
796     // f16 convs' bias+side-input can appear before or after conversion to f16.
797     TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
798     any_changed |= changed;
799     TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
800     any_changed |= changed;
801     TF_ASSIGN_OR_RETURN(changed, FuseSideInputAlpha(comp));
802     any_changed |= changed;
803 
804     TF_ASSIGN_OR_RETURN(changed, FuseRelu(comp));
805     any_changed |= changed;
806 
807     // Check that we don't have any convs outputing integer types other than s8.
808     // cudnn does not support these.  They should have been transformed to
809     // int8->int8 or int8->float above.
810     TF_RETURN_IF_ERROR(CheckNoIllegalIntegerConvs(comp));
811   }
812 
813   VlogStats(module);
814 
815   return any_changed;
816 }
817 }  // namespace gpu
818 }  // namespace xla
819