mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-20 05:21:57 +00:00
[FEAT/BE] resolve jwks
This commit is contained in:
@@ -57,6 +57,7 @@ var (
|
||||
SSOPortalURL string
|
||||
SSOClients map[string]SSOClientConfig
|
||||
SSOAccessCookieName string
|
||||
SSOAccessCookieFallback []string
|
||||
SSORefreshCookieName string
|
||||
SSOCookieDomain string
|
||||
SSOCookieSecure bool
|
||||
@@ -141,6 +142,7 @@ func init() {
|
||||
SSOGetMeURL = viper.GetString("SSO_GETME_URL")
|
||||
SSOPortalURL = strings.TrimSpace(viper.GetString("SSO_PORTAL_URL"))
|
||||
SSOAccessCookieName = defaultString(viper.GetString("SSO_ACCESS_COOKIE_NAME"), "sso_access")
|
||||
SSOAccessCookieFallback = parseList("SSO_ACCESS_COOKIE_FALLBACK")
|
||||
SSORefreshCookieName = defaultString(viper.GetString("SSO_REFRESH_COOKIE_NAME"), "sso_refresh")
|
||||
SSOCookieDomain = viper.GetString("SSO_COOKIE_DOMAIN")
|
||||
SSOCookieSecure = viper.GetBool("SSO_COOKIE_SECURE")
|
||||
|
||||
+59
-13
@@ -19,11 +19,11 @@ const (
|
||||
|
||||
// AuthContext keeps authentication details captured by the middleware.
|
||||
type AuthContext struct {
|
||||
Token string
|
||||
Verification *sso.VerificationResult
|
||||
User *entity.User
|
||||
Roles []sso.Role
|
||||
Permissions map[string]struct{}
|
||||
Token string
|
||||
Verification *sso.VerificationResult
|
||||
User *entity.User
|
||||
Roles []sso.Role
|
||||
Permissions map[string]struct{}
|
||||
UserAreaIDs []uint
|
||||
UserLocationIDs []uint
|
||||
UserAllArea bool
|
||||
@@ -36,8 +36,30 @@ type AuthContext struct {
|
||||
func Auth(userService service.UserService, requiredScopes ...string) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
token := bearerToken(c)
|
||||
if token == "" {
|
||||
token = strings.TrimSpace(c.Cookies(config.SSOAccessCookieName))
|
||||
tokenSource := ""
|
||||
if token != "" {
|
||||
tokenSource = "header"
|
||||
} else {
|
||||
primaryName := strings.TrimSpace(config.SSOAccessCookieName)
|
||||
if primaryName != "" {
|
||||
token = strings.TrimSpace(c.Cookies(primaryName))
|
||||
if token != "" {
|
||||
tokenSource = "cookie:" + primaryName
|
||||
}
|
||||
}
|
||||
if token == "" {
|
||||
for _, name := range config.SSOAccessCookieFallback {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" || name == primaryName {
|
||||
continue
|
||||
}
|
||||
token = strings.TrimSpace(c.Cookies(name))
|
||||
if token != "" {
|
||||
tokenSource = "cookie:" + name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if token == "" {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
|
||||
@@ -45,7 +67,11 @@ func Auth(userService service.UserService, requiredScopes ...string) fiber.Handl
|
||||
|
||||
verification, err := sso.VerifyAccessToken(token)
|
||||
if err != nil {
|
||||
utils.Log.WithError(err).Warn("auth: token verification failed")
|
||||
if sso.IsSignatureError(err) {
|
||||
logSignatureError("auth", tokenSource, token, err)
|
||||
} else {
|
||||
utils.Log.WithError(err).Warn("auth: token verification failed")
|
||||
}
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
|
||||
}
|
||||
|
||||
@@ -89,11 +115,11 @@ func Auth(userService service.UserService, requiredScopes ...string) fiber.Handl
|
||||
}
|
||||
|
||||
ctx := &AuthContext{
|
||||
Token: token,
|
||||
Verification: verification,
|
||||
User: user,
|
||||
Roles: roles,
|
||||
Permissions: permissions,
|
||||
Token: token,
|
||||
Verification: verification,
|
||||
User: user,
|
||||
Roles: roles,
|
||||
Permissions: permissions,
|
||||
UserAreaIDs: nil,
|
||||
UserLocationIDs: nil,
|
||||
UserAllArea: false,
|
||||
@@ -216,6 +242,26 @@ func hasAllScopes(have, required []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func logSignatureError(ctxLabel, tokenSource, token string, err error) {
|
||||
info := sso.ExtractTokenInfo(token)
|
||||
aud := strings.Join(info.Aud, ",")
|
||||
utils.Log.Errorf(
|
||||
"access token verification failed: %v | ctx=%s source=%s iss=%s kid=%s aud=%s sub=%s exp=%d iat=%d nbf=%d expected_iss=%s expected_aud=%v",
|
||||
err,
|
||||
ctxLabel,
|
||||
tokenSource,
|
||||
info.Iss,
|
||||
info.Kid,
|
||||
aud,
|
||||
info.Sub,
|
||||
info.Exp,
|
||||
info.Iat,
|
||||
info.Nbf,
|
||||
config.SSOIssuer,
|
||||
config.SSOAllowedAudiences,
|
||||
)
|
||||
}
|
||||
|
||||
// RequirePermissions ensures the authenticated user possesses all specified permissions.
|
||||
func RequirePermissions(perms ...string) fiber.Handler {
|
||||
required := canonicalPermissions(perms)
|
||||
|
||||
@@ -196,7 +196,11 @@ func (h *Controller) Refresh(c *fiber.Ctx) error {
|
||||
|
||||
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("access token verification failed: %v", err)
|
||||
if sso.IsSignatureError(err) {
|
||||
logSignatureError("sso refresh", "sso_token", tokenResp.AccessToken, err)
|
||||
} else {
|
||||
utils.Log.Errorf("access token verification failed: %v", err)
|
||||
}
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
|
||||
}
|
||||
|
||||
@@ -304,7 +308,11 @@ func (h *Controller) Callback(c *fiber.Ctx) error {
|
||||
|
||||
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("access token verification failed: %v", err)
|
||||
if sso.IsSignatureError(err) {
|
||||
logSignatureError("sso callback", "sso_token", tokenResp.AccessToken, err)
|
||||
} else {
|
||||
utils.Log.Errorf("access token verification failed: %v", err)
|
||||
}
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
|
||||
}
|
||||
|
||||
@@ -337,6 +345,22 @@ func (h *Controller) UserInfo(c *fiber.Ctx) error {
|
||||
|
||||
token := strings.TrimSpace(c.Cookies(accessName))
|
||||
tokenFromCookie := token != ""
|
||||
usedCookieName := accessName
|
||||
|
||||
if !tokenFromCookie {
|
||||
for _, name := range config.SSOAccessCookieFallback {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" || name == accessName {
|
||||
continue
|
||||
}
|
||||
token = strings.TrimSpace(c.Cookies(name))
|
||||
if token != "" {
|
||||
tokenFromCookie = true
|
||||
usedCookieName = name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !tokenFromCookie {
|
||||
authHeader := strings.TrimSpace(c.Get("Authorization"))
|
||||
@@ -363,7 +387,11 @@ func (h *Controller) UserInfo(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
if _, err := sso.VerifyAccessToken(token); err != nil {
|
||||
utils.Log.WithError(err).Warn("access token verification failed for userinfo")
|
||||
if sso.IsSignatureError(err) {
|
||||
logSignatureError("sso userinfo", "request", token, err)
|
||||
} else {
|
||||
utils.Log.WithError(err).Warn("access token verification failed for userinfo")
|
||||
}
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
||||
}
|
||||
|
||||
@@ -382,7 +410,7 @@ func (h *Controller) UserInfo(c *fiber.Ctx) error {
|
||||
// SSO /auth/get-me expects the access cookie; add Authorization as well for compatibility.
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
if tokenFromCookie {
|
||||
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", accessName, token))
|
||||
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", usedCookieName, token))
|
||||
}
|
||||
|
||||
resp, err := h.httpClient.Do(req)
|
||||
@@ -836,6 +864,27 @@ func resolveSSOCookieName(configuredName, fallback string) string {
|
||||
return strings.TrimSpace(fallback)
|
||||
}
|
||||
|
||||
func logSignatureError(ctxLabel, tokenSource, token string, err error) {
|
||||
info := sso.ExtractTokenInfo(token)
|
||||
aud := strings.Join(info.Aud, ",")
|
||||
utils.Log.Errorf(
|
||||
"access token verification failed: %v | ctx=%s source=%s iss=%s kid=%s aud=%s sub=%s exp=%d iat=%d nbf=%d expected_iss=%s expected_aud=%v jwks=%s",
|
||||
err,
|
||||
ctxLabel,
|
||||
tokenSource,
|
||||
info.Iss,
|
||||
info.Kid,
|
||||
aud,
|
||||
info.Sub,
|
||||
info.Exp,
|
||||
info.Iat,
|
||||
info.Nbf,
|
||||
config.SSOIssuer,
|
||||
config.SSOAllowedAudiences,
|
||||
config.SSOJWKSURL,
|
||||
)
|
||||
}
|
||||
|
||||
func normalizeClientParam(raw string) string {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
|
||||
@@ -2,9 +2,11 @@ package sso
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -41,6 +43,16 @@ type VerificationResult struct {
|
||||
Claims *AccessTokenClaims
|
||||
}
|
||||
|
||||
type TokenInfo struct {
|
||||
Kid string
|
||||
Iss string
|
||||
Aud []string
|
||||
Sub string
|
||||
Exp int64
|
||||
Iat int64
|
||||
Nbf int64
|
||||
}
|
||||
|
||||
var (
|
||||
globalMu sync.RWMutex
|
||||
globalV *verifier
|
||||
@@ -106,10 +118,19 @@ func VerifyAccessToken(token string) (*VerificationResult, error) {
|
||||
jwt.WithIssuedAt(),
|
||||
jwt.WithExpirationRequired(),
|
||||
)
|
||||
|
||||
|
||||
tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse token: %w", err)
|
||||
if shouldRefreshOnVerifyError(err) {
|
||||
if refreshErr := v.jwks.Refresh(context.Background(), keyfunc.RefreshOptions{IgnoreRateLimit: true}); refreshErr != nil {
|
||||
utils.Log.WithError(refreshErr).Warn("sso jwks refresh after signature error failed")
|
||||
} else {
|
||||
tok, err = parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse token: %w", err)
|
||||
}
|
||||
}
|
||||
if !tok.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
@@ -158,3 +179,106 @@ func VerifyAccessToken(token string) (*VerificationResult, error) {
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func shouldRefreshOnVerifyError(err error) bool {
|
||||
if !IsSignatureError(err) {
|
||||
return false
|
||||
}
|
||||
return !disableRefreshOnSignatureError()
|
||||
}
|
||||
|
||||
func IsSignatureError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := err.Error()
|
||||
return strings.Contains(msg, "verification error") || strings.Contains(msg, "token signature is invalid")
|
||||
}
|
||||
|
||||
func disableRefreshOnSignatureError() bool {
|
||||
val := strings.TrimSpace(os.Getenv("SSO_DISABLE_JWKS_REFRESH_ON_SIG_ERROR"))
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
return val == "1" || strings.EqualFold(val, "true") || strings.EqualFold(val, "yes")
|
||||
}
|
||||
|
||||
func ExtractTokenInfo(token string) TokenInfo {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return TokenInfo{}
|
||||
}
|
||||
|
||||
claims := jwt.MapClaims{}
|
||||
parser := jwt.NewParser()
|
||||
tok, _, err := parser.ParseUnverified(token, claims)
|
||||
if err != nil {
|
||||
return TokenInfo{}
|
||||
}
|
||||
|
||||
info := TokenInfo{}
|
||||
if kid, ok := tok.Header["kid"].(string); ok {
|
||||
info.Kid = kid
|
||||
}
|
||||
if iss, ok := claims["iss"].(string); ok {
|
||||
info.Iss = iss
|
||||
}
|
||||
if sub, ok := claims["sub"].(string); ok {
|
||||
info.Sub = sub
|
||||
}
|
||||
if aud, ok := claims["aud"]; ok {
|
||||
info.Aud = toStringSlice(aud)
|
||||
}
|
||||
info.Exp = toInt64(claims["exp"])
|
||||
info.Iat = toInt64(claims["iat"])
|
||||
info.Nbf = toInt64(claims["nbf"])
|
||||
return info
|
||||
}
|
||||
|
||||
func toStringSlice(v any) []string {
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
if t == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{t}
|
||||
case []string:
|
||||
out := make([]string, 0, len(t))
|
||||
for _, s := range t {
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
case []any:
|
||||
out := make([]string, 0, len(t))
|
||||
for _, item := range t {
|
||||
if s, ok := item.(string); ok && s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func toInt64(v any) int64 {
|
||||
switch t := v.(type) {
|
||||
case int64:
|
||||
return t
|
||||
case int:
|
||||
return int64(t)
|
||||
case float64:
|
||||
return int64(t)
|
||||
case json.Number:
|
||||
if n, err := t.Int64(); err == nil {
|
||||
return n
|
||||
}
|
||||
case string:
|
||||
if n, err := strconv.ParseInt(t, 10, 64); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user