1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"  // from @llvm-project
17 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
19 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
20 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
21 #include "mlir/IR/TypeRange.h"  // from @llvm-project
22 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
24 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
25 
26 namespace mlir {
27 namespace kernel_gen {
28 namespace tf_framework {
29 namespace {
30 
31 // Prepends argument type list of the function with an OpKernelContextType arg.
32 class FuncOpConverter : public OpConversionPattern<func::FuncOp> {
33  public:
34   using OpConversionPattern<func::FuncOp>::OpConversionPattern;
35 
matchAndRewrite(func::FuncOp func,OpAdaptor,ConversionPatternRewriter & rewriter) const36   LogicalResult matchAndRewrite(
37       func::FuncOp func, OpAdaptor /*adaptor*/,
38       ConversionPatternRewriter &rewriter) const override {
39     // Convert function arguments using the provided TypeConverter.
40     auto func_type = func.getFunctionType();
41     TypeConverter::SignatureConversion conversion(func_type.getNumInputs());
42 
43     conversion.addInputs(OpKernelContextType::get(rewriter.getContext()));
44     for (auto arg_type : llvm::enumerate(func_type.getInputs())) {
45       conversion.addInputs(arg_type.index(), arg_type.value());
46     }
47 
48     rewriter.applySignatureConversion(&func.getBody(), conversion);
49 
50     // Update the signature of the function.
51     rewriter.updateRootInPlace(func, [&] {
52       func.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
53                                             func_type.getResults()));
54     });
55     return success();
56   }
57 };
58 
FindOpKernelContext(Operation * op)59 llvm::Optional<Value> FindOpKernelContext(Operation *op) {
60   auto func = op->getParentOfType<func::FuncOp>();
61   if (func.getNumArguments() == 0) {
62     return llvm::None;
63   }
64   Value ctx = func.getArgument(0);
65   if (!ctx.getType().isa<OpKernelContextType>()) {
66     return llvm::None;
67   }
68   return ctx;
69 }
70 
71 // Converts std.alloc to tf_framework.alloc_raw using OpKernelContextType arg of
72 // the parent function.
73 struct AllocOpConverter : public OpConversionPattern<memref::AllocOp> {
74   using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
75 
matchAndRewritemlir::kernel_gen::tf_framework::__anona0bbadb60111::AllocOpConverter76   LogicalResult matchAndRewrite(
77       memref::AllocOp alloc, OpAdaptor adaptor,
78       ConversionPatternRewriter &rewriter) const override {
79     llvm::Optional<Value> ctx = FindOpKernelContext(alloc);
80     if (!ctx) return failure();
81 
82     // Symbolic operands that bind to the symbols of the memref's layout map are
83     // not supported by TFAllocOp.
84     if (!alloc.getSymbolOperands().empty()) {
85       return failure();
86     }
87     auto reuse_input_candidates = alloc->getAttrOfType<ArrayAttr>(
88         TFAllocOp::kReuseInputCandidatesAttrName);
89     auto reuse_output_index =
90         alloc->getAttrOfType<IntegerAttr>(TFAllocOp::kReuseOutputAttrName);
91     Value buffer = rewriter.replaceOpWithNewOp<TFAllocOp>(
92         alloc, alloc.getType(), *ctx, adaptor.getOperands(),
93         reuse_input_candidates, reuse_output_index);
94     Location loc = buffer.getLoc();
95     Value cond = rewriter.create<IsValidMemRefOp>(
96         loc, rewriter.getIntegerType(1), buffer);
97     rewriter.create<TFAssertOp>(loc, *ctx, cond, ErrorCode::RESOURCE_EXHAUSTED,
98                                 "failed to allocate memory");
99     return success();
100   }
101 };
102 
103 // Converts std.dealloc to tf_framework.dealloc_raw using OpKernelContextType
104 // arg of the parent function.
105 struct DeallocOpConverter : public OpConversionPattern<memref::DeallocOp> {
106   using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
107 
matchAndRewritemlir::kernel_gen::tf_framework::__anona0bbadb60111::DeallocOpConverter108   LogicalResult matchAndRewrite(
109       memref::DeallocOp dealloc, OpAdaptor adaptor,
110       ConversionPatternRewriter &rewriter) const override {
111     llvm::Optional<Value> ctx = FindOpKernelContext(dealloc);
112     if (!ctx) return failure();
113 
114     // Operand with no layout is expected.
115     auto operand_memref_type = dealloc.getMemref().getType().cast<MemRefType>();
116     if (!operand_memref_type.getLayout().isIdentity()) {
117       return failure();
118     }
119     rewriter.replaceOpWithNewOp<TFDeallocOp>(dealloc, *ctx,
120                                              adaptor.getMemref());
121     return success();
122   }
123 };
124 
125 // Converts std.assert to tf_framework.assert with using OpKernelContextType
126 // arg of the parent function.
127 struct AssertOpConverter : public OpConversionPattern<cf::AssertOp> {
128  public:
129   using OpConversionPattern<cf::AssertOp>::OpConversionPattern;
130 
matchAndRewritemlir::kernel_gen::tf_framework::__anona0bbadb60111::AssertOpConverter131   LogicalResult matchAndRewrite(
132       cf::AssertOp op, OpAdaptor adaptor,
133       ConversionPatternRewriter &rewriter) const override {
134     llvm::Optional<Value> ctx = FindOpKernelContext(op);
135     if (!ctx) return failure();
136     rewriter.replaceOpWithNewOp<TFAssertOp>(op, *ctx, adaptor.getArg(),
137                                             ErrorCode::INVALID_ARGUMENT,
138                                             adaptor.getMsg());
139     return success();
140   }
141 };
142 
143 // Amends `tf_framework.jit_execute` with the newly introduced OpKernelContext.
144 struct JITExecuteOpConverter : public OpConversionPattern<JITExecuteOp> {
145   using OpConversionPattern<JITExecuteOp>::OpConversionPattern;
146 
matchAndRewritemlir::kernel_gen::tf_framework::__anona0bbadb60111::JITExecuteOpConverter147   LogicalResult matchAndRewrite(
148       JITExecuteOp op, OpAdaptor /*adaptor*/,
149       ConversionPatternRewriter &rewriter) const override {
150     llvm::Optional<Value> ctx = FindOpKernelContext(op);
151     if (!ctx) return failure();
152     rewriter.replaceOpWithNewOp<JITExecuteOp>(op, op.result().getType(), *ctx,
153                                               op.callable(), op.operands());
154     return success();
155   }
156 };
157 
158 // Amends `tf_framework.jit_compile_from_str` with the newly introduced
159 // OpKernelContext.
160 struct JITCompileFromStrOpConverter
161     : public OpConversionPattern<JITCompileFromStrOp> {
162   using OpConversionPattern<JITCompileFromStrOp>::OpConversionPattern;
163 
matchAndRewritemlir::kernel_gen::tf_framework::__anona0bbadb60111::JITCompileFromStrOpConverter164   LogicalResult matchAndRewrite(
165       JITCompileFromStrOp op, OpAdaptor /*adaptor*/,
166       ConversionPatternRewriter &rewriter) const override {
167     llvm::Optional<Value> ctx = FindOpKernelContext(op);
168     if (!ctx) return failure();
169     rewriter.replaceOpWithNewOp<JITCompileFromStrOp>(
170         op, rewriter.getType<JITCallableType>(), *ctx, op->getAttrs());
171     return success();
172   }
173 };
174 
175 }  // namespace
176 
PopulateEmbedTFFrameworkAssertPattern(RewritePatternSet * patterns)177 void PopulateEmbedTFFrameworkAssertPattern(RewritePatternSet *patterns) {
178   patterns->add<AssertOpConverter>(patterns->getContext());
179 }
180 
PopulateEmbedTFFrameworkPatterns(RewritePatternSet * patterns)181 void PopulateEmbedTFFrameworkPatterns(RewritePatternSet *patterns) {
182   // clang-format off
183   patterns->add<
184       AllocOpConverter,
185       AssertOpConverter,
186       DeallocOpConverter,
187       FuncOpConverter,
188       JITCompileFromStrOpConverter,
189       JITExecuteOpConverter>(patterns->getContext());
190   // clang-format on
191 }
192 
193 }  // namespace tf_framework
194 }  // namespace kernel_gen
195 }  // namespace mlir
196