Files
lti-api/internal/middleware/role_scope.go
T
2026-01-27 10:34:25 +07:00

637 lines
16 KiB
Go

package middleware
import (
"errors"
"github.com/gofiber/fiber/v2"
"gorm.io/gorm"
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
)
type ScopeFilter struct {
IDs []uint
Restrict bool
}
type roleScope struct {
allArea bool
allLocation bool
areaIDs []uint
locationIDs []uint
hasAnyScopes bool
}
func ResolveAreaScope(c *fiber.Ctx, db *gorm.DB) (ScopeFilter, error) {
scope, err := collectRoleScope(c)
if err != nil || !scope.hasAnyScopes {
return ScopeFilter{}, err
}
if scope.allArea || scope.allLocation {
return ScopeFilter{}, nil
}
allowed := uniqueUint(scope.areaIDs)
if len(scope.locationIDs) > 0 {
derived, err := areaIDsByLocationIDs(db, scope.locationIDs)
if err != nil {
return ScopeFilter{}, err
}
allowed = uniqueUint(append(allowed, derived...))
}
if len(allowed) == 0 {
return ScopeFilter{Restrict: true}, nil
}
return ScopeFilter{IDs: allowed, Restrict: true}, nil
}
func ResolveLocationScope(c *fiber.Ctx, db *gorm.DB) (ScopeFilter, error) {
scope, err := collectRoleScope(c)
if err != nil || !scope.hasAnyScopes {
return ScopeFilter{}, err
}
if scope.allLocation || scope.allArea {
return ScopeFilter{}, nil
}
areaIDs := uniqueUint(scope.areaIDs)
locationIDs := uniqueUint(scope.locationIDs)
switch {
case len(locationIDs) > 0 && len(areaIDs) > 0:
filtered, err := filterLocationIDsByAreaIDs(db, locationIDs, areaIDs)
if err != nil {
return ScopeFilter{}, err
}
locationIDs = filtered
case len(locationIDs) == 0 && len(areaIDs) > 0:
derived, err := locationIDsByAreaIDs(db, areaIDs)
if err != nil {
return ScopeFilter{}, err
}
locationIDs = derived
}
locationIDs = uniqueUint(locationIDs)
if len(locationIDs) == 0 {
return ScopeFilter{Restrict: true}, nil
}
return ScopeFilter{IDs: locationIDs, Restrict: true}, nil
}
func ResolveLocationAreaScopes(c *fiber.Ctx, db *gorm.DB) (ScopeFilter, ScopeFilter, error) {
locationScope, err := ResolveLocationScope(c, db)
if err != nil {
return ScopeFilter{}, ScopeFilter{}, err
}
areaScope, err := ResolveAreaScope(c, db)
if err != nil {
return ScopeFilter{}, ScopeFilter{}, err
}
return locationScope, areaScope, nil
}
func collectRoleScope(c *fiber.Ctx) (roleScope, error) {
ctx, ok := AuthDetails(c)
if !ok || ctx == nil {
return roleScope{}, nil
}
userAreaIDs := uniqueUint(ctx.UserAreaIDs)
userLocationIDs := uniqueUint(ctx.UserLocationIDs)
userScope := roleScope{
allArea: ctx.UserAllArea,
allLocation: ctx.UserAllLocation,
areaIDs: userAreaIDs,
locationIDs: userLocationIDs,
hasAnyScopes: ctx.UserAllArea || ctx.UserAllLocation || len(userAreaIDs) > 0 || len(userLocationIDs) > 0,
}
if userScope.hasAnyScopes {
return userScope, nil
}
return roleScope{}, nil
}
func areaIDsByLocationIDs(db *gorm.DB, locationIDs []uint) ([]uint, error) {
if db == nil {
return nil, errors.New("database not configured")
}
if len(locationIDs) == 0 {
return nil, nil
}
var areaIDs []uint
if err := db.Model(&entity.Location{}).
Where("deleted_at IS NULL").
Where("id IN ?", locationIDs).
Distinct("area_id").
Pluck("area_id", &areaIDs).Error; err != nil {
return nil, err
}
return areaIDs, nil
}
func locationIDsByAreaIDs(db *gorm.DB, areaIDs []uint) ([]uint, error) {
if db == nil {
return nil, errors.New("database not configured")
}
if len(areaIDs) == 0 {
return nil, nil
}
var locationIDs []uint
if err := db.Model(&entity.Location{}).
Where("deleted_at IS NULL").
Where("area_id IN ?", areaIDs).
Distinct("id").
Pluck("id", &locationIDs).Error; err != nil {
return nil, err
}
return locationIDs, nil
}
func filterLocationIDsByAreaIDs(db *gorm.DB, locationIDs, areaIDs []uint) ([]uint, error) {
if db == nil {
return nil, errors.New("database not configured")
}
if len(locationIDs) == 0 || len(areaIDs) == 0 {
return nil, nil
}
var filtered []uint
if err := db.Model(&entity.Location{}).
Where("deleted_at IS NULL").
Where("id IN ?", locationIDs).
Where("area_id IN ?", areaIDs).
Distinct("id").
Pluck("id", &filtered).Error; err != nil {
return nil, err
}
return filtered, nil
}
func uniqueUint(ids []uint) []uint {
if len(ids) == 0 {
return nil
}
seen := make(map[uint]struct{}, len(ids))
result := make([]uint, 0, len(ids))
for _, id := range ids {
if id == 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
result = append(result, id)
}
return result
}
func ApplyScopeFilter(db *gorm.DB, scope ScopeFilter, column string) *gorm.DB {
if db == nil || !scope.Restrict {
return db
}
if len(scope.IDs) == 0 {
return db.Where("1 = 0")
}
return db.Where(column+" IN ?", scope.IDs)
}
func ApplyLocationScope(c *fiber.Ctx, db *gorm.DB, column string) (*gorm.DB, error) {
scopeDB := db
if db != nil {
scopeDB = db.Session(&gorm.Session{NewDB: true})
}
scope, err := ResolveLocationScope(c, scopeDB)
if err != nil {
return db, err
}
return ApplyScopeFilter(db, scope, column), nil
}
func ApplyAreaScope(c *fiber.Ctx, db *gorm.DB, column string) (*gorm.DB, error) {
scopeDB := db
if db != nil {
scopeDB = db.Session(&gorm.Session{NewDB: true})
}
scope, err := ResolveAreaScope(c, scopeDB)
if err != nil {
return db, err
}
return ApplyScopeFilter(db, scope, column), nil
}
func ApplyLocationAreaScope(c *fiber.Ctx, db *gorm.DB, locationColumn, areaColumn string) (*gorm.DB, error) {
scopeDB := db
if db != nil {
scopeDB = db.Session(&gorm.Session{NewDB: true})
}
if locationColumn != "" {
locationScope, err := ResolveLocationScope(c, scopeDB)
if err != nil {
return db, err
}
db = ApplyScopeFilter(db, locationScope, locationColumn)
}
if areaColumn != "" {
areaScope, err := ResolveAreaScope(c, scopeDB)
if err != nil {
return db, err
}
db = ApplyScopeFilter(db, areaScope, areaColumn)
}
return db, nil
}
func EnsureWarehouseAccess(c *fiber.Ctx, db *gorm.DB, warehouseID uint) error {
if warehouseID == 0 {
return fiber.NewError(fiber.StatusBadRequest, "Invalid warehouse id")
}
if db == nil {
return fiber.NewError(fiber.StatusInternalServerError, "Database not configured")
}
scope, err := ResolveLocationScope(c, db)
if err != nil || !scope.Restrict {
return err
}
if len(scope.IDs) == 0 {
return fiber.NewError(fiber.StatusNotFound, "Warehouse not found")
}
var count int64
if err := ApplyScopeFilter(
db.WithContext(c.Context()).
Model(&entity.Warehouse{}).
Where("id = ?", warehouseID),
scope,
"warehouses.location_id",
).Count(&count).Error; err != nil {
return err
}
if count == 0 {
return fiber.NewError(fiber.StatusNotFound, "Warehouse not found")
}
return nil
}
func EnsureAreaAccess(c *fiber.Ctx, db *gorm.DB, areaID uint) error {
if areaID == 0 {
return fiber.NewError(fiber.StatusBadRequest, "Invalid area id")
}
if db == nil {
return fiber.NewError(fiber.StatusInternalServerError, "Database not configured")
}
scope, err := ResolveAreaScope(c, db)
if err != nil || !scope.Restrict {
return err
}
if len(scope.IDs) == 0 {
return fiber.NewError(fiber.StatusNotFound, "Area not found")
}
var count int64
if err := ApplyScopeFilter(
db.WithContext(c.Context()).
Model(&entity.Area{}).
Where("id = ?", areaID),
scope,
"areas.id",
).Count(&count).Error; err != nil {
return err
}
if count == 0 {
return fiber.NewError(fiber.StatusNotFound, "Area not found")
}
return nil
}
func EnsureLocationAccess(c *fiber.Ctx, db *gorm.DB, locationID uint) error {
if locationID == 0 {
return fiber.NewError(fiber.StatusBadRequest, "Invalid location id")
}
if db == nil {
return fiber.NewError(fiber.StatusInternalServerError, "Database not configured")
}
scope, err := ResolveLocationScope(c, db)
if err != nil || !scope.Restrict {
return err
}
if len(scope.IDs) == 0 {
return fiber.NewError(fiber.StatusNotFound, "Location not found")
}
var count int64
if err := ApplyScopeFilter(
db.WithContext(c.Context()).
Model(&entity.Location{}).
Where("id = ?", locationID),
scope,
"locations.id",
).Count(&count).Error; err != nil {
return err
}
if count == 0 {
return fiber.NewError(fiber.StatusNotFound, "Location not found")
}
return nil
}
func EnsureKandangAccess(c *fiber.Ctx, db *gorm.DB, kandangID uint) error {
if kandangID == 0 {
return fiber.NewError(fiber.StatusBadRequest, "Invalid kandang id")
}
if db == nil {
return fiber.NewError(fiber.StatusInternalServerError, "Database not configured")
}
scope, err := ResolveLocationScope(c, db)
if err != nil || !scope.Restrict {
return err
}
if len(scope.IDs) == 0 {
return fiber.NewError(fiber.StatusNotFound, "Kandang not found")
}
var count int64
if err := ApplyScopeFilter(
db.WithContext(c.Context()).
Model(&entity.Kandang{}).
Where("id = ?", kandangID),
scope,
"kandangs.location_id",
).Count(&count).Error; err != nil {
return err
}
if count == 0 {
return fiber.NewError(fiber.StatusNotFound, "Kandang not found")
}
return nil
}
func EnsureProductWarehouseAccess(c *fiber.Ctx, db *gorm.DB, productWarehouseID uint) error {
if productWarehouseID == 0 {
return fiber.NewError(fiber.StatusBadRequest, "Invalid product warehouse id")
}
if db == nil {
return fiber.NewError(fiber.StatusInternalServerError, "Database not configured")
}
scope, err := ResolveLocationScope(c, db)
if err != nil || !scope.Restrict {
return err
}
if len(scope.IDs) == 0 {
return fiber.NewError(fiber.StatusNotFound, "Product warehouse not found")
}
var count int64
q := db.WithContext(c.Context()).
Table("product_warehouses pw").
Joins("JOIN warehouses w ON w.id = pw.warehouse_id").
Where("pw.id = ?", productWarehouseID)
q = ApplyScopeFilter(q, scope, "w.location_id")
if err := q.Count(&count).Error; err != nil {
return err
}
if count == 0 {
return fiber.NewError(fiber.StatusNotFound, "Product warehouse not found")
}
return nil
}
func EnsureStockLogAccess(c *fiber.Ctx, db *gorm.DB, stockLogID uint) error {
if stockLogID == 0 {
return fiber.NewError(fiber.StatusBadRequest, "Invalid stock log id")
}
if db == nil {
return fiber.NewError(fiber.StatusInternalServerError, "Database not configured")
}
scope, err := ResolveLocationScope(c, db)
if err != nil || !scope.Restrict {
return err
}
if len(scope.IDs) == 0 {
return fiber.NewError(fiber.StatusNotFound, "Stock log not found")
}
var count int64
q := db.WithContext(c.Context()).
Table("stock_logs sl").
Joins("JOIN product_warehouses pw ON pw.id = sl.product_warehouse_id").
Joins("JOIN warehouses w ON w.id = pw.warehouse_id").
Where("sl.id = ?", stockLogID)
q = ApplyScopeFilter(q, scope, "w.location_id")
if err := q.Count(&count).Error; err != nil {
return err
}
if count == 0 {
return fiber.NewError(fiber.StatusNotFound, "Stock log not found")
}
return nil
}
func EnsureMarketingAccess(c *fiber.Ctx, db *gorm.DB, marketingID uint) error {
if marketingID == 0 {
return fiber.NewError(fiber.StatusBadRequest, "Invalid marketing id")
}
if db == nil {
return fiber.NewError(fiber.StatusInternalServerError, "Database not configured")
}
scope, err := ResolveLocationScope(c, db)
if err != nil || !scope.Restrict {
return err
}
if len(scope.IDs) == 0 {
return fiber.NewError(fiber.StatusNotFound, "Marketing not found")
}
var count int64
q := db.WithContext(c.Context()).
Table("marketings m").
Joins("JOIN marketing_products mp ON mp.marketing_id = m.id").
Joins("JOIN product_warehouses pw ON pw.id = mp.product_warehouse_id").
Joins("JOIN warehouses w ON w.id = pw.warehouse_id").
Where("m.id = ?", marketingID)
q = ApplyScopeFilter(q, scope, "w.location_id")
if err := q.Count(&count).Error; err != nil {
return err
}
if count == 0 {
return fiber.NewError(fiber.StatusNotFound, "Marketing not found")
}
return nil
}
func EnsureRecordingAccess(c *fiber.Ctx, db *gorm.DB, recordingID uint) error {
if recordingID == 0 {
return fiber.NewError(fiber.StatusBadRequest, "Invalid recording id")
}
if db == nil {
return fiber.NewError(fiber.StatusInternalServerError, "Database not configured")
}
scope, err := ResolveLocationScope(c, db)
if err != nil || !scope.Restrict {
return err
}
if len(scope.IDs) == 0 {
return fiber.NewError(fiber.StatusNotFound, "Recording not found")
}
var count int64
q := db.WithContext(c.Context()).
Table("recordings r").
Joins("JOIN project_flock_kandangs pfk ON pfk.id = r.project_flock_kandangs_id").
Joins("JOIN project_flocks pf ON pf.id = pfk.project_flock_id").
Where("r.id = ?", recordingID)
q = ApplyScopeFilter(q, scope, "pf.location_id")
if err := q.Count(&count).Error; err != nil {
return err
}
if count == 0 {
return fiber.NewError(fiber.StatusNotFound, "Recording not found")
}
return nil
}
func EnsureUniformityAccess(c *fiber.Ctx, db *gorm.DB, uniformityID uint) error {
if uniformityID == 0 {
return fiber.NewError(fiber.StatusBadRequest, "Invalid uniformity id")
}
if db == nil {
return fiber.NewError(fiber.StatusInternalServerError, "Database not configured")
}
scope, err := ResolveLocationScope(c, db)
if err != nil || !scope.Restrict {
return err
}
if len(scope.IDs) == 0 {
return fiber.NewError(fiber.StatusNotFound, "Uniformity not found")
}
var count int64
q := db.WithContext(c.Context()).
Table("project_flock_kandang_uniformity u").
Joins("JOIN project_flock_kandangs pfk ON pfk.id = u.project_flock_kandang_id").
Joins("JOIN project_flocks pf ON pf.id = pfk.project_flock_id").
Where("u.id = ?", uniformityID)
q = ApplyScopeFilter(q, scope, "pf.location_id")
if err := q.Count(&count).Error; err != nil {
return err
}
if count == 0 {
return fiber.NewError(fiber.StatusNotFound, "Uniformity not found")
}
return nil
}
func EnsureLayingTransferAccess(c *fiber.Ctx, db *gorm.DB, transferID uint) error {
if transferID == 0 {
return fiber.NewError(fiber.StatusBadRequest, "Invalid transfer id")
}
if db == nil {
return fiber.NewError(fiber.StatusInternalServerError, "Database not configured")
}
scope, err := ResolveLocationScope(c, db)
if err != nil || !scope.Restrict {
return err
}
if len(scope.IDs) == 0 {
return fiber.NewError(fiber.StatusNotFound, "Transfer not found")
}
var count int64
q := db.WithContext(c.Context()).
Table("laying_transfers lt").
Joins("JOIN project_flocks pf_from ON pf_from.id = lt.from_project_flock_id").
Joins("JOIN project_flocks pf_to ON pf_to.id = lt.to_project_flock_id").
Where("lt.id = ?", transferID).
Where("(pf_from.location_id IN ? OR pf_to.location_id IN ?)", scope.IDs, scope.IDs)
if err := q.Count(&count).Error; err != nil {
return err
}
if count == 0 {
return fiber.NewError(fiber.StatusNotFound, "Transfer not found")
}
return nil
}
func EnsureProjectFlockAccess(c *fiber.Ctx, db *gorm.DB, projectFlockID uint) error {
if projectFlockID == 0 {
return fiber.NewError(fiber.StatusBadRequest, "Invalid project flock id")
}
if db == nil {
return fiber.NewError(fiber.StatusInternalServerError, "Database not configured")
}
scope, err := ResolveLocationScope(c, db)
if err != nil || !scope.Restrict {
return err
}
if len(scope.IDs) == 0 {
return fiber.NewError(fiber.StatusNotFound, "Project Flock not found")
}
var count int64
if err := ApplyScopeFilter(
db.WithContext(c.Context()).
Model(&entity.ProjectFlock{}).
Where("id = ?", projectFlockID),
scope,
"project_flocks.location_id",
).Count(&count).Error; err != nil {
return err
}
if count == 0 {
return fiber.NewError(fiber.StatusNotFound, "Project Flock not found")
}
return nil
}
func EnsureProjectFlockKandangAccess(c *fiber.Ctx, db *gorm.DB, projectFlockID, projectFlockKandangID uint) error {
if projectFlockKandangID == 0 {
return fiber.NewError(fiber.StatusBadRequest, "Invalid project flock kandang id")
}
if db == nil {
return fiber.NewError(fiber.StatusInternalServerError, "Database not configured")
}
scope, err := ResolveLocationScope(c, db)
if err != nil || !scope.Restrict {
return err
}
if len(scope.IDs) == 0 {
return fiber.NewError(fiber.StatusNotFound, "Project Flock Kandang not found")
}
var count int64
q := db.WithContext(c.Context()).
Table("project_flock_kandangs").
Joins("JOIN project_flocks ON project_flocks.id = project_flock_kandangs.project_flock_id").
Where("project_flock_kandangs.id = ?", projectFlockKandangID)
if projectFlockID > 0 {
q = q.Where("project_flock_kandangs.project_flock_id = ?", projectFlockID)
}
q = ApplyScopeFilter(q, scope, "project_flocks.location_id")
if err := q.Count(&count).Error; err != nil {
return err
}
if count == 0 {
return fiber.NewError(fiber.StatusNotFound, "Project Flock Kandang not found")
}
return nil
}