rename ngin to rest
This commit is contained in:
141
rest/handler/authhandler.go
Normal file
141
rest/handler/authhandler.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/rest/internal"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
)
|
||||
|
||||
const (
|
||||
jwtAudience = "aud"
|
||||
jwtExpire = "exp"
|
||||
jwtId = "jti"
|
||||
jwtIssueAt = "iat"
|
||||
jwtIssuer = "iss"
|
||||
jwtNotBefore = "nbf"
|
||||
jwtSubject = "sub"
|
||||
noDetailReason = "no detail reason"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidToken = errors.New("invalid auth token")
|
||||
errNoClaims = errors.New("no auth params")
|
||||
)
|
||||
|
||||
type (
|
||||
AuthorizeOptions struct {
|
||||
PrevSecret string
|
||||
Callback UnauthorizedCallback
|
||||
}
|
||||
|
||||
UnauthorizedCallback func(w http.ResponseWriter, r *http.Request, err error)
|
||||
AuthorizeOption func(opts *AuthorizeOptions)
|
||||
)
|
||||
|
||||
func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.Handler {
|
||||
var authOpts AuthorizeOptions
|
||||
for _, opt := range opts {
|
||||
opt(&authOpts)
|
||||
}
|
||||
|
||||
parser := internal.NewTokenParser()
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token, err := parser.ParseToken(r, secret, authOpts.PrevSecret)
|
||||
if err != nil {
|
||||
unauthorized(w, r, err, authOpts.Callback)
|
||||
return
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
unauthorized(w, r, errInvalidToken, authOpts.Callback)
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
unauthorized(w, r, errNoClaims, authOpts.Callback)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
for k, v := range claims {
|
||||
switch k {
|
||||
case jwtAudience, jwtExpire, jwtId, jwtIssueAt, jwtIssuer, jwtNotBefore, jwtSubject:
|
||||
// ignore the standard claims
|
||||
default:
|
||||
ctx = context.WithValue(ctx, k, v)
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func WithPrevSecret(secret string) AuthorizeOption {
|
||||
return func(opts *AuthorizeOptions) {
|
||||
opts.PrevSecret = secret
|
||||
}
|
||||
}
|
||||
|
||||
func WithUnauthorizedCallback(callback UnauthorizedCallback) AuthorizeOption {
|
||||
return func(opts *AuthorizeOptions) {
|
||||
opts.Callback = callback
|
||||
}
|
||||
}
|
||||
|
||||
func detailAuthLog(r *http.Request, reason string) {
|
||||
// discard dump error, only for debug purpose
|
||||
details, _ := httputil.DumpRequest(r, true)
|
||||
logx.Errorf("authorize failed: %s\n=> %+v", reason, string(details))
|
||||
}
|
||||
|
||||
func unauthorized(w http.ResponseWriter, r *http.Request, err error, callback UnauthorizedCallback) {
|
||||
writer := newGuardedResponseWriter(w)
|
||||
|
||||
if err != nil {
|
||||
detailAuthLog(r, err.Error())
|
||||
} else {
|
||||
detailAuthLog(r, noDetailReason)
|
||||
}
|
||||
if callback != nil {
|
||||
callback(writer, r, err)
|
||||
}
|
||||
|
||||
writer.WriteHeader(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
type guardedResponseWriter struct {
|
||||
writer http.ResponseWriter
|
||||
wroteHeader bool
|
||||
}
|
||||
|
||||
func newGuardedResponseWriter(w http.ResponseWriter) *guardedResponseWriter {
|
||||
return &guardedResponseWriter{
|
||||
writer: w,
|
||||
}
|
||||
}
|
||||
|
||||
func (grw *guardedResponseWriter) Header() http.Header {
|
||||
return grw.writer.Header()
|
||||
}
|
||||
|
||||
func (grw *guardedResponseWriter) Write(body []byte) (int, error) {
|
||||
return grw.writer.Write(body)
|
||||
}
|
||||
|
||||
func (grw *guardedResponseWriter) WriteHeader(statusCode int) {
|
||||
if grw.wroteHeader {
|
||||
return
|
||||
}
|
||||
|
||||
grw.wroteHeader = true
|
||||
grw.writer.WriteHeader(statusCode)
|
||||
}
|
||||
99
rest/handler/authhandler_test.go
Normal file
99
rest/handler/authhandler_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAuthHandlerFailed(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
handler := Authorize("B63F477D-BBA3-4E52-96D3-C0034C27694A", WithUnauthorizedCallback(
|
||||
func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
w.Header().Set("X-Test", "test")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, err = w.Write([]byte("content"))
|
||||
assert.Nil(t, err)
|
||||
}))(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.Code)
|
||||
}
|
||||
|
||||
func TestAuthHandler(t *testing.T) {
|
||||
const key = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
token, err := buildToken(key, map[string]interface{}{
|
||||
"key": "value",
|
||||
}, 3600)
|
||||
assert.Nil(t, err)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
handler := Authorize(key)(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Test", "test")
|
||||
_, err := w.Write([]byte("content"))
|
||||
assert.Nil(t, err)
|
||||
}))
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
assert.Equal(t, "test", resp.Header().Get("X-Test"))
|
||||
assert.Equal(t, "content", resp.Body.String())
|
||||
}
|
||||
|
||||
func TestAuthHandlerWithPrevSecret(t *testing.T) {
|
||||
const (
|
||||
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
|
||||
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
token, err := buildToken(key, map[string]interface{}{
|
||||
"key": "value",
|
||||
}, 3600)
|
||||
assert.Nil(t, err)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
handler := Authorize(key, WithPrevSecret(prevKey))(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Test", "test")
|
||||
_, err := w.Write([]byte("content"))
|
||||
assert.Nil(t, err)
|
||||
}))
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
assert.Equal(t, "test", resp.Header().Get("X-Test"))
|
||||
assert.Equal(t, "content", resp.Body.String())
|
||||
}
|
||||
|
||||
func TestAuthHandler_NilError(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
assert.NotPanics(t, func() {
|
||||
unauthorized(resp, req, nil, nil)
|
||||
})
|
||||
}
|
||||
|
||||
func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
|
||||
now := time.Now().Unix()
|
||||
claims := make(jwt.MapClaims)
|
||||
claims["exp"] = now + seconds
|
||||
claims["iat"] = now
|
||||
for k, v := range payloads {
|
||||
claims[k] = v
|
||||
}
|
||||
|
||||
token := jwt.New(jwt.SigningMethodHS256)
|
||||
token.Claims = claims
|
||||
|
||||
return token.SignedString([]byte(secretKey))
|
||||
}
|
||||
41
rest/handler/breakerhandler.go
Normal file
41
rest/handler/breakerhandler.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"zero/core/breaker"
|
||||
"zero/core/logx"
|
||||
"zero/core/stat"
|
||||
"zero/rest/internal"
|
||||
"zero/rest/internal/security"
|
||||
)
|
||||
|
||||
const breakerSeparator = "://"
|
||||
|
||||
func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handler) http.Handler {
|
||||
brk := breaker.NewBreaker(breaker.WithName(strings.Join([]string{method, path}, breakerSeparator)))
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
promise, err := brk.Allow()
|
||||
if err != nil {
|
||||
metrics.AddDrop()
|
||||
logx.Errorf("[http] dropped, %s - %s - %s",
|
||||
r.RequestURI, internal.GetRemoteAddr(r), r.UserAgent())
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
cw := &security.WithCodeResponseWriter{Writer: w}
|
||||
defer func() {
|
||||
if cw.Code < http.StatusInternalServerError {
|
||||
promise.Accept()
|
||||
} else {
|
||||
promise.Reject(fmt.Sprintf("%d %s", cw.Code, http.StatusText(cw.Code)))
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(cw, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
102
rest/handler/breakerhandler_test.go
Normal file
102
rest/handler/breakerhandler_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/core/stat"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logx.Disable()
|
||||
stat.SetReporter(nil)
|
||||
}
|
||||
|
||||
func TestBreakerHandlerAccept(t *testing.T) {
|
||||
metrics := stat.NewMetrics("unit-test")
|
||||
breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
|
||||
handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Test", "test")
|
||||
_, err := w.Write([]byte("content"))
|
||||
assert.Nil(t, err)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Set("X-Test", "test")
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
assert.Equal(t, "test", resp.Header().Get("X-Test"))
|
||||
assert.Equal(t, "content", resp.Body.String())
|
||||
}
|
||||
|
||||
func TestBreakerHandlerFail(t *testing.T) {
|
||||
metrics := stat.NewMetrics("unit-test")
|
||||
breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
|
||||
handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusBadGateway, resp.Code)
|
||||
}
|
||||
|
||||
func TestBreakerHandler_4XX(t *testing.T) {
|
||||
metrics := stat.NewMetrics("unit-test")
|
||||
breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
|
||||
handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
}
|
||||
|
||||
const tries = 100
|
||||
var pass int
|
||||
for i := 0; i < tries; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
if resp.Code == http.StatusBadRequest {
|
||||
pass++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, tries, pass)
|
||||
}
|
||||
|
||||
func TestBreakerHandlerReject(t *testing.T) {
|
||||
metrics := stat.NewMetrics("unit-test")
|
||||
breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
|
||||
handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
}
|
||||
|
||||
var drops int
|
||||
for i := 0; i < 100; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
if resp.Code == http.StatusServiceUnavailable {
|
||||
drops++
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, drops >= 80, fmt.Sprintf("expected to be greater than 80, but got %d", drops))
|
||||
}
|
||||
61
rest/handler/contentsecurityhandler.go
Normal file
61
rest/handler/contentsecurityhandler.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"zero/core/codec"
|
||||
"zero/core/logx"
|
||||
"zero/rest/httpx"
|
||||
"zero/rest/internal/security"
|
||||
)
|
||||
|
||||
const contentSecurity = "X-Content-Security"
|
||||
|
||||
type UnsignedCallback func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int)
|
||||
|
||||
func ContentSecurityHandler(decrypters map[string]codec.RsaDecrypter, tolerance time.Duration,
|
||||
strict bool, callbacks ...UnsignedCallback) func(http.Handler) http.Handler {
|
||||
if len(callbacks) == 0 {
|
||||
callbacks = append(callbacks, handleVerificationFailure)
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodDelete, http.MethodGet, http.MethodPost, http.MethodPut:
|
||||
header, err := security.ParseContentSecurity(decrypters, r)
|
||||
if err != nil {
|
||||
logx.Infof("Signature parse failed, X-Content-Security: %s, error: %s",
|
||||
r.Header.Get(contentSecurity), err.Error())
|
||||
executeCallbacks(w, r, next, strict, httpx.CodeSignatureInvalidHeader, callbacks)
|
||||
} else if code := security.VerifySignature(r, header, tolerance); code != httpx.CodeSignaturePass {
|
||||
logx.Infof("Signature verification failed, X-Content-Security: %s",
|
||||
r.Header.Get(contentSecurity))
|
||||
executeCallbacks(w, r, next, strict, code, callbacks)
|
||||
} else if r.ContentLength > 0 && header.Encrypted() {
|
||||
CryptionHandler(header.Key)(next).ServeHTTP(w, r)
|
||||
} else {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
default:
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func executeCallbacks(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool,
|
||||
code int, callbacks []UnsignedCallback) {
|
||||
for _, callback := range callbacks {
|
||||
callback(w, r, next, strict, code)
|
||||
}
|
||||
}
|
||||
|
||||
func handleVerificationFailure(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
|
||||
if strict {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
} else {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
388
rest/handler/contentsecurityhandler_test.go
Normal file
388
rest/handler/contentsecurityhandler_test.go
Normal file
@@ -0,0 +1,388 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"zero/core/codec"
|
||||
"zero/rest/httpx"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const timeDiff = time.Hour * 2 * 24
|
||||
|
||||
var (
|
||||
fingerprint = "12345"
|
||||
pubKey = []byte(`-----BEGIN PUBLIC KEY-----
|
||||
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQD7bq4FLG0ctccbEFEsUBuRxkjE
|
||||
eJ5U+0CAEjJk20V9/u2Fu76i1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVH
|
||||
miYbRgh5Fy6336KepLCtCmV/r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwR
|
||||
my47YlhspwszKdRP+wIDAQAB
|
||||
-----END PUBLIC KEY-----`)
|
||||
priKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXAIBAAKBgQD7bq4FLG0ctccbEFEsUBuRxkjEeJ5U+0CAEjJk20V9/u2Fu76i
|
||||
1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVHmiYbRgh5Fy6336KepLCtCmV/
|
||||
r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwRmy47YlhspwszKdRP+wIDAQAB
|
||||
AoGBANs1qf7UtuSbD1ZnKX5K8V5s07CHwPMygw+lzc3k5ndtNUStZQ2vnAaBXHyH
|
||||
Nm4lJ4AI2mhQ39jQB/1TyP1uAzvpLhT60fRybEq9zgJ/81Gm9bnaEpFJ9bP2bBrY
|
||||
J0jbaTMfbzL/PJFl3J3RGMR40C76h5yRYSnOpMoMiKWnJqrhAkEA/zCOkR+34Pk0
|
||||
Yo3sIP4ranY6AAvwacgNaui4ll5xeYwv3iLOQvPlpxIxFHKXEY0klNNyjjXqgYjP
|
||||
cOenqtt6UwJBAPw7EYuteVHvHvQVuTbKAaYHcOrp4nFeZF3ndFfl0w2dwGhfzcXO
|
||||
ROyd5dNQCuCWRo8JBpjG6PFyzezayF4KLrkCQCGditoxHG7FRRJKcbVy5dMzWbaR
|
||||
3AyDLslLeK1OKZKCVffkC9mj+TeF3PM9mQrV1eDI7ckv7wE7PWA5E8wc90MCQEOV
|
||||
MCZU3OTvRUPxbicYCUkLRV4sPNhTimD+21WR5vMHCb7trJ0Ln7wmsqXkFIYIve8l
|
||||
Y/cblN7c/AAyvu0znUECQA318nPldsxR6+H8HTS3uEbkL4UJdjQJHsvTwKxAw5qc
|
||||
moKExvRlN0zmGGuArKcqS38KG7PXZMrUv3FXPdp6BDQ=
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
|
||||
)
|
||||
|
||||
type requestSettings struct {
|
||||
method string
|
||||
url string
|
||||
body io.Reader
|
||||
strict bool
|
||||
crypt bool
|
||||
requestUri string
|
||||
timestamp int64
|
||||
fingerprint string
|
||||
missHeader bool
|
||||
signature string
|
||||
}
|
||||
|
||||
func init() {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
}
|
||||
|
||||
func TestContentSecurityHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
method string
|
||||
url string
|
||||
body string
|
||||
strict bool
|
||||
crypt bool
|
||||
requestUri string
|
||||
timestamp int64
|
||||
fingerprint string
|
||||
missHeader bool
|
||||
signature string
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
method: http.MethodGet,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
strict: true,
|
||||
crypt: false,
|
||||
},
|
||||
{
|
||||
method: http.MethodPost,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
body: "hello",
|
||||
strict: true,
|
||||
crypt: false,
|
||||
},
|
||||
{
|
||||
method: http.MethodGet,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
strict: true,
|
||||
crypt: true,
|
||||
},
|
||||
{
|
||||
method: http.MethodPost,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
body: "hello",
|
||||
strict: true,
|
||||
crypt: true,
|
||||
},
|
||||
{
|
||||
method: http.MethodGet,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
strict: true,
|
||||
crypt: true,
|
||||
timestamp: time.Now().Add(timeDiff).Unix(),
|
||||
statusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
method: http.MethodPost,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
body: "hello",
|
||||
strict: true,
|
||||
crypt: true,
|
||||
timestamp: time.Now().Add(-timeDiff).Unix(),
|
||||
statusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
method: http.MethodPost,
|
||||
url: "http://remotehost/",
|
||||
body: "hello",
|
||||
strict: true,
|
||||
crypt: true,
|
||||
requestUri: "http://localhost/a/b?c=d&e=f",
|
||||
},
|
||||
{
|
||||
method: http.MethodPost,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
body: "hello",
|
||||
strict: false,
|
||||
crypt: true,
|
||||
fingerprint: "badone",
|
||||
},
|
||||
{
|
||||
method: http.MethodPost,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
body: "hello",
|
||||
strict: true,
|
||||
crypt: true,
|
||||
timestamp: time.Now().Add(-timeDiff).Unix(),
|
||||
fingerprint: "badone",
|
||||
statusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
method: http.MethodPost,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
body: "hello",
|
||||
strict: true,
|
||||
crypt: true,
|
||||
missHeader: true,
|
||||
statusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
method: http.MethodHead,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
strict: true,
|
||||
crypt: false,
|
||||
},
|
||||
{
|
||||
method: http.MethodGet,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
strict: true,
|
||||
crypt: false,
|
||||
signature: "badone",
|
||||
statusCode: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.url, func(t *testing.T) {
|
||||
if test.statusCode == 0 {
|
||||
test.statusCode = http.StatusOK
|
||||
}
|
||||
if len(test.fingerprint) == 0 {
|
||||
test.fingerprint = fingerprint
|
||||
}
|
||||
if test.timestamp == 0 {
|
||||
test.timestamp = time.Now().Unix()
|
||||
}
|
||||
|
||||
func() {
|
||||
keyFile, err := createTempFile(priKey)
|
||||
defer os.Remove(keyFile)
|
||||
|
||||
assert.Nil(t, err)
|
||||
decrypter, err := codec.NewRsaDecrypter(keyFile)
|
||||
assert.Nil(t, err)
|
||||
contentSecurityHandler := ContentSecurityHandler(map[string]codec.RsaDecrypter{
|
||||
fingerprint: decrypter,
|
||||
}, time.Hour, test.strict)
|
||||
handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
}))
|
||||
|
||||
var reader io.Reader
|
||||
if len(test.body) > 0 {
|
||||
reader = strings.NewReader(test.body)
|
||||
}
|
||||
setting := requestSettings{
|
||||
method: test.method,
|
||||
url: test.url,
|
||||
body: reader,
|
||||
strict: test.strict,
|
||||
crypt: test.crypt,
|
||||
requestUri: test.requestUri,
|
||||
timestamp: test.timestamp,
|
||||
fingerprint: test.fingerprint,
|
||||
missHeader: test.missHeader,
|
||||
signature: test.signature,
|
||||
}
|
||||
req, err := buildRequest(setting)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, test.statusCode, resp.Code)
|
||||
}()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContentSecurityHandler_UnsignedCallback(t *testing.T) {
|
||||
keyFile, err := createTempFile(priKey)
|
||||
defer os.Remove(keyFile)
|
||||
|
||||
assert.Nil(t, err)
|
||||
decrypter, err := codec.NewRsaDecrypter(keyFile)
|
||||
assert.Nil(t, err)
|
||||
contentSecurityHandler := ContentSecurityHandler(
|
||||
map[string]codec.RsaDecrypter{
|
||||
fingerprint: decrypter,
|
||||
},
|
||||
time.Hour,
|
||||
true,
|
||||
func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
|
||||
setting := requestSettings{
|
||||
method: http.MethodGet,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
signature: "badone",
|
||||
}
|
||||
req, err := buildRequest(setting)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
|
||||
func TestContentSecurityHandler_UnsignedCallback_WrongTime(t *testing.T) {
|
||||
keyFile, err := createTempFile(priKey)
|
||||
defer os.Remove(keyFile)
|
||||
|
||||
assert.Nil(t, err)
|
||||
decrypter, err := codec.NewRsaDecrypter(keyFile)
|
||||
assert.Nil(t, err)
|
||||
contentSecurityHandler := ContentSecurityHandler(
|
||||
map[string]codec.RsaDecrypter{
|
||||
fingerprint: decrypter,
|
||||
},
|
||||
time.Hour,
|
||||
true,
|
||||
func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
|
||||
assert.Equal(t, httpx.CodeSignatureWrongTime, code)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
|
||||
var reader io.Reader
|
||||
reader = strings.NewReader("hello")
|
||||
setting := requestSettings{
|
||||
method: http.MethodPost,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
body: reader,
|
||||
strict: true,
|
||||
crypt: true,
|
||||
timestamp: time.Now().Add(time.Hour * 24 * 365).Unix(),
|
||||
fingerprint: fingerprint,
|
||||
}
|
||||
req, err := buildRequest(setting)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
|
||||
func buildRequest(rs requestSettings) (*http.Request, error) {
|
||||
var bodyStr string
|
||||
var err error
|
||||
|
||||
if rs.crypt && rs.body != nil {
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, rs.body)
|
||||
bodyBytes, err := codec.EcbEncrypt(key, buf.Bytes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bodyStr = base64.StdEncoding.EncodeToString(bodyBytes)
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(rs.method, rs.url, strings.NewReader(bodyStr))
|
||||
if len(rs.signature) == 0 {
|
||||
sha := sha256.New()
|
||||
sha.Write([]byte(bodyStr))
|
||||
bodySign := fmt.Sprintf("%x", sha.Sum(nil))
|
||||
var path string
|
||||
var query string
|
||||
if len(rs.requestUri) > 0 {
|
||||
if u, err := url.Parse(rs.requestUri); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
path = u.Path
|
||||
query = u.RawQuery
|
||||
}
|
||||
} else {
|
||||
path = r.URL.Path
|
||||
query = r.URL.RawQuery
|
||||
}
|
||||
contentOfSign := strings.Join([]string{
|
||||
strconv.FormatInt(rs.timestamp, 10),
|
||||
rs.method,
|
||||
path,
|
||||
query,
|
||||
bodySign,
|
||||
}, "\n")
|
||||
rs.signature = codec.HmacBase64([]byte(key), contentOfSign)
|
||||
}
|
||||
|
||||
var mode string
|
||||
if rs.crypt {
|
||||
mode = "1"
|
||||
} else {
|
||||
mode = "0"
|
||||
}
|
||||
content := strings.Join([]string{
|
||||
"version=v1",
|
||||
"type=" + mode,
|
||||
fmt.Sprintf("key=%s", base64.StdEncoding.EncodeToString(key)),
|
||||
"time=" + strconv.FormatInt(rs.timestamp, 10),
|
||||
}, "; ")
|
||||
|
||||
encrypter, err := codec.NewRsaEncrypter([]byte(pubKey))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
output, err := encrypter.Encrypt([]byte(content))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
encryptedContent := base64.StdEncoding.EncodeToString(output)
|
||||
if !rs.missHeader {
|
||||
r.Header.Set(httpx.ContentSecurity, strings.Join([]string{
|
||||
fmt.Sprintf("key=%s", rs.fingerprint),
|
||||
"secret=" + encryptedContent,
|
||||
"signature=" + rs.signature,
|
||||
}, "; "))
|
||||
}
|
||||
if len(rs.requestUri) > 0 {
|
||||
r.Header.Set("X-Request-Uri", rs.requestUri)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func createTempFile(body []byte) (string, error) {
|
||||
tmpFile, err := ioutil.TempFile(os.TempDir(), "go-unit-*.tmp")
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else {
|
||||
tmpFile.Close()
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(tmpFile.Name(), body, os.ModePerm)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return tmpFile.Name(), nil
|
||||
}
|
||||
101
rest/handler/cryptionhandler.go
Normal file
101
rest/handler/cryptionhandler.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"zero/core/codec"
|
||||
"zero/core/logx"
|
||||
)
|
||||
|
||||
const maxBytes = 1 << 20 // 1 MiB
|
||||
|
||||
func CryptionHandler(key []byte) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cw := newCryptionResponseWriter(w)
|
||||
defer cw.flush(key)
|
||||
|
||||
if r.ContentLength <= 0 {
|
||||
next.ServeHTTP(cw, r)
|
||||
return
|
||||
}
|
||||
|
||||
if err := decryptBody(key, r); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(cw, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func decryptBody(key []byte, r *http.Request) error {
|
||||
content, err := ioutil.ReadAll(io.LimitReader(r.Body, maxBytes))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
content, err = base64.StdEncoding.DecodeString(string(content))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
output, err := codec.EcbDecrypt(key, content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.Write(output)
|
||||
r.Body = ioutil.NopCloser(&buf)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type cryptionResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
func newCryptionResponseWriter(w http.ResponseWriter) *cryptionResponseWriter {
|
||||
return &cryptionResponseWriter{
|
||||
ResponseWriter: w,
|
||||
buf: new(bytes.Buffer),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *cryptionResponseWriter) Header() http.Header {
|
||||
return w.ResponseWriter.Header()
|
||||
}
|
||||
|
||||
func (w *cryptionResponseWriter) Write(p []byte) (int, error) {
|
||||
return w.buf.Write(p)
|
||||
}
|
||||
|
||||
func (w *cryptionResponseWriter) WriteHeader(statusCode int) {
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *cryptionResponseWriter) flush(key []byte) {
|
||||
if w.buf.Len() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
content, err := codec.EcbEncrypt(key, w.buf.Bytes())
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
body := base64.StdEncoding.EncodeToString(content)
|
||||
if n, err := io.WriteString(w.ResponseWriter, body); err != nil {
|
||||
logx.Errorf("write response failed, error: %s", err)
|
||||
} else if n < len(content) {
|
||||
logx.Errorf("actual bytes: %d, written bytes: %d", len(content), n)
|
||||
}
|
||||
}
|
||||
90
rest/handler/cryptionhandler_test.go
Normal file
90
rest/handler/cryptionhandler_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"zero/core/codec"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
reqText = "ping"
|
||||
respText = "pong"
|
||||
)
|
||||
|
||||
var aesKey = []byte(`PdSgVkYp3s6v9y$B&E)H+MbQeThWmZq4`)
|
||||
|
||||
func init() {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
}
|
||||
|
||||
func TestCryptionHandlerGet(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/any", nil)
|
||||
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte(respText))
|
||||
w.Header().Set("X-Test", "test")
|
||||
assert.Nil(t, err)
|
||||
}))
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, "test", recorder.Header().Get("X-Test"))
|
||||
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestCryptionHandlerPost(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
|
||||
assert.Nil(t, err)
|
||||
buf.WriteString(base64.StdEncoding.EncodeToString(enc))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/any", &buf)
|
||||
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, reqText, string(body))
|
||||
|
||||
w.Write([]byte(respText))
|
||||
}))
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestCryptionHandlerPostBadEncryption(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
|
||||
assert.Nil(t, err)
|
||||
buf.Write(enc)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/any", &buf)
|
||||
handler := CryptionHandler(aesKey)(nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestCryptionHandlerWriteHeader(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/any", nil)
|
||||
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
|
||||
}
|
||||
27
rest/handler/gunziphandler.go
Normal file
27
rest/handler/gunziphandler.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"zero/rest/httpx"
|
||||
)
|
||||
|
||||
const gzipEncoding = "gzip"
|
||||
|
||||
func GunzipHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.Header.Get(httpx.ContentEncoding), gzipEncoding) {
|
||||
reader, err := gzip.NewReader(r.Body)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = reader
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
66
rest/handler/gunziphandler_test.go
Normal file
66
rest/handler/gunziphandler_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"zero/core/codec"
|
||||
"zero/rest/httpx"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGunzipHandler(t *testing.T) {
|
||||
const message = "hello world"
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
handler := GunzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, string(body), message)
|
||||
wg.Done()
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "http://localhost",
|
||||
bytes.NewReader(codec.Gzip([]byte(message))))
|
||||
req.Header.Set(httpx.ContentEncoding, gzipEncoding)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestGunzipHandler_NoGzip(t *testing.T) {
|
||||
const message = "hello world"
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
handler := GunzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, string(body), message)
|
||||
wg.Done()
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "http://localhost",
|
||||
strings.NewReader(message))
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestGunzipHandler_NoGzipButTelling(t *testing.T) {
|
||||
const message = "hello world"
|
||||
handler := GunzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
req := httptest.NewRequest(http.MethodPost, "http://localhost",
|
||||
strings.NewReader(message))
|
||||
req.Header.Set(httpx.ContentEncoding, gzipEncoding)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusBadRequest, resp.Code)
|
||||
}
|
||||
165
rest/handler/loghandler.go
Normal file
165
rest/handler/loghandler.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"zero/core/iox"
|
||||
"zero/core/logx"
|
||||
"zero/core/timex"
|
||||
"zero/core/utils"
|
||||
"zero/rest/internal"
|
||||
)
|
||||
|
||||
const slowThreshold = time.Millisecond * 500
|
||||
|
||||
type LoggedResponseWriter struct {
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
code int
|
||||
}
|
||||
|
||||
func (w *LoggedResponseWriter) Header() http.Header {
|
||||
return w.w.Header()
|
||||
}
|
||||
|
||||
func (w *LoggedResponseWriter) Write(bytes []byte) (int, error) {
|
||||
return w.w.Write(bytes)
|
||||
}
|
||||
|
||||
func (w *LoggedResponseWriter) WriteHeader(code int) {
|
||||
w.w.WriteHeader(code)
|
||||
w.code = code
|
||||
}
|
||||
|
||||
func LogHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
timer := utils.NewElapsedTimer()
|
||||
logs := new(internal.LogCollector)
|
||||
lrw := LoggedResponseWriter{
|
||||
w: w,
|
||||
r: r,
|
||||
code: http.StatusOK,
|
||||
}
|
||||
|
||||
var dup io.ReadCloser
|
||||
r.Body, dup = iox.DupReadCloser(r.Body)
|
||||
next.ServeHTTP(&lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs)))
|
||||
r.Body = dup
|
||||
logBrief(r, lrw.code, timer, logs)
|
||||
})
|
||||
}
|
||||
|
||||
type DetailLoggedResponseWriter struct {
|
||||
writer *LoggedResponseWriter
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
func newDetailLoggedResponseWriter(writer *LoggedResponseWriter, buf *bytes.Buffer) *DetailLoggedResponseWriter {
|
||||
return &DetailLoggedResponseWriter{
|
||||
writer: writer,
|
||||
buf: buf,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *DetailLoggedResponseWriter) Header() http.Header {
|
||||
return w.writer.Header()
|
||||
}
|
||||
|
||||
func (w *DetailLoggedResponseWriter) Write(bs []byte) (int, error) {
|
||||
w.buf.Write(bs)
|
||||
return w.writer.Write(bs)
|
||||
}
|
||||
|
||||
func (w *DetailLoggedResponseWriter) WriteHeader(code int) {
|
||||
w.writer.WriteHeader(code)
|
||||
}
|
||||
|
||||
func DetailedLogHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
timer := utils.NewElapsedTimer()
|
||||
var buf bytes.Buffer
|
||||
lrw := newDetailLoggedResponseWriter(&LoggedResponseWriter{
|
||||
w: w,
|
||||
r: r,
|
||||
code: http.StatusOK,
|
||||
}, &buf)
|
||||
|
||||
var dup io.ReadCloser
|
||||
r.Body, dup = iox.DupReadCloser(r.Body)
|
||||
logs := new(internal.LogCollector)
|
||||
next.ServeHTTP(lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs)))
|
||||
r.Body = dup
|
||||
logDetails(r, lrw, timer, logs)
|
||||
})
|
||||
}
|
||||
|
||||
func dumpRequest(r *http.Request) string {
|
||||
reqContent, err := httputil.DumpRequest(r, true)
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
} else {
|
||||
return string(reqContent)
|
||||
}
|
||||
}
|
||||
|
||||
func logBrief(r *http.Request, code int, timer *utils.ElapsedTimer, logs *internal.LogCollector) {
|
||||
var buf bytes.Buffer
|
||||
duration := timer.Duration()
|
||||
buf.WriteString(fmt.Sprintf("%d - %s - %s - %s - %s",
|
||||
code, r.RequestURI, internal.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration)))
|
||||
if duration > slowThreshold {
|
||||
logx.Slowf("[HTTP] %d - %s - %s - %s - slowcall(%s)",
|
||||
code, r.RequestURI, internal.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration))
|
||||
}
|
||||
|
||||
ok := isOkResponse(code)
|
||||
if !ok {
|
||||
buf.WriteString(fmt.Sprintf("\n%s", dumpRequest(r)))
|
||||
}
|
||||
|
||||
body := logs.Flush()
|
||||
if len(body) > 0 {
|
||||
buf.WriteString(fmt.Sprintf("\n%s", body))
|
||||
}
|
||||
|
||||
if ok {
|
||||
logx.Info(buf.String())
|
||||
} else {
|
||||
logx.Error(buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func logDetails(r *http.Request, response *DetailLoggedResponseWriter, timer *utils.ElapsedTimer,
|
||||
logs *internal.LogCollector) {
|
||||
var buf bytes.Buffer
|
||||
duration := timer.Duration()
|
||||
buf.WriteString(fmt.Sprintf("%d - %s - %s\n=> %s\n",
|
||||
response.writer.code, r.RemoteAddr, timex.ReprOfDuration(duration), dumpRequest(r)))
|
||||
if duration > slowThreshold {
|
||||
logx.Slowf("[HTTP] %d - %s - slowcall(%s)\n=> %s\n",
|
||||
response.writer.code, r.RemoteAddr, timex.ReprOfDuration(duration), dumpRequest(r))
|
||||
}
|
||||
|
||||
body := logs.Flush()
|
||||
if len(body) > 0 {
|
||||
buf.WriteString(fmt.Sprintf("%s\n", body))
|
||||
}
|
||||
|
||||
respBuf := response.buf.Bytes()
|
||||
if len(respBuf) > 0 {
|
||||
buf.WriteString(fmt.Sprintf("<= %s", respBuf))
|
||||
}
|
||||
|
||||
logx.Info(buf.String())
|
||||
}
|
||||
|
||||
func isOkResponse(code int) bool {
|
||||
// not server error
|
||||
return code < http.StatusInternalServerError
|
||||
}
|
||||
74
rest/handler/loghandler_test.go
Normal file
74
rest/handler/loghandler_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"zero/rest/internal"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
}
|
||||
|
||||
func TestLogHandler(t *testing.T) {
|
||||
handlers := []func(handler http.Handler) http.Handler{
|
||||
LogHandler,
|
||||
DetailedLogHandler,
|
||||
}
|
||||
|
||||
for _, logHandler := range handlers {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.Context().Value(internal.LogContext).(*internal.LogCollector).Append("anything")
|
||||
w.Header().Set("X-Test", "test")
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
_, err := w.Write([]byte("content"))
|
||||
assert.Nil(t, err)
|
||||
}))
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
|
||||
assert.Equal(t, "test", resp.Header().Get("X-Test"))
|
||||
assert.Equal(t, "content", resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogHandlerSlow(t *testing.T) {
|
||||
handlers := []func(handler http.Handler) http.Handler{
|
||||
LogHandler,
|
||||
DetailedLogHandler,
|
||||
}
|
||||
|
||||
for _, logHandler := range handlers {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(slowThreshold + time.Millisecond*50)
|
||||
}))
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLogHandler(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
}
|
||||
}
|
||||
27
rest/handler/maxbyteshandler.go
Normal file
27
rest/handler/maxbyteshandler.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"zero/rest/internal"
|
||||
)
|
||||
|
||||
func MaxBytesHandler(n int64) func(http.Handler) http.Handler {
|
||||
if n <= 0 {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return next
|
||||
}
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.ContentLength > n {
|
||||
internal.Errorf(r, "request entity too large, limit is %d, but got %d, rejected with code %d",
|
||||
n, r.ContentLength, http.StatusRequestEntityTooLarge)
|
||||
w.WriteHeader(http.StatusRequestEntityTooLarge)
|
||||
} else {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
37
rest/handler/maxbyteshandler_test.go
Normal file
37
rest/handler/maxbyteshandler_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMaxBytesHandler(t *testing.T) {
|
||||
maxb := MaxBytesHandler(10)
|
||||
handler := maxb(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "http://localhost",
|
||||
bytes.NewBufferString("123456789012345"))
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusRequestEntityTooLarge, resp.Code)
|
||||
|
||||
req = httptest.NewRequest(http.MethodPost, "http://localhost", bytes.NewBufferString("12345"))
|
||||
resp = httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
|
||||
func TestMaxBytesHandlerNoLimit(t *testing.T) {
|
||||
maxb := MaxBytesHandler(-1)
|
||||
handler := maxb(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "http://localhost",
|
||||
bytes.NewBufferString("123456789012345"))
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
37
rest/handler/maxconnshandler.go
Normal file
37
rest/handler/maxconnshandler.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/core/syncx"
|
||||
"zero/rest/internal"
|
||||
)
|
||||
|
||||
func MaxConns(n int) func(http.Handler) http.Handler {
|
||||
if n <= 0 {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return next
|
||||
}
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
latchLimiter := syncx.NewLimit(n)
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if latchLimiter.TryBorrow() {
|
||||
defer func() {
|
||||
if err := latchLimiter.Return(); err != nil {
|
||||
logx.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
} else {
|
||||
internal.Errorf(r, "Concurrent connections over %d, rejected with code %d",
|
||||
n, http.StatusServiceUnavailable)
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
80
rest/handler/maxconnshandler_test.go
Normal file
80
rest/handler/maxconnshandler_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"zero/core/lang"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const conns = 4
|
||||
|
||||
func init() {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
}
|
||||
|
||||
func TestMaxConnsHandler(t *testing.T) {
|
||||
var waitGroup sync.WaitGroup
|
||||
waitGroup.Add(conns)
|
||||
done := make(chan lang.PlaceholderType)
|
||||
defer close(done)
|
||||
|
||||
maxConns := MaxConns(conns)
|
||||
handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
waitGroup.Done()
|
||||
<-done
|
||||
}))
|
||||
|
||||
for i := 0; i < conns; i++ {
|
||||
go func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
handler.ServeHTTP(httptest.NewRecorder(), req)
|
||||
}()
|
||||
}
|
||||
|
||||
waitGroup.Wait()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
|
||||
}
|
||||
|
||||
func TestWithoutMaxConnsHandler(t *testing.T) {
|
||||
const (
|
||||
key = "block"
|
||||
value = "1"
|
||||
)
|
||||
var waitGroup sync.WaitGroup
|
||||
waitGroup.Add(conns)
|
||||
done := make(chan lang.PlaceholderType)
|
||||
defer close(done)
|
||||
|
||||
maxConns := MaxConns(0)
|
||||
handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
val := r.Header.Get(key)
|
||||
if val == value {
|
||||
waitGroup.Done()
|
||||
<-done
|
||||
}
|
||||
}))
|
||||
|
||||
for i := 0; i < conns; i++ {
|
||||
go func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Set(key, value)
|
||||
handler.ServeHTTP(httptest.NewRecorder(), req)
|
||||
}()
|
||||
}
|
||||
|
||||
waitGroup.Wait()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
23
rest/handler/metrichandler.go
Normal file
23
rest/handler/metrichandler.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"zero/core/stat"
|
||||
"zero/core/timex"
|
||||
)
|
||||
|
||||
func MetricHandler(metrics *stat.Metrics) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
metrics.Add(stat.Task{
|
||||
Duration: timex.Since(startTime),
|
||||
})
|
||||
}()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
24
rest/handler/metrichandler_test.go
Normal file
24
rest/handler/metrichandler_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"zero/core/stat"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMetricHandler(t *testing.T) {
|
||||
metrics := stat.NewMetrics("unit-test")
|
||||
metricHandler := MetricHandler(metrics)
|
||||
handler := metricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
47
rest/handler/prommetrichandler.go
Normal file
47
rest/handler/prommetrichandler.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"zero/core/metric"
|
||||
"zero/core/timex"
|
||||
"zero/rest/internal/security"
|
||||
)
|
||||
|
||||
const serverNamespace = "http_server"
|
||||
|
||||
var (
|
||||
metricServerReqDur = metric.NewHistogramVec(&metric.HistogramVecOpts{
|
||||
Namespace: serverNamespace,
|
||||
Subsystem: "requests",
|
||||
Name: "duration_ms",
|
||||
Help: "http server requests duration(ms).",
|
||||
Labels: []string{"path"},
|
||||
Buckets: []float64{5, 10, 25, 50, 100, 250, 500, 1000},
|
||||
})
|
||||
|
||||
metricServerReqCodeTotal = metric.NewCounterVec(&metric.CounterVecOpts{
|
||||
Namespace: serverNamespace,
|
||||
Subsystem: "requests",
|
||||
Name: "code_total",
|
||||
Help: "http server requests error count.",
|
||||
Labels: []string{"path", "code"},
|
||||
})
|
||||
)
|
||||
|
||||
func PromMetricHandler(path string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
startTime := timex.Now()
|
||||
cw := &security.WithCodeResponseWriter{Writer: w}
|
||||
defer func() {
|
||||
metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), path)
|
||||
metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code))
|
||||
}()
|
||||
|
||||
next.ServeHTTP(cw, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
21
rest/handler/prommetrichandler_test.go
Normal file
21
rest/handler/prommetrichandler_test.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPromMetricHandler(t *testing.T) {
|
||||
promMetricHandler := PromMetricHandler("/user/login")
|
||||
handler := promMetricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
22
rest/handler/recoverhandler.go
Normal file
22
rest/handler/recoverhandler.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
|
||||
"zero/rest/internal"
|
||||
)
|
||||
|
||||
func RecoverHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if result := recover(); result != nil {
|
||||
internal.Error(r, fmt.Sprintf("%v\n%s", result, debug.Stack()))
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
36
rest/handler/recoverhandler_test.go
Normal file
36
rest/handler/recoverhandler_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
}
|
||||
|
||||
func TestWithPanic(t *testing.T) {
|
||||
handler := RecoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("whatever")
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, resp.Code)
|
||||
}
|
||||
|
||||
func TestWithoutPanic(t *testing.T) {
|
||||
handler := RecoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
63
rest/handler/sheddinghandler.go
Normal file
63
rest/handler/sheddinghandler.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"zero/core/load"
|
||||
"zero/core/logx"
|
||||
"zero/core/stat"
|
||||
"zero/rest/internal"
|
||||
"zero/rest/internal/security"
|
||||
)
|
||||
|
||||
const serviceType = "api"
|
||||
|
||||
var (
|
||||
sheddingStat *load.SheddingStat
|
||||
lock sync.Mutex
|
||||
)
|
||||
|
||||
func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Handler) http.Handler {
|
||||
if shedder == nil {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return next
|
||||
}
|
||||
}
|
||||
|
||||
ensureSheddingStat()
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sheddingStat.IncrementTotal()
|
||||
promise, err := shedder.Allow()
|
||||
if err != nil {
|
||||
metrics.AddDrop()
|
||||
sheddingStat.IncrementDrop()
|
||||
logx.Errorf("[http] dropped, %s - %s - %s",
|
||||
r.RequestURI, internal.GetRemoteAddr(r), r.UserAgent())
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
cw := &security.WithCodeResponseWriter{Writer: w}
|
||||
defer func() {
|
||||
if cw.Code == http.StatusServiceUnavailable {
|
||||
promise.Fail()
|
||||
} else {
|
||||
sheddingStat.IncrementPass()
|
||||
promise.Pass()
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(cw, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ensureSheddingStat() {
|
||||
lock.Lock()
|
||||
if sheddingStat == nil {
|
||||
sheddingStat = load.NewSheddingStat(serviceType)
|
||||
}
|
||||
lock.Unlock()
|
||||
}
|
||||
105
rest/handler/sheddinghandler_test.go
Normal file
105
rest/handler/sheddinghandler_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"zero/core/load"
|
||||
"zero/core/stat"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
}
|
||||
|
||||
func TestSheddingHandlerAccept(t *testing.T) {
|
||||
metrics := stat.NewMetrics("unit-test")
|
||||
shedder := mockShedder{
|
||||
allow: true,
|
||||
}
|
||||
sheddingHandler := SheddingHandler(shedder, metrics)
|
||||
handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Test", "test")
|
||||
_, err := w.Write([]byte("content"))
|
||||
assert.Nil(t, err)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Set("X-Test", "test")
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
assert.Equal(t, "test", resp.Header().Get("X-Test"))
|
||||
assert.Equal(t, "content", resp.Body.String())
|
||||
}
|
||||
|
||||
func TestSheddingHandlerFail(t *testing.T) {
|
||||
metrics := stat.NewMetrics("unit-test")
|
||||
shedder := mockShedder{
|
||||
allow: true,
|
||||
}
|
||||
sheddingHandler := SheddingHandler(shedder, metrics)
|
||||
handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
|
||||
}
|
||||
|
||||
func TestSheddingHandlerReject(t *testing.T) {
|
||||
metrics := stat.NewMetrics("unit-test")
|
||||
shedder := mockShedder{
|
||||
allow: false,
|
||||
}
|
||||
sheddingHandler := SheddingHandler(shedder, metrics)
|
||||
handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
|
||||
}
|
||||
|
||||
func TestSheddingHandlerNoShedding(t *testing.T) {
|
||||
metrics := stat.NewMetrics("unit-test")
|
||||
sheddingHandler := SheddingHandler(nil, metrics)
|
||||
handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
|
||||
type mockShedder struct {
|
||||
allow bool
|
||||
}
|
||||
|
||||
func (s mockShedder) Allow() (load.Promise, error) {
|
||||
if s.allow {
|
||||
return mockPromise{}, nil
|
||||
} else {
|
||||
return nil, load.ErrServiceOverloaded
|
||||
}
|
||||
}
|
||||
|
||||
type mockPromise struct {
|
||||
}
|
||||
|
||||
func (p mockPromise) Pass() {
|
||||
}
|
||||
|
||||
func (p mockPromise) Fail() {
|
||||
}
|
||||
18
rest/handler/timeouthandler.go
Normal file
18
rest/handler/timeouthandler.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const reason = "Request Timeout"
|
||||
|
||||
func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
if duration > 0 {
|
||||
return http.TimeoutHandler(next, duration, reason)
|
||||
} else {
|
||||
return next
|
||||
}
|
||||
}
|
||||
}
|
||||
52
rest/handler/timeouthandler_test.go
Normal file
52
rest/handler/timeouthandler_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
}
|
||||
|
||||
func TestTimeout(t *testing.T) {
|
||||
timeoutHandler := TimeoutHandler(time.Millisecond)
|
||||
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(time.Minute)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
|
||||
}
|
||||
|
||||
func TestWithinTimeout(t *testing.T) {
|
||||
timeoutHandler := TimeoutHandler(time.Second)
|
||||
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(time.Millisecond)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
|
||||
func TestWithoutTimeout(t *testing.T) {
|
||||
timeoutHandler := TimeoutHandler(0)
|
||||
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
25
rest/handler/tracinghandler.go
Normal file
25
rest/handler/tracinghandler.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/core/sysx"
|
||||
"zero/core/trace"
|
||||
)
|
||||
|
||||
func TracingHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
carrier, err := trace.Extract(trace.HttpFormat, r.Header)
|
||||
// ErrInvalidCarrier means no trace id was set in http header
|
||||
if err != nil && err != trace.ErrInvalidCarrier {
|
||||
logx.Error(err)
|
||||
}
|
||||
|
||||
ctx, span := trace.StartServerSpan(r.Context(), carrier, sysx.Hostname(), r.RequestURI)
|
||||
defer span.Finish()
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
25
rest/handler/tracinghandler_test.go
Normal file
25
rest/handler/tracinghandler_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"zero/core/trace/tracespec"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTracingHandler(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Set("X-Trace-ID", "theid")
|
||||
handler := TracingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
span, ok := r.Context().Value(tracespec.TracingKey).(tracespec.Trace)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "theid", span.TraceId())
|
||||
}))
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
Reference in New Issue
Block a user