mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-20 13:31:56 +00:00
Revert "[FIX/BE-US]add feature restrict by location and areas in roles"
This reverts commit dff9e73ab1.
This commit is contained in:
@@ -1,331 +0,0 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
||||
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
|
||||
sso "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/verifier"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/response"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
||||
)
|
||||
|
||||
type MasterDataController struct {
|
||||
db *gorm.DB
|
||||
redis *redis.Client
|
||||
clients map[string]config.SSOClientConfig
|
||||
drift time.Duration
|
||||
nonceTTL time.Duration
|
||||
localNonce sync.Map
|
||||
}
|
||||
|
||||
type masterArea struct {
|
||||
ID uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type masterLocation struct {
|
||||
ID uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
AreaID uint `json:"area_id"`
|
||||
}
|
||||
|
||||
func NewMasterDataController(db *gorm.DB, redis *redis.Client, clients map[string]config.SSOClientConfig) *MasterDataController {
|
||||
normalized := make(map[string]config.SSOClientConfig, len(clients))
|
||||
for alias, cfg := range clients {
|
||||
alias = strings.ToLower(strings.TrimSpace(alias))
|
||||
normalized[alias] = cfg
|
||||
}
|
||||
|
||||
drift := config.SSOUserSyncDrift
|
||||
if drift <= 0 {
|
||||
drift = 2 * time.Minute
|
||||
}
|
||||
|
||||
nonceTTL := config.SSOUserSyncNonceTTL
|
||||
if nonceTTL <= 0 {
|
||||
nonceTTL = 10 * time.Minute
|
||||
}
|
||||
|
||||
return &MasterDataController{
|
||||
db: db,
|
||||
redis: redis,
|
||||
clients: normalized,
|
||||
drift: drift,
|
||||
nonceTTL: nonceTTL,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *MasterDataController) GetAreas(c *fiber.Ctx) error {
|
||||
if _, _, err := h.authenticate(c, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
search := strings.TrimSpace(c.Query("search", ""))
|
||||
ids := parseUintList(c.Query("ids", ""))
|
||||
|
||||
query := h.db.WithContext(c.Context()).
|
||||
Model(&entity.Area{}).
|
||||
Where("deleted_at IS NULL")
|
||||
if search != "" {
|
||||
query = query.Where("name ILIKE ?", "%"+search+"%")
|
||||
}
|
||||
if len(ids) > 0 {
|
||||
query = query.Where("id IN ?", ids)
|
||||
}
|
||||
|
||||
var areas []masterArea
|
||||
if err := query.Order("name ASC").Find(&areas).Error; err != nil {
|
||||
utils.Log.WithError(err).Error("failed to fetch areas for master data")
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to fetch areas")
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(response.Success{
|
||||
Code: fiber.StatusOK,
|
||||
Status: "success",
|
||||
Message: "Get areas successfully",
|
||||
Data: areas,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *MasterDataController) GetLocations(c *fiber.Ctx) error {
|
||||
if _, _, err := h.authenticate(c, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
search := strings.TrimSpace(c.Query("search", ""))
|
||||
areaIDs := parseUintList(c.Query("area_ids", ""))
|
||||
ids := parseUintList(c.Query("ids", ""))
|
||||
|
||||
query := h.db.WithContext(c.Context()).
|
||||
Model(&entity.Location{}).
|
||||
Where("deleted_at IS NULL")
|
||||
if search != "" {
|
||||
query = query.Where("name ILIKE ?", "%"+search+"%")
|
||||
}
|
||||
if len(areaIDs) > 0 {
|
||||
query = query.Where("area_id IN ?", areaIDs)
|
||||
}
|
||||
if len(ids) > 0 {
|
||||
query = query.Where("id IN ?", ids)
|
||||
}
|
||||
|
||||
var locations []masterLocation
|
||||
if err := query.Order("name ASC").Find(&locations).Error; err != nil {
|
||||
utils.Log.WithError(err).Error("failed to fetch locations for master data")
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "failed to fetch locations")
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(response.Success{
|
||||
Code: fiber.StatusOK,
|
||||
Status: "success",
|
||||
Message: "Get locations successfully",
|
||||
Data: locations,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *MasterDataController) authenticate(c *fiber.Ctx, body []byte) (string, config.SSOClientConfig, error) {
|
||||
rawAlias := strings.TrimSpace(c.Get("X-Sync-Client"))
|
||||
if rawAlias == "" {
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "missing sync client header")
|
||||
}
|
||||
|
||||
aliasKey := strings.ToLower(rawAlias)
|
||||
clientCfg, ok := h.clients[aliasKey]
|
||||
if !ok {
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "unknown sync client")
|
||||
}
|
||||
|
||||
if err := h.verifyAuthorization(c, aliasKey); err != nil {
|
||||
return "", config.SSOClientConfig{}, err
|
||||
}
|
||||
|
||||
secret := strings.TrimSpace(clientCfg.SyncSecret)
|
||||
if secret == "" {
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "sync secret not configured")
|
||||
}
|
||||
|
||||
timestamp := strings.TrimSpace(c.Get("X-Sync-Timestamp"))
|
||||
nonce := strings.TrimSpace(c.Get("X-Sync-Nonce"))
|
||||
signature := strings.TrimSpace(c.Get("X-Sync-Signature"))
|
||||
if timestamp == "" || nonce == "" || signature == "" {
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "missing signature headers")
|
||||
}
|
||||
if len(nonce) < 16 {
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "nonce too short")
|
||||
}
|
||||
|
||||
ts, err := strconv.ParseInt(timestamp, 10, 64)
|
||||
if err != nil {
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusBadRequest, "invalid timestamp")
|
||||
}
|
||||
|
||||
msgTime := time.Unix(ts, 0).UTC()
|
||||
now := time.Now().UTC()
|
||||
drift := now.Sub(msgTime)
|
||||
if drift > h.drift || drift < -h.drift {
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "timestamp outside allowed window")
|
||||
}
|
||||
|
||||
providedSig, err := decodeMasterSignature(signature)
|
||||
if err != nil {
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature encoding")
|
||||
}
|
||||
|
||||
expectedSignature := calculateSignature(secret, rawAlias, timestamp, nonce, body)
|
||||
if !hmac.Equal(providedSig, expectedSignature) {
|
||||
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature")
|
||||
}
|
||||
|
||||
if err := h.registerNonce(c.Context(), aliasKey, nonce); err != nil {
|
||||
return "", config.SSOClientConfig{}, err
|
||||
}
|
||||
|
||||
return aliasKey, clientCfg, nil
|
||||
}
|
||||
|
||||
func (h *MasterDataController) verifyAuthorization(c *fiber.Ctx, alias string) error {
|
||||
authHeader := strings.TrimSpace(c.Get(fiber.HeaderAuthorization))
|
||||
if authHeader == "" {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "missing authorization header")
|
||||
}
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header")
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(parts[1])
|
||||
if token == "" {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header")
|
||||
}
|
||||
|
||||
verification, err := sso.VerifyAccessToken(token)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
|
||||
}
|
||||
|
||||
if verification.ServiceAlias == "" || verification.ServiceAlias != alias {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "service subject mismatch")
|
||||
}
|
||||
if !hasAnyScope(verification.Claims.Scopes(), []string{"sync.master", "sync.users"}) {
|
||||
return fiber.NewError(fiber.StatusForbidden, "missing sync scope")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MasterDataController) registerNonce(ctx context.Context, alias, nonce string) error {
|
||||
ttl := h.nonceTTL
|
||||
if ttl <= 0 {
|
||||
ttl = 10 * time.Minute
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("sso:sync:%s:%s", alias, nonce)
|
||||
if h.redis != nil {
|
||||
stored, err := h.redis.SetNX(ctx, key, "1", ttl).Result()
|
||||
if err == nil {
|
||||
if !stored {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "nonce already used")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
utils.Log.WithError(err).Warn("store sync nonce failed")
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if expRaw, ok := h.localNonce.Load(key); ok {
|
||||
if expTime, ok := expRaw.(time.Time); ok && expTime.After(now) {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "nonce already used")
|
||||
}
|
||||
}
|
||||
h.localNonce.Store(key, now.Add(ttl))
|
||||
return nil
|
||||
}
|
||||
|
||||
func calculateSignature(secret, alias, timestamp, nonce string, body []byte) []byte {
|
||||
mac := hmac.New(sha256.New, []byte(secret))
|
||||
mac.Write([]byte(alias))
|
||||
mac.Write([]byte("\n"))
|
||||
mac.Write([]byte(timestamp))
|
||||
mac.Write([]byte("\n"))
|
||||
mac.Write([]byte(nonce))
|
||||
mac.Write([]byte("\n"))
|
||||
if len(body) > 0 {
|
||||
mac.Write(body)
|
||||
}
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
func decodeMasterSignature(sig string) ([]byte, error) {
|
||||
sig = strings.TrimSpace(sig)
|
||||
if sig == "" {
|
||||
return nil, errors.New("empty signature")
|
||||
}
|
||||
if decoded, err := hex.DecodeString(sig); err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
if decoded, err := base64.StdEncoding.DecodeString(sig); err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
if decoded, err := base64.URLEncoding.DecodeString(sig); err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
return nil, errors.New("unrecognized signature encoding")
|
||||
}
|
||||
|
||||
func parseUintList(raw string) []uint {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(raw, ",")
|
||||
out := make([]uint, 0, len(parts))
|
||||
seen := make(map[uint]struct{}, len(parts))
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
val, err := strconv.ParseUint(part, 10, 64)
|
||||
if err != nil || val == 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[uint(val)]; ok {
|
||||
continue
|
||||
}
|
||||
seen[uint(val)] = struct{}{}
|
||||
out = append(out, uint(val))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func hasAnyScope(scopes []string, targets []string) bool {
|
||||
if len(scopes) == 0 || len(targets) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, scope := range scopes {
|
||||
scope = strings.ToLower(strings.TrimSpace(scope))
|
||||
if scope == "" {
|
||||
continue
|
||||
}
|
||||
for _, target := range targets {
|
||||
if scope == strings.ToLower(strings.TrimSpace(target)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
|
||||
sso "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/verifier"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/sso"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/utils/secure"
|
||||
)
|
||||
|
||||
@@ -9,24 +9,23 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
||||
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
|
||||
sso "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/verifier"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/modules/users/dto"
|
||||
userRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/users/repositories"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/response"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/sso"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
||||
)
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ func Routes(router fiber.Router, db *gorm.DB, validate *validator.Validate) {
|
||||
ctrl := ssoController.NewController(&http.Client{Timeout: 10 * time.Second}, store, session.GetRevocationStore())
|
||||
userRepo := userRepository.NewUserRepository(db)
|
||||
syncCtrl := ssoController.NewUserSyncController(validate, userRepo, cache.Redis(), config.SSOClients)
|
||||
masterCtrl := ssoController.NewMasterDataController(db, cache.Redis(), config.SSOClients)
|
||||
|
||||
group := router.Group("/sso")
|
||||
group.Get("/start", middleware.NewLimiter(30, time.Minute), ctrl.Start)
|
||||
@@ -35,6 +34,4 @@ func Routes(router fiber.Router, db *gorm.DB, validate *validator.Validate) {
|
||||
group.Post("/refresh", middleware.NewLimiter(60, time.Minute), ctrl.Refresh)
|
||||
group.Post("/logout", middleware.NewLimiter(60, time.Minute), ctrl.Logout)
|
||||
group.Post("/users/sync", middleware.NewLimiter(30, time.Minute), syncCtrl.Sync)
|
||||
group.Get("/master/areas", middleware.NewLimiter(60, time.Minute), masterCtrl.GetAreas)
|
||||
group.Get("/master/locations", middleware.NewLimiter(60, time.Minute), masterCtrl.GetLocations)
|
||||
}
|
||||
|
||||
@@ -1,319 +0,0 @@
|
||||
package sso
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/cache"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
profileCachePrefix = "sso:profile:user:"
|
||||
profileCacheTTL = time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
profileClient = &http.Client{Timeout: 5 * time.Second}
|
||||
|
||||
profileLocalCache sync.Map // map[string]cachedProfile
|
||||
)
|
||||
|
||||
type cachedProfile struct {
|
||||
Profile *UserProfile
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// UserProfile represents the enriched user information returned by the central SSO.
|
||||
type UserProfile struct {
|
||||
UserID uint
|
||||
Roles []Role
|
||||
Permissions []Permission
|
||||
}
|
||||
|
||||
// Role describes a role assignment from the SSO profile response.
|
||||
type Role struct {
|
||||
ID uint
|
||||
Key string
|
||||
Name string
|
||||
ClientID uint
|
||||
ClientAlias string
|
||||
ClientName string
|
||||
AllArea bool
|
||||
AllLocation bool
|
||||
AreaIDs []uint
|
||||
LocationIDs []uint
|
||||
Permissions []Permission
|
||||
RawReference json.RawMessage `json:"-"`
|
||||
}
|
||||
|
||||
// Permission describes a granular permission entry from the SSO profile.
|
||||
type Permission struct {
|
||||
ID uint
|
||||
Name string
|
||||
Action string
|
||||
ClientID uint
|
||||
ClientAlias string
|
||||
ClientName string
|
||||
}
|
||||
|
||||
// PermissionNames returns a de-duplicated slice of permission identifiers in canonical form.
|
||||
func (p *UserProfile) PermissionNames() []string {
|
||||
if p == nil || len(p.Permissions) == 0 {
|
||||
return nil
|
||||
}
|
||||
set := make(map[string]struct{}, len(p.Permissions))
|
||||
for _, perm := range p.Permissions {
|
||||
name := canonicalPermissionName(perm.Name)
|
||||
if name != "" {
|
||||
set[name] = struct{}{}
|
||||
}
|
||||
}
|
||||
out := make([]string, 0, len(set))
|
||||
for name := range set {
|
||||
out = append(out, name)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// FetchProfile retrieves the SSO profile for the authenticated user, using Redis/in-memory
|
||||
// caching to reduce load on the SSO service. Only end-user tokens (subject user:ID) are supported.
|
||||
func FetchProfile(ctx context.Context, token string, verification *VerificationResult) (*UserProfile, error) {
|
||||
if verification == nil || verification.UserID == 0 {
|
||||
return nil, errors.New("profile only available for user tokens")
|
||||
}
|
||||
key := profileCacheKey(verification.UserID)
|
||||
|
||||
if profile := loadProfileFromLocalCache(key); profile != nil {
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
if profile := loadProfileFromRedis(ctx, key); profile != nil {
|
||||
storeProfileInLocalCache(key, profile)
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
profile, err := fetchProfileFromSSO(ctx, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeProfileInLocalCache(key, profile)
|
||||
storeProfileInRedis(ctx, key, profile)
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
func fetchProfileFromSSO(ctx context.Context, token string) (*UserProfile, error) {
|
||||
endpoint := strings.TrimSpace(config.SSOGetMeURL)
|
||||
if endpoint == "" {
|
||||
return nil, errors.New("sso get-me endpoint not configured")
|
||||
}
|
||||
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build profile request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
if cookieName := strings.TrimSpace(config.SSOAccessCookieName); cookieName != "" {
|
||||
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", cookieName, token))
|
||||
}
|
||||
|
||||
resp, err := profileClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch profile: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("fetch profile: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var envelope userInfoEnvelope
|
||||
if err := json.NewDecoder(resp.Body).Decode(&envelope); err != nil {
|
||||
return nil, fmt.Errorf("decode profile: %w", err)
|
||||
}
|
||||
|
||||
roles := envelope.getRoles()
|
||||
profile := &UserProfile{}
|
||||
|
||||
// Attempt to infer user id if provided.
|
||||
if envelope.User != nil && envelope.User.ID > 0 {
|
||||
profile.UserID = uint(envelope.User.ID)
|
||||
}
|
||||
|
||||
perms := make([]Permission, 0)
|
||||
convertedRoles := make([]Role, 0, len(roles))
|
||||
for _, r := range roles {
|
||||
role := Role{
|
||||
ID: uint(r.ID),
|
||||
Key: strings.TrimSpace(r.Key),
|
||||
Name: strings.TrimSpace(r.Name),
|
||||
ClientAlias: strings.TrimSpace(r.Client.Alias),
|
||||
ClientName: strings.TrimSpace(r.Client.Name),
|
||||
ClientID: uint(r.Client.ID),
|
||||
AllArea: r.AllArea,
|
||||
AllLocation: r.AllLocation,
|
||||
AreaIDs: r.AreaIDs,
|
||||
LocationIDs: r.LocationIDs,
|
||||
}
|
||||
rolePerms := make([]Permission, 0, len(r.Permissions))
|
||||
for _, p := range r.Permissions {
|
||||
perm := Permission{
|
||||
ID: uint(p.ID),
|
||||
Name: strings.TrimSpace(p.Name),
|
||||
Action: strings.TrimSpace(p.Action),
|
||||
ClientAlias: strings.TrimSpace(p.Client.Alias),
|
||||
ClientName: strings.TrimSpace(p.Client.Name),
|
||||
ClientID: uint(p.Client.ID),
|
||||
}
|
||||
if perm.Name != "" {
|
||||
rolePerms = append(rolePerms, perm)
|
||||
perms = append(perms, perm)
|
||||
}
|
||||
}
|
||||
role.Permissions = rolePerms
|
||||
convertedRoles = append(convertedRoles, role)
|
||||
}
|
||||
profile.Roles = convertedRoles
|
||||
profile.Permissions = perms
|
||||
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
func loadProfileFromLocalCache(key string) *UserProfile {
|
||||
if value, ok := profileLocalCache.Load(key); ok {
|
||||
if cached, ok := value.(cachedProfile); ok {
|
||||
if time.Now().Before(cached.ExpiresAt) && cached.Profile != nil {
|
||||
return cached.Profile
|
||||
}
|
||||
profileLocalCache.Delete(key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadProfileFromRedis(ctx context.Context, key string) *UserProfile {
|
||||
client := cache.Redis()
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := client.Get(ctx, key).Bytes()
|
||||
if err != nil {
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
utils.Log.WithError(err).Warn("sso profile redis lookup failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var profile UserProfile
|
||||
if err := json.Unmarshal(data, &profile); err != nil {
|
||||
utils.Log.WithError(err).Warn("sso profile redis decode failed")
|
||||
return nil
|
||||
}
|
||||
|
||||
return &profile
|
||||
}
|
||||
|
||||
func storeProfileInLocalCache(key string, profile *UserProfile) {
|
||||
if profile == nil {
|
||||
return
|
||||
}
|
||||
profileLocalCache.Store(key, cachedProfile{
|
||||
Profile: profile,
|
||||
ExpiresAt: time.Now().Add(profileCacheTTL),
|
||||
})
|
||||
}
|
||||
|
||||
func storeProfileInRedis(ctx context.Context, key string, profile *UserProfile) {
|
||||
client := cache.Redis()
|
||||
if client == nil || profile == nil {
|
||||
return
|
||||
}
|
||||
|
||||
data, err := json.Marshal(profile)
|
||||
if err != nil {
|
||||
utils.Log.WithError(err).Warn("sso profile redis encode failed")
|
||||
return
|
||||
}
|
||||
|
||||
if err := client.Set(ctx, key, data, profileCacheTTL).Err(); err != nil {
|
||||
utils.Log.WithError(err).Warn("sso profile redis store failed")
|
||||
}
|
||||
}
|
||||
|
||||
func profileCacheKey(userID uint) string {
|
||||
return profileCachePrefix + strconv.FormatUint(uint64(userID), 10)
|
||||
}
|
||||
|
||||
func canonicalPermissionName(name string) string {
|
||||
return strings.ToLower(strings.TrimSpace(name))
|
||||
}
|
||||
|
||||
// userInfoEnvelope handles the varying shapes returned by the SSO userinfo endpoint.
|
||||
type userInfoEnvelope struct {
|
||||
Roles []userInfoRole `json:"roles"`
|
||||
Data *struct {
|
||||
ID int64 `json:"id"`
|
||||
Roles []userInfoRole `json:"roles"`
|
||||
} `json:"data"`
|
||||
User *struct {
|
||||
ID int64 `json:"id"`
|
||||
} `json:"user"`
|
||||
}
|
||||
|
||||
func (e *userInfoEnvelope) getRoles() []userInfoRole {
|
||||
if len(e.Roles) > 0 {
|
||||
return e.Roles
|
||||
}
|
||||
if e.Data != nil && len(e.Data.Roles) > 0 {
|
||||
if e.User == nil && e.Data.ID > 0 {
|
||||
e.User = &struct {
|
||||
ID int64 `json:"id"`
|
||||
}{ID: e.Data.ID}
|
||||
}
|
||||
return e.Data.Roles
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type userInfoRole struct {
|
||||
ID int64 `json:"id"`
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
AllArea bool `json:"all_area"`
|
||||
AllLocation bool `json:"all_location"`
|
||||
AreaIDs []uint `json:"area_ids"`
|
||||
LocationIDs []uint `json:"location_ids"`
|
||||
Client userInfoClient `json:"client"`
|
||||
Permissions []userInfoPermRaw `json:"permissions"`
|
||||
}
|
||||
|
||||
type userInfoClient struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Alias string `json:"alias"`
|
||||
}
|
||||
|
||||
type userInfoPermRaw struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Action string `json:"action"`
|
||||
Client userInfoClient `json:"client"`
|
||||
Details any `json:"details"`
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
package sso
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/MicahParks/keyfunc/v2"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
||||
)
|
||||
|
||||
type verifier struct {
|
||||
jwks *keyfunc.JWKS
|
||||
issuer string
|
||||
audiences map[string]struct{}
|
||||
}
|
||||
|
||||
type AccessTokenClaims struct {
|
||||
Scope string `json:"scope"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func (c AccessTokenClaims) Scopes() []string {
|
||||
if c.Scope == "" {
|
||||
return nil
|
||||
}
|
||||
return strings.Fields(c.Scope)
|
||||
}
|
||||
|
||||
type VerificationResult struct {
|
||||
UserID uint
|
||||
ServiceAlias string
|
||||
Subject string
|
||||
Claims *AccessTokenClaims
|
||||
}
|
||||
|
||||
var (
|
||||
globalMu sync.RWMutex
|
||||
globalV *verifier
|
||||
)
|
||||
|
||||
func Init(ctx context.Context, jwksURL, issuer string, audiences []string) error {
|
||||
jwksURL = strings.TrimSpace(jwksURL)
|
||||
issuer = strings.TrimSpace(issuer)
|
||||
if jwksURL == "" || issuer == "" {
|
||||
return errors.New("missing SSO JWKS or issuer configuration")
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
options := keyfunc.Options{
|
||||
Ctx: ctx,
|
||||
Client: client,
|
||||
RefreshTimeout: 10 * time.Second,
|
||||
RefreshInterval: time.Hour,
|
||||
RefreshUnknownKID: true,
|
||||
RefreshErrorHandler: func(err error) {
|
||||
utils.Log.Errorf("sso jwks refresh failed: %v", err)
|
||||
},
|
||||
}
|
||||
|
||||
jwks, err := keyfunc.Get(jwksURL, options)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load jwks: %w", err)
|
||||
}
|
||||
|
||||
audienceMap := make(map[string]struct{}, len(audiences))
|
||||
for _, aud := range audiences {
|
||||
aud = strings.TrimSpace(aud)
|
||||
if aud == "" {
|
||||
continue
|
||||
}
|
||||
audienceMap[aud] = struct{}{}
|
||||
}
|
||||
|
||||
globalMu.Lock()
|
||||
globalV = &verifier{jwks: jwks, issuer: issuer, audiences: audienceMap}
|
||||
globalMu.Unlock()
|
||||
|
||||
utils.Log.Infof("sso verifier initialized for issuer %s (%d keys)", issuer, len(jwks.KIDs()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func VerifyAccessToken(token string) (*VerificationResult, error) {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return nil, errors.New("empty token")
|
||||
}
|
||||
|
||||
globalMu.RLock()
|
||||
v := globalV
|
||||
globalMu.RUnlock()
|
||||
if v == nil {
|
||||
return nil, errors.New("sso verifier not initialized")
|
||||
}
|
||||
|
||||
claims := &AccessTokenClaims{}
|
||||
parser := jwt.NewParser(
|
||||
jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()}),
|
||||
jwt.WithIssuedAt(),
|
||||
jwt.WithExpirationRequired(),
|
||||
)
|
||||
|
||||
tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse token: %w", err)
|
||||
}
|
||||
if !tok.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
if claims.Issuer != v.issuer {
|
||||
return nil, errors.New("unexpected token issuer")
|
||||
}
|
||||
|
||||
if len(v.audiences) > 0 {
|
||||
validAud := false
|
||||
for _, aud := range claims.Audience {
|
||||
if _, ok := v.audiences[aud]; ok {
|
||||
validAud = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !validAud {
|
||||
return nil, errors.New("unexpected token audience")
|
||||
}
|
||||
}
|
||||
|
||||
sub := strings.TrimSpace(claims.Subject)
|
||||
if sub == "" {
|
||||
return nil, errors.New("missing subject")
|
||||
}
|
||||
|
||||
result := &VerificationResult{Claims: claims, Subject: sub}
|
||||
switch {
|
||||
case strings.HasPrefix(sub, "user:"):
|
||||
idStr := strings.TrimPrefix(sub, "user:")
|
||||
id, err := strconv.ParseUint(idStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid subject: %w", err)
|
||||
}
|
||||
result.UserID = uint(id)
|
||||
case strings.HasPrefix(sub, "service:"):
|
||||
alias := strings.TrimSpace(strings.TrimPrefix(sub, "service:"))
|
||||
if alias == "" {
|
||||
return nil, errors.New("invalid service subject")
|
||||
}
|
||||
result.ServiceAlias = strings.ToLower(alias)
|
||||
default:
|
||||
return nil, errors.New("unsupported subject type")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
Reference in New Issue
Block a user