xref: /aosp_15_r20/external/cronet/net/third_party/quiche/src/quiche/quic/qbone/platform/netlink.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright (c) 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "quiche/quic/qbone/platform/netlink.h"
6 
7 #include <linux/fib_rules.h>
8 
9 #include <utility>
10 
11 #include "absl/base/attributes.h"
12 #include "absl/strings/str_cat.h"
13 #include "quiche/quic/core/crypto/quic_random.h"
14 #include "quiche/quic/platform/api/quic_ip_address.h"
15 #include "quiche/quic/platform/api/quic_logging.h"
16 #include "quiche/quic/qbone/platform/rtnetlink_message.h"
17 #include "quiche/quic/qbone/qbone_constants.h"
18 
19 namespace quic {
20 
Netlink(KernelInterface * kernel)21 Netlink::Netlink(KernelInterface* kernel) : kernel_(kernel) {
22   seq_ = QuicRandom::GetInstance()->RandUint64();
23 }
24 
~Netlink()25 Netlink::~Netlink() { CloseSocket(); }
26 
ResetRecvBuf(size_t size)27 void Netlink::ResetRecvBuf(size_t size) {
28   if (size != 0) {
29     recvbuf_ = std::make_unique<char[]>(size);
30   } else {
31     recvbuf_ = nullptr;
32   }
33   recvbuf_length_ = size;
34 }
35 
OpenSocket()36 bool Netlink::OpenSocket() {
37   if (socket_fd_ >= 0) {
38     return true;
39   }
40 
41   socket_fd_ = kernel_->socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE);
42 
43   if (socket_fd_ < 0) {
44     QUIC_PLOG(ERROR) << "can't open netlink socket";
45     return false;
46   }
47 
48   QUIC_LOG(INFO) << "Opened a new netlink socket fd = " << socket_fd_;
49 
50   // bind a local address to the socket
51   sockaddr_nl myaddr;
52   memset(&myaddr, 0, sizeof(myaddr));
53   myaddr.nl_family = AF_NETLINK;
54   if (kernel_->bind(socket_fd_, reinterpret_cast<struct sockaddr*>(&myaddr),
55                     sizeof(myaddr)) < 0) {
56     QUIC_LOG(INFO) << "can't bind address to socket";
57     CloseSocket();
58     return false;
59   }
60 
61   return true;
62 }
63 
CloseSocket()64 void Netlink::CloseSocket() {
65   if (socket_fd_ >= 0) {
66     QUIC_LOG(INFO) << "Closing netlink socket fd = " << socket_fd_;
67     kernel_->close(socket_fd_);
68   }
69   ResetRecvBuf(0);
70   socket_fd_ = -1;
71 }
72 
73 namespace {
74 
75 class LinkInfoParser : public NetlinkParserInterface {
76  public:
LinkInfoParser(std::string interface_name,Netlink::LinkInfo * link_info)77   LinkInfoParser(std::string interface_name, Netlink::LinkInfo* link_info)
78       : interface_name_(std::move(interface_name)), link_info_(link_info) {}
79 
Run(struct nlmsghdr * netlink_message)80   void Run(struct nlmsghdr* netlink_message) override {
81     if (netlink_message->nlmsg_type != RTM_NEWLINK) {
82       QUIC_LOG(INFO) << absl::StrCat(
83           "Unexpected nlmsg_type: ", netlink_message->nlmsg_type,
84           " expected: ", RTM_NEWLINK);
85       return;
86     }
87 
88     struct ifinfomsg* interface_info =
89         reinterpret_cast<struct ifinfomsg*>(NLMSG_DATA(netlink_message));
90 
91     // make sure interface_info is what we asked for.
92     if (interface_info->ifi_family != AF_UNSPEC) {
93       QUIC_LOG(INFO) << absl::StrCat(
94           "Unexpected ifi_family: ", interface_info->ifi_family,
95           " expected: ", AF_UNSPEC);
96       return;
97     }
98 
99     char hardware_address[kHwAddrSize];
100     size_t hardware_address_length = 0;
101     char broadcast_address[kHwAddrSize];
102     size_t broadcast_address_length = 0;
103     std::string name;
104 
105     // loop through the attributes
106     struct rtattr* rta;
107     int payload_length = IFLA_PAYLOAD(netlink_message);
108     for (rta = IFLA_RTA(interface_info); RTA_OK(rta, payload_length);
109          rta = RTA_NEXT(rta, payload_length)) {
110       int attribute_length;
111       switch (rta->rta_type) {
112         case IFLA_ADDRESS: {
113           attribute_length = RTA_PAYLOAD(rta);
114           if (attribute_length > kHwAddrSize) {
115             QUIC_VLOG(2) << "IFLA_ADDRESS too long: " << attribute_length;
116             break;
117           }
118           memmove(hardware_address, RTA_DATA(rta), attribute_length);
119           hardware_address_length = attribute_length;
120           break;
121         }
122         case IFLA_BROADCAST: {
123           attribute_length = RTA_PAYLOAD(rta);
124           if (attribute_length > kHwAddrSize) {
125             QUIC_VLOG(2) << "IFLA_BROADCAST too long: " << attribute_length;
126             break;
127           }
128           memmove(broadcast_address, RTA_DATA(rta), attribute_length);
129           broadcast_address_length = attribute_length;
130           break;
131         }
132         case IFLA_IFNAME: {
133           name = std::string(reinterpret_cast<char*>(RTA_DATA(rta)),
134                              RTA_PAYLOAD(rta));
135           // The name maybe a 0 terminated c string.
136           name = name.substr(0, name.find('\0'));
137           break;
138         }
139       }
140     }
141 
142     QUIC_VLOG(2) << "interface name: " << name
143                  << ", index: " << interface_info->ifi_index;
144 
145     if (name == interface_name_) {
146       link_info_->index = interface_info->ifi_index;
147       link_info_->type = interface_info->ifi_type;
148       link_info_->hardware_address_length = hardware_address_length;
149       if (hardware_address_length > 0) {
150         memmove(&link_info_->hardware_address, hardware_address,
151                 hardware_address_length);
152       }
153       link_info_->broadcast_address_length = broadcast_address_length;
154       if (broadcast_address_length > 0) {
155         memmove(&link_info_->broadcast_address, broadcast_address,
156                 broadcast_address_length);
157       }
158       found_link_ = true;
159     }
160   }
161 
found_link()162   bool found_link() { return found_link_; }
163 
164  private:
165   const std::string interface_name_;
166   Netlink::LinkInfo* const link_info_;
167   bool found_link_ = false;
168 };
169 
170 }  // namespace
171 
GetLinkInfo(const std::string & interface_name,LinkInfo * link_info)172 bool Netlink::GetLinkInfo(const std::string& interface_name,
173                           LinkInfo* link_info) {
174   auto message = LinkMessage::New(RtnetlinkMessage::Operation::GET,
175                                   NLM_F_ROOT | NLM_F_MATCH | NLM_F_REQUEST,
176                                   seq_, getpid(), nullptr);
177 
178   if (!Send(message.BuildIoVec().get(), message.IoVecSize())) {
179     QUIC_LOG(ERROR) << "send failed.";
180     return false;
181   }
182 
183   // Pass the parser to the receive routine. It may be called multiple times
184   // since there may be multiple reply packets each with multiple reply
185   // messages.
186   LinkInfoParser parser(interface_name, link_info);
187   if (!Recv(seq_++, &parser)) {
188     QUIC_LOG(ERROR) << "recv failed.";
189     return false;
190   }
191 
192   return parser.found_link();
193 }
194 
195 namespace {
196 
197 class LocalAddressParser : public NetlinkParserInterface {
198  public:
LocalAddressParser(int interface_index,uint8_t unwanted_flags,std::vector<Netlink::AddressInfo> * local_addresses,int * num_ipv6_nodad_dadfailed_addresses)199   LocalAddressParser(int interface_index, uint8_t unwanted_flags,
200                      std::vector<Netlink::AddressInfo>* local_addresses,
201                      int* num_ipv6_nodad_dadfailed_addresses)
202       : interface_index_(interface_index),
203         unwanted_flags_(unwanted_flags),
204         local_addresses_(local_addresses),
205         num_ipv6_nodad_dadfailed_addresses_(
206             num_ipv6_nodad_dadfailed_addresses) {}
207 
Run(struct nlmsghdr * netlink_message)208   void Run(struct nlmsghdr* netlink_message) override {
209     // each nlmsg contains a header and multiple address attributes.
210     if (netlink_message->nlmsg_type != RTM_NEWADDR) {
211       QUIC_LOG(INFO) << "Unexpected nlmsg_type: " << netlink_message->nlmsg_type
212                      << " expected: " << RTM_NEWADDR;
213       return;
214     }
215 
216     struct ifaddrmsg* interface_address =
217         reinterpret_cast<struct ifaddrmsg*>(NLMSG_DATA(netlink_message));
218 
219     // Make sure this is for an address family we're interested in.
220     if (interface_address->ifa_family != AF_INET &&
221         interface_address->ifa_family != AF_INET6) {
222       QUIC_VLOG(2) << absl::StrCat("uninteresting ifa family: ",
223                                    interface_address->ifa_family);
224       return;
225     }
226 
227     // Keep track of addresses with both 'nodad' and 'dadfailed', this really
228     // should't be possible and is likely a kernel bug.
229     if (num_ipv6_nodad_dadfailed_addresses_ != nullptr &&
230         (interface_address->ifa_flags & IFA_F_NODAD) &&
231         (interface_address->ifa_flags & IFA_F_DADFAILED)) {
232       ++(*num_ipv6_nodad_dadfailed_addresses_);
233     }
234 
235     uint8_t unwanted_flags = interface_address->ifa_flags & unwanted_flags_;
236     if (unwanted_flags != 0) {
237       QUIC_VLOG(2) << absl::StrCat("unwanted ifa flags: ", unwanted_flags);
238       return;
239     }
240 
241     // loop through the attributes
242     struct rtattr* rta;
243     int payload_length = IFA_PAYLOAD(netlink_message);
244     Netlink::AddressInfo address_info;
245     for (rta = IFA_RTA(interface_address); RTA_OK(rta, payload_length);
246          rta = RTA_NEXT(rta, payload_length)) {
247       // There's quite a lot of confusion in Linux over the use of IFA_LOCAL and
248       // IFA_ADDRESS (source and destination address). For broadcast links, such
249       // as Ethernet, they are identical (see <linux/if_addr.h>), but the kernel
250       // sometimes uses only one or the other. We'll return both so that the
251       // caller can decide which to use.
252       if (rta->rta_type != IFA_LOCAL && rta->rta_type != IFA_ADDRESS) {
253         QUIC_VLOG(2) << "Ignoring uninteresting rta_type: " << rta->rta_type;
254         continue;
255       }
256 
257       switch (interface_address->ifa_family) {
258         case AF_INET:
259           ABSL_FALLTHROUGH_INTENDED;
260         case AF_INET6:
261           // QuicIpAddress knows how to parse ip from raw bytes as long as they
262           // are in network byte order.
263           if (RTA_PAYLOAD(rta) == sizeof(struct in_addr) ||
264               RTA_PAYLOAD(rta) == sizeof(struct in6_addr)) {
265             auto* raw_ip = reinterpret_cast<char*>(RTA_DATA(rta));
266             if (rta->rta_type == IFA_LOCAL) {
267               address_info.local_address.FromPackedString(raw_ip,
268                                                           RTA_PAYLOAD(rta));
269             } else {
270               address_info.interface_address.FromPackedString(raw_ip,
271                                                               RTA_PAYLOAD(rta));
272             }
273           }
274           break;
275         default:
276           QUIC_LOG(ERROR) << absl::StrCat("Unknown address family: ",
277                                           interface_address->ifa_family);
278       }
279     }
280 
281     QUIC_VLOG(2) << "local_address: " << address_info.local_address.ToString()
282                  << " interface_address: "
283                  << address_info.interface_address.ToString()
284                  << " index: " << interface_address->ifa_index;
285     if (interface_address->ifa_index != interface_index_) {
286       return;
287     }
288 
289     address_info.prefix_length = interface_address->ifa_prefixlen;
290     address_info.scope = interface_address->ifa_scope;
291     if (address_info.local_address.IsInitialized() ||
292         address_info.interface_address.IsInitialized()) {
293       local_addresses_->push_back(address_info);
294     }
295   }
296 
297  private:
298   const int interface_index_;
299   const uint8_t unwanted_flags_;
300   std::vector<Netlink::AddressInfo>* const local_addresses_;
301   int* const num_ipv6_nodad_dadfailed_addresses_;
302 };
303 
304 }  // namespace
305 
GetAddresses(int interface_index,uint8_t unwanted_flags,std::vector<AddressInfo> * addresses,int * num_ipv6_nodad_dadfailed_addresses)306 bool Netlink::GetAddresses(int interface_index, uint8_t unwanted_flags,
307                            std::vector<AddressInfo>* addresses,
308                            int* num_ipv6_nodad_dadfailed_addresses) {
309   // the message doesn't contain the index, we'll have to do the filtering while
310   // parsing the reply. This is because NLM_F_MATCH, which only returns entries
311   // that matches the request criteria, is not yet implemented (see man 3
312   // netlink).
313   auto message = AddressMessage::New(RtnetlinkMessage::Operation::GET,
314                                      NLM_F_ROOT | NLM_F_MATCH | NLM_F_REQUEST,
315                                      seq_, getpid(), nullptr);
316 
317   // the send routine returns the socket to listen on.
318   if (!Send(message.BuildIoVec().get(), message.IoVecSize())) {
319     QUIC_LOG(ERROR) << "send failed.";
320     return false;
321   }
322 
323   addresses->clear();
324   if (num_ipv6_nodad_dadfailed_addresses != nullptr) {
325     *num_ipv6_nodad_dadfailed_addresses = 0;
326   }
327 
328   LocalAddressParser parser(interface_index, unwanted_flags, addresses,
329                             num_ipv6_nodad_dadfailed_addresses);
330   // Pass the parser to the receive routine. It may be called multiple times
331   // since there may be multiple reply packets each with multiple reply
332   // messages.
333   if (!Recv(seq_++, &parser)) {
334     QUIC_LOG(ERROR) << "recv failed";
335     return false;
336   }
337   return true;
338 }
339 
340 namespace {
341 
342 class UnknownParser : public NetlinkParserInterface {
343  public:
Run(struct nlmsghdr * netlink_message)344   void Run(struct nlmsghdr* netlink_message) override {
345     QUIC_LOG(INFO) << "nlmsg reply type: " << netlink_message->nlmsg_type;
346   }
347 };
348 
349 }  // namespace
350 
ChangeLocalAddress(uint32_t interface_index,Verb verb,const QuicIpAddress & address,uint8_t prefix_length,uint8_t ifa_flags,uint8_t ifa_scope,const std::vector<struct rtattr * > & additional_attributes)351 bool Netlink::ChangeLocalAddress(
352     uint32_t interface_index, Verb verb, const QuicIpAddress& address,
353     uint8_t prefix_length, uint8_t ifa_flags, uint8_t ifa_scope,
354     const std::vector<struct rtattr*>& additional_attributes) {
355   if (verb == Verb::kReplace) {
356     return false;
357   }
358   auto operation = verb == Verb::kAdd ? RtnetlinkMessage::Operation::NEW
359                                       : RtnetlinkMessage::Operation::DEL;
360   uint8_t address_family;
361   if (address.address_family() == IpAddressFamily::IP_V4) {
362     address_family = AF_INET;
363   } else if (address.address_family() == IpAddressFamily::IP_V6) {
364     address_family = AF_INET6;
365   } else {
366     return false;
367   }
368 
369   struct ifaddrmsg address_header = {address_family, prefix_length, ifa_flags,
370                                      ifa_scope, interface_index};
371 
372   auto message = AddressMessage::New(operation, NLM_F_REQUEST | NLM_F_ACK, seq_,
373                                      getpid(), &address_header);
374 
375   for (const auto& attribute : additional_attributes) {
376     if (attribute->rta_type == IFA_LOCAL) {
377       continue;
378     }
379     message.AppendAttribute(attribute->rta_type, RTA_DATA(attribute),
380                             RTA_PAYLOAD(attribute));
381   }
382 
383   message.AppendAttribute(IFA_LOCAL, address.ToPackedString().c_str(),
384                           address.ToPackedString().size());
385 
386   if (!Send(message.BuildIoVec().get(), message.IoVecSize())) {
387     QUIC_LOG(ERROR) << "send failed";
388     return false;
389   }
390 
391   UnknownParser parser;
392   if (!Recv(seq_++, &parser)) {
393     QUIC_LOG(ERROR) << "receive failed.";
394     return false;
395   }
396   return true;
397 }
398 
399 namespace {
400 
401 class RoutingRuleParser : public NetlinkParserInterface {
402  public:
RoutingRuleParser(std::vector<Netlink::RoutingRule> * routing_rules)403   explicit RoutingRuleParser(std::vector<Netlink::RoutingRule>* routing_rules)
404       : routing_rules_(routing_rules) {}
405 
Run(struct nlmsghdr * netlink_message)406   void Run(struct nlmsghdr* netlink_message) override {
407     if (netlink_message->nlmsg_type != RTM_NEWROUTE) {
408       QUIC_LOG(WARNING) << absl::StrCat(
409           "Unexpected nlmsg_type: ", netlink_message->nlmsg_type,
410           " expected: ", RTM_NEWROUTE);
411       return;
412     }
413 
414     auto* route = reinterpret_cast<struct rtmsg*>(NLMSG_DATA(netlink_message));
415     int payload_length = RTM_PAYLOAD(netlink_message);
416 
417     if (route->rtm_family != AF_INET && route->rtm_family != AF_INET6) {
418       QUIC_VLOG(2) << absl::StrCat("Uninteresting family: ", route->rtm_family);
419       return;
420     }
421 
422     Netlink::RoutingRule rule;
423     rule.scope = route->rtm_scope;
424     rule.table = route->rtm_table;
425     rule.init_cwnd = Netlink::kUnspecifiedInitCwnd;
426 
427     struct rtattr* rta;
428     for (rta = RTM_RTA(route); RTA_OK(rta, payload_length);
429          rta = RTA_NEXT(rta, payload_length)) {
430       switch (rta->rta_type) {
431         case RTA_TABLE: {
432           rule.table = *reinterpret_cast<uint32_t*>(RTA_DATA(rta));
433           break;
434         }
435         case RTA_DST: {
436           QuicIpAddress destination;
437           destination.FromPackedString(reinterpret_cast<char*> RTA_DATA(rta),
438                                        RTA_PAYLOAD(rta));
439           rule.destination_subnet = IpRange(destination, route->rtm_dst_len);
440           break;
441         }
442         case RTA_PREFSRC: {
443           QuicIpAddress preferred_source;
444           rule.preferred_source.FromPackedString(
445               reinterpret_cast<char*> RTA_DATA(rta), RTA_PAYLOAD(rta));
446           break;
447         }
448         case RTA_OIF: {
449           rule.out_interface = *reinterpret_cast<int*>(RTA_DATA(rta));
450           break;
451         }
452         case RTA_METRICS: {
453           struct rtattr* rtax;
454           int rta_payload_length = RTA_PAYLOAD(rta);
455           for (rtax = reinterpret_cast<struct rtattr*>(RTA_DATA(rta));
456                RTA_OK(rtax, rta_payload_length);
457                rtax = RTA_NEXT(rtax, rta_payload_length)) {
458             switch (rtax->rta_type) {
459               case RTAX_INITCWND: {
460                 rule.init_cwnd = *reinterpret_cast<uint32_t*>(RTA_DATA(rtax));
461                 break;
462               }
463               default: {
464                 QUIC_VLOG(2) << absl::StrCat(
465                     "Uninteresting RTA_METRICS attribute: ", rtax->rta_type);
466               }
467             }
468           }
469           break;
470         }
471         default: {
472           QUIC_VLOG(2) << absl::StrCat("Uninteresting attribute: ",
473                                        rta->rta_type);
474         }
475       }
476     }
477     routing_rules_->push_back(rule);
478   }
479 
480  private:
481   std::vector<Netlink::RoutingRule>* routing_rules_;
482 };
483 
484 }  // namespace
485 
GetRouteInfo(std::vector<Netlink::RoutingRule> * routing_rules)486 bool Netlink::GetRouteInfo(std::vector<Netlink::RoutingRule>* routing_rules) {
487   rtmsg route_message{};
488   // Only manipulate main routing table.
489   route_message.rtm_table = RT_TABLE_MAIN;
490 
491   auto message = RouteMessage::New(RtnetlinkMessage::Operation::GET,
492                                    NLM_F_REQUEST | NLM_F_ROOT | NLM_F_MATCH,
493                                    seq_, getpid(), &route_message);
494 
495   if (!Send(message.BuildIoVec().get(), message.IoVecSize())) {
496     QUIC_LOG(ERROR) << "send failed";
497     return false;
498   }
499 
500   RoutingRuleParser parser(routing_rules);
501   if (!Recv(seq_++, &parser)) {
502     QUIC_LOG(ERROR) << "recv failed";
503     return false;
504   }
505 
506   return true;
507 }
508 
ChangeRoute(Netlink::Verb verb,uint32_t table,const IpRange & destination_subnet,uint8_t scope,QuicIpAddress preferred_source,int32_t interface_index,uint32_t init_cwnd)509 bool Netlink::ChangeRoute(Netlink::Verb verb, uint32_t table,
510                           const IpRange& destination_subnet, uint8_t scope,
511                           QuicIpAddress preferred_source,
512                           int32_t interface_index, uint32_t init_cwnd) {
513   if (!destination_subnet.prefix().IsInitialized()) {
514     return false;
515   }
516   if (destination_subnet.address_family() != IpAddressFamily::IP_V4 &&
517       destination_subnet.address_family() != IpAddressFamily::IP_V6) {
518     return false;
519   }
520   if (preferred_source.IsInitialized() &&
521       preferred_source.address_family() !=
522           destination_subnet.address_family()) {
523     return false;
524   }
525 
526   RtnetlinkMessage::Operation operation;
527   uint16_t flags = NLM_F_REQUEST | NLM_F_ACK;
528   switch (verb) {
529     case Verb::kAdd:
530       operation = RtnetlinkMessage::Operation::NEW;
531       // Setting NLM_F_EXCL so that an existing entry for this subnet will fail
532       // the request. NLM_F_CREATE is necessary to indicate this is trying to
533       // create a new entry - simply having RTM_NEWROUTE is not enough even the
534       // name suggests so.
535       flags |= NLM_F_EXCL | NLM_F_CREATE;
536       break;
537     case Verb::kRemove:
538       operation = RtnetlinkMessage::Operation::DEL;
539       break;
540     case Verb::kReplace:
541       operation = RtnetlinkMessage::Operation::NEW;
542       // Setting NLM_F_REPLACE to tell the kernel that existing entry for this
543       // subnet should be replaced.
544       flags |= NLM_F_REPLACE | NLM_F_CREATE;
545       break;
546   }
547 
548   struct rtmsg route_message;
549   memset(&route_message, 0, sizeof(route_message));
550   route_message.rtm_family =
551       destination_subnet.address_family() == IpAddressFamily::IP_V4 ? AF_INET
552                                                                     : AF_INET6;
553   // rtm_dst_len and rtm_src_len are actually the subnet prefix lengths. Poor
554   // naming.
555   route_message.rtm_dst_len = destination_subnet.prefix_length();
556   // 0 means no source subnet for this rule.
557   route_message.rtm_src_len = 0;
558   // Only program the main table. Other tables are intended for the kernel to
559   // manage.
560   route_message.rtm_table = RT_TABLE_MAIN;
561   // Use RTPROT_UNSPEC to match all the different protocol. Rules added by
562   // kernel have RTPROT_KERNEL. Rules added by the root user have RTPROT_STATIC
563   // instead.
564   route_message.rtm_protocol =
565       verb == Verb::kRemove ? RTPROT_UNSPEC : RTPROT_STATIC;
566   route_message.rtm_scope = scope;
567   // Only add unicast routing rule.
568   route_message.rtm_type = RTN_UNICAST;
569   auto message =
570       RouteMessage::New(operation, flags, seq_, getpid(), &route_message);
571 
572   message.AppendAttribute(RTA_TABLE, &table, sizeof(table));
573 
574   if (init_cwnd != kUnspecifiedInitCwnd) {
575     char data[RTA_LENGTH(sizeof(uint32_t))];
576     struct rtattr* rta = reinterpret_cast<struct rtattr*>(data);
577     rta->rta_type = RTAX_INITCWND;
578     rta->rta_len = sizeof(data);
579     *reinterpret_cast<uint32_t*>(RTA_DATA(rta)) = init_cwnd;
580     message.AppendAttribute(RTA_METRICS, data, sizeof(data));
581   }
582 
583   // RTA_OIF is the target interface for this rule.
584   message.AppendAttribute(RTA_OIF, &interface_index, sizeof(interface_index));
585   // The actual destination subnet must be truncated of all the tailing zeros.
586   message.AppendAttribute(
587       RTA_DST,
588       reinterpret_cast<const void*>(
589           destination_subnet.prefix().ToPackedString().c_str()),
590       destination_subnet.prefix().ToPackedString().size());
591   // This is the source address to use in the IP packet should this routing rule
592   // is used.
593   if (preferred_source.IsInitialized()) {
594     auto src_str = preferred_source.ToPackedString();
595     message.AppendAttribute(RTA_PREFSRC,
596                             reinterpret_cast<const void*>(src_str.c_str()),
597                             src_str.size());
598   }
599 
600   if (verb != Verb::kRemove) {
601     auto gateway_str = QboneConstants::GatewayAddress()->ToPackedString();
602     message.AppendAttribute(RTA_GATEWAY,
603                             reinterpret_cast<const void*>(gateway_str.c_str()),
604                             gateway_str.size());
605   }
606 
607   if (!Send(message.BuildIoVec().get(), message.IoVecSize())) {
608     QUIC_LOG(ERROR) << "send failed";
609     return false;
610   }
611 
612   UnknownParser parser;
613   if (!Recv(seq_++, &parser)) {
614     QUIC_LOG(ERROR) << "receive failed.";
615     return false;
616   }
617   return true;
618 }
619 
620 namespace {
621 
622 class IpRuleParser : public NetlinkParserInterface {
623  public:
IpRuleParser(std::vector<Netlink::IpRule> * ip_rules)624   explicit IpRuleParser(std::vector<Netlink::IpRule>* ip_rules)
625       : ip_rules_(ip_rules) {}
626 
Run(struct nlmsghdr * netlink_message)627   void Run(struct nlmsghdr* netlink_message) override {
628     if (netlink_message->nlmsg_type != RTM_NEWRULE) {
629       QUIC_LOG(WARNING) << absl::StrCat(
630           "Unexpected nlmsg_type: ", netlink_message->nlmsg_type,
631           " expected: ", RTM_NEWRULE);
632       return;
633     }
634 
635     auto* rule = reinterpret_cast<rtmsg*>(NLMSG_DATA(netlink_message));
636     int payload_length = RTM_PAYLOAD(netlink_message);
637 
638     if (rule->rtm_family != AF_INET6) {
639       QUIC_LOG(ERROR) << absl::StrCat("Unexpected family: ", rule->rtm_family);
640       return;
641     }
642 
643     Netlink::IpRule ip_rule;
644     ip_rule.table = rule->rtm_table;
645 
646     struct rtattr* rta;
647     for (rta = RTM_RTA(rule); RTA_OK(rta, payload_length);
648          rta = RTA_NEXT(rta, payload_length)) {
649       switch (rta->rta_type) {
650         case RTA_TABLE: {
651           ip_rule.table = *reinterpret_cast<uint32_t*>(RTA_DATA(rta));
652           break;
653         }
654         case RTA_SRC: {
655           QuicIpAddress src_addr;
656           src_addr.FromPackedString(reinterpret_cast<char*>(RTA_DATA(rta)),
657                                     RTA_PAYLOAD(rta));
658           IpRange src_range(src_addr, rule->rtm_src_len);
659           ip_rule.source_range = src_range;
660           break;
661         }
662         default: {
663           QUIC_VLOG(2) << absl::StrCat("Uninteresting attribute: ",
664                                        rta->rta_type);
665         }
666       }
667     }
668     ip_rules_->emplace_back(ip_rule);
669   }
670 
671  private:
672   std::vector<Netlink::IpRule>* ip_rules_;
673 };
674 
675 }  // namespace
676 
GetRuleInfo(std::vector<Netlink::IpRule> * ip_rules)677 bool Netlink::GetRuleInfo(std::vector<Netlink::IpRule>* ip_rules) {
678   rtmsg rule_message{};
679   rule_message.rtm_family = AF_INET6;
680 
681   auto message = RuleMessage::New(RtnetlinkMessage::Operation::GET,
682                                   NLM_F_REQUEST | NLM_F_DUMP, seq_, getpid(),
683                                   &rule_message);
684 
685   if (!Send(message.BuildIoVec().get(), message.IoVecSize())) {
686     QUIC_LOG(ERROR) << "send failed";
687     return false;
688   }
689 
690   IpRuleParser parser(ip_rules);
691   if (!Recv(seq_++, &parser)) {
692     QUIC_LOG(ERROR) << "receive failed.";
693     return false;
694   }
695   return true;
696 }
697 
ChangeRule(Verb verb,uint32_t table,IpRange source_range)698 bool Netlink::ChangeRule(Verb verb, uint32_t table, IpRange source_range) {
699   RtnetlinkMessage::Operation operation;
700   uint16_t flags = NLM_F_REQUEST | NLM_F_ACK;
701 
702   rtmsg rule_message{};
703   rule_message.rtm_family = AF_INET6;
704   rule_message.rtm_protocol = RTPROT_STATIC;
705   rule_message.rtm_scope = RT_SCOPE_UNIVERSE;
706   rule_message.rtm_table = RT_TABLE_UNSPEC;
707 
708   rule_message.rtm_flags |= FIB_RULE_FIND_SADDR;
709 
710   switch (verb) {
711     case Verb::kAdd:
712       if (!source_range.IsInitialized()) {
713         QUIC_LOG(ERROR) << "Source range must be initialized.";
714         return false;
715       }
716       operation = RtnetlinkMessage::Operation::NEW;
717       flags |= NLM_F_EXCL | NLM_F_CREATE;
718       rule_message.rtm_type = FRA_DST;
719       rule_message.rtm_src_len = source_range.prefix_length();
720       break;
721     case Verb::kRemove:
722       operation = RtnetlinkMessage::Operation::DEL;
723       break;
724     case Verb::kReplace:
725       QUIC_LOG(ERROR) << "Unsupported verb: kReplace";
726       return false;
727   }
728   auto message =
729       RuleMessage::New(operation, flags, seq_, getpid(), &rule_message);
730 
731   message.AppendAttribute(RTA_TABLE, &table, sizeof(table));
732 
733   if (source_range.IsInitialized()) {
734     std::string packed_src = source_range.prefix().ToPackedString();
735     message.AppendAttribute(RTA_SRC,
736                             reinterpret_cast<const void*>(packed_src.c_str()),
737                             packed_src.size());
738   }
739 
740   if (!Send(message.BuildIoVec().get(), message.IoVecSize())) {
741     QUIC_LOG(ERROR) << "send failed";
742     return false;
743   }
744 
745   UnknownParser parser;
746   if (!Recv(seq_++, &parser)) {
747     QUIC_LOG(ERROR) << "receive failed.";
748     return false;
749   }
750   return true;
751 }
752 
Send(struct iovec * iov,size_t iovlen)753 bool Netlink::Send(struct iovec* iov, size_t iovlen) {
754   if (!OpenSocket()) {
755     QUIC_LOG(ERROR) << "can't open socket";
756     return false;
757   }
758 
759   // an address for communicating with the kernel netlink code
760   sockaddr_nl netlink_address;
761   memset(&netlink_address, 0, sizeof(netlink_address));
762   netlink_address.nl_family = AF_NETLINK;
763   netlink_address.nl_pid = 0;     // destination is kernel
764   netlink_address.nl_groups = 0;  // no multicast
765 
766   struct msghdr msg = {
767       &netlink_address, sizeof(netlink_address), iov, iovlen, nullptr, 0, 0};
768 
769   if (kernel_->sendmsg(socket_fd_, &msg, 0) < 0) {
770     QUIC_LOG(ERROR) << "sendmsg failed";
771     CloseSocket();
772     return false;
773   }
774 
775   return true;
776 }
777 
Recv(uint32_t seq,NetlinkParserInterface * parser)778 bool Netlink::Recv(uint32_t seq, NetlinkParserInterface* parser) {
779   sockaddr_nl netlink_address;
780 
781   // replies can span multiple packets
782   for (;;) {
783     socklen_t address_length = sizeof(netlink_address);
784 
785     // First, call recvfrom with buffer size of 0 and MSG_PEEK | MSG_TRUNC set
786     // so that we know the size of the incoming packet before actually receiving
787     // it.
788     int next_packet_size = kernel_->recvfrom(
789         socket_fd_, recvbuf_.get(), /* len = */ 0, MSG_PEEK | MSG_TRUNC,
790         reinterpret_cast<struct sockaddr*>(&netlink_address), &address_length);
791     if (next_packet_size < 0) {
792       QUIC_LOG(ERROR)
793           << "error recvfrom with MSG_PEEK | MSG_TRUNC to get packet length.";
794       CloseSocket();
795       return false;
796     }
797     QUIC_VLOG(3) << "netlink packet size: " << next_packet_size;
798     if (next_packet_size > recvbuf_length_) {
799       QUIC_VLOG(2) << "resizing recvbuf to " << next_packet_size;
800       ResetRecvBuf(next_packet_size);
801     }
802 
803     // Get the packet for real.
804     memset(recvbuf_.get(), 0, recvbuf_length_);
805     int len = kernel_->recvfrom(
806         socket_fd_, recvbuf_.get(), recvbuf_length_, /* flags = */ 0,
807         reinterpret_cast<struct sockaddr*>(&netlink_address), &address_length);
808     QUIC_VLOG(3) << "recvfrom returned: " << len;
809     if (len < 0) {
810       QUIC_LOG(INFO) << "can't receive netlink packet";
811       CloseSocket();
812       return false;
813     }
814 
815     // there may be multiple nlmsg's in each reply packet
816     struct nlmsghdr* netlink_message;
817     for (netlink_message = reinterpret_cast<struct nlmsghdr*>(recvbuf_.get());
818          NLMSG_OK(netlink_message, len);
819          netlink_message = NLMSG_NEXT(netlink_message, len)) {
820       QUIC_VLOG(3) << "netlink_message->nlmsg_type = "
821                    << netlink_message->nlmsg_type;
822       // make sure this is to us
823       if (netlink_message->nlmsg_seq != seq) {
824         QUIC_LOG(INFO) << "netlink_message not meant for us."
825                        << " seq: " << seq
826                        << " nlmsg_seq: " << netlink_message->nlmsg_seq;
827         continue;
828       }
829 
830       // done with this whole reply (not just this particular packet)
831       if (netlink_message->nlmsg_type == NLMSG_DONE) {
832         return true;
833       }
834       if (netlink_message->nlmsg_type == NLMSG_ERROR) {
835         struct nlmsgerr* err =
836             reinterpret_cast<struct nlmsgerr*>(NLMSG_DATA(netlink_message));
837         if (netlink_message->nlmsg_len <
838             NLMSG_LENGTH(sizeof(struct nlmsgerr))) {
839           QUIC_LOG(INFO) << "netlink_message ERROR truncated";
840         } else {
841           // an ACK
842           if (err->error == 0) {
843             QUIC_VLOG(3) << "Netlink sent an ACK";
844             return true;
845           }
846           QUIC_LOG(INFO) << "netlink_message ERROR: " << err->error;
847         }
848         return false;
849       }
850 
851       parser->Run(netlink_message);
852     }
853   }
854 }
855 
856 }  // namespace quic
857