1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package lzw implements the Lempel-Ziv-Welch compressed data format,
6// described in T. A. Welch, “A Technique for High-Performance Data
7// Compression”, Computer, 17(6) (June 1984), pp 8-19.
8//
9// In particular, it implements LZW as used by the GIF and PDF file
10// formats, which means variable-width codes up to 12 bits and the first
11// two non-literal codes are a clear code and an EOF code.
12//
13// The TIFF file format uses a similar but incompatible version of the LZW
14// algorithm. See the golang.org/x/image/tiff/lzw package for an
15// implementation.
16package lzw
17
18// TODO(nigeltao): check that PDF uses LZW in the same way as GIF,
19// modulo LSB/MSB packing order.
20
21import (
22	"bufio"
23	"errors"
24	"fmt"
25	"io"
26)
27
28// Order specifies the bit ordering in an LZW data stream.
29type Order int
30
31const (
32	// LSB means Least Significant Bits first, as used in the GIF file format.
33	LSB Order = iota
34	// MSB means Most Significant Bits first, as used in the TIFF and PDF
35	// file formats.
36	MSB
37)
38
39const (
40	maxWidth           = 12
41	decoderInvalidCode = 0xffff
42	flushBuffer        = 1 << maxWidth
43)
44
45// Reader is an io.Reader which can be used to read compressed data in the
46// LZW format.
47type Reader struct {
48	r        io.ByteReader
49	bits     uint32
50	nBits    uint
51	width    uint
52	read     func(*Reader) (uint16, error) // readLSB or readMSB
53	litWidth int                           // width in bits of literal codes
54	err      error
55
56	// The first 1<<litWidth codes are literal codes.
57	// The next two codes mean clear and EOF.
58	// Other valid codes are in the range [lo, hi] where lo := clear + 2,
59	// with the upper bound incrementing on each code seen.
60	//
61	// overflow is the code at which hi overflows the code width. It always
62	// equals 1 << width.
63	//
64	// last is the most recently seen code, or decoderInvalidCode.
65	//
66	// An invariant is that hi < overflow.
67	clear, eof, hi, overflow, last uint16
68
69	// Each code c in [lo, hi] expands to two or more bytes. For c != hi:
70	//   suffix[c] is the last of these bytes.
71	//   prefix[c] is the code for all but the last byte.
72	//   This code can either be a literal code or another code in [lo, c).
73	// The c == hi case is a special case.
74	suffix [1 << maxWidth]uint8
75	prefix [1 << maxWidth]uint16
76
77	// output is the temporary output buffer.
78	// Literal codes are accumulated from the start of the buffer.
79	// Non-literal codes decode to a sequence of suffixes that are first
80	// written right-to-left from the end of the buffer before being copied
81	// to the start of the buffer.
82	// It is flushed when it contains >= 1<<maxWidth bytes,
83	// so that there is always room to decode an entire code.
84	output [2 * 1 << maxWidth]byte
85	o      int    // write index into output
86	toRead []byte // bytes to return from Read
87}
88
89// readLSB returns the next code for "Least Significant Bits first" data.
90func (r *Reader) readLSB() (uint16, error) {
91	for r.nBits < r.width {
92		x, err := r.r.ReadByte()
93		if err != nil {
94			return 0, err
95		}
96		r.bits |= uint32(x) << r.nBits
97		r.nBits += 8
98	}
99	code := uint16(r.bits & (1<<r.width - 1))
100	r.bits >>= r.width
101	r.nBits -= r.width
102	return code, nil
103}
104
105// readMSB returns the next code for "Most Significant Bits first" data.
106func (r *Reader) readMSB() (uint16, error) {
107	for r.nBits < r.width {
108		x, err := r.r.ReadByte()
109		if err != nil {
110			return 0, err
111		}
112		r.bits |= uint32(x) << (24 - r.nBits)
113		r.nBits += 8
114	}
115	code := uint16(r.bits >> (32 - r.width))
116	r.bits <<= r.width
117	r.nBits -= r.width
118	return code, nil
119}
120
121// Read implements io.Reader, reading uncompressed bytes from its underlying [Reader].
122func (r *Reader) Read(b []byte) (int, error) {
123	for {
124		if len(r.toRead) > 0 {
125			n := copy(b, r.toRead)
126			r.toRead = r.toRead[n:]
127			return n, nil
128		}
129		if r.err != nil {
130			return 0, r.err
131		}
132		r.decode()
133	}
134}
135
136// decode decompresses bytes from r and leaves them in d.toRead.
137// read specifies how to decode bytes into codes.
138// litWidth is the width in bits of literal codes.
139func (r *Reader) decode() {
140	// Loop over the code stream, converting codes into decompressed bytes.
141loop:
142	for {
143		code, err := r.read(r)
144		if err != nil {
145			if err == io.EOF {
146				err = io.ErrUnexpectedEOF
147			}
148			r.err = err
149			break
150		}
151		switch {
152		case code < r.clear:
153			// We have a literal code.
154			r.output[r.o] = uint8(code)
155			r.o++
156			if r.last != decoderInvalidCode {
157				// Save what the hi code expands to.
158				r.suffix[r.hi] = uint8(code)
159				r.prefix[r.hi] = r.last
160			}
161		case code == r.clear:
162			r.width = 1 + uint(r.litWidth)
163			r.hi = r.eof
164			r.overflow = 1 << r.width
165			r.last = decoderInvalidCode
166			continue
167		case code == r.eof:
168			r.err = io.EOF
169			break loop
170		case code <= r.hi:
171			c, i := code, len(r.output)-1
172			if code == r.hi && r.last != decoderInvalidCode {
173				// code == hi is a special case which expands to the last expansion
174				// followed by the head of the last expansion. To find the head, we walk
175				// the prefix chain until we find a literal code.
176				c = r.last
177				for c >= r.clear {
178					c = r.prefix[c]
179				}
180				r.output[i] = uint8(c)
181				i--
182				c = r.last
183			}
184			// Copy the suffix chain into output and then write that to w.
185			for c >= r.clear {
186				r.output[i] = r.suffix[c]
187				i--
188				c = r.prefix[c]
189			}
190			r.output[i] = uint8(c)
191			r.o += copy(r.output[r.o:], r.output[i:])
192			if r.last != decoderInvalidCode {
193				// Save what the hi code expands to.
194				r.suffix[r.hi] = uint8(c)
195				r.prefix[r.hi] = r.last
196			}
197		default:
198			r.err = errors.New("lzw: invalid code")
199			break loop
200		}
201		r.last, r.hi = code, r.hi+1
202		if r.hi >= r.overflow {
203			if r.hi > r.overflow {
204				panic("unreachable")
205			}
206			if r.width == maxWidth {
207				r.last = decoderInvalidCode
208				// Undo the d.hi++ a few lines above, so that (1) we maintain
209				// the invariant that d.hi < d.overflow, and (2) d.hi does not
210				// eventually overflow a uint16.
211				r.hi--
212			} else {
213				r.width++
214				r.overflow = 1 << r.width
215			}
216		}
217		if r.o >= flushBuffer {
218			break
219		}
220	}
221	// Flush pending output.
222	r.toRead = r.output[:r.o]
223	r.o = 0
224}
225
226var errClosed = errors.New("lzw: reader/writer is closed")
227
228// Close closes the [Reader] and returns an error for any future read operation.
229// It does not close the underlying [io.Reader].
230func (r *Reader) Close() error {
231	r.err = errClosed // in case any Reads come along
232	return nil
233}
234
235// Reset clears the [Reader]'s state and allows it to be reused again
236// as a new [Reader].
237func (r *Reader) Reset(src io.Reader, order Order, litWidth int) {
238	*r = Reader{}
239	r.init(src, order, litWidth)
240}
241
242// NewReader creates a new [io.ReadCloser].
243// Reads from the returned [io.ReadCloser] read and decompress data from r.
244// If r does not also implement [io.ByteReader],
245// the decompressor may read more data than necessary from r.
246// It is the caller's responsibility to call Close on the ReadCloser when
247// finished reading.
248// The number of bits to use for literal codes, litWidth, must be in the
249// range [2,8] and is typically 8. It must equal the litWidth
250// used during compression.
251//
252// It is guaranteed that the underlying type of the returned [io.ReadCloser]
253// is a *[Reader].
254func NewReader(r io.Reader, order Order, litWidth int) io.ReadCloser {
255	return newReader(r, order, litWidth)
256}
257
258func newReader(src io.Reader, order Order, litWidth int) *Reader {
259	r := new(Reader)
260	r.init(src, order, litWidth)
261	return r
262}
263
264func (r *Reader) init(src io.Reader, order Order, litWidth int) {
265	switch order {
266	case LSB:
267		r.read = (*Reader).readLSB
268	case MSB:
269		r.read = (*Reader).readMSB
270	default:
271		r.err = errors.New("lzw: unknown order")
272		return
273	}
274	if litWidth < 2 || 8 < litWidth {
275		r.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
276		return
277	}
278
279	br, ok := src.(io.ByteReader)
280	if !ok && src != nil {
281		br = bufio.NewReader(src)
282	}
283	r.r = br
284	r.litWidth = litWidth
285	r.width = 1 + uint(litWidth)
286	r.clear = uint16(1) << uint(litWidth)
287	r.eof, r.hi = r.clear+1, r.clear+1
288	r.overflow = uint16(1) << r.width
289	r.last = decoderInvalidCode
290}
291