1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/cc/tensor_layout.h"
17
18 #include <map>
19 #include <memory>
20 #include <ostream>
21 #include <string>
22 #include <vector>
23
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/dtensor/proto/layout.pb.h"
26
27 namespace tensorflow {
28 namespace dtensor {
29 namespace {
30
31 using ::testing::ElementsAre;
32 using ::testing::IsEmpty;
33 using ::testing::SizeIs;
34
35 // Simple implementation of a proto matcher comparing string representations.
36 // Only works as ShapeProto's textual representation is deterministic.
37 class ProtoStringMatcher {
38 public:
ProtoStringMatcher(const tensorflow::protobuf::Message & expected)39 explicit ProtoStringMatcher(const tensorflow::protobuf::Message& expected)
40 : expected_(expected.SerializeAsString()) {}
41
42 template <typename Message>
MatchAndExplain(const Message & p,::testing::MatchResultListener *) const43 bool MatchAndExplain(const Message& p,
44 ::testing::MatchResultListener*) const {
45 return p.SerializeAsString() == expected_;
46 }
47
DescribeTo(::std::ostream * os) const48 void DescribeTo(::std::ostream* os) const { *os << expected_; }
DescribeNegationTo(::std::ostream * os) const49 void DescribeNegationTo(::std::ostream* os) const {
50 *os << "not equal to expected message: " << expected_;
51 }
52
53 private:
54 const std::string expected_;
55 };
56
EqualsProto(const tensorflow::protobuf::Message & x)57 inline ::testing::PolymorphicMatcher<ProtoStringMatcher> EqualsProto(
58 const tensorflow::protobuf::Message& x) {
59 return ::testing::MakePolymorphicMatcher(ProtoStringMatcher(x));
60 }
61
62 class LayoutTest : public ::testing::Test {
63 protected:
BatchLayout()64 Layout BatchLayout() {
65 return Layout::FromString("sharding_specs:x,batch, mesh:|x=4,batch=8|*TPU")
66 .ValueOrDie();
67 }
68 };
69
TEST_F(LayoutTest,FromStringEmptyMesh)70 TEST_F(LayoutTest, FromStringEmptyMesh) {
71 Mesh mesh = Mesh::Empty();
72 std::string mesh_str = mesh.ToString();
73 EXPECT_EQ(mesh_str, Mesh::kEmptyMeshString);
74 }
75
TEST_F(LayoutTest,FromStringEmptyLayout)76 TEST_F(LayoutTest, FromStringEmptyLayout) {
77 Layout layout = Layout::Empty();
78 std::string layout_str = layout.ToString();
79 EXPECT_THAT(
80 layout.ToProto(),
81 EqualsProto(Layout::FromString(layout_str).ValueOrDie().ToProto()));
82 }
83
TEST_F(LayoutTest,LayoutToFromString)84 TEST_F(LayoutTest, LayoutToFromString) {
85 Layout layout = BatchLayout();
86 std::string layout_str = layout.ToString();
87 EXPECT_THAT(
88 layout.ToProto(),
89 EqualsProto(Layout::FromString(layout_str).ValueOrDie().ToProto()));
90 }
91
TEST_F(LayoutTest,LayoutToFromStringNotSharded)92 TEST_F(LayoutTest, LayoutToFromStringNotSharded) {
93 std::string layout_str = "sharding_specs:x," + string(Layout::kUnshardedDim) +
94 ", mesh:|x=1|0|0|/job:localhost/task:0/device:CPU:0";
95 EXPECT_EQ(layout_str, Layout::FromString(layout_str)->ToString());
96 }
97
TEST_F(LayoutTest,LayoutToFromStringAny)98 TEST_F(LayoutTest, LayoutToFromStringAny) {
99 std::string layout_str =
100 "sharding_specs:any, mesh:|x=1|0|0|/job:localhost/task:0/device:CPU:0";
101 EXPECT_EQ(layout_str, Layout::FromString(layout_str)->ToString());
102 }
103
TEST_F(LayoutTest,AutoGenerateLayout)104 TEST_F(LayoutTest, AutoGenerateLayout) {
105 std::string layout_str = "sharding_specs:x, mesh:|x=2,y=2|*CPU";
106 std::string exp_layout_str =
107 "sharding_specs:x, "
108 "mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/"
109 "job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/"
110 "job:localhost/task:0/device:CPU:3";
111 EXPECT_EQ(exp_layout_str, Layout::FromString(layout_str)->ToString());
112 }
113
TEST_F(LayoutTest,MeshToFromString)114 TEST_F(LayoutTest, MeshToFromString) {
115 Mesh mesh = BatchLayout().mesh();
116 std::string mesh_str = mesh.ToString();
117 EXPECT_THAT(mesh.ToProto(),
118 EqualsProto(Mesh::FromString(mesh_str).ValueOrDie().ToProto()));
119 }
120
TEST_F(LayoutTest,GetType)121 TEST_F(LayoutTest, GetType) {
122 Mesh mesh = BatchLayout().mesh();
123 EXPECT_TRUE(mesh.is_tpu_mesh());
124 }
125
TEST_F(LayoutTest,OnTPUMesh)126 TEST_F(LayoutTest, OnTPUMesh) {
127 Layout layout = BatchLayout();
128 EXPECT_TRUE(layout.mesh().is_tpu_mesh());
129 }
130
TEST_F(LayoutTest,NumShardsAsVector)131 TEST_F(LayoutTest, NumShardsAsVector) {
132 std::vector<int32> shards = {4, 8};
133 EXPECT_EQ(BatchLayout().num_shards(), shards);
134 }
135
TEST_F(LayoutTest,IsReplicated)136 TEST_F(LayoutTest, IsReplicated) {
137 EXPECT_FALSE(BatchLayout().IsFullyReplicated());
138 }
139
TEST_F(LayoutTest,LayoutDimLocations)140 TEST_F(LayoutTest, LayoutDimLocations) {
141 Layout layout = BatchLayout();
142 absl::InlinedVector<int64, 4> offset = {1, 2};
143 EXPECT_EQ(layout.device_location(10).ValueOrDie(), offset);
144 offset = {2, 2};
145 EXPECT_EQ(layout.device_location(18).ValueOrDie(), offset);
146 offset = {3, 7};
147 EXPECT_EQ(layout.device_location(31).ValueOrDie(), offset);
148
149 EXPECT_FALSE(layout.device_location(32).ok());
150 EXPECT_FALSE(layout.device_location(-1).ok());
151 }
152
TEST_F(LayoutTest,ScalarLayout)153 TEST_F(LayoutTest, ScalarLayout) {
154 Layout layout =
155 Layout::FromString("sharding_specs:scalar, mesh:|x=4,y=4|*TPU")
156 .ValueOrDie();
157 EXPECT_EQ(layout.num_devices(), 16);
158 EXPECT_TRUE(layout.mesh().is_tpu_mesh());
159 EXPECT_EQ(layout.ToProto().mesh_config().mesh_dimensions(0).size(), 4);
160 EXPECT_EQ(layout.rank(), 0);
161 }
162
TEST_F(LayoutTest,ParseSimpleTpuMesh)163 TEST_F(LayoutTest, ParseSimpleTpuMesh) {
164 Layout layout =
165 Layout::FromString("sharding_specs:x, mesh:|x=4,y=4|*TPU").ValueOrDie();
166 EXPECT_EQ(layout.num_devices(), 16);
167 EXPECT_TRUE(layout.mesh().is_tpu_mesh());
168 EXPECT_EQ(layout.ToProto().mesh_config().mesh_dimensions(0).size(), 4);
169 }
170
TEST_F(LayoutTest,ParseSimpleCpuMesh)171 TEST_F(LayoutTest, ParseSimpleCpuMesh) {
172 auto layout =
173 Layout::FromString("sharding_specs:x,unsharded, mesh:|x=4,y=4|*CPU")
174 .ValueOrDie();
175 EXPECT_EQ(layout.num_devices(), 16);
176 EXPECT_FALSE(layout.mesh().is_tpu_mesh());
177
178 EXPECT_EQ(layout.ToProto().mesh_config().mesh_dimensions(0).size(), 4);
179 }
180
TEST_F(LayoutTest,ParseFailsOnRepeatedShardingSpec)181 TEST_F(LayoutTest, ParseFailsOnRepeatedShardingSpec) {
182 StatusOr<Layout> maybe_layout =
183 Layout::FromString("sharding_specs:x,x, mesh:|x=1,y=2|*CPU");
184 EXPECT_FALSE(maybe_layout.ok());
185 }
186
TEST_F(LayoutTest,ParseFailsOnInvalidScalarShardingSpec)187 TEST_F(LayoutTest, ParseFailsOnInvalidScalarShardingSpec) {
188 StatusOr<Layout> maybe_layout =
189 Layout::FromString("sharding_specs:x,scalar, mesh:|x=1,y=2|*CPU");
190 EXPECT_FALSE(maybe_layout.ok());
191 }
192
TEST_F(LayoutTest,ParseFailsOnShardingSpecOverNonExistentMeshDim)193 TEST_F(LayoutTest, ParseFailsOnShardingSpecOverNonExistentMeshDim) {
194 StatusOr<Layout> maybe_layout =
195 Layout::FromString("sharding_specs:x,z, mesh:|x=1,y=2|*CPU");
196 EXPECT_FALSE(maybe_layout.ok());
197 }
198
TEST_F(LayoutTest,ParseFailsOnBadDeviceString)199 TEST_F(LayoutTest, ParseFailsOnBadDeviceString) {
200 auto layout =
201 Layout::FromString("sharding_specs:x,unsharded, d:TPU mesh:x=4,y=4");
202 EXPECT_FALSE(layout.ok()) << layout.status();
203 }
204
TEST_F(LayoutTest,ParseReplicatedLayout)205 TEST_F(LayoutTest, ParseReplicatedLayout) {
206 auto layout = Layout::FromString(
207 "sharding_specs:unsharded,unsharded, mesh:|x=4,y=4|*CPU")
208 .ValueOrDie();
209 EXPECT_EQ(layout.num_devices(), 16);
210 EXPECT_FALSE(layout.mesh().is_tpu_mesh());
211 EXPECT_TRUE(layout.IsFullyReplicated());
212 EXPECT_EQ(layout.ToProto().mesh_config().mesh_dimensions(0).size(), 4);
213 }
214
TEST_F(LayoutTest,SingleHostFullyReplicatedReducedMesh)215 TEST_F(LayoutTest, SingleHostFullyReplicatedReducedMesh) {
216 Layout replicated_layout =
217 Layout::FromString(
218 "sharding_specs:unsharded,unsharded, mesh:|x=2,y=2|*CPU")
219 .ValueOrDie();
220 Mesh reduced_mesh = replicated_layout.ReducedMesh();
221 EXPECT_EQ(reduced_mesh.size(), 1);
222 EXPECT_THAT(reduced_mesh.hosts(), SizeIs(1));
223 }
224
TEST_F(LayoutTest,SingleHostFullShardedReducedMesh)225 TEST_F(LayoutTest, SingleHostFullShardedReducedMesh) {
226 Layout layout = BatchLayout();
227 Mesh original_mesh = layout.mesh();
228 Mesh reduced_mesh = layout.ReducedMesh();
229 EXPECT_EQ(original_mesh.ToString(), reduced_mesh.ToString());
230 EXPECT_EQ(reduced_mesh.size(), 32);
231 EXPECT_THAT(reduced_mesh.hosts(), SizeIs(1));
232 }
233
TEST_F(LayoutTest,MultiHostReplicatedReducedMesh)234 TEST_F(LayoutTest, MultiHostReplicatedReducedMesh) {
235 StatusOr<Layout> layout = Layout::FromString(
236 "sharding_specs:unsharded,unsharded, "
237 "mesh:|x=4,y=2|0,1,2,3,4,5,6,7|4,5,6,7|"
238 "/job:localhost/task:1/device:CPU:0,/job:localhost/task:1/device:CPU:1,"
239 "/job:localhost/task:1/device:CPU:2,/job:localhost/task:1/device:CPU:3");
240
241 Mesh reduced_mesh = layout->ReducedMesh();
242 EXPECT_EQ(reduced_mesh.size(), 1);
243 EXPECT_THAT(reduced_mesh.global_device_ids(), ElementsAre(0));
244 EXPECT_THAT(reduced_mesh.local_device_ids(), IsEmpty());
245 EXPECT_THAT(reduced_mesh.local_devices(), IsEmpty());
246 EXPECT_THAT(reduced_mesh.hosts(), IsEmpty());
247 }
248
TEST_F(LayoutTest,MultiHostPartiallyShardedReducedMesh)249 TEST_F(LayoutTest, MultiHostPartiallyShardedReducedMesh) {
250 StatusOr<Layout> layout = Layout::FromString(
251 "sharding_specs:x,unsharded, "
252 "mesh:|x=4,y=2|0,1,2,3,4,5,6,7|4,5,6,7|"
253 "/job:localhost/task:1/device:CPU:0,/job:localhost/task:1/device:CPU:1,"
254 "/job:localhost/task:1/device:CPU:2,/job:localhost/task:1/device:CPU:3");
255
256 Mesh reduced_mesh = layout->ReducedMesh();
257 EXPECT_EQ(reduced_mesh.size(), 4);
258 EXPECT_THAT(reduced_mesh.global_device_ids(), ElementsAre(0, 2, 4, 6));
259 EXPECT_THAT(reduced_mesh.local_device_ids(), ElementsAre(4, 6));
260 EXPECT_THAT(reduced_mesh.local_devices(),
261 ElementsAre("/job:localhost/task:1/device:CPU:0",
262 "/job:localhost/task:1/device:CPU:2"));
263 EXPECT_THAT(reduced_mesh.hosts(), SizeIs(1));
264 }
265
TEST_F(LayoutTest,MultiHostFullyShardedReducedMesh)266 TEST_F(LayoutTest, MultiHostFullyShardedReducedMesh) {
267 StatusOr<Layout> layout = Layout::FromString(
268 "sharding_specs:x,y, "
269 "mesh:|x=4,y=2|0,1,2,3,4,5,6,7|4,5,6,7|"
270 "/job:localhost/task:1/device:CPU:0,/job:localhost/task:1/device:CPU:1,"
271 "/job:localhost/task:1/device:CPU:2,/job:localhost/task:1/device:CPU:3");
272
273 Mesh reduced_mesh = layout->ReducedMesh();
274 EXPECT_EQ(reduced_mesh.size(), 8);
275 EXPECT_THAT(reduced_mesh.global_device_ids(),
276 ElementsAre(0, 1, 2, 3, 4, 5, 6, 7));
277 EXPECT_THAT(reduced_mesh.local_device_ids(), ElementsAre(4, 5, 6, 7));
278 EXPECT_THAT(reduced_mesh.local_devices(),
279 ElementsAre("/job:localhost/task:1/device:CPU:0",
280 "/job:localhost/task:1/device:CPU:1",
281 "/job:localhost/task:1/device:CPU:2",
282 "/job:localhost/task:1/device:CPU:3"));
283 EXPECT_EQ(reduced_mesh.hosts().size(), 1);
284 }
285
286 // TODO(luispazos) Decide if we want this to be the case.
TEST_F(LayoutTest,FlippedShardedMultiHostMeshes)287 TEST_F(LayoutTest, FlippedShardedMultiHostMeshes) {
288 StatusOr<Layout> multi_host_layout_1 = Layout::FromString(
289 "sharding_specs:x,y, "
290 "mesh:|x=4,y=2|0,1,2,3,4,5,6,7|4,5,6,7|"
291 "/job:localhost/task:1/device:CPU:0,/job:localhost/task:1/device:CPU:1,"
292 "/job:localhost/task:1/device:CPU:2,/job:localhost/task:1/device:CPU:3");
293 StatusOr<Layout> multi_host_layout_2 = Layout::FromString(
294 "sharding_specs:x,y, "
295 "mesh:|x=4,y=2|0,1,2,3,4,5,6,7|6,7,4,5|"
296 "/job:localhost/task:1/device:CPU:2,/job:localhost/task:1/device:CPU:3,"
297 "/job:localhost/task:1/device:CPU:0,/job:localhost/task:1/device:CPU:1");
298
299 Mesh reduced_mesh_1 = multi_host_layout_1->ReducedMesh();
300 Mesh reduced_mesh_2 = multi_host_layout_2->ReducedMesh();
301 EXPECT_FALSE(reduced_mesh_1 == reduced_mesh_2);
302 }
303
TEST_F(LayoutTest,ShardEqualityOneDim)304 TEST_F(LayoutTest, ShardEqualityOneDim) {
305 ShardVector shard_vec1;
306 Shard shard1{1};
307 shard_vec1.shards.push_back(shard1);
308 shard_vec1.num_shards_per_dim.push_back(1);
309
310 ShardVector shard_vec2;
311 Shard shard2{2};
312 Shard shard3{3};
313 shard_vec2.shards.push_back(shard1);
314 shard_vec2.shards.push_back(shard2);
315 shard_vec2.shards.push_back(shard3);
316 shard_vec2.num_shards_per_dim.push_back(3);
317
318 EXPECT_EQ(shard_vec1, shard_vec2);
319 }
320
TEST_F(LayoutTest,ShardEqualityOneDimOffset)321 TEST_F(LayoutTest, ShardEqualityOneDimOffset) {
322 ShardVector shard_vec1;
323 Shard shard1{3};
324 shard_vec1.shards.push_back(shard1);
325 shard_vec1.num_shards_per_dim.push_back(3);
326
327 ShardVector shard_vec2;
328 Shard shard2{7};
329 Shard shard3{8};
330 Shard shard4{9};
331 shard_vec2.shards.push_back(shard2);
332 shard_vec2.shards.push_back(shard3);
333 shard_vec2.shards.push_back(shard4);
334 shard_vec2.num_shards_per_dim.push_back(9);
335
336 EXPECT_EQ(shard_vec1, shard_vec2);
337 }
338
TEST_F(LayoutTest,ShardEqualityTwoDims)339 TEST_F(LayoutTest, ShardEqualityTwoDims) {
340 auto GenFullVector = [](std::vector<int> num_shards_per_dim) -> ShardVector {
341 ShardVector shard_vec;
342 shard_vec.num_shards_per_dim = num_shards_per_dim;
343 for (int i = 1; i <= num_shards_per_dim[0]; ++i)
344 for (int j = 1; j <= num_shards_per_dim[1]; ++j) {
345 Shard shard{i, j};
346 shard_vec.shards.push_back(shard);
347 }
348 return shard_vec;
349 };
350 std::vector<int> num_shards_per_dim_1{2, 4};
351 ShardVector shard_vec1 = GenFullVector(num_shards_per_dim_1);
352
353 std::vector<int> num_shards_per_dim_2{3, 3};
354 ShardVector shard_vec2 = GenFullVector(num_shards_per_dim_2);
355 EXPECT_EQ(shard_vec1, shard_vec2);
356 }
357
TEST_F(LayoutTest,Shards)358 TEST_F(LayoutTest, Shards) {
359 Layout layout =
360 Layout::FromString("sharding_specs:x,y, mesh:|x=2,y=3|*CPU").ValueOrDie();
361 ShardVector shard_vec = layout.GetShardVector();
362
363 std::string expected_shard_vec_str =
364 "shards:[(1,1),(1,2),(1,3),(2,1),(2,2),(2,3)] num_shards_per_dim:(2,3)";
365 EXPECT_EQ(shard_vec.ToString(), expected_shard_vec_str);
366 }
367
TEST_F(LayoutTest,ShardsInverted)368 TEST_F(LayoutTest, ShardsInverted) {
369 Layout layout =
370 Layout::FromString("sharding_specs:y,x, mesh:|x=2,y=3|*CPU").ValueOrDie();
371 ShardVector shards = layout.GetShardVector();
372 std::string expected_shards =
373 "shards:[(1,1),(2,1),(3,1),(1,2),(2,2),(3,2)] num_shards_per_dim:(3,2)";
374 EXPECT_EQ(shards.ToString(), expected_shards);
375 }
376
TEST_F(LayoutTest,HostShardMap)377 TEST_F(LayoutTest, HostShardMap) {
378 Layout layout =
379 Layout::FromString("sharding_specs:x,y, mesh:TPU|x=2,y=2|*TPU")
380 .ValueOrDie();
381 std::string host_name = layout.mesh().hosts()[0];
382 auto host_map = layout.HostShardMap();
383
384 std::string expected_shards =
385 "shards:[(1,1),(1,2),(2,1),(2,2)] num_shards_per_dim:(2,2)";
386 EXPECT_EQ(host_map.find(host_name)->second.ToString(), expected_shards);
387 }
388
TEST_F(LayoutTest,MultiHostMultiDeviceShards)389 TEST_F(LayoutTest, MultiHostMultiDeviceShards) {
390 std::string host1 = "/job:localhost/task:0";
391 std::string host2 = "/job:localhost/task:1";
392 std::string device1 = "/device:TPU:0";
393 std::string device2 = "/device:TPU:1";
394 Layout layout =
395 Layout::FromString(
396 "sharding_specs:x,unsharded, mesh:TPU|x=4,y=1|0,1,2,3|0,1,2,3|" +
397 host1 + device1 + "," + host1 + device2 + "," + host2 + device1 +
398 "," + host2 + device2)
399 .ValueOrDie();
400 std::string expected_shard_vec =
401 "shards:[(1,1),(2,1),(3,1),(4,1)] num_shards_per_dim:(4,1)";
402 EXPECT_EQ(layout.GetShardVector().ToString(), expected_shard_vec);
403
404 std::map<std::string, ShardVector> host_shard_map = layout.HostShardMap();
405
406 std::string expected_shards_host1 =
407 "shards:[(1,1),(2,1)] num_shards_per_dim:(4,1)";
408 ShardVector host1_shard_vec = host_shard_map.find(host1)->second;
409 EXPECT_EQ(host1_shard_vec.ToString(), expected_shards_host1);
410
411 std::string expected_shards_host2 =
412 "shards:[(3,1),(4,1)] num_shards_per_dim:(4,1)";
413 ShardVector host2_shard_vec = host_shard_map.find(host2)->second;
414 EXPECT_EQ(host2_shard_vec.ToString(), expected_shards_host2);
415 }
416
TEST_F(LayoutTest,MultiHostCommXYSharded)417 TEST_F(LayoutTest, MultiHostCommXYSharded) {
418 std::string host_0 = "/job:localhost/task:0/";
419 std::string host_1 = "/job:localhost/task:1/";
420
421 StatusOr<Layout> send_layout =
422 Layout::FromString("sharding_specs:y,x, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|" +
423 host_0 + "device:CPU:0," + host_0 + "device:CPU:1," +
424 host_1 + "device:CPU:0," + host_1 + "device:CPU:1");
425 StatusOr<Layout> recv_layout =
426 Layout::FromString("sharding_specs:x,y, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|" +
427 host_0 + "device:TPU:0," + host_0 + "device:TPU:1," +
428 host_1 + "device:TPU:0," + host_1 + "device:TPU:1");
429
430 std::vector<std::string> send_hosts = send_layout->ReducedMesh().hosts();
431 std::vector<std::string> recv_hosts = recv_layout->ReducedMesh().hosts();
432 EXPECT_TRUE(send_hosts == recv_hosts);
433 }
434
TEST_F(LayoutTest,MultiHostCommXSharded)435 TEST_F(LayoutTest, MultiHostCommXSharded) {
436 std::vector<std::string> hosts{"/job:localhost/task:0",
437 "/job:localhost/task:1"};
438
439 StatusOr<Layout> send_layout = Layout::FromString(
440 "sharding_specs:x, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|" + hosts[0] +
441 "/device:CPU:0," + hosts[0] + "/device:CPU:1," + hosts[1] +
442 "/device:CPU:0," + hosts[1] + "/device:CPU:1");
443 StatusOr<Layout> recv_layout = Layout::FromString(
444 "sharding_specs:x, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|" + hosts[0] +
445 "/device:TPU:0," + hosts[0] + "/device:TPU:1," + hosts[1] +
446 "/device:TPU:0," + hosts[1] + "/device:TPU:1");
447
448 std::vector<std::string> send_hosts = send_layout->ReducedMesh().hosts();
449 std::vector<std::string> recv_hosts = recv_layout->ReducedMesh().hosts();
450 EXPECT_TRUE(send_hosts == recv_hosts);
451
452 std::map<std::string, ShardVector> send_host_shard_map =
453 send_layout->HostShardMap();
454 std::map<std::string, ShardVector> recv_host_shard_map =
455 recv_layout->HostShardMap();
456
457 // Check shards match in each host.
458 for (const std::string& host : hosts) {
459 ShardVector shard_vec_in_send_host = send_host_shard_map.find(host)->second;
460 ShardVector shard_vec_in_recv_host = recv_host_shard_map.find(host)->second;
461 EXPECT_EQ(shard_vec_in_send_host, shard_vec_in_recv_host);
462 }
463 }
464
TEST_F(LayoutTest,Transposed2DLayout)465 TEST_F(LayoutTest, Transposed2DLayout) {
466 auto layout =
467 Layout::FromString("sharding_specs:x,y, mesh:|x=2,y=2|*CPU").ValueOrDie();
468 auto expected_layout =
469 Layout::FromString("sharding_specs:y,x, mesh:|x=2,y=2|*CPU").ValueOrDie();
470 EXPECT_EQ(Layout::Transposed2D(layout).ValueOrDie(), expected_layout);
471 }
472
TEST_F(LayoutTest,Transposed2DLayoutWithBatch)473 TEST_F(LayoutTest, Transposed2DLayoutWithBatch) {
474 auto layout = Layout::FromString(
475 "sharding_specs:b1,b2,x,y, mesh:|x=2,y=2,b1=2,b2=2|*CPU")
476 .ValueOrDie();
477 auto expected_layout =
478 Layout::FromString(
479 "sharding_specs:b1,b2,y,x, mesh:|x=2,y=2,b1=2,b2=2|*CPU")
480 .ValueOrDie();
481 EXPECT_EQ(Layout::Transposed2D(layout).ValueOrDie(), expected_layout);
482 }
483
TEST_F(LayoutTest,MeshDimensionIndex)484 TEST_F(LayoutTest, MeshDimensionIndex) {
485 auto layout =
486 Layout::FromString("sharding_specs:x,y, mesh:|x=2,y=2|*CPU").ValueOrDie();
487 EXPECT_EQ(layout.mesh().idx_for_dim("x").ValueOrDie(), 0);
488 EXPECT_EQ(layout.mesh().idx_for_dim("y").ValueOrDie(), 1);
489 }
490
TEST_F(LayoutTest,TruncateBeginning)491 TEST_F(LayoutTest, TruncateBeginning) {
492 auto layout = Layout::FromString("sharding_specs:x,y, mesh:CPU|x=2,y=2|*CPU")
493 .ValueOrDie();
494 auto expected_layout =
495 Layout::FromString("sharding_specs:x, mesh:CPU|x=2,y=2|*CPU")
496 .ValueOrDie();
497 EXPECT_EQ(layout.Truncate(/*split_point=*/1), expected_layout);
498 }
499
TEST_F(LayoutTest,TruncateEnd)500 TEST_F(LayoutTest, TruncateEnd) {
501 auto layout = Layout::FromString("sharding_specs:x,y, mesh:CPU|x=2,y=2|*CPU")
502 .ValueOrDie();
503 auto expected_layout =
504 Layout::FromString("sharding_specs:y, mesh:CPU|x=2,y=2|*CPU")
505 .ValueOrDie();
506 EXPECT_EQ(layout.Truncate(/*split_point=*/1, /*end=*/true), expected_layout);
507 }
508
TEST_F(LayoutTest,Concatenate)509 TEST_F(LayoutTest, Concatenate) {
510 auto layout_1 = Layout::FromString("sharding_specs:x, mesh:CPU|x=2,y=2|*CPU")
511 .ValueOrDie();
512 auto layout_2 = Layout::FromString("sharding_specs:y, mesh:CPU|x=2,y=2|*CPU")
513 .ValueOrDie();
514 auto expected_layout =
515 Layout::FromString("sharding_specs:x,y, mesh:CPU|x=2,y=2|*CPU")
516 .ValueOrDie();
517 EXPECT_EQ(ConcatenateLayouts(layout_1, layout_2).ValueOrDie(),
518 expected_layout);
519 }
520
TEST_F(LayoutTest,ConcatenateDifferentMesh)521 TEST_F(LayoutTest, ConcatenateDifferentMesh) {
522 auto layout_1 =
523 Layout::FromString("sharding_specs:x, mesh:CPU|x=2|*CPU").ValueOrDie();
524 auto layout_2 =
525 Layout::FromString("sharding_specs:y, mesh:CPU|y=2|*CPU").ValueOrDie();
526 auto layout = ConcatenateLayouts(layout_1, layout_2);
527 EXPECT_FALSE(layout.ok()) << layout.status();
528 }
529
TEST_F(LayoutTest,ConcatenateSameDimension)530 TEST_F(LayoutTest, ConcatenateSameDimension) {
531 auto layout_1 = Layout::FromString("sharding_specs:x, mesh:CPU|x=2,y=2|*CPU")
532 .ValueOrDie();
533 auto layout_2 = Layout::FromString("sharding_specs:x, mesh:CPU|x=2,y=2|*CPU")
534 .ValueOrDie();
535 auto layout = ConcatenateLayouts(layout_1, layout_2);
536 EXPECT_FALSE(layout.ok()) << layout.status();
537 }
538
TEST_F(LayoutTest,EmptyMeshDeviceType)539 TEST_F(LayoutTest, EmptyMeshDeviceType) {
540 auto mesh = Mesh::Empty();
541 EXPECT_EQ(mesh.device_type(), std::string());
542 }
543
TEST_F(LayoutTest,ConvertMeshDeviceType)544 TEST_F(LayoutTest, ConvertMeshDeviceType) {
545 Mesh mesh = Mesh::FromString("mesh:|x=2,batch=1|*TPU").ValueOrDie();
546 Mesh cpu_mesh = mesh.ToDeviceType("CPU").ValueOrDie();
547 EXPECT_TRUE(cpu_mesh.is_cpu_mesh());
548
549 std::string expected_task_name = "/job:localhost/replica:0/task:0/";
550 Mesh expected_mesh =
551 Mesh::FromString("mesh:|x=2,batch=1|0,1|0,1|" + expected_task_name +
552 "device:CPU:0," + expected_task_name + "device:CPU:1")
553 .ValueOrDie();
554 EXPECT_EQ(cpu_mesh, expected_mesh);
555 }
556
TEST_F(LayoutTest,EquivalentLayout)557 TEST_F(LayoutTest, EquivalentLayout) {
558 Layout fully_sharded =
559 Layout::FromString("sharding_specs:x,y, mesh:|x=2,y=1|*TPU").ValueOrDie();
560 Layout x_sharded =
561 Layout::FromString("sharding_specs:x,unsharded, mesh:|x=2,y=1|*TPU")
562 .ValueOrDie();
563 Layout y_sharded =
564 Layout::FromString("sharding_specs:unsharded,y, mesh:|x=2,y=1|*TPU")
565 .ValueOrDie();
566
567 EXPECT_TRUE(fully_sharded.IsEquivalent(x_sharded));
568 EXPECT_TRUE(x_sharded.IsEquivalent(fully_sharded));
569 EXPECT_FALSE(fully_sharded.IsEquivalent(y_sharded));
570 EXPECT_FALSE(y_sharded.IsEquivalent(fully_sharded));
571 }
572
573 } // namespace
574 } // namespace dtensor
575 } // namespace tensorflow
576