mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-20 21:41:55 +00:00
[FEAT/BE] resolve jwks
This commit is contained in:
@@ -2,9 +2,11 @@ package sso
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -41,6 +43,16 @@ type VerificationResult struct {
|
||||
Claims *AccessTokenClaims
|
||||
}
|
||||
|
||||
type TokenInfo struct {
|
||||
Kid string
|
||||
Iss string
|
||||
Aud []string
|
||||
Sub string
|
||||
Exp int64
|
||||
Iat int64
|
||||
Nbf int64
|
||||
}
|
||||
|
||||
var (
|
||||
globalMu sync.RWMutex
|
||||
globalV *verifier
|
||||
@@ -106,10 +118,19 @@ func VerifyAccessToken(token string) (*VerificationResult, error) {
|
||||
jwt.WithIssuedAt(),
|
||||
jwt.WithExpirationRequired(),
|
||||
)
|
||||
|
||||
|
||||
tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse token: %w", err)
|
||||
if shouldRefreshOnVerifyError(err) {
|
||||
if refreshErr := v.jwks.Refresh(context.Background(), keyfunc.RefreshOptions{IgnoreRateLimit: true}); refreshErr != nil {
|
||||
utils.Log.WithError(refreshErr).Warn("sso jwks refresh after signature error failed")
|
||||
} else {
|
||||
tok, err = parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse token: %w", err)
|
||||
}
|
||||
}
|
||||
if !tok.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
@@ -158,3 +179,106 @@ func VerifyAccessToken(token string) (*VerificationResult, error) {
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func shouldRefreshOnVerifyError(err error) bool {
|
||||
if !IsSignatureError(err) {
|
||||
return false
|
||||
}
|
||||
return !disableRefreshOnSignatureError()
|
||||
}
|
||||
|
||||
func IsSignatureError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := err.Error()
|
||||
return strings.Contains(msg, "verification error") || strings.Contains(msg, "token signature is invalid")
|
||||
}
|
||||
|
||||
func disableRefreshOnSignatureError() bool {
|
||||
val := strings.TrimSpace(os.Getenv("SSO_DISABLE_JWKS_REFRESH_ON_SIG_ERROR"))
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
return val == "1" || strings.EqualFold(val, "true") || strings.EqualFold(val, "yes")
|
||||
}
|
||||
|
||||
func ExtractTokenInfo(token string) TokenInfo {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return TokenInfo{}
|
||||
}
|
||||
|
||||
claims := jwt.MapClaims{}
|
||||
parser := jwt.NewParser()
|
||||
tok, _, err := parser.ParseUnverified(token, claims)
|
||||
if err != nil {
|
||||
return TokenInfo{}
|
||||
}
|
||||
|
||||
info := TokenInfo{}
|
||||
if kid, ok := tok.Header["kid"].(string); ok {
|
||||
info.Kid = kid
|
||||
}
|
||||
if iss, ok := claims["iss"].(string); ok {
|
||||
info.Iss = iss
|
||||
}
|
||||
if sub, ok := claims["sub"].(string); ok {
|
||||
info.Sub = sub
|
||||
}
|
||||
if aud, ok := claims["aud"]; ok {
|
||||
info.Aud = toStringSlice(aud)
|
||||
}
|
||||
info.Exp = toInt64(claims["exp"])
|
||||
info.Iat = toInt64(claims["iat"])
|
||||
info.Nbf = toInt64(claims["nbf"])
|
||||
return info
|
||||
}
|
||||
|
||||
func toStringSlice(v any) []string {
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
if t == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{t}
|
||||
case []string:
|
||||
out := make([]string, 0, len(t))
|
||||
for _, s := range t {
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
case []any:
|
||||
out := make([]string, 0, len(t))
|
||||
for _, item := range t {
|
||||
if s, ok := item.(string); ok && s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func toInt64(v any) int64 {
|
||||
switch t := v.(type) {
|
||||
case int64:
|
||||
return t
|
||||
case int:
|
||||
return int64(t)
|
||||
case float64:
|
||||
return int64(t)
|
||||
case json.Number:
|
||||
if n, err := t.Int64(); err == nil {
|
||||
return n
|
||||
}
|
||||
case string:
|
||||
if n, err := strconv.ParseInt(t, 10, 64); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user