xref: /aosp_15_r20/external/webrtc/rtc_base/nat_unittest.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <string.h>
12 
13 #include <algorithm>
14 #include <memory>
15 #include <string>
16 #include <vector>
17 
18 #include "absl/memory/memory.h"
19 #include "rtc_base/async_packet_socket.h"
20 #include "rtc_base/async_tcp_socket.h"
21 #include "rtc_base/async_udp_socket.h"
22 #include "rtc_base/gunit.h"
23 #include "rtc_base/ip_address.h"
24 #include "rtc_base/logging.h"
25 #include "rtc_base/nat_server.h"
26 #include "rtc_base/nat_socket_factory.h"
27 #include "rtc_base/nat_types.h"
28 #include "rtc_base/net_helpers.h"
29 #include "rtc_base/network.h"
30 #include "rtc_base/physical_socket_server.h"
31 #include "rtc_base/socket.h"
32 #include "rtc_base/socket_address.h"
33 #include "rtc_base/socket_factory.h"
34 #include "rtc_base/socket_server.h"
35 #include "rtc_base/test_client.h"
36 #include "rtc_base/third_party/sigslot/sigslot.h"
37 #include "rtc_base/thread.h"
38 #include "rtc_base/virtual_socket_server.h"
39 #include "test/gtest.h"
40 #include "test/scoped_key_value_config.h"
41 
42 namespace rtc {
43 namespace {
44 
CheckReceive(TestClient * client,bool should_receive,const char * buf,size_t size)45 bool CheckReceive(TestClient* client,
46                   bool should_receive,
47                   const char* buf,
48                   size_t size) {
49   return (should_receive) ? client->CheckNextPacket(buf, size, 0)
50                           : client->CheckNoPacket();
51 }
52 
CreateTestClient(SocketFactory * factory,const SocketAddress & local_addr)53 TestClient* CreateTestClient(SocketFactory* factory,
54                              const SocketAddress& local_addr) {
55   return new TestClient(
56       absl::WrapUnique(AsyncUDPSocket::Create(factory, local_addr)));
57 }
58 
CreateTCPTestClient(Socket * socket)59 TestClient* CreateTCPTestClient(Socket* socket) {
60   return new TestClient(std::make_unique<AsyncTCPSocket>(socket));
61 }
62 
63 // Tests that when sending from internal_addr to external_addrs through the
64 // NAT type specified by nat_type, all external addrs receive the sent packet
65 // and, if exp_same is true, all use the same mapped-address on the NAT.
TestSend(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4],NATType nat_type,bool exp_same)66 void TestSend(SocketServer* internal,
67               const SocketAddress& internal_addr,
68               SocketServer* external,
69               const SocketAddress external_addrs[4],
70               NATType nat_type,
71               bool exp_same) {
72   Thread th_int(internal);
73   Thread th_ext(external);
74 
75   SocketAddress server_addr = internal_addr;
76   server_addr.SetPort(0);  // Auto-select a port
77   NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr,
78                                  external, external_addrs[0]);
79   NATSocketFactory* natsf = new NATSocketFactory(
80       internal, nat->internal_udp_address(), nat->internal_tcp_address());
81 
82   TestClient* in = CreateTestClient(natsf, internal_addr);
83   TestClient* out[4];
84   for (int i = 0; i < 4; i++)
85     out[i] = CreateTestClient(external, external_addrs[i]);
86 
87   th_int.Start();
88   th_ext.Start();
89 
90   const char* buf = "filter_test";
91   size_t len = strlen(buf);
92 
93   in->SendTo(buf, len, out[0]->address());
94   SocketAddress trans_addr;
95   EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr));
96 
97   for (int i = 1; i < 4; i++) {
98     in->SendTo(buf, len, out[i]->address());
99     SocketAddress trans_addr2;
100     EXPECT_TRUE(out[i]->CheckNextPacket(buf, len, &trans_addr2));
101     bool are_same = (trans_addr == trans_addr2);
102     ASSERT_EQ(are_same, exp_same) << "same translated address";
103     ASSERT_NE(AF_UNSPEC, trans_addr.family());
104     ASSERT_NE(AF_UNSPEC, trans_addr2.family());
105   }
106 
107   th_int.Stop();
108   th_ext.Stop();
109 
110   delete nat;
111   delete natsf;
112   delete in;
113   for (int i = 0; i < 4; i++)
114     delete out[i];
115 }
116 
117 // Tests that when sending from external_addrs to internal_addr, the packet
118 // is delivered according to the specified filter_ip and filter_port rules.
TestRecv(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4],NATType nat_type,bool filter_ip,bool filter_port)119 void TestRecv(SocketServer* internal,
120               const SocketAddress& internal_addr,
121               SocketServer* external,
122               const SocketAddress external_addrs[4],
123               NATType nat_type,
124               bool filter_ip,
125               bool filter_port) {
126   Thread th_int(internal);
127   Thread th_ext(external);
128 
129   SocketAddress server_addr = internal_addr;
130   server_addr.SetPort(0);  // Auto-select a port
131   NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr,
132                                  external, external_addrs[0]);
133   NATSocketFactory* natsf = new NATSocketFactory(
134       internal, nat->internal_udp_address(), nat->internal_tcp_address());
135 
136   TestClient* in = CreateTestClient(natsf, internal_addr);
137   TestClient* out[4];
138   for (int i = 0; i < 4; i++)
139     out[i] = CreateTestClient(external, external_addrs[i]);
140 
141   th_int.Start();
142   th_ext.Start();
143 
144   const char* buf = "filter_test";
145   size_t len = strlen(buf);
146 
147   in->SendTo(buf, len, out[0]->address());
148   SocketAddress trans_addr;
149   EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr));
150 
151   out[1]->SendTo(buf, len, trans_addr);
152   EXPECT_TRUE(CheckReceive(in, !filter_ip, buf, len));
153 
154   out[2]->SendTo(buf, len, trans_addr);
155   EXPECT_TRUE(CheckReceive(in, !filter_port, buf, len));
156 
157   out[3]->SendTo(buf, len, trans_addr);
158   EXPECT_TRUE(CheckReceive(in, !filter_ip && !filter_port, buf, len));
159 
160   th_int.Stop();
161   th_ext.Stop();
162 
163   delete nat;
164   delete natsf;
165   delete in;
166   for (int i = 0; i < 4; i++)
167     delete out[i];
168 }
169 
170 // Tests that NATServer allocates bindings properly.
TestBindings(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4])171 void TestBindings(SocketServer* internal,
172                   const SocketAddress& internal_addr,
173                   SocketServer* external,
174                   const SocketAddress external_addrs[4]) {
175   TestSend(internal, internal_addr, external, external_addrs, NAT_OPEN_CONE,
176            true);
177   TestSend(internal, internal_addr, external, external_addrs,
178            NAT_ADDR_RESTRICTED, true);
179   TestSend(internal, internal_addr, external, external_addrs,
180            NAT_PORT_RESTRICTED, true);
181   TestSend(internal, internal_addr, external, external_addrs, NAT_SYMMETRIC,
182            false);
183 }
184 
185 // Tests that NATServer filters packets properly.
TestFilters(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4])186 void TestFilters(SocketServer* internal,
187                  const SocketAddress& internal_addr,
188                  SocketServer* external,
189                  const SocketAddress external_addrs[4]) {
190   TestRecv(internal, internal_addr, external, external_addrs, NAT_OPEN_CONE,
191            false, false);
192   TestRecv(internal, internal_addr, external, external_addrs,
193            NAT_ADDR_RESTRICTED, true, false);
194   TestRecv(internal, internal_addr, external, external_addrs,
195            NAT_PORT_RESTRICTED, true, true);
196   TestRecv(internal, internal_addr, external, external_addrs, NAT_SYMMETRIC,
197            true, true);
198 }
199 
TestConnectivity(const SocketAddress & src,const IPAddress & dst)200 bool TestConnectivity(const SocketAddress& src, const IPAddress& dst) {
201   // The physical NAT tests require connectivity to the selected ip from the
202   // internal address used for the NAT. Things like firewalls can break that, so
203   // check to see if it's worth even trying with this ip.
204   std::unique_ptr<PhysicalSocketServer> pss(new PhysicalSocketServer());
205   std::unique_ptr<Socket> client(pss->CreateSocket(src.family(), SOCK_DGRAM));
206   std::unique_ptr<Socket> server(pss->CreateSocket(src.family(), SOCK_DGRAM));
207   if (client->Bind(SocketAddress(src.ipaddr(), 0)) != 0 ||
208       server->Bind(SocketAddress(dst, 0)) != 0) {
209     return false;
210   }
211   const char* buf = "hello other socket";
212   size_t len = strlen(buf);
213   int sent = client->SendTo(buf, len, server->GetLocalAddress());
214   SocketAddress addr;
215   const size_t kRecvBufSize = 64;
216   char recvbuf[kRecvBufSize];
217   Thread::Current()->SleepMs(100);
218   int received = server->RecvFrom(recvbuf, kRecvBufSize, &addr, nullptr);
219   return received == sent && ::memcmp(buf, recvbuf, len) == 0;
220 }
221 
TestPhysicalInternal(const SocketAddress & int_addr)222 void TestPhysicalInternal(const SocketAddress& int_addr) {
223   webrtc::test::ScopedKeyValueConfig field_trials;
224   rtc::AutoThread main_thread;
225   PhysicalSocketServer socket_server;
226   BasicNetworkManager network_manager(nullptr, &socket_server, &field_trials);
227   network_manager.StartUpdating();
228   // Process pending messages so the network list is updated.
229   Thread::Current()->ProcessMessages(0);
230 
231   std::vector<const Network*> networks = network_manager.GetNetworks();
232   networks.erase(std::remove_if(networks.begin(), networks.end(),
233                                 [](const rtc::Network* network) {
234                                   return rtc::kDefaultNetworkIgnoreMask &
235                                          network->type();
236                                 }),
237                  networks.end());
238   if (networks.empty()) {
239     RTC_LOG(LS_WARNING) << "Not enough network adapters for test.";
240     return;
241   }
242 
243   SocketAddress ext_addr1(int_addr);
244   SocketAddress ext_addr2;
245   // Find an available IP with matching family. The test breaks if int_addr
246   // can't talk to ip, so check for connectivity as well.
247   for (const Network* const network : networks) {
248     const IPAddress& ip = network->GetBestIP();
249     if (ip.family() == int_addr.family() && TestConnectivity(int_addr, ip)) {
250       ext_addr2.SetIP(ip);
251       break;
252     }
253   }
254   if (ext_addr2.IsNil()) {
255     RTC_LOG(LS_WARNING) << "No available IP of same family as "
256                         << int_addr.ToString();
257     return;
258   }
259 
260   RTC_LOG(LS_INFO) << "selected ip " << ext_addr2.ipaddr().ToString();
261 
262   SocketAddress ext_addrs[4] = {
263       SocketAddress(ext_addr1), SocketAddress(ext_addr2),
264       SocketAddress(ext_addr1), SocketAddress(ext_addr2)};
265 
266   std::unique_ptr<PhysicalSocketServer> int_pss(new PhysicalSocketServer());
267   std::unique_ptr<PhysicalSocketServer> ext_pss(new PhysicalSocketServer());
268 
269   TestBindings(int_pss.get(), int_addr, ext_pss.get(), ext_addrs);
270   TestFilters(int_pss.get(), int_addr, ext_pss.get(), ext_addrs);
271 }
272 
TEST(NatTest,TestPhysicalIPv4)273 TEST(NatTest, TestPhysicalIPv4) {
274   TestPhysicalInternal(SocketAddress("127.0.0.1", 0));
275 }
276 
TEST(NatTest,TestPhysicalIPv6)277 TEST(NatTest, TestPhysicalIPv6) {
278   if (HasIPv6Enabled()) {
279     TestPhysicalInternal(SocketAddress("::1", 0));
280   } else {
281     RTC_LOG(LS_WARNING) << "No IPv6, skipping";
282   }
283 }
284 
285 namespace {
286 
287 class TestVirtualSocketServer : public VirtualSocketServer {
288  public:
289   // Expose this publicly
GetNextIP(int af)290   IPAddress GetNextIP(int af) { return VirtualSocketServer::GetNextIP(af); }
291 };
292 
293 }  // namespace
294 
TestVirtualInternal(int family)295 void TestVirtualInternal(int family) {
296   rtc::AutoThread main_thread;
297   std::unique_ptr<TestVirtualSocketServer> int_vss(
298       new TestVirtualSocketServer());
299   std::unique_ptr<TestVirtualSocketServer> ext_vss(
300       new TestVirtualSocketServer());
301 
302   SocketAddress int_addr;
303   SocketAddress ext_addrs[4];
304   int_addr.SetIP(int_vss->GetNextIP(family));
305   ext_addrs[0].SetIP(ext_vss->GetNextIP(int_addr.family()));
306   ext_addrs[1].SetIP(ext_vss->GetNextIP(int_addr.family()));
307   ext_addrs[2].SetIP(ext_addrs[0].ipaddr());
308   ext_addrs[3].SetIP(ext_addrs[1].ipaddr());
309 
310   TestBindings(int_vss.get(), int_addr, ext_vss.get(), ext_addrs);
311   TestFilters(int_vss.get(), int_addr, ext_vss.get(), ext_addrs);
312 }
313 
TEST(NatTest,TestVirtualIPv4)314 TEST(NatTest, TestVirtualIPv4) {
315   TestVirtualInternal(AF_INET);
316 }
317 
TEST(NatTest,TestVirtualIPv6)318 TEST(NatTest, TestVirtualIPv6) {
319   if (HasIPv6Enabled()) {
320     TestVirtualInternal(AF_INET6);
321   } else {
322     RTC_LOG(LS_WARNING) << "No IPv6, skipping";
323   }
324 }
325 
326 class NatTcpTest : public ::testing::Test, public sigslot::has_slots<> {
327  public:
NatTcpTest()328   NatTcpTest()
329       : int_addr_("192.168.0.1", 0),
330         ext_addr_("10.0.0.1", 0),
331         connected_(false),
332         int_vss_(new TestVirtualSocketServer()),
333         ext_vss_(new TestVirtualSocketServer()),
334         int_thread_(new Thread(int_vss_.get())),
335         ext_thread_(new Thread(ext_vss_.get())),
336         nat_(new NATServer(NAT_OPEN_CONE,
337                            int_vss_.get(),
338                            int_addr_,
339                            int_addr_,
340                            ext_vss_.get(),
341                            ext_addr_)),
342         natsf_(new NATSocketFactory(int_vss_.get(),
343                                     nat_->internal_udp_address(),
344                                     nat_->internal_tcp_address())) {
345     int_thread_->Start();
346     ext_thread_->Start();
347   }
348 
OnConnectEvent(Socket * socket)349   void OnConnectEvent(Socket* socket) { connected_ = true; }
350 
OnAcceptEvent(Socket * socket)351   void OnAcceptEvent(Socket* socket) {
352     accepted_.reset(server_->Accept(nullptr));
353   }
354 
OnCloseEvent(Socket * socket,int error)355   void OnCloseEvent(Socket* socket, int error) {}
356 
ConnectEvents()357   void ConnectEvents() {
358     server_->SignalReadEvent.connect(this, &NatTcpTest::OnAcceptEvent);
359     client_->SignalConnectEvent.connect(this, &NatTcpTest::OnConnectEvent);
360   }
361 
362   SocketAddress int_addr_;
363   SocketAddress ext_addr_;
364   bool connected_;
365   std::unique_ptr<TestVirtualSocketServer> int_vss_;
366   std::unique_ptr<TestVirtualSocketServer> ext_vss_;
367   std::unique_ptr<Thread> int_thread_;
368   std::unique_ptr<Thread> ext_thread_;
369   std::unique_ptr<NATServer> nat_;
370   std::unique_ptr<NATSocketFactory> natsf_;
371   std::unique_ptr<Socket> client_;
372   std::unique_ptr<Socket> server_;
373   std::unique_ptr<Socket> accepted_;
374 };
375 
TEST_F(NatTcpTest,DISABLED_TestConnectOut)376 TEST_F(NatTcpTest, DISABLED_TestConnectOut) {
377   server_.reset(ext_vss_->CreateSocket(AF_INET, SOCK_STREAM));
378   server_->Bind(ext_addr_);
379   server_->Listen(5);
380 
381   client_.reset(natsf_->CreateSocket(AF_INET, SOCK_STREAM));
382   EXPECT_GE(0, client_->Bind(int_addr_));
383   EXPECT_GE(0, client_->Connect(server_->GetLocalAddress()));
384 
385   ConnectEvents();
386 
387   EXPECT_TRUE_WAIT(connected_, 1000);
388   EXPECT_EQ(client_->GetRemoteAddress(), server_->GetLocalAddress());
389   EXPECT_EQ(accepted_->GetRemoteAddress().ipaddr(), ext_addr_.ipaddr());
390 
391   std::unique_ptr<rtc::TestClient> in(CreateTCPTestClient(client_.release()));
392   std::unique_ptr<rtc::TestClient> out(
393       CreateTCPTestClient(accepted_.release()));
394 
395   const char* buf = "test_packet";
396   size_t len = strlen(buf);
397 
398   in->Send(buf, len);
399   SocketAddress trans_addr;
400   EXPECT_TRUE(out->CheckNextPacket(buf, len, &trans_addr));
401 
402   out->Send(buf, len);
403   EXPECT_TRUE(in->CheckNextPacket(buf, len, &trans_addr));
404 }
405 
406 }  // namespace
407 }  // namespace rtc
408