1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package asn1
6
7import (
8	"bytes"
9	"errors"
10	"fmt"
11	"math/big"
12	"reflect"
13	"slices"
14	"time"
15	"unicode/utf8"
16)
17
18var (
19	byte00Encoder encoder = byteEncoder(0x00)
20	byteFFEncoder encoder = byteEncoder(0xff)
21)
22
23// encoder represents an ASN.1 element that is waiting to be marshaled.
24type encoder interface {
25	// Len returns the number of bytes needed to marshal this element.
26	Len() int
27	// Encode encodes this element by writing Len() bytes to dst.
28	Encode(dst []byte)
29}
30
31type byteEncoder byte
32
33func (c byteEncoder) Len() int {
34	return 1
35}
36
37func (c byteEncoder) Encode(dst []byte) {
38	dst[0] = byte(c)
39}
40
41type bytesEncoder []byte
42
43func (b bytesEncoder) Len() int {
44	return len(b)
45}
46
47func (b bytesEncoder) Encode(dst []byte) {
48	if copy(dst, b) != len(b) {
49		panic("internal error")
50	}
51}
52
53type stringEncoder string
54
55func (s stringEncoder) Len() int {
56	return len(s)
57}
58
59func (s stringEncoder) Encode(dst []byte) {
60	if copy(dst, s) != len(s) {
61		panic("internal error")
62	}
63}
64
65type multiEncoder []encoder
66
67func (m multiEncoder) Len() int {
68	var size int
69	for _, e := range m {
70		size += e.Len()
71	}
72	return size
73}
74
75func (m multiEncoder) Encode(dst []byte) {
76	var off int
77	for _, e := range m {
78		e.Encode(dst[off:])
79		off += e.Len()
80	}
81}
82
83type setEncoder []encoder
84
85func (s setEncoder) Len() int {
86	var size int
87	for _, e := range s {
88		size += e.Len()
89	}
90	return size
91}
92
93func (s setEncoder) Encode(dst []byte) {
94	// Per X690 Section 11.6: The encodings of the component values of a
95	// set-of value shall appear in ascending order, the encodings being
96	// compared as octet strings with the shorter components being padded
97	// at their trailing end with 0-octets.
98	//
99	// First we encode each element to its TLV encoding and then use
100	// octetSort to get the ordering expected by X690 DER rules before
101	// writing the sorted encodings out to dst.
102	l := make([][]byte, len(s))
103	for i, e := range s {
104		l[i] = make([]byte, e.Len())
105		e.Encode(l[i])
106	}
107
108	// Since we are using bytes.Compare to compare TLV encodings we
109	// don't need to right pad s[i] and s[j] to the same length as
110	// suggested in X690. If len(s[i]) < len(s[j]) the length octet of
111	// s[i], which is the first determining byte, will inherently be
112	// smaller than the length octet of s[j]. This lets us skip the
113	// padding step.
114	slices.SortFunc(l, bytes.Compare)
115
116	var off int
117	for _, b := range l {
118		copy(dst[off:], b)
119		off += len(b)
120	}
121}
122
123type taggedEncoder struct {
124	// scratch contains temporary space for encoding the tag and length of
125	// an element in order to avoid extra allocations.
126	scratch [8]byte
127	tag     encoder
128	body    encoder
129}
130
131func (t *taggedEncoder) Len() int {
132	return t.tag.Len() + t.body.Len()
133}
134
135func (t *taggedEncoder) Encode(dst []byte) {
136	t.tag.Encode(dst)
137	t.body.Encode(dst[t.tag.Len():])
138}
139
140type int64Encoder int64
141
142func (i int64Encoder) Len() int {
143	n := 1
144
145	for i > 127 {
146		n++
147		i >>= 8
148	}
149
150	for i < -128 {
151		n++
152		i >>= 8
153	}
154
155	return n
156}
157
158func (i int64Encoder) Encode(dst []byte) {
159	n := i.Len()
160
161	for j := 0; j < n; j++ {
162		dst[j] = byte(i >> uint((n-1-j)*8))
163	}
164}
165
166func base128IntLength(n int64) int {
167	if n == 0 {
168		return 1
169	}
170
171	l := 0
172	for i := n; i > 0; i >>= 7 {
173		l++
174	}
175
176	return l
177}
178
179func appendBase128Int(dst []byte, n int64) []byte {
180	l := base128IntLength(n)
181
182	for i := l - 1; i >= 0; i-- {
183		o := byte(n >> uint(i*7))
184		o &= 0x7f
185		if i != 0 {
186			o |= 0x80
187		}
188
189		dst = append(dst, o)
190	}
191
192	return dst
193}
194
195func makeBigInt(n *big.Int) (encoder, error) {
196	if n == nil {
197		return nil, StructuralError{"empty integer"}
198	}
199
200	if n.Sign() < 0 {
201		// A negative number has to be converted to two's-complement
202		// form. So we'll invert and subtract 1. If the
203		// most-significant-bit isn't set then we'll need to pad the
204		// beginning with 0xff in order to keep the number negative.
205		nMinus1 := new(big.Int).Neg(n)
206		nMinus1.Sub(nMinus1, bigOne)
207		bytes := nMinus1.Bytes()
208		for i := range bytes {
209			bytes[i] ^= 0xff
210		}
211		if len(bytes) == 0 || bytes[0]&0x80 == 0 {
212			return multiEncoder([]encoder{byteFFEncoder, bytesEncoder(bytes)}), nil
213		}
214		return bytesEncoder(bytes), nil
215	} else if n.Sign() == 0 {
216		// Zero is written as a single 0 zero rather than no bytes.
217		return byte00Encoder, nil
218	} else {
219		bytes := n.Bytes()
220		if len(bytes) > 0 && bytes[0]&0x80 != 0 {
221			// We'll have to pad this with 0x00 in order to stop it
222			// looking like a negative number.
223			return multiEncoder([]encoder{byte00Encoder, bytesEncoder(bytes)}), nil
224		}
225		return bytesEncoder(bytes), nil
226	}
227}
228
229func appendLength(dst []byte, i int) []byte {
230	n := lengthLength(i)
231
232	for ; n > 0; n-- {
233		dst = append(dst, byte(i>>uint((n-1)*8)))
234	}
235
236	return dst
237}
238
239func lengthLength(i int) (numBytes int) {
240	numBytes = 1
241	for i > 255 {
242		numBytes++
243		i >>= 8
244	}
245	return
246}
247
248func appendTagAndLength(dst []byte, t tagAndLength) []byte {
249	b := uint8(t.class) << 6
250	if t.isCompound {
251		b |= 0x20
252	}
253	if t.tag >= 31 {
254		b |= 0x1f
255		dst = append(dst, b)
256		dst = appendBase128Int(dst, int64(t.tag))
257	} else {
258		b |= uint8(t.tag)
259		dst = append(dst, b)
260	}
261
262	if t.length >= 128 {
263		l := lengthLength(t.length)
264		dst = append(dst, 0x80|byte(l))
265		dst = appendLength(dst, t.length)
266	} else {
267		dst = append(dst, byte(t.length))
268	}
269
270	return dst
271}
272
273type bitStringEncoder BitString
274
275func (b bitStringEncoder) Len() int {
276	return len(b.Bytes) + 1
277}
278
279func (b bitStringEncoder) Encode(dst []byte) {
280	dst[0] = byte((8 - b.BitLength%8) % 8)
281	if copy(dst[1:], b.Bytes) != len(b.Bytes) {
282		panic("internal error")
283	}
284}
285
286type oidEncoder []int
287
288func (oid oidEncoder) Len() int {
289	l := base128IntLength(int64(oid[0]*40 + oid[1]))
290	for i := 2; i < len(oid); i++ {
291		l += base128IntLength(int64(oid[i]))
292	}
293	return l
294}
295
296func (oid oidEncoder) Encode(dst []byte) {
297	dst = appendBase128Int(dst[:0], int64(oid[0]*40+oid[1]))
298	for i := 2; i < len(oid); i++ {
299		dst = appendBase128Int(dst, int64(oid[i]))
300	}
301}
302
303func makeObjectIdentifier(oid []int) (e encoder, err error) {
304	if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
305		return nil, StructuralError{"invalid object identifier"}
306	}
307
308	return oidEncoder(oid), nil
309}
310
311func makePrintableString(s string) (e encoder, err error) {
312	for i := 0; i < len(s); i++ {
313		// The asterisk is often used in PrintableString, even though
314		// it is invalid. If a PrintableString was specifically
315		// requested then the asterisk is permitted by this code.
316		// Ampersand is allowed in parsing due a handful of CA
317		// certificates, however when making new certificates
318		// it is rejected.
319		if !isPrintable(s[i], allowAsterisk, rejectAmpersand) {
320			return nil, StructuralError{"PrintableString contains invalid character"}
321		}
322	}
323
324	return stringEncoder(s), nil
325}
326
327func makeIA5String(s string) (e encoder, err error) {
328	for i := 0; i < len(s); i++ {
329		if s[i] > 127 {
330			return nil, StructuralError{"IA5String contains invalid character"}
331		}
332	}
333
334	return stringEncoder(s), nil
335}
336
337func makeNumericString(s string) (e encoder, err error) {
338	for i := 0; i < len(s); i++ {
339		if !isNumeric(s[i]) {
340			return nil, StructuralError{"NumericString contains invalid character"}
341		}
342	}
343
344	return stringEncoder(s), nil
345}
346
347func makeUTF8String(s string) encoder {
348	return stringEncoder(s)
349}
350
351func appendTwoDigits(dst []byte, v int) []byte {
352	return append(dst, byte('0'+(v/10)%10), byte('0'+v%10))
353}
354
355func appendFourDigits(dst []byte, v int) []byte {
356	return append(dst,
357		byte('0'+(v/1000)%10),
358		byte('0'+(v/100)%10),
359		byte('0'+(v/10)%10),
360		byte('0'+v%10))
361}
362
363func outsideUTCRange(t time.Time) bool {
364	year := t.Year()
365	return year < 1950 || year >= 2050
366}
367
368func makeUTCTime(t time.Time) (e encoder, err error) {
369	dst := make([]byte, 0, 18)
370
371	dst, err = appendUTCTime(dst, t)
372	if err != nil {
373		return nil, err
374	}
375
376	return bytesEncoder(dst), nil
377}
378
379func makeGeneralizedTime(t time.Time) (e encoder, err error) {
380	dst := make([]byte, 0, 20)
381
382	dst, err = appendGeneralizedTime(dst, t)
383	if err != nil {
384		return nil, err
385	}
386
387	return bytesEncoder(dst), nil
388}
389
390func appendUTCTime(dst []byte, t time.Time) (ret []byte, err error) {
391	year := t.Year()
392
393	switch {
394	case 1950 <= year && year < 2000:
395		dst = appendTwoDigits(dst, year-1900)
396	case 2000 <= year && year < 2050:
397		dst = appendTwoDigits(dst, year-2000)
398	default:
399		return nil, StructuralError{"cannot represent time as UTCTime"}
400	}
401
402	return appendTimeCommon(dst, t), nil
403}
404
405func appendGeneralizedTime(dst []byte, t time.Time) (ret []byte, err error) {
406	year := t.Year()
407	if year < 0 || year > 9999 {
408		return nil, StructuralError{"cannot represent time as GeneralizedTime"}
409	}
410
411	dst = appendFourDigits(dst, year)
412
413	return appendTimeCommon(dst, t), nil
414}
415
416func appendTimeCommon(dst []byte, t time.Time) []byte {
417	_, month, day := t.Date()
418
419	dst = appendTwoDigits(dst, int(month))
420	dst = appendTwoDigits(dst, day)
421
422	hour, min, sec := t.Clock()
423
424	dst = appendTwoDigits(dst, hour)
425	dst = appendTwoDigits(dst, min)
426	dst = appendTwoDigits(dst, sec)
427
428	_, offset := t.Zone()
429
430	switch {
431	case offset/60 == 0:
432		return append(dst, 'Z')
433	case offset > 0:
434		dst = append(dst, '+')
435	case offset < 0:
436		dst = append(dst, '-')
437	}
438
439	offsetMinutes := offset / 60
440	if offsetMinutes < 0 {
441		offsetMinutes = -offsetMinutes
442	}
443
444	dst = appendTwoDigits(dst, offsetMinutes/60)
445	dst = appendTwoDigits(dst, offsetMinutes%60)
446
447	return dst
448}
449
450func stripTagAndLength(in []byte) []byte {
451	_, offset, err := parseTagAndLength(in, 0)
452	if err != nil {
453		return in
454	}
455	return in[offset:]
456}
457
458func makeBody(value reflect.Value, params fieldParameters) (e encoder, err error) {
459	switch value.Type() {
460	case flagType:
461		return bytesEncoder(nil), nil
462	case timeType:
463		t := value.Interface().(time.Time)
464		if params.timeType == TagGeneralizedTime || outsideUTCRange(t) {
465			return makeGeneralizedTime(t)
466		}
467		return makeUTCTime(t)
468	case bitStringType:
469		return bitStringEncoder(value.Interface().(BitString)), nil
470	case objectIdentifierType:
471		return makeObjectIdentifier(value.Interface().(ObjectIdentifier))
472	case bigIntType:
473		return makeBigInt(value.Interface().(*big.Int))
474	}
475
476	switch v := value; v.Kind() {
477	case reflect.Bool:
478		if v.Bool() {
479			return byteFFEncoder, nil
480		}
481		return byte00Encoder, nil
482	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
483		return int64Encoder(v.Int()), nil
484	case reflect.Struct:
485		t := v.Type()
486
487		for i := 0; i < t.NumField(); i++ {
488			if !t.Field(i).IsExported() {
489				return nil, StructuralError{"struct contains unexported fields"}
490			}
491		}
492
493		startingField := 0
494
495		n := t.NumField()
496		if n == 0 {
497			return bytesEncoder(nil), nil
498		}
499
500		// If the first element of the structure is a non-empty
501		// RawContents, then we don't bother serializing the rest.
502		if t.Field(0).Type == rawContentsType {
503			s := v.Field(0)
504			if s.Len() > 0 {
505				bytes := s.Bytes()
506				/* The RawContents will contain the tag and
507				 * length fields but we'll also be writing
508				 * those ourselves, so we strip them out of
509				 * bytes */
510				return bytesEncoder(stripTagAndLength(bytes)), nil
511			}
512
513			startingField = 1
514		}
515
516		switch n1 := n - startingField; n1 {
517		case 0:
518			return bytesEncoder(nil), nil
519		case 1:
520			return makeField(v.Field(startingField), parseFieldParameters(t.Field(startingField).Tag.Get("asn1")))
521		default:
522			m := make([]encoder, n1)
523			for i := 0; i < n1; i++ {
524				m[i], err = makeField(v.Field(i+startingField), parseFieldParameters(t.Field(i+startingField).Tag.Get("asn1")))
525				if err != nil {
526					return nil, err
527				}
528			}
529
530			return multiEncoder(m), nil
531		}
532	case reflect.Slice:
533		sliceType := v.Type()
534		if sliceType.Elem().Kind() == reflect.Uint8 {
535			return bytesEncoder(v.Bytes()), nil
536		}
537
538		var fp fieldParameters
539
540		switch l := v.Len(); l {
541		case 0:
542			return bytesEncoder(nil), nil
543		case 1:
544			return makeField(v.Index(0), fp)
545		default:
546			m := make([]encoder, l)
547
548			for i := 0; i < l; i++ {
549				m[i], err = makeField(v.Index(i), fp)
550				if err != nil {
551					return nil, err
552				}
553			}
554
555			if params.set {
556				return setEncoder(m), nil
557			}
558			return multiEncoder(m), nil
559		}
560	case reflect.String:
561		switch params.stringType {
562		case TagIA5String:
563			return makeIA5String(v.String())
564		case TagPrintableString:
565			return makePrintableString(v.String())
566		case TagNumericString:
567			return makeNumericString(v.String())
568		default:
569			return makeUTF8String(v.String()), nil
570		}
571	}
572
573	return nil, StructuralError{"unknown Go type"}
574}
575
576func makeField(v reflect.Value, params fieldParameters) (e encoder, err error) {
577	if !v.IsValid() {
578		return nil, fmt.Errorf("asn1: cannot marshal nil value")
579	}
580	// If the field is an interface{} then recurse into it.
581	if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
582		return makeField(v.Elem(), params)
583	}
584
585	if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
586		return bytesEncoder(nil), nil
587	}
588
589	if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) {
590		defaultValue := reflect.New(v.Type()).Elem()
591		defaultValue.SetInt(*params.defaultValue)
592
593		if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) {
594			return bytesEncoder(nil), nil
595		}
596	}
597
598	// If no default value is given then the zero value for the type is
599	// assumed to be the default value. This isn't obviously the correct
600	// behavior, but it's what Go has traditionally done.
601	if params.optional && params.defaultValue == nil {
602		if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
603			return bytesEncoder(nil), nil
604		}
605	}
606
607	if v.Type() == rawValueType {
608		rv := v.Interface().(RawValue)
609		if len(rv.FullBytes) != 0 {
610			return bytesEncoder(rv.FullBytes), nil
611		}
612
613		t := new(taggedEncoder)
614
615		t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound}))
616		t.body = bytesEncoder(rv.Bytes)
617
618		return t, nil
619	}
620
621	matchAny, tag, isCompound, ok := getUniversalType(v.Type())
622	if !ok || matchAny {
623		return nil, StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
624	}
625
626	if params.timeType != 0 && tag != TagUTCTime {
627		return nil, StructuralError{"explicit time type given to non-time member"}
628	}
629
630	if params.stringType != 0 && tag != TagPrintableString {
631		return nil, StructuralError{"explicit string type given to non-string member"}
632	}
633
634	switch tag {
635	case TagPrintableString:
636		if params.stringType == 0 {
637			// This is a string without an explicit string type. We'll use
638			// a PrintableString if the character set in the string is
639			// sufficiently limited, otherwise we'll use a UTF8String.
640			for _, r := range v.String() {
641				if r >= utf8.RuneSelf || !isPrintable(byte(r), rejectAsterisk, rejectAmpersand) {
642					if !utf8.ValidString(v.String()) {
643						return nil, errors.New("asn1: string not valid UTF-8")
644					}
645					tag = TagUTF8String
646					break
647				}
648			}
649		} else {
650			tag = params.stringType
651		}
652	case TagUTCTime:
653		if params.timeType == TagGeneralizedTime || outsideUTCRange(v.Interface().(time.Time)) {
654			tag = TagGeneralizedTime
655		}
656	}
657
658	if params.set {
659		if tag != TagSequence {
660			return nil, StructuralError{"non sequence tagged as set"}
661		}
662		tag = TagSet
663	}
664
665	// makeField can be called for a slice that should be treated as a SET
666	// but doesn't have params.set set, for instance when using a slice
667	// with the SET type name suffix. In this case getUniversalType returns
668	// TagSet, but makeBody doesn't know about that so will treat the slice
669	// as a sequence. To work around this we set params.set.
670	if tag == TagSet && !params.set {
671		params.set = true
672	}
673
674	t := new(taggedEncoder)
675
676	t.body, err = makeBody(v, params)
677	if err != nil {
678		return nil, err
679	}
680
681	bodyLen := t.body.Len()
682
683	class := ClassUniversal
684	if params.tag != nil {
685		if params.application {
686			class = ClassApplication
687		} else if params.private {
688			class = ClassPrivate
689		} else {
690			class = ClassContextSpecific
691		}
692
693		if params.explicit {
694			t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{ClassUniversal, tag, bodyLen, isCompound}))
695
696			tt := new(taggedEncoder)
697
698			tt.body = t
699
700			tt.tag = bytesEncoder(appendTagAndLength(tt.scratch[:0], tagAndLength{
701				class:      class,
702				tag:        *params.tag,
703				length:     bodyLen + t.tag.Len(),
704				isCompound: true,
705			}))
706
707			return tt, nil
708		}
709
710		// implicit tag.
711		tag = *params.tag
712	}
713
714	t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound}))
715
716	return t, nil
717}
718
719// Marshal returns the ASN.1 encoding of val.
720//
721// In addition to the struct tags recognized by Unmarshal, the following can be
722// used:
723//
724//	ia5:         causes strings to be marshaled as ASN.1, IA5String values
725//	omitempty:   causes empty slices to be skipped
726//	printable:   causes strings to be marshaled as ASN.1, PrintableString values
727//	utf8:        causes strings to be marshaled as ASN.1, UTF8String values
728//	utc:         causes time.Time to be marshaled as ASN.1, UTCTime values
729//	generalized: causes time.Time to be marshaled as ASN.1, GeneralizedTime values
730func Marshal(val any) ([]byte, error) {
731	return MarshalWithParams(val, "")
732}
733
734// MarshalWithParams allows field parameters to be specified for the
735// top-level element. The form of the params is the same as the field tags.
736func MarshalWithParams(val any, params string) ([]byte, error) {
737	e, err := makeField(reflect.ValueOf(val), parseFieldParameters(params))
738	if err != nil {
739		return nil, err
740	}
741	b := make([]byte, e.Len())
742	e.Encode(b)
743	return b, nil
744}
745