1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/bfloat16_support.h"
17 #include "tensorflow/compiler/xla/service/hlo_computation.h"
18 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
19 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
20
21 namespace xla {
22
SupportsBF16Operand(const HloInstruction & hlo,int64_t operand_index) const23 bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo,
24 int64_t operand_index) const {
25 switch (hlo.opcode()) {
26 case HloOpcode::kCall:
27 case HloOpcode::kConditional:
28 case HloOpcode::kCustomCall:
29 case HloOpcode::kDomain:
30 case HloOpcode::kGetTupleElement:
31 case HloOpcode::kTuple:
32 case HloOpcode::kWhile:
33 case HloOpcode::kOptimizationBarrier:
34 return true;
35 case HloOpcode::kConvert:
36 CHECK_EQ(operand_index, 0);
37 return hlo.operand(0)->shape().element_type() == BF16;
38 default:
39 break;
40 }
41 return false;
42 }
43
SupportsBF16Output(const HloInstruction & hlo) const44 bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const {
45 switch (hlo.opcode()) {
46 case HloOpcode::kCall:
47 case HloOpcode::kConditional:
48 case HloOpcode::kCustomCall:
49 case HloOpcode::kDomain:
50 case HloOpcode::kGetTupleElement:
51 case HloOpcode::kTuple:
52 case HloOpcode::kWhile:
53 case HloOpcode::kOptimizationBarrier:
54 return true;
55 case HloOpcode::kConvert:
56 return hlo.shape().element_type() == BF16;
57 default:
58 break;
59 }
60 return false;
61 }
62
SupportsMixedPrecisions(const HloInstruction & hlo) const63 bool BFloat16Support::SupportsMixedPrecisions(const HloInstruction& hlo) const {
64 switch (hlo.opcode()) {
65 case HloOpcode::kCall:
66 case HloOpcode::kConditional:
67 case HloOpcode::kConvert:
68 case HloOpcode::kCustomCall:
69 case HloOpcode::kGetTupleElement:
70 case HloOpcode::kTuple:
71 case HloOpcode::kWhile:
72 case HloOpcode::kOptimizationBarrier:
73 return true;
74 default:
75 break;
76 }
77 return false;
78 }
79
80 /* static */
EffectiveOperandPrecisionIsOutputPrecision(const HloInstruction & hlo,int64_t operand_index)81 bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision(
82 const HloInstruction& hlo, int64_t operand_index) {
83 switch (hlo.opcode()) {
84 case HloOpcode::kAbs:
85 case HloOpcode::kAllGather:
86 case HloOpcode::kAllToAll:
87 case HloOpcode::kBroadcast:
88 case HloOpcode::kClamp:
89 case HloOpcode::kCollectivePermute:
90 case HloOpcode::kConcatenate:
91 case HloOpcode::kConvert:
92 case HloOpcode::kCopy:
93 case HloOpcode::kDomain:
94 case HloOpcode::kGetTupleElement:
95 case HloOpcode::kMaximum:
96 case HloOpcode::kMinimum:
97 case HloOpcode::kPad:
98 case HloOpcode::kReshape:
99 case HloOpcode::kReverse:
100 case HloOpcode::kSlice:
101 case HloOpcode::kSort:
102 case HloOpcode::kTranspose:
103 case HloOpcode::kTuple:
104 case HloOpcode::kOptimizationBarrier:
105 return true;
106 case HloOpcode::kBitcast:
107 return hlo.shape().element_type() ==
108 hlo.operand(0)->shape().element_type();
109 case HloOpcode::kDynamicSlice:
110 return operand_index == 0;
111 case HloOpcode::kDynamicUpdateSlice:
112 return operand_index == 0 || operand_index == 1;
113 case HloOpcode::kGather:
114 return operand_index == 0;
115 case HloOpcode::kSelect:
116 return operand_index == 1 || operand_index == 2;
117 case HloOpcode::kReduce:
118 case HloOpcode::kReduceWindow: {
119 HloComputation* reduce_comp = hlo.called_computations()[0];
120 for (HloInstruction* inst : reduce_comp->instructions()) {
121 if (inst->opcode() == HloOpcode::kParameter) {
122 continue;
123 }
124 for (int64_t i = 0; i < inst->operand_count(); ++i) {
125 if (!EffectiveOperandPrecisionIsOutputPrecision(*inst, i)) {
126 return false;
127 }
128 }
129 }
130 return true;
131 }
132 default:
133 break;
134 }
135 return false;
136 }
137
EffectiveOperandPrecisionIsBF16(const HloInstruction & hlo,int64_t operand_index) const138 bool BFloat16Support::EffectiveOperandPrecisionIsBF16(
139 const HloInstruction& hlo, int64_t operand_index) const {
140 return false;
141 }
142
143 } // namespace xla
144