package controllers import ( "context" "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/hex" "errors" "fmt" "strconv" "strings" "sync" "time" "github.com/gofiber/fiber/v2" "github.com/redis/go-redis/v9" "gorm.io/gorm" "gitlab.com/mbugroup/lti-api.git/internal/config" entity "gitlab.com/mbugroup/lti-api.git/internal/entities" sso "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/verifier" "gitlab.com/mbugroup/lti-api.git/internal/response" "gitlab.com/mbugroup/lti-api.git/internal/utils" ) type MasterDataController struct { db *gorm.DB redis *redis.Client clients map[string]config.SSOClientConfig drift time.Duration nonceTTL time.Duration localNonce sync.Map } type masterArea struct { ID uint `json:"id"` Name string `json:"name"` } type masterLocation struct { ID uint `json:"id"` Name string `json:"name"` AreaID uint `json:"area_id"` } func NewMasterDataController(db *gorm.DB, redis *redis.Client, clients map[string]config.SSOClientConfig) *MasterDataController { normalized := make(map[string]config.SSOClientConfig, len(clients)) for alias, cfg := range clients { alias = strings.ToLower(strings.TrimSpace(alias)) normalized[alias] = cfg } drift := config.SSOUserSyncDrift if drift <= 0 { drift = 2 * time.Minute } nonceTTL := config.SSOUserSyncNonceTTL if nonceTTL <= 0 { nonceTTL = 10 * time.Minute } return &MasterDataController{ db: db, redis: redis, clients: normalized, drift: drift, nonceTTL: nonceTTL, } } func (h *MasterDataController) GetAreas(c *fiber.Ctx) error { if _, _, err := h.authenticate(c, nil); err != nil { return err } search := strings.TrimSpace(c.Query("search", "")) ids := parseUintList(c.Query("ids", "")) query := h.db.WithContext(c.Context()). Model(&entity.Area{}). Where("deleted_at IS NULL") if search != "" { query = query.Where("name ILIKE ?", "%"+search+"%") } if len(ids) > 0 { query = query.Where("id IN ?", ids) } var areas []masterArea if err := query.Order("name ASC").Find(&areas).Error; err != nil { utils.Log.WithError(err).Error("failed to fetch areas for master data") return fiber.NewError(fiber.StatusInternalServerError, "failed to fetch areas") } return c.Status(fiber.StatusOK).JSON(response.Success{ Code: fiber.StatusOK, Status: "success", Message: "Get areas successfully", Data: areas, }) } func (h *MasterDataController) GetLocations(c *fiber.Ctx) error { if _, _, err := h.authenticate(c, nil); err != nil { return err } search := strings.TrimSpace(c.Query("search", "")) areaIDs := parseUintList(c.Query("area_ids", "")) ids := parseUintList(c.Query("ids", "")) query := h.db.WithContext(c.Context()). Model(&entity.Location{}). Where("deleted_at IS NULL") if search != "" { query = query.Where("name ILIKE ?", "%"+search+"%") } if len(areaIDs) > 0 { query = query.Where("area_id IN ?", areaIDs) } if len(ids) > 0 { query = query.Where("id IN ?", ids) } var locations []masterLocation if err := query.Order("name ASC").Find(&locations).Error; err != nil { utils.Log.WithError(err).Error("failed to fetch locations for master data") return fiber.NewError(fiber.StatusInternalServerError, "failed to fetch locations") } return c.Status(fiber.StatusOK).JSON(response.Success{ Code: fiber.StatusOK, Status: "success", Message: "Get locations successfully", Data: locations, }) } func (h *MasterDataController) authenticate(c *fiber.Ctx, body []byte) (string, config.SSOClientConfig, error) { rawAlias := strings.TrimSpace(c.Get("X-Sync-Client")) if rawAlias == "" { return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "missing sync client header") } aliasKey := strings.ToLower(rawAlias) clientCfg, ok := h.clients[aliasKey] if !ok { return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "unknown sync client") } if err := h.verifyAuthorization(c, aliasKey); err != nil { return "", config.SSOClientConfig{}, err } secret := strings.TrimSpace(clientCfg.SyncSecret) if secret == "" { return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "sync secret not configured") } timestamp := strings.TrimSpace(c.Get("X-Sync-Timestamp")) nonce := strings.TrimSpace(c.Get("X-Sync-Nonce")) signature := strings.TrimSpace(c.Get("X-Sync-Signature")) if timestamp == "" || nonce == "" || signature == "" { return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "missing signature headers") } if len(nonce) < 16 { return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "nonce too short") } ts, err := strconv.ParseInt(timestamp, 10, 64) if err != nil { return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusBadRequest, "invalid timestamp") } msgTime := time.Unix(ts, 0).UTC() now := time.Now().UTC() drift := now.Sub(msgTime) if drift > h.drift || drift < -h.drift { return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "timestamp outside allowed window") } providedSig, err := decodeMasterSignature(signature) if err != nil { return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature encoding") } expectedSignature := calculateSignature(secret, rawAlias, timestamp, nonce, body) if !hmac.Equal(providedSig, expectedSignature) { return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature") } if err := h.registerNonce(c.Context(), aliasKey, nonce); err != nil { return "", config.SSOClientConfig{}, err } return aliasKey, clientCfg, nil } func (h *MasterDataController) verifyAuthorization(c *fiber.Ctx, alias string) error { authHeader := strings.TrimSpace(c.Get(fiber.HeaderAuthorization)) if authHeader == "" { return fiber.NewError(fiber.StatusUnauthorized, "missing authorization header") } parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header") } token := strings.TrimSpace(parts[1]) if token == "" { return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header") } verification, err := sso.VerifyAccessToken(token) if err != nil { return fiber.NewError(fiber.StatusUnauthorized, "invalid access token") } if verification.ServiceAlias == "" || verification.ServiceAlias != alias { return fiber.NewError(fiber.StatusUnauthorized, "service subject mismatch") } if !hasAnyScope(verification.Claims.Scopes(), []string{"sync.master", "sync.users"}) { return fiber.NewError(fiber.StatusForbidden, "missing sync scope") } return nil } func (h *MasterDataController) registerNonce(ctx context.Context, alias, nonce string) error { ttl := h.nonceTTL if ttl <= 0 { ttl = 10 * time.Minute } key := fmt.Sprintf("sso:sync:%s:%s", alias, nonce) if h.redis != nil { stored, err := h.redis.SetNX(ctx, key, "1", ttl).Result() if err == nil { if !stored { return fiber.NewError(fiber.StatusUnauthorized, "nonce already used") } return nil } utils.Log.WithError(err).Warn("store sync nonce failed") } now := time.Now().UTC() if expRaw, ok := h.localNonce.Load(key); ok { if expTime, ok := expRaw.(time.Time); ok && expTime.After(now) { return fiber.NewError(fiber.StatusUnauthorized, "nonce already used") } } h.localNonce.Store(key, now.Add(ttl)) return nil } func calculateSignature(secret, alias, timestamp, nonce string, body []byte) []byte { mac := hmac.New(sha256.New, []byte(secret)) mac.Write([]byte(alias)) mac.Write([]byte("\n")) mac.Write([]byte(timestamp)) mac.Write([]byte("\n")) mac.Write([]byte(nonce)) mac.Write([]byte("\n")) if len(body) > 0 { mac.Write(body) } return mac.Sum(nil) } func decodeMasterSignature(sig string) ([]byte, error) { sig = strings.TrimSpace(sig) if sig == "" { return nil, errors.New("empty signature") } if decoded, err := hex.DecodeString(sig); err == nil { return decoded, nil } if decoded, err := base64.StdEncoding.DecodeString(sig); err == nil { return decoded, nil } if decoded, err := base64.URLEncoding.DecodeString(sig); err == nil { return decoded, nil } return nil, errors.New("unrecognized signature encoding") } func parseUintList(raw string) []uint { raw = strings.TrimSpace(raw) if raw == "" { return nil } parts := strings.Split(raw, ",") out := make([]uint, 0, len(parts)) seen := make(map[uint]struct{}, len(parts)) for _, part := range parts { part = strings.TrimSpace(part) if part == "" { continue } val, err := strconv.ParseUint(part, 10, 64) if err != nil || val == 0 { continue } if _, ok := seen[uint(val)]; ok { continue } seen[uint(val)] = struct{}{} out = append(out, uint(val)) } return out } func hasAnyScope(scopes []string, targets []string) bool { if len(scopes) == 0 || len(targets) == 0 { return false } for _, scope := range scopes { scope = strings.ToLower(strings.TrimSpace(scope)) if scope == "" { continue } for _, target := range targets { if scope == strings.ToLower(strings.TrimSpace(target)) { return true } } } return false }