xref: /aosp_15_r20/external/tink/python/tink/jwt/_jwt_signature_key_manager_test.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2021 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for tink.python.tink.jwt._jwt_signature_key_manager."""
15
16import datetime
17
18from typing import cast
19
20from absl.testing import absltest
21from absl.testing import parameterized
22
23from tink.proto import jwt_ecdsa_pb2
24from tink.proto import tink_pb2
25import tink
26from tink import jwt
27from tink.jwt import _jwt_format
28
29from tink.jwt import _jwt_signature_key_manager
30from tink.jwt import _jwt_signature_wrappers
31
32DATETIME_1970 = datetime.datetime.fromtimestamp(12345, datetime.timezone.utc)
33DATETIME_2011 = datetime.datetime.fromtimestamp(1300819380,
34                                                datetime.timezone.utc)
35DATETIME_2020 = datetime.datetime.fromtimestamp(1582230020,
36                                                datetime.timezone.utc)
37
38
39def setUpModule():
40  jwt.register_jwt_signature()
41
42
43def gen_compact(json_header: str, json_payload: str, raw_sign) -> str:
44  unsigned_compact = (
45      _jwt_format.encode_header(json_header) + b'.' +
46      _jwt_format.encode_payload(json_payload))
47  signature = raw_sign.sign(unsigned_compact)
48  return _jwt_format.create_signed_compact(unsigned_compact, signature)
49
50
51class JwtSignatureKeyManagerTest(parameterized.TestCase):
52
53  def test_create_sign_verify(self):
54    handle = tink.new_keyset_handle(jwt.jwt_es256_template())
55    sign = handle.primitive(jwt.JwtPublicKeySign)
56    verify = handle.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify)
57    raw_jwt = jwt.new_raw_jwt(
58        issuer='joe',
59        expiration=DATETIME_2011,
60        custom_claims={'http://example.com/is_root': True})
61    signed_compact = sign.sign_and_encode(raw_jwt)
62
63    validator = jwt.new_validator(
64        expected_issuer='joe', fixed_now=DATETIME_1970)
65    verified_jwt = verify.verify_and_decode(signed_compact, validator)
66    self.assertEqual(verified_jwt.issuer(), 'joe')
67    self.assertEqual(verified_jwt.expiration().year, 2011)
68    self.assertCountEqual(verified_jwt.custom_claim_names(),
69                          ['http://example.com/is_root'])
70    self.assertTrue(verified_jwt.custom_claim('http://example.com/is_root'))
71
72    # fails because it is expired
73    with self.assertRaises(tink.TinkError):
74      verify.verify_and_decode(
75          signed_compact,
76          jwt.new_validator(expected_issuer='joe', fixed_now=DATETIME_2020))
77
78    # wrong issuer
79    with self.assertRaises(tink.TinkError):
80      verify.verify_and_decode(
81          signed_compact,
82          jwt.new_validator(expected_issuer='jane', fixed_now=DATETIME_1970))
83
84    # invalid format
85    with self.assertRaises(tink.TinkError):
86      verify.verify_and_decode(signed_compact + '.123', validator)
87
88    # invalid character
89    with self.assertRaises(tink.TinkError):
90      verify.verify_and_decode(signed_compact + '?', validator)
91
92    # modified signature
93    with self.assertRaises(tink.TinkError):
94      verify.verify_and_decode(signed_compact + 'a', validator)
95
96    # modified header
97    with self.assertRaises(tink.TinkError):
98      verify.verify_and_decode('a' + signed_compact, validator)
99
100  def test_create_sign_verify_with_type_header(self):
101    handle = tink.new_keyset_handle(jwt.jwt_es256_template())
102    sign = handle.primitive(jwt.JwtPublicKeySign)
103    verify = handle.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify)
104    raw_jwt = jwt.new_raw_jwt(
105        type_header='typeHeader', issuer='joe', without_expiration=True)
106    signed_compact = sign.sign_and_encode(raw_jwt)
107
108    validator = jwt.new_validator(
109        expected_type_header='typeHeader',
110        expected_issuer='joe',
111        allow_missing_expiration=True)
112    verified_jwt = verify.verify_and_decode(signed_compact, validator)
113    self.assertEqual(verified_jwt.type_header(), 'typeHeader')
114
115  def test_verify_with_other_key_fails(self):
116    handle = tink.new_keyset_handle(jwt.jwt_es256_template())
117    sign = handle.primitive(jwt.JwtPublicKeySign)
118    raw_jwt = jwt.new_raw_jwt(issuer='issuer', without_expiration=True)
119    compact = sign.sign_and_encode(raw_jwt)
120
121    other_handle = tink.new_keyset_handle(jwt.jwt_es256_template())
122    other_verify = other_handle.public_keyset_handle().primitive(
123        jwt.JwtPublicKeyVerify)
124    with self.assertRaises(tink.TinkError):
125      other_verify.verify_and_decode(
126          compact,
127          jwt.new_validator(
128              expected_issuer='issuer', allow_missing_expiration=True))
129
130  def test_weird_tokens_with_valid_signatures(self):
131    handle = tink.new_keyset_handle(jwt.raw_jwt_es256_template())
132    sign = handle.primitive(jwt.JwtPublicKeySign)
133    # Get the internal PublicKeySign primitive to create valid signatures.
134    wrapped = cast(_jwt_signature_wrappers._WrappedJwtPublicKeySign, sign)
135    raw_sign = cast(_jwt_signature_key_manager._JwtPublicKeySign,
136                    wrapped._primitive_set.primary().primitive)._public_key_sign
137
138    verify = handle.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify)
139    validator = jwt.new_validator(
140        expected_issuer='issuer', allow_missing_expiration=True)
141
142    # Normal token.
143    valid = gen_compact('{"alg":"ES256"}', '{"iss":"issuer"}', raw_sign)
144    verified_jwt = verify.verify_and_decode(valid, validator)
145    self.assertEqual(verified_jwt.issuer(), 'issuer')
146
147    # Token with unknown header is valid.
148    unknown_header = gen_compact('{"alg":"ES256","unknown_header":"abc"} \n ',
149                                 '{"iss":"issuer" }', raw_sign)
150    verified_jwt = verify.verify_and_decode(unknown_header, validator)
151    self.assertEqual(verified_jwt.issuer(), 'issuer')
152
153    # Token with unknown kid is valid, since primitives with output prefix type
154    # RAW ignore kid headers.
155    unknown_header = gen_compact('{"alg":"ES256","kid":"unknown"} \n ',
156                                 '{"iss":"issuer" }', raw_sign)
157    verified_jwt = verify.verify_and_decode(unknown_header, validator)
158    self.assertEqual(verified_jwt.issuer(), 'issuer')
159
160    # Token with invalid alg header
161    alg_invalid = gen_compact('{"alg":"ES384"}', '{"iss":"issuer"}', raw_sign)
162    with self.assertRaises(tink.TinkError):
163      verify.verify_and_decode(alg_invalid, validator)
164
165    # Token with empty header
166    empty_header = gen_compact('{}', '{"iss":"issuer"}', raw_sign)
167    with self.assertRaises(tink.TinkError):
168      verify.verify_and_decode(empty_header, validator)
169
170    # Token header is not valid JSON
171    header_invalid = gen_compact('{"alg":"ES256"', '{"iss":"issuer"}', raw_sign)
172    with self.assertRaises(tink.TinkError):
173      verify.verify_and_decode(header_invalid, validator)
174
175    # Token payload is not valid JSON
176    payload_invalid = gen_compact('{"alg":"ES256"}', '{"iss":"issuer"',
177                                  raw_sign)
178    with self.assertRaises(tink.TinkError):
179      verify.verify_and_decode(payload_invalid, validator)
180
181    # Token with whitespace in header JSON string is valid.
182    whitespace_in_header = gen_compact(' {"alg":   \n  "ES256"} \n ',
183                                       '{"iss":"issuer" }', raw_sign)
184    verified_jwt = verify.verify_and_decode(whitespace_in_header, validator)
185    self.assertEqual(verified_jwt.issuer(), 'issuer')
186
187    # Token with whitespace in payload JSON string is valid.
188    whitespace_in_payload = gen_compact('{"alg":"ES256"}',
189                                        ' {"iss": \n"issuer" } \n', raw_sign)
190    verified_jwt = verify.verify_and_decode(whitespace_in_payload, validator)
191    self.assertEqual(verified_jwt.issuer(), 'issuer')
192
193    # Token with whitespace in base64-encoded header is invalid.
194    with_whitespace = (
195        _jwt_format.encode_header('{"alg":"ES256"}') + b' .' +
196        _jwt_format.encode_payload('{"iss":"issuer"}'))
197    token_with_whitespace = _jwt_format.create_signed_compact(
198        with_whitespace, raw_sign.sign(with_whitespace))
199    with self.assertRaises(tink.TinkError):
200      verify.verify_and_decode(token_with_whitespace, validator)
201
202    # Token with invalid character is invalid.
203    with_invalid_char = (
204        _jwt_format.encode_header('{"alg":"ES256"}') + b'.?' +
205        _jwt_format.encode_payload('{"iss":"issuer"}'))
206    token_with_invalid_char = _jwt_format.create_signed_compact(
207        with_invalid_char, raw_sign.sign(with_invalid_char))
208    with self.assertRaises(tink.TinkError):
209      verify.verify_and_decode(token_with_invalid_char, validator)
210
211    # Token with additional '.' is invalid.
212    with_dot = (
213        _jwt_format.encode_header('{"alg":"ES256"}') + b'.' +
214        _jwt_format.encode_payload('{"iss":"issuer"}') + b'.')
215    token_with_dot = _jwt_format.create_signed_compact(
216        with_dot, raw_sign.sign(with_dot))
217    with self.assertRaises(tink.TinkError):
218      verify.verify_and_decode(token_with_dot, validator)
219
220    # num_recursions has been chosen such that parsing of this token fails
221    # in all languages. We want to make sure that the algorithm does not
222    # hang or crash in this case, but only returns a parsing error.
223    num_recursions = 10000
224    rec_payload = ('{"a":' * num_recursions) + '""' + ('}' * num_recursions)
225    rec_token = gen_compact('{"alg":"ES256"}', rec_payload, raw_sign)
226    with self.assertRaises(tink.TinkError):
227      verify.verify_and_decode(
228          rec_token, validator=jwt.new_validator(allow_missing_expiration=True))
229
230    # test wrong types
231    with self.assertRaises(tink.TinkError):
232      verify.verify_and_decode(cast(str, None), validator)
233    with self.assertRaises(tink.TinkError):
234      verify.verify_and_decode(cast(str, 123), validator)
235    with self.assertRaises(tink.TinkError):
236      valid_bytes = valid.encode('utf8')
237      verify.verify_and_decode(cast(str, valid_bytes), validator)
238
239  def test_create_ecdsa_handle_with_invalid_algorithm_fails(self):
240    key_format = jwt_ecdsa_pb2.JwtEcdsaKeyFormat(
241        algorithm=jwt_ecdsa_pb2.ES_UNKNOWN)
242    template = tink_pb2.KeyTemplate(
243        type_url='type.googleapis.com/google.crypto.tink.JwtEcdsaPrivateKey',
244        value=key_format.SerializeToString(),
245        output_prefix_type=tink_pb2.RAW)
246    with self.assertRaises(tink.TinkError):
247      tink.new_keyset_handle(template)
248
249  def test_create_sign_primitive_with_invalid_algorithm_fails(self):
250    handle = tink.new_keyset_handle(jwt.jwt_es256_template())
251    key = jwt_ecdsa_pb2.JwtEcdsaPrivateKey.FromString(
252        handle._keyset.key[0].key_data.value)
253    key.public_key.algorithm = jwt_ecdsa_pb2.ES_UNKNOWN
254    handle._keyset.key[0].key_data.value = key.SerializeToString()
255    with self.assertRaises(tink.TinkError):
256      handle.primitive(jwt.JwtPublicKeySign)
257
258  def test_create_verify_primitive_with_invalid_algorithm_fails(self):
259    private_handle = tink.new_keyset_handle(jwt.jwt_es256_template())
260    handle = private_handle.public_keyset_handle()
261    key = jwt_ecdsa_pb2.JwtEcdsaPublicKey.FromString(
262        handle._keyset.key[0].key_data.value)
263    key.algorithm = jwt_ecdsa_pb2.ES_UNKNOWN
264    handle._keyset.key[0].key_data.value = key.SerializeToString()
265    with self.assertRaises(tink.TinkError):
266      handle.primitive(jwt.JwtPublicKeyVerify)
267
268
269if __name__ == '__main__':
270  absltest.main()
271