mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-20 13:31:56 +00:00
Feat(BE-69,70,71,72,73): crud and integration sso with lti, revoke_token
This commit is contained in:
@@ -25,22 +25,24 @@ import (
|
||||
type Controller struct {
|
||||
httpClient *http.Client
|
||||
store *session.Store
|
||||
revoker *session.RevocationStore
|
||||
}
|
||||
|
||||
func NewController(client *http.Client, store *session.Store) *Controller {
|
||||
return &Controller{httpClient: client, store: store}
|
||||
func NewController(client *http.Client, store *session.Store, revoker *session.RevocationStore) *Controller {
|
||||
return &Controller{httpClient: client, store: store, revoker: revoker}
|
||||
}
|
||||
|
||||
// Start handles GET /sso/start requests and redirects users to the central SSO authorize endpoint.
|
||||
func (h *Controller) Start(c *fiber.Ctx) error {
|
||||
alias := strings.ToLower(strings.TrimSpace(c.Query("client")))
|
||||
if alias == "" {
|
||||
alias = strings.ToLower(strings.TrimSpace(c.Query("client_id")))
|
||||
requestedAlias := normalizeClientParam(c.Query("client"))
|
||||
if requestedAlias == "" {
|
||||
requestedAlias = normalizeClientParam(c.Query("client_id"))
|
||||
}
|
||||
if alias == "" {
|
||||
if requestedAlias == "" {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "missing client")
|
||||
}
|
||||
cfg, ok := config.SSOClients[alias]
|
||||
|
||||
alias, cfg, ok := findSSOClientConfig(requestedAlias)
|
||||
if !ok || cfg.PublicID == "" {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "unknown client")
|
||||
}
|
||||
@@ -209,6 +211,7 @@ func (h *Controller) Callback(c *fiber.Ctx) error {
|
||||
return fiber.NewError(fiber.StatusBadGateway, "missing access token")
|
||||
}
|
||||
|
||||
fmt.Println(tokenResp.AccessToken)
|
||||
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
|
||||
if err != nil {
|
||||
utils.Log.Errorf("access token verification failed: %v", err)
|
||||
@@ -223,7 +226,6 @@ func (h *Controller) Callback(c *fiber.Ctx) error {
|
||||
redirectTarget = "/"
|
||||
}
|
||||
|
||||
fmt.Println(sessionData.ClientAlias,"test")
|
||||
utils.Log.WithFields(logrus.Fields{
|
||||
"client": sessionData.ClientAlias,
|
||||
"user_id": verification.UserID,
|
||||
@@ -255,6 +257,24 @@ func (h *Controller) UserInfo(c *fiber.Ctx) error {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
||||
}
|
||||
|
||||
if revoker := session.GetRevocationStore(); revoker != nil {
|
||||
if fingerprint := session.TokenFingerprint(token); fingerprint != "" {
|
||||
revoked, err := revoker.IsRevoked(c.Context(), fingerprint)
|
||||
if err != nil {
|
||||
utils.Log.WithError(err).Warn("failed to check token revocation for userinfo")
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
||||
}
|
||||
if revoked {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := sso.VerifyAccessToken(token); err != nil {
|
||||
utils.Log.WithError(err).Warn("access token verification failed for userinfo")
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
||||
}
|
||||
|
||||
endpoint := strings.TrimSpace(config.SSOGetMeURL)
|
||||
if endpoint == "" {
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "userinfo endpoint not configured")
|
||||
@@ -297,6 +317,129 @@ func (h *Controller) UserInfo(c *fiber.Ctx) error {
|
||||
return c.Status(resp.StatusCode).Send(body)
|
||||
}
|
||||
|
||||
// Logout clears SSO cookies and removes any leftover PKCE session state.
|
||||
func (h *Controller) Logout(c *fiber.Ctx) error {
|
||||
requestedAlias := normalizeClientParam(c.Query("client"))
|
||||
if requestedAlias == "" {
|
||||
requestedAlias = normalizeClientParam(c.Query("client_id"))
|
||||
}
|
||||
|
||||
var (
|
||||
alias string
|
||||
cfg config.SSOClientConfig
|
||||
hasClientInfo bool
|
||||
)
|
||||
if requestedAlias != "" {
|
||||
alias, cfg, hasClientInfo = findSSOClientConfig(requestedAlias)
|
||||
}
|
||||
|
||||
accessName := resolveSSOCookieName(config.SSOAccessCookieName, "access")
|
||||
refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh")
|
||||
|
||||
var accessToken, refreshToken string
|
||||
if accessName != "" {
|
||||
accessToken = strings.TrimSpace(c.Cookies(accessName))
|
||||
}
|
||||
if refreshName != "" {
|
||||
refreshToken = strings.TrimSpace(c.Cookies(refreshName))
|
||||
}
|
||||
|
||||
hadAccessCookie := accessToken != ""
|
||||
hadRefreshCookie := refreshToken != ""
|
||||
|
||||
state := strings.TrimSpace(c.Query("state"))
|
||||
if state != "" {
|
||||
if err := h.store.Delete(c.Context(), state); err != nil {
|
||||
utils.Log.Warnf("failed to delete pkce session during logout: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !hadAccessCookie && !hadRefreshCookie && state == "" {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "not authenticated")
|
||||
}
|
||||
|
||||
if hadAccessCookie {
|
||||
if verification, err := sso.VerifyAccessToken(accessToken); err != nil {
|
||||
utils.Log.WithError(err).Warn("failed to verify access token during logout")
|
||||
} else {
|
||||
h.revokeToken(c.Context(), accessToken, verification)
|
||||
}
|
||||
}
|
||||
if refreshToken != "" {
|
||||
h.revokeRefreshToken(c.Context(), refreshToken)
|
||||
}
|
||||
|
||||
clearSSOCookie(c, accessName)
|
||||
clearSSOCookie(c, refreshName)
|
||||
|
||||
redirectTarget := ""
|
||||
rawReturn := strings.TrimSpace(c.Query("return_to"))
|
||||
if hasClientInfo {
|
||||
if rawReturn == "" {
|
||||
rawReturn = cfg.DefaultReturnURI
|
||||
}
|
||||
if normalized, err := normalizeReturnTarget(rawReturn, cfg); err == nil {
|
||||
redirectTarget = normalized
|
||||
} else if rawReturn != "" {
|
||||
utils.Log.WithError(err).Warn("invalid return_to during logout")
|
||||
}
|
||||
} else if rawReturn != "" {
|
||||
if strings.HasPrefix(rawReturn, "/") && !strings.HasPrefix(rawReturn, "//") {
|
||||
redirectTarget = rawReturn
|
||||
}
|
||||
}
|
||||
|
||||
utils.Log.WithFields(logrus.Fields{
|
||||
"client": alias,
|
||||
"state": state,
|
||||
"redirect": redirectTarget,
|
||||
}).Info("sso logout completed")
|
||||
|
||||
if redirectTarget != "" {
|
||||
return c.Redirect(redirectTarget, fiber.StatusFound)
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "signed out"})
|
||||
}
|
||||
|
||||
func (h *Controller) revokeToken(ctx context.Context, token string, verification *sso.VerificationResult) {
|
||||
if h.revoker == nil || verification == nil || verification.Claims == nil {
|
||||
return
|
||||
}
|
||||
fingerprint := session.TokenFingerprint(token)
|
||||
if fingerprint == "" {
|
||||
return
|
||||
}
|
||||
if verification.Claims.ExpiresAt == nil {
|
||||
utils.Log.Warn("access token missing expiry claim")
|
||||
return
|
||||
}
|
||||
ttl := time.Until(verification.Claims.ExpiresAt.Time)
|
||||
if ttl <= 0 {
|
||||
return
|
||||
}
|
||||
if ttl < time.Second {
|
||||
ttl = time.Second
|
||||
}
|
||||
if err := h.revoker.Revoke(ctx, fingerprint, ttl); err != nil {
|
||||
utils.Log.WithError(err).Warn("failed to revoke access token")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Controller) revokeRefreshToken(ctx context.Context, token string) {
|
||||
if h.revoker == nil {
|
||||
return
|
||||
}
|
||||
fingerprint := session.TokenFingerprint(token)
|
||||
if fingerprint == "" {
|
||||
return
|
||||
}
|
||||
const refreshTTL = 30 * 24 * time.Hour
|
||||
if err := h.revoker.Revoke(ctx, fingerprint, refreshTTL); err != nil {
|
||||
utils.Log.WithError(err).Warn("failed to revoke refresh token")
|
||||
}
|
||||
}
|
||||
|
||||
func issueCookies(c *fiber.Ctx, tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
@@ -307,15 +450,8 @@ func issueCookies(c *fiber.Ctx, tokenResp struct {
|
||||
Error string `json:"error"`
|
||||
Description string `json:"error_description"`
|
||||
}, verification *sso.VerificationResult) {
|
||||
fmt.Println(tokenResp.AccessToken)
|
||||
accessName := config.SSOAccessCookieName
|
||||
if accessName == "" {
|
||||
accessName = "access"
|
||||
}
|
||||
refreshName := config.SSORefreshCookieName
|
||||
if refreshName == "" {
|
||||
refreshName = "refresh"
|
||||
}
|
||||
accessName := resolveSSOCookieName(config.SSOAccessCookieName, "access")
|
||||
refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh")
|
||||
maxAge := tokenResp.ExpiresIn
|
||||
if maxAge <= 0 {
|
||||
maxAge = int(15 * time.Minute.Seconds())
|
||||
@@ -357,6 +493,64 @@ func issueCookies(c *fiber.Ctx, tokenResp struct {
|
||||
c.Set("X-Auth-User", fmt.Sprintf("%d", verification.UserID))
|
||||
}
|
||||
|
||||
func clearSSOCookie(c *fiber.Ctx, name string) {
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
|
||||
sameSite := config.SSOCookieSameSite
|
||||
if sameSite == "" {
|
||||
sameSite = "Lax"
|
||||
}
|
||||
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Domain: config.SSOCookieDomain,
|
||||
HTTPOnly: true,
|
||||
Secure: config.SSOCookieSecure,
|
||||
SameSite: sameSite,
|
||||
Expires: time.Unix(0, 0),
|
||||
MaxAge: -1,
|
||||
})
|
||||
}
|
||||
|
||||
func resolveSSOCookieName(configuredName, fallback string) string {
|
||||
name := strings.TrimSpace(configuredName)
|
||||
if name != "" {
|
||||
return name
|
||||
}
|
||||
return strings.TrimSpace(fallback)
|
||||
}
|
||||
|
||||
func normalizeClientParam(raw string) string {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
if idx := strings.Index(value, "|"); idx >= 0 {
|
||||
value = value[:idx]
|
||||
}
|
||||
value = strings.TrimSpace(value)
|
||||
return strings.ToLower(value)
|
||||
}
|
||||
|
||||
func findSSOClientConfig(requestedAlias string) (string, config.SSOClientConfig, bool) {
|
||||
if requestedAlias == "" {
|
||||
return "", config.SSOClientConfig{}, false
|
||||
}
|
||||
if cfg, ok := config.SSOClients[requestedAlias]; ok && strings.TrimSpace(cfg.PublicID) != "" {
|
||||
return requestedAlias, cfg, true
|
||||
}
|
||||
for alias, cfg := range config.SSOClients {
|
||||
if strings.EqualFold(strings.TrimSpace(cfg.PublicID), requestedAlias) && strings.TrimSpace(cfg.PublicID) != "" {
|
||||
return alias, cfg, true
|
||||
}
|
||||
}
|
||||
return "", config.SSOClientConfig{}, false
|
||||
}
|
||||
|
||||
func normalizeReturnTarget(returnTo string, cfg config.SSOClientConfig) (string, error) {
|
||||
returnTo = strings.TrimSpace(returnTo)
|
||||
if returnTo == "" {
|
||||
|
||||
@@ -13,22 +13,21 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/modules/users/dto"
|
||||
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/modules/users/dto"
|
||||
userRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/users/repositories"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/response"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/sso"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
||||
)
|
||||
|
||||
|
||||
const (
|
||||
headerClient = "X-Sync-Client"
|
||||
headerTimestamp = "X-Sync-Timestamp"
|
||||
@@ -209,6 +208,18 @@ func (h *UserSyncController) authenticate(c *fiber.Ctx, body []byte) (string, co
|
||||
|
||||
expectedSignature := h.calculateSignature(secret, rawAlias, timestamp, nonce, body)
|
||||
if !hmac.Equal(providedSig, expectedSignature) {
|
||||
bodyHash := sha256.Sum256(body)
|
||||
h.log.WithFields(logrus.Fields{
|
||||
"alias": rawAlias,
|
||||
"alias_key": aliasKey,
|
||||
"timestamp": timestamp,
|
||||
"nonce": nonce,
|
||||
"body_len": len(body),
|
||||
"body_sha256": hex.EncodeToString(bodyHash[:]),
|
||||
"body_base64": base64.StdEncoding.EncodeToString(body),
|
||||
"provided_hex_full": hex.EncodeToString(providedSig),
|
||||
"expected_hex_full": hex.EncodeToString(expectedSignature),
|
||||
}).Warn("sso sync signature mismatch")
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature")
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ func Routes(router fiber.Router, db *gorm.DB, validate *validator.Validate) {
|
||||
}
|
||||
|
||||
store := session.NewStore(cache.MustRedis(), ttl)
|
||||
ctrl := ssoController.NewController(&http.Client{Timeout: 10 * time.Second}, store)
|
||||
ctrl := ssoController.NewController(&http.Client{Timeout: 10 * time.Second}, store, session.GetRevocationStore())
|
||||
userRepo := userRepository.NewUserRepository(db)
|
||||
syncCtrl := ssoController.NewUserSyncController(validate, userRepo, cache.Redis(), config.SSOClients)
|
||||
|
||||
@@ -31,5 +31,6 @@ func Routes(router fiber.Router, db *gorm.DB, validate *validator.Validate) {
|
||||
group.Get("/start", middleware.NewLimiter(30, time.Minute), ctrl.Start)
|
||||
group.Get("/callback", ctrl.Callback)
|
||||
group.Get("/userinfo", middleware.NewLimiter(60, time.Minute), ctrl.UserInfo)
|
||||
group.Post("/logout", middleware.NewLimiter(60, time.Minute), ctrl.Logout)
|
||||
group.Post("/users/sync", middleware.NewLimiter(30, time.Minute), syncCtrl.Sync)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RevocationStore handles token blacklist / revocation entries in Redis.
|
||||
type RevocationStore struct {
|
||||
redis *redis.Client
|
||||
prefix string
|
||||
}
|
||||
|
||||
var (
|
||||
globalRevokerMu sync.RWMutex
|
||||
globalRevoker *RevocationStore
|
||||
)
|
||||
|
||||
// NewRevocationStore creates a revocation store with the given redis client and key prefix.
|
||||
func NewRevocationStore(client *redis.Client, prefix string) *RevocationStore {
|
||||
return &RevocationStore{
|
||||
redis: client,
|
||||
prefix: strings.TrimSpace(prefix),
|
||||
}
|
||||
}
|
||||
|
||||
// SetRevocationStore registers the provided revocation store for global access.
|
||||
func SetRevocationStore(store *RevocationStore) {
|
||||
globalRevokerMu.Lock()
|
||||
globalRevoker = store
|
||||
globalRevokerMu.Unlock()
|
||||
}
|
||||
|
||||
// GetRevocationStore returns the globally registered revocation store, or nil if unset.
|
||||
func GetRevocationStore() *RevocationStore {
|
||||
globalRevokerMu.RLock()
|
||||
defer globalRevokerMu.RUnlock()
|
||||
return globalRevoker
|
||||
}
|
||||
|
||||
// MustRevocationStore returns the registered revocation store or panics if none is configured.
|
||||
func MustRevocationStore() *RevocationStore {
|
||||
store := GetRevocationStore()
|
||||
if store == nil {
|
||||
panic("revocation store not initialised")
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
// Revoke stores the fingerprint with the provided TTL.
|
||||
func (s *RevocationStore) Revoke(ctx context.Context, fingerprint string, ttl time.Duration) error {
|
||||
if s == nil || s.redis == nil {
|
||||
return errors.New("revocation store redis client not initialised")
|
||||
}
|
||||
fingerprint = strings.TrimSpace(fingerprint)
|
||||
if fingerprint == "" {
|
||||
return nil
|
||||
}
|
||||
if ttl <= 0 {
|
||||
ttl = time.Minute
|
||||
}
|
||||
key := s.keyFor(fingerprint)
|
||||
return s.redis.Set(ctx, key, "1", ttl).Err()
|
||||
}
|
||||
|
||||
// IsRevoked returns true when the fingerprint appears in the blacklist.
|
||||
func (s *RevocationStore) IsRevoked(ctx context.Context, fingerprint string) (bool, error) {
|
||||
if s == nil || s.redis == nil {
|
||||
return false, errors.New("revocation store redis client not initialised")
|
||||
}
|
||||
fingerprint = strings.TrimSpace(fingerprint)
|
||||
if fingerprint == "" {
|
||||
return false, nil
|
||||
}
|
||||
key := s.keyFor(fingerprint)
|
||||
exists, err := s.redis.Exists(ctx, key).Result()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return exists > 0, nil
|
||||
}
|
||||
|
||||
func (s *RevocationStore) keyFor(fingerprint string) string {
|
||||
prefix := s.prefix
|
||||
if prefix == "" {
|
||||
prefix = "sso:blacklist"
|
||||
}
|
||||
return prefix + ":" + fingerprint
|
||||
}
|
||||
|
||||
// TokenFingerprint hashes token material before persisting it to the blacklist.
|
||||
func TokenFingerprint(token string) string {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -2,8 +2,10 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
commonrepo "gitlab.com/mbugroup/lti-api.git/internal/common/repository"
|
||||
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
|
||||
"gorm.io/gorm"
|
||||
@@ -42,12 +44,44 @@ func (r *UserRepositoryImpl) UpsertByIdUser(ctx context.Context, user *entity.Us
|
||||
return gorm.ErrInvalidData
|
||||
}
|
||||
|
||||
conflict := []clause.Column{{Name: "id_user"}}
|
||||
user.DeletedAt = gorm.DeletedAt{}
|
||||
user.UpdatedAt = time.Now()
|
||||
return r.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
now := time.Now()
|
||||
user.DeletedAt = gorm.DeletedAt{}
|
||||
user.UpdatedAt = now
|
||||
|
||||
return r.BaseRepositoryImpl.Upsert(ctx, user, conflict, func(db *gorm.DB) *gorm.DB {
|
||||
return db.Omit("id", "created_at")
|
||||
err := tx.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "id_user"}},
|
||||
UpdateAll: true,
|
||||
}).Omit("id", "created_at").Create(user).Error
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isUniqueViolation(err, "users_email_unique") {
|
||||
return err
|
||||
}
|
||||
|
||||
var existing entity.User
|
||||
lockQuery := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("email = ?", user.Email)
|
||||
if err := lockQuery.First(&existing).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.Id = existing.Id
|
||||
|
||||
updates := map[string]any{
|
||||
"id_user": user.IdUser,
|
||||
"email": user.Email,
|
||||
"name": user.Name,
|
||||
"updated_at": now,
|
||||
"deleted_at": gorm.DeletedAt{},
|
||||
}
|
||||
|
||||
if err := tx.Model(&entity.User{}).Where("id = ?", existing.Id).Updates(updates).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -62,3 +96,17 @@ func (r *UserRepositoryImpl) SoftDeleteByIdUser(ctx context.Context, idUser int6
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isUniqueViolation(err error, constraint string) bool {
|
||||
var pgErr *pgconn.PgError
|
||||
if !errors.As(err, &pgErr) {
|
||||
return false
|
||||
}
|
||||
if pgErr.Code != "23505" {
|
||||
return false
|
||||
}
|
||||
if constraint == "" {
|
||||
return true
|
||||
}
|
||||
return pgErr.ConstraintName == constraint
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user