mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-20 13:31:56 +00:00
feat/login crud in users sync with sso
This commit is contained in:
@@ -0,0 +1,404 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/sso"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/utils/secure"
|
||||
)
|
||||
|
||||
// Controller manages the SSO start & callback flow using PKCE.
|
||||
type Controller struct {
|
||||
httpClient *http.Client
|
||||
store *session.Store
|
||||
}
|
||||
|
||||
func NewController(client *http.Client, store *session.Store) *Controller {
|
||||
return &Controller{httpClient: client, store: store}
|
||||
}
|
||||
|
||||
// Start handles GET /sso/start requests and redirects users to the central SSO authorize endpoint.
|
||||
func (h *Controller) Start(c *fiber.Ctx) error {
|
||||
alias := strings.ToLower(strings.TrimSpace(c.Query("client")))
|
||||
if alias == "" {
|
||||
alias = strings.ToLower(strings.TrimSpace(c.Query("client_id")))
|
||||
}
|
||||
if alias == "" {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "missing client")
|
||||
}
|
||||
cfg, ok := config.SSOClients[alias]
|
||||
if !ok || cfg.PublicID == "" {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "unknown client")
|
||||
}
|
||||
|
||||
authorizeEndpoint := strings.TrimSpace(config.SSOAuthorizeURL)
|
||||
if authorizeEndpoint == "" {
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "authorize endpoint not configured")
|
||||
}
|
||||
|
||||
state, err := secure.RandomString(48)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("generate state failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
||||
}
|
||||
|
||||
nonce, err := secure.RandomString(32)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("generate nonce failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
||||
}
|
||||
|
||||
codeVerifier, err := secure.PKCECodeVerifier(96)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("generate code verifier failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
||||
}
|
||||
|
||||
digest := sha256.Sum256([]byte(codeVerifier))
|
||||
challenge := secure.Base64URLEncode(digest[:])
|
||||
|
||||
authorizeURL, err := url.Parse(authorizeEndpoint)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "invalid authorize endpoint")
|
||||
}
|
||||
|
||||
scope := cfg.Scope
|
||||
if scope == "" {
|
||||
scope = "openid profile"
|
||||
}
|
||||
if !strings.Contains(" "+scope+" ", " openid ") {
|
||||
scope = scope + " openid"
|
||||
}
|
||||
|
||||
rawReturn := strings.TrimSpace(c.Query("return_to"))
|
||||
if rawReturn == "" {
|
||||
rawReturn = cfg.DefaultReturnURI
|
||||
}
|
||||
|
||||
returnTo, err := normalizeReturnTarget(rawReturn, cfg)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
query := authorizeURL.Query()
|
||||
query.Set("response_type", "code")
|
||||
query.Set("client_id", cfg.PublicID)
|
||||
query.Set("redirect_uri", cfg.RedirectURI)
|
||||
query.Set("scope", strings.TrimSpace(scope))
|
||||
query.Set("state", state)
|
||||
query.Set("code_challenge", challenge)
|
||||
query.Set("code_challenge_method", "S256")
|
||||
query.Set("nonce", nonce)
|
||||
if prompt := strings.TrimSpace(cfg.Prompt); prompt != "" {
|
||||
query.Set("prompt", prompt)
|
||||
}
|
||||
if extraPrompt := strings.TrimSpace(c.Query("prompt")); extraPrompt != "" {
|
||||
query.Set("prompt", extraPrompt)
|
||||
}
|
||||
authorizeURL.RawQuery = query.Encode()
|
||||
|
||||
payload := &session.PKCESession{
|
||||
CodeVerifier: codeVerifier,
|
||||
Nonce: nonce,
|
||||
ClientAlias: alias,
|
||||
ClientID: cfg.PublicID,
|
||||
RedirectURI: cfg.RedirectURI,
|
||||
Scope: strings.TrimSpace(scope),
|
||||
ReturnTo: returnTo,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := h.store.Save(c.Context(), state, payload); err != nil {
|
||||
utils.Log.Errorf("store pkce session failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
||||
}
|
||||
|
||||
utils.Log.WithFields(logrus.Fields{
|
||||
"client": alias,
|
||||
"state": state,
|
||||
"return_to": returnTo,
|
||||
}).Info("sso start redirect")
|
||||
|
||||
return c.Redirect(authorizeURL.String(), fiber.StatusFound)
|
||||
}
|
||||
|
||||
// Callback handles the redirect from SSO containing the authorization code.
|
||||
func (h *Controller) Callback(c *fiber.Ctx) error {
|
||||
state := strings.TrimSpace(c.Query("state"))
|
||||
code := strings.TrimSpace(c.Query("code"))
|
||||
if state == "" || code == "" {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "missing code or state")
|
||||
}
|
||||
|
||||
sessionData, err := h.store.Get(c.Context(), state)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("load pkce session failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to validate authorization state")
|
||||
}
|
||||
if sessionData == nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "authorization state not found or expired")
|
||||
}
|
||||
defer func() {
|
||||
if err := h.store.Delete(context.Background(), state); err != nil {
|
||||
utils.Log.Warnf("failed to delete pkce session: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
tokenEndpoint := strings.TrimSpace(config.SSOTokenURL)
|
||||
if tokenEndpoint == "" {
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "token endpoint not configured")
|
||||
}
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", "authorization_code")
|
||||
form.Set("code", code)
|
||||
form.Set("code_verifier", sessionData.CodeVerifier)
|
||||
form.Set("redirect_uri", sessionData.RedirectURI)
|
||||
form.Set("client_id", sessionData.ClientID)
|
||||
|
||||
req, err := http.NewRequestWithContext(c.Context(), http.MethodPost, tokenEndpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to create token request")
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := h.httpClient.Do(req)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("token request failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusBadGateway, "failed to exchange authorization code")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
utils.Log.Warnf("token response status %d", resp.StatusCode)
|
||||
return fiber.NewError(fiber.StatusBadGateway, "token exchange rejected")
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
IDToken string `json:"id_token"`
|
||||
Error string `json:"error"`
|
||||
Description string `json:"error_description"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return fiber.NewError(fiber.StatusBadGateway, "invalid token response")
|
||||
}
|
||||
if tokenResp.Error != "" {
|
||||
return fiber.NewError(fiber.StatusBadGateway, tokenResp.Description)
|
||||
}
|
||||
if tokenResp.AccessToken == "" {
|
||||
return fiber.NewError(fiber.StatusBadGateway, "missing access token")
|
||||
}
|
||||
|
||||
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("access token verification failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
|
||||
}
|
||||
|
||||
// prepare cookies
|
||||
issueCookies(c, tokenResp, verification)
|
||||
|
||||
redirectTarget := sessionData.ReturnTo
|
||||
if redirectTarget == "" {
|
||||
redirectTarget = "/"
|
||||
}
|
||||
|
||||
fmt.Println(sessionData.ClientAlias,"test")
|
||||
utils.Log.WithFields(logrus.Fields{
|
||||
"client": sessionData.ClientAlias,
|
||||
"user_id": verification.UserID,
|
||||
"return_to": redirectTarget,
|
||||
}).Info("sso callback successful")
|
||||
|
||||
return c.Redirect(redirectTarget, fiber.StatusFound)
|
||||
}
|
||||
|
||||
// UserInfo proxies the user profile from the central SSO so the frontend can obtain
|
||||
// enriched user metadata (roles, permissions, etc.) without exposing tokens to the browser.
|
||||
func (h *Controller) UserInfo(c *fiber.Ctx) error {
|
||||
accessName := config.SSOAccessCookieName
|
||||
if accessName == "" {
|
||||
accessName = "sso_access"
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(c.Cookies(accessName))
|
||||
tokenFromCookie := token != ""
|
||||
|
||||
if !tokenFromCookie {
|
||||
authHeader := strings.TrimSpace(c.Get("Authorization"))
|
||||
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
||||
token = strings.TrimSpace(authHeader[7:])
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
||||
}
|
||||
|
||||
endpoint := strings.TrimSpace(config.SSOGetMeURL)
|
||||
if endpoint == "" {
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "userinfo endpoint not configured")
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(c.Context(), http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("failed to build userinfo request: %v", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare userinfo request")
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
// SSO /auth/get-me expects the access cookie; add Authorization as well for compatibility.
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
if tokenFromCookie {
|
||||
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", accessName, token))
|
||||
}
|
||||
|
||||
resp, err := h.httpClient.Do(req)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("userinfo request failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusBadGateway, "failed to fetch user profile")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
utils.Log.WithFields(logrus.Fields{"status": resp.StatusCode}).Info("sso userinfo response")
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("failed to read userinfo response: %v", err)
|
||||
return fiber.NewError(fiber.StatusBadGateway, "invalid user profile response")
|
||||
}
|
||||
|
||||
if ct := resp.Header.Get("Content-Type"); ct != "" {
|
||||
c.Set("Content-Type", ct)
|
||||
} else {
|
||||
c.Type("json")
|
||||
}
|
||||
|
||||
return c.Status(resp.StatusCode).Send(body)
|
||||
}
|
||||
|
||||
func issueCookies(c *fiber.Ctx, tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
IDToken string `json:"id_token"`
|
||||
Error string `json:"error"`
|
||||
Description string `json:"error_description"`
|
||||
}, verification *sso.VerificationResult) {
|
||||
fmt.Println(tokenResp.AccessToken)
|
||||
accessName := config.SSOAccessCookieName
|
||||
if accessName == "" {
|
||||
accessName = "access"
|
||||
}
|
||||
refreshName := config.SSORefreshCookieName
|
||||
if refreshName == "" {
|
||||
refreshName = "refresh"
|
||||
}
|
||||
maxAge := tokenResp.ExpiresIn
|
||||
if maxAge <= 0 {
|
||||
maxAge = int(15 * time.Minute.Seconds())
|
||||
}
|
||||
|
||||
sameSite := config.SSOCookieSameSite
|
||||
if sameSite == "" {
|
||||
sameSite = "Lax"
|
||||
}
|
||||
|
||||
cookieDomain := config.SSOCookieDomain
|
||||
|
||||
cookieAccess := &fiber.Cookie{
|
||||
Name: accessName,
|
||||
Value: tokenResp.AccessToken,
|
||||
Path: "/",
|
||||
Domain: cookieDomain,
|
||||
HTTPOnly: true,
|
||||
Secure: config.SSOCookieSecure,
|
||||
SameSite: sameSite,
|
||||
MaxAge: maxAge,
|
||||
}
|
||||
c.Cookie(cookieAccess)
|
||||
if tokenResp.RefreshToken != "" {
|
||||
cookieRefresh := &fiber.Cookie{
|
||||
Name: refreshName,
|
||||
Value: tokenResp.RefreshToken,
|
||||
Path: "/",
|
||||
Domain: cookieDomain,
|
||||
HTTPOnly: true,
|
||||
Secure: config.SSOCookieSecure,
|
||||
SameSite: sameSite,
|
||||
MaxAge: int((time.Hour * 24 * 30).Seconds()),
|
||||
}
|
||||
c.Cookie(cookieRefresh)
|
||||
}
|
||||
|
||||
// Optional: expose limited info via headers for FE debugging (avoid tokens)
|
||||
c.Set("X-Auth-User", fmt.Sprintf("%d", verification.UserID))
|
||||
}
|
||||
|
||||
func normalizeReturnTarget(returnTo string, cfg config.SSOClientConfig) (string, error) {
|
||||
returnTo = strings.TrimSpace(returnTo)
|
||||
if returnTo == "" {
|
||||
return "", nil
|
||||
}
|
||||
if strings.HasPrefix(returnTo, "//") {
|
||||
return "", fmt.Errorf("invalid return_to")
|
||||
}
|
||||
if strings.HasPrefix(returnTo, "/") {
|
||||
return returnTo, nil
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(returnTo)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid return_to")
|
||||
}
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
return "", fmt.Errorf("invalid return_to scheme")
|
||||
}
|
||||
|
||||
allowedOrigins := make(map[string]struct{})
|
||||
if cfg.DefaultReturnURI != "" {
|
||||
if u, err := url.Parse(cfg.DefaultReturnURI); err == nil && u.Host != "" {
|
||||
allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{}
|
||||
}
|
||||
}
|
||||
for _, origin := range cfg.AllowedReturnOrigins {
|
||||
origin = strings.TrimSpace(origin)
|
||||
if origin == "" {
|
||||
continue
|
||||
}
|
||||
if u, err := url.Parse(origin); err == nil && u.Host != "" && (u.Scheme == "http" || u.Scheme == "https") {
|
||||
allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowedOrigins) > 0 {
|
||||
origin := parsed.Scheme + "://" + parsed.Host
|
||||
if _, ok := allowedOrigins[origin]; !ok {
|
||||
return "", fmt.Errorf("return_to origin not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
return parsed.String(), nil
|
||||
}
|
||||
Reference in New Issue
Block a user