xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/shape.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_
17 #define TENSORFLOW_COMPILER_XLA_SHAPE_H_
18 
19 #include <optional>
20 #include <ostream>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/inlined_vector.h"
26 #include "tensorflow/compiler/xla/layout.h"
27 #include "tensorflow/compiler/xla/primitive_util.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 
30 namespace xla {
31 
32 // A shape describes the number of dimensions in a array, the bounds of each
33 // dimension, and the primitive component type. For tuples, shape describes the
34 // structure (number of elements and nesting).
35 class Shape {
36  public:
37   Shape() = default;
38 
39   // Construct a shape from a ShapeProto.
40   explicit Shape(const ShapeProto& shape_proto);
41 
Shape(PrimitiveType element_type,absl::Span<const int64_t> dimensions,absl::Span<const bool> dynamic_dimensions,std::vector<Shape> tuple_shapes)42   Shape(PrimitiveType element_type, absl::Span<const int64_t> dimensions,
43         absl::Span<const bool> dynamic_dimensions,
44         std::vector<Shape> tuple_shapes)
45       : element_type_(element_type),
46         dimensions_(dimensions.begin(), dimensions.end()),
47         dynamic_dimensions_(dynamic_dimensions.begin(),
48                             dynamic_dimensions.end()),
49         tuple_shapes_(std::move(tuple_shapes)) {}
50 
51   // Returns a ShapeProto representation of the Shape.
52   ShapeProto ToProto() const;
53 
54   // Returns a human-readable string that represents the given shape, with or
55   // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]".
56   std::string ToString(bool print_layout = false) const;
57 
58   // Returns the rank (number of dimensions) of the given shape. Shape must be
59   // an array.
rank()60   int64_t rank() const {
61     DCHECK(IsArray()) << "Non-arrays do not have a rank, shape: " << ToString();
62     return dimensions_.size();
63   }
64 
65   // Returns whether the shape is of the specified type (array, tuple, etc).
IsArray()66   bool IsArray() const { return primitive_util::IsArrayType(element_type()); }
IsTuple()67   bool IsTuple() const { return element_type() == TUPLE; }
IsToken()68   bool IsToken() const { return element_type() == TOKEN; }
IsOpaque()69   bool IsOpaque() const { return element_type() == OPAQUE_TYPE; }
70 
71   // Returns whether all elements in the shape are integer.
72   // A nested tuple of integers is considered as integer.
73   bool IsInteger() const;
74 
75   // Returns true if no array dimension in the shape is dynamically sized. Tuple
76   // shapes are traversed recursively.
77   bool is_static() const;
78 
is_dynamic()79   bool is_dynamic() const { return !is_static(); }
80 
81   // Returns true if the given dimension is dynamically-sized.
is_dynamic_dimension(int dimension)82   bool is_dynamic_dimension(int dimension) const {
83     return dynamic_dimensions_.at(dimension);
84   }
85 
86   // Sets whether or not the given dimension is dynamically-sized.
set_dynamic_dimension(int dimension,bool is_dynamic)87   void set_dynamic_dimension(int dimension, bool is_dynamic) {
88     dynamic_dimensions_[dimension] = is_dynamic;
89   }
90 
dynamic_dimensions()91   absl::Span<const bool> dynamic_dimensions() const {
92     return dynamic_dimensions_;
93   }
94 
mutable_dynamic_dimensions()95   absl::Span<bool> mutable_dynamic_dimensions() {
96     return absl::MakeSpan(dynamic_dimensions_);
97   }
98 
99   // Add dimension_upper_bound().
100 
101   // Removes the given dimension form the shape. Layout, if it exists, is
102   // adjusted to match the modified shape.
103   void DeleteDimension(int64_t dim_to_delete);
104 
105   // The following methods mirror the protobuf generated code interface for the
106   // message ShapeProto. This enabled easy migration of this data structure
107   // from a proto to a proper C++ class.
108   // TODO(b/29771030): Replace or augment these methods with a more ergonomic
109   // interface.
110 
111   // Methods for accessing the primitive type.
element_type()112   PrimitiveType element_type() const { return element_type_; }
set_element_type(PrimitiveType value)113   void set_element_type(PrimitiveType value) { element_type_ = value; }
114 
115   // Methods for accessing the dimensions array.
dimensions_size()116   int dimensions_size() const { return dimensions_.size(); }
dimensions(int index)117   int64_t dimensions(int index) const { return dimensions_.at(index); }
dimensions_minor(int index)118   int64_t dimensions_minor(int index) const {
119     CHECK(has_layout());
120     return dimensions_.at(layout_->minor_to_major(index));
121   }
set_dimensions(int index,int64_t value)122   void set_dimensions(int index, int64_t value) {
123     dimensions_.at(index) = value;
124   }
set_dimensions_minor(int index,int64_t value)125   void set_dimensions_minor(int index, int64_t value) {
126     CHECK(has_layout());
127     dimensions_.at(layout_->minor_to_major(index)) = value;
128   }
add_dimensions(int64_t value)129   void add_dimensions(int64_t value) {
130     dimensions_.push_back(value);
131     dynamic_dimensions_.push_back(false);
132   }
clear_dimensions()133   void clear_dimensions() {
134     dimensions_.clear();
135     dynamic_dimensions_.clear();
136   }
dimensions()137   absl::Span<const int64_t> dimensions() const { return dimensions_; }
mutable_dimensions()138   absl::Span<int64_t> mutable_dimensions() {
139     return absl::MakeSpan(dimensions_);
140   }
141 
142   // Methods for accessing the tuple subshapes. This field only non-empty for
143   // tuple shapes.
tuple_shapes_size()144   int tuple_shapes_size() const { return tuple_shapes_.size(); }
tuple_shapes(int index)145   const Shape& tuple_shapes(int index) const { return tuple_shapes_.at(index); }
mutable_tuple_shapes(int index)146   Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); }
add_tuple_shapes()147   Shape* add_tuple_shapes() {
148     tuple_shapes_.push_back(Shape());
149     return &tuple_shapes_.back();
150   }
clear_tuple_shapes()151   void clear_tuple_shapes() { tuple_shapes_.clear(); }
tuple_shapes()152   const std::vector<Shape>& tuple_shapes() const { return tuple_shapes_; }
mutable_tuple_shapes()153   std::vector<Shape>* mutable_tuple_shapes() { return &tuple_shapes_; }
154 
155   // Methods for accessing the layout field.
has_layout()156   bool has_layout() const { return layout_ != std::nullopt; }
layout()157   const Layout& layout() const {
158     CHECK(has_layout()) << ShortDebugString();
159     return *layout_;
160   }
mutable_layout()161   Layout* mutable_layout() {
162     CHECK(IsArray()) << ShortDebugString();
163     if (layout_ == std::nullopt) {
164       layout_.emplace();
165     }
166     return &(*layout_);
167   }
clear_layout()168   void clear_layout() { layout_ = std::nullopt; }
169 
170   // Recursively clear dynamic dimension of a shape.
clear_dynamic_dimensions()171   void clear_dynamic_dimensions() {
172     if (!IsTuple()) {
173       for (int64_t i = 0; i < dynamic_dimensions_.size(); ++i) {
174         dynamic_dimensions_[i] = false;
175       }
176       return;
177     }
178     for (auto& subshape : tuple_shapes_) {
179       subshape.clear_dynamic_dimensions();
180     }
181   }
182 
Swap(Shape * other)183   void Swap(Shape* other) {
184     using std::swap;
185     swap(*this, *other);
186   }
187 
Clear()188   void Clear() {
189     element_type_ = PRIMITIVE_TYPE_INVALID;
190     clear_dimensions();
191     tuple_shapes_.clear();
192     clear_layout();
193   }
194 
SerializeAsString()195   std::string SerializeAsString() const {
196     return ToProto().SerializeAsString();
197   }
ShortDebugString()198   std::string ShortDebugString() const { return ToProto().ShortDebugString(); }
DebugString()199   std::string DebugString() const { return ToProto().DebugString(); }
200 
201   // Equal is a configurable functor to check the equality of two shapes.
202   //
203   // Examples:
204   //
205   // - Comparing two shapes ignoring their layout difference:
206   //   Equal().IgnoreLayout()(shape1, shape2);
207   //
208   // - Comparing two shapes ignoring their layout and element type difference:
209   //   Equal().IgnoreLayout().IgnoreElementType()(shape1, shape2);
210   class Equal {
211    public:
212     Equal() = default;
213 
214     bool operator()(const Shape& lhs, const Shape& rhs);
215 
IgnoreLayout()216     Equal& IgnoreLayout() {
217       ignore_layout_ = true;
218       return *this;
219     }
IgnoreTilesInLayout()220     Equal& IgnoreTilesInLayout() {
221       ignore_tiles_in_layout_ = true;
222       return *this;
223     }
IgnoreElementSizeInLayout()224     Equal& IgnoreElementSizeInLayout() {
225       ignore_element_size_in_layout_ = true;
226       return *this;
227     }
IgnoreMemorySpaceInLayout()228     Equal& IgnoreMemorySpaceInLayout() {
229       ignore_memory_space_in_layout_ = true;
230       return *this;
231     }
MinorToMajorOnlyInLayout()232     Equal& MinorToMajorOnlyInLayout() {
233       ignore_tiles_in_layout_ = true;
234       ignore_element_size_in_layout_ = true;
235       ignore_memory_space_in_layout_ = true;
236       return *this;
237     }
IgnoreElementType()238     Equal& IgnoreElementType() {
239       ignore_element_type_ = true;
240       return *this;
241     }
IgnoreFpPrecision()242     Equal& IgnoreFpPrecision() {
243       ignore_fp_precision_ = true;
244       return *this;
245     }
IgnoreDynamicDimension()246     Equal& IgnoreDynamicDimension() {
247       ignore_dynamic_dimension_ = true;
248       return *this;
249     }
IgnoreDimensions()250     Equal& IgnoreDimensions() {
251       ignore_dimensions_ = true;
252       return *this;
253     }
254 
255    private:
256     bool ignore_layout_ = false;
257     bool ignore_tiles_in_layout_ = false;
258     bool ignore_element_size_in_layout_ = false;
259     bool ignore_memory_space_in_layout_ = false;
260     bool ignore_element_type_ = false;
261     bool ignore_fp_precision_ = false;
262     bool ignore_dynamic_dimension_ = false;
263     bool ignore_dimensions_ = false;
264   };
265 
266   // Test that all fields of the shape are the same, equivalent to Equal().
267   bool operator==(const Shape& other) const { return Equal()(*this, other); }
268   bool operator!=(const Shape& other) const { return !(*this == other); }
269 
270   template <typename H, bool kIsLayoutSensitive = true>
Hash(H h,const Shape & s)271   static H Hash(H h, const Shape& s) {
272     if (s.IsTuple()) {
273       for (const Shape& subshape : s.tuple_shapes_) {
274         h = Shape::Hash<H, kIsLayoutSensitive>(std::move(h), subshape);
275       }
276       return H::combine(std::move(h), s.tuple_shapes_size());
277     }
278     h = H::combine(std::move(h), s.element_type_, s.dimensions_,
279                    s.dynamic_dimensions_);
280     if (kIsLayoutSensitive) {
281       h = H::combine(std::move(h), s.layout_);
282     }
283     return std::move(h);
284   }
285 
286   template <typename H>
AbslHashValue(H h,const Shape & s)287   friend H AbslHashValue(H h, const Shape& s) {
288     return Shape::Hash(std::move(h), s);
289   }
290 
291  private:
292   // The element type of this shape (tuple, array, etc).
293   PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID;
294 
295   // The array bounds of the dimensions. This is nonempty only for array
296   // shapes. For a dynamically-sized dimension, the respective value in this
297   // vector is an inclusive upper limit of the array bound.
298   DimensionVector dimensions_;
299 
300   // This vector is the same size as 'dimensions_' and indicates whether the
301   // respective dimension is dynamically sized.
302   absl::InlinedVector<bool, InlineRank()> dynamic_dimensions_;
303 
304   // The tuple element subshapes. This is nonempty only for tuple shapes.
305   std::vector<Shape> tuple_shapes_;
306 
307   // The layout of the shape. Only relevant for arrays.
308   std::optional<Layout> layout_;
309 };
310 
311 // Shape of the parameters and output of an XLA computation. This is analogous
312 // to a traditional function signature.
313 class ProgramShape {
314  public:
315   ProgramShape() = default;
316 
317   // Creates a ProgramShape from a ProgramShapeProto protobuf.
318   explicit ProgramShape(const ProgramShapeProto& program_shape_proto);
319 
320   // Returns a proto representation of the object.
321   ProgramShapeProto ToProto() const;
322 
323   std::string ToString() const;
324 
325   // The following methods mirror the protobuf generated code interface for the
326   // message ProgramShapeProto. This enabled easy migration of this data
327   // structure from a proto to a proper C++ class.
328   // TODO(b/29771030): Replace or augment these methods with a more ergonomic
329   // interface.
330 
331   // Methods for accessing and manipulating the Shape of the parameters.
parameters_size()332   int parameters_size() const { return parameters_.size(); }
parameters(int index)333   const Shape& parameters(int index) const { return parameters_.at(index); }
mutable_parameters(int index)334   Shape* mutable_parameters(int index) { return &parameters_.at(index); }
add_parameters()335   Shape* add_parameters() {
336     parameters_.emplace_back();
337     return &parameters_.back();
338   }
clear_parameters()339   void clear_parameters() { parameters_.clear(); }
parameters()340   const std::vector<Shape>& parameters() const { return parameters_; }
mutable_parameters()341   std::vector<Shape>* mutable_parameters() { return &parameters_; }
342 
343   // Methods for accessing and manipulating the Shape of the result.
result()344   const Shape& result() const { return result_; }
mutable_result()345   Shape* mutable_result() { return &result_; }
346 
347   // Methods for accessing and manipulating the names of the parameters.
parameter_names_size()348   int parameter_names_size() const { return parameter_names_.size(); }
parameter_names(int index)349   const std::string& parameter_names(int index) const {
350     return parameter_names_.at(index);
351   }
set_parameter_names(int index,const std::string & value)352   void set_parameter_names(int index, const std::string& value) {
353     parameter_names_.at(index) = value;
354   }
mutable_parameter_names(int index)355   std::string* mutable_parameter_names(int index) {
356     return &parameter_names_.at(index);
357   }
add_parameter_names(const std::string & value)358   void add_parameter_names(const std::string& value) {
359     parameter_names_.push_back(value);
360   }
add_parameter_names()361   std::string* add_parameter_names() {
362     parameter_names_.push_back("");
363     return &parameter_names_.back();
364   }
clear_parameter_names()365   void clear_parameter_names() { parameter_names_.clear(); }
parameter_names()366   const std::vector<std::string>& parameter_names() const {
367     return parameter_names_;
368   }
mutable_parameter_names()369   std::vector<std::string>* mutable_parameter_names() {
370     return &parameter_names_;
371   }
372 
ShortDebugString()373   std::string ShortDebugString() const { return ToProto().ShortDebugString(); }
DebugString()374   std::string DebugString() const { return ToProto().DebugString(); }
375 
376  private:
377   // The shapes of the parameters of the computation represented by this object.
378   std::vector<Shape> parameters_;
379 
380   // The names of the parameters of the computation represented by this object.
381   std::vector<std::string> parameter_names_;
382 
383   // The shape of the result of the computation represented by this object.
384   Shape result_;
385 };
386 
387 std::ostream& operator<<(std::ostream& out, const Shape& shape);
388 std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape);
389 
390 }  // namespace xla
391 
392 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_H_
393