package controllers import ( "context" "crypto/sha256" "encoding/json" "fmt" "io" "net/http" "net/url" "strings" "time" "github.com/gofiber/fiber/v2" "github.com/sirupsen/logrus" "gitlab.com/mbugroup/lti-api.git/internal/config" "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session" sso "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/verifier" "gitlab.com/mbugroup/lti-api.git/internal/utils" "gitlab.com/mbugroup/lti-api.git/internal/utils/secure" ) // Controller manages the SSO start & callback flow using PKCE. type Controller struct { httpClient *http.Client store *session.Store revoker *session.RevocationStore } func NewController(client *http.Client, store *session.Store, revoker *session.RevocationStore) *Controller { return &Controller{httpClient: client, store: store, revoker: revoker} } // Start handles GET /sso/start requests and redirects users to the central SSO authorize endpoint. func (h *Controller) Start(c *fiber.Ctx) error { requestedAlias := normalizeClientParam(c.Query("client")) if requestedAlias == "" { requestedAlias = normalizeClientParam(c.Query("client_id")) } if requestedAlias == "" { return fiber.NewError(fiber.StatusBadRequest, "missing client") } alias, cfg, ok := findSSOClientConfig(requestedAlias) if !ok || cfg.PublicID == "" { return fiber.NewError(fiber.StatusBadRequest, "unknown client") } authorizeEndpoint := strings.TrimSpace(config.SSOAuthorizeURL) if authorizeEndpoint == "" { return fiber.NewError(fiber.StatusInternalServerError, "authorize endpoint not configured") } state, err := secure.RandomString(48) if err != nil { utils.Log.Errorf("generate state failed: %v", err) return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization") } nonce, err := secure.RandomString(32) if err != nil { utils.Log.Errorf("generate nonce failed: %v", err) return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization") } codeVerifier, err := secure.PKCECodeVerifier(96) if err != nil { utils.Log.Errorf("generate code verifier failed: %v", err) return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization") } digest := sha256.Sum256([]byte(codeVerifier)) challenge := secure.Base64URLEncode(digest[:]) authorizeURL, err := url.Parse(authorizeEndpoint) if err != nil { return fiber.NewError(fiber.StatusInternalServerError, "invalid authorize endpoint") } scope := cfg.Scope if scope == "" { scope = "openid profile" } if !strings.Contains(" "+scope+" ", " openid ") { scope = scope + " openid" } rawReturn := strings.TrimSpace(c.Query("return_to")) if rawReturn == "" { rawReturn = cfg.DefaultReturnURI } returnTo, err := normalizeReturnTarget(rawReturn, cfg) if err != nil { return fiber.NewError(fiber.StatusBadRequest, err.Error()) } query := authorizeURL.Query() query.Set("response_type", "code") query.Set("client_id", cfg.PublicID) query.Set("redirect_uri", cfg.RedirectURI) query.Set("scope", strings.TrimSpace(scope)) query.Set("state", state) query.Set("code_challenge", challenge) query.Set("code_challenge_method", "S256") query.Set("nonce", nonce) // if prompt := strings.TrimSpace(cfg.Prompt); prompt != "" { // query.Set("prompt", prompt) // } if extraPrompt := strings.TrimSpace(c.Query("prompt")); extraPrompt != "" { query.Set("prompt", extraPrompt) } authorizeURL.RawQuery = query.Encode() payload := &session.PKCESession{ CodeVerifier: codeVerifier, Nonce: nonce, ClientAlias: alias, ClientID: cfg.PublicID, RedirectURI: cfg.RedirectURI, Scope: strings.TrimSpace(scope), ReturnTo: returnTo, CreatedAt: time.Now().UTC(), } if err := h.store.Save(c.Context(), state, payload); err != nil { utils.Log.Errorf("store pkce session failed: %v", err) return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization") } utils.Log.WithFields(logrus.Fields{ "client": alias, "state": state, "return_to": returnTo, }).Info("sso start redirect") return c.Redirect(authorizeURL.String(), fiber.StatusFound) } // Refresh exchanges the current SSO refresh token for a new access/refresh pair // without redirecting the browser to the SSO login page. func (h *Controller) Refresh(c *fiber.Ctx) error { refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh") refreshToken := strings.TrimSpace(c.Cookies(refreshName)) if refreshToken == "" { if target := buildStartRedirect(defaultSSOClientAlias()); target != "" { return c.Redirect(target, fiber.StatusFound) } return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated") } tokenEndpoint := strings.TrimSpace(config.SSOTokenURL) if tokenEndpoint == "" { return fiber.NewError(fiber.StatusInternalServerError, "token endpoint not configured") } form := url.Values{} form.Set("grant_type", "refresh_token") form.Set("refresh_token", refreshToken) req, err := http.NewRequestWithContext(c.Context(), http.MethodPost, tokenEndpoint, strings.NewReader(form.Encode())) if err != nil { return fiber.NewError(fiber.StatusInternalServerError, "failed to create refresh request") } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := h.httpClient.Do(req) if err != nil { utils.Log.Errorf("token refresh request failed: %v", err) return fiber.NewError(fiber.StatusBadGateway, "failed to refresh access token") } defer resp.Body.Close() if resp.StatusCode >= 400 { utils.Log.Warnf("token refresh response status %d", resp.StatusCode) if resp.StatusCode == fiber.StatusTooManyRequests { return fiber.NewError(fiber.StatusTooManyRequests, "Too many attempts, please slow down") } if target := buildStartRedirect(defaultSSOClientAlias()); target != "" { return c.Redirect(target, fiber.StatusFound) } return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated") } var tokenResp refreshTokenResponse if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { return fiber.NewError(fiber.StatusBadGateway, "invalid token response") } if tokenResp.Error != "" { return fiber.NewError(fiber.StatusBadGateway, tokenResp.Description) } if tokenResp.AccessToken == "" { return fiber.NewError(fiber.StatusBadGateway, "missing access token") } verification, err := sso.VerifyAccessToken(tokenResp.AccessToken) if err != nil { if sso.IsSignatureError(err) { logSignatureError("sso refresh", "sso_token", tokenResp.AccessToken, err) } else { utils.Log.Errorf("access token verification failed: %v", err) } return fiber.NewError(fiber.StatusUnauthorized, "invalid access token") } if err := issueCookies(c, struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` Scope string `json:"scope"` IDToken string `json:"id_token"` Error string `json:"error"` Description string `json:"error_description"` }{ AccessToken: tokenResp.AccessToken, RefreshToken: tokenResp.RefreshToken, TokenType: tokenResp.TokenType, ExpiresIn: tokenResp.ExpiresIn, Scope: tokenResp.Scope, IDToken: tokenResp.IDToken, Error: tokenResp.Error, Description: tokenResp.Description, }, verification); err != nil { return err } utils.Log.WithFields(logrus.Fields{ "user_id": verification.UserID, }).Info("sso refresh successful") return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "ok"}) } // Callback handles the redirect from SSO containing the authorization code. func (h *Controller) Callback(c *fiber.Ctx) error { state := strings.TrimSpace(c.Query("state")) code := strings.TrimSpace(c.Query("code")) if state == "" || code == "" { return fiber.NewError(fiber.StatusBadRequest, "missing code or state") } sessionData, err := h.store.Get(c.Context(), state) if err != nil { utils.Log.Errorf("load pkce session failed: %v", err) return fiber.NewError(fiber.StatusInternalServerError, "failed to validate authorization state") } if sessionData == nil { return fiber.NewError(fiber.StatusBadRequest, "authorization state not found or expired") } defer func() { if err := h.store.Delete(context.Background(), state); err != nil { utils.Log.Warnf("failed to delete pkce session: %v", err) } }() tokenEndpoint := strings.TrimSpace(config.SSOTokenURL) if tokenEndpoint == "" { return fiber.NewError(fiber.StatusInternalServerError, "token endpoint not configured") } form := url.Values{} form.Set("grant_type", "authorization_code") form.Set("code", code) form.Set("code_verifier", sessionData.CodeVerifier) form.Set("redirect_uri", sessionData.RedirectURI) form.Set("client_id", sessionData.ClientID) req, err := http.NewRequestWithContext(c.Context(), http.MethodPost, tokenEndpoint, strings.NewReader(form.Encode())) if err != nil { return fiber.NewError(fiber.StatusInternalServerError, "failed to create token request") } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := h.httpClient.Do(req) if err != nil { utils.Log.Errorf("token request failed: %v", err) return fiber.NewError(fiber.StatusBadGateway, "failed to exchange authorization code") } defer resp.Body.Close() if resp.StatusCode >= 400 { utils.Log.Warnf("token response status %d", resp.StatusCode) return fiber.NewError(fiber.StatusBadGateway, "token exchange rejected") } var tokenResp struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` Scope string `json:"scope"` IDToken string `json:"id_token"` Error string `json:"error"` Description string `json:"error_description"` } if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { return fiber.NewError(fiber.StatusBadGateway, "invalid token response") } if tokenResp.Error != "" { return fiber.NewError(fiber.StatusBadGateway, tokenResp.Description) } if tokenResp.AccessToken == "" { return fiber.NewError(fiber.StatusBadGateway, "missing access token") } verification, err := sso.VerifyAccessToken(tokenResp.AccessToken) if err != nil { if sso.IsSignatureError(err) { logSignatureError("sso callback", "sso_token", tokenResp.AccessToken, err) } else { utils.Log.Errorf("access token verification failed: %v", err) } return fiber.NewError(fiber.StatusUnauthorized, "invalid access token") } // prepare cookies if err := issueCookies(c, tokenResp, verification); err != nil { return err } redirectTarget := sessionData.ReturnTo if redirectTarget == "" { redirectTarget = "/" } utils.Log.WithFields(logrus.Fields{ "client": sessionData.ClientAlias, "user_id": verification.UserID, "return_to": redirectTarget, }).Info("sso callback successful") return c.Redirect(redirectTarget, fiber.StatusFound) } // UserInfo proxies the user profile from the central SSO so the frontend can obtain // enriched user metadata (roles, permissions, etc.) without exposing tokens to the browser. func (h *Controller) UserInfo(c *fiber.Ctx) error { accessName := config.SSOAccessCookieName if accessName == "" { accessName = "sso_access" } token := strings.TrimSpace(c.Cookies(accessName)) tokenFromCookie := token != "" usedCookieName := accessName if !tokenFromCookie { for _, name := range config.SSOAccessCookieFallback { name = strings.TrimSpace(name) if name == "" || name == accessName { continue } token = strings.TrimSpace(c.Cookies(name)) if token != "" { tokenFromCookie = true usedCookieName = name break } } } if !tokenFromCookie { authHeader := strings.TrimSpace(c.Get("Authorization")) if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { token = strings.TrimSpace(authHeader[7:]) } } if token == "" { return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated") } if revoker := session.GetRevocationStore(); revoker != nil { if fingerprint := session.TokenFingerprint(token); fingerprint != "" { revoked, err := revoker.IsRevoked(c.Context(), fingerprint) if err != nil { utils.Log.WithError(err).Warn("failed to check token revocation for userinfo") return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated") } if revoked { return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated") } } } if _, err := sso.VerifyAccessToken(token); err != nil { if sso.IsSignatureError(err) { logSignatureError("sso userinfo", "request", token, err) } else { utils.Log.WithError(err).Warn("access token verification failed for userinfo") } return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated") } endpoint := strings.TrimSpace(config.SSOGetMeURL) if endpoint == "" { return fiber.NewError(fiber.StatusInternalServerError, "userinfo endpoint not configured") } req, err := http.NewRequestWithContext(c.Context(), http.MethodGet, endpoint, nil) if err != nil { utils.Log.Errorf("failed to build userinfo request: %v", err) return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare userinfo request") } req.Header.Set("Accept", "application/json") // SSO /auth/get-me expects the access cookie; add Authorization as well for compatibility. req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) if tokenFromCookie { req.Header.Set("Cookie", fmt.Sprintf("%s=%s", usedCookieName, token)) } resp, err := h.httpClient.Do(req) if err != nil { utils.Log.Errorf("userinfo request failed: %v", err) return fiber.NewError(fiber.StatusBadGateway, "failed to fetch user profile") } defer resp.Body.Close() utils.Log.WithFields(logrus.Fields{"status": resp.StatusCode}).Info("sso userinfo response") body, err := io.ReadAll(resp.Body) if err != nil { utils.Log.Errorf("failed to read userinfo response: %v", err) return fiber.NewError(fiber.StatusBadGateway, "invalid user profile response") } if ct := resp.Header.Get("Content-Type"); ct != "" { c.Set("Content-Type", ct) } else { c.Type("json") } return c.Status(resp.StatusCode).Send(body) } // Logout clears SSO cookies and removes any leftover PKCE session state. func (h *Controller) Logout(c *fiber.Ctx) error { alias := "" if singleAlias, _, ok := singleSSOClient(); ok { alias = singleAlias } accessName := resolveSSOCookieName(config.SSOAccessCookieName, "access") refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh") var accessToken, refreshToken string var verification *sso.VerificationResult if accessName != "" { accessToken = strings.TrimSpace(c.Cookies(accessName)) } if refreshName != "" { refreshToken = strings.TrimSpace(c.Cookies(refreshName)) } hadAccessCookie := accessToken != "" hadRefreshCookie := refreshToken != "" if !hadAccessCookie && !hadRefreshCookie { return fiber.NewError(fiber.StatusUnauthorized, "not authenticated") } if hadAccessCookie { if v, err := sso.VerifyAccessToken(accessToken); err != nil { utils.Log.WithError(err).Warn("failed to verify access token during logout") } else { verification = v if revoker := session.GetRevocationStore(); revoker != nil { if err := revoker.MarkUserLogout(c.Context(), verification.UserID, time.Now().UTC()); err != nil { utils.Log.WithError(err).Warn("failed to mark user logout") } } h.revokeToken(c.Context(), accessToken, verification) } } if refreshToken != "" { h.revokeRefreshToken(c.Context(), refreshToken) } clearSSOCookie(c, accessName) clearSSOCookie(c, refreshName) redirectTarget := "" if config.SSOPortalURL != "" { redirectTarget = config.SSOPortalURL } utils.Log.WithFields(logrus.Fields{ "client": alias, "redirect": redirectTarget, }).Info("sso logout completed") if redirectTarget != "" { return c.Status(fiber.StatusOK).JSON(fiber.Map{ "status": "signed out", "redirect": redirectTarget, }) } return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "signed out"}) } func singleSSOClient() (string, config.SSOClientConfig, bool) { if len(config.SSOClients) != 1 { return "", config.SSOClientConfig{}, false } for alias, cfg := range config.SSOClients { if strings.TrimSpace(alias) == "" || strings.TrimSpace(cfg.PublicID) == "" { return "", config.SSOClientConfig{}, false } return alias, cfg, true } return "", config.SSOClientConfig{}, false } func defaultSSOClientAlias() string { for alias := range config.SSOClients { if strings.TrimSpace(alias) == "" { continue } return alias } return "" } func buildStartRedirect(alias string) string { alias = strings.TrimSpace(alias) if alias == "" { return "" } return "/api/sso/start?client=" + url.QueryEscape(alias) } func (h *Controller) revokeToken(ctx context.Context, token string, verification *sso.VerificationResult) { if h.revoker == nil || verification == nil || verification.Claims == nil { return } fingerprint := session.TokenFingerprint(token) if fingerprint == "" { return } if verification.Claims.ExpiresAt == nil { utils.Log.Warn("access token missing expiry claim") return } ttl := time.Until(verification.Claims.ExpiresAt.Time) if ttl <= 0 { return } if ttl < time.Second { ttl = time.Second } if err := h.revoker.Revoke(ctx, fingerprint, ttl); err != nil { utils.Log.WithError(err).Warn("failed to revoke access token") } } func (h *Controller) revokeRefreshToken(ctx context.Context, token string) { if h.revoker == nil { return } fingerprint := session.TokenFingerprint(token) if fingerprint == "" { return } const refreshTTL = 30 * 24 * time.Hour if err := h.revoker.Revoke(ctx, fingerprint, refreshTTL); err != nil { utils.Log.WithError(err).Warn("failed to revoke refresh token") } } func issueCookies(c *fiber.Ctx, tokenResp struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` Scope string `json:"scope"` IDToken string `json:"id_token"` Error string `json:"error"` Description string `json:"error_description"` }, verification *sso.VerificationResult) error { if revoker := session.GetRevocationStore(); revoker != nil && verification != nil { if err := revoker.ClearUserLogout(c.Context(), verification.UserID); err != nil { utils.Log.WithError(err).Warn("failed to clear logout marker") } } if max := config.SSOAccessTokenMaxBytes; max > 0 && len(tokenResp.AccessToken) > max { utils.Log.WithFields(logrus.Fields{ "token_len": len(tokenResp.AccessToken), "max_len": max, }).Warn("sso access token exceeds cookie size limit") return fiber.NewError(fiber.StatusRequestEntityTooLarge, "access token too large") } accessName := resolveSSOCookieName(config.SSOAccessCookieName, "access") refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh") maxAge := tokenResp.ExpiresIn if maxAge <= 0 { maxAge = int(15 * time.Minute.Seconds()) } sameSite := config.SSOCookieSameSite if sameSite == "" { sameSite = "Lax" } cookieDomain := config.SSOCookieDomain cookieAccess := &fiber.Cookie{ Name: accessName, Value: tokenResp.AccessToken, Path: "/", Domain: cookieDomain, HTTPOnly: true, Secure: config.SSOCookieSecure, SameSite: sameSite, MaxAge: maxAge, } c.Cookie(cookieAccess) if tokenResp.RefreshToken != "" { cookieRefresh := &fiber.Cookie{ Name: refreshName, Value: tokenResp.RefreshToken, Path: "/", Domain: cookieDomain, HTTPOnly: true, Secure: config.SSOCookieSecure, SameSite: sameSite, MaxAge: int((time.Hour * 24 * 30).Seconds()), } c.Cookie(cookieRefresh) } // Optional: expose limited info via headers for FE debugging (avoid tokens) c.Set("X-Auth-User", fmt.Sprintf("%d", verification.UserID)) return nil } func clearSSOCookie(c *fiber.Ctx, name string) { if name == "" { return } sameSite := config.SSOCookieSameSite if sameSite == "" { sameSite = "Lax" } c.Cookie(&fiber.Cookie{ Name: name, Value: "", Path: "/", Domain: config.SSOCookieDomain, HTTPOnly: true, Secure: config.SSOCookieSecure, SameSite: sameSite, Expires: time.Unix(0, 0), MaxAge: -1, }) } func resolveSSOCookieName(configuredName, fallback string) string { name := strings.TrimSpace(configuredName) if name != "" { return name } return strings.TrimSpace(fallback) } func logSignatureError(ctxLabel, tokenSource, token string, err error) { info := sso.ExtractTokenInfo(token) aud := strings.Join(info.Aud, ",") utils.Log.Errorf( "access token verification failed: %v | ctx=%s source=%s iss=%s kid=%s aud=%s sub=%s exp=%d iat=%d nbf=%d expected_iss=%s expected_aud=%v jwks=%s", err, ctxLabel, tokenSource, info.Iss, info.Kid, aud, info.Sub, info.Exp, info.Iat, info.Nbf, config.SSOIssuer, config.SSOAllowedAudiences, config.SSOJWKSURL, ) } func normalizeClientParam(raw string) string { value := strings.TrimSpace(raw) if value == "" { return "" } if idx := strings.Index(value, "|"); idx >= 0 { value = value[:idx] } value = strings.TrimSpace(value) return strings.ToLower(value) } func findSSOClientConfig(requestedAlias string) (string, config.SSOClientConfig, bool) { if requestedAlias == "" { return "", config.SSOClientConfig{}, false } if cfg, ok := config.SSOClients[requestedAlias]; ok && strings.TrimSpace(cfg.PublicID) != "" { return requestedAlias, cfg, true } for alias, cfg := range config.SSOClients { if strings.EqualFold(strings.TrimSpace(cfg.PublicID), requestedAlias) && strings.TrimSpace(cfg.PublicID) != "" { return alias, cfg, true } } return "", config.SSOClientConfig{}, false } func normalizeReturnTarget(returnTo string, cfg config.SSOClientConfig) (string, error) { returnTo = strings.TrimSpace(returnTo) if returnTo == "" { return "", nil } if strings.HasPrefix(returnTo, "//") { return "", fmt.Errorf("invalid return_to") } if strings.HasPrefix(returnTo, "/") { return returnTo, nil } parsed, err := url.Parse(returnTo) if err != nil { return "", fmt.Errorf("invalid return_to") } if parsed.Scheme != "http" && parsed.Scheme != "https" { return "", fmt.Errorf("invalid return_to scheme") } allowedOrigins := make(map[string]struct{}) if cfg.DefaultReturnURI != "" { if u, err := url.Parse(cfg.DefaultReturnURI); err == nil && u.Host != "" { allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{} } } for _, origin := range cfg.AllowedReturnOrigins { origin = strings.TrimSpace(origin) if origin == "" { continue } if u, err := url.Parse(origin); err == nil && u.Host != "" && (u.Scheme == "http" || u.Scheme == "https") { allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{} } } if len(allowedOrigins) > 0 { origin := parsed.Scheme + "://" + parsed.Host if _, ok := allowedOrigins[origin]; !ok { return "", fmt.Errorf("return_to origin not allowed") } } return parsed.String(), nil }