mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-20 21:41:55 +00:00
[FEAT/BE] down to hs256 without rotate key
This commit is contained in:
@@ -19,9 +19,10 @@ import (
|
||||
)
|
||||
|
||||
type verifier struct {
|
||||
jwks *keyfunc.JWKS
|
||||
issuer string
|
||||
audiences map[string]struct{}
|
||||
jwks *keyfunc.JWKS
|
||||
issuer string
|
||||
audiences map[string]struct{}
|
||||
hmacSecret []byte
|
||||
}
|
||||
|
||||
type AccessTokenClaims struct {
|
||||
@@ -58,13 +59,39 @@ var (
|
||||
globalV *verifier
|
||||
)
|
||||
|
||||
func Init(ctx context.Context, jwksURL, issuer string, audiences []string) error {
|
||||
func Init(ctx context.Context, jwksURL, issuer string, audiences []string, hmacSecret string) error {
|
||||
jwksURL = strings.TrimSpace(jwksURL)
|
||||
issuer = strings.TrimSpace(issuer)
|
||||
if jwksURL == "" || issuer == "" {
|
||||
return errors.New("missing SSO JWKS or issuer configuration")
|
||||
hmacSecret = strings.TrimSpace(hmacSecret)
|
||||
if issuer == "" {
|
||||
return errors.New("missing SSO issuer configuration")
|
||||
}
|
||||
|
||||
audienceMap := make(map[string]struct{}, len(audiences))
|
||||
for _, aud := range audiences {
|
||||
aud = strings.TrimSpace(aud)
|
||||
if aud == "" {
|
||||
continue
|
||||
}
|
||||
audienceMap[aud] = struct{}{}
|
||||
}
|
||||
|
||||
globalMu.Lock()
|
||||
if hmacSecret != "" {
|
||||
globalV = &verifier{
|
||||
jwks: nil,
|
||||
issuer: issuer,
|
||||
audiences: audienceMap,
|
||||
hmacSecret: []byte(hmacSecret),
|
||||
}
|
||||
globalMu.Unlock()
|
||||
utils.Log.Infof("sso verifier initialized for issuer %s (hmac)", issuer)
|
||||
return nil
|
||||
}
|
||||
if jwksURL == "" {
|
||||
globalMu.Unlock()
|
||||
return errors.New("missing SSO JWKS configuration")
|
||||
}
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
options := keyfunc.Options{
|
||||
Ctx: ctx,
|
||||
@@ -79,19 +106,9 @@ func Init(ctx context.Context, jwksURL, issuer string, audiences []string) error
|
||||
|
||||
jwks, err := keyfunc.Get(jwksURL, options)
|
||||
if err != nil {
|
||||
globalMu.Unlock()
|
||||
return fmt.Errorf("load jwks: %w", err)
|
||||
}
|
||||
|
||||
audienceMap := make(map[string]struct{}, len(audiences))
|
||||
for _, aud := range audiences {
|
||||
aud = strings.TrimSpace(aud)
|
||||
if aud == "" {
|
||||
continue
|
||||
}
|
||||
audienceMap[aud] = struct{}{}
|
||||
}
|
||||
|
||||
globalMu.Lock()
|
||||
globalV = &verifier{jwks: jwks, issuer: issuer, audiences: audienceMap}
|
||||
globalMu.Unlock()
|
||||
|
||||
@@ -113,27 +130,47 @@ func VerifyAccessToken(token string) (*VerificationResult, error) {
|
||||
}
|
||||
|
||||
claims := &AccessTokenClaims{}
|
||||
parser := jwt.NewParser(
|
||||
jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()}),
|
||||
jwt.WithIssuedAt(),
|
||||
jwt.WithExpirationRequired(),
|
||||
)
|
||||
|
||||
tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
||||
if err != nil {
|
||||
if shouldRefreshOnVerifyError(err) {
|
||||
if refreshErr := v.jwks.Refresh(context.Background(), keyfunc.RefreshOptions{IgnoreRateLimit: true}); refreshErr != nil {
|
||||
utils.Log.WithError(refreshErr).Warn("sso jwks refresh after signature error failed")
|
||||
} else {
|
||||
tok, err = parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
||||
if len(v.hmacSecret) > 0 {
|
||||
parser := jwt.NewParser(
|
||||
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
|
||||
jwt.WithIssuedAt(),
|
||||
jwt.WithExpirationRequired(),
|
||||
)
|
||||
tok, err := parser.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, errors.New("invalid token signing method")
|
||||
}
|
||||
}
|
||||
return v.hmacSecret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse token: %w", err)
|
||||
}
|
||||
}
|
||||
if !tok.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
if !tok.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
} else {
|
||||
parser := jwt.NewParser(
|
||||
jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()}),
|
||||
jwt.WithIssuedAt(),
|
||||
jwt.WithExpirationRequired(),
|
||||
)
|
||||
|
||||
tok, err := parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
||||
if err != nil {
|
||||
if shouldRefreshOnVerifyError(err) {
|
||||
if refreshErr := v.jwks.Refresh(context.Background(), keyfunc.RefreshOptions{IgnoreRateLimit: true}); refreshErr != nil {
|
||||
utils.Log.WithError(refreshErr).Warn("sso jwks refresh after signature error failed")
|
||||
} else {
|
||||
tok, err = parser.ParseWithClaims(token, claims, v.jwks.Keyfunc)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse token: %w", err)
|
||||
}
|
||||
}
|
||||
if !tok.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
}
|
||||
|
||||
if claims.Issuer != v.issuer {
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
package sso
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
func TestVerifyAccessTokenHMAC(t *testing.T) {
|
||||
secret := "test-secret-123"
|
||||
issuer := "http://localhost:8080"
|
||||
aud := []string{"client:1"}
|
||||
|
||||
if err := Init(context.Background(), "", issuer, aud, secret); err != nil {
|
||||
t.Fatalf("Init error: %v", err)
|
||||
}
|
||||
|
||||
claims := &AccessTokenClaims{
|
||||
Scope: "openid profile",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: issuer,
|
||||
Subject: "user:1",
|
||||
Audience: jwt.ClaimStrings(aud),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
|
||||
},
|
||||
}
|
||||
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
t.Fatalf("sign token error: %v", err)
|
||||
}
|
||||
|
||||
result, err := VerifyAccessToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyAccessToken error: %v", err)
|
||||
}
|
||||
if result.UserID != 1 {
|
||||
t.Fatalf("unexpected user id: %d", result.UserID)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user