1"""Stuff to parse WAVE files.
2
3Usage.
4
5Reading WAVE files:
6      f = wave.open(file, 'r')
7where file is either the name of a file or an open file pointer.
8The open file pointer must have methods read(), seek(), and close().
9When the setpos() and rewind() methods are not used, the seek()
10method is not  necessary.
11
12This returns an instance of a class with the following public methods:
13      getnchannels()  -- returns number of audio channels (1 for
14                         mono, 2 for stereo)
15      getsampwidth()  -- returns sample width in bytes
16      getframerate()  -- returns sampling frequency
17      getnframes()    -- returns number of audio frames
18      getcomptype()   -- returns compression type ('NONE' for linear samples)
19      getcompname()   -- returns human-readable version of
20                         compression type ('not compressed' linear samples)
21      getparams()     -- returns a namedtuple consisting of all of the
22                         above in the above order
23      getmarkers()    -- returns None (for compatibility with the
24                         aifc module)
25      getmark(id)     -- raises an error since the mark does not
26                         exist (for compatibility with the aifc module)
27      readframes(n)   -- returns at most n frames of audio
28      rewind()        -- rewind to the beginning of the audio stream
29      setpos(pos)     -- seek to the specified position
30      tell()          -- return the current position
31      close()         -- close the instance (make it unusable)
32The position returned by tell() and the position given to setpos()
33are compatible and have nothing to do with the actual position in the
34file.
35The close() method is called automatically when the class instance
36is destroyed.
37
38Writing WAVE files:
39      f = wave.open(file, 'w')
40where file is either the name of a file or an open file pointer.
41The open file pointer must have methods write(), tell(), seek(), and
42close().
43
44This returns an instance of a class with the following public methods:
45      setnchannels(n) -- set the number of channels
46      setsampwidth(n) -- set the sample width
47      setframerate(n) -- set the frame rate
48      setnframes(n)   -- set the number of frames
49      setcomptype(type, name)
50                      -- set the compression type and the
51                         human-readable compression type
52      setparams(tuple)
53                      -- set all parameters at once
54      tell()          -- return current position in output file
55      writeframesraw(data)
56                      -- write audio frames without patching up the
57                         file header
58      writeframes(data)
59                      -- write audio frames and patch up the file header
60      close()         -- patch up the file header and close the
61                         output file
62You should set the parameters before the first writeframesraw or
63writeframes.  The total number of frames does not need to be set,
64but when it is set to the correct value, the header does not have to
65be patched up.
66It is best to first set all parameters, perhaps possibly the
67compression type, and then write audio frames using writeframesraw.
68When all frames have been written, either call writeframes(b'') or
69close() to patch up the sizes in the header.
70The close() method is called automatically when the class instance
71is destroyed.
72"""
73
74from collections import namedtuple
75import builtins
76import struct
77import sys
78
79
80__all__ = ["open", "Error", "Wave_read", "Wave_write"]
81
82class Error(Exception):
83    pass
84
85WAVE_FORMAT_PCM = 0x0001
86
87_array_fmts = None, 'b', 'h', None, 'i'
88
89_wave_params = namedtuple('_wave_params',
90                     'nchannels sampwidth framerate nframes comptype compname')
91
92
93def _byteswap(data, width):
94    swapped_data = bytearray(len(data))
95
96    for i in range(0, len(data), width):
97        for j in range(width):
98            swapped_data[i + width - 1 - j] = data[i + j]
99
100    return bytes(swapped_data)
101
102
103class _Chunk:
104    def __init__(self, file, align=True, bigendian=True, inclheader=False):
105        self.closed = False
106        self.align = align      # whether to align to word (2-byte) boundaries
107        if bigendian:
108            strflag = '>'
109        else:
110            strflag = '<'
111        self.file = file
112        self.chunkname = file.read(4)
113        if len(self.chunkname) < 4:
114            raise EOFError
115        try:
116            self.chunksize = struct.unpack_from(strflag+'L', file.read(4))[0]
117        except struct.error:
118            raise EOFError from None
119        if inclheader:
120            self.chunksize = self.chunksize - 8 # subtract header
121        self.size_read = 0
122        try:
123            self.offset = self.file.tell()
124        except (AttributeError, OSError):
125            self.seekable = False
126        else:
127            self.seekable = True
128
129    def getname(self):
130        """Return the name (ID) of the current chunk."""
131        return self.chunkname
132
133    def close(self):
134        if not self.closed:
135            try:
136                self.skip()
137            finally:
138                self.closed = True
139
140    def seek(self, pos, whence=0):
141        """Seek to specified position into the chunk.
142        Default position is 0 (start of chunk).
143        If the file is not seekable, this will result in an error.
144        """
145
146        if self.closed:
147            raise ValueError("I/O operation on closed file")
148        if not self.seekable:
149            raise OSError("cannot seek")
150        if whence == 1:
151            pos = pos + self.size_read
152        elif whence == 2:
153            pos = pos + self.chunksize
154        if pos < 0 or pos > self.chunksize:
155            raise RuntimeError
156        self.file.seek(self.offset + pos, 0)
157        self.size_read = pos
158
159    def tell(self):
160        if self.closed:
161            raise ValueError("I/O operation on closed file")
162        return self.size_read
163
164    def read(self, size=-1):
165        """Read at most size bytes from the chunk.
166        If size is omitted or negative, read until the end
167        of the chunk.
168        """
169
170        if self.closed:
171            raise ValueError("I/O operation on closed file")
172        if self.size_read >= self.chunksize:
173            return b''
174        if size < 0:
175            size = self.chunksize - self.size_read
176        if size > self.chunksize - self.size_read:
177            size = self.chunksize - self.size_read
178        data = self.file.read(size)
179        self.size_read = self.size_read + len(data)
180        if self.size_read == self.chunksize and \
181           self.align and \
182           (self.chunksize & 1):
183            dummy = self.file.read(1)
184            self.size_read = self.size_read + len(dummy)
185        return data
186
187    def skip(self):
188        """Skip the rest of the chunk.
189        If you are not interested in the contents of the chunk,
190        this method should be called so that the file points to
191        the start of the next chunk.
192        """
193
194        if self.closed:
195            raise ValueError("I/O operation on closed file")
196        if self.seekable:
197            try:
198                n = self.chunksize - self.size_read
199                # maybe fix alignment
200                if self.align and (self.chunksize & 1):
201                    n = n + 1
202                self.file.seek(n, 1)
203                self.size_read = self.size_read + n
204                return
205            except OSError:
206                pass
207        while self.size_read < self.chunksize:
208            n = min(8192, self.chunksize - self.size_read)
209            dummy = self.read(n)
210            if not dummy:
211                raise EOFError
212
213
214class Wave_read:
215    """Variables used in this class:
216
217    These variables are available to the user though appropriate
218    methods of this class:
219    _file -- the open file with methods read(), close(), and seek()
220              set through the __init__() method
221    _nchannels -- the number of audio channels
222              available through the getnchannels() method
223    _nframes -- the number of audio frames
224              available through the getnframes() method
225    _sampwidth -- the number of bytes per audio sample
226              available through the getsampwidth() method
227    _framerate -- the sampling frequency
228              available through the getframerate() method
229    _comptype -- the AIFF-C compression type ('NONE' if AIFF)
230              available through the getcomptype() method
231    _compname -- the human-readable AIFF-C compression type
232              available through the getcomptype() method
233    _soundpos -- the position in the audio stream
234              available through the tell() method, set through the
235              setpos() method
236
237    These variables are used internally only:
238    _fmt_chunk_read -- 1 iff the FMT chunk has been read
239    _data_seek_needed -- 1 iff positioned correctly in audio
240              file for readframes()
241    _data_chunk -- instantiation of a chunk class for the DATA chunk
242    _framesize -- size of one frame in the file
243    """
244
245    def initfp(self, file):
246        self._convert = None
247        self._soundpos = 0
248        self._file = _Chunk(file, bigendian = 0)
249        if self._file.getname() != b'RIFF':
250            raise Error('file does not start with RIFF id')
251        if self._file.read(4) != b'WAVE':
252            raise Error('not a WAVE file')
253        self._fmt_chunk_read = 0
254        self._data_chunk = None
255        while 1:
256            self._data_seek_needed = 1
257            try:
258                chunk = _Chunk(self._file, bigendian = 0)
259            except EOFError:
260                break
261            chunkname = chunk.getname()
262            if chunkname == b'fmt ':
263                self._read_fmt_chunk(chunk)
264                self._fmt_chunk_read = 1
265            elif chunkname == b'data':
266                if not self._fmt_chunk_read:
267                    raise Error('data chunk before fmt chunk')
268                self._data_chunk = chunk
269                self._nframes = chunk.chunksize // self._framesize
270                self._data_seek_needed = 0
271                break
272            chunk.skip()
273        if not self._fmt_chunk_read or not self._data_chunk:
274            raise Error('fmt chunk and/or data chunk missing')
275
276    def __init__(self, f):
277        self._i_opened_the_file = None
278        if isinstance(f, str):
279            f = builtins.open(f, 'rb')
280            self._i_opened_the_file = f
281        # else, assume it is an open file object already
282        try:
283            self.initfp(f)
284        except:
285            if self._i_opened_the_file:
286                f.close()
287            raise
288
289    def __del__(self):
290        self.close()
291
292    def __enter__(self):
293        return self
294
295    def __exit__(self, *args):
296        self.close()
297
298    #
299    # User visible methods.
300    #
301    def getfp(self):
302        return self._file
303
304    def rewind(self):
305        self._data_seek_needed = 1
306        self._soundpos = 0
307
308    def close(self):
309        self._file = None
310        file = self._i_opened_the_file
311        if file:
312            self._i_opened_the_file = None
313            file.close()
314
315    def tell(self):
316        return self._soundpos
317
318    def getnchannels(self):
319        return self._nchannels
320
321    def getnframes(self):
322        return self._nframes
323
324    def getsampwidth(self):
325        return self._sampwidth
326
327    def getframerate(self):
328        return self._framerate
329
330    def getcomptype(self):
331        return self._comptype
332
333    def getcompname(self):
334        return self._compname
335
336    def getparams(self):
337        return _wave_params(self.getnchannels(), self.getsampwidth(),
338                       self.getframerate(), self.getnframes(),
339                       self.getcomptype(), self.getcompname())
340
341    def getmarkers(self):
342        return None
343
344    def getmark(self, id):
345        raise Error('no marks')
346
347    def setpos(self, pos):
348        if pos < 0 or pos > self._nframes:
349            raise Error('position not in range')
350        self._soundpos = pos
351        self._data_seek_needed = 1
352
353    def readframes(self, nframes):
354        if self._data_seek_needed:
355            self._data_chunk.seek(0, 0)
356            pos = self._soundpos * self._framesize
357            if pos:
358                self._data_chunk.seek(pos, 0)
359            self._data_seek_needed = 0
360        if nframes == 0:
361            return b''
362        data = self._data_chunk.read(nframes * self._framesize)
363        if self._sampwidth != 1 and sys.byteorder == 'big':
364            data = _byteswap(data, self._sampwidth)
365        if self._convert and data:
366            data = self._convert(data)
367        self._soundpos = self._soundpos + len(data) // (self._nchannels * self._sampwidth)
368        return data
369
370    #
371    # Internal methods.
372    #
373
374    def _read_fmt_chunk(self, chunk):
375        try:
376            wFormatTag, self._nchannels, self._framerate, dwAvgBytesPerSec, wBlockAlign = struct.unpack_from('<HHLLH', chunk.read(14))
377        except struct.error:
378            raise EOFError from None
379        if wFormatTag == WAVE_FORMAT_PCM:
380            try:
381                sampwidth = struct.unpack_from('<H', chunk.read(2))[0]
382            except struct.error:
383                raise EOFError from None
384            self._sampwidth = (sampwidth + 7) // 8
385            if not self._sampwidth:
386                raise Error('bad sample width')
387        else:
388            raise Error('unknown format: %r' % (wFormatTag,))
389        if not self._nchannels:
390            raise Error('bad # of channels')
391        self._framesize = self._nchannels * self._sampwidth
392        self._comptype = 'NONE'
393        self._compname = 'not compressed'
394
395
396class Wave_write:
397    """Variables used in this class:
398
399    These variables are user settable through appropriate methods
400    of this class:
401    _file -- the open file with methods write(), close(), tell(), seek()
402              set through the __init__() method
403    _comptype -- the AIFF-C compression type ('NONE' in AIFF)
404              set through the setcomptype() or setparams() method
405    _compname -- the human-readable AIFF-C compression type
406              set through the setcomptype() or setparams() method
407    _nchannels -- the number of audio channels
408              set through the setnchannels() or setparams() method
409    _sampwidth -- the number of bytes per audio sample
410              set through the setsampwidth() or setparams() method
411    _framerate -- the sampling frequency
412              set through the setframerate() or setparams() method
413    _nframes -- the number of audio frames written to the header
414              set through the setnframes() or setparams() method
415
416    These variables are used internally only:
417    _datalength -- the size of the audio samples written to the header
418    _nframeswritten -- the number of frames actually written
419    _datawritten -- the size of the audio samples actually written
420    """
421
422    def __init__(self, f):
423        self._i_opened_the_file = None
424        if isinstance(f, str):
425            f = builtins.open(f, 'wb')
426            self._i_opened_the_file = f
427        try:
428            self.initfp(f)
429        except:
430            if self._i_opened_the_file:
431                f.close()
432            raise
433
434    def initfp(self, file):
435        self._file = file
436        self._convert = None
437        self._nchannels = 0
438        self._sampwidth = 0
439        self._framerate = 0
440        self._nframes = 0
441        self._nframeswritten = 0
442        self._datawritten = 0
443        self._datalength = 0
444        self._headerwritten = False
445
446    def __del__(self):
447        self.close()
448
449    def __enter__(self):
450        return self
451
452    def __exit__(self, *args):
453        self.close()
454
455    #
456    # User visible methods.
457    #
458    def setnchannels(self, nchannels):
459        if self._datawritten:
460            raise Error('cannot change parameters after starting to write')
461        if nchannels < 1:
462            raise Error('bad # of channels')
463        self._nchannels = nchannels
464
465    def getnchannels(self):
466        if not self._nchannels:
467            raise Error('number of channels not set')
468        return self._nchannels
469
470    def setsampwidth(self, sampwidth):
471        if self._datawritten:
472            raise Error('cannot change parameters after starting to write')
473        if sampwidth < 1 or sampwidth > 4:
474            raise Error('bad sample width')
475        self._sampwidth = sampwidth
476
477    def getsampwidth(self):
478        if not self._sampwidth:
479            raise Error('sample width not set')
480        return self._sampwidth
481
482    def setframerate(self, framerate):
483        if self._datawritten:
484            raise Error('cannot change parameters after starting to write')
485        if framerate <= 0:
486            raise Error('bad frame rate')
487        self._framerate = int(round(framerate))
488
489    def getframerate(self):
490        if not self._framerate:
491            raise Error('frame rate not set')
492        return self._framerate
493
494    def setnframes(self, nframes):
495        if self._datawritten:
496            raise Error('cannot change parameters after starting to write')
497        self._nframes = nframes
498
499    def getnframes(self):
500        return self._nframeswritten
501
502    def setcomptype(self, comptype, compname):
503        if self._datawritten:
504            raise Error('cannot change parameters after starting to write')
505        if comptype not in ('NONE',):
506            raise Error('unsupported compression type')
507        self._comptype = comptype
508        self._compname = compname
509
510    def getcomptype(self):
511        return self._comptype
512
513    def getcompname(self):
514        return self._compname
515
516    def setparams(self, params):
517        nchannels, sampwidth, framerate, nframes, comptype, compname = params
518        if self._datawritten:
519            raise Error('cannot change parameters after starting to write')
520        self.setnchannels(nchannels)
521        self.setsampwidth(sampwidth)
522        self.setframerate(framerate)
523        self.setnframes(nframes)
524        self.setcomptype(comptype, compname)
525
526    def getparams(self):
527        if not self._nchannels or not self._sampwidth or not self._framerate:
528            raise Error('not all parameters set')
529        return _wave_params(self._nchannels, self._sampwidth, self._framerate,
530              self._nframes, self._comptype, self._compname)
531
532    def setmark(self, id, pos, name):
533        raise Error('setmark() not supported')
534
535    def getmark(self, id):
536        raise Error('no marks')
537
538    def getmarkers(self):
539        return None
540
541    def tell(self):
542        return self._nframeswritten
543
544    def writeframesraw(self, data):
545        if not isinstance(data, (bytes, bytearray)):
546            data = memoryview(data).cast('B')
547        self._ensure_header_written(len(data))
548        nframes = len(data) // (self._sampwidth * self._nchannels)
549        if self._convert:
550            data = self._convert(data)
551        if self._sampwidth != 1 and sys.byteorder == 'big':
552            data = _byteswap(data, self._sampwidth)
553        self._file.write(data)
554        self._datawritten += len(data)
555        self._nframeswritten = self._nframeswritten + nframes
556
557    def writeframes(self, data):
558        self.writeframesraw(data)
559        if self._datalength != self._datawritten:
560            self._patchheader()
561
562    def close(self):
563        try:
564            if self._file:
565                self._ensure_header_written(0)
566                if self._datalength != self._datawritten:
567                    self._patchheader()
568                self._file.flush()
569        finally:
570            self._file = None
571            file = self._i_opened_the_file
572            if file:
573                self._i_opened_the_file = None
574                file.close()
575
576    #
577    # Internal methods.
578    #
579
580    def _ensure_header_written(self, datasize):
581        if not self._headerwritten:
582            if not self._nchannels:
583                raise Error('# channels not specified')
584            if not self._sampwidth:
585                raise Error('sample width not specified')
586            if not self._framerate:
587                raise Error('sampling rate not specified')
588            self._write_header(datasize)
589
590    def _write_header(self, initlength):
591        assert not self._headerwritten
592        self._file.write(b'RIFF')
593        if not self._nframes:
594            self._nframes = initlength // (self._nchannels * self._sampwidth)
595        self._datalength = self._nframes * self._nchannels * self._sampwidth
596        try:
597            self._form_length_pos = self._file.tell()
598        except (AttributeError, OSError):
599            self._form_length_pos = None
600        self._file.write(struct.pack('<L4s4sLHHLLHH4s',
601            36 + self._datalength, b'WAVE', b'fmt ', 16,
602            WAVE_FORMAT_PCM, self._nchannels, self._framerate,
603            self._nchannels * self._framerate * self._sampwidth,
604            self._nchannels * self._sampwidth,
605            self._sampwidth * 8, b'data'))
606        if self._form_length_pos is not None:
607            self._data_length_pos = self._file.tell()
608        self._file.write(struct.pack('<L', self._datalength))
609        self._headerwritten = True
610
611    def _patchheader(self):
612        assert self._headerwritten
613        if self._datawritten == self._datalength:
614            return
615        curpos = self._file.tell()
616        self._file.seek(self._form_length_pos, 0)
617        self._file.write(struct.pack('<L', 36 + self._datawritten))
618        self._file.seek(self._data_length_pos, 0)
619        self._file.write(struct.pack('<L', self._datawritten))
620        self._file.seek(curpos, 0)
621        self._datalength = self._datawritten
622
623
624def open(f, mode=None):
625    if mode is None:
626        if hasattr(f, 'mode'):
627            mode = f.mode
628        else:
629            mode = 'rb'
630    if mode in ('r', 'rb'):
631        return Wave_read(f)
632    elif mode in ('w', 'wb'):
633        return Wave_write(f)
634    else:
635        raise Error("mode must be 'r', 'rb', 'w', or 'wb'")
636