xref: /aosp_15_r20/external/tink/go/jwt/jwt_rsa_ssa_pkcs1_signer_key_manager.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package jwt
18
19import (
20	"crypto/rand"
21	"crypto/rsa"
22	"errors"
23	"fmt"
24	"math/big"
25
26	"google.golang.org/protobuf/proto"
27	"github.com/google/tink/go/core/registry"
28	internal "github.com/google/tink/go/internal/signature"
29	"github.com/google/tink/go/keyset"
30	jrsppb "github.com/google/tink/go/proto/jwt_rsa_ssa_pkcs1_go_proto"
31	tinkpb "github.com/google/tink/go/proto/tink_go_proto"
32)
33
34const (
35	jwtRSSignerKeyVersion = 0
36	jwtRSSignerTypeURL    = "type.googleapis.com/google.crypto.tink.JwtRsaSsaPkcs1PrivateKey"
37)
38
39var (
40	errRSInvalidPrivateKey = errors.New("invalid JwtRsaSsaPkcs1PrivateKey")
41	errRSInvalidKeyFormat  = errors.New("invalid RSA SSA PKCS1 key format")
42)
43
44// jwtRSSignerKeyManager implements the KeyManager interface
45// for JWT Signing using the 'RS256', 'RS384', and 'RS512' JWA algorithm.
46type jwtRSSignerKeyManager struct{}
47
48var _ registry.PrivateKeyManager = (*jwtRSSignerKeyManager)(nil)
49
50func bytesToBigInt(v []byte) *big.Int {
51	return new(big.Int).SetBytes(v)
52}
53
54func (km *jwtRSSignerKeyManager) Primitive(serializedKey []byte) (interface{}, error) {
55	if serializedKey == nil {
56		return nil, fmt.Errorf("invalid JwtRsaSsaPkcs1PrivateKey")
57	}
58	privKey := &jrsppb.JwtRsaSsaPkcs1PrivateKey{}
59	if err := proto.Unmarshal(serializedKey, privKey); err != nil {
60		return nil, fmt.Errorf("failed to unmarshal RsaSsaPkcs1PrivateKey: %v", err)
61	}
62	if err := validateRSPrivateKey(privKey); err != nil {
63		return nil, err
64	}
65	rsaPrivKey := &rsa.PrivateKey{
66		PublicKey: rsa.PublicKey{
67			N: bytesToBigInt(privKey.GetPublicKey().GetN()),
68			E: int(bytesToBigInt(privKey.GetPublicKey().GetE()).Int64()),
69		},
70		D: bytesToBigInt(privKey.GetD()),
71		Primes: []*big.Int{
72			bytesToBigInt(privKey.GetP()),
73			bytesToBigInt(privKey.GetQ()),
74		},
75		Precomputed: rsa.PrecomputedValues{
76			Dp: bytesToBigInt(privKey.GetDp()),
77			Dq: bytesToBigInt(privKey.GetDq()),
78			// in crypto/rsa `GetCrt()` returns the "Chinese Remainder Theorem
79			// coefficient q^(-1) mod p. Which is `Qinv` in the tink proto and not
80			// the `CRTValues`.
81			Qinv: bytesToBigInt(privKey.GetCrt()),
82		},
83	}
84	alg := privKey.GetPublicKey().GetAlgorithm()
85	if err := internal.Validate_RSA_SSA_PKCS1(validRSAlgToHash[alg], rsaPrivKey); err != nil {
86		return nil, err
87	}
88	signer, err := internal.New_RSA_SSA_PKCS1_Signer(validRSAlgToHash[alg], rsaPrivKey)
89	if err != nil {
90		return nil, err
91	}
92	return newSignerWithKID(signer, alg.String(), rsCustomKID(privKey.GetPublicKey()))
93}
94
95func validateRSPrivateKey(privKey *jrsppb.JwtRsaSsaPkcs1PrivateKey) error {
96	if err := keyset.ValidateKeyVersion(privKey.Version, jwtRSSignerKeyVersion); err != nil {
97		return err
98	}
99	if privKey.GetD() == nil ||
100		len(privKey.GetPublicKey().GetN()) == 0 ||
101		len(privKey.GetPublicKey().GetE()) == 0 ||
102		privKey.GetP() == nil ||
103		privKey.GetQ() == nil ||
104		privKey.GetDp() == nil ||
105		privKey.GetDq() == nil ||
106		privKey.GetCrt() == nil {
107		return fmt.Errorf("invalid private key")
108	}
109	if err := validateRSPublicKey(privKey.GetPublicKey()); err != nil {
110		return err
111	}
112	return nil
113}
114
115func (km *jwtRSSignerKeyManager) NewKey(serializedKeyFormat []byte) (proto.Message, error) {
116	if len(serializedKeyFormat) == 0 {
117		return nil, errRSInvalidKeyFormat
118	}
119	keyFormat := &jrsppb.JwtRsaSsaPkcs1KeyFormat{}
120	if err := proto.Unmarshal(serializedKeyFormat, keyFormat); err != nil {
121		return nil, fmt.Errorf("failed to unmarshal JwtRsaSsaPkcs1KeyFormat: %v", err)
122	}
123	if err := keyset.ValidateKeyVersion(keyFormat.GetVersion(), jwtRSSignerKeyVersion); err != nil {
124		return nil, err
125	}
126	if keyFormat.GetVersion() != jwtRSSignerKeyVersion {
127		return nil, fmt.Errorf("invalid key format version: %d", keyFormat.GetVersion())
128	}
129	rsaKey, err := rsa.GenerateKey(rand.Reader, int(keyFormat.GetModulusSizeInBits()))
130	if err != nil {
131		return nil, err
132	}
133	privKey := &jrsppb.JwtRsaSsaPkcs1PrivateKey{
134		Version: jwtRSSignerKeyVersion,
135		PublicKey: &jrsppb.JwtRsaSsaPkcs1PublicKey{
136			Version:   jwtRSSignerKeyVersion,
137			Algorithm: keyFormat.GetAlgorithm(),
138			N:         rsaKey.PublicKey.N.Bytes(),
139			E:         keyFormat.GetPublicExponent(),
140		},
141		D:  rsaKey.D.Bytes(),
142		P:  rsaKey.Primes[0].Bytes(),
143		Q:  rsaKey.Primes[1].Bytes(),
144		Dp: rsaKey.Precomputed.Dp.Bytes(),
145		Dq: rsaKey.Precomputed.Dq.Bytes(),
146		// in crypto/rsa `GetCrt()` returns the "Chinese Remainder Theorem
147		// coefficient q^(-1) mod p. Which is `Qinv` in the tink proto and not
148		// the `CRTValues`.
149		Crt: rsaKey.Precomputed.Qinv.Bytes(),
150	}
151	if err := validateRSPrivateKey(privKey); err != nil {
152		return nil, err
153	}
154	return privKey, nil
155}
156
157func (km *jwtRSSignerKeyManager) NewKeyData(serializedKeyFormat []byte) (*tinkpb.KeyData, error) {
158	key, err := km.NewKey(serializedKeyFormat)
159	if err != nil {
160		return nil, err
161	}
162	serializedKey, err := proto.Marshal(key)
163	if err != nil {
164		return nil, err
165	}
166	return &tinkpb.KeyData{
167		TypeUrl:         jwtRSSignerTypeURL,
168		Value:           serializedKey,
169		KeyMaterialType: tinkpb.KeyData_ASYMMETRIC_PRIVATE,
170	}, nil
171}
172
173func (km *jwtRSSignerKeyManager) PublicKeyData(serializedPrivKey []byte) (*tinkpb.KeyData, error) {
174	if serializedPrivKey == nil {
175		return nil, errRSInvalidKeyFormat
176	}
177	privKey := &jrsppb.JwtRsaSsaPkcs1PrivateKey{}
178	if err := proto.Unmarshal(serializedPrivKey, privKey); err != nil {
179		return nil, fmt.Errorf("failed to unmarshal JwtRsaSsaPkcs1PrivateKey: %v", err)
180	}
181	if err := validateRSPrivateKey(privKey); err != nil {
182		return nil, err
183	}
184	serializedPubKey, err := proto.Marshal(privKey.GetPublicKey())
185	if err != nil {
186		return nil, err
187	}
188	return &tinkpb.KeyData{
189		TypeUrl:         jwtRSVerifierTypeURL,
190		Value:           serializedPubKey,
191		KeyMaterialType: tinkpb.KeyData_ASYMMETRIC_PUBLIC,
192	}, nil
193}
194
195func (km *jwtRSSignerKeyManager) DoesSupport(typeURL string) bool {
196	return jwtRSSignerTypeURL == typeURL
197}
198
199func (km *jwtRSSignerKeyManager) TypeURL() string {
200	return jwtRSSignerTypeURL
201}
202