xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/bfloat16_support.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/bfloat16_support.h"
17 #include "tensorflow/compiler/xla/service/hlo_computation.h"
18 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
19 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
20 
21 namespace xla {
22 
SupportsBF16Operand(const HloInstruction & hlo,int64_t operand_index) const23 bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo,
24                                           int64_t operand_index) const {
25   switch (hlo.opcode()) {
26     case HloOpcode::kCall:
27     case HloOpcode::kConditional:
28     case HloOpcode::kCustomCall:
29     case HloOpcode::kDomain:
30     case HloOpcode::kGetTupleElement:
31     case HloOpcode::kTuple:
32     case HloOpcode::kWhile:
33     case HloOpcode::kOptimizationBarrier:
34       return true;
35     case HloOpcode::kConvert:
36       CHECK_EQ(operand_index, 0);
37       return hlo.operand(0)->shape().element_type() == BF16;
38     default:
39       break;
40   }
41   return false;
42 }
43 
SupportsBF16Output(const HloInstruction & hlo) const44 bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const {
45   switch (hlo.opcode()) {
46     case HloOpcode::kCall:
47     case HloOpcode::kConditional:
48     case HloOpcode::kCustomCall:
49     case HloOpcode::kDomain:
50     case HloOpcode::kGetTupleElement:
51     case HloOpcode::kTuple:
52     case HloOpcode::kWhile:
53     case HloOpcode::kOptimizationBarrier:
54       return true;
55     case HloOpcode::kConvert:
56       return hlo.shape().element_type() == BF16;
57     default:
58       break;
59   }
60   return false;
61 }
62 
SupportsMixedPrecisions(const HloInstruction & hlo) const63 bool BFloat16Support::SupportsMixedPrecisions(const HloInstruction& hlo) const {
64   switch (hlo.opcode()) {
65     case HloOpcode::kCall:
66     case HloOpcode::kConditional:
67     case HloOpcode::kConvert:
68     case HloOpcode::kCustomCall:
69     case HloOpcode::kGetTupleElement:
70     case HloOpcode::kTuple:
71     case HloOpcode::kWhile:
72     case HloOpcode::kOptimizationBarrier:
73       return true;
74     default:
75       break;
76   }
77   return false;
78 }
79 
80 /* static */
EffectiveOperandPrecisionIsOutputPrecision(const HloInstruction & hlo,int64_t operand_index)81 bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision(
82     const HloInstruction& hlo, int64_t operand_index) {
83   switch (hlo.opcode()) {
84     case HloOpcode::kAbs:
85     case HloOpcode::kAllGather:
86     case HloOpcode::kAllToAll:
87     case HloOpcode::kBroadcast:
88     case HloOpcode::kClamp:
89     case HloOpcode::kCollectivePermute:
90     case HloOpcode::kConcatenate:
91     case HloOpcode::kConvert:
92     case HloOpcode::kCopy:
93     case HloOpcode::kDomain:
94     case HloOpcode::kGetTupleElement:
95     case HloOpcode::kMaximum:
96     case HloOpcode::kMinimum:
97     case HloOpcode::kPad:
98     case HloOpcode::kReshape:
99     case HloOpcode::kReverse:
100     case HloOpcode::kSlice:
101     case HloOpcode::kSort:
102     case HloOpcode::kTranspose:
103     case HloOpcode::kTuple:
104     case HloOpcode::kOptimizationBarrier:
105       return true;
106     case HloOpcode::kBitcast:
107       return hlo.shape().element_type() ==
108              hlo.operand(0)->shape().element_type();
109     case HloOpcode::kDynamicSlice:
110       return operand_index == 0;
111     case HloOpcode::kDynamicUpdateSlice:
112       return operand_index == 0 || operand_index == 1;
113     case HloOpcode::kGather:
114       return operand_index == 0;
115     case HloOpcode::kSelect:
116       return operand_index == 1 || operand_index == 2;
117     case HloOpcode::kReduce:
118     case HloOpcode::kReduceWindow: {
119       HloComputation* reduce_comp = hlo.called_computations()[0];
120       for (HloInstruction* inst : reduce_comp->instructions()) {
121         if (inst->opcode() == HloOpcode::kParameter) {
122           continue;
123         }
124         for (int64_t i = 0; i < inst->operand_count(); ++i) {
125           if (!EffectiveOperandPrecisionIsOutputPrecision(*inst, i)) {
126             return false;
127           }
128         }
129       }
130       return true;
131     }
132     default:
133       break;
134   }
135   return false;
136 }
137 
EffectiveOperandPrecisionIsBF16(const HloInstruction & hlo,int64_t operand_index) const138 bool BFloat16Support::EffectiveOperandPrecisionIsBF16(
139     const HloInstruction& hlo, int64_t operand_index) const {
140   return false;
141 }
142 
143 }  // namespace xla
144