refactor
This commit is contained in:
@@ -1,129 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
|
||||
"zero/core/httpsecurity"
|
||||
"zero/core/logx"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
)
|
||||
|
||||
const (
|
||||
jwtAudience = "aud"
|
||||
jwtExpire = "exp"
|
||||
jwtId = "jti"
|
||||
jwtIssueAt = "iat"
|
||||
jwtIssuer = "iss"
|
||||
jwtNotBefore = "nbf"
|
||||
jwtSubject = "sub"
|
||||
)
|
||||
|
||||
type (
|
||||
AuthorizeOptions struct {
|
||||
PrevSecret string
|
||||
Callback UnauthorizedCallback
|
||||
}
|
||||
|
||||
UnauthorizedCallback func(w http.ResponseWriter, r *http.Request, err error)
|
||||
AuthorizeOption func(opts *AuthorizeOptions)
|
||||
)
|
||||
|
||||
func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.Handler {
|
||||
var authOpts AuthorizeOptions
|
||||
for _, opt := range opts {
|
||||
opt(&authOpts)
|
||||
}
|
||||
|
||||
parser := httpsecurity.NewTokenParser()
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token, err := parser.ParseToken(r, secret, authOpts.PrevSecret)
|
||||
if err != nil {
|
||||
unauthorized(w, r, err, authOpts.Callback)
|
||||
return
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
unauthorized(w, r, err, authOpts.Callback)
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
unauthorized(w, r, err, authOpts.Callback)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
for k, v := range claims {
|
||||
switch k {
|
||||
case jwtAudience, jwtExpire, jwtId, jwtIssueAt, jwtIssuer, jwtNotBefore, jwtSubject:
|
||||
// ignore the standard claims
|
||||
default:
|
||||
ctx = context.WithValue(ctx, k, v)
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func WithPrevSecret(secret string) AuthorizeOption {
|
||||
return func(opts *AuthorizeOptions) {
|
||||
opts.PrevSecret = secret
|
||||
}
|
||||
}
|
||||
|
||||
func WithUnauthorizedCallback(callback UnauthorizedCallback) AuthorizeOption {
|
||||
return func(opts *AuthorizeOptions) {
|
||||
opts.Callback = callback
|
||||
}
|
||||
}
|
||||
|
||||
func detailAuthLog(r *http.Request, reason string) {
|
||||
// discard dump error, only for debug purpose
|
||||
details, _ := httputil.DumpRequest(r, true)
|
||||
logx.Errorf("authorize failed: %s\n=> %+v", reason, string(details))
|
||||
}
|
||||
|
||||
func unauthorized(w http.ResponseWriter, r *http.Request, err error, callback UnauthorizedCallback) {
|
||||
writer := newGuardedResponseWriter(w)
|
||||
|
||||
detailAuthLog(r, err.Error())
|
||||
if callback != nil {
|
||||
callback(writer, r, err)
|
||||
}
|
||||
writer.WriteHeader(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
type guardedResponseWriter struct {
|
||||
writer http.ResponseWriter
|
||||
wroteHeader bool
|
||||
}
|
||||
|
||||
func newGuardedResponseWriter(w http.ResponseWriter) *guardedResponseWriter {
|
||||
return &guardedResponseWriter{
|
||||
writer: w,
|
||||
}
|
||||
}
|
||||
|
||||
func (grw *guardedResponseWriter) Header() http.Header {
|
||||
return grw.writer.Header()
|
||||
}
|
||||
|
||||
func (grw *guardedResponseWriter) Write(body []byte) (int, error) {
|
||||
return grw.writer.Write(body)
|
||||
}
|
||||
|
||||
func (grw *guardedResponseWriter) WriteHeader(statusCode int) {
|
||||
if grw.wroteHeader {
|
||||
return
|
||||
}
|
||||
|
||||
grw.wroteHeader = true
|
||||
grw.writer.WriteHeader(statusCode)
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAuthHandlerFailed(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
handler := Authorize("B63F477D-BBA3-4E52-96D3-C0034C27694A", WithUnauthorizedCallback(
|
||||
func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
w.Header().Set("X-Test", "test")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, err = w.Write([]byte("content"))
|
||||
assert.Nil(t, err)
|
||||
}))(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.Code)
|
||||
}
|
||||
|
||||
func TestAuthHandler(t *testing.T) {
|
||||
const key = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
token, err := buildToken(key, map[string]interface{}{
|
||||
"key": "value",
|
||||
}, 3600)
|
||||
assert.Nil(t, err)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
handler := Authorize(key)(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Test", "test")
|
||||
_, err := w.Write([]byte("content"))
|
||||
assert.Nil(t, err)
|
||||
}))
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
assert.Equal(t, "test", resp.Header().Get("X-Test"))
|
||||
assert.Equal(t, "content", resp.Body.String())
|
||||
}
|
||||
|
||||
func TestAuthHandlerWithPrevSecret(t *testing.T) {
|
||||
const (
|
||||
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
|
||||
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
token, err := buildToken(key, map[string]interface{}{
|
||||
"key": "value",
|
||||
}, 3600)
|
||||
assert.Nil(t, err)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
handler := Authorize(key, WithPrevSecret(prevKey))(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Test", "test")
|
||||
_, err := w.Write([]byte("content"))
|
||||
assert.Nil(t, err)
|
||||
}))
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
assert.Equal(t, "test", resp.Header().Get("X-Test"))
|
||||
assert.Equal(t, "content", resp.Body.String())
|
||||
}
|
||||
|
||||
func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
|
||||
now := time.Now().Unix()
|
||||
claims := make(jwt.MapClaims)
|
||||
claims["exp"] = now + seconds
|
||||
claims["iat"] = now
|
||||
for k, v := range payloads {
|
||||
claims[k] = v
|
||||
}
|
||||
|
||||
token := jwt.New(jwt.SigningMethodHS256)
|
||||
token.Claims = claims
|
||||
|
||||
return token.SignedString([]byte(secretKey))
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"zero/core/breaker"
|
||||
"zero/core/httphandler/internal"
|
||||
"zero/core/httpx"
|
||||
"zero/core/logx"
|
||||
"zero/core/stat"
|
||||
)
|
||||
|
||||
const breakerSeparator = "://"
|
||||
|
||||
func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handler) http.Handler {
|
||||
brk := breaker.NewBreaker(breaker.WithName(strings.Join([]string{method, path}, breakerSeparator)))
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
promise, err := brk.Allow()
|
||||
if err != nil {
|
||||
metrics.AddDrop()
|
||||
logx.Errorf("[http] dropped, %s - %s - %s",
|
||||
r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent())
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
cw := &internal.WithCodeResponseWriter{Writer: w}
|
||||
defer func() {
|
||||
if cw.Code < http.StatusInternalServerError {
|
||||
promise.Accept()
|
||||
} else {
|
||||
promise.Reject(fmt.Sprintf("%d %s", cw.Code, http.StatusText(cw.Code)))
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(cw, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/core/stat"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logx.Disable()
|
||||
stat.SetReporter(nil)
|
||||
}
|
||||
|
||||
func TestBreakerHandlerAccept(t *testing.T) {
|
||||
metrics := stat.NewMetrics("unit-test")
|
||||
breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
|
||||
handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Test", "test")
|
||||
_, err := w.Write([]byte("content"))
|
||||
assert.Nil(t, err)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Set("X-Test", "test")
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
assert.Equal(t, "test", resp.Header().Get("X-Test"))
|
||||
assert.Equal(t, "content", resp.Body.String())
|
||||
}
|
||||
|
||||
func TestBreakerHandlerFail(t *testing.T) {
|
||||
metrics := stat.NewMetrics("unit-test")
|
||||
breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
|
||||
handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusBadGateway, resp.Code)
|
||||
}
|
||||
|
||||
func TestBreakerHandler_4XX(t *testing.T) {
|
||||
metrics := stat.NewMetrics("unit-test")
|
||||
breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
|
||||
handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
}
|
||||
|
||||
const tries = 100
|
||||
var pass int
|
||||
for i := 0; i < tries; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
if resp.Code == http.StatusBadRequest {
|
||||
pass++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, tries, pass)
|
||||
}
|
||||
|
||||
func TestBreakerHandlerReject(t *testing.T) {
|
||||
metrics := stat.NewMetrics("unit-test")
|
||||
breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
|
||||
handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
}
|
||||
|
||||
var drops int
|
||||
for i := 0; i < 100; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
if resp.Code == http.StatusServiceUnavailable {
|
||||
drops++
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, drops >= 80, fmt.Sprintf("expected to be greater than 80, but got %d", drops))
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"zero/core/codec"
|
||||
"zero/core/httphandler/internal"
|
||||
"zero/core/httpx"
|
||||
"zero/core/logx"
|
||||
)
|
||||
|
||||
const contentSecurity = "X-Content-Security"
|
||||
|
||||
type UnsignedCallback func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int)
|
||||
|
||||
func ContentSecurityHandler(decrypters map[string]codec.RsaDecrypter, tolerance time.Duration,
|
||||
strict bool, callbacks ...UnsignedCallback) func(http.Handler) http.Handler {
|
||||
if len(callbacks) == 0 {
|
||||
callbacks = append(callbacks, handleVerificationFailure)
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodDelete, http.MethodGet, http.MethodPost, http.MethodPut:
|
||||
header, err := internal.ParseContentSecurity(decrypters, r)
|
||||
if err != nil {
|
||||
logx.Infof("Signature parse failed, X-Content-Security: %s, error: %s",
|
||||
r.Header.Get(contentSecurity), err.Error())
|
||||
executeCallbacks(w, r, next, strict, httpx.CodeSignatureInvalidHeader, callbacks)
|
||||
} else if code := internal.VerifySignature(r, header, tolerance); code != httpx.CodeSignaturePass {
|
||||
logx.Infof("Signature verification failed, X-Content-Security: %s",
|
||||
r.Header.Get(contentSecurity))
|
||||
executeCallbacks(w, r, next, strict, code, callbacks)
|
||||
} else if r.ContentLength > 0 && header.Encrypted() {
|
||||
CryptionHandler(header.Key)(next).ServeHTTP(w, r)
|
||||
} else {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
default:
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func executeCallbacks(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool,
|
||||
code int, callbacks []UnsignedCallback) {
|
||||
for _, callback := range callbacks {
|
||||
callback(w, r, next, strict, code)
|
||||
}
|
||||
}
|
||||
|
||||
func handleVerificationFailure(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
|
||||
if strict {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
} else {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
@@ -1,388 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"zero/core/codec"
|
||||
"zero/core/httpx"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const timeDiff = time.Hour * 2 * 24
|
||||
|
||||
var (
|
||||
fingerprint = "12345"
|
||||
pubKey = []byte(`-----BEGIN PUBLIC KEY-----
|
||||
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQD7bq4FLG0ctccbEFEsUBuRxkjE
|
||||
eJ5U+0CAEjJk20V9/u2Fu76i1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVH
|
||||
miYbRgh5Fy6336KepLCtCmV/r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwR
|
||||
my47YlhspwszKdRP+wIDAQAB
|
||||
-----END PUBLIC KEY-----`)
|
||||
priKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXAIBAAKBgQD7bq4FLG0ctccbEFEsUBuRxkjEeJ5U+0CAEjJk20V9/u2Fu76i
|
||||
1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVHmiYbRgh5Fy6336KepLCtCmV/
|
||||
r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwRmy47YlhspwszKdRP+wIDAQAB
|
||||
AoGBANs1qf7UtuSbD1ZnKX5K8V5s07CHwPMygw+lzc3k5ndtNUStZQ2vnAaBXHyH
|
||||
Nm4lJ4AI2mhQ39jQB/1TyP1uAzvpLhT60fRybEq9zgJ/81Gm9bnaEpFJ9bP2bBrY
|
||||
J0jbaTMfbzL/PJFl3J3RGMR40C76h5yRYSnOpMoMiKWnJqrhAkEA/zCOkR+34Pk0
|
||||
Yo3sIP4ranY6AAvwacgNaui4ll5xeYwv3iLOQvPlpxIxFHKXEY0klNNyjjXqgYjP
|
||||
cOenqtt6UwJBAPw7EYuteVHvHvQVuTbKAaYHcOrp4nFeZF3ndFfl0w2dwGhfzcXO
|
||||
ROyd5dNQCuCWRo8JBpjG6PFyzezayF4KLrkCQCGditoxHG7FRRJKcbVy5dMzWbaR
|
||||
3AyDLslLeK1OKZKCVffkC9mj+TeF3PM9mQrV1eDI7ckv7wE7PWA5E8wc90MCQEOV
|
||||
MCZU3OTvRUPxbicYCUkLRV4sPNhTimD+21WR5vMHCb7trJ0Ln7wmsqXkFIYIve8l
|
||||
Y/cblN7c/AAyvu0znUECQA318nPldsxR6+H8HTS3uEbkL4UJdjQJHsvTwKxAw5qc
|
||||
moKExvRlN0zmGGuArKcqS38KG7PXZMrUv3FXPdp6BDQ=
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
|
||||
)
|
||||
|
||||
type requestSettings struct {
|
||||
method string
|
||||
url string
|
||||
body io.Reader
|
||||
strict bool
|
||||
crypt bool
|
||||
requestUri string
|
||||
timestamp int64
|
||||
fingerprint string
|
||||
missHeader bool
|
||||
signature string
|
||||
}
|
||||
|
||||
func init() {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
}
|
||||
|
||||
func TestContentSecurityHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
method string
|
||||
url string
|
||||
body string
|
||||
strict bool
|
||||
crypt bool
|
||||
requestUri string
|
||||
timestamp int64
|
||||
fingerprint string
|
||||
missHeader bool
|
||||
signature string
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
method: http.MethodGet,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
strict: true,
|
||||
crypt: false,
|
||||
},
|
||||
{
|
||||
method: http.MethodPost,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
body: "hello",
|
||||
strict: true,
|
||||
crypt: false,
|
||||
},
|
||||
{
|
||||
method: http.MethodGet,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
strict: true,
|
||||
crypt: true,
|
||||
},
|
||||
{
|
||||
method: http.MethodPost,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
body: "hello",
|
||||
strict: true,
|
||||
crypt: true,
|
||||
},
|
||||
{
|
||||
method: http.MethodGet,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
strict: true,
|
||||
crypt: true,
|
||||
timestamp: time.Now().Add(timeDiff).Unix(),
|
||||
statusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
method: http.MethodPost,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
body: "hello",
|
||||
strict: true,
|
||||
crypt: true,
|
||||
timestamp: time.Now().Add(-timeDiff).Unix(),
|
||||
statusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
method: http.MethodPost,
|
||||
url: "http://remotehost/",
|
||||
body: "hello",
|
||||
strict: true,
|
||||
crypt: true,
|
||||
requestUri: "http://localhost/a/b?c=d&e=f",
|
||||
},
|
||||
{
|
||||
method: http.MethodPost,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
body: "hello",
|
||||
strict: false,
|
||||
crypt: true,
|
||||
fingerprint: "badone",
|
||||
},
|
||||
{
|
||||
method: http.MethodPost,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
body: "hello",
|
||||
strict: true,
|
||||
crypt: true,
|
||||
timestamp: time.Now().Add(-timeDiff).Unix(),
|
||||
fingerprint: "badone",
|
||||
statusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
method: http.MethodPost,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
body: "hello",
|
||||
strict: true,
|
||||
crypt: true,
|
||||
missHeader: true,
|
||||
statusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
method: http.MethodHead,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
strict: true,
|
||||
crypt: false,
|
||||
},
|
||||
{
|
||||
method: http.MethodGet,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
strict: true,
|
||||
crypt: false,
|
||||
signature: "badone",
|
||||
statusCode: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.url, func(t *testing.T) {
|
||||
if test.statusCode == 0 {
|
||||
test.statusCode = http.StatusOK
|
||||
}
|
||||
if len(test.fingerprint) == 0 {
|
||||
test.fingerprint = fingerprint
|
||||
}
|
||||
if test.timestamp == 0 {
|
||||
test.timestamp = time.Now().Unix()
|
||||
}
|
||||
|
||||
func() {
|
||||
keyFile, err := createTempFile(priKey)
|
||||
defer os.Remove(keyFile)
|
||||
|
||||
assert.Nil(t, err)
|
||||
decrypter, err := codec.NewRsaDecrypter(keyFile)
|
||||
assert.Nil(t, err)
|
||||
contentSecurityHandler := ContentSecurityHandler(map[string]codec.RsaDecrypter{
|
||||
fingerprint: decrypter,
|
||||
}, time.Hour, test.strict)
|
||||
handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
}))
|
||||
|
||||
var reader io.Reader
|
||||
if len(test.body) > 0 {
|
||||
reader = strings.NewReader(test.body)
|
||||
}
|
||||
setting := requestSettings{
|
||||
method: test.method,
|
||||
url: test.url,
|
||||
body: reader,
|
||||
strict: test.strict,
|
||||
crypt: test.crypt,
|
||||
requestUri: test.requestUri,
|
||||
timestamp: test.timestamp,
|
||||
fingerprint: test.fingerprint,
|
||||
missHeader: test.missHeader,
|
||||
signature: test.signature,
|
||||
}
|
||||
req, err := buildRequest(setting)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, test.statusCode, resp.Code)
|
||||
}()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContentSecurityHandler_UnsignedCallback(t *testing.T) {
|
||||
keyFile, err := createTempFile(priKey)
|
||||
defer os.Remove(keyFile)
|
||||
|
||||
assert.Nil(t, err)
|
||||
decrypter, err := codec.NewRsaDecrypter(keyFile)
|
||||
assert.Nil(t, err)
|
||||
contentSecurityHandler := ContentSecurityHandler(
|
||||
map[string]codec.RsaDecrypter{
|
||||
fingerprint: decrypter,
|
||||
},
|
||||
time.Hour,
|
||||
true,
|
||||
func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
|
||||
setting := requestSettings{
|
||||
method: http.MethodGet,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
signature: "badone",
|
||||
}
|
||||
req, err := buildRequest(setting)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
|
||||
func TestContentSecurityHandler_UnsignedCallback_WrongTime(t *testing.T) {
|
||||
keyFile, err := createTempFile(priKey)
|
||||
defer os.Remove(keyFile)
|
||||
|
||||
assert.Nil(t, err)
|
||||
decrypter, err := codec.NewRsaDecrypter(keyFile)
|
||||
assert.Nil(t, err)
|
||||
contentSecurityHandler := ContentSecurityHandler(
|
||||
map[string]codec.RsaDecrypter{
|
||||
fingerprint: decrypter,
|
||||
},
|
||||
time.Hour,
|
||||
true,
|
||||
func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
|
||||
assert.Equal(t, httpx.CodeSignatureWrongTime, code)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
|
||||
var reader io.Reader
|
||||
reader = strings.NewReader("hello")
|
||||
setting := requestSettings{
|
||||
method: http.MethodPost,
|
||||
url: "http://localhost/a/b?c=d&e=f",
|
||||
body: reader,
|
||||
strict: true,
|
||||
crypt: true,
|
||||
timestamp: time.Now().Add(time.Hour * 24 * 365).Unix(),
|
||||
fingerprint: fingerprint,
|
||||
}
|
||||
req, err := buildRequest(setting)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
|
||||
func buildRequest(rs requestSettings) (*http.Request, error) {
|
||||
var bodyStr string
|
||||
var err error
|
||||
|
||||
if rs.crypt && rs.body != nil {
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, rs.body)
|
||||
bodyBytes, err := codec.EcbEncrypt(key, buf.Bytes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bodyStr = base64.StdEncoding.EncodeToString(bodyBytes)
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(rs.method, rs.url, strings.NewReader(bodyStr))
|
||||
if len(rs.signature) == 0 {
|
||||
sha := sha256.New()
|
||||
sha.Write([]byte(bodyStr))
|
||||
bodySign := fmt.Sprintf("%x", sha.Sum(nil))
|
||||
var path string
|
||||
var query string
|
||||
if len(rs.requestUri) > 0 {
|
||||
if u, err := url.Parse(rs.requestUri); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
path = u.Path
|
||||
query = u.RawQuery
|
||||
}
|
||||
} else {
|
||||
path = r.URL.Path
|
||||
query = r.URL.RawQuery
|
||||
}
|
||||
contentOfSign := strings.Join([]string{
|
||||
strconv.FormatInt(rs.timestamp, 10),
|
||||
rs.method,
|
||||
path,
|
||||
query,
|
||||
bodySign,
|
||||
}, "\n")
|
||||
rs.signature = codec.HmacBase64([]byte(key), contentOfSign)
|
||||
}
|
||||
|
||||
var mode string
|
||||
if rs.crypt {
|
||||
mode = "1"
|
||||
} else {
|
||||
mode = "0"
|
||||
}
|
||||
content := strings.Join([]string{
|
||||
"version=v1",
|
||||
"type=" + mode,
|
||||
fmt.Sprintf("key=%s", base64.StdEncoding.EncodeToString(key)),
|
||||
"time=" + strconv.FormatInt(rs.timestamp, 10),
|
||||
}, "; ")
|
||||
|
||||
encrypter, err := codec.NewRsaEncrypter([]byte(pubKey))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
output, err := encrypter.Encrypt([]byte(content))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
encryptedContent := base64.StdEncoding.EncodeToString(output)
|
||||
if !rs.missHeader {
|
||||
r.Header.Set(httpx.ContentSecurity, strings.Join([]string{
|
||||
fmt.Sprintf("key=%s", rs.fingerprint),
|
||||
"secret=" + encryptedContent,
|
||||
"signature=" + rs.signature,
|
||||
}, "; "))
|
||||
}
|
||||
if len(rs.requestUri) > 0 {
|
||||
r.Header.Set("X-Request-Uri", rs.requestUri)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func createTempFile(body []byte) (string, error) {
|
||||
tmpFile, err := ioutil.TempFile(os.TempDir(), "go-unit-*.tmp")
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else {
|
||||
tmpFile.Close()
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(tmpFile.Name(), body, os.ModePerm)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return tmpFile.Name(), nil
|
||||
}
|
||||
@@ -1,101 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"zero/core/codec"
|
||||
"zero/core/logx"
|
||||
)
|
||||
|
||||
const maxBytes = 1 << 20 // 1 MiB
|
||||
|
||||
func CryptionHandler(key []byte) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cw := newCryptionResponseWriter(w)
|
||||
defer cw.flush(key)
|
||||
|
||||
if r.ContentLength <= 0 {
|
||||
next.ServeHTTP(cw, r)
|
||||
return
|
||||
}
|
||||
|
||||
if err := decryptBody(key, r); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(cw, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func decryptBody(key []byte, r *http.Request) error {
|
||||
content, err := ioutil.ReadAll(io.LimitReader(r.Body, maxBytes))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
content, err = base64.StdEncoding.DecodeString(string(content))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
output, err := codec.EcbDecrypt(key, content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.Write(output)
|
||||
r.Body = ioutil.NopCloser(&buf)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type cryptionResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
func newCryptionResponseWriter(w http.ResponseWriter) *cryptionResponseWriter {
|
||||
return &cryptionResponseWriter{
|
||||
ResponseWriter: w,
|
||||
buf: new(bytes.Buffer),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *cryptionResponseWriter) Header() http.Header {
|
||||
return w.ResponseWriter.Header()
|
||||
}
|
||||
|
||||
func (w *cryptionResponseWriter) Write(p []byte) (int, error) {
|
||||
return w.buf.Write(p)
|
||||
}
|
||||
|
||||
func (w *cryptionResponseWriter) WriteHeader(statusCode int) {
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *cryptionResponseWriter) flush(key []byte) {
|
||||
if w.buf.Len() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
content, err := codec.EcbEncrypt(key, w.buf.Bytes())
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
body := base64.StdEncoding.EncodeToString(content)
|
||||
if n, err := io.WriteString(w.ResponseWriter, body); err != nil {
|
||||
logx.Errorf("write response failed, error: %s", err)
|
||||
} else if n < len(content) {
|
||||
logx.Errorf("actual bytes: %d, written bytes: %d", len(content), n)
|
||||
}
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"zero/core/codec"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
reqText = "ping"
|
||||
respText = "pong"
|
||||
)
|
||||
|
||||
var aesKey = []byte(`PdSgVkYp3s6v9y$B&E)H+MbQeThWmZq4`)
|
||||
|
||||
func init() {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
}
|
||||
|
||||
func TestCryptionHandlerGet(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/any", nil)
|
||||
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte(respText))
|
||||
w.Header().Set("X-Test", "test")
|
||||
assert.Nil(t, err)
|
||||
}))
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, "test", recorder.Header().Get("X-Test"))
|
||||
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestCryptionHandlerPost(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
|
||||
assert.Nil(t, err)
|
||||
buf.WriteString(base64.StdEncoding.EncodeToString(enc))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/any", &buf)
|
||||
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, reqText, string(body))
|
||||
|
||||
w.Write([]byte(respText))
|
||||
}))
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestCryptionHandlerPostBadEncryption(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
|
||||
assert.Nil(t, err)
|
||||
buf.Write(enc)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/any", &buf)
|
||||
handler := CryptionHandler(aesKey)(nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestCryptionHandlerWriteHeader(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/any", nil)
|
||||
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"zero/core/httpx"
|
||||
)
|
||||
|
||||
const gzipEncoding = "gzip"
|
||||
|
||||
func GunzipHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.Header.Get(httpx.ContentEncoding), gzipEncoding) {
|
||||
reader, err := gzip.NewReader(r.Body)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = reader
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"zero/core/codec"
|
||||
"zero/core/httpx"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGunzipHandler(t *testing.T) {
|
||||
const message = "hello world"
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
handler := GunzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, string(body), message)
|
||||
wg.Done()
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "http://localhost",
|
||||
bytes.NewReader(codec.Gzip([]byte(message))))
|
||||
req.Header.Set(httpx.ContentEncoding, gzipEncoding)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestGunzipHandler_NoGzip(t *testing.T) {
|
||||
const message = "hello world"
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
handler := GunzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, string(body), message)
|
||||
wg.Done()
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "http://localhost",
|
||||
strings.NewReader(message))
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestGunzipHandler_NoGzipButTelling(t *testing.T) {
|
||||
const message = "hello world"
|
||||
handler := GunzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
req := httptest.NewRequest(http.MethodPost, "http://localhost",
|
||||
strings.NewReader(message))
|
||||
req.Header.Set(httpx.ContentEncoding, gzipEncoding)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusBadRequest, resp.Code)
|
||||
}
|
||||
@@ -1,147 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"zero/core/codec"
|
||||
"zero/core/httpx"
|
||||
"zero/core/iox"
|
||||
"zero/core/logx"
|
||||
)
|
||||
|
||||
const (
|
||||
requestUriHeader = "X-Request-Uri"
|
||||
signatureField = "signature"
|
||||
timeField = "time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidContentType = errors.New("invalid content type")
|
||||
ErrInvalidHeader = errors.New("invalid X-Content-Security header")
|
||||
ErrInvalidKey = errors.New("invalid key")
|
||||
ErrInvalidPublicKey = errors.New("invalid public key")
|
||||
ErrInvalidSecret = errors.New("invalid secret")
|
||||
)
|
||||
|
||||
type ContentSecurityHeader struct {
|
||||
Key []byte
|
||||
Timestamp string
|
||||
ContentType int
|
||||
Signature string
|
||||
}
|
||||
|
||||
func (h *ContentSecurityHeader) Encrypted() bool {
|
||||
return h.ContentType == httpx.CryptionType
|
||||
}
|
||||
|
||||
func ParseContentSecurity(decrypters map[string]codec.RsaDecrypter, r *http.Request) (
|
||||
*ContentSecurityHeader, error) {
|
||||
contentSecurity := r.Header.Get(httpx.ContentSecurity)
|
||||
attrs := httpx.ParseHeader(contentSecurity)
|
||||
fingerprint := attrs[httpx.KeyField]
|
||||
secret := attrs[httpx.SecretField]
|
||||
signature := attrs[signatureField]
|
||||
|
||||
if len(fingerprint) == 0 || len(secret) == 0 || len(signature) == 0 {
|
||||
return nil, ErrInvalidHeader
|
||||
}
|
||||
|
||||
decrypter, ok := decrypters[fingerprint]
|
||||
if !ok {
|
||||
return nil, ErrInvalidPublicKey
|
||||
}
|
||||
|
||||
decryptedSecret, err := decrypter.DecryptBase64(secret)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidSecret
|
||||
}
|
||||
|
||||
attrs = httpx.ParseHeader(string(decryptedSecret))
|
||||
base64Key := attrs[httpx.KeyField]
|
||||
timestamp := attrs[timeField]
|
||||
contentType := attrs[httpx.TypeField]
|
||||
|
||||
key, err := base64.StdEncoding.DecodeString(base64Key)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidKey
|
||||
}
|
||||
|
||||
cType, err := strconv.Atoi(contentType)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidContentType
|
||||
}
|
||||
|
||||
return &ContentSecurityHeader{
|
||||
Key: key,
|
||||
Timestamp: timestamp,
|
||||
ContentType: cType,
|
||||
Signature: signature,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func VerifySignature(r *http.Request, securityHeader *ContentSecurityHeader, tolerance time.Duration) int {
|
||||
seconds, err := strconv.ParseInt(securityHeader.Timestamp, 10, 64)
|
||||
if err != nil {
|
||||
return httpx.CodeSignatureInvalidHeader
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
toleranceSeconds := int64(tolerance.Seconds())
|
||||
if seconds+toleranceSeconds < now || now+toleranceSeconds < seconds {
|
||||
return httpx.CodeSignatureWrongTime
|
||||
}
|
||||
|
||||
reqPath, reqQuery := getPathQuery(r)
|
||||
signContent := strings.Join([]string{
|
||||
securityHeader.Timestamp,
|
||||
r.Method,
|
||||
reqPath,
|
||||
reqQuery,
|
||||
computeBodySignature(r),
|
||||
}, "\n")
|
||||
actualSignature := codec.HmacBase64(securityHeader.Key, signContent)
|
||||
|
||||
passed := securityHeader.Signature == actualSignature
|
||||
if !passed {
|
||||
logx.Infof("signature different, expect: %s, actual: %s",
|
||||
securityHeader.Signature, actualSignature)
|
||||
}
|
||||
|
||||
if passed {
|
||||
return httpx.CodeSignaturePass
|
||||
} else {
|
||||
return httpx.CodeSignatureInvalidToken
|
||||
}
|
||||
}
|
||||
|
||||
func computeBodySignature(r *http.Request) string {
|
||||
var dup io.ReadCloser
|
||||
r.Body, dup = iox.DupReadCloser(r.Body)
|
||||
sha := sha256.New()
|
||||
io.Copy(sha, r.Body)
|
||||
r.Body = dup
|
||||
return fmt.Sprintf("%x", sha.Sum(nil))
|
||||
}
|
||||
|
||||
func getPathQuery(r *http.Request) (string, string) {
|
||||
requestUri := r.Header.Get(requestUriHeader)
|
||||
if len(requestUri) == 0 {
|
||||
return r.URL.Path, r.URL.RawQuery
|
||||
}
|
||||
|
||||
uri, err := url.Parse(requestUri)
|
||||
if err != nil {
|
||||
return r.URL.Path, r.URL.RawQuery
|
||||
}
|
||||
|
||||
return uri.Path, uri.RawQuery
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package internal
|
||||
|
||||
import "net/http"
|
||||
|
||||
type WithCodeResponseWriter struct {
|
||||
Writer http.ResponseWriter
|
||||
Code int
|
||||
}
|
||||
|
||||
func (w *WithCodeResponseWriter) Header() http.Header {
|
||||
return w.Writer.Header()
|
||||
}
|
||||
|
||||
func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) {
|
||||
return w.Writer.Write(bytes)
|
||||
}
|
||||
|
||||
func (w *WithCodeResponseWriter) WriteHeader(code int) {
|
||||
w.Writer.WriteHeader(code)
|
||||
w.Code = code
|
||||
}
|
||||
@@ -1,166 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"zero/core/httplog"
|
||||
"zero/core/httpx"
|
||||
"zero/core/iox"
|
||||
"zero/core/logx"
|
||||
"zero/core/timex"
|
||||
"zero/core/utils"
|
||||
)
|
||||
|
||||
const slowThreshold = time.Millisecond * 500
|
||||
|
||||
type LoggedResponseWriter struct {
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
code int
|
||||
}
|
||||
|
||||
func (w *LoggedResponseWriter) Header() http.Header {
|
||||
return w.w.Header()
|
||||
}
|
||||
|
||||
func (w *LoggedResponseWriter) Write(bytes []byte) (int, error) {
|
||||
return w.w.Write(bytes)
|
||||
}
|
||||
|
||||
func (w *LoggedResponseWriter) WriteHeader(code int) {
|
||||
w.w.WriteHeader(code)
|
||||
w.code = code
|
||||
}
|
||||
|
||||
func LogHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
timer := utils.NewElapsedTimer()
|
||||
logs := new(httplog.LogCollector)
|
||||
lrw := LoggedResponseWriter{
|
||||
w: w,
|
||||
r: r,
|
||||
code: http.StatusOK,
|
||||
}
|
||||
|
||||
var dup io.ReadCloser
|
||||
r.Body, dup = iox.DupReadCloser(r.Body)
|
||||
next.ServeHTTP(&lrw, r.WithContext(context.WithValue(r.Context(), httplog.LogContext, logs)))
|
||||
r.Body = dup
|
||||
logBrief(r, lrw.code, timer, logs)
|
||||
})
|
||||
}
|
||||
|
||||
type DetailLoggedResponseWriter struct {
|
||||
writer *LoggedResponseWriter
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
func newDetailLoggedResponseWriter(writer *LoggedResponseWriter, buf *bytes.Buffer) *DetailLoggedResponseWriter {
|
||||
return &DetailLoggedResponseWriter{
|
||||
writer: writer,
|
||||
buf: buf,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *DetailLoggedResponseWriter) Header() http.Header {
|
||||
return w.writer.Header()
|
||||
}
|
||||
|
||||
func (w *DetailLoggedResponseWriter) Write(bs []byte) (int, error) {
|
||||
w.buf.Write(bs)
|
||||
return w.writer.Write(bs)
|
||||
}
|
||||
|
||||
func (w *DetailLoggedResponseWriter) WriteHeader(code int) {
|
||||
w.writer.WriteHeader(code)
|
||||
}
|
||||
|
||||
func DetailedLogHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
timer := utils.NewElapsedTimer()
|
||||
var buf bytes.Buffer
|
||||
lrw := newDetailLoggedResponseWriter(&LoggedResponseWriter{
|
||||
w: w,
|
||||
r: r,
|
||||
code: http.StatusOK,
|
||||
}, &buf)
|
||||
|
||||
var dup io.ReadCloser
|
||||
r.Body, dup = iox.DupReadCloser(r.Body)
|
||||
logs := new(httplog.LogCollector)
|
||||
next.ServeHTTP(lrw, r.WithContext(context.WithValue(r.Context(), httplog.LogContext, logs)))
|
||||
r.Body = dup
|
||||
logDetails(r, lrw, timer, logs)
|
||||
})
|
||||
}
|
||||
|
||||
func dumpRequest(r *http.Request) string {
|
||||
reqContent, err := httputil.DumpRequest(r, true)
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
} else {
|
||||
return string(reqContent)
|
||||
}
|
||||
}
|
||||
|
||||
func logBrief(r *http.Request, code int, timer *utils.ElapsedTimer, logs *httplog.LogCollector) {
|
||||
var buf bytes.Buffer
|
||||
duration := timer.Duration()
|
||||
buf.WriteString(fmt.Sprintf("%d - %s - %s - %s - %s",
|
||||
code, r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration)))
|
||||
if duration > slowThreshold {
|
||||
logx.Slowf("[HTTP] %d - %s - %s - %s - slowcall(%s)",
|
||||
code, r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration))
|
||||
}
|
||||
|
||||
ok := isOkResponse(code)
|
||||
if !ok {
|
||||
buf.WriteString(fmt.Sprintf("\n%s", dumpRequest(r)))
|
||||
}
|
||||
|
||||
body := logs.Flush()
|
||||
if len(body) > 0 {
|
||||
buf.WriteString(fmt.Sprintf("\n%s", body))
|
||||
}
|
||||
|
||||
if ok {
|
||||
logx.Info(buf.String())
|
||||
} else {
|
||||
logx.Error(buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func logDetails(r *http.Request, response *DetailLoggedResponseWriter, timer *utils.ElapsedTimer,
|
||||
logs *httplog.LogCollector) {
|
||||
var buf bytes.Buffer
|
||||
duration := timer.Duration()
|
||||
buf.WriteString(fmt.Sprintf("%d - %s - %s\n=> %s\n",
|
||||
response.writer.code, r.RemoteAddr, timex.ReprOfDuration(duration), dumpRequest(r)))
|
||||
if duration > slowThreshold {
|
||||
logx.Slowf("[HTTP] %d - %s - slowcall(%s)\n=> %s\n",
|
||||
response.writer.code, r.RemoteAddr, timex.ReprOfDuration(duration), dumpRequest(r))
|
||||
}
|
||||
|
||||
body := logs.Flush()
|
||||
if len(body) > 0 {
|
||||
buf.WriteString(fmt.Sprintf("%s\n", body))
|
||||
}
|
||||
|
||||
respBuf := response.buf.Bytes()
|
||||
if len(respBuf) > 0 {
|
||||
buf.WriteString(fmt.Sprintf("<= %s", respBuf))
|
||||
}
|
||||
|
||||
logx.Info(buf.String())
|
||||
}
|
||||
|
||||
func isOkResponse(code int) bool {
|
||||
// not server error
|
||||
return code < http.StatusInternalServerError
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"zero/core/httplog"
|
||||
)
|
||||
|
||||
func init() {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
}
|
||||
|
||||
func TestLogHandler(t *testing.T) {
|
||||
handlers := []func(handler http.Handler) http.Handler{
|
||||
LogHandler,
|
||||
DetailedLogHandler,
|
||||
}
|
||||
|
||||
for _, logHandler := range handlers {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.Context().Value(httplog.LogContext).(*httplog.LogCollector).Append("anything")
|
||||
w.Header().Set("X-Test", "test")
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
_, err := w.Write([]byte("content"))
|
||||
assert.Nil(t, err)
|
||||
}))
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
|
||||
assert.Equal(t, "test", resp.Header().Get("X-Test"))
|
||||
assert.Equal(t, "content", resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogHandlerSlow(t *testing.T) {
|
||||
handlers := []func(handler http.Handler) http.Handler{
|
||||
LogHandler,
|
||||
DetailedLogHandler,
|
||||
}
|
||||
|
||||
for _, logHandler := range handlers {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(slowThreshold + time.Millisecond*50)
|
||||
}))
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLogHandler(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
}
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"zero/core/httplog"
|
||||
)
|
||||
|
||||
func MaxBytesHandler(n int64) func(http.Handler) http.Handler {
|
||||
if n <= 0 {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return next
|
||||
}
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.ContentLength > n {
|
||||
httplog.Errorf(r, "request entity too large, limit is %d, but got %d, rejected with code %d",
|
||||
n, r.ContentLength, http.StatusRequestEntityTooLarge)
|
||||
w.WriteHeader(http.StatusRequestEntityTooLarge)
|
||||
} else {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMaxBytesHandler(t *testing.T) {
|
||||
maxb := MaxBytesHandler(10)
|
||||
handler := maxb(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "http://localhost",
|
||||
bytes.NewBufferString("123456789012345"))
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusRequestEntityTooLarge, resp.Code)
|
||||
|
||||
req = httptest.NewRequest(http.MethodPost, "http://localhost", bytes.NewBufferString("12345"))
|
||||
resp = httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
|
||||
func TestMaxBytesHandlerNoLimit(t *testing.T) {
|
||||
maxb := MaxBytesHandler(-1)
|
||||
handler := maxb(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "http://localhost",
|
||||
bytes.NewBufferString("123456789012345"))
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"zero/core/httplog"
|
||||
"zero/core/logx"
|
||||
"zero/core/syncx"
|
||||
)
|
||||
|
||||
func MaxConns(n int) func(http.Handler) http.Handler {
|
||||
if n <= 0 {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return next
|
||||
}
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
latchLimiter := syncx.NewLimit(n)
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if latchLimiter.TryBorrow() {
|
||||
defer func() {
|
||||
if err := latchLimiter.Return(); err != nil {
|
||||
logx.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
} else {
|
||||
httplog.Errorf(r, "Concurrent connections over %d, rejected with code %d",
|
||||
n, http.StatusServiceUnavailable)
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"zero/core/lang"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const conns = 4
|
||||
|
||||
func init() {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
}
|
||||
|
||||
func TestMaxConnsHandler(t *testing.T) {
|
||||
var waitGroup sync.WaitGroup
|
||||
waitGroup.Add(conns)
|
||||
done := make(chan lang.PlaceholderType)
|
||||
defer close(done)
|
||||
|
||||
maxConns := MaxConns(conns)
|
||||
handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
waitGroup.Done()
|
||||
<-done
|
||||
}))
|
||||
|
||||
for i := 0; i < conns; i++ {
|
||||
go func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
handler.ServeHTTP(httptest.NewRecorder(), req)
|
||||
}()
|
||||
}
|
||||
|
||||
waitGroup.Wait()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
|
||||
}
|
||||
|
||||
func TestWithoutMaxConnsHandler(t *testing.T) {
|
||||
const (
|
||||
key = "block"
|
||||
value = "1"
|
||||
)
|
||||
var waitGroup sync.WaitGroup
|
||||
waitGroup.Add(conns)
|
||||
done := make(chan lang.PlaceholderType)
|
||||
defer close(done)
|
||||
|
||||
maxConns := MaxConns(0)
|
||||
handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
val := r.Header.Get(key)
|
||||
if val == value {
|
||||
waitGroup.Done()
|
||||
<-done
|
||||
}
|
||||
}))
|
||||
|
||||
for i := 0; i < conns; i++ {
|
||||
go func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Set(key, value)
|
||||
handler.ServeHTTP(httptest.NewRecorder(), req)
|
||||
}()
|
||||
}
|
||||
|
||||
waitGroup.Wait()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"zero/core/stat"
|
||||
"zero/core/timex"
|
||||
)
|
||||
|
||||
func MetricHandler(metrics *stat.Metrics) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
metrics.Add(stat.Task{
|
||||
Duration: timex.Since(startTime),
|
||||
})
|
||||
}()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"zero/core/stat"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMetricHandler(t *testing.T) {
|
||||
metrics := stat.NewMetrics("unit-test")
|
||||
metricHandler := MetricHandler(metrics)
|
||||
handler := metricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"zero/core/httphandler/internal"
|
||||
"zero/core/metric"
|
||||
"zero/core/timex"
|
||||
)
|
||||
|
||||
const serverNamespace = "http_server"
|
||||
|
||||
var (
|
||||
metricServerReqDur = metric.NewHistogramVec(&metric.HistogramVecOpts{
|
||||
Namespace: serverNamespace,
|
||||
Subsystem: "requests",
|
||||
Name: "duration_ms",
|
||||
Help: "http server requests duration(ms).",
|
||||
Labels: []string{"path"},
|
||||
Buckets: []float64{5, 10, 25, 50, 100, 250, 500, 1000},
|
||||
})
|
||||
|
||||
metricServerReqCodeTotal = metric.NewCounterVec(&metric.CounterVecOpts{
|
||||
Namespace: serverNamespace,
|
||||
Subsystem: "requests",
|
||||
Name: "code_total",
|
||||
Help: "http server requests error count.",
|
||||
Labels: []string{"path", "code"},
|
||||
})
|
||||
)
|
||||
|
||||
func PromMetricHandler(path string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
startTime := timex.Now()
|
||||
cw := &internal.WithCodeResponseWriter{Writer: w}
|
||||
defer func() {
|
||||
metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), path)
|
||||
metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code))
|
||||
}()
|
||||
|
||||
next.ServeHTTP(cw, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPromMetricHandler(t *testing.T) {
|
||||
promMetricHandler := PromMetricHandler("/user/login")
|
||||
handler := promMetricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
|
||||
"zero/core/httplog"
|
||||
)
|
||||
|
||||
func RecoverHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if result := recover(); result != nil {
|
||||
httplog.Error(r, fmt.Sprintf("%v\n%s", result, debug.Stack()))
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
}
|
||||
|
||||
func TestWithPanic(t *testing.T) {
|
||||
handler := RecoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("whatever")
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, resp.Code)
|
||||
}
|
||||
|
||||
func TestWithoutPanic(t *testing.T) {
|
||||
handler := RecoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"zero/core/httphandler/internal"
|
||||
"zero/core/httpx"
|
||||
"zero/core/load"
|
||||
"zero/core/logx"
|
||||
"zero/core/stat"
|
||||
)
|
||||
|
||||
const serviceType = "api"
|
||||
|
||||
var (
|
||||
sheddingStat *load.SheddingStat
|
||||
lock sync.Mutex
|
||||
)
|
||||
|
||||
func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Handler) http.Handler {
|
||||
if shedder == nil {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return next
|
||||
}
|
||||
}
|
||||
|
||||
ensureSheddingStat()
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sheddingStat.IncrementTotal()
|
||||
promise, err := shedder.Allow()
|
||||
if err != nil {
|
||||
metrics.AddDrop()
|
||||
sheddingStat.IncrementDrop()
|
||||
logx.Errorf("[http] dropped, %s - %s - %s",
|
||||
r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent())
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
cw := &internal.WithCodeResponseWriter{Writer: w}
|
||||
defer func() {
|
||||
if cw.Code == http.StatusServiceUnavailable {
|
||||
promise.Fail()
|
||||
} else {
|
||||
sheddingStat.IncrementPass()
|
||||
promise.Pass()
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(cw, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ensureSheddingStat() {
|
||||
lock.Lock()
|
||||
if sheddingStat == nil {
|
||||
sheddingStat = load.NewSheddingStat(serviceType)
|
||||
}
|
||||
lock.Unlock()
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"zero/core/load"
|
||||
"zero/core/stat"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
}
|
||||
|
||||
func TestSheddingHandlerAccept(t *testing.T) {
|
||||
metrics := stat.NewMetrics("unit-test")
|
||||
shedder := mockShedder{
|
||||
allow: true,
|
||||
}
|
||||
sheddingHandler := SheddingHandler(shedder, metrics)
|
||||
handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Test", "test")
|
||||
_, err := w.Write([]byte("content"))
|
||||
assert.Nil(t, err)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Set("X-Test", "test")
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
assert.Equal(t, "test", resp.Header().Get("X-Test"))
|
||||
assert.Equal(t, "content", resp.Body.String())
|
||||
}
|
||||
|
||||
func TestSheddingHandlerFail(t *testing.T) {
|
||||
metrics := stat.NewMetrics("unit-test")
|
||||
shedder := mockShedder{
|
||||
allow: true,
|
||||
}
|
||||
sheddingHandler := SheddingHandler(shedder, metrics)
|
||||
handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
|
||||
}
|
||||
|
||||
func TestSheddingHandlerReject(t *testing.T) {
|
||||
metrics := stat.NewMetrics("unit-test")
|
||||
shedder := mockShedder{
|
||||
allow: false,
|
||||
}
|
||||
sheddingHandler := SheddingHandler(shedder, metrics)
|
||||
handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
|
||||
}
|
||||
|
||||
func TestSheddingHandlerNoShedding(t *testing.T) {
|
||||
metrics := stat.NewMetrics("unit-test")
|
||||
sheddingHandler := SheddingHandler(nil, metrics)
|
||||
handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
|
||||
type mockShedder struct {
|
||||
allow bool
|
||||
}
|
||||
|
||||
func (s mockShedder) Allow() (load.Promise, error) {
|
||||
if s.allow {
|
||||
return mockPromise{}, nil
|
||||
} else {
|
||||
return nil, load.ErrServiceOverloaded
|
||||
}
|
||||
}
|
||||
|
||||
type mockPromise struct {
|
||||
}
|
||||
|
||||
func (p mockPromise) Pass() {
|
||||
}
|
||||
|
||||
func (p mockPromise) Fail() {
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const reason = "Request Timeout"
|
||||
|
||||
func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
if duration > 0 {
|
||||
return http.TimeoutHandler(next, duration, reason)
|
||||
} else {
|
||||
return next
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
}
|
||||
|
||||
func TestTimeout(t *testing.T) {
|
||||
timeoutHandler := TimeoutHandler(time.Millisecond)
|
||||
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(time.Minute)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
|
||||
}
|
||||
|
||||
func TestWithinTimeout(t *testing.T) {
|
||||
timeoutHandler := TimeoutHandler(time.Second)
|
||||
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(time.Millisecond)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
|
||||
func TestWithoutTimeout(t *testing.T) {
|
||||
timeoutHandler := TimeoutHandler(0)
|
||||
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/core/sysx"
|
||||
"zero/core/trace"
|
||||
)
|
||||
|
||||
func TracingHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
carrier, err := trace.Extract(trace.HttpFormat, r.Header)
|
||||
// ErrInvalidCarrier means no trace id was set in http header
|
||||
if err != nil && err != trace.ErrInvalidCarrier {
|
||||
logx.Error(err)
|
||||
}
|
||||
|
||||
ctx, span := trace.StartServerSpan(r.Context(), carrier, sysx.Hostname(), r.RequestURI)
|
||||
defer span.Finish()
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package httphandler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"zero/core/trace/tracespec"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTracingHandler(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Set("X-Trace-ID", "theid")
|
||||
handler := TracingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
span, ok := r.Context().Value(tracespec.TracingKey).(tracespec.Trace)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "theid", span.TraceId())
|
||||
}))
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
@@ -1,84 +0,0 @@
|
||||
package httplog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"zero/core/httpx"
|
||||
"zero/core/logx"
|
||||
)
|
||||
|
||||
const LogContext = "request_logs"
|
||||
|
||||
type LogCollector struct {
|
||||
Messages []string
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (lc *LogCollector) Append(msg string) {
|
||||
lc.lock.Lock()
|
||||
lc.Messages = append(lc.Messages, msg)
|
||||
lc.lock.Unlock()
|
||||
}
|
||||
|
||||
func (lc *LogCollector) Flush() string {
|
||||
var buffer bytes.Buffer
|
||||
|
||||
start := true
|
||||
for _, message := range lc.takeAll() {
|
||||
if start {
|
||||
start = false
|
||||
} else {
|
||||
buffer.WriteByte('\n')
|
||||
}
|
||||
buffer.WriteString(message)
|
||||
}
|
||||
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
func (lc *LogCollector) takeAll() []string {
|
||||
lc.lock.Lock()
|
||||
messages := lc.Messages
|
||||
lc.Messages = nil
|
||||
lc.lock.Unlock()
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func Error(r *http.Request, v ...interface{}) {
|
||||
logx.ErrorCaller(1, format(r, v...))
|
||||
}
|
||||
|
||||
func Errorf(r *http.Request, format string, v ...interface{}) {
|
||||
logx.ErrorCaller(1, formatf(r, format, v...))
|
||||
}
|
||||
|
||||
func Info(r *http.Request, v ...interface{}) {
|
||||
appendLog(r, format(r, v...))
|
||||
}
|
||||
|
||||
func Infof(r *http.Request, format string, v ...interface{}) {
|
||||
appendLog(r, formatf(r, format, v...))
|
||||
}
|
||||
|
||||
func appendLog(r *http.Request, message string) {
|
||||
logs := r.Context().Value(LogContext)
|
||||
if logs != nil {
|
||||
logs.(*LogCollector).Append(message)
|
||||
}
|
||||
}
|
||||
|
||||
func format(r *http.Request, v ...interface{}) string {
|
||||
return formatWithReq(r, fmt.Sprint(v...))
|
||||
}
|
||||
|
||||
func formatf(r *http.Request, format string, v ...interface{}) string {
|
||||
return formatWithReq(r, fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
func formatWithReq(r *http.Request, v string) string {
|
||||
return fmt.Sprintf("(%s - %s) %s", r.RequestURI, httpx.GetRemoteAddr(r), v)
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
package httplog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInfo(t *testing.T) {
|
||||
collector := new(LogCollector)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), LogContext, collector))
|
||||
Info(req, "first")
|
||||
Infof(req, "second %s", "third")
|
||||
val := collector.Flush()
|
||||
assert.True(t, strings.Contains(val, "first"))
|
||||
assert.True(t, strings.Contains(val, "second"))
|
||||
assert.True(t, strings.Contains(val, "third"))
|
||||
assert.True(t, strings.Contains(val, "\n"))
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
var writer strings.Builder
|
||||
log.SetOutput(&writer)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
Error(req, "first")
|
||||
Errorf(req, "second %s", "third")
|
||||
val := writer.String()
|
||||
assert.True(t, strings.Contains(val, "first"))
|
||||
assert.True(t, strings.Contains(val, "second"))
|
||||
assert.True(t, strings.Contains(val, "third"))
|
||||
assert.True(t, strings.Contains(val, "\n"))
|
||||
}
|
||||
@@ -1,115 +0,0 @@
|
||||
package httprouter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"zero/core/search"
|
||||
)
|
||||
|
||||
const (
|
||||
allowHeader = "Allow"
|
||||
allowMethodSeparator = ", "
|
||||
pathVars = "pathVars"
|
||||
)
|
||||
|
||||
type PatRouter struct {
|
||||
trees map[string]*search.Tree
|
||||
notFound http.Handler
|
||||
}
|
||||
|
||||
func NewPatRouter() Router {
|
||||
return &PatRouter{
|
||||
trees: make(map[string]*search.Tree),
|
||||
}
|
||||
}
|
||||
|
||||
func (pr *PatRouter) Handle(method, reqPath string, handler http.Handler) error {
|
||||
if !validMethod(method) {
|
||||
return ErrInvalidMethod
|
||||
}
|
||||
|
||||
if len(reqPath) == 0 || reqPath[0] != '/' {
|
||||
return ErrInvalidPath
|
||||
}
|
||||
|
||||
cleanPath := path.Clean(reqPath)
|
||||
if tree, ok := pr.trees[method]; ok {
|
||||
return tree.Add(cleanPath, handler)
|
||||
} else {
|
||||
tree = search.NewTree()
|
||||
pr.trees[method] = tree
|
||||
return tree.Add(cleanPath, handler)
|
||||
}
|
||||
}
|
||||
|
||||
func (pr *PatRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
reqPath := path.Clean(r.URL.Path)
|
||||
if tree, ok := pr.trees[r.Method]; ok {
|
||||
if result, ok := tree.Search(reqPath); ok {
|
||||
if len(result.Params) > 0 {
|
||||
r = r.WithContext(context.WithValue(r.Context(), pathVars, result.Params))
|
||||
}
|
||||
result.Item.(http.Handler).ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if allow, ok := pr.methodNotAllowed(r.Method, reqPath); ok {
|
||||
w.Header().Set(allowHeader, allow)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
} else {
|
||||
pr.handleNotFound(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (pr *PatRouter) SetNotFoundHandler(handler http.Handler) {
|
||||
pr.notFound = handler
|
||||
}
|
||||
|
||||
func (pr *PatRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
|
||||
if pr.notFound != nil {
|
||||
pr.notFound.ServeHTTP(w, r)
|
||||
} else {
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (pr *PatRouter) methodNotAllowed(method, path string) (string, bool) {
|
||||
var allows []string
|
||||
|
||||
for treeMethod, tree := range pr.trees {
|
||||
if treeMethod == method {
|
||||
continue
|
||||
}
|
||||
|
||||
_, ok := tree.Search(path)
|
||||
if ok {
|
||||
allows = append(allows, treeMethod)
|
||||
}
|
||||
}
|
||||
|
||||
if len(allows) > 0 {
|
||||
return strings.Join(allows, allowMethodSeparator), true
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func Vars(r *http.Request) map[string]string {
|
||||
vars, ok := r.Context().Value(pathVars).(map[string]string)
|
||||
if ok {
|
||||
return vars
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validMethod(method string) bool {
|
||||
return method == http.MethodDelete || method == http.MethodGet ||
|
||||
method == http.MethodHead || method == http.MethodOptions ||
|
||||
method == http.MethodPatch || method == http.MethodPost ||
|
||||
method == http.MethodPut
|
||||
}
|
||||
@@ -1,120 +0,0 @@
|
||||
package httprouter
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockedResponseWriter struct {
|
||||
code int
|
||||
}
|
||||
|
||||
func (m *mockedResponseWriter) Header() http.Header {
|
||||
return http.Header{}
|
||||
}
|
||||
|
||||
func (m *mockedResponseWriter) Write(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (m *mockedResponseWriter) WriteHeader(code int) {
|
||||
m.code = code
|
||||
}
|
||||
|
||||
func TestPatRouterHandleErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
method string
|
||||
path string
|
||||
err error
|
||||
}{
|
||||
{"FAKE", "", ErrInvalidMethod},
|
||||
{"GET", "", ErrInvalidPath},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.method, func(t *testing.T) {
|
||||
router := NewPatRouter()
|
||||
err := router.Handle(test.method, test.path, nil)
|
||||
assert.Error(t, ErrInvalidMethod, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatRouterNotFound(t *testing.T) {
|
||||
var notFound bool
|
||||
router := NewPatRouter()
|
||||
router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
notFound = true
|
||||
}))
|
||||
router.Handle(http.MethodGet, "/a/b", nil)
|
||||
r, _ := http.NewRequest(http.MethodGet, "/b/c", nil)
|
||||
w := new(mockedResponseWriter)
|
||||
router.ServeHTTP(w, r)
|
||||
assert.True(t, notFound)
|
||||
}
|
||||
|
||||
func TestPatRouter(t *testing.T) {
|
||||
tests := []struct {
|
||||
method string
|
||||
path string
|
||||
expect bool
|
||||
code int
|
||||
err error
|
||||
}{
|
||||
// we don't explicitly set status code, framework will do it.
|
||||
{http.MethodGet, "/a/b", true, 0, nil},
|
||||
{http.MethodGet, "/a/b/", true, 0, nil},
|
||||
{http.MethodGet, "/a/b?a=b", true, 0, nil},
|
||||
{http.MethodGet, "/a/b/?a=b", true, 0, nil},
|
||||
{http.MethodGet, "/a/b/c?a=b", true, 0, nil},
|
||||
{http.MethodGet, "/b/d", false, http.StatusNotFound, nil},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.method+":"+test.path, func(t *testing.T) {
|
||||
routed := false
|
||||
router := NewPatRouter()
|
||||
err := router.Handle(test.method, "/a/:b", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
routed = true
|
||||
assert.Equal(t, 1, len(Vars(r)))
|
||||
}))
|
||||
assert.Nil(t, err)
|
||||
err = router.Handle(test.method, "/a/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
routed = true
|
||||
assert.Nil(t, Vars(r))
|
||||
}))
|
||||
assert.Nil(t, err)
|
||||
err = router.Handle(test.method, "/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
routed = true
|
||||
}))
|
||||
assert.Nil(t, err)
|
||||
|
||||
w := new(mockedResponseWriter)
|
||||
r, _ := http.NewRequest(test.method, test.path, nil)
|
||||
router.ServeHTTP(w, r)
|
||||
assert.Equal(t, test.expect, routed)
|
||||
assert.Equal(t, test.code, w.code)
|
||||
|
||||
if test.code == 0 {
|
||||
r, _ = http.NewRequest(http.MethodPut, test.path, nil)
|
||||
router.ServeHTTP(w, r)
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, w.code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPatRouter(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
router := NewPatRouter()
|
||||
router.Handle(http.MethodGet, "/api/:user/:name", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
}))
|
||||
w := &mockedResponseWriter{}
|
||||
r, _ := http.NewRequest(http.MethodGet, "/api/a/b", nil)
|
||||
for i := 0; i < b.N; i++ {
|
||||
router.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
package httprouter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidMethod = errors.New("not a valid http method")
|
||||
ErrInvalidPath = errors.New("path must begin with '/'")
|
||||
)
|
||||
|
||||
type (
|
||||
Route struct {
|
||||
Path string
|
||||
Handler http.HandlerFunc
|
||||
}
|
||||
|
||||
Router interface {
|
||||
http.Handler
|
||||
Handle(method string, path string, handler http.Handler) error
|
||||
SetNotFoundHandler(handler http.Handler)
|
||||
}
|
||||
)
|
||||
@@ -1,122 +0,0 @@
|
||||
package httpsecurity
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"zero/core/timex"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/dgrijalva/jwt-go/request"
|
||||
)
|
||||
|
||||
const claimHistoryResetDuration = time.Hour * 24
|
||||
|
||||
type (
|
||||
ParseOption func(parser *TokenParser)
|
||||
|
||||
TokenParser struct {
|
||||
resetTime time.Duration
|
||||
resetDuration time.Duration
|
||||
history sync.Map
|
||||
}
|
||||
)
|
||||
|
||||
func NewTokenParser(opts ...ParseOption) *TokenParser {
|
||||
parser := &TokenParser{
|
||||
resetTime: timex.Now(),
|
||||
resetDuration: claimHistoryResetDuration,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(parser)
|
||||
}
|
||||
|
||||
return parser
|
||||
}
|
||||
|
||||
func (tp *TokenParser) ParseToken(r *http.Request, secret, prevSecret string) (*jwt.Token, error) {
|
||||
var token *jwt.Token
|
||||
var err error
|
||||
|
||||
if len(prevSecret) > 0 {
|
||||
count := tp.loadCount(secret)
|
||||
prevCount := tp.loadCount(prevSecret)
|
||||
|
||||
var first, second string
|
||||
if count > prevCount {
|
||||
first = secret
|
||||
second = prevSecret
|
||||
} else {
|
||||
first = prevSecret
|
||||
second = secret
|
||||
}
|
||||
|
||||
token, err = tp.doParseToken(r, first)
|
||||
if err != nil {
|
||||
token, err = tp.doParseToken(r, second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
tp.incrementCount(second)
|
||||
}
|
||||
} else {
|
||||
tp.incrementCount(first)
|
||||
}
|
||||
} else {
|
||||
token, err = tp.doParseToken(r, secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (tp *TokenParser) doParseToken(r *http.Request, secret string) (*jwt.Token, error) {
|
||||
return request.ParseFromRequest(r, request.AuthorizationHeaderExtractor,
|
||||
func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(secret), nil
|
||||
}, request.WithParser(newParser()))
|
||||
}
|
||||
|
||||
func (tp *TokenParser) incrementCount(secret string) {
|
||||
now := timex.Now()
|
||||
if tp.resetTime+tp.resetDuration < now {
|
||||
tp.history.Range(func(key, value interface{}) bool {
|
||||
tp.history.Delete(key)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
value, ok := tp.history.Load(secret)
|
||||
if ok {
|
||||
atomic.AddUint64(value.(*uint64), 1)
|
||||
} else {
|
||||
var count uint64 = 1
|
||||
tp.history.Store(secret, &count)
|
||||
}
|
||||
}
|
||||
|
||||
func (tp *TokenParser) loadCount(secret string) uint64 {
|
||||
value, ok := tp.history.Load(secret)
|
||||
if ok {
|
||||
return *value.(*uint64)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func WithResetDuration(duration time.Duration) ParseOption {
|
||||
return func(parser *TokenParser) {
|
||||
parser.resetDuration = duration
|
||||
}
|
||||
}
|
||||
|
||||
func newParser() *jwt.Parser {
|
||||
return &jwt.Parser{
|
||||
UseJSONNumber: true,
|
||||
}
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
package httpsecurity
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"zero/core/timex"
|
||||
)
|
||||
|
||||
func TestTokenParser(t *testing.T) {
|
||||
const (
|
||||
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
|
||||
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
|
||||
)
|
||||
keys := []struct {
|
||||
key string
|
||||
prevKey string
|
||||
}{
|
||||
{
|
||||
key,
|
||||
prevKey,
|
||||
},
|
||||
{
|
||||
key,
|
||||
"",
|
||||
},
|
||||
}
|
||||
|
||||
for _, pair := range keys {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
token, err := buildToken(key, map[string]interface{}{
|
||||
"key": "value",
|
||||
}, 3600)
|
||||
assert.Nil(t, err)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
parser := NewTokenParser(WithResetDuration(time.Minute))
|
||||
tok, err := parser.ParseToken(req, pair.key, pair.prevKey)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenParser_Expired(t *testing.T) {
|
||||
const (
|
||||
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
|
||||
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
token, err := buildToken(key, map[string]interface{}{
|
||||
"key": "value",
|
||||
}, 3600)
|
||||
assert.Nil(t, err)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
parser := NewTokenParser(WithResetDuration(time.Second))
|
||||
tok, err := parser.ParseToken(req, key, prevKey)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
|
||||
tok, err = parser.ParseToken(req, key, prevKey)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
|
||||
parser.resetTime = timex.Now() - time.Hour
|
||||
tok, err = parser.ParseToken(req, key, prevKey)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
|
||||
}
|
||||
|
||||
func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
|
||||
now := time.Now().Unix()
|
||||
claims := make(jwt.MapClaims)
|
||||
claims["exp"] = now + seconds
|
||||
claims["iat"] = now
|
||||
for k, v := range payloads {
|
||||
claims[k] = v
|
||||
}
|
||||
|
||||
token := jwt.New(jwt.SigningMethodHS256)
|
||||
token.Claims = claims
|
||||
|
||||
return token.SignedString([]byte(secretKey))
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
package httpserver
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func StartHttp(host string, port int, handler http.Handler) error {
|
||||
addr := fmt.Sprintf("%s:%d", host, port)
|
||||
server := buildHttpServer(addr, handler)
|
||||
return StartServer(server)
|
||||
}
|
||||
|
||||
func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler) error {
|
||||
addr := fmt.Sprintf("%s:%d", host, port)
|
||||
if server, err := buildHttpsServer(addr, handler, certFile, keyFile); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return StartServer(server)
|
||||
}
|
||||
}
|
||||
|
||||
func buildHttpServer(addr string, handler http.Handler) *http.Server {
|
||||
return &http.Server{Addr: addr, Handler: handler}
|
||||
}
|
||||
|
||||
func buildHttpsServer(addr string, handler http.Handler, certFile, keyFile string) (*http.Server, error) {
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config := tls.Config{Certificates: []tls.Certificate{cert}}
|
||||
return &http.Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
TLSConfig: &config,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
package httpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"zero/core/proc"
|
||||
)
|
||||
|
||||
func StartServer(srv *http.Server) error {
|
||||
proc.AddWrapUpListener(func() {
|
||||
srv.Shutdown(context.Background())
|
||||
})
|
||||
|
||||
return srv.ListenAndServe()
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package httpx
|
||||
|
||||
const (
|
||||
ApplicationJson = "application/json"
|
||||
ContentEncoding = "Content-Encoding"
|
||||
ContentSecurity = "X-Content-Security"
|
||||
ContentType = "Content-Type"
|
||||
KeyField = "key"
|
||||
SecretField = "secret"
|
||||
TypeField = "type"
|
||||
CryptionType = 1
|
||||
)
|
||||
|
||||
const (
|
||||
CodeSignaturePass = iota
|
||||
CodeSignatureInvalidHeader
|
||||
CodeSignatureWrongTime
|
||||
CodeSignatureInvalidToken
|
||||
)
|
||||
@@ -1,124 +0,0 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"zero/core/httprouter"
|
||||
"zero/core/mapping"
|
||||
)
|
||||
|
||||
const (
|
||||
multipartFormData = "multipart/form-data"
|
||||
xForwardFor = "X-Forward-For"
|
||||
formKey = "form"
|
||||
pathKey = "path"
|
||||
emptyJson = "{}"
|
||||
maxMemory = 32 << 20 // 32MB
|
||||
maxBodyLen = 8 << 20 // 8MB
|
||||
separator = ";"
|
||||
tokensInAttribute = 2
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBodylessRequest = errors.New("not a POST|PUT|PATCH request")
|
||||
|
||||
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
|
||||
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
|
||||
)
|
||||
|
||||
// Returns the peer address, supports X-Forward-For
|
||||
func GetRemoteAddr(r *http.Request) string {
|
||||
v := r.Header.Get(xForwardFor)
|
||||
if len(v) > 0 {
|
||||
return v
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
func Parse(r *http.Request, v interface{}) error {
|
||||
if err := ParsePath(r, v); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ParseForm(r, v); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ParseJsonBody(r, v)
|
||||
}
|
||||
|
||||
// Parses the form request.
|
||||
func ParseForm(r *http.Request, v interface{}) error {
|
||||
if strings.Index(r.Header.Get(ContentType), multipartFormData) != -1 {
|
||||
if err := r.ParseMultipartForm(maxMemory); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
params := make(map[string]interface{}, len(r.Form))
|
||||
for name := range r.Form {
|
||||
formValue := r.Form.Get(name)
|
||||
if len(formValue) > 0 {
|
||||
params[name] = formValue
|
||||
}
|
||||
}
|
||||
|
||||
return formUnmarshaler.Unmarshal(params, v)
|
||||
}
|
||||
|
||||
func ParseHeader(headerValue string) map[string]string {
|
||||
ret := make(map[string]string)
|
||||
fields := strings.Split(headerValue, separator)
|
||||
|
||||
for _, field := range fields {
|
||||
field = strings.TrimSpace(field)
|
||||
if len(field) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
kv := strings.SplitN(field, "=", tokensInAttribute)
|
||||
if len(kv) != tokensInAttribute {
|
||||
continue
|
||||
}
|
||||
|
||||
ret[kv[0]] = kv[1]
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// Parses the post request which contains json in body.
|
||||
func ParseJsonBody(r *http.Request, v interface{}) error {
|
||||
var reader io.Reader
|
||||
|
||||
if withJsonBody(r) {
|
||||
reader = io.LimitReader(r.Body, maxBodyLen)
|
||||
} else {
|
||||
reader = strings.NewReader(emptyJson)
|
||||
}
|
||||
|
||||
return mapping.UnmarshalJsonReader(reader, v)
|
||||
}
|
||||
|
||||
// Parses the symbols reside in url path.
|
||||
// Like http://localhost/bag/:name
|
||||
func ParsePath(r *http.Request, v interface{}) error {
|
||||
vars := httprouter.Vars(r)
|
||||
m := make(map[string]interface{}, len(vars))
|
||||
for k, v := range vars {
|
||||
m[k] = v
|
||||
}
|
||||
|
||||
return pathUnmarshaler.Unmarshal(m, v)
|
||||
}
|
||||
|
||||
func withJsonBody(r *http.Request) bool {
|
||||
return r.ContentLength > 0 && strings.Index(r.Header.Get(ContentType), ApplicationJson) != -1
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,29 +0,0 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"zero/core/logx"
|
||||
)
|
||||
|
||||
func OkJson(w http.ResponseWriter, v interface{}) {
|
||||
WriteJson(w, http.StatusOK, v)
|
||||
}
|
||||
|
||||
func WriteJson(w http.ResponseWriter, code int, v interface{}) {
|
||||
w.Header().Set(ContentType, ApplicationJson)
|
||||
w.WriteHeader(code)
|
||||
|
||||
if bs, err := json.Marshal(v); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
} else if n, err := w.Write(bs); err != nil {
|
||||
// http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
|
||||
// so it's ignored here.
|
||||
if err != http.ErrHandlerTimeout {
|
||||
logx.Errorf("write response failed, error: %s", err)
|
||||
}
|
||||
} else if n < len(bs) {
|
||||
logx.Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
|
||||
}
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"zero/core/logx"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type message struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
logx.Disable()
|
||||
}
|
||||
|
||||
func TestOkJson(t *testing.T) {
|
||||
w := tracedResponseWriter{
|
||||
headers: make(map[string][]string),
|
||||
}
|
||||
msg := message{Name: "anyone"}
|
||||
OkJson(&w, msg)
|
||||
assert.Equal(t, http.StatusOK, w.code)
|
||||
assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
|
||||
}
|
||||
|
||||
func TestWriteJsonTimeout(t *testing.T) {
|
||||
// only log it and ignore
|
||||
w := tracedResponseWriter{
|
||||
headers: make(map[string][]string),
|
||||
timeout: true,
|
||||
}
|
||||
msg := message{Name: "anyone"}
|
||||
WriteJson(&w, http.StatusOK, msg)
|
||||
assert.Equal(t, http.StatusOK, w.code)
|
||||
}
|
||||
|
||||
func TestWriteJsonLessWritten(t *testing.T) {
|
||||
w := tracedResponseWriter{
|
||||
headers: make(map[string][]string),
|
||||
lessWritten: true,
|
||||
}
|
||||
msg := message{Name: "anyone"}
|
||||
WriteJson(&w, http.StatusOK, msg)
|
||||
assert.Equal(t, http.StatusOK, w.code)
|
||||
}
|
||||
|
||||
type tracedResponseWriter struct {
|
||||
headers map[string][]string
|
||||
builder strings.Builder
|
||||
code int
|
||||
lessWritten bool
|
||||
timeout bool
|
||||
}
|
||||
|
||||
func (w *tracedResponseWriter) Header() http.Header {
|
||||
return w.headers
|
||||
}
|
||||
|
||||
func (w *tracedResponseWriter) Write(bytes []byte) (n int, err error) {
|
||||
if w.timeout {
|
||||
return 0, http.ErrHandlerTimeout
|
||||
}
|
||||
|
||||
n, err = w.builder.Write(bytes)
|
||||
if w.lessWritten {
|
||||
n -= 1
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (w *tracedResponseWriter) WriteHeader(code int) {
|
||||
w.code = code
|
||||
}
|
||||
Reference in New Issue
Block a user