1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mime
6
7import (
8	"errors"
9	"fmt"
10	"slices"
11	"strings"
12	"unicode"
13)
14
15// FormatMediaType serializes mediatype t and the parameters
16// param as a media type conforming to RFC 2045 and RFC 2616.
17// The type and parameter names are written in lower-case.
18// When any of the arguments result in a standard violation then
19// FormatMediaType returns the empty string.
20func FormatMediaType(t string, param map[string]string) string {
21	var b strings.Builder
22	if major, sub, ok := strings.Cut(t, "/"); !ok {
23		if !isToken(t) {
24			return ""
25		}
26		b.WriteString(strings.ToLower(t))
27	} else {
28		if !isToken(major) || !isToken(sub) {
29			return ""
30		}
31		b.WriteString(strings.ToLower(major))
32		b.WriteByte('/')
33		b.WriteString(strings.ToLower(sub))
34	}
35
36	attrs := make([]string, 0, len(param))
37	for a := range param {
38		attrs = append(attrs, a)
39	}
40	slices.Sort(attrs)
41
42	for _, attribute := range attrs {
43		value := param[attribute]
44		b.WriteByte(';')
45		b.WriteByte(' ')
46		if !isToken(attribute) {
47			return ""
48		}
49		b.WriteString(strings.ToLower(attribute))
50
51		needEnc := needsEncoding(value)
52		if needEnc {
53			// RFC 2231 section 4
54			b.WriteByte('*')
55		}
56		b.WriteByte('=')
57
58		if needEnc {
59			b.WriteString("utf-8''")
60
61			offset := 0
62			for index := 0; index < len(value); index++ {
63				ch := value[index]
64				// {RFC 2231 section 7}
65				// attribute-char := <any (US-ASCII) CHAR except SPACE, CTLs, "*", "'", "%", or tspecials>
66				if ch <= ' ' || ch >= 0x7F ||
67					ch == '*' || ch == '\'' || ch == '%' ||
68					isTSpecial(rune(ch)) {
69
70					b.WriteString(value[offset:index])
71					offset = index + 1
72
73					b.WriteByte('%')
74					b.WriteByte(upperhex[ch>>4])
75					b.WriteByte(upperhex[ch&0x0F])
76				}
77			}
78			b.WriteString(value[offset:])
79			continue
80		}
81
82		if isToken(value) {
83			b.WriteString(value)
84			continue
85		}
86
87		b.WriteByte('"')
88		offset := 0
89		for index := 0; index < len(value); index++ {
90			character := value[index]
91			if character == '"' || character == '\\' {
92				b.WriteString(value[offset:index])
93				offset = index
94				b.WriteByte('\\')
95			}
96		}
97		b.WriteString(value[offset:])
98		b.WriteByte('"')
99	}
100	return b.String()
101}
102
103func checkMediaTypeDisposition(s string) error {
104	typ, rest := consumeToken(s)
105	if typ == "" {
106		return errors.New("mime: no media type")
107	}
108	if rest == "" {
109		return nil
110	}
111	if !strings.HasPrefix(rest, "/") {
112		return errors.New("mime: expected slash after first token")
113	}
114	subtype, rest := consumeToken(rest[1:])
115	if subtype == "" {
116		return errors.New("mime: expected token after slash")
117	}
118	if rest != "" {
119		return errors.New("mime: unexpected content after media subtype")
120	}
121	return nil
122}
123
124// ErrInvalidMediaParameter is returned by [ParseMediaType] if
125// the media type value was found but there was an error parsing
126// the optional parameters
127var ErrInvalidMediaParameter = errors.New("mime: invalid media parameter")
128
129// ParseMediaType parses a media type value and any optional
130// parameters, per RFC 1521.  Media types are the values in
131// Content-Type and Content-Disposition headers (RFC 2183).
132// On success, ParseMediaType returns the media type converted
133// to lowercase and trimmed of white space and a non-nil map.
134// If there is an error parsing the optional parameter,
135// the media type will be returned along with the error
136// [ErrInvalidMediaParameter].
137// The returned map, params, maps from the lowercase
138// attribute to the attribute value with its case preserved.
139func ParseMediaType(v string) (mediatype string, params map[string]string, err error) {
140	base, _, _ := strings.Cut(v, ";")
141	mediatype = strings.TrimSpace(strings.ToLower(base))
142
143	err = checkMediaTypeDisposition(mediatype)
144	if err != nil {
145		return "", nil, err
146	}
147
148	params = make(map[string]string)
149
150	// Map of base parameter name -> parameter name -> value
151	// for parameters containing a '*' character.
152	// Lazily initialized.
153	var continuation map[string]map[string]string
154
155	v = v[len(base):]
156	for len(v) > 0 {
157		v = strings.TrimLeftFunc(v, unicode.IsSpace)
158		if len(v) == 0 {
159			break
160		}
161		key, value, rest := consumeMediaParam(v)
162		if key == "" {
163			if strings.TrimSpace(rest) == ";" {
164				// Ignore trailing semicolons.
165				// Not an error.
166				break
167			}
168			// Parse error.
169			return mediatype, nil, ErrInvalidMediaParameter
170		}
171
172		pmap := params
173		if baseName, _, ok := strings.Cut(key, "*"); ok {
174			if continuation == nil {
175				continuation = make(map[string]map[string]string)
176			}
177			var ok bool
178			if pmap, ok = continuation[baseName]; !ok {
179				continuation[baseName] = make(map[string]string)
180				pmap = continuation[baseName]
181			}
182		}
183		if v, exists := pmap[key]; exists && v != value {
184			// Duplicate parameter names are incorrect, but we allow them if they are equal.
185			return "", nil, errors.New("mime: duplicate parameter name")
186		}
187		pmap[key] = value
188		v = rest
189	}
190
191	// Stitch together any continuations or things with stars
192	// (i.e. RFC 2231 things with stars: "foo*0" or "foo*")
193	var buf strings.Builder
194	for key, pieceMap := range continuation {
195		singlePartKey := key + "*"
196		if v, ok := pieceMap[singlePartKey]; ok {
197			if decv, ok := decode2231Enc(v); ok {
198				params[key] = decv
199			}
200			continue
201		}
202
203		buf.Reset()
204		valid := false
205		for n := 0; ; n++ {
206			simplePart := fmt.Sprintf("%s*%d", key, n)
207			if v, ok := pieceMap[simplePart]; ok {
208				valid = true
209				buf.WriteString(v)
210				continue
211			}
212			encodedPart := simplePart + "*"
213			v, ok := pieceMap[encodedPart]
214			if !ok {
215				break
216			}
217			valid = true
218			if n == 0 {
219				if decv, ok := decode2231Enc(v); ok {
220					buf.WriteString(decv)
221				}
222			} else {
223				decv, _ := percentHexUnescape(v)
224				buf.WriteString(decv)
225			}
226		}
227		if valid {
228			params[key] = buf.String()
229		}
230	}
231
232	return
233}
234
235func decode2231Enc(v string) (string, bool) {
236	sv := strings.SplitN(v, "'", 3)
237	if len(sv) != 3 {
238		return "", false
239	}
240	// TODO: ignoring lang in sv[1] for now. If anybody needs it we'll
241	// need to decide how to expose it in the API. But I'm not sure
242	// anybody uses it in practice.
243	charset := strings.ToLower(sv[0])
244	if len(charset) == 0 {
245		return "", false
246	}
247	if charset != "us-ascii" && charset != "utf-8" {
248		// TODO: unsupported encoding
249		return "", false
250	}
251	encv, err := percentHexUnescape(sv[2])
252	if err != nil {
253		return "", false
254	}
255	return encv, true
256}
257
258func isNotTokenChar(r rune) bool {
259	return !isTokenChar(r)
260}
261
262// consumeToken consumes a token from the beginning of provided
263// string, per RFC 2045 section 5.1 (referenced from 2183), and return
264// the token consumed and the rest of the string. Returns ("", v) on
265// failure to consume at least one character.
266func consumeToken(v string) (token, rest string) {
267	notPos := strings.IndexFunc(v, isNotTokenChar)
268	if notPos == -1 {
269		return v, ""
270	}
271	if notPos == 0 {
272		return "", v
273	}
274	return v[0:notPos], v[notPos:]
275}
276
277// consumeValue consumes a "value" per RFC 2045, where a value is
278// either a 'token' or a 'quoted-string'.  On success, consumeValue
279// returns the value consumed (and de-quoted/escaped, if a
280// quoted-string) and the rest of the string. On failure, returns
281// ("", v).
282func consumeValue(v string) (value, rest string) {
283	if v == "" {
284		return
285	}
286	if v[0] != '"' {
287		return consumeToken(v)
288	}
289
290	// parse a quoted-string
291	buffer := new(strings.Builder)
292	for i := 1; i < len(v); i++ {
293		r := v[i]
294		if r == '"' {
295			return buffer.String(), v[i+1:]
296		}
297		// When MSIE sends a full file path (in "intranet mode"), it does not
298		// escape backslashes: "C:\dev\go\foo.txt", not "C:\\dev\\go\\foo.txt".
299		//
300		// No known MIME generators emit unnecessary backslash escapes
301		// for simple token characters like numbers and letters.
302		//
303		// If we see an unnecessary backslash escape, assume it is from MSIE
304		// and intended as a literal backslash. This makes Go servers deal better
305		// with MSIE without affecting the way they handle conforming MIME
306		// generators.
307		if r == '\\' && i+1 < len(v) && isTSpecial(rune(v[i+1])) {
308			buffer.WriteByte(v[i+1])
309			i++
310			continue
311		}
312		if r == '\r' || r == '\n' {
313			return "", v
314		}
315		buffer.WriteByte(v[i])
316	}
317	// Did not find end quote.
318	return "", v
319}
320
321func consumeMediaParam(v string) (param, value, rest string) {
322	rest = strings.TrimLeftFunc(v, unicode.IsSpace)
323	if !strings.HasPrefix(rest, ";") {
324		return "", "", v
325	}
326
327	rest = rest[1:] // consume semicolon
328	rest = strings.TrimLeftFunc(rest, unicode.IsSpace)
329	param, rest = consumeToken(rest)
330	param = strings.ToLower(param)
331	if param == "" {
332		return "", "", v
333	}
334
335	rest = strings.TrimLeftFunc(rest, unicode.IsSpace)
336	if !strings.HasPrefix(rest, "=") {
337		return "", "", v
338	}
339	rest = rest[1:] // consume equals sign
340	rest = strings.TrimLeftFunc(rest, unicode.IsSpace)
341	value, rest2 := consumeValue(rest)
342	if value == "" && rest2 == rest {
343		return "", "", v
344	}
345	rest = rest2
346	return param, value, rest
347}
348
349func percentHexUnescape(s string) (string, error) {
350	// Count %, check that they're well-formed.
351	percents := 0
352	for i := 0; i < len(s); {
353		if s[i] != '%' {
354			i++
355			continue
356		}
357		percents++
358		if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
359			s = s[i:]
360			if len(s) > 3 {
361				s = s[0:3]
362			}
363			return "", fmt.Errorf("mime: bogus characters after %%: %q", s)
364		}
365		i += 3
366	}
367	if percents == 0 {
368		return s, nil
369	}
370
371	t := make([]byte, len(s)-2*percents)
372	j := 0
373	for i := 0; i < len(s); {
374		switch s[i] {
375		case '%':
376			t[j] = unhex(s[i+1])<<4 | unhex(s[i+2])
377			j++
378			i += 3
379		default:
380			t[j] = s[i]
381			j++
382			i++
383		}
384	}
385	return string(t), nil
386}
387
388func ishex(c byte) bool {
389	switch {
390	case '0' <= c && c <= '9':
391		return true
392	case 'a' <= c && c <= 'f':
393		return true
394	case 'A' <= c && c <= 'F':
395		return true
396	}
397	return false
398}
399
400func unhex(c byte) byte {
401	switch {
402	case '0' <= c && c <= '9':
403		return c - '0'
404	case 'a' <= c && c <= 'f':
405		return c - 'a' + 10
406	case 'A' <= c && c <= 'F':
407		return c - 'A' + 10
408	}
409	return 0
410}
411