1# Copyright 2021 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS-IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""JWT testing service API implementations in Python.""" 15 16import datetime 17import io 18import json 19 20from typing import Tuple 21 22import grpc 23import tink 24from tink import cleartext_keyset_handle 25 26from tink import jwt 27 28from google.protobuf import duration_pb2 29from google.protobuf import timestamp_pb2 30 31from protos import testing_api_pb2 32from protos import testing_api_pb2_grpc 33 34 35def _to_timestamp_tuple(t: datetime.datetime) -> Tuple[int, int]: 36 if not t.tzinfo: 37 raise ValueError('datetime must have tzinfo') 38 seconds = int(t.timestamp()) 39 nanos = int((t.timestamp() - seconds) * 1e9) 40 return (seconds, nanos) 41 42 43def _from_timestamp_proto( 44 timestamp: timestamp_pb2.Timestamp) -> datetime.datetime: 45 t = timestamp.seconds + (timestamp.nanos / 1e9) 46 return datetime.datetime.fromtimestamp(t, datetime.timezone.utc) 47 48 49def _from_duration_proto( 50 duration: duration_pb2.Duration) -> datetime.timedelta: 51 return datetime.timedelta(seconds=duration.seconds) 52 53 54def raw_jwt_from_proto(proto_raw_jwt: testing_api_pb2.JwtToken) -> jwt.RawJwt: 55 """Converts a proto JwtToken into a jwt.RawJwt.""" 56 type_header = None 57 if proto_raw_jwt.HasField('type_header'): 58 type_header = proto_raw_jwt.type_header.value 59 issuer = None 60 if proto_raw_jwt.HasField('issuer'): 61 issuer = proto_raw_jwt.issuer.value 62 subject = None 63 if proto_raw_jwt.HasField('subject'): 64 subject = proto_raw_jwt.subject.value 65 audiences = list(proto_raw_jwt.audiences) 66 if not audiences: 67 audiences = None 68 jwt_id = None 69 if proto_raw_jwt.HasField('jwt_id'): 70 jwt_id = proto_raw_jwt.jwt_id.value 71 custom_claims = {} 72 for name, claim in proto_raw_jwt.custom_claims.items(): 73 if claim.HasField('null_value'): 74 custom_claims[name] = None 75 elif claim.HasField('number_value'): 76 custom_claims[name] = claim.number_value 77 elif claim.HasField('string_value'): 78 custom_claims[name] = claim.string_value 79 elif claim.HasField('bool_value'): 80 custom_claims[name] = claim.bool_value 81 elif claim.HasField('json_object_value'): 82 custom_claims[name] = json.loads(claim.json_object_value) 83 elif claim.HasField('json_array_value'): 84 custom_claims[name] = json.loads(claim.json_array_value) 85 else: 86 raise ValueError('claim %s has unknown type' % name) 87 expiration = None 88 if proto_raw_jwt.HasField('expiration'): 89 expiration = _from_timestamp_proto(proto_raw_jwt.expiration) 90 not_before = None 91 if proto_raw_jwt.HasField('not_before'): 92 not_before = _from_timestamp_proto(proto_raw_jwt.not_before) 93 issued_at = None 94 if proto_raw_jwt.HasField('issued_at'): 95 issued_at = _from_timestamp_proto(proto_raw_jwt.issued_at) 96 without_expiration = not expiration 97 return jwt.new_raw_jwt( 98 type_header=type_header, 99 issuer=issuer, 100 subject=subject, 101 audiences=audiences, 102 jwt_id=jwt_id, 103 expiration=expiration, 104 without_expiration=without_expiration, 105 not_before=not_before, 106 issued_at=issued_at, 107 custom_claims=custom_claims) 108 109 110def verifiedjwt_to_proto( 111 verified_jwt: jwt.VerifiedJwt) -> testing_api_pb2.JwtToken: 112 """Converts a jwt.VerifiedJwt into a proto JwtToken.""" 113 token = testing_api_pb2.JwtToken() 114 if verified_jwt.has_type_header(): 115 token.type_header.value = verified_jwt.type_header() 116 if verified_jwt.has_issuer(): 117 token.issuer.value = verified_jwt.issuer() 118 if verified_jwt.has_subject(): 119 token.subject.value = verified_jwt.subject() 120 if verified_jwt.has_audiences(): 121 token.audiences.extend(verified_jwt.audiences()) 122 if verified_jwt.has_jwt_id(): 123 token.jwt_id.value = verified_jwt.jwt_id() 124 if verified_jwt.has_expiration(): 125 seconds, nanos = _to_timestamp_tuple(verified_jwt.expiration()) 126 token.expiration.seconds = seconds 127 token.expiration.nanos = nanos 128 if verified_jwt.has_not_before(): 129 seconds, nanos = _to_timestamp_tuple(verified_jwt.not_before()) 130 token.not_before.seconds = seconds 131 token.not_before.nanos = nanos 132 if verified_jwt.has_issued_at(): 133 seconds, nanos = _to_timestamp_tuple(verified_jwt.issued_at()) 134 token.issued_at.seconds = seconds 135 token.issued_at.nanos = nanos 136 for name in verified_jwt.custom_claim_names(): 137 value = verified_jwt.custom_claim(name) 138 if value is None: 139 token.custom_claims[name].null_value = testing_api_pb2.NULL_VALUE 140 elif isinstance(value, bool): 141 token.custom_claims[name].bool_value = value 142 elif isinstance(value, (int, float)): 143 token.custom_claims[name].number_value = value 144 elif isinstance(value, str): 145 token.custom_claims[name].string_value = value 146 elif isinstance(value, dict): 147 token.custom_claims[name].json_object_value = json.dumps(value) 148 elif isinstance(value, list): 149 token.custom_claims[name].json_array_value = json.dumps(value) 150 else: 151 raise ValueError('claim %s has unknown type' % name) 152 return token 153 154 155def validator_from_proto( 156 proto_validator: testing_api_pb2.JwtValidator) -> jwt.JwtValidator: 157 """Converts a proto JwtValidator into a JwtValidator.""" 158 expected_type_header = None 159 if proto_validator.HasField('expected_type_header'): 160 expected_type_header = proto_validator.expected_type_header.value 161 expected_issuer = None 162 if proto_validator.HasField('expected_issuer'): 163 expected_issuer = proto_validator.expected_issuer.value 164 expected_audience = None 165 if proto_validator.HasField('expected_audience'): 166 expected_audience = proto_validator.expected_audience.value 167 fixed_now = None 168 if proto_validator.HasField('now'): 169 fixed_now = _from_timestamp_proto(proto_validator.now) 170 clock_skew = None 171 if proto_validator.HasField('clock_skew'): 172 clock_skew = _from_duration_proto(proto_validator.clock_skew) 173 return jwt.new_validator( 174 expected_type_header=expected_type_header, 175 expected_issuer=expected_issuer, 176 expected_audience=expected_audience, 177 ignore_type_header=proto_validator.ignore_type_header, 178 ignore_issuer=proto_validator.ignore_issuer, 179 ignore_audiences=proto_validator.ignore_audience, 180 allow_missing_expiration=proto_validator.allow_missing_expiration, 181 expect_issued_in_the_past=proto_validator.expect_issued_in_the_past, 182 fixed_now=fixed_now, 183 clock_skew=clock_skew) 184 185 186class JwtServicer(testing_api_pb2_grpc.JwtServicer): 187 """A service for signing and verifying JWTs.""" 188 189 def CreateJwtMac( 190 self, request: testing_api_pb2.CreationRequest, 191 context: grpc.ServicerContext) -> testing_api_pb2.CreationResponse: 192 """Creates a JwtMac without using it.""" 193 try: 194 keyset_handle = cleartext_keyset_handle.read( 195 tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) 196 keyset_handle.primitive(jwt.JwtMac) 197 return testing_api_pb2.CreationResponse() 198 except tink.TinkError as e: 199 return testing_api_pb2.CreationResponse(err=str(e)) 200 201 def CreateJwtPublicKeySign( 202 self, request: testing_api_pb2.CreationRequest, 203 context: grpc.ServicerContext) -> testing_api_pb2.CreationResponse: 204 """Creates a JwtPublicKeySign without using it.""" 205 try: 206 keyset_handle = cleartext_keyset_handle.read( 207 tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) 208 keyset_handle.primitive(jwt.JwtPublicKeySign) 209 return testing_api_pb2.CreationResponse() 210 except tink.TinkError as e: 211 return testing_api_pb2.CreationResponse(err=str(e)) 212 213 def CreateJwtPublicKeyVerify( 214 self, request: testing_api_pb2.CreationRequest, 215 context: grpc.ServicerContext) -> testing_api_pb2.CreationResponse: 216 """Creates a JwtPublicKeyVerify without using it.""" 217 try: 218 keyset_handle = cleartext_keyset_handle.read( 219 tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) 220 keyset_handle.primitive(jwt.JwtPublicKeyVerify) 221 return testing_api_pb2.CreationResponse() 222 except tink.TinkError as e: 223 return testing_api_pb2.CreationResponse(err=str(e)) 224 225 def ComputeMacAndEncode( 226 self, request: testing_api_pb2.JwtSignRequest, 227 context: grpc.ServicerContext) -> testing_api_pb2.JwtSignResponse: 228 """Computes a MACed compact JWT.""" 229 try: 230 keyset_handle = cleartext_keyset_handle.read( 231 tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) 232 p = keyset_handle.primitive(jwt.JwtMac) 233 raw_jwt = raw_jwt_from_proto(request.raw_jwt) 234 signed_compact_jwt = p.compute_mac_and_encode(raw_jwt) 235 return testing_api_pb2.JwtSignResponse( 236 signed_compact_jwt=signed_compact_jwt) 237 except tink.TinkError as e: 238 return testing_api_pb2.JwtSignResponse(err=str(e)) 239 240 def VerifyMacAndDecode( 241 self, request: testing_api_pb2.JwtVerifyRequest, 242 context: grpc.ServicerContext) -> testing_api_pb2.JwtVerifyResponse: 243 """Verifies a MAC value.""" 244 try: 245 keyset_handle = cleartext_keyset_handle.read( 246 tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) 247 validator = validator_from_proto(request.validator) 248 p = keyset_handle.primitive(jwt.JwtMac) 249 verified_jwt = p.verify_mac_and_decode(request.signed_compact_jwt, 250 validator) 251 return testing_api_pb2.JwtVerifyResponse( 252 verified_jwt=verifiedjwt_to_proto(verified_jwt)) 253 except tink.TinkError as e: 254 return testing_api_pb2.JwtVerifyResponse(err=str(e)) 255 256 def PublicKeySignAndEncode( 257 self, request: testing_api_pb2.JwtSignRequest, 258 context: grpc.ServicerContext) -> testing_api_pb2.JwtSignResponse: 259 """Computes a signed compact JWT token.""" 260 try: 261 keyset_handle = cleartext_keyset_handle.read( 262 tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) 263 p = keyset_handle.primitive(jwt.JwtPublicKeySign) 264 raw_jwt = raw_jwt_from_proto(request.raw_jwt) 265 signed_compact_jwt = p.sign_and_encode(raw_jwt) 266 return testing_api_pb2.JwtSignResponse( 267 signed_compact_jwt=signed_compact_jwt) 268 except tink.TinkError as e: 269 return testing_api_pb2.JwtSignResponse(err=str(e)) 270 271 def PublicKeyVerifyAndDecode( 272 self, request: testing_api_pb2.JwtVerifyRequest, 273 context: grpc.ServicerContext) -> testing_api_pb2.JwtVerifyResponse: 274 """Verifies the validity of the signed compact JWT token.""" 275 try: 276 keyset_handle = cleartext_keyset_handle.read( 277 tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) 278 validator = validator_from_proto(request.validator) 279 p = keyset_handle.primitive(jwt.JwtPublicKeyVerify) 280 verified_jwt = p.verify_and_decode(request.signed_compact_jwt, validator) 281 return testing_api_pb2.JwtVerifyResponse( 282 verified_jwt=verifiedjwt_to_proto(verified_jwt)) 283 except tink.TinkError as e: 284 return testing_api_pb2.JwtVerifyResponse(err=str(e)) 285 286 def ToJwkSet( 287 self, request: testing_api_pb2.JwtToJwkSetRequest, 288 context: grpc.ServicerContext) -> testing_api_pb2.JwtToJwkSetResponse: 289 """Converts a Tink Keyset with JWT keys into a JWK set.""" 290 try: 291 keyset_handle = cleartext_keyset_handle.read( 292 tink.BinaryKeysetReader(request.keyset)) 293 jwk_set = jwt.jwk_set_from_public_keyset_handle(keyset_handle) 294 return testing_api_pb2.JwtToJwkSetResponse(jwk_set=jwk_set) 295 except tink.TinkError as e: 296 return testing_api_pb2.JwtToJwkSetResponse(err=str(e)) 297 298 def FromJwkSet( 299 self, request: testing_api_pb2.JwtFromJwkSetRequest, 300 context: grpc.ServicerContext) -> testing_api_pb2.JwtFromJwkSetResponse: 301 """Converts a JWK set into a Tink Keyset.""" 302 try: 303 keyset_handle = jwt.jwk_set_to_public_keyset_handle(request.jwk_set) 304 keyset = io.BytesIO() 305 cleartext_keyset_handle.write( 306 tink.BinaryKeysetWriter(keyset), keyset_handle) 307 return testing_api_pb2.JwtFromJwkSetResponse(keyset=keyset.getvalue()) 308 except tink.TinkError as e: 309 return testing_api_pb2.JwtFromJwkSetResponse(err=str(e)) 310