[FIX/BE-US]add feature restrict by location and areas in roles

This commit is contained in:
ragilap
2026-01-14 13:27:52 +07:00
parent 1847e5590a
commit dff9e73ab1
30 changed files with 1258 additions and 37 deletions
@@ -0,0 +1,331 @@
package controllers
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
"gitlab.com/mbugroup/lti-api.git/internal/config"
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
sso "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/verifier"
"gitlab.com/mbugroup/lti-api.git/internal/response"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
)
type MasterDataController struct {
db *gorm.DB
redis *redis.Client
clients map[string]config.SSOClientConfig
drift time.Duration
nonceTTL time.Duration
localNonce sync.Map
}
type masterArea struct {
ID uint `json:"id"`
Name string `json:"name"`
}
type masterLocation struct {
ID uint `json:"id"`
Name string `json:"name"`
AreaID uint `json:"area_id"`
}
func NewMasterDataController(db *gorm.DB, redis *redis.Client, clients map[string]config.SSOClientConfig) *MasterDataController {
normalized := make(map[string]config.SSOClientConfig, len(clients))
for alias, cfg := range clients {
alias = strings.ToLower(strings.TrimSpace(alias))
normalized[alias] = cfg
}
drift := config.SSOUserSyncDrift
if drift <= 0 {
drift = 2 * time.Minute
}
nonceTTL := config.SSOUserSyncNonceTTL
if nonceTTL <= 0 {
nonceTTL = 10 * time.Minute
}
return &MasterDataController{
db: db,
redis: redis,
clients: normalized,
drift: drift,
nonceTTL: nonceTTL,
}
}
func (h *MasterDataController) GetAreas(c *fiber.Ctx) error {
if _, _, err := h.authenticate(c, nil); err != nil {
return err
}
search := strings.TrimSpace(c.Query("search", ""))
ids := parseUintList(c.Query("ids", ""))
query := h.db.WithContext(c.Context()).
Model(&entity.Area{}).
Where("deleted_at IS NULL")
if search != "" {
query = query.Where("name ILIKE ?", "%"+search+"%")
}
if len(ids) > 0 {
query = query.Where("id IN ?", ids)
}
var areas []masterArea
if err := query.Order("name ASC").Find(&areas).Error; err != nil {
utils.Log.WithError(err).Error("failed to fetch areas for master data")
return fiber.NewError(fiber.StatusInternalServerError, "failed to fetch areas")
}
return c.Status(fiber.StatusOK).JSON(response.Success{
Code: fiber.StatusOK,
Status: "success",
Message: "Get areas successfully",
Data: areas,
})
}
func (h *MasterDataController) GetLocations(c *fiber.Ctx) error {
if _, _, err := h.authenticate(c, nil); err != nil {
return err
}
search := strings.TrimSpace(c.Query("search", ""))
areaIDs := parseUintList(c.Query("area_ids", ""))
ids := parseUintList(c.Query("ids", ""))
query := h.db.WithContext(c.Context()).
Model(&entity.Location{}).
Where("deleted_at IS NULL")
if search != "" {
query = query.Where("name ILIKE ?", "%"+search+"%")
}
if len(areaIDs) > 0 {
query = query.Where("area_id IN ?", areaIDs)
}
if len(ids) > 0 {
query = query.Where("id IN ?", ids)
}
var locations []masterLocation
if err := query.Order("name ASC").Find(&locations).Error; err != nil {
utils.Log.WithError(err).Error("failed to fetch locations for master data")
return fiber.NewError(fiber.StatusInternalServerError, "failed to fetch locations")
}
return c.Status(fiber.StatusOK).JSON(response.Success{
Code: fiber.StatusOK,
Status: "success",
Message: "Get locations successfully",
Data: locations,
})
}
func (h *MasterDataController) authenticate(c *fiber.Ctx, body []byte) (string, config.SSOClientConfig, error) {
rawAlias := strings.TrimSpace(c.Get("X-Sync-Client"))
if rawAlias == "" {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "missing sync client header")
}
aliasKey := strings.ToLower(rawAlias)
clientCfg, ok := h.clients[aliasKey]
if !ok {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "unknown sync client")
}
if err := h.verifyAuthorization(c, aliasKey); err != nil {
return "", config.SSOClientConfig{}, err
}
secret := strings.TrimSpace(clientCfg.SyncSecret)
if secret == "" {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "sync secret not configured")
}
timestamp := strings.TrimSpace(c.Get("X-Sync-Timestamp"))
nonce := strings.TrimSpace(c.Get("X-Sync-Nonce"))
signature := strings.TrimSpace(c.Get("X-Sync-Signature"))
if timestamp == "" || nonce == "" || signature == "" {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "missing signature headers")
}
if len(nonce) < 16 {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "nonce too short")
}
ts, err := strconv.ParseInt(timestamp, 10, 64)
if err != nil {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusBadRequest, "invalid timestamp")
}
msgTime := time.Unix(ts, 0).UTC()
now := time.Now().UTC()
drift := now.Sub(msgTime)
if drift > h.drift || drift < -h.drift {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "timestamp outside allowed window")
}
providedSig, err := decodeMasterSignature(signature)
if err != nil {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature encoding")
}
expectedSignature := calculateSignature(secret, rawAlias, timestamp, nonce, body)
if !hmac.Equal(providedSig, expectedSignature) {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature")
}
if err := h.registerNonce(c.Context(), aliasKey, nonce); err != nil {
return "", config.SSOClientConfig{}, err
}
return aliasKey, clientCfg, nil
}
func (h *MasterDataController) verifyAuthorization(c *fiber.Ctx, alias string) error {
authHeader := strings.TrimSpace(c.Get(fiber.HeaderAuthorization))
if authHeader == "" {
return fiber.NewError(fiber.StatusUnauthorized, "missing authorization header")
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header")
}
token := strings.TrimSpace(parts[1])
if token == "" {
return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header")
}
verification, err := sso.VerifyAccessToken(token)
if err != nil {
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
}
if verification.ServiceAlias == "" || verification.ServiceAlias != alias {
return fiber.NewError(fiber.StatusUnauthorized, "service subject mismatch")
}
if !hasAnyScope(verification.Claims.Scopes(), []string{"sync.master", "sync.users"}) {
return fiber.NewError(fiber.StatusForbidden, "missing sync scope")
}
return nil
}
func (h *MasterDataController) registerNonce(ctx context.Context, alias, nonce string) error {
ttl := h.nonceTTL
if ttl <= 0 {
ttl = 10 * time.Minute
}
key := fmt.Sprintf("sso:sync:%s:%s", alias, nonce)
if h.redis != nil {
stored, err := h.redis.SetNX(ctx, key, "1", ttl).Result()
if err == nil {
if !stored {
return fiber.NewError(fiber.StatusUnauthorized, "nonce already used")
}
return nil
}
utils.Log.WithError(err).Warn("store sync nonce failed")
}
now := time.Now().UTC()
if expRaw, ok := h.localNonce.Load(key); ok {
if expTime, ok := expRaw.(time.Time); ok && expTime.After(now) {
return fiber.NewError(fiber.StatusUnauthorized, "nonce already used")
}
}
h.localNonce.Store(key, now.Add(ttl))
return nil
}
func calculateSignature(secret, alias, timestamp, nonce string, body []byte) []byte {
mac := hmac.New(sha256.New, []byte(secret))
mac.Write([]byte(alias))
mac.Write([]byte("\n"))
mac.Write([]byte(timestamp))
mac.Write([]byte("\n"))
mac.Write([]byte(nonce))
mac.Write([]byte("\n"))
if len(body) > 0 {
mac.Write(body)
}
return mac.Sum(nil)
}
func decodeMasterSignature(sig string) ([]byte, error) {
sig = strings.TrimSpace(sig)
if sig == "" {
return nil, errors.New("empty signature")
}
if decoded, err := hex.DecodeString(sig); err == nil {
return decoded, nil
}
if decoded, err := base64.StdEncoding.DecodeString(sig); err == nil {
return decoded, nil
}
if decoded, err := base64.URLEncoding.DecodeString(sig); err == nil {
return decoded, nil
}
return nil, errors.New("unrecognized signature encoding")
}
func parseUintList(raw string) []uint {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
parts := strings.Split(raw, ",")
out := make([]uint, 0, len(parts))
seen := make(map[uint]struct{}, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
val, err := strconv.ParseUint(part, 10, 64)
if err != nil || val == 0 {
continue
}
if _, ok := seen[uint(val)]; ok {
continue
}
seen[uint(val)] = struct{}{}
out = append(out, uint(val))
}
return out
}
func hasAnyScope(scopes []string, targets []string) bool {
if len(scopes) == 0 || len(targets) == 0 {
return false
}
for _, scope := range scopes {
scope = strings.ToLower(strings.TrimSpace(scope))
if scope == "" {
continue
}
for _, target := range targets {
if scope == strings.ToLower(strings.TrimSpace(target)) {
return true
}
}
}
return false
}
@@ -16,7 +16,7 @@ import (
"gitlab.com/mbugroup/lti-api.git/internal/config"
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
"gitlab.com/mbugroup/lti-api.git/internal/sso"
sso "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/verifier"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
"gitlab.com/mbugroup/lti-api.git/internal/utils/secure"
)
@@ -9,23 +9,24 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2"
"github.com/redis/go-redis/v9"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
"strconv"
"strings"
"sync"
"time"
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2"
"github.com/redis/go-redis/v9"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
"gitlab.com/mbugroup/lti-api.git/internal/config"
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
sso "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/verifier"
"gitlab.com/mbugroup/lti-api.git/internal/modules/users/dto"
userRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/users/repositories"
"gitlab.com/mbugroup/lti-api.git/internal/response"
"gitlab.com/mbugroup/lti-api.git/internal/sso"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
)