xref: /aosp_15_r20/external/tink/go/jwt/jwk_converter.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package jwt
18
19import (
20	"bytes"
21	"fmt"
22	"math/rand"
23
24	spb "google.golang.org/protobuf/types/known/structpb"
25	"google.golang.org/protobuf/proto"
26	"github.com/google/tink/go/keyset"
27	jepb "github.com/google/tink/go/proto/jwt_ecdsa_go_proto"
28	jrsppb "github.com/google/tink/go/proto/jwt_rsa_ssa_pkcs1_go_proto"
29	jrpsspb "github.com/google/tink/go/proto/jwt_rsa_ssa_pss_go_proto"
30	tinkpb "github.com/google/tink/go/proto/tink_go_proto"
31)
32
33const (
34	jwtECDSAPublicKeyType = "type.googleapis.com/google.crypto.tink.JwtEcdsaPublicKey"
35	jwtRSPublicKeyType    = "type.googleapis.com/google.crypto.tink.JwtRsaSsaPkcs1PublicKey"
36	jwtPSPublicKeyType    = "type.googleapis.com/google.crypto.tink.JwtRsaSsaPssPublicKey"
37)
38
39func keysetHasID(ks *tinkpb.Keyset, keyID uint32) bool {
40	for _, k := range ks.GetKey() {
41		if k.GetKeyId() == keyID {
42			return true
43		}
44	}
45	return false
46}
47
48func generateUnusedID(ks *tinkpb.Keyset) uint32 {
49	for {
50		keyID := rand.Uint32()
51		if !keysetHasID(ks, keyID) {
52			return keyID
53		}
54	}
55}
56
57func hasItem(s *spb.Struct, name string) bool {
58	if s.GetFields() == nil {
59		return false
60	}
61	_, ok := s.Fields[name]
62	return ok
63}
64
65func stringItem(s *spb.Struct, name string) (string, error) {
66	fields := s.GetFields()
67	if fields == nil {
68		return "", fmt.Errorf("no fields")
69	}
70	val, ok := fields[name]
71	if !ok {
72		return "", fmt.Errorf("field %q not found", name)
73	}
74	r, ok := val.Kind.(*spb.Value_StringValue)
75	if !ok {
76		return "", fmt.Errorf("field %q is not a string", name)
77	}
78	return r.StringValue, nil
79}
80
81func listValue(s *spb.Struct, name string) (*spb.ListValue, error) {
82	fields := s.GetFields()
83	if fields == nil {
84		return nil, fmt.Errorf("empty set")
85	}
86	vals, ok := fields[name]
87	if !ok {
88		return nil, fmt.Errorf("%q not found", name)
89	}
90	list, ok := vals.Kind.(*spb.Value_ListValue)
91	if !ok {
92		return nil, fmt.Errorf("%q is not a list", name)
93	}
94	if list.ListValue == nil || len(list.ListValue.GetValues()) == 0 {
95		return nil, fmt.Errorf("%q list is empty", name)
96	}
97	return list.ListValue, nil
98}
99
100func expectStringItem(s *spb.Struct, name, value string) error {
101	item, err := stringItem(s, name)
102	if err != nil {
103		return err
104	}
105	if item != value {
106		return fmt.Errorf("unexpected value %q for %q", value, name)
107	}
108	return nil
109}
110
111func decodeItem(s *spb.Struct, name string) ([]byte, error) {
112	e, err := stringItem(s, name)
113	if err != nil {
114		return nil, err
115	}
116	return base64Decode(e)
117}
118
119func validateKeyOPSIsVerify(s *spb.Struct) error {
120	if !hasItem(s, "key_ops") {
121		return nil
122	}
123	keyOPSList, err := listValue(s, "key_ops")
124	if err != nil {
125		return err
126	}
127	if len(keyOPSList.GetValues()) != 1 {
128		return fmt.Errorf("key_ops size is not 1")
129	}
130	value, ok := keyOPSList.GetValues()[0].Kind.(*spb.Value_StringValue)
131	if !ok {
132		return fmt.Errorf("key_ops is not a string")
133	}
134	if value.StringValue != "verify" {
135		return fmt.Errorf("key_ops is not equal to [\"verify\"]")
136	}
137	return nil
138}
139
140func validateUseIsSig(s *spb.Struct) error {
141	if !hasItem(s, "use") {
142		return nil
143	}
144	return expectStringItem(s, "use", "sig")
145}
146
147func algorithmPrefix(s *spb.Struct) (string, error) {
148	alg, err := stringItem(s, "alg")
149	if err != nil {
150		return "", err
151	}
152	if len(alg) < 2 {
153		return "", fmt.Errorf("invalid algorithm")
154	}
155	return alg[0:2], nil
156}
157
158var psNameToAlg = map[string]jrpsspb.JwtRsaSsaPssAlgorithm{
159	"PS256": jrpsspb.JwtRsaSsaPssAlgorithm_PS256,
160	"PS384": jrpsspb.JwtRsaSsaPssAlgorithm_PS384,
161	"PS512": jrpsspb.JwtRsaSsaPssAlgorithm_PS512,
162}
163
164func psPublicKeyDataFromStruct(keyStruct *spb.Struct) (*tinkpb.KeyData, error) {
165	alg, err := stringItem(keyStruct, "alg")
166	if err != nil {
167		return nil, err
168	}
169	algorithm, ok := psNameToAlg[alg]
170	if !ok {
171		return nil, fmt.Errorf("invalid alg header: %q", alg)
172	}
173	rsaPubKey, err := rsaPubKeyFromStruct(keyStruct)
174	if err != nil {
175		return nil, err
176	}
177	jwtPubKey := &jrpsspb.JwtRsaSsaPssPublicKey{
178		Version:   jwtECDSASignerKeyVersion,
179		Algorithm: algorithm,
180		E:         rsaPubKey.exponent,
181		N:         rsaPubKey.modulus,
182	}
183	if rsaPubKey.customKID != nil {
184		jwtPubKey.CustomKid = &jrpsspb.JwtRsaSsaPssPublicKey_CustomKid{
185			Value: *rsaPubKey.customKID,
186		}
187	}
188	serializedPubKey, err := proto.Marshal(jwtPubKey)
189	if err != nil {
190		return nil, err
191	}
192	return &tinkpb.KeyData{
193		TypeUrl:         jwtPSPublicKeyType,
194		Value:           serializedPubKey,
195		KeyMaterialType: tinkpb.KeyData_ASYMMETRIC_PUBLIC,
196	}, nil
197}
198
199var rsNameToAlg = map[string]jrsppb.JwtRsaSsaPkcs1Algorithm{
200	"RS256": jrsppb.JwtRsaSsaPkcs1Algorithm_RS256,
201	"RS384": jrsppb.JwtRsaSsaPkcs1Algorithm_RS384,
202	"RS512": jrsppb.JwtRsaSsaPkcs1Algorithm_RS512,
203}
204
205func rsPublicKeyDataFromStruct(keyStruct *spb.Struct) (*tinkpb.KeyData, error) {
206	alg, err := stringItem(keyStruct, "alg")
207	if err != nil {
208		return nil, err
209	}
210	algorithm, ok := rsNameToAlg[alg]
211	if !ok {
212		return nil, fmt.Errorf("invalid alg header: %q", alg)
213	}
214	rsaPubKey, err := rsaPubKeyFromStruct(keyStruct)
215	if err != nil {
216		return nil, err
217	}
218	jwtPubKey := &jrsppb.JwtRsaSsaPkcs1PublicKey{
219		Version:   0,
220		Algorithm: algorithm,
221		E:         rsaPubKey.exponent,
222		N:         rsaPubKey.modulus,
223	}
224	if rsaPubKey.customKID != nil {
225		jwtPubKey.CustomKid = &jrsppb.JwtRsaSsaPkcs1PublicKey_CustomKid{
226			Value: *rsaPubKey.customKID,
227		}
228	}
229	serializedPubKey, err := proto.Marshal(jwtPubKey)
230	if err != nil {
231		return nil, err
232	}
233	return &tinkpb.KeyData{
234		TypeUrl:         jwtRSPublicKeyType,
235		Value:           serializedPubKey,
236		KeyMaterialType: tinkpb.KeyData_ASYMMETRIC_PUBLIC,
237	}, nil
238}
239
240type rsaPubKey struct {
241	exponent  []byte
242	modulus   []byte
243	customKID *string
244}
245
246func rsaPubKeyFromStruct(keyStruct *spb.Struct) (*rsaPubKey, error) {
247	if hasItem(keyStruct, "p") ||
248		hasItem(keyStruct, "q") ||
249		hasItem(keyStruct, "dq") ||
250		hasItem(keyStruct, "dp") ||
251		hasItem(keyStruct, "d") ||
252		hasItem(keyStruct, "qi") {
253		return nil, fmt.Errorf("private key can't be converted")
254	}
255	if err := expectStringItem(keyStruct, "kty", "RSA"); err != nil {
256		return nil, err
257	}
258	if err := validateUseIsSig(keyStruct); err != nil {
259		return nil, err
260	}
261	if err := validateKeyOPSIsVerify(keyStruct); err != nil {
262		return nil, err
263	}
264	e, err := decodeItem(keyStruct, "e")
265	if err != nil {
266		return nil, err
267	}
268	n, err := decodeItem(keyStruct, "n")
269	if err != nil {
270		return nil, err
271	}
272	var customKID *string = nil
273	if hasItem(keyStruct, "kid") {
274		kid, err := stringItem(keyStruct, "kid")
275		if err != nil {
276			return nil, err
277		}
278		customKID = &kid
279	}
280	return &rsaPubKey{
281		exponent:  e,
282		modulus:   n,
283		customKID: customKID,
284	}, nil
285}
286
287func esPublicKeyDataFromStruct(keyStruct *spb.Struct) (*tinkpb.KeyData, error) {
288	alg, err := stringItem(keyStruct, "alg")
289	if err != nil {
290		return nil, err
291	}
292	curve, err := stringItem(keyStruct, "crv")
293	if err != nil {
294		return nil, err
295	}
296	var algorithm jepb.JwtEcdsaAlgorithm = jepb.JwtEcdsaAlgorithm_ES_UNKNOWN
297	if alg == "ES256" && curve == "P-256" {
298		algorithm = jepb.JwtEcdsaAlgorithm_ES256
299	}
300	if alg == "ES384" && curve == "P-384" {
301		algorithm = jepb.JwtEcdsaAlgorithm_ES384
302	}
303	if alg == "ES512" && curve == "P-521" {
304		algorithm = jepb.JwtEcdsaAlgorithm_ES512
305	}
306	if algorithm == jepb.JwtEcdsaAlgorithm_ES_UNKNOWN {
307		return nil, fmt.Errorf("invalid algorithm %q and curve %q", alg, curve)
308	}
309	if hasItem(keyStruct, "d") {
310		return nil, fmt.Errorf("private keys cannot be converted")
311	}
312	if err := expectStringItem(keyStruct, "kty", "EC"); err != nil {
313		return nil, err
314	}
315	if err := validateUseIsSig(keyStruct); err != nil {
316		return nil, err
317	}
318	if err := validateKeyOPSIsVerify(keyStruct); err != nil {
319		return nil, err
320	}
321	x, err := decodeItem(keyStruct, "x")
322	if err != nil {
323		return nil, fmt.Errorf("failed to decode x: %v", err)
324	}
325	y, err := decodeItem(keyStruct, "y")
326	if err != nil {
327		return nil, fmt.Errorf("failed to decode y: %v", err)
328	}
329	var customKID *jepb.JwtEcdsaPublicKey_CustomKid = nil
330	if hasItem(keyStruct, "kid") {
331		kid, err := stringItem(keyStruct, "kid")
332		if err != nil {
333			return nil, err
334		}
335		customKID = &jepb.JwtEcdsaPublicKey_CustomKid{Value: kid}
336	}
337	pubKey := &jepb.JwtEcdsaPublicKey{
338		Version:   0,
339		Algorithm: algorithm,
340		X:         x,
341		Y:         y,
342		CustomKid: customKID,
343	}
344	serializedPubKey, err := proto.Marshal(pubKey)
345	if err != nil {
346		return nil, err
347	}
348	return &tinkpb.KeyData{
349		TypeUrl:         jwtECDSAPublicKeyType,
350		Value:           serializedPubKey,
351		KeyMaterialType: tinkpb.KeyData_ASYMMETRIC_PUBLIC,
352	}, nil
353}
354
355func keysetKeyFromStruct(val *spb.Value, keyID uint32) (*tinkpb.Keyset_Key, error) {
356	keyStruct := val.GetStructValue()
357	if keyStruct == nil {
358		return nil, fmt.Errorf("key is not a JSON object")
359	}
360	algPrefix, err := algorithmPrefix(keyStruct)
361	if err != nil {
362		return nil, err
363	}
364	var keyData *tinkpb.KeyData
365	switch algPrefix {
366	case "ES":
367		keyData, err = esPublicKeyDataFromStruct(keyStruct)
368	case "RS":
369		keyData, err = rsPublicKeyDataFromStruct(keyStruct)
370	case "PS":
371		keyData, err = psPublicKeyDataFromStruct(keyStruct)
372	default:
373		return nil, fmt.Errorf("unsupported algorithm prefix: %v", algPrefix)
374	}
375	if err != nil {
376		return nil, err
377	}
378	return &tinkpb.Keyset_Key{
379		KeyData:          keyData,
380		Status:           tinkpb.KeyStatusType_ENABLED,
381		OutputPrefixType: tinkpb.OutputPrefixType_RAW,
382		KeyId:            keyID,
383	}, nil
384}
385
386// JWKSetToPublicKeysetHandle converts a Json Web Key (JWK) set into a Tink KeysetHandle.
387// It requires that all keys in the set have the "alg" field set. Currently, only
388// public keys for algorithms ES256, ES384, ES512, RS256, RS384, and RS512 are supported.
389// JWK is defined in https://www.rfc-editor.org/rfc/rfc7517.txt.
390func JWKSetToPublicKeysetHandle(jwkSet []byte) (*keyset.Handle, error) {
391	jwk := &spb.Struct{}
392	if err := jwk.UnmarshalJSON(jwkSet); err != nil {
393		return nil, err
394	}
395	keyList, err := listValue(jwk, "keys")
396	if err != nil {
397		return nil, err
398	}
399
400	ks := &tinkpb.Keyset{}
401	for _, keyStruct := range keyList.GetValues() {
402		key, err := keysetKeyFromStruct(keyStruct, generateUnusedID(ks))
403		if err != nil {
404			return nil, err
405		}
406		ks.Key = append(ks.Key, key)
407	}
408	ks.PrimaryKeyId = ks.Key[len(ks.Key)-1].GetKeyId()
409	return keyset.NewHandleWithNoSecrets(ks)
410}
411
412func addKeyOPSVerify(s *spb.Struct) {
413	s.GetFields()["key_ops"] = spb.NewListValue(&spb.ListValue{Values: []*spb.Value{spb.NewStringValue("verify")}})
414}
415
416func addStringEntry(s *spb.Struct, key, val string) {
417	s.GetFields()[key] = spb.NewStringValue(val)
418}
419
420var psAlgToStr map[jrpsspb.JwtRsaSsaPssAlgorithm]string = map[jrpsspb.JwtRsaSsaPssAlgorithm]string{
421	jrpsspb.JwtRsaSsaPssAlgorithm_PS256: "PS256",
422	jrpsspb.JwtRsaSsaPssAlgorithm_PS384: "PS384",
423	jrpsspb.JwtRsaSsaPssAlgorithm_PS512: "PS512",
424}
425
426func psPublicKeyToStruct(key *tinkpb.Keyset_Key) (*spb.Struct, error) {
427	pubKey := &jrpsspb.JwtRsaSsaPssPublicKey{}
428	if err := proto.Unmarshal(key.GetKeyData().GetValue(), pubKey); err != nil {
429		return nil, err
430	}
431	alg, ok := psAlgToStr[pubKey.GetAlgorithm()]
432	if !ok {
433		return nil, fmt.Errorf("invalid algorithm")
434	}
435	outKey := &spb.Struct{
436		Fields: map[string]*spb.Value{},
437	}
438	addStringEntry(outKey, "alg", alg)
439	addStringEntry(outKey, "kty", "RSA")
440	addStringEntry(outKey, "e", base64Encode(pubKey.GetE()))
441	addStringEntry(outKey, "n", base64Encode(pubKey.GetN()))
442	addStringEntry(outKey, "use", "sig")
443	addKeyOPSVerify(outKey)
444	var customKID *string = nil
445	if pubKey.GetCustomKid() != nil {
446		ck := pubKey.GetCustomKid().GetValue()
447		customKID = &ck
448	}
449	if err := setKeyID(outKey, key, customKID); err != nil {
450		return nil, err
451	}
452	return outKey, nil
453}
454
455var rsAlgToStr map[jrsppb.JwtRsaSsaPkcs1Algorithm]string = map[jrsppb.JwtRsaSsaPkcs1Algorithm]string{
456	jrsppb.JwtRsaSsaPkcs1Algorithm_RS256: "RS256",
457	jrsppb.JwtRsaSsaPkcs1Algorithm_RS384: "RS384",
458	jrsppb.JwtRsaSsaPkcs1Algorithm_RS512: "RS512",
459}
460
461func rsPublicKeyToStruct(key *tinkpb.Keyset_Key) (*spb.Struct, error) {
462	pubKey := &jrsppb.JwtRsaSsaPkcs1PublicKey{}
463	if err := proto.Unmarshal(key.GetKeyData().GetValue(), pubKey); err != nil {
464		return nil, err
465	}
466	alg, ok := rsAlgToStr[pubKey.GetAlgorithm()]
467	if !ok {
468		return nil, fmt.Errorf("invalid algorithm")
469	}
470	outKey := &spb.Struct{
471		Fields: map[string]*spb.Value{},
472	}
473	addStringEntry(outKey, "alg", alg)
474	addStringEntry(outKey, "kty", "RSA")
475	addStringEntry(outKey, "e", base64Encode(pubKey.GetE()))
476	addStringEntry(outKey, "n", base64Encode(pubKey.GetN()))
477	addStringEntry(outKey, "use", "sig")
478	addKeyOPSVerify(outKey)
479
480	var customKID *string = nil
481	if pubKey.GetCustomKid() != nil {
482		ck := pubKey.GetCustomKid().GetValue()
483		customKID = &ck
484	}
485	if err := setKeyID(outKey, key, customKID); err != nil {
486		return nil, err
487	}
488	return outKey, nil
489}
490
491func esPublicKeyToStruct(key *tinkpb.Keyset_Key) (*spb.Struct, error) {
492	pubKey := &jepb.JwtEcdsaPublicKey{}
493	if err := proto.Unmarshal(key.GetKeyData().GetValue(), pubKey); err != nil {
494		return nil, err
495	}
496	outKey := &spb.Struct{
497		Fields: map[string]*spb.Value{},
498	}
499	var algorithm, curve string
500	switch pubKey.GetAlgorithm() {
501	case jepb.JwtEcdsaAlgorithm_ES256:
502		curve, algorithm = "P-256", "ES256"
503	case jepb.JwtEcdsaAlgorithm_ES384:
504		curve, algorithm = "P-384", "ES384"
505	case jepb.JwtEcdsaAlgorithm_ES512:
506		curve, algorithm = "P-521", "ES512"
507	default:
508		return nil, fmt.Errorf("invalid algorithm")
509	}
510	addStringEntry(outKey, "crv", curve)
511	addStringEntry(outKey, "alg", algorithm)
512	addStringEntry(outKey, "kty", "EC")
513	addStringEntry(outKey, "x", base64Encode(pubKey.GetX()))
514	addStringEntry(outKey, "y", base64Encode(pubKey.GetY()))
515	addStringEntry(outKey, "use", "sig")
516	addKeyOPSVerify(outKey)
517
518	var customKID *string = nil
519	if pubKey.GetCustomKid() != nil {
520		ck := pubKey.GetCustomKid().GetValue()
521		customKID = &ck
522	}
523	if err := setKeyID(outKey, key, customKID); err != nil {
524		return nil, err
525	}
526	return outKey, nil
527}
528
529func setKeyID(outKey *spb.Struct, key *tinkpb.Keyset_Key, customKID *string) error {
530	if key.GetOutputPrefixType() == tinkpb.OutputPrefixType_TINK {
531		if customKID != nil {
532			return fmt.Errorf("TINK keys shouldn't have custom KID")
533		}
534		kid := keyID(key.KeyId, key.GetOutputPrefixType())
535		if kid == nil {
536			return fmt.Errorf("tink KID shouldn't be nil")
537		}
538		addStringEntry(outKey, "kid", *kid)
539	} else if customKID != nil {
540		addStringEntry(outKey, "kid", *customKID)
541	}
542	return nil
543}
544
545// JWKSetFromPublicKeysetHandle converts a Tink KeysetHandle with JWT keys into a Json Web Key (JWK) set.
546// Currently only public keys for algorithms ES256, ES384, ES512, RS256, RS384, and RS512 are supported.
547// JWK is defined in https://www.rfc-editor.org/rfc/rfc7517.html.
548func JWKSetFromPublicKeysetHandle(kh *keyset.Handle) ([]byte, error) {
549	b := &bytes.Buffer{}
550	if err := kh.WriteWithNoSecrets(keyset.NewBinaryWriter(b)); err != nil {
551		return nil, err
552	}
553	ks := &tinkpb.Keyset{}
554	if err := proto.Unmarshal(b.Bytes(), ks); err != nil {
555		return nil, err
556	}
557	keyValList := []*spb.Value{}
558	for _, k := range ks.Key {
559		if k.GetStatus() != tinkpb.KeyStatusType_ENABLED {
560			continue
561		}
562		if k.GetOutputPrefixType() != tinkpb.OutputPrefixType_TINK &&
563			k.GetOutputPrefixType() != tinkpb.OutputPrefixType_RAW {
564			return nil, fmt.Errorf("unsupported output prefix type")
565		}
566		keyData := k.GetKeyData()
567		if keyData == nil {
568			return nil, fmt.Errorf("invalid key data")
569		}
570		if keyData.GetKeyMaterialType() != tinkpb.KeyData_ASYMMETRIC_PUBLIC {
571			return nil, fmt.Errorf("only asymmetric public keys are supported")
572		}
573		keyStruct := &spb.Struct{}
574		var err error
575		switch keyData.GetTypeUrl() {
576		case jwtECDSAPublicKeyType:
577			keyStruct, err = esPublicKeyToStruct(k)
578		case jwtRSPublicKeyType:
579			keyStruct, err = rsPublicKeyToStruct(k)
580		case jwtPSPublicKeyType:
581			keyStruct, err = psPublicKeyToStruct(k)
582		default:
583			return nil, fmt.Errorf("unsupported key type url")
584		}
585		if err != nil {
586			return nil, err
587		}
588		keyValList = append(keyValList, spb.NewStructValue(keyStruct))
589	}
590	output := &spb.Struct{
591		Fields: map[string]*spb.Value{
592			"keys": spb.NewListValue(&spb.ListValue{Values: keyValList}),
593		},
594	}
595	return output.MarshalJSON()
596}
597