1 /*
2 * Copyright 2004 The WebRTC Project Authors. All rights reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include <string.h>
12
13 #include <algorithm>
14 #include <memory>
15 #include <string>
16 #include <vector>
17
18 #include "absl/memory/memory.h"
19 #include "rtc_base/async_packet_socket.h"
20 #include "rtc_base/async_tcp_socket.h"
21 #include "rtc_base/async_udp_socket.h"
22 #include "rtc_base/gunit.h"
23 #include "rtc_base/ip_address.h"
24 #include "rtc_base/logging.h"
25 #include "rtc_base/nat_server.h"
26 #include "rtc_base/nat_socket_factory.h"
27 #include "rtc_base/nat_types.h"
28 #include "rtc_base/net_helpers.h"
29 #include "rtc_base/network.h"
30 #include "rtc_base/physical_socket_server.h"
31 #include "rtc_base/socket.h"
32 #include "rtc_base/socket_address.h"
33 #include "rtc_base/socket_factory.h"
34 #include "rtc_base/socket_server.h"
35 #include "rtc_base/test_client.h"
36 #include "rtc_base/third_party/sigslot/sigslot.h"
37 #include "rtc_base/thread.h"
38 #include "rtc_base/virtual_socket_server.h"
39 #include "test/gtest.h"
40 #include "test/scoped_key_value_config.h"
41
42 namespace rtc {
43 namespace {
44
CheckReceive(TestClient * client,bool should_receive,const char * buf,size_t size)45 bool CheckReceive(TestClient* client,
46 bool should_receive,
47 const char* buf,
48 size_t size) {
49 return (should_receive) ? client->CheckNextPacket(buf, size, 0)
50 : client->CheckNoPacket();
51 }
52
CreateTestClient(SocketFactory * factory,const SocketAddress & local_addr)53 TestClient* CreateTestClient(SocketFactory* factory,
54 const SocketAddress& local_addr) {
55 return new TestClient(
56 absl::WrapUnique(AsyncUDPSocket::Create(factory, local_addr)));
57 }
58
CreateTCPTestClient(Socket * socket)59 TestClient* CreateTCPTestClient(Socket* socket) {
60 return new TestClient(std::make_unique<AsyncTCPSocket>(socket));
61 }
62
63 // Tests that when sending from internal_addr to external_addrs through the
64 // NAT type specified by nat_type, all external addrs receive the sent packet
65 // and, if exp_same is true, all use the same mapped-address on the NAT.
TestSend(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4],NATType nat_type,bool exp_same)66 void TestSend(SocketServer* internal,
67 const SocketAddress& internal_addr,
68 SocketServer* external,
69 const SocketAddress external_addrs[4],
70 NATType nat_type,
71 bool exp_same) {
72 Thread th_int(internal);
73 Thread th_ext(external);
74
75 SocketAddress server_addr = internal_addr;
76 server_addr.SetPort(0); // Auto-select a port
77 NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr,
78 external, external_addrs[0]);
79 NATSocketFactory* natsf = new NATSocketFactory(
80 internal, nat->internal_udp_address(), nat->internal_tcp_address());
81
82 TestClient* in = CreateTestClient(natsf, internal_addr);
83 TestClient* out[4];
84 for (int i = 0; i < 4; i++)
85 out[i] = CreateTestClient(external, external_addrs[i]);
86
87 th_int.Start();
88 th_ext.Start();
89
90 const char* buf = "filter_test";
91 size_t len = strlen(buf);
92
93 in->SendTo(buf, len, out[0]->address());
94 SocketAddress trans_addr;
95 EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr));
96
97 for (int i = 1; i < 4; i++) {
98 in->SendTo(buf, len, out[i]->address());
99 SocketAddress trans_addr2;
100 EXPECT_TRUE(out[i]->CheckNextPacket(buf, len, &trans_addr2));
101 bool are_same = (trans_addr == trans_addr2);
102 ASSERT_EQ(are_same, exp_same) << "same translated address";
103 ASSERT_NE(AF_UNSPEC, trans_addr.family());
104 ASSERT_NE(AF_UNSPEC, trans_addr2.family());
105 }
106
107 th_int.Stop();
108 th_ext.Stop();
109
110 delete nat;
111 delete natsf;
112 delete in;
113 for (int i = 0; i < 4; i++)
114 delete out[i];
115 }
116
117 // Tests that when sending from external_addrs to internal_addr, the packet
118 // is delivered according to the specified filter_ip and filter_port rules.
TestRecv(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4],NATType nat_type,bool filter_ip,bool filter_port)119 void TestRecv(SocketServer* internal,
120 const SocketAddress& internal_addr,
121 SocketServer* external,
122 const SocketAddress external_addrs[4],
123 NATType nat_type,
124 bool filter_ip,
125 bool filter_port) {
126 Thread th_int(internal);
127 Thread th_ext(external);
128
129 SocketAddress server_addr = internal_addr;
130 server_addr.SetPort(0); // Auto-select a port
131 NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr,
132 external, external_addrs[0]);
133 NATSocketFactory* natsf = new NATSocketFactory(
134 internal, nat->internal_udp_address(), nat->internal_tcp_address());
135
136 TestClient* in = CreateTestClient(natsf, internal_addr);
137 TestClient* out[4];
138 for (int i = 0; i < 4; i++)
139 out[i] = CreateTestClient(external, external_addrs[i]);
140
141 th_int.Start();
142 th_ext.Start();
143
144 const char* buf = "filter_test";
145 size_t len = strlen(buf);
146
147 in->SendTo(buf, len, out[0]->address());
148 SocketAddress trans_addr;
149 EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr));
150
151 out[1]->SendTo(buf, len, trans_addr);
152 EXPECT_TRUE(CheckReceive(in, !filter_ip, buf, len));
153
154 out[2]->SendTo(buf, len, trans_addr);
155 EXPECT_TRUE(CheckReceive(in, !filter_port, buf, len));
156
157 out[3]->SendTo(buf, len, trans_addr);
158 EXPECT_TRUE(CheckReceive(in, !filter_ip && !filter_port, buf, len));
159
160 th_int.Stop();
161 th_ext.Stop();
162
163 delete nat;
164 delete natsf;
165 delete in;
166 for (int i = 0; i < 4; i++)
167 delete out[i];
168 }
169
170 // Tests that NATServer allocates bindings properly.
TestBindings(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4])171 void TestBindings(SocketServer* internal,
172 const SocketAddress& internal_addr,
173 SocketServer* external,
174 const SocketAddress external_addrs[4]) {
175 TestSend(internal, internal_addr, external, external_addrs, NAT_OPEN_CONE,
176 true);
177 TestSend(internal, internal_addr, external, external_addrs,
178 NAT_ADDR_RESTRICTED, true);
179 TestSend(internal, internal_addr, external, external_addrs,
180 NAT_PORT_RESTRICTED, true);
181 TestSend(internal, internal_addr, external, external_addrs, NAT_SYMMETRIC,
182 false);
183 }
184
185 // Tests that NATServer filters packets properly.
TestFilters(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4])186 void TestFilters(SocketServer* internal,
187 const SocketAddress& internal_addr,
188 SocketServer* external,
189 const SocketAddress external_addrs[4]) {
190 TestRecv(internal, internal_addr, external, external_addrs, NAT_OPEN_CONE,
191 false, false);
192 TestRecv(internal, internal_addr, external, external_addrs,
193 NAT_ADDR_RESTRICTED, true, false);
194 TestRecv(internal, internal_addr, external, external_addrs,
195 NAT_PORT_RESTRICTED, true, true);
196 TestRecv(internal, internal_addr, external, external_addrs, NAT_SYMMETRIC,
197 true, true);
198 }
199
TestConnectivity(const SocketAddress & src,const IPAddress & dst)200 bool TestConnectivity(const SocketAddress& src, const IPAddress& dst) {
201 // The physical NAT tests require connectivity to the selected ip from the
202 // internal address used for the NAT. Things like firewalls can break that, so
203 // check to see if it's worth even trying with this ip.
204 std::unique_ptr<PhysicalSocketServer> pss(new PhysicalSocketServer());
205 std::unique_ptr<Socket> client(pss->CreateSocket(src.family(), SOCK_DGRAM));
206 std::unique_ptr<Socket> server(pss->CreateSocket(src.family(), SOCK_DGRAM));
207 if (client->Bind(SocketAddress(src.ipaddr(), 0)) != 0 ||
208 server->Bind(SocketAddress(dst, 0)) != 0) {
209 return false;
210 }
211 const char* buf = "hello other socket";
212 size_t len = strlen(buf);
213 int sent = client->SendTo(buf, len, server->GetLocalAddress());
214 SocketAddress addr;
215 const size_t kRecvBufSize = 64;
216 char recvbuf[kRecvBufSize];
217 Thread::Current()->SleepMs(100);
218 int received = server->RecvFrom(recvbuf, kRecvBufSize, &addr, nullptr);
219 return received == sent && ::memcmp(buf, recvbuf, len) == 0;
220 }
221
TestPhysicalInternal(const SocketAddress & int_addr)222 void TestPhysicalInternal(const SocketAddress& int_addr) {
223 webrtc::test::ScopedKeyValueConfig field_trials;
224 rtc::AutoThread main_thread;
225 PhysicalSocketServer socket_server;
226 BasicNetworkManager network_manager(nullptr, &socket_server, &field_trials);
227 network_manager.StartUpdating();
228 // Process pending messages so the network list is updated.
229 Thread::Current()->ProcessMessages(0);
230
231 std::vector<const Network*> networks = network_manager.GetNetworks();
232 networks.erase(std::remove_if(networks.begin(), networks.end(),
233 [](const rtc::Network* network) {
234 return rtc::kDefaultNetworkIgnoreMask &
235 network->type();
236 }),
237 networks.end());
238 if (networks.empty()) {
239 RTC_LOG(LS_WARNING) << "Not enough network adapters for test.";
240 return;
241 }
242
243 SocketAddress ext_addr1(int_addr);
244 SocketAddress ext_addr2;
245 // Find an available IP with matching family. The test breaks if int_addr
246 // can't talk to ip, so check for connectivity as well.
247 for (const Network* const network : networks) {
248 const IPAddress& ip = network->GetBestIP();
249 if (ip.family() == int_addr.family() && TestConnectivity(int_addr, ip)) {
250 ext_addr2.SetIP(ip);
251 break;
252 }
253 }
254 if (ext_addr2.IsNil()) {
255 RTC_LOG(LS_WARNING) << "No available IP of same family as "
256 << int_addr.ToString();
257 return;
258 }
259
260 RTC_LOG(LS_INFO) << "selected ip " << ext_addr2.ipaddr().ToString();
261
262 SocketAddress ext_addrs[4] = {
263 SocketAddress(ext_addr1), SocketAddress(ext_addr2),
264 SocketAddress(ext_addr1), SocketAddress(ext_addr2)};
265
266 std::unique_ptr<PhysicalSocketServer> int_pss(new PhysicalSocketServer());
267 std::unique_ptr<PhysicalSocketServer> ext_pss(new PhysicalSocketServer());
268
269 TestBindings(int_pss.get(), int_addr, ext_pss.get(), ext_addrs);
270 TestFilters(int_pss.get(), int_addr, ext_pss.get(), ext_addrs);
271 }
272
TEST(NatTest,TestPhysicalIPv4)273 TEST(NatTest, TestPhysicalIPv4) {
274 TestPhysicalInternal(SocketAddress("127.0.0.1", 0));
275 }
276
TEST(NatTest,TestPhysicalIPv6)277 TEST(NatTest, TestPhysicalIPv6) {
278 if (HasIPv6Enabled()) {
279 TestPhysicalInternal(SocketAddress("::1", 0));
280 } else {
281 RTC_LOG(LS_WARNING) << "No IPv6, skipping";
282 }
283 }
284
285 namespace {
286
287 class TestVirtualSocketServer : public VirtualSocketServer {
288 public:
289 // Expose this publicly
GetNextIP(int af)290 IPAddress GetNextIP(int af) { return VirtualSocketServer::GetNextIP(af); }
291 };
292
293 } // namespace
294
TestVirtualInternal(int family)295 void TestVirtualInternal(int family) {
296 rtc::AutoThread main_thread;
297 std::unique_ptr<TestVirtualSocketServer> int_vss(
298 new TestVirtualSocketServer());
299 std::unique_ptr<TestVirtualSocketServer> ext_vss(
300 new TestVirtualSocketServer());
301
302 SocketAddress int_addr;
303 SocketAddress ext_addrs[4];
304 int_addr.SetIP(int_vss->GetNextIP(family));
305 ext_addrs[0].SetIP(ext_vss->GetNextIP(int_addr.family()));
306 ext_addrs[1].SetIP(ext_vss->GetNextIP(int_addr.family()));
307 ext_addrs[2].SetIP(ext_addrs[0].ipaddr());
308 ext_addrs[3].SetIP(ext_addrs[1].ipaddr());
309
310 TestBindings(int_vss.get(), int_addr, ext_vss.get(), ext_addrs);
311 TestFilters(int_vss.get(), int_addr, ext_vss.get(), ext_addrs);
312 }
313
TEST(NatTest,TestVirtualIPv4)314 TEST(NatTest, TestVirtualIPv4) {
315 TestVirtualInternal(AF_INET);
316 }
317
TEST(NatTest,TestVirtualIPv6)318 TEST(NatTest, TestVirtualIPv6) {
319 if (HasIPv6Enabled()) {
320 TestVirtualInternal(AF_INET6);
321 } else {
322 RTC_LOG(LS_WARNING) << "No IPv6, skipping";
323 }
324 }
325
326 class NatTcpTest : public ::testing::Test, public sigslot::has_slots<> {
327 public:
NatTcpTest()328 NatTcpTest()
329 : int_addr_("192.168.0.1", 0),
330 ext_addr_("10.0.0.1", 0),
331 connected_(false),
332 int_vss_(new TestVirtualSocketServer()),
333 ext_vss_(new TestVirtualSocketServer()),
334 int_thread_(new Thread(int_vss_.get())),
335 ext_thread_(new Thread(ext_vss_.get())),
336 nat_(new NATServer(NAT_OPEN_CONE,
337 int_vss_.get(),
338 int_addr_,
339 int_addr_,
340 ext_vss_.get(),
341 ext_addr_)),
342 natsf_(new NATSocketFactory(int_vss_.get(),
343 nat_->internal_udp_address(),
344 nat_->internal_tcp_address())) {
345 int_thread_->Start();
346 ext_thread_->Start();
347 }
348
OnConnectEvent(Socket * socket)349 void OnConnectEvent(Socket* socket) { connected_ = true; }
350
OnAcceptEvent(Socket * socket)351 void OnAcceptEvent(Socket* socket) {
352 accepted_.reset(server_->Accept(nullptr));
353 }
354
OnCloseEvent(Socket * socket,int error)355 void OnCloseEvent(Socket* socket, int error) {}
356
ConnectEvents()357 void ConnectEvents() {
358 server_->SignalReadEvent.connect(this, &NatTcpTest::OnAcceptEvent);
359 client_->SignalConnectEvent.connect(this, &NatTcpTest::OnConnectEvent);
360 }
361
362 SocketAddress int_addr_;
363 SocketAddress ext_addr_;
364 bool connected_;
365 std::unique_ptr<TestVirtualSocketServer> int_vss_;
366 std::unique_ptr<TestVirtualSocketServer> ext_vss_;
367 std::unique_ptr<Thread> int_thread_;
368 std::unique_ptr<Thread> ext_thread_;
369 std::unique_ptr<NATServer> nat_;
370 std::unique_ptr<NATSocketFactory> natsf_;
371 std::unique_ptr<Socket> client_;
372 std::unique_ptr<Socket> server_;
373 std::unique_ptr<Socket> accepted_;
374 };
375
TEST_F(NatTcpTest,DISABLED_TestConnectOut)376 TEST_F(NatTcpTest, DISABLED_TestConnectOut) {
377 server_.reset(ext_vss_->CreateSocket(AF_INET, SOCK_STREAM));
378 server_->Bind(ext_addr_);
379 server_->Listen(5);
380
381 client_.reset(natsf_->CreateSocket(AF_INET, SOCK_STREAM));
382 EXPECT_GE(0, client_->Bind(int_addr_));
383 EXPECT_GE(0, client_->Connect(server_->GetLocalAddress()));
384
385 ConnectEvents();
386
387 EXPECT_TRUE_WAIT(connected_, 1000);
388 EXPECT_EQ(client_->GetRemoteAddress(), server_->GetLocalAddress());
389 EXPECT_EQ(accepted_->GetRemoteAddress().ipaddr(), ext_addr_.ipaddr());
390
391 std::unique_ptr<rtc::TestClient> in(CreateTCPTestClient(client_.release()));
392 std::unique_ptr<rtc::TestClient> out(
393 CreateTCPTestClient(accepted_.release()));
394
395 const char* buf = "test_packet";
396 size_t len = strlen(buf);
397
398 in->Send(buf, len);
399 SocketAddress trans_addr;
400 EXPECT_TRUE(out->CheckNextPacket(buf, len, &trans_addr));
401
402 out->Send(buf, len);
403 EXPECT_TRUE(in->CheckNextPacket(buf, len, &trans_addr));
404 }
405
406 } // namespace
407 } // namespace rtc
408