xref: /aosp_15_r20/external/cronet/net/test/embedded_test_server/connection_tracker.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2021 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/test/embedded_test_server/connection_tracker.h"
6 
7 #include "base/containers/contains.h"
8 #include "base/run_loop.h"
9 #include "base/task/single_thread_task_runner.h"
10 #include "net/test/embedded_test_server/embedded_test_server.h"
11 #include "testing/gtest/include/gtest/gtest.h"
12 
13 namespace {
14 
GetPort(const net::StreamSocket & connection,uint16_t * port)15 bool GetPort(const net::StreamSocket& connection, uint16_t* port) {
16   // Gets the remote port of the peer, since the local port will always be
17   // the port the test server is listening on. This isn't strictly correct -
18   // it's possible for multiple peers to connect with the same remote port
19   // but different remote IPs - but the tests here assume that connections
20   // to the test server (running on localhost) will always come from
21   // localhost, and thus the peer port is all that's needed to distinguish
22   // two connections. This also would be problematic if the OS reused ports,
23   // but that's not something to worry about for these tests.
24   net::IPEndPoint address;
25   int result = connection.GetPeerAddress(&address);
26   if (result != net::OK)
27     return false;
28   *port = address.port();
29   return true;
30 }
31 
32 }  // namespace
33 
34 namespace net::test_server {
35 
ConnectionTracker(EmbeddedTestServer * test_server)36 ConnectionTracker::ConnectionTracker(EmbeddedTestServer* test_server)
37     : connection_listener_(this) {
38   test_server->SetConnectionListener(&connection_listener_);
39 }
40 
41 ConnectionTracker::~ConnectionTracker() = default;
42 
AcceptedSocketWithPort(uint16_t port)43 void ConnectionTracker::AcceptedSocketWithPort(uint16_t port) {
44   num_connected_sockets_++;
45   sockets_[port] = SocketStatus::kAccepted;
46   CheckAccepted();
47 }
48 
ReadFromSocketWithPort(uint16_t port)49 void ConnectionTracker::ReadFromSocketWithPort(uint16_t port) {
50   EXPECT_TRUE(base::Contains(sockets_, port));
51   if (sockets_[port] == SocketStatus::kAccepted)
52     num_read_sockets_++;
53   sockets_[port] = SocketStatus::kReadFrom;
54   if (read_loop_) {
55     read_loop_->Quit();
56     read_loop_ = nullptr;
57   }
58 }
59 
60 // Returns the number of sockets that were accepted by the server.
GetAcceptedSocketCount() const61 size_t ConnectionTracker::GetAcceptedSocketCount() const {
62   return num_connected_sockets_;
63 }
64 
65 // Returns the number of sockets that were read from by the server.
GetReadSocketCount() const66 size_t ConnectionTracker::GetReadSocketCount() const {
67   return num_read_sockets_;
68 }
69 
WaitUntilConnectionRead()70 void ConnectionTracker::WaitUntilConnectionRead() {
71   base::RunLoop run_loop;
72   read_loop_ = &run_loop;
73   read_loop_->Run();
74 }
75 
76 // This will wait for exactly |num_connections| items in |sockets_|. This method
77 // expects the server will not accept more than |num_connections| connections.
78 // |num_connections| must be greater than 0.
WaitForAcceptedConnections(size_t num_connections)79 void ConnectionTracker::WaitForAcceptedConnections(size_t num_connections) {
80   DCHECK(!num_accepted_connections_loop_);
81   DCHECK_GT(num_connections, 0u);
82   base::RunLoop run_loop;
83   EXPECT_GE(num_connections, num_connected_sockets_);
84   num_accepted_connections_loop_ = &run_loop;
85   num_accepted_connections_needed_ = num_connections;
86   CheckAccepted();
87   // Note that the previous call to CheckAccepted can quit this run loop
88   // before this call, which will make this call a no-op.
89   run_loop.Run();
90   EXPECT_EQ(num_connections, num_connected_sockets_);
91 }
92 
93 // Helper function to stop the waiting for sockets to be accepted for
94 // WaitForAcceptedConnections. |num_accepted_connections_loop_| spins
95 // until |num_accepted_connections_needed_| sockets are accepted by the test
96 // server. The values will be null/0 if the loop is not running.
CheckAccepted()97 void ConnectionTracker::CheckAccepted() {
98   // |num_accepted_connections_loop_| null implies
99   // |num_accepted_connections_needed_| == 0.
100   DCHECK(num_accepted_connections_loop_ ||
101          num_accepted_connections_needed_ == 0);
102   if (!num_accepted_connections_loop_ ||
103       num_accepted_connections_needed_ != num_connected_sockets_) {
104     return;
105   }
106 
107   num_accepted_connections_loop_->Quit();
108   num_accepted_connections_needed_ = 0;
109   num_accepted_connections_loop_ = nullptr;
110 }
111 
ResetCounts()112 void ConnectionTracker::ResetCounts() {
113   sockets_.clear();
114   num_connected_sockets_ = 0;
115   num_read_sockets_ = 0;
116 }
117 
ConnectionListener(ConnectionTracker * tracker)118 ConnectionTracker::ConnectionListener::ConnectionListener(
119     ConnectionTracker* tracker)
120     : task_runner_(base::SingleThreadTaskRunner::GetCurrentDefault()),
121       tracker_(tracker) {}
122 
123 ConnectionTracker::ConnectionListener::~ConnectionListener() = default;
124 
125 // Gets called from the EmbeddedTestServer thread to be notified that
126 // a connection was accepted.
127 std::unique_ptr<net::StreamSocket>
AcceptedSocket(std::unique_ptr<net::StreamSocket> connection)128 ConnectionTracker::ConnectionListener::AcceptedSocket(
129     std::unique_ptr<net::StreamSocket> connection) {
130   uint16_t port;
131   if (GetPort(*connection, &port)) {
132     task_runner_->PostTask(
133         FROM_HERE, base::BindOnce(&ConnectionTracker::AcceptedSocketWithPort,
134                                   base::Unretained(tracker_), port));
135   }
136   return connection;
137 }
138 
139 // Gets called from the EmbeddedTestServer thread to be notified that
140 // a connection was read from.
ReadFromSocket(const net::StreamSocket & connection,int rv)141 void ConnectionTracker::ConnectionListener::ReadFromSocket(
142     const net::StreamSocket& connection,
143     int rv) {
144   // Don't log a read if no data was transferred. This case often happens if
145   // the sockets of the test server are being flushed and disconnected.
146   if (rv <= 0)
147     return;
148   uint16_t port;
149   if (GetPort(connection, &port)) {
150     task_runner_->PostTask(
151         FROM_HERE, base::BindOnce(&ConnectionTracker::ReadFromSocketWithPort,
152                                   base::Unretained(tracker_), port));
153   }
154 }
155 
156 }  // namespace net::test_server
157