1#!/usr/bin/env python3
2#
3#   Copyright 2017 - The Android Open Source Project
4#
5#   Licensed under the Apache License, Version 2.0 (the "License");
6#   you may not use this file except in compliance with the License.
7#   You may obtain a copy of the License at
8#
9#       http://www.apache.org/licenses/LICENSE-2.0
10#
11#   Unless required by applicable law or agreed to in writing, software
12#   distributed under the License is distributed on an "AS IS" BASIS,
13#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#   See the License for the specific language governing permissions and
15#   limitations under the License.
16"""Audio Analysis tool to analyze wave file and detect artifacts."""
17
18import collections
19import json
20import logging
21import numpy
22import pprint
23import subprocess
24import tempfile
25import wave
26
27import acts_contrib.test_utils.audio_analysis_lib.audio_analysis as audio_analysis
28import acts_contrib.test_utils.audio_analysis_lib.audio_data as audio_data
29import acts_contrib.test_utils.audio_analysis_lib.audio_quality_measurement as \
30 audio_quality_measurement
31
32# Holder for quality parameters used in audio_quality_measurement module.
33QualityParams = collections.namedtuple('QualityParams', [
34    'block_size_secs', 'frequency_error_threshold',
35    'delay_amplitude_threshold', 'noise_amplitude_threshold',
36    'burst_amplitude_threshold'
37])
38
39DEFAULT_QUALITY_BLOCK_SIZE_SECS = 0.0015
40DEFAULT_BURST_AMPLITUDE_THRESHOLD = 1.4
41DEFAULT_DELAY_AMPLITUDE_THRESHOLD = 0.6
42DEFAULT_FREQUENCY_ERROR_THRESHOLD = 0.5
43DEFAULT_NOISE_AMPLITUDE_THRESHOLD = 0.5
44
45
46class WaveFileException(Exception):
47    """Error in WaveFile."""
48
49
50class WaveFormatExtensibleException(Exception):
51    """Wave file is in WAVE_FORMAT_EXTENSIBLE format which is not supported."""
52
53
54class WaveFile(object):
55    """Class which handles wave file reading.
56
57    Properties:
58        raw_data: audio_data.AudioRawData object for data in wave file.
59        rate: sampling rate.
60
61    """
62
63    def __init__(self, filename):
64        """Inits a wave file.
65
66        Args:
67            filename: file name of the wave file.
68
69        """
70        self.raw_data = None
71        self.rate = None
72
73        self._wave_reader = None
74        self._n_channels = None
75        self._sample_width_bits = None
76        self._n_frames = None
77        self._binary = None
78
79        try:
80            self._read_wave_file(filename)
81        except WaveFormatExtensibleException:
82            logging.warning(
83                'WAVE_FORMAT_EXTENSIBLE is not supproted. '
84                'Try command "sox in.wav -t wavpcm out.wav" to convert '
85                'the file to WAVE_FORMAT_PCM format.')
86            self._convert_and_read_wav_file(filename)
87
88    def _convert_and_read_wav_file(self, filename):
89        """Converts the wav file and read it.
90
91        Converts the file into WAVE_FORMAT_PCM format using sox command and
92        reads its content.
93
94        Args:
95            filename: The wave file to be read.
96
97        Raises:
98            RuntimeError: sox is not installed.
99
100        """
101        # Checks if sox is installed.
102        try:
103            subprocess.check_output(['sox', '--version'])
104        except:
105            raise RuntimeError('sox command is not installed. '
106                               'Try sudo apt-get install sox')
107
108        with tempfile.NamedTemporaryFile(suffix='.wav') as converted_file:
109            command = ['sox', filename, '-t', 'wavpcm', converted_file.name]
110            logging.debug('Convert the file using sox: %s', command)
111            subprocess.check_call(command)
112            self._read_wave_file(converted_file.name)
113
114    def _read_wave_file(self, filename):
115        """Reads wave file header and samples.
116
117        Args:
118            filename: The wave file to be read.
119
120        @raises WaveFormatExtensibleException: Wave file is in
121                                               WAVE_FORMAT_EXTENSIBLE format.
122        @raises WaveFileException: Wave file format is not supported.
123
124        """
125        try:
126            self._wave_reader = wave.open(filename, 'r')
127            self._read_wave_header()
128            self._read_wave_binary()
129        except wave.Error as e:
130            if 'unknown format: 65534' in str(e):
131                raise WaveFormatExtensibleException()
132            else:
133                logging.exception('Unsupported wave format')
134                raise WaveFileException()
135        finally:
136            if self._wave_reader:
137                self._wave_reader.close()
138
139    def _read_wave_header(self):
140        """Reads wave file header.
141
142        @raises WaveFileException: wave file is compressed.
143
144        """
145        # Header is a tuple of
146        # (nchannels, sampwidth, framerate, nframes, comptype, compname).
147        header = self._wave_reader.getparams()
148        logging.debug('Wave header: %s', header)
149
150        self._n_channels = header[0]
151        self._sample_width_bits = header[1] * 8
152        self.rate = header[2]
153        self._n_frames = header[3]
154        comptype = header[4]
155        compname = header[5]
156
157        if comptype != 'NONE' or compname != 'not compressed':
158            raise WaveFileException('Can not support compressed wav file.')
159
160    def _read_wave_binary(self):
161        """Reads in samples in wave file."""
162        self._binary = self._wave_reader.readframes(self._n_frames)
163        format_str = 'S%d_LE' % self._sample_width_bits
164        self.raw_data = audio_data.AudioRawData(binary=self._binary,
165                                                channel=self._n_channels,
166                                                sample_format=format_str)
167
168
169class QualityCheckerError(Exception):
170    """Error in QualityChecker."""
171
172
173class CompareFailure(QualityCheckerError):
174    """Exception when frequency comparison fails."""
175
176
177class QualityFailure(QualityCheckerError):
178    """Exception when quality check fails."""
179
180
181class QualityChecker(object):
182    """Quality checker controls the flow of checking quality of raw data."""
183
184    def __init__(self, raw_data, rate):
185        """Inits a quality checker.
186
187        Args:
188            raw_data: An audio_data.AudioRawData object.
189            rate: Sampling rate in samples per second. Example inputs: 44100,
190            48000
191
192        """
193        self._raw_data = raw_data
194        self._rate = rate
195        self._spectrals = []
196        self._quality_result = []
197
198    def do_spectral_analysis(self, ignore_high_freq, check_quality,
199                             quality_params):
200        """Gets the spectral_analysis result.
201
202        Args:
203            ignore_high_freq: Ignore high frequencies above this threshold.
204            check_quality: Check quality of each channel.
205            quality_params: A QualityParams object for quality measurement.
206
207        """
208        self.has_data()
209        for channel_idx in range(self._raw_data.channel):
210            signal = self._raw_data.channel_data[channel_idx]
211            max_abs = max(numpy.abs(signal))
212            logging.debug('Channel %d max abs signal: %f', channel_idx,
213                          max_abs)
214            if max_abs == 0:
215                logging.info('No data on channel %d, skip this channel',
216                             channel_idx)
217                continue
218
219            saturate_value = audio_data.get_maximum_value_from_sample_format(
220                self._raw_data.sample_format)
221            normalized_signal = audio_analysis.normalize_signal(
222                signal, saturate_value)
223            logging.debug('saturate_value: %f', saturate_value)
224            logging.debug('max signal after normalized: %f',
225                          max(normalized_signal))
226            spectral = audio_analysis.spectral_analysis(
227                normalized_signal, self._rate)
228
229            logging.debug('Channel %d spectral:\n%s', channel_idx,
230                          pprint.pformat(spectral))
231
232            # Ignore high frequencies above the threshold.
233            spectral = [(f, c) for (f, c) in spectral if f < ignore_high_freq]
234
235            logging.info(
236                'Channel %d spectral after ignoring high frequencies '
237                'above %f:\n%s', channel_idx, ignore_high_freq,
238                pprint.pformat(spectral))
239
240            try:
241                if check_quality:
242                    quality = audio_quality_measurement.quality_measurement(
243                        signal=normalized_signal,
244                        rate=self._rate,
245                        dominant_frequency=spectral[0][0],
246                        block_size_secs=quality_params.block_size_secs,
247                        frequency_error_threshold=quality_params.
248                        frequency_error_threshold,
249                        delay_amplitude_threshold=quality_params.
250                        delay_amplitude_threshold,
251                        noise_amplitude_threshold=quality_params.
252                        noise_amplitude_threshold,
253                        burst_amplitude_threshold=quality_params.
254                        burst_amplitude_threshold)
255
256                    logging.debug('Channel %d quality:\n%s', channel_idx,
257                                  pprint.pformat(quality))
258                    self._quality_result.append(quality)
259                self._spectrals.append(spectral)
260            except Exception as error:
261                logging.warning(
262                    "Failed to analyze channel {} with error: {}".format(
263                        channel_idx, error))
264
265    def has_data(self):
266        """Checks if data has been set.
267
268        Raises:
269            QualityCheckerError: if data or rate is not set yet.
270
271        """
272        if not self._raw_data or not self._rate:
273            raise QualityCheckerError('Data and rate is not set yet')
274
275    def check_freqs(self, expected_freqs, freq_threshold):
276        """Checks the dominant frequencies in the channels.
277
278        Args:
279            expected_freq: A list of frequencies. If frequency is 0, it
280                              means this channel should be ignored.
281            freq_threshold: The difference threshold to compare two
282                               frequencies.
283
284        """
285        logging.debug('expected_freqs: %s', expected_freqs)
286        for idx, expected_freq in enumerate(expected_freqs):
287            if expected_freq == 0:
288                continue
289            if not self._spectrals[idx]:
290                raise CompareFailure(
291                    'Failed at channel %d: no dominant frequency' % idx)
292            dominant_freq = self._spectrals[idx][0][0]
293            if abs(dominant_freq - expected_freq) > freq_threshold:
294                raise CompareFailure(
295                    'Failed at channel %d: %f is too far away from %f' %
296                    (idx, dominant_freq, expected_freq))
297
298    def check_quality(self):
299        """Checks the quality measurement results on each channel.
300
301        Raises:
302            QualityFailure when there is artifact.
303
304        """
305        error_msgs = []
306
307        for idx, quality_res in enumerate(self._quality_result):
308            artifacts = quality_res['artifacts']
309            if artifacts['noise_before_playback']:
310                error_msgs.append('Found noise before playback: %s' %
311                                  (artifacts['noise_before_playback']))
312            if artifacts['noise_after_playback']:
313                error_msgs.append('Found noise after playback: %s' %
314                                  (artifacts['noise_after_playback']))
315            if artifacts['delay_during_playback']:
316                error_msgs.append('Found delay during playback: %s' %
317                                  (artifacts['delay_during_playback']))
318            if artifacts['burst_during_playback']:
319                error_msgs.append('Found burst during playback: %s' %
320                                  (artifacts['burst_during_playback']))
321        if error_msgs:
322            raise QualityFailure('Found bad quality: %s',
323                                 '\n'.join(error_msgs))
324
325    def dump(self, output_file):
326        """Dumps the result into a file in json format.
327
328        Args:
329            output_file: A file path to dump spectral and quality
330                            measurement result of each channel.
331
332        """
333        dump_dict = {
334            'spectrals': self._spectrals,
335            'quality_result': self._quality_result
336        }
337        with open(output_file, 'w') as f:
338            json.dump(dump_dict, f)
339
340    def has_data(self):
341        """Checks if data has been set.
342
343        Raises:
344            QualityCheckerError: if data or rate is not set yet.
345
346        """
347        if not self._raw_data or not self._rate:
348            raise QualityCheckerError('Data and rate is not set yet')
349
350    def check_freqs(self, expected_freqs, freq_threshold):
351        """Checks the dominant frequencies in the channels.
352
353        Args:
354            expected_freq: A list of frequencies. If frequency is 0, it
355                              means this channel should be ignored.
356            freq_threshold: The difference threshold to compare two
357                               frequencies.
358
359        """
360        logging.debug('expected_freqs: %s', expected_freqs)
361        for idx, expected_freq in enumerate(expected_freqs):
362            if expected_freq == 0:
363                continue
364            if not self._spectrals[idx]:
365                raise CompareFailure(
366                    'Failed at channel %d: no dominant frequency' % idx)
367            dominant_freq = self._spectrals[idx][0][0]
368            if abs(dominant_freq - expected_freq) > freq_threshold:
369                raise CompareFailure(
370                    'Failed at channel %d: %f is too far away from %f' %
371                    (idx, dominant_freq, expected_freq))
372
373    def check_quality(self):
374        """Checks the quality measurement results on each channel.
375
376        Raises:
377            QualityFailure when there is artifact.
378
379        """
380        error_msgs = []
381
382        for idx, quality_res in enumerate(self._quality_result):
383            artifacts = quality_res['artifacts']
384            if artifacts['noise_before_playback']:
385                error_msgs.append('Found noise before playback: %s' %
386                                  (artifacts['noise_before_playback']))
387            if artifacts['noise_after_playback']:
388                error_msgs.append('Found noise after playback: %s' %
389                                  (artifacts['noise_after_playback']))
390            if artifacts['delay_during_playback']:
391                error_msgs.append('Found delay during playback: %s' %
392                                  (artifacts['delay_during_playback']))
393            if artifacts['burst_during_playback']:
394                error_msgs.append('Found burst during playback: %s' %
395                                  (artifacts['burst_during_playback']))
396        if error_msgs:
397            raise QualityFailure('Found bad quality: %s',
398                                 '\n'.join(error_msgs))
399
400    def dump(self, output_file):
401        """Dumps the result into a file in json format.
402
403        Args:
404            output_file: A file path to dump spectral and quality
405                            measurement result of each channel.
406
407        """
408        dump_dict = {
409            'spectrals': self._spectrals,
410            'quality_result': self._quality_result
411        }
412        with open(output_file, 'w') as f:
413            json.dump(dump_dict, f)
414
415
416class CheckQualityError(Exception):
417    """Error in check_quality main function."""
418
419
420def read_audio_file(filename, channel, bit_width, rate):
421    """Reads audio file.
422
423    Args:
424        filename: The wav or raw file to check.
425        channel: For raw file. Number of channels.
426        bit_width: For raw file. Bit width of a sample.
427        rate: Sampling rate in samples per second. Example inputs: 44100,
428        48000
429
430
431    Returns:
432        A tuple (raw_data, rate) where raw_data is audio_data.AudioRawData, rate
433            is sampling rate.
434
435    """
436    if filename.endswith('.wav'):
437        wavefile = WaveFile(filename)
438        raw_data = wavefile.raw_data
439        rate = wavefile.rate
440    elif filename.endswith('.raw'):
441        binary = None
442        with open(filename, 'rb') as f:
443            binary = f.read()
444        raw_data = audio_data.AudioRawData(binary=binary,
445                                           channel=channel,
446                                           sample_format='S%d_LE' % bit_width)
447    else:
448        raise CheckQualityError('File format for %s is not supported' %
449                                filename)
450
451    return raw_data, rate
452
453
454def get_quality_params(quality_block_size_secs,
455                       quality_frequency_error_threshold,
456                       quality_delay_amplitude_threshold,
457                       quality_noise_amplitude_threshold,
458                       quality_burst_amplitude_threshold):
459    """Gets quality parameters in arguments.
460
461    Args:
462        quality_block_size_secs: Input block size in seconds.
463        quality_frequency_error_threshold: Input the frequency error
464        threshold.
465        quality_delay_amplitude_threshold: Input the delay aplitutde
466        threshold.
467        quality_noise_amplitude_threshold: Input the noise aplitutde
468        threshold.
469        quality_burst_amplitude_threshold: Input the burst aplitutde
470        threshold.
471
472    Returns:
473        A QualityParams object.
474
475    """
476    quality_params = QualityParams(
477        block_size_secs=quality_block_size_secs,
478        frequency_error_threshold=quality_frequency_error_threshold,
479        delay_amplitude_threshold=quality_delay_amplitude_threshold,
480        noise_amplitude_threshold=quality_noise_amplitude_threshold,
481        burst_amplitude_threshold=quality_burst_amplitude_threshold)
482
483    return quality_params
484
485
486def quality_analysis(
487        filename,
488        output_file,
489        bit_width,
490        rate,
491        channel,
492        freqs=None,
493        freq_threshold=5,
494        ignore_high_freq=5000,
495        spectral_only=False,
496        quality_block_size_secs=DEFAULT_QUALITY_BLOCK_SIZE_SECS,
497        quality_burst_amplitude_threshold=DEFAULT_BURST_AMPLITUDE_THRESHOLD,
498        quality_delay_amplitude_threshold=DEFAULT_DELAY_AMPLITUDE_THRESHOLD,
499        quality_frequency_error_threshold=DEFAULT_FREQUENCY_ERROR_THRESHOLD,
500        quality_noise_amplitude_threshold=DEFAULT_NOISE_AMPLITUDE_THRESHOLD,
501):
502    """ Runs various functions to measure audio quality base on user input.
503
504    Args:
505        filename: The wav or raw file to check.
506        output_file: Output file to dump analysis result in JSON format.
507        bit_width: For raw file. Bit width of a sample.
508        rate: Sampling rate in samples per second. Example inputs: 44100,
509        48000
510        channel: For raw file. Number of channels.
511        freqs: Expected frequencies in the channels.
512        freq_threshold: Frequency difference threshold in Hz.
513        ignore_high_freq: Frequency threshold in Hz to be ignored for high
514        frequency. Default is 5KHz
515        spectral_only: Only do spectral analysis on each channel.
516        quality_block_size_secs: Input block size in seconds.
517        quality_frequency_error_threshold: Input the frequency error
518        threshold.
519        quality_delay_amplitude_threshold: Input the delay aplitutde
520        threshold.
521        quality_noise_amplitude_threshold: Input the noise aplitutde
522        threshold.
523        quality_burst_amplitude_threshold: Input the burst aplitutde
524        threshold.
525    """
526
527    raw_data, rate = read_audio_file(filename, channel, bit_width, rate)
528
529    checker = QualityChecker(raw_data, rate)
530
531    quality_params = get_quality_params(quality_block_size_secs,
532                                        quality_frequency_error_threshold,
533                                        quality_delay_amplitude_threshold,
534                                        quality_noise_amplitude_threshold,
535                                        quality_burst_amplitude_threshold)
536
537    checker.do_spectral_analysis(ignore_high_freq=ignore_high_freq,
538                                 check_quality=(not spectral_only),
539                                 quality_params=quality_params)
540
541    checker.dump(output_file)
542
543    if freqs:
544        checker.check_freqs(freqs, freq_threshold)
545
546    if not spectral_only:
547        checker.check_quality()
548    logging.debug("Audio analysis completed.")
549