1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "osp/impl/quic/quic_server.h"
6
7 #include <memory>
8
9 #include "gmock/gmock.h"
10 #include "gtest/gtest.h"
11 #include "osp/impl/quic/testing/fake_quic_connection_factory.h"
12 #include "osp/impl/quic/testing/quic_test_support.h"
13 #include "osp/public/network_metrics.h"
14 #include "osp/public/network_service_manager.h"
15 #include "osp/public/testing/message_demuxer_test_support.h"
16 #include "platform/base/error.h"
17 #include "platform/test/fake_clock.h"
18 #include "platform/test/fake_task_runner.h"
19
20 namespace openscreen {
21 namespace osp {
22 namespace {
23
24 using ::testing::_;
25 using ::testing::Invoke;
26 using ::testing::Test;
27
28 class MockConnectRequest final
29 : public ProtocolConnectionClient::ConnectionRequestCallback {
30 public:
31 ~MockConnectRequest() override = default;
32
OnConnectionOpened(uint64_t request_id,std::unique_ptr<ProtocolConnection> connection)33 void OnConnectionOpened(
34 uint64_t request_id,
35 std::unique_ptr<ProtocolConnection> connection) override {
36 OnConnectionOpenedMock();
37 }
38 MOCK_METHOD0(OnConnectionOpenedMock, void());
39 MOCK_METHOD1(OnConnectionFailed, void(uint64_t request_id));
40 };
41
42 class MockConnectionObserver final : public ProtocolConnection::Observer {
43 public:
44 ~MockConnectionObserver() override = default;
45
46 MOCK_METHOD1(OnConnectionClosed, void(const ProtocolConnection& connection));
47 };
48
49 class QuicServerTest : public Test {
50 public:
QuicServerTest()51 QuicServerTest() {
52 fake_clock_ = std::make_unique<FakeClock>(
53 Clock::time_point(std::chrono::milliseconds(1298424)));
54 task_runner_ = std::make_unique<FakeTaskRunner>(fake_clock_.get());
55 quic_bridge_ =
56 std::make_unique<FakeQuicBridge>(task_runner_.get(), FakeClock::now);
57 }
58
59 protected:
ExpectIncomingConnection()60 std::unique_ptr<ProtocolConnection> ExpectIncomingConnection() {
61 MockConnectRequest mock_connect_request;
62 NetworkServiceManager::Get()->GetProtocolConnectionClient()->Connect(
63 quic_bridge_->kReceiverEndpoint, &mock_connect_request);
64 std::unique_ptr<ProtocolConnection> stream;
65 EXPECT_CALL(mock_connect_request, OnConnectionOpenedMock());
66 EXPECT_CALL(quic_bridge_->mock_server_observer, OnIncomingConnectionMock(_))
67 .WillOnce(
68 Invoke([&stream](std::unique_ptr<ProtocolConnection>& connection) {
69 stream = std::move(connection);
70 }));
71 quic_bridge_->RunTasksUntilIdle();
72 return stream;
73 }
74
SetUp()75 void SetUp() override {
76 server_ = quic_bridge_->quic_server.get();
77 NetworkServiceManager::Create(nullptr, nullptr,
78 std::move(quic_bridge_->quic_client),
79 std::move(quic_bridge_->quic_server));
80 }
81
TearDown()82 void TearDown() override { NetworkServiceManager::Dispose(); }
83
SendTestMessage(ProtocolConnection * connection)84 void SendTestMessage(ProtocolConnection* connection) {
85 MockMessageCallback mock_message_callback;
86 MessageDemuxer::MessageWatch message_watch =
87 quic_bridge_->controller_demuxer->WatchMessageType(
88 0, msgs::Type::kPresentationConnectionMessage,
89 &mock_message_callback);
90
91 msgs::CborEncodeBuffer buffer;
92 msgs::PresentationConnectionMessage message;
93 message.connection_id = 7;
94 message.message.which = decltype(message.message.which)::kString;
95 new (&message.message.str) std::string("message from server");
96 ASSERT_TRUE(msgs::EncodePresentationConnectionMessage(message, &buffer));
97 connection->Write(buffer.data(), buffer.size());
98 connection->CloseWriteEnd();
99
100 ssize_t decode_result = 0;
101 msgs::PresentationConnectionMessage received_message;
102 EXPECT_CALL(mock_message_callback,
103 OnStreamMessage(
104 0, _, msgs::Type::kPresentationConnectionMessage, _, _, _))
105 .WillOnce(Invoke([&decode_result, &received_message](
106 uint64_t endpoint_id, uint64_t connection_id,
107 msgs::Type message_type, const uint8_t* buffer,
108 size_t buffer_size, Clock::time_point now) {
109 decode_result = msgs::DecodePresentationConnectionMessage(
110 buffer, buffer_size, &received_message);
111 if (decode_result < 0)
112 return ErrorOr<size_t>(Error::Code::kCborParsing);
113 return ErrorOr<size_t>(decode_result);
114 }));
115 quic_bridge_->RunTasksUntilIdle();
116
117 ASSERT_GT(decode_result, 0);
118 EXPECT_EQ(decode_result, static_cast<ssize_t>(buffer.size() - 1));
119 EXPECT_EQ(received_message.connection_id, message.connection_id);
120 ASSERT_EQ(received_message.message.which,
121 decltype(received_message.message.which)::kString);
122 EXPECT_EQ(received_message.message.str, message.message.str);
123 }
124
125 std::unique_ptr<FakeClock> fake_clock_;
126 std::unique_ptr<FakeTaskRunner> task_runner_;
127 std::unique_ptr<FakeQuicBridge> quic_bridge_;
128 QuicServer* server_;
129 };
130
131 } // namespace
132
TEST_F(QuicServerTest,Connect)133 TEST_F(QuicServerTest, Connect) {
134 std::unique_ptr<ProtocolConnection> connection = ExpectIncomingConnection();
135 ASSERT_TRUE(connection);
136
137 SendTestMessage(connection.get());
138
139 server_->Stop();
140 }
141
TEST_F(QuicServerTest,OpenImmediate)142 TEST_F(QuicServerTest, OpenImmediate) {
143 EXPECT_FALSE(server_->CreateProtocolConnection(1));
144
145 std::unique_ptr<ProtocolConnection> connection1 = ExpectIncomingConnection();
146 ASSERT_TRUE(connection1);
147
148 std::unique_ptr<ProtocolConnection> connection2 =
149 server_->CreateProtocolConnection(connection1->endpoint_id());
150
151 SendTestMessage(connection2.get());
152
153 server_->Stop();
154 }
155
TEST_F(QuicServerTest,States)156 TEST_F(QuicServerTest, States) {
157 server_->Stop();
158 EXPECT_CALL(quic_bridge_->mock_server_observer, OnRunning());
159 EXPECT_TRUE(server_->Start());
160 EXPECT_FALSE(server_->Start());
161
162 std::unique_ptr<ProtocolConnection> connection = ExpectIncomingConnection();
163 ASSERT_TRUE(connection);
164 MockConnectionObserver mock_connection_observer;
165 connection->SetObserver(&mock_connection_observer);
166
167 EXPECT_CALL(mock_connection_observer, OnConnectionClosed(_));
168 EXPECT_CALL(quic_bridge_->mock_server_observer, OnStopped());
169 EXPECT_TRUE(server_->Stop());
170 EXPECT_FALSE(server_->Stop());
171
172 EXPECT_CALL(quic_bridge_->mock_server_observer, OnRunning());
173 EXPECT_TRUE(server_->Start());
174
175 EXPECT_CALL(quic_bridge_->mock_server_observer, OnSuspended());
176 EXPECT_TRUE(server_->Suspend());
177 EXPECT_FALSE(server_->Suspend());
178 EXPECT_FALSE(server_->Start());
179
180 EXPECT_CALL(quic_bridge_->mock_server_observer, OnRunning());
181 EXPECT_TRUE(server_->Resume());
182 EXPECT_FALSE(server_->Resume());
183 EXPECT_FALSE(server_->Start());
184
185 EXPECT_CALL(quic_bridge_->mock_server_observer, OnSuspended());
186 EXPECT_TRUE(server_->Suspend());
187
188 EXPECT_CALL(quic_bridge_->mock_server_observer, OnStopped());
189 EXPECT_TRUE(server_->Stop());
190 }
191
TEST_F(QuicServerTest,RequestIds)192 TEST_F(QuicServerTest, RequestIds) {
193 std::unique_ptr<ProtocolConnection> connection = ExpectIncomingConnection();
194 ASSERT_TRUE(connection);
195
196 uint64_t endpoint_id = connection->endpoint_id();
197 EXPECT_EQ(1u, server_->endpoint_request_ids()->GetNextRequestId(endpoint_id));
198 EXPECT_EQ(3u, server_->endpoint_request_ids()->GetNextRequestId(endpoint_id));
199
200 connection->CloseWriteEnd();
201 connection.reset();
202 quic_bridge_->RunTasksUntilIdle();
203 EXPECT_EQ(5u, server_->endpoint_request_ids()->GetNextRequestId(endpoint_id));
204
205 server_->Stop();
206 EXPECT_EQ(1u, server_->endpoint_request_ids()->GetNextRequestId(endpoint_id));
207 }
208
209 } // namespace osp
210 } // namespace openscreen
211