mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-20 13:31:56 +00:00
32772a63c8
This reverts commit 26bf7f165e.
991 lines
29 KiB
Go
991 lines
29 KiB
Go
package controllers
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/sirupsen/logrus"
|
|
|
|
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
|
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
|
|
sso "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/verifier"
|
|
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
|
"gitlab.com/mbugroup/lti-api.git/internal/utils/secure"
|
|
)
|
|
|
|
// Controller manages the SSO start & callback flow using PKCE.
|
|
type Controller struct {
|
|
httpClient *http.Client
|
|
store *session.Store
|
|
revoker *session.RevocationStore
|
|
}
|
|
|
|
func NewController(client *http.Client, store *session.Store, revoker *session.RevocationStore) *Controller {
|
|
return &Controller{httpClient: client, store: store, revoker: revoker}
|
|
}
|
|
|
|
// Start handles GET /sso/start requests and redirects users to the central SSO authorize endpoint.
|
|
func (h *Controller) Start(c *fiber.Ctx) error {
|
|
requestedAlias := normalizeClientParam(c.Query("client"))
|
|
if requestedAlias == "" {
|
|
requestedAlias = normalizeClientParam(c.Query("client_id"))
|
|
}
|
|
if requestedAlias == "" {
|
|
return fiber.NewError(fiber.StatusBadRequest, "missing client")
|
|
}
|
|
|
|
alias, cfg, ok := findSSOClientConfig(requestedAlias)
|
|
if !ok || cfg.PublicID == "" {
|
|
return fiber.NewError(fiber.StatusBadRequest, "unknown client")
|
|
}
|
|
|
|
authorizeEndpoint := strings.TrimSpace(config.SSOAuthorizeURL)
|
|
if authorizeEndpoint == "" {
|
|
return fiber.NewError(fiber.StatusInternalServerError, "authorize endpoint not configured")
|
|
}
|
|
|
|
state, err := secure.RandomString(48)
|
|
if err != nil {
|
|
utils.Log.Errorf("generate state failed: %v", err)
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
|
}
|
|
|
|
nonce, err := secure.RandomString(32)
|
|
if err != nil {
|
|
utils.Log.Errorf("generate nonce failed: %v", err)
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
|
}
|
|
|
|
codeVerifier, err := secure.PKCECodeVerifier(96)
|
|
if err != nil {
|
|
utils.Log.Errorf("generate code verifier failed: %v", err)
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
|
}
|
|
|
|
digest := sha256.Sum256([]byte(codeVerifier))
|
|
challenge := secure.Base64URLEncode(digest[:])
|
|
|
|
authorizeURL, err := url.Parse(authorizeEndpoint)
|
|
if err != nil {
|
|
return fiber.NewError(fiber.StatusInternalServerError, "invalid authorize endpoint")
|
|
}
|
|
|
|
scope := cfg.Scope
|
|
if scope == "" {
|
|
scope = "openid profile"
|
|
}
|
|
if !strings.Contains(" "+scope+" ", " openid ") {
|
|
scope = scope + " openid"
|
|
}
|
|
|
|
rawReturn := strings.TrimSpace(c.Query("return_to"))
|
|
if rawReturn == "" {
|
|
rawReturn = cfg.DefaultReturnURI
|
|
}
|
|
|
|
returnTo, err := normalizeReturnTarget(rawReturn, cfg)
|
|
if err != nil {
|
|
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
|
}
|
|
|
|
query := authorizeURL.Query()
|
|
query.Set("response_type", "code")
|
|
query.Set("client_id", cfg.PublicID)
|
|
query.Set("redirect_uri", cfg.RedirectURI)
|
|
query.Set("scope", strings.TrimSpace(scope))
|
|
query.Set("state", state)
|
|
query.Set("code_challenge", challenge)
|
|
query.Set("code_challenge_method", "S256")
|
|
query.Set("nonce", nonce)
|
|
// if prompt := strings.TrimSpace(cfg.Prompt); prompt != "" {
|
|
// query.Set("prompt", prompt)
|
|
// }
|
|
if extraPrompt := strings.TrimSpace(c.Query("prompt")); extraPrompt != "" {
|
|
query.Set("prompt", extraPrompt)
|
|
}
|
|
authorizeURL.RawQuery = query.Encode()
|
|
|
|
payload := &session.PKCESession{
|
|
CodeVerifier: codeVerifier,
|
|
Nonce: nonce,
|
|
ClientAlias: alias,
|
|
ClientID: cfg.PublicID,
|
|
RedirectURI: cfg.RedirectURI,
|
|
Scope: strings.TrimSpace(scope),
|
|
ReturnTo: returnTo,
|
|
CreatedAt: time.Now().UTC(),
|
|
}
|
|
|
|
if err := h.store.Save(c.Context(), state, payload); err != nil {
|
|
utils.Log.Errorf("store pkce session failed: %v", err)
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
|
}
|
|
|
|
utils.Log.WithFields(logrus.Fields{
|
|
"client": alias,
|
|
"state": state,
|
|
"return_to": returnTo,
|
|
}).Info("sso start redirect")
|
|
|
|
return c.Redirect(authorizeURL.String(), fiber.StatusFound)
|
|
}
|
|
|
|
// Refresh exchanges the current SSO refresh token for a new access/refresh pair
|
|
// without redirecting the browser to the SSO login page.
|
|
func (h *Controller) Refresh(c *fiber.Ctx) error {
|
|
refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh")
|
|
refreshToken := strings.TrimSpace(c.Cookies(refreshName))
|
|
if refreshToken == "" {
|
|
if target := buildStartRedirect(defaultSSOClientAlias()); target != "" {
|
|
return c.Redirect(target, fiber.StatusFound)
|
|
}
|
|
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
|
}
|
|
|
|
tokenEndpoint := strings.TrimSpace(config.SSOTokenURL)
|
|
if tokenEndpoint == "" {
|
|
return fiber.NewError(fiber.StatusInternalServerError, "token endpoint not configured")
|
|
}
|
|
|
|
form := url.Values{}
|
|
form.Set("grant_type", "refresh_token")
|
|
form.Set("refresh_token", refreshToken)
|
|
|
|
req, err := http.NewRequestWithContext(c.Context(), http.MethodPost, tokenEndpoint, strings.NewReader(form.Encode()))
|
|
if err != nil {
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to create refresh request")
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
resp, err := h.httpClient.Do(req)
|
|
if err != nil {
|
|
utils.Log.Errorf("token refresh request failed: %v", err)
|
|
return fiber.NewError(fiber.StatusBadGateway, "failed to refresh access token")
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode >= 400 {
|
|
utils.Log.Warnf("token refresh response status %d", resp.StatusCode)
|
|
if resp.StatusCode == fiber.StatusTooManyRequests {
|
|
return fiber.NewError(fiber.StatusTooManyRequests, "Too many attempts, please slow down")
|
|
}
|
|
if target := buildStartRedirect(defaultSSOClientAlias()); target != "" {
|
|
return c.Redirect(target, fiber.StatusFound)
|
|
}
|
|
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
|
}
|
|
|
|
var tokenResp refreshTokenResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
|
return fiber.NewError(fiber.StatusBadGateway, "invalid token response")
|
|
}
|
|
if tokenResp.Error != "" {
|
|
return fiber.NewError(fiber.StatusBadGateway, tokenResp.Description)
|
|
}
|
|
if tokenResp.AccessToken == "" {
|
|
return fiber.NewError(fiber.StatusBadGateway, "missing access token")
|
|
}
|
|
|
|
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
|
|
if err != nil {
|
|
utils.Log.Errorf("access token verification failed: %v", err)
|
|
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
|
|
}
|
|
|
|
issueCookies(c, struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
Scope string `json:"scope"`
|
|
IDToken string `json:"id_token"`
|
|
Error string `json:"error"`
|
|
Description string `json:"error_description"`
|
|
}{
|
|
AccessToken: tokenResp.AccessToken,
|
|
RefreshToken: tokenResp.RefreshToken,
|
|
TokenType: tokenResp.TokenType,
|
|
ExpiresIn: tokenResp.ExpiresIn,
|
|
Scope: tokenResp.Scope,
|
|
IDToken: tokenResp.IDToken,
|
|
Error: tokenResp.Error,
|
|
Description: tokenResp.Description,
|
|
}, verification)
|
|
|
|
utils.Log.WithFields(logrus.Fields{
|
|
"user_id": verification.UserID,
|
|
}).Info("sso refresh successful")
|
|
|
|
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "ok"})
|
|
}
|
|
|
|
// Callback handles the redirect from SSO containing the authorization code.
|
|
func (h *Controller) Callback(c *fiber.Ctx) error {
|
|
state := strings.TrimSpace(c.Query("state"))
|
|
code := strings.TrimSpace(c.Query("code"))
|
|
if state == "" || code == "" {
|
|
return fiber.NewError(fiber.StatusBadRequest, "missing code or state")
|
|
}
|
|
|
|
sessionData, err := h.store.Get(c.Context(), state)
|
|
if err != nil {
|
|
utils.Log.Errorf("load pkce session failed: %v", err)
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to validate authorization state")
|
|
}
|
|
if sessionData == nil {
|
|
return fiber.NewError(fiber.StatusBadRequest, "authorization state not found or expired")
|
|
}
|
|
defer func() {
|
|
if err := h.store.Delete(context.Background(), state); err != nil {
|
|
utils.Log.Warnf("failed to delete pkce session: %v", err)
|
|
}
|
|
}()
|
|
|
|
tokenEndpoint := strings.TrimSpace(config.SSOTokenURL)
|
|
if tokenEndpoint == "" {
|
|
return fiber.NewError(fiber.StatusInternalServerError, "token endpoint not configured")
|
|
}
|
|
|
|
form := url.Values{}
|
|
form.Set("grant_type", "authorization_code")
|
|
form.Set("code", code)
|
|
form.Set("code_verifier", sessionData.CodeVerifier)
|
|
form.Set("redirect_uri", sessionData.RedirectURI)
|
|
form.Set("client_id", sessionData.ClientID)
|
|
|
|
req, err := http.NewRequestWithContext(c.Context(), http.MethodPost, tokenEndpoint, strings.NewReader(form.Encode()))
|
|
if err != nil {
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to create token request")
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
resp, err := h.httpClient.Do(req)
|
|
if err != nil {
|
|
utils.Log.Errorf("token request failed: %v", err)
|
|
return fiber.NewError(fiber.StatusBadGateway, "failed to exchange authorization code")
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode >= 400 {
|
|
utils.Log.Warnf("token response status %d", resp.StatusCode)
|
|
return fiber.NewError(fiber.StatusBadGateway, "token exchange rejected")
|
|
}
|
|
|
|
var tokenResp struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
Scope string `json:"scope"`
|
|
IDToken string `json:"id_token"`
|
|
Error string `json:"error"`
|
|
Description string `json:"error_description"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
|
return fiber.NewError(fiber.StatusBadGateway, "invalid token response")
|
|
}
|
|
if tokenResp.Error != "" {
|
|
return fiber.NewError(fiber.StatusBadGateway, tokenResp.Description)
|
|
}
|
|
if tokenResp.AccessToken == "" {
|
|
return fiber.NewError(fiber.StatusBadGateway, "missing access token")
|
|
}
|
|
|
|
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
|
|
if err != nil {
|
|
utils.Log.Errorf("access token verification failed: %v", err)
|
|
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
|
|
}
|
|
|
|
// prepare cookies
|
|
issueCookies(c, tokenResp, verification)
|
|
|
|
redirectTarget := sessionData.ReturnTo
|
|
if redirectTarget == "" {
|
|
redirectTarget = "/"
|
|
}
|
|
|
|
utils.Log.WithFields(logrus.Fields{
|
|
"client": sessionData.ClientAlias,
|
|
"user_id": verification.UserID,
|
|
"return_to": redirectTarget,
|
|
}).Info("sso callback successful")
|
|
|
|
return c.Redirect(redirectTarget, fiber.StatusFound)
|
|
}
|
|
|
|
// UserInfo proxies the user profile from the central SSO so the frontend can obtain
|
|
// enriched user metadata (roles, permissions, etc.) without exposing tokens to the browser.
|
|
func (h *Controller) UserInfo(c *fiber.Ctx) error {
|
|
accessName := config.SSOAccessCookieName
|
|
if accessName == "" {
|
|
accessName = "sso_access"
|
|
}
|
|
|
|
token := strings.TrimSpace(c.Cookies(accessName))
|
|
tokenFromCookie := token != ""
|
|
|
|
if !tokenFromCookie {
|
|
authHeader := strings.TrimSpace(c.Get("Authorization"))
|
|
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
|
token = strings.TrimSpace(authHeader[7:])
|
|
}
|
|
}
|
|
|
|
if token == "" {
|
|
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
|
}
|
|
|
|
if revoker := session.GetRevocationStore(); revoker != nil {
|
|
if fingerprint := session.TokenFingerprint(token); fingerprint != "" {
|
|
revoked, err := revoker.IsRevoked(c.Context(), fingerprint)
|
|
if err != nil {
|
|
utils.Log.WithError(err).Warn("failed to check token revocation for userinfo")
|
|
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
|
}
|
|
if revoked {
|
|
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
|
}
|
|
}
|
|
}
|
|
|
|
if _, err := sso.VerifyAccessToken(token); err != nil {
|
|
utils.Log.WithError(err).Warn("access token verification failed for userinfo")
|
|
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
|
}
|
|
|
|
endpoint := strings.TrimSpace(config.SSOGetMeURL)
|
|
if endpoint == "" {
|
|
return fiber.NewError(fiber.StatusInternalServerError, "userinfo endpoint not configured")
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(c.Context(), http.MethodGet, endpoint, nil)
|
|
if err != nil {
|
|
utils.Log.Errorf("failed to build userinfo request: %v", err)
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare userinfo request")
|
|
}
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
// SSO /auth/get-me expects the access cookie; add Authorization as well for compatibility.
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
|
if tokenFromCookie {
|
|
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", accessName, token))
|
|
}
|
|
|
|
resp, err := h.httpClient.Do(req)
|
|
if err != nil {
|
|
utils.Log.Errorf("userinfo request failed: %v", err)
|
|
return fiber.NewError(fiber.StatusBadGateway, "failed to fetch user profile")
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
utils.Log.WithFields(logrus.Fields{"status": resp.StatusCode}).Info("sso userinfo response")
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
utils.Log.Errorf("failed to read userinfo response: %v", err)
|
|
return fiber.NewError(fiber.StatusBadGateway, "invalid user profile response")
|
|
}
|
|
|
|
// if sanitized, perms, ok := sanitizeUserInfoPayload(body); ok {
|
|
// if caps := capabilities.FromPermissions(perms); len(caps) > 0 {
|
|
// injectCapabilities(sanitized, caps)
|
|
// }
|
|
// return c.Status(resp.StatusCode).JSON(sanitized)
|
|
// }
|
|
|
|
if ct := resp.Header.Get("Content-Type"); ct != "" {
|
|
c.Set("Content-Type", ct)
|
|
} else {
|
|
c.Type("json")
|
|
}
|
|
|
|
return c.Status(resp.StatusCode).Send(body)
|
|
}
|
|
|
|
// Logout clears SSO cookies and removes any leftover PKCE session state.
|
|
func (h *Controller) Logout(c *fiber.Ctx) error {
|
|
requestedAlias := normalizeClientParam(c.Query("client"))
|
|
if requestedAlias == "" {
|
|
requestedAlias = normalizeClientParam(c.Query("client_id"))
|
|
}
|
|
var (
|
|
alias string
|
|
cfg config.SSOClientConfig
|
|
hasClientInfo bool
|
|
)
|
|
if requestedAlias != "" {
|
|
alias, cfg, hasClientInfo = findSSOClientConfig(requestedAlias)
|
|
}
|
|
|
|
accessName := resolveSSOCookieName(config.SSOAccessCookieName, "access")
|
|
refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh")
|
|
|
|
var accessToken, refreshToken string
|
|
var verification *sso.VerificationResult
|
|
if accessName != "" {
|
|
accessToken = strings.TrimSpace(c.Cookies(accessName))
|
|
}
|
|
if refreshName != "" {
|
|
refreshToken = strings.TrimSpace(c.Cookies(refreshName))
|
|
}
|
|
hadAccessCookie := accessToken != ""
|
|
hadRefreshCookie := refreshToken != ""
|
|
|
|
state := strings.TrimSpace(c.Query("state"))
|
|
if state != "" {
|
|
if err := h.store.Delete(c.Context(), state); err != nil {
|
|
utils.Log.Warnf("failed to delete pkce session during logout: %v", err)
|
|
}
|
|
}
|
|
|
|
if !hadAccessCookie && !hadRefreshCookie && state == "" {
|
|
return fiber.NewError(fiber.StatusUnauthorized, "not authenticated")
|
|
}
|
|
|
|
if hadAccessCookie {
|
|
if v, err := sso.VerifyAccessToken(accessToken); err != nil {
|
|
utils.Log.WithError(err).Warn("failed to verify access token during logout")
|
|
} else {
|
|
verification = v
|
|
if revoker := session.GetRevocationStore(); revoker != nil {
|
|
if err := revoker.MarkUserLogout(c.Context(), verification.UserID, time.Now().UTC()); err != nil {
|
|
utils.Log.WithError(err).Warn("failed to mark user logout")
|
|
}
|
|
}
|
|
h.revokeToken(c.Context(), accessToken, verification)
|
|
}
|
|
}
|
|
if refreshToken != "" {
|
|
h.revokeRefreshToken(c.Context(), refreshToken)
|
|
}
|
|
|
|
clearSSOCookie(c, accessName)
|
|
clearSSOCookie(c, refreshName)
|
|
|
|
redirectTarget := ""
|
|
rawReturn := strings.TrimSpace(c.Query("return_to"))
|
|
if hasClientInfo {
|
|
if rawReturn == "" {
|
|
rawReturn = cfg.DefaultReturnURI
|
|
}
|
|
if normalized, err := normalizeReturnTarget(rawReturn, cfg); err == nil {
|
|
redirectTarget = normalized
|
|
} else if rawReturn != "" {
|
|
utils.Log.WithError(err).Warn("invalid return_to during logout")
|
|
}
|
|
} else if rawReturn == "" && config.SSOPortalURL != "" {
|
|
if alias, singleCfg, ok := singleClientFromToken(verification); ok {
|
|
if normalized, err := normalizeReturnTarget(singleCfg.DefaultReturnURI, singleCfg); err == nil && normalized != "" {
|
|
redirectTarget = normalized
|
|
alias, cfg, hasClientInfo = alias, singleCfg, true
|
|
} else {
|
|
redirectTarget = config.SSOPortalURL
|
|
}
|
|
} else if accessToken != "" {
|
|
if alias, singleCfg, ok := h.singleClientFromSSO(c.Context(), accessToken); ok {
|
|
if normalized, err := normalizeReturnTarget(singleCfg.DefaultReturnURI, singleCfg); err == nil && normalized != "" {
|
|
redirectTarget = normalized
|
|
alias, cfg, hasClientInfo = alias, singleCfg, true
|
|
} else {
|
|
redirectTarget = config.SSOPortalURL
|
|
}
|
|
} else {
|
|
redirectTarget = config.SSOPortalURL
|
|
}
|
|
} else {
|
|
redirectTarget = config.SSOPortalURL
|
|
}
|
|
} else if rawReturn != "" {
|
|
if strings.HasPrefix(rawReturn, "/") && !strings.HasPrefix(rawReturn, "//") {
|
|
redirectTarget = rawReturn
|
|
}
|
|
}
|
|
|
|
utils.Log.WithFields(logrus.Fields{
|
|
"client": alias,
|
|
"state": state,
|
|
"redirect": redirectTarget,
|
|
}).Info("sso logout completed")
|
|
|
|
if redirectTarget != "" {
|
|
return c.Redirect(redirectTarget, fiber.StatusFound)
|
|
}
|
|
|
|
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "signed out"})
|
|
}
|
|
|
|
func singleSSOClient() (string, config.SSOClientConfig, bool) {
|
|
if len(config.SSOClients) != 1 {
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
for alias, cfg := range config.SSOClients {
|
|
if strings.TrimSpace(alias) == "" || strings.TrimSpace(cfg.PublicID) == "" {
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
return alias, cfg, true
|
|
}
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
|
|
func singleClientFromToken(verification *sso.VerificationResult) (string, config.SSOClientConfig, bool) {
|
|
if verification == nil || verification.Claims == nil {
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
return singleClientFromScopes(verification.Claims.Scopes())
|
|
}
|
|
|
|
func (h *Controller) singleClientFromSSO(ctx context.Context, accessToken string) (string, config.SSOClientConfig, bool) {
|
|
accessToken = strings.TrimSpace(accessToken)
|
|
if accessToken == "" {
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
meURL := strings.TrimSpace(config.SSOGetMeURL)
|
|
if meURL == "" {
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, meURL, nil)
|
|
if err != nil {
|
|
utils.Log.WithError(err).Warn("failed to build SSO getme request")
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
|
|
resp, err := h.httpClient.Do(req)
|
|
if err != nil {
|
|
utils.Log.WithError(err).Warn("SSO getme request failed")
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
|
utils.Log.WithField("status", resp.StatusCode).Warn("SSO getme responded with error")
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
|
|
var payload struct {
|
|
Data struct {
|
|
Roles []struct {
|
|
Client *struct {
|
|
Alias string `json:"alias"`
|
|
} `json:"client"`
|
|
} `json:"roles"`
|
|
} `json:"data"`
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
|
utils.Log.WithError(err).Warn("failed to decode SSO getme response")
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
|
|
aliases := make(map[string]struct{})
|
|
for _, role := range payload.Data.Roles {
|
|
if role.Client == nil {
|
|
continue
|
|
}
|
|
alias := strings.ToLower(strings.TrimSpace(role.Client.Alias))
|
|
if alias != "" {
|
|
aliases[alias] = struct{}{}
|
|
}
|
|
}
|
|
if len(aliases) != 1 {
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
for alias := range aliases {
|
|
if normalized, cfg, ok := findClientAlias(alias); ok {
|
|
return normalized, cfg, true
|
|
}
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
|
|
func singleClientFromScopes(scopes []string) (string, config.SSOClientConfig, bool) {
|
|
if len(scopes) == 0 {
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
seen := make(map[string]struct{})
|
|
for _, scope := range scopes {
|
|
if alias, ok := matchClientAliasFromScope(scope); ok {
|
|
seen[alias] = struct{}{}
|
|
}
|
|
if len(seen) > 1 {
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
}
|
|
if len(seen) != 1 {
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
for alias := range seen {
|
|
if normalized, cfg, ok := findClientAlias(alias); ok {
|
|
return normalized, cfg, true
|
|
}
|
|
}
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
|
|
func matchClientAliasFromScope(scope string) (string, bool) {
|
|
scope = strings.ToLower(strings.TrimSpace(scope))
|
|
if scope == "" {
|
|
return "", false
|
|
}
|
|
prefix := scope
|
|
if idx := strings.IndexAny(prefix, ".:"); idx > 0 {
|
|
prefix = prefix[:idx]
|
|
}
|
|
if prefix == "" {
|
|
return "", false
|
|
}
|
|
if alias, _, ok := findClientAlias(prefix); ok {
|
|
return alias, true
|
|
}
|
|
if prefix == "user-management" {
|
|
if alias, _, ok := findClientAlias("umgmt"); ok {
|
|
return alias, true
|
|
}
|
|
}
|
|
if prefix == "umgmt" {
|
|
if alias, _, ok := findClientAlias("user-management"); ok {
|
|
return alias, true
|
|
}
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
func findClientAlias(alias string) (string, config.SSOClientConfig, bool) {
|
|
alias = strings.TrimSpace(alias)
|
|
if alias == "" {
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
if cfg, ok := config.SSOClients[alias]; ok && strings.TrimSpace(cfg.PublicID) != "" {
|
|
return alias, cfg, true
|
|
}
|
|
for key, cfg := range config.SSOClients {
|
|
if strings.EqualFold(key, alias) && strings.TrimSpace(cfg.PublicID) != "" {
|
|
return key, cfg, true
|
|
}
|
|
}
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
|
|
func defaultSSOClientAlias() string {
|
|
for alias := range config.SSOClients {
|
|
if strings.TrimSpace(alias) == "" {
|
|
continue
|
|
}
|
|
return alias
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func buildStartRedirect(alias string) string {
|
|
alias = strings.TrimSpace(alias)
|
|
if alias == "" {
|
|
return ""
|
|
}
|
|
return "/api/sso/start?client=" + url.QueryEscape(alias)
|
|
}
|
|
|
|
func (h *Controller) revokeToken(ctx context.Context, token string, verification *sso.VerificationResult) {
|
|
if h.revoker == nil || verification == nil || verification.Claims == nil {
|
|
return
|
|
}
|
|
fingerprint := session.TokenFingerprint(token)
|
|
if fingerprint == "" {
|
|
return
|
|
}
|
|
if verification.Claims.ExpiresAt == nil {
|
|
utils.Log.Warn("access token missing expiry claim")
|
|
return
|
|
}
|
|
ttl := time.Until(verification.Claims.ExpiresAt.Time)
|
|
if ttl <= 0 {
|
|
return
|
|
}
|
|
if ttl < time.Second {
|
|
ttl = time.Second
|
|
}
|
|
if err := h.revoker.Revoke(ctx, fingerprint, ttl); err != nil {
|
|
utils.Log.WithError(err).Warn("failed to revoke access token")
|
|
}
|
|
}
|
|
|
|
func (h *Controller) revokeRefreshToken(ctx context.Context, token string) {
|
|
if h.revoker == nil {
|
|
return
|
|
}
|
|
fingerprint := session.TokenFingerprint(token)
|
|
if fingerprint == "" {
|
|
return
|
|
}
|
|
const refreshTTL = 30 * 24 * time.Hour
|
|
if err := h.revoker.Revoke(ctx, fingerprint, refreshTTL); err != nil {
|
|
utils.Log.WithError(err).Warn("failed to revoke refresh token")
|
|
}
|
|
}
|
|
|
|
func issueCookies(c *fiber.Ctx, tokenResp struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
Scope string `json:"scope"`
|
|
IDToken string `json:"id_token"`
|
|
Error string `json:"error"`
|
|
Description string `json:"error_description"`
|
|
}, verification *sso.VerificationResult) {
|
|
if revoker := session.GetRevocationStore(); revoker != nil && verification != nil {
|
|
if err := revoker.ClearUserLogout(c.Context(), verification.UserID); err != nil {
|
|
utils.Log.WithError(err).Warn("failed to clear logout marker")
|
|
}
|
|
}
|
|
|
|
accessName := resolveSSOCookieName(config.SSOAccessCookieName, "access")
|
|
refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh")
|
|
maxAge := tokenResp.ExpiresIn
|
|
if maxAge <= 0 {
|
|
maxAge = int(15 * time.Minute.Seconds())
|
|
}
|
|
|
|
sameSite := config.SSOCookieSameSite
|
|
if sameSite == "" {
|
|
sameSite = "Lax"
|
|
}
|
|
|
|
cookieDomain := config.SSOCookieDomain
|
|
|
|
cookieAccess := &fiber.Cookie{
|
|
Name: accessName,
|
|
Value: tokenResp.AccessToken,
|
|
Path: "/",
|
|
Domain: cookieDomain,
|
|
HTTPOnly: true,
|
|
Secure: config.SSOCookieSecure,
|
|
SameSite: sameSite,
|
|
MaxAge: maxAge,
|
|
}
|
|
c.Cookie(cookieAccess)
|
|
if tokenResp.RefreshToken != "" {
|
|
cookieRefresh := &fiber.Cookie{
|
|
Name: refreshName,
|
|
Value: tokenResp.RefreshToken,
|
|
Path: "/",
|
|
Domain: cookieDomain,
|
|
HTTPOnly: true,
|
|
Secure: config.SSOCookieSecure,
|
|
SameSite: sameSite,
|
|
MaxAge: int((time.Hour * 24 * 30).Seconds()),
|
|
}
|
|
c.Cookie(cookieRefresh)
|
|
}
|
|
|
|
// Optional: expose limited info via headers for FE debugging (avoid tokens)
|
|
c.Set("X-Auth-User", fmt.Sprintf("%d", verification.UserID))
|
|
}
|
|
|
|
func clearSSOCookie(c *fiber.Ctx, name string) {
|
|
if name == "" {
|
|
return
|
|
}
|
|
|
|
sameSite := config.SSOCookieSameSite
|
|
if sameSite == "" {
|
|
sameSite = "Lax"
|
|
}
|
|
|
|
c.Cookie(&fiber.Cookie{
|
|
Name: name,
|
|
Value: "",
|
|
Path: "/",
|
|
Domain: config.SSOCookieDomain,
|
|
HTTPOnly: true,
|
|
Secure: config.SSOCookieSecure,
|
|
SameSite: sameSite,
|
|
Expires: time.Unix(0, 0),
|
|
MaxAge: -1,
|
|
})
|
|
}
|
|
|
|
func resolveSSOCookieName(configuredName, fallback string) string {
|
|
name := strings.TrimSpace(configuredName)
|
|
if name != "" {
|
|
return name
|
|
}
|
|
return strings.TrimSpace(fallback)
|
|
}
|
|
|
|
func normalizeClientParam(raw string) string {
|
|
value := strings.TrimSpace(raw)
|
|
if value == "" {
|
|
return ""
|
|
}
|
|
if idx := strings.Index(value, "|"); idx >= 0 {
|
|
value = value[:idx]
|
|
}
|
|
value = strings.TrimSpace(value)
|
|
return strings.ToLower(value)
|
|
}
|
|
|
|
func sanitizeUserInfoPayload(body []byte) (map[string]any, []string, bool) {
|
|
if len(body) == 0 {
|
|
return map[string]any{}, nil, true
|
|
}
|
|
|
|
var payload any
|
|
if err := json.Unmarshal(body, &payload); err != nil {
|
|
return nil, nil, false
|
|
}
|
|
|
|
perms := collectPermissionNames(payload)
|
|
|
|
sensitive := map[string]struct{}{
|
|
"roles": {},
|
|
"permissions": {},
|
|
}
|
|
payload = scrubSensitiveKeys(payload, sensitive)
|
|
|
|
sanitized, ok := payload.(map[string]any)
|
|
if !ok {
|
|
sanitized = map[string]any{"data": payload}
|
|
}
|
|
|
|
return sanitized, perms, true
|
|
}
|
|
|
|
func scrubSensitiveKeys(value any, sensitive map[string]struct{}) any {
|
|
switch v := value.(type) {
|
|
case map[string]any:
|
|
for key, val := range v {
|
|
if _, ok := sensitive[strings.ToLower(key)]; ok {
|
|
delete(v, key)
|
|
continue
|
|
}
|
|
v[key] = scrubSensitiveKeys(val, sensitive)
|
|
}
|
|
return v
|
|
case []any:
|
|
for i, item := range v {
|
|
v[i] = scrubSensitiveKeys(item, sensitive)
|
|
}
|
|
return v
|
|
default:
|
|
return value
|
|
}
|
|
}
|
|
|
|
func collectPermissionNames(value any) []string {
|
|
names := make(map[string]struct{})
|
|
collectPermissionRec(value, names)
|
|
out := make([]string, 0, len(names))
|
|
for name := range names {
|
|
out = append(out, name)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func collectPermissionRec(value any, acc map[string]struct{}) {
|
|
switch v := value.(type) {
|
|
case map[string]any:
|
|
for key, val := range v {
|
|
if strings.EqualFold(key, "permissions") {
|
|
if arr, ok := val.([]any); ok {
|
|
for _, item := range arr {
|
|
if perm, ok := item.(map[string]any); ok {
|
|
if name, ok := perm["name"].(string); ok && strings.TrimSpace(name) != "" {
|
|
acc[strings.ToLower(strings.TrimSpace(name))] = struct{}{}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
collectPermissionRec(val, acc)
|
|
}
|
|
}
|
|
case []any:
|
|
for _, item := range v {
|
|
collectPermissionRec(item, acc)
|
|
}
|
|
}
|
|
}
|
|
|
|
func injectCapabilities(payload map[string]any, caps map[string]bool) {
|
|
if len(caps) == 0 {
|
|
return
|
|
}
|
|
if data, ok := payload["data"].(map[string]any); ok {
|
|
data["capabilities"] = caps
|
|
return
|
|
}
|
|
payload["capabilities"] = caps
|
|
}
|
|
|
|
func findSSOClientConfig(requestedAlias string) (string, config.SSOClientConfig, bool) {
|
|
if requestedAlias == "" {
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
if cfg, ok := config.SSOClients[requestedAlias]; ok && strings.TrimSpace(cfg.PublicID) != "" {
|
|
return requestedAlias, cfg, true
|
|
}
|
|
for alias, cfg := range config.SSOClients {
|
|
if strings.EqualFold(strings.TrimSpace(cfg.PublicID), requestedAlias) && strings.TrimSpace(cfg.PublicID) != "" {
|
|
return alias, cfg, true
|
|
}
|
|
}
|
|
return "", config.SSOClientConfig{}, false
|
|
}
|
|
|
|
func normalizeReturnTarget(returnTo string, cfg config.SSOClientConfig) (string, error) {
|
|
returnTo = strings.TrimSpace(returnTo)
|
|
if returnTo == "" {
|
|
return "", nil
|
|
}
|
|
if strings.HasPrefix(returnTo, "//") {
|
|
return "", fmt.Errorf("invalid return_to")
|
|
}
|
|
if strings.HasPrefix(returnTo, "/") {
|
|
return returnTo, nil
|
|
}
|
|
|
|
parsed, err := url.Parse(returnTo)
|
|
if err != nil {
|
|
return "", fmt.Errorf("invalid return_to")
|
|
}
|
|
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
|
return "", fmt.Errorf("invalid return_to scheme")
|
|
}
|
|
|
|
allowedOrigins := make(map[string]struct{})
|
|
if cfg.DefaultReturnURI != "" {
|
|
if u, err := url.Parse(cfg.DefaultReturnURI); err == nil && u.Host != "" {
|
|
allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{}
|
|
}
|
|
}
|
|
for _, origin := range cfg.AllowedReturnOrigins {
|
|
origin = strings.TrimSpace(origin)
|
|
if origin == "" {
|
|
continue
|
|
}
|
|
if u, err := url.Parse(origin); err == nil && u.Host != "" && (u.Scheme == "http" || u.Scheme == "https") {
|
|
allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{}
|
|
}
|
|
}
|
|
|
|
if len(allowedOrigins) > 0 {
|
|
origin := parsed.Scheme + "://" + parsed.Host
|
|
if _, ok := allowedOrigins[origin]; !ok {
|
|
return "", fmt.Errorf("return_to origin not allowed")
|
|
}
|
|
}
|
|
|
|
return parsed.String(), nil
|
|
}
|