1# Copyright 2021 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS-IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for tink.testing.python.jwt_service.""" 15 16from absl.testing import absltest 17import grpc 18 19from tink import jwt 20 21from protos import testing_api_pb2 22import jwt_service 23import services 24 25 26class DummyServicerContext(grpc.ServicerContext): 27 28 def is_active(self): 29 pass 30 31 def time_remaining(self): 32 pass 33 34 def cancel(self): 35 pass 36 37 def add_callback(self, callback): 38 pass 39 40 def invocation_metadata(self): 41 pass 42 43 def peer(self): 44 pass 45 46 def peer_identities(self): 47 pass 48 49 def peer_identity_key(self): 50 pass 51 52 def auth_context(self): 53 pass 54 55 def set_compression(self, compression): 56 pass 57 58 def send_initial_metadata(self, initial_metadata): 59 pass 60 61 def set_trailing_metadata(self, trailing_metadata): 62 pass 63 64 def abort(self, code, details): 65 pass 66 67 def abort_with_status(self, status): 68 pass 69 70 def set_code(self, code): 71 pass 72 73 def set_details(self, details): 74 pass 75 76 def disable_next_message_compression(self): 77 pass 78 79 80class JwtServiceTest(absltest.TestCase): 81 82 _ctx = DummyServicerContext() 83 84 @classmethod 85 def setUpClass(cls): 86 super().setUpClass() 87 jwt.register_jwt_mac() 88 jwt.register_jwt_signature() 89 90 def test_create_jwt_mac(self): 91 keyset_servicer = services.KeysetServicer() 92 jwt_servicer = jwt_service.JwtServicer() 93 94 template = jwt.jwt_hs256_template().SerializeToString() 95 gen_request = testing_api_pb2.KeysetGenerateRequest(template=template) 96 gen_response = keyset_servicer.Generate(gen_request, self._ctx) 97 self.assertEqual(gen_response.WhichOneof('result'), 'keyset') 98 99 creation_request = testing_api_pb2.CreationRequest( 100 annotated_keyset=testing_api_pb2.AnnotatedKeyset( 101 serialized_keyset=gen_response.keyset)) 102 creation_response = jwt_servicer.CreateJwtMac( 103 creation_request, self._ctx) 104 self.assertEmpty(creation_response.err) 105 106 def test_create_jwt_mac_broken_keyset(self): 107 jwt_servicer = jwt_service.JwtServicer() 108 109 creation_request = testing_api_pb2.CreationRequest( 110 annotated_keyset=testing_api_pb2.AnnotatedKeyset( 111 serialized_keyset=b'\x80')) 112 creation_response = jwt_servicer.CreateJwtMac(creation_request, self._ctx) 113 self.assertNotEmpty(creation_response.err) 114 115 def test_generate_compute_verify_mac(self): 116 keyset_servicer = services.KeysetServicer() 117 jwt_servicer = jwt_service.JwtServicer() 118 119 template = jwt.jwt_hs256_template().SerializeToString() 120 gen_request = testing_api_pb2.KeysetGenerateRequest(template=template) 121 gen_response = keyset_servicer.Generate(gen_request, self._ctx) 122 self.assertEqual(gen_response.WhichOneof('result'), 'keyset') 123 keyset = gen_response.keyset 124 125 comp_request = testing_api_pb2.JwtSignRequest( 126 annotated_keyset=testing_api_pb2.AnnotatedKeyset( 127 serialized_keyset=keyset)) 128 comp_request.raw_jwt.issuer.value = 'issuer' 129 comp_request.raw_jwt.subject.value = 'subject' 130 comp_request.raw_jwt.custom_claims['myclaim'].bool_value = True 131 comp_request.raw_jwt.expiration.seconds = 1334 132 comp_request.raw_jwt.expiration.nanos = 123000000 133 134 comp_response = jwt_servicer.ComputeMacAndEncode(comp_request, self._ctx) 135 self.assertEqual(comp_response.WhichOneof('result'), 'signed_compact_jwt') 136 signed_compact_jwt = comp_response.signed_compact_jwt 137 verify_request = testing_api_pb2.JwtVerifyRequest( 138 annotated_keyset=testing_api_pb2.AnnotatedKeyset( 139 serialized_keyset=keyset), 140 signed_compact_jwt=signed_compact_jwt) 141 verify_request.validator.expected_issuer.value = 'issuer' 142 verify_request.validator.now.seconds = 1234 143 verify_response = jwt_servicer.VerifyMacAndDecode(verify_request, self._ctx) 144 self.assertEqual(verify_response.WhichOneof('result'), 'verified_jwt') 145 self.assertEqual(verify_response.verified_jwt.issuer.value, 'issuer') 146 self.assertEqual(verify_response.verified_jwt.subject.value, 'subject') 147 self.assertEqual(verify_response.verified_jwt.expiration.seconds, 1334) 148 self.assertEqual(verify_response.verified_jwt.expiration.nanos, 0) 149 150 def test_generate_compute_verify_mac_without_expiration(self): 151 keyset_servicer = services.KeysetServicer() 152 jwt_servicer = jwt_service.JwtServicer() 153 154 template = jwt.jwt_hs256_template().SerializeToString() 155 gen_request = testing_api_pb2.KeysetGenerateRequest(template=template) 156 gen_response = keyset_servicer.Generate(gen_request, self._ctx) 157 self.assertEqual(gen_response.WhichOneof('result'), 'keyset') 158 keyset = gen_response.keyset 159 160 comp_request = testing_api_pb2.JwtSignRequest( 161 annotated_keyset=testing_api_pb2.AnnotatedKeyset( 162 serialized_keyset=keyset)) 163 comp_request.raw_jwt.issuer.value = 'issuer' 164 165 comp_response = jwt_servicer.ComputeMacAndEncode(comp_request, self._ctx) 166 self.assertEqual(comp_response.WhichOneof('result'), 'signed_compact_jwt') 167 signed_compact_jwt = comp_response.signed_compact_jwt 168 verify_request = testing_api_pb2.JwtVerifyRequest( 169 annotated_keyset=testing_api_pb2.AnnotatedKeyset( 170 serialized_keyset=keyset), 171 signed_compact_jwt=signed_compact_jwt) 172 verify_request.validator.expected_issuer.value = 'issuer' 173 verify_request.validator.allow_missing_expiration = True 174 verify_response = jwt_servicer.VerifyMacAndDecode(verify_request, self._ctx) 175 print(verify_response.err) 176 self.assertEqual(verify_response.WhichOneof('result'), 'verified_jwt') 177 self.assertEqual(verify_response.verified_jwt.issuer.value, 'issuer') 178 179 def test_create_public_key_sign(self): 180 keyset_servicer = services.KeysetServicer() 181 jwt_servicer = jwt_service.JwtServicer() 182 183 template = jwt.jwt_es256_template().SerializeToString() 184 gen_request = testing_api_pb2.KeysetGenerateRequest(template=template) 185 gen_response = keyset_servicer.Generate(gen_request, self._ctx) 186 self.assertEqual(gen_response.WhichOneof('result'), 'keyset') 187 188 creation_request = testing_api_pb2.CreationRequest( 189 annotated_keyset=testing_api_pb2.AnnotatedKeyset( 190 serialized_keyset=gen_response.keyset)) 191 creation_response = jwt_servicer.CreateJwtPublicKeySign( 192 creation_request, self._ctx) 193 self.assertEmpty(creation_response.err) 194 195 def test_create_public_key_sign_bad_keyset(self): 196 jwt_servicer = jwt_service.JwtServicer() 197 198 creation_request = testing_api_pb2.CreationRequest( 199 annotated_keyset=testing_api_pb2.AnnotatedKeyset( 200 serialized_keyset=b'\x80')) 201 creation_response = jwt_servicer.CreateJwtPublicKeySign( 202 creation_request, self._ctx) 203 self.assertNotEmpty(creation_response.err) 204 205 def test_create_public_key_verify(self): 206 keyset_servicer = services.KeysetServicer() 207 jwt_servicer = jwt_service.JwtServicer() 208 209 template = jwt.jwt_es256_template().SerializeToString() 210 gen_request = testing_api_pb2.KeysetGenerateRequest(template=template) 211 gen_response = keyset_servicer.Generate(gen_request, self._ctx) 212 self.assertEqual(gen_response.WhichOneof('result'), 'keyset') 213 pub_request = testing_api_pb2.KeysetPublicRequest( 214 private_keyset=gen_response.keyset) 215 pub_response = keyset_servicer.Public(pub_request, self._ctx) 216 self.assertEqual(pub_response.WhichOneof('result'), 'public_keyset') 217 218 creation_request = testing_api_pb2.CreationRequest( 219 annotated_keyset=testing_api_pb2.AnnotatedKeyset( 220 serialized_keyset=pub_response.public_keyset)) 221 creation_response = jwt_servicer.CreateJwtPublicKeyVerify( 222 creation_request, self._ctx) 223 self.assertEmpty(creation_response.err) 224 225 def test_create_public_key_verify_bad_keyset(self): 226 jwt_servicer = jwt_service.JwtServicer() 227 228 creation_request = testing_api_pb2.CreationRequest( 229 annotated_keyset=testing_api_pb2.AnnotatedKeyset( 230 serialized_keyset=b'\x80')) 231 creation_response = jwt_servicer.CreateJwtPublicKeyVerify( 232 creation_request, self._ctx) 233 self.assertNotEmpty(creation_response.err) 234 235 def test_generate_sign_export_import_verify_signature(self): 236 keyset_servicer = services.KeysetServicer() 237 jwt_servicer = jwt_service.JwtServicer() 238 239 template = jwt.jwt_es256_template().SerializeToString() 240 gen_request = testing_api_pb2.KeysetGenerateRequest(template=template) 241 gen_response = keyset_servicer.Generate(gen_request, self._ctx) 242 self.assertEqual(gen_response.WhichOneof('result'), 'keyset') 243 private_keyset = gen_response.keyset 244 245 comp_request = testing_api_pb2.JwtSignRequest( 246 annotated_keyset=testing_api_pb2.AnnotatedKeyset( 247 serialized_keyset=private_keyset)) 248 comp_request.raw_jwt.issuer.value = 'issuer' 249 comp_request.raw_jwt.subject.value = 'subject' 250 comp_request.raw_jwt.custom_claims['myclaim'].bool_value = True 251 comp_response = jwt_servicer.PublicKeySignAndEncode(comp_request, self._ctx) 252 self.assertEqual(comp_response.WhichOneof('result'), 'signed_compact_jwt') 253 signed_compact_jwt = comp_response.signed_compact_jwt 254 255 pub_request = testing_api_pb2.KeysetPublicRequest( 256 private_keyset=private_keyset) 257 pub_response = keyset_servicer.Public(pub_request, self._ctx) 258 self.assertEqual(pub_response.WhichOneof('result'), 'public_keyset') 259 public_keyset = pub_response.public_keyset 260 261 to_jwkset_request = testing_api_pb2.JwtToJwkSetRequest(keyset=public_keyset) 262 to_jwkset_response = jwt_servicer.ToJwkSet(to_jwkset_request, self._ctx) 263 self.assertEqual(to_jwkset_response.WhichOneof('result'), 'jwk_set') 264 265 self.assertStartsWith(to_jwkset_response.jwk_set, '{"keys":[{"') 266 267 from_jwkset_request = testing_api_pb2.JwtFromJwkSetRequest( 268 jwk_set=to_jwkset_response.jwk_set) 269 from_jwkset_response = jwt_servicer.FromJwkSet( 270 from_jwkset_request, self._ctx) 271 self.assertEqual(from_jwkset_response.WhichOneof('result'), 'keyset') 272 273 verify_request = testing_api_pb2.JwtVerifyRequest( 274 annotated_keyset=testing_api_pb2.AnnotatedKeyset( 275 serialized_keyset=from_jwkset_response.keyset), 276 signed_compact_jwt=signed_compact_jwt) 277 verify_request.validator.expected_issuer.value = 'issuer' 278 verify_request.validator.allow_missing_expiration = True 279 verify_response = jwt_servicer.PublicKeyVerifyAndDecode( 280 verify_request, self._ctx) 281 self.assertEqual(verify_response.WhichOneof('result'), 'verified_jwt') 282 self.assertEqual(verify_response.verified_jwt.issuer.value, 'issuer') 283 284 def test_to_jwk_set_with_invalid_keyset_fails(self): 285 jwt_servicer = jwt_service.JwtServicer() 286 287 to_jwkset_request = testing_api_pb2.JwtToJwkSetRequest(keyset=b'invalid') 288 jwkset_response = jwt_servicer.ToJwkSet(to_jwkset_request, self._ctx) 289 self.assertEqual(jwkset_response.WhichOneof('result'), 'err') 290 291 def test_from_jwk_set_with_invalid_jwk_set_fails(self): 292 jwt_servicer = jwt_service.JwtServicer() 293 294 from_jwkset_request = testing_api_pb2.JwtFromJwkSetRequest( 295 jwk_set='invalid') 296 from_jwkset_response = jwt_servicer.FromJwkSet(from_jwkset_request, 297 self._ctx) 298 self.assertEqual(from_jwkset_response.WhichOneof('result'), 'err') 299 print(from_jwkset_response.err) 300 301 302if __name__ == '__main__': 303 absltest.main() 304