1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/shape_util.h"
17
18 #include <numeric>
19
20 #include "tensorflow/compiler/tf2xla/type_util.h"
21 #include "tensorflow/compiler/xla/layout_util.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/core/lib/core/status.h"
24
25 namespace tensorflow {
26 namespace {
27
PopulateInfeedLayoutVector(const xla::Shape & shape,std::vector<int> * layouts)28 Status PopulateInfeedLayoutVector(const xla::Shape& shape,
29 std::vector<int>* layouts) {
30 if (shape.IsTuple()) {
31 int64_t tuple_elements = xla::ShapeUtil::TupleElementCount(shape);
32 for (int64_t i = 0; i < tuple_elements; ++i) {
33 const xla::Shape& subshape =
34 xla::ShapeUtil::GetTupleElementShape(shape, i);
35 TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(subshape, layouts));
36 }
37 } else if (xla::LayoutUtil::HasLayout(shape)) {
38 for (auto dim : xla::LayoutUtil::MinorToMajor(shape)) {
39 layouts->push_back(dim);
40 }
41 } else {
42 layouts->insert(layouts->end(), shape.rank(), -1);
43 }
44 return OkStatus();
45 }
46
47 // Populate the output layout unless the minor_to_major array contains all -1
48 // value, in which case the layout is considered missing and the API returns
49 // false.
MakeLayout(absl::Span<const int64_t> minor_to_major,xla::Layout * layout)50 StatusOr<bool> MakeLayout(absl::Span<const int64_t> minor_to_major,
51 xla::Layout* layout) {
52 if (std::all_of(minor_to_major.begin(), minor_to_major.end(),
53 [](int64_t dim) { return dim == -1; })) {
54 return false;
55 }
56 std::vector<bool> dim_present(minor_to_major.size(), false);
57 for (auto dim : minor_to_major) {
58 const int minor_to_major_size = minor_to_major.size();
59 if (dim < 0 || dim >= minor_to_major_size) {
60 return errors::InvalidArgument("Layout dimension out of range: dim=", dim,
61 " rank=", minor_to_major.size());
62 }
63 if (dim_present[dim]) {
64 return errors::InvalidArgument("Repeated layout dimension: dim=", dim);
65 }
66 dim_present[dim] = true;
67 }
68 *layout = xla::LayoutUtil::MakeLayout(minor_to_major);
69 return true;
70 }
71
AssignLayout(absl::Span<const int64_t> minor_to_major,const std::function<xla::Layout (const xla::Shape &)> & layout_func,xla::Shape * shape)72 Status AssignLayout(
73 absl::Span<const int64_t> minor_to_major,
74 const std::function<xla::Layout(const xla::Shape&)>& layout_func,
75 xla::Shape* shape) {
76 xla::Layout layout;
77 TF_ASSIGN_OR_RETURN(bool has_layout, MakeLayout(minor_to_major, &layout));
78 if (!has_layout && layout_func) {
79 layout = layout_func(*shape);
80 }
81 *shape->mutable_layout() = layout;
82 return OkStatus();
83 }
84
85 } // namespace
86
87 // Convert an XLA Shape into the equivalent TensorFlow shape.
XLAShapeToTensorShape(const xla::Shape & shape,TensorShape * tensor_shape)88 Status XLAShapeToTensorShape(const xla::Shape& shape,
89 TensorShape* tensor_shape) {
90 if (shape.IsTuple()) {
91 return errors::InvalidArgument("XLA shape ",
92 xla::ShapeUtil::HumanString(shape),
93 " cannot be converted to a TensorShape");
94 }
95 *tensor_shape = TensorShape();
96 for (int i = 0; i < shape.rank(); ++i) {
97 tensor_shape->AddDim(shape.dimensions(i));
98 }
99 return OkStatus();
100 }
101
102 // Convert a TensorShape into the equivalent XLA Shape proto.
TensorShapeToXLAShape(DataType dtype,const PartialTensorShape & tensor_shape,xla::Shape * shape)103 Status TensorShapeToXLAShape(DataType dtype,
104 const PartialTensorShape& tensor_shape,
105 xla::Shape* shape) {
106 xla::PrimitiveType type;
107 TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
108 *shape = TensorShapeToXLAShape(type, tensor_shape);
109 return OkStatus();
110 }
111
TensorShapeToXLAShape(xla::PrimitiveType type,const PartialTensorShape & tensor_shape)112 xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
113 const PartialTensorShape& tensor_shape) {
114 if (tensor_shape.unknown_rank()) {
115 // For unknown shape, create a rank 1 size 0 tensor.
116 return xla::ShapeUtil::MakeShapeWithLayout(type, {0}, {0});
117 }
118 int rank = tensor_shape.dims();
119 std::vector<int64_t> dimensions(rank);
120 std::vector<int64_t> layout(rank);
121 for (int d = 0; d < rank; ++d) {
122 dimensions[d] = tensor_shape.dim_size(d);
123 if (dimensions[d] < 0) {
124 LOG(WARNING) << "Unable to convert TF shape with dynamic size to XLA "
125 "shape; returning unknown sentinel value";
126 return xla::ShapeUtil::MakeShapeWithLayout(type, {0}, {0});
127 }
128 }
129 // XLA uses minor-to-major; Tensorflow uses major-to-minor.
130 std::iota(layout.rbegin(), layout.rend(), 0);
131 xla::Shape result =
132 xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
133 return result;
134 }
135
136 // Convert a TensorShape into the equivalent XLA Shape proto.
TensorShapeToXLAShape(DataType dtype,const TensorShape & tensor_shape,xla::Shape * shape)137 Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
138 xla::Shape* shape) {
139 xla::PrimitiveType type;
140 TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
141 *shape = TensorShapeToXLAShape(type, tensor_shape);
142 return OkStatus();
143 }
144
TensorShapeToXLAShape(DataType dtype,const TensorShape & tensor_shape)145 StatusOr<xla::Shape> TensorShapeToXLAShape(DataType dtype,
146 const TensorShape& tensor_shape) {
147 xla::Shape out;
148 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, tensor_shape, &out));
149 return out;
150 }
151
TensorShapeToXLAShape(xla::PrimitiveType type,const TensorShape & tensor_shape)152 xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
153 const TensorShape& tensor_shape) {
154 int rank = tensor_shape.dims();
155 std::vector<int64_t> dimensions(rank);
156 std::vector<int64_t> layout(rank);
157 for (int d = 0; d < rank; ++d) {
158 dimensions[d] = tensor_shape.dim_size(d);
159 }
160 // XLA uses minor-to-major; Tensorflow uses major-to-minor.
161 std::iota(layout.rbegin(), layout.rend(), 0);
162
163 return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
164 }
165
GetShapeLayoutVector(const xla::Shape & shape)166 StatusOr<std::vector<int>> GetShapeLayoutVector(const xla::Shape& shape) {
167 std::vector<int> layouts;
168 TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(shape, &layouts));
169 return layouts;
170 }
171
GetShapeWithLayout(const xla::Shape & input_shape,absl::Span<const int64_t> minor_to_major,const std::function<xla::Layout (const xla::Shape &)> & layout_func,xla::Shape * output_shape)172 Status GetShapeWithLayout(
173 const xla::Shape& input_shape, absl::Span<const int64_t> minor_to_major,
174 const std::function<xla::Layout(const xla::Shape&)>& layout_func,
175 xla::Shape* output_shape) {
176 if (input_shape.IsTuple()) {
177 int64_t tuple_elements = xla::ShapeUtil::TupleElementCount(input_shape);
178 std::vector<xla::Shape> shapes;
179 shapes.reserve(tuple_elements);
180 size_t position = 0;
181 for (int64_t i = 0; i < tuple_elements; ++i) {
182 const xla::Shape& shape =
183 xla::ShapeUtil::GetTupleElementShape(input_shape, i);
184 if (shape.IsTuple()) {
185 return errors::InvalidArgument(
186 "Nested tuples not supported: ",
187 xla::ShapeUtil::HumanString(input_shape));
188 }
189 int64_t rank = shape.rank();
190 if (position + rank > minor_to_major.size()) {
191 return errors::InvalidArgument(
192 "Not enough layout attribute elements: position=", position,
193 " rank=", rank, " elements=", minor_to_major.size());
194 }
195 shapes.push_back(shape);
196 TF_RETURN_IF_ERROR(AssignLayout(
197 absl::Span<const int64_t>(minor_to_major).subspan(position, rank),
198 layout_func, &shapes.back()));
199 position += rank;
200
201 VLOG(4) << "Shape[" << i
202 << "] = " << xla::ShapeUtil::HumanStringWithLayout(shapes.back());
203 }
204 if (position != minor_to_major.size()) {
205 return errors::InvalidArgument(
206 "Too many elements passed in the layout attribute: position=",
207 position, " size=", minor_to_major.size());
208 }
209 *output_shape = xla::ShapeUtil::MakeTupleShape(shapes);
210 } else {
211 int64_t rank = input_shape.rank();
212 const int64_t minor_to_major_size = minor_to_major.size();
213 if (rank != minor_to_major_size) {
214 return errors::InvalidArgument(
215 "Wrong number of layout attribute elements: rank=", rank,
216 " elements=", minor_to_major.size());
217 }
218 *output_shape = input_shape;
219 TF_RETURN_IF_ERROR(AssignLayout(minor_to_major, layout_func, output_shape));
220
221 VLOG(4) << "Shape[] = "
222 << xla::ShapeUtil::HumanStringWithLayout(*output_shape);
223 }
224 return OkStatus();
225 }
226
227 } // namespace tensorflow
228