Merge branch 'development' of https://gitlab.com/mbugroup/lti-api into dev/teguh

This commit is contained in:
aguhh18
2025-11-17 13:52:13 +07:00
87 changed files with 3284 additions and 582 deletions
+38
View File
@@ -0,0 +1,38 @@
package cache
import (
"errors"
"sync"
"github.com/redis/go-redis/v9"
)
var (
redisClient *redis.Client
mu sync.RWMutex
)
// SetRedis assigns the global redis client used across the application.
func SetRedis(client *redis.Client) {
mu.Lock()
defer mu.Unlock()
redisClient = client
}
// Redis returns the configured redis client. It may be nil if not yet initialised.
func Redis() *redis.Client {
mu.RLock()
defer mu.RUnlock()
return redisClient
}
// MustRedis returns the redis client or panics if it has not been set.
func MustRedis() *redis.Client {
mu.RLock()
client := redisClient
mu.RUnlock()
if client == nil {
panic(errors.New("redis client not initialised"))
}
return client
}
+44
View File
@@ -0,0 +1,44 @@
package capabilities
import (
"strings"
recordings "gitlab.com/mbugroup/lti-api.git/internal/modules/production/recordings"
)
// FromPermissions returns a filtered map of capabilities that the frontend can use
// to toggle features. Only permissions recognized by the application are exposed.
func FromPermissions(perms []string) map[string]bool {
if len(perms) == 0 {
return nil
}
out := make(map[string]bool)
for _, perm := range perms {
if key, ok := normalizeAndAllow(perm); ok {
out[key] = true
}
}
if len(out) == 0 {
return nil
}
return out
}
func normalizeAndAllow(perm string) (string, bool) {
perm = strings.ToLower(strings.TrimSpace(perm))
if perm == "" {
return "", false
}
if _, ok := allowed[perm]; !ok {
return "", false
}
return perm, true
}
var allowed = map[string]struct{}{
recordings.PermissionRecordingRead: {},
recordings.PermissionRecordingCreate: {},
recordings.PermissionRecordingUpdate: {},
recordings.PermissionRecordingDelete: {},
}
@@ -11,6 +11,7 @@ type BaseRepository[T any] interface {
GetAll(ctx context.Context, offset, limit int, modifier func(*gorm.DB) *gorm.DB) ([]T, int64, error)
GetByID(ctx context.Context, id uint, modifier func(*gorm.DB) *gorm.DB) (*T, error)
GetByIDs(ctx context.Context, ids []uint, modifier func(*gorm.DB) *gorm.DB) ([]T, error)
First(ctx context.Context, modifier func(*gorm.DB) *gorm.DB) (*T, error)
CreateOne(ctx context.Context, entity *T, modifier func(*gorm.DB) *gorm.DB) error
CreateMany(ctx context.Context, entities []*T, modifier func(*gorm.DB) *gorm.DB) error
@@ -96,6 +97,21 @@ func (r *BaseRepositoryImpl[T]) GetByIDs(
return entities, nil
}
func (r *BaseRepositoryImpl[T]) First(
ctx context.Context,
modifier func(*gorm.DB) *gorm.DB,
) (*T, error) {
entity := new(T)
q := r.db.WithContext(ctx)
if modifier != nil {
q = modifier(q)
}
if err := q.First(entity).Error; err != nil {
return nil, err
}
return entity, nil
}
// ---- CREATE ----
func (r *BaseRepositoryImpl[T]) CreateOne(
ctx context.Context,
+40 -8
View File
@@ -3,6 +3,8 @@ package validation
import (
"errors"
"fmt"
"reflect"
"strings"
"github.com/go-playground/validator/v10"
)
@@ -21,34 +23,41 @@ var customMessages = map[string]string{
"alphanum": "Field %s must contain only alphanumeric characters",
"oneof": "Invalid value for field %s",
"password": "Field %s must be at least 8 characters, contain uppercase, lowercase, number, and special character",
"gt": "Invalid %s, must be greater than %s",
}
func CustomErrorMessages(err error) map[string]string {
func CustomErrorMessages(err error) (string, map[string]string) {
var validationErrors validator.ValidationErrors
if errors.As(err, &validationErrors) {
return generateErrorMessages(validationErrors)
}
return nil
return "", nil
}
func generateErrorMessages(validationErrors validator.ValidationErrors) map[string]string {
func generateErrorMessages(validationErrors validator.ValidationErrors) (string, map[string]string) {
errorsMap := make(map[string]string)
for _, err := range validationErrors {
var firstMessage string
for i, err := range validationErrors {
fieldName := err.StructNamespace()
tag := err.Tag()
customMessage := customMessages[tag]
var msg string
if customMessage != "" {
errorsMap[fieldName] = formatErrorMessage(customMessage, err, tag)
msg = formatErrorMessage(customMessage, err, tag)
} else {
errorsMap[fieldName] = defaultErrorMessage(err)
msg = defaultErrorMessage(err)
}
errorsMap[fieldName] = msg
if i == 0 {
firstMessage = msg
}
}
return errorsMap
return firstMessage, errorsMap
}
func formatErrorMessage(customMessage string, err validator.FieldError, tag string) string {
if tag == "min" || tag == "max" || tag == "len" {
if tag == "min" || tag == "max" || tag == "len" || tag == "gt" {
return fmt.Sprintf(customMessage, err.Field(), err.Param())
}
return fmt.Sprintf(customMessage, err.Field())
@@ -61,6 +70,16 @@ func defaultErrorMessage(err validator.FieldError) string {
func Validator() *validator.Validate {
validate := validator.New()
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
if jsonTag := getTagName(fld, "json"); jsonTag != "" {
return jsonTag
}
if queryTag := getTagName(fld, "query"); queryTag != "" {
return queryTag
}
return fld.Name
})
if err := validate.RegisterValidation("password", Password); err != nil {
return nil
}
@@ -72,3 +91,16 @@ func Validator() *validator.Validate {
}
return validate
}
func getTagName(fld reflect.StructField, tag string) string {
value, ok := fld.Tag.Lookup(tag)
if !ok || value == "-" {
return ""
}
name := strings.Split(value, ",")[0]
if name == "" || name == "-" {
return ""
}
return name
}
BIN
View File
Binary file not shown.
+164 -22
View File
@@ -2,36 +2,69 @@ package config
import (
"encoding/json"
"fmt"
"strings"
"time"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
"github.com/spf13/viper"
)
type SSOClientConfig struct {
PublicID string `json:"public_id"`
RedirectURI string `json:"redirect_uri"`
Scope string `json:"scope"`
// Prompt string `json:"prompt"`
DefaultReturnURI string `json:"default_return_uri"`
AllowedReturnOrigins []string `json:"allowed_return_origins"`
SyncSecret string `json:"sync_secret"`
}
var (
IsProd bool
AppHost string
Version string
LogLevel string
AppPort int
DBHost string
DBUser string
DBPassword string
DBName string
DBPort int
JWTSecret string
JWTAccessExp int
JWTRefreshExp int
JWTResetPasswordExp int
JWTVerifyEmailExp int
RedisURL string
CORSAllowOrigins []string
CORSAllowMethods []string
CORSAllowHeaders []string
CORSExposeHeaders []string
CORSAllowCredentials bool
CORSMaxAge int
IsProd bool
AppHost string
Version string
LogLevel string
AppPort int
DBHost string
DBUser string
DBPassword string
DBName string
DBPort int
DBSSLMode string
DBSSLRootCert string
DBSSLCert string
DBSSLKey string
JWTSecret string
JWTAccessExp int
JWTRefreshExp int
JWTResetPasswordExp int
JWTVerifyEmailExp int
RedisURL string
CORSAllowOrigins []string
CORSAllowMethods []string
CORSAllowHeaders []string
CORSExposeHeaders []string
CORSAllowCredentials bool
CORSMaxAge int
SSOIssuer string
SSOJWKSURL string
SSOAllowedAudiences []string
SSOAuthorizeURL string
SSOTokenURL string
SSOGetMeURL string
SSOClients map[string]SSOClientConfig
SSOAccessCookieName string
SSORefreshCookieName string
SSOCookieDomain string
SSOCookieSecure bool
SSOCookieSameSite string
SSOTokenBlacklistPrefix string
SSOPKCETTL time.Duration
SSOUserSyncDrift time.Duration
SSOUserSyncNonceTTL time.Duration
SSOUserSyncMaxBodyBytes int
)
func init() {
@@ -50,6 +83,10 @@ func init() {
DBPassword = viper.GetString("DB_PASSWORD")
DBName = viper.GetString("DB_NAME")
DBPort = viper.GetInt("DB_PORT")
DBSSLMode = defaultString(viper.GetString("DB_SSLMODE"), "disable")
DBSSLRootCert = strings.TrimSpace(viper.GetString("DB_SSLROOTCERT"))
DBSSLCert = strings.TrimSpace(viper.GetString("DB_SSLCERT"))
DBSSLKey = strings.TrimSpace(viper.GetString("DB_SSLKEY"))
// jwt configuration
JWTSecret = viper.GetString("JWT_SECRET")
@@ -68,6 +105,44 @@ func init() {
// Redis
RedisURL = viper.GetString("REDIS_URL")
// SSO integration
SSOIssuer = viper.GetString("SSO_ISSUER")
SSOJWKSURL = viper.GetString("SSO_JWKS_URL")
SSOAllowedAudiences = parseList("SSO_ALLOWED_AUDIENCES")
SSOAuthorizeURL = viper.GetString("SSO_AUTHORIZE_URL")
SSOTokenURL = viper.GetString("SSO_TOKEN_URL")
SSOGetMeURL = viper.GetString("SSO_GETME_URL")
SSOAccessCookieName = defaultString(viper.GetString("SSO_ACCESS_COOKIE_NAME"), "sso_access")
SSORefreshCookieName = defaultString(viper.GetString("SSO_REFRESH_COOKIE_NAME"), "sso_refresh")
SSOCookieDomain = viper.GetString("SSO_COOKIE_DOMAIN")
SSOCookieSecure = viper.GetBool("SSO_COOKIE_SECURE")
SSOCookieSameSite = defaultString(viper.GetString("SSO_COOKIE_SAMESITE"), "Lax")
SSOTokenBlacklistPrefix = defaultString(viper.GetString("SSO_TOKEN_BLACKLIST_PREFIX"), "sso:blacklist")
if ttl := viper.GetInt("SSO_PKCE_TTL_SECONDS"); ttl > 0 {
SSOPKCETTL = time.Duration(ttl) * time.Second
} else {
SSOPKCETTL = 5 * time.Minute
}
SSOClients = loadSSOClients("SSO_CLIENTS")
if drift := viper.GetInt("SSO_USER_SYNC_SIGNATURE_DRIFT_SECONDS"); drift > 0 {
SSOUserSyncDrift = time.Duration(drift) * time.Second
} else {
SSOUserSyncDrift = 2 * time.Minute
}
if ttl := viper.GetInt("SSO_USER_SYNC_NONCE_TTL_SECONDS"); ttl > 0 {
SSOUserSyncNonceTTL = time.Duration(ttl) * time.Second
} else {
SSOUserSyncNonceTTL = 10 * time.Minute
}
SSOUserSyncMaxBodyBytes = viper.GetInt("SSO_USER_SYNC_MAX_BODY_BYTES")
if SSOUserSyncMaxBodyBytes <= 0 {
SSOUserSyncMaxBodyBytes = 32 * 1024
}
if IsProd {
ensureProdConfig()
}
}
func loadConfig() {
@@ -117,3 +192,70 @@ func parseListWithDefault(key, def string) []string {
}
return parts
}
func loadSSOClients(key string) map[string]SSOClientConfig {
clients := make(map[string]SSOClientConfig)
raw := strings.TrimSpace(viper.GetString(key))
if raw == "" {
return clients
}
if err := json.Unmarshal([]byte(raw), &clients); err != nil {
utils.Log.Errorf("Failed to parse %s: %v", key, err)
return make(map[string]SSOClientConfig)
}
result := make(map[string]SSOClientConfig, len(clients))
for alias, cfg := range clients {
alias = strings.ToLower(strings.TrimSpace(alias))
for i, origin := range cfg.AllowedReturnOrigins {
cfg.AllowedReturnOrigins[i] = strings.TrimSpace(origin)
}
cfg.SyncSecret = strings.TrimSpace(cfg.SyncSecret)
result[alias] = cfg
}
return result
}
func defaultString(v, def string) string {
if strings.TrimSpace(v) == "" {
return def
}
return v
}
func ensureProdConfig() {
if SSOAuthorizeURL == "" || !strings.HasPrefix(SSOAuthorizeURL, "https://") {
panic("SSO_AUTHORIZE_URL must be https in production")
}
if SSOTokenURL == "" || !strings.HasPrefix(SSOTokenURL, "https://") {
panic("SSO_TOKEN_URL must be https in production")
}
if SSOGetMeURL == "" || !strings.HasPrefix(SSOGetMeURL, "https://") {
panic("SSO_GETME_URL must be https in production")
}
if !SSOCookieSecure {
panic("SSO_COOKIE_SECURE must be true in production")
}
if SSOCookieDomain == "" {
panic("SSO_COOKIE_DOMAIN must be configured in production")
}
if len(SSOAllowedAudiences) == 0 {
panic("SSO_ALLOWED_AUDIENCES must contain at least one audience in production")
}
for alias, cfg := range SSOClients {
if strings.TrimSpace(cfg.SyncSecret) == "" {
panic(fmt.Sprintf("SSO_CLIENTS[%s].sync_secret must be configured in production", alias))
}
if len(cfg.SyncSecret) < 16 {
panic(fmt.Sprintf("SSO_CLIENTS[%s].sync_secret must be at least 16 characters", alias))
}
}
if SSOUserSyncDrift <= 0 {
panic("SSO_USER_SYNC_SIGNATURE_DRIFT_SECONDS must be greater than zero in production")
}
if SSOUserSyncNonceTTL <= 0 {
panic("SSO_USER_SYNC_NONCE_TTL_SECONDS must be greater than zero in production")
}
if SSOUserSyncMaxBodyBytes <= 0 {
panic("SSO_USER_SYNC_MAX_BODY_BYTES must be greater than zero in production")
}
}
+20 -4
View File
@@ -2,6 +2,7 @@ package database
import (
"fmt"
"strings"
"time"
"gitlab.com/mbugroup/lti-api.git/internal/config"
@@ -13,10 +14,25 @@ import (
)
func Connect(dbHost, dbName string) *gorm.DB {
dsn := fmt.Sprintf(
"host=%s user=%s password=%s dbname=%s port=%d sslmode=disable TimeZone=Asia/Shanghai",
dbHost, config.DBUser, config.DBPassword, dbName, config.DBPort,
)
parts := []string{
fmt.Sprintf("host=%s", dbHost),
fmt.Sprintf("user=%s", config.DBUser),
fmt.Sprintf("password=%s", config.DBPassword),
fmt.Sprintf("dbname=%s", dbName),
fmt.Sprintf("port=%d", config.DBPort),
fmt.Sprintf("sslmode=%s", config.DBSSLMode),
"TimeZone=Asia/Shanghai",
}
if config.DBSSLRootCert != "" {
parts = append(parts, fmt.Sprintf("sslrootcert=%s", config.DBSSLRootCert))
}
if config.DBSSLCert != "" {
parts = append(parts, fmt.Sprintf("sslcert=%s", config.DBSSLCert))
}
if config.DBSSLKey != "" {
parts = append(parts, fmt.Sprintf("sslkey=%s", config.DBSSLKey))
}
dsn := strings.Join(parts, " ")
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
@@ -9,13 +9,9 @@ CREATE TABLE users (
deleted_at TIMESTAMPTZ
);
CREATE UNIQUE INDEX users_id_user_unique ON users (id_user)
WHERE
deleted_at IS NULL;
CREATE UNIQUE INDEX users_id_user_unique ON users (id_user) WHERE deleted_at IS NULL;
CREATE UNIQUE INDEX users_email_unique ON users (email)
WHERE
deleted_at IS NULL;
CREATE UNIQUE INDEX users_email_unique ON users (email) WHERE deleted_at IS NULL;
-- FLAGS
CREATE TABLE flags (
@@ -334,4 +330,4 @@ CREATE INDEX stock_logs_created_by_idx ON stock_logs (created_by);
CREATE INDEX stock_logs_created_at_idx ON stock_logs (created_at);
CREATE INDEX stock_logs_deleted_at_idx ON stock_logs (deleted_at);
CREATE INDEX stock_logs_deleted_at_idx ON stock_logs (deleted_at);
@@ -0,0 +1,2 @@
ALTER TABLE users
DROP CONSTRAINT IF EXISTS users_id_user_key;
@@ -0,0 +1,2 @@
ALTER TABLE users
ADD CONSTRAINT users_id_user_key UNIQUE (id_user);
@@ -0,0 +1,2 @@
ALTER TABLE kandangs
DROP COLUMN IF EXISTS capacity;
@@ -0,0 +1,2 @@
ALTER TABLE kandangs
ADD COLUMN capacity NUMERIC(15,3) NOT NULL;
+5 -4
View File
@@ -235,13 +235,14 @@ func seedKandangs(tx *gorm.DB, createdBy uint, locations map[string]uint, users
seeds := []struct {
Name string
Status utils.KandangStatus
Capacity float64
Location string
PicKey string
}{
{Name: "Singaparna 1", Status: utils.KandangStatusNonActive, Location: "Singaparna", PicKey: "admin"},
{Name: "Singaparna 2", Status: utils.KandangStatusNonActive, Location: "Singaparna", PicKey: "admin"},
{Name: "Cikaum 1", Status: utils.KandangStatusNonActive, Location: "Cikaum", PicKey: "admin"},
{Name: "Cikaum 2", Status: utils.KandangStatusNonActive, Location: "Cikaum", PicKey: "admin"},
{Name: "Singaparna 1", Status: utils.KandangStatusNonActive, Capacity: 50000, Location: "Singaparna", PicKey: "admin"},
{Name: "Singaparna 2", Status: utils.KandangStatusNonActive, Capacity: 50000, Location: "Singaparna", PicKey: "admin"},
{Name: "Cikaum 1", Status: utils.KandangStatusNonActive, Capacity: 50000, Location: "Cikaum", PicKey: "admin"},
{Name: "Cikaum 2", Status: utils.KandangStatusNonActive, Capacity: 50000, Location: "Cikaum", PicKey: "admin"},
}
result := make(map[string]uint, len(seeds))
+13 -12
View File
@@ -7,17 +7,18 @@ import (
)
type Kandang struct {
Id uint `gorm:"primaryKey"`
Name string `gorm:"not null;uniqueIndex:kandangs_name_unique,where:deleted_at IS NULL"`
Status string `gorm:"type:varchar(50);not null"`
LocationId uint `gorm:"not null"`
PicId uint `gorm:"not null"`
CreatedBy uint `gorm:"not null"`
CreatedAt time.Time `gorm:"autoCreateTime"`
UpdatedAt time.Time `gorm:"autoUpdateTime"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
CreatedUser User `gorm:"foreignKey:CreatedBy;references:Id"`
Location Location `gorm:"foreignKey:LocationId;references:Id"`
Pic User `gorm:"foreignKey:PicId;references:Id"`
Id uint `gorm:"primaryKey"`
Name string `gorm:"not null;uniqueIndex:kandangs_name_unique,where:deleted_at IS NULL"`
Status string `gorm:"type:varchar(50);not null"`
LocationId uint `gorm:"not null"`
Capacity float64 `gorm:"not null"`
PicId uint `gorm:"not null"`
CreatedBy uint `gorm:"not null"`
CreatedAt time.Time `gorm:"autoCreateTime"`
UpdatedAt time.Time `gorm:"autoUpdateTime"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
CreatedUser User `gorm:"foreignKey:CreatedBy;references:Id"`
Location Location `gorm:"foreignKey:LocationId;references:Id"`
Pic User `gorm:"foreignKey:PicId;references:Id"`
ProjectFlockKandangs []ProjectFlockKandang `gorm:"foreignKey:KandangId;references:Id" json:"-"`
}
+177 -85
View File
@@ -1,101 +1,193 @@
package middleware
// import (
// "strings"
import (
"strings"
// "gitlab.com/mbugroup/lti-api.git/internal/config"
// service "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
// "gitlab.com/mbugroup/lti-api.git/internal/utils"
"gitlab.com/mbugroup/lti-api.git/internal/config"
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
service "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
"gitlab.com/mbugroup/lti-api.git/internal/sso"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
// "github.com/gofiber/fiber/v2"
// )
"github.com/gofiber/fiber/v2"
)
// func Auth(userService service.UserService, requiredRights ...string) fiber.Handler {
// return func(c *fiber.Ctx) error {
// authHeader := c.Get("Authorization")
// token := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer "))
const (
authContextLocalsKey = "auth.context"
authUserLocalsKey = "auth.user"
)
// if token == "" {
// return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
// }
// AuthContext keeps authentication details captured by the middleware.
type AuthContext struct {
Token string
Verification *sso.VerificationResult
User *entity.User
Roles []sso.Role
Permissions map[string]struct{}
}
// userID, err := utils.VerifyToken(token, config.JWTSecret, config.TokenTypeAccess)
// if err != nil {
// return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
// }
// Auth validates the incoming request against the central SSO access token and
// loads the corresponding local user. Optional scopes can be provided to enforce
// fine-grained authorization using the SSO access token scopes.
func Auth(userService service.UserService, requiredScopes ...string) fiber.Handler {
return func(c *fiber.Ctx) error {
token := bearerToken(c)
if token == "" {
token = strings.TrimSpace(c.Cookies(config.SSOAccessCookieName))
}
if token == "" {
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
}
// // Only end-user subjects are allowed by this middleware. Service tokens
// if verification.UserID == 0 {
// return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
// }
verification, err := sso.VerifyAccessToken(token)
if err != nil {
utils.Log.WithError(err).Warn("auth: token verification failed")
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
}
// // Fail-closed on revocation check errors for stricter security posture.
// if revoker := session.GetRevocationStore(); revoker != nil {
// if fingerprint := session.TokenFingerprint(token); fingerprint != "" {
// revoked, err := revoker.IsRevoked(c.Context(), fingerprint)
// if err != nil {
// utils.Log.WithError(err).Warn("failed to check token revocation")
// return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
// }
// if revoked {
// return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
// }
// }
// }
if verification.UserID == 0 {
return fiber.NewError(fiber.StatusForbidden, "Service authentication is not permitted for this endpoint")
}
// user, err := userService.GetBySSOUserID(c, verification.UserID)
// if err != nil || user == nil {
// return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
// }
if err := ensureNotRevoked(c, token, verification); err != nil {
return err
}
// if len(requiredRights) > 0 && verification.Claims != nil {
// if !hasAllScopes(verification.Claims.Scopes(), requiredRights) {
// return fiber.NewError(fiber.StatusForbidden, "Insufficient scope")
// }
// }
user, err := userService.GetBySSOUserID(c, verification.UserID)
if err != nil || user == nil {
utils.Log.WithError(err).Warn("auth: failed to resolve user from repository")
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
}
// c.Locals("user", user)
if len(requiredScopes) > 0 {
if verification.Claims == nil || !hasAllScopes(verification.Claims.Scopes(), requiredScopes) {
return fiber.NewError(fiber.StatusForbidden, "Insufficient scope")
}
}
// // if len(requiredRights) > 0 {
// // userRights, hasRights := config.RoleRights[user.Role]
// // if (!hasRights || !hasAllRights(userRights, requiredRights)) && c.Params("userId") != userID {
// // return fiber.NewError(fiber.StatusForbidden, "You don't have permission to access this resource")
// // }
// // }
var roles []sso.Role
permissions := make(map[string]struct{})
if verification.UserID != 0 {
if profile, err := sso.FetchProfile(c.Context(), token, verification); err != nil {
utils.Log.WithError(err).Warn("auth: failed to fetch sso profile")
} else if profile != nil {
roles = profile.Roles
for _, perm := range profile.PermissionNames() {
if perm != "" {
permissions[perm] = struct{}{}
}
}
}
}
// return c.Next()
// }
// }
ctx := &AuthContext{
Token: token,
Verification: verification,
User: user,
Roles: roles,
Permissions: permissions,
}
// // bearerToken extracts a Bearer token from the Authorization header using
// // case-insensitive scheme matching and tolerant whitespace handling.
// func bearerToken(c *fiber.Ctx) string {
// parts := strings.Fields(c.Get("Authorization"))
// if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
// return strings.TrimSpace(parts[1])
// }
// return ""
// }
c.Locals(authContextLocalsKey, ctx)
c.Locals(authUserLocalsKey, user)
// func hasAllScopes(have, required []string) bool {
// if len(required) == 0 {
// return true
// }
// set := make(map[string]struct{}, len(have))
// for _, s := range have {
// s = strings.ToLower(strings.TrimSpace(s))
// if s != "" {
// set[s] = struct{}{}
// }
// }
// for _, r := range required {
// r = strings.ToLower(strings.TrimSpace(r))
// if r == "" {
// continue
// }
// if _, ok := set[r]; !ok {
// return false
// }
// }
// return true
// }
return c.Next()
}
}
// AuthenticatedUser returns the authenticated user populated by Auth.
func AuthenticatedUser(c *fiber.Ctx) (*entity.User, bool) {
value := c.Locals(authUserLocalsKey)
if user, ok := value.(*entity.User); ok && user != nil {
return user, true
}
return nil, false
}
// AuthDetails returns the full authentication context (token, claims, user).
func AuthDetails(c *fiber.Ctx) (*AuthContext, bool) {
value := c.Locals(authContextLocalsKey)
if ctx, ok := value.(*AuthContext); ok && ctx != nil {
return ctx, true
}
return nil, false
}
// ensureNotRevoked ensures the token is not revoked or superseded by a forced logout.
func ensureNotRevoked(c *fiber.Ctx, token string, verification *sso.VerificationResult) error {
revoker := session.GetRevocationStore()
if revoker == nil {
return nil
}
if fingerprint := session.TokenFingerprint(token); fingerprint != "" {
revoked, err := revoker.IsRevoked(c.Context(), fingerprint)
if err != nil {
utils.Log.WithError(err).Warn("auth: token revocation check failed")
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
}
if revoked {
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
}
}
if verification.UserID == 0 {
return nil
}
logoutAt, err := revoker.UserLogoutTime(c.Context(), verification.UserID)
if err != nil {
utils.Log.WithError(err).Warn("auth: failed to load user logout marker")
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
}
if logoutAt.IsZero() {
return nil
}
claims := verification.Claims
if claims == nil || claims.IssuedAt == nil {
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
}
issuedAt := claims.IssuedAt.Time
// Treat tokens issued at or before the forced logout timestamp as invalid.
if !issuedAt.After(logoutAt) {
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
}
return nil
}
// bearerToken extracts a Bearer token from the Authorization header using
// case-insensitive scheme matching and tolerant whitespace handling.
func bearerToken(c *fiber.Ctx) string {
parts := strings.Fields(c.Get("Authorization"))
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
return strings.TrimSpace(parts[1])
}
return ""
}
func hasAllScopes(have, required []string) bool {
if len(required) == 0 {
return true
}
set := make(map[string]struct{}, len(have))
for _, s := range have {
s = strings.ToLower(strings.TrimSpace(s))
if s != "" {
set[s] = struct{}{}
}
}
for _, r := range required {
r = strings.ToLower(strings.TrimSpace(r))
if r == "" {
continue
}
if _, ok := set[r]; !ok {
return false
}
}
return true
}
+21
View File
@@ -24,3 +24,24 @@ func LimiterConfig() fiber.Handler {
SkipSuccessfulRequests: true,
})
}
func NewLimiter(max int, expiration time.Duration) fiber.Handler {
if max <= 0 {
max = 10
}
if expiration <= 0 {
expiration = time.Minute
}
return limiter.New(limiter.Config{
Max: max,
Expiration: expiration,
LimitReached: func(c *fiber.Ctx) error {
return c.Status(fiber.StatusTooManyRequests).
JSON(response.Common{
Code: fiber.StatusTooManyRequests,
Status: "error",
Message: "Too many requests, please try again later",
})
},
})
}
+75
View File
@@ -0,0 +1,75 @@
package middleware
import (
"strings"
"github.com/gofiber/fiber/v2"
)
// RequirePermissions ensures the authenticated user possesses all specified permissions.
func RequirePermissions(perms ...string) fiber.Handler {
required := canonicalPermissions(perms)
return func(c *fiber.Ctx) error {
if len(required) == 0 {
return c.Next()
}
ctx, ok := AuthDetails(c)
if !ok || ctx == nil {
return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
}
userPerms := ctx.permissionSet()
if len(userPerms) == 0 {
return fiber.NewError(fiber.StatusForbidden, "Insufficient permission")
}
for _, perm := range required {
if _, has := userPerms[perm]; !has {
return fiber.NewError(fiber.StatusForbidden, "Insufficient permission")
}
}
return c.Next()
}
}
// HasPermission reports whether the current request context includes the given permission.
func HasPermission(c *fiber.Ctx, perm string) bool {
ctx, ok := AuthDetails(c)
if !ok || ctx == nil {
return false
}
perm = canonicalPermission(perm)
if perm == "" {
return false
}
_, has := ctx.permissionSet()[perm]
return has
}
func (a *AuthContext) permissionSet() map[string]struct{} {
if a == nil || a.Permissions == nil {
return nil
}
return a.Permissions
}
func canonicalPermissions(perms []string) []string {
out := make([]string, 0, len(perms))
seen := make(map[string]struct{}, len(perms))
for _, perm := range perms {
if canonical := canonicalPermission(perm); canonical != "" {
if _, ok := seen[canonical]; ok {
continue
}
seen[canonical] = struct{}{}
out = append(out, canonical)
}
}
return out
}
func canonicalPermission(perm string) string {
return strings.ToLower(strings.TrimSpace(perm))
}
+4
View File
@@ -16,6 +16,10 @@ func JSONBody() fiber.Handler {
return c.Next()
}
if strings.EqualFold(c.Path(), "/api/sso/users/sync") {
return c.Next()
}
body := c.Body()
if len(body) == 0 {
return c.Next()
@@ -1,7 +1,7 @@
package productWarehouses
import (
// m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/inventory/product-warehouses/controllers"
productWarehouse "gitlab.com/mbugroup/lti-api.git/internal/modules/inventory/product-warehouses/services"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
@@ -13,12 +13,7 @@ func ProductWarehouseRoutes(v1 fiber.Router, u user.UserService, s productWareho
ctrl := controller.NewProductWarehouseController(s)
route := v1.Group("/product-warehouses")
// route.Get("/", m.Auth(u), ctrl.GetAll)
// route.Post("/", m.Auth(u), ctrl.CreateOne)
// route.Get("/:id", m.Auth(u), ctrl.GetOne)
// route.Patch("/:id", m.Auth(u), ctrl.UpdateOne)
// route.Delete("/:id", m.Auth(u), ctrl.DeleteOne)
route.Use(m.Auth(u))
route.Get("/", ctrl.GetAll)
route.Get("/:id", ctrl.GetOne)
+1 -1
View File
@@ -7,8 +7,8 @@ import (
"github.com/gofiber/fiber/v2"
"gorm.io/gorm"
productWarehouses "gitlab.com/mbugroup/lti-api.git/internal/modules/inventory/product-warehouses"
adjustments "gitlab.com/mbugroup/lti-api.git/internal/modules/inventory/adjustments"
productWarehouses "gitlab.com/mbugroup/lti-api.git/internal/modules/inventory/product-warehouses"
transfers "gitlab.com/mbugroup/lti-api.git/internal/modules/inventory/transfers"
// MODULE IMPORTS
)
@@ -1,7 +1,7 @@
package transfers
import (
// m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/inventory/transfers/controllers"
transfer "gitlab.com/mbugroup/lti-api.git/internal/modules/inventory/transfers/services"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
@@ -13,12 +13,7 @@ func TransferRoutes(v1 fiber.Router, u user.UserService, s transfer.TransferServ
ctrl := controller.NewTransferController(s)
route := v1.Group("/transfers")
// route.Get("/", m.Auth(u), ctrl.GetAll)
// route.Post("/", m.Auth(u), ctrl.CreateOne)
// route.Get("/:id", m.Auth(u), ctrl.GetOne)
// route.Patch("/:id", m.Auth(u), ctrl.UpdateOne)
// route.Delete("/:id", m.Auth(u), ctrl.DeleteOne)
route.Use(m.Auth(u))
route.Get("/", ctrl.GetAll)
route.Post("/", ctrl.CreateOne)
@@ -328,7 +328,6 @@ func groupDeliveryProducts(products []MarketingDeliveryProductDTO, soNumber stri
return groups
}
// getVehicleNumber mengambil vehicle number dari DeliveryProduct jika ada
func getVehicleNumber(e entity.MarketingProduct) string {
if e.DeliveryProduct != nil && e.DeliveryProduct.VehicleNumber != "" {
return e.DeliveryProduct.VehicleNumber
-1
View File
@@ -23,4 +23,3 @@ func (AreaModule) RegisterRoutes(router fiber.Router, db *gorm.DB, validate *val
AreaRoutes(router, userService, areaService)
}
+2 -7
View File
@@ -1,7 +1,7 @@
package areas
import (
// m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/master/areas/controllers"
area "gitlab.com/mbugroup/lti-api.git/internal/modules/master/areas/services"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
@@ -13,12 +13,7 @@ func AreaRoutes(v1 fiber.Router, u user.UserService, s area.AreaService) {
ctrl := controller.NewAreaController(s)
route := v1.Group("/areas")
// route.Get("/", m.Auth(u), ctrl.GetAll)
// route.Post("/", m.Auth(u), ctrl.CreateOne)
// route.Get("/:id", m.Auth(u), ctrl.GetOne)
// route.Patch("/:id", m.Auth(u), ctrl.UpdateOne)
// route.Delete("/:id", m.Auth(u), ctrl.DeleteOne)
route.Use(m.Auth(u))
route.Get("/", ctrl.GetAll)
route.Post("/", ctrl.CreateOne)
-1
View File
@@ -23,4 +23,3 @@ func (BankModule) RegisterRoutes(router fiber.Router, db *gorm.DB, validate *val
BankRoutes(router, userService, bankService)
}
+2 -7
View File
@@ -1,7 +1,7 @@
package banks
import (
// m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/master/banks/controllers"
bank "gitlab.com/mbugroup/lti-api.git/internal/modules/master/banks/services"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
@@ -13,12 +13,7 @@ func BankRoutes(v1 fiber.Router, u user.UserService, s bank.BankService) {
ctrl := controller.NewBankController(s)
route := v1.Group("/banks")
// route.Get("/", m.Auth(u), ctrl.GetAll)
// route.Post("/", m.Auth(u), ctrl.CreateOne)
// route.Get("/:id", m.Auth(u), ctrl.GetOne)
// route.Patch("/:id", m.Auth(u), ctrl.UpdateOne)
// route.Delete("/:id", m.Auth(u), ctrl.DeleteOne)
route.Use(m.Auth(u))
route.Get("/", ctrl.GetAll)
route.Post("/", ctrl.CreateOne)
@@ -23,4 +23,3 @@ func (CustomerModule) RegisterRoutes(router fiber.Router, db *gorm.DB, validate
CustomerRoutes(router, userService, customerService)
}
+2 -7
View File
@@ -1,7 +1,7 @@
package customers
import (
// m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/master/customers/controllers"
customer "gitlab.com/mbugroup/lti-api.git/internal/modules/master/customers/services"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
@@ -13,12 +13,7 @@ func CustomerRoutes(v1 fiber.Router, u user.UserService, s customer.CustomerServ
ctrl := controller.NewCustomerController(s)
route := v1.Group("/customers")
// route.Get("/", m.Auth(u), ctrl.GetAll)
// route.Post("/", m.Auth(u), ctrl.CreateOne)
// route.Get("/:id", m.Auth(u), ctrl.GetOne)
// route.Patch("/:id", m.Auth(u), ctrl.UpdateOne)
// route.Delete("/:id", m.Auth(u), ctrl.DeleteOne)
route.Use(m.Auth(u))
route.Get("/", ctrl.GetAll)
route.Post("/", ctrl.CreateOne)
+2 -7
View File
@@ -1,7 +1,7 @@
package fcrs
import (
// m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/master/fcrs/controllers"
fcr "gitlab.com/mbugroup/lti-api.git/internal/modules/master/fcrs/services"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
@@ -13,12 +13,7 @@ func FcrRoutes(v1 fiber.Router, u user.UserService, s fcr.FcrService) {
ctrl := controller.NewFcrController(s)
route := v1.Group("/fcrs")
// route.Get("/", m.Auth(u), ctrl.GetAll)
// route.Post("/", m.Auth(u), ctrl.CreateOne)
// route.Get("/:id", m.Auth(u), ctrl.GetOne)
// route.Patch("/:id", m.Auth(u), ctrl.UpdateOne)
// route.Delete("/:id", m.Auth(u), ctrl.DeleteOne)
route.Use(m.Auth(u))
route.Get("/", ctrl.GetAll)
route.Post("/", ctrl.CreateOne)
@@ -43,9 +43,9 @@ func ToFlockListDTO(e entity.Flock) FlockListDTO {
return FlockListDTO{
FlockBaseDTO: ToFlockBaseDTO(e),
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
CreatedUser: createdUser,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
CreatedUser: createdUser,
}
}
+2 -7
View File
@@ -1,7 +1,7 @@
package flocks
import (
// m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/master/flocks/controllers"
flock "gitlab.com/mbugroup/lti-api.git/internal/modules/master/flocks/services"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
@@ -13,12 +13,7 @@ func FlockRoutes(v1 fiber.Router, u user.UserService, s flock.FlockService) {
ctrl := controller.NewFlockController(s)
route := v1.Group("/flocks")
// route.Get("/", m.Auth(u), ctrl.GetAll)
// route.Post("/", m.Auth(u), ctrl.CreateOne)
// route.Get("/:id", m.Auth(u), ctrl.GetOne)
// route.Patch("/:id", m.Auth(u), ctrl.UpdateOne)
// route.Delete("/:id", m.Auth(u), ctrl.DeleteOne)
route.Use(m.Auth(u))
route.Get("/", ctrl.GetAll)
route.Post("/", ctrl.CreateOne)
@@ -1,11 +1,11 @@
package validation
type Create struct {
Name string `json:"name" validate:"required_strict,min=3"`
Name string `json:"name" validate:"required_strict,min=3"`
}
type Update struct {
Name *string `json:"name,omitempty" validate:"omitempty"`
Name *string `json:"name,omitempty" validate:"omitempty"`
}
type Query struct {
@@ -14,6 +14,7 @@ type KandangBaseDTO struct {
Id uint `json:"id"`
Name string `json:"name"`
Status string `json:"status"`
Capacity float64 `json:"capacity"`
Location *locationDTO.LocationBaseDTO `json:"location"`
Pic *userDTO.UserBaseDTO `json:"pic"`
}
@@ -48,6 +49,7 @@ func ToKandangBaseDTO(e entity.Kandang) KandangBaseDTO {
Id: e.Id,
Name: e.Name,
Status: e.Status,
Capacity: e.Capacity,
Location: location,
Pic: pic,
}
@@ -23,4 +23,3 @@ func (KandangModule) RegisterRoutes(router fiber.Router, db *gorm.DB, validate *
KandangRoutes(router, userService, kandangService)
}
+1 -6
View File
@@ -13,12 +13,7 @@ func KandangRoutes(v1 fiber.Router, u user.UserService, s kandang.KandangService
ctrl := controller.NewKandangController(s)
route := v1.Group("/kandangs")
// route.Get("/", m.Auth(u), ctrl.GetAll)
// route.Post("/", m.Auth(u), ctrl.CreateOne)
// route.Get("/:id", m.Auth(u), ctrl.GetOne)
// route.Patch("/:id", m.Auth(u), ctrl.UpdateOne)
// route.Delete("/:id", m.Auth(u), ctrl.DeleteOne)
// route.Use(m.Auth(u))
route.Get("/", ctrl.GetAll)
route.Post("/", ctrl.CreateOne)
@@ -41,7 +41,7 @@ func NewKandangService(repo repository.KandangRepository, validate *validator.Va
func (s kandangService) withRelations(db *gorm.DB) *gorm.DB {
return db.Preload("CreatedUser").Preload("Location").Preload("Pic").Preload("ProjectFlockKandangs.ProjectFlock")
}
func (s kandangService) GetAll(c *fiber.Ctx, params *validation.Query) ([]entity.Kandang, int64, error) {
@@ -132,11 +132,12 @@ func (s *kandangService) CreateOne(c *fiber.Ctx, req *validation.Create) (*entit
//TODO: created by dummy
createBody := &entity.Kandang{
Name: req.Name,
LocationId: req.LocationId,
Status: status,
PicId: req.PicId,
CreatedBy: 1,
Name: req.Name,
LocationId: req.LocationId,
Capacity: req.Capacity,
Status: status,
PicId: req.PicId,
CreatedBy: 1,
}
if err := s.Repository.CreateOne(c.Context(), createBody, nil); err != nil {
@@ -194,6 +195,10 @@ func (s kandangService) UpdateOne(c *fiber.Ctx, req *validation.Update, id uint)
updateBody["pic_id"] = *req.PicId
}
if req.Capacity != nil {
updateBody["capacity"] = *req.Capacity
}
finalStatus := strings.ToUpper(existing.Status)
if req.Status != nil {
status := strings.ToUpper(*req.Status)
@@ -1,19 +1,21 @@
package validation
type Create struct {
Name string `json:"name" validate:"required_strict,min=3"`
Status string `json:"status,omitempty" validate:"omitempty,min=3"`
LocationId uint `json:"location_id" validate:"required_strict,number,gt=0"`
PicId uint `json:"pic_id" validate:"required_strict,number,gt=0"`
ProjectFlockId *uint `json:"project_flock_id" validate:"omitempty,number,gt=0"`
Name string `json:"name" validate:"required_strict,min=3"`
Status string `json:"status,omitempty" validate:"omitempty,min=3"`
Capacity float64 `json:"capacity" validate:"required_strict,gt=0"`
LocationId uint `json:"location_id" validate:"required_strict,number,gt=0"`
PicId uint `json:"pic_id" validate:"required_strict,number,gt=0"`
ProjectFlockId *uint `json:"project_flock_id" validate:"omitempty,number,gt=0"`
}
type Update struct {
Name *string `json:"name,omitempty" validate:"omitempty"`
Status *string `json:"status,omitempty" validate:"omitempty,min=3"`
LocationId *uint `json:"location_id,omitempty" validate:"omitempty,number,gt=0"`
PicId *uint `json:"pic_id,omitempty" validate:"omitempty,number,gt=0"`
ProjectFlockId *uint `json:"project_flock_id,omitempty" validate:"omitempty,number,gt=0"`
Name *string `json:"name,omitempty" validate:"omitempty"`
Status *string `json:"status,omitempty" validate:"omitempty,min=3"`
Capacity *float64 `json:"capacity" validate:"omitempty,gt=0"`
LocationId *uint `json:"location_id,omitempty" validate:"omitempty,number,gt=0"`
PicId *uint `json:"pic_id,omitempty" validate:"omitempty,number,gt=0"`
ProjectFlockId *uint `json:"project_flock_id,omitempty" validate:"omitempty,number,gt=0"`
}
type Query struct {
@@ -23,4 +23,3 @@ func (LocationModule) RegisterRoutes(router fiber.Router, db *gorm.DB, validate
LocationRoutes(router, userService, locationService)
}
+2 -7
View File
@@ -1,7 +1,7 @@
package locations
import (
// m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/master/locations/controllers"
location "gitlab.com/mbugroup/lti-api.git/internal/modules/master/locations/services"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
@@ -13,12 +13,7 @@ func LocationRoutes(v1 fiber.Router, u user.UserService, s location.LocationServ
ctrl := controller.NewLocationController(s)
route := v1.Group("/locations")
// route.Get("/", m.Auth(u), ctrl.GetAll)
// route.Post("/", m.Auth(u), ctrl.CreateOne)
// route.Get("/:id", m.Auth(u), ctrl.GetOne)
// route.Patch("/:id", m.Auth(u), ctrl.UpdateOne)
// route.Delete("/:id", m.Auth(u), ctrl.DeleteOne)
route.Use(m.Auth(u))
route.Get("/", ctrl.GetAll)
route.Post("/", ctrl.CreateOne)
@@ -23,4 +23,3 @@ func (NonstockModule) RegisterRoutes(router fiber.Router, db *gorm.DB, validate
NonstockRoutes(router, userService, nonstockService)
}
+2 -7
View File
@@ -1,7 +1,7 @@
package nonstocks
import (
// m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/master/nonstocks/controllers"
nonstock "gitlab.com/mbugroup/lti-api.git/internal/modules/master/nonstocks/services"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
@@ -13,12 +13,7 @@ func NonstockRoutes(v1 fiber.Router, u user.UserService, s nonstock.NonstockServ
ctrl := controller.NewNonstockController(s)
route := v1.Group("/nonstocks")
// route.Get("/", m.Auth(u), ctrl.GetAll)
// route.Post("/", m.Auth(u), ctrl.CreateOne)
// route.Get("/:id", m.Auth(u), ctrl.GetOne)
// route.Patch("/:id", m.Auth(u), ctrl.UpdateOne)
// route.Delete("/:id", m.Auth(u), ctrl.DeleteOne)
route.Use(m.Auth(u))
route.Get("/", ctrl.GetAll)
route.Post("/", ctrl.CreateOne)
@@ -1,7 +1,7 @@
package productcategories
import (
// m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/master/product-categories/controllers"
productCategory "gitlab.com/mbugroup/lti-api.git/internal/modules/master/product-categories/services"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
@@ -13,12 +13,7 @@ func ProductCategoryRoutes(v1 fiber.Router, u user.UserService, s productCategor
ctrl := controller.NewProductCategoryController(s)
route := v1.Group("/product-categories")
// route.Get("/", m.Auth(u), ctrl.GetAll)
// route.Post("/", m.Auth(u), ctrl.CreateOne)
// route.Get("/:id", m.Auth(u), ctrl.GetOne)
// route.Patch("/:id", m.Auth(u), ctrl.UpdateOne)
// route.Delete("/:id", m.Auth(u), ctrl.DeleteOne)
route.Use(m.Auth(u))
route.Get("/", ctrl.GetAll)
route.Post("/", ctrl.CreateOne)
@@ -23,4 +23,3 @@ func (ProductModule) RegisterRoutes(router fiber.Router, db *gorm.DB, validate *
ProductRoutes(router, userService, productService)
}
+2 -7
View File
@@ -1,7 +1,7 @@
package products
import (
// m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/master/products/controllers"
product "gitlab.com/mbugroup/lti-api.git/internal/modules/master/products/services"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
@@ -13,12 +13,7 @@ func ProductRoutes(v1 fiber.Router, u user.UserService, s product.ProductService
ctrl := controller.NewProductController(s)
route := v1.Group("/products")
// route.Get("/", m.Auth(u), ctrl.GetAll)
// route.Post("/", m.Auth(u), ctrl.CreateOne)
// route.Get("/:id", m.Auth(u), ctrl.GetOne)
// route.Patch("/:id", m.Auth(u), ctrl.UpdateOne)
// route.Delete("/:id", m.Auth(u), ctrl.DeleteOne)
route.Use(m.Auth(u))
route.Get("/", ctrl.GetAll)
route.Post("/", ctrl.CreateOne)
+1 -1
View File
@@ -11,6 +11,7 @@ import (
banks "gitlab.com/mbugroup/lti-api.git/internal/modules/master/banks"
customers "gitlab.com/mbugroup/lti-api.git/internal/modules/master/customers"
fcrs "gitlab.com/mbugroup/lti-api.git/internal/modules/master/fcrs"
flocks "gitlab.com/mbugroup/lti-api.git/internal/modules/master/flocks"
kandangs "gitlab.com/mbugroup/lti-api.git/internal/modules/master/kandangs"
locations "gitlab.com/mbugroup/lti-api.git/internal/modules/master/locations"
nonstocks "gitlab.com/mbugroup/lti-api.git/internal/modules/master/nonstocks"
@@ -19,7 +20,6 @@ import (
suppliers "gitlab.com/mbugroup/lti-api.git/internal/modules/master/suppliers"
uoms "gitlab.com/mbugroup/lti-api.git/internal/modules/master/uoms"
warehouses "gitlab.com/mbugroup/lti-api.git/internal/modules/master/warehouses"
flocks "gitlab.com/mbugroup/lti-api.git/internal/modules/master/flocks"
// MODULE IMPORTS
)
@@ -23,4 +23,3 @@ func (SupplierModule) RegisterRoutes(router fiber.Router, db *gorm.DB, validate
SupplierRoutes(router, userService, supplierService)
}
+2 -7
View File
@@ -1,7 +1,7 @@
package suppliers
import (
// m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/master/suppliers/controllers"
supplier "gitlab.com/mbugroup/lti-api.git/internal/modules/master/suppliers/services"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
@@ -13,12 +13,7 @@ func SupplierRoutes(v1 fiber.Router, u user.UserService, s supplier.SupplierServ
ctrl := controller.NewSupplierController(s)
route := v1.Group("/suppliers")
// route.Get("/", m.Auth(u), ctrl.GetAll)
// route.Post("/", m.Auth(u), ctrl.CreateOne)
// route.Get("/:id", m.Auth(u), ctrl.GetOne)
// route.Patch("/:id", m.Auth(u), ctrl.UpdateOne)
// route.Delete("/:id", m.Auth(u), ctrl.DeleteOne)
route.Use(m.Auth(u))
route.Get("/", ctrl.GetAll)
route.Post("/", ctrl.CreateOne)
+4 -2
View File
@@ -15,7 +15,8 @@ type UomBaseDTO struct {
}
type UomListDTO struct {
UomBaseDTO
Id uint `json:"id"`
Name string `json:"name"`
CreatedUser *userDTO.UserBaseDTO `json:"created_user"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
@@ -42,7 +43,8 @@ func ToUomListDTO(e entity.Uom) UomListDTO {
}
return UomListDTO{
UomBaseDTO: ToUomBaseDTO(e),
Id: e.Id,
Name: e.Name,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
CreatedUser: createdUser,
-1
View File
@@ -23,4 +23,3 @@ func (UomModule) RegisterRoutes(router fiber.Router, db *gorm.DB, validate *vali
UomRoutes(router, userService, uomService)
}
+2 -7
View File
@@ -1,7 +1,7 @@
package uoms
import (
// m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/master/uoms/controllers"
uom "gitlab.com/mbugroup/lti-api.git/internal/modules/master/uoms/services"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
@@ -13,12 +13,7 @@ func UomRoutes(v1 fiber.Router, u user.UserService, s uom.UomService) {
ctrl := controller.NewUomController(s)
route := v1.Group("/uoms")
// route.Get("/", m.Auth(u), ctrl.GetAll)
// route.Post("/", m.Auth(u), ctrl.CreateOne)
// route.Get("/:id", m.Auth(u), ctrl.GetOne)
// route.Patch("/:id", m.Auth(u), ctrl.UpdateOne)
// route.Delete("/:id", m.Auth(u), ctrl.DeleteOne)
route.Use(m.Auth(u))
route.Get("/", ctrl.GetAll)
route.Post("/", ctrl.CreateOne)
@@ -23,4 +23,3 @@ func (WarehouseModule) RegisterRoutes(router fiber.Router, db *gorm.DB, validate
WarehouseRoutes(router, userService, warehouseService)
}
+2 -7
View File
@@ -1,7 +1,7 @@
package warehouses
import (
// m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/master/warehouses/controllers"
warehouse "gitlab.com/mbugroup/lti-api.git/internal/modules/master/warehouses/services"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
@@ -13,12 +13,7 @@ func WarehouseRoutes(v1 fiber.Router, u user.UserService, s warehouse.WarehouseS
ctrl := controller.NewWarehouseController(s)
route := v1.Group("/warehouses")
// route.Get("/", m.Auth(u), ctrl.GetAll)
// route.Post("/", m.Auth(u), ctrl.CreateOne)
// route.Get("/:id", m.Auth(u), ctrl.GetOne)
// route.Patch("/:id", m.Auth(u), ctrl.UpdateOne)
// route.Delete("/:id", m.Auth(u), ctrl.DeleteOne)
route.Use(m.Auth(u))
route.Get("/", ctrl.GetAll)
route.Post("/", ctrl.CreateOne)
@@ -1,7 +1,7 @@
package chickins
import (
// m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/production/chickins/controllers"
chickin "gitlab.com/mbugroup/lti-api.git/internal/modules/production/chickins/services"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
@@ -13,12 +13,7 @@ func ChickinRoutes(v1 fiber.Router, u user.UserService, s chickin.ChickinService
ctrl := controller.NewChickinController(s)
route := v1.Group("/chickins")
// route.Get("/", m.Auth(u), ctrl.GetAll)
// route.Post("/", m.Auth(u), ctrl.CreateOne)
// route.Get("/:id", m.Auth(u), ctrl.GetOne)
// route.Patch("/:id", m.Auth(u), ctrl.UpdateOne)
// route.Delete("/:id", m.Auth(u), ctrl.DeleteOne)
route.Use(m.Auth(u))
// route.Get("/", ctrl.GetAll)
route.Post("/", ctrl.CreateOne)
@@ -1,7 +1,7 @@
package project_flocks
import (
// m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/production/project_flocks/controllers"
projectflock "gitlab.com/mbugroup/lti-api.git/internal/modules/production/project_flocks/services"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
@@ -19,6 +19,7 @@ func ProjectflockRoutes(v1 fiber.Router, u user.UserService, s projectflock.Proj
// route.Get("/:id", m.Auth(u), ctrl.GetOne)
// route.Patch("/:id", m.Auth(u), ctrl.UpdateOne)
// route.Delete("/:id", m.Auth(u), ctrl.DeleteOne)
route.Use(m.Auth(u))
route.Get("/", ctrl.GetAll)
route.Post("/", ctrl.CreateOne)
@@ -28,5 +29,6 @@ func ProjectflockRoutes(v1 fiber.Router, u user.UserService, s projectflock.Proj
route.Get("/kandangs/:project_flock_kandang_id/periods", ctrl.GetFlockPeriodSummary)
route.Get("/kandangs/lookup", ctrl.LookupProjectFlockKandang)
route.Post("/approvals", ctrl.Approval)
route.Get("/kandangs/:project_flock_kandang_id/periods", ctrl.GetFlockPeriodSummary)
}
@@ -10,6 +10,7 @@ import (
commonRepo "gitlab.com/mbugroup/lti-api.git/internal/common/repository"
commonSvc "gitlab.com/mbugroup/lti-api.git/internal/common/service"
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
authmiddleware "gitlab.com/mbugroup/lti-api.git/internal/middleware"
productWarehouseRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/inventory/product-warehouses/repositories"
flockDTO "gitlab.com/mbugroup/lti-api.git/internal/modules/master/flocks/dto"
flockRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/master/flocks/repositories"
@@ -235,6 +236,11 @@ func (s *projectflockService) CreateOne(c *fiber.Ctx, req *validation.Create) (*
return nil, err
}
actorID, err := actorIDFromContext(c)
if err != nil {
return nil, err
}
cat := strings.ToUpper(req.Category)
if !utils.IsValidProjectFlockCategory(cat) {
return nil, fiber.NewError(fiber.StatusBadRequest, "Invalid category")
@@ -259,7 +265,7 @@ func (s *projectflockService) CreateOne(c *fiber.Ctx, req *validation.Create) (*
canonicalBase := baseName
if s.FlockRepo != nil {
baseFlock, err := s.ensureFlockByName(c.Context(), baseName)
baseFlock, err := s.ensureFlockByName(c.Context(), actorID, baseName)
if err != nil {
return nil, err
}
@@ -289,7 +295,7 @@ func (s *projectflockService) CreateOne(c *fiber.Ctx, req *validation.Create) (*
Category: cat,
FcrId: req.FcrId,
LocationId: req.LocationId,
CreatedBy: 1,
CreatedBy: actorID,
}
err = s.Repository.DB().WithContext(c.Context()).Transaction(func(dbTransaction *gorm.DB) error {
@@ -314,7 +320,6 @@ func (s *projectflockService) CreateOne(c *fiber.Ctx, req *validation.Create) (*
return err
}
actorID := uint(1) //TODO: Change From Auth
action := entity.ApprovalActionCreated
approvalSvcTx := commonSvc.NewApprovalService(commonRepo.NewApprovalRepository(dbTransaction))
_, err = approvalSvcTx.CreateApproval(
@@ -348,6 +353,11 @@ func (s projectflockService) UpdateOne(c *fiber.Ctx, req *validation.Update, id
return nil, err
}
actorID, err := actorIDFromContext(c)
if err != nil {
return nil, err
}
existing, err := s.Repository.GetByID(c.Context(), id, s.Repository.WithDefaultRelations())
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fiber.NewError(fiber.StatusNotFound, "Projectflock not found")
@@ -370,7 +380,7 @@ func (s projectflockService) UpdateOne(c *fiber.Ctx, req *validation.Update, id
}
canonicalBase := trimmed
if s.FlockRepo != nil {
flockEntity, err := s.ensureFlockByName(c.Context(), trimmed)
flockEntity, err := s.ensureFlockByName(c.Context(), actorID, trimmed)
if err != nil {
return nil, err
}
@@ -529,7 +539,6 @@ func (s projectflockService) UpdateOne(c *fiber.Ctx, req *validation.Update, id
}
if hasChanges {
actorID := uint(1) //TODO: Change From Auth
approvalSvc := commonSvc.NewApprovalService(commonRepo.NewApprovalRepository(dbTransaction))
if approvalSvc != nil {
latestBeforeReset, err := approvalSvc.LatestByTarget(c.Context(), s.approvalWorkflow, id, nil)
@@ -583,7 +592,11 @@ func (s projectflockService) Approval(c *fiber.Ctx, req *validation.Approve) ([]
return nil, err
}
actorID := uint(1) // TODO: change from auth context
actorID, err := actorIDFromContext(c)
if err != nil {
return nil, err
}
var action entity.ApprovalAction
switch strings.ToUpper(strings.TrimSpace(req.Action)) {
case string(entity.ApprovalActionRejected):
@@ -604,7 +617,7 @@ func (s projectflockService) Approval(c *fiber.Ctx, req *validation.Approve) ([]
step = utils.ProjectFlockStepAktif
}
err := s.Repository.DB().WithContext(c.Context()).Transaction(func(dbTransaction *gorm.DB) error {
err = s.Repository.DB().WithContext(c.Context()).Transaction(func(dbTransaction *gorm.DB) error {
approvalSvc := commonSvc.NewApprovalService(commonRepo.NewApprovalRepository(dbTransaction))
kandangRepoTx := kandangRepository.NewKandangRepository(dbTransaction)
projectRepoTx := repository.NewProjectflockRepository(dbTransaction)
@@ -891,7 +904,7 @@ func (s projectflockService) generateSequentialFlockName(ctx context.Context, re
}
}
func (s projectflockService) ensureFlockByName(ctx context.Context, name string) (*entity.Flock, error) {
func (s projectflockService) ensureFlockByName(ctx context.Context, actorID uint, name string) (*entity.Flock, error) {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return nil, fiber.NewError(fiber.StatusBadRequest, "Flock name cannot be empty")
@@ -908,7 +921,7 @@ func (s projectflockService) ensureFlockByName(ctx context.Context, name string)
newFlock := &entity.Flock{
Name: trimmed,
CreatedBy: 1, // TODO: replace with authenticated user
CreatedBy: actorID,
}
if err := s.FlockRepo.CreateOne(ctx, newFlock, nil); err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
@@ -1027,3 +1040,11 @@ func (s projectflockService) kandangRepoWithTx(tx *gorm.DB) kandangRepository.Ka
}
return kandangRepository.NewKandangRepository(s.Repository.DB())
}
func actorIDFromContext(c *fiber.Ctx) (uint, error) {
user, ok := authmiddleware.AuthenticatedUser(c)
if !ok || user == nil || user.Id == 0 {
return 0, fiber.NewError(fiber.StatusUnauthorized, "Please authenticate")
}
return user.Id, nil
}
@@ -69,12 +69,21 @@ type RecordingStockDTO struct {
}
type RecordingEggDTO struct {
Id uint `json:"id"`
ProductWarehouseId uint `json:"product_warehouse_id"`
Qty int `json:"qty"`
ProductWarehouse productWarehouseDTO.ProductWarehouseDTO `json:"product_warehouse"`
Gradings []RecordingEggGradingDTO `json:"gradings,omitempty"`
}
type RecordingProductWarehouseDTO struct {
Id uint `json:"id"`
ProductId uint `json:"product_id"`
ProductName string `json:"product_name"`
WarehouseId uint `json:"warehouse_id"`
WarehouseName string `json:"warehouse_name"`
}
type RecordingEggGradingDTO struct {
Grade string `json:"grade,omitempty"`
Qty float64 `json:"qty"`
@@ -241,6 +250,7 @@ func ToRecordingEggDTOs(eggs []entity.RecordingEgg) []RecordingEggDTO {
result := make([]RecordingEggDTO, len(eggs))
for i, egg := range eggs {
result[i] = RecordingEggDTO{
Id: egg.Id,
ProductWarehouseId: egg.ProductWarehouseId,
Qty: egg.Qty,
ProductWarehouse: mapProductWarehouseDTO(&egg.ProductWarehouse),
@@ -0,0 +1,8 @@
package recordings
const (
PermissionRecordingRead = "recording.read"
PermissionRecordingCreate = "recording.write"
PermissionRecordingUpdate = "recording.update"
PermissionRecordingDelete = "recording.delete"
)
@@ -1,7 +1,7 @@
package recordings
import (
// m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
m "gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/production/recordings/controllers"
recording "gitlab.com/mbugroup/lti-api.git/internal/modules/production/recordings/services"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
@@ -13,12 +13,7 @@ func RecordingRoutes(v1 fiber.Router, u user.UserService, s recording.RecordingS
ctrl := controller.NewRecordingController(s)
route := v1.Group("/recordings")
// route.Get("/", m.Auth(u), ctrl.GetAll)
// route.Post("/", m.Auth(u), ctrl.CreateOne)
// route.Get("/:id", m.Auth(u), ctrl.GetOne)
// route.Patch("/:id", m.Auth(u), ctrl.UpdateOne)
// route.Delete("/:id", m.Auth(u), ctrl.DeleteOne)
route.Use(m.Auth(u))
route.Get("/", ctrl.GetAll)
route.Get("/next-day", ctrl.GetNextDay)
@@ -0,0 +1,706 @@
package controllers
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/sirupsen/logrus"
"gitlab.com/mbugroup/lti-api.git/internal/config"
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
"gitlab.com/mbugroup/lti-api.git/internal/sso"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
"gitlab.com/mbugroup/lti-api.git/internal/utils/secure"
)
// Controller manages the SSO start & callback flow using PKCE.
type Controller struct {
httpClient *http.Client
store *session.Store
revoker *session.RevocationStore
}
func NewController(client *http.Client, store *session.Store, revoker *session.RevocationStore) *Controller {
return &Controller{httpClient: client, store: store, revoker: revoker}
}
// Start handles GET /sso/start requests and redirects users to the central SSO authorize endpoint.
func (h *Controller) Start(c *fiber.Ctx) error {
requestedAlias := normalizeClientParam(c.Query("client"))
if requestedAlias == "" {
requestedAlias = normalizeClientParam(c.Query("client_id"))
}
if requestedAlias == "" {
return fiber.NewError(fiber.StatusBadRequest, "missing client")
}
alias, cfg, ok := findSSOClientConfig(requestedAlias)
if !ok || cfg.PublicID == "" {
return fiber.NewError(fiber.StatusBadRequest, "unknown client")
}
authorizeEndpoint := strings.TrimSpace(config.SSOAuthorizeURL)
if authorizeEndpoint == "" {
return fiber.NewError(fiber.StatusInternalServerError, "authorize endpoint not configured")
}
state, err := secure.RandomString(48)
if err != nil {
utils.Log.Errorf("generate state failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
}
nonce, err := secure.RandomString(32)
if err != nil {
utils.Log.Errorf("generate nonce failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
}
codeVerifier, err := secure.PKCECodeVerifier(96)
if err != nil {
utils.Log.Errorf("generate code verifier failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
}
digest := sha256.Sum256([]byte(codeVerifier))
challenge := secure.Base64URLEncode(digest[:])
authorizeURL, err := url.Parse(authorizeEndpoint)
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, "invalid authorize endpoint")
}
scope := cfg.Scope
if scope == "" {
scope = "openid profile"
}
if !strings.Contains(" "+scope+" ", " openid ") {
scope = scope + " openid"
}
rawReturn := strings.TrimSpace(c.Query("return_to"))
if rawReturn == "" {
rawReturn = cfg.DefaultReturnURI
}
returnTo, err := normalizeReturnTarget(rawReturn, cfg)
if err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
}
query := authorizeURL.Query()
query.Set("response_type", "code")
query.Set("client_id", cfg.PublicID)
query.Set("redirect_uri", cfg.RedirectURI)
query.Set("scope", strings.TrimSpace(scope))
query.Set("state", state)
query.Set("code_challenge", challenge)
query.Set("code_challenge_method", "S256")
query.Set("nonce", nonce)
// if prompt := strings.TrimSpace(cfg.Prompt); prompt != "" {
// query.Set("prompt", prompt)
// }
if extraPrompt := strings.TrimSpace(c.Query("prompt")); extraPrompt != "" {
query.Set("prompt", extraPrompt)
}
authorizeURL.RawQuery = query.Encode()
payload := &session.PKCESession{
CodeVerifier: codeVerifier,
Nonce: nonce,
ClientAlias: alias,
ClientID: cfg.PublicID,
RedirectURI: cfg.RedirectURI,
Scope: strings.TrimSpace(scope),
ReturnTo: returnTo,
CreatedAt: time.Now().UTC(),
}
if err := h.store.Save(c.Context(), state, payload); err != nil {
utils.Log.Errorf("store pkce session failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare authorization")
}
utils.Log.WithFields(logrus.Fields{
"client": alias,
"state": state,
"return_to": returnTo,
}).Info("sso start redirect")
return c.Redirect(authorizeURL.String(), fiber.StatusFound)
}
// Callback handles the redirect from SSO containing the authorization code.
func (h *Controller) Callback(c *fiber.Ctx) error {
state := strings.TrimSpace(c.Query("state"))
code := strings.TrimSpace(c.Query("code"))
if state == "" || code == "" {
return fiber.NewError(fiber.StatusBadRequest, "missing code or state")
}
sessionData, err := h.store.Get(c.Context(), state)
if err != nil {
utils.Log.Errorf("load pkce session failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to validate authorization state")
}
if sessionData == nil {
return fiber.NewError(fiber.StatusBadRequest, "authorization state not found or expired")
}
defer func() {
if err := h.store.Delete(context.Background(), state); err != nil {
utils.Log.Warnf("failed to delete pkce session: %v", err)
}
}()
tokenEndpoint := strings.TrimSpace(config.SSOTokenURL)
if tokenEndpoint == "" {
return fiber.NewError(fiber.StatusInternalServerError, "token endpoint not configured")
}
form := url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("code", code)
form.Set("code_verifier", sessionData.CodeVerifier)
form.Set("redirect_uri", sessionData.RedirectURI)
form.Set("client_id", sessionData.ClientID)
req, err := http.NewRequestWithContext(c.Context(), http.MethodPost, tokenEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, "failed to create token request")
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := h.httpClient.Do(req)
if err != nil {
utils.Log.Errorf("token request failed: %v", err)
return fiber.NewError(fiber.StatusBadGateway, "failed to exchange authorization code")
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
utils.Log.Warnf("token response status %d", resp.StatusCode)
return fiber.NewError(fiber.StatusBadGateway, "token exchange rejected")
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
IDToken string `json:"id_token"`
Error string `json:"error"`
Description string `json:"error_description"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return fiber.NewError(fiber.StatusBadGateway, "invalid token response")
}
if tokenResp.Error != "" {
return fiber.NewError(fiber.StatusBadGateway, tokenResp.Description)
}
if tokenResp.AccessToken == "" {
return fiber.NewError(fiber.StatusBadGateway, "missing access token")
}
verification, err := sso.VerifyAccessToken(tokenResp.AccessToken)
if err != nil {
utils.Log.Errorf("access token verification failed: %v", err)
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
}
// prepare cookies
issueCookies(c, tokenResp, verification)
redirectTarget := sessionData.ReturnTo
if redirectTarget == "" {
redirectTarget = "/"
}
utils.Log.WithFields(logrus.Fields{
"client": sessionData.ClientAlias,
"user_id": verification.UserID,
"return_to": redirectTarget,
}).Info("sso callback successful")
return c.Redirect(redirectTarget, fiber.StatusFound)
}
// UserInfo proxies the user profile from the central SSO so the frontend can obtain
// enriched user metadata (roles, permissions, etc.) without exposing tokens to the browser.
func (h *Controller) UserInfo(c *fiber.Ctx) error {
accessName := config.SSOAccessCookieName
if accessName == "" {
accessName = "sso_access"
}
token := strings.TrimSpace(c.Cookies(accessName))
tokenFromCookie := token != ""
if !tokenFromCookie {
authHeader := strings.TrimSpace(c.Get("Authorization"))
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
token = strings.TrimSpace(authHeader[7:])
}
}
if token == "" {
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
}
if revoker := session.GetRevocationStore(); revoker != nil {
if fingerprint := session.TokenFingerprint(token); fingerprint != "" {
revoked, err := revoker.IsRevoked(c.Context(), fingerprint)
if err != nil {
utils.Log.WithError(err).Warn("failed to check token revocation for userinfo")
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
}
if revoked {
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
}
}
}
if _, err := sso.VerifyAccessToken(token); err != nil {
utils.Log.WithError(err).Warn("access token verification failed for userinfo")
return fiber.NewError(fiber.StatusUnauthorized, "unauthenticated")
}
endpoint := strings.TrimSpace(config.SSOGetMeURL)
if endpoint == "" {
return fiber.NewError(fiber.StatusInternalServerError, "userinfo endpoint not configured")
}
req, err := http.NewRequestWithContext(c.Context(), http.MethodGet, endpoint, nil)
if err != nil {
utils.Log.Errorf("failed to build userinfo request: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to prepare userinfo request")
}
req.Header.Set("Accept", "application/json")
// SSO /auth/get-me expects the access cookie; add Authorization as well for compatibility.
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
if tokenFromCookie {
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", accessName, token))
}
resp, err := h.httpClient.Do(req)
if err != nil {
utils.Log.Errorf("userinfo request failed: %v", err)
return fiber.NewError(fiber.StatusBadGateway, "failed to fetch user profile")
}
defer resp.Body.Close()
utils.Log.WithFields(logrus.Fields{"status": resp.StatusCode}).Info("sso userinfo response")
body, err := io.ReadAll(resp.Body)
if err != nil {
utils.Log.Errorf("failed to read userinfo response: %v", err)
return fiber.NewError(fiber.StatusBadGateway, "invalid user profile response")
}
// if sanitized, perms, ok := sanitizeUserInfoPayload(body); ok {
// if caps := capabilities.FromPermissions(perms); len(caps) > 0 {
// injectCapabilities(sanitized, caps)
// }
// return c.Status(resp.StatusCode).JSON(sanitized)
// }
if ct := resp.Header.Get("Content-Type"); ct != "" {
c.Set("Content-Type", ct)
} else {
c.Type("json")
}
return c.Status(resp.StatusCode).Send(body)
}
// Logout clears SSO cookies and removes any leftover PKCE session state.
func (h *Controller) Logout(c *fiber.Ctx) error {
requestedAlias := normalizeClientParam(c.Query("client"))
if requestedAlias == "" {
requestedAlias = normalizeClientParam(c.Query("client_id"))
}
var (
alias string
cfg config.SSOClientConfig
hasClientInfo bool
)
if requestedAlias != "" {
alias, cfg, hasClientInfo = findSSOClientConfig(requestedAlias)
}
accessName := resolveSSOCookieName(config.SSOAccessCookieName, "access")
refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh")
var accessToken, refreshToken string
if accessName != "" {
accessToken = strings.TrimSpace(c.Cookies(accessName))
}
if refreshName != "" {
refreshToken = strings.TrimSpace(c.Cookies(refreshName))
}
hadAccessCookie := accessToken != ""
hadRefreshCookie := refreshToken != ""
state := strings.TrimSpace(c.Query("state"))
if state != "" {
if err := h.store.Delete(c.Context(), state); err != nil {
utils.Log.Warnf("failed to delete pkce session during logout: %v", err)
}
}
if !hadAccessCookie && !hadRefreshCookie && state == "" {
return fiber.NewError(fiber.StatusUnauthorized, "not authenticated")
}
if hadAccessCookie {
if verification, err := sso.VerifyAccessToken(accessToken); err != nil {
utils.Log.WithError(err).Warn("failed to verify access token during logout")
} else {
if revoker := session.GetRevocationStore(); revoker != nil {
if err := revoker.MarkUserLogout(c.Context(), verification.UserID, time.Now().UTC()); err != nil {
utils.Log.WithError(err).Warn("failed to mark user logout")
}
}
h.revokeToken(c.Context(), accessToken, verification)
}
}
if refreshToken != "" {
h.revokeRefreshToken(c.Context(), refreshToken)
}
clearSSOCookie(c, accessName)
clearSSOCookie(c, refreshName)
redirectTarget := ""
rawReturn := strings.TrimSpace(c.Query("return_to"))
if hasClientInfo {
if rawReturn == "" {
rawReturn = cfg.DefaultReturnURI
}
if normalized, err := normalizeReturnTarget(rawReturn, cfg); err == nil {
redirectTarget = normalized
} else if rawReturn != "" {
utils.Log.WithError(err).Warn("invalid return_to during logout")
}
} else if rawReturn != "" {
if strings.HasPrefix(rawReturn, "/") && !strings.HasPrefix(rawReturn, "//") {
redirectTarget = rawReturn
}
}
utils.Log.WithFields(logrus.Fields{
"client": alias,
"state": state,
"redirect": redirectTarget,
}).Info("sso logout completed")
if redirectTarget != "" {
return c.Redirect(redirectTarget, fiber.StatusFound)
}
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "signed out"})
}
func (h *Controller) revokeToken(ctx context.Context, token string, verification *sso.VerificationResult) {
if h.revoker == nil || verification == nil || verification.Claims == nil {
return
}
fingerprint := session.TokenFingerprint(token)
if fingerprint == "" {
return
}
if verification.Claims.ExpiresAt == nil {
utils.Log.Warn("access token missing expiry claim")
return
}
ttl := time.Until(verification.Claims.ExpiresAt.Time)
if ttl <= 0 {
return
}
if ttl < time.Second {
ttl = time.Second
}
if err := h.revoker.Revoke(ctx, fingerprint, ttl); err != nil {
utils.Log.WithError(err).Warn("failed to revoke access token")
}
}
func (h *Controller) revokeRefreshToken(ctx context.Context, token string) {
if h.revoker == nil {
return
}
fingerprint := session.TokenFingerprint(token)
if fingerprint == "" {
return
}
const refreshTTL = 30 * 24 * time.Hour
if err := h.revoker.Revoke(ctx, fingerprint, refreshTTL); err != nil {
utils.Log.WithError(err).Warn("failed to revoke refresh token")
}
}
func issueCookies(c *fiber.Ctx, tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
IDToken string `json:"id_token"`
Error string `json:"error"`
Description string `json:"error_description"`
}, verification *sso.VerificationResult) {
if revoker := session.GetRevocationStore(); revoker != nil && verification != nil {
if err := revoker.ClearUserLogout(c.Context(), verification.UserID); err != nil {
utils.Log.WithError(err).Warn("failed to clear logout marker")
}
}
accessName := resolveSSOCookieName(config.SSOAccessCookieName, "access")
refreshName := resolveSSOCookieName(config.SSORefreshCookieName, "refresh")
maxAge := tokenResp.ExpiresIn
if maxAge <= 0 {
maxAge = int(15 * time.Minute.Seconds())
}
sameSite := config.SSOCookieSameSite
if sameSite == "" {
sameSite = "Lax"
}
cookieDomain := config.SSOCookieDomain
cookieAccess := &fiber.Cookie{
Name: accessName,
Value: tokenResp.AccessToken,
Path: "/",
Domain: cookieDomain,
HTTPOnly: true,
Secure: config.SSOCookieSecure,
SameSite: sameSite,
MaxAge: maxAge,
}
c.Cookie(cookieAccess)
if tokenResp.RefreshToken != "" {
cookieRefresh := &fiber.Cookie{
Name: refreshName,
Value: tokenResp.RefreshToken,
Path: "/",
Domain: cookieDomain,
HTTPOnly: true,
Secure: config.SSOCookieSecure,
SameSite: sameSite,
MaxAge: int((time.Hour * 24 * 30).Seconds()),
}
c.Cookie(cookieRefresh)
}
// Optional: expose limited info via headers for FE debugging (avoid tokens)
c.Set("X-Auth-User", fmt.Sprintf("%d", verification.UserID))
}
func clearSSOCookie(c *fiber.Ctx, name string) {
if name == "" {
return
}
sameSite := config.SSOCookieSameSite
if sameSite == "" {
sameSite = "Lax"
}
c.Cookie(&fiber.Cookie{
Name: name,
Value: "",
Path: "/",
Domain: config.SSOCookieDomain,
HTTPOnly: true,
Secure: config.SSOCookieSecure,
SameSite: sameSite,
Expires: time.Unix(0, 0),
MaxAge: -1,
})
}
func resolveSSOCookieName(configuredName, fallback string) string {
name := strings.TrimSpace(configuredName)
if name != "" {
return name
}
return strings.TrimSpace(fallback)
}
func normalizeClientParam(raw string) string {
value := strings.TrimSpace(raw)
if value == "" {
return ""
}
if idx := strings.Index(value, "|"); idx >= 0 {
value = value[:idx]
}
value = strings.TrimSpace(value)
return strings.ToLower(value)
}
func sanitizeUserInfoPayload(body []byte) (map[string]any, []string, bool) {
if len(body) == 0 {
return map[string]any{}, nil, true
}
var payload any
if err := json.Unmarshal(body, &payload); err != nil {
return nil, nil, false
}
perms := collectPermissionNames(payload)
sensitive := map[string]struct{}{
"roles": {},
"permissions": {},
}
payload = scrubSensitiveKeys(payload, sensitive)
sanitized, ok := payload.(map[string]any)
if !ok {
sanitized = map[string]any{"data": payload}
}
return sanitized, perms, true
}
func scrubSensitiveKeys(value any, sensitive map[string]struct{}) any {
switch v := value.(type) {
case map[string]any:
for key, val := range v {
if _, ok := sensitive[strings.ToLower(key)]; ok {
delete(v, key)
continue
}
v[key] = scrubSensitiveKeys(val, sensitive)
}
return v
case []any:
for i, item := range v {
v[i] = scrubSensitiveKeys(item, sensitive)
}
return v
default:
return value
}
}
func collectPermissionNames(value any) []string {
names := make(map[string]struct{})
collectPermissionRec(value, names)
out := make([]string, 0, len(names))
for name := range names {
out = append(out, name)
}
return out
}
func collectPermissionRec(value any, acc map[string]struct{}) {
switch v := value.(type) {
case map[string]any:
for key, val := range v {
if strings.EqualFold(key, "permissions") {
if arr, ok := val.([]any); ok {
for _, item := range arr {
if perm, ok := item.(map[string]any); ok {
if name, ok := perm["name"].(string); ok && strings.TrimSpace(name) != "" {
acc[strings.ToLower(strings.TrimSpace(name))] = struct{}{}
}
}
}
}
} else {
collectPermissionRec(val, acc)
}
}
case []any:
for _, item := range v {
collectPermissionRec(item, acc)
}
}
}
func injectCapabilities(payload map[string]any, caps map[string]bool) {
if len(caps) == 0 {
return
}
if data, ok := payload["data"].(map[string]any); ok {
data["capabilities"] = caps
return
}
payload["capabilities"] = caps
}
func findSSOClientConfig(requestedAlias string) (string, config.SSOClientConfig, bool) {
if requestedAlias == "" {
return "", config.SSOClientConfig{}, false
}
if cfg, ok := config.SSOClients[requestedAlias]; ok && strings.TrimSpace(cfg.PublicID) != "" {
return requestedAlias, cfg, true
}
for alias, cfg := range config.SSOClients {
if strings.EqualFold(strings.TrimSpace(cfg.PublicID), requestedAlias) && strings.TrimSpace(cfg.PublicID) != "" {
return alias, cfg, true
}
}
return "", config.SSOClientConfig{}, false
}
func normalizeReturnTarget(returnTo string, cfg config.SSOClientConfig) (string, error) {
returnTo = strings.TrimSpace(returnTo)
if returnTo == "" {
return "", nil
}
if strings.HasPrefix(returnTo, "//") {
return "", fmt.Errorf("invalid return_to")
}
if strings.HasPrefix(returnTo, "/") {
return returnTo, nil
}
parsed, err := url.Parse(returnTo)
if err != nil {
return "", fmt.Errorf("invalid return_to")
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return "", fmt.Errorf("invalid return_to scheme")
}
allowedOrigins := make(map[string]struct{})
if cfg.DefaultReturnURI != "" {
if u, err := url.Parse(cfg.DefaultReturnURI); err == nil && u.Host != "" {
allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{}
}
}
for _, origin := range cfg.AllowedReturnOrigins {
origin = strings.TrimSpace(origin)
if origin == "" {
continue
}
if u, err := url.Parse(origin); err == nil && u.Host != "" && (u.Scheme == "http" || u.Scheme == "https") {
allowedOrigins[u.Scheme+"://"+u.Host] = struct{}{}
}
}
if len(allowedOrigins) > 0 {
origin := parsed.Scheme + "://" + parsed.Host
if _, ok := allowedOrigins[origin]; !ok {
return "", fmt.Errorf("return_to origin not allowed")
}
}
return parsed.String(), nil
}
@@ -0,0 +1,429 @@
package controllers
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2"
"github.com/redis/go-redis/v9"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
"strconv"
"strings"
"sync"
"time"
"gitlab.com/mbugroup/lti-api.git/internal/config"
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
"gitlab.com/mbugroup/lti-api.git/internal/modules/users/dto"
userRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/users/repositories"
"gitlab.com/mbugroup/lti-api.git/internal/response"
"gitlab.com/mbugroup/lti-api.git/internal/sso"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
)
const (
headerClient = "X-Sync-Client"
headerTimestamp = "X-Sync-Timestamp"
headerNonce = "X-Sync-Nonce"
headerSignature = "X-Sync-Signature"
defaultDrift = 2 * time.Minute
defaultNonceTTL = 10 * time.Minute
)
// UserSyncController handles incoming user management events from the central SSO service.
type UserSyncController struct {
validate *validator.Validate
repo userRepository.UserRepository
redis *redis.Client
clients map[string]config.SSOClientConfig
drift time.Duration
nonceTTL time.Duration
maxBodyBytes int
log *logrus.Logger
localNonces sync.Map
}
type userSyncRequest struct {
Action string `json:"action" validate:"required,oneof=create update delete logout"`
PublicID string `json:"public_id" validate:"required"`
User userSyncUser `json:"user" validate:"required"`
}
type userSyncUser struct {
ID int64 `json:"id" validate:"required"`
Email string `json:"email"`
Name string `json:"name"`
}
func NewUserSyncController(validate *validator.Validate, repo userRepository.UserRepository, redis *redis.Client, clients map[string]config.SSOClientConfig) *UserSyncController {
normalized := make(map[string]config.SSOClientConfig, len(clients))
for alias, cfg := range clients {
alias = strings.ToLower(strings.TrimSpace(alias))
normalized[alias] = cfg
}
drift := config.SSOUserSyncDrift
if drift <= 0 {
drift = defaultDrift
}
nonceTTL := config.SSOUserSyncNonceTTL
if nonceTTL <= 0 {
nonceTTL = defaultNonceTTL
}
maxBody := config.SSOUserSyncMaxBodyBytes
if maxBody <= 0 {
maxBody = 32 * 1024
}
log := utils.Log
if redis == nil {
log.Warn("SSO user sync nonce store fallback to in-memory cache; enable Redis for replay protection")
}
return &UserSyncController{
validate: validate,
repo: repo,
redis: redis,
clients: normalized,
drift: drift,
nonceTTL: nonceTTL,
maxBodyBytes: maxBody,
log: log,
}
}
func (h *UserSyncController) Sync(c *fiber.Ctx) error {
if ct := strings.TrimSpace(c.Get(fiber.HeaderContentType)); ct != "" && !strings.HasPrefix(strings.ToLower(ct), fiber.MIMEApplicationJSON) {
return fiber.NewError(fiber.StatusUnsupportedMediaType, "content-type must be application/json")
}
body := c.Body()
if h.maxBodyBytes > 0 && len(body) > h.maxBodyBytes {
return fiber.NewError(fiber.StatusRequestEntityTooLarge, "request body too large")
}
alias, clientCfg, err := h.authenticate(c, body)
if err != nil {
return err
}
req := new(userSyncRequest)
if err := json.Unmarshal(body, req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, "invalid request body")
}
req.Action = strings.ToLower(strings.TrimSpace(req.Action))
req.PublicID = strings.TrimSpace(req.PublicID)
req.User.Email = strings.TrimSpace(req.User.Email)
req.User.Name = strings.TrimSpace(req.User.Name)
if err := h.validate.Struct(req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
}
if clientCfg.PublicID != "" && req.PublicID != clientCfg.PublicID {
return fiber.NewError(fiber.StatusBadRequest, "public_id mismatch with configured client")
}
if req.Action == "create" || req.Action == "update" {
if req.User.Email == "" || req.User.Name == "" {
return fiber.NewError(fiber.StatusBadRequest, "email and name are required for create/update actions")
}
if err := h.validate.Var(req.User.Email, "email"); err != nil {
return fiber.NewError(fiber.StatusBadRequest, "invalid email format")
}
}
if req.User.ID <= 0 {
return fiber.NewError(fiber.StatusBadRequest, "invalid user id")
}
switch req.Action {
case "create", "update":
return h.upsertUser(c, alias, req)
case "delete":
return h.removeUser(c, alias, req)
case "logout":
return h.logoutUser(c, alias, req)
default:
return fiber.NewError(fiber.StatusBadRequest, "unsupported action")
}
}
func (h *UserSyncController) authenticate(c *fiber.Ctx, body []byte) (string, config.SSOClientConfig, error) {
rawAlias := strings.TrimSpace(c.Get(headerClient))
if rawAlias == "" {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "missing sync client header")
}
aliasKey := strings.ToLower(rawAlias)
clientCfg, ok := h.clients[aliasKey]
if !ok {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "unknown sync client")
}
if err := h.verifyAuthorization(c, aliasKey); err != nil {
return "", config.SSOClientConfig{}, err
}
secret := strings.TrimSpace(clientCfg.SyncSecret)
if secret == "" {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "sync secret not configured")
}
timestamp := strings.TrimSpace(c.Get(headerTimestamp))
nonce := strings.TrimSpace(c.Get(headerNonce))
signature := strings.TrimSpace(c.Get(headerSignature))
if timestamp == "" || nonce == "" || signature == "" {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "missing signature headers")
}
if len(nonce) < 16 {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "nonce too short")
}
ts, err := strconv.ParseInt(timestamp, 10, 64)
if err != nil {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusBadRequest, "invalid timestamp")
}
msgTime := time.Unix(ts, 0).UTC()
now := time.Now().UTC()
drift := now.Sub(msgTime)
if drift > h.drift || drift < -h.drift {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "timestamp outside allowed window")
}
providedSig, err := decodeSignature(signature)
if err != nil {
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature encoding")
}
expectedSignature := h.calculateSignature(secret, rawAlias, timestamp, nonce, body)
if !hmac.Equal(providedSig, expectedSignature) {
bodyHash := sha256.Sum256(body)
h.log.WithFields(logrus.Fields{
"alias": rawAlias,
"alias_key": aliasKey,
"timestamp": timestamp,
"nonce": nonce,
"body_len": len(body),
"body_sha256": hex.EncodeToString(bodyHash[:]),
"body_base64": base64.StdEncoding.EncodeToString(body),
"provided_hex_full": hex.EncodeToString(providedSig),
"expected_hex_full": hex.EncodeToString(expectedSignature),
}).Warn("sso sync signature mismatch")
return "", config.SSOClientConfig{}, fiber.NewError(fiber.StatusUnauthorized, "invalid signature")
}
if err := h.registerNonce(c.Context(), aliasKey, nonce); err != nil {
return "", config.SSOClientConfig{}, err
}
return aliasKey, clientCfg, nil
}
func (h *UserSyncController) verifyAuthorization(c *fiber.Ctx, alias string) error {
authHeader := strings.TrimSpace(c.Get(fiber.HeaderAuthorization))
if authHeader == "" {
return fiber.NewError(fiber.StatusUnauthorized, "missing authorization header")
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header")
}
token := strings.TrimSpace(parts[1])
if token == "" {
return fiber.NewError(fiber.StatusUnauthorized, "invalid authorization header")
}
verification, err := sso.VerifyAccessToken(token)
if err != nil {
return fiber.NewError(fiber.StatusUnauthorized, "invalid access token")
}
if verification.ServiceAlias == "" || verification.ServiceAlias != alias {
return fiber.NewError(fiber.StatusUnauthorized, "service subject mismatch")
}
if !containsScope(verification.Claims.Scopes(), "sync.users") {
return fiber.NewError(fiber.StatusForbidden, "missing sync scope")
}
return nil
}
func (h *UserSyncController) upsertUser(c *fiber.Ctx, alias string, req *userSyncRequest) error {
entity := &entity.User{
IdUser: req.User.ID,
Email: req.User.Email,
Name: req.User.Name,
}
//TODO: MIGRATION TO UPSERT BASE REPOSITORY
if err := h.repo.UpsertByIdUser(c.Context(), entity); err != nil {
h.log.Errorf("sso user upsert failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to upsert user")
}
user, err := h.repo.GetByIdUser(c.Context(), req.User.ID, nil)
if err != nil {
h.log.Errorf("sso user fetch after upsert failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to load user")
}
h.log.WithFields(logrus.Fields{
"action": req.Action,
"public_id": req.PublicID,
"alias": alias,
"user_id": req.User.ID,
}).Info("sso user synced")
msg := fmt.Sprintf("User %s successfully", req.Action)
return c.Status(fiber.StatusOK).JSON(response.Success{
Code: fiber.StatusOK,
Status: "success",
Message: msg,
Data: dto.ToUserListDTO(*user),
})
}
func (h *UserSyncController) logoutUser(c *fiber.Ctx, alias string, req *userSyncRequest) error {
revoker := session.GetRevocationStore()
if revoker != nil {
if err := revoker.MarkUserLogout(c.Context(), uint(req.User.ID), time.Now().UTC()); err != nil {
h.log.WithError(err).Error("sso user logout revoke failed")
return fiber.NewError(fiber.StatusInternalServerError, "failed to revoke user session")
}
} else {
h.log.Warn("sso user logout received but revocation store not configured")
}
h.log.WithFields(logrus.Fields{
"action": req.Action,
"public_id": req.PublicID,
"alias": alias,
"user_id": req.User.ID,
}).Info("sso user logout enforced")
return c.Status(fiber.StatusOK).JSON(response.Common{
Code: fiber.StatusOK,
Status: "success",
Message: "User sessions revoked successfully",
})
}
func (h *UserSyncController) removeUser(c *fiber.Ctx, alias string, req *userSyncRequest) error {
if err := h.repo.SoftDeleteByIdUser(c.Context(), req.User.ID); err != nil {
if err == gorm.ErrRecordNotFound {
return fiber.NewError(fiber.StatusNotFound, "user not found")
}
h.log.Errorf("sso user delete failed: %v", err)
return fiber.NewError(fiber.StatusInternalServerError, "failed to delete user")
}
h.log.WithFields(logrus.Fields{
"action": req.Action,
"public_id": req.PublicID,
"alias": alias,
"user_id": req.User.ID,
}).Info("sso user deleted")
return c.Status(fiber.StatusOK).JSON(response.Common{
Code: fiber.StatusOK,
Status: "success",
Message: "User deleted successfully",
})
}
func (h *UserSyncController) registerNonce(ctx context.Context, alias, nonce string) error {
ttl := h.nonceTTL
if ttl <= 0 {
ttl = defaultNonceTTL
}
key := fmt.Sprintf("sso:sync:%s:%s", alias, nonce)
if h.redis != nil {
stored, err := h.redis.SetNX(ctx, key, "1", ttl).Result()
if err == nil {
if !stored {
return fiber.NewError(fiber.StatusUnauthorized, "nonce already used")
}
return nil
}
h.log.Errorf("store sync nonce failed: %v", err)
}
now := time.Now().UTC()
if expRaw, ok := h.localNonces.Load(key); ok {
if expTime, ok := expRaw.(time.Time); ok && expTime.After(now) {
return fiber.NewError(fiber.StatusUnauthorized, "nonce already used")
}
}
h.localNonces.Store(key, now.Add(ttl))
h.pruneLocalNonces(now)
return nil
}
func (h *UserSyncController) calculateSignature(secret, alias, timestamp, nonce string, body []byte) []byte {
mac := hmac.New(sha256.New, []byte(secret))
mac.Write([]byte(alias))
mac.Write([]byte("\n"))
mac.Write([]byte(timestamp))
mac.Write([]byte("\n"))
mac.Write([]byte(nonce))
mac.Write([]byte("\n"))
mac.Write(body)
return mac.Sum(nil)
}
func containsScope(scopes []string, target string) bool {
target = strings.ToLower(strings.TrimSpace(target))
if target == "" {
return false
}
for _, scope := range scopes {
if strings.ToLower(strings.TrimSpace(scope)) == target {
return true
}
}
return false
}
func decodeSignature(sig string) ([]byte, error) {
sig = strings.TrimSpace(sig)
if sig == "" {
return nil, errors.New("empty signature")
}
if decoded, err := hex.DecodeString(sig); err == nil {
return decoded, nil
}
if decoded, err := base64.StdEncoding.DecodeString(sig); err == nil {
return decoded, nil
}
if decoded, err := base64.URLEncoding.DecodeString(sig); err == nil {
return decoded, nil
}
return nil, errors.New("unrecognized signature encoding")
}
func (h *UserSyncController) pruneLocalNonces(now time.Time) {
h.localNonces.Range(func(key, value any) bool {
exp, ok := value.(time.Time)
if !ok || exp.Before(now) {
h.localNonces.Delete(key)
}
return true
})
}
+13
View File
@@ -0,0 +1,13 @@
package sso
import (
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2"
"gorm.io/gorm"
)
type Module struct{}
func (Module) RegisterRoutes(router fiber.Router, db *gorm.DB, validate *validator.Validate) {
Routes(router, db, validate)
}
+36
View File
@@ -0,0 +1,36 @@
package sso
import (
"net/http"
"time"
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2"
"gorm.io/gorm"
"gitlab.com/mbugroup/lti-api.git/internal/cache"
"gitlab.com/mbugroup/lti-api.git/internal/config"
"gitlab.com/mbugroup/lti-api.git/internal/middleware"
ssoController "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/controllers"
"gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session"
userRepository "gitlab.com/mbugroup/lti-api.git/internal/modules/users/repositories"
)
func Routes(router fiber.Router, db *gorm.DB, validate *validator.Validate) {
ttl := config.SSOPKCETTL
if ttl <= 0 {
ttl = 5 * time.Minute
}
store := session.NewStore(cache.MustRedis(), ttl)
ctrl := ssoController.NewController(&http.Client{Timeout: 10 * time.Second}, store, session.GetRevocationStore())
userRepo := userRepository.NewUserRepository(db)
syncCtrl := ssoController.NewUserSyncController(validate, userRepo, cache.Redis(), config.SSOClients)
group := router.Group("/sso")
group.Get("/start", middleware.NewLimiter(30, time.Minute), ctrl.Start)
group.Get("/callback", ctrl.Callback)
group.Get("/userinfo", middleware.NewLimiter(60, time.Minute), ctrl.UserInfo)
group.Post("/logout", middleware.NewLimiter(60, time.Minute), ctrl.Logout)
group.Post("/users/sync", middleware.NewLimiter(30, time.Minute), syncCtrl.Sync)
}
+163
View File
@@ -0,0 +1,163 @@
package session
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"strconv"
"strings"
"sync"
"time"
"github.com/redis/go-redis/v9"
)
// RevocationStore handles token blacklist / revocation entries in Redis.
type RevocationStore struct {
redis *redis.Client
prefix string
}
var (
globalRevokerMu sync.RWMutex
globalRevoker *RevocationStore
)
// NewRevocationStore creates a revocation store with the given redis client and key prefix.
func NewRevocationStore(client *redis.Client, prefix string) *RevocationStore {
return &RevocationStore{
redis: client,
prefix: strings.TrimSpace(prefix),
}
}
// SetRevocationStore registers the provided revocation store for global access.
func SetRevocationStore(store *RevocationStore) {
globalRevokerMu.Lock()
globalRevoker = store
globalRevokerMu.Unlock()
}
// GetRevocationStore returns the globally registered revocation store, or nil if unset.
func GetRevocationStore() *RevocationStore {
globalRevokerMu.RLock()
defer globalRevokerMu.RUnlock()
return globalRevoker
}
// MustRevocationStore returns the registered revocation store or panics if none is configured.
func MustRevocationStore() *RevocationStore {
store := GetRevocationStore()
if store == nil {
panic("revocation store not initialised")
}
return store
}
// Revoke stores the fingerprint with the provided TTL.
func (s *RevocationStore) Revoke(ctx context.Context, fingerprint string, ttl time.Duration) error {
if s == nil || s.redis == nil {
return errors.New("revocation store redis client not initialised")
}
fingerprint = strings.TrimSpace(fingerprint)
if fingerprint == "" {
return nil
}
if ttl <= 0 {
ttl = time.Minute
}
key := s.keyFor(fingerprint)
return s.redis.Set(ctx, key, "1", ttl).Err()
}
// IsRevoked returns true when the fingerprint appears in the blacklist.
func (s *RevocationStore) IsRevoked(ctx context.Context, fingerprint string) (bool, error) {
if s == nil || s.redis == nil {
return false, errors.New("revocation store redis client not initialised")
}
fingerprint = strings.TrimSpace(fingerprint)
if fingerprint == "" {
return false, nil
}
key := s.keyFor(fingerprint)
exists, err := s.redis.Exists(ctx, key).Result()
if err != nil {
return false, err
}
return exists > 0, nil
}
// MarkUserLogout stores the timestamp of the last forced logout for the given user.
func (s *RevocationStore) MarkUserLogout(ctx context.Context, userID uint, at time.Time) error {
if s == nil || s.redis == nil {
return errors.New("revocation store redis client not initialised")
}
if userID == 0 {
return errors.New("invalid user id")
}
key := s.userLogoutKey(userID)
return s.redis.Set(ctx, key, at.UTC().Format(time.RFC3339Nano), 0).Err()
}
// ClearUserLogout removes any stored forced logout marker for the given user.
func (s *RevocationStore) ClearUserLogout(ctx context.Context, userID uint) error {
if s == nil || s.redis == nil {
return errors.New("revocation store redis client not initialised")
}
if userID == 0 {
return errors.New("invalid user id")
}
key := s.userLogoutKey(userID)
return s.redis.Del(ctx, key).Err()
}
// UserLogoutTime returns the timestamp of the last forced logout for the given user.
func (s *RevocationStore) UserLogoutTime(ctx context.Context, userID uint) (time.Time, error) {
var zero time.Time
if s == nil || s.redis == nil {
return zero, errors.New("revocation store redis client not initialised")
}
if userID == 0 {
return zero, errors.New("invalid user id")
}
key := s.userLogoutKey(userID)
value, err := s.redis.Get(ctx, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return zero, nil
}
return zero, err
}
ts, err := time.Parse(time.RFC3339Nano, value)
if err != nil {
return zero, err
}
return ts, nil
}
func (s *RevocationStore) keyFor(fingerprint string) string {
prefix := s.prefix
if prefix == "" {
prefix = "sso:blacklist"
}
return prefix + ":" + fingerprint
}
func (s *RevocationStore) userLogoutKey(userID uint) string {
prefix := s.prefix
if prefix == "" {
prefix = "sso:blacklist"
}
return prefix + ":user-logout:" + strconv.FormatUint(uint64(userID), 10)
}
// TokenFingerprint hashes token material before persisting it to the blacklist.
func TokenFingerprint(token string) string {
token = strings.TrimSpace(token)
if token == "" {
return ""
}
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
+70
View File
@@ -0,0 +1,70 @@
package session
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
)
const keyTemplate = "sso:pkce:%s"
// PKCESession holds data required to complete the OAuth2 PKCE exchange.
type PKCESession struct {
CodeVerifier string `json:"code_verifier"`
Nonce string `json:"nonce"`
ClientAlias string `json:"client_alias"`
ClientID string `json:"client_id"`
RedirectURI string `json:"redirect_uri"`
Scope string `json:"scope"`
ReturnTo string `json:"return_to,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// Store persists pkce sessions inside Redis using a configurable TTL.
type Store struct {
redis *redis.Client
ttl time.Duration
}
func NewStore(client *redis.Client, ttl time.Duration) *Store {
return &Store{redis: client, ttl: ttl}
}
func (s *Store) Save(ctx context.Context, state string, payload *PKCESession) error {
if s.redis == nil {
return fmt.Errorf("redis client is not initialised")
}
bytes, err := json.Marshal(payload)
if err != nil {
return err
}
return s.redis.Set(ctx, fmt.Sprintf(keyTemplate, state), bytes, s.ttl).Err()
}
func (s *Store) Get(ctx context.Context, state string) (*PKCESession, error) {
if s.redis == nil {
return nil, fmt.Errorf("redis client is not initialised")
}
raw, err := s.redis.Get(ctx, fmt.Sprintf(keyTemplate, state)).Result()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, err
}
var payload PKCESession
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
return nil, err
}
return &payload, nil
}
func (s *Store) Delete(ctx context.Context, state string) error {
if s.redis == nil {
return fmt.Errorf("redis client is not initialised")
}
return s.redis.Del(ctx, fmt.Sprintf(keyTemplate, state)).Err()
}
@@ -71,70 +71,70 @@ func (u *UserController) GetOne(c *fiber.Ctx) error {
})
}
func (u *UserController) CreateOne(c *fiber.Ctx) error {
req := new(validation.Create)
// func (u *UserController) CreateOne(c *fiber.Ctx) error {
// req := new(validation.Create)
if err := c.BodyParser(req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, "Invalid request body")
}
// if err := c.BodyParser(req); err != nil {
// return fiber.NewError(fiber.StatusBadRequest, "Invalid request body")
// }
result, err := u.UserService.CreateOne(c, req)
if err != nil {
return err
}
// result, err := u.UserService.CreateOne(c, req)
// if err != nil {
// return err
// }
return c.Status(fiber.StatusCreated).
JSON(response.Success{
Code: fiber.StatusCreated,
Status: "success",
Message: "Create user successfully",
Data: dto.ToUserListDTO(*result),
})
}
// return c.Status(fiber.StatusCreated).
// JSON(response.Success{
// Code: fiber.StatusCreated,
// Status: "success",
// Message: "Create user successfully",
// Data: dto.ToUserListDTO(*result),
// })
// }
func (u *UserController) UpdateOne(c *fiber.Ctx) error {
req := new(validation.Update)
param := c.Params("id")
// func (u *UserController) UpdateOne(c *fiber.Ctx) error {
// req := new(validation.Update)
// param := c.Params("id")
id, err := strconv.Atoi(param)
if err != nil {
return fiber.NewError(fiber.StatusBadRequest, "Invalid Id")
}
// id, err := strconv.Atoi(param)
// if err != nil {
// return fiber.NewError(fiber.StatusBadRequest, "Invalid Id")
// }
if err := c.BodyParser(req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, "Invalid request body")
}
// if err := c.BodyParser(req); err != nil {
// return fiber.NewError(fiber.StatusBadRequest, "Invalid request body")
// }
result, err := u.UserService.UpdateOne(c, req, uint(id))
if err != nil {
return err
}
// result, err := u.UserService.UpdateOne(c, req, uint(id))
// if err != nil {
// return err
// }
return c.Status(fiber.StatusOK).
JSON(response.Success{
Code: fiber.StatusOK,
Status: "success",
Message: "Update user successfully",
Data: dto.ToUserListDTO(*result),
})
}
// return c.Status(fiber.StatusOK).
// JSON(response.Success{
// Code: fiber.StatusOK,
// Status: "success",
// Message: "Update user successfully",
// Data: dto.ToUserListDTO(*result),
// })
// }
func (u *UserController) DeleteOne(c *fiber.Ctx) error {
param := c.Params("id")
// func (u *UserController) DeleteOne(c *fiber.Ctx) error {
// param := c.Params("id")
id, err := strconv.Atoi(param)
if err != nil {
return fiber.NewError(fiber.StatusBadRequest, "Invalid Id")
}
// id, err := strconv.Atoi(param)
// if err != nil {
// return fiber.NewError(fiber.StatusBadRequest, "Invalid Id")
// }
if err := u.UserService.DeleteOne(c, uint(id)); err != nil {
return err
}
// if err := u.UserService.DeleteOne(c, uint(id)); err != nil {
// return err
// }
return c.Status(fiber.StatusOK).
JSON(response.Common{
Code: fiber.StatusOK,
Status: "success",
Message: "Delete user successfully",
})
}
// return c.Status(fiber.StatusOK).
// JSON(response.Common{
// Code: fiber.StatusOK,
// Status: "success",
// Message: "Delete user successfully",
// })
// }
@@ -2,29 +2,118 @@ package repository
import (
"context"
"errors"
"time"
"gitlab.com/mbugroup/lti-api.git/internal/common/repository"
"github.com/jackc/pgconn"
commonrepo "gitlab.com/mbugroup/lti-api.git/internal/common/repository"
entity "gitlab.com/mbugroup/lti-api.git/internal/entities"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type UserRepository interface {
repository.BaseRepository[entity.User]
commonrepo.BaseRepository[entity.User]
IdExists(ctx context.Context, id uint) (bool, error)
GetByIdUser(ctx context.Context, idUser int64, modifier func(*gorm.DB) *gorm.DB) (*entity.User, error)
UpsertByIdUser(ctx context.Context, user *entity.User) error
SoftDeleteByIdUser(ctx context.Context, idUser int64) error
}
type UserRepositoryImpl struct {
*repository.BaseRepositoryImpl[entity.User]
*commonrepo.BaseRepositoryImpl[entity.User]
db *gorm.DB
}
func NewUserRepository(db *gorm.DB) UserRepository {
return &UserRepositoryImpl{
BaseRepositoryImpl: repository.NewBaseRepository[entity.User](db),
BaseRepositoryImpl: commonrepo.NewBaseRepository[entity.User](db),
db: db,
}
}
func (r *UserRepositoryImpl) IdExists(ctx context.Context, id uint) (bool, error) {
return repository.Exists[entity.User](ctx, r.db, id)
return commonrepo.Exists[entity.User](ctx, r.db, id)
}
func (r *UserRepositoryImpl) GetByIdUser(
ctx context.Context,
idUser int64,
modifier func(*gorm.DB) *gorm.DB,
) (*entity.User, error) {
return r.BaseRepositoryImpl.First(ctx, func(db *gorm.DB) *gorm.DB {
return db.Where("id_user = ?", idUser)
})
}
func (r *UserRepositoryImpl) UpsertByIdUser(ctx context.Context, user *entity.User) error {
if user == nil {
return gorm.ErrInvalidData
}
return r.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
now := time.Now()
user.DeletedAt = gorm.DeletedAt{}
user.UpdatedAt = now
err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "id_user"}},
UpdateAll: true,
}).Omit("id", "created_at").Create(user).Error
if err == nil {
return nil
}
if !isUniqueViolation(err, "users_email_unique") {
return err
}
var existing entity.User
lockQuery := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("email = ?", user.Email)
if err := lockQuery.First(&existing).Error; err != nil {
return err
}
user.Id = existing.Id
updates := map[string]any{
"id_user": user.IdUser,
"email": user.Email,
"name": user.Name,
"updated_at": now,
"deleted_at": gorm.DeletedAt{},
}
if err := tx.Model(&entity.User{}).Where("id = ?", existing.Id).Updates(updates).Error; err != nil {
return err
}
return nil
})
}
func (r *UserRepositoryImpl) SoftDeleteByIdUser(ctx context.Context, idUser int64) error {
query := r.DB().WithContext(ctx).Where("id_user = ?", idUser)
result := query.Delete(&entity.User{})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
func isUniqueViolation(err error, constraint string) bool {
var pgErr *pgconn.PgError
if !errors.As(err, &pgErr) {
return false
}
if pgErr.Code != "23505" {
return false
}
if constraint == "" {
return true
}
return pgErr.ConstraintName == constraint
}
+7 -5
View File
@@ -1,20 +1,22 @@
package users
import (
"github.com/gofiber/fiber/v2"
"gitlab.com/mbugroup/lti-api.git/internal/middleware"
controller "gitlab.com/mbugroup/lti-api.git/internal/modules/users/controllers"
user "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services"
"github.com/gofiber/fiber/v2"
)
func UserRoutes(v1 fiber.Router, s user.UserService) {
ctrl := controller.NewUserController(s)
route := v1.Group("/users")
route.Use(middleware.Auth(s))
route.Get("/", ctrl.GetAll)
route.Post("/", ctrl.CreateOne)
// route.Post("/", ctrl.CreateOne)
route.Get("/:id", ctrl.GetOne)
route.Patch("/:id", ctrl.UpdateOne)
route.Delete("/:id", ctrl.DeleteOne)
// route.Patch("/:id", ctrl.UpdateOne)
// route.Delete("/:id", ctrl.DeleteOne)
}
@@ -20,6 +20,7 @@ type UserService interface {
CreateOne(ctx *fiber.Ctx, req *validation.Create) (*entity.User, error)
UpdateOne(ctx *fiber.Ctx, req *validation.Update, id uint) (*entity.User, error)
DeleteOne(ctx *fiber.Ctx, id uint) error
GetBySSOUserID(ctx *fiber.Ctx, ssoUserID uint) (*entity.User, error)
}
type userService struct {
@@ -68,6 +69,18 @@ func (s userService) GetOne(c *fiber.Ctx, id uint) (*entity.User, error) {
return user, nil
}
func (s userService) GetBySSOUserID(c *fiber.Ctx, ssoUserID uint) (*entity.User, error) {
user, err := s.Repository.GetByIdUser(c.Context(), int64(ssoUserID), nil)
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fiber.NewError(fiber.StatusNotFound, "User not found")
}
if err != nil {
s.Log.Errorf("Failed get user by SSO id: %+v", err)
return nil, err
}
return user, nil
}
func (s *userService) CreateOne(c *fiber.Ctx, req *validation.Create) (*entity.User, error) {
if err := s.Validate.Struct(req); err != nil {
return nil, err
+4 -2
View File
@@ -11,11 +11,12 @@ import (
approvals "gitlab.com/mbugroup/lti-api.git/internal/modules/approvals"
constants "gitlab.com/mbugroup/lti-api.git/internal/modules/constants"
inventory "gitlab.com/mbugroup/lti-api.git/internal/modules/inventory"
marketing "gitlab.com/mbugroup/lti-api.git/internal/modules/marketing"
master "gitlab.com/mbugroup/lti-api.git/internal/modules/master"
production "gitlab.com/mbugroup/lti-api.git/internal/modules/production"
purchases "gitlab.com/mbugroup/lti-api.git/internal/modules/purchases"
ssoModule "gitlab.com/mbugroup/lti-api.git/internal/modules/sso"
users "gitlab.com/mbugroup/lti-api.git/internal/modules/users"
marketing "gitlab.com/mbugroup/lti-api.git/internal/modules/marketing"
// MODULE IMPORTS
)
@@ -33,7 +34,8 @@ func Routes(app *fiber.App, db *gorm.DB) {
production.ProductionModule{},
approvals.ApprovalModule{},
purchases.PurchaseModule{},
marketing.MarketingModule{},
marketing.MarketingModule{},
ssoModule.Module{},
// MODULE REGISTRY
}
+307
View File
@@ -0,0 +1,307 @@
package sso
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/redis/go-redis/v9"
"gitlab.com/mbugroup/lti-api.git/internal/cache"
"gitlab.com/mbugroup/lti-api.git/internal/config"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
)
const (
profileCachePrefix = "sso:profile:user:"
profileCacheTTL = time.Minute
)
var (
profileClient = &http.Client{Timeout: 5 * time.Second}
profileLocalCache sync.Map // map[string]cachedProfile
)
type cachedProfile struct {
Profile *UserProfile
ExpiresAt time.Time
}
// UserProfile represents the enriched user information returned by the central SSO.
type UserProfile struct {
UserID uint
Roles []Role
Permissions []Permission
}
// Role describes a role assignment from the SSO profile response.
type Role struct {
ID uint
Key string
Name string
ClientID uint
ClientAlias string
ClientName string
Permissions []Permission
RawReference json.RawMessage `json:"-"`
}
// Permission describes a granular permission entry from the SSO profile.
type Permission struct {
ID uint
Name string
Action string
ClientID uint
ClientAlias string
ClientName string
}
// PermissionNames returns a de-duplicated slice of permission identifiers in canonical form.
func (p *UserProfile) PermissionNames() []string {
if p == nil || len(p.Permissions) == 0 {
return nil
}
set := make(map[string]struct{}, len(p.Permissions))
for _, perm := range p.Permissions {
name := canonicalPermissionName(perm.Name)
if name != "" {
set[name] = struct{}{}
}
}
out := make([]string, 0, len(set))
for name := range set {
out = append(out, name)
}
return out
}
// FetchProfile retrieves the SSO profile for the authenticated user, using Redis/in-memory
// caching to reduce load on the SSO service. Only end-user tokens (subject user:ID) are supported.
func FetchProfile(ctx context.Context, token string, verification *VerificationResult) (*UserProfile, error) {
if verification == nil || verification.UserID == 0 {
return nil, errors.New("profile only available for user tokens")
}
key := profileCacheKey(verification.UserID)
if profile := loadProfileFromLocalCache(key); profile != nil {
return profile, nil
}
if profile := loadProfileFromRedis(ctx, key); profile != nil {
storeProfileInLocalCache(key, profile)
return profile, nil
}
profile, err := fetchProfileFromSSO(ctx, token)
if err != nil {
return nil, err
}
storeProfileInLocalCache(key, profile)
storeProfileInRedis(ctx, key, profile)
return profile, nil
}
func fetchProfileFromSSO(ctx context.Context, token string) (*UserProfile, error) {
endpoint := strings.TrimSpace(config.SSOGetMeURL)
if endpoint == "" {
return nil, errors.New("sso get-me endpoint not configured")
}
if ctx == nil {
ctx = context.Background()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("build profile request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
if cookieName := strings.TrimSpace(config.SSOAccessCookieName); cookieName != "" {
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", cookieName, token))
}
resp, err := profileClient.Do(req)
if err != nil {
return nil, fmt.Errorf("fetch profile: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("fetch profile: status %d", resp.StatusCode)
}
var envelope userInfoEnvelope
if err := json.NewDecoder(resp.Body).Decode(&envelope); err != nil {
return nil, fmt.Errorf("decode profile: %w", err)
}
roles := envelope.getRoles()
profile := &UserProfile{}
// Attempt to infer user id if provided.
if envelope.User != nil && envelope.User.ID > 0 {
profile.UserID = uint(envelope.User.ID)
}
perms := make([]Permission, 0)
convertedRoles := make([]Role, 0, len(roles))
for _, r := range roles {
role := Role{
ID: uint(r.ID),
Key: strings.TrimSpace(r.Key),
Name: strings.TrimSpace(r.Name),
ClientAlias: strings.TrimSpace(r.Client.Alias),
ClientName: strings.TrimSpace(r.Client.Name),
ClientID: uint(r.Client.ID),
}
rolePerms := make([]Permission, 0, len(r.Permissions))
for _, p := range r.Permissions {
perm := Permission{
ID: uint(p.ID),
Name: strings.TrimSpace(p.Name),
Action: strings.TrimSpace(p.Action),
ClientAlias: strings.TrimSpace(p.Client.Alias),
ClientName: strings.TrimSpace(p.Client.Name),
ClientID: uint(p.Client.ID),
}
if perm.Name != "" {
rolePerms = append(rolePerms, perm)
perms = append(perms, perm)
}
}
role.Permissions = rolePerms
convertedRoles = append(convertedRoles, role)
}
profile.Roles = convertedRoles
profile.Permissions = perms
return profile, nil
}
func loadProfileFromLocalCache(key string) *UserProfile {
if value, ok := profileLocalCache.Load(key); ok {
if cached, ok := value.(cachedProfile); ok {
if time.Now().Before(cached.ExpiresAt) && cached.Profile != nil {
return cached.Profile
}
profileLocalCache.Delete(key)
}
}
return nil
}
func loadProfileFromRedis(ctx context.Context, key string) *UserProfile {
client := cache.Redis()
if client == nil {
return nil
}
data, err := client.Get(ctx, key).Bytes()
if err != nil {
if !errors.Is(err, redis.Nil) {
utils.Log.WithError(err).Warn("sso profile redis lookup failed")
}
return nil
}
var profile UserProfile
if err := json.Unmarshal(data, &profile); err != nil {
utils.Log.WithError(err).Warn("sso profile redis decode failed")
return nil
}
return &profile
}
func storeProfileInLocalCache(key string, profile *UserProfile) {
if profile == nil {
return
}
profileLocalCache.Store(key, cachedProfile{
Profile: profile,
ExpiresAt: time.Now().Add(profileCacheTTL),
})
}
func storeProfileInRedis(ctx context.Context, key string, profile *UserProfile) {
client := cache.Redis()
if client == nil || profile == nil {
return
}
data, err := json.Marshal(profile)
if err != nil {
utils.Log.WithError(err).Warn("sso profile redis encode failed")
return
}
if err := client.Set(ctx, key, data, profileCacheTTL).Err(); err != nil {
utils.Log.WithError(err).Warn("sso profile redis store failed")
}
}
func profileCacheKey(userID uint) string {
return profileCachePrefix + strconv.FormatUint(uint64(userID), 10)
}
func canonicalPermissionName(name string) string {
return strings.ToLower(strings.TrimSpace(name))
}
// userInfoEnvelope handles the varying shapes returned by the SSO userinfo endpoint.
type userInfoEnvelope struct {
Roles []userInfoRole `json:"roles"`
Data *struct {
ID int64 `json:"id"`
Roles []userInfoRole `json:"roles"`
} `json:"data"`
User *struct {
ID int64 `json:"id"`
} `json:"user"`
}
func (e *userInfoEnvelope) getRoles() []userInfoRole {
if len(e.Roles) > 0 {
return e.Roles
}
if e.Data != nil && len(e.Data.Roles) > 0 {
if e.User == nil && e.Data.ID > 0 {
e.User = &struct {
ID int64 `json:"id"`
}{ID: e.Data.ID}
}
return e.Data.Roles
}
return nil
}
type userInfoRole struct {
ID int64 `json:"id"`
Key string `json:"key"`
Name string `json:"name"`
Client userInfoClient `json:"client"`
Permissions []userInfoPermRaw `json:"permissions"`
}
type userInfoClient struct {
ID int64 `json:"id"`
Name string `json:"name"`
Alias string `json:"alias"`
}
type userInfoPermRaw struct {
ID int64 `json:"id"`
Name string `json:"name"`
Action string `json:"action"`
Client userInfoClient `json:"client"`
Details any `json:"details"`
}
+160
View File
@@ -0,0 +1,160 @@
package sso
import (
"context"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/MicahParks/keyfunc/v2"
"github.com/golang-jwt/jwt/v5"
"gitlab.com/mbugroup/lti-api.git/internal/utils"
)
type verifier struct {
jwks *keyfunc.JWKS
issuer string
audiences map[string]struct{}
}
type AccessTokenClaims struct {
Scope string `json:"scope"`
jwt.RegisteredClaims
}
func (c AccessTokenClaims) Scopes() []string {
if c.Scope == "" {
return nil
}
return strings.Fields(c.Scope)
}
type VerificationResult struct {
UserID uint
ServiceAlias string
Subject string
Claims *AccessTokenClaims
}
var (
globalMu sync.RWMutex
globalV *verifier
)
func Init(ctx context.Context, jwksURL, issuer string, audiences []string) error {
jwksURL = strings.TrimSpace(jwksURL)
issuer = strings.TrimSpace(issuer)
if jwksURL == "" || issuer == "" {
return errors.New("missing SSO JWKS or issuer configuration")
}
client := &http.Client{Timeout: 5 * time.Second}
options := keyfunc.Options{
Ctx: ctx,
Client: client,
RefreshTimeout: 10 * time.Second,
RefreshInterval: time.Hour,
RefreshUnknownKID: true,
RefreshErrorHandler: func(err error) {
utils.Log.Errorf("sso jwks refresh failed: %v", err)
},
}
jwks, err := keyfunc.Get(jwksURL, options)
if err != nil {
return fmt.Errorf("load jwks: %w", err)
}
audienceMap := make(map[string]struct{}, len(audiences))
for _, aud := range audiences {
aud = strings.TrimSpace(aud)
if aud == "" {
continue
}
audienceMap[aud] = struct{}{}
}
globalMu.Lock()
globalV = &verifier{jwks: jwks, issuer: issuer, audiences: audienceMap}
globalMu.Unlock()
utils.Log.Infof("sso verifier initialized for issuer %s (%d keys)", issuer, len(jwks.KIDs()))
return nil
}
func VerifyAccessToken(token string) (*VerificationResult, error) {
token = strings.TrimSpace(token)
if token == "" {
return nil, errors.New("empty token")
}
globalMu.RLock()
v := globalV
globalMu.RUnlock()
if v == nil {
return nil, errors.New("sso verifier not initialized")
}
claims := &AccessTokenClaims{}
parser := jwt.NewParser(
jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()}),
jwt.WithIssuedAt(),
jwt.WithExpirationRequired(),
)
tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
if err != nil {
return nil, fmt.Errorf("parse token: %w", err)
}
if !tok.Valid {
return nil, errors.New("invalid token")
}
if claims.Issuer != v.issuer {
return nil, errors.New("unexpected token issuer")
}
if len(v.audiences) > 0 {
validAud := false
for _, aud := range claims.Audience {
if _, ok := v.audiences[aud]; ok {
validAud = true
break
}
}
if !validAud {
return nil, errors.New("unexpected token audience")
}
}
sub := strings.TrimSpace(claims.Subject)
if sub == "" {
return nil, errors.New("missing subject")
}
result := &VerificationResult{Claims: claims, Subject: sub}
switch {
case strings.HasPrefix(sub, "user:"):
idStr := strings.TrimPrefix(sub, "user:")
id, err := strconv.ParseUint(idStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid subject: %w", err)
}
result.UserID = uint(id)
case strings.HasPrefix(sub, "service:"):
alias := strings.TrimSpace(strings.TrimPrefix(sub, "service:"))
if alias == "" {
return nil, errors.New("invalid service subject")
}
result.ServiceAlias = strings.ToLower(alias)
default:
return nil, errors.New("unsupported subject type")
}
return result, nil
}
+2 -2
View File
@@ -10,8 +10,8 @@ import (
)
func ErrorHandler(c *fiber.Ctx, err error) error {
if errorsMap := validation.CustomErrorMessages(err); len(errorsMap) > 0 {
return response.Error(c, fiber.StatusBadRequest, "Bad Request", errorsMap)
if message, errorsMap := validation.CustomErrorMessages(err); len(errorsMap) > 0 {
return response.Error(c, fiber.StatusBadRequest, message, nil)
}
var fiberErr *fiber.Error
+84
View File
@@ -0,0 +1,84 @@
package secure
import (
"crypto/rand"
"encoding/base64"
"fmt"
"strings"
)
const pkceCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
// RandomBytes returns securely generated random bytes of given length.
func RandomBytes(length int) ([]byte, error) {
if length <= 0 {
return nil, fmt.Errorf("length must be positive")
}
b := make([]byte, length)
if _, err := rand.Read(b); err != nil {
return nil, err
}
return b, nil
}
// RandomString returns a base64url encoded random string of approximately the requested length.
func RandomString(length int) (string, error) {
if length <= 0 {
return "", fmt.Errorf("length must be positive")
}
// Generate ceil(length * 6/8) bytes to have enough entropy for base64 url encoding.
byteLen := (length*6 + 7) / 8
bytes, err := RandomBytes(byteLen)
if err != nil {
return "", err
}
s := base64.RawURLEncoding.EncodeToString(bytes)
if len(s) > length {
return s[:length], nil
}
// If encoded string shorter, pad using charset
if len(s) < length {
sb := strings.Builder{}
sb.WriteString(s)
extraNeeded := length - len(s)
more, err := randomFromCharset(extraNeeded)
if err != nil {
return "", err
}
sb.WriteString(more)
return sb.String(), nil
}
return s, nil
}
// PKCECodeVerifier generates a random string compliant with RFC 7636.
func PKCECodeVerifier(length int) (string, error) {
if length < 43 {
length = 43
}
if length > 128 {
length = 128
}
return randomFromCharset(length)
}
// randomFromCharset returns a random string of given length using pkceCharset.
func randomFromCharset(length int) (string, error) {
if length <= 0 {
return "", fmt.Errorf("length must be positive")
}
bytes, err := RandomBytes(length)
if err != nil {
return "", err
}
out := make([]byte, length)
for i, b := range bytes {
out[i] = pkceCharset[int(b)%len(pkceCharset)]
}
return string(out), nil
}
// Base64URLEncode encodes the input bytes using base64 URL encoding without padding.
func Base64URLEncode(data []byte) string {
return base64.RawURLEncoding.EncodeToString(data)
}
-45
View File
@@ -1,45 +0,0 @@
package utils
import (
"errors"
"strconv"
"github.com/golang-jwt/jwt/v5"
)
func VerifyToken(tokenStr, secret, tokenType string) (uint, error) {
token, err := jwt.Parse(tokenStr, func(_ *jwt.Token) (interface{}, error) {
return []byte(secret), nil
})
if err != nil || !token.Valid {
return 0, err
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return 0, errors.New("invalid token claims")
}
jwtType, ok := claims["type"].(string)
if !ok || jwtType != tokenType {
return 0, errors.New("invalid token type")
}
sub, ok := claims["sub"]
if !ok {
return 0, errors.New("invalid token sub")
}
switch v := sub.(type) {
case float64:
return uint(v), nil
case string:
id, err := strconv.Atoi(v)
if err != nil {
return 0, errors.New("invalid sub format")
}
return uint(id), nil
default:
return 0, errors.New("unsupported sub type")
}
}