mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-20 13:31:56 +00:00
311 lines
9.9 KiB
Go
311 lines
9.9 KiB
Go
package config
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"gitlab.com/mbugroup/lti-api.git/internal/utils"
|
|
|
|
"github.com/spf13/viper"
|
|
)
|
|
|
|
type SSOClientConfig struct {
|
|
PublicID string `json:"public_id"`
|
|
RedirectURI string `json:"redirect_uri"`
|
|
Scope string `json:"scope"`
|
|
// Prompt string `json:"prompt"`
|
|
DefaultReturnURI string `json:"default_return_uri"`
|
|
AllowedReturnOrigins []string `json:"allowed_return_origins"`
|
|
SyncSecret string `json:"sync_secret"`
|
|
}
|
|
|
|
var (
|
|
IsProd bool
|
|
AppHost string
|
|
Version string
|
|
LogLevel string
|
|
AppPort int
|
|
DBHost string
|
|
DBUser string
|
|
DBPassword string
|
|
DBName string
|
|
DBPort int
|
|
DBSSLMode string
|
|
DBSSLRootCert string
|
|
DBSSLCert string
|
|
DBSSLKey string
|
|
JWTSecret string
|
|
JWTAccessExp int
|
|
JWTRefreshExp int
|
|
JWTResetPasswordExp int
|
|
JWTVerifyEmailExp int
|
|
RedisURL string
|
|
CORSAllowOrigins []string
|
|
CORSAllowMethods []string
|
|
CORSAllowHeaders []string
|
|
CORSExposeHeaders []string
|
|
CORSAllowCredentials bool
|
|
CORSMaxAge int
|
|
SSOIssuer string
|
|
SSOJWKSURL string
|
|
SSOHMACSecret string
|
|
SSOAllowedAudiences []string
|
|
SSOAuthorizeURL string
|
|
SSOTokenURL string
|
|
SSOGetMeURL string
|
|
SSOPortalURL string
|
|
SSOClients map[string]SSOClientConfig
|
|
SSOAccessCookieName string
|
|
SSOAccessCookieFallback []string
|
|
SSORefreshCookieName string
|
|
SSOCookieDomain string
|
|
SSOCookieSecure bool
|
|
SSOCookieSameSite string
|
|
SSOAccessTokenMaxBytes int
|
|
SSOTokenBlacklistPrefix string
|
|
SSOPKCETTL time.Duration
|
|
SSOUserSyncDrift time.Duration
|
|
SSOUserSyncNonceTTL time.Duration
|
|
SSOUserSyncMaxBodyBytes int
|
|
S3Endpoint string
|
|
S3Region string
|
|
S3Bucket string
|
|
S3AccessKey string
|
|
S3SecretKey string
|
|
S3ForcePathStyle bool
|
|
S3PublicBaseURL string
|
|
S3EnvPrefix string
|
|
S3DocumentKeyPrefix string
|
|
)
|
|
|
|
func init() {
|
|
loadConfig()
|
|
|
|
// server configuration
|
|
IsProd = viper.GetString("APP_ENV") == "prod"
|
|
AppHost = viper.GetString("APP_HOST")
|
|
AppPort = viper.GetInt("APP_PORT")
|
|
Version = viper.GetString("VERSION")
|
|
LogLevel = viper.GetString("LOG_LEVEL")
|
|
|
|
// database configuration
|
|
DBHost = viper.GetString("DB_HOST")
|
|
DBUser = viper.GetString("DB_USER")
|
|
DBPassword = viper.GetString("DB_PASSWORD")
|
|
DBName = viper.GetString("DB_NAME")
|
|
DBPort = viper.GetInt("DB_PORT")
|
|
DBSSLMode = defaultString(viper.GetString("DB_SSLMODE"), "disable")
|
|
DBSSLRootCert = strings.TrimSpace(viper.GetString("DB_SSLROOTCERT"))
|
|
DBSSLCert = strings.TrimSpace(viper.GetString("DB_SSLCERT"))
|
|
DBSSLKey = strings.TrimSpace(viper.GetString("DB_SSLKEY"))
|
|
|
|
// jwt configuration
|
|
JWTSecret = viper.GetString("JWT_SECRET")
|
|
JWTAccessExp = viper.GetInt("JWT_ACCESS_EXP_MINUTES")
|
|
JWTRefreshExp = viper.GetInt("JWT_REFRESH_EXP_DAYS")
|
|
JWTResetPasswordExp = viper.GetInt("JWT_RESET_PASSWORD_EXP_MINUTES")
|
|
JWTVerifyEmailExp = viper.GetInt("JWT_VERIFY_EMAIL_EXP_MINUTES")
|
|
|
|
//Cors
|
|
CORSAllowOrigins = parseList("CORS_ALLOW_ORIGINS")
|
|
CORSAllowMethods = parseListWithDefault("CORS_ALLOW_METHODS", "GET,POST,PUT,PATCH,DELETE,OPTIONS")
|
|
CORSAllowHeaders = parseListWithDefault("CORS_ALLOW_HEADERS", "Content-Type,Authorization,X-Requested-With")
|
|
CORSExposeHeaders = parseList("CORS_EXPOSE_HEADERS")
|
|
CORSAllowCredentials = viper.GetBool("CORS_ALLOW_CREDENTIALS")
|
|
CORSMaxAge = viper.GetInt("CORS_MAX_AGE")
|
|
|
|
// Redis
|
|
RedisURL = viper.GetString("REDIS_URL")
|
|
|
|
// Object storage
|
|
S3Endpoint = strings.TrimSpace(viper.GetString("S3_ENDPOINT"))
|
|
S3Region = strings.TrimSpace(viper.GetString("S3_REGION"))
|
|
S3Bucket = strings.TrimSpace(viper.GetString("S3_BUCKET"))
|
|
S3AccessKey = strings.TrimSpace(viper.GetString("S3_ACCESS_KEY"))
|
|
S3SecretKey = strings.TrimSpace(viper.GetString("S3_SECRET_KEY"))
|
|
S3ForcePathStyle = viper.GetBool("S3_FORCE_PATH_STYLE")
|
|
S3PublicBaseURL = strings.TrimSuffix(strings.TrimSpace(viper.GetString("S3_PUBLIC_BASE_URL")), "/")
|
|
S3EnvPrefix = defaultString(strings.Trim(strings.TrimSpace(viper.GetString("S3_ENV_PREFIX")), "/"), "local")
|
|
docPrefix := strings.Trim(strings.TrimSpace(viper.GetString("S3_DOCUMENT_PREFIX")), "/")
|
|
if docPrefix == "" {
|
|
docPrefix = "docs"
|
|
}
|
|
S3DocumentKeyPrefix = joinPath(S3EnvPrefix, docPrefix)
|
|
|
|
// SSO integration
|
|
SSOIssuer = viper.GetString("SSO_ISSUER")
|
|
SSOJWKSURL = viper.GetString("SSO_JWKS_URL")
|
|
SSOHMACSecret = viper.GetString("SSO_HS_SECRET")
|
|
SSOAllowedAudiences = parseList("SSO_ALLOWED_AUDIENCES")
|
|
SSOAuthorizeURL = viper.GetString("SSO_AUTHORIZE_URL")
|
|
SSOTokenURL = viper.GetString("SSO_TOKEN_URL")
|
|
SSOGetMeURL = viper.GetString("SSO_GETME_URL")
|
|
SSOPortalURL = strings.TrimSpace(viper.GetString("SSO_PORTAL_URL"))
|
|
SSOAccessCookieName = defaultString(viper.GetString("SSO_ACCESS_COOKIE_NAME"), "sso_access")
|
|
SSOAccessCookieFallback = parseList("SSO_ACCESS_COOKIE_FALLBACK")
|
|
SSORefreshCookieName = defaultString(viper.GetString("SSO_REFRESH_COOKIE_NAME"), "sso_refresh")
|
|
SSOCookieDomain = viper.GetString("SSO_COOKIE_DOMAIN")
|
|
SSOCookieSecure = viper.GetBool("SSO_COOKIE_SECURE")
|
|
SSOCookieSameSite = defaultString(viper.GetString("SSO_COOKIE_SAMESITE"), "Lax")
|
|
SSOAccessTokenMaxBytes = viper.GetInt("SSO_ACCESS_TOKEN_MAX_BYTES")
|
|
if SSOAccessTokenMaxBytes <= 0 {
|
|
SSOAccessTokenMaxBytes = 4096
|
|
}
|
|
SSOTokenBlacklistPrefix = defaultString(viper.GetString("SSO_TOKEN_BLACKLIST_PREFIX"), "sso:blacklist")
|
|
if ttl := viper.GetInt("SSO_PKCE_TTL_SECONDS"); ttl > 0 {
|
|
SSOPKCETTL = time.Duration(ttl) * time.Second
|
|
} else {
|
|
SSOPKCETTL = 5 * time.Minute
|
|
}
|
|
SSOClients = loadSSOClients("SSO_CLIENTS")
|
|
if drift := viper.GetInt("SSO_USER_SYNC_SIGNATURE_DRIFT_SECONDS"); drift > 0 {
|
|
SSOUserSyncDrift = time.Duration(drift) * time.Second
|
|
} else {
|
|
SSOUserSyncDrift = 2 * time.Minute
|
|
}
|
|
if ttl := viper.GetInt("SSO_USER_SYNC_NONCE_TTL_SECONDS"); ttl > 0 {
|
|
SSOUserSyncNonceTTL = time.Duration(ttl) * time.Second
|
|
} else {
|
|
SSOUserSyncNonceTTL = 10 * time.Minute
|
|
}
|
|
SSOUserSyncMaxBodyBytes = viper.GetInt("SSO_USER_SYNC_MAX_BODY_BYTES")
|
|
if SSOUserSyncMaxBodyBytes <= 0 {
|
|
SSOUserSyncMaxBodyBytes = 32 * 1024
|
|
}
|
|
|
|
if IsProd {
|
|
ensureProdConfig()
|
|
}
|
|
}
|
|
|
|
func loadConfig() {
|
|
viper.AutomaticEnv()
|
|
|
|
viper.SetConfigFile(".env")
|
|
if err := viper.ReadInConfig(); err == nil {
|
|
utils.Log.Info("Config file loaded from .env")
|
|
} else {
|
|
utils.Log.Warn("No .env file found, using environment variables only")
|
|
}
|
|
}
|
|
|
|
func parseList(key string) []string {
|
|
raw := strings.TrimSpace(viper.GetString(key))
|
|
if raw == "" {
|
|
return nil
|
|
}
|
|
if strings.HasPrefix(raw, "[") {
|
|
var arr []string
|
|
if json.Unmarshal([]byte(raw), &arr) == nil {
|
|
for i := range arr {
|
|
arr[i] = strings.TrimSpace(arr[i])
|
|
}
|
|
return arr
|
|
}
|
|
}
|
|
parts := strings.Split(raw, ",")
|
|
out := make([]string, 0, len(parts))
|
|
for _, p := range parts {
|
|
p = strings.TrimSpace(p)
|
|
if p != "" {
|
|
out = append(out, p)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func parseListWithDefault(key, def string) []string {
|
|
if v := parseList(key); len(v) > 0 {
|
|
return v
|
|
}
|
|
// fallback ke default CSV
|
|
parts := strings.Split(def, ",")
|
|
for i := range parts {
|
|
parts[i] = strings.TrimSpace(parts[i])
|
|
}
|
|
return parts
|
|
}
|
|
|
|
func loadSSOClients(key string) map[string]SSOClientConfig {
|
|
clients := make(map[string]SSOClientConfig)
|
|
raw := strings.TrimSpace(viper.GetString(key))
|
|
if raw == "" {
|
|
return clients
|
|
}
|
|
if err := json.Unmarshal([]byte(raw), &clients); err != nil {
|
|
utils.Log.Errorf("Failed to parse %s: %v", key, err)
|
|
return make(map[string]SSOClientConfig)
|
|
}
|
|
result := make(map[string]SSOClientConfig, len(clients))
|
|
for alias, cfg := range clients {
|
|
alias = strings.ToLower(strings.TrimSpace(alias))
|
|
for i, origin := range cfg.AllowedReturnOrigins {
|
|
cfg.AllowedReturnOrigins[i] = strings.TrimSpace(origin)
|
|
}
|
|
cfg.SyncSecret = strings.TrimSpace(cfg.SyncSecret)
|
|
result[alias] = cfg
|
|
}
|
|
return result
|
|
}
|
|
|
|
func defaultString(v, def string) string {
|
|
if strings.TrimSpace(v) == "" {
|
|
return def
|
|
}
|
|
return v
|
|
}
|
|
|
|
func joinPath(parts ...string) string {
|
|
out := make([]string, 0, len(parts))
|
|
for _, part := range parts {
|
|
part = strings.Trim(part, "/")
|
|
if part != "" {
|
|
out = append(out, part)
|
|
}
|
|
}
|
|
return strings.Join(out, "/")
|
|
}
|
|
|
|
func ensureProdConfig() {
|
|
if SSOAuthorizeURL == "" || !strings.HasPrefix(SSOAuthorizeURL, "https://") {
|
|
panic("SSO_AUTHORIZE_URL must be https in production")
|
|
}
|
|
if strings.TrimSpace(SSOHMACSecret) == "" && strings.TrimSpace(SSOJWKSURL) == "" {
|
|
panic("SSO_JWKS_URL or SSO_HS_SECRET must be configured in production")
|
|
}
|
|
if SSOTokenURL == "" || !strings.HasPrefix(SSOTokenURL, "https://") {
|
|
panic("SSO_TOKEN_URL must be https in production")
|
|
}
|
|
if SSOGetMeURL == "" || !strings.HasPrefix(SSOGetMeURL, "https://") {
|
|
panic("SSO_GETME_URL must be https in production")
|
|
}
|
|
if !SSOCookieSecure {
|
|
panic("SSO_COOKIE_SECURE must be true in production")
|
|
}
|
|
if SSOCookieDomain == "" {
|
|
panic("SSO_COOKIE_DOMAIN must be configured in production")
|
|
}
|
|
if len(SSOAllowedAudiences) == 0 {
|
|
panic("SSO_ALLOWED_AUDIENCES must contain at least one audience in production")
|
|
}
|
|
for alias, cfg := range SSOClients {
|
|
if strings.TrimSpace(cfg.SyncSecret) == "" {
|
|
panic(fmt.Sprintf("SSO_CLIENTS[%s].sync_secret must be configured in production", alias))
|
|
}
|
|
if len(cfg.SyncSecret) < 16 {
|
|
panic(fmt.Sprintf("SSO_CLIENTS[%s].sync_secret must be at least 16 characters", alias))
|
|
}
|
|
}
|
|
if SSOUserSyncDrift <= 0 {
|
|
panic("SSO_USER_SYNC_SIGNATURE_DRIFT_SECONDS must be greater than zero in production")
|
|
}
|
|
if SSOUserSyncNonceTTL <= 0 {
|
|
panic("SSO_USER_SYNC_NONCE_TTL_SECONDS must be greater than zero in production")
|
|
}
|
|
if SSOUserSyncMaxBodyBytes <= 0 {
|
|
panic("SSO_USER_SYNC_MAX_BODY_BYTES must be greater than zero in production")
|
|
}
|
|
}
|