1// Copyright 2022 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15//////////////////////////////////////////////////////////////////////////////// 16 17package jwt 18 19import ( 20 "bytes" 21 "fmt" 22 "math/rand" 23 24 spb "google.golang.org/protobuf/types/known/structpb" 25 "google.golang.org/protobuf/proto" 26 "github.com/google/tink/go/keyset" 27 jepb "github.com/google/tink/go/proto/jwt_ecdsa_go_proto" 28 jrsppb "github.com/google/tink/go/proto/jwt_rsa_ssa_pkcs1_go_proto" 29 jrpsspb "github.com/google/tink/go/proto/jwt_rsa_ssa_pss_go_proto" 30 tinkpb "github.com/google/tink/go/proto/tink_go_proto" 31) 32 33const ( 34 jwtECDSAPublicKeyType = "type.googleapis.com/google.crypto.tink.JwtEcdsaPublicKey" 35 jwtRSPublicKeyType = "type.googleapis.com/google.crypto.tink.JwtRsaSsaPkcs1PublicKey" 36 jwtPSPublicKeyType = "type.googleapis.com/google.crypto.tink.JwtRsaSsaPssPublicKey" 37) 38 39func keysetHasID(ks *tinkpb.Keyset, keyID uint32) bool { 40 for _, k := range ks.GetKey() { 41 if k.GetKeyId() == keyID { 42 return true 43 } 44 } 45 return false 46} 47 48func generateUnusedID(ks *tinkpb.Keyset) uint32 { 49 for { 50 keyID := rand.Uint32() 51 if !keysetHasID(ks, keyID) { 52 return keyID 53 } 54 } 55} 56 57func hasItem(s *spb.Struct, name string) bool { 58 if s.GetFields() == nil { 59 return false 60 } 61 _, ok := s.Fields[name] 62 return ok 63} 64 65func stringItem(s *spb.Struct, name string) (string, error) { 66 fields := s.GetFields() 67 if fields == nil { 68 return "", fmt.Errorf("no fields") 69 } 70 val, ok := fields[name] 71 if !ok { 72 return "", fmt.Errorf("field %q not found", name) 73 } 74 r, ok := val.Kind.(*spb.Value_StringValue) 75 if !ok { 76 return "", fmt.Errorf("field %q is not a string", name) 77 } 78 return r.StringValue, nil 79} 80 81func listValue(s *spb.Struct, name string) (*spb.ListValue, error) { 82 fields := s.GetFields() 83 if fields == nil { 84 return nil, fmt.Errorf("empty set") 85 } 86 vals, ok := fields[name] 87 if !ok { 88 return nil, fmt.Errorf("%q not found", name) 89 } 90 list, ok := vals.Kind.(*spb.Value_ListValue) 91 if !ok { 92 return nil, fmt.Errorf("%q is not a list", name) 93 } 94 if list.ListValue == nil || len(list.ListValue.GetValues()) == 0 { 95 return nil, fmt.Errorf("%q list is empty", name) 96 } 97 return list.ListValue, nil 98} 99 100func expectStringItem(s *spb.Struct, name, value string) error { 101 item, err := stringItem(s, name) 102 if err != nil { 103 return err 104 } 105 if item != value { 106 return fmt.Errorf("unexpected value %q for %q", value, name) 107 } 108 return nil 109} 110 111func decodeItem(s *spb.Struct, name string) ([]byte, error) { 112 e, err := stringItem(s, name) 113 if err != nil { 114 return nil, err 115 } 116 return base64Decode(e) 117} 118 119func validateKeyOPSIsVerify(s *spb.Struct) error { 120 if !hasItem(s, "key_ops") { 121 return nil 122 } 123 keyOPSList, err := listValue(s, "key_ops") 124 if err != nil { 125 return err 126 } 127 if len(keyOPSList.GetValues()) != 1 { 128 return fmt.Errorf("key_ops size is not 1") 129 } 130 value, ok := keyOPSList.GetValues()[0].Kind.(*spb.Value_StringValue) 131 if !ok { 132 return fmt.Errorf("key_ops is not a string") 133 } 134 if value.StringValue != "verify" { 135 return fmt.Errorf("key_ops is not equal to [\"verify\"]") 136 } 137 return nil 138} 139 140func validateUseIsSig(s *spb.Struct) error { 141 if !hasItem(s, "use") { 142 return nil 143 } 144 return expectStringItem(s, "use", "sig") 145} 146 147func algorithmPrefix(s *spb.Struct) (string, error) { 148 alg, err := stringItem(s, "alg") 149 if err != nil { 150 return "", err 151 } 152 if len(alg) < 2 { 153 return "", fmt.Errorf("invalid algorithm") 154 } 155 return alg[0:2], nil 156} 157 158var psNameToAlg = map[string]jrpsspb.JwtRsaSsaPssAlgorithm{ 159 "PS256": jrpsspb.JwtRsaSsaPssAlgorithm_PS256, 160 "PS384": jrpsspb.JwtRsaSsaPssAlgorithm_PS384, 161 "PS512": jrpsspb.JwtRsaSsaPssAlgorithm_PS512, 162} 163 164func psPublicKeyDataFromStruct(keyStruct *spb.Struct) (*tinkpb.KeyData, error) { 165 alg, err := stringItem(keyStruct, "alg") 166 if err != nil { 167 return nil, err 168 } 169 algorithm, ok := psNameToAlg[alg] 170 if !ok { 171 return nil, fmt.Errorf("invalid alg header: %q", alg) 172 } 173 rsaPubKey, err := rsaPubKeyFromStruct(keyStruct) 174 if err != nil { 175 return nil, err 176 } 177 jwtPubKey := &jrpsspb.JwtRsaSsaPssPublicKey{ 178 Version: jwtECDSASignerKeyVersion, 179 Algorithm: algorithm, 180 E: rsaPubKey.exponent, 181 N: rsaPubKey.modulus, 182 } 183 if rsaPubKey.customKID != nil { 184 jwtPubKey.CustomKid = &jrpsspb.JwtRsaSsaPssPublicKey_CustomKid{ 185 Value: *rsaPubKey.customKID, 186 } 187 } 188 serializedPubKey, err := proto.Marshal(jwtPubKey) 189 if err != nil { 190 return nil, err 191 } 192 return &tinkpb.KeyData{ 193 TypeUrl: jwtPSPublicKeyType, 194 Value: serializedPubKey, 195 KeyMaterialType: tinkpb.KeyData_ASYMMETRIC_PUBLIC, 196 }, nil 197} 198 199var rsNameToAlg = map[string]jrsppb.JwtRsaSsaPkcs1Algorithm{ 200 "RS256": jrsppb.JwtRsaSsaPkcs1Algorithm_RS256, 201 "RS384": jrsppb.JwtRsaSsaPkcs1Algorithm_RS384, 202 "RS512": jrsppb.JwtRsaSsaPkcs1Algorithm_RS512, 203} 204 205func rsPublicKeyDataFromStruct(keyStruct *spb.Struct) (*tinkpb.KeyData, error) { 206 alg, err := stringItem(keyStruct, "alg") 207 if err != nil { 208 return nil, err 209 } 210 algorithm, ok := rsNameToAlg[alg] 211 if !ok { 212 return nil, fmt.Errorf("invalid alg header: %q", alg) 213 } 214 rsaPubKey, err := rsaPubKeyFromStruct(keyStruct) 215 if err != nil { 216 return nil, err 217 } 218 jwtPubKey := &jrsppb.JwtRsaSsaPkcs1PublicKey{ 219 Version: 0, 220 Algorithm: algorithm, 221 E: rsaPubKey.exponent, 222 N: rsaPubKey.modulus, 223 } 224 if rsaPubKey.customKID != nil { 225 jwtPubKey.CustomKid = &jrsppb.JwtRsaSsaPkcs1PublicKey_CustomKid{ 226 Value: *rsaPubKey.customKID, 227 } 228 } 229 serializedPubKey, err := proto.Marshal(jwtPubKey) 230 if err != nil { 231 return nil, err 232 } 233 return &tinkpb.KeyData{ 234 TypeUrl: jwtRSPublicKeyType, 235 Value: serializedPubKey, 236 KeyMaterialType: tinkpb.KeyData_ASYMMETRIC_PUBLIC, 237 }, nil 238} 239 240type rsaPubKey struct { 241 exponent []byte 242 modulus []byte 243 customKID *string 244} 245 246func rsaPubKeyFromStruct(keyStruct *spb.Struct) (*rsaPubKey, error) { 247 if hasItem(keyStruct, "p") || 248 hasItem(keyStruct, "q") || 249 hasItem(keyStruct, "dq") || 250 hasItem(keyStruct, "dp") || 251 hasItem(keyStruct, "d") || 252 hasItem(keyStruct, "qi") { 253 return nil, fmt.Errorf("private key can't be converted") 254 } 255 if err := expectStringItem(keyStruct, "kty", "RSA"); err != nil { 256 return nil, err 257 } 258 if err := validateUseIsSig(keyStruct); err != nil { 259 return nil, err 260 } 261 if err := validateKeyOPSIsVerify(keyStruct); err != nil { 262 return nil, err 263 } 264 e, err := decodeItem(keyStruct, "e") 265 if err != nil { 266 return nil, err 267 } 268 n, err := decodeItem(keyStruct, "n") 269 if err != nil { 270 return nil, err 271 } 272 var customKID *string = nil 273 if hasItem(keyStruct, "kid") { 274 kid, err := stringItem(keyStruct, "kid") 275 if err != nil { 276 return nil, err 277 } 278 customKID = &kid 279 } 280 return &rsaPubKey{ 281 exponent: e, 282 modulus: n, 283 customKID: customKID, 284 }, nil 285} 286 287func esPublicKeyDataFromStruct(keyStruct *spb.Struct) (*tinkpb.KeyData, error) { 288 alg, err := stringItem(keyStruct, "alg") 289 if err != nil { 290 return nil, err 291 } 292 curve, err := stringItem(keyStruct, "crv") 293 if err != nil { 294 return nil, err 295 } 296 var algorithm jepb.JwtEcdsaAlgorithm = jepb.JwtEcdsaAlgorithm_ES_UNKNOWN 297 if alg == "ES256" && curve == "P-256" { 298 algorithm = jepb.JwtEcdsaAlgorithm_ES256 299 } 300 if alg == "ES384" && curve == "P-384" { 301 algorithm = jepb.JwtEcdsaAlgorithm_ES384 302 } 303 if alg == "ES512" && curve == "P-521" { 304 algorithm = jepb.JwtEcdsaAlgorithm_ES512 305 } 306 if algorithm == jepb.JwtEcdsaAlgorithm_ES_UNKNOWN { 307 return nil, fmt.Errorf("invalid algorithm %q and curve %q", alg, curve) 308 } 309 if hasItem(keyStruct, "d") { 310 return nil, fmt.Errorf("private keys cannot be converted") 311 } 312 if err := expectStringItem(keyStruct, "kty", "EC"); err != nil { 313 return nil, err 314 } 315 if err := validateUseIsSig(keyStruct); err != nil { 316 return nil, err 317 } 318 if err := validateKeyOPSIsVerify(keyStruct); err != nil { 319 return nil, err 320 } 321 x, err := decodeItem(keyStruct, "x") 322 if err != nil { 323 return nil, fmt.Errorf("failed to decode x: %v", err) 324 } 325 y, err := decodeItem(keyStruct, "y") 326 if err != nil { 327 return nil, fmt.Errorf("failed to decode y: %v", err) 328 } 329 var customKID *jepb.JwtEcdsaPublicKey_CustomKid = nil 330 if hasItem(keyStruct, "kid") { 331 kid, err := stringItem(keyStruct, "kid") 332 if err != nil { 333 return nil, err 334 } 335 customKID = &jepb.JwtEcdsaPublicKey_CustomKid{Value: kid} 336 } 337 pubKey := &jepb.JwtEcdsaPublicKey{ 338 Version: 0, 339 Algorithm: algorithm, 340 X: x, 341 Y: y, 342 CustomKid: customKID, 343 } 344 serializedPubKey, err := proto.Marshal(pubKey) 345 if err != nil { 346 return nil, err 347 } 348 return &tinkpb.KeyData{ 349 TypeUrl: jwtECDSAPublicKeyType, 350 Value: serializedPubKey, 351 KeyMaterialType: tinkpb.KeyData_ASYMMETRIC_PUBLIC, 352 }, nil 353} 354 355func keysetKeyFromStruct(val *spb.Value, keyID uint32) (*tinkpb.Keyset_Key, error) { 356 keyStruct := val.GetStructValue() 357 if keyStruct == nil { 358 return nil, fmt.Errorf("key is not a JSON object") 359 } 360 algPrefix, err := algorithmPrefix(keyStruct) 361 if err != nil { 362 return nil, err 363 } 364 var keyData *tinkpb.KeyData 365 switch algPrefix { 366 case "ES": 367 keyData, err = esPublicKeyDataFromStruct(keyStruct) 368 case "RS": 369 keyData, err = rsPublicKeyDataFromStruct(keyStruct) 370 case "PS": 371 keyData, err = psPublicKeyDataFromStruct(keyStruct) 372 default: 373 return nil, fmt.Errorf("unsupported algorithm prefix: %v", algPrefix) 374 } 375 if err != nil { 376 return nil, err 377 } 378 return &tinkpb.Keyset_Key{ 379 KeyData: keyData, 380 Status: tinkpb.KeyStatusType_ENABLED, 381 OutputPrefixType: tinkpb.OutputPrefixType_RAW, 382 KeyId: keyID, 383 }, nil 384} 385 386// JWKSetToPublicKeysetHandle converts a Json Web Key (JWK) set into a Tink KeysetHandle. 387// It requires that all keys in the set have the "alg" field set. Currently, only 388// public keys for algorithms ES256, ES384, ES512, RS256, RS384, and RS512 are supported. 389// JWK is defined in https://www.rfc-editor.org/rfc/rfc7517.txt. 390func JWKSetToPublicKeysetHandle(jwkSet []byte) (*keyset.Handle, error) { 391 jwk := &spb.Struct{} 392 if err := jwk.UnmarshalJSON(jwkSet); err != nil { 393 return nil, err 394 } 395 keyList, err := listValue(jwk, "keys") 396 if err != nil { 397 return nil, err 398 } 399 400 ks := &tinkpb.Keyset{} 401 for _, keyStruct := range keyList.GetValues() { 402 key, err := keysetKeyFromStruct(keyStruct, generateUnusedID(ks)) 403 if err != nil { 404 return nil, err 405 } 406 ks.Key = append(ks.Key, key) 407 } 408 ks.PrimaryKeyId = ks.Key[len(ks.Key)-1].GetKeyId() 409 return keyset.NewHandleWithNoSecrets(ks) 410} 411 412func addKeyOPSVerify(s *spb.Struct) { 413 s.GetFields()["key_ops"] = spb.NewListValue(&spb.ListValue{Values: []*spb.Value{spb.NewStringValue("verify")}}) 414} 415 416func addStringEntry(s *spb.Struct, key, val string) { 417 s.GetFields()[key] = spb.NewStringValue(val) 418} 419 420var psAlgToStr map[jrpsspb.JwtRsaSsaPssAlgorithm]string = map[jrpsspb.JwtRsaSsaPssAlgorithm]string{ 421 jrpsspb.JwtRsaSsaPssAlgorithm_PS256: "PS256", 422 jrpsspb.JwtRsaSsaPssAlgorithm_PS384: "PS384", 423 jrpsspb.JwtRsaSsaPssAlgorithm_PS512: "PS512", 424} 425 426func psPublicKeyToStruct(key *tinkpb.Keyset_Key) (*spb.Struct, error) { 427 pubKey := &jrpsspb.JwtRsaSsaPssPublicKey{} 428 if err := proto.Unmarshal(key.GetKeyData().GetValue(), pubKey); err != nil { 429 return nil, err 430 } 431 alg, ok := psAlgToStr[pubKey.GetAlgorithm()] 432 if !ok { 433 return nil, fmt.Errorf("invalid algorithm") 434 } 435 outKey := &spb.Struct{ 436 Fields: map[string]*spb.Value{}, 437 } 438 addStringEntry(outKey, "alg", alg) 439 addStringEntry(outKey, "kty", "RSA") 440 addStringEntry(outKey, "e", base64Encode(pubKey.GetE())) 441 addStringEntry(outKey, "n", base64Encode(pubKey.GetN())) 442 addStringEntry(outKey, "use", "sig") 443 addKeyOPSVerify(outKey) 444 var customKID *string = nil 445 if pubKey.GetCustomKid() != nil { 446 ck := pubKey.GetCustomKid().GetValue() 447 customKID = &ck 448 } 449 if err := setKeyID(outKey, key, customKID); err != nil { 450 return nil, err 451 } 452 return outKey, nil 453} 454 455var rsAlgToStr map[jrsppb.JwtRsaSsaPkcs1Algorithm]string = map[jrsppb.JwtRsaSsaPkcs1Algorithm]string{ 456 jrsppb.JwtRsaSsaPkcs1Algorithm_RS256: "RS256", 457 jrsppb.JwtRsaSsaPkcs1Algorithm_RS384: "RS384", 458 jrsppb.JwtRsaSsaPkcs1Algorithm_RS512: "RS512", 459} 460 461func rsPublicKeyToStruct(key *tinkpb.Keyset_Key) (*spb.Struct, error) { 462 pubKey := &jrsppb.JwtRsaSsaPkcs1PublicKey{} 463 if err := proto.Unmarshal(key.GetKeyData().GetValue(), pubKey); err != nil { 464 return nil, err 465 } 466 alg, ok := rsAlgToStr[pubKey.GetAlgorithm()] 467 if !ok { 468 return nil, fmt.Errorf("invalid algorithm") 469 } 470 outKey := &spb.Struct{ 471 Fields: map[string]*spb.Value{}, 472 } 473 addStringEntry(outKey, "alg", alg) 474 addStringEntry(outKey, "kty", "RSA") 475 addStringEntry(outKey, "e", base64Encode(pubKey.GetE())) 476 addStringEntry(outKey, "n", base64Encode(pubKey.GetN())) 477 addStringEntry(outKey, "use", "sig") 478 addKeyOPSVerify(outKey) 479 480 var customKID *string = nil 481 if pubKey.GetCustomKid() != nil { 482 ck := pubKey.GetCustomKid().GetValue() 483 customKID = &ck 484 } 485 if err := setKeyID(outKey, key, customKID); err != nil { 486 return nil, err 487 } 488 return outKey, nil 489} 490 491func esPublicKeyToStruct(key *tinkpb.Keyset_Key) (*spb.Struct, error) { 492 pubKey := &jepb.JwtEcdsaPublicKey{} 493 if err := proto.Unmarshal(key.GetKeyData().GetValue(), pubKey); err != nil { 494 return nil, err 495 } 496 outKey := &spb.Struct{ 497 Fields: map[string]*spb.Value{}, 498 } 499 var algorithm, curve string 500 switch pubKey.GetAlgorithm() { 501 case jepb.JwtEcdsaAlgorithm_ES256: 502 curve, algorithm = "P-256", "ES256" 503 case jepb.JwtEcdsaAlgorithm_ES384: 504 curve, algorithm = "P-384", "ES384" 505 case jepb.JwtEcdsaAlgorithm_ES512: 506 curve, algorithm = "P-521", "ES512" 507 default: 508 return nil, fmt.Errorf("invalid algorithm") 509 } 510 addStringEntry(outKey, "crv", curve) 511 addStringEntry(outKey, "alg", algorithm) 512 addStringEntry(outKey, "kty", "EC") 513 addStringEntry(outKey, "x", base64Encode(pubKey.GetX())) 514 addStringEntry(outKey, "y", base64Encode(pubKey.GetY())) 515 addStringEntry(outKey, "use", "sig") 516 addKeyOPSVerify(outKey) 517 518 var customKID *string = nil 519 if pubKey.GetCustomKid() != nil { 520 ck := pubKey.GetCustomKid().GetValue() 521 customKID = &ck 522 } 523 if err := setKeyID(outKey, key, customKID); err != nil { 524 return nil, err 525 } 526 return outKey, nil 527} 528 529func setKeyID(outKey *spb.Struct, key *tinkpb.Keyset_Key, customKID *string) error { 530 if key.GetOutputPrefixType() == tinkpb.OutputPrefixType_TINK { 531 if customKID != nil { 532 return fmt.Errorf("TINK keys shouldn't have custom KID") 533 } 534 kid := keyID(key.KeyId, key.GetOutputPrefixType()) 535 if kid == nil { 536 return fmt.Errorf("tink KID shouldn't be nil") 537 } 538 addStringEntry(outKey, "kid", *kid) 539 } else if customKID != nil { 540 addStringEntry(outKey, "kid", *customKID) 541 } 542 return nil 543} 544 545// JWKSetFromPublicKeysetHandle converts a Tink KeysetHandle with JWT keys into a Json Web Key (JWK) set. 546// Currently only public keys for algorithms ES256, ES384, ES512, RS256, RS384, and RS512 are supported. 547// JWK is defined in https://www.rfc-editor.org/rfc/rfc7517.html. 548func JWKSetFromPublicKeysetHandle(kh *keyset.Handle) ([]byte, error) { 549 b := &bytes.Buffer{} 550 if err := kh.WriteWithNoSecrets(keyset.NewBinaryWriter(b)); err != nil { 551 return nil, err 552 } 553 ks := &tinkpb.Keyset{} 554 if err := proto.Unmarshal(b.Bytes(), ks); err != nil { 555 return nil, err 556 } 557 keyValList := []*spb.Value{} 558 for _, k := range ks.Key { 559 if k.GetStatus() != tinkpb.KeyStatusType_ENABLED { 560 continue 561 } 562 if k.GetOutputPrefixType() != tinkpb.OutputPrefixType_TINK && 563 k.GetOutputPrefixType() != tinkpb.OutputPrefixType_RAW { 564 return nil, fmt.Errorf("unsupported output prefix type") 565 } 566 keyData := k.GetKeyData() 567 if keyData == nil { 568 return nil, fmt.Errorf("invalid key data") 569 } 570 if keyData.GetKeyMaterialType() != tinkpb.KeyData_ASYMMETRIC_PUBLIC { 571 return nil, fmt.Errorf("only asymmetric public keys are supported") 572 } 573 keyStruct := &spb.Struct{} 574 var err error 575 switch keyData.GetTypeUrl() { 576 case jwtECDSAPublicKeyType: 577 keyStruct, err = esPublicKeyToStruct(k) 578 case jwtRSPublicKeyType: 579 keyStruct, err = rsPublicKeyToStruct(k) 580 case jwtPSPublicKeyType: 581 keyStruct, err = psPublicKeyToStruct(k) 582 default: 583 return nil, fmt.Errorf("unsupported key type url") 584 } 585 if err != nil { 586 return nil, err 587 } 588 keyValList = append(keyValList, spb.NewStructValue(keyStruct)) 589 } 590 output := &spb.Struct{ 591 Fields: map[string]*spb.Value{ 592 "keys": spb.NewListValue(&spb.ListValue{Values: keyValList}), 593 }, 594 } 595 return output.MarshalJSON() 596} 597