1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" // from @llvm-project
17 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
18 #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
19 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
20 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
21 #include "mlir/IR/TypeRange.h" // from @llvm-project
22 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
23 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
24 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
25
26 namespace mlir {
27 namespace kernel_gen {
28 namespace tf_framework {
29 namespace {
30
31 // Prepends argument type list of the function with an OpKernelContextType arg.
32 class FuncOpConverter : public OpConversionPattern<func::FuncOp> {
33 public:
34 using OpConversionPattern<func::FuncOp>::OpConversionPattern;
35
matchAndRewrite(func::FuncOp func,OpAdaptor,ConversionPatternRewriter & rewriter) const36 LogicalResult matchAndRewrite(
37 func::FuncOp func, OpAdaptor /*adaptor*/,
38 ConversionPatternRewriter &rewriter) const override {
39 // Convert function arguments using the provided TypeConverter.
40 auto func_type = func.getFunctionType();
41 TypeConverter::SignatureConversion conversion(func_type.getNumInputs());
42
43 conversion.addInputs(OpKernelContextType::get(rewriter.getContext()));
44 for (auto arg_type : llvm::enumerate(func_type.getInputs())) {
45 conversion.addInputs(arg_type.index(), arg_type.value());
46 }
47
48 rewriter.applySignatureConversion(&func.getBody(), conversion);
49
50 // Update the signature of the function.
51 rewriter.updateRootInPlace(func, [&] {
52 func.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
53 func_type.getResults()));
54 });
55 return success();
56 }
57 };
58
FindOpKernelContext(Operation * op)59 llvm::Optional<Value> FindOpKernelContext(Operation *op) {
60 auto func = op->getParentOfType<func::FuncOp>();
61 if (func.getNumArguments() == 0) {
62 return llvm::None;
63 }
64 Value ctx = func.getArgument(0);
65 if (!ctx.getType().isa<OpKernelContextType>()) {
66 return llvm::None;
67 }
68 return ctx;
69 }
70
71 // Converts std.alloc to tf_framework.alloc_raw using OpKernelContextType arg of
72 // the parent function.
73 struct AllocOpConverter : public OpConversionPattern<memref::AllocOp> {
74 using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
75
matchAndRewritemlir::kernel_gen::tf_framework::__anona0bbadb60111::AllocOpConverter76 LogicalResult matchAndRewrite(
77 memref::AllocOp alloc, OpAdaptor adaptor,
78 ConversionPatternRewriter &rewriter) const override {
79 llvm::Optional<Value> ctx = FindOpKernelContext(alloc);
80 if (!ctx) return failure();
81
82 // Symbolic operands that bind to the symbols of the memref's layout map are
83 // not supported by TFAllocOp.
84 if (!alloc.getSymbolOperands().empty()) {
85 return failure();
86 }
87 auto reuse_input_candidates = alloc->getAttrOfType<ArrayAttr>(
88 TFAllocOp::kReuseInputCandidatesAttrName);
89 auto reuse_output_index =
90 alloc->getAttrOfType<IntegerAttr>(TFAllocOp::kReuseOutputAttrName);
91 Value buffer = rewriter.replaceOpWithNewOp<TFAllocOp>(
92 alloc, alloc.getType(), *ctx, adaptor.getOperands(),
93 reuse_input_candidates, reuse_output_index);
94 Location loc = buffer.getLoc();
95 Value cond = rewriter.create<IsValidMemRefOp>(
96 loc, rewriter.getIntegerType(1), buffer);
97 rewriter.create<TFAssertOp>(loc, *ctx, cond, ErrorCode::RESOURCE_EXHAUSTED,
98 "failed to allocate memory");
99 return success();
100 }
101 };
102
103 // Converts std.dealloc to tf_framework.dealloc_raw using OpKernelContextType
104 // arg of the parent function.
105 struct DeallocOpConverter : public OpConversionPattern<memref::DeallocOp> {
106 using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
107
matchAndRewritemlir::kernel_gen::tf_framework::__anona0bbadb60111::DeallocOpConverter108 LogicalResult matchAndRewrite(
109 memref::DeallocOp dealloc, OpAdaptor adaptor,
110 ConversionPatternRewriter &rewriter) const override {
111 llvm::Optional<Value> ctx = FindOpKernelContext(dealloc);
112 if (!ctx) return failure();
113
114 // Operand with no layout is expected.
115 auto operand_memref_type = dealloc.getMemref().getType().cast<MemRefType>();
116 if (!operand_memref_type.getLayout().isIdentity()) {
117 return failure();
118 }
119 rewriter.replaceOpWithNewOp<TFDeallocOp>(dealloc, *ctx,
120 adaptor.getMemref());
121 return success();
122 }
123 };
124
125 // Converts std.assert to tf_framework.assert with using OpKernelContextType
126 // arg of the parent function.
127 struct AssertOpConverter : public OpConversionPattern<cf::AssertOp> {
128 public:
129 using OpConversionPattern<cf::AssertOp>::OpConversionPattern;
130
matchAndRewritemlir::kernel_gen::tf_framework::__anona0bbadb60111::AssertOpConverter131 LogicalResult matchAndRewrite(
132 cf::AssertOp op, OpAdaptor adaptor,
133 ConversionPatternRewriter &rewriter) const override {
134 llvm::Optional<Value> ctx = FindOpKernelContext(op);
135 if (!ctx) return failure();
136 rewriter.replaceOpWithNewOp<TFAssertOp>(op, *ctx, adaptor.getArg(),
137 ErrorCode::INVALID_ARGUMENT,
138 adaptor.getMsg());
139 return success();
140 }
141 };
142
143 // Amends `tf_framework.jit_execute` with the newly introduced OpKernelContext.
144 struct JITExecuteOpConverter : public OpConversionPattern<JITExecuteOp> {
145 using OpConversionPattern<JITExecuteOp>::OpConversionPattern;
146
matchAndRewritemlir::kernel_gen::tf_framework::__anona0bbadb60111::JITExecuteOpConverter147 LogicalResult matchAndRewrite(
148 JITExecuteOp op, OpAdaptor /*adaptor*/,
149 ConversionPatternRewriter &rewriter) const override {
150 llvm::Optional<Value> ctx = FindOpKernelContext(op);
151 if (!ctx) return failure();
152 rewriter.replaceOpWithNewOp<JITExecuteOp>(op, op.result().getType(), *ctx,
153 op.callable(), op.operands());
154 return success();
155 }
156 };
157
158 // Amends `tf_framework.jit_compile_from_str` with the newly introduced
159 // OpKernelContext.
160 struct JITCompileFromStrOpConverter
161 : public OpConversionPattern<JITCompileFromStrOp> {
162 using OpConversionPattern<JITCompileFromStrOp>::OpConversionPattern;
163
matchAndRewritemlir::kernel_gen::tf_framework::__anona0bbadb60111::JITCompileFromStrOpConverter164 LogicalResult matchAndRewrite(
165 JITCompileFromStrOp op, OpAdaptor /*adaptor*/,
166 ConversionPatternRewriter &rewriter) const override {
167 llvm::Optional<Value> ctx = FindOpKernelContext(op);
168 if (!ctx) return failure();
169 rewriter.replaceOpWithNewOp<JITCompileFromStrOp>(
170 op, rewriter.getType<JITCallableType>(), *ctx, op->getAttrs());
171 return success();
172 }
173 };
174
175 } // namespace
176
PopulateEmbedTFFrameworkAssertPattern(RewritePatternSet * patterns)177 void PopulateEmbedTFFrameworkAssertPattern(RewritePatternSet *patterns) {
178 patterns->add<AssertOpConverter>(patterns->getContext());
179 }
180
PopulateEmbedTFFrameworkPatterns(RewritePatternSet * patterns)181 void PopulateEmbedTFFrameworkPatterns(RewritePatternSet *patterns) {
182 // clang-format off
183 patterns->add<
184 AllocOpConverter,
185 AssertOpConverter,
186 DeallocOpConverter,
187 FuncOpConverter,
188 JITCompileFromStrOpConverter,
189 JITExecuteOpConverter>(patterns->getContext());
190 // clang-format on
191 }
192
193 } // namespace tf_framework
194 } // namespace kernel_gen
195 } // namespace mlir
196