xref: /aosp_15_r20/external/tink/testing/cross_language/util/testing_servers_test.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS-IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for tink.testing.cross_language.util.testing_server."""
15
16import datetime
17import io
18import textwrap
19from typing import Iterable, Tuple
20
21from absl import flags
22from absl.testing import absltest
23from absl.testing import parameterized
24import tink
25from tink import aead
26from tink import daead
27from tink import hybrid
28from tink import jwt
29from tink import mac
30from tink import prf
31from tink import signature
32from tink import streaming_aead
33
34from tink.proto import tink_pb2
35from util import key_util
36from util import test_keys
37from util import testing_servers
38
39_SUPPORTED_LANGUAGES = testing_servers.SUPPORTED_LANGUAGES_BY_PRIMITIVE
40
41_HEX_TEMPLATE = flags.DEFINE_string(
42    'hex_template',
43    aead.aead_key_templates.AES256_GCM.SerializeToString().hex(),
44    'The template in hex format to use in the create_keyset test.'
45)
46
47_FORCE_FAILURE_FOR_ADDING_KEY_TO_DB = flags.DEFINE_boolean(
48    'force_failure_for_adding_key_to_db', False,
49    'Set to force a message which helps to add a new key to the DB.')
50
51_MESSAGE_TEMPLATE = '''
52Please add the following to _test_keys_db.py:
53COPY PASTE START ===============================================================
54db.add_key(
55    template=r"""
56{template_text_format}""",
57    key=r"""
58{key_text_format}""")
59COPY PASTE END =================================================================
60'''
61
62
63def setUpModule():
64  aead.register()
65  daead.register()
66  hybrid.register()
67  jwt.register_jwt_mac()
68  jwt.register_jwt_signature()
69  mac.register()
70  prf.register()
71  signature.register()
72  streaming_aead.register()
73
74
75class TestingServersConfigTest(absltest.TestCase):
76
77  def test_primitives(self):
78    self.assertEqual(
79        testing_servers._PRIMITIVE_STUBS.keys(),
80        _SUPPORTED_LANGUAGES.keys(),
81        msg=(
82            'The primitives specified as keys in '
83            'testing_servers._PRIMITIVE_STUBS must match the primitives '
84            ' specified as keys in '
85            'testing_servers.SUPPORTED_LANGUAGES_BY_PRIMITIVE.'
86        ))
87
88  def test_languages(self):
89    for primitive in _SUPPORTED_LANGUAGES:
90      languages = set(testing_servers.LANGUAGES)
91      supported_languages = set(_SUPPORTED_LANGUAGES[primitive])
92      self.assertContainsSubset(supported_languages, languages, msg=(
93          'The languages specified in '
94          'testing_servers.SUPPORTED_LANGUAGES_BY_PRIMITIVE must be a subset '
95          'of the languages specified in testing_servers.LANGUAGES.'
96      ))
97
98
99def encrypted_keyset_test_cases() -> Iterable[Tuple[str, str, str]]:
100  for lang in testing_servers.LANGUAGES:
101    for reader_type, writer_type in testing_servers.KEYSET_READER_WRITER_TYPES:
102      yield (lang, reader_type, writer_type)
103
104
105class TestingServersTest(parameterized.TestCase):
106
107  @classmethod
108  def setUpClass(cls):
109    super(TestingServersTest, cls).setUpClass()
110    testing_servers.start('testing_server')
111
112  @classmethod
113  def tearDownClass(cls):
114    testing_servers.stop()
115    super(TestingServersTest, cls).tearDownClass()
116
117  @parameterized.parameters(testing_servers.LANGUAGES)
118  def test_get_template(self, lang):
119    template = testing_servers.key_template(lang, 'AES128_GCM')
120    self.assertEqual(template.type_url,
121                     'type.googleapis.com/google.crypto.tink.AesGcmKey')
122
123  @parameterized.parameters(testing_servers.LANGUAGES)
124  def test_new_keyset(self, lang):
125    """Tests that we can create a new keyset in each language.
126
127    This test also serves to add new keys to the _test_keys_db -- see the
128    comments there.
129
130    Args:
131      lang: language to use for the test
132    """
133    template = tink_pb2.KeyTemplate().FromString(
134        bytes.fromhex(_HEX_TEMPLATE.value))
135    keyset = testing_servers.new_keyset(lang, template)
136    parsed_keyset = tink_pb2.Keyset.FromString(keyset)
137    self.assertLen(parsed_keyset.key, 1)
138    if _FORCE_FAILURE_FOR_ADDING_KEY_TO_DB.value:
139      self.fail(
140          _MESSAGE_TEMPLATE.format(
141              template_text_format=textwrap.indent(
142                  key_util.text_format(template), ' ' * 6),
143              key_text_format=textwrap.indent(
144                  key_util.text_format(parsed_keyset.key[0]), ' ' * 6)))
145
146  @parameterized.parameters([
147      aead.Aead, daead.DeterministicAead, streaming_aead.StreamingAead,
148      hybrid.HybridDecrypt, hybrid.HybridEncrypt, mac.Mac,
149      signature.PublicKeySign, signature.PublicKeyVerify, prf.PrfSet,
150      jwt.JwtMac, jwt.JwtPublicKeySign, jwt.JwtPublicKeyVerify
151  ])
152  def test_create_with_correct_keyset(self, primitive):
153    keyset = test_keys.some_keyset_for_primitive(primitive)
154    _ = testing_servers.remote_primitive('python', keyset, primitive)
155
156  @parameterized.parameters([
157      aead.Aead, daead.DeterministicAead, streaming_aead.StreamingAead,
158      hybrid.HybridDecrypt, hybrid.HybridEncrypt, mac.Mac,
159      signature.PublicKeySign, signature.PublicKeyVerify, prf.PrfSet,
160      jwt.JwtMac, jwt.JwtPublicKeySign, jwt.JwtPublicKeyVerify
161  ])
162  def test_create_with_incorrect_keyset(self, primitive):
163    wrong_primitive = aead.Aead if primitive == mac.Mac else mac.Mac
164    keyset = test_keys.some_keyset_for_primitive(wrong_primitive)
165    with self.assertRaises(tink.TinkError):
166      testing_servers.remote_primitive('python', keyset, primitive)
167
168  @parameterized.parameters(encrypted_keyset_test_cases())
169  def test_read_write_encrypted_keyset(self, lang, keyset_reader_type,
170                                       keyset_writer_type):
171    keyset = testing_servers.new_keyset(lang,
172                                        aead.aead_key_templates.AES128_GCM)
173    master_keyset = testing_servers.new_keyset(
174        lang, aead.aead_key_templates.AES128_GCM)
175    encrypted_keyset = testing_servers.keyset_write_encrypted(
176        lang, keyset, master_keyset, b'associated_data', keyset_writer_type)
177    output_keyset = testing_servers.keyset_read_encrypted(
178        lang, encrypted_keyset, master_keyset, b'associated_data',
179        keyset_reader_type)
180    self.assertEqual(output_keyset, keyset)
181
182    with self.assertRaises(tink.TinkError):
183      testing_servers.keyset_read_encrypted(lang, encrypted_keyset,
184                                            master_keyset,
185                                            b'invalid_associated_data',
186                                            keyset_reader_type)
187    with self.assertRaises(tink.TinkError):
188      testing_servers.keyset_read_encrypted(lang, b'invalid_encrypted_keyset',
189                                            master_keyset, b'associated_data',
190                                            keyset_reader_type)
191    with self.assertRaises(tink.TinkError):
192      testing_servers.keyset_read_encrypted(lang, encrypted_keyset,
193                                            b'invalid_master_keyset',
194                                            b'associated_data',
195                                            keyset_reader_type)
196    with self.assertRaises(tink.TinkError):
197      testing_servers.keyset_write_encrypted(lang, keyset,
198                                             b'invalid_master_keyset',
199                                             b'associated_data',
200                                             keyset_writer_type)
201
202  @parameterized.parameters(_SUPPORTED_LANGUAGES['aead'])
203  def test_aead(self, lang):
204    keyset = testing_servers.new_keyset(lang,
205                                        aead.aead_key_templates.AES128_GCM)
206    plaintext = b'The quick brown fox jumps over the lazy dog'
207    associated_data = b'associated_data'
208    aead_primitive = testing_servers.remote_primitive(lang, keyset, aead.Aead)
209    ciphertext = aead_primitive.encrypt(plaintext, associated_data)
210    output = aead_primitive.decrypt(ciphertext, associated_data)
211    self.assertEqual(output, plaintext)
212
213    with self.assertRaises(tink.TinkError):
214      aead_primitive.decrypt(b'foo', associated_data)
215
216  @parameterized.parameters(_SUPPORTED_LANGUAGES['daead'])
217  def test_daead(self, lang):
218    keyset = testing_servers.new_keyset(
219        lang, daead.deterministic_aead_key_templates.AES256_SIV)
220    plaintext = b'The quick brown fox jumps over the lazy dog'
221    associated_data = b'associated_data'
222    daead_primitive = testing_servers.remote_primitive(lang, keyset,
223                                                       daead.DeterministicAead)
224    ciphertext = daead_primitive.encrypt_deterministically(
225        plaintext, associated_data)
226    output = daead_primitive.decrypt_deterministically(
227        ciphertext, associated_data)
228    self.assertEqual(output, plaintext)
229
230    with self.assertRaises(tink.TinkError):
231      daead_primitive.decrypt_deterministically(b'foo', associated_data)
232
233  @parameterized.parameters(_SUPPORTED_LANGUAGES['streaming_aead'])
234  def test_streaming_aead(self, lang):
235    keyset = testing_servers.new_keyset(
236        lang, streaming_aead.streaming_aead_key_templates.AES128_GCM_HKDF_4KB)
237    plaintext = b'The quick brown fox jumps over the lazy dog'
238    plaintext_stream = io.BytesIO(plaintext)
239    associated_data = b'associated_data'
240    streaming_aead_primitive = testing_servers.remote_primitive(
241        lang, keyset, streaming_aead.StreamingAead)
242    ciphertext_stream = streaming_aead_primitive.new_encrypting_stream(
243        plaintext_stream, associated_data)
244    output_stream = streaming_aead_primitive.new_decrypting_stream(
245        ciphertext_stream, associated_data)
246    self.assertEqual(output_stream.read(), plaintext)
247
248    with self.assertRaises(tink.TinkError):
249      streaming_aead_primitive.new_decrypting_stream(io.BytesIO(b'foo'),
250                                                     associated_data)
251
252  @parameterized.parameters(_SUPPORTED_LANGUAGES['mac'])
253  def test_mac(self, lang):
254    keyset = testing_servers.new_keyset(
255        lang, mac.mac_key_templates.HMAC_SHA256_128BITTAG)
256    data = b'The quick brown fox jumps over the lazy dog'
257    mac_primitive = testing_servers.remote_primitive(lang, keyset, mac.Mac)
258    mac_value = mac_primitive.compute_mac(data)
259    mac_primitive.verify_mac(mac_value, data)
260
261    with self.assertRaises(tink.TinkError):
262      mac_primitive.verify_mac(b'foo', data)
263
264  @parameterized.parameters(_SUPPORTED_LANGUAGES['hybrid'])
265  def test_hybrid(self, lang):
266    private_handle = testing_servers.new_keyset(
267        lang,
268        hybrid.hybrid_key_templates.ECIES_P256_HKDF_HMAC_SHA256_AES128_GCM)
269    public_handle = testing_servers.public_keyset(lang, private_handle)
270    enc_primitive = testing_servers.remote_primitive(lang, public_handle,
271                                                     hybrid.HybridEncrypt)
272    data = b'The quick brown fox jumps over the lazy dog'
273    context_info = b'context'
274    ciphertext = enc_primitive.encrypt(data, context_info)
275    dec_primitive = testing_servers.remote_primitive(lang, private_handle,
276                                                     hybrid.HybridDecrypt)
277    output = dec_primitive.decrypt(ciphertext, context_info)
278    self.assertEqual(output, data)
279
280    with self.assertRaises(tink.TinkError):
281      dec_primitive.decrypt(b'foo', context_info)
282
283  @parameterized.parameters(_SUPPORTED_LANGUAGES['signature'])
284  def test_signature(self, lang):
285    private_handle = testing_servers.new_keyset(
286        lang, signature.signature_key_templates.ED25519)
287    public_handle = testing_servers.public_keyset(lang, private_handle)
288    sign_primitive = testing_servers.remote_primitive(lang, private_handle,
289                                                      signature.PublicKeySign)
290    data = b'The quick brown fox jumps over the lazy dog'
291    signature_value = sign_primitive.sign(data)
292    verify_primitive = testing_servers.remote_primitive(
293        lang, public_handle, signature.PublicKeyVerify)
294    verify_primitive.verify(signature_value, data)
295
296    with self.assertRaises(tink.TinkError):
297      verify_primitive.verify(b'foo', data)
298
299  @parameterized.parameters(_SUPPORTED_LANGUAGES['prf'])
300  def test_prf(self, lang):
301    keyset = testing_servers.new_keyset(lang,
302                                        prf.prf_key_templates.HMAC_SHA256)
303    input_data = b'The quick brown fox jumps over the lazy dog'
304    prf_set_primitive = testing_servers.remote_primitive(
305        lang, keyset, prf.PrfSet)
306    output = prf_set_primitive.primary().compute(input_data, output_length=15)
307    self.assertLen(output, 15)
308
309    with self.assertRaises(tink.TinkError):
310      prf_set_primitive.primary().compute(input_data, output_length=123456)
311
312  @parameterized.parameters(_SUPPORTED_LANGUAGES['jwt'])
313  def test_jwt_mac(self, lang):
314    keyset = testing_servers.new_keyset(lang, jwt.jwt_hs256_template())
315
316    jwt_mac_primitive = testing_servers.remote_primitive(
317        lang, keyset, jwt.JwtMac)
318
319    now = datetime.datetime.now(tz=datetime.timezone.utc)
320    token = jwt.new_raw_jwt(
321        issuer='issuer',
322        subject='subject',
323        audiences=['audience1', 'audience2'],
324        jwt_id='jwt_id',
325        expiration=now + datetime.timedelta(seconds=10),
326        custom_claims={'switch': True, 'pi': 3.14159})
327    compact = jwt_mac_primitive.compute_mac_and_encode(token)
328    validator = jwt.new_validator(
329        expected_issuer='issuer',
330        expected_audience='audience1',
331        fixed_now=now)
332    verified_jwt = jwt_mac_primitive.verify_mac_and_decode(compact, validator)
333    self.assertEqual(verified_jwt.issuer(), 'issuer')
334    self.assertEqual(verified_jwt.subject(), 'subject')
335    self.assertEqual(verified_jwt.jwt_id(), 'jwt_id')
336    self.assertEqual(verified_jwt.custom_claim('switch'), True)
337    self.assertEqual(verified_jwt.custom_claim('pi'), 3.14159)
338
339    validator2 = jwt.new_validator(
340        expected_audience='wrong_audience', fixed_now=now)
341    with self.assertRaises(tink.TinkError):
342      jwt_mac_primitive.verify_mac_and_decode(compact, validator2)
343
344  @parameterized.parameters(_SUPPORTED_LANGUAGES['jwt'])
345  def test_jwt_public_key_sign_verify(self, lang):
346    private_keyset = testing_servers.new_keyset(lang, jwt.jwt_es256_template())
347    public_keyset = testing_servers.public_keyset(lang, private_keyset)
348
349    signer = testing_servers.remote_primitive(lang, private_keyset,
350                                              jwt.JwtPublicKeySign)
351    verifier = testing_servers.remote_primitive(lang, public_keyset,
352                                                jwt.JwtPublicKeyVerify)
353
354    now = datetime.datetime.now(tz=datetime.timezone.utc)
355    token = jwt.new_raw_jwt(
356        issuer='issuer',
357        subject='subject',
358        audiences=['audience1', 'audience2'],
359        jwt_id='jwt_id',
360        expiration=now + datetime.timedelta(seconds=10),
361        custom_claims={'switch': True, 'pi': 3.14159})
362    compact = signer.sign_and_encode(token)
363    validator = jwt.new_validator(
364        expected_issuer='issuer',
365        expected_audience='audience1',
366        fixed_now=now)
367    verified_jwt = verifier.verify_and_decode(compact, validator)
368    self.assertEqual(verified_jwt.issuer(), 'issuer')
369    self.assertEqual(verified_jwt.subject(), 'subject')
370    self.assertEqual(verified_jwt.jwt_id(), 'jwt_id')
371    self.assertEqual(verified_jwt.custom_claim('switch'), True)
372    self.assertEqual(verified_jwt.custom_claim('pi'), 3.14159)
373
374    validator2 = jwt.new_validator(
375        expected_audience='wrong_audience', fixed_now=now)
376    with self.assertRaises(tink.TinkError):
377      verifier.verify_and_decode(compact, validator2)
378
379  @parameterized.parameters(['java'])
380  def test_jwt_public_key_sign_export_import_verify(self, lang):
381    private_keyset = testing_servers.new_keyset(lang, jwt.jwt_es256_template())
382    public_keyset = testing_servers.public_keyset(lang, private_keyset)
383
384    # sign and export public key
385    signer = testing_servers.remote_primitive(lang, private_keyset,
386                                              jwt.JwtPublicKeySign)
387    now = datetime.datetime.now(tz=datetime.timezone.utc)
388    token = jwt.new_raw_jwt(
389        jwt_id='jwt_id', expiration=now + datetime.timedelta(seconds=100))
390    compact = signer.sign_and_encode(token)
391    public_jwk_set = testing_servers.jwk_set_from_keyset(lang, public_keyset)
392
393    # verify using public_jwk_set
394    imported_public_keyset = testing_servers.jwk_set_to_keyset(
395        lang, public_jwk_set)
396
397    verifier = testing_servers.remote_primitive(lang, imported_public_keyset,
398                                                jwt.JwtPublicKeyVerify)
399    validator = jwt.new_validator(fixed_now=now)
400    verified_jwt = verifier.verify_and_decode(compact, validator)
401    self.assertEqual(verified_jwt.jwt_id(), 'jwt_id')
402
403
404if __name__ == '__main__':
405  absltest.main()
406