1 // Copyright 2021 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/test/embedded_test_server/connection_tracker.h"
6
7 #include "base/containers/contains.h"
8 #include "base/run_loop.h"
9 #include "base/task/single_thread_task_runner.h"
10 #include "net/test/embedded_test_server/embedded_test_server.h"
11 #include "testing/gtest/include/gtest/gtest.h"
12
13 namespace {
14
GetPort(const net::StreamSocket & connection,uint16_t * port)15 bool GetPort(const net::StreamSocket& connection, uint16_t* port) {
16 // Gets the remote port of the peer, since the local port will always be
17 // the port the test server is listening on. This isn't strictly correct -
18 // it's possible for multiple peers to connect with the same remote port
19 // but different remote IPs - but the tests here assume that connections
20 // to the test server (running on localhost) will always come from
21 // localhost, and thus the peer port is all that's needed to distinguish
22 // two connections. This also would be problematic if the OS reused ports,
23 // but that's not something to worry about for these tests.
24 net::IPEndPoint address;
25 int result = connection.GetPeerAddress(&address);
26 if (result != net::OK)
27 return false;
28 *port = address.port();
29 return true;
30 }
31
32 } // namespace
33
34 namespace net::test_server {
35
ConnectionTracker(EmbeddedTestServer * test_server)36 ConnectionTracker::ConnectionTracker(EmbeddedTestServer* test_server)
37 : connection_listener_(this) {
38 test_server->SetConnectionListener(&connection_listener_);
39 }
40
41 ConnectionTracker::~ConnectionTracker() = default;
42
AcceptedSocketWithPort(uint16_t port)43 void ConnectionTracker::AcceptedSocketWithPort(uint16_t port) {
44 num_connected_sockets_++;
45 sockets_[port] = SocketStatus::kAccepted;
46 CheckAccepted();
47 }
48
ReadFromSocketWithPort(uint16_t port)49 void ConnectionTracker::ReadFromSocketWithPort(uint16_t port) {
50 EXPECT_TRUE(base::Contains(sockets_, port));
51 if (sockets_[port] == SocketStatus::kAccepted)
52 num_read_sockets_++;
53 sockets_[port] = SocketStatus::kReadFrom;
54 if (read_loop_) {
55 read_loop_->Quit();
56 read_loop_ = nullptr;
57 }
58 }
59
60 // Returns the number of sockets that were accepted by the server.
GetAcceptedSocketCount() const61 size_t ConnectionTracker::GetAcceptedSocketCount() const {
62 return num_connected_sockets_;
63 }
64
65 // Returns the number of sockets that were read from by the server.
GetReadSocketCount() const66 size_t ConnectionTracker::GetReadSocketCount() const {
67 return num_read_sockets_;
68 }
69
WaitUntilConnectionRead()70 void ConnectionTracker::WaitUntilConnectionRead() {
71 base::RunLoop run_loop;
72 read_loop_ = &run_loop;
73 read_loop_->Run();
74 }
75
76 // This will wait for exactly |num_connections| items in |sockets_|. This method
77 // expects the server will not accept more than |num_connections| connections.
78 // |num_connections| must be greater than 0.
WaitForAcceptedConnections(size_t num_connections)79 void ConnectionTracker::WaitForAcceptedConnections(size_t num_connections) {
80 DCHECK(!num_accepted_connections_loop_);
81 DCHECK_GT(num_connections, 0u);
82 base::RunLoop run_loop;
83 EXPECT_GE(num_connections, num_connected_sockets_);
84 num_accepted_connections_loop_ = &run_loop;
85 num_accepted_connections_needed_ = num_connections;
86 CheckAccepted();
87 // Note that the previous call to CheckAccepted can quit this run loop
88 // before this call, which will make this call a no-op.
89 run_loop.Run();
90 EXPECT_EQ(num_connections, num_connected_sockets_);
91 }
92
93 // Helper function to stop the waiting for sockets to be accepted for
94 // WaitForAcceptedConnections. |num_accepted_connections_loop_| spins
95 // until |num_accepted_connections_needed_| sockets are accepted by the test
96 // server. The values will be null/0 if the loop is not running.
CheckAccepted()97 void ConnectionTracker::CheckAccepted() {
98 // |num_accepted_connections_loop_| null implies
99 // |num_accepted_connections_needed_| == 0.
100 DCHECK(num_accepted_connections_loop_ ||
101 num_accepted_connections_needed_ == 0);
102 if (!num_accepted_connections_loop_ ||
103 num_accepted_connections_needed_ != num_connected_sockets_) {
104 return;
105 }
106
107 num_accepted_connections_loop_->Quit();
108 num_accepted_connections_needed_ = 0;
109 num_accepted_connections_loop_ = nullptr;
110 }
111
ResetCounts()112 void ConnectionTracker::ResetCounts() {
113 sockets_.clear();
114 num_connected_sockets_ = 0;
115 num_read_sockets_ = 0;
116 }
117
ConnectionListener(ConnectionTracker * tracker)118 ConnectionTracker::ConnectionListener::ConnectionListener(
119 ConnectionTracker* tracker)
120 : task_runner_(base::SingleThreadTaskRunner::GetCurrentDefault()),
121 tracker_(tracker) {}
122
123 ConnectionTracker::ConnectionListener::~ConnectionListener() = default;
124
125 // Gets called from the EmbeddedTestServer thread to be notified that
126 // a connection was accepted.
127 std::unique_ptr<net::StreamSocket>
AcceptedSocket(std::unique_ptr<net::StreamSocket> connection)128 ConnectionTracker::ConnectionListener::AcceptedSocket(
129 std::unique_ptr<net::StreamSocket> connection) {
130 uint16_t port;
131 if (GetPort(*connection, &port)) {
132 task_runner_->PostTask(
133 FROM_HERE, base::BindOnce(&ConnectionTracker::AcceptedSocketWithPort,
134 base::Unretained(tracker_), port));
135 }
136 return connection;
137 }
138
139 // Gets called from the EmbeddedTestServer thread to be notified that
140 // a connection was read from.
ReadFromSocket(const net::StreamSocket & connection,int rv)141 void ConnectionTracker::ConnectionListener::ReadFromSocket(
142 const net::StreamSocket& connection,
143 int rv) {
144 // Don't log a read if no data was transferred. This case often happens if
145 // the sockets of the test server are being flushed and disconnected.
146 if (rv <= 0)
147 return;
148 uint16_t port;
149 if (GetPort(connection, &port)) {
150 task_runner_->PostTask(
151 FROM_HERE, base::BindOnce(&ConnectionTracker::ReadFromSocketWithPort,
152 base::Unretained(tracker_), port));
153 }
154 }
155
156 } // namespace net::test_server
157