xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/shape_tree.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
17 #define TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
18 
19 #include <algorithm>
20 #include <functional>
21 #include <iterator>
22 #include <memory>
23 #include <type_traits>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/algorithm/container.h"
28 #include "absl/container/inlined_vector.h"
29 #include "absl/functional/function_ref.h"
30 #include "absl/types/span.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/lib/gtl/iterator_range.h"
35 #include "tensorflow/core/platform/logging.h"
36 
37 namespace xla {
38 
39 namespace internal {
40 
41 class IndexTable {
42  public:
43   // Use indices, rather than pointers, so index table can be copied between
44   // ShapeTrees.
45   struct Entry {
46     // Index of the node in the nodes vector.
47     size_t node_id;
48     // Index of the first child of this node in the index table (-1 for leaves).
49     std::make_signed_t<size_t> children_start_id = -1;
50   };
51 
52   IndexTable() = default;
53   explicit IndexTable(const Shape& shape);
54 
empty()55   bool empty() const { return entries_.empty(); }
56 
57   const Entry& operator[](ShapeIndexView index) const;
58 
59  private:
60   void CreateEntry(Entry& entry, const Shape& shape, size_t& next_node_id);
61 
62   absl::InlinedVector<Entry, 1> entries_;
63 };
64 
65 }  // namespace internal
66 
67 // A ShapeTree<T> is a recursive data structure which mirrors the structure of a
68 // XLA shape and holds a value of type T for each subshape (i.e. tuple or array)
69 // in the shape. For array shapes, a ShapeTree trivially holds a single value of
70 // type T.
71 //
72 // For tuple shapes which can be an arbitrary tree with arrays at the leaves, a
73 // ShapeTree is an identically structured tree with data elements of type T at
74 // every node. I.e. the root is a tuple by definition, all interior nodes are
75 // also tuples, and all leaves are arrays.
76 //
77 // Like the Shape data structure, this is a tree and tuple elements cannot be
78 // duplicated. That is, every distinct ShapeIndex in the Shape has a unique T
79 // object.
80 //
81 // Normally a ShapeTree owns its Shape, but for efficiency reasons, sometimes
82 // it's helpful not to copy a Shape just to make a ShapeTree.  In these cases,
83 // you can pass a Shape* instead of a Shape to the ShapeTree constructor.  It's
84 // then up to you to ensure that the pointed-to Shape isn't freed, moved or
85 // modified before its ShapeTree goes away.
86 template <typename T>
87 class ShapeTree {
88   template <typename U>
89   friend class ShapeTree;
90 
91  public:
92   // TODO(cjfj): Don't store ShapeIndex with data. Generate it or cache it?
93   using Node = std::pair<ShapeIndex, T>;
94   using Nodes = absl::InlinedVector<Node, 1>;
95   using IndexTable = internal::IndexTable;
96 
97   template <typename Iterator, typename ValueType>
98   class LeafIterator;
99 
100   // Default constructor creates a tree with a nil shape (i.e. an empty tuple).
ShapeTree()101   ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
102 
103   // Create ShapeTree with the given shape, and default-constructed T values for
104   // all nodes.
105   //
106   // The version that takes a pointer may be cheaper because it doesn't require
107   // any Shape copies, but then it's up to you to ensure that the pointer stays
108   // alive longer than this ShapeTree.
ShapeTree(Shape shape)109   explicit ShapeTree(Shape shape)
110       : ShapeTree(std::make_shared<Shape>(std::move(shape))) {}
111 
ShapeTree(const Shape * shape)112   explicit ShapeTree(const Shape* shape)
113       : ShapeTree(shape, CreateNodes(*shape)) {}
114 
115   // Create ShapeTree with the given shape, and init_value for all nodes.
ShapeTree(Shape shape,const T & init_value)116   ShapeTree(Shape shape, const T& init_value)
117       : ShapeTree(std::make_shared<Shape>(std::move(shape)), init_value) {}
118 
ShapeTree(const Shape * shape,const T & init_value)119   ShapeTree(const Shape* shape, const T& init_value)
120       : ShapeTree(shape, CreateNodes(*shape, init_value)) {}
121 
122   // Returns the data element associated with the array in the shape at the
123   // given index (see ShapeUtil::GetSubshape for how indexes are defined).
element(ShapeIndexView index)124   const T& element(ShapeIndexView index) const { return find(index)->second; }
mutable_element(ShapeIndexView index)125   T* mutable_element(ShapeIndexView index) { return &find(index)->second; }
126 
127   // Return the shape represented with this ShapeTree.
shape()128   const Shape& shape() const { return *shape_; }
129 
130   // A ShapeTree object can own the underlying Shape pointer (via the
131   // shape_storage_ member), or can point to a Shape object owned by the caller.
132   // This API replaces the underlying Shape object to the one supplied by the
133   // caller, whom must ensure the object remain valid for the whole lifetime of
134   // this ShapeTree object, and also that the Shape is consistent with it.
replace_shape_ptr(const Shape & shape)135   void replace_shape_ptr(const Shape& shape) {
136     if (shape_storage_ != nullptr) {
137       DCHECK_EQ(shape, *shape_storage_);
138       shape_storage_ = nullptr;
139     }
140     shape_ = &shape;
141   }
142 
143   // Returns true if the node at the given index is a leaf node (an array
144   // shape).
IsLeaf(ShapeIndexView index)145   bool IsLeaf(ShapeIndexView index) const {
146     return index_table_[index].children_start_id == -1;
147   }
148 
149   using iterator = typename Nodes::iterator;
150   using const_iterator = typename Nodes::const_iterator;
151   using reverse_iterator = typename Nodes::reverse_iterator;
152   using const_reverse_iterator = typename Nodes::const_reverse_iterator;
153 
154   using leaf_iterator = LeafIterator<iterator, Node>;
155   using const_leaf_iterator = LeafIterator<const_iterator, const Node>;
156   using reverse_leaf_iterator = std::reverse_iterator<leaf_iterator>;
157   using const_reverse_leaf_iterator =
158       std::reverse_iterator<const_leaf_iterator>;
159 
begin()160   iterator begin() { return nodes_.begin(); }
end()161   iterator end() { return nodes_.end(); }
begin()162   const_iterator begin() const { return nodes_.begin(); }
end()163   const_iterator end() const { return nodes_.end(); }
164 
rbegin()165   reverse_iterator rbegin() { return nodes_.rbegin(); }
rend()166   reverse_iterator rend() { return nodes_.rend(); }
rbegin()167   const_reverse_iterator rbegin() const { return nodes_.rbegin(); }
rend()168   const_reverse_iterator rend() const { return nodes_.rend(); }
169 
170   // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no
171   // children).
leaf_begin()172   leaf_iterator leaf_begin() { return leaf_iterator(*this, nodes_.begin()); }
leaf_end()173   leaf_iterator leaf_end() { return leaf_iterator(*this, nodes_.end()); }
leaf_begin()174   const_leaf_iterator leaf_begin() const {
175     return const_leaf_iterator(*this, nodes_.begin());
176   }
leaf_end()177   const_leaf_iterator leaf_end() const {
178     return const_leaf_iterator(*this, nodes_.end());
179   }
180   // range-based iterator for leaf_begin()/leaf_end().
leaves()181   tensorflow::gtl::iterator_range<leaf_iterator> leaves() {
182     return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
183   }
leaves()184   tensorflow::gtl::iterator_range<const_leaf_iterator> leaves() const {
185     return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
186   }
187 
leaf_rbegin()188   reverse_leaf_iterator leaf_rbegin() {
189     return reverse_leaf_iterator(leaf_end());
190   }
leaf_rend()191   reverse_leaf_iterator leaf_rend() {
192     return reverse_leaf_iterator(leaf_begin());
193   }
leaf_rbegin()194   const_reverse_leaf_iterator leaf_rbegin() const {
195     return const_reverse_leaf_iterator(leaf_end());
196   }
leaf_rend()197   const_reverse_leaf_iterator leaf_rend() const {
198     return const_reverse_leaf_iterator(leaf_begin());
199   }
200 
201   // Returns an iterator pointing to the given ShapeIndex.
202   // REQUIRES: index must exist in the ShapeTree.
find(ShapeIndexView index)203   iterator find(ShapeIndexView index) {
204     return nodes_.begin() + index_table_[index].node_id;
205   }
find(ShapeIndexView index)206   const_iterator find(ShapeIndexView index) const {
207     return nodes_.begin() + index_table_[index].node_id;
208   }
209 
210   // Returns the number of leaf nodes in the tree.
leaf_count()211   int64_t leaf_count() const { return std::distance(leaf_begin(), leaf_end()); }
212 
213   // TODO(cjfj): Remove the `ForEach...` methods. They are redundant.
214   // Recursively traverses the shape and calls the given function at each
215   // element.
ForEachElement(absl::FunctionRef<void (const ShapeIndex &,const T &)> func)216   void ForEachElement(
217       absl::FunctionRef<void(const ShapeIndex&, const T&)> func) const {
218     for (const Node& node : nodes_) {
219       func(node.first, node.second);
220     }
221   }
222 
ForEachMutableElement(absl::FunctionRef<void (const ShapeIndex &,T *)> func)223   void ForEachMutableElement(
224       absl::FunctionRef<void(const ShapeIndex&, T*)> func) {
225     for (Node& node : nodes_) {
226       func(node.first, &node.second);
227     }
228   }
229 
230   // Like ForEach(Mutable)Element, but the callable returns a Status instead of
231   // void.  The first non-OK return value is returned by the ForEach* function.
ForEachElementWithStatus(absl::FunctionRef<Status (const ShapeIndex &,const T &)> func)232   Status ForEachElementWithStatus(
233       absl::FunctionRef<Status(const ShapeIndex&, const T&)> func) const {
234     for (const Node& node : nodes_) {
235       TF_RETURN_IF_ERROR(func(node.first, node.second));
236     }
237     return OkStatus();
238   }
239 
ForEachMutableElementWithStatus(absl::FunctionRef<Status (const ShapeIndex &,T *)> func)240   Status ForEachMutableElementWithStatus(
241       absl::FunctionRef<Status(const ShapeIndex&, T*)> func) {
242     for (Node& node : nodes_) {
243       TF_RETURN_IF_ERROR(func(node.first, &node.second));
244     }
245     return OkStatus();
246   }
247 
248   // Maps each element to generate a new tree with the same shape.
249   template <typename U>
Map(absl::FunctionRef<U (const T &)> func)250   ShapeTree<U> Map(absl::FunctionRef<U(const T&)> func) {
251     typename ShapeTree<U>::Nodes result_nodes;
252     result_nodes.reserve(nodes_.size());
253     for (const Node& node : nodes_) {
254       result_nodes.push_back({node.first, func(node.second)});
255     }
256 
257     ShapeTree<U> result(shape_, std::move(result_nodes));
258     result.index_table_ = index_table_;
259     result.shape_storage_ = shape_storage_;
260     return result;
261   }
262 
263   // Copy the subtree of values from 'other' rooted at ShapeIndex 'src_index'
264   // into the subtree of value in this ShapeTree rooted at 'dst_index'.
265   //
266   // Precondition: The subshape of other.shape() at index src_index must be
267   // compatible with the subshape of shape() at index dst_index.
CopySubtreeFrom(const ShapeTree<T> & other,const ShapeIndex & src_index,const ShapeIndex & dst_index)268   void CopySubtreeFrom(const ShapeTree<T>& other, const ShapeIndex& src_index,
269                        const ShapeIndex& dst_index) {
270     const Shape& src_shape = ShapeUtil::GetSubshape(other.shape(), src_index);
271     const Shape& dst_shape = ShapeUtil::GetSubshape(shape(), dst_index);
272     CHECK(ShapeUtil::Compatible(src_shape, dst_shape))
273         << src_shape << ", " << dst_shape;
274 
275     // Replace the prefix `src_index` with `dst_index`.
276     auto replace_shape_index_prefix = [&](const ShapeIndex& index) {
277       auto without_prefix = ShapeIndexView(index).subspan(src_index.size());
278       ShapeIndex result;
279       result.reserve(dst_index.size() + without_prefix.size());
280       result.insert(result.end(), dst_index.begin(), dst_index.end());
281       result.insert(result.end(), without_prefix.begin(), without_prefix.end());
282       return result;
283     };
284 
285     auto first = other.find(src_index);
286     auto last = first + ShapeUtil::SubshapeCount(src_shape);
287 
288     std::transform(first, last, find(dst_index), [&](const Node& node) -> Node {
289       return {replace_shape_index_prefix(node.first), node.second};
290     });
291   }
292 
SubShapeTree(const ShapeIndex & index)293   StatusOr<ShapeTree<T>> SubShapeTree(const ShapeIndex& index) const {
294     TF_ASSIGN_OR_RETURN(const Shape* sub_shape,
295                         ShapeUtil::TryGetSubshape(shape(), index));
296     size_t count = ShapeUtil::SubshapeCount(*sub_shape);
297     Nodes sub_tree_nodes;
298     sub_tree_nodes.reserve(count);
299     for (auto it = find(index), end = it + count; it != end; ++it) {
300       // For each shape index, remove the prefix `index`.
301       auto without_prefix = ShapeIndexView(it->first).subspan(index.size());
302       sub_tree_nodes.push_back(Node{without_prefix, it->second});
303     }
304     return ShapeTree(sub_shape, std::move(sub_tree_nodes));
305   }
306 
307   bool operator==(const ShapeTree<T>& other) const {
308     return nodes_ == other.nodes_;
309   }
310   bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); }
311 
312  private:
ShapeTree(std::shared_ptr<Shape> shape)313   explicit ShapeTree(std::shared_ptr<Shape> shape) : ShapeTree(shape.get()) {
314     shape_storage_.swap(shape);
315   }
316 
ShapeTree(std::shared_ptr<Shape> shape,const T & init_value)317   ShapeTree(std::shared_ptr<Shape> shape, const T& init_value)
318       : ShapeTree(shape.get(), init_value) {
319     shape_storage_.swap(shape);
320   }
321 
ShapeTree(const Shape * shape,Nodes nodes)322   ShapeTree(const Shape* shape, Nodes nodes)
323       : nodes_(std::move(nodes)), index_table_(*shape), shape_(shape) {
324     DCHECK_EQ(nodes_.size(), ShapeUtil::SubshapeCount(*shape));
325   }
326 
327   template <typename... Ts>
CreateNodes(const Shape & shape,Ts &&...args)328   static Nodes CreateNodes(const Shape& shape, Ts&&... args) {
329     Nodes nodes;
330     ShapeUtil::ForEachSubshape(
331         shape, [&](const Shape&, const ShapeIndex& index) {
332           nodes.push_back({index, T(std::forward<Ts>(args)...)});
333         });
334     return nodes;
335   }
336 
337   // The nodes in this shape tree.
338   Nodes nodes_;
339 
340   // Index table for node lookups. Each entry contains the index of the first
341   // child of the node at that index, or -1 for leaf nodes. Evaluated lazily.
342   IndexTable index_table_;
343 
344   // If we own our Shape, this field contains it, and shape_ is a pointer into
345   // here.  Otherwise if we don't own our shape, this is nullptr.
346   std::shared_ptr<Shape> shape_storage_;
347 
348   // The XLA shape mirrored in this ShapeTree.  This is either
349   // shape_storage_.get() or the Shape pointer passed to our constructor.
350   const Shape* shape_;
351 };
352 
353 // Internal iterator that performs a pre-order walk of the leaves. This is cheap
354 // to copy. The iterator value_type is equivalent to a std::pair<ShapeIndex,T>&,
355 // similar to std::map.
356 template <typename T>
357 template <typename Iterator, typename ValueType>
358 class ShapeTree<T>::LeafIterator {
359  public:
360   using iterator_category = std::bidirectional_iterator_tag;
361   using value_type = ValueType;
362   using difference_type = ptrdiff_t;
363   using pointer = value_type*;
364   using reference = value_type&;
365 
LeafIterator(const ShapeTree & tree,Iterator it)366   LeafIterator(const ShapeTree& tree, Iterator it) : tree_(tree), it_(it) {
367     while ((it_ != tree_.nodes_.end()) && !IsLeaf()) ++it_;
368   }
369 
370   LeafIterator& operator++() {
371     do {
372       ++it_;
373     } while ((it_ != tree_.nodes_.end()) && !IsLeaf());
374     return *this;
375   }
376 
377   LeafIterator operator++(int) {
378     auto prev = *this;
379     ++(*this);
380     return prev;
381   }
382 
383   LeafIterator& operator--() {
384     do {
385       --it_;
386     } while ((it_ != tree_.nodes_.begin()) && !IsLeaf());
387     return *this;
388   }
389 
390   LeafIterator operator--(int) {
391     auto prev = *this;
392     --(*this);
393     return prev;
394   }
395 
396   bool operator==(const LeafIterator& other) const { return it_ == other.it_; }
397   bool operator!=(const LeafIterator& other) const { return !(*this == other); }
398   ValueType& operator*() const { return *it_; }
399   ValueType* operator->() const { return &*it_; }
400 
401  private:
IsLeaf()402   bool IsLeaf() const { return tree_.IsLeaf(it_->first); }
403 
404   const ShapeTree<T>& tree_;
405   Iterator it_;
406 };
407 
408 }  // namespace xla
409 
410 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
411