xref: /aosp_15_r20/external/pytorch/test/cpp/rpc/test_wire_serialization.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/csrc/distributed/rpc/utils.h>
5 #include <torch/torch.h>
6 
7 #include <memory>
8 #include <string>
9 #include <vector>
10 
11 using ::testing::IsSubstring;
12 
TEST(WireSerialize,Base)13 TEST(WireSerialize, Base) {
14   auto run = [](const std::string& payload,
15                 const std::vector<at::Tensor>& tensors) {
16     std::string serialized;
17     {
18       std::vector<char> mpayload(payload.begin(), payload.end());
19       std::vector<at::Tensor> mtensors = tensors;
20       serialized = torch::distributed::rpc::wireSerialize(
21           std::move(mpayload), std::move(mtensors));
22     }
23     auto deser = torch::distributed::rpc::wireDeserialize(
24         serialized.data(), serialized.size());
25     EXPECT_EQ(payload.size(), deser.first.size());
26     EXPECT_EQ(tensors.size(), deser.second.size());
27     if (payload.size() > 0) {
28       EXPECT_TRUE(
29           memcmp(deser.first.data(), payload.data(), payload.size()) == 0);
30     }
31     for (const auto i : c10::irange(tensors.size())) {
32       EXPECT_TRUE(torch::equal(tensors[i], deser.second[i]));
33     }
34   };
35   run("", {});
36   run("hi", {});
37   run("", {torch::randn({5, 5})});
38   run("hi", {torch::randn({5, 5})});
39   run("more", {torch::randn({5, 5}), torch::rand({10, 10})});
40 }
41 
TEST(WireSerialize,RecopySparseTensors)42 TEST(WireSerialize, RecopySparseTensors) {
43   // Take a 1K row of a 1M tensors, and make sure we don't send across 1M rows.
44   constexpr size_t k1K = 1024;
45   at::Tensor main = torch::randn({k1K, k1K});
46   at::Tensor tiny = main.select(0, 2); // Select a row in the middle
47   EXPECT_EQ(tiny.numel(), k1K);
48   EXPECT_EQ(tiny.storage().nbytes() / tiny.dtype().itemsize(), k1K * k1K);
49   auto ser = torch::distributed::rpc::wireSerialize({}, {tiny});
50   auto deser = torch::distributed::rpc::wireDeserialize(ser.data(), ser.size());
51   EXPECT_TRUE(torch::equal(tiny, deser.second[0]));
52   EXPECT_LT(ser.size(), (tiny.element_size() * k1K) + k1K);
53 }
54 
TEST(WireSerialize,CloneSparseTensors)55 TEST(WireSerialize, CloneSparseTensors) {
56   constexpr size_t k1K = 1024;
57   at::Tensor big = torch::randn({k1K, k1K});
58   auto v1 = torch::distributed::rpc::cloneSparseTensors({big});
59   EXPECT_EQ(v1.get(0).storage(), big.storage()); // Not cloned
60 
61   at::Tensor tiny = big.select(0, 2); // Select a row in the middle
62   auto v2 = torch::distributed::rpc::cloneSparseTensors({tiny});
63   EXPECT_NE(&v2.get(0).storage(), &tiny.storage()); // Cloned.
64   EXPECT_TRUE(torch::equal(v2.get(0), tiny));
65 
66   at::Tensor sparse = at::empty({2, 3}, at::dtype<float>().layout(at::kSparse));
67   auto v3 = torch::distributed::rpc::cloneSparseTensors({sparse});
68   // There is no storage() to compare, but at least confirm equality.
69   EXPECT_TRUE(v3.get(0).is_same(sparse));
70 }
71 
TEST(WireSerialize,Errors)72 TEST(WireSerialize, Errors) {
73   auto checkMessage = [](auto&& f, const char* msg) {
74     try {
75       f();
76       FAIL();
77     } catch (const std::exception& e) {
78       EXPECT_PRED_FORMAT2(IsSubstring, msg, e.what());
79     } catch (...) {
80       FAIL();
81     }
82   };
83   checkMessage(
84       []() { (void)torch::distributed::rpc::wireDeserialize("", 0); },
85       "failed parse");
86   checkMessage(
87       []() { (void)torch::distributed::rpc::wireDeserialize(" ", 1); },
88       "failed parse");
89   auto serialized =
90       torch::distributed::rpc::wireSerialize({}, {torch::randn({5, 5})});
91   checkMessage(
92       [&]() {
93         (void)torch::distributed::rpc::wireDeserialize(
94             serialized.data(), serialized.size() / 2);
95       },
96       "failed bounds");
97 }
98 
99 // Enable this once JIT Pickler supports sparse tensors.
TEST(WireSerialize,DISABLED_Sparse)100 TEST(WireSerialize, DISABLED_Sparse) {
101   at::Tensor main = at::empty({2, 3}, at::dtype<float>().layout(at::kSparse));
102   auto ser = torch::distributed::rpc::wireSerialize({}, {main.to(at::kSparse)});
103   auto deser = torch::distributed::rpc::wireDeserialize(ser.data(), ser.size());
104   EXPECT_TRUE(torch::equal(main, deser.second[0]));
105 }
106