1# Copyright 2021 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for tink.python.tink.jwt._jwt_signature_key_manager.""" 15 16import datetime 17 18from typing import cast 19 20from absl.testing import absltest 21from absl.testing import parameterized 22 23from tink.proto import jwt_ecdsa_pb2 24from tink.proto import tink_pb2 25import tink 26from tink import jwt 27from tink.jwt import _jwt_format 28 29from tink.jwt import _jwt_signature_key_manager 30from tink.jwt import _jwt_signature_wrappers 31 32DATETIME_1970 = datetime.datetime.fromtimestamp(12345, datetime.timezone.utc) 33DATETIME_2011 = datetime.datetime.fromtimestamp(1300819380, 34 datetime.timezone.utc) 35DATETIME_2020 = datetime.datetime.fromtimestamp(1582230020, 36 datetime.timezone.utc) 37 38 39def setUpModule(): 40 jwt.register_jwt_signature() 41 42 43def gen_compact(json_header: str, json_payload: str, raw_sign) -> str: 44 unsigned_compact = ( 45 _jwt_format.encode_header(json_header) + b'.' + 46 _jwt_format.encode_payload(json_payload)) 47 signature = raw_sign.sign(unsigned_compact) 48 return _jwt_format.create_signed_compact(unsigned_compact, signature) 49 50 51class JwtSignatureKeyManagerTest(parameterized.TestCase): 52 53 def test_create_sign_verify(self): 54 handle = tink.new_keyset_handle(jwt.jwt_es256_template()) 55 sign = handle.primitive(jwt.JwtPublicKeySign) 56 verify = handle.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify) 57 raw_jwt = jwt.new_raw_jwt( 58 issuer='joe', 59 expiration=DATETIME_2011, 60 custom_claims={'http://example.com/is_root': True}) 61 signed_compact = sign.sign_and_encode(raw_jwt) 62 63 validator = jwt.new_validator( 64 expected_issuer='joe', fixed_now=DATETIME_1970) 65 verified_jwt = verify.verify_and_decode(signed_compact, validator) 66 self.assertEqual(verified_jwt.issuer(), 'joe') 67 self.assertEqual(verified_jwt.expiration().year, 2011) 68 self.assertCountEqual(verified_jwt.custom_claim_names(), 69 ['http://example.com/is_root']) 70 self.assertTrue(verified_jwt.custom_claim('http://example.com/is_root')) 71 72 # fails because it is expired 73 with self.assertRaises(tink.TinkError): 74 verify.verify_and_decode( 75 signed_compact, 76 jwt.new_validator(expected_issuer='joe', fixed_now=DATETIME_2020)) 77 78 # wrong issuer 79 with self.assertRaises(tink.TinkError): 80 verify.verify_and_decode( 81 signed_compact, 82 jwt.new_validator(expected_issuer='jane', fixed_now=DATETIME_1970)) 83 84 # invalid format 85 with self.assertRaises(tink.TinkError): 86 verify.verify_and_decode(signed_compact + '.123', validator) 87 88 # invalid character 89 with self.assertRaises(tink.TinkError): 90 verify.verify_and_decode(signed_compact + '?', validator) 91 92 # modified signature 93 with self.assertRaises(tink.TinkError): 94 verify.verify_and_decode(signed_compact + 'a', validator) 95 96 # modified header 97 with self.assertRaises(tink.TinkError): 98 verify.verify_and_decode('a' + signed_compact, validator) 99 100 def test_create_sign_verify_with_type_header(self): 101 handle = tink.new_keyset_handle(jwt.jwt_es256_template()) 102 sign = handle.primitive(jwt.JwtPublicKeySign) 103 verify = handle.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify) 104 raw_jwt = jwt.new_raw_jwt( 105 type_header='typeHeader', issuer='joe', without_expiration=True) 106 signed_compact = sign.sign_and_encode(raw_jwt) 107 108 validator = jwt.new_validator( 109 expected_type_header='typeHeader', 110 expected_issuer='joe', 111 allow_missing_expiration=True) 112 verified_jwt = verify.verify_and_decode(signed_compact, validator) 113 self.assertEqual(verified_jwt.type_header(), 'typeHeader') 114 115 def test_verify_with_other_key_fails(self): 116 handle = tink.new_keyset_handle(jwt.jwt_es256_template()) 117 sign = handle.primitive(jwt.JwtPublicKeySign) 118 raw_jwt = jwt.new_raw_jwt(issuer='issuer', without_expiration=True) 119 compact = sign.sign_and_encode(raw_jwt) 120 121 other_handle = tink.new_keyset_handle(jwt.jwt_es256_template()) 122 other_verify = other_handle.public_keyset_handle().primitive( 123 jwt.JwtPublicKeyVerify) 124 with self.assertRaises(tink.TinkError): 125 other_verify.verify_and_decode( 126 compact, 127 jwt.new_validator( 128 expected_issuer='issuer', allow_missing_expiration=True)) 129 130 def test_weird_tokens_with_valid_signatures(self): 131 handle = tink.new_keyset_handle(jwt.raw_jwt_es256_template()) 132 sign = handle.primitive(jwt.JwtPublicKeySign) 133 # Get the internal PublicKeySign primitive to create valid signatures. 134 wrapped = cast(_jwt_signature_wrappers._WrappedJwtPublicKeySign, sign) 135 raw_sign = cast(_jwt_signature_key_manager._JwtPublicKeySign, 136 wrapped._primitive_set.primary().primitive)._public_key_sign 137 138 verify = handle.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify) 139 validator = jwt.new_validator( 140 expected_issuer='issuer', allow_missing_expiration=True) 141 142 # Normal token. 143 valid = gen_compact('{"alg":"ES256"}', '{"iss":"issuer"}', raw_sign) 144 verified_jwt = verify.verify_and_decode(valid, validator) 145 self.assertEqual(verified_jwt.issuer(), 'issuer') 146 147 # Token with unknown header is valid. 148 unknown_header = gen_compact('{"alg":"ES256","unknown_header":"abc"} \n ', 149 '{"iss":"issuer" }', raw_sign) 150 verified_jwt = verify.verify_and_decode(unknown_header, validator) 151 self.assertEqual(verified_jwt.issuer(), 'issuer') 152 153 # Token with unknown kid is valid, since primitives with output prefix type 154 # RAW ignore kid headers. 155 unknown_header = gen_compact('{"alg":"ES256","kid":"unknown"} \n ', 156 '{"iss":"issuer" }', raw_sign) 157 verified_jwt = verify.verify_and_decode(unknown_header, validator) 158 self.assertEqual(verified_jwt.issuer(), 'issuer') 159 160 # Token with invalid alg header 161 alg_invalid = gen_compact('{"alg":"ES384"}', '{"iss":"issuer"}', raw_sign) 162 with self.assertRaises(tink.TinkError): 163 verify.verify_and_decode(alg_invalid, validator) 164 165 # Token with empty header 166 empty_header = gen_compact('{}', '{"iss":"issuer"}', raw_sign) 167 with self.assertRaises(tink.TinkError): 168 verify.verify_and_decode(empty_header, validator) 169 170 # Token header is not valid JSON 171 header_invalid = gen_compact('{"alg":"ES256"', '{"iss":"issuer"}', raw_sign) 172 with self.assertRaises(tink.TinkError): 173 verify.verify_and_decode(header_invalid, validator) 174 175 # Token payload is not valid JSON 176 payload_invalid = gen_compact('{"alg":"ES256"}', '{"iss":"issuer"', 177 raw_sign) 178 with self.assertRaises(tink.TinkError): 179 verify.verify_and_decode(payload_invalid, validator) 180 181 # Token with whitespace in header JSON string is valid. 182 whitespace_in_header = gen_compact(' {"alg": \n "ES256"} \n ', 183 '{"iss":"issuer" }', raw_sign) 184 verified_jwt = verify.verify_and_decode(whitespace_in_header, validator) 185 self.assertEqual(verified_jwt.issuer(), 'issuer') 186 187 # Token with whitespace in payload JSON string is valid. 188 whitespace_in_payload = gen_compact('{"alg":"ES256"}', 189 ' {"iss": \n"issuer" } \n', raw_sign) 190 verified_jwt = verify.verify_and_decode(whitespace_in_payload, validator) 191 self.assertEqual(verified_jwt.issuer(), 'issuer') 192 193 # Token with whitespace in base64-encoded header is invalid. 194 with_whitespace = ( 195 _jwt_format.encode_header('{"alg":"ES256"}') + b' .' + 196 _jwt_format.encode_payload('{"iss":"issuer"}')) 197 token_with_whitespace = _jwt_format.create_signed_compact( 198 with_whitespace, raw_sign.sign(with_whitespace)) 199 with self.assertRaises(tink.TinkError): 200 verify.verify_and_decode(token_with_whitespace, validator) 201 202 # Token with invalid character is invalid. 203 with_invalid_char = ( 204 _jwt_format.encode_header('{"alg":"ES256"}') + b'.?' + 205 _jwt_format.encode_payload('{"iss":"issuer"}')) 206 token_with_invalid_char = _jwt_format.create_signed_compact( 207 with_invalid_char, raw_sign.sign(with_invalid_char)) 208 with self.assertRaises(tink.TinkError): 209 verify.verify_and_decode(token_with_invalid_char, validator) 210 211 # Token with additional '.' is invalid. 212 with_dot = ( 213 _jwt_format.encode_header('{"alg":"ES256"}') + b'.' + 214 _jwt_format.encode_payload('{"iss":"issuer"}') + b'.') 215 token_with_dot = _jwt_format.create_signed_compact( 216 with_dot, raw_sign.sign(with_dot)) 217 with self.assertRaises(tink.TinkError): 218 verify.verify_and_decode(token_with_dot, validator) 219 220 # num_recursions has been chosen such that parsing of this token fails 221 # in all languages. We want to make sure that the algorithm does not 222 # hang or crash in this case, but only returns a parsing error. 223 num_recursions = 10000 224 rec_payload = ('{"a":' * num_recursions) + '""' + ('}' * num_recursions) 225 rec_token = gen_compact('{"alg":"ES256"}', rec_payload, raw_sign) 226 with self.assertRaises(tink.TinkError): 227 verify.verify_and_decode( 228 rec_token, validator=jwt.new_validator(allow_missing_expiration=True)) 229 230 # test wrong types 231 with self.assertRaises(tink.TinkError): 232 verify.verify_and_decode(cast(str, None), validator) 233 with self.assertRaises(tink.TinkError): 234 verify.verify_and_decode(cast(str, 123), validator) 235 with self.assertRaises(tink.TinkError): 236 valid_bytes = valid.encode('utf8') 237 verify.verify_and_decode(cast(str, valid_bytes), validator) 238 239 def test_create_ecdsa_handle_with_invalid_algorithm_fails(self): 240 key_format = jwt_ecdsa_pb2.JwtEcdsaKeyFormat( 241 algorithm=jwt_ecdsa_pb2.ES_UNKNOWN) 242 template = tink_pb2.KeyTemplate( 243 type_url='type.googleapis.com/google.crypto.tink.JwtEcdsaPrivateKey', 244 value=key_format.SerializeToString(), 245 output_prefix_type=tink_pb2.RAW) 246 with self.assertRaises(tink.TinkError): 247 tink.new_keyset_handle(template) 248 249 def test_create_sign_primitive_with_invalid_algorithm_fails(self): 250 handle = tink.new_keyset_handle(jwt.jwt_es256_template()) 251 key = jwt_ecdsa_pb2.JwtEcdsaPrivateKey.FromString( 252 handle._keyset.key[0].key_data.value) 253 key.public_key.algorithm = jwt_ecdsa_pb2.ES_UNKNOWN 254 handle._keyset.key[0].key_data.value = key.SerializeToString() 255 with self.assertRaises(tink.TinkError): 256 handle.primitive(jwt.JwtPublicKeySign) 257 258 def test_create_verify_primitive_with_invalid_algorithm_fails(self): 259 private_handle = tink.new_keyset_handle(jwt.jwt_es256_template()) 260 handle = private_handle.public_keyset_handle() 261 key = jwt_ecdsa_pb2.JwtEcdsaPublicKey.FromString( 262 handle._keyset.key[0].key_data.value) 263 key.algorithm = jwt_ecdsa_pb2.ES_UNKNOWN 264 handle._keyset.key[0].key_data.value = key.SerializeToString() 265 with self.assertRaises(tink.TinkError): 266 handle.primitive(jwt.JwtPublicKeyVerify) 267 268 269if __name__ == '__main__': 270 absltest.main() 271