mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-20 05:21:57 +00:00
768 lines
23 KiB
Go
768 lines
23 KiB
Go
package controllers
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/sirupsen/logrus"
|
|
|
|
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
|
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
|
|
sso "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/verifier"
|
|
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
|
"gitlab.com/mbugroup/lti-api.git/internal/utils/secure"
|
|
)
|
|
|
|
// Controller manages the SSO start & callback flow using PKCE.
|
|
type Controller struct {
|
|
httpClient *http.Client
|
|
store *session.Store
|
|
revoker *session.RevocationStore
|
|
}
|
|
|
|
func NewController(client *http.Client, store *session.Store, revoker *session.RevocationStore) *Controller {
|
|
return &Controller{httpClient: client, store: store, revoker: revoker}
|
|
}
|
|
|
|
// Start handles GET /sso/start requests and redirects users to the central SSO authorize endpoint.
|
|
func (h *Controller) Start(c *fiber.Ctx) error {
|
|
requestedAlias := normalizeClientParam(c.Query("client"))
|
|
if requestedAlias == "" {
|
|
requestedAlias = normalizeClientParam(c.Query("client_id"))
|
|
}
|
|
if requestedAlias == "" {
|
|
return fiber.NewError(fiber.StatusBadRequest, "missing client")
|
|
}
|
|
|
|
alias, cfg, ok := findSSOClientConfig(requestedAlias)
|
|
if !ok || cfg.PublicID == "" {
|
|
return fiber.NewError(fiber.StatusBadRequest, "unknown client")
|
|
}
|
|
|
|
authorizeEndpoint := strings.TrimSpace(config.SSOAuthorizeURL)
|
|
if authorizeEndpoint == "" {
|
|
return fiber.NewError(fiber.StatusInternalServerError, "authorize endpoint not configured")
|
|
}
|
|
|
|
state, err := secure.RandomString(48)
|
|
if err != nil {
|
|
utils.Log.Errorf("generate state failed: %v", err)
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
|
}
|
|
|
|
nonce, err := secure.RandomString(32)
|
|
if err != nil {
|
|
utils.Log.Errorf("generate nonce failed: %v", err)
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
|
}
|
|
|
|
codeVerifier, err := secure.PKCECodeVerifier(96)
|
|
if err != nil {
|
|
utils.Log.Errorf("generate code verifier failed: %v", err)
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
|
}
|
|
|
|
digest := sha256.Sum256([]byte(codeVerifier))
|
|
challenge := secure.Base64URLEncode(digest[:])
|
|
|
|
authorizeURL, err := url.Parse(authorizeEndpoint)
|
|
if err != nil {
|
|
return fiber.NewError(fiber.StatusInternalServerError, "invalid authorize endpoint")
|
|
}
|
|
|
|
scope := cfg.Scope
|
|
if scope == "" {
|
|
scope = "openid profile"
|
|
}
|
|
if !strings.Contains(" "+scope+" ", " openid ") {
|
|
scope = scope + " openid"
|
|
}
|
|
|
|
rawReturn := strings.TrimSpace(c.Query("return_to"))
|
|
if rawReturn == "" {
|
|
rawReturn = cfg.DefaultReturnURI
|
|
}
|
|
|
|
returnTo, err := normalizeReturnTarget(rawReturn, cfg)
|
|
if err != nil {
|
|
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
|
}
|
|
|
|
query := authorizeURL.Query()
|
|
query.Set("response_type", "code")
|
|
query.Set("client_id", cfg.PublicID)
|
|
query.Set("redirect_uri", cfg.RedirectURI)
|
|
query.Set("scope", strings.TrimSpace(scope))
|
|
query.Set("state", state)
|
|
query.Set("code_challenge", challenge)
|
|
query.Set("code_challenge_method", "S256")
|
|
query.Set("nonce", nonce)
|
|
// if prompt := strings.TrimSpace(cfg.Prompt); prompt != "" {
|
|
// query.Set("prompt", prompt)
|
|
// }
|
|
if extraPrompt := strings.TrimSpace(c.Query("prompt")); extraPrompt != "" {
|
|
query.Set("prompt", extraPrompt)
|
|
}
|
|
authorizeURL.RawQuery = query.Encode()
|
|
|
|
payload := &session.PKCESession{
|
|
CodeVerifier: codeVerifier,
|
|
Nonce: nonce,
|
|
ClientAlias: alias,
|
|
ClientID: cfg.PublicID,
|
|
RedirectURI: cfg.RedirectURI,
|
|
Scope: strings.TrimSpace(scope),
|
|
ReturnTo: returnTo,
|
|
CreatedAt: time.Now().UTC(),
|
|
}
|
|
|
|
if err := h.store.Save(c.Context(), state, payload); err != nil {
|
|
utils.Log.Errorf("store pkce session failed: %v", err)
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
|
}
|
|
|
|
utils.Log.WithFields(logrus.Fields{
|
|
"client": alias,
|
|
"state": state,
|
|
"return_to": returnTo,
|
|
}).Info("sso start redirect")
|
|
|
|
return c.Redirect(authorizeURL.String(), fiber.StatusFound)
|
|
}
|
|
|
|
// Refresh exchanges the current SSO refresh token for a new access/refresh pair
|
|
// without redirecting the browser to the SSO login page.
|
|
func (h *Controller) Refresh(c *fiber.Ctx) error {
|
|
refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh")
|
|
refreshToken := strings.TrimSpace(c.Cookies(refreshName))
|
|
if refreshToken == "" {
|
|
if target := buildStartRedirect(defaultSSOClientAlias()); target != "" {
|
|
return c.Redirect(target, fiber.StatusFound)
|
|
}
|
|
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
|
}
|
|
|
|
tokenEndpoint := strings.TrimSpace(config.SSOTokenURL)
|
|
if tokenEndpoint == "" {
|
|
return fiber.NewError(fiber.StatusInternalServerError, "token endpoint not configured")
|
|
}
|
|
|
|
form := url.Values{}
|
|
form.Set("grant_type", "refresh_token")
|
|
form.Set("refresh_token", refreshToken)
|
|
|
|
req, err := http.NewRequestWithContext(c.Context(), http.MethodPost, tokenEndpoint, strings.NewReader(form.Encode()))
|
|
if err != nil {
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to create refresh request")
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
resp, err := h.httpClient.Do(req)
|
|
if err != nil {
|
|
utils.Log.Errorf("token refresh request failed: %v", err)
|
|
return fiber.NewError(fiber.StatusBadGateway, "failed to refresh access token")
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode >= 400 {
|
|
utils.Log.Warnf("token refresh response status %d", resp.StatusCode)
|
|
if resp.StatusCode == fiber.StatusTooManyRequests {
|
|
return fiber.NewError(fiber.StatusTooManyRequests, "Too many attempts, please slow down")
|
|
}
|
|
if target := buildStartRedirect(defaultSSOClientAlias()); target != "" {
|
|
return c.Redirect(target, fiber.StatusFound)
|
|
}
|
|
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
|
}
|
|
|
|
var tokenResp refreshTokenResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
|
return fiber.NewError(fiber.StatusBadGateway, "invalid token response")
|
|
}
|
|
if tokenResp.Error != "" {
|
|
return fiber.NewError(fiber.StatusBadGateway, tokenResp.Description)
|
|
}
|
|
if tokenResp.AccessToken == "" {
|
|
return fiber.NewError(fiber.StatusBadGateway, "missing access token")
|
|
}
|
|
|
|
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
|
|
if err != nil {
|
|
if sso.IsSignatureError(err) {
|
|
logSignatureError("sso refresh", "sso_token", tokenResp.AccessToken, err)
|
|
} else {
|
|
utils.Log.Errorf("access token verification failed: %v", err)
|
|
}
|
|
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
|
|
}
|
|
|
|
if err := issueCookies(c, struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
Scope string `json:"scope"`
|
|
IDToken string `json:"id_token"`
|
|
Error string `json:"error"`
|
|
Description string `json:"error_description"`
|
|
}{
|
|
AccessToken: tokenResp.AccessToken,
|
|
RefreshToken: tokenResp.RefreshToken,
|
|
TokenType: tokenResp.TokenType,
|
|
ExpiresIn: tokenResp.ExpiresIn,
|
|
Scope: tokenResp.Scope,
|
|
IDToken: tokenResp.IDToken,
|
|
Error: tokenResp.Error,
|
|
Description: tokenResp.Description,
|
|
}, verification); err != nil {
|
|
return err
|
|
}
|
|
|
|
utils.Log.WithFields(logrus.Fields{
|
|
"user_id": verification.UserID,
|
|
}).Info("sso refresh successful")
|
|
|
|
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "ok"})
|
|
}
|
|
|
|
// Callback handles the redirect from SSO containing the authorization code.
|
|
func (h *Controller) Callback(c *fiber.Ctx) error {
|
|
state := strings.TrimSpace(c.Query("state"))
|
|
code := strings.TrimSpace(c.Query("code"))
|
|
if state == "" || code == "" {
|
|
return fiber.NewError(fiber.StatusBadRequest, "missing code or state")
|
|
}
|
|
|
|
sessionData, err := h.store.Get(c.Context(), state)
|
|
if err != nil {
|
|
utils.Log.Errorf("load pkce session failed: %v", err)
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to validate authorization state")
|
|
}
|
|
if sessionData == nil {
|
|
return fiber.NewError(fiber.StatusBadRequest, "authorization state not found or expired")
|
|
}
|
|
defer func() {
|
|
if err := h.store.Delete(context.Background(), state); err != nil {
|
|
utils.Log.Warnf("failed to delete pkce session: %v", err)
|
|
}
|
|
}()
|
|
|
|
tokenEndpoint := strings.TrimSpace(config.SSOTokenURL)
|
|
if tokenEndpoint == "" {
|
|
return fiber.NewError(fiber.StatusInternalServerError, "token endpoint not configured")
|
|
}
|
|
|
|
form := url.Values{}
|
|
form.Set("grant_type", "authorization_code")
|
|
form.Set("code", code)
|
|
form.Set("code_verifier", sessionData.CodeVerifier)
|
|
form.Set("redirect_uri", sessionData.RedirectURI)
|
|
form.Set("client_id", sessionData.ClientID)
|
|
|
|
req, err := http.NewRequestWithContext(c.Context(), http.MethodPost, tokenEndpoint, strings.NewReader(form.Encode()))
|
|
if err != nil {
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to create token request")
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
resp, err := h.httpClient.Do(req)
|
|
if err != nil {
|
|
utils.Log.Errorf("token request failed: %v", err)
|
|
return fiber.NewError(fiber.StatusBadGateway, "failed to exchange authorization code")
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode >= 400 {
|
|
utils.Log.Warnf("token response status %d", resp.StatusCode)
|
|
return fiber.NewError(fiber.StatusBadGateway, "token exchange rejected")
|
|
}
|
|
|
|
var tokenResp struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
Scope string `json:"scope"`
|
|
IDToken string `json:"id_token"`
|
|
Error string `json:"error"`
|
|
Description string `json:"error_description"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
|
return fiber.NewError(fiber.StatusBadGateway, "invalid token response")
|
|
}
|
|
if tokenResp.Error != "" {
|
|
return fiber.NewError(fiber.StatusBadGateway, tokenResp.Description)
|
|
}
|
|
if tokenResp.AccessToken == "" {
|
|
return fiber.NewError(fiber.StatusBadGateway, "missing access token")
|
|
}
|
|
|
|
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
|
|
if err != nil {
|
|
if sso.IsSignatureError(err) {
|
|
logSignatureError("sso callback", "sso_token", tokenResp.AccessToken, err)
|
|
} else {
|
|
utils.Log.Errorf("access token verification failed: %v", err)
|
|
}
|
|
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
|
|
}
|
|
|
|
// prepare cookies
|
|
if err := issueCookies(c, tokenResp, verification); err != nil {
|
|
return err
|
|
}
|
|
|
|
redirectTarget := sessionData.ReturnTo
|
|
if redirectTarget == "" {
|
|
redirectTarget = "/"
|
|
}
|
|
|
|
utils.Log.WithFields(logrus.Fields{
|
|
"client": sessionData.ClientAlias,
|
|
"user_id": verification.UserID,
|
|
"return_to": redirectTarget,
|
|
}).Info("sso callback successful")
|
|
|
|
return c.Redirect(redirectTarget, fiber.StatusFound)
|
|
}
|
|
|
|
// UserInfo proxies the user profile from the central SSO so the frontend can obtain
|
|
// enriched user metadata (roles, permissions, etc.) without exposing tokens to the browser.
|
|
func (h *Controller) UserInfo(c *fiber.Ctx) error {
|
|
accessName := config.SSOAccessCookieName
|
|
if accessName == "" {
|
|
accessName = "sso_access"
|
|
}
|
|
|
|
token := strings.TrimSpace(c.Cookies(accessName))
|
|
tokenFromCookie := token != ""
|
|
usedCookieName := accessName
|
|
|
|
if !tokenFromCookie {
|
|
for _, name := range config.SSOAccessCookieFallback {
|
|
name = strings.TrimSpace(name)
|
|
if name == "" || name == accessName {
|
|
continue
|
|
}
|
|
token = strings.TrimSpace(c.Cookies(name))
|
|
if token != "" {
|
|
tokenFromCookie = true
|
|
usedCookieName = name
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if !tokenFromCookie {
|
|
authHeader := strings.TrimSpace(c.Get("Authorization"))
|
|
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
|
token = strings.TrimSpace(authHeader[7:])
|
|
}
|
|
}
|
|
|
|
if token == "" {
|
|
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
|
}
|
|
|
|
if revoker := session.GetRevocationStore(); revoker != nil {
|
|
if fingerprint := session.TokenFingerprint(token); fingerprint != "" {
|
|
revoked, err := revoker.IsRevoked(c.Context(), fingerprint)
|
|
if err != nil {
|
|
utils.Log.WithError(err).Warn("failed to check token revocation for userinfo")
|
|
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
|
}
|
|
if revoked {
|
|
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
|
}
|
|
}
|
|
}
|
|
|
|
if _, err := sso.VerifyAccessToken(token); err != nil {
|
|
if sso.IsSignatureError(err) {
|
|
logSignatureError("sso userinfo", "request", token, err)
|
|
} else {
|
|
utils.Log.WithError(err).Warn("access token verification failed for userinfo")
|
|
}
|
|
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
|
}
|
|
|
|
endpoint := strings.TrimSpace(config.SSOGetMeURL)
|
|
if endpoint == "" {
|
|
return fiber.NewError(fiber.StatusInternalServerError, "userinfo endpoint not configured")
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(c.Context(), http.MethodGet, endpoint, nil)
|
|
if err != nil {
|
|
utils.Log.Errorf("failed to build userinfo request: %v", err)
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare userinfo request")
|
|
}
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
// SSO /auth/get-me expects the access cookie; add Authorization as well for compatibility.
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
|
if tokenFromCookie {
|
|
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", usedCookieName, token))
|
|
}
|
|
|
|
resp, err := h.httpClient.Do(req)
|
|
if err != nil {
|
|
utils.Log.Errorf("userinfo request failed: %v", err)
|
|
return fiber.NewError(fiber.StatusBadGateway, "failed to fetch user profile")
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
utils.Log.WithFields(logrus.Fields{"status": resp.StatusCode}).Info("sso userinfo response")
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
utils.Log.Errorf("failed to read userinfo response: %v", err)
|
|
return fiber.NewError(fiber.StatusBadGateway, "invalid user profile response")
|
|
}
|
|
|
|
if ct := resp.Header.Get("Content-Type"); ct != "" {
|
|
c.Set("Content-Type", ct)
|
|
} else {
|
|
c.Type("json")
|
|
}
|
|
|
|
return c.Status(resp.StatusCode).Send(body)
|
|
}
|
|
|
|
// Logout clears SSO cookies and removes any leftover PKCE session state.
|
|
func (h *Controller) Logout(c *fiber.Ctx) error {
|
|
alias := ""
|
|
if singleAlias, _, ok := singleSSOClient(); ok {
|
|
alias = singleAlias
|
|
}
|
|
|
|
accessName := resolveSSOCookieName(config.SSOAccessCookieName, "access")
|
|
refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh")
|
|
|
|
var accessToken, refreshToken string
|
|
var verification *sso.VerificationResult
|
|
if accessName != "" {
|
|
accessToken = strings.TrimSpace(c.Cookies(accessName))
|
|
}
|
|
if refreshName != "" {
|
|
refreshToken = strings.TrimSpace(c.Cookies(refreshName))
|
|
}
|
|
hadAccessCookie := accessToken != ""
|
|
hadRefreshCookie := refreshToken != ""
|
|
|
|
if !hadAccessCookie && !hadRefreshCookie {
|
|
return fiber.NewError(fiber.StatusUnauthorized, "not authenticated")
|
|
}
|
|
|
|
if hadAccessCookie {
|
|
if v, err := sso.VerifyAccessToken(accessToken); err != nil {
|
|
utils.Log.WithError(err).Warn("failed to verify access token during logout")
|
|
} else {
|
|
verification = v
|
|
if revoker := session.GetRevocationStore(); revoker != nil {
|
|
if err := revoker.MarkUserLogout(c.Context(), verification.UserID, time.Now().UTC()); err != nil {
|
|
utils.Log.WithError(err).Warn("failed to mark user logout")
|
|
}
|
|
}
|
|
h.revokeToken(c.Context(), accessToken, verification)
|
|
}
|
|
}
|
|
if refreshToken != "" {
|
|
h.revokeRefreshToken(c.Context(), refreshToken)
|
|
}
|
|
|
|
clearSSOCookie(c, accessName)
|
|
clearSSOCookie(c, refreshName)
|
|
|
|
redirectTarget := ""
|
|
if config.SSOPortalURL != "" {
|
|
redirectTarget = config.SSOPortalURL
|
|
}
|
|
|
|
utils.Log.WithFields(logrus.Fields{
|
|
"client": alias,
|
|
"redirect": redirectTarget,
|
|
}).Info("sso logout completed")
|
|
|
|
if redirectTarget != "" {
|
|
return c.Status(fiber.StatusOK).JSON(fiber.Map{
|
|
"status": "signed out",
|
|
"redirect": redirectTarget,
|
|
})
|
|
}
|
|
|
|
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "signed out"})
|
|
}
|
|
|
|
func singleSSOClient() (string, config.SSOClientConfig, bool) {
|
|
if len(config.SSOClients) != 1 {
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
for alias, cfg := range config.SSOClients {
|
|
if strings.TrimSpace(alias) == "" || strings.TrimSpace(cfg.PublicID) == "" {
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
return alias, cfg, true
|
|
}
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
|
|
|
|
func defaultSSOClientAlias() string {
|
|
for alias := range config.SSOClients {
|
|
if strings.TrimSpace(alias) == "" {
|
|
continue
|
|
}
|
|
return alias
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func buildStartRedirect(alias string) string {
|
|
alias = strings.TrimSpace(alias)
|
|
if alias == "" {
|
|
return ""
|
|
}
|
|
return "/api/sso/start?client=" + url.QueryEscape(alias)
|
|
}
|
|
|
|
func (h *Controller) revokeToken(ctx context.Context, token string, verification *sso.VerificationResult) {
|
|
if h.revoker == nil || verification == nil || verification.Claims == nil {
|
|
return
|
|
}
|
|
fingerprint := session.TokenFingerprint(token)
|
|
if fingerprint == "" {
|
|
return
|
|
}
|
|
if verification.Claims.ExpiresAt == nil {
|
|
utils.Log.Warn("access token missing expiry claim")
|
|
return
|
|
}
|
|
ttl := time.Until(verification.Claims.ExpiresAt.Time)
|
|
if ttl <= 0 {
|
|
return
|
|
}
|
|
if ttl < time.Second {
|
|
ttl = time.Second
|
|
}
|
|
if err := h.revoker.Revoke(ctx, fingerprint, ttl); err != nil {
|
|
utils.Log.WithError(err).Warn("failed to revoke access token")
|
|
}
|
|
}
|
|
|
|
func (h *Controller) revokeRefreshToken(ctx context.Context, token string) {
|
|
if h.revoker == nil {
|
|
return
|
|
}
|
|
fingerprint := session.TokenFingerprint(token)
|
|
if fingerprint == "" {
|
|
return
|
|
}
|
|
const refreshTTL = 30 * 24 * time.Hour
|
|
if err := h.revoker.Revoke(ctx, fingerprint, refreshTTL); err != nil {
|
|
utils.Log.WithError(err).Warn("failed to revoke refresh token")
|
|
}
|
|
}
|
|
|
|
func issueCookies(c *fiber.Ctx, tokenResp struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
Scope string `json:"scope"`
|
|
IDToken string `json:"id_token"`
|
|
Error string `json:"error"`
|
|
Description string `json:"error_description"`
|
|
}, verification *sso.VerificationResult) error {
|
|
if revoker := session.GetRevocationStore(); revoker != nil && verification != nil {
|
|
if err := revoker.ClearUserLogout(c.Context(), verification.UserID); err != nil {
|
|
utils.Log.WithError(err).Warn("failed to clear logout marker")
|
|
}
|
|
}
|
|
|
|
if max := config.SSOAccessTokenMaxBytes; max > 0 && len(tokenResp.AccessToken) > max {
|
|
utils.Log.WithFields(logrus.Fields{
|
|
"token_len": len(tokenResp.AccessToken),
|
|
"max_len": max,
|
|
}).Warn("sso access token exceeds cookie size limit")
|
|
return fiber.NewError(fiber.StatusRequestEntityTooLarge, "access token too large")
|
|
}
|
|
|
|
accessName := resolveSSOCookieName(config.SSOAccessCookieName, "access")
|
|
refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh")
|
|
maxAge := tokenResp.ExpiresIn
|
|
if maxAge <= 0 {
|
|
maxAge = int(15 * time.Minute.Seconds())
|
|
}
|
|
|
|
sameSite := config.SSOCookieSameSite
|
|
if sameSite == "" {
|
|
sameSite = "Lax"
|
|
}
|
|
|
|
cookieDomain := config.SSOCookieDomain
|
|
|
|
cookieAccess := &fiber.Cookie{
|
|
Name: accessName,
|
|
Value: tokenResp.AccessToken,
|
|
Path: "/",
|
|
Domain: cookieDomain,
|
|
HTTPOnly: true,
|
|
Secure: config.SSOCookieSecure,
|
|
SameSite: sameSite,
|
|
MaxAge: maxAge,
|
|
}
|
|
c.Cookie(cookieAccess)
|
|
if tokenResp.RefreshToken != "" {
|
|
cookieRefresh := &fiber.Cookie{
|
|
Name: refreshName,
|
|
Value: tokenResp.RefreshToken,
|
|
Path: "/",
|
|
Domain: cookieDomain,
|
|
HTTPOnly: true,
|
|
Secure: config.SSOCookieSecure,
|
|
SameSite: sameSite,
|
|
MaxAge: int((time.Hour * 24 * 30).Seconds()),
|
|
}
|
|
c.Cookie(cookieRefresh)
|
|
}
|
|
|
|
// Optional: expose limited info via headers for FE debugging (avoid tokens)
|
|
c.Set("X-Auth-User", fmt.Sprintf("%d", verification.UserID))
|
|
return nil
|
|
}
|
|
|
|
func clearSSOCookie(c *fiber.Ctx, name string) {
|
|
if name == "" {
|
|
return
|
|
}
|
|
|
|
sameSite := config.SSOCookieSameSite
|
|
if sameSite == "" {
|
|
sameSite = "Lax"
|
|
}
|
|
|
|
c.Cookie(&fiber.Cookie{
|
|
Name: name,
|
|
Value: "",
|
|
Path: "/",
|
|
Domain: config.SSOCookieDomain,
|
|
HTTPOnly: true,
|
|
Secure: config.SSOCookieSecure,
|
|
SameSite: sameSite,
|
|
Expires: time.Unix(0, 0),
|
|
MaxAge: -1,
|
|
})
|
|
}
|
|
|
|
func resolveSSOCookieName(configuredName, fallback string) string {
|
|
name := strings.TrimSpace(configuredName)
|
|
if name != "" {
|
|
return name
|
|
}
|
|
return strings.TrimSpace(fallback)
|
|
}
|
|
|
|
func logSignatureError(ctxLabel, tokenSource, token string, err error) {
|
|
info := sso.ExtractTokenInfo(token)
|
|
aud := strings.Join(info.Aud, ",")
|
|
utils.Log.Errorf(
|
|
"access token verification failed: %v | ctx=%s source=%s iss=%s kid=%s aud=%s sub=%s exp=%d iat=%d nbf=%d expected_iss=%s expected_aud=%v jwks=%s",
|
|
err,
|
|
ctxLabel,
|
|
tokenSource,
|
|
info.Iss,
|
|
info.Kid,
|
|
aud,
|
|
info.Sub,
|
|
info.Exp,
|
|
info.Iat,
|
|
info.Nbf,
|
|
config.SSOIssuer,
|
|
config.SSOAllowedAudiences,
|
|
config.SSOJWKSURL,
|
|
)
|
|
}
|
|
|
|
func normalizeClientParam(raw string) string {
|
|
value := strings.TrimSpace(raw)
|
|
if value == "" {
|
|
return ""
|
|
}
|
|
if idx := strings.Index(value, "|"); idx >= 0 {
|
|
value = value[:idx]
|
|
}
|
|
value = strings.TrimSpace(value)
|
|
return strings.ToLower(value)
|
|
}
|
|
|
|
|
|
func findSSOClientConfig(requestedAlias string) (string, config.SSOClientConfig, bool) {
|
|
if requestedAlias == "" {
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
if cfg, ok := config.SSOClients[requestedAlias]; ok && strings.TrimSpace(cfg.PublicID) != "" {
|
|
return requestedAlias, cfg, true
|
|
}
|
|
for alias, cfg := range config.SSOClients {
|
|
if strings.EqualFold(strings.TrimSpace(cfg.PublicID), requestedAlias) && strings.TrimSpace(cfg.PublicID) != "" {
|
|
return alias, cfg, true
|
|
}
|
|
}
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
|
|
func normalizeReturnTarget(returnTo string, cfg config.SSOClientConfig) (string, error) {
|
|
returnTo = strings.TrimSpace(returnTo)
|
|
if returnTo == "" {
|
|
return "", nil
|
|
}
|
|
if strings.HasPrefix(returnTo, "//") {
|
|
return "", fmt.Errorf("invalid return_to")
|
|
}
|
|
if strings.HasPrefix(returnTo, "/") {
|
|
return returnTo, nil
|
|
}
|
|
|
|
parsed, err := url.Parse(returnTo)
|
|
if err != nil {
|
|
return "", fmt.Errorf("invalid return_to")
|
|
}
|
|
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
|
return "", fmt.Errorf("invalid return_to scheme")
|
|
}
|
|
|
|
allowedOrigins := make(map[string]struct{})
|
|
if cfg.DefaultReturnURI != "" {
|
|
if u, err := url.Parse(cfg.DefaultReturnURI); err == nil && u.Host != "" {
|
|
allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{}
|
|
}
|
|
}
|
|
for _, origin := range cfg.AllowedReturnOrigins {
|
|
origin = strings.TrimSpace(origin)
|
|
if origin == "" {
|
|
continue
|
|
}
|
|
if u, err := url.Parse(origin); err == nil && u.Host != "" && (u.Scheme == "http" || u.Scheme == "https") {
|
|
allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{}
|
|
}
|
|
}
|
|
|
|
if len(allowedOrigins) > 0 {
|
|
origin := parsed.Scheme + "://" + parsed.Host
|
|
if _, ok := allowedOrigins[origin]; !ok {
|
|
return "", fmt.Errorf("return_to origin not allowed")
|
|
}
|
|
}
|
|
|
|
return parsed.String(), nil
|
|
}
|