This commit is contained in:
GitLab Deploy Bot
2025-10-21 23:45:13 +07:00
parent 6c387b420c
commit bb60e987e5
3548 changed files with 4952576 additions and 116 deletions
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2021 Fiber
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+300
View File
@@ -0,0 +1,300 @@
---
id: jwt
---
# JWT
![Release](https://img.shields.io/github/v/tag/gofiber/contrib?filter=jwt*)
[![Discord](https://img.shields.io/discord/704680098577514527?style=flat&label=%F0%9F%92%AC%20discord&color=00ACD7)](https://gofiber.io/discord)
![Test](https://github.com/gofiber/contrib/workflows/Tests/badge.svg)
![Security](https://github.com/gofiber/contrib/workflows/Security/badge.svg)
![Linter](https://github.com/gofiber/contrib/workflows/Linter/badge.svg)
JWT returns a JSON Web Token (JWT) auth middleware.
For valid token, it sets the user in Ctx.Locals and calls next handler.
For invalid token, it returns "401 - Unauthorized" error.
For missing token, it returns "400 - Bad Request" error.
Special thanks and credits to [Echo](https://echo.labstack.com/middleware/jwt)
**Note: Requires Go 1.19 and above**
## Install
This middleware supports Fiber v1 & v2, install accordingly.
```
go get -u github.com/gofiber/fiber/v2
go get -u github.com/gofiber/contrib/jwt
go get -u github.com/golang-jwt/jwt/v5
```
## Signature
```go
jwtware.New(config ...jwtware.Config) func(*fiber.Ctx) error
```
## Config
| Property | Type | Description | Default |
|:---------------|:--------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------|
| Filter | `func(*fiber.Ctx) bool` | Defines a function to skip middleware | `nil` |
| SuccessHandler | `func(*fiber.Ctx) error` | SuccessHandler defines a function which is executed for a valid token. | `nil` |
| ErrorHandler | `func(*fiber.Ctx, error) error` | ErrorHandler defines a function which is executed for an invalid token. | `401 Invalid or expired JWT` |
| SigningKey | `interface{}` | Signing key to validate token. Used as fallback if SigningKeys has length 0. | `nil` |
| SigningKeys | `map[string]interface{}` | Map of signing keys to validate token with kid field usage. | `nil` |
| ContextKey | `string` | Context key to store user information from the token into context. | `"user"` |
| Claims | `jwt.Claim` | Claims are extendable claims data defining token content. | `jwt.MapClaims{}` |
| TokenLookup | `string` | TokenLookup is a string in the form of `<source>:<name>` that is used | `"header:Authorization"` |
| AuthScheme | `string` | AuthScheme to be used in the Authorization header. The default value (`"Bearer"`) will only be used in conjuction with the default `TokenLookup` value. | `"Bearer"` |
| KeyFunc | `func() jwt.Keyfunc` | KeyFunc defines a user-defined function that supplies the public key for a token validation. | `jwtKeyFunc` |
| JWKSetURLs | `[]string` | A slice of unique JSON Web Key (JWK) Set URLs to used to parse JWTs. | `nil` |
## HS256 Example
```go
package main
import (
"time"
"github.com/gofiber/fiber/v2"
jwtware "github.com/gofiber/contrib/jwt"
"github.com/golang-jwt/jwt/v5"
)
func main() {
app := fiber.New()
// Login route
app.Post("/login", login)
// Unauthenticated route
app.Get("/", accessible)
// JWT Middleware
app.Use(jwtware.New(jwtware.Config{
SigningKey: jwtware.SigningKey{Key: []byte("secret")},
}))
// Restricted Routes
app.Get("/restricted", restricted)
app.Listen(":3000")
}
func login(c *fiber.Ctx) error {
user := c.FormValue("user")
pass := c.FormValue("pass")
// Throws Unauthorized error
if user != "john" || pass != "doe" {
return c.SendStatus(fiber.StatusUnauthorized)
}
// Create the Claims
claims := jwt.MapClaims{
"name": "John Doe",
"admin": true,
"exp": time.Now().Add(time.Hour * 72).Unix(),
}
// Create token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
// Generate encoded token and send it as response.
t, err := token.SignedString([]byte("secret"))
if err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
return c.JSON(fiber.Map{"token": t})
}
func accessible(c *fiber.Ctx) error {
return c.SendString("Accessible")
}
func restricted(c *fiber.Ctx) error {
user := c.Locals("user").(*jwt.Token)
claims := user.Claims.(jwt.MapClaims)
name := claims["name"].(string)
return c.SendString("Welcome " + name)
}
```
## HS256 Test
_Login using username and password to retrieve a token._
```
curl --data "user=john&pass=doe" http://localhost:3000/login
```
_Response_
```json
{
"token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE0NjE5NTcxMzZ9.RB3arc4-OyzASAaUhC2W3ReWaXAt_z2Fd3BN4aWTgEY"
}
```
_Request a restricted resource using the token in Authorization request header._
```
curl localhost:3000/restricted -H "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE0NjE5NTcxMzZ9.RB3arc4-OyzASAaUhC2W3ReWaXAt_z2Fd3BN4aWTgEY"
```
_Response_
```
Welcome John Doe
```
## RS256 Example
```go
package main
import (
"crypto/rand"
"crypto/rsa"
"log"
"time"
"github.com/gofiber/fiber/v2"
"github.com/golang-jwt/jwt/v5"
jwtware "github.com/gofiber/contrib/jwt"
)
var (
// Obviously, this is just a test example. Do not do this in production.
// In production, you would have the private key and public key pair generated
// in advance. NEVER add a private key to any GitHub repo.
privateKey *rsa.PrivateKey
)
func main() {
app := fiber.New()
// Just as a demo, generate a new private/public key pair on each run. See note above.
rng := rand.Reader
var err error
privateKey, err = rsa.GenerateKey(rng, 2048)
if err != nil {
log.Fatalf("rsa.GenerateKey: %v", err)
}
// Login route
app.Post("/login", login)
// Unauthenticated route
app.Get("/", accessible)
// JWT Middleware
app.Use(jwtware.New(jwtware.Config{
SigningKey: jwtware.SigningKey{
JWTAlg: jwtware.RS256,
Key: privateKey.Public(),
},
}))
// Restricted Routes
app.Get("/restricted", restricted)
app.Listen(":3000")
}
func login(c *fiber.Ctx) error {
user := c.FormValue("user")
pass := c.FormValue("pass")
// Throws Unauthorized error
if user != "john" || pass != "doe" {
return c.SendStatus(fiber.StatusUnauthorized)
}
// Create the Claims
claims := jwt.MapClaims{
"name": "John Doe",
"admin": true,
"exp": time.Now().Add(time.Hour * 72).Unix(),
}
// Create token
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
// Generate encoded token and send it as response.
t, err := token.SignedString(privateKey)
if err != nil {
log.Printf("token.SignedString: %v", err)
return c.SendStatus(fiber.StatusInternalServerError)
}
return c.JSON(fiber.Map{"token": t})
}
func accessible(c *fiber.Ctx) error {
return c.SendString("Accessible")
}
func restricted(c *fiber.Ctx) error {
user := c.Locals("user").(*jwt.Token)
claims := user.Claims.(jwt.MapClaims)
name := claims["name"].(string)
return c.SendString("Welcome " + name)
}
```
## RS256 Test
The RS256 is actually identical to the HS256 test above.
## JWK Set Test
The tests are identical to basic `JWT` tests above, with exception that `JWKSetURLs` to valid public keys collection in JSON Web Key (JWK) Set format should be supplied. See [RFC 7517](https://www.rfc-editor.org/rfc/rfc7517).
## Custom KeyFunc example
KeyFunc defines a user-defined function that supplies the public key for a token validation.
The function shall take care of verifying the signing algorithm and selecting the proper key.
A user-defined KeyFunc can be useful if tokens are issued by an external party.
When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored.
This is one of the three options to provide a token validation key.
The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
Required if neither SigningKeys nor SigningKey is provided.
Default to an internal implementation verifying the signing algorithm and selecting the proper key.
```go
package main
import (
"fmt"
"github.com/gofiber/fiber/v2"
jwtware "github.com/gofiber/contrib/jwt"
"github.com/golang-jwt/jwt/v5"
)
func main() {
app := fiber.New()
app.Use(jwtware.New(jwtware.Config{
KeyFunc: customKeyFunc(),
}))
app.Get("/ok", func(c *fiber.Ctx) error {
return c.SendString("OK")
})
}
func customKeyFunc() jwt.Keyfunc {
return func(t *jwt.Token) (interface{}, error) {
// Always check the signing method
if t.Method.Alg() != jwtware.HS256 {
return nil, fmt.Errorf("Unexpected jwt signing method=%v", t.Header["alg"])
}
// TODO custom implementation of loading signing key like from a database
signingKey := "secret"
return []byte(signingKey), nil
}
}
```
+232
View File
@@ -0,0 +1,232 @@
package jwtware
import (
"errors"
"fmt"
"log"
"strings"
"time"
"github.com/MicahParks/keyfunc/v2"
"github.com/gofiber/fiber/v2"
"github.com/golang-jwt/jwt/v5"
)
var (
// ErrJWTAlg is returned when the JWT header did not contain the expected algorithm.
ErrJWTAlg = errors.New("the JWT header did not contain the expected algorithm")
)
// Config defines the config for JWT middleware
type Config struct {
// Filter defines a function to skip middleware.
// Optional. Default: nil
Filter func(*fiber.Ctx) bool
// SuccessHandler defines a function which is executed for a valid token.
// Optional. Default: nil
SuccessHandler fiber.Handler
// ErrorHandler defines a function which is executed for an invalid token.
// It may be used to define a custom JWT error.
// Optional. Default: 401 Invalid or expired JWT
ErrorHandler fiber.ErrorHandler
// Signing key to validate token. Used as fallback if SigningKeys has length 0.
// At least one of the following is required: KeyFunc, JWKSetURLs, SigningKeys, or SigningKey.
// The order of precedence is: KeyFunc, JWKSetURLs, SigningKeys, SigningKey.
SigningKey SigningKey
// Map of signing keys to validate token with kid field usage.
// At least one of the following is required: KeyFunc, JWKSetURLs, SigningKeys, or SigningKey.
// The order of precedence is: KeyFunc, JWKSetURLs, SigningKeys, SigningKey.
SigningKeys map[string]SigningKey
// Context key to store user information from the token into context.
// Optional. Default: "user".
ContextKey string
// Claims are extendable claims data defining token content.
// Optional. Default value jwt.MapClaims
Claims jwt.Claims
// TokenLookup is a string in the form of "<source>:<name>" that is used
// to extract token from the request.
// Optional. Default value "header:Authorization".
// Possible values:
// - "header:<name>"
// - "query:<name>"
// - "param:<name>"
// - "cookie:<name>"
TokenLookup string
// AuthScheme to be used in the Authorization header.
// Optional. Default: "Bearer".
AuthScheme string
// KeyFunc is a function that supplies the public key for JWT cryptographic verification.
// The function shall take care of verifying the signing algorithm and selecting the proper key.
// Internally, github.com/MicahParks/keyfunc/v2 package is used project defaults. If you need more customization,
// you can provide a jwt.Keyfunc using that package or make your own implementation.
//
// At least one of the following is required: KeyFunc, JWKSetURLs, SigningKeys, or SigningKey.
// The order of precedence is: KeyFunc, JWKSetURLs, SigningKeys, SigningKey.
KeyFunc jwt.Keyfunc
// JWKSetURLs is a slice of HTTP URLs that contain the JSON Web Key Set (JWKS) used to verify the signatures of
// JWTs. Use of HTTPS is recommended. The presence of the "kid" field in the JWT header and JWKs is mandatory for
// this feature.
//
// By default, all JWK Sets in this slice will:
// * Refresh every hour.
// * Refresh automatically if a new "kid" is seen in a JWT being verified.
// * Rate limit refreshes to once every 5 minutes.
// * Timeout refreshes after 10 seconds.
//
// At least one of the following is required: KeyFunc, JWKSetURLs, SigningKeys, or SigningKey.
// The order of precedence is: KeyFunc, JWKSetURLs, SigningKeys, SigningKey.
JWKSetURLs []string
}
// SigningKey holds information about the recognized cryptographic keys used to sign JWTs by this program.
type SigningKey struct {
// JWTAlg is the algorithm used to sign JWTs. If this value is a non-empty string, this will be checked against the
// "alg" value in the JWT header.
//
// https://www.rfc-editor.org/rfc/rfc7518#section-3.1
JWTAlg string
// Key is the cryptographic key used to sign JWTs. For supported types, please see
// https://github.com/golang-jwt/jwt.
Key interface{}
}
// makeCfg function will check correctness of supplied configuration
// and will complement it with default values instead of missing ones
func makeCfg(config []Config) (cfg Config) {
if len(config) > 0 {
cfg = config[0]
}
if cfg.SuccessHandler == nil {
cfg.SuccessHandler = func(c *fiber.Ctx) error {
return c.Next()
}
}
if cfg.ErrorHandler == nil {
cfg.ErrorHandler = func(c *fiber.Ctx, err error) error {
if err.Error() == ErrJWTMissingOrMalformed.Error() {
return c.Status(fiber.StatusBadRequest).SendString(ErrJWTMissingOrMalformed.Error())
}
return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired JWT")
}
}
if cfg.SigningKey.Key == nil && len(cfg.SigningKeys) == 0 && len(cfg.JWKSetURLs) == 0 && cfg.KeyFunc == nil {
panic("Fiber: JWT middleware configuration: At least one of the following is required: KeyFunc, JWKSetURLs, SigningKeys, or SigningKey.")
}
if cfg.ContextKey == "" {
cfg.ContextKey = "user"
}
if cfg.Claims == nil {
cfg.Claims = jwt.MapClaims{}
}
if cfg.TokenLookup == "" {
cfg.TokenLookup = defaultTokenLookup
// set AuthScheme as "Bearer" only if TokenLookup is set to default.
if cfg.AuthScheme == "" {
cfg.AuthScheme = "Bearer"
}
}
if cfg.KeyFunc == nil {
if len(cfg.SigningKeys) > 0 || len(cfg.JWKSetURLs) > 0 {
var givenKeys map[string]keyfunc.GivenKey
if cfg.SigningKeys != nil {
givenKeys = make(map[string]keyfunc.GivenKey, len(cfg.SigningKeys))
for kid, key := range cfg.SigningKeys {
givenKeys[kid] = keyfunc.NewGivenCustom(key.Key, keyfunc.GivenKeyOptions{
Algorithm: key.JWTAlg,
})
}
}
if len(cfg.JWKSetURLs) > 0 {
var err error
cfg.KeyFunc, err = multiKeyfunc(givenKeys, cfg.JWKSetURLs)
if err != nil {
panic("Failed to create keyfunc from JWK Set URL: " + err.Error())
}
} else {
cfg.KeyFunc = keyfunc.NewGiven(givenKeys).Keyfunc
}
} else {
cfg.KeyFunc = signingKeyFunc(cfg.SigningKey)
}
}
return cfg
}
func multiKeyfunc(givenKeys map[string]keyfunc.GivenKey, jwkSetURLs []string) (jwt.Keyfunc, error) {
opts := keyfuncOptions(givenKeys)
multiple := make(map[string]keyfunc.Options, len(jwkSetURLs))
for _, url := range jwkSetURLs {
multiple[url] = opts
}
multiOpts := keyfunc.MultipleOptions{
KeySelector: keyfunc.KeySelectorFirst,
}
multi, err := keyfunc.GetMultiple(multiple, multiOpts)
if err != nil {
return nil, fmt.Errorf("failed to get multiple JWK Set URLs: %w", err)
}
return multi.Keyfunc, nil
}
func keyfuncOptions(givenKeys map[string]keyfunc.GivenKey) keyfunc.Options {
return keyfunc.Options{
GivenKeys: givenKeys,
RefreshErrorHandler: func(err error) {
log.Printf("Failed to perform background refresh of JWK Set: %s.", err)
},
RefreshInterval: time.Hour,
RefreshRateLimit: time.Minute * 5,
RefreshTimeout: time.Second * 10,
RefreshUnknownKID: true,
}
}
// getExtractors function will create a slice of functions which will be used
// for token sarch and will perform extraction of the value
func (cfg *Config) getExtractors() []jwtExtractor {
// Initialize
extractors := make([]jwtExtractor, 0)
rootParts := strings.Split(cfg.TokenLookup, ",")
for _, rootPart := range rootParts {
parts := strings.Split(strings.TrimSpace(rootPart), ":")
switch parts[0] {
case "header":
extractors = append(extractors, jwtFromHeader(parts[1], cfg.AuthScheme))
case "query":
extractors = append(extractors, jwtFromQuery(parts[1]))
case "param":
extractors = append(extractors, jwtFromParam(parts[1]))
case "cookie":
extractors = append(extractors, jwtFromCookie(parts[1]))
}
}
return extractors
}
func signingKeyFunc(key SigningKey) jwt.Keyfunc {
return func(token *jwt.Token) (interface{}, error) {
if key.JWTAlg != "" {
alg, ok := token.Header["alg"].(string)
if !ok {
return nil, fmt.Errorf("unexpected jwt signing method: expected: %q: got: missing or unexpected JSON type", key.JWTAlg)
}
if alg != key.JWTAlg {
return nil, fmt.Errorf("unexpected jwt signing method: expected: %q: got: %q", key.JWTAlg, alg)
}
}
return key.Key, nil
}
}
+48
View File
@@ -0,0 +1,48 @@
package jwtware
const (
// HS256 represents a public cryptography key generated by a 256 bit HMAC algorithm.
HS256 = "HS256"
// HS384 represents a public cryptography key generated by a 384 bit HMAC algorithm.
HS384 = "HS384"
// HS512 represents a public cryptography key generated by a 512 bit HMAC algorithm.
HS512 = "HS512"
// ES256 represents a public cryptography key generated by a 256 bit ECDSA algorithm.
ES256 = "ES256"
// ES384 represents a public cryptography key generated by a 384 bit ECDSA algorithm.
ES384 = "ES384"
// ES512 represents a public cryptography key generated by a 512 bit ECDSA algorithm.
ES512 = "ES512"
// P256 represents a cryptographic elliptical curve type.
P256 = "P-256"
// P384 represents a cryptographic elliptical curve type.
P384 = "P-384"
// P521 represents a cryptographic elliptical curve type.
P521 = "P-521"
// RS256 represents a public cryptography key generated by a 256 bit RSA algorithm.
RS256 = "RS256"
// RS384 represents a public cryptography key generated by a 384 bit RSA algorithm.
RS384 = "RS384"
// RS512 represents a public cryptography key generated by a 512 bit RSA algorithm.
RS512 = "RS512"
// PS256 represents a public cryptography key generated by a 256 bit RSA algorithm.
PS256 = "PS256"
// PS384 represents a public cryptography key generated by a 384 bit RSA algorithm.
PS384 = "PS384"
// PS512 represents a public cryptography key generated by a 512 bit RSA algorithm.
PS512 = "PS512"
)
+59
View File
@@ -0,0 +1,59 @@
// 🚀 Fiber is an Express inspired web framework written in Go with 💖
// 📌 API Documentation: https://fiber.wiki
// 📝 Github Repository: https://github.com/gofiber/fiber
// Special thanks to Echo: https://github.com/labstack/echo/blob/master/middleware/jwt.go
package jwtware
import (
"reflect"
"github.com/gofiber/fiber/v2"
"github.com/golang-jwt/jwt/v5"
)
var (
defaultTokenLookup = "header:" + fiber.HeaderAuthorization
)
// New ...
func New(config ...Config) fiber.Handler {
cfg := makeCfg(config)
extractors := cfg.getExtractors()
// Return middleware handler
return func(c *fiber.Ctx) error {
// Filter request to skip middleware
if cfg.Filter != nil && cfg.Filter(c) {
return c.Next()
}
var auth string
var err error
for _, extractor := range extractors {
auth, err = extractor(c)
if auth != "" && err == nil {
break
}
}
if err != nil {
return cfg.ErrorHandler(c, err)
}
var token *jwt.Token
if _, ok := cfg.Claims.(jwt.MapClaims); ok {
token, err = jwt.Parse(auth, cfg.KeyFunc)
} else {
t := reflect.ValueOf(cfg.Claims).Type().Elem()
claims := reflect.New(t).Interface().(jwt.Claims)
token, err = jwt.ParseWithClaims(auth, claims, cfg.KeyFunc)
}
if err == nil && token.Valid {
// Store user information from token into context.
c.Locals(cfg.ContextKey, token)
return cfg.SuccessHandler(c)
}
return cfg.ErrorHandler(c, err)
}
}
+60
View File
@@ -0,0 +1,60 @@
package jwtware
import (
"errors"
"strings"
"github.com/gofiber/fiber/v2"
)
var (
// ErrJWTMissingOrMalformed is returned when the JWT is missing or malformed.
ErrJWTMissingOrMalformed = errors.New("missing or malformed JWT")
)
type jwtExtractor func(c *fiber.Ctx) (string, error)
// jwtFromHeader returns a function that extracts token from the request header.
func jwtFromHeader(header string, authScheme string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
auth := c.Get(header)
l := len(authScheme)
if len(auth) > l+1 && strings.EqualFold(auth[:l], authScheme) {
return strings.TrimSpace(auth[l:]), nil
}
return "", ErrJWTMissingOrMalformed
}
}
// jwtFromQuery returns a function that extracts token from the query string.
func jwtFromQuery(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Query(param)
if token == "" {
return "", ErrJWTMissingOrMalformed
}
return token, nil
}
}
// jwtFromParam returns a function that extracts token from the url param string.
func jwtFromParam(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Params(param)
if token == "" {
return "", ErrJWTMissingOrMalformed
}
return token, nil
}
}
// jwtFromCookie returns a function that extracts token from the named cookie.
func jwtFromCookie(name string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Cookies(name)
if token == "" {
return "", ErrJWTMissingOrMalformed
}
return token, nil
}
}
+8
View File
@@ -0,0 +1,8 @@
; This file is for unifying the coding style for different editors and IDEs.
; More information at http://editorconfig.org
; This style originates from https://github.com/fewagency/best-practices
root = true
[*]
charset = utf-8
end_of_line = lf
+12
View File
@@ -0,0 +1,12 @@
# Handle line endings automatically for files detected as text
# and leave all files detected as binary untouched.
* text=auto eol=lf
# Force batch scripts to always use CRLF line endings so that if a repo is accessed
# in Windows via a file share from Linux, the scripts will work.
*.{cmd,[cC][mM][dD]} text eol=crlf
*.{bat,[bB][aA][tT]} text eol=crlf
# Force bash scripts to always use LF line endings so that if a repo is accessed
# in Unix via a file share from Windows, the scripts will work.
*.sh text eol=lf
+30
View File
@@ -0,0 +1,30 @@
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, built with `go test -c`
*.test
*.tmp
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# IDE files
.vscode
.DS_Store
.idea
# Misc
*.fiber.gz
*.fasthttp.gz
*.pprof
*.workspace
# Dependencies
/vendor/
vendor/
vendor
/Godeps/
+197
View File
@@ -0,0 +1,197 @@
# Created based on v1.51.0
# NOTE: Keep this in sync with the version in .github/workflows/linter.yml
run:
modules-download-mode: readonly
skip-dirs-use-default: false
skip-dirs:
- internal
output:
sort-results: true
linters-settings:
errcheck:
check-type-assertions: true
check-blank: true
disable-default-exclusions: true
errchkjson:
report-no-exported: true
exhaustive:
default-signifies-exhaustive: true
forbidigo:
forbid:
- ^(fmt\.Print(|f|ln)|print|println)$
- 'http\.Default(Client|Transport)'
# TODO: Eventually enable these patterns
# - 'time\.Sleep'
# - 'panic'
gocritic:
disabled-checks:
- ifElseChain
gofumpt:
module-path: github.com/gofiber/fiber
extra-rules: true
gosec:
config:
global:
audit: true
govet:
check-shadowing: true
enable-all: true
disable:
- shadow
- fieldalignment
- loopclosure
grouper:
import-require-single-import: true
import-require-grouping: true
misspell:
locale: US
nolintlint:
require-explanation: true
require-specific: true
nonamedreturns:
report-error-in-defer: true
predeclared:
q: true
promlinter:
strict: true
revive:
enable-all-rules: true
rules:
# Provided by gomnd linter
- name: add-constant
disabled: true
- name: argument-limit
disabled: true
# Provided by bidichk
- name: banned-characters
disabled: true
- name: cognitive-complexity
disabled: true
- name: cyclomatic
disabled: true
- name: early-return
severity: warning
disabled: true
- name: exported
disabled: true
- name: file-header
disabled: true
- name: function-result-limit
disabled: true
- name: function-length
disabled: true
- name: line-length-limit
disabled: true
- name: max-public-structs
disabled: true
- name: modifies-parameter
disabled: true
- name: nested-structs
disabled: true
- name: package-comments
disabled: true
stylecheck:
checks:
- all
- -ST1000
- -ST1020
- -ST1021
- -ST1022
tagliatelle:
case:
rules:
json: snake
#tenv:
# all: true
#unparam:
# check-exported: true
wrapcheck:
ignorePackageGlobs:
- github.com/gofiber/fiber/*
- github.com/valyala/fasthttp
issues:
exclude-use-default: false
linters:
enable:
- asasalint
- asciicheck
- bidichk
- bodyclose
- containedctx
- contextcheck
- depguard
- dogsled
- durationcheck
- errcheck
- errchkjson
- errname
- errorlint
- execinquery
- exhaustive
- exportloopref
- forbidigo
- forcetypeassert
- goconst
- gocritic
- gofmt
- gofumpt
- goimports
- gomoddirectives
- goprintffuncname
- gosec
- gosimple
- govet
- grouper
- loggercheck
- misspell
- nakedret
- nilerr
- nilnil
- noctx
- nolintlint
- nonamedreturns
- nosprintfhostport
- predeclared
- promlinter
- reassign
- revive
- rowserrcheck
- sqlclosecheck
- staticcheck
- stylecheck
- tagliatelle
# - testpackage # TODO: Enable once https://github.com/gofiber/fiber/issues/2252 is implemented
- thelper
# - tparallel # TODO: Enable once https://github.com/gofiber/fiber/issues/2254 is implemented
- typecheck
- unconvert
- unparam
- unused
- usestdlibvars
- wastedassign
- whitespace
- wrapcheck
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2019-present Fenny and Contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+1120
View File
File diff suppressed because it is too large Load Diff
+1021
View File
File diff suppressed because it is too large Load Diff
+107
View File
@@ -0,0 +1,107 @@
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 🤖 Github Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package fiber
// Colors is a struct to define custom colors for Fiber app and middlewares.
type Colors struct {
// Black color.
//
// Optional. Default: "\u001b[90m"
Black string
// Red color.
//
// Optional. Default: "\u001b[91m"
Red string
// Green color.
//
// Optional. Default: "\u001b[92m"
Green string
// Yellow color.
//
// Optional. Default: "\u001b[93m"
Yellow string
// Blue color.
//
// Optional. Default: "\u001b[94m"
Blue string
// Magenta color.
//
// Optional. Default: "\u001b[95m"
Magenta string
// Cyan color.
//
// Optional. Default: "\u001b[96m"
Cyan string
// White color.
//
// Optional. Default: "\u001b[97m"
White string
// Reset color.
//
// Optional. Default: "\u001b[0m"
Reset string
}
// DefaultColors Default color codes
var DefaultColors = Colors{
Black: "\u001b[90m",
Red: "\u001b[91m",
Green: "\u001b[92m",
Yellow: "\u001b[93m",
Blue: "\u001b[94m",
Magenta: "\u001b[95m",
Cyan: "\u001b[96m",
White: "\u001b[97m",
Reset: "\u001b[0m",
}
// defaultColors is a function to override default colors to config
func defaultColors(colors Colors) Colors {
if colors.Black == "" {
colors.Black = DefaultColors.Black
}
if colors.Red == "" {
colors.Red = DefaultColors.Red
}
if colors.Green == "" {
colors.Green = DefaultColors.Green
}
if colors.Yellow == "" {
colors.Yellow = DefaultColors.Yellow
}
if colors.Blue == "" {
colors.Blue = DefaultColors.Blue
}
if colors.Magenta == "" {
colors.Magenta = DefaultColors.Magenta
}
if colors.Cyan == "" {
colors.Cyan = DefaultColors.Cyan
}
if colors.White == "" {
colors.White = DefaultColors.White
}
if colors.Reset == "" {
colors.Reset = DefaultColors.Reset
}
return colors
}
+1993
View File
File diff suppressed because it is too large Load Diff
+40
View File
@@ -0,0 +1,40 @@
package fiber
import (
errors "encoding/json"
"github.com/gofiber/fiber/v2/internal/schema"
)
type (
// ConversionError Conversion error exposes the internal schema.ConversionError for public use.
ConversionError = schema.ConversionError
// UnknownKeyError error exposes the internal schema.UnknownKeyError for public use.
UnknownKeyError = schema.UnknownKeyError
// EmptyFieldError error exposes the internal schema.EmptyFieldError for public use.
EmptyFieldError = schema.EmptyFieldError
// MultiError error exposes the internal schema.MultiError for public use.
MultiError = schema.MultiError
)
type (
// An InvalidUnmarshalError describes an invalid argument passed to Unmarshal.
// (The argument to Unmarshal must be a non-nil pointer.)
InvalidUnmarshalError = errors.InvalidUnmarshalError
// A MarshalerError represents an error from calling a MarshalJSON or MarshalText method.
MarshalerError = errors.MarshalerError
// A SyntaxError is a description of a JSON syntax error.
SyntaxError = errors.SyntaxError
// An UnmarshalTypeError describes a JSON value that was
// not appropriate for a value of a specific Go type.
UnmarshalTypeError = errors.UnmarshalTypeError
// An UnsupportedTypeError is returned by Marshal when attempting
// to encode an unsupported value type.
UnsupportedTypeError = errors.UnsupportedTypeError
UnsupportedValueError = errors.UnsupportedValueError
)
+209
View File
@@ -0,0 +1,209 @@
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 🤖 Github Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package fiber
import (
"fmt"
"reflect"
)
// Group struct
type Group struct {
app *App
parentGroup *Group
name string
anyRouteDefined bool
Prefix string
}
// Name Assign name to specific route or group itself.
//
// If this method is used before any route added to group, it'll set group name and OnGroupNameHook will be used.
// Otherwise, it'll set route name and OnName hook will be used.
func (grp *Group) Name(name string) Router {
if grp.anyRouteDefined {
grp.app.Name(name)
return grp
}
grp.app.mutex.Lock()
if grp.parentGroup != nil {
grp.name = grp.parentGroup.name + name
} else {
grp.name = name
}
if err := grp.app.hooks.executeOnGroupNameHooks(*grp); err != nil {
panic(err)
}
grp.app.mutex.Unlock()
return grp
}
// Use registers a middleware route that will match requests
// with the provided prefix (which is optional and defaults to "/").
//
// app.Use(func(c *fiber.Ctx) error {
// return c.Next()
// })
// app.Use("/api", func(c *fiber.Ctx) error {
// return c.Next()
// })
// app.Use("/api", handler, func(c *fiber.Ctx) error {
// return c.Next()
// })
//
// This method will match all HTTP verbs: GET, POST, PUT, HEAD etc...
func (grp *Group) Use(args ...interface{}) Router {
var prefix string
var prefixes []string
var handlers []Handler
for i := 0; i < len(args); i++ {
switch arg := args[i].(type) {
case string:
prefix = arg
case []string:
prefixes = arg
case Handler:
handlers = append(handlers, arg)
default:
panic(fmt.Sprintf("use: invalid handler %v\n", reflect.TypeOf(arg)))
}
}
if len(prefixes) == 0 {
prefixes = append(prefixes, prefix)
}
for _, prefix := range prefixes {
grp.app.register(methodUse, getGroupPath(grp.Prefix, prefix), grp, handlers...)
}
if !grp.anyRouteDefined {
grp.anyRouteDefined = true
}
return grp
}
// Get registers a route for GET methods that requests a representation
// of the specified resource. Requests using GET should only retrieve data.
func (grp *Group) Get(path string, handlers ...Handler) Router {
grp.Add(MethodHead, path, handlers...)
return grp.Add(MethodGet, path, handlers...)
}
// Head registers a route for HEAD methods that asks for a response identical
// to that of a GET request, but without the response body.
func (grp *Group) Head(path string, handlers ...Handler) Router {
return grp.Add(MethodHead, path, handlers...)
}
// Post registers a route for POST methods that is used to submit an entity to the
// specified resource, often causing a change in state or side effects on the server.
func (grp *Group) Post(path string, handlers ...Handler) Router {
return grp.Add(MethodPost, path, handlers...)
}
// Put registers a route for PUT methods that replaces all current representations
// of the target resource with the request payload.
func (grp *Group) Put(path string, handlers ...Handler) Router {
return grp.Add(MethodPut, path, handlers...)
}
// Delete registers a route for DELETE methods that deletes the specified resource.
func (grp *Group) Delete(path string, handlers ...Handler) Router {
return grp.Add(MethodDelete, path, handlers...)
}
// Connect registers a route for CONNECT methods that establishes a tunnel to the
// server identified by the target resource.
func (grp *Group) Connect(path string, handlers ...Handler) Router {
return grp.Add(MethodConnect, path, handlers...)
}
// Options registers a route for OPTIONS methods that is used to describe the
// communication options for the target resource.
func (grp *Group) Options(path string, handlers ...Handler) Router {
return grp.Add(MethodOptions, path, handlers...)
}
// Trace registers a route for TRACE methods that performs a message loop-back
// test along the path to the target resource.
func (grp *Group) Trace(path string, handlers ...Handler) Router {
return grp.Add(MethodTrace, path, handlers...)
}
// Patch registers a route for PATCH methods that is used to apply partial
// modifications to a resource.
func (grp *Group) Patch(path string, handlers ...Handler) Router {
return grp.Add(MethodPatch, path, handlers...)
}
// Add allows you to specify a HTTP method to register a route
func (grp *Group) Add(method, path string, handlers ...Handler) Router {
grp.app.register(method, getGroupPath(grp.Prefix, path), grp, handlers...)
if !grp.anyRouteDefined {
grp.anyRouteDefined = true
}
return grp
}
// Static will create a file server serving static files
func (grp *Group) Static(prefix, root string, config ...Static) Router {
grp.app.registerStatic(getGroupPath(grp.Prefix, prefix), root, config...)
if !grp.anyRouteDefined {
grp.anyRouteDefined = true
}
return grp
}
// All will register the handler on all HTTP methods
func (grp *Group) All(path string, handlers ...Handler) Router {
for _, method := range grp.app.config.RequestMethods {
_ = grp.Add(method, path, handlers...)
}
return grp
}
// Group is used for Routes with common prefix to define a new sub-router with optional middleware.
//
// api := app.Group("/api")
// api.Get("/users", handler)
func (grp *Group) Group(prefix string, handlers ...Handler) Router {
prefix = getGroupPath(grp.Prefix, prefix)
if len(handlers) > 0 {
grp.app.register(methodUse, prefix, grp, handlers...)
}
// Create new group
newGrp := &Group{Prefix: prefix, app: grp.app, parentGroup: grp}
if err := grp.app.hooks.executeOnGroupHooks(*newGrp); err != nil {
panic(err)
}
return newGrp
}
// Route is used to define routes with a common prefix inside the common function.
// Uses Group method to define new sub-router.
func (grp *Group) Route(prefix string, fn func(router Router), name ...string) Router {
// Create new group
group := grp.Group(prefix)
if len(name) > 0 {
group.Name(name[0])
}
// Define routes
fn(group)
return group
}
+1153
View File
File diff suppressed because it is too large Load Diff
+218
View File
@@ -0,0 +1,218 @@
package fiber
import (
"github.com/gofiber/fiber/v2/log"
)
// OnRouteHandler Handlers define a function to create hooks for Fiber.
type (
OnRouteHandler = func(Route) error
OnNameHandler = OnRouteHandler
OnGroupHandler = func(Group) error
OnGroupNameHandler = OnGroupHandler
OnListenHandler = func(ListenData) error
OnShutdownHandler = func() error
OnForkHandler = func(int) error
OnMountHandler = func(*App) error
)
// Hooks is a struct to use it with App.
type Hooks struct {
// Embed app
app *App
// Hooks
onRoute []OnRouteHandler
onName []OnNameHandler
onGroup []OnGroupHandler
onGroupName []OnGroupNameHandler
onListen []OnListenHandler
onShutdown []OnShutdownHandler
onFork []OnForkHandler
onMount []OnMountHandler
}
// ListenData is a struct to use it with OnListenHandler
type ListenData struct {
Host string
Port string
TLS bool
}
func newHooks(app *App) *Hooks {
return &Hooks{
app: app,
onRoute: make([]OnRouteHandler, 0),
onGroup: make([]OnGroupHandler, 0),
onGroupName: make([]OnGroupNameHandler, 0),
onName: make([]OnNameHandler, 0),
onListen: make([]OnListenHandler, 0),
onShutdown: make([]OnShutdownHandler, 0),
onFork: make([]OnForkHandler, 0),
onMount: make([]OnMountHandler, 0),
}
}
// OnRoute is a hook to execute user functions on each route registeration.
// Also you can get route properties by route parameter.
func (h *Hooks) OnRoute(handler ...OnRouteHandler) {
h.app.mutex.Lock()
h.onRoute = append(h.onRoute, handler...)
h.app.mutex.Unlock()
}
// OnName is a hook to execute user functions on each route naming.
// Also you can get route properties by route parameter.
//
// WARN: OnName only works with naming routes, not groups.
func (h *Hooks) OnName(handler ...OnNameHandler) {
h.app.mutex.Lock()
h.onName = append(h.onName, handler...)
h.app.mutex.Unlock()
}
// OnGroup is a hook to execute user functions on each group registeration.
// Also you can get group properties by group parameter.
func (h *Hooks) OnGroup(handler ...OnGroupHandler) {
h.app.mutex.Lock()
h.onGroup = append(h.onGroup, handler...)
h.app.mutex.Unlock()
}
// OnGroupName is a hook to execute user functions on each group naming.
// Also you can get group properties by group parameter.
//
// WARN: OnGroupName only works with naming groups, not routes.
func (h *Hooks) OnGroupName(handler ...OnGroupNameHandler) {
h.app.mutex.Lock()
h.onGroupName = append(h.onGroupName, handler...)
h.app.mutex.Unlock()
}
// OnListen is a hook to execute user functions on Listen, ListenTLS, Listener.
func (h *Hooks) OnListen(handler ...OnListenHandler) {
h.app.mutex.Lock()
h.onListen = append(h.onListen, handler...)
h.app.mutex.Unlock()
}
// OnShutdown is a hook to execute user functions after Shutdown.
func (h *Hooks) OnShutdown(handler ...OnShutdownHandler) {
h.app.mutex.Lock()
h.onShutdown = append(h.onShutdown, handler...)
h.app.mutex.Unlock()
}
// OnFork is a hook to execute user function after fork process.
func (h *Hooks) OnFork(handler ...OnForkHandler) {
h.app.mutex.Lock()
h.onFork = append(h.onFork, handler...)
h.app.mutex.Unlock()
}
// OnMount is a hook to execute user function after mounting process.
// The mount event is fired when sub-app is mounted on a parent app. The parent app is passed as a parameter.
// It works for app and group mounting.
func (h *Hooks) OnMount(handler ...OnMountHandler) {
h.app.mutex.Lock()
h.onMount = append(h.onMount, handler...)
h.app.mutex.Unlock()
}
func (h *Hooks) executeOnRouteHooks(route Route) error {
// Check mounting
if h.app.mountFields.mountPath != "" {
route.path = h.app.mountFields.mountPath + route.path
route.Path = route.path
}
for _, v := range h.onRoute {
if err := v(route); err != nil {
return err
}
}
return nil
}
func (h *Hooks) executeOnNameHooks(route Route) error {
// Check mounting
if h.app.mountFields.mountPath != "" {
route.path = h.app.mountFields.mountPath + route.path
route.Path = route.path
}
for _, v := range h.onName {
if err := v(route); err != nil {
return err
}
}
return nil
}
func (h *Hooks) executeOnGroupHooks(group Group) error {
// Check mounting
if h.app.mountFields.mountPath != "" {
group.Prefix = h.app.mountFields.mountPath + group.Prefix
}
for _, v := range h.onGroup {
if err := v(group); err != nil {
return err
}
}
return nil
}
func (h *Hooks) executeOnGroupNameHooks(group Group) error {
// Check mounting
if h.app.mountFields.mountPath != "" {
group.Prefix = h.app.mountFields.mountPath + group.Prefix
}
for _, v := range h.onGroupName {
if err := v(group); err != nil {
return err
}
}
return nil
}
func (h *Hooks) executeOnListenHooks(listenData ListenData) error {
for _, v := range h.onListen {
if err := v(listenData); err != nil {
return err
}
}
return nil
}
func (h *Hooks) executeOnShutdownHooks() {
for _, v := range h.onShutdown {
if err := v(); err != nil {
log.Errorf("failed to call shutdown hook: %v", err)
}
}
}
func (h *Hooks) executeOnForkHooks(pid int) {
for _, v := range h.onFork {
if err := v(pid); err != nil {
log.Errorf("failed to call fork hook: %v", err)
}
}
}
func (h *Hooks) executeOnMountHooks(app *App) error {
for _, v := range h.onMount {
if err := v(app); err != nil {
return err
}
}
return nil
}
+97
View File
@@ -0,0 +1,97 @@
// Package memory Is a slight copy of the memory storage, but far from the storage interface it can not only work with bytes
// but directly store any kind of data without having to encode it each time, which gives a huge speed advantage
package memory
import (
"sync"
"sync/atomic"
"time"
"github.com/gofiber/fiber/v2/utils"
)
type Storage struct {
sync.RWMutex
data map[string]item // data
}
type item struct {
// max value is 4294967295 -> Sun Feb 07 2106 06:28:15 GMT+0000
e uint32 // exp
v interface{} // val
}
func New() *Storage {
store := &Storage{
data: make(map[string]item),
}
utils.StartTimeStampUpdater()
go store.gc(1 * time.Second)
return store
}
// Get value by key
func (s *Storage) Get(key string) interface{} {
s.RLock()
v, ok := s.data[key]
s.RUnlock()
if !ok || v.e != 0 && v.e <= atomic.LoadUint32(&utils.Timestamp) {
return nil
}
return v.v
}
// Set key with value
func (s *Storage) Set(key string, val interface{}, ttl time.Duration) {
var exp uint32
if ttl > 0 {
exp = uint32(ttl.Seconds()) + atomic.LoadUint32(&utils.Timestamp)
}
i := item{exp, val}
s.Lock()
s.data[key] = i
s.Unlock()
}
// Delete key by key
func (s *Storage) Delete(key string) {
s.Lock()
delete(s.data, key)
s.Unlock()
}
// Reset all keys
func (s *Storage) Reset() {
nd := make(map[string]item)
s.Lock()
s.data = nd
s.Unlock()
}
func (s *Storage) gc(sleep time.Duration) {
ticker := time.NewTicker(sleep)
defer ticker.Stop()
var expired []string
for range ticker.C {
ts := atomic.LoadUint32(&utils.Timestamp)
expired = expired[:0]
s.RLock()
for key, v := range s.data {
if v.e != 0 && v.e <= ts {
expired = append(expired, key)
}
}
s.RUnlock()
s.Lock()
// Double-checked locking.
// We might have replaced the item in the meantime.
for i := range expired {
v := s.data[expired[i]]
if v.e != 0 && v.e <= ts {
delete(s.data, expired[i])
}
}
s.Unlock()
}
}
+27
View File
@@ -0,0 +1,27 @@
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+305
View File
@@ -0,0 +1,305 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schema
import (
"errors"
"reflect"
"strconv"
"strings"
"sync"
)
var errInvalidPath = errors.New("schema: invalid path")
// newCache returns a new cache.
func newCache() *cache {
c := cache{
m: make(map[reflect.Type]*structInfo),
regconv: make(map[reflect.Type]Converter),
tag: "schema",
}
return &c
}
// cache caches meta-data about a struct.
type cache struct {
l sync.RWMutex
m map[reflect.Type]*structInfo
regconv map[reflect.Type]Converter
tag string
}
// registerConverter registers a converter function for a custom type.
func (c *cache) registerConverter(value interface{}, converterFunc Converter) {
c.regconv[reflect.TypeOf(value)] = converterFunc
}
// parsePath parses a path in dotted notation verifying that it is a valid
// path to a struct field.
//
// It returns "path parts" which contain indices to fields to be used by
// reflect.Value.FieldByString(). Multiple parts are required for slices of
// structs.
func (c *cache) parsePath(p string, t reflect.Type) ([]pathPart, error) {
var struc *structInfo
var field *fieldInfo
var index64 int64
var err error
parts := make([]pathPart, 0)
path := make([]string, 0)
keys := strings.Split(p, ".")
for i := 0; i < len(keys); i++ {
if t.Kind() != reflect.Struct {
return nil, errInvalidPath
}
if struc = c.get(t); struc == nil {
return nil, errInvalidPath
}
if field = struc.get(keys[i]); field == nil {
return nil, errInvalidPath
}
// Valid field. Append index.
path = append(path, field.name)
if field.isSliceOfStructs && (!field.unmarshalerInfo.IsValid || (field.unmarshalerInfo.IsValid && field.unmarshalerInfo.IsSliceElement)) {
// Parse a special case: slices of structs.
// i+1 must be the slice index.
//
// Now that struct can implements TextUnmarshaler interface,
// we don't need to force the struct's fields to appear in the path.
// So checking i+2 is not necessary anymore.
i++
if i+1 > len(keys) {
return nil, errInvalidPath
}
if index64, err = strconv.ParseInt(keys[i], 10, 0); err != nil {
return nil, errInvalidPath
}
parts = append(parts, pathPart{
path: path,
field: field,
index: int(index64),
})
path = make([]string, 0)
// Get the next struct type, dropping ptrs.
if field.typ.Kind() == reflect.Ptr {
t = field.typ.Elem()
} else {
t = field.typ
}
if t.Kind() == reflect.Slice {
t = t.Elem()
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
}
} else if field.typ.Kind() == reflect.Ptr {
t = field.typ.Elem()
} else {
t = field.typ
}
}
// Add the remaining.
parts = append(parts, pathPart{
path: path,
field: field,
index: -1,
})
return parts, nil
}
// get returns a cached structInfo, creating it if necessary.
func (c *cache) get(t reflect.Type) *structInfo {
c.l.RLock()
info := c.m[t]
c.l.RUnlock()
if info == nil {
info = c.create(t, "")
c.l.Lock()
c.m[t] = info
c.l.Unlock()
}
return info
}
// create creates a structInfo with meta-data about a struct.
func (c *cache) create(t reflect.Type, parentAlias string) *structInfo {
info := &structInfo{}
var anonymousInfos []*structInfo
for i := 0; i < t.NumField(); i++ {
if f := c.createField(t.Field(i), parentAlias); f != nil {
info.fields = append(info.fields, f)
if ft := indirectType(f.typ); ft.Kind() == reflect.Struct && f.isAnonymous {
anonymousInfos = append(anonymousInfos, c.create(ft, f.canonicalAlias))
}
}
}
for i, a := range anonymousInfos {
others := []*structInfo{info}
others = append(others, anonymousInfos[:i]...)
others = append(others, anonymousInfos[i+1:]...)
for _, f := range a.fields {
if !containsAlias(others, f.alias) {
info.fields = append(info.fields, f)
}
}
}
return info
}
// createField creates a fieldInfo for the given field.
func (c *cache) createField(field reflect.StructField, parentAlias string) *fieldInfo {
alias, options := fieldAlias(field, c.tag)
if alias == "-" {
// Ignore this field.
return nil
}
canonicalAlias := alias
if parentAlias != "" {
canonicalAlias = parentAlias + "." + alias
}
// Check if the type is supported and don't cache it if not.
// First let's get the basic type.
isSlice, isStruct := false, false
ft := field.Type
m := isTextUnmarshaler(reflect.Zero(ft))
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
if isSlice = ft.Kind() == reflect.Slice; isSlice {
ft = ft.Elem()
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
}
if ft.Kind() == reflect.Array {
ft = ft.Elem()
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
}
if isStruct = ft.Kind() == reflect.Struct; !isStruct {
if c.converter(ft) == nil && builtinConverters[ft.Kind()] == nil {
// Type is not supported.
return nil
}
}
return &fieldInfo{
typ: field.Type,
name: field.Name,
alias: alias,
canonicalAlias: canonicalAlias,
unmarshalerInfo: m,
isSliceOfStructs: isSlice && isStruct,
isAnonymous: field.Anonymous,
isRequired: options.Contains("required"),
}
}
// converter returns the converter for a type.
func (c *cache) converter(t reflect.Type) Converter {
return c.regconv[t]
}
// ----------------------------------------------------------------------------
type structInfo struct {
fields []*fieldInfo
}
func (i *structInfo) get(alias string) *fieldInfo {
for _, field := range i.fields {
if strings.EqualFold(field.alias, alias) {
return field
}
}
return nil
}
func containsAlias(infos []*structInfo, alias string) bool {
for _, info := range infos {
if info.get(alias) != nil {
return true
}
}
return false
}
type fieldInfo struct {
typ reflect.Type
// name is the field name in the struct.
name string
alias string
// canonicalAlias is almost the same as the alias, but is prefixed with
// an embedded struct field alias in dotted notation if this field is
// promoted from the struct.
// For instance, if the alias is "N" and this field is an embedded field
// in a struct "X", canonicalAlias will be "X.N".
canonicalAlias string
// unmarshalerInfo contains information regarding the
// encoding.TextUnmarshaler implementation of the field type.
unmarshalerInfo unmarshaler
// isSliceOfStructs indicates if the field type is a slice of structs.
isSliceOfStructs bool
// isAnonymous indicates whether the field is embedded in the struct.
isAnonymous bool
isRequired bool
}
func (f *fieldInfo) paths(prefix string) []string {
if f.alias == f.canonicalAlias {
return []string{prefix + f.alias}
}
return []string{prefix + f.alias, prefix + f.canonicalAlias}
}
type pathPart struct {
field *fieldInfo
path []string // path to the field: walks structs using field names.
index int // struct index in slices of structs.
}
// ----------------------------------------------------------------------------
func indirectType(typ reflect.Type) reflect.Type {
if typ.Kind() == reflect.Ptr {
return typ.Elem()
}
return typ
}
// fieldAlias parses a field tag to get a field alias.
func fieldAlias(field reflect.StructField, tagName string) (alias string, options tagOptions) {
if tag := field.Tag.Get(tagName); tag != "" {
alias, options = parseTag(tag)
}
if alias == "" {
alias = field.Name
}
return alias, options
}
// tagOptions is the string following a comma in a struct field's tag, or
// the empty string. It does not include the leading comma.
type tagOptions []string
// parseTag splits a struct field's url tag into its name and comma-separated
// options.
func parseTag(tag string) (string, tagOptions) {
s := strings.Split(tag, ",")
return s[0], s[1:]
}
// Contains checks whether the tagOptions contains the specified option.
func (o tagOptions) Contains(option string) bool {
for _, s := range o {
if s == option {
return true
}
}
return false
}
+145
View File
@@ -0,0 +1,145 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schema
import (
"reflect"
"strconv"
)
type Converter func(string) reflect.Value
var (
invalidValue = reflect.Value{}
boolType = reflect.Bool
float32Type = reflect.Float32
float64Type = reflect.Float64
intType = reflect.Int
int8Type = reflect.Int8
int16Type = reflect.Int16
int32Type = reflect.Int32
int64Type = reflect.Int64
stringType = reflect.String
uintType = reflect.Uint
uint8Type = reflect.Uint8
uint16Type = reflect.Uint16
uint32Type = reflect.Uint32
uint64Type = reflect.Uint64
)
// Default converters for basic types.
var builtinConverters = map[reflect.Kind]Converter{
boolType: convertBool,
float32Type: convertFloat32,
float64Type: convertFloat64,
intType: convertInt,
int8Type: convertInt8,
int16Type: convertInt16,
int32Type: convertInt32,
int64Type: convertInt64,
stringType: convertString,
uintType: convertUint,
uint8Type: convertUint8,
uint16Type: convertUint16,
uint32Type: convertUint32,
uint64Type: convertUint64,
}
func convertBool(value string) reflect.Value {
if value == "on" {
return reflect.ValueOf(true)
} else if v, err := strconv.ParseBool(value); err == nil {
return reflect.ValueOf(v)
}
return invalidValue
}
func convertFloat32(value string) reflect.Value {
if v, err := strconv.ParseFloat(value, 32); err == nil {
return reflect.ValueOf(float32(v))
}
return invalidValue
}
func convertFloat64(value string) reflect.Value {
if v, err := strconv.ParseFloat(value, 64); err == nil {
return reflect.ValueOf(v)
}
return invalidValue
}
func convertInt(value string) reflect.Value {
if v, err := strconv.ParseInt(value, 10, 0); err == nil {
return reflect.ValueOf(int(v))
}
return invalidValue
}
func convertInt8(value string) reflect.Value {
if v, err := strconv.ParseInt(value, 10, 8); err == nil {
return reflect.ValueOf(int8(v))
}
return invalidValue
}
func convertInt16(value string) reflect.Value {
if v, err := strconv.ParseInt(value, 10, 16); err == nil {
return reflect.ValueOf(int16(v))
}
return invalidValue
}
func convertInt32(value string) reflect.Value {
if v, err := strconv.ParseInt(value, 10, 32); err == nil {
return reflect.ValueOf(int32(v))
}
return invalidValue
}
func convertInt64(value string) reflect.Value {
if v, err := strconv.ParseInt(value, 10, 64); err == nil {
return reflect.ValueOf(v)
}
return invalidValue
}
func convertString(value string) reflect.Value {
return reflect.ValueOf(value)
}
func convertUint(value string) reflect.Value {
if v, err := strconv.ParseUint(value, 10, 0); err == nil {
return reflect.ValueOf(uint(v))
}
return invalidValue
}
func convertUint8(value string) reflect.Value {
if v, err := strconv.ParseUint(value, 10, 8); err == nil {
return reflect.ValueOf(uint8(v))
}
return invalidValue
}
func convertUint16(value string) reflect.Value {
if v, err := strconv.ParseUint(value, 10, 16); err == nil {
return reflect.ValueOf(uint16(v))
}
return invalidValue
}
func convertUint32(value string) reflect.Value {
if v, err := strconv.ParseUint(value, 10, 32); err == nil {
return reflect.ValueOf(uint32(v))
}
return invalidValue
}
func convertUint64(value string) reflect.Value {
if v, err := strconv.ParseUint(value, 10, 64); err == nil {
return reflect.ValueOf(v)
}
return invalidValue
}
+534
View File
@@ -0,0 +1,534 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schema
import (
"encoding"
"errors"
"fmt"
"reflect"
"strings"
)
// NewDecoder returns a new Decoder.
func NewDecoder() *Decoder {
return &Decoder{cache: newCache()}
}
// Decoder decodes values from a map[string][]string to a struct.
type Decoder struct {
cache *cache
zeroEmpty bool
ignoreUnknownKeys bool
}
// SetAliasTag changes the tag used to locate custom field aliases.
// The default tag is "schema".
func (d *Decoder) SetAliasTag(tag string) {
d.cache.tag = tag
}
// ZeroEmpty controls the behaviour when the decoder encounters empty values
// in a map.
// If z is true and a key in the map has the empty string as a value
// then the corresponding struct field is set to the zero value.
// If z is false then empty strings are ignored.
//
// The default value is false, that is empty values do not change
// the value of the struct field.
func (d *Decoder) ZeroEmpty(z bool) {
d.zeroEmpty = z
}
// IgnoreUnknownKeys controls the behaviour when the decoder encounters unknown
// keys in the map.
// If i is true and an unknown field is encountered, it is ignored. This is
// similar to how unknown keys are handled by encoding/json.
// If i is false then Decode will return an error. Note that any valid keys
// will still be decoded in to the target struct.
//
// To preserve backwards compatibility, the default value is false.
func (d *Decoder) IgnoreUnknownKeys(i bool) {
d.ignoreUnknownKeys = i
}
// RegisterConverter registers a converter function for a custom type.
func (d *Decoder) RegisterConverter(value interface{}, converterFunc Converter) {
d.cache.registerConverter(value, converterFunc)
}
// Decode decodes a map[string][]string to a struct.
//
// The first parameter must be a pointer to a struct.
//
// The second parameter is a map, typically url.Values from an HTTP request.
// Keys are "paths" in dotted notation to the struct fields and nested structs.
//
// See the package documentation for a full explanation of the mechanics.
func (d *Decoder) Decode(dst interface{}, src map[string][]string) error {
v := reflect.ValueOf(dst)
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
return errors.New("schema: interface must be a pointer to struct")
}
v = v.Elem()
t := v.Type()
multiError := MultiError{}
for path, values := range src {
if parts, err := d.cache.parsePath(path, t); err == nil {
if err = d.decode(v, path, parts, values); err != nil {
multiError[path] = err
}
} else if !d.ignoreUnknownKeys {
multiError[path] = UnknownKeyError{Key: path}
}
}
multiError.merge(d.checkRequired(t, src))
if len(multiError) > 0 {
return multiError
}
return nil
}
// checkRequired checks whether required fields are empty
//
// check type t recursively if t has struct fields.
//
// src is the source map for decoding, we use it here to see if those required fields are included in src
func (d *Decoder) checkRequired(t reflect.Type, src map[string][]string) MultiError {
m, errs := d.findRequiredFields(t, "", "")
for key, fields := range m {
if isEmptyFields(fields, src) {
errs[key] = EmptyFieldError{Key: key}
}
}
return errs
}
// findRequiredFields recursively searches the struct type t for required fields.
//
// canonicalPrefix and searchPrefix are used to resolve full paths in dotted notation
// for nested struct fields. canonicalPrefix is a complete path which never omits
// any embedded struct fields. searchPrefix is a user-friendly path which may omit
// some embedded struct fields to point promoted fields.
func (d *Decoder) findRequiredFields(t reflect.Type, canonicalPrefix, searchPrefix string) (map[string][]fieldWithPrefix, MultiError) {
struc := d.cache.get(t)
if struc == nil {
// unexpect, cache.get never return nil
return nil, MultiError{canonicalPrefix + "*": errors.New("cache fail")}
}
m := map[string][]fieldWithPrefix{}
errs := MultiError{}
for _, f := range struc.fields {
if f.typ.Kind() == reflect.Struct {
fcprefix := canonicalPrefix + f.canonicalAlias + "."
for _, fspath := range f.paths(searchPrefix) {
fm, ferrs := d.findRequiredFields(f.typ, fcprefix, fspath+".")
for key, fields := range fm {
m[key] = append(m[key], fields...)
}
errs.merge(ferrs)
}
}
if f.isRequired {
key := canonicalPrefix + f.canonicalAlias
m[key] = append(m[key], fieldWithPrefix{
fieldInfo: f,
prefix: searchPrefix,
})
}
}
return m, errs
}
type fieldWithPrefix struct {
*fieldInfo
prefix string
}
// isEmptyFields returns true if all of specified fields are empty.
func isEmptyFields(fields []fieldWithPrefix, src map[string][]string) bool {
for _, f := range fields {
for _, path := range f.paths(f.prefix) {
v, ok := src[path]
if ok && !isEmpty(f.typ, v) {
return false
}
for key := range src {
// issue references:
// https://github.com/gofiber/fiber/issues/1414
// https://github.com/gorilla/schema/issues/176
nested := strings.IndexByte(key, '.') != -1
// for non required nested structs
c1 := strings.HasSuffix(f.prefix, ".") && key == path
// for required nested structs
c2 := f.prefix == "" && nested && strings.HasPrefix(key, path)
// for non nested fields
c3 := f.prefix == "" && !nested && key == path
if !isEmpty(f.typ, src[key]) && (c1 || c2 || c3) {
return false
}
}
}
}
return true
}
// isEmpty returns true if value is empty for specific type
func isEmpty(t reflect.Type, value []string) bool {
if len(value) == 0 {
return true
}
switch t.Kind() {
case boolType, float32Type, float64Type, intType, int8Type, int32Type, int64Type, stringType, uint8Type, uint16Type, uint32Type, uint64Type:
return len(value[0]) == 0
}
return false
}
// decode fills a struct field using a parsed path.
func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values []string) error {
// Get the field walking the struct fields by index.
for _, name := range parts[0].path {
if v.Type().Kind() == reflect.Ptr {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
// alloc embedded structs
if v.Type().Kind() == reflect.Struct {
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.Type().Kind() == reflect.Ptr && field.IsNil() && v.Type().Field(i).Anonymous {
field.Set(reflect.New(field.Type().Elem()))
}
}
}
v = v.FieldByName(name)
}
// Don't even bother for unexported fields.
if !v.CanSet() {
return nil
}
// Dereference if needed.
t := v.Type()
if t.Kind() == reflect.Ptr {
t = t.Elem()
if v.IsNil() {
v.Set(reflect.New(t))
}
v = v.Elem()
}
// Slice of structs. Let's go recursive.
if len(parts) > 1 {
idx := parts[0].index
if v.IsNil() || v.Len() < idx+1 {
value := reflect.MakeSlice(t, idx+1, idx+1)
if v.Len() < idx+1 {
// Resize it.
reflect.Copy(value, v)
}
v.Set(value)
}
return d.decode(v.Index(idx), path, parts[1:], values)
}
// Get the converter early in case there is one for a slice type.
conv := d.cache.converter(t)
m := isTextUnmarshaler(v)
if conv == nil && t.Kind() == reflect.Slice && m.IsSliceElement {
var items []reflect.Value
elemT := t.Elem()
isPtrElem := elemT.Kind() == reflect.Ptr
if isPtrElem {
elemT = elemT.Elem()
}
// Try to get a converter for the element type.
conv := d.cache.converter(elemT)
if conv == nil {
conv = builtinConverters[elemT.Kind()]
if conv == nil {
// As we are not dealing with slice of structs here, we don't need to check if the type
// implements TextUnmarshaler interface
return fmt.Errorf("schema: converter not found for %v", elemT)
}
}
for key, value := range values {
if value == "" {
if d.zeroEmpty {
items = append(items, reflect.Zero(elemT))
}
} else if m.IsValid {
u := reflect.New(elemT)
if m.IsSliceElementPtr {
u = reflect.New(reflect.PtrTo(elemT).Elem())
}
if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value)); err != nil {
return ConversionError{
Key: path,
Type: t,
Index: key,
Err: err,
}
}
if m.IsSliceElementPtr {
items = append(items, u.Elem().Addr())
} else if u.Kind() == reflect.Ptr {
items = append(items, u.Elem())
} else {
items = append(items, u)
}
} else if item := conv(value); item.IsValid() {
if isPtrElem {
ptr := reflect.New(elemT)
ptr.Elem().Set(item)
item = ptr
}
if item.Type() != elemT && !isPtrElem {
item = item.Convert(elemT)
}
items = append(items, item)
} else {
if strings.Contains(value, ",") {
values := strings.Split(value, ",")
for _, value := range values {
if value == "" {
if d.zeroEmpty {
items = append(items, reflect.Zero(elemT))
}
} else if item := conv(value); item.IsValid() {
if isPtrElem {
ptr := reflect.New(elemT)
ptr.Elem().Set(item)
item = ptr
}
if item.Type() != elemT && !isPtrElem {
item = item.Convert(elemT)
}
items = append(items, item)
} else {
return ConversionError{
Key: path,
Type: elemT,
Index: key,
}
}
}
} else {
return ConversionError{
Key: path,
Type: elemT,
Index: key,
}
}
}
}
value := reflect.Append(reflect.MakeSlice(t, 0, 0), items...)
v.Set(value)
} else {
val := ""
// Use the last value provided if any values were provided
if len(values) > 0 {
val = values[len(values)-1]
}
if conv != nil {
if value := conv(val); value.IsValid() {
v.Set(value.Convert(t))
} else {
return ConversionError{
Key: path,
Type: t,
Index: -1,
}
}
} else if m.IsValid {
if m.IsPtr {
u := reflect.New(v.Type())
if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(val)); err != nil {
return ConversionError{
Key: path,
Type: t,
Index: -1,
Err: err,
}
}
v.Set(reflect.Indirect(u))
} else {
// If the value implements the encoding.TextUnmarshaler interface
// apply UnmarshalText as the converter
if err := m.Unmarshaler.UnmarshalText([]byte(val)); err != nil {
return ConversionError{
Key: path,
Type: t,
Index: -1,
Err: err,
}
}
}
} else if val == "" {
if d.zeroEmpty {
v.Set(reflect.Zero(t))
}
} else if conv := builtinConverters[t.Kind()]; conv != nil {
if value := conv(val); value.IsValid() {
v.Set(value.Convert(t))
} else {
return ConversionError{
Key: path,
Type: t,
Index: -1,
}
}
} else {
return fmt.Errorf("schema: converter not found for %v", t)
}
}
return nil
}
func isTextUnmarshaler(v reflect.Value) unmarshaler {
// Create a new unmarshaller instance
m := unmarshaler{}
if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid {
return m
}
// As the UnmarshalText function should be applied to the pointer of the
// type, we check that type to see if it implements the necessary
// method.
if m.Unmarshaler, m.IsValid = reflect.New(v.Type()).Interface().(encoding.TextUnmarshaler); m.IsValid {
m.IsPtr = true
return m
}
// if v is []T or *[]T create new T
t := v.Type()
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() == reflect.Slice {
// Check if the slice implements encoding.TextUnmarshaller
if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid {
return m
}
// If t is a pointer slice, check if its elements implement
// encoding.TextUnmarshaler
m.IsSliceElement = true
if t = t.Elem(); t.Kind() == reflect.Ptr {
t = reflect.PtrTo(t.Elem())
v = reflect.Zero(t)
m.IsSliceElementPtr = true
m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler)
return m
}
}
v = reflect.New(t)
m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler)
return m
}
// TextUnmarshaler helpers ----------------------------------------------------
// unmarshaller contains information about a TextUnmarshaler type
type unmarshaler struct {
Unmarshaler encoding.TextUnmarshaler
// IsValid indicates whether the resolved type indicated by the other
// flags implements the encoding.TextUnmarshaler interface.
IsValid bool
// IsPtr indicates that the resolved type is the pointer of the original
// type.
IsPtr bool
// IsSliceElement indicates that the resolved type is a slice element of
// the original type.
IsSliceElement bool
// IsSliceElementPtr indicates that the resolved type is a pointer to a
// slice element of the original type.
IsSliceElementPtr bool
}
// Errors ---------------------------------------------------------------------
// ConversionError stores information about a failed conversion.
type ConversionError struct {
Key string // key from the source map.
Type reflect.Type // expected type of elem
Index int // index for multi-value fields; -1 for single-value fields.
Err error // low-level error (when it exists)
}
func (e ConversionError) Error() string {
var output string
if e.Index < 0 {
output = fmt.Sprintf("schema: error converting value for %q", e.Key)
} else {
output = fmt.Sprintf("schema: error converting value for index %d of %q",
e.Index, e.Key)
}
if e.Err != nil {
output = fmt.Sprintf("%s. Details: %s", output, e.Err)
}
return output
}
// UnknownKeyError stores information about an unknown key in the source map.
type UnknownKeyError struct {
Key string // key from the source map.
}
func (e UnknownKeyError) Error() string {
return fmt.Sprintf("schema: invalid path %q", e.Key)
}
// EmptyFieldError stores information about an empty required field.
type EmptyFieldError struct {
Key string // required key in the source map.
}
func (e EmptyFieldError) Error() string {
return fmt.Sprintf("%v is empty", e.Key)
}
// MultiError stores multiple decoding errors.
//
// Borrowed from the App Engine SDK.
type MultiError map[string]error
func (e MultiError) Error() string {
s := ""
for _, err := range e {
s = err.Error()
break
}
switch len(e) {
case 0:
return "(0 errors)"
case 1:
return s
case 2:
return s + " (and 1 other error)"
}
return fmt.Sprintf("%s (and %d other errors)", s, len(e)-1)
}
func (e MultiError) merge(errors MultiError) {
for key, err := range errors {
if e[key] == nil {
e[key] = err
}
}
}
+148
View File
@@ -0,0 +1,148 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package gorilla/schema fills a struct with form values.
The basic usage is really simple. Given this struct:
type Person struct {
Name string
Phone string
}
...we can fill it passing a map to the Decode() function:
values := map[string][]string{
"Name": {"John"},
"Phone": {"999-999-999"},
}
person := new(Person)
decoder := schema.NewDecoder()
decoder.Decode(person, values)
This is just a simple example and it doesn't make a lot of sense to create
the map manually. Typically it will come from a http.Request object and
will be of type url.Values, http.Request.Form, or http.Request.MultipartForm:
func MyHandler(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
// Handle error
}
decoder := schema.NewDecoder()
// r.PostForm is a map of our POST form values
err := decoder.Decode(person, r.PostForm)
if err != nil {
// Handle error
}
// Do something with person.Name or person.Phone
}
Note: it is a good idea to set a Decoder instance as a package global,
because it caches meta-data about structs, and an instance can be shared safely:
var decoder = schema.NewDecoder()
To define custom names for fields, use a struct tag "schema". To not populate
certain fields, use a dash for the name and it will be ignored:
type Person struct {
Name string `schema:"name"` // custom name
Phone string `schema:"phone"` // custom name
Admin bool `schema:"-"` // this field is never set
}
The supported field types in the destination struct are:
- bool
- float variants (float32, float64)
- int variants (int, int8, int16, int32, int64)
- string
- uint variants (uint, uint8, uint16, uint32, uint64)
- struct
- a pointer to one of the above types
- a slice or a pointer to a slice of one of the above types
Non-supported types are simply ignored, however custom types can be registered
to be converted.
To fill nested structs, keys must use a dotted notation as the "path" for the
field. So for example, to fill the struct Person below:
type Phone struct {
Label string
Number string
}
type Person struct {
Name string
Phone Phone
}
...the source map must have the keys "Name", "Phone.Label" and "Phone.Number".
This means that an HTML form to fill a Person struct must look like this:
<form>
<input type="text" name="Name">
<input type="text" name="Phone.Label">
<input type="text" name="Phone.Number">
</form>
Single values are filled using the first value for a key from the source map.
Slices are filled using all values for a key from the source map. So to fill
a Person with multiple Phone values, like:
type Person struct {
Name string
Phones []Phone
}
...an HTML form that accepts three Phone values would look like this:
<form>
<input type="text" name="Name">
<input type="text" name="Phones.0.Label">
<input type="text" name="Phones.0.Number">
<input type="text" name="Phones.1.Label">
<input type="text" name="Phones.1.Number">
<input type="text" name="Phones.2.Label">
<input type="text" name="Phones.2.Number">
</form>
Notice that only for slices of structs the slice index is required.
This is needed for disambiguation: if the nested struct also had a slice
field, we could not translate multiple values to it if we did not use an
index for the parent struct.
There's also the possibility to create a custom type that implements the
TextUnmarshaler interface, and in this case there's no need to register
a converter, like:
type Person struct {
Emails []Email
}
type Email struct {
*mail.Address
}
func (e *Email) UnmarshalText(text []byte) (err error) {
e.Address, err = mail.ParseAddress(string(text))
return
}
...an HTML form that accepts three Email values would look like this:
<form>
<input type="email" name="Emails.0">
<input type="email" name="Emails.1">
<input type="email" name="Emails.2">
</form>
*/
package schema
+202
View File
@@ -0,0 +1,202 @@
package schema
import (
"errors"
"fmt"
"reflect"
"strconv"
)
type encoderFunc func(reflect.Value) string
// Encoder encodes values from a struct into url.Values.
type Encoder struct {
cache *cache
regenc map[reflect.Type]encoderFunc
}
// NewEncoder returns a new Encoder with defaults.
func NewEncoder() *Encoder {
return &Encoder{cache: newCache(), regenc: make(map[reflect.Type]encoderFunc)}
}
// Encode encodes a struct into map[string][]string.
//
// Intended for use with url.Values.
func (e *Encoder) Encode(src interface{}, dst map[string][]string) error {
v := reflect.ValueOf(src)
return e.encode(v, dst)
}
// RegisterEncoder registers a converter for encoding a custom type.
func (e *Encoder) RegisterEncoder(value interface{}, encoder func(reflect.Value) string) {
e.regenc[reflect.TypeOf(value)] = encoder
}
// SetAliasTag changes the tag used to locate custom field aliases.
// The default tag is "schema".
func (e *Encoder) SetAliasTag(tag string) {
e.cache.tag = tag
}
// isValidStructPointer test if input value is a valid struct pointer.
func isValidStructPointer(v reflect.Value) bool {
return v.Type().Kind() == reflect.Ptr && v.Elem().IsValid() && v.Elem().Type().Kind() == reflect.Struct
}
func isZero(v reflect.Value) bool {
switch v.Kind() {
case reflect.Func:
case reflect.Map, reflect.Slice:
return v.IsNil() || v.Len() == 0
case reflect.Array:
z := true
for i := 0; i < v.Len(); i++ {
z = z && isZero(v.Index(i))
}
return z
case reflect.Struct:
type zero interface {
IsZero() bool
}
if v.Type().Implements(reflect.TypeOf((*zero)(nil)).Elem()) {
iz := v.MethodByName("IsZero").Call([]reflect.Value{})[0]
return iz.Interface().(bool)
}
z := true
for i := 0; i < v.NumField(); i++ {
z = z && isZero(v.Field(i))
}
return z
}
// Compare other types directly:
z := reflect.Zero(v.Type())
return v.Interface() == z.Interface()
}
func (e *Encoder) encode(v reflect.Value, dst map[string][]string) error {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return errors.New("schema: interface must be a struct")
}
t := v.Type()
errors := MultiError{}
for i := 0; i < v.NumField(); i++ {
name, opts := fieldAlias(t.Field(i), e.cache.tag)
if name == "-" {
continue
}
// Encode struct pointer types if the field is a valid pointer and a struct.
if isValidStructPointer(v.Field(i)) {
_ = e.encode(v.Field(i).Elem(), dst)
continue
}
encFunc := typeEncoder(v.Field(i).Type(), e.regenc)
// Encode non-slice types and custom implementations immediately.
if encFunc != nil {
value := encFunc(v.Field(i))
if opts.Contains("omitempty") && isZero(v.Field(i)) {
continue
}
dst[name] = append(dst[name], value)
continue
}
if v.Field(i).Type().Kind() == reflect.Struct {
_ = e.encode(v.Field(i), dst)
continue
}
if v.Field(i).Type().Kind() == reflect.Slice {
encFunc = typeEncoder(v.Field(i).Type().Elem(), e.regenc)
}
if encFunc == nil {
errors[v.Field(i).Type().String()] = fmt.Errorf("schema: encoder not found for %v", v.Field(i))
continue
}
// Encode a slice.
if v.Field(i).Len() == 0 && opts.Contains("omitempty") {
continue
}
dst[name] = []string{}
for j := 0; j < v.Field(i).Len(); j++ {
dst[name] = append(dst[name], encFunc(v.Field(i).Index(j)))
}
}
if len(errors) > 0 {
return errors
}
return nil
}
func typeEncoder(t reflect.Type, reg map[reflect.Type]encoderFunc) encoderFunc {
if f, ok := reg[t]; ok {
return f
}
switch t.Kind() {
case reflect.Bool:
return encodeBool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return encodeInt
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return encodeUint
case reflect.Float32:
return encodeFloat32
case reflect.Float64:
return encodeFloat64
case reflect.Ptr:
f := typeEncoder(t.Elem(), reg)
return func(v reflect.Value) string {
if v.IsNil() {
return "null"
}
return f(v.Elem())
}
case reflect.String:
return encodeString
default:
return nil
}
}
func encodeBool(v reflect.Value) string {
return strconv.FormatBool(v.Bool())
}
func encodeInt(v reflect.Value) string {
return strconv.FormatInt(int64(v.Int()), 10)
}
func encodeUint(v reflect.Value) string {
return strconv.FormatUint(uint64(v.Uint()), 10)
}
func encodeFloat(v reflect.Value, bits int) string {
return strconv.FormatFloat(v.Float(), 'f', 6, bits)
}
func encodeFloat32(v reflect.Value) string {
return encodeFloat(v, 32)
}
func encodeFloat64(v reflect.Value) string {
return encodeFloat(v, 64)
}
func encodeString(v reflect.Value) string {
return v.String()
}
+502
View File
@@ -0,0 +1,502 @@
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 🤖 Github Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package fiber
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"reflect"
"runtime"
"sort"
"strconv"
"strings"
"text/tabwriter"
"github.com/mattn/go-colorable"
"github.com/mattn/go-isatty"
"github.com/mattn/go-runewidth"
"github.com/gofiber/fiber/v2/log"
)
const (
globalIpv4Addr = "0.0.0.0"
)
// Listener can be used to pass a custom listener.
func (app *App) Listener(ln net.Listener) error {
// prepare the server for the start
app.startupProcess()
// run hooks
app.runOnListenHooks(app.prepareListenData(ln.Addr().String(), getTLSConfig(ln) != nil))
// Print startup message
if !app.config.DisableStartupMessage {
app.startupMessage(ln.Addr().String(), getTLSConfig(ln) != nil, "")
}
// Print routes
if app.config.EnablePrintRoutes {
app.printRoutesMessage()
}
// Prefork is not supported for custom listeners
if app.config.Prefork {
log.Warn("Prefork isn't supported for custom listeners.")
}
// Start listening
return app.server.Serve(ln)
}
// Listen serves HTTP requests from the given addr.
//
// app.Listen(":8080")
// app.Listen("127.0.0.1:8080")
func (app *App) Listen(addr string) error {
// Start prefork
if app.config.Prefork {
return app.prefork(app.config.Network, addr, nil)
}
// Setup listener
ln, err := net.Listen(app.config.Network, addr)
if err != nil {
return fmt.Errorf("failed to listen: %w", err)
}
// prepare the server for the start
app.startupProcess()
// run hooks
app.runOnListenHooks(app.prepareListenData(ln.Addr().String(), false))
// Print startup message
if !app.config.DisableStartupMessage {
app.startupMessage(ln.Addr().String(), false, "")
}
// Print routes
if app.config.EnablePrintRoutes {
app.printRoutesMessage()
}
// Start listening
return app.server.Serve(ln)
}
// ListenTLS serves HTTPS requests from the given addr.
// certFile and keyFile are the paths to TLS certificate and key file:
//
// app.ListenTLS(":8080", "./cert.pem", "./cert.key")
func (app *App) ListenTLS(addr, certFile, keyFile string) error {
// Check for valid cert/key path
if len(certFile) == 0 || len(keyFile) == 0 {
return errors.New("tls: provide a valid cert or key path")
}
// Set TLS config with handler
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %w", certFile, keyFile, err)
}
return app.ListenTLSWithCertificate(addr, cert)
}
// ListenTLS serves HTTPS requests from the given addr.
// cert is a tls.Certificate
//
// app.ListenTLSWithCertificate(":8080", cert)
func (app *App) ListenTLSWithCertificate(addr string, cert tls.Certificate) error {
tlsHandler := &TLSHandler{}
config := &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{
cert,
},
GetCertificate: tlsHandler.GetClientInfo,
}
// Prefork is supported
if app.config.Prefork {
return app.prefork(app.config.Network, addr, config)
}
// Setup listener
ln, err := net.Listen(app.config.Network, addr)
ln = tls.NewListener(ln, config)
if err != nil {
return fmt.Errorf("failed to listen: %w", err)
}
// prepare the server for the start
app.startupProcess()
// run hooks
app.runOnListenHooks(app.prepareListenData(ln.Addr().String(), getTLSConfig(ln) != nil))
// Print startup message
if !app.config.DisableStartupMessage {
app.startupMessage(ln.Addr().String(), true, "")
}
// Print routes
if app.config.EnablePrintRoutes {
app.printRoutesMessage()
}
// Attach the tlsHandler to the config
app.SetTLSHandler(tlsHandler)
// Start listening
return app.server.Serve(ln)
}
// ListenMutualTLS serves HTTPS requests from the given addr.
// certFile, keyFile and clientCertFile are the paths to TLS certificate and key file:
//
// app.ListenMutualTLS(":8080", "./cert.pem", "./cert.key", "./client.pem")
func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string) error {
// Check for valid cert/key path
if len(certFile) == 0 || len(keyFile) == 0 {
return errors.New("tls: provide a valid cert or key path")
}
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %w", certFile, keyFile, err)
}
clientCACert, err := os.ReadFile(filepath.Clean(clientCertFile))
if err != nil {
return fmt.Errorf("failed to read file: %w", err)
}
clientCertPool := x509.NewCertPool()
clientCertPool.AppendCertsFromPEM(clientCACert)
return app.ListenMutualTLSWithCertificate(addr, cert, clientCertPool)
}
// ListenMutualTLSWithCertificate serves HTTPS requests from the given addr.
// cert is a tls.Certificate and clientCertPool is a *x509.CertPool:
//
// app.ListenMutualTLS(":8080", cert, clientCertPool)
func (app *App) ListenMutualTLSWithCertificate(addr string, cert tls.Certificate, clientCertPool *x509.CertPool) error {
tlsHandler := &TLSHandler{}
config := &tls.Config{
MinVersion: tls.VersionTLS12,
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: clientCertPool,
Certificates: []tls.Certificate{
cert,
},
GetCertificate: tlsHandler.GetClientInfo,
}
// Prefork is supported
if app.config.Prefork {
return app.prefork(app.config.Network, addr, config)
}
// Setup listener
ln, err := tls.Listen(app.config.Network, addr, config)
if err != nil {
return fmt.Errorf("failed to listen: %w", err)
}
// prepare the server for the start
app.startupProcess()
// run hooks
app.runOnListenHooks(app.prepareListenData(ln.Addr().String(), getTLSConfig(ln) != nil))
// Print startup message
if !app.config.DisableStartupMessage {
app.startupMessage(ln.Addr().String(), true, "")
}
// Print routes
if app.config.EnablePrintRoutes {
app.printRoutesMessage()
}
// Attach the tlsHandler to the config
app.SetTLSHandler(tlsHandler)
// Start listening
return app.server.Serve(ln)
}
// prepareListenData create an slice of ListenData
func (app *App) prepareListenData(addr string, isTLS bool) ListenData { //revive:disable-line:flag-parameter // Accepting a bool param named isTLS if fine here
host, port := parseAddr(addr)
if host == "" {
if app.config.Network == NetworkTCP6 {
host = "[::1]"
} else {
host = globalIpv4Addr
}
}
return ListenData{
Host: host,
Port: port,
TLS: isTLS,
}
}
// startupMessage prepares the startup message with the handler number, port, address and other information
func (app *App) startupMessage(addr string, isTLS bool, pids string) { //nolint: revive // Accepting a bool param named isTLS if fine here
// ignore child processes
if IsChild() {
return
}
// Alias colors
colors := app.config.ColorScheme
value := func(s string, width int) string {
pad := width - len(s)
str := ""
for i := 0; i < pad; i++ {
str += "."
}
if s == "Disabled" {
str += " " + s
} else {
str += fmt.Sprintf(" %s%s%s", colors.Cyan, s, colors.Black)
}
return str
}
center := func(s string, width int) string {
const padDiv = 2
pad := strconv.Itoa((width - len(s)) / padDiv)
str := fmt.Sprintf("%"+pad+"s", " ")
str += s
str += fmt.Sprintf("%"+pad+"s", " ")
if len(str) < width {
str += " "
}
return str
}
centerValue := func(s string, width int) string {
const padDiv = 2
pad := strconv.Itoa((width - runewidth.StringWidth(s)) / padDiv)
str := fmt.Sprintf("%"+pad+"s", " ")
str += fmt.Sprintf("%s%s%s", colors.Cyan, s, colors.Black)
str += fmt.Sprintf("%"+pad+"s", " ")
if runewidth.StringWidth(s)-10 < width && runewidth.StringWidth(s)%2 == 0 {
// add an ending space if the length of str is even and str is not too long
str += " "
}
return str
}
pad := func(s string, width int) string {
toAdd := width - len(s)
str := s
for i := 0; i < toAdd; i++ {
str += " "
}
return str
}
host, port := parseAddr(addr)
if host == "" {
if app.config.Network == NetworkTCP6 {
host = "[::1]"
} else {
host = globalIpv4Addr
}
}
scheme := schemeHTTP
if isTLS {
scheme = schemeHTTPS
}
isPrefork := "Disabled"
if app.config.Prefork {
isPrefork = "Enabled"
}
procs := strconv.Itoa(runtime.GOMAXPROCS(0))
if !app.config.Prefork {
procs = "1"
}
const lineLen = 49
mainLogo := colors.Black + " ┌───────────────────────────────────────────────────┐\n"
if app.config.AppName != "" {
mainLogo += " │ " + centerValue(app.config.AppName, lineLen) + " │\n"
}
mainLogo += " │ " + centerValue("Fiber v"+Version, lineLen) + " │\n"
if host == globalIpv4Addr {
mainLogo += " │ " + center(fmt.Sprintf("%s://127.0.0.1:%s", scheme, port), lineLen) + " │\n" +
" │ " + center(fmt.Sprintf("(bound on host 0.0.0.0 and port %s)", port), lineLen) + " │\n"
} else {
mainLogo += " │ " + center(fmt.Sprintf("%s://%s:%s", scheme, host, port), lineLen) + " │\n"
}
mainLogo += fmt.Sprintf(
" │ │\n"+
" │ Handlers %s Processes %s │\n"+
" │ Prefork .%s PID ....%s │\n"+
" └───────────────────────────────────────────────────┘"+
colors.Reset,
value(strconv.Itoa(int(app.handlersCount)), 14), value(procs, 12),
value(isPrefork, 14), value(strconv.Itoa(os.Getpid()), 14),
)
var childPidsLogo string
if app.config.Prefork {
var childPidsTemplate string
childPidsTemplate += "%s"
childPidsTemplate += " ┌───────────────────────────────────────────────────┐\n%s"
childPidsTemplate += " └───────────────────────────────────────────────────┘"
childPidsTemplate += "%s"
newLine := " │ %s%s%s │"
// Turn the `pids` variable (in the form ",a,b,c,d,e,f,etc") into a slice of PIDs
var pidSlice []string
for _, v := range strings.Split(pids, ",") {
if v != "" {
pidSlice = append(pidSlice, v)
}
}
var lines []string
thisLine := "Child PIDs ... "
var itemsOnThisLine []string
const maxLineLen = 49
addLine := func() {
lines = append(lines,
fmt.Sprintf(
newLine,
colors.Black,
thisLine+colors.Cyan+pad(strings.Join(itemsOnThisLine, ", "), maxLineLen-len(thisLine)),
colors.Black,
),
)
}
for _, pid := range pidSlice {
if len(thisLine+strings.Join(append(itemsOnThisLine, pid), ", ")) > maxLineLen {
addLine()
thisLine = ""
itemsOnThisLine = []string{pid}
} else {
itemsOnThisLine = append(itemsOnThisLine, pid)
}
}
// Add left over items to their own line
if len(itemsOnThisLine) != 0 {
addLine()
}
// Form logo
childPidsLogo = fmt.Sprintf(childPidsTemplate,
colors.Black,
strings.Join(lines, "\n")+"\n",
colors.Reset,
)
}
// Combine both the child PID logo and the main Fiber logo
// Pad the shorter logo to the length of the longer one
splitMainLogo := strings.Split(mainLogo, "\n")
splitChildPidsLogo := strings.Split(childPidsLogo, "\n")
mainLen := len(splitMainLogo)
childLen := len(splitChildPidsLogo)
if mainLen > childLen {
diff := mainLen - childLen
for i := 0; i < diff; i++ {
splitChildPidsLogo = append(splitChildPidsLogo, "")
}
} else {
diff := childLen - mainLen
for i := 0; i < diff; i++ {
splitMainLogo = append(splitMainLogo, "")
}
}
// Combine the two logos, line by line
output := "\n"
for i := range splitMainLogo {
output += colors.Black + splitMainLogo[i] + " " + splitChildPidsLogo[i] + "\n"
}
out := colorable.NewColorableStdout()
if os.Getenv("TERM") == "dumb" || os.Getenv("NO_COLOR") == "1" || (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd())) {
out = colorable.NewNonColorable(os.Stdout)
}
_, _ = fmt.Fprintln(out, output)
}
// printRoutesMessage print all routes with method, path, name and handlers
// in a format of table, like this:
// method | path | name | handlers
// GET | / | routeName | github.com/gofiber/fiber/v2.emptyHandler
// HEAD | / | | github.com/gofiber/fiber/v2.emptyHandler
func (app *App) printRoutesMessage() {
// ignore child processes
if IsChild() {
return
}
// Alias colors
colors := app.config.ColorScheme
var routes []RouteMessage
for _, routeStack := range app.stack {
for _, route := range routeStack {
var newRoute RouteMessage
newRoute.name = route.Name
newRoute.method = route.Method
newRoute.path = route.Path
for _, handler := range route.Handlers {
newRoute.handlers += runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name() + " "
}
routes = append(routes, newRoute)
}
}
out := colorable.NewColorableStdout()
if os.Getenv("TERM") == "dumb" || os.Getenv("NO_COLOR") == "1" || (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd())) {
out = colorable.NewNonColorable(os.Stdout)
}
w := tabwriter.NewWriter(out, 1, 1, 1, ' ', 0)
// Sort routes by path
sort.Slice(routes, func(i, j int) bool {
return routes[i].path < routes[j].path
})
_, _ = fmt.Fprintf(w, "%smethod\t%s| %spath\t%s| %sname\t%s| %shandlers\t%s\n", colors.Blue, colors.White, colors.Green, colors.White, colors.Cyan, colors.White, colors.Yellow, colors.Reset)
_, _ = fmt.Fprintf(w, "%s------\t%s| %s----\t%s| %s----\t%s| %s--------\t%s\n", colors.Blue, colors.White, colors.Green, colors.White, colors.Cyan, colors.White, colors.Yellow, colors.Reset)
for _, route := range routes {
_, _ = fmt.Fprintf(w, "%s%s\t%s| %s%s\t%s| %s%s\t%s| %s%s%s\n", colors.Blue, route.method, colors.White, colors.Green, route.path, colors.White, colors.Cyan, route.name, colors.White, colors.Yellow, route.handlers, colors.Reset)
}
_ = w.Flush() //nolint:errcheck // It is fine to ignore the error here
}
+209
View File
@@ -0,0 +1,209 @@
package log
import (
"context"
"fmt"
"io"
"log"
"os"
"sync"
"github.com/valyala/bytebufferpool"
)
var _ AllLogger = (*defaultLogger)(nil)
type defaultLogger struct {
stdlog *log.Logger
level Level
depth int
}
// privateLog logs a message at a given level log the default logger.
// when the level is fatal, it will exit the program.
func (l *defaultLogger) privateLog(lv Level, fmtArgs []interface{}) {
if l.level > lv {
return
}
level := lv.toString()
buf := bytebufferpool.Get()
_, _ = buf.WriteString(level) //nolint:errcheck // It is fine to ignore the error
_, _ = buf.WriteString(fmt.Sprint(fmtArgs...)) //nolint:errcheck // It is fine to ignore the error
_ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error
buf.Reset()
bytebufferpool.Put(buf)
if lv == LevelFatal {
os.Exit(1) //nolint:revive // we want to exit the program when Fatal is called
}
}
// privateLog logs a message at a given level log the default logger.
// when the level is fatal, it will exit the program.
func (l *defaultLogger) privateLogf(lv Level, format string, fmtArgs []interface{}) {
if l.level > lv {
return
}
level := lv.toString()
buf := bytebufferpool.Get()
_, _ = buf.WriteString(level) //nolint:errcheck // It is fine to ignore the error
if len(fmtArgs) > 0 {
_, _ = fmt.Fprintf(buf, format, fmtArgs...)
} else {
_, _ = fmt.Fprint(buf, fmtArgs...)
}
_ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error
buf.Reset()
bytebufferpool.Put(buf)
if lv == LevelFatal {
os.Exit(1) //nolint:revive // we want to exit the program when Fatal is called
}
}
// privateLogw logs a message at a given level log the default logger.
// when the level is fatal, it will exit the program.
func (l *defaultLogger) privateLogw(lv Level, format string, keysAndValues []interface{}) {
if l.level > lv {
return
}
level := lv.toString()
buf := bytebufferpool.Get()
_, _ = buf.WriteString(level) //nolint:errcheck // It is fine to ignore the error
// Write format privateLog buffer
if format != "" {
_, _ = buf.WriteString(format) //nolint:errcheck // It is fine to ignore the error
}
var once sync.Once
isFirst := true
// Write keys and values privateLog buffer
if len(keysAndValues) > 0 {
if (len(keysAndValues) & 1) == 1 {
keysAndValues = append(keysAndValues, "KEYVALS UNPAIRED")
}
for i := 0; i < len(keysAndValues); i += 2 {
if format == "" && isFirst {
once.Do(func() {
_, _ = fmt.Fprintf(buf, "%s=%v", keysAndValues[i], keysAndValues[i+1])
isFirst = false
})
continue
}
_, _ = fmt.Fprintf(buf, " %s=%v", keysAndValues[i], keysAndValues[i+1])
}
}
_ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error
buf.Reset()
bytebufferpool.Put(buf)
if lv == LevelFatal {
os.Exit(1) //nolint:revive // we want to exit the program when Fatal is called
}
}
func (l *defaultLogger) Trace(v ...interface{}) {
l.privateLog(LevelTrace, v)
}
func (l *defaultLogger) Debug(v ...interface{}) {
l.privateLog(LevelDebug, v)
}
func (l *defaultLogger) Info(v ...interface{}) {
l.privateLog(LevelInfo, v)
}
func (l *defaultLogger) Warn(v ...interface{}) {
l.privateLog(LevelWarn, v)
}
func (l *defaultLogger) Error(v ...interface{}) {
l.privateLog(LevelError, v)
}
func (l *defaultLogger) Fatal(v ...interface{}) {
l.privateLog(LevelFatal, v)
}
func (l *defaultLogger) Panic(v ...interface{}) {
l.privateLog(LevelPanic, v)
}
func (l *defaultLogger) Tracef(format string, v ...interface{}) {
l.privateLogf(LevelTrace, format, v)
}
func (l *defaultLogger) Debugf(format string, v ...interface{}) {
l.privateLogf(LevelDebug, format, v)
}
func (l *defaultLogger) Infof(format string, v ...interface{}) {
l.privateLogf(LevelInfo, format, v)
}
func (l *defaultLogger) Warnf(format string, v ...interface{}) {
l.privateLogf(LevelWarn, format, v)
}
func (l *defaultLogger) Errorf(format string, v ...interface{}) {
l.privateLogf(LevelError, format, v)
}
func (l *defaultLogger) Fatalf(format string, v ...interface{}) {
l.privateLogf(LevelFatal, format, v)
}
func (l *defaultLogger) Panicf(format string, v ...interface{}) {
l.privateLogf(LevelPanic, format, v)
}
func (l *defaultLogger) Tracew(msg string, keysAndValues ...interface{}) {
l.privateLogw(LevelTrace, msg, keysAndValues)
}
func (l *defaultLogger) Debugw(msg string, keysAndValues ...interface{}) {
l.privateLogw(LevelDebug, msg, keysAndValues)
}
func (l *defaultLogger) Infow(msg string, keysAndValues ...interface{}) {
l.privateLogw(LevelInfo, msg, keysAndValues)
}
func (l *defaultLogger) Warnw(msg string, keysAndValues ...interface{}) {
l.privateLogw(LevelWarn, msg, keysAndValues)
}
func (l *defaultLogger) Errorw(msg string, keysAndValues ...interface{}) {
l.privateLogw(LevelError, msg, keysAndValues)
}
func (l *defaultLogger) Fatalw(msg string, keysAndValues ...interface{}) {
l.privateLogw(LevelFatal, msg, keysAndValues)
}
func (l *defaultLogger) Panicw(msg string, keysAndValues ...interface{}) {
l.privateLogw(LevelPanic, msg, keysAndValues)
}
func (l *defaultLogger) WithContext(_ context.Context) CommonLogger {
return &defaultLogger{
stdlog: l.stdlog,
level: l.level,
depth: l.depth - 1,
}
}
func (l *defaultLogger) SetLevel(level Level) {
l.level = level
}
func (l *defaultLogger) SetOutput(writer io.Writer) {
l.stdlog.SetOutput(writer)
}
// DefaultLogger returns the default logger.
func DefaultLogger() AllLogger {
return logger
}
+141
View File
@@ -0,0 +1,141 @@
package log
import (
"context"
"io"
)
// Fatal calls the default logger's Fatal method and then os.Exit(1).
func Fatal(v ...interface{}) {
logger.Fatal(v...)
}
// Error calls the default logger's Error method.
func Error(v ...interface{}) {
logger.Error(v...)
}
// Warn calls the default logger's Warn method.
func Warn(v ...interface{}) {
logger.Warn(v...)
}
// Info calls the default logger's Info method.
func Info(v ...interface{}) {
logger.Info(v...)
}
// Debug calls the default logger's Debug method.
func Debug(v ...interface{}) {
logger.Debug(v...)
}
// Trace calls the default logger's Trace method.
func Trace(v ...interface{}) {
logger.Trace(v...)
}
// Panic calls the default logger's Panic method.
func Panic(v ...interface{}) {
logger.Panic(v...)
}
// Fatalf calls the default logger's Fatalf method and then os.Exit(1).
func Fatalf(format string, v ...interface{}) {
logger.Fatalf(format, v...)
}
// Errorf calls the default logger's Errorf method.
func Errorf(format string, v ...interface{}) {
logger.Errorf(format, v...)
}
// Warnf calls the default logger's Warnf method.
func Warnf(format string, v ...interface{}) {
logger.Warnf(format, v...)
}
// Infof calls the default logger's Infof method.
func Infof(format string, v ...interface{}) {
logger.Infof(format, v...)
}
// Debugf calls the default logger's Debugf method.
func Debugf(format string, v ...interface{}) {
logger.Debugf(format, v...)
}
// Tracef calls the default logger's Tracef method.
func Tracef(format string, v ...interface{}) {
logger.Tracef(format, v...)
}
// Panicf calls the default logger's Tracef method.
func Panicf(format string, v ...interface{}) {
logger.Panicf(format, v...)
}
// Tracew logs a message with some additional context. The variadic key-value
// pairs are treated as they are privateLog With.
func Tracew(msg string, keysAndValues ...interface{}) {
logger.Tracew(msg, keysAndValues...)
}
// Debugw logs a message with some additional context. The variadic key-value
// pairs are treated as they are privateLog With.
func Debugw(msg string, keysAndValues ...interface{}) {
logger.Debugw(msg, keysAndValues...)
}
// Infow logs a message with some additional context. The variadic key-value
// pairs are treated as they are privateLog With.
func Infow(msg string, keysAndValues ...interface{}) {
logger.Infow(msg, keysAndValues...)
}
// Warnw logs a message with some additional context. The variadic key-value
// pairs are treated as they are privateLog With.
func Warnw(msg string, keysAndValues ...interface{}) {
logger.Warnw(msg, keysAndValues...)
}
// Errorw logs a message with some additional context. The variadic key-value
// pairs are treated as they are privateLog With.
func Errorw(msg string, keysAndValues ...interface{}) {
logger.Errorw(msg, keysAndValues...)
}
// Fatalw logs a message with some additional context. The variadic key-value
// pairs are treated as they are privateLog With.
func Fatalw(msg string, keysAndValues ...interface{}) {
logger.Fatalw(msg, keysAndValues...)
}
// Panicw logs a message with some additional context. The variadic key-value
// pairs are treated as they are privateLog With.
func Panicw(msg string, keysAndValues ...interface{}) {
logger.Panicw(msg, keysAndValues...)
}
func WithContext(ctx context.Context) CommonLogger {
return logger.WithContext(ctx)
}
// SetLogger sets the default logger and the system logger.
// Note that this method is not concurrent-safe and must not be called
// after the use of DefaultLogger and global functions privateLog this package.
func SetLogger(v AllLogger) {
logger = v
}
// SetOutput sets the output of default logger and system logger. By default, it is stderr.
func SetOutput(w io.Writer) {
logger.SetOutput(w)
}
// SetLevel sets the level of logs below which logs will not be output.
// The default logger is LevelTrace.
// Note that this method is not concurrent-safe.
func SetLevel(lv Level) {
logger.SetLevel(lv)
}
+100
View File
@@ -0,0 +1,100 @@
package log
import (
"context"
"fmt"
"io"
"log"
"os"
)
var logger AllLogger = &defaultLogger{
stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds),
depth: 4,
}
// Logger is a logger interface that provides logging function with levels.
type Logger interface {
Trace(v ...interface{})
Debug(v ...interface{})
Info(v ...interface{})
Warn(v ...interface{})
Error(v ...interface{})
Fatal(v ...interface{})
Panic(v ...interface{})
}
// FormatLogger is a logger interface that output logs with a format.
type FormatLogger interface {
Tracef(format string, v ...interface{})
Debugf(format string, v ...interface{})
Infof(format string, v ...interface{})
Warnf(format string, v ...interface{})
Errorf(format string, v ...interface{})
Fatalf(format string, v ...interface{})
Panicf(format string, v ...interface{})
}
// WithLogger is a logger interface that output logs with a message and key-value pairs.
type WithLogger interface {
Tracew(msg string, keysAndValues ...interface{})
Debugw(msg string, keysAndValues ...interface{})
Infow(msg string, keysAndValues ...interface{})
Warnw(msg string, keysAndValues ...interface{})
Errorw(msg string, keysAndValues ...interface{})
Fatalw(msg string, keysAndValues ...interface{})
Panicw(msg string, keysAndValues ...interface{})
}
type CommonLogger interface {
Logger
FormatLogger
WithLogger
}
// ControlLogger provides methods to config a logger.
type ControlLogger interface {
SetLevel(Level)
SetOutput(io.Writer)
}
// AllLogger is the combination of Logger, FormatLogger, CtxLogger and ControlLogger.
// Custom extensions can be made through AllLogger
type AllLogger interface {
CommonLogger
ControlLogger
WithContext(ctx context.Context) CommonLogger
}
// Level defines the priority of a log message.
// When a logger is configured with a level, any log message with a lower
// log level (smaller by integer comparison) will not be output.
type Level int
// The levels of logs.
const (
LevelTrace Level = iota
LevelDebug
LevelInfo
LevelWarn
LevelError
LevelFatal
LevelPanic
)
var strs = []string{
"[Trace] ",
"[Debug] ",
"[Info] ",
"[Warn] ",
"[Error] ",
"[Fatal] ",
"[Panic] ",
}
func (lv Level) toString() string {
if lv >= LevelTrace && lv <= LevelPanic {
return strs[lv]
}
return fmt.Sprintf("[?%d] ", lv)
}
+65
View File
@@ -0,0 +1,65 @@
package compress
import (
"github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Setup request handlers
var (
fctx = func(c *fasthttp.RequestCtx) {}
compressor fasthttp.RequestHandler
)
// Setup compression algorithm
switch cfg.Level {
case LevelDefault:
// LevelDefault
compressor = fasthttp.CompressHandlerBrotliLevel(fctx,
fasthttp.CompressBrotliDefaultCompression,
fasthttp.CompressDefaultCompression,
)
case LevelBestSpeed:
// LevelBestSpeed
compressor = fasthttp.CompressHandlerBrotliLevel(fctx,
fasthttp.CompressBrotliBestSpeed,
fasthttp.CompressBestSpeed,
)
case LevelBestCompression:
// LevelBestCompression
compressor = fasthttp.CompressHandlerBrotliLevel(fctx,
fasthttp.CompressBrotliBestCompression,
fasthttp.CompressBestCompression,
)
default:
// LevelDisabled
return func(c *fiber.Ctx) error {
return c.Next()
}
}
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Continue stack
if err := c.Next(); err != nil {
return err
}
// Compress response
compressor(c.Context())
// Return from handler
return nil
}
}
+56
View File
@@ -0,0 +1,56 @@
package compress
import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Level determines the compression algorithm
//
// Optional. Default: LevelDefault
// LevelDisabled: -1
// LevelDefault: 0
// LevelBestSpeed: 1
// LevelBestCompression: 2
Level Level
}
// Level is numeric representation of compression level
type Level int
// Represents compression level that will be used in the middleware
const (
LevelDisabled Level = -1
LevelDefault Level = 0
LevelBestSpeed Level = 1
LevelBestCompression Level = 2
)
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Level: LevelDefault,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Level < LevelDisabled || cfg.Level > LevelBestCompression {
cfg.Level = ConfigDefault.Level
}
return cfg
}
+289
View File
@@ -0,0 +1,289 @@
package cors
import (
"strconv"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// AllowOriginsFunc defines a function that will set the 'Access-Control-Allow-Origin'
// response header to the 'origin' request header when returned true. This allows for
// dynamic evaluation of allowed origins. Note if AllowCredentials is true, wildcard origins
// will be not have the 'Access-Control-Allow-Credentials' header set to 'true'.
//
// Optional. Default: nil
AllowOriginsFunc func(origin string) bool
// AllowOrigin defines a comma separated list of origins that may access the resource.
//
// Optional. Default value "*"
AllowOrigins string
// AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request.
//
// Optional. Default value "GET,POST,HEAD,PUT,DELETE,PATCH"
AllowMethods string
// AllowHeaders defines a list of request headers that can be used when
// making the actual request. This is in response to a preflight request.
//
// Optional. Default value "".
AllowHeaders string
// AllowCredentials indicates whether or not the response to the request
// can be exposed when the credentials flag is true. When used as part of
// a response to a preflight request, this indicates whether or not the
// actual request can be made using credentials. Note: If true, AllowOrigins
// cannot be set to a wildcard ("*") to prevent security vulnerabilities.
//
// Optional. Default value false.
AllowCredentials bool
// ExposeHeaders defines a whitelist headers that clients are allowed to
// access.
//
// Optional. Default value "".
ExposeHeaders string
// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached.
// If you pass MaxAge 0, Access-Control-Max-Age header will not be added and
// browser will use 5 seconds by default.
// To disable caching completely, pass MaxAge value negative. It will set the Access-Control-Max-Age header 0.
//
// Optional. Default value 0.
MaxAge int
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
AllowOriginsFunc: nil,
AllowOrigins: "*",
AllowMethods: strings.Join([]string{
fiber.MethodGet,
fiber.MethodPost,
fiber.MethodHead,
fiber.MethodPut,
fiber.MethodDelete,
fiber.MethodPatch,
}, ","),
AllowHeaders: "",
AllowCredentials: false,
ExposeHeaders: "",
MaxAge: 0,
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := ConfigDefault
// Override config if provided
if len(config) > 0 {
cfg = config[0]
// Set default values
if cfg.AllowMethods == "" {
cfg.AllowMethods = ConfigDefault.AllowMethods
}
// When none of the AllowOrigins or AllowOriginsFunc config was defined, set the default AllowOrigins value with "*"
if cfg.AllowOrigins == "" && cfg.AllowOriginsFunc == nil {
cfg.AllowOrigins = ConfigDefault.AllowOrigins
}
}
// Warning logs if both AllowOrigins and AllowOriginsFunc are set
if cfg.AllowOrigins != "" && cfg.AllowOriginsFunc != nil {
log.Warn("[CORS] Both 'AllowOrigins' and 'AllowOriginsFunc' have been defined.")
}
// Validate CORS credentials configuration
if cfg.AllowCredentials && cfg.AllowOrigins == "*" {
panic("[CORS] Insecure setup, 'AllowCredentials' is set to true, and 'AllowOrigins' is set to a wildcard.")
}
// allowOrigins is a slice of strings that contains the allowed origins
// defined in the 'AllowOrigins' configuration.
allowOrigins := []string{}
allowSOrigins := []subdomain{}
allowAllOrigins := false
// Validate and normalize static AllowOrigins
if cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" {
origins := strings.Split(cfg.AllowOrigins, ",")
for _, origin := range origins {
if i := strings.Index(origin, "://*."); i != -1 {
trimmedOrigin := strings.TrimSpace(origin[:i+3] + origin[i+4:])
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
}
sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]}
allowSOrigins = append(allowSOrigins, sd)
} else {
trimmedOrigin := strings.TrimSpace(origin)
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
}
allowOrigins = append(allowOrigins, normalizedOrigin)
}
}
} else if cfg.AllowOrigins == "*" {
allowAllOrigins = true
}
// Strip white spaces
allowMethods := strings.ReplaceAll(cfg.AllowMethods, " ", "")
allowHeaders := strings.ReplaceAll(cfg.AllowHeaders, " ", "")
exposeHeaders := strings.ReplaceAll(cfg.ExposeHeaders, " ", "")
// Convert int to string
maxAge := strconv.Itoa(cfg.MaxAge)
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Get originHeader header
originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin))
// If the request does not have Origin header, the request is outside the scope of CORS
if originHeader == "" {
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
// Unless all origins are allowed, we include the Vary header to cache the response correctly
if !allowAllOrigins {
c.Vary(fiber.HeaderOrigin)
}
return c.Next()
}
// If it's a preflight request and doesn't have Access-Control-Request-Method header, it's outside the scope of CORS
if c.Method() == fiber.MethodOptions && c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
// Response to OPTIONS request should not be cached but,
// some caching can be configured to cache such responses.
// To Avoid poisoning the cache, we include the Vary header
// for non-CORS OPTIONS requests:
c.Vary(fiber.HeaderOrigin)
return c.Next()
}
// Set default allowOrigin to empty string
allowOrigin := ""
// Check allowed origins
if allowAllOrigins {
allowOrigin = "*"
} else {
// Check if the origin is in the list of allowed origins
for _, origin := range allowOrigins {
if origin == originHeader {
allowOrigin = originHeader
break
}
}
// Check if the origin is in the list of allowed subdomains
if allowOrigin == "" {
for _, sOrigin := range allowSOrigins {
if sOrigin.match(originHeader) {
allowOrigin = originHeader
break
}
}
}
}
// Run AllowOriginsFunc if the logic for
// handling the value in 'AllowOrigins' does
// not result in allowOrigin being set.
if allowOrigin == "" && cfg.AllowOriginsFunc != nil && cfg.AllowOriginsFunc(originHeader) {
allowOrigin = originHeader
}
// Simple request
// Ommit allowMethods and allowHeaders, only used for pre-flight requests
if c.Method() != fiber.MethodOptions {
if !allowAllOrigins {
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
c.Vary(fiber.HeaderOrigin)
}
setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg)
return c.Next()
}
// Pre-flight request
// Response to OPTIONS request should not be cached but,
// some caching can be configured to cache such responses.
// To Avoid poisoning the cache, we include the Vary header
// of preflight responses:
c.Vary(fiber.HeaderAccessControlRequestMethod)
c.Vary(fiber.HeaderAccessControlRequestHeaders)
c.Vary(fiber.HeaderOrigin)
setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)
// Send 204 No Content
return c.SendStatus(fiber.StatusNoContent)
}
}
// Function to set CORS headers
func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) {
if cfg.AllowCredentials {
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
if allowOrigin == "*" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.")
} else if allowOrigin != "" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
}
} else if allowOrigin != "" {
// For non-credential requests, it's safe to set to '*' or specific origins
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
}
// Set Allow-Methods if not empty
if allowMethods != "" {
c.Set(fiber.HeaderAccessControlAllowMethods, allowMethods)
}
// Set Allow-Headers if not empty
if allowHeaders != "" {
c.Set(fiber.HeaderAccessControlAllowHeaders, allowHeaders)
} else {
h := c.Get(fiber.HeaderAccessControlRequestHeaders)
if h != "" {
c.Set(fiber.HeaderAccessControlAllowHeaders, h)
}
}
// Set MaxAge if set
if cfg.MaxAge > 0 {
c.Set(fiber.HeaderAccessControlMaxAge, maxAge)
} else if cfg.MaxAge < 0 {
c.Set(fiber.HeaderAccessControlMaxAge, "0")
}
// Set Expose-Headers if not empty
if exposeHeaders != "" {
c.Set(fiber.HeaderAccessControlExposeHeaders, exposeHeaders)
}
}
+71
View File
@@ -0,0 +1,71 @@
package cors
import (
"net/url"
"strings"
)
// matchScheme compares the scheme of the domain and pattern
func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":")
pidx := strings.Index(pattern, ":")
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
}
// normalizeDomain removes the scheme and port from the input domain
func normalizeDomain(input string) string {
// Remove scheme
input = strings.TrimPrefix(strings.TrimPrefix(input, "http://"), "https://")
// Find and remove port, if present
if len(input) > 0 && input[0] != '[' {
if portIndex := strings.Index(input, ":"); portIndex != -1 {
input = input[:portIndex]
}
}
return input
}
// normalizeOrigin checks if the provided origin is in a correct format
// and normalizes it by removing any path or trailing slash.
// It returns a boolean indicating whether the origin is valid
// and the normalized origin.
func normalizeOrigin(origin string) (bool, string) {
parsedOrigin, err := url.Parse(origin)
if err != nil {
return false, ""
}
// Validate the scheme is either http or https
if parsedOrigin.Scheme != "http" && parsedOrigin.Scheme != "https" {
return false, ""
}
// Don't allow a wildcard with a protocol
// wildcards cannot be used within any other value. For example, the following header is not valid:
// Access-Control-Allow-Origin: https://*
if strings.Contains(parsedOrigin.Host, "*") {
return false, ""
}
// Validate there is a host present. The presence of a path, query, or fragment components
// is checked, but a trailing "/" (indicative of the root) is allowed for the path and will be normalized
if parsedOrigin.Host == "" || (parsedOrigin.Path != "" && parsedOrigin.Path != "/") || parsedOrigin.RawQuery != "" || parsedOrigin.Fragment != "" {
return false, ""
}
// Normalize the origin by constructing it from the scheme and host.
// The path or trailing slash is not included in the normalized origin.
return true, strings.ToLower(parsedOrigin.Scheme) + "://" + strings.ToLower(parsedOrigin.Host)
}
type subdomain struct {
// The wildcard pattern
prefix string
suffix string
}
func (s subdomain) match(o string) bool {
return len(o) >= len(s.prefix)+len(s.suffix) && strings.HasPrefix(o, s.prefix) && strings.HasSuffix(o, s.suffix)
}
+154
View File
@@ -0,0 +1,154 @@
package helmet
import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip middleware.
// Optional. Default: nil
Next func(*fiber.Ctx) bool
// XSSProtection
// Optional. Default value "0".
XSSProtection string
// ContentTypeNosniff
// Optional. Default value "nosniff".
ContentTypeNosniff string
// XFrameOptions
// Optional. Default value "SAMEORIGIN".
// Possible values: "SAMEORIGIN", "DENY", "ALLOW-FROM uri"
XFrameOptions string
// HSTSMaxAge
// Optional. Default value 0.
HSTSMaxAge int
// HSTSExcludeSubdomains
// Optional. Default value false.
HSTSExcludeSubdomains bool
// ContentSecurityPolicy
// Optional. Default value "".
ContentSecurityPolicy string
// CSPReportOnly
// Optional. Default value false.
CSPReportOnly bool
// HSTSPreloadEnabled
// Optional. Default value false.
HSTSPreloadEnabled bool
// ReferrerPolicy
// Optional. Default value "ReferrerPolicy".
ReferrerPolicy string
// Permissions-Policy
// Optional. Default value "".
PermissionPolicy string
// Cross-Origin-Embedder-Policy
// Optional. Default value "require-corp".
CrossOriginEmbedderPolicy string
// Cross-Origin-Opener-Policy
// Optional. Default value "same-origin".
CrossOriginOpenerPolicy string
// Cross-Origin-Resource-Policy
// Optional. Default value "same-origin".
CrossOriginResourcePolicy string
// Origin-Agent-Cluster
// Optional. Default value "?1".
OriginAgentCluster string
// X-DNS-Prefetch-Control
// Optional. Default value "off".
XDNSPrefetchControl string
// X-Download-Options
// Optional. Default value "noopen".
XDownloadOptions string
// X-Permitted-Cross-Domain-Policies
// Optional. Default value "none".
XPermittedCrossDomain string
}
// ConfigDefault is the default config
var ConfigDefault = Config{
XSSProtection: "0",
ContentTypeNosniff: "nosniff",
XFrameOptions: "SAMEORIGIN",
ReferrerPolicy: "no-referrer",
CrossOriginEmbedderPolicy: "require-corp",
CrossOriginOpenerPolicy: "same-origin",
CrossOriginResourcePolicy: "same-origin",
OriginAgentCluster: "?1",
XDNSPrefetchControl: "off",
XDownloadOptions: "noopen",
XPermittedCrossDomain: "none",
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.XSSProtection == "" {
cfg.XSSProtection = ConfigDefault.XSSProtection
}
if cfg.ContentTypeNosniff == "" {
cfg.ContentTypeNosniff = ConfigDefault.ContentTypeNosniff
}
if cfg.XFrameOptions == "" {
cfg.XFrameOptions = ConfigDefault.XFrameOptions
}
if cfg.ReferrerPolicy == "" {
cfg.ReferrerPolicy = ConfigDefault.ReferrerPolicy
}
if cfg.CrossOriginEmbedderPolicy == "" {
cfg.CrossOriginEmbedderPolicy = ConfigDefault.CrossOriginEmbedderPolicy
}
if cfg.CrossOriginOpenerPolicy == "" {
cfg.CrossOriginOpenerPolicy = ConfigDefault.CrossOriginOpenerPolicy
}
if cfg.CrossOriginResourcePolicy == "" {
cfg.CrossOriginResourcePolicy = ConfigDefault.CrossOriginResourcePolicy
}
if cfg.OriginAgentCluster == "" {
cfg.OriginAgentCluster = ConfigDefault.OriginAgentCluster
}
if cfg.XDNSPrefetchControl == "" {
cfg.XDNSPrefetchControl = ConfigDefault.XDNSPrefetchControl
}
if cfg.XDownloadOptions == "" {
cfg.XDownloadOptions = ConfigDefault.XDownloadOptions
}
if cfg.XPermittedCrossDomain == "" {
cfg.XPermittedCrossDomain = ConfigDefault.XPermittedCrossDomain
}
return cfg
}
+94
View File
@@ -0,0 +1,94 @@
package helmet
import (
"fmt"
"github.com/gofiber/fiber/v2"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Init config
cfg := configDefault(config...)
// Return middleware handler
return func(c *fiber.Ctx) error {
// Next request to skip middleware
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Set headers
if cfg.XSSProtection != "" {
c.Set(fiber.HeaderXXSSProtection, cfg.XSSProtection)
}
if cfg.ContentTypeNosniff != "" {
c.Set(fiber.HeaderXContentTypeOptions, cfg.ContentTypeNosniff)
}
if cfg.XFrameOptions != "" {
c.Set(fiber.HeaderXFrameOptions, cfg.XFrameOptions)
}
if cfg.CrossOriginEmbedderPolicy != "" {
c.Set("Cross-Origin-Embedder-Policy", cfg.CrossOriginEmbedderPolicy)
}
if cfg.CrossOriginOpenerPolicy != "" {
c.Set("Cross-Origin-Opener-Policy", cfg.CrossOriginOpenerPolicy)
}
if cfg.CrossOriginResourcePolicy != "" {
c.Set("Cross-Origin-Resource-Policy", cfg.CrossOriginResourcePolicy)
}
if cfg.OriginAgentCluster != "" {
c.Set("Origin-Agent-Cluster", cfg.OriginAgentCluster)
}
if cfg.ReferrerPolicy != "" {
c.Set("Referrer-Policy", cfg.ReferrerPolicy)
}
if cfg.XDNSPrefetchControl != "" {
c.Set("X-DNS-Prefetch-Control", cfg.XDNSPrefetchControl)
}
if cfg.XDownloadOptions != "" {
c.Set("X-Download-Options", cfg.XDownloadOptions)
}
if cfg.XPermittedCrossDomain != "" {
c.Set("X-Permitted-Cross-Domain-Policies", cfg.XPermittedCrossDomain)
}
// Handle HSTS headers
if c.Protocol() == "https" && cfg.HSTSMaxAge != 0 {
subdomains := ""
if !cfg.HSTSExcludeSubdomains {
subdomains = "; includeSubDomains"
}
if cfg.HSTSPreloadEnabled {
subdomains = fmt.Sprintf("%s; preload", subdomains)
}
c.Set(fiber.HeaderStrictTransportSecurity, fmt.Sprintf("max-age=%d%s", cfg.HSTSMaxAge, subdomains))
}
// Handle Content-Security-Policy headers
if cfg.ContentSecurityPolicy != "" {
if cfg.CSPReportOnly {
c.Set(fiber.HeaderContentSecurityPolicyReportOnly, cfg.ContentSecurityPolicy)
} else {
c.Set(fiber.HeaderContentSecurityPolicy, cfg.ContentSecurityPolicy)
}
}
// Handle Permissions-Policy headers
if cfg.PermissionPolicy != "" {
c.Set(fiber.HeaderPermissionsPolicy, cfg.PermissionPolicy)
}
return c.Next()
}
}
+128
View File
@@ -0,0 +1,128 @@
package limiter
import (
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Max number of recent connections during `Expiration` seconds before sending a 429 response
//
// Default: 5
Max int
// KeyGenerator allows you to generate custom keys, by default c.IP() is used
//
// Default: func(c *fiber.Ctx) string {
// return c.IP()
// }
KeyGenerator func(*fiber.Ctx) string
// Expiration is the time on how long to keep records of requests in memory
//
// Default: 1 * time.Minute
Expiration time.Duration
// LimitReached is called when a request hits the limit
//
// Default: func(c *fiber.Ctx) error {
// return c.SendStatus(fiber.StatusTooManyRequests)
// }
LimitReached fiber.Handler
// When set to true, requests with StatusCode >= 400 won't be counted.
//
// Default: false
SkipFailedRequests bool
// When set to true, requests with StatusCode < 400 won't be counted.
//
// Default: false
SkipSuccessfulRequests bool
// Store is used to store the state of the middleware
//
// Default: an in memory store for this process only
Storage fiber.Storage
// LimiterMiddleware is the struct that implements a limiter middleware.
//
// Default: a new Fixed Window Rate Limiter
LimiterMiddleware LimiterHandler
// Deprecated: Use Expiration instead
Duration time.Duration
// Deprecated: Use Storage instead
Store fiber.Storage
// Deprecated: Use KeyGenerator instead
Key func(*fiber.Ctx) string
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Max: 5,
Expiration: 1 * time.Minute,
KeyGenerator: func(c *fiber.Ctx) string {
return c.IP()
},
LimitReached: func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusTooManyRequests)
},
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
LimiterMiddleware: FixedWindow{},
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if int(cfg.Duration.Seconds()) > 0 {
log.Warn("[LIMITER] Duration is deprecated, please use Expiration")
cfg.Expiration = cfg.Duration
}
if cfg.Key != nil {
log.Warn("[LIMITER] Key is deprecated, please us KeyGenerator")
cfg.KeyGenerator = cfg.Key
}
if cfg.Store != nil {
log.Warn("[LIMITER] Store is deprecated, please use Storage")
cfg.Storage = cfg.Store
}
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Max <= 0 {
cfg.Max = ConfigDefault.Max
}
if int(cfg.Expiration.Seconds()) <= 0 {
cfg.Expiration = ConfigDefault.Expiration
}
if cfg.KeyGenerator == nil {
cfg.KeyGenerator = ConfigDefault.KeyGenerator
}
if cfg.LimitReached == nil {
cfg.LimitReached = ConfigDefault.LimitReached
}
if cfg.LimiterMiddleware == nil {
cfg.LimiterMiddleware = ConfigDefault.LimiterMiddleware
}
return cfg
}
+25
View File
@@ -0,0 +1,25 @@
package limiter
import (
"github.com/gofiber/fiber/v2"
)
const (
// X-RateLimit-* headers
xRateLimitLimit = "X-RateLimit-Limit"
xRateLimitRemaining = "X-RateLimit-Remaining"
xRateLimitReset = "X-RateLimit-Reset"
)
type LimiterHandler interface {
New(config Config) fiber.Handler
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Return the specified middleware handler.
return cfg.LimiterMiddleware.New(cfg)
}
+106
View File
@@ -0,0 +1,106 @@
package limiter
import (
"strconv"
"sync"
"sync/atomic"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
type FixedWindow struct{}
// New creates a new fixed window middleware handler
func (FixedWindow) New(cfg Config) fiber.Handler {
var (
// Limiter variables
mux = &sync.RWMutex{}
max = strconv.Itoa(cfg.Max)
expiration = uint64(cfg.Expiration.Seconds())
)
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage)
// Update timestamp every second
utils.StartTimeStampUpdater()
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Get key from request
key := cfg.KeyGenerator(c)
// Lock entry
mux.Lock()
// Get entry from pool and release when finished
e := manager.get(key)
// Get timestamp
ts := uint64(atomic.LoadUint32(&utils.Timestamp))
// Set expiration if entry does not exist
if e.exp == 0 {
e.exp = ts + expiration
} else if ts >= e.exp {
// Check if entry is expired
e.currHits = 0
e.exp = ts + expiration
}
// Increment hits
e.currHits++
// Calculate when it resets in seconds
resetInSec := e.exp - ts
// Set how many hits we have left
remaining := cfg.Max - e.currHits
// Update storage
manager.set(key, e, cfg.Expiration)
// Unlock entry
mux.Unlock()
// Check if hits exceed the cfg.Max
if remaining < 0 {
// Return response with Retry-After header
// https://tools.ietf.org/html/rfc6584
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
// Call LimitReached handler
return cfg.LimitReached(c)
}
// Continue stack for reaching c.Response().StatusCode()
// Store err for returning
err := c.Next()
// Check for SkipFailedRequests and SkipSuccessfulRequests
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
// Lock entry
mux.Lock()
e = manager.get(key)
e.currHits--
remaining++
manager.set(key, e, cfg.Expiration)
// Unlock entry
mux.Unlock()
}
// We can continue, update RateLimit headers
c.Set(xRateLimitLimit, max)
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
return err
}
}
@@ -0,0 +1,137 @@
package limiter
import (
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
type SlidingWindow struct{}
// New creates a new sliding window middleware handler
func (SlidingWindow) New(cfg Config) fiber.Handler {
var (
// Limiter variables
mux = &sync.RWMutex{}
max = strconv.Itoa(cfg.Max)
expiration = uint64(cfg.Expiration.Seconds())
)
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage)
// Update timestamp every second
utils.StartTimeStampUpdater()
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Get key from request
key := cfg.KeyGenerator(c)
// Lock entry
mux.Lock()
// Get entry from pool and release when finished
e := manager.get(key)
// Get timestamp
ts := uint64(atomic.LoadUint32(&utils.Timestamp))
// Set expiration if entry does not exist
if e.exp == 0 {
e.exp = ts + expiration
} else if ts >= e.exp {
// The entry has expired, handle the expiration.
// Set the prevHits to the current hits and reset the hits to 0.
e.prevHits = e.currHits
// Reset the current hits to 0.
e.currHits = 0
// Check how much into the current window it currently is and sets the
// expiry based on that, otherwise this would only reset on
// the next request and not show the correct expiry.
elapsed := ts - e.exp
if elapsed >= expiration {
e.exp = ts + expiration
} else {
e.exp = ts + expiration - elapsed
}
}
// Increment hits
e.currHits++
// Calculate when it resets in seconds
resetInSec := e.exp - ts
// weight = time until current window reset / total window length
weight := float64(resetInSec) / float64(expiration)
// rate = request count in previous window - weight + request count in current window
rate := int(float64(e.prevHits)*weight) + e.currHits
// Calculate how many hits can be made based on the current rate
remaining := cfg.Max - rate
// Update storage. Garbage collect when the next window ends.
// |--------------------------|--------------------------|
// ^ ^ ^ ^
// ts e.exp End sample window End next window
// <------------>
// resetInSec
// resetInSec = e.exp - ts - time until end of current window.
// duration + expiration = end of next window.
// Because we don't want to garbage collect in the middle of a window
// we add the expiration to the duration.
// Otherwise after the end of "sample window", attackers could launch
// a new request with the full window length.
manager.set(key, e, time.Duration(resetInSec+expiration)*time.Second)
// Unlock entry
mux.Unlock()
// Check if hits exceed the cfg.Max
if remaining < 0 {
// Return response with Retry-After header
// https://tools.ietf.org/html/rfc6584
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
// Call LimitReached handler
return cfg.LimitReached(c)
}
// Continue stack for reaching c.Response().StatusCode()
// Store err for returning
err := c.Next()
// Check for SkipFailedRequests and SkipSuccessfulRequests
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
// Lock entry
mux.Lock()
e = manager.get(key)
e.currHits--
remaining++
manager.set(key, e, cfg.Expiration)
// Unlock entry
mux.Unlock()
}
// We can continue, update RateLimit headers
c.Set(xRateLimitLimit, max)
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
return err
}
}
+92
View File
@@ -0,0 +1,92 @@
package limiter
import (
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/memory"
)
// go:generate msgp
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
type item struct {
currHits int
prevHits int
exp uint64
}
//msgp:ignore manager
type manager struct {
pool sync.Pool
memory *memory.Storage
storage fiber.Storage
}
func newManager(storage fiber.Storage) *manager {
// Create new storage handler
manager := &manager{
pool: sync.Pool{
New: func() interface{} {
return new(item)
},
},
}
if storage != nil {
// Use provided storage if provided
manager.storage = storage
} else {
// Fallback too memory storage
manager.memory = memory.New()
}
return manager
}
// acquire returns an *entry from the sync.Pool
func (m *manager) acquire() *item {
return m.pool.Get().(*item) //nolint:forcetypeassert // We store nothing else in the pool
}
// release and reset *entry to sync.Pool
func (m *manager) release(e *item) {
e.prevHits = 0
e.currHits = 0
e.exp = 0
m.pool.Put(e)
}
// get data from storage or memory
func (m *manager) get(key string) *item {
var it *item
if m.storage != nil {
it = m.acquire()
raw, err := m.storage.Get(key)
if err != nil {
return it
}
if raw != nil {
if _, err := it.UnmarshalMsg(raw); err != nil {
return it
}
}
return it
}
if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
it = m.acquire()
return it
}
return it
}
// set data to storage or memory
func (m *manager) set(key string, it *item, exp time.Duration) {
if m.storage != nil {
if raw, err := it.MarshalMsg(nil); err == nil {
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
}
// we can release data because it's serialized to database
m.release(it)
} else {
m.memory.Set(key, it, exp)
}
}
+160
View File
@@ -0,0 +1,160 @@
package limiter
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
import (
"github.com/tinylib/msgp/msgp"
)
// DecodeMsg implements msgp.Decodable
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "currHits":
z.currHits, err = dc.ReadInt()
if err != nil {
err = msgp.WrapError(err, "currHits")
return
}
case "prevHits":
z.prevHits, err = dc.ReadInt()
if err != nil {
err = msgp.WrapError(err, "prevHits")
return
}
case "exp":
z.exp, err = dc.ReadUint64()
if err != nil {
err = msgp.WrapError(err, "exp")
return
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 3
// write "currHits"
err = en.Append(0x83, 0xa8, 0x63, 0x75, 0x72, 0x72, 0x48, 0x69, 0x74, 0x73)
if err != nil {
return
}
err = en.WriteInt(z.currHits)
if err != nil {
err = msgp.WrapError(err, "currHits")
return
}
// write "prevHits"
err = en.Append(0xa8, 0x70, 0x72, 0x65, 0x76, 0x48, 0x69, 0x74, 0x73)
if err != nil {
return
}
err = en.WriteInt(z.prevHits)
if err != nil {
err = msgp.WrapError(err, "prevHits")
return
}
// write "exp"
err = en.Append(0xa3, 0x65, 0x78, 0x70)
if err != nil {
return
}
err = en.WriteUint64(z.exp)
if err != nil {
err = msgp.WrapError(err, "exp")
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 3
// string "currHits"
o = append(o, 0x83, 0xa8, 0x63, 0x75, 0x72, 0x72, 0x48, 0x69, 0x74, 0x73)
o = msgp.AppendInt(o, z.currHits)
// string "prevHits"
o = append(o, 0xa8, 0x70, 0x72, 0x65, 0x76, 0x48, 0x69, 0x74, 0x73)
o = msgp.AppendInt(o, z.prevHits)
// string "exp"
o = append(o, 0xa3, 0x65, 0x78, 0x70)
o = msgp.AppendUint64(o, z.exp)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "currHits":
z.currHits, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
err = msgp.WrapError(err, "currHits")
return
}
case "prevHits":
z.prevHits, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
err = msgp.WrapError(err, "prevHits")
return
}
case "exp":
z.exp, bts, err = msgp.ReadUint64Bytes(bts)
if err != nil {
err = msgp.WrapError(err, "exp")
return
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z item) Msgsize() (s int) {
s = 1 + 9 + msgp.IntSize + 9 + msgp.IntSize + 4 + msgp.Uint64Size
return
}
+136
View File
@@ -0,0 +1,136 @@
package logger
import (
"io"
"os"
"time"
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Done is a function that is called after the log string for a request is written to Output,
// and pass the log string as parameter.
//
// Optional. Default: nil
Done func(c *fiber.Ctx, logString []byte)
// tagFunctions defines the custom tag action
//
// Optional. Default: map[string]LogFunc
CustomTags map[string]LogFunc
// Format defines the logging tags
//
// Optional. Default: ${time} | ${status} | ${latency} | ${ip} | ${method} | ${path} | ${error}\n
Format string
// TimeFormat https://programming.guide/go/format-parse-string-time-date-example.html
//
// Optional. Default: 15:04:05
TimeFormat string
// TimeZone can be specified, such as "UTC" and "America/New_York" and "Asia/Chongqing", etc
//
// Optional. Default: "Local"
TimeZone string
// TimeInterval is the delay before the timestamp is updated
//
// Optional. Default: 500 * time.Millisecond
TimeInterval time.Duration
// Output is a writer where logs are written
//
// Default: os.Stdout
Output io.Writer
// DisableColors defines if the logs output should be colorized
//
// Default: false
DisableColors bool
enableColors bool
enableLatency bool
timeZoneLocation *time.Location
}
const (
startTag = "${"
endTag = "}"
paramSeparator = ":"
)
type Buffer interface {
Len() int
ReadFrom(r io.Reader) (int64, error)
WriteTo(w io.Writer) (int64, error)
Bytes() []byte
Write(p []byte) (int, error)
WriteByte(c byte) error
WriteString(s string) (int, error)
Set(p []byte)
SetString(s string)
String() string
}
type LogFunc func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error)
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Done: nil,
Format: "${time} | ${status} | ${latency} | ${ip} | ${method} | ${path} | ${error}\n",
TimeFormat: "15:04:05",
TimeZone: "Local",
TimeInterval: 500 * time.Millisecond,
Output: os.Stdout,
DisableColors: false,
enableColors: true,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Done == nil {
cfg.Done = ConfigDefault.Done
}
if cfg.Format == "" {
cfg.Format = ConfigDefault.Format
}
if cfg.TimeZone == "" {
cfg.TimeZone = ConfigDefault.TimeZone
}
if cfg.TimeFormat == "" {
cfg.TimeFormat = ConfigDefault.TimeFormat
}
if int(cfg.TimeInterval) <= 0 {
cfg.TimeInterval = ConfigDefault.TimeInterval
}
if cfg.Output == nil {
cfg.Output = ConfigDefault.Output
}
if !cfg.DisableColors && cfg.Output == ConfigDefault.Output {
cfg.enableColors = true
}
return cfg
}
+16
View File
@@ -0,0 +1,16 @@
package logger
import (
"sync/atomic"
"time"
)
// Data is a struct to define some variables to use in custom logger function.
type Data struct {
Pid string
ErrPaddingStr string
ChainErr error
Start time.Time
Stop time.Time
Timestamp atomic.Value
}
+182
View File
@@ -0,0 +1,182 @@
package logger
import (
"fmt"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/mattn/go-colorable"
"github.com/mattn/go-isatty"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Get timezone location
tz, err := time.LoadLocation(cfg.TimeZone)
if err != nil || tz == nil {
cfg.timeZoneLocation = time.Local
} else {
cfg.timeZoneLocation = tz
}
// Check if format contains latency
cfg.enableLatency = strings.Contains(cfg.Format, "${"+TagLatency+"}")
var timestamp atomic.Value
// Create correct timeformat
timestamp.Store(time.Now().In(cfg.timeZoneLocation).Format(cfg.TimeFormat))
// Update date/time every 500 milliseconds in a separate go routine
if strings.Contains(cfg.Format, "${"+TagTime+"}") {
go func() {
for {
time.Sleep(cfg.TimeInterval)
timestamp.Store(time.Now().In(cfg.timeZoneLocation).Format(cfg.TimeFormat))
}
}()
}
// Set PID once
pid := strconv.Itoa(os.Getpid())
// Set variables
var (
once sync.Once
mu sync.Mutex
errHandler fiber.ErrorHandler
dataPool = sync.Pool{New: func() interface{} { return new(Data) }}
)
// If colors are enabled, check terminal compatibility
if cfg.enableColors {
cfg.Output = colorable.NewColorableStdout()
if os.Getenv("TERM") == "dumb" || os.Getenv("NO_COLOR") == "1" || (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd())) {
cfg.Output = colorable.NewNonColorable(os.Stdout)
}
}
errPadding := 15
errPaddingStr := strconv.Itoa(errPadding)
// instead of analyzing the template inside(handler) each time, this is done once before
// and we create several slices of the same length with the functions to be executed and fixed parts.
templateChain, logFunChain, err := buildLogFuncChain(&cfg, createTagMap(&cfg))
if err != nil {
panic(err)
}
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Set error handler once
once.Do(func() {
// get longested possible path
stack := c.App().Stack()
for m := range stack {
for r := range stack[m] {
if len(stack[m][r].Path) > errPadding {
errPadding = len(stack[m][r].Path)
errPaddingStr = strconv.Itoa(errPadding)
}
}
}
// override error handler
errHandler = c.App().ErrorHandler
})
// Logger data
data := dataPool.Get().(*Data) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
// no need for a reset, as long as we always override everything
data.Pid = pid
data.ErrPaddingStr = errPaddingStr
data.Timestamp = timestamp
// put data back in the pool
defer dataPool.Put(data)
// Set latency start time
if cfg.enableLatency {
data.Start = time.Now()
}
// Handle request, store err for logging
chainErr := c.Next()
data.ChainErr = chainErr
// Manually call error handler
if chainErr != nil {
if err := errHandler(c, chainErr); err != nil {
_ = c.SendStatus(fiber.StatusInternalServerError) //nolint:errcheck // TODO: Explain why we ignore the error here
}
}
// Set latency stop time
if cfg.enableLatency {
data.Stop = time.Now()
}
// Get new buffer
buf := bytebufferpool.Get()
var err error
// Loop over template parts execute dynamic parts and add fixed parts to the buffer
for i, logFunc := range logFunChain {
if logFunc == nil {
_, _ = buf.Write(templateChain[i]) //nolint:errcheck // This will never fail
} else if templateChain[i] == nil {
_, err = logFunc(buf, c, data, "")
} else {
_, err = logFunc(buf, c, data, utils.UnsafeString(templateChain[i]))
}
if err != nil {
break
}
}
// Also write errors to the buffer
if err != nil {
_, _ = buf.WriteString(err.Error()) //nolint:errcheck // This will never fail
}
mu.Lock()
// Write buffer to output
if _, err := cfg.Output.Write(buf.Bytes()); err != nil {
// Write error to output
if _, err := cfg.Output.Write([]byte(err.Error())); err != nil {
// There is something wrong with the given io.Writer
_, _ = fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
}
}
mu.Unlock()
if cfg.Done != nil {
cfg.Done(c, buf.Bytes())
}
// Put buffer back to pool
bytebufferpool.Put(buf)
return nil
}
}
func appendInt(output Buffer, v int) (int, error) {
old := output.Len()
output.Set(fasthttp.AppendUint(output.Bytes(), v))
return output.Len() - old, nil
}
+211
View File
@@ -0,0 +1,211 @@
package logger
import (
"fmt"
"strings"
"github.com/gofiber/fiber/v2"
)
// Logger variables
const (
TagPid = "pid"
TagTime = "time"
TagReferer = "referer"
TagProtocol = "protocol"
TagPort = "port"
TagIP = "ip"
TagIPs = "ips"
TagHost = "host"
TagMethod = "method"
TagPath = "path"
TagURL = "url"
TagUA = "ua"
TagLatency = "latency"
TagStatus = "status"
TagResBody = "resBody"
TagReqHeaders = "reqHeaders"
TagQueryStringParams = "queryParams"
TagBody = "body"
TagBytesSent = "bytesSent"
TagBytesReceived = "bytesReceived"
TagRoute = "route"
TagError = "error"
// Deprecated: Use TagReqHeader instead
TagHeader = "header:"
TagReqHeader = "reqHeader:"
TagRespHeader = "respHeader:"
TagLocals = "locals:"
TagQuery = "query:"
TagForm = "form:"
TagCookie = "cookie:"
TagBlack = "black"
TagRed = "red"
TagGreen = "green"
TagYellow = "yellow"
TagBlue = "blue"
TagMagenta = "magenta"
TagCyan = "cyan"
TagWhite = "white"
TagReset = "reset"
)
// createTagMap function merged the default with the custom tags
func createTagMap(cfg *Config) map[string]LogFunc {
// Set default tags
tagFunctions := map[string]LogFunc{
TagReferer: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Get(fiber.HeaderReferer))
},
TagProtocol: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Protocol())
},
TagPort: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Port())
},
TagIP: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.IP())
},
TagIPs: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Get(fiber.HeaderXForwardedFor))
},
TagHost: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Hostname())
},
TagPath: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Path())
},
TagURL: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.OriginalURL())
},
TagUA: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Get(fiber.HeaderUserAgent))
},
TagBody: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.Write(c.Body())
},
TagBytesReceived: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return appendInt(output, len(c.Request().Body()))
},
TagBytesSent: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
if c.Response().Header.ContentLength() < 0 {
return appendInt(output, 0)
}
return appendInt(output, len(c.Response().Body()))
},
TagRoute: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Route().Path)
},
TagResBody: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.Write(c.Response().Body())
},
TagReqHeaders: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
reqHeaders := make([]string, 0)
for k, v := range c.GetReqHeaders() {
reqHeaders = append(reqHeaders, k+"="+strings.Join(v, ","))
}
return output.Write([]byte(strings.Join(reqHeaders, "&")))
},
TagQueryStringParams: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Request().URI().QueryArgs().String())
},
TagBlack: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Black)
},
TagRed: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Red)
},
TagGreen: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Green)
},
TagYellow: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Yellow)
},
TagBlue: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Blue)
},
TagMagenta: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Magenta)
},
TagCyan: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Cyan)
},
TagWhite: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.White)
},
TagReset: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Reset)
},
TagError: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
if data.ChainErr != nil {
if cfg.enableColors {
colors := c.App().Config().ColorScheme
return output.WriteString(fmt.Sprintf("%s%s%s", colors.Red, data.ChainErr.Error(), colors.Reset))
}
return output.WriteString(data.ChainErr.Error())
}
return output.WriteString("-")
},
TagReqHeader: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Get(extraParam))
},
TagHeader: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Get(extraParam))
},
TagRespHeader: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.GetRespHeader(extraParam))
},
TagQuery: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Query(extraParam))
},
TagForm: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.FormValue(extraParam))
},
TagCookie: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Cookies(extraParam))
},
TagLocals: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
switch v := c.Locals(extraParam).(type) {
case []byte:
return output.Write(v)
case string:
return output.WriteString(v)
case nil:
return 0, nil
default:
return output.WriteString(fmt.Sprintf("%v", v))
}
},
TagStatus: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
if cfg.enableColors {
colors := c.App().Config().ColorScheme
return output.WriteString(fmt.Sprintf("%s%3d%s", statusColor(c.Response().StatusCode(), colors), c.Response().StatusCode(), colors.Reset))
}
return appendInt(output, c.Response().StatusCode())
},
TagMethod: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
if cfg.enableColors {
colors := c.App().Config().ColorScheme
return output.WriteString(fmt.Sprintf("%s%s%s", methodColor(c.Method(), colors), c.Method(), colors.Reset))
}
return output.WriteString(c.Method())
},
TagPid: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(data.Pid)
},
TagLatency: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
latency := data.Stop.Sub(data.Start)
return output.WriteString(fmt.Sprintf("%13v", latency))
},
TagTime: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(data.Timestamp.Load().(string)) //nolint:forcetypeassert // We always store a string in here
},
}
// merge with custom tags from user
for k, v := range cfg.CustomTags {
tagFunctions[k] = v
}
return tagFunctions
}
+70
View File
@@ -0,0 +1,70 @@
package logger
import (
"bytes"
"errors"
"github.com/gofiber/fiber/v2/utils"
)
// buildLogFuncChain analyzes the template and creates slices with the functions for execution and
// slices with the fixed parts of the template and the parameters
//
// fixParts contains the fixed parts of the template or parameters if a function is stored in the funcChain at this position
// funcChain contains for the parts which exist the functions for the dynamic parts
// funcChain and fixParts always have the same length and contain nil for the parts where no data is required in the chain,
// if a function exists for the part, a parameter for it can also exist in the fixParts slice
func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) ([][]byte, []LogFunc, error) {
// process flow is copied from the fasttemplate flow https://github.com/valyala/fasttemplate/blob/2a2d1afadadf9715bfa19683cdaeac8347e5d9f9/template.go#L23-L62
templateB := utils.UnsafeBytes(cfg.Format)
startTagB := utils.UnsafeBytes(startTag)
endTagB := utils.UnsafeBytes(endTag)
paramSeparatorB := utils.UnsafeBytes(paramSeparator)
var fixParts [][]byte
var funcChain []LogFunc
for {
currentPos := bytes.Index(templateB, startTagB)
if currentPos < 0 {
// no starting tag found in the existing template part
break
}
// add fixed part
funcChain = append(funcChain, nil)
fixParts = append(fixParts, templateB[:currentPos])
templateB = templateB[currentPos+len(startTagB):]
currentPos = bytes.Index(templateB, endTagB)
if currentPos < 0 {
// cannot find end tag - just write it to the output.
funcChain = append(funcChain, nil)
fixParts = append(fixParts, startTagB)
break
}
// ## function block ##
// first check for tags with parameters
if index := bytes.Index(templateB[:currentPos], paramSeparatorB); index != -1 {
logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:index+1])]
if !ok {
return nil, nil, errors.New("No parameter found in \"" + utils.UnsafeString(templateB[:currentPos]) + "\"")
}
funcChain = append(funcChain, logFunc)
// add param to the fixParts
fixParts = append(fixParts, templateB[index+1:currentPos])
} else if logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:currentPos])]; ok {
// add functions without parameter
funcChain = append(funcChain, logFunc)
fixParts = append(fixParts, nil)
}
// ## function block end ##
// reduce the template string
templateB = templateB[currentPos+len(endTagB):]
}
// set the rest
funcChain = append(funcChain, nil)
fixParts = append(fixParts, templateB)
return fixParts, funcChain, nil
}
+39
View File
@@ -0,0 +1,39 @@
package logger
import (
"github.com/gofiber/fiber/v2"
)
func methodColor(method string, colors fiber.Colors) string {
switch method {
case fiber.MethodGet:
return colors.Cyan
case fiber.MethodPost:
return colors.Green
case fiber.MethodPut:
return colors.Yellow
case fiber.MethodDelete:
return colors.Red
case fiber.MethodPatch:
return colors.White
case fiber.MethodHead:
return colors.Magenta
case fiber.MethodOptions:
return colors.Blue
default:
return colors.Reset
}
}
func statusColor(code int, colors fiber.Colors) string {
switch {
case code >= fiber.StatusOK && code < fiber.StatusMultipleChoices:
return colors.Green
case code >= fiber.StatusMultipleChoices && code < fiber.StatusBadRequest:
return colors.Blue
case code >= fiber.StatusBadRequest && code < fiber.StatusInternalServerError:
return colors.Yellow
default:
return colors.Red
}
}
+47
View File
@@ -0,0 +1,47 @@
package recover //nolint:predeclared // TODO: Rename to some non-builtin
import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// EnableStackTrace enables handling stack trace
//
// Optional. Default: false
EnableStackTrace bool
// StackTraceHandler defines a function to handle stack trace
//
// Optional. Default: defaultStackTraceHandler
StackTraceHandler func(c *fiber.Ctx, e interface{})
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
EnableStackTrace: false,
StackTraceHandler: defaultStackTraceHandler,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
if cfg.EnableStackTrace && cfg.StackTraceHandler == nil {
cfg.StackTraceHandler = defaultStackTraceHandler
}
return cfg
}
+45
View File
@@ -0,0 +1,45 @@
package recover //nolint:predeclared // TODO: Rename to some non-builtin
import (
"fmt"
"os"
"runtime/debug"
"github.com/gofiber/fiber/v2"
)
func defaultStackTraceHandler(_ *fiber.Ctx, e interface{}) {
_, _ = os.Stderr.WriteString(fmt.Sprintf("panic: %v\n%s\n", e, debug.Stack())) //nolint:errcheck // This will never fail
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Return new handler
return func(c *fiber.Ctx) (err error) { //nolint:nonamedreturns // Uses recover() to overwrite the error
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Catch panics
defer func() {
if r := recover(); r != nil {
if cfg.EnableStackTrace {
cfg.StackTraceHandler(c, r)
}
var ok bool
if err, ok = r.(error); !ok {
// Set error that will call the global error handler
err = fmt.Errorf("%v", r)
}
}
}()
// Return err if exist, else move to next handler
return c.Next()
}
}
+230
View File
@@ -0,0 +1,230 @@
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 🤖 Github Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package fiber
import (
"sort"
"strings"
"sync"
"sync/atomic"
)
// Put fields related to mounting.
type mountFields struct {
// Mounted and main apps
appList map[string]*App
// Ordered keys of apps (sorted by key length for Render)
appListKeys []string
// check added routes of sub-apps
subAppsRoutesAdded sync.Once
// check mounted sub-apps
subAppsProcessed sync.Once
// Prefix of app if it was mounted
mountPath string
}
// Create empty mountFields instance
func newMountFields(app *App) *mountFields {
return &mountFields{
appList: map[string]*App{"": app},
appListKeys: make([]string, 0),
}
}
// Mount attaches another app instance as a sub-router along a routing path.
// It's very useful to split up a large API as many independent routers and
// compose them as a single service using Mount. The fiber's error handler and
// any of the fiber's sub apps are added to the application's error handlers
// to be invoked on errors that happen within the prefix route.
func (app *App) Mount(prefix string, subApp *App) Router {
prefix = strings.TrimRight(prefix, "/")
if prefix == "" {
prefix = "/"
}
// Support for configs of mounted-apps and sub-mounted-apps
for mountedPrefixes, subApp := range subApp.mountFields.appList {
path := getGroupPath(prefix, mountedPrefixes)
subApp.mountFields.mountPath = path
app.mountFields.appList[path] = subApp
}
// register mounted group
mountGroup := &Group{Prefix: prefix, app: subApp}
app.register(methodUse, prefix, mountGroup)
// Execute onMount hooks
if err := subApp.hooks.executeOnMountHooks(app); err != nil {
panic(err)
}
return app
}
// Mount attaches another app instance as a sub-router along a routing path.
// It's very useful to split up a large API as many independent routers and
// compose them as a single service using Mount.
func (grp *Group) Mount(prefix string, subApp *App) Router {
groupPath := getGroupPath(grp.Prefix, prefix)
groupPath = strings.TrimRight(groupPath, "/")
if groupPath == "" {
groupPath = "/"
}
// Support for configs of mounted-apps and sub-mounted-apps
for mountedPrefixes, subApp := range subApp.mountFields.appList {
path := getGroupPath(groupPath, mountedPrefixes)
subApp.mountFields.mountPath = path
grp.app.mountFields.appList[path] = subApp
}
// register mounted group
mountGroup := &Group{Prefix: groupPath, app: subApp}
grp.app.register(methodUse, groupPath, mountGroup)
// Execute onMount hooks
if err := subApp.hooks.executeOnMountHooks(grp.app); err != nil {
panic(err)
}
return grp
}
// The MountPath property contains one or more path patterns on which a sub-app was mounted.
func (app *App) MountPath() string {
return app.mountFields.mountPath
}
// hasMountedApps Checks if there are any mounted apps in the current application.
func (app *App) hasMountedApps() bool {
return len(app.mountFields.appList) > 1
}
// mountStartupProcess Handles the startup process of mounted apps by appending sub-app routes, generating app list keys, and processing sub-app routes.
func (app *App) mountStartupProcess() {
if app.hasMountedApps() {
// add routes of sub-apps
app.mountFields.subAppsProcessed.Do(func() {
app.appendSubAppLists(app.mountFields.appList)
app.generateAppListKeys()
})
// adds the routes of the sub-apps to the current application.
app.mountFields.subAppsRoutesAdded.Do(func() {
app.processSubAppsRoutes()
})
}
}
// generateAppListKeys generates app list keys for Render, should work after appendSubAppLists
func (app *App) generateAppListKeys() {
for key := range app.mountFields.appList {
app.mountFields.appListKeys = append(app.mountFields.appListKeys, key)
}
sort.Slice(app.mountFields.appListKeys, func(i, j int) bool {
return len(app.mountFields.appListKeys[i]) < len(app.mountFields.appListKeys[j])
})
}
// appendSubAppLists supports nested for sub apps
func (app *App) appendSubAppLists(appList map[string]*App, parent ...string) {
// Optimize: Cache parent prefix
parentPrefix := ""
if len(parent) > 0 {
parentPrefix = parent[0]
}
for prefix, subApp := range appList {
// skip real app
if prefix == "" {
continue
}
if parentPrefix != "" {
prefix = getGroupPath(parentPrefix, prefix)
}
if _, ok := app.mountFields.appList[prefix]; !ok {
app.mountFields.appList[prefix] = subApp
}
// The first element of appList is always the app itself. If there are no other sub apps, we should skip appending nested apps.
if len(subApp.mountFields.appList) > 1 {
app.appendSubAppLists(subApp.mountFields.appList, prefix)
}
}
}
// processSubAppsRoutes adds routes of sub-apps recursively when the server is started
func (app *App) processSubAppsRoutes() {
for prefix, subApp := range app.mountFields.appList {
// skip real app
if prefix == "" {
continue
}
// process the inner routes
if subApp.hasMountedApps() {
subApp.mountFields.subAppsRoutesAdded.Do(func() {
subApp.processSubAppsRoutes()
})
}
}
var handlersCount uint32
var routePos uint32
// Iterate over the stack of the parent app
for m := range app.stack {
// Iterate over each route in the stack
stackLen := len(app.stack[m])
for i := 0; i < stackLen; i++ {
route := app.stack[m][i]
// Check if the route has a mounted app
if !route.mount {
routePos++
// If not, update the route's position and continue
route.pos = routePos
if !route.use || (route.use && m == 0) {
handlersCount += uint32(len(route.Handlers))
}
continue
}
// Create a slice to hold the sub-app's routes
subRoutes := make([]*Route, len(route.group.app.stack[m]))
// Iterate over the sub-app's routes
for j, subAppRoute := range route.group.app.stack[m] {
// Clone the sub-app's route
subAppRouteClone := app.copyRoute(subAppRoute)
// Add the parent route's path as a prefix to the sub-app's route
app.addPrefixToRoute(route.path, subAppRouteClone)
// Add the cloned sub-app's route to the slice of sub-app routes
subRoutes[j] = subAppRouteClone
}
// Insert the sub-app's routes into the parent app's stack
newStack := make([]*Route, len(app.stack[m])+len(subRoutes)-1)
copy(newStack[:i], app.stack[m][:i])
copy(newStack[i:i+len(subRoutes)], subRoutes)
copy(newStack[i+len(subRoutes):], app.stack[m][i+1:])
app.stack[m] = newStack
// Decrease the parent app's route count to account for the mounted app's original route
atomic.AddUint32(&app.routesCount, ^uint32(0))
i--
// Increase the parent app's route count to account for the sub-app's routes
atomic.AddUint32(&app.routesCount, uint32(len(subRoutes)))
// Mark the parent app's routes as refreshed
app.routesRefreshed = true
// update stackLen after appending subRoutes to app.stack[m]
stackLen = len(app.stack[m])
}
}
atomic.StoreUint32(&app.handlersCount, handlersCount)
}
+740
View File
@@ -0,0 +1,740 @@
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 📄 Github Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
// ⚠️ This path parser was inspired by ucarion/urlpath (MIT License).
// 💖 Maintained and modified for Fiber by @renewerner87
package fiber
import (
"regexp"
"strconv"
"strings"
"time"
"unicode"
"github.com/google/uuid"
"github.com/gofiber/fiber/v2/utils"
)
// routeParser holds the path segments and param names
type routeParser struct {
segs []*routeSegment // the parsed segments of the route
params []string // that parameter names the parsed route
wildCardCount int // number of wildcard parameters, used internally to give the wildcard parameter its number
plusCount int // number of plus parameters, used internally to give the plus parameter its number
}
// paramsSeg holds the segment metadata
type routeSegment struct {
// const information
Const string // constant part of the route
// parameter information
IsParam bool // Truth value that indicates whether it is a parameter or a constant part
ParamName string // name of the parameter for access to it, for wildcards and plus parameters access iterators starting with 1 are added
ComparePart string // search part to find the end of the parameter
PartCount int // how often is the search part contained in the non-param segments? -> necessary for greedy search
IsGreedy bool // indicates whether the parameter is greedy or not, is used with wildcard and plus
IsOptional bool // indicates whether the parameter is optional or not
// common information
IsLast bool // shows if the segment is the last one for the route
HasOptionalSlash bool // segment has the possibility of an optional slash
Constraints []*Constraint // Constraint type if segment is a parameter, if not it will be set to noConstraint by default
Length int // length of the parameter for segment, when its 0 then the length is undetermined
// future TODO: add support for optional groups "/abc(/def)?"
}
// different special routing signs
const (
wildcardParam byte = '*' // indicates an optional greedy parameter
plusParam byte = '+' // indicates a required greedy parameter
optionalParam byte = '?' // concludes a parameter by name and makes it optional
paramStarterChar byte = ':' // start character for a parameter with name
slashDelimiter byte = '/' // separator for the route, unlike the other delimiters this character at the end can be optional
escapeChar byte = '\\' // escape character
paramConstraintStart byte = '<' // start of type constraint for a parameter
paramConstraintEnd byte = '>' // end of type constraint for a parameter
paramConstraintSeparator byte = ';' // separator of type constraints for a parameter
paramConstraintDataStart byte = '(' // start of data of type constraint for a parameter
paramConstraintDataEnd byte = ')' // end of data of type constraint for a parameter
paramConstraintDataSeparator byte = ',' // separator of datas of type constraint for a parameter
)
// TypeConstraint parameter constraint types
type TypeConstraint int16
type Constraint struct {
ID TypeConstraint
RegexCompiler *regexp.Regexp
Data []string
}
const (
noConstraint TypeConstraint = iota + 1
intConstraint
boolConstraint
floatConstraint
alphaConstraint
datetimeConstraint
guidConstraint
minLenConstraint
maxLenConstraint
lenConstraint
betweenLenConstraint
minConstraint
maxConstraint
rangeConstraint
regexConstraint
)
// list of possible parameter and segment delimiter
var (
// slash has a special role, unlike the other parameters it must not be interpreted as a parameter
routeDelimiter = []byte{slashDelimiter, '-', '.'}
// list of greedy parameters
greedyParameters = []byte{wildcardParam, plusParam}
// list of chars for the parameter recognizing
parameterStartChars = []byte{wildcardParam, plusParam, paramStarterChar}
// list of chars of delimiters and the starting parameter name char
parameterDelimiterChars = append([]byte{paramStarterChar, escapeChar}, routeDelimiter...)
// list of chars to find the end of a parameter
parameterEndChars = append([]byte{optionalParam}, parameterDelimiterChars...)
// list of parameter constraint start
parameterConstraintStartChars = []byte{paramConstraintStart}
// list of parameter constraint end
parameterConstraintEndChars = []byte{paramConstraintEnd}
// list of parameter separator
parameterConstraintSeparatorChars = []byte{paramConstraintSeparator}
// list of parameter constraint data start
parameterConstraintDataStartChars = []byte{paramConstraintDataStart}
// list of parameter constraint data end
parameterConstraintDataEndChars = []byte{paramConstraintDataEnd}
// list of parameter constraint data separator
parameterConstraintDataSeparatorChars = []byte{paramConstraintDataSeparator}
)
// RoutePatternMatch checks if a given path matches a Fiber route pattern.
func RoutePatternMatch(path, pattern string, cfg ...Config) bool {
// See logic in (*Route).match and (*App).register
var ctxParams [maxParams]string
config := Config{}
if len(cfg) > 0 {
config = cfg[0]
}
if path == "" {
path = "/"
}
// Cannot have an empty pattern
if pattern == "" {
pattern = "/"
}
// Pattern always start with a '/'
if pattern[0] != '/' {
pattern = "/" + pattern
}
patternPretty := pattern
// Case-sensitive routing, all to lowercase
if !config.CaseSensitive {
patternPretty = utils.ToLower(patternPretty)
path = utils.ToLower(path)
}
// Strict routing, remove trailing slashes
if !config.StrictRouting && len(patternPretty) > 1 {
patternPretty = utils.TrimRight(patternPretty, '/')
}
parser := parseRoute(patternPretty)
if patternPretty == "/" && path == "/" {
return true
// '*' wildcard matches any path
} else if patternPretty == "/*" {
return true
}
// Does this route have parameters
if len(parser.params) > 0 {
if match := parser.getMatch(path, path, &ctxParams, false); match {
return true
}
}
// Check for a simple match
patternPretty = RemoveEscapeChar(patternPretty)
if len(patternPretty) == len(path) && patternPretty == path {
return true
}
// No match
return false
}
// parseRoute analyzes the route and divides it into segments for constant areas and parameters,
// this information is needed later when assigning the requests to the declared routes
func parseRoute(pattern string) routeParser {
parser := routeParser{}
part := ""
for len(pattern) > 0 {
nextParamPosition := findNextParamPosition(pattern)
// handle the parameter part
if nextParamPosition == 0 {
processedPart, seg := parser.analyseParameterPart(pattern)
parser.params, parser.segs, part = append(parser.params, seg.ParamName), append(parser.segs, seg), processedPart
} else {
processedPart, seg := parser.analyseConstantPart(pattern, nextParamPosition)
parser.segs, part = append(parser.segs, seg), processedPart
}
// reduce the pattern by the processed parts
if len(part) == len(pattern) {
break
}
pattern = pattern[len(part):]
}
// mark last segment
if len(parser.segs) > 0 {
parser.segs[len(parser.segs)-1].IsLast = true
}
parser.segs = addParameterMetaInfo(parser.segs)
return parser
}
// addParameterMetaInfo add important meta information to the parameter segments
// to simplify the search for the end of the parameter
func addParameterMetaInfo(segs []*routeSegment) []*routeSegment {
var comparePart string
segLen := len(segs)
// loop from end to begin
for i := segLen - 1; i >= 0; i-- {
// set the compare part for the parameter
if segs[i].IsParam {
// important for finding the end of the parameter
segs[i].ComparePart = RemoveEscapeChar(comparePart)
} else {
comparePart = segs[i].Const
if len(comparePart) > 1 {
comparePart = utils.TrimRight(comparePart, slashDelimiter)
}
}
}
// loop from begin to end
for i := 0; i < segLen; i++ {
// check how often the compare part is in the following const parts
if segs[i].IsParam {
// check if parameter segments are directly after each other and if one of them is greedy
// in case the next parameter or the current parameter is not a wildcard it's not greedy, we only want one character
if segLen > i+1 && !segs[i].IsGreedy && segs[i+1].IsParam && !segs[i+1].IsGreedy {
segs[i].Length = 1
}
if segs[i].ComparePart == "" {
continue
}
for j := i + 1; j <= len(segs)-1; j++ {
if !segs[j].IsParam {
// count is important for the greedy match
segs[i].PartCount += strings.Count(segs[j].Const, segs[i].ComparePart)
}
}
// check if the end of the segment is a optional slash and then if the segement is optional or the last one
} else if segs[i].Const[len(segs[i].Const)-1] == slashDelimiter && (segs[i].IsLast || (segLen > i+1 && segs[i+1].IsOptional)) {
segs[i].HasOptionalSlash = true
}
}
return segs
}
// findNextParamPosition search for the next possible parameter start position
func findNextParamPosition(pattern string) int {
nextParamPosition := findNextNonEscapedCharsetPosition(pattern, parameterStartChars)
if nextParamPosition != -1 && len(pattern) > nextParamPosition && pattern[nextParamPosition] != wildcardParam {
// search for parameter characters for the found parameter start,
// if there are more, move the parameter start to the last parameter char
for found := findNextNonEscapedCharsetPosition(pattern[nextParamPosition+1:], parameterStartChars); found == 0; {
nextParamPosition++
if len(pattern) > nextParamPosition {
break
}
}
}
return nextParamPosition
}
// analyseConstantPart find the end of the constant part and create the route segment
func (*routeParser) analyseConstantPart(pattern string, nextParamPosition int) (string, *routeSegment) {
// handle the constant part
processedPart := pattern
if nextParamPosition != -1 {
// remove the constant part until the parameter
processedPart = pattern[:nextParamPosition]
}
constPart := RemoveEscapeChar(processedPart)
return processedPart, &routeSegment{
Const: constPart,
Length: len(constPart),
}
}
// analyseParameterPart find the parameter end and create the route segment
func (routeParser *routeParser) analyseParameterPart(pattern string) (string, *routeSegment) {
isWildCard := pattern[0] == wildcardParam
isPlusParam := pattern[0] == plusParam
var parameterEndPosition int
if strings.ContainsRune(pattern, rune(paramConstraintStart)) && strings.ContainsRune(pattern, rune(paramConstraintEnd)) {
parameterEndPosition = findNextCharsetPositionConstraint(pattern[1:], parameterEndChars)
} else {
parameterEndPosition = findNextNonEscapedCharsetPosition(pattern[1:], parameterEndChars)
}
parameterConstraintStart := -1
parameterConstraintEnd := -1
// handle wildcard end
switch {
case isWildCard, isPlusParam:
parameterEndPosition = 0
case parameterEndPosition == -1:
parameterEndPosition = len(pattern) - 1
case !isInCharset(pattern[parameterEndPosition+1], parameterDelimiterChars):
parameterEndPosition++
}
// find constraint part if exists in the parameter part and remove it
if parameterEndPosition > 0 {
parameterConstraintStart = findNextNonEscapedCharsetPosition(pattern[0:parameterEndPosition], parameterConstraintStartChars)
parameterConstraintEnd = findLastCharsetPosition(pattern[0:parameterEndPosition+1], parameterConstraintEndChars)
}
// cut params part
processedPart := pattern[0 : parameterEndPosition+1]
paramName := RemoveEscapeChar(GetTrimmedParam(processedPart))
// Check has constraint
var constraints []*Constraint
if hasConstraint := parameterConstraintStart != -1 && parameterConstraintEnd != -1; hasConstraint {
constraintString := pattern[parameterConstraintStart+1 : parameterConstraintEnd]
userConstraints := splitNonEscaped(constraintString, string(parameterConstraintSeparatorChars))
constraints = make([]*Constraint, 0, len(userConstraints))
for _, c := range userConstraints {
start := findNextNonEscapedCharsetPosition(c, parameterConstraintDataStartChars)
end := findLastCharsetPosition(c, parameterConstraintDataEndChars)
// Assign constraint
if start != -1 && end != -1 {
constraint := &Constraint{
ID: getParamConstraintType(c[:start]),
}
// remove escapes from data
if constraint.ID != regexConstraint {
constraint.Data = splitNonEscaped(c[start+1:end], string(parameterConstraintDataSeparatorChars))
if len(constraint.Data) == 1 {
constraint.Data[0] = RemoveEscapeChar(constraint.Data[0])
} else if len(constraint.Data) == 2 { // This is fine, we simply expect two parts
constraint.Data[0] = RemoveEscapeChar(constraint.Data[0])
constraint.Data[1] = RemoveEscapeChar(constraint.Data[1])
}
}
// Precompile regex if has regex constraint
if constraint.ID == regexConstraint {
constraint.Data = []string{c[start+1 : end]}
constraint.RegexCompiler = regexp.MustCompile(constraint.Data[0])
}
constraints = append(constraints, constraint)
} else {
constraints = append(constraints, &Constraint{
ID: getParamConstraintType(c),
Data: []string{},
})
}
}
paramName = RemoveEscapeChar(GetTrimmedParam(pattern[0:parameterConstraintStart]))
}
// add access iterator to wildcard and plus
if isWildCard {
routeParser.wildCardCount++
paramName += strconv.Itoa(routeParser.wildCardCount)
} else if isPlusParam {
routeParser.plusCount++
paramName += strconv.Itoa(routeParser.plusCount)
}
segment := &routeSegment{
ParamName: paramName,
IsParam: true,
IsOptional: isWildCard || pattern[parameterEndPosition] == optionalParam,
IsGreedy: isWildCard || isPlusParam,
}
if len(constraints) > 0 {
segment.Constraints = constraints
}
return processedPart, segment
}
// isInCharset check is the given character in the charset list
func isInCharset(searchChar byte, charset []byte) bool {
for _, char := range charset {
if char == searchChar {
return true
}
}
return false
}
// findNextCharsetPosition search the next char position from the charset
func findNextCharsetPosition(search string, charset []byte) int {
nextPosition := -1
for _, char := range charset {
if pos := strings.IndexByte(search, char); pos != -1 && (pos < nextPosition || nextPosition == -1) {
nextPosition = pos
}
}
return nextPosition
}
// findNextCharsetPosition search the last char position from the charset
func findLastCharsetPosition(search string, charset []byte) int {
lastPosition := -1
for _, char := range charset {
if pos := strings.LastIndexByte(search, char); pos != -1 && (pos < lastPosition || lastPosition == -1) {
lastPosition = pos
}
}
return lastPosition
}
// findNextCharsetPositionConstraint search the next char position from the charset
// unlike findNextCharsetPosition, it takes care of constraint start-end chars to parse route pattern
func findNextCharsetPositionConstraint(search string, charset []byte) int {
constraintStart := findNextNonEscapedCharsetPosition(search, parameterConstraintStartChars)
constraintEnd := findNextNonEscapedCharsetPosition(search, parameterConstraintEndChars)
nextPosition := -1
for _, char := range charset {
pos := strings.IndexByte(search, char)
if pos != -1 && (pos < nextPosition || nextPosition == -1) {
if (pos > constraintStart && pos > constraintEnd) || (pos < constraintStart && pos < constraintEnd) {
nextPosition = pos
}
}
}
return nextPosition
}
// findNextNonEscapedCharsetPosition search the next char position from the charset and skip the escaped characters
func findNextNonEscapedCharsetPosition(search string, charset []byte) int {
pos := findNextCharsetPosition(search, charset)
for pos > 0 && search[pos-1] == escapeChar {
if len(search) == pos+1 {
// escaped character is at the end
return -1
}
nextPossiblePos := findNextCharsetPosition(search[pos+1:], charset)
if nextPossiblePos == -1 {
return -1
}
// the previous character is taken into consideration
pos = nextPossiblePos + pos + 1
}
return pos
}
// splitNonEscaped slices s into all substrings separated by sep and returns a slice of the substrings between those separators
// This function also takes a care of escape char when splitting.
func splitNonEscaped(s, sep string) []string {
var result []string
i := findNextNonEscapedCharsetPosition(s, []byte(sep))
for i > -1 {
result = append(result, s[:i])
s = s[i+len(sep):]
i = findNextNonEscapedCharsetPosition(s, []byte(sep))
}
return append(result, s)
}
// getMatch parses the passed url and tries to match it against the route segments and determine the parameter positions
func (routeParser *routeParser) getMatch(detectionPath, path string, params *[maxParams]string, partialCheck bool) bool { //nolint: revive // Accepting a bool param is fine here
var i, paramsIterator, partLen int
for _, segment := range routeParser.segs {
partLen = len(detectionPath)
// check const segment
if !segment.IsParam {
i = segment.Length
// is optional part or the const part must match with the given string
// check if the end of the segment is an optional slash
if segment.HasOptionalSlash && partLen == i-1 && detectionPath == segment.Const[:i-1] {
i--
} else if !(i <= partLen && detectionPath[:i] == segment.Const) {
return false
}
} else {
// determine parameter length
i = findParamLen(detectionPath, segment)
if !segment.IsOptional && i == 0 {
return false
}
// take over the params positions
params[paramsIterator] = path[:i]
if !(segment.IsOptional && i == 0) {
// check constraint
for _, c := range segment.Constraints {
if matched := c.CheckConstraint(params[paramsIterator]); !matched {
return false
}
}
}
paramsIterator++
}
// reduce founded part from the string
if partLen > 0 {
detectionPath, path = detectionPath[i:], path[i:]
}
}
if detectionPath != "" && !partialCheck {
return false
}
return true
}
// findParamLen for the expressjs wildcard behavior (right to left greedy)
// look at the other segments and take what is left for the wildcard from right to left
func findParamLen(s string, segment *routeSegment) int {
if segment.IsLast {
return findParamLenForLastSegment(s, segment)
}
if segment.Length != 0 && len(s) >= segment.Length {
return segment.Length
} else if segment.IsGreedy {
// Search the parameters until the next constant part
// special logic for greedy params
searchCount := strings.Count(s, segment.ComparePart)
if searchCount > 1 {
return findGreedyParamLen(s, searchCount, segment)
}
}
if len(segment.ComparePart) == 1 {
if constPosition := strings.IndexByte(s, segment.ComparePart[0]); constPosition != -1 {
return constPosition
}
} else if constPosition := strings.Index(s, segment.ComparePart); constPosition != -1 {
// if the compare part was found, but contains a slash although this part is not greedy, then it must not match
// example: /api/:param/fixedEnd -> path: /api/123/456/fixedEnd = no match , /api/123/fixedEnd = match
if !segment.IsGreedy && strings.IndexByte(s[:constPosition], slashDelimiter) != -1 {
return 0
}
return constPosition
}
return len(s)
}
// findParamLenForLastSegment get the length of the parameter if it is the last segment
func findParamLenForLastSegment(s string, seg *routeSegment) int {
if !seg.IsGreedy {
if i := strings.IndexByte(s, slashDelimiter); i != -1 {
return i
}
}
return len(s)
}
// findGreedyParamLen get the length of the parameter for greedy segments from right to left
func findGreedyParamLen(s string, searchCount int, segment *routeSegment) int {
// check all from right to left segments
for i := segment.PartCount; i > 0 && searchCount > 0; i-- {
searchCount--
if constPosition := strings.LastIndex(s, segment.ComparePart); constPosition != -1 {
s = s[:constPosition]
} else {
break
}
}
return len(s)
}
// GetTrimmedParam trims the ':' & '?' from a string
func GetTrimmedParam(param string) string {
start := 0
end := len(param)
if end == 0 || param[start] != paramStarterChar { // is not a param
return param
}
start++
if param[end-1] == optionalParam { // is ?
end--
}
return param[start:end]
}
// RemoveEscapeChar remove escape characters
func RemoveEscapeChar(word string) string {
if strings.IndexByte(word, escapeChar) != -1 {
return strings.ReplaceAll(word, string(escapeChar), "")
}
return word
}
func getParamConstraintType(constraintPart string) TypeConstraint {
switch constraintPart {
case ConstraintInt:
return intConstraint
case ConstraintBool:
return boolConstraint
case ConstraintFloat:
return floatConstraint
case ConstraintAlpha:
return alphaConstraint
case ConstraintGuid:
return guidConstraint
case ConstraintMinLen, ConstraintMinLenLower:
return minLenConstraint
case ConstraintMaxLen, ConstraintMaxLenLower:
return maxLenConstraint
case ConstraintLen:
return lenConstraint
case ConstraintBetweenLen, ConstraintBetweenLenLower:
return betweenLenConstraint
case ConstraintMin:
return minConstraint
case ConstraintMax:
return maxConstraint
case ConstraintRange:
return rangeConstraint
case ConstraintDatetime:
return datetimeConstraint
case ConstraintRegex:
return regexConstraint
default:
return noConstraint
}
}
//nolint:errcheck // TODO: Properly check _all_ errors in here, log them & immediately return
func (c *Constraint) CheckConstraint(param string) bool {
var err error
var num int
// check data exists
needOneData := []TypeConstraint{minLenConstraint, maxLenConstraint, lenConstraint, minConstraint, maxConstraint, datetimeConstraint, regexConstraint}
needTwoData := []TypeConstraint{betweenLenConstraint, rangeConstraint}
for _, data := range needOneData {
if c.ID == data && len(c.Data) == 0 {
return false
}
}
for _, data := range needTwoData {
if c.ID == data && len(c.Data) < 2 {
return false
}
}
// check constraints
switch c.ID {
case noConstraint:
// Nothing to check
case intConstraint:
_, err = strconv.Atoi(param)
case boolConstraint:
_, err = strconv.ParseBool(param)
case floatConstraint:
_, err = strconv.ParseFloat(param, 32)
case alphaConstraint:
for _, r := range param {
if !unicode.IsLetter(r) {
return false
}
}
case guidConstraint:
_, err = uuid.Parse(param)
case minLenConstraint:
data, _ := strconv.Atoi(c.Data[0])
if len(param) < data {
return false
}
case maxLenConstraint:
data, _ := strconv.Atoi(c.Data[0])
if len(param) > data {
return false
}
case lenConstraint:
data, _ := strconv.Atoi(c.Data[0])
if len(param) != data {
return false
}
case betweenLenConstraint:
data, _ := strconv.Atoi(c.Data[0])
data2, _ := strconv.Atoi(c.Data[1])
length := len(param)
if length < data || length > data2 {
return false
}
case minConstraint:
data, _ := strconv.Atoi(c.Data[0])
num, err = strconv.Atoi(param)
if num < data {
return false
}
case maxConstraint:
data, _ := strconv.Atoi(c.Data[0])
num, err = strconv.Atoi(param)
if num > data {
return false
}
case rangeConstraint:
data, _ := strconv.Atoi(c.Data[0])
data2, _ := strconv.Atoi(c.Data[1])
num, err = strconv.Atoi(param)
if num < data || num > data2 {
return false
}
case datetimeConstraint:
_, err = time.Parse(c.Data[0], param)
case regexConstraint:
if match := c.RegexCompiler.MatchString(param); !match {
return false
}
}
return err == nil
}
+179
View File
@@ -0,0 +1,179 @@
package fiber
import (
"crypto/tls"
"errors"
"fmt"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/valyala/fasthttp/reuseport"
"github.com/gofiber/fiber/v2/log"
)
const (
envPreforkChildKey = "FIBER_PREFORK_CHILD"
envPreforkChildVal = "1"
)
var (
testPreforkMaster = false
testOnPrefork = false
)
// IsChild determines if the current process is a child of Prefork
func IsChild() bool {
return os.Getenv(envPreforkChildKey) == envPreforkChildVal
}
// prefork manages child processes to make use of the OS REUSEPORT or REUSEADDR feature
func (app *App) prefork(network, addr string, tlsConfig *tls.Config) error {
// 👶 child process 👶
if IsChild() {
// use 1 cpu core per child process
runtime.GOMAXPROCS(1)
// Linux will use SO_REUSEPORT and Windows falls back to SO_REUSEADDR
// Only tcp4 or tcp6 is supported when preforking, both are not supported
ln, err := reuseport.Listen(network, addr)
if err != nil {
if !app.config.DisableStartupMessage {
const sleepDuration = 100 * time.Millisecond
time.Sleep(sleepDuration) // avoid colliding with startup message
}
return fmt.Errorf("prefork: %w", err)
}
// wrap a tls config around the listener if provided
if tlsConfig != nil {
ln = tls.NewListener(ln, tlsConfig)
}
// kill current child proc when master exits
go watchMaster()
// prepare the server for the start
app.startupProcess()
// listen for incoming connections
return app.server.Serve(ln)
}
// 👮 master process 👮
type child struct {
pid int
err error
}
// create variables
max := runtime.GOMAXPROCS(0)
childs := make(map[int]*exec.Cmd)
channel := make(chan child, max)
// kill child procs when master exits
defer func() {
for _, proc := range childs {
if err := proc.Process.Kill(); err != nil {
if !errors.Is(err, os.ErrProcessDone) {
log.Errorf("prefork: failed to kill child: %v", err)
}
}
}
}()
// collect child pids
var pids []string
// launch child procs
for i := 0; i < max; i++ {
cmd := exec.Command(os.Args[0], os.Args[1:]...) //nolint:gosec // It's fine to launch the same process again
if testPreforkMaster {
// When test prefork master,
// just start the child process with a dummy cmd,
// which will exit soon
cmd = dummyCmd()
}
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
// add fiber prefork child flag into child proc env
cmd.Env = append(os.Environ(),
fmt.Sprintf("%s=%s", envPreforkChildKey, envPreforkChildVal),
)
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start a child prefork process, error: %w", err)
}
// store child process
pid := cmd.Process.Pid
childs[pid] = cmd
pids = append(pids, strconv.Itoa(pid))
// execute fork hook
if app.hooks != nil {
if testOnPrefork {
app.hooks.executeOnForkHooks(dummyPid)
} else {
app.hooks.executeOnForkHooks(pid)
}
}
// notify master if child crashes
go func() {
channel <- child{pid, cmd.Wait()}
}()
}
// Run onListen hooks
// Hooks have to be run here as different as non-prefork mode due to they should run as child or master
app.runOnListenHooks(app.prepareListenData(addr, tlsConfig != nil))
// Print startup message
if !app.config.DisableStartupMessage {
app.startupMessage(addr, tlsConfig != nil, ","+strings.Join(pids, ","))
}
// return error if child crashes
return (<-channel).err
}
// watchMaster watches child procs
func watchMaster() {
if runtime.GOOS == "windows" {
// finds parent process,
// and waits for it to exit
p, err := os.FindProcess(os.Getppid())
if err == nil {
_, _ = p.Wait() //nolint:errcheck // It is fine to ignore the error here
}
os.Exit(1) //nolint:revive // Calling os.Exit is fine here in the prefork
}
// if it is equal to 1 (init process ID),
// it indicates that the master process has exited
const watchInterval = 500 * time.Millisecond
for range time.NewTicker(watchInterval).C {
if os.Getppid() == 1 {
os.Exit(1) //nolint:revive // Calling os.Exit is fine here in the prefork
}
}
}
var (
dummyPid = 1
dummyChildCmd atomic.Value
)
// dummyCmd is for internal prefork testing
func dummyCmd() *exec.Cmd {
command := "go"
if storeCommand := dummyChildCmd.Load(); storeCommand != nil && storeCommand != "" {
command = storeCommand.(string) //nolint:forcetypeassert,errcheck // We always store a string in here
}
if runtime.GOOS == "windows" {
return exec.Command("cmd", "/C", command, "version")
}
return exec.Command(command, "version")
}
+518
View File
@@ -0,0 +1,518 @@
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 🤖 Github Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package fiber
import (
"fmt"
"html"
"sort"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
// Router defines all router handle interface, including app and group router.
type Router interface {
Use(args ...interface{}) Router
Get(path string, handlers ...Handler) Router
Head(path string, handlers ...Handler) Router
Post(path string, handlers ...Handler) Router
Put(path string, handlers ...Handler) Router
Delete(path string, handlers ...Handler) Router
Connect(path string, handlers ...Handler) Router
Options(path string, handlers ...Handler) Router
Trace(path string, handlers ...Handler) Router
Patch(path string, handlers ...Handler) Router
Add(method, path string, handlers ...Handler) Router
Static(prefix, root string, config ...Static) Router
All(path string, handlers ...Handler) Router
Group(prefix string, handlers ...Handler) Router
Route(prefix string, fn func(router Router), name ...string) Router
Mount(prefix string, fiber *App) Router
Name(name string) Router
}
// Route is a struct that holds all metadata for each registered handler.
type Route struct {
// ### important: always keep in sync with the copy method "app.copyRoute" ###
// Data for routing
pos uint32 // Position in stack -> important for the sort of the matched routes
use bool // USE matches path prefixes
mount bool // Indicated a mounted app on a specific route
star bool // Path equals '*'
root bool // Path equals '/'
path string // Prettified path
routeParser routeParser // Parameter parser
group *Group // Group instance. used for routes in groups
// Public fields
Method string `json:"method"` // HTTP method
Name string `json:"name"` // Route's name
//nolint:revive // Having both a Path (uppercase) and a path (lowercase) is fine
Path string `json:"path"` // Original registered route path
Params []string `json:"params"` // Case sensitive param keys
Handlers []Handler `json:"-"` // Ctx handlers
}
func (r *Route) match(detectionPath, path string, params *[maxParams]string) bool {
// root detectionPath check
if r.root && detectionPath == "/" {
return true
// '*' wildcard matches any detectionPath
} else if r.star {
if len(path) > 1 {
params[0] = path[1:]
} else {
params[0] = ""
}
return true
}
// Does this route have parameters
if len(r.Params) > 0 {
// Match params
if match := r.routeParser.getMatch(detectionPath, path, params, r.use); match {
// Get params from the path detectionPath
return match
}
}
// Is this route a Middleware?
if r.use {
// Single slash will match or detectionPath prefix
if r.root || strings.HasPrefix(detectionPath, r.path) {
return true
}
// Check for a simple detectionPath match
} else if len(r.path) == len(detectionPath) && r.path == detectionPath {
return true
}
// No match
return false
}
func (app *App) next(c *Ctx) (bool, error) {
// Get stack length
tree, ok := app.treeStack[c.methodINT][c.treePath]
if !ok {
tree = app.treeStack[c.methodINT][""]
}
lenTree := len(tree) - 1
// Loop over the route stack starting from previous index
for c.indexRoute < lenTree {
// Increment route index
c.indexRoute++
// Get *Route
route := tree[c.indexRoute]
var match bool
var err error
// skip for mounted apps
if route.mount {
continue
}
// Check if it matches the request path
match = route.match(c.detectionPath, c.path, &c.values)
if !match {
// No match, next route
continue
}
// Pass route reference and param values
c.route = route
// Non use handler matched
if !c.matched && !route.use {
c.matched = true
}
// Execute first handler of route
c.indexHandler = 0
if len(route.Handlers) > 0 {
err = route.Handlers[0](c)
}
return match, err // Stop scanning the stack
}
// If c.Next() does not match, return 404
err := NewError(StatusNotFound, "Cannot "+c.method+" "+html.EscapeString(c.pathOriginal))
if !c.matched && app.methodExist(c) {
// If no match, scan stack again if other methods match the request
// Moved from app.handler because middleware may break the route chain
err = ErrMethodNotAllowed
}
return false, err
}
func (app *App) handler(rctx *fasthttp.RequestCtx) { //revive:disable-line:confusing-naming // Having both a Handler() (uppercase) and a handler() (lowercase) is fine. TODO: Use nolint:revive directive instead. See https://github.com/golangci/golangci-lint/issues/3476
// Acquire Ctx with fasthttp request from pool
c := app.AcquireCtx(rctx)
defer app.ReleaseCtx(c)
// handle invalid http method directly
if c.methodINT == -1 {
_ = c.Status(StatusBadRequest).SendString("Invalid http method") //nolint:errcheck // It is fine to ignore the error here
return
}
// Find match in stack
match, err := app.next(c)
if err != nil {
if catch := c.app.ErrorHandler(c, err); catch != nil {
_ = c.SendStatus(StatusInternalServerError) //nolint:errcheck // It is fine to ignore the error here
}
// TODO: Do we need to return here?
}
// Generate ETag if enabled
if match && app.config.ETag {
setETag(c, false)
}
}
func (app *App) addPrefixToRoute(prefix string, route *Route) *Route {
prefixedPath := getGroupPath(prefix, route.Path)
prettyPath := prefixedPath
// Case-sensitive routing, all to lowercase
if !app.config.CaseSensitive {
prettyPath = utils.ToLower(prettyPath)
}
// Strict routing, remove trailing slashes
if !app.config.StrictRouting && len(prettyPath) > 1 {
prettyPath = utils.TrimRight(prettyPath, '/')
}
route.Path = prefixedPath
route.path = RemoveEscapeChar(prettyPath)
route.routeParser = parseRoute(prettyPath)
route.root = false
route.star = false
return route
}
func (*App) copyRoute(route *Route) *Route {
return &Route{
// Router booleans
use: route.use,
mount: route.mount,
star: route.star,
root: route.root,
// Path data
path: route.path,
routeParser: route.routeParser,
// misc
pos: route.pos,
// Public data
Path: route.Path,
Params: route.Params,
Name: route.Name,
Method: route.Method,
Handlers: route.Handlers,
}
}
func (app *App) register(method, pathRaw string, group *Group, handlers ...Handler) {
// Uppercase HTTP methods
method = utils.ToUpper(method)
// Check if the HTTP method is valid unless it's USE
if method != methodUse && app.methodInt(method) == -1 {
panic(fmt.Sprintf("add: invalid http method %s\n", method))
}
// is mounted app
isMount := group != nil && group.app != app
// A route requires atleast one ctx handler
if len(handlers) == 0 && !isMount {
panic(fmt.Sprintf("missing handler in route: %s\n", pathRaw))
}
// Cannot have an empty path
if pathRaw == "" {
pathRaw = "/"
}
// Path always start with a '/'
if pathRaw[0] != '/' {
pathRaw = "/" + pathRaw
}
// Create a stripped path in-case sensitive / trailing slashes
pathPretty := pathRaw
// Case-sensitive routing, all to lowercase
if !app.config.CaseSensitive {
pathPretty = utils.ToLower(pathPretty)
}
// Strict routing, remove trailing slashes
if !app.config.StrictRouting && len(pathPretty) > 1 {
pathPretty = utils.TrimRight(pathPretty, '/')
}
// Is layer a middleware?
isUse := method == methodUse
// Is path a direct wildcard?
isStar := pathPretty == "/*"
// Is path a root slash?
isRoot := pathPretty == "/"
// Parse path parameters
parsedRaw := parseRoute(pathRaw)
parsedPretty := parseRoute(pathPretty)
// Create route metadata without pointer
route := Route{
// Router booleans
use: isUse,
mount: isMount,
star: isStar,
root: isRoot,
// Path data
path: RemoveEscapeChar(pathPretty),
routeParser: parsedPretty,
Params: parsedRaw.params,
// Group data
group: group,
// Public data
Path: pathRaw,
Method: method,
Handlers: handlers,
}
// Increment global handler count
atomic.AddUint32(&app.handlersCount, uint32(len(handlers)))
// Middleware route matches all HTTP methods
if isUse {
// Add route to all HTTP methods stack
for _, m := range app.config.RequestMethods {
// Create a route copy to avoid duplicates during compression
r := route
app.addRoute(m, &r, isMount)
}
} else {
// Add route to stack
app.addRoute(method, &route, isMount)
}
}
func (app *App) registerStatic(prefix, root string, config ...Static) {
// For security, we want to restrict to the current work directory.
if root == "" {
root = "."
}
// Cannot have an empty prefix
if prefix == "" {
prefix = "/"
}
// Prefix always start with a '/' or '*'
if prefix[0] != '/' {
prefix = "/" + prefix
}
// in case-sensitive routing, all to lowercase
if !app.config.CaseSensitive {
prefix = utils.ToLower(prefix)
}
// Strip trailing slashes from the root path
if len(root) > 0 && root[len(root)-1] == '/' {
root = root[:len(root)-1]
}
// Is prefix a direct wildcard?
isStar := prefix == "/*"
// Is prefix a root slash?
isRoot := prefix == "/"
// Is prefix a partial wildcard?
if strings.Contains(prefix, "*") {
// /john* -> /john
isStar = true
prefix = strings.Split(prefix, "*")[0]
// Fix this later
}
prefixLen := len(prefix)
if prefixLen > 1 && prefix[prefixLen-1:] == "/" {
// /john/ -> /john
prefixLen--
prefix = prefix[:prefixLen]
}
const cacheDuration = 10 * time.Second
// Fileserver settings
fs := &fasthttp.FS{
Root: root,
AllowEmptyRoot: true,
GenerateIndexPages: false,
AcceptByteRange: false,
Compress: false,
CompressedFileSuffix: app.config.CompressedFileSuffix,
CacheDuration: cacheDuration,
IndexNames: []string{"index.html"},
PathRewrite: func(fctx *fasthttp.RequestCtx) []byte {
path := fctx.Path()
if len(path) >= prefixLen {
if isStar && app.getString(path[0:prefixLen]) == prefix {
path = append(path[0:0], '/')
} else {
path = path[prefixLen:]
if len(path) == 0 || path[len(path)-1] != '/' {
path = append(path, '/')
}
}
}
if len(path) > 0 && path[0] != '/' {
path = append([]byte("/"), path...)
}
return path
},
PathNotFound: func(fctx *fasthttp.RequestCtx) {
fctx.Response.SetStatusCode(StatusNotFound)
},
}
// Set config if provided
var cacheControlValue string
var modifyResponse Handler
if len(config) > 0 {
maxAge := config[0].MaxAge
if maxAge > 0 {
cacheControlValue = "public, max-age=" + strconv.Itoa(maxAge)
}
fs.CacheDuration = config[0].CacheDuration
fs.Compress = config[0].Compress
fs.AcceptByteRange = config[0].ByteRange
fs.GenerateIndexPages = config[0].Browse
if config[0].Index != "" {
fs.IndexNames = []string{config[0].Index}
}
modifyResponse = config[0].ModifyResponse
}
fileHandler := fs.NewRequestHandler()
handler := func(c *Ctx) error {
// Don't execute middleware if Next returns true
if len(config) != 0 && config[0].Next != nil && config[0].Next(c) {
return c.Next()
}
// Serve file
fileHandler(c.fasthttp)
// Sets the response Content-Disposition header to attachment if the Download option is true
if len(config) > 0 && config[0].Download {
c.Attachment()
}
// Return request if found and not forbidden
status := c.fasthttp.Response.StatusCode()
if status != StatusNotFound && status != StatusForbidden {
if len(cacheControlValue) > 0 {
c.fasthttp.Response.Header.Set(HeaderCacheControl, cacheControlValue)
}
if modifyResponse != nil {
return modifyResponse(c)
}
return nil
}
// Reset response to default
c.fasthttp.SetContentType("") // Issue #420
c.fasthttp.Response.SetStatusCode(StatusOK)
c.fasthttp.Response.SetBodyString("")
// Next middleware
return c.Next()
}
// Create route metadata without pointer
route := Route{
// Router booleans
use: true,
root: isRoot,
path: prefix,
// Public data
Method: MethodGet,
Path: prefix,
Handlers: []Handler{handler},
}
// Increment global handler count
atomic.AddUint32(&app.handlersCount, 1)
// Add route to stack
app.addRoute(MethodGet, &route)
// Add HEAD route
app.addRoute(MethodHead, &route)
}
func (app *App) addRoute(method string, route *Route, isMounted ...bool) {
// Check mounted routes
var mounted bool
if len(isMounted) > 0 {
mounted = isMounted[0]
}
// Get unique HTTP method identifier
m := app.methodInt(method)
// prevent identically route registration
l := len(app.stack[m])
if l > 0 && app.stack[m][l-1].Path == route.Path && route.use == app.stack[m][l-1].use && !route.mount && !app.stack[m][l-1].mount {
preRoute := app.stack[m][l-1]
preRoute.Handlers = append(preRoute.Handlers, route.Handlers...)
} else {
// Increment global route position
route.pos = atomic.AddUint32(&app.routesCount, 1)
route.Method = method
// Add route to the stack
app.stack[m] = append(app.stack[m], route)
app.routesRefreshed = true
}
// Execute onRoute hooks & change latestRoute if not adding mounted route
if !mounted {
app.mutex.Lock()
app.latestRoute = route
if err := app.hooks.executeOnRouteHooks(*route); err != nil {
panic(err)
}
app.mutex.Unlock()
}
}
// buildTree build the prefix tree from the previously registered routes
func (app *App) buildTree() *App {
if !app.routesRefreshed {
return app
}
// loop all the methods and stacks and create the prefix tree
for m := range app.config.RequestMethods {
tsMap := make(map[string][]*Route)
for _, route := range app.stack[m] {
treePath := ""
if len(route.routeParser.segs) > 0 && len(route.routeParser.segs[0].Const) >= 3 {
treePath = route.routeParser.segs[0].Const[:3]
}
// create tree stack
tsMap[treePath] = append(tsMap[treePath], route)
}
app.treeStack[m] = tsMap
}
// loop the methods and tree stacks and add global stack and sort everything
for m := range app.config.RequestMethods {
tsMap := app.treeStack[m]
for treePart := range tsMap {
if treePart != "" {
// merge global tree routes in current tree stack
tsMap[treePart] = uniqueRouteStack(append(tsMap[treePart], tsMap[""]...))
}
// sort tree slices with the positions
slc := tsMap[treePart]
sort.Slice(slc, func(i, j int) bool { return slc[i].pos < slc[j].pos })
}
}
app.routesRefreshed = false
return app
}
+90
View File
@@ -0,0 +1,90 @@
A collection of common functions but with better performance, less allocations and no dependencies created for [Fiber](https://github.com/gofiber/fiber).
```go
// go test -benchmem -run=^$ -bench=Benchmark_ -count=2
Benchmark_ToLowerBytes/fiber-16 42847654 25.7 ns/op 0 B/op 0 allocs/op
Benchmark_ToLowerBytes/fiber-16 46143196 25.7 ns/op 0 B/op 0 allocs/op
Benchmark_ToLowerBytes/default-16 17387322 67.4 ns/op 48 B/op 1 allocs/op
Benchmark_ToLowerBytes/default-16 17906491 67.4 ns/op 48 B/op 1 allocs/op
Benchmark_ToUpperBytes/fiber-16 46143729 25.7 ns/op 0 B/op 0 allocs/op
Benchmark_ToUpperBytes/fiber-16 47989250 25.6 ns/op 0 B/op 0 allocs/op
Benchmark_ToUpperBytes/default-16 15580854 76.7 ns/op 48 B/op 1 allocs/op
Benchmark_ToUpperBytes/default-16 15381202 76.9 ns/op 48 B/op 1 allocs/op
Benchmark_TrimRightBytes/fiber-16 70572459 16.3 ns/op 8 B/op 1 allocs/op
Benchmark_TrimRightBytes/fiber-16 74983597 16.3 ns/op 8 B/op 1 allocs/op
Benchmark_TrimRightBytes/default-16 16212578 74.1 ns/op 40 B/op 2 allocs/op
Benchmark_TrimRightBytes/default-16 16434686 74.1 ns/op 40 B/op 2 allocs/op
Benchmark_TrimLeftBytes/fiber-16 74983128 16.3 ns/op 8 B/op 1 allocs/op
Benchmark_TrimLeftBytes/fiber-16 74985002 16.3 ns/op 8 B/op 1 allocs/op
Benchmark_TrimLeftBytes/default-16 21047868 56.5 ns/op 40 B/op 2 allocs/op
Benchmark_TrimLeftBytes/default-16 21048015 56.5 ns/op 40 B/op 2 allocs/op
Benchmark_TrimBytes/fiber-16 54533307 21.9 ns/op 16 B/op 1 allocs/op
Benchmark_TrimBytes/fiber-16 54532812 21.9 ns/op 16 B/op 1 allocs/op
Benchmark_TrimBytes/default-16 14282517 84.6 ns/op 48 B/op 2 allocs/op
Benchmark_TrimBytes/default-16 14114508 84.7 ns/op 48 B/op 2 allocs/op
Benchmark_EqualFolds/fiber-16 36355153 32.6 ns/op 0 B/op 0 allocs/op
Benchmark_EqualFolds/fiber-16 36355593 32.6 ns/op 0 B/op 0 allocs/op
Benchmark_EqualFolds/default-16 15186220 78.1 ns/op 0 B/op 0 allocs/op
Benchmark_EqualFolds/default-16 15186412 78.3 ns/op 0 B/op 0 allocs/op
Benchmark_UUID/fiber-16 23994625 49.8 ns/op 48 B/op 1 allocs/op
Benchmark_UUID/fiber-16 23994768 50.1 ns/op 48 B/op 1 allocs/op
Benchmark_UUID/default-16 3233772 371 ns/op 208 B/op 6 allocs/op
Benchmark_UUID/default-16 3251295 370 ns/op 208 B/op 6 allocs/op
Benchmark_GetString/unsafe-16 1000000000 0.709 ns/op 0 B/op 0 allocs/op
Benchmark_GetString/unsafe-16 1000000000 0.713 ns/op 0 B/op 0 allocs/op
Benchmark_GetString/default-16 59986202 19.0 ns/op 16 B/op 1 allocs/op
Benchmark_GetString/default-16 63142939 19.0 ns/op 16 B/op 1 allocs/op
Benchmark_GetBytes/unsafe-16 508360195 2.36 ns/op 0 B/op 0 allocs/op
Benchmark_GetBytes/unsafe-16 508359979 2.35 ns/op 0 B/op 0 allocs/op
Benchmark_GetBytes/default-16 46143019 25.7 ns/op 16 B/op 1 allocs/op
Benchmark_GetBytes/default-16 44434734 25.6 ns/op 16 B/op 1 allocs/op
Benchmark_GetMIME/fiber-16 21423750 56.3 ns/op 0 B/op 0 allocs/op
Benchmark_GetMIME/fiber-16 21423559 55.4 ns/op 0 B/op 0 allocs/op
Benchmark_GetMIME/default-16 6735282 173 ns/op 0 B/op 0 allocs/op
Benchmark_GetMIME/default-16 6895002 172 ns/op 0 B/op 0 allocs/op
Benchmark_StatusMessage/fiber-16 1000000000 0.766 ns/op 0 B/op 0 allocs/op
Benchmark_StatusMessage/fiber-16 1000000000 0.767 ns/op 0 B/op 0 allocs/op
Benchmark_StatusMessage/default-16 159538528 7.50 ns/op 0 B/op 0 allocs/op
Benchmark_StatusMessage/default-16 159750830 7.51 ns/op 0 B/op 0 allocs/op
Benchmark_ToUpper/fiber-16 22217408 53.3 ns/op 48 B/op 1 allocs/op
Benchmark_ToUpper/fiber-16 22636554 53.2 ns/op 48 B/op 1 allocs/op
Benchmark_ToUpper/default-16 11108600 108 ns/op 48 B/op 1 allocs/op
Benchmark_ToUpper/default-16 11108580 108 ns/op 48 B/op 1 allocs/op
Benchmark_ToLower/fiber-16 23994720 49.8 ns/op 48 B/op 1 allocs/op
Benchmark_ToLower/fiber-16 23994768 50.1 ns/op 48 B/op 1 allocs/op
Benchmark_ToLower/default-16 10808376 110 ns/op 48 B/op 1 allocs/op
Benchmark_ToLower/default-16 10617034 110 ns/op 48 B/op 1 allocs/op
Benchmark_TrimRight/fiber-16 413699521 2.94 ns/op 0 B/op 0 allocs/op
Benchmark_TrimRight/fiber-16 415131687 2.91 ns/op 0 B/op 0 allocs/op
Benchmark_TrimRight/default-16 23994577 49.1 ns/op 32 B/op 1 allocs/op
Benchmark_TrimRight/default-16 24484249 49.4 ns/op 32 B/op 1 allocs/op
Benchmark_TrimLeft/fiber-16 379661170 3.13 ns/op 0 B/op 0 allocs/op
Benchmark_TrimLeft/fiber-16 382079941 3.16 ns/op 0 B/op 0 allocs/op
Benchmark_TrimLeft/default-16 27900877 41.9 ns/op 32 B/op 1 allocs/op
Benchmark_TrimLeft/default-16 28564898 42.0 ns/op 32 B/op 1 allocs/op
Benchmark_Trim/fiber-16 236632856 4.96 ns/op 0 B/op 0 allocs/op
Benchmark_Trim/fiber-16 237570085 4.93 ns/op 0 B/op 0 allocs/op
Benchmark_Trim/default-16 18457221 66.0 ns/op 32 B/op 1 allocs/op
Benchmark_Trim/default-16 18177328 65.9 ns/op 32 B/op 1 allocs/op
Benchmark_Trim/default.trimspace-16 188933770 6.33 ns/op 0 B/op 0 allocs/op
Benchmark_Trim/default.trimspace-16 184007649 6.42 ns/op 0 B/op 0 allocs/op
Benchmark_ConvertToBytes/fiber-8 43773547 24.43 ns/op 0 B/op 0 allocs/op
Benchmark_ConvertToBytes/fiber-8 45849477 25.33 ns/op 0 B/op 0 allocs/op
```
+68
View File
@@ -0,0 +1,68 @@
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 🤖 Github Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package utils
import (
"bytes"
"fmt"
"log"
"path/filepath"
"reflect"
"runtime"
"testing"
"text/tabwriter"
)
// AssertEqual checks if values are equal
func AssertEqual(tb testing.TB, expected, actual interface{}, description ...string) { //nolint:thelper // TODO: Verify if tb can be nil
if tb != nil {
tb.Helper()
}
if reflect.DeepEqual(expected, actual) {
return
}
aType := "<nil>"
bType := "<nil>"
if expected != nil {
aType = reflect.TypeOf(expected).String()
}
if actual != nil {
bType = reflect.TypeOf(actual).String()
}
testName := "AssertEqual"
if tb != nil {
testName = tb.Name()
}
_, file, line, _ := runtime.Caller(1)
var buf bytes.Buffer
const pad = 5
w := tabwriter.NewWriter(&buf, 0, 0, pad, ' ', 0)
_, _ = fmt.Fprintf(w, "\nTest:\t%s", testName)
_, _ = fmt.Fprintf(w, "\nTrace:\t%s:%d", filepath.Base(file), line)
if len(description) > 0 {
_, _ = fmt.Fprintf(w, "\nDescription:\t%s", description[0])
}
_, _ = fmt.Fprintf(w, "\nExpect:\t%v\t(%s)", expected, aType)
_, _ = fmt.Fprintf(w, "\nResult:\t%v\t(%s)", actual, bType)
var result string
if err := w.Flush(); err != nil {
result = err.Error()
} else {
result = buf.String()
}
if tb != nil {
tb.Fatal(result)
} else {
log.Fatal(result) //nolint:revive // tb might be nil, so we need a fallback
}
}
+69
View File
@@ -0,0 +1,69 @@
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 🤖 Github Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package utils
// ToLowerBytes converts ascii slice to lower-case in-place.
func ToLowerBytes(b []byte) []byte {
for i := 0; i < len(b); i++ {
b[i] = toLowerTable[b[i]]
}
return b
}
// ToUpperBytes converts ascii slice to upper-case in-place.
func ToUpperBytes(b []byte) []byte {
for i := 0; i < len(b); i++ {
b[i] = toUpperTable[b[i]]
}
return b
}
// TrimRightBytes is the equivalent of bytes.TrimRight
func TrimRightBytes(b []byte, cutset byte) []byte {
lenStr := len(b)
for lenStr > 0 && b[lenStr-1] == cutset {
lenStr--
}
return b[:lenStr]
}
// TrimLeftBytes is the equivalent of bytes.TrimLeft
func TrimLeftBytes(b []byte, cutset byte) []byte {
lenStr, start := len(b), 0
for start < lenStr && b[start] == cutset {
start++
}
return b[start:]
}
// TrimBytes is the equivalent of bytes.Trim
func TrimBytes(b []byte, cutset byte) []byte {
i, j := 0, len(b)-1
for ; i <= j; i++ {
if b[i] != cutset {
break
}
}
for ; i < j; j-- {
if b[j] != cutset {
break
}
}
return b[i : j+1]
}
// EqualFoldBytes tests ascii slices for equality case-insensitively
func EqualFoldBytes(b, s []byte) bool {
if len(b) != len(s) {
return false
}
for i := len(b) - 1; i >= 0; i-- {
if toUpperTable[b[i]] != toUpperTable[s[i]] {
return false
}
}
return true
}
+160
View File
@@ -0,0 +1,160 @@
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 🤖 Github Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package utils
import (
"bytes"
"crypto/rand"
"encoding/binary"
"encoding/hex"
"math"
"net"
"os"
"reflect"
"runtime"
"strconv"
"sync"
"sync/atomic"
"unicode"
googleuuid "github.com/google/uuid"
)
const (
toLowerTable = "\x00\x01\x02\x03\x04\x05\x06\a\b\t\n\v\f\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@abcdefghijklmnopqrstuvwxyz[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\u007f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff"
toUpperTable = "\x00\x01\x02\x03\x04\x05\x06\a\b\t\n\v\f\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`ABCDEFGHIJKLMNOPQRSTUVWXYZ{|}~\u007f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff"
)
// Copyright © 2014, Roger Peppe
// github.com/rogpeppe/fastuuid
// All rights reserved.
const (
emptyUUID = "00000000-0000-0000-0000-000000000000"
)
var (
uuidSeed [24]byte
uuidCounter uint64
uuidSetup sync.Once
unitsSlice = []byte("kmgtp")
)
// UUID generates an universally unique identifier (UUID)
func UUID() string {
// Setup seed & counter once
uuidSetup.Do(func() {
if _, err := rand.Read(uuidSeed[:]); err != nil {
return
}
uuidCounter = binary.LittleEndian.Uint64(uuidSeed[:8])
})
if atomic.LoadUint64(&uuidCounter) <= 0 {
return emptyUUID
}
// first 8 bytes differ, taking a slice of the first 16 bytes
x := atomic.AddUint64(&uuidCounter, 1)
uuid := uuidSeed
binary.LittleEndian.PutUint64(uuid[:8], x)
uuid[6], uuid[9] = uuid[9], uuid[6]
// RFC4122 v4
uuid[6] = (uuid[6] & 0x0f) | 0x40
uuid[8] = uuid[8]&0x3f | 0x80
// create UUID representation of the first 128 bits
b := make([]byte, 36)
hex.Encode(b[0:8], uuid[0:4])
b[8] = '-'
hex.Encode(b[9:13], uuid[4:6])
b[13] = '-'
hex.Encode(b[14:18], uuid[6:8])
b[18] = '-'
hex.Encode(b[19:23], uuid[8:10])
b[23] = '-'
hex.Encode(b[24:], uuid[10:16])
return UnsafeString(b)
}
// UUIDv4 returns a Random (Version 4) UUID.
// The strength of the UUIDs is based on the strength of the crypto/rand package.
func UUIDv4() string {
token, err := googleuuid.NewRandom()
if err != nil {
return UUID()
}
return token.String()
}
// FunctionName returns function name
func FunctionName(fn interface{}) string {
t := reflect.ValueOf(fn).Type()
if t.Kind() == reflect.Func {
return runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
}
return t.String()
}
// GetArgument check if key is in arguments
func GetArgument(arg string) bool {
for i := range os.Args[1:] {
if os.Args[1:][i] == arg {
return true
}
}
return false
}
// IncrementIPRange Find available next IP address
func IncrementIPRange(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
}
}
}
// ConvertToBytes returns integer size of bytes from human-readable string, ex. 42kb, 42M
// Returns 0 if string is unrecognized
func ConvertToBytes(humanReadableString string) int {
strLen := len(humanReadableString)
if strLen == 0 {
return 0
}
var unitPrefixPos, lastNumberPos int
// loop the string
for i := strLen - 1; i >= 0; i-- {
// check if the char is a number
if unicode.IsDigit(rune(humanReadableString[i])) {
lastNumberPos = i
break
} else if humanReadableString[i] != ' ' {
unitPrefixPos = i
}
}
if lastNumberPos < 0 {
return 0
}
// fetch the number part and parse it to float
size, err := strconv.ParseFloat(humanReadableString[:lastNumberPos+1], 64)
if err != nil {
return 0
}
// check the multiplier from the string and use it
if unitPrefixPos > 0 {
// convert multiplier char to lowercase and check if exists in units slice
index := bytes.IndexByte(unitsSlice, toLowerTable[humanReadableString[unitPrefixPos]])
if index != -1 {
const bytesPerKB = 1000
size *= math.Pow(bytesPerKB, float64(index+1))
}
}
return int(size)
}
+117
View File
@@ -0,0 +1,117 @@
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 🤖 Github Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package utils
import (
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
// CopyString copies a string to make it immutable
func CopyString(s string) string {
return string(UnsafeBytes(s))
}
// CopyBytes copies a slice to make it immutable
func CopyBytes(b []byte) []byte {
tmp := make([]byte, len(b))
copy(tmp, b)
return tmp
}
const (
uByte = 1 << (10 * iota) // 1 << 10 == 1024
uKilobyte
uMegabyte
uGigabyte
uTerabyte
uPetabyte
uExabyte
)
// ByteSize returns a human-readable byte string of the form 10M, 12.5K, and so forth.
// The unit that results in the smallest number greater than or equal to 1 is always chosen.
func ByteSize(bytes uint64) string {
unit := ""
value := float64(bytes)
switch {
case bytes >= uExabyte:
unit = "EB"
value /= uExabyte
case bytes >= uPetabyte:
unit = "PB"
value /= uPetabyte
case bytes >= uTerabyte:
unit = "TB"
value /= uTerabyte
case bytes >= uGigabyte:
unit = "GB"
value /= uGigabyte
case bytes >= uMegabyte:
unit = "MB"
value /= uMegabyte
case bytes >= uKilobyte:
unit = "KB"
value /= uKilobyte
case bytes >= uByte:
unit = "B"
default:
return "0B"
}
result := strconv.FormatFloat(value, 'f', 1, 64)
result = strings.TrimSuffix(result, ".0")
return result + unit
}
// ToString Change arg to string
func ToString(arg interface{}, timeFormat ...string) string {
tmp := reflect.Indirect(reflect.ValueOf(arg)).Interface()
switch v := tmp.(type) {
case int:
return strconv.Itoa(v)
case int8:
return strconv.FormatInt(int64(v), 10)
case int16:
return strconv.FormatInt(int64(v), 10)
case int32:
return strconv.FormatInt(int64(v), 10)
case int64:
return strconv.FormatInt(v, 10)
case uint:
return strconv.Itoa(int(v))
case uint8:
return strconv.FormatInt(int64(v), 10)
case uint16:
return strconv.FormatInt(int64(v), 10)
case uint32:
return strconv.FormatInt(int64(v), 10)
case uint64:
return strconv.FormatInt(int64(v), 10)
case string:
return v
case []byte:
return string(v)
case bool:
return strconv.FormatBool(v)
case float32:
return strconv.FormatFloat(float64(v), 'f', -1, 32)
case float64:
return strconv.FormatFloat(v, 'f', -1, 64)
case time.Time:
if len(timeFormat) > 0 {
return v.Format(timeFormat[0])
}
return v.Format("2006-01-02 15:04:05")
case reflect.Value:
return ToString(v.Interface(), timeFormat...)
case fmt.Stringer:
return v.String()
default:
return ""
}
}
+12
View File
@@ -0,0 +1,12 @@
//go:build go1.20
package utils
import (
"unsafe"
)
// UnsafeString returns a string pointer without allocation
func UnsafeString(b []byte) string {
return unsafe.String(unsafe.SliceData(b), len(b))
}
+14
View File
@@ -0,0 +1,14 @@
//go:build !go1.20
package utils
import (
"unsafe"
)
// UnsafeString returns a string pointer without allocation
//
//nolint:gosec // unsafe is used for better performance here
func UnsafeString(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
+12
View File
@@ -0,0 +1,12 @@
//go:build go1.20
package utils
import (
"unsafe"
)
// UnsafeBytes returns a byte pointer without allocation.
func UnsafeBytes(s string) []byte {
return unsafe.Slice(unsafe.StringData(s), len(s))
}
+24
View File
@@ -0,0 +1,24 @@
//go:build !go1.20
package utils
import (
"reflect"
"unsafe"
)
const MaxStringLen = 0x7fff0000 // Maximum string length for UnsafeBytes. (decimal: 2147418112)
// UnsafeBytes returns a byte pointer without allocation.
// String length shouldn't be more than 2147418112.
//
//nolint:gosec // unsafe is used for better performance here
func UnsafeBytes(s string) []byte {
if s == "" {
return nil
}
return (*[MaxStringLen]byte)(unsafe.Pointer(
(*reflect.StringHeader)(unsafe.Pointer(&s)).Data),
)[:len(s):len(s)]
}
+16
View File
@@ -0,0 +1,16 @@
package utils
// Deprecated: Please use UnsafeString instead
func GetString(b []byte) string {
return UnsafeString(b)
}
// Deprecated: Please use UnsafeBytes instead
func GetBytes(s string) []byte {
return UnsafeBytes(s)
}
// Deprecated: Please use CopyString instead
func ImmutableString(s string) string {
return CopyString(s)
}
+267
View File
@@ -0,0 +1,267 @@
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 🤖 Github Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package utils
import (
"mime"
"strings"
)
const MIMEOctetStream = "application/octet-stream"
// GetMIME returns the content-type of a file extension
func GetMIME(extension string) string {
if len(extension) == 0 {
return ""
}
var foundMime string
if extension[0] == '.' {
foundMime = mimeExtensions[extension[1:]]
} else {
foundMime = mimeExtensions[extension]
}
if len(foundMime) == 0 {
if extension[0] != '.' {
foundMime = mime.TypeByExtension("." + extension)
} else {
foundMime = mime.TypeByExtension(extension)
}
if foundMime == "" {
return MIMEOctetStream
}
}
return foundMime
}
// ParseVendorSpecificContentType check if content type is vendor specific and
// if it is parsable to any known types. If its not vendor specific then returns
// the original content type.
func ParseVendorSpecificContentType(cType string) string {
plusIndex := strings.Index(cType, "+")
if plusIndex == -1 {
return cType
}
var parsableType string
if semiColonIndex := strings.Index(cType, ";"); semiColonIndex == -1 {
parsableType = cType[plusIndex+1:]
} else if plusIndex < semiColonIndex {
parsableType = cType[plusIndex+1 : semiColonIndex]
} else {
return cType[:semiColonIndex]
}
slashIndex := strings.Index(cType, "/")
if slashIndex == -1 {
return cType
}
return cType[0:slashIndex+1] + parsableType
}
// limits for HTTP statuscodes
const (
statusMessageMin = 100
statusMessageMax = 511
)
// StatusMessage returns the correct message for the provided HTTP statuscode
func StatusMessage(status int) string {
if status < statusMessageMin || status > statusMessageMax {
return ""
}
return statusMessage[status]
}
// NOTE: Keep this in sync with the status code list
var statusMessage = []string{
100: "Continue", // StatusContinue
101: "Switching Protocols", // StatusSwitchingProtocols
102: "Processing", // StatusProcessing
103: "Early Hints", // StatusEarlyHints
200: "OK", // StatusOK
201: "Created", // StatusCreated
202: "Accepted", // StatusAccepted
203: "Non-Authoritative Information", // StatusNonAuthoritativeInformation
204: "No Content", // StatusNoContent
205: "Reset Content", // StatusResetContent
206: "Partial Content", // StatusPartialContent
207: "Multi-Status", // StatusMultiStatus
208: "Already Reported", // StatusAlreadyReported
226: "IM Used", // StatusIMUsed
300: "Multiple Choices", // StatusMultipleChoices
301: "Moved Permanently", // StatusMovedPermanently
302: "Found", // StatusFound
303: "See Other", // StatusSeeOther
304: "Not Modified", // StatusNotModified
305: "Use Proxy", // StatusUseProxy
306: "Switch Proxy", // StatusSwitchProxy
307: "Temporary Redirect", // StatusTemporaryRedirect
308: "Permanent Redirect", // StatusPermanentRedirect
400: "Bad Request", // StatusBadRequest
401: "Unauthorized", // StatusUnauthorized
402: "Payment Required", // StatusPaymentRequired
403: "Forbidden", // StatusForbidden
404: "Not Found", // StatusNotFound
405: "Method Not Allowed", // StatusMethodNotAllowed
406: "Not Acceptable", // StatusNotAcceptable
407: "Proxy Authentication Required", // StatusProxyAuthRequired
408: "Request Timeout", // StatusRequestTimeout
409: "Conflict", // StatusConflict
410: "Gone", // StatusGone
411: "Length Required", // StatusLengthRequired
412: "Precondition Failed", // StatusPreconditionFailed
413: "Request Entity Too Large", // StatusRequestEntityTooLarge
414: "Request URI Too Long", // StatusRequestURITooLong
415: "Unsupported Media Type", // StatusUnsupportedMediaType
416: "Requested Range Not Satisfiable", // StatusRequestedRangeNotSatisfiable
417: "Expectation Failed", // StatusExpectationFailed
418: "I'm a teapot", // StatusTeapot
421: "Misdirected Request", // StatusMisdirectedRequest
422: "Unprocessable Entity", // StatusUnprocessableEntity
423: "Locked", // StatusLocked
424: "Failed Dependency", // StatusFailedDependency
425: "Too Early", // StatusTooEarly
426: "Upgrade Required", // StatusUpgradeRequired
428: "Precondition Required", // StatusPreconditionRequired
429: "Too Many Requests", // StatusTooManyRequests
431: "Request Header Fields Too Large", // StatusRequestHeaderFieldsTooLarge
451: "Unavailable For Legal Reasons", // StatusUnavailableForLegalReasons
500: "Internal Server Error", // StatusInternalServerError
501: "Not Implemented", // StatusNotImplemented
502: "Bad Gateway", // StatusBadGateway
503: "Service Unavailable", // StatusServiceUnavailable
504: "Gateway Timeout", // StatusGatewayTimeout
505: "HTTP Version Not Supported", // StatusHTTPVersionNotSupported
506: "Variant Also Negotiates", // StatusVariantAlsoNegotiates
507: "Insufficient Storage", // StatusInsufficientStorage
508: "Loop Detected", // StatusLoopDetected
510: "Not Extended", // StatusNotExtended
511: "Network Authentication Required", // StatusNetworkAuthenticationRequired
}
// MIME types were copied from https://github.com/nginx/nginx/blob/67d2a9541826ecd5db97d604f23460210fd3e517/conf/mime.types with the following updates:
// - Use "application/xml" instead of "text/xml" as recommended per https://datatracker.ietf.org/doc/html/rfc7303#section-4.1
// - Use "text/javascript" instead of "application/javascript" as recommended per https://www.rfc-editor.org/rfc/rfc9239#name-text-javascript
var mimeExtensions = map[string]string{
"html": "text/html",
"htm": "text/html",
"shtml": "text/html",
"css": "text/css",
"xml": "application/xml",
"gif": "image/gif",
"jpeg": "image/jpeg",
"jpg": "image/jpeg",
"js": "text/javascript",
"atom": "application/atom+xml",
"rss": "application/rss+xml",
"mml": "text/mathml",
"txt": "text/plain",
"jad": "text/vnd.sun.j2me.app-descriptor",
"wml": "text/vnd.wap.wml",
"htc": "text/x-component",
"avif": "image/avif",
"png": "image/png",
"svg": "image/svg+xml",
"svgz": "image/svg+xml",
"tif": "image/tiff",
"tiff": "image/tiff",
"wbmp": "image/vnd.wap.wbmp",
"webp": "image/webp",
"ico": "image/x-icon",
"jng": "image/x-jng",
"bmp": "image/x-ms-bmp",
"woff": "font/woff",
"woff2": "font/woff2",
"jar": "application/java-archive",
"war": "application/java-archive",
"ear": "application/java-archive",
"json": "application/json",
"hqx": "application/mac-binhex40",
"doc": "application/msword",
"pdf": "application/pdf",
"ps": "application/postscript",
"eps": "application/postscript",
"ai": "application/postscript",
"rtf": "application/rtf",
"m3u8": "application/vnd.apple.mpegurl",
"kml": "application/vnd.google-earth.kml+xml",
"kmz": "application/vnd.google-earth.kmz",
"xls": "application/vnd.ms-excel",
"eot": "application/vnd.ms-fontobject",
"ppt": "application/vnd.ms-powerpoint",
"odg": "application/vnd.oasis.opendocument.graphics",
"odp": "application/vnd.oasis.opendocument.presentation",
"ods": "application/vnd.oasis.opendocument.spreadsheet",
"odt": "application/vnd.oasis.opendocument.text",
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"wmlc": "application/vnd.wap.wmlc",
"wasm": "application/wasm",
"7z": "application/x-7z-compressed",
"cco": "application/x-cocoa",
"jardiff": "application/x-java-archive-diff",
"jnlp": "application/x-java-jnlp-file",
"run": "application/x-makeself",
"pl": "application/x-perl",
"pm": "application/x-perl",
"prc": "application/x-pilot",
"pdb": "application/x-pilot",
"rar": "application/x-rar-compressed",
"rpm": "application/x-redhat-package-manager",
"sea": "application/x-sea",
"swf": "application/x-shockwave-flash",
"sit": "application/x-stuffit",
"tcl": "application/x-tcl",
"tk": "application/x-tcl",
"der": "application/x-x509-ca-cert",
"pem": "application/x-x509-ca-cert",
"crt": "application/x-x509-ca-cert",
"xpi": "application/x-xpinstall",
"xhtml": "application/xhtml+xml",
"xspf": "application/xspf+xml",
"zip": "application/zip",
"bin": "application/octet-stream",
"exe": "application/octet-stream",
"dll": "application/octet-stream",
"deb": "application/octet-stream",
"dmg": "application/octet-stream",
"iso": "application/octet-stream",
"img": "application/octet-stream",
"msi": "application/octet-stream",
"msp": "application/octet-stream",
"msm": "application/octet-stream",
"mid": "audio/midi",
"midi": "audio/midi",
"kar": "audio/midi",
"mp3": "audio/mpeg",
"ogg": "audio/ogg",
"m4a": "audio/x-m4a",
"ra": "audio/x-realaudio",
"3gpp": "video/3gpp",
"3gp": "video/3gpp",
"ts": "video/mp2t",
"mp4": "video/mp4",
"mpeg": "video/mpeg",
"mpg": "video/mpeg",
"mov": "video/quicktime",
"webm": "video/webm",
"flv": "video/x-flv",
"m4v": "video/x-m4v",
"mng": "video/x-mng",
"asx": "video/x-ms-asf",
"asf": "video/x-ms-asf",
"wmv": "video/x-ms-wmv",
"avi": "video/x-msvideo",
}
+143
View File
@@ -0,0 +1,143 @@
package utils
import (
"net"
)
// IsIPv4 works the same way as net.ParseIP,
// but without check for IPv6 case and without returning net.IP slice, whereby IsIPv4 makes no allocations.
func IsIPv4(s string) bool {
for i := 0; i < net.IPv4len; i++ {
if len(s) == 0 {
return false
}
if i > 0 {
if s[0] != '.' {
return false
}
s = s[1:]
}
n, ci := 0, 0
for ci = 0; ci < len(s) && '0' <= s[ci] && s[ci] <= '9'; ci++ {
n = n*10 + int(s[ci]-'0')
if n > 0xFF {
return false
}
}
if ci == 0 || (ci > 1 && s[0] == '0') {
return false
}
s = s[ci:]
}
return len(s) == 0
}
// IsIPv6 works the same way as net.ParseIP,
// but without check for IPv4 case and without returning net.IP slice, whereby IsIPv6 makes no allocations.
func IsIPv6(s string) bool {
ellipsis := -1 // position of ellipsis in ip
// Might have leading ellipsis
if len(s) >= 2 && s[0] == ':' && s[1] == ':' {
ellipsis = 0
s = s[2:]
// Might be only ellipsis
if len(s) == 0 {
return true
}
}
// Loop, parsing hex numbers followed by colon.
i := 0
for i < net.IPv6len {
// Hex number.
n, ci := 0, 0
for ci = 0; ci < len(s); ci++ {
if '0' <= s[ci] && s[ci] <= '9' {
n *= 16
n += int(s[ci] - '0')
} else if 'a' <= s[ci] && s[ci] <= 'f' {
n *= 16
n += int(s[ci]-'a') + 10
} else if 'A' <= s[ci] && s[ci] <= 'F' {
n *= 16
n += int(s[ci]-'A') + 10
} else {
break
}
if n > 0xFFFF {
return false
}
}
if ci == 0 || n > 0xFFFF {
return false
}
if ci < len(s) && s[ci] == '.' {
if ellipsis < 0 && i != net.IPv6len-net.IPv4len {
return false
}
if i+net.IPv4len > net.IPv6len {
return false
}
if !IsIPv4(s) {
return false
}
s = ""
i += net.IPv4len
break
}
// Save this 16-bit chunk.
i += 2
// Stop at end of string.
s = s[ci:]
if len(s) == 0 {
break
}
// Otherwise must be followed by colon and more.
if s[0] != ':' || len(s) == 1 {
return false
}
s = s[1:]
// Look for ellipsis.
if s[0] == ':' {
if ellipsis >= 0 { // already have one
return false
}
ellipsis = i
s = s[1:]
if len(s) == 0 { // can be at end
break
}
}
}
// Must have used entire string.
if len(s) != 0 {
return false
}
// If didn't parse enough, expand ellipsis.
if i < net.IPv6len {
if ellipsis < 0 {
return false
}
} else if ellipsis >= 0 {
// Ellipsis must represent at least one 0 group.
return false
}
return true
}
+9
View File
@@ -0,0 +1,9 @@
package utils
// JSONMarshal returns the JSON encoding of v.
type JSONMarshal func(v interface{}) ([]byte, error)
// JSONUnmarshal parses the JSON-encoded data and stores the result
// in the value pointed to by v. If v is nil or not a pointer,
// Unmarshal returns an InvalidUnmarshalError.
type JSONUnmarshal func(data []byte, v interface{}) error
+75
View File
@@ -0,0 +1,75 @@
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 🤖 Github Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package utils
// ToLower converts ascii string to lower-case
func ToLower(b string) string {
res := make([]byte, len(b))
copy(res, b)
for i := 0; i < len(res); i++ {
res[i] = toLowerTable[res[i]]
}
return UnsafeString(res)
}
// ToUpper converts ascii string to upper-case
func ToUpper(b string) string {
res := make([]byte, len(b))
copy(res, b)
for i := 0; i < len(res); i++ {
res[i] = toUpperTable[res[i]]
}
return UnsafeString(res)
}
// TrimLeft is the equivalent of strings.TrimLeft
func TrimLeft(s string, cutset byte) string {
lenStr, start := len(s), 0
for start < lenStr && s[start] == cutset {
start++
}
return s[start:]
}
// Trim is the equivalent of strings.Trim
func Trim(s string, cutset byte) string {
i, j := 0, len(s)-1
for ; i <= j; i++ {
if s[i] != cutset {
break
}
}
for ; i < j; j-- {
if s[j] != cutset {
break
}
}
return s[i : j+1]
}
// TrimRight is the equivalent of strings.TrimRight
func TrimRight(s string, cutset byte) string {
lenStr := len(s)
for lenStr > 0 && s[lenStr-1] == cutset {
lenStr--
}
return s[:lenStr]
}
// EqualFold tests ascii strings for equality case-insensitively
func EqualFold(b, s string) bool {
if len(b) != len(s) {
return false
}
for i := len(b) - 1; i >= 0; i-- {
if toUpperTable[b[i]] != toUpperTable[s[i]] {
return false
}
}
return true
}
+32
View File
@@ -0,0 +1,32 @@
package utils
import (
"sync"
"sync/atomic"
"time"
)
var (
timestampTimer sync.Once
// Timestamp please start the timer function before you use this value
// please load the value with atomic `atomic.LoadUint32(&utils.Timestamp)`
Timestamp uint32
)
// StartTimeStampUpdater starts a concurrent function which stores the timestamp to an atomic value per second,
// which is much better for performance than determining it at runtime each time
func StartTimeStampUpdater() {
timestampTimer.Do(func() {
// set initial value
atomic.StoreUint32(&Timestamp, uint32(time.Now().Unix()))
go func(sleep time.Duration) {
ticker := time.NewTicker(sleep)
defer ticker.Stop()
for t := range ticker.C {
// update timestamp
atomic.StoreUint32(&Timestamp, uint32(t.Unix()))
}
}(1 * time.Second) // duration
})
}
+4
View File
@@ -0,0 +1,4 @@
package utils
// XMLMarshal returns the XML encoding of v.
type XMLMarshal func(v interface{}) ([]byte, error)