mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-23 14:55:42 +00:00
[FEAT/BE] resolve jwks
This commit is contained in:
@@ -57,6 +57,7 @@ var (
|
|||||||
SSOPortalURL string
|
SSOPortalURL string
|
||||||
SSOClients map[string]SSOClientConfig
|
SSOClients map[string]SSOClientConfig
|
||||||
SSOAccessCookieName string
|
SSOAccessCookieName string
|
||||||
|
SSOAccessCookieFallback []string
|
||||||
SSORefreshCookieName string
|
SSORefreshCookieName string
|
||||||
SSOCookieDomain string
|
SSOCookieDomain string
|
||||||
SSOCookieSecure bool
|
SSOCookieSecure bool
|
||||||
@@ -141,6 +142,7 @@ func init() {
|
|||||||
SSOGetMeURL = viper.GetString("SSO_GETME_URL")
|
SSOGetMeURL = viper.GetString("SSO_GETME_URL")
|
||||||
SSOPortalURL = strings.TrimSpace(viper.GetString("SSO_PORTAL_URL"))
|
SSOPortalURL = strings.TrimSpace(viper.GetString("SSO_PORTAL_URL"))
|
||||||
SSOAccessCookieName = defaultString(viper.GetString("SSO_ACCESS_COOKIE_NAME"), "sso_access")
|
SSOAccessCookieName = defaultString(viper.GetString("SSO_ACCESS_COOKIE_NAME"), "sso_access")
|
||||||
|
SSOAccessCookieFallback = parseList("SSO_ACCESS_COOKIE_FALLBACK")
|
||||||
SSORefreshCookieName = defaultString(viper.GetString("SSO_REFRESH_COOKIE_NAME"), "sso_refresh")
|
SSORefreshCookieName = defaultString(viper.GetString("SSO_REFRESH_COOKIE_NAME"), "sso_refresh")
|
||||||
SSOCookieDomain = viper.GetString("SSO_COOKIE_DOMAIN")
|
SSOCookieDomain = viper.GetString("SSO_COOKIE_DOMAIN")
|
||||||
SSOCookieSecure = viper.GetBool("SSO_COOKIE_SECURE")
|
SSOCookieSecure = viper.GetBool("SSO_COOKIE_SECURE")
|
||||||
|
|||||||
+59
-13
@@ -19,11 +19,11 @@ const (
|
|||||||
|
|
||||||
// AuthContext keeps authentication details captured by the middleware.
|
// AuthContext keeps authentication details captured by the middleware.
|
||||||
type AuthContext struct {
|
type AuthContext struct {
|
||||||
Token string
|
Token string
|
||||||
Verification *sso.VerificationResult
|
Verification *sso.VerificationResult
|
||||||
User *entity.User
|
User *entity.User
|
||||||
Roles []sso.Role
|
Roles []sso.Role
|
||||||
Permissions map[string]struct{}
|
Permissions map[string]struct{}
|
||||||
UserAreaIDs []uint
|
UserAreaIDs []uint
|
||||||
UserLocationIDs []uint
|
UserLocationIDs []uint
|
||||||
UserAllArea bool
|
UserAllArea bool
|
||||||
@@ -36,8 +36,30 @@ type AuthContext struct {
|
|||||||
func Auth(userService service.UserService, requiredScopes ...string) fiber.Handler {
|
func Auth(userService service.UserService, requiredScopes ...string) fiber.Handler {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
token := bearerToken(c)
|
token := bearerToken(c)
|
||||||
if token == "" {
|
tokenSource := ""
|
||||||
token = strings.TrimSpace(c.Cookies(config.SSOAccessCookieName))
|
if token != "" {
|
||||||
|
tokenSource = "header"
|
||||||
|
} else {
|
||||||
|
primaryName := strings.TrimSpace(config.SSOAccessCookieName)
|
||||||
|
if primaryName != "" {
|
||||||
|
token = strings.TrimSpace(c.Cookies(primaryName))
|
||||||
|
if token != "" {
|
||||||
|
tokenSource = "cookie:" + primaryName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if token == "" {
|
||||||
|
for _, name := range config.SSOAccessCookieFallback {
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if name == "" || name == primaryName {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
token = strings.TrimSpace(c.Cookies(name))
|
||||||
|
if token != "" {
|
||||||
|
tokenSource = "cookie:" + name
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if token == "" {
|
if token == "" {
|
||||||
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
|
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
|
||||||
@@ -45,7 +67,11 @@ func Auth(userService service.UserService, requiredScopes ...string) fiber.Handl
|
|||||||
|
|
||||||
verification, err := sso.VerifyAccessToken(token)
|
verification, err := sso.VerifyAccessToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Log.WithError(err).Warn("auth: token verification failed")
|
if sso.IsSignatureError(err) {
|
||||||
|
logSignatureError("auth", tokenSource, token, err)
|
||||||
|
} else {
|
||||||
|
utils.Log.WithError(err).Warn("auth: token verification failed")
|
||||||
|
}
|
||||||
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
|
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,11 +115,11 @@ func Auth(userService service.UserService, requiredScopes ...string) fiber.Handl
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx := &AuthContext{
|
ctx := &AuthContext{
|
||||||
Token: token,
|
Token: token,
|
||||||
Verification: verification,
|
Verification: verification,
|
||||||
User: user,
|
User: user,
|
||||||
Roles: roles,
|
Roles: roles,
|
||||||
Permissions: permissions,
|
Permissions: permissions,
|
||||||
UserAreaIDs: nil,
|
UserAreaIDs: nil,
|
||||||
UserLocationIDs: nil,
|
UserLocationIDs: nil,
|
||||||
UserAllArea: false,
|
UserAllArea: false,
|
||||||
@@ -216,6 +242,26 @@ func hasAllScopes(have, required []string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func logSignatureError(ctxLabel, tokenSource, token string, err error) {
|
||||||
|
info := sso.ExtractTokenInfo(token)
|
||||||
|
aud := strings.Join(info.Aud, ",")
|
||||||
|
utils.Log.Errorf(
|
||||||
|
"access token verification failed: %v | ctx=%s source=%s iss=%s kid=%s aud=%s sub=%s exp=%d iat=%d nbf=%d expected_iss=%s expected_aud=%v",
|
||||||
|
err,
|
||||||
|
ctxLabel,
|
||||||
|
tokenSource,
|
||||||
|
info.Iss,
|
||||||
|
info.Kid,
|
||||||
|
aud,
|
||||||
|
info.Sub,
|
||||||
|
info.Exp,
|
||||||
|
info.Iat,
|
||||||
|
info.Nbf,
|
||||||
|
config.SSOIssuer,
|
||||||
|
config.SSOAllowedAudiences,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// RequirePermissions ensures the authenticated user possesses all specified permissions.
|
// RequirePermissions ensures the authenticated user possesses all specified permissions.
|
||||||
func RequirePermissions(perms ...string) fiber.Handler {
|
func RequirePermissions(perms ...string) fiber.Handler {
|
||||||
required := canonicalPermissions(perms)
|
required := canonicalPermissions(perms)
|
||||||
|
|||||||
@@ -196,7 +196,11 @@ func (h *Controller) Refresh(c *fiber.Ctx) error {
|
|||||||
|
|
||||||
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
|
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Log.Errorf("access token verification failed: %v", err)
|
if sso.IsSignatureError(err) {
|
||||||
|
logSignatureError("sso refresh", "sso_token", tokenResp.AccessToken, err)
|
||||||
|
} else {
|
||||||
|
utils.Log.Errorf("access token verification failed: %v", err)
|
||||||
|
}
|
||||||
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
|
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -304,7 +308,11 @@ func (h *Controller) Callback(c *fiber.Ctx) error {
|
|||||||
|
|
||||||
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
|
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Log.Errorf("access token verification failed: %v", err)
|
if sso.IsSignatureError(err) {
|
||||||
|
logSignatureError("sso callback", "sso_token", tokenResp.AccessToken, err)
|
||||||
|
} else {
|
||||||
|
utils.Log.Errorf("access token verification failed: %v", err)
|
||||||
|
}
|
||||||
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
|
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -337,6 +345,22 @@ func (h *Controller) UserInfo(c *fiber.Ctx) error {
|
|||||||
|
|
||||||
token := strings.TrimSpace(c.Cookies(accessName))
|
token := strings.TrimSpace(c.Cookies(accessName))
|
||||||
tokenFromCookie := token != ""
|
tokenFromCookie := token != ""
|
||||||
|
usedCookieName := accessName
|
||||||
|
|
||||||
|
if !tokenFromCookie {
|
||||||
|
for _, name := range config.SSOAccessCookieFallback {
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if name == "" || name == accessName {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
token = strings.TrimSpace(c.Cookies(name))
|
||||||
|
if token != "" {
|
||||||
|
tokenFromCookie = true
|
||||||
|
usedCookieName = name
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !tokenFromCookie {
|
if !tokenFromCookie {
|
||||||
authHeader := strings.TrimSpace(c.Get("Authorization"))
|
authHeader := strings.TrimSpace(c.Get("Authorization"))
|
||||||
@@ -363,7 +387,11 @@ func (h *Controller) UserInfo(c *fiber.Ctx) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, err := sso.VerifyAccessToken(token); err != nil {
|
if _, err := sso.VerifyAccessToken(token); err != nil {
|
||||||
utils.Log.WithError(err).Warn("access token verification failed for userinfo")
|
if sso.IsSignatureError(err) {
|
||||||
|
logSignatureError("sso userinfo", "request", token, err)
|
||||||
|
} else {
|
||||||
|
utils.Log.WithError(err).Warn("access token verification failed for userinfo")
|
||||||
|
}
|
||||||
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -382,7 +410,7 @@ func (h *Controller) UserInfo(c *fiber.Ctx) error {
|
|||||||
// SSO /auth/get-me expects the access cookie; add Authorization as well for compatibility.
|
// SSO /auth/get-me expects the access cookie; add Authorization as well for compatibility.
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||||
if tokenFromCookie {
|
if tokenFromCookie {
|
||||||
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", accessName, token))
|
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", usedCookieName, token))
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := h.httpClient.Do(req)
|
resp, err := h.httpClient.Do(req)
|
||||||
@@ -836,6 +864,27 @@ func resolveSSOCookieName(configuredName, fallback string) string {
|
|||||||
return strings.TrimSpace(fallback)
|
return strings.TrimSpace(fallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func logSignatureError(ctxLabel, tokenSource, token string, err error) {
|
||||||
|
info := sso.ExtractTokenInfo(token)
|
||||||
|
aud := strings.Join(info.Aud, ",")
|
||||||
|
utils.Log.Errorf(
|
||||||
|
"access token verification failed: %v | ctx=%s source=%s iss=%s kid=%s aud=%s sub=%s exp=%d iat=%d nbf=%d expected_iss=%s expected_aud=%v jwks=%s",
|
||||||
|
err,
|
||||||
|
ctxLabel,
|
||||||
|
tokenSource,
|
||||||
|
info.Iss,
|
||||||
|
info.Kid,
|
||||||
|
aud,
|
||||||
|
info.Sub,
|
||||||
|
info.Exp,
|
||||||
|
info.Iat,
|
||||||
|
info.Nbf,
|
||||||
|
config.SSOIssuer,
|
||||||
|
config.SSOAllowedAudiences,
|
||||||
|
config.SSOJWKSURL,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeClientParam(raw string) string {
|
func normalizeClientParam(raw string) string {
|
||||||
value := strings.TrimSpace(raw)
|
value := strings.TrimSpace(raw)
|
||||||
if value == "" {
|
if value == "" {
|
||||||
|
|||||||
@@ -2,9 +2,11 @@ package sso
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -41,6 +43,16 @@ type VerificationResult struct {
|
|||||||
Claims *AccessTokenClaims
|
Claims *AccessTokenClaims
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TokenInfo struct {
|
||||||
|
Kid string
|
||||||
|
Iss string
|
||||||
|
Aud []string
|
||||||
|
Sub string
|
||||||
|
Exp int64
|
||||||
|
Iat int64
|
||||||
|
Nbf int64
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
globalMu sync.RWMutex
|
globalMu sync.RWMutex
|
||||||
globalV *verifier
|
globalV *verifier
|
||||||
@@ -106,10 +118,19 @@ func VerifyAccessToken(token string) (*VerificationResult, error) {
|
|||||||
jwt.WithIssuedAt(),
|
jwt.WithIssuedAt(),
|
||||||
jwt.WithExpirationRequired(),
|
jwt.WithExpirationRequired(),
|
||||||
)
|
)
|
||||||
|
|
||||||
tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parse token: %w", err)
|
if shouldRefreshOnVerifyError(err) {
|
||||||
|
if refreshErr := v.jwks.Refresh(context.Background(), keyfunc.RefreshOptions{IgnoreRateLimit: true}); refreshErr != nil {
|
||||||
|
utils.Log.WithError(refreshErr).Warn("sso jwks refresh after signature error failed")
|
||||||
|
} else {
|
||||||
|
tok, err = parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse token: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if !tok.Valid {
|
if !tok.Valid {
|
||||||
return nil, errors.New("invalid token")
|
return nil, errors.New("invalid token")
|
||||||
@@ -158,3 +179,106 @@ func VerifyAccessToken(token string) (*VerificationResult, error) {
|
|||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldRefreshOnVerifyError(err error) bool {
|
||||||
|
if !IsSignatureError(err) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !disableRefreshOnSignatureError()
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsSignatureError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
msg := err.Error()
|
||||||
|
return strings.Contains(msg, "verification error") || strings.Contains(msg, "token signature is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
func disableRefreshOnSignatureError() bool {
|
||||||
|
val := strings.TrimSpace(os.Getenv("SSO_DISABLE_JWKS_REFRESH_ON_SIG_ERROR"))
|
||||||
|
if val == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return val == "1" || strings.EqualFold(val, "true") || strings.EqualFold(val, "yes")
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtractTokenInfo(token string) TokenInfo {
|
||||||
|
token = strings.TrimSpace(token)
|
||||||
|
if token == "" {
|
||||||
|
return TokenInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := jwt.MapClaims{}
|
||||||
|
parser := jwt.NewParser()
|
||||||
|
tok, _, err := parser.ParseUnverified(token, claims)
|
||||||
|
if err != nil {
|
||||||
|
return TokenInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
info := TokenInfo{}
|
||||||
|
if kid, ok := tok.Header["kid"].(string); ok {
|
||||||
|
info.Kid = kid
|
||||||
|
}
|
||||||
|
if iss, ok := claims["iss"].(string); ok {
|
||||||
|
info.Iss = iss
|
||||||
|
}
|
||||||
|
if sub, ok := claims["sub"].(string); ok {
|
||||||
|
info.Sub = sub
|
||||||
|
}
|
||||||
|
if aud, ok := claims["aud"]; ok {
|
||||||
|
info.Aud = toStringSlice(aud)
|
||||||
|
}
|
||||||
|
info.Exp = toInt64(claims["exp"])
|
||||||
|
info.Iat = toInt64(claims["iat"])
|
||||||
|
info.Nbf = toInt64(claims["nbf"])
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
func toStringSlice(v any) []string {
|
||||||
|
switch t := v.(type) {
|
||||||
|
case string:
|
||||||
|
if t == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []string{t}
|
||||||
|
case []string:
|
||||||
|
out := make([]string, 0, len(t))
|
||||||
|
for _, s := range t {
|
||||||
|
if s != "" {
|
||||||
|
out = append(out, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
case []any:
|
||||||
|
out := make([]string, 0, len(t))
|
||||||
|
for _, item := range t {
|
||||||
|
if s, ok := item.(string); ok && s != "" {
|
||||||
|
out = append(out, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toInt64(v any) int64 {
|
||||||
|
switch t := v.(type) {
|
||||||
|
case int64:
|
||||||
|
return t
|
||||||
|
case int:
|
||||||
|
return int64(t)
|
||||||
|
case float64:
|
||||||
|
return int64(t)
|
||||||
|
case json.Number:
|
||||||
|
if n, err := t.Int64(); err == nil {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
if n, err := strconv.ParseInt(t, 10, 64); err == nil {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user