[FIX/BE-US]add feature restrict by location and areas in roles

This commit is contained in:
ragilap
2026-01-14 13:27:52 +07:00
parent 1847e5590a
commit dff9e73ab1
30 changed files with 1258 additions and 37 deletions
+280
View File
@@ -0,0 +1,280 @@
package middleware
import (
"errors"
"strings"
"github.com/gofiber/fiber/v2"
"gorm.io/gorm"
"gitlab.com/mbugroup/lti-api.git/internal/config"
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
)
type ScopeFilter struct {
IDs []uint
Restrict bool
}
type roleScope struct {
allArea bool
allLocation bool
areaIDs []uint
locationIDs []uint
hasAnyScopes bool
}
func ResolveAreaScope(c *fiber.Ctx, db *gorm.DB) (ScopeFilter, error) {
scope, err := collectRoleScope(c)
if err != nil || !scope.hasAnyScopes {
return ScopeFilter{}, err
}
if scope.allArea || scope.allLocation {
return ScopeFilter{}, nil
}
allowed := uniqueUint(scope.areaIDs)
if len(scope.locationIDs) > 0 {
derived, err := areaIDsByLocationIDs(db, scope.locationIDs)
if err != nil {
return ScopeFilter{}, err
}
allowed = uniqueUint(append(allowed, derived...))
}
if len(allowed) == 0 {
return ScopeFilter{Restrict: true}, nil
}
return ScopeFilter{IDs: allowed, Restrict: true}, nil
}
func ResolveLocationScope(c *fiber.Ctx, db *gorm.DB) (ScopeFilter, error) {
scope, err := collectRoleScope(c)
if err != nil || !scope.hasAnyScopes {
return ScopeFilter{}, err
}
if scope.allLocation || scope.allArea {
return ScopeFilter{}, nil
}
areaIDs := uniqueUint(scope.areaIDs)
locationIDs := uniqueUint(scope.locationIDs)
switch {
case len(locationIDs) > 0 && len(areaIDs) > 0:
filtered, err := filterLocationIDsByAreaIDs(db, locationIDs, areaIDs)
if err != nil {
return ScopeFilter{}, err
}
locationIDs = filtered
case len(locationIDs) == 0 && len(areaIDs) > 0:
derived, err := locationIDsByAreaIDs(db, areaIDs)
if err != nil {
return ScopeFilter{}, err
}
locationIDs = derived
}
locationIDs = uniqueUint(locationIDs)
if len(locationIDs) == 0 {
return ScopeFilter{Restrict: true}, nil
}
return ScopeFilter{IDs: locationIDs, Restrict: true}, nil
}
func collectRoleScope(c *fiber.Ctx) (roleScope, error) {
ctx, ok := AuthDetails(c)
if !ok || ctx == nil || len(ctx.Roles) == 0 {
return roleScope{}, nil
}
clientAlias := resolveClientAlias(ctx)
scope := roleScope{}
areaSet := make(map[uint]struct{})
locationSet := make(map[uint]struct{})
for _, role := range ctx.Roles {
if clientAlias != "" && !strings.EqualFold(strings.TrimSpace(role.ClientAlias), clientAlias) {
continue
}
scope.hasAnyScopes = true
if role.AllArea {
scope.allArea = true
}
if role.AllLocation {
scope.allLocation = true
}
for _, id := range role.AreaIDs {
if id == 0 {
continue
}
areaSet[id] = struct{}{}
}
for _, id := range role.LocationIDs {
if id == 0 {
continue
}
locationSet[id] = struct{}{}
}
}
scope.areaIDs = keysUint(areaSet)
scope.locationIDs = keysUint(locationSet)
scope.hasAnyScopes = scope.hasAnyScopes &&
(scope.allArea || scope.allLocation || len(scope.areaIDs) > 0 || len(scope.locationIDs) > 0)
return scope, nil
}
func areaIDsByLocationIDs(db *gorm.DB, locationIDs []uint) ([]uint, error) {
if db == nil {
return nil, errors.New("database not configured")
}
if len(locationIDs) == 0 {
return nil, nil
}
var areaIDs []uint
if err := db.Model(&entity.Location{}).
Where("deleted_at IS NULL").
Where("id IN ?", locationIDs).
Distinct("area_id").
Pluck("area_id", &areaIDs).Error; err != nil {
return nil, err
}
return areaIDs, nil
}
func locationIDsByAreaIDs(db *gorm.DB, areaIDs []uint) ([]uint, error) {
if db == nil {
return nil, errors.New("database not configured")
}
if len(areaIDs) == 0 {
return nil, nil
}
var locationIDs []uint
if err := db.Model(&entity.Location{}).
Where("deleted_at IS NULL").
Where("area_id IN ?", areaIDs).
Distinct("id").
Pluck("id", &locationIDs).Error; err != nil {
return nil, err
}
return locationIDs, nil
}
func filterLocationIDsByAreaIDs(db *gorm.DB, locationIDs, areaIDs []uint) ([]uint, error) {
if db == nil {
return nil, errors.New("database not configured")
}
if len(locationIDs) == 0 || len(areaIDs) == 0 {
return nil, nil
}
var filtered []uint
if err := db.Model(&entity.Location{}).
Where("deleted_at IS NULL").
Where("id IN ?", locationIDs).
Where("area_id IN ?", areaIDs).
Distinct("id").
Pluck("id", &filtered).Error; err != nil {
return nil, err
}
return filtered, nil
}
func uniqueUint(ids []uint) []uint {
if len(ids) == 0 {
return nil
}
seen := make(map[uint]struct{}, len(ids))
result := make([]uint, 0, len(ids))
for _, id := range ids {
if id == 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
result = append(result, id)
}
return result
}
func keysUint(set map[uint]struct{}) []uint {
if len(set) == 0 {
return nil
}
out := make([]uint, 0, len(set))
for id := range set {
out = append(out, id)
}
return out
}
func resolveClientAlias(ctx *AuthContext) string {
if ctx == nil || ctx.Verification == nil || ctx.Verification.Claims == nil {
return ""
}
scopes := ctx.Verification.Claims.Scopes()
if len(scopes) == 0 {
return ""
}
seen := make(map[string]struct{})
for _, scope := range scopes {
scope = strings.ToLower(strings.TrimSpace(scope))
if scope == "" {
continue
}
prefix := scope
if idx := strings.IndexAny(prefix, ".:"); idx > 0 {
prefix = prefix[:idx]
}
prefix = strings.TrimSpace(prefix)
if prefix == "" {
continue
}
if alias := matchAlias(prefix); alias != "" {
seen[alias] = struct{}{}
}
}
if len(seen) != 1 {
return ""
}
for alias := range seen {
return alias
}
return ""
}
func matchAlias(alias string) string {
alias = strings.ToLower(strings.TrimSpace(alias))
if alias == "" {
return ""
}
if _, ok := config.SSOClients[alias]; ok {
return alias
}
for key := range config.SSOClients {
if strings.EqualFold(key, alias) {
return strings.ToLower(strings.TrimSpace(key))
}
}
return ""
}
func ApplyScopeFilter(db *gorm.DB, scope ScopeFilter, column string) *gorm.DB {
if db == nil || !scope.Restrict {
return db
}
if len(scope.IDs) == 0 {
return db.Where("1 = 0")
}
return db.Where(column+" IN ?", scope.IDs)
}