1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "discovery/mdns/mdns_responder.h"
6
7 #include <utility>
8
9 #include "discovery/common/config.h"
10 #include "discovery/mdns/mdns_probe_manager.h"
11 #include "discovery/mdns/mdns_random.h"
12 #include "discovery/mdns/mdns_receiver.h"
13 #include "discovery/mdns/mdns_records.h"
14 #include "discovery/mdns/mdns_sender.h"
15 #include "platform/test/fake_clock.h"
16 #include "platform/test/fake_task_runner.h"
17 #include "platform/test/fake_udp_socket.h"
18
19 namespace openscreen {
20 namespace discovery {
21 namespace {
22
23 constexpr Clock::duration kMaximumSharedRecordResponseDelayMs(120 * 1000);
24
ContainsRecordType(const std::vector<MdnsRecord> & records,DnsType type)25 bool ContainsRecordType(const std::vector<MdnsRecord>& records, DnsType type) {
26 return std::find_if(records.begin(), records.end(),
27 [type](const MdnsRecord& record) {
28 return record.dns_type() == type;
29 }) != records.end();
30 }
31
CheckSingleNsecRecordType(const MdnsMessage & message,DnsType type)32 void CheckSingleNsecRecordType(const MdnsMessage& message, DnsType type) {
33 ASSERT_EQ(message.answers().size(), size_t{1});
34 const MdnsRecord record = message.answers()[0];
35
36 ASSERT_EQ(record.dns_type(), DnsType::kNSEC);
37 const NsecRecordRdata& rdata = absl::get<NsecRecordRdata>(record.rdata());
38
39 ASSERT_EQ(rdata.types().size(), size_t{1});
40 EXPECT_EQ(rdata.types()[0], type);
41 }
42
CheckPtrDomain(const MdnsRecord & record,const DomainName & domain)43 void CheckPtrDomain(const MdnsRecord& record, const DomainName& domain) {
44 ASSERT_EQ(record.dns_type(), DnsType::kPTR);
45 const PtrRecordRdata& rdata = absl::get<PtrRecordRdata>(record.rdata());
46
47 EXPECT_EQ(rdata.ptr_domain(), domain);
48 }
49
ExpectContainsNsecRecordType(const std::vector<MdnsRecord> & records,DnsType type)50 void ExpectContainsNsecRecordType(const std::vector<MdnsRecord>& records,
51 DnsType type) {
52 auto it = std::find_if(
53 records.begin(), records.end(), [type](const MdnsRecord& record) {
54 if (record.dns_type() != DnsType::kNSEC) {
55 return false;
56 }
57
58 const NsecRecordRdata& rdata =
59 absl::get<NsecRecordRdata>(record.rdata());
60 return rdata.types().size() == 1 && rdata.types()[0] == type;
61 });
62 EXPECT_TRUE(it != records.end());
63 }
64
65 } // namespace
66
67 using testing::_;
68 using testing::Args;
69 using testing::Invoke;
70 using testing::Return;
71 using testing::StrictMock;
72
73 class MockRecordHandler : public MdnsResponder::RecordHandler {
74 public:
AddRecord(MdnsRecord record)75 void AddRecord(MdnsRecord record) { records_.push_back(record); }
76
77 MOCK_METHOD3(HasRecords, bool(const DomainName&, DnsType, DnsClass));
78
GetRecords(const DomainName & name,DnsType type,DnsClass clazz)79 std::vector<MdnsRecord::ConstRef> GetRecords(const DomainName& name,
80 DnsType type,
81 DnsClass clazz) override {
82 std::vector<MdnsRecord::ConstRef> records;
83 for (const auto& record : records_) {
84 if (type == DnsType::kANY || record.dns_type() == type) {
85 records.push_back(record);
86 }
87 }
88
89 return records;
90 }
91
GetPtrRecords(DnsClass clazz)92 std::vector<MdnsRecord::ConstRef> GetPtrRecords(DnsClass clazz) override {
93 std::vector<MdnsRecord::ConstRef> records;
94 for (const auto& record : records_) {
95 if (record.dns_type() == DnsType::kPTR) {
96 records.push_back(record);
97 }
98 }
99
100 return records;
101 }
102
103 private:
104 std::vector<MdnsRecord> records_;
105 };
106
107 class MockMdnsSender : public MdnsSender {
108 public:
MockMdnsSender(UdpSocket * socket)109 explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {}
110
111 MOCK_METHOD1(SendMulticast, Error(const MdnsMessage& message));
112 MOCK_METHOD2(SendMessage,
113 Error(const MdnsMessage& message, const IPEndpoint& endpoint));
114 };
115
116 class MockProbeManager : public MdnsProbeManager {
117 public:
118 MOCK_CONST_METHOD1(IsDomainClaimed, bool(const DomainName&));
119 MOCK_METHOD2(RespondToProbeQuery,
120 void(const MdnsMessage&, const IPEndpoint&));
121 };
122
123 class MdnsResponderTest : public testing::Test {
124 public:
MdnsResponderTest()125 MdnsResponderTest()
126 : clock_(Clock::now()),
127 task_runner_(&clock_),
128 socket_(&task_runner_),
129 sender_(&socket_),
130 receiver_(config_),
131 responder_(&record_handler_,
132 &probe_manager_,
133 &sender_,
134 &receiver_,
135 &task_runner_,
136 FakeClock::now,
137 &random_,
138 config_) {}
139
140 protected:
GetFakePtrRecord(const DomainName & target)141 MdnsRecord GetFakePtrRecord(const DomainName& target) {
142 DomainName name(++target.labels().begin(), target.labels().end());
143 PtrRecordRdata rdata(target);
144 return MdnsRecord(std::move(name), DnsType::kPTR, DnsClass::kIN,
145 RecordType::kUnique, std::chrono::seconds(0), rdata);
146 }
147
GetFakeSrvRecord(const DomainName & name)148 MdnsRecord GetFakeSrvRecord(const DomainName& name) {
149 SrvRecordRdata rdata(0, 0, 80, name);
150 return MdnsRecord(name, DnsType::kSRV, DnsClass::kIN, RecordType::kUnique,
151 std::chrono::seconds(0), rdata);
152 }
153
GetFakeTxtRecord(const DomainName & name)154 MdnsRecord GetFakeTxtRecord(const DomainName& name) {
155 TxtRecordRdata rdata;
156 return MdnsRecord(name, DnsType::kTXT, DnsClass::kIN, RecordType::kUnique,
157 std::chrono::seconds(0), rdata);
158 }
159
GetFakeARecord(const DomainName & name)160 MdnsRecord GetFakeARecord(const DomainName& name) {
161 ARecordRdata rdata(IPAddress(192, 168, 0, 0));
162 return MdnsRecord(name, DnsType::kA, DnsClass::kIN, RecordType::kUnique,
163 std::chrono::seconds(0), rdata);
164 }
165
GetFakeAAAARecord(const DomainName & name)166 MdnsRecord GetFakeAAAARecord(const DomainName& name) {
167 AAAARecordRdata rdata(IPAddress(1, 2, 3, 4, 5, 6, 7, 8));
168 return MdnsRecord(name, DnsType::kAAAA, DnsClass::kIN, RecordType::kUnique,
169 std::chrono::seconds(0), rdata);
170 }
171
OnMessageReceived(const MdnsMessage & message,const IPEndpoint & src)172 void OnMessageReceived(const MdnsMessage& message, const IPEndpoint& src) {
173 responder_.OnMessageReceived(message, src);
174 }
175
QueryForRecordTypeWhenNonePresent(DnsType type)176 void QueryForRecordTypeWhenNonePresent(DnsType type) {
177 MdnsQuestion question(domain_, type, DnsClass::kANY,
178 ResponseType::kMulticast);
179 MdnsMessage message(0, MessageType::Query);
180 message.AddQuestion(question);
181
182 EXPECT_CALL(sender_, SendMulticast(_))
183 .WillOnce([type](const MdnsMessage& msg) -> Error {
184 CheckSingleNsecRecordType(msg, type);
185 return Error::None();
186 });
187 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
188 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
189 .WillRepeatedly(Return(true));
190 OnMessageReceived(message, endpoint_);
191 }
192
CreateMulticastMdnsQuery(DnsType type)193 MdnsMessage CreateMulticastMdnsQuery(DnsType type) {
194 MdnsQuestion question(domain_, type, DnsClass::kANY,
195 ResponseType::kMulticast);
196 MdnsMessage message(0, MessageType::Query);
197 message.AddQuestion(std::move(question));
198
199 return message;
200 }
201
CreateTypeEnumerationQuery()202 MdnsMessage CreateTypeEnumerationQuery() {
203 MdnsQuestion question(type_enumeration_domain_, DnsType::kPTR,
204 DnsClass::kANY, ResponseType::kMulticast);
205 MdnsMessage message(0, MessageType::Query);
206 message.AddQuestion(std::move(question));
207
208 return message;
209 }
210
211 const Config config_;
212 FakeClock clock_;
213 FakeTaskRunner task_runner_;
214 FakeUdpSocket socket_;
215 StrictMock<MockMdnsSender> sender_;
216 StrictMock<MockRecordHandler> record_handler_;
217 StrictMock<MockProbeManager> probe_manager_;
218 MdnsReceiver receiver_;
219 MdnsRandom random_;
220 MdnsResponder responder_;
221
222 DomainName domain_{"instance", "_googlecast", "_tcp", "local"};
223 DomainName type_enumeration_domain_{"_services", "_dns-sd", "_udp", "local"};
224 IPEndpoint endpoint_{IPAddress(192, 168, 0, 0), 80};
225 };
226
227 // Validate that when records may be sent from multiple receivers, the broadcast
228 // is delayed and it is not delayed otherwise.
TEST_F(MdnsResponderTest,OwnedRecordsSentImmediately)229 TEST_F(MdnsResponderTest, OwnedRecordsSentImmediately) {
230 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
231
232 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
233 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
234 .WillRepeatedly(Return(true));
235 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
236 EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
237 OnMessageReceived(message, endpoint_);
238 testing::Mock::VerifyAndClearExpectations(&sender_);
239 testing::Mock::VerifyAndClearExpectations(&record_handler_);
240 testing::Mock::VerifyAndClearExpectations(&probe_manager_);
241
242 EXPECT_CALL(sender_, SendMulticast(_)).Times(0);
243 clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
244 }
245
TEST_F(MdnsResponderTest,NonOwnedRecordsDelayed)246 TEST_F(MdnsResponderTest, NonOwnedRecordsDelayed) {
247 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
248
249 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(false));
250 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
251 .WillRepeatedly(Return(true));
252 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
253 EXPECT_CALL(sender_, SendMulticast(_)).Times(0);
254 OnMessageReceived(message, endpoint_);
255 testing::Mock::VerifyAndClearExpectations(&sender_);
256 testing::Mock::VerifyAndClearExpectations(&record_handler_);
257 testing::Mock::VerifyAndClearExpectations(&probe_manager_);
258
259 EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
260 clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
261 }
262
TEST_F(MdnsResponderTest,MultipleQuestionsProcessed)263 TEST_F(MdnsResponderTest, MultipleQuestionsProcessed) {
264 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
265 MdnsQuestion question2(domain_, DnsType::kANY, DnsClass::kANY,
266 ResponseType::kMulticast);
267 message.AddQuestion(std::move(question2));
268
269 EXPECT_CALL(probe_manager_, IsDomainClaimed(_))
270 .WillOnce(Return(true))
271 .WillOnce(Return(false));
272 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
273 .WillRepeatedly(Return(true));
274 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
275 EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
276 OnMessageReceived(message, endpoint_);
277 testing::Mock::VerifyAndClearExpectations(&sender_);
278 testing::Mock::VerifyAndClearExpectations(&record_handler_);
279 testing::Mock::VerifyAndClearExpectations(&probe_manager_);
280
281 EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
282 clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
283 }
284
285 // Validate that the correct messaging scheme (unicast vs multicast) is used.
TEST_F(MdnsResponderTest,UnicastMessageSentOverUnicast)286 TEST_F(MdnsResponderTest, UnicastMessageSentOverUnicast) {
287 MdnsQuestion question(domain_, DnsType::kANY, DnsClass::kANY,
288 ResponseType::kUnicast);
289 MdnsMessage message(0, MessageType::Query);
290 message.AddQuestion(question);
291
292 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
293 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
294 .WillRepeatedly(Return(true));
295 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
296 EXPECT_CALL(sender_, SendMessage(_, endpoint_)).Times(1);
297 OnMessageReceived(message, endpoint_);
298 }
299
TEST_F(MdnsResponderTest,MulticastMessageSentOverMulticast)300 TEST_F(MdnsResponderTest, MulticastMessageSentOverMulticast) {
301 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
302
303 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
304 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
305 .WillRepeatedly(Return(true));
306 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
307 EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
308 OnMessageReceived(message, endpoint_);
309 }
310
311 // Validate that records are added as expected based on the query type, and that
312 // additional records are populated as specified in RFC 6762 and 6763.
TEST_F(MdnsResponderTest,AnyQueryResultsAllApplied)313 TEST_F(MdnsResponderTest, AnyQueryResultsAllApplied) {
314 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
315
316 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
317 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
318 .WillRepeatedly(Return(true));
319 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
320 record_handler_.AddRecord(GetFakeTxtRecord(domain_));
321 record_handler_.AddRecord(GetFakeARecord(domain_));
322 record_handler_.AddRecord(GetFakeAAAARecord(domain_));
323 EXPECT_CALL(sender_, SendMulticast(_))
324 .WillOnce([](const MdnsMessage& message) -> Error {
325 EXPECT_EQ(message.questions().size(), size_t{0});
326 EXPECT_EQ(message.authority_records().size(), size_t{0});
327 EXPECT_EQ(message.additional_records().size(), size_t{0});
328
329 EXPECT_EQ(message.answers().size(), size_t{4});
330 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
331 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kTXT));
332 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
333 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kAAAA));
334 EXPECT_FALSE(ContainsRecordType(message.answers(), DnsType::kPTR));
335 return Error::None();
336 });
337
338 OnMessageReceived(message, endpoint_);
339 }
340
TEST_F(MdnsResponderTest,PtrQueryResultsApplied)341 TEST_F(MdnsResponderTest, PtrQueryResultsApplied) {
342 DomainName ptr_domain{"_googlecast", "_tcp", "local"};
343 MdnsQuestion question(ptr_domain, DnsType::kPTR, DnsClass::kANY,
344 ResponseType::kMulticast);
345 MdnsMessage message(0, MessageType::Query);
346 message.AddQuestion(question);
347
348 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
349 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
350 .WillRepeatedly(Return(true));
351 record_handler_.AddRecord(GetFakePtrRecord(domain_));
352 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
353 record_handler_.AddRecord(GetFakeTxtRecord(domain_));
354 record_handler_.AddRecord(GetFakeARecord(domain_));
355 record_handler_.AddRecord(GetFakeAAAARecord(domain_));
356 EXPECT_CALL(sender_, SendMulticast(_))
357 .WillOnce([](const MdnsMessage& message) -> Error {
358 EXPECT_EQ(message.questions().size(), size_t{0});
359 EXPECT_EQ(message.authority_records().size(), size_t{0});
360 EXPECT_EQ(message.additional_records().size(), size_t{4});
361
362 EXPECT_EQ(message.answers().size(), size_t{1});
363 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
364
365 const auto& records = message.additional_records();
366 EXPECT_EQ(records.size(), size_t{4});
367 EXPECT_TRUE(ContainsRecordType(records, DnsType::kSRV));
368 EXPECT_TRUE(ContainsRecordType(records, DnsType::kTXT));
369 EXPECT_TRUE(ContainsRecordType(records, DnsType::kA));
370 EXPECT_TRUE(ContainsRecordType(records, DnsType::kAAAA));
371 EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR));
372
373 return Error::None();
374 });
375
376 OnMessageReceived(message, endpoint_);
377 }
378
TEST_F(MdnsResponderTest,SrvQueryResultsApplied)379 TEST_F(MdnsResponderTest, SrvQueryResultsApplied) {
380 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kSRV);
381
382 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
383 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
384 .WillRepeatedly(Return(true));
385 record_handler_.AddRecord(GetFakePtrRecord(domain_));
386 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
387 record_handler_.AddRecord(GetFakeTxtRecord(domain_));
388 record_handler_.AddRecord(GetFakeARecord(domain_));
389 record_handler_.AddRecord(GetFakeAAAARecord(domain_));
390 EXPECT_CALL(sender_, SendMulticast(_))
391 .WillOnce([](const MdnsMessage& message) -> Error {
392 EXPECT_EQ(message.questions().size(), size_t{0});
393 EXPECT_EQ(message.authority_records().size(), size_t{0});
394 EXPECT_EQ(message.additional_records().size(), size_t{2});
395
396 EXPECT_EQ(message.answers().size(), size_t{1});
397 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
398
399 const auto& records = message.additional_records();
400 EXPECT_EQ(records.size(), size_t{2});
401 EXPECT_TRUE(ContainsRecordType(records, DnsType::kA));
402 EXPECT_TRUE(ContainsRecordType(records, DnsType::kAAAA));
403 EXPECT_FALSE(ContainsRecordType(records, DnsType::kSRV));
404 EXPECT_FALSE(ContainsRecordType(records, DnsType::kTXT));
405 EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR));
406
407 return Error::None();
408 });
409
410 OnMessageReceived(message, endpoint_);
411 }
412
TEST_F(MdnsResponderTest,AQueryResultsApplied)413 TEST_F(MdnsResponderTest, AQueryResultsApplied) {
414 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kA);
415
416 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
417 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
418 .WillRepeatedly(Return(true));
419 record_handler_.AddRecord(GetFakePtrRecord(domain_));
420 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
421 record_handler_.AddRecord(GetFakeTxtRecord(domain_));
422 record_handler_.AddRecord(GetFakeARecord(domain_));
423 record_handler_.AddRecord(GetFakeAAAARecord(domain_));
424 EXPECT_CALL(sender_, SendMulticast(_))
425 .WillOnce([](const MdnsMessage& message) -> Error {
426 EXPECT_EQ(message.questions().size(), size_t{0});
427 EXPECT_EQ(message.authority_records().size(), size_t{0});
428 EXPECT_EQ(message.additional_records().size(), size_t{1});
429
430 EXPECT_EQ(message.answers().size(), size_t{1});
431 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
432
433 const auto& records = message.additional_records();
434 EXPECT_EQ(records.size(), size_t{1});
435 EXPECT_TRUE(ContainsRecordType(records, DnsType::kAAAA));
436 EXPECT_FALSE(ContainsRecordType(records, DnsType::kA));
437 EXPECT_FALSE(ContainsRecordType(records, DnsType::kSRV));
438 EXPECT_FALSE(ContainsRecordType(records, DnsType::kTXT));
439 EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR));
440
441 return Error::None();
442 });
443
444 OnMessageReceived(message, endpoint_);
445 }
446
TEST_F(MdnsResponderTest,AAAAQueryResultsApplied)447 TEST_F(MdnsResponderTest, AAAAQueryResultsApplied) {
448 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kAAAA);
449
450 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
451 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
452 .WillRepeatedly(Return(true));
453 record_handler_.AddRecord(GetFakePtrRecord(domain_));
454 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
455 record_handler_.AddRecord(GetFakeTxtRecord(domain_));
456 record_handler_.AddRecord(GetFakeARecord(domain_));
457 record_handler_.AddRecord(GetFakeAAAARecord(domain_));
458 EXPECT_CALL(sender_, SendMulticast(_))
459 .WillOnce([](const MdnsMessage& message) -> Error {
460 EXPECT_EQ(message.questions().size(), size_t{0});
461 EXPECT_EQ(message.authority_records().size(), size_t{0});
462 EXPECT_EQ(message.additional_records().size(), size_t{1});
463
464 EXPECT_EQ(message.answers().size(), size_t{1});
465 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kAAAA));
466
467 const auto& records = message.additional_records();
468 EXPECT_EQ(records.size(), size_t{1});
469 EXPECT_TRUE(ContainsRecordType(records, DnsType::kA));
470 EXPECT_FALSE(ContainsRecordType(records, DnsType::kAAAA));
471 EXPECT_FALSE(ContainsRecordType(records, DnsType::kSRV));
472 EXPECT_FALSE(ContainsRecordType(records, DnsType::kTXT));
473 EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR));
474
475 return Error::None();
476 });
477
478 OnMessageReceived(message, endpoint_);
479 }
480
TEST_F(MdnsResponderTest,MessageOnlySentIfAnswerNotKnown)481 TEST_F(MdnsResponderTest, MessageOnlySentIfAnswerNotKnown) {
482 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kAAAA);
483 MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
484 message.AddAnswer(aaaa_record);
485
486 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
487 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
488 .WillRepeatedly(Return(true));
489 record_handler_.AddRecord(GetFakePtrRecord(domain_));
490 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
491 record_handler_.AddRecord(GetFakeTxtRecord(domain_));
492 record_handler_.AddRecord(GetFakeARecord(domain_));
493 record_handler_.AddRecord(aaaa_record);
494
495 OnMessageReceived(message, endpoint_);
496 }
497
TEST_F(MdnsResponderTest,RecordOnlySentIfNotKnown)498 TEST_F(MdnsResponderTest, RecordOnlySentIfNotKnown) {
499 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
500 MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
501 message.AddAnswer(aaaa_record);
502
503 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
504 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
505 .WillRepeatedly(Return(true));
506 record_handler_.AddRecord(GetFakeARecord(domain_));
507 record_handler_.AddRecord(aaaa_record);
508 EXPECT_CALL(sender_, SendMulticast(_))
509 .WillOnce([](const MdnsMessage& message) -> Error {
510 EXPECT_EQ(message.questions().size(), size_t{0});
511 EXPECT_EQ(message.authority_records().size(), size_t{0});
512 EXPECT_EQ(message.additional_records().size(), size_t{0});
513
514 EXPECT_EQ(message.answers().size(), size_t{1});
515 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
516 return Error::None();
517 });
518
519 OnMessageReceived(message, endpoint_);
520 }
521
TEST_F(MdnsResponderTest,RecordOnlySentIfNotKnownMultiplePackets)522 TEST_F(MdnsResponderTest, RecordOnlySentIfNotKnownMultiplePackets) {
523 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
524 message.set_truncated();
525
526 MdnsMessage message2(1, MessageType::Query);
527 MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
528 message2.AddAnswer(aaaa_record);
529
530 OnMessageReceived(message, endpoint_);
531 OnMessageReceived(message2, endpoint_);
532
533 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
534 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
535 .WillRepeatedly(Return(true));
536 record_handler_.AddRecord(GetFakeARecord(domain_));
537 record_handler_.AddRecord(aaaa_record);
538 EXPECT_CALL(sender_, SendMulticast(_))
539 .WillOnce([](const MdnsMessage& message) -> Error {
540 EXPECT_EQ(message.questions().size(), size_t{0});
541 EXPECT_EQ(message.authority_records().size(), size_t{0});
542 EXPECT_EQ(message.additional_records().size(), size_t{0});
543
544 EXPECT_EQ(message.answers().size(), size_t{1});
545 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
546 return Error::None();
547 });
548 clock_.Advance(std::chrono::seconds(1));
549 }
550
TEST_F(MdnsResponderTest,RecordOnlySentIfNotKnownMultiplePacketsOutOfOrder)551 TEST_F(MdnsResponderTest, RecordOnlySentIfNotKnownMultiplePacketsOutOfOrder) {
552 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
553 message.set_truncated();
554
555 MdnsMessage message2(2, MessageType::Query);
556 MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
557 message2.AddAnswer(aaaa_record);
558 message2.set_truncated();
559
560 MdnsMessage message3(3, MessageType::Query);
561 MdnsRecord a_record = GetFakeARecord(domain_);
562 message3.AddAnswer(a_record);
563
564 OnMessageReceived(message2, endpoint_);
565 OnMessageReceived(message3, endpoint_);
566 OnMessageReceived(message, endpoint_);
567
568 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
569 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
570 .WillRepeatedly(Return(true));
571 record_handler_.AddRecord(a_record);
572 record_handler_.AddRecord(aaaa_record);
573 record_handler_.AddRecord(aaaa_record);
574 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
575 EXPECT_CALL(sender_, SendMulticast(_))
576 .WillOnce([](const MdnsMessage& message) -> Error {
577 EXPECT_EQ(message.questions().size(), size_t{0});
578 EXPECT_EQ(message.authority_records().size(), size_t{0});
579 EXPECT_EQ(message.additional_records().size(), size_t{0});
580
581 EXPECT_EQ(message.answers().size(), size_t{1});
582 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
583 return Error::None();
584 });
585 clock_.Advance(std::chrono::seconds(1));
586 }
587
TEST_F(MdnsResponderTest,RecordSentForMultiPacketsSuppressionIfMoreNotFound)588 TEST_F(MdnsResponderTest, RecordSentForMultiPacketsSuppressionIfMoreNotFound) {
589 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
590 MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
591 message.AddAnswer(aaaa_record);
592 message.set_truncated();
593
594 OnMessageReceived(message, endpoint_);
595
596 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
597 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
598 .WillRepeatedly(Return(true));
599 record_handler_.AddRecord(GetFakeARecord(domain_));
600 record_handler_.AddRecord(aaaa_record);
601 EXPECT_CALL(sender_, SendMulticast(_))
602 .WillOnce([](const MdnsMessage& message) -> Error {
603 EXPECT_EQ(message.questions().size(), size_t{0});
604 EXPECT_EQ(message.authority_records().size(), size_t{0});
605 EXPECT_EQ(message.additional_records().size(), size_t{0});
606
607 EXPECT_EQ(message.answers().size(), size_t{1});
608 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
609 return Error::None();
610 });
611 clock_.Advance(std::chrono::seconds(1));
612 }
613
TEST_F(MdnsResponderTest,RecordNotSentForMultiPacketsSuppressionIfNoQuery)614 TEST_F(MdnsResponderTest, RecordNotSentForMultiPacketsSuppressionIfNoQuery) {
615 MdnsMessage message(1, MessageType::Query);
616 MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
617 message.AddAnswer(aaaa_record);
618
619 OnMessageReceived(message, endpoint_);
620 clock_.Advance(std::chrono::seconds(1));
621 }
622
623 // Validate NSEC records are used correctly.
TEST_F(MdnsResponderTest,QueryForRecordTypesWhenNonePresent)624 TEST_F(MdnsResponderTest, QueryForRecordTypesWhenNonePresent) {
625 QueryForRecordTypeWhenNonePresent(DnsType::kANY);
626 QueryForRecordTypeWhenNonePresent(DnsType::kSRV);
627 QueryForRecordTypeWhenNonePresent(DnsType::kTXT);
628 QueryForRecordTypeWhenNonePresent(DnsType::kA);
629 QueryForRecordTypeWhenNonePresent(DnsType::kAAAA);
630 }
631
TEST_F(MdnsResponderTest,AAAAQueryGiveANsec)632 TEST_F(MdnsResponderTest, AAAAQueryGiveANsec) {
633 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kAAAA);
634
635 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
636 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
637 .WillRepeatedly(Return(true));
638 record_handler_.AddRecord(GetFakePtrRecord(domain_));
639 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
640 record_handler_.AddRecord(GetFakeTxtRecord(domain_));
641 record_handler_.AddRecord(GetFakeAAAARecord(domain_));
642 EXPECT_CALL(sender_, SendMulticast(_))
643 .WillOnce([](const MdnsMessage& message) -> Error {
644 EXPECT_EQ(message.questions().size(), size_t{0});
645 EXPECT_EQ(message.authority_records().size(), size_t{0});
646
647 EXPECT_EQ(message.answers().size(), size_t{1});
648 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kAAAA));
649
650 EXPECT_EQ(message.additional_records().size(), size_t{1});
651 ExpectContainsNsecRecordType(message.additional_records(), DnsType::kA);
652
653 return Error::None();
654 });
655
656 OnMessageReceived(message, endpoint_);
657 }
658
TEST_F(MdnsResponderTest,AQueryGiveAAAANsec)659 TEST_F(MdnsResponderTest, AQueryGiveAAAANsec) {
660 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kA);
661
662 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
663 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
664 .WillRepeatedly(Return(true));
665 record_handler_.AddRecord(GetFakePtrRecord(domain_));
666 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
667 record_handler_.AddRecord(GetFakeTxtRecord(domain_));
668 record_handler_.AddRecord(GetFakeARecord(domain_));
669 EXPECT_CALL(sender_, SendMulticast(_))
670 .WillOnce([](const MdnsMessage& message) -> Error {
671 EXPECT_EQ(message.questions().size(), size_t{0});
672 EXPECT_EQ(message.authority_records().size(), size_t{0});
673
674 EXPECT_EQ(message.answers().size(), size_t{1});
675 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
676
677 EXPECT_EQ(message.additional_records().size(), size_t{1});
678 ExpectContainsNsecRecordType(message.additional_records(),
679 DnsType::kAAAA);
680
681 return Error::None();
682 });
683
684 OnMessageReceived(message, endpoint_);
685 }
686
TEST_F(MdnsResponderTest,SrvQueryGiveCorrectNsecForNoAOrAAAA)687 TEST_F(MdnsResponderTest, SrvQueryGiveCorrectNsecForNoAOrAAAA) {
688 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kSRV);
689
690 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
691 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
692 .WillRepeatedly(Return(true));
693 record_handler_.AddRecord(GetFakePtrRecord(domain_));
694 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
695 record_handler_.AddRecord(GetFakeTxtRecord(domain_));
696 EXPECT_CALL(sender_, SendMulticast(_))
697 .WillOnce([](const MdnsMessage& message) -> Error {
698 EXPECT_EQ(message.questions().size(), size_t{0});
699 EXPECT_EQ(message.authority_records().size(), size_t{0});
700
701 EXPECT_EQ(message.answers().size(), size_t{1});
702 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
703
704 EXPECT_EQ(message.additional_records().size(), size_t{2});
705 ExpectContainsNsecRecordType(message.additional_records(), DnsType::kA);
706 ExpectContainsNsecRecordType(message.additional_records(),
707 DnsType::kAAAA);
708
709 return Error::None();
710 });
711 OnMessageReceived(message, endpoint_);
712 }
713
TEST_F(MdnsResponderTest,SrvQueryGiveCorrectNsec)714 TEST_F(MdnsResponderTest, SrvQueryGiveCorrectNsec) {
715 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kSRV);
716
717 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
718 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
719 .WillRepeatedly(Return(true));
720 record_handler_.AddRecord(GetFakePtrRecord(domain_));
721 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
722 record_handler_.AddRecord(GetFakeTxtRecord(domain_));
723 record_handler_.AddRecord(GetFakeARecord(domain_));
724 EXPECT_CALL(sender_, SendMulticast(_))
725 .WillOnce([](const MdnsMessage& message) -> Error {
726 EXPECT_EQ(message.questions().size(), size_t{0});
727 EXPECT_EQ(message.authority_records().size(), size_t{0});
728
729 EXPECT_EQ(message.answers().size(), size_t{1});
730 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
731
732 EXPECT_EQ(message.additional_records().size(), size_t{2});
733 EXPECT_TRUE(
734 ContainsRecordType(message.additional_records(), DnsType::kA));
735 ExpectContainsNsecRecordType(message.additional_records(),
736 DnsType::kAAAA);
737
738 return Error::None();
739 });
740 OnMessageReceived(message, endpoint_);
741 }
742
TEST_F(MdnsResponderTest,PtrQueryGiveCorrectNsecForNoPtrOrSrv)743 TEST_F(MdnsResponderTest, PtrQueryGiveCorrectNsecForNoPtrOrSrv) {
744 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kPTR);
745
746 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
747 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
748 .WillRepeatedly(Return(true));
749 record_handler_.AddRecord(GetFakePtrRecord(domain_));
750 EXPECT_CALL(sender_, SendMulticast(_))
751 .WillOnce([](const MdnsMessage& message) -> Error {
752 EXPECT_EQ(message.questions().size(), size_t{0});
753 EXPECT_EQ(message.authority_records().size(), size_t{0});
754
755 EXPECT_EQ(message.answers().size(), size_t{1});
756 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
757
758 EXPECT_EQ(message.additional_records().size(), size_t{2});
759 ExpectContainsNsecRecordType(message.additional_records(),
760 DnsType::kTXT);
761 ExpectContainsNsecRecordType(message.additional_records(),
762 DnsType::kSRV);
763
764 return Error::None();
765 });
766 OnMessageReceived(message, endpoint_);
767 }
768
TEST_F(MdnsResponderTest,PtrQueryGiveCorrectNsecForOnlyPtr)769 TEST_F(MdnsResponderTest, PtrQueryGiveCorrectNsecForOnlyPtr) {
770 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kPTR);
771
772 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
773 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
774 .WillRepeatedly(Return(true));
775 record_handler_.AddRecord(GetFakePtrRecord(domain_));
776 record_handler_.AddRecord(GetFakeTxtRecord(domain_));
777 EXPECT_CALL(sender_, SendMulticast(_))
778 .WillOnce([](const MdnsMessage& message) -> Error {
779 EXPECT_EQ(message.questions().size(), size_t{0});
780 EXPECT_EQ(message.authority_records().size(), size_t{0});
781
782 EXPECT_EQ(message.answers().size(), size_t{1});
783 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
784
785 EXPECT_EQ(message.additional_records().size(), size_t{2});
786 EXPECT_TRUE(
787 ContainsRecordType(message.additional_records(), DnsType::kTXT));
788 ExpectContainsNsecRecordType(message.additional_records(),
789 DnsType::kSRV);
790
791 return Error::None();
792 });
793 OnMessageReceived(message, endpoint_);
794 }
795
TEST_F(MdnsResponderTest,PtrQueryGiveCorrectNsecForOnlySrv)796 TEST_F(MdnsResponderTest, PtrQueryGiveCorrectNsecForOnlySrv) {
797 MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kPTR);
798
799 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
800 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
801 .WillRepeatedly(Return(true));
802 record_handler_.AddRecord(GetFakePtrRecord(domain_));
803 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
804 EXPECT_CALL(sender_, SendMulticast(_))
805 .WillOnce([](const MdnsMessage& message) -> Error {
806 EXPECT_EQ(message.questions().size(), size_t{0});
807 EXPECT_EQ(message.authority_records().size(), size_t{0});
808
809 EXPECT_EQ(message.answers().size(), size_t{1});
810 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
811
812 EXPECT_EQ(message.additional_records().size(), size_t{4});
813 EXPECT_TRUE(
814 ContainsRecordType(message.additional_records(), DnsType::kSRV));
815 ExpectContainsNsecRecordType(message.additional_records(),
816 DnsType::kTXT);
817 ExpectContainsNsecRecordType(message.additional_records(), DnsType::kA);
818 ExpectContainsNsecRecordType(message.additional_records(),
819 DnsType::kAAAA);
820
821 return Error::None();
822 });
823 OnMessageReceived(message, endpoint_);
824 }
825
TEST_F(MdnsResponderTest,EnumerateAllQuery)826 TEST_F(MdnsResponderTest, EnumerateAllQuery) {
827 MdnsMessage message = CreateTypeEnumerationQuery();
828
829 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(false));
830 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
831 .WillRepeatedly(Return(true));
832 const auto ptr = GetFakePtrRecord(domain_);
833 record_handler_.AddRecord(ptr);
834 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
835 record_handler_.AddRecord(GetFakeTxtRecord(domain_));
836 record_handler_.AddRecord(GetFakeARecord(domain_));
837 OnMessageReceived(message, endpoint_);
838
839 EXPECT_CALL(sender_, SendMulticast(_))
840 .WillOnce([this, &ptr](const MdnsMessage& message) -> Error {
841 EXPECT_EQ(message.questions().size(), size_t{0});
842 EXPECT_EQ(message.authority_records().size(), size_t{0});
843
844 EXPECT_EQ(message.answers().size(), size_t{1});
845 EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
846 EXPECT_EQ(message.answers()[0].name(), type_enumeration_domain_);
847 CheckPtrDomain(message.answers()[0], ptr.name());
848 return Error::None();
849 });
850 clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
851 }
852
TEST_F(MdnsResponderTest,EnumerateAllQueryNoResults)853 TEST_F(MdnsResponderTest, EnumerateAllQueryNoResults) {
854 MdnsMessage message = CreateTypeEnumerationQuery();
855
856 EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(false));
857 EXPECT_CALL(record_handler_, HasRecords(_, _, _))
858 .WillRepeatedly(Return(true));
859 const auto ptr = GetFakePtrRecord(domain_);
860 record_handler_.AddRecord(GetFakeSrvRecord(domain_));
861 record_handler_.AddRecord(GetFakeTxtRecord(domain_));
862 record_handler_.AddRecord(GetFakeARecord(domain_));
863 OnMessageReceived(message, endpoint_);
864 clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
865 }
866
867 } // namespace discovery
868 } // namespace openscreen
869