diff --git a/internal/config/config.go b/internal/config/config.go index af723b3b..0c09ee33 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -57,6 +57,7 @@ var ( SSOPortalURL string SSOClients map[string]SSOClientConfig SSOAccessCookieName string + SSOAccessCookieFallback []string SSORefreshCookieName string SSOCookieDomain string SSOCookieSecure bool @@ -141,6 +142,7 @@ func init() { SSOGetMeURL = viper.GetString("SSO_GETME_URL") SSOPortalURL = strings.TrimSpace(viper.GetString("SSO_PORTAL_URL")) SSOAccessCookieName = defaultString(viper.GetString("SSO_ACCESS_COOKIE_NAME"), "sso_access") + SSOAccessCookieFallback = parseList("SSO_ACCESS_COOKIE_FALLBACK") SSORefreshCookieName = defaultString(viper.GetString("SSO_REFRESH_COOKIE_NAME"), "sso_refresh") SSOCookieDomain = viper.GetString("SSO_COOKIE_DOMAIN") SSOCookieSecure = viper.GetBool("SSO_COOKIE_SECURE") diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index b7229382..e7640e7b 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -19,11 +19,11 @@ const ( // AuthContext keeps authentication details captured by the middleware. type AuthContext struct { - Token string - Verification *sso.VerificationResult - User *entity.User - Roles []sso.Role - Permissions map[string]struct{} + Token string + Verification *sso.VerificationResult + User *entity.User + Roles []sso.Role + Permissions map[string]struct{} UserAreaIDs []uint UserLocationIDs []uint UserAllArea bool @@ -36,8 +36,30 @@ type AuthContext struct { func Auth(userService service.UserService, requiredScopes ...string) fiber.Handler { return func(c *fiber.Ctx) error { token := bearerToken(c) - if token == "" { - token = strings.TrimSpace(c.Cookies(config.SSOAccessCookieName)) + tokenSource := "" + if token != "" { + tokenSource = "header" + } else { + primaryName := strings.TrimSpace(config.SSOAccessCookieName) + if primaryName != "" { + token = strings.TrimSpace(c.Cookies(primaryName)) + if token != "" { + tokenSource = "cookie:" + primaryName + } + } + if token == "" { + for _, name := range config.SSOAccessCookieFallback { + name = strings.TrimSpace(name) + if name == "" || name == primaryName { + continue + } + token = strings.TrimSpace(c.Cookies(name)) + if token != "" { + tokenSource = "cookie:" + name + break + } + } + } } if token == "" { return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate") @@ -45,7 +67,11 @@ func Auth(userService service.UserService, requiredScopes ...string) fiber.Handl verification, err := sso.VerifyAccessToken(token) if err != nil { - utils.Log.WithError(err).Warn("auth: token verification failed") + if sso.IsSignatureError(err) { + logSignatureError("auth", tokenSource, token, err) + } else { + utils.Log.WithError(err).Warn("auth: token verification failed") + } return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate") } @@ -89,11 +115,11 @@ func Auth(userService service.UserService, requiredScopes ...string) fiber.Handl } ctx := &AuthContext{ - Token: token, - Verification: verification, - User: user, - Roles: roles, - Permissions: permissions, + Token: token, + Verification: verification, + User: user, + Roles: roles, + Permissions: permissions, UserAreaIDs: nil, UserLocationIDs: nil, UserAllArea: false, @@ -216,6 +242,26 @@ func hasAllScopes(have, required []string) bool { return true } +func logSignatureError(ctxLabel, tokenSource, token string, err error) { + info := sso.ExtractTokenInfo(token) + aud := strings.Join(info.Aud, ",") + utils.Log.Errorf( + "access token verification failed: %v | ctx=%s source=%s iss=%s kid=%s aud=%s sub=%s exp=%d iat=%d nbf=%d expected_iss=%s expected_aud=%v", + err, + ctxLabel, + tokenSource, + info.Iss, + info.Kid, + aud, + info.Sub, + info.Exp, + info.Iat, + info.Nbf, + config.SSOIssuer, + config.SSOAllowedAudiences, + ) +} + // RequirePermissions ensures the authenticated user possesses all specified permissions. func RequirePermissions(perms ...string) fiber.Handler { required := canonicalPermissions(perms) diff --git a/internal/modules/sso/controllers/sso.controller.go b/internal/modules/sso/controllers/sso.controller.go index 5e75d4a9..41ece390 100644 --- a/internal/modules/sso/controllers/sso.controller.go +++ b/internal/modules/sso/controllers/sso.controller.go @@ -196,7 +196,11 @@ func (h *Controller) Refresh(c *fiber.Ctx) error { verification, err := sso.VerifyAccessToken(tokenResp.AccessToken) if err != nil { - utils.Log.Errorf("access token verification failed: %v", err) + if sso.IsSignatureError(err) { + logSignatureError("sso refresh", "sso_token", tokenResp.AccessToken, err) + } else { + utils.Log.Errorf("access token verification failed: %v", err) + } return fiber.NewError(fiber.StatusUnauthorized, "invalid access token") } @@ -304,7 +308,11 @@ func (h *Controller) Callback(c *fiber.Ctx) error { verification, err := sso.VerifyAccessToken(tokenResp.AccessToken) if err != nil { - utils.Log.Errorf("access token verification failed: %v", err) + if sso.IsSignatureError(err) { + logSignatureError("sso callback", "sso_token", tokenResp.AccessToken, err) + } else { + utils.Log.Errorf("access token verification failed: %v", err) + } return fiber.NewError(fiber.StatusUnauthorized, "invalid access token") } @@ -337,6 +345,22 @@ func (h *Controller) UserInfo(c *fiber.Ctx) error { token := strings.TrimSpace(c.Cookies(accessName)) tokenFromCookie := token != "" + usedCookieName := accessName + + if !tokenFromCookie { + for _, name := range config.SSOAccessCookieFallback { + name = strings.TrimSpace(name) + if name == "" || name == accessName { + continue + } + token = strings.TrimSpace(c.Cookies(name)) + if token != "" { + tokenFromCookie = true + usedCookieName = name + break + } + } + } if !tokenFromCookie { authHeader := strings.TrimSpace(c.Get("Authorization")) @@ -363,7 +387,11 @@ func (h *Controller) UserInfo(c *fiber.Ctx) error { } if _, err := sso.VerifyAccessToken(token); err != nil { - utils.Log.WithError(err).Warn("access token verification failed for userinfo") + if sso.IsSignatureError(err) { + logSignatureError("sso userinfo", "request", token, err) + } else { + utils.Log.WithError(err).Warn("access token verification failed for userinfo") + } return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated") } @@ -382,7 +410,7 @@ func (h *Controller) UserInfo(c *fiber.Ctx) error { // SSO /auth/get-me expects the access cookie; add Authorization as well for compatibility. req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) if tokenFromCookie { - req.Header.Set("Cookie", fmt.Sprintf("%s=%s", accessName, token)) + req.Header.Set("Cookie", fmt.Sprintf("%s=%s", usedCookieName, token)) } resp, err := h.httpClient.Do(req) @@ -836,6 +864,27 @@ func resolveSSOCookieName(configuredName, fallback string) string { return strings.TrimSpace(fallback) } +func logSignatureError(ctxLabel, tokenSource, token string, err error) { + info := sso.ExtractTokenInfo(token) + aud := strings.Join(info.Aud, ",") + utils.Log.Errorf( + "access token verification failed: %v | ctx=%s source=%s iss=%s kid=%s aud=%s sub=%s exp=%d iat=%d nbf=%d expected_iss=%s expected_aud=%v jwks=%s", + err, + ctxLabel, + tokenSource, + info.Iss, + info.Kid, + aud, + info.Sub, + info.Exp, + info.Iat, + info.Nbf, + config.SSOIssuer, + config.SSOAllowedAudiences, + config.SSOJWKSURL, + ) +} + func normalizeClientParam(raw string) string { value := strings.TrimSpace(raw) if value == "" { diff --git a/internal/modules/sso/verifier/verifier.go b/internal/modules/sso/verifier/verifier.go index 0c8d97e8..7d7cefbb 100644 --- a/internal/modules/sso/verifier/verifier.go +++ b/internal/modules/sso/verifier/verifier.go @@ -2,9 +2,11 @@ package sso import ( "context" + "encoding/json" "errors" "fmt" "net/http" + "os" "strconv" "strings" "sync" @@ -41,6 +43,16 @@ type VerificationResult struct { Claims *AccessTokenClaims } +type TokenInfo struct { + Kid string + Iss string + Aud []string + Sub string + Exp int64 + Iat int64 + Nbf int64 +} + var ( globalMu sync.RWMutex globalV *verifier @@ -106,10 +118,19 @@ func VerifyAccessToken(token string) (*VerificationResult, error) { jwt.WithIssuedAt(), jwt.WithExpirationRequired(), ) - + tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc) if err != nil { - return nil, fmt.Errorf("parse token: %w", err) + if shouldRefreshOnVerifyError(err) { + if refreshErr := v.jwks.Refresh(context.Background(), keyfunc.RefreshOptions{IgnoreRateLimit: true}); refreshErr != nil { + utils.Log.WithError(refreshErr).Warn("sso jwks refresh after signature error failed") + } else { + tok, err = parser.ParseWithClaims(token, claims, v.jwks.Keyfunc) + } + } + if err != nil { + return nil, fmt.Errorf("parse token: %w", err) + } } if !tok.Valid { return nil, errors.New("invalid token") @@ -158,3 +179,106 @@ func VerifyAccessToken(token string) (*VerificationResult, error) { return result, nil } + +func shouldRefreshOnVerifyError(err error) bool { + if !IsSignatureError(err) { + return false + } + return !disableRefreshOnSignatureError() +} + +func IsSignatureError(err error) bool { + if err == nil { + return false + } + msg := err.Error() + return strings.Contains(msg, "verification error") || strings.Contains(msg, "token signature is invalid") +} + +func disableRefreshOnSignatureError() bool { + val := strings.TrimSpace(os.Getenv("SSO_DISABLE_JWKS_REFRESH_ON_SIG_ERROR")) + if val == "" { + return false + } + return val == "1" || strings.EqualFold(val, "true") || strings.EqualFold(val, "yes") +} + +func ExtractTokenInfo(token string) TokenInfo { + token = strings.TrimSpace(token) + if token == "" { + return TokenInfo{} + } + + claims := jwt.MapClaims{} + parser := jwt.NewParser() + tok, _, err := parser.ParseUnverified(token, claims) + if err != nil { + return TokenInfo{} + } + + info := TokenInfo{} + if kid, ok := tok.Header["kid"].(string); ok { + info.Kid = kid + } + if iss, ok := claims["iss"].(string); ok { + info.Iss = iss + } + if sub, ok := claims["sub"].(string); ok { + info.Sub = sub + } + if aud, ok := claims["aud"]; ok { + info.Aud = toStringSlice(aud) + } + info.Exp = toInt64(claims["exp"]) + info.Iat = toInt64(claims["iat"]) + info.Nbf = toInt64(claims["nbf"]) + return info +} + +func toStringSlice(v any) []string { + switch t := v.(type) { + case string: + if t == "" { + return nil + } + return []string{t} + case []string: + out := make([]string, 0, len(t)) + for _, s := range t { + if s != "" { + out = append(out, s) + } + } + return out + case []any: + out := make([]string, 0, len(t)) + for _, item := range t { + if s, ok := item.(string); ok && s != "" { + out = append(out, s) + } + } + return out + default: + return nil + } +} + +func toInt64(v any) int64 { + switch t := v.(type) { + case int64: + return t + case int: + return int64(t) + case float64: + return int64(t) + case json.Number: + if n, err := t.Int64(); err == nil { + return n + } + case string: + if n, err := strconv.ParseInt(t, 10, 64); err == nil { + return n + } + } + return 0 +}