package session import ( "context" "crypto/sha256" "encoding/hex" "errors" "strconv" "strings" "sync" "time" "github.com/redis/go-redis/v9" ) // RevocationStore handles token blacklist / revocation entries in Redis. type RevocationStore struct { redis *redis.Client prefix string } var ( globalRevokerMu sync.RWMutex globalRevoker *RevocationStore ) // NewRevocationStore creates a revocation store with the given redis client and key prefix. func NewRevocationStore(client *redis.Client, prefix string) *RevocationStore { return &RevocationStore{ redis: client, prefix: strings.TrimSpace(prefix), } } // SetRevocationStore registers the provided revocation store for global access. func SetRevocationStore(store *RevocationStore) { globalRevokerMu.Lock() globalRevoker = store globalRevokerMu.Unlock() } // GetRevocationStore returns the globally registered revocation store, or nil if unset. func GetRevocationStore() *RevocationStore { globalRevokerMu.RLock() defer globalRevokerMu.RUnlock() return globalRevoker } // MustRevocationStore returns the registered revocation store or panics if none is configured. func MustRevocationStore() *RevocationStore { store := GetRevocationStore() if store == nil { panic("revocation store not initialised") } return store } // Revoke stores the fingerprint with the provided TTL. func (s *RevocationStore) Revoke(ctx context.Context, fingerprint string, ttl time.Duration) error { if s == nil || s.redis == nil { return errors.New("revocation store redis client not initialised") } fingerprint = strings.TrimSpace(fingerprint) if fingerprint == "" { return nil } if ttl <= 0 { ttl = time.Minute } key := s.keyFor(fingerprint) return s.redis.Set(ctx, key, "1", ttl).Err() } // IsRevoked returns true when the fingerprint appears in the blacklist. func (s *RevocationStore) IsRevoked(ctx context.Context, fingerprint string) (bool, error) { if s == nil || s.redis == nil { return false, errors.New("revocation store redis client not initialised") } fingerprint = strings.TrimSpace(fingerprint) if fingerprint == "" { return false, nil } key := s.keyFor(fingerprint) exists, err := s.redis.Exists(ctx, key).Result() if err != nil { return false, err } return exists > 0, nil } // MarkUserLogout stores the timestamp of the last forced logout for the given user. func (s *RevocationStore) MarkUserLogout(ctx context.Context, userID uint, at time.Time) error { if s == nil || s.redis == nil { return errors.New("revocation store redis client not initialised") } if userID == 0 { return errors.New("invalid user id") } key := s.userLogoutKey(userID) return s.redis.Set(ctx, key, at.UTC().Format(time.RFC3339Nano), 0).Err() } // ClearUserLogout removes any stored forced logout marker for the given user. func (s *RevocationStore) ClearUserLogout(ctx context.Context, userID uint) error { if s == nil || s.redis == nil { return errors.New("revocation store redis client not initialised") } if userID == 0 { return errors.New("invalid user id") } key := s.userLogoutKey(userID) return s.redis.Del(ctx, key).Err() } // UserLogoutTime returns the timestamp of the last forced logout for the given user. func (s *RevocationStore) UserLogoutTime(ctx context.Context, userID uint) (time.Time, error) { var zero time.Time if s == nil || s.redis == nil { return zero, errors.New("revocation store redis client not initialised") } if userID == 0 { return zero, errors.New("invalid user id") } key := s.userLogoutKey(userID) value, err := s.redis.Get(ctx, key).Result() if err != nil { if errors.Is(err, redis.Nil) { return zero, nil } return zero, err } ts, err := time.Parse(time.RFC3339Nano, value) if err != nil { return zero, err } return ts, nil } func (s *RevocationStore) keyFor(fingerprint string) string { prefix := s.prefix if prefix == "" { prefix = "sso:blacklist" } return prefix + ":" + fingerprint } func (s *RevocationStore) userLogoutKey(userID uint) string { prefix := s.prefix if prefix == "" { prefix = "sso:blacklist" } return prefix + ":user-logout:" + strconv.FormatUint(uint64(userID), 10) } // TokenFingerprint hashes token material before persisting it to the blacklist. func TokenFingerprint(token string) string { token = strings.TrimSpace(token) if token == "" { return "" } sum := sha256.Sum256([]byte(token)) return hex.EncodeToString(sum[:]) }