xref: /aosp_15_r20/external/openscreen/discovery/mdns/mdns_responder_unittest.cc (revision 3f982cf4871df8771c9d4abe6e9a6f8d829b2736)
1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "discovery/mdns/mdns_responder.h"
6 
7 #include <utility>
8 
9 #include "discovery/common/config.h"
10 #include "discovery/mdns/mdns_probe_manager.h"
11 #include "discovery/mdns/mdns_random.h"
12 #include "discovery/mdns/mdns_receiver.h"
13 #include "discovery/mdns/mdns_records.h"
14 #include "discovery/mdns/mdns_sender.h"
15 #include "platform/test/fake_clock.h"
16 #include "platform/test/fake_task_runner.h"
17 #include "platform/test/fake_udp_socket.h"
18 
19 namespace openscreen {
20 namespace discovery {
21 namespace {
22 
23 constexpr Clock::duration kMaximumSharedRecordResponseDelayMs(120 * 1000);
24 
ContainsRecordType(const std::vector<MdnsRecord> & records,DnsType type)25 bool ContainsRecordType(const std::vector<MdnsRecord>& records, DnsType type) {
26   return std::find_if(records.begin(), records.end(),
27                       [type](const MdnsRecord& record) {
28                         return record.dns_type() == type;
29                       }) != records.end();
30 }
31 
CheckSingleNsecRecordType(const MdnsMessage & message,DnsType type)32 void CheckSingleNsecRecordType(const MdnsMessage& message, DnsType type) {
33   ASSERT_EQ(message.answers().size(), size_t{1});
34   const MdnsRecord record = message.answers()[0];
35 
36   ASSERT_EQ(record.dns_type(), DnsType::kNSEC);
37   const NsecRecordRdata& rdata = absl::get<NsecRecordRdata>(record.rdata());
38 
39   ASSERT_EQ(rdata.types().size(), size_t{1});
40   EXPECT_EQ(rdata.types()[0], type);
41 }
42 
CheckPtrDomain(const MdnsRecord & record,const DomainName & domain)43 void CheckPtrDomain(const MdnsRecord& record, const DomainName& domain) {
44   ASSERT_EQ(record.dns_type(), DnsType::kPTR);
45   const PtrRecordRdata& rdata = absl::get<PtrRecordRdata>(record.rdata());
46 
47   EXPECT_EQ(rdata.ptr_domain(), domain);
48 }
49 
ExpectContainsNsecRecordType(const std::vector<MdnsRecord> & records,DnsType type)50 void ExpectContainsNsecRecordType(const std::vector<MdnsRecord>& records,
51                                   DnsType type) {
52   auto it = std::find_if(
53       records.begin(), records.end(), [type](const MdnsRecord& record) {
54         if (record.dns_type() != DnsType::kNSEC) {
55           return false;
56         }
57 
58         const NsecRecordRdata& rdata =
59             absl::get<NsecRecordRdata>(record.rdata());
60         return rdata.types().size() == 1 && rdata.types()[0] == type;
61       });
62   EXPECT_TRUE(it != records.end());
63 }
64 
65 }  // namespace
66 
67 using testing::_;
68 using testing::Args;
69 using testing::Invoke;
70 using testing::Return;
71 using testing::StrictMock;
72 
73 class MockRecordHandler : public MdnsResponder::RecordHandler {
74  public:
AddRecord(MdnsRecord record)75   void AddRecord(MdnsRecord record) { records_.push_back(record); }
76 
77   MOCK_METHOD3(HasRecords, bool(const DomainName&, DnsType, DnsClass));
78 
GetRecords(const DomainName & name,DnsType type,DnsClass clazz)79   std::vector<MdnsRecord::ConstRef> GetRecords(const DomainName& name,
80                                                DnsType type,
81                                                DnsClass clazz) override {
82     std::vector<MdnsRecord::ConstRef> records;
83     for (const auto& record : records_) {
84       if (type == DnsType::kANY || record.dns_type() == type) {
85         records.push_back(record);
86       }
87     }
88 
89     return records;
90   }
91 
GetPtrRecords(DnsClass clazz)92   std::vector<MdnsRecord::ConstRef> GetPtrRecords(DnsClass clazz) override {
93     std::vector<MdnsRecord::ConstRef> records;
94     for (const auto& record : records_) {
95       if (record.dns_type() == DnsType::kPTR) {
96         records.push_back(record);
97       }
98     }
99 
100     return records;
101   }
102 
103  private:
104   std::vector<MdnsRecord> records_;
105 };
106 
107 class MockMdnsSender : public MdnsSender {
108  public:
MockMdnsSender(UdpSocket * socket)109   explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {}
110 
111   MOCK_METHOD1(SendMulticast, Error(const MdnsMessage& message));
112   MOCK_METHOD2(SendMessage,
113                Error(const MdnsMessage& message, const IPEndpoint& endpoint));
114 };
115 
116 class MockProbeManager : public MdnsProbeManager {
117  public:
118   MOCK_CONST_METHOD1(IsDomainClaimed, bool(const DomainName&));
119   MOCK_METHOD2(RespondToProbeQuery,
120                void(const MdnsMessage&, const IPEndpoint&));
121 };
122 
123 class MdnsResponderTest : public testing::Test {
124  public:
MdnsResponderTest()125   MdnsResponderTest()
126       : clock_(Clock::now()),
127         task_runner_(&clock_),
128         socket_(&task_runner_),
129         sender_(&socket_),
130         receiver_(config_),
131         responder_(&record_handler_,
132                    &probe_manager_,
133                    &sender_,
134                    &receiver_,
135                    &task_runner_,
136                    FakeClock::now,
137                    &random_,
138                    config_) {}
139 
140  protected:
GetFakePtrRecord(const DomainName & target)141   MdnsRecord GetFakePtrRecord(const DomainName& target) {
142     DomainName name(++target.labels().begin(), target.labels().end());
143     PtrRecordRdata rdata(target);
144     return MdnsRecord(std::move(name), DnsType::kPTR, DnsClass::kIN,
145                       RecordType::kUnique, std::chrono::seconds(0), rdata);
146   }
147 
GetFakeSrvRecord(const DomainName & name)148   MdnsRecord GetFakeSrvRecord(const DomainName& name) {
149     SrvRecordRdata rdata(0, 0, 80, name);
150     return MdnsRecord(name, DnsType::kSRV, DnsClass::kIN, RecordType::kUnique,
151                       std::chrono::seconds(0), rdata);
152   }
153 
GetFakeTxtRecord(const DomainName & name)154   MdnsRecord GetFakeTxtRecord(const DomainName& name) {
155     TxtRecordRdata rdata;
156     return MdnsRecord(name, DnsType::kTXT, DnsClass::kIN, RecordType::kUnique,
157                       std::chrono::seconds(0), rdata);
158   }
159 
GetFakeARecord(const DomainName & name)160   MdnsRecord GetFakeARecord(const DomainName& name) {
161     ARecordRdata rdata(IPAddress(192, 168, 0, 0));
162     return MdnsRecord(name, DnsType::kA, DnsClass::kIN, RecordType::kUnique,
163                       std::chrono::seconds(0), rdata);
164   }
165 
GetFakeAAAARecord(const DomainName & name)166   MdnsRecord GetFakeAAAARecord(const DomainName& name) {
167     AAAARecordRdata rdata(IPAddress(1, 2, 3, 4, 5, 6, 7, 8));
168     return MdnsRecord(name, DnsType::kAAAA, DnsClass::kIN, RecordType::kUnique,
169                       std::chrono::seconds(0), rdata);
170   }
171 
OnMessageReceived(const MdnsMessage & message,const IPEndpoint & src)172   void OnMessageReceived(const MdnsMessage& message, const IPEndpoint& src) {
173     responder_.OnMessageReceived(message, src);
174   }
175 
QueryForRecordTypeWhenNonePresent(DnsType type)176   void QueryForRecordTypeWhenNonePresent(DnsType type) {
177     MdnsQuestion question(domain_, type, DnsClass::kANY,
178                           ResponseType::kMulticast);
179     MdnsMessage message(0, MessageType::Query);
180     message.AddQuestion(question);
181 
182     EXPECT_CALL(sender_, SendMulticast(_))
183         .WillOnce([type](const MdnsMessage& msg) -> Error {
184           CheckSingleNsecRecordType(msg, type);
185           return Error::None();
186         });
187     EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
188     EXPECT_CALL(record_handler_, HasRecords(_, _, _))
189         .WillRepeatedly(Return(true));
190     OnMessageReceived(message, endpoint_);
191   }
192 
CreateMulticastMdnsQuery(DnsType type)193   MdnsMessage CreateMulticastMdnsQuery(DnsType type) {
194     MdnsQuestion question(domain_, type, DnsClass::kANY,
195                           ResponseType::kMulticast);
196     MdnsMessage message(0, MessageType::Query);
197     message.AddQuestion(std::move(question));
198 
199     return message;
200   }
201 
CreateTypeEnumerationQuery()202   MdnsMessage CreateTypeEnumerationQuery() {
203     MdnsQuestion question(type_enumeration_domain_, DnsType::kPTR,
204                           DnsClass::kANY, ResponseType::kMulticast);
205     MdnsMessage message(0, MessageType::Query);
206     message.AddQuestion(std::move(question));
207 
208     return message;
209   }
210 
211   const Config config_;
212   FakeClock clock_;
213   FakeTaskRunner task_runner_;
214   FakeUdpSocket socket_;
215   StrictMock<MockMdnsSender> sender_;
216   StrictMock<MockRecordHandler> record_handler_;
217   StrictMock<MockProbeManager> probe_manager_;
218   MdnsReceiver receiver_;
219   MdnsRandom random_;
220   MdnsResponder responder_;
221 
222   DomainName domain_{"instance", "_googlecast", "_tcp", "local"};
223   DomainName type_enumeration_domain_{"_services", "_dns-sd", "_udp", "local"};
224   IPEndpoint endpoint_{IPAddress(192, 168, 0, 0), 80};
225 };
226 
227 // Validate that when records may be sent from multiple receivers, the broadcast
228 // is delayed and it is not delayed otherwise.
TEST_F(MdnsResponderTest,OwnedRecordsSentImmediately)229 TEST_F(MdnsResponderTest, OwnedRecordsSentImmediately) {
230   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
231 
232   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
233   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
234       .WillRepeatedly(Return(true));
235   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
236   EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
237   OnMessageReceived(message, endpoint_);
238   testing::Mock::VerifyAndClearExpectations(&sender_);
239   testing::Mock::VerifyAndClearExpectations(&record_handler_);
240   testing::Mock::VerifyAndClearExpectations(&probe_manager_);
241 
242   EXPECT_CALL(sender_, SendMulticast(_)).Times(0);
243   clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
244 }
245 
TEST_F(MdnsResponderTest,NonOwnedRecordsDelayed)246 TEST_F(MdnsResponderTest, NonOwnedRecordsDelayed) {
247   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
248 
249   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(false));
250   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
251       .WillRepeatedly(Return(true));
252   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
253   EXPECT_CALL(sender_, SendMulticast(_)).Times(0);
254   OnMessageReceived(message, endpoint_);
255   testing::Mock::VerifyAndClearExpectations(&sender_);
256   testing::Mock::VerifyAndClearExpectations(&record_handler_);
257   testing::Mock::VerifyAndClearExpectations(&probe_manager_);
258 
259   EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
260   clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
261 }
262 
TEST_F(MdnsResponderTest,MultipleQuestionsProcessed)263 TEST_F(MdnsResponderTest, MultipleQuestionsProcessed) {
264   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
265   MdnsQuestion question2(domain_, DnsType::kANY, DnsClass::kANY,
266                          ResponseType::kMulticast);
267   message.AddQuestion(std::move(question2));
268 
269   EXPECT_CALL(probe_manager_, IsDomainClaimed(_))
270       .WillOnce(Return(true))
271       .WillOnce(Return(false));
272   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
273       .WillRepeatedly(Return(true));
274   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
275   EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
276   OnMessageReceived(message, endpoint_);
277   testing::Mock::VerifyAndClearExpectations(&sender_);
278   testing::Mock::VerifyAndClearExpectations(&record_handler_);
279   testing::Mock::VerifyAndClearExpectations(&probe_manager_);
280 
281   EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
282   clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
283 }
284 
285 // Validate that the correct messaging scheme (unicast vs multicast) is used.
TEST_F(MdnsResponderTest,UnicastMessageSentOverUnicast)286 TEST_F(MdnsResponderTest, UnicastMessageSentOverUnicast) {
287   MdnsQuestion question(domain_, DnsType::kANY, DnsClass::kANY,
288                         ResponseType::kUnicast);
289   MdnsMessage message(0, MessageType::Query);
290   message.AddQuestion(question);
291 
292   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
293   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
294       .WillRepeatedly(Return(true));
295   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
296   EXPECT_CALL(sender_, SendMessage(_, endpoint_)).Times(1);
297   OnMessageReceived(message, endpoint_);
298 }
299 
TEST_F(MdnsResponderTest,MulticastMessageSentOverMulticast)300 TEST_F(MdnsResponderTest, MulticastMessageSentOverMulticast) {
301   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
302 
303   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
304   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
305       .WillRepeatedly(Return(true));
306   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
307   EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
308   OnMessageReceived(message, endpoint_);
309 }
310 
311 // Validate that records are added as expected based on the query type, and that
312 // additional records are populated as specified in RFC 6762 and 6763.
TEST_F(MdnsResponderTest,AnyQueryResultsAllApplied)313 TEST_F(MdnsResponderTest, AnyQueryResultsAllApplied) {
314   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
315 
316   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
317   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
318       .WillRepeatedly(Return(true));
319   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
320   record_handler_.AddRecord(GetFakeTxtRecord(domain_));
321   record_handler_.AddRecord(GetFakeARecord(domain_));
322   record_handler_.AddRecord(GetFakeAAAARecord(domain_));
323   EXPECT_CALL(sender_, SendMulticast(_))
324       .WillOnce([](const MdnsMessage& message) -> Error {
325         EXPECT_EQ(message.questions().size(), size_t{0});
326         EXPECT_EQ(message.authority_records().size(), size_t{0});
327         EXPECT_EQ(message.additional_records().size(), size_t{0});
328 
329         EXPECT_EQ(message.answers().size(), size_t{4});
330         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
331         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kTXT));
332         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
333         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kAAAA));
334         EXPECT_FALSE(ContainsRecordType(message.answers(), DnsType::kPTR));
335         return Error::None();
336       });
337 
338   OnMessageReceived(message, endpoint_);
339 }
340 
TEST_F(MdnsResponderTest,PtrQueryResultsApplied)341 TEST_F(MdnsResponderTest, PtrQueryResultsApplied) {
342   DomainName ptr_domain{"_googlecast", "_tcp", "local"};
343   MdnsQuestion question(ptr_domain, DnsType::kPTR, DnsClass::kANY,
344                         ResponseType::kMulticast);
345   MdnsMessage message(0, MessageType::Query);
346   message.AddQuestion(question);
347 
348   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
349   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
350       .WillRepeatedly(Return(true));
351   record_handler_.AddRecord(GetFakePtrRecord(domain_));
352   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
353   record_handler_.AddRecord(GetFakeTxtRecord(domain_));
354   record_handler_.AddRecord(GetFakeARecord(domain_));
355   record_handler_.AddRecord(GetFakeAAAARecord(domain_));
356   EXPECT_CALL(sender_, SendMulticast(_))
357       .WillOnce([](const MdnsMessage& message) -> Error {
358         EXPECT_EQ(message.questions().size(), size_t{0});
359         EXPECT_EQ(message.authority_records().size(), size_t{0});
360         EXPECT_EQ(message.additional_records().size(), size_t{4});
361 
362         EXPECT_EQ(message.answers().size(), size_t{1});
363         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
364 
365         const auto& records = message.additional_records();
366         EXPECT_EQ(records.size(), size_t{4});
367         EXPECT_TRUE(ContainsRecordType(records, DnsType::kSRV));
368         EXPECT_TRUE(ContainsRecordType(records, DnsType::kTXT));
369         EXPECT_TRUE(ContainsRecordType(records, DnsType::kA));
370         EXPECT_TRUE(ContainsRecordType(records, DnsType::kAAAA));
371         EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR));
372 
373         return Error::None();
374       });
375 
376   OnMessageReceived(message, endpoint_);
377 }
378 
TEST_F(MdnsResponderTest,SrvQueryResultsApplied)379 TEST_F(MdnsResponderTest, SrvQueryResultsApplied) {
380   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kSRV);
381 
382   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
383   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
384       .WillRepeatedly(Return(true));
385   record_handler_.AddRecord(GetFakePtrRecord(domain_));
386   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
387   record_handler_.AddRecord(GetFakeTxtRecord(domain_));
388   record_handler_.AddRecord(GetFakeARecord(domain_));
389   record_handler_.AddRecord(GetFakeAAAARecord(domain_));
390   EXPECT_CALL(sender_, SendMulticast(_))
391       .WillOnce([](const MdnsMessage& message) -> Error {
392         EXPECT_EQ(message.questions().size(), size_t{0});
393         EXPECT_EQ(message.authority_records().size(), size_t{0});
394         EXPECT_EQ(message.additional_records().size(), size_t{2});
395 
396         EXPECT_EQ(message.answers().size(), size_t{1});
397         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
398 
399         const auto& records = message.additional_records();
400         EXPECT_EQ(records.size(), size_t{2});
401         EXPECT_TRUE(ContainsRecordType(records, DnsType::kA));
402         EXPECT_TRUE(ContainsRecordType(records, DnsType::kAAAA));
403         EXPECT_FALSE(ContainsRecordType(records, DnsType::kSRV));
404         EXPECT_FALSE(ContainsRecordType(records, DnsType::kTXT));
405         EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR));
406 
407         return Error::None();
408       });
409 
410   OnMessageReceived(message, endpoint_);
411 }
412 
TEST_F(MdnsResponderTest,AQueryResultsApplied)413 TEST_F(MdnsResponderTest, AQueryResultsApplied) {
414   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kA);
415 
416   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
417   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
418       .WillRepeatedly(Return(true));
419   record_handler_.AddRecord(GetFakePtrRecord(domain_));
420   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
421   record_handler_.AddRecord(GetFakeTxtRecord(domain_));
422   record_handler_.AddRecord(GetFakeARecord(domain_));
423   record_handler_.AddRecord(GetFakeAAAARecord(domain_));
424   EXPECT_CALL(sender_, SendMulticast(_))
425       .WillOnce([](const MdnsMessage& message) -> Error {
426         EXPECT_EQ(message.questions().size(), size_t{0});
427         EXPECT_EQ(message.authority_records().size(), size_t{0});
428         EXPECT_EQ(message.additional_records().size(), size_t{1});
429 
430         EXPECT_EQ(message.answers().size(), size_t{1});
431         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
432 
433         const auto& records = message.additional_records();
434         EXPECT_EQ(records.size(), size_t{1});
435         EXPECT_TRUE(ContainsRecordType(records, DnsType::kAAAA));
436         EXPECT_FALSE(ContainsRecordType(records, DnsType::kA));
437         EXPECT_FALSE(ContainsRecordType(records, DnsType::kSRV));
438         EXPECT_FALSE(ContainsRecordType(records, DnsType::kTXT));
439         EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR));
440 
441         return Error::None();
442       });
443 
444   OnMessageReceived(message, endpoint_);
445 }
446 
TEST_F(MdnsResponderTest,AAAAQueryResultsApplied)447 TEST_F(MdnsResponderTest, AAAAQueryResultsApplied) {
448   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kAAAA);
449 
450   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
451   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
452       .WillRepeatedly(Return(true));
453   record_handler_.AddRecord(GetFakePtrRecord(domain_));
454   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
455   record_handler_.AddRecord(GetFakeTxtRecord(domain_));
456   record_handler_.AddRecord(GetFakeARecord(domain_));
457   record_handler_.AddRecord(GetFakeAAAARecord(domain_));
458   EXPECT_CALL(sender_, SendMulticast(_))
459       .WillOnce([](const MdnsMessage& message) -> Error {
460         EXPECT_EQ(message.questions().size(), size_t{0});
461         EXPECT_EQ(message.authority_records().size(), size_t{0});
462         EXPECT_EQ(message.additional_records().size(), size_t{1});
463 
464         EXPECT_EQ(message.answers().size(), size_t{1});
465         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kAAAA));
466 
467         const auto& records = message.additional_records();
468         EXPECT_EQ(records.size(), size_t{1});
469         EXPECT_TRUE(ContainsRecordType(records, DnsType::kA));
470         EXPECT_FALSE(ContainsRecordType(records, DnsType::kAAAA));
471         EXPECT_FALSE(ContainsRecordType(records, DnsType::kSRV));
472         EXPECT_FALSE(ContainsRecordType(records, DnsType::kTXT));
473         EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR));
474 
475         return Error::None();
476       });
477 
478   OnMessageReceived(message, endpoint_);
479 }
480 
TEST_F(MdnsResponderTest,MessageOnlySentIfAnswerNotKnown)481 TEST_F(MdnsResponderTest, MessageOnlySentIfAnswerNotKnown) {
482   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kAAAA);
483   MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
484   message.AddAnswer(aaaa_record);
485 
486   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
487   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
488       .WillRepeatedly(Return(true));
489   record_handler_.AddRecord(GetFakePtrRecord(domain_));
490   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
491   record_handler_.AddRecord(GetFakeTxtRecord(domain_));
492   record_handler_.AddRecord(GetFakeARecord(domain_));
493   record_handler_.AddRecord(aaaa_record);
494 
495   OnMessageReceived(message, endpoint_);
496 }
497 
TEST_F(MdnsResponderTest,RecordOnlySentIfNotKnown)498 TEST_F(MdnsResponderTest, RecordOnlySentIfNotKnown) {
499   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
500   MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
501   message.AddAnswer(aaaa_record);
502 
503   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
504   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
505       .WillRepeatedly(Return(true));
506   record_handler_.AddRecord(GetFakeARecord(domain_));
507   record_handler_.AddRecord(aaaa_record);
508   EXPECT_CALL(sender_, SendMulticast(_))
509       .WillOnce([](const MdnsMessage& message) -> Error {
510         EXPECT_EQ(message.questions().size(), size_t{0});
511         EXPECT_EQ(message.authority_records().size(), size_t{0});
512         EXPECT_EQ(message.additional_records().size(), size_t{0});
513 
514         EXPECT_EQ(message.answers().size(), size_t{1});
515         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
516         return Error::None();
517       });
518 
519   OnMessageReceived(message, endpoint_);
520 }
521 
TEST_F(MdnsResponderTest,RecordOnlySentIfNotKnownMultiplePackets)522 TEST_F(MdnsResponderTest, RecordOnlySentIfNotKnownMultiplePackets) {
523   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
524   message.set_truncated();
525 
526   MdnsMessage message2(1, MessageType::Query);
527   MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
528   message2.AddAnswer(aaaa_record);
529 
530   OnMessageReceived(message, endpoint_);
531   OnMessageReceived(message2, endpoint_);
532 
533   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
534   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
535       .WillRepeatedly(Return(true));
536   record_handler_.AddRecord(GetFakeARecord(domain_));
537   record_handler_.AddRecord(aaaa_record);
538   EXPECT_CALL(sender_, SendMulticast(_))
539       .WillOnce([](const MdnsMessage& message) -> Error {
540         EXPECT_EQ(message.questions().size(), size_t{0});
541         EXPECT_EQ(message.authority_records().size(), size_t{0});
542         EXPECT_EQ(message.additional_records().size(), size_t{0});
543 
544         EXPECT_EQ(message.answers().size(), size_t{1});
545         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
546         return Error::None();
547       });
548   clock_.Advance(std::chrono::seconds(1));
549 }
550 
TEST_F(MdnsResponderTest,RecordOnlySentIfNotKnownMultiplePacketsOutOfOrder)551 TEST_F(MdnsResponderTest, RecordOnlySentIfNotKnownMultiplePacketsOutOfOrder) {
552   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
553   message.set_truncated();
554 
555   MdnsMessage message2(2, MessageType::Query);
556   MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
557   message2.AddAnswer(aaaa_record);
558   message2.set_truncated();
559 
560   MdnsMessage message3(3, MessageType::Query);
561   MdnsRecord a_record = GetFakeARecord(domain_);
562   message3.AddAnswer(a_record);
563 
564   OnMessageReceived(message2, endpoint_);
565   OnMessageReceived(message3, endpoint_);
566   OnMessageReceived(message, endpoint_);
567 
568   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
569   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
570       .WillRepeatedly(Return(true));
571   record_handler_.AddRecord(a_record);
572   record_handler_.AddRecord(aaaa_record);
573   record_handler_.AddRecord(aaaa_record);
574   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
575   EXPECT_CALL(sender_, SendMulticast(_))
576       .WillOnce([](const MdnsMessage& message) -> Error {
577         EXPECT_EQ(message.questions().size(), size_t{0});
578         EXPECT_EQ(message.authority_records().size(), size_t{0});
579         EXPECT_EQ(message.additional_records().size(), size_t{0});
580 
581         EXPECT_EQ(message.answers().size(), size_t{1});
582         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
583         return Error::None();
584       });
585   clock_.Advance(std::chrono::seconds(1));
586 }
587 
TEST_F(MdnsResponderTest,RecordSentForMultiPacketsSuppressionIfMoreNotFound)588 TEST_F(MdnsResponderTest, RecordSentForMultiPacketsSuppressionIfMoreNotFound) {
589   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
590   MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
591   message.AddAnswer(aaaa_record);
592   message.set_truncated();
593 
594   OnMessageReceived(message, endpoint_);
595 
596   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
597   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
598       .WillRepeatedly(Return(true));
599   record_handler_.AddRecord(GetFakeARecord(domain_));
600   record_handler_.AddRecord(aaaa_record);
601   EXPECT_CALL(sender_, SendMulticast(_))
602       .WillOnce([](const MdnsMessage& message) -> Error {
603         EXPECT_EQ(message.questions().size(), size_t{0});
604         EXPECT_EQ(message.authority_records().size(), size_t{0});
605         EXPECT_EQ(message.additional_records().size(), size_t{0});
606 
607         EXPECT_EQ(message.answers().size(), size_t{1});
608         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
609         return Error::None();
610       });
611   clock_.Advance(std::chrono::seconds(1));
612 }
613 
TEST_F(MdnsResponderTest,RecordNotSentForMultiPacketsSuppressionIfNoQuery)614 TEST_F(MdnsResponderTest, RecordNotSentForMultiPacketsSuppressionIfNoQuery) {
615   MdnsMessage message(1, MessageType::Query);
616   MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
617   message.AddAnswer(aaaa_record);
618 
619   OnMessageReceived(message, endpoint_);
620   clock_.Advance(std::chrono::seconds(1));
621 }
622 
623 // Validate NSEC records are used correctly.
TEST_F(MdnsResponderTest,QueryForRecordTypesWhenNonePresent)624 TEST_F(MdnsResponderTest, QueryForRecordTypesWhenNonePresent) {
625   QueryForRecordTypeWhenNonePresent(DnsType::kANY);
626   QueryForRecordTypeWhenNonePresent(DnsType::kSRV);
627   QueryForRecordTypeWhenNonePresent(DnsType::kTXT);
628   QueryForRecordTypeWhenNonePresent(DnsType::kA);
629   QueryForRecordTypeWhenNonePresent(DnsType::kAAAA);
630 }
631 
TEST_F(MdnsResponderTest,AAAAQueryGiveANsec)632 TEST_F(MdnsResponderTest, AAAAQueryGiveANsec) {
633   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kAAAA);
634 
635   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
636   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
637       .WillRepeatedly(Return(true));
638   record_handler_.AddRecord(GetFakePtrRecord(domain_));
639   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
640   record_handler_.AddRecord(GetFakeTxtRecord(domain_));
641   record_handler_.AddRecord(GetFakeAAAARecord(domain_));
642   EXPECT_CALL(sender_, SendMulticast(_))
643       .WillOnce([](const MdnsMessage& message) -> Error {
644         EXPECT_EQ(message.questions().size(), size_t{0});
645         EXPECT_EQ(message.authority_records().size(), size_t{0});
646 
647         EXPECT_EQ(message.answers().size(), size_t{1});
648         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kAAAA));
649 
650         EXPECT_EQ(message.additional_records().size(), size_t{1});
651         ExpectContainsNsecRecordType(message.additional_records(), DnsType::kA);
652 
653         return Error::None();
654       });
655 
656   OnMessageReceived(message, endpoint_);
657 }
658 
TEST_F(MdnsResponderTest,AQueryGiveAAAANsec)659 TEST_F(MdnsResponderTest, AQueryGiveAAAANsec) {
660   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kA);
661 
662   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
663   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
664       .WillRepeatedly(Return(true));
665   record_handler_.AddRecord(GetFakePtrRecord(domain_));
666   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
667   record_handler_.AddRecord(GetFakeTxtRecord(domain_));
668   record_handler_.AddRecord(GetFakeARecord(domain_));
669   EXPECT_CALL(sender_, SendMulticast(_))
670       .WillOnce([](const MdnsMessage& message) -> Error {
671         EXPECT_EQ(message.questions().size(), size_t{0});
672         EXPECT_EQ(message.authority_records().size(), size_t{0});
673 
674         EXPECT_EQ(message.answers().size(), size_t{1});
675         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
676 
677         EXPECT_EQ(message.additional_records().size(), size_t{1});
678         ExpectContainsNsecRecordType(message.additional_records(),
679                                      DnsType::kAAAA);
680 
681         return Error::None();
682       });
683 
684   OnMessageReceived(message, endpoint_);
685 }
686 
TEST_F(MdnsResponderTest,SrvQueryGiveCorrectNsecForNoAOrAAAA)687 TEST_F(MdnsResponderTest, SrvQueryGiveCorrectNsecForNoAOrAAAA) {
688   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kSRV);
689 
690   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
691   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
692       .WillRepeatedly(Return(true));
693   record_handler_.AddRecord(GetFakePtrRecord(domain_));
694   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
695   record_handler_.AddRecord(GetFakeTxtRecord(domain_));
696   EXPECT_CALL(sender_, SendMulticast(_))
697       .WillOnce([](const MdnsMessage& message) -> Error {
698         EXPECT_EQ(message.questions().size(), size_t{0});
699         EXPECT_EQ(message.authority_records().size(), size_t{0});
700 
701         EXPECT_EQ(message.answers().size(), size_t{1});
702         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
703 
704         EXPECT_EQ(message.additional_records().size(), size_t{2});
705         ExpectContainsNsecRecordType(message.additional_records(), DnsType::kA);
706         ExpectContainsNsecRecordType(message.additional_records(),
707                                      DnsType::kAAAA);
708 
709         return Error::None();
710       });
711   OnMessageReceived(message, endpoint_);
712 }
713 
TEST_F(MdnsResponderTest,SrvQueryGiveCorrectNsec)714 TEST_F(MdnsResponderTest, SrvQueryGiveCorrectNsec) {
715   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kSRV);
716 
717   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
718   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
719       .WillRepeatedly(Return(true));
720   record_handler_.AddRecord(GetFakePtrRecord(domain_));
721   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
722   record_handler_.AddRecord(GetFakeTxtRecord(domain_));
723   record_handler_.AddRecord(GetFakeARecord(domain_));
724   EXPECT_CALL(sender_, SendMulticast(_))
725       .WillOnce([](const MdnsMessage& message) -> Error {
726         EXPECT_EQ(message.questions().size(), size_t{0});
727         EXPECT_EQ(message.authority_records().size(), size_t{0});
728 
729         EXPECT_EQ(message.answers().size(), size_t{1});
730         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
731 
732         EXPECT_EQ(message.additional_records().size(), size_t{2});
733         EXPECT_TRUE(
734             ContainsRecordType(message.additional_records(), DnsType::kA));
735         ExpectContainsNsecRecordType(message.additional_records(),
736                                      DnsType::kAAAA);
737 
738         return Error::None();
739       });
740   OnMessageReceived(message, endpoint_);
741 }
742 
TEST_F(MdnsResponderTest,PtrQueryGiveCorrectNsecForNoPtrOrSrv)743 TEST_F(MdnsResponderTest, PtrQueryGiveCorrectNsecForNoPtrOrSrv) {
744   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kPTR);
745 
746   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
747   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
748       .WillRepeatedly(Return(true));
749   record_handler_.AddRecord(GetFakePtrRecord(domain_));
750   EXPECT_CALL(sender_, SendMulticast(_))
751       .WillOnce([](const MdnsMessage& message) -> Error {
752         EXPECT_EQ(message.questions().size(), size_t{0});
753         EXPECT_EQ(message.authority_records().size(), size_t{0});
754 
755         EXPECT_EQ(message.answers().size(), size_t{1});
756         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
757 
758         EXPECT_EQ(message.additional_records().size(), size_t{2});
759         ExpectContainsNsecRecordType(message.additional_records(),
760                                      DnsType::kTXT);
761         ExpectContainsNsecRecordType(message.additional_records(),
762                                      DnsType::kSRV);
763 
764         return Error::None();
765       });
766   OnMessageReceived(message, endpoint_);
767 }
768 
TEST_F(MdnsResponderTest,PtrQueryGiveCorrectNsecForOnlyPtr)769 TEST_F(MdnsResponderTest, PtrQueryGiveCorrectNsecForOnlyPtr) {
770   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kPTR);
771 
772   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
773   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
774       .WillRepeatedly(Return(true));
775   record_handler_.AddRecord(GetFakePtrRecord(domain_));
776   record_handler_.AddRecord(GetFakeTxtRecord(domain_));
777   EXPECT_CALL(sender_, SendMulticast(_))
778       .WillOnce([](const MdnsMessage& message) -> Error {
779         EXPECT_EQ(message.questions().size(), size_t{0});
780         EXPECT_EQ(message.authority_records().size(), size_t{0});
781 
782         EXPECT_EQ(message.answers().size(), size_t{1});
783         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
784 
785         EXPECT_EQ(message.additional_records().size(), size_t{2});
786         EXPECT_TRUE(
787             ContainsRecordType(message.additional_records(), DnsType::kTXT));
788         ExpectContainsNsecRecordType(message.additional_records(),
789                                      DnsType::kSRV);
790 
791         return Error::None();
792       });
793   OnMessageReceived(message, endpoint_);
794 }
795 
TEST_F(MdnsResponderTest,PtrQueryGiveCorrectNsecForOnlySrv)796 TEST_F(MdnsResponderTest, PtrQueryGiveCorrectNsecForOnlySrv) {
797   MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kPTR);
798 
799   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
800   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
801       .WillRepeatedly(Return(true));
802   record_handler_.AddRecord(GetFakePtrRecord(domain_));
803   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
804   EXPECT_CALL(sender_, SendMulticast(_))
805       .WillOnce([](const MdnsMessage& message) -> Error {
806         EXPECT_EQ(message.questions().size(), size_t{0});
807         EXPECT_EQ(message.authority_records().size(), size_t{0});
808 
809         EXPECT_EQ(message.answers().size(), size_t{1});
810         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
811 
812         EXPECT_EQ(message.additional_records().size(), size_t{4});
813         EXPECT_TRUE(
814             ContainsRecordType(message.additional_records(), DnsType::kSRV));
815         ExpectContainsNsecRecordType(message.additional_records(),
816                                      DnsType::kTXT);
817         ExpectContainsNsecRecordType(message.additional_records(), DnsType::kA);
818         ExpectContainsNsecRecordType(message.additional_records(),
819                                      DnsType::kAAAA);
820 
821         return Error::None();
822       });
823   OnMessageReceived(message, endpoint_);
824 }
825 
TEST_F(MdnsResponderTest,EnumerateAllQuery)826 TEST_F(MdnsResponderTest, EnumerateAllQuery) {
827   MdnsMessage message = CreateTypeEnumerationQuery();
828 
829   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(false));
830   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
831       .WillRepeatedly(Return(true));
832   const auto ptr = GetFakePtrRecord(domain_);
833   record_handler_.AddRecord(ptr);
834   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
835   record_handler_.AddRecord(GetFakeTxtRecord(domain_));
836   record_handler_.AddRecord(GetFakeARecord(domain_));
837   OnMessageReceived(message, endpoint_);
838 
839   EXPECT_CALL(sender_, SendMulticast(_))
840       .WillOnce([this, &ptr](const MdnsMessage& message) -> Error {
841         EXPECT_EQ(message.questions().size(), size_t{0});
842         EXPECT_EQ(message.authority_records().size(), size_t{0});
843 
844         EXPECT_EQ(message.answers().size(), size_t{1});
845         EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
846         EXPECT_EQ(message.answers()[0].name(), type_enumeration_domain_);
847         CheckPtrDomain(message.answers()[0], ptr.name());
848         return Error::None();
849       });
850   clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
851 }
852 
TEST_F(MdnsResponderTest,EnumerateAllQueryNoResults)853 TEST_F(MdnsResponderTest, EnumerateAllQueryNoResults) {
854   MdnsMessage message = CreateTypeEnumerationQuery();
855 
856   EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(false));
857   EXPECT_CALL(record_handler_, HasRecords(_, _, _))
858       .WillRepeatedly(Return(true));
859   const auto ptr = GetFakePtrRecord(domain_);
860   record_handler_.AddRecord(GetFakeSrvRecord(domain_));
861   record_handler_.AddRecord(GetFakeTxtRecord(domain_));
862   record_handler_.AddRecord(GetFakeARecord(domain_));
863   OnMessageReceived(message, endpoint_);
864   clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
865 }
866 
867 }  // namespace discovery
868 }  // namespace openscreen
869