xref: /aosp_15_r20/external/tink/testing/python/jwt_service.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2021 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS-IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""JWT testing service API implementations in Python."""
15
16import datetime
17import io
18import json
19
20from typing import Tuple
21
22import grpc
23import tink
24from tink import cleartext_keyset_handle
25
26from tink import jwt
27
28from google.protobuf import duration_pb2
29from google.protobuf import timestamp_pb2
30
31from protos import testing_api_pb2
32from protos import testing_api_pb2_grpc
33
34
35def _to_timestamp_tuple(t: datetime.datetime) -> Tuple[int, int]:
36  if not t.tzinfo:
37    raise ValueError('datetime must have tzinfo')
38  seconds = int(t.timestamp())
39  nanos = int((t.timestamp() - seconds) * 1e9)
40  return (seconds, nanos)
41
42
43def _from_timestamp_proto(
44    timestamp: timestamp_pb2.Timestamp) -> datetime.datetime:
45  t = timestamp.seconds + (timestamp.nanos / 1e9)
46  return datetime.datetime.fromtimestamp(t, datetime.timezone.utc)
47
48
49def _from_duration_proto(
50    duration: duration_pb2.Duration) -> datetime.timedelta:
51  return datetime.timedelta(seconds=duration.seconds)
52
53
54def raw_jwt_from_proto(proto_raw_jwt: testing_api_pb2.JwtToken) -> jwt.RawJwt:
55  """Converts a proto JwtToken into a jwt.RawJwt."""
56  type_header = None
57  if proto_raw_jwt.HasField('type_header'):
58    type_header = proto_raw_jwt.type_header.value
59  issuer = None
60  if proto_raw_jwt.HasField('issuer'):
61    issuer = proto_raw_jwt.issuer.value
62  subject = None
63  if proto_raw_jwt.HasField('subject'):
64    subject = proto_raw_jwt.subject.value
65  audiences = list(proto_raw_jwt.audiences)
66  if not audiences:
67    audiences = None
68  jwt_id = None
69  if proto_raw_jwt.HasField('jwt_id'):
70    jwt_id = proto_raw_jwt.jwt_id.value
71  custom_claims = {}
72  for name, claim in proto_raw_jwt.custom_claims.items():
73    if claim.HasField('null_value'):
74      custom_claims[name] = None
75    elif claim.HasField('number_value'):
76      custom_claims[name] = claim.number_value
77    elif claim.HasField('string_value'):
78      custom_claims[name] = claim.string_value
79    elif claim.HasField('bool_value'):
80      custom_claims[name] = claim.bool_value
81    elif claim.HasField('json_object_value'):
82      custom_claims[name] = json.loads(claim.json_object_value)
83    elif claim.HasField('json_array_value'):
84      custom_claims[name] = json.loads(claim.json_array_value)
85    else:
86      raise ValueError('claim %s has unknown type' % name)
87  expiration = None
88  if proto_raw_jwt.HasField('expiration'):
89    expiration = _from_timestamp_proto(proto_raw_jwt.expiration)
90  not_before = None
91  if proto_raw_jwt.HasField('not_before'):
92    not_before = _from_timestamp_proto(proto_raw_jwt.not_before)
93  issued_at = None
94  if proto_raw_jwt.HasField('issued_at'):
95    issued_at = _from_timestamp_proto(proto_raw_jwt.issued_at)
96  without_expiration = not expiration
97  return jwt.new_raw_jwt(
98      type_header=type_header,
99      issuer=issuer,
100      subject=subject,
101      audiences=audiences,
102      jwt_id=jwt_id,
103      expiration=expiration,
104      without_expiration=without_expiration,
105      not_before=not_before,
106      issued_at=issued_at,
107      custom_claims=custom_claims)
108
109
110def verifiedjwt_to_proto(
111    verified_jwt: jwt.VerifiedJwt) -> testing_api_pb2.JwtToken:
112  """Converts a jwt.VerifiedJwt into a proto JwtToken."""
113  token = testing_api_pb2.JwtToken()
114  if verified_jwt.has_type_header():
115    token.type_header.value = verified_jwt.type_header()
116  if verified_jwt.has_issuer():
117    token.issuer.value = verified_jwt.issuer()
118  if verified_jwt.has_subject():
119    token.subject.value = verified_jwt.subject()
120  if verified_jwt.has_audiences():
121    token.audiences.extend(verified_jwt.audiences())
122  if verified_jwt.has_jwt_id():
123    token.jwt_id.value = verified_jwt.jwt_id()
124  if verified_jwt.has_expiration():
125    seconds, nanos = _to_timestamp_tuple(verified_jwt.expiration())
126    token.expiration.seconds = seconds
127    token.expiration.nanos = nanos
128  if verified_jwt.has_not_before():
129    seconds, nanos = _to_timestamp_tuple(verified_jwt.not_before())
130    token.not_before.seconds = seconds
131    token.not_before.nanos = nanos
132  if verified_jwt.has_issued_at():
133    seconds, nanos = _to_timestamp_tuple(verified_jwt.issued_at())
134    token.issued_at.seconds = seconds
135    token.issued_at.nanos = nanos
136  for name in verified_jwt.custom_claim_names():
137    value = verified_jwt.custom_claim(name)
138    if value is None:
139      token.custom_claims[name].null_value = testing_api_pb2.NULL_VALUE
140    elif isinstance(value, bool):
141      token.custom_claims[name].bool_value = value
142    elif isinstance(value, (int, float)):
143      token.custom_claims[name].number_value = value
144    elif isinstance(value, str):
145      token.custom_claims[name].string_value = value
146    elif isinstance(value, dict):
147      token.custom_claims[name].json_object_value = json.dumps(value)
148    elif isinstance(value, list):
149      token.custom_claims[name].json_array_value = json.dumps(value)
150    else:
151      raise ValueError('claim %s has unknown type' % name)
152  return token
153
154
155def validator_from_proto(
156    proto_validator: testing_api_pb2.JwtValidator) -> jwt.JwtValidator:
157  """Converts a proto JwtValidator into a JwtValidator."""
158  expected_type_header = None
159  if proto_validator.HasField('expected_type_header'):
160    expected_type_header = proto_validator.expected_type_header.value
161  expected_issuer = None
162  if proto_validator.HasField('expected_issuer'):
163    expected_issuer = proto_validator.expected_issuer.value
164  expected_audience = None
165  if proto_validator.HasField('expected_audience'):
166    expected_audience = proto_validator.expected_audience.value
167  fixed_now = None
168  if proto_validator.HasField('now'):
169    fixed_now = _from_timestamp_proto(proto_validator.now)
170  clock_skew = None
171  if proto_validator.HasField('clock_skew'):
172    clock_skew = _from_duration_proto(proto_validator.clock_skew)
173  return jwt.new_validator(
174      expected_type_header=expected_type_header,
175      expected_issuer=expected_issuer,
176      expected_audience=expected_audience,
177      ignore_type_header=proto_validator.ignore_type_header,
178      ignore_issuer=proto_validator.ignore_issuer,
179      ignore_audiences=proto_validator.ignore_audience,
180      allow_missing_expiration=proto_validator.allow_missing_expiration,
181      expect_issued_in_the_past=proto_validator.expect_issued_in_the_past,
182      fixed_now=fixed_now,
183      clock_skew=clock_skew)
184
185
186class JwtServicer(testing_api_pb2_grpc.JwtServicer):
187  """A service for signing and verifying JWTs."""
188
189  def CreateJwtMac(
190      self, request: testing_api_pb2.CreationRequest,
191      context: grpc.ServicerContext) -> testing_api_pb2.CreationResponse:
192    """Creates a JwtMac without using it."""
193    try:
194      keyset_handle = cleartext_keyset_handle.read(
195          tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset))
196      keyset_handle.primitive(jwt.JwtMac)
197      return testing_api_pb2.CreationResponse()
198    except tink.TinkError as e:
199      return testing_api_pb2.CreationResponse(err=str(e))
200
201  def CreateJwtPublicKeySign(
202      self, request: testing_api_pb2.CreationRequest,
203      context: grpc.ServicerContext) -> testing_api_pb2.CreationResponse:
204    """Creates a JwtPublicKeySign without using it."""
205    try:
206      keyset_handle = cleartext_keyset_handle.read(
207          tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset))
208      keyset_handle.primitive(jwt.JwtPublicKeySign)
209      return testing_api_pb2.CreationResponse()
210    except tink.TinkError as e:
211      return testing_api_pb2.CreationResponse(err=str(e))
212
213  def CreateJwtPublicKeyVerify(
214      self, request: testing_api_pb2.CreationRequest,
215      context: grpc.ServicerContext) -> testing_api_pb2.CreationResponse:
216    """Creates a JwtPublicKeyVerify without using it."""
217    try:
218      keyset_handle = cleartext_keyset_handle.read(
219          tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset))
220      keyset_handle.primitive(jwt.JwtPublicKeyVerify)
221      return testing_api_pb2.CreationResponse()
222    except tink.TinkError as e:
223      return testing_api_pb2.CreationResponse(err=str(e))
224
225  def ComputeMacAndEncode(
226      self, request: testing_api_pb2.JwtSignRequest,
227      context: grpc.ServicerContext) -> testing_api_pb2.JwtSignResponse:
228    """Computes a MACed compact JWT."""
229    try:
230      keyset_handle = cleartext_keyset_handle.read(
231          tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset))
232      p = keyset_handle.primitive(jwt.JwtMac)
233      raw_jwt = raw_jwt_from_proto(request.raw_jwt)
234      signed_compact_jwt = p.compute_mac_and_encode(raw_jwt)
235      return testing_api_pb2.JwtSignResponse(
236          signed_compact_jwt=signed_compact_jwt)
237    except tink.TinkError as e:
238      return testing_api_pb2.JwtSignResponse(err=str(e))
239
240  def VerifyMacAndDecode(
241      self, request: testing_api_pb2.JwtVerifyRequest,
242      context: grpc.ServicerContext) -> testing_api_pb2.JwtVerifyResponse:
243    """Verifies a MAC value."""
244    try:
245      keyset_handle = cleartext_keyset_handle.read(
246          tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset))
247      validator = validator_from_proto(request.validator)
248      p = keyset_handle.primitive(jwt.JwtMac)
249      verified_jwt = p.verify_mac_and_decode(request.signed_compact_jwt,
250                                             validator)
251      return testing_api_pb2.JwtVerifyResponse(
252          verified_jwt=verifiedjwt_to_proto(verified_jwt))
253    except tink.TinkError as e:
254      return testing_api_pb2.JwtVerifyResponse(err=str(e))
255
256  def PublicKeySignAndEncode(
257      self, request: testing_api_pb2.JwtSignRequest,
258      context: grpc.ServicerContext) -> testing_api_pb2.JwtSignResponse:
259    """Computes a signed compact JWT token."""
260    try:
261      keyset_handle = cleartext_keyset_handle.read(
262          tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset))
263      p = keyset_handle.primitive(jwt.JwtPublicKeySign)
264      raw_jwt = raw_jwt_from_proto(request.raw_jwt)
265      signed_compact_jwt = p.sign_and_encode(raw_jwt)
266      return testing_api_pb2.JwtSignResponse(
267          signed_compact_jwt=signed_compact_jwt)
268    except tink.TinkError as e:
269      return testing_api_pb2.JwtSignResponse(err=str(e))
270
271  def PublicKeyVerifyAndDecode(
272      self, request: testing_api_pb2.JwtVerifyRequest,
273      context: grpc.ServicerContext) -> testing_api_pb2.JwtVerifyResponse:
274    """Verifies the validity of the signed compact JWT token."""
275    try:
276      keyset_handle = cleartext_keyset_handle.read(
277          tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset))
278      validator = validator_from_proto(request.validator)
279      p = keyset_handle.primitive(jwt.JwtPublicKeyVerify)
280      verified_jwt = p.verify_and_decode(request.signed_compact_jwt, validator)
281      return testing_api_pb2.JwtVerifyResponse(
282          verified_jwt=verifiedjwt_to_proto(verified_jwt))
283    except tink.TinkError as e:
284      return testing_api_pb2.JwtVerifyResponse(err=str(e))
285
286  def ToJwkSet(
287      self, request: testing_api_pb2.JwtToJwkSetRequest,
288      context: grpc.ServicerContext) -> testing_api_pb2.JwtToJwkSetResponse:
289    """Converts a Tink Keyset with JWT keys into a JWK set."""
290    try:
291      keyset_handle = cleartext_keyset_handle.read(
292          tink.BinaryKeysetReader(request.keyset))
293      jwk_set = jwt.jwk_set_from_public_keyset_handle(keyset_handle)
294      return testing_api_pb2.JwtToJwkSetResponse(jwk_set=jwk_set)
295    except tink.TinkError as e:
296      return testing_api_pb2.JwtToJwkSetResponse(err=str(e))
297
298  def FromJwkSet(
299      self, request: testing_api_pb2.JwtFromJwkSetRequest,
300      context: grpc.ServicerContext) -> testing_api_pb2.JwtFromJwkSetResponse:
301    """Converts a JWK set into a Tink Keyset."""
302    try:
303      keyset_handle = jwt.jwk_set_to_public_keyset_handle(request.jwk_set)
304      keyset = io.BytesIO()
305      cleartext_keyset_handle.write(
306          tink.BinaryKeysetWriter(keyset), keyset_handle)
307      return testing_api_pb2.JwtFromJwkSetResponse(keyset=keyset.getvalue())
308    except tink.TinkError as e:
309      return testing_api_pb2.JwtFromJwkSetResponse(err=str(e))
310