Feat(BE-69,70,71,72,73): crud and integration sso with lti, revoke_token

This commit is contained in:
ragilap
2025-10-08 15:25:17 +07:00
committed by Adnan Zahir
parent 501b6f8440
commit 22e4728738
13 changed files with 517 additions and 31 deletions
@@ -25,22 +25,24 @@ import (
type Controller struct {
httpClient *http.Client
store *session.Store
revoker *session.RevocationStore
}
func NewController(client *http.Client, store *session.Store) *Controller {
return &Controller{httpClient: client, store: store}
func NewController(client *http.Client, store *session.Store, revoker *session.RevocationStore) *Controller {
return &Controller{httpClient: client, store: store, revoker: revoker}
}
// Start handles GET /sso/start requests and redirects users to the central SSO authorize endpoint.
func (h *Controller) Start(c *fiber.Ctx) error {
alias := strings.ToLower(strings.TrimSpace(c.Query("client")))
if alias == "" {
alias = strings.ToLower(strings.TrimSpace(c.Query("client_id")))
requestedAlias := normalizeClientParam(c.Query("client"))
if requestedAlias == "" {
requestedAlias = normalizeClientParam(c.Query("client_id"))
}
if alias == "" {
if requestedAlias == "" {
return fiber.NewError(fiber.StatusBadRequest, "missing client")
}
cfg, ok := config.SSOClients[alias]
alias, cfg, ok := findSSOClientConfig(requestedAlias)
if !ok || cfg.PublicID == "" {
return fiber.NewError(fiber.StatusBadRequest, "unknown client")
}
@@ -209,6 +211,7 @@ func (h *Controller) Callback(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusBadGateway, "missing access token")
}
fmt.Println(tokenResp.AccessToken)
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
if err != nil {
utils.Log.Errorf("access token verification failed: %v", err)
@@ -223,7 +226,6 @@ func (h *Controller) Callback(c *fiber.Ctx) error {
redirectTarget = "/"
}
fmt.Println(sessionData.ClientAlias,"test")
utils.Log.WithFields(logrus.Fields{
"client": sessionData.ClientAlias,
"user_id": verification.UserID,
@@ -255,6 +257,24 @@ func (h *Controller) UserInfo(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
}
if revoker := session.GetRevocationStore(); revoker != nil {
if fingerprint := session.TokenFingerprint(token); fingerprint != "" {
revoked, err := revoker.IsRevoked(c.Context(), fingerprint)
if err != nil {
utils.Log.WithError(err).Warn("failed to check token revocation for userinfo")
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
}
if revoked {
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
}
}
}
if _, err := sso.VerifyAccessToken(token); err != nil {
utils.Log.WithError(err).Warn("access token verification failed for userinfo")
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
}
endpoint := strings.TrimSpace(config.SSOGetMeURL)
if endpoint == "" {
return fiber.NewError(fiber.StatusInternalServerError, "userinfo endpoint not configured")
@@ -297,6 +317,129 @@ func (h *Controller) UserInfo(c *fiber.Ctx) error {
return c.Status(resp.StatusCode).Send(body)
}
// Logout clears SSO cookies and removes any leftover PKCE session state.
func (h *Controller) Logout(c *fiber.Ctx) error {
requestedAlias := normalizeClientParam(c.Query("client"))
if requestedAlias == "" {
requestedAlias = normalizeClientParam(c.Query("client_id"))
}
var (
alias string
cfg config.SSOClientConfig
hasClientInfo bool
)
if requestedAlias != "" {
alias, cfg, hasClientInfo = findSSOClientConfig(requestedAlias)
}
accessName := resolveSSOCookieName(config.SSOAccessCookieName, "access")
refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh")
var accessToken, refreshToken string
if accessName != "" {
accessToken = strings.TrimSpace(c.Cookies(accessName))
}
if refreshName != "" {
refreshToken = strings.TrimSpace(c.Cookies(refreshName))
}
hadAccessCookie := accessToken != ""
hadRefreshCookie := refreshToken != ""
state := strings.TrimSpace(c.Query("state"))
if state != "" {
if err := h.store.Delete(c.Context(), state); err != nil {
utils.Log.Warnf("failed to delete pkce session during logout: %v", err)
}
}
if !hadAccessCookie && !hadRefreshCookie && state == "" {
return fiber.NewError(fiber.StatusUnauthorized, "not authenticated")
}
if hadAccessCookie {
if verification, err := sso.VerifyAccessToken(accessToken); err != nil {
utils.Log.WithError(err).Warn("failed to verify access token during logout")
} else {
h.revokeToken(c.Context(), accessToken, verification)
}
}
if refreshToken != "" {
h.revokeRefreshToken(c.Context(), refreshToken)
}
clearSSOCookie(c, accessName)
clearSSOCookie(c, refreshName)
redirectTarget := ""
rawReturn := strings.TrimSpace(c.Query("return_to"))
if hasClientInfo {
if rawReturn == "" {
rawReturn = cfg.DefaultReturnURI
}
if normalized, err := normalizeReturnTarget(rawReturn, cfg); err == nil {
redirectTarget = normalized
} else if rawReturn != "" {
utils.Log.WithError(err).Warn("invalid return_to during logout")
}
} else if rawReturn != "" {
if strings.HasPrefix(rawReturn, "/") && !strings.HasPrefix(rawReturn, "//") {
redirectTarget = rawReturn
}
}
utils.Log.WithFields(logrus.Fields{
"client": alias,
"state": state,
"redirect": redirectTarget,
}).Info("sso logout completed")
if redirectTarget != "" {
return c.Redirect(redirectTarget, fiber.StatusFound)
}
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "signed out"})
}
func (h *Controller) revokeToken(ctx context.Context, token string, verification *sso.VerificationResult) {
if h.revoker == nil || verification == nil || verification.Claims == nil {
return
}
fingerprint := session.TokenFingerprint(token)
if fingerprint == "" {
return
}
if verification.Claims.ExpiresAt == nil {
utils.Log.Warn("access token missing expiry claim")
return
}
ttl := time.Until(verification.Claims.ExpiresAt.Time)
if ttl <= 0 {
return
}
if ttl < time.Second {
ttl = time.Second
}
if err := h.revoker.Revoke(ctx, fingerprint, ttl); err != nil {
utils.Log.WithError(err).Warn("failed to revoke access token")
}
}
func (h *Controller) revokeRefreshToken(ctx context.Context, token string) {
if h.revoker == nil {
return
}
fingerprint := session.TokenFingerprint(token)
if fingerprint == "" {
return
}
const refreshTTL = 30 * 24 * time.Hour
if err := h.revoker.Revoke(ctx, fingerprint, refreshTTL); err != nil {
utils.Log.WithError(err).Warn("failed to revoke refresh token")
}
}
func issueCookies(c *fiber.Ctx, tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
@@ -307,15 +450,8 @@ func issueCookies(c *fiber.Ctx, tokenResp struct {
Error string `json:"error"`
Description string `json:"error_description"`
}, verification *sso.VerificationResult) {
fmt.Println(tokenResp.AccessToken)
accessName := config.SSOAccessCookieName
if accessName == "" {
accessName = "access"
}
refreshName := config.SSORefreshCookieName
if refreshName == "" {
refreshName = "refresh"
}
accessName := resolveSSOCookieName(config.SSOAccessCookieName, "access")
refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh")
maxAge := tokenResp.ExpiresIn
if maxAge <= 0 {
maxAge = int(15 * time.Minute.Seconds())
@@ -357,6 +493,64 @@ func issueCookies(c *fiber.Ctx, tokenResp struct {
c.Set("X-Auth-User", fmt.Sprintf("%d", verification.UserID))
}
func clearSSOCookie(c *fiber.Ctx, name string) {
if name == "" {
return
}
sameSite := config.SSOCookieSameSite
if sameSite == "" {
sameSite = "Lax"
}
c.Cookie(&fiber.Cookie{
Name: name,
Value: "",
Path: "/",
Domain: config.SSOCookieDomain,
HTTPOnly: true,
Secure: config.SSOCookieSecure,
SameSite: sameSite,
Expires: time.Unix(0, 0),
MaxAge: -1,
})
}
func resolveSSOCookieName(configuredName, fallback string) string {
name := strings.TrimSpace(configuredName)
if name != "" {
return name
}
return strings.TrimSpace(fallback)
}
func normalizeClientParam(raw string) string {
value := strings.TrimSpace(raw)
if value == "" {
return ""
}
if idx := strings.Index(value, "|"); idx >= 0 {
value = value[:idx]
}
value = strings.TrimSpace(value)
return strings.ToLower(value)
}
func findSSOClientConfig(requestedAlias string) (string, config.SSOClientConfig, bool) {
if requestedAlias == "" {
return "", config.SSOClientConfig{}, false
}
if cfg, ok := config.SSOClients[requestedAlias]; ok && strings.TrimSpace(cfg.PublicID) != "" {
return requestedAlias, cfg, true
}
for alias, cfg := range config.SSOClients {
if strings.EqualFold(strings.TrimSpace(cfg.PublicID), requestedAlias) && strings.TrimSpace(cfg.PublicID) != "" {
return alias, cfg, true
}
}
return "", config.SSOClientConfig{}, false
}
func normalizeReturnTarget(returnTo string, cfg config.SSOClientConfig) (string, error) {
returnTo = strings.TrimSpace(returnTo)
if returnTo == "" {
@@ -13,22 +13,21 @@ import (
"strings"
"sync"
"time"
"github.com/go-playground/validator/v10"
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2"
"github.com/redis/go-redis/v9"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
"gitlab.com/mbugroup/lti-api.git/internal/config"
"gitlab.com/mbugroup/lti-api.git/internal/modules/users/dto"
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
"gitlab.com/mbugroup/lti-api.git/internal/modules/users/dto"
userRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/users/repositories"
"gitlab.com/mbugroup/lti-api.git/internal/response"
"gitlab.com/mbugroup/lti-api.git/internal/sso"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
)
const (
headerClient = "X-Sync-Client"
headerTimestamp = "X-Sync-Timestamp"
@@ -209,6 +208,18 @@ func (h *UserSyncController) authenticate(c *fiber.Ctx, body []byte) (string, co
expectedSignature := h.calculateSignature(secret, rawAlias, timestamp, nonce, body)
if !hmac.Equal(providedSig, expectedSignature) {
bodyHash := sha256.Sum256(body)
h.log.WithFields(logrus.Fields{
"alias": rawAlias,
"alias_key": aliasKey,
"timestamp": timestamp,
"nonce": nonce,
"body_len": len(body),
"body_sha256": hex.EncodeToString(bodyHash[:]),
"body_base64": base64.StdEncoding.EncodeToString(body),
"provided_hex_full": hex.EncodeToString(providedSig),
"expected_hex_full": hex.EncodeToString(expectedSignature),
}).Warn("sso sync signature mismatch")
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature")
}
+2 -1
View File
@@ -23,7 +23,7 @@ func Routes(router fiber.Router, db *gorm.DB, validate *validator.Validate) {
}
store := session.NewStore(cache.MustRedis(), ttl)
ctrl := ssoController.NewController(&http.Client{Timeout: 10 * time.Second}, store)
ctrl := ssoController.NewController(&http.Client{Timeout: 10 * time.Second}, store, session.GetRevocationStore())
userRepo := userRepository.NewUserRepository(db)
syncCtrl := ssoController.NewUserSyncController(validate, userRepo, cache.Redis(), config.SSOClients)
@@ -31,5 +31,6 @@ func Routes(router fiber.Router, db *gorm.DB, validate *validator.Validate) {
group.Get("/start", middleware.NewLimiter(30, time.Minute), ctrl.Start)
group.Get("/callback", ctrl.Callback)
group.Get("/userinfo", middleware.NewLimiter(60, time.Minute), ctrl.UserInfo)
group.Post("/logout", middleware.NewLimiter(60, time.Minute), ctrl.Logout)
group.Post("/users/sync", middleware.NewLimiter(30, time.Minute), syncCtrl.Sync)
}
+106
View File
@@ -0,0 +1,106 @@
package session
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"strings"
"sync"
"time"
"github.com/redis/go-redis/v9"
)
// RevocationStore handles token blacklist / revocation entries in Redis.
type RevocationStore struct {
redis *redis.Client
prefix string
}
var (
globalRevokerMu sync.RWMutex
globalRevoker *RevocationStore
)
// NewRevocationStore creates a revocation store with the given redis client and key prefix.
func NewRevocationStore(client *redis.Client, prefix string) *RevocationStore {
return &RevocationStore{
redis: client,
prefix: strings.TrimSpace(prefix),
}
}
// SetRevocationStore registers the provided revocation store for global access.
func SetRevocationStore(store *RevocationStore) {
globalRevokerMu.Lock()
globalRevoker = store
globalRevokerMu.Unlock()
}
// GetRevocationStore returns the globally registered revocation store, or nil if unset.
func GetRevocationStore() *RevocationStore {
globalRevokerMu.RLock()
defer globalRevokerMu.RUnlock()
return globalRevoker
}
// MustRevocationStore returns the registered revocation store or panics if none is configured.
func MustRevocationStore() *RevocationStore {
store := GetRevocationStore()
if store == nil {
panic("revocation store not initialised")
}
return store
}
// Revoke stores the fingerprint with the provided TTL.
func (s *RevocationStore) Revoke(ctx context.Context, fingerprint string, ttl time.Duration) error {
if s == nil || s.redis == nil {
return errors.New("revocation store redis client not initialised")
}
fingerprint = strings.TrimSpace(fingerprint)
if fingerprint == "" {
return nil
}
if ttl <= 0 {
ttl = time.Minute
}
key := s.keyFor(fingerprint)
return s.redis.Set(ctx, key, "1", ttl).Err()
}
// IsRevoked returns true when the fingerprint appears in the blacklist.
func (s *RevocationStore) IsRevoked(ctx context.Context, fingerprint string) (bool, error) {
if s == nil || s.redis == nil {
return false, errors.New("revocation store redis client not initialised")
}
fingerprint = strings.TrimSpace(fingerprint)
if fingerprint == "" {
return false, nil
}
key := s.keyFor(fingerprint)
exists, err := s.redis.Exists(ctx, key).Result()
if err != nil {
return false, err
}
return exists > 0, nil
}
func (s *RevocationStore) keyFor(fingerprint string) string {
prefix := s.prefix
if prefix == "" {
prefix = "sso:blacklist"
}
return prefix + ":" + fingerprint
}
// TokenFingerprint hashes token material before persisting it to the blacklist.
func TokenFingerprint(token string) string {
token = strings.TrimSpace(token)
if token == "" {
return ""
}
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
@@ -2,8 +2,10 @@ package repository
import (
"context"
"errors"
"time"
"github.com/jackc/pgconn"
commonrepo "gitlab.com/mbugroup/lti-api.git/internal/common/repository"
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
"gorm.io/gorm"
@@ -42,12 +44,44 @@ func (r *UserRepositoryImpl) UpsertByIdUser(ctx context.Context, user *entity.Us
return gorm.ErrInvalidData
}
conflict := []clause.Column{{Name: "id_user"}}
user.DeletedAt = gorm.DeletedAt{}
user.UpdatedAt = time.Now()
return r.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
now := time.Now()
user.DeletedAt = gorm.DeletedAt{}
user.UpdatedAt = now
return r.BaseRepositoryImpl.Upsert(ctx, user, conflict, func(db *gorm.DB) *gorm.DB {
return db.Omit("id", "created_at")
err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "id_user"}},
UpdateAll: true,
}).Omit("id", "created_at").Create(user).Error
if err == nil {
return nil
}
if !isUniqueViolation(err, "users_email_unique") {
return err
}
var existing entity.User
lockQuery := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("email = ?", user.Email)
if err := lockQuery.First(&existing).Error; err != nil {
return err
}
user.Id = existing.Id
updates := map[string]any{
"id_user": user.IdUser,
"email": user.Email,
"name": user.Name,
"updated_at": now,
"deleted_at": gorm.DeletedAt{},
}
if err := tx.Model(&entity.User{}).Where("id = ?", existing.Id).Updates(updates).Error; err != nil {
return err
}
return nil
})
}
@@ -62,3 +96,17 @@ func (r *UserRepositoryImpl) SoftDeleteByIdUser(ctx context.Context, idUser int6
}
return nil
}
func isUniqueViolation(err error, constraint string) bool {
var pgErr *pgconn.PgError
if !errors.As(err, &pgErr) {
return false
}
if pgErr.Code != "23505" {
return false
}
if constraint == "" {
return true
}
return pgErr.ConstraintName == constraint
}