1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ 18 19 #include <algorithm> 20 #include <functional> 21 #include <iterator> 22 #include <memory> 23 #include <type_traits> 24 #include <utility> 25 #include <vector> 26 27 #include "absl/algorithm/container.h" 28 #include "absl/container/inlined_vector.h" 29 #include "absl/functional/function_ref.h" 30 #include "absl/types/span.h" 31 #include "tensorflow/compiler/xla/shape_util.h" 32 #include "tensorflow/compiler/xla/status_macros.h" 33 #include "tensorflow/core/lib/core/status.h" 34 #include "tensorflow/core/lib/gtl/iterator_range.h" 35 #include "tensorflow/core/platform/logging.h" 36 37 namespace xla { 38 39 namespace internal { 40 41 class IndexTable { 42 public: 43 // Use indices, rather than pointers, so index table can be copied between 44 // ShapeTrees. 45 struct Entry { 46 // Index of the node in the nodes vector. 47 size_t node_id; 48 // Index of the first child of this node in the index table (-1 for leaves). 49 std::make_signed_t<size_t> children_start_id = -1; 50 }; 51 52 IndexTable() = default; 53 explicit IndexTable(const Shape& shape); 54 empty()55 bool empty() const { return entries_.empty(); } 56 57 const Entry& operator[](ShapeIndexView index) const; 58 59 private: 60 void CreateEntry(Entry& entry, const Shape& shape, size_t& next_node_id); 61 62 absl::InlinedVector<Entry, 1> entries_; 63 }; 64 65 } // namespace internal 66 67 // A ShapeTree<T> is a recursive data structure which mirrors the structure of a 68 // XLA shape and holds a value of type T for each subshape (i.e. tuple or array) 69 // in the shape. For array shapes, a ShapeTree trivially holds a single value of 70 // type T. 71 // 72 // For tuple shapes which can be an arbitrary tree with arrays at the leaves, a 73 // ShapeTree is an identically structured tree with data elements of type T at 74 // every node. I.e. the root is a tuple by definition, all interior nodes are 75 // also tuples, and all leaves are arrays. 76 // 77 // Like the Shape data structure, this is a tree and tuple elements cannot be 78 // duplicated. That is, every distinct ShapeIndex in the Shape has a unique T 79 // object. 80 // 81 // Normally a ShapeTree owns its Shape, but for efficiency reasons, sometimes 82 // it's helpful not to copy a Shape just to make a ShapeTree. In these cases, 83 // you can pass a Shape* instead of a Shape to the ShapeTree constructor. It's 84 // then up to you to ensure that the pointed-to Shape isn't freed, moved or 85 // modified before its ShapeTree goes away. 86 template <typename T> 87 class ShapeTree { 88 template <typename U> 89 friend class ShapeTree; 90 91 public: 92 // TODO(cjfj): Don't store ShapeIndex with data. Generate it or cache it? 93 using Node = std::pair<ShapeIndex, T>; 94 using Nodes = absl::InlinedVector<Node, 1>; 95 using IndexTable = internal::IndexTable; 96 97 template <typename Iterator, typename ValueType> 98 class LeafIterator; 99 100 // Default constructor creates a tree with a nil shape (i.e. an empty tuple). ShapeTree()101 ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} 102 103 // Create ShapeTree with the given shape, and default-constructed T values for 104 // all nodes. 105 // 106 // The version that takes a pointer may be cheaper because it doesn't require 107 // any Shape copies, but then it's up to you to ensure that the pointer stays 108 // alive longer than this ShapeTree. ShapeTree(Shape shape)109 explicit ShapeTree(Shape shape) 110 : ShapeTree(std::make_shared<Shape>(std::move(shape))) {} 111 ShapeTree(const Shape * shape)112 explicit ShapeTree(const Shape* shape) 113 : ShapeTree(shape, CreateNodes(*shape)) {} 114 115 // Create ShapeTree with the given shape, and init_value for all nodes. ShapeTree(Shape shape,const T & init_value)116 ShapeTree(Shape shape, const T& init_value) 117 : ShapeTree(std::make_shared<Shape>(std::move(shape)), init_value) {} 118 ShapeTree(const Shape * shape,const T & init_value)119 ShapeTree(const Shape* shape, const T& init_value) 120 : ShapeTree(shape, CreateNodes(*shape, init_value)) {} 121 122 // Returns the data element associated with the array in the shape at the 123 // given index (see ShapeUtil::GetSubshape for how indexes are defined). element(ShapeIndexView index)124 const T& element(ShapeIndexView index) const { return find(index)->second; } mutable_element(ShapeIndexView index)125 T* mutable_element(ShapeIndexView index) { return &find(index)->second; } 126 127 // Return the shape represented with this ShapeTree. shape()128 const Shape& shape() const { return *shape_; } 129 130 // A ShapeTree object can own the underlying Shape pointer (via the 131 // shape_storage_ member), or can point to a Shape object owned by the caller. 132 // This API replaces the underlying Shape object to the one supplied by the 133 // caller, whom must ensure the object remain valid for the whole lifetime of 134 // this ShapeTree object, and also that the Shape is consistent with it. replace_shape_ptr(const Shape & shape)135 void replace_shape_ptr(const Shape& shape) { 136 if (shape_storage_ != nullptr) { 137 DCHECK_EQ(shape, *shape_storage_); 138 shape_storage_ = nullptr; 139 } 140 shape_ = &shape; 141 } 142 143 // Returns true if the node at the given index is a leaf node (an array 144 // shape). IsLeaf(ShapeIndexView index)145 bool IsLeaf(ShapeIndexView index) const { 146 return index_table_[index].children_start_id == -1; 147 } 148 149 using iterator = typename Nodes::iterator; 150 using const_iterator = typename Nodes::const_iterator; 151 using reverse_iterator = typename Nodes::reverse_iterator; 152 using const_reverse_iterator = typename Nodes::const_reverse_iterator; 153 154 using leaf_iterator = LeafIterator<iterator, Node>; 155 using const_leaf_iterator = LeafIterator<const_iterator, const Node>; 156 using reverse_leaf_iterator = std::reverse_iterator<leaf_iterator>; 157 using const_reverse_leaf_iterator = 158 std::reverse_iterator<const_leaf_iterator>; 159 begin()160 iterator begin() { return nodes_.begin(); } end()161 iterator end() { return nodes_.end(); } begin()162 const_iterator begin() const { return nodes_.begin(); } end()163 const_iterator end() const { return nodes_.end(); } 164 rbegin()165 reverse_iterator rbegin() { return nodes_.rbegin(); } rend()166 reverse_iterator rend() { return nodes_.rend(); } rbegin()167 const_reverse_iterator rbegin() const { return nodes_.rbegin(); } rend()168 const_reverse_iterator rend() const { return nodes_.rend(); } 169 170 // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no 171 // children). leaf_begin()172 leaf_iterator leaf_begin() { return leaf_iterator(*this, nodes_.begin()); } leaf_end()173 leaf_iterator leaf_end() { return leaf_iterator(*this, nodes_.end()); } leaf_begin()174 const_leaf_iterator leaf_begin() const { 175 return const_leaf_iterator(*this, nodes_.begin()); 176 } leaf_end()177 const_leaf_iterator leaf_end() const { 178 return const_leaf_iterator(*this, nodes_.end()); 179 } 180 // range-based iterator for leaf_begin()/leaf_end(). leaves()181 tensorflow::gtl::iterator_range<leaf_iterator> leaves() { 182 return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); 183 } leaves()184 tensorflow::gtl::iterator_range<const_leaf_iterator> leaves() const { 185 return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); 186 } 187 leaf_rbegin()188 reverse_leaf_iterator leaf_rbegin() { 189 return reverse_leaf_iterator(leaf_end()); 190 } leaf_rend()191 reverse_leaf_iterator leaf_rend() { 192 return reverse_leaf_iterator(leaf_begin()); 193 } leaf_rbegin()194 const_reverse_leaf_iterator leaf_rbegin() const { 195 return const_reverse_leaf_iterator(leaf_end()); 196 } leaf_rend()197 const_reverse_leaf_iterator leaf_rend() const { 198 return const_reverse_leaf_iterator(leaf_begin()); 199 } 200 201 // Returns an iterator pointing to the given ShapeIndex. 202 // REQUIRES: index must exist in the ShapeTree. find(ShapeIndexView index)203 iterator find(ShapeIndexView index) { 204 return nodes_.begin() + index_table_[index].node_id; 205 } find(ShapeIndexView index)206 const_iterator find(ShapeIndexView index) const { 207 return nodes_.begin() + index_table_[index].node_id; 208 } 209 210 // Returns the number of leaf nodes in the tree. leaf_count()211 int64_t leaf_count() const { return std::distance(leaf_begin(), leaf_end()); } 212 213 // TODO(cjfj): Remove the `ForEach...` methods. They are redundant. 214 // Recursively traverses the shape and calls the given function at each 215 // element. ForEachElement(absl::FunctionRef<void (const ShapeIndex &,const T &)> func)216 void ForEachElement( 217 absl::FunctionRef<void(const ShapeIndex&, const T&)> func) const { 218 for (const Node& node : nodes_) { 219 func(node.first, node.second); 220 } 221 } 222 ForEachMutableElement(absl::FunctionRef<void (const ShapeIndex &,T *)> func)223 void ForEachMutableElement( 224 absl::FunctionRef<void(const ShapeIndex&, T*)> func) { 225 for (Node& node : nodes_) { 226 func(node.first, &node.second); 227 } 228 } 229 230 // Like ForEach(Mutable)Element, but the callable returns a Status instead of 231 // void. The first non-OK return value is returned by the ForEach* function. ForEachElementWithStatus(absl::FunctionRef<Status (const ShapeIndex &,const T &)> func)232 Status ForEachElementWithStatus( 233 absl::FunctionRef<Status(const ShapeIndex&, const T&)> func) const { 234 for (const Node& node : nodes_) { 235 TF_RETURN_IF_ERROR(func(node.first, node.second)); 236 } 237 return OkStatus(); 238 } 239 ForEachMutableElementWithStatus(absl::FunctionRef<Status (const ShapeIndex &,T *)> func)240 Status ForEachMutableElementWithStatus( 241 absl::FunctionRef<Status(const ShapeIndex&, T*)> func) { 242 for (Node& node : nodes_) { 243 TF_RETURN_IF_ERROR(func(node.first, &node.second)); 244 } 245 return OkStatus(); 246 } 247 248 // Maps each element to generate a new tree with the same shape. 249 template <typename U> Map(absl::FunctionRef<U (const T &)> func)250 ShapeTree<U> Map(absl::FunctionRef<U(const T&)> func) { 251 typename ShapeTree<U>::Nodes result_nodes; 252 result_nodes.reserve(nodes_.size()); 253 for (const Node& node : nodes_) { 254 result_nodes.push_back({node.first, func(node.second)}); 255 } 256 257 ShapeTree<U> result(shape_, std::move(result_nodes)); 258 result.index_table_ = index_table_; 259 result.shape_storage_ = shape_storage_; 260 return result; 261 } 262 263 // Copy the subtree of values from 'other' rooted at ShapeIndex 'src_index' 264 // into the subtree of value in this ShapeTree rooted at 'dst_index'. 265 // 266 // Precondition: The subshape of other.shape() at index src_index must be 267 // compatible with the subshape of shape() at index dst_index. CopySubtreeFrom(const ShapeTree<T> & other,const ShapeIndex & src_index,const ShapeIndex & dst_index)268 void CopySubtreeFrom(const ShapeTree<T>& other, const ShapeIndex& src_index, 269 const ShapeIndex& dst_index) { 270 const Shape& src_shape = ShapeUtil::GetSubshape(other.shape(), src_index); 271 const Shape& dst_shape = ShapeUtil::GetSubshape(shape(), dst_index); 272 CHECK(ShapeUtil::Compatible(src_shape, dst_shape)) 273 << src_shape << ", " << dst_shape; 274 275 // Replace the prefix `src_index` with `dst_index`. 276 auto replace_shape_index_prefix = [&](const ShapeIndex& index) { 277 auto without_prefix = ShapeIndexView(index).subspan(src_index.size()); 278 ShapeIndex result; 279 result.reserve(dst_index.size() + without_prefix.size()); 280 result.insert(result.end(), dst_index.begin(), dst_index.end()); 281 result.insert(result.end(), without_prefix.begin(), without_prefix.end()); 282 return result; 283 }; 284 285 auto first = other.find(src_index); 286 auto last = first + ShapeUtil::SubshapeCount(src_shape); 287 288 std::transform(first, last, find(dst_index), [&](const Node& node) -> Node { 289 return {replace_shape_index_prefix(node.first), node.second}; 290 }); 291 } 292 SubShapeTree(const ShapeIndex & index)293 StatusOr<ShapeTree<T>> SubShapeTree(const ShapeIndex& index) const { 294 TF_ASSIGN_OR_RETURN(const Shape* sub_shape, 295 ShapeUtil::TryGetSubshape(shape(), index)); 296 size_t count = ShapeUtil::SubshapeCount(*sub_shape); 297 Nodes sub_tree_nodes; 298 sub_tree_nodes.reserve(count); 299 for (auto it = find(index), end = it + count; it != end; ++it) { 300 // For each shape index, remove the prefix `index`. 301 auto without_prefix = ShapeIndexView(it->first).subspan(index.size()); 302 sub_tree_nodes.push_back(Node{without_prefix, it->second}); 303 } 304 return ShapeTree(sub_shape, std::move(sub_tree_nodes)); 305 } 306 307 bool operator==(const ShapeTree<T>& other) const { 308 return nodes_ == other.nodes_; 309 } 310 bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); } 311 312 private: ShapeTree(std::shared_ptr<Shape> shape)313 explicit ShapeTree(std::shared_ptr<Shape> shape) : ShapeTree(shape.get()) { 314 shape_storage_.swap(shape); 315 } 316 ShapeTree(std::shared_ptr<Shape> shape,const T & init_value)317 ShapeTree(std::shared_ptr<Shape> shape, const T& init_value) 318 : ShapeTree(shape.get(), init_value) { 319 shape_storage_.swap(shape); 320 } 321 ShapeTree(const Shape * shape,Nodes nodes)322 ShapeTree(const Shape* shape, Nodes nodes) 323 : nodes_(std::move(nodes)), index_table_(*shape), shape_(shape) { 324 DCHECK_EQ(nodes_.size(), ShapeUtil::SubshapeCount(*shape)); 325 } 326 327 template <typename... Ts> CreateNodes(const Shape & shape,Ts &&...args)328 static Nodes CreateNodes(const Shape& shape, Ts&&... args) { 329 Nodes nodes; 330 ShapeUtil::ForEachSubshape( 331 shape, [&](const Shape&, const ShapeIndex& index) { 332 nodes.push_back({index, T(std::forward<Ts>(args)...)}); 333 }); 334 return nodes; 335 } 336 337 // The nodes in this shape tree. 338 Nodes nodes_; 339 340 // Index table for node lookups. Each entry contains the index of the first 341 // child of the node at that index, or -1 for leaf nodes. Evaluated lazily. 342 IndexTable index_table_; 343 344 // If we own our Shape, this field contains it, and shape_ is a pointer into 345 // here. Otherwise if we don't own our shape, this is nullptr. 346 std::shared_ptr<Shape> shape_storage_; 347 348 // The XLA shape mirrored in this ShapeTree. This is either 349 // shape_storage_.get() or the Shape pointer passed to our constructor. 350 const Shape* shape_; 351 }; 352 353 // Internal iterator that performs a pre-order walk of the leaves. This is cheap 354 // to copy. The iterator value_type is equivalent to a std::pair<ShapeIndex,T>&, 355 // similar to std::map. 356 template <typename T> 357 template <typename Iterator, typename ValueType> 358 class ShapeTree<T>::LeafIterator { 359 public: 360 using iterator_category = std::bidirectional_iterator_tag; 361 using value_type = ValueType; 362 using difference_type = ptrdiff_t; 363 using pointer = value_type*; 364 using reference = value_type&; 365 LeafIterator(const ShapeTree & tree,Iterator it)366 LeafIterator(const ShapeTree& tree, Iterator it) : tree_(tree), it_(it) { 367 while ((it_ != tree_.nodes_.end()) && !IsLeaf()) ++it_; 368 } 369 370 LeafIterator& operator++() { 371 do { 372 ++it_; 373 } while ((it_ != tree_.nodes_.end()) && !IsLeaf()); 374 return *this; 375 } 376 377 LeafIterator operator++(int) { 378 auto prev = *this; 379 ++(*this); 380 return prev; 381 } 382 383 LeafIterator& operator--() { 384 do { 385 --it_; 386 } while ((it_ != tree_.nodes_.begin()) && !IsLeaf()); 387 return *this; 388 } 389 390 LeafIterator operator--(int) { 391 auto prev = *this; 392 --(*this); 393 return prev; 394 } 395 396 bool operator==(const LeafIterator& other) const { return it_ == other.it_; } 397 bool operator!=(const LeafIterator& other) const { return !(*this == other); } 398 ValueType& operator*() const { return *it_; } 399 ValueType* operator->() const { return &*it_; } 400 401 private: IsLeaf()402 bool IsLeaf() const { return tree_.IsLeaf(it_->first); } 403 404 const ShapeTree<T>& tree_; 405 Iterator it_; 406 }; 407 408 } // namespace xla 409 410 #endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ 411