From 6bddbbf9d92c465ed0282aab965c7480bb4ecabc Mon Sep 17 00:00:00 2001 From: ragilap Date: Mon, 6 Oct 2025 12:31:54 +0700 Subject: [PATCH] feat/login crud in users sync with sso --- .env.example | 20 + cmd/api/main.go | 10 + internal/cache/cache.go | 38 ++ internal/common/repository/repository.go | 16 + internal/config/config.go | 176 +++++++- ...20250925040409_create_master_tables.up.sql | 1 + internal/middleware/auth.go | 35 +- internal/middleware/limiter.go | 21 + .../modules/sso/controllers/sso.controller.go | 404 ++++++++++++++++++ .../sso/controllers/user_sync.controller.go | 391 +++++++++++++++++ internal/modules/sso/module.go | 13 + internal/modules/sso/route.go | 35 ++ internal/modules/sso/session/store.go | 70 +++ .../users/controllers/user.controller.go | 110 ++--- .../users/repositories/user.repository.go | 51 ++- internal/modules/users/route.go | 12 +- .../modules/users/services/user.service.go | 13 + internal/route/route.go | 6 +- internal/sso/verifier.go | 161 +++++++ internal/utils/secure/random.go | 84 ++++ internal/utils/verify.go | 45 -- 21 files changed, 1576 insertions(+), 136 deletions(-) create mode 100644 internal/cache/cache.go create mode 100644 internal/modules/sso/controllers/sso.controller.go create mode 100644 internal/modules/sso/controllers/user_sync.controller.go create mode 100644 internal/modules/sso/module.go create mode 100644 internal/modules/sso/route.go create mode 100644 internal/modules/sso/session/store.go create mode 100644 internal/sso/verifier.go create mode 100644 internal/utils/secure/random.go delete mode 100644 internal/utils/verify.go diff --git a/.env.example b/.env.example index 02810734..1ba4e23b 100644 --- a/.env.example +++ b/.env.example @@ -32,3 +32,23 @@ CORS_MAX_AGE=600 # Redis REDIS_URL=redis://redis:6379/0 REDIS_PORT_HOST=6381 + +# SSO Integration +SSO_ISSUER=http://localhost:8080/api +SSO_JWKS_URL=http://localhost:8080/api/.well-known/jwks.json +SSO_ALLOWED_AUDIENCES=client:lti-api +SSO_AUTHORIZE_URL=http://localhost:8080/sso/authorize +SSO_TOKEN_URL=http://localhost:8080/sso/token +SSO_GETME_URL=http://localhost:8080/api/auth/get-me +SSO_ACCESS_COOKIE_NAME=sso_access +SSO_REFRESH_COOKIE_NAME=sso_refresh +SSO_COOKIE_DOMAIN= +SSO_COOKIE_SECURE=false +SSO_COOKIE_SAMESITE=Lax +SSO_PKCE_TTL_SECONDS=300 +# Security window and payload limits for SSO user sync webhook +SSO_USER_SYNC_SIGNATURE_DRIFT_SECONDS=120 +SSO_USER_SYNC_NONCE_TTL_SECONDS=600 +SSO_USER_SYNC_MAX_BODY_BYTES=32768 +# Example JSON (single-line) of client configs (each client requires a unique sync_secret) +SSO_CLIENTS={"lti":{"public_id":"client:lti","redirect_uri":"http://localhost:8081/api/sso/callback","scope":"openid profile","default_return_uri":"http://localhost:3000","allowed_return_origins":["http://localhost:3000"],"sync_secret":"changeme"}} diff --git a/cmd/api/main.go b/cmd/api/main.go index 0bcbaa86..2c120aa9 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -9,10 +9,12 @@ import ( "syscall" "time" + "gitlab.com/mbugroup/lti-api.git/internal/cache" "gitlab.com/mbugroup/lti-api.git/internal/config" "gitlab.com/mbugroup/lti-api.git/internal/database" "gitlab.com/mbugroup/lti-api.git/internal/middleware" "gitlab.com/mbugroup/lti-api.git/internal/route" + "gitlab.com/mbugroup/lti-api.git/internal/sso" "gitlab.com/mbugroup/lti-api.git/internal/utils" "github.com/gofiber/fiber/v2" @@ -33,6 +35,7 @@ func main() { defer closeDatabase(db) rdb := setupRedis() defer rdb.Close() + setupSSO(ctx) setupRoutes(app, db, rdb) address := fmt.Sprintf("%s:%d", config.AppHost, config.AppPort) @@ -52,10 +55,17 @@ func setupRedis() *redis.Client { if err := rdb.Ping(context.Background()).Err(); err != nil { utils.Log.Fatalf("Redis ping failed: %v", err) } + cache.SetRedis(rdb) utils.Log.Infof("Redis connected: %s", config.RedisURL) return rdb } +func setupSSO(ctx context.Context) { + if err := sso.Init(ctx, config.SSOJWKSURL, config.SSOIssuer, config.SSOAllowedAudiences); err != nil { + utils.Log.Fatalf("SSO initialization failed: %v", err) + } +} + func setupFiberApp() *fiber.App { app := fiber.New(config.FiberConfig()) diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 00000000..9474d3d7 --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,38 @@ +package cache + +import ( + "errors" + "sync" + + "github.com/redis/go-redis/v9" +) + +var ( + redisClient *redis.Client + mu sync.RWMutex +) + +// SetRedis assigns the global redis client used across the application. +func SetRedis(client *redis.Client) { + mu.Lock() + defer mu.Unlock() + redisClient = client +} + +// Redis returns the configured redis client. It may be nil if not yet initialised. +func Redis() *redis.Client { + mu.RLock() + defer mu.RUnlock() + return redisClient +} + +// MustRedis returns the redis client or panics if it has not been set. +func MustRedis() *redis.Client { + mu.RLock() + client := redisClient + mu.RUnlock() + if client == nil { + panic(errors.New("redis client not initialised")) + } + return client +} diff --git a/internal/common/repository/repository.go b/internal/common/repository/repository.go index 6605f95f..fa58fcd7 100644 --- a/internal/common/repository/repository.go +++ b/internal/common/repository/repository.go @@ -11,6 +11,7 @@ type BaseRepository[T any] interface { GetAll(ctx context.Context, offset, limit int, modifier func(*gorm.DB) *gorm.DB) ([]T, int64, error) GetByID(ctx context.Context, id uint, modifier func(*gorm.DB) *gorm.DB) (*T, error) GetByIDs(ctx context.Context, ids []uint, modifier func(*gorm.DB) *gorm.DB) ([]T, error) + First(ctx context.Context, modifier func(*gorm.DB) *gorm.DB) (*T, error) CreateOne(ctx context.Context, entity *T, modifier func(*gorm.DB) *gorm.DB) error CreateMany(ctx context.Context, entities []*T, modifier func(*gorm.DB) *gorm.DB) error @@ -96,6 +97,21 @@ func (r *BaseRepositoryImpl[T]) GetByIDs( return entities, nil } +func (r *BaseRepositoryImpl[T]) First( + ctx context.Context, + modifier func(*gorm.DB) *gorm.DB, +) (*T, error) { + entity := new(T) + q := r.db.WithContext(ctx) + if modifier != nil { + q = modifier(q) + } + if err := q.First(entity).Error; err != nil { + return nil, err + } + return entity, nil +} + // ---- CREATE ---- func (r *BaseRepositoryImpl[T]) CreateOne( ctx context.Context, diff --git a/internal/config/config.go b/internal/config/config.go index 2cd4987e..ce17722f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,36 +2,64 @@ package config import ( "encoding/json" + "fmt" "strings" + "time" "gitlab.com/mbugroup/lti-api.git/internal/utils" "github.com/spf13/viper" ) +type SSOClientConfig struct { + PublicID string `json:"public_id"` + RedirectURI string `json:"redirect_uri"` + Scope string `json:"scope"` + Prompt string `json:"prompt"` + DefaultReturnURI string `json:"default_return_uri"` + AllowedReturnOrigins []string `json:"allowed_return_origins"` + SyncSecret string `json:"sync_secret"` +} + var ( - IsProd bool - AppHost string - Version string - LogLevel string - AppPort int - DBHost string - DBUser string - DBPassword string - DBName string - DBPort int - JWTSecret string - JWTAccessExp int - JWTRefreshExp int - JWTResetPasswordExp int - JWTVerifyEmailExp int - RedisURL string - CORSAllowOrigins []string - CORSAllowMethods []string - CORSAllowHeaders []string - CORSExposeHeaders []string - CORSAllowCredentials bool - CORSMaxAge int + IsProd bool + AppHost string + Version string + LogLevel string + AppPort int + DBHost string + DBUser string + DBPassword string + DBName string + DBPort int + JWTSecret string + JWTAccessExp int + JWTRefreshExp int + JWTResetPasswordExp int + JWTVerifyEmailExp int + RedisURL string + CORSAllowOrigins []string + CORSAllowMethods []string + CORSAllowHeaders []string + CORSExposeHeaders []string + CORSAllowCredentials bool + CORSMaxAge int + SSOIssuer string + SSOJWKSURL string + SSOAllowedAudiences []string + SSOAuthorizeURL string + SSOTokenURL string + SSOGetMeURL string + SSOClients map[string]SSOClientConfig + SSOAccessCookieName string + SSORefreshCookieName string + SSOCookieDomain string + SSOCookieSecure bool + SSOCookieSameSite string + SSOPKCETTL time.Duration + SSOUserSyncDrift time.Duration + SSOUserSyncNonceTTL time.Duration + SSOUserSyncMaxBodyBytes int ) func init() { @@ -68,6 +96,43 @@ func init() { // Redis RedisURL = viper.GetString("REDIS_URL") + + // SSO integration + SSOIssuer = viper.GetString("SSO_ISSUER") + SSOJWKSURL = viper.GetString("SSO_JWKS_URL") + SSOAllowedAudiences = parseList("SSO_ALLOWED_AUDIENCES") + SSOAuthorizeURL = viper.GetString("SSO_AUTHORIZE_URL") + SSOTokenURL = viper.GetString("SSO_TOKEN_URL") + SSOGetMeURL = viper.GetString("SSO_GETME_URL") + SSOAccessCookieName = defaultString(viper.GetString("SSO_ACCESS_COOKIE_NAME"), "sso_access") + SSORefreshCookieName = defaultString(viper.GetString("SSO_REFRESH_COOKIE_NAME"), "sso_refresh") + SSOCookieDomain = viper.GetString("SSO_COOKIE_DOMAIN") + SSOCookieSecure = viper.GetBool("SSO_COOKIE_SECURE") + SSOCookieSameSite = defaultString(viper.GetString("SSO_COOKIE_SAMESITE"), "Lax") + if ttl := viper.GetInt("SSO_PKCE_TTL_SECONDS"); ttl > 0 { + SSOPKCETTL = time.Duration(ttl) * time.Second + } else { + SSOPKCETTL = 5 * time.Minute + } + SSOClients = loadSSOClients("SSO_CLIENTS") + if drift := viper.GetInt("SSO_USER_SYNC_SIGNATURE_DRIFT_SECONDS"); drift > 0 { + SSOUserSyncDrift = time.Duration(drift) * time.Second + } else { + SSOUserSyncDrift = 2 * time.Minute + } + if ttl := viper.GetInt("SSO_USER_SYNC_NONCE_TTL_SECONDS"); ttl > 0 { + SSOUserSyncNonceTTL = time.Duration(ttl) * time.Second + } else { + SSOUserSyncNonceTTL = 10 * time.Minute + } + SSOUserSyncMaxBodyBytes = viper.GetInt("SSO_USER_SYNC_MAX_BODY_BYTES") + if SSOUserSyncMaxBodyBytes <= 0 { + SSOUserSyncMaxBodyBytes = 32 * 1024 + } + + if IsProd { + ensureProdConfig() + } } func loadConfig() { @@ -117,3 +182,70 @@ func parseListWithDefault(key, def string) []string { } return parts } + +func loadSSOClients(key string) map[string]SSOClientConfig { + clients := make(map[string]SSOClientConfig) + raw := strings.TrimSpace(viper.GetString(key)) + if raw == "" { + return clients + } + if err := json.Unmarshal([]byte(raw), &clients); err != nil { + utils.Log.Errorf("Failed to parse %s: %v", key, err) + return make(map[string]SSOClientConfig) + } + result := make(map[string]SSOClientConfig, len(clients)) + for alias, cfg := range clients { + alias = strings.ToLower(strings.TrimSpace(alias)) + for i, origin := range cfg.AllowedReturnOrigins { + cfg.AllowedReturnOrigins[i] = strings.TrimSpace(origin) + } + cfg.SyncSecret = strings.TrimSpace(cfg.SyncSecret) + result[alias] = cfg + } + return result +} + +func defaultString(v, def string) string { + if strings.TrimSpace(v) == "" { + return def + } + return v +} + +func ensureProdConfig() { + if SSOAuthorizeURL == "" || !strings.HasPrefix(SSOAuthorizeURL, "https://") { + panic("SSO_AUTHORIZE_URL must be https in production") + } + if SSOTokenURL == "" || !strings.HasPrefix(SSOTokenURL, "https://") { + panic("SSO_TOKEN_URL must be https in production") + } + if SSOGetMeURL == "" || !strings.HasPrefix(SSOGetMeURL, "https://") { + panic("SSO_GETME_URL must be https in production") + } + if !SSOCookieSecure { + panic("SSO_COOKIE_SECURE must be true in production") + } + if SSOCookieDomain == "" { + panic("SSO_COOKIE_DOMAIN must be configured in production") + } + if len(SSOAllowedAudiences) == 0 { + panic("SSO_ALLOWED_AUDIENCES must contain at least one audience in production") + } + for alias, cfg := range SSOClients { + if strings.TrimSpace(cfg.SyncSecret) == "" { + panic(fmt.Sprintf("SSO_CLIENTS[%s].sync_secret must be configured in production", alias)) + } + if len(cfg.SyncSecret) < 16 { + panic(fmt.Sprintf("SSO_CLIENTS[%s].sync_secret must be at least 16 characters", alias)) + } + } + if SSOUserSyncDrift <= 0 { + panic("SSO_USER_SYNC_SIGNATURE_DRIFT_SECONDS must be greater than zero in production") + } + if SSOUserSyncNonceTTL <= 0 { + panic("SSO_USER_SYNC_NONCE_TTL_SECONDS must be greater than zero in production") + } + if SSOUserSyncMaxBodyBytes <= 0 { + panic("SSO_USER_SYNC_MAX_BODY_BYTES must be greater than zero in production") + } +} diff --git a/internal/database/migrations/20250925040409_create_master_tables.up.sql b/internal/database/migrations/20250925040409_create_master_tables.up.sql index 6dcd914a..0725afbf 100644 --- a/internal/database/migrations/20250925040409_create_master_tables.up.sql +++ b/internal/database/migrations/20250925040409_create_master_tables.up.sql @@ -10,6 +10,7 @@ CREATE TABLE users ( ); CREATE UNIQUE INDEX users_id_user_unique ON users (id_user) WHERE deleted_at IS NULL; + CREATE UNIQUE INDEX users_email_unique ON users (email) WHERE deleted_at IS NULL; -- FLAGS diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index 14a64337..fb959989 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -5,7 +5,7 @@ import ( "gitlab.com/mbugroup/lti-api.git/internal/config" service "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services" - "gitlab.com/mbugroup/lti-api.git/internal/utils" + "gitlab.com/mbugroup/lti-api.git/internal/sso" "github.com/gofiber/fiber/v2" ) @@ -15,21 +15,50 @@ func Auth(userService service.UserService, requiredRights ...string) fiber.Handl authHeader := c.Get("Authorization") token := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer ")) + if token == "" { + cookieName := config.SSOAccessCookieName + if cookieName == "" { + cookieName = "access" + } + token = strings.TrimSpace(c.Cookies(cookieName)) + } + if token == "" { return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate") } - userID, err := utils.VerifyToken(token, config.JWTSecret, config.TokenTypeAccess) + verification, err := sso.VerifyAccessToken(token) if err != nil { return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate") } - user, err := userService.GetOne(c, userID) + if len(config.SSOAllowedAudiences) > 0 { + allowed := make(map[string]struct{}, len(config.SSOAllowedAudiences)) + for _, aud := range config.SSOAllowedAudiences { + aud = strings.TrimSpace(aud) + if aud != "" { + allowed[aud] = struct{}{} + } + } + audienceValid := false + for _, aud := range verification.Claims.Audience { + if _, ok := allowed[aud]; ok { + audienceValid = true + break + } + } + if !audienceValid { + return fiber.NewError(fiber.StatusUnauthorized, "invalid audience") + } + } + + user, err := userService.GetBySSOUserID(c, verification.UserID) if err != nil || user == nil { return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate") } c.Locals("user", user) + c.Locals("token_claims", verification.Claims) // if len(requiredRights) > 0 { // userRights, hasRights := config.RoleRights[user.Role] diff --git a/internal/middleware/limiter.go b/internal/middleware/limiter.go index 205facd1..2b9471ce 100644 --- a/internal/middleware/limiter.go +++ b/internal/middleware/limiter.go @@ -24,3 +24,24 @@ func LimiterConfig() fiber.Handler { SkipSuccessfulRequests: true, }) } + +func NewLimiter(max int, expiration time.Duration) fiber.Handler { + if max <= 0 { + max = 10 + } + if expiration <= 0 { + expiration = time.Minute + } + return limiter.New(limiter.Config{ + Max: max, + Expiration: expiration, + LimitReached: func(c *fiber.Ctx) error { + return c.Status(fiber.StatusTooManyRequests). + JSON(response.Common{ + Code: fiber.StatusTooManyRequests, + Status: "error", + Message: "Too many requests, please try again later", + }) + }, + }) +} diff --git a/internal/modules/sso/controllers/sso.controller.go b/internal/modules/sso/controllers/sso.controller.go new file mode 100644 index 00000000..e8b5b6ef --- /dev/null +++ b/internal/modules/sso/controllers/sso.controller.go @@ -0,0 +1,404 @@ +package controllers + +import ( + "context" + "crypto/sha256" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/gofiber/fiber/v2" + "github.com/sirupsen/logrus" + + "gitlab.com/mbugroup/lti-api.git/internal/config" + "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session" + "gitlab.com/mbugroup/lti-api.git/internal/sso" + "gitlab.com/mbugroup/lti-api.git/internal/utils" + "gitlab.com/mbugroup/lti-api.git/internal/utils/secure" +) + +// Controller manages the SSO start & callback flow using PKCE. +type Controller struct { + httpClient *http.Client + store *session.Store +} + +func NewController(client *http.Client, store *session.Store) *Controller { + return &Controller{httpClient: client, store: store} +} + +// Start handles GET /sso/start requests and redirects users to the central SSO authorize endpoint. +func (h *Controller) Start(c *fiber.Ctx) error { + alias := strings.ToLower(strings.TrimSpace(c.Query("client"))) + if alias == "" { + alias = strings.ToLower(strings.TrimSpace(c.Query("client_id"))) + } + if alias == "" { + return fiber.NewError(fiber.StatusBadRequest, "missing client") + } + cfg, ok := config.SSOClients[alias] + if !ok || cfg.PublicID == "" { + return fiber.NewError(fiber.StatusBadRequest, "unknown client") + } + + authorizeEndpoint := strings.TrimSpace(config.SSOAuthorizeURL) + if authorizeEndpoint == "" { + return fiber.NewError(fiber.StatusInternalServerError, "authorize endpoint not configured") + } + + state, err := secure.RandomString(48) + if err != nil { + utils.Log.Errorf("generate state failed: %v", err) + return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization") + } + + nonce, err := secure.RandomString(32) + if err != nil { + utils.Log.Errorf("generate nonce failed: %v", err) + return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization") + } + + codeVerifier, err := secure.PKCECodeVerifier(96) + if err != nil { + utils.Log.Errorf("generate code verifier failed: %v", err) + return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization") + } + + digest := sha256.Sum256([]byte(codeVerifier)) + challenge := secure.Base64URLEncode(digest[:]) + + authorizeURL, err := url.Parse(authorizeEndpoint) + if err != nil { + return fiber.NewError(fiber.StatusInternalServerError, "invalid authorize endpoint") + } + + scope := cfg.Scope + if scope == "" { + scope = "openid profile" + } + if !strings.Contains(" "+scope+" ", " openid ") { + scope = scope + " openid" + } + + rawReturn := strings.TrimSpace(c.Query("return_to")) + if rawReturn == "" { + rawReturn = cfg.DefaultReturnURI + } + + returnTo, err := normalizeReturnTarget(rawReturn, cfg) + if err != nil { + return fiber.NewError(fiber.StatusBadRequest, err.Error()) + } + + query := authorizeURL.Query() + query.Set("response_type", "code") + query.Set("client_id", cfg.PublicID) + query.Set("redirect_uri", cfg.RedirectURI) + query.Set("scope", strings.TrimSpace(scope)) + query.Set("state", state) + query.Set("code_challenge", challenge) + query.Set("code_challenge_method", "S256") + query.Set("nonce", nonce) + if prompt := strings.TrimSpace(cfg.Prompt); prompt != "" { + query.Set("prompt", prompt) + } + if extraPrompt := strings.TrimSpace(c.Query("prompt")); extraPrompt != "" { + query.Set("prompt", extraPrompt) + } + authorizeURL.RawQuery = query.Encode() + + payload := &session.PKCESession{ + CodeVerifier: codeVerifier, + Nonce: nonce, + ClientAlias: alias, + ClientID: cfg.PublicID, + RedirectURI: cfg.RedirectURI, + Scope: strings.TrimSpace(scope), + ReturnTo: returnTo, + CreatedAt: time.Now().UTC(), + } + + if err := h.store.Save(c.Context(), state, payload); err != nil { + utils.Log.Errorf("store pkce session failed: %v", err) + return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization") + } + + utils.Log.WithFields(logrus.Fields{ + "client": alias, + "state": state, + "return_to": returnTo, + }).Info("sso start redirect") + + return c.Redirect(authorizeURL.String(), fiber.StatusFound) +} + +// Callback handles the redirect from SSO containing the authorization code. +func (h *Controller) Callback(c *fiber.Ctx) error { + state := strings.TrimSpace(c.Query("state")) + code := strings.TrimSpace(c.Query("code")) + if state == "" || code == "" { + return fiber.NewError(fiber.StatusBadRequest, "missing code or state") + } + + sessionData, err := h.store.Get(c.Context(), state) + if err != nil { + utils.Log.Errorf("load pkce session failed: %v", err) + return fiber.NewError(fiber.StatusInternalServerError, "failed to validate authorization state") + } + if sessionData == nil { + return fiber.NewError(fiber.StatusBadRequest, "authorization state not found or expired") + } + defer func() { + if err := h.store.Delete(context.Background(), state); err != nil { + utils.Log.Warnf("failed to delete pkce session: %v", err) + } + }() + + tokenEndpoint := strings.TrimSpace(config.SSOTokenURL) + if tokenEndpoint == "" { + return fiber.NewError(fiber.StatusInternalServerError, "token endpoint not configured") + } + + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("code", code) + form.Set("code_verifier", sessionData.CodeVerifier) + form.Set("redirect_uri", sessionData.RedirectURI) + form.Set("client_id", sessionData.ClientID) + + req, err := http.NewRequestWithContext(c.Context(), http.MethodPost, tokenEndpoint, strings.NewReader(form.Encode())) + if err != nil { + return fiber.NewError(fiber.StatusInternalServerError, "failed to create token request") + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := h.httpClient.Do(req) + if err != nil { + utils.Log.Errorf("token request failed: %v", err) + return fiber.NewError(fiber.StatusBadGateway, "failed to exchange authorization code") + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + utils.Log.Warnf("token response status %d", resp.StatusCode) + return fiber.NewError(fiber.StatusBadGateway, "token exchange rejected") + } + + var tokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + IDToken string `json:"id_token"` + Error string `json:"error"` + Description string `json:"error_description"` + } + + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return fiber.NewError(fiber.StatusBadGateway, "invalid token response") + } + if tokenResp.Error != "" { + return fiber.NewError(fiber.StatusBadGateway, tokenResp.Description) + } + if tokenResp.AccessToken == "" { + return fiber.NewError(fiber.StatusBadGateway, "missing access token") + } + + verification, err := sso.VerifyAccessToken(tokenResp.AccessToken) + if err != nil { + utils.Log.Errorf("access token verification failed: %v", err) + return fiber.NewError(fiber.StatusUnauthorized, "invalid access token") + } + + // prepare cookies + issueCookies(c, tokenResp, verification) + + redirectTarget := sessionData.ReturnTo + if redirectTarget == "" { + redirectTarget = "/" + } + + fmt.Println(sessionData.ClientAlias,"test") + utils.Log.WithFields(logrus.Fields{ + "client": sessionData.ClientAlias, + "user_id": verification.UserID, + "return_to": redirectTarget, + }).Info("sso callback successful") + + return c.Redirect(redirectTarget, fiber.StatusFound) +} + +// UserInfo proxies the user profile from the central SSO so the frontend can obtain +// enriched user metadata (roles, permissions, etc.) without exposing tokens to the browser. +func (h *Controller) UserInfo(c *fiber.Ctx) error { + accessName := config.SSOAccessCookieName + if accessName == "" { + accessName = "sso_access" + } + + token := strings.TrimSpace(c.Cookies(accessName)) + tokenFromCookie := token != "" + + if !tokenFromCookie { + authHeader := strings.TrimSpace(c.Get("Authorization")) + if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { + token = strings.TrimSpace(authHeader[7:]) + } + } + + if token == "" { + return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated") + } + + endpoint := strings.TrimSpace(config.SSOGetMeURL) + if endpoint == "" { + return fiber.NewError(fiber.StatusInternalServerError, "userinfo endpoint not configured") + } + + req, err := http.NewRequestWithContext(c.Context(), http.MethodGet, endpoint, nil) + if err != nil { + utils.Log.Errorf("failed to build userinfo request: %v", err) + return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare userinfo request") + } + req.Header.Set("Accept", "application/json") + + // SSO /auth/get-me expects the access cookie; add Authorization as well for compatibility. + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + if tokenFromCookie { + req.Header.Set("Cookie", fmt.Sprintf("%s=%s", accessName, token)) + } + + resp, err := h.httpClient.Do(req) + if err != nil { + utils.Log.Errorf("userinfo request failed: %v", err) + return fiber.NewError(fiber.StatusBadGateway, "failed to fetch user profile") + } + defer resp.Body.Close() + + utils.Log.WithFields(logrus.Fields{"status": resp.StatusCode}).Info("sso userinfo response") + + body, err := io.ReadAll(resp.Body) + if err != nil { + utils.Log.Errorf("failed to read userinfo response: %v", err) + return fiber.NewError(fiber.StatusBadGateway, "invalid user profile response") + } + + if ct := resp.Header.Get("Content-Type"); ct != "" { + c.Set("Content-Type", ct) + } else { + c.Type("json") + } + + return c.Status(resp.StatusCode).Send(body) +} + +func issueCookies(c *fiber.Ctx, tokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + IDToken string `json:"id_token"` + Error string `json:"error"` + Description string `json:"error_description"` +}, verification *sso.VerificationResult) { + fmt.Println(tokenResp.AccessToken) + accessName := config.SSOAccessCookieName + if accessName == "" { + accessName = "access" + } + refreshName := config.SSORefreshCookieName + if refreshName == "" { + refreshName = "refresh" + } + maxAge := tokenResp.ExpiresIn + if maxAge <= 0 { + maxAge = int(15 * time.Minute.Seconds()) + } + + sameSite := config.SSOCookieSameSite + if sameSite == "" { + sameSite = "Lax" + } + + cookieDomain := config.SSOCookieDomain + + cookieAccess := &fiber.Cookie{ + Name: accessName, + Value: tokenResp.AccessToken, + Path: "/", + Domain: cookieDomain, + HTTPOnly: true, + Secure: config.SSOCookieSecure, + SameSite: sameSite, + MaxAge: maxAge, + } + c.Cookie(cookieAccess) + if tokenResp.RefreshToken != "" { + cookieRefresh := &fiber.Cookie{ + Name: refreshName, + Value: tokenResp.RefreshToken, + Path: "/", + Domain: cookieDomain, + HTTPOnly: true, + Secure: config.SSOCookieSecure, + SameSite: sameSite, + MaxAge: int((time.Hour * 24 * 30).Seconds()), + } + c.Cookie(cookieRefresh) + } + + // Optional: expose limited info via headers for FE debugging (avoid tokens) + c.Set("X-Auth-User", fmt.Sprintf("%d", verification.UserID)) +} + +func normalizeReturnTarget(returnTo string, cfg config.SSOClientConfig) (string, error) { + returnTo = strings.TrimSpace(returnTo) + if returnTo == "" { + return "", nil + } + if strings.HasPrefix(returnTo, "//") { + return "", fmt.Errorf("invalid return_to") + } + if strings.HasPrefix(returnTo, "/") { + return returnTo, nil + } + + parsed, err := url.Parse(returnTo) + if err != nil { + return "", fmt.Errorf("invalid return_to") + } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return "", fmt.Errorf("invalid return_to scheme") + } + + allowedOrigins := make(map[string]struct{}) + if cfg.DefaultReturnURI != "" { + if u, err := url.Parse(cfg.DefaultReturnURI); err == nil && u.Host != "" { + allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{} + } + } + for _, origin := range cfg.AllowedReturnOrigins { + origin = strings.TrimSpace(origin) + if origin == "" { + continue + } + if u, err := url.Parse(origin); err == nil && u.Host != "" && (u.Scheme == "http" || u.Scheme == "https") { + allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{} + } + } + + if len(allowedOrigins) > 0 { + origin := parsed.Scheme + "://" + parsed.Host + if _, ok := allowedOrigins[origin]; !ok { + return "", fmt.Errorf("return_to origin not allowed") + } + } + + return parsed.String(), nil +} diff --git a/internal/modules/sso/controllers/user_sync.controller.go b/internal/modules/sso/controllers/user_sync.controller.go new file mode 100644 index 00000000..2e02c2fd --- /dev/null +++ b/internal/modules/sso/controllers/user_sync.controller.go @@ -0,0 +1,391 @@ +package controllers + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + "sync" + "time" + "github.com/go-playground/validator/v10" + "github.com/gofiber/fiber/v2" + "github.com/redis/go-redis/v9" + "github.com/sirupsen/logrus" + "gorm.io/gorm" + + "gitlab.com/mbugroup/lti-api.git/internal/config" + "gitlab.com/mbugroup/lti-api.git/internal/modules/users/dto" + entity "gitlab.com/mbugroup/lti-api.git/internal/entities" + userRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/users/repositories" + "gitlab.com/mbugroup/lti-api.git/internal/response" + "gitlab.com/mbugroup/lti-api.git/internal/sso" + "gitlab.com/mbugroup/lti-api.git/internal/utils" +) + + +const ( + headerClient = "X-Sync-Client" + headerTimestamp = "X-Sync-Timestamp" + headerNonce = "X-Sync-Nonce" + headerSignature = "X-Sync-Signature" + defaultDrift = 2 * time.Minute + defaultNonceTTL = 10 * time.Minute +) + +// UserSyncController handles incoming user management events from the central SSO service. +type UserSyncController struct { + validate *validator.Validate + repo userRepository.UserRepository + redis *redis.Client + clients map[string]config.SSOClientConfig + drift time.Duration + nonceTTL time.Duration + maxBodyBytes int + log *logrus.Logger + localNonces sync.Map +} + +type userSyncRequest struct { + Action string `json:"action" validate:"required,oneof=create update delete"` + PublicID string `json:"public_id" validate:"required"` + User userSyncUser `json:"user" validate:"required"` +} + +type userSyncUser struct { + ID int64 `json:"id" validate:"required"` + Email string `json:"email"` + Name string `json:"name"` +} + +func NewUserSyncController(validate *validator.Validate, repo userRepository.UserRepository, redis *redis.Client, clients map[string]config.SSOClientConfig) *UserSyncController { + normalized := make(map[string]config.SSOClientConfig, len(clients)) + for alias, cfg := range clients { + alias = strings.ToLower(strings.TrimSpace(alias)) + normalized[alias] = cfg + } + + drift := config.SSOUserSyncDrift + if drift <= 0 { + drift = defaultDrift + } + + nonceTTL := config.SSOUserSyncNonceTTL + if nonceTTL <= 0 { + nonceTTL = defaultNonceTTL + } + + maxBody := config.SSOUserSyncMaxBodyBytes + if maxBody <= 0 { + maxBody = 32 * 1024 + } + + log := utils.Log + if redis == nil { + log.Warn("SSO user sync nonce store fallback to in-memory cache; enable Redis for replay protection") + } + + return &UserSyncController{ + validate: validate, + repo: repo, + redis: redis, + clients: normalized, + drift: drift, + nonceTTL: nonceTTL, + maxBodyBytes: maxBody, + log: log, + } +} + +func (h *UserSyncController) Sync(c *fiber.Ctx) error { + if ct := strings.TrimSpace(c.Get(fiber.HeaderContentType)); ct != "" && !strings.HasPrefix(strings.ToLower(ct), fiber.MIMEApplicationJSON) { + return fiber.NewError(fiber.StatusUnsupportedMediaType, "content-type must be application/json") + } + + body := c.Body() + if h.maxBodyBytes > 0 && len(body) > h.maxBodyBytes { + return fiber.NewError(fiber.StatusRequestEntityTooLarge, "request body too large") + } + + alias, clientCfg, err := h.authenticate(c, body) + if err != nil { + return err + } + + req := new(userSyncRequest) + if err := json.Unmarshal(body, req); err != nil { + return fiber.NewError(fiber.StatusBadRequest, "invalid request body") + } + + req.Action = strings.ToLower(strings.TrimSpace(req.Action)) + req.PublicID = strings.TrimSpace(req.PublicID) + req.User.Email = strings.TrimSpace(req.User.Email) + req.User.Name = strings.TrimSpace(req.User.Name) + + if err := h.validate.Struct(req); err != nil { + return fiber.NewError(fiber.StatusBadRequest, err.Error()) + } + + if clientCfg.PublicID != "" && req.PublicID != clientCfg.PublicID { + return fiber.NewError(fiber.StatusBadRequest, "public_id mismatch with configured client") + } + + if req.Action != "delete" { + if req.User.Email == "" || req.User.Name == "" { + return fiber.NewError(fiber.StatusBadRequest, "email and name are required for create/update actions") + } + if err := h.validate.Var(req.User.Email, "email"); err != nil { + return fiber.NewError(fiber.StatusBadRequest, "invalid email format") + } + } + + if req.User.ID <= 0 { + return fiber.NewError(fiber.StatusBadRequest, "invalid user id") + } + + switch req.Action { + case "create", "update": + return h.upsertUser(c, alias, req) + case "delete": + return h.removeUser(c, alias, req) + default: + return fiber.NewError(fiber.StatusBadRequest, "unsupported action") + } +} + +func (h *UserSyncController) authenticate(c *fiber.Ctx, body []byte) (string, config.SSOClientConfig, error) { + rawAlias := strings.TrimSpace(c.Get(headerClient)) + if rawAlias == "" { + return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "missing sync client header") + } + + aliasKey := strings.ToLower(rawAlias) + clientCfg, ok := h.clients[aliasKey] + if !ok { + return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "unknown sync client") + } + + if err := h.verifyAuthorization(c, aliasKey); err != nil { + return "", config.SSOClientConfig{}, err + } + + secret := strings.TrimSpace(clientCfg.SyncSecret) + if secret == "" { + return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "sync secret not configured") + } + + timestamp := strings.TrimSpace(c.Get(headerTimestamp)) + nonce := strings.TrimSpace(c.Get(headerNonce)) + signature := strings.TrimSpace(c.Get(headerSignature)) + + if timestamp == "" || nonce == "" || signature == "" { + return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "missing signature headers") + } + if len(nonce) < 16 { + return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "nonce too short") + } + + ts, err := strconv.ParseInt(timestamp, 10, 64) + if err != nil { + return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusBadRequest, "invalid timestamp") + } + + msgTime := time.Unix(ts, 0).UTC() + now := time.Now().UTC() + drift := now.Sub(msgTime) + if drift > h.drift || drift < -h.drift { + return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "timestamp outside allowed window") + } + + providedSig, err := decodeSignature(signature) + if err != nil { + return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature encoding") + } + + expectedSignature := h.calculateSignature(secret, rawAlias, timestamp, nonce, body) + if !hmac.Equal(providedSig, expectedSignature) { + return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature") + } + + if err := h.registerNonce(c.Context(), aliasKey, nonce); err != nil { + return "", config.SSOClientConfig{}, err + } + + return aliasKey, clientCfg, nil +} + +func (h *UserSyncController) verifyAuthorization(c *fiber.Ctx, alias string) error { + authHeader := strings.TrimSpace(c.Get(fiber.HeaderAuthorization)) + if authHeader == "" { + return fiber.NewError(fiber.StatusUnauthorized, "missing authorization header") + } + + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { + return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header") + } + + token := strings.TrimSpace(parts[1]) + if token == "" { + return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header") + } + + verification, err := sso.VerifyAccessToken(token) + if err != nil { + return fiber.NewError(fiber.StatusUnauthorized, "invalid access token") + } + + if verification.ServiceAlias == "" || verification.ServiceAlias != alias { + return fiber.NewError(fiber.StatusUnauthorized, "service subject mismatch") + } + + if !containsScope(verification.Claims.Scopes(), "sync.users") { + return fiber.NewError(fiber.StatusForbidden, "missing sync scope") + } + + return nil +} + +func (h *UserSyncController) upsertUser(c *fiber.Ctx, alias string, req *userSyncRequest) error { + entity := &entity.User{ + IdUser: req.User.ID, + Email: req.User.Email, + Name: req.User.Name, + } + + //TODO: MIGRATION TO UPSERT BASE REPOSITORY + if err := h.repo.UpsertByIdUser(c.Context(), entity); err != nil { + h.log.Errorf("sso user upsert failed: %v", err) + return fiber.NewError(fiber.StatusInternalServerError, "failed to upsert user") + } + + user, err := h.repo.GetByIdUser(c.Context(), req.User.ID, nil) + if err != nil { + h.log.Errorf("sso user fetch after upsert failed: %v", err) + return fiber.NewError(fiber.StatusInternalServerError, "failed to load user") + } + + h.log.WithFields(logrus.Fields{ + "action": req.Action, + "public_id": req.PublicID, + "alias": alias, + "user_id": req.User.ID, + }).Info("sso user synced") + + msg := fmt.Sprintf("User %s successfully", req.Action) + return c.Status(fiber.StatusOK).JSON(response.Success{ + Code: fiber.StatusOK, + Status: "success", + Message: msg, + Data: dto.ToUserListDTO(*user), + }) +} + +func (h *UserSyncController) removeUser(c *fiber.Ctx, alias string, req *userSyncRequest) error { + if err := h.repo.SoftDeleteByIdUser(c.Context(), req.User.ID); err != nil { + if err == gorm.ErrRecordNotFound { + return fiber.NewError(fiber.StatusNotFound, "user not found") + } + h.log.Errorf("sso user delete failed: %v", err) + return fiber.NewError(fiber.StatusInternalServerError, "failed to delete user") + } + + h.log.WithFields(logrus.Fields{ + "action": req.Action, + "public_id": req.PublicID, + "alias": alias, + "user_id": req.User.ID, + }).Info("sso user deleted") + + return c.Status(fiber.StatusOK).JSON(response.Common{ + Code: fiber.StatusOK, + Status: "success", + Message: "User deleted successfully", + }) +} + +func (h *UserSyncController) registerNonce(ctx context.Context, alias, nonce string) error { + ttl := h.nonceTTL + if ttl <= 0 { + ttl = defaultNonceTTL + } + + key := fmt.Sprintf("sso:sync:%s:%s", alias, nonce) + if h.redis != nil { + stored, err := h.redis.SetNX(ctx, key, "1", ttl).Result() + if err == nil { + if !stored { + return fiber.NewError(fiber.StatusUnauthorized, "nonce already used") + } + return nil + } + h.log.Errorf("store sync nonce failed: %v", err) + } + + now := time.Now().UTC() + if expRaw, ok := h.localNonces.Load(key); ok { + if expTime, ok := expRaw.(time.Time); ok && expTime.After(now) { + return fiber.NewError(fiber.StatusUnauthorized, "nonce already used") + } + } + h.localNonces.Store(key, now.Add(ttl)) + h.pruneLocalNonces(now) + return nil +} + +func (h *UserSyncController) calculateSignature(secret, alias, timestamp, nonce string, body []byte) []byte { + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write([]byte(alias)) + mac.Write([]byte("\n")) + mac.Write([]byte(timestamp)) + mac.Write([]byte("\n")) + mac.Write([]byte(nonce)) + mac.Write([]byte("\n")) + mac.Write(body) + return mac.Sum(nil) +} + +func containsScope(scopes []string, target string) bool { + target = strings.ToLower(strings.TrimSpace(target)) + if target == "" { + return false + } + for _, scope := range scopes { + if strings.ToLower(strings.TrimSpace(scope)) == target { + return true + } + } + return false +} + +func decodeSignature(sig string) ([]byte, error) { + sig = strings.TrimSpace(sig) + if sig == "" { + return nil, errors.New("empty signature") + } + if decoded, err := hex.DecodeString(sig); err == nil { + return decoded, nil + } + if decoded, err := base64.StdEncoding.DecodeString(sig); err == nil { + return decoded, nil + } + if decoded, err := base64.URLEncoding.DecodeString(sig); err == nil { + return decoded, nil + } + return nil, errors.New("unrecognized signature encoding") +} + +func (h *UserSyncController) pruneLocalNonces(now time.Time) { + h.localNonces.Range(func(key, value any) bool { + exp, ok := value.(time.Time) + if !ok || exp.Before(now) { + h.localNonces.Delete(key) + } + return true + }) +} diff --git a/internal/modules/sso/module.go b/internal/modules/sso/module.go new file mode 100644 index 00000000..4924f071 --- /dev/null +++ b/internal/modules/sso/module.go @@ -0,0 +1,13 @@ +package sso + +import ( + "github.com/go-playground/validator/v10" + "github.com/gofiber/fiber/v2" + "gorm.io/gorm" +) + +type Module struct{} + +func (Module) RegisterRoutes(router fiber.Router, db *gorm.DB, validate *validator.Validate) { + Routes(router, db, validate) +} diff --git a/internal/modules/sso/route.go b/internal/modules/sso/route.go new file mode 100644 index 00000000..1c55830e --- /dev/null +++ b/internal/modules/sso/route.go @@ -0,0 +1,35 @@ +package sso + +import ( + "net/http" + "time" + + "github.com/go-playground/validator/v10" + "github.com/gofiber/fiber/v2" + "gorm.io/gorm" + + "gitlab.com/mbugroup/lti-api.git/internal/cache" + "gitlab.com/mbugroup/lti-api.git/internal/config" + "gitlab.com/mbugroup/lti-api.git/internal/middleware" + ssoController "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/controllers" + "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session" + userRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/users/repositories" +) + +func Routes(router fiber.Router, db *gorm.DB, validate *validator.Validate) { + ttl := config.SSOPKCETTL + if ttl <= 0 { + ttl = 5 * time.Minute + } + + store := session.NewStore(cache.MustRedis(), ttl) + ctrl := ssoController.NewController(&http.Client{Timeout: 10 * time.Second}, store) + userRepo := userRepository.NewUserRepository(db) + syncCtrl := ssoController.NewUserSyncController(validate, userRepo, cache.Redis(), config.SSOClients) + + group := router.Group("/sso") + group.Get("/start", middleware.NewLimiter(30, time.Minute), ctrl.Start) + group.Get("/callback", ctrl.Callback) + group.Get("/userinfo", middleware.NewLimiter(60, time.Minute), ctrl.UserInfo) + group.Post("/users/sync", middleware.NewLimiter(30, time.Minute), syncCtrl.Sync) +} diff --git a/internal/modules/sso/session/store.go b/internal/modules/sso/session/store.go new file mode 100644 index 00000000..f906ae22 --- /dev/null +++ b/internal/modules/sso/session/store.go @@ -0,0 +1,70 @@ +package session + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/redis/go-redis/v9" +) + +const keyTemplate = "sso:pkce:%s" + +// PKCESession holds data required to complete the OAuth2 PKCE exchange. +type PKCESession struct { + CodeVerifier string `json:"code_verifier"` + Nonce string `json:"nonce"` + ClientAlias string `json:"client_alias"` + ClientID string `json:"client_id"` + RedirectURI string `json:"redirect_uri"` + Scope string `json:"scope"` + ReturnTo string `json:"return_to,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +// Store persists pkce sessions inside Redis using a configurable TTL. +type Store struct { + redis *redis.Client + ttl time.Duration +} + +func NewStore(client *redis.Client, ttl time.Duration) *Store { + return &Store{redis: client, ttl: ttl} +} + +func (s *Store) Save(ctx context.Context, state string, payload *PKCESession) error { + if s.redis == nil { + return fmt.Errorf("redis client is not initialised") + } + bytes, err := json.Marshal(payload) + if err != nil { + return err + } + return s.redis.Set(ctx, fmt.Sprintf(keyTemplate, state), bytes, s.ttl).Err() +} + +func (s *Store) Get(ctx context.Context, state string) (*PKCESession, error) { + if s.redis == nil { + return nil, fmt.Errorf("redis client is not initialised") + } + raw, err := s.redis.Get(ctx, fmt.Sprintf(keyTemplate, state)).Result() + if err != nil { + if err == redis.Nil { + return nil, nil + } + return nil, err + } + var payload PKCESession + if err := json.Unmarshal([]byte(raw), &payload); err != nil { + return nil, err + } + return &payload, nil +} + +func (s *Store) Delete(ctx context.Context, state string) error { + if s.redis == nil { + return fmt.Errorf("redis client is not initialised") + } + return s.redis.Del(ctx, fmt.Sprintf(keyTemplate, state)).Err() +} diff --git a/internal/modules/users/controllers/user.controller.go b/internal/modules/users/controllers/user.controller.go index 88361557..f51dfb10 100644 --- a/internal/modules/users/controllers/user.controller.go +++ b/internal/modules/users/controllers/user.controller.go @@ -71,70 +71,70 @@ func (u *UserController) GetOne(c *fiber.Ctx) error { }) } -func (u *UserController) CreateOne(c *fiber.Ctx) error { - req := new(validation.Create) +// func (u *UserController) CreateOne(c *fiber.Ctx) error { +// req := new(validation.Create) - if err := c.BodyParser(req); err != nil { - return fiber.NewError(fiber.StatusBadRequest, "Invalid request body") - } +// if err := c.BodyParser(req); err != nil { +// return fiber.NewError(fiber.StatusBadRequest, "Invalid request body") +// } - result, err := u.UserService.CreateOne(c, req) - if err != nil { - return err - } +// result, err := u.UserService.CreateOne(c, req) +// if err != nil { +// return err +// } - return c.Status(fiber.StatusCreated). - JSON(response.Success{ - Code: fiber.StatusCreated, - Status: "success", - Message: "Create user successfully", - Data: dto.ToUserListDTO(*result), - }) -} +// return c.Status(fiber.StatusCreated). +// JSON(response.Success{ +// Code: fiber.StatusCreated, +// Status: "success", +// Message: "Create user successfully", +// Data: dto.ToUserListDTO(*result), +// }) +// } -func (u *UserController) UpdateOne(c *fiber.Ctx) error { - req := new(validation.Update) - param := c.Params("id") +// func (u *UserController) UpdateOne(c *fiber.Ctx) error { +// req := new(validation.Update) +// param := c.Params("id") - id, err := strconv.Atoi(param) - if err != nil { - return fiber.NewError(fiber.StatusBadRequest, "Invalid Id") - } +// id, err := strconv.Atoi(param) +// if err != nil { +// return fiber.NewError(fiber.StatusBadRequest, "Invalid Id") +// } - if err := c.BodyParser(req); err != nil { - return fiber.NewError(fiber.StatusBadRequest, "Invalid request body") - } +// if err := c.BodyParser(req); err != nil { +// return fiber.NewError(fiber.StatusBadRequest, "Invalid request body") +// } - result, err := u.UserService.UpdateOne(c, req, uint(id)) - if err != nil { - return err - } +// result, err := u.UserService.UpdateOne(c, req, uint(id)) +// if err != nil { +// return err +// } - return c.Status(fiber.StatusOK). - JSON(response.Success{ - Code: fiber.StatusOK, - Status: "success", - Message: "Update user successfully", - Data: dto.ToUserListDTO(*result), - }) -} +// return c.Status(fiber.StatusOK). +// JSON(response.Success{ +// Code: fiber.StatusOK, +// Status: "success", +// Message: "Update user successfully", +// Data: dto.ToUserListDTO(*result), +// }) +// } -func (u *UserController) DeleteOne(c *fiber.Ctx) error { - param := c.Params("id") +// func (u *UserController) DeleteOne(c *fiber.Ctx) error { +// param := c.Params("id") - id, err := strconv.Atoi(param) - if err != nil { - return fiber.NewError(fiber.StatusBadRequest, "Invalid Id") - } +// id, err := strconv.Atoi(param) +// if err != nil { +// return fiber.NewError(fiber.StatusBadRequest, "Invalid Id") +// } - if err := u.UserService.DeleteOne(c, uint(id)); err != nil { - return err - } +// if err := u.UserService.DeleteOne(c, uint(id)); err != nil { +// return err +// } - return c.Status(fiber.StatusOK). - JSON(response.Common{ - Code: fiber.StatusOK, - Status: "success", - Message: "Delete user successfully", - }) -} +// return c.Status(fiber.StatusOK). +// JSON(response.Common{ +// Code: fiber.StatusOK, +// Status: "success", +// Message: "Delete user successfully", +// }) +// } diff --git a/internal/modules/users/repositories/user.repository.go b/internal/modules/users/repositories/user.repository.go index 8472db13..855284ac 100644 --- a/internal/modules/users/repositories/user.repository.go +++ b/internal/modules/users/repositories/user.repository.go @@ -1,21 +1,64 @@ package repository import ( - "gitlab.com/mbugroup/lti-api.git/internal/common/repository" + "context" + "time" + + commonrepo "gitlab.com/mbugroup/lti-api.git/internal/common/repository" entity "gitlab.com/mbugroup/lti-api.git/internal/entities" "gorm.io/gorm" + "gorm.io/gorm/clause" ) type UserRepository interface { - repository.BaseRepository[entity.User] + commonrepo.BaseRepository[entity.User] + GetByIdUser(ctx context.Context, idUser int64, modifier func(*gorm.DB) *gorm.DB) (*entity.User, error) + UpsertByIdUser(ctx context.Context, user *entity.User) error + SoftDeleteByIdUser(ctx context.Context, idUser int64) error } type UserRepositoryImpl struct { - *repository.BaseRepositoryImpl[entity.User] + *commonrepo.BaseRepositoryImpl[entity.User] } func NewUserRepository(db *gorm.DB) UserRepository { return &UserRepositoryImpl{ - BaseRepositoryImpl: repository.NewBaseRepository[entity.User](db), + BaseRepositoryImpl: commonrepo.NewBaseRepository[entity.User](db), } } + +func (r *UserRepositoryImpl) GetByIdUser( + ctx context.Context, + idUser int64, + modifier func(*gorm.DB) *gorm.DB, +) (*entity.User, error) { + return r.BaseRepositoryImpl.First(ctx, func(db *gorm.DB) *gorm.DB { + return db.Where("id_user = ?", idUser) + }) +} + +func (r *UserRepositoryImpl) UpsertByIdUser(ctx context.Context, user *entity.User) error { + if user == nil { + return gorm.ErrInvalidData + } + + conflict := []clause.Column{{Name: "id_user"}} + user.DeletedAt = gorm.DeletedAt{} + user.UpdatedAt = time.Now() + + return r.BaseRepositoryImpl.Upsert(ctx, user, conflict, func(db *gorm.DB) *gorm.DB { + return db.Omit("id", "created_at") + }) +} + +func (r *UserRepositoryImpl) SoftDeleteByIdUser(ctx context.Context, idUser int64) error { + query := r.DB().WithContext(ctx).Where("id_user = ?", idUser) + result := query.Delete(&entity.User{}) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return gorm.ErrRecordNotFound + } + return nil +} diff --git a/internal/modules/users/route.go b/internal/modules/users/route.go index 2c428f3a..9ba6bfb3 100644 --- a/internal/modules/users/route.go +++ b/internal/modules/users/route.go @@ -1,20 +1,22 @@ package users import ( + "github.com/gofiber/fiber/v2" + + "gitlab.com/mbugroup/lti-api.git/internal/middleware" controller "gitlab.com/mbugroup/lti-api.git/internal/modules/users/controllers" user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services" - - "github.com/gofiber/fiber/v2" ) func UserRoutes(v1 fiber.Router, s user.UserService) { ctrl := controller.NewUserController(s) route := v1.Group("/users") + route.Use(middleware.Auth(s)) route.Get("/", ctrl.GetAll) - route.Post("/", ctrl.CreateOne) + // route.Post("/", ctrl.CreateOne) route.Get("/:id", ctrl.GetOne) - route.Patch("/:id", ctrl.UpdateOne) - route.Delete("/:id", ctrl.DeleteOne) + // route.Patch("/:id", ctrl.UpdateOne) + // route.Delete("/:id", ctrl.DeleteOne) } diff --git a/internal/modules/users/services/user.service.go b/internal/modules/users/services/user.service.go index f8e053e4..3b28197e 100644 --- a/internal/modules/users/services/user.service.go +++ b/internal/modules/users/services/user.service.go @@ -20,6 +20,7 @@ type UserService interface { CreateOne(ctx *fiber.Ctx, req *validation.Create) (*entity.User, error) UpdateOne(ctx *fiber.Ctx, req *validation.Update, id uint) (*entity.User, error) DeleteOne(ctx *fiber.Ctx, id uint) error + GetBySSOUserID(ctx *fiber.Ctx, ssoUserID uint) (*entity.User, error) } type userService struct { @@ -68,6 +69,18 @@ func (s userService) GetOne(c *fiber.Ctx, id uint) (*entity.User, error) { return user, nil } +func (s userService) GetBySSOUserID(c *fiber.Ctx, ssoUserID uint) (*entity.User, error) { + user, err := s.Repository.GetByIdUser(c.Context(), int64(ssoUserID), nil) + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, fiber.NewError(fiber.StatusNotFound, "User not found") + } + if err != nil { + s.Log.Errorf("Failed get user by SSO id: %+v", err) + return nil, err + } + return user, nil +} + func (s *userService) CreateOne(c *fiber.Ctx, req *validation.Create) (*entity.User, error) { if err := s.Validate.Struct(req); err != nil { return nil, err diff --git a/internal/route/route.go b/internal/route/route.go index c4bfa4b0..1f595f42 100644 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -8,9 +8,10 @@ import ( "github.com/gofiber/fiber/v2" "gorm.io/gorm" + constants "gitlab.com/mbugroup/lti-api.git/internal/modules/constants" master "gitlab.com/mbugroup/lti-api.git/internal/modules/master" + ssoModule "gitlab.com/mbugroup/lti-api.git/internal/modules/sso" users "gitlab.com/mbugroup/lti-api.git/internal/modules/users" - constants "gitlab.com/mbugroup/lti-api.git/internal/modules/constants" // MODULE IMPORTS ) @@ -23,7 +24,8 @@ func Routes(app *fiber.App, db *gorm.DB) { allModules := []modules.Module{ users.UserModule{}, master.MasterModule{}, - constants.ConstantModule{}, + constants.ConstantModule{}, + ssoModule.Module{}, // MODULE REGISTRY } diff --git a/internal/sso/verifier.go b/internal/sso/verifier.go new file mode 100644 index 00000000..dc9ea111 --- /dev/null +++ b/internal/sso/verifier.go @@ -0,0 +1,161 @@ +package sso + +import ( + "context" + "errors" + "fmt" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/MicahParks/keyfunc/v2" + "github.com/golang-jwt/jwt/v5" + + "gitlab.com/mbugroup/lti-api.git/internal/utils" +) + +type verifier struct { + jwks *keyfunc.JWKS + issuer string + audiences map[string]struct{} +} + +type AccessTokenClaims struct { + Scope string `json:"scope"` + jwt.RegisteredClaims +} + +func (c AccessTokenClaims) Scopes() []string { + if c.Scope == "" { + return nil + } + return strings.Fields(c.Scope) +} + +type VerificationResult struct { + UserID uint + ServiceAlias string + Subject string + Claims *AccessTokenClaims +} + +var ( + globalMu sync.RWMutex + globalV *verifier +) + +func Init(ctx context.Context, jwksURL, issuer string, audiences []string) error { + jwksURL = strings.TrimSpace(jwksURL) + issuer = strings.TrimSpace(issuer) + if jwksURL == "" || issuer == "" { + return errors.New("missing SSO JWKS or issuer configuration") + } + + client := &http.Client{Timeout: 5 * time.Second} + options := keyfunc.Options{ + Ctx: ctx, + Client: client, + RefreshTimeout: 10 * time.Second, + RefreshInterval: time.Hour, + RefreshUnknownKID: true, + RefreshErrorHandler: func(err error) { + utils.Log.Errorf("sso jwks refresh failed: %v", err) + }, + } + + jwks, err := keyfunc.Get(jwksURL, options) + if err != nil { + return fmt.Errorf("load jwks: %w", err) + } + + audienceMap := make(map[string]struct{}, len(audiences)) + for _, aud := range audiences { + aud = strings.TrimSpace(aud) + if aud == "" { + continue + } + audienceMap[aud] = struct{}{} + } + + globalMu.Lock() + globalV = &verifier{jwks: jwks, issuer: issuer, audiences: audienceMap} + globalMu.Unlock() + + utils.Log.Infof("sso verifier initialized for issuer %s (%d keys)", issuer, len(jwks.KIDs())) + return nil +} + +func VerifyAccessToken(token string) (*VerificationResult, error) { + token = strings.TrimSpace(token) + if token == "" { + return nil, errors.New("empty token") + } + + globalMu.RLock() + v := globalV + globalMu.RUnlock() + if v == nil { + return nil, errors.New("sso verifier not initialized") + } + + claims := &AccessTokenClaims{} + parser := jwt.NewParser( + jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()}), + jwt.WithIssuedAt(), + jwt.WithExpirationRequired(), + ) + + tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc) + if err != nil { + return nil, fmt.Errorf("parse token: %w", err) + } + if !tok.Valid { + return nil, errors.New("invalid token") + } + + if claims.Issuer != v.issuer { + return nil, errors.New("unexpected token issuer") + } + + if len(v.audiences) > 0 { + validAud := false + for _, aud := range claims.Audience { + if _, ok := v.audiences[aud]; ok { + validAud = true + break + } + } + if !validAud { + return nil, errors.New("unexpected token audience") + } + } + + sub := strings.TrimSpace(claims.Subject) + if sub == "" { + return nil, errors.New("missing subject") + } + + result := &VerificationResult{Claims: claims, Subject: sub} + + switch { + case strings.HasPrefix(sub, "user:"): + idStr := strings.TrimPrefix(sub, "user:") + id, err := strconv.ParseUint(idStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid subject: %w", err) + } + result.UserID = uint(id) + case strings.HasPrefix(sub, "service:"): + alias := strings.TrimSpace(strings.TrimPrefix(sub, "service:")) + if alias == "" { + return nil, errors.New("invalid service subject") + } + result.ServiceAlias = strings.ToLower(alias) + default: + return nil, errors.New("unsupported subject type") + } + + return result, nil +} diff --git a/internal/utils/secure/random.go b/internal/utils/secure/random.go new file mode 100644 index 00000000..152fd2f9 --- /dev/null +++ b/internal/utils/secure/random.go @@ -0,0 +1,84 @@ +package secure + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "strings" +) + +const pkceCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" + +// RandomBytes returns securely generated random bytes of given length. +func RandomBytes(length int) ([]byte, error) { + if length <= 0 { + return nil, fmt.Errorf("length must be positive") + } + b := make([]byte, length) + if _, err := rand.Read(b); err != nil { + return nil, err + } + return b, nil +} + +// RandomString returns a base64url encoded random string of approximately the requested length. +func RandomString(length int) (string, error) { + if length <= 0 { + return "", fmt.Errorf("length must be positive") + } + // Generate ceil(length * 6/8) bytes to have enough entropy for base64 url encoding. + byteLen := (length*6 + 7) / 8 + bytes, err := RandomBytes(byteLen) + if err != nil { + return "", err + } + s := base64.RawURLEncoding.EncodeToString(bytes) + if len(s) > length { + return s[:length], nil + } + // If encoded string shorter, pad using charset + if len(s) < length { + sb := strings.Builder{} + sb.WriteString(s) + extraNeeded := length - len(s) + more, err := randomFromCharset(extraNeeded) + if err != nil { + return "", err + } + sb.WriteString(more) + return sb.String(), nil + } + return s, nil +} + +// PKCECodeVerifier generates a random string compliant with RFC 7636. +func PKCECodeVerifier(length int) (string, error) { + if length < 43 { + length = 43 + } + if length > 128 { + length = 128 + } + return randomFromCharset(length) +} + +// randomFromCharset returns a random string of given length using pkceCharset. +func randomFromCharset(length int) (string, error) { + if length <= 0 { + return "", fmt.Errorf("length must be positive") + } + bytes, err := RandomBytes(length) + if err != nil { + return "", err + } + out := make([]byte, length) + for i, b := range bytes { + out[i] = pkceCharset[int(b)%len(pkceCharset)] + } + return string(out), nil +} + +// Base64URLEncode encodes the input bytes using base64 URL encoding without padding. +func Base64URLEncode(data []byte) string { + return base64.RawURLEncoding.EncodeToString(data) +} diff --git a/internal/utils/verify.go b/internal/utils/verify.go deleted file mode 100644 index e8b3a850..00000000 --- a/internal/utils/verify.go +++ /dev/null @@ -1,45 +0,0 @@ -package utils - -import ( - "errors" - "strconv" - - "github.com/golang-jwt/jwt/v5" -) - -func VerifyToken(tokenStr, secret, tokenType string) (uint, error) { - token, err := jwt.Parse(tokenStr, func(_ *jwt.Token) (interface{}, error) { - return []byte(secret), nil - }) - if err != nil || !token.Valid { - return 0, err - } - - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - return 0, errors.New("invalid token claims") - } - - jwtType, ok := claims["type"].(string) - if !ok || jwtType != tokenType { - return 0, errors.New("invalid token type") - } - - sub, ok := claims["sub"] - if !ok { - return 0, errors.New("invalid token sub") - } - - switch v := sub.(type) { - case float64: - return uint(v), nil - case string: - id, err := strconv.Atoi(v) - if err != nil { - return 0, errors.New("invalid sub format") - } - return uint(id), nil - default: - return 0, errors.New("unsupported sub type") - } -}