package middleware import ( "errors" "strings" "github.com/gofiber/fiber/v2" "gorm.io/gorm" "gitlab.com/mbugroup/lti-api.git/internal/config" entity "gitlab.com/mbugroup/lti-api.git/internal/entities" ) type ScopeFilter struct { IDs []uint Restrict bool } type roleScope struct { allArea bool allLocation bool areaIDs []uint locationIDs []uint hasAnyScopes bool } func ResolveAreaScope(c *fiber.Ctx, db *gorm.DB) (ScopeFilter, error) { scope, err := collectRoleScope(c) if err != nil || !scope.hasAnyScopes { return ScopeFilter{}, err } if scope.allArea || scope.allLocation { return ScopeFilter{}, nil } allowed := uniqueUint(scope.areaIDs) if len(scope.locationIDs) > 0 { derived, err := areaIDsByLocationIDs(db, scope.locationIDs) if err != nil { return ScopeFilter{}, err } allowed = uniqueUint(append(allowed, derived...)) } if len(allowed) == 0 { return ScopeFilter{Restrict: true}, nil } return ScopeFilter{IDs: allowed, Restrict: true}, nil } func ResolveLocationScope(c *fiber.Ctx, db *gorm.DB) (ScopeFilter, error) { scope, err := collectRoleScope(c) if err != nil || !scope.hasAnyScopes { return ScopeFilter{}, err } if scope.allLocation || scope.allArea { return ScopeFilter{}, nil } areaIDs := uniqueUint(scope.areaIDs) locationIDs := uniqueUint(scope.locationIDs) switch { case len(locationIDs) > 0 && len(areaIDs) > 0: filtered, err := filterLocationIDsByAreaIDs(db, locationIDs, areaIDs) if err != nil { return ScopeFilter{}, err } locationIDs = filtered case len(locationIDs) == 0 && len(areaIDs) > 0: derived, err := locationIDsByAreaIDs(db, areaIDs) if err != nil { return ScopeFilter{}, err } locationIDs = derived } locationIDs = uniqueUint(locationIDs) if len(locationIDs) == 0 { return ScopeFilter{Restrict: true}, nil } return ScopeFilter{IDs: locationIDs, Restrict: true}, nil } func collectRoleScope(c *fiber.Ctx) (roleScope, error) { ctx, ok := AuthDetails(c) if !ok || ctx == nil || len(ctx.Roles) == 0 { return roleScope{}, nil } clientAlias := resolveClientAlias(ctx) scope := roleScope{} areaSet := make(map[uint]struct{}) locationSet := make(map[uint]struct{}) for _, role := range ctx.Roles { if clientAlias != "" && !strings.EqualFold(strings.TrimSpace(role.ClientAlias), clientAlias) { continue } scope.hasAnyScopes = true if role.AllArea { scope.allArea = true } if role.AllLocation { scope.allLocation = true } for _, id := range role.AreaIDs { if id == 0 { continue } areaSet[id] = struct{}{} } for _, id := range role.LocationIDs { if id == 0 { continue } locationSet[id] = struct{}{} } } scope.areaIDs = keysUint(areaSet) scope.locationIDs = keysUint(locationSet) scope.hasAnyScopes = scope.hasAnyScopes && (scope.allArea || scope.allLocation || len(scope.areaIDs) > 0 || len(scope.locationIDs) > 0) return scope, nil } func areaIDsByLocationIDs(db *gorm.DB, locationIDs []uint) ([]uint, error) { if db == nil { return nil, errors.New("database not configured") } if len(locationIDs) == 0 { return nil, nil } var areaIDs []uint if err := db.Model(&entity.Location{}). Where("deleted_at IS NULL"). Where("id IN ?", locationIDs). Distinct("area_id"). Pluck("area_id", &areaIDs).Error; err != nil { return nil, err } return areaIDs, nil } func locationIDsByAreaIDs(db *gorm.DB, areaIDs []uint) ([]uint, error) { if db == nil { return nil, errors.New("database not configured") } if len(areaIDs) == 0 { return nil, nil } var locationIDs []uint if err := db.Model(&entity.Location{}). Where("deleted_at IS NULL"). Where("area_id IN ?", areaIDs). Distinct("id"). Pluck("id", &locationIDs).Error; err != nil { return nil, err } return locationIDs, nil } func filterLocationIDsByAreaIDs(db *gorm.DB, locationIDs, areaIDs []uint) ([]uint, error) { if db == nil { return nil, errors.New("database not configured") } if len(locationIDs) == 0 || len(areaIDs) == 0 { return nil, nil } var filtered []uint if err := db.Model(&entity.Location{}). Where("deleted_at IS NULL"). Where("id IN ?", locationIDs). Where("area_id IN ?", areaIDs). Distinct("id"). Pluck("id", &filtered).Error; err != nil { return nil, err } return filtered, nil } func uniqueUint(ids []uint) []uint { if len(ids) == 0 { return nil } seen := make(map[uint]struct{}, len(ids)) result := make([]uint, 0, len(ids)) for _, id := range ids { if id == 0 { continue } if _, ok := seen[id]; ok { continue } seen[id] = struct{}{} result = append(result, id) } return result } func keysUint(set map[uint]struct{}) []uint { if len(set) == 0 { return nil } out := make([]uint, 0, len(set)) for id := range set { out = append(out, id) } return out } func resolveClientAlias(ctx *AuthContext) string { if ctx == nil || ctx.Verification == nil || ctx.Verification.Claims == nil { return "" } scopes := ctx.Verification.Claims.Scopes() if len(scopes) == 0 { return "" } seen := make(map[string]struct{}) for _, scope := range scopes { scope = strings.ToLower(strings.TrimSpace(scope)) if scope == "" { continue } prefix := scope if idx := strings.IndexAny(prefix, ".:"); idx > 0 { prefix = prefix[:idx] } prefix = strings.TrimSpace(prefix) if prefix == "" { continue } if alias := matchAlias(prefix); alias != "" { seen[alias] = struct{}{} } } if len(seen) != 1 { return "" } for alias := range seen { return alias } return "" } func matchAlias(alias string) string { alias = strings.ToLower(strings.TrimSpace(alias)) if alias == "" { return "" } if _, ok := config.SSOClients[alias]; ok { return alias } for key := range config.SSOClients { if strings.EqualFold(key, alias) { return strings.ToLower(strings.TrimSpace(key)) } } return "" } func ApplyScopeFilter(db *gorm.DB, scope ScopeFilter, column string) *gorm.DB { if db == nil || !scope.Restrict { return db } if len(scope.IDs) == 0 { return db.Where("1 = 0") } return db.Where(column+" IN ?", scope.IDs) }