Feat(BE-69,70,71,72,73): crud and integration sso with lti, revoke_token

This commit is contained in:
ragilap
2025-10-21 20:31:10 +07:00
committed by Adnan Zahir
parent 94f4929749
commit 40665b0d8f
5 changed files with 108 additions and 16 deletions
+1 -1
View File
@@ -15,7 +15,7 @@ type SSOClientConfig struct {
PublicID string `json:"public_id"` PublicID string `json:"public_id"`
RedirectURI string `json:"redirect_uri"` RedirectURI string `json:"redirect_uri"`
Scope string `json:"scope"` Scope string `json:"scope"`
Prompt string `json:"prompt"` // Prompt string `json:"prompt"`
DefaultReturnURI string `json:"default_return_uri"` DefaultReturnURI string `json:"default_return_uri"`
AllowedReturnOrigins []string `json:"allowed_return_origins"` AllowedReturnOrigins []string `json:"allowed_return_origins"`
SyncSecret string `json:"sync_secret"` SyncSecret string `json:"sync_secret"`
@@ -105,9 +105,9 @@ func (h *Controller) Start(c *fiber.Ctx) error {
query.Set("code_challenge", challenge) query.Set("code_challenge", challenge)
query.Set("code_challenge_method", "S256") query.Set("code_challenge_method", "S256")
query.Set("nonce", nonce) query.Set("nonce", nonce)
if prompt := strings.TrimSpace(cfg.Prompt); prompt != "" { // if prompt := strings.TrimSpace(cfg.Prompt); prompt != "" {
query.Set("prompt", prompt) // query.Set("prompt", prompt)
} // }
if extraPrompt := strings.TrimSpace(c.Query("prompt")); extraPrompt != "" { if extraPrompt := strings.TrimSpace(c.Query("prompt")); extraPrompt != "" {
query.Set("prompt", extraPrompt) query.Set("prompt", extraPrompt)
} }
@@ -323,7 +323,6 @@ func (h *Controller) Logout(c *fiber.Ctx) error {
if requestedAlias == "" { if requestedAlias == "" {
requestedAlias = normalizeClientParam(c.Query("client_id")) requestedAlias = normalizeClientParam(c.Query("client_id"))
} }
var ( var (
alias string alias string
cfg config.SSOClientConfig cfg config.SSOClientConfig
@@ -343,7 +342,6 @@ func (h *Controller) Logout(c *fiber.Ctx) error {
if refreshName != "" { if refreshName != "" {
refreshToken = strings.TrimSpace(c.Cookies(refreshName)) refreshToken = strings.TrimSpace(c.Cookies(refreshName))
} }
hadAccessCookie := accessToken != "" hadAccessCookie := accessToken != ""
hadRefreshCookie := refreshToken != "" hadRefreshCookie := refreshToken != ""
@@ -362,6 +360,11 @@ func (h *Controller) Logout(c *fiber.Ctx) error {
if verification, err := sso.VerifyAccessToken(accessToken); err != nil { if verification, err := sso.VerifyAccessToken(accessToken); err != nil {
utils.Log.WithError(err).Warn("failed to verify access token during logout") utils.Log.WithError(err).Warn("failed to verify access token during logout")
} else { } else {
if revoker := session.GetRevocationStore(); revoker != nil {
if err := revoker.MarkUserLogout(c.Context(), verification.UserID, time.Now().UTC()); err != nil {
utils.Log.WithError(err).Warn("failed to mark user logout")
}
}
h.revokeToken(c.Context(), accessToken, verification) h.revokeToken(c.Context(), accessToken, verification)
} }
} }
@@ -450,6 +453,12 @@ func issueCookies(c *fiber.Ctx, tokenResp struct {
Error string `json:"error"` Error string `json:"error"`
Description string `json:"error_description"` Description string `json:"error_description"`
}, verification *sso.VerificationResult) { }, verification *sso.VerificationResult) {
if revoker := session.GetRevocationStore(); revoker != nil && verification != nil {
if err := revoker.ClearUserLogout(c.Context(), verification.UserID); err != nil {
utils.Log.WithError(err).Warn("failed to clear logout marker")
}
}
accessName := resolveSSOCookieName(config.SSOAccessCookieName, "access") accessName := resolveSSOCookieName(config.SSOAccessCookieName, "access")
refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh") refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh")
maxAge := tokenResp.ExpiresIn maxAge := tokenResp.ExpiresIn
@@ -9,18 +9,19 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gorm.io/gorm" "gorm.io/gorm"
"strconv"
"strings"
"sync"
"time"
"gitlab.com/mbugroup/lti-api.git/internal/config" "gitlab.com/mbugroup/lti-api.git/internal/config"
entity "gitlab.com/mbugroup/lti-api.git/internal/entities" entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
"gitlab.com/mbugroup/lti-api.git/internal/modules/users/dto" "gitlab.com/mbugroup/lti-api.git/internal/modules/users/dto"
userRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/users/repositories" userRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/users/repositories"
"gitlab.com/mbugroup/lti-api.git/internal/response" "gitlab.com/mbugroup/lti-api.git/internal/response"
@@ -51,7 +52,7 @@ type UserSyncController struct {
} }
type userSyncRequest struct { type userSyncRequest struct {
Action string `json:"action" validate:"required,oneof=create update delete"` Action string `json:"action" validate:"required,oneof=create update delete logout"`
PublicID string `json:"public_id" validate:"required"` PublicID string `json:"public_id" validate:"required"`
User userSyncUser `json:"user" validate:"required"` User userSyncUser `json:"user" validate:"required"`
} }
@@ -134,7 +135,7 @@ func (h *UserSyncController) Sync(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusBadRequest, "public_id mismatch with configured client") return fiber.NewError(fiber.StatusBadRequest, "public_id mismatch with configured client")
} }
if req.Action != "delete" { if req.Action == "create" || req.Action == "update" {
if req.User.Email == "" || req.User.Name == "" { if req.User.Email == "" || req.User.Name == "" {
return fiber.NewError(fiber.StatusBadRequest, "email and name are required for create/update actions") return fiber.NewError(fiber.StatusBadRequest, "email and name are required for create/update actions")
} }
@@ -152,6 +153,8 @@ func (h *UserSyncController) Sync(c *fiber.Ctx) error {
return h.upsertUser(c, alias, req) return h.upsertUser(c, alias, req)
case "delete": case "delete":
return h.removeUser(c, alias, req) return h.removeUser(c, alias, req)
case "logout":
return h.logoutUser(c, alias, req)
default: default:
return fiber.NewError(fiber.StatusBadRequest, "unsupported action") return fiber.NewError(fiber.StatusBadRequest, "unsupported action")
} }
@@ -231,11 +234,11 @@ func (h *UserSyncController) authenticate(c *fiber.Ctx, body []byte) (string, co
} }
func (h *UserSyncController) verifyAuthorization(c *fiber.Ctx, alias string) error { func (h *UserSyncController) verifyAuthorization(c *fiber.Ctx, alias string) error {
authHeader := strings.TrimSpace(c.Get(fiber.HeaderAuthorization)) authHeader := strings.TrimSpace(c.Get(fiber.HeaderAuthorization))
if authHeader == "" { if authHeader == "" {
return fiber.NewError(fiber.StatusUnauthorized, "missing authorization header") return fiber.NewError(fiber.StatusUnauthorized, "missing authorization header")
} }
parts := strings.SplitN(authHeader, " ", 2) parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header") return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header")
@@ -254,7 +257,6 @@ func (h *UserSyncController) verifyAuthorization(c *fiber.Ctx, alias string) err
if verification.ServiceAlias == "" || verification.ServiceAlias != alias { if verification.ServiceAlias == "" || verification.ServiceAlias != alias {
return fiber.NewError(fiber.StatusUnauthorized, "service subject mismatch") return fiber.NewError(fiber.StatusUnauthorized, "service subject mismatch")
} }
if !containsScope(verification.Claims.Scopes(), "sync.users") { if !containsScope(verification.Claims.Scopes(), "sync.users") {
return fiber.NewError(fiber.StatusForbidden, "missing sync scope") return fiber.NewError(fiber.StatusForbidden, "missing sync scope")
} }
@@ -297,6 +299,31 @@ func (h *UserSyncController) upsertUser(c *fiber.Ctx, alias string, req *userSyn
}) })
} }
func (h *UserSyncController) logoutUser(c *fiber.Ctx, alias string, req *userSyncRequest) error {
revoker := session.GetRevocationStore()
if revoker != nil {
if err := revoker.MarkUserLogout(c.Context(), uint(req.User.ID), time.Now().UTC()); err != nil {
h.log.WithError(err).Error("sso user logout revoke failed")
return fiber.NewError(fiber.StatusInternalServerError, "failed to revoke user session")
}
} else {
h.log.Warn("sso user logout received but revocation store not configured")
}
h.log.WithFields(logrus.Fields{
"action": req.Action,
"public_id": req.PublicID,
"alias": alias,
"user_id": req.User.ID,
}).Info("sso user logout enforced")
return c.Status(fiber.StatusOK).JSON(response.Common{
Code: fiber.StatusOK,
Status: "success",
Message: "User sessions revoked successfully",
})
}
func (h *UserSyncController) removeUser(c *fiber.Ctx, alias string, req *userSyncRequest) error { func (h *UserSyncController) removeUser(c *fiber.Ctx, alias string, req *userSyncRequest) error {
if err := h.repo.SoftDeleteByIdUser(c.Context(), req.User.ID); err != nil { if err := h.repo.SoftDeleteByIdUser(c.Context(), req.User.ID); err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
@@ -5,6 +5,7 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"errors" "errors"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -87,6 +88,54 @@ func (s *RevocationStore) IsRevoked(ctx context.Context, fingerprint string) (bo
return exists > 0, nil return exists > 0, nil
} }
// MarkUserLogout stores the timestamp of the last forced logout for the given user.
func (s *RevocationStore) MarkUserLogout(ctx context.Context, userID uint, at time.Time) error {
if s == nil || s.redis == nil {
return errors.New("revocation store redis client not initialised")
}
if userID == 0 {
return errors.New("invalid user id")
}
key := s.userLogoutKey(userID)
return s.redis.Set(ctx, key, at.UTC().Format(time.RFC3339Nano), 0).Err()
}
// ClearUserLogout removes any stored forced logout marker for the given user.
func (s *RevocationStore) ClearUserLogout(ctx context.Context, userID uint) error {
if s == nil || s.redis == nil {
return errors.New("revocation store redis client not initialised")
}
if userID == 0 {
return errors.New("invalid user id")
}
key := s.userLogoutKey(userID)
return s.redis.Del(ctx, key).Err()
}
// UserLogoutTime returns the timestamp of the last forced logout for the given user.
func (s *RevocationStore) UserLogoutTime(ctx context.Context, userID uint) (time.Time, error) {
var zero time.Time
if s == nil || s.redis == nil {
return zero, errors.New("revocation store redis client not initialised")
}
if userID == 0 {
return zero, errors.New("invalid user id")
}
key := s.userLogoutKey(userID)
value, err := s.redis.Get(ctx, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return zero, nil
}
return zero, err
}
ts, err := time.Parse(time.RFC3339Nano, value)
if err != nil {
return zero, err
}
return ts, nil
}
func (s *RevocationStore) keyFor(fingerprint string) string { func (s *RevocationStore) keyFor(fingerprint string) string {
prefix := s.prefix prefix := s.prefix
if prefix == "" { if prefix == "" {
@@ -95,6 +144,14 @@ func (s *RevocationStore) keyFor(fingerprint string) string {
return prefix + ":" + fingerprint return prefix + ":" + fingerprint
} }
func (s *RevocationStore) userLogoutKey(userID uint) string {
prefix := s.prefix
if prefix == "" {
prefix = "sso:blacklist"
}
return prefix + ":user-logout:" + strconv.FormatUint(uint64(userID), 10)
}
// TokenFingerprint hashes token material before persisting it to the blacklist. // TokenFingerprint hashes token material before persisting it to the blacklist.
func TokenFingerprint(token string) string { func TokenFingerprint(token string) string {
token = strings.TrimSpace(token) token = strings.TrimSpace(token)
+1 -2
View File
@@ -106,7 +106,7 @@ func VerifyAccessToken(token string) (*VerificationResult, error) {
jwt.WithIssuedAt(), jwt.WithIssuedAt(),
jwt.WithExpirationRequired(), jwt.WithExpirationRequired(),
) )
tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc) tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
if err != nil { if err != nil {
return nil, fmt.Errorf("parse token: %w", err) return nil, fmt.Errorf("parse token: %w", err)
@@ -138,7 +138,6 @@ func VerifyAccessToken(token string) (*VerificationResult, error) {
} }
result := &VerificationResult{Claims: claims, Subject: sub} result := &VerificationResult{Claims: claims, Subject: sub}
switch { switch {
case strings.HasPrefix(sub, "user:"): case strings.HasPrefix(sub, "user:"):
idStr := strings.TrimPrefix(sub, "user:") idStr := strings.TrimPrefix(sub, "user:")