xref: /aosp_15_r20/external/tink/python/tink/jwt/_jwt_mac_wrapper_test.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2021 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for tink.python.tink.jwt._jwt_mac_wrapper."""
15
16import io
17
18from absl.testing import absltest
19from absl.testing import parameterized
20
21
22from tink.proto import jwt_hmac_pb2
23from tink.proto import tink_pb2
24import tink
25from tink import cleartext_keyset_handle
26from tink import jwt
27from tink.jwt import _json_util
28from tink.jwt import _jwt_format
29from tink.testing import keyset_builder
30
31
32def setUpModule():
33  jwt.register_jwt_mac()
34
35
36def _set_custom_kid(keyset_handle: tink.KeysetHandle,
37                    custom_kid: str) -> tink.KeysetHandle:
38  """Set the custom_kid field of the first key."""
39  buffer = io.BytesIO()
40  cleartext_keyset_handle.write(
41      tink.BinaryKeysetWriter(buffer), keyset_handle)
42  keyset = tink_pb2.Keyset.FromString(buffer.getvalue())
43  hmac_key = jwt_hmac_pb2.JwtHmacKey.FromString(keyset.key[0].key_data.value)
44  hmac_key.custom_kid.value = custom_kid
45  keyset.key[0].key_data.value = hmac_key.SerializeToString()
46  return cleartext_keyset_handle.from_keyset(keyset)
47
48
49def _change_key_id(keyset_handle: tink.KeysetHandle) -> tink.KeysetHandle:
50  """Changes the key id of the first key and sets it primary."""
51  buffer = io.BytesIO()
52  cleartext_keyset_handle.write(
53      tink.BinaryKeysetWriter(buffer), keyset_handle)
54  keyset = tink_pb2.Keyset.FromString(buffer.getvalue())
55  # XOR the key id with an arbitrary 32-bit string to get a new key id.
56  new_key_id = keyset.key[0].key_id ^ 0xdeadbeef
57  keyset.key[0].key_id = new_key_id
58  keyset.primary_key_id = new_key_id
59  return cleartext_keyset_handle.from_keyset(keyset)
60
61
62def _change_output_prefix_to_tink(
63    keyset_handle: tink.KeysetHandle) -> tink.KeysetHandle:
64  """Changes the output prefix type of the first key to TINK."""
65  buffer = io.BytesIO()
66  cleartext_keyset_handle.write(
67      tink.BinaryKeysetWriter(buffer), keyset_handle)
68  keyset = tink_pb2.Keyset.FromString(buffer.getvalue())
69  keyset.key[0].output_prefix_type = tink_pb2.TINK
70  return cleartext_keyset_handle.from_keyset(keyset)
71
72
73class JwtMacWrapperTest(parameterized.TestCase):
74
75  @parameterized.parameters([
76      (jwt.raw_jwt_hs256_template(), jwt.raw_jwt_hs256_template()),
77      (jwt.raw_jwt_hs256_template(), jwt.jwt_hs256_template()),
78      (jwt.jwt_hs256_template(), jwt.raw_jwt_hs256_template()),
79      (jwt.jwt_hs256_template(), jwt.jwt_hs256_template()),
80  ])
81  def test_key_rotation(self, old_key_tmpl, new_key_tmpl):
82    builder = keyset_builder.new_keyset_builder()
83    older_key_id = builder.add_new_key(old_key_tmpl)
84
85    builder.set_primary_key(older_key_id)
86    jwtmac1 = builder.keyset_handle().primitive(jwt.JwtMac)
87
88    newer_key_id = builder.add_new_key(new_key_tmpl)
89    jwtmac2 = builder.keyset_handle().primitive(jwt.JwtMac)
90
91    builder.set_primary_key(newer_key_id)
92    jwtmac3 = builder.keyset_handle().primitive(jwt.JwtMac)
93
94    builder.disable_key(older_key_id)
95    jwtmac4 = builder.keyset_handle().primitive(jwt.JwtMac)
96
97    raw_jwt = jwt.new_raw_jwt(issuer='a', without_expiration=True)
98    validator = jwt.new_validator(
99        expected_issuer='a', allow_missing_expiration=True)
100
101    self.assertNotEqual(older_key_id, newer_key_id)
102    # 1 uses the older key. So 1, 2 and 3 can verify the mac, but not 4.
103    compact1 = jwtmac1.compute_mac_and_encode(raw_jwt)
104    self.assertEqual(
105        jwtmac1.verify_mac_and_decode(compact1, validator).issuer(), 'a')
106    self.assertEqual(
107        jwtmac2.verify_mac_and_decode(compact1, validator).issuer(), 'a')
108    self.assertEqual(
109        jwtmac3.verify_mac_and_decode(compact1, validator).issuer(), 'a')
110    with self.assertRaises(tink.TinkError):
111      jwtmac4.verify_mac_and_decode(compact1, validator)
112
113    # 2 uses the older key. So 1, 2 and 3 can verify the mac, but not 4.
114    compact2 = jwtmac2.compute_mac_and_encode(raw_jwt)
115    self.assertEqual(
116        jwtmac1.verify_mac_and_decode(compact2, validator).issuer(), 'a')
117    self.assertEqual(
118        jwtmac2.verify_mac_and_decode(compact2, validator).issuer(), 'a')
119    self.assertEqual(
120        jwtmac3.verify_mac_and_decode(compact2, validator).issuer(), 'a')
121    with self.assertRaises(tink.TinkError):
122      jwtmac4.verify_mac_and_decode(compact2, validator)
123
124    # 3 uses the newer key. So 2, 3 and 4 can verify the mac, but not 1.
125    compact3 = jwtmac3.compute_mac_and_encode(raw_jwt)
126    with self.assertRaises(tink.TinkError):
127      jwtmac1.verify_mac_and_decode(compact3, validator)
128    self.assertEqual(
129        jwtmac2.verify_mac_and_decode(compact3, validator).issuer(), 'a')
130    self.assertEqual(
131        jwtmac3.verify_mac_and_decode(compact3, validator).issuer(), 'a')
132    self.assertEqual(
133        jwtmac4.verify_mac_and_decode(compact3, validator).issuer(), 'a')
134
135    # 4 uses the newer key. So 2, 3 and 4 can verify the mac, but not 1.
136    compact4 = jwtmac4.compute_mac_and_encode(raw_jwt)
137    with self.assertRaises(tink.TinkError):
138      jwtmac1.verify_mac_and_decode(compact4, validator)
139    self.assertEqual(
140        jwtmac2.verify_mac_and_decode(compact4, validator).issuer(), 'a')
141    self.assertEqual(
142        jwtmac3.verify_mac_and_decode(compact4, validator).issuer(), 'a')
143    self.assertEqual(
144        jwtmac4.verify_mac_and_decode(compact4, validator).issuer(), 'a')
145
146  def test_only_tink_output_prefix_type_encodes_a_kid_header(self):
147    handle = tink.new_keyset_handle(jwt.raw_jwt_hs256_template())
148    jwt_mac = handle.primitive(jwt.JwtMac)
149
150    tink_handle = _change_output_prefix_to_tink(handle)
151    tink_jwt_mac = tink_handle.primitive(jwt.JwtMac)
152
153    raw_jwt = jwt.new_raw_jwt(issuer='issuer', without_expiration=True)
154
155    token = jwt_mac.compute_mac_and_encode(raw_jwt)
156    token_with_kid = tink_jwt_mac.compute_mac_and_encode(raw_jwt)
157
158    _, header, _, _ = _jwt_format.split_signed_compact(token)
159    self.assertNotIn('kid', _json_util.json_loads(header))
160
161    _, header_with_kid, _, _ = _jwt_format.split_signed_compact(token_with_kid)
162    self.assertIn('kid', _json_util.json_loads(header_with_kid))
163
164    validator = jwt.new_validator(
165        expected_issuer='issuer', allow_missing_expiration=True)
166    jwt_mac.verify_mac_and_decode(token, validator)
167    tink_jwt_mac.verify_mac_and_decode(token_with_kid, validator)
168
169    # With output prefix type RAW, a kid header is ignored
170    jwt_mac.verify_mac_and_decode(token_with_kid, validator)
171    # With output prefix type TINK, a kid header is required.
172    with self.assertRaises(tink.TinkError):
173      tink_jwt_mac.verify_mac_and_decode(token, validator)
174
175    other_handle = _change_key_id(tink_handle)
176    other_jwt_mac = other_handle.primitive(jwt.JwtMac)
177    # A token with a wrong kid is rejected, even if the signature is ok.
178    with self.assertRaises(tink.TinkError):
179      other_jwt_mac.verify_mac_and_decode(token_with_kid, validator)
180
181  def test_raw_output_prefix_type_encodes_a_custom_kid_header(self):
182    # normal HMAC jwt_mac with output prefix RAW
183    handle = tink.new_keyset_handle(jwt.raw_jwt_hs256_template())
184    raw_jwt = jwt.new_raw_jwt(issuer='issuer', without_expiration=True)
185    validator = jwt.new_validator(
186        expected_issuer='issuer', allow_missing_expiration=True)
187
188    jwt_mac = handle.primitive(jwt.JwtMac)
189    token = jwt_mac.compute_mac_and_encode(raw_jwt)
190    jwt_mac.verify_mac_and_decode(token, validator)
191
192    _, json_header, _, _ = _jwt_format.split_signed_compact(token)
193    self.assertNotIn('kid', _json_util.json_loads(json_header))
194
195    # HMAC jwt_mac with a custom_kid set
196    custom_kid_handle = _set_custom_kid(handle, custom_kid='my kid')
197    custom_kid_jwt_mac = custom_kid_handle.primitive(jwt.JwtMac)
198    token_with_kid = custom_kid_jwt_mac.compute_mac_and_encode(raw_jwt)
199    custom_kid_jwt_mac.verify_mac_and_decode(token_with_kid, validator)
200
201    _, header_with_kid, _, _ = _jwt_format.split_signed_compact(token_with_kid)
202    self.assertEqual(_json_util.json_loads(header_with_kid)['kid'], 'my kid')
203
204    # Even when custom_kid is set, its not required to be set in the header.
205    custom_kid_jwt_mac.verify_mac_and_decode(token, validator)
206    # An additional kid header is ignored.
207    jwt_mac.verify_mac_and_decode(token_with_kid, validator)
208
209    other_handle = _set_custom_kid(handle, custom_kid='other kid')
210    other_jwt_mac = other_handle.primitive(jwt.JwtMac)
211    with self.assertRaises(tink.TinkError):
212      # The custom_kid does not match the kid header.
213      other_jwt_mac.verify_mac_and_decode(
214          token_with_kid, validator)
215
216    tink_handle = _change_output_prefix_to_tink(custom_kid_handle)
217    tink_jwt_mac = tink_handle.primitive(jwt.JwtMac)
218    # having custom_kid set with output prefix TINK is not allowed
219    with self.assertRaises(tink.TinkError):
220      tink_jwt_mac.compute_mac_and_encode(raw_jwt)
221    with self.assertRaises(tink.TinkError):
222      tink_jwt_mac.verify_mac_and_decode(token, validator)
223    with self.assertRaises(tink.TinkError):
224      tink_jwt_mac.verify_mac_and_decode(token_with_kid, validator)
225
226  def test_legacy_key_fails(self):
227    template = keyset_builder.legacy_template(jwt.raw_jwt_hs256_template())
228    builder = keyset_builder.new_keyset_builder()
229    key_id = builder.add_new_key(template)
230    builder.set_primary_key(key_id)
231    handle = builder.keyset_handle()
232    with self.assertRaises(tink.TinkError):
233      handle.primitive(jwt.JwtMac)
234
235  def test_legacy_non_primary_key_fails(self):
236    builder = keyset_builder.new_keyset_builder()
237    old_template = keyset_builder.legacy_template(jwt.raw_jwt_hs256_template())
238    _ = builder.add_new_key(old_template)
239    current_key_id = builder.add_new_key(jwt.jwt_hs256_template())
240    builder.set_primary_key(current_key_id)
241    handle = builder.keyset_handle()
242    with self.assertRaises(tink.TinkError):
243      handle.primitive(jwt.JwtMac)
244
245  def test_jwt_mac_from_keyset_without_primary_fails(self):
246    builder = keyset_builder.new_keyset_builder()
247    builder.add_new_key(jwt.raw_jwt_hs256_template())
248    with self.assertRaises(tink.TinkError):
249      builder.keyset_handle()
250
251
252if __name__ == '__main__':
253  absltest.main()
254