mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-20 13:31:56 +00:00
Merge branch 'development' of https://gitlab.com/mbugroup/lti-api into refactor-to-serve/with-middleware
This commit is contained in:
@@ -0,0 +1,607 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/sso"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/utils/secure"
|
||||
)
|
||||
|
||||
// Controller manages the SSO start & callback flow using PKCE.
|
||||
type Controller struct {
|
||||
httpClient *http.Client
|
||||
store *session.Store
|
||||
revoker *session.RevocationStore
|
||||
}
|
||||
|
||||
func NewController(client *http.Client, store *session.Store, revoker *session.RevocationStore) *Controller {
|
||||
return &Controller{httpClient: client, store: store, revoker: revoker}
|
||||
}
|
||||
|
||||
// Start handles GET /sso/start requests and redirects users to the central SSO authorize endpoint.
|
||||
func (h *Controller) Start(c *fiber.Ctx) error {
|
||||
requestedAlias := normalizeClientParam(c.Query("client"))
|
||||
if requestedAlias == "" {
|
||||
requestedAlias = normalizeClientParam(c.Query("client_id"))
|
||||
}
|
||||
if requestedAlias == "" {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "missing client")
|
||||
}
|
||||
|
||||
alias, cfg, ok := findSSOClientConfig(requestedAlias)
|
||||
if !ok || cfg.PublicID == "" {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "unknown client")
|
||||
}
|
||||
|
||||
authorizeEndpoint := strings.TrimSpace(config.SSOAuthorizeURL)
|
||||
if authorizeEndpoint == "" {
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "authorize endpoint not configured")
|
||||
}
|
||||
|
||||
state, err := secure.RandomString(48)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("generate state failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
||||
}
|
||||
|
||||
nonce, err := secure.RandomString(32)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("generate nonce failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
||||
}
|
||||
|
||||
codeVerifier, err := secure.PKCECodeVerifier(96)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("generate code verifier failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
||||
}
|
||||
|
||||
digest := sha256.Sum256([]byte(codeVerifier))
|
||||
challenge := secure.Base64URLEncode(digest[:])
|
||||
|
||||
authorizeURL, err := url.Parse(authorizeEndpoint)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "invalid authorize endpoint")
|
||||
}
|
||||
|
||||
scope := cfg.Scope
|
||||
if scope == "" {
|
||||
scope = "openid profile"
|
||||
}
|
||||
if !strings.Contains(" "+scope+" ", " openid ") {
|
||||
scope = scope + " openid"
|
||||
}
|
||||
|
||||
rawReturn := strings.TrimSpace(c.Query("return_to"))
|
||||
if rawReturn == "" {
|
||||
rawReturn = cfg.DefaultReturnURI
|
||||
}
|
||||
|
||||
returnTo, err := normalizeReturnTarget(rawReturn, cfg)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
query := authorizeURL.Query()
|
||||
query.Set("response_type", "code")
|
||||
query.Set("client_id", cfg.PublicID)
|
||||
query.Set("redirect_uri", cfg.RedirectURI)
|
||||
query.Set("scope", strings.TrimSpace(scope))
|
||||
query.Set("state", state)
|
||||
query.Set("code_challenge", challenge)
|
||||
query.Set("code_challenge_method", "S256")
|
||||
query.Set("nonce", nonce)
|
||||
// if prompt := strings.TrimSpace(cfg.Prompt); prompt != "" {
|
||||
// query.Set("prompt", prompt)
|
||||
// }
|
||||
if extraPrompt := strings.TrimSpace(c.Query("prompt")); extraPrompt != "" {
|
||||
query.Set("prompt", extraPrompt)
|
||||
}
|
||||
authorizeURL.RawQuery = query.Encode()
|
||||
|
||||
payload := &session.PKCESession{
|
||||
CodeVerifier: codeVerifier,
|
||||
Nonce: nonce,
|
||||
ClientAlias: alias,
|
||||
ClientID: cfg.PublicID,
|
||||
RedirectURI: cfg.RedirectURI,
|
||||
Scope: strings.TrimSpace(scope),
|
||||
ReturnTo: returnTo,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := h.store.Save(c.Context(), state, payload); err != nil {
|
||||
utils.Log.Errorf("store pkce session failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
||||
}
|
||||
|
||||
utils.Log.WithFields(logrus.Fields{
|
||||
"client": alias,
|
||||
"state": state,
|
||||
"return_to": returnTo,
|
||||
}).Info("sso start redirect")
|
||||
|
||||
return c.Redirect(authorizeURL.String(), fiber.StatusFound)
|
||||
}
|
||||
|
||||
// Callback handles the redirect from SSO containing the authorization code.
|
||||
func (h *Controller) Callback(c *fiber.Ctx) error {
|
||||
state := strings.TrimSpace(c.Query("state"))
|
||||
code := strings.TrimSpace(c.Query("code"))
|
||||
if state == "" || code == "" {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "missing code or state")
|
||||
}
|
||||
|
||||
sessionData, err := h.store.Get(c.Context(), state)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("load pkce session failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to validate authorization state")
|
||||
}
|
||||
if sessionData == nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "authorization state not found or expired")
|
||||
}
|
||||
defer func() {
|
||||
if err := h.store.Delete(context.Background(), state); err != nil {
|
||||
utils.Log.Warnf("failed to delete pkce session: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
tokenEndpoint := strings.TrimSpace(config.SSOTokenURL)
|
||||
if tokenEndpoint == "" {
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "token endpoint not configured")
|
||||
}
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", "authorization_code")
|
||||
form.Set("code", code)
|
||||
form.Set("code_verifier", sessionData.CodeVerifier)
|
||||
form.Set("redirect_uri", sessionData.RedirectURI)
|
||||
form.Set("client_id", sessionData.ClientID)
|
||||
|
||||
req, err := http.NewRequestWithContext(c.Context(), http.MethodPost, tokenEndpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to create token request")
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := h.httpClient.Do(req)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("token request failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusBadGateway, "failed to exchange authorization code")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
utils.Log.Warnf("token response status %d", resp.StatusCode)
|
||||
return fiber.NewError(fiber.StatusBadGateway, "token exchange rejected")
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
IDToken string `json:"id_token"`
|
||||
Error string `json:"error"`
|
||||
Description string `json:"error_description"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return fiber.NewError(fiber.StatusBadGateway, "invalid token response")
|
||||
}
|
||||
if tokenResp.Error != "" {
|
||||
return fiber.NewError(fiber.StatusBadGateway, tokenResp.Description)
|
||||
}
|
||||
if tokenResp.AccessToken == "" {
|
||||
return fiber.NewError(fiber.StatusBadGateway, "missing access token")
|
||||
}
|
||||
|
||||
fmt.Println(tokenResp.AccessToken)
|
||||
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("access token verification failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
|
||||
}
|
||||
|
||||
// prepare cookies
|
||||
issueCookies(c, tokenResp, verification)
|
||||
|
||||
redirectTarget := sessionData.ReturnTo
|
||||
if redirectTarget == "" {
|
||||
redirectTarget = "/"
|
||||
}
|
||||
|
||||
utils.Log.WithFields(logrus.Fields{
|
||||
"client": sessionData.ClientAlias,
|
||||
"user_id": verification.UserID,
|
||||
"return_to": redirectTarget,
|
||||
}).Info("sso callback successful")
|
||||
|
||||
return c.Redirect(redirectTarget, fiber.StatusFound)
|
||||
}
|
||||
|
||||
// UserInfo proxies the user profile from the central SSO so the frontend can obtain
|
||||
// enriched user metadata (roles, permissions, etc.) without exposing tokens to the browser.
|
||||
func (h *Controller) UserInfo(c *fiber.Ctx) error {
|
||||
accessName := config.SSOAccessCookieName
|
||||
if accessName == "" {
|
||||
accessName = "sso_access"
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(c.Cookies(accessName))
|
||||
tokenFromCookie := token != ""
|
||||
|
||||
if !tokenFromCookie {
|
||||
authHeader := strings.TrimSpace(c.Get("Authorization"))
|
||||
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
||||
token = strings.TrimSpace(authHeader[7:])
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
||||
}
|
||||
|
||||
if revoker := session.GetRevocationStore(); revoker != nil {
|
||||
if fingerprint := session.TokenFingerprint(token); fingerprint != "" {
|
||||
revoked, err := revoker.IsRevoked(c.Context(), fingerprint)
|
||||
if err != nil {
|
||||
utils.Log.WithError(err).Warn("failed to check token revocation for userinfo")
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
||||
}
|
||||
if revoked {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := sso.VerifyAccessToken(token); err != nil {
|
||||
utils.Log.WithError(err).Warn("access token verification failed for userinfo")
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
||||
}
|
||||
|
||||
endpoint := strings.TrimSpace(config.SSOGetMeURL)
|
||||
if endpoint == "" {
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "userinfo endpoint not configured")
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(c.Context(), http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("failed to build userinfo request: %v", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare userinfo request")
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
// SSO /auth/get-me expects the access cookie; add Authorization as well for compatibility.
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
if tokenFromCookie {
|
||||
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", accessName, token))
|
||||
}
|
||||
|
||||
resp, err := h.httpClient.Do(req)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("userinfo request failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusBadGateway, "failed to fetch user profile")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
utils.Log.WithFields(logrus.Fields{"status": resp.StatusCode}).Info("sso userinfo response")
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("failed to read userinfo response: %v", err)
|
||||
return fiber.NewError(fiber.StatusBadGateway, "invalid user profile response")
|
||||
}
|
||||
|
||||
if ct := resp.Header.Get("Content-Type"); ct != "" {
|
||||
c.Set("Content-Type", ct)
|
||||
} else {
|
||||
c.Type("json")
|
||||
}
|
||||
|
||||
return c.Status(resp.StatusCode).Send(body)
|
||||
}
|
||||
|
||||
// Logout clears SSO cookies and removes any leftover PKCE session state.
|
||||
func (h *Controller) Logout(c *fiber.Ctx) error {
|
||||
requestedAlias := normalizeClientParam(c.Query("client"))
|
||||
if requestedAlias == "" {
|
||||
requestedAlias = normalizeClientParam(c.Query("client_id"))
|
||||
}
|
||||
var (
|
||||
alias string
|
||||
cfg config.SSOClientConfig
|
||||
hasClientInfo bool
|
||||
)
|
||||
if requestedAlias != "" {
|
||||
alias, cfg, hasClientInfo = findSSOClientConfig(requestedAlias)
|
||||
}
|
||||
|
||||
accessName := resolveSSOCookieName(config.SSOAccessCookieName, "access")
|
||||
refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh")
|
||||
|
||||
var accessToken, refreshToken string
|
||||
if accessName != "" {
|
||||
accessToken = strings.TrimSpace(c.Cookies(accessName))
|
||||
}
|
||||
if refreshName != "" {
|
||||
refreshToken = strings.TrimSpace(c.Cookies(refreshName))
|
||||
}
|
||||
hadAccessCookie := accessToken != ""
|
||||
hadRefreshCookie := refreshToken != ""
|
||||
|
||||
state := strings.TrimSpace(c.Query("state"))
|
||||
if state != "" {
|
||||
if err := h.store.Delete(c.Context(), state); err != nil {
|
||||
utils.Log.Warnf("failed to delete pkce session during logout: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !hadAccessCookie && !hadRefreshCookie && state == "" {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "not authenticated")
|
||||
}
|
||||
|
||||
if hadAccessCookie {
|
||||
if verification, err := sso.VerifyAccessToken(accessToken); err != nil {
|
||||
utils.Log.WithError(err).Warn("failed to verify access token during logout")
|
||||
} else {
|
||||
if revoker := session.GetRevocationStore(); revoker != nil {
|
||||
if err := revoker.MarkUserLogout(c.Context(), verification.UserID, time.Now().UTC()); err != nil {
|
||||
utils.Log.WithError(err).Warn("failed to mark user logout")
|
||||
}
|
||||
}
|
||||
h.revokeToken(c.Context(), accessToken, verification)
|
||||
}
|
||||
}
|
||||
if refreshToken != "" {
|
||||
h.revokeRefreshToken(c.Context(), refreshToken)
|
||||
}
|
||||
|
||||
clearSSOCookie(c, accessName)
|
||||
clearSSOCookie(c, refreshName)
|
||||
|
||||
redirectTarget := ""
|
||||
rawReturn := strings.TrimSpace(c.Query("return_to"))
|
||||
if hasClientInfo {
|
||||
if rawReturn == "" {
|
||||
rawReturn = cfg.DefaultReturnURI
|
||||
}
|
||||
if normalized, err := normalizeReturnTarget(rawReturn, cfg); err == nil {
|
||||
redirectTarget = normalized
|
||||
} else if rawReturn != "" {
|
||||
utils.Log.WithError(err).Warn("invalid return_to during logout")
|
||||
}
|
||||
} else if rawReturn != "" {
|
||||
if strings.HasPrefix(rawReturn, "/") && !strings.HasPrefix(rawReturn, "//") {
|
||||
redirectTarget = rawReturn
|
||||
}
|
||||
}
|
||||
|
||||
utils.Log.WithFields(logrus.Fields{
|
||||
"client": alias,
|
||||
"state": state,
|
||||
"redirect": redirectTarget,
|
||||
}).Info("sso logout completed")
|
||||
|
||||
if redirectTarget != "" {
|
||||
return c.Redirect(redirectTarget, fiber.StatusFound)
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "signed out"})
|
||||
}
|
||||
|
||||
func (h *Controller) revokeToken(ctx context.Context, token string, verification *sso.VerificationResult) {
|
||||
if h.revoker == nil || verification == nil || verification.Claims == nil {
|
||||
return
|
||||
}
|
||||
fingerprint := session.TokenFingerprint(token)
|
||||
if fingerprint == "" {
|
||||
return
|
||||
}
|
||||
if verification.Claims.ExpiresAt == nil {
|
||||
utils.Log.Warn("access token missing expiry claim")
|
||||
return
|
||||
}
|
||||
ttl := time.Until(verification.Claims.ExpiresAt.Time)
|
||||
if ttl <= 0 {
|
||||
return
|
||||
}
|
||||
if ttl < time.Second {
|
||||
ttl = time.Second
|
||||
}
|
||||
if err := h.revoker.Revoke(ctx, fingerprint, ttl); err != nil {
|
||||
utils.Log.WithError(err).Warn("failed to revoke access token")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Controller) revokeRefreshToken(ctx context.Context, token string) {
|
||||
if h.revoker == nil {
|
||||
return
|
||||
}
|
||||
fingerprint := session.TokenFingerprint(token)
|
||||
if fingerprint == "" {
|
||||
return
|
||||
}
|
||||
const refreshTTL = 30 * 24 * time.Hour
|
||||
if err := h.revoker.Revoke(ctx, fingerprint, refreshTTL); err != nil {
|
||||
utils.Log.WithError(err).Warn("failed to revoke refresh token")
|
||||
}
|
||||
}
|
||||
|
||||
func issueCookies(c *fiber.Ctx, tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
IDToken string `json:"id_token"`
|
||||
Error string `json:"error"`
|
||||
Description string `json:"error_description"`
|
||||
}, verification *sso.VerificationResult) {
|
||||
if revoker := session.GetRevocationStore(); revoker != nil && verification != nil {
|
||||
if err := revoker.ClearUserLogout(c.Context(), verification.UserID); err != nil {
|
||||
utils.Log.WithError(err).Warn("failed to clear logout marker")
|
||||
}
|
||||
}
|
||||
|
||||
accessName := resolveSSOCookieName(config.SSOAccessCookieName, "access")
|
||||
refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh")
|
||||
maxAge := tokenResp.ExpiresIn
|
||||
if maxAge <= 0 {
|
||||
maxAge = int(15 * time.Minute.Seconds())
|
||||
}
|
||||
|
||||
sameSite := config.SSOCookieSameSite
|
||||
if sameSite == "" {
|
||||
sameSite = "Lax"
|
||||
}
|
||||
|
||||
cookieDomain := config.SSOCookieDomain
|
||||
|
||||
cookieAccess := &fiber.Cookie{
|
||||
Name: accessName,
|
||||
Value: tokenResp.AccessToken,
|
||||
Path: "/",
|
||||
Domain: cookieDomain,
|
||||
HTTPOnly: true,
|
||||
Secure: config.SSOCookieSecure,
|
||||
SameSite: sameSite,
|
||||
MaxAge: maxAge,
|
||||
}
|
||||
c.Cookie(cookieAccess)
|
||||
if tokenResp.RefreshToken != "" {
|
||||
cookieRefresh := &fiber.Cookie{
|
||||
Name: refreshName,
|
||||
Value: tokenResp.RefreshToken,
|
||||
Path: "/",
|
||||
Domain: cookieDomain,
|
||||
HTTPOnly: true,
|
||||
Secure: config.SSOCookieSecure,
|
||||
SameSite: sameSite,
|
||||
MaxAge: int((time.Hour * 24 * 30).Seconds()),
|
||||
}
|
||||
c.Cookie(cookieRefresh)
|
||||
}
|
||||
|
||||
// Optional: expose limited info via headers for FE debugging (avoid tokens)
|
||||
c.Set("X-Auth-User", fmt.Sprintf("%d", verification.UserID))
|
||||
}
|
||||
|
||||
func clearSSOCookie(c *fiber.Ctx, name string) {
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
|
||||
sameSite := config.SSOCookieSameSite
|
||||
if sameSite == "" {
|
||||
sameSite = "Lax"
|
||||
}
|
||||
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Domain: config.SSOCookieDomain,
|
||||
HTTPOnly: true,
|
||||
Secure: config.SSOCookieSecure,
|
||||
SameSite: sameSite,
|
||||
Expires: time.Unix(0, 0),
|
||||
MaxAge: -1,
|
||||
})
|
||||
}
|
||||
|
||||
func resolveSSOCookieName(configuredName, fallback string) string {
|
||||
name := strings.TrimSpace(configuredName)
|
||||
if name != "" {
|
||||
return name
|
||||
}
|
||||
return strings.TrimSpace(fallback)
|
||||
}
|
||||
|
||||
func normalizeClientParam(raw string) string {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
if idx := strings.Index(value, "|"); idx >= 0 {
|
||||
value = value[:idx]
|
||||
}
|
||||
value = strings.TrimSpace(value)
|
||||
return strings.ToLower(value)
|
||||
}
|
||||
|
||||
func findSSOClientConfig(requestedAlias string) (string, config.SSOClientConfig, bool) {
|
||||
if requestedAlias == "" {
|
||||
return "", config.SSOClientConfig{}, false
|
||||
}
|
||||
if cfg, ok := config.SSOClients[requestedAlias]; ok && strings.TrimSpace(cfg.PublicID) != "" {
|
||||
return requestedAlias, cfg, true
|
||||
}
|
||||
for alias, cfg := range config.SSOClients {
|
||||
if strings.EqualFold(strings.TrimSpace(cfg.PublicID), requestedAlias) && strings.TrimSpace(cfg.PublicID) != "" {
|
||||
return alias, cfg, true
|
||||
}
|
||||
}
|
||||
return "", config.SSOClientConfig{}, false
|
||||
}
|
||||
|
||||
func normalizeReturnTarget(returnTo string, cfg config.SSOClientConfig) (string, error) {
|
||||
returnTo = strings.TrimSpace(returnTo)
|
||||
if returnTo == "" {
|
||||
return "", nil
|
||||
}
|
||||
if strings.HasPrefix(returnTo, "//") {
|
||||
return "", fmt.Errorf("invalid return_to")
|
||||
}
|
||||
if strings.HasPrefix(returnTo, "/") {
|
||||
return returnTo, nil
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(returnTo)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid return_to")
|
||||
}
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
return "", fmt.Errorf("invalid return_to scheme")
|
||||
}
|
||||
|
||||
allowedOrigins := make(map[string]struct{})
|
||||
if cfg.DefaultReturnURI != "" {
|
||||
if u, err := url.Parse(cfg.DefaultReturnURI); err == nil && u.Host != "" {
|
||||
allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{}
|
||||
}
|
||||
}
|
||||
for _, origin := range cfg.AllowedReturnOrigins {
|
||||
origin = strings.TrimSpace(origin)
|
||||
if origin == "" {
|
||||
continue
|
||||
}
|
||||
if u, err := url.Parse(origin); err == nil && u.Host != "" && (u.Scheme == "http" || u.Scheme == "https") {
|
||||
allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowedOrigins) > 0 {
|
||||
origin := parsed.Scheme + "://" + parsed.Host
|
||||
if _, ok := allowedOrigins[origin]; !ok {
|
||||
return "", fmt.Errorf("return_to origin not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
return parsed.String(), nil
|
||||
}
|
||||
@@ -0,0 +1,429 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
||||
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/modules/users/dto"
|
||||
userRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/users/repositories"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/response"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/sso"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
headerClient = "X-Sync-Client"
|
||||
headerTimestamp = "X-Sync-Timestamp"
|
||||
headerNonce = "X-Sync-Nonce"
|
||||
headerSignature = "X-Sync-Signature"
|
||||
defaultDrift = 2 * time.Minute
|
||||
defaultNonceTTL = 10 * time.Minute
|
||||
)
|
||||
|
||||
// UserSyncController handles incoming user management events from the central SSO service.
|
||||
type UserSyncController struct {
|
||||
validate *validator.Validate
|
||||
repo userRepository.UserRepository
|
||||
redis *redis.Client
|
||||
clients map[string]config.SSOClientConfig
|
||||
drift time.Duration
|
||||
nonceTTL time.Duration
|
||||
maxBodyBytes int
|
||||
log *logrus.Logger
|
||||
localNonces sync.Map
|
||||
}
|
||||
|
||||
type userSyncRequest struct {
|
||||
Action string `json:"action" validate:"required,oneof=create update delete logout"`
|
||||
PublicID string `json:"public_id" validate:"required"`
|
||||
User userSyncUser `json:"user" validate:"required"`
|
||||
}
|
||||
|
||||
type userSyncUser struct {
|
||||
ID int64 `json:"id" validate:"required"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func NewUserSyncController(validate *validator.Validate, repo userRepository.UserRepository, redis *redis.Client, clients map[string]config.SSOClientConfig) *UserSyncController {
|
||||
normalized := make(map[string]config.SSOClientConfig, len(clients))
|
||||
for alias, cfg := range clients {
|
||||
alias = strings.ToLower(strings.TrimSpace(alias))
|
||||
normalized[alias] = cfg
|
||||
}
|
||||
|
||||
drift := config.SSOUserSyncDrift
|
||||
if drift <= 0 {
|
||||
drift = defaultDrift
|
||||
}
|
||||
|
||||
nonceTTL := config.SSOUserSyncNonceTTL
|
||||
if nonceTTL <= 0 {
|
||||
nonceTTL = defaultNonceTTL
|
||||
}
|
||||
|
||||
maxBody := config.SSOUserSyncMaxBodyBytes
|
||||
if maxBody <= 0 {
|
||||
maxBody = 32 * 1024
|
||||
}
|
||||
|
||||
log := utils.Log
|
||||
if redis == nil {
|
||||
log.Warn("SSO user sync nonce store fallback to in-memory cache; enable Redis for replay protection")
|
||||
}
|
||||
|
||||
return &UserSyncController{
|
||||
validate: validate,
|
||||
repo: repo,
|
||||
redis: redis,
|
||||
clients: normalized,
|
||||
drift: drift,
|
||||
nonceTTL: nonceTTL,
|
||||
maxBodyBytes: maxBody,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *UserSyncController) Sync(c *fiber.Ctx) error {
|
||||
if ct := strings.TrimSpace(c.Get(fiber.HeaderContentType)); ct != "" && !strings.HasPrefix(strings.ToLower(ct), fiber.MIMEApplicationJSON) {
|
||||
return fiber.NewError(fiber.StatusUnsupportedMediaType, "content-type must be application/json")
|
||||
}
|
||||
|
||||
body := c.Body()
|
||||
if h.maxBodyBytes > 0 && len(body) > h.maxBodyBytes {
|
||||
return fiber.NewError(fiber.StatusRequestEntityTooLarge, "request body too large")
|
||||
}
|
||||
|
||||
alias, clientCfg, err := h.authenticate(c, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := new(userSyncRequest)
|
||||
if err := json.Unmarshal(body, req); err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "invalid request body")
|
||||
}
|
||||
|
||||
req.Action = strings.ToLower(strings.TrimSpace(req.Action))
|
||||
req.PublicID = strings.TrimSpace(req.PublicID)
|
||||
req.User.Email = strings.TrimSpace(req.User.Email)
|
||||
req.User.Name = strings.TrimSpace(req.User.Name)
|
||||
|
||||
if err := h.validate.Struct(req); err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
if clientCfg.PublicID != "" && req.PublicID != clientCfg.PublicID {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "public_id mismatch with configured client")
|
||||
}
|
||||
|
||||
if req.Action == "create" || req.Action == "update" {
|
||||
if req.User.Email == "" || req.User.Name == "" {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "email and name are required for create/update actions")
|
||||
}
|
||||
if err := h.validate.Var(req.User.Email, "email"); err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "invalid email format")
|
||||
}
|
||||
}
|
||||
|
||||
if req.User.ID <= 0 {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "invalid user id")
|
||||
}
|
||||
|
||||
switch req.Action {
|
||||
case "create", "update":
|
||||
return h.upsertUser(c, alias, req)
|
||||
case "delete":
|
||||
return h.removeUser(c, alias, req)
|
||||
case "logout":
|
||||
return h.logoutUser(c, alias, req)
|
||||
default:
|
||||
return fiber.NewError(fiber.StatusBadRequest, "unsupported action")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *UserSyncController) authenticate(c *fiber.Ctx, body []byte) (string, config.SSOClientConfig, error) {
|
||||
rawAlias := strings.TrimSpace(c.Get(headerClient))
|
||||
if rawAlias == "" {
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "missing sync client header")
|
||||
}
|
||||
|
||||
aliasKey := strings.ToLower(rawAlias)
|
||||
clientCfg, ok := h.clients[aliasKey]
|
||||
if !ok {
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "unknown sync client")
|
||||
}
|
||||
|
||||
if err := h.verifyAuthorization(c, aliasKey); err != nil {
|
||||
return "", config.SSOClientConfig{}, err
|
||||
}
|
||||
|
||||
secret := strings.TrimSpace(clientCfg.SyncSecret)
|
||||
if secret == "" {
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "sync secret not configured")
|
||||
}
|
||||
|
||||
timestamp := strings.TrimSpace(c.Get(headerTimestamp))
|
||||
nonce := strings.TrimSpace(c.Get(headerNonce))
|
||||
signature := strings.TrimSpace(c.Get(headerSignature))
|
||||
|
||||
if timestamp == "" || nonce == "" || signature == "" {
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "missing signature headers")
|
||||
}
|
||||
if len(nonce) < 16 {
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "nonce too short")
|
||||
}
|
||||
|
||||
ts, err := strconv.ParseInt(timestamp, 10, 64)
|
||||
if err != nil {
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusBadRequest, "invalid timestamp")
|
||||
}
|
||||
|
||||
msgTime := time.Unix(ts, 0).UTC()
|
||||
now := time.Now().UTC()
|
||||
drift := now.Sub(msgTime)
|
||||
if drift > h.drift || drift < -h.drift {
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "timestamp outside allowed window")
|
||||
}
|
||||
|
||||
providedSig, err := decodeSignature(signature)
|
||||
if err != nil {
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature encoding")
|
||||
}
|
||||
|
||||
expectedSignature := h.calculateSignature(secret, rawAlias, timestamp, nonce, body)
|
||||
if !hmac.Equal(providedSig, expectedSignature) {
|
||||
bodyHash := sha256.Sum256(body)
|
||||
h.log.WithFields(logrus.Fields{
|
||||
"alias": rawAlias,
|
||||
"alias_key": aliasKey,
|
||||
"timestamp": timestamp,
|
||||
"nonce": nonce,
|
||||
"body_len": len(body),
|
||||
"body_sha256": hex.EncodeToString(bodyHash[:]),
|
||||
"body_base64": base64.StdEncoding.EncodeToString(body),
|
||||
"provided_hex_full": hex.EncodeToString(providedSig),
|
||||
"expected_hex_full": hex.EncodeToString(expectedSignature),
|
||||
}).Warn("sso sync signature mismatch")
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature")
|
||||
}
|
||||
|
||||
if err := h.registerNonce(c.Context(), aliasKey, nonce); err != nil {
|
||||
return "", config.SSOClientConfig{}, err
|
||||
}
|
||||
|
||||
return aliasKey, clientCfg, nil
|
||||
}
|
||||
|
||||
func (h *UserSyncController) verifyAuthorization(c *fiber.Ctx, alias string) error {
|
||||
|
||||
authHeader := strings.TrimSpace(c.Get(fiber.HeaderAuthorization))
|
||||
if authHeader == "" {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "missing authorization header")
|
||||
}
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header")
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(parts[1])
|
||||
if token == "" {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header")
|
||||
}
|
||||
|
||||
verification, err := sso.VerifyAccessToken(token)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
|
||||
}
|
||||
|
||||
if verification.ServiceAlias == "" || verification.ServiceAlias != alias {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "service subject mismatch")
|
||||
}
|
||||
if !containsScope(verification.Claims.Scopes(), "sync.users") {
|
||||
return fiber.NewError(fiber.StatusForbidden, "missing sync scope")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *UserSyncController) upsertUser(c *fiber.Ctx, alias string, req *userSyncRequest) error {
|
||||
entity := &entity.User{
|
||||
IdUser: req.User.ID,
|
||||
Email: req.User.Email,
|
||||
Name: req.User.Name,
|
||||
}
|
||||
|
||||
//TODO: MIGRATION TO UPSERT BASE REPOSITORY
|
||||
if err := h.repo.UpsertByIdUser(c.Context(), entity); err != nil {
|
||||
h.log.Errorf("sso user upsert failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to upsert user")
|
||||
}
|
||||
|
||||
user, err := h.repo.GetByIdUser(c.Context(), req.User.ID, nil)
|
||||
if err != nil {
|
||||
h.log.Errorf("sso user fetch after upsert failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to load user")
|
||||
}
|
||||
|
||||
h.log.WithFields(logrus.Fields{
|
||||
"action": req.Action,
|
||||
"public_id": req.PublicID,
|
||||
"alias": alias,
|
||||
"user_id": req.User.ID,
|
||||
}).Info("sso user synced")
|
||||
|
||||
msg := fmt.Sprintf("User %s successfully", req.Action)
|
||||
return c.Status(fiber.StatusOK).JSON(response.Success{
|
||||
Code: fiber.StatusOK,
|
||||
Status: "success",
|
||||
Message: msg,
|
||||
Data: dto.ToUserListDTO(*user),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *UserSyncController) logoutUser(c *fiber.Ctx, alias string, req *userSyncRequest) error {
|
||||
revoker := session.GetRevocationStore()
|
||||
if revoker != nil {
|
||||
if err := revoker.MarkUserLogout(c.Context(), uint(req.User.ID), time.Now().UTC()); err != nil {
|
||||
h.log.WithError(err).Error("sso user logout revoke failed")
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to revoke user session")
|
||||
}
|
||||
} else {
|
||||
h.log.Warn("sso user logout received but revocation store not configured")
|
||||
}
|
||||
|
||||
h.log.WithFields(logrus.Fields{
|
||||
"action": req.Action,
|
||||
"public_id": req.PublicID,
|
||||
"alias": alias,
|
||||
"user_id": req.User.ID,
|
||||
}).Info("sso user logout enforced")
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(response.Common{
|
||||
Code: fiber.StatusOK,
|
||||
Status: "success",
|
||||
Message: "User sessions revoked successfully",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *UserSyncController) removeUser(c *fiber.Ctx, alias string, req *userSyncRequest) error {
|
||||
if err := h.repo.SoftDeleteByIdUser(c.Context(), req.User.ID); err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return fiber.NewError(fiber.StatusNotFound, "user not found")
|
||||
}
|
||||
h.log.Errorf("sso user delete failed: %v", err)
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to delete user")
|
||||
}
|
||||
|
||||
h.log.WithFields(logrus.Fields{
|
||||
"action": req.Action,
|
||||
"public_id": req.PublicID,
|
||||
"alias": alias,
|
||||
"user_id": req.User.ID,
|
||||
}).Info("sso user deleted")
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(response.Common{
|
||||
Code: fiber.StatusOK,
|
||||
Status: "success",
|
||||
Message: "User deleted successfully",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *UserSyncController) registerNonce(ctx context.Context, alias, nonce string) error {
|
||||
ttl := h.nonceTTL
|
||||
if ttl <= 0 {
|
||||
ttl = defaultNonceTTL
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("sso:sync:%s:%s", alias, nonce)
|
||||
if h.redis != nil {
|
||||
stored, err := h.redis.SetNX(ctx, key, "1", ttl).Result()
|
||||
if err == nil {
|
||||
if !stored {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "nonce already used")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
h.log.Errorf("store sync nonce failed: %v", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if expRaw, ok := h.localNonces.Load(key); ok {
|
||||
if expTime, ok := expRaw.(time.Time); ok && expTime.After(now) {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "nonce already used")
|
||||
}
|
||||
}
|
||||
h.localNonces.Store(key, now.Add(ttl))
|
||||
h.pruneLocalNonces(now)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *UserSyncController) calculateSignature(secret, alias, timestamp, nonce string, body []byte) []byte {
|
||||
mac := hmac.New(sha256.New, []byte(secret))
|
||||
mac.Write([]byte(alias))
|
||||
mac.Write([]byte("\n"))
|
||||
mac.Write([]byte(timestamp))
|
||||
mac.Write([]byte("\n"))
|
||||
mac.Write([]byte(nonce))
|
||||
mac.Write([]byte("\n"))
|
||||
mac.Write(body)
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
func containsScope(scopes []string, target string) bool {
|
||||
target = strings.ToLower(strings.TrimSpace(target))
|
||||
if target == "" {
|
||||
return false
|
||||
}
|
||||
for _, scope := range scopes {
|
||||
if strings.ToLower(strings.TrimSpace(scope)) == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func decodeSignature(sig string) ([]byte, error) {
|
||||
sig = strings.TrimSpace(sig)
|
||||
if sig == "" {
|
||||
return nil, errors.New("empty signature")
|
||||
}
|
||||
if decoded, err := hex.DecodeString(sig); err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
if decoded, err := base64.StdEncoding.DecodeString(sig); err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
if decoded, err := base64.URLEncoding.DecodeString(sig); err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
return nil, errors.New("unrecognized signature encoding")
|
||||
}
|
||||
|
||||
func (h *UserSyncController) pruneLocalNonces(now time.Time) {
|
||||
h.localNonces.Range(func(key, value any) bool {
|
||||
exp, ok := value.(time.Time)
|
||||
if !ok || exp.Before(now) {
|
||||
h.localNonces.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package sso
|
||||
|
||||
import (
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Module struct{}
|
||||
|
||||
func (Module) RegisterRoutes(router fiber.Router, db *gorm.DB, validate *validator.Validate) {
|
||||
Routes(router, db, validate)
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package sso
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/cache"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/middleware"
|
||||
ssoController "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/controllers"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
|
||||
userRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/users/repositories"
|
||||
)
|
||||
|
||||
func Routes(router fiber.Router, db *gorm.DB, validate *validator.Validate) {
|
||||
ttl := config.SSOPKCETTL
|
||||
if ttl <= 0 {
|
||||
ttl = 5 * time.Minute
|
||||
}
|
||||
|
||||
store := session.NewStore(cache.MustRedis(), ttl)
|
||||
ctrl := ssoController.NewController(&http.Client{Timeout: 10 * time.Second}, store, session.GetRevocationStore())
|
||||
userRepo := userRepository.NewUserRepository(db)
|
||||
syncCtrl := ssoController.NewUserSyncController(validate, userRepo, cache.Redis(), config.SSOClients)
|
||||
|
||||
group := router.Group("/sso")
|
||||
group.Get("/start", middleware.NewLimiter(30, time.Minute), ctrl.Start)
|
||||
group.Get("/callback", ctrl.Callback)
|
||||
group.Get("/userinfo", middleware.NewLimiter(60, time.Minute), ctrl.UserInfo)
|
||||
group.Post("/logout", middleware.NewLimiter(60, time.Minute), ctrl.Logout)
|
||||
group.Post("/users/sync", middleware.NewLimiter(30, time.Minute), syncCtrl.Sync)
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RevocationStore handles token blacklist / revocation entries in Redis.
|
||||
type RevocationStore struct {
|
||||
redis *redis.Client
|
||||
prefix string
|
||||
}
|
||||
|
||||
var (
|
||||
globalRevokerMu sync.RWMutex
|
||||
globalRevoker *RevocationStore
|
||||
)
|
||||
|
||||
// NewRevocationStore creates a revocation store with the given redis client and key prefix.
|
||||
func NewRevocationStore(client *redis.Client, prefix string) *RevocationStore {
|
||||
return &RevocationStore{
|
||||
redis: client,
|
||||
prefix: strings.TrimSpace(prefix),
|
||||
}
|
||||
}
|
||||
|
||||
// SetRevocationStore registers the provided revocation store for global access.
|
||||
func SetRevocationStore(store *RevocationStore) {
|
||||
globalRevokerMu.Lock()
|
||||
globalRevoker = store
|
||||
globalRevokerMu.Unlock()
|
||||
}
|
||||
|
||||
// GetRevocationStore returns the globally registered revocation store, or nil if unset.
|
||||
func GetRevocationStore() *RevocationStore {
|
||||
globalRevokerMu.RLock()
|
||||
defer globalRevokerMu.RUnlock()
|
||||
return globalRevoker
|
||||
}
|
||||
|
||||
// MustRevocationStore returns the registered revocation store or panics if none is configured.
|
||||
func MustRevocationStore() *RevocationStore {
|
||||
store := GetRevocationStore()
|
||||
if store == nil {
|
||||
panic("revocation store not initialised")
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
// Revoke stores the fingerprint with the provided TTL.
|
||||
func (s *RevocationStore) Revoke(ctx context.Context, fingerprint string, ttl time.Duration) error {
|
||||
if s == nil || s.redis == nil {
|
||||
return errors.New("revocation store redis client not initialised")
|
||||
}
|
||||
fingerprint = strings.TrimSpace(fingerprint)
|
||||
if fingerprint == "" {
|
||||
return nil
|
||||
}
|
||||
if ttl <= 0 {
|
||||
ttl = time.Minute
|
||||
}
|
||||
key := s.keyFor(fingerprint)
|
||||
return s.redis.Set(ctx, key, "1", ttl).Err()
|
||||
}
|
||||
|
||||
// IsRevoked returns true when the fingerprint appears in the blacklist.
|
||||
func (s *RevocationStore) IsRevoked(ctx context.Context, fingerprint string) (bool, error) {
|
||||
if s == nil || s.redis == nil {
|
||||
return false, errors.New("revocation store redis client not initialised")
|
||||
}
|
||||
fingerprint = strings.TrimSpace(fingerprint)
|
||||
if fingerprint == "" {
|
||||
return false, nil
|
||||
}
|
||||
key := s.keyFor(fingerprint)
|
||||
exists, err := s.redis.Exists(ctx, key).Result()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return exists > 0, nil
|
||||
}
|
||||
|
||||
// MarkUserLogout stores the timestamp of the last forced logout for the given user.
|
||||
func (s *RevocationStore) MarkUserLogout(ctx context.Context, userID uint, at time.Time) error {
|
||||
if s == nil || s.redis == nil {
|
||||
return errors.New("revocation store redis client not initialised")
|
||||
}
|
||||
if userID == 0 {
|
||||
return errors.New("invalid user id")
|
||||
}
|
||||
key := s.userLogoutKey(userID)
|
||||
return s.redis.Set(ctx, key, at.UTC().Format(time.RFC3339Nano), 0).Err()
|
||||
}
|
||||
|
||||
// ClearUserLogout removes any stored forced logout marker for the given user.
|
||||
func (s *RevocationStore) ClearUserLogout(ctx context.Context, userID uint) error {
|
||||
if s == nil || s.redis == nil {
|
||||
return errors.New("revocation store redis client not initialised")
|
||||
}
|
||||
if userID == 0 {
|
||||
return errors.New("invalid user id")
|
||||
}
|
||||
key := s.userLogoutKey(userID)
|
||||
return s.redis.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// UserLogoutTime returns the timestamp of the last forced logout for the given user.
|
||||
func (s *RevocationStore) UserLogoutTime(ctx context.Context, userID uint) (time.Time, error) {
|
||||
var zero time.Time
|
||||
if s == nil || s.redis == nil {
|
||||
return zero, errors.New("revocation store redis client not initialised")
|
||||
}
|
||||
if userID == 0 {
|
||||
return zero, errors.New("invalid user id")
|
||||
}
|
||||
key := s.userLogoutKey(userID)
|
||||
value, err := s.redis.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return zero, nil
|
||||
}
|
||||
return zero, err
|
||||
}
|
||||
ts, err := time.Parse(time.RFC3339Nano, value)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
return ts, nil
|
||||
}
|
||||
|
||||
func (s *RevocationStore) keyFor(fingerprint string) string {
|
||||
prefix := s.prefix
|
||||
if prefix == "" {
|
||||
prefix = "sso:blacklist"
|
||||
}
|
||||
return prefix + ":" + fingerprint
|
||||
}
|
||||
|
||||
func (s *RevocationStore) userLogoutKey(userID uint) string {
|
||||
prefix := s.prefix
|
||||
if prefix == "" {
|
||||
prefix = "sso:blacklist"
|
||||
}
|
||||
return prefix + ":user-logout:" + strconv.FormatUint(uint64(userID), 10)
|
||||
}
|
||||
|
||||
// TokenFingerprint hashes token material before persisting it to the blacklist.
|
||||
func TokenFingerprint(token string) string {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const keyTemplate = "sso:pkce:%s"
|
||||
|
||||
// PKCESession holds data required to complete the OAuth2 PKCE exchange.
|
||||
type PKCESession struct {
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
Nonce string `json:"nonce"`
|
||||
ClientAlias string `json:"client_alias"`
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
Scope string `json:"scope"`
|
||||
ReturnTo string `json:"return_to,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// Store persists pkce sessions inside Redis using a configurable TTL.
|
||||
type Store struct {
|
||||
redis *redis.Client
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func NewStore(client *redis.Client, ttl time.Duration) *Store {
|
||||
return &Store{redis: client, ttl: ttl}
|
||||
}
|
||||
|
||||
func (s *Store) Save(ctx context.Context, state string, payload *PKCESession) error {
|
||||
if s.redis == nil {
|
||||
return fmt.Errorf("redis client is not initialised")
|
||||
}
|
||||
bytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.redis.Set(ctx, fmt.Sprintf(keyTemplate, state), bytes, s.ttl).Err()
|
||||
}
|
||||
|
||||
func (s *Store) Get(ctx context.Context, state string) (*PKCESession, error) {
|
||||
if s.redis == nil {
|
||||
return nil, fmt.Errorf("redis client is not initialised")
|
||||
}
|
||||
raw, err := s.redis.Get(ctx, fmt.Sprintf(keyTemplate, state)).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
var payload PKCESession
|
||||
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &payload, nil
|
||||
}
|
||||
|
||||
func (s *Store) Delete(ctx context.Context, state string) error {
|
||||
if s.redis == nil {
|
||||
return fmt.Errorf("redis client is not initialised")
|
||||
}
|
||||
return s.redis.Del(ctx, fmt.Sprintf(keyTemplate, state)).Err()
|
||||
}
|
||||
@@ -71,70 +71,70 @@ func (u *UserController) GetOne(c *fiber.Ctx) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (u *UserController) CreateOne(c *fiber.Ctx) error {
|
||||
req := new(validation.Create)
|
||||
// func (u *UserController) CreateOne(c *fiber.Ctx) error {
|
||||
// req := new(validation.Create)
|
||||
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "Invalid request body")
|
||||
}
|
||||
// if err := c.BodyParser(req); err != nil {
|
||||
// return fiber.NewError(fiber.StatusBadRequest, "Invalid request body")
|
||||
// }
|
||||
|
||||
result, err := u.UserService.CreateOne(c, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// result, err := u.UserService.CreateOne(c, req)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
return c.Status(fiber.StatusCreated).
|
||||
JSON(response.Success{
|
||||
Code: fiber.StatusCreated,
|
||||
Status: "success",
|
||||
Message: "Create user successfully",
|
||||
Data: dto.ToUserListDTO(*result),
|
||||
})
|
||||
}
|
||||
// return c.Status(fiber.StatusCreated).
|
||||
// JSON(response.Success{
|
||||
// Code: fiber.StatusCreated,
|
||||
// Status: "success",
|
||||
// Message: "Create user successfully",
|
||||
// Data: dto.ToUserListDTO(*result),
|
||||
// })
|
||||
// }
|
||||
|
||||
func (u *UserController) UpdateOne(c *fiber.Ctx) error {
|
||||
req := new(validation.Update)
|
||||
param := c.Params("id")
|
||||
// func (u *UserController) UpdateOne(c *fiber.Ctx) error {
|
||||
// req := new(validation.Update)
|
||||
// param := c.Params("id")
|
||||
|
||||
id, err := strconv.Atoi(param)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "Invalid Id")
|
||||
}
|
||||
// id, err := strconv.Atoi(param)
|
||||
// if err != nil {
|
||||
// return fiber.NewError(fiber.StatusBadRequest, "Invalid Id")
|
||||
// }
|
||||
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "Invalid request body")
|
||||
}
|
||||
// if err := c.BodyParser(req); err != nil {
|
||||
// return fiber.NewError(fiber.StatusBadRequest, "Invalid request body")
|
||||
// }
|
||||
|
||||
result, err := u.UserService.UpdateOne(c, req, uint(id))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// result, err := u.UserService.UpdateOne(c, req, uint(id))
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
return c.Status(fiber.StatusOK).
|
||||
JSON(response.Success{
|
||||
Code: fiber.StatusOK,
|
||||
Status: "success",
|
||||
Message: "Update user successfully",
|
||||
Data: dto.ToUserListDTO(*result),
|
||||
})
|
||||
}
|
||||
// return c.Status(fiber.StatusOK).
|
||||
// JSON(response.Success{
|
||||
// Code: fiber.StatusOK,
|
||||
// Status: "success",
|
||||
// Message: "Update user successfully",
|
||||
// Data: dto.ToUserListDTO(*result),
|
||||
// })
|
||||
// }
|
||||
|
||||
func (u *UserController) DeleteOne(c *fiber.Ctx) error {
|
||||
param := c.Params("id")
|
||||
// func (u *UserController) DeleteOne(c *fiber.Ctx) error {
|
||||
// param := c.Params("id")
|
||||
|
||||
id, err := strconv.Atoi(param)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "Invalid Id")
|
||||
}
|
||||
// id, err := strconv.Atoi(param)
|
||||
// if err != nil {
|
||||
// return fiber.NewError(fiber.StatusBadRequest, "Invalid Id")
|
||||
// }
|
||||
|
||||
if err := u.UserService.DeleteOne(c, uint(id)); err != nil {
|
||||
return err
|
||||
}
|
||||
// if err := u.UserService.DeleteOne(c, uint(id)); err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
return c.Status(fiber.StatusOK).
|
||||
JSON(response.Common{
|
||||
Code: fiber.StatusOK,
|
||||
Status: "success",
|
||||
Message: "Delete user successfully",
|
||||
})
|
||||
}
|
||||
// return c.Status(fiber.StatusOK).
|
||||
// JSON(response.Common{
|
||||
// Code: fiber.StatusOK,
|
||||
// Status: "success",
|
||||
// Message: "Delete user successfully",
|
||||
// })
|
||||
// }
|
||||
|
||||
@@ -1,21 +1,112 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/common/repository"
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
commonrepo "gitlab.com/mbugroup/lti-api.git/internal/common/repository"
|
||||
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type UserRepository interface {
|
||||
repository.BaseRepository[entity.User]
|
||||
commonrepo.BaseRepository[entity.User]
|
||||
GetByIdUser(ctx context.Context, idUser int64, modifier func(*gorm.DB) *gorm.DB) (*entity.User, error)
|
||||
UpsertByIdUser(ctx context.Context, user *entity.User) error
|
||||
SoftDeleteByIdUser(ctx context.Context, idUser int64) error
|
||||
}
|
||||
|
||||
type UserRepositoryImpl struct {
|
||||
*repository.BaseRepositoryImpl[entity.User]
|
||||
*commonrepo.BaseRepositoryImpl[entity.User]
|
||||
}
|
||||
|
||||
func NewUserRepository(db *gorm.DB) UserRepository {
|
||||
return &UserRepositoryImpl{
|
||||
BaseRepositoryImpl: repository.NewBaseRepository[entity.User](db),
|
||||
BaseRepositoryImpl: commonrepo.NewBaseRepository[entity.User](db),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *UserRepositoryImpl) GetByIdUser(
|
||||
ctx context.Context,
|
||||
idUser int64,
|
||||
modifier func(*gorm.DB) *gorm.DB,
|
||||
) (*entity.User, error) {
|
||||
return r.BaseRepositoryImpl.First(ctx, func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("id_user = ?", idUser)
|
||||
})
|
||||
}
|
||||
|
||||
func (r *UserRepositoryImpl) UpsertByIdUser(ctx context.Context, user *entity.User) error {
|
||||
if user == nil {
|
||||
return gorm.ErrInvalidData
|
||||
}
|
||||
|
||||
return r.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
now := time.Now()
|
||||
user.DeletedAt = gorm.DeletedAt{}
|
||||
user.UpdatedAt = now
|
||||
|
||||
err := tx.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "id_user"}},
|
||||
UpdateAll: true,
|
||||
}).Omit("id", "created_at").Create(user).Error
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isUniqueViolation(err, "users_email_unique") {
|
||||
return err
|
||||
}
|
||||
|
||||
var existing entity.User
|
||||
lockQuery := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("email = ?", user.Email)
|
||||
if err := lockQuery.First(&existing).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.Id = existing.Id
|
||||
|
||||
updates := map[string]any{
|
||||
"id_user": user.IdUser,
|
||||
"email": user.Email,
|
||||
"name": user.Name,
|
||||
"updated_at": now,
|
||||
"deleted_at": gorm.DeletedAt{},
|
||||
}
|
||||
|
||||
if err := tx.Model(&entity.User{}).Where("id = ?", existing.Id).Updates(updates).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *UserRepositoryImpl) SoftDeleteByIdUser(ctx context.Context, idUser int64) error {
|
||||
query := r.DB().WithContext(ctx).Where("id_user = ?", idUser)
|
||||
result := query.Delete(&entity.User{})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isUniqueViolation(err error, constraint string) bool {
|
||||
var pgErr *pgconn.PgError
|
||||
if !errors.As(err, &pgErr) {
|
||||
return false
|
||||
}
|
||||
if pgErr.Code != "23505" {
|
||||
return false
|
||||
}
|
||||
if constraint == "" {
|
||||
return true
|
||||
}
|
||||
return pgErr.ConstraintName == constraint
|
||||
}
|
||||
|
||||
@@ -1,20 +1,22 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/middleware"
|
||||
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/users/controllers"
|
||||
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func UserRoutes(v1 fiber.Router, s user.UserService) {
|
||||
ctrl := controller.NewUserController(s)
|
||||
|
||||
route := v1.Group("/users")
|
||||
route.Use(middleware.Auth(s))
|
||||
|
||||
route.Get("/", ctrl.GetAll)
|
||||
route.Post("/", ctrl.CreateOne)
|
||||
// route.Post("/", ctrl.CreateOne)
|
||||
route.Get("/:id", ctrl.GetOne)
|
||||
route.Patch("/:id", ctrl.UpdateOne)
|
||||
route.Delete("/:id", ctrl.DeleteOne)
|
||||
// route.Patch("/:id", ctrl.UpdateOne)
|
||||
// route.Delete("/:id", ctrl.DeleteOne)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ type UserService interface {
|
||||
CreateOne(ctx *fiber.Ctx, req *validation.Create) (*entity.User, error)
|
||||
UpdateOne(ctx *fiber.Ctx, req *validation.Update, id uint) (*entity.User, error)
|
||||
DeleteOne(ctx *fiber.Ctx, id uint) error
|
||||
GetBySSOUserID(ctx *fiber.Ctx, ssoUserID uint) (*entity.User, error)
|
||||
}
|
||||
|
||||
type userService struct {
|
||||
@@ -68,6 +69,18 @@ func (s userService) GetOne(c *fiber.Ctx, id uint) (*entity.User, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s userService) GetBySSOUserID(c *fiber.Ctx, ssoUserID uint) (*entity.User, error) {
|
||||
user, err := s.Repository.GetByIdUser(c.Context(), int64(ssoUserID), nil)
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fiber.NewError(fiber.StatusNotFound, "User not found")
|
||||
}
|
||||
if err != nil {
|
||||
s.Log.Errorf("Failed get user by SSO id: %+v", err)
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *userService) CreateOne(c *fiber.Ctx, req *validation.Create) (*entity.User, error) {
|
||||
if err := s.Validate.Struct(req); err != nil {
|
||||
return nil, err
|
||||
|
||||
Reference in New Issue
Block a user