1"""Stuff to parse AIFF-C and AIFF files.
2
3Unless explicitly stated otherwise, the description below is true
4both for AIFF-C files and AIFF files.
5
6An AIFF-C file has the following structure.
7
8  +-----------------+
9  | FORM            |
10  +-----------------+
11  | <size>          |
12  +----+------------+
13  |    | AIFC       |
14  |    +------------+
15  |    | <chunks>   |
16  |    |    .       |
17  |    |    .       |
18  |    |    .       |
19  +----+------------+
20
21An AIFF file has the string "AIFF" instead of "AIFC".
22
23A chunk consists of an identifier (4 bytes) followed by a size (4 bytes,
24big endian order), followed by the data.  The size field does not include
25the size of the 8 byte header.
26
27The following chunk types are recognized.
28
29  FVER
30      <version number of AIFF-C defining document> (AIFF-C only).
31  MARK
32      <# of markers> (2 bytes)
33      list of markers:
34          <marker ID> (2 bytes, must be > 0)
35          <position> (4 bytes)
36          <marker name> ("pstring")
37  COMM
38      <# of channels> (2 bytes)
39      <# of sound frames> (4 bytes)
40      <size of the samples> (2 bytes)
41      <sampling frequency> (10 bytes, IEEE 80-bit extended
42          floating point)
43      in AIFF-C files only:
44      <compression type> (4 bytes)
45      <human-readable version of compression type> ("pstring")
46  SSND
47      <offset> (4 bytes, not used by this program)
48      <blocksize> (4 bytes, not used by this program)
49      <sound data>
50
51A pstring consists of 1 byte length, a string of characters, and 0 or 1
52byte pad to make the total length even.
53
54Usage.
55
56Reading AIFF files:
57  f = aifc.open(file, 'r')
58where file is either the name of a file or an open file pointer.
59The open file pointer must have methods read(), seek(), and close().
60In some types of audio files, if the setpos() method is not used,
61the seek() method is not necessary.
62
63This returns an instance of a class with the following public methods:
64  getnchannels()  -- returns number of audio channels (1 for
65             mono, 2 for stereo)
66  getsampwidth()  -- returns sample width in bytes
67  getframerate()  -- returns sampling frequency
68  getnframes()    -- returns number of audio frames
69  getcomptype()   -- returns compression type ('NONE' for AIFF files)
70  getcompname()   -- returns human-readable version of
71             compression type ('not compressed' for AIFF files)
72  getparams() -- returns a namedtuple consisting of all of the
73             above in the above order
74  getmarkers()    -- get the list of marks in the audio file or None
75             if there are no marks
76  getmark(id) -- get mark with the specified id (raises an error
77             if the mark does not exist)
78  readframes(n)   -- returns at most n frames of audio
79  rewind()    -- rewind to the beginning of the audio stream
80  setpos(pos) -- seek to the specified position
81  tell()      -- return the current position
82  close()     -- close the instance (make it unusable)
83The position returned by tell(), the position given to setpos() and
84the position of marks are all compatible and have nothing to do with
85the actual position in the file.
86The close() method is called automatically when the class instance
87is destroyed.
88
89Writing AIFF files:
90  f = aifc.open(file, 'w')
91where file is either the name of a file or an open file pointer.
92The open file pointer must have methods write(), tell(), seek(), and
93close().
94
95This returns an instance of a class with the following public methods:
96  aiff()      -- create an AIFF file (AIFF-C default)
97  aifc()      -- create an AIFF-C file
98  setnchannels(n) -- set the number of channels
99  setsampwidth(n) -- set the sample width
100  setframerate(n) -- set the frame rate
101  setnframes(n)   -- set the number of frames
102  setcomptype(type, name)
103          -- set the compression type and the
104             human-readable compression type
105  setparams(tuple)
106          -- set all parameters at once
107  setmark(id, pos, name)
108          -- add specified mark to the list of marks
109  tell()      -- return current position in output file (useful
110             in combination with setmark())
111  writeframesraw(data)
112          -- write audio frames without pathing up the
113             file header
114  writeframes(data)
115          -- write audio frames and patch up the file header
116  close()     -- patch up the file header and close the
117             output file
118You should set the parameters before the first writeframesraw or
119writeframes.  The total number of frames does not need to be set,
120but when it is set to the correct value, the header does not have to
121be patched up.
122It is best to first set all parameters, perhaps possibly the
123compression type, and then write audio frames using writeframesraw.
124When all frames have been written, either call writeframes(b'') or
125close() to patch up the sizes in the header.
126Marks can be added anytime.  If there are any marks, you must call
127close() after all frames have been written.
128The close() method is called automatically when the class instance
129is destroyed.
130
131When a file is opened with the extension '.aiff', an AIFF file is
132written, otherwise an AIFF-C file is written.  This default can be
133changed by calling aiff() or aifc() before the first writeframes or
134writeframesraw.
135"""
136
137import struct
138import builtins
139import warnings
140
141__all__ = ["Error", "open"]
142
143
144warnings._deprecated(__name__, remove=(3, 13))
145
146
147class Error(Exception):
148    pass
149
150_AIFC_version = 0xA2805140     # Version 1 of AIFF-C
151
152def _read_long(file):
153    try:
154        return struct.unpack('>l', file.read(4))[0]
155    except struct.error:
156        raise EOFError from None
157
158def _read_ulong(file):
159    try:
160        return struct.unpack('>L', file.read(4))[0]
161    except struct.error:
162        raise EOFError from None
163
164def _read_short(file):
165    try:
166        return struct.unpack('>h', file.read(2))[0]
167    except struct.error:
168        raise EOFError from None
169
170def _read_ushort(file):
171    try:
172        return struct.unpack('>H', file.read(2))[0]
173    except struct.error:
174        raise EOFError from None
175
176def _read_string(file):
177    length = ord(file.read(1))
178    if length == 0:
179        data = b''
180    else:
181        data = file.read(length)
182    if length & 1 == 0:
183        dummy = file.read(1)
184    return data
185
186_HUGE_VAL = 1.79769313486231e+308 # See <limits.h>
187
188def _read_float(f): # 10 bytes
189    expon = _read_short(f) # 2 bytes
190    sign = 1
191    if expon < 0:
192        sign = -1
193        expon = expon + 0x8000
194    himant = _read_ulong(f) # 4 bytes
195    lomant = _read_ulong(f) # 4 bytes
196    if expon == himant == lomant == 0:
197        f = 0.0
198    elif expon == 0x7FFF:
199        f = _HUGE_VAL
200    else:
201        expon = expon - 16383
202        f = (himant * 0x100000000 + lomant) * pow(2.0, expon - 63)
203    return sign * f
204
205def _write_short(f, x):
206    f.write(struct.pack('>h', x))
207
208def _write_ushort(f, x):
209    f.write(struct.pack('>H', x))
210
211def _write_long(f, x):
212    f.write(struct.pack('>l', x))
213
214def _write_ulong(f, x):
215    f.write(struct.pack('>L', x))
216
217def _write_string(f, s):
218    if len(s) > 255:
219        raise ValueError("string exceeds maximum pstring length")
220    f.write(struct.pack('B', len(s)))
221    f.write(s)
222    if len(s) & 1 == 0:
223        f.write(b'\x00')
224
225def _write_float(f, x):
226    import math
227    if x < 0:
228        sign = 0x8000
229        x = x * -1
230    else:
231        sign = 0
232    if x == 0:
233        expon = 0
234        himant = 0
235        lomant = 0
236    else:
237        fmant, expon = math.frexp(x)
238        if expon > 16384 or fmant >= 1 or fmant != fmant: # Infinity or NaN
239            expon = sign|0x7FFF
240            himant = 0
241            lomant = 0
242        else:                   # Finite
243            expon = expon + 16382
244            if expon < 0:           # denormalized
245                fmant = math.ldexp(fmant, expon)
246                expon = 0
247            expon = expon | sign
248            fmant = math.ldexp(fmant, 32)
249            fsmant = math.floor(fmant)
250            himant = int(fsmant)
251            fmant = math.ldexp(fmant - fsmant, 32)
252            fsmant = math.floor(fmant)
253            lomant = int(fsmant)
254    _write_ushort(f, expon)
255    _write_ulong(f, himant)
256    _write_ulong(f, lomant)
257
258with warnings.catch_warnings():
259    warnings.simplefilter("ignore", DeprecationWarning)
260    from chunk import Chunk
261from collections import namedtuple
262
263_aifc_params = namedtuple('_aifc_params',
264                          'nchannels sampwidth framerate nframes comptype compname')
265
266_aifc_params.nchannels.__doc__ = 'Number of audio channels (1 for mono, 2 for stereo)'
267_aifc_params.sampwidth.__doc__ = 'Sample width in bytes'
268_aifc_params.framerate.__doc__ = 'Sampling frequency'
269_aifc_params.nframes.__doc__ = 'Number of audio frames'
270_aifc_params.comptype.__doc__ = 'Compression type ("NONE" for AIFF files)'
271_aifc_params.compname.__doc__ = ("""\
272A human-readable version of the compression type
273('not compressed' for AIFF files)""")
274
275
276class Aifc_read:
277    # Variables used in this class:
278    #
279    # These variables are available to the user though appropriate
280    # methods of this class:
281    # _file -- the open file with methods read(), close(), and seek()
282    #       set through the __init__() method
283    # _nchannels -- the number of audio channels
284    #       available through the getnchannels() method
285    # _nframes -- the number of audio frames
286    #       available through the getnframes() method
287    # _sampwidth -- the number of bytes per audio sample
288    #       available through the getsampwidth() method
289    # _framerate -- the sampling frequency
290    #       available through the getframerate() method
291    # _comptype -- the AIFF-C compression type ('NONE' if AIFF)
292    #       available through the getcomptype() method
293    # _compname -- the human-readable AIFF-C compression type
294    #       available through the getcomptype() method
295    # _markers -- the marks in the audio file
296    #       available through the getmarkers() and getmark()
297    #       methods
298    # _soundpos -- the position in the audio stream
299    #       available through the tell() method, set through the
300    #       setpos() method
301    #
302    # These variables are used internally only:
303    # _version -- the AIFF-C version number
304    # _decomp -- the decompressor from builtin module cl
305    # _comm_chunk_read -- 1 iff the COMM chunk has been read
306    # _aifc -- 1 iff reading an AIFF-C file
307    # _ssnd_seek_needed -- 1 iff positioned correctly in audio
308    #       file for readframes()
309    # _ssnd_chunk -- instantiation of a chunk class for the SSND chunk
310    # _framesize -- size of one frame in the file
311
312    _file = None  # Set here since __del__ checks it
313
314    def initfp(self, file):
315        self._version = 0
316        self._convert = None
317        self._markers = []
318        self._soundpos = 0
319        self._file = file
320        chunk = Chunk(file)
321        if chunk.getname() != b'FORM':
322            raise Error('file does not start with FORM id')
323        formdata = chunk.read(4)
324        if formdata == b'AIFF':
325            self._aifc = 0
326        elif formdata == b'AIFC':
327            self._aifc = 1
328        else:
329            raise Error('not an AIFF or AIFF-C file')
330        self._comm_chunk_read = 0
331        self._ssnd_chunk = None
332        while 1:
333            self._ssnd_seek_needed = 1
334            try:
335                chunk = Chunk(self._file)
336            except EOFError:
337                break
338            chunkname = chunk.getname()
339            if chunkname == b'COMM':
340                self._read_comm_chunk(chunk)
341                self._comm_chunk_read = 1
342            elif chunkname == b'SSND':
343                self._ssnd_chunk = chunk
344                dummy = chunk.read(8)
345                self._ssnd_seek_needed = 0
346            elif chunkname == b'FVER':
347                self._version = _read_ulong(chunk)
348            elif chunkname == b'MARK':
349                self._readmark(chunk)
350            chunk.skip()
351        if not self._comm_chunk_read or not self._ssnd_chunk:
352            raise Error('COMM chunk and/or SSND chunk missing')
353
354    def __init__(self, f):
355        if isinstance(f, str):
356            file_object = builtins.open(f, 'rb')
357            try:
358                self.initfp(file_object)
359            except:
360                file_object.close()
361                raise
362        else:
363            # assume it is an open file object already
364            self.initfp(f)
365
366    def __enter__(self):
367        return self
368
369    def __exit__(self, *args):
370        self.close()
371
372    #
373    # User visible methods.
374    #
375    def getfp(self):
376        return self._file
377
378    def rewind(self):
379        self._ssnd_seek_needed = 1
380        self._soundpos = 0
381
382    def close(self):
383        file = self._file
384        if file is not None:
385            self._file = None
386            file.close()
387
388    def tell(self):
389        return self._soundpos
390
391    def getnchannels(self):
392        return self._nchannels
393
394    def getnframes(self):
395        return self._nframes
396
397    def getsampwidth(self):
398        return self._sampwidth
399
400    def getframerate(self):
401        return self._framerate
402
403    def getcomptype(self):
404        return self._comptype
405
406    def getcompname(self):
407        return self._compname
408
409##  def getversion(self):
410##      return self._version
411
412    def getparams(self):
413        return _aifc_params(self.getnchannels(), self.getsampwidth(),
414                            self.getframerate(), self.getnframes(),
415                            self.getcomptype(), self.getcompname())
416
417    def getmarkers(self):
418        if len(self._markers) == 0:
419            return None
420        return self._markers
421
422    def getmark(self, id):
423        for marker in self._markers:
424            if id == marker[0]:
425                return marker
426        raise Error('marker {0!r} does not exist'.format(id))
427
428    def setpos(self, pos):
429        if pos < 0 or pos > self._nframes:
430            raise Error('position not in range')
431        self._soundpos = pos
432        self._ssnd_seek_needed = 1
433
434    def readframes(self, nframes):
435        if self._ssnd_seek_needed:
436            self._ssnd_chunk.seek(0)
437            dummy = self._ssnd_chunk.read(8)
438            pos = self._soundpos * self._framesize
439            if pos:
440                self._ssnd_chunk.seek(pos + 8)
441            self._ssnd_seek_needed = 0
442        if nframes == 0:
443            return b''
444        data = self._ssnd_chunk.read(nframes * self._framesize)
445        if self._convert and data:
446            data = self._convert(data)
447        self._soundpos = self._soundpos + len(data) // (self._nchannels
448                                                        * self._sampwidth)
449        return data
450
451    #
452    # Internal methods.
453    #
454
455    def _alaw2lin(self, data):
456        with warnings.catch_warnings():
457            warnings.simplefilter('ignore', category=DeprecationWarning)
458            import audioop
459        return audioop.alaw2lin(data, 2)
460
461    def _ulaw2lin(self, data):
462        with warnings.catch_warnings():
463            warnings.simplefilter('ignore', category=DeprecationWarning)
464            import audioop
465        return audioop.ulaw2lin(data, 2)
466
467    def _adpcm2lin(self, data):
468        with warnings.catch_warnings():
469            warnings.simplefilter('ignore', category=DeprecationWarning)
470            import audioop
471        if not hasattr(self, '_adpcmstate'):
472            # first time
473            self._adpcmstate = None
474        data, self._adpcmstate = audioop.adpcm2lin(data, 2, self._adpcmstate)
475        return data
476
477    def _sowt2lin(self, data):
478        with warnings.catch_warnings():
479            warnings.simplefilter('ignore', category=DeprecationWarning)
480            import audioop
481        return audioop.byteswap(data, 2)
482
483    def _read_comm_chunk(self, chunk):
484        self._nchannels = _read_short(chunk)
485        self._nframes = _read_long(chunk)
486        self._sampwidth = (_read_short(chunk) + 7) // 8
487        self._framerate = int(_read_float(chunk))
488        if self._sampwidth <= 0:
489            raise Error('bad sample width')
490        if self._nchannels <= 0:
491            raise Error('bad # of channels')
492        self._framesize = self._nchannels * self._sampwidth
493        if self._aifc:
494            #DEBUG: SGI's soundeditor produces a bad size :-(
495            kludge = 0
496            if chunk.chunksize == 18:
497                kludge = 1
498                warnings.warn('Warning: bad COMM chunk size')
499                chunk.chunksize = 23
500            #DEBUG end
501            self._comptype = chunk.read(4)
502            #DEBUG start
503            if kludge:
504                length = ord(chunk.file.read(1))
505                if length & 1 == 0:
506                    length = length + 1
507                chunk.chunksize = chunk.chunksize + length
508                chunk.file.seek(-1, 1)
509            #DEBUG end
510            self._compname = _read_string(chunk)
511            if self._comptype != b'NONE':
512                if self._comptype == b'G722':
513                    self._convert = self._adpcm2lin
514                elif self._comptype in (b'ulaw', b'ULAW'):
515                    self._convert = self._ulaw2lin
516                elif self._comptype in (b'alaw', b'ALAW'):
517                    self._convert = self._alaw2lin
518                elif self._comptype in (b'sowt', b'SOWT'):
519                    self._convert = self._sowt2lin
520                else:
521                    raise Error('unsupported compression type')
522                self._sampwidth = 2
523        else:
524            self._comptype = b'NONE'
525            self._compname = b'not compressed'
526
527    def _readmark(self, chunk):
528        nmarkers = _read_short(chunk)
529        # Some files appear to contain invalid counts.
530        # Cope with this by testing for EOF.
531        try:
532            for i in range(nmarkers):
533                id = _read_short(chunk)
534                pos = _read_long(chunk)
535                name = _read_string(chunk)
536                if pos or name:
537                    # some files appear to have
538                    # dummy markers consisting of
539                    # a position 0 and name ''
540                    self._markers.append((id, pos, name))
541        except EOFError:
542            w = ('Warning: MARK chunk contains only %s marker%s instead of %s' %
543                 (len(self._markers), '' if len(self._markers) == 1 else 's',
544                  nmarkers))
545            warnings.warn(w)
546
547class Aifc_write:
548    # Variables used in this class:
549    #
550    # These variables are user settable through appropriate methods
551    # of this class:
552    # _file -- the open file with methods write(), close(), tell(), seek()
553    #       set through the __init__() method
554    # _comptype -- the AIFF-C compression type ('NONE' in AIFF)
555    #       set through the setcomptype() or setparams() method
556    # _compname -- the human-readable AIFF-C compression type
557    #       set through the setcomptype() or setparams() method
558    # _nchannels -- the number of audio channels
559    #       set through the setnchannels() or setparams() method
560    # _sampwidth -- the number of bytes per audio sample
561    #       set through the setsampwidth() or setparams() method
562    # _framerate -- the sampling frequency
563    #       set through the setframerate() or setparams() method
564    # _nframes -- the number of audio frames written to the header
565    #       set through the setnframes() or setparams() method
566    # _aifc -- whether we're writing an AIFF-C file or an AIFF file
567    #       set through the aifc() method, reset through the
568    #       aiff() method
569    #
570    # These variables are used internally only:
571    # _version -- the AIFF-C version number
572    # _comp -- the compressor from builtin module cl
573    # _nframeswritten -- the number of audio frames actually written
574    # _datalength -- the size of the audio samples written to the header
575    # _datawritten -- the size of the audio samples actually written
576
577    _file = None  # Set here since __del__ checks it
578
579    def __init__(self, f):
580        if isinstance(f, str):
581            file_object = builtins.open(f, 'wb')
582            try:
583                self.initfp(file_object)
584            except:
585                file_object.close()
586                raise
587
588            # treat .aiff file extensions as non-compressed audio
589            if f.endswith('.aiff'):
590                self._aifc = 0
591        else:
592            # assume it is an open file object already
593            self.initfp(f)
594
595    def initfp(self, file):
596        self._file = file
597        self._version = _AIFC_version
598        self._comptype = b'NONE'
599        self._compname = b'not compressed'
600        self._convert = None
601        self._nchannels = 0
602        self._sampwidth = 0
603        self._framerate = 0
604        self._nframes = 0
605        self._nframeswritten = 0
606        self._datawritten = 0
607        self._datalength = 0
608        self._markers = []
609        self._marklength = 0
610        self._aifc = 1      # AIFF-C is default
611
612    def __del__(self):
613        self.close()
614
615    def __enter__(self):
616        return self
617
618    def __exit__(self, *args):
619        self.close()
620
621    #
622    # User visible methods.
623    #
624    def aiff(self):
625        if self._nframeswritten:
626            raise Error('cannot change parameters after starting to write')
627        self._aifc = 0
628
629    def aifc(self):
630        if self._nframeswritten:
631            raise Error('cannot change parameters after starting to write')
632        self._aifc = 1
633
634    def setnchannels(self, nchannels):
635        if self._nframeswritten:
636            raise Error('cannot change parameters after starting to write')
637        if nchannels < 1:
638            raise Error('bad # of channels')
639        self._nchannels = nchannels
640
641    def getnchannels(self):
642        if not self._nchannels:
643            raise Error('number of channels not set')
644        return self._nchannels
645
646    def setsampwidth(self, sampwidth):
647        if self._nframeswritten:
648            raise Error('cannot change parameters after starting to write')
649        if sampwidth < 1 or sampwidth > 4:
650            raise Error('bad sample width')
651        self._sampwidth = sampwidth
652
653    def getsampwidth(self):
654        if not self._sampwidth:
655            raise Error('sample width not set')
656        return self._sampwidth
657
658    def setframerate(self, framerate):
659        if self._nframeswritten:
660            raise Error('cannot change parameters after starting to write')
661        if framerate <= 0:
662            raise Error('bad frame rate')
663        self._framerate = framerate
664
665    def getframerate(self):
666        if not self._framerate:
667            raise Error('frame rate not set')
668        return self._framerate
669
670    def setnframes(self, nframes):
671        if self._nframeswritten:
672            raise Error('cannot change parameters after starting to write')
673        self._nframes = nframes
674
675    def getnframes(self):
676        return self._nframeswritten
677
678    def setcomptype(self, comptype, compname):
679        if self._nframeswritten:
680            raise Error('cannot change parameters after starting to write')
681        if comptype not in (b'NONE', b'ulaw', b'ULAW',
682                            b'alaw', b'ALAW', b'G722', b'sowt', b'SOWT'):
683            raise Error('unsupported compression type')
684        self._comptype = comptype
685        self._compname = compname
686
687    def getcomptype(self):
688        return self._comptype
689
690    def getcompname(self):
691        return self._compname
692
693##  def setversion(self, version):
694##      if self._nframeswritten:
695##          raise Error, 'cannot change parameters after starting to write'
696##      self._version = version
697
698    def setparams(self, params):
699        nchannels, sampwidth, framerate, nframes, comptype, compname = params
700        if self._nframeswritten:
701            raise Error('cannot change parameters after starting to write')
702        if comptype not in (b'NONE', b'ulaw', b'ULAW',
703                            b'alaw', b'ALAW', b'G722', b'sowt', b'SOWT'):
704            raise Error('unsupported compression type')
705        self.setnchannels(nchannels)
706        self.setsampwidth(sampwidth)
707        self.setframerate(framerate)
708        self.setnframes(nframes)
709        self.setcomptype(comptype, compname)
710
711    def getparams(self):
712        if not self._nchannels or not self._sampwidth or not self._framerate:
713            raise Error('not all parameters set')
714        return _aifc_params(self._nchannels, self._sampwidth, self._framerate,
715                            self._nframes, self._comptype, self._compname)
716
717    def setmark(self, id, pos, name):
718        if id <= 0:
719            raise Error('marker ID must be > 0')
720        if pos < 0:
721            raise Error('marker position must be >= 0')
722        if not isinstance(name, bytes):
723            raise Error('marker name must be bytes')
724        for i in range(len(self._markers)):
725            if id == self._markers[i][0]:
726                self._markers[i] = id, pos, name
727                return
728        self._markers.append((id, pos, name))
729
730    def getmark(self, id):
731        for marker in self._markers:
732            if id == marker[0]:
733                return marker
734        raise Error('marker {0!r} does not exist'.format(id))
735
736    def getmarkers(self):
737        if len(self._markers) == 0:
738            return None
739        return self._markers
740
741    def tell(self):
742        return self._nframeswritten
743
744    def writeframesraw(self, data):
745        if not isinstance(data, (bytes, bytearray)):
746            data = memoryview(data).cast('B')
747        self._ensure_header_written(len(data))
748        nframes = len(data) // (self._sampwidth * self._nchannels)
749        if self._convert:
750            data = self._convert(data)
751        self._file.write(data)
752        self._nframeswritten = self._nframeswritten + nframes
753        self._datawritten = self._datawritten + len(data)
754
755    def writeframes(self, data):
756        self.writeframesraw(data)
757        if self._nframeswritten != self._nframes or \
758              self._datalength != self._datawritten:
759            self._patchheader()
760
761    def close(self):
762        if self._file is None:
763            return
764        try:
765            self._ensure_header_written(0)
766            if self._datawritten & 1:
767                # quick pad to even size
768                self._file.write(b'\x00')
769                self._datawritten = self._datawritten + 1
770            self._writemarkers()
771            if self._nframeswritten != self._nframes or \
772                  self._datalength != self._datawritten or \
773                  self._marklength:
774                self._patchheader()
775        finally:
776            # Prevent ref cycles
777            self._convert = None
778            f = self._file
779            self._file = None
780            f.close()
781
782    #
783    # Internal methods.
784    #
785
786    def _lin2alaw(self, data):
787        with warnings.catch_warnings():
788            warnings.simplefilter('ignore', category=DeprecationWarning)
789            import audioop
790        return audioop.lin2alaw(data, 2)
791
792    def _lin2ulaw(self, data):
793        with warnings.catch_warnings():
794            warnings.simplefilter('ignore', category=DeprecationWarning)
795            import audioop
796        return audioop.lin2ulaw(data, 2)
797
798    def _lin2adpcm(self, data):
799        with warnings.catch_warnings():
800            warnings.simplefilter('ignore', category=DeprecationWarning)
801            import audioop
802        if not hasattr(self, '_adpcmstate'):
803            self._adpcmstate = None
804        data, self._adpcmstate = audioop.lin2adpcm(data, 2, self._adpcmstate)
805        return data
806
807    def _lin2sowt(self, data):
808        with warnings.catch_warnings():
809            warnings.simplefilter('ignore', category=DeprecationWarning)
810            import audioop
811        return audioop.byteswap(data, 2)
812
813    def _ensure_header_written(self, datasize):
814        if not self._nframeswritten:
815            if self._comptype in (b'ULAW', b'ulaw',
816                b'ALAW', b'alaw', b'G722',
817                b'sowt', b'SOWT'):
818                if not self._sampwidth:
819                    self._sampwidth = 2
820                if self._sampwidth != 2:
821                    raise Error('sample width must be 2 when compressing '
822                                'with ulaw/ULAW, alaw/ALAW, sowt/SOWT '
823                                'or G7.22 (ADPCM)')
824            if not self._nchannels:
825                raise Error('# channels not specified')
826            if not self._sampwidth:
827                raise Error('sample width not specified')
828            if not self._framerate:
829                raise Error('sampling rate not specified')
830            self._write_header(datasize)
831
832    def _init_compression(self):
833        if self._comptype == b'G722':
834            self._convert = self._lin2adpcm
835        elif self._comptype in (b'ulaw', b'ULAW'):
836            self._convert = self._lin2ulaw
837        elif self._comptype in (b'alaw', b'ALAW'):
838            self._convert = self._lin2alaw
839        elif self._comptype in (b'sowt', b'SOWT'):
840            self._convert = self._lin2sowt
841
842    def _write_header(self, initlength):
843        if self._aifc and self._comptype != b'NONE':
844            self._init_compression()
845        self._file.write(b'FORM')
846        if not self._nframes:
847            self._nframes = initlength // (self._nchannels * self._sampwidth)
848        self._datalength = self._nframes * self._nchannels * self._sampwidth
849        if self._datalength & 1:
850            self._datalength = self._datalength + 1
851        if self._aifc:
852            if self._comptype in (b'ulaw', b'ULAW', b'alaw', b'ALAW'):
853                self._datalength = self._datalength // 2
854                if self._datalength & 1:
855                    self._datalength = self._datalength + 1
856            elif self._comptype == b'G722':
857                self._datalength = (self._datalength + 3) // 4
858                if self._datalength & 1:
859                    self._datalength = self._datalength + 1
860        try:
861            self._form_length_pos = self._file.tell()
862        except (AttributeError, OSError):
863            self._form_length_pos = None
864        commlength = self._write_form_length(self._datalength)
865        if self._aifc:
866            self._file.write(b'AIFC')
867            self._file.write(b'FVER')
868            _write_ulong(self._file, 4)
869            _write_ulong(self._file, self._version)
870        else:
871            self._file.write(b'AIFF')
872        self._file.write(b'COMM')
873        _write_ulong(self._file, commlength)
874        _write_short(self._file, self._nchannels)
875        if self._form_length_pos is not None:
876            self._nframes_pos = self._file.tell()
877        _write_ulong(self._file, self._nframes)
878        if self._comptype in (b'ULAW', b'ulaw', b'ALAW', b'alaw', b'G722'):
879            _write_short(self._file, 8)
880        else:
881            _write_short(self._file, self._sampwidth * 8)
882        _write_float(self._file, self._framerate)
883        if self._aifc:
884            self._file.write(self._comptype)
885            _write_string(self._file, self._compname)
886        self._file.write(b'SSND')
887        if self._form_length_pos is not None:
888            self._ssnd_length_pos = self._file.tell()
889        _write_ulong(self._file, self._datalength + 8)
890        _write_ulong(self._file, 0)
891        _write_ulong(self._file, 0)
892
893    def _write_form_length(self, datalength):
894        if self._aifc:
895            commlength = 18 + 5 + len(self._compname)
896            if commlength & 1:
897                commlength = commlength + 1
898            verslength = 12
899        else:
900            commlength = 18
901            verslength = 0
902        _write_ulong(self._file, 4 + verslength + self._marklength + \
903                     8 + commlength + 16 + datalength)
904        return commlength
905
906    def _patchheader(self):
907        curpos = self._file.tell()
908        if self._datawritten & 1:
909            datalength = self._datawritten + 1
910            self._file.write(b'\x00')
911        else:
912            datalength = self._datawritten
913        if datalength == self._datalength and \
914              self._nframes == self._nframeswritten and \
915              self._marklength == 0:
916            self._file.seek(curpos, 0)
917            return
918        self._file.seek(self._form_length_pos, 0)
919        dummy = self._write_form_length(datalength)
920        self._file.seek(self._nframes_pos, 0)
921        _write_ulong(self._file, self._nframeswritten)
922        self._file.seek(self._ssnd_length_pos, 0)
923        _write_ulong(self._file, datalength + 8)
924        self._file.seek(curpos, 0)
925        self._nframes = self._nframeswritten
926        self._datalength = datalength
927
928    def _writemarkers(self):
929        if len(self._markers) == 0:
930            return
931        self._file.write(b'MARK')
932        length = 2
933        for marker in self._markers:
934            id, pos, name = marker
935            length = length + len(name) + 1 + 6
936            if len(name) & 1 == 0:
937                length = length + 1
938        _write_ulong(self._file, length)
939        self._marklength = length + 8
940        _write_short(self._file, len(self._markers))
941        for marker in self._markers:
942            id, pos, name = marker
943            _write_short(self._file, id)
944            _write_ulong(self._file, pos)
945            _write_string(self._file, name)
946
947def open(f, mode=None):
948    if mode is None:
949        if hasattr(f, 'mode'):
950            mode = f.mode
951        else:
952            mode = 'rb'
953    if mode in ('r', 'rb'):
954        return Aifc_read(f)
955    elif mode in ('w', 'wb'):
956        return Aifc_write(f)
957    else:
958        raise Error("mode must be 'r', 'rb', 'w', or 'wb'")
959
960
961if __name__ == '__main__':
962    import sys
963    if not sys.argv[1:]:
964        sys.argv.append('/usr/demos/data/audio/bach.aiff')
965    fn = sys.argv[1]
966    with open(fn, 'r') as f:
967        print("Reading", fn)
968        print("nchannels =", f.getnchannels())
969        print("nframes   =", f.getnframes())
970        print("sampwidth =", f.getsampwidth())
971        print("framerate =", f.getframerate())
972        print("comptype  =", f.getcomptype())
973        print("compname  =", f.getcompname())
974        if sys.argv[2:]:
975            gn = sys.argv[2]
976            print("Writing", gn)
977            with open(gn, 'w') as g:
978                g.setparams(f.getparams())
979                while 1:
980                    data = f.readframes(1024)
981                    if not data:
982                        break
983                    g.writeframes(data)
984            print("Done.")
985