xref: /aosp_15_r20/tools/netsim/rust/daemon/src/wifi/mdns_forwarder.rs (revision cf78ab8cffb8fc9207af348f23af247fb04370a6)
1 // Copyright 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use anyhow::anyhow;
16 use bytes::Bytes;
17 use log::{debug, warn};
18 use socket2::{Protocol, Socket};
19 use std::mem::MaybeUninit;
20 use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
21 use std::sync::mpsc;
22 
23 const MDNS_IP: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
24 const MDNS_PORT: u16 = 5353;
25 
26 struct MacAddress(u64);
27 
28 impl MacAddress {
to_be_bytes(&self) -> [u8; 6]29     fn to_be_bytes(&self) -> [u8; 6] {
30         // NOTE: mac address is le
31         self.0.to_le_bytes()[0..6].try_into().unwrap()
32     }
33 }
34 
35 impl From<MacAddress> for [u8; 6] {
from(MacAddress(addr): MacAddress) -> Self36     fn from(MacAddress(addr): MacAddress) -> Self {
37         let bytes = u64::to_le_bytes(addr);
38         bytes[0..6].try_into().unwrap()
39     }
40 }
41 
42 impl From<&[u8; 6]> for MacAddress {
from(bytes: &[u8; 6]) -> Self43     fn from(bytes: &[u8; 6]) -> Self {
44         Self(u64::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], 0, 0]))
45     }
46 }
47 
48 #[repr(C, packed)]
49 struct Ipv4Header {
50     version_ihl: u8, // 4 bits Version, 4 bits Internet Header Length
51     dscp_ecn: u8, // 6 bits Differentiated Services Code Point, 2 bits Explicit Congestion Notification
52     total_length: u16,
53     identification: u16,
54     flags_fragment_offset: u16, // 3 bits Flags, 13 bits Fragment Offset
55     time_to_live: u8,
56     protocol: u8,
57     header_checksum: u16,
58     source_ip: [u8; 4],
59     destination_ip: [u8; 4],
60 }
61 
62 macro_rules! be_vec {
63     ( $( $x:expr ),* ) => {
64          Vec::<u8>::new().iter().copied()
65          $( .chain($x.to_be_bytes()) )*
66          .collect()
67        };
68     }
69 
70 impl Ipv4Header {
calculate_checksum(&self) -> u1671     fn calculate_checksum(&self) -> u16 {
72         let mut sum: u32 = 0;
73 
74         // Process fixed-size fields (first 20 bytes)
75         let fixed_bytes: [u8; 20] = self.to_be_bytes();
76         for i in 0..10 {
77             let word = ((fixed_bytes[i * 2] as u16) << 8) | (fixed_bytes[i * 2 + 1] as u16);
78             sum += word as u32;
79         }
80 
81         // Handle carries (fold the carry into the sum)
82         while (sum >> 16) > 0 {
83             sum = (sum & 0xFFFF) + (sum >> 16);
84         }
85 
86         // One's complement
87         !sum as u16
88     }
89 
update_checksum(&mut self)90     fn update_checksum(&mut self) {
91         self.header_checksum = 0; // Reset checksum before calculation
92         self.header_checksum = self.calculate_checksum();
93     }
94 
to_be_bytes(&self) -> [u8; 20]95     fn to_be_bytes(&self) -> [u8; 20] {
96         let mut v: Vec<u8> = be_vec![
97             self.version_ihl,
98             self.dscp_ecn,
99             self.total_length,
100             self.identification,
101             self.flags_fragment_offset,
102             self.time_to_live,
103             self.protocol,
104             self.header_checksum
105         ];
106         v.extend(Ipv4Addr::from(self.source_ip).octets());
107         v.extend(Ipv4Addr::from(self.destination_ip).octets());
108         v.try_into().unwrap()
109     }
110 }
111 
112 #[repr(C, packed)]
113 struct UdpHeader {
114     source_port: u16,
115     destination_port: u16,
116     length: u16,
117     checksum: u16,
118 }
119 
120 impl UdpHeader {
to_be_bytes(&self) -> [u8; 8]121     fn to_be_bytes(&self) -> [u8; 8] {
122         let v: Vec<u8> =
123             be_vec![self.source_port, self.destination_port, self.length, self.checksum];
124         v.try_into().unwrap()
125     }
126 }
127 
128 /* 10Mb/s ethernet header */
129 
130 #[repr(C, packed)]
131 struct EtherHeader {
132     ether_dhost: [u8; 6],
133     ether_shost: [u8; 6],
134     ether_type: u16,
135 }
136 
137 /* Ethernet protocol ID's */
138 const ETHER_TYPE_IP: u16 = 0x0800;
139 
140 impl EtherHeader {
to_be_bytes(&self) -> [u8; 14]141     fn to_be_bytes(&self) -> [u8; 14] {
142         let v: Vec<u8> = be_vec![
143             MacAddress::from(&self.ether_dhost),
144             MacAddress::from(&self.ether_shost),
145             self.ether_type
146         ];
147         v.try_into().unwrap()
148     }
149 }
150 
151 // Define constants for header sizes (bytes)
152 const UDP_HEADER_LEN: usize = std::mem::size_of::<UdpHeader>();
153 const IPV4_HEADER_LEN: usize = std::mem::size_of::<Ipv4Header>();
154 const ETHER_HEADER_LEN: usize = std::mem::size_of::<EtherHeader>();
155 
156 /// Creates a new UDP socket to bind to `port` with REUSEPORT option.
157 /// `non_block` indicates whether to set O_NONBLOCK for the socket.
new_socket(addr: SocketAddr, non_block: bool) -> anyhow::Result<Socket>158 fn new_socket(addr: SocketAddr, non_block: bool) -> anyhow::Result<Socket> {
159     let domain = match addr {
160         SocketAddr::V4(_) => socket2::Domain::IPV4,
161         SocketAddr::V6(_) => socket2::Domain::IPV6,
162     };
163 
164     let socket = Socket::new(domain, socket2::Type::DGRAM, Some(Protocol::UDP))
165         .map_err(|e| anyhow!("create socket failed: {:?}", e))?;
166 
167     socket.set_reuse_address(true).map_err(|e| anyhow!("set ReuseAddr failed: {:?}", e))?;
168     #[cfg(not(windows))]
169     socket.set_reuse_port(true)?;
170 
171     #[cfg(unix)] // this is currently restricted to Unix's in socket2
172     socket.set_reuse_port(true).map_err(|e| anyhow!("set ReusePort failed: {:?}", e))?;
173 
174     if non_block {
175         socket.set_nonblocking(true).map_err(|e| anyhow!("set O_NONBLOCK: {:?}", e))?;
176     }
177 
178     socket.join_multicast_v4(&MDNS_IP, &Ipv4Addr::UNSPECIFIED)?;
179     socket.set_multicast_loop_v4(false).expect("set_multicast_loop_v4 call failed");
180 
181     socket.bind(&addr.into()).map_err(|e| anyhow!("socket bind to {} failed: {:?}", &addr, e))?;
182 
183     Ok(socket)
184 }
185 
create_ethernet_frame(packet: &[u8], ip_addr: &Ipv4Addr) -> anyhow::Result<Vec<u8>>186 fn create_ethernet_frame(packet: &[u8], ip_addr: &Ipv4Addr) -> anyhow::Result<Vec<u8>> {
187     // TODO: Use the etherparse crate
188     let ether_header = EtherHeader {
189         // mDNS multicast IP address
190         ether_dhost: [0x01, 0x00, 0x5e, 0x00, 0x00, 0xfb],
191         ether_shost: [0x01, 0x00, 0x5e, 0x00, 0x00, 0xfb],
192         ether_type: ETHER_TYPE_IP,
193     };
194 
195     // Create UDP Header
196     let udp_header = UdpHeader {
197         source_port: MDNS_PORT,
198         destination_port: MDNS_PORT,
199         length: (packet.len() + UDP_HEADER_LEN) as u16,
200         // Usually 0 for mDNS
201         checksum: 0,
202     };
203 
204     // Create IPv4 Header
205     let mut ipv4_header = Ipv4Header {
206         version_ihl: 0x45,
207         dscp_ecn: 0,
208         total_length: (packet.len() + UDP_HEADER_LEN + IPV4_HEADER_LEN) as u16,
209         identification: 0,
210         flags_fragment_offset: 0,
211         time_to_live: 64,
212         protocol: 17,
213         header_checksum: 0,
214         source_ip: ip_addr.octets(),
215         // mDNS multicast
216         destination_ip: MDNS_IP.octets(),
217     };
218     ipv4_header.update_checksum();
219 
220     // Combine Headers and Payload (Safely using Vec)
221     let mut response_packet =
222         Vec::with_capacity(ETHER_HEADER_LEN + IPV4_HEADER_LEN + UDP_HEADER_LEN + packet.len());
223     response_packet.extend_from_slice(&ether_header.to_be_bytes());
224     response_packet.extend_from_slice(&ipv4_header.to_be_bytes());
225     response_packet.extend_from_slice(&udp_header.to_be_bytes());
226     response_packet.extend_from_slice(packet);
227 
228     Ok(response_packet)
229 }
230 
run_mdns_forwarder(tx: mpsc::Sender<Bytes>) -> anyhow::Result<()>231 pub fn run_mdns_forwarder(tx: mpsc::Sender<Bytes>) -> anyhow::Result<()> {
232     let addr = SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), MDNS_PORT);
233     let socket = new_socket(addr.into(), false)?;
234 
235     // Typical max mDNS packet size
236     let mut buf: [MaybeUninit<u8>; 1500] = [MaybeUninit::new(0 as u8); 1500];
237     loop {
238         let (size, src_addr) = socket.recv_from(&mut buf[..])?;
239         // SAFETY: `recv_from` implementation promises not to write uninitialized bytes to `buf`.
240         // Documentation: https://docs.rs/socket2/latest/socket2/struct.Socket.html#method.recv_from
241         let packet = unsafe { &*(&buf[..size] as *const [MaybeUninit<u8>] as *const [u8]) };
242         if let Some(socket_addr_v4) = src_addr.as_socket_ipv4() {
243             debug!("Received {} bytes from {:?}", packet.len(), socket_addr_v4);
244             match create_ethernet_frame(packet, socket_addr_v4.ip()) {
245                 Ok(ethernet_frame) => {
246                     if let Err(e) = tx.send(ethernet_frame.into()) {
247                         warn!("Failed to send packet: {e}");
248                     }
249                 }
250                 Err(e) => warn!("Failed to create ethernet frame from UDP payload: {}", e),
251             };
252         } else {
253             warn!("Forwarding mDNS from IPv6 is not supported: {:?}", src_addr);
254         }
255     }
256 }
257