xref: /aosp_15_r20/external/tink/testing/python/jwt_service_test.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2021 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS-IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for tink.testing.python.jwt_service."""
15
16from absl.testing import absltest
17import grpc
18
19from tink import jwt
20
21from protos import testing_api_pb2
22import jwt_service
23import services
24
25
26class DummyServicerContext(grpc.ServicerContext):
27
28  def is_active(self):
29    pass
30
31  def time_remaining(self):
32    pass
33
34  def cancel(self):
35    pass
36
37  def add_callback(self, callback):
38    pass
39
40  def invocation_metadata(self):
41    pass
42
43  def peer(self):
44    pass
45
46  def peer_identities(self):
47    pass
48
49  def peer_identity_key(self):
50    pass
51
52  def auth_context(self):
53    pass
54
55  def set_compression(self, compression):
56    pass
57
58  def send_initial_metadata(self, initial_metadata):
59    pass
60
61  def set_trailing_metadata(self, trailing_metadata):
62    pass
63
64  def abort(self, code, details):
65    pass
66
67  def abort_with_status(self, status):
68    pass
69
70  def set_code(self, code):
71    pass
72
73  def set_details(self, details):
74    pass
75
76  def disable_next_message_compression(self):
77    pass
78
79
80class JwtServiceTest(absltest.TestCase):
81
82  _ctx = DummyServicerContext()
83
84  @classmethod
85  def setUpClass(cls):
86    super().setUpClass()
87    jwt.register_jwt_mac()
88    jwt.register_jwt_signature()
89
90  def test_create_jwt_mac(self):
91    keyset_servicer = services.KeysetServicer()
92    jwt_servicer = jwt_service.JwtServicer()
93
94    template = jwt.jwt_hs256_template().SerializeToString()
95    gen_request = testing_api_pb2.KeysetGenerateRequest(template=template)
96    gen_response = keyset_servicer.Generate(gen_request, self._ctx)
97    self.assertEqual(gen_response.WhichOneof('result'), 'keyset')
98
99    creation_request = testing_api_pb2.CreationRequest(
100        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
101            serialized_keyset=gen_response.keyset))
102    creation_response = jwt_servicer.CreateJwtMac(
103        creation_request, self._ctx)
104    self.assertEmpty(creation_response.err)
105
106  def test_create_jwt_mac_broken_keyset(self):
107    jwt_servicer = jwt_service.JwtServicer()
108
109    creation_request = testing_api_pb2.CreationRequest(
110        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
111            serialized_keyset=b'\x80'))
112    creation_response = jwt_servicer.CreateJwtMac(creation_request, self._ctx)
113    self.assertNotEmpty(creation_response.err)
114
115  def test_generate_compute_verify_mac(self):
116    keyset_servicer = services.KeysetServicer()
117    jwt_servicer = jwt_service.JwtServicer()
118
119    template = jwt.jwt_hs256_template().SerializeToString()
120    gen_request = testing_api_pb2.KeysetGenerateRequest(template=template)
121    gen_response = keyset_servicer.Generate(gen_request, self._ctx)
122    self.assertEqual(gen_response.WhichOneof('result'), 'keyset')
123    keyset = gen_response.keyset
124
125    comp_request = testing_api_pb2.JwtSignRequest(
126        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
127            serialized_keyset=keyset))
128    comp_request.raw_jwt.issuer.value = 'issuer'
129    comp_request.raw_jwt.subject.value = 'subject'
130    comp_request.raw_jwt.custom_claims['myclaim'].bool_value = True
131    comp_request.raw_jwt.expiration.seconds = 1334
132    comp_request.raw_jwt.expiration.nanos = 123000000
133
134    comp_response = jwt_servicer.ComputeMacAndEncode(comp_request, self._ctx)
135    self.assertEqual(comp_response.WhichOneof('result'), 'signed_compact_jwt')
136    signed_compact_jwt = comp_response.signed_compact_jwt
137    verify_request = testing_api_pb2.JwtVerifyRequest(
138        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
139            serialized_keyset=keyset),
140        signed_compact_jwt=signed_compact_jwt)
141    verify_request.validator.expected_issuer.value = 'issuer'
142    verify_request.validator.now.seconds = 1234
143    verify_response = jwt_servicer.VerifyMacAndDecode(verify_request, self._ctx)
144    self.assertEqual(verify_response.WhichOneof('result'), 'verified_jwt')
145    self.assertEqual(verify_response.verified_jwt.issuer.value, 'issuer')
146    self.assertEqual(verify_response.verified_jwt.subject.value, 'subject')
147    self.assertEqual(verify_response.verified_jwt.expiration.seconds, 1334)
148    self.assertEqual(verify_response.verified_jwt.expiration.nanos, 0)
149
150  def test_generate_compute_verify_mac_without_expiration(self):
151    keyset_servicer = services.KeysetServicer()
152    jwt_servicer = jwt_service.JwtServicer()
153
154    template = jwt.jwt_hs256_template().SerializeToString()
155    gen_request = testing_api_pb2.KeysetGenerateRequest(template=template)
156    gen_response = keyset_servicer.Generate(gen_request, self._ctx)
157    self.assertEqual(gen_response.WhichOneof('result'), 'keyset')
158    keyset = gen_response.keyset
159
160    comp_request = testing_api_pb2.JwtSignRequest(
161        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
162            serialized_keyset=keyset))
163    comp_request.raw_jwt.issuer.value = 'issuer'
164
165    comp_response = jwt_servicer.ComputeMacAndEncode(comp_request, self._ctx)
166    self.assertEqual(comp_response.WhichOneof('result'), 'signed_compact_jwt')
167    signed_compact_jwt = comp_response.signed_compact_jwt
168    verify_request = testing_api_pb2.JwtVerifyRequest(
169        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
170            serialized_keyset=keyset),
171        signed_compact_jwt=signed_compact_jwt)
172    verify_request.validator.expected_issuer.value = 'issuer'
173    verify_request.validator.allow_missing_expiration = True
174    verify_response = jwt_servicer.VerifyMacAndDecode(verify_request, self._ctx)
175    print(verify_response.err)
176    self.assertEqual(verify_response.WhichOneof('result'), 'verified_jwt')
177    self.assertEqual(verify_response.verified_jwt.issuer.value, 'issuer')
178
179  def test_create_public_key_sign(self):
180    keyset_servicer = services.KeysetServicer()
181    jwt_servicer = jwt_service.JwtServicer()
182
183    template = jwt.jwt_es256_template().SerializeToString()
184    gen_request = testing_api_pb2.KeysetGenerateRequest(template=template)
185    gen_response = keyset_servicer.Generate(gen_request, self._ctx)
186    self.assertEqual(gen_response.WhichOneof('result'), 'keyset')
187
188    creation_request = testing_api_pb2.CreationRequest(
189        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
190            serialized_keyset=gen_response.keyset))
191    creation_response = jwt_servicer.CreateJwtPublicKeySign(
192        creation_request, self._ctx)
193    self.assertEmpty(creation_response.err)
194
195  def test_create_public_key_sign_bad_keyset(self):
196    jwt_servicer = jwt_service.JwtServicer()
197
198    creation_request = testing_api_pb2.CreationRequest(
199        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
200            serialized_keyset=b'\x80'))
201    creation_response = jwt_servicer.CreateJwtPublicKeySign(
202        creation_request, self._ctx)
203    self.assertNotEmpty(creation_response.err)
204
205  def test_create_public_key_verify(self):
206    keyset_servicer = services.KeysetServicer()
207    jwt_servicer = jwt_service.JwtServicer()
208
209    template = jwt.jwt_es256_template().SerializeToString()
210    gen_request = testing_api_pb2.KeysetGenerateRequest(template=template)
211    gen_response = keyset_servicer.Generate(gen_request, self._ctx)
212    self.assertEqual(gen_response.WhichOneof('result'), 'keyset')
213    pub_request = testing_api_pb2.KeysetPublicRequest(
214        private_keyset=gen_response.keyset)
215    pub_response = keyset_servicer.Public(pub_request, self._ctx)
216    self.assertEqual(pub_response.WhichOneof('result'), 'public_keyset')
217
218    creation_request = testing_api_pb2.CreationRequest(
219        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
220            serialized_keyset=pub_response.public_keyset))
221    creation_response = jwt_servicer.CreateJwtPublicKeyVerify(
222        creation_request, self._ctx)
223    self.assertEmpty(creation_response.err)
224
225  def test_create_public_key_verify_bad_keyset(self):
226    jwt_servicer = jwt_service.JwtServicer()
227
228    creation_request = testing_api_pb2.CreationRequest(
229        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
230            serialized_keyset=b'\x80'))
231    creation_response = jwt_servicer.CreateJwtPublicKeyVerify(
232        creation_request, self._ctx)
233    self.assertNotEmpty(creation_response.err)
234
235  def test_generate_sign_export_import_verify_signature(self):
236    keyset_servicer = services.KeysetServicer()
237    jwt_servicer = jwt_service.JwtServicer()
238
239    template = jwt.jwt_es256_template().SerializeToString()
240    gen_request = testing_api_pb2.KeysetGenerateRequest(template=template)
241    gen_response = keyset_servicer.Generate(gen_request, self._ctx)
242    self.assertEqual(gen_response.WhichOneof('result'), 'keyset')
243    private_keyset = gen_response.keyset
244
245    comp_request = testing_api_pb2.JwtSignRequest(
246        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
247            serialized_keyset=private_keyset))
248    comp_request.raw_jwt.issuer.value = 'issuer'
249    comp_request.raw_jwt.subject.value = 'subject'
250    comp_request.raw_jwt.custom_claims['myclaim'].bool_value = True
251    comp_response = jwt_servicer.PublicKeySignAndEncode(comp_request, self._ctx)
252    self.assertEqual(comp_response.WhichOneof('result'), 'signed_compact_jwt')
253    signed_compact_jwt = comp_response.signed_compact_jwt
254
255    pub_request = testing_api_pb2.KeysetPublicRequest(
256        private_keyset=private_keyset)
257    pub_response = keyset_servicer.Public(pub_request, self._ctx)
258    self.assertEqual(pub_response.WhichOneof('result'), 'public_keyset')
259    public_keyset = pub_response.public_keyset
260
261    to_jwkset_request = testing_api_pb2.JwtToJwkSetRequest(keyset=public_keyset)
262    to_jwkset_response = jwt_servicer.ToJwkSet(to_jwkset_request, self._ctx)
263    self.assertEqual(to_jwkset_response.WhichOneof('result'), 'jwk_set')
264
265    self.assertStartsWith(to_jwkset_response.jwk_set, '{"keys":[{"')
266
267    from_jwkset_request = testing_api_pb2.JwtFromJwkSetRequest(
268        jwk_set=to_jwkset_response.jwk_set)
269    from_jwkset_response = jwt_servicer.FromJwkSet(
270        from_jwkset_request, self._ctx)
271    self.assertEqual(from_jwkset_response.WhichOneof('result'), 'keyset')
272
273    verify_request = testing_api_pb2.JwtVerifyRequest(
274        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
275            serialized_keyset=from_jwkset_response.keyset),
276        signed_compact_jwt=signed_compact_jwt)
277    verify_request.validator.expected_issuer.value = 'issuer'
278    verify_request.validator.allow_missing_expiration = True
279    verify_response = jwt_servicer.PublicKeyVerifyAndDecode(
280        verify_request, self._ctx)
281    self.assertEqual(verify_response.WhichOneof('result'), 'verified_jwt')
282    self.assertEqual(verify_response.verified_jwt.issuer.value, 'issuer')
283
284  def test_to_jwk_set_with_invalid_keyset_fails(self):
285    jwt_servicer = jwt_service.JwtServicer()
286
287    to_jwkset_request = testing_api_pb2.JwtToJwkSetRequest(keyset=b'invalid')
288    jwkset_response = jwt_servicer.ToJwkSet(to_jwkset_request, self._ctx)
289    self.assertEqual(jwkset_response.WhichOneof('result'), 'err')
290
291  def test_from_jwk_set_with_invalid_jwk_set_fails(self):
292    jwt_servicer = jwt_service.JwtServicer()
293
294    from_jwkset_request = testing_api_pb2.JwtFromJwkSetRequest(
295        jwk_set='invalid')
296    from_jwkset_response = jwt_servicer.FromJwkSet(from_jwkset_request,
297                                                   self._ctx)
298    self.assertEqual(from_jwkset_response.WhichOneof('result'), 'err')
299    print(from_jwkset_response.err)
300
301
302if __name__ == '__main__':
303  absltest.main()
304