xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/buffer_value.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_
18 
19 #include <functional>
20 #include <string>
21 
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/service/hlo.pb.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/platform/logging.h"
29 
30 namespace xla {
31 
32 // Abstract class describing a value used by one of the dataflow analyses -
33 // TuplePointsToAnalysis or HloDataflowAnalysis.
34 // TODO(b/78906445) Delete this class when TuplePointsToAnalysis is unused.
35 //
36 // XLA arrays are trivially a single BufferValue. Tuples are made up of more
37 // than one BufferValue: an BufferValue for the pointer vector, and an
38 // BufferValue for each child element.
39 //
40 // Every BufferValue is defined by a particular instruction and most
41 // instructions define only a single BufferValue. Instructions which define a
42 // single BufferValue include array-shaped instructions such as Add but also
43 // includes Tuple-shaped instructions such as Tuple. The Tuple instruction
44 // defines a single BufferValue which is a vector of pointers to the values
45 // containing the Tuple instruction's operands. Though the result of the Tuple
46 // instruction includes multiple values only the top-level BufferValue (the
47 // vector of pointers) is defined by the Tuple instruction. The values
48 // containing the tuple elements are defined by earlier instructions, usually
49 // the operands of the Tuple instruction.
50 //
51 // Instructions which construct both the tuple *and* the tuple elements define
52 // more than one BufferValue. This includes (at least) tuple-shaped Constant,
53 // Parameter, Infeed and While instructions. These tuple-shaped instructions do
54 // not assemble a tuple from existing BufferValues like the Tuple instruction
55 // does, but rather define all the BufferValues in the tuple.
56 //
57 // Some instructions, such as Bitcast, define no buffers. These instructions
58 // simply forward buffers from their operands.
59 //
60 // The BufferValue object describes which HLO instruction defines a buffer and
61 // where within that instruction's output shape the buffer is defined. The
62 // location within the output shape is indicated by BufferValue::index() which
63 // is defined identically to the index used in ShapeUtil::GetSubshape().
64 // Examples:
65 //
66 // %add = Add(%foo, %bar)
67 // %tuple_constant = Constant({1, {42, 43}})
68 //
69 // %add defines a single array-shaped buffer BufferValue(%add, {}) which holds
70 // the array result of the add operation. The nested-tuple-shaped
71 // %tuple_constant defines 5 buffers described by the following BufferValue
72 // objects:
73 //
74 //   BufferValue(%tuple_constant, {})      // "Top-level" buffer: vector of
75 //                                         //  pointers to BufferValues at
76 //                                         //  indices {0} and {1}
77 //   BufferValue(%tuple_constant, {0})     // Holds value "1"
78 //   BufferValue(%tuple_constant, {1})     // Holds nested tuple: vector of
79 //                                         //  pointers to BufferValues at
80 //                                         //  indices {1, 0} and {1, 1}
81 //   BufferValue(%tuple_constant, {1, 0})  // Holds value "42"
82 //   BufferValue(%tuple_constant, {1, 1})  // Holds value "43"
83 
84 class BufferValue {
85  public:
86   using Color = int64_t;
87 
88   // Id is a unique identifier for the BufferValue to facilitate efficient
89   // collections of BufferValues with stable iteration order.
90   using Id = int64_t;
91 
92   // Functions which return the size and alignment of a logical buffer in bytes.
93   using SizeFunction = std::function<int64_t(const BufferValue&)>;
94   using AlignmentFunction = std::function<int64_t(BufferValue::Color)>;
95 
96   // Prevent value being copied, allowing comparison by pointer,
97   BufferValue(const BufferValue&) = delete;
98   BufferValue& operator=(const BufferValue&) = delete;
99   // ... but allow moves.
100   BufferValue(BufferValue&&) = default;
101   BufferValue& operator=(BufferValue&&) = default;
102 
~BufferValue()103   virtual ~BufferValue() {}
104 
id()105   Id id() const { return id_; }
106 
107   // Return the instruction that defines the buffer.
108   virtual HloInstruction* instruction() const = 0;
109 
110   // Return the index within the output of the instruction where the buffer is
111   // defined. Index used defined as in ShapeUtil::GetSubshape()
112   virtual const ShapeIndex& index() const = 0;
113 
114   // Return the color of the BufferValue. Differently colored buffers can not be
115   // parts of the same allocation.
116   ABSL_DEPRECATED("Use Layout::memory_space instead.")
color()117   Color color() const {
118     CHECK_NE(color_, kInvalidColor)
119         << "Should not query the color of a buffer that was never colored";
120     return color_;
121   }
122 
123   ABSL_DEPRECATED("Use Layout::memory_space instead.")
set_color(Color color)124   void set_color(Color color) {
125     CHECK_NE(color, kInvalidColor)
126         << "Should not set the color of a buffer to the invalid color";
127     color_ = color;
128   }
129 
130   ABSL_DEPRECATED("Use Layout::memory_space instead.")
has_color()131   bool has_color() const { return color_ != kInvalidColor; }
132 
133   // Return the shape of the buffer. This reference points into the shape field
134   // of the instruction defining the buffer.  Therefore, the returned shape will
135   // contain the layout of instruction, if any.
136   virtual const Shape& shape() const = 0;
137 
138   // Returns true if this buffer is the top-level output buffer of the defining
139   // HLO instruction. This is equivalent to index == {}.
IsTopLevel()140   bool IsTopLevel() const { return index().empty(); }
141 
142   // Whether this buffer contains a tuple.
IsTuple()143   bool IsTuple() const { return is_tuple_; }
144 
145   // Whether this buffer contains an array.
IsArray()146   bool IsArray() const { return is_array_; }
147 
148   bool operator<(const BufferValue& other) const { return id_ < other.id_; }
149 
150   virtual std::string ToString() const = 0;
151 
152   // TODO(lauj) rename LogicalBufferProto to BufferValueProto.
153   LogicalBufferProto ToProto(const SizeFunction& size_fn) const;
154 
155   // Returns the LogicalBufferProto::Location that serializes the given
156   // instruction and index.
157   static LogicalBufferProto::Location ToLocationProto(
158       const HloInstruction& instruction, const ShapeIndex& index);
159 
160   const Color kInvalidColor = -1;
161 
162  protected:
163   BufferValue(HloInstruction* instruction, const ShapeIndex& index, Id id);
164 
165  private:
166   // The defining instruction and index are not stored here; they can be found
167   // in the LogicalBuffer and HloValue subclasses. This class exists only to
168   // support migrations from TuplePointsToAnalysis to HloDataflowAnalysis, by
169   // allowing abstract use of LogicalBuffer or HloValue. After those migrations
170   // are complete, this class should be deleted (b/78906445). Because we plan to
171   // delete LogicalBuffer and this class, we don't refactor all the shared
172   // features from LogicalBuffer and HloValue into this class.
173   Id id_ : 62;
174   bool is_array_ : 1;
175   bool is_tuple_ : 1;
176   Color color_ = kInvalidColor;
177 };
178 
179 std::ostream& operator<<(std::ostream& out, const BufferValue& buffer);
180 
181 }  // namespace xla
182 
183 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_
184