xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ir/tf_op_wrapper.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_IR_TF_OP_WRAPPER_H_
17 #define TENSORFLOW_CORE_IR_TF_OP_WRAPPER_H_
18 
19 #include <cstddef>
20 
21 #include "llvm/ADT/iterator_range.h"
22 #include "mlir/IR/Operation.h"  // from @llvm-project
23 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
24 #include "mlir/IR/TypeRange.h"  // from @llvm-project
25 #include "tensorflow/core/ir/dialect.h"
26 #include "tensorflow/core/ir/types/dialect.h"
27 #include "tensorflow/core/ir/utility.h"
28 
29 namespace mlir {
30 namespace detail {
31 // This class iterates over the control dependencies of the values.
32 template <typename ValueIteratorT>
33 class ControlRetIterator final
34     : public llvm::mapped_iterator_base<ControlRetIterator<ValueIteratorT>,
35                                         ValueIteratorT, Value> {
36  public:
37   using llvm::mapped_iterator_base<ControlRetIterator<ValueIteratorT>,
38                                    ValueIteratorT, Value>::mapped_iterator_base;
39 
mapElement(Value value)40   Value mapElement(Value value) const {
41     return value.getType().isa<tf_type::ControlType>()
42                ? value
43                : tfg::LookupControlDependency(value);
44   }
45 };
46 }  // namespace detail
47 
48 namespace tfg {
49 
50 // Wrapper class exposing convenience methods to manipulate TensorFlow graph
51 // nodes uniformly.
52 class TFOp {
53  public:
54   // Wrap an operation. The operation can be null. The constructor must be
55   // marked as implicit to support `llvm::dyn_cast`.
56   TFOp(Operation *op = nullptr);  // NOLINT
57 
TFOp(Operation & op)58   explicit TFOp(Operation &op) : TFOp(&op) {}
59 
60   // Support LLVM-style RTTI.
classof(Operation * op)61   static bool classof(Operation *op) {
62     return isa<TFGraphDialect>(op->getDialect());
63   }
64 
65   // Get the wrapped operation.
getOperation()66   Operation *getOperation() { return op_; }
67 
68   // Returns a pointer to the TensorFlow Graph Dialect. It nevers returns
69   // nullptr.
getDialect()70   TFGraphDialect *getDialect() {
71     return cast<TFGraphDialect>(op_->getDialect());
72   }
73 
74   // Split the operands into data and control operands.
splitOperands()75   std::pair<OperandRange, OperandRange> splitOperands() {
76     ControlType ctl_type = getDialect()->getControlType();
77     return SplitDataAndControlValues(op_->getOperands(), ctl_type);
78   }
79 
80   // Returns the regular operands, the control operands will be excluded.
getNonControlOperands()81   OperandRange getNonControlOperands() { return splitOperands().first; }
82 
83   // The control operands are always after the regular inputs.
getControlOperands()84   OperandRange getControlOperands() { return splitOperands().second; }
85 
86   // Returns the control token produced by this operation.
controlRet()87   Value controlRet() { return op_->getResult(op_->getNumResults() - 1); }
88 
89   // Returns the non-control results produced by this operation.
getNonControlResults()90   ResultRange getNonControlResults() {
91     return op_->getResults().slice(0, op_->getNumResults() - 1);
92   }
93 
94   // Returns the node name for this operation.
95   StringAttr nameAttr();
96   StringRef name();
97   // Set a new node name for this operation.
98   void setName(const Twine &name);
99   void setName(StringAttr name);
100 
101   // Returns the requested device, which is also the "device" field in a
102   // GraphDef.
103   StringAttr requestedDeviceAttr();
104   StringRef requestedDevice();
105   // Set a new requested device for this operation.
106   void setRequestedDevice(const Twine &requested_device);
107   void setRequestedDevice(StringAttr requested_device);
108 
109   // Returns the assigned device, this field is set by placer in general.
110   StringAttr assignedDeviceAttr();
111   StringRef assignedDevice();
112   // Set a new assigned device for this operation.
113   void setAssignedDevice(const Twine &assigned_device);
114   void setAssignedDevice(StringAttr assigned_device);
115 
116   // Returns the assigned TPU cluster name.
117   StringAttr tpuReplicate();
118   // Set the assigned TPU cluster name.
119   void setTpuReplicate(StringAttr tpu_replicate);
120 
121   // Returns the device, preferring the assigned device if set, and the
122   // requested device otherwise.
deviceAttr()123   StringAttr deviceAttr() {
124     StringAttr device = assignedDeviceAttr();
125     if (device) {
126       assert(!device.getValue().empty());
127       return device;
128     }
129     return requestedDeviceAttr();
130   }
device()131   StringRef device() {
132     StringAttr device_attr = deviceAttr();
133     if (device_attr) return device_attr.getValue();
134     return "";
135   }
136 
137   // Forward `->` to the underlying operation, exposing the `Operation` methods.
138   Operation *operator->() { return op_; }
139   Operation &operator*() { return *op_; }
140 
141   // Converts to true if there is a wrapped operation.
142   explicit operator bool() const { return op_; }
143 
144  private:
145   // The wrapped operation.
146   Operation *op_;
147 };
148 
149 // A range iterator to get the control tokens associated with a value range.
150 // This range allows to wrap a ValueRange (or an OperandRange) and iterates on
151 // the control token associated to the producer of each value. For example, if
152 // you wrap the operands of an operation:
153 //     OperandControlRetRange range = op->getOperands();
154 // iterating this range will yield the control edges from each of the operations
155 // (or block arguments) producing these operands.
156 template <typename ValueRangeT>
157 class ControlRetRange final
158     : public llvm::iterator_range<
159           ::mlir::detail::ControlRetIterator<typename ValueRangeT::iterator>> {
160  public:
161   using Base = llvm::iterator_range<
162       ::mlir::detail::ControlRetIterator<typename ValueRangeT::iterator>>;
ControlRetRange(ValueRangeT c)163   explicit ControlRetRange(ValueRangeT c) : Base(c.begin(), c.end()) {}
164 
165   /// Return the value at the given index.
166   Value operator[](size_t index) const {
167     assert(index < size() && "invalid index into value range");
168     return *(this->begin() + index);
169   }
170 
171   // Return the size of this range.
size()172   size_t size() const { return llvm::size(*this); }
173 
174   // Return first value in the range.
front()175   Value front() { return (*this)[0]; }
176 
177   // Compare this range with another.
178   template <typename OtherT>
179   bool operator==(const OtherT &other) const {
180     return llvm::size(*this) == llvm::size(other) &&
181            std::equal(this->begin(), this->end(), other.begin());
182   }
183   template <typename OtherT>
184   bool operator!=(const OtherT &other) const {
185     return !(*this == other);
186   }
187 };
188 
189 using OperandControlRetRange = ControlRetRange<OperandRange>;
190 using ValueControlRetRange = ControlRetRange<ValueRange>;
191 
192 }  // namespace tfg
193 }  // namespace mlir
194 
195 #endif  // TENSORFLOW_CORE_IR_TF_OP_WRAPPER_H_
196