xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/tests/tensor_layout_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/cc/tensor_layout.h"
17 
18 #include <map>
19 #include <memory>
20 #include <ostream>
21 #include <string>
22 #include <vector>
23 
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/dtensor/proto/layout.pb.h"
26 
27 namespace tensorflow {
28 namespace dtensor {
29 namespace {
30 
31 using ::testing::ElementsAre;
32 using ::testing::IsEmpty;
33 using ::testing::SizeIs;
34 
35 // Simple implementation of a proto matcher comparing string representations.
36 // Only works as ShapeProto's textual representation is deterministic.
37 class ProtoStringMatcher {
38  public:
ProtoStringMatcher(const tensorflow::protobuf::Message & expected)39   explicit ProtoStringMatcher(const tensorflow::protobuf::Message& expected)
40       : expected_(expected.SerializeAsString()) {}
41 
42   template <typename Message>
MatchAndExplain(const Message & p,::testing::MatchResultListener *) const43   bool MatchAndExplain(const Message& p,
44                        ::testing::MatchResultListener*) const {
45     return p.SerializeAsString() == expected_;
46   }
47 
DescribeTo(::std::ostream * os) const48   void DescribeTo(::std::ostream* os) const { *os << expected_; }
DescribeNegationTo(::std::ostream * os) const49   void DescribeNegationTo(::std::ostream* os) const {
50     *os << "not equal to expected message: " << expected_;
51   }
52 
53  private:
54   const std::string expected_;
55 };
56 
EqualsProto(const tensorflow::protobuf::Message & x)57 inline ::testing::PolymorphicMatcher<ProtoStringMatcher> EqualsProto(
58     const tensorflow::protobuf::Message& x) {
59   return ::testing::MakePolymorphicMatcher(ProtoStringMatcher(x));
60 }
61 
62 class LayoutTest : public ::testing::Test {
63  protected:
BatchLayout()64   Layout BatchLayout() {
65     return Layout::FromString("sharding_specs:x,batch, mesh:|x=4,batch=8|*TPU")
66         .ValueOrDie();
67   }
68 };
69 
TEST_F(LayoutTest,FromStringEmptyMesh)70 TEST_F(LayoutTest, FromStringEmptyMesh) {
71   Mesh mesh = Mesh::Empty();
72   std::string mesh_str = mesh.ToString();
73   EXPECT_EQ(mesh_str, Mesh::kEmptyMeshString);
74 }
75 
TEST_F(LayoutTest,FromStringEmptyLayout)76 TEST_F(LayoutTest, FromStringEmptyLayout) {
77   Layout layout = Layout::Empty();
78   std::string layout_str = layout.ToString();
79   EXPECT_THAT(
80       layout.ToProto(),
81       EqualsProto(Layout::FromString(layout_str).ValueOrDie().ToProto()));
82 }
83 
TEST_F(LayoutTest,LayoutToFromString)84 TEST_F(LayoutTest, LayoutToFromString) {
85   Layout layout = BatchLayout();
86   std::string layout_str = layout.ToString();
87   EXPECT_THAT(
88       layout.ToProto(),
89       EqualsProto(Layout::FromString(layout_str).ValueOrDie().ToProto()));
90 }
91 
TEST_F(LayoutTest,LayoutToFromStringNotSharded)92 TEST_F(LayoutTest, LayoutToFromStringNotSharded) {
93   std::string layout_str = "sharding_specs:x," + string(Layout::kUnshardedDim) +
94                            ", mesh:|x=1|0|0|/job:localhost/task:0/device:CPU:0";
95   EXPECT_EQ(layout_str, Layout::FromString(layout_str)->ToString());
96 }
97 
TEST_F(LayoutTest,LayoutToFromStringAny)98 TEST_F(LayoutTest, LayoutToFromStringAny) {
99   std::string layout_str =
100       "sharding_specs:any, mesh:|x=1|0|0|/job:localhost/task:0/device:CPU:0";
101   EXPECT_EQ(layout_str, Layout::FromString(layout_str)->ToString());
102 }
103 
TEST_F(LayoutTest,AutoGenerateLayout)104 TEST_F(LayoutTest, AutoGenerateLayout) {
105   std::string layout_str = "sharding_specs:x, mesh:|x=2,y=2|*CPU";
106   std::string exp_layout_str =
107       "sharding_specs:x, "
108       "mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/"
109       "job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/"
110       "job:localhost/task:0/device:CPU:3";
111   EXPECT_EQ(exp_layout_str, Layout::FromString(layout_str)->ToString());
112 }
113 
TEST_F(LayoutTest,MeshToFromString)114 TEST_F(LayoutTest, MeshToFromString) {
115   Mesh mesh = BatchLayout().mesh();
116   std::string mesh_str = mesh.ToString();
117   EXPECT_THAT(mesh.ToProto(),
118               EqualsProto(Mesh::FromString(mesh_str).ValueOrDie().ToProto()));
119 }
120 
TEST_F(LayoutTest,GetType)121 TEST_F(LayoutTest, GetType) {
122   Mesh mesh = BatchLayout().mesh();
123   EXPECT_TRUE(mesh.is_tpu_mesh());
124 }
125 
TEST_F(LayoutTest,OnTPUMesh)126 TEST_F(LayoutTest, OnTPUMesh) {
127   Layout layout = BatchLayout();
128   EXPECT_TRUE(layout.mesh().is_tpu_mesh());
129 }
130 
TEST_F(LayoutTest,NumShardsAsVector)131 TEST_F(LayoutTest, NumShardsAsVector) {
132   std::vector<int32> shards = {4, 8};
133   EXPECT_EQ(BatchLayout().num_shards(), shards);
134 }
135 
TEST_F(LayoutTest,IsReplicated)136 TEST_F(LayoutTest, IsReplicated) {
137   EXPECT_FALSE(BatchLayout().IsFullyReplicated());
138 }
139 
TEST_F(LayoutTest,LayoutDimLocations)140 TEST_F(LayoutTest, LayoutDimLocations) {
141   Layout layout = BatchLayout();
142   absl::InlinedVector<int64, 4> offset = {1, 2};
143   EXPECT_EQ(layout.device_location(10).ValueOrDie(), offset);
144   offset = {2, 2};
145   EXPECT_EQ(layout.device_location(18).ValueOrDie(), offset);
146   offset = {3, 7};
147   EXPECT_EQ(layout.device_location(31).ValueOrDie(), offset);
148 
149   EXPECT_FALSE(layout.device_location(32).ok());
150   EXPECT_FALSE(layout.device_location(-1).ok());
151 }
152 
TEST_F(LayoutTest,ScalarLayout)153 TEST_F(LayoutTest, ScalarLayout) {
154   Layout layout =
155       Layout::FromString("sharding_specs:scalar, mesh:|x=4,y=4|*TPU")
156           .ValueOrDie();
157   EXPECT_EQ(layout.num_devices(), 16);
158   EXPECT_TRUE(layout.mesh().is_tpu_mesh());
159   EXPECT_EQ(layout.ToProto().mesh_config().mesh_dimensions(0).size(), 4);
160   EXPECT_EQ(layout.rank(), 0);
161 }
162 
TEST_F(LayoutTest,ParseSimpleTpuMesh)163 TEST_F(LayoutTest, ParseSimpleTpuMesh) {
164   Layout layout =
165       Layout::FromString("sharding_specs:x, mesh:|x=4,y=4|*TPU").ValueOrDie();
166   EXPECT_EQ(layout.num_devices(), 16);
167   EXPECT_TRUE(layout.mesh().is_tpu_mesh());
168   EXPECT_EQ(layout.ToProto().mesh_config().mesh_dimensions(0).size(), 4);
169 }
170 
TEST_F(LayoutTest,ParseSimpleCpuMesh)171 TEST_F(LayoutTest, ParseSimpleCpuMesh) {
172   auto layout =
173       Layout::FromString("sharding_specs:x,unsharded, mesh:|x=4,y=4|*CPU")
174           .ValueOrDie();
175   EXPECT_EQ(layout.num_devices(), 16);
176   EXPECT_FALSE(layout.mesh().is_tpu_mesh());
177 
178   EXPECT_EQ(layout.ToProto().mesh_config().mesh_dimensions(0).size(), 4);
179 }
180 
TEST_F(LayoutTest,ParseFailsOnRepeatedShardingSpec)181 TEST_F(LayoutTest, ParseFailsOnRepeatedShardingSpec) {
182   StatusOr<Layout> maybe_layout =
183       Layout::FromString("sharding_specs:x,x, mesh:|x=1,y=2|*CPU");
184   EXPECT_FALSE(maybe_layout.ok());
185 }
186 
TEST_F(LayoutTest,ParseFailsOnInvalidScalarShardingSpec)187 TEST_F(LayoutTest, ParseFailsOnInvalidScalarShardingSpec) {
188   StatusOr<Layout> maybe_layout =
189       Layout::FromString("sharding_specs:x,scalar, mesh:|x=1,y=2|*CPU");
190   EXPECT_FALSE(maybe_layout.ok());
191 }
192 
TEST_F(LayoutTest,ParseFailsOnShardingSpecOverNonExistentMeshDim)193 TEST_F(LayoutTest, ParseFailsOnShardingSpecOverNonExistentMeshDim) {
194   StatusOr<Layout> maybe_layout =
195       Layout::FromString("sharding_specs:x,z, mesh:|x=1,y=2|*CPU");
196   EXPECT_FALSE(maybe_layout.ok());
197 }
198 
TEST_F(LayoutTest,ParseFailsOnBadDeviceString)199 TEST_F(LayoutTest, ParseFailsOnBadDeviceString) {
200   auto layout =
201       Layout::FromString("sharding_specs:x,unsharded, d:TPU mesh:x=4,y=4");
202   EXPECT_FALSE(layout.ok()) << layout.status();
203 }
204 
TEST_F(LayoutTest,ParseReplicatedLayout)205 TEST_F(LayoutTest, ParseReplicatedLayout) {
206   auto layout = Layout::FromString(
207                     "sharding_specs:unsharded,unsharded, mesh:|x=4,y=4|*CPU")
208                     .ValueOrDie();
209   EXPECT_EQ(layout.num_devices(), 16);
210   EXPECT_FALSE(layout.mesh().is_tpu_mesh());
211   EXPECT_TRUE(layout.IsFullyReplicated());
212   EXPECT_EQ(layout.ToProto().mesh_config().mesh_dimensions(0).size(), 4);
213 }
214 
TEST_F(LayoutTest,SingleHostFullyReplicatedReducedMesh)215 TEST_F(LayoutTest, SingleHostFullyReplicatedReducedMesh) {
216   Layout replicated_layout =
217       Layout::FromString(
218           "sharding_specs:unsharded,unsharded, mesh:|x=2,y=2|*CPU")
219           .ValueOrDie();
220   Mesh reduced_mesh = replicated_layout.ReducedMesh();
221   EXPECT_EQ(reduced_mesh.size(), 1);
222   EXPECT_THAT(reduced_mesh.hosts(), SizeIs(1));
223 }
224 
TEST_F(LayoutTest,SingleHostFullShardedReducedMesh)225 TEST_F(LayoutTest, SingleHostFullShardedReducedMesh) {
226   Layout layout = BatchLayout();
227   Mesh original_mesh = layout.mesh();
228   Mesh reduced_mesh = layout.ReducedMesh();
229   EXPECT_EQ(original_mesh.ToString(), reduced_mesh.ToString());
230   EXPECT_EQ(reduced_mesh.size(), 32);
231   EXPECT_THAT(reduced_mesh.hosts(), SizeIs(1));
232 }
233 
TEST_F(LayoutTest,MultiHostReplicatedReducedMesh)234 TEST_F(LayoutTest, MultiHostReplicatedReducedMesh) {
235   StatusOr<Layout> layout = Layout::FromString(
236       "sharding_specs:unsharded,unsharded, "
237       "mesh:|x=4,y=2|0,1,2,3,4,5,6,7|4,5,6,7|"
238       "/job:localhost/task:1/device:CPU:0,/job:localhost/task:1/device:CPU:1,"
239       "/job:localhost/task:1/device:CPU:2,/job:localhost/task:1/device:CPU:3");
240 
241   Mesh reduced_mesh = layout->ReducedMesh();
242   EXPECT_EQ(reduced_mesh.size(), 1);
243   EXPECT_THAT(reduced_mesh.global_device_ids(), ElementsAre(0));
244   EXPECT_THAT(reduced_mesh.local_device_ids(), IsEmpty());
245   EXPECT_THAT(reduced_mesh.local_devices(), IsEmpty());
246   EXPECT_THAT(reduced_mesh.hosts(), IsEmpty());
247 }
248 
TEST_F(LayoutTest,MultiHostPartiallyShardedReducedMesh)249 TEST_F(LayoutTest, MultiHostPartiallyShardedReducedMesh) {
250   StatusOr<Layout> layout = Layout::FromString(
251       "sharding_specs:x,unsharded, "
252       "mesh:|x=4,y=2|0,1,2,3,4,5,6,7|4,5,6,7|"
253       "/job:localhost/task:1/device:CPU:0,/job:localhost/task:1/device:CPU:1,"
254       "/job:localhost/task:1/device:CPU:2,/job:localhost/task:1/device:CPU:3");
255 
256   Mesh reduced_mesh = layout->ReducedMesh();
257   EXPECT_EQ(reduced_mesh.size(), 4);
258   EXPECT_THAT(reduced_mesh.global_device_ids(), ElementsAre(0, 2, 4, 6));
259   EXPECT_THAT(reduced_mesh.local_device_ids(), ElementsAre(4, 6));
260   EXPECT_THAT(reduced_mesh.local_devices(),
261               ElementsAre("/job:localhost/task:1/device:CPU:0",
262                           "/job:localhost/task:1/device:CPU:2"));
263   EXPECT_THAT(reduced_mesh.hosts(), SizeIs(1));
264 }
265 
TEST_F(LayoutTest,MultiHostFullyShardedReducedMesh)266 TEST_F(LayoutTest, MultiHostFullyShardedReducedMesh) {
267   StatusOr<Layout> layout = Layout::FromString(
268       "sharding_specs:x,y, "
269       "mesh:|x=4,y=2|0,1,2,3,4,5,6,7|4,5,6,7|"
270       "/job:localhost/task:1/device:CPU:0,/job:localhost/task:1/device:CPU:1,"
271       "/job:localhost/task:1/device:CPU:2,/job:localhost/task:1/device:CPU:3");
272 
273   Mesh reduced_mesh = layout->ReducedMesh();
274   EXPECT_EQ(reduced_mesh.size(), 8);
275   EXPECT_THAT(reduced_mesh.global_device_ids(),
276               ElementsAre(0, 1, 2, 3, 4, 5, 6, 7));
277   EXPECT_THAT(reduced_mesh.local_device_ids(), ElementsAre(4, 5, 6, 7));
278   EXPECT_THAT(reduced_mesh.local_devices(),
279               ElementsAre("/job:localhost/task:1/device:CPU:0",
280                           "/job:localhost/task:1/device:CPU:1",
281                           "/job:localhost/task:1/device:CPU:2",
282                           "/job:localhost/task:1/device:CPU:3"));
283   EXPECT_EQ(reduced_mesh.hosts().size(), 1);
284 }
285 
286 // TODO(luispazos) Decide if we want this to be the case.
TEST_F(LayoutTest,FlippedShardedMultiHostMeshes)287 TEST_F(LayoutTest, FlippedShardedMultiHostMeshes) {
288   StatusOr<Layout> multi_host_layout_1 = Layout::FromString(
289       "sharding_specs:x,y, "
290       "mesh:|x=4,y=2|0,1,2,3,4,5,6,7|4,5,6,7|"
291       "/job:localhost/task:1/device:CPU:0,/job:localhost/task:1/device:CPU:1,"
292       "/job:localhost/task:1/device:CPU:2,/job:localhost/task:1/device:CPU:3");
293   StatusOr<Layout> multi_host_layout_2 = Layout::FromString(
294       "sharding_specs:x,y, "
295       "mesh:|x=4,y=2|0,1,2,3,4,5,6,7|6,7,4,5|"
296       "/job:localhost/task:1/device:CPU:2,/job:localhost/task:1/device:CPU:3,"
297       "/job:localhost/task:1/device:CPU:0,/job:localhost/task:1/device:CPU:1");
298 
299   Mesh reduced_mesh_1 = multi_host_layout_1->ReducedMesh();
300   Mesh reduced_mesh_2 = multi_host_layout_2->ReducedMesh();
301   EXPECT_FALSE(reduced_mesh_1 == reduced_mesh_2);
302 }
303 
TEST_F(LayoutTest,ShardEqualityOneDim)304 TEST_F(LayoutTest, ShardEqualityOneDim) {
305   ShardVector shard_vec1;
306   Shard shard1{1};
307   shard_vec1.shards.push_back(shard1);
308   shard_vec1.num_shards_per_dim.push_back(1);
309 
310   ShardVector shard_vec2;
311   Shard shard2{2};
312   Shard shard3{3};
313   shard_vec2.shards.push_back(shard1);
314   shard_vec2.shards.push_back(shard2);
315   shard_vec2.shards.push_back(shard3);
316   shard_vec2.num_shards_per_dim.push_back(3);
317 
318   EXPECT_EQ(shard_vec1, shard_vec2);
319 }
320 
TEST_F(LayoutTest,ShardEqualityOneDimOffset)321 TEST_F(LayoutTest, ShardEqualityOneDimOffset) {
322   ShardVector shard_vec1;
323   Shard shard1{3};
324   shard_vec1.shards.push_back(shard1);
325   shard_vec1.num_shards_per_dim.push_back(3);
326 
327   ShardVector shard_vec2;
328   Shard shard2{7};
329   Shard shard3{8};
330   Shard shard4{9};
331   shard_vec2.shards.push_back(shard2);
332   shard_vec2.shards.push_back(shard3);
333   shard_vec2.shards.push_back(shard4);
334   shard_vec2.num_shards_per_dim.push_back(9);
335 
336   EXPECT_EQ(shard_vec1, shard_vec2);
337 }
338 
TEST_F(LayoutTest,ShardEqualityTwoDims)339 TEST_F(LayoutTest, ShardEqualityTwoDims) {
340   auto GenFullVector = [](std::vector<int> num_shards_per_dim) -> ShardVector {
341     ShardVector shard_vec;
342     shard_vec.num_shards_per_dim = num_shards_per_dim;
343     for (int i = 1; i <= num_shards_per_dim[0]; ++i)
344       for (int j = 1; j <= num_shards_per_dim[1]; ++j) {
345         Shard shard{i, j};
346         shard_vec.shards.push_back(shard);
347       }
348     return shard_vec;
349   };
350   std::vector<int> num_shards_per_dim_1{2, 4};
351   ShardVector shard_vec1 = GenFullVector(num_shards_per_dim_1);
352 
353   std::vector<int> num_shards_per_dim_2{3, 3};
354   ShardVector shard_vec2 = GenFullVector(num_shards_per_dim_2);
355   EXPECT_EQ(shard_vec1, shard_vec2);
356 }
357 
TEST_F(LayoutTest,Shards)358 TEST_F(LayoutTest, Shards) {
359   Layout layout =
360       Layout::FromString("sharding_specs:x,y, mesh:|x=2,y=3|*CPU").ValueOrDie();
361   ShardVector shard_vec = layout.GetShardVector();
362 
363   std::string expected_shard_vec_str =
364       "shards:[(1,1),(1,2),(1,3),(2,1),(2,2),(2,3)] num_shards_per_dim:(2,3)";
365   EXPECT_EQ(shard_vec.ToString(), expected_shard_vec_str);
366 }
367 
TEST_F(LayoutTest,ShardsInverted)368 TEST_F(LayoutTest, ShardsInverted) {
369   Layout layout =
370       Layout::FromString("sharding_specs:y,x, mesh:|x=2,y=3|*CPU").ValueOrDie();
371   ShardVector shards = layout.GetShardVector();
372   std::string expected_shards =
373       "shards:[(1,1),(2,1),(3,1),(1,2),(2,2),(3,2)] num_shards_per_dim:(3,2)";
374   EXPECT_EQ(shards.ToString(), expected_shards);
375 }
376 
TEST_F(LayoutTest,HostShardMap)377 TEST_F(LayoutTest, HostShardMap) {
378   Layout layout =
379       Layout::FromString("sharding_specs:x,y, mesh:TPU|x=2,y=2|*TPU")
380           .ValueOrDie();
381   std::string host_name = layout.mesh().hosts()[0];
382   auto host_map = layout.HostShardMap();
383 
384   std::string expected_shards =
385       "shards:[(1,1),(1,2),(2,1),(2,2)] num_shards_per_dim:(2,2)";
386   EXPECT_EQ(host_map.find(host_name)->second.ToString(), expected_shards);
387 }
388 
TEST_F(LayoutTest,MultiHostMultiDeviceShards)389 TEST_F(LayoutTest, MultiHostMultiDeviceShards) {
390   std::string host1 = "/job:localhost/task:0";
391   std::string host2 = "/job:localhost/task:1";
392   std::string device1 = "/device:TPU:0";
393   std::string device2 = "/device:TPU:1";
394   Layout layout =
395       Layout::FromString(
396           "sharding_specs:x,unsharded, mesh:TPU|x=4,y=1|0,1,2,3|0,1,2,3|" +
397           host1 + device1 + "," + host1 + device2 + "," + host2 + device1 +
398           "," + host2 + device2)
399           .ValueOrDie();
400   std::string expected_shard_vec =
401       "shards:[(1,1),(2,1),(3,1),(4,1)] num_shards_per_dim:(4,1)";
402   EXPECT_EQ(layout.GetShardVector().ToString(), expected_shard_vec);
403 
404   std::map<std::string, ShardVector> host_shard_map = layout.HostShardMap();
405 
406   std::string expected_shards_host1 =
407       "shards:[(1,1),(2,1)] num_shards_per_dim:(4,1)";
408   ShardVector host1_shard_vec = host_shard_map.find(host1)->second;
409   EXPECT_EQ(host1_shard_vec.ToString(), expected_shards_host1);
410 
411   std::string expected_shards_host2 =
412       "shards:[(3,1),(4,1)] num_shards_per_dim:(4,1)";
413   ShardVector host2_shard_vec = host_shard_map.find(host2)->second;
414   EXPECT_EQ(host2_shard_vec.ToString(), expected_shards_host2);
415 }
416 
TEST_F(LayoutTest,MultiHostCommXYSharded)417 TEST_F(LayoutTest, MultiHostCommXYSharded) {
418   std::string host_0 = "/job:localhost/task:0/";
419   std::string host_1 = "/job:localhost/task:1/";
420 
421   StatusOr<Layout> send_layout =
422       Layout::FromString("sharding_specs:y,x, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|" +
423                          host_0 + "device:CPU:0," + host_0 + "device:CPU:1," +
424                          host_1 + "device:CPU:0," + host_1 + "device:CPU:1");
425   StatusOr<Layout> recv_layout =
426       Layout::FromString("sharding_specs:x,y, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|" +
427                          host_0 + "device:TPU:0," + host_0 + "device:TPU:1," +
428                          host_1 + "device:TPU:0," + host_1 + "device:TPU:1");
429 
430   std::vector<std::string> send_hosts = send_layout->ReducedMesh().hosts();
431   std::vector<std::string> recv_hosts = recv_layout->ReducedMesh().hosts();
432   EXPECT_TRUE(send_hosts == recv_hosts);
433 }
434 
TEST_F(LayoutTest,MultiHostCommXSharded)435 TEST_F(LayoutTest, MultiHostCommXSharded) {
436   std::vector<std::string> hosts{"/job:localhost/task:0",
437                                  "/job:localhost/task:1"};
438 
439   StatusOr<Layout> send_layout = Layout::FromString(
440       "sharding_specs:x, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|" + hosts[0] +
441       "/device:CPU:0," + hosts[0] + "/device:CPU:1," + hosts[1] +
442       "/device:CPU:0," + hosts[1] + "/device:CPU:1");
443   StatusOr<Layout> recv_layout = Layout::FromString(
444       "sharding_specs:x, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|" + hosts[0] +
445       "/device:TPU:0," + hosts[0] + "/device:TPU:1," + hosts[1] +
446       "/device:TPU:0," + hosts[1] + "/device:TPU:1");
447 
448   std::vector<std::string> send_hosts = send_layout->ReducedMesh().hosts();
449   std::vector<std::string> recv_hosts = recv_layout->ReducedMesh().hosts();
450   EXPECT_TRUE(send_hosts == recv_hosts);
451 
452   std::map<std::string, ShardVector> send_host_shard_map =
453       send_layout->HostShardMap();
454   std::map<std::string, ShardVector> recv_host_shard_map =
455       recv_layout->HostShardMap();
456 
457   // Check shards match in each host.
458   for (const std::string& host : hosts) {
459     ShardVector shard_vec_in_send_host = send_host_shard_map.find(host)->second;
460     ShardVector shard_vec_in_recv_host = recv_host_shard_map.find(host)->second;
461     EXPECT_EQ(shard_vec_in_send_host, shard_vec_in_recv_host);
462   }
463 }
464 
TEST_F(LayoutTest,Transposed2DLayout)465 TEST_F(LayoutTest, Transposed2DLayout) {
466   auto layout =
467       Layout::FromString("sharding_specs:x,y, mesh:|x=2,y=2|*CPU").ValueOrDie();
468   auto expected_layout =
469       Layout::FromString("sharding_specs:y,x, mesh:|x=2,y=2|*CPU").ValueOrDie();
470   EXPECT_EQ(Layout::Transposed2D(layout).ValueOrDie(), expected_layout);
471 }
472 
TEST_F(LayoutTest,Transposed2DLayoutWithBatch)473 TEST_F(LayoutTest, Transposed2DLayoutWithBatch) {
474   auto layout = Layout::FromString(
475                     "sharding_specs:b1,b2,x,y, mesh:|x=2,y=2,b1=2,b2=2|*CPU")
476                     .ValueOrDie();
477   auto expected_layout =
478       Layout::FromString(
479           "sharding_specs:b1,b2,y,x, mesh:|x=2,y=2,b1=2,b2=2|*CPU")
480           .ValueOrDie();
481   EXPECT_EQ(Layout::Transposed2D(layout).ValueOrDie(), expected_layout);
482 }
483 
TEST_F(LayoutTest,MeshDimensionIndex)484 TEST_F(LayoutTest, MeshDimensionIndex) {
485   auto layout =
486       Layout::FromString("sharding_specs:x,y, mesh:|x=2,y=2|*CPU").ValueOrDie();
487   EXPECT_EQ(layout.mesh().idx_for_dim("x").ValueOrDie(), 0);
488   EXPECT_EQ(layout.mesh().idx_for_dim("y").ValueOrDie(), 1);
489 }
490 
TEST_F(LayoutTest,TruncateBeginning)491 TEST_F(LayoutTest, TruncateBeginning) {
492   auto layout = Layout::FromString("sharding_specs:x,y, mesh:CPU|x=2,y=2|*CPU")
493                     .ValueOrDie();
494   auto expected_layout =
495       Layout::FromString("sharding_specs:x, mesh:CPU|x=2,y=2|*CPU")
496           .ValueOrDie();
497   EXPECT_EQ(layout.Truncate(/*split_point=*/1), expected_layout);
498 }
499 
TEST_F(LayoutTest,TruncateEnd)500 TEST_F(LayoutTest, TruncateEnd) {
501   auto layout = Layout::FromString("sharding_specs:x,y, mesh:CPU|x=2,y=2|*CPU")
502                     .ValueOrDie();
503   auto expected_layout =
504       Layout::FromString("sharding_specs:y, mesh:CPU|x=2,y=2|*CPU")
505           .ValueOrDie();
506   EXPECT_EQ(layout.Truncate(/*split_point=*/1, /*end=*/true), expected_layout);
507 }
508 
TEST_F(LayoutTest,Concatenate)509 TEST_F(LayoutTest, Concatenate) {
510   auto layout_1 = Layout::FromString("sharding_specs:x, mesh:CPU|x=2,y=2|*CPU")
511                       .ValueOrDie();
512   auto layout_2 = Layout::FromString("sharding_specs:y, mesh:CPU|x=2,y=2|*CPU")
513                       .ValueOrDie();
514   auto expected_layout =
515       Layout::FromString("sharding_specs:x,y, mesh:CPU|x=2,y=2|*CPU")
516           .ValueOrDie();
517   EXPECT_EQ(ConcatenateLayouts(layout_1, layout_2).ValueOrDie(),
518             expected_layout);
519 }
520 
TEST_F(LayoutTest,ConcatenateDifferentMesh)521 TEST_F(LayoutTest, ConcatenateDifferentMesh) {
522   auto layout_1 =
523       Layout::FromString("sharding_specs:x, mesh:CPU|x=2|*CPU").ValueOrDie();
524   auto layout_2 =
525       Layout::FromString("sharding_specs:y, mesh:CPU|y=2|*CPU").ValueOrDie();
526   auto layout = ConcatenateLayouts(layout_1, layout_2);
527   EXPECT_FALSE(layout.ok()) << layout.status();
528 }
529 
TEST_F(LayoutTest,ConcatenateSameDimension)530 TEST_F(LayoutTest, ConcatenateSameDimension) {
531   auto layout_1 = Layout::FromString("sharding_specs:x, mesh:CPU|x=2,y=2|*CPU")
532                       .ValueOrDie();
533   auto layout_2 = Layout::FromString("sharding_specs:x, mesh:CPU|x=2,y=2|*CPU")
534                       .ValueOrDie();
535   auto layout = ConcatenateLayouts(layout_1, layout_2);
536   EXPECT_FALSE(layout.ok()) << layout.status();
537 }
538 
TEST_F(LayoutTest,EmptyMeshDeviceType)539 TEST_F(LayoutTest, EmptyMeshDeviceType) {
540   auto mesh = Mesh::Empty();
541   EXPECT_EQ(mesh.device_type(), std::string());
542 }
543 
TEST_F(LayoutTest,ConvertMeshDeviceType)544 TEST_F(LayoutTest, ConvertMeshDeviceType) {
545   Mesh mesh = Mesh::FromString("mesh:|x=2,batch=1|*TPU").ValueOrDie();
546   Mesh cpu_mesh = mesh.ToDeviceType("CPU").ValueOrDie();
547   EXPECT_TRUE(cpu_mesh.is_cpu_mesh());
548 
549   std::string expected_task_name = "/job:localhost/replica:0/task:0/";
550   Mesh expected_mesh =
551       Mesh::FromString("mesh:|x=2,batch=1|0,1|0,1|" + expected_task_name +
552                        "device:CPU:0," + expected_task_name + "device:CPU:1")
553           .ValueOrDie();
554   EXPECT_EQ(cpu_mesh, expected_mesh);
555 }
556 
TEST_F(LayoutTest,EquivalentLayout)557 TEST_F(LayoutTest, EquivalentLayout) {
558   Layout fully_sharded =
559       Layout::FromString("sharding_specs:x,y, mesh:|x=2,y=1|*TPU").ValueOrDie();
560   Layout x_sharded =
561       Layout::FromString("sharding_specs:x,unsharded, mesh:|x=2,y=1|*TPU")
562           .ValueOrDie();
563   Layout y_sharded =
564       Layout::FromString("sharding_specs:unsharded,y, mesh:|x=2,y=1|*TPU")
565           .ValueOrDie();
566 
567   EXPECT_TRUE(fully_sharded.IsEquivalent(x_sharded));
568   EXPECT_TRUE(x_sharded.IsEquivalent(fully_sharded));
569   EXPECT_FALSE(fully_sharded.IsEquivalent(y_sharded));
570   EXPECT_FALSE(y_sharded.IsEquivalent(fully_sharded));
571 }
572 
573 }  // namespace
574 }  // namespace dtensor
575 }  // namespace tensorflow
576