1// Copyright 2023 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package x509 6 7import ( 8 "bytes" 9 "encoding/asn1" 10 "errors" 11 "math" 12 "math/big" 13 "math/bits" 14 "strconv" 15 "strings" 16) 17 18var ( 19 errInvalidOID = errors.New("invalid oid") 20) 21 22// An OID represents an ASN.1 OBJECT IDENTIFIER. 23type OID struct { 24 der []byte 25} 26 27// ParseOID parses a Object Identifier string, represented by ASCII numbers separated by dots. 28func ParseOID(oid string) (OID, error) { 29 var o OID 30 return o, o.unmarshalOIDText(oid) 31} 32 33func newOIDFromDER(der []byte) (OID, bool) { 34 if len(der) == 0 || der[len(der)-1]&0x80 != 0 { 35 return OID{}, false 36 } 37 38 start := 0 39 for i, v := range der { 40 // ITU-T X.690, section 8.19.2: 41 // The subidentifier shall be encoded in the fewest possible octets, 42 // that is, the leading octet of the subidentifier shall not have the value 0x80. 43 if i == start && v == 0x80 { 44 return OID{}, false 45 } 46 if v&0x80 == 0 { 47 start = i + 1 48 } 49 } 50 51 return OID{der}, true 52} 53 54// OIDFromInts creates a new OID using ints, each integer is a separate component. 55func OIDFromInts(oid []uint64) (OID, error) { 56 if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) { 57 return OID{}, errInvalidOID 58 } 59 60 length := base128IntLength(oid[0]*40 + oid[1]) 61 for _, v := range oid[2:] { 62 length += base128IntLength(v) 63 } 64 65 der := make([]byte, 0, length) 66 der = appendBase128Int(der, oid[0]*40+oid[1]) 67 for _, v := range oid[2:] { 68 der = appendBase128Int(der, v) 69 } 70 return OID{der}, nil 71} 72 73func base128IntLength(n uint64) int { 74 if n == 0 { 75 return 1 76 } 77 return (bits.Len64(n) + 6) / 7 78} 79 80func appendBase128Int(dst []byte, n uint64) []byte { 81 for i := base128IntLength(n) - 1; i >= 0; i-- { 82 o := byte(n >> uint(i*7)) 83 o &= 0x7f 84 if i != 0 { 85 o |= 0x80 86 } 87 dst = append(dst, o) 88 } 89 return dst 90} 91 92func base128BigIntLength(n *big.Int) int { 93 if n.Cmp(big.NewInt(0)) == 0 { 94 return 1 95 } 96 return (n.BitLen() + 6) / 7 97} 98 99func appendBase128BigInt(dst []byte, n *big.Int) []byte { 100 if n.Cmp(big.NewInt(0)) == 0 { 101 return append(dst, 0) 102 } 103 104 for i := base128BigIntLength(n) - 1; i >= 0; i-- { 105 o := byte(big.NewInt(0).Rsh(n, uint(i)*7).Bits()[0]) 106 o &= 0x7f 107 if i != 0 { 108 o |= 0x80 109 } 110 dst = append(dst, o) 111 } 112 return dst 113} 114 115// MarshalText implements [encoding.TextMarshaler] 116func (o OID) MarshalText() ([]byte, error) { 117 return []byte(o.String()), nil 118} 119 120// UnmarshalText implements [encoding.TextUnmarshaler] 121func (o *OID) UnmarshalText(text []byte) error { 122 return o.unmarshalOIDText(string(text)) 123} 124 125func (o *OID) unmarshalOIDText(oid string) error { 126 // (*big.Int).SetString allows +/- signs, but we don't want 127 // to allow them in the string representation of Object Identifier, so 128 // reject such encodings. 129 for _, c := range oid { 130 isDigit := c >= '0' && c <= '9' 131 if !isDigit && c != '.' { 132 return errInvalidOID 133 } 134 } 135 136 var ( 137 firstNum string 138 secondNum string 139 ) 140 141 var nextComponentExists bool 142 firstNum, oid, nextComponentExists = strings.Cut(oid, ".") 143 if !nextComponentExists { 144 return errInvalidOID 145 } 146 secondNum, oid, nextComponentExists = strings.Cut(oid, ".") 147 148 var ( 149 first = big.NewInt(0) 150 second = big.NewInt(0) 151 ) 152 153 if _, ok := first.SetString(firstNum, 10); !ok { 154 return errInvalidOID 155 } 156 if _, ok := second.SetString(secondNum, 10); !ok { 157 return errInvalidOID 158 } 159 160 if first.Cmp(big.NewInt(2)) > 0 || (first.Cmp(big.NewInt(2)) < 0 && second.Cmp(big.NewInt(40)) >= 0) { 161 return errInvalidOID 162 } 163 164 firstComponent := first.Mul(first, big.NewInt(40)) 165 firstComponent.Add(firstComponent, second) 166 167 der := appendBase128BigInt(make([]byte, 0, 32), firstComponent) 168 169 for nextComponentExists { 170 var strNum string 171 strNum, oid, nextComponentExists = strings.Cut(oid, ".") 172 b, ok := big.NewInt(0).SetString(strNum, 10) 173 if !ok { 174 return errInvalidOID 175 } 176 der = appendBase128BigInt(der, b) 177 } 178 179 o.der = der 180 return nil 181} 182 183// MarshalBinary implements [encoding.BinaryMarshaler] 184func (o OID) MarshalBinary() ([]byte, error) { 185 return bytes.Clone(o.der), nil 186} 187 188// UnmarshalBinary implements [encoding.BinaryUnmarshaler] 189func (o *OID) UnmarshalBinary(b []byte) error { 190 oid, ok := newOIDFromDER(bytes.Clone(b)) 191 if !ok { 192 return errInvalidOID 193 } 194 *o = oid 195 return nil 196} 197 198// Equal returns true when oid and other represents the same Object Identifier. 199func (oid OID) Equal(other OID) bool { 200 // There is only one possible DER encoding of 201 // each unique Object Identifier. 202 return bytes.Equal(oid.der, other.der) 203} 204 205func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, failed bool) { 206 offset = initOffset 207 var ret64 int64 208 for shifted := 0; offset < len(bytes); shifted++ { 209 // 5 * 7 bits per byte == 35 bits of data 210 // Thus the representation is either non-minimal or too large for an int32 211 if shifted == 5 { 212 failed = true 213 return 214 } 215 ret64 <<= 7 216 b := bytes[offset] 217 // integers should be minimally encoded, so the leading octet should 218 // never be 0x80 219 if shifted == 0 && b == 0x80 { 220 failed = true 221 return 222 } 223 ret64 |= int64(b & 0x7f) 224 offset++ 225 if b&0x80 == 0 { 226 ret = int(ret64) 227 // Ensure that the returned value fits in an int on all platforms 228 if ret64 > math.MaxInt32 { 229 failed = true 230 } 231 return 232 } 233 } 234 failed = true 235 return 236} 237 238// EqualASN1OID returns whether an OID equals an asn1.ObjectIdentifier. If 239// asn1.ObjectIdentifier cannot represent the OID specified by oid, because 240// a component of OID requires more than 31 bits, it returns false. 241func (oid OID) EqualASN1OID(other asn1.ObjectIdentifier) bool { 242 if len(other) < 2 { 243 return false 244 } 245 v, offset, failed := parseBase128Int(oid.der, 0) 246 if failed { 247 // This should never happen, since we've already parsed the OID, 248 // but just in case. 249 return false 250 } 251 if v < 80 { 252 a, b := v/40, v%40 253 if other[0] != a || other[1] != b { 254 return false 255 } 256 } else { 257 a, b := 2, v-80 258 if other[0] != a || other[1] != b { 259 return false 260 } 261 } 262 263 i := 2 264 for ; offset < len(oid.der); i++ { 265 v, offset, failed = parseBase128Int(oid.der, offset) 266 if failed { 267 // Again, shouldn't happen, since we've already parsed 268 // the OID, but better safe than sorry. 269 return false 270 } 271 if i >= len(other) || v != other[i] { 272 return false 273 } 274 } 275 276 return i == len(other) 277} 278 279// Strings returns the string representation of the Object Identifier. 280func (oid OID) String() string { 281 var b strings.Builder 282 b.Grow(32) 283 const ( 284 valSize = 64 // size in bits of val. 285 bitsPerByte = 7 286 maxValSafeShift = (1 << (valSize - bitsPerByte)) - 1 287 ) 288 var ( 289 start = 0 290 val = uint64(0) 291 numBuf = make([]byte, 0, 21) 292 bigVal *big.Int 293 overflow bool 294 ) 295 for i, v := range oid.der { 296 curVal := v & 0x7F 297 valEnd := v&0x80 == 0 298 if valEnd { 299 if start != 0 { 300 b.WriteByte('.') 301 } 302 } 303 if !overflow && val > maxValSafeShift { 304 if bigVal == nil { 305 bigVal = new(big.Int) 306 } 307 bigVal = bigVal.SetUint64(val) 308 overflow = true 309 } 310 if overflow { 311 bigVal = bigVal.Lsh(bigVal, bitsPerByte).Or(bigVal, big.NewInt(int64(curVal))) 312 if valEnd { 313 if start == 0 { 314 b.WriteString("2.") 315 bigVal = bigVal.Sub(bigVal, big.NewInt(80)) 316 } 317 numBuf = bigVal.Append(numBuf, 10) 318 b.Write(numBuf) 319 numBuf = numBuf[:0] 320 val = 0 321 start = i + 1 322 overflow = false 323 } 324 continue 325 } 326 val <<= bitsPerByte 327 val |= uint64(curVal) 328 if valEnd { 329 if start == 0 { 330 if val < 80 { 331 b.Write(strconv.AppendUint(numBuf, val/40, 10)) 332 b.WriteByte('.') 333 b.Write(strconv.AppendUint(numBuf, val%40, 10)) 334 } else { 335 b.WriteString("2.") 336 b.Write(strconv.AppendUint(numBuf, val-80, 10)) 337 } 338 } else { 339 b.Write(strconv.AppendUint(numBuf, val, 10)) 340 } 341 val = 0 342 start = i + 1 343 } 344 } 345 return b.String() 346} 347 348func (oid OID) toASN1OID() (asn1.ObjectIdentifier, bool) { 349 out := make([]int, 0, len(oid.der)+1) 350 351 const ( 352 valSize = 31 // amount of usable bits of val for OIDs. 353 bitsPerByte = 7 354 maxValSafeShift = (1 << (valSize - bitsPerByte)) - 1 355 ) 356 357 val := 0 358 359 for _, v := range oid.der { 360 if val > maxValSafeShift { 361 return nil, false 362 } 363 364 val <<= bitsPerByte 365 val |= int(v & 0x7F) 366 367 if v&0x80 == 0 { 368 if len(out) == 0 { 369 if val < 80 { 370 out = append(out, val/40) 371 out = append(out, val%40) 372 } else { 373 out = append(out, 2) 374 out = append(out, val-80) 375 } 376 val = 0 377 continue 378 } 379 out = append(out, val) 380 val = 0 381 } 382 } 383 384 return out, true 385} 386