xref: /aosp_15_r20/external/crosvm/win_audio/src/win_audio_impl/wave_format.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::convert::TryInto;
6 use std::fmt;
7 use std::fmt::Debug;
8 use std::fmt::Formatter;
9 
10 use base::error;
11 use base::info;
12 use base::warn;
13 use base::Error;
14 use metrics::sys::WaveFormat as WaveFormatMetric;
15 use metrics::sys::WaveFormatDetails as WaveFormatDetailsMetric;
16 use metrics::sys::WaveFormatSubFormat as WaveFormatSubFormatMetric;
17 use metrics::MetricEventType;
18 use winapi::shared::guiddef::IsEqualGUID;
19 use winapi::shared::guiddef::GUID;
20 use winapi::shared::ksmedia::KSDATAFORMAT_SUBTYPE_ADPCM;
21 use winapi::shared::ksmedia::KSDATAFORMAT_SUBTYPE_ALAW;
22 use winapi::shared::ksmedia::KSDATAFORMAT_SUBTYPE_ANALOG;
23 use winapi::shared::ksmedia::KSDATAFORMAT_SUBTYPE_DRM;
24 use winapi::shared::ksmedia::KSDATAFORMAT_SUBTYPE_IEEE_FLOAT;
25 use winapi::shared::ksmedia::KSDATAFORMAT_SUBTYPE_MPEG;
26 use winapi::shared::ksmedia::KSDATAFORMAT_SUBTYPE_MULAW;
27 use winapi::shared::ksmedia::KSDATAFORMAT_SUBTYPE_PCM;
28 use winapi::shared::mmreg::SPEAKER_FRONT_CENTER;
29 use winapi::shared::mmreg::SPEAKER_FRONT_LEFT;
30 use winapi::shared::mmreg::SPEAKER_FRONT_RIGHT;
31 use winapi::shared::mmreg::WAVEFORMATEX;
32 use winapi::shared::mmreg::WAVEFORMATEXTENSIBLE;
33 use winapi::shared::mmreg::WAVE_FORMAT_EXTENSIBLE;
34 use winapi::shared::mmreg::WAVE_FORMAT_IEEE_FLOAT;
35 use winapi::shared::winerror::S_FALSE;
36 use winapi::shared::winerror::S_OK;
37 use winapi::um::audioclient::IAudioClient;
38 use winapi::um::audiosessiontypes::AUDCLNT_SHAREMODE_SHARED;
39 #[cfg(not(test))]
40 use winapi::um::combaseapi::CoTaskMemFree;
41 use wio::com::ComPtr;
42 
43 use crate::AudioSharedFormat;
44 use crate::WinAudioError;
45 use crate::MONO_CHANNEL_COUNT;
46 use crate::STEREO_CHANNEL_COUNT;
47 
48 /// Wrapper around `WAVEFORMATEX` and `WAVEFORMATEXTENSIBLE` to hide some of the unsafe calls
49 /// that could be made.
50 pub enum WaveAudioFormat {
51     /// Format where channels are capped at 2.
52     WaveFormat(WAVEFORMATEX),
53     /// Format where channels can be >2. (It can still have <=2 channels)
54     WaveFormatExtensible(WAVEFORMATEXTENSIBLE),
55 }
56 
57 pub(crate) enum AudioFormatEventType {
58     RequestOk,
59     ModifiedOk,
60     Failed,
61 }
62 
63 impl WaveAudioFormat {
64     /// Wraps a WAVEFORMATEX pointer to make it's use more safe.
65     ///
66     /// # Safety
67     /// Unsafe if `wave_format_ptr` is pointing to null. This function will assume it's not null
68     /// and dereference it.
69     /// Also `format_ptr` will be deallocated after this function completes, so it cannot be used.
70     #[allow(clippy::let_and_return)]
new(format_ptr: *mut WAVEFORMATEX) -> Self71     pub unsafe fn new(format_ptr: *mut WAVEFORMATEX) -> Self {
72         let format_tag = { (*format_ptr).wFormatTag };
73         let result = if format_tag != WAVE_FORMAT_EXTENSIBLE {
74             warn!(
75                 "Default Mix Format does not have format_tag WAVE_FORMAT_EXTENSIBLE. It is: {}",
76                 format_tag
77             );
78             WaveAudioFormat::WaveFormat(*format_ptr)
79         } else {
80             WaveAudioFormat::WaveFormatExtensible(*(format_ptr as *const WAVEFORMATEXTENSIBLE))
81         };
82 
83         // WAVEFORMATEX and WAVEFORMATEXTENSIBLE both implement the Copy trait, so they have been
84         // copied to the WaveAudioFormat enum. Therefore, it is safe to free the memory
85         // `format_ptr` is pointing to.
86         // In a test, WAVEFORMATEX is initiated by us, not by Windows, so calling this function
87         // could cause a STATUS_HEAP_CORRUPTION exception.
88         #[cfg(not(test))]
89         CoTaskMemFree(format_ptr as *mut std::ffi::c_void);
90 
91         result
92     }
93 
get_num_channels(&self) -> u1694     pub fn get_num_channels(&self) -> u16 {
95         match self {
96             WaveAudioFormat::WaveFormat(wave_format) => wave_format.nChannels,
97             WaveAudioFormat::WaveFormatExtensible(wave_format_extensible) => {
98                 wave_format_extensible.Format.nChannels
99             }
100         }
101     }
102 
103     // Modifies `WAVEFORMATEXTENSIBLE` to have the values passed into the function params.
104     // Currently it should only modify the bit_depth if it's != 32 and the data format if it's not
105     // float.
modify_mix_format(&mut self, target_bit_depth: usize, ks_data_format: GUID)106     pub fn modify_mix_format(&mut self, target_bit_depth: usize, ks_data_format: GUID) {
107         let default_num_channels = self.get_num_channels();
108 
109         fn calc_avg_bytes_per_sec(num_channels: u16, bit_depth: u16, samples_per_sec: u32) -> u32 {
110             num_channels as u32 * (bit_depth as u32 / 8) * samples_per_sec
111         }
112 
113         fn calc_block_align(num_channels: u16, bit_depth: u16) -> u16 {
114             (bit_depth / 8) * num_channels
115         }
116 
117         match self {
118             WaveAudioFormat::WaveFormat(wave_format) => {
119                 if default_num_channels > STEREO_CHANNEL_COUNT {
120                     warn!("WAVEFORMATEX shouldn't have >2 channels.");
121                 }
122 
123                 // Force the format to be the only supported format (32 bit float)
124                 if wave_format.wBitsPerSample != target_bit_depth as u16
125                     || wave_format.wFormatTag != WAVE_FORMAT_IEEE_FLOAT
126                 {
127                     wave_format.wFormatTag = WAVE_FORMAT_IEEE_FLOAT;
128                     wave_format.nChannels =
129                         std::cmp::min(STEREO_CHANNEL_COUNT, default_num_channels);
130                     wave_format.wBitsPerSample = target_bit_depth as u16;
131                     wave_format.nAvgBytesPerSec = calc_avg_bytes_per_sec(
132                         wave_format.nChannels,
133                         wave_format.wBitsPerSample,
134                         wave_format.nSamplesPerSec,
135                     );
136                     wave_format.nBlockAlign =
137                         calc_block_align(wave_format.nChannels, wave_format.wBitsPerSample);
138                 }
139             }
140             WaveAudioFormat::WaveFormatExtensible(wave_format_extensible) => {
141                 // WAVE_FORMAT_EXTENSIBLE uses the #[repr(packed)] flag so the compiler might
142                 // unalign the fields. Thus, the fields will be copied to a local variable to
143                 // prevent segfaults. For more information:
144                 // https://github.com/rust-lang/rust/issues/46043
145                 let sub_format = wave_format_extensible.SubFormat;
146 
147                 if wave_format_extensible.Format.wBitsPerSample != target_bit_depth as u16
148                     || !IsEqualGUID(&sub_format, &ks_data_format)
149                 {
150                     // wFormatTag won't be changed
151                     wave_format_extensible.Format.nChannels = default_num_channels;
152                     wave_format_extensible.Format.wBitsPerSample = target_bit_depth as u16;
153                     // nSamplesPerSec should stay the same
154                     // Calculated with a bit depth of 32bits
155                     wave_format_extensible.Format.nAvgBytesPerSec = calc_avg_bytes_per_sec(
156                         wave_format_extensible.Format.nChannels,
157                         wave_format_extensible.Format.wBitsPerSample,
158                         wave_format_extensible.Format.nSamplesPerSec,
159                     );
160                     wave_format_extensible.Format.nBlockAlign = calc_block_align(
161                         wave_format_extensible.Format.nChannels,
162                         wave_format_extensible.Format.wBitsPerSample,
163                     );
164                     // 22 is the size typically used when the format tag is WAVE_FORMAT_EXTENSIBLE.
165                     // Since the `Initialize` syscall takes in a WAVEFORMATEX, this tells Windows
166                     // how many bytes are left after the `Format` field
167                     // (ie. Samples, dwChannelMask, SubFormat) so that it can cast to
168                     // WAVEFORMATEXTENSIBLE safely.
169                     wave_format_extensible.Format.cbSize = 22;
170                     wave_format_extensible.Samples = target_bit_depth as u16;
171                     let n_channels = wave_format_extensible.Format.nChannels;
172                     // The channel masks are defined here:
173                     // https://docs.microsoft.com/en-us/windows/win32/api/mmreg/ns-mmreg-waveformatextensible#remarks
174                     wave_format_extensible.dwChannelMask = match n_channels {
175                         STEREO_CHANNEL_COUNT => SPEAKER_FRONT_LEFT | SPEAKER_FRONT_RIGHT,
176                         MONO_CHANNEL_COUNT => SPEAKER_FRONT_CENTER,
177                         _ => {
178                             // Don't change channel mask if it's >2 channels.
179                             wave_format_extensible.dwChannelMask
180                         }
181                     };
182                     wave_format_extensible.SubFormat = ks_data_format;
183                 }
184             }
185         }
186     }
187 
as_ptr(&self) -> *const WAVEFORMATEX188     pub fn as_ptr(&self) -> *const WAVEFORMATEX {
189         match self {
190             WaveAudioFormat::WaveFormat(wave_format) => wave_format as *const WAVEFORMATEX,
191             WaveAudioFormat::WaveFormatExtensible(wave_format_extensible) => {
192                 wave_format_extensible as *const _ as *const WAVEFORMATEX
193             }
194         }
195     }
196 
get_shared_audio_engine_period_in_frames( &self, shared_default_size_in_100nanoseconds: f64, ) -> usize197     pub fn get_shared_audio_engine_period_in_frames(
198         &self,
199         shared_default_size_in_100nanoseconds: f64,
200     ) -> usize {
201         // a 100 nanosecond unit is 1 * 10^-7 seconds
202         //
203         // To convert a 100nanoseconds value to # of frames in a period, we multiple by the
204         // frame rate (nSamplesPerSec. Sample rate == Frame rate) and then divide by 10000000
205         // in order to convert 100nanoseconds to seconds.
206         let samples_per_sec = match self {
207             WaveAudioFormat::WaveFormat(wave_format) => wave_format.nSamplesPerSec,
208             WaveAudioFormat::WaveFormatExtensible(wave_format_extensible) => {
209                 wave_format_extensible.Format.nSamplesPerSec
210             }
211         };
212 
213         ((samples_per_sec as f64 * shared_default_size_in_100nanoseconds) / 10000000.0).ceil()
214             as usize
215     }
216 
create_audio_shared_format( &self, shared_audio_engine_period_in_frames: usize, ) -> AudioSharedFormat217     pub fn create_audio_shared_format(
218         &self,
219         shared_audio_engine_period_in_frames: usize,
220     ) -> AudioSharedFormat {
221         match self {
222             WaveAudioFormat::WaveFormat(wave_format) => AudioSharedFormat {
223                 bit_depth: wave_format.wBitsPerSample as usize,
224                 frame_rate: wave_format.nSamplesPerSec as usize,
225                 shared_audio_engine_period_in_frames,
226                 channels: wave_format.nChannels as usize,
227                 channel_mask: None,
228             },
229             WaveAudioFormat::WaveFormatExtensible(wave_format_extensible) => AudioSharedFormat {
230                 bit_depth: wave_format_extensible.Format.wBitsPerSample as usize,
231                 frame_rate: wave_format_extensible.Format.nSamplesPerSec as usize,
232                 shared_audio_engine_period_in_frames,
233                 channels: wave_format_extensible.Format.nChannels as usize,
234                 channel_mask: Some(wave_format_extensible.dwChannelMask),
235             },
236         }
237     }
238 
239     #[cfg(test)]
take_waveformatex(self) -> WAVEFORMATEX240     fn take_waveformatex(self) -> WAVEFORMATEX {
241         match self {
242             WaveAudioFormat::WaveFormat(wave_format) => wave_format,
243             WaveAudioFormat::WaveFormatExtensible(wave_format_extensible) => {
244                 // SAFETY: `wave_format_extensible` can't be a null pointer, otherwise the
245                 // constructor would've failed. This will also give the caller ownership of this
246                 // struct.
247                 unsafe { *(&wave_format_extensible as *const _ as *const WAVEFORMATEX) }
248             }
249         }
250     }
251 
252     #[cfg(test)]
take_waveformatextensible(self) -> WAVEFORMATEXTENSIBLE253     fn take_waveformatextensible(self) -> WAVEFORMATEXTENSIBLE {
254         match self {
255             WaveAudioFormat::WaveFormat(_wave_format) => {
256                 panic!("Format is WAVEFORMATEX. Can't convert to WAVEFORMATEXTENSBILE.")
257             }
258             WaveAudioFormat::WaveFormatExtensible(wave_format_extensible) => wave_format_extensible,
259         }
260     }
261 }
262 
263 impl Debug for WaveAudioFormat {
fmt(&self, f: &mut Formatter) -> fmt::Result264     fn fmt(&self, f: &mut Formatter) -> fmt::Result {
265         let res = match self {
266             WaveAudioFormat::WaveFormat(wave_format) => {
267                 format!(
268                     "wFormatTag: {}, \nnChannels: {}, \nnSamplesPerSec: {}, \nnAvgBytesPerSec: \
269                     {}, \nnBlockAlign: {}, \nwBitsPerSample: {}, \ncbSize: {}",
270                     { wave_format.wFormatTag },
271                     { wave_format.nChannels },
272                     { wave_format.nSamplesPerSec },
273                     { wave_format.nAvgBytesPerSec },
274                     { wave_format.nBlockAlign },
275                     { wave_format.wBitsPerSample },
276                     { wave_format.cbSize },
277                 )
278             }
279             WaveAudioFormat::WaveFormatExtensible(wave_format_extensible) => {
280                 let audio_engine_format = format!(
281                     "wFormatTag: {}, \nnChannels: {}, \nnSamplesPerSec: {}, \nnAvgBytesPerSec: \
282                     {}, \nnBlockAlign: {}, \nwBitsPerSample: {}, \ncbSize: {}",
283                     { wave_format_extensible.Format.wFormatTag },
284                     { wave_format_extensible.Format.nChannels },
285                     { wave_format_extensible.Format.nSamplesPerSec },
286                     { wave_format_extensible.Format.nAvgBytesPerSec },
287                     { wave_format_extensible.Format.nBlockAlign },
288                     { wave_format_extensible.Format.wBitsPerSample },
289                     { wave_format_extensible.Format.cbSize },
290                 );
291 
292                 let subformat = wave_format_extensible.SubFormat;
293 
294                 // TODO(b/240186720): Passing in `KSDATAFORMAT_SUBTYPE_PCM` will cause a
295                 // freeze. IsEqualGUID is unsafe even though it isn't marked as such. Look into
296                 // fixing or possibily write our own, that works.
297                 //
298                 // This check would be a nice to have, but not necessary. Right now, the subformat
299                 // used will always be `IEEE_FLOAT`, so this check will be useless if nothing
300                 // changes.
301                 //
302                 // if !IsEqualGUID(
303                 //     &{ wave_format_extensible.SubFormat },
304                 //     &KSDATAFORMAT_SUBTYPE_IEEE_FLOAT,
305                 // ) {
306                 //     warn!("Audio Engine format is NOT IEEE FLOAT");
307                 // }
308 
309                 let audio_engine_extensible_format = format!(
310                     "\nSamples: {}, \ndwChannelMask: {}, \nSubFormat: {}-{}-{}-{:?}",
311                     { wave_format_extensible.Samples },
312                     { wave_format_extensible.dwChannelMask },
313                     subformat.Data1,
314                     subformat.Data2,
315                     subformat.Data3,
316                     subformat.Data4,
317                 );
318 
319                 format!("{}{}", audio_engine_format, audio_engine_extensible_format)
320             }
321         };
322         write!(f, "{}", res)
323     }
324 }
325 
326 #[cfg(test)]
327 impl PartialEq for WaveAudioFormat {
eq(&self, other: &Self) -> bool328     fn eq(&self, other: &Self) -> bool {
329         if std::mem::discriminant(self) != std::mem::discriminant(other) {
330             return false;
331         }
332 
333         fn are_formats_same(
334             wave_format_pointer: *const u8,
335             other_format_pointer: *const u8,
336             cb_size: usize,
337         ) -> bool {
338             // SAFETY: wave_format_pointer is valid for the given size.
339             let wave_format_bytes: &[u8] = unsafe {
340                 std::slice::from_raw_parts(
341                     wave_format_pointer,
342                     std::mem::size_of::<WAVEFORMATEX>() + cb_size,
343                 )
344             };
345             // SAFETY: other_format_pointer is valid for the given size.
346             let other_bytes: &[u8] = unsafe {
347                 std::slice::from_raw_parts(
348                     other_format_pointer,
349                     std::mem::size_of::<WAVEFORMATEX>() + cb_size,
350                 )
351             };
352 
353             !wave_format_bytes
354                 .iter()
355                 .zip(other_bytes)
356                 .map(|(x, y)| x.cmp(y))
357                 .any(|ord| ord != std::cmp::Ordering::Equal)
358         }
359 
360         match self {
361             WaveAudioFormat::WaveFormat(wave_format) => match other {
362                 WaveAudioFormat::WaveFormat(other_wave_format) => {
363                     if wave_format.cbSize != other_wave_format.cbSize {
364                         return false;
365                     }
366                     are_formats_same(
367                         wave_format as *const _ as *const u8,
368                         other_wave_format as *const _ as *const u8,
369                         wave_format.cbSize as usize,
370                     )
371                 }
372                 WaveAudioFormat::WaveFormatExtensible(_) => unreachable!(),
373             },
374             WaveAudioFormat::WaveFormatExtensible(wave_format_extensible) => match other {
375                 WaveAudioFormat::WaveFormatExtensible(other_wave_format_extensible) => {
376                     if wave_format_extensible.Format.cbSize
377                         != other_wave_format_extensible.Format.cbSize
378                     {
379                         return false;
380                     }
381                     are_formats_same(
382                         wave_format_extensible as *const _ as *const u8,
383                         other_wave_format_extensible as *const _ as *const u8,
384                         wave_format_extensible.Format.cbSize as usize,
385                     )
386                 }
387                 WaveAudioFormat::WaveFormat(_) => unreachable!(),
388             },
389         }
390     }
391 }
392 
393 impl From<&WaveAudioFormat> for WaveFormatMetric {
from(format: &WaveAudioFormat) -> WaveFormatMetric394     fn from(format: &WaveAudioFormat) -> WaveFormatMetric {
395         match format {
396             WaveAudioFormat::WaveFormat(wave_format) => WaveFormatMetric {
397                 format_tag: wave_format.wFormatTag.into(),
398                 channels: wave_format.nChannels.into(),
399                 samples_per_sec: wave_format
400                     .nSamplesPerSec
401                     .try_into()
402                     .expect("Failed to cast nSamplesPerSec to i32"),
403                 avg_bytes_per_sec: wave_format
404                     .nAvgBytesPerSec
405                     .try_into()
406                     .expect("Failed to cast nAvgBytesPerSec"),
407                 block_align: wave_format.nBlockAlign.into(),
408                 bits_per_sample: wave_format.wBitsPerSample.into(),
409                 size_bytes: wave_format.cbSize.into(),
410                 samples: None,
411                 channel_mask: None,
412                 sub_format: None,
413             },
414             WaveAudioFormat::WaveFormatExtensible(wave_format_extensible) => {
415                 let sub_format = wave_format_extensible.SubFormat;
416                 WaveFormatMetric {
417                     format_tag: wave_format_extensible.Format.wFormatTag.into(),
418                     channels: wave_format_extensible.Format.nChannels.into(),
419                     samples_per_sec: wave_format_extensible
420                         .Format
421                         .nSamplesPerSec
422                         .try_into()
423                         .expect("Failed to cast nSamplesPerSec to i32"),
424                     avg_bytes_per_sec: wave_format_extensible
425                         .Format
426                         .nAvgBytesPerSec
427                         .try_into()
428                         .expect("Failed to cast nAvgBytesPerSec"),
429                     block_align: wave_format_extensible.Format.nBlockAlign.into(),
430                     bits_per_sample: wave_format_extensible.Format.wBitsPerSample.into(),
431                     size_bytes: wave_format_extensible.Format.cbSize.into(),
432                     samples: Some(wave_format_extensible.Samples.into()),
433                     channel_mask: Some(wave_format_extensible.dwChannelMask.into()),
434                     sub_format: Some(GuidWrapper(&sub_format).into()),
435                 }
436             }
437         }
438     }
439 }
440 
441 /// Get an audio format that will be accepted by the audio client. In terms of bit depth, the goal
442 /// is to always get a 32bit float format.
get_valid_mix_format( audio_client: &ComPtr<IAudioClient>, ) -> Result<WaveAudioFormat, WinAudioError>443 pub(crate) fn get_valid_mix_format(
444     audio_client: &ComPtr<IAudioClient>,
445 ) -> Result<WaveAudioFormat, WinAudioError> {
446     // SAFETY: `format_ptr` is owned by this unsafe block. `format_ptr` is guarenteed to
447     // be not null by the time it reached `WaveAudioFormat::new` (check_hresult! should make
448     // sure of that), which is also release the pointer passed in.
449     let mut format = unsafe {
450         let mut format_ptr: *mut WAVEFORMATEX = std::ptr::null_mut();
451         let hr = audio_client.GetMixFormat(&mut format_ptr);
452         check_hresult!(
453             hr,
454             WinAudioError::from(hr),
455             "Failed to retrieve audio engine's shared format"
456         )?;
457 
458         WaveAudioFormat::new(format_ptr)
459     };
460 
461     let mut wave_format_details = WaveFormatDetailsMetric::default();
462     let mut event_code = AudioFormatEventType::RequestOk;
463     wave_format_details.requested = Some(WaveFormatMetric::from(&format));
464 
465     info!("Printing mix format from `GetMixFormat`:\n{:?}", format);
466     const BIT_DEPTH: usize = 32;
467     format.modify_mix_format(BIT_DEPTH, KSDATAFORMAT_SUBTYPE_IEEE_FLOAT);
468 
469     let modified_wave_format = Some(WaveFormatMetric::from(&format));
470     if modified_wave_format != wave_format_details.requested {
471         wave_format_details.modified = modified_wave_format;
472         event_code = AudioFormatEventType::ModifiedOk;
473     }
474 
475     info!("Audio Engine Mix Format Used: \n{:?}", format);
476     check_format(audio_client, &format, wave_format_details, event_code)?;
477 
478     Ok(format)
479 }
480 
481 /// Checks to see if `format` is accepted by the audio client.
482 ///
483 /// Exposed as crate public for testing.
check_format( audio_client: &IAudioClient, format: &WaveAudioFormat, mut wave_format_details: WaveFormatDetailsMetric, event_code: AudioFormatEventType, ) -> Result<(), WinAudioError>484 pub(crate) fn check_format(
485     audio_client: &IAudioClient,
486     format: &WaveAudioFormat,
487     mut wave_format_details: WaveFormatDetailsMetric,
488     event_code: AudioFormatEventType,
489 ) -> Result<(), WinAudioError> {
490     let mut closest_match_format: *mut WAVEFORMATEX = std::ptr::null_mut();
491     // SAFETY: All values passed into `IsFormatSupport` is owned by us and we will
492     // guarentee they won't be dropped and are valid.
493     let hr = unsafe {
494         audio_client.IsFormatSupported(
495             AUDCLNT_SHAREMODE_SHARED,
496             format.as_ptr(),
497             &mut closest_match_format,
498         )
499     };
500 
501     // If the audio engine does not support the format.
502     if hr != S_OK {
503         if hr == S_FALSE {
504             // SAFETY: If the `hr` value is `S_FALSE`, then `IsFormatSupported` must've
505             // given us a closest match.
506             let closest_match_enum = unsafe { WaveAudioFormat::new(closest_match_format) };
507             wave_format_details.closest_matched = Some(WaveFormatMetric::from(&closest_match_enum));
508 
509             error!(
510                 "Current audio format not supported, the closest format is:\n{:?}",
511                 closest_match_enum
512             );
513         } else {
514             error!("IsFormatSupported failed with hr: {}", hr);
515         }
516 
517         // Get last error here just incase `upload_metrics` causes an error.
518         let last_error = Error::last();
519         // TODO:(b/253509368): Only upload for audio rendering, since these metrics can't
520         // differentiate between rendering and capture.
521         upload_metrics(wave_format_details, AudioFormatEventType::Failed);
522 
523         Err(WinAudioError::WindowsError(hr, last_error))
524     } else {
525         upload_metrics(wave_format_details, event_code);
526 
527         Ok(())
528     }
529 }
530 
upload_metrics(details: WaveFormatDetailsMetric, event_type: AudioFormatEventType)531 fn upload_metrics(details: WaveFormatDetailsMetric, event_type: AudioFormatEventType) {
532     let event = match event_type {
533         AudioFormatEventType::RequestOk => MetricEventType::AudioFormatRequestOk(details),
534         AudioFormatEventType::ModifiedOk => MetricEventType::AudioFormatModifiedOk(details),
535         AudioFormatEventType::Failed => MetricEventType::AudioFormatFailed(details),
536     };
537     metrics::log_event(event);
538 }
539 
540 struct GuidWrapper<'a>(&'a GUID);
541 
542 impl<'a> From<GuidWrapper<'a>> for WaveFormatSubFormatMetric {
from(guid: GuidWrapper) -> WaveFormatSubFormatMetric543     fn from(guid: GuidWrapper) -> WaveFormatSubFormatMetric {
544         let guid = guid.0;
545         if IsEqualGUID(guid, &KSDATAFORMAT_SUBTYPE_ANALOG) {
546             WaveFormatSubFormatMetric::Analog
547         } else if IsEqualGUID(guid, &KSDATAFORMAT_SUBTYPE_PCM) {
548             WaveFormatSubFormatMetric::Pcm
549         } else if IsEqualGUID(guid, &KSDATAFORMAT_SUBTYPE_IEEE_FLOAT) {
550             WaveFormatSubFormatMetric::IeeeFloat
551         } else if IsEqualGUID(guid, &KSDATAFORMAT_SUBTYPE_DRM) {
552             WaveFormatSubFormatMetric::Drm
553         } else if IsEqualGUID(guid, &KSDATAFORMAT_SUBTYPE_ALAW) {
554             WaveFormatSubFormatMetric::ALaw
555         } else if IsEqualGUID(guid, &KSDATAFORMAT_SUBTYPE_MULAW) {
556             WaveFormatSubFormatMetric::MuLaw
557         } else if IsEqualGUID(guid, &KSDATAFORMAT_SUBTYPE_ADPCM) {
558             WaveFormatSubFormatMetric::Adpcm
559         } else if IsEqualGUID(guid, &KSDATAFORMAT_SUBTYPE_MPEG) {
560             WaveFormatSubFormatMetric::Mpeg
561         } else {
562             WaveFormatSubFormatMetric::Invalid
563         }
564     }
565 }
566 
567 #[cfg(test)]
568 mod tests {
569     use winapi::shared::ksmedia::KSDATAFORMAT_SUBTYPE_PCM;
570     use winapi::shared::mmreg::SPEAKER_BACK_LEFT;
571     use winapi::shared::mmreg::SPEAKER_BACK_RIGHT;
572     use winapi::shared::mmreg::SPEAKER_LOW_FREQUENCY;
573     use winapi::shared::mmreg::SPEAKER_SIDE_LEFT;
574     use winapi::shared::mmreg::SPEAKER_SIDE_RIGHT;
575     use winapi::shared::mmreg::WAVE_FORMAT_PCM;
576 
577     use super::*;
578 
579     #[test]
test_modify_mix_format()580     fn test_modify_mix_format() {
581         // A typical 7.1 surround sound channel mask.
582         const channel_mask_7_1: u32 = SPEAKER_FRONT_LEFT
583             | SPEAKER_FRONT_RIGHT
584             | SPEAKER_FRONT_CENTER
585             | SPEAKER_LOW_FREQUENCY
586             | SPEAKER_BACK_LEFT
587             | SPEAKER_BACK_RIGHT
588             | SPEAKER_SIDE_LEFT
589             | SPEAKER_SIDE_RIGHT;
590 
591         let surround_sound_format = WAVEFORMATEXTENSIBLE {
592             Format: WAVEFORMATEX {
593                 wFormatTag: WAVE_FORMAT_EXTENSIBLE,
594                 nChannels: 8,
595                 nSamplesPerSec: 44100,
596                 nAvgBytesPerSec: 1411200,
597                 nBlockAlign: 32,
598                 wBitsPerSample: 32,
599                 cbSize: 22,
600             },
601             Samples: 32,
602             dwChannelMask: channel_mask_7_1,
603             SubFormat: KSDATAFORMAT_SUBTYPE_PCM,
604         };
605 
606         // SAFETY: `GetMixFormat` casts `WAVEFORMATEXTENSIBLE` into a `WAVEFORMATEX` like so.
607         // Also this is casting from a bigger to a smaller struct, so it shouldn't be possible for
608         // this contructor to access memory it shouldn't.
609         let mut format = unsafe {
610             WaveAudioFormat::new((&surround_sound_format) as *const _ as *mut WAVEFORMATEX)
611         };
612 
613         format.modify_mix_format(
614             /* bit_depth= */ 32,
615             /* ks_data_format= */ KSDATAFORMAT_SUBTYPE_IEEE_FLOAT,
616         );
617 
618         // SAFETY: We know the format is originally a `WAVEFORMATEXTENSIBLE`.
619         let surround_sound_format = format.take_waveformatextensible();
620 
621         // WAVE_FORMAT_EXTENSIBLE uses the #[repr(packed)] flag so the compiler might unalign
622         // the fields. Thus, the fields will be copied to a local variable to prevent segfaults.
623         // For more information: https://github.com/rust-lang/rust/issues/46043
624         let format_tag = surround_sound_format.Format.wFormatTag;
625         // We expect `SubFormat` to be IEEE float instead of PCM.
626         // Everything else should remain the same.
627         assert_eq!(format_tag, WAVE_FORMAT_EXTENSIBLE);
628         let channels = surround_sound_format.Format.nChannels;
629         assert_eq!(channels, 8);
630         let samples_per_sec = surround_sound_format.Format.nSamplesPerSec;
631         assert_eq!(samples_per_sec, 44100);
632         let avg_bytes_per_sec = surround_sound_format.Format.nAvgBytesPerSec;
633         assert_eq!(avg_bytes_per_sec, 1411200);
634         let block_align = surround_sound_format.Format.nBlockAlign;
635         assert_eq!(block_align, 32);
636         let bits_per_samples = surround_sound_format.Format.wBitsPerSample;
637         assert_eq!(bits_per_samples, 32);
638         let size = surround_sound_format.Format.cbSize;
639         assert_eq!(size, 22);
640         let samples = surround_sound_format.Samples;
641         assert_eq!(samples, 32);
642         let channel_mask = surround_sound_format.dwChannelMask;
643         assert_eq!(channel_mask, channel_mask_7_1);
644         let sub_format = surround_sound_format.SubFormat;
645         assert!(IsEqualGUID(&sub_format, &KSDATAFORMAT_SUBTYPE_IEEE_FLOAT));
646     }
647 
648     #[test]
test_waveformatex_ieee_modify_same_format()649     fn test_waveformatex_ieee_modify_same_format() {
650         let format = WAVEFORMATEX {
651             wFormatTag: WAVE_FORMAT_IEEE_FLOAT,
652             nChannels: 2,
653             nSamplesPerSec: 48000,
654             nAvgBytesPerSec: 384000,
655             nBlockAlign: 8,
656             wBitsPerSample: 32,
657             cbSize: 0,
658         };
659 
660         let mut format =
661             // SAFETY: We can convert a struct to a pointer declared above. Also that means the
662             // pointer can be safely deferenced.
663             unsafe { WaveAudioFormat::new((&format) as *const WAVEFORMATEX as *mut WAVEFORMATEX) };
664 
665         format.modify_mix_format(
666             /* bit_depth= */ 32,
667             /* ks_data_format= */ KSDATAFORMAT_SUBTYPE_IEEE_FLOAT,
668         );
669 
670         let result_format = format.take_waveformatex();
671 
672         assert_waveformatex_ieee(&result_format);
673     }
674 
675     #[test]
test_waveformatex_ieee_modify_different_format()676     fn test_waveformatex_ieee_modify_different_format() {
677         // I don't expect this format to show up ever, but it's possible so it's good to test.
678         let format = WAVEFORMATEX {
679             wFormatTag: WAVE_FORMAT_IEEE_FLOAT,
680             nChannels: 2,
681             nSamplesPerSec: 48000,
682             nAvgBytesPerSec: 192000,
683             nBlockAlign: 4,
684             wBitsPerSample: 16,
685             cbSize: 0,
686         };
687 
688         let mut format =
689             // SAFETY: We can convert a struct to a pointer declared above. Also that means the
690             // pointer can be safely deferenced.
691             unsafe { WaveAudioFormat::new((&format) as *const WAVEFORMATEX as *mut WAVEFORMATEX) };
692 
693         format.modify_mix_format(
694             /* bit_depth= */ 32,
695             /* ks_data_format= */ KSDATAFORMAT_SUBTYPE_IEEE_FLOAT,
696         );
697 
698         let result_format = format.take_waveformatex();
699 
700         assert_waveformatex_ieee(&result_format);
701     }
702 
703     #[test]
test_format_comparison_waveformatex_pass()704     fn test_format_comparison_waveformatex_pass() {
705         let format = WAVEFORMATEX {
706             wFormatTag: WAVE_FORMAT_PCM,
707             nChannels: 1,
708             nSamplesPerSec: 48000,
709             nAvgBytesPerSec: 4 * 48000,
710             nBlockAlign: 4,
711             wBitsPerSample: 16,
712             cbSize: 0,
713         };
714 
715         let format =
716             // SAFETY: We can convert a struct to a pointer declared above. Also that means the
717             // pointer can be safely deferenced.
718             unsafe { WaveAudioFormat::new((&format) as *const WAVEFORMATEX as *mut WAVEFORMATEX) };
719 
720         let expected = WAVEFORMATEX {
721             wFormatTag: WAVE_FORMAT_PCM,
722             nChannels: 1,
723             nSamplesPerSec: 48000,
724             nAvgBytesPerSec: 4 * 48000,
725             nBlockAlign: 4,
726             wBitsPerSample: 16,
727             cbSize: 0,
728         };
729 
730         // SAFETY: We can convert a struct to a pointer declared above. Also that means the
731         // pointer can be safely deferenced.
732         let expected = unsafe {
733             WaveAudioFormat::new((&expected) as *const WAVEFORMATEX as *mut WAVEFORMATEX)
734         };
735 
736         assert_eq!(expected, format);
737     }
738 
739     #[test]
test_format_comparison_waveformatextensible_pass()740     fn test_format_comparison_waveformatextensible_pass() {
741         let format = WAVEFORMATEXTENSIBLE {
742             Format: WAVEFORMATEX {
743                 wFormatTag: WAVE_FORMAT_EXTENSIBLE,
744                 nChannels: 1,
745                 nSamplesPerSec: 48000,
746                 nAvgBytesPerSec: 4 * 48000,
747                 nBlockAlign: 4,
748                 wBitsPerSample: 16,
749                 cbSize: 22,
750             },
751             Samples: 16,
752             dwChannelMask: SPEAKER_FRONT_LEFT | SPEAKER_FRONT_RIGHT,
753             SubFormat: KSDATAFORMAT_SUBTYPE_IEEE_FLOAT,
754         };
755 
756         // SAFETY: We can convert a struct to a pointer declared above. Also that means the
757         // pointer can be safely deferenced.
758         let format = unsafe {
759             WaveAudioFormat::new((&format) as *const WAVEFORMATEXTENSIBLE as *mut WAVEFORMATEX)
760         };
761 
762         let expected = WAVEFORMATEXTENSIBLE {
763             Format: WAVEFORMATEX {
764                 wFormatTag: WAVE_FORMAT_EXTENSIBLE,
765                 nChannels: 1,
766                 nSamplesPerSec: 48000,
767                 nAvgBytesPerSec: 4 * 48000,
768                 nBlockAlign: 4,
769                 wBitsPerSample: 16,
770                 cbSize: 22,
771             },
772             Samples: 16,
773             dwChannelMask: SPEAKER_FRONT_LEFT | SPEAKER_FRONT_RIGHT,
774             SubFormat: KSDATAFORMAT_SUBTYPE_IEEE_FLOAT,
775         };
776 
777         // SAFETY: We can convert a struct to a pointer declared above. Also that means the
778         // pointer can be safely deferenced.
779         let expected = unsafe {
780             WaveAudioFormat::new((&expected) as *const WAVEFORMATEXTENSIBLE as *mut WAVEFORMATEX)
781         };
782 
783         assert_eq!(expected, format);
784     }
785 
786     #[test]
test_format_comparison_waveformatex_fail()787     fn test_format_comparison_waveformatex_fail() {
788         let format = WAVEFORMATEX {
789             wFormatTag: WAVE_FORMAT_PCM,
790             nChannels: 1,
791             nSamplesPerSec: 48000,
792             nAvgBytesPerSec: 4 * 48000,
793             nBlockAlign: 4,
794             wBitsPerSample: 16,
795             cbSize: 0,
796         };
797 
798         let format =
799             // SAFETY: We can convert a struct to a pointer declared above. Also that means the
800             // pointer can be safely deferenced.
801             unsafe { WaveAudioFormat::new((&format) as *const WAVEFORMATEX as *mut WAVEFORMATEX) };
802 
803         let expected = WAVEFORMATEX {
804             wFormatTag: WAVE_FORMAT_PCM,
805             // The field below is the difference
806             nChannels: 6,
807             nSamplesPerSec: 48000,
808             nAvgBytesPerSec: 4 * 48000,
809             nBlockAlign: 4,
810             wBitsPerSample: 16,
811             cbSize: 0,
812         };
813 
814         // SAFETY: We can convert a struct to a pointer declared above. Also that means the
815         // pointer can be safely deferenced.
816         let expected = unsafe {
817             WaveAudioFormat::new((&expected) as *const WAVEFORMATEX as *mut WAVEFORMATEX)
818         };
819 
820         assert_ne!(expected, format);
821     }
822 
823     #[test]
test_format_comparison_waveformatextensible_fail()824     fn test_format_comparison_waveformatextensible_fail() {
825         let format = WAVEFORMATEXTENSIBLE {
826             Format: WAVEFORMATEX {
827                 wFormatTag: WAVE_FORMAT_EXTENSIBLE,
828                 nChannels: 1,
829                 nSamplesPerSec: 48000,
830                 nAvgBytesPerSec: 4 * 48000,
831                 nBlockAlign: 4,
832                 wBitsPerSample: 16,
833                 cbSize: 22,
834             },
835             Samples: 16,
836             dwChannelMask: SPEAKER_FRONT_LEFT | SPEAKER_FRONT_RIGHT,
837             SubFormat: KSDATAFORMAT_SUBTYPE_IEEE_FLOAT,
838         };
839 
840         // SAFETY: We can convert a struct to a pointer declared above. Also that means the
841         // pointer can be safely deferenced.
842         let format = unsafe {
843             WaveAudioFormat::new((&format) as *const WAVEFORMATEXTENSIBLE as *mut WAVEFORMATEX)
844         };
845 
846         let expected = WAVEFORMATEXTENSIBLE {
847             Format: WAVEFORMATEX {
848                 wFormatTag: WAVE_FORMAT_EXTENSIBLE,
849                 nChannels: 1,
850                 nSamplesPerSec: 48000,
851                 nAvgBytesPerSec: 4 * 48000,
852                 nBlockAlign: 4,
853                 wBitsPerSample: 16,
854                 cbSize: 22,
855             },
856             Samples: 16,
857             // The field below is the difference.
858             dwChannelMask: SPEAKER_FRONT_CENTER | SPEAKER_BACK_LEFT,
859             SubFormat: KSDATAFORMAT_SUBTYPE_IEEE_FLOAT,
860         };
861 
862         // SAFETY: We can convert a struct to a pointer declared above. Also that means the
863         // pointer can be safely deferenced.
864         let expected = unsafe {
865             WaveAudioFormat::new((&expected) as *const WAVEFORMATEXTENSIBLE as *mut WAVEFORMATEX)
866         };
867 
868         assert_ne!(expected, format);
869     }
870 
871     #[test]
test_modify_mix_mono_channel_different_bit_depth_wave_format_extensible()872     fn test_modify_mix_mono_channel_different_bit_depth_wave_format_extensible() {
873         // Start with a mono channel and 16 bit depth format.
874         let format = WAVEFORMATEXTENSIBLE {
875             Format: WAVEFORMATEX {
876                 wFormatTag: WAVE_FORMAT_EXTENSIBLE,
877                 nChannels: 1,
878                 nSamplesPerSec: 48000,
879                 nAvgBytesPerSec: 2 * 48000,
880                 nBlockAlign: 2,
881                 wBitsPerSample: 16,
882                 cbSize: 22,
883             },
884             Samples: 16,
885             // Probably will never see a mask like this for two channels, but this is just testing
886             // that it will get changed.
887             dwChannelMask: SPEAKER_FRONT_LEFT | SPEAKER_FRONT_RIGHT,
888             SubFormat: KSDATAFORMAT_SUBTYPE_IEEE_FLOAT,
889         };
890 
891         // SAFETY: We can convert a struct to a pointer declared above. Also that means the
892         // pointer can be safely deferenced.
893         let mut format = unsafe {
894             WaveAudioFormat::new((&format) as *const WAVEFORMATEXTENSIBLE as *mut WAVEFORMATEX)
895         };
896 
897         format.modify_mix_format(
898             /* bit_depth= */ 32,
899             /* ks_data_format= */ KSDATAFORMAT_SUBTYPE_IEEE_FLOAT,
900         );
901 
902         // The format should be converted to 32 bit depth and retain mono channel.
903         let expected = WAVEFORMATEXTENSIBLE {
904             Format: WAVEFORMATEX {
905                 wFormatTag: WAVE_FORMAT_EXTENSIBLE,
906                 nChannels: 1,
907                 nSamplesPerSec: 48000,
908                 nAvgBytesPerSec: 4 * 48000, // Changed
909                 nBlockAlign: 4,             // Changed
910                 wBitsPerSample: 32,         // Changed
911                 cbSize: 22,
912             },
913             Samples: 32,
914             dwChannelMask: SPEAKER_FRONT_CENTER, // Changed
915             SubFormat: KSDATAFORMAT_SUBTYPE_IEEE_FLOAT,
916         };
917 
918         // SAFETY: We can convert a struct to a pointer declared above. Also that means the
919         // pointer can be safely deferenced.
920         let expected = unsafe {
921             WaveAudioFormat::new((&expected) as *const WAVEFORMATEXTENSIBLE as *mut WAVEFORMATEX)
922         };
923 
924         assert_eq!(format, expected);
925     }
926 
927     #[test]
test_modify_mix_mono_channel_different_bit_depth_wave_format()928     fn test_modify_mix_mono_channel_different_bit_depth_wave_format() {
929         // Start with a mono channel and 16 bit depth format.
930         let format = WAVEFORMATEX {
931             wFormatTag: WAVE_FORMAT_PCM,
932             nChannels: 1,
933             nSamplesPerSec: 48000,
934             nAvgBytesPerSec: 2 * 48000,
935             nBlockAlign: 2,
936             wBitsPerSample: 16,
937             cbSize: 0,
938         };
939 
940         let mut format =
941             // SAFETY: We can convert a struct to a pointer declared above. Also that means the
942             // pointer can be safely deferenced.
943             unsafe { WaveAudioFormat::new((&format) as *const WAVEFORMATEX as *mut WAVEFORMATEX) };
944 
945         format.modify_mix_format(
946             /* bit_depth= */ 32,
947             /* ks_data_format= */ KSDATAFORMAT_SUBTYPE_IEEE_FLOAT,
948         );
949 
950         // The format should be converted to 32 bit depth and retain mono channel.
951         let expected = WAVEFORMATEX {
952             wFormatTag: WAVE_FORMAT_IEEE_FLOAT, // Changed
953             nChannels: 1,
954             nSamplesPerSec: 48000,
955             nAvgBytesPerSec: 4 * 48000, // Changed
956             nBlockAlign: 4,             // Changed
957             wBitsPerSample: 32,         // Changed
958             cbSize: 0,
959         };
960 
961         // SAFETY: We can convert a struct to a pointer declared above. Also that means the
962         // pointer can be safely deferenced.
963         let expected = unsafe {
964             WaveAudioFormat::new((&expected) as *const WAVEFORMATEX as *mut WAVEFORMATEX)
965         };
966 
967         assert_eq!(format, expected);
968     }
969 
970     #[test]
test_waveformatex_non_ieee_modify_format()971     fn test_waveformatex_non_ieee_modify_format() {
972         let format = WAVEFORMATEX {
973             wFormatTag: WAVE_FORMAT_PCM,
974             nChannels: 2,
975             nSamplesPerSec: 48000,
976             nAvgBytesPerSec: 192000,
977             nBlockAlign: 4,
978             wBitsPerSample: 16,
979             cbSize: 0,
980         };
981 
982         let mut format =
983             // SAFETY: We can convert a struct to a pointer declared above. Also that means the
984             // pointer can be safely deferenced.
985             unsafe { WaveAudioFormat::new((&format) as *const WAVEFORMATEX as *mut WAVEFORMATEX) };
986 
987         format.modify_mix_format(
988             /* bit_depth= */ 32,
989             /* ks_data_format= */ KSDATAFORMAT_SUBTYPE_IEEE_FLOAT,
990         );
991 
992         let result_format = format.take_waveformatex();
993 
994         assert_waveformatex_ieee(&result_format);
995     }
996 
997     #[test]
test_waveformatex_non_ieee_32_bit_modify_format()998     fn test_waveformatex_non_ieee_32_bit_modify_format() {
999         let format = WAVEFORMATEX {
1000             wFormatTag: WAVE_FORMAT_PCM,
1001             nChannels: 2,
1002             nSamplesPerSec: 48000,
1003             nAvgBytesPerSec: 384000,
1004             nBlockAlign: 8,
1005             wBitsPerSample: 32,
1006             cbSize: 0,
1007         };
1008 
1009         let mut format =
1010             // SAFETY: We can convert a struct to a pointer declared above. Also that means the
1011             // pointer can be safely deferenced.
1012             unsafe { WaveAudioFormat::new((&format) as *const WAVEFORMATEX as *mut WAVEFORMATEX) };
1013 
1014         format.modify_mix_format(
1015             /* bit_depth= */ 32,
1016             /* ks_data_format= */ KSDATAFORMAT_SUBTYPE_IEEE_FLOAT,
1017         );
1018 
1019         let result_format = format.take_waveformatex();
1020 
1021         assert_waveformatex_ieee(&result_format);
1022     }
1023 
assert_waveformatex_ieee(result_format: &WAVEFORMATEX)1024     fn assert_waveformatex_ieee(result_format: &WAVEFORMATEX) {
1025         let format_tag = result_format.wFormatTag;
1026         assert_eq!(format_tag, WAVE_FORMAT_IEEE_FLOAT);
1027         let channels = result_format.nChannels;
1028         assert_eq!(channels, 2);
1029         let samples_per_sec = result_format.nSamplesPerSec;
1030         assert_eq!(samples_per_sec, 48000);
1031         let avg_bytes_per_sec = result_format.nAvgBytesPerSec;
1032         assert_eq!(avg_bytes_per_sec, 384000);
1033         let block_align = result_format.nBlockAlign;
1034         assert_eq!(block_align, 8);
1035         let bits_per_samples = result_format.wBitsPerSample;
1036         assert_eq!(bits_per_samples, 32);
1037         let size = result_format.cbSize;
1038         assert_eq!(size, 0);
1039     }
1040 
1041     #[test]
test_create_audio_shared_format_wave_format_ex()1042     fn test_create_audio_shared_format_wave_format_ex() {
1043         let wave_format = WAVEFORMATEX {
1044             wFormatTag: WAVE_FORMAT_PCM,
1045             nChannels: 2,
1046             nSamplesPerSec: 48000,
1047             nAvgBytesPerSec: 192000,
1048             nBlockAlign: 4,
1049             wBitsPerSample: 16,
1050             cbSize: 0,
1051         };
1052 
1053         // SAFETY: We can convert a struct to a pointer declared above. Also that means the
1054         // pointer can be safely deferenced.
1055         let format = unsafe {
1056             WaveAudioFormat::new((&wave_format) as *const WAVEFORMATEX as *mut WAVEFORMATEX)
1057         };
1058 
1059         // The period will most likely never be 123, but this is ok for testing.
1060         let audio_shared_format =
1061             format.create_audio_shared_format(/* shared_audio_engine_period_in_frames= */ 123);
1062 
1063         assert_eq!(
1064             audio_shared_format.bit_depth,
1065             wave_format.wBitsPerSample as usize
1066         );
1067         assert_eq!(audio_shared_format.channels, wave_format.nChannels as usize);
1068         assert_eq!(
1069             audio_shared_format.frame_rate,
1070             wave_format.nSamplesPerSec as usize
1071         );
1072         assert_eq!(
1073             audio_shared_format.shared_audio_engine_period_in_frames,
1074             123
1075         );
1076     }
1077 
1078     #[test]
test_create_audio_shared_format_wave_format_extensible()1079     fn test_create_audio_shared_format_wave_format_extensible() {
1080         let wave_format_extensible = WAVEFORMATEXTENSIBLE {
1081             Format: WAVEFORMATEX {
1082                 wFormatTag: WAVE_FORMAT_EXTENSIBLE,
1083                 nChannels: 2,
1084                 nSamplesPerSec: 48000,
1085                 nAvgBytesPerSec: 8 * 48000,
1086                 nBlockAlign: 8,
1087                 wBitsPerSample: 32,
1088                 cbSize: 22,
1089             },
1090             Samples: 32,
1091             dwChannelMask: SPEAKER_FRONT_LEFT | SPEAKER_FRONT_RIGHT,
1092             SubFormat: KSDATAFORMAT_SUBTYPE_IEEE_FLOAT,
1093         };
1094 
1095         // SAFETY: We can convert a struct to a pointer declared above. Also that means the
1096         // pointer can be safely deferenced.
1097         let format = unsafe {
1098             WaveAudioFormat::new((&wave_format_extensible) as *const _ as *mut WAVEFORMATEX)
1099         };
1100 
1101         // The period will most likely never be 123, but this is ok for testing.
1102         let audio_shared_format =
1103             format.create_audio_shared_format(/* shared_audio_engine_period_in_frames= */ 123);
1104 
1105         assert_eq!(
1106             audio_shared_format.bit_depth,
1107             wave_format_extensible.Format.wBitsPerSample as usize
1108         );
1109         assert_eq!(
1110             audio_shared_format.channels,
1111             wave_format_extensible.Format.nChannels as usize
1112         );
1113         assert_eq!(
1114             audio_shared_format.frame_rate,
1115             wave_format_extensible.Format.nSamplesPerSec as usize
1116         );
1117         assert_eq!(
1118             audio_shared_format.shared_audio_engine_period_in_frames,
1119             123
1120         );
1121         assert_eq!(
1122             audio_shared_format.channel_mask,
1123             Some(SPEAKER_FRONT_LEFT | SPEAKER_FRONT_RIGHT)
1124         );
1125     }
1126 
1127     #[test]
test_wave_format_to_proto_convertion()1128     fn test_wave_format_to_proto_convertion() {
1129         let wave_format = WAVEFORMATEX {
1130             wFormatTag: WAVE_FORMAT_PCM,
1131             nChannels: 2,
1132             nSamplesPerSec: 48000,
1133             nAvgBytesPerSec: 192000,
1134             nBlockAlign: 4,
1135             wBitsPerSample: 16,
1136             cbSize: 0,
1137         };
1138 
1139         let wave_audio_format =
1140             // SAFETY: We can convert a struct to a pointer declared above. Also that means the
1141             // pointer can be safely deferenced.
1142             unsafe { WaveAudioFormat::new((&wave_format) as *const _ as *mut WAVEFORMATEX) };
1143 
1144         // Testing the `into`.
1145         let wave_format_metric = WaveFormatMetric::from(&wave_audio_format);
1146 
1147         let expected = WaveFormatMetric {
1148             format_tag: WAVE_FORMAT_PCM.into(),
1149             channels: 2,
1150             samples_per_sec: 48000,
1151             avg_bytes_per_sec: 192000,
1152             block_align: 4,
1153             bits_per_sample: 16,
1154             size_bytes: 0,
1155             samples: None,
1156             channel_mask: None,
1157             sub_format: None,
1158         };
1159 
1160         assert_eq!(wave_format_metric, expected);
1161     }
1162 
1163     #[test]
test_wave_format_extensible_to_proto_convertion()1164     fn test_wave_format_extensible_to_proto_convertion() {
1165         let wave_format_extensible = WAVEFORMATEXTENSIBLE {
1166             Format: WAVEFORMATEX {
1167                 wFormatTag: WAVE_FORMAT_EXTENSIBLE,
1168                 nChannels: 2,
1169                 nSamplesPerSec: 48000,
1170                 nAvgBytesPerSec: 8 * 48000,
1171                 nBlockAlign: 8,
1172                 wBitsPerSample: 32,
1173                 cbSize: 22,
1174             },
1175             Samples: 32,
1176             dwChannelMask: SPEAKER_FRONT_LEFT | SPEAKER_FRONT_RIGHT,
1177             SubFormat: KSDATAFORMAT_SUBTYPE_IEEE_FLOAT,
1178         };
1179 
1180         // SAFETY: We can convert a struct to a pointer declared above. Also that means the pointer
1181         // can be safely deferenced.
1182         let wave_audio_format = unsafe {
1183             WaveAudioFormat::new((&wave_format_extensible) as *const _ as *mut WAVEFORMATEX)
1184         };
1185 
1186         // Testing the `into`.
1187         let wave_format_metric = WaveFormatMetric::from(&wave_audio_format);
1188 
1189         let expected = WaveFormatMetric {
1190             format_tag: WAVE_FORMAT_EXTENSIBLE.into(),
1191             channels: 2,
1192             samples_per_sec: 48000,
1193             avg_bytes_per_sec: 8 * 48000,
1194             block_align: 8,
1195             bits_per_sample: 32,
1196             size_bytes: 22,
1197             samples: Some(32),
1198             channel_mask: Some((SPEAKER_FRONT_LEFT | SPEAKER_FRONT_RIGHT) as i64),
1199             sub_format: Some(WaveFormatSubFormatMetric::IeeeFloat),
1200         };
1201 
1202         assert_eq!(wave_format_metric, expected);
1203     }
1204 }
1205