1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/layout_util.h"
17
18 #include "tensorflow/compiler/tf2xla/shape_util.h"
19 #include "tensorflow/compiler/tf2xla/type_util.h"
20 #include "tensorflow/core/lib/core/status.h"
21
22 namespace tensorflow {
23
ShapeDeterminationFns()24 XlaShapeLayoutHelpers::ShapeDeterminationFns::ShapeDeterminationFns() {
25 layout_preference_fn = UseNoPreferenceLayoutFn();
26 shape_representation_fn = IdentityShapeRepresentationFn();
27 }
28
UseNoPreferenceLayoutFn()29 XlaShapeLayoutHelpers::LayoutPreferenceFn UseNoPreferenceLayoutFn() {
30 return [](const TensorShape& shape, DataType dtype,
31 std::optional<XlaArgument::Kind>) -> XlaLayoutPreference {
32 return XlaLayoutPreference::kNoPreference;
33 };
34 }
35
36 // Rewrites the layout of xla_shape if there is tiled sharding.
RewriteLayoutWithShardedShape(const std::optional<xla::HloSharding> & sharding,bool use_fast_memory,XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,xla::Shape * xla_shape)37 Status RewriteLayoutWithShardedShape(
38 const std::optional<xla::HloSharding>& sharding, bool use_fast_memory,
39 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
40 xla::Shape* xla_shape) {
41 if (sharding && !sharding->IsTileMaximal() && !sharding->IsManual()) {
42 // After sharding, per core shape might have different layout. For example,
43 // before sharding, a shape [128, 128] will be assigned default
44 // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
45 // the sharded shapes will have minor-to-major {0, 1}.
46 //
47 // As a result, for sharded shapes, we set their layout to per core shape's
48 // layout.
49 //
50 // TODO(endlessroad): for variable input & update, we might have
51 // different layouts which will prevent input output aliasing and
52 // increase memory usage. Investigate such cases.
53 int64_t device = *sharding->tile_assignment().begin();
54 std::vector<int64_t> offset =
55 sharding->TileOffsetForDevice(*xla_shape, device);
56 std::vector<int64_t> limit =
57 sharding->TileLimitForDevice(*xla_shape, device);
58 std::vector<int64_t> dimensions(xla_shape->rank());
59 for (int64_t i = 0; i < xla_shape->rank(); ++i) {
60 dimensions[i] = limit[i] - offset[i];
61 }
62 xla::Shape per_device_xla_shape =
63 xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
64 TensorShape per_device_tensor_shape;
65 TF_RETURN_IF_ERROR(
66 XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape));
67 TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
68 xla_shape->element_type()));
69 auto layout_preference = shape_determination_fns.layout_preference_fn(
70 per_device_tensor_shape, dtype, std::nullopt);
71 TF_ASSIGN_OR_RETURN(per_device_xla_shape,
72 shape_determination_fns.shape_representation_fn(
73 per_device_tensor_shape, dtype, use_fast_memory,
74 layout_preference));
75 *xla_shape->mutable_layout() = per_device_xla_shape.layout();
76 }
77 return OkStatus();
78 }
79
80 // There is a shape_representation_fn or sharding for an output, this function
81 // uses a reshape to fix the layout.
ReshapeWithCorrectRepresentationAndSharding(xla::XlaBuilder * builder,xla::XlaOp original,xla::Shape original_shape,XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,std::optional<xla::OpSharding> sharding,bool fast_mem)82 StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
83 xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
84 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
85 std::optional<xla::OpSharding> sharding, bool fast_mem) {
86 if (original_shape.IsTuple()) {
87 std::vector<xla::XlaOp> elements;
88 for (int i = 0; i < original_shape.tuple_shapes_size(); ++i) {
89 auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
90 TF_ASSIGN_OR_RETURN(auto element,
91 ReshapeWithCorrectRepresentationAndSharding(
92 builder, xla::GetTupleElement(original, i),
93 original_shape.tuple_shapes(i),
94 shape_determination_fns, subsharding, fast_mem));
95 elements.push_back(element);
96 }
97 return xla::Tuple(builder, elements);
98 }
99 if (!original_shape.IsArray()) return original;
100 TensorShape shape;
101 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
102 TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
103 original_shape.element_type()));
104 auto layout_preference =
105 shape_determination_fns.layout_preference_fn(shape, dtype, std::nullopt);
106 TF_ASSIGN_OR_RETURN(auto to_shape,
107 shape_determination_fns.shape_representation_fn(
108 shape, dtype, fast_mem, layout_preference));
109 if (sharding) {
110 TF_ASSIGN_OR_RETURN(auto hlo_sharding,
111 xla::HloSharding::FromProto(*sharding));
112
113 TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
114 hlo_sharding, fast_mem, shape_determination_fns, &to_shape));
115 }
116 if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
117 for (int64_t i = 0; i < original_shape.rank(); ++i) {
118 to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
119 }
120 }
121 return xla::Reshape(to_shape, original);
122 }
123
124 } // namespace tensorflow
125