1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file defines the operations used in the tf_framework dialect.
17
18 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
19
20 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" // from @llvm-project
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
23 #include "mlir/IR/Builders.h" // from @llvm-project
24 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_status.cc.inc"
26
27 // Generated dialect definitions.
28 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_dialect.cc.inc"
29
30 namespace mlir {
31 namespace kernel_gen {
32 namespace tf_framework {
33
initialize()34 void TFFrameworkDialect::initialize() {
35 addOperations<
36 #define GET_OP_LIST
37 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc"
38 >();
39 addTypes<JITCallableType, OpKernelContextType>();
40 }
41
42 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const43 Type TFFrameworkDialect::parseType(DialectAsmParser &parser) const {
44 StringRef keyword;
45 if (parser.parseKeyword(&keyword)) return Type();
46
47 if (keyword == "op_kernel_context") {
48 return OpKernelContextType::get(getContext());
49 }
50 if (keyword == "jit_callable") {
51 return JITCallableType::get(getContext());
52 }
53
54 parser.emitError(parser.getNameLoc(), "unknown TF Framework type: ")
55 << keyword;
56 return Type();
57 }
58
59 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const60 void TFFrameworkDialect::printType(Type type, DialectAsmPrinter &os) const {
61 if (type.isa<OpKernelContextType>()) {
62 os << "op_kernel_context";
63 return;
64 }
65 if (type.isa<JITCallableType>()) {
66 os << "jit_callable";
67 return;
68 }
69 llvm_unreachable("unexpected TF Framework type kind");
70 }
71
72 //===----------------------------------------------------------------------===//
73 // TFAllocOp
74 //===----------------------------------------------------------------------===//
verify()75 LogicalResult TFAllocOp::verify() {
76 TFAllocOp op = *this;
77 // Check that the total number of operands matches the number of dynamic
78 // dimensions specified in the memref type.
79 unsigned result_dyn_dims = op.getType().getNumDynamicDims();
80 unsigned dyn_sizes_count = op.dyn_sizes().size();
81 if (dyn_sizes_count != result_dyn_dims)
82 return op.emitOpError()
83 << "`dyn_sizes` count " << dyn_sizes_count
84 << " does not match dynamic dimensions count in the result type"
85 << op.getType();
86 return success();
87 }
88
buildDealloc(OpBuilder & builder,Value alloc)89 Optional<Operation *> TFAllocOp::buildDealloc(OpBuilder &builder, Value alloc) {
90 auto funcop = alloc.getParentRegion()->getParentOfType<func::FuncOp>();
91 return builder
92 .create<TFDeallocOp>(alloc.getLoc(), funcop.getArgument(0), alloc)
93 .getOperation();
94 }
95
buildClone(OpBuilder & builder,Value alloc)96 Optional<Value> TFAllocOp::buildClone(OpBuilder &builder, Value alloc) {
97 // TODO(herhut): We should have our own clone op if one of these survives.
98 return builder.create<mlir::bufferization::CloneOp>(alloc.getLoc(), alloc)
99 .getResult();
100 }
101
102 //===----------------------------------------------------------------------===//
103 // JITExecuteOp
104 //===----------------------------------------------------------------------===//
105
buildDealloc(OpBuilder & builder,Value alloc)106 Optional<Operation *> JITExecuteOp::buildDealloc(OpBuilder &builder,
107 Value alloc) {
108 auto funcop = alloc.getParentRegion()->getParentOfType<func::FuncOp>();
109 return builder
110 .create<TFDeallocOp>(alloc.getLoc(), funcop.getArgument(0), alloc)
111 .getOperation();
112 }
113
buildClone(OpBuilder & builder,Value alloc)114 Optional<Value> JITExecuteOp::buildClone(OpBuilder &builder, Value alloc) {
115 // TODO(herhut): We should have our own clone op if one of these survives.
116 return builder.create<mlir::bufferization::CloneOp>(alloc.getLoc(), alloc)
117 .getResult();
118 }
119
ConvertAttrToEnumValue(ErrorCode error_code)120 ::tensorflow::error::Code ConvertAttrToEnumValue(ErrorCode error_code) {
121 using ::tensorflow::error::Code;
122 switch (error_code) {
123 case ErrorCode::OK:
124 return Code::OK;
125 case ErrorCode::CANCELLED:
126 return Code::CANCELLED;
127 case ErrorCode::UNKNOWN:
128 return Code::UNKNOWN;
129 case ErrorCode::INVALID_ARGUMENT:
130 return Code::INVALID_ARGUMENT;
131 case ErrorCode::DEADLINE_EXCEEDED:
132 return Code::DEADLINE_EXCEEDED;
133 case ErrorCode::NOT_FOUND:
134 return Code::NOT_FOUND;
135 case ErrorCode::ALREADY_EXISTS:
136 return Code::ALREADY_EXISTS;
137 case ErrorCode::PERMISSION_DENIED:
138 return Code::PERMISSION_DENIED;
139 case ErrorCode::UNAUTHENTICATED:
140 return Code::UNAUTHENTICATED;
141 case ErrorCode::RESOURCE_EXHAUSTED:
142 return Code::RESOURCE_EXHAUSTED;
143 case ErrorCode::FAILED_PRECONDITION:
144 return Code::FAILED_PRECONDITION;
145 case ErrorCode::ABORTED:
146 return Code::ABORTED;
147 case ErrorCode::OUT_OF_RANGE:
148 return Code::OUT_OF_RANGE;
149 case ErrorCode::UNIMPLEMENTED:
150 return Code::UNIMPLEMENTED;
151 case ErrorCode::INTERNAL:
152 return Code::INTERNAL;
153 case ErrorCode::UNAVAILABLE:
154 return Code::UNAVAILABLE;
155 case ErrorCode::DATA_LOSS:
156 return Code::DATA_LOSS;
157 }
158 }
159
160 } // namespace tf_framework
161 } // namespace kernel_gen
162 } // namespace mlir
163
164 #define GET_OP_CLASSES
165 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc"
166