1 // Copyright 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 use anyhow::anyhow;
16 use bytes::Bytes;
17 use log::{debug, warn};
18 use socket2::{Protocol, Socket};
19 use std::mem::MaybeUninit;
20 use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
21 use std::sync::mpsc;
22
23 const MDNS_IP: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
24 const MDNS_PORT: u16 = 5353;
25
26 struct MacAddress(u64);
27
28 impl MacAddress {
to_be_bytes(&self) -> [u8; 6]29 fn to_be_bytes(&self) -> [u8; 6] {
30 // NOTE: mac address is le
31 self.0.to_le_bytes()[0..6].try_into().unwrap()
32 }
33 }
34
35 impl From<MacAddress> for [u8; 6] {
from(MacAddress(addr): MacAddress) -> Self36 fn from(MacAddress(addr): MacAddress) -> Self {
37 let bytes = u64::to_le_bytes(addr);
38 bytes[0..6].try_into().unwrap()
39 }
40 }
41
42 impl From<&[u8; 6]> for MacAddress {
from(bytes: &[u8; 6]) -> Self43 fn from(bytes: &[u8; 6]) -> Self {
44 Self(u64::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], 0, 0]))
45 }
46 }
47
48 #[repr(C, packed)]
49 struct Ipv4Header {
50 version_ihl: u8, // 4 bits Version, 4 bits Internet Header Length
51 dscp_ecn: u8, // 6 bits Differentiated Services Code Point, 2 bits Explicit Congestion Notification
52 total_length: u16,
53 identification: u16,
54 flags_fragment_offset: u16, // 3 bits Flags, 13 bits Fragment Offset
55 time_to_live: u8,
56 protocol: u8,
57 header_checksum: u16,
58 source_ip: [u8; 4],
59 destination_ip: [u8; 4],
60 }
61
62 macro_rules! be_vec {
63 ( $( $x:expr ),* ) => {
64 Vec::<u8>::new().iter().copied()
65 $( .chain($x.to_be_bytes()) )*
66 .collect()
67 };
68 }
69
70 impl Ipv4Header {
calculate_checksum(&self) -> u1671 fn calculate_checksum(&self) -> u16 {
72 let mut sum: u32 = 0;
73
74 // Process fixed-size fields (first 20 bytes)
75 let fixed_bytes: [u8; 20] = self.to_be_bytes();
76 for i in 0..10 {
77 let word = ((fixed_bytes[i * 2] as u16) << 8) | (fixed_bytes[i * 2 + 1] as u16);
78 sum += word as u32;
79 }
80
81 // Handle carries (fold the carry into the sum)
82 while (sum >> 16) > 0 {
83 sum = (sum & 0xFFFF) + (sum >> 16);
84 }
85
86 // One's complement
87 !sum as u16
88 }
89
update_checksum(&mut self)90 fn update_checksum(&mut self) {
91 self.header_checksum = 0; // Reset checksum before calculation
92 self.header_checksum = self.calculate_checksum();
93 }
94
to_be_bytes(&self) -> [u8; 20]95 fn to_be_bytes(&self) -> [u8; 20] {
96 let mut v: Vec<u8> = be_vec![
97 self.version_ihl,
98 self.dscp_ecn,
99 self.total_length,
100 self.identification,
101 self.flags_fragment_offset,
102 self.time_to_live,
103 self.protocol,
104 self.header_checksum
105 ];
106 v.extend(Ipv4Addr::from(self.source_ip).octets());
107 v.extend(Ipv4Addr::from(self.destination_ip).octets());
108 v.try_into().unwrap()
109 }
110 }
111
112 #[repr(C, packed)]
113 struct UdpHeader {
114 source_port: u16,
115 destination_port: u16,
116 length: u16,
117 checksum: u16,
118 }
119
120 impl UdpHeader {
to_be_bytes(&self) -> [u8; 8]121 fn to_be_bytes(&self) -> [u8; 8] {
122 let v: Vec<u8> =
123 be_vec![self.source_port, self.destination_port, self.length, self.checksum];
124 v.try_into().unwrap()
125 }
126 }
127
128 /* 10Mb/s ethernet header */
129
130 #[repr(C, packed)]
131 struct EtherHeader {
132 ether_dhost: [u8; 6],
133 ether_shost: [u8; 6],
134 ether_type: u16,
135 }
136
137 /* Ethernet protocol ID's */
138 const ETHER_TYPE_IP: u16 = 0x0800;
139
140 impl EtherHeader {
to_be_bytes(&self) -> [u8; 14]141 fn to_be_bytes(&self) -> [u8; 14] {
142 let v: Vec<u8> = be_vec![
143 MacAddress::from(&self.ether_dhost),
144 MacAddress::from(&self.ether_shost),
145 self.ether_type
146 ];
147 v.try_into().unwrap()
148 }
149 }
150
151 // Define constants for header sizes (bytes)
152 const UDP_HEADER_LEN: usize = std::mem::size_of::<UdpHeader>();
153 const IPV4_HEADER_LEN: usize = std::mem::size_of::<Ipv4Header>();
154 const ETHER_HEADER_LEN: usize = std::mem::size_of::<EtherHeader>();
155
156 /// Creates a new UDP socket to bind to `port` with REUSEPORT option.
157 /// `non_block` indicates whether to set O_NONBLOCK for the socket.
new_socket(addr: SocketAddr, non_block: bool) -> anyhow::Result<Socket>158 fn new_socket(addr: SocketAddr, non_block: bool) -> anyhow::Result<Socket> {
159 let domain = match addr {
160 SocketAddr::V4(_) => socket2::Domain::IPV4,
161 SocketAddr::V6(_) => socket2::Domain::IPV6,
162 };
163
164 let socket = Socket::new(domain, socket2::Type::DGRAM, Some(Protocol::UDP))
165 .map_err(|e| anyhow!("create socket failed: {:?}", e))?;
166
167 socket.set_reuse_address(true).map_err(|e| anyhow!("set ReuseAddr failed: {:?}", e))?;
168 #[cfg(not(windows))]
169 socket.set_reuse_port(true)?;
170
171 #[cfg(unix)] // this is currently restricted to Unix's in socket2
172 socket.set_reuse_port(true).map_err(|e| anyhow!("set ReusePort failed: {:?}", e))?;
173
174 if non_block {
175 socket.set_nonblocking(true).map_err(|e| anyhow!("set O_NONBLOCK: {:?}", e))?;
176 }
177
178 socket.join_multicast_v4(&MDNS_IP, &Ipv4Addr::UNSPECIFIED)?;
179 socket.set_multicast_loop_v4(false).expect("set_multicast_loop_v4 call failed");
180
181 socket.bind(&addr.into()).map_err(|e| anyhow!("socket bind to {} failed: {:?}", &addr, e))?;
182
183 Ok(socket)
184 }
185
create_ethernet_frame(packet: &[u8], ip_addr: &Ipv4Addr) -> anyhow::Result<Vec<u8>>186 fn create_ethernet_frame(packet: &[u8], ip_addr: &Ipv4Addr) -> anyhow::Result<Vec<u8>> {
187 // TODO: Use the etherparse crate
188 let ether_header = EtherHeader {
189 // mDNS multicast IP address
190 ether_dhost: [0x01, 0x00, 0x5e, 0x00, 0x00, 0xfb],
191 ether_shost: [0x01, 0x00, 0x5e, 0x00, 0x00, 0xfb],
192 ether_type: ETHER_TYPE_IP,
193 };
194
195 // Create UDP Header
196 let udp_header = UdpHeader {
197 source_port: MDNS_PORT,
198 destination_port: MDNS_PORT,
199 length: (packet.len() + UDP_HEADER_LEN) as u16,
200 // Usually 0 for mDNS
201 checksum: 0,
202 };
203
204 // Create IPv4 Header
205 let mut ipv4_header = Ipv4Header {
206 version_ihl: 0x45,
207 dscp_ecn: 0,
208 total_length: (packet.len() + UDP_HEADER_LEN + IPV4_HEADER_LEN) as u16,
209 identification: 0,
210 flags_fragment_offset: 0,
211 time_to_live: 64,
212 protocol: 17,
213 header_checksum: 0,
214 source_ip: ip_addr.octets(),
215 // mDNS multicast
216 destination_ip: MDNS_IP.octets(),
217 };
218 ipv4_header.update_checksum();
219
220 // Combine Headers and Payload (Safely using Vec)
221 let mut response_packet =
222 Vec::with_capacity(ETHER_HEADER_LEN + IPV4_HEADER_LEN + UDP_HEADER_LEN + packet.len());
223 response_packet.extend_from_slice(ðer_header.to_be_bytes());
224 response_packet.extend_from_slice(&ipv4_header.to_be_bytes());
225 response_packet.extend_from_slice(&udp_header.to_be_bytes());
226 response_packet.extend_from_slice(packet);
227
228 Ok(response_packet)
229 }
230
run_mdns_forwarder(tx: mpsc::Sender<Bytes>) -> anyhow::Result<()>231 pub fn run_mdns_forwarder(tx: mpsc::Sender<Bytes>) -> anyhow::Result<()> {
232 let addr = SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), MDNS_PORT);
233 let socket = new_socket(addr.into(), false)?;
234
235 // Typical max mDNS packet size
236 let mut buf: [MaybeUninit<u8>; 1500] = [MaybeUninit::new(0 as u8); 1500];
237 loop {
238 let (size, src_addr) = socket.recv_from(&mut buf[..])?;
239 // SAFETY: `recv_from` implementation promises not to write uninitialized bytes to `buf`.
240 // Documentation: https://docs.rs/socket2/latest/socket2/struct.Socket.html#method.recv_from
241 let packet = unsafe { &*(&buf[..size] as *const [MaybeUninit<u8>] as *const [u8]) };
242 if let Some(socket_addr_v4) = src_addr.as_socket_ipv4() {
243 debug!("Received {} bytes from {:?}", packet.len(), socket_addr_v4);
244 match create_ethernet_frame(packet, socket_addr_v4.ip()) {
245 Ok(ethernet_frame) => {
246 if let Err(e) = tx.send(ethernet_frame.into()) {
247 warn!("Failed to send packet: {e}");
248 }
249 }
250 Err(e) => warn!("Failed to create ethernet frame from UDP payload: {}", e),
251 };
252 } else {
253 warn!("Forwarding mDNS from IPv6 is not supported: {:?}", src_addr);
254 }
255 }
256 }
257