1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_query.h"
17
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
20 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23
24 namespace xla {
25 namespace hlo_query {
26
IsCollectiveCommunicationOp(HloOpcode op)27 bool IsCollectiveCommunicationOp(HloOpcode op) {
28 return op == HloOpcode::kAllReduce || op == HloOpcode::kAllGather ||
29 op == HloOpcode::kAllToAll || op == HloOpcode::kCollectivePermute ||
30 op == HloOpcode::kReduceScatter;
31 }
32
IsConstantR0F32(HloInstruction * instruction,float * out)33 bool IsConstantR0F32(HloInstruction* instruction, float* out) {
34 if (instruction->opcode() == HloOpcode::kConstant &&
35 ShapeUtil::IsScalarWithElementType(instruction->shape(), F32)) {
36 *out = instruction->literal().Get<float>({});
37 return true;
38 }
39
40 return false;
41 }
42
AllOperandsAreParametersOrConstants(const HloInstruction & instruction)43 bool AllOperandsAreParametersOrConstants(const HloInstruction& instruction) {
44 for (const auto& operand : instruction.operands()) {
45 if (operand->opcode() != HloOpcode::kParameter &&
46 operand->opcode() != HloOpcode::kConstant) {
47 return false;
48 }
49 }
50 return true;
51 }
52
AllOperandsAreParameters(const HloInstruction & instruction)53 bool AllOperandsAreParameters(const HloInstruction& instruction) {
54 for (const auto& operand : instruction.operands()) {
55 if (operand->opcode() != HloOpcode::kParameter) {
56 return false;
57 }
58 }
59 return true;
60 }
61
AllOperandsAreConstants(const HloInstruction & instruction)62 bool AllOperandsAreConstants(const HloInstruction& instruction) {
63 for (const auto& operand : instruction.operands()) {
64 if (operand->opcode() != HloOpcode::kConstant) {
65 return false;
66 }
67 }
68 return true;
69 }
70
GetMatchingOperand(const HloPredicate & matcher,HloInstruction * instruction)71 HloInstruction* GetMatchingOperand(const HloPredicate& matcher,
72 HloInstruction* instruction) {
73 for (HloInstruction* op : instruction->operands()) {
74 if (matcher(op)) {
75 return op;
76 }
77 }
78 return nullptr;
79 }
80
MatchBinaryInstructionOperand(const HloPredicate & matcher,HloInstruction * instruction,HloInstruction ** matching_operand,HloInstruction ** other_operand)81 bool MatchBinaryInstructionOperand(const HloPredicate& matcher,
82 HloInstruction* instruction,
83 HloInstruction** matching_operand,
84 HloInstruction** other_operand) {
85 CHECK_EQ(instruction->operand_count(), 2);
86 if (matcher(instruction->operand(0))) {
87 *matching_operand = instruction->mutable_operand(0);
88 *other_operand = instruction->mutable_operand(1);
89 return true;
90 }
91 if (matcher(instruction->operand(1))) {
92 *matching_operand = instruction->mutable_operand(1);
93 *other_operand = instruction->mutable_operand(0);
94 return true;
95 }
96 return false;
97 }
98
MatchBinaryInstructionOperandOpcode(HloOpcode opcode,HloInstruction * instruction,HloInstruction ** matching_operand,HloInstruction ** other_operand)99 bool MatchBinaryInstructionOperandOpcode(HloOpcode opcode,
100 HloInstruction* instruction,
101 HloInstruction** matching_operand,
102 HloInstruction** other_operand) {
103 return MatchBinaryInstructionOperand(
104 [opcode](const HloInstruction* instruction) {
105 return instruction->opcode() == opcode;
106 },
107 instruction, matching_operand, other_operand);
108 }
109
IsScalarConstant(const HloInstruction * instruction)110 bool IsScalarConstant(const HloInstruction* instruction) {
111 return instruction->IsConstant() && ShapeUtil::IsScalar(instruction->shape());
112 }
113
ContainsInstrWithOpcode(const HloComputation * comp,const absl::flat_hash_set<HloOpcode> & opcodes)114 bool ContainsInstrWithOpcode(const HloComputation* comp,
115 const absl::flat_hash_set<HloOpcode>& opcodes) {
116 for (const auto* instr : comp->instructions()) {
117 if (opcodes.count(instr->opcode())) {
118 return true;
119 }
120 for (const HloComputation* subcomp : instr->called_computations()) {
121 if (ContainsInstrWithOpcode(subcomp, opcodes)) {
122 return true;
123 }
124 }
125 }
126 return false;
127 }
128
ContainsLayoutConstrainedCollective(const HloModule & module,HloOpcode op)129 bool ContainsLayoutConstrainedCollective(const HloModule& module,
130 HloOpcode op) {
131 CHECK(IsCollectiveCommunicationOp(op));
132
133 for (auto computation : module.computations()) {
134 for (auto hlo : computation->instructions()) {
135 if (hlo->opcode() == op &&
136 DynCast<HloCollectiveInstruction>(hlo)->constrain_layout()) {
137 return true;
138 }
139 }
140 }
141 return false;
142 }
143
NextChannelId(const HloModule & module)144 int64_t NextChannelId(const HloModule& module) {
145 int64_t next_channel_id = 1;
146 for (const HloComputation* comp : module.computations()) {
147 for (const HloInstruction* hlo : comp->instructions()) {
148 const HloChannelInstruction* channel_instr =
149 DynCast<HloChannelInstruction>(hlo);
150 if (channel_instr && channel_instr->channel_id()) {
151 next_channel_id =
152 std::max(next_channel_id, *channel_instr->channel_id() + 1);
153 }
154 }
155 }
156 return next_channel_id;
157 }
158
HasX64TransformedHostTransfer(const HloModule & module)159 bool HasX64TransformedHostTransfer(const HloModule& module) {
160 for (auto computation : module.computations()) {
161 for (auto hlo : computation->instructions()) {
162 if (hlo->opcode() == HloOpcode::kSend) {
163 auto send = DynCast<HloSendInstruction>(hlo);
164 if (send->is_host_transfer() && send->operand(0)->shape().IsTuple()) {
165 return true;
166 }
167 } else if (hlo->opcode() == HloOpcode::kRecv) {
168 auto recv = DynCast<HloRecvInstruction>(hlo);
169 if (recv->is_host_transfer() &&
170 recv->shape().tuple_shapes(0).IsTuple()) {
171 return true;
172 }
173 }
174 }
175 }
176 return false;
177 }
178
179 } // namespace hlo_query
180 } // namespace xla
181