mirror of
https://gitlab.com/mbugroup/lti-api.git
synced 2026-05-21 13:55:43 +00:00
.
This commit is contained in:
+21
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021 Fiber
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
+300
@@ -0,0 +1,300 @@
|
||||
---
|
||||
id: jwt
|
||||
---
|
||||
|
||||
# JWT
|
||||
|
||||

|
||||
[](https://gofiber.io/discord)
|
||||

|
||||

|
||||

|
||||
|
||||
JWT returns a JSON Web Token (JWT) auth middleware.
|
||||
For valid token, it sets the user in Ctx.Locals and calls next handler.
|
||||
For invalid token, it returns "401 - Unauthorized" error.
|
||||
For missing token, it returns "400 - Bad Request" error.
|
||||
|
||||
Special thanks and credits to [Echo](https://echo.labstack.com/middleware/jwt)
|
||||
|
||||
**Note: Requires Go 1.19 and above**
|
||||
|
||||
## Install
|
||||
|
||||
This middleware supports Fiber v1 & v2, install accordingly.
|
||||
|
||||
```
|
||||
go get -u github.com/gofiber/fiber/v2
|
||||
go get -u github.com/gofiber/contrib/jwt
|
||||
go get -u github.com/golang-jwt/jwt/v5
|
||||
```
|
||||
|
||||
## Signature
|
||||
```go
|
||||
jwtware.New(config ...jwtware.Config) func(*fiber.Ctx) error
|
||||
```
|
||||
|
||||
## Config
|
||||
|
||||
| Property | Type | Description | Default |
|
||||
|:---------------|:--------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------|
|
||||
| Filter | `func(*fiber.Ctx) bool` | Defines a function to skip middleware | `nil` |
|
||||
| SuccessHandler | `func(*fiber.Ctx) error` | SuccessHandler defines a function which is executed for a valid token. | `nil` |
|
||||
| ErrorHandler | `func(*fiber.Ctx, error) error` | ErrorHandler defines a function which is executed for an invalid token. | `401 Invalid or expired JWT` |
|
||||
| SigningKey | `interface{}` | Signing key to validate token. Used as fallback if SigningKeys has length 0. | `nil` |
|
||||
| SigningKeys | `map[string]interface{}` | Map of signing keys to validate token with kid field usage. | `nil` |
|
||||
| ContextKey | `string` | Context key to store user information from the token into context. | `"user"` |
|
||||
| Claims | `jwt.Claim` | Claims are extendable claims data defining token content. | `jwt.MapClaims{}` |
|
||||
| TokenLookup | `string` | TokenLookup is a string in the form of `<source>:<name>` that is used | `"header:Authorization"` |
|
||||
| AuthScheme | `string` | AuthScheme to be used in the Authorization header. The default value (`"Bearer"`) will only be used in conjuction with the default `TokenLookup` value. | `"Bearer"` |
|
||||
| KeyFunc | `func() jwt.Keyfunc` | KeyFunc defines a user-defined function that supplies the public key for a token validation. | `jwtKeyFunc` |
|
||||
| JWKSetURLs | `[]string` | A slice of unique JSON Web Key (JWK) Set URLs to used to parse JWTs. | `nil` |
|
||||
|
||||
|
||||
## HS256 Example
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
jwtware "github.com/gofiber/contrib/jwt"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
func main() {
|
||||
app := fiber.New()
|
||||
|
||||
// Login route
|
||||
app.Post("/login", login)
|
||||
|
||||
// Unauthenticated route
|
||||
app.Get("/", accessible)
|
||||
|
||||
// JWT Middleware
|
||||
app.Use(jwtware.New(jwtware.Config{
|
||||
SigningKey: jwtware.SigningKey{Key: []byte("secret")},
|
||||
}))
|
||||
|
||||
// Restricted Routes
|
||||
app.Get("/restricted", restricted)
|
||||
|
||||
app.Listen(":3000")
|
||||
}
|
||||
|
||||
func login(c *fiber.Ctx) error {
|
||||
user := c.FormValue("user")
|
||||
pass := c.FormValue("pass")
|
||||
|
||||
// Throws Unauthorized error
|
||||
if user != "john" || pass != "doe" {
|
||||
return c.SendStatus(fiber.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// Create the Claims
|
||||
claims := jwt.MapClaims{
|
||||
"name": "John Doe",
|
||||
"admin": true,
|
||||
"exp": time.Now().Add(time.Hour * 72).Unix(),
|
||||
}
|
||||
|
||||
// Create token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
|
||||
// Generate encoded token and send it as response.
|
||||
t, err := token.SignedString([]byte("secret"))
|
||||
if err != nil {
|
||||
return c.SendStatus(fiber.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{"token": t})
|
||||
}
|
||||
|
||||
func accessible(c *fiber.Ctx) error {
|
||||
return c.SendString("Accessible")
|
||||
}
|
||||
|
||||
func restricted(c *fiber.Ctx) error {
|
||||
user := c.Locals("user").(*jwt.Token)
|
||||
claims := user.Claims.(jwt.MapClaims)
|
||||
name := claims["name"].(string)
|
||||
return c.SendString("Welcome " + name)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
## HS256 Test
|
||||
_Login using username and password to retrieve a token._
|
||||
```
|
||||
curl --data "user=john&pass=doe" http://localhost:3000/login
|
||||
```
|
||||
_Response_
|
||||
```json
|
||||
{
|
||||
"token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE0NjE5NTcxMzZ9.RB3arc4-OyzASAaUhC2W3ReWaXAt_z2Fd3BN4aWTgEY"
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
_Request a restricted resource using the token in Authorization request header._
|
||||
```
|
||||
curl localhost:3000/restricted -H "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE0NjE5NTcxMzZ9.RB3arc4-OyzASAaUhC2W3ReWaXAt_z2Fd3BN4aWTgEY"
|
||||
```
|
||||
_Response_
|
||||
```
|
||||
Welcome John Doe
|
||||
```
|
||||
|
||||
## RS256 Example
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
jwtware "github.com/gofiber/contrib/jwt"
|
||||
)
|
||||
|
||||
var (
|
||||
// Obviously, this is just a test example. Do not do this in production.
|
||||
// In production, you would have the private key and public key pair generated
|
||||
// in advance. NEVER add a private key to any GitHub repo.
|
||||
privateKey *rsa.PrivateKey
|
||||
)
|
||||
|
||||
func main() {
|
||||
app := fiber.New()
|
||||
|
||||
// Just as a demo, generate a new private/public key pair on each run. See note above.
|
||||
rng := rand.Reader
|
||||
var err error
|
||||
privateKey, err = rsa.GenerateKey(rng, 2048)
|
||||
if err != nil {
|
||||
log.Fatalf("rsa.GenerateKey: %v", err)
|
||||
}
|
||||
|
||||
// Login route
|
||||
app.Post("/login", login)
|
||||
|
||||
// Unauthenticated route
|
||||
app.Get("/", accessible)
|
||||
|
||||
// JWT Middleware
|
||||
app.Use(jwtware.New(jwtware.Config{
|
||||
SigningKey: jwtware.SigningKey{
|
||||
JWTAlg: jwtware.RS256,
|
||||
Key: privateKey.Public(),
|
||||
},
|
||||
}))
|
||||
|
||||
// Restricted Routes
|
||||
app.Get("/restricted", restricted)
|
||||
|
||||
app.Listen(":3000")
|
||||
}
|
||||
|
||||
func login(c *fiber.Ctx) error {
|
||||
user := c.FormValue("user")
|
||||
pass := c.FormValue("pass")
|
||||
|
||||
// Throws Unauthorized error
|
||||
if user != "john" || pass != "doe" {
|
||||
return c.SendStatus(fiber.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// Create the Claims
|
||||
claims := jwt.MapClaims{
|
||||
"name": "John Doe",
|
||||
"admin": true,
|
||||
"exp": time.Now().Add(time.Hour * 72).Unix(),
|
||||
}
|
||||
|
||||
// Create token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
|
||||
// Generate encoded token and send it as response.
|
||||
t, err := token.SignedString(privateKey)
|
||||
if err != nil {
|
||||
log.Printf("token.SignedString: %v", err)
|
||||
return c.SendStatus(fiber.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{"token": t})
|
||||
}
|
||||
|
||||
func accessible(c *fiber.Ctx) error {
|
||||
return c.SendString("Accessible")
|
||||
}
|
||||
|
||||
func restricted(c *fiber.Ctx) error {
|
||||
user := c.Locals("user").(*jwt.Token)
|
||||
claims := user.Claims.(jwt.MapClaims)
|
||||
name := claims["name"].(string)
|
||||
return c.SendString("Welcome " + name)
|
||||
}
|
||||
```
|
||||
|
||||
## RS256 Test
|
||||
The RS256 is actually identical to the HS256 test above.
|
||||
|
||||
## JWK Set Test
|
||||
The tests are identical to basic `JWT` tests above, with exception that `JWKSetURLs` to valid public keys collection in JSON Web Key (JWK) Set format should be supplied. See [RFC 7517](https://www.rfc-editor.org/rfc/rfc7517).
|
||||
|
||||
## Custom KeyFunc example
|
||||
|
||||
KeyFunc defines a user-defined function that supplies the public key for a token validation.
|
||||
The function shall take care of verifying the signing algorithm and selecting the proper key.
|
||||
A user-defined KeyFunc can be useful if tokens are issued by an external party.
|
||||
|
||||
When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored.
|
||||
This is one of the three options to provide a token validation key.
|
||||
The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
|
||||
Required if neither SigningKeys nor SigningKey is provided.
|
||||
Default to an internal implementation verifying the signing algorithm and selecting the proper key.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
jwtware "github.com/gofiber/contrib/jwt"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
func main() {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(jwtware.New(jwtware.Config{
|
||||
KeyFunc: customKeyFunc(),
|
||||
}))
|
||||
|
||||
app.Get("/ok", func(c *fiber.Ctx) error {
|
||||
return c.SendString("OK")
|
||||
})
|
||||
}
|
||||
|
||||
func customKeyFunc() jwt.Keyfunc {
|
||||
return func(t *jwt.Token) (interface{}, error) {
|
||||
// Always check the signing method
|
||||
if t.Method.Alg() != jwtware.HS256 {
|
||||
return nil, fmt.Errorf("Unexpected jwt signing method=%v", t.Header["alg"])
|
||||
}
|
||||
|
||||
// TODO custom implementation of loading signing key like from a database
|
||||
signingKey := "secret"
|
||||
|
||||
return []byte(signingKey), nil
|
||||
}
|
||||
}
|
||||
```
|
||||
+232
@@ -0,0 +1,232 @@
|
||||
package jwtware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/MicahParks/keyfunc/v2"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrJWTAlg is returned when the JWT header did not contain the expected algorithm.
|
||||
ErrJWTAlg = errors.New("the JWT header did not contain the expected algorithm")
|
||||
)
|
||||
|
||||
// Config defines the config for JWT middleware
|
||||
type Config struct {
|
||||
// Filter defines a function to skip middleware.
|
||||
// Optional. Default: nil
|
||||
Filter func(*fiber.Ctx) bool
|
||||
|
||||
// SuccessHandler defines a function which is executed for a valid token.
|
||||
// Optional. Default: nil
|
||||
SuccessHandler fiber.Handler
|
||||
|
||||
// ErrorHandler defines a function which is executed for an invalid token.
|
||||
// It may be used to define a custom JWT error.
|
||||
// Optional. Default: 401 Invalid or expired JWT
|
||||
ErrorHandler fiber.ErrorHandler
|
||||
|
||||
// Signing key to validate token. Used as fallback if SigningKeys has length 0.
|
||||
// At least one of the following is required: KeyFunc, JWKSetURLs, SigningKeys, or SigningKey.
|
||||
// The order of precedence is: KeyFunc, JWKSetURLs, SigningKeys, SigningKey.
|
||||
SigningKey SigningKey
|
||||
|
||||
// Map of signing keys to validate token with kid field usage.
|
||||
// At least one of the following is required: KeyFunc, JWKSetURLs, SigningKeys, or SigningKey.
|
||||
// The order of precedence is: KeyFunc, JWKSetURLs, SigningKeys, SigningKey.
|
||||
SigningKeys map[string]SigningKey
|
||||
|
||||
// Context key to store user information from the token into context.
|
||||
// Optional. Default: "user".
|
||||
ContextKey string
|
||||
|
||||
// Claims are extendable claims data defining token content.
|
||||
// Optional. Default value jwt.MapClaims
|
||||
Claims jwt.Claims
|
||||
|
||||
// TokenLookup is a string in the form of "<source>:<name>" that is used
|
||||
// to extract token from the request.
|
||||
// Optional. Default value "header:Authorization".
|
||||
// Possible values:
|
||||
// - "header:<name>"
|
||||
// - "query:<name>"
|
||||
// - "param:<name>"
|
||||
// - "cookie:<name>"
|
||||
TokenLookup string
|
||||
|
||||
// AuthScheme to be used in the Authorization header.
|
||||
// Optional. Default: "Bearer".
|
||||
AuthScheme string
|
||||
|
||||
// KeyFunc is a function that supplies the public key for JWT cryptographic verification.
|
||||
// The function shall take care of verifying the signing algorithm and selecting the proper key.
|
||||
// Internally, github.com/MicahParks/keyfunc/v2 package is used project defaults. If you need more customization,
|
||||
// you can provide a jwt.Keyfunc using that package or make your own implementation.
|
||||
//
|
||||
// At least one of the following is required: KeyFunc, JWKSetURLs, SigningKeys, or SigningKey.
|
||||
// The order of precedence is: KeyFunc, JWKSetURLs, SigningKeys, SigningKey.
|
||||
KeyFunc jwt.Keyfunc
|
||||
|
||||
// JWKSetURLs is a slice of HTTP URLs that contain the JSON Web Key Set (JWKS) used to verify the signatures of
|
||||
// JWTs. Use of HTTPS is recommended. The presence of the "kid" field in the JWT header and JWKs is mandatory for
|
||||
// this feature.
|
||||
//
|
||||
// By default, all JWK Sets in this slice will:
|
||||
// * Refresh every hour.
|
||||
// * Refresh automatically if a new "kid" is seen in a JWT being verified.
|
||||
// * Rate limit refreshes to once every 5 minutes.
|
||||
// * Timeout refreshes after 10 seconds.
|
||||
//
|
||||
// At least one of the following is required: KeyFunc, JWKSetURLs, SigningKeys, or SigningKey.
|
||||
// The order of precedence is: KeyFunc, JWKSetURLs, SigningKeys, SigningKey.
|
||||
JWKSetURLs []string
|
||||
}
|
||||
|
||||
// SigningKey holds information about the recognized cryptographic keys used to sign JWTs by this program.
|
||||
type SigningKey struct {
|
||||
// JWTAlg is the algorithm used to sign JWTs. If this value is a non-empty string, this will be checked against the
|
||||
// "alg" value in the JWT header.
|
||||
//
|
||||
// https://www.rfc-editor.org/rfc/rfc7518#section-3.1
|
||||
JWTAlg string
|
||||
// Key is the cryptographic key used to sign JWTs. For supported types, please see
|
||||
// https://github.com/golang-jwt/jwt.
|
||||
Key interface{}
|
||||
}
|
||||
|
||||
// makeCfg function will check correctness of supplied configuration
|
||||
// and will complement it with default values instead of missing ones
|
||||
func makeCfg(config []Config) (cfg Config) {
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
}
|
||||
if cfg.SuccessHandler == nil {
|
||||
cfg.SuccessHandler = func(c *fiber.Ctx) error {
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
if cfg.ErrorHandler == nil {
|
||||
cfg.ErrorHandler = func(c *fiber.Ctx, err error) error {
|
||||
if err.Error() == ErrJWTMissingOrMalformed.Error() {
|
||||
return c.Status(fiber.StatusBadRequest).SendString(ErrJWTMissingOrMalformed.Error())
|
||||
}
|
||||
return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired JWT")
|
||||
}
|
||||
}
|
||||
if cfg.SigningKey.Key == nil && len(cfg.SigningKeys) == 0 && len(cfg.JWKSetURLs) == 0 && cfg.KeyFunc == nil {
|
||||
panic("Fiber: JWT middleware configuration: At least one of the following is required: KeyFunc, JWKSetURLs, SigningKeys, or SigningKey.")
|
||||
}
|
||||
if cfg.ContextKey == "" {
|
||||
cfg.ContextKey = "user"
|
||||
}
|
||||
if cfg.Claims == nil {
|
||||
cfg.Claims = jwt.MapClaims{}
|
||||
}
|
||||
if cfg.TokenLookup == "" {
|
||||
cfg.TokenLookup = defaultTokenLookup
|
||||
// set AuthScheme as "Bearer" only if TokenLookup is set to default.
|
||||
if cfg.AuthScheme == "" {
|
||||
cfg.AuthScheme = "Bearer"
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.KeyFunc == nil {
|
||||
if len(cfg.SigningKeys) > 0 || len(cfg.JWKSetURLs) > 0 {
|
||||
var givenKeys map[string]keyfunc.GivenKey
|
||||
if cfg.SigningKeys != nil {
|
||||
givenKeys = make(map[string]keyfunc.GivenKey, len(cfg.SigningKeys))
|
||||
for kid, key := range cfg.SigningKeys {
|
||||
givenKeys[kid] = keyfunc.NewGivenCustom(key.Key, keyfunc.GivenKeyOptions{
|
||||
Algorithm: key.JWTAlg,
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(cfg.JWKSetURLs) > 0 {
|
||||
var err error
|
||||
cfg.KeyFunc, err = multiKeyfunc(givenKeys, cfg.JWKSetURLs)
|
||||
if err != nil {
|
||||
panic("Failed to create keyfunc from JWK Set URL: " + err.Error())
|
||||
}
|
||||
} else {
|
||||
cfg.KeyFunc = keyfunc.NewGiven(givenKeys).Keyfunc
|
||||
}
|
||||
} else {
|
||||
cfg.KeyFunc = signingKeyFunc(cfg.SigningKey)
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func multiKeyfunc(givenKeys map[string]keyfunc.GivenKey, jwkSetURLs []string) (jwt.Keyfunc, error) {
|
||||
opts := keyfuncOptions(givenKeys)
|
||||
multiple := make(map[string]keyfunc.Options, len(jwkSetURLs))
|
||||
for _, url := range jwkSetURLs {
|
||||
multiple[url] = opts
|
||||
}
|
||||
multiOpts := keyfunc.MultipleOptions{
|
||||
KeySelector: keyfunc.KeySelectorFirst,
|
||||
}
|
||||
multi, err := keyfunc.GetMultiple(multiple, multiOpts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get multiple JWK Set URLs: %w", err)
|
||||
}
|
||||
return multi.Keyfunc, nil
|
||||
}
|
||||
|
||||
func keyfuncOptions(givenKeys map[string]keyfunc.GivenKey) keyfunc.Options {
|
||||
return keyfunc.Options{
|
||||
GivenKeys: givenKeys,
|
||||
RefreshErrorHandler: func(err error) {
|
||||
log.Printf("Failed to perform background refresh of JWK Set: %s.", err)
|
||||
},
|
||||
RefreshInterval: time.Hour,
|
||||
RefreshRateLimit: time.Minute * 5,
|
||||
RefreshTimeout: time.Second * 10,
|
||||
RefreshUnknownKID: true,
|
||||
}
|
||||
}
|
||||
|
||||
// getExtractors function will create a slice of functions which will be used
|
||||
// for token sarch and will perform extraction of the value
|
||||
func (cfg *Config) getExtractors() []jwtExtractor {
|
||||
// Initialize
|
||||
extractors := make([]jwtExtractor, 0)
|
||||
rootParts := strings.Split(cfg.TokenLookup, ",")
|
||||
for _, rootPart := range rootParts {
|
||||
parts := strings.Split(strings.TrimSpace(rootPart), ":")
|
||||
|
||||
switch parts[0] {
|
||||
case "header":
|
||||
extractors = append(extractors, jwtFromHeader(parts[1], cfg.AuthScheme))
|
||||
case "query":
|
||||
extractors = append(extractors, jwtFromQuery(parts[1]))
|
||||
case "param":
|
||||
extractors = append(extractors, jwtFromParam(parts[1]))
|
||||
case "cookie":
|
||||
extractors = append(extractors, jwtFromCookie(parts[1]))
|
||||
}
|
||||
}
|
||||
return extractors
|
||||
}
|
||||
|
||||
func signingKeyFunc(key SigningKey) jwt.Keyfunc {
|
||||
return func(token *jwt.Token) (interface{}, error) {
|
||||
if key.JWTAlg != "" {
|
||||
alg, ok := token.Header["alg"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected jwt signing method: expected: %q: got: missing or unexpected JSON type", key.JWTAlg)
|
||||
}
|
||||
if alg != key.JWTAlg {
|
||||
return nil, fmt.Errorf("unexpected jwt signing method: expected: %q: got: %q", key.JWTAlg, alg)
|
||||
}
|
||||
}
|
||||
return key.Key, nil
|
||||
}
|
||||
}
|
||||
+48
@@ -0,0 +1,48 @@
|
||||
package jwtware
|
||||
|
||||
const (
|
||||
// HS256 represents a public cryptography key generated by a 256 bit HMAC algorithm.
|
||||
HS256 = "HS256"
|
||||
|
||||
// HS384 represents a public cryptography key generated by a 384 bit HMAC algorithm.
|
||||
HS384 = "HS384"
|
||||
|
||||
// HS512 represents a public cryptography key generated by a 512 bit HMAC algorithm.
|
||||
HS512 = "HS512"
|
||||
|
||||
// ES256 represents a public cryptography key generated by a 256 bit ECDSA algorithm.
|
||||
ES256 = "ES256"
|
||||
|
||||
// ES384 represents a public cryptography key generated by a 384 bit ECDSA algorithm.
|
||||
ES384 = "ES384"
|
||||
|
||||
// ES512 represents a public cryptography key generated by a 512 bit ECDSA algorithm.
|
||||
ES512 = "ES512"
|
||||
|
||||
// P256 represents a cryptographic elliptical curve type.
|
||||
P256 = "P-256"
|
||||
|
||||
// P384 represents a cryptographic elliptical curve type.
|
||||
P384 = "P-384"
|
||||
|
||||
// P521 represents a cryptographic elliptical curve type.
|
||||
P521 = "P-521"
|
||||
|
||||
// RS256 represents a public cryptography key generated by a 256 bit RSA algorithm.
|
||||
RS256 = "RS256"
|
||||
|
||||
// RS384 represents a public cryptography key generated by a 384 bit RSA algorithm.
|
||||
RS384 = "RS384"
|
||||
|
||||
// RS512 represents a public cryptography key generated by a 512 bit RSA algorithm.
|
||||
RS512 = "RS512"
|
||||
|
||||
// PS256 represents a public cryptography key generated by a 256 bit RSA algorithm.
|
||||
PS256 = "PS256"
|
||||
|
||||
// PS384 represents a public cryptography key generated by a 384 bit RSA algorithm.
|
||||
PS384 = "PS384"
|
||||
|
||||
// PS512 represents a public cryptography key generated by a 512 bit RSA algorithm.
|
||||
PS512 = "PS512"
|
||||
)
|
||||
+59
@@ -0,0 +1,59 @@
|
||||
// 🚀 Fiber is an Express inspired web framework written in Go with 💖
|
||||
// 📌 API Documentation: https://fiber.wiki
|
||||
// 📝 Github Repository: https://github.com/gofiber/fiber
|
||||
// Special thanks to Echo: https://github.com/labstack/echo/blob/master/middleware/jwt.go
|
||||
|
||||
package jwtware
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultTokenLookup = "header:" + fiber.HeaderAuthorization
|
||||
)
|
||||
|
||||
// New ...
|
||||
func New(config ...Config) fiber.Handler {
|
||||
cfg := makeCfg(config)
|
||||
|
||||
extractors := cfg.getExtractors()
|
||||
|
||||
// Return middleware handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Filter request to skip middleware
|
||||
if cfg.Filter != nil && cfg.Filter(c) {
|
||||
return c.Next()
|
||||
}
|
||||
var auth string
|
||||
var err error
|
||||
|
||||
for _, extractor := range extractors {
|
||||
auth, err = extractor(c)
|
||||
if auth != "" && err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return cfg.ErrorHandler(c, err)
|
||||
}
|
||||
var token *jwt.Token
|
||||
|
||||
if _, ok := cfg.Claims.(jwt.MapClaims); ok {
|
||||
token, err = jwt.Parse(auth, cfg.KeyFunc)
|
||||
} else {
|
||||
t := reflect.ValueOf(cfg.Claims).Type().Elem()
|
||||
claims := reflect.New(t).Interface().(jwt.Claims)
|
||||
token, err = jwt.ParseWithClaims(auth, claims, cfg.KeyFunc)
|
||||
}
|
||||
if err == nil && token.Valid {
|
||||
// Store user information from token into context.
|
||||
c.Locals(cfg.ContextKey, token)
|
||||
return cfg.SuccessHandler(c)
|
||||
}
|
||||
return cfg.ErrorHandler(c, err)
|
||||
}
|
||||
}
|
||||
+60
@@ -0,0 +1,60 @@
|
||||
package jwtware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrJWTMissingOrMalformed is returned when the JWT is missing or malformed.
|
||||
ErrJWTMissingOrMalformed = errors.New("missing or malformed JWT")
|
||||
)
|
||||
|
||||
type jwtExtractor func(c *fiber.Ctx) (string, error)
|
||||
|
||||
// jwtFromHeader returns a function that extracts token from the request header.
|
||||
func jwtFromHeader(header string, authScheme string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
auth := c.Get(header)
|
||||
l := len(authScheme)
|
||||
if len(auth) > l+1 && strings.EqualFold(auth[:l], authScheme) {
|
||||
return strings.TrimSpace(auth[l:]), nil
|
||||
}
|
||||
return "", ErrJWTMissingOrMalformed
|
||||
}
|
||||
}
|
||||
|
||||
// jwtFromQuery returns a function that extracts token from the query string.
|
||||
func jwtFromQuery(param string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
token := c.Query(param)
|
||||
if token == "" {
|
||||
return "", ErrJWTMissingOrMalformed
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// jwtFromParam returns a function that extracts token from the url param string.
|
||||
func jwtFromParam(param string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
token := c.Params(param)
|
||||
if token == "" {
|
||||
return "", ErrJWTMissingOrMalformed
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// jwtFromCookie returns a function that extracts token from the named cookie.
|
||||
func jwtFromCookie(name string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
token := c.Cookies(name)
|
||||
if token == "" {
|
||||
return "", ErrJWTMissingOrMalformed
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
+8
@@ -0,0 +1,8 @@
|
||||
; This file is for unifying the coding style for different editors and IDEs.
|
||||
; More information at http://editorconfig.org
|
||||
; This style originates from https://github.com/fewagency/best-practices
|
||||
root = true
|
||||
|
||||
[*]
|
||||
charset = utf-8
|
||||
end_of_line = lf
|
||||
+12
@@ -0,0 +1,12 @@
|
||||
# Handle line endings automatically for files detected as text
|
||||
# and leave all files detected as binary untouched.
|
||||
* text=auto eol=lf
|
||||
|
||||
# Force batch scripts to always use CRLF line endings so that if a repo is accessed
|
||||
# in Windows via a file share from Linux, the scripts will work.
|
||||
*.{cmd,[cC][mM][dD]} text eol=crlf
|
||||
*.{bat,[bB][aA][tT]} text eol=crlf
|
||||
|
||||
# Force bash scripts to always use LF line endings so that if a repo is accessed
|
||||
# in Unix via a file share from Windows, the scripts will work.
|
||||
*.sh text eol=lf
|
||||
+30
@@ -0,0 +1,30 @@
|
||||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Test binary, built with `go test -c`
|
||||
*.test
|
||||
*.tmp
|
||||
|
||||
# Output of the go coverage tool, specifically when used with LiteIDE
|
||||
*.out
|
||||
|
||||
# IDE files
|
||||
.vscode
|
||||
.DS_Store
|
||||
.idea
|
||||
|
||||
# Misc
|
||||
*.fiber.gz
|
||||
*.fasthttp.gz
|
||||
*.pprof
|
||||
*.workspace
|
||||
|
||||
# Dependencies
|
||||
/vendor/
|
||||
vendor/
|
||||
vendor
|
||||
/Godeps/
|
||||
+197
@@ -0,0 +1,197 @@
|
||||
# Created based on v1.51.0
|
||||
# NOTE: Keep this in sync with the version in .github/workflows/linter.yml
|
||||
|
||||
run:
|
||||
modules-download-mode: readonly
|
||||
skip-dirs-use-default: false
|
||||
skip-dirs:
|
||||
- internal
|
||||
|
||||
output:
|
||||
sort-results: true
|
||||
|
||||
linters-settings:
|
||||
errcheck:
|
||||
check-type-assertions: true
|
||||
check-blank: true
|
||||
disable-default-exclusions: true
|
||||
|
||||
errchkjson:
|
||||
report-no-exported: true
|
||||
|
||||
exhaustive:
|
||||
default-signifies-exhaustive: true
|
||||
|
||||
forbidigo:
|
||||
forbid:
|
||||
- ^(fmt\.Print(|f|ln)|print|println)$
|
||||
- 'http\.Default(Client|Transport)'
|
||||
# TODO: Eventually enable these patterns
|
||||
# - 'time\.Sleep'
|
||||
# - 'panic'
|
||||
|
||||
gocritic:
|
||||
disabled-checks:
|
||||
- ifElseChain
|
||||
|
||||
gofumpt:
|
||||
module-path: github.com/gofiber/fiber
|
||||
extra-rules: true
|
||||
|
||||
gosec:
|
||||
config:
|
||||
global:
|
||||
audit: true
|
||||
|
||||
govet:
|
||||
check-shadowing: true
|
||||
enable-all: true
|
||||
disable:
|
||||
- shadow
|
||||
- fieldalignment
|
||||
- loopclosure
|
||||
|
||||
grouper:
|
||||
import-require-single-import: true
|
||||
import-require-grouping: true
|
||||
|
||||
misspell:
|
||||
locale: US
|
||||
|
||||
nolintlint:
|
||||
require-explanation: true
|
||||
require-specific: true
|
||||
|
||||
nonamedreturns:
|
||||
report-error-in-defer: true
|
||||
|
||||
predeclared:
|
||||
q: true
|
||||
|
||||
promlinter:
|
||||
strict: true
|
||||
|
||||
revive:
|
||||
enable-all-rules: true
|
||||
rules:
|
||||
# Provided by gomnd linter
|
||||
- name: add-constant
|
||||
disabled: true
|
||||
- name: argument-limit
|
||||
disabled: true
|
||||
# Provided by bidichk
|
||||
- name: banned-characters
|
||||
disabled: true
|
||||
- name: cognitive-complexity
|
||||
disabled: true
|
||||
- name: cyclomatic
|
||||
disabled: true
|
||||
- name: early-return
|
||||
severity: warning
|
||||
disabled: true
|
||||
- name: exported
|
||||
disabled: true
|
||||
- name: file-header
|
||||
disabled: true
|
||||
- name: function-result-limit
|
||||
disabled: true
|
||||
- name: function-length
|
||||
disabled: true
|
||||
- name: line-length-limit
|
||||
disabled: true
|
||||
- name: max-public-structs
|
||||
disabled: true
|
||||
- name: modifies-parameter
|
||||
disabled: true
|
||||
- name: nested-structs
|
||||
disabled: true
|
||||
- name: package-comments
|
||||
disabled: true
|
||||
|
||||
stylecheck:
|
||||
checks:
|
||||
- all
|
||||
- -ST1000
|
||||
- -ST1020
|
||||
- -ST1021
|
||||
- -ST1022
|
||||
|
||||
tagliatelle:
|
||||
case:
|
||||
rules:
|
||||
json: snake
|
||||
|
||||
#tenv:
|
||||
# all: true
|
||||
|
||||
#unparam:
|
||||
# check-exported: true
|
||||
|
||||
wrapcheck:
|
||||
ignorePackageGlobs:
|
||||
- github.com/gofiber/fiber/*
|
||||
- github.com/valyala/fasthttp
|
||||
|
||||
issues:
|
||||
exclude-use-default: false
|
||||
|
||||
linters:
|
||||
enable:
|
||||
- asasalint
|
||||
- asciicheck
|
||||
- bidichk
|
||||
- bodyclose
|
||||
- containedctx
|
||||
- contextcheck
|
||||
- depguard
|
||||
- dogsled
|
||||
- durationcheck
|
||||
- errcheck
|
||||
- errchkjson
|
||||
- errname
|
||||
- errorlint
|
||||
- execinquery
|
||||
- exhaustive
|
||||
- exportloopref
|
||||
- forbidigo
|
||||
- forcetypeassert
|
||||
- goconst
|
||||
- gocritic
|
||||
- gofmt
|
||||
- gofumpt
|
||||
- goimports
|
||||
- gomoddirectives
|
||||
- goprintffuncname
|
||||
- gosec
|
||||
- gosimple
|
||||
- govet
|
||||
- grouper
|
||||
- loggercheck
|
||||
- misspell
|
||||
- nakedret
|
||||
- nilerr
|
||||
- nilnil
|
||||
- noctx
|
||||
- nolintlint
|
||||
- nonamedreturns
|
||||
- nosprintfhostport
|
||||
- predeclared
|
||||
- promlinter
|
||||
- reassign
|
||||
- revive
|
||||
- rowserrcheck
|
||||
- sqlclosecheck
|
||||
- staticcheck
|
||||
- stylecheck
|
||||
- tagliatelle
|
||||
# - testpackage # TODO: Enable once https://github.com/gofiber/fiber/issues/2252 is implemented
|
||||
- thelper
|
||||
# - tparallel # TODO: Enable once https://github.com/gofiber/fiber/issues/2254 is implemented
|
||||
- typecheck
|
||||
- unconvert
|
||||
- unparam
|
||||
- unused
|
||||
- usestdlibvars
|
||||
- wastedassign
|
||||
- whitespace
|
||||
- wrapcheck
|
||||
+21
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2019-present Fenny and Contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
+1120
File diff suppressed because it is too large
Load Diff
+1021
File diff suppressed because it is too large
Load Diff
+107
@@ -0,0 +1,107 @@
|
||||
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
|
||||
// 🤖 Github Repository: https://github.com/gofiber/fiber
|
||||
// 📌 API Documentation: https://docs.gofiber.io
|
||||
|
||||
package fiber
|
||||
|
||||
// Colors is a struct to define custom colors for Fiber app and middlewares.
|
||||
type Colors struct {
|
||||
// Black color.
|
||||
//
|
||||
// Optional. Default: "\u001b[90m"
|
||||
Black string
|
||||
|
||||
// Red color.
|
||||
//
|
||||
// Optional. Default: "\u001b[91m"
|
||||
Red string
|
||||
|
||||
// Green color.
|
||||
//
|
||||
// Optional. Default: "\u001b[92m"
|
||||
Green string
|
||||
|
||||
// Yellow color.
|
||||
//
|
||||
// Optional. Default: "\u001b[93m"
|
||||
Yellow string
|
||||
|
||||
// Blue color.
|
||||
//
|
||||
// Optional. Default: "\u001b[94m"
|
||||
Blue string
|
||||
|
||||
// Magenta color.
|
||||
//
|
||||
// Optional. Default: "\u001b[95m"
|
||||
Magenta string
|
||||
|
||||
// Cyan color.
|
||||
//
|
||||
// Optional. Default: "\u001b[96m"
|
||||
Cyan string
|
||||
|
||||
// White color.
|
||||
//
|
||||
// Optional. Default: "\u001b[97m"
|
||||
White string
|
||||
|
||||
// Reset color.
|
||||
//
|
||||
// Optional. Default: "\u001b[0m"
|
||||
Reset string
|
||||
}
|
||||
|
||||
// DefaultColors Default color codes
|
||||
var DefaultColors = Colors{
|
||||
Black: "\u001b[90m",
|
||||
Red: "\u001b[91m",
|
||||
Green: "\u001b[92m",
|
||||
Yellow: "\u001b[93m",
|
||||
Blue: "\u001b[94m",
|
||||
Magenta: "\u001b[95m",
|
||||
Cyan: "\u001b[96m",
|
||||
White: "\u001b[97m",
|
||||
Reset: "\u001b[0m",
|
||||
}
|
||||
|
||||
// defaultColors is a function to override default colors to config
|
||||
func defaultColors(colors Colors) Colors {
|
||||
if colors.Black == "" {
|
||||
colors.Black = DefaultColors.Black
|
||||
}
|
||||
|
||||
if colors.Red == "" {
|
||||
colors.Red = DefaultColors.Red
|
||||
}
|
||||
|
||||
if colors.Green == "" {
|
||||
colors.Green = DefaultColors.Green
|
||||
}
|
||||
|
||||
if colors.Yellow == "" {
|
||||
colors.Yellow = DefaultColors.Yellow
|
||||
}
|
||||
|
||||
if colors.Blue == "" {
|
||||
colors.Blue = DefaultColors.Blue
|
||||
}
|
||||
|
||||
if colors.Magenta == "" {
|
||||
colors.Magenta = DefaultColors.Magenta
|
||||
}
|
||||
|
||||
if colors.Cyan == "" {
|
||||
colors.Cyan = DefaultColors.Cyan
|
||||
}
|
||||
|
||||
if colors.White == "" {
|
||||
colors.White = DefaultColors.White
|
||||
}
|
||||
|
||||
if colors.Reset == "" {
|
||||
colors.Reset = DefaultColors.Reset
|
||||
}
|
||||
|
||||
return colors
|
||||
}
|
||||
+1993
File diff suppressed because it is too large
Load Diff
+40
@@ -0,0 +1,40 @@
|
||||
package fiber
|
||||
|
||||
import (
|
||||
errors "encoding/json"
|
||||
|
||||
"github.com/gofiber/fiber/v2/internal/schema"
|
||||
)
|
||||
|
||||
type (
|
||||
// ConversionError Conversion error exposes the internal schema.ConversionError for public use.
|
||||
ConversionError = schema.ConversionError
|
||||
// UnknownKeyError error exposes the internal schema.UnknownKeyError for public use.
|
||||
UnknownKeyError = schema.UnknownKeyError
|
||||
// EmptyFieldError error exposes the internal schema.EmptyFieldError for public use.
|
||||
EmptyFieldError = schema.EmptyFieldError
|
||||
// MultiError error exposes the internal schema.MultiError for public use.
|
||||
MultiError = schema.MultiError
|
||||
)
|
||||
|
||||
type (
|
||||
// An InvalidUnmarshalError describes an invalid argument passed to Unmarshal.
|
||||
// (The argument to Unmarshal must be a non-nil pointer.)
|
||||
InvalidUnmarshalError = errors.InvalidUnmarshalError
|
||||
|
||||
// A MarshalerError represents an error from calling a MarshalJSON or MarshalText method.
|
||||
MarshalerError = errors.MarshalerError
|
||||
|
||||
// A SyntaxError is a description of a JSON syntax error.
|
||||
SyntaxError = errors.SyntaxError
|
||||
|
||||
// An UnmarshalTypeError describes a JSON value that was
|
||||
// not appropriate for a value of a specific Go type.
|
||||
UnmarshalTypeError = errors.UnmarshalTypeError
|
||||
|
||||
// An UnsupportedTypeError is returned by Marshal when attempting
|
||||
// to encode an unsupported value type.
|
||||
UnsupportedTypeError = errors.UnsupportedTypeError
|
||||
|
||||
UnsupportedValueError = errors.UnsupportedValueError
|
||||
)
|
||||
+209
@@ -0,0 +1,209 @@
|
||||
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
|
||||
// 🤖 Github Repository: https://github.com/gofiber/fiber
|
||||
// 📌 API Documentation: https://docs.gofiber.io
|
||||
|
||||
package fiber
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Group struct
|
||||
type Group struct {
|
||||
app *App
|
||||
parentGroup *Group
|
||||
name string
|
||||
anyRouteDefined bool
|
||||
|
||||
Prefix string
|
||||
}
|
||||
|
||||
// Name Assign name to specific route or group itself.
|
||||
//
|
||||
// If this method is used before any route added to group, it'll set group name and OnGroupNameHook will be used.
|
||||
// Otherwise, it'll set route name and OnName hook will be used.
|
||||
func (grp *Group) Name(name string) Router {
|
||||
if grp.anyRouteDefined {
|
||||
grp.app.Name(name)
|
||||
|
||||
return grp
|
||||
}
|
||||
|
||||
grp.app.mutex.Lock()
|
||||
if grp.parentGroup != nil {
|
||||
grp.name = grp.parentGroup.name + name
|
||||
} else {
|
||||
grp.name = name
|
||||
}
|
||||
|
||||
if err := grp.app.hooks.executeOnGroupNameHooks(*grp); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
grp.app.mutex.Unlock()
|
||||
|
||||
return grp
|
||||
}
|
||||
|
||||
// Use registers a middleware route that will match requests
|
||||
// with the provided prefix (which is optional and defaults to "/").
|
||||
//
|
||||
// app.Use(func(c *fiber.Ctx) error {
|
||||
// return c.Next()
|
||||
// })
|
||||
// app.Use("/api", func(c *fiber.Ctx) error {
|
||||
// return c.Next()
|
||||
// })
|
||||
// app.Use("/api", handler, func(c *fiber.Ctx) error {
|
||||
// return c.Next()
|
||||
// })
|
||||
//
|
||||
// This method will match all HTTP verbs: GET, POST, PUT, HEAD etc...
|
||||
func (grp *Group) Use(args ...interface{}) Router {
|
||||
var prefix string
|
||||
var prefixes []string
|
||||
var handlers []Handler
|
||||
|
||||
for i := 0; i < len(args); i++ {
|
||||
switch arg := args[i].(type) {
|
||||
case string:
|
||||
prefix = arg
|
||||
case []string:
|
||||
prefixes = arg
|
||||
case Handler:
|
||||
handlers = append(handlers, arg)
|
||||
default:
|
||||
panic(fmt.Sprintf("use: invalid handler %v\n", reflect.TypeOf(arg)))
|
||||
}
|
||||
}
|
||||
|
||||
if len(prefixes) == 0 {
|
||||
prefixes = append(prefixes, prefix)
|
||||
}
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
grp.app.register(methodUse, getGroupPath(grp.Prefix, prefix), grp, handlers...)
|
||||
}
|
||||
|
||||
if !grp.anyRouteDefined {
|
||||
grp.anyRouteDefined = true
|
||||
}
|
||||
|
||||
return grp
|
||||
}
|
||||
|
||||
// Get registers a route for GET methods that requests a representation
|
||||
// of the specified resource. Requests using GET should only retrieve data.
|
||||
func (grp *Group) Get(path string, handlers ...Handler) Router {
|
||||
grp.Add(MethodHead, path, handlers...)
|
||||
return grp.Add(MethodGet, path, handlers...)
|
||||
}
|
||||
|
||||
// Head registers a route for HEAD methods that asks for a response identical
|
||||
// to that of a GET request, but without the response body.
|
||||
func (grp *Group) Head(path string, handlers ...Handler) Router {
|
||||
return grp.Add(MethodHead, path, handlers...)
|
||||
}
|
||||
|
||||
// Post registers a route for POST methods that is used to submit an entity to the
|
||||
// specified resource, often causing a change in state or side effects on the server.
|
||||
func (grp *Group) Post(path string, handlers ...Handler) Router {
|
||||
return grp.Add(MethodPost, path, handlers...)
|
||||
}
|
||||
|
||||
// Put registers a route for PUT methods that replaces all current representations
|
||||
// of the target resource with the request payload.
|
||||
func (grp *Group) Put(path string, handlers ...Handler) Router {
|
||||
return grp.Add(MethodPut, path, handlers...)
|
||||
}
|
||||
|
||||
// Delete registers a route for DELETE methods that deletes the specified resource.
|
||||
func (grp *Group) Delete(path string, handlers ...Handler) Router {
|
||||
return grp.Add(MethodDelete, path, handlers...)
|
||||
}
|
||||
|
||||
// Connect registers a route for CONNECT methods that establishes a tunnel to the
|
||||
// server identified by the target resource.
|
||||
func (grp *Group) Connect(path string, handlers ...Handler) Router {
|
||||
return grp.Add(MethodConnect, path, handlers...)
|
||||
}
|
||||
|
||||
// Options registers a route for OPTIONS methods that is used to describe the
|
||||
// communication options for the target resource.
|
||||
func (grp *Group) Options(path string, handlers ...Handler) Router {
|
||||
return grp.Add(MethodOptions, path, handlers...)
|
||||
}
|
||||
|
||||
// Trace registers a route for TRACE methods that performs a message loop-back
|
||||
// test along the path to the target resource.
|
||||
func (grp *Group) Trace(path string, handlers ...Handler) Router {
|
||||
return grp.Add(MethodTrace, path, handlers...)
|
||||
}
|
||||
|
||||
// Patch registers a route for PATCH methods that is used to apply partial
|
||||
// modifications to a resource.
|
||||
func (grp *Group) Patch(path string, handlers ...Handler) Router {
|
||||
return grp.Add(MethodPatch, path, handlers...)
|
||||
}
|
||||
|
||||
// Add allows you to specify a HTTP method to register a route
|
||||
func (grp *Group) Add(method, path string, handlers ...Handler) Router {
|
||||
grp.app.register(method, getGroupPath(grp.Prefix, path), grp, handlers...)
|
||||
if !grp.anyRouteDefined {
|
||||
grp.anyRouteDefined = true
|
||||
}
|
||||
|
||||
return grp
|
||||
}
|
||||
|
||||
// Static will create a file server serving static files
|
||||
func (grp *Group) Static(prefix, root string, config ...Static) Router {
|
||||
grp.app.registerStatic(getGroupPath(grp.Prefix, prefix), root, config...)
|
||||
if !grp.anyRouteDefined {
|
||||
grp.anyRouteDefined = true
|
||||
}
|
||||
|
||||
return grp
|
||||
}
|
||||
|
||||
// All will register the handler on all HTTP methods
|
||||
func (grp *Group) All(path string, handlers ...Handler) Router {
|
||||
for _, method := range grp.app.config.RequestMethods {
|
||||
_ = grp.Add(method, path, handlers...)
|
||||
}
|
||||
return grp
|
||||
}
|
||||
|
||||
// Group is used for Routes with common prefix to define a new sub-router with optional middleware.
|
||||
//
|
||||
// api := app.Group("/api")
|
||||
// api.Get("/users", handler)
|
||||
func (grp *Group) Group(prefix string, handlers ...Handler) Router {
|
||||
prefix = getGroupPath(grp.Prefix, prefix)
|
||||
if len(handlers) > 0 {
|
||||
grp.app.register(methodUse, prefix, grp, handlers...)
|
||||
}
|
||||
|
||||
// Create new group
|
||||
newGrp := &Group{Prefix: prefix, app: grp.app, parentGroup: grp}
|
||||
if err := grp.app.hooks.executeOnGroupHooks(*newGrp); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return newGrp
|
||||
}
|
||||
|
||||
// Route is used to define routes with a common prefix inside the common function.
|
||||
// Uses Group method to define new sub-router.
|
||||
func (grp *Group) Route(prefix string, fn func(router Router), name ...string) Router {
|
||||
// Create new group
|
||||
group := grp.Group(prefix)
|
||||
if len(name) > 0 {
|
||||
group.Name(name[0])
|
||||
}
|
||||
|
||||
// Define routes
|
||||
fn(group)
|
||||
|
||||
return group
|
||||
}
|
||||
+1153
File diff suppressed because it is too large
Load Diff
+218
@@ -0,0 +1,218 @@
|
||||
package fiber
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
// OnRouteHandler Handlers define a function to create hooks for Fiber.
|
||||
type (
|
||||
OnRouteHandler = func(Route) error
|
||||
OnNameHandler = OnRouteHandler
|
||||
OnGroupHandler = func(Group) error
|
||||
OnGroupNameHandler = OnGroupHandler
|
||||
OnListenHandler = func(ListenData) error
|
||||
OnShutdownHandler = func() error
|
||||
OnForkHandler = func(int) error
|
||||
OnMountHandler = func(*App) error
|
||||
)
|
||||
|
||||
// Hooks is a struct to use it with App.
|
||||
type Hooks struct {
|
||||
// Embed app
|
||||
app *App
|
||||
|
||||
// Hooks
|
||||
onRoute []OnRouteHandler
|
||||
onName []OnNameHandler
|
||||
onGroup []OnGroupHandler
|
||||
onGroupName []OnGroupNameHandler
|
||||
onListen []OnListenHandler
|
||||
onShutdown []OnShutdownHandler
|
||||
onFork []OnForkHandler
|
||||
onMount []OnMountHandler
|
||||
}
|
||||
|
||||
// ListenData is a struct to use it with OnListenHandler
|
||||
type ListenData struct {
|
||||
Host string
|
||||
Port string
|
||||
TLS bool
|
||||
}
|
||||
|
||||
func newHooks(app *App) *Hooks {
|
||||
return &Hooks{
|
||||
app: app,
|
||||
onRoute: make([]OnRouteHandler, 0),
|
||||
onGroup: make([]OnGroupHandler, 0),
|
||||
onGroupName: make([]OnGroupNameHandler, 0),
|
||||
onName: make([]OnNameHandler, 0),
|
||||
onListen: make([]OnListenHandler, 0),
|
||||
onShutdown: make([]OnShutdownHandler, 0),
|
||||
onFork: make([]OnForkHandler, 0),
|
||||
onMount: make([]OnMountHandler, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// OnRoute is a hook to execute user functions on each route registeration.
|
||||
// Also you can get route properties by route parameter.
|
||||
func (h *Hooks) OnRoute(handler ...OnRouteHandler) {
|
||||
h.app.mutex.Lock()
|
||||
h.onRoute = append(h.onRoute, handler...)
|
||||
h.app.mutex.Unlock()
|
||||
}
|
||||
|
||||
// OnName is a hook to execute user functions on each route naming.
|
||||
// Also you can get route properties by route parameter.
|
||||
//
|
||||
// WARN: OnName only works with naming routes, not groups.
|
||||
func (h *Hooks) OnName(handler ...OnNameHandler) {
|
||||
h.app.mutex.Lock()
|
||||
h.onName = append(h.onName, handler...)
|
||||
h.app.mutex.Unlock()
|
||||
}
|
||||
|
||||
// OnGroup is a hook to execute user functions on each group registeration.
|
||||
// Also you can get group properties by group parameter.
|
||||
func (h *Hooks) OnGroup(handler ...OnGroupHandler) {
|
||||
h.app.mutex.Lock()
|
||||
h.onGroup = append(h.onGroup, handler...)
|
||||
h.app.mutex.Unlock()
|
||||
}
|
||||
|
||||
// OnGroupName is a hook to execute user functions on each group naming.
|
||||
// Also you can get group properties by group parameter.
|
||||
//
|
||||
// WARN: OnGroupName only works with naming groups, not routes.
|
||||
func (h *Hooks) OnGroupName(handler ...OnGroupNameHandler) {
|
||||
h.app.mutex.Lock()
|
||||
h.onGroupName = append(h.onGroupName, handler...)
|
||||
h.app.mutex.Unlock()
|
||||
}
|
||||
|
||||
// OnListen is a hook to execute user functions on Listen, ListenTLS, Listener.
|
||||
func (h *Hooks) OnListen(handler ...OnListenHandler) {
|
||||
h.app.mutex.Lock()
|
||||
h.onListen = append(h.onListen, handler...)
|
||||
h.app.mutex.Unlock()
|
||||
}
|
||||
|
||||
// OnShutdown is a hook to execute user functions after Shutdown.
|
||||
func (h *Hooks) OnShutdown(handler ...OnShutdownHandler) {
|
||||
h.app.mutex.Lock()
|
||||
h.onShutdown = append(h.onShutdown, handler...)
|
||||
h.app.mutex.Unlock()
|
||||
}
|
||||
|
||||
// OnFork is a hook to execute user function after fork process.
|
||||
func (h *Hooks) OnFork(handler ...OnForkHandler) {
|
||||
h.app.mutex.Lock()
|
||||
h.onFork = append(h.onFork, handler...)
|
||||
h.app.mutex.Unlock()
|
||||
}
|
||||
|
||||
// OnMount is a hook to execute user function after mounting process.
|
||||
// The mount event is fired when sub-app is mounted on a parent app. The parent app is passed as a parameter.
|
||||
// It works for app and group mounting.
|
||||
func (h *Hooks) OnMount(handler ...OnMountHandler) {
|
||||
h.app.mutex.Lock()
|
||||
h.onMount = append(h.onMount, handler...)
|
||||
h.app.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (h *Hooks) executeOnRouteHooks(route Route) error {
|
||||
// Check mounting
|
||||
if h.app.mountFields.mountPath != "" {
|
||||
route.path = h.app.mountFields.mountPath + route.path
|
||||
route.Path = route.path
|
||||
}
|
||||
|
||||
for _, v := range h.onRoute {
|
||||
if err := v(route); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hooks) executeOnNameHooks(route Route) error {
|
||||
// Check mounting
|
||||
if h.app.mountFields.mountPath != "" {
|
||||
route.path = h.app.mountFields.mountPath + route.path
|
||||
route.Path = route.path
|
||||
}
|
||||
|
||||
for _, v := range h.onName {
|
||||
if err := v(route); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hooks) executeOnGroupHooks(group Group) error {
|
||||
// Check mounting
|
||||
if h.app.mountFields.mountPath != "" {
|
||||
group.Prefix = h.app.mountFields.mountPath + group.Prefix
|
||||
}
|
||||
|
||||
for _, v := range h.onGroup {
|
||||
if err := v(group); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hooks) executeOnGroupNameHooks(group Group) error {
|
||||
// Check mounting
|
||||
if h.app.mountFields.mountPath != "" {
|
||||
group.Prefix = h.app.mountFields.mountPath + group.Prefix
|
||||
}
|
||||
|
||||
for _, v := range h.onGroupName {
|
||||
if err := v(group); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hooks) executeOnListenHooks(listenData ListenData) error {
|
||||
for _, v := range h.onListen {
|
||||
if err := v(listenData); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hooks) executeOnShutdownHooks() {
|
||||
for _, v := range h.onShutdown {
|
||||
if err := v(); err != nil {
|
||||
log.Errorf("failed to call shutdown hook: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hooks) executeOnForkHooks(pid int) {
|
||||
for _, v := range h.onFork {
|
||||
if err := v(pid); err != nil {
|
||||
log.Errorf("failed to call fork hook: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hooks) executeOnMountHooks(app *App) error {
|
||||
for _, v := range h.onMount {
|
||||
if err := v(app); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
+97
@@ -0,0 +1,97 @@
|
||||
// Package memory Is a slight copy of the memory storage, but far from the storage interface it can not only work with bytes
|
||||
// but directly store any kind of data without having to encode it each time, which gives a huge speed advantage
|
||||
package memory
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
type Storage struct {
|
||||
sync.RWMutex
|
||||
data map[string]item // data
|
||||
}
|
||||
|
||||
type item struct {
|
||||
// max value is 4294967295 -> Sun Feb 07 2106 06:28:15 GMT+0000
|
||||
e uint32 // exp
|
||||
v interface{} // val
|
||||
}
|
||||
|
||||
func New() *Storage {
|
||||
store := &Storage{
|
||||
data: make(map[string]item),
|
||||
}
|
||||
utils.StartTimeStampUpdater()
|
||||
go store.gc(1 * time.Second)
|
||||
return store
|
||||
}
|
||||
|
||||
// Get value by key
|
||||
func (s *Storage) Get(key string) interface{} {
|
||||
s.RLock()
|
||||
v, ok := s.data[key]
|
||||
s.RUnlock()
|
||||
if !ok || v.e != 0 && v.e <= atomic.LoadUint32(&utils.Timestamp) {
|
||||
return nil
|
||||
}
|
||||
return v.v
|
||||
}
|
||||
|
||||
// Set key with value
|
||||
func (s *Storage) Set(key string, val interface{}, ttl time.Duration) {
|
||||
var exp uint32
|
||||
if ttl > 0 {
|
||||
exp = uint32(ttl.Seconds()) + atomic.LoadUint32(&utils.Timestamp)
|
||||
}
|
||||
i := item{exp, val}
|
||||
s.Lock()
|
||||
s.data[key] = i
|
||||
s.Unlock()
|
||||
}
|
||||
|
||||
// Delete key by key
|
||||
func (s *Storage) Delete(key string) {
|
||||
s.Lock()
|
||||
delete(s.data, key)
|
||||
s.Unlock()
|
||||
}
|
||||
|
||||
// Reset all keys
|
||||
func (s *Storage) Reset() {
|
||||
nd := make(map[string]item)
|
||||
s.Lock()
|
||||
s.data = nd
|
||||
s.Unlock()
|
||||
}
|
||||
|
||||
func (s *Storage) gc(sleep time.Duration) {
|
||||
ticker := time.NewTicker(sleep)
|
||||
defer ticker.Stop()
|
||||
var expired []string
|
||||
|
||||
for range ticker.C {
|
||||
ts := atomic.LoadUint32(&utils.Timestamp)
|
||||
expired = expired[:0]
|
||||
s.RLock()
|
||||
for key, v := range s.data {
|
||||
if v.e != 0 && v.e <= ts {
|
||||
expired = append(expired, key)
|
||||
}
|
||||
}
|
||||
s.RUnlock()
|
||||
s.Lock()
|
||||
// Double-checked locking.
|
||||
// We might have replaced the item in the meantime.
|
||||
for i := range expired {
|
||||
v := s.data[expired[i]]
|
||||
if v.e != 0 && v.e <= ts {
|
||||
delete(s.data, expired[i])
|
||||
}
|
||||
}
|
||||
s.Unlock()
|
||||
}
|
||||
}
|
||||
+27
@@ -0,0 +1,27 @@
|
||||
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
+305
@@ -0,0 +1,305 @@
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package schema
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var errInvalidPath = errors.New("schema: invalid path")
|
||||
|
||||
// newCache returns a new cache.
|
||||
func newCache() *cache {
|
||||
c := cache{
|
||||
m: make(map[reflect.Type]*structInfo),
|
||||
regconv: make(map[reflect.Type]Converter),
|
||||
tag: "schema",
|
||||
}
|
||||
return &c
|
||||
}
|
||||
|
||||
// cache caches meta-data about a struct.
|
||||
type cache struct {
|
||||
l sync.RWMutex
|
||||
m map[reflect.Type]*structInfo
|
||||
regconv map[reflect.Type]Converter
|
||||
tag string
|
||||
}
|
||||
|
||||
// registerConverter registers a converter function for a custom type.
|
||||
func (c *cache) registerConverter(value interface{}, converterFunc Converter) {
|
||||
c.regconv[reflect.TypeOf(value)] = converterFunc
|
||||
}
|
||||
|
||||
// parsePath parses a path in dotted notation verifying that it is a valid
|
||||
// path to a struct field.
|
||||
//
|
||||
// It returns "path parts" which contain indices to fields to be used by
|
||||
// reflect.Value.FieldByString(). Multiple parts are required for slices of
|
||||
// structs.
|
||||
func (c *cache) parsePath(p string, t reflect.Type) ([]pathPart, error) {
|
||||
var struc *structInfo
|
||||
var field *fieldInfo
|
||||
var index64 int64
|
||||
var err error
|
||||
parts := make([]pathPart, 0)
|
||||
path := make([]string, 0)
|
||||
keys := strings.Split(p, ".")
|
||||
for i := 0; i < len(keys); i++ {
|
||||
if t.Kind() != reflect.Struct {
|
||||
return nil, errInvalidPath
|
||||
}
|
||||
if struc = c.get(t); struc == nil {
|
||||
return nil, errInvalidPath
|
||||
}
|
||||
if field = struc.get(keys[i]); field == nil {
|
||||
return nil, errInvalidPath
|
||||
}
|
||||
// Valid field. Append index.
|
||||
path = append(path, field.name)
|
||||
if field.isSliceOfStructs && (!field.unmarshalerInfo.IsValid || (field.unmarshalerInfo.IsValid && field.unmarshalerInfo.IsSliceElement)) {
|
||||
// Parse a special case: slices of structs.
|
||||
// i+1 must be the slice index.
|
||||
//
|
||||
// Now that struct can implements TextUnmarshaler interface,
|
||||
// we don't need to force the struct's fields to appear in the path.
|
||||
// So checking i+2 is not necessary anymore.
|
||||
i++
|
||||
if i+1 > len(keys) {
|
||||
return nil, errInvalidPath
|
||||
}
|
||||
if index64, err = strconv.ParseInt(keys[i], 10, 0); err != nil {
|
||||
return nil, errInvalidPath
|
||||
}
|
||||
parts = append(parts, pathPart{
|
||||
path: path,
|
||||
field: field,
|
||||
index: int(index64),
|
||||
})
|
||||
path = make([]string, 0)
|
||||
|
||||
// Get the next struct type, dropping ptrs.
|
||||
if field.typ.Kind() == reflect.Ptr {
|
||||
t = field.typ.Elem()
|
||||
} else {
|
||||
t = field.typ
|
||||
}
|
||||
if t.Kind() == reflect.Slice {
|
||||
t = t.Elem()
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
}
|
||||
} else if field.typ.Kind() == reflect.Ptr {
|
||||
t = field.typ.Elem()
|
||||
} else {
|
||||
t = field.typ
|
||||
}
|
||||
}
|
||||
// Add the remaining.
|
||||
parts = append(parts, pathPart{
|
||||
path: path,
|
||||
field: field,
|
||||
index: -1,
|
||||
})
|
||||
return parts, nil
|
||||
}
|
||||
|
||||
// get returns a cached structInfo, creating it if necessary.
|
||||
func (c *cache) get(t reflect.Type) *structInfo {
|
||||
c.l.RLock()
|
||||
info := c.m[t]
|
||||
c.l.RUnlock()
|
||||
if info == nil {
|
||||
info = c.create(t, "")
|
||||
c.l.Lock()
|
||||
c.m[t] = info
|
||||
c.l.Unlock()
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
// create creates a structInfo with meta-data about a struct.
|
||||
func (c *cache) create(t reflect.Type, parentAlias string) *structInfo {
|
||||
info := &structInfo{}
|
||||
var anonymousInfos []*structInfo
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
if f := c.createField(t.Field(i), parentAlias); f != nil {
|
||||
info.fields = append(info.fields, f)
|
||||
if ft := indirectType(f.typ); ft.Kind() == reflect.Struct && f.isAnonymous {
|
||||
anonymousInfos = append(anonymousInfos, c.create(ft, f.canonicalAlias))
|
||||
}
|
||||
}
|
||||
}
|
||||
for i, a := range anonymousInfos {
|
||||
others := []*structInfo{info}
|
||||
others = append(others, anonymousInfos[:i]...)
|
||||
others = append(others, anonymousInfos[i+1:]...)
|
||||
for _, f := range a.fields {
|
||||
if !containsAlias(others, f.alias) {
|
||||
info.fields = append(info.fields, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
// createField creates a fieldInfo for the given field.
|
||||
func (c *cache) createField(field reflect.StructField, parentAlias string) *fieldInfo {
|
||||
alias, options := fieldAlias(field, c.tag)
|
||||
if alias == "-" {
|
||||
// Ignore this field.
|
||||
return nil
|
||||
}
|
||||
canonicalAlias := alias
|
||||
if parentAlias != "" {
|
||||
canonicalAlias = parentAlias + "." + alias
|
||||
}
|
||||
// Check if the type is supported and don't cache it if not.
|
||||
// First let's get the basic type.
|
||||
isSlice, isStruct := false, false
|
||||
ft := field.Type
|
||||
m := isTextUnmarshaler(reflect.Zero(ft))
|
||||
if ft.Kind() == reflect.Ptr {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
if isSlice = ft.Kind() == reflect.Slice; isSlice {
|
||||
ft = ft.Elem()
|
||||
if ft.Kind() == reflect.Ptr {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
}
|
||||
if ft.Kind() == reflect.Array {
|
||||
ft = ft.Elem()
|
||||
if ft.Kind() == reflect.Ptr {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
}
|
||||
if isStruct = ft.Kind() == reflect.Struct; !isStruct {
|
||||
if c.converter(ft) == nil && builtinConverters[ft.Kind()] == nil {
|
||||
// Type is not supported.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return &fieldInfo{
|
||||
typ: field.Type,
|
||||
name: field.Name,
|
||||
alias: alias,
|
||||
canonicalAlias: canonicalAlias,
|
||||
unmarshalerInfo: m,
|
||||
isSliceOfStructs: isSlice && isStruct,
|
||||
isAnonymous: field.Anonymous,
|
||||
isRequired: options.Contains("required"),
|
||||
}
|
||||
}
|
||||
|
||||
// converter returns the converter for a type.
|
||||
func (c *cache) converter(t reflect.Type) Converter {
|
||||
return c.regconv[t]
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type structInfo struct {
|
||||
fields []*fieldInfo
|
||||
}
|
||||
|
||||
func (i *structInfo) get(alias string) *fieldInfo {
|
||||
for _, field := range i.fields {
|
||||
if strings.EqualFold(field.alias, alias) {
|
||||
return field
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func containsAlias(infos []*structInfo, alias string) bool {
|
||||
for _, info := range infos {
|
||||
if info.get(alias) != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type fieldInfo struct {
|
||||
typ reflect.Type
|
||||
// name is the field name in the struct.
|
||||
name string
|
||||
alias string
|
||||
// canonicalAlias is almost the same as the alias, but is prefixed with
|
||||
// an embedded struct field alias in dotted notation if this field is
|
||||
// promoted from the struct.
|
||||
// For instance, if the alias is "N" and this field is an embedded field
|
||||
// in a struct "X", canonicalAlias will be "X.N".
|
||||
canonicalAlias string
|
||||
// unmarshalerInfo contains information regarding the
|
||||
// encoding.TextUnmarshaler implementation of the field type.
|
||||
unmarshalerInfo unmarshaler
|
||||
// isSliceOfStructs indicates if the field type is a slice of structs.
|
||||
isSliceOfStructs bool
|
||||
// isAnonymous indicates whether the field is embedded in the struct.
|
||||
isAnonymous bool
|
||||
isRequired bool
|
||||
}
|
||||
|
||||
func (f *fieldInfo) paths(prefix string) []string {
|
||||
if f.alias == f.canonicalAlias {
|
||||
return []string{prefix + f.alias}
|
||||
}
|
||||
return []string{prefix + f.alias, prefix + f.canonicalAlias}
|
||||
}
|
||||
|
||||
type pathPart struct {
|
||||
field *fieldInfo
|
||||
path []string // path to the field: walks structs using field names.
|
||||
index int // struct index in slices of structs.
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func indirectType(typ reflect.Type) reflect.Type {
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
return typ.Elem()
|
||||
}
|
||||
return typ
|
||||
}
|
||||
|
||||
// fieldAlias parses a field tag to get a field alias.
|
||||
func fieldAlias(field reflect.StructField, tagName string) (alias string, options tagOptions) {
|
||||
if tag := field.Tag.Get(tagName); tag != "" {
|
||||
alias, options = parseTag(tag)
|
||||
}
|
||||
if alias == "" {
|
||||
alias = field.Name
|
||||
}
|
||||
return alias, options
|
||||
}
|
||||
|
||||
// tagOptions is the string following a comma in a struct field's tag, or
|
||||
// the empty string. It does not include the leading comma.
|
||||
type tagOptions []string
|
||||
|
||||
// parseTag splits a struct field's url tag into its name and comma-separated
|
||||
// options.
|
||||
func parseTag(tag string) (string, tagOptions) {
|
||||
s := strings.Split(tag, ",")
|
||||
return s[0], s[1:]
|
||||
}
|
||||
|
||||
// Contains checks whether the tagOptions contains the specified option.
|
||||
func (o tagOptions) Contains(option string) bool {
|
||||
for _, s := range o {
|
||||
if s == option {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
+145
@@ -0,0 +1,145 @@
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package schema
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Converter func(string) reflect.Value
|
||||
|
||||
var (
|
||||
invalidValue = reflect.Value{}
|
||||
boolType = reflect.Bool
|
||||
float32Type = reflect.Float32
|
||||
float64Type = reflect.Float64
|
||||
intType = reflect.Int
|
||||
int8Type = reflect.Int8
|
||||
int16Type = reflect.Int16
|
||||
int32Type = reflect.Int32
|
||||
int64Type = reflect.Int64
|
||||
stringType = reflect.String
|
||||
uintType = reflect.Uint
|
||||
uint8Type = reflect.Uint8
|
||||
uint16Type = reflect.Uint16
|
||||
uint32Type = reflect.Uint32
|
||||
uint64Type = reflect.Uint64
|
||||
)
|
||||
|
||||
// Default converters for basic types.
|
||||
var builtinConverters = map[reflect.Kind]Converter{
|
||||
boolType: convertBool,
|
||||
float32Type: convertFloat32,
|
||||
float64Type: convertFloat64,
|
||||
intType: convertInt,
|
||||
int8Type: convertInt8,
|
||||
int16Type: convertInt16,
|
||||
int32Type: convertInt32,
|
||||
int64Type: convertInt64,
|
||||
stringType: convertString,
|
||||
uintType: convertUint,
|
||||
uint8Type: convertUint8,
|
||||
uint16Type: convertUint16,
|
||||
uint32Type: convertUint32,
|
||||
uint64Type: convertUint64,
|
||||
}
|
||||
|
||||
func convertBool(value string) reflect.Value {
|
||||
if value == "on" {
|
||||
return reflect.ValueOf(true)
|
||||
} else if v, err := strconv.ParseBool(value); err == nil {
|
||||
return reflect.ValueOf(v)
|
||||
}
|
||||
return invalidValue
|
||||
}
|
||||
|
||||
func convertFloat32(value string) reflect.Value {
|
||||
if v, err := strconv.ParseFloat(value, 32); err == nil {
|
||||
return reflect.ValueOf(float32(v))
|
||||
}
|
||||
return invalidValue
|
||||
}
|
||||
|
||||
func convertFloat64(value string) reflect.Value {
|
||||
if v, err := strconv.ParseFloat(value, 64); err == nil {
|
||||
return reflect.ValueOf(v)
|
||||
}
|
||||
return invalidValue
|
||||
}
|
||||
|
||||
func convertInt(value string) reflect.Value {
|
||||
if v, err := strconv.ParseInt(value, 10, 0); err == nil {
|
||||
return reflect.ValueOf(int(v))
|
||||
}
|
||||
return invalidValue
|
||||
}
|
||||
|
||||
func convertInt8(value string) reflect.Value {
|
||||
if v, err := strconv.ParseInt(value, 10, 8); err == nil {
|
||||
return reflect.ValueOf(int8(v))
|
||||
}
|
||||
return invalidValue
|
||||
}
|
||||
|
||||
func convertInt16(value string) reflect.Value {
|
||||
if v, err := strconv.ParseInt(value, 10, 16); err == nil {
|
||||
return reflect.ValueOf(int16(v))
|
||||
}
|
||||
return invalidValue
|
||||
}
|
||||
|
||||
func convertInt32(value string) reflect.Value {
|
||||
if v, err := strconv.ParseInt(value, 10, 32); err == nil {
|
||||
return reflect.ValueOf(int32(v))
|
||||
}
|
||||
return invalidValue
|
||||
}
|
||||
|
||||
func convertInt64(value string) reflect.Value {
|
||||
if v, err := strconv.ParseInt(value, 10, 64); err == nil {
|
||||
return reflect.ValueOf(v)
|
||||
}
|
||||
return invalidValue
|
||||
}
|
||||
|
||||
func convertString(value string) reflect.Value {
|
||||
return reflect.ValueOf(value)
|
||||
}
|
||||
|
||||
func convertUint(value string) reflect.Value {
|
||||
if v, err := strconv.ParseUint(value, 10, 0); err == nil {
|
||||
return reflect.ValueOf(uint(v))
|
||||
}
|
||||
return invalidValue
|
||||
}
|
||||
|
||||
func convertUint8(value string) reflect.Value {
|
||||
if v, err := strconv.ParseUint(value, 10, 8); err == nil {
|
||||
return reflect.ValueOf(uint8(v))
|
||||
}
|
||||
return invalidValue
|
||||
}
|
||||
|
||||
func convertUint16(value string) reflect.Value {
|
||||
if v, err := strconv.ParseUint(value, 10, 16); err == nil {
|
||||
return reflect.ValueOf(uint16(v))
|
||||
}
|
||||
return invalidValue
|
||||
}
|
||||
|
||||
func convertUint32(value string) reflect.Value {
|
||||
if v, err := strconv.ParseUint(value, 10, 32); err == nil {
|
||||
return reflect.ValueOf(uint32(v))
|
||||
}
|
||||
return invalidValue
|
||||
}
|
||||
|
||||
func convertUint64(value string) reflect.Value {
|
||||
if v, err := strconv.ParseUint(value, 10, 64); err == nil {
|
||||
return reflect.ValueOf(v)
|
||||
}
|
||||
return invalidValue
|
||||
}
|
||||
+534
@@ -0,0 +1,534 @@
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package schema
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// NewDecoder returns a new Decoder.
|
||||
func NewDecoder() *Decoder {
|
||||
return &Decoder{cache: newCache()}
|
||||
}
|
||||
|
||||
// Decoder decodes values from a map[string][]string to a struct.
|
||||
type Decoder struct {
|
||||
cache *cache
|
||||
zeroEmpty bool
|
||||
ignoreUnknownKeys bool
|
||||
}
|
||||
|
||||
// SetAliasTag changes the tag used to locate custom field aliases.
|
||||
// The default tag is "schema".
|
||||
func (d *Decoder) SetAliasTag(tag string) {
|
||||
d.cache.tag = tag
|
||||
}
|
||||
|
||||
// ZeroEmpty controls the behaviour when the decoder encounters empty values
|
||||
// in a map.
|
||||
// If z is true and a key in the map has the empty string as a value
|
||||
// then the corresponding struct field is set to the zero value.
|
||||
// If z is false then empty strings are ignored.
|
||||
//
|
||||
// The default value is false, that is empty values do not change
|
||||
// the value of the struct field.
|
||||
func (d *Decoder) ZeroEmpty(z bool) {
|
||||
d.zeroEmpty = z
|
||||
}
|
||||
|
||||
// IgnoreUnknownKeys controls the behaviour when the decoder encounters unknown
|
||||
// keys in the map.
|
||||
// If i is true and an unknown field is encountered, it is ignored. This is
|
||||
// similar to how unknown keys are handled by encoding/json.
|
||||
// If i is false then Decode will return an error. Note that any valid keys
|
||||
// will still be decoded in to the target struct.
|
||||
//
|
||||
// To preserve backwards compatibility, the default value is false.
|
||||
func (d *Decoder) IgnoreUnknownKeys(i bool) {
|
||||
d.ignoreUnknownKeys = i
|
||||
}
|
||||
|
||||
// RegisterConverter registers a converter function for a custom type.
|
||||
func (d *Decoder) RegisterConverter(value interface{}, converterFunc Converter) {
|
||||
d.cache.registerConverter(value, converterFunc)
|
||||
}
|
||||
|
||||
// Decode decodes a map[string][]string to a struct.
|
||||
//
|
||||
// The first parameter must be a pointer to a struct.
|
||||
//
|
||||
// The second parameter is a map, typically url.Values from an HTTP request.
|
||||
// Keys are "paths" in dotted notation to the struct fields and nested structs.
|
||||
//
|
||||
// See the package documentation for a full explanation of the mechanics.
|
||||
func (d *Decoder) Decode(dst interface{}, src map[string][]string) error {
|
||||
v := reflect.ValueOf(dst)
|
||||
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
|
||||
return errors.New("schema: interface must be a pointer to struct")
|
||||
}
|
||||
v = v.Elem()
|
||||
t := v.Type()
|
||||
multiError := MultiError{}
|
||||
for path, values := range src {
|
||||
if parts, err := d.cache.parsePath(path, t); err == nil {
|
||||
if err = d.decode(v, path, parts, values); err != nil {
|
||||
multiError[path] = err
|
||||
}
|
||||
} else if !d.ignoreUnknownKeys {
|
||||
multiError[path] = UnknownKeyError{Key: path}
|
||||
}
|
||||
}
|
||||
multiError.merge(d.checkRequired(t, src))
|
||||
if len(multiError) > 0 {
|
||||
return multiError
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkRequired checks whether required fields are empty
|
||||
//
|
||||
// check type t recursively if t has struct fields.
|
||||
//
|
||||
// src is the source map for decoding, we use it here to see if those required fields are included in src
|
||||
func (d *Decoder) checkRequired(t reflect.Type, src map[string][]string) MultiError {
|
||||
m, errs := d.findRequiredFields(t, "", "")
|
||||
for key, fields := range m {
|
||||
if isEmptyFields(fields, src) {
|
||||
errs[key] = EmptyFieldError{Key: key}
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
// findRequiredFields recursively searches the struct type t for required fields.
|
||||
//
|
||||
// canonicalPrefix and searchPrefix are used to resolve full paths in dotted notation
|
||||
// for nested struct fields. canonicalPrefix is a complete path which never omits
|
||||
// any embedded struct fields. searchPrefix is a user-friendly path which may omit
|
||||
// some embedded struct fields to point promoted fields.
|
||||
func (d *Decoder) findRequiredFields(t reflect.Type, canonicalPrefix, searchPrefix string) (map[string][]fieldWithPrefix, MultiError) {
|
||||
struc := d.cache.get(t)
|
||||
if struc == nil {
|
||||
// unexpect, cache.get never return nil
|
||||
return nil, MultiError{canonicalPrefix + "*": errors.New("cache fail")}
|
||||
}
|
||||
|
||||
m := map[string][]fieldWithPrefix{}
|
||||
errs := MultiError{}
|
||||
for _, f := range struc.fields {
|
||||
if f.typ.Kind() == reflect.Struct {
|
||||
fcprefix := canonicalPrefix + f.canonicalAlias + "."
|
||||
for _, fspath := range f.paths(searchPrefix) {
|
||||
fm, ferrs := d.findRequiredFields(f.typ, fcprefix, fspath+".")
|
||||
for key, fields := range fm {
|
||||
m[key] = append(m[key], fields...)
|
||||
}
|
||||
errs.merge(ferrs)
|
||||
}
|
||||
}
|
||||
if f.isRequired {
|
||||
key := canonicalPrefix + f.canonicalAlias
|
||||
m[key] = append(m[key], fieldWithPrefix{
|
||||
fieldInfo: f,
|
||||
prefix: searchPrefix,
|
||||
})
|
||||
}
|
||||
}
|
||||
return m, errs
|
||||
}
|
||||
|
||||
type fieldWithPrefix struct {
|
||||
*fieldInfo
|
||||
prefix string
|
||||
}
|
||||
|
||||
// isEmptyFields returns true if all of specified fields are empty.
|
||||
func isEmptyFields(fields []fieldWithPrefix, src map[string][]string) bool {
|
||||
for _, f := range fields {
|
||||
for _, path := range f.paths(f.prefix) {
|
||||
v, ok := src[path]
|
||||
if ok && !isEmpty(f.typ, v) {
|
||||
return false
|
||||
}
|
||||
for key := range src {
|
||||
// issue references:
|
||||
// https://github.com/gofiber/fiber/issues/1414
|
||||
// https://github.com/gorilla/schema/issues/176
|
||||
nested := strings.IndexByte(key, '.') != -1
|
||||
|
||||
// for non required nested structs
|
||||
c1 := strings.HasSuffix(f.prefix, ".") && key == path
|
||||
|
||||
// for required nested structs
|
||||
c2 := f.prefix == "" && nested && strings.HasPrefix(key, path)
|
||||
|
||||
// for non nested fields
|
||||
c3 := f.prefix == "" && !nested && key == path
|
||||
if !isEmpty(f.typ, src[key]) && (c1 || c2 || c3) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isEmpty returns true if value is empty for specific type
|
||||
func isEmpty(t reflect.Type, value []string) bool {
|
||||
if len(value) == 0 {
|
||||
return true
|
||||
}
|
||||
switch t.Kind() {
|
||||
case boolType, float32Type, float64Type, intType, int8Type, int32Type, int64Type, stringType, uint8Type, uint16Type, uint32Type, uint64Type:
|
||||
return len(value[0]) == 0
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// decode fills a struct field using a parsed path.
|
||||
func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values []string) error {
|
||||
// Get the field walking the struct fields by index.
|
||||
for _, name := range parts[0].path {
|
||||
if v.Type().Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
}
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
// alloc embedded structs
|
||||
if v.Type().Kind() == reflect.Struct {
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
if field.Type().Kind() == reflect.Ptr && field.IsNil() && v.Type().Field(i).Anonymous {
|
||||
field.Set(reflect.New(field.Type().Elem()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
v = v.FieldByName(name)
|
||||
}
|
||||
// Don't even bother for unexported fields.
|
||||
if !v.CanSet() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Dereference if needed.
|
||||
t := v.Type()
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
if v.IsNil() {
|
||||
v.Set(reflect.New(t))
|
||||
}
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
// Slice of structs. Let's go recursive.
|
||||
if len(parts) > 1 {
|
||||
idx := parts[0].index
|
||||
if v.IsNil() || v.Len() < idx+1 {
|
||||
value := reflect.MakeSlice(t, idx+1, idx+1)
|
||||
if v.Len() < idx+1 {
|
||||
// Resize it.
|
||||
reflect.Copy(value, v)
|
||||
}
|
||||
v.Set(value)
|
||||
}
|
||||
return d.decode(v.Index(idx), path, parts[1:], values)
|
||||
}
|
||||
|
||||
// Get the converter early in case there is one for a slice type.
|
||||
conv := d.cache.converter(t)
|
||||
m := isTextUnmarshaler(v)
|
||||
if conv == nil && t.Kind() == reflect.Slice && m.IsSliceElement {
|
||||
var items []reflect.Value
|
||||
elemT := t.Elem()
|
||||
isPtrElem := elemT.Kind() == reflect.Ptr
|
||||
if isPtrElem {
|
||||
elemT = elemT.Elem()
|
||||
}
|
||||
|
||||
// Try to get a converter for the element type.
|
||||
conv := d.cache.converter(elemT)
|
||||
if conv == nil {
|
||||
conv = builtinConverters[elemT.Kind()]
|
||||
if conv == nil {
|
||||
// As we are not dealing with slice of structs here, we don't need to check if the type
|
||||
// implements TextUnmarshaler interface
|
||||
return fmt.Errorf("schema: converter not found for %v", elemT)
|
||||
}
|
||||
}
|
||||
|
||||
for key, value := range values {
|
||||
if value == "" {
|
||||
if d.zeroEmpty {
|
||||
items = append(items, reflect.Zero(elemT))
|
||||
}
|
||||
} else if m.IsValid {
|
||||
u := reflect.New(elemT)
|
||||
if m.IsSliceElementPtr {
|
||||
u = reflect.New(reflect.PtrTo(elemT).Elem())
|
||||
}
|
||||
if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value)); err != nil {
|
||||
return ConversionError{
|
||||
Key: path,
|
||||
Type: t,
|
||||
Index: key,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
if m.IsSliceElementPtr {
|
||||
items = append(items, u.Elem().Addr())
|
||||
} else if u.Kind() == reflect.Ptr {
|
||||
items = append(items, u.Elem())
|
||||
} else {
|
||||
items = append(items, u)
|
||||
}
|
||||
} else if item := conv(value); item.IsValid() {
|
||||
if isPtrElem {
|
||||
ptr := reflect.New(elemT)
|
||||
ptr.Elem().Set(item)
|
||||
item = ptr
|
||||
}
|
||||
if item.Type() != elemT && !isPtrElem {
|
||||
item = item.Convert(elemT)
|
||||
}
|
||||
items = append(items, item)
|
||||
} else {
|
||||
if strings.Contains(value, ",") {
|
||||
values := strings.Split(value, ",")
|
||||
for _, value := range values {
|
||||
if value == "" {
|
||||
if d.zeroEmpty {
|
||||
items = append(items, reflect.Zero(elemT))
|
||||
}
|
||||
} else if item := conv(value); item.IsValid() {
|
||||
if isPtrElem {
|
||||
ptr := reflect.New(elemT)
|
||||
ptr.Elem().Set(item)
|
||||
item = ptr
|
||||
}
|
||||
if item.Type() != elemT && !isPtrElem {
|
||||
item = item.Convert(elemT)
|
||||
}
|
||||
items = append(items, item)
|
||||
} else {
|
||||
return ConversionError{
|
||||
Key: path,
|
||||
Type: elemT,
|
||||
Index: key,
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return ConversionError{
|
||||
Key: path,
|
||||
Type: elemT,
|
||||
Index: key,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
value := reflect.Append(reflect.MakeSlice(t, 0, 0), items...)
|
||||
v.Set(value)
|
||||
} else {
|
||||
val := ""
|
||||
// Use the last value provided if any values were provided
|
||||
if len(values) > 0 {
|
||||
val = values[len(values)-1]
|
||||
}
|
||||
|
||||
if conv != nil {
|
||||
if value := conv(val); value.IsValid() {
|
||||
v.Set(value.Convert(t))
|
||||
} else {
|
||||
return ConversionError{
|
||||
Key: path,
|
||||
Type: t,
|
||||
Index: -1,
|
||||
}
|
||||
}
|
||||
} else if m.IsValid {
|
||||
if m.IsPtr {
|
||||
u := reflect.New(v.Type())
|
||||
if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(val)); err != nil {
|
||||
return ConversionError{
|
||||
Key: path,
|
||||
Type: t,
|
||||
Index: -1,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
v.Set(reflect.Indirect(u))
|
||||
} else {
|
||||
// If the value implements the encoding.TextUnmarshaler interface
|
||||
// apply UnmarshalText as the converter
|
||||
if err := m.Unmarshaler.UnmarshalText([]byte(val)); err != nil {
|
||||
return ConversionError{
|
||||
Key: path,
|
||||
Type: t,
|
||||
Index: -1,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if val == "" {
|
||||
if d.zeroEmpty {
|
||||
v.Set(reflect.Zero(t))
|
||||
}
|
||||
} else if conv := builtinConverters[t.Kind()]; conv != nil {
|
||||
if value := conv(val); value.IsValid() {
|
||||
v.Set(value.Convert(t))
|
||||
} else {
|
||||
return ConversionError{
|
||||
Key: path,
|
||||
Type: t,
|
||||
Index: -1,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("schema: converter not found for %v", t)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isTextUnmarshaler(v reflect.Value) unmarshaler {
|
||||
// Create a new unmarshaller instance
|
||||
m := unmarshaler{}
|
||||
if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid {
|
||||
return m
|
||||
}
|
||||
// As the UnmarshalText function should be applied to the pointer of the
|
||||
// type, we check that type to see if it implements the necessary
|
||||
// method.
|
||||
if m.Unmarshaler, m.IsValid = reflect.New(v.Type()).Interface().(encoding.TextUnmarshaler); m.IsValid {
|
||||
m.IsPtr = true
|
||||
return m
|
||||
}
|
||||
|
||||
// if v is []T or *[]T create new T
|
||||
t := v.Type()
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
if t.Kind() == reflect.Slice {
|
||||
// Check if the slice implements encoding.TextUnmarshaller
|
||||
if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid {
|
||||
return m
|
||||
}
|
||||
// If t is a pointer slice, check if its elements implement
|
||||
// encoding.TextUnmarshaler
|
||||
m.IsSliceElement = true
|
||||
if t = t.Elem(); t.Kind() == reflect.Ptr {
|
||||
t = reflect.PtrTo(t.Elem())
|
||||
v = reflect.Zero(t)
|
||||
m.IsSliceElementPtr = true
|
||||
m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler)
|
||||
return m
|
||||
}
|
||||
}
|
||||
|
||||
v = reflect.New(t)
|
||||
m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler)
|
||||
return m
|
||||
}
|
||||
|
||||
// TextUnmarshaler helpers ----------------------------------------------------
|
||||
// unmarshaller contains information about a TextUnmarshaler type
|
||||
type unmarshaler struct {
|
||||
Unmarshaler encoding.TextUnmarshaler
|
||||
// IsValid indicates whether the resolved type indicated by the other
|
||||
// flags implements the encoding.TextUnmarshaler interface.
|
||||
IsValid bool
|
||||
// IsPtr indicates that the resolved type is the pointer of the original
|
||||
// type.
|
||||
IsPtr bool
|
||||
// IsSliceElement indicates that the resolved type is a slice element of
|
||||
// the original type.
|
||||
IsSliceElement bool
|
||||
// IsSliceElementPtr indicates that the resolved type is a pointer to a
|
||||
// slice element of the original type.
|
||||
IsSliceElementPtr bool
|
||||
}
|
||||
|
||||
// Errors ---------------------------------------------------------------------
|
||||
|
||||
// ConversionError stores information about a failed conversion.
|
||||
type ConversionError struct {
|
||||
Key string // key from the source map.
|
||||
Type reflect.Type // expected type of elem
|
||||
Index int // index for multi-value fields; -1 for single-value fields.
|
||||
Err error // low-level error (when it exists)
|
||||
}
|
||||
|
||||
func (e ConversionError) Error() string {
|
||||
var output string
|
||||
|
||||
if e.Index < 0 {
|
||||
output = fmt.Sprintf("schema: error converting value for %q", e.Key)
|
||||
} else {
|
||||
output = fmt.Sprintf("schema: error converting value for index %d of %q",
|
||||
e.Index, e.Key)
|
||||
}
|
||||
|
||||
if e.Err != nil {
|
||||
output = fmt.Sprintf("%s. Details: %s", output, e.Err)
|
||||
}
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
// UnknownKeyError stores information about an unknown key in the source map.
|
||||
type UnknownKeyError struct {
|
||||
Key string // key from the source map.
|
||||
}
|
||||
|
||||
func (e UnknownKeyError) Error() string {
|
||||
return fmt.Sprintf("schema: invalid path %q", e.Key)
|
||||
}
|
||||
|
||||
// EmptyFieldError stores information about an empty required field.
|
||||
type EmptyFieldError struct {
|
||||
Key string // required key in the source map.
|
||||
}
|
||||
|
||||
func (e EmptyFieldError) Error() string {
|
||||
return fmt.Sprintf("%v is empty", e.Key)
|
||||
}
|
||||
|
||||
// MultiError stores multiple decoding errors.
|
||||
//
|
||||
// Borrowed from the App Engine SDK.
|
||||
type MultiError map[string]error
|
||||
|
||||
func (e MultiError) Error() string {
|
||||
s := ""
|
||||
for _, err := range e {
|
||||
s = err.Error()
|
||||
break
|
||||
}
|
||||
switch len(e) {
|
||||
case 0:
|
||||
return "(0 errors)"
|
||||
case 1:
|
||||
return s
|
||||
case 2:
|
||||
return s + " (and 1 other error)"
|
||||
}
|
||||
return fmt.Sprintf("%s (and %d other errors)", s, len(e)-1)
|
||||
}
|
||||
|
||||
func (e MultiError) merge(errors MultiError) {
|
||||
for key, err := range errors {
|
||||
if e[key] == nil {
|
||||
e[key] = err
|
||||
}
|
||||
}
|
||||
}
|
||||
+148
@@ -0,0 +1,148 @@
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package gorilla/schema fills a struct with form values.
|
||||
|
||||
The basic usage is really simple. Given this struct:
|
||||
|
||||
type Person struct {
|
||||
Name string
|
||||
Phone string
|
||||
}
|
||||
|
||||
...we can fill it passing a map to the Decode() function:
|
||||
|
||||
values := map[string][]string{
|
||||
"Name": {"John"},
|
||||
"Phone": {"999-999-999"},
|
||||
}
|
||||
person := new(Person)
|
||||
decoder := schema.NewDecoder()
|
||||
decoder.Decode(person, values)
|
||||
|
||||
This is just a simple example and it doesn't make a lot of sense to create
|
||||
the map manually. Typically it will come from a http.Request object and
|
||||
will be of type url.Values, http.Request.Form, or http.Request.MultipartForm:
|
||||
|
||||
func MyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseForm()
|
||||
|
||||
if err != nil {
|
||||
// Handle error
|
||||
}
|
||||
|
||||
decoder := schema.NewDecoder()
|
||||
// r.PostForm is a map of our POST form values
|
||||
err := decoder.Decode(person, r.PostForm)
|
||||
|
||||
if err != nil {
|
||||
// Handle error
|
||||
}
|
||||
|
||||
// Do something with person.Name or person.Phone
|
||||
}
|
||||
|
||||
Note: it is a good idea to set a Decoder instance as a package global,
|
||||
because it caches meta-data about structs, and an instance can be shared safely:
|
||||
|
||||
var decoder = schema.NewDecoder()
|
||||
|
||||
To define custom names for fields, use a struct tag "schema". To not populate
|
||||
certain fields, use a dash for the name and it will be ignored:
|
||||
|
||||
type Person struct {
|
||||
Name string `schema:"name"` // custom name
|
||||
Phone string `schema:"phone"` // custom name
|
||||
Admin bool `schema:"-"` // this field is never set
|
||||
}
|
||||
|
||||
The supported field types in the destination struct are:
|
||||
|
||||
- bool
|
||||
- float variants (float32, float64)
|
||||
- int variants (int, int8, int16, int32, int64)
|
||||
- string
|
||||
- uint variants (uint, uint8, uint16, uint32, uint64)
|
||||
- struct
|
||||
- a pointer to one of the above types
|
||||
- a slice or a pointer to a slice of one of the above types
|
||||
|
||||
Non-supported types are simply ignored, however custom types can be registered
|
||||
to be converted.
|
||||
|
||||
To fill nested structs, keys must use a dotted notation as the "path" for the
|
||||
field. So for example, to fill the struct Person below:
|
||||
|
||||
type Phone struct {
|
||||
Label string
|
||||
Number string
|
||||
}
|
||||
|
||||
type Person struct {
|
||||
Name string
|
||||
Phone Phone
|
||||
}
|
||||
|
||||
...the source map must have the keys "Name", "Phone.Label" and "Phone.Number".
|
||||
This means that an HTML form to fill a Person struct must look like this:
|
||||
|
||||
<form>
|
||||
<input type="text" name="Name">
|
||||
<input type="text" name="Phone.Label">
|
||||
<input type="text" name="Phone.Number">
|
||||
</form>
|
||||
|
||||
Single values are filled using the first value for a key from the source map.
|
||||
Slices are filled using all values for a key from the source map. So to fill
|
||||
a Person with multiple Phone values, like:
|
||||
|
||||
type Person struct {
|
||||
Name string
|
||||
Phones []Phone
|
||||
}
|
||||
|
||||
...an HTML form that accepts three Phone values would look like this:
|
||||
|
||||
<form>
|
||||
<input type="text" name="Name">
|
||||
<input type="text" name="Phones.0.Label">
|
||||
<input type="text" name="Phones.0.Number">
|
||||
<input type="text" name="Phones.1.Label">
|
||||
<input type="text" name="Phones.1.Number">
|
||||
<input type="text" name="Phones.2.Label">
|
||||
<input type="text" name="Phones.2.Number">
|
||||
</form>
|
||||
|
||||
Notice that only for slices of structs the slice index is required.
|
||||
This is needed for disambiguation: if the nested struct also had a slice
|
||||
field, we could not translate multiple values to it if we did not use an
|
||||
index for the parent struct.
|
||||
|
||||
There's also the possibility to create a custom type that implements the
|
||||
TextUnmarshaler interface, and in this case there's no need to register
|
||||
a converter, like:
|
||||
|
||||
type Person struct {
|
||||
Emails []Email
|
||||
}
|
||||
|
||||
type Email struct {
|
||||
*mail.Address
|
||||
}
|
||||
|
||||
func (e *Email) UnmarshalText(text []byte) (err error) {
|
||||
e.Address, err = mail.ParseAddress(string(text))
|
||||
return
|
||||
}
|
||||
|
||||
...an HTML form that accepts three Email values would look like this:
|
||||
|
||||
<form>
|
||||
<input type="email" name="Emails.0">
|
||||
<input type="email" name="Emails.1">
|
||||
<input type="email" name="Emails.2">
|
||||
</form>
|
||||
*/
|
||||
package schema
|
||||
+202
@@ -0,0 +1,202 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type encoderFunc func(reflect.Value) string
|
||||
|
||||
// Encoder encodes values from a struct into url.Values.
|
||||
type Encoder struct {
|
||||
cache *cache
|
||||
regenc map[reflect.Type]encoderFunc
|
||||
}
|
||||
|
||||
// NewEncoder returns a new Encoder with defaults.
|
||||
func NewEncoder() *Encoder {
|
||||
return &Encoder{cache: newCache(), regenc: make(map[reflect.Type]encoderFunc)}
|
||||
}
|
||||
|
||||
// Encode encodes a struct into map[string][]string.
|
||||
//
|
||||
// Intended for use with url.Values.
|
||||
func (e *Encoder) Encode(src interface{}, dst map[string][]string) error {
|
||||
v := reflect.ValueOf(src)
|
||||
|
||||
return e.encode(v, dst)
|
||||
}
|
||||
|
||||
// RegisterEncoder registers a converter for encoding a custom type.
|
||||
func (e *Encoder) RegisterEncoder(value interface{}, encoder func(reflect.Value) string) {
|
||||
e.regenc[reflect.TypeOf(value)] = encoder
|
||||
}
|
||||
|
||||
// SetAliasTag changes the tag used to locate custom field aliases.
|
||||
// The default tag is "schema".
|
||||
func (e *Encoder) SetAliasTag(tag string) {
|
||||
e.cache.tag = tag
|
||||
}
|
||||
|
||||
// isValidStructPointer test if input value is a valid struct pointer.
|
||||
func isValidStructPointer(v reflect.Value) bool {
|
||||
return v.Type().Kind() == reflect.Ptr && v.Elem().IsValid() && v.Elem().Type().Kind() == reflect.Struct
|
||||
}
|
||||
|
||||
func isZero(v reflect.Value) bool {
|
||||
switch v.Kind() {
|
||||
case reflect.Func:
|
||||
case reflect.Map, reflect.Slice:
|
||||
return v.IsNil() || v.Len() == 0
|
||||
case reflect.Array:
|
||||
z := true
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
z = z && isZero(v.Index(i))
|
||||
}
|
||||
return z
|
||||
case reflect.Struct:
|
||||
type zero interface {
|
||||
IsZero() bool
|
||||
}
|
||||
if v.Type().Implements(reflect.TypeOf((*zero)(nil)).Elem()) {
|
||||
iz := v.MethodByName("IsZero").Call([]reflect.Value{})[0]
|
||||
return iz.Interface().(bool)
|
||||
}
|
||||
z := true
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
z = z && isZero(v.Field(i))
|
||||
}
|
||||
return z
|
||||
}
|
||||
// Compare other types directly:
|
||||
z := reflect.Zero(v.Type())
|
||||
return v.Interface() == z.Interface()
|
||||
}
|
||||
|
||||
func (e *Encoder) encode(v reflect.Value, dst map[string][]string) error {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
if v.Kind() != reflect.Struct {
|
||||
return errors.New("schema: interface must be a struct")
|
||||
}
|
||||
t := v.Type()
|
||||
|
||||
errors := MultiError{}
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
name, opts := fieldAlias(t.Field(i), e.cache.tag)
|
||||
if name == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Encode struct pointer types if the field is a valid pointer and a struct.
|
||||
if isValidStructPointer(v.Field(i)) {
|
||||
_ = e.encode(v.Field(i).Elem(), dst)
|
||||
continue
|
||||
}
|
||||
|
||||
encFunc := typeEncoder(v.Field(i).Type(), e.regenc)
|
||||
|
||||
// Encode non-slice types and custom implementations immediately.
|
||||
if encFunc != nil {
|
||||
value := encFunc(v.Field(i))
|
||||
if opts.Contains("omitempty") && isZero(v.Field(i)) {
|
||||
continue
|
||||
}
|
||||
|
||||
dst[name] = append(dst[name], value)
|
||||
continue
|
||||
}
|
||||
|
||||
if v.Field(i).Type().Kind() == reflect.Struct {
|
||||
_ = e.encode(v.Field(i), dst)
|
||||
continue
|
||||
}
|
||||
|
||||
if v.Field(i).Type().Kind() == reflect.Slice {
|
||||
encFunc = typeEncoder(v.Field(i).Type().Elem(), e.regenc)
|
||||
}
|
||||
|
||||
if encFunc == nil {
|
||||
errors[v.Field(i).Type().String()] = fmt.Errorf("schema: encoder not found for %v", v.Field(i))
|
||||
continue
|
||||
}
|
||||
|
||||
// Encode a slice.
|
||||
if v.Field(i).Len() == 0 && opts.Contains("omitempty") {
|
||||
continue
|
||||
}
|
||||
|
||||
dst[name] = []string{}
|
||||
for j := 0; j < v.Field(i).Len(); j++ {
|
||||
dst[name] = append(dst[name], encFunc(v.Field(i).Index(j)))
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return errors
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func typeEncoder(t reflect.Type, reg map[reflect.Type]encoderFunc) encoderFunc {
|
||||
if f, ok := reg[t]; ok {
|
||||
return f
|
||||
}
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Bool:
|
||||
return encodeBool
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return encodeInt
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return encodeUint
|
||||
case reflect.Float32:
|
||||
return encodeFloat32
|
||||
case reflect.Float64:
|
||||
return encodeFloat64
|
||||
case reflect.Ptr:
|
||||
f := typeEncoder(t.Elem(), reg)
|
||||
return func(v reflect.Value) string {
|
||||
if v.IsNil() {
|
||||
return "null"
|
||||
}
|
||||
return f(v.Elem())
|
||||
}
|
||||
case reflect.String:
|
||||
return encodeString
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func encodeBool(v reflect.Value) string {
|
||||
return strconv.FormatBool(v.Bool())
|
||||
}
|
||||
|
||||
func encodeInt(v reflect.Value) string {
|
||||
return strconv.FormatInt(int64(v.Int()), 10)
|
||||
}
|
||||
|
||||
func encodeUint(v reflect.Value) string {
|
||||
return strconv.FormatUint(uint64(v.Uint()), 10)
|
||||
}
|
||||
|
||||
func encodeFloat(v reflect.Value, bits int) string {
|
||||
return strconv.FormatFloat(v.Float(), 'f', 6, bits)
|
||||
}
|
||||
|
||||
func encodeFloat32(v reflect.Value) string {
|
||||
return encodeFloat(v, 32)
|
||||
}
|
||||
|
||||
func encodeFloat64(v reflect.Value) string {
|
||||
return encodeFloat(v, 64)
|
||||
}
|
||||
|
||||
func encodeString(v reflect.Value) string {
|
||||
return v.String()
|
||||
}
|
||||
+502
@@ -0,0 +1,502 @@
|
||||
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
|
||||
// 🤖 Github Repository: https://github.com/gofiber/fiber
|
||||
// 📌 API Documentation: https://docs.gofiber.io
|
||||
|
||||
package fiber
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
|
||||
"github.com/mattn/go-colorable"
|
||||
"github.com/mattn/go-isatty"
|
||||
"github.com/mattn/go-runewidth"
|
||||
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
const (
|
||||
globalIpv4Addr = "0.0.0.0"
|
||||
)
|
||||
|
||||
// Listener can be used to pass a custom listener.
|
||||
func (app *App) Listener(ln net.Listener) error {
|
||||
// prepare the server for the start
|
||||
app.startupProcess()
|
||||
|
||||
// run hooks
|
||||
app.runOnListenHooks(app.prepareListenData(ln.Addr().String(), getTLSConfig(ln) != nil))
|
||||
|
||||
// Print startup message
|
||||
if !app.config.DisableStartupMessage {
|
||||
app.startupMessage(ln.Addr().String(), getTLSConfig(ln) != nil, "")
|
||||
}
|
||||
|
||||
// Print routes
|
||||
if app.config.EnablePrintRoutes {
|
||||
app.printRoutesMessage()
|
||||
}
|
||||
|
||||
// Prefork is not supported for custom listeners
|
||||
if app.config.Prefork {
|
||||
log.Warn("Prefork isn't supported for custom listeners.")
|
||||
}
|
||||
|
||||
// Start listening
|
||||
return app.server.Serve(ln)
|
||||
}
|
||||
|
||||
// Listen serves HTTP requests from the given addr.
|
||||
//
|
||||
// app.Listen(":8080")
|
||||
// app.Listen("127.0.0.1:8080")
|
||||
func (app *App) Listen(addr string) error {
|
||||
// Start prefork
|
||||
if app.config.Prefork {
|
||||
return app.prefork(app.config.Network, addr, nil)
|
||||
}
|
||||
|
||||
// Setup listener
|
||||
ln, err := net.Listen(app.config.Network, addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen: %w", err)
|
||||
}
|
||||
|
||||
// prepare the server for the start
|
||||
app.startupProcess()
|
||||
|
||||
// run hooks
|
||||
app.runOnListenHooks(app.prepareListenData(ln.Addr().String(), false))
|
||||
|
||||
// Print startup message
|
||||
if !app.config.DisableStartupMessage {
|
||||
app.startupMessage(ln.Addr().String(), false, "")
|
||||
}
|
||||
|
||||
// Print routes
|
||||
if app.config.EnablePrintRoutes {
|
||||
app.printRoutesMessage()
|
||||
}
|
||||
|
||||
// Start listening
|
||||
return app.server.Serve(ln)
|
||||
}
|
||||
|
||||
// ListenTLS serves HTTPS requests from the given addr.
|
||||
// certFile and keyFile are the paths to TLS certificate and key file:
|
||||
//
|
||||
// app.ListenTLS(":8080", "./cert.pem", "./cert.key")
|
||||
func (app *App) ListenTLS(addr, certFile, keyFile string) error {
|
||||
// Check for valid cert/key path
|
||||
if len(certFile) == 0 || len(keyFile) == 0 {
|
||||
return errors.New("tls: provide a valid cert or key path")
|
||||
}
|
||||
|
||||
// Set TLS config with handler
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %w", certFile, keyFile, err)
|
||||
}
|
||||
|
||||
return app.ListenTLSWithCertificate(addr, cert)
|
||||
}
|
||||
|
||||
// ListenTLS serves HTTPS requests from the given addr.
|
||||
// cert is a tls.Certificate
|
||||
//
|
||||
// app.ListenTLSWithCertificate(":8080", cert)
|
||||
func (app *App) ListenTLSWithCertificate(addr string, cert tls.Certificate) error {
|
||||
tlsHandler := &TLSHandler{}
|
||||
config := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
Certificates: []tls.Certificate{
|
||||
cert,
|
||||
},
|
||||
GetCertificate: tlsHandler.GetClientInfo,
|
||||
}
|
||||
|
||||
// Prefork is supported
|
||||
if app.config.Prefork {
|
||||
return app.prefork(app.config.Network, addr, config)
|
||||
}
|
||||
|
||||
// Setup listener
|
||||
ln, err := net.Listen(app.config.Network, addr)
|
||||
ln = tls.NewListener(ln, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen: %w", err)
|
||||
}
|
||||
|
||||
// prepare the server for the start
|
||||
app.startupProcess()
|
||||
|
||||
// run hooks
|
||||
app.runOnListenHooks(app.prepareListenData(ln.Addr().String(), getTLSConfig(ln) != nil))
|
||||
|
||||
// Print startup message
|
||||
if !app.config.DisableStartupMessage {
|
||||
app.startupMessage(ln.Addr().String(), true, "")
|
||||
}
|
||||
|
||||
// Print routes
|
||||
if app.config.EnablePrintRoutes {
|
||||
app.printRoutesMessage()
|
||||
}
|
||||
|
||||
// Attach the tlsHandler to the config
|
||||
app.SetTLSHandler(tlsHandler)
|
||||
|
||||
// Start listening
|
||||
return app.server.Serve(ln)
|
||||
}
|
||||
|
||||
// ListenMutualTLS serves HTTPS requests from the given addr.
|
||||
// certFile, keyFile and clientCertFile are the paths to TLS certificate and key file:
|
||||
//
|
||||
// app.ListenMutualTLS(":8080", "./cert.pem", "./cert.key", "./client.pem")
|
||||
func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string) error {
|
||||
// Check for valid cert/key path
|
||||
if len(certFile) == 0 || len(keyFile) == 0 {
|
||||
return errors.New("tls: provide a valid cert or key path")
|
||||
}
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %w", certFile, keyFile, err)
|
||||
}
|
||||
|
||||
clientCACert, err := os.ReadFile(filepath.Clean(clientCertFile))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
clientCertPool := x509.NewCertPool()
|
||||
clientCertPool.AppendCertsFromPEM(clientCACert)
|
||||
|
||||
return app.ListenMutualTLSWithCertificate(addr, cert, clientCertPool)
|
||||
}
|
||||
|
||||
// ListenMutualTLSWithCertificate serves HTTPS requests from the given addr.
|
||||
// cert is a tls.Certificate and clientCertPool is a *x509.CertPool:
|
||||
//
|
||||
// app.ListenMutualTLS(":8080", cert, clientCertPool)
|
||||
func (app *App) ListenMutualTLSWithCertificate(addr string, cert tls.Certificate, clientCertPool *x509.CertPool) error {
|
||||
tlsHandler := &TLSHandler{}
|
||||
config := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientCAs: clientCertPool,
|
||||
Certificates: []tls.Certificate{
|
||||
cert,
|
||||
},
|
||||
GetCertificate: tlsHandler.GetClientInfo,
|
||||
}
|
||||
|
||||
// Prefork is supported
|
||||
if app.config.Prefork {
|
||||
return app.prefork(app.config.Network, addr, config)
|
||||
}
|
||||
|
||||
// Setup listener
|
||||
ln, err := tls.Listen(app.config.Network, addr, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen: %w", err)
|
||||
}
|
||||
|
||||
// prepare the server for the start
|
||||
app.startupProcess()
|
||||
|
||||
// run hooks
|
||||
app.runOnListenHooks(app.prepareListenData(ln.Addr().String(), getTLSConfig(ln) != nil))
|
||||
|
||||
// Print startup message
|
||||
if !app.config.DisableStartupMessage {
|
||||
app.startupMessage(ln.Addr().String(), true, "")
|
||||
}
|
||||
|
||||
// Print routes
|
||||
if app.config.EnablePrintRoutes {
|
||||
app.printRoutesMessage()
|
||||
}
|
||||
|
||||
// Attach the tlsHandler to the config
|
||||
app.SetTLSHandler(tlsHandler)
|
||||
|
||||
// Start listening
|
||||
return app.server.Serve(ln)
|
||||
}
|
||||
|
||||
// prepareListenData create an slice of ListenData
|
||||
func (app *App) prepareListenData(addr string, isTLS bool) ListenData { //revive:disable-line:flag-parameter // Accepting a bool param named isTLS if fine here
|
||||
host, port := parseAddr(addr)
|
||||
if host == "" {
|
||||
if app.config.Network == NetworkTCP6 {
|
||||
host = "[::1]"
|
||||
} else {
|
||||
host = globalIpv4Addr
|
||||
}
|
||||
}
|
||||
|
||||
return ListenData{
|
||||
Host: host,
|
||||
Port: port,
|
||||
TLS: isTLS,
|
||||
}
|
||||
}
|
||||
|
||||
// startupMessage prepares the startup message with the handler number, port, address and other information
|
||||
func (app *App) startupMessage(addr string, isTLS bool, pids string) { //nolint: revive // Accepting a bool param named isTLS if fine here
|
||||
// ignore child processes
|
||||
if IsChild() {
|
||||
return
|
||||
}
|
||||
|
||||
// Alias colors
|
||||
colors := app.config.ColorScheme
|
||||
|
||||
value := func(s string, width int) string {
|
||||
pad := width - len(s)
|
||||
str := ""
|
||||
for i := 0; i < pad; i++ {
|
||||
str += "."
|
||||
}
|
||||
if s == "Disabled" {
|
||||
str += " " + s
|
||||
} else {
|
||||
str += fmt.Sprintf(" %s%s%s", colors.Cyan, s, colors.Black)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
center := func(s string, width int) string {
|
||||
const padDiv = 2
|
||||
pad := strconv.Itoa((width - len(s)) / padDiv)
|
||||
str := fmt.Sprintf("%"+pad+"s", " ")
|
||||
str += s
|
||||
str += fmt.Sprintf("%"+pad+"s", " ")
|
||||
if len(str) < width {
|
||||
str += " "
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
centerValue := func(s string, width int) string {
|
||||
const padDiv = 2
|
||||
pad := strconv.Itoa((width - runewidth.StringWidth(s)) / padDiv)
|
||||
str := fmt.Sprintf("%"+pad+"s", " ")
|
||||
str += fmt.Sprintf("%s%s%s", colors.Cyan, s, colors.Black)
|
||||
str += fmt.Sprintf("%"+pad+"s", " ")
|
||||
if runewidth.StringWidth(s)-10 < width && runewidth.StringWidth(s)%2 == 0 {
|
||||
// add an ending space if the length of str is even and str is not too long
|
||||
str += " "
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
pad := func(s string, width int) string {
|
||||
toAdd := width - len(s)
|
||||
str := s
|
||||
for i := 0; i < toAdd; i++ {
|
||||
str += " "
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
host, port := parseAddr(addr)
|
||||
if host == "" {
|
||||
if app.config.Network == NetworkTCP6 {
|
||||
host = "[::1]"
|
||||
} else {
|
||||
host = globalIpv4Addr
|
||||
}
|
||||
}
|
||||
|
||||
scheme := schemeHTTP
|
||||
if isTLS {
|
||||
scheme = schemeHTTPS
|
||||
}
|
||||
|
||||
isPrefork := "Disabled"
|
||||
if app.config.Prefork {
|
||||
isPrefork = "Enabled"
|
||||
}
|
||||
|
||||
procs := strconv.Itoa(runtime.GOMAXPROCS(0))
|
||||
if !app.config.Prefork {
|
||||
procs = "1"
|
||||
}
|
||||
|
||||
const lineLen = 49
|
||||
mainLogo := colors.Black + " ┌───────────────────────────────────────────────────┐\n"
|
||||
if app.config.AppName != "" {
|
||||
mainLogo += " │ " + centerValue(app.config.AppName, lineLen) + " │\n"
|
||||
}
|
||||
mainLogo += " │ " + centerValue("Fiber v"+Version, lineLen) + " │\n"
|
||||
|
||||
if host == globalIpv4Addr {
|
||||
mainLogo += " │ " + center(fmt.Sprintf("%s://127.0.0.1:%s", scheme, port), lineLen) + " │\n" +
|
||||
" │ " + center(fmt.Sprintf("(bound on host 0.0.0.0 and port %s)", port), lineLen) + " │\n"
|
||||
} else {
|
||||
mainLogo += " │ " + center(fmt.Sprintf("%s://%s:%s", scheme, host, port), lineLen) + " │\n"
|
||||
}
|
||||
|
||||
mainLogo += fmt.Sprintf(
|
||||
" │ │\n"+
|
||||
" │ Handlers %s Processes %s │\n"+
|
||||
" │ Prefork .%s PID ....%s │\n"+
|
||||
" └───────────────────────────────────────────────────┘"+
|
||||
colors.Reset,
|
||||
value(strconv.Itoa(int(app.handlersCount)), 14), value(procs, 12),
|
||||
value(isPrefork, 14), value(strconv.Itoa(os.Getpid()), 14),
|
||||
)
|
||||
|
||||
var childPidsLogo string
|
||||
if app.config.Prefork {
|
||||
var childPidsTemplate string
|
||||
childPidsTemplate += "%s"
|
||||
childPidsTemplate += " ┌───────────────────────────────────────────────────┐\n%s"
|
||||
childPidsTemplate += " └───────────────────────────────────────────────────┘"
|
||||
childPidsTemplate += "%s"
|
||||
|
||||
newLine := " │ %s%s%s │"
|
||||
|
||||
// Turn the `pids` variable (in the form ",a,b,c,d,e,f,etc") into a slice of PIDs
|
||||
var pidSlice []string
|
||||
for _, v := range strings.Split(pids, ",") {
|
||||
if v != "" {
|
||||
pidSlice = append(pidSlice, v)
|
||||
}
|
||||
}
|
||||
|
||||
var lines []string
|
||||
thisLine := "Child PIDs ... "
|
||||
var itemsOnThisLine []string
|
||||
|
||||
const maxLineLen = 49
|
||||
|
||||
addLine := func() {
|
||||
lines = append(lines,
|
||||
fmt.Sprintf(
|
||||
newLine,
|
||||
colors.Black,
|
||||
thisLine+colors.Cyan+pad(strings.Join(itemsOnThisLine, ", "), maxLineLen-len(thisLine)),
|
||||
colors.Black,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
for _, pid := range pidSlice {
|
||||
if len(thisLine+strings.Join(append(itemsOnThisLine, pid), ", ")) > maxLineLen {
|
||||
addLine()
|
||||
thisLine = ""
|
||||
itemsOnThisLine = []string{pid}
|
||||
} else {
|
||||
itemsOnThisLine = append(itemsOnThisLine, pid)
|
||||
}
|
||||
}
|
||||
|
||||
// Add left over items to their own line
|
||||
if len(itemsOnThisLine) != 0 {
|
||||
addLine()
|
||||
}
|
||||
|
||||
// Form logo
|
||||
childPidsLogo = fmt.Sprintf(childPidsTemplate,
|
||||
colors.Black,
|
||||
strings.Join(lines, "\n")+"\n",
|
||||
colors.Reset,
|
||||
)
|
||||
}
|
||||
|
||||
// Combine both the child PID logo and the main Fiber logo
|
||||
|
||||
// Pad the shorter logo to the length of the longer one
|
||||
splitMainLogo := strings.Split(mainLogo, "\n")
|
||||
splitChildPidsLogo := strings.Split(childPidsLogo, "\n")
|
||||
|
||||
mainLen := len(splitMainLogo)
|
||||
childLen := len(splitChildPidsLogo)
|
||||
|
||||
if mainLen > childLen {
|
||||
diff := mainLen - childLen
|
||||
for i := 0; i < diff; i++ {
|
||||
splitChildPidsLogo = append(splitChildPidsLogo, "")
|
||||
}
|
||||
} else {
|
||||
diff := childLen - mainLen
|
||||
for i := 0; i < diff; i++ {
|
||||
splitMainLogo = append(splitMainLogo, "")
|
||||
}
|
||||
}
|
||||
|
||||
// Combine the two logos, line by line
|
||||
output := "\n"
|
||||
for i := range splitMainLogo {
|
||||
output += colors.Black + splitMainLogo[i] + " " + splitChildPidsLogo[i] + "\n"
|
||||
}
|
||||
|
||||
out := colorable.NewColorableStdout()
|
||||
if os.Getenv("TERM") == "dumb" || os.Getenv("NO_COLOR") == "1" || (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd())) {
|
||||
out = colorable.NewNonColorable(os.Stdout)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(out, output)
|
||||
}
|
||||
|
||||
// printRoutesMessage print all routes with method, path, name and handlers
|
||||
// in a format of table, like this:
|
||||
// method | path | name | handlers
|
||||
// GET | / | routeName | github.com/gofiber/fiber/v2.emptyHandler
|
||||
// HEAD | / | | github.com/gofiber/fiber/v2.emptyHandler
|
||||
func (app *App) printRoutesMessage() {
|
||||
// ignore child processes
|
||||
if IsChild() {
|
||||
return
|
||||
}
|
||||
|
||||
// Alias colors
|
||||
colors := app.config.ColorScheme
|
||||
|
||||
var routes []RouteMessage
|
||||
for _, routeStack := range app.stack {
|
||||
for _, route := range routeStack {
|
||||
var newRoute RouteMessage
|
||||
newRoute.name = route.Name
|
||||
newRoute.method = route.Method
|
||||
newRoute.path = route.Path
|
||||
for _, handler := range route.Handlers {
|
||||
newRoute.handlers += runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name() + " "
|
||||
}
|
||||
routes = append(routes, newRoute)
|
||||
}
|
||||
}
|
||||
|
||||
out := colorable.NewColorableStdout()
|
||||
if os.Getenv("TERM") == "dumb" || os.Getenv("NO_COLOR") == "1" || (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd())) {
|
||||
out = colorable.NewNonColorable(os.Stdout)
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(out, 1, 1, 1, ' ', 0)
|
||||
// Sort routes by path
|
||||
sort.Slice(routes, func(i, j int) bool {
|
||||
return routes[i].path < routes[j].path
|
||||
})
|
||||
|
||||
_, _ = fmt.Fprintf(w, "%smethod\t%s| %spath\t%s| %sname\t%s| %shandlers\t%s\n", colors.Blue, colors.White, colors.Green, colors.White, colors.Cyan, colors.White, colors.Yellow, colors.Reset)
|
||||
_, _ = fmt.Fprintf(w, "%s------\t%s| %s----\t%s| %s----\t%s| %s--------\t%s\n", colors.Blue, colors.White, colors.Green, colors.White, colors.Cyan, colors.White, colors.Yellow, colors.Reset)
|
||||
for _, route := range routes {
|
||||
_, _ = fmt.Fprintf(w, "%s%s\t%s| %s%s\t%s| %s%s\t%s| %s%s%s\n", colors.Blue, route.method, colors.White, colors.Green, route.path, colors.White, colors.Cyan, route.name, colors.White, colors.Yellow, route.handlers, colors.Reset)
|
||||
}
|
||||
|
||||
_ = w.Flush() //nolint:errcheck // It is fine to ignore the error here
|
||||
}
|
||||
+209
@@ -0,0 +1,209 @@
|
||||
package log
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/valyala/bytebufferpool"
|
||||
)
|
||||
|
||||
var _ AllLogger = (*defaultLogger)(nil)
|
||||
|
||||
type defaultLogger struct {
|
||||
stdlog *log.Logger
|
||||
level Level
|
||||
depth int
|
||||
}
|
||||
|
||||
// privateLog logs a message at a given level log the default logger.
|
||||
// when the level is fatal, it will exit the program.
|
||||
func (l *defaultLogger) privateLog(lv Level, fmtArgs []interface{}) {
|
||||
if l.level > lv {
|
||||
return
|
||||
}
|
||||
level := lv.toString()
|
||||
buf := bytebufferpool.Get()
|
||||
_, _ = buf.WriteString(level) //nolint:errcheck // It is fine to ignore the error
|
||||
_, _ = buf.WriteString(fmt.Sprint(fmtArgs...)) //nolint:errcheck // It is fine to ignore the error
|
||||
|
||||
_ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error
|
||||
buf.Reset()
|
||||
bytebufferpool.Put(buf)
|
||||
if lv == LevelFatal {
|
||||
os.Exit(1) //nolint:revive // we want to exit the program when Fatal is called
|
||||
}
|
||||
}
|
||||
|
||||
// privateLog logs a message at a given level log the default logger.
|
||||
// when the level is fatal, it will exit the program.
|
||||
func (l *defaultLogger) privateLogf(lv Level, format string, fmtArgs []interface{}) {
|
||||
if l.level > lv {
|
||||
return
|
||||
}
|
||||
level := lv.toString()
|
||||
buf := bytebufferpool.Get()
|
||||
_, _ = buf.WriteString(level) //nolint:errcheck // It is fine to ignore the error
|
||||
|
||||
if len(fmtArgs) > 0 {
|
||||
_, _ = fmt.Fprintf(buf, format, fmtArgs...)
|
||||
} else {
|
||||
_, _ = fmt.Fprint(buf, fmtArgs...)
|
||||
}
|
||||
_ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error
|
||||
buf.Reset()
|
||||
bytebufferpool.Put(buf)
|
||||
if lv == LevelFatal {
|
||||
os.Exit(1) //nolint:revive // we want to exit the program when Fatal is called
|
||||
}
|
||||
}
|
||||
|
||||
// privateLogw logs a message at a given level log the default logger.
|
||||
// when the level is fatal, it will exit the program.
|
||||
func (l *defaultLogger) privateLogw(lv Level, format string, keysAndValues []interface{}) {
|
||||
if l.level > lv {
|
||||
return
|
||||
}
|
||||
level := lv.toString()
|
||||
buf := bytebufferpool.Get()
|
||||
_, _ = buf.WriteString(level) //nolint:errcheck // It is fine to ignore the error
|
||||
|
||||
// Write format privateLog buffer
|
||||
if format != "" {
|
||||
_, _ = buf.WriteString(format) //nolint:errcheck // It is fine to ignore the error
|
||||
}
|
||||
var once sync.Once
|
||||
isFirst := true
|
||||
// Write keys and values privateLog buffer
|
||||
if len(keysAndValues) > 0 {
|
||||
if (len(keysAndValues) & 1) == 1 {
|
||||
keysAndValues = append(keysAndValues, "KEYVALS UNPAIRED")
|
||||
}
|
||||
|
||||
for i := 0; i < len(keysAndValues); i += 2 {
|
||||
if format == "" && isFirst {
|
||||
once.Do(func() {
|
||||
_, _ = fmt.Fprintf(buf, "%s=%v", keysAndValues[i], keysAndValues[i+1])
|
||||
isFirst = false
|
||||
})
|
||||
continue
|
||||
}
|
||||
_, _ = fmt.Fprintf(buf, " %s=%v", keysAndValues[i], keysAndValues[i+1])
|
||||
}
|
||||
}
|
||||
|
||||
_ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error
|
||||
buf.Reset()
|
||||
bytebufferpool.Put(buf)
|
||||
if lv == LevelFatal {
|
||||
os.Exit(1) //nolint:revive // we want to exit the program when Fatal is called
|
||||
}
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Trace(v ...interface{}) {
|
||||
l.privateLog(LevelTrace, v)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Debug(v ...interface{}) {
|
||||
l.privateLog(LevelDebug, v)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Info(v ...interface{}) {
|
||||
l.privateLog(LevelInfo, v)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Warn(v ...interface{}) {
|
||||
l.privateLog(LevelWarn, v)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Error(v ...interface{}) {
|
||||
l.privateLog(LevelError, v)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Fatal(v ...interface{}) {
|
||||
l.privateLog(LevelFatal, v)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Panic(v ...interface{}) {
|
||||
l.privateLog(LevelPanic, v)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Tracef(format string, v ...interface{}) {
|
||||
l.privateLogf(LevelTrace, format, v)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Debugf(format string, v ...interface{}) {
|
||||
l.privateLogf(LevelDebug, format, v)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Infof(format string, v ...interface{}) {
|
||||
l.privateLogf(LevelInfo, format, v)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Warnf(format string, v ...interface{}) {
|
||||
l.privateLogf(LevelWarn, format, v)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Errorf(format string, v ...interface{}) {
|
||||
l.privateLogf(LevelError, format, v)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Fatalf(format string, v ...interface{}) {
|
||||
l.privateLogf(LevelFatal, format, v)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Panicf(format string, v ...interface{}) {
|
||||
l.privateLogf(LevelPanic, format, v)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Tracew(msg string, keysAndValues ...interface{}) {
|
||||
l.privateLogw(LevelTrace, msg, keysAndValues)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Debugw(msg string, keysAndValues ...interface{}) {
|
||||
l.privateLogw(LevelDebug, msg, keysAndValues)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Infow(msg string, keysAndValues ...interface{}) {
|
||||
l.privateLogw(LevelInfo, msg, keysAndValues)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Warnw(msg string, keysAndValues ...interface{}) {
|
||||
l.privateLogw(LevelWarn, msg, keysAndValues)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Errorw(msg string, keysAndValues ...interface{}) {
|
||||
l.privateLogw(LevelError, msg, keysAndValues)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Fatalw(msg string, keysAndValues ...interface{}) {
|
||||
l.privateLogw(LevelFatal, msg, keysAndValues)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Panicw(msg string, keysAndValues ...interface{}) {
|
||||
l.privateLogw(LevelPanic, msg, keysAndValues)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) WithContext(_ context.Context) CommonLogger {
|
||||
return &defaultLogger{
|
||||
stdlog: l.stdlog,
|
||||
level: l.level,
|
||||
depth: l.depth - 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *defaultLogger) SetLevel(level Level) {
|
||||
l.level = level
|
||||
}
|
||||
|
||||
func (l *defaultLogger) SetOutput(writer io.Writer) {
|
||||
l.stdlog.SetOutput(writer)
|
||||
}
|
||||
|
||||
// DefaultLogger returns the default logger.
|
||||
func DefaultLogger() AllLogger {
|
||||
return logger
|
||||
}
|
||||
+141
@@ -0,0 +1,141 @@
|
||||
package log
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Fatal calls the default logger's Fatal method and then os.Exit(1).
|
||||
func Fatal(v ...interface{}) {
|
||||
logger.Fatal(v...)
|
||||
}
|
||||
|
||||
// Error calls the default logger's Error method.
|
||||
func Error(v ...interface{}) {
|
||||
logger.Error(v...)
|
||||
}
|
||||
|
||||
// Warn calls the default logger's Warn method.
|
||||
func Warn(v ...interface{}) {
|
||||
logger.Warn(v...)
|
||||
}
|
||||
|
||||
// Info calls the default logger's Info method.
|
||||
func Info(v ...interface{}) {
|
||||
logger.Info(v...)
|
||||
}
|
||||
|
||||
// Debug calls the default logger's Debug method.
|
||||
func Debug(v ...interface{}) {
|
||||
logger.Debug(v...)
|
||||
}
|
||||
|
||||
// Trace calls the default logger's Trace method.
|
||||
func Trace(v ...interface{}) {
|
||||
logger.Trace(v...)
|
||||
}
|
||||
|
||||
// Panic calls the default logger's Panic method.
|
||||
func Panic(v ...interface{}) {
|
||||
logger.Panic(v...)
|
||||
}
|
||||
|
||||
// Fatalf calls the default logger's Fatalf method and then os.Exit(1).
|
||||
func Fatalf(format string, v ...interface{}) {
|
||||
logger.Fatalf(format, v...)
|
||||
}
|
||||
|
||||
// Errorf calls the default logger's Errorf method.
|
||||
func Errorf(format string, v ...interface{}) {
|
||||
logger.Errorf(format, v...)
|
||||
}
|
||||
|
||||
// Warnf calls the default logger's Warnf method.
|
||||
func Warnf(format string, v ...interface{}) {
|
||||
logger.Warnf(format, v...)
|
||||
}
|
||||
|
||||
// Infof calls the default logger's Infof method.
|
||||
func Infof(format string, v ...interface{}) {
|
||||
logger.Infof(format, v...)
|
||||
}
|
||||
|
||||
// Debugf calls the default logger's Debugf method.
|
||||
func Debugf(format string, v ...interface{}) {
|
||||
logger.Debugf(format, v...)
|
||||
}
|
||||
|
||||
// Tracef calls the default logger's Tracef method.
|
||||
func Tracef(format string, v ...interface{}) {
|
||||
logger.Tracef(format, v...)
|
||||
}
|
||||
|
||||
// Panicf calls the default logger's Tracef method.
|
||||
func Panicf(format string, v ...interface{}) {
|
||||
logger.Panicf(format, v...)
|
||||
}
|
||||
|
||||
// Tracew logs a message with some additional context. The variadic key-value
|
||||
// pairs are treated as they are privateLog With.
|
||||
func Tracew(msg string, keysAndValues ...interface{}) {
|
||||
logger.Tracew(msg, keysAndValues...)
|
||||
}
|
||||
|
||||
// Debugw logs a message with some additional context. The variadic key-value
|
||||
// pairs are treated as they are privateLog With.
|
||||
func Debugw(msg string, keysAndValues ...interface{}) {
|
||||
logger.Debugw(msg, keysAndValues...)
|
||||
}
|
||||
|
||||
// Infow logs a message with some additional context. The variadic key-value
|
||||
// pairs are treated as they are privateLog With.
|
||||
func Infow(msg string, keysAndValues ...interface{}) {
|
||||
logger.Infow(msg, keysAndValues...)
|
||||
}
|
||||
|
||||
// Warnw logs a message with some additional context. The variadic key-value
|
||||
// pairs are treated as they are privateLog With.
|
||||
func Warnw(msg string, keysAndValues ...interface{}) {
|
||||
logger.Warnw(msg, keysAndValues...)
|
||||
}
|
||||
|
||||
// Errorw logs a message with some additional context. The variadic key-value
|
||||
// pairs are treated as they are privateLog With.
|
||||
func Errorw(msg string, keysAndValues ...interface{}) {
|
||||
logger.Errorw(msg, keysAndValues...)
|
||||
}
|
||||
|
||||
// Fatalw logs a message with some additional context. The variadic key-value
|
||||
// pairs are treated as they are privateLog With.
|
||||
func Fatalw(msg string, keysAndValues ...interface{}) {
|
||||
logger.Fatalw(msg, keysAndValues...)
|
||||
}
|
||||
|
||||
// Panicw logs a message with some additional context. The variadic key-value
|
||||
// pairs are treated as they are privateLog With.
|
||||
func Panicw(msg string, keysAndValues ...interface{}) {
|
||||
logger.Panicw(msg, keysAndValues...)
|
||||
}
|
||||
|
||||
func WithContext(ctx context.Context) CommonLogger {
|
||||
return logger.WithContext(ctx)
|
||||
}
|
||||
|
||||
// SetLogger sets the default logger and the system logger.
|
||||
// Note that this method is not concurrent-safe and must not be called
|
||||
// after the use of DefaultLogger and global functions privateLog this package.
|
||||
func SetLogger(v AllLogger) {
|
||||
logger = v
|
||||
}
|
||||
|
||||
// SetOutput sets the output of default logger and system logger. By default, it is stderr.
|
||||
func SetOutput(w io.Writer) {
|
||||
logger.SetOutput(w)
|
||||
}
|
||||
|
||||
// SetLevel sets the level of logs below which logs will not be output.
|
||||
// The default logger is LevelTrace.
|
||||
// Note that this method is not concurrent-safe.
|
||||
func SetLevel(lv Level) {
|
||||
logger.SetLevel(lv)
|
||||
}
|
||||
+100
@@ -0,0 +1,100 @@
|
||||
package log
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
var logger AllLogger = &defaultLogger{
|
||||
stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds),
|
||||
depth: 4,
|
||||
}
|
||||
|
||||
// Logger is a logger interface that provides logging function with levels.
|
||||
type Logger interface {
|
||||
Trace(v ...interface{})
|
||||
Debug(v ...interface{})
|
||||
Info(v ...interface{})
|
||||
Warn(v ...interface{})
|
||||
Error(v ...interface{})
|
||||
Fatal(v ...interface{})
|
||||
Panic(v ...interface{})
|
||||
}
|
||||
|
||||
// FormatLogger is a logger interface that output logs with a format.
|
||||
type FormatLogger interface {
|
||||
Tracef(format string, v ...interface{})
|
||||
Debugf(format string, v ...interface{})
|
||||
Infof(format string, v ...interface{})
|
||||
Warnf(format string, v ...interface{})
|
||||
Errorf(format string, v ...interface{})
|
||||
Fatalf(format string, v ...interface{})
|
||||
Panicf(format string, v ...interface{})
|
||||
}
|
||||
|
||||
// WithLogger is a logger interface that output logs with a message and key-value pairs.
|
||||
type WithLogger interface {
|
||||
Tracew(msg string, keysAndValues ...interface{})
|
||||
Debugw(msg string, keysAndValues ...interface{})
|
||||
Infow(msg string, keysAndValues ...interface{})
|
||||
Warnw(msg string, keysAndValues ...interface{})
|
||||
Errorw(msg string, keysAndValues ...interface{})
|
||||
Fatalw(msg string, keysAndValues ...interface{})
|
||||
Panicw(msg string, keysAndValues ...interface{})
|
||||
}
|
||||
|
||||
type CommonLogger interface {
|
||||
Logger
|
||||
FormatLogger
|
||||
WithLogger
|
||||
}
|
||||
|
||||
// ControlLogger provides methods to config a logger.
|
||||
type ControlLogger interface {
|
||||
SetLevel(Level)
|
||||
SetOutput(io.Writer)
|
||||
}
|
||||
|
||||
// AllLogger is the combination of Logger, FormatLogger, CtxLogger and ControlLogger.
|
||||
// Custom extensions can be made through AllLogger
|
||||
type AllLogger interface {
|
||||
CommonLogger
|
||||
ControlLogger
|
||||
WithContext(ctx context.Context) CommonLogger
|
||||
}
|
||||
|
||||
// Level defines the priority of a log message.
|
||||
// When a logger is configured with a level, any log message with a lower
|
||||
// log level (smaller by integer comparison) will not be output.
|
||||
type Level int
|
||||
|
||||
// The levels of logs.
|
||||
const (
|
||||
LevelTrace Level = iota
|
||||
LevelDebug
|
||||
LevelInfo
|
||||
LevelWarn
|
||||
LevelError
|
||||
LevelFatal
|
||||
LevelPanic
|
||||
)
|
||||
|
||||
var strs = []string{
|
||||
"[Trace] ",
|
||||
"[Debug] ",
|
||||
"[Info] ",
|
||||
"[Warn] ",
|
||||
"[Error] ",
|
||||
"[Fatal] ",
|
||||
"[Panic] ",
|
||||
}
|
||||
|
||||
func (lv Level) toString() string {
|
||||
if lv >= LevelTrace && lv <= LevelPanic {
|
||||
return strs[lv]
|
||||
}
|
||||
return fmt.Sprintf("[?%d] ", lv)
|
||||
}
|
||||
+65
@@ -0,0 +1,65 @@
|
||||
package compress
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Setup request handlers
|
||||
var (
|
||||
fctx = func(c *fasthttp.RequestCtx) {}
|
||||
compressor fasthttp.RequestHandler
|
||||
)
|
||||
|
||||
// Setup compression algorithm
|
||||
switch cfg.Level {
|
||||
case LevelDefault:
|
||||
// LevelDefault
|
||||
compressor = fasthttp.CompressHandlerBrotliLevel(fctx,
|
||||
fasthttp.CompressBrotliDefaultCompression,
|
||||
fasthttp.CompressDefaultCompression,
|
||||
)
|
||||
case LevelBestSpeed:
|
||||
// LevelBestSpeed
|
||||
compressor = fasthttp.CompressHandlerBrotliLevel(fctx,
|
||||
fasthttp.CompressBrotliBestSpeed,
|
||||
fasthttp.CompressBestSpeed,
|
||||
)
|
||||
case LevelBestCompression:
|
||||
// LevelBestCompression
|
||||
compressor = fasthttp.CompressHandlerBrotliLevel(fctx,
|
||||
fasthttp.CompressBrotliBestCompression,
|
||||
fasthttp.CompressBestCompression,
|
||||
)
|
||||
default:
|
||||
// LevelDisabled
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Continue stack
|
||||
if err := c.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Compress response
|
||||
compressor(c.Context())
|
||||
|
||||
// Return from handler
|
||||
return nil
|
||||
}
|
||||
}
|
||||
+56
@@ -0,0 +1,56 @@
|
||||
package compress
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Level determines the compression algorithm
|
||||
//
|
||||
// Optional. Default: LevelDefault
|
||||
// LevelDisabled: -1
|
||||
// LevelDefault: 0
|
||||
// LevelBestSpeed: 1
|
||||
// LevelBestCompression: 2
|
||||
Level Level
|
||||
}
|
||||
|
||||
// Level is numeric representation of compression level
|
||||
type Level int
|
||||
|
||||
// Represents compression level that will be used in the middleware
|
||||
const (
|
||||
LevelDisabled Level = -1
|
||||
LevelDefault Level = 0
|
||||
LevelBestSpeed Level = 1
|
||||
LevelBestCompression Level = 2
|
||||
)
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Level: LevelDefault,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Level < LevelDisabled || cfg.Level > LevelBestCompression {
|
||||
cfg.Level = ConfigDefault.Level
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
+289
@@ -0,0 +1,289 @@
|
||||
package cors
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// AllowOriginsFunc defines a function that will set the 'Access-Control-Allow-Origin'
|
||||
// response header to the 'origin' request header when returned true. This allows for
|
||||
// dynamic evaluation of allowed origins. Note if AllowCredentials is true, wildcard origins
|
||||
// will be not have the 'Access-Control-Allow-Credentials' header set to 'true'.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
AllowOriginsFunc func(origin string) bool
|
||||
|
||||
// AllowOrigin defines a comma separated list of origins that may access the resource.
|
||||
//
|
||||
// Optional. Default value "*"
|
||||
AllowOrigins string
|
||||
|
||||
// AllowMethods defines a list methods allowed when accessing the resource.
|
||||
// This is used in response to a preflight request.
|
||||
//
|
||||
// Optional. Default value "GET,POST,HEAD,PUT,DELETE,PATCH"
|
||||
AllowMethods string
|
||||
|
||||
// AllowHeaders defines a list of request headers that can be used when
|
||||
// making the actual request. This is in response to a preflight request.
|
||||
//
|
||||
// Optional. Default value "".
|
||||
AllowHeaders string
|
||||
|
||||
// AllowCredentials indicates whether or not the response to the request
|
||||
// can be exposed when the credentials flag is true. When used as part of
|
||||
// a response to a preflight request, this indicates whether or not the
|
||||
// actual request can be made using credentials. Note: If true, AllowOrigins
|
||||
// cannot be set to a wildcard ("*") to prevent security vulnerabilities.
|
||||
//
|
||||
// Optional. Default value false.
|
||||
AllowCredentials bool
|
||||
|
||||
// ExposeHeaders defines a whitelist headers that clients are allowed to
|
||||
// access.
|
||||
//
|
||||
// Optional. Default value "".
|
||||
ExposeHeaders string
|
||||
|
||||
// MaxAge indicates how long (in seconds) the results of a preflight request
|
||||
// can be cached.
|
||||
// If you pass MaxAge 0, Access-Control-Max-Age header will not be added and
|
||||
// browser will use 5 seconds by default.
|
||||
// To disable caching completely, pass MaxAge value negative. It will set the Access-Control-Max-Age header 0.
|
||||
//
|
||||
// Optional. Default value 0.
|
||||
MaxAge int
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
AllowOriginsFunc: nil,
|
||||
AllowOrigins: "*",
|
||||
AllowMethods: strings.Join([]string{
|
||||
fiber.MethodGet,
|
||||
fiber.MethodPost,
|
||||
fiber.MethodHead,
|
||||
fiber.MethodPut,
|
||||
fiber.MethodDelete,
|
||||
fiber.MethodPatch,
|
||||
}, ","),
|
||||
AllowHeaders: "",
|
||||
AllowCredentials: false,
|
||||
ExposeHeaders: "",
|
||||
MaxAge: 0,
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := ConfigDefault
|
||||
|
||||
// Override config if provided
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.AllowMethods == "" {
|
||||
cfg.AllowMethods = ConfigDefault.AllowMethods
|
||||
}
|
||||
// When none of the AllowOrigins or AllowOriginsFunc config was defined, set the default AllowOrigins value with "*"
|
||||
if cfg.AllowOrigins == "" && cfg.AllowOriginsFunc == nil {
|
||||
cfg.AllowOrigins = ConfigDefault.AllowOrigins
|
||||
}
|
||||
}
|
||||
|
||||
// Warning logs if both AllowOrigins and AllowOriginsFunc are set
|
||||
if cfg.AllowOrigins != "" && cfg.AllowOriginsFunc != nil {
|
||||
log.Warn("[CORS] Both 'AllowOrigins' and 'AllowOriginsFunc' have been defined.")
|
||||
}
|
||||
|
||||
// Validate CORS credentials configuration
|
||||
if cfg.AllowCredentials && cfg.AllowOrigins == "*" {
|
||||
panic("[CORS] Insecure setup, 'AllowCredentials' is set to true, and 'AllowOrigins' is set to a wildcard.")
|
||||
}
|
||||
|
||||
// allowOrigins is a slice of strings that contains the allowed origins
|
||||
// defined in the 'AllowOrigins' configuration.
|
||||
allowOrigins := []string{}
|
||||
allowSOrigins := []subdomain{}
|
||||
allowAllOrigins := false
|
||||
|
||||
// Validate and normalize static AllowOrigins
|
||||
if cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" {
|
||||
origins := strings.Split(cfg.AllowOrigins, ",")
|
||||
for _, origin := range origins {
|
||||
if i := strings.Index(origin, "://*."); i != -1 {
|
||||
trimmedOrigin := strings.TrimSpace(origin[:i+3] + origin[i+4:])
|
||||
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
|
||||
if !isValid {
|
||||
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
|
||||
}
|
||||
sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]}
|
||||
allowSOrigins = append(allowSOrigins, sd)
|
||||
} else {
|
||||
trimmedOrigin := strings.TrimSpace(origin)
|
||||
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
|
||||
if !isValid {
|
||||
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
|
||||
}
|
||||
allowOrigins = append(allowOrigins, normalizedOrigin)
|
||||
}
|
||||
}
|
||||
} else if cfg.AllowOrigins == "*" {
|
||||
allowAllOrigins = true
|
||||
}
|
||||
|
||||
// Strip white spaces
|
||||
allowMethods := strings.ReplaceAll(cfg.AllowMethods, " ", "")
|
||||
allowHeaders := strings.ReplaceAll(cfg.AllowHeaders, " ", "")
|
||||
exposeHeaders := strings.ReplaceAll(cfg.ExposeHeaders, " ", "")
|
||||
|
||||
// Convert int to string
|
||||
maxAge := strconv.Itoa(cfg.MaxAge)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Get originHeader header
|
||||
originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin))
|
||||
|
||||
// If the request does not have Origin header, the request is outside the scope of CORS
|
||||
if originHeader == "" {
|
||||
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
|
||||
// Unless all origins are allowed, we include the Vary header to cache the response correctly
|
||||
if !allowAllOrigins {
|
||||
c.Vary(fiber.HeaderOrigin)
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// If it's a preflight request and doesn't have Access-Control-Request-Method header, it's outside the scope of CORS
|
||||
if c.Method() == fiber.MethodOptions && c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
|
||||
// Response to OPTIONS request should not be cached but,
|
||||
// some caching can be configured to cache such responses.
|
||||
// To Avoid poisoning the cache, we include the Vary header
|
||||
// for non-CORS OPTIONS requests:
|
||||
c.Vary(fiber.HeaderOrigin)
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Set default allowOrigin to empty string
|
||||
allowOrigin := ""
|
||||
|
||||
// Check allowed origins
|
||||
if allowAllOrigins {
|
||||
allowOrigin = "*"
|
||||
} else {
|
||||
// Check if the origin is in the list of allowed origins
|
||||
for _, origin := range allowOrigins {
|
||||
if origin == originHeader {
|
||||
allowOrigin = originHeader
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the origin is in the list of allowed subdomains
|
||||
if allowOrigin == "" {
|
||||
for _, sOrigin := range allowSOrigins {
|
||||
if sOrigin.match(originHeader) {
|
||||
allowOrigin = originHeader
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run AllowOriginsFunc if the logic for
|
||||
// handling the value in 'AllowOrigins' does
|
||||
// not result in allowOrigin being set.
|
||||
if allowOrigin == "" && cfg.AllowOriginsFunc != nil && cfg.AllowOriginsFunc(originHeader) {
|
||||
allowOrigin = originHeader
|
||||
}
|
||||
|
||||
// Simple request
|
||||
// Ommit allowMethods and allowHeaders, only used for pre-flight requests
|
||||
if c.Method() != fiber.MethodOptions {
|
||||
if !allowAllOrigins {
|
||||
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
|
||||
c.Vary(fiber.HeaderOrigin)
|
||||
}
|
||||
setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg)
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Pre-flight request
|
||||
|
||||
// Response to OPTIONS request should not be cached but,
|
||||
// some caching can be configured to cache such responses.
|
||||
// To Avoid poisoning the cache, we include the Vary header
|
||||
// of preflight responses:
|
||||
c.Vary(fiber.HeaderAccessControlRequestMethod)
|
||||
c.Vary(fiber.HeaderAccessControlRequestHeaders)
|
||||
c.Vary(fiber.HeaderOrigin)
|
||||
|
||||
setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)
|
||||
|
||||
// Send 204 No Content
|
||||
return c.SendStatus(fiber.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
// Function to set CORS headers
|
||||
func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) {
|
||||
if cfg.AllowCredentials {
|
||||
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
|
||||
if allowOrigin == "*" {
|
||||
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||
log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.")
|
||||
} else if allowOrigin != "" {
|
||||
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
|
||||
}
|
||||
} else if allowOrigin != "" {
|
||||
// For non-credential requests, it's safe to set to '*' or specific origins
|
||||
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||
}
|
||||
|
||||
// Set Allow-Methods if not empty
|
||||
if allowMethods != "" {
|
||||
c.Set(fiber.HeaderAccessControlAllowMethods, allowMethods)
|
||||
}
|
||||
|
||||
// Set Allow-Headers if not empty
|
||||
if allowHeaders != "" {
|
||||
c.Set(fiber.HeaderAccessControlAllowHeaders, allowHeaders)
|
||||
} else {
|
||||
h := c.Get(fiber.HeaderAccessControlRequestHeaders)
|
||||
if h != "" {
|
||||
c.Set(fiber.HeaderAccessControlAllowHeaders, h)
|
||||
}
|
||||
}
|
||||
|
||||
// Set MaxAge if set
|
||||
if cfg.MaxAge > 0 {
|
||||
c.Set(fiber.HeaderAccessControlMaxAge, maxAge)
|
||||
} else if cfg.MaxAge < 0 {
|
||||
c.Set(fiber.HeaderAccessControlMaxAge, "0")
|
||||
}
|
||||
|
||||
// Set Expose-Headers if not empty
|
||||
if exposeHeaders != "" {
|
||||
c.Set(fiber.HeaderAccessControlExposeHeaders, exposeHeaders)
|
||||
}
|
||||
}
|
||||
+71
@@ -0,0 +1,71 @@
|
||||
package cors
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// matchScheme compares the scheme of the domain and pattern
|
||||
func matchScheme(domain, pattern string) bool {
|
||||
didx := strings.Index(domain, ":")
|
||||
pidx := strings.Index(pattern, ":")
|
||||
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
|
||||
}
|
||||
|
||||
// normalizeDomain removes the scheme and port from the input domain
|
||||
func normalizeDomain(input string) string {
|
||||
// Remove scheme
|
||||
input = strings.TrimPrefix(strings.TrimPrefix(input, "http://"), "https://")
|
||||
|
||||
// Find and remove port, if present
|
||||
if len(input) > 0 && input[0] != '[' {
|
||||
if portIndex := strings.Index(input, ":"); portIndex != -1 {
|
||||
input = input[:portIndex]
|
||||
}
|
||||
}
|
||||
|
||||
return input
|
||||
}
|
||||
|
||||
// normalizeOrigin checks if the provided origin is in a correct format
|
||||
// and normalizes it by removing any path or trailing slash.
|
||||
// It returns a boolean indicating whether the origin is valid
|
||||
// and the normalized origin.
|
||||
func normalizeOrigin(origin string) (bool, string) {
|
||||
parsedOrigin, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// Validate the scheme is either http or https
|
||||
if parsedOrigin.Scheme != "http" && parsedOrigin.Scheme != "https" {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// Don't allow a wildcard with a protocol
|
||||
// wildcards cannot be used within any other value. For example, the following header is not valid:
|
||||
// Access-Control-Allow-Origin: https://*
|
||||
if strings.Contains(parsedOrigin.Host, "*") {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// Validate there is a host present. The presence of a path, query, or fragment components
|
||||
// is checked, but a trailing "/" (indicative of the root) is allowed for the path and will be normalized
|
||||
if parsedOrigin.Host == "" || (parsedOrigin.Path != "" && parsedOrigin.Path != "/") || parsedOrigin.RawQuery != "" || parsedOrigin.Fragment != "" {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// Normalize the origin by constructing it from the scheme and host.
|
||||
// The path or trailing slash is not included in the normalized origin.
|
||||
return true, strings.ToLower(parsedOrigin.Scheme) + "://" + strings.ToLower(parsedOrigin.Host)
|
||||
}
|
||||
|
||||
type subdomain struct {
|
||||
// The wildcard pattern
|
||||
prefix string
|
||||
suffix string
|
||||
}
|
||||
|
||||
func (s subdomain) match(o string) bool {
|
||||
return len(o) >= len(s.prefix)+len(s.suffix) && strings.HasPrefix(o, s.prefix) && strings.HasSuffix(o, s.suffix)
|
||||
}
|
||||
+154
@@ -0,0 +1,154 @@
|
||||
package helmet
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip middleware.
|
||||
// Optional. Default: nil
|
||||
Next func(*fiber.Ctx) bool
|
||||
|
||||
// XSSProtection
|
||||
// Optional. Default value "0".
|
||||
XSSProtection string
|
||||
|
||||
// ContentTypeNosniff
|
||||
// Optional. Default value "nosniff".
|
||||
ContentTypeNosniff string
|
||||
|
||||
// XFrameOptions
|
||||
// Optional. Default value "SAMEORIGIN".
|
||||
// Possible values: "SAMEORIGIN", "DENY", "ALLOW-FROM uri"
|
||||
XFrameOptions string
|
||||
|
||||
// HSTSMaxAge
|
||||
// Optional. Default value 0.
|
||||
HSTSMaxAge int
|
||||
|
||||
// HSTSExcludeSubdomains
|
||||
// Optional. Default value false.
|
||||
HSTSExcludeSubdomains bool
|
||||
|
||||
// ContentSecurityPolicy
|
||||
// Optional. Default value "".
|
||||
ContentSecurityPolicy string
|
||||
|
||||
// CSPReportOnly
|
||||
// Optional. Default value false.
|
||||
CSPReportOnly bool
|
||||
|
||||
// HSTSPreloadEnabled
|
||||
// Optional. Default value false.
|
||||
HSTSPreloadEnabled bool
|
||||
|
||||
// ReferrerPolicy
|
||||
// Optional. Default value "ReferrerPolicy".
|
||||
ReferrerPolicy string
|
||||
|
||||
// Permissions-Policy
|
||||
// Optional. Default value "".
|
||||
PermissionPolicy string
|
||||
|
||||
// Cross-Origin-Embedder-Policy
|
||||
// Optional. Default value "require-corp".
|
||||
CrossOriginEmbedderPolicy string
|
||||
|
||||
// Cross-Origin-Opener-Policy
|
||||
// Optional. Default value "same-origin".
|
||||
CrossOriginOpenerPolicy string
|
||||
|
||||
// Cross-Origin-Resource-Policy
|
||||
// Optional. Default value "same-origin".
|
||||
CrossOriginResourcePolicy string
|
||||
|
||||
// Origin-Agent-Cluster
|
||||
// Optional. Default value "?1".
|
||||
OriginAgentCluster string
|
||||
|
||||
// X-DNS-Prefetch-Control
|
||||
// Optional. Default value "off".
|
||||
XDNSPrefetchControl string
|
||||
|
||||
// X-Download-Options
|
||||
// Optional. Default value "noopen".
|
||||
XDownloadOptions string
|
||||
|
||||
// X-Permitted-Cross-Domain-Policies
|
||||
// Optional. Default value "none".
|
||||
XPermittedCrossDomain string
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
XSSProtection: "0",
|
||||
ContentTypeNosniff: "nosniff",
|
||||
XFrameOptions: "SAMEORIGIN",
|
||||
ReferrerPolicy: "no-referrer",
|
||||
CrossOriginEmbedderPolicy: "require-corp",
|
||||
CrossOriginOpenerPolicy: "same-origin",
|
||||
CrossOriginResourcePolicy: "same-origin",
|
||||
OriginAgentCluster: "?1",
|
||||
XDNSPrefetchControl: "off",
|
||||
XDownloadOptions: "noopen",
|
||||
XPermittedCrossDomain: "none",
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.XSSProtection == "" {
|
||||
cfg.XSSProtection = ConfigDefault.XSSProtection
|
||||
}
|
||||
|
||||
if cfg.ContentTypeNosniff == "" {
|
||||
cfg.ContentTypeNosniff = ConfigDefault.ContentTypeNosniff
|
||||
}
|
||||
|
||||
if cfg.XFrameOptions == "" {
|
||||
cfg.XFrameOptions = ConfigDefault.XFrameOptions
|
||||
}
|
||||
|
||||
if cfg.ReferrerPolicy == "" {
|
||||
cfg.ReferrerPolicy = ConfigDefault.ReferrerPolicy
|
||||
}
|
||||
|
||||
if cfg.CrossOriginEmbedderPolicy == "" {
|
||||
cfg.CrossOriginEmbedderPolicy = ConfigDefault.CrossOriginEmbedderPolicy
|
||||
}
|
||||
|
||||
if cfg.CrossOriginOpenerPolicy == "" {
|
||||
cfg.CrossOriginOpenerPolicy = ConfigDefault.CrossOriginOpenerPolicy
|
||||
}
|
||||
|
||||
if cfg.CrossOriginResourcePolicy == "" {
|
||||
cfg.CrossOriginResourcePolicy = ConfigDefault.CrossOriginResourcePolicy
|
||||
}
|
||||
|
||||
if cfg.OriginAgentCluster == "" {
|
||||
cfg.OriginAgentCluster = ConfigDefault.OriginAgentCluster
|
||||
}
|
||||
|
||||
if cfg.XDNSPrefetchControl == "" {
|
||||
cfg.XDNSPrefetchControl = ConfigDefault.XDNSPrefetchControl
|
||||
}
|
||||
|
||||
if cfg.XDownloadOptions == "" {
|
||||
cfg.XDownloadOptions = ConfigDefault.XDownloadOptions
|
||||
}
|
||||
|
||||
if cfg.XPermittedCrossDomain == "" {
|
||||
cfg.XPermittedCrossDomain = ConfigDefault.XPermittedCrossDomain
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
+94
@@ -0,0 +1,94 @@
|
||||
package helmet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Init config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Return middleware handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Next request to skip middleware
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Set headers
|
||||
if cfg.XSSProtection != "" {
|
||||
c.Set(fiber.HeaderXXSSProtection, cfg.XSSProtection)
|
||||
}
|
||||
|
||||
if cfg.ContentTypeNosniff != "" {
|
||||
c.Set(fiber.HeaderXContentTypeOptions, cfg.ContentTypeNosniff)
|
||||
}
|
||||
|
||||
if cfg.XFrameOptions != "" {
|
||||
c.Set(fiber.HeaderXFrameOptions, cfg.XFrameOptions)
|
||||
}
|
||||
|
||||
if cfg.CrossOriginEmbedderPolicy != "" {
|
||||
c.Set("Cross-Origin-Embedder-Policy", cfg.CrossOriginEmbedderPolicy)
|
||||
}
|
||||
|
||||
if cfg.CrossOriginOpenerPolicy != "" {
|
||||
c.Set("Cross-Origin-Opener-Policy", cfg.CrossOriginOpenerPolicy)
|
||||
}
|
||||
|
||||
if cfg.CrossOriginResourcePolicy != "" {
|
||||
c.Set("Cross-Origin-Resource-Policy", cfg.CrossOriginResourcePolicy)
|
||||
}
|
||||
|
||||
if cfg.OriginAgentCluster != "" {
|
||||
c.Set("Origin-Agent-Cluster", cfg.OriginAgentCluster)
|
||||
}
|
||||
|
||||
if cfg.ReferrerPolicy != "" {
|
||||
c.Set("Referrer-Policy", cfg.ReferrerPolicy)
|
||||
}
|
||||
|
||||
if cfg.XDNSPrefetchControl != "" {
|
||||
c.Set("X-DNS-Prefetch-Control", cfg.XDNSPrefetchControl)
|
||||
}
|
||||
|
||||
if cfg.XDownloadOptions != "" {
|
||||
c.Set("X-Download-Options", cfg.XDownloadOptions)
|
||||
}
|
||||
|
||||
if cfg.XPermittedCrossDomain != "" {
|
||||
c.Set("X-Permitted-Cross-Domain-Policies", cfg.XPermittedCrossDomain)
|
||||
}
|
||||
|
||||
// Handle HSTS headers
|
||||
if c.Protocol() == "https" && cfg.HSTSMaxAge != 0 {
|
||||
subdomains := ""
|
||||
if !cfg.HSTSExcludeSubdomains {
|
||||
subdomains = "; includeSubDomains"
|
||||
}
|
||||
if cfg.HSTSPreloadEnabled {
|
||||
subdomains = fmt.Sprintf("%s; preload", subdomains)
|
||||
}
|
||||
c.Set(fiber.HeaderStrictTransportSecurity, fmt.Sprintf("max-age=%d%s", cfg.HSTSMaxAge, subdomains))
|
||||
}
|
||||
|
||||
// Handle Content-Security-Policy headers
|
||||
if cfg.ContentSecurityPolicy != "" {
|
||||
if cfg.CSPReportOnly {
|
||||
c.Set(fiber.HeaderContentSecurityPolicyReportOnly, cfg.ContentSecurityPolicy)
|
||||
} else {
|
||||
c.Set(fiber.HeaderContentSecurityPolicy, cfg.ContentSecurityPolicy)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Permissions-Policy headers
|
||||
if cfg.PermissionPolicy != "" {
|
||||
c.Set(fiber.HeaderPermissionsPolicy, cfg.PermissionPolicy)
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
+128
@@ -0,0 +1,128 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Max number of recent connections during `Expiration` seconds before sending a 429 response
|
||||
//
|
||||
// Default: 5
|
||||
Max int
|
||||
|
||||
// KeyGenerator allows you to generate custom keys, by default c.IP() is used
|
||||
//
|
||||
// Default: func(c *fiber.Ctx) string {
|
||||
// return c.IP()
|
||||
// }
|
||||
KeyGenerator func(*fiber.Ctx) string
|
||||
|
||||
// Expiration is the time on how long to keep records of requests in memory
|
||||
//
|
||||
// Default: 1 * time.Minute
|
||||
Expiration time.Duration
|
||||
|
||||
// LimitReached is called when a request hits the limit
|
||||
//
|
||||
// Default: func(c *fiber.Ctx) error {
|
||||
// return c.SendStatus(fiber.StatusTooManyRequests)
|
||||
// }
|
||||
LimitReached fiber.Handler
|
||||
|
||||
// When set to true, requests with StatusCode >= 400 won't be counted.
|
||||
//
|
||||
// Default: false
|
||||
SkipFailedRequests bool
|
||||
|
||||
// When set to true, requests with StatusCode < 400 won't be counted.
|
||||
//
|
||||
// Default: false
|
||||
SkipSuccessfulRequests bool
|
||||
|
||||
// Store is used to store the state of the middleware
|
||||
//
|
||||
// Default: an in memory store for this process only
|
||||
Storage fiber.Storage
|
||||
|
||||
// LimiterMiddleware is the struct that implements a limiter middleware.
|
||||
//
|
||||
// Default: a new Fixed Window Rate Limiter
|
||||
LimiterMiddleware LimiterHandler
|
||||
|
||||
// Deprecated: Use Expiration instead
|
||||
Duration time.Duration
|
||||
|
||||
// Deprecated: Use Storage instead
|
||||
Store fiber.Storage
|
||||
|
||||
// Deprecated: Use KeyGenerator instead
|
||||
Key func(*fiber.Ctx) string
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Max: 5,
|
||||
Expiration: 1 * time.Minute,
|
||||
KeyGenerator: func(c *fiber.Ctx) string {
|
||||
return c.IP()
|
||||
},
|
||||
LimitReached: func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusTooManyRequests)
|
||||
},
|
||||
SkipFailedRequests: false,
|
||||
SkipSuccessfulRequests: false,
|
||||
LimiterMiddleware: FixedWindow{},
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if int(cfg.Duration.Seconds()) > 0 {
|
||||
log.Warn("[LIMITER] Duration is deprecated, please use Expiration")
|
||||
cfg.Expiration = cfg.Duration
|
||||
}
|
||||
if cfg.Key != nil {
|
||||
log.Warn("[LIMITER] Key is deprecated, please us KeyGenerator")
|
||||
cfg.KeyGenerator = cfg.Key
|
||||
}
|
||||
if cfg.Store != nil {
|
||||
log.Warn("[LIMITER] Store is deprecated, please use Storage")
|
||||
cfg.Storage = cfg.Store
|
||||
}
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
if cfg.Max <= 0 {
|
||||
cfg.Max = ConfigDefault.Max
|
||||
}
|
||||
if int(cfg.Expiration.Seconds()) <= 0 {
|
||||
cfg.Expiration = ConfigDefault.Expiration
|
||||
}
|
||||
if cfg.KeyGenerator == nil {
|
||||
cfg.KeyGenerator = ConfigDefault.KeyGenerator
|
||||
}
|
||||
if cfg.LimitReached == nil {
|
||||
cfg.LimitReached = ConfigDefault.LimitReached
|
||||
}
|
||||
if cfg.LimiterMiddleware == nil {
|
||||
cfg.LimiterMiddleware = ConfigDefault.LimiterMiddleware
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
+25
@@ -0,0 +1,25 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
// X-RateLimit-* headers
|
||||
xRateLimitLimit = "X-RateLimit-Limit"
|
||||
xRateLimitRemaining = "X-RateLimit-Remaining"
|
||||
xRateLimitReset = "X-RateLimit-Reset"
|
||||
)
|
||||
|
||||
type LimiterHandler interface {
|
||||
New(config Config) fiber.Handler
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Return the specified middleware handler.
|
||||
return cfg.LimiterMiddleware.New(cfg)
|
||||
}
|
||||
+106
@@ -0,0 +1,106 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
type FixedWindow struct{}
|
||||
|
||||
// New creates a new fixed window middleware handler
|
||||
func (FixedWindow) New(cfg Config) fiber.Handler {
|
||||
var (
|
||||
// Limiter variables
|
||||
mux = &sync.RWMutex{}
|
||||
max = strconv.Itoa(cfg.Max)
|
||||
expiration = uint64(cfg.Expiration.Seconds())
|
||||
)
|
||||
|
||||
// Create manager to simplify storage operations ( see manager.go )
|
||||
manager := newManager(cfg.Storage)
|
||||
|
||||
// Update timestamp every second
|
||||
utils.StartTimeStampUpdater()
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Get key from request
|
||||
key := cfg.KeyGenerator(c)
|
||||
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
|
||||
// Get entry from pool and release when finished
|
||||
e := manager.get(key)
|
||||
|
||||
// Get timestamp
|
||||
ts := uint64(atomic.LoadUint32(&utils.Timestamp))
|
||||
|
||||
// Set expiration if entry does not exist
|
||||
if e.exp == 0 {
|
||||
e.exp = ts + expiration
|
||||
} else if ts >= e.exp {
|
||||
// Check if entry is expired
|
||||
e.currHits = 0
|
||||
e.exp = ts + expiration
|
||||
}
|
||||
|
||||
// Increment hits
|
||||
e.currHits++
|
||||
|
||||
// Calculate when it resets in seconds
|
||||
resetInSec := e.exp - ts
|
||||
|
||||
// Set how many hits we have left
|
||||
remaining := cfg.Max - e.currHits
|
||||
|
||||
// Update storage
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
|
||||
// Unlock entry
|
||||
mux.Unlock()
|
||||
|
||||
// Check if hits exceed the cfg.Max
|
||||
if remaining < 0 {
|
||||
// Return response with Retry-After header
|
||||
// https://tools.ietf.org/html/rfc6584
|
||||
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
|
||||
|
||||
// Call LimitReached handler
|
||||
return cfg.LimitReached(c)
|
||||
}
|
||||
|
||||
// Continue stack for reaching c.Response().StatusCode()
|
||||
// Store err for returning
|
||||
err := c.Next()
|
||||
|
||||
// Check for SkipFailedRequests and SkipSuccessfulRequests
|
||||
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
|
||||
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
e = manager.get(key)
|
||||
e.currHits--
|
||||
remaining++
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
// Unlock entry
|
||||
mux.Unlock()
|
||||
}
|
||||
|
||||
// We can continue, update RateLimit headers
|
||||
c.Set(xRateLimitLimit, max)
|
||||
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
||||
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
+137
@@ -0,0 +1,137 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
type SlidingWindow struct{}
|
||||
|
||||
// New creates a new sliding window middleware handler
|
||||
func (SlidingWindow) New(cfg Config) fiber.Handler {
|
||||
var (
|
||||
// Limiter variables
|
||||
mux = &sync.RWMutex{}
|
||||
max = strconv.Itoa(cfg.Max)
|
||||
expiration = uint64(cfg.Expiration.Seconds())
|
||||
)
|
||||
|
||||
// Create manager to simplify storage operations ( see manager.go )
|
||||
manager := newManager(cfg.Storage)
|
||||
|
||||
// Update timestamp every second
|
||||
utils.StartTimeStampUpdater()
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Get key from request
|
||||
key := cfg.KeyGenerator(c)
|
||||
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
|
||||
// Get entry from pool and release when finished
|
||||
e := manager.get(key)
|
||||
|
||||
// Get timestamp
|
||||
ts := uint64(atomic.LoadUint32(&utils.Timestamp))
|
||||
|
||||
// Set expiration if entry does not exist
|
||||
if e.exp == 0 {
|
||||
e.exp = ts + expiration
|
||||
} else if ts >= e.exp {
|
||||
// The entry has expired, handle the expiration.
|
||||
// Set the prevHits to the current hits and reset the hits to 0.
|
||||
e.prevHits = e.currHits
|
||||
|
||||
// Reset the current hits to 0.
|
||||
e.currHits = 0
|
||||
|
||||
// Check how much into the current window it currently is and sets the
|
||||
// expiry based on that, otherwise this would only reset on
|
||||
// the next request and not show the correct expiry.
|
||||
elapsed := ts - e.exp
|
||||
if elapsed >= expiration {
|
||||
e.exp = ts + expiration
|
||||
} else {
|
||||
e.exp = ts + expiration - elapsed
|
||||
}
|
||||
}
|
||||
|
||||
// Increment hits
|
||||
e.currHits++
|
||||
|
||||
// Calculate when it resets in seconds
|
||||
resetInSec := e.exp - ts
|
||||
|
||||
// weight = time until current window reset / total window length
|
||||
weight := float64(resetInSec) / float64(expiration)
|
||||
|
||||
// rate = request count in previous window - weight + request count in current window
|
||||
rate := int(float64(e.prevHits)*weight) + e.currHits
|
||||
|
||||
// Calculate how many hits can be made based on the current rate
|
||||
remaining := cfg.Max - rate
|
||||
|
||||
// Update storage. Garbage collect when the next window ends.
|
||||
// |--------------------------|--------------------------|
|
||||
// ^ ^ ^ ^
|
||||
// ts e.exp End sample window End next window
|
||||
// <------------>
|
||||
// resetInSec
|
||||
// resetInSec = e.exp - ts - time until end of current window.
|
||||
// duration + expiration = end of next window.
|
||||
// Because we don't want to garbage collect in the middle of a window
|
||||
// we add the expiration to the duration.
|
||||
// Otherwise after the end of "sample window", attackers could launch
|
||||
// a new request with the full window length.
|
||||
manager.set(key, e, time.Duration(resetInSec+expiration)*time.Second)
|
||||
|
||||
// Unlock entry
|
||||
mux.Unlock()
|
||||
|
||||
// Check if hits exceed the cfg.Max
|
||||
if remaining < 0 {
|
||||
// Return response with Retry-After header
|
||||
// https://tools.ietf.org/html/rfc6584
|
||||
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
|
||||
|
||||
// Call LimitReached handler
|
||||
return cfg.LimitReached(c)
|
||||
}
|
||||
|
||||
// Continue stack for reaching c.Response().StatusCode()
|
||||
// Store err for returning
|
||||
err := c.Next()
|
||||
|
||||
// Check for SkipFailedRequests and SkipSuccessfulRequests
|
||||
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
|
||||
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
e = manager.get(key)
|
||||
e.currHits--
|
||||
remaining++
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
// Unlock entry
|
||||
mux.Unlock()
|
||||
}
|
||||
|
||||
// We can continue, update RateLimit headers
|
||||
c.Set(xRateLimitLimit, max)
|
||||
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
||||
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
+92
@@ -0,0 +1,92 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/memory"
|
||||
)
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
|
||||
type item struct {
|
||||
currHits int
|
||||
prevHits int
|
||||
exp uint64
|
||||
}
|
||||
|
||||
//msgp:ignore manager
|
||||
type manager struct {
|
||||
pool sync.Pool
|
||||
memory *memory.Storage
|
||||
storage fiber.Storage
|
||||
}
|
||||
|
||||
func newManager(storage fiber.Storage) *manager {
|
||||
// Create new storage handler
|
||||
manager := &manager{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(item)
|
||||
},
|
||||
},
|
||||
}
|
||||
if storage != nil {
|
||||
// Use provided storage if provided
|
||||
manager.storage = storage
|
||||
} else {
|
||||
// Fallback too memory storage
|
||||
manager.memory = memory.New()
|
||||
}
|
||||
return manager
|
||||
}
|
||||
|
||||
// acquire returns an *entry from the sync.Pool
|
||||
func (m *manager) acquire() *item {
|
||||
return m.pool.Get().(*item) //nolint:forcetypeassert // We store nothing else in the pool
|
||||
}
|
||||
|
||||
// release and reset *entry to sync.Pool
|
||||
func (m *manager) release(e *item) {
|
||||
e.prevHits = 0
|
||||
e.currHits = 0
|
||||
e.exp = 0
|
||||
m.pool.Put(e)
|
||||
}
|
||||
|
||||
// get data from storage or memory
|
||||
func (m *manager) get(key string) *item {
|
||||
var it *item
|
||||
if m.storage != nil {
|
||||
it = m.acquire()
|
||||
raw, err := m.storage.Get(key)
|
||||
if err != nil {
|
||||
return it
|
||||
}
|
||||
if raw != nil {
|
||||
if _, err := it.UnmarshalMsg(raw); err != nil {
|
||||
return it
|
||||
}
|
||||
}
|
||||
return it
|
||||
}
|
||||
if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
|
||||
it = m.acquire()
|
||||
return it
|
||||
}
|
||||
return it
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
if raw, err := it.MarshalMsg(nil); err == nil {
|
||||
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
|
||||
}
|
||||
// we can release data because it's serialized to database
|
||||
m.release(it)
|
||||
} else {
|
||||
m.memory.Set(key, it, exp)
|
||||
}
|
||||
}
|
||||
+160
@@ -0,0 +1,160 @@
|
||||
package limiter
|
||||
|
||||
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
|
||||
|
||||
import (
|
||||
"github.com/tinylib/msgp/msgp"
|
||||
)
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, err = dc.ReadMapHeader()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, err = dc.ReadMapKeyPtr()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "currHits":
|
||||
z.currHits, err = dc.ReadInt()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "currHits")
|
||||
return
|
||||
}
|
||||
case "prevHits":
|
||||
z.prevHits, err = dc.ReadInt()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "prevHits")
|
||||
return
|
||||
}
|
||||
case "exp":
|
||||
z.exp, err = dc.ReadUint64()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "exp")
|
||||
return
|
||||
}
|
||||
default:
|
||||
err = dc.Skip()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 3
|
||||
// write "currHits"
|
||||
err = en.Append(0x83, 0xa8, 0x63, 0x75, 0x72, 0x72, 0x48, 0x69, 0x74, 0x73)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteInt(z.currHits)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "currHits")
|
||||
return
|
||||
}
|
||||
// write "prevHits"
|
||||
err = en.Append(0xa8, 0x70, 0x72, 0x65, 0x76, 0x48, 0x69, 0x74, 0x73)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteInt(z.prevHits)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "prevHits")
|
||||
return
|
||||
}
|
||||
// write "exp"
|
||||
err = en.Append(0xa3, 0x65, 0x78, 0x70)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteUint64(z.exp)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "exp")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 3
|
||||
// string "currHits"
|
||||
o = append(o, 0x83, 0xa8, 0x63, 0x75, 0x72, 0x72, 0x48, 0x69, 0x74, 0x73)
|
||||
o = msgp.AppendInt(o, z.currHits)
|
||||
// string "prevHits"
|
||||
o = append(o, 0xa8, 0x70, 0x72, 0x65, 0x76, 0x48, 0x69, 0x74, 0x73)
|
||||
o = msgp.AppendInt(o, z.prevHits)
|
||||
// string "exp"
|
||||
o = append(o, 0xa3, 0x65, 0x78, 0x70)
|
||||
o = msgp.AppendUint64(o, z.exp)
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "currHits":
|
||||
z.currHits, bts, err = msgp.ReadIntBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "currHits")
|
||||
return
|
||||
}
|
||||
case "prevHits":
|
||||
z.prevHits, bts, err = msgp.ReadIntBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "prevHits")
|
||||
return
|
||||
}
|
||||
case "exp":
|
||||
z.exp, bts, err = msgp.ReadUint64Bytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "exp")
|
||||
return
|
||||
}
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
o = bts
|
||||
return
|
||||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z item) Msgsize() (s int) {
|
||||
s = 1 + 9 + msgp.IntSize + 9 + msgp.IntSize + 4 + msgp.Uint64Size
|
||||
return
|
||||
}
|
||||
+136
@@ -0,0 +1,136 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Done is a function that is called after the log string for a request is written to Output,
|
||||
// and pass the log string as parameter.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Done func(c *fiber.Ctx, logString []byte)
|
||||
|
||||
// tagFunctions defines the custom tag action
|
||||
//
|
||||
// Optional. Default: map[string]LogFunc
|
||||
CustomTags map[string]LogFunc
|
||||
|
||||
// Format defines the logging tags
|
||||
//
|
||||
// Optional. Default: ${time} | ${status} | ${latency} | ${ip} | ${method} | ${path} | ${error}\n
|
||||
Format string
|
||||
|
||||
// TimeFormat https://programming.guide/go/format-parse-string-time-date-example.html
|
||||
//
|
||||
// Optional. Default: 15:04:05
|
||||
TimeFormat string
|
||||
|
||||
// TimeZone can be specified, such as "UTC" and "America/New_York" and "Asia/Chongqing", etc
|
||||
//
|
||||
// Optional. Default: "Local"
|
||||
TimeZone string
|
||||
|
||||
// TimeInterval is the delay before the timestamp is updated
|
||||
//
|
||||
// Optional. Default: 500 * time.Millisecond
|
||||
TimeInterval time.Duration
|
||||
|
||||
// Output is a writer where logs are written
|
||||
//
|
||||
// Default: os.Stdout
|
||||
Output io.Writer
|
||||
|
||||
// DisableColors defines if the logs output should be colorized
|
||||
//
|
||||
// Default: false
|
||||
DisableColors bool
|
||||
|
||||
enableColors bool
|
||||
enableLatency bool
|
||||
timeZoneLocation *time.Location
|
||||
}
|
||||
|
||||
const (
|
||||
startTag = "${"
|
||||
endTag = "}"
|
||||
paramSeparator = ":"
|
||||
)
|
||||
|
||||
type Buffer interface {
|
||||
Len() int
|
||||
ReadFrom(r io.Reader) (int64, error)
|
||||
WriteTo(w io.Writer) (int64, error)
|
||||
Bytes() []byte
|
||||
Write(p []byte) (int, error)
|
||||
WriteByte(c byte) error
|
||||
WriteString(s string) (int, error)
|
||||
Set(p []byte)
|
||||
SetString(s string)
|
||||
String() string
|
||||
}
|
||||
|
||||
type LogFunc func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error)
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Done: nil,
|
||||
Format: "${time} | ${status} | ${latency} | ${ip} | ${method} | ${path} | ${error}\n",
|
||||
TimeFormat: "15:04:05",
|
||||
TimeZone: "Local",
|
||||
TimeInterval: 500 * time.Millisecond,
|
||||
Output: os.Stdout,
|
||||
DisableColors: false,
|
||||
enableColors: true,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
if cfg.Done == nil {
|
||||
cfg.Done = ConfigDefault.Done
|
||||
}
|
||||
if cfg.Format == "" {
|
||||
cfg.Format = ConfigDefault.Format
|
||||
}
|
||||
if cfg.TimeZone == "" {
|
||||
cfg.TimeZone = ConfigDefault.TimeZone
|
||||
}
|
||||
if cfg.TimeFormat == "" {
|
||||
cfg.TimeFormat = ConfigDefault.TimeFormat
|
||||
}
|
||||
if int(cfg.TimeInterval) <= 0 {
|
||||
cfg.TimeInterval = ConfigDefault.TimeInterval
|
||||
}
|
||||
if cfg.Output == nil {
|
||||
cfg.Output = ConfigDefault.Output
|
||||
}
|
||||
|
||||
if !cfg.DisableColors && cfg.Output == ConfigDefault.Output {
|
||||
cfg.enableColors = true
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
+16
@@ -0,0 +1,16 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Data is a struct to define some variables to use in custom logger function.
|
||||
type Data struct {
|
||||
Pid string
|
||||
ErrPaddingStr string
|
||||
ChainErr error
|
||||
Start time.Time
|
||||
Stop time.Time
|
||||
Timestamp atomic.Value
|
||||
}
|
||||
+182
@@ -0,0 +1,182 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/mattn/go-colorable"
|
||||
"github.com/mattn/go-isatty"
|
||||
"github.com/valyala/bytebufferpool"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Get timezone location
|
||||
tz, err := time.LoadLocation(cfg.TimeZone)
|
||||
if err != nil || tz == nil {
|
||||
cfg.timeZoneLocation = time.Local
|
||||
} else {
|
||||
cfg.timeZoneLocation = tz
|
||||
}
|
||||
|
||||
// Check if format contains latency
|
||||
cfg.enableLatency = strings.Contains(cfg.Format, "${"+TagLatency+"}")
|
||||
|
||||
var timestamp atomic.Value
|
||||
// Create correct timeformat
|
||||
timestamp.Store(time.Now().In(cfg.timeZoneLocation).Format(cfg.TimeFormat))
|
||||
|
||||
// Update date/time every 500 milliseconds in a separate go routine
|
||||
if strings.Contains(cfg.Format, "${"+TagTime+"}") {
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(cfg.TimeInterval)
|
||||
timestamp.Store(time.Now().In(cfg.timeZoneLocation).Format(cfg.TimeFormat))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Set PID once
|
||||
pid := strconv.Itoa(os.Getpid())
|
||||
|
||||
// Set variables
|
||||
var (
|
||||
once sync.Once
|
||||
mu sync.Mutex
|
||||
errHandler fiber.ErrorHandler
|
||||
|
||||
dataPool = sync.Pool{New: func() interface{} { return new(Data) }}
|
||||
)
|
||||
|
||||
// If colors are enabled, check terminal compatibility
|
||||
if cfg.enableColors {
|
||||
cfg.Output = colorable.NewColorableStdout()
|
||||
if os.Getenv("TERM") == "dumb" || os.Getenv("NO_COLOR") == "1" || (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd())) {
|
||||
cfg.Output = colorable.NewNonColorable(os.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
errPadding := 15
|
||||
errPaddingStr := strconv.Itoa(errPadding)
|
||||
|
||||
// instead of analyzing the template inside(handler) each time, this is done once before
|
||||
// and we create several slices of the same length with the functions to be executed and fixed parts.
|
||||
templateChain, logFunChain, err := buildLogFuncChain(&cfg, createTagMap(&cfg))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Set error handler once
|
||||
once.Do(func() {
|
||||
// get longested possible path
|
||||
stack := c.App().Stack()
|
||||
for m := range stack {
|
||||
for r := range stack[m] {
|
||||
if len(stack[m][r].Path) > errPadding {
|
||||
errPadding = len(stack[m][r].Path)
|
||||
errPaddingStr = strconv.Itoa(errPadding)
|
||||
}
|
||||
}
|
||||
}
|
||||
// override error handler
|
||||
errHandler = c.App().ErrorHandler
|
||||
})
|
||||
|
||||
// Logger data
|
||||
data := dataPool.Get().(*Data) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
|
||||
// no need for a reset, as long as we always override everything
|
||||
data.Pid = pid
|
||||
data.ErrPaddingStr = errPaddingStr
|
||||
data.Timestamp = timestamp
|
||||
// put data back in the pool
|
||||
defer dataPool.Put(data)
|
||||
|
||||
// Set latency start time
|
||||
if cfg.enableLatency {
|
||||
data.Start = time.Now()
|
||||
}
|
||||
|
||||
// Handle request, store err for logging
|
||||
chainErr := c.Next()
|
||||
|
||||
data.ChainErr = chainErr
|
||||
// Manually call error handler
|
||||
if chainErr != nil {
|
||||
if err := errHandler(c, chainErr); err != nil {
|
||||
_ = c.SendStatus(fiber.StatusInternalServerError) //nolint:errcheck // TODO: Explain why we ignore the error here
|
||||
}
|
||||
}
|
||||
|
||||
// Set latency stop time
|
||||
if cfg.enableLatency {
|
||||
data.Stop = time.Now()
|
||||
}
|
||||
|
||||
// Get new buffer
|
||||
buf := bytebufferpool.Get()
|
||||
|
||||
var err error
|
||||
// Loop over template parts execute dynamic parts and add fixed parts to the buffer
|
||||
for i, logFunc := range logFunChain {
|
||||
if logFunc == nil {
|
||||
_, _ = buf.Write(templateChain[i]) //nolint:errcheck // This will never fail
|
||||
} else if templateChain[i] == nil {
|
||||
_, err = logFunc(buf, c, data, "")
|
||||
} else {
|
||||
_, err = logFunc(buf, c, data, utils.UnsafeString(templateChain[i]))
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Also write errors to the buffer
|
||||
if err != nil {
|
||||
_, _ = buf.WriteString(err.Error()) //nolint:errcheck // This will never fail
|
||||
}
|
||||
mu.Lock()
|
||||
// Write buffer to output
|
||||
if _, err := cfg.Output.Write(buf.Bytes()); err != nil {
|
||||
// Write error to output
|
||||
if _, err := cfg.Output.Write([]byte(err.Error())); err != nil {
|
||||
// There is something wrong with the given io.Writer
|
||||
_, _ = fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
if cfg.Done != nil {
|
||||
cfg.Done(c, buf.Bytes())
|
||||
}
|
||||
|
||||
// Put buffer back to pool
|
||||
bytebufferpool.Put(buf)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func appendInt(output Buffer, v int) (int, error) {
|
||||
old := output.Len()
|
||||
output.Set(fasthttp.AppendUint(output.Bytes(), v))
|
||||
return output.Len() - old, nil
|
||||
}
|
||||
+211
@@ -0,0 +1,211 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Logger variables
|
||||
const (
|
||||
TagPid = "pid"
|
||||
TagTime = "time"
|
||||
TagReferer = "referer"
|
||||
TagProtocol = "protocol"
|
||||
TagPort = "port"
|
||||
TagIP = "ip"
|
||||
TagIPs = "ips"
|
||||
TagHost = "host"
|
||||
TagMethod = "method"
|
||||
TagPath = "path"
|
||||
TagURL = "url"
|
||||
TagUA = "ua"
|
||||
TagLatency = "latency"
|
||||
TagStatus = "status"
|
||||
TagResBody = "resBody"
|
||||
TagReqHeaders = "reqHeaders"
|
||||
TagQueryStringParams = "queryParams"
|
||||
TagBody = "body"
|
||||
TagBytesSent = "bytesSent"
|
||||
TagBytesReceived = "bytesReceived"
|
||||
TagRoute = "route"
|
||||
TagError = "error"
|
||||
// Deprecated: Use TagReqHeader instead
|
||||
TagHeader = "header:"
|
||||
TagReqHeader = "reqHeader:"
|
||||
TagRespHeader = "respHeader:"
|
||||
TagLocals = "locals:"
|
||||
TagQuery = "query:"
|
||||
TagForm = "form:"
|
||||
TagCookie = "cookie:"
|
||||
TagBlack = "black"
|
||||
TagRed = "red"
|
||||
TagGreen = "green"
|
||||
TagYellow = "yellow"
|
||||
TagBlue = "blue"
|
||||
TagMagenta = "magenta"
|
||||
TagCyan = "cyan"
|
||||
TagWhite = "white"
|
||||
TagReset = "reset"
|
||||
)
|
||||
|
||||
// createTagMap function merged the default with the custom tags
|
||||
func createTagMap(cfg *Config) map[string]LogFunc {
|
||||
// Set default tags
|
||||
tagFunctions := map[string]LogFunc{
|
||||
TagReferer: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Get(fiber.HeaderReferer))
|
||||
},
|
||||
TagProtocol: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Protocol())
|
||||
},
|
||||
TagPort: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Port())
|
||||
},
|
||||
TagIP: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.IP())
|
||||
},
|
||||
TagIPs: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Get(fiber.HeaderXForwardedFor))
|
||||
},
|
||||
TagHost: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Hostname())
|
||||
},
|
||||
TagPath: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Path())
|
||||
},
|
||||
TagURL: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.OriginalURL())
|
||||
},
|
||||
TagUA: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Get(fiber.HeaderUserAgent))
|
||||
},
|
||||
TagBody: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.Write(c.Body())
|
||||
},
|
||||
TagBytesReceived: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return appendInt(output, len(c.Request().Body()))
|
||||
},
|
||||
TagBytesSent: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
if c.Response().Header.ContentLength() < 0 {
|
||||
return appendInt(output, 0)
|
||||
}
|
||||
return appendInt(output, len(c.Response().Body()))
|
||||
},
|
||||
TagRoute: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Route().Path)
|
||||
},
|
||||
TagResBody: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.Write(c.Response().Body())
|
||||
},
|
||||
TagReqHeaders: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
reqHeaders := make([]string, 0)
|
||||
for k, v := range c.GetReqHeaders() {
|
||||
reqHeaders = append(reqHeaders, k+"="+strings.Join(v, ","))
|
||||
}
|
||||
return output.Write([]byte(strings.Join(reqHeaders, "&")))
|
||||
},
|
||||
TagQueryStringParams: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Request().URI().QueryArgs().String())
|
||||
},
|
||||
|
||||
TagBlack: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Black)
|
||||
},
|
||||
TagRed: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Red)
|
||||
},
|
||||
TagGreen: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Green)
|
||||
},
|
||||
TagYellow: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Yellow)
|
||||
},
|
||||
TagBlue: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Blue)
|
||||
},
|
||||
TagMagenta: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Magenta)
|
||||
},
|
||||
TagCyan: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Cyan)
|
||||
},
|
||||
TagWhite: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.White)
|
||||
},
|
||||
TagReset: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Reset)
|
||||
},
|
||||
TagError: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
if data.ChainErr != nil {
|
||||
if cfg.enableColors {
|
||||
colors := c.App().Config().ColorScheme
|
||||
return output.WriteString(fmt.Sprintf("%s%s%s", colors.Red, data.ChainErr.Error(), colors.Reset))
|
||||
}
|
||||
return output.WriteString(data.ChainErr.Error())
|
||||
}
|
||||
return output.WriteString("-")
|
||||
},
|
||||
TagReqHeader: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Get(extraParam))
|
||||
},
|
||||
TagHeader: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Get(extraParam))
|
||||
},
|
||||
TagRespHeader: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.GetRespHeader(extraParam))
|
||||
},
|
||||
TagQuery: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Query(extraParam))
|
||||
},
|
||||
TagForm: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.FormValue(extraParam))
|
||||
},
|
||||
TagCookie: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Cookies(extraParam))
|
||||
},
|
||||
TagLocals: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
switch v := c.Locals(extraParam).(type) {
|
||||
case []byte:
|
||||
return output.Write(v)
|
||||
case string:
|
||||
return output.WriteString(v)
|
||||
case nil:
|
||||
return 0, nil
|
||||
default:
|
||||
return output.WriteString(fmt.Sprintf("%v", v))
|
||||
}
|
||||
},
|
||||
TagStatus: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
if cfg.enableColors {
|
||||
colors := c.App().Config().ColorScheme
|
||||
return output.WriteString(fmt.Sprintf("%s%3d%s", statusColor(c.Response().StatusCode(), colors), c.Response().StatusCode(), colors.Reset))
|
||||
}
|
||||
return appendInt(output, c.Response().StatusCode())
|
||||
},
|
||||
TagMethod: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
if cfg.enableColors {
|
||||
colors := c.App().Config().ColorScheme
|
||||
return output.WriteString(fmt.Sprintf("%s%s%s", methodColor(c.Method(), colors), c.Method(), colors.Reset))
|
||||
}
|
||||
return output.WriteString(c.Method())
|
||||
},
|
||||
TagPid: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(data.Pid)
|
||||
},
|
||||
TagLatency: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
latency := data.Stop.Sub(data.Start)
|
||||
return output.WriteString(fmt.Sprintf("%13v", latency))
|
||||
},
|
||||
TagTime: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(data.Timestamp.Load().(string)) //nolint:forcetypeassert // We always store a string in here
|
||||
},
|
||||
}
|
||||
// merge with custom tags from user
|
||||
for k, v := range cfg.CustomTags {
|
||||
tagFunctions[k] = v
|
||||
}
|
||||
|
||||
return tagFunctions
|
||||
}
|
||||
+70
@@ -0,0 +1,70 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// buildLogFuncChain analyzes the template and creates slices with the functions for execution and
|
||||
// slices with the fixed parts of the template and the parameters
|
||||
//
|
||||
// fixParts contains the fixed parts of the template or parameters if a function is stored in the funcChain at this position
|
||||
// funcChain contains for the parts which exist the functions for the dynamic parts
|
||||
// funcChain and fixParts always have the same length and contain nil for the parts where no data is required in the chain,
|
||||
// if a function exists for the part, a parameter for it can also exist in the fixParts slice
|
||||
func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) ([][]byte, []LogFunc, error) {
|
||||
// process flow is copied from the fasttemplate flow https://github.com/valyala/fasttemplate/blob/2a2d1afadadf9715bfa19683cdaeac8347e5d9f9/template.go#L23-L62
|
||||
templateB := utils.UnsafeBytes(cfg.Format)
|
||||
startTagB := utils.UnsafeBytes(startTag)
|
||||
endTagB := utils.UnsafeBytes(endTag)
|
||||
paramSeparatorB := utils.UnsafeBytes(paramSeparator)
|
||||
|
||||
var fixParts [][]byte
|
||||
var funcChain []LogFunc
|
||||
|
||||
for {
|
||||
currentPos := bytes.Index(templateB, startTagB)
|
||||
if currentPos < 0 {
|
||||
// no starting tag found in the existing template part
|
||||
break
|
||||
}
|
||||
// add fixed part
|
||||
funcChain = append(funcChain, nil)
|
||||
fixParts = append(fixParts, templateB[:currentPos])
|
||||
|
||||
templateB = templateB[currentPos+len(startTagB):]
|
||||
currentPos = bytes.Index(templateB, endTagB)
|
||||
if currentPos < 0 {
|
||||
// cannot find end tag - just write it to the output.
|
||||
funcChain = append(funcChain, nil)
|
||||
fixParts = append(fixParts, startTagB)
|
||||
break
|
||||
}
|
||||
// ## function block ##
|
||||
// first check for tags with parameters
|
||||
if index := bytes.Index(templateB[:currentPos], paramSeparatorB); index != -1 {
|
||||
logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:index+1])]
|
||||
if !ok {
|
||||
return nil, nil, errors.New("No parameter found in \"" + utils.UnsafeString(templateB[:currentPos]) + "\"")
|
||||
}
|
||||
funcChain = append(funcChain, logFunc)
|
||||
// add param to the fixParts
|
||||
fixParts = append(fixParts, templateB[index+1:currentPos])
|
||||
} else if logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:currentPos])]; ok {
|
||||
// add functions without parameter
|
||||
funcChain = append(funcChain, logFunc)
|
||||
fixParts = append(fixParts, nil)
|
||||
}
|
||||
// ## function block end ##
|
||||
|
||||
// reduce the template string
|
||||
templateB = templateB[currentPos+len(endTagB):]
|
||||
}
|
||||
// set the rest
|
||||
funcChain = append(funcChain, nil)
|
||||
fixParts = append(fixParts, templateB)
|
||||
|
||||
return fixParts, funcChain, nil
|
||||
}
|
||||
+39
@@ -0,0 +1,39 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func methodColor(method string, colors fiber.Colors) string {
|
||||
switch method {
|
||||
case fiber.MethodGet:
|
||||
return colors.Cyan
|
||||
case fiber.MethodPost:
|
||||
return colors.Green
|
||||
case fiber.MethodPut:
|
||||
return colors.Yellow
|
||||
case fiber.MethodDelete:
|
||||
return colors.Red
|
||||
case fiber.MethodPatch:
|
||||
return colors.White
|
||||
case fiber.MethodHead:
|
||||
return colors.Magenta
|
||||
case fiber.MethodOptions:
|
||||
return colors.Blue
|
||||
default:
|
||||
return colors.Reset
|
||||
}
|
||||
}
|
||||
|
||||
func statusColor(code int, colors fiber.Colors) string {
|
||||
switch {
|
||||
case code >= fiber.StatusOK && code < fiber.StatusMultipleChoices:
|
||||
return colors.Green
|
||||
case code >= fiber.StatusMultipleChoices && code < fiber.StatusBadRequest:
|
||||
return colors.Blue
|
||||
case code >= fiber.StatusBadRequest && code < fiber.StatusInternalServerError:
|
||||
return colors.Yellow
|
||||
default:
|
||||
return colors.Red
|
||||
}
|
||||
}
|
||||
+47
@@ -0,0 +1,47 @@
|
||||
package recover //nolint:predeclared // TODO: Rename to some non-builtin
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// EnableStackTrace enables handling stack trace
|
||||
//
|
||||
// Optional. Default: false
|
||||
EnableStackTrace bool
|
||||
|
||||
// StackTraceHandler defines a function to handle stack trace
|
||||
//
|
||||
// Optional. Default: defaultStackTraceHandler
|
||||
StackTraceHandler func(c *fiber.Ctx, e interface{})
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
EnableStackTrace: false,
|
||||
StackTraceHandler: defaultStackTraceHandler,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
if cfg.EnableStackTrace && cfg.StackTraceHandler == nil {
|
||||
cfg.StackTraceHandler = defaultStackTraceHandler
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
+45
@@ -0,0 +1,45 @@
|
||||
package recover //nolint:predeclared // TODO: Rename to some non-builtin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func defaultStackTraceHandler(_ *fiber.Ctx, e interface{}) {
|
||||
_, _ = os.Stderr.WriteString(fmt.Sprintf("panic: %v\n%s\n", e, debug.Stack())) //nolint:errcheck // This will never fail
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) (err error) { //nolint:nonamedreturns // Uses recover() to overwrite the error
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Catch panics
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if cfg.EnableStackTrace {
|
||||
cfg.StackTraceHandler(c, r)
|
||||
}
|
||||
|
||||
var ok bool
|
||||
if err, ok = r.(error); !ok {
|
||||
// Set error that will call the global error handler
|
||||
err = fmt.Errorf("%v", r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Return err if exist, else move to next handler
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
+230
@@ -0,0 +1,230 @@
|
||||
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
|
||||
// 🤖 Github Repository: https://github.com/gofiber/fiber
|
||||
// 📌 API Documentation: https://docs.gofiber.io
|
||||
|
||||
package fiber
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Put fields related to mounting.
|
||||
type mountFields struct {
|
||||
// Mounted and main apps
|
||||
appList map[string]*App
|
||||
// Ordered keys of apps (sorted by key length for Render)
|
||||
appListKeys []string
|
||||
// check added routes of sub-apps
|
||||
subAppsRoutesAdded sync.Once
|
||||
// check mounted sub-apps
|
||||
subAppsProcessed sync.Once
|
||||
// Prefix of app if it was mounted
|
||||
mountPath string
|
||||
}
|
||||
|
||||
// Create empty mountFields instance
|
||||
func newMountFields(app *App) *mountFields {
|
||||
return &mountFields{
|
||||
appList: map[string]*App{"": app},
|
||||
appListKeys: make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Mount attaches another app instance as a sub-router along a routing path.
|
||||
// It's very useful to split up a large API as many independent routers and
|
||||
// compose them as a single service using Mount. The fiber's error handler and
|
||||
// any of the fiber's sub apps are added to the application's error handlers
|
||||
// to be invoked on errors that happen within the prefix route.
|
||||
func (app *App) Mount(prefix string, subApp *App) Router {
|
||||
prefix = strings.TrimRight(prefix, "/")
|
||||
if prefix == "" {
|
||||
prefix = "/"
|
||||
}
|
||||
|
||||
// Support for configs of mounted-apps and sub-mounted-apps
|
||||
for mountedPrefixes, subApp := range subApp.mountFields.appList {
|
||||
path := getGroupPath(prefix, mountedPrefixes)
|
||||
|
||||
subApp.mountFields.mountPath = path
|
||||
app.mountFields.appList[path] = subApp
|
||||
}
|
||||
|
||||
// register mounted group
|
||||
mountGroup := &Group{Prefix: prefix, app: subApp}
|
||||
app.register(methodUse, prefix, mountGroup)
|
||||
|
||||
// Execute onMount hooks
|
||||
if err := subApp.hooks.executeOnMountHooks(app); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return app
|
||||
}
|
||||
|
||||
// Mount attaches another app instance as a sub-router along a routing path.
|
||||
// It's very useful to split up a large API as many independent routers and
|
||||
// compose them as a single service using Mount.
|
||||
func (grp *Group) Mount(prefix string, subApp *App) Router {
|
||||
groupPath := getGroupPath(grp.Prefix, prefix)
|
||||
groupPath = strings.TrimRight(groupPath, "/")
|
||||
if groupPath == "" {
|
||||
groupPath = "/"
|
||||
}
|
||||
|
||||
// Support for configs of mounted-apps and sub-mounted-apps
|
||||
for mountedPrefixes, subApp := range subApp.mountFields.appList {
|
||||
path := getGroupPath(groupPath, mountedPrefixes)
|
||||
|
||||
subApp.mountFields.mountPath = path
|
||||
grp.app.mountFields.appList[path] = subApp
|
||||
}
|
||||
|
||||
// register mounted group
|
||||
mountGroup := &Group{Prefix: groupPath, app: subApp}
|
||||
grp.app.register(methodUse, groupPath, mountGroup)
|
||||
|
||||
// Execute onMount hooks
|
||||
if err := subApp.hooks.executeOnMountHooks(grp.app); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return grp
|
||||
}
|
||||
|
||||
// The MountPath property contains one or more path patterns on which a sub-app was mounted.
|
||||
func (app *App) MountPath() string {
|
||||
return app.mountFields.mountPath
|
||||
}
|
||||
|
||||
// hasMountedApps Checks if there are any mounted apps in the current application.
|
||||
func (app *App) hasMountedApps() bool {
|
||||
return len(app.mountFields.appList) > 1
|
||||
}
|
||||
|
||||
// mountStartupProcess Handles the startup process of mounted apps by appending sub-app routes, generating app list keys, and processing sub-app routes.
|
||||
func (app *App) mountStartupProcess() {
|
||||
if app.hasMountedApps() {
|
||||
// add routes of sub-apps
|
||||
app.mountFields.subAppsProcessed.Do(func() {
|
||||
app.appendSubAppLists(app.mountFields.appList)
|
||||
app.generateAppListKeys()
|
||||
})
|
||||
// adds the routes of the sub-apps to the current application.
|
||||
app.mountFields.subAppsRoutesAdded.Do(func() {
|
||||
app.processSubAppsRoutes()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// generateAppListKeys generates app list keys for Render, should work after appendSubAppLists
|
||||
func (app *App) generateAppListKeys() {
|
||||
for key := range app.mountFields.appList {
|
||||
app.mountFields.appListKeys = append(app.mountFields.appListKeys, key)
|
||||
}
|
||||
|
||||
sort.Slice(app.mountFields.appListKeys, func(i, j int) bool {
|
||||
return len(app.mountFields.appListKeys[i]) < len(app.mountFields.appListKeys[j])
|
||||
})
|
||||
}
|
||||
|
||||
// appendSubAppLists supports nested for sub apps
|
||||
func (app *App) appendSubAppLists(appList map[string]*App, parent ...string) {
|
||||
// Optimize: Cache parent prefix
|
||||
parentPrefix := ""
|
||||
if len(parent) > 0 {
|
||||
parentPrefix = parent[0]
|
||||
}
|
||||
|
||||
for prefix, subApp := range appList {
|
||||
// skip real app
|
||||
if prefix == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if parentPrefix != "" {
|
||||
prefix = getGroupPath(parentPrefix, prefix)
|
||||
}
|
||||
|
||||
if _, ok := app.mountFields.appList[prefix]; !ok {
|
||||
app.mountFields.appList[prefix] = subApp
|
||||
}
|
||||
|
||||
// The first element of appList is always the app itself. If there are no other sub apps, we should skip appending nested apps.
|
||||
if len(subApp.mountFields.appList) > 1 {
|
||||
app.appendSubAppLists(subApp.mountFields.appList, prefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processSubAppsRoutes adds routes of sub-apps recursively when the server is started
|
||||
func (app *App) processSubAppsRoutes() {
|
||||
for prefix, subApp := range app.mountFields.appList {
|
||||
// skip real app
|
||||
if prefix == "" {
|
||||
continue
|
||||
}
|
||||
// process the inner routes
|
||||
if subApp.hasMountedApps() {
|
||||
subApp.mountFields.subAppsRoutesAdded.Do(func() {
|
||||
subApp.processSubAppsRoutes()
|
||||
})
|
||||
}
|
||||
}
|
||||
var handlersCount uint32
|
||||
var routePos uint32
|
||||
// Iterate over the stack of the parent app
|
||||
for m := range app.stack {
|
||||
// Iterate over each route in the stack
|
||||
stackLen := len(app.stack[m])
|
||||
for i := 0; i < stackLen; i++ {
|
||||
route := app.stack[m][i]
|
||||
// Check if the route has a mounted app
|
||||
if !route.mount {
|
||||
routePos++
|
||||
// If not, update the route's position and continue
|
||||
route.pos = routePos
|
||||
if !route.use || (route.use && m == 0) {
|
||||
handlersCount += uint32(len(route.Handlers))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Create a slice to hold the sub-app's routes
|
||||
subRoutes := make([]*Route, len(route.group.app.stack[m]))
|
||||
|
||||
// Iterate over the sub-app's routes
|
||||
for j, subAppRoute := range route.group.app.stack[m] {
|
||||
// Clone the sub-app's route
|
||||
subAppRouteClone := app.copyRoute(subAppRoute)
|
||||
|
||||
// Add the parent route's path as a prefix to the sub-app's route
|
||||
app.addPrefixToRoute(route.path, subAppRouteClone)
|
||||
|
||||
// Add the cloned sub-app's route to the slice of sub-app routes
|
||||
subRoutes[j] = subAppRouteClone
|
||||
}
|
||||
|
||||
// Insert the sub-app's routes into the parent app's stack
|
||||
newStack := make([]*Route, len(app.stack[m])+len(subRoutes)-1)
|
||||
copy(newStack[:i], app.stack[m][:i])
|
||||
copy(newStack[i:i+len(subRoutes)], subRoutes)
|
||||
copy(newStack[i+len(subRoutes):], app.stack[m][i+1:])
|
||||
app.stack[m] = newStack
|
||||
|
||||
// Decrease the parent app's route count to account for the mounted app's original route
|
||||
atomic.AddUint32(&app.routesCount, ^uint32(0))
|
||||
i--
|
||||
// Increase the parent app's route count to account for the sub-app's routes
|
||||
atomic.AddUint32(&app.routesCount, uint32(len(subRoutes)))
|
||||
|
||||
// Mark the parent app's routes as refreshed
|
||||
app.routesRefreshed = true
|
||||
// update stackLen after appending subRoutes to app.stack[m]
|
||||
stackLen = len(app.stack[m])
|
||||
}
|
||||
}
|
||||
atomic.StoreUint32(&app.handlersCount, handlersCount)
|
||||
}
|
||||
+740
@@ -0,0 +1,740 @@
|
||||
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
|
||||
// 📄 Github Repository: https://github.com/gofiber/fiber
|
||||
// 📌 API Documentation: https://docs.gofiber.io
|
||||
// ⚠️ This path parser was inspired by ucarion/urlpath (MIT License).
|
||||
// 💖 Maintained and modified for Fiber by @renewerner87
|
||||
|
||||
package fiber
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// routeParser holds the path segments and param names
|
||||
type routeParser struct {
|
||||
segs []*routeSegment // the parsed segments of the route
|
||||
params []string // that parameter names the parsed route
|
||||
wildCardCount int // number of wildcard parameters, used internally to give the wildcard parameter its number
|
||||
plusCount int // number of plus parameters, used internally to give the plus parameter its number
|
||||
}
|
||||
|
||||
// paramsSeg holds the segment metadata
|
||||
type routeSegment struct {
|
||||
// const information
|
||||
Const string // constant part of the route
|
||||
// parameter information
|
||||
IsParam bool // Truth value that indicates whether it is a parameter or a constant part
|
||||
ParamName string // name of the parameter for access to it, for wildcards and plus parameters access iterators starting with 1 are added
|
||||
ComparePart string // search part to find the end of the parameter
|
||||
PartCount int // how often is the search part contained in the non-param segments? -> necessary for greedy search
|
||||
IsGreedy bool // indicates whether the parameter is greedy or not, is used with wildcard and plus
|
||||
IsOptional bool // indicates whether the parameter is optional or not
|
||||
// common information
|
||||
IsLast bool // shows if the segment is the last one for the route
|
||||
HasOptionalSlash bool // segment has the possibility of an optional slash
|
||||
Constraints []*Constraint // Constraint type if segment is a parameter, if not it will be set to noConstraint by default
|
||||
Length int // length of the parameter for segment, when its 0 then the length is undetermined
|
||||
// future TODO: add support for optional groups "/abc(/def)?"
|
||||
}
|
||||
|
||||
// different special routing signs
|
||||
const (
|
||||
wildcardParam byte = '*' // indicates an optional greedy parameter
|
||||
plusParam byte = '+' // indicates a required greedy parameter
|
||||
optionalParam byte = '?' // concludes a parameter by name and makes it optional
|
||||
paramStarterChar byte = ':' // start character for a parameter with name
|
||||
slashDelimiter byte = '/' // separator for the route, unlike the other delimiters this character at the end can be optional
|
||||
escapeChar byte = '\\' // escape character
|
||||
paramConstraintStart byte = '<' // start of type constraint for a parameter
|
||||
paramConstraintEnd byte = '>' // end of type constraint for a parameter
|
||||
paramConstraintSeparator byte = ';' // separator of type constraints for a parameter
|
||||
paramConstraintDataStart byte = '(' // start of data of type constraint for a parameter
|
||||
paramConstraintDataEnd byte = ')' // end of data of type constraint for a parameter
|
||||
paramConstraintDataSeparator byte = ',' // separator of datas of type constraint for a parameter
|
||||
)
|
||||
|
||||
// TypeConstraint parameter constraint types
|
||||
type TypeConstraint int16
|
||||
|
||||
type Constraint struct {
|
||||
ID TypeConstraint
|
||||
RegexCompiler *regexp.Regexp
|
||||
Data []string
|
||||
}
|
||||
|
||||
const (
|
||||
noConstraint TypeConstraint = iota + 1
|
||||
intConstraint
|
||||
boolConstraint
|
||||
floatConstraint
|
||||
alphaConstraint
|
||||
datetimeConstraint
|
||||
guidConstraint
|
||||
minLenConstraint
|
||||
maxLenConstraint
|
||||
lenConstraint
|
||||
betweenLenConstraint
|
||||
minConstraint
|
||||
maxConstraint
|
||||
rangeConstraint
|
||||
regexConstraint
|
||||
)
|
||||
|
||||
// list of possible parameter and segment delimiter
|
||||
var (
|
||||
// slash has a special role, unlike the other parameters it must not be interpreted as a parameter
|
||||
routeDelimiter = []byte{slashDelimiter, '-', '.'}
|
||||
// list of greedy parameters
|
||||
greedyParameters = []byte{wildcardParam, plusParam}
|
||||
// list of chars for the parameter recognizing
|
||||
parameterStartChars = []byte{wildcardParam, plusParam, paramStarterChar}
|
||||
// list of chars of delimiters and the starting parameter name char
|
||||
parameterDelimiterChars = append([]byte{paramStarterChar, escapeChar}, routeDelimiter...)
|
||||
// list of chars to find the end of a parameter
|
||||
parameterEndChars = append([]byte{optionalParam}, parameterDelimiterChars...)
|
||||
// list of parameter constraint start
|
||||
parameterConstraintStartChars = []byte{paramConstraintStart}
|
||||
// list of parameter constraint end
|
||||
parameterConstraintEndChars = []byte{paramConstraintEnd}
|
||||
// list of parameter separator
|
||||
parameterConstraintSeparatorChars = []byte{paramConstraintSeparator}
|
||||
// list of parameter constraint data start
|
||||
parameterConstraintDataStartChars = []byte{paramConstraintDataStart}
|
||||
// list of parameter constraint data end
|
||||
parameterConstraintDataEndChars = []byte{paramConstraintDataEnd}
|
||||
// list of parameter constraint data separator
|
||||
parameterConstraintDataSeparatorChars = []byte{paramConstraintDataSeparator}
|
||||
)
|
||||
|
||||
// RoutePatternMatch checks if a given path matches a Fiber route pattern.
|
||||
func RoutePatternMatch(path, pattern string, cfg ...Config) bool {
|
||||
// See logic in (*Route).match and (*App).register
|
||||
var ctxParams [maxParams]string
|
||||
|
||||
config := Config{}
|
||||
if len(cfg) > 0 {
|
||||
config = cfg[0]
|
||||
}
|
||||
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
|
||||
// Cannot have an empty pattern
|
||||
if pattern == "" {
|
||||
pattern = "/"
|
||||
}
|
||||
// Pattern always start with a '/'
|
||||
if pattern[0] != '/' {
|
||||
pattern = "/" + pattern
|
||||
}
|
||||
|
||||
patternPretty := pattern
|
||||
|
||||
// Case-sensitive routing, all to lowercase
|
||||
if !config.CaseSensitive {
|
||||
patternPretty = utils.ToLower(patternPretty)
|
||||
path = utils.ToLower(path)
|
||||
}
|
||||
// Strict routing, remove trailing slashes
|
||||
if !config.StrictRouting && len(patternPretty) > 1 {
|
||||
patternPretty = utils.TrimRight(patternPretty, '/')
|
||||
}
|
||||
|
||||
parser := parseRoute(patternPretty)
|
||||
|
||||
if patternPretty == "/" && path == "/" {
|
||||
return true
|
||||
// '*' wildcard matches any path
|
||||
} else if patternPretty == "/*" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Does this route have parameters
|
||||
if len(parser.params) > 0 {
|
||||
if match := parser.getMatch(path, path, &ctxParams, false); match {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Check for a simple match
|
||||
patternPretty = RemoveEscapeChar(patternPretty)
|
||||
if len(patternPretty) == len(path) && patternPretty == path {
|
||||
return true
|
||||
}
|
||||
// No match
|
||||
return false
|
||||
}
|
||||
|
||||
// parseRoute analyzes the route and divides it into segments for constant areas and parameters,
|
||||
// this information is needed later when assigning the requests to the declared routes
|
||||
func parseRoute(pattern string) routeParser {
|
||||
parser := routeParser{}
|
||||
|
||||
part := ""
|
||||
for len(pattern) > 0 {
|
||||
nextParamPosition := findNextParamPosition(pattern)
|
||||
// handle the parameter part
|
||||
if nextParamPosition == 0 {
|
||||
processedPart, seg := parser.analyseParameterPart(pattern)
|
||||
parser.params, parser.segs, part = append(parser.params, seg.ParamName), append(parser.segs, seg), processedPart
|
||||
} else {
|
||||
processedPart, seg := parser.analyseConstantPart(pattern, nextParamPosition)
|
||||
parser.segs, part = append(parser.segs, seg), processedPart
|
||||
}
|
||||
|
||||
// reduce the pattern by the processed parts
|
||||
if len(part) == len(pattern) {
|
||||
break
|
||||
}
|
||||
pattern = pattern[len(part):]
|
||||
}
|
||||
// mark last segment
|
||||
if len(parser.segs) > 0 {
|
||||
parser.segs[len(parser.segs)-1].IsLast = true
|
||||
}
|
||||
parser.segs = addParameterMetaInfo(parser.segs)
|
||||
|
||||
return parser
|
||||
}
|
||||
|
||||
// addParameterMetaInfo add important meta information to the parameter segments
|
||||
// to simplify the search for the end of the parameter
|
||||
func addParameterMetaInfo(segs []*routeSegment) []*routeSegment {
|
||||
var comparePart string
|
||||
segLen := len(segs)
|
||||
// loop from end to begin
|
||||
for i := segLen - 1; i >= 0; i-- {
|
||||
// set the compare part for the parameter
|
||||
if segs[i].IsParam {
|
||||
// important for finding the end of the parameter
|
||||
segs[i].ComparePart = RemoveEscapeChar(comparePart)
|
||||
} else {
|
||||
comparePart = segs[i].Const
|
||||
if len(comparePart) > 1 {
|
||||
comparePart = utils.TrimRight(comparePart, slashDelimiter)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// loop from begin to end
|
||||
for i := 0; i < segLen; i++ {
|
||||
// check how often the compare part is in the following const parts
|
||||
if segs[i].IsParam {
|
||||
// check if parameter segments are directly after each other and if one of them is greedy
|
||||
// in case the next parameter or the current parameter is not a wildcard it's not greedy, we only want one character
|
||||
if segLen > i+1 && !segs[i].IsGreedy && segs[i+1].IsParam && !segs[i+1].IsGreedy {
|
||||
segs[i].Length = 1
|
||||
}
|
||||
if segs[i].ComparePart == "" {
|
||||
continue
|
||||
}
|
||||
for j := i + 1; j <= len(segs)-1; j++ {
|
||||
if !segs[j].IsParam {
|
||||
// count is important for the greedy match
|
||||
segs[i].PartCount += strings.Count(segs[j].Const, segs[i].ComparePart)
|
||||
}
|
||||
}
|
||||
// check if the end of the segment is a optional slash and then if the segement is optional or the last one
|
||||
} else if segs[i].Const[len(segs[i].Const)-1] == slashDelimiter && (segs[i].IsLast || (segLen > i+1 && segs[i+1].IsOptional)) {
|
||||
segs[i].HasOptionalSlash = true
|
||||
}
|
||||
}
|
||||
|
||||
return segs
|
||||
}
|
||||
|
||||
// findNextParamPosition search for the next possible parameter start position
|
||||
func findNextParamPosition(pattern string) int {
|
||||
nextParamPosition := findNextNonEscapedCharsetPosition(pattern, parameterStartChars)
|
||||
if nextParamPosition != -1 && len(pattern) > nextParamPosition && pattern[nextParamPosition] != wildcardParam {
|
||||
// search for parameter characters for the found parameter start,
|
||||
// if there are more, move the parameter start to the last parameter char
|
||||
for found := findNextNonEscapedCharsetPosition(pattern[nextParamPosition+1:], parameterStartChars); found == 0; {
|
||||
nextParamPosition++
|
||||
if len(pattern) > nextParamPosition {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nextParamPosition
|
||||
}
|
||||
|
||||
// analyseConstantPart find the end of the constant part and create the route segment
|
||||
func (*routeParser) analyseConstantPart(pattern string, nextParamPosition int) (string, *routeSegment) {
|
||||
// handle the constant part
|
||||
processedPart := pattern
|
||||
if nextParamPosition != -1 {
|
||||
// remove the constant part until the parameter
|
||||
processedPart = pattern[:nextParamPosition]
|
||||
}
|
||||
constPart := RemoveEscapeChar(processedPart)
|
||||
return processedPart, &routeSegment{
|
||||
Const: constPart,
|
||||
Length: len(constPart),
|
||||
}
|
||||
}
|
||||
|
||||
// analyseParameterPart find the parameter end and create the route segment
|
||||
func (routeParser *routeParser) analyseParameterPart(pattern string) (string, *routeSegment) {
|
||||
isWildCard := pattern[0] == wildcardParam
|
||||
isPlusParam := pattern[0] == plusParam
|
||||
|
||||
var parameterEndPosition int
|
||||
if strings.ContainsRune(pattern, rune(paramConstraintStart)) && strings.ContainsRune(pattern, rune(paramConstraintEnd)) {
|
||||
parameterEndPosition = findNextCharsetPositionConstraint(pattern[1:], parameterEndChars)
|
||||
} else {
|
||||
parameterEndPosition = findNextNonEscapedCharsetPosition(pattern[1:], parameterEndChars)
|
||||
}
|
||||
|
||||
parameterConstraintStart := -1
|
||||
parameterConstraintEnd := -1
|
||||
// handle wildcard end
|
||||
switch {
|
||||
case isWildCard, isPlusParam:
|
||||
parameterEndPosition = 0
|
||||
case parameterEndPosition == -1:
|
||||
parameterEndPosition = len(pattern) - 1
|
||||
case !isInCharset(pattern[parameterEndPosition+1], parameterDelimiterChars):
|
||||
parameterEndPosition++
|
||||
}
|
||||
|
||||
// find constraint part if exists in the parameter part and remove it
|
||||
if parameterEndPosition > 0 {
|
||||
parameterConstraintStart = findNextNonEscapedCharsetPosition(pattern[0:parameterEndPosition], parameterConstraintStartChars)
|
||||
parameterConstraintEnd = findLastCharsetPosition(pattern[0:parameterEndPosition+1], parameterConstraintEndChars)
|
||||
}
|
||||
|
||||
// cut params part
|
||||
processedPart := pattern[0 : parameterEndPosition+1]
|
||||
paramName := RemoveEscapeChar(GetTrimmedParam(processedPart))
|
||||
|
||||
// Check has constraint
|
||||
var constraints []*Constraint
|
||||
|
||||
if hasConstraint := parameterConstraintStart != -1 && parameterConstraintEnd != -1; hasConstraint {
|
||||
constraintString := pattern[parameterConstraintStart+1 : parameterConstraintEnd]
|
||||
userConstraints := splitNonEscaped(constraintString, string(parameterConstraintSeparatorChars))
|
||||
constraints = make([]*Constraint, 0, len(userConstraints))
|
||||
|
||||
for _, c := range userConstraints {
|
||||
start := findNextNonEscapedCharsetPosition(c, parameterConstraintDataStartChars)
|
||||
end := findLastCharsetPosition(c, parameterConstraintDataEndChars)
|
||||
|
||||
// Assign constraint
|
||||
if start != -1 && end != -1 {
|
||||
constraint := &Constraint{
|
||||
ID: getParamConstraintType(c[:start]),
|
||||
}
|
||||
|
||||
// remove escapes from data
|
||||
if constraint.ID != regexConstraint {
|
||||
constraint.Data = splitNonEscaped(c[start+1:end], string(parameterConstraintDataSeparatorChars))
|
||||
if len(constraint.Data) == 1 {
|
||||
constraint.Data[0] = RemoveEscapeChar(constraint.Data[0])
|
||||
} else if len(constraint.Data) == 2 { // This is fine, we simply expect two parts
|
||||
constraint.Data[0] = RemoveEscapeChar(constraint.Data[0])
|
||||
constraint.Data[1] = RemoveEscapeChar(constraint.Data[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Precompile regex if has regex constraint
|
||||
if constraint.ID == regexConstraint {
|
||||
constraint.Data = []string{c[start+1 : end]}
|
||||
constraint.RegexCompiler = regexp.MustCompile(constraint.Data[0])
|
||||
}
|
||||
|
||||
constraints = append(constraints, constraint)
|
||||
} else {
|
||||
constraints = append(constraints, &Constraint{
|
||||
ID: getParamConstraintType(c),
|
||||
Data: []string{},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
paramName = RemoveEscapeChar(GetTrimmedParam(pattern[0:parameterConstraintStart]))
|
||||
}
|
||||
|
||||
// add access iterator to wildcard and plus
|
||||
if isWildCard {
|
||||
routeParser.wildCardCount++
|
||||
paramName += strconv.Itoa(routeParser.wildCardCount)
|
||||
} else if isPlusParam {
|
||||
routeParser.plusCount++
|
||||
paramName += strconv.Itoa(routeParser.plusCount)
|
||||
}
|
||||
|
||||
segment := &routeSegment{
|
||||
ParamName: paramName,
|
||||
IsParam: true,
|
||||
IsOptional: isWildCard || pattern[parameterEndPosition] == optionalParam,
|
||||
IsGreedy: isWildCard || isPlusParam,
|
||||
}
|
||||
|
||||
if len(constraints) > 0 {
|
||||
segment.Constraints = constraints
|
||||
}
|
||||
|
||||
return processedPart, segment
|
||||
}
|
||||
|
||||
// isInCharset check is the given character in the charset list
|
||||
func isInCharset(searchChar byte, charset []byte) bool {
|
||||
for _, char := range charset {
|
||||
if char == searchChar {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// findNextCharsetPosition search the next char position from the charset
|
||||
func findNextCharsetPosition(search string, charset []byte) int {
|
||||
nextPosition := -1
|
||||
for _, char := range charset {
|
||||
if pos := strings.IndexByte(search, char); pos != -1 && (pos < nextPosition || nextPosition == -1) {
|
||||
nextPosition = pos
|
||||
}
|
||||
}
|
||||
|
||||
return nextPosition
|
||||
}
|
||||
|
||||
// findNextCharsetPosition search the last char position from the charset
|
||||
func findLastCharsetPosition(search string, charset []byte) int {
|
||||
lastPosition := -1
|
||||
for _, char := range charset {
|
||||
if pos := strings.LastIndexByte(search, char); pos != -1 && (pos < lastPosition || lastPosition == -1) {
|
||||
lastPosition = pos
|
||||
}
|
||||
}
|
||||
|
||||
return lastPosition
|
||||
}
|
||||
|
||||
// findNextCharsetPositionConstraint search the next char position from the charset
|
||||
// unlike findNextCharsetPosition, it takes care of constraint start-end chars to parse route pattern
|
||||
func findNextCharsetPositionConstraint(search string, charset []byte) int {
|
||||
constraintStart := findNextNonEscapedCharsetPosition(search, parameterConstraintStartChars)
|
||||
constraintEnd := findNextNonEscapedCharsetPosition(search, parameterConstraintEndChars)
|
||||
nextPosition := -1
|
||||
|
||||
for _, char := range charset {
|
||||
pos := strings.IndexByte(search, char)
|
||||
|
||||
if pos != -1 && (pos < nextPosition || nextPosition == -1) {
|
||||
if (pos > constraintStart && pos > constraintEnd) || (pos < constraintStart && pos < constraintEnd) {
|
||||
nextPosition = pos
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nextPosition
|
||||
}
|
||||
|
||||
// findNextNonEscapedCharsetPosition search the next char position from the charset and skip the escaped characters
|
||||
func findNextNonEscapedCharsetPosition(search string, charset []byte) int {
|
||||
pos := findNextCharsetPosition(search, charset)
|
||||
for pos > 0 && search[pos-1] == escapeChar {
|
||||
if len(search) == pos+1 {
|
||||
// escaped character is at the end
|
||||
return -1
|
||||
}
|
||||
nextPossiblePos := findNextCharsetPosition(search[pos+1:], charset)
|
||||
if nextPossiblePos == -1 {
|
||||
return -1
|
||||
}
|
||||
// the previous character is taken into consideration
|
||||
pos = nextPossiblePos + pos + 1
|
||||
}
|
||||
|
||||
return pos
|
||||
}
|
||||
|
||||
// splitNonEscaped slices s into all substrings separated by sep and returns a slice of the substrings between those separators
|
||||
// This function also takes a care of escape char when splitting.
|
||||
func splitNonEscaped(s, sep string) []string {
|
||||
var result []string
|
||||
i := findNextNonEscapedCharsetPosition(s, []byte(sep))
|
||||
|
||||
for i > -1 {
|
||||
result = append(result, s[:i])
|
||||
s = s[i+len(sep):]
|
||||
i = findNextNonEscapedCharsetPosition(s, []byte(sep))
|
||||
}
|
||||
|
||||
return append(result, s)
|
||||
}
|
||||
|
||||
// getMatch parses the passed url and tries to match it against the route segments and determine the parameter positions
|
||||
func (routeParser *routeParser) getMatch(detectionPath, path string, params *[maxParams]string, partialCheck bool) bool { //nolint: revive // Accepting a bool param is fine here
|
||||
var i, paramsIterator, partLen int
|
||||
for _, segment := range routeParser.segs {
|
||||
partLen = len(detectionPath)
|
||||
// check const segment
|
||||
if !segment.IsParam {
|
||||
i = segment.Length
|
||||
// is optional part or the const part must match with the given string
|
||||
// check if the end of the segment is an optional slash
|
||||
if segment.HasOptionalSlash && partLen == i-1 && detectionPath == segment.Const[:i-1] {
|
||||
i--
|
||||
} else if !(i <= partLen && detectionPath[:i] == segment.Const) {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
// determine parameter length
|
||||
i = findParamLen(detectionPath, segment)
|
||||
if !segment.IsOptional && i == 0 {
|
||||
return false
|
||||
}
|
||||
// take over the params positions
|
||||
params[paramsIterator] = path[:i]
|
||||
|
||||
if !(segment.IsOptional && i == 0) {
|
||||
// check constraint
|
||||
for _, c := range segment.Constraints {
|
||||
if matched := c.CheckConstraint(params[paramsIterator]); !matched {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
paramsIterator++
|
||||
}
|
||||
|
||||
// reduce founded part from the string
|
||||
if partLen > 0 {
|
||||
detectionPath, path = detectionPath[i:], path[i:]
|
||||
}
|
||||
}
|
||||
if detectionPath != "" && !partialCheck {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// findParamLen for the expressjs wildcard behavior (right to left greedy)
|
||||
// look at the other segments and take what is left for the wildcard from right to left
|
||||
func findParamLen(s string, segment *routeSegment) int {
|
||||
if segment.IsLast {
|
||||
return findParamLenForLastSegment(s, segment)
|
||||
}
|
||||
|
||||
if segment.Length != 0 && len(s) >= segment.Length {
|
||||
return segment.Length
|
||||
} else if segment.IsGreedy {
|
||||
// Search the parameters until the next constant part
|
||||
// special logic for greedy params
|
||||
searchCount := strings.Count(s, segment.ComparePart)
|
||||
if searchCount > 1 {
|
||||
return findGreedyParamLen(s, searchCount, segment)
|
||||
}
|
||||
}
|
||||
|
||||
if len(segment.ComparePart) == 1 {
|
||||
if constPosition := strings.IndexByte(s, segment.ComparePart[0]); constPosition != -1 {
|
||||
return constPosition
|
||||
}
|
||||
} else if constPosition := strings.Index(s, segment.ComparePart); constPosition != -1 {
|
||||
// if the compare part was found, but contains a slash although this part is not greedy, then it must not match
|
||||
// example: /api/:param/fixedEnd -> path: /api/123/456/fixedEnd = no match , /api/123/fixedEnd = match
|
||||
if !segment.IsGreedy && strings.IndexByte(s[:constPosition], slashDelimiter) != -1 {
|
||||
return 0
|
||||
}
|
||||
return constPosition
|
||||
}
|
||||
|
||||
return len(s)
|
||||
}
|
||||
|
||||
// findParamLenForLastSegment get the length of the parameter if it is the last segment
|
||||
func findParamLenForLastSegment(s string, seg *routeSegment) int {
|
||||
if !seg.IsGreedy {
|
||||
if i := strings.IndexByte(s, slashDelimiter); i != -1 {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return len(s)
|
||||
}
|
||||
|
||||
// findGreedyParamLen get the length of the parameter for greedy segments from right to left
|
||||
func findGreedyParamLen(s string, searchCount int, segment *routeSegment) int {
|
||||
// check all from right to left segments
|
||||
for i := segment.PartCount; i > 0 && searchCount > 0; i-- {
|
||||
searchCount--
|
||||
if constPosition := strings.LastIndex(s, segment.ComparePart); constPosition != -1 {
|
||||
s = s[:constPosition]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return len(s)
|
||||
}
|
||||
|
||||
// GetTrimmedParam trims the ':' & '?' from a string
|
||||
func GetTrimmedParam(param string) string {
|
||||
start := 0
|
||||
end := len(param)
|
||||
|
||||
if end == 0 || param[start] != paramStarterChar { // is not a param
|
||||
return param
|
||||
}
|
||||
start++
|
||||
if param[end-1] == optionalParam { // is ?
|
||||
end--
|
||||
}
|
||||
|
||||
return param[start:end]
|
||||
}
|
||||
|
||||
// RemoveEscapeChar remove escape characters
|
||||
func RemoveEscapeChar(word string) string {
|
||||
if strings.IndexByte(word, escapeChar) != -1 {
|
||||
return strings.ReplaceAll(word, string(escapeChar), "")
|
||||
}
|
||||
return word
|
||||
}
|
||||
|
||||
func getParamConstraintType(constraintPart string) TypeConstraint {
|
||||
switch constraintPart {
|
||||
case ConstraintInt:
|
||||
return intConstraint
|
||||
case ConstraintBool:
|
||||
return boolConstraint
|
||||
case ConstraintFloat:
|
||||
return floatConstraint
|
||||
case ConstraintAlpha:
|
||||
return alphaConstraint
|
||||
case ConstraintGuid:
|
||||
return guidConstraint
|
||||
case ConstraintMinLen, ConstraintMinLenLower:
|
||||
return minLenConstraint
|
||||
case ConstraintMaxLen, ConstraintMaxLenLower:
|
||||
return maxLenConstraint
|
||||
case ConstraintLen:
|
||||
return lenConstraint
|
||||
case ConstraintBetweenLen, ConstraintBetweenLenLower:
|
||||
return betweenLenConstraint
|
||||
case ConstraintMin:
|
||||
return minConstraint
|
||||
case ConstraintMax:
|
||||
return maxConstraint
|
||||
case ConstraintRange:
|
||||
return rangeConstraint
|
||||
case ConstraintDatetime:
|
||||
return datetimeConstraint
|
||||
case ConstraintRegex:
|
||||
return regexConstraint
|
||||
default:
|
||||
return noConstraint
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:errcheck // TODO: Properly check _all_ errors in here, log them & immediately return
|
||||
func (c *Constraint) CheckConstraint(param string) bool {
|
||||
var err error
|
||||
var num int
|
||||
|
||||
// check data exists
|
||||
needOneData := []TypeConstraint{minLenConstraint, maxLenConstraint, lenConstraint, minConstraint, maxConstraint, datetimeConstraint, regexConstraint}
|
||||
needTwoData := []TypeConstraint{betweenLenConstraint, rangeConstraint}
|
||||
|
||||
for _, data := range needOneData {
|
||||
if c.ID == data && len(c.Data) == 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
for _, data := range needTwoData {
|
||||
if c.ID == data && len(c.Data) < 2 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// check constraints
|
||||
switch c.ID {
|
||||
case noConstraint:
|
||||
// Nothing to check
|
||||
case intConstraint:
|
||||
_, err = strconv.Atoi(param)
|
||||
case boolConstraint:
|
||||
_, err = strconv.ParseBool(param)
|
||||
case floatConstraint:
|
||||
_, err = strconv.ParseFloat(param, 32)
|
||||
case alphaConstraint:
|
||||
for _, r := range param {
|
||||
if !unicode.IsLetter(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
case guidConstraint:
|
||||
_, err = uuid.Parse(param)
|
||||
case minLenConstraint:
|
||||
data, _ := strconv.Atoi(c.Data[0])
|
||||
|
||||
if len(param) < data {
|
||||
return false
|
||||
}
|
||||
case maxLenConstraint:
|
||||
data, _ := strconv.Atoi(c.Data[0])
|
||||
|
||||
if len(param) > data {
|
||||
return false
|
||||
}
|
||||
case lenConstraint:
|
||||
data, _ := strconv.Atoi(c.Data[0])
|
||||
|
||||
if len(param) != data {
|
||||
return false
|
||||
}
|
||||
case betweenLenConstraint:
|
||||
data, _ := strconv.Atoi(c.Data[0])
|
||||
data2, _ := strconv.Atoi(c.Data[1])
|
||||
length := len(param)
|
||||
if length < data || length > data2 {
|
||||
return false
|
||||
}
|
||||
case minConstraint:
|
||||
data, _ := strconv.Atoi(c.Data[0])
|
||||
num, err = strconv.Atoi(param)
|
||||
|
||||
if num < data {
|
||||
return false
|
||||
}
|
||||
case maxConstraint:
|
||||
data, _ := strconv.Atoi(c.Data[0])
|
||||
num, err = strconv.Atoi(param)
|
||||
|
||||
if num > data {
|
||||
return false
|
||||
}
|
||||
case rangeConstraint:
|
||||
data, _ := strconv.Atoi(c.Data[0])
|
||||
data2, _ := strconv.Atoi(c.Data[1])
|
||||
num, err = strconv.Atoi(param)
|
||||
|
||||
if num < data || num > data2 {
|
||||
return false
|
||||
}
|
||||
case datetimeConstraint:
|
||||
_, err = time.Parse(c.Data[0], param)
|
||||
case regexConstraint:
|
||||
if match := c.RegexCompiler.MatchString(param); !match {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return err == nil
|
||||
}
|
||||
+179
@@ -0,0 +1,179 @@
|
||||
package fiber
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp/reuseport"
|
||||
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
const (
|
||||
envPreforkChildKey = "FIBER_PREFORK_CHILD"
|
||||
envPreforkChildVal = "1"
|
||||
)
|
||||
|
||||
var (
|
||||
testPreforkMaster = false
|
||||
testOnPrefork = false
|
||||
)
|
||||
|
||||
// IsChild determines if the current process is a child of Prefork
|
||||
func IsChild() bool {
|
||||
return os.Getenv(envPreforkChildKey) == envPreforkChildVal
|
||||
}
|
||||
|
||||
// prefork manages child processes to make use of the OS REUSEPORT or REUSEADDR feature
|
||||
func (app *App) prefork(network, addr string, tlsConfig *tls.Config) error {
|
||||
// 👶 child process 👶
|
||||
if IsChild() {
|
||||
// use 1 cpu core per child process
|
||||
runtime.GOMAXPROCS(1)
|
||||
// Linux will use SO_REUSEPORT and Windows falls back to SO_REUSEADDR
|
||||
// Only tcp4 or tcp6 is supported when preforking, both are not supported
|
||||
ln, err := reuseport.Listen(network, addr)
|
||||
if err != nil {
|
||||
if !app.config.DisableStartupMessage {
|
||||
const sleepDuration = 100 * time.Millisecond
|
||||
time.Sleep(sleepDuration) // avoid colliding with startup message
|
||||
}
|
||||
return fmt.Errorf("prefork: %w", err)
|
||||
}
|
||||
// wrap a tls config around the listener if provided
|
||||
if tlsConfig != nil {
|
||||
ln = tls.NewListener(ln, tlsConfig)
|
||||
}
|
||||
|
||||
// kill current child proc when master exits
|
||||
go watchMaster()
|
||||
|
||||
// prepare the server for the start
|
||||
app.startupProcess()
|
||||
|
||||
// listen for incoming connections
|
||||
return app.server.Serve(ln)
|
||||
}
|
||||
|
||||
// 👮 master process 👮
|
||||
type child struct {
|
||||
pid int
|
||||
err error
|
||||
}
|
||||
// create variables
|
||||
max := runtime.GOMAXPROCS(0)
|
||||
childs := make(map[int]*exec.Cmd)
|
||||
channel := make(chan child, max)
|
||||
|
||||
// kill child procs when master exits
|
||||
defer func() {
|
||||
for _, proc := range childs {
|
||||
if err := proc.Process.Kill(); err != nil {
|
||||
if !errors.Is(err, os.ErrProcessDone) {
|
||||
log.Errorf("prefork: failed to kill child: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// collect child pids
|
||||
var pids []string
|
||||
|
||||
// launch child procs
|
||||
for i := 0; i < max; i++ {
|
||||
cmd := exec.Command(os.Args[0], os.Args[1:]...) //nolint:gosec // It's fine to launch the same process again
|
||||
if testPreforkMaster {
|
||||
// When test prefork master,
|
||||
// just start the child process with a dummy cmd,
|
||||
// which will exit soon
|
||||
cmd = dummyCmd()
|
||||
}
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
// add fiber prefork child flag into child proc env
|
||||
cmd.Env = append(os.Environ(),
|
||||
fmt.Sprintf("%s=%s", envPreforkChildKey, envPreforkChildVal),
|
||||
)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start a child prefork process, error: %w", err)
|
||||
}
|
||||
|
||||
// store child process
|
||||
pid := cmd.Process.Pid
|
||||
childs[pid] = cmd
|
||||
pids = append(pids, strconv.Itoa(pid))
|
||||
|
||||
// execute fork hook
|
||||
if app.hooks != nil {
|
||||
if testOnPrefork {
|
||||
app.hooks.executeOnForkHooks(dummyPid)
|
||||
} else {
|
||||
app.hooks.executeOnForkHooks(pid)
|
||||
}
|
||||
}
|
||||
|
||||
// notify master if child crashes
|
||||
go func() {
|
||||
channel <- child{pid, cmd.Wait()}
|
||||
}()
|
||||
}
|
||||
|
||||
// Run onListen hooks
|
||||
// Hooks have to be run here as different as non-prefork mode due to they should run as child or master
|
||||
app.runOnListenHooks(app.prepareListenData(addr, tlsConfig != nil))
|
||||
|
||||
// Print startup message
|
||||
if !app.config.DisableStartupMessage {
|
||||
app.startupMessage(addr, tlsConfig != nil, ","+strings.Join(pids, ","))
|
||||
}
|
||||
|
||||
// return error if child crashes
|
||||
return (<-channel).err
|
||||
}
|
||||
|
||||
// watchMaster watches child procs
|
||||
func watchMaster() {
|
||||
if runtime.GOOS == "windows" {
|
||||
// finds parent process,
|
||||
// and waits for it to exit
|
||||
p, err := os.FindProcess(os.Getppid())
|
||||
if err == nil {
|
||||
_, _ = p.Wait() //nolint:errcheck // It is fine to ignore the error here
|
||||
}
|
||||
os.Exit(1) //nolint:revive // Calling os.Exit is fine here in the prefork
|
||||
}
|
||||
// if it is equal to 1 (init process ID),
|
||||
// it indicates that the master process has exited
|
||||
const watchInterval = 500 * time.Millisecond
|
||||
for range time.NewTicker(watchInterval).C {
|
||||
if os.Getppid() == 1 {
|
||||
os.Exit(1) //nolint:revive // Calling os.Exit is fine here in the prefork
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
dummyPid = 1
|
||||
dummyChildCmd atomic.Value
|
||||
)
|
||||
|
||||
// dummyCmd is for internal prefork testing
|
||||
func dummyCmd() *exec.Cmd {
|
||||
command := "go"
|
||||
if storeCommand := dummyChildCmd.Load(); storeCommand != nil && storeCommand != "" {
|
||||
command = storeCommand.(string) //nolint:forcetypeassert,errcheck // We always store a string in here
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
return exec.Command("cmd", "/C", command, "version")
|
||||
}
|
||||
return exec.Command(command, "version")
|
||||
}
|
||||
+518
@@ -0,0 +1,518 @@
|
||||
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
|
||||
// 🤖 Github Repository: https://github.com/gofiber/fiber
|
||||
// 📌 API Documentation: https://docs.gofiber.io
|
||||
|
||||
package fiber
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// Router defines all router handle interface, including app and group router.
|
||||
type Router interface {
|
||||
Use(args ...interface{}) Router
|
||||
|
||||
Get(path string, handlers ...Handler) Router
|
||||
Head(path string, handlers ...Handler) Router
|
||||
Post(path string, handlers ...Handler) Router
|
||||
Put(path string, handlers ...Handler) Router
|
||||
Delete(path string, handlers ...Handler) Router
|
||||
Connect(path string, handlers ...Handler) Router
|
||||
Options(path string, handlers ...Handler) Router
|
||||
Trace(path string, handlers ...Handler) Router
|
||||
Patch(path string, handlers ...Handler) Router
|
||||
|
||||
Add(method, path string, handlers ...Handler) Router
|
||||
Static(prefix, root string, config ...Static) Router
|
||||
All(path string, handlers ...Handler) Router
|
||||
|
||||
Group(prefix string, handlers ...Handler) Router
|
||||
|
||||
Route(prefix string, fn func(router Router), name ...string) Router
|
||||
|
||||
Mount(prefix string, fiber *App) Router
|
||||
|
||||
Name(name string) Router
|
||||
}
|
||||
|
||||
// Route is a struct that holds all metadata for each registered handler.
|
||||
type Route struct {
|
||||
// ### important: always keep in sync with the copy method "app.copyRoute" ###
|
||||
// Data for routing
|
||||
pos uint32 // Position in stack -> important for the sort of the matched routes
|
||||
use bool // USE matches path prefixes
|
||||
mount bool // Indicated a mounted app on a specific route
|
||||
star bool // Path equals '*'
|
||||
root bool // Path equals '/'
|
||||
path string // Prettified path
|
||||
routeParser routeParser // Parameter parser
|
||||
group *Group // Group instance. used for routes in groups
|
||||
|
||||
// Public fields
|
||||
Method string `json:"method"` // HTTP method
|
||||
Name string `json:"name"` // Route's name
|
||||
//nolint:revive // Having both a Path (uppercase) and a path (lowercase) is fine
|
||||
Path string `json:"path"` // Original registered route path
|
||||
Params []string `json:"params"` // Case sensitive param keys
|
||||
Handlers []Handler `json:"-"` // Ctx handlers
|
||||
}
|
||||
|
||||
func (r *Route) match(detectionPath, path string, params *[maxParams]string) bool {
|
||||
// root detectionPath check
|
||||
if r.root && detectionPath == "/" {
|
||||
return true
|
||||
// '*' wildcard matches any detectionPath
|
||||
} else if r.star {
|
||||
if len(path) > 1 {
|
||||
params[0] = path[1:]
|
||||
} else {
|
||||
params[0] = ""
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Does this route have parameters
|
||||
if len(r.Params) > 0 {
|
||||
// Match params
|
||||
if match := r.routeParser.getMatch(detectionPath, path, params, r.use); match {
|
||||
// Get params from the path detectionPath
|
||||
return match
|
||||
}
|
||||
}
|
||||
// Is this route a Middleware?
|
||||
if r.use {
|
||||
// Single slash will match or detectionPath prefix
|
||||
if r.root || strings.HasPrefix(detectionPath, r.path) {
|
||||
return true
|
||||
}
|
||||
// Check for a simple detectionPath match
|
||||
} else if len(r.path) == len(detectionPath) && r.path == detectionPath {
|
||||
return true
|
||||
}
|
||||
// No match
|
||||
return false
|
||||
}
|
||||
|
||||
func (app *App) next(c *Ctx) (bool, error) {
|
||||
// Get stack length
|
||||
tree, ok := app.treeStack[c.methodINT][c.treePath]
|
||||
if !ok {
|
||||
tree = app.treeStack[c.methodINT][""]
|
||||
}
|
||||
lenTree := len(tree) - 1
|
||||
|
||||
// Loop over the route stack starting from previous index
|
||||
for c.indexRoute < lenTree {
|
||||
// Increment route index
|
||||
c.indexRoute++
|
||||
|
||||
// Get *Route
|
||||
route := tree[c.indexRoute]
|
||||
|
||||
var match bool
|
||||
var err error
|
||||
// skip for mounted apps
|
||||
if route.mount {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if it matches the request path
|
||||
match = route.match(c.detectionPath, c.path, &c.values)
|
||||
if !match {
|
||||
// No match, next route
|
||||
continue
|
||||
}
|
||||
// Pass route reference and param values
|
||||
c.route = route
|
||||
|
||||
// Non use handler matched
|
||||
if !c.matched && !route.use {
|
||||
c.matched = true
|
||||
}
|
||||
|
||||
// Execute first handler of route
|
||||
c.indexHandler = 0
|
||||
if len(route.Handlers) > 0 {
|
||||
err = route.Handlers[0](c)
|
||||
}
|
||||
return match, err // Stop scanning the stack
|
||||
}
|
||||
|
||||
// If c.Next() does not match, return 404
|
||||
err := NewError(StatusNotFound, "Cannot "+c.method+" "+html.EscapeString(c.pathOriginal))
|
||||
if !c.matched && app.methodExist(c) {
|
||||
// If no match, scan stack again if other methods match the request
|
||||
// Moved from app.handler because middleware may break the route chain
|
||||
err = ErrMethodNotAllowed
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
func (app *App) handler(rctx *fasthttp.RequestCtx) { //revive:disable-line:confusing-naming // Having both a Handler() (uppercase) and a handler() (lowercase) is fine. TODO: Use nolint:revive directive instead. See https://github.com/golangci/golangci-lint/issues/3476
|
||||
// Acquire Ctx with fasthttp request from pool
|
||||
c := app.AcquireCtx(rctx)
|
||||
defer app.ReleaseCtx(c)
|
||||
|
||||
// handle invalid http method directly
|
||||
if c.methodINT == -1 {
|
||||
_ = c.Status(StatusBadRequest).SendString("Invalid http method") //nolint:errcheck // It is fine to ignore the error here
|
||||
return
|
||||
}
|
||||
|
||||
// Find match in stack
|
||||
match, err := app.next(c)
|
||||
if err != nil {
|
||||
if catch := c.app.ErrorHandler(c, err); catch != nil {
|
||||
_ = c.SendStatus(StatusInternalServerError) //nolint:errcheck // It is fine to ignore the error here
|
||||
}
|
||||
// TODO: Do we need to return here?
|
||||
}
|
||||
// Generate ETag if enabled
|
||||
if match && app.config.ETag {
|
||||
setETag(c, false)
|
||||
}
|
||||
}
|
||||
|
||||
func (app *App) addPrefixToRoute(prefix string, route *Route) *Route {
|
||||
prefixedPath := getGroupPath(prefix, route.Path)
|
||||
prettyPath := prefixedPath
|
||||
// Case-sensitive routing, all to lowercase
|
||||
if !app.config.CaseSensitive {
|
||||
prettyPath = utils.ToLower(prettyPath)
|
||||
}
|
||||
// Strict routing, remove trailing slashes
|
||||
if !app.config.StrictRouting && len(prettyPath) > 1 {
|
||||
prettyPath = utils.TrimRight(prettyPath, '/')
|
||||
}
|
||||
|
||||
route.Path = prefixedPath
|
||||
route.path = RemoveEscapeChar(prettyPath)
|
||||
route.routeParser = parseRoute(prettyPath)
|
||||
route.root = false
|
||||
route.star = false
|
||||
|
||||
return route
|
||||
}
|
||||
|
||||
func (*App) copyRoute(route *Route) *Route {
|
||||
return &Route{
|
||||
// Router booleans
|
||||
use: route.use,
|
||||
mount: route.mount,
|
||||
star: route.star,
|
||||
root: route.root,
|
||||
|
||||
// Path data
|
||||
path: route.path,
|
||||
routeParser: route.routeParser,
|
||||
|
||||
// misc
|
||||
pos: route.pos,
|
||||
|
||||
// Public data
|
||||
Path: route.Path,
|
||||
Params: route.Params,
|
||||
Name: route.Name,
|
||||
Method: route.Method,
|
||||
Handlers: route.Handlers,
|
||||
}
|
||||
}
|
||||
|
||||
func (app *App) register(method, pathRaw string, group *Group, handlers ...Handler) {
|
||||
// Uppercase HTTP methods
|
||||
method = utils.ToUpper(method)
|
||||
// Check if the HTTP method is valid unless it's USE
|
||||
if method != methodUse && app.methodInt(method) == -1 {
|
||||
panic(fmt.Sprintf("add: invalid http method %s\n", method))
|
||||
}
|
||||
// is mounted app
|
||||
isMount := group != nil && group.app != app
|
||||
// A route requires atleast one ctx handler
|
||||
if len(handlers) == 0 && !isMount {
|
||||
panic(fmt.Sprintf("missing handler in route: %s\n", pathRaw))
|
||||
}
|
||||
// Cannot have an empty path
|
||||
if pathRaw == "" {
|
||||
pathRaw = "/"
|
||||
}
|
||||
// Path always start with a '/'
|
||||
if pathRaw[0] != '/' {
|
||||
pathRaw = "/" + pathRaw
|
||||
}
|
||||
// Create a stripped path in-case sensitive / trailing slashes
|
||||
pathPretty := pathRaw
|
||||
// Case-sensitive routing, all to lowercase
|
||||
if !app.config.CaseSensitive {
|
||||
pathPretty = utils.ToLower(pathPretty)
|
||||
}
|
||||
// Strict routing, remove trailing slashes
|
||||
if !app.config.StrictRouting && len(pathPretty) > 1 {
|
||||
pathPretty = utils.TrimRight(pathPretty, '/')
|
||||
}
|
||||
// Is layer a middleware?
|
||||
isUse := method == methodUse
|
||||
// Is path a direct wildcard?
|
||||
isStar := pathPretty == "/*"
|
||||
// Is path a root slash?
|
||||
isRoot := pathPretty == "/"
|
||||
// Parse path parameters
|
||||
parsedRaw := parseRoute(pathRaw)
|
||||
parsedPretty := parseRoute(pathPretty)
|
||||
|
||||
// Create route metadata without pointer
|
||||
route := Route{
|
||||
// Router booleans
|
||||
use: isUse,
|
||||
mount: isMount,
|
||||
star: isStar,
|
||||
root: isRoot,
|
||||
|
||||
// Path data
|
||||
path: RemoveEscapeChar(pathPretty),
|
||||
routeParser: parsedPretty,
|
||||
Params: parsedRaw.params,
|
||||
|
||||
// Group data
|
||||
group: group,
|
||||
|
||||
// Public data
|
||||
Path: pathRaw,
|
||||
Method: method,
|
||||
Handlers: handlers,
|
||||
}
|
||||
// Increment global handler count
|
||||
atomic.AddUint32(&app.handlersCount, uint32(len(handlers)))
|
||||
|
||||
// Middleware route matches all HTTP methods
|
||||
if isUse {
|
||||
// Add route to all HTTP methods stack
|
||||
for _, m := range app.config.RequestMethods {
|
||||
// Create a route copy to avoid duplicates during compression
|
||||
r := route
|
||||
app.addRoute(m, &r, isMount)
|
||||
}
|
||||
} else {
|
||||
// Add route to stack
|
||||
app.addRoute(method, &route, isMount)
|
||||
}
|
||||
}
|
||||
|
||||
func (app *App) registerStatic(prefix, root string, config ...Static) {
|
||||
// For security, we want to restrict to the current work directory.
|
||||
if root == "" {
|
||||
root = "."
|
||||
}
|
||||
// Cannot have an empty prefix
|
||||
if prefix == "" {
|
||||
prefix = "/"
|
||||
}
|
||||
// Prefix always start with a '/' or '*'
|
||||
if prefix[0] != '/' {
|
||||
prefix = "/" + prefix
|
||||
}
|
||||
// in case-sensitive routing, all to lowercase
|
||||
if !app.config.CaseSensitive {
|
||||
prefix = utils.ToLower(prefix)
|
||||
}
|
||||
// Strip trailing slashes from the root path
|
||||
if len(root) > 0 && root[len(root)-1] == '/' {
|
||||
root = root[:len(root)-1]
|
||||
}
|
||||
// Is prefix a direct wildcard?
|
||||
isStar := prefix == "/*"
|
||||
// Is prefix a root slash?
|
||||
isRoot := prefix == "/"
|
||||
// Is prefix a partial wildcard?
|
||||
if strings.Contains(prefix, "*") {
|
||||
// /john* -> /john
|
||||
isStar = true
|
||||
prefix = strings.Split(prefix, "*")[0]
|
||||
// Fix this later
|
||||
}
|
||||
prefixLen := len(prefix)
|
||||
if prefixLen > 1 && prefix[prefixLen-1:] == "/" {
|
||||
// /john/ -> /john
|
||||
prefixLen--
|
||||
prefix = prefix[:prefixLen]
|
||||
}
|
||||
const cacheDuration = 10 * time.Second
|
||||
// Fileserver settings
|
||||
fs := &fasthttp.FS{
|
||||
Root: root,
|
||||
AllowEmptyRoot: true,
|
||||
GenerateIndexPages: false,
|
||||
AcceptByteRange: false,
|
||||
Compress: false,
|
||||
CompressedFileSuffix: app.config.CompressedFileSuffix,
|
||||
CacheDuration: cacheDuration,
|
||||
IndexNames: []string{"index.html"},
|
||||
PathRewrite: func(fctx *fasthttp.RequestCtx) []byte {
|
||||
path := fctx.Path()
|
||||
if len(path) >= prefixLen {
|
||||
if isStar && app.getString(path[0:prefixLen]) == prefix {
|
||||
path = append(path[0:0], '/')
|
||||
} else {
|
||||
path = path[prefixLen:]
|
||||
if len(path) == 0 || path[len(path)-1] != '/' {
|
||||
path = append(path, '/')
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(path) > 0 && path[0] != '/' {
|
||||
path = append([]byte("/"), path...)
|
||||
}
|
||||
return path
|
||||
},
|
||||
PathNotFound: func(fctx *fasthttp.RequestCtx) {
|
||||
fctx.Response.SetStatusCode(StatusNotFound)
|
||||
},
|
||||
}
|
||||
|
||||
// Set config if provided
|
||||
var cacheControlValue string
|
||||
var modifyResponse Handler
|
||||
if len(config) > 0 {
|
||||
maxAge := config[0].MaxAge
|
||||
if maxAge > 0 {
|
||||
cacheControlValue = "public, max-age=" + strconv.Itoa(maxAge)
|
||||
}
|
||||
fs.CacheDuration = config[0].CacheDuration
|
||||
fs.Compress = config[0].Compress
|
||||
fs.AcceptByteRange = config[0].ByteRange
|
||||
fs.GenerateIndexPages = config[0].Browse
|
||||
if config[0].Index != "" {
|
||||
fs.IndexNames = []string{config[0].Index}
|
||||
}
|
||||
modifyResponse = config[0].ModifyResponse
|
||||
}
|
||||
fileHandler := fs.NewRequestHandler()
|
||||
handler := func(c *Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if len(config) != 0 && config[0].Next != nil && config[0].Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
// Serve file
|
||||
fileHandler(c.fasthttp)
|
||||
// Sets the response Content-Disposition header to attachment if the Download option is true
|
||||
if len(config) > 0 && config[0].Download {
|
||||
c.Attachment()
|
||||
}
|
||||
// Return request if found and not forbidden
|
||||
status := c.fasthttp.Response.StatusCode()
|
||||
if status != StatusNotFound && status != StatusForbidden {
|
||||
if len(cacheControlValue) > 0 {
|
||||
c.fasthttp.Response.Header.Set(HeaderCacheControl, cacheControlValue)
|
||||
}
|
||||
if modifyResponse != nil {
|
||||
return modifyResponse(c)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Reset response to default
|
||||
c.fasthttp.SetContentType("") // Issue #420
|
||||
c.fasthttp.Response.SetStatusCode(StatusOK)
|
||||
c.fasthttp.Response.SetBodyString("")
|
||||
// Next middleware
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Create route metadata without pointer
|
||||
route := Route{
|
||||
// Router booleans
|
||||
use: true,
|
||||
root: isRoot,
|
||||
path: prefix,
|
||||
// Public data
|
||||
Method: MethodGet,
|
||||
Path: prefix,
|
||||
Handlers: []Handler{handler},
|
||||
}
|
||||
// Increment global handler count
|
||||
atomic.AddUint32(&app.handlersCount, 1)
|
||||
// Add route to stack
|
||||
app.addRoute(MethodGet, &route)
|
||||
// Add HEAD route
|
||||
app.addRoute(MethodHead, &route)
|
||||
}
|
||||
|
||||
func (app *App) addRoute(method string, route *Route, isMounted ...bool) {
|
||||
// Check mounted routes
|
||||
var mounted bool
|
||||
if len(isMounted) > 0 {
|
||||
mounted = isMounted[0]
|
||||
}
|
||||
|
||||
// Get unique HTTP method identifier
|
||||
m := app.methodInt(method)
|
||||
|
||||
// prevent identically route registration
|
||||
l := len(app.stack[m])
|
||||
if l > 0 && app.stack[m][l-1].Path == route.Path && route.use == app.stack[m][l-1].use && !route.mount && !app.stack[m][l-1].mount {
|
||||
preRoute := app.stack[m][l-1]
|
||||
preRoute.Handlers = append(preRoute.Handlers, route.Handlers...)
|
||||
} else {
|
||||
// Increment global route position
|
||||
route.pos = atomic.AddUint32(&app.routesCount, 1)
|
||||
route.Method = method
|
||||
// Add route to the stack
|
||||
app.stack[m] = append(app.stack[m], route)
|
||||
app.routesRefreshed = true
|
||||
}
|
||||
|
||||
// Execute onRoute hooks & change latestRoute if not adding mounted route
|
||||
if !mounted {
|
||||
app.mutex.Lock()
|
||||
app.latestRoute = route
|
||||
if err := app.hooks.executeOnRouteHooks(*route); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
app.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// buildTree build the prefix tree from the previously registered routes
|
||||
func (app *App) buildTree() *App {
|
||||
if !app.routesRefreshed {
|
||||
return app
|
||||
}
|
||||
|
||||
// loop all the methods and stacks and create the prefix tree
|
||||
for m := range app.config.RequestMethods {
|
||||
tsMap := make(map[string][]*Route)
|
||||
for _, route := range app.stack[m] {
|
||||
treePath := ""
|
||||
if len(route.routeParser.segs) > 0 && len(route.routeParser.segs[0].Const) >= 3 {
|
||||
treePath = route.routeParser.segs[0].Const[:3]
|
||||
}
|
||||
// create tree stack
|
||||
tsMap[treePath] = append(tsMap[treePath], route)
|
||||
}
|
||||
app.treeStack[m] = tsMap
|
||||
}
|
||||
|
||||
// loop the methods and tree stacks and add global stack and sort everything
|
||||
for m := range app.config.RequestMethods {
|
||||
tsMap := app.treeStack[m]
|
||||
for treePart := range tsMap {
|
||||
if treePart != "" {
|
||||
// merge global tree routes in current tree stack
|
||||
tsMap[treePart] = uniqueRouteStack(append(tsMap[treePart], tsMap[""]...))
|
||||
}
|
||||
// sort tree slices with the positions
|
||||
slc := tsMap[treePart]
|
||||
sort.Slice(slc, func(i, j int) bool { return slc[i].pos < slc[j].pos })
|
||||
}
|
||||
}
|
||||
app.routesRefreshed = false
|
||||
|
||||
return app
|
||||
}
|
||||
+90
@@ -0,0 +1,90 @@
|
||||
A collection of common functions but with better performance, less allocations and no dependencies created for [Fiber](https://github.com/gofiber/fiber).
|
||||
|
||||
```go
|
||||
// go test -benchmem -run=^$ -bench=Benchmark_ -count=2
|
||||
|
||||
Benchmark_ToLowerBytes/fiber-16 42847654 25.7 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_ToLowerBytes/fiber-16 46143196 25.7 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_ToLowerBytes/default-16 17387322 67.4 ns/op 48 B/op 1 allocs/op
|
||||
Benchmark_ToLowerBytes/default-16 17906491 67.4 ns/op 48 B/op 1 allocs/op
|
||||
|
||||
Benchmark_ToUpperBytes/fiber-16 46143729 25.7 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_ToUpperBytes/fiber-16 47989250 25.6 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_ToUpperBytes/default-16 15580854 76.7 ns/op 48 B/op 1 allocs/op
|
||||
Benchmark_ToUpperBytes/default-16 15381202 76.9 ns/op 48 B/op 1 allocs/op
|
||||
|
||||
Benchmark_TrimRightBytes/fiber-16 70572459 16.3 ns/op 8 B/op 1 allocs/op
|
||||
Benchmark_TrimRightBytes/fiber-16 74983597 16.3 ns/op 8 B/op 1 allocs/op
|
||||
Benchmark_TrimRightBytes/default-16 16212578 74.1 ns/op 40 B/op 2 allocs/op
|
||||
Benchmark_TrimRightBytes/default-16 16434686 74.1 ns/op 40 B/op 2 allocs/op
|
||||
|
||||
Benchmark_TrimLeftBytes/fiber-16 74983128 16.3 ns/op 8 B/op 1 allocs/op
|
||||
Benchmark_TrimLeftBytes/fiber-16 74985002 16.3 ns/op 8 B/op 1 allocs/op
|
||||
Benchmark_TrimLeftBytes/default-16 21047868 56.5 ns/op 40 B/op 2 allocs/op
|
||||
Benchmark_TrimLeftBytes/default-16 21048015 56.5 ns/op 40 B/op 2 allocs/op
|
||||
|
||||
Benchmark_TrimBytes/fiber-16 54533307 21.9 ns/op 16 B/op 1 allocs/op
|
||||
Benchmark_TrimBytes/fiber-16 54532812 21.9 ns/op 16 B/op 1 allocs/op
|
||||
Benchmark_TrimBytes/default-16 14282517 84.6 ns/op 48 B/op 2 allocs/op
|
||||
Benchmark_TrimBytes/default-16 14114508 84.7 ns/op 48 B/op 2 allocs/op
|
||||
|
||||
Benchmark_EqualFolds/fiber-16 36355153 32.6 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_EqualFolds/fiber-16 36355593 32.6 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_EqualFolds/default-16 15186220 78.1 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_EqualFolds/default-16 15186412 78.3 ns/op 0 B/op 0 allocs/op
|
||||
|
||||
Benchmark_UUID/fiber-16 23994625 49.8 ns/op 48 B/op 1 allocs/op
|
||||
Benchmark_UUID/fiber-16 23994768 50.1 ns/op 48 B/op 1 allocs/op
|
||||
Benchmark_UUID/default-16 3233772 371 ns/op 208 B/op 6 allocs/op
|
||||
Benchmark_UUID/default-16 3251295 370 ns/op 208 B/op 6 allocs/op
|
||||
|
||||
Benchmark_GetString/unsafe-16 1000000000 0.709 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_GetString/unsafe-16 1000000000 0.713 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_GetString/default-16 59986202 19.0 ns/op 16 B/op 1 allocs/op
|
||||
Benchmark_GetString/default-16 63142939 19.0 ns/op 16 B/op 1 allocs/op
|
||||
|
||||
Benchmark_GetBytes/unsafe-16 508360195 2.36 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_GetBytes/unsafe-16 508359979 2.35 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_GetBytes/default-16 46143019 25.7 ns/op 16 B/op 1 allocs/op
|
||||
Benchmark_GetBytes/default-16 44434734 25.6 ns/op 16 B/op 1 allocs/op
|
||||
|
||||
Benchmark_GetMIME/fiber-16 21423750 56.3 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_GetMIME/fiber-16 21423559 55.4 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_GetMIME/default-16 6735282 173 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_GetMIME/default-16 6895002 172 ns/op 0 B/op 0 allocs/op
|
||||
|
||||
Benchmark_StatusMessage/fiber-16 1000000000 0.766 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_StatusMessage/fiber-16 1000000000 0.767 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_StatusMessage/default-16 159538528 7.50 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_StatusMessage/default-16 159750830 7.51 ns/op 0 B/op 0 allocs/op
|
||||
|
||||
Benchmark_ToUpper/fiber-16 22217408 53.3 ns/op 48 B/op 1 allocs/op
|
||||
Benchmark_ToUpper/fiber-16 22636554 53.2 ns/op 48 B/op 1 allocs/op
|
||||
Benchmark_ToUpper/default-16 11108600 108 ns/op 48 B/op 1 allocs/op
|
||||
Benchmark_ToUpper/default-16 11108580 108 ns/op 48 B/op 1 allocs/op
|
||||
|
||||
Benchmark_ToLower/fiber-16 23994720 49.8 ns/op 48 B/op 1 allocs/op
|
||||
Benchmark_ToLower/fiber-16 23994768 50.1 ns/op 48 B/op 1 allocs/op
|
||||
Benchmark_ToLower/default-16 10808376 110 ns/op 48 B/op 1 allocs/op
|
||||
Benchmark_ToLower/default-16 10617034 110 ns/op 48 B/op 1 allocs/op
|
||||
|
||||
Benchmark_TrimRight/fiber-16 413699521 2.94 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_TrimRight/fiber-16 415131687 2.91 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_TrimRight/default-16 23994577 49.1 ns/op 32 B/op 1 allocs/op
|
||||
Benchmark_TrimRight/default-16 24484249 49.4 ns/op 32 B/op 1 allocs/op
|
||||
|
||||
Benchmark_TrimLeft/fiber-16 379661170 3.13 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_TrimLeft/fiber-16 382079941 3.16 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_TrimLeft/default-16 27900877 41.9 ns/op 32 B/op 1 allocs/op
|
||||
Benchmark_TrimLeft/default-16 28564898 42.0 ns/op 32 B/op 1 allocs/op
|
||||
|
||||
Benchmark_Trim/fiber-16 236632856 4.96 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_Trim/fiber-16 237570085 4.93 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_Trim/default-16 18457221 66.0 ns/op 32 B/op 1 allocs/op
|
||||
Benchmark_Trim/default-16 18177328 65.9 ns/op 32 B/op 1 allocs/op
|
||||
Benchmark_Trim/default.trimspace-16 188933770 6.33 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_Trim/default.trimspace-16 184007649 6.42 ns/op 0 B/op 0 allocs/op
|
||||
|
||||
Benchmark_ConvertToBytes/fiber-8 43773547 24.43 ns/op 0 B/op 0 allocs/op
|
||||
Benchmark_ConvertToBytes/fiber-8 45849477 25.33 ns/op 0 B/op 0 allocs/op
|
||||
```
|
||||
+68
@@ -0,0 +1,68 @@
|
||||
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
|
||||
// 🤖 Github Repository: https://github.com/gofiber/fiber
|
||||
// 📌 API Documentation: https://docs.gofiber.io
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"testing"
|
||||
"text/tabwriter"
|
||||
)
|
||||
|
||||
// AssertEqual checks if values are equal
|
||||
func AssertEqual(tb testing.TB, expected, actual interface{}, description ...string) { //nolint:thelper // TODO: Verify if tb can be nil
|
||||
if tb != nil {
|
||||
tb.Helper()
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(expected, actual) {
|
||||
return
|
||||
}
|
||||
|
||||
aType := "<nil>"
|
||||
bType := "<nil>"
|
||||
|
||||
if expected != nil {
|
||||
aType = reflect.TypeOf(expected).String()
|
||||
}
|
||||
if actual != nil {
|
||||
bType = reflect.TypeOf(actual).String()
|
||||
}
|
||||
|
||||
testName := "AssertEqual"
|
||||
if tb != nil {
|
||||
testName = tb.Name()
|
||||
}
|
||||
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
|
||||
var buf bytes.Buffer
|
||||
const pad = 5
|
||||
w := tabwriter.NewWriter(&buf, 0, 0, pad, ' ', 0)
|
||||
_, _ = fmt.Fprintf(w, "\nTest:\t%s", testName)
|
||||
_, _ = fmt.Fprintf(w, "\nTrace:\t%s:%d", filepath.Base(file), line)
|
||||
if len(description) > 0 {
|
||||
_, _ = fmt.Fprintf(w, "\nDescription:\t%s", description[0])
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "\nExpect:\t%v\t(%s)", expected, aType)
|
||||
_, _ = fmt.Fprintf(w, "\nResult:\t%v\t(%s)", actual, bType)
|
||||
|
||||
var result string
|
||||
if err := w.Flush(); err != nil {
|
||||
result = err.Error()
|
||||
} else {
|
||||
result = buf.String()
|
||||
}
|
||||
|
||||
if tb != nil {
|
||||
tb.Fatal(result)
|
||||
} else {
|
||||
log.Fatal(result) //nolint:revive // tb might be nil, so we need a fallback
|
||||
}
|
||||
}
|
||||
+69
@@ -0,0 +1,69 @@
|
||||
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
|
||||
// 🤖 Github Repository: https://github.com/gofiber/fiber
|
||||
// 📌 API Documentation: https://docs.gofiber.io
|
||||
|
||||
package utils
|
||||
|
||||
// ToLowerBytes converts ascii slice to lower-case in-place.
|
||||
func ToLowerBytes(b []byte) []byte {
|
||||
for i := 0; i < len(b); i++ {
|
||||
b[i] = toLowerTable[b[i]]
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// ToUpperBytes converts ascii slice to upper-case in-place.
|
||||
func ToUpperBytes(b []byte) []byte {
|
||||
for i := 0; i < len(b); i++ {
|
||||
b[i] = toUpperTable[b[i]]
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// TrimRightBytes is the equivalent of bytes.TrimRight
|
||||
func TrimRightBytes(b []byte, cutset byte) []byte {
|
||||
lenStr := len(b)
|
||||
for lenStr > 0 && b[lenStr-1] == cutset {
|
||||
lenStr--
|
||||
}
|
||||
return b[:lenStr]
|
||||
}
|
||||
|
||||
// TrimLeftBytes is the equivalent of bytes.TrimLeft
|
||||
func TrimLeftBytes(b []byte, cutset byte) []byte {
|
||||
lenStr, start := len(b), 0
|
||||
for start < lenStr && b[start] == cutset {
|
||||
start++
|
||||
}
|
||||
return b[start:]
|
||||
}
|
||||
|
||||
// TrimBytes is the equivalent of bytes.Trim
|
||||
func TrimBytes(b []byte, cutset byte) []byte {
|
||||
i, j := 0, len(b)-1
|
||||
for ; i <= j; i++ {
|
||||
if b[i] != cutset {
|
||||
break
|
||||
}
|
||||
}
|
||||
for ; i < j; j-- {
|
||||
if b[j] != cutset {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return b[i : j+1]
|
||||
}
|
||||
|
||||
// EqualFoldBytes tests ascii slices for equality case-insensitively
|
||||
func EqualFoldBytes(b, s []byte) bool {
|
||||
if len(b) != len(s) {
|
||||
return false
|
||||
}
|
||||
for i := len(b) - 1; i >= 0; i-- {
|
||||
if toUpperTable[b[i]] != toUpperTable[s[i]] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
+160
@@ -0,0 +1,160 @@
|
||||
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
|
||||
// 🤖 Github Repository: https://github.com/gofiber/fiber
|
||||
// 📌 API Documentation: https://docs.gofiber.io
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unicode"
|
||||
|
||||
googleuuid "github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
toLowerTable = "\x00\x01\x02\x03\x04\x05\x06\a\b\t\n\v\f\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@abcdefghijklmnopqrstuvwxyz[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\u007f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff"
|
||||
toUpperTable = "\x00\x01\x02\x03\x04\x05\x06\a\b\t\n\v\f\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`ABCDEFGHIJKLMNOPQRSTUVWXYZ{|}~\u007f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff"
|
||||
)
|
||||
|
||||
// Copyright © 2014, Roger Peppe
|
||||
// github.com/rogpeppe/fastuuid
|
||||
// All rights reserved.
|
||||
|
||||
const (
|
||||
emptyUUID = "00000000-0000-0000-0000-000000000000"
|
||||
)
|
||||
|
||||
var (
|
||||
uuidSeed [24]byte
|
||||
uuidCounter uint64
|
||||
uuidSetup sync.Once
|
||||
unitsSlice = []byte("kmgtp")
|
||||
)
|
||||
|
||||
// UUID generates an universally unique identifier (UUID)
|
||||
func UUID() string {
|
||||
// Setup seed & counter once
|
||||
uuidSetup.Do(func() {
|
||||
if _, err := rand.Read(uuidSeed[:]); err != nil {
|
||||
return
|
||||
}
|
||||
uuidCounter = binary.LittleEndian.Uint64(uuidSeed[:8])
|
||||
})
|
||||
if atomic.LoadUint64(&uuidCounter) <= 0 {
|
||||
return emptyUUID
|
||||
}
|
||||
// first 8 bytes differ, taking a slice of the first 16 bytes
|
||||
x := atomic.AddUint64(&uuidCounter, 1)
|
||||
uuid := uuidSeed
|
||||
binary.LittleEndian.PutUint64(uuid[:8], x)
|
||||
uuid[6], uuid[9] = uuid[9], uuid[6]
|
||||
|
||||
// RFC4122 v4
|
||||
uuid[6] = (uuid[6] & 0x0f) | 0x40
|
||||
uuid[8] = uuid[8]&0x3f | 0x80
|
||||
|
||||
// create UUID representation of the first 128 bits
|
||||
b := make([]byte, 36)
|
||||
hex.Encode(b[0:8], uuid[0:4])
|
||||
b[8] = '-'
|
||||
hex.Encode(b[9:13], uuid[4:6])
|
||||
b[13] = '-'
|
||||
hex.Encode(b[14:18], uuid[6:8])
|
||||
b[18] = '-'
|
||||
hex.Encode(b[19:23], uuid[8:10])
|
||||
b[23] = '-'
|
||||
hex.Encode(b[24:], uuid[10:16])
|
||||
|
||||
return UnsafeString(b)
|
||||
}
|
||||
|
||||
// UUIDv4 returns a Random (Version 4) UUID.
|
||||
// The strength of the UUIDs is based on the strength of the crypto/rand package.
|
||||
func UUIDv4() string {
|
||||
token, err := googleuuid.NewRandom()
|
||||
if err != nil {
|
||||
return UUID()
|
||||
}
|
||||
return token.String()
|
||||
}
|
||||
|
||||
// FunctionName returns function name
|
||||
func FunctionName(fn interface{}) string {
|
||||
t := reflect.ValueOf(fn).Type()
|
||||
if t.Kind() == reflect.Func {
|
||||
return runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
|
||||
}
|
||||
return t.String()
|
||||
}
|
||||
|
||||
// GetArgument check if key is in arguments
|
||||
func GetArgument(arg string) bool {
|
||||
for i := range os.Args[1:] {
|
||||
if os.Args[1:][i] == arg {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IncrementIPRange Find available next IP address
|
||||
func IncrementIPRange(ip net.IP) {
|
||||
for j := len(ip) - 1; j >= 0; j-- {
|
||||
ip[j]++
|
||||
if ip[j] > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertToBytes returns integer size of bytes from human-readable string, ex. 42kb, 42M
|
||||
// Returns 0 if string is unrecognized
|
||||
func ConvertToBytes(humanReadableString string) int {
|
||||
strLen := len(humanReadableString)
|
||||
if strLen == 0 {
|
||||
return 0
|
||||
}
|
||||
var unitPrefixPos, lastNumberPos int
|
||||
// loop the string
|
||||
for i := strLen - 1; i >= 0; i-- {
|
||||
// check if the char is a number
|
||||
if unicode.IsDigit(rune(humanReadableString[i])) {
|
||||
lastNumberPos = i
|
||||
break
|
||||
} else if humanReadableString[i] != ' ' {
|
||||
unitPrefixPos = i
|
||||
}
|
||||
}
|
||||
|
||||
if lastNumberPos < 0 {
|
||||
return 0
|
||||
}
|
||||
// fetch the number part and parse it to float
|
||||
size, err := strconv.ParseFloat(humanReadableString[:lastNumberPos+1], 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
// check the multiplier from the string and use it
|
||||
if unitPrefixPos > 0 {
|
||||
// convert multiplier char to lowercase and check if exists in units slice
|
||||
index := bytes.IndexByte(unitsSlice, toLowerTable[humanReadableString[unitPrefixPos]])
|
||||
if index != -1 {
|
||||
const bytesPerKB = 1000
|
||||
size *= math.Pow(bytesPerKB, float64(index+1))
|
||||
}
|
||||
}
|
||||
|
||||
return int(size)
|
||||
}
|
||||
+117
@@ -0,0 +1,117 @@
|
||||
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
|
||||
// 🤖 Github Repository: https://github.com/gofiber/fiber
|
||||
// 📌 API Documentation: https://docs.gofiber.io
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CopyString copies a string to make it immutable
|
||||
func CopyString(s string) string {
|
||||
return string(UnsafeBytes(s))
|
||||
}
|
||||
|
||||
// CopyBytes copies a slice to make it immutable
|
||||
func CopyBytes(b []byte) []byte {
|
||||
tmp := make([]byte, len(b))
|
||||
copy(tmp, b)
|
||||
return tmp
|
||||
}
|
||||
|
||||
const (
|
||||
uByte = 1 << (10 * iota) // 1 << 10 == 1024
|
||||
uKilobyte
|
||||
uMegabyte
|
||||
uGigabyte
|
||||
uTerabyte
|
||||
uPetabyte
|
||||
uExabyte
|
||||
)
|
||||
|
||||
// ByteSize returns a human-readable byte string of the form 10M, 12.5K, and so forth.
|
||||
// The unit that results in the smallest number greater than or equal to 1 is always chosen.
|
||||
func ByteSize(bytes uint64) string {
|
||||
unit := ""
|
||||
value := float64(bytes)
|
||||
switch {
|
||||
case bytes >= uExabyte:
|
||||
unit = "EB"
|
||||
value /= uExabyte
|
||||
case bytes >= uPetabyte:
|
||||
unit = "PB"
|
||||
value /= uPetabyte
|
||||
case bytes >= uTerabyte:
|
||||
unit = "TB"
|
||||
value /= uTerabyte
|
||||
case bytes >= uGigabyte:
|
||||
unit = "GB"
|
||||
value /= uGigabyte
|
||||
case bytes >= uMegabyte:
|
||||
unit = "MB"
|
||||
value /= uMegabyte
|
||||
case bytes >= uKilobyte:
|
||||
unit = "KB"
|
||||
value /= uKilobyte
|
||||
case bytes >= uByte:
|
||||
unit = "B"
|
||||
default:
|
||||
return "0B"
|
||||
}
|
||||
result := strconv.FormatFloat(value, 'f', 1, 64)
|
||||
result = strings.TrimSuffix(result, ".0")
|
||||
return result + unit
|
||||
}
|
||||
|
||||
// ToString Change arg to string
|
||||
func ToString(arg interface{}, timeFormat ...string) string {
|
||||
tmp := reflect.Indirect(reflect.ValueOf(arg)).Interface()
|
||||
switch v := tmp.(type) {
|
||||
case int:
|
||||
return strconv.Itoa(v)
|
||||
case int8:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case int16:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case int32:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case int64:
|
||||
return strconv.FormatInt(v, 10)
|
||||
case uint:
|
||||
return strconv.Itoa(int(v))
|
||||
case uint8:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case uint16:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case uint32:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case uint64:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case string:
|
||||
return v
|
||||
case []byte:
|
||||
return string(v)
|
||||
case bool:
|
||||
return strconv.FormatBool(v)
|
||||
case float32:
|
||||
return strconv.FormatFloat(float64(v), 'f', -1, 32)
|
||||
case float64:
|
||||
return strconv.FormatFloat(v, 'f', -1, 64)
|
||||
case time.Time:
|
||||
if len(timeFormat) > 0 {
|
||||
return v.Format(timeFormat[0])
|
||||
}
|
||||
return v.Format("2006-01-02 15:04:05")
|
||||
case reflect.Value:
|
||||
return ToString(v.Interface(), timeFormat...)
|
||||
case fmt.Stringer:
|
||||
return v.String()
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
+12
@@ -0,0 +1,12 @@
|
||||
//go:build go1.20
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// UnsafeString returns a string pointer without allocation
|
||||
func UnsafeString(b []byte) string {
|
||||
return unsafe.String(unsafe.SliceData(b), len(b))
|
||||
}
|
||||
+14
@@ -0,0 +1,14 @@
|
||||
//go:build !go1.20
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// UnsafeString returns a string pointer without allocation
|
||||
//
|
||||
//nolint:gosec // unsafe is used for better performance here
|
||||
func UnsafeString(b []byte) string {
|
||||
return *(*string)(unsafe.Pointer(&b))
|
||||
}
|
||||
+12
@@ -0,0 +1,12 @@
|
||||
//go:build go1.20
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// UnsafeBytes returns a byte pointer without allocation.
|
||||
func UnsafeBytes(s string) []byte {
|
||||
return unsafe.Slice(unsafe.StringData(s), len(s))
|
||||
}
|
||||
+24
@@ -0,0 +1,24 @@
|
||||
//go:build !go1.20
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const MaxStringLen = 0x7fff0000 // Maximum string length for UnsafeBytes. (decimal: 2147418112)
|
||||
|
||||
// UnsafeBytes returns a byte pointer without allocation.
|
||||
// String length shouldn't be more than 2147418112.
|
||||
//
|
||||
//nolint:gosec // unsafe is used for better performance here
|
||||
func UnsafeBytes(s string) []byte {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return (*[MaxStringLen]byte)(unsafe.Pointer(
|
||||
(*reflect.StringHeader)(unsafe.Pointer(&s)).Data),
|
||||
)[:len(s):len(s)]
|
||||
}
|
||||
+16
@@ -0,0 +1,16 @@
|
||||
package utils
|
||||
|
||||
// Deprecated: Please use UnsafeString instead
|
||||
func GetString(b []byte) string {
|
||||
return UnsafeString(b)
|
||||
}
|
||||
|
||||
// Deprecated: Please use UnsafeBytes instead
|
||||
func GetBytes(s string) []byte {
|
||||
return UnsafeBytes(s)
|
||||
}
|
||||
|
||||
// Deprecated: Please use CopyString instead
|
||||
func ImmutableString(s string) string {
|
||||
return CopyString(s)
|
||||
}
|
||||
+267
@@ -0,0 +1,267 @@
|
||||
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
|
||||
// 🤖 Github Repository: https://github.com/gofiber/fiber
|
||||
// 📌 API Documentation: https://docs.gofiber.io
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"mime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const MIMEOctetStream = "application/octet-stream"
|
||||
|
||||
// GetMIME returns the content-type of a file extension
|
||||
func GetMIME(extension string) string {
|
||||
if len(extension) == 0 {
|
||||
return ""
|
||||
}
|
||||
var foundMime string
|
||||
if extension[0] == '.' {
|
||||
foundMime = mimeExtensions[extension[1:]]
|
||||
} else {
|
||||
foundMime = mimeExtensions[extension]
|
||||
}
|
||||
|
||||
if len(foundMime) == 0 {
|
||||
if extension[0] != '.' {
|
||||
foundMime = mime.TypeByExtension("." + extension)
|
||||
} else {
|
||||
foundMime = mime.TypeByExtension(extension)
|
||||
}
|
||||
|
||||
if foundMime == "" {
|
||||
return MIMEOctetStream
|
||||
}
|
||||
}
|
||||
return foundMime
|
||||
}
|
||||
|
||||
// ParseVendorSpecificContentType check if content type is vendor specific and
|
||||
// if it is parsable to any known types. If its not vendor specific then returns
|
||||
// the original content type.
|
||||
func ParseVendorSpecificContentType(cType string) string {
|
||||
plusIndex := strings.Index(cType, "+")
|
||||
|
||||
if plusIndex == -1 {
|
||||
return cType
|
||||
}
|
||||
|
||||
var parsableType string
|
||||
if semiColonIndex := strings.Index(cType, ";"); semiColonIndex == -1 {
|
||||
parsableType = cType[plusIndex+1:]
|
||||
} else if plusIndex < semiColonIndex {
|
||||
parsableType = cType[plusIndex+1 : semiColonIndex]
|
||||
} else {
|
||||
return cType[:semiColonIndex]
|
||||
}
|
||||
|
||||
slashIndex := strings.Index(cType, "/")
|
||||
|
||||
if slashIndex == -1 {
|
||||
return cType
|
||||
}
|
||||
|
||||
return cType[0:slashIndex+1] + parsableType
|
||||
}
|
||||
|
||||
// limits for HTTP statuscodes
|
||||
const (
|
||||
statusMessageMin = 100
|
||||
statusMessageMax = 511
|
||||
)
|
||||
|
||||
// StatusMessage returns the correct message for the provided HTTP statuscode
|
||||
func StatusMessage(status int) string {
|
||||
if status < statusMessageMin || status > statusMessageMax {
|
||||
return ""
|
||||
}
|
||||
return statusMessage[status]
|
||||
}
|
||||
|
||||
// NOTE: Keep this in sync with the status code list
|
||||
var statusMessage = []string{
|
||||
100: "Continue", // StatusContinue
|
||||
101: "Switching Protocols", // StatusSwitchingProtocols
|
||||
102: "Processing", // StatusProcessing
|
||||
103: "Early Hints", // StatusEarlyHints
|
||||
|
||||
200: "OK", // StatusOK
|
||||
201: "Created", // StatusCreated
|
||||
202: "Accepted", // StatusAccepted
|
||||
203: "Non-Authoritative Information", // StatusNonAuthoritativeInformation
|
||||
204: "No Content", // StatusNoContent
|
||||
205: "Reset Content", // StatusResetContent
|
||||
206: "Partial Content", // StatusPartialContent
|
||||
207: "Multi-Status", // StatusMultiStatus
|
||||
208: "Already Reported", // StatusAlreadyReported
|
||||
226: "IM Used", // StatusIMUsed
|
||||
|
||||
300: "Multiple Choices", // StatusMultipleChoices
|
||||
301: "Moved Permanently", // StatusMovedPermanently
|
||||
302: "Found", // StatusFound
|
||||
303: "See Other", // StatusSeeOther
|
||||
304: "Not Modified", // StatusNotModified
|
||||
305: "Use Proxy", // StatusUseProxy
|
||||
306: "Switch Proxy", // StatusSwitchProxy
|
||||
307: "Temporary Redirect", // StatusTemporaryRedirect
|
||||
308: "Permanent Redirect", // StatusPermanentRedirect
|
||||
|
||||
400: "Bad Request", // StatusBadRequest
|
||||
401: "Unauthorized", // StatusUnauthorized
|
||||
402: "Payment Required", // StatusPaymentRequired
|
||||
403: "Forbidden", // StatusForbidden
|
||||
404: "Not Found", // StatusNotFound
|
||||
405: "Method Not Allowed", // StatusMethodNotAllowed
|
||||
406: "Not Acceptable", // StatusNotAcceptable
|
||||
407: "Proxy Authentication Required", // StatusProxyAuthRequired
|
||||
408: "Request Timeout", // StatusRequestTimeout
|
||||
409: "Conflict", // StatusConflict
|
||||
410: "Gone", // StatusGone
|
||||
411: "Length Required", // StatusLengthRequired
|
||||
412: "Precondition Failed", // StatusPreconditionFailed
|
||||
413: "Request Entity Too Large", // StatusRequestEntityTooLarge
|
||||
414: "Request URI Too Long", // StatusRequestURITooLong
|
||||
415: "Unsupported Media Type", // StatusUnsupportedMediaType
|
||||
416: "Requested Range Not Satisfiable", // StatusRequestedRangeNotSatisfiable
|
||||
417: "Expectation Failed", // StatusExpectationFailed
|
||||
418: "I'm a teapot", // StatusTeapot
|
||||
421: "Misdirected Request", // StatusMisdirectedRequest
|
||||
422: "Unprocessable Entity", // StatusUnprocessableEntity
|
||||
423: "Locked", // StatusLocked
|
||||
424: "Failed Dependency", // StatusFailedDependency
|
||||
425: "Too Early", // StatusTooEarly
|
||||
426: "Upgrade Required", // StatusUpgradeRequired
|
||||
428: "Precondition Required", // StatusPreconditionRequired
|
||||
429: "Too Many Requests", // StatusTooManyRequests
|
||||
431: "Request Header Fields Too Large", // StatusRequestHeaderFieldsTooLarge
|
||||
451: "Unavailable For Legal Reasons", // StatusUnavailableForLegalReasons
|
||||
|
||||
500: "Internal Server Error", // StatusInternalServerError
|
||||
501: "Not Implemented", // StatusNotImplemented
|
||||
502: "Bad Gateway", // StatusBadGateway
|
||||
503: "Service Unavailable", // StatusServiceUnavailable
|
||||
504: "Gateway Timeout", // StatusGatewayTimeout
|
||||
505: "HTTP Version Not Supported", // StatusHTTPVersionNotSupported
|
||||
506: "Variant Also Negotiates", // StatusVariantAlsoNegotiates
|
||||
507: "Insufficient Storage", // StatusInsufficientStorage
|
||||
508: "Loop Detected", // StatusLoopDetected
|
||||
510: "Not Extended", // StatusNotExtended
|
||||
511: "Network Authentication Required", // StatusNetworkAuthenticationRequired
|
||||
}
|
||||
|
||||
// MIME types were copied from https://github.com/nginx/nginx/blob/67d2a9541826ecd5db97d604f23460210fd3e517/conf/mime.types with the following updates:
|
||||
// - Use "application/xml" instead of "text/xml" as recommended per https://datatracker.ietf.org/doc/html/rfc7303#section-4.1
|
||||
// - Use "text/javascript" instead of "application/javascript" as recommended per https://www.rfc-editor.org/rfc/rfc9239#name-text-javascript
|
||||
var mimeExtensions = map[string]string{
|
||||
"html": "text/html",
|
||||
"htm": "text/html",
|
||||
"shtml": "text/html",
|
||||
"css": "text/css",
|
||||
"xml": "application/xml",
|
||||
"gif": "image/gif",
|
||||
"jpeg": "image/jpeg",
|
||||
"jpg": "image/jpeg",
|
||||
"js": "text/javascript",
|
||||
"atom": "application/atom+xml",
|
||||
"rss": "application/rss+xml",
|
||||
"mml": "text/mathml",
|
||||
"txt": "text/plain",
|
||||
"jad": "text/vnd.sun.j2me.app-descriptor",
|
||||
"wml": "text/vnd.wap.wml",
|
||||
"htc": "text/x-component",
|
||||
"avif": "image/avif",
|
||||
"png": "image/png",
|
||||
"svg": "image/svg+xml",
|
||||
"svgz": "image/svg+xml",
|
||||
"tif": "image/tiff",
|
||||
"tiff": "image/tiff",
|
||||
"wbmp": "image/vnd.wap.wbmp",
|
||||
"webp": "image/webp",
|
||||
"ico": "image/x-icon",
|
||||
"jng": "image/x-jng",
|
||||
"bmp": "image/x-ms-bmp",
|
||||
"woff": "font/woff",
|
||||
"woff2": "font/woff2",
|
||||
"jar": "application/java-archive",
|
||||
"war": "application/java-archive",
|
||||
"ear": "application/java-archive",
|
||||
"json": "application/json",
|
||||
"hqx": "application/mac-binhex40",
|
||||
"doc": "application/msword",
|
||||
"pdf": "application/pdf",
|
||||
"ps": "application/postscript",
|
||||
"eps": "application/postscript",
|
||||
"ai": "application/postscript",
|
||||
"rtf": "application/rtf",
|
||||
"m3u8": "application/vnd.apple.mpegurl",
|
||||
"kml": "application/vnd.google-earth.kml+xml",
|
||||
"kmz": "application/vnd.google-earth.kmz",
|
||||
"xls": "application/vnd.ms-excel",
|
||||
"eot": "application/vnd.ms-fontobject",
|
||||
"ppt": "application/vnd.ms-powerpoint",
|
||||
"odg": "application/vnd.oasis.opendocument.graphics",
|
||||
"odp": "application/vnd.oasis.opendocument.presentation",
|
||||
"ods": "application/vnd.oasis.opendocument.spreadsheet",
|
||||
"odt": "application/vnd.oasis.opendocument.text",
|
||||
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"wmlc": "application/vnd.wap.wmlc",
|
||||
"wasm": "application/wasm",
|
||||
"7z": "application/x-7z-compressed",
|
||||
"cco": "application/x-cocoa",
|
||||
"jardiff": "application/x-java-archive-diff",
|
||||
"jnlp": "application/x-java-jnlp-file",
|
||||
"run": "application/x-makeself",
|
||||
"pl": "application/x-perl",
|
||||
"pm": "application/x-perl",
|
||||
"prc": "application/x-pilot",
|
||||
"pdb": "application/x-pilot",
|
||||
"rar": "application/x-rar-compressed",
|
||||
"rpm": "application/x-redhat-package-manager",
|
||||
"sea": "application/x-sea",
|
||||
"swf": "application/x-shockwave-flash",
|
||||
"sit": "application/x-stuffit",
|
||||
"tcl": "application/x-tcl",
|
||||
"tk": "application/x-tcl",
|
||||
"der": "application/x-x509-ca-cert",
|
||||
"pem": "application/x-x509-ca-cert",
|
||||
"crt": "application/x-x509-ca-cert",
|
||||
"xpi": "application/x-xpinstall",
|
||||
"xhtml": "application/xhtml+xml",
|
||||
"xspf": "application/xspf+xml",
|
||||
"zip": "application/zip",
|
||||
"bin": "application/octet-stream",
|
||||
"exe": "application/octet-stream",
|
||||
"dll": "application/octet-stream",
|
||||
"deb": "application/octet-stream",
|
||||
"dmg": "application/octet-stream",
|
||||
"iso": "application/octet-stream",
|
||||
"img": "application/octet-stream",
|
||||
"msi": "application/octet-stream",
|
||||
"msp": "application/octet-stream",
|
||||
"msm": "application/octet-stream",
|
||||
"mid": "audio/midi",
|
||||
"midi": "audio/midi",
|
||||
"kar": "audio/midi",
|
||||
"mp3": "audio/mpeg",
|
||||
"ogg": "audio/ogg",
|
||||
"m4a": "audio/x-m4a",
|
||||
"ra": "audio/x-realaudio",
|
||||
"3gpp": "video/3gpp",
|
||||
"3gp": "video/3gpp",
|
||||
"ts": "video/mp2t",
|
||||
"mp4": "video/mp4",
|
||||
"mpeg": "video/mpeg",
|
||||
"mpg": "video/mpeg",
|
||||
"mov": "video/quicktime",
|
||||
"webm": "video/webm",
|
||||
"flv": "video/x-flv",
|
||||
"m4v": "video/x-m4v",
|
||||
"mng": "video/x-mng",
|
||||
"asx": "video/x-ms-asf",
|
||||
"asf": "video/x-ms-asf",
|
||||
"wmv": "video/x-ms-wmv",
|
||||
"avi": "video/x-msvideo",
|
||||
}
|
||||
+143
@@ -0,0 +1,143 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// IsIPv4 works the same way as net.ParseIP,
|
||||
// but without check for IPv6 case and without returning net.IP slice, whereby IsIPv4 makes no allocations.
|
||||
func IsIPv4(s string) bool {
|
||||
for i := 0; i < net.IPv4len; i++ {
|
||||
if len(s) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if i > 0 {
|
||||
if s[0] != '.' {
|
||||
return false
|
||||
}
|
||||
s = s[1:]
|
||||
}
|
||||
|
||||
n, ci := 0, 0
|
||||
|
||||
for ci = 0; ci < len(s) && '0' <= s[ci] && s[ci] <= '9'; ci++ {
|
||||
n = n*10 + int(s[ci]-'0')
|
||||
if n > 0xFF {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if ci == 0 || (ci > 1 && s[0] == '0') {
|
||||
return false
|
||||
}
|
||||
|
||||
s = s[ci:]
|
||||
}
|
||||
|
||||
return len(s) == 0
|
||||
}
|
||||
|
||||
// IsIPv6 works the same way as net.ParseIP,
|
||||
// but without check for IPv4 case and without returning net.IP slice, whereby IsIPv6 makes no allocations.
|
||||
func IsIPv6(s string) bool {
|
||||
ellipsis := -1 // position of ellipsis in ip
|
||||
|
||||
// Might have leading ellipsis
|
||||
if len(s) >= 2 && s[0] == ':' && s[1] == ':' {
|
||||
ellipsis = 0
|
||||
s = s[2:]
|
||||
// Might be only ellipsis
|
||||
if len(s) == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Loop, parsing hex numbers followed by colon.
|
||||
i := 0
|
||||
for i < net.IPv6len {
|
||||
// Hex number.
|
||||
n, ci := 0, 0
|
||||
|
||||
for ci = 0; ci < len(s); ci++ {
|
||||
if '0' <= s[ci] && s[ci] <= '9' {
|
||||
n *= 16
|
||||
n += int(s[ci] - '0')
|
||||
} else if 'a' <= s[ci] && s[ci] <= 'f' {
|
||||
n *= 16
|
||||
n += int(s[ci]-'a') + 10
|
||||
} else if 'A' <= s[ci] && s[ci] <= 'F' {
|
||||
n *= 16
|
||||
n += int(s[ci]-'A') + 10
|
||||
} else {
|
||||
break
|
||||
}
|
||||
if n > 0xFFFF {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if ci == 0 || n > 0xFFFF {
|
||||
return false
|
||||
}
|
||||
|
||||
if ci < len(s) && s[ci] == '.' {
|
||||
if ellipsis < 0 && i != net.IPv6len-net.IPv4len {
|
||||
return false
|
||||
}
|
||||
if i+net.IPv4len > net.IPv6len {
|
||||
return false
|
||||
}
|
||||
|
||||
if !IsIPv4(s) {
|
||||
return false
|
||||
}
|
||||
|
||||
s = ""
|
||||
i += net.IPv4len
|
||||
break
|
||||
}
|
||||
|
||||
// Save this 16-bit chunk.
|
||||
i += 2
|
||||
|
||||
// Stop at end of string.
|
||||
s = s[ci:]
|
||||
if len(s) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Otherwise must be followed by colon and more.
|
||||
if s[0] != ':' || len(s) == 1 {
|
||||
return false
|
||||
}
|
||||
s = s[1:]
|
||||
|
||||
// Look for ellipsis.
|
||||
if s[0] == ':' {
|
||||
if ellipsis >= 0 { // already have one
|
||||
return false
|
||||
}
|
||||
ellipsis = i
|
||||
s = s[1:]
|
||||
if len(s) == 0 { // can be at end
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Must have used entire string.
|
||||
if len(s) != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// If didn't parse enough, expand ellipsis.
|
||||
if i < net.IPv6len {
|
||||
if ellipsis < 0 {
|
||||
return false
|
||||
}
|
||||
} else if ellipsis >= 0 {
|
||||
// Ellipsis must represent at least one 0 group.
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
+9
@@ -0,0 +1,9 @@
|
||||
package utils
|
||||
|
||||
// JSONMarshal returns the JSON encoding of v.
|
||||
type JSONMarshal func(v interface{}) ([]byte, error)
|
||||
|
||||
// JSONUnmarshal parses the JSON-encoded data and stores the result
|
||||
// in the value pointed to by v. If v is nil or not a pointer,
|
||||
// Unmarshal returns an InvalidUnmarshalError.
|
||||
type JSONUnmarshal func(data []byte, v interface{}) error
|
||||
+75
@@ -0,0 +1,75 @@
|
||||
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
|
||||
// 🤖 Github Repository: https://github.com/gofiber/fiber
|
||||
// 📌 API Documentation: https://docs.gofiber.io
|
||||
|
||||
package utils
|
||||
|
||||
// ToLower converts ascii string to lower-case
|
||||
func ToLower(b string) string {
|
||||
res := make([]byte, len(b))
|
||||
copy(res, b)
|
||||
for i := 0; i < len(res); i++ {
|
||||
res[i] = toLowerTable[res[i]]
|
||||
}
|
||||
|
||||
return UnsafeString(res)
|
||||
}
|
||||
|
||||
// ToUpper converts ascii string to upper-case
|
||||
func ToUpper(b string) string {
|
||||
res := make([]byte, len(b))
|
||||
copy(res, b)
|
||||
for i := 0; i < len(res); i++ {
|
||||
res[i] = toUpperTable[res[i]]
|
||||
}
|
||||
|
||||
return UnsafeString(res)
|
||||
}
|
||||
|
||||
// TrimLeft is the equivalent of strings.TrimLeft
|
||||
func TrimLeft(s string, cutset byte) string {
|
||||
lenStr, start := len(s), 0
|
||||
for start < lenStr && s[start] == cutset {
|
||||
start++
|
||||
}
|
||||
return s[start:]
|
||||
}
|
||||
|
||||
// Trim is the equivalent of strings.Trim
|
||||
func Trim(s string, cutset byte) string {
|
||||
i, j := 0, len(s)-1
|
||||
for ; i <= j; i++ {
|
||||
if s[i] != cutset {
|
||||
break
|
||||
}
|
||||
}
|
||||
for ; i < j; j-- {
|
||||
if s[j] != cutset {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return s[i : j+1]
|
||||
}
|
||||
|
||||
// TrimRight is the equivalent of strings.TrimRight
|
||||
func TrimRight(s string, cutset byte) string {
|
||||
lenStr := len(s)
|
||||
for lenStr > 0 && s[lenStr-1] == cutset {
|
||||
lenStr--
|
||||
}
|
||||
return s[:lenStr]
|
||||
}
|
||||
|
||||
// EqualFold tests ascii strings for equality case-insensitively
|
||||
func EqualFold(b, s string) bool {
|
||||
if len(b) != len(s) {
|
||||
return false
|
||||
}
|
||||
for i := len(b) - 1; i >= 0; i-- {
|
||||
if toUpperTable[b[i]] != toUpperTable[s[i]] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
+32
@@ -0,0 +1,32 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
timestampTimer sync.Once
|
||||
// Timestamp please start the timer function before you use this value
|
||||
// please load the value with atomic `atomic.LoadUint32(&utils.Timestamp)`
|
||||
Timestamp uint32
|
||||
)
|
||||
|
||||
// StartTimeStampUpdater starts a concurrent function which stores the timestamp to an atomic value per second,
|
||||
// which is much better for performance than determining it at runtime each time
|
||||
func StartTimeStampUpdater() {
|
||||
timestampTimer.Do(func() {
|
||||
// set initial value
|
||||
atomic.StoreUint32(&Timestamp, uint32(time.Now().Unix()))
|
||||
go func(sleep time.Duration) {
|
||||
ticker := time.NewTicker(sleep)
|
||||
defer ticker.Stop()
|
||||
|
||||
for t := range ticker.C {
|
||||
// update timestamp
|
||||
atomic.StoreUint32(&Timestamp, uint32(t.Unix()))
|
||||
}
|
||||
}(1 * time.Second) // duration
|
||||
})
|
||||
}
|
||||
+4
@@ -0,0 +1,4 @@
|
||||
package utils
|
||||
|
||||
// XMLMarshal returns the XML encoding of v.
|
||||
type XMLMarshal func(v interface{}) ([]byte, error)
|
||||
Reference in New Issue
Block a user