mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-20 13:31:56 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2cf4ab03ad |
@@ -42,6 +42,8 @@ Copy .env.example to .env and adjust the variables (e.g. DATABASE_URL, JWT secre
|
|||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Catatan: isi `SSO_HS_SECRET` jika ingin verifikasi token HS256 tanpa JWKS.
|
||||||
|
|
||||||
### 5. Setup Docker
|
### 5. Setup Docker
|
||||||
|
|
||||||
Run initial docker.
|
Run initial docker.
|
||||||
|
|||||||
+1
-1
@@ -69,7 +69,7 @@ func setupSSO(ctx context.Context, rdb *redis.Client) {
|
|||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||||
if err := sso.Init(ctx, config.SSOJWKSURL, config.SSOIssuer, config.SSOAllowedAudiences); err != nil {
|
if err := sso.Init(ctx, config.SSOJWKSURL, config.SSOIssuer, config.SSOAllowedAudiences, config.SSOHMACSecret); err != nil {
|
||||||
lastErr = err
|
lastErr = err
|
||||||
utils.Log.WithError(err).Warnf("SSO initialization attempt %d/%d failed", attempt, maxAttempts)
|
utils.Log.WithError(err).Warnf("SSO initialization attempt %d/%d failed", attempt, maxAttempts)
|
||||||
select {
|
select {
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ var (
|
|||||||
CORSMaxAge int
|
CORSMaxAge int
|
||||||
SSOIssuer string
|
SSOIssuer string
|
||||||
SSOJWKSURL string
|
SSOJWKSURL string
|
||||||
|
SSOHMACSecret string
|
||||||
SSOAllowedAudiences []string
|
SSOAllowedAudiences []string
|
||||||
SSOAuthorizeURL string
|
SSOAuthorizeURL string
|
||||||
SSOTokenURL string
|
SSOTokenURL string
|
||||||
@@ -136,6 +137,7 @@ func init() {
|
|||||||
// SSO integration
|
// SSO integration
|
||||||
SSOIssuer = viper.GetString("SSO_ISSUER")
|
SSOIssuer = viper.GetString("SSO_ISSUER")
|
||||||
SSOJWKSURL = viper.GetString("SSO_JWKS_URL")
|
SSOJWKSURL = viper.GetString("SSO_JWKS_URL")
|
||||||
|
SSOHMACSecret = viper.GetString("SSO_HS_SECRET")
|
||||||
SSOAllowedAudiences = parseList("SSO_ALLOWED_AUDIENCES")
|
SSOAllowedAudiences = parseList("SSO_ALLOWED_AUDIENCES")
|
||||||
SSOAuthorizeURL = viper.GetString("SSO_AUTHORIZE_URL")
|
SSOAuthorizeURL = viper.GetString("SSO_AUTHORIZE_URL")
|
||||||
SSOTokenURL = viper.GetString("SSO_TOKEN_URL")
|
SSOTokenURL = viper.GetString("SSO_TOKEN_URL")
|
||||||
@@ -270,6 +272,9 @@ func ensureProdConfig() {
|
|||||||
if SSOAuthorizeURL == "" || !strings.HasPrefix(SSOAuthorizeURL, "https://") {
|
if SSOAuthorizeURL == "" || !strings.HasPrefix(SSOAuthorizeURL, "https://") {
|
||||||
panic("SSO_AUTHORIZE_URL must be https in production")
|
panic("SSO_AUTHORIZE_URL must be https in production")
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(SSOHMACSecret) == "" && strings.TrimSpace(SSOJWKSURL) == "" {
|
||||||
|
panic("SSO_JWKS_URL or SSO_HS_SECRET must be configured in production")
|
||||||
|
}
|
||||||
if SSOTokenURL == "" || !strings.HasPrefix(SSOTokenURL, "https://") {
|
if SSOTokenURL == "" || !strings.HasPrefix(SSOTokenURL, "https://") {
|
||||||
panic("SSO_TOKEN_URL must be https in production")
|
panic("SSO_TOKEN_URL must be https in production")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,9 +19,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type verifier struct {
|
type verifier struct {
|
||||||
jwks *keyfunc.JWKS
|
jwks *keyfunc.JWKS
|
||||||
issuer string
|
issuer string
|
||||||
audiences map[string]struct{}
|
audiences map[string]struct{}
|
||||||
|
hmacSecret []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccessTokenClaims struct {
|
type AccessTokenClaims struct {
|
||||||
@@ -58,13 +59,39 @@ var (
|
|||||||
globalV *verifier
|
globalV *verifier
|
||||||
)
|
)
|
||||||
|
|
||||||
func Init(ctx context.Context, jwksURL, issuer string, audiences []string) error {
|
func Init(ctx context.Context, jwksURL, issuer string, audiences []string, hmacSecret string) error {
|
||||||
jwksURL = strings.TrimSpace(jwksURL)
|
jwksURL = strings.TrimSpace(jwksURL)
|
||||||
issuer = strings.TrimSpace(issuer)
|
issuer = strings.TrimSpace(issuer)
|
||||||
if jwksURL == "" || issuer == "" {
|
hmacSecret = strings.TrimSpace(hmacSecret)
|
||||||
return errors.New("missing SSO JWKS or issuer configuration")
|
if issuer == "" {
|
||||||
|
return errors.New("missing SSO issuer configuration")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
audienceMap := make(map[string]struct{}, len(audiences))
|
||||||
|
for _, aud := range audiences {
|
||||||
|
aud = strings.TrimSpace(aud)
|
||||||
|
if aud == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
audienceMap[aud] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
globalMu.Lock()
|
||||||
|
if hmacSecret != "" {
|
||||||
|
globalV = &verifier{
|
||||||
|
jwks: nil,
|
||||||
|
issuer: issuer,
|
||||||
|
audiences: audienceMap,
|
||||||
|
hmacSecret: []byte(hmacSecret),
|
||||||
|
}
|
||||||
|
globalMu.Unlock()
|
||||||
|
utils.Log.Infof("sso verifier initialized for issuer %s (hmac)", issuer)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if jwksURL == "" {
|
||||||
|
globalMu.Unlock()
|
||||||
|
return errors.New("missing SSO JWKS configuration")
|
||||||
|
}
|
||||||
client := &http.Client{Timeout: 5 * time.Second}
|
client := &http.Client{Timeout: 5 * time.Second}
|
||||||
options := keyfunc.Options{
|
options := keyfunc.Options{
|
||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
@@ -79,19 +106,9 @@ func Init(ctx context.Context, jwksURL, issuer string, audiences []string) error
|
|||||||
|
|
||||||
jwks, err := keyfunc.Get(jwksURL, options)
|
jwks, err := keyfunc.Get(jwksURL, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
globalMu.Unlock()
|
||||||
return fmt.Errorf("load jwks: %w", err)
|
return fmt.Errorf("load jwks: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
audienceMap := make(map[string]struct{}, len(audiences))
|
|
||||||
for _, aud := range audiences {
|
|
||||||
aud = strings.TrimSpace(aud)
|
|
||||||
if aud == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
audienceMap[aud] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
globalMu.Lock()
|
|
||||||
globalV = &verifier{jwks: jwks, issuer: issuer, audiences: audienceMap}
|
globalV = &verifier{jwks: jwks, issuer: issuer, audiences: audienceMap}
|
||||||
globalMu.Unlock()
|
globalMu.Unlock()
|
||||||
|
|
||||||
@@ -113,27 +130,47 @@ func VerifyAccessToken(token string) (*VerificationResult, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := &AccessTokenClaims{}
|
claims := &AccessTokenClaims{}
|
||||||
parser := jwt.NewParser(
|
if len(v.hmacSecret) > 0 {
|
||||||
jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()}),
|
parser := jwt.NewParser(
|
||||||
jwt.WithIssuedAt(),
|
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
|
||||||
jwt.WithExpirationRequired(),
|
jwt.WithIssuedAt(),
|
||||||
)
|
jwt.WithExpirationRequired(),
|
||||||
|
)
|
||||||
tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
tok, err := parser.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) {
|
||||||
if err != nil {
|
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
if shouldRefreshOnVerifyError(err) {
|
return nil, errors.New("invalid token signing method")
|
||||||
if refreshErr := v.jwks.Refresh(context.Background(), keyfunc.RefreshOptions{IgnoreRateLimit: true}); refreshErr != nil {
|
|
||||||
utils.Log.WithError(refreshErr).Warn("sso jwks refresh after signature error failed")
|
|
||||||
} else {
|
|
||||||
tok, err = parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
|
||||||
}
|
}
|
||||||
}
|
return v.hmacSecret, nil
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parse token: %w", err)
|
return nil, fmt.Errorf("parse token: %w", err)
|
||||||
}
|
}
|
||||||
}
|
if !tok.Valid {
|
||||||
if !tok.Valid {
|
return nil, errors.New("invalid token")
|
||||||
return nil, errors.New("invalid token")
|
}
|
||||||
|
} else {
|
||||||
|
parser := jwt.NewParser(
|
||||||
|
jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()}),
|
||||||
|
jwt.WithIssuedAt(),
|
||||||
|
jwt.WithExpirationRequired(),
|
||||||
|
)
|
||||||
|
|
||||||
|
tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
||||||
|
if err != nil {
|
||||||
|
if shouldRefreshOnVerifyError(err) {
|
||||||
|
if refreshErr := v.jwks.Refresh(context.Background(), keyfunc.RefreshOptions{IgnoreRateLimit: true}); refreshErr != nil {
|
||||||
|
utils.Log.WithError(refreshErr).Warn("sso jwks refresh after signature error failed")
|
||||||
|
} else {
|
||||||
|
tok, err = parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse token: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !tok.Valid {
|
||||||
|
return nil, errors.New("invalid token")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if claims.Issuer != v.issuer {
|
if claims.Issuer != v.issuer {
|
||||||
|
|||||||
@@ -0,0 +1,42 @@
|
|||||||
|
package sso
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestVerifyAccessTokenHMAC(t *testing.T) {
|
||||||
|
secret := "test-secret-123"
|
||||||
|
issuer := "http://localhost:8080"
|
||||||
|
aud := []string{"client:1"}
|
||||||
|
|
||||||
|
if err := Init(context.Background(), "", issuer, aud, secret); err != nil {
|
||||||
|
t.Fatalf("Init error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := &AccessTokenClaims{
|
||||||
|
Scope: "openid profile",
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Issuer: issuer,
|
||||||
|
Subject: "user:1",
|
||||||
|
Audience: jwt.ClaimStrings(aud),
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)),
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(secret))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sign token error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := VerifyAccessToken(token)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("VerifyAccessToken error: %v", err)
|
||||||
|
}
|
||||||
|
if result.UserID != 1 {
|
||||||
|
t.Fatalf("unexpected user id: %d", result.UserID)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user