1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SHAPE_H_ 18 19 #include <optional> 20 #include <ostream> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 25 #include "absl/container/inlined_vector.h" 26 #include "tensorflow/compiler/xla/layout.h" 27 #include "tensorflow/compiler/xla/primitive_util.h" 28 #include "tensorflow/compiler/xla/xla_data.pb.h" 29 30 namespace xla { 31 32 // A shape describes the number of dimensions in a array, the bounds of each 33 // dimension, and the primitive component type. For tuples, shape describes the 34 // structure (number of elements and nesting). 35 class Shape { 36 public: 37 Shape() = default; 38 39 // Construct a shape from a ShapeProto. 40 explicit Shape(const ShapeProto& shape_proto); 41 Shape(PrimitiveType element_type,absl::Span<const int64_t> dimensions,absl::Span<const bool> dynamic_dimensions,std::vector<Shape> tuple_shapes)42 Shape(PrimitiveType element_type, absl::Span<const int64_t> dimensions, 43 absl::Span<const bool> dynamic_dimensions, 44 std::vector<Shape> tuple_shapes) 45 : element_type_(element_type), 46 dimensions_(dimensions.begin(), dimensions.end()), 47 dynamic_dimensions_(dynamic_dimensions.begin(), 48 dynamic_dimensions.end()), 49 tuple_shapes_(std::move(tuple_shapes)) {} 50 51 // Returns a ShapeProto representation of the Shape. 52 ShapeProto ToProto() const; 53 54 // Returns a human-readable string that represents the given shape, with or 55 // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]". 56 std::string ToString(bool print_layout = false) const; 57 58 // Returns the rank (number of dimensions) of the given shape. Shape must be 59 // an array. rank()60 int64_t rank() const { 61 DCHECK(IsArray()) << "Non-arrays do not have a rank, shape: " << ToString(); 62 return dimensions_.size(); 63 } 64 65 // Returns whether the shape is of the specified type (array, tuple, etc). IsArray()66 bool IsArray() const { return primitive_util::IsArrayType(element_type()); } IsTuple()67 bool IsTuple() const { return element_type() == TUPLE; } IsToken()68 bool IsToken() const { return element_type() == TOKEN; } IsOpaque()69 bool IsOpaque() const { return element_type() == OPAQUE_TYPE; } 70 71 // Returns whether all elements in the shape are integer. 72 // A nested tuple of integers is considered as integer. 73 bool IsInteger() const; 74 75 // Returns true if no array dimension in the shape is dynamically sized. Tuple 76 // shapes are traversed recursively. 77 bool is_static() const; 78 is_dynamic()79 bool is_dynamic() const { return !is_static(); } 80 81 // Returns true if the given dimension is dynamically-sized. is_dynamic_dimension(int dimension)82 bool is_dynamic_dimension(int dimension) const { 83 return dynamic_dimensions_.at(dimension); 84 } 85 86 // Sets whether or not the given dimension is dynamically-sized. set_dynamic_dimension(int dimension,bool is_dynamic)87 void set_dynamic_dimension(int dimension, bool is_dynamic) { 88 dynamic_dimensions_[dimension] = is_dynamic; 89 } 90 dynamic_dimensions()91 absl::Span<const bool> dynamic_dimensions() const { 92 return dynamic_dimensions_; 93 } 94 mutable_dynamic_dimensions()95 absl::Span<bool> mutable_dynamic_dimensions() { 96 return absl::MakeSpan(dynamic_dimensions_); 97 } 98 99 // Add dimension_upper_bound(). 100 101 // Removes the given dimension form the shape. Layout, if it exists, is 102 // adjusted to match the modified shape. 103 void DeleteDimension(int64_t dim_to_delete); 104 105 // The following methods mirror the protobuf generated code interface for the 106 // message ShapeProto. This enabled easy migration of this data structure 107 // from a proto to a proper C++ class. 108 // TODO(b/29771030): Replace or augment these methods with a more ergonomic 109 // interface. 110 111 // Methods for accessing the primitive type. element_type()112 PrimitiveType element_type() const { return element_type_; } set_element_type(PrimitiveType value)113 void set_element_type(PrimitiveType value) { element_type_ = value; } 114 115 // Methods for accessing the dimensions array. dimensions_size()116 int dimensions_size() const { return dimensions_.size(); } dimensions(int index)117 int64_t dimensions(int index) const { return dimensions_.at(index); } dimensions_minor(int index)118 int64_t dimensions_minor(int index) const { 119 CHECK(has_layout()); 120 return dimensions_.at(layout_->minor_to_major(index)); 121 } set_dimensions(int index,int64_t value)122 void set_dimensions(int index, int64_t value) { 123 dimensions_.at(index) = value; 124 } set_dimensions_minor(int index,int64_t value)125 void set_dimensions_minor(int index, int64_t value) { 126 CHECK(has_layout()); 127 dimensions_.at(layout_->minor_to_major(index)) = value; 128 } add_dimensions(int64_t value)129 void add_dimensions(int64_t value) { 130 dimensions_.push_back(value); 131 dynamic_dimensions_.push_back(false); 132 } clear_dimensions()133 void clear_dimensions() { 134 dimensions_.clear(); 135 dynamic_dimensions_.clear(); 136 } dimensions()137 absl::Span<const int64_t> dimensions() const { return dimensions_; } mutable_dimensions()138 absl::Span<int64_t> mutable_dimensions() { 139 return absl::MakeSpan(dimensions_); 140 } 141 142 // Methods for accessing the tuple subshapes. This field only non-empty for 143 // tuple shapes. tuple_shapes_size()144 int tuple_shapes_size() const { return tuple_shapes_.size(); } tuple_shapes(int index)145 const Shape& tuple_shapes(int index) const { return tuple_shapes_.at(index); } mutable_tuple_shapes(int index)146 Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); } add_tuple_shapes()147 Shape* add_tuple_shapes() { 148 tuple_shapes_.push_back(Shape()); 149 return &tuple_shapes_.back(); 150 } clear_tuple_shapes()151 void clear_tuple_shapes() { tuple_shapes_.clear(); } tuple_shapes()152 const std::vector<Shape>& tuple_shapes() const { return tuple_shapes_; } mutable_tuple_shapes()153 std::vector<Shape>* mutable_tuple_shapes() { return &tuple_shapes_; } 154 155 // Methods for accessing the layout field. has_layout()156 bool has_layout() const { return layout_ != std::nullopt; } layout()157 const Layout& layout() const { 158 CHECK(has_layout()) << ShortDebugString(); 159 return *layout_; 160 } mutable_layout()161 Layout* mutable_layout() { 162 CHECK(IsArray()) << ShortDebugString(); 163 if (layout_ == std::nullopt) { 164 layout_.emplace(); 165 } 166 return &(*layout_); 167 } clear_layout()168 void clear_layout() { layout_ = std::nullopt; } 169 170 // Recursively clear dynamic dimension of a shape. clear_dynamic_dimensions()171 void clear_dynamic_dimensions() { 172 if (!IsTuple()) { 173 for (int64_t i = 0; i < dynamic_dimensions_.size(); ++i) { 174 dynamic_dimensions_[i] = false; 175 } 176 return; 177 } 178 for (auto& subshape : tuple_shapes_) { 179 subshape.clear_dynamic_dimensions(); 180 } 181 } 182 Swap(Shape * other)183 void Swap(Shape* other) { 184 using std::swap; 185 swap(*this, *other); 186 } 187 Clear()188 void Clear() { 189 element_type_ = PRIMITIVE_TYPE_INVALID; 190 clear_dimensions(); 191 tuple_shapes_.clear(); 192 clear_layout(); 193 } 194 SerializeAsString()195 std::string SerializeAsString() const { 196 return ToProto().SerializeAsString(); 197 } ShortDebugString()198 std::string ShortDebugString() const { return ToProto().ShortDebugString(); } DebugString()199 std::string DebugString() const { return ToProto().DebugString(); } 200 201 // Equal is a configurable functor to check the equality of two shapes. 202 // 203 // Examples: 204 // 205 // - Comparing two shapes ignoring their layout difference: 206 // Equal().IgnoreLayout()(shape1, shape2); 207 // 208 // - Comparing two shapes ignoring their layout and element type difference: 209 // Equal().IgnoreLayout().IgnoreElementType()(shape1, shape2); 210 class Equal { 211 public: 212 Equal() = default; 213 214 bool operator()(const Shape& lhs, const Shape& rhs); 215 IgnoreLayout()216 Equal& IgnoreLayout() { 217 ignore_layout_ = true; 218 return *this; 219 } IgnoreTilesInLayout()220 Equal& IgnoreTilesInLayout() { 221 ignore_tiles_in_layout_ = true; 222 return *this; 223 } IgnoreElementSizeInLayout()224 Equal& IgnoreElementSizeInLayout() { 225 ignore_element_size_in_layout_ = true; 226 return *this; 227 } IgnoreMemorySpaceInLayout()228 Equal& IgnoreMemorySpaceInLayout() { 229 ignore_memory_space_in_layout_ = true; 230 return *this; 231 } MinorToMajorOnlyInLayout()232 Equal& MinorToMajorOnlyInLayout() { 233 ignore_tiles_in_layout_ = true; 234 ignore_element_size_in_layout_ = true; 235 ignore_memory_space_in_layout_ = true; 236 return *this; 237 } IgnoreElementType()238 Equal& IgnoreElementType() { 239 ignore_element_type_ = true; 240 return *this; 241 } IgnoreFpPrecision()242 Equal& IgnoreFpPrecision() { 243 ignore_fp_precision_ = true; 244 return *this; 245 } IgnoreDynamicDimension()246 Equal& IgnoreDynamicDimension() { 247 ignore_dynamic_dimension_ = true; 248 return *this; 249 } IgnoreDimensions()250 Equal& IgnoreDimensions() { 251 ignore_dimensions_ = true; 252 return *this; 253 } 254 255 private: 256 bool ignore_layout_ = false; 257 bool ignore_tiles_in_layout_ = false; 258 bool ignore_element_size_in_layout_ = false; 259 bool ignore_memory_space_in_layout_ = false; 260 bool ignore_element_type_ = false; 261 bool ignore_fp_precision_ = false; 262 bool ignore_dynamic_dimension_ = false; 263 bool ignore_dimensions_ = false; 264 }; 265 266 // Test that all fields of the shape are the same, equivalent to Equal(). 267 bool operator==(const Shape& other) const { return Equal()(*this, other); } 268 bool operator!=(const Shape& other) const { return !(*this == other); } 269 270 template <typename H, bool kIsLayoutSensitive = true> Hash(H h,const Shape & s)271 static H Hash(H h, const Shape& s) { 272 if (s.IsTuple()) { 273 for (const Shape& subshape : s.tuple_shapes_) { 274 h = Shape::Hash<H, kIsLayoutSensitive>(std::move(h), subshape); 275 } 276 return H::combine(std::move(h), s.tuple_shapes_size()); 277 } 278 h = H::combine(std::move(h), s.element_type_, s.dimensions_, 279 s.dynamic_dimensions_); 280 if (kIsLayoutSensitive) { 281 h = H::combine(std::move(h), s.layout_); 282 } 283 return std::move(h); 284 } 285 286 template <typename H> AbslHashValue(H h,const Shape & s)287 friend H AbslHashValue(H h, const Shape& s) { 288 return Shape::Hash(std::move(h), s); 289 } 290 291 private: 292 // The element type of this shape (tuple, array, etc). 293 PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID; 294 295 // The array bounds of the dimensions. This is nonempty only for array 296 // shapes. For a dynamically-sized dimension, the respective value in this 297 // vector is an inclusive upper limit of the array bound. 298 DimensionVector dimensions_; 299 300 // This vector is the same size as 'dimensions_' and indicates whether the 301 // respective dimension is dynamically sized. 302 absl::InlinedVector<bool, InlineRank()> dynamic_dimensions_; 303 304 // The tuple element subshapes. This is nonempty only for tuple shapes. 305 std::vector<Shape> tuple_shapes_; 306 307 // The layout of the shape. Only relevant for arrays. 308 std::optional<Layout> layout_; 309 }; 310 311 // Shape of the parameters and output of an XLA computation. This is analogous 312 // to a traditional function signature. 313 class ProgramShape { 314 public: 315 ProgramShape() = default; 316 317 // Creates a ProgramShape from a ProgramShapeProto protobuf. 318 explicit ProgramShape(const ProgramShapeProto& program_shape_proto); 319 320 // Returns a proto representation of the object. 321 ProgramShapeProto ToProto() const; 322 323 std::string ToString() const; 324 325 // The following methods mirror the protobuf generated code interface for the 326 // message ProgramShapeProto. This enabled easy migration of this data 327 // structure from a proto to a proper C++ class. 328 // TODO(b/29771030): Replace or augment these methods with a more ergonomic 329 // interface. 330 331 // Methods for accessing and manipulating the Shape of the parameters. parameters_size()332 int parameters_size() const { return parameters_.size(); } parameters(int index)333 const Shape& parameters(int index) const { return parameters_.at(index); } mutable_parameters(int index)334 Shape* mutable_parameters(int index) { return ¶meters_.at(index); } add_parameters()335 Shape* add_parameters() { 336 parameters_.emplace_back(); 337 return ¶meters_.back(); 338 } clear_parameters()339 void clear_parameters() { parameters_.clear(); } parameters()340 const std::vector<Shape>& parameters() const { return parameters_; } mutable_parameters()341 std::vector<Shape>* mutable_parameters() { return ¶meters_; } 342 343 // Methods for accessing and manipulating the Shape of the result. result()344 const Shape& result() const { return result_; } mutable_result()345 Shape* mutable_result() { return &result_; } 346 347 // Methods for accessing and manipulating the names of the parameters. parameter_names_size()348 int parameter_names_size() const { return parameter_names_.size(); } parameter_names(int index)349 const std::string& parameter_names(int index) const { 350 return parameter_names_.at(index); 351 } set_parameter_names(int index,const std::string & value)352 void set_parameter_names(int index, const std::string& value) { 353 parameter_names_.at(index) = value; 354 } mutable_parameter_names(int index)355 std::string* mutable_parameter_names(int index) { 356 return ¶meter_names_.at(index); 357 } add_parameter_names(const std::string & value)358 void add_parameter_names(const std::string& value) { 359 parameter_names_.push_back(value); 360 } add_parameter_names()361 std::string* add_parameter_names() { 362 parameter_names_.push_back(""); 363 return ¶meter_names_.back(); 364 } clear_parameter_names()365 void clear_parameter_names() { parameter_names_.clear(); } parameter_names()366 const std::vector<std::string>& parameter_names() const { 367 return parameter_names_; 368 } mutable_parameter_names()369 std::vector<std::string>* mutable_parameter_names() { 370 return ¶meter_names_; 371 } 372 ShortDebugString()373 std::string ShortDebugString() const { return ToProto().ShortDebugString(); } DebugString()374 std::string DebugString() const { return ToProto().DebugString(); } 375 376 private: 377 // The shapes of the parameters of the computation represented by this object. 378 std::vector<Shape> parameters_; 379 380 // The names of the parameters of the computation represented by this object. 381 std::vector<std::string> parameter_names_; 382 383 // The shape of the result of the computation represented by this object. 384 Shape result_; 385 }; 386 387 std::ostream& operator<<(std::ostream& out, const Shape& shape); 388 std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape); 389 390 } // namespace xla 391 392 #endif // TENSORFLOW_COMPILER_XLA_SHAPE_H_ 393