package middleware import ( "context" "errors" "strings" "sync" "github.com/gofiber/fiber/v2" "gitlab.com/mbugroup/lti-api.git/internal/apikeys" "gitlab.com/mbugroup/lti-api.git/internal/config" entity "gitlab.com/mbugroup/lti-api.git/internal/entities" "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/session" sso "gitlab.com/mbugroup/lti-api.git/internal/modules/sso/verifier" service "gitlab.com/mbugroup/lti-api.git/internal/modules/users/services" "gitlab.com/mbugroup/lti-api.git/internal/utils" ) const ( authContextLocalsKey = "auth.context" authUserLocalsKey = "auth.user" ) var ( verifyAccessTokenFunc = sso.VerifyAccessToken fetchProfileFunc = sso.FetchProfile apiKeyAuthMu sync.RWMutex apiKeyAuthenticator apikeys.Authenticator ) // AuthContext keeps authentication details captured by the middleware. type AuthContext struct { Token string Verification *sso.VerificationResult User *entity.User PrincipalType string PrincipalName string Roles []sso.Role Permissions map[string]struct{} UserAreaIDs []uint UserLocationIDs []uint UserAllArea bool UserAllLocation bool } func SetAPIKeyAuthenticator(authenticator apikeys.Authenticator) { apiKeyAuthMu.Lock() defer apiKeyAuthMu.Unlock() apiKeyAuthenticator = authenticator } // Auth validates the incoming request against the central SSO access token and // loads the corresponding local user. Optional scopes can be provided to enforce // fine-grained authorization using the SSO access token scopes. func Auth(userService service.UserService, requiredScopes ...string) fiber.Handler { return func(c *fiber.Ctx) error { token := bearerToken(c) tokenSource := "" if token != "" { tokenSource = "header" } else { primaryName := strings.TrimSpace(config.SSOAccessCookieName) if primaryName != "" { token = strings.TrimSpace(c.Cookies(primaryName)) if token != "" { tokenSource = "cookie:" + primaryName } } if token == "" { for _, name := range config.SSOAccessCookieFallback { name = strings.TrimSpace(name) if name == "" || name == primaryName { continue } token = strings.TrimSpace(c.Cookies(name)) if token != "" { tokenSource = "cookie:" + name break } } } } if token == "" { if c.Method() == fiber.MethodGet { if err := authenticateAPIKey(c); err == nil { if len(requiredScopes) > 0 { return fiber.NewError(fiber.StatusForbidden, "Insufficient scope") } return c.Next() } else if err != nil && !errors.Is(err, apikeys.ErrInvalidAPIKey) && !errors.Is(err, apikeys.ErrInactiveKey) { return err } } return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate") } verification, err := verifyAccessTokenFunc(token) if err != nil { if sso.IsSignatureError(err) { logSignatureError("auth", tokenSource, token, err) } else { utils.Log.WithError(err).Warn("auth: token verification failed") } return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate") } if verification.UserID == 0 { return fiber.NewError(fiber.StatusForbidden, "Service authentication is not permitted for this endpoint") } if err := ensureNotRevoked(c, token, verification); err != nil { return err } user, err := userService.GetBySSOUserID(c, verification.UserID) if err != nil || user == nil { utils.Log.WithError(err).Warn("auth: failed to resolve user from repository") return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate") } if len(requiredScopes) > 0 { if verification.Claims == nil || !hasAllScopes(verification.Claims.Scopes(), requiredScopes) { return fiber.NewError(fiber.StatusForbidden, "Insufficient scope") } } var roles []sso.Role permissions := make(map[string]struct{}) var profile *sso.UserProfile if verification.UserID != 0 { if p, err := fetchProfileFunc(c.Context(), token, verification); err != nil { utils.Log.WithError(err).Warn("auth: failed to fetch sso profile") } else { profile = p } } if profile != nil { roles = profile.Roles for _, perm := range profile.PermissionNames() { if perm != "" { permissions[perm] = struct{}{} } } } ctx := &AuthContext{ Token: token, Verification: verification, User: user, PrincipalType: "user", PrincipalName: user.Name, Roles: roles, Permissions: permissions, UserAreaIDs: nil, UserLocationIDs: nil, UserAllArea: false, UserAllLocation: false, } if profile != nil { ctx.UserAreaIDs = profile.AreaIDs ctx.UserLocationIDs = profile.LocationIDs ctx.UserAllArea = profile.AllArea ctx.UserAllLocation = profile.AllLocation } c.Locals(authContextLocalsKey, ctx) c.Locals(authUserLocalsKey, user) return c.Next() } } // AuthenticatedUser returns the authenticated user populated by Auth. func AuthenticatedUser(c *fiber.Ctx) (*entity.User, bool) { value := c.Locals(authUserLocalsKey) if user, ok := value.(*entity.User); ok && user != nil { return user, true } return nil, false } func ActorIDFromContext(c *fiber.Ctx) (uint, error) { user, ok := AuthenticatedUser(c) if !ok || user == nil || user.Id == 0 { return 0, fiber.NewError(fiber.StatusUnauthorized, "Please authenticate") } return user.Id, nil } // AuthDetails returns the full authentication context (token, claims, user). func AuthDetails(c *fiber.Ctx) (*AuthContext, bool) { value := c.Locals(authContextLocalsKey) if ctx, ok := value.(*AuthContext); ok && ctx != nil { return ctx, true } return nil, false } // ensureNotRevoked ensures the token is not revoked or superseded by a forced logout. func ensureNotRevoked(c *fiber.Ctx, token string, verification *sso.VerificationResult) error { revoker := session.GetRevocationStore() if revoker == nil { return nil } if fingerprint := session.TokenFingerprint(token); fingerprint != "" { revoked, err := revoker.IsRevoked(c.Context(), fingerprint) if err != nil { utils.Log.WithError(err).Warn("auth: token revocation check failed") return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate") } if revoked { return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate") } } if verification.UserID == 0 { return nil } logoutAt, err := revoker.UserLogoutTime(c.Context(), verification.UserID) if err != nil { utils.Log.WithError(err).Warn("auth: failed to load user logout marker") return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate") } if logoutAt.IsZero() { return nil } claims := verification.Claims if claims == nil || claims.IssuedAt == nil { return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate") } issuedAt := claims.IssuedAt.Time // Treat tokens issued at or before the forced logout timestamp as invalid. if !issuedAt.After(logoutAt) { return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate") } return nil } // bearerToken extracts a Bearer token from the Authorization header using // case-insensitive scheme matching and tolerant whitespace handling. func bearerToken(c *fiber.Ctx) string { parts := strings.Fields(c.Get("Authorization")) if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") { return strings.TrimSpace(parts[1]) } return "" } func authenticateAPIKey(c *fiber.Ctx) error { rawKey := strings.TrimSpace(c.Get("X-API-Key")) if rawKey == "" { return apikeys.ErrInvalidAPIKey } authenticator := currentAPIKeyAuthenticator() if authenticator == nil { return apikeys.ErrInvalidAPIKey } principal, err := authenticator.Authenticate(context.Background(), rawKey, c.IP()) if err != nil { if errors.Is(err, apikeys.ErrInvalidAPIKey) || errors.Is(err, apikeys.ErrInactiveKey) { return apikeys.ErrInvalidAPIKey } utils.Log.WithError(err).Warn("auth: api key authentication failed") return fiber.NewError(fiber.StatusInternalServerError, "Failed to authenticate request") } permissions := make(map[string]struct{}, len(principal.Permissions)) for _, perm := range principal.Permissions { if canonical := canonicalPermission(perm); canonical != "" { permissions[canonical] = struct{}{} } } c.Locals(authContextLocalsKey, &AuthContext{ Token: "", Verification: nil, User: nil, PrincipalType: "api_key", PrincipalName: principal.Name, Roles: nil, Permissions: permissions, UserAreaIDs: principal.AreaIDs, UserLocationIDs: principal.LocationIDs, UserAllArea: principal.AllArea, UserAllLocation: principal.AllLocation, }) c.Locals(authUserLocalsKey, nil) return nil } func currentAPIKeyAuthenticator() apikeys.Authenticator { apiKeyAuthMu.RLock() defer apiKeyAuthMu.RUnlock() return apiKeyAuthenticator } func hasAllScopes(have, required []string) bool { if len(required) == 0 { return true } set := make(map[string]struct{}, len(have)) for _, s := range have { s = strings.ToLower(strings.TrimSpace(s)) if s != "" { set[s] = struct{}{} } } for _, r := range required { r = strings.ToLower(strings.TrimSpace(r)) if r == "" { continue } if _, ok := set[r]; !ok { return false } } return true } func logSignatureError(ctxLabel, tokenSource, token string, err error) { info := sso.ExtractTokenInfo(token) aud := strings.Join(info.Aud, ",") utils.Log.Errorf( "access token verification failed: %v | ctx=%s source=%s iss=%s kid=%s aud=%s sub=%s exp=%d iat=%d nbf=%d expected_iss=%s expected_aud=%v", err, ctxLabel, tokenSource, info.Iss, info.Kid, aud, info.Sub, info.Exp, info.Iat, info.Nbf, config.SSOIssuer, config.SSOAllowedAudiences, ) } // RequirePermissions ensures the authenticated user possesses all specified permissions. func RequirePermissions(perms ...string) fiber.Handler { required := canonicalPermissions(perms) return func(c *fiber.Ctx) error { if len(required) == 0 { return c.Next() } ctx, ok := AuthDetails(c) if !ok || ctx == nil { return fiber.NewError(fiber.StatusUnauthorized, "Please authenticate") } userPerms := ctx.permissionSet() if len(userPerms) == 0 { return fiber.NewError(fiber.StatusForbidden, "Insufficient permission") } for _, perm := range required { if _, has := userPerms[perm]; !has { return fiber.NewError(fiber.StatusForbidden, "Insufficient permission") } } return c.Next() } } // HasPermission reports whether the current request context includes the given permission. func HasPermission(c *fiber.Ctx, perm string) bool { ctx, ok := AuthDetails(c) if !ok || ctx == nil { return false } perm = canonicalPermission(perm) if perm == "" { return false } _, has := ctx.permissionSet()[perm] return has } func (a *AuthContext) permissionSet() map[string]struct{} { if a == nil || a.Permissions == nil { return nil } return a.Permissions } func canonicalPermissions(perms []string) []string { out := make([]string, 0, len(perms)) seen := make(map[string]struct{}, len(perms)) for _, perm := range perms { if canonical := canonicalPermission(perm); canonical != "" { if _, ok := seen[canonical]; ok { continue } seen[canonical] = struct{}{} out = append(out, canonical) } } return out } func canonicalPermission(perm string) string { return strings.ToLower(strings.TrimSpace(perm)) }