feat/login crud in users sync with sso

This commit is contained in:
ragilap
2025-10-06 12:31:54 +07:00
parent 1684d69fae
commit 6bddbbf9d9
21 changed files with 1576 additions and 136 deletions
+20
View File
@@ -32,3 +32,23 @@ CORS_MAX_AGE=600
# Redis
REDIS_URL=redis://redis:6379/0
REDIS_PORT_HOST=6381
# SSO Integration
SSO_ISSUER=http://localhost:8080/api
SSO_JWKS_URL=http://localhost:8080/api/.well-known/jwks.json
SSO_ALLOWED_AUDIENCES=client:lti-api
SSO_AUTHORIZE_URL=http://localhost:8080/sso/authorize
SSO_TOKEN_URL=http://localhost:8080/sso/token
SSO_GETME_URL=http://localhost:8080/api/auth/get-me
SSO_ACCESS_COOKIE_NAME=sso_access
SSO_REFRESH_COOKIE_NAME=sso_refresh
SSO_COOKIE_DOMAIN=
SSO_COOKIE_SECURE=false
SSO_COOKIE_SAMESITE=Lax
SSO_PKCE_TTL_SECONDS=300
# Security window and payload limits for SSO user sync webhook
SSO_USER_SYNC_SIGNATURE_DRIFT_SECONDS=120
SSO_USER_SYNC_NONCE_TTL_SECONDS=600
SSO_USER_SYNC_MAX_BODY_BYTES=32768
# Example JSON (single-line) of client configs (each client requires a unique sync_secret)
SSO_CLIENTS={"lti":{"public_id":"client:lti","redirect_uri":"http://localhost:8081/api/sso/callback","scope":"openid profile","default_return_uri":"http://localhost:3000","allowed_return_origins":["http://localhost:3000"],"sync_secret":"changeme"}}
+10
View File
@@ -9,10 +9,12 @@ import (
"syscall"
"time"
"gitlab.com/mbugroup/lti-api.git/internal/cache"
"gitlab.com/mbugroup/lti-api.git/internal/config"
"gitlab.com/mbugroup/lti-api.git/internal/database"
"gitlab.com/mbugroup/lti-api.git/internal/middleware"
"gitlab.com/mbugroup/lti-api.git/internal/route"
"gitlab.com/mbugroup/lti-api.git/internal/sso"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
"github.com/gofiber/fiber/v2"
@@ -33,6 +35,7 @@ func main() {
defer closeDatabase(db)
rdb := setupRedis()
defer rdb.Close()
setupSSO(ctx)
setupRoutes(app, db, rdb)
address := fmt.Sprintf("%s:%d", config.AppHost, config.AppPort)
@@ -52,10 +55,17 @@ func setupRedis() *redis.Client {
if err := rdb.Ping(context.Background()).Err(); err != nil {
utils.Log.Fatalf("Redis ping failed: %v", err)
}
cache.SetRedis(rdb)
utils.Log.Infof("Redis connected: %s", config.RedisURL)
return rdb
}
func setupSSO(ctx context.Context) {
if err := sso.Init(ctx, config.SSOJWKSURL, config.SSOIssuer, config.SSOAllowedAudiences); err != nil {
utils.Log.Fatalf("SSO initialization failed: %v", err)
}
}
func setupFiberApp() *fiber.App {
app := fiber.New(config.FiberConfig())
+38
View File
@@ -0,0 +1,38 @@
package cache
import (
"errors"
"sync"
"github.com/redis/go-redis/v9"
)
var (
redisClient *redis.Client
mu sync.RWMutex
)
// SetRedis assigns the global redis client used across the application.
func SetRedis(client *redis.Client) {
mu.Lock()
defer mu.Unlock()
redisClient = client
}
// Redis returns the configured redis client. It may be nil if not yet initialised.
func Redis() *redis.Client {
mu.RLock()
defer mu.RUnlock()
return redisClient
}
// MustRedis returns the redis client or panics if it has not been set.
func MustRedis() *redis.Client {
mu.RLock()
client := redisClient
mu.RUnlock()
if client == nil {
panic(errors.New("redis client not initialised"))
}
return client
}
+16
View File
@@ -11,6 +11,7 @@ type BaseRepository[T any] interface {
GetAll(ctx context.Context, offset, limit int, modifier func(*gorm.DB) *gorm.DB) ([]T, int64, error)
GetByID(ctx context.Context, id uint, modifier func(*gorm.DB) *gorm.DB) (*T, error)
GetByIDs(ctx context.Context, ids []uint, modifier func(*gorm.DB) *gorm.DB) ([]T, error)
First(ctx context.Context, modifier func(*gorm.DB) *gorm.DB) (*T, error)
CreateOne(ctx context.Context, entity *T, modifier func(*gorm.DB) *gorm.DB) error
CreateMany(ctx context.Context, entities []*T, modifier func(*gorm.DB) *gorm.DB) error
@@ -96,6 +97,21 @@ func (r *BaseRepositoryImpl[T]) GetByIDs(
return entities, nil
}
func (r *BaseRepositoryImpl[T]) First(
ctx context.Context,
modifier func(*gorm.DB) *gorm.DB,
) (*T, error) {
entity := new(T)
q := r.db.WithContext(ctx)
if modifier != nil {
q = modifier(q)
}
if err := q.First(entity).Error; err != nil {
return nil, err
}
return entity, nil
}
// ---- CREATE ----
func (r *BaseRepositoryImpl[T]) CreateOne(
ctx context.Context,
+154 -22
View File
@@ -2,36 +2,64 @@ package config
import (
"encoding/json"
"fmt"
"strings"
"time"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
"github.com/spf13/viper"
)
type SSOClientConfig struct {
PublicID string `json:"public_id"`
RedirectURI string `json:"redirect_uri"`
Scope string `json:"scope"`
Prompt string `json:"prompt"`
DefaultReturnURI string `json:"default_return_uri"`
AllowedReturnOrigins []string `json:"allowed_return_origins"`
SyncSecret string `json:"sync_secret"`
}
var (
IsProd bool
AppHost string
Version string
LogLevel string
AppPort int
DBHost string
DBUser string
DBPassword string
DBName string
DBPort int
JWTSecret string
JWTAccessExp int
JWTRefreshExp int
JWTResetPasswordExp int
JWTVerifyEmailExp int
RedisURL string
CORSAllowOrigins []string
CORSAllowMethods []string
CORSAllowHeaders []string
CORSExposeHeaders []string
CORSAllowCredentials bool
CORSMaxAge int
IsProd bool
AppHost string
Version string
LogLevel string
AppPort int
DBHost string
DBUser string
DBPassword string
DBName string
DBPort int
JWTSecret string
JWTAccessExp int
JWTRefreshExp int
JWTResetPasswordExp int
JWTVerifyEmailExp int
RedisURL string
CORSAllowOrigins []string
CORSAllowMethods []string
CORSAllowHeaders []string
CORSExposeHeaders []string
CORSAllowCredentials bool
CORSMaxAge int
SSOIssuer string
SSOJWKSURL string
SSOAllowedAudiences []string
SSOAuthorizeURL string
SSOTokenURL string
SSOGetMeURL string
SSOClients map[string]SSOClientConfig
SSOAccessCookieName string
SSORefreshCookieName string
SSOCookieDomain string
SSOCookieSecure bool
SSOCookieSameSite string
SSOPKCETTL time.Duration
SSOUserSyncDrift time.Duration
SSOUserSyncNonceTTL time.Duration
SSOUserSyncMaxBodyBytes int
)
func init() {
@@ -68,6 +96,43 @@ func init() {
// Redis
RedisURL = viper.GetString("REDIS_URL")
// SSO integration
SSOIssuer = viper.GetString("SSO_ISSUER")
SSOJWKSURL = viper.GetString("SSO_JWKS_URL")
SSOAllowedAudiences = parseList("SSO_ALLOWED_AUDIENCES")
SSOAuthorizeURL = viper.GetString("SSO_AUTHORIZE_URL")
SSOTokenURL = viper.GetString("SSO_TOKEN_URL")
SSOGetMeURL = viper.GetString("SSO_GETME_URL")
SSOAccessCookieName = defaultString(viper.GetString("SSO_ACCESS_COOKIE_NAME"), "sso_access")
SSORefreshCookieName = defaultString(viper.GetString("SSO_REFRESH_COOKIE_NAME"), "sso_refresh")
SSOCookieDomain = viper.GetString("SSO_COOKIE_DOMAIN")
SSOCookieSecure = viper.GetBool("SSO_COOKIE_SECURE")
SSOCookieSameSite = defaultString(viper.GetString("SSO_COOKIE_SAMESITE"), "Lax")
if ttl := viper.GetInt("SSO_PKCE_TTL_SECONDS"); ttl > 0 {
SSOPKCETTL = time.Duration(ttl) * time.Second
} else {
SSOPKCETTL = 5 * time.Minute
}
SSOClients = loadSSOClients("SSO_CLIENTS")
if drift := viper.GetInt("SSO_USER_SYNC_SIGNATURE_DRIFT_SECONDS"); drift > 0 {
SSOUserSyncDrift = time.Duration(drift) * time.Second
} else {
SSOUserSyncDrift = 2 * time.Minute
}
if ttl := viper.GetInt("SSO_USER_SYNC_NONCE_TTL_SECONDS"); ttl > 0 {
SSOUserSyncNonceTTL = time.Duration(ttl) * time.Second
} else {
SSOUserSyncNonceTTL = 10 * time.Minute
}
SSOUserSyncMaxBodyBytes = viper.GetInt("SSO_USER_SYNC_MAX_BODY_BYTES")
if SSOUserSyncMaxBodyBytes <= 0 {
SSOUserSyncMaxBodyBytes = 32 * 1024
}
if IsProd {
ensureProdConfig()
}
}
func loadConfig() {
@@ -117,3 +182,70 @@ func parseListWithDefault(key, def string) []string {
}
return parts
}
func loadSSOClients(key string) map[string]SSOClientConfig {
clients := make(map[string]SSOClientConfig)
raw := strings.TrimSpace(viper.GetString(key))
if raw == "" {
return clients
}
if err := json.Unmarshal([]byte(raw), &clients); err != nil {
utils.Log.Errorf("Failed to parse %s: %v", key, err)
return make(map[string]SSOClientConfig)
}
result := make(map[string]SSOClientConfig, len(clients))
for alias, cfg := range clients {
alias = strings.ToLower(strings.TrimSpace(alias))
for i, origin := range cfg.AllowedReturnOrigins {
cfg.AllowedReturnOrigins[i] = strings.TrimSpace(origin)
}
cfg.SyncSecret = strings.TrimSpace(cfg.SyncSecret)
result[alias] = cfg
}
return result
}
func defaultString(v, def string) string {
if strings.TrimSpace(v) == "" {
return def
}
return v
}
func ensureProdConfig() {
if SSOAuthorizeURL == "" || !strings.HasPrefix(SSOAuthorizeURL, "https://") {
panic("SSO_AUTHORIZE_URL must be https in production")
}
if SSOTokenURL == "" || !strings.HasPrefix(SSOTokenURL, "https://") {
panic("SSO_TOKEN_URL must be https in production")
}
if SSOGetMeURL == "" || !strings.HasPrefix(SSOGetMeURL, "https://") {
panic("SSO_GETME_URL must be https in production")
}
if !SSOCookieSecure {
panic("SSO_COOKIE_SECURE must be true in production")
}
if SSOCookieDomain == "" {
panic("SSO_COOKIE_DOMAIN must be configured in production")
}
if len(SSOAllowedAudiences) == 0 {
panic("SSO_ALLOWED_AUDIENCES must contain at least one audience in production")
}
for alias, cfg := range SSOClients {
if strings.TrimSpace(cfg.SyncSecret) == "" {
panic(fmt.Sprintf("SSO_CLIENTS[%s].sync_secret must be configured in production", alias))
}
if len(cfg.SyncSecret) < 16 {
panic(fmt.Sprintf("SSO_CLIENTS[%s].sync_secret must be at least 16 characters", alias))
}
}
if SSOUserSyncDrift <= 0 {
panic("SSO_USER_SYNC_SIGNATURE_DRIFT_SECONDS must be greater than zero in production")
}
if SSOUserSyncNonceTTL <= 0 {
panic("SSO_USER_SYNC_NONCE_TTL_SECONDS must be greater than zero in production")
}
if SSOUserSyncMaxBodyBytes <= 0 {
panic("SSO_USER_SYNC_MAX_BODY_BYTES must be greater than zero in production")
}
}
@@ -10,6 +10,7 @@ CREATE TABLE users (
);
CREATE UNIQUE INDEX users_id_user_unique ON users (id_user) WHERE deleted_at IS NULL;
CREATE UNIQUE INDEX users_email_unique ON users (email) WHERE deleted_at IS NULL;
-- FLAGS
+32 -3
View File
@@ -5,7 +5,7 @@ import (
"gitlab.com/mbugroup/lti-api.git/internal/config"
service "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
"gitlab.com/mbugroup/lti-api.git/internal/sso"
"github.com/gofiber/fiber/v2"
)
@@ -15,21 +15,50 @@ func Auth(userService service.UserService, requiredRights ...string) fiber.Handl
authHeader := c.Get("Authorization")
token := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer "))
if token == "" {
cookieName := config.SSOAccessCookieName
if cookieName == "" {
cookieName = "access"
}
token = strings.TrimSpace(c.Cookies(cookieName))
}
if token == "" {
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
}
userID, err := utils.VerifyToken(token, config.JWTSecret, config.TokenTypeAccess)
verification, err := sso.VerifyAccessToken(token)
if err != nil {
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
}
user, err := userService.GetOne(c, userID)
if len(config.SSOAllowedAudiences) > 0 {
allowed := make(map[string]struct{}, len(config.SSOAllowedAudiences))
for _, aud := range config.SSOAllowedAudiences {
aud = strings.TrimSpace(aud)
if aud != "" {
allowed[aud] = struct{}{}
}
}
audienceValid := false
for _, aud := range verification.Claims.Audience {
if _, ok := allowed[aud]; ok {
audienceValid = true
break
}
}
if !audienceValid {
return fiber.NewError(fiber.StatusUnauthorized, "invalid audience")
}
}
user, err := userService.GetBySSOUserID(c, verification.UserID)
if err != nil || user == nil {
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
}
c.Locals("user", user)
c.Locals("token_claims", verification.Claims)
// if len(requiredRights) > 0 {
// userRights, hasRights := config.RoleRights[user.Role]
+21
View File
@@ -24,3 +24,24 @@ func LimiterConfig() fiber.Handler {
SkipSuccessfulRequests: true,
})
}
func NewLimiter(max int, expiration time.Duration) fiber.Handler {
if max <= 0 {
max = 10
}
if expiration <= 0 {
expiration = time.Minute
}
return limiter.New(limiter.Config{
Max: max,
Expiration: expiration,
LimitReached: func(c *fiber.Ctx) error {
return c.Status(fiber.StatusTooManyRequests).
JSON(response.Common{
Code: fiber.StatusTooManyRequests,
Status: "error",
Message: "Too many requests, please try again later",
})
},
})
}
@@ -0,0 +1,404 @@
package controllers
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/sirupsen/logrus"
"gitlab.com/mbugroup/lti-api.git/internal/config"
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
"gitlab.com/mbugroup/lti-api.git/internal/sso"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
"gitlab.com/mbugroup/lti-api.git/internal/utils/secure"
)
// Controller manages the SSO start & callback flow using PKCE.
type Controller struct {
httpClient *http.Client
store *session.Store
}
func NewController(client *http.Client, store *session.Store) *Controller {
return &Controller{httpClient: client, store: store}
}
// Start handles GET /sso/start requests and redirects users to the central SSO authorize endpoint.
func (h *Controller) Start(c *fiber.Ctx) error {
alias := strings.ToLower(strings.TrimSpace(c.Query("client")))
if alias == "" {
alias = strings.ToLower(strings.TrimSpace(c.Query("client_id")))
}
if alias == "" {
return fiber.NewError(fiber.StatusBadRequest, "missing client")
}
cfg, ok := config.SSOClients[alias]
if !ok || cfg.PublicID == "" {
return fiber.NewError(fiber.StatusBadRequest, "unknown client")
}
authorizeEndpoint := strings.TrimSpace(config.SSOAuthorizeURL)
if authorizeEndpoint == "" {
return fiber.NewError(fiber.StatusInternalServerError, "authorize endpoint not configured")
}
state, err := secure.RandomString(48)
if err != nil {
utils.Log.Errorf("generate state failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
}
nonce, err := secure.RandomString(32)
if err != nil {
utils.Log.Errorf("generate nonce failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
}
codeVerifier, err := secure.PKCECodeVerifier(96)
if err != nil {
utils.Log.Errorf("generate code verifier failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
}
digest := sha256.Sum256([]byte(codeVerifier))
challenge := secure.Base64URLEncode(digest[:])
authorizeURL, err := url.Parse(authorizeEndpoint)
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, "invalid authorize endpoint")
}
scope := cfg.Scope
if scope == "" {
scope = "openid profile"
}
if !strings.Contains(" "+scope+" ", " openid ") {
scope = scope + " openid"
}
rawReturn := strings.TrimSpace(c.Query("return_to"))
if rawReturn == "" {
rawReturn = cfg.DefaultReturnURI
}
returnTo, err := normalizeReturnTarget(rawReturn, cfg)
if err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
}
query := authorizeURL.Query()
query.Set("response_type", "code")
query.Set("client_id", cfg.PublicID)
query.Set("redirect_uri", cfg.RedirectURI)
query.Set("scope", strings.TrimSpace(scope))
query.Set("state", state)
query.Set("code_challenge", challenge)
query.Set("code_challenge_method", "S256")
query.Set("nonce", nonce)
if prompt := strings.TrimSpace(cfg.Prompt); prompt != "" {
query.Set("prompt", prompt)
}
if extraPrompt := strings.TrimSpace(c.Query("prompt")); extraPrompt != "" {
query.Set("prompt", extraPrompt)
}
authorizeURL.RawQuery = query.Encode()
payload := &session.PKCESession{
CodeVerifier: codeVerifier,
Nonce: nonce,
ClientAlias: alias,
ClientID: cfg.PublicID,
RedirectURI: cfg.RedirectURI,
Scope: strings.TrimSpace(scope),
ReturnTo: returnTo,
CreatedAt: time.Now().UTC(),
}
if err := h.store.Save(c.Context(), state, payload); err != nil {
utils.Log.Errorf("store pkce session failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
}
utils.Log.WithFields(logrus.Fields{
"client": alias,
"state": state,
"return_to": returnTo,
}).Info("sso start redirect")
return c.Redirect(authorizeURL.String(), fiber.StatusFound)
}
// Callback handles the redirect from SSO containing the authorization code.
func (h *Controller) Callback(c *fiber.Ctx) error {
state := strings.TrimSpace(c.Query("state"))
code := strings.TrimSpace(c.Query("code"))
if state == "" || code == "" {
return fiber.NewError(fiber.StatusBadRequest, "missing code or state")
}
sessionData, err := h.store.Get(c.Context(), state)
if err != nil {
utils.Log.Errorf("load pkce session failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to validate authorization state")
}
if sessionData == nil {
return fiber.NewError(fiber.StatusBadRequest, "authorization state not found or expired")
}
defer func() {
if err := h.store.Delete(context.Background(), state); err != nil {
utils.Log.Warnf("failed to delete pkce session: %v", err)
}
}()
tokenEndpoint := strings.TrimSpace(config.SSOTokenURL)
if tokenEndpoint == "" {
return fiber.NewError(fiber.StatusInternalServerError, "token endpoint not configured")
}
form := url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("code", code)
form.Set("code_verifier", sessionData.CodeVerifier)
form.Set("redirect_uri", sessionData.RedirectURI)
form.Set("client_id", sessionData.ClientID)
req, err := http.NewRequestWithContext(c.Context(), http.MethodPost, tokenEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, "failed to create token request")
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := h.httpClient.Do(req)
if err != nil {
utils.Log.Errorf("token request failed: %v", err)
return fiber.NewError(fiber.StatusBadGateway, "failed to exchange authorization code")
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
utils.Log.Warnf("token response status %d", resp.StatusCode)
return fiber.NewError(fiber.StatusBadGateway, "token exchange rejected")
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
IDToken string `json:"id_token"`
Error string `json:"error"`
Description string `json:"error_description"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return fiber.NewError(fiber.StatusBadGateway, "invalid token response")
}
if tokenResp.Error != "" {
return fiber.NewError(fiber.StatusBadGateway, tokenResp.Description)
}
if tokenResp.AccessToken == "" {
return fiber.NewError(fiber.StatusBadGateway, "missing access token")
}
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
if err != nil {
utils.Log.Errorf("access token verification failed: %v", err)
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
}
// prepare cookies
issueCookies(c, tokenResp, verification)
redirectTarget := sessionData.ReturnTo
if redirectTarget == "" {
redirectTarget = "/"
}
fmt.Println(sessionData.ClientAlias,"test")
utils.Log.WithFields(logrus.Fields{
"client": sessionData.ClientAlias,
"user_id": verification.UserID,
"return_to": redirectTarget,
}).Info("sso callback successful")
return c.Redirect(redirectTarget, fiber.StatusFound)
}
// UserInfo proxies the user profile from the central SSO so the frontend can obtain
// enriched user metadata (roles, permissions, etc.) without exposing tokens to the browser.
func (h *Controller) UserInfo(c *fiber.Ctx) error {
accessName := config.SSOAccessCookieName
if accessName == "" {
accessName = "sso_access"
}
token := strings.TrimSpace(c.Cookies(accessName))
tokenFromCookie := token != ""
if !tokenFromCookie {
authHeader := strings.TrimSpace(c.Get("Authorization"))
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
token = strings.TrimSpace(authHeader[7:])
}
}
if token == "" {
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
}
endpoint := strings.TrimSpace(config.SSOGetMeURL)
if endpoint == "" {
return fiber.NewError(fiber.StatusInternalServerError, "userinfo endpoint not configured")
}
req, err := http.NewRequestWithContext(c.Context(), http.MethodGet, endpoint, nil)
if err != nil {
utils.Log.Errorf("failed to build userinfo request: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare userinfo request")
}
req.Header.Set("Accept", "application/json")
// SSO /auth/get-me expects the access cookie; add Authorization as well for compatibility.
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
if tokenFromCookie {
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", accessName, token))
}
resp, err := h.httpClient.Do(req)
if err != nil {
utils.Log.Errorf("userinfo request failed: %v", err)
return fiber.NewError(fiber.StatusBadGateway, "failed to fetch user profile")
}
defer resp.Body.Close()
utils.Log.WithFields(logrus.Fields{"status": resp.StatusCode}).Info("sso userinfo response")
body, err := io.ReadAll(resp.Body)
if err != nil {
utils.Log.Errorf("failed to read userinfo response: %v", err)
return fiber.NewError(fiber.StatusBadGateway, "invalid user profile response")
}
if ct := resp.Header.Get("Content-Type"); ct != "" {
c.Set("Content-Type", ct)
} else {
c.Type("json")
}
return c.Status(resp.StatusCode).Send(body)
}
func issueCookies(c *fiber.Ctx, tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
IDToken string `json:"id_token"`
Error string `json:"error"`
Description string `json:"error_description"`
}, verification *sso.VerificationResult) {
fmt.Println(tokenResp.AccessToken)
accessName := config.SSOAccessCookieName
if accessName == "" {
accessName = "access"
}
refreshName := config.SSORefreshCookieName
if refreshName == "" {
refreshName = "refresh"
}
maxAge := tokenResp.ExpiresIn
if maxAge <= 0 {
maxAge = int(15 * time.Minute.Seconds())
}
sameSite := config.SSOCookieSameSite
if sameSite == "" {
sameSite = "Lax"
}
cookieDomain := config.SSOCookieDomain
cookieAccess := &fiber.Cookie{
Name: accessName,
Value: tokenResp.AccessToken,
Path: "/",
Domain: cookieDomain,
HTTPOnly: true,
Secure: config.SSOCookieSecure,
SameSite: sameSite,
MaxAge: maxAge,
}
c.Cookie(cookieAccess)
if tokenResp.RefreshToken != "" {
cookieRefresh := &fiber.Cookie{
Name: refreshName,
Value: tokenResp.RefreshToken,
Path: "/",
Domain: cookieDomain,
HTTPOnly: true,
Secure: config.SSOCookieSecure,
SameSite: sameSite,
MaxAge: int((time.Hour * 24 * 30).Seconds()),
}
c.Cookie(cookieRefresh)
}
// Optional: expose limited info via headers for FE debugging (avoid tokens)
c.Set("X-Auth-User", fmt.Sprintf("%d", verification.UserID))
}
func normalizeReturnTarget(returnTo string, cfg config.SSOClientConfig) (string, error) {
returnTo = strings.TrimSpace(returnTo)
if returnTo == "" {
return "", nil
}
if strings.HasPrefix(returnTo, "//") {
return "", fmt.Errorf("invalid return_to")
}
if strings.HasPrefix(returnTo, "/") {
return returnTo, nil
}
parsed, err := url.Parse(returnTo)
if err != nil {
return "", fmt.Errorf("invalid return_to")
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return "", fmt.Errorf("invalid return_to scheme")
}
allowedOrigins := make(map[string]struct{})
if cfg.DefaultReturnURI != "" {
if u, err := url.Parse(cfg.DefaultReturnURI); err == nil && u.Host != "" {
allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{}
}
}
for _, origin := range cfg.AllowedReturnOrigins {
origin = strings.TrimSpace(origin)
if origin == "" {
continue
}
if u, err := url.Parse(origin); err == nil && u.Host != "" && (u.Scheme == "http" || u.Scheme == "https") {
allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{}
}
}
if len(allowedOrigins) > 0 {
origin := parsed.Scheme + "://" + parsed.Host
if _, ok := allowedOrigins[origin]; !ok {
return "", fmt.Errorf("return_to origin not allowed")
}
}
return parsed.String(), nil
}
@@ -0,0 +1,391 @@
package controllers
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2"
"github.com/redis/go-redis/v9"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
"gitlab.com/mbugroup/lti-api.git/internal/config"
"gitlab.com/mbugroup/lti-api.git/internal/modules/users/dto"
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
userRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/users/repositories"
"gitlab.com/mbugroup/lti-api.git/internal/response"
"gitlab.com/mbugroup/lti-api.git/internal/sso"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
)
const (
headerClient = "X-Sync-Client"
headerTimestamp = "X-Sync-Timestamp"
headerNonce = "X-Sync-Nonce"
headerSignature = "X-Sync-Signature"
defaultDrift = 2 * time.Minute
defaultNonceTTL = 10 * time.Minute
)
// UserSyncController handles incoming user management events from the central SSO service.
type UserSyncController struct {
validate *validator.Validate
repo userRepository.UserRepository
redis *redis.Client
clients map[string]config.SSOClientConfig
drift time.Duration
nonceTTL time.Duration
maxBodyBytes int
log *logrus.Logger
localNonces sync.Map
}
type userSyncRequest struct {
Action string `json:"action" validate:"required,oneof=create update delete"`
PublicID string `json:"public_id" validate:"required"`
User userSyncUser `json:"user" validate:"required"`
}
type userSyncUser struct {
ID int64 `json:"id" validate:"required"`
Email string `json:"email"`
Name string `json:"name"`
}
func NewUserSyncController(validate *validator.Validate, repo userRepository.UserRepository, redis *redis.Client, clients map[string]config.SSOClientConfig) *UserSyncController {
normalized := make(map[string]config.SSOClientConfig, len(clients))
for alias, cfg := range clients {
alias = strings.ToLower(strings.TrimSpace(alias))
normalized[alias] = cfg
}
drift := config.SSOUserSyncDrift
if drift <= 0 {
drift = defaultDrift
}
nonceTTL := config.SSOUserSyncNonceTTL
if nonceTTL <= 0 {
nonceTTL = defaultNonceTTL
}
maxBody := config.SSOUserSyncMaxBodyBytes
if maxBody <= 0 {
maxBody = 32 * 1024
}
log := utils.Log
if redis == nil {
log.Warn("SSO user sync nonce store fallback to in-memory cache; enable Redis for replay protection")
}
return &UserSyncController{
validate: validate,
repo: repo,
redis: redis,
clients: normalized,
drift: drift,
nonceTTL: nonceTTL,
maxBodyBytes: maxBody,
log: log,
}
}
func (h *UserSyncController) Sync(c *fiber.Ctx) error {
if ct := strings.TrimSpace(c.Get(fiber.HeaderContentType)); ct != "" && !strings.HasPrefix(strings.ToLower(ct), fiber.MIMEApplicationJSON) {
return fiber.NewError(fiber.StatusUnsupportedMediaType, "content-type must be application/json")
}
body := c.Body()
if h.maxBodyBytes > 0 && len(body) > h.maxBodyBytes {
return fiber.NewError(fiber.StatusRequestEntityTooLarge, "request body too large")
}
alias, clientCfg, err := h.authenticate(c, body)
if err != nil {
return err
}
req := new(userSyncRequest)
if err := json.Unmarshal(body, req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, "invalid request body")
}
req.Action = strings.ToLower(strings.TrimSpace(req.Action))
req.PublicID = strings.TrimSpace(req.PublicID)
req.User.Email = strings.TrimSpace(req.User.Email)
req.User.Name = strings.TrimSpace(req.User.Name)
if err := h.validate.Struct(req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
}
if clientCfg.PublicID != "" && req.PublicID != clientCfg.PublicID {
return fiber.NewError(fiber.StatusBadRequest, "public_id mismatch with configured client")
}
if req.Action != "delete" {
if req.User.Email == "" || req.User.Name == "" {
return fiber.NewError(fiber.StatusBadRequest, "email and name are required for create/update actions")
}
if err := h.validate.Var(req.User.Email, "email"); err != nil {
return fiber.NewError(fiber.StatusBadRequest, "invalid email format")
}
}
if req.User.ID <= 0 {
return fiber.NewError(fiber.StatusBadRequest, "invalid user id")
}
switch req.Action {
case "create", "update":
return h.upsertUser(c, alias, req)
case "delete":
return h.removeUser(c, alias, req)
default:
return fiber.NewError(fiber.StatusBadRequest, "unsupported action")
}
}
func (h *UserSyncController) authenticate(c *fiber.Ctx, body []byte) (string, config.SSOClientConfig, error) {
rawAlias := strings.TrimSpace(c.Get(headerClient))
if rawAlias == "" {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "missing sync client header")
}
aliasKey := strings.ToLower(rawAlias)
clientCfg, ok := h.clients[aliasKey]
if !ok {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "unknown sync client")
}
if err := h.verifyAuthorization(c, aliasKey); err != nil {
return "", config.SSOClientConfig{}, err
}
secret := strings.TrimSpace(clientCfg.SyncSecret)
if secret == "" {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "sync secret not configured")
}
timestamp := strings.TrimSpace(c.Get(headerTimestamp))
nonce := strings.TrimSpace(c.Get(headerNonce))
signature := strings.TrimSpace(c.Get(headerSignature))
if timestamp == "" || nonce == "" || signature == "" {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "missing signature headers")
}
if len(nonce) < 16 {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "nonce too short")
}
ts, err := strconv.ParseInt(timestamp, 10, 64)
if err != nil {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusBadRequest, "invalid timestamp")
}
msgTime := time.Unix(ts, 0).UTC()
now := time.Now().UTC()
drift := now.Sub(msgTime)
if drift > h.drift || drift < -h.drift {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "timestamp outside allowed window")
}
providedSig, err := decodeSignature(signature)
if err != nil {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature encoding")
}
expectedSignature := h.calculateSignature(secret, rawAlias, timestamp, nonce, body)
if !hmac.Equal(providedSig, expectedSignature) {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature")
}
if err := h.registerNonce(c.Context(), aliasKey, nonce); err != nil {
return "", config.SSOClientConfig{}, err
}
return aliasKey, clientCfg, nil
}
func (h *UserSyncController) verifyAuthorization(c *fiber.Ctx, alias string) error {
authHeader := strings.TrimSpace(c.Get(fiber.HeaderAuthorization))
if authHeader == "" {
return fiber.NewError(fiber.StatusUnauthorized, "missing authorization header")
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header")
}
token := strings.TrimSpace(parts[1])
if token == "" {
return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header")
}
verification, err := sso.VerifyAccessToken(token)
if err != nil {
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
}
if verification.ServiceAlias == "" || verification.ServiceAlias != alias {
return fiber.NewError(fiber.StatusUnauthorized, "service subject mismatch")
}
if !containsScope(verification.Claims.Scopes(), "sync.users") {
return fiber.NewError(fiber.StatusForbidden, "missing sync scope")
}
return nil
}
func (h *UserSyncController) upsertUser(c *fiber.Ctx, alias string, req *userSyncRequest) error {
entity := &entity.User{
IdUser: req.User.ID,
Email: req.User.Email,
Name: req.User.Name,
}
//TODO: MIGRATION TO UPSERT BASE REPOSITORY
if err := h.repo.UpsertByIdUser(c.Context(), entity); err != nil {
h.log.Errorf("sso user upsert failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to upsert user")
}
user, err := h.repo.GetByIdUser(c.Context(), req.User.ID, nil)
if err != nil {
h.log.Errorf("sso user fetch after upsert failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to load user")
}
h.log.WithFields(logrus.Fields{
"action": req.Action,
"public_id": req.PublicID,
"alias": alias,
"user_id": req.User.ID,
}).Info("sso user synced")
msg := fmt.Sprintf("User %s successfully", req.Action)
return c.Status(fiber.StatusOK).JSON(response.Success{
Code: fiber.StatusOK,
Status: "success",
Message: msg,
Data: dto.ToUserListDTO(*user),
})
}
func (h *UserSyncController) removeUser(c *fiber.Ctx, alias string, req *userSyncRequest) error {
if err := h.repo.SoftDeleteByIdUser(c.Context(), req.User.ID); err != nil {
if err == gorm.ErrRecordNotFound {
return fiber.NewError(fiber.StatusNotFound, "user not found")
}
h.log.Errorf("sso user delete failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to delete user")
}
h.log.WithFields(logrus.Fields{
"action": req.Action,
"public_id": req.PublicID,
"alias": alias,
"user_id": req.User.ID,
}).Info("sso user deleted")
return c.Status(fiber.StatusOK).JSON(response.Common{
Code: fiber.StatusOK,
Status: "success",
Message: "User deleted successfully",
})
}
func (h *UserSyncController) registerNonce(ctx context.Context, alias, nonce string) error {
ttl := h.nonceTTL
if ttl <= 0 {
ttl = defaultNonceTTL
}
key := fmt.Sprintf("sso:sync:%s:%s", alias, nonce)
if h.redis != nil {
stored, err := h.redis.SetNX(ctx, key, "1", ttl).Result()
if err == nil {
if !stored {
return fiber.NewError(fiber.StatusUnauthorized, "nonce already used")
}
return nil
}
h.log.Errorf("store sync nonce failed: %v", err)
}
now := time.Now().UTC()
if expRaw, ok := h.localNonces.Load(key); ok {
if expTime, ok := expRaw.(time.Time); ok && expTime.After(now) {
return fiber.NewError(fiber.StatusUnauthorized, "nonce already used")
}
}
h.localNonces.Store(key, now.Add(ttl))
h.pruneLocalNonces(now)
return nil
}
func (h *UserSyncController) calculateSignature(secret, alias, timestamp, nonce string, body []byte) []byte {
mac := hmac.New(sha256.New, []byte(secret))
mac.Write([]byte(alias))
mac.Write([]byte("\n"))
mac.Write([]byte(timestamp))
mac.Write([]byte("\n"))
mac.Write([]byte(nonce))
mac.Write([]byte("\n"))
mac.Write(body)
return mac.Sum(nil)
}
func containsScope(scopes []string, target string) bool {
target = strings.ToLower(strings.TrimSpace(target))
if target == "" {
return false
}
for _, scope := range scopes {
if strings.ToLower(strings.TrimSpace(scope)) == target {
return true
}
}
return false
}
func decodeSignature(sig string) ([]byte, error) {
sig = strings.TrimSpace(sig)
if sig == "" {
return nil, errors.New("empty signature")
}
if decoded, err := hex.DecodeString(sig); err == nil {
return decoded, nil
}
if decoded, err := base64.StdEncoding.DecodeString(sig); err == nil {
return decoded, nil
}
if decoded, err := base64.URLEncoding.DecodeString(sig); err == nil {
return decoded, nil
}
return nil, errors.New("unrecognized signature encoding")
}
func (h *UserSyncController) pruneLocalNonces(now time.Time) {
h.localNonces.Range(func(key, value any) bool {
exp, ok := value.(time.Time)
if !ok || exp.Before(now) {
h.localNonces.Delete(key)
}
return true
})
}
+13
View File
@@ -0,0 +1,13 @@
package sso
import (
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2"
"gorm.io/gorm"
)
type Module struct{}
func (Module) RegisterRoutes(router fiber.Router, db *gorm.DB, validate *validator.Validate) {
Routes(router, db, validate)
}
+35
View File
@@ -0,0 +1,35 @@
package sso
import (
"net/http"
"time"
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2"
"gorm.io/gorm"
"gitlab.com/mbugroup/lti-api.git/internal/cache"
"gitlab.com/mbugroup/lti-api.git/internal/config"
"gitlab.com/mbugroup/lti-api.git/internal/middleware"
ssoController "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/controllers"
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
userRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/users/repositories"
)
func Routes(router fiber.Router, db *gorm.DB, validate *validator.Validate) {
ttl := config.SSOPKCETTL
if ttl <= 0 {
ttl = 5 * time.Minute
}
store := session.NewStore(cache.MustRedis(), ttl)
ctrl := ssoController.NewController(&http.Client{Timeout: 10 * time.Second}, store)
userRepo := userRepository.NewUserRepository(db)
syncCtrl := ssoController.NewUserSyncController(validate, userRepo, cache.Redis(), config.SSOClients)
group := router.Group("/sso")
group.Get("/start", middleware.NewLimiter(30, time.Minute), ctrl.Start)
group.Get("/callback", ctrl.Callback)
group.Get("/userinfo", middleware.NewLimiter(60, time.Minute), ctrl.UserInfo)
group.Post("/users/sync", middleware.NewLimiter(30, time.Minute), syncCtrl.Sync)
}
+70
View File
@@ -0,0 +1,70 @@
package session
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
)
const keyTemplate = "sso:pkce:%s"
// PKCESession holds data required to complete the OAuth2 PKCE exchange.
type PKCESession struct {
CodeVerifier string `json:"code_verifier"`
Nonce string `json:"nonce"`
ClientAlias string `json:"client_alias"`
ClientID string `json:"client_id"`
RedirectURI string `json:"redirect_uri"`
Scope string `json:"scope"`
ReturnTo string `json:"return_to,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// Store persists pkce sessions inside Redis using a configurable TTL.
type Store struct {
redis *redis.Client
ttl time.Duration
}
func NewStore(client *redis.Client, ttl time.Duration) *Store {
return &Store{redis: client, ttl: ttl}
}
func (s *Store) Save(ctx context.Context, state string, payload *PKCESession) error {
if s.redis == nil {
return fmt.Errorf("redis client is not initialised")
}
bytes, err := json.Marshal(payload)
if err != nil {
return err
}
return s.redis.Set(ctx, fmt.Sprintf(keyTemplate, state), bytes, s.ttl).Err()
}
func (s *Store) Get(ctx context.Context, state string) (*PKCESession, error) {
if s.redis == nil {
return nil, fmt.Errorf("redis client is not initialised")
}
raw, err := s.redis.Get(ctx, fmt.Sprintf(keyTemplate, state)).Result()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, err
}
var payload PKCESession
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
return nil, err
}
return &payload, nil
}
func (s *Store) Delete(ctx context.Context, state string) error {
if s.redis == nil {
return fmt.Errorf("redis client is not initialised")
}
return s.redis.Del(ctx, fmt.Sprintf(keyTemplate, state)).Err()
}
@@ -71,70 +71,70 @@ func (u *UserController) GetOne(c *fiber.Ctx) error {
})
}
func (u *UserController) CreateOne(c *fiber.Ctx) error {
req := new(validation.Create)
// func (u *UserController) CreateOne(c *fiber.Ctx) error {
// req := new(validation.Create)
if err := c.BodyParser(req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, "Invalid request body")
}
// if err := c.BodyParser(req); err != nil {
// return fiber.NewError(fiber.StatusBadRequest, "Invalid request body")
// }
result, err := u.UserService.CreateOne(c, req)
if err != nil {
return err
}
// result, err := u.UserService.CreateOne(c, req)
// if err != nil {
// return err
// }
return c.Status(fiber.StatusCreated).
JSON(response.Success{
Code: fiber.StatusCreated,
Status: "success",
Message: "Create user successfully",
Data: dto.ToUserListDTO(*result),
})
}
// return c.Status(fiber.StatusCreated).
// JSON(response.Success{
// Code: fiber.StatusCreated,
// Status: "success",
// Message: "Create user successfully",
// Data: dto.ToUserListDTO(*result),
// })
// }
func (u *UserController) UpdateOne(c *fiber.Ctx) error {
req := new(validation.Update)
param := c.Params("id")
// func (u *UserController) UpdateOne(c *fiber.Ctx) error {
// req := new(validation.Update)
// param := c.Params("id")
id, err := strconv.Atoi(param)
if err != nil {
return fiber.NewError(fiber.StatusBadRequest, "Invalid Id")
}
// id, err := strconv.Atoi(param)
// if err != nil {
// return fiber.NewError(fiber.StatusBadRequest, "Invalid Id")
// }
if err := c.BodyParser(req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, "Invalid request body")
}
// if err := c.BodyParser(req); err != nil {
// return fiber.NewError(fiber.StatusBadRequest, "Invalid request body")
// }
result, err := u.UserService.UpdateOne(c, req, uint(id))
if err != nil {
return err
}
// result, err := u.UserService.UpdateOne(c, req, uint(id))
// if err != nil {
// return err
// }
return c.Status(fiber.StatusOK).
JSON(response.Success{
Code: fiber.StatusOK,
Status: "success",
Message: "Update user successfully",
Data: dto.ToUserListDTO(*result),
})
}
// return c.Status(fiber.StatusOK).
// JSON(response.Success{
// Code: fiber.StatusOK,
// Status: "success",
// Message: "Update user successfully",
// Data: dto.ToUserListDTO(*result),
// })
// }
func (u *UserController) DeleteOne(c *fiber.Ctx) error {
param := c.Params("id")
// func (u *UserController) DeleteOne(c *fiber.Ctx) error {
// param := c.Params("id")
id, err := strconv.Atoi(param)
if err != nil {
return fiber.NewError(fiber.StatusBadRequest, "Invalid Id")
}
// id, err := strconv.Atoi(param)
// if err != nil {
// return fiber.NewError(fiber.StatusBadRequest, "Invalid Id")
// }
if err := u.UserService.DeleteOne(c, uint(id)); err != nil {
return err
}
// if err := u.UserService.DeleteOne(c, uint(id)); err != nil {
// return err
// }
return c.Status(fiber.StatusOK).
JSON(response.Common{
Code: fiber.StatusOK,
Status: "success",
Message: "Delete user successfully",
})
}
// return c.Status(fiber.StatusOK).
// JSON(response.Common{
// Code: fiber.StatusOK,
// Status: "success",
// Message: "Delete user successfully",
// })
// }
@@ -1,21 +1,64 @@
package repository
import (
"gitlab.com/mbugroup/lti-api.git/internal/common/repository"
"context"
"time"
commonrepo "gitlab.com/mbugroup/lti-api.git/internal/common/repository"
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type UserRepository interface {
repository.BaseRepository[entity.User]
commonrepo.BaseRepository[entity.User]
GetByIdUser(ctx context.Context, idUser int64, modifier func(*gorm.DB) *gorm.DB) (*entity.User, error)
UpsertByIdUser(ctx context.Context, user *entity.User) error
SoftDeleteByIdUser(ctx context.Context, idUser int64) error
}
type UserRepositoryImpl struct {
*repository.BaseRepositoryImpl[entity.User]
*commonrepo.BaseRepositoryImpl[entity.User]
}
func NewUserRepository(db *gorm.DB) UserRepository {
return &UserRepositoryImpl{
BaseRepositoryImpl: repository.NewBaseRepository[entity.User](db),
BaseRepositoryImpl: commonrepo.NewBaseRepository[entity.User](db),
}
}
func (r *UserRepositoryImpl) GetByIdUser(
ctx context.Context,
idUser int64,
modifier func(*gorm.DB) *gorm.DB,
) (*entity.User, error) {
return r.BaseRepositoryImpl.First(ctx, func(db *gorm.DB) *gorm.DB {
return db.Where("id_user = ?", idUser)
})
}
func (r *UserRepositoryImpl) UpsertByIdUser(ctx context.Context, user *entity.User) error {
if user == nil {
return gorm.ErrInvalidData
}
conflict := []clause.Column{{Name: "id_user"}}
user.DeletedAt = gorm.DeletedAt{}
user.UpdatedAt = time.Now()
return r.BaseRepositoryImpl.Upsert(ctx, user, conflict, func(db *gorm.DB) *gorm.DB {
return db.Omit("id", "created_at")
})
}
func (r *UserRepositoryImpl) SoftDeleteByIdUser(ctx context.Context, idUser int64) error {
query := r.DB().WithContext(ctx).Where("id_user = ?", idUser)
result := query.Delete(&entity.User{})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
+7 -5
View File
@@ -1,20 +1,22 @@
package users
import (
"github.com/gofiber/fiber/v2"
"gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/users/controllers"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
"github.com/gofiber/fiber/v2"
)
func UserRoutes(v1 fiber.Router, s user.UserService) {
ctrl := controller.NewUserController(s)
route := v1.Group("/users")
route.Use(middleware.Auth(s))
route.Get("/", ctrl.GetAll)
route.Post("/", ctrl.CreateOne)
// route.Post("/", ctrl.CreateOne)
route.Get("/:id", ctrl.GetOne)
route.Patch("/:id", ctrl.UpdateOne)
route.Delete("/:id", ctrl.DeleteOne)
// route.Patch("/:id", ctrl.UpdateOne)
// route.Delete("/:id", ctrl.DeleteOne)
}
@@ -20,6 +20,7 @@ type UserService interface {
CreateOne(ctx *fiber.Ctx, req *validation.Create) (*entity.User, error)
UpdateOne(ctx *fiber.Ctx, req *validation.Update, id uint) (*entity.User, error)
DeleteOne(ctx *fiber.Ctx, id uint) error
GetBySSOUserID(ctx *fiber.Ctx, ssoUserID uint) (*entity.User, error)
}
type userService struct {
@@ -68,6 +69,18 @@ func (s userService) GetOne(c *fiber.Ctx, id uint) (*entity.User, error) {
return user, nil
}
func (s userService) GetBySSOUserID(c *fiber.Ctx, ssoUserID uint) (*entity.User, error) {
user, err := s.Repository.GetByIdUser(c.Context(), int64(ssoUserID), nil)
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fiber.NewError(fiber.StatusNotFound, "User not found")
}
if err != nil {
s.Log.Errorf("Failed get user by SSO id: %+v", err)
return nil, err
}
return user, nil
}
func (s *userService) CreateOne(c *fiber.Ctx, req *validation.Create) (*entity.User, error) {
if err := s.Validate.Struct(req); err != nil {
return nil, err
+4 -2
View File
@@ -8,9 +8,10 @@ import (
"github.com/gofiber/fiber/v2"
"gorm.io/gorm"
constants "gitlab.com/mbugroup/lti-api.git/internal/modules/constants"
master "gitlab.com/mbugroup/lti-api.git/internal/modules/master"
ssoModule "gitlab.com/mbugroup/lti-api.git/internal/modules/sso"
users "gitlab.com/mbugroup/lti-api.git/internal/modules/users"
constants "gitlab.com/mbugroup/lti-api.git/internal/modules/constants"
// MODULE IMPORTS
)
@@ -23,7 +24,8 @@ func Routes(app *fiber.App, db *gorm.DB) {
allModules := []modules.Module{
users.UserModule{},
master.MasterModule{},
constants.ConstantModule{},
constants.ConstantModule{},
ssoModule.Module{},
// MODULE REGISTRY
}
+161
View File
@@ -0,0 +1,161 @@
package sso
import (
"context"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/MicahParks/keyfunc/v2"
"github.com/golang-jwt/jwt/v5"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
)
type verifier struct {
jwks *keyfunc.JWKS
issuer string
audiences map[string]struct{}
}
type AccessTokenClaims struct {
Scope string `json:"scope"`
jwt.RegisteredClaims
}
func (c AccessTokenClaims) Scopes() []string {
if c.Scope == "" {
return nil
}
return strings.Fields(c.Scope)
}
type VerificationResult struct {
UserID uint
ServiceAlias string
Subject string
Claims *AccessTokenClaims
}
var (
globalMu sync.RWMutex
globalV *verifier
)
func Init(ctx context.Context, jwksURL, issuer string, audiences []string) error {
jwksURL = strings.TrimSpace(jwksURL)
issuer = strings.TrimSpace(issuer)
if jwksURL == "" || issuer == "" {
return errors.New("missing SSO JWKS or issuer configuration")
}
client := &http.Client{Timeout: 5 * time.Second}
options := keyfunc.Options{
Ctx: ctx,
Client: client,
RefreshTimeout: 10 * time.Second,
RefreshInterval: time.Hour,
RefreshUnknownKID: true,
RefreshErrorHandler: func(err error) {
utils.Log.Errorf("sso jwks refresh failed: %v", err)
},
}
jwks, err := keyfunc.Get(jwksURL, options)
if err != nil {
return fmt.Errorf("load jwks: %w", err)
}
audienceMap := make(map[string]struct{}, len(audiences))
for _, aud := range audiences {
aud = strings.TrimSpace(aud)
if aud == "" {
continue
}
audienceMap[aud] = struct{}{}
}
globalMu.Lock()
globalV = &verifier{jwks: jwks, issuer: issuer, audiences: audienceMap}
globalMu.Unlock()
utils.Log.Infof("sso verifier initialized for issuer %s (%d keys)", issuer, len(jwks.KIDs()))
return nil
}
func VerifyAccessToken(token string) (*VerificationResult, error) {
token = strings.TrimSpace(token)
if token == "" {
return nil, errors.New("empty token")
}
globalMu.RLock()
v := globalV
globalMu.RUnlock()
if v == nil {
return nil, errors.New("sso verifier not initialized")
}
claims := &AccessTokenClaims{}
parser := jwt.NewParser(
jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()}),
jwt.WithIssuedAt(),
jwt.WithExpirationRequired(),
)
tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
if err != nil {
return nil, fmt.Errorf("parse token: %w", err)
}
if !tok.Valid {
return nil, errors.New("invalid token")
}
if claims.Issuer != v.issuer {
return nil, errors.New("unexpected token issuer")
}
if len(v.audiences) > 0 {
validAud := false
for _, aud := range claims.Audience {
if _, ok := v.audiences[aud]; ok {
validAud = true
break
}
}
if !validAud {
return nil, errors.New("unexpected token audience")
}
}
sub := strings.TrimSpace(claims.Subject)
if sub == "" {
return nil, errors.New("missing subject")
}
result := &VerificationResult{Claims: claims, Subject: sub}
switch {
case strings.HasPrefix(sub, "user:"):
idStr := strings.TrimPrefix(sub, "user:")
id, err := strconv.ParseUint(idStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid subject: %w", err)
}
result.UserID = uint(id)
case strings.HasPrefix(sub, "service:"):
alias := strings.TrimSpace(strings.TrimPrefix(sub, "service:"))
if alias == "" {
return nil, errors.New("invalid service subject")
}
result.ServiceAlias = strings.ToLower(alias)
default:
return nil, errors.New("unsupported subject type")
}
return result, nil
}
+84
View File
@@ -0,0 +1,84 @@
package secure
import (
"crypto/rand"
"encoding/base64"
"fmt"
"strings"
)
const pkceCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
// RandomBytes returns securely generated random bytes of given length.
func RandomBytes(length int) ([]byte, error) {
if length <= 0 {
return nil, fmt.Errorf("length must be positive")
}
b := make([]byte, length)
if _, err := rand.Read(b); err != nil {
return nil, err
}
return b, nil
}
// RandomString returns a base64url encoded random string of approximately the requested length.
func RandomString(length int) (string, error) {
if length <= 0 {
return "", fmt.Errorf("length must be positive")
}
// Generate ceil(length * 6/8) bytes to have enough entropy for base64 url encoding.
byteLen := (length*6 + 7) / 8
bytes, err := RandomBytes(byteLen)
if err != nil {
return "", err
}
s := base64.RawURLEncoding.EncodeToString(bytes)
if len(s) > length {
return s[:length], nil
}
// If encoded string shorter, pad using charset
if len(s) < length {
sb := strings.Builder{}
sb.WriteString(s)
extraNeeded := length - len(s)
more, err := randomFromCharset(extraNeeded)
if err != nil {
return "", err
}
sb.WriteString(more)
return sb.String(), nil
}
return s, nil
}
// PKCECodeVerifier generates a random string compliant with RFC 7636.
func PKCECodeVerifier(length int) (string, error) {
if length < 43 {
length = 43
}
if length > 128 {
length = 128
}
return randomFromCharset(length)
}
// randomFromCharset returns a random string of given length using pkceCharset.
func randomFromCharset(length int) (string, error) {
if length <= 0 {
return "", fmt.Errorf("length must be positive")
}
bytes, err := RandomBytes(length)
if err != nil {
return "", err
}
out := make([]byte, length)
for i, b := range bytes {
out[i] = pkceCharset[int(b)%len(pkceCharset)]
}
return string(out), nil
}
// Base64URLEncode encodes the input bytes using base64 URL encoding without padding.
func Base64URLEncode(data []byte) string {
return base64.RawURLEncoding.EncodeToString(data)
}
-45
View File
@@ -1,45 +0,0 @@
package utils
import (
"errors"
"strconv"
"github.com/golang-jwt/jwt/v5"
)
func VerifyToken(tokenStr, secret, tokenType string) (uint, error) {
token, err := jwt.Parse(tokenStr, func(_ *jwt.Token) (interface{}, error) {
return []byte(secret), nil
})
if err != nil || !token.Valid {
return 0, err
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return 0, errors.New("invalid token claims")
}
jwtType, ok := claims["type"].(string)
if !ok || jwtType != tokenType {
return 0, errors.New("invalid token type")
}
sub, ok := claims["sub"]
if !ok {
return 0, errors.New("invalid token sub")
}
switch v := sub.(type) {
case float64:
return uint(v), nil
case string:
id, err := strconv.Atoi(v)
if err != nil {
return 0, errors.New("invalid sub format")
}
return uint(id), nil
default:
return 0, errors.New("unsupported sub type")
}
}