mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-20 05:21:57 +00:00
285 lines
6.0 KiB
Go
285 lines
6.0 KiB
Go
package sso
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/MicahParks/keyfunc/v2"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
|
|
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
|
)
|
|
|
|
type verifier struct {
|
|
jwks *keyfunc.JWKS
|
|
issuer string
|
|
audiences map[string]struct{}
|
|
}
|
|
|
|
type AccessTokenClaims struct {
|
|
Scope string `json:"scope"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
func (c AccessTokenClaims) Scopes() []string {
|
|
if c.Scope == "" {
|
|
return nil
|
|
}
|
|
return strings.Fields(c.Scope)
|
|
}
|
|
|
|
type VerificationResult struct {
|
|
UserID uint
|
|
ServiceAlias string
|
|
Subject string
|
|
Claims *AccessTokenClaims
|
|
}
|
|
|
|
type TokenInfo struct {
|
|
Kid string
|
|
Iss string
|
|
Aud []string
|
|
Sub string
|
|
Exp int64
|
|
Iat int64
|
|
Nbf int64
|
|
}
|
|
|
|
var (
|
|
globalMu sync.RWMutex
|
|
globalV *verifier
|
|
)
|
|
|
|
func Init(ctx context.Context, jwksURL, issuer string, audiences []string) error {
|
|
jwksURL = strings.TrimSpace(jwksURL)
|
|
issuer = strings.TrimSpace(issuer)
|
|
if jwksURL == "" || issuer == "" {
|
|
return errors.New("missing SSO JWKS or issuer configuration")
|
|
}
|
|
|
|
client := &http.Client{Timeout: 5 * time.Second}
|
|
options := keyfunc.Options{
|
|
Ctx: ctx,
|
|
Client: client,
|
|
RefreshTimeout: 10 * time.Second,
|
|
RefreshInterval: time.Hour,
|
|
RefreshUnknownKID: true,
|
|
RefreshErrorHandler: func(err error) {
|
|
utils.Log.Errorf("sso jwks refresh failed: %v", err)
|
|
},
|
|
}
|
|
|
|
jwks, err := keyfunc.Get(jwksURL, options)
|
|
if err != nil {
|
|
return fmt.Errorf("load jwks: %w", err)
|
|
}
|
|
|
|
audienceMap := make(map[string]struct{}, len(audiences))
|
|
for _, aud := range audiences {
|
|
aud = strings.TrimSpace(aud)
|
|
if aud == "" {
|
|
continue
|
|
}
|
|
audienceMap[aud] = struct{}{}
|
|
}
|
|
|
|
globalMu.Lock()
|
|
globalV = &verifier{jwks: jwks, issuer: issuer, audiences: audienceMap}
|
|
globalMu.Unlock()
|
|
|
|
utils.Log.Infof("sso verifier initialized for issuer %s (%d keys)", issuer, len(jwks.KIDs()))
|
|
return nil
|
|
}
|
|
|
|
func VerifyAccessToken(token string) (*VerificationResult, error) {
|
|
token = strings.TrimSpace(token)
|
|
if token == "" {
|
|
return nil, errors.New("empty token")
|
|
}
|
|
|
|
globalMu.RLock()
|
|
v := globalV
|
|
globalMu.RUnlock()
|
|
if v == nil {
|
|
return nil, errors.New("sso verifier not initialized")
|
|
}
|
|
|
|
claims := &AccessTokenClaims{}
|
|
parser := jwt.NewParser(
|
|
jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()}),
|
|
jwt.WithIssuedAt(),
|
|
jwt.WithExpirationRequired(),
|
|
)
|
|
|
|
tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
|
if err != nil {
|
|
if shouldRefreshOnVerifyError(err) {
|
|
if refreshErr := v.jwks.Refresh(context.Background(), keyfunc.RefreshOptions{IgnoreRateLimit: true}); refreshErr != nil {
|
|
utils.Log.WithError(refreshErr).Warn("sso jwks refresh after signature error failed")
|
|
} else {
|
|
tok, err = parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
|
}
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse token: %w", err)
|
|
}
|
|
}
|
|
if !tok.Valid {
|
|
return nil, errors.New("invalid token")
|
|
}
|
|
|
|
if claims.Issuer != v.issuer {
|
|
return nil, errors.New("unexpected token issuer")
|
|
}
|
|
|
|
if len(v.audiences) > 0 {
|
|
validAud := false
|
|
for _, aud := range claims.Audience {
|
|
if _, ok := v.audiences[aud]; ok {
|
|
validAud = true
|
|
break
|
|
}
|
|
}
|
|
if !validAud {
|
|
return nil, errors.New("unexpected token audience")
|
|
}
|
|
}
|
|
|
|
sub := strings.TrimSpace(claims.Subject)
|
|
if sub == "" {
|
|
return nil, errors.New("missing subject")
|
|
}
|
|
|
|
result := &VerificationResult{Claims: claims, Subject: sub}
|
|
switch {
|
|
case strings.HasPrefix(sub, "user:"):
|
|
idStr := strings.TrimPrefix(sub, "user:")
|
|
id, err := strconv.ParseUint(idStr, 10, 64)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid subject: %w", err)
|
|
}
|
|
result.UserID = uint(id)
|
|
case strings.HasPrefix(sub, "service:"):
|
|
alias := strings.TrimSpace(strings.TrimPrefix(sub, "service:"))
|
|
if alias == "" {
|
|
return nil, errors.New("invalid service subject")
|
|
}
|
|
result.ServiceAlias = strings.ToLower(alias)
|
|
default:
|
|
return nil, errors.New("unsupported subject type")
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func shouldRefreshOnVerifyError(err error) bool {
|
|
if !IsSignatureError(err) {
|
|
return false
|
|
}
|
|
return !disableRefreshOnSignatureError()
|
|
}
|
|
|
|
func IsSignatureError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
msg := err.Error()
|
|
return strings.Contains(msg, "verification error") || strings.Contains(msg, "token signature is invalid")
|
|
}
|
|
|
|
func disableRefreshOnSignatureError() bool {
|
|
val := strings.TrimSpace(os.Getenv("SSO_DISABLE_JWKS_REFRESH_ON_SIG_ERROR"))
|
|
if val == "" {
|
|
return false
|
|
}
|
|
return val == "1" || strings.EqualFold(val, "true") || strings.EqualFold(val, "yes")
|
|
}
|
|
|
|
func ExtractTokenInfo(token string) TokenInfo {
|
|
token = strings.TrimSpace(token)
|
|
if token == "" {
|
|
return TokenInfo{}
|
|
}
|
|
|
|
claims := jwt.MapClaims{}
|
|
parser := jwt.NewParser()
|
|
tok, _, err := parser.ParseUnverified(token, claims)
|
|
if err != nil {
|
|
return TokenInfo{}
|
|
}
|
|
|
|
info := TokenInfo{}
|
|
if kid, ok := tok.Header["kid"].(string); ok {
|
|
info.Kid = kid
|
|
}
|
|
if iss, ok := claims["iss"].(string); ok {
|
|
info.Iss = iss
|
|
}
|
|
if sub, ok := claims["sub"].(string); ok {
|
|
info.Sub = sub
|
|
}
|
|
if aud, ok := claims["aud"]; ok {
|
|
info.Aud = toStringSlice(aud)
|
|
}
|
|
info.Exp = toInt64(claims["exp"])
|
|
info.Iat = toInt64(claims["iat"])
|
|
info.Nbf = toInt64(claims["nbf"])
|
|
return info
|
|
}
|
|
|
|
func toStringSlice(v any) []string {
|
|
switch t := v.(type) {
|
|
case string:
|
|
if t == "" {
|
|
return nil
|
|
}
|
|
return []string{t}
|
|
case []string:
|
|
out := make([]string, 0, len(t))
|
|
for _, s := range t {
|
|
if s != "" {
|
|
out = append(out, s)
|
|
}
|
|
}
|
|
return out
|
|
case []any:
|
|
out := make([]string, 0, len(t))
|
|
for _, item := range t {
|
|
if s, ok := item.(string); ok && s != "" {
|
|
out = append(out, s)
|
|
}
|
|
}
|
|
return out
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func toInt64(v any) int64 {
|
|
switch t := v.(type) {
|
|
case int64:
|
|
return t
|
|
case int:
|
|
return int64(t)
|
|
case float64:
|
|
return int64(t)
|
|
case json.Number:
|
|
if n, err := t.Int64(); err == nil {
|
|
return n
|
|
}
|
|
case string:
|
|
if n, err := strconv.ParseInt(t, 10, 64); err == nil {
|
|
return n
|
|
}
|
|
}
|
|
return 0
|
|
}
|