mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-20 13:31:56 +00:00
164 lines
4.3 KiB
Go
164 lines
4.3 KiB
Go
package session
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
// RevocationStore handles token blacklist / revocation entries in Redis.
|
|
type RevocationStore struct {
|
|
redis *redis.Client
|
|
prefix string
|
|
}
|
|
|
|
var (
|
|
globalRevokerMu sync.RWMutex
|
|
globalRevoker *RevocationStore
|
|
)
|
|
|
|
// NewRevocationStore creates a revocation store with the given redis client and key prefix.
|
|
func NewRevocationStore(client *redis.Client, prefix string) *RevocationStore {
|
|
return &RevocationStore{
|
|
redis: client,
|
|
prefix: strings.TrimSpace(prefix),
|
|
}
|
|
}
|
|
|
|
// SetRevocationStore registers the provided revocation store for global access.
|
|
func SetRevocationStore(store *RevocationStore) {
|
|
globalRevokerMu.Lock()
|
|
globalRevoker = store
|
|
globalRevokerMu.Unlock()
|
|
}
|
|
|
|
// GetRevocationStore returns the globally registered revocation store, or nil if unset.
|
|
func GetRevocationStore() *RevocationStore {
|
|
globalRevokerMu.RLock()
|
|
defer globalRevokerMu.RUnlock()
|
|
return globalRevoker
|
|
}
|
|
|
|
// MustRevocationStore returns the registered revocation store or panics if none is configured.
|
|
func MustRevocationStore() *RevocationStore {
|
|
store := GetRevocationStore()
|
|
if store == nil {
|
|
panic("revocation store not initialised")
|
|
}
|
|
return store
|
|
}
|
|
|
|
// Revoke stores the fingerprint with the provided TTL.
|
|
func (s *RevocationStore) Revoke(ctx context.Context, fingerprint string, ttl time.Duration) error {
|
|
if s == nil || s.redis == nil {
|
|
return errors.New("revocation store redis client not initialised")
|
|
}
|
|
fingerprint = strings.TrimSpace(fingerprint)
|
|
if fingerprint == "" {
|
|
return nil
|
|
}
|
|
if ttl <= 0 {
|
|
ttl = time.Minute
|
|
}
|
|
key := s.keyFor(fingerprint)
|
|
return s.redis.Set(ctx, key, "1", ttl).Err()
|
|
}
|
|
|
|
// IsRevoked returns true when the fingerprint appears in the blacklist.
|
|
func (s *RevocationStore) IsRevoked(ctx context.Context, fingerprint string) (bool, error) {
|
|
if s == nil || s.redis == nil {
|
|
return false, errors.New("revocation store redis client not initialised")
|
|
}
|
|
fingerprint = strings.TrimSpace(fingerprint)
|
|
if fingerprint == "" {
|
|
return false, nil
|
|
}
|
|
key := s.keyFor(fingerprint)
|
|
exists, err := s.redis.Exists(ctx, key).Result()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return exists > 0, nil
|
|
}
|
|
|
|
// MarkUserLogout stores the timestamp of the last forced logout for the given user.
|
|
func (s *RevocationStore) MarkUserLogout(ctx context.Context, userID uint, at time.Time) error {
|
|
if s == nil || s.redis == nil {
|
|
return errors.New("revocation store redis client not initialised")
|
|
}
|
|
if userID == 0 {
|
|
return errors.New("invalid user id")
|
|
}
|
|
key := s.userLogoutKey(userID)
|
|
return s.redis.Set(ctx, key, at.UTC().Format(time.RFC3339Nano), 0).Err()
|
|
}
|
|
|
|
// ClearUserLogout removes any stored forced logout marker for the given user.
|
|
func (s *RevocationStore) ClearUserLogout(ctx context.Context, userID uint) error {
|
|
if s == nil || s.redis == nil {
|
|
return errors.New("revocation store redis client not initialised")
|
|
}
|
|
if userID == 0 {
|
|
return errors.New("invalid user id")
|
|
}
|
|
key := s.userLogoutKey(userID)
|
|
return s.redis.Del(ctx, key).Err()
|
|
}
|
|
|
|
// UserLogoutTime returns the timestamp of the last forced logout for the given user.
|
|
func (s *RevocationStore) UserLogoutTime(ctx context.Context, userID uint) (time.Time, error) {
|
|
var zero time.Time
|
|
if s == nil || s.redis == nil {
|
|
return zero, errors.New("revocation store redis client not initialised")
|
|
}
|
|
if userID == 0 {
|
|
return zero, errors.New("invalid user id")
|
|
}
|
|
key := s.userLogoutKey(userID)
|
|
value, err := s.redis.Get(ctx, key).Result()
|
|
if err != nil {
|
|
if errors.Is(err, redis.Nil) {
|
|
return zero, nil
|
|
}
|
|
return zero, err
|
|
}
|
|
ts, err := time.Parse(time.RFC3339Nano, value)
|
|
if err != nil {
|
|
return zero, err
|
|
}
|
|
return ts, nil
|
|
}
|
|
|
|
func (s *RevocationStore) keyFor(fingerprint string) string {
|
|
prefix := s.prefix
|
|
if prefix == "" {
|
|
prefix = "sso:blacklist"
|
|
}
|
|
return prefix + ":" + fingerprint
|
|
}
|
|
|
|
func (s *RevocationStore) userLogoutKey(userID uint) string {
|
|
prefix := s.prefix
|
|
if prefix == "" {
|
|
prefix = "sso:blacklist"
|
|
}
|
|
return prefix + ":user-logout:" + strconv.FormatUint(uint64(userID), 10)
|
|
}
|
|
|
|
// TokenFingerprint hashes token material before persisting it to the blacklist.
|
|
func TokenFingerprint(token string) string {
|
|
token = strings.TrimSpace(token)
|
|
if token == "" {
|
|
return ""
|
|
}
|
|
sum := sha256.Sum256([]byte(token))
|
|
return hex.EncodeToString(sum[:])
|
|
}
|