1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ATTRIBUTE_UTILS_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ATTRIBUTE_UTILS_H_
18
19 #include "mlir/IR/Attributes.h" // from @llvm-project
20 #include "mlir/IR/Operation.h" // from @llvm-project
21 #include "tensorflow/compiler/tf2xla/tf2xla_defs.h"
22
23 namespace mlir {
24 namespace TF {
25
26 // TODO(b/229028654) Use definitions from tf2xla_defs.h directly. We currently
27 // don't do this to avoid explicit casts (implicit conversion from
28 // `absl::string_view` to `llvm::StringRef` is not supported until C++17).
29
30 // Marks a node for XLA compilation. The attribute value indicates the
31 // compilation device type.
32 inline constexpr llvm::StringRef kCompileDeviceTypeAttr =
33 "_xla_compile_device_type";
34 // Marks a node for replication. The attribute value indicates the replication
35 // metadata op.
36 inline constexpr llvm::StringRef kReplicationInfoAttr = "_replication_info";
37 // Marks a node for XLA-TPU compilation. The attribute value indicates the
38 // associated compilation cluster and replication metadata op.
39 inline constexpr llvm::StringRef kTpuReplicateAttr = "_tpu_replicate";
40 // Device types.
41 inline constexpr llvm::StringRef kTpuDevice = "TPU";
42 // Function attribute to signal that a function should be skipped from TPU
43 // island outlining. The attribute is set in
44 // `TpuV1BridgeExecutorIslandCoarsening` and removed in the subsequent
45 // `TPUBridgeExecutorIslandOutlining` pass.
46 inline constexpr llvm::StringRef kSkipIslandOutlining =
47 "_skip_island_outlining";
48
49 // Copies attributes that satisfy the given predicate from `from` to `to`.
50 template <typename Predicate>
CopyAttributes(Operation * from,Operation * to,Predicate P)51 void CopyAttributes(Operation *from, Operation *to, Predicate P) {
52 for (const NamedAttribute &attr : from->getAttrs())
53 if (P(attr)) to->setAttr(attr.getName(), attr.getValue());
54 }
55
56 // Copies attributes whose name begins with an _ from `from` to `to`.
CopyUnderscoredAttributes(Operation * from,Operation * to)57 inline void CopyUnderscoredAttributes(Operation *from, Operation *to) {
58 CopyAttributes(from, to, [](const NamedAttribute &attr) {
59 return attr.getName().strref().front() == '_';
60 });
61 }
62
63 // Copies attributes that are either `device` or whose name begins with an _
64 // from `from` to `to`.
65 // TODO(b/158769932): This should be a general feature instead post some policy
66 // discussion.
CopyDeviceAndUnderscoredAttributes(Operation * from,Operation * to)67 inline void CopyDeviceAndUnderscoredAttributes(Operation *from, Operation *to) {
68 auto device = mlir::StringAttr::get(from->getContext(), "device");
69 CopyAttributes(from, to, [&device](const NamedAttribute &attr) {
70 return attr.getName().strref().front() == '_' || attr.getName() == device;
71 });
72 }
73
74 // Forward declare these passthrough ops.
75 // TODO(jpienaar): Remove these and use trait instead.
76 class IdentityOp;
77 class IdentityNOp;
78
79 // Returns if a value corresponds to a constant, returns the matched constant
80 // as an attribute.
81 template <typename AttrT>
GetValueAsConstant(Value val,AttrT & attr)82 bool GetValueAsConstant(Value val, AttrT &attr) {
83 while (auto result = val.dyn_cast<OpResult>()) {
84 Operation *op = result.getOwner();
85 if (!isa<IdentityOp>(op) && !isa<IdentityNOp>(op)) break;
86 val = op->getOperand(result.getResultNumber());
87 }
88 return matchPattern(val, m_Constant(&attr));
89 }
90
91 LogicalResult HasValidCompilationAndReplicationAttributes(Operation &op);
92
93 } // namespace TF
94 } // namespace mlir
95
96 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ATTRIBUTE_UTILS_H_
97