mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-20 13:31:56 +00:00
feat/login crud in users sync with sso
This commit is contained in:
@@ -32,3 +32,23 @@ CORS_MAX_AGE=600
|
|||||||
# Redis
|
# Redis
|
||||||
REDIS_URL=redis://redis:6379/0
|
REDIS_URL=redis://redis:6379/0
|
||||||
REDIS_PORT_HOST=6381
|
REDIS_PORT_HOST=6381
|
||||||
|
|
||||||
|
# SSO Integration
|
||||||
|
SSO_ISSUER=http://localhost:8080/api
|
||||||
|
SSO_JWKS_URL=http://localhost:8080/api/.well-known/jwks.json
|
||||||
|
SSO_ALLOWED_AUDIENCES=client:lti-api
|
||||||
|
SSO_AUTHORIZE_URL=http://localhost:8080/sso/authorize
|
||||||
|
SSO_TOKEN_URL=http://localhost:8080/sso/token
|
||||||
|
SSO_GETME_URL=http://localhost:8080/api/auth/get-me
|
||||||
|
SSO_ACCESS_COOKIE_NAME=sso_access
|
||||||
|
SSO_REFRESH_COOKIE_NAME=sso_refresh
|
||||||
|
SSO_COOKIE_DOMAIN=
|
||||||
|
SSO_COOKIE_SECURE=false
|
||||||
|
SSO_COOKIE_SAMESITE=Lax
|
||||||
|
SSO_PKCE_TTL_SECONDS=300
|
||||||
|
# Security window and payload limits for SSO user sync webhook
|
||||||
|
SSO_USER_SYNC_SIGNATURE_DRIFT_SECONDS=120
|
||||||
|
SSO_USER_SYNC_NONCE_TTL_SECONDS=600
|
||||||
|
SSO_USER_SYNC_MAX_BODY_BYTES=32768
|
||||||
|
# Example JSON (single-line) of client configs (each client requires a unique sync_secret)
|
||||||
|
SSO_CLIENTS={"lti":{"public_id":"client:lti","redirect_uri":"http://localhost:8081/api/sso/callback","scope":"openid profile","default_return_uri":"http://localhost:3000","allowed_return_origins":["http://localhost:3000"],"sync_secret":"changeme"}}
|
||||||
|
|||||||
@@ -9,10 +9,12 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gitlab.com/mbugroup/lti-api.git/internal/cache"
|
||||||
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
||||||
"gitlab.com/mbugroup/lti-api.git/internal/database"
|
"gitlab.com/mbugroup/lti-api.git/internal/database"
|
||||||
"gitlab.com/mbugroup/lti-api.git/internal/middleware"
|
"gitlab.com/mbugroup/lti-api.git/internal/middleware"
|
||||||
"gitlab.com/mbugroup/lti-api.git/internal/route"
|
"gitlab.com/mbugroup/lti-api.git/internal/route"
|
||||||
|
"gitlab.com/mbugroup/lti-api.git/internal/sso"
|
||||||
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
@@ -33,6 +35,7 @@ func main() {
|
|||||||
defer closeDatabase(db)
|
defer closeDatabase(db)
|
||||||
rdb := setupRedis()
|
rdb := setupRedis()
|
||||||
defer rdb.Close()
|
defer rdb.Close()
|
||||||
|
setupSSO(ctx)
|
||||||
setupRoutes(app, db, rdb)
|
setupRoutes(app, db, rdb)
|
||||||
|
|
||||||
address := fmt.Sprintf("%s:%d", config.AppHost, config.AppPort)
|
address := fmt.Sprintf("%s:%d", config.AppHost, config.AppPort)
|
||||||
@@ -52,10 +55,17 @@ func setupRedis() *redis.Client {
|
|||||||
if err := rdb.Ping(context.Background()).Err(); err != nil {
|
if err := rdb.Ping(context.Background()).Err(); err != nil {
|
||||||
utils.Log.Fatalf("Redis ping failed: %v", err)
|
utils.Log.Fatalf("Redis ping failed: %v", err)
|
||||||
}
|
}
|
||||||
|
cache.SetRedis(rdb)
|
||||||
utils.Log.Infof("Redis connected: %s", config.RedisURL)
|
utils.Log.Infof("Redis connected: %s", config.RedisURL)
|
||||||
return rdb
|
return rdb
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setupSSO(ctx context.Context) {
|
||||||
|
if err := sso.Init(ctx, config.SSOJWKSURL, config.SSOIssuer, config.SSOAllowedAudiences); err != nil {
|
||||||
|
utils.Log.Fatalf("SSO initialization failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func setupFiberApp() *fiber.App {
|
func setupFiberApp() *fiber.App {
|
||||||
app := fiber.New(config.FiberConfig())
|
app := fiber.New(config.FiberConfig())
|
||||||
|
|
||||||
|
|||||||
Vendored
+38
@@ -0,0 +1,38 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
redisClient *redis.Client
|
||||||
|
mu sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetRedis assigns the global redis client used across the application.
|
||||||
|
func SetRedis(client *redis.Client) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
redisClient = client
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redis returns the configured redis client. It may be nil if not yet initialised.
|
||||||
|
func Redis() *redis.Client {
|
||||||
|
mu.RLock()
|
||||||
|
defer mu.RUnlock()
|
||||||
|
return redisClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustRedis returns the redis client or panics if it has not been set.
|
||||||
|
func MustRedis() *redis.Client {
|
||||||
|
mu.RLock()
|
||||||
|
client := redisClient
|
||||||
|
mu.RUnlock()
|
||||||
|
if client == nil {
|
||||||
|
panic(errors.New("redis client not initialised"))
|
||||||
|
}
|
||||||
|
return client
|
||||||
|
}
|
||||||
@@ -11,6 +11,7 @@ type BaseRepository[T any] interface {
|
|||||||
GetAll(ctx context.Context, offset, limit int, modifier func(*gorm.DB) *gorm.DB) ([]T, int64, error)
|
GetAll(ctx context.Context, offset, limit int, modifier func(*gorm.DB) *gorm.DB) ([]T, int64, error)
|
||||||
GetByID(ctx context.Context, id uint, modifier func(*gorm.DB) *gorm.DB) (*T, error)
|
GetByID(ctx context.Context, id uint, modifier func(*gorm.DB) *gorm.DB) (*T, error)
|
||||||
GetByIDs(ctx context.Context, ids []uint, modifier func(*gorm.DB) *gorm.DB) ([]T, error)
|
GetByIDs(ctx context.Context, ids []uint, modifier func(*gorm.DB) *gorm.DB) ([]T, error)
|
||||||
|
First(ctx context.Context, modifier func(*gorm.DB) *gorm.DB) (*T, error)
|
||||||
|
|
||||||
CreateOne(ctx context.Context, entity *T, modifier func(*gorm.DB) *gorm.DB) error
|
CreateOne(ctx context.Context, entity *T, modifier func(*gorm.DB) *gorm.DB) error
|
||||||
CreateMany(ctx context.Context, entities []*T, modifier func(*gorm.DB) *gorm.DB) error
|
CreateMany(ctx context.Context, entities []*T, modifier func(*gorm.DB) *gorm.DB) error
|
||||||
@@ -96,6 +97,21 @@ func (r *BaseRepositoryImpl[T]) GetByIDs(
|
|||||||
return entities, nil
|
return entities, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *BaseRepositoryImpl[T]) First(
|
||||||
|
ctx context.Context,
|
||||||
|
modifier func(*gorm.DB) *gorm.DB,
|
||||||
|
) (*T, error) {
|
||||||
|
entity := new(T)
|
||||||
|
q := r.db.WithContext(ctx)
|
||||||
|
if modifier != nil {
|
||||||
|
q = modifier(q)
|
||||||
|
}
|
||||||
|
if err := q.First(entity).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return entity, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ---- CREATE ----
|
// ---- CREATE ----
|
||||||
func (r *BaseRepositoryImpl[T]) CreateOne(
|
func (r *BaseRepositoryImpl[T]) CreateOne(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
|||||||
+154
-22
@@ -2,36 +2,64 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
||||||
|
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type SSOClientConfig struct {
|
||||||
|
PublicID string `json:"public_id"`
|
||||||
|
RedirectURI string `json:"redirect_uri"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
DefaultReturnURI string `json:"default_return_uri"`
|
||||||
|
AllowedReturnOrigins []string `json:"allowed_return_origins"`
|
||||||
|
SyncSecret string `json:"sync_secret"`
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
IsProd bool
|
IsProd bool
|
||||||
AppHost string
|
AppHost string
|
||||||
Version string
|
Version string
|
||||||
LogLevel string
|
LogLevel string
|
||||||
AppPort int
|
AppPort int
|
||||||
DBHost string
|
DBHost string
|
||||||
DBUser string
|
DBUser string
|
||||||
DBPassword string
|
DBPassword string
|
||||||
DBName string
|
DBName string
|
||||||
DBPort int
|
DBPort int
|
||||||
JWTSecret string
|
JWTSecret string
|
||||||
JWTAccessExp int
|
JWTAccessExp int
|
||||||
JWTRefreshExp int
|
JWTRefreshExp int
|
||||||
JWTResetPasswordExp int
|
JWTResetPasswordExp int
|
||||||
JWTVerifyEmailExp int
|
JWTVerifyEmailExp int
|
||||||
RedisURL string
|
RedisURL string
|
||||||
CORSAllowOrigins []string
|
CORSAllowOrigins []string
|
||||||
CORSAllowMethods []string
|
CORSAllowMethods []string
|
||||||
CORSAllowHeaders []string
|
CORSAllowHeaders []string
|
||||||
CORSExposeHeaders []string
|
CORSExposeHeaders []string
|
||||||
CORSAllowCredentials bool
|
CORSAllowCredentials bool
|
||||||
CORSMaxAge int
|
CORSMaxAge int
|
||||||
|
SSOIssuer string
|
||||||
|
SSOJWKSURL string
|
||||||
|
SSOAllowedAudiences []string
|
||||||
|
SSOAuthorizeURL string
|
||||||
|
SSOTokenURL string
|
||||||
|
SSOGetMeURL string
|
||||||
|
SSOClients map[string]SSOClientConfig
|
||||||
|
SSOAccessCookieName string
|
||||||
|
SSORefreshCookieName string
|
||||||
|
SSOCookieDomain string
|
||||||
|
SSOCookieSecure bool
|
||||||
|
SSOCookieSameSite string
|
||||||
|
SSOPKCETTL time.Duration
|
||||||
|
SSOUserSyncDrift time.Duration
|
||||||
|
SSOUserSyncNonceTTL time.Duration
|
||||||
|
SSOUserSyncMaxBodyBytes int
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -68,6 +96,43 @@ func init() {
|
|||||||
|
|
||||||
// Redis
|
// Redis
|
||||||
RedisURL = viper.GetString("REDIS_URL")
|
RedisURL = viper.GetString("REDIS_URL")
|
||||||
|
|
||||||
|
// SSO integration
|
||||||
|
SSOIssuer = viper.GetString("SSO_ISSUER")
|
||||||
|
SSOJWKSURL = viper.GetString("SSO_JWKS_URL")
|
||||||
|
SSOAllowedAudiences = parseList("SSO_ALLOWED_AUDIENCES")
|
||||||
|
SSOAuthorizeURL = viper.GetString("SSO_AUTHORIZE_URL")
|
||||||
|
SSOTokenURL = viper.GetString("SSO_TOKEN_URL")
|
||||||
|
SSOGetMeURL = viper.GetString("SSO_GETME_URL")
|
||||||
|
SSOAccessCookieName = defaultString(viper.GetString("SSO_ACCESS_COOKIE_NAME"), "sso_access")
|
||||||
|
SSORefreshCookieName = defaultString(viper.GetString("SSO_REFRESH_COOKIE_NAME"), "sso_refresh")
|
||||||
|
SSOCookieDomain = viper.GetString("SSO_COOKIE_DOMAIN")
|
||||||
|
SSOCookieSecure = viper.GetBool("SSO_COOKIE_SECURE")
|
||||||
|
SSOCookieSameSite = defaultString(viper.GetString("SSO_COOKIE_SAMESITE"), "Lax")
|
||||||
|
if ttl := viper.GetInt("SSO_PKCE_TTL_SECONDS"); ttl > 0 {
|
||||||
|
SSOPKCETTL = time.Duration(ttl) * time.Second
|
||||||
|
} else {
|
||||||
|
SSOPKCETTL = 5 * time.Minute
|
||||||
|
}
|
||||||
|
SSOClients = loadSSOClients("SSO_CLIENTS")
|
||||||
|
if drift := viper.GetInt("SSO_USER_SYNC_SIGNATURE_DRIFT_SECONDS"); drift > 0 {
|
||||||
|
SSOUserSyncDrift = time.Duration(drift) * time.Second
|
||||||
|
} else {
|
||||||
|
SSOUserSyncDrift = 2 * time.Minute
|
||||||
|
}
|
||||||
|
if ttl := viper.GetInt("SSO_USER_SYNC_NONCE_TTL_SECONDS"); ttl > 0 {
|
||||||
|
SSOUserSyncNonceTTL = time.Duration(ttl) * time.Second
|
||||||
|
} else {
|
||||||
|
SSOUserSyncNonceTTL = 10 * time.Minute
|
||||||
|
}
|
||||||
|
SSOUserSyncMaxBodyBytes = viper.GetInt("SSO_USER_SYNC_MAX_BODY_BYTES")
|
||||||
|
if SSOUserSyncMaxBodyBytes <= 0 {
|
||||||
|
SSOUserSyncMaxBodyBytes = 32 * 1024
|
||||||
|
}
|
||||||
|
|
||||||
|
if IsProd {
|
||||||
|
ensureProdConfig()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadConfig() {
|
func loadConfig() {
|
||||||
@@ -117,3 +182,70 @@ func parseListWithDefault(key, def string) []string {
|
|||||||
}
|
}
|
||||||
return parts
|
return parts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func loadSSOClients(key string) map[string]SSOClientConfig {
|
||||||
|
clients := make(map[string]SSOClientConfig)
|
||||||
|
raw := strings.TrimSpace(viper.GetString(key))
|
||||||
|
if raw == "" {
|
||||||
|
return clients
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(raw), &clients); err != nil {
|
||||||
|
utils.Log.Errorf("Failed to parse %s: %v", key, err)
|
||||||
|
return make(map[string]SSOClientConfig)
|
||||||
|
}
|
||||||
|
result := make(map[string]SSOClientConfig, len(clients))
|
||||||
|
for alias, cfg := range clients {
|
||||||
|
alias = strings.ToLower(strings.TrimSpace(alias))
|
||||||
|
for i, origin := range cfg.AllowedReturnOrigins {
|
||||||
|
cfg.AllowedReturnOrigins[i] = strings.TrimSpace(origin)
|
||||||
|
}
|
||||||
|
cfg.SyncSecret = strings.TrimSpace(cfg.SyncSecret)
|
||||||
|
result[alias] = cfg
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultString(v, def string) string {
|
||||||
|
if strings.TrimSpace(v) == "" {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureProdConfig() {
|
||||||
|
if SSOAuthorizeURL == "" || !strings.HasPrefix(SSOAuthorizeURL, "https://") {
|
||||||
|
panic("SSO_AUTHORIZE_URL must be https in production")
|
||||||
|
}
|
||||||
|
if SSOTokenURL == "" || !strings.HasPrefix(SSOTokenURL, "https://") {
|
||||||
|
panic("SSO_TOKEN_URL must be https in production")
|
||||||
|
}
|
||||||
|
if SSOGetMeURL == "" || !strings.HasPrefix(SSOGetMeURL, "https://") {
|
||||||
|
panic("SSO_GETME_URL must be https in production")
|
||||||
|
}
|
||||||
|
if !SSOCookieSecure {
|
||||||
|
panic("SSO_COOKIE_SECURE must be true in production")
|
||||||
|
}
|
||||||
|
if SSOCookieDomain == "" {
|
||||||
|
panic("SSO_COOKIE_DOMAIN must be configured in production")
|
||||||
|
}
|
||||||
|
if len(SSOAllowedAudiences) == 0 {
|
||||||
|
panic("SSO_ALLOWED_AUDIENCES must contain at least one audience in production")
|
||||||
|
}
|
||||||
|
for alias, cfg := range SSOClients {
|
||||||
|
if strings.TrimSpace(cfg.SyncSecret) == "" {
|
||||||
|
panic(fmt.Sprintf("SSO_CLIENTS[%s].sync_secret must be configured in production", alias))
|
||||||
|
}
|
||||||
|
if len(cfg.SyncSecret) < 16 {
|
||||||
|
panic(fmt.Sprintf("SSO_CLIENTS[%s].sync_secret must be at least 16 characters", alias))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if SSOUserSyncDrift <= 0 {
|
||||||
|
panic("SSO_USER_SYNC_SIGNATURE_DRIFT_SECONDS must be greater than zero in production")
|
||||||
|
}
|
||||||
|
if SSOUserSyncNonceTTL <= 0 {
|
||||||
|
panic("SSO_USER_SYNC_NONCE_TTL_SECONDS must be greater than zero in production")
|
||||||
|
}
|
||||||
|
if SSOUserSyncMaxBodyBytes <= 0 {
|
||||||
|
panic("SSO_USER_SYNC_MAX_BODY_BYTES must be greater than zero in production")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ CREATE TABLE users (
|
|||||||
);
|
);
|
||||||
|
|
||||||
CREATE UNIQUE INDEX users_id_user_unique ON users (id_user) WHERE deleted_at IS NULL;
|
CREATE UNIQUE INDEX users_id_user_unique ON users (id_user) WHERE deleted_at IS NULL;
|
||||||
|
|
||||||
CREATE UNIQUE INDEX users_email_unique ON users (email) WHERE deleted_at IS NULL;
|
CREATE UNIQUE INDEX users_email_unique ON users (email) WHERE deleted_at IS NULL;
|
||||||
|
|
||||||
-- FLAGS
|
-- FLAGS
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
|
|
||||||
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
||||||
service "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
|
service "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
|
||||||
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
"gitlab.com/mbugroup/lti-api.git/internal/sso"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
)
|
)
|
||||||
@@ -15,21 +15,50 @@ func Auth(userService service.UserService, requiredRights ...string) fiber.Handl
|
|||||||
authHeader := c.Get("Authorization")
|
authHeader := c.Get("Authorization")
|
||||||
token := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer "))
|
token := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer "))
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
cookieName := config.SSOAccessCookieName
|
||||||
|
if cookieName == "" {
|
||||||
|
cookieName = "access"
|
||||||
|
}
|
||||||
|
token = strings.TrimSpace(c.Cookies(cookieName))
|
||||||
|
}
|
||||||
|
|
||||||
if token == "" {
|
if token == "" {
|
||||||
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
|
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, err := utils.VerifyToken(token, config.JWTSecret, config.TokenTypeAccess)
|
verification, err := sso.VerifyAccessToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
|
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := userService.GetOne(c, userID)
|
if len(config.SSOAllowedAudiences) > 0 {
|
||||||
|
allowed := make(map[string]struct{}, len(config.SSOAllowedAudiences))
|
||||||
|
for _, aud := range config.SSOAllowedAudiences {
|
||||||
|
aud = strings.TrimSpace(aud)
|
||||||
|
if aud != "" {
|
||||||
|
allowed[aud] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
audienceValid := false
|
||||||
|
for _, aud := range verification.Claims.Audience {
|
||||||
|
if _, ok := allowed[aud]; ok {
|
||||||
|
audienceValid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !audienceValid {
|
||||||
|
return fiber.NewError(fiber.StatusUnauthorized, "invalid audience")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := userService.GetBySSOUserID(c, verification.UserID)
|
||||||
if err != nil || user == nil {
|
if err != nil || user == nil {
|
||||||
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
|
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Locals("user", user)
|
c.Locals("user", user)
|
||||||
|
c.Locals("token_claims", verification.Claims)
|
||||||
|
|
||||||
// if len(requiredRights) > 0 {
|
// if len(requiredRights) > 0 {
|
||||||
// userRights, hasRights := config.RoleRights[user.Role]
|
// userRights, hasRights := config.RoleRights[user.Role]
|
||||||
|
|||||||
@@ -24,3 +24,24 @@ func LimiterConfig() fiber.Handler {
|
|||||||
SkipSuccessfulRequests: true,
|
SkipSuccessfulRequests: true,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewLimiter(max int, expiration time.Duration) fiber.Handler {
|
||||||
|
if max <= 0 {
|
||||||
|
max = 10
|
||||||
|
}
|
||||||
|
if expiration <= 0 {
|
||||||
|
expiration = time.Minute
|
||||||
|
}
|
||||||
|
return limiter.New(limiter.Config{
|
||||||
|
Max: max,
|
||||||
|
Expiration: expiration,
|
||||||
|
LimitReached: func(c *fiber.Ctx) error {
|
||||||
|
return c.Status(fiber.StatusTooManyRequests).
|
||||||
|
JSON(response.Common{
|
||||||
|
Code: fiber.StatusTooManyRequests,
|
||||||
|
Status: "error",
|
||||||
|
Message: "Too many requests, please try again later",
|
||||||
|
})
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,404 @@
|
|||||||
|
package controllers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
||||||
|
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
|
||||||
|
"gitlab.com/mbugroup/lti-api.git/internal/sso"
|
||||||
|
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
||||||
|
"gitlab.com/mbugroup/lti-api.git/internal/utils/secure"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Controller manages the SSO start & callback flow using PKCE.
|
||||||
|
type Controller struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
store *session.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewController(client *http.Client, store *session.Store) *Controller {
|
||||||
|
return &Controller{httpClient: client, store: store}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start handles GET /sso/start requests and redirects users to the central SSO authorize endpoint.
|
||||||
|
func (h *Controller) Start(c *fiber.Ctx) error {
|
||||||
|
alias := strings.ToLower(strings.TrimSpace(c.Query("client")))
|
||||||
|
if alias == "" {
|
||||||
|
alias = strings.ToLower(strings.TrimSpace(c.Query("client_id")))
|
||||||
|
}
|
||||||
|
if alias == "" {
|
||||||
|
return fiber.NewError(fiber.StatusBadRequest, "missing client")
|
||||||
|
}
|
||||||
|
cfg, ok := config.SSOClients[alias]
|
||||||
|
if !ok || cfg.PublicID == "" {
|
||||||
|
return fiber.NewError(fiber.StatusBadRequest, "unknown client")
|
||||||
|
}
|
||||||
|
|
||||||
|
authorizeEndpoint := strings.TrimSpace(config.SSOAuthorizeURL)
|
||||||
|
if authorizeEndpoint == "" {
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "authorize endpoint not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
state, err := secure.RandomString(48)
|
||||||
|
if err != nil {
|
||||||
|
utils.Log.Errorf("generate state failed: %v", err)
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce, err := secure.RandomString(32)
|
||||||
|
if err != nil {
|
||||||
|
utils.Log.Errorf("generate nonce failed: %v", err)
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
||||||
|
}
|
||||||
|
|
||||||
|
codeVerifier, err := secure.PKCECodeVerifier(96)
|
||||||
|
if err != nil {
|
||||||
|
utils.Log.Errorf("generate code verifier failed: %v", err)
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
||||||
|
}
|
||||||
|
|
||||||
|
digest := sha256.Sum256([]byte(codeVerifier))
|
||||||
|
challenge := secure.Base64URLEncode(digest[:])
|
||||||
|
|
||||||
|
authorizeURL, err := url.Parse(authorizeEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "invalid authorize endpoint")
|
||||||
|
}
|
||||||
|
|
||||||
|
scope := cfg.Scope
|
||||||
|
if scope == "" {
|
||||||
|
scope = "openid profile"
|
||||||
|
}
|
||||||
|
if !strings.Contains(" "+scope+" ", " openid ") {
|
||||||
|
scope = scope + " openid"
|
||||||
|
}
|
||||||
|
|
||||||
|
rawReturn := strings.TrimSpace(c.Query("return_to"))
|
||||||
|
if rawReturn == "" {
|
||||||
|
rawReturn = cfg.DefaultReturnURI
|
||||||
|
}
|
||||||
|
|
||||||
|
returnTo, err := normalizeReturnTarget(rawReturn, cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
query := authorizeURL.Query()
|
||||||
|
query.Set("response_type", "code")
|
||||||
|
query.Set("client_id", cfg.PublicID)
|
||||||
|
query.Set("redirect_uri", cfg.RedirectURI)
|
||||||
|
query.Set("scope", strings.TrimSpace(scope))
|
||||||
|
query.Set("state", state)
|
||||||
|
query.Set("code_challenge", challenge)
|
||||||
|
query.Set("code_challenge_method", "S256")
|
||||||
|
query.Set("nonce", nonce)
|
||||||
|
if prompt := strings.TrimSpace(cfg.Prompt); prompt != "" {
|
||||||
|
query.Set("prompt", prompt)
|
||||||
|
}
|
||||||
|
if extraPrompt := strings.TrimSpace(c.Query("prompt")); extraPrompt != "" {
|
||||||
|
query.Set("prompt", extraPrompt)
|
||||||
|
}
|
||||||
|
authorizeURL.RawQuery = query.Encode()
|
||||||
|
|
||||||
|
payload := &session.PKCESession{
|
||||||
|
CodeVerifier: codeVerifier,
|
||||||
|
Nonce: nonce,
|
||||||
|
ClientAlias: alias,
|
||||||
|
ClientID: cfg.PublicID,
|
||||||
|
RedirectURI: cfg.RedirectURI,
|
||||||
|
Scope: strings.TrimSpace(scope),
|
||||||
|
ReturnTo: returnTo,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.store.Save(c.Context(), state, payload); err != nil {
|
||||||
|
utils.Log.Errorf("store pkce session failed: %v", err)
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
|
||||||
|
}
|
||||||
|
|
||||||
|
utils.Log.WithFields(logrus.Fields{
|
||||||
|
"client": alias,
|
||||||
|
"state": state,
|
||||||
|
"return_to": returnTo,
|
||||||
|
}).Info("sso start redirect")
|
||||||
|
|
||||||
|
return c.Redirect(authorizeURL.String(), fiber.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Callback handles the redirect from SSO containing the authorization code.
|
||||||
|
func (h *Controller) Callback(c *fiber.Ctx) error {
|
||||||
|
state := strings.TrimSpace(c.Query("state"))
|
||||||
|
code := strings.TrimSpace(c.Query("code"))
|
||||||
|
if state == "" || code == "" {
|
||||||
|
return fiber.NewError(fiber.StatusBadRequest, "missing code or state")
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionData, err := h.store.Get(c.Context(), state)
|
||||||
|
if err != nil {
|
||||||
|
utils.Log.Errorf("load pkce session failed: %v", err)
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to validate authorization state")
|
||||||
|
}
|
||||||
|
if sessionData == nil {
|
||||||
|
return fiber.NewError(fiber.StatusBadRequest, "authorization state not found or expired")
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := h.store.Delete(context.Background(), state); err != nil {
|
||||||
|
utils.Log.Warnf("failed to delete pkce session: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
tokenEndpoint := strings.TrimSpace(config.SSOTokenURL)
|
||||||
|
if tokenEndpoint == "" {
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "token endpoint not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
form := url.Values{}
|
||||||
|
form.Set("grant_type", "authorization_code")
|
||||||
|
form.Set("code", code)
|
||||||
|
form.Set("code_verifier", sessionData.CodeVerifier)
|
||||||
|
form.Set("redirect_uri", sessionData.RedirectURI)
|
||||||
|
form.Set("client_id", sessionData.ClientID)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(c.Context(), http.MethodPost, tokenEndpoint, strings.NewReader(form.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to create token request")
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
resp, err := h.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
utils.Log.Errorf("token request failed: %v", err)
|
||||||
|
return fiber.NewError(fiber.StatusBadGateway, "failed to exchange authorization code")
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
utils.Log.Warnf("token response status %d", resp.StatusCode)
|
||||||
|
return fiber.NewError(fiber.StatusBadGateway, "token exchange rejected")
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResp struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
Description string `json:"error_description"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||||
|
return fiber.NewError(fiber.StatusBadGateway, "invalid token response")
|
||||||
|
}
|
||||||
|
if tokenResp.Error != "" {
|
||||||
|
return fiber.NewError(fiber.StatusBadGateway, tokenResp.Description)
|
||||||
|
}
|
||||||
|
if tokenResp.AccessToken == "" {
|
||||||
|
return fiber.NewError(fiber.StatusBadGateway, "missing access token")
|
||||||
|
}
|
||||||
|
|
||||||
|
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
utils.Log.Errorf("access token verification failed: %v", err)
|
||||||
|
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepare cookies
|
||||||
|
issueCookies(c, tokenResp, verification)
|
||||||
|
|
||||||
|
redirectTarget := sessionData.ReturnTo
|
||||||
|
if redirectTarget == "" {
|
||||||
|
redirectTarget = "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(sessionData.ClientAlias,"test")
|
||||||
|
utils.Log.WithFields(logrus.Fields{
|
||||||
|
"client": sessionData.ClientAlias,
|
||||||
|
"user_id": verification.UserID,
|
||||||
|
"return_to": redirectTarget,
|
||||||
|
}).Info("sso callback successful")
|
||||||
|
|
||||||
|
return c.Redirect(redirectTarget, fiber.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserInfo proxies the user profile from the central SSO so the frontend can obtain
|
||||||
|
// enriched user metadata (roles, permissions, etc.) without exposing tokens to the browser.
|
||||||
|
func (h *Controller) UserInfo(c *fiber.Ctx) error {
|
||||||
|
accessName := config.SSOAccessCookieName
|
||||||
|
if accessName == "" {
|
||||||
|
accessName = "sso_access"
|
||||||
|
}
|
||||||
|
|
||||||
|
token := strings.TrimSpace(c.Cookies(accessName))
|
||||||
|
tokenFromCookie := token != ""
|
||||||
|
|
||||||
|
if !tokenFromCookie {
|
||||||
|
authHeader := strings.TrimSpace(c.Get("Authorization"))
|
||||||
|
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
||||||
|
token = strings.TrimSpace(authHeader[7:])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := strings.TrimSpace(config.SSOGetMeURL)
|
||||||
|
if endpoint == "" {
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "userinfo endpoint not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(c.Context(), http.MethodGet, endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
utils.Log.Errorf("failed to build userinfo request: %v", err)
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare userinfo request")
|
||||||
|
}
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
// SSO /auth/get-me expects the access cookie; add Authorization as well for compatibility.
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||||
|
if tokenFromCookie {
|
||||||
|
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", accessName, token))
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := h.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
utils.Log.Errorf("userinfo request failed: %v", err)
|
||||||
|
return fiber.NewError(fiber.StatusBadGateway, "failed to fetch user profile")
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
utils.Log.WithFields(logrus.Fields{"status": resp.StatusCode}).Info("sso userinfo response")
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
utils.Log.Errorf("failed to read userinfo response: %v", err)
|
||||||
|
return fiber.NewError(fiber.StatusBadGateway, "invalid user profile response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ct := resp.Header.Get("Content-Type"); ct != "" {
|
||||||
|
c.Set("Content-Type", ct)
|
||||||
|
} else {
|
||||||
|
c.Type("json")
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Status(resp.StatusCode).Send(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func issueCookies(c *fiber.Ctx, tokenResp struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
Description string `json:"error_description"`
|
||||||
|
}, verification *sso.VerificationResult) {
|
||||||
|
fmt.Println(tokenResp.AccessToken)
|
||||||
|
accessName := config.SSOAccessCookieName
|
||||||
|
if accessName == "" {
|
||||||
|
accessName = "access"
|
||||||
|
}
|
||||||
|
refreshName := config.SSORefreshCookieName
|
||||||
|
if refreshName == "" {
|
||||||
|
refreshName = "refresh"
|
||||||
|
}
|
||||||
|
maxAge := tokenResp.ExpiresIn
|
||||||
|
if maxAge <= 0 {
|
||||||
|
maxAge = int(15 * time.Minute.Seconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
sameSite := config.SSOCookieSameSite
|
||||||
|
if sameSite == "" {
|
||||||
|
sameSite = "Lax"
|
||||||
|
}
|
||||||
|
|
||||||
|
cookieDomain := config.SSOCookieDomain
|
||||||
|
|
||||||
|
cookieAccess := &fiber.Cookie{
|
||||||
|
Name: accessName,
|
||||||
|
Value: tokenResp.AccessToken,
|
||||||
|
Path: "/",
|
||||||
|
Domain: cookieDomain,
|
||||||
|
HTTPOnly: true,
|
||||||
|
Secure: config.SSOCookieSecure,
|
||||||
|
SameSite: sameSite,
|
||||||
|
MaxAge: maxAge,
|
||||||
|
}
|
||||||
|
c.Cookie(cookieAccess)
|
||||||
|
if tokenResp.RefreshToken != "" {
|
||||||
|
cookieRefresh := &fiber.Cookie{
|
||||||
|
Name: refreshName,
|
||||||
|
Value: tokenResp.RefreshToken,
|
||||||
|
Path: "/",
|
||||||
|
Domain: cookieDomain,
|
||||||
|
HTTPOnly: true,
|
||||||
|
Secure: config.SSOCookieSecure,
|
||||||
|
SameSite: sameSite,
|
||||||
|
MaxAge: int((time.Hour * 24 * 30).Seconds()),
|
||||||
|
}
|
||||||
|
c.Cookie(cookieRefresh)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional: expose limited info via headers for FE debugging (avoid tokens)
|
||||||
|
c.Set("X-Auth-User", fmt.Sprintf("%d", verification.UserID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeReturnTarget(returnTo string, cfg config.SSOClientConfig) (string, error) {
|
||||||
|
returnTo = strings.TrimSpace(returnTo)
|
||||||
|
if returnTo == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(returnTo, "//") {
|
||||||
|
return "", fmt.Errorf("invalid return_to")
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(returnTo, "/") {
|
||||||
|
return returnTo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := url.Parse(returnTo)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("invalid return_to")
|
||||||
|
}
|
||||||
|
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||||
|
return "", fmt.Errorf("invalid return_to scheme")
|
||||||
|
}
|
||||||
|
|
||||||
|
allowedOrigins := make(map[string]struct{})
|
||||||
|
if cfg.DefaultReturnURI != "" {
|
||||||
|
if u, err := url.Parse(cfg.DefaultReturnURI); err == nil && u.Host != "" {
|
||||||
|
allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, origin := range cfg.AllowedReturnOrigins {
|
||||||
|
origin = strings.TrimSpace(origin)
|
||||||
|
if origin == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if u, err := url.Parse(origin); err == nil && u.Host != "" && (u.Scheme == "http" || u.Scheme == "https") {
|
||||||
|
allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(allowedOrigins) > 0 {
|
||||||
|
origin := parsed.Scheme + "://" + parsed.Host
|
||||||
|
if _, ok := allowedOrigins[origin]; !ok {
|
||||||
|
return "", fmt.Errorf("return_to origin not allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsed.String(), nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,391 @@
|
|||||||
|
package controllers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
"github.com/go-playground/validator/v10"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
||||||
|
"gitlab.com/mbugroup/lti-api.git/internal/modules/users/dto"
|
||||||
|
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
|
||||||
|
userRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/users/repositories"
|
||||||
|
"gitlab.com/mbugroup/lti-api.git/internal/response"
|
||||||
|
"gitlab.com/mbugroup/lti-api.git/internal/sso"
|
||||||
|
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
const (
|
||||||
|
headerClient = "X-Sync-Client"
|
||||||
|
headerTimestamp = "X-Sync-Timestamp"
|
||||||
|
headerNonce = "X-Sync-Nonce"
|
||||||
|
headerSignature = "X-Sync-Signature"
|
||||||
|
defaultDrift = 2 * time.Minute
|
||||||
|
defaultNonceTTL = 10 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserSyncController handles incoming user management events from the central SSO service.
|
||||||
|
type UserSyncController struct {
|
||||||
|
validate *validator.Validate
|
||||||
|
repo userRepository.UserRepository
|
||||||
|
redis *redis.Client
|
||||||
|
clients map[string]config.SSOClientConfig
|
||||||
|
drift time.Duration
|
||||||
|
nonceTTL time.Duration
|
||||||
|
maxBodyBytes int
|
||||||
|
log *logrus.Logger
|
||||||
|
localNonces sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
type userSyncRequest struct {
|
||||||
|
Action string `json:"action" validate:"required,oneof=create update delete"`
|
||||||
|
PublicID string `json:"public_id" validate:"required"`
|
||||||
|
User userSyncUser `json:"user" validate:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type userSyncUser struct {
|
||||||
|
ID int64 `json:"id" validate:"required"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUserSyncController(validate *validator.Validate, repo userRepository.UserRepository, redis *redis.Client, clients map[string]config.SSOClientConfig) *UserSyncController {
|
||||||
|
normalized := make(map[string]config.SSOClientConfig, len(clients))
|
||||||
|
for alias, cfg := range clients {
|
||||||
|
alias = strings.ToLower(strings.TrimSpace(alias))
|
||||||
|
normalized[alias] = cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
drift := config.SSOUserSyncDrift
|
||||||
|
if drift <= 0 {
|
||||||
|
drift = defaultDrift
|
||||||
|
}
|
||||||
|
|
||||||
|
nonceTTL := config.SSOUserSyncNonceTTL
|
||||||
|
if nonceTTL <= 0 {
|
||||||
|
nonceTTL = defaultNonceTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
maxBody := config.SSOUserSyncMaxBodyBytes
|
||||||
|
if maxBody <= 0 {
|
||||||
|
maxBody = 32 * 1024
|
||||||
|
}
|
||||||
|
|
||||||
|
log := utils.Log
|
||||||
|
if redis == nil {
|
||||||
|
log.Warn("SSO user sync nonce store fallback to in-memory cache; enable Redis for replay protection")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UserSyncController{
|
||||||
|
validate: validate,
|
||||||
|
repo: repo,
|
||||||
|
redis: redis,
|
||||||
|
clients: normalized,
|
||||||
|
drift: drift,
|
||||||
|
nonceTTL: nonceTTL,
|
||||||
|
maxBodyBytes: maxBody,
|
||||||
|
log: log,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *UserSyncController) Sync(c *fiber.Ctx) error {
|
||||||
|
if ct := strings.TrimSpace(c.Get(fiber.HeaderContentType)); ct != "" && !strings.HasPrefix(strings.ToLower(ct), fiber.MIMEApplicationJSON) {
|
||||||
|
return fiber.NewError(fiber.StatusUnsupportedMediaType, "content-type must be application/json")
|
||||||
|
}
|
||||||
|
|
||||||
|
body := c.Body()
|
||||||
|
if h.maxBodyBytes > 0 && len(body) > h.maxBodyBytes {
|
||||||
|
return fiber.NewError(fiber.StatusRequestEntityTooLarge, "request body too large")
|
||||||
|
}
|
||||||
|
|
||||||
|
alias, clientCfg, err := h.authenticate(c, body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req := new(userSyncRequest)
|
||||||
|
if err := json.Unmarshal(body, req); err != nil {
|
||||||
|
return fiber.NewError(fiber.StatusBadRequest, "invalid request body")
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Action = strings.ToLower(strings.TrimSpace(req.Action))
|
||||||
|
req.PublicID = strings.TrimSpace(req.PublicID)
|
||||||
|
req.User.Email = strings.TrimSpace(req.User.Email)
|
||||||
|
req.User.Name = strings.TrimSpace(req.User.Name)
|
||||||
|
|
||||||
|
if err := h.validate.Struct(req); err != nil {
|
||||||
|
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if clientCfg.PublicID != "" && req.PublicID != clientCfg.PublicID {
|
||||||
|
return fiber.NewError(fiber.StatusBadRequest, "public_id mismatch with configured client")
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Action != "delete" {
|
||||||
|
if req.User.Email == "" || req.User.Name == "" {
|
||||||
|
return fiber.NewError(fiber.StatusBadRequest, "email and name are required for create/update actions")
|
||||||
|
}
|
||||||
|
if err := h.validate.Var(req.User.Email, "email"); err != nil {
|
||||||
|
return fiber.NewError(fiber.StatusBadRequest, "invalid email format")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.User.ID <= 0 {
|
||||||
|
return fiber.NewError(fiber.StatusBadRequest, "invalid user id")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch req.Action {
|
||||||
|
case "create", "update":
|
||||||
|
return h.upsertUser(c, alias, req)
|
||||||
|
case "delete":
|
||||||
|
return h.removeUser(c, alias, req)
|
||||||
|
default:
|
||||||
|
return fiber.NewError(fiber.StatusBadRequest, "unsupported action")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *UserSyncController) authenticate(c *fiber.Ctx, body []byte) (string, config.SSOClientConfig, error) {
|
||||||
|
rawAlias := strings.TrimSpace(c.Get(headerClient))
|
||||||
|
if rawAlias == "" {
|
||||||
|
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "missing sync client header")
|
||||||
|
}
|
||||||
|
|
||||||
|
aliasKey := strings.ToLower(rawAlias)
|
||||||
|
clientCfg, ok := h.clients[aliasKey]
|
||||||
|
if !ok {
|
||||||
|
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "unknown sync client")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.verifyAuthorization(c, aliasKey); err != nil {
|
||||||
|
return "", config.SSOClientConfig{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
secret := strings.TrimSpace(clientCfg.SyncSecret)
|
||||||
|
if secret == "" {
|
||||||
|
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "sync secret not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
timestamp := strings.TrimSpace(c.Get(headerTimestamp))
|
||||||
|
nonce := strings.TrimSpace(c.Get(headerNonce))
|
||||||
|
signature := strings.TrimSpace(c.Get(headerSignature))
|
||||||
|
|
||||||
|
if timestamp == "" || nonce == "" || signature == "" {
|
||||||
|
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "missing signature headers")
|
||||||
|
}
|
||||||
|
if len(nonce) < 16 {
|
||||||
|
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "nonce too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
ts, err := strconv.ParseInt(timestamp, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusBadRequest, "invalid timestamp")
|
||||||
|
}
|
||||||
|
|
||||||
|
msgTime := time.Unix(ts, 0).UTC()
|
||||||
|
now := time.Now().UTC()
|
||||||
|
drift := now.Sub(msgTime)
|
||||||
|
if drift > h.drift || drift < -h.drift {
|
||||||
|
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "timestamp outside allowed window")
|
||||||
|
}
|
||||||
|
|
||||||
|
providedSig, err := decodeSignature(signature)
|
||||||
|
if err != nil {
|
||||||
|
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature encoding")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedSignature := h.calculateSignature(secret, rawAlias, timestamp, nonce, body)
|
||||||
|
if !hmac.Equal(providedSig, expectedSignature) {
|
||||||
|
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.registerNonce(c.Context(), aliasKey, nonce); err != nil {
|
||||||
|
return "", config.SSOClientConfig{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return aliasKey, clientCfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *UserSyncController) verifyAuthorization(c *fiber.Ctx, alias string) error {
|
||||||
|
authHeader := strings.TrimSpace(c.Get(fiber.HeaderAuthorization))
|
||||||
|
if authHeader == "" {
|
||||||
|
return fiber.NewError(fiber.StatusUnauthorized, "missing authorization header")
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.SplitN(authHeader, " ", 2)
|
||||||
|
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
||||||
|
return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header")
|
||||||
|
}
|
||||||
|
|
||||||
|
token := strings.TrimSpace(parts[1])
|
||||||
|
if token == "" {
|
||||||
|
return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header")
|
||||||
|
}
|
||||||
|
|
||||||
|
verification, err := sso.VerifyAccessToken(token)
|
||||||
|
if err != nil {
|
||||||
|
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
|
||||||
|
}
|
||||||
|
|
||||||
|
if verification.ServiceAlias == "" || verification.ServiceAlias != alias {
|
||||||
|
return fiber.NewError(fiber.StatusUnauthorized, "service subject mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !containsScope(verification.Claims.Scopes(), "sync.users") {
|
||||||
|
return fiber.NewError(fiber.StatusForbidden, "missing sync scope")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *UserSyncController) upsertUser(c *fiber.Ctx, alias string, req *userSyncRequest) error {
|
||||||
|
entity := &entity.User{
|
||||||
|
IdUser: req.User.ID,
|
||||||
|
Email: req.User.Email,
|
||||||
|
Name: req.User.Name,
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO: MIGRATION TO UPSERT BASE REPOSITORY
|
||||||
|
if err := h.repo.UpsertByIdUser(c.Context(), entity); err != nil {
|
||||||
|
h.log.Errorf("sso user upsert failed: %v", err)
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to upsert user")
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := h.repo.GetByIdUser(c.Context(), req.User.ID, nil)
|
||||||
|
if err != nil {
|
||||||
|
h.log.Errorf("sso user fetch after upsert failed: %v", err)
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to load user")
|
||||||
|
}
|
||||||
|
|
||||||
|
h.log.WithFields(logrus.Fields{
|
||||||
|
"action": req.Action,
|
||||||
|
"public_id": req.PublicID,
|
||||||
|
"alias": alias,
|
||||||
|
"user_id": req.User.ID,
|
||||||
|
}).Info("sso user synced")
|
||||||
|
|
||||||
|
msg := fmt.Sprintf("User %s successfully", req.Action)
|
||||||
|
return c.Status(fiber.StatusOK).JSON(response.Success{
|
||||||
|
Code: fiber.StatusOK,
|
||||||
|
Status: "success",
|
||||||
|
Message: msg,
|
||||||
|
Data: dto.ToUserListDTO(*user),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *UserSyncController) removeUser(c *fiber.Ctx, alias string, req *userSyncRequest) error {
|
||||||
|
if err := h.repo.SoftDeleteByIdUser(c.Context(), req.User.ID); err != nil {
|
||||||
|
if err == gorm.ErrRecordNotFound {
|
||||||
|
return fiber.NewError(fiber.StatusNotFound, "user not found")
|
||||||
|
}
|
||||||
|
h.log.Errorf("sso user delete failed: %v", err)
|
||||||
|
return fiber.NewError(fiber.StatusInternalServerError, "failed to delete user")
|
||||||
|
}
|
||||||
|
|
||||||
|
h.log.WithFields(logrus.Fields{
|
||||||
|
"action": req.Action,
|
||||||
|
"public_id": req.PublicID,
|
||||||
|
"alias": alias,
|
||||||
|
"user_id": req.User.ID,
|
||||||
|
}).Info("sso user deleted")
|
||||||
|
|
||||||
|
return c.Status(fiber.StatusOK).JSON(response.Common{
|
||||||
|
Code: fiber.StatusOK,
|
||||||
|
Status: "success",
|
||||||
|
Message: "User deleted successfully",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *UserSyncController) registerNonce(ctx context.Context, alias, nonce string) error {
|
||||||
|
ttl := h.nonceTTL
|
||||||
|
if ttl <= 0 {
|
||||||
|
ttl = defaultNonceTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
key := fmt.Sprintf("sso:sync:%s:%s", alias, nonce)
|
||||||
|
if h.redis != nil {
|
||||||
|
stored, err := h.redis.SetNX(ctx, key, "1", ttl).Result()
|
||||||
|
if err == nil {
|
||||||
|
if !stored {
|
||||||
|
return fiber.NewError(fiber.StatusUnauthorized, "nonce already used")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
h.log.Errorf("store sync nonce failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
if expRaw, ok := h.localNonces.Load(key); ok {
|
||||||
|
if expTime, ok := expRaw.(time.Time); ok && expTime.After(now) {
|
||||||
|
return fiber.NewError(fiber.StatusUnauthorized, "nonce already used")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.localNonces.Store(key, now.Add(ttl))
|
||||||
|
h.pruneLocalNonces(now)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *UserSyncController) calculateSignature(secret, alias, timestamp, nonce string, body []byte) []byte {
|
||||||
|
mac := hmac.New(sha256.New, []byte(secret))
|
||||||
|
mac.Write([]byte(alias))
|
||||||
|
mac.Write([]byte("\n"))
|
||||||
|
mac.Write([]byte(timestamp))
|
||||||
|
mac.Write([]byte("\n"))
|
||||||
|
mac.Write([]byte(nonce))
|
||||||
|
mac.Write([]byte("\n"))
|
||||||
|
mac.Write(body)
|
||||||
|
return mac.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsScope(scopes []string, target string) bool {
|
||||||
|
target = strings.ToLower(strings.TrimSpace(target))
|
||||||
|
if target == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, scope := range scopes {
|
||||||
|
if strings.ToLower(strings.TrimSpace(scope)) == target {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeSignature(sig string) ([]byte, error) {
|
||||||
|
sig = strings.TrimSpace(sig)
|
||||||
|
if sig == "" {
|
||||||
|
return nil, errors.New("empty signature")
|
||||||
|
}
|
||||||
|
if decoded, err := hex.DecodeString(sig); err == nil {
|
||||||
|
return decoded, nil
|
||||||
|
}
|
||||||
|
if decoded, err := base64.StdEncoding.DecodeString(sig); err == nil {
|
||||||
|
return decoded, nil
|
||||||
|
}
|
||||||
|
if decoded, err := base64.URLEncoding.DecodeString(sig); err == nil {
|
||||||
|
return decoded, nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("unrecognized signature encoding")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *UserSyncController) pruneLocalNonces(now time.Time) {
|
||||||
|
h.localNonces.Range(func(key, value any) bool {
|
||||||
|
exp, ok := value.(time.Time)
|
||||||
|
if !ok || exp.Before(now) {
|
||||||
|
h.localNonces.Delete(key)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
package sso
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-playground/validator/v10"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Module struct{}
|
||||||
|
|
||||||
|
func (Module) RegisterRoutes(router fiber.Router, db *gorm.DB, validate *validator.Validate) {
|
||||||
|
Routes(router, db, validate)
|
||||||
|
}
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
package sso
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-playground/validator/v10"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"gitlab.com/mbugroup/lti-api.git/internal/cache"
|
||||||
|
"gitlab.com/mbugroup/lti-api.git/internal/config"
|
||||||
|
"gitlab.com/mbugroup/lti-api.git/internal/middleware"
|
||||||
|
ssoController "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/controllers"
|
||||||
|
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
|
||||||
|
userRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/users/repositories"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Routes(router fiber.Router, db *gorm.DB, validate *validator.Validate) {
|
||||||
|
ttl := config.SSOPKCETTL
|
||||||
|
if ttl <= 0 {
|
||||||
|
ttl = 5 * time.Minute
|
||||||
|
}
|
||||||
|
|
||||||
|
store := session.NewStore(cache.MustRedis(), ttl)
|
||||||
|
ctrl := ssoController.NewController(&http.Client{Timeout: 10 * time.Second}, store)
|
||||||
|
userRepo := userRepository.NewUserRepository(db)
|
||||||
|
syncCtrl := ssoController.NewUserSyncController(validate, userRepo, cache.Redis(), config.SSOClients)
|
||||||
|
|
||||||
|
group := router.Group("/sso")
|
||||||
|
group.Get("/start", middleware.NewLimiter(30, time.Minute), ctrl.Start)
|
||||||
|
group.Get("/callback", ctrl.Callback)
|
||||||
|
group.Get("/userinfo", middleware.NewLimiter(60, time.Minute), ctrl.UserInfo)
|
||||||
|
group.Post("/users/sync", middleware.NewLimiter(30, time.Minute), syncCtrl.Sync)
|
||||||
|
}
|
||||||
@@ -0,0 +1,70 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const keyTemplate = "sso:pkce:%s"
|
||||||
|
|
||||||
|
// PKCESession holds data required to complete the OAuth2 PKCE exchange.
|
||||||
|
type PKCESession struct {
|
||||||
|
CodeVerifier string `json:"code_verifier"`
|
||||||
|
Nonce string `json:"nonce"`
|
||||||
|
ClientAlias string `json:"client_alias"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
RedirectURI string `json:"redirect_uri"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
ReturnTo string `json:"return_to,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store persists pkce sessions inside Redis using a configurable TTL.
|
||||||
|
type Store struct {
|
||||||
|
redis *redis.Client
|
||||||
|
ttl time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStore(client *redis.Client, ttl time.Duration) *Store {
|
||||||
|
return &Store{redis: client, ttl: ttl}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Save(ctx context.Context, state string, payload *PKCESession) error {
|
||||||
|
if s.redis == nil {
|
||||||
|
return fmt.Errorf("redis client is not initialised")
|
||||||
|
}
|
||||||
|
bytes, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return s.redis.Set(ctx, fmt.Sprintf(keyTemplate, state), bytes, s.ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Get(ctx context.Context, state string) (*PKCESession, error) {
|
||||||
|
if s.redis == nil {
|
||||||
|
return nil, fmt.Errorf("redis client is not initialised")
|
||||||
|
}
|
||||||
|
raw, err := s.redis.Get(ctx, fmt.Sprintf(keyTemplate, state)).Result()
|
||||||
|
if err != nil {
|
||||||
|
if err == redis.Nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var payload PKCESession
|
||||||
|
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Delete(ctx context.Context, state string) error {
|
||||||
|
if s.redis == nil {
|
||||||
|
return fmt.Errorf("redis client is not initialised")
|
||||||
|
}
|
||||||
|
return s.redis.Del(ctx, fmt.Sprintf(keyTemplate, state)).Err()
|
||||||
|
}
|
||||||
@@ -71,70 +71,70 @@ func (u *UserController) GetOne(c *fiber.Ctx) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *UserController) CreateOne(c *fiber.Ctx) error {
|
// func (u *UserController) CreateOne(c *fiber.Ctx) error {
|
||||||
req := new(validation.Create)
|
// req := new(validation.Create)
|
||||||
|
|
||||||
if err := c.BodyParser(req); err != nil {
|
// if err := c.BodyParser(req); err != nil {
|
||||||
return fiber.NewError(fiber.StatusBadRequest, "Invalid request body")
|
// return fiber.NewError(fiber.StatusBadRequest, "Invalid request body")
|
||||||
}
|
// }
|
||||||
|
|
||||||
result, err := u.UserService.CreateOne(c, req)
|
// result, err := u.UserService.CreateOne(c, req)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return err
|
// return err
|
||||||
}
|
// }
|
||||||
|
|
||||||
return c.Status(fiber.StatusCreated).
|
// return c.Status(fiber.StatusCreated).
|
||||||
JSON(response.Success{
|
// JSON(response.Success{
|
||||||
Code: fiber.StatusCreated,
|
// Code: fiber.StatusCreated,
|
||||||
Status: "success",
|
// Status: "success",
|
||||||
Message: "Create user successfully",
|
// Message: "Create user successfully",
|
||||||
Data: dto.ToUserListDTO(*result),
|
// Data: dto.ToUserListDTO(*result),
|
||||||
})
|
// })
|
||||||
}
|
// }
|
||||||
|
|
||||||
func (u *UserController) UpdateOne(c *fiber.Ctx) error {
|
// func (u *UserController) UpdateOne(c *fiber.Ctx) error {
|
||||||
req := new(validation.Update)
|
// req := new(validation.Update)
|
||||||
param := c.Params("id")
|
// param := c.Params("id")
|
||||||
|
|
||||||
id, err := strconv.Atoi(param)
|
// id, err := strconv.Atoi(param)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return fiber.NewError(fiber.StatusBadRequest, "Invalid Id")
|
// return fiber.NewError(fiber.StatusBadRequest, "Invalid Id")
|
||||||
}
|
// }
|
||||||
|
|
||||||
if err := c.BodyParser(req); err != nil {
|
// if err := c.BodyParser(req); err != nil {
|
||||||
return fiber.NewError(fiber.StatusBadRequest, "Invalid request body")
|
// return fiber.NewError(fiber.StatusBadRequest, "Invalid request body")
|
||||||
}
|
// }
|
||||||
|
|
||||||
result, err := u.UserService.UpdateOne(c, req, uint(id))
|
// result, err := u.UserService.UpdateOne(c, req, uint(id))
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return err
|
// return err
|
||||||
}
|
// }
|
||||||
|
|
||||||
return c.Status(fiber.StatusOK).
|
// return c.Status(fiber.StatusOK).
|
||||||
JSON(response.Success{
|
// JSON(response.Success{
|
||||||
Code: fiber.StatusOK,
|
// Code: fiber.StatusOK,
|
||||||
Status: "success",
|
// Status: "success",
|
||||||
Message: "Update user successfully",
|
// Message: "Update user successfully",
|
||||||
Data: dto.ToUserListDTO(*result),
|
// Data: dto.ToUserListDTO(*result),
|
||||||
})
|
// })
|
||||||
}
|
// }
|
||||||
|
|
||||||
func (u *UserController) DeleteOne(c *fiber.Ctx) error {
|
// func (u *UserController) DeleteOne(c *fiber.Ctx) error {
|
||||||
param := c.Params("id")
|
// param := c.Params("id")
|
||||||
|
|
||||||
id, err := strconv.Atoi(param)
|
// id, err := strconv.Atoi(param)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return fiber.NewError(fiber.StatusBadRequest, "Invalid Id")
|
// return fiber.NewError(fiber.StatusBadRequest, "Invalid Id")
|
||||||
}
|
// }
|
||||||
|
|
||||||
if err := u.UserService.DeleteOne(c, uint(id)); err != nil {
|
// if err := u.UserService.DeleteOne(c, uint(id)); err != nil {
|
||||||
return err
|
// return err
|
||||||
}
|
// }
|
||||||
|
|
||||||
return c.Status(fiber.StatusOK).
|
// return c.Status(fiber.StatusOK).
|
||||||
JSON(response.Common{
|
// JSON(response.Common{
|
||||||
Code: fiber.StatusOK,
|
// Code: fiber.StatusOK,
|
||||||
Status: "success",
|
// Status: "success",
|
||||||
Message: "Delete user successfully",
|
// Message: "Delete user successfully",
|
||||||
})
|
// })
|
||||||
}
|
// }
|
||||||
|
|||||||
@@ -1,21 +1,64 @@
|
|||||||
package repository
|
package repository
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"gitlab.com/mbugroup/lti-api.git/internal/common/repository"
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
commonrepo "gitlab.com/mbugroup/lti-api.git/internal/common/repository"
|
||||||
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
|
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserRepository interface {
|
type UserRepository interface {
|
||||||
repository.BaseRepository[entity.User]
|
commonrepo.BaseRepository[entity.User]
|
||||||
|
GetByIdUser(ctx context.Context, idUser int64, modifier func(*gorm.DB) *gorm.DB) (*entity.User, error)
|
||||||
|
UpsertByIdUser(ctx context.Context, user *entity.User) error
|
||||||
|
SoftDeleteByIdUser(ctx context.Context, idUser int64) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserRepositoryImpl struct {
|
type UserRepositoryImpl struct {
|
||||||
*repository.BaseRepositoryImpl[entity.User]
|
*commonrepo.BaseRepositoryImpl[entity.User]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserRepository(db *gorm.DB) UserRepository {
|
func NewUserRepository(db *gorm.DB) UserRepository {
|
||||||
return &UserRepositoryImpl{
|
return &UserRepositoryImpl{
|
||||||
BaseRepositoryImpl: repository.NewBaseRepository[entity.User](db),
|
BaseRepositoryImpl: commonrepo.NewBaseRepository[entity.User](db),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *UserRepositoryImpl) GetByIdUser(
|
||||||
|
ctx context.Context,
|
||||||
|
idUser int64,
|
||||||
|
modifier func(*gorm.DB) *gorm.DB,
|
||||||
|
) (*entity.User, error) {
|
||||||
|
return r.BaseRepositoryImpl.First(ctx, func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Where("id_user = ?", idUser)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *UserRepositoryImpl) UpsertByIdUser(ctx context.Context, user *entity.User) error {
|
||||||
|
if user == nil {
|
||||||
|
return gorm.ErrInvalidData
|
||||||
|
}
|
||||||
|
|
||||||
|
conflict := []clause.Column{{Name: "id_user"}}
|
||||||
|
user.DeletedAt = gorm.DeletedAt{}
|
||||||
|
user.UpdatedAt = time.Now()
|
||||||
|
|
||||||
|
return r.BaseRepositoryImpl.Upsert(ctx, user, conflict, func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Omit("id", "created_at")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *UserRepositoryImpl) SoftDeleteByIdUser(ctx context.Context, idUser int64) error {
|
||||||
|
query := r.DB().WithContext(ctx).Where("id_user = ?", idUser)
|
||||||
|
result := query.Delete(&entity.User{})
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return gorm.ErrRecordNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,20 +1,22 @@
|
|||||||
package users
|
package users
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
|
||||||
|
"gitlab.com/mbugroup/lti-api.git/internal/middleware"
|
||||||
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/users/controllers"
|
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/users/controllers"
|
||||||
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
|
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func UserRoutes(v1 fiber.Router, s user.UserService) {
|
func UserRoutes(v1 fiber.Router, s user.UserService) {
|
||||||
ctrl := controller.NewUserController(s)
|
ctrl := controller.NewUserController(s)
|
||||||
|
|
||||||
route := v1.Group("/users")
|
route := v1.Group("/users")
|
||||||
|
route.Use(middleware.Auth(s))
|
||||||
|
|
||||||
route.Get("/", ctrl.GetAll)
|
route.Get("/", ctrl.GetAll)
|
||||||
route.Post("/", ctrl.CreateOne)
|
// route.Post("/", ctrl.CreateOne)
|
||||||
route.Get("/:id", ctrl.GetOne)
|
route.Get("/:id", ctrl.GetOne)
|
||||||
route.Patch("/:id", ctrl.UpdateOne)
|
// route.Patch("/:id", ctrl.UpdateOne)
|
||||||
route.Delete("/:id", ctrl.DeleteOne)
|
// route.Delete("/:id", ctrl.DeleteOne)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ type UserService interface {
|
|||||||
CreateOne(ctx *fiber.Ctx, req *validation.Create) (*entity.User, error)
|
CreateOne(ctx *fiber.Ctx, req *validation.Create) (*entity.User, error)
|
||||||
UpdateOne(ctx *fiber.Ctx, req *validation.Update, id uint) (*entity.User, error)
|
UpdateOne(ctx *fiber.Ctx, req *validation.Update, id uint) (*entity.User, error)
|
||||||
DeleteOne(ctx *fiber.Ctx, id uint) error
|
DeleteOne(ctx *fiber.Ctx, id uint) error
|
||||||
|
GetBySSOUserID(ctx *fiber.Ctx, ssoUserID uint) (*entity.User, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type userService struct {
|
type userService struct {
|
||||||
@@ -68,6 +69,18 @@ func (s userService) GetOne(c *fiber.Ctx, id uint) (*entity.User, error) {
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s userService) GetBySSOUserID(c *fiber.Ctx, ssoUserID uint) (*entity.User, error) {
|
||||||
|
user, err := s.Repository.GetByIdUser(c.Context(), int64(ssoUserID), nil)
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, fiber.NewError(fiber.StatusNotFound, "User not found")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
s.Log.Errorf("Failed get user by SSO id: %+v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *userService) CreateOne(c *fiber.Ctx, req *validation.Create) (*entity.User, error) {
|
func (s *userService) CreateOne(c *fiber.Ctx, req *validation.Create) (*entity.User, error) {
|
||||||
if err := s.Validate.Struct(req); err != nil {
|
if err := s.Validate.Struct(req); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -8,9 +8,10 @@ import (
|
|||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
constants "gitlab.com/mbugroup/lti-api.git/internal/modules/constants"
|
||||||
master "gitlab.com/mbugroup/lti-api.git/internal/modules/master"
|
master "gitlab.com/mbugroup/lti-api.git/internal/modules/master"
|
||||||
|
ssoModule "gitlab.com/mbugroup/lti-api.git/internal/modules/sso"
|
||||||
users "gitlab.com/mbugroup/lti-api.git/internal/modules/users"
|
users "gitlab.com/mbugroup/lti-api.git/internal/modules/users"
|
||||||
constants "gitlab.com/mbugroup/lti-api.git/internal/modules/constants"
|
|
||||||
// MODULE IMPORTS
|
// MODULE IMPORTS
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -23,7 +24,8 @@ func Routes(app *fiber.App, db *gorm.DB) {
|
|||||||
allModules := []modules.Module{
|
allModules := []modules.Module{
|
||||||
users.UserModule{},
|
users.UserModule{},
|
||||||
master.MasterModule{},
|
master.MasterModule{},
|
||||||
constants.ConstantModule{},
|
constants.ConstantModule{},
|
||||||
|
ssoModule.Module{},
|
||||||
// MODULE REGISTRY
|
// MODULE REGISTRY
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,161 @@
|
|||||||
|
package sso
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/MicahParks/keyfunc/v2"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
|
||||||
|
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type verifier struct {
|
||||||
|
jwks *keyfunc.JWKS
|
||||||
|
issuer string
|
||||||
|
audiences map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type AccessTokenClaims struct {
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c AccessTokenClaims) Scopes() []string {
|
||||||
|
if c.Scope == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return strings.Fields(c.Scope)
|
||||||
|
}
|
||||||
|
|
||||||
|
type VerificationResult struct {
|
||||||
|
UserID uint
|
||||||
|
ServiceAlias string
|
||||||
|
Subject string
|
||||||
|
Claims *AccessTokenClaims
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
globalMu sync.RWMutex
|
||||||
|
globalV *verifier
|
||||||
|
)
|
||||||
|
|
||||||
|
func Init(ctx context.Context, jwksURL, issuer string, audiences []string) error {
|
||||||
|
jwksURL = strings.TrimSpace(jwksURL)
|
||||||
|
issuer = strings.TrimSpace(issuer)
|
||||||
|
if jwksURL == "" || issuer == "" {
|
||||||
|
return errors.New("missing SSO JWKS or issuer configuration")
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 5 * time.Second}
|
||||||
|
options := keyfunc.Options{
|
||||||
|
Ctx: ctx,
|
||||||
|
Client: client,
|
||||||
|
RefreshTimeout: 10 * time.Second,
|
||||||
|
RefreshInterval: time.Hour,
|
||||||
|
RefreshUnknownKID: true,
|
||||||
|
RefreshErrorHandler: func(err error) {
|
||||||
|
utils.Log.Errorf("sso jwks refresh failed: %v", err)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jwks, err := keyfunc.Get(jwksURL, options)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load jwks: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
audienceMap := make(map[string]struct{}, len(audiences))
|
||||||
|
for _, aud := range audiences {
|
||||||
|
aud = strings.TrimSpace(aud)
|
||||||
|
if aud == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
audienceMap[aud] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
globalMu.Lock()
|
||||||
|
globalV = &verifier{jwks: jwks, issuer: issuer, audiences: audienceMap}
|
||||||
|
globalMu.Unlock()
|
||||||
|
|
||||||
|
utils.Log.Infof("sso verifier initialized for issuer %s (%d keys)", issuer, len(jwks.KIDs()))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func VerifyAccessToken(token string) (*VerificationResult, error) {
|
||||||
|
token = strings.TrimSpace(token)
|
||||||
|
if token == "" {
|
||||||
|
return nil, errors.New("empty token")
|
||||||
|
}
|
||||||
|
|
||||||
|
globalMu.RLock()
|
||||||
|
v := globalV
|
||||||
|
globalMu.RUnlock()
|
||||||
|
if v == nil {
|
||||||
|
return nil, errors.New("sso verifier not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := &AccessTokenClaims{}
|
||||||
|
parser := jwt.NewParser(
|
||||||
|
jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()}),
|
||||||
|
jwt.WithIssuedAt(),
|
||||||
|
jwt.WithExpirationRequired(),
|
||||||
|
)
|
||||||
|
|
||||||
|
tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse token: %w", err)
|
||||||
|
}
|
||||||
|
if !tok.Valid {
|
||||||
|
return nil, errors.New("invalid token")
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.Issuer != v.issuer {
|
||||||
|
return nil, errors.New("unexpected token issuer")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(v.audiences) > 0 {
|
||||||
|
validAud := false
|
||||||
|
for _, aud := range claims.Audience {
|
||||||
|
if _, ok := v.audiences[aud]; ok {
|
||||||
|
validAud = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !validAud {
|
||||||
|
return nil, errors.New("unexpected token audience")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sub := strings.TrimSpace(claims.Subject)
|
||||||
|
if sub == "" {
|
||||||
|
return nil, errors.New("missing subject")
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &VerificationResult{Claims: claims, Subject: sub}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(sub, "user:"):
|
||||||
|
idStr := strings.TrimPrefix(sub, "user:")
|
||||||
|
id, err := strconv.ParseUint(idStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid subject: %w", err)
|
||||||
|
}
|
||||||
|
result.UserID = uint(id)
|
||||||
|
case strings.HasPrefix(sub, "service:"):
|
||||||
|
alias := strings.TrimSpace(strings.TrimPrefix(sub, "service:"))
|
||||||
|
if alias == "" {
|
||||||
|
return nil, errors.New("invalid service subject")
|
||||||
|
}
|
||||||
|
result.ServiceAlias = strings.ToLower(alias)
|
||||||
|
default:
|
||||||
|
return nil, errors.New("unsupported subject type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,84 @@
|
|||||||
|
package secure
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const pkceCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
|
||||||
|
|
||||||
|
// RandomBytes returns securely generated random bytes of given length.
|
||||||
|
func RandomBytes(length int) ([]byte, error) {
|
||||||
|
if length <= 0 {
|
||||||
|
return nil, fmt.Errorf("length must be positive")
|
||||||
|
}
|
||||||
|
b := make([]byte, length)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RandomString returns a base64url encoded random string of approximately the requested length.
|
||||||
|
func RandomString(length int) (string, error) {
|
||||||
|
if length <= 0 {
|
||||||
|
return "", fmt.Errorf("length must be positive")
|
||||||
|
}
|
||||||
|
// Generate ceil(length * 6/8) bytes to have enough entropy for base64 url encoding.
|
||||||
|
byteLen := (length*6 + 7) / 8
|
||||||
|
bytes, err := RandomBytes(byteLen)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
s := base64.RawURLEncoding.EncodeToString(bytes)
|
||||||
|
if len(s) > length {
|
||||||
|
return s[:length], nil
|
||||||
|
}
|
||||||
|
// If encoded string shorter, pad using charset
|
||||||
|
if len(s) < length {
|
||||||
|
sb := strings.Builder{}
|
||||||
|
sb.WriteString(s)
|
||||||
|
extraNeeded := length - len(s)
|
||||||
|
more, err := randomFromCharset(extraNeeded)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
sb.WriteString(more)
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PKCECodeVerifier generates a random string compliant with RFC 7636.
|
||||||
|
func PKCECodeVerifier(length int) (string, error) {
|
||||||
|
if length < 43 {
|
||||||
|
length = 43
|
||||||
|
}
|
||||||
|
if length > 128 {
|
||||||
|
length = 128
|
||||||
|
}
|
||||||
|
return randomFromCharset(length)
|
||||||
|
}
|
||||||
|
|
||||||
|
// randomFromCharset returns a random string of given length using pkceCharset.
|
||||||
|
func randomFromCharset(length int) (string, error) {
|
||||||
|
if length <= 0 {
|
||||||
|
return "", fmt.Errorf("length must be positive")
|
||||||
|
}
|
||||||
|
bytes, err := RandomBytes(length)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
out := make([]byte, length)
|
||||||
|
for i, b := range bytes {
|
||||||
|
out[i] = pkceCharset[int(b)%len(pkceCharset)]
|
||||||
|
}
|
||||||
|
return string(out), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base64URLEncode encodes the input bytes using base64 URL encoding without padding.
|
||||||
|
func Base64URLEncode(data []byte) string {
|
||||||
|
return base64.RawURLEncoding.EncodeToString(data)
|
||||||
|
}
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
package utils
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
)
|
|
||||||
|
|
||||||
func VerifyToken(tokenStr, secret, tokenType string) (uint, error) {
|
|
||||||
token, err := jwt.Parse(tokenStr, func(_ *jwt.Token) (interface{}, error) {
|
|
||||||
return []byte(secret), nil
|
|
||||||
})
|
|
||||||
if err != nil || !token.Valid {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
claims, ok := token.Claims.(jwt.MapClaims)
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.New("invalid token claims")
|
|
||||||
}
|
|
||||||
|
|
||||||
jwtType, ok := claims["type"].(string)
|
|
||||||
if !ok || jwtType != tokenType {
|
|
||||||
return 0, errors.New("invalid token type")
|
|
||||||
}
|
|
||||||
|
|
||||||
sub, ok := claims["sub"]
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.New("invalid token sub")
|
|
||||||
}
|
|
||||||
|
|
||||||
switch v := sub.(type) {
|
|
||||||
case float64:
|
|
||||||
return uint(v), nil
|
|
||||||
case string:
|
|
||||||
id, err := strconv.Atoi(v)
|
|
||||||
if err != nil {
|
|
||||||
return 0, errors.New("invalid sub format")
|
|
||||||
}
|
|
||||||
return uint(id), nil
|
|
||||||
default:
|
|
||||||
return 0, errors.New("unsupported sub type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user