1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ 18 19 #include <functional> 20 #include <string> 21 22 #include "absl/types/span.h" 23 #include "tensorflow/compiler/xla/service/hlo.pb.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/shape_util.h" 26 #include "tensorflow/compiler/xla/types.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 #include "tensorflow/core/platform/logging.h" 29 30 namespace xla { 31 32 // Abstract class describing a value used by one of the dataflow analyses - 33 // TuplePointsToAnalysis or HloDataflowAnalysis. 34 // TODO(b/78906445) Delete this class when TuplePointsToAnalysis is unused. 35 // 36 // XLA arrays are trivially a single BufferValue. Tuples are made up of more 37 // than one BufferValue: an BufferValue for the pointer vector, and an 38 // BufferValue for each child element. 39 // 40 // Every BufferValue is defined by a particular instruction and most 41 // instructions define only a single BufferValue. Instructions which define a 42 // single BufferValue include array-shaped instructions such as Add but also 43 // includes Tuple-shaped instructions such as Tuple. The Tuple instruction 44 // defines a single BufferValue which is a vector of pointers to the values 45 // containing the Tuple instruction's operands. Though the result of the Tuple 46 // instruction includes multiple values only the top-level BufferValue (the 47 // vector of pointers) is defined by the Tuple instruction. The values 48 // containing the tuple elements are defined by earlier instructions, usually 49 // the operands of the Tuple instruction. 50 // 51 // Instructions which construct both the tuple *and* the tuple elements define 52 // more than one BufferValue. This includes (at least) tuple-shaped Constant, 53 // Parameter, Infeed and While instructions. These tuple-shaped instructions do 54 // not assemble a tuple from existing BufferValues like the Tuple instruction 55 // does, but rather define all the BufferValues in the tuple. 56 // 57 // Some instructions, such as Bitcast, define no buffers. These instructions 58 // simply forward buffers from their operands. 59 // 60 // The BufferValue object describes which HLO instruction defines a buffer and 61 // where within that instruction's output shape the buffer is defined. The 62 // location within the output shape is indicated by BufferValue::index() which 63 // is defined identically to the index used in ShapeUtil::GetSubshape(). 64 // Examples: 65 // 66 // %add = Add(%foo, %bar) 67 // %tuple_constant = Constant({1, {42, 43}}) 68 // 69 // %add defines a single array-shaped buffer BufferValue(%add, {}) which holds 70 // the array result of the add operation. The nested-tuple-shaped 71 // %tuple_constant defines 5 buffers described by the following BufferValue 72 // objects: 73 // 74 // BufferValue(%tuple_constant, {}) // "Top-level" buffer: vector of 75 // // pointers to BufferValues at 76 // // indices {0} and {1} 77 // BufferValue(%tuple_constant, {0}) // Holds value "1" 78 // BufferValue(%tuple_constant, {1}) // Holds nested tuple: vector of 79 // // pointers to BufferValues at 80 // // indices {1, 0} and {1, 1} 81 // BufferValue(%tuple_constant, {1, 0}) // Holds value "42" 82 // BufferValue(%tuple_constant, {1, 1}) // Holds value "43" 83 84 class BufferValue { 85 public: 86 using Color = int64_t; 87 88 // Id is a unique identifier for the BufferValue to facilitate efficient 89 // collections of BufferValues with stable iteration order. 90 using Id = int64_t; 91 92 // Functions which return the size and alignment of a logical buffer in bytes. 93 using SizeFunction = std::function<int64_t(const BufferValue&)>; 94 using AlignmentFunction = std::function<int64_t(BufferValue::Color)>; 95 96 // Prevent value being copied, allowing comparison by pointer, 97 BufferValue(const BufferValue&) = delete; 98 BufferValue& operator=(const BufferValue&) = delete; 99 // ... but allow moves. 100 BufferValue(BufferValue&&) = default; 101 BufferValue& operator=(BufferValue&&) = default; 102 ~BufferValue()103 virtual ~BufferValue() {} 104 id()105 Id id() const { return id_; } 106 107 // Return the instruction that defines the buffer. 108 virtual HloInstruction* instruction() const = 0; 109 110 // Return the index within the output of the instruction where the buffer is 111 // defined. Index used defined as in ShapeUtil::GetSubshape() 112 virtual const ShapeIndex& index() const = 0; 113 114 // Return the color of the BufferValue. Differently colored buffers can not be 115 // parts of the same allocation. 116 ABSL_DEPRECATED("Use Layout::memory_space instead.") color()117 Color color() const { 118 CHECK_NE(color_, kInvalidColor) 119 << "Should not query the color of a buffer that was never colored"; 120 return color_; 121 } 122 123 ABSL_DEPRECATED("Use Layout::memory_space instead.") set_color(Color color)124 void set_color(Color color) { 125 CHECK_NE(color, kInvalidColor) 126 << "Should not set the color of a buffer to the invalid color"; 127 color_ = color; 128 } 129 130 ABSL_DEPRECATED("Use Layout::memory_space instead.") has_color()131 bool has_color() const { return color_ != kInvalidColor; } 132 133 // Return the shape of the buffer. This reference points into the shape field 134 // of the instruction defining the buffer. Therefore, the returned shape will 135 // contain the layout of instruction, if any. 136 virtual const Shape& shape() const = 0; 137 138 // Returns true if this buffer is the top-level output buffer of the defining 139 // HLO instruction. This is equivalent to index == {}. IsTopLevel()140 bool IsTopLevel() const { return index().empty(); } 141 142 // Whether this buffer contains a tuple. IsTuple()143 bool IsTuple() const { return is_tuple_; } 144 145 // Whether this buffer contains an array. IsArray()146 bool IsArray() const { return is_array_; } 147 148 bool operator<(const BufferValue& other) const { return id_ < other.id_; } 149 150 virtual std::string ToString() const = 0; 151 152 // TODO(lauj) rename LogicalBufferProto to BufferValueProto. 153 LogicalBufferProto ToProto(const SizeFunction& size_fn) const; 154 155 // Returns the LogicalBufferProto::Location that serializes the given 156 // instruction and index. 157 static LogicalBufferProto::Location ToLocationProto( 158 const HloInstruction& instruction, const ShapeIndex& index); 159 160 const Color kInvalidColor = -1; 161 162 protected: 163 BufferValue(HloInstruction* instruction, const ShapeIndex& index, Id id); 164 165 private: 166 // The defining instruction and index are not stored here; they can be found 167 // in the LogicalBuffer and HloValue subclasses. This class exists only to 168 // support migrations from TuplePointsToAnalysis to HloDataflowAnalysis, by 169 // allowing abstract use of LogicalBuffer or HloValue. After those migrations 170 // are complete, this class should be deleted (b/78906445). Because we plan to 171 // delete LogicalBuffer and this class, we don't refactor all the shared 172 // features from LogicalBuffer and HloValue into this class. 173 Id id_ : 62; 174 bool is_array_ : 1; 175 bool is_tuple_ : 1; 176 Color color_ = kInvalidColor; 177 }; 178 179 std::ostream& operator<<(std::ostream& out, const BufferValue& buffer); 180 181 } // namespace xla 182 183 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ 184