Files

388 lines
9.6 KiB
Go

package sso
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/redis/go-redis/v9"
"gitlab.com/mbugroup/lti-api.git/internal/cache"
"gitlab.com/mbugroup/lti-api.git/internal/config"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
)
const (
profileCachePrefix = "sso:profile:user:"
profileCacheTTL = time.Minute
)
var (
profileClient = &http.Client{Timeout: 5 * time.Second}
profileLocalCache sync.Map // map[string]cachedProfile
)
type cachedProfile struct {
Profile *UserProfile
ExpiresAt time.Time
}
// UserProfile represents the enriched user information returned by the central SSO.
type UserProfile struct {
UserID uint
Roles []Role
Permissions []Permission
AreaIDs []uint
LocationIDs []uint
AllArea bool
AllLocation bool
}
// Role describes a role assignment from the SSO profile response.
type Role struct {
ID uint
Key string
Name string
ClientID uint
ClientAlias string
ClientName string
Permissions []Permission
RawReference json.RawMessage `json:"-"`
}
// Permission describes a granular permission entry from the SSO profile.
type Permission struct {
ID uint
Name string
Action string
ClientID uint
ClientAlias string
ClientName string
}
// PermissionNames returns a de-duplicated slice of permission identifiers in canonical form.
func (p *UserProfile) PermissionNames() []string {
if p == nil || len(p.Permissions) == 0 {
return nil
}
set := make(map[string]struct{}, len(p.Permissions))
for _, perm := range p.Permissions {
name := canonicalPermissionName(perm.Name)
if name != "" {
set[name] = struct{}{}
}
}
out := make([]string, 0, len(set))
for name := range set {
out = append(out, name)
}
return out
}
// FetchProfile retrieves the SSO profile for the authenticated user, using Redis/in-memory
// caching to reduce load on the SSO service. Only end-user tokens (subject user:ID) are supported.
func FetchProfile(ctx context.Context, token string, verification *VerificationResult) (*UserProfile, error) {
if verification == nil || verification.UserID == 0 {
return nil, errors.New("profile only available for user tokens")
}
key := profileCacheKey(verification.UserID)
if profile := loadProfileFromLocalCache(key); profile != nil {
return profile, nil
}
if profile := loadProfileFromRedis(ctx, key); profile != nil {
storeProfileInLocalCache(key, profile)
return profile, nil
}
profile, err := fetchProfileFromSSO(ctx, token)
if err != nil {
return nil, err
}
storeProfileInLocalCache(key, profile)
storeProfileInRedis(ctx, key, profile)
return profile, nil
}
func fetchProfileFromSSO(ctx context.Context, token string) (*UserProfile, error) {
endpoint := strings.TrimSpace(config.SSOGetMeURL)
if endpoint == "" {
return nil, errors.New("sso get-me endpoint not configured")
}
if ctx == nil {
ctx = context.Background()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("build profile request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
if cookieName := strings.TrimSpace(config.SSOAccessCookieName); cookieName != "" {
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", cookieName, token))
}
resp, err := profileClient.Do(req)
if err != nil {
return nil, fmt.Errorf("fetch profile: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("fetch profile: status %d", resp.StatusCode)
}
var envelope userInfoEnvelope
if err := json.NewDecoder(resp.Body).Decode(&envelope); err != nil {
return nil, fmt.Errorf("decode profile: %w", err)
}
roles := envelope.getRoles()
areaIDs := envelope.getAreaIDs()
locationIDs := envelope.getLocationIDs()
allArea := envelope.getAllArea()
allLocation := envelope.getAllLocation()
profile := &UserProfile{}
// Attempt to infer user id if provided.
if envelope.User != nil && envelope.User.ID > 0 {
profile.UserID = uint(envelope.User.ID)
}
perms := make([]Permission, 0)
convertedRoles := make([]Role, 0, len(roles))
for _, r := range roles {
role := Role{
ID: uint(r.ID),
Key: strings.TrimSpace(r.Key),
Name: strings.TrimSpace(r.Name),
ClientAlias: strings.TrimSpace(r.Client.Alias),
ClientName: strings.TrimSpace(r.Client.Name),
ClientID: uint(r.Client.ID),
}
rolePerms := make([]Permission, 0, len(r.Permissions))
for _, p := range r.Permissions {
perm := Permission{
ID: uint(p.ID),
Name: strings.TrimSpace(p.Name),
Action: strings.TrimSpace(p.Action),
ClientAlias: strings.TrimSpace(p.Client.Alias),
ClientName: strings.TrimSpace(p.Client.Name),
ClientID: uint(p.Client.ID),
}
if perm.Name != "" {
rolePerms = append(rolePerms, perm)
perms = append(perms, perm)
}
}
role.Permissions = rolePerms
convertedRoles = append(convertedRoles, role)
}
profile.Roles = convertedRoles
profile.Permissions = perms
profile.AreaIDs = areaIDs
profile.LocationIDs = locationIDs
profile.AllArea = allArea
profile.AllLocation = allLocation
return profile, nil
}
func loadProfileFromLocalCache(key string) *UserProfile {
if value, ok := profileLocalCache.Load(key); ok {
if cached, ok := value.(cachedProfile); ok {
if time.Now().Before(cached.ExpiresAt) && cached.Profile != nil {
return cached.Profile
}
profileLocalCache.Delete(key)
}
}
return nil
}
func loadProfileFromRedis(ctx context.Context, key string) *UserProfile {
client := cache.Redis()
if client == nil {
return nil
}
data, err := client.Get(ctx, key).Bytes()
if err != nil {
if !errors.Is(err, redis.Nil) {
utils.Log.WithError(err).Warn("sso profile redis lookup failed")
}
return nil
}
var profile UserProfile
if err := json.Unmarshal(data, &profile); err != nil {
utils.Log.WithError(err).Warn("sso profile redis decode failed")
return nil
}
return &profile
}
func storeProfileInLocalCache(key string, profile *UserProfile) {
if profile == nil {
return
}
profileLocalCache.Store(key, cachedProfile{
Profile: profile,
ExpiresAt: time.Now().Add(profileCacheTTL),
})
}
func storeProfileInRedis(ctx context.Context, key string, profile *UserProfile) {
client := cache.Redis()
if client == nil || profile == nil {
return
}
data, err := json.Marshal(profile)
if err != nil {
utils.Log.WithError(err).Warn("sso profile redis encode failed")
return
}
if err := client.Set(ctx, key, data, profileCacheTTL).Err(); err != nil {
utils.Log.WithError(err).Warn("sso profile redis store failed")
}
}
func profileCacheKey(userID uint) string {
return profileCachePrefix + strconv.FormatUint(uint64(userID), 10)
}
// InvalidateProfileCache clears cached profile data for the given user in both local and Redis caches.
func InvalidateProfileCache(ctx context.Context, userID uint) {
if userID == 0 {
return
}
key := profileCacheKey(userID)
profileLocalCache.Delete(key)
client := cache.Redis()
if client == nil {
return
}
if ctx == nil {
ctx = context.Background()
}
if err := client.Del(ctx, key).Err(); err != nil && !errors.Is(err, redis.Nil) {
utils.Log.WithError(err).Warn("sso profile redis delete failed")
}
}
func canonicalPermissionName(name string) string {
return strings.ToLower(strings.TrimSpace(name))
}
// userInfoEnvelope handles the varying shapes returned by the SSO userinfo endpoint.
type userInfoEnvelope struct {
Roles []userInfoRole `json:"roles"`
AreaIDs []uint `json:"area_ids"`
LocationIDs []uint `json:"location_ids"`
AllArea bool `json:"all_area"`
AllLocation bool `json:"all_location"`
Data *struct {
ID int64 `json:"id"`
Roles []userInfoRole `json:"roles"`
AreaIDs []uint `json:"area_ids"`
LocationIDs []uint `json:"location_ids"`
AllArea bool `json:"all_area"`
AllLocation bool `json:"all_location"`
} `json:"data"`
User *struct {
ID int64 `json:"id"`
} `json:"user"`
}
func (e *userInfoEnvelope) getRoles() []userInfoRole {
if len(e.Roles) > 0 {
return e.Roles
}
if e.Data != nil && len(e.Data.Roles) > 0 {
if e.User == nil && e.Data.ID > 0 {
e.User = &struct {
ID int64 `json:"id"`
}{ID: e.Data.ID}
}
return e.Data.Roles
}
return nil
}
func (e *userInfoEnvelope) getAreaIDs() []uint {
if len(e.AreaIDs) > 0 {
return e.AreaIDs
}
if e.Data != nil && len(e.Data.AreaIDs) > 0 {
return e.Data.AreaIDs
}
return nil
}
func (e *userInfoEnvelope) getLocationIDs() []uint {
if len(e.LocationIDs) > 0 {
return e.LocationIDs
}
if e.Data != nil && len(e.Data.LocationIDs) > 0 {
return e.Data.LocationIDs
}
return nil
}
func (e *userInfoEnvelope) getAllArea() bool {
if e.AllArea {
return true
}
if e.Data != nil && e.Data.AllArea {
return true
}
return false
}
func (e *userInfoEnvelope) getAllLocation() bool {
if e.AllLocation {
return true
}
if e.Data != nil && e.Data.AllLocation {
return true
}
return false
}
type userInfoRole struct {
ID int64 `json:"id"`
Key string `json:"key"`
Name string `json:"name"`
Client userInfoClient `json:"client"`
Permissions []userInfoPermRaw `json:"permissions"`
}
type userInfoClient struct {
ID int64 `json:"id"`
Name string `json:"name"`
Alias string `json:"alias"`
}
type userInfoPermRaw struct {
ID int64 `json:"id"`
Name string `json:"name"`
Action string `json:"action"`
Client userInfoClient `json:"client"`
Details any `json:"details"`
}