1# Copyright 2021 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for tink.python.tink.jwt._jwt_mac_wrapper.""" 15 16import io 17 18from absl.testing import absltest 19from absl.testing import parameterized 20 21 22from tink.proto import jwt_hmac_pb2 23from tink.proto import tink_pb2 24import tink 25from tink import cleartext_keyset_handle 26from tink import jwt 27from tink.jwt import _json_util 28from tink.jwt import _jwt_format 29from tink.testing import keyset_builder 30 31 32def setUpModule(): 33 jwt.register_jwt_mac() 34 35 36def _set_custom_kid(keyset_handle: tink.KeysetHandle, 37 custom_kid: str) -> tink.KeysetHandle: 38 """Set the custom_kid field of the first key.""" 39 buffer = io.BytesIO() 40 cleartext_keyset_handle.write( 41 tink.BinaryKeysetWriter(buffer), keyset_handle) 42 keyset = tink_pb2.Keyset.FromString(buffer.getvalue()) 43 hmac_key = jwt_hmac_pb2.JwtHmacKey.FromString(keyset.key[0].key_data.value) 44 hmac_key.custom_kid.value = custom_kid 45 keyset.key[0].key_data.value = hmac_key.SerializeToString() 46 return cleartext_keyset_handle.from_keyset(keyset) 47 48 49def _change_key_id(keyset_handle: tink.KeysetHandle) -> tink.KeysetHandle: 50 """Changes the key id of the first key and sets it primary.""" 51 buffer = io.BytesIO() 52 cleartext_keyset_handle.write( 53 tink.BinaryKeysetWriter(buffer), keyset_handle) 54 keyset = tink_pb2.Keyset.FromString(buffer.getvalue()) 55 # XOR the key id with an arbitrary 32-bit string to get a new key id. 56 new_key_id = keyset.key[0].key_id ^ 0xdeadbeef 57 keyset.key[0].key_id = new_key_id 58 keyset.primary_key_id = new_key_id 59 return cleartext_keyset_handle.from_keyset(keyset) 60 61 62def _change_output_prefix_to_tink( 63 keyset_handle: tink.KeysetHandle) -> tink.KeysetHandle: 64 """Changes the output prefix type of the first key to TINK.""" 65 buffer = io.BytesIO() 66 cleartext_keyset_handle.write( 67 tink.BinaryKeysetWriter(buffer), keyset_handle) 68 keyset = tink_pb2.Keyset.FromString(buffer.getvalue()) 69 keyset.key[0].output_prefix_type = tink_pb2.TINK 70 return cleartext_keyset_handle.from_keyset(keyset) 71 72 73class JwtMacWrapperTest(parameterized.TestCase): 74 75 @parameterized.parameters([ 76 (jwt.raw_jwt_hs256_template(), jwt.raw_jwt_hs256_template()), 77 (jwt.raw_jwt_hs256_template(), jwt.jwt_hs256_template()), 78 (jwt.jwt_hs256_template(), jwt.raw_jwt_hs256_template()), 79 (jwt.jwt_hs256_template(), jwt.jwt_hs256_template()), 80 ]) 81 def test_key_rotation(self, old_key_tmpl, new_key_tmpl): 82 builder = keyset_builder.new_keyset_builder() 83 older_key_id = builder.add_new_key(old_key_tmpl) 84 85 builder.set_primary_key(older_key_id) 86 jwtmac1 = builder.keyset_handle().primitive(jwt.JwtMac) 87 88 newer_key_id = builder.add_new_key(new_key_tmpl) 89 jwtmac2 = builder.keyset_handle().primitive(jwt.JwtMac) 90 91 builder.set_primary_key(newer_key_id) 92 jwtmac3 = builder.keyset_handle().primitive(jwt.JwtMac) 93 94 builder.disable_key(older_key_id) 95 jwtmac4 = builder.keyset_handle().primitive(jwt.JwtMac) 96 97 raw_jwt = jwt.new_raw_jwt(issuer='a', without_expiration=True) 98 validator = jwt.new_validator( 99 expected_issuer='a', allow_missing_expiration=True) 100 101 self.assertNotEqual(older_key_id, newer_key_id) 102 # 1 uses the older key. So 1, 2 and 3 can verify the mac, but not 4. 103 compact1 = jwtmac1.compute_mac_and_encode(raw_jwt) 104 self.assertEqual( 105 jwtmac1.verify_mac_and_decode(compact1, validator).issuer(), 'a') 106 self.assertEqual( 107 jwtmac2.verify_mac_and_decode(compact1, validator).issuer(), 'a') 108 self.assertEqual( 109 jwtmac3.verify_mac_and_decode(compact1, validator).issuer(), 'a') 110 with self.assertRaises(tink.TinkError): 111 jwtmac4.verify_mac_and_decode(compact1, validator) 112 113 # 2 uses the older key. So 1, 2 and 3 can verify the mac, but not 4. 114 compact2 = jwtmac2.compute_mac_and_encode(raw_jwt) 115 self.assertEqual( 116 jwtmac1.verify_mac_and_decode(compact2, validator).issuer(), 'a') 117 self.assertEqual( 118 jwtmac2.verify_mac_and_decode(compact2, validator).issuer(), 'a') 119 self.assertEqual( 120 jwtmac3.verify_mac_and_decode(compact2, validator).issuer(), 'a') 121 with self.assertRaises(tink.TinkError): 122 jwtmac4.verify_mac_and_decode(compact2, validator) 123 124 # 3 uses the newer key. So 2, 3 and 4 can verify the mac, but not 1. 125 compact3 = jwtmac3.compute_mac_and_encode(raw_jwt) 126 with self.assertRaises(tink.TinkError): 127 jwtmac1.verify_mac_and_decode(compact3, validator) 128 self.assertEqual( 129 jwtmac2.verify_mac_and_decode(compact3, validator).issuer(), 'a') 130 self.assertEqual( 131 jwtmac3.verify_mac_and_decode(compact3, validator).issuer(), 'a') 132 self.assertEqual( 133 jwtmac4.verify_mac_and_decode(compact3, validator).issuer(), 'a') 134 135 # 4 uses the newer key. So 2, 3 and 4 can verify the mac, but not 1. 136 compact4 = jwtmac4.compute_mac_and_encode(raw_jwt) 137 with self.assertRaises(tink.TinkError): 138 jwtmac1.verify_mac_and_decode(compact4, validator) 139 self.assertEqual( 140 jwtmac2.verify_mac_and_decode(compact4, validator).issuer(), 'a') 141 self.assertEqual( 142 jwtmac3.verify_mac_and_decode(compact4, validator).issuer(), 'a') 143 self.assertEqual( 144 jwtmac4.verify_mac_and_decode(compact4, validator).issuer(), 'a') 145 146 def test_only_tink_output_prefix_type_encodes_a_kid_header(self): 147 handle = tink.new_keyset_handle(jwt.raw_jwt_hs256_template()) 148 jwt_mac = handle.primitive(jwt.JwtMac) 149 150 tink_handle = _change_output_prefix_to_tink(handle) 151 tink_jwt_mac = tink_handle.primitive(jwt.JwtMac) 152 153 raw_jwt = jwt.new_raw_jwt(issuer='issuer', without_expiration=True) 154 155 token = jwt_mac.compute_mac_and_encode(raw_jwt) 156 token_with_kid = tink_jwt_mac.compute_mac_and_encode(raw_jwt) 157 158 _, header, _, _ = _jwt_format.split_signed_compact(token) 159 self.assertNotIn('kid', _json_util.json_loads(header)) 160 161 _, header_with_kid, _, _ = _jwt_format.split_signed_compact(token_with_kid) 162 self.assertIn('kid', _json_util.json_loads(header_with_kid)) 163 164 validator = jwt.new_validator( 165 expected_issuer='issuer', allow_missing_expiration=True) 166 jwt_mac.verify_mac_and_decode(token, validator) 167 tink_jwt_mac.verify_mac_and_decode(token_with_kid, validator) 168 169 # With output prefix type RAW, a kid header is ignored 170 jwt_mac.verify_mac_and_decode(token_with_kid, validator) 171 # With output prefix type TINK, a kid header is required. 172 with self.assertRaises(tink.TinkError): 173 tink_jwt_mac.verify_mac_and_decode(token, validator) 174 175 other_handle = _change_key_id(tink_handle) 176 other_jwt_mac = other_handle.primitive(jwt.JwtMac) 177 # A token with a wrong kid is rejected, even if the signature is ok. 178 with self.assertRaises(tink.TinkError): 179 other_jwt_mac.verify_mac_and_decode(token_with_kid, validator) 180 181 def test_raw_output_prefix_type_encodes_a_custom_kid_header(self): 182 # normal HMAC jwt_mac with output prefix RAW 183 handle = tink.new_keyset_handle(jwt.raw_jwt_hs256_template()) 184 raw_jwt = jwt.new_raw_jwt(issuer='issuer', without_expiration=True) 185 validator = jwt.new_validator( 186 expected_issuer='issuer', allow_missing_expiration=True) 187 188 jwt_mac = handle.primitive(jwt.JwtMac) 189 token = jwt_mac.compute_mac_and_encode(raw_jwt) 190 jwt_mac.verify_mac_and_decode(token, validator) 191 192 _, json_header, _, _ = _jwt_format.split_signed_compact(token) 193 self.assertNotIn('kid', _json_util.json_loads(json_header)) 194 195 # HMAC jwt_mac with a custom_kid set 196 custom_kid_handle = _set_custom_kid(handle, custom_kid='my kid') 197 custom_kid_jwt_mac = custom_kid_handle.primitive(jwt.JwtMac) 198 token_with_kid = custom_kid_jwt_mac.compute_mac_and_encode(raw_jwt) 199 custom_kid_jwt_mac.verify_mac_and_decode(token_with_kid, validator) 200 201 _, header_with_kid, _, _ = _jwt_format.split_signed_compact(token_with_kid) 202 self.assertEqual(_json_util.json_loads(header_with_kid)['kid'], 'my kid') 203 204 # Even when custom_kid is set, its not required to be set in the header. 205 custom_kid_jwt_mac.verify_mac_and_decode(token, validator) 206 # An additional kid header is ignored. 207 jwt_mac.verify_mac_and_decode(token_with_kid, validator) 208 209 other_handle = _set_custom_kid(handle, custom_kid='other kid') 210 other_jwt_mac = other_handle.primitive(jwt.JwtMac) 211 with self.assertRaises(tink.TinkError): 212 # The custom_kid does not match the kid header. 213 other_jwt_mac.verify_mac_and_decode( 214 token_with_kid, validator) 215 216 tink_handle = _change_output_prefix_to_tink(custom_kid_handle) 217 tink_jwt_mac = tink_handle.primitive(jwt.JwtMac) 218 # having custom_kid set with output prefix TINK is not allowed 219 with self.assertRaises(tink.TinkError): 220 tink_jwt_mac.compute_mac_and_encode(raw_jwt) 221 with self.assertRaises(tink.TinkError): 222 tink_jwt_mac.verify_mac_and_decode(token, validator) 223 with self.assertRaises(tink.TinkError): 224 tink_jwt_mac.verify_mac_and_decode(token_with_kid, validator) 225 226 def test_legacy_key_fails(self): 227 template = keyset_builder.legacy_template(jwt.raw_jwt_hs256_template()) 228 builder = keyset_builder.new_keyset_builder() 229 key_id = builder.add_new_key(template) 230 builder.set_primary_key(key_id) 231 handle = builder.keyset_handle() 232 with self.assertRaises(tink.TinkError): 233 handle.primitive(jwt.JwtMac) 234 235 def test_legacy_non_primary_key_fails(self): 236 builder = keyset_builder.new_keyset_builder() 237 old_template = keyset_builder.legacy_template(jwt.raw_jwt_hs256_template()) 238 _ = builder.add_new_key(old_template) 239 current_key_id = builder.add_new_key(jwt.jwt_hs256_template()) 240 builder.set_primary_key(current_key_id) 241 handle = builder.keyset_handle() 242 with self.assertRaises(tink.TinkError): 243 handle.primitive(jwt.JwtMac) 244 245 def test_jwt_mac_from_keyset_without_primary_fails(self): 246 builder = keyset_builder.new_keyset_builder() 247 builder.add_new_key(jwt.raw_jwt_hs256_template()) 248 with self.assertRaises(tink.TinkError): 249 builder.keyset_handle() 250 251 252if __name__ == '__main__': 253 absltest.main() 254