1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CC_FRAMEWORK_OPS_H_ 17 #define TENSORFLOW_CC_FRAMEWORK_OPS_H_ 18 19 #include <type_traits> 20 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/framework/tensor.pb.h" 23 #include "tensorflow/core/graph/graph.h" 24 #include "tensorflow/core/lib/hash/hash.h" 25 #include "tensorflow/core/lib/strings/strcat.h" 26 27 namespace tensorflow { 28 29 /// @defgroup core Core Tensorflow API 30 31 class Output; 32 33 /// @addtogroup core 34 /// @{ 35 36 /// Represents a node in the computation graph. 37 class Operation { 38 public: Operation()39 Operation() : node_(nullptr) {} 40 explicit Operation(Node* n); 41 num_inputs()42 int32 num_inputs() const { return node_->num_inputs(); } input_type(int32_t o)43 DataType input_type(int32_t o) const { return node_->input_type(o); } 44 Output input(int32_t i) const; 45 num_outputs()46 int32 num_outputs() const { return node_->num_outputs(); } output_type(int32_t o)47 DataType output_type(int32_t o) const { return node_->output_type(o); } 48 Output output(int32_t i) const; 49 node()50 Node* node() const { return node_; } 51 52 uint64 hash(int32_t index) const; 53 54 bool operator==(const Operation& other) const { return node_ == other.node_; } 55 56 private: 57 typedef std::vector<std::pair<Node*, int32>> Inputs; 58 static Inputs GetInputs(Node* node); 59 60 Inputs inputs_; 61 Node* node_; 62 }; 63 64 /// Represents a tensor value produced by an Operation. 65 class Output { 66 public: 67 Output() = default; Output(Node * n)68 explicit Output(Node* n) : op_(n) {} Output(Node * n,int32_t index)69 Output(Node* n, int32_t index) : op_(n), index_(index) {} Output(const Operation & op,int32_t index)70 Output(const Operation& op, int32_t index) : op_(op), index_(index) {} 71 op()72 Operation op() const { return op_; } node()73 Node* node() const { return op().node(); } index()74 int32 index() const { return index_; } type()75 DataType type() const { return op_.output_type(index_); } name()76 std::string name() const { 77 return strings::StrCat(node()->name(), ":", index()); 78 } 79 bool operator==(const Output& other) const { 80 return op_ == other.op_ && index_ == other.index_; 81 } 82 hash()83 uint64 hash() const { return op_.hash(index_); } 84 85 private: 86 Operation op_ = Operation(nullptr); 87 int32 index_ = 0; 88 }; 89 90 /// Hash class that can be used for e.g. storing Outputs in an unordered_map 91 struct OutputHash { operatorOutputHash92 std::size_t operator()(const Output& output) const { 93 return Hash64Combine(std::hash<Node*>()(output.node()), 94 std::hash<int32>()(output.index())); 95 } 96 }; 97 98 /// Represents a tensor value that can be used as an operand to an Operation. 99 class Input { 100 public: 101 /// Initializer enables constructing an Input object from various kinds of C++ 102 /// constants such as simple primitive constants and nested initializer lists 103 /// representing a multi-dimensional array. Initializer constructors are all 104 /// templates, so the aforementioned kinds of C++ constants can be used to 105 /// construct an Initializer. Initializer stores the value it got constructed 106 /// with in a Tensor object. 107 struct Initializer { 108 /// Construct from a scalar value of an arithmetic type or a type that can 109 /// be converted to a string (eg. a string literal). 110 template <typename T, typename = typename std::enable_if< 111 std::is_arithmetic<T>::value || 112 std::is_convertible<T, std::string>::value>::type> InitializerInitializer113 Initializer(const T& v) { // NOLINT(runtime/explicit) 114 typedef typename RealType<T>::type RealT; 115 Tensor t(DataTypeToEnum<RealT>::v(), TensorShape()); 116 t.flat<RealT>()(0) = RealT(v); 117 tensor = t; 118 } 119 InitializerInitializer120 Initializer(const Tensor& t) : tensor(t) {} // NOLINT(runtime/explicit) 121 122 /// Construct from a scalar value and an explicit shape 123 template <typename T, typename = typename std::enable_if< 124 std::is_arithmetic<T>::value || 125 std::is_convertible<T, std::string>::value>::type> InitializerInitializer126 Initializer(const T& v, const TensorShape& shape) { 127 typedef typename RealType<T>::type RealT; 128 Tensor t(DataTypeToEnum<RealT>::v(), shape); 129 for (int64_t i = 0; i < t.NumElements(); ++i) { 130 t.flat<RealT>()(i) = RealT(v); 131 } 132 tensor = t; 133 } 134 135 /// Construct from a initializer list of scalars (a one-dimensional tensor). 136 template <typename T, typename = typename std::enable_if< 137 std::is_arithmetic<T>::value || 138 std::is_convertible<T, std::string>::value>::type> InitializerInitializer139 Initializer( 140 const std::initializer_list<T>& v) { // NOLINT(runtime/explicit) 141 typedef typename RealType<T>::type RealT; 142 Tensor t(DataTypeToEnum<RealT>::v(), 143 TensorShape{static_cast<int>(v.size())}); 144 std::copy_n(v.begin(), v.size(), t.flat<RealT>().data()); 145 tensor = t; 146 } 147 148 /// Construct from a initializer list of scalars and an explicit shape. 149 template <typename T, typename = typename std::enable_if< 150 std::is_arithmetic<T>::value || 151 std::is_convertible<T, std::string>::value>::type> InitializerInitializer152 Initializer(const std::initializer_list<T>& v, const TensorShape& shape) { 153 typedef typename RealType<T>::type RealT; 154 Tensor t(DataTypeToEnum<RealT>::v(), shape); 155 if (t.NumElements() != static_cast<int64_t>(v.size())) { 156 status = errors::InvalidArgument( 157 "Cannot construct a tensor with ", t.NumElements(), 158 " from an initializer list with ", v.size(), " elements"); 159 return; 160 } 161 std::copy_n(v.begin(), v.size(), t.flat<RealT>().data()); 162 tensor = t; 163 } 164 165 /// Construct a multi-dimensional tensor from a nested initializer 166 /// list. Note that C++ syntax allows nesting of arbitrarily typed 167 /// initializer lists, so such invalid initializers cannot be disallowed at 168 /// compile time. This function performs checks to make sure that the nested 169 /// initializer list is indeed a valid multi-dimensional tensor. 170 Initializer(const std::initializer_list<Initializer>& v); 171 172 // START_SKIP_DOXYGEN 173 template <typename T, bool = std::is_convertible<T, std::string>::value> 174 struct RealType { 175 typedef tstring type; 176 }; 177 178 template <typename T> 179 struct RealType<T, false> { 180 typedef T type; 181 }; 182 // END_SKIP_DOXYGEN 183 184 TensorProto AsTensorProto() { 185 TensorProto tensor_proto; 186 if (tensor.NumElements() > 1) { 187 tensor.AsProtoTensorContent(&tensor_proto); 188 } else { 189 tensor.AsProtoField(&tensor_proto); 190 } 191 return tensor_proto; 192 } 193 194 Status status; 195 Tensor tensor; 196 }; 197 198 /// All of Input's constructors are implicit. Input can be implicitly 199 /// constructed from the following objects : 200 /// * Output: This is so that the output of an Operation can be directly used 201 /// as the input to a op wrapper, which takes Inputs. 202 /// * A scalar, or a multi-dimensional tensor specified as a recursive 203 /// initializer list. This enables directly passing constants as 204 /// inputs to op wrappers. 205 /// * A Tensor object. 206 Input(const Output& o) : output_(o) {} // NOLINT(runtime/explicit) 207 208 template <typename T, typename = typename std::enable_if< 209 std::is_arithmetic<T>::value || 210 std::is_convertible<T, std::string>::value>::type> 211 Input(const T& v) // NOLINT(runtime/explicit) 212 : Input(Initializer(v)) {} 213 214 Input(const Initializer& init) // NOLINT(runtime/explicit) 215 : status_(init.status), 216 tensor_(init.tensor) {} 217 218 Input(const Tensor& t) // NOLINT(runtime/explicit) 219 : status_(OkStatus()), tensor_(t) {} 220 221 Input(const std::initializer_list<Initializer>& 222 init) { // NOLINT(runtime/explicit) 223 for (const auto& i : init) { 224 if (!i.status.ok()) { 225 status_ = i.status; 226 return; 227 } 228 } 229 tensor_ = Initializer(init).tensor; 230 } 231 232 /// Constructor specifying a node name, index and datatype. This should only 233 /// be used for specifying a backward edge, needed by control flow. 234 Input(const std::string& name, int32_t i, DataType dt) 235 : node_name_(name), index_(i), data_type_(dt) {} 236 237 Node* node() const { return output_.node(); } 238 std::string node_name() const { return node_name_; } 239 int32 index() const { return node_name_.empty() ? output_.index() : index_; } 240 DataType data_type() const { return data_type_; } 241 Status status() const { return status_; } 242 const Tensor& tensor() const { return tensor_; } 243 244 private: 245 Status status_; 246 Output output_ = Output(Operation(nullptr), 0); 247 Tensor tensor_; 248 const std::string node_name_ = ""; 249 int32 index_ = 0; 250 DataType data_type_ = DT_INVALID; 251 }; 252 253 /// A type for representing the output of ops that produce more than one output, 254 /// or a list of tensors. 255 typedef std::vector<Output> OutputList; 256 257 /// A type for representing the input to ops that require a list of tensors. 258 class InputList { 259 public: 260 /// Implicitly convert a list of outputs to a list of inputs. This is useful 261 /// to write code such as ops::Concat(ops::Split(x, 4)). 262 InputList(const OutputList& out) { // NOLINT(runtime/explicit) 263 for (auto const& x : out) { 264 inputs_.push_back(x); 265 } 266 } 267 268 InputList( 269 const std::initializer_list<Input>& inputs) // NOLINT(runtime/explicit) 270 : inputs_(inputs.begin(), inputs.end()) {} 271 272 InputList(const tensorflow::gtl::ArraySlice<Input>& 273 inputs) // NOLINT(runtime/explicit) 274 : inputs_(inputs.begin(), inputs.end()) {} 275 276 InputList( 277 const std::initializer_list<Output>& out) { // NOLINT(runtime/explicit) 278 for (auto const& x : out) { 279 inputs_.push_back(x); 280 } 281 } 282 283 typename std::vector<Input>::iterator begin() { return inputs_.begin(); } 284 typename std::vector<Input>::iterator end() { return inputs_.end(); } 285 typename std::vector<Input>::const_iterator begin() const { 286 return inputs_.begin(); 287 } 288 typename std::vector<Input>::const_iterator end() const { 289 return inputs_.end(); 290 } 291 292 private: 293 std::vector<Input> inputs_; 294 }; 295 296 /// @} 297 298 } // namespace tensorflow 299 300 #endif // TENSORFLOW_CC_FRAMEWORK_OPS_H_ 301