package controllers import ( "context" "crypto/sha256" "encoding/json" "fmt" "io" "net/http" "net/url" "strings" "time" "github.com/gofiber/fiber/v2" "github.com/sirupsen/logrus" "gitlab.com/mbugroup/lti-api.git/internal/config" "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session" "gitlab.com/mbugroup/lti-api.git/internal/sso" "gitlab.com/mbugroup/lti-api.git/internal/utils" "gitlab.com/mbugroup/lti-api.git/internal/utils/secure" ) // Controller manages the SSO start & callback flow using PKCE. type Controller struct { httpClient *http.Client store *session.Store } func NewController(client *http.Client, store *session.Store) *Controller { return &Controller{httpClient: client, store: store} } // Start handles GET /sso/start requests and redirects users to the central SSO authorize endpoint. func (h *Controller) Start(c *fiber.Ctx) error { alias := strings.ToLower(strings.TrimSpace(c.Query("client"))) if alias == "" { alias = strings.ToLower(strings.TrimSpace(c.Query("client_id"))) } if alias == "" { return fiber.NewError(fiber.StatusBadRequest, "missing client") } cfg, ok := config.SSOClients[alias] if !ok || cfg.PublicID == "" { return fiber.NewError(fiber.StatusBadRequest, "unknown client") } authorizeEndpoint := strings.TrimSpace(config.SSOAuthorizeURL) if authorizeEndpoint == "" { return fiber.NewError(fiber.StatusInternalServerError, "authorize endpoint not configured") } state, err := secure.RandomString(48) if err != nil { utils.Log.Errorf("generate state failed: %v", err) return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization") } nonce, err := secure.RandomString(32) if err != nil { utils.Log.Errorf("generate nonce failed: %v", err) return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization") } codeVerifier, err := secure.PKCECodeVerifier(96) if err != nil { utils.Log.Errorf("generate code verifier failed: %v", err) return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization") } digest := sha256.Sum256([]byte(codeVerifier)) challenge := secure.Base64URLEncode(digest[:]) authorizeURL, err := url.Parse(authorizeEndpoint) if err != nil { return fiber.NewError(fiber.StatusInternalServerError, "invalid authorize endpoint") } scope := cfg.Scope if scope == "" { scope = "openid profile" } if !strings.Contains(" "+scope+" ", " openid ") { scope = scope + " openid" } rawReturn := strings.TrimSpace(c.Query("return_to")) if rawReturn == "" { rawReturn = cfg.DefaultReturnURI } returnTo, err := normalizeReturnTarget(rawReturn, cfg) if err != nil { return fiber.NewError(fiber.StatusBadRequest, err.Error()) } query := authorizeURL.Query() query.Set("response_type", "code") query.Set("client_id", cfg.PublicID) query.Set("redirect_uri", cfg.RedirectURI) query.Set("scope", strings.TrimSpace(scope)) query.Set("state", state) query.Set("code_challenge", challenge) query.Set("code_challenge_method", "S256") query.Set("nonce", nonce) if prompt := strings.TrimSpace(cfg.Prompt); prompt != "" { query.Set("prompt", prompt) } if extraPrompt := strings.TrimSpace(c.Query("prompt")); extraPrompt != "" { query.Set("prompt", extraPrompt) } authorizeURL.RawQuery = query.Encode() payload := &session.PKCESession{ CodeVerifier: codeVerifier, Nonce: nonce, ClientAlias: alias, ClientID: cfg.PublicID, RedirectURI: cfg.RedirectURI, Scope: strings.TrimSpace(scope), ReturnTo: returnTo, CreatedAt: time.Now().UTC(), } if err := h.store.Save(c.Context(), state, payload); err != nil { utils.Log.Errorf("store pkce session failed: %v", err) return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization") } utils.Log.WithFields(logrus.Fields{ "client": alias, "state": state, "return_to": returnTo, }).Info("sso start redirect") return c.Redirect(authorizeURL.String(), fiber.StatusFound) } // Callback handles the redirect from SSO containing the authorization code. func (h *Controller) Callback(c *fiber.Ctx) error { state := strings.TrimSpace(c.Query("state")) code := strings.TrimSpace(c.Query("code")) if state == "" || code == "" { return fiber.NewError(fiber.StatusBadRequest, "missing code or state") } sessionData, err := h.store.Get(c.Context(), state) if err != nil { utils.Log.Errorf("load pkce session failed: %v", err) return fiber.NewError(fiber.StatusInternalServerError, "failed to validate authorization state") } if sessionData == nil { return fiber.NewError(fiber.StatusBadRequest, "authorization state not found or expired") } defer func() { if err := h.store.Delete(context.Background(), state); err != nil { utils.Log.Warnf("failed to delete pkce session: %v", err) } }() tokenEndpoint := strings.TrimSpace(config.SSOTokenURL) if tokenEndpoint == "" { return fiber.NewError(fiber.StatusInternalServerError, "token endpoint not configured") } form := url.Values{} form.Set("grant_type", "authorization_code") form.Set("code", code) form.Set("code_verifier", sessionData.CodeVerifier) form.Set("redirect_uri", sessionData.RedirectURI) form.Set("client_id", sessionData.ClientID) req, err := http.NewRequestWithContext(c.Context(), http.MethodPost, tokenEndpoint, strings.NewReader(form.Encode())) if err != nil { return fiber.NewError(fiber.StatusInternalServerError, "failed to create token request") } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := h.httpClient.Do(req) if err != nil { utils.Log.Errorf("token request failed: %v", err) return fiber.NewError(fiber.StatusBadGateway, "failed to exchange authorization code") } defer resp.Body.Close() if resp.StatusCode >= 400 { utils.Log.Warnf("token response status %d", resp.StatusCode) return fiber.NewError(fiber.StatusBadGateway, "token exchange rejected") } var tokenResp struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` Scope string `json:"scope"` IDToken string `json:"id_token"` Error string `json:"error"` Description string `json:"error_description"` } if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { return fiber.NewError(fiber.StatusBadGateway, "invalid token response") } if tokenResp.Error != "" { return fiber.NewError(fiber.StatusBadGateway, tokenResp.Description) } if tokenResp.AccessToken == "" { return fiber.NewError(fiber.StatusBadGateway, "missing access token") } verification, err := sso.VerifyAccessToken(tokenResp.AccessToken) if err != nil { utils.Log.Errorf("access token verification failed: %v", err) return fiber.NewError(fiber.StatusUnauthorized, "invalid access token") } // prepare cookies issueCookies(c, tokenResp, verification) redirectTarget := sessionData.ReturnTo if redirectTarget == "" { redirectTarget = "/" } fmt.Println(sessionData.ClientAlias,"test") utils.Log.WithFields(logrus.Fields{ "client": sessionData.ClientAlias, "user_id": verification.UserID, "return_to": redirectTarget, }).Info("sso callback successful") return c.Redirect(redirectTarget, fiber.StatusFound) } // UserInfo proxies the user profile from the central SSO so the frontend can obtain // enriched user metadata (roles, permissions, etc.) without exposing tokens to the browser. func (h *Controller) UserInfo(c *fiber.Ctx) error { accessName := config.SSOAccessCookieName if accessName == "" { accessName = "sso_access" } token := strings.TrimSpace(c.Cookies(accessName)) tokenFromCookie := token != "" if !tokenFromCookie { authHeader := strings.TrimSpace(c.Get("Authorization")) if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { token = strings.TrimSpace(authHeader[7:]) } } if token == "" { return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated") } endpoint := strings.TrimSpace(config.SSOGetMeURL) if endpoint == "" { return fiber.NewError(fiber.StatusInternalServerError, "userinfo endpoint not configured") } req, err := http.NewRequestWithContext(c.Context(), http.MethodGet, endpoint, nil) if err != nil { utils.Log.Errorf("failed to build userinfo request: %v", err) return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare userinfo request") } req.Header.Set("Accept", "application/json") // SSO /auth/get-me expects the access cookie; add Authorization as well for compatibility. req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) if tokenFromCookie { req.Header.Set("Cookie", fmt.Sprintf("%s=%s", accessName, token)) } resp, err := h.httpClient.Do(req) if err != nil { utils.Log.Errorf("userinfo request failed: %v", err) return fiber.NewError(fiber.StatusBadGateway, "failed to fetch user profile") } defer resp.Body.Close() utils.Log.WithFields(logrus.Fields{"status": resp.StatusCode}).Info("sso userinfo response") body, err := io.ReadAll(resp.Body) if err != nil { utils.Log.Errorf("failed to read userinfo response: %v", err) return fiber.NewError(fiber.StatusBadGateway, "invalid user profile response") } if ct := resp.Header.Get("Content-Type"); ct != "" { c.Set("Content-Type", ct) } else { c.Type("json") } return c.Status(resp.StatusCode).Send(body) } func issueCookies(c *fiber.Ctx, tokenResp struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` Scope string `json:"scope"` IDToken string `json:"id_token"` Error string `json:"error"` Description string `json:"error_description"` }, verification *sso.VerificationResult) { fmt.Println(tokenResp.AccessToken) accessName := config.SSOAccessCookieName if accessName == "" { accessName = "access" } refreshName := config.SSORefreshCookieName if refreshName == "" { refreshName = "refresh" } maxAge := tokenResp.ExpiresIn if maxAge <= 0 { maxAge = int(15 * time.Minute.Seconds()) } sameSite := config.SSOCookieSameSite if sameSite == "" { sameSite = "Lax" } cookieDomain := config.SSOCookieDomain cookieAccess := &fiber.Cookie{ Name: accessName, Value: tokenResp.AccessToken, Path: "/", Domain: cookieDomain, HTTPOnly: true, Secure: config.SSOCookieSecure, SameSite: sameSite, MaxAge: maxAge, } c.Cookie(cookieAccess) if tokenResp.RefreshToken != "" { cookieRefresh := &fiber.Cookie{ Name: refreshName, Value: tokenResp.RefreshToken, Path: "/", Domain: cookieDomain, HTTPOnly: true, Secure: config.SSOCookieSecure, SameSite: sameSite, MaxAge: int((time.Hour * 24 * 30).Seconds()), } c.Cookie(cookieRefresh) } // Optional: expose limited info via headers for FE debugging (avoid tokens) c.Set("X-Auth-User", fmt.Sprintf("%d", verification.UserID)) } func normalizeReturnTarget(returnTo string, cfg config.SSOClientConfig) (string, error) { returnTo = strings.TrimSpace(returnTo) if returnTo == "" { return "", nil } if strings.HasPrefix(returnTo, "//") { return "", fmt.Errorf("invalid return_to") } if strings.HasPrefix(returnTo, "/") { return returnTo, nil } parsed, err := url.Parse(returnTo) if err != nil { return "", fmt.Errorf("invalid return_to") } if parsed.Scheme != "http" && parsed.Scheme != "https" { return "", fmt.Errorf("invalid return_to scheme") } allowedOrigins := make(map[string]struct{}) if cfg.DefaultReturnURI != "" { if u, err := url.Parse(cfg.DefaultReturnURI); err == nil && u.Host != "" { allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{} } } for _, origin := range cfg.AllowedReturnOrigins { origin = strings.TrimSpace(origin) if origin == "" { continue } if u, err := url.Parse(origin); err == nil && u.Host != "" && (u.Scheme == "http" || u.Scheme == "https") { allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{} } } if len(allowedOrigins) > 0 { origin := parsed.Scheme + "://" + parsed.Host if _, ok := allowedOrigins[origin]; !ok { return "", fmt.Errorf("return_to origin not allowed") } } return parsed.String(), nil }