xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines the operations used in the tf_framework dialect.
17 
18 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
19 
20 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"  // from @llvm-project
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_status.cc.inc"
26 
27 // Generated dialect definitions.
28 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_dialect.cc.inc"
29 
30 namespace mlir {
31 namespace kernel_gen {
32 namespace tf_framework {
33 
initialize()34 void TFFrameworkDialect::initialize() {
35   addOperations<
36 #define GET_OP_LIST
37 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc"
38       >();
39   addTypes<JITCallableType, OpKernelContextType>();
40 }
41 
42 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const43 Type TFFrameworkDialect::parseType(DialectAsmParser &parser) const {
44   StringRef keyword;
45   if (parser.parseKeyword(&keyword)) return Type();
46 
47   if (keyword == "op_kernel_context") {
48     return OpKernelContextType::get(getContext());
49   }
50   if (keyword == "jit_callable") {
51     return JITCallableType::get(getContext());
52   }
53 
54   parser.emitError(parser.getNameLoc(), "unknown TF Framework type: ")
55       << keyword;
56   return Type();
57 }
58 
59 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const60 void TFFrameworkDialect::printType(Type type, DialectAsmPrinter &os) const {
61   if (type.isa<OpKernelContextType>()) {
62     os << "op_kernel_context";
63     return;
64   }
65   if (type.isa<JITCallableType>()) {
66     os << "jit_callable";
67     return;
68   }
69   llvm_unreachable("unexpected TF Framework type kind");
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // TFAllocOp
74 //===----------------------------------------------------------------------===//
verify()75 LogicalResult TFAllocOp::verify() {
76   TFAllocOp op = *this;
77   // Check that the total number of operands matches the number of dynamic
78   // dimensions specified in the memref type.
79   unsigned result_dyn_dims = op.getType().getNumDynamicDims();
80   unsigned dyn_sizes_count = op.dyn_sizes().size();
81   if (dyn_sizes_count != result_dyn_dims)
82     return op.emitOpError()
83            << "`dyn_sizes` count " << dyn_sizes_count
84            << " does not match dynamic dimensions count in the result type"
85            << op.getType();
86   return success();
87 }
88 
buildDealloc(OpBuilder & builder,Value alloc)89 Optional<Operation *> TFAllocOp::buildDealloc(OpBuilder &builder, Value alloc) {
90   auto funcop = alloc.getParentRegion()->getParentOfType<func::FuncOp>();
91   return builder
92       .create<TFDeallocOp>(alloc.getLoc(), funcop.getArgument(0), alloc)
93       .getOperation();
94 }
95 
buildClone(OpBuilder & builder,Value alloc)96 Optional<Value> TFAllocOp::buildClone(OpBuilder &builder, Value alloc) {
97   // TODO(herhut): We should have our own clone op if one of these survives.
98   return builder.create<mlir::bufferization::CloneOp>(alloc.getLoc(), alloc)
99       .getResult();
100 }
101 
102 //===----------------------------------------------------------------------===//
103 // JITExecuteOp
104 //===----------------------------------------------------------------------===//
105 
buildDealloc(OpBuilder & builder,Value alloc)106 Optional<Operation *> JITExecuteOp::buildDealloc(OpBuilder &builder,
107                                                  Value alloc) {
108   auto funcop = alloc.getParentRegion()->getParentOfType<func::FuncOp>();
109   return builder
110       .create<TFDeallocOp>(alloc.getLoc(), funcop.getArgument(0), alloc)
111       .getOperation();
112 }
113 
buildClone(OpBuilder & builder,Value alloc)114 Optional<Value> JITExecuteOp::buildClone(OpBuilder &builder, Value alloc) {
115   // TODO(herhut): We should have our own clone op if one of these survives.
116   return builder.create<mlir::bufferization::CloneOp>(alloc.getLoc(), alloc)
117       .getResult();
118 }
119 
ConvertAttrToEnumValue(ErrorCode error_code)120 ::tensorflow::error::Code ConvertAttrToEnumValue(ErrorCode error_code) {
121   using ::tensorflow::error::Code;
122   switch (error_code) {
123     case ErrorCode::OK:
124       return Code::OK;
125     case ErrorCode::CANCELLED:
126       return Code::CANCELLED;
127     case ErrorCode::UNKNOWN:
128       return Code::UNKNOWN;
129     case ErrorCode::INVALID_ARGUMENT:
130       return Code::INVALID_ARGUMENT;
131     case ErrorCode::DEADLINE_EXCEEDED:
132       return Code::DEADLINE_EXCEEDED;
133     case ErrorCode::NOT_FOUND:
134       return Code::NOT_FOUND;
135     case ErrorCode::ALREADY_EXISTS:
136       return Code::ALREADY_EXISTS;
137     case ErrorCode::PERMISSION_DENIED:
138       return Code::PERMISSION_DENIED;
139     case ErrorCode::UNAUTHENTICATED:
140       return Code::UNAUTHENTICATED;
141     case ErrorCode::RESOURCE_EXHAUSTED:
142       return Code::RESOURCE_EXHAUSTED;
143     case ErrorCode::FAILED_PRECONDITION:
144       return Code::FAILED_PRECONDITION;
145     case ErrorCode::ABORTED:
146       return Code::ABORTED;
147     case ErrorCode::OUT_OF_RANGE:
148       return Code::OUT_OF_RANGE;
149     case ErrorCode::UNIMPLEMENTED:
150       return Code::UNIMPLEMENTED;
151     case ErrorCode::INTERNAL:
152       return Code::INTERNAL;
153     case ErrorCode::UNAVAILABLE:
154       return Code::UNAVAILABLE;
155     case ErrorCode::DATA_LOSS:
156       return Code::DATA_LOSS;
157   }
158 }
159 
160 }  // namespace tf_framework
161 }  // namespace kernel_gen
162 }  // namespace mlir
163 
164 #define GET_OP_CLASSES
165 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc"
166