xref: /aosp_15_r20/external/cronet/net/socket/udp_socket_unittest.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/udp_socket.h"
6 
7 #include <algorithm>
8 
9 #include "base/containers/circular_deque.h"
10 #include "base/functional/bind.h"
11 #include "base/location.h"
12 #include "base/memory/raw_ptr.h"
13 #include "base/memory/weak_ptr.h"
14 #include "base/run_loop.h"
15 #include "base/scoped_clear_last_error.h"
16 #include "base/strings/string_number_conversions.h"
17 #include "base/task/single_thread_task_runner.h"
18 #include "base/test/scoped_feature_list.h"
19 #include "base/threading/thread.h"
20 #include "base/time/time.h"
21 #include "build/build_config.h"
22 #include "build/chromeos_buildflags.h"
23 #include "net/base/features.h"
24 #include "net/base/io_buffer.h"
25 #include "net/base/ip_address.h"
26 #include "net/base/ip_endpoint.h"
27 #include "net/base/net_errors.h"
28 #include "net/base/network_interfaces.h"
29 #include "net/base/test_completion_callback.h"
30 #include "net/log/net_log_event_type.h"
31 #include "net/log/net_log_source.h"
32 #include "net/log/test_net_log.h"
33 #include "net/log/test_net_log_util.h"
34 #include "net/socket/socket_test_util.h"
35 #include "net/socket/udp_client_socket.h"
36 #include "net/socket/udp_server_socket.h"
37 #include "net/socket/udp_socket_global_limits.h"
38 #include "net/test/gtest_util.h"
39 #include "net/test/test_with_task_environment.h"
40 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
41 #include "testing/gmock/include/gmock/gmock.h"
42 #include "testing/gtest/include/gtest/gtest.h"
43 #include "testing/platform_test.h"
44 
45 #if !BUILDFLAG(IS_WIN)
46 #include <netinet/in.h>
47 #include <sys/socket.h>
48 #else
49 #include <winsock2.h>
50 #endif
51 
52 #if BUILDFLAG(IS_ANDROID)
53 #include "base/android/build_info.h"
54 #include "net/android/network_change_notifier_factory_android.h"
55 #include "net/base/network_change_notifier.h"
56 #endif
57 
58 #if BUILDFLAG(IS_IOS)
59 #include <TargetConditionals.h>
60 #endif
61 
62 #if BUILDFLAG(IS_MAC)
63 #include "base/mac/mac_util.h"
64 #endif  // BUILDFLAG(IS_MAC)
65 
66 using net::test::IsError;
67 using net::test::IsOk;
68 using testing::DoAll;
69 using testing::Not;
70 
71 namespace net {
72 
73 namespace {
74 
75 // Creates an address from ip address and port and writes it to |*address|.
CreateUDPAddress(const std::string & ip_str,uint16_t port,IPEndPoint * address)76 bool CreateUDPAddress(const std::string& ip_str,
77                       uint16_t port,
78                       IPEndPoint* address) {
79   IPAddress ip_address;
80   if (!ip_address.AssignFromIPLiteral(ip_str))
81     return false;
82 
83   *address = IPEndPoint(ip_address, port);
84   return true;
85 }
86 
87 class UDPSocketTest : public PlatformTest, public WithTaskEnvironment {
88  public:
UDPSocketTest()89   UDPSocketTest() : buffer_(base::MakeRefCounted<IOBufferWithSize>(kMaxRead)) {}
90 
91   // Blocks until data is read from the socket.
RecvFromSocket(UDPServerSocket * socket)92   std::string RecvFromSocket(UDPServerSocket* socket) {
93     TestCompletionCallback callback;
94 
95     int rv = socket->RecvFrom(buffer_.get(), kMaxRead, &recv_from_address_,
96                               callback.callback());
97     rv = callback.GetResult(rv);
98     if (rv < 0)
99       return std::string();
100     return std::string(buffer_->data(), rv);
101   }
102 
103   // Sends UDP packet.
104   // If |address| is specified, then it is used for the destination
105   // to send to. Otherwise, will send to the last socket this server
106   // received from.
SendToSocket(UDPServerSocket * socket,const std::string & msg)107   int SendToSocket(UDPServerSocket* socket, const std::string& msg) {
108     return SendToSocket(socket, msg, recv_from_address_);
109   }
110 
SendToSocket(UDPServerSocket * socket,std::string msg,const IPEndPoint & address)111   int SendToSocket(UDPServerSocket* socket,
112                    std::string msg,
113                    const IPEndPoint& address) {
114     scoped_refptr<StringIOBuffer> io_buffer =
115         base::MakeRefCounted<StringIOBuffer>(msg);
116     TestCompletionCallback callback;
117     int rv = socket->SendTo(io_buffer.get(), io_buffer->size(), address,
118                             callback.callback());
119     return callback.GetResult(rv);
120   }
121 
ReadSocket(UDPClientSocket * socket)122   std::string ReadSocket(UDPClientSocket* socket) {
123     return ReadSocket(socket, DSCP_DEFAULT, ECN_DEFAULT);
124   }
125 
ReadSocket(UDPClientSocket * socket,DiffServCodePoint dscp,EcnCodePoint ecn)126   std::string ReadSocket(UDPClientSocket* socket,
127                          DiffServCodePoint dscp,
128                          EcnCodePoint ecn) {
129     TestCompletionCallback callback;
130 
131     int rv = socket->Read(buffer_.get(), kMaxRead, callback.callback());
132     rv = callback.GetResult(rv);
133     if (rv < 0)
134       return std::string();
135 #if BUILDFLAG(IS_WIN)
136     // The DSCP value is not populated on Windows, in order to avoid incurring
137     // an extra system call.
138     EXPECT_EQ(socket->GetLastTos().dscp, DSCP_DEFAULT);
139 #else
140     EXPECT_EQ(socket->GetLastTos().dscp, dscp);
141 #endif
142     EXPECT_EQ(socket->GetLastTos().ecn, ecn);
143     return std::string(buffer_->data(), rv);
144   }
145 
146   // Writes specified message to the socket.
WriteSocket(UDPClientSocket * socket,const std::string & msg)147   int WriteSocket(UDPClientSocket* socket, const std::string& msg) {
148     scoped_refptr<StringIOBuffer> io_buffer =
149         base::MakeRefCounted<StringIOBuffer>(msg);
150     TestCompletionCallback callback;
151     int rv = socket->Write(io_buffer.get(), io_buffer->size(),
152                            callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
153     return callback.GetResult(rv);
154   }
155 
WriteSocketIgnoreResult(UDPClientSocket * socket,const std::string & msg)156   void WriteSocketIgnoreResult(UDPClientSocket* socket,
157                                const std::string& msg) {
158     WriteSocket(socket, msg);
159   }
160 
161   // And again for a bare socket
SendToSocket(UDPSocket * socket,std::string msg,const IPEndPoint & address)162   int SendToSocket(UDPSocket* socket,
163                    std::string msg,
164                    const IPEndPoint& address) {
165     auto io_buffer = base::MakeRefCounted<StringIOBuffer>(msg);
166     TestCompletionCallback callback;
167     int rv = socket->SendTo(io_buffer.get(), io_buffer->size(), address,
168                             callback.callback());
169     return callback.GetResult(rv);
170   }
171 
172   // Run unit test for a connection test.
173   // |use_nonblocking_io| is used to switch between overlapped and non-blocking
174   // IO on Windows. It has no effect in other ports.
175   void ConnectTest(bool use_nonblocking_io, bool use_async);
176 
177  protected:
178   static const int kMaxRead = 1024;
179   scoped_refptr<IOBufferWithSize> buffer_;
180   IPEndPoint recv_from_address_;
181 };
182 
183 const int UDPSocketTest::kMaxRead;
184 
ReadCompleteCallback(int * result_out,base::OnceClosure callback,int result)185 void ReadCompleteCallback(int* result_out,
186                           base::OnceClosure callback,
187                           int result) {
188   *result_out = result;
189   std::move(callback).Run();
190 }
191 
ConnectTest(bool use_nonblocking_io,bool use_async)192 void UDPSocketTest::ConnectTest(bool use_nonblocking_io, bool use_async) {
193   std::string simple_message("hello world!");
194   RecordingNetLogObserver net_log_observer;
195   // Setup the server to listen.
196   IPEndPoint server_address(IPAddress::IPv4Localhost(), 0 /* port */);
197   auto server =
198       std::make_unique<UDPServerSocket>(NetLog::Get(), NetLogSource());
199   if (use_nonblocking_io)
200     server->UseNonBlockingIO();
201   server->AllowAddressReuse();
202   ASSERT_THAT(server->Listen(server_address), IsOk());
203   // Get bound port.
204   ASSERT_THAT(server->GetLocalAddress(&server_address), IsOk());
205 
206   // Setup the client.
207   auto client = std::make_unique<UDPClientSocket>(
208       DatagramSocket::DEFAULT_BIND, NetLog::Get(), NetLogSource());
209   if (use_nonblocking_io)
210     client->UseNonBlockingIO();
211 
212   if (!use_async) {
213     EXPECT_THAT(client->Connect(server_address), IsOk());
214   } else {
215     TestCompletionCallback callback;
216     int rv = client->ConnectAsync(server_address, callback.callback());
217     if (rv != OK) {
218       ASSERT_EQ(rv, ERR_IO_PENDING);
219       rv = callback.WaitForResult();
220       EXPECT_EQ(rv, OK);
221     } else {
222       EXPECT_EQ(rv, OK);
223     }
224   }
225   // Client sends to the server.
226   EXPECT_EQ(simple_message.length(),
227             static_cast<size_t>(WriteSocket(client.get(), simple_message)));
228 
229   // Server waits for message.
230   std::string str = RecvFromSocket(server.get());
231   EXPECT_EQ(simple_message, str);
232 
233   // Server echoes reply.
234   EXPECT_EQ(simple_message.length(),
235             static_cast<size_t>(SendToSocket(server.get(), simple_message)));
236 
237   // Client waits for response.
238   str = ReadSocket(client.get());
239   EXPECT_EQ(simple_message, str);
240 
241   // Test asynchronous read. Server waits for message.
242   base::RunLoop run_loop;
243   int read_result = 0;
244   int rv = server->RecvFrom(buffer_.get(), kMaxRead, &recv_from_address_,
245                             base::BindOnce(&ReadCompleteCallback, &read_result,
246                                            run_loop.QuitClosure()));
247   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
248 
249   // Client sends to the server.
250   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
251       FROM_HERE,
252       base::BindOnce(&UDPSocketTest::WriteSocketIgnoreResult,
253                      base::Unretained(this), client.get(), simple_message));
254   run_loop.Run();
255   EXPECT_EQ(simple_message.length(), static_cast<size_t>(read_result));
256   EXPECT_EQ(simple_message, std::string(buffer_->data(), read_result));
257 
258   NetLogSource server_net_log_source = server->NetLog().source();
259   NetLogSource client_net_log_source = client->NetLog().source();
260 
261   // Delete sockets so they log their final events.
262   server.reset();
263   client.reset();
264 
265   // Check the server's log.
266   auto server_entries =
267       net_log_observer.GetEntriesForSource(server_net_log_source);
268   ASSERT_EQ(6u, server_entries.size());
269   EXPECT_TRUE(
270       LogContainsBeginEvent(server_entries, 0, NetLogEventType::SOCKET_ALIVE));
271   EXPECT_TRUE(LogContainsEvent(server_entries, 1,
272                                NetLogEventType::UDP_LOCAL_ADDRESS,
273                                NetLogEventPhase::NONE));
274   EXPECT_TRUE(LogContainsEvent(server_entries, 2,
275                                NetLogEventType::UDP_BYTES_RECEIVED,
276                                NetLogEventPhase::NONE));
277   EXPECT_TRUE(LogContainsEvent(server_entries, 3,
278                                NetLogEventType::UDP_BYTES_SENT,
279                                NetLogEventPhase::NONE));
280   EXPECT_TRUE(LogContainsEvent(server_entries, 4,
281                                NetLogEventType::UDP_BYTES_RECEIVED,
282                                NetLogEventPhase::NONE));
283   EXPECT_TRUE(
284       LogContainsEndEvent(server_entries, 5, NetLogEventType::SOCKET_ALIVE));
285 
286   // Check the client's log.
287   auto client_entries =
288       net_log_observer.GetEntriesForSource(client_net_log_source);
289   EXPECT_EQ(7u, client_entries.size());
290   EXPECT_TRUE(
291       LogContainsBeginEvent(client_entries, 0, NetLogEventType::SOCKET_ALIVE));
292   EXPECT_TRUE(
293       LogContainsBeginEvent(client_entries, 1, NetLogEventType::UDP_CONNECT));
294   EXPECT_TRUE(
295       LogContainsEndEvent(client_entries, 2, NetLogEventType::UDP_CONNECT));
296   EXPECT_TRUE(LogContainsEvent(client_entries, 3,
297                                NetLogEventType::UDP_BYTES_SENT,
298                                NetLogEventPhase::NONE));
299   EXPECT_TRUE(LogContainsEvent(client_entries, 4,
300                                NetLogEventType::UDP_BYTES_RECEIVED,
301                                NetLogEventPhase::NONE));
302   EXPECT_TRUE(LogContainsEvent(client_entries, 5,
303                                NetLogEventType::UDP_BYTES_SENT,
304                                NetLogEventPhase::NONE));
305   EXPECT_TRUE(
306       LogContainsEndEvent(client_entries, 6, NetLogEventType::SOCKET_ALIVE));
307 }
308 
TEST_F(UDPSocketTest,Connect)309 TEST_F(UDPSocketTest, Connect) {
310   // The variable |use_nonblocking_io| has no effect in non-Windows ports.
311   // Run ConnectTest once with sync connect and once with async connect
312   ConnectTest(false, false);
313   ConnectTest(false, true);
314 }
315 
316 #if BUILDFLAG(IS_WIN)
TEST_F(UDPSocketTest,ConnectNonBlocking)317 TEST_F(UDPSocketTest, ConnectNonBlocking) {
318   ConnectTest(true, false);
319   ConnectTest(true, true);
320 }
321 #endif
322 
TEST_F(UDPSocketTest,PartialRecv)323 TEST_F(UDPSocketTest, PartialRecv) {
324   UDPServerSocket server_socket(nullptr, NetLogSource());
325   ASSERT_THAT(server_socket.Listen(IPEndPoint(IPAddress::IPv4Localhost(), 0)),
326               IsOk());
327   IPEndPoint server_address;
328   ASSERT_THAT(server_socket.GetLocalAddress(&server_address), IsOk());
329 
330   UDPClientSocket client_socket(DatagramSocket::DEFAULT_BIND, nullptr,
331                                 NetLogSource());
332   ASSERT_THAT(client_socket.Connect(server_address), IsOk());
333 
334   std::string test_packet("hello world!");
335   ASSERT_EQ(static_cast<int>(test_packet.size()),
336             WriteSocket(&client_socket, test_packet));
337 
338   TestCompletionCallback recv_callback;
339 
340   // Read just 2 bytes. Read() is expected to return the first 2 bytes from the
341   // packet and discard the rest.
342   const int kPartialReadSize = 2;
343   auto buffer = base::MakeRefCounted<IOBufferWithSize>(kPartialReadSize);
344   int rv =
345       server_socket.RecvFrom(buffer.get(), kPartialReadSize,
346                              &recv_from_address_, recv_callback.callback());
347   rv = recv_callback.GetResult(rv);
348 
349   EXPECT_EQ(rv, ERR_MSG_TOO_BIG);
350 
351   // Send a different message again.
352   std::string second_packet("Second packet");
353   ASSERT_EQ(static_cast<int>(second_packet.size()),
354             WriteSocket(&client_socket, second_packet));
355 
356   // Read whole packet now.
357   std::string received = RecvFromSocket(&server_socket);
358   EXPECT_EQ(second_packet, received);
359 }
360 
361 #if BUILDFLAG(IS_APPLE) || BUILDFLAG(IS_ANDROID)
362 // - MacOS: requires root permissions on OSX 10.7+.
363 // - Android: devices attached to testbots don't have default network, so
364 // broadcasting to 255.255.255.255 returns error -109 (Address not reachable).
365 // crbug.com/139144.
366 #define MAYBE_LocalBroadcast DISABLED_LocalBroadcast
367 #else
368 #define MAYBE_LocalBroadcast LocalBroadcast
369 #endif
TEST_F(UDPSocketTest,MAYBE_LocalBroadcast)370 TEST_F(UDPSocketTest, MAYBE_LocalBroadcast) {
371   std::string first_message("first message"), second_message("second message");
372 
373   IPEndPoint listen_address;
374   ASSERT_TRUE(CreateUDPAddress("0.0.0.0", 0 /* port */, &listen_address));
375 
376   auto server1 =
377       std::make_unique<UDPServerSocket>(NetLog::Get(), NetLogSource());
378   auto server2 =
379       std::make_unique<UDPServerSocket>(NetLog::Get(), NetLogSource());
380   server1->AllowAddressReuse();
381   server1->AllowBroadcast();
382   server2->AllowAddressReuse();
383   server2->AllowBroadcast();
384 
385   EXPECT_THAT(server1->Listen(listen_address), IsOk());
386   // Get bound port.
387   EXPECT_THAT(server1->GetLocalAddress(&listen_address), IsOk());
388   EXPECT_THAT(server2->Listen(listen_address), IsOk());
389 
390   IPEndPoint broadcast_address;
391   ASSERT_TRUE(CreateUDPAddress("127.255.255.255", listen_address.port(),
392                                &broadcast_address));
393   ASSERT_EQ(static_cast<int>(first_message.size()),
394             SendToSocket(server1.get(), first_message, broadcast_address));
395   std::string str = RecvFromSocket(server1.get());
396   ASSERT_EQ(first_message, str);
397   str = RecvFromSocket(server2.get());
398   ASSERT_EQ(first_message, str);
399 
400   ASSERT_EQ(static_cast<int>(second_message.size()),
401             SendToSocket(server2.get(), second_message, broadcast_address));
402   str = RecvFromSocket(server1.get());
403   ASSERT_EQ(second_message, str);
404   str = RecvFromSocket(server2.get());
405   ASSERT_EQ(second_message, str);
406 }
407 
408 // ConnectRandomBind verifies RANDOM_BIND is handled correctly. It connects
409 // 1000 sockets and then verifies that the allocated port numbers satisfy the
410 // following 2 conditions:
411 //  1. Range from min port value to max is greater than 10000.
412 //  2. There is at least one port in the 5 buckets in the [min, max] range.
413 //
414 // These conditions are not enough to verify that the port numbers are truly
415 // random, but they are enough to protect from most common non-random port
416 // allocation strategies (e.g. counter, pool of available ports, etc.) False
417 // positive result is theoretically possible, but its probability is negligible.
TEST_F(UDPSocketTest,ConnectRandomBind)418 TEST_F(UDPSocketTest, ConnectRandomBind) {
419   const int kIterations = 1000;
420 
421   std::vector<int> used_ports;
422   for (int i = 0; i < kIterations; ++i) {
423     UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
424                            NetLogSource());
425     EXPECT_THAT(socket.Connect(IPEndPoint(IPAddress::IPv4Localhost(), 53)),
426                 IsOk());
427 
428     IPEndPoint client_address;
429     EXPECT_THAT(socket.GetLocalAddress(&client_address), IsOk());
430     used_ports.push_back(client_address.port());
431   }
432 
433   int min_port = *std::min_element(used_ports.begin(), used_ports.end());
434   int max_port = *std::max_element(used_ports.begin(), used_ports.end());
435   int range = max_port - min_port + 1;
436 
437   // Verify that the range of ports used by the random port allocator is wider
438   // than 10k. Assuming that socket implementation limits port range to 16k
439   // ports (default on Fuchsia) probability of false negative is below
440   // 10^-200.
441   static int kMinRange = 10000;
442   EXPECT_GT(range, kMinRange);
443 
444   static int kBuckets = 5;
445   std::vector<int> bucket_sizes(kBuckets, 0);
446   for (int port : used_ports) {
447     bucket_sizes[(port - min_port) * kBuckets / range] += 1;
448   }
449 
450   // Verify that there is at least one value in each bucket. Probability of
451   // false negative is below (kBuckets * (1 - 1 / kBuckets) ^ kIterations),
452   // which is less than 10^-96.
453   for (int size : bucket_sizes) {
454     EXPECT_GT(size, 0);
455   }
456 }
457 
TEST_F(UDPSocketTest,ConnectFail)458 TEST_F(UDPSocketTest, ConnectFail) {
459   UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
460 
461   EXPECT_THAT(socket.Open(ADDRESS_FAMILY_IPV4), IsOk());
462 
463   // Connect to an IPv6 address should fail since the socket was created for
464   // IPv4.
465   EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
466               Not(IsOk()));
467 
468   // Make sure that UDPSocket actually closed the socket.
469   EXPECT_FALSE(socket.is_connected());
470 }
471 
472 // Similar to ConnectFail but UDPSocket adopts an opened socket instead of
473 // opening one directly.
TEST_F(UDPSocketTest,AdoptedSocket)474 TEST_F(UDPSocketTest, AdoptedSocket) {
475   auto socketfd =
476       CreatePlatformSocket(ConvertAddressFamily(ADDRESS_FAMILY_IPV4),
477                            SOCK_DGRAM, AF_UNIX ? 0 : IPPROTO_UDP);
478   UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
479 
480   EXPECT_THAT(socket.AdoptOpenedSocket(ADDRESS_FAMILY_IPV4, socketfd), IsOk());
481 
482   // Connect to an IPv6 address should fail since the socket was created for
483   // IPv4.
484   EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
485               Not(IsOk()));
486 
487   // Make sure that UDPSocket actually closed the socket.
488   EXPECT_FALSE(socket.is_connected());
489 }
490 
491 // Tests that UDPSocket updates the global counter correctly.
TEST_F(UDPSocketTest,LimitAdoptSocket)492 TEST_F(UDPSocketTest, LimitAdoptSocket) {
493   ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
494   {
495     // Creating a platform socket does not increase count.
496     auto socketfd =
497         CreatePlatformSocket(ConvertAddressFamily(ADDRESS_FAMILY_IPV4),
498                              SOCK_DGRAM, AF_UNIX ? 0 : IPPROTO_UDP);
499     ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
500 
501     // Simply allocating a UDPSocket does not increase count.
502     UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
503     EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
504 
505     // Calling AdoptOpenedSocket() allocates the socket and increases the global
506     // counter.
507     EXPECT_THAT(socket.AdoptOpenedSocket(ADDRESS_FAMILY_IPV4, socketfd),
508                 IsOk());
509     EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
510 
511     // Connect to an IPv6 address should fail since the socket was created for
512     // IPv4.
513     EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
514                 Not(IsOk()));
515 
516     // That Connect() failed doesn't change the global counter.
517     EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
518   }
519   // Finally, destroying UDPSocket decrements the global counter.
520   EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
521 }
522 
523 // In this test, we verify that connect() on a socket will have the effect
524 // of filtering reads on this socket only to data read from the destination
525 // we connected to.
526 //
527 // The purpose of this test is that some documentation indicates that connect
528 // binds the client's sends to send to a particular server endpoint, but does
529 // not bind the client's reads to only be from that endpoint, and that we need
530 // to always use recvfrom() to disambiguate.
TEST_F(UDPSocketTest,VerifyConnectBindsAddr)531 TEST_F(UDPSocketTest, VerifyConnectBindsAddr) {
532   std::string simple_message("hello world!");
533   std::string foreign_message("BAD MESSAGE TO GET!!");
534 
535   // Setup the first server to listen.
536   IPEndPoint server1_address(IPAddress::IPv4Localhost(), 0 /* port */);
537   UDPServerSocket server1(nullptr, NetLogSource());
538   ASSERT_THAT(server1.Listen(server1_address), IsOk());
539   // Get the bound port.
540   ASSERT_THAT(server1.GetLocalAddress(&server1_address), IsOk());
541 
542   // Setup the second server to listen.
543   IPEndPoint server2_address(IPAddress::IPv4Localhost(), 0 /* port */);
544   UDPServerSocket server2(nullptr, NetLogSource());
545   ASSERT_THAT(server2.Listen(server2_address), IsOk());
546 
547   // Setup the client, connected to server 1.
548   UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
549   EXPECT_THAT(client.Connect(server1_address), IsOk());
550 
551   // Client sends to server1.
552   EXPECT_EQ(simple_message.length(),
553             static_cast<size_t>(WriteSocket(&client, simple_message)));
554 
555   // Server1 waits for message.
556   std::string str = RecvFromSocket(&server1);
557   EXPECT_EQ(simple_message, str);
558 
559   // Get the client's address.
560   IPEndPoint client_address;
561   EXPECT_THAT(client.GetLocalAddress(&client_address), IsOk());
562 
563   // Server2 sends reply.
564   EXPECT_EQ(foreign_message.length(),
565             static_cast<size_t>(
566                 SendToSocket(&server2, foreign_message, client_address)));
567 
568   // Server1 sends reply.
569   EXPECT_EQ(simple_message.length(),
570             static_cast<size_t>(
571                 SendToSocket(&server1, simple_message, client_address)));
572 
573   // Client waits for response.
574   str = ReadSocket(&client);
575   EXPECT_EQ(simple_message, str);
576 }
577 
TEST_F(UDPSocketTest,ClientGetLocalPeerAddresses)578 TEST_F(UDPSocketTest, ClientGetLocalPeerAddresses) {
579   struct TestData {
580     std::string remote_address;
581     std::string local_address;
582     bool may_fail;
583   } tests[] = {
584     {"127.0.00.1", "127.0.0.1", false},
585     {"::1", "::1", true},
586 #if !BUILDFLAG(IS_ANDROID) && !BUILDFLAG(IS_IOS)
587     // Addresses below are disabled on Android. See crbug.com/161248
588     // They are also disabled on iOS. See https://crbug.com/523225
589     {"192.168.1.1", "127.0.0.1", false},
590     {"2001:db8:0::42", "::1", true},
591 #endif
592   };
593   for (const auto& test : tests) {
594     SCOPED_TRACE(std::string("Connecting from ") + test.local_address +
595                  std::string(" to ") + test.remote_address);
596 
597     IPAddress ip_address;
598     EXPECT_TRUE(ip_address.AssignFromIPLiteral(test.remote_address));
599     IPEndPoint remote_address(ip_address, 80);
600     EXPECT_TRUE(ip_address.AssignFromIPLiteral(test.local_address));
601     IPEndPoint local_address(ip_address, 80);
602 
603     UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr,
604                            NetLogSource());
605     int rv = client.Connect(remote_address);
606     if (test.may_fail && rv == ERR_ADDRESS_UNREACHABLE) {
607       // Connect() may return ERR_ADDRESS_UNREACHABLE for IPv6
608       // addresses if IPv6 is not configured.
609       continue;
610     }
611 
612     EXPECT_LE(ERR_IO_PENDING, rv);
613 
614     IPEndPoint fetched_local_address;
615     rv = client.GetLocalAddress(&fetched_local_address);
616     EXPECT_THAT(rv, IsOk());
617 
618     // TODO(mbelshe): figure out how to verify the IP and port.
619     //                The port is dynamically generated by the udp stack.
620     //                The IP is the real IP of the client, not necessarily
621     //                loopback.
622     // EXPECT_EQ(local_address.address(), fetched_local_address.address());
623 
624     IPEndPoint fetched_remote_address;
625     rv = client.GetPeerAddress(&fetched_remote_address);
626     EXPECT_THAT(rv, IsOk());
627 
628     EXPECT_EQ(remote_address, fetched_remote_address);
629   }
630 }
631 
TEST_F(UDPSocketTest,ServerGetLocalAddress)632 TEST_F(UDPSocketTest, ServerGetLocalAddress) {
633   IPEndPoint bind_address(IPAddress::IPv4Localhost(), 0);
634   UDPServerSocket server(nullptr, NetLogSource());
635   int rv = server.Listen(bind_address);
636   EXPECT_THAT(rv, IsOk());
637 
638   IPEndPoint local_address;
639   rv = server.GetLocalAddress(&local_address);
640   EXPECT_EQ(rv, 0);
641 
642   // Verify that port was allocated.
643   EXPECT_GT(local_address.port(), 0);
644   EXPECT_EQ(local_address.address(), bind_address.address());
645 }
646 
TEST_F(UDPSocketTest,ServerGetPeerAddress)647 TEST_F(UDPSocketTest, ServerGetPeerAddress) {
648   IPEndPoint bind_address(IPAddress::IPv4Localhost(), 0);
649   UDPServerSocket server(nullptr, NetLogSource());
650   int rv = server.Listen(bind_address);
651   EXPECT_THAT(rv, IsOk());
652 
653   IPEndPoint peer_address;
654   rv = server.GetPeerAddress(&peer_address);
655   EXPECT_EQ(rv, ERR_SOCKET_NOT_CONNECTED);
656 }
657 
TEST_F(UDPSocketTest,ClientSetDoNotFragment)658 TEST_F(UDPSocketTest, ClientSetDoNotFragment) {
659   for (std::string ip : {"127.0.0.1", "::1"}) {
660     UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr,
661                            NetLogSource());
662     IPAddress ip_address;
663     EXPECT_TRUE(ip_address.AssignFromIPLiteral(ip));
664     IPEndPoint remote_address(ip_address, 80);
665     int rv = client.Connect(remote_address);
666     // May fail on IPv6 is IPv6 is not configured.
667     if (ip_address.IsIPv6() && rv == ERR_ADDRESS_UNREACHABLE)
668       return;
669     EXPECT_THAT(rv, IsOk());
670 
671     rv = client.SetDoNotFragment();
672 #if BUILDFLAG(IS_IOS) || BUILDFLAG(IS_FUCHSIA)
673     // TODO(crbug.com/945590): IP_MTU_DISCOVER is not implemented on Fuchsia.
674     EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
675 #elif BUILDFLAG(IS_MAC)
676     if (base::mac::MacOSMajorVersion() >= 11) {
677       EXPECT_THAT(rv, IsOk());
678     } else {
679       EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
680     }
681 #else
682     EXPECT_THAT(rv, IsOk());
683 #endif
684   }
685 }
686 
TEST_F(UDPSocketTest,ServerSetDoNotFragment)687 TEST_F(UDPSocketTest, ServerSetDoNotFragment) {
688   for (std::string ip : {"127.0.0.1", "::1"}) {
689     IPEndPoint bind_address;
690     ASSERT_TRUE(CreateUDPAddress(ip, 0, &bind_address));
691     UDPServerSocket server(nullptr, NetLogSource());
692     int rv = server.Listen(bind_address);
693     // May fail on IPv6 is IPv6 is not configure
694     if (bind_address.address().IsIPv6() &&
695         (rv == ERR_ADDRESS_INVALID || rv == ERR_ADDRESS_UNREACHABLE))
696       return;
697     EXPECT_THAT(rv, IsOk());
698 
699     rv = server.SetDoNotFragment();
700 #if BUILDFLAG(IS_IOS) || BUILDFLAG(IS_FUCHSIA)
701     // TODO(crbug.com/945590): IP_MTU_DISCOVER is not implemented on Fuchsia.
702     EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
703 #elif BUILDFLAG(IS_MAC)
704     if (base::mac::MacOSMajorVersion() >= 11) {
705       EXPECT_THAT(rv, IsOk());
706     } else {
707       EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
708     }
709 #else
710     EXPECT_THAT(rv, IsOk());
711 #endif
712   }
713 }
714 
715 // Close the socket while read is pending.
TEST_F(UDPSocketTest,CloseWithPendingRead)716 TEST_F(UDPSocketTest, CloseWithPendingRead) {
717   IPEndPoint bind_address(IPAddress::IPv4Localhost(), 0);
718   UDPServerSocket server(nullptr, NetLogSource());
719   int rv = server.Listen(bind_address);
720   EXPECT_THAT(rv, IsOk());
721 
722   TestCompletionCallback callback;
723   IPEndPoint from;
724   rv = server.RecvFrom(buffer_.get(), kMaxRead, &from, callback.callback());
725   EXPECT_EQ(rv, ERR_IO_PENDING);
726 
727   server.Close();
728 
729   EXPECT_FALSE(callback.have_result());
730 }
731 
732 // Some Android devices do not support multicast.
733 // The ones supporting multicast need WifiManager.MulitcastLock to enable it.
734 // http://goo.gl/jjAk9
735 #if !BUILDFLAG(IS_ANDROID)
TEST_F(UDPSocketTest,JoinMulticastGroup)736 TEST_F(UDPSocketTest, JoinMulticastGroup) {
737   const char kGroup[] = "237.132.100.17";
738 
739   IPAddress group_ip;
740   EXPECT_TRUE(group_ip.AssignFromIPLiteral(kGroup));
741 // TODO(https://github.com/google/gvisor/issues/3839): don't guard on
742 // OS_FUCHSIA.
743 #if BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
744   IPEndPoint bind_address(IPAddress::AllZeros(group_ip.size()), 0 /* port */);
745 #else
746   IPEndPoint bind_address(group_ip, 0 /* port */);
747 #endif  // BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
748 
749   UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
750   EXPECT_THAT(socket.Open(bind_address.GetFamily()), IsOk());
751 
752   EXPECT_THAT(socket.Bind(bind_address), IsOk());
753   EXPECT_THAT(socket.JoinGroup(group_ip), IsOk());
754   // Joining group multiple times.
755   EXPECT_NE(OK, socket.JoinGroup(group_ip));
756   EXPECT_THAT(socket.LeaveGroup(group_ip), IsOk());
757   // Leaving group multiple times.
758   EXPECT_NE(OK, socket.LeaveGroup(group_ip));
759 
760   socket.Close();
761 }
762 
763 // TODO(https://crbug.com/947115): failing on device on iOS 12.2.
764 // TODO(https://crbug.com/1227554): flaky on Mac 11.
765 #if BUILDFLAG(IS_IOS) || BUILDFLAG(IS_MAC)
766 #define MAYBE_SharedMulticastAddress DISABLED_SharedMulticastAddress
767 #else
768 #define MAYBE_SharedMulticastAddress SharedMulticastAddress
769 #endif
TEST_F(UDPSocketTest,MAYBE_SharedMulticastAddress)770 TEST_F(UDPSocketTest, MAYBE_SharedMulticastAddress) {
771   const char kGroup[] = "224.0.0.251";
772 
773   IPAddress group_ip;
774   ASSERT_TRUE(group_ip.AssignFromIPLiteral(kGroup));
775 // TODO(https://github.com/google/gvisor/issues/3839): don't guard on
776 // OS_FUCHSIA.
777 #if BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
778   IPEndPoint receive_address(IPAddress::AllZeros(group_ip.size()),
779                              0 /* port */);
780 #else
781   IPEndPoint receive_address(group_ip, 0 /* port */);
782 #endif  // BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
783 
784   NetworkInterfaceList interfaces;
785   ASSERT_TRUE(GetNetworkList(&interfaces, 0));
786   // The test fails with the Hyper-V switch interface (on the host side).
787   interfaces.erase(std::remove_if(interfaces.begin(), interfaces.end(),
788                                   [](const auto& iface) {
789                                     return iface.friendly_name.rfind(
790                                                "vEthernet", 0) == 0;
791                                   }),
792                    interfaces.end());
793   ASSERT_FALSE(interfaces.empty());
794 
795   // Setup first receiving socket.
796   UDPServerSocket socket1(nullptr, NetLogSource());
797   socket1.AllowAddressSharingForMulticast();
798   ASSERT_THAT(socket1.SetMulticastInterface(interfaces[0].interface_index),
799               IsOk());
800   ASSERT_THAT(socket1.Listen(receive_address), IsOk());
801   ASSERT_THAT(socket1.JoinGroup(group_ip), IsOk());
802   // Get the bound port.
803   ASSERT_THAT(socket1.GetLocalAddress(&receive_address), IsOk());
804 
805   // Setup second receiving socket.
806   UDPServerSocket socket2(nullptr, NetLogSource());
807   socket2.AllowAddressSharingForMulticast(), IsOk();
808   ASSERT_THAT(socket2.SetMulticastInterface(interfaces[0].interface_index),
809               IsOk());
810   ASSERT_THAT(socket2.Listen(receive_address), IsOk());
811   ASSERT_THAT(socket2.JoinGroup(group_ip), IsOk());
812 
813   // Setup client socket.
814   IPEndPoint send_address(group_ip, receive_address.port());
815   UDPClientSocket client_socket(DatagramSocket::DEFAULT_BIND, nullptr,
816                                 NetLogSource());
817   ASSERT_THAT(client_socket.Connect(send_address), IsOk());
818 
819 #if !BUILDFLAG(IS_CHROMEOS_ASH)
820   // Send a message via the multicast group. That message is expected be be
821   // received by both receving sockets.
822   //
823   // Skip on ChromeOS where it's known to sometimes not work.
824   // TODO(crbug.com/898964): If possible, fix and reenable.
825   const char kMessage[] = "hello!";
826   ASSERT_GE(WriteSocket(&client_socket, kMessage), 0);
827   EXPECT_EQ(kMessage, RecvFromSocket(&socket1));
828   EXPECT_EQ(kMessage, RecvFromSocket(&socket2));
829 #endif  // !BUILDFLAG(IS_CHROMEOS_ASH)
830 }
831 #endif  // !BUILDFLAG(IS_ANDROID)
832 
TEST_F(UDPSocketTest,MulticastOptions)833 TEST_F(UDPSocketTest, MulticastOptions) {
834   IPEndPoint bind_address;
835   ASSERT_TRUE(CreateUDPAddress("0.0.0.0", 0 /* port */, &bind_address));
836 
837   UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
838   // Before binding.
839   EXPECT_THAT(socket.SetMulticastLoopbackMode(false), IsOk());
840   EXPECT_THAT(socket.SetMulticastLoopbackMode(true), IsOk());
841   EXPECT_THAT(socket.SetMulticastTimeToLive(0), IsOk());
842   EXPECT_THAT(socket.SetMulticastTimeToLive(3), IsOk());
843   EXPECT_NE(OK, socket.SetMulticastTimeToLive(-1));
844   EXPECT_THAT(socket.SetMulticastInterface(0), IsOk());
845 
846   EXPECT_THAT(socket.Open(bind_address.GetFamily()), IsOk());
847   EXPECT_THAT(socket.Bind(bind_address), IsOk());
848 
849   EXPECT_NE(OK, socket.SetMulticastLoopbackMode(false));
850   EXPECT_NE(OK, socket.SetMulticastTimeToLive(0));
851   EXPECT_NE(OK, socket.SetMulticastInterface(0));
852 
853   socket.Close();
854 }
855 
856 // Checking that DSCP bits are set correctly is difficult,
857 // but let's check that the code doesn't crash at least.
TEST_F(UDPSocketTest,SetDSCP)858 TEST_F(UDPSocketTest, SetDSCP) {
859   // Setup the server to listen.
860   IPEndPoint bind_address;
861   UDPSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
862   // We need a real IP, but we won't actually send anything to it.
863   ASSERT_TRUE(CreateUDPAddress("8.8.8.8", 9999, &bind_address));
864   int rv = client.Open(bind_address.GetFamily());
865   EXPECT_THAT(rv, IsOk());
866 
867   rv = client.Connect(bind_address);
868   if (rv != OK) {
869     // Let's try localhost then.
870     bind_address = IPEndPoint(IPAddress::IPv4Localhost(), 9999);
871     rv = client.Connect(bind_address);
872   }
873   EXPECT_THAT(rv, IsOk());
874 
875   client.SetDiffServCodePoint(DSCP_NO_CHANGE);
876   client.SetDiffServCodePoint(DSCP_AF41);
877   client.SetDiffServCodePoint(DSCP_DEFAULT);
878   client.SetDiffServCodePoint(DSCP_CS2);
879   client.SetDiffServCodePoint(DSCP_NO_CHANGE);
880   client.SetDiffServCodePoint(DSCP_DEFAULT);
881   client.Close();
882 }
883 
884 // Send DSCP + ECN marked packets from server to client and verify the TOS
885 // bytes that arrive.
TEST_F(UDPSocketTest,VerifyDscpAndEcnExchange)886 TEST_F(UDPSocketTest, VerifyDscpAndEcnExchange) {
887   IPEndPoint server_address(IPAddress::IPv4Localhost(), 0);
888   UDPServerSocket server(nullptr, NetLogSource());
889   server.AllowAddressReuse();
890   ASSERT_THAT(server.Listen(server_address), IsOk());
891   // Get bound port.
892   ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
893   UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
894   client.Connect(server_address);
895   EXPECT_EQ(client.SetRecvTos(), 0);
896   IPEndPoint client_address;
897   client.GetLocalAddress(&client_address);
898 
899   EXPECT_EQ(server.SetTos(DSCP_AF41, ECN_ECT1), 0);
900   std::string first_message = "foobar";
901   EXPECT_EQ(SendToSocket(&server, first_message, client_address),
902             static_cast<int>(first_message.length()));
903   EXPECT_EQ(ReadSocket(&client, DSCP_AF41, ECN_ECT1), first_message.data());
904 
905   std::string second_message = "foo";
906   EXPECT_EQ(server.SetTos(DSCP_CS2, ECN_ECT0), 0);
907   EXPECT_EQ(SendToSocket(&server, second_message, client_address),
908             static_cast<int>(second_message.length()));
909   EXPECT_EQ(ReadSocket(&client, DSCP_CS2, ECN_ECT0), second_message.data());
910 
911 #if BUILDFLAG(IS_WIN)
912   // The Windows sendmsg API does not allow setting ECN_CE as the outgoing mark.
913   EcnCodePoint final_ecn = ECN_ECT1;
914 #else
915   EcnCodePoint final_ecn = ECN_CE;
916 #endif
917 
918   EXPECT_EQ(server.SetTos(DSCP_NO_CHANGE, final_ecn), 0);
919   EXPECT_EQ(SendToSocket(&server, second_message, client_address),
920             static_cast<int>(second_message.length()));
921   EXPECT_EQ(ReadSocket(&client, DSCP_CS2, final_ecn), second_message.data());
922 
923   EXPECT_EQ(server.SetTos(DSCP_AF41, ECN_NO_CHANGE), 0);
924   EXPECT_EQ(SendToSocket(&server, second_message, client_address),
925             static_cast<int>(second_message.length()));
926   EXPECT_EQ(ReadSocket(&client, DSCP_AF41, final_ecn), second_message.data());
927 
928   EXPECT_EQ(server.SetTos(DSCP_NO_CHANGE, ECN_NO_CHANGE), 0);
929   EXPECT_EQ(SendToSocket(&server, second_message, client_address),
930             static_cast<int>(second_message.length()));
931   EXPECT_EQ(ReadSocket(&client, DSCP_AF41, final_ecn), second_message.data());
932 
933   server.Close();
934   client.Close();
935 }
936 
937 // For windows, test with Nonblocking sockets. For other platforms, this test
938 // is identical to VerifyDscpAndEcnExchange, above.
TEST_F(UDPSocketTest,VerifyDscpAndEcnExchangeNonBlocking)939 TEST_F(UDPSocketTest, VerifyDscpAndEcnExchangeNonBlocking) {
940   IPEndPoint server_address(IPAddress::IPv4Localhost(), 0);
941   UDPServerSocket server(nullptr, NetLogSource());
942   server.UseNonBlockingIO();
943   server.AllowAddressReuse();
944   ASSERT_THAT(server.Listen(server_address), IsOk());
945   // Get bound port.
946   ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
947   UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
948   client.UseNonBlockingIO();
949   client.Connect(server_address);
950   EXPECT_EQ(client.SetRecvTos(), 0);
951   IPEndPoint client_address;
952   client.GetLocalAddress(&client_address);
953 
954   EXPECT_EQ(server.SetTos(DSCP_AF41, ECN_ECT1), 0);
955   std::string first_message = "foobar";
956   EXPECT_EQ(SendToSocket(&server, first_message, client_address),
957             static_cast<int>(first_message.length()));
958   EXPECT_EQ(ReadSocket(&client, DSCP_AF41, ECN_ECT1), first_message.data());
959 
960   std::string second_message = "foo";
961   EXPECT_EQ(server.SetTos(DSCP_CS2, ECN_ECT0), 0);
962   EXPECT_EQ(SendToSocket(&server, second_message, client_address),
963             static_cast<int>(second_message.length()));
964   EXPECT_EQ(ReadSocket(&client, DSCP_CS2, ECN_ECT0), second_message.data());
965 
966   // The Windows sendmsg API does not allow setting ECN_CE as the outgoing mark.
967   EcnCodePoint final_ecn = ECN_ECT1;
968 
969   EXPECT_EQ(server.SetTos(DSCP_NO_CHANGE, final_ecn), 0);
970   EXPECT_EQ(SendToSocket(&server, second_message, client_address),
971             static_cast<int>(second_message.length()));
972   EXPECT_EQ(ReadSocket(&client, DSCP_CS2, final_ecn), second_message.data());
973 
974   EXPECT_EQ(server.SetTos(DSCP_AF41, ECN_NO_CHANGE), 0);
975   EXPECT_EQ(SendToSocket(&server, second_message, client_address),
976             static_cast<int>(second_message.length()));
977   EXPECT_EQ(ReadSocket(&client, DSCP_AF41, final_ecn), second_message.data());
978 
979   EXPECT_EQ(server.SetTos(DSCP_NO_CHANGE, ECN_NO_CHANGE), 0);
980   EXPECT_EQ(SendToSocket(&server, second_message, client_address),
981             static_cast<int>(second_message.length()));
982   EXPECT_EQ(ReadSocket(&client, DSCP_AF41, final_ecn), second_message.data());
983 
984   server.Close();
985   client.Close();
986 }
987 
TEST_F(UDPSocketTest,ConnectUsingNetwork)988 TEST_F(UDPSocketTest, ConnectUsingNetwork) {
989   // The specific value of this address doesn't really matter, and no
990   // server needs to be running here. The test only needs to call
991   // ConnectUsingNetwork() and won't send any datagrams.
992   const IPEndPoint fake_server_address(IPAddress::IPv4Localhost(), 8080);
993   const handles::NetworkHandle wrong_network_handle = 65536;
994 #if BUILDFLAG(IS_ANDROID)
995   NetworkChangeNotifierFactoryAndroid ncn_factory;
996   NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
997   std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
998   if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
999     GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
1000 
1001   {
1002     // Connecting using a not existing network should fail but not report
1003     // ERR_NOT_IMPLEMENTED when network handles are supported.
1004     UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
1005                            NetLogSource());
1006     int rv =
1007         socket.ConnectUsingNetwork(wrong_network_handle, fake_server_address);
1008     EXPECT_NE(ERR_NOT_IMPLEMENTED, rv);
1009     EXPECT_NE(OK, rv);
1010     EXPECT_NE(wrong_network_handle, socket.GetBoundNetwork());
1011   }
1012 
1013   {
1014     // Connecting using an existing network should succeed when
1015     // NetworkChangeNotifier returns a valid default network.
1016     UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
1017                            NetLogSource());
1018     const handles::NetworkHandle network_handle =
1019         NetworkChangeNotifier::GetDefaultNetwork();
1020     if (network_handle != handles::kInvalidNetworkHandle) {
1021       EXPECT_EQ(
1022           OK, socket.ConnectUsingNetwork(network_handle, fake_server_address));
1023       EXPECT_EQ(network_handle, socket.GetBoundNetwork());
1024     }
1025   }
1026 #else
1027   UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr, NetLogSource());
1028   EXPECT_EQ(
1029       ERR_NOT_IMPLEMENTED,
1030       socket.ConnectUsingNetwork(wrong_network_handle, fake_server_address));
1031 #endif  // BUILDFLAG(IS_ANDROID)
1032 }
1033 
TEST_F(UDPSocketTest,ConnectUsingNetworkAsync)1034 TEST_F(UDPSocketTest, ConnectUsingNetworkAsync) {
1035   // The specific value of this address doesn't really matter, and no
1036   // server needs to be running here. The test only needs to call
1037   // ConnectUsingNetwork() and won't send any datagrams.
1038   const IPEndPoint fake_server_address(IPAddress::IPv4Localhost(), 8080);
1039   const handles::NetworkHandle wrong_network_handle = 65536;
1040 #if BUILDFLAG(IS_ANDROID)
1041   NetworkChangeNotifierFactoryAndroid ncn_factory;
1042   NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
1043   std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
1044   if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
1045     GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
1046 
1047   {
1048     // Connecting using a not existing network should fail but not report
1049     // ERR_NOT_IMPLEMENTED when network handles are supported.
1050     UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
1051                            NetLogSource());
1052     TestCompletionCallback callback;
1053     int rv = socket.ConnectUsingNetworkAsync(
1054         wrong_network_handle, fake_server_address, callback.callback());
1055 
1056     if (rv == ERR_IO_PENDING) {
1057       rv = callback.WaitForResult();
1058     }
1059     EXPECT_NE(ERR_NOT_IMPLEMENTED, rv);
1060     EXPECT_NE(OK, rv);
1061   }
1062 
1063   {
1064     // Connecting using an existing network should succeed when
1065     // NetworkChangeNotifier returns a valid default network.
1066     UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
1067                            NetLogSource());
1068     TestCompletionCallback callback;
1069     const handles::NetworkHandle network_handle =
1070         NetworkChangeNotifier::GetDefaultNetwork();
1071     if (network_handle != handles::kInvalidNetworkHandle) {
1072       int rv = socket.ConnectUsingNetworkAsync(
1073           network_handle, fake_server_address, callback.callback());
1074       if (rv == ERR_IO_PENDING) {
1075         rv = callback.WaitForResult();
1076       }
1077       EXPECT_EQ(OK, rv);
1078       EXPECT_EQ(network_handle, socket.GetBoundNetwork());
1079     }
1080   }
1081 #else
1082   UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr, NetLogSource());
1083   TestCompletionCallback callback;
1084   EXPECT_EQ(ERR_NOT_IMPLEMENTED, socket.ConnectUsingNetworkAsync(
1085                                      wrong_network_handle, fake_server_address,
1086                                      callback.callback()));
1087 #endif  // BUILDFLAG(IS_ANDROID)
1088 }
1089 
1090 }  // namespace
1091 
1092 #if BUILDFLAG(IS_WIN)
1093 
1094 namespace {
1095 
1096 const HANDLE kFakeHandle1 = (HANDLE)12;
1097 const HANDLE kFakeHandle2 = (HANDLE)13;
1098 
1099 const QOS_FLOWID kFakeFlowId1 = (QOS_FLOWID)27;
1100 const QOS_FLOWID kFakeFlowId2 = (QOS_FLOWID)38;
1101 
1102 class TestUDPSocketWin : public UDPSocketWin {
1103  public:
TestUDPSocketWin(QwaveApi * qos,DatagramSocket::BindType bind_type,net::NetLog * net_log,const net::NetLogSource & source)1104   TestUDPSocketWin(QwaveApi* qos,
1105                    DatagramSocket::BindType bind_type,
1106                    net::NetLog* net_log,
1107                    const net::NetLogSource& source)
1108       : UDPSocketWin(bind_type, net_log, source), qos_(qos) {}
1109 
1110   TestUDPSocketWin(const TestUDPSocketWin&) = delete;
1111   TestUDPSocketWin& operator=(const TestUDPSocketWin&) = delete;
1112 
1113   // Overriding GetQwaveApi causes the test class to use the injected mock
1114   // QwaveApi instance instead of the singleton.
GetQwaveApi() const1115   QwaveApi* GetQwaveApi() const override { return qos_; }
1116 
1117  private:
1118   raw_ptr<QwaveApi> qos_;
1119 };
1120 
1121 class MockQwaveApi : public QwaveApi {
1122  public:
1123   MOCK_CONST_METHOD0(qwave_supported, bool());
1124   MOCK_METHOD0(OnFatalError, void());
1125   MOCK_METHOD2(CreateHandle, BOOL(PQOS_VERSION version, PHANDLE handle));
1126   MOCK_METHOD1(CloseHandle, BOOL(HANDLE handle));
1127   MOCK_METHOD6(AddSocketToFlow,
1128                BOOL(HANDLE handle,
1129                     SOCKET socket,
1130                     PSOCKADDR addr,
1131                     QOS_TRAFFIC_TYPE traffic_type,
1132                     DWORD flags,
1133                     PQOS_FLOWID flow_id));
1134 
1135   MOCK_METHOD4(
1136       RemoveSocketFromFlow,
1137       BOOL(HANDLE handle, SOCKET socket, QOS_FLOWID flow_id, DWORD reserved));
1138   MOCK_METHOD7(SetFlow,
1139                BOOL(HANDLE handle,
1140                     QOS_FLOWID flow_id,
1141                     QOS_SET_FLOW op,
1142                     ULONG size,
1143                     PVOID data,
1144                     DWORD reserved,
1145                     LPOVERLAPPED overlapped));
1146 };
1147 
OpenedDscpTestClient(QwaveApi * api,IPEndPoint bind_address)1148 std::unique_ptr<UDPSocket> OpenedDscpTestClient(QwaveApi* api,
1149                                                 IPEndPoint bind_address) {
1150   auto client = std::make_unique<TestUDPSocketWin>(
1151       api, DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1152   int rv = client->Open(bind_address.GetFamily());
1153   EXPECT_THAT(rv, IsOk());
1154 
1155   return client;
1156 }
1157 
ConnectedDscpTestClient(QwaveApi * api)1158 std::unique_ptr<UDPSocket> ConnectedDscpTestClient(QwaveApi* api) {
1159   IPEndPoint bind_address;
1160   // We need a real IP, but we won't actually send anything to it.
1161   EXPECT_TRUE(CreateUDPAddress("8.8.8.8", 9999, &bind_address));
1162   auto client = OpenedDscpTestClient(api, bind_address);
1163   EXPECT_THAT(client->Connect(bind_address), IsOk());
1164   return client;
1165 }
1166 
UnconnectedDscpTestClient(QwaveApi * api)1167 std::unique_ptr<UDPSocket> UnconnectedDscpTestClient(QwaveApi* api) {
1168   IPEndPoint bind_address;
1169   EXPECT_TRUE(CreateUDPAddress("0.0.0.0", 9999, &bind_address));
1170   auto client = OpenedDscpTestClient(api, bind_address);
1171   EXPECT_THAT(client->Bind(bind_address), IsOk());
1172   return client;
1173 }
1174 
1175 }  // namespace
1176 
1177 using ::testing::Return;
1178 using ::testing::SetArgPointee;
1179 using ::testing::_;
1180 
TEST_F(UDPSocketTest,SetDSCPNoopIfPassedNoChange)1181 TEST_F(UDPSocketTest, SetDSCPNoopIfPassedNoChange) {
1182   MockQwaveApi api;
1183   EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1184 
1185   EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _)).Times(0);
1186   std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1187   EXPECT_THAT(client->SetDiffServCodePoint(DSCP_NO_CHANGE), IsOk());
1188 }
1189 
TEST_F(UDPSocketTest,SetDSCPFailsIfQOSDoesntLink)1190 TEST_F(UDPSocketTest, SetDSCPFailsIfQOSDoesntLink) {
1191   MockQwaveApi api;
1192   EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(false));
1193   EXPECT_CALL(api, CreateHandle(_, _)).Times(0);
1194 
1195   std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1196   EXPECT_EQ(ERR_NOT_IMPLEMENTED, client->SetDiffServCodePoint(DSCP_AF41));
1197 }
1198 
TEST_F(UDPSocketTest,SetDSCPFailsIfHandleCantBeCreated)1199 TEST_F(UDPSocketTest, SetDSCPFailsIfHandleCantBeCreated) {
1200   MockQwaveApi api;
1201   EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1202   EXPECT_CALL(api, CreateHandle(_, _)).WillOnce(Return(false));
1203   EXPECT_CALL(api, OnFatalError()).Times(1);
1204 
1205   std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1206   EXPECT_EQ(ERR_INVALID_HANDLE, client->SetDiffServCodePoint(DSCP_AF41));
1207 
1208   RunUntilIdle();
1209 
1210   EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(false));
1211   EXPECT_EQ(ERR_NOT_IMPLEMENTED, client->SetDiffServCodePoint(DSCP_AF41));
1212 }
1213 
1214 MATCHER_P(DscpPointee, dscp, "") {
1215   return *(DWORD*)arg == (DWORD)dscp;
1216 }
1217 
TEST_F(UDPSocketTest,ConnectedSocketDelayedInitAndUpdate)1218 TEST_F(UDPSocketTest, ConnectedSocketDelayedInitAndUpdate) {
1219   MockQwaveApi api;
1220   std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1221   EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1222   EXPECT_CALL(api, CreateHandle(_, _))
1223       .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1224 
1225   EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _))
1226       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1227   EXPECT_CALL(api, SetFlow(_, _, _, _, _, _, _));
1228 
1229   // First set on connected sockets will fail since init is async and
1230   // we haven't given the runloop a chance to execute the callback.
1231   EXPECT_EQ(ERR_INVALID_HANDLE, client->SetDiffServCodePoint(DSCP_AF41));
1232   RunUntilIdle();
1233   EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF41), IsOk());
1234 
1235   // New dscp value should reset the flow.
1236   EXPECT_CALL(api, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1237   EXPECT_CALL(api, AddSocketToFlow(_, _, _, QOSTrafficTypeBestEffort, _, _))
1238       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId2), Return(true)));
1239   EXPECT_CALL(api, SetFlow(_, _, QOSSetOutgoingDSCPValue, _,
1240                            DscpPointee(DSCP_DEFAULT), _, _));
1241   EXPECT_THAT(client->SetDiffServCodePoint(DSCP_DEFAULT), IsOk());
1242 
1243   // Called from DscpManager destructor.
1244   EXPECT_CALL(api, RemoveSocketFromFlow(_, _, kFakeFlowId2, _));
1245   EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1246 }
1247 
TEST_F(UDPSocketTest,UnonnectedSocketDelayedInitAndUpdate)1248 TEST_F(UDPSocketTest, UnonnectedSocketDelayedInitAndUpdate) {
1249   MockQwaveApi api;
1250   EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1251   EXPECT_CALL(api, CreateHandle(_, _))
1252       .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1253 
1254   // CreateHandle won't have completed yet.  Set passes.
1255   std::unique_ptr<UDPSocket> client = UnconnectedDscpTestClient(&api);
1256   EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF41), IsOk());
1257 
1258   RunUntilIdle();
1259   EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF42), IsOk());
1260 
1261   // Called from DscpManager destructor.
1262   EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1263 }
1264 
1265 // TODO(zstein): Mocking out DscpManager might be simpler here
1266 // (just verify that DscpManager::Set and DscpManager::PrepareForSend are
1267 // called).
TEST_F(UDPSocketTest,SendToCallsQwaveApis)1268 TEST_F(UDPSocketTest, SendToCallsQwaveApis) {
1269   MockQwaveApi api;
1270   std::unique_ptr<UDPSocket> client = UnconnectedDscpTestClient(&api);
1271   EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1272   EXPECT_CALL(api, CreateHandle(_, _))
1273       .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1274   EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF41), IsOk());
1275   RunUntilIdle();
1276 
1277   EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _))
1278       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1279   EXPECT_CALL(api, SetFlow(_, _, _, _, _, _, _));
1280   std::string simple_message("hello world");
1281   IPEndPoint server_address(IPAddress::IPv4Localhost(), 9438);
1282   int rv = SendToSocket(client.get(), simple_message, server_address);
1283   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1284 
1285   // TODO(zstein): Move to second test case (Qwave APIs called once per address)
1286   rv = SendToSocket(client.get(), simple_message, server_address);
1287   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1288 
1289   // TODO(zstein): Move to third test case (Qwave APIs called for each
1290   // destination address).
1291   EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _)).WillOnce(Return(true));
1292   IPEndPoint server_address2(IPAddress::IPv4Localhost(), 9439);
1293 
1294   rv = SendToSocket(client.get(), simple_message, server_address2);
1295   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1296 
1297   // Called from DscpManager destructor.
1298   EXPECT_CALL(api, RemoveSocketFromFlow(_, _, _, _));
1299   EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1300 }
1301 
TEST_F(UDPSocketTest,SendToCallsApisAfterDeferredInit)1302 TEST_F(UDPSocketTest, SendToCallsApisAfterDeferredInit) {
1303   MockQwaveApi api;
1304   std::unique_ptr<UDPSocket> client = UnconnectedDscpTestClient(&api);
1305   EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1306   EXPECT_CALL(api, CreateHandle(_, _))
1307       .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1308 
1309   // SetDiffServCodepoint works even if qos api hasn't finished initing.
1310   EXPECT_THAT(client->SetDiffServCodePoint(DSCP_CS7), IsOk());
1311 
1312   std::string simple_message("hello world");
1313   IPEndPoint server_address(IPAddress::IPv4Localhost(), 9438);
1314 
1315   // SendTo works, but doesn't yet apply TOS
1316   EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _)).Times(0);
1317   int rv = SendToSocket(client.get(), simple_message, server_address);
1318   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1319 
1320   RunUntilIdle();
1321   // Now we're initialized, SendTo triggers qos calls with correct codepoint.
1322   EXPECT_CALL(api, AddSocketToFlow(_, _, _, QOSTrafficTypeControl, _, _))
1323       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1324   EXPECT_CALL(api, SetFlow(_, _, _, _, _, _, _)).WillOnce(Return(true));
1325   rv = SendToSocket(client.get(), simple_message, server_address);
1326   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1327 
1328   // Called from DscpManager destructor.
1329   EXPECT_CALL(api, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1330   EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1331 }
1332 
1333 class DscpManagerTest : public TestWithTaskEnvironment {
1334  protected:
DscpManagerTest()1335   DscpManagerTest() {
1336     EXPECT_CALL(api_, qwave_supported()).WillRepeatedly(Return(true));
1337     EXPECT_CALL(api_, CreateHandle(_, _))
1338         .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1339     dscp_manager_ = std::make_unique<DscpManager>(&api_, INVALID_SOCKET);
1340 
1341     CreateUDPAddress("1.2.3.4", 9001, &address1_);
1342     CreateUDPAddress("1234:5678:90ab:cdef:1234:5678:90ab:cdef", 9002,
1343                      &address2_);
1344   }
1345 
1346   MockQwaveApi api_;
1347   std::unique_ptr<DscpManager> dscp_manager_;
1348 
1349   IPEndPoint address1_;
1350   IPEndPoint address2_;
1351 };
1352 
TEST_F(DscpManagerTest,PrepareForSendIsNoopIfNoSet)1353 TEST_F(DscpManagerTest, PrepareForSendIsNoopIfNoSet) {
1354   RunUntilIdle();
1355   dscp_manager_->PrepareForSend(address1_);
1356 }
1357 
TEST_F(DscpManagerTest,PrepareForSendCallsQwaveApisAfterSet)1358 TEST_F(DscpManagerTest, PrepareForSendCallsQwaveApisAfterSet) {
1359   RunUntilIdle();
1360   dscp_manager_->Set(DSCP_CS2);
1361 
1362   // AddSocketToFlow should be called for each address.
1363   // SetFlow should only be called when the flow is first created.
1364   EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1365       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1366   EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _));
1367   dscp_manager_->PrepareForSend(address1_);
1368 
1369   EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1370       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1371   EXPECT_CALL(api_, SetFlow(_, _, _, _, _, _, _)).Times(0);
1372   dscp_manager_->PrepareForSend(address2_);
1373 
1374   // Called from DscpManager destructor.
1375   EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1376   EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1377 }
1378 
TEST_F(DscpManagerTest,PrepareForSendCallsQwaveApisOncePerAddress)1379 TEST_F(DscpManagerTest, PrepareForSendCallsQwaveApisOncePerAddress) {
1380   RunUntilIdle();
1381   dscp_manager_->Set(DSCP_CS2);
1382 
1383   EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1384       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1385   EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _));
1386   dscp_manager_->PrepareForSend(address1_);
1387   EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _)).Times(0);
1388   EXPECT_CALL(api_, SetFlow(_, _, _, _, _, _, _)).Times(0);
1389   dscp_manager_->PrepareForSend(address1_);
1390 
1391   // Called from DscpManager destructor.
1392   EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1393   EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1394 }
1395 
TEST_F(DscpManagerTest,SetDestroysExistingFlow)1396 TEST_F(DscpManagerTest, SetDestroysExistingFlow) {
1397   RunUntilIdle();
1398   dscp_manager_->Set(DSCP_CS2);
1399 
1400   EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1401       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1402   EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _));
1403   dscp_manager_->PrepareForSend(address1_);
1404 
1405   // Calling Set should destroy the existing flow.
1406   // TODO(zstein): Verify that RemoveSocketFromFlow with no address
1407   // destroys the flow for all destinations.
1408   EXPECT_CALL(api_, RemoveSocketFromFlow(_, NULL, kFakeFlowId1, _));
1409   dscp_manager_->Set(DSCP_CS5);
1410 
1411   EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1412       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId2), Return(true)));
1413   EXPECT_CALL(api_, SetFlow(_, kFakeFlowId2, _, _, _, _, _));
1414   dscp_manager_->PrepareForSend(address1_);
1415 
1416   // Called from DscpManager destructor.
1417   EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId2, _));
1418   EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1419 }
1420 
TEST_F(DscpManagerTest,SocketReAddedOnRecreateHandle)1421 TEST_F(DscpManagerTest, SocketReAddedOnRecreateHandle) {
1422   RunUntilIdle();
1423   dscp_manager_->Set(DSCP_CS2);
1424 
1425   // First Set and Send work fine.
1426   EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1427       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1428   EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _))
1429       .WillOnce(Return(true));
1430   EXPECT_THAT(dscp_manager_->PrepareForSend(address1_), IsOk());
1431 
1432   // Make Second flow operation fail (requires resetting the codepoint).
1433   EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId1, _))
1434       .WillOnce(Return(true));
1435   dscp_manager_->Set(DSCP_CS7);
1436 
1437   auto error = std::make_unique<base::ScopedClearLastError>();
1438   ::SetLastError(ERROR_DEVICE_REINITIALIZATION_NEEDED);
1439   EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _)).WillOnce(Return(false));
1440   EXPECT_CALL(api_, SetFlow(_, _, _, _, _, _, _)).Times(0);
1441   EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1442   EXPECT_CALL(api_, CreateHandle(_, _))
1443       .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle2), Return(true)));
1444   EXPECT_EQ(ERR_INVALID_HANDLE, dscp_manager_->PrepareForSend(address1_));
1445   error = nullptr;
1446   RunUntilIdle();
1447 
1448   // Next Send should work fine, without requiring another Set
1449   EXPECT_CALL(api_, AddSocketToFlow(_, _, _, QOSTrafficTypeControl, _, _))
1450       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId2), Return(true)));
1451   EXPECT_CALL(api_, SetFlow(_, kFakeFlowId2, _, _, _, _, _))
1452       .WillOnce(Return(true));
1453   EXPECT_THAT(dscp_manager_->PrepareForSend(address1_), IsOk());
1454 
1455   // Called from DscpManager destructor.
1456   EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId2, _));
1457   EXPECT_CALL(api_, CloseHandle(kFakeHandle2));
1458 }
1459 
1460 #endif
1461 
TEST_F(UDPSocketTest,ReadWithSocketOptimization)1462 TEST_F(UDPSocketTest, ReadWithSocketOptimization) {
1463   std::string simple_message("hello world!");
1464 
1465   // Setup the server to listen.
1466   IPEndPoint server_address(IPAddress::IPv4Localhost(), 0 /* port */);
1467   UDPServerSocket server(nullptr, NetLogSource());
1468   server.AllowAddressReuse();
1469   ASSERT_THAT(server.Listen(server_address), IsOk());
1470   // Get bound port.
1471   ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
1472 
1473   // Setup the client, enable experimental optimization and connected to the
1474   // server.
1475   UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1476   client.EnableRecvOptimization();
1477   EXPECT_THAT(client.Connect(server_address), IsOk());
1478 
1479   // Get the client's address.
1480   IPEndPoint client_address;
1481   EXPECT_THAT(client.GetLocalAddress(&client_address), IsOk());
1482 
1483   // Server sends the message to the client.
1484   EXPECT_EQ(simple_message.length(),
1485             static_cast<size_t>(
1486                 SendToSocket(&server, simple_message, client_address)));
1487 
1488   // Client receives the message.
1489   std::string str = ReadSocket(&client);
1490   EXPECT_EQ(simple_message, str);
1491 
1492   server.Close();
1493   client.Close();
1494 }
1495 
1496 // Tests that read from a socket correctly returns
1497 // |ERR_MSG_TOO_BIG| when the buffer is too small and
1498 // returns the actual message when it fits the buffer.
1499 // For the optimized path, the buffer size should be at least
1500 // 1 byte greater than the message.
TEST_F(UDPSocketTest,ReadWithSocketOptimizationTruncation)1501 TEST_F(UDPSocketTest, ReadWithSocketOptimizationTruncation) {
1502   std::string too_long_message(kMaxRead + 1, 'A');
1503   std::string right_length_message(kMaxRead - 1, 'B');
1504   std::string exact_length_message(kMaxRead, 'C');
1505 
1506   // Setup the server to listen.
1507   IPEndPoint server_address(IPAddress::IPv4Localhost(), 0 /* port */);
1508   UDPServerSocket server(nullptr, NetLogSource());
1509   server.AllowAddressReuse();
1510   ASSERT_THAT(server.Listen(server_address), IsOk());
1511   // Get bound port.
1512   ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
1513 
1514   // Setup the client, enable experimental optimization and connected to the
1515   // server.
1516   UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1517   client.EnableRecvOptimization();
1518   EXPECT_THAT(client.Connect(server_address), IsOk());
1519 
1520   // Get the client's address.
1521   IPEndPoint client_address;
1522   EXPECT_THAT(client.GetLocalAddress(&client_address), IsOk());
1523 
1524   // Send messages to the client.
1525   EXPECT_EQ(too_long_message.length(),
1526             static_cast<size_t>(
1527                 SendToSocket(&server, too_long_message, client_address)));
1528   EXPECT_EQ(right_length_message.length(),
1529             static_cast<size_t>(
1530                 SendToSocket(&server, right_length_message, client_address)));
1531   EXPECT_EQ(exact_length_message.length(),
1532             static_cast<size_t>(
1533                 SendToSocket(&server, exact_length_message, client_address)));
1534 
1535   // Client receives the messages.
1536 
1537   // 1. The first message is |too_long_message|. Its size exceeds the buffer.
1538   // In that case, the client is expected to get |ERR_MSG_TOO_BIG| when the
1539   // data is read.
1540   TestCompletionCallback callback;
1541   int rv = client.Read(buffer_.get(), kMaxRead, callback.callback());
1542   EXPECT_EQ(ERR_MSG_TOO_BIG, callback.GetResult(rv));
1543   EXPECT_EQ(client.GetLastTos().dscp, DSCP_DEFAULT);
1544   EXPECT_EQ(client.GetLastTos().ecn, ECN_DEFAULT);
1545 
1546   // 2. The second message is |right_length_message|. Its size is
1547   // one byte smaller than the size of the buffer. In that case, the client
1548   // is expected to read the whole message successfully.
1549   rv = client.Read(buffer_.get(), kMaxRead, callback.callback());
1550   rv = callback.GetResult(rv);
1551   EXPECT_EQ(static_cast<int>(right_length_message.length()), rv);
1552   EXPECT_EQ(right_length_message, std::string(buffer_->data(), rv));
1553   EXPECT_EQ(client.GetLastTos().dscp, DSCP_DEFAULT);
1554   EXPECT_EQ(client.GetLastTos().ecn, ECN_DEFAULT);
1555 
1556   // 3. The third message is |exact_length_message|. Its size is equal to
1557   // the read buffer size. In that case, the client expects to get
1558   // |ERR_MSG_TOO_BIG| when the socket is read. Internally, the optimized
1559   // path uses read() system call that requires one extra byte to detect
1560   // truncated messages; therefore, messages that fill the buffer exactly
1561   // are considered truncated.
1562   // The optimization is only enabled on POSIX platforms. On Windows,
1563   // the optimization is turned off; therefore, the client
1564   // should be able to read the whole message without encountering
1565   // |ERR_MSG_TOO_BIG|.
1566   rv = client.Read(buffer_.get(), kMaxRead, callback.callback());
1567   rv = callback.GetResult(rv);
1568   EXPECT_EQ(client.GetLastTos().dscp, DSCP_DEFAULT);
1569   EXPECT_EQ(client.GetLastTos().ecn, ECN_DEFAULT);
1570 #if BUILDFLAG(IS_POSIX)
1571   EXPECT_EQ(ERR_MSG_TOO_BIG, rv);
1572 #else
1573   EXPECT_EQ(static_cast<int>(exact_length_message.length()), rv);
1574   EXPECT_EQ(exact_length_message, std::string(buffer_->data(), rv));
1575 #endif
1576   server.Close();
1577   client.Close();
1578 }
1579 
1580 // On Android, where socket tagging is supported, verify that UDPSocket::Tag
1581 // works as expected.
1582 #if BUILDFLAG(IS_ANDROID)
TEST_F(UDPSocketTest,Tag)1583 TEST_F(UDPSocketTest, Tag) {
1584   if (!CanGetTaggedBytes()) {
1585     DVLOG(0) << "Skipping test - GetTaggedBytes unsupported.";
1586     return;
1587   }
1588 
1589   UDPServerSocket server(nullptr, NetLogSource());
1590   ASSERT_THAT(server.Listen(IPEndPoint(IPAddress::IPv4Localhost(), 0)), IsOk());
1591   IPEndPoint server_address;
1592   ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
1593 
1594   UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1595   ASSERT_THAT(client.Connect(server_address), IsOk());
1596 
1597   // Verify UDP packets are tagged and counted properly.
1598   int32_t tag_val1 = 0x12345678;
1599   uint64_t old_traffic = GetTaggedBytes(tag_val1);
1600   SocketTag tag1(SocketTag::UNSET_UID, tag_val1);
1601   client.ApplySocketTag(tag1);
1602   // Client sends to the server.
1603   std::string simple_message("hello world!");
1604   int rv = WriteSocket(&client, simple_message);
1605   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1606   // Server waits for message.
1607   std::string str = RecvFromSocket(&server);
1608   EXPECT_EQ(simple_message, str);
1609   // Server echoes reply.
1610   rv = SendToSocket(&server, simple_message);
1611   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1612   // Client waits for response.
1613   str = ReadSocket(&client);
1614   EXPECT_EQ(simple_message, str);
1615   EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
1616 
1617   // Verify socket can be retagged with a new value and the current process's
1618   // UID.
1619   int32_t tag_val2 = 0x87654321;
1620   old_traffic = GetTaggedBytes(tag_val2);
1621   SocketTag tag2(getuid(), tag_val2);
1622   client.ApplySocketTag(tag2);
1623   // Client sends to the server.
1624   rv = WriteSocket(&client, simple_message);
1625   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1626   // Server waits for message.
1627   str = RecvFromSocket(&server);
1628   EXPECT_EQ(simple_message, str);
1629   // Server echoes reply.
1630   rv = SendToSocket(&server, simple_message);
1631   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1632   // Client waits for response.
1633   str = ReadSocket(&client);
1634   EXPECT_EQ(simple_message, str);
1635   EXPECT_GT(GetTaggedBytes(tag_val2), old_traffic);
1636 
1637   // Verify socket can be retagged with a new value and the current process's
1638   // UID.
1639   old_traffic = GetTaggedBytes(tag_val1);
1640   client.ApplySocketTag(tag1);
1641   // Client sends to the server.
1642   rv = WriteSocket(&client, simple_message);
1643   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1644   // Server waits for message.
1645   str = RecvFromSocket(&server);
1646   EXPECT_EQ(simple_message, str);
1647   // Server echoes reply.
1648   rv = SendToSocket(&server, simple_message);
1649   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1650   // Client waits for response.
1651   str = ReadSocket(&client);
1652   EXPECT_EQ(simple_message, str);
1653   EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
1654 }
1655 
TEST_F(UDPSocketTest,BindToNetwork)1656 TEST_F(UDPSocketTest, BindToNetwork) {
1657   // The specific value of this address doesn't really matter, and no
1658   // server needs to be running here. The test only needs to call
1659   // Connect() and won't send any datagrams.
1660   const IPEndPoint fake_server_address(IPAddress::IPv4Localhost(), 8080);
1661   NetworkChangeNotifierFactoryAndroid ncn_factory;
1662   NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
1663   std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
1664   if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
1665     GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
1666 
1667   // Binding the socket to a not existing network should fail at connect time.
1668   const handles::NetworkHandle wrong_network_handle = 65536;
1669   UDPClientSocket wrong_socket(DatagramSocket::RANDOM_BIND, nullptr,
1670                                NetLogSource(), wrong_network_handle);
1671   // Different Android versions might report different errors. Hence, just check
1672   // what shouldn't happen.
1673   int rv = wrong_socket.Connect(fake_server_address);
1674   EXPECT_NE(OK, rv);
1675   EXPECT_NE(ERR_NOT_IMPLEMENTED, rv);
1676   EXPECT_NE(wrong_network_handle, wrong_socket.GetBoundNetwork());
1677 
1678   // Binding the socket to an existing network should succeed.
1679   const handles::NetworkHandle network_handle =
1680       NetworkChangeNotifier::GetDefaultNetwork();
1681   if (network_handle != handles::kInvalidNetworkHandle) {
1682     UDPClientSocket correct_socket(DatagramSocket::RANDOM_BIND, nullptr,
1683                                    NetLogSource(), network_handle);
1684     EXPECT_EQ(OK, correct_socket.Connect(fake_server_address));
1685     EXPECT_EQ(network_handle, correct_socket.GetBoundNetwork());
1686   }
1687 }
1688 
1689 #endif  // BUILDFLAG(IS_ANDROID)
1690 
1691 // Scoped helper to override the process-wide UDP socket limit.
1692 class OverrideUDPSocketLimit {
1693  public:
OverrideUDPSocketLimit(int new_limit)1694   explicit OverrideUDPSocketLimit(int new_limit) {
1695     base::FieldTrialParams params;
1696     params[features::kLimitOpenUDPSocketsMax.name] =
1697         base::NumberToString(new_limit);
1698 
1699     scoped_feature_list_.InitAndEnableFeatureWithParameters(
1700         features::kLimitOpenUDPSockets, params);
1701   }
1702 
1703  private:
1704   base::test::ScopedFeatureList scoped_feature_list_;
1705 };
1706 
1707 // Tests that UDPClientSocket respects the global UDP socket limits.
TEST_F(UDPSocketTest,LimitClientSocket)1708 TEST_F(UDPSocketTest, LimitClientSocket) {
1709   // Reduce the global UDP limit to 2.
1710   OverrideUDPSocketLimit set_limit(2);
1711 
1712   ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1713 
1714   auto socket1 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1715                                                    nullptr, NetLogSource());
1716   auto socket2 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1717                                                    nullptr, NetLogSource());
1718 
1719   // Simply constructing a UDPClientSocket does not increase the limit (no
1720   // Connect() or Bind() has been called yet).
1721   ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1722 
1723   // The specific value of this address doesn't really matter, and no server
1724   // needs to be running here. The test only needs to call Connect() and won't
1725   // send any datagrams.
1726   IPEndPoint server_address(IPAddress::IPv4Localhost(), 8080);
1727 
1728   // Successful Connect() on socket1 increases socket count.
1729   EXPECT_THAT(socket1->Connect(server_address), IsOk());
1730   EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1731 
1732   // Successful Connect() on socket2 increases socket count.
1733   EXPECT_THAT(socket2->Connect(server_address), IsOk());
1734   EXPECT_EQ(2, GetGlobalUDPSocketCountForTesting());
1735 
1736   // Attempting a third Connect() should fail with ERR_INSUFFICIENT_RESOURCES,
1737   // as the limit is currently 2.
1738   auto socket3 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1739                                                    nullptr, NetLogSource());
1740   EXPECT_THAT(socket3->Connect(server_address),
1741               IsError(ERR_INSUFFICIENT_RESOURCES));
1742   EXPECT_EQ(2, GetGlobalUDPSocketCountForTesting());
1743 
1744   // Check that explicitly closing socket2 free up a count.
1745   socket2->Close();
1746   EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1747 
1748   // Since the socket was already closed, deleting it will not affect the count.
1749   socket2.reset();
1750   EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1751 
1752   // Now that the count is below limit, try to connect another socket. This time
1753   // it will work.
1754   auto socket4 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1755                                                    nullptr, NetLogSource());
1756   EXPECT_THAT(socket4->Connect(server_address), IsOk());
1757   EXPECT_EQ(2, GetGlobalUDPSocketCountForTesting());
1758 
1759   // Verify that closing the two remaining sockets brings the open count back to
1760   // 0.
1761   socket1.reset();
1762   EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1763   socket4.reset();
1764   EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1765 }
1766 
1767 // Tests that UDPSocketClient updates the global counter
1768 // correctly when Connect() fails.
TEST_F(UDPSocketTest,LimitConnectFail)1769 TEST_F(UDPSocketTest, LimitConnectFail) {
1770   ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1771 
1772   {
1773     // Simply allocating a UDPSocket does not increase count.
1774     UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1775     EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1776 
1777     // Calling Open() allocates the socket and increases the global counter.
1778     EXPECT_THAT(socket.Open(ADDRESS_FAMILY_IPV4), IsOk());
1779     EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1780 
1781     // Connect to an IPv6 address should fail since the socket was created for
1782     // IPv4.
1783     EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
1784                 Not(IsOk()));
1785 
1786     // That Connect() failed doesn't change the global counter.
1787     EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1788   }
1789 
1790   // Finally, destroying UDPSocket decrements the global counter.
1791   EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1792 }
1793 
1794 // Tests allocating UDPClientSockets and Connect()ing them in parallel.
1795 //
1796 // This is primarily intended for coverage under TSAN, to check for races
1797 // enforcing the global socket counter.
TEST_F(UDPSocketTest,LimitConnectMultithreaded)1798 TEST_F(UDPSocketTest, LimitConnectMultithreaded) {
1799   ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1800 
1801   // Start up some threads.
1802   std::vector<std::unique_ptr<base::Thread>> threads;
1803   for (size_t i = 0; i < 5; ++i) {
1804     threads.push_back(std::make_unique<base::Thread>("Worker thread"));
1805     ASSERT_TRUE(threads.back()->Start());
1806   }
1807 
1808   // Post tasks to each of the threads.
1809   for (const auto& thread : threads) {
1810     thread->task_runner()->PostTask(
1811         FROM_HERE, base::BindOnce([] {
1812           // The specific value of this address doesn't really matter, and no
1813           // server needs to be running here. The test only needs to call
1814           // Connect() and won't send any datagrams.
1815           IPEndPoint server_address(IPAddress::IPv4Localhost(), 8080);
1816 
1817           UDPClientSocket socket(DatagramSocket::DEFAULT_BIND, nullptr,
1818                                  NetLogSource());
1819           EXPECT_THAT(socket.Connect(server_address), IsOk());
1820         }));
1821   }
1822 
1823   // Complete all the tasks.
1824   threads.clear();
1825 
1826   EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1827 }
1828 
1829 }  // namespace net
1830