mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-20 13:31:56 +00:00
392 lines
12 KiB
Go
392 lines
12 KiB
Go
package controllers
|
|
|
|
import (
|
|
"context"
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
"github.com/go-playground/validator/v10"
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/redis/go-redis/v9"
|
|
"github.com/sirupsen/logrus"
|
|
"gorm.io/gorm"
|
|
|
|
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
|
"gitlab.com/mbugroup/lti-api.git/internal/modules/users/dto"
|
|
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
|
|
userRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/users/repositories"
|
|
"gitlab.com/mbugroup/lti-api.git/internal/response"
|
|
"gitlab.com/mbugroup/lti-api.git/internal/sso"
|
|
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
|
)
|
|
|
|
|
|
const (
|
|
headerClient = "X-Sync-Client"
|
|
headerTimestamp = "X-Sync-Timestamp"
|
|
headerNonce = "X-Sync-Nonce"
|
|
headerSignature = "X-Sync-Signature"
|
|
defaultDrift = 2 * time.Minute
|
|
defaultNonceTTL = 10 * time.Minute
|
|
)
|
|
|
|
// UserSyncController handles incoming user management events from the central SSO service.
|
|
type UserSyncController struct {
|
|
validate *validator.Validate
|
|
repo userRepository.UserRepository
|
|
redis *redis.Client
|
|
clients map[string]config.SSOClientConfig
|
|
drift time.Duration
|
|
nonceTTL time.Duration
|
|
maxBodyBytes int
|
|
log *logrus.Logger
|
|
localNonces sync.Map
|
|
}
|
|
|
|
type userSyncRequest struct {
|
|
Action string `json:"action" validate:"required,oneof=create update delete"`
|
|
PublicID string `json:"public_id" validate:"required"`
|
|
User userSyncUser `json:"user" validate:"required"`
|
|
}
|
|
|
|
type userSyncUser struct {
|
|
ID int64 `json:"id" validate:"required"`
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
}
|
|
|
|
func NewUserSyncController(validate *validator.Validate, repo userRepository.UserRepository, redis *redis.Client, clients map[string]config.SSOClientConfig) *UserSyncController {
|
|
normalized := make(map[string]config.SSOClientConfig, len(clients))
|
|
for alias, cfg := range clients {
|
|
alias = strings.ToLower(strings.TrimSpace(alias))
|
|
normalized[alias] = cfg
|
|
}
|
|
|
|
drift := config.SSOUserSyncDrift
|
|
if drift <= 0 {
|
|
drift = defaultDrift
|
|
}
|
|
|
|
nonceTTL := config.SSOUserSyncNonceTTL
|
|
if nonceTTL <= 0 {
|
|
nonceTTL = defaultNonceTTL
|
|
}
|
|
|
|
maxBody := config.SSOUserSyncMaxBodyBytes
|
|
if maxBody <= 0 {
|
|
maxBody = 32 * 1024
|
|
}
|
|
|
|
log := utils.Log
|
|
if redis == nil {
|
|
log.Warn("SSO user sync nonce store fallback to in-memory cache; enable Redis for replay protection")
|
|
}
|
|
|
|
return &UserSyncController{
|
|
validate: validate,
|
|
repo: repo,
|
|
redis: redis,
|
|
clients: normalized,
|
|
drift: drift,
|
|
nonceTTL: nonceTTL,
|
|
maxBodyBytes: maxBody,
|
|
log: log,
|
|
}
|
|
}
|
|
|
|
func (h *UserSyncController) Sync(c *fiber.Ctx) error {
|
|
if ct := strings.TrimSpace(c.Get(fiber.HeaderContentType)); ct != "" && !strings.HasPrefix(strings.ToLower(ct), fiber.MIMEApplicationJSON) {
|
|
return fiber.NewError(fiber.StatusUnsupportedMediaType, "content-type must be application/json")
|
|
}
|
|
|
|
body := c.Body()
|
|
if h.maxBodyBytes > 0 && len(body) > h.maxBodyBytes {
|
|
return fiber.NewError(fiber.StatusRequestEntityTooLarge, "request body too large")
|
|
}
|
|
|
|
alias, clientCfg, err := h.authenticate(c, body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req := new(userSyncRequest)
|
|
if err := json.Unmarshal(body, req); err != nil {
|
|
return fiber.NewError(fiber.StatusBadRequest, "invalid request body")
|
|
}
|
|
|
|
req.Action = strings.ToLower(strings.TrimSpace(req.Action))
|
|
req.PublicID = strings.TrimSpace(req.PublicID)
|
|
req.User.Email = strings.TrimSpace(req.User.Email)
|
|
req.User.Name = strings.TrimSpace(req.User.Name)
|
|
|
|
if err := h.validate.Struct(req); err != nil {
|
|
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
|
}
|
|
|
|
if clientCfg.PublicID != "" && req.PublicID != clientCfg.PublicID {
|
|
return fiber.NewError(fiber.StatusBadRequest, "public_id mismatch with configured client")
|
|
}
|
|
|
|
if req.Action != "delete" {
|
|
if req.User.Email == "" || req.User.Name == "" {
|
|
return fiber.NewError(fiber.StatusBadRequest, "email and name are required for create/update actions")
|
|
}
|
|
if err := h.validate.Var(req.User.Email, "email"); err != nil {
|
|
return fiber.NewError(fiber.StatusBadRequest, "invalid email format")
|
|
}
|
|
}
|
|
|
|
if req.User.ID <= 0 {
|
|
return fiber.NewError(fiber.StatusBadRequest, "invalid user id")
|
|
}
|
|
|
|
switch req.Action {
|
|
case "create", "update":
|
|
return h.upsertUser(c, alias, req)
|
|
case "delete":
|
|
return h.removeUser(c, alias, req)
|
|
default:
|
|
return fiber.NewError(fiber.StatusBadRequest, "unsupported action")
|
|
}
|
|
}
|
|
|
|
func (h *UserSyncController) authenticate(c *fiber.Ctx, body []byte) (string, config.SSOClientConfig, error) {
|
|
rawAlias := strings.TrimSpace(c.Get(headerClient))
|
|
if rawAlias == "" {
|
|
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "missing sync client header")
|
|
}
|
|
|
|
aliasKey := strings.ToLower(rawAlias)
|
|
clientCfg, ok := h.clients[aliasKey]
|
|
if !ok {
|
|
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "unknown sync client")
|
|
}
|
|
|
|
if err := h.verifyAuthorization(c, aliasKey); err != nil {
|
|
return "", config.SSOClientConfig{}, err
|
|
}
|
|
|
|
secret := strings.TrimSpace(clientCfg.SyncSecret)
|
|
if secret == "" {
|
|
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "sync secret not configured")
|
|
}
|
|
|
|
timestamp := strings.TrimSpace(c.Get(headerTimestamp))
|
|
nonce := strings.TrimSpace(c.Get(headerNonce))
|
|
signature := strings.TrimSpace(c.Get(headerSignature))
|
|
|
|
if timestamp == "" || nonce == "" || signature == "" {
|
|
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "missing signature headers")
|
|
}
|
|
if len(nonce) < 16 {
|
|
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "nonce too short")
|
|
}
|
|
|
|
ts, err := strconv.ParseInt(timestamp, 10, 64)
|
|
if err != nil {
|
|
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusBadRequest, "invalid timestamp")
|
|
}
|
|
|
|
msgTime := time.Unix(ts, 0).UTC()
|
|
now := time.Now().UTC()
|
|
drift := now.Sub(msgTime)
|
|
if drift > h.drift || drift < -h.drift {
|
|
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "timestamp outside allowed window")
|
|
}
|
|
|
|
providedSig, err := decodeSignature(signature)
|
|
if err != nil {
|
|
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature encoding")
|
|
}
|
|
|
|
expectedSignature := h.calculateSignature(secret, rawAlias, timestamp, nonce, body)
|
|
if !hmac.Equal(providedSig, expectedSignature) {
|
|
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature")
|
|
}
|
|
|
|
if err := h.registerNonce(c.Context(), aliasKey, nonce); err != nil {
|
|
return "", config.SSOClientConfig{}, err
|
|
}
|
|
|
|
return aliasKey, clientCfg, nil
|
|
}
|
|
|
|
func (h *UserSyncController) verifyAuthorization(c *fiber.Ctx, alias string) error {
|
|
authHeader := strings.TrimSpace(c.Get(fiber.HeaderAuthorization))
|
|
if authHeader == "" {
|
|
return fiber.NewError(fiber.StatusUnauthorized, "missing authorization header")
|
|
}
|
|
|
|
parts := strings.SplitN(authHeader, " ", 2)
|
|
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
|
return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header")
|
|
}
|
|
|
|
token := strings.TrimSpace(parts[1])
|
|
if token == "" {
|
|
return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header")
|
|
}
|
|
|
|
verification, err := sso.VerifyAccessToken(token)
|
|
if err != nil {
|
|
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
|
|
}
|
|
|
|
if verification.ServiceAlias == "" || verification.ServiceAlias != alias {
|
|
return fiber.NewError(fiber.StatusUnauthorized, "service subject mismatch")
|
|
}
|
|
|
|
if !containsScope(verification.Claims.Scopes(), "sync.users") {
|
|
return fiber.NewError(fiber.StatusForbidden, "missing sync scope")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *UserSyncController) upsertUser(c *fiber.Ctx, alias string, req *userSyncRequest) error {
|
|
entity := &entity.User{
|
|
IdUser: req.User.ID,
|
|
Email: req.User.Email,
|
|
Name: req.User.Name,
|
|
}
|
|
|
|
//TODO: MIGRATION TO UPSERT BASE REPOSITORY
|
|
if err := h.repo.UpsertByIdUser(c.Context(), entity); err != nil {
|
|
h.log.Errorf("sso user upsert failed: %v", err)
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to upsert user")
|
|
}
|
|
|
|
user, err := h.repo.GetByIdUser(c.Context(), req.User.ID, nil)
|
|
if err != nil {
|
|
h.log.Errorf("sso user fetch after upsert failed: %v", err)
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to load user")
|
|
}
|
|
|
|
h.log.WithFields(logrus.Fields{
|
|
"action": req.Action,
|
|
"public_id": req.PublicID,
|
|
"alias": alias,
|
|
"user_id": req.User.ID,
|
|
}).Info("sso user synced")
|
|
|
|
msg := fmt.Sprintf("User %s successfully", req.Action)
|
|
return c.Status(fiber.StatusOK).JSON(response.Success{
|
|
Code: fiber.StatusOK,
|
|
Status: "success",
|
|
Message: msg,
|
|
Data: dto.ToUserListDTO(*user),
|
|
})
|
|
}
|
|
|
|
func (h *UserSyncController) removeUser(c *fiber.Ctx, alias string, req *userSyncRequest) error {
|
|
if err := h.repo.SoftDeleteByIdUser(c.Context(), req.User.ID); err != nil {
|
|
if err == gorm.ErrRecordNotFound {
|
|
return fiber.NewError(fiber.StatusNotFound, "user not found")
|
|
}
|
|
h.log.Errorf("sso user delete failed: %v", err)
|
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to delete user")
|
|
}
|
|
|
|
h.log.WithFields(logrus.Fields{
|
|
"action": req.Action,
|
|
"public_id": req.PublicID,
|
|
"alias": alias,
|
|
"user_id": req.User.ID,
|
|
}).Info("sso user deleted")
|
|
|
|
return c.Status(fiber.StatusOK).JSON(response.Common{
|
|
Code: fiber.StatusOK,
|
|
Status: "success",
|
|
Message: "User deleted successfully",
|
|
})
|
|
}
|
|
|
|
func (h *UserSyncController) registerNonce(ctx context.Context, alias, nonce string) error {
|
|
ttl := h.nonceTTL
|
|
if ttl <= 0 {
|
|
ttl = defaultNonceTTL
|
|
}
|
|
|
|
key := fmt.Sprintf("sso:sync:%s:%s", alias, nonce)
|
|
if h.redis != nil {
|
|
stored, err := h.redis.SetNX(ctx, key, "1", ttl).Result()
|
|
if err == nil {
|
|
if !stored {
|
|
return fiber.NewError(fiber.StatusUnauthorized, "nonce already used")
|
|
}
|
|
return nil
|
|
}
|
|
h.log.Errorf("store sync nonce failed: %v", err)
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
if expRaw, ok := h.localNonces.Load(key); ok {
|
|
if expTime, ok := expRaw.(time.Time); ok && expTime.After(now) {
|
|
return fiber.NewError(fiber.StatusUnauthorized, "nonce already used")
|
|
}
|
|
}
|
|
h.localNonces.Store(key, now.Add(ttl))
|
|
h.pruneLocalNonces(now)
|
|
return nil
|
|
}
|
|
|
|
func (h *UserSyncController) calculateSignature(secret, alias, timestamp, nonce string, body []byte) []byte {
|
|
mac := hmac.New(sha256.New, []byte(secret))
|
|
mac.Write([]byte(alias))
|
|
mac.Write([]byte("\n"))
|
|
mac.Write([]byte(timestamp))
|
|
mac.Write([]byte("\n"))
|
|
mac.Write([]byte(nonce))
|
|
mac.Write([]byte("\n"))
|
|
mac.Write(body)
|
|
return mac.Sum(nil)
|
|
}
|
|
|
|
func containsScope(scopes []string, target string) bool {
|
|
target = strings.ToLower(strings.TrimSpace(target))
|
|
if target == "" {
|
|
return false
|
|
}
|
|
for _, scope := range scopes {
|
|
if strings.ToLower(strings.TrimSpace(scope)) == target {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func decodeSignature(sig string) ([]byte, error) {
|
|
sig = strings.TrimSpace(sig)
|
|
if sig == "" {
|
|
return nil, errors.New("empty signature")
|
|
}
|
|
if decoded, err := hex.DecodeString(sig); err == nil {
|
|
return decoded, nil
|
|
}
|
|
if decoded, err := base64.StdEncoding.DecodeString(sig); err == nil {
|
|
return decoded, nil
|
|
}
|
|
if decoded, err := base64.URLEncoding.DecodeString(sig); err == nil {
|
|
return decoded, nil
|
|
}
|
|
return nil, errors.New("unrecognized signature encoding")
|
|
}
|
|
|
|
func (h *UserSyncController) pruneLocalNonces(now time.Time) {
|
|
h.localNonces.Range(func(key, value any) bool {
|
|
exp, ok := value.(time.Time)
|
|
if !ok || exp.Before(now) {
|
|
h.localNonces.Delete(key)
|
|
}
|
|
return true
|
|
})
|
|
}
|