xref: /aosp_15_r20/external/tink/go/jwt/raw_jwt.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1*e7b1675dSTing-Kang Chang// Copyright 2021 Google LLC
2*e7b1675dSTing-Kang Chang//
3*e7b1675dSTing-Kang Chang// Licensed under the Apache License, Version 2.0 (the "License");
4*e7b1675dSTing-Kang Chang// you may not use this file except in compliance with the License.
5*e7b1675dSTing-Kang Chang// You may obtain a copy of the License at
6*e7b1675dSTing-Kang Chang//
7*e7b1675dSTing-Kang Chang//      http://www.apache.org/licenses/LICENSE-2.0
8*e7b1675dSTing-Kang Chang//
9*e7b1675dSTing-Kang Chang// Unless required by applicable law or agreed to in writing, software
10*e7b1675dSTing-Kang Chang// distributed under the License is distributed on an "AS IS" BASIS,
11*e7b1675dSTing-Kang Chang// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*e7b1675dSTing-Kang Chang// See the License for the specific language governing permissions and
13*e7b1675dSTing-Kang Chang// limitations under the License.
14*e7b1675dSTing-Kang Chang//
15*e7b1675dSTing-Kang Chang////////////////////////////////////////////////////////////////////////////////
16*e7b1675dSTing-Kang Chang
17*e7b1675dSTing-Kang Changpackage jwt
18*e7b1675dSTing-Kang Chang
19*e7b1675dSTing-Kang Changimport (
20*e7b1675dSTing-Kang Chang	"fmt"
21*e7b1675dSTing-Kang Chang	"time"
22*e7b1675dSTing-Kang Chang	"unicode/utf8"
23*e7b1675dSTing-Kang Chang
24*e7b1675dSTing-Kang Chang	spb "google.golang.org/protobuf/types/known/structpb"
25*e7b1675dSTing-Kang Chang)
26*e7b1675dSTing-Kang Chang
27*e7b1675dSTing-Kang Changconst (
28*e7b1675dSTing-Kang Chang	claimIssuer     = "iss"
29*e7b1675dSTing-Kang Chang	claimSubject    = "sub"
30*e7b1675dSTing-Kang Chang	claimAudience   = "aud"
31*e7b1675dSTing-Kang Chang	claimExpiration = "exp"
32*e7b1675dSTing-Kang Chang	claimNotBefore  = "nbf"
33*e7b1675dSTing-Kang Chang	claimIssuedAt   = "iat"
34*e7b1675dSTing-Kang Chang	claimJWTID      = "jti"
35*e7b1675dSTing-Kang Chang
36*e7b1675dSTing-Kang Chang	jwtTimestampMax = 253402300799
37*e7b1675dSTing-Kang Chang	jwtTimestampMin = 0
38*e7b1675dSTing-Kang Chang)
39*e7b1675dSTing-Kang Chang
40*e7b1675dSTing-Kang Chang// RawJWTOptions represent an unsigned JSON Web Token (JWT), https://tools.ietf.org/html/rfc7519.
41*e7b1675dSTing-Kang Chang//
42*e7b1675dSTing-Kang Chang// It contains all payload claims and a subset of the headers. It does not
43*e7b1675dSTing-Kang Chang// contain any headers that depend on the key, such as "alg" or "kid", because
44*e7b1675dSTing-Kang Chang// these headers are chosen when the token is signed and encoded, and should not
45*e7b1675dSTing-Kang Chang// be chosen by the user. This ensures that the key can be changed without any
46*e7b1675dSTing-Kang Chang// changes to the user code.
47*e7b1675dSTing-Kang Changtype RawJWTOptions struct {
48*e7b1675dSTing-Kang Chang	Audiences    []string
49*e7b1675dSTing-Kang Chang	Audience     *string
50*e7b1675dSTing-Kang Chang	Subject      *string
51*e7b1675dSTing-Kang Chang	Issuer       *string
52*e7b1675dSTing-Kang Chang	JWTID        *string
53*e7b1675dSTing-Kang Chang	IssuedAt     *time.Time
54*e7b1675dSTing-Kang Chang	ExpiresAt    *time.Time
55*e7b1675dSTing-Kang Chang	NotBefore    *time.Time
56*e7b1675dSTing-Kang Chang	CustomClaims map[string]interface{}
57*e7b1675dSTing-Kang Chang
58*e7b1675dSTing-Kang Chang	TypeHeader        *string
59*e7b1675dSTing-Kang Chang	WithoutExpiration bool
60*e7b1675dSTing-Kang Chang}
61*e7b1675dSTing-Kang Chang
62*e7b1675dSTing-Kang Chang// RawJWT is an unsigned JSON Web Token (JWT), https://tools.ietf.org/html/rfc7519.
63*e7b1675dSTing-Kang Changtype RawJWT struct {
64*e7b1675dSTing-Kang Chang	jsonpb     *spb.Struct
65*e7b1675dSTing-Kang Chang	typeHeader *string
66*e7b1675dSTing-Kang Chang}
67*e7b1675dSTing-Kang Chang
68*e7b1675dSTing-Kang Chang// NewRawJWT constructs a new RawJWT token based on the RawJwtOptions provided.
69*e7b1675dSTing-Kang Changfunc NewRawJWT(opts *RawJWTOptions) (*RawJWT, error) {
70*e7b1675dSTing-Kang Chang	if opts == nil {
71*e7b1675dSTing-Kang Chang		return nil, fmt.Errorf("jwt options can't be nil")
72*e7b1675dSTing-Kang Chang	}
73*e7b1675dSTing-Kang Chang	payload, err := createPayload(opts)
74*e7b1675dSTing-Kang Chang	if err != nil {
75*e7b1675dSTing-Kang Chang		return nil, err
76*e7b1675dSTing-Kang Chang	}
77*e7b1675dSTing-Kang Chang	if err := validatePayload(payload); err != nil {
78*e7b1675dSTing-Kang Chang		return nil, err
79*e7b1675dSTing-Kang Chang	}
80*e7b1675dSTing-Kang Chang	return &RawJWT{
81*e7b1675dSTing-Kang Chang		jsonpb:     payload,
82*e7b1675dSTing-Kang Chang		typeHeader: opts.TypeHeader,
83*e7b1675dSTing-Kang Chang	}, nil
84*e7b1675dSTing-Kang Chang}
85*e7b1675dSTing-Kang Chang
86*e7b1675dSTing-Kang Chang// NewRawJWTFromJSON builds a RawJWT from a marshaled JSON.
87*e7b1675dSTing-Kang Chang// Users shouldn't call this function and instead use NewRawJWT.
88*e7b1675dSTing-Kang Changfunc NewRawJWTFromJSON(typeHeader *string, jsonPayload []byte) (*RawJWT, error) {
89*e7b1675dSTing-Kang Chang	payload := &spb.Struct{}
90*e7b1675dSTing-Kang Chang	if err := payload.UnmarshalJSON(jsonPayload); err != nil {
91*e7b1675dSTing-Kang Chang		return nil, err
92*e7b1675dSTing-Kang Chang	}
93*e7b1675dSTing-Kang Chang	if err := validatePayload(payload); err != nil {
94*e7b1675dSTing-Kang Chang		return nil, err
95*e7b1675dSTing-Kang Chang	}
96*e7b1675dSTing-Kang Chang	return &RawJWT{
97*e7b1675dSTing-Kang Chang		jsonpb:     payload,
98*e7b1675dSTing-Kang Chang		typeHeader: typeHeader,
99*e7b1675dSTing-Kang Chang	}, nil
100*e7b1675dSTing-Kang Chang}
101*e7b1675dSTing-Kang Chang
102*e7b1675dSTing-Kang Chang// JSONPayload marshals a RawJWT payload to JSON.
103*e7b1675dSTing-Kang Changfunc (r *RawJWT) JSONPayload() ([]byte, error) {
104*e7b1675dSTing-Kang Chang	return r.jsonpb.MarshalJSON()
105*e7b1675dSTing-Kang Chang}
106*e7b1675dSTing-Kang Chang
107*e7b1675dSTing-Kang Chang// HasTypeHeader returns whether a RawJWT contains a type header.
108*e7b1675dSTing-Kang Changfunc (r *RawJWT) HasTypeHeader() bool {
109*e7b1675dSTing-Kang Chang	return r.typeHeader != nil
110*e7b1675dSTing-Kang Chang}
111*e7b1675dSTing-Kang Chang
112*e7b1675dSTing-Kang Chang// TypeHeader returns the JWT type header.
113*e7b1675dSTing-Kang Changfunc (r *RawJWT) TypeHeader() (string, error) {
114*e7b1675dSTing-Kang Chang	if !r.HasTypeHeader() {
115*e7b1675dSTing-Kang Chang		return "", fmt.Errorf("no type header present")
116*e7b1675dSTing-Kang Chang	}
117*e7b1675dSTing-Kang Chang	return *r.typeHeader, nil
118*e7b1675dSTing-Kang Chang}
119*e7b1675dSTing-Kang Chang
120*e7b1675dSTing-Kang Chang// HasAudiences checks whether a JWT contains the audience claim ('aud').
121*e7b1675dSTing-Kang Changfunc (r *RawJWT) HasAudiences() bool {
122*e7b1675dSTing-Kang Chang	return r.hasField(claimAudience)
123*e7b1675dSTing-Kang Chang}
124*e7b1675dSTing-Kang Chang
125*e7b1675dSTing-Kang Chang// Audiences returns a list of audiences from the 'aud' claim. If the 'aud' claim is a single string, it is converted into a list with a single entry.
126*e7b1675dSTing-Kang Changfunc (r *RawJWT) Audiences() ([]string, error) {
127*e7b1675dSTing-Kang Chang	aud, ok := r.field(claimAudience)
128*e7b1675dSTing-Kang Chang	if !ok {
129*e7b1675dSTing-Kang Chang		return nil, fmt.Errorf("no audience claim found")
130*e7b1675dSTing-Kang Chang	}
131*e7b1675dSTing-Kang Chang	if err := validateAudienceClaim(aud); err != nil {
132*e7b1675dSTing-Kang Chang		return nil, err
133*e7b1675dSTing-Kang Chang	}
134*e7b1675dSTing-Kang Chang	if val, isString := aud.GetKind().(*spb.Value_StringValue); isString {
135*e7b1675dSTing-Kang Chang		return []string{val.StringValue}, nil
136*e7b1675dSTing-Kang Chang	}
137*e7b1675dSTing-Kang Chang	s := make([]string, 0, len(aud.GetListValue().GetValues()))
138*e7b1675dSTing-Kang Chang	for _, a := range aud.GetListValue().GetValues() {
139*e7b1675dSTing-Kang Chang		s = append(s, a.GetStringValue())
140*e7b1675dSTing-Kang Chang	}
141*e7b1675dSTing-Kang Chang	return s, nil
142*e7b1675dSTing-Kang Chang}
143*e7b1675dSTing-Kang Chang
144*e7b1675dSTing-Kang Chang// HasSubject checks whether a JWT contains an issuer claim ('sub').
145*e7b1675dSTing-Kang Changfunc (r *RawJWT) HasSubject() bool {
146*e7b1675dSTing-Kang Chang	return r.hasField(claimSubject)
147*e7b1675dSTing-Kang Chang}
148*e7b1675dSTing-Kang Chang
149*e7b1675dSTing-Kang Chang// Subject returns the subject claim ('sub') or an error if no claim is present.
150*e7b1675dSTing-Kang Changfunc (r *RawJWT) Subject() (string, error) {
151*e7b1675dSTing-Kang Chang	return r.stringClaim(claimSubject)
152*e7b1675dSTing-Kang Chang}
153*e7b1675dSTing-Kang Chang
154*e7b1675dSTing-Kang Chang// HasIssuer checks whether a JWT contains an issuer claim ('iss').
155*e7b1675dSTing-Kang Changfunc (r *RawJWT) HasIssuer() bool {
156*e7b1675dSTing-Kang Chang	return r.hasField(claimIssuer)
157*e7b1675dSTing-Kang Chang}
158*e7b1675dSTing-Kang Chang
159*e7b1675dSTing-Kang Chang// Issuer returns the issuer claim ('iss') or an error if no claim is present.
160*e7b1675dSTing-Kang Changfunc (r *RawJWT) Issuer() (string, error) {
161*e7b1675dSTing-Kang Chang	return r.stringClaim(claimIssuer)
162*e7b1675dSTing-Kang Chang}
163*e7b1675dSTing-Kang Chang
164*e7b1675dSTing-Kang Chang// HasJWTID checks whether a JWT contains an JWT ID claim ('jti').
165*e7b1675dSTing-Kang Changfunc (r *RawJWT) HasJWTID() bool {
166*e7b1675dSTing-Kang Chang	return r.hasField(claimJWTID)
167*e7b1675dSTing-Kang Chang}
168*e7b1675dSTing-Kang Chang
169*e7b1675dSTing-Kang Chang// JWTID returns the JWT ID claim ('jti') or an error if no claim is present.
170*e7b1675dSTing-Kang Changfunc (r *RawJWT) JWTID() (string, error) {
171*e7b1675dSTing-Kang Chang	return r.stringClaim(claimJWTID)
172*e7b1675dSTing-Kang Chang}
173*e7b1675dSTing-Kang Chang
174*e7b1675dSTing-Kang Chang// HasIssuedAt checks whether a JWT contains an issued at claim ('iat').
175*e7b1675dSTing-Kang Changfunc (r *RawJWT) HasIssuedAt() bool {
176*e7b1675dSTing-Kang Chang	return r.hasField(claimIssuedAt)
177*e7b1675dSTing-Kang Chang}
178*e7b1675dSTing-Kang Chang
179*e7b1675dSTing-Kang Chang// IssuedAt returns the issued at claim ('iat') or an error if no claim is present.
180*e7b1675dSTing-Kang Changfunc (r *RawJWT) IssuedAt() (time.Time, error) {
181*e7b1675dSTing-Kang Chang	return r.timeClaim(claimIssuedAt)
182*e7b1675dSTing-Kang Chang}
183*e7b1675dSTing-Kang Chang
184*e7b1675dSTing-Kang Chang// HasExpiration checks whether a JWT contains an expiration time claim ('exp').
185*e7b1675dSTing-Kang Changfunc (r *RawJWT) HasExpiration() bool {
186*e7b1675dSTing-Kang Chang	return r.hasField(claimExpiration)
187*e7b1675dSTing-Kang Chang}
188*e7b1675dSTing-Kang Chang
189*e7b1675dSTing-Kang Chang// ExpiresAt returns the expiration claim ('exp') or an error if no claim is present.
190*e7b1675dSTing-Kang Changfunc (r *RawJWT) ExpiresAt() (time.Time, error) {
191*e7b1675dSTing-Kang Chang	return r.timeClaim(claimExpiration)
192*e7b1675dSTing-Kang Chang}
193*e7b1675dSTing-Kang Chang
194*e7b1675dSTing-Kang Chang// HasNotBefore checks whether a JWT contains a not before claim ('nbf').
195*e7b1675dSTing-Kang Changfunc (r *RawJWT) HasNotBefore() bool {
196*e7b1675dSTing-Kang Chang	return r.hasField(claimNotBefore)
197*e7b1675dSTing-Kang Chang}
198*e7b1675dSTing-Kang Chang
199*e7b1675dSTing-Kang Chang// NotBefore returns the not before claim ('nbf') or an error if no claim is present.
200*e7b1675dSTing-Kang Changfunc (r *RawJWT) NotBefore() (time.Time, error) {
201*e7b1675dSTing-Kang Chang	return r.timeClaim(claimNotBefore)
202*e7b1675dSTing-Kang Chang}
203*e7b1675dSTing-Kang Chang
204*e7b1675dSTing-Kang Chang// HasStringClaim checks whether a claim of type string is present.
205*e7b1675dSTing-Kang Changfunc (r *RawJWT) HasStringClaim(name string) bool {
206*e7b1675dSTing-Kang Chang	return !isRegisteredClaim(name) && r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_StringValue{}})
207*e7b1675dSTing-Kang Chang}
208*e7b1675dSTing-Kang Chang
209*e7b1675dSTing-Kang Chang// StringClaim returns a custom string claim or an error if no claim is present.
210*e7b1675dSTing-Kang Changfunc (r *RawJWT) StringClaim(name string) (string, error) {
211*e7b1675dSTing-Kang Chang	if isRegisteredClaim(name) {
212*e7b1675dSTing-Kang Chang		return "", fmt.Errorf("claim '%q' is a registered claim", name)
213*e7b1675dSTing-Kang Chang	}
214*e7b1675dSTing-Kang Chang	return r.stringClaim(name)
215*e7b1675dSTing-Kang Chang}
216*e7b1675dSTing-Kang Chang
217*e7b1675dSTing-Kang Chang// HasNumberClaim checks whether a claim of type number is present.
218*e7b1675dSTing-Kang Changfunc (r *RawJWT) HasNumberClaim(name string) bool {
219*e7b1675dSTing-Kang Chang	return !isRegisteredClaim(name) && r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_NumberValue{}})
220*e7b1675dSTing-Kang Chang}
221*e7b1675dSTing-Kang Chang
222*e7b1675dSTing-Kang Chang// NumberClaim returns a custom number claim or an error if no claim is present.
223*e7b1675dSTing-Kang Changfunc (r *RawJWT) NumberClaim(name string) (float64, error) {
224*e7b1675dSTing-Kang Chang	if isRegisteredClaim(name) {
225*e7b1675dSTing-Kang Chang		return 0, fmt.Errorf("claim '%q' is a registered claim", name)
226*e7b1675dSTing-Kang Chang	}
227*e7b1675dSTing-Kang Chang	return r.numberClaim(name)
228*e7b1675dSTing-Kang Chang}
229*e7b1675dSTing-Kang Chang
230*e7b1675dSTing-Kang Chang// HasBooleanClaim checks whether a claim of type boolean is present.
231*e7b1675dSTing-Kang Changfunc (r *RawJWT) HasBooleanClaim(name string) bool {
232*e7b1675dSTing-Kang Chang	return r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_BoolValue{}})
233*e7b1675dSTing-Kang Chang}
234*e7b1675dSTing-Kang Chang
235*e7b1675dSTing-Kang Chang// BooleanClaim returns a custom bool claim or an error if no claim is present.
236*e7b1675dSTing-Kang Changfunc (r *RawJWT) BooleanClaim(name string) (bool, error) {
237*e7b1675dSTing-Kang Chang	val, err := r.customClaim(name)
238*e7b1675dSTing-Kang Chang	if err != nil {
239*e7b1675dSTing-Kang Chang		return false, err
240*e7b1675dSTing-Kang Chang	}
241*e7b1675dSTing-Kang Chang	b, ok := val.Kind.(*spb.Value_BoolValue)
242*e7b1675dSTing-Kang Chang	if !ok {
243*e7b1675dSTing-Kang Chang		return false, fmt.Errorf("claim '%q' is not a boolean", name)
244*e7b1675dSTing-Kang Chang	}
245*e7b1675dSTing-Kang Chang	return b.BoolValue, nil
246*e7b1675dSTing-Kang Chang}
247*e7b1675dSTing-Kang Chang
248*e7b1675dSTing-Kang Chang// HasNullClaim checks whether a claim of type null is present.
249*e7b1675dSTing-Kang Changfunc (r *RawJWT) HasNullClaim(name string) bool {
250*e7b1675dSTing-Kang Chang	return r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_NullValue{}})
251*e7b1675dSTing-Kang Chang}
252*e7b1675dSTing-Kang Chang
253*e7b1675dSTing-Kang Chang// HasArrayClaim checks whether a claim of type list is present.
254*e7b1675dSTing-Kang Changfunc (r *RawJWT) HasArrayClaim(name string) bool {
255*e7b1675dSTing-Kang Chang	return !isRegisteredClaim(name) && r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_ListValue{}})
256*e7b1675dSTing-Kang Chang}
257*e7b1675dSTing-Kang Chang
258*e7b1675dSTing-Kang Chang// ArrayClaim returns a slice representing a JSON array for a claim or an error if the claim is empty.
259*e7b1675dSTing-Kang Changfunc (r *RawJWT) ArrayClaim(name string) ([]interface{}, error) {
260*e7b1675dSTing-Kang Chang	val, err := r.customClaim(name)
261*e7b1675dSTing-Kang Chang	if err != nil {
262*e7b1675dSTing-Kang Chang		return nil, err
263*e7b1675dSTing-Kang Chang	}
264*e7b1675dSTing-Kang Chang	if val.GetListValue() == nil {
265*e7b1675dSTing-Kang Chang		return nil, fmt.Errorf("claim '%q' is not a list", name)
266*e7b1675dSTing-Kang Chang	}
267*e7b1675dSTing-Kang Chang	return val.GetListValue().AsSlice(), nil
268*e7b1675dSTing-Kang Chang}
269*e7b1675dSTing-Kang Chang
270*e7b1675dSTing-Kang Chang// HasObjectClaim checks whether a claim of type JSON object is present.
271*e7b1675dSTing-Kang Changfunc (r *RawJWT) HasObjectClaim(name string) bool {
272*e7b1675dSTing-Kang Chang	return r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_StructValue{}})
273*e7b1675dSTing-Kang Chang}
274*e7b1675dSTing-Kang Chang
275*e7b1675dSTing-Kang Chang// ObjectClaim returns a map representing a JSON object for a claim or an error if the claim is empty.
276*e7b1675dSTing-Kang Changfunc (r *RawJWT) ObjectClaim(name string) (map[string]interface{}, error) {
277*e7b1675dSTing-Kang Chang	val, err := r.customClaim(name)
278*e7b1675dSTing-Kang Chang	if err != nil {
279*e7b1675dSTing-Kang Chang		return nil, err
280*e7b1675dSTing-Kang Chang	}
281*e7b1675dSTing-Kang Chang	if val.GetStructValue() == nil {
282*e7b1675dSTing-Kang Chang		return nil, fmt.Errorf("claim '%q' is not a JSON object", name)
283*e7b1675dSTing-Kang Chang	}
284*e7b1675dSTing-Kang Chang	return val.GetStructValue().AsMap(), err
285*e7b1675dSTing-Kang Chang}
286*e7b1675dSTing-Kang Chang
287*e7b1675dSTing-Kang Chang// CustomClaimNames returns a list with the name of custom claims in a RawJWT.
288*e7b1675dSTing-Kang Changfunc (r *RawJWT) CustomClaimNames() []string {
289*e7b1675dSTing-Kang Chang	names := []string{}
290*e7b1675dSTing-Kang Chang	for key := range r.jsonpb.GetFields() {
291*e7b1675dSTing-Kang Chang		if !isRegisteredClaim(key) {
292*e7b1675dSTing-Kang Chang			names = append(names, key)
293*e7b1675dSTing-Kang Chang		}
294*e7b1675dSTing-Kang Chang	}
295*e7b1675dSTing-Kang Chang	return names
296*e7b1675dSTing-Kang Chang}
297*e7b1675dSTing-Kang Chang
298*e7b1675dSTing-Kang Changfunc (r *RawJWT) timeClaim(name string) (time.Time, error) {
299*e7b1675dSTing-Kang Chang	n, err := r.numberClaim(name)
300*e7b1675dSTing-Kang Chang	if err != nil {
301*e7b1675dSTing-Kang Chang		return time.Time{}, err
302*e7b1675dSTing-Kang Chang	}
303*e7b1675dSTing-Kang Chang	return time.Unix(int64(n), 0), err
304*e7b1675dSTing-Kang Chang}
305*e7b1675dSTing-Kang Chang
306*e7b1675dSTing-Kang Changfunc (r *RawJWT) numberClaim(name string) (float64, error) {
307*e7b1675dSTing-Kang Chang	val, ok := r.field(name)
308*e7b1675dSTing-Kang Chang	if !ok {
309*e7b1675dSTing-Kang Chang		return 0, fmt.Errorf("no '%q' claim found", name)
310*e7b1675dSTing-Kang Chang	}
311*e7b1675dSTing-Kang Chang	s, ok := val.Kind.(*spb.Value_NumberValue)
312*e7b1675dSTing-Kang Chang	if !ok {
313*e7b1675dSTing-Kang Chang		return 0, fmt.Errorf("claim '%q' is not a number", name)
314*e7b1675dSTing-Kang Chang	}
315*e7b1675dSTing-Kang Chang	return s.NumberValue, nil
316*e7b1675dSTing-Kang Chang}
317*e7b1675dSTing-Kang Chang
318*e7b1675dSTing-Kang Changfunc (r *RawJWT) stringClaim(name string) (string, error) {
319*e7b1675dSTing-Kang Chang	val, ok := r.field(name)
320*e7b1675dSTing-Kang Chang	if !ok {
321*e7b1675dSTing-Kang Chang		return "", fmt.Errorf("no '%q' claim found", name)
322*e7b1675dSTing-Kang Chang	}
323*e7b1675dSTing-Kang Chang	s, ok := val.Kind.(*spb.Value_StringValue)
324*e7b1675dSTing-Kang Chang	if !ok {
325*e7b1675dSTing-Kang Chang		return "", fmt.Errorf("claim '%q' is not a string", name)
326*e7b1675dSTing-Kang Chang	}
327*e7b1675dSTing-Kang Chang	if !utf8.ValidString(s.StringValue) {
328*e7b1675dSTing-Kang Chang		return "", fmt.Errorf("claim '%q' is not a valid utf-8 encoded string", name)
329*e7b1675dSTing-Kang Chang	}
330*e7b1675dSTing-Kang Chang	return s.StringValue, nil
331*e7b1675dSTing-Kang Chang}
332*e7b1675dSTing-Kang Chang
333*e7b1675dSTing-Kang Changfunc (r *RawJWT) hasClaimOfKind(name string, exp *spb.Value) bool {
334*e7b1675dSTing-Kang Chang	val, exist := r.field(name)
335*e7b1675dSTing-Kang Chang	if !exist || exp == nil {
336*e7b1675dSTing-Kang Chang		return false
337*e7b1675dSTing-Kang Chang	}
338*e7b1675dSTing-Kang Chang	var isKind bool
339*e7b1675dSTing-Kang Chang	switch exp.GetKind().(type) {
340*e7b1675dSTing-Kang Chang	case *spb.Value_StructValue:
341*e7b1675dSTing-Kang Chang		_, isKind = val.GetKind().(*spb.Value_StructValue)
342*e7b1675dSTing-Kang Chang	case *spb.Value_NullValue:
343*e7b1675dSTing-Kang Chang		_, isKind = val.GetKind().(*spb.Value_NullValue)
344*e7b1675dSTing-Kang Chang	case *spb.Value_BoolValue:
345*e7b1675dSTing-Kang Chang		_, isKind = val.GetKind().(*spb.Value_BoolValue)
346*e7b1675dSTing-Kang Chang	case *spb.Value_ListValue:
347*e7b1675dSTing-Kang Chang		_, isKind = val.GetKind().(*spb.Value_ListValue)
348*e7b1675dSTing-Kang Chang	case *spb.Value_StringValue:
349*e7b1675dSTing-Kang Chang		_, isKind = val.GetKind().(*spb.Value_StringValue)
350*e7b1675dSTing-Kang Chang	case *spb.Value_NumberValue:
351*e7b1675dSTing-Kang Chang		_, isKind = val.GetKind().(*spb.Value_NumberValue)
352*e7b1675dSTing-Kang Chang	default:
353*e7b1675dSTing-Kang Chang		isKind = false
354*e7b1675dSTing-Kang Chang	}
355*e7b1675dSTing-Kang Chang	return isKind
356*e7b1675dSTing-Kang Chang}
357*e7b1675dSTing-Kang Chang
358*e7b1675dSTing-Kang Changfunc (r *RawJWT) customClaim(name string) (*spb.Value, error) {
359*e7b1675dSTing-Kang Chang	if isRegisteredClaim(name) {
360*e7b1675dSTing-Kang Chang		return nil, fmt.Errorf("'%q' is a registered claim", name)
361*e7b1675dSTing-Kang Chang	}
362*e7b1675dSTing-Kang Chang	val, ok := r.field(name)
363*e7b1675dSTing-Kang Chang	if !ok {
364*e7b1675dSTing-Kang Chang		return nil, fmt.Errorf("claim '%q' not found", name)
365*e7b1675dSTing-Kang Chang	}
366*e7b1675dSTing-Kang Chang	return val, nil
367*e7b1675dSTing-Kang Chang}
368*e7b1675dSTing-Kang Chang
369*e7b1675dSTing-Kang Changfunc (r *RawJWT) hasField(name string) bool {
370*e7b1675dSTing-Kang Chang	_, ok := r.field(name)
371*e7b1675dSTing-Kang Chang	return ok
372*e7b1675dSTing-Kang Chang}
373*e7b1675dSTing-Kang Chang
374*e7b1675dSTing-Kang Changfunc (r *RawJWT) field(name string) (*spb.Value, bool) {
375*e7b1675dSTing-Kang Chang	val, ok := r.jsonpb.GetFields()[name]
376*e7b1675dSTing-Kang Chang	return val, ok
377*e7b1675dSTing-Kang Chang}
378*e7b1675dSTing-Kang Chang
379*e7b1675dSTing-Kang Chang// createPayload creates a JSON payload from JWT options.
380*e7b1675dSTing-Kang Changfunc createPayload(opts *RawJWTOptions) (*spb.Struct, error) {
381*e7b1675dSTing-Kang Chang	if err := validateCustomClaims(opts.CustomClaims); err != nil {
382*e7b1675dSTing-Kang Chang		return nil, err
383*e7b1675dSTing-Kang Chang	}
384*e7b1675dSTing-Kang Chang	if opts.ExpiresAt == nil && !opts.WithoutExpiration {
385*e7b1675dSTing-Kang Chang		return nil, fmt.Errorf("jwt options must contain an expiration or must be marked WithoutExpiration")
386*e7b1675dSTing-Kang Chang	}
387*e7b1675dSTing-Kang Chang	if opts.ExpiresAt != nil && opts.WithoutExpiration {
388*e7b1675dSTing-Kang Chang		return nil, fmt.Errorf("jwt options can't be marked WithoutExpiration when expiration is specified")
389*e7b1675dSTing-Kang Chang	}
390*e7b1675dSTing-Kang Chang	if opts.Audience != nil && opts.Audiences != nil {
391*e7b1675dSTing-Kang Chang		return nil, fmt.Errorf("jwt options can either contain a single Audience or a list of Audiences but not both")
392*e7b1675dSTing-Kang Chang	}
393*e7b1675dSTing-Kang Chang
394*e7b1675dSTing-Kang Chang	payload := &spb.Struct{
395*e7b1675dSTing-Kang Chang		Fields: map[string]*spb.Value{},
396*e7b1675dSTing-Kang Chang	}
397*e7b1675dSTing-Kang Chang	setStringValue(payload, claimJWTID, opts.JWTID)
398*e7b1675dSTing-Kang Chang	setStringValue(payload, claimIssuer, opts.Issuer)
399*e7b1675dSTing-Kang Chang	setStringValue(payload, claimSubject, opts.Subject)
400*e7b1675dSTing-Kang Chang	setStringValue(payload, claimAudience, opts.Audience)
401*e7b1675dSTing-Kang Chang	setTimeValue(payload, claimIssuedAt, opts.IssuedAt)
402*e7b1675dSTing-Kang Chang	setTimeValue(payload, claimNotBefore, opts.NotBefore)
403*e7b1675dSTing-Kang Chang	setTimeValue(payload, claimExpiration, opts.ExpiresAt)
404*e7b1675dSTing-Kang Chang	setAudiences(payload, claimAudience, opts.Audiences)
405*e7b1675dSTing-Kang Chang
406*e7b1675dSTing-Kang Chang	for k, v := range opts.CustomClaims {
407*e7b1675dSTing-Kang Chang		val, err := spb.NewValue(v)
408*e7b1675dSTing-Kang Chang		if err != nil {
409*e7b1675dSTing-Kang Chang			return nil, err
410*e7b1675dSTing-Kang Chang		}
411*e7b1675dSTing-Kang Chang		setValue(payload, k, val)
412*e7b1675dSTing-Kang Chang	}
413*e7b1675dSTing-Kang Chang	return payload, nil
414*e7b1675dSTing-Kang Chang}
415*e7b1675dSTing-Kang Chang
416*e7b1675dSTing-Kang Changfunc validatePayload(payload *spb.Struct) error {
417*e7b1675dSTing-Kang Chang	if payload.Fields == nil || len(payload.Fields) == 0 {
418*e7b1675dSTing-Kang Chang		return nil
419*e7b1675dSTing-Kang Chang	}
420*e7b1675dSTing-Kang Chang	if err := validateAudienceClaim(payload.Fields[claimAudience]); err != nil {
421*e7b1675dSTing-Kang Chang		return err
422*e7b1675dSTing-Kang Chang	}
423*e7b1675dSTing-Kang Chang	for claim, val := range payload.GetFields() {
424*e7b1675dSTing-Kang Chang		if isRegisteredTimeClaim(claim) {
425*e7b1675dSTing-Kang Chang			if err := validateTimeClaim(claim, val); err != nil {
426*e7b1675dSTing-Kang Chang				return err
427*e7b1675dSTing-Kang Chang			}
428*e7b1675dSTing-Kang Chang		}
429*e7b1675dSTing-Kang Chang
430*e7b1675dSTing-Kang Chang		if isRegisteredStringClaim(claim) {
431*e7b1675dSTing-Kang Chang			if err := validateStringClaim(claim, val); err != nil {
432*e7b1675dSTing-Kang Chang				return err
433*e7b1675dSTing-Kang Chang			}
434*e7b1675dSTing-Kang Chang		}
435*e7b1675dSTing-Kang Chang	}
436*e7b1675dSTing-Kang Chang	return nil
437*e7b1675dSTing-Kang Chang}
438*e7b1675dSTing-Kang Chang
439*e7b1675dSTing-Kang Changfunc validateStringClaim(claim string, val *spb.Value) error {
440*e7b1675dSTing-Kang Chang	v, ok := val.Kind.(*spb.Value_StringValue)
441*e7b1675dSTing-Kang Chang	if !ok {
442*e7b1675dSTing-Kang Chang		return fmt.Errorf("claim: '%q' MUST be a string", claim)
443*e7b1675dSTing-Kang Chang	}
444*e7b1675dSTing-Kang Chang	if !utf8.ValidString(v.StringValue) {
445*e7b1675dSTing-Kang Chang		return fmt.Errorf("claim: '%q' isn't a valid UTF-8 string", claim)
446*e7b1675dSTing-Kang Chang	}
447*e7b1675dSTing-Kang Chang	return nil
448*e7b1675dSTing-Kang Chang}
449*e7b1675dSTing-Kang Chang
450*e7b1675dSTing-Kang Changfunc validateTimeClaim(claim string, val *spb.Value) error {
451*e7b1675dSTing-Kang Chang	if _, ok := val.Kind.(*spb.Value_NumberValue); !ok {
452*e7b1675dSTing-Kang Chang		return fmt.Errorf("claim %q MUST be a numeric value, ", claim)
453*e7b1675dSTing-Kang Chang	}
454*e7b1675dSTing-Kang Chang	t := int64(val.GetNumberValue())
455*e7b1675dSTing-Kang Chang	if t > jwtTimestampMax || t < jwtTimestampMin {
456*e7b1675dSTing-Kang Chang		return fmt.Errorf("invalid timestamp: '%d' for claim: %q", t, claim)
457*e7b1675dSTing-Kang Chang	}
458*e7b1675dSTing-Kang Chang	return nil
459*e7b1675dSTing-Kang Chang}
460*e7b1675dSTing-Kang Chang
461*e7b1675dSTing-Kang Changfunc validateAudienceClaim(val *spb.Value) error {
462*e7b1675dSTing-Kang Chang	if val == nil {
463*e7b1675dSTing-Kang Chang		return nil
464*e7b1675dSTing-Kang Chang	}
465*e7b1675dSTing-Kang Chang	_, isString := val.Kind.(*spb.Value_StringValue)
466*e7b1675dSTing-Kang Chang	l, isList := val.Kind.(*spb.Value_ListValue)
467*e7b1675dSTing-Kang Chang	if !isList && !isString {
468*e7b1675dSTing-Kang Chang		return fmt.Errorf("audience claim MUST be a list with at least one string or a single string value")
469*e7b1675dSTing-Kang Chang	}
470*e7b1675dSTing-Kang Chang	if isString {
471*e7b1675dSTing-Kang Chang		return validateStringClaim(claimAudience, val)
472*e7b1675dSTing-Kang Chang	}
473*e7b1675dSTing-Kang Chang	if l.ListValue != nil && len(l.ListValue.Values) == 0 {
474*e7b1675dSTing-Kang Chang		return fmt.Errorf("there MUST be at least one value present in the audience claim")
475*e7b1675dSTing-Kang Chang	}
476*e7b1675dSTing-Kang Chang	for _, aud := range l.ListValue.Values {
477*e7b1675dSTing-Kang Chang		v, ok := aud.Kind.(*spb.Value_StringValue)
478*e7b1675dSTing-Kang Chang		if !ok {
479*e7b1675dSTing-Kang Chang			return fmt.Errorf("audience value is not a string")
480*e7b1675dSTing-Kang Chang		}
481*e7b1675dSTing-Kang Chang		if !utf8.ValidString(v.StringValue) {
482*e7b1675dSTing-Kang Chang			return fmt.Errorf("audience value is not a valid UTF-8 string")
483*e7b1675dSTing-Kang Chang		}
484*e7b1675dSTing-Kang Chang	}
485*e7b1675dSTing-Kang Chang	return nil
486*e7b1675dSTing-Kang Chang}
487*e7b1675dSTing-Kang Chang
488*e7b1675dSTing-Kang Changfunc validateCustomClaims(cc map[string]interface{}) error {
489*e7b1675dSTing-Kang Chang	if cc == nil {
490*e7b1675dSTing-Kang Chang		return nil
491*e7b1675dSTing-Kang Chang	}
492*e7b1675dSTing-Kang Chang	for key := range cc {
493*e7b1675dSTing-Kang Chang		if isRegisteredClaim(key) {
494*e7b1675dSTing-Kang Chang			return fmt.Errorf("claim '%q' is a registered claim, it can't be declared as a custom claim", key)
495*e7b1675dSTing-Kang Chang		}
496*e7b1675dSTing-Kang Chang	}
497*e7b1675dSTing-Kang Chang	return nil
498*e7b1675dSTing-Kang Chang}
499*e7b1675dSTing-Kang Chang
500*e7b1675dSTing-Kang Changfunc setTimeValue(p *spb.Struct, claim string, val *time.Time) {
501*e7b1675dSTing-Kang Chang	if val == nil {
502*e7b1675dSTing-Kang Chang		return
503*e7b1675dSTing-Kang Chang	}
504*e7b1675dSTing-Kang Chang	setValue(p, claim, spb.NewNumberValue(float64(val.Unix())))
505*e7b1675dSTing-Kang Chang}
506*e7b1675dSTing-Kang Chang
507*e7b1675dSTing-Kang Changfunc setStringValue(p *spb.Struct, claim string, val *string) {
508*e7b1675dSTing-Kang Chang	if val == nil {
509*e7b1675dSTing-Kang Chang		return
510*e7b1675dSTing-Kang Chang	}
511*e7b1675dSTing-Kang Chang	setValue(p, claim, spb.NewStringValue(*val))
512*e7b1675dSTing-Kang Chang}
513*e7b1675dSTing-Kang Chang
514*e7b1675dSTing-Kang Changfunc setAudiences(p *spb.Struct, claim string, vals []string) {
515*e7b1675dSTing-Kang Chang	if vals == nil {
516*e7b1675dSTing-Kang Chang		return
517*e7b1675dSTing-Kang Chang	}
518*e7b1675dSTing-Kang Chang	audList := &spb.ListValue{
519*e7b1675dSTing-Kang Chang		Values: make([]*spb.Value, 0, len(vals)),
520*e7b1675dSTing-Kang Chang	}
521*e7b1675dSTing-Kang Chang	for _, aud := range vals {
522*e7b1675dSTing-Kang Chang		audList.Values = append(audList.Values, spb.NewStringValue(aud))
523*e7b1675dSTing-Kang Chang	}
524*e7b1675dSTing-Kang Chang	setValue(p, claim, spb.NewListValue(audList))
525*e7b1675dSTing-Kang Chang}
526*e7b1675dSTing-Kang Chang
527*e7b1675dSTing-Kang Changfunc setValue(p *spb.Struct, claim string, val *spb.Value) {
528*e7b1675dSTing-Kang Chang	if p.GetFields() == nil {
529*e7b1675dSTing-Kang Chang		p.Fields = make(map[string]*spb.Value)
530*e7b1675dSTing-Kang Chang	}
531*e7b1675dSTing-Kang Chang	p.GetFields()[claim] = val
532*e7b1675dSTing-Kang Chang}
533*e7b1675dSTing-Kang Chang
534*e7b1675dSTing-Kang Changfunc isRegisteredClaim(c string) bool {
535*e7b1675dSTing-Kang Chang	return isRegisteredStringClaim(c) || isRegisteredTimeClaim(c) || c == claimAudience
536*e7b1675dSTing-Kang Chang}
537*e7b1675dSTing-Kang Chang
538*e7b1675dSTing-Kang Changfunc isRegisteredStringClaim(c string) bool {
539*e7b1675dSTing-Kang Chang	return c == claimIssuer || c == claimSubject || c == claimJWTID
540*e7b1675dSTing-Kang Chang}
541*e7b1675dSTing-Kang Chang
542*e7b1675dSTing-Kang Changfunc isRegisteredTimeClaim(c string) bool {
543*e7b1675dSTing-Kang Chang	return c == claimExpiration || c == claimNotBefore || c == claimIssuedAt
544*e7b1675dSTing-Kang Chang}
545