xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
17 
18 #include <climits>
19 #include <memory>
20 #include <tuple>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/cleanup/cleanup.h"
24 #include "absl/types/optional.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
28 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"  // from @llvm-project
29 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
30 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
31 #include "mlir/IR/AffineExpr.h"  // from @llvm-project
32 #include "mlir/IR/AffineMap.h"  // from @llvm-project
33 #include "mlir/IR/Attributes.h"  // from @llvm-project
34 #include "mlir/IR/Builders.h"  // from @llvm-project
35 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
36 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
37 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
38 #include "mlir/IR/Dialect.h"  // from @llvm-project
39 #include "mlir/IR/Location.h"  // from @llvm-project
40 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
41 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
42 #include "mlir/IR/Operation.h"  // from @llvm-project
43 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
44 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
45 #include "mlir/IR/Verifier.h"  // from @llvm-project
46 #include "mlir/Pass/Pass.h"  // from @llvm-project
47 #include "mlir/Pass/PassOptions.h"  // from @llvm-project
48 #include "mlir/Tools/mlir-translate/Translation.h"  // from @llvm-project
49 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
50 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
51 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
52 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
53 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
54 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
55 #include "tensorflow/compiler/xla/debug_options_flags.h"
56 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
57 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
58 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
59 #include "tensorflow/compiler/xla/service/backend.h"
60 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
61 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
62 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
63 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
64 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
65 #include "tensorflow/compiler/xla/service/hlo_computation.h"
66 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
67 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
68 #include "tensorflow/compiler/xla/service/hlo_module.h"
69 #include "tensorflow/compiler/xla/service/hlo_parser.h"
70 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
71 #include "tensorflow/compiler/xla/shape_util.h"
72 #include "tensorflow/compiler/xla/statusor.h"
73 #include "tensorflow/compiler/xla/util.h"
74 #include "tensorflow/compiler/xla/window_util.h"
75 #include "tensorflow/compiler/xla/xla_data.pb.h"
76 
77 using xla::BufferAllocation;
78 using xla::BufferAssignment;
79 using xla::HloComputation;
80 using xla::HloCustomCallInstruction;
81 using xla::HloInfeedInstruction;
82 using xla::HloInstruction;
83 using xla::HloModule;
84 using xla::HloModuleProto;
85 using xla::HloOutfeedInstruction;
86 using xla::HloProto;
87 using xla::Shape;
88 using xla::StatusOr;
89 
90 namespace mlir {
91 namespace {
92 
StringRefToView(llvm::StringRef ref)93 absl::string_view StringRefToView(llvm::StringRef ref) {
94   return {ref.data(), ref.size()};
95 }
96 
HloModuleFromProto(const HloProto & hlo_proto)97 StatusOr<std::unique_ptr<HloModule>> HloModuleFromProto(
98     const HloProto& hlo_proto) {
99   const HloModuleProto& module_proto = hlo_proto.hlo_module();
100   TF_ASSIGN_OR_RETURN(const xla::HloModuleConfig module_config,
101                       HloModule::CreateModuleConfigFromProto(
102                           module_proto, xla::GetDebugOptionsFromFlags()));
103   return HloModule::CreateFromProto(module_proto, module_config);
104 }
105 
106 }  // namespace
107 
108 // Convert the MLIR `module` from HLO dialect to LHLO dialect using XLA for the
109 // given platform.
OptimizeAndConvertHloToLmhlo(std::unique_ptr<HloModule> hlo_module,ModuleOp module,StringRef platform_name,bool optimize_xla_hlo)110 Status OptimizeAndConvertHloToLmhlo(std::unique_ptr<HloModule> hlo_module,
111                                     ModuleOp module, StringRef platform_name,
112                                     bool optimize_xla_hlo) {
113   auto platform = xla::se::MultiPlatformManager::PlatformWithName(
114       StringRefToView(platform_name));
115   if (!platform.ok()) {
116     std::string error_msg;
117     llvm::raw_string_ostream os(error_msg);
118     os << "failed to get platform: " << platform.status().ToString()
119        << " (available Platform: ";
120     std::vector<std::string> available_platforms;
121     (void)xla::se::MultiPlatformManager::PlatformsWithFilter(
122         [&](const stream_executor::Platform* p) {
123           available_platforms.push_back(p->Name());
124           return false;
125         });
126     llvm::interleaveComma(available_platforms, os);
127     os << ")";
128     return xla::InvalidArgument("%s", os.str().c_str());
129   }
130 
131   xla::BackendOptions backend_options;
132   backend_options.set_platform(platform.ValueOrDie());
133   auto backend_or_err = xla::Backend::CreateBackend(backend_options);
134   TF_RETURN_WITH_CONTEXT_IF_ERROR(backend_or_err.status(),
135                                   "failed to create XLA Backend ");
136   auto backend = std::move(backend_or_err.ValueOrDie());
137 
138   StatusOr<std::unique_ptr<HloModule>> optimized_hlo_module;
139 
140   if (optimize_xla_hlo) {
141     // Run all HLO passes to produce an optimized module.
142     optimized_hlo_module = backend->compiler()->RunHloPasses(
143         std::move(hlo_module), backend->default_stream_executor(),
144         backend->memory_allocator());
145     TF_RETURN_WITH_CONTEXT_IF_ERROR(optimized_hlo_module.status(),
146                                     "running XLA pass pipeline");
147   } else {
148     optimized_hlo_module = std::move(hlo_module);
149   }
150 
151   StatusOr<std::unique_ptr<BufferAssignment>> assignment =
152       backend->compiler()->AssignBuffers(optimized_hlo_module->get());
153   TF_RETURN_WITH_CONTEXT_IF_ERROR(assignment.status(),
154                                   "running XLA buffer assigment");
155 
156   // Clear the module before populating it back with the result of the
157   // conversion.
158   module.getBody()->clear();
159   OpBuilder builder(module);
160 
161   TF_RETURN_WITH_CONTEXT_IF_ERROR(
162       HloToLhloModule(**assignment, **optimized_hlo_module, module),
163       "converting HLO to LHLO");
164 
165   return ::tensorflow::OkStatus();
166 }
167 
168 namespace {
169 // This pass takes an MLIR HLO module, converts it to XLA to perform the HLO
170 // optimization pipeline for the required platform, and then converts it back to
171 // MLIR LHLO.
172 class XlaHloToLhloPass
173     : public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const174   void getDependentDialects(DialectRegistry& registry) const override {
175     registry
176         .insert<arith::ArithmeticDialect, bufferization::BufferizationDialect,
177                 func::FuncDialect, memref::MemRefDialect, mhlo::MhloDialect,
178                 lmhlo::LmhloDialect, lmhlo_gpu::LmhloGpuDialect>();
179   }
180 
181  public:
182   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XlaHloToLhloPass)
183 
184   XlaHloToLhloPass() = default;
XlaHloToLhloPass(const XlaHloToLhloPass &)185   XlaHloToLhloPass(const XlaHloToLhloPass&) {}
getArgument() const186   StringRef getArgument() const final { return "xla-hlo-to-lhlo-with-xla"; }
getDescription() const187   StringRef getDescription() const final {
188     return "Emit LHLO from HLO using the existing XLA implementation";
189   }
190 
191  private:
runOnOperation()192   void runOnOperation() final {
193     ModuleOp module = getOperation();
194 
195     auto status = [&module, this]() -> Status {
196       SymbolTable symbol_table(module);
197       if (!symbol_table.lookup("main")) {
198         return xla::InvalidArgument(
199             "conversion to HLO module failed: missing main()");
200       }
201       HloProto hlo_proto;
202       TF_RETURN_WITH_CONTEXT_IF_ERROR(
203           ConvertMlirHloToHlo(module, &hlo_proto,
204                               /*use_tuple_args=*/false,
205                               /*return_tuple=*/false),
206           "conversion to XLA HLO proto failed");
207 
208       auto statusOrHloModule = HloModuleFromProto(hlo_proto);
209       TF_RETURN_WITH_CONTEXT_IF_ERROR(statusOrHloModule.status(),
210                                       "parsing HLO proto to HLO module failed");
211       std::unique_ptr<HloModule> hlo_module =
212           std::move(statusOrHloModule.ValueOrDie());
213 
214       return OptimizeAndConvertHloToLmhlo(std::move(hlo_module), module,
215                                           platform_, optimize_xla_hlo_);
216     }();
217     if (!status.ok()) {
218       module.emitError() << status.ToString();
219       return signalPassFailure();
220     }
221   }
222 
223   Option<std::string> platform_{
224       *this, "platform",
225       llvm::cl::desc("The platform to use for the XLA optimization pipeline."),
226       llvm::cl::init("Host")};
227   Option<bool> optimize_xla_hlo_{
228       *this, "optimize-xla-hlo",
229       llvm::cl::desc("Whether to apply HLO optimizations."),
230       llvm::cl::init(true)};
231 };
232 
233 }  // namespace
234 
235 // Creates MLIR operands corresponding to operands and results of the XLA HLO
236 // instruction. If `num_operands` is valid, then only the first `num_operands`
237 // operands of the HLO instruction will be considered.
CreateOperands(const HloInstruction * instr,std::optional<int64_t> num_operands,TokenLoweringMode token_mode,llvm::SmallVectorImpl<Value> & operands,size_t & num_arguments,size_t & num_results)238 Status LhloDialectEmitter::CreateOperands(
239     const HloInstruction* instr, std::optional<int64_t> num_operands,
240     TokenLoweringMode token_mode, llvm::SmallVectorImpl<Value>& operands,
241     size_t& num_arguments, size_t& num_results) {
242   if (num_operands.value_or(0) > instr->operand_count())
243     return xla::InvalidArgument("num_operands must be <= operand count");
244   for (int64_t i = 0; i < num_operands.value_or(instr->operand_count()); ++i) {
245     TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands,
246                                        /*result_subset=*/{}, token_mode));
247   }
248   num_arguments = operands.size();
249   TF_RETURN_IF_ERROR(
250       GetOrCreateView(instr, &operands, /*result_subset=*/{}, token_mode));
251   num_results = operands.size() - num_arguments;
252   return ::tensorflow::OkStatus();
253 }
254 
255 template <typename OpType>
CreateOpWithoutAttrs(const HloInstruction * instr,ValueRange operands)256 OpType LhloDialectEmitter::CreateOpWithoutAttrs(const HloInstruction* instr,
257                                                 ValueRange operands) {
258   Location loc = getLocation(instr);
259   return builder_.create<OpType>(loc, llvm::None, operands,
260                                  llvm::ArrayRef<NamedAttribute>{});
261 }
262 
263 template <typename OpType>
CreateOpWithoutAttrs(const HloInstruction * instr,size_t & num_arguments,size_t & num_results,std::optional<int64_t> num_operands)264 StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
265     const HloInstruction* instr, size_t& num_arguments, size_t& num_results,
266     std::optional<int64_t> num_operands) {
267   llvm::SmallVector<Value, 4> operands;
268   TF_RETURN_IF_ERROR(CreateOperands(instr, num_operands,
269                                     TokenLoweringMode::kFailToLower, operands,
270                                     num_arguments, num_results));
271   return CreateOpWithoutAttrs<OpType>(instr, operands);
272 }
273 
CreateOpInFusion(const HloInstruction * instr,ValueRange buffer_operands,size_t num_arguments,size_t num_results)274 StatusOr<mlir::Operation*> LhloDialectEmitter::CreateOpInFusion(
275     const HloInstruction* instr, ValueRange buffer_operands,
276     size_t num_arguments, size_t num_results) {
277   Location loc = getLocation(instr);
278   std::vector<Value> buffers(buffer_operands.begin(), buffer_operands.end());
279   absl::Span<Value> arguments =
280       absl::MakeSpan(buffers).subspan(0, num_arguments);
281   absl::Span<Value> results =
282       absl::MakeSpan(buffers).subspan(num_arguments, num_results);
283 
284   mlir::lmhlo::FusionOp fusion = builder_.create<mlir::lmhlo::FusionOp>(loc);
285   mlir::OpBuilder b(&fusion.getRegion());
286 
287   llvm::SmallVector<mlir::Value, 4> loads;
288   for (Value arg : arguments) {
289     auto load = b.create<mlir::bufferization::ToTensorOp>(loc, arg);
290     Shape shape = xla::TypeToShape(arg.getType());
291     TF_RET_CHECK(shape.IsArray());
292     if (shape.layout() !=
293         xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) {
294       load->setAttr("xla_shape",
295                     b.getStringAttr(shape.ToString(/*print_layout=*/true)));
296     }
297     loads.push_back(load);
298   }
299   mlir::Operation* op = nullptr;
300   if (instr->opcode() == xla::HloOpcode::kReduce) {
301     TF_RET_CHECK(loads.size() % 2 == 0);
302     std::vector<int64_t> dimensions(instr->dimensions().begin(),
303                                     instr->dimensions().end());
304     auto reduce_op = b.create<mhlo::ReduceOp>(
305         loc, llvm::makeArrayRef(loads).take_front(loads.size() / 2),
306         llvm::makeArrayRef(loads).drop_front(loads.size() / 2),
307         GetI64DenseElementsAttr(dimensions));
308 
309     TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
310         *instr->called_computations()[0], &reduce_op.body(), &builder_,
311         /*flatten_region_arg_tuple=*/true));
312     op = reduce_op;
313   } else {
314     TF_ASSIGN_OR_RETURN(
315         op,
316         xla::HloFunctionImporter::ImportInstruction(
317             instr, loads, &b, xla::DynamicShapeHandlingMode::kConvertToStatic));
318   }
319   TF_RET_CHECK(op->getNumResults() == num_results);
320   for (int i = 0; i < results.size(); i++) {
321     b.create<mlir::memref::TensorStoreOp>(loc, op->getResult(i), results[i]);
322   }
323   return op;
324 }
325 
CreateOpInFusion(const HloInstruction * instr)326 StatusOr<mlir::Operation*> LhloDialectEmitter::CreateOpInFusion(
327     const HloInstruction* instr) {
328   llvm::SmallVector<Value, 4> operands;
329   size_t num_arguments, num_results;
330   TF_RETURN_IF_ERROR(CreateOperands(instr, std::nullopt,
331                                     TokenLoweringMode::kFailToLower, operands,
332                                     num_arguments, num_results));
333   TF_ASSIGN_OR_RETURN(
334       auto op, CreateOpInFusion(instr, operands, num_arguments, num_results));
335   return op->getParentOp();
336 }
337 
EmitOp(const HloInstruction * instr)338 StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
339     const HloInstruction* instr) {
340   using xla::HloOpcode;
341   switch (instr->opcode()) {
342     case HloOpcode::kAddDependency:
343       return nullptr;
344     case HloOpcode::kAfterAll:
345       // LMHLO is already ordered. This assumption may be broken after
346       // introducing async regions and partial orders.
347       return nullptr;
348     case HloOpcode::kAllToAll:
349       return EmitAllToAllOp(instr);
350     case HloOpcode::kAllGather:
351       return EmitAllGatherOp(instr);
352     case HloOpcode::kAllReduce:
353       return EmitAllReduceOp(instr);
354     case HloOpcode::kAllReduceStart:
355       return EmitAllReduceStartOp(instr);
356     case HloOpcode::kAllReduceDone:
357       return EmitAllReduceDoneOp(instr);
358     case HloOpcode::kReduceScatter:
359       return EmitReduceScatterOp(instr);
360     case HloOpcode::kBitcast:
361       return EmitBitcast(instr);
362     case HloOpcode::kCollectivePermute:
363       return EmitCollectivePermuteOp(instr);
364     case HloOpcode::kConditional:
365       return EmitCaseOp(instr);
366     case HloOpcode::kFft:
367       return EmitFftOp(instr);
368     case HloOpcode::kGetTupleElement:
369       return nullptr;
370     case HloOpcode::kInfeed:
371       return EmitInfeedOp(instr);
372     case HloOpcode::kOutfeed:
373       return EmitOutfeedOp(instr);
374     case HloOpcode::kPartitionId:
375       return CreateOpWithoutAttrs<lmhlo::PartitionIdOp>(instr);
376     case HloOpcode::kReplicaId:
377       return CreateOpWithoutAttrs<lmhlo::ReplicaIdOp>(instr);
378     case HloOpcode::kTriangularSolve:
379       return EmitTriangularSolveOp(instr);
380     case HloOpcode::kTuple:
381       return nullptr;
382     case HloOpcode::kSort:
383       return EmitSortOp(instr);
384     case HloOpcode::kFusion:
385       return EmitFusionOp(instr);
386     case HloOpcode::kScatter:
387       return EmitScatterOp(instr);
388     case HloOpcode::kSelectAndScatter:
389       return EmitSelectAndScatterOp(instr);
390     case HloOpcode::kCustomCall:
391       return EmitCustomCallOp(instr);
392     case HloOpcode::kConstant:
393       return EmitConstant(instr);
394     case HloOpcode::kRngGetAndUpdateState:
395       return EmitRngGetAndUpdateStateOp(instr);
396     case HloOpcode::kWhile:
397       return EmitWhileOp(instr);
398 
399     case HloOpcode::kAbs:
400     case HloOpcode::kAdd:
401     case HloOpcode::kAnd:
402     case HloOpcode::kAtan2:
403     case HloOpcode::kBitcastConvert:
404     case HloOpcode::kBroadcast:
405     case HloOpcode::kCeil:
406     case HloOpcode::kCbrt:
407     case HloOpcode::kClamp:
408     case HloOpcode::kClz:
409     case HloOpcode::kCompare:
410     case HloOpcode::kComplex:
411     case HloOpcode::kConcatenate:
412     case HloOpcode::kConvert:
413     case HloOpcode::kCos:
414     case HloOpcode::kDivide:
415     case HloOpcode::kDot:
416     case HloOpcode::kDynamicSlice:
417     case HloOpcode::kDynamicUpdateSlice:
418     case HloOpcode::kExp:
419     case HloOpcode::kExpm1:
420     case HloOpcode::kFloor:
421     case HloOpcode::kGather:
422     case HloOpcode::kImag:
423     case HloOpcode::kIota:
424     case HloOpcode::kIsFinite:
425     case HloOpcode::kLog:
426     case HloOpcode::kLog1p:
427     case HloOpcode::kMap:
428     case HloOpcode::kMaximum:
429     case HloOpcode::kMinimum:
430     case HloOpcode::kMultiply:
431     case HloOpcode::kNegate:
432     case HloOpcode::kNot:
433     case HloOpcode::kOr:
434     case HloOpcode::kPad:
435     case HloOpcode::kPopulationCount:
436     case HloOpcode::kPower:
437     case HloOpcode::kReal:
438     case HloOpcode::kReshape:
439     case HloOpcode::kReducePrecision:
440     case HloOpcode::kReduceWindow:
441     case HloOpcode::kRemainder:
442     case HloOpcode::kReverse:
443     case HloOpcode::kRoundNearestAfz:
444     case HloOpcode::kRoundNearestEven:
445     case HloOpcode::kRsqrt:
446     case HloOpcode::kSelect:
447     case HloOpcode::kShiftLeft:
448     case HloOpcode::kShiftRightLogical:
449     case HloOpcode::kShiftRightArithmetic:
450     case HloOpcode::kSign:
451     case HloOpcode::kSin:
452     case HloOpcode::kSlice:
453     case HloOpcode::kSqrt:
454     case HloOpcode::kSubtract:
455     case HloOpcode::kTanh:
456     case HloOpcode::kTranspose:
457     case HloOpcode::kXor:
458     case HloOpcode::kCopy:
459     case HloOpcode::kReduce:
460       return CreateOpInFusion(instr);
461     default:
462       llvm::errs() << instr->ToString();
463       return tensorflow::errors::Internal(
464           absl::StrCat("LHLO opcode ", xla::HloOpcodeString(instr->opcode()),
465                        " is not supported."));
466   }
467 }
468 
DefaultAction(const HloInstruction * instr)469 Status LhloDialectEmitter::DefaultAction(const HloInstruction* instr) {
470   return EmitOp(instr).status();
471 }
472 
EmitSortOp(const HloInstruction * instr)473 StatusOr<lmhlo::SortOp> LhloDialectEmitter::EmitSortOp(
474     const HloInstruction* instr) {
475   TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs<lmhlo::SortOp>(instr));
476   auto* sort_instr = xla::Cast<xla::HloSortInstruction>(instr);
477   sort.setDimensionAttr(
478       builder_.getI64IntegerAttr(sort_instr->sort_dimension()));
479   sort.setIsStableAttr(builder_.getBoolAttr(sort_instr->is_stable()));
480   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
481       *sort_instr->called_computations()[0], &sort.getComparator(), &builder_));
482   return sort;
483 }
484 
485 // Walks MHLO::TupleOp recursively.
WalkTuplePostOrder(Value v,const std::function<Status (Value)> & visitor)486 Status WalkTuplePostOrder(Value v,
487                           const std::function<Status(Value)>& visitor) {
488   if (auto* op = v.getDefiningOp()) {
489     if (auto tuple = dyn_cast<mhlo::TupleOp>(op)) {
490       for (Value sub_v : tuple.val()) {
491         TF_RETURN_IF_ERROR(WalkTuplePostOrder(sub_v, visitor));
492       }
493       return ::tensorflow::OkStatus();
494     }
495   }
496   return visitor(v);
497 }
498 
RewriteFusionOperand(const HloInstruction * root,const Shape & shape,xla::ShapeIndex * shape_index,OpBuilder * b,Location loc)499 StatusOr<Value> LhloDialectEmitter::RewriteFusionOperand(
500     const HloInstruction* root, const Shape& shape,
501     xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) {
502   if (shape.IsTuple()) {
503     llvm::SmallVector<Value, 4> values;
504     for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
505       shape_index->push_back(i);
506       TF_ASSIGN_OR_RETURN(
507           auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index,
508                                        b, loc));
509       values.push_back(v);
510       shape_index->pop_back();
511     }
512     return Value(b->create<mhlo::TupleOp>(loc, values));
513   }
514   TF_ASSIGN_OR_RETURN(Value memref,
515                       GetOrCreateArrayView(root, shape, *shape_index));
516   auto load = b->create<bufferization::ToTensorOp>(loc, memref);
517   if (shape.layout() !=
518       xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) {
519     llvm::SmallVector<int64_t, 4> minor_to_major(
520         shape.layout().minor_to_major().begin(),
521         shape.layout().minor_to_major().end());
522     load->setAttr("xla_shape",
523                   b->getStringAttr(shape.ToString(/*print_layout=*/true)));
524   }
525   return load.getResult();
526 }
527 
528 // Emit a lmhlo.fusion based on XLA HLO fusion. Structurally they are not neatly
529 // equivalent. Specifically, XLA HLO fusion:
530 //     fused_computation {
531 //       %p0 = parameter(0)
532 //       %p1 = parameter(1)
533 //       ...
534 //       ROOT %ret = ...
535 //     }
536 // will be converted to
537 //     lmhlo.fusion() {  // no explicit operands
538 //       // capturing outside buffers
539 //       %p0 = bufferization.to_tensor(%arg0) : memref<...> -> tensor<...>
540 //       %p1 = bufferization.to_tensor(%arg1) : memref<...> -> tensor<...>
541 //       ...
542 //       tensor_store ..., %ret // store a tensor to a memref
543 //     }
EmitFusionOp(const HloInstruction * instr)544 StatusOr<lmhlo::FusionOp> LhloDialectEmitter::EmitFusionOp(
545     const HloInstruction* instr) {
546   Location loc = getLocation(instr);
547 
548   auto* fusion_instr = xla::Cast<xla::HloFusionInstruction>(instr);
549 
550   auto fusion = builder_.create<lmhlo::FusionOp>(getLocation(instr));
551   auto after_fusion = builder_.saveInsertionPoint();
552   auto reverter = absl::MakeCleanup(
553       [this, after_fusion] { builder_.restoreInsertionPoint(after_fusion); });
554   builder_ = mlir::OpBuilder(fusion);
555 
556   auto region_builder = OpBuilder::atBlockBegin(&fusion.getRegion().front());
557 
558   llvm::SmallVector<Value, 8> arguments;
559   for (int i = 0; i < instr->operands().size(); ++i) {
560     const HloInstruction* operand = instr->operand(i);
561     xla::ShapeIndex shape_index;
562     TF_ASSIGN_OR_RETURN(
563         auto arg, RewriteFusionOperand(operand, operand->shape(), &shape_index,
564                                        &region_builder, loc));
565     arguments.push_back(arg);
566   }
567 
568   TF_ASSIGN_OR_RETURN(Value result,
569                       xla::HloFunctionImporter::ImportInstructions(
570                           *fusion_instr->fused_instructions_computation(),
571                           arguments, &region_builder));
572   {
573     int i = 0;
574     llvm::SmallVector<Value, 4> output;
575     TF_RETURN_IF_ERROR(GetOrCreateView(instr, &output));
576     TF_RETURN_IF_ERROR(WalkTuplePostOrder(result, [&](Value v) mutable {
577       region_builder.create<memref::TensorStoreOp>(loc, v, output[i++]);
578       return ::tensorflow::OkStatus();
579     }));
580     if (i != output.size()) {
581       return xla::InternalError("output sizes don't match");
582     }
583   }
584 
585   // Fold GTE/Tuple pairs.
586   //
587   // Since the fused region refers to values in its parent region, we can't
588   // call applyPatternAndFoldGreedily. We optimize it manually.
589   //
590   // Only walk once, because post-ordering is exactly what we need for GTE
591   // optimizations.
592   fusion.getRegion().walk([](mhlo::GetTupleElementOp gte) {
593     SmallVector<Value, 4> folded_values;
594     if (succeeded(OpBuilder(gte).tryFold(gte, folded_values))) {
595       gte.replaceAllUsesWith(folded_values[0]);
596     }
597   });
598 
599   // Effectively a DCE on the region.
600   {
601     llvm::SmallVector<mlir::Operation*, 4> ops;
602     fusion.getRegion().walk([&](mlir::Operation* op) { ops.push_back(op); });
603     // Visit the user first.
604     std::reverse(ops.begin(), ops.end());
605     for (auto op : ops) {
606       if (isOpTriviallyDead(op)) op->erase();
607     }
608   }
609 
610   return fusion;
611 }
612 
613 StatusOr<mhlo::ScatterDimensionNumbersAttr>
GetScatterDimensionNumbers(const HloInstruction * instr,mlir::MLIRContext * context)614 LhloDialectEmitter::GetScatterDimensionNumbers(const HloInstruction* instr,
615                                                mlir::MLIRContext* context) {
616   auto* scatter_instr = xla::Cast<xla::HloScatterInstruction>(instr);
617 
618   const xla::ScatterDimensionNumbers& xla_scatter_dim =
619       scatter_instr->scatter_dimension_numbers();
620 
621   auto get_i64_array = [](absl::Span<const int64_t> container) {
622     return ArrayRef<int64_t>{container.data(),
623                              static_cast<size_t>(container.size())};
624   };
625   auto scatter_dimension_numbers = mhlo::ScatterDimensionNumbersAttr::get(
626       context, get_i64_array(xla_scatter_dim.update_window_dims()),
627       get_i64_array(xla_scatter_dim.inserted_window_dims()),
628       get_i64_array(xla_scatter_dim.scatter_dims_to_operand_dims()),
629       xla_scatter_dim.index_vector_dim());
630   return scatter_dimension_numbers;
631 }
632 
EmitScatterOp(const HloInstruction * instr)633 StatusOr<lmhlo::ScatterOp> LhloDialectEmitter::EmitScatterOp(
634     const HloInstruction* instr) {
635   TF_ASSIGN_OR_RETURN(auto scatter,
636                       CreateOpWithoutAttrs<lmhlo::ScatterOp>(instr));
637 
638   // copy attributes
639   auto* scatter_instr = xla::Cast<xla::HloScatterInstruction>(instr);
640 
641   TF_ASSIGN_OR_RETURN(auto scatter_dimension_numbers,
642                       GetScatterDimensionNumbers(instr, builder_.getContext()));
643   scatter.setScatterDimensionNumbersAttr(scatter_dimension_numbers);
644   scatter.setIndicesAreSortedAttr(
645       builder_.getBoolAttr(scatter_instr->indices_are_sorted()));
646   scatter.setUniqueIndicesAttr(
647       builder_.getBoolAttr(scatter_instr->unique_indices()));
648 
649   // import update computation as region
650   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
651       *scatter_instr->called_computations()[0], &scatter.getUpdateComputation(),
652       &builder_));
653 
654   return scatter;
655 }
656 
EmitSelectAndScatterOp(const HloInstruction * instr)657 StatusOr<lmhlo::SelectAndScatterOp> LhloDialectEmitter::EmitSelectAndScatterOp(
658     const HloInstruction* instr) {
659   TF_ASSIGN_OR_RETURN(auto select_and_scatter,
660                       CreateOpWithoutAttrs<lmhlo::SelectAndScatterOp>(instr));
661 
662   // copy attributes
663   auto* select_and_scatter_instr =
664       xla::Cast<xla::HloSelectAndScatterInstruction>(instr);
665   const xla::Window& window = select_and_scatter_instr->window();
666 
667   if (xla::window_util::HasDilation(window)) {
668     return xla::Unimplemented("Dilation for SelectAndScatter is not supported");
669   }
670 
671   select_and_scatter.setWindowDimensionsAttr(
672       GetWindowElements(window, [](const xla::WindowDimension& dim) {
673         return static_cast<int64_t>(dim.size());
674       }));
675   select_and_scatter.setWindowStridesAttr(
676       GetWindowElements(window, [](const xla::WindowDimension& dim) {
677         return static_cast<int64_t>(dim.stride());
678       }));
679   select_and_scatter.setPaddingAttr(
680       GetWindowElements(window, [](const xla::WindowDimension& dim) {
681         return static_cast<int64_t>(dim.padding_low());
682       }));
683 
684   // import select and scatter computation as region
685   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
686       *select_and_scatter_instr->select(), &select_and_scatter.getSelect(),
687       &builder_));
688   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
689       *select_and_scatter_instr->scatter(), &select_and_scatter.getScatter(),
690       &builder_));
691   return select_and_scatter;
692 }
693 
EmitCustomCallOp(const HloInstruction * instr)694 StatusOr<mlir::Operation*> LhloDialectEmitter::EmitCustomCallOp(
695     const HloInstruction* instr) {
696   auto* custom_call_instr = xla::Cast<xla::HloCustomCallInstruction>(instr);
697 
698   if (xla::gpu::IsCustomCallToCusolver(*instr)) {
699     return EmitCholesky(custom_call_instr);
700   }
701 
702   if (xla::gpu::IsCublasGemm(*instr)) {
703     return EmitGemm(custom_call_instr);
704   }
705 
706   if (xla::gpu::IsCublasLtMatmul(*instr)) {
707     return EmitCublasLtMatmul(custom_call_instr);
708   }
709 
710   if (xla::gpu::IsCustomCallToDnnConvolution(*instr)) {
711     return EmitDnnConvolution(custom_call_instr);
712   }
713 
714   // For custom call, if there are any token operands or results, they will not
715   // be represented in LHLO so we need to remember the mapping. First create
716   // operands where each token is replaced with a null Value.
717   llvm::SmallVector<Value, 4> operands;
718   size_t num_arguments, num_results;
719   TF_RETURN_IF_ERROR(CreateOperands(instr, /*num_operands=*/std::nullopt,
720                                     TokenLoweringMode::kUseNull, operands,
721                                     num_arguments, num_results));
722 
723   // Now check if any of the operands is Null, which would indicate the presence
724   // of a token in the input or output.
725   bool has_token = llvm::any_of(operands, [](Value v) { return !v; });
726 
727   lmhlo::CustomCallTargetArgMappingAttr target_mapping;
728   if (has_token) {
729     // If there was a token, squeeze all the non-token arguments and results
730     // (in-place) and remember the mapping.
731     int next_index = 0;
732     llvm::SmallVector<int64_t> arg_to_target_arg_mapping;
733     for (int i = 0; i < num_arguments; ++i) {
734       if (operands[i]) {
735         arg_to_target_arg_mapping.push_back(i);
736         operands[next_index++] = operands[i];
737       }
738     }
739     // Size of arg_to_target_arg_mapping is the number of arguments in LHLO.
740     llvm::SmallVector<int64_t> result_to_target_result_mapping;
741     for (int i = num_arguments; i < operands.size(); ++i) {
742       if (operands[i]) {
743         result_to_target_result_mapping.push_back(i - num_arguments);
744         operands[next_index++] = operands[i];
745       }
746     }
747 
748     // Build the mapping attribute.
749     target_mapping = lmhlo::CustomCallTargetArgMappingAttr::get(
750         builder_.getContext(), num_arguments, num_results,
751         arg_to_target_arg_mapping, result_to_target_result_mapping);
752 
753     // Drop the remaining operands and adjust num_arguments and num_results
754     // for LMHLO creation.
755     operands.resize(next_index);
756     num_arguments = arg_to_target_arg_mapping.size();
757     num_results = result_to_target_result_mapping.size();
758   }
759 
760   auto custom_call = CreateOpWithoutAttrs<lmhlo::CustomCallOp>(instr, operands);
761   TF_ASSIGN_OR_RETURN(
762       auto mlir_api_version,
763       ConvertCustomCallApiVersion(custom_call_instr->api_version()));
764   custom_call.setCallTargetNameAttr(
765       builder_.getStringAttr(custom_call_instr->custom_call_target()));
766   custom_call.setBackendConfigAttr(
767       builder_.getStringAttr(custom_call_instr->opaque()));
768   custom_call.setApiVersionAttr(mhlo::CustomCallApiVersionAttr::get(
769       builder_.getContext(), mlir_api_version));
770   const int32_t segments[2] = {static_cast<int32_t>(num_arguments),
771                                static_cast<int32_t>(num_results)};
772   custom_call->setAttr(lmhlo::CustomCallOp::getOperandSegmentSizeAttr(),
773                        builder_.getDenseI32ArrayAttr(segments));
774   if (target_mapping) custom_call.setTargetArgMappingAttr(target_mapping);
775   return custom_call.getOperation();
776 }
777 
EmitCholesky(const HloCustomCallInstruction * custom_call)778 StatusOr<lmhlo_gpu::CholeskyOp> LhloDialectEmitter::EmitCholesky(
779     const HloCustomCallInstruction* custom_call) {
780   TF_ASSIGN_OR_RETURN(auto cholesky_op,
781                       CreateOpWithoutAttrs<lmhlo_gpu::CholeskyOp>(custom_call));
782   TF_ASSIGN_OR_RETURN(xla::CholeskyOptions options,
783                       custom_call->backend_config<xla::CholeskyOptions>());
784   cholesky_op.setIsLowerAttr(builder_.getBoolAttr(options.lower()));
785   return cholesky_op;
786 }
787 
788 namespace {
789 
790 template <typename OpT>
SetMatmulAttributes(OpT op,const xla::gpu::GemmBackendConfig & config,OpBuilder & builder)791 void SetMatmulAttributes(OpT op, const xla::gpu::GemmBackendConfig& config,
792                          OpBuilder& builder) {
793   auto arrayref = [](absl::Span<const int64_t> array) {
794     return llvm::ArrayRef<int64_t>{array.data(), array.size()};
795   };
796 
797   auto hlo_dims = config.dot_dimension_numbers();
798   auto mlir_dims = mhlo::DotDimensionNumbersAttr::get(
799       builder.getContext(), arrayref(hlo_dims.lhs_batch_dimensions()),
800       arrayref(hlo_dims.rhs_batch_dimensions()),
801       arrayref(hlo_dims.lhs_contracting_dimensions()),
802       arrayref(hlo_dims.rhs_contracting_dimensions()));
803   op.setDotDimensionNumbersAttr(mlir_dims);
804   op.setAlphaRealAttr(builder.getF64FloatAttr(config.alpha_real()));
805   op.setAlphaImagAttr(builder.getF64FloatAttr(config.alpha_imag()));
806   op.setBetaAttr(builder.getF64FloatAttr(config.beta()));
807   if (config.algorithm_case() ==
808       xla::gpu::GemmBackendConfig::kSelectedAlgorithm) {
809     op.setAlgorithmAttr(builder.getI64IntegerAttr(config.selected_algorithm()));
810   }
811   op.setPrecisionConfigAttr(
812       xla::ConvertPrecisionConfig(&config.precision_config(), &builder));
813 }
814 
AsLhloEpilogue(xla::gpu::GemmBackendConfig_Epilogue epilogue)815 StatusOr<lmhlo_gpu::CublasLtMatmulEpilogue> AsLhloEpilogue(
816     xla::gpu::GemmBackendConfig_Epilogue epilogue) {
817   switch (epilogue) {
818     case xla::gpu::GemmBackendConfig::DEFAULT:
819       return lmhlo_gpu::CublasLtMatmulEpilogue::Default;
820       break;
821     case xla::gpu::GemmBackendConfig::BIAS:
822       return lmhlo_gpu::CublasLtMatmulEpilogue::Bias;
823       break;
824     default:
825       return xla::InternalError("unknown epilogue");
826   }
827 }
828 
829 }  // namespace
830 
EmitGemm(const HloCustomCallInstruction * custom_call)831 StatusOr<Operation*> LhloDialectEmitter::EmitGemm(
832     const HloCustomCallInstruction* custom_call) {
833   TF_ASSIGN_OR_RETURN(
834       auto const config,
835       custom_call->backend_config<xla::gpu::GemmBackendConfig>());
836 
837   if (custom_call->operand_count() == 2) {
838     TF_RET_CHECK(config.beta() == 0.);
839   } else if (custom_call->operand_count() != 3) {
840     return xla::InvalidArgument("GEMM custom call should have 2 or 3 operands");
841   }
842 
843   // GEMM may have two or three operands. However, in the three operand case,
844   // the third operand is updated in-place, so we treat that as an output here.
845   TF_ASSIGN_OR_RETURN(
846       lmhlo_gpu::GEMMOp op,
847       CreateOpWithoutAttrs<lmhlo_gpu::GEMMOp>(custom_call,
848                                               /*num_operands=*/2));
849 
850   SetMatmulAttributes(op, config, builder_);
851   return op.getOperation();
852 }
853 
EmitCublasLtMatmul(const HloCustomCallInstruction * custom_call)854 StatusOr<Operation*> LhloDialectEmitter::EmitCublasLtMatmul(
855     const HloCustomCallInstruction* custom_call) {
856   TF_ASSIGN_OR_RETURN(
857       auto const config,
858       custom_call->backend_config<xla::gpu::GemmBackendConfig>());
859 
860   bool has_matrix_bias = config.beta() != 0.;
861   bool has_vector_bias = config.epilogue() == xla::gpu::GemmBackendConfig::BIAS;
862   TF_RET_CHECK(custom_call->operand_count() ==
863                2 + int{has_matrix_bias} + int{has_vector_bias});
864 
865   llvm::SmallVector<Value, 5> operands;
866   TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(0), &operands));
867   TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(1), &operands));
868   TF_RETURN_IF_ERROR(GetOrCreateView(
869       has_matrix_bias ? custom_call->operand(2) : custom_call, &operands));
870   TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands));
871 
872   if (has_vector_bias) {
873     TF_RETURN_IF_ERROR(GetOrCreateView(
874         custom_call->operand(has_matrix_bias ? 3 : 2), &operands));
875   }
876 
877   auto op =
878       CreateOpWithoutAttrs<lmhlo_gpu::CublasLtMatmulOp>(custom_call, operands);
879   SetMatmulAttributes(op, config, builder_);
880 
881   TF_ASSIGN_OR_RETURN(lmhlo_gpu::CublasLtMatmulEpilogue epilogue,
882                       AsLhloEpilogue(config.epilogue()));
883   op.setEpilogueAttr(lmhlo_gpu::CublasLtMatmulEpilogueAttr::get(
884       builder_.getContext(), epilogue));
885 
886   // Use the first algorithm by default (i.e. fastest according to heuristics).
887   if (config.algorithm_case() !=
888       xla::gpu::GemmBackendConfig::kSelectedAlgorithm) {
889     op.setAlgorithmAttr(builder_.getI64IntegerAttr(0));
890   }
891 
892   return op.getOperation();
893 }
894 
GetLHLOActivation(stream_executor::dnn::ActivationMode activation)895 static StatusOr<mlir::lmhlo_gpu::Activation> GetLHLOActivation(
896     stream_executor::dnn::ActivationMode activation) {
897   switch (activation) {
898     case stream_executor::dnn::kNone:
899       return mlir::lmhlo_gpu::Activation::None;
900     case stream_executor::dnn::kSigmoid:
901       return mlir::lmhlo_gpu::Activation::Sigmoid;
902     case stream_executor::dnn::kRelu:
903       return mlir::lmhlo_gpu::Activation::Relu;
904     case stream_executor::dnn::kRelu6:
905       return mlir::lmhlo_gpu::Activation::Relu6;
906     case stream_executor::dnn::kReluX:
907       return mlir::lmhlo_gpu::Activation::ReluX;
908     case stream_executor::dnn::kTanh:
909       return mlir::lmhlo_gpu::Activation::Tanh;
910     case stream_executor::dnn::kBandPass:
911       return mlir::lmhlo_gpu::Activation::BandPass;
912     default:
913       return xla::InternalError("Unknown activation");
914   }
915 }
916 
EmitDnnConvolution(const HloCustomCallInstruction * custom_call)917 StatusOr<Operation*> LhloDialectEmitter::EmitDnnConvolution(
918     const HloCustomCallInstruction* custom_call) {
919   TF_ASSIGN_OR_RETURN(
920       auto const backend_config,
921       custom_call->backend_config<xla::gpu::CudnnConvBackendConfig>());
922 
923   TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnConvKind kind,
924                       xla::gpu::GetCudnnConvKind(custom_call));
925 
926   auto get_layout_attribute = [&](const xla::Layout& layout) {
927     std::vector<int64_t> minor_to_major(layout.minor_to_major_size());
928     absl::c_transform(layout.minor_to_major(), minor_to_major.begin(),
929                       [](int64_t x) { return static_cast<int64_t>(x); });
930     return minor_to_major;
931   };
932 
933   auto set_common_conv_attributes = [&, this](auto op) -> Operation* {
934     const xla::Window& window = custom_call->window();
935     // Window size for Cudnn Conv is same as the kernel size.
936     NamedAttrList attrs(op->getAttrDictionary());
937     DenseIntElementsAttr window_strides;
938     attrs.set(op.getWindowStridesAttrName(),
939               window_strides = GetWindowElements(
940                   window, [](const xla::WindowDimension& dim) {
941                     return static_cast<int64_t>(dim.stride());
942                   }));
943     // Cudnn Conv requires low and high padding to be equal.
944     attrs.set(op.getPaddingAttrName(),
945               GetWindowElements(window, [](const xla::WindowDimension& dim) {
946                 return static_cast<int64_t>(dim.padding_low());
947               }));
948     // LHS dilation is encoded in base_dilation of the backend config.
949     // RHS dilation is encoded in window_dilation of the backend config.
950     attrs.set(op.getLhsDilationAttrName(),
951               GetWindowElements(window, [](const xla::WindowDimension& dim) {
952                 return static_cast<int64_t>(dim.base_dilation());
953               }));
954     attrs.set(op.getRhsDilationAttrName(),
955               GetWindowElements(window, [](const xla::WindowDimension& dim) {
956                 return static_cast<int64_t>(dim.window_dilation());
957               }));
958     // Setup window reversal.
959     auto window_reversal = llvm::to_vector<4>(llvm::map_range(
960         window.dimensions(),
961         [](const xla::WindowDimension& dim) { return dim.window_reversal(); }));
962     auto type = RankedTensorType::get(window_strides.getType().getShape(),
963                                       builder_.getIntegerType(/*width=*/1));
964     attrs.set(op.getWindowReversalAttrName(),
965               DenseElementsAttr::get(type, window_reversal));
966 
967     attrs.set(op.getDimensionNumbersAttrName(),
968               xla::ConvertConvDimensionNumbers(
969                   custom_call->convolution_dimension_numbers(), &builder_));
970     attrs.set(op.getFeatureGroupCountAttrName(),
971               builder_.getI64IntegerAttr(custom_call->feature_group_count()));
972     attrs.set(op.getBatchGroupCountAttrName(),
973               builder_.getI64IntegerAttr(custom_call->batch_group_count()));
974     attrs.set(op.getPrecisionConfigAttrName(),
975               xla::ConvertPrecisionConfig(&custom_call->precision_config(),
976                                           &builder_));
977     attrs.set(op.getResultScaleAttrName(),
978               builder_.getF64FloatAttr(backend_config.conv_result_scale()));
979 
980     const auto& algorithm = backend_config.algorithm();
981     std::vector<int64_t> knob_ids;
982     std::vector<int64_t> knob_values;
983     for (const auto& entry : algorithm.tuning_knobs()) {
984       knob_ids.push_back(entry.first);
985       knob_values.push_back(entry.second);
986     }
987 
988     auto config = mlir::lmhlo_gpu::ConvolutionBackendConfigAttr::get(
989         builder_.getContext(), algorithm.algo_id(),
990 
991         algorithm.math_type() ==
992             stream_executor::dnn::AlgorithmProto::TENSOR_OP_MATH,
993         knob_ids, knob_values, algorithm.is_cudnn_frontend(),
994         algorithm.has_workspace_size() ? algorithm.workspace_size().value()
995                                        : -1,
996         get_layout_attribute(custom_call->operand(0)->shape().layout()),
997         get_layout_attribute(custom_call->operand(1)->shape().layout()),
998         get_layout_attribute(custom_call->shape().tuple_shapes(0).layout()));
999     attrs.set(op.getBackendConfigAttrName(), config);
1000     op->setAttrs(attrs.getDictionary(op->getContext()));
1001 
1002     return op.getOperation();
1003   };
1004 
1005   auto set_activation = [&, this](auto op) -> Status {
1006     auto se_activation = static_cast<stream_executor::dnn::ActivationMode>(
1007         backend_config.activation_mode());
1008     TF_ASSIGN_OR_RETURN(mlir::lmhlo_gpu::Activation activation,
1009                         GetLHLOActivation(se_activation));
1010     auto activation_attr = ::mlir::lmhlo_gpu::ActivationAttr::get(
1011         getLocation(custom_call).getContext(), activation);
1012     op.setActivationModeAttr(activation_attr);
1013     return ::tensorflow::OkStatus();
1014   };
1015 
1016   switch (kind) {
1017     case xla::gpu::CudnnConvKind::kForward: {
1018       TF_ASSIGN_OR_RETURN(
1019           auto cnn_forward,
1020           CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardOp>(custom_call));
1021       return set_common_conv_attributes(cnn_forward);
1022     }
1023     case xla::gpu::CudnnConvKind::kBackwardInput: {
1024       TF_ASSIGN_OR_RETURN(
1025           auto cnn_backward,
1026           CreateOpWithoutAttrs<lmhlo_gpu::ConvBackwardInputOp>(custom_call));
1027       return set_common_conv_attributes(cnn_backward);
1028     }
1029     case xla::gpu::CudnnConvKind::kBackwardFilter: {
1030       TF_ASSIGN_OR_RETURN(
1031           auto cnn_backward,
1032           CreateOpWithoutAttrs<lmhlo_gpu::ConvBackwardFilterOp>(custom_call));
1033       return set_common_conv_attributes(cnn_backward);
1034     }
1035     case xla::gpu::CudnnConvKind::kForwardActivation: {
1036       // Fused conv can be either with side input or without.
1037       if (custom_call->operand_count() == 3) {
1038         TF_ASSIGN_OR_RETURN(
1039             auto cnn_fused,
1040             CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardFusedOp>(custom_call));
1041         TF_RETURN_IF_ERROR(set_activation(cnn_fused));
1042         return set_common_conv_attributes(cnn_fused);
1043       }
1044 
1045       TF_RET_CHECK(custom_call->operand_count() == 4);
1046       TF_ASSIGN_OR_RETURN(
1047           auto cnn_fused_side_input,
1048           CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardFusedSideInputOp>(
1049               custom_call));
1050       cnn_fused_side_input.setSideInputScaleAttr(
1051           builder_.getF64FloatAttr(backend_config.side_input_scale()));
1052       TF_RETURN_IF_ERROR(set_activation(cnn_fused_side_input));
1053       return set_common_conv_attributes(cnn_fused_side_input);
1054     }
1055   }
1056 }
1057 
1058 // Convert an XLA HLO constant to a global_memref + get_global_memref pair.
EmitConstant(const HloInstruction * instr)1059 StatusOr<mlir::memref::GetGlobalOp> LhloDialectEmitter::EmitConstant(
1060     const HloInstruction* instr) {
1061   auto& cached_value = slices_[std::make_pair(instr, xla::ShapeIndex())];
1062   if (cached_value) {
1063     return dyn_cast<mlir::memref::GetGlobalOp>(cached_value.getDefiningOp());
1064   }
1065 
1066   // Insert a global_memref in the module.
1067   Location loc = getLocation(instr);
1068 
1069   auto const_instr = xla::Cast<xla::HloConstantInstruction>(instr);
1070   TF_RET_CHECK(const_instr->shape().IsArray() &&
1071                const_instr->shape().is_static());
1072   TF_ASSIGN_OR_RETURN(Type type, xla::ConvertShapeToType<MemRefType>(
1073                                      const_instr->shape(), builder_));
1074   auto memref_type = type.dyn_cast<MemRefType>();
1075   TF_RET_CHECK(memref_type != nullptr);
1076 
1077   TF_ASSIGN_OR_RETURN(
1078       DenseElementsAttr initial_value,
1079       CreateDenseElementsAttrFromLiteral(const_instr->literal(), builder_));
1080 
1081   std::string constant_name = xla::llvm_ir::ConstantNameToGlobalName(
1082       xla::llvm_ir::SanitizeConstantName(instr->name()));
1083 
1084   // Insert the global memref at the top level.
1085   {
1086     OpBuilder::InsertionGuard guard(builder_);
1087     builder_.clearInsertionPoint();
1088     auto global_var = builder_.create<memref::GlobalOp>(
1089         loc, constant_name, builder_.getStringAttr("private"), memref_type,
1090         initial_value, true, /*alignment=*/IntegerAttr());
1091     SymbolTable(module_).insert(global_var);
1092     global_var.getOperation()->moveBefore(&module_.front());
1093 
1094     // For operations that do not fold this constant value in their codegen, we
1095     // still need to materialize it into a buffer. Since buffer allocation is
1096     // already done, annotate the global_memref with the information to get to
1097     // the allocated buffer slice for this constant if need be.
1098     TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
1099                         assignment_.GetUniqueTopLevelSlice(instr));
1100     global_var->setAttr(
1101         "lmhlo.alloc",
1102         builder_.getIndexAttr(allocations_.find(slice.allocation())
1103                                   ->second.cast<BlockArgument>()
1104                                   .getArgNumber()));
1105     TF_RET_CHECK(slice.offset() == 0)
1106         << "Each constant should have its own allocation from BufferAssignment";
1107     TF_RET_CHECK(slice.allocation()->size() == slice.size())
1108         << "Each constant should have its own allocation from BufferAssignment";
1109   }
1110 
1111   auto get_global_memref =
1112       builder_.create<memref::GetGlobalOp>(loc, memref_type, constant_name);
1113 
1114   // Update the cache to remember this value.
1115   cached_value = get_global_memref;
1116   return get_global_memref;
1117 }
1118 
1119 namespace {
1120 template <typename OpT>
SetupChannelIdAttribute(OpT op,const xla::HloChannelInstruction * instr,mlir::Builder builder)1121 void SetupChannelIdAttribute(OpT op, const xla::HloChannelInstruction* instr,
1122                              mlir::Builder builder) {
1123   if (instr->channel_id().has_value()) {
1124     op.setChannelIdAttr(mlir::mhlo::ChannelHandleAttr::get(
1125         builder.getContext(), *instr->channel_id(), 0));
1126   }
1127 }
1128 
1129 template <typename OpT>
SetupCommonCollectiveOpAttributes(OpT op,const HloInstruction * instr,mlir::OpBuilder & builder)1130 Status SetupCommonCollectiveOpAttributes(OpT op, const HloInstruction* instr,
1131                                          mlir::OpBuilder& builder) {
1132   auto* collective = xla::Cast<xla::HloCollectiveInstruction>(instr);
1133   auto replica_groups_attr = xla::HloFunctionImporter::ConvertReplicaGroups(
1134       collective->replica_groups(), &builder);
1135   op->setAttr(replica_groups_attr.getName(), replica_groups_attr.getValue());
1136   op.setConstrainLayoutAttr(
1137       builder.getBoolAttr(collective->constrain_layout()));
1138   SetupChannelIdAttribute(op, collective, builder);
1139   return ::tensorflow::OkStatus();
1140 }
1141 }  // namespace
1142 
EmitAllToAllOp(const HloInstruction * instr)1143 StatusOr<lmhlo::AllToAllOp> LhloDialectEmitter::EmitAllToAllOp(
1144     const HloInstruction* instr) {
1145   TF_ASSIGN_OR_RETURN(auto all_to_all_op,
1146                       CreateOpWithoutAttrs<lmhlo::AllToAllOp>(instr));
1147   auto* all_to_all = xla::Cast<xla::HloAllToAllInstruction>(instr);
1148   TF_RETURN_IF_ERROR(
1149       SetupCommonCollectiveOpAttributes(all_to_all_op, instr, builder_));
1150   if (all_to_all->split_dimension().has_value()) {
1151     all_to_all_op.setSplitDimensionAttr(
1152         builder_.getI64IntegerAttr(*all_to_all->split_dimension()));
1153   }
1154   return all_to_all_op;
1155 }
1156 
EmitAllGatherOp(const HloInstruction * instr)1157 StatusOr<lmhlo::AllGatherOp> LhloDialectEmitter::EmitAllGatherOp(
1158     const HloInstruction* instr) {
1159   TF_ASSIGN_OR_RETURN(auto all_gather_op,
1160                       CreateOpWithoutAttrs<lmhlo::AllGatherOp>(instr));
1161   auto* all_gather = xla::Cast<xla::HloAllGatherInstruction>(instr);
1162   TF_RETURN_IF_ERROR(
1163       SetupCommonCollectiveOpAttributes(all_gather_op, instr, builder_));
1164   all_gather_op.setUseGlobalDeviceIdsAttr(
1165       builder_.getBoolAttr(all_gather->use_global_device_ids()));
1166   all_gather_op.setAllGatherDimensionAttr(
1167       builder_.getI64IntegerAttr(all_gather->all_gather_dimension()));
1168   return all_gather_op;
1169 }
1170 
EmitAllReduceOp(const HloInstruction * instr)1171 StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
1172     const HloInstruction* instr) {
1173   TF_ASSIGN_OR_RETURN(auto all_reduce_op,
1174                       CreateOpWithoutAttrs<lmhlo::AllReduceOp>(instr));
1175   auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr);
1176   TF_RETURN_IF_ERROR(
1177       SetupCommonCollectiveOpAttributes(all_reduce_op, instr, builder_));
1178   all_reduce_op.setUseGlobalDeviceIdsAttr(
1179       builder_.getBoolAttr(all_reduce->use_global_device_ids()));
1180   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
1181       *instr->called_computations()[0], &all_reduce_op.getComputation(),
1182       &builder_));
1183   return all_reduce_op;
1184 }
1185 
EmitAllReduceStartOp(const HloInstruction * instr)1186 StatusOr<lmhlo_gpu::AllReduceStartOp> LhloDialectEmitter::EmitAllReduceStartOp(
1187     const HloInstruction* instr) {
1188   llvm::SmallVector<Value, 4> operands;
1189   for (const HloInstruction* operand : instr->operands()) {
1190     TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands));
1191   }
1192   TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{}));
1193 
1194   Location loc = getLocation(instr);
1195   mlir::Type token_type = mlir::mhlo::TokenType::get(builder_.getContext());
1196   std::array<mlir::Type, 1> result_types = {token_type};
1197   lmhlo_gpu::AllReduceStartOp all_reduce_start_op =
1198       builder_.create<lmhlo_gpu::AllReduceStartOp>(loc, result_types, operands);
1199 
1200   auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr);
1201   TF_RETURN_IF_ERROR(
1202       SetupCommonCollectiveOpAttributes(all_reduce_start_op, instr, builder_));
1203   all_reduce_start_op.setUseGlobalDeviceIdsAttr(
1204       builder_.getBoolAttr(all_reduce->use_global_device_ids()));
1205   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
1206       *instr->called_computations()[0], &all_reduce_start_op.getComputation(),
1207       &builder_));
1208 
1209   TF_RET_CHECK(all_reduce_start_ops_.emplace(instr, all_reduce_start_op).second)
1210       << "all-reduce-start already lowered";
1211   return all_reduce_start_op;
1212 }
1213 
EmitAllReduceDoneOp(const HloInstruction * instr)1214 StatusOr<lmhlo_gpu::AllReduceDoneOp> LhloDialectEmitter::EmitAllReduceDoneOp(
1215     const HloInstruction* instr) {
1216   auto it = all_reduce_start_ops_.find(instr->operand(0));
1217   TF_RET_CHECK(it != all_reduce_start_ops_.end())
1218       << "didn't find all-reduce-start op";
1219 
1220   llvm::SmallVector<Value, 4> operands;
1221   operands.push_back(it->second.getToken());
1222   all_reduce_start_ops_.erase(it);
1223 
1224   for (const HloInstruction* operand : instr->operands()) {
1225     TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands));
1226   }
1227   // We don't need to add buffers for the outputs, as these always alias inputs.
1228   return builder_.create<lmhlo_gpu::AllReduceDoneOp>(
1229       getLocation(instr), /*resultTypes=*/llvm::None, operands);
1230 }
1231 
EmitReduceScatterOp(const HloInstruction * instr)1232 StatusOr<lmhlo::ReduceScatterOp> LhloDialectEmitter::EmitReduceScatterOp(
1233     const HloInstruction* instr) {
1234   TF_ASSIGN_OR_RETURN(auto reduce_scatter_op,
1235                       CreateOpWithoutAttrs<lmhlo::ReduceScatterOp>(instr));
1236   auto* ars = xla::Cast<xla::HloReduceScatterInstruction>(instr);
1237   TF_RETURN_IF_ERROR(
1238       SetupCommonCollectiveOpAttributes(reduce_scatter_op, instr, builder_));
1239   reduce_scatter_op.setUseGlobalDeviceIdsAttr(
1240       builder_.getBoolAttr(ars->use_global_device_ids()));
1241   TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
1242       *instr->called_computations()[0], &reduce_scatter_op.getComputation(),
1243       &builder_));
1244   reduce_scatter_op.setScatterDimensionAttr(
1245       builder_.getI64IntegerAttr(ars->scatter_dimension()));
1246   return reduce_scatter_op;
1247 }
1248 
1249 StatusOr<lmhlo::CollectivePermuteOp>
EmitCollectivePermuteOp(const HloInstruction * instr)1250 LhloDialectEmitter::EmitCollectivePermuteOp(const HloInstruction* instr) {
1251   TF_ASSIGN_OR_RETURN(auto permute_op,
1252                       CreateOpWithoutAttrs<lmhlo::CollectivePermuteOp>(instr));
1253   auto* permute = xla::Cast<xla::HloCollectivePermuteInstruction>(instr);
1254   SetupChannelIdAttribute(permute_op, permute, builder_);
1255   mlir::NamedAttribute source_target_pairs_attr =
1256       xla::HloFunctionImporter::ConvertSourceTargetPairs(
1257           permute->source_target_pairs(), &builder_);
1258   permute_op->setAttr(source_target_pairs_attr.getName(),
1259                       source_target_pairs_attr.getValue());
1260   return permute_op;
1261 }
1262 
EmitInfeedOp(const HloInstruction * instr)1263 StatusOr<lmhlo::InfeedOp> LhloDialectEmitter::EmitInfeedOp(
1264     const HloInstruction* instr) {
1265   const HloInfeedInstruction* infeed = xla::Cast<HloInfeedInstruction>(instr);
1266   // HLO Infeed instruction has a single operand of token type and a tuple
1267   // with buffers and a token as its output. LMHLO Infeed operation does not
1268   // need the token operand or result, so drop it.
1269   SmallVector<Value, 2> operands;
1270   TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{0}));
1271   auto infeed_op = CreateOpWithoutAttrs<lmhlo::InfeedOp>(instr, operands);
1272   infeed_op.setConfigAttr(builder_.getStringAttr(infeed->infeed_config()));
1273   return infeed_op;
1274 }
1275 
EmitOutfeedOp(const HloInstruction * instr)1276 StatusOr<lmhlo::OutfeedOp> LhloDialectEmitter::EmitOutfeedOp(
1277     const HloInstruction* instr) {
1278   const HloOutfeedInstruction* outfeed =
1279       xla::Cast<HloOutfeedInstruction>(instr);
1280   // HLO outfeed instruction has 2 operands, the source and a token, and a
1281   // single token output. LMHLO Outfeed does not need the token operand and
1282   // result, do drop it.
1283   SmallVector<Value, 2> operands;
1284   TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(0), &operands));
1285   auto outfeed_op = CreateOpWithoutAttrs<lmhlo::OutfeedOp>(instr, operands);
1286   outfeed_op.setConfigAttr(builder_.getStringAttr(outfeed->outfeed_config()));
1287   return outfeed_op;
1288 }
1289 
1290 xla::StatusOr<lmhlo::RngGetAndUpdateStateOp>
EmitRngGetAndUpdateStateOp(const xla::HloInstruction * instr)1291 LhloDialectEmitter::EmitRngGetAndUpdateStateOp(
1292     const xla::HloInstruction* instr) {
1293   TF_ASSIGN_OR_RETURN(
1294       auto rng, CreateOpWithoutAttrs<lmhlo::RngGetAndUpdateStateOp>(instr));
1295   auto hlo_rng = xla::Cast<xla::HloRngGetAndUpdateStateInstruction>(instr);
1296   rng.setDeltaAttr(builder_.getI64IntegerAttr(hlo_rng->delta()));
1297   return rng;
1298 }
1299 
EmitFftOp(const HloInstruction * instr)1300 xla::StatusOr<lmhlo::FftOp> LhloDialectEmitter::EmitFftOp(
1301     const HloInstruction* instr) {
1302   auto hlo_fft = xla::Cast<xla::HloFftInstruction>(instr);
1303   TF_ASSIGN_OR_RETURN(auto fft, CreateOpWithoutAttrs<lmhlo::FftOp>(instr));
1304   TF_ASSIGN_OR_RETURN(mlir::mhlo::FftType fft_type,
1305                       xla::ConvertFftType(hlo_fft->fft_type()));
1306   fft.setFftTypeAttr(
1307       mlir::mhlo::FftTypeAttr::get(builder_.getContext(), fft_type));
1308   fft.setFftLengthAttr(GetI64DenseElementsAttr(instr->fft_length()));
1309   return fft;
1310 }
1311 
1312 xla::StatusOr<lmhlo::TriangularSolveOp>
EmitTriangularSolveOp(const xla::HloInstruction * instr)1313 LhloDialectEmitter::EmitTriangularSolveOp(const xla::HloInstruction* instr) {
1314   auto hlo_triangular_solve =
1315       xla::Cast<xla::HloTriangularSolveInstruction>(instr);
1316   TF_ASSIGN_OR_RETURN(auto triangular_solve,
1317                       CreateOpWithoutAttrs<lmhlo::TriangularSolveOp>(instr));
1318   const xla::TriangularSolveOptions& options =
1319       hlo_triangular_solve->triangular_solve_options();
1320   triangular_solve.setLeftSideAttr(builder_.getBoolAttr(options.left_side()));
1321   triangular_solve.setLowerAttr(builder_.getBoolAttr(options.lower()));
1322   triangular_solve.setUnitDiagonalAttr(
1323       builder_.getBoolAttr(options.unit_diagonal()));
1324   TF_ASSIGN_OR_RETURN(mlir::mhlo::Transpose transpose,
1325                       xla::ConvertTranspose(options.transpose_a()));
1326   triangular_solve.setTransposeAAttr(
1327       mlir::mhlo::TransposeAttr::get(builder_.getContext(), transpose));
1328   triangular_solve.setLayoutAAttr(
1329       GetLayoutAttribute(instr->operand(0)->shape().layout(), &builder_));
1330   triangular_solve.setLayoutBAttr(
1331       GetLayoutAttribute(instr->operand(1)->shape().layout(), &builder_));
1332   triangular_solve.setLayoutOutputAttr(
1333       GetLayoutAttribute(instr->shape().layout(), &builder_));
1334   return triangular_solve;
1335 }
1336 
EmitBitcast(const xla::HloInstruction * instr)1337 xla::StatusOr<Operation*> LhloDialectEmitter::EmitBitcast(
1338     const xla::HloInstruction* instr) {
1339   // XLA buffer assignment should assign the same slice to a bitcast input and
1340   // output.
1341   const xla::ShapeIndex top_index;
1342   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
1343                       assignment_.GetUniqueSlice(instr, top_index));
1344   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
1345                       assignment_.GetUniqueSlice(instr->operand(0), top_index));
1346 
1347   if (input_slice != result_slice) {
1348     return xla::InvalidArgument(
1349         "Bitcast input and result slice should be same");
1350   }
1351   return nullptr;
1352 }
1353 
GetLayoutAttribute(const xla::Layout & layout,Builder * builder)1354 mlir::DenseIntElementsAttr LhloDialectEmitter::GetLayoutAttribute(
1355     const xla::Layout& layout, Builder* builder) {
1356   llvm::SmallVector<int64_t, 4> minor_to_major(layout.minor_to_major().begin(),
1357                                                layout.minor_to_major().end());
1358   return builder->getIndexTensorAttr(minor_to_major);
1359 }
1360 
ImportAsLmhloRegion(xla::HloComputation * computation,mlir::Region * region)1361 Status LhloDialectEmitter::ImportAsLmhloRegion(xla::HloComputation* computation,
1362                                                mlir::Region* region) {
1363   auto after = builder_.saveInsertionPoint();
1364   auto reverter = absl::MakeCleanup(
1365       [this, after] { builder_.restoreInsertionPoint(after); });
1366 
1367   builder_ = OpBuilder(region);
1368   const xla::HloInstructionSequence* schedule =
1369       assignment_.hlo_ordering().SequentialOrder(*computation);
1370   if (!schedule)
1371     return xla::Unimplemented("Missing sequential order for the computation");
1372   TF_RETURN_IF_ERROR(
1373       computation->AcceptOrdered(this, schedule->instructions()));
1374   builder_.create<lmhlo::TerminatorOp>(builder_.getUnknownLoc());
1375   return ::tensorflow::OkStatus();
1376 }
1377 
EmitCaseOp(const HloInstruction * instr)1378 StatusOr<lmhlo::CaseOp> LhloDialectEmitter::EmitCaseOp(
1379     const HloInstruction* instr) {
1380   Location loc = getLocation(instr);
1381   llvm::SmallVector<Value, 4> operands;
1382   size_t num_arguments, num_results;
1383   TF_RETURN_IF_ERROR(CreateOperands(instr, 1, TokenLoweringMode::kUseNull,
1384                                     operands, num_arguments, num_results));
1385 
1386   auto case_op =
1387       builder_.create<lmhlo::CaseOp>(loc, operands[0], instr->branch_count());
1388 
1389   for (int i = 0; i < instr->branch_count(); i++) {
1390     case_op.getBranches()[i].push_back(new mlir::Block());
1391     TF_RETURN_IF_ERROR(ImportAsLmhloRegion(instr->called_computations()[i],
1392                                            &case_op.getBranches()[i]));
1393   }
1394 
1395   return case_op;
1396 }
1397 
EmitWhileOp(const xla::HloInstruction * instr)1398 xla::StatusOr<lmhlo::WhileOp> LhloDialectEmitter::EmitWhileOp(
1399     const xla::HloInstruction* instr) {
1400   Location loc = getLocation(instr);
1401   SmallVector<Value, 1> operands;
1402   TF_RETURN_IF_ERROR(GetOrCreateView(
1403       instr->called_computations()[1]->root_instruction(), &operands));
1404   TF_RET_CHECK(operands.size() == 1);
1405 
1406   TF_ASSIGN_OR_RETURN(auto config,
1407                       instr->backend_config<xla::WhileLoopBackendConfig>());
1408   mlir::IntegerAttr trip_count;
1409   if (config.has_known_trip_count()) {
1410     trip_count = builder_.getI64IntegerAttr(config.known_trip_count().n());
1411   }
1412   lmhlo::WhileOp while_op =
1413       builder_.create<lmhlo::WhileOp>(loc, operands[0], trip_count);
1414 
1415   while_op.getCond().push_back(new mlir::Block());
1416   while_op.getBody().push_back(new mlir::Block());
1417   TF_RETURN_IF_ERROR(ImportAsLmhloRegion(instr->called_computations()[1],
1418                                          &while_op.getCond()));
1419 
1420   TF_RETURN_IF_ERROR(ImportAsLmhloRegion(instr->called_computations()[0],
1421                                          &while_op.getBody()));
1422 
1423   return while_op;
1424 }
1425 
GetOrCreateArrayView(const xla::HloInstruction * instr,const xla::Shape & current_shape,const xla::ShapeIndex & shape_index)1426 StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
1427     const xla::HloInstruction* instr, const xla::Shape& current_shape,
1428     const xla::ShapeIndex& shape_index) {
1429   // For constants, the cache is managed inside EmitConstant since it can
1430   // be called either from here or when we see a top-level HloConstant instr.
1431   if (instr->IsConstant() && shape_index.empty()) {
1432     TF_ASSIGN_OR_RETURN(Value constant_memref, EmitConstant(instr));
1433     return constant_memref;
1434   }
1435 
1436   // Cache generated ViewOp and StaticMemRefCastOp by (instruction,
1437   // shape_index).
1438   auto& cached_value = slices_[std::make_pair(instr, shape_index)];
1439   if (cached_value) {
1440     return cached_value;
1441   }
1442 
1443   // If the shape happens to have dynamic dimensions, create the memref using
1444   // the underlying static shape.
1445   // TODO(jurahul): Revisit this when we can model memrefs with dynamic shape
1446   // but static bounds in MLIR.
1447   const Shape static_shape = xla::ShapeUtil::MakeStaticShape(current_shape);
1448 
1449   TF_ASSIGN_OR_RETURN(Type out_type, xla::ConvertShapeToType<MemRefType>(
1450                                          static_shape, builder_));
1451   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
1452                       assignment_.GetUniqueSlice(instr, shape_index));
1453   Value alloc = allocations_[slice.allocation()];
1454 
1455   // TODO(timshen): revisit location handling.
1456   Location loc = builder_.getUnknownLoc();
1457 
1458   Value byte_shift =
1459       builder_.create<arith::ConstantIndexOp>(alloc.getLoc(), slice.offset());
1460 
1461   xla::Shape physical_shape =
1462       xla::ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
1463           static_shape);
1464   TF_ASSIGN_OR_RETURN(
1465       Type physical_out_type,
1466       xla::ConvertShapeToType<MemRefType>(physical_shape, builder_));
1467 
1468   // ViewOp only takes memrefs without affine maps (layouts). Let ViewOp
1469   // produce the physical shape (where dimensions are ordered in major to
1470   // minor) first, then follow up with a MemRefReinterpretCast to cast the
1471   // resulting memref to the original layout.
1472   Value result =
1473       builder_.create<memref::ViewOp>(loc, physical_out_type, alloc, byte_shift,
1474                                       /*sizes=*/ValueRange{});
1475   if (result.getType() != out_type) {
1476     int64_t out_offset;
1477     SmallVector<int64_t, 4> out_strides;
1478     auto out_memref_type = out_type.dyn_cast<MemRefType>();
1479     if (!out_memref_type)
1480       return tensorflow::errors::Internal(
1481           "Expected memref type when creating a view for leaf type of a "
1482           "tuple.");
1483     if (failed(getStridesAndOffset(out_memref_type, out_strides, out_offset)))
1484       return tensorflow::errors::Internal(
1485           "Failed to get strides and offset from the output type.");
1486     result = builder_.create<memref::ReinterpretCastOp>(
1487         loc, out_memref_type, result, out_offset, out_memref_type.getShape(),
1488         out_strides);
1489   }
1490   return cached_value = result;
1491 }
1492 
GetOrCreateViewImpl(const HloInstruction * instr,const Shape & current_shape,xla::ShapeIndex * current_shape_index,SmallVectorImpl<Value> * values,TokenLoweringMode token_mode)1493 Status LhloDialectEmitter::GetOrCreateViewImpl(
1494     const HloInstruction* instr, const Shape& current_shape,
1495     xla::ShapeIndex* current_shape_index, SmallVectorImpl<Value>* values,
1496     TokenLoweringMode token_mode) {
1497   if (current_shape.IsTuple()) {
1498     for (int i = 0; i < current_shape.tuple_shapes().size(); ++i) {
1499       current_shape_index->push_back(i);
1500       TF_RETURN_IF_ERROR(
1501           GetOrCreateViewImpl(instr, current_shape.tuple_shapes(i),
1502                               current_shape_index, values, token_mode));
1503       current_shape_index->pop_back();
1504     }
1505     return ::tensorflow::OkStatus();
1506   }
1507   if (current_shape.IsArray()) {
1508     TF_ASSIGN_OR_RETURN(auto v, GetOrCreateArrayView(instr, current_shape,
1509                                                      *current_shape_index));
1510     values->push_back(v);
1511     return ::tensorflow::OkStatus();
1512   }
1513   if (current_shape.IsToken()) {
1514     switch (token_mode) {
1515       case TokenLoweringMode::kFailToLower:
1516         return xla::InternalError(
1517             "Unexpected token kind for %s and shape index %s",
1518             instr->ToString(), current_shape_index->ToString());
1519 
1520       case TokenLoweringMode::kUseNull:
1521         values->push_back(Value{});
1522         return ::tensorflow::OkStatus();
1523     }
1524   }
1525   return xla::InternalError("Unexpected shape kind for %s and shape index %s",
1526                             instr->ToString(), current_shape_index->ToString());
1527 }
1528 
1529 // Returns a view for the result of an instruction.
1530 // We first get a view for the slice in the allocation, and then may need to
1531 // create another view to adjust the slice for the shape of the instruction.
GetOrCreateView(const HloInstruction * instr,SmallVectorImpl<Value> * values,const xla::ShapeIndex & result_subset,TokenLoweringMode token_mode)1532 Status LhloDialectEmitter::GetOrCreateView(const HloInstruction* instr,
1533                                            SmallVectorImpl<Value>* values,
1534                                            const xla::ShapeIndex& result_subset,
1535                                            TokenLoweringMode token_mode) {
1536   xla::ShapeIndex shape_index = result_subset;
1537   const Shape& sub_shape =
1538       xla::ShapeUtil::GetSubshape(instr->shape(), shape_index);
1539   return GetOrCreateViewImpl(instr, sub_shape, &shape_index, values,
1540                              token_mode);
1541 }
1542 
Initialize()1543 Status LhloDialectEmitter::Initialize() {
1544   TF_RET_CHECK(computation_.IsEntryComputation());
1545 
1546   mlir::IntegerAttr unique_id =
1547       builder_.getI32IntegerAttr(computation_.parent()->unique_id());
1548   module_->setAttr("hlo.unique_id", unique_id);
1549   std::string function_name =
1550       computation_.name().empty() ? "__compute" : computation_.name();
1551 
1552   // Create the function as () -> (), we'll compute the arguments from the
1553   // buffer allocation and update the type then.
1554   auto func_op = func::FuncOp::create(builder_.getUnknownLoc(), function_name,
1555                                       builder_.getFunctionType({}, {}));
1556 
1557   {
1558     // This is an optional attribute used by the XLA backend. If the resulting
1559     // LMHLO doesn't go through XLA, this is not needed.
1560     const Shape& shape = computation_.root_instruction()->shape();
1561     func_op->setAttr(
1562         "result_xla_shape",
1563         builder_.getStringAttr(shape.ToString(/*print_layout=*/true)));
1564   }
1565   Block* block = func_op.addEntryBlock();
1566 
1567   llvm::SmallVector<const BufferAllocation*, 8> ordered_allocations;
1568   for (const BufferAllocation& alloc : assignment_.Allocations())
1569     ordered_allocations.push_back(&alloc);
1570 
1571   if (computation_.IsEntryComputation()) {
1572     // Sort the rather arbitrarily ordered allocations to match the input/output
1573     // parameters. Specifically we want to sort buffer allocations in the
1574     // following order:
1575     // * Parameters always order before non-parameters.
1576     // * Different parameters order by parameter number.
1577     // * Different allocations for the same parameter order by the shape index.
1578     //
1579     // TODO(timshen): there should be only one non-parameter buffer, the temp
1580     // buffer. Check on that.
1581     const auto allocation_comparator = [](const BufferAllocation* lhs,
1582                                           const BufferAllocation* rhs) {
1583       if (lhs->is_entry_computation_parameter() !=
1584           rhs->is_entry_computation_parameter()) {
1585         return lhs->is_entry_computation_parameter() >
1586                rhs->is_entry_computation_parameter();
1587       }
1588       if (lhs->is_entry_computation_parameter()) {
1589         return std::tuple<int, const xla::ShapeIndex&>(
1590                    lhs->parameter_number(), lhs->param_shape_index()) <
1591                std::tuple<int, const xla::ShapeIndex&>(
1592                    rhs->parameter_number(), rhs->param_shape_index());
1593       }
1594       return false;
1595     };
1596 
1597     std::stable_sort(ordered_allocations.begin(), ordered_allocations.end(),
1598                      allocation_comparator);
1599   }
1600 
1601   absl::flat_hash_map<const BufferAllocation*,
1602                       std::pair<const Shape*, xla::ShapeIndex>>
1603       allocation_to_output_info;
1604   TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus(
1605       computation_.root_instruction()->shape(),
1606       [&](const Shape& sub_shape, xla::ShapeIndex index) -> Status {
1607         TF_ASSIGN_OR_RETURN(
1608             auto slice,
1609             assignment_.GetUniqueSlice(computation_.root_instruction(), index));
1610         const BufferAllocation* alloc = slice.allocation();
1611         TF_RET_CHECK(slice.offset() == 0);
1612         TF_RET_CHECK(slice.size() == alloc->size());
1613         allocation_to_output_info[alloc] = std::make_pair(&sub_shape, index);
1614         return ::tensorflow::OkStatus();
1615       }));
1616 
1617   // The function signature will be composed of:
1618   // - one memref for each of the parameters.
1619   // - one memref for each other buffer allocation.
1620   llvm::SmallVector<DictionaryAttr, 8> args_attrs;
1621   for (const BufferAllocation* alloc : ordered_allocations) {
1622     if (alloc->is_thread_local()) {
1623       continue;
1624     }
1625 
1626     // There are optional attributes to help the program run through XLA. XLA
1627     // defines ExecutionInput and ExecutionOutput structures to carry
1628     // input-output type and buffer information, therefore any information they
1629     // need (mainly the type structure, potentially containing tuples) to be
1630     // preserved. They are not needed if the generated LMHLO is not sent to XLA.
1631     NamedAttrList arg_attr_list;
1632     mlir::Type arg_type = MemRefType::get({alloc->size()}, i8_type_);
1633 
1634     // Propagate source location information for every HLOInstruction that
1635     // uses this allocation.
1636     std::vector<mlir::Location> buf_locs;
1637     buf_locs.reserve(alloc->assigned_buffers().size());
1638     for (const auto& entry : alloc->assigned_buffers()) {
1639       const xla::HloValue* hlo_value = entry.first;
1640       buf_locs.push_back(getLocation(hlo_value->instruction()));
1641     }
1642     mlir::Location loc = builder_.getFusedLoc(buf_locs);
1643 
1644     if (alloc->is_entry_computation_parameter()) {
1645       arg_attr_list.set("lmhlo.params",
1646                         builder_.getIndexAttr(alloc->parameter_number()));
1647       if (!alloc->param_shape_index().empty()) {
1648         arg_attr_list.set("lmhlo.param_shape_index",
1649                           builder_.getI64TensorAttr(llvm::makeArrayRef(
1650                               alloc->param_shape_index().begin(),
1651                               alloc->param_shape_index().end())));
1652       }
1653     }
1654     // Optional: an attribute for optimization. If a kernel uses this
1655     // allocation, but the allocation has lmhlo.constant_name, then the kernel
1656     // will instead use the global value indicated by the name for potentially
1657     // more optimizations (e.g. constant propagation).
1658     if (alloc->is_constant()) {
1659       arg_attr_list.set(
1660           "lmhlo.constant_name",
1661           builder_.getStringAttr(
1662               xla::llvm_ir::ConstantBufferAllocationToGlobalName(*alloc)));
1663     }
1664     auto iter = allocation_to_output_info.find(alloc);
1665     if (iter != allocation_to_output_info.end()) {
1666       const Shape* sub_shape = iter->second.first;
1667       const xla::ShapeIndex& shape_index = iter->second.second;
1668       if (!sub_shape->IsArray()) {
1669         continue;
1670       }
1671       arg_attr_list.set("lmhlo.output_index",
1672                         builder_.getI64TensorAttr(llvm::makeArrayRef(
1673                             shape_index.begin(), shape_index.end())));
1674       if (auto alias = computation_.parent()
1675                            ->input_output_alias_config()
1676                            .GetAliasedParameter(shape_index)) {
1677         if (alias->must_alias()) {
1678           arg_attr_list.set("lmhlo.must_alias", builder_.getUnitAttr());
1679         }
1680       }
1681     }
1682     block->addArgument(arg_type, loc);
1683     allocations_[alloc] = block->getArguments().back();
1684     args_attrs.push_back(arg_attr_list.getDictionary(builder_.getContext()));
1685   }
1686 
1687   FunctionType function_type =
1688       builder_.getFunctionType(block->getArgumentTypes(), {});
1689   func_op.setType(function_type);
1690   func_op.setAllArgAttrs(args_attrs);
1691 
1692   SymbolTable symbol_table(module_);
1693   symbol_table.insert(func_op);
1694   builder_.setInsertionPointToEnd(block);
1695 
1696   auto return_op =
1697       builder_.create<lmhlo::TerminatorOp>(builder_.getUnknownLoc());
1698   builder_ = OpBuilder(return_op);
1699 
1700   return ::tensorflow::OkStatus();
1701 }
1702 
createXlaHloToLhloWithXlaPass()1703 std::unique_ptr<OperationPass<ModuleOp>> createXlaHloToLhloWithXlaPass() {
1704   return std::make_unique<XlaHloToLhloPass>();
1705 }
1706 
HloToLhloModule(const BufferAssignment & assignment,const HloModule & hlo_module,ModuleOp module)1707 Status HloToLhloModule(const BufferAssignment& assignment,
1708                        const HloModule& hlo_module, ModuleOp module) {
1709   module.getContext()
1710       ->loadDialect<arith::ArithmeticDialect,
1711                     bufferization::BufferizationDialect, func::FuncDialect,
1712                     memref::MemRefDialect, mhlo::MhloDialect,
1713                     lmhlo::LmhloDialect, lmhlo_gpu::LmhloGpuDialect>();
1714 
1715   module->setLoc(mlir::NameLoc::get(
1716       mlir::StringAttr::get(module.getContext(), hlo_module.name())));
1717 
1718   // Store the HloModule's unique_id in the MLIR module.
1719   Builder builder(module.getContext());
1720   module->setAttr("mhlo.unique_id",
1721                   builder.getI64IntegerAttr(hlo_module.unique_id()));
1722 
1723   const HloComputation* computation = hlo_module.entry_computation();
1724 
1725   LhloDialectEmitter emitter(assignment, *computation, module);
1726   TF_RETURN_IF_ERROR(emitter.Initialize());
1727 
1728   const xla::HloInstructionSequence* schedule =
1729       assignment.hlo_ordering().SequentialOrder(*computation);
1730   if (!schedule)
1731     return xla::Unimplemented("Missing sequential order for the computation");
1732 
1733   StatusScopedDiagnosticHandler status_handler(module.getContext());
1734 
1735   const std::vector<HloInstruction*>& ordering = schedule->instructions();
1736   TF_RETURN_IF_ERROR(computation->AcceptOrdered(&emitter, ordering));
1737   TF_RETURN_IF_ERROR(status_handler.ConsumeStatus());
1738 
1739   (void)mlir::verify(module);
1740   return status_handler.ConsumeStatus();
1741 }
1742 
HloTextToLhloTranslateFunction(llvm::StringRef input,MLIRContext * context,bool optimize_xla_hlo)1743 OwningOpRef<mlir::ModuleOp> HloTextToLhloTranslateFunction(
1744     llvm::StringRef input, MLIRContext* context, bool optimize_xla_hlo) {
1745   StatusOr<std::unique_ptr<HloModule>> maybe_module =
1746       xla::ParseAndReturnUnverifiedModule(
1747           absl::string_view(input.data(), input.size()));
1748   TF_CHECK_OK(maybe_module.status());
1749 
1750   OwningOpRef<mlir::ModuleOp> module =
1751       ModuleOp::create(UnknownLoc::get(context));
1752 
1753   TF_CHECK_OK(OptimizeAndConvertHloToLmhlo(
1754       std::move(maybe_module).value(), module.get(), "Host", optimize_xla_hlo));
1755 
1756   return module;
1757 }
1758 
RegisterMhloToLhloWithXlaPass()1759 void RegisterMhloToLhloWithXlaPass() {
1760   static PassRegistration<XlaHloToLhloPass> registration;
1761 }
1762 
1763 }  // namespace mlir
1764