xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26 
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/Optional.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/StringSwitch.h"
37 #include "llvm/ADT/iterator_range.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
41 #include "mlir/Dialect/Traits.h"  // from @llvm-project
42 #include "mlir/IR/Attributes.h"  // from @llvm-project
43 #include "mlir/IR/Builders.h"  // from @llvm-project
44 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
45 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
46 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
47 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
48 #include "mlir/IR/Location.h"  // from @llvm-project
49 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
50 #include "mlir/IR/Matchers.h"  // from @llvm-project
51 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
52 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
53 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
54 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
55 #include "mlir/IR/Types.h"  // from @llvm-project
56 #include "mlir/IR/Value.h"  // from @llvm-project
57 #include "mlir/Parser/Parser.h"  // from @llvm-project
58 #include "mlir/Support/LLVM.h"  // from @llvm-project
59 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
60 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
61 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
62 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
63 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
64 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
65 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
67 #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h"
68 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
69 #include "tensorflow/core/platform/logging.h"
70 #include "tensorflow/core/util/tensor_format.h"
71 
72 namespace mlir {
73 namespace TF {
74 namespace {
75 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
76 }  // namespace
77 
78 //===----------------------------------------------------------------------===//
79 // _XlaHostComputeOp
80 //===----------------------------------------------------------------------===//
81 
82 // This verifies that `_XlaHostComputeMlirOp` has a well-formed
83 // `host_mlir_module` attribute.
84 // For other attributes, there is no additional verification beyond the default.
verify()85 LogicalResult _XlaHostComputeMlirOp::verify() {
86   _XlaHostComputeMlirOp op = *this;
87   // Extract the module and function.
88   StringRef host_module = op.host_mlir_module();
89 
90   if (host_module.empty()) return success();
91 
92   mlir::OwningOpRef<mlir::ModuleOp> module_for_func;
93   tensorflow::Status status = tensorflow::DeserializeMlirModule(
94       host_module.str(), op->getContext(), &module_for_func);
95   if (!status.ok()) {
96     return op.emitError()
97            << "attribute 'host_mlir_module' can not be deserialized. "
98            << status.error_message();
99   }
100 
101   func::FuncOp func = module_for_func->lookupSymbol<func::FuncOp>("host_func");
102   if (!func)
103     return op.emitError()
104            << "serialized module in attribute 'host_mlir_module' does not "
105               "contain 'host_func' function.";
106 
107   if (op->getNumOperands() != func.getFunctionType().getNumInputs())
108     return op.emitError()
109            << "'host_func' has " << func.getFunctionType().getNumInputs()
110            << " inputs and '_XlaHostComputeMlir' has " << op->getNumOperands()
111            << " operands.  Number of operands/inputs should be the same.";
112 
113   if (op->getNumResults() != func.getFunctionType().getNumResults())
114     return op.emitError() << "'host_func' has "
115                           << func.getFunctionType().getNumResults()
116                           << " results and '_XlaHostComputeMlir' has "
117                           << op->getNumResults()
118                           << " results.  Number of results should be the same.";
119 
120   return success();
121 }
122 
GetHostFunc(mlir::OwningOpRef<mlir::ModuleOp> * mlir_module)123 func::FuncOp _XlaHostComputeMlirOp::GetHostFunc(
124     mlir::OwningOpRef<mlir::ModuleOp>* mlir_module) {
125   if (!tensorflow::DeserializeMlirModule(host_mlir_module().str(),
126                                          this->getContext(), mlir_module)
127            .ok())
128     return nullptr;
129   return (*mlir_module)->lookupSymbol<func::FuncOp>("host_func");
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // XLA Send/Recv ops
134 //===----------------------------------------------------------------------===//
135 
136 // For XLA Send/Recv ops the key corresponds to the resource instance.
137 
GetResourceInstanceStr()138 std::string _XlaRecvAtHostOp::GetResourceInstanceStr() { return key().str(); }
139 
GetResourceInstanceStr()140 std::string _XlaRecvAtHostV2Op::GetResourceInstanceStr() { return key().str(); }
141 
GetResourceInstanceStr()142 std::string _XlaSendFromHostOp::GetResourceInstanceStr() { return key().str(); }
143 
GetResourceInstanceStr()144 std::string _XlaSendFromHostV2Op::GetResourceInstanceStr() {
145   return key().str();
146 }
147 
148 namespace {
GetRendezvousKey(const std::string & send_device,const uint64_t send_device_incarnation,const std::string & recv_device,const std::string & tensor_name)149 std::string GetRendezvousKey(const std::string& send_device,
150                              const uint64_t send_device_incarnation,
151                              const std::string& recv_device,
152                              const std::string& tensor_name) {
153   return absl::StrCat(send_device, ";", send_device_incarnation, ";",
154                       recv_device, ";", tensor_name);
155 }
156 }  // namespace
157 
GetResourceInstanceStr()158 std::string _HostRecvOp::GetResourceInstanceStr() {
159   return GetRendezvousKey(send_device().str(), send_device_incarnation(),
160                           recv_device().str(), tensor_name().str());
161 }
162 
GetResourceInstanceStr()163 std::string _HostSendOp::GetResourceInstanceStr() {
164   return GetRendezvousKey(send_device().str(), send_device_incarnation(),
165                           recv_device().str(), tensor_name().str());
166 }
167 
GetResourceInstanceStr()168 std::string _RecvOp::GetResourceInstanceStr() {
169   return GetRendezvousKey(send_device().str(), send_device_incarnation(),
170                           recv_device().str(), tensor_name().str());
171 }
172 
GetResourceInstanceStr()173 std::string _SendOp::GetResourceInstanceStr() {
174   return GetRendezvousKey(send_device().str(), send_device_incarnation(),
175                           recv_device().str(), tensor_name().str());
176 }
177 
178 }  // namespace TF
179 }  // namespace mlir
180 
181 //===----------------------------------------------------------------------===//
182 // TableGen'd op method definitions
183 //===----------------------------------------------------------------------===//
184 
185 #define GET_OP_CLASSES
186 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc.inc"
187