1 //! The BSD sockets API requires us to read the `ss_family` field before we can
2 //! interpret the rest of a `sockaddr` produced by the kernel.
3 #![allow(unsafe_code)]
4 
5 use crate::backend::c;
6 use crate::io;
7 #[cfg(target_os = "linux")]
8 use crate::net::xdp::{SockaddrXdpFlags, SocketAddrXdp};
9 use crate::net::{Ipv4Addr, Ipv6Addr, SocketAddrAny, SocketAddrUnix, SocketAddrV4, SocketAddrV6};
10 use core::mem::size_of;
11 use core::slice;
12 
13 // This must match the header of `sockaddr`.
14 #[repr(C)]
15 struct sockaddr_header {
16     ss_family: u16,
17 }
18 
19 /// Read the `ss_family` field from a socket address returned from the OS.
20 ///
21 /// # Safety
22 ///
23 /// `storage` must point to a valid socket address returned from the OS.
24 #[inline]
read_ss_family(storage: *const c::sockaddr) -> u1625 unsafe fn read_ss_family(storage: *const c::sockaddr) -> u16 {
26     // Assert that we know the layout of `sockaddr`.
27     let _ = c::sockaddr {
28         __storage: c::sockaddr_storage {
29             __bindgen_anon_1: linux_raw_sys::net::__kernel_sockaddr_storage__bindgen_ty_1 {
30                 __bindgen_anon_1:
31                     linux_raw_sys::net::__kernel_sockaddr_storage__bindgen_ty_1__bindgen_ty_1 {
32                         ss_family: 0_u16,
33                         __data: [0; 126_usize],
34                     },
35             },
36         },
37     };
38 
39     (*storage.cast::<sockaddr_header>()).ss_family
40 }
41 
42 /// Set the `ss_family` field of a socket address to `AF_UNSPEC`, so that we
43 /// can test for `AF_UNSPEC` to test whether it was stored to.
44 #[inline]
initialize_family_to_unspec(storage: *mut c::sockaddr)45 pub(crate) unsafe fn initialize_family_to_unspec(storage: *mut c::sockaddr) {
46     (*storage.cast::<sockaddr_header>()).ss_family = c::AF_UNSPEC as _;
47 }
48 
49 /// Read a socket address encoded in a platform-specific format.
50 ///
51 /// # Safety
52 ///
53 /// `storage` must point to valid socket address storage.
read_sockaddr( storage: *const c::sockaddr, len: usize, ) -> io::Result<SocketAddrAny>54 pub(crate) unsafe fn read_sockaddr(
55     storage: *const c::sockaddr,
56     len: usize,
57 ) -> io::Result<SocketAddrAny> {
58     let offsetof_sun_path = super::addr::offsetof_sun_path();
59 
60     if len < size_of::<c::sa_family_t>() {
61         return Err(io::Errno::INVAL);
62     }
63     match read_ss_family(storage).into() {
64         c::AF_INET => {
65             if len < size_of::<c::sockaddr_in>() {
66                 return Err(io::Errno::INVAL);
67             }
68             let decode = &*storage.cast::<c::sockaddr_in>();
69             Ok(SocketAddrAny::V4(SocketAddrV4::new(
70                 Ipv4Addr::from(u32::from_be(decode.sin_addr.s_addr)),
71                 u16::from_be(decode.sin_port),
72             )))
73         }
74         c::AF_INET6 => {
75             if len < size_of::<c::sockaddr_in6>() {
76                 return Err(io::Errno::INVAL);
77             }
78             let decode = &*storage.cast::<c::sockaddr_in6>();
79             Ok(SocketAddrAny::V6(SocketAddrV6::new(
80                 Ipv6Addr::from(decode.sin6_addr.in6_u.u6_addr8),
81                 u16::from_be(decode.sin6_port),
82                 u32::from_be(decode.sin6_flowinfo),
83                 decode.sin6_scope_id,
84             )))
85         }
86         c::AF_UNIX => {
87             if len < offsetof_sun_path {
88                 return Err(io::Errno::INVAL);
89             }
90             if len == offsetof_sun_path {
91                 Ok(SocketAddrAny::Unix(SocketAddrUnix::new(&[][..])?))
92             } else {
93                 let decode = &*storage.cast::<c::sockaddr_un>();
94 
95                 // On Linux check for Linux's [abstract namespace].
96                 //
97                 // [abstract namespace]: https://man7.org/linux/man-pages/man7/unix.7.html
98                 if decode.sun_path[0] == 0 {
99                     let bytes = &decode.sun_path[1..len - offsetof_sun_path];
100 
101                     // SAFETY: Convert `&[c_char]` to `&[u8]`.
102                     let bytes = slice::from_raw_parts(bytes.as_ptr().cast::<u8>(), bytes.len());
103 
104                     return SocketAddrUnix::new_abstract_name(bytes).map(SocketAddrAny::Unix);
105                 }
106 
107                 // Otherwise we expect a NUL-terminated filesystem path.
108                 let bytes = &decode.sun_path[..len - 1 - offsetof_sun_path];
109 
110                 // SAFETY: Convert `&[c_char]` to `&[u8]`.
111                 let bytes = slice::from_raw_parts(bytes.as_ptr().cast::<u8>(), bytes.len());
112 
113                 assert_eq!(decode.sun_path[len - 1 - offsetof_sun_path], 0);
114                 Ok(SocketAddrAny::Unix(SocketAddrUnix::new(bytes)?))
115             }
116         }
117         #[cfg(target_os = "linux")]
118         c::AF_XDP => {
119             if len < size_of::<c::sockaddr_xdp>() {
120                 return Err(io::Errno::INVAL);
121             }
122             let decode = &*storage.cast::<c::sockaddr_xdp>();
123             Ok(SocketAddrAny::Xdp(SocketAddrXdp::new(
124                 SockaddrXdpFlags::from_bits_retain(decode.sxdp_flags),
125                 u32::from_be(decode.sxdp_ifindex),
126                 u32::from_be(decode.sxdp_queue_id),
127                 u32::from_be(decode.sxdp_shared_umem_fd),
128             )))
129         }
130         _ => Err(io::Errno::NOTSUP),
131     }
132 }
133 
134 /// Read an optional socket address returned from the OS.
135 ///
136 /// # Safety
137 ///
138 /// `storage` must point to a valid socket address returned from the OS.
maybe_read_sockaddr_os( storage: *const c::sockaddr, len: usize, ) -> Option<SocketAddrAny>139 pub(crate) unsafe fn maybe_read_sockaddr_os(
140     storage: *const c::sockaddr,
141     len: usize,
142 ) -> Option<SocketAddrAny> {
143     if len == 0 {
144         None
145     } else {
146         Some(read_sockaddr_os(storage, len))
147     }
148 }
149 
150 /// Read a socket address returned from the OS.
151 ///
152 /// # Safety
153 ///
154 /// `storage` must point to a valid socket address returned from the OS.
read_sockaddr_os(storage: *const c::sockaddr, len: usize) -> SocketAddrAny155 pub(crate) unsafe fn read_sockaddr_os(storage: *const c::sockaddr, len: usize) -> SocketAddrAny {
156     let offsetof_sun_path = super::addr::offsetof_sun_path();
157 
158     assert!(len >= size_of::<c::sa_family_t>());
159     match read_ss_family(storage).into() {
160         c::AF_INET => {
161             assert!(len >= size_of::<c::sockaddr_in>());
162             let decode = &*storage.cast::<c::sockaddr_in>();
163             SocketAddrAny::V4(SocketAddrV4::new(
164                 Ipv4Addr::from(u32::from_be(decode.sin_addr.s_addr)),
165                 u16::from_be(decode.sin_port),
166             ))
167         }
168         c::AF_INET6 => {
169             assert!(len >= size_of::<c::sockaddr_in6>());
170             let decode = &*storage.cast::<c::sockaddr_in6>();
171             SocketAddrAny::V6(SocketAddrV6::new(
172                 Ipv6Addr::from(decode.sin6_addr.in6_u.u6_addr8),
173                 u16::from_be(decode.sin6_port),
174                 u32::from_be(decode.sin6_flowinfo),
175                 decode.sin6_scope_id,
176             ))
177         }
178         c::AF_UNIX => {
179             assert!(len >= offsetof_sun_path);
180             if len == offsetof_sun_path {
181                 SocketAddrAny::Unix(SocketAddrUnix::new(&[][..]).unwrap())
182             } else {
183                 let decode = &*storage.cast::<c::sockaddr_un>();
184 
185                 // On Linux check for Linux's [abstract namespace].
186                 //
187                 // [abstract namespace]: https://man7.org/linux/man-pages/man7/unix.7.html
188                 if decode.sun_path[0] == 0 {
189                     let bytes = &decode.sun_path[1..len - offsetof_sun_path];
190 
191                     // SAFETY: Convert `&[c_char]` to `&[u8]`.
192                     let bytes = slice::from_raw_parts(bytes.as_ptr().cast::<u8>(), bytes.len());
193 
194                     return SocketAddrAny::Unix(SocketAddrUnix::new_abstract_name(bytes).unwrap());
195                 }
196 
197                 // Otherwise we expect a NUL-terminated filesystem path.
198                 assert_eq!(decode.sun_path[len - 1 - offsetof_sun_path], 0);
199 
200                 let bytes = &decode.sun_path[..len - 1 - offsetof_sun_path];
201 
202                 // SAFETY: Convert `&[c_char]` to `&[u8]`.
203                 let bytes = slice::from_raw_parts(bytes.as_ptr().cast::<u8>(), bytes.len());
204 
205                 SocketAddrAny::Unix(SocketAddrUnix::new(bytes).unwrap())
206             }
207         }
208         #[cfg(target_os = "linux")]
209         c::AF_XDP => {
210             assert!(len >= size_of::<c::sockaddr_xdp>());
211             let decode = &*storage.cast::<c::sockaddr_xdp>();
212             SocketAddrAny::Xdp(SocketAddrXdp::new(
213                 SockaddrXdpFlags::from_bits_retain(decode.sxdp_flags),
214                 u32::from_be(decode.sxdp_ifindex),
215                 u32::from_be(decode.sxdp_queue_id),
216                 u32::from_be(decode.sxdp_shared_umem_fd),
217             ))
218         }
219         other => unimplemented!("{:?}", other),
220     }
221 }
222