1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 //===- kernel_creator.cc ----------------------------------------*- C++ -*-===//
17 //
18 // This file implements the function to compile a TF kernel function to gpu
19 // binary (hsaco for AMD, cubin for NVIDIA) or to a gpu binary with host side.
20 //
21 //===----------------------------------------------------------------------===//
22 #include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h"
23
24 #include <string>
25
26 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
27 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" // from @llvm-project
28 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" // from @llvm-project
29 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project
30 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project
31 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project
32 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project
33 #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project
34 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project
35 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project
36 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
37 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" // from @llvm-project
38 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" // from @llvm-project
39 #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project
40 #include "mlir/Dialect/GPU/Transforms/Passes.h" // from @llvm-project
41 #include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h" // from @llvm-project
42 #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
43 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
44 #include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project
45 #include "mlir/Dialect/SCF/Transforms/Passes.h" // from @llvm-project
46 #include "mlir/Parser/Parser.h" // from @llvm-project
47 #include "mlir/Pass/Pass.h" // from @llvm-project
48 #include "mlir/Pass/PassManager.h" // from @llvm-project
49 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project
50 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project
51 #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" // from @llvm-project
52 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
53 #include "mlir/Transforms/Passes.h" // from @llvm-project
54 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
55 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
56 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
57 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
58 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
59 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
60 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
61 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
62 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/gpu_passes.h"
63 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/passes.h"
64 #include "tensorflow/core/platform/statusor.h"
65
66 namespace tensorflow {
67 namespace kernel_gen {
68 namespace {
69
70 using mlir::Value;
71 using mlir::func::FuncOp;
72 using mlir::memref::RankOp;
73
74 constexpr llvm::StringRef kGpuBinaryAttrName = "gpu.binary";
75
76 /// Check if the size of the allocation is less than the given size. The
77 /// transformation is only applied to small buffers since large buffers could
78 /// exceed the stack space.
IsSmallAlloc(Value alloc)79 bool IsSmallAlloc(Value alloc) {
80 constexpr unsigned kMaximumSizeInBytes = 64;
81 constexpr unsigned kMaxRankOfAllocatedMemRef = 1;
82
83 auto type = alloc.getType().dyn_cast<mlir::ShapedType>();
84 if (!type || !alloc.getDefiningOp<mlir::memref::AllocOp>()) return false;
85 if (!type.hasStaticShape()) {
86 // Check if the dynamic shape dimension of the alloc is produced by RankOp
87 // or SelectOp(_, RankOp, RankOp).
88 // If this is the case, it is likely to be small. Furthermore, the dimension
89 // is limited to the maximum rank of the allocated memref to avoid large
90 // values by multiplying several small values.
91 if (type.getRank() <= kMaxRankOfAllocatedMemRef) {
92 for (Value alloc_arg : alloc.getDefiningOp()->getOperands()) {
93 if (auto select = alloc_arg.getDefiningOp<mlir::arith::SelectOp>()) {
94 if (!select.getTrueValue().getDefiningOp<RankOp>() ||
95 !select.getFalseValue().getDefiningOp<RankOp>())
96 return false;
97 } else if (!alloc_arg.getDefiningOp<RankOp>()) {
98 return false;
99 }
100 }
101 return true;
102 }
103 return false;
104 }
105 unsigned bitwidth = mlir::DataLayout::closest(alloc.getDefiningOp())
106 .getTypeSizeInBits(type.getElementType());
107 return type.getNumElements() * bitwidth <= kMaximumSizeInBytes * 8;
108 }
109
LowerTFToJITInvocation(mlir::ModuleOp module,llvm::ArrayRef<int64_t> tile_sizes,llvm::ArrayRef<int64_t> unroll_factors,int64_t max_supported_rank,bool enable_ftz,bool index_64bit,bool jit_i64_indexed_for_large_tensors,bool apply_cl_options)110 Status LowerTFToJITInvocation(mlir::ModuleOp module,
111 llvm::ArrayRef<int64_t> tile_sizes,
112 llvm::ArrayRef<int64_t> unroll_factors,
113 int64_t max_supported_rank, bool enable_ftz,
114 bool index_64bit,
115 bool jit_i64_indexed_for_large_tensors,
116 bool apply_cl_options) {
117 mlir::PassManager pm(module.getContext());
118 if (apply_cl_options) applyTensorflowAndCLOptions(pm);
119
120 pm.addNestedPass<FuncOp>(
121 mlir::kernel_gen::transforms::CreateTFToJITInvocationPass(
122 tile_sizes, unroll_factors, max_supported_rank, enable_ftz,
123 index_64bit, jit_i64_indexed_for_large_tensors));
124 pm.addPass(mlir::kernel_gen::tf_framework::CreateEmbedTFFrameworkPass());
125 pm.addNestedPass<FuncOp>(mlir::createLinalgInitTensorToAllocTensorPass());
126 pm.addPass(mlir::createComputeOpAndFuncBufferizePass());
127
128 pm.addPass(mlir::createFinalBufferizePass(
129 /*alignment=*/64,
130 mlir::kernel_gen::transforms::populateExtraBufferizeDialects,
131 mlir::kernel_gen::transforms::populateExtraBufferizePatterns));
132
133 if (failed(pm.run(module))) {
134 return tensorflow::errors::Internal(
135 "Lowering TF to JIT invocation failed.");
136 }
137 return OkStatus();
138 }
139
LowerTFtoLoops(mlir::ModuleOp module,llvm::ArrayRef<int64_t> tile_sizes,llvm::ArrayRef<int64_t> unroll_factors,int64_t max_supported_rank,bool enable_ftz,bool index_64bit,bool jit_i64_indexed_for_large_tensors,bool apply_cl_options)140 Status LowerTFtoLoops(mlir::ModuleOp module, llvm::ArrayRef<int64_t> tile_sizes,
141 llvm::ArrayRef<int64_t> unroll_factors,
142 int64_t max_supported_rank, bool enable_ftz,
143 bool index_64bit, bool jit_i64_indexed_for_large_tensors,
144 bool apply_cl_options) {
145 mlir::PassManager pm(module.getContext());
146 if (apply_cl_options) applyTensorflowAndCLOptions(pm);
147 if (jit_i64_indexed_for_large_tensors) {
148 pm.addNestedPass<FuncOp>(
149 mlir::kernel_gen::transforms::CreateTFToJITInvocationPass(
150 tile_sizes, unroll_factors, max_supported_rank, enable_ftz,
151 index_64bit,
152 /*jit_i64_indexed_for_large_tensors=*/true));
153 }
154 pm.addNestedPass<FuncOp>(mlir::mhlo::createLegalizeTFNoFallbackPass(
155 /*allow_partial_conversion=*/true));
156 pm.addNestedPass<FuncOp>(mlir::mhlo::createRankSpecializationClusterPass());
157 pm.addNestedPass<FuncOp>(
158 mlir::mhlo::createRankSpecializationToSCFPass(max_supported_rank));
159 pm.addNestedPass<FuncOp>(mlir::mhlo::createChloLegalizeToHloPass());
160
161 pm.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
162 pm.addNestedPass<FuncOp>(mlir::createCSEPass());
163 pm.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
164 pm.addNestedPass<FuncOp>(mlir::createShapeSimplification());
165 pm.addNestedPass<FuncOp>(mlir::mhlo::createMergeAssumingOpsPass());
166 pm.addNestedPass<FuncOp>(mlir::mhlo::createBroadcastPropagationPass());
167 pm.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
168 pm.addNestedPass<FuncOp>(mlir::createCSEPass());
169
170 // Transform HLO operations to LinAlg and standard.
171 pm.addNestedPass<FuncOp>(::mlir::mhlo::createLegalizeHloToLinalgPass());
172 pm.addPass(::mlir::mhlo::createLegalizeToArithmeticPass());
173 pm.addNestedPass<FuncOp>(
174 mlir::mhlo::createLegalizeHloShapeOpsToStandardPass());
175
176 // Remove the remaining references to unsigned types after all HLO compute
177 // operations were converted.
178 pm.addPass(mlir::mhlo::createConvertToSignlessPass());
179
180 pm.addPass(mlir::createCanonicalizerPass());
181 pm.addNestedPass<FuncOp>(mlir::createCSEPass());
182
183 // Convert operations from the Complex dialect to the Standard/Math dialects.
184 pm.addNestedPass<FuncOp>(::mlir::createConvertComplexToStandardPass());
185
186 // Fuse linalg operations.
187 pm.addPass(mlir::memref::createResolveShapedTypeResultDimsPass());
188 pm.addPass(mlir::createCanonicalizerPass());
189 pm.addNestedPass<FuncOp>(mlir::createLinalgElementwiseOpFusionPass());
190
191 // Partial bufferization: Transforms inparticular HLO and Linalg operations to
192 // their corresponding LHLO operations and converts the function signature.
193 // Leaves shape operations untouched.
194 //
195 // TODO(pifon): Rename the pass to CreateHloLinalgBufferizePass or bufferize
196 // in 2 steps: first Linalg, then Hlo. That would need refactoring of
197 // BufferizeTypeConverter.
198 pm.addNestedPass<FuncOp>(mlir::createLinalgInitTensorToAllocTensorPass());
199 pm.addPass(mlir::createComputeOpAndFuncBufferizePass());
200 pm.addNestedPass<FuncOp>(::mlir::createCanonicalizerPass());
201 pm.addNestedPass<FuncOp>(::mlir::createCSEPass());
202 // Remove copies which are introduced by canonicalizing
203 // BufferCastOp(TensorLoadOp).
204 pm.addNestedPass<FuncOp>(
205 mlir::kernel_gen::transforms::CreateCopyCleanupPass());
206 // Find candidates for buffer reuse. This is only successful if buffer size
207 // equality can be determined based on `linalg.generic` operations.
208 pm.addNestedPass<FuncOp>(
209 mlir::kernel_gen::transforms::CreateBufferReusePass());
210 // Approximate Tanh using standard operations.
211 pm.addNestedPass<FuncOp>(
212 ::mlir::mhlo::createLegalizeTrigonometricToApproximationPass());
213 // Transform the Linalg ops inside of the loop nest into parallel loops.
214 pm.addNestedPass<FuncOp>(::mlir::createConvertLinalgToParallelLoopsPass());
215
216 // Canonicalize the code to simplify index computations. This is needed so
217 // that loop bounds have the same value.
218 pm.addNestedPass<FuncOp>(::mlir::createCanonicalizerPass());
219 // Run CSE to ensure that loads and stores to the same subview get
220 // recognized as such.
221 pm.addNestedPass<FuncOp>(::mlir::createCSEPass());
222 // Collapse and tile parallel loops for GPU only.
223 pm.addNestedPass<FuncOp>(mlir::createCollapseParallelLoopsTo1DPass());
224 pm.addNestedPass<FuncOp>(
225 mlir::createTileLoopsPass(tile_sizes, unroll_factors));
226
227 pm.addNestedPass<FuncOp>(::mlir::createCanonicalizerPass());
228 pm.addNestedPass<FuncOp>(::mlir::createCSEPass());
229 if (failed(pm.run(module))) {
230 return tensorflow::errors::Internal("Lowering TF to loops failed.");
231 }
232 return OkStatus();
233 }
234
LowerLoopsToGPU(mlir::ModuleOp module,bool embed_memref_prints,bool index_64bit,bool apply_cl_options)235 Status LowerLoopsToGPU(mlir::ModuleOp module, bool embed_memref_prints,
236 bool index_64bit, bool apply_cl_options) {
237 mlir::PassManager pm(module.getContext());
238 if (apply_cl_options) applyTensorflowAndCLOptions(pm);
239
240 // Greedily map the remaining loop to GPU hardware dimensions.
241 pm.addNestedPass<FuncOp>(mlir::createGpuMapParallelLoopsPass());
242
243 // Expand memref_reshape to its ranked form so that we can propagate
244 // scalars and avoid allocation.
245 pm.addNestedPass<FuncOp>(mlir::arith::createArithmeticExpandOpsPass());
246 pm.addNestedPass<FuncOp>(mlir::memref::createExpandOpsPass());
247 pm.addPass(mlir::createCanonicalizerPass());
248 pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass());
249 // Before bufferizing further, remove unused tensor_to_memref, so that we do
250 // not create allocations for tensor computations that are not actually
251 // needed.
252 pm.addPass(mlir::createCanonicalizerPass());
253 pm.addNestedPass<FuncOp>(mlir::createCSEPass());
254 // Before inserting more allocs, map the ones we already have to the
255 // tf runtime. That ensures that all allocations for the actual computation
256 // end up on the device, whereas allocations for shape computation and host
257 // side things remain on the host.
258 // Longer term, this should be handled by proper device placement.
259 pm.addPass(mlir::kernel_gen::tf_framework::CreateEmbedTFFrameworkPass());
260 // Now lower the shape computations, bufferize all remaining ops and insert
261 // deallocs.
262 pm.addPass(mlir::createFinalBufferizePass(
263 /*alignment=*/64,
264 mlir::kernel_gen::transforms::populateExtraBufferizeDialects,
265 mlir::kernel_gen::transforms::populateExtraBufferizePatterns));
266 // TODO(herhut): Enable once no-longer broken.
267 pm.addNestedPass<FuncOp>(::mlir::bufferization::createBufferHoistingPass());
268 pm.addNestedPass<FuncOp>(mlir::bufferization::createPromoteBuffersToStackPass(
269 [](Value alloc) { return IsSmallAlloc(alloc); }));
270 // Free all temporaries,
271 pm.addNestedPass<FuncOp>(
272 ::mlir::bufferization::createBufferDeallocationPass());
273 pm.addPass(mlir::createCanonicalizerPass());
274
275 // Apply the mapping and go to GPU. We cannot do this earlier due to missing
276 // interfaces on the GPU dialect.
277 // TODO(b/174830459): Move up once implemented.
278 pm.addNestedPass<FuncOp>(mlir::createParallelLoopToGpuPass());
279
280 // Some basic cleanup.
281 pm.addNestedPass<FuncOp>(::mlir::createCanonicalizerPass());
282 pm.addNestedPass<FuncOp>(::mlir::createCSEPass());
283 // Make loops with min bounds into a conditional plus static bounds.
284 pm.addNestedPass<FuncOp>(mlir::createForLoopSpecializationPass());
285 // Take launches to launches with kernels.
286 pm.addPass(mlir::createGpuLauchSinkIndexComputationsPass());
287 const std::string gpuDataLayoutSpec =
288 index_64bit ? "#dlti.dl_spec<#dlti.dl_entry<index,64:i64>>"
289 : "#dlti.dl_spec<#dlti.dl_entry<index,32:i32>>";
290 pm.addPass(mlir::createGpuKernelOutliningPass(gpuDataLayoutSpec));
291
292 pm.addPass(::mlir::createLowerAffinePass());
293 // Constraints are removed as late as possible and before lowering to CFG.
294 pm.addNestedPass<FuncOp>(::mlir::createConvertShapeConstraintsPass());
295 pm.addNestedPass<FuncOp>(::mlir::createCanonicalizerPass());
296 pm.addPass(::mlir::createConvertSCFToCFPass());
297 // Map asserts to the tensorflow framework.
298 pm.addPass(mlir::kernel_gen::tf_framework::CreateRewriteTFFrameworkAssert());
299 if (embed_memref_prints) {
300 pm.addPass(mlir::kernel_gen::transforms::CreateEmbedMemRefPrintsPass());
301 }
302 if (failed(pm.run(module))) {
303 return tensorflow::errors::Internal("Lowering to GPU kernels failed.");
304 }
305 return OkStatus();
306 }
307
LowerKernelBodiesToLowLevelIr(mlir::ModuleOp module,bool apply_cl_options)308 Status LowerKernelBodiesToLowLevelIr(mlir::ModuleOp module,
309 bool apply_cl_options) {
310 #if !defined(TENSORFLOW_USE_ROCM) && !defined(GOOGLE_CUDA)
311 return tensorflow::errors::Internal(
312 "Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined."
313 " Did you specify either --config=rocm or --config=cuda ?");
314 #endif
315
316 #if TENSORFLOW_USE_ROCM
317 auto gpu_modules = module.getOps<::mlir::gpu::GPUModuleOp>();
318 for (::mlir::gpu::GPUModuleOp gpu_module : gpu_modules) {
319 gpu_module.walk([&](mlir::gpu::GPUFuncOp gpu_kernel) {
320 if (gpu_kernel.isKernel()) {
321 gpu_kernel->setAttr(
322 "rocdl.max_flat_work_group_size",
323 mlir::IntegerAttr::get(
324 mlir::IntegerType::get(module.getContext(), 32), 1024));
325 }
326 });
327 }
328 #endif
329
330 mlir::PassManager pm(module.getContext());
331 // We cannot verify as the signature of the kernel is rewritten.
332 // pm.enableVerifier(false);
333 if (apply_cl_options) tensorflow::applyTensorflowAndCLOptions(pm);
334 auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>();
335 kernelPm.addPass(::mlir::createConvertSCFToCFPass());
336 #if TENSORFLOW_USE_ROCM
337 kernelPm.addPass(mlir::createGpuKernelToRocdlPass());
338 #elif GOOGLE_CUDA
339 kernelPm.addPass(mlir::createGpuKernelToNvvmPass());
340 kernelPm.addPass(mlir::NVVM::createOptimizeForTargetPass());
341 #endif
342 // Remove all location information to prevent a debug build.
343 pm.addPass(::mlir::createStripDebugInfoPass());
344
345 if (failed(pm.run(module))) {
346 return tensorflow::errors::Internal(
347 "Lowering to low-level device IR failed.");
348 }
349
350 return OkStatus();
351 }
352
AmendKernelLLVMIRWithStaticKnowledge(mlir::ModuleOp module,bool apply_cl_options)353 Status AmendKernelLLVMIRWithStaticKnowledge(mlir::ModuleOp module,
354 bool apply_cl_options) {
355 mlir::PassManager pm(module.getContext());
356 if (apply_cl_options) applyTensorflowAndCLOptions(pm);
357
358 pm.addNestedPass<FuncOp>(
359 mlir::kernel_gen::transforms::CreatePropagateShapeKnowledgeToKernels());
360 pm.addNestedPass<FuncOp>(
361 mlir::kernel_gen::transforms::CreatePropagateTfAbiKnowledgeToKernels());
362
363 return failed(pm.run(module))
364 ? tensorflow::errors::Internal(
365 "Amending LLVMIR with static knowledge failed.")
366 : OkStatus();
367 }
368
GenerateDeviceCode(mlir::ModuleOp module,llvm::StringRef gpu_binary_attr_name,llvm::ArrayRef<std::string> architectures,bool print_ptx,bool print_llvmir,bool enable_ftz,bool apply_cl_options)369 Status GenerateDeviceCode(mlir::ModuleOp module,
370 llvm::StringRef gpu_binary_attr_name,
371 llvm::ArrayRef<std::string> architectures,
372 bool print_ptx, bool print_llvmir, bool enable_ftz,
373 bool apply_cl_options) {
374 mlir::PassManager pm(module.getContext());
375 if (apply_cl_options) applyTensorflowAndCLOptions(pm);
376 mlir::registerLLVMDialectTranslation(*module->getContext());
377
378 auto& kernel_pm = pm.nest<mlir::gpu::GPUModuleOp>();
379 // Remove debug information to ensure we do not create debug PTX.
380 kernel_pm.addPass(mlir::createStripDebugInfoPass());
381 kernel_pm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToBlobPass(
382 gpu_binary_attr_name, architectures, print_ptx, print_llvmir,
383 enable_ftz));
384
385 return failed(pm.run(module))
386 ? tensorflow::errors::Internal("Generating device code failed.")
387 : OkStatus();
388 }
389
LowerHostSideToFinalForm(mlir::ModuleOp module,bool apply_cl_options)390 Status LowerHostSideToFinalForm(mlir::ModuleOp module, bool apply_cl_options) {
391 mlir::PassManager pm(module.getContext());
392 if (apply_cl_options) applyTensorflowAndCLOptions(pm);
393
394 pm.addPass(mlir::kernel_gen::transforms::CreateTFKernelToLLVMPass(
395 kGpuBinaryAttrName));
396 pm.addPass(mlir::createReconcileUnrealizedCastsPass());
397 pm.addPass(mlir::createCanonicalizerPass());
398 pm.addPass(mlir::createCSEPass());
399
400 return failed(pm.run(module)) ? tensorflow::errors::Internal(
401 "Final lowering of host side failed.")
402 : OkStatus();
403 }
404
405 } // namespace
406
SetupContextAndParseModule(mlir::MLIRContext & context,llvm::StringRef tf_code)407 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> SetupContextAndParseModule(
408 mlir::MLIRContext& context, llvm::StringRef tf_code) {
409 mlir::DialectRegistry registry;
410 mlir::RegisterAllTensorFlowDialects(registry);
411 registry.insert<mlir::chlo::ChloDialect, mlir::mhlo::MhloDialect>();
412 mlir::registerLLVMDialectTranslation(registry);
413 mlir::registerNVVMDialectTranslation(registry);
414 mlir::registerROCDLDialectTranslation(registry);
415 context.appendDialectRegistry(registry);
416 mlir::OwningOpRef<mlir::ModuleOp> module =
417 mlir::parseSourceString<mlir::ModuleOp>(tf_code, &context);
418 if (!module)
419 return tensorflow::Status(tensorflow::error::Code::INVALID_ARGUMENT,
420 "invalid kernel IR");
421 return module;
422 }
423
GenerateKernelForTfCode(mlir::MLIRContext & context,llvm::StringRef tf_code,llvm::ArrayRef<std::string> architectures,llvm::ArrayRef<int64_t> tile_sizes,llvm::ArrayRef<int64_t> unroll_factors,int64_t max_supported_rank,bool embed_memref_prints,bool print_ptx,bool print_llvmir,bool enable_ftz,bool index_64bit,bool jit_compile,bool jit_i64_indexed_for_large_tensors,bool apply_cl_options)424 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> GenerateKernelForTfCode(
425 mlir::MLIRContext& context, llvm::StringRef tf_code,
426 llvm::ArrayRef<std::string> architectures,
427 llvm::ArrayRef<int64_t> tile_sizes, llvm::ArrayRef<int64_t> unroll_factors,
428 int64_t max_supported_rank, bool embed_memref_prints, bool print_ptx,
429 bool print_llvmir, bool enable_ftz, bool index_64bit, bool jit_compile,
430 bool jit_i64_indexed_for_large_tensors, bool apply_cl_options) {
431 TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> module,
432 SetupContextAndParseModule(context, tf_code));
433
434 if (jit_compile) {
435 TF_RETURN_IF_ERROR(LowerTFToJITInvocation(
436 module.get(), tile_sizes, unroll_factors, max_supported_rank,
437 enable_ftz, index_64bit,
438 /*jit_i64_indexed_for_large_tensors=*/false, apply_cl_options));
439 } else {
440 TF_RETURN_IF_ERROR(
441 LowerTFtoLoops(module.get(), tile_sizes, unroll_factors,
442 max_supported_rank, enable_ftz, index_64bit,
443 jit_i64_indexed_for_large_tensors, apply_cl_options));
444 TF_RETURN_IF_ERROR(LowerLoopsToGPU(module.get(), embed_memref_prints,
445 index_64bit, apply_cl_options));
446 TF_RETURN_IF_ERROR(
447 LowerKernelBodiesToLowLevelIr(module.get(), apply_cl_options));
448 TF_RETURN_IF_ERROR(
449 AmendKernelLLVMIRWithStaticKnowledge(module.get(), apply_cl_options));
450 TF_RETURN_IF_ERROR(GenerateDeviceCode(
451 module.get(), kGpuBinaryAttrName, architectures, print_ptx,
452 print_llvmir, enable_ftz, apply_cl_options));
453 }
454
455 TF_RETURN_IF_ERROR(LowerHostSideToFinalForm(module.get(), apply_cl_options));
456
457 return module;
458 }
459
460 } // namespace kernel_gen
461 } // namespace tensorflow
462