rename ngin to rest

This commit is contained in:
kevin
2020-07-31 11:14:48 +08:00
parent e133ffd820
commit 0897f60c5d
78 changed files with 118 additions and 111 deletions

141
rest/handler/authhandler.go Normal file
View File

@@ -0,0 +1,141 @@
package handler
import (
"context"
"errors"
"net/http"
"net/http/httputil"
"zero/core/logx"
"zero/rest/internal"
"github.com/dgrijalva/jwt-go"
)
const (
jwtAudience = "aud"
jwtExpire = "exp"
jwtId = "jti"
jwtIssueAt = "iat"
jwtIssuer = "iss"
jwtNotBefore = "nbf"
jwtSubject = "sub"
noDetailReason = "no detail reason"
)
var (
errInvalidToken = errors.New("invalid auth token")
errNoClaims = errors.New("no auth params")
)
type (
AuthorizeOptions struct {
PrevSecret string
Callback UnauthorizedCallback
}
UnauthorizedCallback func(w http.ResponseWriter, r *http.Request, err error)
AuthorizeOption func(opts *AuthorizeOptions)
)
func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.Handler {
var authOpts AuthorizeOptions
for _, opt := range opts {
opt(&authOpts)
}
parser := internal.NewTokenParser()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, err := parser.ParseToken(r, secret, authOpts.PrevSecret)
if err != nil {
unauthorized(w, r, err, authOpts.Callback)
return
}
if !token.Valid {
unauthorized(w, r, errInvalidToken, authOpts.Callback)
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
unauthorized(w, r, errNoClaims, authOpts.Callback)
return
}
ctx := r.Context()
for k, v := range claims {
switch k {
case jwtAudience, jwtExpire, jwtId, jwtIssueAt, jwtIssuer, jwtNotBefore, jwtSubject:
// ignore the standard claims
default:
ctx = context.WithValue(ctx, k, v)
}
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
func WithPrevSecret(secret string) AuthorizeOption {
return func(opts *AuthorizeOptions) {
opts.PrevSecret = secret
}
}
func WithUnauthorizedCallback(callback UnauthorizedCallback) AuthorizeOption {
return func(opts *AuthorizeOptions) {
opts.Callback = callback
}
}
func detailAuthLog(r *http.Request, reason string) {
// discard dump error, only for debug purpose
details, _ := httputil.DumpRequest(r, true)
logx.Errorf("authorize failed: %s\n=> %+v", reason, string(details))
}
func unauthorized(w http.ResponseWriter, r *http.Request, err error, callback UnauthorizedCallback) {
writer := newGuardedResponseWriter(w)
if err != nil {
detailAuthLog(r, err.Error())
} else {
detailAuthLog(r, noDetailReason)
}
if callback != nil {
callback(writer, r, err)
}
writer.WriteHeader(http.StatusUnauthorized)
}
type guardedResponseWriter struct {
writer http.ResponseWriter
wroteHeader bool
}
func newGuardedResponseWriter(w http.ResponseWriter) *guardedResponseWriter {
return &guardedResponseWriter{
writer: w,
}
}
func (grw *guardedResponseWriter) Header() http.Header {
return grw.writer.Header()
}
func (grw *guardedResponseWriter) Write(body []byte) (int, error) {
return grw.writer.Write(body)
}
func (grw *guardedResponseWriter) WriteHeader(statusCode int) {
if grw.wroteHeader {
return
}
grw.wroteHeader = true
grw.writer.WriteHeader(statusCode)
}

View File

@@ -0,0 +1,99 @@
package handler
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/dgrijalva/jwt-go"
"github.com/stretchr/testify/assert"
)
func TestAuthHandlerFailed(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler := Authorize("B63F477D-BBA3-4E52-96D3-C0034C27694A", WithUnauthorizedCallback(
func(w http.ResponseWriter, r *http.Request, err error) {
w.Header().Set("X-Test", "test")
w.WriteHeader(http.StatusUnauthorized)
_, err = w.Write([]byte("content"))
assert.Nil(t, err)
}))(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusUnauthorized, resp.Code)
}
func TestAuthHandler(t *testing.T) {
const key = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
token, err := buildToken(key, map[string]interface{}{
"key": "value",
}, 3600)
assert.Nil(t, err)
req.Header.Set("Authorization", "Bearer "+token)
handler := Authorize(key)(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test", "test")
_, err := w.Write([]byte("content"))
assert.Nil(t, err)
}))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, "test", resp.Header().Get("X-Test"))
assert.Equal(t, "content", resp.Body.String())
}
func TestAuthHandlerWithPrevSecret(t *testing.T) {
const (
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
)
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
token, err := buildToken(key, map[string]interface{}{
"key": "value",
}, 3600)
assert.Nil(t, err)
req.Header.Set("Authorization", "Bearer "+token)
handler := Authorize(key, WithPrevSecret(prevKey))(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test", "test")
_, err := w.Write([]byte("content"))
assert.Nil(t, err)
}))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, "test", resp.Header().Get("X-Test"))
assert.Equal(t, "content", resp.Body.String())
}
func TestAuthHandler_NilError(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
assert.NotPanics(t, func() {
unauthorized(resp, req, nil, nil)
})
}
func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
now := time.Now().Unix()
claims := make(jwt.MapClaims)
claims["exp"] = now + seconds
claims["iat"] = now
for k, v := range payloads {
claims[k] = v
}
token := jwt.New(jwt.SigningMethodHS256)
token.Claims = claims
return token.SignedString([]byte(secretKey))
}

View File

@@ -0,0 +1,41 @@
package handler
import (
"fmt"
"net/http"
"strings"
"zero/core/breaker"
"zero/core/logx"
"zero/core/stat"
"zero/rest/internal"
"zero/rest/internal/security"
)
const breakerSeparator = "://"
func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handler) http.Handler {
brk := breaker.NewBreaker(breaker.WithName(strings.Join([]string{method, path}, breakerSeparator)))
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
promise, err := brk.Allow()
if err != nil {
metrics.AddDrop()
logx.Errorf("[http] dropped, %s - %s - %s",
r.RequestURI, internal.GetRemoteAddr(r), r.UserAgent())
w.WriteHeader(http.StatusServiceUnavailable)
return
}
cw := &security.WithCodeResponseWriter{Writer: w}
defer func() {
if cw.Code < http.StatusInternalServerError {
promise.Accept()
} else {
promise.Reject(fmt.Sprintf("%d %s", cw.Code, http.StatusText(cw.Code)))
}
}()
next.ServeHTTP(cw, r)
})
}
}

View File

@@ -0,0 +1,102 @@
package handler
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"zero/core/logx"
"zero/core/stat"
"github.com/stretchr/testify/assert"
)
func init() {
logx.Disable()
stat.SetReporter(nil)
}
func TestBreakerHandlerAccept(t *testing.T) {
metrics := stat.NewMetrics("unit-test")
breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test", "test")
_, err := w.Write([]byte("content"))
assert.Nil(t, err)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Set("X-Test", "test")
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, "test", resp.Header().Get("X-Test"))
assert.Equal(t, "content", resp.Body.String())
}
func TestBreakerHandlerFail(t *testing.T) {
metrics := stat.NewMetrics("unit-test")
breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadGateway)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusBadGateway, resp.Code)
}
func TestBreakerHandler_4XX(t *testing.T) {
metrics := stat.NewMetrics("unit-test")
breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}))
for i := 0; i < 1000; i++ {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
}
const tries = 100
var pass int
for i := 0; i < tries; i++ {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
if resp.Code == http.StatusBadRequest {
pass++
}
}
assert.Equal(t, tries, pass)
}
func TestBreakerHandlerReject(t *testing.T) {
metrics := stat.NewMetrics("unit-test")
breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
for i := 0; i < 1000; i++ {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
}
var drops int
for i := 0; i < 100; i++ {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
if resp.Code == http.StatusServiceUnavailable {
drops++
}
}
assert.True(t, drops >= 80, fmt.Sprintf("expected to be greater than 80, but got %d", drops))
}

View File

@@ -0,0 +1,61 @@
package handler
import (
"net/http"
"time"
"zero/core/codec"
"zero/core/logx"
"zero/rest/httpx"
"zero/rest/internal/security"
)
const contentSecurity = "X-Content-Security"
type UnsignedCallback func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int)
func ContentSecurityHandler(decrypters map[string]codec.RsaDecrypter, tolerance time.Duration,
strict bool, callbacks ...UnsignedCallback) func(http.Handler) http.Handler {
if len(callbacks) == 0 {
callbacks = append(callbacks, handleVerificationFailure)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodDelete, http.MethodGet, http.MethodPost, http.MethodPut:
header, err := security.ParseContentSecurity(decrypters, r)
if err != nil {
logx.Infof("Signature parse failed, X-Content-Security: %s, error: %s",
r.Header.Get(contentSecurity), err.Error())
executeCallbacks(w, r, next, strict, httpx.CodeSignatureInvalidHeader, callbacks)
} else if code := security.VerifySignature(r, header, tolerance); code != httpx.CodeSignaturePass {
logx.Infof("Signature verification failed, X-Content-Security: %s",
r.Header.Get(contentSecurity))
executeCallbacks(w, r, next, strict, code, callbacks)
} else if r.ContentLength > 0 && header.Encrypted() {
CryptionHandler(header.Key)(next).ServeHTTP(w, r)
} else {
next.ServeHTTP(w, r)
}
default:
next.ServeHTTP(w, r)
}
})
}
}
func executeCallbacks(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool,
code int, callbacks []UnsignedCallback) {
for _, callback := range callbacks {
callback(w, r, next, strict, code)
}
}
func handleVerificationFailure(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
if strict {
w.WriteHeader(http.StatusUnauthorized)
} else {
next.ServeHTTP(w, r)
}
}

View File

@@ -0,0 +1,388 @@
package handler
import (
"bytes"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strconv"
"strings"
"testing"
"time"
"zero/core/codec"
"zero/rest/httpx"
"github.com/stretchr/testify/assert"
)
const timeDiff = time.Hour * 2 * 24
var (
fingerprint = "12345"
pubKey = []byte(`-----BEGIN PUBLIC KEY-----
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQD7bq4FLG0ctccbEFEsUBuRxkjE
eJ5U+0CAEjJk20V9/u2Fu76i1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVH
miYbRgh5Fy6336KepLCtCmV/r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwR
my47YlhspwszKdRP+wIDAQAB
-----END PUBLIC KEY-----`)
priKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQD7bq4FLG0ctccbEFEsUBuRxkjEeJ5U+0CAEjJk20V9/u2Fu76i
1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVHmiYbRgh5Fy6336KepLCtCmV/
r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwRmy47YlhspwszKdRP+wIDAQAB
AoGBANs1qf7UtuSbD1ZnKX5K8V5s07CHwPMygw+lzc3k5ndtNUStZQ2vnAaBXHyH
Nm4lJ4AI2mhQ39jQB/1TyP1uAzvpLhT60fRybEq9zgJ/81Gm9bnaEpFJ9bP2bBrY
J0jbaTMfbzL/PJFl3J3RGMR40C76h5yRYSnOpMoMiKWnJqrhAkEA/zCOkR+34Pk0
Yo3sIP4ranY6AAvwacgNaui4ll5xeYwv3iLOQvPlpxIxFHKXEY0klNNyjjXqgYjP
cOenqtt6UwJBAPw7EYuteVHvHvQVuTbKAaYHcOrp4nFeZF3ndFfl0w2dwGhfzcXO
ROyd5dNQCuCWRo8JBpjG6PFyzezayF4KLrkCQCGditoxHG7FRRJKcbVy5dMzWbaR
3AyDLslLeK1OKZKCVffkC9mj+TeF3PM9mQrV1eDI7ckv7wE7PWA5E8wc90MCQEOV
MCZU3OTvRUPxbicYCUkLRV4sPNhTimD+21WR5vMHCb7trJ0Ln7wmsqXkFIYIve8l
Y/cblN7c/AAyvu0znUECQA318nPldsxR6+H8HTS3uEbkL4UJdjQJHsvTwKxAw5qc
moKExvRlN0zmGGuArKcqS38KG7PXZMrUv3FXPdp6BDQ=
-----END RSA PRIVATE KEY-----`)
key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
)
type requestSettings struct {
method string
url string
body io.Reader
strict bool
crypt bool
requestUri string
timestamp int64
fingerprint string
missHeader bool
signature string
}
func init() {
log.SetOutput(ioutil.Discard)
}
func TestContentSecurityHandler(t *testing.T) {
tests := []struct {
method string
url string
body string
strict bool
crypt bool
requestUri string
timestamp int64
fingerprint string
missHeader bool
signature string
statusCode int
}{
{
method: http.MethodGet,
url: "http://localhost/a/b?c=d&e=f",
strict: true,
crypt: false,
},
{
method: http.MethodPost,
url: "http://localhost/a/b?c=d&e=f",
body: "hello",
strict: true,
crypt: false,
},
{
method: http.MethodGet,
url: "http://localhost/a/b?c=d&e=f",
strict: true,
crypt: true,
},
{
method: http.MethodPost,
url: "http://localhost/a/b?c=d&e=f",
body: "hello",
strict: true,
crypt: true,
},
{
method: http.MethodGet,
url: "http://localhost/a/b?c=d&e=f",
strict: true,
crypt: true,
timestamp: time.Now().Add(timeDiff).Unix(),
statusCode: http.StatusUnauthorized,
},
{
method: http.MethodPost,
url: "http://localhost/a/b?c=d&e=f",
body: "hello",
strict: true,
crypt: true,
timestamp: time.Now().Add(-timeDiff).Unix(),
statusCode: http.StatusUnauthorized,
},
{
method: http.MethodPost,
url: "http://remotehost/",
body: "hello",
strict: true,
crypt: true,
requestUri: "http://localhost/a/b?c=d&e=f",
},
{
method: http.MethodPost,
url: "http://localhost/a/b?c=d&e=f",
body: "hello",
strict: false,
crypt: true,
fingerprint: "badone",
},
{
method: http.MethodPost,
url: "http://localhost/a/b?c=d&e=f",
body: "hello",
strict: true,
crypt: true,
timestamp: time.Now().Add(-timeDiff).Unix(),
fingerprint: "badone",
statusCode: http.StatusUnauthorized,
},
{
method: http.MethodPost,
url: "http://localhost/a/b?c=d&e=f",
body: "hello",
strict: true,
crypt: true,
missHeader: true,
statusCode: http.StatusUnauthorized,
},
{
method: http.MethodHead,
url: "http://localhost/a/b?c=d&e=f",
strict: true,
crypt: false,
},
{
method: http.MethodGet,
url: "http://localhost/a/b?c=d&e=f",
strict: true,
crypt: false,
signature: "badone",
statusCode: http.StatusUnauthorized,
},
}
for _, test := range tests {
t.Run(test.url, func(t *testing.T) {
if test.statusCode == 0 {
test.statusCode = http.StatusOK
}
if len(test.fingerprint) == 0 {
test.fingerprint = fingerprint
}
if test.timestamp == 0 {
test.timestamp = time.Now().Unix()
}
func() {
keyFile, err := createTempFile(priKey)
defer os.Remove(keyFile)
assert.Nil(t, err)
decrypter, err := codec.NewRsaDecrypter(keyFile)
assert.Nil(t, err)
contentSecurityHandler := ContentSecurityHandler(map[string]codec.RsaDecrypter{
fingerprint: decrypter,
}, time.Hour, test.strict)
handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
var reader io.Reader
if len(test.body) > 0 {
reader = strings.NewReader(test.body)
}
setting := requestSettings{
method: test.method,
url: test.url,
body: reader,
strict: test.strict,
crypt: test.crypt,
requestUri: test.requestUri,
timestamp: test.timestamp,
fingerprint: test.fingerprint,
missHeader: test.missHeader,
signature: test.signature,
}
req, err := buildRequest(setting)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, test.statusCode, resp.Code)
}()
})
}
}
func TestContentSecurityHandler_UnsignedCallback(t *testing.T) {
keyFile, err := createTempFile(priKey)
defer os.Remove(keyFile)
assert.Nil(t, err)
decrypter, err := codec.NewRsaDecrypter(keyFile)
assert.Nil(t, err)
contentSecurityHandler := ContentSecurityHandler(
map[string]codec.RsaDecrypter{
fingerprint: decrypter,
},
time.Hour,
true,
func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
w.WriteHeader(http.StatusOK)
})
handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
setting := requestSettings{
method: http.MethodGet,
url: "http://localhost/a/b?c=d&e=f",
signature: "badone",
}
req, err := buildRequest(setting)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}
func TestContentSecurityHandler_UnsignedCallback_WrongTime(t *testing.T) {
keyFile, err := createTempFile(priKey)
defer os.Remove(keyFile)
assert.Nil(t, err)
decrypter, err := codec.NewRsaDecrypter(keyFile)
assert.Nil(t, err)
contentSecurityHandler := ContentSecurityHandler(
map[string]codec.RsaDecrypter{
fingerprint: decrypter,
},
time.Hour,
true,
func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
assert.Equal(t, httpx.CodeSignatureWrongTime, code)
w.WriteHeader(http.StatusOK)
})
handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
var reader io.Reader
reader = strings.NewReader("hello")
setting := requestSettings{
method: http.MethodPost,
url: "http://localhost/a/b?c=d&e=f",
body: reader,
strict: true,
crypt: true,
timestamp: time.Now().Add(time.Hour * 24 * 365).Unix(),
fingerprint: fingerprint,
}
req, err := buildRequest(setting)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}
func buildRequest(rs requestSettings) (*http.Request, error) {
var bodyStr string
var err error
if rs.crypt && rs.body != nil {
var buf bytes.Buffer
io.Copy(&buf, rs.body)
bodyBytes, err := codec.EcbEncrypt(key, buf.Bytes())
if err != nil {
return nil, err
}
bodyStr = base64.StdEncoding.EncodeToString(bodyBytes)
}
r := httptest.NewRequest(rs.method, rs.url, strings.NewReader(bodyStr))
if len(rs.signature) == 0 {
sha := sha256.New()
sha.Write([]byte(bodyStr))
bodySign := fmt.Sprintf("%x", sha.Sum(nil))
var path string
var query string
if len(rs.requestUri) > 0 {
if u, err := url.Parse(rs.requestUri); err != nil {
return nil, err
} else {
path = u.Path
query = u.RawQuery
}
} else {
path = r.URL.Path
query = r.URL.RawQuery
}
contentOfSign := strings.Join([]string{
strconv.FormatInt(rs.timestamp, 10),
rs.method,
path,
query,
bodySign,
}, "\n")
rs.signature = codec.HmacBase64([]byte(key), contentOfSign)
}
var mode string
if rs.crypt {
mode = "1"
} else {
mode = "0"
}
content := strings.Join([]string{
"version=v1",
"type=" + mode,
fmt.Sprintf("key=%s", base64.StdEncoding.EncodeToString(key)),
"time=" + strconv.FormatInt(rs.timestamp, 10),
}, "; ")
encrypter, err := codec.NewRsaEncrypter([]byte(pubKey))
if err != nil {
log.Fatal(err)
}
output, err := encrypter.Encrypt([]byte(content))
if err != nil {
log.Fatal(err)
}
encryptedContent := base64.StdEncoding.EncodeToString(output)
if !rs.missHeader {
r.Header.Set(httpx.ContentSecurity, strings.Join([]string{
fmt.Sprintf("key=%s", rs.fingerprint),
"secret=" + encryptedContent,
"signature=" + rs.signature,
}, "; "))
}
if len(rs.requestUri) > 0 {
r.Header.Set("X-Request-Uri", rs.requestUri)
}
return r, nil
}
func createTempFile(body []byte) (string, error) {
tmpFile, err := ioutil.TempFile(os.TempDir(), "go-unit-*.tmp")
if err != nil {
return "", err
} else {
tmpFile.Close()
}
err = ioutil.WriteFile(tmpFile.Name(), body, os.ModePerm)
if err != nil {
return "", err
}
return tmpFile.Name(), nil
}

View File

@@ -0,0 +1,101 @@
package handler
import (
"bytes"
"encoding/base64"
"io"
"io/ioutil"
"net/http"
"zero/core/codec"
"zero/core/logx"
)
const maxBytes = 1 << 20 // 1 MiB
func CryptionHandler(key []byte) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cw := newCryptionResponseWriter(w)
defer cw.flush(key)
if r.ContentLength <= 0 {
next.ServeHTTP(cw, r)
return
}
if err := decryptBody(key, r); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
next.ServeHTTP(cw, r)
})
}
}
func decryptBody(key []byte, r *http.Request) error {
content, err := ioutil.ReadAll(io.LimitReader(r.Body, maxBytes))
if err != nil {
return err
}
content, err = base64.StdEncoding.DecodeString(string(content))
if err != nil {
return err
}
output, err := codec.EcbDecrypt(key, content)
if err != nil {
return err
}
var buf bytes.Buffer
buf.Write(output)
r.Body = ioutil.NopCloser(&buf)
return nil
}
type cryptionResponseWriter struct {
http.ResponseWriter
buf *bytes.Buffer
}
func newCryptionResponseWriter(w http.ResponseWriter) *cryptionResponseWriter {
return &cryptionResponseWriter{
ResponseWriter: w,
buf: new(bytes.Buffer),
}
}
func (w *cryptionResponseWriter) Header() http.Header {
return w.ResponseWriter.Header()
}
func (w *cryptionResponseWriter) Write(p []byte) (int, error) {
return w.buf.Write(p)
}
func (w *cryptionResponseWriter) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *cryptionResponseWriter) flush(key []byte) {
if w.buf.Len() == 0 {
return
}
content, err := codec.EcbEncrypt(key, w.buf.Bytes())
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
body := base64.StdEncoding.EncodeToString(content)
if n, err := io.WriteString(w.ResponseWriter, body); err != nil {
logx.Errorf("write response failed, error: %s", err)
} else if n < len(content) {
logx.Errorf("actual bytes: %d, written bytes: %d", len(content), n)
}
}

View File

@@ -0,0 +1,90 @@
package handler
import (
"bytes"
"encoding/base64"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"testing"
"zero/core/codec"
"github.com/stretchr/testify/assert"
)
const (
reqText = "ping"
respText = "pong"
)
var aesKey = []byte(`PdSgVkYp3s6v9y$B&E)H+MbQeThWmZq4`)
func init() {
log.SetOutput(ioutil.Discard)
}
func TestCryptionHandlerGet(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/any", nil)
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte(respText))
w.Header().Set("X-Test", "test")
assert.Nil(t, err)
}))
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
assert.Nil(t, err)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "test", recorder.Header().Get("X-Test"))
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
}
func TestCryptionHandlerPost(t *testing.T) {
var buf bytes.Buffer
enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
assert.Nil(t, err)
buf.WriteString(base64.StdEncoding.EncodeToString(enc))
req := httptest.NewRequest(http.MethodPost, "/any", &buf)
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
assert.Nil(t, err)
assert.Equal(t, reqText, string(body))
w.Write([]byte(respText))
}))
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
assert.Nil(t, err)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
}
func TestCryptionHandlerPostBadEncryption(t *testing.T) {
var buf bytes.Buffer
enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
assert.Nil(t, err)
buf.Write(enc)
req := httptest.NewRequest(http.MethodPost, "/any", &buf)
handler := CryptionHandler(aesKey)(nil)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestCryptionHandlerWriteHeader(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/any", nil)
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
}

View File

@@ -0,0 +1,27 @@
package handler
import (
"compress/gzip"
"net/http"
"strings"
"zero/rest/httpx"
)
const gzipEncoding = "gzip"
func GunzipHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.Header.Get(httpx.ContentEncoding), gzipEncoding) {
reader, err := gzip.NewReader(r.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
r.Body = reader
}
next.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,66 @@
package handler
import (
"bytes"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"zero/core/codec"
"zero/rest/httpx"
"github.com/stretchr/testify/assert"
)
func TestGunzipHandler(t *testing.T) {
const message = "hello world"
var wg sync.WaitGroup
wg.Add(1)
handler := GunzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
assert.Nil(t, err)
assert.Equal(t, string(body), message)
wg.Done()
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost",
bytes.NewReader(codec.Gzip([]byte(message))))
req.Header.Set(httpx.ContentEncoding, gzipEncoding)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
wg.Wait()
}
func TestGunzipHandler_NoGzip(t *testing.T) {
const message = "hello world"
var wg sync.WaitGroup
wg.Add(1)
handler := GunzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
assert.Nil(t, err)
assert.Equal(t, string(body), message)
wg.Done()
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost",
strings.NewReader(message))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
wg.Wait()
}
func TestGunzipHandler_NoGzipButTelling(t *testing.T) {
const message = "hello world"
handler := GunzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
req := httptest.NewRequest(http.MethodPost, "http://localhost",
strings.NewReader(message))
req.Header.Set(httpx.ContentEncoding, gzipEncoding)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusBadRequest, resp.Code)
}

165
rest/handler/loghandler.go Normal file
View File

@@ -0,0 +1,165 @@
package handler
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httputil"
"time"
"zero/core/iox"
"zero/core/logx"
"zero/core/timex"
"zero/core/utils"
"zero/rest/internal"
)
const slowThreshold = time.Millisecond * 500
type LoggedResponseWriter struct {
w http.ResponseWriter
r *http.Request
code int
}
func (w *LoggedResponseWriter) Header() http.Header {
return w.w.Header()
}
func (w *LoggedResponseWriter) Write(bytes []byte) (int, error) {
return w.w.Write(bytes)
}
func (w *LoggedResponseWriter) WriteHeader(code int) {
w.w.WriteHeader(code)
w.code = code
}
func LogHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
timer := utils.NewElapsedTimer()
logs := new(internal.LogCollector)
lrw := LoggedResponseWriter{
w: w,
r: r,
code: http.StatusOK,
}
var dup io.ReadCloser
r.Body, dup = iox.DupReadCloser(r.Body)
next.ServeHTTP(&lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs)))
r.Body = dup
logBrief(r, lrw.code, timer, logs)
})
}
type DetailLoggedResponseWriter struct {
writer *LoggedResponseWriter
buf *bytes.Buffer
}
func newDetailLoggedResponseWriter(writer *LoggedResponseWriter, buf *bytes.Buffer) *DetailLoggedResponseWriter {
return &DetailLoggedResponseWriter{
writer: writer,
buf: buf,
}
}
func (w *DetailLoggedResponseWriter) Header() http.Header {
return w.writer.Header()
}
func (w *DetailLoggedResponseWriter) Write(bs []byte) (int, error) {
w.buf.Write(bs)
return w.writer.Write(bs)
}
func (w *DetailLoggedResponseWriter) WriteHeader(code int) {
w.writer.WriteHeader(code)
}
func DetailedLogHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
timer := utils.NewElapsedTimer()
var buf bytes.Buffer
lrw := newDetailLoggedResponseWriter(&LoggedResponseWriter{
w: w,
r: r,
code: http.StatusOK,
}, &buf)
var dup io.ReadCloser
r.Body, dup = iox.DupReadCloser(r.Body)
logs := new(internal.LogCollector)
next.ServeHTTP(lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs)))
r.Body = dup
logDetails(r, lrw, timer, logs)
})
}
func dumpRequest(r *http.Request) string {
reqContent, err := httputil.DumpRequest(r, true)
if err != nil {
return err.Error()
} else {
return string(reqContent)
}
}
func logBrief(r *http.Request, code int, timer *utils.ElapsedTimer, logs *internal.LogCollector) {
var buf bytes.Buffer
duration := timer.Duration()
buf.WriteString(fmt.Sprintf("%d - %s - %s - %s - %s",
code, r.RequestURI, internal.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration)))
if duration > slowThreshold {
logx.Slowf("[HTTP] %d - %s - %s - %s - slowcall(%s)",
code, r.RequestURI, internal.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration))
}
ok := isOkResponse(code)
if !ok {
buf.WriteString(fmt.Sprintf("\n%s", dumpRequest(r)))
}
body := logs.Flush()
if len(body) > 0 {
buf.WriteString(fmt.Sprintf("\n%s", body))
}
if ok {
logx.Info(buf.String())
} else {
logx.Error(buf.String())
}
}
func logDetails(r *http.Request, response *DetailLoggedResponseWriter, timer *utils.ElapsedTimer,
logs *internal.LogCollector) {
var buf bytes.Buffer
duration := timer.Duration()
buf.WriteString(fmt.Sprintf("%d - %s - %s\n=> %s\n",
response.writer.code, r.RemoteAddr, timex.ReprOfDuration(duration), dumpRequest(r)))
if duration > slowThreshold {
logx.Slowf("[HTTP] %d - %s - slowcall(%s)\n=> %s\n",
response.writer.code, r.RemoteAddr, timex.ReprOfDuration(duration), dumpRequest(r))
}
body := logs.Flush()
if len(body) > 0 {
buf.WriteString(fmt.Sprintf("%s\n", body))
}
respBuf := response.buf.Bytes()
if len(respBuf) > 0 {
buf.WriteString(fmt.Sprintf("<= %s", respBuf))
}
logx.Info(buf.String())
}
func isOkResponse(code int) bool {
// not server error
return code < http.StatusInternalServerError
}

View File

@@ -0,0 +1,74 @@
package handler
import (
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"testing"
"time"
"zero/rest/internal"
"github.com/stretchr/testify/assert"
)
func init() {
log.SetOutput(ioutil.Discard)
}
func TestLogHandler(t *testing.T) {
handlers := []func(handler http.Handler) http.Handler{
LogHandler,
DetailedLogHandler,
}
for _, logHandler := range handlers {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Context().Value(internal.LogContext).(*internal.LogCollector).Append("anything")
w.Header().Set("X-Test", "test")
w.WriteHeader(http.StatusServiceUnavailable)
_, err := w.Write([]byte("content"))
assert.Nil(t, err)
}))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
assert.Equal(t, "test", resp.Header().Get("X-Test"))
assert.Equal(t, "content", resp.Body.String())
}
}
func TestLogHandlerSlow(t *testing.T) {
handlers := []func(handler http.Handler) http.Handler{
LogHandler,
DetailedLogHandler,
}
for _, logHandler := range handlers {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(slowThreshold + time.Millisecond*50)
}))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}
}
func BenchmarkLogHandler(b *testing.B) {
b.ReportAllocs()
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
for i := 0; i < b.N; i++ {
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
}
}

View File

@@ -0,0 +1,27 @@
package handler
import (
"net/http"
"zero/rest/internal"
)
func MaxBytesHandler(n int64) func(http.Handler) http.Handler {
if n <= 0 {
return func(next http.Handler) http.Handler {
return next
}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.ContentLength > n {
internal.Errorf(r, "request entity too large, limit is %d, but got %d, rejected with code %d",
n, r.ContentLength, http.StatusRequestEntityTooLarge)
w.WriteHeader(http.StatusRequestEntityTooLarge)
} else {
next.ServeHTTP(w, r)
}
})
}
}

View File

@@ -0,0 +1,37 @@
package handler
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestMaxBytesHandler(t *testing.T) {
maxb := MaxBytesHandler(10)
handler := maxb(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
req := httptest.NewRequest(http.MethodPost, "http://localhost",
bytes.NewBufferString("123456789012345"))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusRequestEntityTooLarge, resp.Code)
req = httptest.NewRequest(http.MethodPost, "http://localhost", bytes.NewBufferString("12345"))
resp = httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}
func TestMaxBytesHandlerNoLimit(t *testing.T) {
maxb := MaxBytesHandler(-1)
handler := maxb(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
req := httptest.NewRequest(http.MethodPost, "http://localhost",
bytes.NewBufferString("123456789012345"))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}

View File

@@ -0,0 +1,37 @@
package handler
import (
"net/http"
"zero/core/logx"
"zero/core/syncx"
"zero/rest/internal"
)
func MaxConns(n int) func(http.Handler) http.Handler {
if n <= 0 {
return func(next http.Handler) http.Handler {
return next
}
}
return func(next http.Handler) http.Handler {
latchLimiter := syncx.NewLimit(n)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if latchLimiter.TryBorrow() {
defer func() {
if err := latchLimiter.Return(); err != nil {
logx.Error(err)
}
}()
next.ServeHTTP(w, r)
} else {
internal.Errorf(r, "Concurrent connections over %d, rejected with code %d",
n, http.StatusServiceUnavailable)
w.WriteHeader(http.StatusServiceUnavailable)
}
})
}
}

View File

@@ -0,0 +1,80 @@
package handler
import (
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"sync"
"testing"
"zero/core/lang"
"github.com/stretchr/testify/assert"
)
const conns = 4
func init() {
log.SetOutput(ioutil.Discard)
}
func TestMaxConnsHandler(t *testing.T) {
var waitGroup sync.WaitGroup
waitGroup.Add(conns)
done := make(chan lang.PlaceholderType)
defer close(done)
maxConns := MaxConns(conns)
handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
waitGroup.Done()
<-done
}))
for i := 0; i < conns; i++ {
go func() {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler.ServeHTTP(httptest.NewRecorder(), req)
}()
}
waitGroup.Wait()
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
}
func TestWithoutMaxConnsHandler(t *testing.T) {
const (
key = "block"
value = "1"
)
var waitGroup sync.WaitGroup
waitGroup.Add(conns)
done := make(chan lang.PlaceholderType)
defer close(done)
maxConns := MaxConns(0)
handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
val := r.Header.Get(key)
if val == value {
waitGroup.Done()
<-done
}
}))
for i := 0; i < conns; i++ {
go func() {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Set(key, value)
handler.ServeHTTP(httptest.NewRecorder(), req)
}()
}
waitGroup.Wait()
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}

View File

@@ -0,0 +1,23 @@
package handler
import (
"net/http"
"zero/core/stat"
"zero/core/timex"
)
func MetricHandler(metrics *stat.Metrics) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
startTime := timex.Now()
defer func() {
metrics.Add(stat.Task{
Duration: timex.Since(startTime),
})
}()
next.ServeHTTP(w, r)
})
}
}

View File

@@ -0,0 +1,24 @@
package handler
import (
"net/http"
"net/http/httptest"
"testing"
"zero/core/stat"
"github.com/stretchr/testify/assert"
)
func TestMetricHandler(t *testing.T) {
metrics := stat.NewMetrics("unit-test")
metricHandler := MetricHandler(metrics)
handler := metricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}

View File

@@ -0,0 +1,47 @@
package handler
import (
"net/http"
"strconv"
"time"
"zero/core/metric"
"zero/core/timex"
"zero/rest/internal/security"
)
const serverNamespace = "http_server"
var (
metricServerReqDur = metric.NewHistogramVec(&metric.HistogramVecOpts{
Namespace: serverNamespace,
Subsystem: "requests",
Name: "duration_ms",
Help: "http server requests duration(ms).",
Labels: []string{"path"},
Buckets: []float64{5, 10, 25, 50, 100, 250, 500, 1000},
})
metricServerReqCodeTotal = metric.NewCounterVec(&metric.CounterVecOpts{
Namespace: serverNamespace,
Subsystem: "requests",
Name: "code_total",
Help: "http server requests error count.",
Labels: []string{"path", "code"},
})
)
func PromMetricHandler(path string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
startTime := timex.Now()
cw := &security.WithCodeResponseWriter{Writer: w}
defer func() {
metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), path)
metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code))
}()
next.ServeHTTP(cw, r)
})
}
}

View File

@@ -0,0 +1,21 @@
package handler
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestPromMetricHandler(t *testing.T) {
promMetricHandler := PromMetricHandler("/user/login")
handler := promMetricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}

View File

@@ -0,0 +1,22 @@
package handler
import (
"fmt"
"net/http"
"runtime/debug"
"zero/rest/internal"
)
func RecoverHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if result := recover(); result != nil {
internal.Error(r, fmt.Sprintf("%v\n%s", result, debug.Stack()))
w.WriteHeader(http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,36 @@
package handler
import (
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func init() {
log.SetOutput(ioutil.Discard)
}
func TestWithPanic(t *testing.T) {
handler := RecoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("whatever")
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusInternalServerError, resp.Code)
}
func TestWithoutPanic(t *testing.T) {
handler := RecoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}

View File

@@ -0,0 +1,63 @@
package handler
import (
"net/http"
"sync"
"zero/core/load"
"zero/core/logx"
"zero/core/stat"
"zero/rest/internal"
"zero/rest/internal/security"
)
const serviceType = "api"
var (
sheddingStat *load.SheddingStat
lock sync.Mutex
)
func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Handler) http.Handler {
if shedder == nil {
return func(next http.Handler) http.Handler {
return next
}
}
ensureSheddingStat()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sheddingStat.IncrementTotal()
promise, err := shedder.Allow()
if err != nil {
metrics.AddDrop()
sheddingStat.IncrementDrop()
logx.Errorf("[http] dropped, %s - %s - %s",
r.RequestURI, internal.GetRemoteAddr(r), r.UserAgent())
w.WriteHeader(http.StatusServiceUnavailable)
return
}
cw := &security.WithCodeResponseWriter{Writer: w}
defer func() {
if cw.Code == http.StatusServiceUnavailable {
promise.Fail()
} else {
sheddingStat.IncrementPass()
promise.Pass()
}
}()
next.ServeHTTP(cw, r)
})
}
}
func ensureSheddingStat() {
lock.Lock()
if sheddingStat == nil {
sheddingStat = load.NewSheddingStat(serviceType)
}
lock.Unlock()
}

View File

@@ -0,0 +1,105 @@
package handler
import (
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"testing"
"zero/core/load"
"zero/core/stat"
"github.com/stretchr/testify/assert"
)
func init() {
log.SetOutput(ioutil.Discard)
}
func TestSheddingHandlerAccept(t *testing.T) {
metrics := stat.NewMetrics("unit-test")
shedder := mockShedder{
allow: true,
}
sheddingHandler := SheddingHandler(shedder, metrics)
handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test", "test")
_, err := w.Write([]byte("content"))
assert.Nil(t, err)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Set("X-Test", "test")
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, "test", resp.Header().Get("X-Test"))
assert.Equal(t, "content", resp.Body.String())
}
func TestSheddingHandlerFail(t *testing.T) {
metrics := stat.NewMetrics("unit-test")
shedder := mockShedder{
allow: true,
}
sheddingHandler := SheddingHandler(shedder, metrics)
handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
}
func TestSheddingHandlerReject(t *testing.T) {
metrics := stat.NewMetrics("unit-test")
shedder := mockShedder{
allow: false,
}
sheddingHandler := SheddingHandler(shedder, metrics)
handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
}
func TestSheddingHandlerNoShedding(t *testing.T) {
metrics := stat.NewMetrics("unit-test")
sheddingHandler := SheddingHandler(nil, metrics)
handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}
type mockShedder struct {
allow bool
}
func (s mockShedder) Allow() (load.Promise, error) {
if s.allow {
return mockPromise{}, nil
} else {
return nil, load.ErrServiceOverloaded
}
}
type mockPromise struct {
}
func (p mockPromise) Pass() {
}
func (p mockPromise) Fail() {
}

View File

@@ -0,0 +1,18 @@
package handler
import (
"net/http"
"time"
)
const reason = "Request Timeout"
func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
if duration > 0 {
return http.TimeoutHandler(next, duration, reason)
} else {
return next
}
}
}

View File

@@ -0,0 +1,52 @@
package handler
import (
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func init() {
log.SetOutput(ioutil.Discard)
}
func TestTimeout(t *testing.T) {
timeoutHandler := TimeoutHandler(time.Millisecond)
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Minute)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
}
func TestWithinTimeout(t *testing.T) {
timeoutHandler := TimeoutHandler(time.Second)
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Millisecond)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}
func TestWithoutTimeout(t *testing.T) {
timeoutHandler := TimeoutHandler(0)
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}

View File

@@ -0,0 +1,25 @@
package handler
import (
"net/http"
"zero/core/logx"
"zero/core/sysx"
"zero/core/trace"
)
func TracingHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
carrier, err := trace.Extract(trace.HttpFormat, r.Header)
// ErrInvalidCarrier means no trace id was set in http header
if err != nil && err != trace.ErrInvalidCarrier {
logx.Error(err)
}
ctx, span := trace.StartServerSpan(r.Context(), carrier, sysx.Hostname(), r.RequestURI)
defer span.Finish()
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,25 @@
package handler
import (
"net/http"
"net/http/httptest"
"testing"
"zero/core/trace/tracespec"
"github.com/stretchr/testify/assert"
)
func TestTracingHandler(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Set("X-Trace-ID", "theid")
handler := TracingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
span, ok := r.Context().Value(tracespec.TracingKey).(tracespec.Trace)
assert.True(t, ok)
assert.Equal(t, "theid", span.TraceId())
}))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}