mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-24 07:15:43 +00:00
.
This commit is contained in:
+239
@@ -0,0 +1,239 @@
|
||||
package keyfunc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrJWKAlgMismatch indicates that the given JWK was found, but its "alg" parameter's value did not match that of
|
||||
// the JWT.
|
||||
ErrJWKAlgMismatch = errors.New(`the given JWK was found, but its "alg" parameter's value did not match the expected algorithm`)
|
||||
|
||||
// ErrJWKUseWhitelist indicates that the given JWK was found, but its "use" parameter's value was not whitelisted.
|
||||
ErrJWKUseWhitelist = errors.New(`the given JWK was found, but its "use" parameter's value was not whitelisted`)
|
||||
|
||||
// ErrKIDNotFound indicates that the given key ID was not found in the JWKS.
|
||||
ErrKIDNotFound = errors.New("the given key ID was not found in the JWKS")
|
||||
|
||||
// ErrMissingAssets indicates there are required assets are missing to create a public key.
|
||||
ErrMissingAssets = errors.New("required assets are missing to create a public key")
|
||||
)
|
||||
|
||||
// ErrorHandler is a function signature that consumes an error.
|
||||
type ErrorHandler func(err error)
|
||||
|
||||
const (
|
||||
// UseEncryption is a JWK "use" parameter value indicating the JSON Web Key is to be used for encryption.
|
||||
UseEncryption JWKUse = "enc"
|
||||
// UseOmitted is a JWK "use" parameter value that was not specified or was empty.
|
||||
UseOmitted JWKUse = ""
|
||||
// UseSignature is a JWK "use" parameter value indicating the JSON Web Key is to be used for signatures.
|
||||
UseSignature JWKUse = "sig"
|
||||
)
|
||||
|
||||
// JWKUse is a set of values for the "use" parameter of a JWK.
|
||||
// See https://tools.ietf.org/html/rfc7517#section-4.2.
|
||||
type JWKUse string
|
||||
|
||||
// jsonWebKey represents a JSON Web Key inside a JWKS.
|
||||
type jsonWebKey struct {
|
||||
Algorithm string `json:"alg"`
|
||||
Curve string `json:"crv"`
|
||||
Exponent string `json:"e"`
|
||||
K string `json:"k"`
|
||||
ID string `json:"kid"`
|
||||
Modulus string `json:"n"`
|
||||
Type string `json:"kty"`
|
||||
Use string `json:"use"`
|
||||
X string `json:"x"`
|
||||
Y string `json:"y"`
|
||||
}
|
||||
|
||||
// parsedJWK represents a JSON Web Key parsed with fields as the correct Go types.
|
||||
type parsedJWK struct {
|
||||
algorithm string
|
||||
public interface{}
|
||||
use JWKUse
|
||||
}
|
||||
|
||||
// JWKS represents a JSON Web Key Set (JWK Set).
|
||||
type JWKS struct {
|
||||
jwkUseWhitelist map[JWKUse]struct{}
|
||||
cancel context.CancelFunc
|
||||
client *http.Client
|
||||
ctx context.Context
|
||||
raw []byte
|
||||
givenKeys map[string]GivenKey
|
||||
givenKIDOverride bool
|
||||
jwksURL string
|
||||
keys map[string]parsedJWK
|
||||
mux sync.RWMutex
|
||||
refreshErrorHandler ErrorHandler
|
||||
refreshInterval time.Duration
|
||||
refreshRateLimit time.Duration
|
||||
refreshRequests chan refreshRequest
|
||||
refreshTimeout time.Duration
|
||||
refreshUnknownKID bool
|
||||
requestFactory func(ctx context.Context, url string) (*http.Request, error)
|
||||
responseExtractor func(ctx context.Context, resp *http.Response) (json.RawMessage, error)
|
||||
}
|
||||
|
||||
// rawJWKS represents a JWKS in JSON format.
|
||||
type rawJWKS struct {
|
||||
Keys []*jsonWebKey `json:"keys"`
|
||||
}
|
||||
|
||||
// NewJSON creates a new JWKS from a raw JSON message.
|
||||
func NewJSON(jwksBytes json.RawMessage) (jwks *JWKS, err error) {
|
||||
var rawKS rawJWKS
|
||||
err = json.Unmarshal(jwksBytes, &rawKS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Iterate through the keys in the raw JWKS. Add them to the JWKS.
|
||||
jwks = &JWKS{
|
||||
keys: make(map[string]parsedJWK, len(rawKS.Keys)),
|
||||
}
|
||||
for _, key := range rawKS.Keys {
|
||||
var keyInter interface{}
|
||||
switch keyType := key.Type; keyType {
|
||||
case ktyEC:
|
||||
keyInter, err = key.ECDSA()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
case ktyOKP:
|
||||
keyInter, err = key.EdDSA()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
case ktyOct:
|
||||
keyInter, err = key.Oct()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
case ktyRSA:
|
||||
keyInter, err = key.RSA()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
default:
|
||||
// Ignore unknown key types silently.
|
||||
continue
|
||||
}
|
||||
|
||||
jwks.keys[key.ID] = parsedJWK{
|
||||
algorithm: key.Algorithm,
|
||||
use: JWKUse(key.Use),
|
||||
public: keyInter,
|
||||
}
|
||||
}
|
||||
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
// EndBackground ends the background goroutine to update the JWKS. It can only happen once and is only effective if the
|
||||
// JWKS has a background goroutine refreshing the JWKS keys.
|
||||
func (j *JWKS) EndBackground() {
|
||||
if j.cancel != nil {
|
||||
j.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// KIDs returns the key IDs (`kid`) for all keys in the JWKS.
|
||||
func (j *JWKS) KIDs() (kids []string) {
|
||||
j.mux.RLock()
|
||||
defer j.mux.RUnlock()
|
||||
kids = make([]string, len(j.keys))
|
||||
index := 0
|
||||
for kid := range j.keys {
|
||||
kids[index] = kid
|
||||
index++
|
||||
}
|
||||
return kids
|
||||
}
|
||||
|
||||
// Len returns the number of keys in the JWKS.
|
||||
func (j *JWKS) Len() int {
|
||||
j.mux.RLock()
|
||||
defer j.mux.RUnlock()
|
||||
return len(j.keys)
|
||||
}
|
||||
|
||||
// RawJWKS returns a copy of the raw JWKS received from the given JWKS URL.
|
||||
func (j *JWKS) RawJWKS() []byte {
|
||||
j.mux.RLock()
|
||||
defer j.mux.RUnlock()
|
||||
raw := make([]byte, len(j.raw))
|
||||
copy(raw, j.raw)
|
||||
return raw
|
||||
}
|
||||
|
||||
// ReadOnlyKeys returns a read-only copy of the mapping of key IDs (`kid`) to cryptographic keys.
|
||||
func (j *JWKS) ReadOnlyKeys() map[string]interface{} {
|
||||
keys := make(map[string]interface{})
|
||||
j.mux.Lock()
|
||||
for kid, cryptoKey := range j.keys {
|
||||
keys[kid] = cryptoKey.public
|
||||
}
|
||||
j.mux.Unlock()
|
||||
return keys
|
||||
}
|
||||
|
||||
// getKey gets the jsonWebKey from the given KID from the JWKS. It may refresh the JWKS if configured to.
|
||||
func (j *JWKS) getKey(alg, kid string) (jsonKey interface{}, err error) {
|
||||
j.mux.RLock()
|
||||
pubKey, ok := j.keys[kid]
|
||||
j.mux.RUnlock()
|
||||
|
||||
if !ok {
|
||||
if !j.refreshUnknownKID {
|
||||
return nil, ErrKIDNotFound
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(j.ctx)
|
||||
req := refreshRequest{
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Refresh the JWKS.
|
||||
select {
|
||||
case <-j.ctx.Done():
|
||||
return
|
||||
case j.refreshRequests <- req:
|
||||
default:
|
||||
// If the j.refreshRequests channel is full, return the error early.
|
||||
return nil, ErrKIDNotFound
|
||||
}
|
||||
|
||||
// Wait for the JWKS refresh to finish.
|
||||
<-ctx.Done()
|
||||
|
||||
j.mux.RLock()
|
||||
defer j.mux.RUnlock()
|
||||
if pubKey, ok = j.keys[kid]; !ok {
|
||||
return nil, ErrKIDNotFound
|
||||
}
|
||||
}
|
||||
|
||||
// jwkUseWhitelist might be empty if the jwks was from keyfunc.NewJSON() or if JWKUseNoWhitelist option was true.
|
||||
if len(j.jwkUseWhitelist) > 0 {
|
||||
_, ok = j.jwkUseWhitelist[pubKey.use]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`%w: JWK "use" parameter value %q is not whitelisted`, ErrJWKUseWhitelist, pubKey.use)
|
||||
}
|
||||
}
|
||||
|
||||
if pubKey.algorithm != "" && pubKey.algorithm != alg {
|
||||
return nil, fmt.Errorf(`%w: JWK "alg" parameter value %q does not match token "alg" parameter value %q`, ErrJWKAlgMismatch, pubKey.algorithm, alg)
|
||||
}
|
||||
|
||||
return pubKey.public, nil
|
||||
}
|
||||
Reference in New Issue
Block a user