1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
17
18 #include <climits>
19 #include <memory>
20 #include <tuple>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/cleanup/cleanup.h"
24 #include "absl/types/optional.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
28 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" // from @llvm-project
29 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
30 #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
31 #include "mlir/IR/AffineExpr.h" // from @llvm-project
32 #include "mlir/IR/AffineMap.h" // from @llvm-project
33 #include "mlir/IR/Attributes.h" // from @llvm-project
34 #include "mlir/IR/Builders.h" // from @llvm-project
35 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
36 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
37 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
38 #include "mlir/IR/Dialect.h" // from @llvm-project
39 #include "mlir/IR/Location.h" // from @llvm-project
40 #include "mlir/IR/MLIRContext.h" // from @llvm-project
41 #include "mlir/IR/OpDefinition.h" // from @llvm-project
42 #include "mlir/IR/Operation.h" // from @llvm-project
43 #include "mlir/IR/PatternMatch.h" // from @llvm-project
44 #include "mlir/IR/SymbolTable.h" // from @llvm-project
45 #include "mlir/IR/Verifier.h" // from @llvm-project
46 #include "mlir/Pass/Pass.h" // from @llvm-project
47 #include "mlir/Pass/PassOptions.h" // from @llvm-project
48 #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project
49 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
50 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
51 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
52 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
53 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
54 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
55 #include "tensorflow/compiler/xla/debug_options_flags.h"
56 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
57 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
58 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
59 #include "tensorflow/compiler/xla/service/backend.h"
60 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
61 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
62 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
63 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
64 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
65 #include "tensorflow/compiler/xla/service/hlo_computation.h"
66 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
67 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
68 #include "tensorflow/compiler/xla/service/hlo_module.h"
69 #include "tensorflow/compiler/xla/service/hlo_parser.h"
70 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
71 #include "tensorflow/compiler/xla/shape_util.h"
72 #include "tensorflow/compiler/xla/statusor.h"
73 #include "tensorflow/compiler/xla/util.h"
74 #include "tensorflow/compiler/xla/window_util.h"
75 #include "tensorflow/compiler/xla/xla_data.pb.h"
76
77 using xla::BufferAllocation;
78 using xla::BufferAssignment;
79 using xla::HloComputation;
80 using xla::HloCustomCallInstruction;
81 using xla::HloInfeedInstruction;
82 using xla::HloInstruction;
83 using xla::HloModule;
84 using xla::HloModuleProto;
85 using xla::HloOutfeedInstruction;
86 using xla::HloProto;
87 using xla::Shape;
88 using xla::StatusOr;
89
90 namespace mlir {
91 namespace {
92
StringRefToView(llvm::StringRef ref)93 absl::string_view StringRefToView(llvm::StringRef ref) {
94 return {ref.data(), ref.size()};
95 }
96
HloModuleFromProto(const HloProto & hlo_proto)97 StatusOr<std::unique_ptr<HloModule>> HloModuleFromProto(
98 const HloProto& hlo_proto) {
99 const HloModuleProto& module_proto = hlo_proto.hlo_module();
100 TF_ASSIGN_OR_RETURN(const xla::HloModuleConfig module_config,
101 HloModule::CreateModuleConfigFromProto(
102 module_proto, xla::GetDebugOptionsFromFlags()));
103 return HloModule::CreateFromProto(module_proto, module_config);
104 }
105
106 } // namespace
107
108 // Convert the MLIR `module` from HLO dialect to LHLO dialect using XLA for the
109 // given platform.
OptimizeAndConvertHloToLmhlo(std::unique_ptr<HloModule> hlo_module,ModuleOp module,StringRef platform_name,bool optimize_xla_hlo)110 Status OptimizeAndConvertHloToLmhlo(std::unique_ptr<HloModule> hlo_module,
111 ModuleOp module, StringRef platform_name,
112 bool optimize_xla_hlo) {
113 auto platform = xla::se::MultiPlatformManager::PlatformWithName(
114 StringRefToView(platform_name));
115 if (!platform.ok()) {
116 std::string error_msg;
117 llvm::raw_string_ostream os(error_msg);
118 os << "failed to get platform: " << platform.status().ToString()
119 << " (available Platform: ";
120 std::vector<std::string> available_platforms;
121 (void)xla::se::MultiPlatformManager::PlatformsWithFilter(
122 [&](const stream_executor::Platform* p) {
123 available_platforms.push_back(p->Name());
124 return false;
125 });
126 llvm::interleaveComma(available_platforms, os);
127 os << ")";
128 return xla::InvalidArgument("%s", os.str().c_str());
129 }
130
131 xla::BackendOptions backend_options;
132 backend_options.set_platform(platform.ValueOrDie());
133 auto backend_or_err = xla::Backend::CreateBackend(backend_options);
134 TF_RETURN_WITH_CONTEXT_IF_ERROR(backend_or_err.status(),
135 "failed to create XLA Backend ");
136 auto backend = std::move(backend_or_err.ValueOrDie());
137
138 StatusOr<std::unique_ptr<HloModule>> optimized_hlo_module;
139
140 if (optimize_xla_hlo) {
141 // Run all HLO passes to produce an optimized module.
142 optimized_hlo_module = backend->compiler()->RunHloPasses(
143 std::move(hlo_module), backend->default_stream_executor(),
144 backend->memory_allocator());
145 TF_RETURN_WITH_CONTEXT_IF_ERROR(optimized_hlo_module.status(),
146 "running XLA pass pipeline");
147 } else {
148 optimized_hlo_module = std::move(hlo_module);
149 }
150
151 StatusOr<std::unique_ptr<BufferAssignment>> assignment =
152 backend->compiler()->AssignBuffers(optimized_hlo_module->get());
153 TF_RETURN_WITH_CONTEXT_IF_ERROR(assignment.status(),
154 "running XLA buffer assigment");
155
156 // Clear the module before populating it back with the result of the
157 // conversion.
158 module.getBody()->clear();
159 OpBuilder builder(module);
160
161 TF_RETURN_WITH_CONTEXT_IF_ERROR(
162 HloToLhloModule(**assignment, **optimized_hlo_module, module),
163 "converting HLO to LHLO");
164
165 return ::tensorflow::OkStatus();
166 }
167
168 namespace {
169 // This pass takes an MLIR HLO module, converts it to XLA to perform the HLO
170 // optimization pipeline for the required platform, and then converts it back to
171 // MLIR LHLO.
172 class XlaHloToLhloPass
173 : public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const174 void getDependentDialects(DialectRegistry& registry) const override {
175 registry
176 .insert<arith::ArithmeticDialect, bufferization::BufferizationDialect,
177 func::FuncDialect, memref::MemRefDialect, mhlo::MhloDialect,
178 lmhlo::LmhloDialect, lmhlo_gpu::LmhloGpuDialect>();
179 }
180
181 public:
182 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XlaHloToLhloPass)
183
184 XlaHloToLhloPass() = default;
XlaHloToLhloPass(const XlaHloToLhloPass &)185 XlaHloToLhloPass(const XlaHloToLhloPass&) {}
getArgument() const186 StringRef getArgument() const final { return "xla-hlo-to-lhlo-with-xla"; }
getDescription() const187 StringRef getDescription() const final {
188 return "Emit LHLO from HLO using the existing XLA implementation";
189 }
190
191 private:
runOnOperation()192 void runOnOperation() final {
193 ModuleOp module = getOperation();
194
195 auto status = [&module, this]() -> Status {
196 SymbolTable symbol_table(module);
197 if (!symbol_table.lookup("main")) {
198 return xla::InvalidArgument(
199 "conversion to HLO module failed: missing main()");
200 }
201 HloProto hlo_proto;
202 TF_RETURN_WITH_CONTEXT_IF_ERROR(
203 ConvertMlirHloToHlo(module, &hlo_proto,
204 /*use_tuple_args=*/false,
205 /*return_tuple=*/false),
206 "conversion to XLA HLO proto failed");
207
208 auto statusOrHloModule = HloModuleFromProto(hlo_proto);
209 TF_RETURN_WITH_CONTEXT_IF_ERROR(statusOrHloModule.status(),
210 "parsing HLO proto to HLO module failed");
211 std::unique_ptr<HloModule> hlo_module =
212 std::move(statusOrHloModule.ValueOrDie());
213
214 return OptimizeAndConvertHloToLmhlo(std::move(hlo_module), module,
215 platform_, optimize_xla_hlo_);
216 }();
217 if (!status.ok()) {
218 module.emitError() << status.ToString();
219 return signalPassFailure();
220 }
221 }
222
223 Option<std::string> platform_{
224 *this, "platform",
225 llvm::cl::desc("The platform to use for the XLA optimization pipeline."),
226 llvm::cl::init("Host")};
227 Option<bool> optimize_xla_hlo_{
228 *this, "optimize-xla-hlo",
229 llvm::cl::desc("Whether to apply HLO optimizations."),
230 llvm::cl::init(true)};
231 };
232
233 } // namespace
234
235 // Creates MLIR operands corresponding to operands and results of the XLA HLO
236 // instruction. If `num_operands` is valid, then only the first `num_operands`
237 // operands of the HLO instruction will be considered.
CreateOperands(const HloInstruction * instr,std::optional<int64_t> num_operands,TokenLoweringMode token_mode,llvm::SmallVectorImpl<Value> & operands,size_t & num_arguments,size_t & num_results)238 Status LhloDialectEmitter::CreateOperands(
239 const HloInstruction* instr, std::optional<int64_t> num_operands,
240 TokenLoweringMode token_mode, llvm::SmallVectorImpl<Value>& operands,
241 size_t& num_arguments, size_t& num_results) {
242 if (num_operands.value_or(0) > instr->operand_count())
243 return xla::InvalidArgument("num_operands must be <= operand count");
244 for (int64_t i = 0; i < num_operands.value_or(instr->operand_count()); ++i) {
245 TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands,
246 /*result_subset=*/{}, token_mode));
247 }
248 num_arguments = operands.size();
249 TF_RETURN_IF_ERROR(
250 GetOrCreateView(instr, &operands, /*result_subset=*/{}, token_mode));
251 num_results = operands.size() - num_arguments;
252 return ::tensorflow::OkStatus();
253 }
254
255 template <typename OpType>
CreateOpWithoutAttrs(const HloInstruction * instr,ValueRange operands)256 OpType LhloDialectEmitter::CreateOpWithoutAttrs(const HloInstruction* instr,
257 ValueRange operands) {
258 Location loc = getLocation(instr);
259 return builder_.create<OpType>(loc, llvm::None, operands,
260 llvm::ArrayRef<NamedAttribute>{});
261 }
262
263 template <typename OpType>
CreateOpWithoutAttrs(const HloInstruction * instr,size_t & num_arguments,size_t & num_results,std::optional<int64_t> num_operands)264 StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
265 const HloInstruction* instr, size_t& num_arguments, size_t& num_results,
266 std::optional<int64_t> num_operands) {
267 llvm::SmallVector<Value, 4> operands;
268 TF_RETURN_IF_ERROR(CreateOperands(instr, num_operands,
269 TokenLoweringMode::kFailToLower, operands,
270 num_arguments, num_results));
271 return CreateOpWithoutAttrs<OpType>(instr, operands);
272 }
273
CreateOpInFusion(const HloInstruction * instr,ValueRange buffer_operands,size_t num_arguments,size_t num_results)274 StatusOr<mlir::Operation*> LhloDialectEmitter::CreateOpInFusion(
275 const HloInstruction* instr, ValueRange buffer_operands,
276 size_t num_arguments, size_t num_results) {
277 Location loc = getLocation(instr);
278 std::vector<Value> buffers(buffer_operands.begin(), buffer_operands.end());
279 absl::Span<Value> arguments =
280 absl::MakeSpan(buffers).subspan(0, num_arguments);
281 absl::Span<Value> results =
282 absl::MakeSpan(buffers).subspan(num_arguments, num_results);
283
284 mlir::lmhlo::FusionOp fusion = builder_.create<mlir::lmhlo::FusionOp>(loc);
285 mlir::OpBuilder b(&fusion.getRegion());
286
287 llvm::SmallVector<mlir::Value, 4> loads;
288 for (Value arg : arguments) {
289 auto load = b.create<mlir::bufferization::ToTensorOp>(loc, arg);
290 Shape shape = xla::TypeToShape(arg.getType());
291 TF_RET_CHECK(shape.IsArray());
292 if (shape.layout() !=
293 xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) {
294 load->setAttr("xla_shape",
295 b.getStringAttr(shape.ToString(/*print_layout=*/true)));
296 }
297 loads.push_back(load);
298 }
299 mlir::Operation* op = nullptr;
300 if (instr->opcode() == xla::HloOpcode::kReduce) {
301 TF_RET_CHECK(loads.size() % 2 == 0);
302 std::vector<int64_t> dimensions(instr->dimensions().begin(),
303 instr->dimensions().end());
304 auto reduce_op = b.create<mhlo::ReduceOp>(
305 loc, llvm::makeArrayRef(loads).take_front(loads.size() / 2),
306 llvm::makeArrayRef(loads).drop_front(loads.size() / 2),
307 GetI64DenseElementsAttr(dimensions));
308
309 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
310 *instr->called_computations()[0], &reduce_op.body(), &builder_,
311 /*flatten_region_arg_tuple=*/true));
312 op = reduce_op;
313 } else {
314 TF_ASSIGN_OR_RETURN(
315 op,
316 xla::HloFunctionImporter::ImportInstruction(
317 instr, loads, &b, xla::DynamicShapeHandlingMode::kConvertToStatic));
318 }
319 TF_RET_CHECK(op->getNumResults() == num_results);
320 for (int i = 0; i < results.size(); i++) {
321 b.create<mlir::memref::TensorStoreOp>(loc, op->getResult(i), results[i]);
322 }
323 return op;
324 }
325
CreateOpInFusion(const HloInstruction * instr)326 StatusOr<mlir::Operation*> LhloDialectEmitter::CreateOpInFusion(
327 const HloInstruction* instr) {
328 llvm::SmallVector<Value, 4> operands;
329 size_t num_arguments, num_results;
330 TF_RETURN_IF_ERROR(CreateOperands(instr, std::nullopt,
331 TokenLoweringMode::kFailToLower, operands,
332 num_arguments, num_results));
333 TF_ASSIGN_OR_RETURN(
334 auto op, CreateOpInFusion(instr, operands, num_arguments, num_results));
335 return op->getParentOp();
336 }
337
EmitOp(const HloInstruction * instr)338 StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
339 const HloInstruction* instr) {
340 using xla::HloOpcode;
341 switch (instr->opcode()) {
342 case HloOpcode::kAddDependency:
343 return nullptr;
344 case HloOpcode::kAfterAll:
345 // LMHLO is already ordered. This assumption may be broken after
346 // introducing async regions and partial orders.
347 return nullptr;
348 case HloOpcode::kAllToAll:
349 return EmitAllToAllOp(instr);
350 case HloOpcode::kAllGather:
351 return EmitAllGatherOp(instr);
352 case HloOpcode::kAllReduce:
353 return EmitAllReduceOp(instr);
354 case HloOpcode::kAllReduceStart:
355 return EmitAllReduceStartOp(instr);
356 case HloOpcode::kAllReduceDone:
357 return EmitAllReduceDoneOp(instr);
358 case HloOpcode::kReduceScatter:
359 return EmitReduceScatterOp(instr);
360 case HloOpcode::kBitcast:
361 return EmitBitcast(instr);
362 case HloOpcode::kCollectivePermute:
363 return EmitCollectivePermuteOp(instr);
364 case HloOpcode::kConditional:
365 return EmitCaseOp(instr);
366 case HloOpcode::kFft:
367 return EmitFftOp(instr);
368 case HloOpcode::kGetTupleElement:
369 return nullptr;
370 case HloOpcode::kInfeed:
371 return EmitInfeedOp(instr);
372 case HloOpcode::kOutfeed:
373 return EmitOutfeedOp(instr);
374 case HloOpcode::kPartitionId:
375 return CreateOpWithoutAttrs<lmhlo::PartitionIdOp>(instr);
376 case HloOpcode::kReplicaId:
377 return CreateOpWithoutAttrs<lmhlo::ReplicaIdOp>(instr);
378 case HloOpcode::kTriangularSolve:
379 return EmitTriangularSolveOp(instr);
380 case HloOpcode::kTuple:
381 return nullptr;
382 case HloOpcode::kSort:
383 return EmitSortOp(instr);
384 case HloOpcode::kFusion:
385 return EmitFusionOp(instr);
386 case HloOpcode::kScatter:
387 return EmitScatterOp(instr);
388 case HloOpcode::kSelectAndScatter:
389 return EmitSelectAndScatterOp(instr);
390 case HloOpcode::kCustomCall:
391 return EmitCustomCallOp(instr);
392 case HloOpcode::kConstant:
393 return EmitConstant(instr);
394 case HloOpcode::kRngGetAndUpdateState:
395 return EmitRngGetAndUpdateStateOp(instr);
396 case HloOpcode::kWhile:
397 return EmitWhileOp(instr);
398
399 case HloOpcode::kAbs:
400 case HloOpcode::kAdd:
401 case HloOpcode::kAnd:
402 case HloOpcode::kAtan2:
403 case HloOpcode::kBitcastConvert:
404 case HloOpcode::kBroadcast:
405 case HloOpcode::kCeil:
406 case HloOpcode::kCbrt:
407 case HloOpcode::kClamp:
408 case HloOpcode::kClz:
409 case HloOpcode::kCompare:
410 case HloOpcode::kComplex:
411 case HloOpcode::kConcatenate:
412 case HloOpcode::kConvert:
413 case HloOpcode::kCos:
414 case HloOpcode::kDivide:
415 case HloOpcode::kDot:
416 case HloOpcode::kDynamicSlice:
417 case HloOpcode::kDynamicUpdateSlice:
418 case HloOpcode::kExp:
419 case HloOpcode::kExpm1:
420 case HloOpcode::kFloor:
421 case HloOpcode::kGather:
422 case HloOpcode::kImag:
423 case HloOpcode::kIota:
424 case HloOpcode::kIsFinite:
425 case HloOpcode::kLog:
426 case HloOpcode::kLog1p:
427 case HloOpcode::kMap:
428 case HloOpcode::kMaximum:
429 case HloOpcode::kMinimum:
430 case HloOpcode::kMultiply:
431 case HloOpcode::kNegate:
432 case HloOpcode::kNot:
433 case HloOpcode::kOr:
434 case HloOpcode::kPad:
435 case HloOpcode::kPopulationCount:
436 case HloOpcode::kPower:
437 case HloOpcode::kReal:
438 case HloOpcode::kReshape:
439 case HloOpcode::kReducePrecision:
440 case HloOpcode::kReduceWindow:
441 case HloOpcode::kRemainder:
442 case HloOpcode::kReverse:
443 case HloOpcode::kRoundNearestAfz:
444 case HloOpcode::kRoundNearestEven:
445 case HloOpcode::kRsqrt:
446 case HloOpcode::kSelect:
447 case HloOpcode::kShiftLeft:
448 case HloOpcode::kShiftRightLogical:
449 case HloOpcode::kShiftRightArithmetic:
450 case HloOpcode::kSign:
451 case HloOpcode::kSin:
452 case HloOpcode::kSlice:
453 case HloOpcode::kSqrt:
454 case HloOpcode::kSubtract:
455 case HloOpcode::kTanh:
456 case HloOpcode::kTranspose:
457 case HloOpcode::kXor:
458 case HloOpcode::kCopy:
459 case HloOpcode::kReduce:
460 return CreateOpInFusion(instr);
461 default:
462 llvm::errs() << instr->ToString();
463 return tensorflow::errors::Internal(
464 absl::StrCat("LHLO opcode ", xla::HloOpcodeString(instr->opcode()),
465 " is not supported."));
466 }
467 }
468
DefaultAction(const HloInstruction * instr)469 Status LhloDialectEmitter::DefaultAction(const HloInstruction* instr) {
470 return EmitOp(instr).status();
471 }
472
EmitSortOp(const HloInstruction * instr)473 StatusOr<lmhlo::SortOp> LhloDialectEmitter::EmitSortOp(
474 const HloInstruction* instr) {
475 TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs<lmhlo::SortOp>(instr));
476 auto* sort_instr = xla::Cast<xla::HloSortInstruction>(instr);
477 sort.setDimensionAttr(
478 builder_.getI64IntegerAttr(sort_instr->sort_dimension()));
479 sort.setIsStableAttr(builder_.getBoolAttr(sort_instr->is_stable()));
480 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
481 *sort_instr->called_computations()[0], &sort.getComparator(), &builder_));
482 return sort;
483 }
484
485 // Walks MHLO::TupleOp recursively.
WalkTuplePostOrder(Value v,const std::function<Status (Value)> & visitor)486 Status WalkTuplePostOrder(Value v,
487 const std::function<Status(Value)>& visitor) {
488 if (auto* op = v.getDefiningOp()) {
489 if (auto tuple = dyn_cast<mhlo::TupleOp>(op)) {
490 for (Value sub_v : tuple.val()) {
491 TF_RETURN_IF_ERROR(WalkTuplePostOrder(sub_v, visitor));
492 }
493 return ::tensorflow::OkStatus();
494 }
495 }
496 return visitor(v);
497 }
498
RewriteFusionOperand(const HloInstruction * root,const Shape & shape,xla::ShapeIndex * shape_index,OpBuilder * b,Location loc)499 StatusOr<Value> LhloDialectEmitter::RewriteFusionOperand(
500 const HloInstruction* root, const Shape& shape,
501 xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) {
502 if (shape.IsTuple()) {
503 llvm::SmallVector<Value, 4> values;
504 for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
505 shape_index->push_back(i);
506 TF_ASSIGN_OR_RETURN(
507 auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index,
508 b, loc));
509 values.push_back(v);
510 shape_index->pop_back();
511 }
512 return Value(b->create<mhlo::TupleOp>(loc, values));
513 }
514 TF_ASSIGN_OR_RETURN(Value memref,
515 GetOrCreateArrayView(root, shape, *shape_index));
516 auto load = b->create<bufferization::ToTensorOp>(loc, memref);
517 if (shape.layout() !=
518 xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) {
519 llvm::SmallVector<int64_t, 4> minor_to_major(
520 shape.layout().minor_to_major().begin(),
521 shape.layout().minor_to_major().end());
522 load->setAttr("xla_shape",
523 b->getStringAttr(shape.ToString(/*print_layout=*/true)));
524 }
525 return load.getResult();
526 }
527
528 // Emit a lmhlo.fusion based on XLA HLO fusion. Structurally they are not neatly
529 // equivalent. Specifically, XLA HLO fusion:
530 // fused_computation {
531 // %p0 = parameter(0)
532 // %p1 = parameter(1)
533 // ...
534 // ROOT %ret = ...
535 // }
536 // will be converted to
537 // lmhlo.fusion() { // no explicit operands
538 // // capturing outside buffers
539 // %p0 = bufferization.to_tensor(%arg0) : memref<...> -> tensor<...>
540 // %p1 = bufferization.to_tensor(%arg1) : memref<...> -> tensor<...>
541 // ...
542 // tensor_store ..., %ret // store a tensor to a memref
543 // }
EmitFusionOp(const HloInstruction * instr)544 StatusOr<lmhlo::FusionOp> LhloDialectEmitter::EmitFusionOp(
545 const HloInstruction* instr) {
546 Location loc = getLocation(instr);
547
548 auto* fusion_instr = xla::Cast<xla::HloFusionInstruction>(instr);
549
550 auto fusion = builder_.create<lmhlo::FusionOp>(getLocation(instr));
551 auto after_fusion = builder_.saveInsertionPoint();
552 auto reverter = absl::MakeCleanup(
553 [this, after_fusion] { builder_.restoreInsertionPoint(after_fusion); });
554 builder_ = mlir::OpBuilder(fusion);
555
556 auto region_builder = OpBuilder::atBlockBegin(&fusion.getRegion().front());
557
558 llvm::SmallVector<Value, 8> arguments;
559 for (int i = 0; i < instr->operands().size(); ++i) {
560 const HloInstruction* operand = instr->operand(i);
561 xla::ShapeIndex shape_index;
562 TF_ASSIGN_OR_RETURN(
563 auto arg, RewriteFusionOperand(operand, operand->shape(), &shape_index,
564 ®ion_builder, loc));
565 arguments.push_back(arg);
566 }
567
568 TF_ASSIGN_OR_RETURN(Value result,
569 xla::HloFunctionImporter::ImportInstructions(
570 *fusion_instr->fused_instructions_computation(),
571 arguments, ®ion_builder));
572 {
573 int i = 0;
574 llvm::SmallVector<Value, 4> output;
575 TF_RETURN_IF_ERROR(GetOrCreateView(instr, &output));
576 TF_RETURN_IF_ERROR(WalkTuplePostOrder(result, [&](Value v) mutable {
577 region_builder.create<memref::TensorStoreOp>(loc, v, output[i++]);
578 return ::tensorflow::OkStatus();
579 }));
580 if (i != output.size()) {
581 return xla::InternalError("output sizes don't match");
582 }
583 }
584
585 // Fold GTE/Tuple pairs.
586 //
587 // Since the fused region refers to values in its parent region, we can't
588 // call applyPatternAndFoldGreedily. We optimize it manually.
589 //
590 // Only walk once, because post-ordering is exactly what we need for GTE
591 // optimizations.
592 fusion.getRegion().walk([](mhlo::GetTupleElementOp gte) {
593 SmallVector<Value, 4> folded_values;
594 if (succeeded(OpBuilder(gte).tryFold(gte, folded_values))) {
595 gte.replaceAllUsesWith(folded_values[0]);
596 }
597 });
598
599 // Effectively a DCE on the region.
600 {
601 llvm::SmallVector<mlir::Operation*, 4> ops;
602 fusion.getRegion().walk([&](mlir::Operation* op) { ops.push_back(op); });
603 // Visit the user first.
604 std::reverse(ops.begin(), ops.end());
605 for (auto op : ops) {
606 if (isOpTriviallyDead(op)) op->erase();
607 }
608 }
609
610 return fusion;
611 }
612
613 StatusOr<mhlo::ScatterDimensionNumbersAttr>
GetScatterDimensionNumbers(const HloInstruction * instr,mlir::MLIRContext * context)614 LhloDialectEmitter::GetScatterDimensionNumbers(const HloInstruction* instr,
615 mlir::MLIRContext* context) {
616 auto* scatter_instr = xla::Cast<xla::HloScatterInstruction>(instr);
617
618 const xla::ScatterDimensionNumbers& xla_scatter_dim =
619 scatter_instr->scatter_dimension_numbers();
620
621 auto get_i64_array = [](absl::Span<const int64_t> container) {
622 return ArrayRef<int64_t>{container.data(),
623 static_cast<size_t>(container.size())};
624 };
625 auto scatter_dimension_numbers = mhlo::ScatterDimensionNumbersAttr::get(
626 context, get_i64_array(xla_scatter_dim.update_window_dims()),
627 get_i64_array(xla_scatter_dim.inserted_window_dims()),
628 get_i64_array(xla_scatter_dim.scatter_dims_to_operand_dims()),
629 xla_scatter_dim.index_vector_dim());
630 return scatter_dimension_numbers;
631 }
632
EmitScatterOp(const HloInstruction * instr)633 StatusOr<lmhlo::ScatterOp> LhloDialectEmitter::EmitScatterOp(
634 const HloInstruction* instr) {
635 TF_ASSIGN_OR_RETURN(auto scatter,
636 CreateOpWithoutAttrs<lmhlo::ScatterOp>(instr));
637
638 // copy attributes
639 auto* scatter_instr = xla::Cast<xla::HloScatterInstruction>(instr);
640
641 TF_ASSIGN_OR_RETURN(auto scatter_dimension_numbers,
642 GetScatterDimensionNumbers(instr, builder_.getContext()));
643 scatter.setScatterDimensionNumbersAttr(scatter_dimension_numbers);
644 scatter.setIndicesAreSortedAttr(
645 builder_.getBoolAttr(scatter_instr->indices_are_sorted()));
646 scatter.setUniqueIndicesAttr(
647 builder_.getBoolAttr(scatter_instr->unique_indices()));
648
649 // import update computation as region
650 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
651 *scatter_instr->called_computations()[0], &scatter.getUpdateComputation(),
652 &builder_));
653
654 return scatter;
655 }
656
EmitSelectAndScatterOp(const HloInstruction * instr)657 StatusOr<lmhlo::SelectAndScatterOp> LhloDialectEmitter::EmitSelectAndScatterOp(
658 const HloInstruction* instr) {
659 TF_ASSIGN_OR_RETURN(auto select_and_scatter,
660 CreateOpWithoutAttrs<lmhlo::SelectAndScatterOp>(instr));
661
662 // copy attributes
663 auto* select_and_scatter_instr =
664 xla::Cast<xla::HloSelectAndScatterInstruction>(instr);
665 const xla::Window& window = select_and_scatter_instr->window();
666
667 if (xla::window_util::HasDilation(window)) {
668 return xla::Unimplemented("Dilation for SelectAndScatter is not supported");
669 }
670
671 select_and_scatter.setWindowDimensionsAttr(
672 GetWindowElements(window, [](const xla::WindowDimension& dim) {
673 return static_cast<int64_t>(dim.size());
674 }));
675 select_and_scatter.setWindowStridesAttr(
676 GetWindowElements(window, [](const xla::WindowDimension& dim) {
677 return static_cast<int64_t>(dim.stride());
678 }));
679 select_and_scatter.setPaddingAttr(
680 GetWindowElements(window, [](const xla::WindowDimension& dim) {
681 return static_cast<int64_t>(dim.padding_low());
682 }));
683
684 // import select and scatter computation as region
685 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
686 *select_and_scatter_instr->select(), &select_and_scatter.getSelect(),
687 &builder_));
688 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
689 *select_and_scatter_instr->scatter(), &select_and_scatter.getScatter(),
690 &builder_));
691 return select_and_scatter;
692 }
693
EmitCustomCallOp(const HloInstruction * instr)694 StatusOr<mlir::Operation*> LhloDialectEmitter::EmitCustomCallOp(
695 const HloInstruction* instr) {
696 auto* custom_call_instr = xla::Cast<xla::HloCustomCallInstruction>(instr);
697
698 if (xla::gpu::IsCustomCallToCusolver(*instr)) {
699 return EmitCholesky(custom_call_instr);
700 }
701
702 if (xla::gpu::IsCublasGemm(*instr)) {
703 return EmitGemm(custom_call_instr);
704 }
705
706 if (xla::gpu::IsCublasLtMatmul(*instr)) {
707 return EmitCublasLtMatmul(custom_call_instr);
708 }
709
710 if (xla::gpu::IsCustomCallToDnnConvolution(*instr)) {
711 return EmitDnnConvolution(custom_call_instr);
712 }
713
714 // For custom call, if there are any token operands or results, they will not
715 // be represented in LHLO so we need to remember the mapping. First create
716 // operands where each token is replaced with a null Value.
717 llvm::SmallVector<Value, 4> operands;
718 size_t num_arguments, num_results;
719 TF_RETURN_IF_ERROR(CreateOperands(instr, /*num_operands=*/std::nullopt,
720 TokenLoweringMode::kUseNull, operands,
721 num_arguments, num_results));
722
723 // Now check if any of the operands is Null, which would indicate the presence
724 // of a token in the input or output.
725 bool has_token = llvm::any_of(operands, [](Value v) { return !v; });
726
727 lmhlo::CustomCallTargetArgMappingAttr target_mapping;
728 if (has_token) {
729 // If there was a token, squeeze all the non-token arguments and results
730 // (in-place) and remember the mapping.
731 int next_index = 0;
732 llvm::SmallVector<int64_t> arg_to_target_arg_mapping;
733 for (int i = 0; i < num_arguments; ++i) {
734 if (operands[i]) {
735 arg_to_target_arg_mapping.push_back(i);
736 operands[next_index++] = operands[i];
737 }
738 }
739 // Size of arg_to_target_arg_mapping is the number of arguments in LHLO.
740 llvm::SmallVector<int64_t> result_to_target_result_mapping;
741 for (int i = num_arguments; i < operands.size(); ++i) {
742 if (operands[i]) {
743 result_to_target_result_mapping.push_back(i - num_arguments);
744 operands[next_index++] = operands[i];
745 }
746 }
747
748 // Build the mapping attribute.
749 target_mapping = lmhlo::CustomCallTargetArgMappingAttr::get(
750 builder_.getContext(), num_arguments, num_results,
751 arg_to_target_arg_mapping, result_to_target_result_mapping);
752
753 // Drop the remaining operands and adjust num_arguments and num_results
754 // for LMHLO creation.
755 operands.resize(next_index);
756 num_arguments = arg_to_target_arg_mapping.size();
757 num_results = result_to_target_result_mapping.size();
758 }
759
760 auto custom_call = CreateOpWithoutAttrs<lmhlo::CustomCallOp>(instr, operands);
761 TF_ASSIGN_OR_RETURN(
762 auto mlir_api_version,
763 ConvertCustomCallApiVersion(custom_call_instr->api_version()));
764 custom_call.setCallTargetNameAttr(
765 builder_.getStringAttr(custom_call_instr->custom_call_target()));
766 custom_call.setBackendConfigAttr(
767 builder_.getStringAttr(custom_call_instr->opaque()));
768 custom_call.setApiVersionAttr(mhlo::CustomCallApiVersionAttr::get(
769 builder_.getContext(), mlir_api_version));
770 const int32_t segments[2] = {static_cast<int32_t>(num_arguments),
771 static_cast<int32_t>(num_results)};
772 custom_call->setAttr(lmhlo::CustomCallOp::getOperandSegmentSizeAttr(),
773 builder_.getDenseI32ArrayAttr(segments));
774 if (target_mapping) custom_call.setTargetArgMappingAttr(target_mapping);
775 return custom_call.getOperation();
776 }
777
EmitCholesky(const HloCustomCallInstruction * custom_call)778 StatusOr<lmhlo_gpu::CholeskyOp> LhloDialectEmitter::EmitCholesky(
779 const HloCustomCallInstruction* custom_call) {
780 TF_ASSIGN_OR_RETURN(auto cholesky_op,
781 CreateOpWithoutAttrs<lmhlo_gpu::CholeskyOp>(custom_call));
782 TF_ASSIGN_OR_RETURN(xla::CholeskyOptions options,
783 custom_call->backend_config<xla::CholeskyOptions>());
784 cholesky_op.setIsLowerAttr(builder_.getBoolAttr(options.lower()));
785 return cholesky_op;
786 }
787
788 namespace {
789
790 template <typename OpT>
SetMatmulAttributes(OpT op,const xla::gpu::GemmBackendConfig & config,OpBuilder & builder)791 void SetMatmulAttributes(OpT op, const xla::gpu::GemmBackendConfig& config,
792 OpBuilder& builder) {
793 auto arrayref = [](absl::Span<const int64_t> array) {
794 return llvm::ArrayRef<int64_t>{array.data(), array.size()};
795 };
796
797 auto hlo_dims = config.dot_dimension_numbers();
798 auto mlir_dims = mhlo::DotDimensionNumbersAttr::get(
799 builder.getContext(), arrayref(hlo_dims.lhs_batch_dimensions()),
800 arrayref(hlo_dims.rhs_batch_dimensions()),
801 arrayref(hlo_dims.lhs_contracting_dimensions()),
802 arrayref(hlo_dims.rhs_contracting_dimensions()));
803 op.setDotDimensionNumbersAttr(mlir_dims);
804 op.setAlphaRealAttr(builder.getF64FloatAttr(config.alpha_real()));
805 op.setAlphaImagAttr(builder.getF64FloatAttr(config.alpha_imag()));
806 op.setBetaAttr(builder.getF64FloatAttr(config.beta()));
807 if (config.algorithm_case() ==
808 xla::gpu::GemmBackendConfig::kSelectedAlgorithm) {
809 op.setAlgorithmAttr(builder.getI64IntegerAttr(config.selected_algorithm()));
810 }
811 op.setPrecisionConfigAttr(
812 xla::ConvertPrecisionConfig(&config.precision_config(), &builder));
813 }
814
AsLhloEpilogue(xla::gpu::GemmBackendConfig_Epilogue epilogue)815 StatusOr<lmhlo_gpu::CublasLtMatmulEpilogue> AsLhloEpilogue(
816 xla::gpu::GemmBackendConfig_Epilogue epilogue) {
817 switch (epilogue) {
818 case xla::gpu::GemmBackendConfig::DEFAULT:
819 return lmhlo_gpu::CublasLtMatmulEpilogue::Default;
820 break;
821 case xla::gpu::GemmBackendConfig::BIAS:
822 return lmhlo_gpu::CublasLtMatmulEpilogue::Bias;
823 break;
824 default:
825 return xla::InternalError("unknown epilogue");
826 }
827 }
828
829 } // namespace
830
EmitGemm(const HloCustomCallInstruction * custom_call)831 StatusOr<Operation*> LhloDialectEmitter::EmitGemm(
832 const HloCustomCallInstruction* custom_call) {
833 TF_ASSIGN_OR_RETURN(
834 auto const config,
835 custom_call->backend_config<xla::gpu::GemmBackendConfig>());
836
837 if (custom_call->operand_count() == 2) {
838 TF_RET_CHECK(config.beta() == 0.);
839 } else if (custom_call->operand_count() != 3) {
840 return xla::InvalidArgument("GEMM custom call should have 2 or 3 operands");
841 }
842
843 // GEMM may have two or three operands. However, in the three operand case,
844 // the third operand is updated in-place, so we treat that as an output here.
845 TF_ASSIGN_OR_RETURN(
846 lmhlo_gpu::GEMMOp op,
847 CreateOpWithoutAttrs<lmhlo_gpu::GEMMOp>(custom_call,
848 /*num_operands=*/2));
849
850 SetMatmulAttributes(op, config, builder_);
851 return op.getOperation();
852 }
853
EmitCublasLtMatmul(const HloCustomCallInstruction * custom_call)854 StatusOr<Operation*> LhloDialectEmitter::EmitCublasLtMatmul(
855 const HloCustomCallInstruction* custom_call) {
856 TF_ASSIGN_OR_RETURN(
857 auto const config,
858 custom_call->backend_config<xla::gpu::GemmBackendConfig>());
859
860 bool has_matrix_bias = config.beta() != 0.;
861 bool has_vector_bias = config.epilogue() == xla::gpu::GemmBackendConfig::BIAS;
862 TF_RET_CHECK(custom_call->operand_count() ==
863 2 + int{has_matrix_bias} + int{has_vector_bias});
864
865 llvm::SmallVector<Value, 5> operands;
866 TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(0), &operands));
867 TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(1), &operands));
868 TF_RETURN_IF_ERROR(GetOrCreateView(
869 has_matrix_bias ? custom_call->operand(2) : custom_call, &operands));
870 TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands));
871
872 if (has_vector_bias) {
873 TF_RETURN_IF_ERROR(GetOrCreateView(
874 custom_call->operand(has_matrix_bias ? 3 : 2), &operands));
875 }
876
877 auto op =
878 CreateOpWithoutAttrs<lmhlo_gpu::CublasLtMatmulOp>(custom_call, operands);
879 SetMatmulAttributes(op, config, builder_);
880
881 TF_ASSIGN_OR_RETURN(lmhlo_gpu::CublasLtMatmulEpilogue epilogue,
882 AsLhloEpilogue(config.epilogue()));
883 op.setEpilogueAttr(lmhlo_gpu::CublasLtMatmulEpilogueAttr::get(
884 builder_.getContext(), epilogue));
885
886 // Use the first algorithm by default (i.e. fastest according to heuristics).
887 if (config.algorithm_case() !=
888 xla::gpu::GemmBackendConfig::kSelectedAlgorithm) {
889 op.setAlgorithmAttr(builder_.getI64IntegerAttr(0));
890 }
891
892 return op.getOperation();
893 }
894
GetLHLOActivation(stream_executor::dnn::ActivationMode activation)895 static StatusOr<mlir::lmhlo_gpu::Activation> GetLHLOActivation(
896 stream_executor::dnn::ActivationMode activation) {
897 switch (activation) {
898 case stream_executor::dnn::kNone:
899 return mlir::lmhlo_gpu::Activation::None;
900 case stream_executor::dnn::kSigmoid:
901 return mlir::lmhlo_gpu::Activation::Sigmoid;
902 case stream_executor::dnn::kRelu:
903 return mlir::lmhlo_gpu::Activation::Relu;
904 case stream_executor::dnn::kRelu6:
905 return mlir::lmhlo_gpu::Activation::Relu6;
906 case stream_executor::dnn::kReluX:
907 return mlir::lmhlo_gpu::Activation::ReluX;
908 case stream_executor::dnn::kTanh:
909 return mlir::lmhlo_gpu::Activation::Tanh;
910 case stream_executor::dnn::kBandPass:
911 return mlir::lmhlo_gpu::Activation::BandPass;
912 default:
913 return xla::InternalError("Unknown activation");
914 }
915 }
916
EmitDnnConvolution(const HloCustomCallInstruction * custom_call)917 StatusOr<Operation*> LhloDialectEmitter::EmitDnnConvolution(
918 const HloCustomCallInstruction* custom_call) {
919 TF_ASSIGN_OR_RETURN(
920 auto const backend_config,
921 custom_call->backend_config<xla::gpu::CudnnConvBackendConfig>());
922
923 TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnConvKind kind,
924 xla::gpu::GetCudnnConvKind(custom_call));
925
926 auto get_layout_attribute = [&](const xla::Layout& layout) {
927 std::vector<int64_t> minor_to_major(layout.minor_to_major_size());
928 absl::c_transform(layout.minor_to_major(), minor_to_major.begin(),
929 [](int64_t x) { return static_cast<int64_t>(x); });
930 return minor_to_major;
931 };
932
933 auto set_common_conv_attributes = [&, this](auto op) -> Operation* {
934 const xla::Window& window = custom_call->window();
935 // Window size for Cudnn Conv is same as the kernel size.
936 NamedAttrList attrs(op->getAttrDictionary());
937 DenseIntElementsAttr window_strides;
938 attrs.set(op.getWindowStridesAttrName(),
939 window_strides = GetWindowElements(
940 window, [](const xla::WindowDimension& dim) {
941 return static_cast<int64_t>(dim.stride());
942 }));
943 // Cudnn Conv requires low and high padding to be equal.
944 attrs.set(op.getPaddingAttrName(),
945 GetWindowElements(window, [](const xla::WindowDimension& dim) {
946 return static_cast<int64_t>(dim.padding_low());
947 }));
948 // LHS dilation is encoded in base_dilation of the backend config.
949 // RHS dilation is encoded in window_dilation of the backend config.
950 attrs.set(op.getLhsDilationAttrName(),
951 GetWindowElements(window, [](const xla::WindowDimension& dim) {
952 return static_cast<int64_t>(dim.base_dilation());
953 }));
954 attrs.set(op.getRhsDilationAttrName(),
955 GetWindowElements(window, [](const xla::WindowDimension& dim) {
956 return static_cast<int64_t>(dim.window_dilation());
957 }));
958 // Setup window reversal.
959 auto window_reversal = llvm::to_vector<4>(llvm::map_range(
960 window.dimensions(),
961 [](const xla::WindowDimension& dim) { return dim.window_reversal(); }));
962 auto type = RankedTensorType::get(window_strides.getType().getShape(),
963 builder_.getIntegerType(/*width=*/1));
964 attrs.set(op.getWindowReversalAttrName(),
965 DenseElementsAttr::get(type, window_reversal));
966
967 attrs.set(op.getDimensionNumbersAttrName(),
968 xla::ConvertConvDimensionNumbers(
969 custom_call->convolution_dimension_numbers(), &builder_));
970 attrs.set(op.getFeatureGroupCountAttrName(),
971 builder_.getI64IntegerAttr(custom_call->feature_group_count()));
972 attrs.set(op.getBatchGroupCountAttrName(),
973 builder_.getI64IntegerAttr(custom_call->batch_group_count()));
974 attrs.set(op.getPrecisionConfigAttrName(),
975 xla::ConvertPrecisionConfig(&custom_call->precision_config(),
976 &builder_));
977 attrs.set(op.getResultScaleAttrName(),
978 builder_.getF64FloatAttr(backend_config.conv_result_scale()));
979
980 const auto& algorithm = backend_config.algorithm();
981 std::vector<int64_t> knob_ids;
982 std::vector<int64_t> knob_values;
983 for (const auto& entry : algorithm.tuning_knobs()) {
984 knob_ids.push_back(entry.first);
985 knob_values.push_back(entry.second);
986 }
987
988 auto config = mlir::lmhlo_gpu::ConvolutionBackendConfigAttr::get(
989 builder_.getContext(), algorithm.algo_id(),
990
991 algorithm.math_type() ==
992 stream_executor::dnn::AlgorithmProto::TENSOR_OP_MATH,
993 knob_ids, knob_values, algorithm.is_cudnn_frontend(),
994 algorithm.has_workspace_size() ? algorithm.workspace_size().value()
995 : -1,
996 get_layout_attribute(custom_call->operand(0)->shape().layout()),
997 get_layout_attribute(custom_call->operand(1)->shape().layout()),
998 get_layout_attribute(custom_call->shape().tuple_shapes(0).layout()));
999 attrs.set(op.getBackendConfigAttrName(), config);
1000 op->setAttrs(attrs.getDictionary(op->getContext()));
1001
1002 return op.getOperation();
1003 };
1004
1005 auto set_activation = [&, this](auto op) -> Status {
1006 auto se_activation = static_cast<stream_executor::dnn::ActivationMode>(
1007 backend_config.activation_mode());
1008 TF_ASSIGN_OR_RETURN(mlir::lmhlo_gpu::Activation activation,
1009 GetLHLOActivation(se_activation));
1010 auto activation_attr = ::mlir::lmhlo_gpu::ActivationAttr::get(
1011 getLocation(custom_call).getContext(), activation);
1012 op.setActivationModeAttr(activation_attr);
1013 return ::tensorflow::OkStatus();
1014 };
1015
1016 switch (kind) {
1017 case xla::gpu::CudnnConvKind::kForward: {
1018 TF_ASSIGN_OR_RETURN(
1019 auto cnn_forward,
1020 CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardOp>(custom_call));
1021 return set_common_conv_attributes(cnn_forward);
1022 }
1023 case xla::gpu::CudnnConvKind::kBackwardInput: {
1024 TF_ASSIGN_OR_RETURN(
1025 auto cnn_backward,
1026 CreateOpWithoutAttrs<lmhlo_gpu::ConvBackwardInputOp>(custom_call));
1027 return set_common_conv_attributes(cnn_backward);
1028 }
1029 case xla::gpu::CudnnConvKind::kBackwardFilter: {
1030 TF_ASSIGN_OR_RETURN(
1031 auto cnn_backward,
1032 CreateOpWithoutAttrs<lmhlo_gpu::ConvBackwardFilterOp>(custom_call));
1033 return set_common_conv_attributes(cnn_backward);
1034 }
1035 case xla::gpu::CudnnConvKind::kForwardActivation: {
1036 // Fused conv can be either with side input or without.
1037 if (custom_call->operand_count() == 3) {
1038 TF_ASSIGN_OR_RETURN(
1039 auto cnn_fused,
1040 CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardFusedOp>(custom_call));
1041 TF_RETURN_IF_ERROR(set_activation(cnn_fused));
1042 return set_common_conv_attributes(cnn_fused);
1043 }
1044
1045 TF_RET_CHECK(custom_call->operand_count() == 4);
1046 TF_ASSIGN_OR_RETURN(
1047 auto cnn_fused_side_input,
1048 CreateOpWithoutAttrs<lmhlo_gpu::ConvForwardFusedSideInputOp>(
1049 custom_call));
1050 cnn_fused_side_input.setSideInputScaleAttr(
1051 builder_.getF64FloatAttr(backend_config.side_input_scale()));
1052 TF_RETURN_IF_ERROR(set_activation(cnn_fused_side_input));
1053 return set_common_conv_attributes(cnn_fused_side_input);
1054 }
1055 }
1056 }
1057
1058 // Convert an XLA HLO constant to a global_memref + get_global_memref pair.
EmitConstant(const HloInstruction * instr)1059 StatusOr<mlir::memref::GetGlobalOp> LhloDialectEmitter::EmitConstant(
1060 const HloInstruction* instr) {
1061 auto& cached_value = slices_[std::make_pair(instr, xla::ShapeIndex())];
1062 if (cached_value) {
1063 return dyn_cast<mlir::memref::GetGlobalOp>(cached_value.getDefiningOp());
1064 }
1065
1066 // Insert a global_memref in the module.
1067 Location loc = getLocation(instr);
1068
1069 auto const_instr = xla::Cast<xla::HloConstantInstruction>(instr);
1070 TF_RET_CHECK(const_instr->shape().IsArray() &&
1071 const_instr->shape().is_static());
1072 TF_ASSIGN_OR_RETURN(Type type, xla::ConvertShapeToType<MemRefType>(
1073 const_instr->shape(), builder_));
1074 auto memref_type = type.dyn_cast<MemRefType>();
1075 TF_RET_CHECK(memref_type != nullptr);
1076
1077 TF_ASSIGN_OR_RETURN(
1078 DenseElementsAttr initial_value,
1079 CreateDenseElementsAttrFromLiteral(const_instr->literal(), builder_));
1080
1081 std::string constant_name = xla::llvm_ir::ConstantNameToGlobalName(
1082 xla::llvm_ir::SanitizeConstantName(instr->name()));
1083
1084 // Insert the global memref at the top level.
1085 {
1086 OpBuilder::InsertionGuard guard(builder_);
1087 builder_.clearInsertionPoint();
1088 auto global_var = builder_.create<memref::GlobalOp>(
1089 loc, constant_name, builder_.getStringAttr("private"), memref_type,
1090 initial_value, true, /*alignment=*/IntegerAttr());
1091 SymbolTable(module_).insert(global_var);
1092 global_var.getOperation()->moveBefore(&module_.front());
1093
1094 // For operations that do not fold this constant value in their codegen, we
1095 // still need to materialize it into a buffer. Since buffer allocation is
1096 // already done, annotate the global_memref with the information to get to
1097 // the allocated buffer slice for this constant if need be.
1098 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
1099 assignment_.GetUniqueTopLevelSlice(instr));
1100 global_var->setAttr(
1101 "lmhlo.alloc",
1102 builder_.getIndexAttr(allocations_.find(slice.allocation())
1103 ->second.cast<BlockArgument>()
1104 .getArgNumber()));
1105 TF_RET_CHECK(slice.offset() == 0)
1106 << "Each constant should have its own allocation from BufferAssignment";
1107 TF_RET_CHECK(slice.allocation()->size() == slice.size())
1108 << "Each constant should have its own allocation from BufferAssignment";
1109 }
1110
1111 auto get_global_memref =
1112 builder_.create<memref::GetGlobalOp>(loc, memref_type, constant_name);
1113
1114 // Update the cache to remember this value.
1115 cached_value = get_global_memref;
1116 return get_global_memref;
1117 }
1118
1119 namespace {
1120 template <typename OpT>
SetupChannelIdAttribute(OpT op,const xla::HloChannelInstruction * instr,mlir::Builder builder)1121 void SetupChannelIdAttribute(OpT op, const xla::HloChannelInstruction* instr,
1122 mlir::Builder builder) {
1123 if (instr->channel_id().has_value()) {
1124 op.setChannelIdAttr(mlir::mhlo::ChannelHandleAttr::get(
1125 builder.getContext(), *instr->channel_id(), 0));
1126 }
1127 }
1128
1129 template <typename OpT>
SetupCommonCollectiveOpAttributes(OpT op,const HloInstruction * instr,mlir::OpBuilder & builder)1130 Status SetupCommonCollectiveOpAttributes(OpT op, const HloInstruction* instr,
1131 mlir::OpBuilder& builder) {
1132 auto* collective = xla::Cast<xla::HloCollectiveInstruction>(instr);
1133 auto replica_groups_attr = xla::HloFunctionImporter::ConvertReplicaGroups(
1134 collective->replica_groups(), &builder);
1135 op->setAttr(replica_groups_attr.getName(), replica_groups_attr.getValue());
1136 op.setConstrainLayoutAttr(
1137 builder.getBoolAttr(collective->constrain_layout()));
1138 SetupChannelIdAttribute(op, collective, builder);
1139 return ::tensorflow::OkStatus();
1140 }
1141 } // namespace
1142
EmitAllToAllOp(const HloInstruction * instr)1143 StatusOr<lmhlo::AllToAllOp> LhloDialectEmitter::EmitAllToAllOp(
1144 const HloInstruction* instr) {
1145 TF_ASSIGN_OR_RETURN(auto all_to_all_op,
1146 CreateOpWithoutAttrs<lmhlo::AllToAllOp>(instr));
1147 auto* all_to_all = xla::Cast<xla::HloAllToAllInstruction>(instr);
1148 TF_RETURN_IF_ERROR(
1149 SetupCommonCollectiveOpAttributes(all_to_all_op, instr, builder_));
1150 if (all_to_all->split_dimension().has_value()) {
1151 all_to_all_op.setSplitDimensionAttr(
1152 builder_.getI64IntegerAttr(*all_to_all->split_dimension()));
1153 }
1154 return all_to_all_op;
1155 }
1156
EmitAllGatherOp(const HloInstruction * instr)1157 StatusOr<lmhlo::AllGatherOp> LhloDialectEmitter::EmitAllGatherOp(
1158 const HloInstruction* instr) {
1159 TF_ASSIGN_OR_RETURN(auto all_gather_op,
1160 CreateOpWithoutAttrs<lmhlo::AllGatherOp>(instr));
1161 auto* all_gather = xla::Cast<xla::HloAllGatherInstruction>(instr);
1162 TF_RETURN_IF_ERROR(
1163 SetupCommonCollectiveOpAttributes(all_gather_op, instr, builder_));
1164 all_gather_op.setUseGlobalDeviceIdsAttr(
1165 builder_.getBoolAttr(all_gather->use_global_device_ids()));
1166 all_gather_op.setAllGatherDimensionAttr(
1167 builder_.getI64IntegerAttr(all_gather->all_gather_dimension()));
1168 return all_gather_op;
1169 }
1170
EmitAllReduceOp(const HloInstruction * instr)1171 StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
1172 const HloInstruction* instr) {
1173 TF_ASSIGN_OR_RETURN(auto all_reduce_op,
1174 CreateOpWithoutAttrs<lmhlo::AllReduceOp>(instr));
1175 auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr);
1176 TF_RETURN_IF_ERROR(
1177 SetupCommonCollectiveOpAttributes(all_reduce_op, instr, builder_));
1178 all_reduce_op.setUseGlobalDeviceIdsAttr(
1179 builder_.getBoolAttr(all_reduce->use_global_device_ids()));
1180 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
1181 *instr->called_computations()[0], &all_reduce_op.getComputation(),
1182 &builder_));
1183 return all_reduce_op;
1184 }
1185
EmitAllReduceStartOp(const HloInstruction * instr)1186 StatusOr<lmhlo_gpu::AllReduceStartOp> LhloDialectEmitter::EmitAllReduceStartOp(
1187 const HloInstruction* instr) {
1188 llvm::SmallVector<Value, 4> operands;
1189 for (const HloInstruction* operand : instr->operands()) {
1190 TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands));
1191 }
1192 TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{}));
1193
1194 Location loc = getLocation(instr);
1195 mlir::Type token_type = mlir::mhlo::TokenType::get(builder_.getContext());
1196 std::array<mlir::Type, 1> result_types = {token_type};
1197 lmhlo_gpu::AllReduceStartOp all_reduce_start_op =
1198 builder_.create<lmhlo_gpu::AllReduceStartOp>(loc, result_types, operands);
1199
1200 auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr);
1201 TF_RETURN_IF_ERROR(
1202 SetupCommonCollectiveOpAttributes(all_reduce_start_op, instr, builder_));
1203 all_reduce_start_op.setUseGlobalDeviceIdsAttr(
1204 builder_.getBoolAttr(all_reduce->use_global_device_ids()));
1205 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
1206 *instr->called_computations()[0], &all_reduce_start_op.getComputation(),
1207 &builder_));
1208
1209 TF_RET_CHECK(all_reduce_start_ops_.emplace(instr, all_reduce_start_op).second)
1210 << "all-reduce-start already lowered";
1211 return all_reduce_start_op;
1212 }
1213
EmitAllReduceDoneOp(const HloInstruction * instr)1214 StatusOr<lmhlo_gpu::AllReduceDoneOp> LhloDialectEmitter::EmitAllReduceDoneOp(
1215 const HloInstruction* instr) {
1216 auto it = all_reduce_start_ops_.find(instr->operand(0));
1217 TF_RET_CHECK(it != all_reduce_start_ops_.end())
1218 << "didn't find all-reduce-start op";
1219
1220 llvm::SmallVector<Value, 4> operands;
1221 operands.push_back(it->second.getToken());
1222 all_reduce_start_ops_.erase(it);
1223
1224 for (const HloInstruction* operand : instr->operands()) {
1225 TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands));
1226 }
1227 // We don't need to add buffers for the outputs, as these always alias inputs.
1228 return builder_.create<lmhlo_gpu::AllReduceDoneOp>(
1229 getLocation(instr), /*resultTypes=*/llvm::None, operands);
1230 }
1231
EmitReduceScatterOp(const HloInstruction * instr)1232 StatusOr<lmhlo::ReduceScatterOp> LhloDialectEmitter::EmitReduceScatterOp(
1233 const HloInstruction* instr) {
1234 TF_ASSIGN_OR_RETURN(auto reduce_scatter_op,
1235 CreateOpWithoutAttrs<lmhlo::ReduceScatterOp>(instr));
1236 auto* ars = xla::Cast<xla::HloReduceScatterInstruction>(instr);
1237 TF_RETURN_IF_ERROR(
1238 SetupCommonCollectiveOpAttributes(reduce_scatter_op, instr, builder_));
1239 reduce_scatter_op.setUseGlobalDeviceIdsAttr(
1240 builder_.getBoolAttr(ars->use_global_device_ids()));
1241 TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
1242 *instr->called_computations()[0], &reduce_scatter_op.getComputation(),
1243 &builder_));
1244 reduce_scatter_op.setScatterDimensionAttr(
1245 builder_.getI64IntegerAttr(ars->scatter_dimension()));
1246 return reduce_scatter_op;
1247 }
1248
1249 StatusOr<lmhlo::CollectivePermuteOp>
EmitCollectivePermuteOp(const HloInstruction * instr)1250 LhloDialectEmitter::EmitCollectivePermuteOp(const HloInstruction* instr) {
1251 TF_ASSIGN_OR_RETURN(auto permute_op,
1252 CreateOpWithoutAttrs<lmhlo::CollectivePermuteOp>(instr));
1253 auto* permute = xla::Cast<xla::HloCollectivePermuteInstruction>(instr);
1254 SetupChannelIdAttribute(permute_op, permute, builder_);
1255 mlir::NamedAttribute source_target_pairs_attr =
1256 xla::HloFunctionImporter::ConvertSourceTargetPairs(
1257 permute->source_target_pairs(), &builder_);
1258 permute_op->setAttr(source_target_pairs_attr.getName(),
1259 source_target_pairs_attr.getValue());
1260 return permute_op;
1261 }
1262
EmitInfeedOp(const HloInstruction * instr)1263 StatusOr<lmhlo::InfeedOp> LhloDialectEmitter::EmitInfeedOp(
1264 const HloInstruction* instr) {
1265 const HloInfeedInstruction* infeed = xla::Cast<HloInfeedInstruction>(instr);
1266 // HLO Infeed instruction has a single operand of token type and a tuple
1267 // with buffers and a token as its output. LMHLO Infeed operation does not
1268 // need the token operand or result, so drop it.
1269 SmallVector<Value, 2> operands;
1270 TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{0}));
1271 auto infeed_op = CreateOpWithoutAttrs<lmhlo::InfeedOp>(instr, operands);
1272 infeed_op.setConfigAttr(builder_.getStringAttr(infeed->infeed_config()));
1273 return infeed_op;
1274 }
1275
EmitOutfeedOp(const HloInstruction * instr)1276 StatusOr<lmhlo::OutfeedOp> LhloDialectEmitter::EmitOutfeedOp(
1277 const HloInstruction* instr) {
1278 const HloOutfeedInstruction* outfeed =
1279 xla::Cast<HloOutfeedInstruction>(instr);
1280 // HLO outfeed instruction has 2 operands, the source and a token, and a
1281 // single token output. LMHLO Outfeed does not need the token operand and
1282 // result, do drop it.
1283 SmallVector<Value, 2> operands;
1284 TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(0), &operands));
1285 auto outfeed_op = CreateOpWithoutAttrs<lmhlo::OutfeedOp>(instr, operands);
1286 outfeed_op.setConfigAttr(builder_.getStringAttr(outfeed->outfeed_config()));
1287 return outfeed_op;
1288 }
1289
1290 xla::StatusOr<lmhlo::RngGetAndUpdateStateOp>
EmitRngGetAndUpdateStateOp(const xla::HloInstruction * instr)1291 LhloDialectEmitter::EmitRngGetAndUpdateStateOp(
1292 const xla::HloInstruction* instr) {
1293 TF_ASSIGN_OR_RETURN(
1294 auto rng, CreateOpWithoutAttrs<lmhlo::RngGetAndUpdateStateOp>(instr));
1295 auto hlo_rng = xla::Cast<xla::HloRngGetAndUpdateStateInstruction>(instr);
1296 rng.setDeltaAttr(builder_.getI64IntegerAttr(hlo_rng->delta()));
1297 return rng;
1298 }
1299
EmitFftOp(const HloInstruction * instr)1300 xla::StatusOr<lmhlo::FftOp> LhloDialectEmitter::EmitFftOp(
1301 const HloInstruction* instr) {
1302 auto hlo_fft = xla::Cast<xla::HloFftInstruction>(instr);
1303 TF_ASSIGN_OR_RETURN(auto fft, CreateOpWithoutAttrs<lmhlo::FftOp>(instr));
1304 TF_ASSIGN_OR_RETURN(mlir::mhlo::FftType fft_type,
1305 xla::ConvertFftType(hlo_fft->fft_type()));
1306 fft.setFftTypeAttr(
1307 mlir::mhlo::FftTypeAttr::get(builder_.getContext(), fft_type));
1308 fft.setFftLengthAttr(GetI64DenseElementsAttr(instr->fft_length()));
1309 return fft;
1310 }
1311
1312 xla::StatusOr<lmhlo::TriangularSolveOp>
EmitTriangularSolveOp(const xla::HloInstruction * instr)1313 LhloDialectEmitter::EmitTriangularSolveOp(const xla::HloInstruction* instr) {
1314 auto hlo_triangular_solve =
1315 xla::Cast<xla::HloTriangularSolveInstruction>(instr);
1316 TF_ASSIGN_OR_RETURN(auto triangular_solve,
1317 CreateOpWithoutAttrs<lmhlo::TriangularSolveOp>(instr));
1318 const xla::TriangularSolveOptions& options =
1319 hlo_triangular_solve->triangular_solve_options();
1320 triangular_solve.setLeftSideAttr(builder_.getBoolAttr(options.left_side()));
1321 triangular_solve.setLowerAttr(builder_.getBoolAttr(options.lower()));
1322 triangular_solve.setUnitDiagonalAttr(
1323 builder_.getBoolAttr(options.unit_diagonal()));
1324 TF_ASSIGN_OR_RETURN(mlir::mhlo::Transpose transpose,
1325 xla::ConvertTranspose(options.transpose_a()));
1326 triangular_solve.setTransposeAAttr(
1327 mlir::mhlo::TransposeAttr::get(builder_.getContext(), transpose));
1328 triangular_solve.setLayoutAAttr(
1329 GetLayoutAttribute(instr->operand(0)->shape().layout(), &builder_));
1330 triangular_solve.setLayoutBAttr(
1331 GetLayoutAttribute(instr->operand(1)->shape().layout(), &builder_));
1332 triangular_solve.setLayoutOutputAttr(
1333 GetLayoutAttribute(instr->shape().layout(), &builder_));
1334 return triangular_solve;
1335 }
1336
EmitBitcast(const xla::HloInstruction * instr)1337 xla::StatusOr<Operation*> LhloDialectEmitter::EmitBitcast(
1338 const xla::HloInstruction* instr) {
1339 // XLA buffer assignment should assign the same slice to a bitcast input and
1340 // output.
1341 const xla::ShapeIndex top_index;
1342 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
1343 assignment_.GetUniqueSlice(instr, top_index));
1344 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
1345 assignment_.GetUniqueSlice(instr->operand(0), top_index));
1346
1347 if (input_slice != result_slice) {
1348 return xla::InvalidArgument(
1349 "Bitcast input and result slice should be same");
1350 }
1351 return nullptr;
1352 }
1353
GetLayoutAttribute(const xla::Layout & layout,Builder * builder)1354 mlir::DenseIntElementsAttr LhloDialectEmitter::GetLayoutAttribute(
1355 const xla::Layout& layout, Builder* builder) {
1356 llvm::SmallVector<int64_t, 4> minor_to_major(layout.minor_to_major().begin(),
1357 layout.minor_to_major().end());
1358 return builder->getIndexTensorAttr(minor_to_major);
1359 }
1360
ImportAsLmhloRegion(xla::HloComputation * computation,mlir::Region * region)1361 Status LhloDialectEmitter::ImportAsLmhloRegion(xla::HloComputation* computation,
1362 mlir::Region* region) {
1363 auto after = builder_.saveInsertionPoint();
1364 auto reverter = absl::MakeCleanup(
1365 [this, after] { builder_.restoreInsertionPoint(after); });
1366
1367 builder_ = OpBuilder(region);
1368 const xla::HloInstructionSequence* schedule =
1369 assignment_.hlo_ordering().SequentialOrder(*computation);
1370 if (!schedule)
1371 return xla::Unimplemented("Missing sequential order for the computation");
1372 TF_RETURN_IF_ERROR(
1373 computation->AcceptOrdered(this, schedule->instructions()));
1374 builder_.create<lmhlo::TerminatorOp>(builder_.getUnknownLoc());
1375 return ::tensorflow::OkStatus();
1376 }
1377
EmitCaseOp(const HloInstruction * instr)1378 StatusOr<lmhlo::CaseOp> LhloDialectEmitter::EmitCaseOp(
1379 const HloInstruction* instr) {
1380 Location loc = getLocation(instr);
1381 llvm::SmallVector<Value, 4> operands;
1382 size_t num_arguments, num_results;
1383 TF_RETURN_IF_ERROR(CreateOperands(instr, 1, TokenLoweringMode::kUseNull,
1384 operands, num_arguments, num_results));
1385
1386 auto case_op =
1387 builder_.create<lmhlo::CaseOp>(loc, operands[0], instr->branch_count());
1388
1389 for (int i = 0; i < instr->branch_count(); i++) {
1390 case_op.getBranches()[i].push_back(new mlir::Block());
1391 TF_RETURN_IF_ERROR(ImportAsLmhloRegion(instr->called_computations()[i],
1392 &case_op.getBranches()[i]));
1393 }
1394
1395 return case_op;
1396 }
1397
EmitWhileOp(const xla::HloInstruction * instr)1398 xla::StatusOr<lmhlo::WhileOp> LhloDialectEmitter::EmitWhileOp(
1399 const xla::HloInstruction* instr) {
1400 Location loc = getLocation(instr);
1401 SmallVector<Value, 1> operands;
1402 TF_RETURN_IF_ERROR(GetOrCreateView(
1403 instr->called_computations()[1]->root_instruction(), &operands));
1404 TF_RET_CHECK(operands.size() == 1);
1405
1406 TF_ASSIGN_OR_RETURN(auto config,
1407 instr->backend_config<xla::WhileLoopBackendConfig>());
1408 mlir::IntegerAttr trip_count;
1409 if (config.has_known_trip_count()) {
1410 trip_count = builder_.getI64IntegerAttr(config.known_trip_count().n());
1411 }
1412 lmhlo::WhileOp while_op =
1413 builder_.create<lmhlo::WhileOp>(loc, operands[0], trip_count);
1414
1415 while_op.getCond().push_back(new mlir::Block());
1416 while_op.getBody().push_back(new mlir::Block());
1417 TF_RETURN_IF_ERROR(ImportAsLmhloRegion(instr->called_computations()[1],
1418 &while_op.getCond()));
1419
1420 TF_RETURN_IF_ERROR(ImportAsLmhloRegion(instr->called_computations()[0],
1421 &while_op.getBody()));
1422
1423 return while_op;
1424 }
1425
GetOrCreateArrayView(const xla::HloInstruction * instr,const xla::Shape & current_shape,const xla::ShapeIndex & shape_index)1426 StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
1427 const xla::HloInstruction* instr, const xla::Shape& current_shape,
1428 const xla::ShapeIndex& shape_index) {
1429 // For constants, the cache is managed inside EmitConstant since it can
1430 // be called either from here or when we see a top-level HloConstant instr.
1431 if (instr->IsConstant() && shape_index.empty()) {
1432 TF_ASSIGN_OR_RETURN(Value constant_memref, EmitConstant(instr));
1433 return constant_memref;
1434 }
1435
1436 // Cache generated ViewOp and StaticMemRefCastOp by (instruction,
1437 // shape_index).
1438 auto& cached_value = slices_[std::make_pair(instr, shape_index)];
1439 if (cached_value) {
1440 return cached_value;
1441 }
1442
1443 // If the shape happens to have dynamic dimensions, create the memref using
1444 // the underlying static shape.
1445 // TODO(jurahul): Revisit this when we can model memrefs with dynamic shape
1446 // but static bounds in MLIR.
1447 const Shape static_shape = xla::ShapeUtil::MakeStaticShape(current_shape);
1448
1449 TF_ASSIGN_OR_RETURN(Type out_type, xla::ConvertShapeToType<MemRefType>(
1450 static_shape, builder_));
1451 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
1452 assignment_.GetUniqueSlice(instr, shape_index));
1453 Value alloc = allocations_[slice.allocation()];
1454
1455 // TODO(timshen): revisit location handling.
1456 Location loc = builder_.getUnknownLoc();
1457
1458 Value byte_shift =
1459 builder_.create<arith::ConstantIndexOp>(alloc.getLoc(), slice.offset());
1460
1461 xla::Shape physical_shape =
1462 xla::ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
1463 static_shape);
1464 TF_ASSIGN_OR_RETURN(
1465 Type physical_out_type,
1466 xla::ConvertShapeToType<MemRefType>(physical_shape, builder_));
1467
1468 // ViewOp only takes memrefs without affine maps (layouts). Let ViewOp
1469 // produce the physical shape (where dimensions are ordered in major to
1470 // minor) first, then follow up with a MemRefReinterpretCast to cast the
1471 // resulting memref to the original layout.
1472 Value result =
1473 builder_.create<memref::ViewOp>(loc, physical_out_type, alloc, byte_shift,
1474 /*sizes=*/ValueRange{});
1475 if (result.getType() != out_type) {
1476 int64_t out_offset;
1477 SmallVector<int64_t, 4> out_strides;
1478 auto out_memref_type = out_type.dyn_cast<MemRefType>();
1479 if (!out_memref_type)
1480 return tensorflow::errors::Internal(
1481 "Expected memref type when creating a view for leaf type of a "
1482 "tuple.");
1483 if (failed(getStridesAndOffset(out_memref_type, out_strides, out_offset)))
1484 return tensorflow::errors::Internal(
1485 "Failed to get strides and offset from the output type.");
1486 result = builder_.create<memref::ReinterpretCastOp>(
1487 loc, out_memref_type, result, out_offset, out_memref_type.getShape(),
1488 out_strides);
1489 }
1490 return cached_value = result;
1491 }
1492
GetOrCreateViewImpl(const HloInstruction * instr,const Shape & current_shape,xla::ShapeIndex * current_shape_index,SmallVectorImpl<Value> * values,TokenLoweringMode token_mode)1493 Status LhloDialectEmitter::GetOrCreateViewImpl(
1494 const HloInstruction* instr, const Shape& current_shape,
1495 xla::ShapeIndex* current_shape_index, SmallVectorImpl<Value>* values,
1496 TokenLoweringMode token_mode) {
1497 if (current_shape.IsTuple()) {
1498 for (int i = 0; i < current_shape.tuple_shapes().size(); ++i) {
1499 current_shape_index->push_back(i);
1500 TF_RETURN_IF_ERROR(
1501 GetOrCreateViewImpl(instr, current_shape.tuple_shapes(i),
1502 current_shape_index, values, token_mode));
1503 current_shape_index->pop_back();
1504 }
1505 return ::tensorflow::OkStatus();
1506 }
1507 if (current_shape.IsArray()) {
1508 TF_ASSIGN_OR_RETURN(auto v, GetOrCreateArrayView(instr, current_shape,
1509 *current_shape_index));
1510 values->push_back(v);
1511 return ::tensorflow::OkStatus();
1512 }
1513 if (current_shape.IsToken()) {
1514 switch (token_mode) {
1515 case TokenLoweringMode::kFailToLower:
1516 return xla::InternalError(
1517 "Unexpected token kind for %s and shape index %s",
1518 instr->ToString(), current_shape_index->ToString());
1519
1520 case TokenLoweringMode::kUseNull:
1521 values->push_back(Value{});
1522 return ::tensorflow::OkStatus();
1523 }
1524 }
1525 return xla::InternalError("Unexpected shape kind for %s and shape index %s",
1526 instr->ToString(), current_shape_index->ToString());
1527 }
1528
1529 // Returns a view for the result of an instruction.
1530 // We first get a view for the slice in the allocation, and then may need to
1531 // create another view to adjust the slice for the shape of the instruction.
GetOrCreateView(const HloInstruction * instr,SmallVectorImpl<Value> * values,const xla::ShapeIndex & result_subset,TokenLoweringMode token_mode)1532 Status LhloDialectEmitter::GetOrCreateView(const HloInstruction* instr,
1533 SmallVectorImpl<Value>* values,
1534 const xla::ShapeIndex& result_subset,
1535 TokenLoweringMode token_mode) {
1536 xla::ShapeIndex shape_index = result_subset;
1537 const Shape& sub_shape =
1538 xla::ShapeUtil::GetSubshape(instr->shape(), shape_index);
1539 return GetOrCreateViewImpl(instr, sub_shape, &shape_index, values,
1540 token_mode);
1541 }
1542
Initialize()1543 Status LhloDialectEmitter::Initialize() {
1544 TF_RET_CHECK(computation_.IsEntryComputation());
1545
1546 mlir::IntegerAttr unique_id =
1547 builder_.getI32IntegerAttr(computation_.parent()->unique_id());
1548 module_->setAttr("hlo.unique_id", unique_id);
1549 std::string function_name =
1550 computation_.name().empty() ? "__compute" : computation_.name();
1551
1552 // Create the function as () -> (), we'll compute the arguments from the
1553 // buffer allocation and update the type then.
1554 auto func_op = func::FuncOp::create(builder_.getUnknownLoc(), function_name,
1555 builder_.getFunctionType({}, {}));
1556
1557 {
1558 // This is an optional attribute used by the XLA backend. If the resulting
1559 // LMHLO doesn't go through XLA, this is not needed.
1560 const Shape& shape = computation_.root_instruction()->shape();
1561 func_op->setAttr(
1562 "result_xla_shape",
1563 builder_.getStringAttr(shape.ToString(/*print_layout=*/true)));
1564 }
1565 Block* block = func_op.addEntryBlock();
1566
1567 llvm::SmallVector<const BufferAllocation*, 8> ordered_allocations;
1568 for (const BufferAllocation& alloc : assignment_.Allocations())
1569 ordered_allocations.push_back(&alloc);
1570
1571 if (computation_.IsEntryComputation()) {
1572 // Sort the rather arbitrarily ordered allocations to match the input/output
1573 // parameters. Specifically we want to sort buffer allocations in the
1574 // following order:
1575 // * Parameters always order before non-parameters.
1576 // * Different parameters order by parameter number.
1577 // * Different allocations for the same parameter order by the shape index.
1578 //
1579 // TODO(timshen): there should be only one non-parameter buffer, the temp
1580 // buffer. Check on that.
1581 const auto allocation_comparator = [](const BufferAllocation* lhs,
1582 const BufferAllocation* rhs) {
1583 if (lhs->is_entry_computation_parameter() !=
1584 rhs->is_entry_computation_parameter()) {
1585 return lhs->is_entry_computation_parameter() >
1586 rhs->is_entry_computation_parameter();
1587 }
1588 if (lhs->is_entry_computation_parameter()) {
1589 return std::tuple<int, const xla::ShapeIndex&>(
1590 lhs->parameter_number(), lhs->param_shape_index()) <
1591 std::tuple<int, const xla::ShapeIndex&>(
1592 rhs->parameter_number(), rhs->param_shape_index());
1593 }
1594 return false;
1595 };
1596
1597 std::stable_sort(ordered_allocations.begin(), ordered_allocations.end(),
1598 allocation_comparator);
1599 }
1600
1601 absl::flat_hash_map<const BufferAllocation*,
1602 std::pair<const Shape*, xla::ShapeIndex>>
1603 allocation_to_output_info;
1604 TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus(
1605 computation_.root_instruction()->shape(),
1606 [&](const Shape& sub_shape, xla::ShapeIndex index) -> Status {
1607 TF_ASSIGN_OR_RETURN(
1608 auto slice,
1609 assignment_.GetUniqueSlice(computation_.root_instruction(), index));
1610 const BufferAllocation* alloc = slice.allocation();
1611 TF_RET_CHECK(slice.offset() == 0);
1612 TF_RET_CHECK(slice.size() == alloc->size());
1613 allocation_to_output_info[alloc] = std::make_pair(&sub_shape, index);
1614 return ::tensorflow::OkStatus();
1615 }));
1616
1617 // The function signature will be composed of:
1618 // - one memref for each of the parameters.
1619 // - one memref for each other buffer allocation.
1620 llvm::SmallVector<DictionaryAttr, 8> args_attrs;
1621 for (const BufferAllocation* alloc : ordered_allocations) {
1622 if (alloc->is_thread_local()) {
1623 continue;
1624 }
1625
1626 // There are optional attributes to help the program run through XLA. XLA
1627 // defines ExecutionInput and ExecutionOutput structures to carry
1628 // input-output type and buffer information, therefore any information they
1629 // need (mainly the type structure, potentially containing tuples) to be
1630 // preserved. They are not needed if the generated LMHLO is not sent to XLA.
1631 NamedAttrList arg_attr_list;
1632 mlir::Type arg_type = MemRefType::get({alloc->size()}, i8_type_);
1633
1634 // Propagate source location information for every HLOInstruction that
1635 // uses this allocation.
1636 std::vector<mlir::Location> buf_locs;
1637 buf_locs.reserve(alloc->assigned_buffers().size());
1638 for (const auto& entry : alloc->assigned_buffers()) {
1639 const xla::HloValue* hlo_value = entry.first;
1640 buf_locs.push_back(getLocation(hlo_value->instruction()));
1641 }
1642 mlir::Location loc = builder_.getFusedLoc(buf_locs);
1643
1644 if (alloc->is_entry_computation_parameter()) {
1645 arg_attr_list.set("lmhlo.params",
1646 builder_.getIndexAttr(alloc->parameter_number()));
1647 if (!alloc->param_shape_index().empty()) {
1648 arg_attr_list.set("lmhlo.param_shape_index",
1649 builder_.getI64TensorAttr(llvm::makeArrayRef(
1650 alloc->param_shape_index().begin(),
1651 alloc->param_shape_index().end())));
1652 }
1653 }
1654 // Optional: an attribute for optimization. If a kernel uses this
1655 // allocation, but the allocation has lmhlo.constant_name, then the kernel
1656 // will instead use the global value indicated by the name for potentially
1657 // more optimizations (e.g. constant propagation).
1658 if (alloc->is_constant()) {
1659 arg_attr_list.set(
1660 "lmhlo.constant_name",
1661 builder_.getStringAttr(
1662 xla::llvm_ir::ConstantBufferAllocationToGlobalName(*alloc)));
1663 }
1664 auto iter = allocation_to_output_info.find(alloc);
1665 if (iter != allocation_to_output_info.end()) {
1666 const Shape* sub_shape = iter->second.first;
1667 const xla::ShapeIndex& shape_index = iter->second.second;
1668 if (!sub_shape->IsArray()) {
1669 continue;
1670 }
1671 arg_attr_list.set("lmhlo.output_index",
1672 builder_.getI64TensorAttr(llvm::makeArrayRef(
1673 shape_index.begin(), shape_index.end())));
1674 if (auto alias = computation_.parent()
1675 ->input_output_alias_config()
1676 .GetAliasedParameter(shape_index)) {
1677 if (alias->must_alias()) {
1678 arg_attr_list.set("lmhlo.must_alias", builder_.getUnitAttr());
1679 }
1680 }
1681 }
1682 block->addArgument(arg_type, loc);
1683 allocations_[alloc] = block->getArguments().back();
1684 args_attrs.push_back(arg_attr_list.getDictionary(builder_.getContext()));
1685 }
1686
1687 FunctionType function_type =
1688 builder_.getFunctionType(block->getArgumentTypes(), {});
1689 func_op.setType(function_type);
1690 func_op.setAllArgAttrs(args_attrs);
1691
1692 SymbolTable symbol_table(module_);
1693 symbol_table.insert(func_op);
1694 builder_.setInsertionPointToEnd(block);
1695
1696 auto return_op =
1697 builder_.create<lmhlo::TerminatorOp>(builder_.getUnknownLoc());
1698 builder_ = OpBuilder(return_op);
1699
1700 return ::tensorflow::OkStatus();
1701 }
1702
createXlaHloToLhloWithXlaPass()1703 std::unique_ptr<OperationPass<ModuleOp>> createXlaHloToLhloWithXlaPass() {
1704 return std::make_unique<XlaHloToLhloPass>();
1705 }
1706
HloToLhloModule(const BufferAssignment & assignment,const HloModule & hlo_module,ModuleOp module)1707 Status HloToLhloModule(const BufferAssignment& assignment,
1708 const HloModule& hlo_module, ModuleOp module) {
1709 module.getContext()
1710 ->loadDialect<arith::ArithmeticDialect,
1711 bufferization::BufferizationDialect, func::FuncDialect,
1712 memref::MemRefDialect, mhlo::MhloDialect,
1713 lmhlo::LmhloDialect, lmhlo_gpu::LmhloGpuDialect>();
1714
1715 module->setLoc(mlir::NameLoc::get(
1716 mlir::StringAttr::get(module.getContext(), hlo_module.name())));
1717
1718 // Store the HloModule's unique_id in the MLIR module.
1719 Builder builder(module.getContext());
1720 module->setAttr("mhlo.unique_id",
1721 builder.getI64IntegerAttr(hlo_module.unique_id()));
1722
1723 const HloComputation* computation = hlo_module.entry_computation();
1724
1725 LhloDialectEmitter emitter(assignment, *computation, module);
1726 TF_RETURN_IF_ERROR(emitter.Initialize());
1727
1728 const xla::HloInstructionSequence* schedule =
1729 assignment.hlo_ordering().SequentialOrder(*computation);
1730 if (!schedule)
1731 return xla::Unimplemented("Missing sequential order for the computation");
1732
1733 StatusScopedDiagnosticHandler status_handler(module.getContext());
1734
1735 const std::vector<HloInstruction*>& ordering = schedule->instructions();
1736 TF_RETURN_IF_ERROR(computation->AcceptOrdered(&emitter, ordering));
1737 TF_RETURN_IF_ERROR(status_handler.ConsumeStatus());
1738
1739 (void)mlir::verify(module);
1740 return status_handler.ConsumeStatus();
1741 }
1742
HloTextToLhloTranslateFunction(llvm::StringRef input,MLIRContext * context,bool optimize_xla_hlo)1743 OwningOpRef<mlir::ModuleOp> HloTextToLhloTranslateFunction(
1744 llvm::StringRef input, MLIRContext* context, bool optimize_xla_hlo) {
1745 StatusOr<std::unique_ptr<HloModule>> maybe_module =
1746 xla::ParseAndReturnUnverifiedModule(
1747 absl::string_view(input.data(), input.size()));
1748 TF_CHECK_OK(maybe_module.status());
1749
1750 OwningOpRef<mlir::ModuleOp> module =
1751 ModuleOp::create(UnknownLoc::get(context));
1752
1753 TF_CHECK_OK(OptimizeAndConvertHloToLmhlo(
1754 std::move(maybe_module).value(), module.get(), "Host", optimize_xla_hlo));
1755
1756 return module;
1757 }
1758
RegisterMhloToLhloWithXlaPass()1759 void RegisterMhloToLhloWithXlaPass() {
1760 static PassRegistration<XlaHloToLhloPass> registration;
1761 }
1762
1763 } // namespace mlir
1764