xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/topk_rewriter.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/topk_rewriter.h"
17 
18 #include <optional>
19 
20 #include "absl/algorithm/container.h"
21 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 
26 namespace xla {
27 
IsNanSafeGt(HloComputation * comp)28 static bool IsNanSafeGt(HloComputation* comp) {
29   namespace m = match;
30   auto match_bitcast_f32 = [](int64_t parameter_number) {
31     auto param = m::Parameter(parameter_number)
32                      .WithShape(m::Shape().WithElementType(F32));
33     auto param_s32 =
34         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
35     auto param_u32 =
36         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
37     return m::Select(
38         m::Lt(param_s32, m::ConstantScalar(0)),
39         m::BitcastConvert(
40             m::Subtract(m::ConstantScalar(std::numeric_limits<int32_t>::max()),
41                         param_u32))
42             .WithShape(m::Shape().WithElementType(S32)),
43         param_s32);
44   };
45 
46   auto match_bitcast_f32_with_convert = [](int64_t parameter_number) {
47     auto param = m::Parameter(parameter_number)
48                      .WithShape(m::Shape().WithElementType(F32));
49     auto param_s32 =
50         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
51     auto param_u32 =
52         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
53     auto max_u32 =
54         m::Convert(m::ConstantScalar(std::numeric_limits<int32_t>::max()))
55             .WithShape(m::Shape().WithElementType(U32));
56     return m::Select(m::Lt(param_s32, m::ConstantScalar(0)),
57                      m::BitcastConvert(m::Subtract(max_u32, param_u32))
58                          .WithShape(m::Shape().WithElementType(S32)),
59                      param_s32);
60   };
61 
62   auto match_bitcast_bf16 = [](int64_t parameter_number) {
63     auto param = m::Convert(m::Parameter(parameter_number)
64                                 .WithShape(m::Shape().WithElementType(BF16)))
65                      .WithShape(m::Shape().WithElementType(F32));
66     auto param_s32 =
67         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
68     auto param_u32 =
69         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
70     return m::Select(
71         m::Lt(param_s32, m::ConstantScalar(0)),
72         m::BitcastConvert(
73             m::Subtract(m::ConstantScalar(std::numeric_limits<int32_t>::max()),
74                         param_u32))
75             .WithShape(m::Shape().WithElementType(S32)),
76         param_s32);
77   };
78 
79   auto match_bitcast_bf16_with_convert = [](int64_t parameter_number) {
80     auto param = m::Convert(m::Parameter(parameter_number)
81                                 .WithShape(m::Shape().WithElementType(BF16)))
82                      .WithShape(m::Shape().WithElementType(F32));
83     auto param_s32 =
84         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
85     auto param_u32 =
86         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
87     auto max_u32 =
88         m::Convert(m::ConstantScalar(std::numeric_limits<int32_t>::max()))
89             .WithShape(m::Shape().WithElementType(U32));
90     return m::Select(m::Lt(param_s32, m::ConstantScalar(0)),
91                      m::BitcastConvert(m::Subtract(max_u32, param_u32))
92                          .WithShape(m::Shape().WithElementType(S32)),
93                      param_s32);
94   };
95 
96   auto match_s32 = [](int64_t parameter_number) {
97     auto param = m::Parameter(parameter_number)
98                      .WithShape(m::Shape().WithElementType(S32));
99     return param;
100   };
101 
102   return Match(comp->root_instruction(),
103                m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
104          Match(comp->root_instruction(),
105                m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1))) ||
106          Match(comp->root_instruction(),
107                m::Gt(match_bitcast_f32_with_convert(0),
108                      match_bitcast_f32_with_convert(1))) ||
109          Match(comp->root_instruction(),
110                m::Gt(match_bitcast_bf16_with_convert(0),
111                      match_bitcast_bf16_with_convert(1))) ||
112          Match(comp->root_instruction(), m::Gt(match_s32(0), match_s32(1)));
113 }
114 
SortIsInTopK(HloInstruction * inst)115 std::optional<int64_t> TopkRewriter::SortIsInTopK(HloInstruction* inst) {
116   HloSortInstruction* sort = DynCast<HloSortInstruction>(inst);
117   if (sort == nullptr) {
118     return std::nullopt;
119   }
120   if (sort->operand_count() != 1 && sort->operand_count() != 2) {
121     return std::nullopt;
122   }
123   HloInstruction* data = sort->mutable_operand(0);
124 
125   if (sort->operand_count() == 2) {
126     HloIotaInstruction* iota =
127         DynCast<HloIotaInstruction>(sort->mutable_operand(1));
128     if (iota == nullptr || iota->shape().rank() != data->shape().rank() ||
129         iota->shape().element_type() != S32 ||
130         iota->opcode() != HloOpcode::kIota ||
131         iota->iota_dimension() != sort->sort_dimension()) {
132       return std::nullopt;
133     }
134   }
135   if (!IsNanSafeGt(sort->to_apply())) {
136     return std::nullopt;
137   }
138   const int64_t sort_dim = sort->sort_dimension();
139   const int64_t batch_dim = sort_dim == 1 ? 0 : 1;
140   const bool has_batch = data->shape().rank() == 2;
141 
142   bool supported = true;
143   std::optional<int64_t> k;
144   for (HloInstruction* user : sort->users()) {
145     const HloInstruction* slice = user;
146     if (sort->operand_count() == 2) {
147       if (user->opcode() != HloOpcode::kGetTupleElement ||
148           user->user_count() != 1) {
149         supported = false;
150         break;
151       }
152       slice = user->users()[0];
153     }
154     if (slice->opcode() != HloOpcode::kSlice) {
155       // Non-slice user means we are not doing a TopK
156       supported = false;
157       break;
158     }
159     if (absl::c_any_of(slice->slice_starts(), [](int x) { return x != 0; }) ||
160         absl::c_any_of(slice->slice_strides(), [](int x) { return x != 1; })) {
161       // Strided slice or slicing at the beginning isn't supported.
162       supported = false;
163       break;
164     }
165     if (has_batch && slice->slice_limits(batch_dim) !=
166                          slice->operand(0)->shape().dimensions(batch_dim)) {
167       // Slicing along the batch dimension isn't supported.
168       supported = false;
169       break;
170     }
171     if (k == std::nullopt) {
172       k = slice->slice_limits(sort_dim);
173     } else if (k != slice->slice_limits(sort_dim)) {
174       // Different k for the different operands isn't supported.
175       supported = false;
176       break;
177     }
178   }
179   if (k == std::nullopt || !supported) {
180     return std::nullopt;
181   }
182   return k;
183 }
184 
TransformToCustomCall(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)185 StatusOr<bool> TopkRewriter::TransformToCustomCall(
186     HloModule* module,
187     const absl::flat_hash_set<absl::string_view>& execution_threads) {
188   bool changed = false;
189   for (HloComputation* comp : module->computations(execution_threads)) {
190     for (HloInstruction* inst : comp->MakeInstructionPostOrder()) {
191       // Check if sort is in TopK.
192       std::optional<int64_t> k = SortIsInTopK(inst);
193       if (!k) {
194         continue;
195       }
196 
197       HloSortInstruction* sort = DynCast<HloSortInstruction>(inst);
198       HloInstruction* data = sort->mutable_operand(0);
199       const PrimitiveType element_type = data->shape().element_type();
200 
201       if ((data->shape().rank() != 1 && data->shape().rank() != 2) ||
202           (element_type != F32 && element_type != BF16)) {
203         continue;
204       }
205 
206       const int64_t sort_dim = sort->sort_dimension();
207       const int64_t batch_dim = sort_dim == 1 ? 0 : 1;
208       const bool has_batch = data->shape().rank() == 2;
209 
210       // Profitability check.
211       if (!is_profitable_to_convert_(sort, *k)) {
212         continue;
213       }
214 
215       const int64_t batch_size =
216           has_batch ? sort->operand(0)->shape().dimensions(batch_dim) : 1;
217       const int64_t input_size = sort->operand(0)->shape().dimensions(sort_dim);
218       HloInstruction* input = sort->mutable_operand(0);
219       if (has_batch && sort_dim == 0) {
220         input = comp->AddInstruction(HloInstruction::CreateTranspose(
221             ShapeUtil::MakeShape(element_type, {batch_size, input_size}), input,
222             {1, 0}));
223       }
224 
225       Shape topk_shape =
226           has_batch ? ShapeUtil::MakeTupleShape(
227                           {ShapeUtil::MakeShape(element_type,
228                                                 {batch_size, k.value()}),
229                            ShapeUtil::MakeShape(S32, {batch_size, k.value()})})
230                     : ShapeUtil::MakeTupleShape(
231                           {ShapeUtil::MakeShape(element_type, {k.value()}),
232                            ShapeUtil::MakeShape(S32, {k.value()})});
233       HloInstruction* topk = comp->AddInstruction(
234           HloInstruction::CreateCustomCall(topk_shape, {input}, "TopK"));
235       HloInstruction* value_gte =
236           comp->AddInstruction(HloInstruction::CreateGetTupleElement(
237               topk->shape().tuple_shapes(0), topk, 0));
238       HloInstruction* index_gte =
239           comp->AddInstruction(HloInstruction::CreateGetTupleElement(
240               topk->shape().tuple_shapes(1), topk, 1));
241 
242       if (has_batch && sort_dim == 0) {
243         value_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
244             ShapeUtil::MakeShape(element_type, {k.value(), batch_size}),
245             value_gte, {1, 0}));
246         index_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
247             ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte,
248             {1, 0}));
249       }
250 
251       for (HloInstruction* user : sort->users()) {
252         if (sort->operand_count() == 2) {
253           HloInstruction* gte = user;
254           for (HloInstruction* slice : gte->users()) {
255             if (gte->tuple_index() == 0) {
256               TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(value_gte));
257             } else if (gte->tuple_index() == 1) {
258               TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(index_gte));
259             } else {
260               LOG(FATAL) << "Sort with more than 2 output isn't supported in "
261                             "topk rewriter";
262             }
263           }
264         } else {
265           TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(value_gte));
266         }
267       }
268       changed = true;
269     }
270   }
271   return changed;
272 }
273 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)274 StatusOr<bool> TopkRewriter::Run(
275     HloModule* module,
276     const absl::flat_hash_set<absl::string_view>& execution_threads) {
277   bool changed = false;
278   TF_ASSIGN_OR_RETURN(auto transform_to_customcall_changed,
279                       TransformToCustomCall(module, execution_threads));
280   changed |= transform_to_customcall_changed;
281   return changed;
282 }
283 
284 }  // namespace xla
285