xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 //===- kernel_creator.cc ----------------------------------------*- C++ -*-===//
17 //
18 // This file implements the function to compile a TF kernel function to gpu
19 // binary (hsaco for AMD, cubin for NVIDIA) or to a gpu binary with host side.
20 //
21 //===----------------------------------------------------------------------===//
22 #include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h"
23 
24 #include <string>
25 
26 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"  // from @llvm-project
27 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"  // from @llvm-project
28 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"  // from @llvm-project
29 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"  // from @llvm-project
30 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"  // from @llvm-project
31 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"  // from @llvm-project
32 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"  // from @llvm-project
33 #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"  // from @llvm-project
34 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"  // from @llvm-project
35 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"  // from @llvm-project
36 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
37 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"  // from @llvm-project
38 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"  // from @llvm-project
39 #include "mlir/Dialect/GPU/IR/GPUDialect.h"  // from @llvm-project
40 #include "mlir/Dialect/GPU/Transforms/Passes.h"  // from @llvm-project
41 #include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"  // from @llvm-project
42 #include "mlir/Dialect/Linalg/Passes.h"  // from @llvm-project
43 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"  // from @llvm-project
44 #include "mlir/Dialect/MemRef/Transforms/Passes.h"  // from @llvm-project
45 #include "mlir/Dialect/SCF/Transforms/Passes.h"  // from @llvm-project
46 #include "mlir/Parser/Parser.h"  // from @llvm-project
47 #include "mlir/Pass/Pass.h"  // from @llvm-project
48 #include "mlir/Pass/PassManager.h"  // from @llvm-project
49 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"  // from @llvm-project
50 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"  // from @llvm-project
51 #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"  // from @llvm-project
52 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
53 #include "mlir/Transforms/Passes.h"  // from @llvm-project
54 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
55 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
56 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
57 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
58 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
59 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
60 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
61 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
62 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/gpu_passes.h"
63 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/passes.h"
64 #include "tensorflow/core/platform/statusor.h"
65 
66 namespace tensorflow {
67 namespace kernel_gen {
68 namespace {
69 
70 using mlir::Value;
71 using mlir::func::FuncOp;
72 using mlir::memref::RankOp;
73 
74 constexpr llvm::StringRef kGpuBinaryAttrName = "gpu.binary";
75 
76 /// Check if the size of the allocation is less than the given size. The
77 /// transformation is only applied to small buffers since large buffers could
78 /// exceed the stack space.
IsSmallAlloc(Value alloc)79 bool IsSmallAlloc(Value alloc) {
80   constexpr unsigned kMaximumSizeInBytes = 64;
81   constexpr unsigned kMaxRankOfAllocatedMemRef = 1;
82 
83   auto type = alloc.getType().dyn_cast<mlir::ShapedType>();
84   if (!type || !alloc.getDefiningOp<mlir::memref::AllocOp>()) return false;
85   if (!type.hasStaticShape()) {
86     // Check if the dynamic shape dimension of the alloc is produced by RankOp
87     // or SelectOp(_, RankOp, RankOp).
88     // If this is the case, it is likely to be small. Furthermore, the dimension
89     // is limited to the maximum rank of the allocated memref to avoid large
90     // values by multiplying several small values.
91     if (type.getRank() <= kMaxRankOfAllocatedMemRef) {
92       for (Value alloc_arg : alloc.getDefiningOp()->getOperands()) {
93         if (auto select = alloc_arg.getDefiningOp<mlir::arith::SelectOp>()) {
94           if (!select.getTrueValue().getDefiningOp<RankOp>() ||
95               !select.getFalseValue().getDefiningOp<RankOp>())
96             return false;
97         } else if (!alloc_arg.getDefiningOp<RankOp>()) {
98           return false;
99         }
100       }
101       return true;
102     }
103     return false;
104   }
105   unsigned bitwidth = mlir::DataLayout::closest(alloc.getDefiningOp())
106                           .getTypeSizeInBits(type.getElementType());
107   return type.getNumElements() * bitwidth <= kMaximumSizeInBytes * 8;
108 }
109 
LowerTFToJITInvocation(mlir::ModuleOp module,llvm::ArrayRef<int64_t> tile_sizes,llvm::ArrayRef<int64_t> unroll_factors,int64_t max_supported_rank,bool enable_ftz,bool index_64bit,bool jit_i64_indexed_for_large_tensors,bool apply_cl_options)110 Status LowerTFToJITInvocation(mlir::ModuleOp module,
111                               llvm::ArrayRef<int64_t> tile_sizes,
112                               llvm::ArrayRef<int64_t> unroll_factors,
113                               int64_t max_supported_rank, bool enable_ftz,
114                               bool index_64bit,
115                               bool jit_i64_indexed_for_large_tensors,
116                               bool apply_cl_options) {
117   mlir::PassManager pm(module.getContext());
118   if (apply_cl_options) applyTensorflowAndCLOptions(pm);
119 
120   pm.addNestedPass<FuncOp>(
121       mlir::kernel_gen::transforms::CreateTFToJITInvocationPass(
122           tile_sizes, unroll_factors, max_supported_rank, enable_ftz,
123           index_64bit, jit_i64_indexed_for_large_tensors));
124   pm.addPass(mlir::kernel_gen::tf_framework::CreateEmbedTFFrameworkPass());
125   pm.addNestedPass<FuncOp>(mlir::createLinalgInitTensorToAllocTensorPass());
126   pm.addPass(mlir::createComputeOpAndFuncBufferizePass());
127 
128   pm.addPass(mlir::createFinalBufferizePass(
129       /*alignment=*/64,
130       mlir::kernel_gen::transforms::populateExtraBufferizeDialects,
131       mlir::kernel_gen::transforms::populateExtraBufferizePatterns));
132 
133   if (failed(pm.run(module))) {
134     return tensorflow::errors::Internal(
135         "Lowering TF to JIT invocation failed.");
136   }
137   return OkStatus();
138 }
139 
LowerTFtoLoops(mlir::ModuleOp module,llvm::ArrayRef<int64_t> tile_sizes,llvm::ArrayRef<int64_t> unroll_factors,int64_t max_supported_rank,bool enable_ftz,bool index_64bit,bool jit_i64_indexed_for_large_tensors,bool apply_cl_options)140 Status LowerTFtoLoops(mlir::ModuleOp module, llvm::ArrayRef<int64_t> tile_sizes,
141                       llvm::ArrayRef<int64_t> unroll_factors,
142                       int64_t max_supported_rank, bool enable_ftz,
143                       bool index_64bit, bool jit_i64_indexed_for_large_tensors,
144                       bool apply_cl_options) {
145   mlir::PassManager pm(module.getContext());
146   if (apply_cl_options) applyTensorflowAndCLOptions(pm);
147   if (jit_i64_indexed_for_large_tensors) {
148     pm.addNestedPass<FuncOp>(
149         mlir::kernel_gen::transforms::CreateTFToJITInvocationPass(
150             tile_sizes, unroll_factors, max_supported_rank, enable_ftz,
151             index_64bit,
152             /*jit_i64_indexed_for_large_tensors=*/true));
153   }
154   pm.addNestedPass<FuncOp>(mlir::mhlo::createLegalizeTFNoFallbackPass(
155       /*allow_partial_conversion=*/true));
156   pm.addNestedPass<FuncOp>(mlir::mhlo::createRankSpecializationClusterPass());
157   pm.addNestedPass<FuncOp>(
158       mlir::mhlo::createRankSpecializationToSCFPass(max_supported_rank));
159   pm.addNestedPass<FuncOp>(mlir::mhlo::createChloLegalizeToHloPass());
160 
161   pm.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
162   pm.addNestedPass<FuncOp>(mlir::createCSEPass());
163   pm.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
164   pm.addNestedPass<FuncOp>(mlir::createShapeSimplification());
165   pm.addNestedPass<FuncOp>(mlir::mhlo::createMergeAssumingOpsPass());
166   pm.addNestedPass<FuncOp>(mlir::mhlo::createBroadcastPropagationPass());
167   pm.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
168   pm.addNestedPass<FuncOp>(mlir::createCSEPass());
169 
170   // Transform HLO operations to LinAlg and standard.
171   pm.addNestedPass<FuncOp>(::mlir::mhlo::createLegalizeHloToLinalgPass());
172   pm.addPass(::mlir::mhlo::createLegalizeToArithmeticPass());
173   pm.addNestedPass<FuncOp>(
174       mlir::mhlo::createLegalizeHloShapeOpsToStandardPass());
175 
176   // Remove the remaining references to unsigned types after all HLO compute
177   // operations were converted.
178   pm.addPass(mlir::mhlo::createConvertToSignlessPass());
179 
180   pm.addPass(mlir::createCanonicalizerPass());
181   pm.addNestedPass<FuncOp>(mlir::createCSEPass());
182 
183   // Convert operations from the Complex dialect to the Standard/Math dialects.
184   pm.addNestedPass<FuncOp>(::mlir::createConvertComplexToStandardPass());
185 
186   // Fuse linalg operations.
187   pm.addPass(mlir::memref::createResolveShapedTypeResultDimsPass());
188   pm.addPass(mlir::createCanonicalizerPass());
189   pm.addNestedPass<FuncOp>(mlir::createLinalgElementwiseOpFusionPass());
190 
191   // Partial bufferization: Transforms inparticular HLO and Linalg operations to
192   // their corresponding LHLO operations and converts the function signature.
193   // Leaves shape operations untouched.
194   //
195   // TODO(pifon): Rename the pass to CreateHloLinalgBufferizePass or bufferize
196   // in 2 steps: first Linalg, then Hlo. That would need refactoring of
197   // BufferizeTypeConverter.
198   pm.addNestedPass<FuncOp>(mlir::createLinalgInitTensorToAllocTensorPass());
199   pm.addPass(mlir::createComputeOpAndFuncBufferizePass());
200   pm.addNestedPass<FuncOp>(::mlir::createCanonicalizerPass());
201   pm.addNestedPass<FuncOp>(::mlir::createCSEPass());
202   // Remove copies which are introduced by canonicalizing
203   // BufferCastOp(TensorLoadOp).
204   pm.addNestedPass<FuncOp>(
205       mlir::kernel_gen::transforms::CreateCopyCleanupPass());
206   // Find candidates for buffer reuse. This is only successful if buffer size
207   // equality can be determined based on `linalg.generic` operations.
208   pm.addNestedPass<FuncOp>(
209       mlir::kernel_gen::transforms::CreateBufferReusePass());
210   // Approximate Tanh using standard operations.
211   pm.addNestedPass<FuncOp>(
212       ::mlir::mhlo::createLegalizeTrigonometricToApproximationPass());
213   // Transform the Linalg ops inside of the loop nest into parallel loops.
214   pm.addNestedPass<FuncOp>(::mlir::createConvertLinalgToParallelLoopsPass());
215 
216   // Canonicalize the code to simplify index computations. This is needed so
217   // that loop bounds have the same value.
218   pm.addNestedPass<FuncOp>(::mlir::createCanonicalizerPass());
219   // Run CSE to ensure that loads and stores to the same subview get
220   // recognized as such.
221   pm.addNestedPass<FuncOp>(::mlir::createCSEPass());
222   // Collapse and tile parallel loops for GPU only.
223   pm.addNestedPass<FuncOp>(mlir::createCollapseParallelLoopsTo1DPass());
224   pm.addNestedPass<FuncOp>(
225       mlir::createTileLoopsPass(tile_sizes, unroll_factors));
226 
227   pm.addNestedPass<FuncOp>(::mlir::createCanonicalizerPass());
228   pm.addNestedPass<FuncOp>(::mlir::createCSEPass());
229   if (failed(pm.run(module))) {
230     return tensorflow::errors::Internal("Lowering TF to loops failed.");
231   }
232   return OkStatus();
233 }
234 
LowerLoopsToGPU(mlir::ModuleOp module,bool embed_memref_prints,bool index_64bit,bool apply_cl_options)235 Status LowerLoopsToGPU(mlir::ModuleOp module, bool embed_memref_prints,
236                        bool index_64bit, bool apply_cl_options) {
237   mlir::PassManager pm(module.getContext());
238   if (apply_cl_options) applyTensorflowAndCLOptions(pm);
239 
240   // Greedily map the remaining loop to GPU hardware dimensions.
241   pm.addNestedPass<FuncOp>(mlir::createGpuMapParallelLoopsPass());
242 
243   // Expand memref_reshape to its ranked form so that we can propagate
244   // scalars and avoid allocation.
245   pm.addNestedPass<FuncOp>(mlir::arith::createArithmeticExpandOpsPass());
246   pm.addNestedPass<FuncOp>(mlir::memref::createExpandOpsPass());
247   pm.addPass(mlir::createCanonicalizerPass());
248   pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass());
249   // Before bufferizing further, remove unused tensor_to_memref, so that we do
250   // not create allocations for tensor computations that are not actually
251   // needed.
252   pm.addPass(mlir::createCanonicalizerPass());
253   pm.addNestedPass<FuncOp>(mlir::createCSEPass());
254   // Before inserting more allocs, map the ones we already have to the
255   // tf runtime. That ensures that all allocations for the actual computation
256   // end up on the device, whereas allocations for shape computation and host
257   // side things remain on the host.
258   // Longer term, this should be handled by proper device placement.
259   pm.addPass(mlir::kernel_gen::tf_framework::CreateEmbedTFFrameworkPass());
260   // Now lower the shape computations, bufferize all remaining ops and insert
261   // deallocs.
262   pm.addPass(mlir::createFinalBufferizePass(
263       /*alignment=*/64,
264       mlir::kernel_gen::transforms::populateExtraBufferizeDialects,
265       mlir::kernel_gen::transforms::populateExtraBufferizePatterns));
266   // TODO(herhut): Enable once no-longer broken.
267   pm.addNestedPass<FuncOp>(::mlir::bufferization::createBufferHoistingPass());
268   pm.addNestedPass<FuncOp>(mlir::bufferization::createPromoteBuffersToStackPass(
269       [](Value alloc) { return IsSmallAlloc(alloc); }));
270   // Free all temporaries,
271   pm.addNestedPass<FuncOp>(
272       ::mlir::bufferization::createBufferDeallocationPass());
273   pm.addPass(mlir::createCanonicalizerPass());
274 
275   // Apply the mapping and go to GPU. We cannot do this earlier due to missing
276   // interfaces on the GPU dialect.
277   // TODO(b/174830459): Move up once implemented.
278   pm.addNestedPass<FuncOp>(mlir::createParallelLoopToGpuPass());
279 
280   // Some basic cleanup.
281   pm.addNestedPass<FuncOp>(::mlir::createCanonicalizerPass());
282   pm.addNestedPass<FuncOp>(::mlir::createCSEPass());
283   // Make loops with min bounds into a conditional plus static bounds.
284   pm.addNestedPass<FuncOp>(mlir::createForLoopSpecializationPass());
285   // Take launches to launches with kernels.
286   pm.addPass(mlir::createGpuLauchSinkIndexComputationsPass());
287   const std::string gpuDataLayoutSpec =
288       index_64bit ? "#dlti.dl_spec<#dlti.dl_entry<index,64:i64>>"
289                   : "#dlti.dl_spec<#dlti.dl_entry<index,32:i32>>";
290   pm.addPass(mlir::createGpuKernelOutliningPass(gpuDataLayoutSpec));
291 
292   pm.addPass(::mlir::createLowerAffinePass());
293   // Constraints are removed as late as possible and before lowering to CFG.
294   pm.addNestedPass<FuncOp>(::mlir::createConvertShapeConstraintsPass());
295   pm.addNestedPass<FuncOp>(::mlir::createCanonicalizerPass());
296   pm.addPass(::mlir::createConvertSCFToCFPass());
297   // Map asserts to the tensorflow framework.
298   pm.addPass(mlir::kernel_gen::tf_framework::CreateRewriteTFFrameworkAssert());
299   if (embed_memref_prints) {
300     pm.addPass(mlir::kernel_gen::transforms::CreateEmbedMemRefPrintsPass());
301   }
302   if (failed(pm.run(module))) {
303     return tensorflow::errors::Internal("Lowering to GPU kernels failed.");
304   }
305   return OkStatus();
306 }
307 
LowerKernelBodiesToLowLevelIr(mlir::ModuleOp module,bool apply_cl_options)308 Status LowerKernelBodiesToLowLevelIr(mlir::ModuleOp module,
309                                      bool apply_cl_options) {
310 #if !defined(TENSORFLOW_USE_ROCM) && !defined(GOOGLE_CUDA)
311   return tensorflow::errors::Internal(
312       "Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined."
313       " Did you specify either --config=rocm or --config=cuda ?");
314 #endif
315 
316 #if TENSORFLOW_USE_ROCM
317   auto gpu_modules = module.getOps<::mlir::gpu::GPUModuleOp>();
318   for (::mlir::gpu::GPUModuleOp gpu_module : gpu_modules) {
319     gpu_module.walk([&](mlir::gpu::GPUFuncOp gpu_kernel) {
320       if (gpu_kernel.isKernel()) {
321         gpu_kernel->setAttr(
322             "rocdl.max_flat_work_group_size",
323             mlir::IntegerAttr::get(
324                 mlir::IntegerType::get(module.getContext(), 32), 1024));
325       }
326     });
327   }
328 #endif
329 
330   mlir::PassManager pm(module.getContext());
331   // We cannot verify as the signature of the kernel is rewritten.
332   // pm.enableVerifier(false);
333   if (apply_cl_options) tensorflow::applyTensorflowAndCLOptions(pm);
334   auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>();
335   kernelPm.addPass(::mlir::createConvertSCFToCFPass());
336 #if TENSORFLOW_USE_ROCM
337   kernelPm.addPass(mlir::createGpuKernelToRocdlPass());
338 #elif GOOGLE_CUDA
339   kernelPm.addPass(mlir::createGpuKernelToNvvmPass());
340   kernelPm.addPass(mlir::NVVM::createOptimizeForTargetPass());
341 #endif
342   // Remove all location information to prevent a debug build.
343   pm.addPass(::mlir::createStripDebugInfoPass());
344 
345   if (failed(pm.run(module))) {
346     return tensorflow::errors::Internal(
347         "Lowering to low-level device IR failed.");
348   }
349 
350   return OkStatus();
351 }
352 
AmendKernelLLVMIRWithStaticKnowledge(mlir::ModuleOp module,bool apply_cl_options)353 Status AmendKernelLLVMIRWithStaticKnowledge(mlir::ModuleOp module,
354                                             bool apply_cl_options) {
355   mlir::PassManager pm(module.getContext());
356   if (apply_cl_options) applyTensorflowAndCLOptions(pm);
357 
358   pm.addNestedPass<FuncOp>(
359       mlir::kernel_gen::transforms::CreatePropagateShapeKnowledgeToKernels());
360   pm.addNestedPass<FuncOp>(
361       mlir::kernel_gen::transforms::CreatePropagateTfAbiKnowledgeToKernels());
362 
363   return failed(pm.run(module))
364              ? tensorflow::errors::Internal(
365                    "Amending LLVMIR with static knowledge failed.")
366              : OkStatus();
367 }
368 
GenerateDeviceCode(mlir::ModuleOp module,llvm::StringRef gpu_binary_attr_name,llvm::ArrayRef<std::string> architectures,bool print_ptx,bool print_llvmir,bool enable_ftz,bool apply_cl_options)369 Status GenerateDeviceCode(mlir::ModuleOp module,
370                           llvm::StringRef gpu_binary_attr_name,
371                           llvm::ArrayRef<std::string> architectures,
372                           bool print_ptx, bool print_llvmir, bool enable_ftz,
373                           bool apply_cl_options) {
374   mlir::PassManager pm(module.getContext());
375   if (apply_cl_options) applyTensorflowAndCLOptions(pm);
376   mlir::registerLLVMDialectTranslation(*module->getContext());
377 
378   auto& kernel_pm = pm.nest<mlir::gpu::GPUModuleOp>();
379   // Remove debug information to ensure we do not create debug PTX.
380   kernel_pm.addPass(mlir::createStripDebugInfoPass());
381   kernel_pm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToBlobPass(
382       gpu_binary_attr_name, architectures, print_ptx, print_llvmir,
383       enable_ftz));
384 
385   return failed(pm.run(module))
386              ? tensorflow::errors::Internal("Generating device code failed.")
387              : OkStatus();
388 }
389 
LowerHostSideToFinalForm(mlir::ModuleOp module,bool apply_cl_options)390 Status LowerHostSideToFinalForm(mlir::ModuleOp module, bool apply_cl_options) {
391   mlir::PassManager pm(module.getContext());
392   if (apply_cl_options) applyTensorflowAndCLOptions(pm);
393 
394   pm.addPass(mlir::kernel_gen::transforms::CreateTFKernelToLLVMPass(
395       kGpuBinaryAttrName));
396   pm.addPass(mlir::createReconcileUnrealizedCastsPass());
397   pm.addPass(mlir::createCanonicalizerPass());
398   pm.addPass(mlir::createCSEPass());
399 
400   return failed(pm.run(module)) ? tensorflow::errors::Internal(
401                                       "Final lowering of host side failed.")
402                                 : OkStatus();
403 }
404 
405 }  // namespace
406 
SetupContextAndParseModule(mlir::MLIRContext & context,llvm::StringRef tf_code)407 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> SetupContextAndParseModule(
408     mlir::MLIRContext& context, llvm::StringRef tf_code) {
409   mlir::DialectRegistry registry;
410   mlir::RegisterAllTensorFlowDialects(registry);
411   registry.insert<mlir::chlo::ChloDialect, mlir::mhlo::MhloDialect>();
412   mlir::registerLLVMDialectTranslation(registry);
413   mlir::registerNVVMDialectTranslation(registry);
414   mlir::registerROCDLDialectTranslation(registry);
415   context.appendDialectRegistry(registry);
416   mlir::OwningOpRef<mlir::ModuleOp> module =
417       mlir::parseSourceString<mlir::ModuleOp>(tf_code, &context);
418   if (!module)
419     return tensorflow::Status(tensorflow::error::Code::INVALID_ARGUMENT,
420                               "invalid kernel IR");
421   return module;
422 }
423 
GenerateKernelForTfCode(mlir::MLIRContext & context,llvm::StringRef tf_code,llvm::ArrayRef<std::string> architectures,llvm::ArrayRef<int64_t> tile_sizes,llvm::ArrayRef<int64_t> unroll_factors,int64_t max_supported_rank,bool embed_memref_prints,bool print_ptx,bool print_llvmir,bool enable_ftz,bool index_64bit,bool jit_compile,bool jit_i64_indexed_for_large_tensors,bool apply_cl_options)424 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> GenerateKernelForTfCode(
425     mlir::MLIRContext& context, llvm::StringRef tf_code,
426     llvm::ArrayRef<std::string> architectures,
427     llvm::ArrayRef<int64_t> tile_sizes, llvm::ArrayRef<int64_t> unroll_factors,
428     int64_t max_supported_rank, bool embed_memref_prints, bool print_ptx,
429     bool print_llvmir, bool enable_ftz, bool index_64bit, bool jit_compile,
430     bool jit_i64_indexed_for_large_tensors, bool apply_cl_options) {
431   TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> module,
432                       SetupContextAndParseModule(context, tf_code));
433 
434   if (jit_compile) {
435     TF_RETURN_IF_ERROR(LowerTFToJITInvocation(
436         module.get(), tile_sizes, unroll_factors, max_supported_rank,
437         enable_ftz, index_64bit,
438         /*jit_i64_indexed_for_large_tensors=*/false, apply_cl_options));
439   } else {
440     TF_RETURN_IF_ERROR(
441         LowerTFtoLoops(module.get(), tile_sizes, unroll_factors,
442                        max_supported_rank, enable_ftz, index_64bit,
443                        jit_i64_indexed_for_large_tensors, apply_cl_options));
444     TF_RETURN_IF_ERROR(LowerLoopsToGPU(module.get(), embed_memref_prints,
445                                        index_64bit, apply_cl_options));
446     TF_RETURN_IF_ERROR(
447         LowerKernelBodiesToLowLevelIr(module.get(), apply_cl_options));
448     TF_RETURN_IF_ERROR(
449         AmendKernelLLVMIRWithStaticKnowledge(module.get(), apply_cl_options));
450     TF_RETURN_IF_ERROR(GenerateDeviceCode(
451         module.get(), kGpuBinaryAttrName, architectures, print_ptx,
452         print_llvmir, enable_ftz, apply_cl_options));
453   }
454 
455   TF_RETURN_IF_ERROR(LowerHostSideToFinalForm(module.get(), apply_cl_options));
456 
457   return module;
458 }
459 
460 }  // namespace kernel_gen
461 }  // namespace tensorflow
462