1// Copyright 2009 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package asn1 6 7import ( 8 "bytes" 9 "errors" 10 "fmt" 11 "math/big" 12 "reflect" 13 "slices" 14 "time" 15 "unicode/utf8" 16) 17 18var ( 19 byte00Encoder encoder = byteEncoder(0x00) 20 byteFFEncoder encoder = byteEncoder(0xff) 21) 22 23// encoder represents an ASN.1 element that is waiting to be marshaled. 24type encoder interface { 25 // Len returns the number of bytes needed to marshal this element. 26 Len() int 27 // Encode encodes this element by writing Len() bytes to dst. 28 Encode(dst []byte) 29} 30 31type byteEncoder byte 32 33func (c byteEncoder) Len() int { 34 return 1 35} 36 37func (c byteEncoder) Encode(dst []byte) { 38 dst[0] = byte(c) 39} 40 41type bytesEncoder []byte 42 43func (b bytesEncoder) Len() int { 44 return len(b) 45} 46 47func (b bytesEncoder) Encode(dst []byte) { 48 if copy(dst, b) != len(b) { 49 panic("internal error") 50 } 51} 52 53type stringEncoder string 54 55func (s stringEncoder) Len() int { 56 return len(s) 57} 58 59func (s stringEncoder) Encode(dst []byte) { 60 if copy(dst, s) != len(s) { 61 panic("internal error") 62 } 63} 64 65type multiEncoder []encoder 66 67func (m multiEncoder) Len() int { 68 var size int 69 for _, e := range m { 70 size += e.Len() 71 } 72 return size 73} 74 75func (m multiEncoder) Encode(dst []byte) { 76 var off int 77 for _, e := range m { 78 e.Encode(dst[off:]) 79 off += e.Len() 80 } 81} 82 83type setEncoder []encoder 84 85func (s setEncoder) Len() int { 86 var size int 87 for _, e := range s { 88 size += e.Len() 89 } 90 return size 91} 92 93func (s setEncoder) Encode(dst []byte) { 94 // Per X690 Section 11.6: The encodings of the component values of a 95 // set-of value shall appear in ascending order, the encodings being 96 // compared as octet strings with the shorter components being padded 97 // at their trailing end with 0-octets. 98 // 99 // First we encode each element to its TLV encoding and then use 100 // octetSort to get the ordering expected by X690 DER rules before 101 // writing the sorted encodings out to dst. 102 l := make([][]byte, len(s)) 103 for i, e := range s { 104 l[i] = make([]byte, e.Len()) 105 e.Encode(l[i]) 106 } 107 108 // Since we are using bytes.Compare to compare TLV encodings we 109 // don't need to right pad s[i] and s[j] to the same length as 110 // suggested in X690. If len(s[i]) < len(s[j]) the length octet of 111 // s[i], which is the first determining byte, will inherently be 112 // smaller than the length octet of s[j]. This lets us skip the 113 // padding step. 114 slices.SortFunc(l, bytes.Compare) 115 116 var off int 117 for _, b := range l { 118 copy(dst[off:], b) 119 off += len(b) 120 } 121} 122 123type taggedEncoder struct { 124 // scratch contains temporary space for encoding the tag and length of 125 // an element in order to avoid extra allocations. 126 scratch [8]byte 127 tag encoder 128 body encoder 129} 130 131func (t *taggedEncoder) Len() int { 132 return t.tag.Len() + t.body.Len() 133} 134 135func (t *taggedEncoder) Encode(dst []byte) { 136 t.tag.Encode(dst) 137 t.body.Encode(dst[t.tag.Len():]) 138} 139 140type int64Encoder int64 141 142func (i int64Encoder) Len() int { 143 n := 1 144 145 for i > 127 { 146 n++ 147 i >>= 8 148 } 149 150 for i < -128 { 151 n++ 152 i >>= 8 153 } 154 155 return n 156} 157 158func (i int64Encoder) Encode(dst []byte) { 159 n := i.Len() 160 161 for j := 0; j < n; j++ { 162 dst[j] = byte(i >> uint((n-1-j)*8)) 163 } 164} 165 166func base128IntLength(n int64) int { 167 if n == 0 { 168 return 1 169 } 170 171 l := 0 172 for i := n; i > 0; i >>= 7 { 173 l++ 174 } 175 176 return l 177} 178 179func appendBase128Int(dst []byte, n int64) []byte { 180 l := base128IntLength(n) 181 182 for i := l - 1; i >= 0; i-- { 183 o := byte(n >> uint(i*7)) 184 o &= 0x7f 185 if i != 0 { 186 o |= 0x80 187 } 188 189 dst = append(dst, o) 190 } 191 192 return dst 193} 194 195func makeBigInt(n *big.Int) (encoder, error) { 196 if n == nil { 197 return nil, StructuralError{"empty integer"} 198 } 199 200 if n.Sign() < 0 { 201 // A negative number has to be converted to two's-complement 202 // form. So we'll invert and subtract 1. If the 203 // most-significant-bit isn't set then we'll need to pad the 204 // beginning with 0xff in order to keep the number negative. 205 nMinus1 := new(big.Int).Neg(n) 206 nMinus1.Sub(nMinus1, bigOne) 207 bytes := nMinus1.Bytes() 208 for i := range bytes { 209 bytes[i] ^= 0xff 210 } 211 if len(bytes) == 0 || bytes[0]&0x80 == 0 { 212 return multiEncoder([]encoder{byteFFEncoder, bytesEncoder(bytes)}), nil 213 } 214 return bytesEncoder(bytes), nil 215 } else if n.Sign() == 0 { 216 // Zero is written as a single 0 zero rather than no bytes. 217 return byte00Encoder, nil 218 } else { 219 bytes := n.Bytes() 220 if len(bytes) > 0 && bytes[0]&0x80 != 0 { 221 // We'll have to pad this with 0x00 in order to stop it 222 // looking like a negative number. 223 return multiEncoder([]encoder{byte00Encoder, bytesEncoder(bytes)}), nil 224 } 225 return bytesEncoder(bytes), nil 226 } 227} 228 229func appendLength(dst []byte, i int) []byte { 230 n := lengthLength(i) 231 232 for ; n > 0; n-- { 233 dst = append(dst, byte(i>>uint((n-1)*8))) 234 } 235 236 return dst 237} 238 239func lengthLength(i int) (numBytes int) { 240 numBytes = 1 241 for i > 255 { 242 numBytes++ 243 i >>= 8 244 } 245 return 246} 247 248func appendTagAndLength(dst []byte, t tagAndLength) []byte { 249 b := uint8(t.class) << 6 250 if t.isCompound { 251 b |= 0x20 252 } 253 if t.tag >= 31 { 254 b |= 0x1f 255 dst = append(dst, b) 256 dst = appendBase128Int(dst, int64(t.tag)) 257 } else { 258 b |= uint8(t.tag) 259 dst = append(dst, b) 260 } 261 262 if t.length >= 128 { 263 l := lengthLength(t.length) 264 dst = append(dst, 0x80|byte(l)) 265 dst = appendLength(dst, t.length) 266 } else { 267 dst = append(dst, byte(t.length)) 268 } 269 270 return dst 271} 272 273type bitStringEncoder BitString 274 275func (b bitStringEncoder) Len() int { 276 return len(b.Bytes) + 1 277} 278 279func (b bitStringEncoder) Encode(dst []byte) { 280 dst[0] = byte((8 - b.BitLength%8) % 8) 281 if copy(dst[1:], b.Bytes) != len(b.Bytes) { 282 panic("internal error") 283 } 284} 285 286type oidEncoder []int 287 288func (oid oidEncoder) Len() int { 289 l := base128IntLength(int64(oid[0]*40 + oid[1])) 290 for i := 2; i < len(oid); i++ { 291 l += base128IntLength(int64(oid[i])) 292 } 293 return l 294} 295 296func (oid oidEncoder) Encode(dst []byte) { 297 dst = appendBase128Int(dst[:0], int64(oid[0]*40+oid[1])) 298 for i := 2; i < len(oid); i++ { 299 dst = appendBase128Int(dst, int64(oid[i])) 300 } 301} 302 303func makeObjectIdentifier(oid []int) (e encoder, err error) { 304 if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) { 305 return nil, StructuralError{"invalid object identifier"} 306 } 307 308 return oidEncoder(oid), nil 309} 310 311func makePrintableString(s string) (e encoder, err error) { 312 for i := 0; i < len(s); i++ { 313 // The asterisk is often used in PrintableString, even though 314 // it is invalid. If a PrintableString was specifically 315 // requested then the asterisk is permitted by this code. 316 // Ampersand is allowed in parsing due a handful of CA 317 // certificates, however when making new certificates 318 // it is rejected. 319 if !isPrintable(s[i], allowAsterisk, rejectAmpersand) { 320 return nil, StructuralError{"PrintableString contains invalid character"} 321 } 322 } 323 324 return stringEncoder(s), nil 325} 326 327func makeIA5String(s string) (e encoder, err error) { 328 for i := 0; i < len(s); i++ { 329 if s[i] > 127 { 330 return nil, StructuralError{"IA5String contains invalid character"} 331 } 332 } 333 334 return stringEncoder(s), nil 335} 336 337func makeNumericString(s string) (e encoder, err error) { 338 for i := 0; i < len(s); i++ { 339 if !isNumeric(s[i]) { 340 return nil, StructuralError{"NumericString contains invalid character"} 341 } 342 } 343 344 return stringEncoder(s), nil 345} 346 347func makeUTF8String(s string) encoder { 348 return stringEncoder(s) 349} 350 351func appendTwoDigits(dst []byte, v int) []byte { 352 return append(dst, byte('0'+(v/10)%10), byte('0'+v%10)) 353} 354 355func appendFourDigits(dst []byte, v int) []byte { 356 return append(dst, 357 byte('0'+(v/1000)%10), 358 byte('0'+(v/100)%10), 359 byte('0'+(v/10)%10), 360 byte('0'+v%10)) 361} 362 363func outsideUTCRange(t time.Time) bool { 364 year := t.Year() 365 return year < 1950 || year >= 2050 366} 367 368func makeUTCTime(t time.Time) (e encoder, err error) { 369 dst := make([]byte, 0, 18) 370 371 dst, err = appendUTCTime(dst, t) 372 if err != nil { 373 return nil, err 374 } 375 376 return bytesEncoder(dst), nil 377} 378 379func makeGeneralizedTime(t time.Time) (e encoder, err error) { 380 dst := make([]byte, 0, 20) 381 382 dst, err = appendGeneralizedTime(dst, t) 383 if err != nil { 384 return nil, err 385 } 386 387 return bytesEncoder(dst), nil 388} 389 390func appendUTCTime(dst []byte, t time.Time) (ret []byte, err error) { 391 year := t.Year() 392 393 switch { 394 case 1950 <= year && year < 2000: 395 dst = appendTwoDigits(dst, year-1900) 396 case 2000 <= year && year < 2050: 397 dst = appendTwoDigits(dst, year-2000) 398 default: 399 return nil, StructuralError{"cannot represent time as UTCTime"} 400 } 401 402 return appendTimeCommon(dst, t), nil 403} 404 405func appendGeneralizedTime(dst []byte, t time.Time) (ret []byte, err error) { 406 year := t.Year() 407 if year < 0 || year > 9999 { 408 return nil, StructuralError{"cannot represent time as GeneralizedTime"} 409 } 410 411 dst = appendFourDigits(dst, year) 412 413 return appendTimeCommon(dst, t), nil 414} 415 416func appendTimeCommon(dst []byte, t time.Time) []byte { 417 _, month, day := t.Date() 418 419 dst = appendTwoDigits(dst, int(month)) 420 dst = appendTwoDigits(dst, day) 421 422 hour, min, sec := t.Clock() 423 424 dst = appendTwoDigits(dst, hour) 425 dst = appendTwoDigits(dst, min) 426 dst = appendTwoDigits(dst, sec) 427 428 _, offset := t.Zone() 429 430 switch { 431 case offset/60 == 0: 432 return append(dst, 'Z') 433 case offset > 0: 434 dst = append(dst, '+') 435 case offset < 0: 436 dst = append(dst, '-') 437 } 438 439 offsetMinutes := offset / 60 440 if offsetMinutes < 0 { 441 offsetMinutes = -offsetMinutes 442 } 443 444 dst = appendTwoDigits(dst, offsetMinutes/60) 445 dst = appendTwoDigits(dst, offsetMinutes%60) 446 447 return dst 448} 449 450func stripTagAndLength(in []byte) []byte { 451 _, offset, err := parseTagAndLength(in, 0) 452 if err != nil { 453 return in 454 } 455 return in[offset:] 456} 457 458func makeBody(value reflect.Value, params fieldParameters) (e encoder, err error) { 459 switch value.Type() { 460 case flagType: 461 return bytesEncoder(nil), nil 462 case timeType: 463 t := value.Interface().(time.Time) 464 if params.timeType == TagGeneralizedTime || outsideUTCRange(t) { 465 return makeGeneralizedTime(t) 466 } 467 return makeUTCTime(t) 468 case bitStringType: 469 return bitStringEncoder(value.Interface().(BitString)), nil 470 case objectIdentifierType: 471 return makeObjectIdentifier(value.Interface().(ObjectIdentifier)) 472 case bigIntType: 473 return makeBigInt(value.Interface().(*big.Int)) 474 } 475 476 switch v := value; v.Kind() { 477 case reflect.Bool: 478 if v.Bool() { 479 return byteFFEncoder, nil 480 } 481 return byte00Encoder, nil 482 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 483 return int64Encoder(v.Int()), nil 484 case reflect.Struct: 485 t := v.Type() 486 487 for i := 0; i < t.NumField(); i++ { 488 if !t.Field(i).IsExported() { 489 return nil, StructuralError{"struct contains unexported fields"} 490 } 491 } 492 493 startingField := 0 494 495 n := t.NumField() 496 if n == 0 { 497 return bytesEncoder(nil), nil 498 } 499 500 // If the first element of the structure is a non-empty 501 // RawContents, then we don't bother serializing the rest. 502 if t.Field(0).Type == rawContentsType { 503 s := v.Field(0) 504 if s.Len() > 0 { 505 bytes := s.Bytes() 506 /* The RawContents will contain the tag and 507 * length fields but we'll also be writing 508 * those ourselves, so we strip them out of 509 * bytes */ 510 return bytesEncoder(stripTagAndLength(bytes)), nil 511 } 512 513 startingField = 1 514 } 515 516 switch n1 := n - startingField; n1 { 517 case 0: 518 return bytesEncoder(nil), nil 519 case 1: 520 return makeField(v.Field(startingField), parseFieldParameters(t.Field(startingField).Tag.Get("asn1"))) 521 default: 522 m := make([]encoder, n1) 523 for i := 0; i < n1; i++ { 524 m[i], err = makeField(v.Field(i+startingField), parseFieldParameters(t.Field(i+startingField).Tag.Get("asn1"))) 525 if err != nil { 526 return nil, err 527 } 528 } 529 530 return multiEncoder(m), nil 531 } 532 case reflect.Slice: 533 sliceType := v.Type() 534 if sliceType.Elem().Kind() == reflect.Uint8 { 535 return bytesEncoder(v.Bytes()), nil 536 } 537 538 var fp fieldParameters 539 540 switch l := v.Len(); l { 541 case 0: 542 return bytesEncoder(nil), nil 543 case 1: 544 return makeField(v.Index(0), fp) 545 default: 546 m := make([]encoder, l) 547 548 for i := 0; i < l; i++ { 549 m[i], err = makeField(v.Index(i), fp) 550 if err != nil { 551 return nil, err 552 } 553 } 554 555 if params.set { 556 return setEncoder(m), nil 557 } 558 return multiEncoder(m), nil 559 } 560 case reflect.String: 561 switch params.stringType { 562 case TagIA5String: 563 return makeIA5String(v.String()) 564 case TagPrintableString: 565 return makePrintableString(v.String()) 566 case TagNumericString: 567 return makeNumericString(v.String()) 568 default: 569 return makeUTF8String(v.String()), nil 570 } 571 } 572 573 return nil, StructuralError{"unknown Go type"} 574} 575 576func makeField(v reflect.Value, params fieldParameters) (e encoder, err error) { 577 if !v.IsValid() { 578 return nil, fmt.Errorf("asn1: cannot marshal nil value") 579 } 580 // If the field is an interface{} then recurse into it. 581 if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 { 582 return makeField(v.Elem(), params) 583 } 584 585 if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty { 586 return bytesEncoder(nil), nil 587 } 588 589 if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) { 590 defaultValue := reflect.New(v.Type()).Elem() 591 defaultValue.SetInt(*params.defaultValue) 592 593 if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) { 594 return bytesEncoder(nil), nil 595 } 596 } 597 598 // If no default value is given then the zero value for the type is 599 // assumed to be the default value. This isn't obviously the correct 600 // behavior, but it's what Go has traditionally done. 601 if params.optional && params.defaultValue == nil { 602 if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) { 603 return bytesEncoder(nil), nil 604 } 605 } 606 607 if v.Type() == rawValueType { 608 rv := v.Interface().(RawValue) 609 if len(rv.FullBytes) != 0 { 610 return bytesEncoder(rv.FullBytes), nil 611 } 612 613 t := new(taggedEncoder) 614 615 t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})) 616 t.body = bytesEncoder(rv.Bytes) 617 618 return t, nil 619 } 620 621 matchAny, tag, isCompound, ok := getUniversalType(v.Type()) 622 if !ok || matchAny { 623 return nil, StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())} 624 } 625 626 if params.timeType != 0 && tag != TagUTCTime { 627 return nil, StructuralError{"explicit time type given to non-time member"} 628 } 629 630 if params.stringType != 0 && tag != TagPrintableString { 631 return nil, StructuralError{"explicit string type given to non-string member"} 632 } 633 634 switch tag { 635 case TagPrintableString: 636 if params.stringType == 0 { 637 // This is a string without an explicit string type. We'll use 638 // a PrintableString if the character set in the string is 639 // sufficiently limited, otherwise we'll use a UTF8String. 640 for _, r := range v.String() { 641 if r >= utf8.RuneSelf || !isPrintable(byte(r), rejectAsterisk, rejectAmpersand) { 642 if !utf8.ValidString(v.String()) { 643 return nil, errors.New("asn1: string not valid UTF-8") 644 } 645 tag = TagUTF8String 646 break 647 } 648 } 649 } else { 650 tag = params.stringType 651 } 652 case TagUTCTime: 653 if params.timeType == TagGeneralizedTime || outsideUTCRange(v.Interface().(time.Time)) { 654 tag = TagGeneralizedTime 655 } 656 } 657 658 if params.set { 659 if tag != TagSequence { 660 return nil, StructuralError{"non sequence tagged as set"} 661 } 662 tag = TagSet 663 } 664 665 // makeField can be called for a slice that should be treated as a SET 666 // but doesn't have params.set set, for instance when using a slice 667 // with the SET type name suffix. In this case getUniversalType returns 668 // TagSet, but makeBody doesn't know about that so will treat the slice 669 // as a sequence. To work around this we set params.set. 670 if tag == TagSet && !params.set { 671 params.set = true 672 } 673 674 t := new(taggedEncoder) 675 676 t.body, err = makeBody(v, params) 677 if err != nil { 678 return nil, err 679 } 680 681 bodyLen := t.body.Len() 682 683 class := ClassUniversal 684 if params.tag != nil { 685 if params.application { 686 class = ClassApplication 687 } else if params.private { 688 class = ClassPrivate 689 } else { 690 class = ClassContextSpecific 691 } 692 693 if params.explicit { 694 t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{ClassUniversal, tag, bodyLen, isCompound})) 695 696 tt := new(taggedEncoder) 697 698 tt.body = t 699 700 tt.tag = bytesEncoder(appendTagAndLength(tt.scratch[:0], tagAndLength{ 701 class: class, 702 tag: *params.tag, 703 length: bodyLen + t.tag.Len(), 704 isCompound: true, 705 })) 706 707 return tt, nil 708 } 709 710 // implicit tag. 711 tag = *params.tag 712 } 713 714 t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound})) 715 716 return t, nil 717} 718 719// Marshal returns the ASN.1 encoding of val. 720// 721// In addition to the struct tags recognized by Unmarshal, the following can be 722// used: 723// 724// ia5: causes strings to be marshaled as ASN.1, IA5String values 725// omitempty: causes empty slices to be skipped 726// printable: causes strings to be marshaled as ASN.1, PrintableString values 727// utf8: causes strings to be marshaled as ASN.1, UTF8String values 728// utc: causes time.Time to be marshaled as ASN.1, UTCTime values 729// generalized: causes time.Time to be marshaled as ASN.1, GeneralizedTime values 730func Marshal(val any) ([]byte, error) { 731 return MarshalWithParams(val, "") 732} 733 734// MarshalWithParams allows field parameters to be specified for the 735// top-level element. The form of the params is the same as the field tags. 736func MarshalWithParams(val any, params string) ([]byte, error) { 737 e, err := makeField(reflect.ValueOf(val), parseFieldParameters(params)) 738 if err != nil { 739 return nil, err 740 } 741 b := make([]byte, e.Len()) 742 e.Encode(b) 743 return b, nil 744} 745