1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_IR_TF_OP_WRAPPER_H_ 17 #define TENSORFLOW_CORE_IR_TF_OP_WRAPPER_H_ 18 19 #include <cstddef> 20 21 #include "llvm/ADT/iterator_range.h" 22 #include "mlir/IR/Operation.h" // from @llvm-project 23 #include "mlir/IR/OperationSupport.h" // from @llvm-project 24 #include "mlir/IR/TypeRange.h" // from @llvm-project 25 #include "tensorflow/core/ir/dialect.h" 26 #include "tensorflow/core/ir/types/dialect.h" 27 #include "tensorflow/core/ir/utility.h" 28 29 namespace mlir { 30 namespace detail { 31 // This class iterates over the control dependencies of the values. 32 template <typename ValueIteratorT> 33 class ControlRetIterator final 34 : public llvm::mapped_iterator_base<ControlRetIterator<ValueIteratorT>, 35 ValueIteratorT, Value> { 36 public: 37 using llvm::mapped_iterator_base<ControlRetIterator<ValueIteratorT>, 38 ValueIteratorT, Value>::mapped_iterator_base; 39 mapElement(Value value)40 Value mapElement(Value value) const { 41 return value.getType().isa<tf_type::ControlType>() 42 ? value 43 : tfg::LookupControlDependency(value); 44 } 45 }; 46 } // namespace detail 47 48 namespace tfg { 49 50 // Wrapper class exposing convenience methods to manipulate TensorFlow graph 51 // nodes uniformly. 52 class TFOp { 53 public: 54 // Wrap an operation. The operation can be null. The constructor must be 55 // marked as implicit to support `llvm::dyn_cast`. 56 TFOp(Operation *op = nullptr); // NOLINT 57 TFOp(Operation & op)58 explicit TFOp(Operation &op) : TFOp(&op) {} 59 60 // Support LLVM-style RTTI. classof(Operation * op)61 static bool classof(Operation *op) { 62 return isa<TFGraphDialect>(op->getDialect()); 63 } 64 65 // Get the wrapped operation. getOperation()66 Operation *getOperation() { return op_; } 67 68 // Returns a pointer to the TensorFlow Graph Dialect. It nevers returns 69 // nullptr. getDialect()70 TFGraphDialect *getDialect() { 71 return cast<TFGraphDialect>(op_->getDialect()); 72 } 73 74 // Split the operands into data and control operands. splitOperands()75 std::pair<OperandRange, OperandRange> splitOperands() { 76 ControlType ctl_type = getDialect()->getControlType(); 77 return SplitDataAndControlValues(op_->getOperands(), ctl_type); 78 } 79 80 // Returns the regular operands, the control operands will be excluded. getNonControlOperands()81 OperandRange getNonControlOperands() { return splitOperands().first; } 82 83 // The control operands are always after the regular inputs. getControlOperands()84 OperandRange getControlOperands() { return splitOperands().second; } 85 86 // Returns the control token produced by this operation. controlRet()87 Value controlRet() { return op_->getResult(op_->getNumResults() - 1); } 88 89 // Returns the non-control results produced by this operation. getNonControlResults()90 ResultRange getNonControlResults() { 91 return op_->getResults().slice(0, op_->getNumResults() - 1); 92 } 93 94 // Returns the node name for this operation. 95 StringAttr nameAttr(); 96 StringRef name(); 97 // Set a new node name for this operation. 98 void setName(const Twine &name); 99 void setName(StringAttr name); 100 101 // Returns the requested device, which is also the "device" field in a 102 // GraphDef. 103 StringAttr requestedDeviceAttr(); 104 StringRef requestedDevice(); 105 // Set a new requested device for this operation. 106 void setRequestedDevice(const Twine &requested_device); 107 void setRequestedDevice(StringAttr requested_device); 108 109 // Returns the assigned device, this field is set by placer in general. 110 StringAttr assignedDeviceAttr(); 111 StringRef assignedDevice(); 112 // Set a new assigned device for this operation. 113 void setAssignedDevice(const Twine &assigned_device); 114 void setAssignedDevice(StringAttr assigned_device); 115 116 // Returns the assigned TPU cluster name. 117 StringAttr tpuReplicate(); 118 // Set the assigned TPU cluster name. 119 void setTpuReplicate(StringAttr tpu_replicate); 120 121 // Returns the device, preferring the assigned device if set, and the 122 // requested device otherwise. deviceAttr()123 StringAttr deviceAttr() { 124 StringAttr device = assignedDeviceAttr(); 125 if (device) { 126 assert(!device.getValue().empty()); 127 return device; 128 } 129 return requestedDeviceAttr(); 130 } device()131 StringRef device() { 132 StringAttr device_attr = deviceAttr(); 133 if (device_attr) return device_attr.getValue(); 134 return ""; 135 } 136 137 // Forward `->` to the underlying operation, exposing the `Operation` methods. 138 Operation *operator->() { return op_; } 139 Operation &operator*() { return *op_; } 140 141 // Converts to true if there is a wrapped operation. 142 explicit operator bool() const { return op_; } 143 144 private: 145 // The wrapped operation. 146 Operation *op_; 147 }; 148 149 // A range iterator to get the control tokens associated with a value range. 150 // This range allows to wrap a ValueRange (or an OperandRange) and iterates on 151 // the control token associated to the producer of each value. For example, if 152 // you wrap the operands of an operation: 153 // OperandControlRetRange range = op->getOperands(); 154 // iterating this range will yield the control edges from each of the operations 155 // (or block arguments) producing these operands. 156 template <typename ValueRangeT> 157 class ControlRetRange final 158 : public llvm::iterator_range< 159 ::mlir::detail::ControlRetIterator<typename ValueRangeT::iterator>> { 160 public: 161 using Base = llvm::iterator_range< 162 ::mlir::detail::ControlRetIterator<typename ValueRangeT::iterator>>; ControlRetRange(ValueRangeT c)163 explicit ControlRetRange(ValueRangeT c) : Base(c.begin(), c.end()) {} 164 165 /// Return the value at the given index. 166 Value operator[](size_t index) const { 167 assert(index < size() && "invalid index into value range"); 168 return *(this->begin() + index); 169 } 170 171 // Return the size of this range. size()172 size_t size() const { return llvm::size(*this); } 173 174 // Return first value in the range. front()175 Value front() { return (*this)[0]; } 176 177 // Compare this range with another. 178 template <typename OtherT> 179 bool operator==(const OtherT &other) const { 180 return llvm::size(*this) == llvm::size(other) && 181 std::equal(this->begin(), this->end(), other.begin()); 182 } 183 template <typename OtherT> 184 bool operator!=(const OtherT &other) const { 185 return !(*this == other); 186 } 187 }; 188 189 using OperandControlRetRange = ControlRetRange<OperandRange>; 190 using ValueControlRetRange = ControlRetRange<ValueRange>; 191 192 } // namespace tfg 193 } // namespace mlir 194 195 #endif // TENSORFLOW_CORE_IR_TF_OP_WRAPPER_H_ 196