[FEAT/BE] resolve jwks

This commit is contained in:
ragilap
2026-02-24 15:16:09 +07:00
parent 5fb7a78a5a
commit f6f4cc5a10
4 changed files with 240 additions and 19 deletions
@@ -196,7 +196,11 @@ func (h *Controller) Refresh(c *fiber.Ctx) error {
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
if err != nil {
utils.Log.Errorf("access token verification failed: %v", err)
if sso.IsSignatureError(err) {
logSignatureError("sso refresh", "sso_token", tokenResp.AccessToken, err)
} else {
utils.Log.Errorf("access token verification failed: %v", err)
}
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
}
@@ -304,7 +308,11 @@ func (h *Controller) Callback(c *fiber.Ctx) error {
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
if err != nil {
utils.Log.Errorf("access token verification failed: %v", err)
if sso.IsSignatureError(err) {
logSignatureError("sso callback", "sso_token", tokenResp.AccessToken, err)
} else {
utils.Log.Errorf("access token verification failed: %v", err)
}
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
}
@@ -337,6 +345,22 @@ func (h *Controller) UserInfo(c *fiber.Ctx) error {
token := strings.TrimSpace(c.Cookies(accessName))
tokenFromCookie := token != ""
usedCookieName := accessName
if !tokenFromCookie {
for _, name := range config.SSOAccessCookieFallback {
name = strings.TrimSpace(name)
if name == "" || name == accessName {
continue
}
token = strings.TrimSpace(c.Cookies(name))
if token != "" {
tokenFromCookie = true
usedCookieName = name
break
}
}
}
if !tokenFromCookie {
authHeader := strings.TrimSpace(c.Get("Authorization"))
@@ -363,7 +387,11 @@ func (h *Controller) UserInfo(c *fiber.Ctx) error {
}
if _, err := sso.VerifyAccessToken(token); err != nil {
utils.Log.WithError(err).Warn("access token verification failed for userinfo")
if sso.IsSignatureError(err) {
logSignatureError("sso userinfo", "request", token, err)
} else {
utils.Log.WithError(err).Warn("access token verification failed for userinfo")
}
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
}
@@ -382,7 +410,7 @@ func (h *Controller) UserInfo(c *fiber.Ctx) error {
// SSO /auth/get-me expects the access cookie; add Authorization as well for compatibility.
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
if tokenFromCookie {
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", accessName, token))
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", usedCookieName, token))
}
resp, err := h.httpClient.Do(req)
@@ -836,6 +864,27 @@ func resolveSSOCookieName(configuredName, fallback string) string {
return strings.TrimSpace(fallback)
}
func logSignatureError(ctxLabel, tokenSource, token string, err error) {
info := sso.ExtractTokenInfo(token)
aud := strings.Join(info.Aud, ",")
utils.Log.Errorf(
"access token verification failed: %v | ctx=%s source=%s iss=%s kid=%s aud=%s sub=%s exp=%d iat=%d nbf=%d expected_iss=%s expected_aud=%v jwks=%s",
err,
ctxLabel,
tokenSource,
info.Iss,
info.Kid,
aud,
info.Sub,
info.Exp,
info.Iat,
info.Nbf,
config.SSOIssuer,
config.SSOAllowedAudiences,
config.SSOJWKSURL,
)
}
func normalizeClientParam(raw string) string {
value := strings.TrimSpace(raw)
if value == "" {
+126 -2
View File
@@ -2,9 +2,11 @@ package sso
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"strconv"
"strings"
"sync"
@@ -41,6 +43,16 @@ type VerificationResult struct {
Claims *AccessTokenClaims
}
type TokenInfo struct {
Kid string
Iss string
Aud []string
Sub string
Exp int64
Iat int64
Nbf int64
}
var (
globalMu sync.RWMutex
globalV *verifier
@@ -106,10 +118,19 @@ func VerifyAccessToken(token string) (*VerificationResult, error) {
jwt.WithIssuedAt(),
jwt.WithExpirationRequired(),
)
tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
if err != nil {
return nil, fmt.Errorf("parse token: %w", err)
if shouldRefreshOnVerifyError(err) {
if refreshErr := v.jwks.Refresh(context.Background(), keyfunc.RefreshOptions{IgnoreRateLimit: true}); refreshErr != nil {
utils.Log.WithError(refreshErr).Warn("sso jwks refresh after signature error failed")
} else {
tok, err = parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
}
}
if err != nil {
return nil, fmt.Errorf("parse token: %w", err)
}
}
if !tok.Valid {
return nil, errors.New("invalid token")
@@ -158,3 +179,106 @@ func VerifyAccessToken(token string) (*VerificationResult, error) {
return result, nil
}
func shouldRefreshOnVerifyError(err error) bool {
if !IsSignatureError(err) {
return false
}
return !disableRefreshOnSignatureError()
}
func IsSignatureError(err error) bool {
if err == nil {
return false
}
msg := err.Error()
return strings.Contains(msg, "verification error") || strings.Contains(msg, "token signature is invalid")
}
func disableRefreshOnSignatureError() bool {
val := strings.TrimSpace(os.Getenv("SSO_DISABLE_JWKS_REFRESH_ON_SIG_ERROR"))
if val == "" {
return false
}
return val == "1" || strings.EqualFold(val, "true") || strings.EqualFold(val, "yes")
}
func ExtractTokenInfo(token string) TokenInfo {
token = strings.TrimSpace(token)
if token == "" {
return TokenInfo{}
}
claims := jwt.MapClaims{}
parser := jwt.NewParser()
tok, _, err := parser.ParseUnverified(token, claims)
if err != nil {
return TokenInfo{}
}
info := TokenInfo{}
if kid, ok := tok.Header["kid"].(string); ok {
info.Kid = kid
}
if iss, ok := claims["iss"].(string); ok {
info.Iss = iss
}
if sub, ok := claims["sub"].(string); ok {
info.Sub = sub
}
if aud, ok := claims["aud"]; ok {
info.Aud = toStringSlice(aud)
}
info.Exp = toInt64(claims["exp"])
info.Iat = toInt64(claims["iat"])
info.Nbf = toInt64(claims["nbf"])
return info
}
func toStringSlice(v any) []string {
switch t := v.(type) {
case string:
if t == "" {
return nil
}
return []string{t}
case []string:
out := make([]string, 0, len(t))
for _, s := range t {
if s != "" {
out = append(out, s)
}
}
return out
case []any:
out := make([]string, 0, len(t))
for _, item := range t {
if s, ok := item.(string); ok && s != "" {
out = append(out, s)
}
}
return out
default:
return nil
}
}
func toInt64(v any) int64 {
switch t := v.(type) {
case int64:
return t
case int:
return int64(t)
case float64:
return int64(t)
case json.Number:
if n, err := t.Int64(); err == nil {
return n
}
case string:
if n, err := strconv.ParseInt(t, 10, 64); err == nil {
return n
}
}
return 0
}