Files
lti-api/internal/modules/sso/verifier/verifier.go
T
2026-02-24 15:16:09 +07:00

285 lines
6.0 KiB
Go

package sso
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/MicahParks/keyfunc/v2"
"github.com/golang-jwt/jwt/v5"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
)
type verifier struct {
jwks *keyfunc.JWKS
issuer string
audiences map[string]struct{}
}
type AccessTokenClaims struct {
Scope string `json:"scope"`
jwt.RegisteredClaims
}
func (c AccessTokenClaims) Scopes() []string {
if c.Scope == "" {
return nil
}
return strings.Fields(c.Scope)
}
type VerificationResult struct {
UserID uint
ServiceAlias string
Subject string
Claims *AccessTokenClaims
}
type TokenInfo struct {
Kid string
Iss string
Aud []string
Sub string
Exp int64
Iat int64
Nbf int64
}
var (
globalMu sync.RWMutex
globalV *verifier
)
func Init(ctx context.Context, jwksURL, issuer string, audiences []string) error {
jwksURL = strings.TrimSpace(jwksURL)
issuer = strings.TrimSpace(issuer)
if jwksURL == "" || issuer == "" {
return errors.New("missing SSO JWKS or issuer configuration")
}
client := &http.Client{Timeout: 5 * time.Second}
options := keyfunc.Options{
Ctx: ctx,
Client: client,
RefreshTimeout: 10 * time.Second,
RefreshInterval: time.Hour,
RefreshUnknownKID: true,
RefreshErrorHandler: func(err error) {
utils.Log.Errorf("sso jwks refresh failed: %v", err)
},
}
jwks, err := keyfunc.Get(jwksURL, options)
if err != nil {
return fmt.Errorf("load jwks: %w", err)
}
audienceMap := make(map[string]struct{}, len(audiences))
for _, aud := range audiences {
aud = strings.TrimSpace(aud)
if aud == "" {
continue
}
audienceMap[aud] = struct{}{}
}
globalMu.Lock()
globalV = &verifier{jwks: jwks, issuer: issuer, audiences: audienceMap}
globalMu.Unlock()
utils.Log.Infof("sso verifier initialized for issuer %s (%d keys)", issuer, len(jwks.KIDs()))
return nil
}
func VerifyAccessToken(token string) (*VerificationResult, error) {
token = strings.TrimSpace(token)
if token == "" {
return nil, errors.New("empty token")
}
globalMu.RLock()
v := globalV
globalMu.RUnlock()
if v == nil {
return nil, errors.New("sso verifier not initialized")
}
claims := &AccessTokenClaims{}
parser := jwt.NewParser(
jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()}),
jwt.WithIssuedAt(),
jwt.WithExpirationRequired(),
)
tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
if err != nil {
if shouldRefreshOnVerifyError(err) {
if refreshErr := v.jwks.Refresh(context.Background(), keyfunc.RefreshOptions{IgnoreRateLimit: true}); refreshErr != nil {
utils.Log.WithError(refreshErr).Warn("sso jwks refresh after signature error failed")
} else {
tok, err = parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
}
}
if err != nil {
return nil, fmt.Errorf("parse token: %w", err)
}
}
if !tok.Valid {
return nil, errors.New("invalid token")
}
if claims.Issuer != v.issuer {
return nil, errors.New("unexpected token issuer")
}
if len(v.audiences) > 0 {
validAud := false
for _, aud := range claims.Audience {
if _, ok := v.audiences[aud]; ok {
validAud = true
break
}
}
if !validAud {
return nil, errors.New("unexpected token audience")
}
}
sub := strings.TrimSpace(claims.Subject)
if sub == "" {
return nil, errors.New("missing subject")
}
result := &VerificationResult{Claims: claims, Subject: sub}
switch {
case strings.HasPrefix(sub, "user:"):
idStr := strings.TrimPrefix(sub, "user:")
id, err := strconv.ParseUint(idStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid subject: %w", err)
}
result.UserID = uint(id)
case strings.HasPrefix(sub, "service:"):
alias := strings.TrimSpace(strings.TrimPrefix(sub, "service:"))
if alias == "" {
return nil, errors.New("invalid service subject")
}
result.ServiceAlias = strings.ToLower(alias)
default:
return nil, errors.New("unsupported subject type")
}
return result, nil
}
func shouldRefreshOnVerifyError(err error) bool {
if !IsSignatureError(err) {
return false
}
return !disableRefreshOnSignatureError()
}
func IsSignatureError(err error) bool {
if err == nil {
return false
}
msg := err.Error()
return strings.Contains(msg, "verification error") || strings.Contains(msg, "token signature is invalid")
}
func disableRefreshOnSignatureError() bool {
val := strings.TrimSpace(os.Getenv("SSO_DISABLE_JWKS_REFRESH_ON_SIG_ERROR"))
if val == "" {
return false
}
return val == "1" || strings.EqualFold(val, "true") || strings.EqualFold(val, "yes")
}
func ExtractTokenInfo(token string) TokenInfo {
token = strings.TrimSpace(token)
if token == "" {
return TokenInfo{}
}
claims := jwt.MapClaims{}
parser := jwt.NewParser()
tok, _, err := parser.ParseUnverified(token, claims)
if err != nil {
return TokenInfo{}
}
info := TokenInfo{}
if kid, ok := tok.Header["kid"].(string); ok {
info.Kid = kid
}
if iss, ok := claims["iss"].(string); ok {
info.Iss = iss
}
if sub, ok := claims["sub"].(string); ok {
info.Sub = sub
}
if aud, ok := claims["aud"]; ok {
info.Aud = toStringSlice(aud)
}
info.Exp = toInt64(claims["exp"])
info.Iat = toInt64(claims["iat"])
info.Nbf = toInt64(claims["nbf"])
return info
}
func toStringSlice(v any) []string {
switch t := v.(type) {
case string:
if t == "" {
return nil
}
return []string{t}
case []string:
out := make([]string, 0, len(t))
for _, s := range t {
if s != "" {
out = append(out, s)
}
}
return out
case []any:
out := make([]string, 0, len(t))
for _, item := range t {
if s, ok := item.(string); ok && s != "" {
out = append(out, s)
}
}
return out
default:
return nil
}
}
func toInt64(v any) int64 {
switch t := v.(type) {
case int64:
return t
case int:
return int64(t)
case float64:
return int64(t)
case json.Number:
if n, err := t.Int64(); err == nil {
return n
}
case string:
if n, err := strconv.ParseInt(t, 10, 64); err == nil {
return n
}
}
return 0
}