Files
lti-api/internal/modules/sso/controllers/sso.controller.go
T

991 lines
29 KiB
Go

package controllers
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/sirupsen/logrus"
"gitlab.com/mbugroup/lti-api.git/internal/config"
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
"gitlab.com/mbugroup/lti-api.git/internal/sso"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
"gitlab.com/mbugroup/lti-api.git/internal/utils/secure"
)
// Controller manages the SSO start & callback flow using PKCE.
type Controller struct {
httpClient *http.Client
store *session.Store
revoker *session.RevocationStore
}
func NewController(client *http.Client, store *session.Store, revoker *session.RevocationStore) *Controller {
return &Controller{httpClient: client, store: store, revoker: revoker}
}
// Start handles GET /sso/start requests and redirects users to the central SSO authorize endpoint.
func (h *Controller) Start(c *fiber.Ctx) error {
requestedAlias := normalizeClientParam(c.Query("client"))
if requestedAlias == "" {
requestedAlias = normalizeClientParam(c.Query("client_id"))
}
if requestedAlias == "" {
return fiber.NewError(fiber.StatusBadRequest, "missing client")
}
alias, cfg, ok := findSSOClientConfig(requestedAlias)
if !ok || cfg.PublicID == "" {
return fiber.NewError(fiber.StatusBadRequest, "unknown client")
}
authorizeEndpoint := strings.TrimSpace(config.SSOAuthorizeURL)
if authorizeEndpoint == "" {
return fiber.NewError(fiber.StatusInternalServerError, "authorize endpoint not configured")
}
state, err := secure.RandomString(48)
if err != nil {
utils.Log.Errorf("generate state failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
}
nonce, err := secure.RandomString(32)
if err != nil {
utils.Log.Errorf("generate nonce failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
}
codeVerifier, err := secure.PKCECodeVerifier(96)
if err != nil {
utils.Log.Errorf("generate code verifier failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
}
digest := sha256.Sum256([]byte(codeVerifier))
challenge := secure.Base64URLEncode(digest[:])
authorizeURL, err := url.Parse(authorizeEndpoint)
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, "invalid authorize endpoint")
}
scope := cfg.Scope
if scope == "" {
scope = "openid profile"
}
if !strings.Contains(" "+scope+" ", " openid ") {
scope = scope + " openid"
}
rawReturn := strings.TrimSpace(c.Query("return_to"))
if rawReturn == "" {
rawReturn = cfg.DefaultReturnURI
}
returnTo, err := normalizeReturnTarget(rawReturn, cfg)
if err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
}
query := authorizeURL.Query()
query.Set("response_type", "code")
query.Set("client_id", cfg.PublicID)
query.Set("redirect_uri", cfg.RedirectURI)
query.Set("scope", strings.TrimSpace(scope))
query.Set("state", state)
query.Set("code_challenge", challenge)
query.Set("code_challenge_method", "S256")
query.Set("nonce", nonce)
// if prompt := strings.TrimSpace(cfg.Prompt); prompt != "" {
// query.Set("prompt", prompt)
// }
if extraPrompt := strings.TrimSpace(c.Query("prompt")); extraPrompt != "" {
query.Set("prompt", extraPrompt)
}
authorizeURL.RawQuery = query.Encode()
payload := &session.PKCESession{
CodeVerifier: codeVerifier,
Nonce: nonce,
ClientAlias: alias,
ClientID: cfg.PublicID,
RedirectURI: cfg.RedirectURI,
Scope: strings.TrimSpace(scope),
ReturnTo: returnTo,
CreatedAt: time.Now().UTC(),
}
if err := h.store.Save(c.Context(), state, payload); err != nil {
utils.Log.Errorf("store pkce session failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
}
utils.Log.WithFields(logrus.Fields{
"client": alias,
"state": state,
"return_to": returnTo,
}).Info("sso start redirect")
return c.Redirect(authorizeURL.String(), fiber.StatusFound)
}
// Refresh exchanges the current SSO refresh token for a new access/refresh pair
// without redirecting the browser to the SSO login page.
func (h *Controller) Refresh(c *fiber.Ctx) error {
refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh")
refreshToken := strings.TrimSpace(c.Cookies(refreshName))
if refreshToken == "" {
if target := buildStartRedirect(defaultSSOClientAlias()); target != "" {
return c.Redirect(target, fiber.StatusFound)
}
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
}
tokenEndpoint := strings.TrimSpace(config.SSOTokenURL)
if tokenEndpoint == "" {
return fiber.NewError(fiber.StatusInternalServerError, "token endpoint not configured")
}
form := url.Values{}
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", refreshToken)
req, err := http.NewRequestWithContext(c.Context(), http.MethodPost, tokenEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, "failed to create refresh request")
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := h.httpClient.Do(req)
if err != nil {
utils.Log.Errorf("token refresh request failed: %v", err)
return fiber.NewError(fiber.StatusBadGateway, "failed to refresh access token")
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
utils.Log.Warnf("token refresh response status %d", resp.StatusCode)
if resp.StatusCode == fiber.StatusTooManyRequests {
return fiber.NewError(fiber.StatusTooManyRequests, "Too many attempts, please slow down")
}
if target := buildStartRedirect(defaultSSOClientAlias()); target != "" {
return c.Redirect(target, fiber.StatusFound)
}
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
}
var tokenResp refreshTokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return fiber.NewError(fiber.StatusBadGateway, "invalid token response")
}
if tokenResp.Error != "" {
return fiber.NewError(fiber.StatusBadGateway, tokenResp.Description)
}
if tokenResp.AccessToken == "" {
return fiber.NewError(fiber.StatusBadGateway, "missing access token")
}
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
if err != nil {
utils.Log.Errorf("access token verification failed: %v", err)
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
}
issueCookies(c, struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
IDToken string `json:"id_token"`
Error string `json:"error"`
Description string `json:"error_description"`
}{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
Scope: tokenResp.Scope,
IDToken: tokenResp.IDToken,
Error: tokenResp.Error,
Description: tokenResp.Description,
}, verification)
utils.Log.WithFields(logrus.Fields{
"user_id": verification.UserID,
}).Info("sso refresh successful")
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "ok"})
}
// Callback handles the redirect from SSO containing the authorization code.
func (h *Controller) Callback(c *fiber.Ctx) error {
state := strings.TrimSpace(c.Query("state"))
code := strings.TrimSpace(c.Query("code"))
if state == "" || code == "" {
return fiber.NewError(fiber.StatusBadRequest, "missing code or state")
}
sessionData, err := h.store.Get(c.Context(), state)
if err != nil {
utils.Log.Errorf("load pkce session failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to validate authorization state")
}
if sessionData == nil {
return fiber.NewError(fiber.StatusBadRequest, "authorization state not found or expired")
}
defer func() {
if err := h.store.Delete(context.Background(), state); err != nil {
utils.Log.Warnf("failed to delete pkce session: %v", err)
}
}()
tokenEndpoint := strings.TrimSpace(config.SSOTokenURL)
if tokenEndpoint == "" {
return fiber.NewError(fiber.StatusInternalServerError, "token endpoint not configured")
}
form := url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("code", code)
form.Set("code_verifier", sessionData.CodeVerifier)
form.Set("redirect_uri", sessionData.RedirectURI)
form.Set("client_id", sessionData.ClientID)
req, err := http.NewRequestWithContext(c.Context(), http.MethodPost, tokenEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, "failed to create token request")
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := h.httpClient.Do(req)
if err != nil {
utils.Log.Errorf("token request failed: %v", err)
return fiber.NewError(fiber.StatusBadGateway, "failed to exchange authorization code")
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
utils.Log.Warnf("token response status %d", resp.StatusCode)
return fiber.NewError(fiber.StatusBadGateway, "token exchange rejected")
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
IDToken string `json:"id_token"`
Error string `json:"error"`
Description string `json:"error_description"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return fiber.NewError(fiber.StatusBadGateway, "invalid token response")
}
if tokenResp.Error != "" {
return fiber.NewError(fiber.StatusBadGateway, tokenResp.Description)
}
if tokenResp.AccessToken == "" {
return fiber.NewError(fiber.StatusBadGateway, "missing access token")
}
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
if err != nil {
utils.Log.Errorf("access token verification failed: %v", err)
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
}
// prepare cookies
issueCookies(c, tokenResp, verification)
redirectTarget := sessionData.ReturnTo
if redirectTarget == "" {
redirectTarget = "/"
}
utils.Log.WithFields(logrus.Fields{
"client": sessionData.ClientAlias,
"user_id": verification.UserID,
"return_to": redirectTarget,
}).Info("sso callback successful")
return c.Redirect(redirectTarget, fiber.StatusFound)
}
// UserInfo proxies the user profile from the central SSO so the frontend can obtain
// enriched user metadata (roles, permissions, etc.) without exposing tokens to the browser.
func (h *Controller) UserInfo(c *fiber.Ctx) error {
accessName := config.SSOAccessCookieName
if accessName == "" {
accessName = "sso_access"
}
token := strings.TrimSpace(c.Cookies(accessName))
tokenFromCookie := token != ""
if !tokenFromCookie {
authHeader := strings.TrimSpace(c.Get("Authorization"))
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
token = strings.TrimSpace(authHeader[7:])
}
}
if token == "" {
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
}
if revoker := session.GetRevocationStore(); revoker != nil {
if fingerprint := session.TokenFingerprint(token); fingerprint != "" {
revoked, err := revoker.IsRevoked(c.Context(), fingerprint)
if err != nil {
utils.Log.WithError(err).Warn("failed to check token revocation for userinfo")
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
}
if revoked {
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
}
}
}
if _, err := sso.VerifyAccessToken(token); err != nil {
utils.Log.WithError(err).Warn("access token verification failed for userinfo")
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
}
endpoint := strings.TrimSpace(config.SSOGetMeURL)
if endpoint == "" {
return fiber.NewError(fiber.StatusInternalServerError, "userinfo endpoint not configured")
}
req, err := http.NewRequestWithContext(c.Context(), http.MethodGet, endpoint, nil)
if err != nil {
utils.Log.Errorf("failed to build userinfo request: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare userinfo request")
}
req.Header.Set("Accept", "application/json")
// SSO /auth/get-me expects the access cookie; add Authorization as well for compatibility.
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
if tokenFromCookie {
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", accessName, token))
}
resp, err := h.httpClient.Do(req)
if err != nil {
utils.Log.Errorf("userinfo request failed: %v", err)
return fiber.NewError(fiber.StatusBadGateway, "failed to fetch user profile")
}
defer resp.Body.Close()
utils.Log.WithFields(logrus.Fields{"status": resp.StatusCode}).Info("sso userinfo response")
body, err := io.ReadAll(resp.Body)
if err != nil {
utils.Log.Errorf("failed to read userinfo response: %v", err)
return fiber.NewError(fiber.StatusBadGateway, "invalid user profile response")
}
// if sanitized, perms, ok := sanitizeUserInfoPayload(body); ok {
// if caps := capabilities.FromPermissions(perms); len(caps) > 0 {
// injectCapabilities(sanitized, caps)
// }
// return c.Status(resp.StatusCode).JSON(sanitized)
// }
if ct := resp.Header.Get("Content-Type"); ct != "" {
c.Set("Content-Type", ct)
} else {
c.Type("json")
}
return c.Status(resp.StatusCode).Send(body)
}
// Logout clears SSO cookies and removes any leftover PKCE session state.
func (h *Controller) Logout(c *fiber.Ctx) error {
requestedAlias := normalizeClientParam(c.Query("client"))
if requestedAlias == "" {
requestedAlias = normalizeClientParam(c.Query("client_id"))
}
var (
alias string
cfg config.SSOClientConfig
hasClientInfo bool
)
if requestedAlias != "" {
alias, cfg, hasClientInfo = findSSOClientConfig(requestedAlias)
}
accessName := resolveSSOCookieName(config.SSOAccessCookieName, "access")
refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh")
var accessToken, refreshToken string
var verification *sso.VerificationResult
if accessName != "" {
accessToken = strings.TrimSpace(c.Cookies(accessName))
}
if refreshName != "" {
refreshToken = strings.TrimSpace(c.Cookies(refreshName))
}
hadAccessCookie := accessToken != ""
hadRefreshCookie := refreshToken != ""
state := strings.TrimSpace(c.Query("state"))
if state != "" {
if err := h.store.Delete(c.Context(), state); err != nil {
utils.Log.Warnf("failed to delete pkce session during logout: %v", err)
}
}
if !hadAccessCookie && !hadRefreshCookie && state == "" {
return fiber.NewError(fiber.StatusUnauthorized, "not authenticated")
}
if hadAccessCookie {
if v, err := sso.VerifyAccessToken(accessToken); err != nil {
utils.Log.WithError(err).Warn("failed to verify access token during logout")
} else {
verification = v
if revoker := session.GetRevocationStore(); revoker != nil {
if err := revoker.MarkUserLogout(c.Context(), verification.UserID, time.Now().UTC()); err != nil {
utils.Log.WithError(err).Warn("failed to mark user logout")
}
}
h.revokeToken(c.Context(), accessToken, verification)
}
}
if refreshToken != "" {
h.revokeRefreshToken(c.Context(), refreshToken)
}
clearSSOCookie(c, accessName)
clearSSOCookie(c, refreshName)
redirectTarget := ""
rawReturn := strings.TrimSpace(c.Query("return_to"))
if hasClientInfo {
if rawReturn == "" {
rawReturn = cfg.DefaultReturnURI
}
if normalized, err := normalizeReturnTarget(rawReturn, cfg); err == nil {
redirectTarget = normalized
} else if rawReturn != "" {
utils.Log.WithError(err).Warn("invalid return_to during logout")
}
} else if rawReturn == "" && config.SSOPortalURL != "" {
if alias, singleCfg, ok := singleClientFromToken(verification); ok {
if normalized, err := normalizeReturnTarget(singleCfg.DefaultReturnURI, singleCfg); err == nil && normalized != "" {
redirectTarget = normalized
alias, cfg, hasClientInfo = alias, singleCfg, true
} else {
redirectTarget = config.SSOPortalURL
}
} else if accessToken != "" {
if alias, singleCfg, ok := h.singleClientFromSSO(c.Context(), accessToken); ok {
if normalized, err := normalizeReturnTarget(singleCfg.DefaultReturnURI, singleCfg); err == nil && normalized != "" {
redirectTarget = normalized
alias, cfg, hasClientInfo = alias, singleCfg, true
} else {
redirectTarget = config.SSOPortalURL
}
} else {
redirectTarget = config.SSOPortalURL
}
} else {
redirectTarget = config.SSOPortalURL
}
} else if rawReturn != "" {
if strings.HasPrefix(rawReturn, "/") && !strings.HasPrefix(rawReturn, "//") {
redirectTarget = rawReturn
}
}
utils.Log.WithFields(logrus.Fields{
"client": alias,
"state": state,
"redirect": redirectTarget,
}).Info("sso logout completed")
if redirectTarget != "" {
return c.Redirect(redirectTarget, fiber.StatusFound)
}
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "signed out"})
}
func singleSSOClient() (string, config.SSOClientConfig, bool) {
if len(config.SSOClients) != 1 {
return "", config.SSOClientConfig{}, false
}
for alias, cfg := range config.SSOClients {
if strings.TrimSpace(alias) == "" || strings.TrimSpace(cfg.PublicID) == "" {
return "", config.SSOClientConfig{}, false
}
return alias, cfg, true
}
return "", config.SSOClientConfig{}, false
}
func singleClientFromToken(verification *sso.VerificationResult) (string, config.SSOClientConfig, bool) {
if verification == nil || verification.Claims == nil {
return "", config.SSOClientConfig{}, false
}
return singleClientFromScopes(verification.Claims.Scopes())
}
func (h *Controller) singleClientFromSSO(ctx context.Context, accessToken string) (string, config.SSOClientConfig, bool) {
accessToken = strings.TrimSpace(accessToken)
if accessToken == "" {
return "", config.SSOClientConfig{}, false
}
meURL := strings.TrimSpace(config.SSOGetMeURL)
if meURL == "" {
return "", config.SSOClientConfig{}, false
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, meURL, nil)
if err != nil {
utils.Log.WithError(err).Warn("failed to build SSO getme request")
return "", config.SSOClientConfig{}, false
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := h.httpClient.Do(req)
if err != nil {
utils.Log.WithError(err).Warn("SSO getme request failed")
return "", config.SSOClientConfig{}, false
}
defer resp.Body.Close()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
utils.Log.WithField("status", resp.StatusCode).Warn("SSO getme responded with error")
return "", config.SSOClientConfig{}, false
}
var payload struct {
Data struct {
Roles []struct {
Client *struct {
Alias string `json:"alias"`
} `json:"client"`
} `json:"roles"`
} `json:"data"`
}
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
utils.Log.WithError(err).Warn("failed to decode SSO getme response")
return "", config.SSOClientConfig{}, false
}
aliases := make(map[string]struct{})
for _, role := range payload.Data.Roles {
if role.Client == nil {
continue
}
alias := strings.ToLower(strings.TrimSpace(role.Client.Alias))
if alias != "" {
aliases[alias] = struct{}{}
}
}
if len(aliases) != 1 {
return "", config.SSOClientConfig{}, false
}
for alias := range aliases {
if normalized, cfg, ok := findClientAlias(alias); ok {
return normalized, cfg, true
}
return "", config.SSOClientConfig{}, false
}
return "", config.SSOClientConfig{}, false
}
func singleClientFromScopes(scopes []string) (string, config.SSOClientConfig, bool) {
if len(scopes) == 0 {
return "", config.SSOClientConfig{}, false
}
seen := make(map[string]struct{})
for _, scope := range scopes {
if alias, ok := matchClientAliasFromScope(scope); ok {
seen[alias] = struct{}{}
}
if len(seen) > 1 {
return "", config.SSOClientConfig{}, false
}
}
if len(seen) != 1 {
return "", config.SSOClientConfig{}, false
}
for alias := range seen {
if normalized, cfg, ok := findClientAlias(alias); ok {
return normalized, cfg, true
}
}
return "", config.SSOClientConfig{}, false
}
func matchClientAliasFromScope(scope string) (string, bool) {
scope = strings.ToLower(strings.TrimSpace(scope))
if scope == "" {
return "", false
}
prefix := scope
if idx := strings.IndexAny(prefix, ".:"); idx > 0 {
prefix = prefix[:idx]
}
if prefix == "" {
return "", false
}
if alias, _, ok := findClientAlias(prefix); ok {
return alias, true
}
if prefix == "user-management" {
if alias, _, ok := findClientAlias("umgmt"); ok {
return alias, true
}
}
if prefix == "umgmt" {
if alias, _, ok := findClientAlias("user-management"); ok {
return alias, true
}
}
return "", false
}
func findClientAlias(alias string) (string, config.SSOClientConfig, bool) {
alias = strings.TrimSpace(alias)
if alias == "" {
return "", config.SSOClientConfig{}, false
}
if cfg, ok := config.SSOClients[alias]; ok && strings.TrimSpace(cfg.PublicID) != "" {
return alias, cfg, true
}
for key, cfg := range config.SSOClients {
if strings.EqualFold(key, alias) && strings.TrimSpace(cfg.PublicID) != "" {
return key, cfg, true
}
}
return "", config.SSOClientConfig{}, false
}
func defaultSSOClientAlias() string {
for alias := range config.SSOClients {
if strings.TrimSpace(alias) == "" {
continue
}
return alias
}
return ""
}
func buildStartRedirect(alias string) string {
alias = strings.TrimSpace(alias)
if alias == "" {
return ""
}
return "/api/sso/start?client=" + url.QueryEscape(alias)
}
func (h *Controller) revokeToken(ctx context.Context, token string, verification *sso.VerificationResult) {
if h.revoker == nil || verification == nil || verification.Claims == nil {
return
}
fingerprint := session.TokenFingerprint(token)
if fingerprint == "" {
return
}
if verification.Claims.ExpiresAt == nil {
utils.Log.Warn("access token missing expiry claim")
return
}
ttl := time.Until(verification.Claims.ExpiresAt.Time)
if ttl <= 0 {
return
}
if ttl < time.Second {
ttl = time.Second
}
if err := h.revoker.Revoke(ctx, fingerprint, ttl); err != nil {
utils.Log.WithError(err).Warn("failed to revoke access token")
}
}
func (h *Controller) revokeRefreshToken(ctx context.Context, token string) {
if h.revoker == nil {
return
}
fingerprint := session.TokenFingerprint(token)
if fingerprint == "" {
return
}
const refreshTTL = 30 * 24 * time.Hour
if err := h.revoker.Revoke(ctx, fingerprint, refreshTTL); err != nil {
utils.Log.WithError(err).Warn("failed to revoke refresh token")
}
}
func issueCookies(c *fiber.Ctx, tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
IDToken string `json:"id_token"`
Error string `json:"error"`
Description string `json:"error_description"`
}, verification *sso.VerificationResult) {
if revoker := session.GetRevocationStore(); revoker != nil && verification != nil {
if err := revoker.ClearUserLogout(c.Context(), verification.UserID); err != nil {
utils.Log.WithError(err).Warn("failed to clear logout marker")
}
}
accessName := resolveSSOCookieName(config.SSOAccessCookieName, "access")
refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh")
maxAge := tokenResp.ExpiresIn
if maxAge <= 0 {
maxAge = int(15 * time.Minute.Seconds())
}
sameSite := config.SSOCookieSameSite
if sameSite == "" {
sameSite = "Lax"
}
cookieDomain := config.SSOCookieDomain
cookieAccess := &fiber.Cookie{
Name: accessName,
Value: tokenResp.AccessToken,
Path: "/",
Domain: cookieDomain,
HTTPOnly: true,
Secure: config.SSOCookieSecure,
SameSite: sameSite,
MaxAge: maxAge,
}
c.Cookie(cookieAccess)
if tokenResp.RefreshToken != "" {
cookieRefresh := &fiber.Cookie{
Name: refreshName,
Value: tokenResp.RefreshToken,
Path: "/",
Domain: cookieDomain,
HTTPOnly: true,
Secure: config.SSOCookieSecure,
SameSite: sameSite,
MaxAge: int((time.Hour * 24 * 30).Seconds()),
}
c.Cookie(cookieRefresh)
}
// Optional: expose limited info via headers for FE debugging (avoid tokens)
c.Set("X-Auth-User", fmt.Sprintf("%d", verification.UserID))
}
func clearSSOCookie(c *fiber.Ctx, name string) {
if name == "" {
return
}
sameSite := config.SSOCookieSameSite
if sameSite == "" {
sameSite = "Lax"
}
c.Cookie(&fiber.Cookie{
Name: name,
Value: "",
Path: "/",
Domain: config.SSOCookieDomain,
HTTPOnly: true,
Secure: config.SSOCookieSecure,
SameSite: sameSite,
Expires: time.Unix(0, 0),
MaxAge: -1,
})
}
func resolveSSOCookieName(configuredName, fallback string) string {
name := strings.TrimSpace(configuredName)
if name != "" {
return name
}
return strings.TrimSpace(fallback)
}
func normalizeClientParam(raw string) string {
value := strings.TrimSpace(raw)
if value == "" {
return ""
}
if idx := strings.Index(value, "|"); idx >= 0 {
value = value[:idx]
}
value = strings.TrimSpace(value)
return strings.ToLower(value)
}
func sanitizeUserInfoPayload(body []byte) (map[string]any, []string, bool) {
if len(body) == 0 {
return map[string]any{}, nil, true
}
var payload any
if err := json.Unmarshal(body, &payload); err != nil {
return nil, nil, false
}
perms := collectPermissionNames(payload)
sensitive := map[string]struct{}{
"roles": {},
"permissions": {},
}
payload = scrubSensitiveKeys(payload, sensitive)
sanitized, ok := payload.(map[string]any)
if !ok {
sanitized = map[string]any{"data": payload}
}
return sanitized, perms, true
}
func scrubSensitiveKeys(value any, sensitive map[string]struct{}) any {
switch v := value.(type) {
case map[string]any:
for key, val := range v {
if _, ok := sensitive[strings.ToLower(key)]; ok {
delete(v, key)
continue
}
v[key] = scrubSensitiveKeys(val, sensitive)
}
return v
case []any:
for i, item := range v {
v[i] = scrubSensitiveKeys(item, sensitive)
}
return v
default:
return value
}
}
func collectPermissionNames(value any) []string {
names := make(map[string]struct{})
collectPermissionRec(value, names)
out := make([]string, 0, len(names))
for name := range names {
out = append(out, name)
}
return out
}
func collectPermissionRec(value any, acc map[string]struct{}) {
switch v := value.(type) {
case map[string]any:
for key, val := range v {
if strings.EqualFold(key, "permissions") {
if arr, ok := val.([]any); ok {
for _, item := range arr {
if perm, ok := item.(map[string]any); ok {
if name, ok := perm["name"].(string); ok && strings.TrimSpace(name) != "" {
acc[strings.ToLower(strings.TrimSpace(name))] = struct{}{}
}
}
}
}
} else {
collectPermissionRec(val, acc)
}
}
case []any:
for _, item := range v {
collectPermissionRec(item, acc)
}
}
}
func injectCapabilities(payload map[string]any, caps map[string]bool) {
if len(caps) == 0 {
return
}
if data, ok := payload["data"].(map[string]any); ok {
data["capabilities"] = caps
return
}
payload["capabilities"] = caps
}
func findSSOClientConfig(requestedAlias string) (string, config.SSOClientConfig, bool) {
if requestedAlias == "" {
return "", config.SSOClientConfig{}, false
}
if cfg, ok := config.SSOClients[requestedAlias]; ok && strings.TrimSpace(cfg.PublicID) != "" {
return requestedAlias, cfg, true
}
for alias, cfg := range config.SSOClients {
if strings.EqualFold(strings.TrimSpace(cfg.PublicID), requestedAlias) && strings.TrimSpace(cfg.PublicID) != "" {
return alias, cfg, true
}
}
return "", config.SSOClientConfig{}, false
}
func normalizeReturnTarget(returnTo string, cfg config.SSOClientConfig) (string, error) {
returnTo = strings.TrimSpace(returnTo)
if returnTo == "" {
return "", nil
}
if strings.HasPrefix(returnTo, "//") {
return "", fmt.Errorf("invalid return_to")
}
if strings.HasPrefix(returnTo, "/") {
return returnTo, nil
}
parsed, err := url.Parse(returnTo)
if err != nil {
return "", fmt.Errorf("invalid return_to")
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return "", fmt.Errorf("invalid return_to scheme")
}
allowedOrigins := make(map[string]struct{})
if cfg.DefaultReturnURI != "" {
if u, err := url.Parse(cfg.DefaultReturnURI); err == nil && u.Host != "" {
allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{}
}
}
for _, origin := range cfg.AllowedReturnOrigins {
origin = strings.TrimSpace(origin)
if origin == "" {
continue
}
if u, err := url.Parse(origin); err == nil && u.Host != "" && (u.Scheme == "http" || u.Scheme == "https") {
allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{}
}
}
if len(allowedOrigins) > 0 {
origin := parsed.Scheme + "://" + parsed.Host
if _, ok := allowedOrigins[origin]; !ok {
return "", fmt.Errorf("return_to origin not allowed")
}
}
return parsed.String(), nil
}