package sso import ( "context" "encoding/json" "errors" "fmt" "net/http" "os" "strconv" "strings" "sync" "time" "github.com/MicahParks/keyfunc/v2" "github.com/golang-jwt/jwt/v5" "gitlab.com/mbugroup/lti-api.git/internal/utils" ) type verifier struct { jwks *keyfunc.JWKS issuer string audiences map[string]struct{} } type AccessTokenClaims struct { Scope string `json:"scope"` jwt.RegisteredClaims } func (c AccessTokenClaims) Scopes() []string { if c.Scope == "" { return nil } return strings.Fields(c.Scope) } type VerificationResult struct { UserID uint ServiceAlias string Subject string Claims *AccessTokenClaims } type TokenInfo struct { Kid string Iss string Aud []string Sub string Exp int64 Iat int64 Nbf int64 } var ( globalMu sync.RWMutex globalV *verifier ) func Init(ctx context.Context, jwksURL, issuer string, audiences []string) error { jwksURL = strings.TrimSpace(jwksURL) issuer = strings.TrimSpace(issuer) if jwksURL == "" || issuer == "" { return errors.New("missing SSO JWKS or issuer configuration") } client := &http.Client{Timeout: 5 * time.Second} options := keyfunc.Options{ Ctx: ctx, Client: client, RefreshTimeout: 10 * time.Second, RefreshInterval: time.Hour, RefreshUnknownKID: true, RefreshErrorHandler: func(err error) { utils.Log.Errorf("sso jwks refresh failed: %v", err) }, } jwks, err := keyfunc.Get(jwksURL, options) if err != nil { return fmt.Errorf("load jwks: %w", err) } audienceMap := make(map[string]struct{}, len(audiences)) for _, aud := range audiences { aud = strings.TrimSpace(aud) if aud == "" { continue } audienceMap[aud] = struct{}{} } globalMu.Lock() globalV = &verifier{jwks: jwks, issuer: issuer, audiences: audienceMap} globalMu.Unlock() utils.Log.Infof("sso verifier initialized for issuer %s (%d keys)", issuer, len(jwks.KIDs())) return nil } func VerifyAccessToken(token string) (*VerificationResult, error) { token = strings.TrimSpace(token) if token == "" { return nil, errors.New("empty token") } globalMu.RLock() v := globalV globalMu.RUnlock() if v == nil { return nil, errors.New("sso verifier not initialized") } claims := &AccessTokenClaims{} parser := jwt.NewParser( jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()}), jwt.WithIssuedAt(), jwt.WithExpirationRequired(), ) tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc) if err != nil { if shouldRefreshOnVerifyError(err) { if refreshErr := v.jwks.Refresh(context.Background(), keyfunc.RefreshOptions{IgnoreRateLimit: true}); refreshErr != nil { utils.Log.WithError(refreshErr).Warn("sso jwks refresh after signature error failed") } else { tok, err = parser.ParseWithClaims(token, claims, v.jwks.Keyfunc) } } if err != nil { return nil, fmt.Errorf("parse token: %w", err) } } if !tok.Valid { return nil, errors.New("invalid token") } if claims.Issuer != v.issuer { return nil, errors.New("unexpected token issuer") } if len(v.audiences) > 0 { validAud := false for _, aud := range claims.Audience { if _, ok := v.audiences[aud]; ok { validAud = true break } } if !validAud { return nil, errors.New("unexpected token audience") } } sub := strings.TrimSpace(claims.Subject) if sub == "" { return nil, errors.New("missing subject") } result := &VerificationResult{Claims: claims, Subject: sub} switch { case strings.HasPrefix(sub, "user:"): idStr := strings.TrimPrefix(sub, "user:") id, err := strconv.ParseUint(idStr, 10, 64) if err != nil { return nil, fmt.Errorf("invalid subject: %w", err) } result.UserID = uint(id) case strings.HasPrefix(sub, "service:"): alias := strings.TrimSpace(strings.TrimPrefix(sub, "service:")) if alias == "" { return nil, errors.New("invalid service subject") } result.ServiceAlias = strings.ToLower(alias) default: return nil, errors.New("unsupported subject type") } return result, nil } func shouldRefreshOnVerifyError(err error) bool { if !IsSignatureError(err) { return false } return !disableRefreshOnSignatureError() } func IsSignatureError(err error) bool { if err == nil { return false } msg := err.Error() return strings.Contains(msg, "verification error") || strings.Contains(msg, "token signature is invalid") } func disableRefreshOnSignatureError() bool { val := strings.TrimSpace(os.Getenv("SSO_DISABLE_JWKS_REFRESH_ON_SIG_ERROR")) if val == "" { return false } return val == "1" || strings.EqualFold(val, "true") || strings.EqualFold(val, "yes") } func ExtractTokenInfo(token string) TokenInfo { token = strings.TrimSpace(token) if token == "" { return TokenInfo{} } claims := jwt.MapClaims{} parser := jwt.NewParser() tok, _, err := parser.ParseUnverified(token, claims) if err != nil { return TokenInfo{} } info := TokenInfo{} if kid, ok := tok.Header["kid"].(string); ok { info.Kid = kid } if iss, ok := claims["iss"].(string); ok { info.Iss = iss } if sub, ok := claims["sub"].(string); ok { info.Sub = sub } if aud, ok := claims["aud"]; ok { info.Aud = toStringSlice(aud) } info.Exp = toInt64(claims["exp"]) info.Iat = toInt64(claims["iat"]) info.Nbf = toInt64(claims["nbf"]) return info } func toStringSlice(v any) []string { switch t := v.(type) { case string: if t == "" { return nil } return []string{t} case []string: out := make([]string, 0, len(t)) for _, s := range t { if s != "" { out = append(out, s) } } return out case []any: out := make([]string, 0, len(t)) for _, item := range t { if s, ok := item.(string); ok && s != "" { out = append(out, s) } } return out default: return nil } } func toInt64(v any) int64 { switch t := v.(type) { case int64: return t case int: return int64(t) case float64: return int64(t) case json.Number: if n, err := t.Int64(); err == nil { return n } case string: if n, err := strconv.ParseInt(t, 10, 64); err == nil { return n } } return 0 }