// Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //////////////////////////////////////////////////////////////////////////////// package jwt import ( "fmt" "time" "unicode/utf8" spb "google.golang.org/protobuf/types/known/structpb" ) const ( claimIssuer = "iss" claimSubject = "sub" claimAudience = "aud" claimExpiration = "exp" claimNotBefore = "nbf" claimIssuedAt = "iat" claimJWTID = "jti" jwtTimestampMax = 253402300799 jwtTimestampMin = 0 ) // RawJWTOptions represent an unsigned JSON Web Token (JWT), https://tools.ietf.org/html/rfc7519. // // It contains all payload claims and a subset of the headers. It does not // contain any headers that depend on the key, such as "alg" or "kid", because // these headers are chosen when the token is signed and encoded, and should not // be chosen by the user. This ensures that the key can be changed without any // changes to the user code. type RawJWTOptions struct { Audiences []string Audience *string Subject *string Issuer *string JWTID *string IssuedAt *time.Time ExpiresAt *time.Time NotBefore *time.Time CustomClaims map[string]interface{} TypeHeader *string WithoutExpiration bool } // RawJWT is an unsigned JSON Web Token (JWT), https://tools.ietf.org/html/rfc7519. type RawJWT struct { jsonpb *spb.Struct typeHeader *string } // NewRawJWT constructs a new RawJWT token based on the RawJwtOptions provided. func NewRawJWT(opts *RawJWTOptions) (*RawJWT, error) { if opts == nil { return nil, fmt.Errorf("jwt options can't be nil") } payload, err := createPayload(opts) if err != nil { return nil, err } if err := validatePayload(payload); err != nil { return nil, err } return &RawJWT{ jsonpb: payload, typeHeader: opts.TypeHeader, }, nil } // NewRawJWTFromJSON builds a RawJWT from a marshaled JSON. // Users shouldn't call this function and instead use NewRawJWT. func NewRawJWTFromJSON(typeHeader *string, jsonPayload []byte) (*RawJWT, error) { payload := &spb.Struct{} if err := payload.UnmarshalJSON(jsonPayload); err != nil { return nil, err } if err := validatePayload(payload); err != nil { return nil, err } return &RawJWT{ jsonpb: payload, typeHeader: typeHeader, }, nil } // JSONPayload marshals a RawJWT payload to JSON. func (r *RawJWT) JSONPayload() ([]byte, error) { return r.jsonpb.MarshalJSON() } // HasTypeHeader returns whether a RawJWT contains a type header. func (r *RawJWT) HasTypeHeader() bool { return r.typeHeader != nil } // TypeHeader returns the JWT type header. func (r *RawJWT) TypeHeader() (string, error) { if !r.HasTypeHeader() { return "", fmt.Errorf("no type header present") } return *r.typeHeader, nil } // HasAudiences checks whether a JWT contains the audience claim ('aud'). func (r *RawJWT) HasAudiences() bool { return r.hasField(claimAudience) } // Audiences returns a list of audiences from the 'aud' claim. If the 'aud' claim is a single string, it is converted into a list with a single entry. func (r *RawJWT) Audiences() ([]string, error) { aud, ok := r.field(claimAudience) if !ok { return nil, fmt.Errorf("no audience claim found") } if err := validateAudienceClaim(aud); err != nil { return nil, err } if val, isString := aud.GetKind().(*spb.Value_StringValue); isString { return []string{val.StringValue}, nil } s := make([]string, 0, len(aud.GetListValue().GetValues())) for _, a := range aud.GetListValue().GetValues() { s = append(s, a.GetStringValue()) } return s, nil } // HasSubject checks whether a JWT contains an issuer claim ('sub'). func (r *RawJWT) HasSubject() bool { return r.hasField(claimSubject) } // Subject returns the subject claim ('sub') or an error if no claim is present. func (r *RawJWT) Subject() (string, error) { return r.stringClaim(claimSubject) } // HasIssuer checks whether a JWT contains an issuer claim ('iss'). func (r *RawJWT) HasIssuer() bool { return r.hasField(claimIssuer) } // Issuer returns the issuer claim ('iss') or an error if no claim is present. func (r *RawJWT) Issuer() (string, error) { return r.stringClaim(claimIssuer) } // HasJWTID checks whether a JWT contains an JWT ID claim ('jti'). func (r *RawJWT) HasJWTID() bool { return r.hasField(claimJWTID) } // JWTID returns the JWT ID claim ('jti') or an error if no claim is present. func (r *RawJWT) JWTID() (string, error) { return r.stringClaim(claimJWTID) } // HasIssuedAt checks whether a JWT contains an issued at claim ('iat'). func (r *RawJWT) HasIssuedAt() bool { return r.hasField(claimIssuedAt) } // IssuedAt returns the issued at claim ('iat') or an error if no claim is present. func (r *RawJWT) IssuedAt() (time.Time, error) { return r.timeClaim(claimIssuedAt) } // HasExpiration checks whether a JWT contains an expiration time claim ('exp'). func (r *RawJWT) HasExpiration() bool { return r.hasField(claimExpiration) } // ExpiresAt returns the expiration claim ('exp') or an error if no claim is present. func (r *RawJWT) ExpiresAt() (time.Time, error) { return r.timeClaim(claimExpiration) } // HasNotBefore checks whether a JWT contains a not before claim ('nbf'). func (r *RawJWT) HasNotBefore() bool { return r.hasField(claimNotBefore) } // NotBefore returns the not before claim ('nbf') or an error if no claim is present. func (r *RawJWT) NotBefore() (time.Time, error) { return r.timeClaim(claimNotBefore) } // HasStringClaim checks whether a claim of type string is present. func (r *RawJWT) HasStringClaim(name string) bool { return !isRegisteredClaim(name) && r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_StringValue{}}) } // StringClaim returns a custom string claim or an error if no claim is present. func (r *RawJWT) StringClaim(name string) (string, error) { if isRegisteredClaim(name) { return "", fmt.Errorf("claim '%q' is a registered claim", name) } return r.stringClaim(name) } // HasNumberClaim checks whether a claim of type number is present. func (r *RawJWT) HasNumberClaim(name string) bool { return !isRegisteredClaim(name) && r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_NumberValue{}}) } // NumberClaim returns a custom number claim or an error if no claim is present. func (r *RawJWT) NumberClaim(name string) (float64, error) { if isRegisteredClaim(name) { return 0, fmt.Errorf("claim '%q' is a registered claim", name) } return r.numberClaim(name) } // HasBooleanClaim checks whether a claim of type boolean is present. func (r *RawJWT) HasBooleanClaim(name string) bool { return r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_BoolValue{}}) } // BooleanClaim returns a custom bool claim or an error if no claim is present. func (r *RawJWT) BooleanClaim(name string) (bool, error) { val, err := r.customClaim(name) if err != nil { return false, err } b, ok := val.Kind.(*spb.Value_BoolValue) if !ok { return false, fmt.Errorf("claim '%q' is not a boolean", name) } return b.BoolValue, nil } // HasNullClaim checks whether a claim of type null is present. func (r *RawJWT) HasNullClaim(name string) bool { return r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_NullValue{}}) } // HasArrayClaim checks whether a claim of type list is present. func (r *RawJWT) HasArrayClaim(name string) bool { return !isRegisteredClaim(name) && r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_ListValue{}}) } // ArrayClaim returns a slice representing a JSON array for a claim or an error if the claim is empty. func (r *RawJWT) ArrayClaim(name string) ([]interface{}, error) { val, err := r.customClaim(name) if err != nil { return nil, err } if val.GetListValue() == nil { return nil, fmt.Errorf("claim '%q' is not a list", name) } return val.GetListValue().AsSlice(), nil } // HasObjectClaim checks whether a claim of type JSON object is present. func (r *RawJWT) HasObjectClaim(name string) bool { return r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_StructValue{}}) } // ObjectClaim returns a map representing a JSON object for a claim or an error if the claim is empty. func (r *RawJWT) ObjectClaim(name string) (map[string]interface{}, error) { val, err := r.customClaim(name) if err != nil { return nil, err } if val.GetStructValue() == nil { return nil, fmt.Errorf("claim '%q' is not a JSON object", name) } return val.GetStructValue().AsMap(), err } // CustomClaimNames returns a list with the name of custom claims in a RawJWT. func (r *RawJWT) CustomClaimNames() []string { names := []string{} for key := range r.jsonpb.GetFields() { if !isRegisteredClaim(key) { names = append(names, key) } } return names } func (r *RawJWT) timeClaim(name string) (time.Time, error) { n, err := r.numberClaim(name) if err != nil { return time.Time{}, err } return time.Unix(int64(n), 0), err } func (r *RawJWT) numberClaim(name string) (float64, error) { val, ok := r.field(name) if !ok { return 0, fmt.Errorf("no '%q' claim found", name) } s, ok := val.Kind.(*spb.Value_NumberValue) if !ok { return 0, fmt.Errorf("claim '%q' is not a number", name) } return s.NumberValue, nil } func (r *RawJWT) stringClaim(name string) (string, error) { val, ok := r.field(name) if !ok { return "", fmt.Errorf("no '%q' claim found", name) } s, ok := val.Kind.(*spb.Value_StringValue) if !ok { return "", fmt.Errorf("claim '%q' is not a string", name) } if !utf8.ValidString(s.StringValue) { return "", fmt.Errorf("claim '%q' is not a valid utf-8 encoded string", name) } return s.StringValue, nil } func (r *RawJWT) hasClaimOfKind(name string, exp *spb.Value) bool { val, exist := r.field(name) if !exist || exp == nil { return false } var isKind bool switch exp.GetKind().(type) { case *spb.Value_StructValue: _, isKind = val.GetKind().(*spb.Value_StructValue) case *spb.Value_NullValue: _, isKind = val.GetKind().(*spb.Value_NullValue) case *spb.Value_BoolValue: _, isKind = val.GetKind().(*spb.Value_BoolValue) case *spb.Value_ListValue: _, isKind = val.GetKind().(*spb.Value_ListValue) case *spb.Value_StringValue: _, isKind = val.GetKind().(*spb.Value_StringValue) case *spb.Value_NumberValue: _, isKind = val.GetKind().(*spb.Value_NumberValue) default: isKind = false } return isKind } func (r *RawJWT) customClaim(name string) (*spb.Value, error) { if isRegisteredClaim(name) { return nil, fmt.Errorf("'%q' is a registered claim", name) } val, ok := r.field(name) if !ok { return nil, fmt.Errorf("claim '%q' not found", name) } return val, nil } func (r *RawJWT) hasField(name string) bool { _, ok := r.field(name) return ok } func (r *RawJWT) field(name string) (*spb.Value, bool) { val, ok := r.jsonpb.GetFields()[name] return val, ok } // createPayload creates a JSON payload from JWT options. func createPayload(opts *RawJWTOptions) (*spb.Struct, error) { if err := validateCustomClaims(opts.CustomClaims); err != nil { return nil, err } if opts.ExpiresAt == nil && !opts.WithoutExpiration { return nil, fmt.Errorf("jwt options must contain an expiration or must be marked WithoutExpiration") } if opts.ExpiresAt != nil && opts.WithoutExpiration { return nil, fmt.Errorf("jwt options can't be marked WithoutExpiration when expiration is specified") } if opts.Audience != nil && opts.Audiences != nil { return nil, fmt.Errorf("jwt options can either contain a single Audience or a list of Audiences but not both") } payload := &spb.Struct{ Fields: map[string]*spb.Value{}, } setStringValue(payload, claimJWTID, opts.JWTID) setStringValue(payload, claimIssuer, opts.Issuer) setStringValue(payload, claimSubject, opts.Subject) setStringValue(payload, claimAudience, opts.Audience) setTimeValue(payload, claimIssuedAt, opts.IssuedAt) setTimeValue(payload, claimNotBefore, opts.NotBefore) setTimeValue(payload, claimExpiration, opts.ExpiresAt) setAudiences(payload, claimAudience, opts.Audiences) for k, v := range opts.CustomClaims { val, err := spb.NewValue(v) if err != nil { return nil, err } setValue(payload, k, val) } return payload, nil } func validatePayload(payload *spb.Struct) error { if payload.Fields == nil || len(payload.Fields) == 0 { return nil } if err := validateAudienceClaim(payload.Fields[claimAudience]); err != nil { return err } for claim, val := range payload.GetFields() { if isRegisteredTimeClaim(claim) { if err := validateTimeClaim(claim, val); err != nil { return err } } if isRegisteredStringClaim(claim) { if err := validateStringClaim(claim, val); err != nil { return err } } } return nil } func validateStringClaim(claim string, val *spb.Value) error { v, ok := val.Kind.(*spb.Value_StringValue) if !ok { return fmt.Errorf("claim: '%q' MUST be a string", claim) } if !utf8.ValidString(v.StringValue) { return fmt.Errorf("claim: '%q' isn't a valid UTF-8 string", claim) } return nil } func validateTimeClaim(claim string, val *spb.Value) error { if _, ok := val.Kind.(*spb.Value_NumberValue); !ok { return fmt.Errorf("claim %q MUST be a numeric value, ", claim) } t := int64(val.GetNumberValue()) if t > jwtTimestampMax || t < jwtTimestampMin { return fmt.Errorf("invalid timestamp: '%d' for claim: %q", t, claim) } return nil } func validateAudienceClaim(val *spb.Value) error { if val == nil { return nil } _, isString := val.Kind.(*spb.Value_StringValue) l, isList := val.Kind.(*spb.Value_ListValue) if !isList && !isString { return fmt.Errorf("audience claim MUST be a list with at least one string or a single string value") } if isString { return validateStringClaim(claimAudience, val) } if l.ListValue != nil && len(l.ListValue.Values) == 0 { return fmt.Errorf("there MUST be at least one value present in the audience claim") } for _, aud := range l.ListValue.Values { v, ok := aud.Kind.(*spb.Value_StringValue) if !ok { return fmt.Errorf("audience value is not a string") } if !utf8.ValidString(v.StringValue) { return fmt.Errorf("audience value is not a valid UTF-8 string") } } return nil } func validateCustomClaims(cc map[string]interface{}) error { if cc == nil { return nil } for key := range cc { if isRegisteredClaim(key) { return fmt.Errorf("claim '%q' is a registered claim, it can't be declared as a custom claim", key) } } return nil } func setTimeValue(p *spb.Struct, claim string, val *time.Time) { if val == nil { return } setValue(p, claim, spb.NewNumberValue(float64(val.Unix()))) } func setStringValue(p *spb.Struct, claim string, val *string) { if val == nil { return } setValue(p, claim, spb.NewStringValue(*val)) } func setAudiences(p *spb.Struct, claim string, vals []string) { if vals == nil { return } audList := &spb.ListValue{ Values: make([]*spb.Value, 0, len(vals)), } for _, aud := range vals { audList.Values = append(audList.Values, spb.NewStringValue(aud)) } setValue(p, claim, spb.NewListValue(audList)) } func setValue(p *spb.Struct, claim string, val *spb.Value) { if p.GetFields() == nil { p.Fields = make(map[string]*spb.Value) } p.GetFields()[claim] = val } func isRegisteredClaim(c string) bool { return isRegisteredStringClaim(c) || isRegisteredTimeClaim(c) || c == claimAudience } func isRegisteredStringClaim(c string) bool { return c == claimIssuer || c == claimSubject || c == claimJWTID } func isRegisteredTimeClaim(c string) bool { return c == claimExpiration || c == claimNotBefore || c == claimIssuedAt }