xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/shape_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/shape_util.h"
17 
18 #include <numeric>
19 
20 #include "tensorflow/compiler/tf2xla/type_util.h"
21 #include "tensorflow/compiler/xla/layout_util.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/core/lib/core/status.h"
24 
25 namespace tensorflow {
26 namespace {
27 
PopulateInfeedLayoutVector(const xla::Shape & shape,std::vector<int> * layouts)28 Status PopulateInfeedLayoutVector(const xla::Shape& shape,
29                                   std::vector<int>* layouts) {
30   if (shape.IsTuple()) {
31     int64_t tuple_elements = xla::ShapeUtil::TupleElementCount(shape);
32     for (int64_t i = 0; i < tuple_elements; ++i) {
33       const xla::Shape& subshape =
34           xla::ShapeUtil::GetTupleElementShape(shape, i);
35       TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(subshape, layouts));
36     }
37   } else if (xla::LayoutUtil::HasLayout(shape)) {
38     for (auto dim : xla::LayoutUtil::MinorToMajor(shape)) {
39       layouts->push_back(dim);
40     }
41   } else {
42     layouts->insert(layouts->end(), shape.rank(), -1);
43   }
44   return OkStatus();
45 }
46 
47 // Populate the output layout unless the minor_to_major array contains all -1
48 // value, in which case the layout is considered missing and the API returns
49 // false.
MakeLayout(absl::Span<const int64_t> minor_to_major,xla::Layout * layout)50 StatusOr<bool> MakeLayout(absl::Span<const int64_t> minor_to_major,
51                           xla::Layout* layout) {
52   if (std::all_of(minor_to_major.begin(), minor_to_major.end(),
53                   [](int64_t dim) { return dim == -1; })) {
54     return false;
55   }
56   std::vector<bool> dim_present(minor_to_major.size(), false);
57   for (auto dim : minor_to_major) {
58     const int minor_to_major_size = minor_to_major.size();
59     if (dim < 0 || dim >= minor_to_major_size) {
60       return errors::InvalidArgument("Layout dimension out of range: dim=", dim,
61                                      " rank=", minor_to_major.size());
62     }
63     if (dim_present[dim]) {
64       return errors::InvalidArgument("Repeated layout dimension: dim=", dim);
65     }
66     dim_present[dim] = true;
67   }
68   *layout = xla::LayoutUtil::MakeLayout(minor_to_major);
69   return true;
70 }
71 
AssignLayout(absl::Span<const int64_t> minor_to_major,const std::function<xla::Layout (const xla::Shape &)> & layout_func,xla::Shape * shape)72 Status AssignLayout(
73     absl::Span<const int64_t> minor_to_major,
74     const std::function<xla::Layout(const xla::Shape&)>& layout_func,
75     xla::Shape* shape) {
76   xla::Layout layout;
77   TF_ASSIGN_OR_RETURN(bool has_layout, MakeLayout(minor_to_major, &layout));
78   if (!has_layout && layout_func) {
79     layout = layout_func(*shape);
80   }
81   *shape->mutable_layout() = layout;
82   return OkStatus();
83 }
84 
85 }  // namespace
86 
87 // Convert an XLA Shape into the equivalent TensorFlow shape.
XLAShapeToTensorShape(const xla::Shape & shape,TensorShape * tensor_shape)88 Status XLAShapeToTensorShape(const xla::Shape& shape,
89                              TensorShape* tensor_shape) {
90   if (shape.IsTuple()) {
91     return errors::InvalidArgument("XLA shape ",
92                                    xla::ShapeUtil::HumanString(shape),
93                                    " cannot be converted to a TensorShape");
94   }
95   *tensor_shape = TensorShape();
96   for (int i = 0; i < shape.rank(); ++i) {
97     tensor_shape->AddDim(shape.dimensions(i));
98   }
99   return OkStatus();
100 }
101 
102 // Convert a TensorShape into the equivalent XLA Shape proto.
TensorShapeToXLAShape(DataType dtype,const PartialTensorShape & tensor_shape,xla::Shape * shape)103 Status TensorShapeToXLAShape(DataType dtype,
104                              const PartialTensorShape& tensor_shape,
105                              xla::Shape* shape) {
106   xla::PrimitiveType type;
107   TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
108   *shape = TensorShapeToXLAShape(type, tensor_shape);
109   return OkStatus();
110 }
111 
TensorShapeToXLAShape(xla::PrimitiveType type,const PartialTensorShape & tensor_shape)112 xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
113                                  const PartialTensorShape& tensor_shape) {
114   if (tensor_shape.unknown_rank()) {
115     // For unknown shape, create a rank 1 size 0 tensor.
116     return xla::ShapeUtil::MakeShapeWithLayout(type, {0}, {0});
117   }
118   int rank = tensor_shape.dims();
119   std::vector<int64_t> dimensions(rank);
120   std::vector<int64_t> layout(rank);
121   for (int d = 0; d < rank; ++d) {
122     dimensions[d] = tensor_shape.dim_size(d);
123     if (dimensions[d] < 0) {
124       LOG(WARNING) << "Unable to convert TF shape with dynamic size to XLA "
125                       "shape; returning unknown sentinel value";
126       return xla::ShapeUtil::MakeShapeWithLayout(type, {0}, {0});
127     }
128   }
129   // XLA uses minor-to-major; Tensorflow uses major-to-minor.
130   std::iota(layout.rbegin(), layout.rend(), 0);
131   xla::Shape result =
132       xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
133   return result;
134 }
135 
136 // Convert a TensorShape into the equivalent XLA Shape proto.
TensorShapeToXLAShape(DataType dtype,const TensorShape & tensor_shape,xla::Shape * shape)137 Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
138                              xla::Shape* shape) {
139   xla::PrimitiveType type;
140   TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
141   *shape = TensorShapeToXLAShape(type, tensor_shape);
142   return OkStatus();
143 }
144 
TensorShapeToXLAShape(DataType dtype,const TensorShape & tensor_shape)145 StatusOr<xla::Shape> TensorShapeToXLAShape(DataType dtype,
146                                            const TensorShape& tensor_shape) {
147   xla::Shape out;
148   TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, tensor_shape, &out));
149   return out;
150 }
151 
TensorShapeToXLAShape(xla::PrimitiveType type,const TensorShape & tensor_shape)152 xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
153                                  const TensorShape& tensor_shape) {
154   int rank = tensor_shape.dims();
155   std::vector<int64_t> dimensions(rank);
156   std::vector<int64_t> layout(rank);
157   for (int d = 0; d < rank; ++d) {
158     dimensions[d] = tensor_shape.dim_size(d);
159   }
160   // XLA uses minor-to-major; Tensorflow uses major-to-minor.
161   std::iota(layout.rbegin(), layout.rend(), 0);
162 
163   return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
164 }
165 
GetShapeLayoutVector(const xla::Shape & shape)166 StatusOr<std::vector<int>> GetShapeLayoutVector(const xla::Shape& shape) {
167   std::vector<int> layouts;
168   TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(shape, &layouts));
169   return layouts;
170 }
171 
GetShapeWithLayout(const xla::Shape & input_shape,absl::Span<const int64_t> minor_to_major,const std::function<xla::Layout (const xla::Shape &)> & layout_func,xla::Shape * output_shape)172 Status GetShapeWithLayout(
173     const xla::Shape& input_shape, absl::Span<const int64_t> minor_to_major,
174     const std::function<xla::Layout(const xla::Shape&)>& layout_func,
175     xla::Shape* output_shape) {
176   if (input_shape.IsTuple()) {
177     int64_t tuple_elements = xla::ShapeUtil::TupleElementCount(input_shape);
178     std::vector<xla::Shape> shapes;
179     shapes.reserve(tuple_elements);
180     size_t position = 0;
181     for (int64_t i = 0; i < tuple_elements; ++i) {
182       const xla::Shape& shape =
183           xla::ShapeUtil::GetTupleElementShape(input_shape, i);
184       if (shape.IsTuple()) {
185         return errors::InvalidArgument(
186             "Nested tuples not supported: ",
187             xla::ShapeUtil::HumanString(input_shape));
188       }
189       int64_t rank = shape.rank();
190       if (position + rank > minor_to_major.size()) {
191         return errors::InvalidArgument(
192             "Not enough layout attribute elements: position=", position,
193             " rank=", rank, " elements=", minor_to_major.size());
194       }
195       shapes.push_back(shape);
196       TF_RETURN_IF_ERROR(AssignLayout(
197           absl::Span<const int64_t>(minor_to_major).subspan(position, rank),
198           layout_func, &shapes.back()));
199       position += rank;
200 
201       VLOG(4) << "Shape[" << i
202               << "] = " << xla::ShapeUtil::HumanStringWithLayout(shapes.back());
203     }
204     if (position != minor_to_major.size()) {
205       return errors::InvalidArgument(
206           "Too many elements passed in the layout attribute: position=",
207           position, " size=", minor_to_major.size());
208     }
209     *output_shape = xla::ShapeUtil::MakeTupleShape(shapes);
210   } else {
211     int64_t rank = input_shape.rank();
212     const int64_t minor_to_major_size = minor_to_major.size();
213     if (rank != minor_to_major_size) {
214       return errors::InvalidArgument(
215           "Wrong number of layout attribute elements: rank=", rank,
216           " elements=", minor_to_major.size());
217     }
218     *output_shape = input_shape;
219     TF_RETURN_IF_ERROR(AssignLayout(minor_to_major, layout_func, output_shape));
220 
221     VLOG(4) << "Shape[] = "
222             << xla::ShapeUtil::HumanStringWithLayout(*output_shape);
223   }
224   return OkStatus();
225 }
226 
227 }  // namespace tensorflow
228