mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-22 22:35:43 +00:00
.
This commit is contained in:
+65
@@ -0,0 +1,65 @@
|
||||
package compress
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Setup request handlers
|
||||
var (
|
||||
fctx = func(c *fasthttp.RequestCtx) {}
|
||||
compressor fasthttp.RequestHandler
|
||||
)
|
||||
|
||||
// Setup compression algorithm
|
||||
switch cfg.Level {
|
||||
case LevelDefault:
|
||||
// LevelDefault
|
||||
compressor = fasthttp.CompressHandlerBrotliLevel(fctx,
|
||||
fasthttp.CompressBrotliDefaultCompression,
|
||||
fasthttp.CompressDefaultCompression,
|
||||
)
|
||||
case LevelBestSpeed:
|
||||
// LevelBestSpeed
|
||||
compressor = fasthttp.CompressHandlerBrotliLevel(fctx,
|
||||
fasthttp.CompressBrotliBestSpeed,
|
||||
fasthttp.CompressBestSpeed,
|
||||
)
|
||||
case LevelBestCompression:
|
||||
// LevelBestCompression
|
||||
compressor = fasthttp.CompressHandlerBrotliLevel(fctx,
|
||||
fasthttp.CompressBrotliBestCompression,
|
||||
fasthttp.CompressBestCompression,
|
||||
)
|
||||
default:
|
||||
// LevelDisabled
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Continue stack
|
||||
if err := c.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Compress response
|
||||
compressor(c.Context())
|
||||
|
||||
// Return from handler
|
||||
return nil
|
||||
}
|
||||
}
|
||||
+56
@@ -0,0 +1,56 @@
|
||||
package compress
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Level determines the compression algorithm
|
||||
//
|
||||
// Optional. Default: LevelDefault
|
||||
// LevelDisabled: -1
|
||||
// LevelDefault: 0
|
||||
// LevelBestSpeed: 1
|
||||
// LevelBestCompression: 2
|
||||
Level Level
|
||||
}
|
||||
|
||||
// Level is numeric representation of compression level
|
||||
type Level int
|
||||
|
||||
// Represents compression level that will be used in the middleware
|
||||
const (
|
||||
LevelDisabled Level = -1
|
||||
LevelDefault Level = 0
|
||||
LevelBestSpeed Level = 1
|
||||
LevelBestCompression Level = 2
|
||||
)
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Level: LevelDefault,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Level < LevelDisabled || cfg.Level > LevelBestCompression {
|
||||
cfg.Level = ConfigDefault.Level
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
+289
@@ -0,0 +1,289 @@
|
||||
package cors
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// AllowOriginsFunc defines a function that will set the 'Access-Control-Allow-Origin'
|
||||
// response header to the 'origin' request header when returned true. This allows for
|
||||
// dynamic evaluation of allowed origins. Note if AllowCredentials is true, wildcard origins
|
||||
// will be not have the 'Access-Control-Allow-Credentials' header set to 'true'.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
AllowOriginsFunc func(origin string) bool
|
||||
|
||||
// AllowOrigin defines a comma separated list of origins that may access the resource.
|
||||
//
|
||||
// Optional. Default value "*"
|
||||
AllowOrigins string
|
||||
|
||||
// AllowMethods defines a list methods allowed when accessing the resource.
|
||||
// This is used in response to a preflight request.
|
||||
//
|
||||
// Optional. Default value "GET,POST,HEAD,PUT,DELETE,PATCH"
|
||||
AllowMethods string
|
||||
|
||||
// AllowHeaders defines a list of request headers that can be used when
|
||||
// making the actual request. This is in response to a preflight request.
|
||||
//
|
||||
// Optional. Default value "".
|
||||
AllowHeaders string
|
||||
|
||||
// AllowCredentials indicates whether or not the response to the request
|
||||
// can be exposed when the credentials flag is true. When used as part of
|
||||
// a response to a preflight request, this indicates whether or not the
|
||||
// actual request can be made using credentials. Note: If true, AllowOrigins
|
||||
// cannot be set to a wildcard ("*") to prevent security vulnerabilities.
|
||||
//
|
||||
// Optional. Default value false.
|
||||
AllowCredentials bool
|
||||
|
||||
// ExposeHeaders defines a whitelist headers that clients are allowed to
|
||||
// access.
|
||||
//
|
||||
// Optional. Default value "".
|
||||
ExposeHeaders string
|
||||
|
||||
// MaxAge indicates how long (in seconds) the results of a preflight request
|
||||
// can be cached.
|
||||
// If you pass MaxAge 0, Access-Control-Max-Age header will not be added and
|
||||
// browser will use 5 seconds by default.
|
||||
// To disable caching completely, pass MaxAge value negative. It will set the Access-Control-Max-Age header 0.
|
||||
//
|
||||
// Optional. Default value 0.
|
||||
MaxAge int
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
AllowOriginsFunc: nil,
|
||||
AllowOrigins: "*",
|
||||
AllowMethods: strings.Join([]string{
|
||||
fiber.MethodGet,
|
||||
fiber.MethodPost,
|
||||
fiber.MethodHead,
|
||||
fiber.MethodPut,
|
||||
fiber.MethodDelete,
|
||||
fiber.MethodPatch,
|
||||
}, ","),
|
||||
AllowHeaders: "",
|
||||
AllowCredentials: false,
|
||||
ExposeHeaders: "",
|
||||
MaxAge: 0,
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := ConfigDefault
|
||||
|
||||
// Override config if provided
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.AllowMethods == "" {
|
||||
cfg.AllowMethods = ConfigDefault.AllowMethods
|
||||
}
|
||||
// When none of the AllowOrigins or AllowOriginsFunc config was defined, set the default AllowOrigins value with "*"
|
||||
if cfg.AllowOrigins == "" && cfg.AllowOriginsFunc == nil {
|
||||
cfg.AllowOrigins = ConfigDefault.AllowOrigins
|
||||
}
|
||||
}
|
||||
|
||||
// Warning logs if both AllowOrigins and AllowOriginsFunc are set
|
||||
if cfg.AllowOrigins != "" && cfg.AllowOriginsFunc != nil {
|
||||
log.Warn("[CORS] Both 'AllowOrigins' and 'AllowOriginsFunc' have been defined.")
|
||||
}
|
||||
|
||||
// Validate CORS credentials configuration
|
||||
if cfg.AllowCredentials && cfg.AllowOrigins == "*" {
|
||||
panic("[CORS] Insecure setup, 'AllowCredentials' is set to true, and 'AllowOrigins' is set to a wildcard.")
|
||||
}
|
||||
|
||||
// allowOrigins is a slice of strings that contains the allowed origins
|
||||
// defined in the 'AllowOrigins' configuration.
|
||||
allowOrigins := []string{}
|
||||
allowSOrigins := []subdomain{}
|
||||
allowAllOrigins := false
|
||||
|
||||
// Validate and normalize static AllowOrigins
|
||||
if cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" {
|
||||
origins := strings.Split(cfg.AllowOrigins, ",")
|
||||
for _, origin := range origins {
|
||||
if i := strings.Index(origin, "://*."); i != -1 {
|
||||
trimmedOrigin := strings.TrimSpace(origin[:i+3] + origin[i+4:])
|
||||
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
|
||||
if !isValid {
|
||||
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
|
||||
}
|
||||
sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]}
|
||||
allowSOrigins = append(allowSOrigins, sd)
|
||||
} else {
|
||||
trimmedOrigin := strings.TrimSpace(origin)
|
||||
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
|
||||
if !isValid {
|
||||
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
|
||||
}
|
||||
allowOrigins = append(allowOrigins, normalizedOrigin)
|
||||
}
|
||||
}
|
||||
} else if cfg.AllowOrigins == "*" {
|
||||
allowAllOrigins = true
|
||||
}
|
||||
|
||||
// Strip white spaces
|
||||
allowMethods := strings.ReplaceAll(cfg.AllowMethods, " ", "")
|
||||
allowHeaders := strings.ReplaceAll(cfg.AllowHeaders, " ", "")
|
||||
exposeHeaders := strings.ReplaceAll(cfg.ExposeHeaders, " ", "")
|
||||
|
||||
// Convert int to string
|
||||
maxAge := strconv.Itoa(cfg.MaxAge)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Get originHeader header
|
||||
originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin))
|
||||
|
||||
// If the request does not have Origin header, the request is outside the scope of CORS
|
||||
if originHeader == "" {
|
||||
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
|
||||
// Unless all origins are allowed, we include the Vary header to cache the response correctly
|
||||
if !allowAllOrigins {
|
||||
c.Vary(fiber.HeaderOrigin)
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// If it's a preflight request and doesn't have Access-Control-Request-Method header, it's outside the scope of CORS
|
||||
if c.Method() == fiber.MethodOptions && c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
|
||||
// Response to OPTIONS request should not be cached but,
|
||||
// some caching can be configured to cache such responses.
|
||||
// To Avoid poisoning the cache, we include the Vary header
|
||||
// for non-CORS OPTIONS requests:
|
||||
c.Vary(fiber.HeaderOrigin)
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Set default allowOrigin to empty string
|
||||
allowOrigin := ""
|
||||
|
||||
// Check allowed origins
|
||||
if allowAllOrigins {
|
||||
allowOrigin = "*"
|
||||
} else {
|
||||
// Check if the origin is in the list of allowed origins
|
||||
for _, origin := range allowOrigins {
|
||||
if origin == originHeader {
|
||||
allowOrigin = originHeader
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the origin is in the list of allowed subdomains
|
||||
if allowOrigin == "" {
|
||||
for _, sOrigin := range allowSOrigins {
|
||||
if sOrigin.match(originHeader) {
|
||||
allowOrigin = originHeader
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run AllowOriginsFunc if the logic for
|
||||
// handling the value in 'AllowOrigins' does
|
||||
// not result in allowOrigin being set.
|
||||
if allowOrigin == "" && cfg.AllowOriginsFunc != nil && cfg.AllowOriginsFunc(originHeader) {
|
||||
allowOrigin = originHeader
|
||||
}
|
||||
|
||||
// Simple request
|
||||
// Ommit allowMethods and allowHeaders, only used for pre-flight requests
|
||||
if c.Method() != fiber.MethodOptions {
|
||||
if !allowAllOrigins {
|
||||
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
|
||||
c.Vary(fiber.HeaderOrigin)
|
||||
}
|
||||
setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg)
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Pre-flight request
|
||||
|
||||
// Response to OPTIONS request should not be cached but,
|
||||
// some caching can be configured to cache such responses.
|
||||
// To Avoid poisoning the cache, we include the Vary header
|
||||
// of preflight responses:
|
||||
c.Vary(fiber.HeaderAccessControlRequestMethod)
|
||||
c.Vary(fiber.HeaderAccessControlRequestHeaders)
|
||||
c.Vary(fiber.HeaderOrigin)
|
||||
|
||||
setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)
|
||||
|
||||
// Send 204 No Content
|
||||
return c.SendStatus(fiber.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
// Function to set CORS headers
|
||||
func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) {
|
||||
if cfg.AllowCredentials {
|
||||
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
|
||||
if allowOrigin == "*" {
|
||||
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||
log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.")
|
||||
} else if allowOrigin != "" {
|
||||
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
|
||||
}
|
||||
} else if allowOrigin != "" {
|
||||
// For non-credential requests, it's safe to set to '*' or specific origins
|
||||
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||
}
|
||||
|
||||
// Set Allow-Methods if not empty
|
||||
if allowMethods != "" {
|
||||
c.Set(fiber.HeaderAccessControlAllowMethods, allowMethods)
|
||||
}
|
||||
|
||||
// Set Allow-Headers if not empty
|
||||
if allowHeaders != "" {
|
||||
c.Set(fiber.HeaderAccessControlAllowHeaders, allowHeaders)
|
||||
} else {
|
||||
h := c.Get(fiber.HeaderAccessControlRequestHeaders)
|
||||
if h != "" {
|
||||
c.Set(fiber.HeaderAccessControlAllowHeaders, h)
|
||||
}
|
||||
}
|
||||
|
||||
// Set MaxAge if set
|
||||
if cfg.MaxAge > 0 {
|
||||
c.Set(fiber.HeaderAccessControlMaxAge, maxAge)
|
||||
} else if cfg.MaxAge < 0 {
|
||||
c.Set(fiber.HeaderAccessControlMaxAge, "0")
|
||||
}
|
||||
|
||||
// Set Expose-Headers if not empty
|
||||
if exposeHeaders != "" {
|
||||
c.Set(fiber.HeaderAccessControlExposeHeaders, exposeHeaders)
|
||||
}
|
||||
}
|
||||
+71
@@ -0,0 +1,71 @@
|
||||
package cors
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// matchScheme compares the scheme of the domain and pattern
|
||||
func matchScheme(domain, pattern string) bool {
|
||||
didx := strings.Index(domain, ":")
|
||||
pidx := strings.Index(pattern, ":")
|
||||
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
|
||||
}
|
||||
|
||||
// normalizeDomain removes the scheme and port from the input domain
|
||||
func normalizeDomain(input string) string {
|
||||
// Remove scheme
|
||||
input = strings.TrimPrefix(strings.TrimPrefix(input, "http://"), "https://")
|
||||
|
||||
// Find and remove port, if present
|
||||
if len(input) > 0 && input[0] != '[' {
|
||||
if portIndex := strings.Index(input, ":"); portIndex != -1 {
|
||||
input = input[:portIndex]
|
||||
}
|
||||
}
|
||||
|
||||
return input
|
||||
}
|
||||
|
||||
// normalizeOrigin checks if the provided origin is in a correct format
|
||||
// and normalizes it by removing any path or trailing slash.
|
||||
// It returns a boolean indicating whether the origin is valid
|
||||
// and the normalized origin.
|
||||
func normalizeOrigin(origin string) (bool, string) {
|
||||
parsedOrigin, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// Validate the scheme is either http or https
|
||||
if parsedOrigin.Scheme != "http" && parsedOrigin.Scheme != "https" {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// Don't allow a wildcard with a protocol
|
||||
// wildcards cannot be used within any other value. For example, the following header is not valid:
|
||||
// Access-Control-Allow-Origin: https://*
|
||||
if strings.Contains(parsedOrigin.Host, "*") {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// Validate there is a host present. The presence of a path, query, or fragment components
|
||||
// is checked, but a trailing "/" (indicative of the root) is allowed for the path and will be normalized
|
||||
if parsedOrigin.Host == "" || (parsedOrigin.Path != "" && parsedOrigin.Path != "/") || parsedOrigin.RawQuery != "" || parsedOrigin.Fragment != "" {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// Normalize the origin by constructing it from the scheme and host.
|
||||
// The path or trailing slash is not included in the normalized origin.
|
||||
return true, strings.ToLower(parsedOrigin.Scheme) + "://" + strings.ToLower(parsedOrigin.Host)
|
||||
}
|
||||
|
||||
type subdomain struct {
|
||||
// The wildcard pattern
|
||||
prefix string
|
||||
suffix string
|
||||
}
|
||||
|
||||
func (s subdomain) match(o string) bool {
|
||||
return len(o) >= len(s.prefix)+len(s.suffix) && strings.HasPrefix(o, s.prefix) && strings.HasSuffix(o, s.suffix)
|
||||
}
|
||||
+154
@@ -0,0 +1,154 @@
|
||||
package helmet
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip middleware.
|
||||
// Optional. Default: nil
|
||||
Next func(*fiber.Ctx) bool
|
||||
|
||||
// XSSProtection
|
||||
// Optional. Default value "0".
|
||||
XSSProtection string
|
||||
|
||||
// ContentTypeNosniff
|
||||
// Optional. Default value "nosniff".
|
||||
ContentTypeNosniff string
|
||||
|
||||
// XFrameOptions
|
||||
// Optional. Default value "SAMEORIGIN".
|
||||
// Possible values: "SAMEORIGIN", "DENY", "ALLOW-FROM uri"
|
||||
XFrameOptions string
|
||||
|
||||
// HSTSMaxAge
|
||||
// Optional. Default value 0.
|
||||
HSTSMaxAge int
|
||||
|
||||
// HSTSExcludeSubdomains
|
||||
// Optional. Default value false.
|
||||
HSTSExcludeSubdomains bool
|
||||
|
||||
// ContentSecurityPolicy
|
||||
// Optional. Default value "".
|
||||
ContentSecurityPolicy string
|
||||
|
||||
// CSPReportOnly
|
||||
// Optional. Default value false.
|
||||
CSPReportOnly bool
|
||||
|
||||
// HSTSPreloadEnabled
|
||||
// Optional. Default value false.
|
||||
HSTSPreloadEnabled bool
|
||||
|
||||
// ReferrerPolicy
|
||||
// Optional. Default value "ReferrerPolicy".
|
||||
ReferrerPolicy string
|
||||
|
||||
// Permissions-Policy
|
||||
// Optional. Default value "".
|
||||
PermissionPolicy string
|
||||
|
||||
// Cross-Origin-Embedder-Policy
|
||||
// Optional. Default value "require-corp".
|
||||
CrossOriginEmbedderPolicy string
|
||||
|
||||
// Cross-Origin-Opener-Policy
|
||||
// Optional. Default value "same-origin".
|
||||
CrossOriginOpenerPolicy string
|
||||
|
||||
// Cross-Origin-Resource-Policy
|
||||
// Optional. Default value "same-origin".
|
||||
CrossOriginResourcePolicy string
|
||||
|
||||
// Origin-Agent-Cluster
|
||||
// Optional. Default value "?1".
|
||||
OriginAgentCluster string
|
||||
|
||||
// X-DNS-Prefetch-Control
|
||||
// Optional. Default value "off".
|
||||
XDNSPrefetchControl string
|
||||
|
||||
// X-Download-Options
|
||||
// Optional. Default value "noopen".
|
||||
XDownloadOptions string
|
||||
|
||||
// X-Permitted-Cross-Domain-Policies
|
||||
// Optional. Default value "none".
|
||||
XPermittedCrossDomain string
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
XSSProtection: "0",
|
||||
ContentTypeNosniff: "nosniff",
|
||||
XFrameOptions: "SAMEORIGIN",
|
||||
ReferrerPolicy: "no-referrer",
|
||||
CrossOriginEmbedderPolicy: "require-corp",
|
||||
CrossOriginOpenerPolicy: "same-origin",
|
||||
CrossOriginResourcePolicy: "same-origin",
|
||||
OriginAgentCluster: "?1",
|
||||
XDNSPrefetchControl: "off",
|
||||
XDownloadOptions: "noopen",
|
||||
XPermittedCrossDomain: "none",
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.XSSProtection == "" {
|
||||
cfg.XSSProtection = ConfigDefault.XSSProtection
|
||||
}
|
||||
|
||||
if cfg.ContentTypeNosniff == "" {
|
||||
cfg.ContentTypeNosniff = ConfigDefault.ContentTypeNosniff
|
||||
}
|
||||
|
||||
if cfg.XFrameOptions == "" {
|
||||
cfg.XFrameOptions = ConfigDefault.XFrameOptions
|
||||
}
|
||||
|
||||
if cfg.ReferrerPolicy == "" {
|
||||
cfg.ReferrerPolicy = ConfigDefault.ReferrerPolicy
|
||||
}
|
||||
|
||||
if cfg.CrossOriginEmbedderPolicy == "" {
|
||||
cfg.CrossOriginEmbedderPolicy = ConfigDefault.CrossOriginEmbedderPolicy
|
||||
}
|
||||
|
||||
if cfg.CrossOriginOpenerPolicy == "" {
|
||||
cfg.CrossOriginOpenerPolicy = ConfigDefault.CrossOriginOpenerPolicy
|
||||
}
|
||||
|
||||
if cfg.CrossOriginResourcePolicy == "" {
|
||||
cfg.CrossOriginResourcePolicy = ConfigDefault.CrossOriginResourcePolicy
|
||||
}
|
||||
|
||||
if cfg.OriginAgentCluster == "" {
|
||||
cfg.OriginAgentCluster = ConfigDefault.OriginAgentCluster
|
||||
}
|
||||
|
||||
if cfg.XDNSPrefetchControl == "" {
|
||||
cfg.XDNSPrefetchControl = ConfigDefault.XDNSPrefetchControl
|
||||
}
|
||||
|
||||
if cfg.XDownloadOptions == "" {
|
||||
cfg.XDownloadOptions = ConfigDefault.XDownloadOptions
|
||||
}
|
||||
|
||||
if cfg.XPermittedCrossDomain == "" {
|
||||
cfg.XPermittedCrossDomain = ConfigDefault.XPermittedCrossDomain
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
+94
@@ -0,0 +1,94 @@
|
||||
package helmet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Init config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Return middleware handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Next request to skip middleware
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Set headers
|
||||
if cfg.XSSProtection != "" {
|
||||
c.Set(fiber.HeaderXXSSProtection, cfg.XSSProtection)
|
||||
}
|
||||
|
||||
if cfg.ContentTypeNosniff != "" {
|
||||
c.Set(fiber.HeaderXContentTypeOptions, cfg.ContentTypeNosniff)
|
||||
}
|
||||
|
||||
if cfg.XFrameOptions != "" {
|
||||
c.Set(fiber.HeaderXFrameOptions, cfg.XFrameOptions)
|
||||
}
|
||||
|
||||
if cfg.CrossOriginEmbedderPolicy != "" {
|
||||
c.Set("Cross-Origin-Embedder-Policy", cfg.CrossOriginEmbedderPolicy)
|
||||
}
|
||||
|
||||
if cfg.CrossOriginOpenerPolicy != "" {
|
||||
c.Set("Cross-Origin-Opener-Policy", cfg.CrossOriginOpenerPolicy)
|
||||
}
|
||||
|
||||
if cfg.CrossOriginResourcePolicy != "" {
|
||||
c.Set("Cross-Origin-Resource-Policy", cfg.CrossOriginResourcePolicy)
|
||||
}
|
||||
|
||||
if cfg.OriginAgentCluster != "" {
|
||||
c.Set("Origin-Agent-Cluster", cfg.OriginAgentCluster)
|
||||
}
|
||||
|
||||
if cfg.ReferrerPolicy != "" {
|
||||
c.Set("Referrer-Policy", cfg.ReferrerPolicy)
|
||||
}
|
||||
|
||||
if cfg.XDNSPrefetchControl != "" {
|
||||
c.Set("X-DNS-Prefetch-Control", cfg.XDNSPrefetchControl)
|
||||
}
|
||||
|
||||
if cfg.XDownloadOptions != "" {
|
||||
c.Set("X-Download-Options", cfg.XDownloadOptions)
|
||||
}
|
||||
|
||||
if cfg.XPermittedCrossDomain != "" {
|
||||
c.Set("X-Permitted-Cross-Domain-Policies", cfg.XPermittedCrossDomain)
|
||||
}
|
||||
|
||||
// Handle HSTS headers
|
||||
if c.Protocol() == "https" && cfg.HSTSMaxAge != 0 {
|
||||
subdomains := ""
|
||||
if !cfg.HSTSExcludeSubdomains {
|
||||
subdomains = "; includeSubDomains"
|
||||
}
|
||||
if cfg.HSTSPreloadEnabled {
|
||||
subdomains = fmt.Sprintf("%s; preload", subdomains)
|
||||
}
|
||||
c.Set(fiber.HeaderStrictTransportSecurity, fmt.Sprintf("max-age=%d%s", cfg.HSTSMaxAge, subdomains))
|
||||
}
|
||||
|
||||
// Handle Content-Security-Policy headers
|
||||
if cfg.ContentSecurityPolicy != "" {
|
||||
if cfg.CSPReportOnly {
|
||||
c.Set(fiber.HeaderContentSecurityPolicyReportOnly, cfg.ContentSecurityPolicy)
|
||||
} else {
|
||||
c.Set(fiber.HeaderContentSecurityPolicy, cfg.ContentSecurityPolicy)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Permissions-Policy headers
|
||||
if cfg.PermissionPolicy != "" {
|
||||
c.Set(fiber.HeaderPermissionsPolicy, cfg.PermissionPolicy)
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
+128
@@ -0,0 +1,128 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Max number of recent connections during `Expiration` seconds before sending a 429 response
|
||||
//
|
||||
// Default: 5
|
||||
Max int
|
||||
|
||||
// KeyGenerator allows you to generate custom keys, by default c.IP() is used
|
||||
//
|
||||
// Default: func(c *fiber.Ctx) string {
|
||||
// return c.IP()
|
||||
// }
|
||||
KeyGenerator func(*fiber.Ctx) string
|
||||
|
||||
// Expiration is the time on how long to keep records of requests in memory
|
||||
//
|
||||
// Default: 1 * time.Minute
|
||||
Expiration time.Duration
|
||||
|
||||
// LimitReached is called when a request hits the limit
|
||||
//
|
||||
// Default: func(c *fiber.Ctx) error {
|
||||
// return c.SendStatus(fiber.StatusTooManyRequests)
|
||||
// }
|
||||
LimitReached fiber.Handler
|
||||
|
||||
// When set to true, requests with StatusCode >= 400 won't be counted.
|
||||
//
|
||||
// Default: false
|
||||
SkipFailedRequests bool
|
||||
|
||||
// When set to true, requests with StatusCode < 400 won't be counted.
|
||||
//
|
||||
// Default: false
|
||||
SkipSuccessfulRequests bool
|
||||
|
||||
// Store is used to store the state of the middleware
|
||||
//
|
||||
// Default: an in memory store for this process only
|
||||
Storage fiber.Storage
|
||||
|
||||
// LimiterMiddleware is the struct that implements a limiter middleware.
|
||||
//
|
||||
// Default: a new Fixed Window Rate Limiter
|
||||
LimiterMiddleware LimiterHandler
|
||||
|
||||
// Deprecated: Use Expiration instead
|
||||
Duration time.Duration
|
||||
|
||||
// Deprecated: Use Storage instead
|
||||
Store fiber.Storage
|
||||
|
||||
// Deprecated: Use KeyGenerator instead
|
||||
Key func(*fiber.Ctx) string
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Max: 5,
|
||||
Expiration: 1 * time.Minute,
|
||||
KeyGenerator: func(c *fiber.Ctx) string {
|
||||
return c.IP()
|
||||
},
|
||||
LimitReached: func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusTooManyRequests)
|
||||
},
|
||||
SkipFailedRequests: false,
|
||||
SkipSuccessfulRequests: false,
|
||||
LimiterMiddleware: FixedWindow{},
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if int(cfg.Duration.Seconds()) > 0 {
|
||||
log.Warn("[LIMITER] Duration is deprecated, please use Expiration")
|
||||
cfg.Expiration = cfg.Duration
|
||||
}
|
||||
if cfg.Key != nil {
|
||||
log.Warn("[LIMITER] Key is deprecated, please us KeyGenerator")
|
||||
cfg.KeyGenerator = cfg.Key
|
||||
}
|
||||
if cfg.Store != nil {
|
||||
log.Warn("[LIMITER] Store is deprecated, please use Storage")
|
||||
cfg.Storage = cfg.Store
|
||||
}
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
if cfg.Max <= 0 {
|
||||
cfg.Max = ConfigDefault.Max
|
||||
}
|
||||
if int(cfg.Expiration.Seconds()) <= 0 {
|
||||
cfg.Expiration = ConfigDefault.Expiration
|
||||
}
|
||||
if cfg.KeyGenerator == nil {
|
||||
cfg.KeyGenerator = ConfigDefault.KeyGenerator
|
||||
}
|
||||
if cfg.LimitReached == nil {
|
||||
cfg.LimitReached = ConfigDefault.LimitReached
|
||||
}
|
||||
if cfg.LimiterMiddleware == nil {
|
||||
cfg.LimiterMiddleware = ConfigDefault.LimiterMiddleware
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
+25
@@ -0,0 +1,25 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
// X-RateLimit-* headers
|
||||
xRateLimitLimit = "X-RateLimit-Limit"
|
||||
xRateLimitRemaining = "X-RateLimit-Remaining"
|
||||
xRateLimitReset = "X-RateLimit-Reset"
|
||||
)
|
||||
|
||||
type LimiterHandler interface {
|
||||
New(config Config) fiber.Handler
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Return the specified middleware handler.
|
||||
return cfg.LimiterMiddleware.New(cfg)
|
||||
}
|
||||
+106
@@ -0,0 +1,106 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
type FixedWindow struct{}
|
||||
|
||||
// New creates a new fixed window middleware handler
|
||||
func (FixedWindow) New(cfg Config) fiber.Handler {
|
||||
var (
|
||||
// Limiter variables
|
||||
mux = &sync.RWMutex{}
|
||||
max = strconv.Itoa(cfg.Max)
|
||||
expiration = uint64(cfg.Expiration.Seconds())
|
||||
)
|
||||
|
||||
// Create manager to simplify storage operations ( see manager.go )
|
||||
manager := newManager(cfg.Storage)
|
||||
|
||||
// Update timestamp every second
|
||||
utils.StartTimeStampUpdater()
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Get key from request
|
||||
key := cfg.KeyGenerator(c)
|
||||
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
|
||||
// Get entry from pool and release when finished
|
||||
e := manager.get(key)
|
||||
|
||||
// Get timestamp
|
||||
ts := uint64(atomic.LoadUint32(&utils.Timestamp))
|
||||
|
||||
// Set expiration if entry does not exist
|
||||
if e.exp == 0 {
|
||||
e.exp = ts + expiration
|
||||
} else if ts >= e.exp {
|
||||
// Check if entry is expired
|
||||
e.currHits = 0
|
||||
e.exp = ts + expiration
|
||||
}
|
||||
|
||||
// Increment hits
|
||||
e.currHits++
|
||||
|
||||
// Calculate when it resets in seconds
|
||||
resetInSec := e.exp - ts
|
||||
|
||||
// Set how many hits we have left
|
||||
remaining := cfg.Max - e.currHits
|
||||
|
||||
// Update storage
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
|
||||
// Unlock entry
|
||||
mux.Unlock()
|
||||
|
||||
// Check if hits exceed the cfg.Max
|
||||
if remaining < 0 {
|
||||
// Return response with Retry-After header
|
||||
// https://tools.ietf.org/html/rfc6584
|
||||
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
|
||||
|
||||
// Call LimitReached handler
|
||||
return cfg.LimitReached(c)
|
||||
}
|
||||
|
||||
// Continue stack for reaching c.Response().StatusCode()
|
||||
// Store err for returning
|
||||
err := c.Next()
|
||||
|
||||
// Check for SkipFailedRequests and SkipSuccessfulRequests
|
||||
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
|
||||
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
e = manager.get(key)
|
||||
e.currHits--
|
||||
remaining++
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
// Unlock entry
|
||||
mux.Unlock()
|
||||
}
|
||||
|
||||
// We can continue, update RateLimit headers
|
||||
c.Set(xRateLimitLimit, max)
|
||||
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
||||
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
+137
@@ -0,0 +1,137 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
type SlidingWindow struct{}
|
||||
|
||||
// New creates a new sliding window middleware handler
|
||||
func (SlidingWindow) New(cfg Config) fiber.Handler {
|
||||
var (
|
||||
// Limiter variables
|
||||
mux = &sync.RWMutex{}
|
||||
max = strconv.Itoa(cfg.Max)
|
||||
expiration = uint64(cfg.Expiration.Seconds())
|
||||
)
|
||||
|
||||
// Create manager to simplify storage operations ( see manager.go )
|
||||
manager := newManager(cfg.Storage)
|
||||
|
||||
// Update timestamp every second
|
||||
utils.StartTimeStampUpdater()
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Get key from request
|
||||
key := cfg.KeyGenerator(c)
|
||||
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
|
||||
// Get entry from pool and release when finished
|
||||
e := manager.get(key)
|
||||
|
||||
// Get timestamp
|
||||
ts := uint64(atomic.LoadUint32(&utils.Timestamp))
|
||||
|
||||
// Set expiration if entry does not exist
|
||||
if e.exp == 0 {
|
||||
e.exp = ts + expiration
|
||||
} else if ts >= e.exp {
|
||||
// The entry has expired, handle the expiration.
|
||||
// Set the prevHits to the current hits and reset the hits to 0.
|
||||
e.prevHits = e.currHits
|
||||
|
||||
// Reset the current hits to 0.
|
||||
e.currHits = 0
|
||||
|
||||
// Check how much into the current window it currently is and sets the
|
||||
// expiry based on that, otherwise this would only reset on
|
||||
// the next request and not show the correct expiry.
|
||||
elapsed := ts - e.exp
|
||||
if elapsed >= expiration {
|
||||
e.exp = ts + expiration
|
||||
} else {
|
||||
e.exp = ts + expiration - elapsed
|
||||
}
|
||||
}
|
||||
|
||||
// Increment hits
|
||||
e.currHits++
|
||||
|
||||
// Calculate when it resets in seconds
|
||||
resetInSec := e.exp - ts
|
||||
|
||||
// weight = time until current window reset / total window length
|
||||
weight := float64(resetInSec) / float64(expiration)
|
||||
|
||||
// rate = request count in previous window - weight + request count in current window
|
||||
rate := int(float64(e.prevHits)*weight) + e.currHits
|
||||
|
||||
// Calculate how many hits can be made based on the current rate
|
||||
remaining := cfg.Max - rate
|
||||
|
||||
// Update storage. Garbage collect when the next window ends.
|
||||
// |--------------------------|--------------------------|
|
||||
// ^ ^ ^ ^
|
||||
// ts e.exp End sample window End next window
|
||||
// <------------>
|
||||
// resetInSec
|
||||
// resetInSec = e.exp - ts - time until end of current window.
|
||||
// duration + expiration = end of next window.
|
||||
// Because we don't want to garbage collect in the middle of a window
|
||||
// we add the expiration to the duration.
|
||||
// Otherwise after the end of "sample window", attackers could launch
|
||||
// a new request with the full window length.
|
||||
manager.set(key, e, time.Duration(resetInSec+expiration)*time.Second)
|
||||
|
||||
// Unlock entry
|
||||
mux.Unlock()
|
||||
|
||||
// Check if hits exceed the cfg.Max
|
||||
if remaining < 0 {
|
||||
// Return response with Retry-After header
|
||||
// https://tools.ietf.org/html/rfc6584
|
||||
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
|
||||
|
||||
// Call LimitReached handler
|
||||
return cfg.LimitReached(c)
|
||||
}
|
||||
|
||||
// Continue stack for reaching c.Response().StatusCode()
|
||||
// Store err for returning
|
||||
err := c.Next()
|
||||
|
||||
// Check for SkipFailedRequests and SkipSuccessfulRequests
|
||||
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
|
||||
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
e = manager.get(key)
|
||||
e.currHits--
|
||||
remaining++
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
// Unlock entry
|
||||
mux.Unlock()
|
||||
}
|
||||
|
||||
// We can continue, update RateLimit headers
|
||||
c.Set(xRateLimitLimit, max)
|
||||
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
||||
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
+92
@@ -0,0 +1,92 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/memory"
|
||||
)
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
|
||||
type item struct {
|
||||
currHits int
|
||||
prevHits int
|
||||
exp uint64
|
||||
}
|
||||
|
||||
//msgp:ignore manager
|
||||
type manager struct {
|
||||
pool sync.Pool
|
||||
memory *memory.Storage
|
||||
storage fiber.Storage
|
||||
}
|
||||
|
||||
func newManager(storage fiber.Storage) *manager {
|
||||
// Create new storage handler
|
||||
manager := &manager{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(item)
|
||||
},
|
||||
},
|
||||
}
|
||||
if storage != nil {
|
||||
// Use provided storage if provided
|
||||
manager.storage = storage
|
||||
} else {
|
||||
// Fallback too memory storage
|
||||
manager.memory = memory.New()
|
||||
}
|
||||
return manager
|
||||
}
|
||||
|
||||
// acquire returns an *entry from the sync.Pool
|
||||
func (m *manager) acquire() *item {
|
||||
return m.pool.Get().(*item) //nolint:forcetypeassert // We store nothing else in the pool
|
||||
}
|
||||
|
||||
// release and reset *entry to sync.Pool
|
||||
func (m *manager) release(e *item) {
|
||||
e.prevHits = 0
|
||||
e.currHits = 0
|
||||
e.exp = 0
|
||||
m.pool.Put(e)
|
||||
}
|
||||
|
||||
// get data from storage or memory
|
||||
func (m *manager) get(key string) *item {
|
||||
var it *item
|
||||
if m.storage != nil {
|
||||
it = m.acquire()
|
||||
raw, err := m.storage.Get(key)
|
||||
if err != nil {
|
||||
return it
|
||||
}
|
||||
if raw != nil {
|
||||
if _, err := it.UnmarshalMsg(raw); err != nil {
|
||||
return it
|
||||
}
|
||||
}
|
||||
return it
|
||||
}
|
||||
if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
|
||||
it = m.acquire()
|
||||
return it
|
||||
}
|
||||
return it
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
if raw, err := it.MarshalMsg(nil); err == nil {
|
||||
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
|
||||
}
|
||||
// we can release data because it's serialized to database
|
||||
m.release(it)
|
||||
} else {
|
||||
m.memory.Set(key, it, exp)
|
||||
}
|
||||
}
|
||||
+160
@@ -0,0 +1,160 @@
|
||||
package limiter
|
||||
|
||||
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
|
||||
|
||||
import (
|
||||
"github.com/tinylib/msgp/msgp"
|
||||
)
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, err = dc.ReadMapHeader()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, err = dc.ReadMapKeyPtr()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "currHits":
|
||||
z.currHits, err = dc.ReadInt()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "currHits")
|
||||
return
|
||||
}
|
||||
case "prevHits":
|
||||
z.prevHits, err = dc.ReadInt()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "prevHits")
|
||||
return
|
||||
}
|
||||
case "exp":
|
||||
z.exp, err = dc.ReadUint64()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "exp")
|
||||
return
|
||||
}
|
||||
default:
|
||||
err = dc.Skip()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 3
|
||||
// write "currHits"
|
||||
err = en.Append(0x83, 0xa8, 0x63, 0x75, 0x72, 0x72, 0x48, 0x69, 0x74, 0x73)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteInt(z.currHits)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "currHits")
|
||||
return
|
||||
}
|
||||
// write "prevHits"
|
||||
err = en.Append(0xa8, 0x70, 0x72, 0x65, 0x76, 0x48, 0x69, 0x74, 0x73)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteInt(z.prevHits)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "prevHits")
|
||||
return
|
||||
}
|
||||
// write "exp"
|
||||
err = en.Append(0xa3, 0x65, 0x78, 0x70)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteUint64(z.exp)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "exp")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 3
|
||||
// string "currHits"
|
||||
o = append(o, 0x83, 0xa8, 0x63, 0x75, 0x72, 0x72, 0x48, 0x69, 0x74, 0x73)
|
||||
o = msgp.AppendInt(o, z.currHits)
|
||||
// string "prevHits"
|
||||
o = append(o, 0xa8, 0x70, 0x72, 0x65, 0x76, 0x48, 0x69, 0x74, 0x73)
|
||||
o = msgp.AppendInt(o, z.prevHits)
|
||||
// string "exp"
|
||||
o = append(o, 0xa3, 0x65, 0x78, 0x70)
|
||||
o = msgp.AppendUint64(o, z.exp)
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "currHits":
|
||||
z.currHits, bts, err = msgp.ReadIntBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "currHits")
|
||||
return
|
||||
}
|
||||
case "prevHits":
|
||||
z.prevHits, bts, err = msgp.ReadIntBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "prevHits")
|
||||
return
|
||||
}
|
||||
case "exp":
|
||||
z.exp, bts, err = msgp.ReadUint64Bytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "exp")
|
||||
return
|
||||
}
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
o = bts
|
||||
return
|
||||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z item) Msgsize() (s int) {
|
||||
s = 1 + 9 + msgp.IntSize + 9 + msgp.IntSize + 4 + msgp.Uint64Size
|
||||
return
|
||||
}
|
||||
+136
@@ -0,0 +1,136 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Done is a function that is called after the log string for a request is written to Output,
|
||||
// and pass the log string as parameter.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Done func(c *fiber.Ctx, logString []byte)
|
||||
|
||||
// tagFunctions defines the custom tag action
|
||||
//
|
||||
// Optional. Default: map[string]LogFunc
|
||||
CustomTags map[string]LogFunc
|
||||
|
||||
// Format defines the logging tags
|
||||
//
|
||||
// Optional. Default: ${time} | ${status} | ${latency} | ${ip} | ${method} | ${path} | ${error}\n
|
||||
Format string
|
||||
|
||||
// TimeFormat https://programming.guide/go/format-parse-string-time-date-example.html
|
||||
//
|
||||
// Optional. Default: 15:04:05
|
||||
TimeFormat string
|
||||
|
||||
// TimeZone can be specified, such as "UTC" and "America/New_York" and "Asia/Chongqing", etc
|
||||
//
|
||||
// Optional. Default: "Local"
|
||||
TimeZone string
|
||||
|
||||
// TimeInterval is the delay before the timestamp is updated
|
||||
//
|
||||
// Optional. Default: 500 * time.Millisecond
|
||||
TimeInterval time.Duration
|
||||
|
||||
// Output is a writer where logs are written
|
||||
//
|
||||
// Default: os.Stdout
|
||||
Output io.Writer
|
||||
|
||||
// DisableColors defines if the logs output should be colorized
|
||||
//
|
||||
// Default: false
|
||||
DisableColors bool
|
||||
|
||||
enableColors bool
|
||||
enableLatency bool
|
||||
timeZoneLocation *time.Location
|
||||
}
|
||||
|
||||
const (
|
||||
startTag = "${"
|
||||
endTag = "}"
|
||||
paramSeparator = ":"
|
||||
)
|
||||
|
||||
type Buffer interface {
|
||||
Len() int
|
||||
ReadFrom(r io.Reader) (int64, error)
|
||||
WriteTo(w io.Writer) (int64, error)
|
||||
Bytes() []byte
|
||||
Write(p []byte) (int, error)
|
||||
WriteByte(c byte) error
|
||||
WriteString(s string) (int, error)
|
||||
Set(p []byte)
|
||||
SetString(s string)
|
||||
String() string
|
||||
}
|
||||
|
||||
type LogFunc func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error)
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Done: nil,
|
||||
Format: "${time} | ${status} | ${latency} | ${ip} | ${method} | ${path} | ${error}\n",
|
||||
TimeFormat: "15:04:05",
|
||||
TimeZone: "Local",
|
||||
TimeInterval: 500 * time.Millisecond,
|
||||
Output: os.Stdout,
|
||||
DisableColors: false,
|
||||
enableColors: true,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
if cfg.Done == nil {
|
||||
cfg.Done = ConfigDefault.Done
|
||||
}
|
||||
if cfg.Format == "" {
|
||||
cfg.Format = ConfigDefault.Format
|
||||
}
|
||||
if cfg.TimeZone == "" {
|
||||
cfg.TimeZone = ConfigDefault.TimeZone
|
||||
}
|
||||
if cfg.TimeFormat == "" {
|
||||
cfg.TimeFormat = ConfigDefault.TimeFormat
|
||||
}
|
||||
if int(cfg.TimeInterval) <= 0 {
|
||||
cfg.TimeInterval = ConfigDefault.TimeInterval
|
||||
}
|
||||
if cfg.Output == nil {
|
||||
cfg.Output = ConfigDefault.Output
|
||||
}
|
||||
|
||||
if !cfg.DisableColors && cfg.Output == ConfigDefault.Output {
|
||||
cfg.enableColors = true
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
+16
@@ -0,0 +1,16 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Data is a struct to define some variables to use in custom logger function.
|
||||
type Data struct {
|
||||
Pid string
|
||||
ErrPaddingStr string
|
||||
ChainErr error
|
||||
Start time.Time
|
||||
Stop time.Time
|
||||
Timestamp atomic.Value
|
||||
}
|
||||
+182
@@ -0,0 +1,182 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/mattn/go-colorable"
|
||||
"github.com/mattn/go-isatty"
|
||||
"github.com/valyala/bytebufferpool"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Get timezone location
|
||||
tz, err := time.LoadLocation(cfg.TimeZone)
|
||||
if err != nil || tz == nil {
|
||||
cfg.timeZoneLocation = time.Local
|
||||
} else {
|
||||
cfg.timeZoneLocation = tz
|
||||
}
|
||||
|
||||
// Check if format contains latency
|
||||
cfg.enableLatency = strings.Contains(cfg.Format, "${"+TagLatency+"}")
|
||||
|
||||
var timestamp atomic.Value
|
||||
// Create correct timeformat
|
||||
timestamp.Store(time.Now().In(cfg.timeZoneLocation).Format(cfg.TimeFormat))
|
||||
|
||||
// Update date/time every 500 milliseconds in a separate go routine
|
||||
if strings.Contains(cfg.Format, "${"+TagTime+"}") {
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(cfg.TimeInterval)
|
||||
timestamp.Store(time.Now().In(cfg.timeZoneLocation).Format(cfg.TimeFormat))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Set PID once
|
||||
pid := strconv.Itoa(os.Getpid())
|
||||
|
||||
// Set variables
|
||||
var (
|
||||
once sync.Once
|
||||
mu sync.Mutex
|
||||
errHandler fiber.ErrorHandler
|
||||
|
||||
dataPool = sync.Pool{New: func() interface{} { return new(Data) }}
|
||||
)
|
||||
|
||||
// If colors are enabled, check terminal compatibility
|
||||
if cfg.enableColors {
|
||||
cfg.Output = colorable.NewColorableStdout()
|
||||
if os.Getenv("TERM") == "dumb" || os.Getenv("NO_COLOR") == "1" || (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd())) {
|
||||
cfg.Output = colorable.NewNonColorable(os.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
errPadding := 15
|
||||
errPaddingStr := strconv.Itoa(errPadding)
|
||||
|
||||
// instead of analyzing the template inside(handler) each time, this is done once before
|
||||
// and we create several slices of the same length with the functions to be executed and fixed parts.
|
||||
templateChain, logFunChain, err := buildLogFuncChain(&cfg, createTagMap(&cfg))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Set error handler once
|
||||
once.Do(func() {
|
||||
// get longested possible path
|
||||
stack := c.App().Stack()
|
||||
for m := range stack {
|
||||
for r := range stack[m] {
|
||||
if len(stack[m][r].Path) > errPadding {
|
||||
errPadding = len(stack[m][r].Path)
|
||||
errPaddingStr = strconv.Itoa(errPadding)
|
||||
}
|
||||
}
|
||||
}
|
||||
// override error handler
|
||||
errHandler = c.App().ErrorHandler
|
||||
})
|
||||
|
||||
// Logger data
|
||||
data := dataPool.Get().(*Data) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
|
||||
// no need for a reset, as long as we always override everything
|
||||
data.Pid = pid
|
||||
data.ErrPaddingStr = errPaddingStr
|
||||
data.Timestamp = timestamp
|
||||
// put data back in the pool
|
||||
defer dataPool.Put(data)
|
||||
|
||||
// Set latency start time
|
||||
if cfg.enableLatency {
|
||||
data.Start = time.Now()
|
||||
}
|
||||
|
||||
// Handle request, store err for logging
|
||||
chainErr := c.Next()
|
||||
|
||||
data.ChainErr = chainErr
|
||||
// Manually call error handler
|
||||
if chainErr != nil {
|
||||
if err := errHandler(c, chainErr); err != nil {
|
||||
_ = c.SendStatus(fiber.StatusInternalServerError) //nolint:errcheck // TODO: Explain why we ignore the error here
|
||||
}
|
||||
}
|
||||
|
||||
// Set latency stop time
|
||||
if cfg.enableLatency {
|
||||
data.Stop = time.Now()
|
||||
}
|
||||
|
||||
// Get new buffer
|
||||
buf := bytebufferpool.Get()
|
||||
|
||||
var err error
|
||||
// Loop over template parts execute dynamic parts and add fixed parts to the buffer
|
||||
for i, logFunc := range logFunChain {
|
||||
if logFunc == nil {
|
||||
_, _ = buf.Write(templateChain[i]) //nolint:errcheck // This will never fail
|
||||
} else if templateChain[i] == nil {
|
||||
_, err = logFunc(buf, c, data, "")
|
||||
} else {
|
||||
_, err = logFunc(buf, c, data, utils.UnsafeString(templateChain[i]))
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Also write errors to the buffer
|
||||
if err != nil {
|
||||
_, _ = buf.WriteString(err.Error()) //nolint:errcheck // This will never fail
|
||||
}
|
||||
mu.Lock()
|
||||
// Write buffer to output
|
||||
if _, err := cfg.Output.Write(buf.Bytes()); err != nil {
|
||||
// Write error to output
|
||||
if _, err := cfg.Output.Write([]byte(err.Error())); err != nil {
|
||||
// There is something wrong with the given io.Writer
|
||||
_, _ = fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
if cfg.Done != nil {
|
||||
cfg.Done(c, buf.Bytes())
|
||||
}
|
||||
|
||||
// Put buffer back to pool
|
||||
bytebufferpool.Put(buf)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func appendInt(output Buffer, v int) (int, error) {
|
||||
old := output.Len()
|
||||
output.Set(fasthttp.AppendUint(output.Bytes(), v))
|
||||
return output.Len() - old, nil
|
||||
}
|
||||
+211
@@ -0,0 +1,211 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Logger variables
|
||||
const (
|
||||
TagPid = "pid"
|
||||
TagTime = "time"
|
||||
TagReferer = "referer"
|
||||
TagProtocol = "protocol"
|
||||
TagPort = "port"
|
||||
TagIP = "ip"
|
||||
TagIPs = "ips"
|
||||
TagHost = "host"
|
||||
TagMethod = "method"
|
||||
TagPath = "path"
|
||||
TagURL = "url"
|
||||
TagUA = "ua"
|
||||
TagLatency = "latency"
|
||||
TagStatus = "status"
|
||||
TagResBody = "resBody"
|
||||
TagReqHeaders = "reqHeaders"
|
||||
TagQueryStringParams = "queryParams"
|
||||
TagBody = "body"
|
||||
TagBytesSent = "bytesSent"
|
||||
TagBytesReceived = "bytesReceived"
|
||||
TagRoute = "route"
|
||||
TagError = "error"
|
||||
// Deprecated: Use TagReqHeader instead
|
||||
TagHeader = "header:"
|
||||
TagReqHeader = "reqHeader:"
|
||||
TagRespHeader = "respHeader:"
|
||||
TagLocals = "locals:"
|
||||
TagQuery = "query:"
|
||||
TagForm = "form:"
|
||||
TagCookie = "cookie:"
|
||||
TagBlack = "black"
|
||||
TagRed = "red"
|
||||
TagGreen = "green"
|
||||
TagYellow = "yellow"
|
||||
TagBlue = "blue"
|
||||
TagMagenta = "magenta"
|
||||
TagCyan = "cyan"
|
||||
TagWhite = "white"
|
||||
TagReset = "reset"
|
||||
)
|
||||
|
||||
// createTagMap function merged the default with the custom tags
|
||||
func createTagMap(cfg *Config) map[string]LogFunc {
|
||||
// Set default tags
|
||||
tagFunctions := map[string]LogFunc{
|
||||
TagReferer: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Get(fiber.HeaderReferer))
|
||||
},
|
||||
TagProtocol: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Protocol())
|
||||
},
|
||||
TagPort: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Port())
|
||||
},
|
||||
TagIP: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.IP())
|
||||
},
|
||||
TagIPs: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Get(fiber.HeaderXForwardedFor))
|
||||
},
|
||||
TagHost: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Hostname())
|
||||
},
|
||||
TagPath: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Path())
|
||||
},
|
||||
TagURL: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.OriginalURL())
|
||||
},
|
||||
TagUA: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Get(fiber.HeaderUserAgent))
|
||||
},
|
||||
TagBody: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.Write(c.Body())
|
||||
},
|
||||
TagBytesReceived: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return appendInt(output, len(c.Request().Body()))
|
||||
},
|
||||
TagBytesSent: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
if c.Response().Header.ContentLength() < 0 {
|
||||
return appendInt(output, 0)
|
||||
}
|
||||
return appendInt(output, len(c.Response().Body()))
|
||||
},
|
||||
TagRoute: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Route().Path)
|
||||
},
|
||||
TagResBody: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.Write(c.Response().Body())
|
||||
},
|
||||
TagReqHeaders: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
reqHeaders := make([]string, 0)
|
||||
for k, v := range c.GetReqHeaders() {
|
||||
reqHeaders = append(reqHeaders, k+"="+strings.Join(v, ","))
|
||||
}
|
||||
return output.Write([]byte(strings.Join(reqHeaders, "&")))
|
||||
},
|
||||
TagQueryStringParams: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Request().URI().QueryArgs().String())
|
||||
},
|
||||
|
||||
TagBlack: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Black)
|
||||
},
|
||||
TagRed: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Red)
|
||||
},
|
||||
TagGreen: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Green)
|
||||
},
|
||||
TagYellow: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Yellow)
|
||||
},
|
||||
TagBlue: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Blue)
|
||||
},
|
||||
TagMagenta: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Magenta)
|
||||
},
|
||||
TagCyan: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Cyan)
|
||||
},
|
||||
TagWhite: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.White)
|
||||
},
|
||||
TagReset: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Reset)
|
||||
},
|
||||
TagError: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
if data.ChainErr != nil {
|
||||
if cfg.enableColors {
|
||||
colors := c.App().Config().ColorScheme
|
||||
return output.WriteString(fmt.Sprintf("%s%s%s", colors.Red, data.ChainErr.Error(), colors.Reset))
|
||||
}
|
||||
return output.WriteString(data.ChainErr.Error())
|
||||
}
|
||||
return output.WriteString("-")
|
||||
},
|
||||
TagReqHeader: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Get(extraParam))
|
||||
},
|
||||
TagHeader: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Get(extraParam))
|
||||
},
|
||||
TagRespHeader: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.GetRespHeader(extraParam))
|
||||
},
|
||||
TagQuery: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Query(extraParam))
|
||||
},
|
||||
TagForm: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.FormValue(extraParam))
|
||||
},
|
||||
TagCookie: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Cookies(extraParam))
|
||||
},
|
||||
TagLocals: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
switch v := c.Locals(extraParam).(type) {
|
||||
case []byte:
|
||||
return output.Write(v)
|
||||
case string:
|
||||
return output.WriteString(v)
|
||||
case nil:
|
||||
return 0, nil
|
||||
default:
|
||||
return output.WriteString(fmt.Sprintf("%v", v))
|
||||
}
|
||||
},
|
||||
TagStatus: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
if cfg.enableColors {
|
||||
colors := c.App().Config().ColorScheme
|
||||
return output.WriteString(fmt.Sprintf("%s%3d%s", statusColor(c.Response().StatusCode(), colors), c.Response().StatusCode(), colors.Reset))
|
||||
}
|
||||
return appendInt(output, c.Response().StatusCode())
|
||||
},
|
||||
TagMethod: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
if cfg.enableColors {
|
||||
colors := c.App().Config().ColorScheme
|
||||
return output.WriteString(fmt.Sprintf("%s%s%s", methodColor(c.Method(), colors), c.Method(), colors.Reset))
|
||||
}
|
||||
return output.WriteString(c.Method())
|
||||
},
|
||||
TagPid: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(data.Pid)
|
||||
},
|
||||
TagLatency: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
latency := data.Stop.Sub(data.Start)
|
||||
return output.WriteString(fmt.Sprintf("%13v", latency))
|
||||
},
|
||||
TagTime: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(data.Timestamp.Load().(string)) //nolint:forcetypeassert // We always store a string in here
|
||||
},
|
||||
}
|
||||
// merge with custom tags from user
|
||||
for k, v := range cfg.CustomTags {
|
||||
tagFunctions[k] = v
|
||||
}
|
||||
|
||||
return tagFunctions
|
||||
}
|
||||
+70
@@ -0,0 +1,70 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// buildLogFuncChain analyzes the template and creates slices with the functions for execution and
|
||||
// slices with the fixed parts of the template and the parameters
|
||||
//
|
||||
// fixParts contains the fixed parts of the template or parameters if a function is stored in the funcChain at this position
|
||||
// funcChain contains for the parts which exist the functions for the dynamic parts
|
||||
// funcChain and fixParts always have the same length and contain nil for the parts where no data is required in the chain,
|
||||
// if a function exists for the part, a parameter for it can also exist in the fixParts slice
|
||||
func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) ([][]byte, []LogFunc, error) {
|
||||
// process flow is copied from the fasttemplate flow https://github.com/valyala/fasttemplate/blob/2a2d1afadadf9715bfa19683cdaeac8347e5d9f9/template.go#L23-L62
|
||||
templateB := utils.UnsafeBytes(cfg.Format)
|
||||
startTagB := utils.UnsafeBytes(startTag)
|
||||
endTagB := utils.UnsafeBytes(endTag)
|
||||
paramSeparatorB := utils.UnsafeBytes(paramSeparator)
|
||||
|
||||
var fixParts [][]byte
|
||||
var funcChain []LogFunc
|
||||
|
||||
for {
|
||||
currentPos := bytes.Index(templateB, startTagB)
|
||||
if currentPos < 0 {
|
||||
// no starting tag found in the existing template part
|
||||
break
|
||||
}
|
||||
// add fixed part
|
||||
funcChain = append(funcChain, nil)
|
||||
fixParts = append(fixParts, templateB[:currentPos])
|
||||
|
||||
templateB = templateB[currentPos+len(startTagB):]
|
||||
currentPos = bytes.Index(templateB, endTagB)
|
||||
if currentPos < 0 {
|
||||
// cannot find end tag - just write it to the output.
|
||||
funcChain = append(funcChain, nil)
|
||||
fixParts = append(fixParts, startTagB)
|
||||
break
|
||||
}
|
||||
// ## function block ##
|
||||
// first check for tags with parameters
|
||||
if index := bytes.Index(templateB[:currentPos], paramSeparatorB); index != -1 {
|
||||
logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:index+1])]
|
||||
if !ok {
|
||||
return nil, nil, errors.New("No parameter found in \"" + utils.UnsafeString(templateB[:currentPos]) + "\"")
|
||||
}
|
||||
funcChain = append(funcChain, logFunc)
|
||||
// add param to the fixParts
|
||||
fixParts = append(fixParts, templateB[index+1:currentPos])
|
||||
} else if logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:currentPos])]; ok {
|
||||
// add functions without parameter
|
||||
funcChain = append(funcChain, logFunc)
|
||||
fixParts = append(fixParts, nil)
|
||||
}
|
||||
// ## function block end ##
|
||||
|
||||
// reduce the template string
|
||||
templateB = templateB[currentPos+len(endTagB):]
|
||||
}
|
||||
// set the rest
|
||||
funcChain = append(funcChain, nil)
|
||||
fixParts = append(fixParts, templateB)
|
||||
|
||||
return fixParts, funcChain, nil
|
||||
}
|
||||
+39
@@ -0,0 +1,39 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func methodColor(method string, colors fiber.Colors) string {
|
||||
switch method {
|
||||
case fiber.MethodGet:
|
||||
return colors.Cyan
|
||||
case fiber.MethodPost:
|
||||
return colors.Green
|
||||
case fiber.MethodPut:
|
||||
return colors.Yellow
|
||||
case fiber.MethodDelete:
|
||||
return colors.Red
|
||||
case fiber.MethodPatch:
|
||||
return colors.White
|
||||
case fiber.MethodHead:
|
||||
return colors.Magenta
|
||||
case fiber.MethodOptions:
|
||||
return colors.Blue
|
||||
default:
|
||||
return colors.Reset
|
||||
}
|
||||
}
|
||||
|
||||
func statusColor(code int, colors fiber.Colors) string {
|
||||
switch {
|
||||
case code >= fiber.StatusOK && code < fiber.StatusMultipleChoices:
|
||||
return colors.Green
|
||||
case code >= fiber.StatusMultipleChoices && code < fiber.StatusBadRequest:
|
||||
return colors.Blue
|
||||
case code >= fiber.StatusBadRequest && code < fiber.StatusInternalServerError:
|
||||
return colors.Yellow
|
||||
default:
|
||||
return colors.Red
|
||||
}
|
||||
}
|
||||
+47
@@ -0,0 +1,47 @@
|
||||
package recover //nolint:predeclared // TODO: Rename to some non-builtin
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// EnableStackTrace enables handling stack trace
|
||||
//
|
||||
// Optional. Default: false
|
||||
EnableStackTrace bool
|
||||
|
||||
// StackTraceHandler defines a function to handle stack trace
|
||||
//
|
||||
// Optional. Default: defaultStackTraceHandler
|
||||
StackTraceHandler func(c *fiber.Ctx, e interface{})
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
EnableStackTrace: false,
|
||||
StackTraceHandler: defaultStackTraceHandler,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
if cfg.EnableStackTrace && cfg.StackTraceHandler == nil {
|
||||
cfg.StackTraceHandler = defaultStackTraceHandler
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
+45
@@ -0,0 +1,45 @@
|
||||
package recover //nolint:predeclared // TODO: Rename to some non-builtin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func defaultStackTraceHandler(_ *fiber.Ctx, e interface{}) {
|
||||
_, _ = os.Stderr.WriteString(fmt.Sprintf("panic: %v\n%s\n", e, debug.Stack())) //nolint:errcheck // This will never fail
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) (err error) { //nolint:nonamedreturns // Uses recover() to overwrite the error
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Catch panics
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if cfg.EnableStackTrace {
|
||||
cfg.StackTraceHandler(c, r)
|
||||
}
|
||||
|
||||
var ok bool
|
||||
if err, ok = r.(error); !ok {
|
||||
// Set error that will call the global error handler
|
||||
err = fmt.Errorf("%v", r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Return err if exist, else move to next handler
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user