This commit is contained in:
kevin
2020-07-29 18:00:04 +08:00
parent 121323b8c3
commit ca3934582a
58 changed files with 222 additions and 200 deletions

View File

@@ -1,129 +0,0 @@
package httphandler
import (
"context"
"net/http"
"net/http/httputil"
"zero/core/httpsecurity"
"zero/core/logx"
"github.com/dgrijalva/jwt-go"
)
const (
jwtAudience = "aud"
jwtExpire = "exp"
jwtId = "jti"
jwtIssueAt = "iat"
jwtIssuer = "iss"
jwtNotBefore = "nbf"
jwtSubject = "sub"
)
type (
AuthorizeOptions struct {
PrevSecret string
Callback UnauthorizedCallback
}
UnauthorizedCallback func(w http.ResponseWriter, r *http.Request, err error)
AuthorizeOption func(opts *AuthorizeOptions)
)
func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.Handler {
var authOpts AuthorizeOptions
for _, opt := range opts {
opt(&authOpts)
}
parser := httpsecurity.NewTokenParser()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, err := parser.ParseToken(r, secret, authOpts.PrevSecret)
if err != nil {
unauthorized(w, r, err, authOpts.Callback)
return
}
if !token.Valid {
unauthorized(w, r, err, authOpts.Callback)
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
unauthorized(w, r, err, authOpts.Callback)
return
}
ctx := r.Context()
for k, v := range claims {
switch k {
case jwtAudience, jwtExpire, jwtId, jwtIssueAt, jwtIssuer, jwtNotBefore, jwtSubject:
// ignore the standard claims
default:
ctx = context.WithValue(ctx, k, v)
}
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
func WithPrevSecret(secret string) AuthorizeOption {
return func(opts *AuthorizeOptions) {
opts.PrevSecret = secret
}
}
func WithUnauthorizedCallback(callback UnauthorizedCallback) AuthorizeOption {
return func(opts *AuthorizeOptions) {
opts.Callback = callback
}
}
func detailAuthLog(r *http.Request, reason string) {
// discard dump error, only for debug purpose
details, _ := httputil.DumpRequest(r, true)
logx.Errorf("authorize failed: %s\n=> %+v", reason, string(details))
}
func unauthorized(w http.ResponseWriter, r *http.Request, err error, callback UnauthorizedCallback) {
writer := newGuardedResponseWriter(w)
detailAuthLog(r, err.Error())
if callback != nil {
callback(writer, r, err)
}
writer.WriteHeader(http.StatusUnauthorized)
}
type guardedResponseWriter struct {
writer http.ResponseWriter
wroteHeader bool
}
func newGuardedResponseWriter(w http.ResponseWriter) *guardedResponseWriter {
return &guardedResponseWriter{
writer: w,
}
}
func (grw *guardedResponseWriter) Header() http.Header {
return grw.writer.Header()
}
func (grw *guardedResponseWriter) Write(body []byte) (int, error) {
return grw.writer.Write(body)
}
func (grw *guardedResponseWriter) WriteHeader(statusCode int) {
if grw.wroteHeader {
return
}
grw.wroteHeader = true
grw.writer.WriteHeader(statusCode)
}

View File

@@ -1,91 +0,0 @@
package httphandler
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/dgrijalva/jwt-go"
"github.com/stretchr/testify/assert"
)
func TestAuthHandlerFailed(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler := Authorize("B63F477D-BBA3-4E52-96D3-C0034C27694A", WithUnauthorizedCallback(
func(w http.ResponseWriter, r *http.Request, err error) {
w.Header().Set("X-Test", "test")
w.WriteHeader(http.StatusUnauthorized)
_, err = w.Write([]byte("content"))
assert.Nil(t, err)
}))(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusUnauthorized, resp.Code)
}
func TestAuthHandler(t *testing.T) {
const key = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
token, err := buildToken(key, map[string]interface{}{
"key": "value",
}, 3600)
assert.Nil(t, err)
req.Header.Set("Authorization", "Bearer "+token)
handler := Authorize(key)(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test", "test")
_, err := w.Write([]byte("content"))
assert.Nil(t, err)
}))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, "test", resp.Header().Get("X-Test"))
assert.Equal(t, "content", resp.Body.String())
}
func TestAuthHandlerWithPrevSecret(t *testing.T) {
const (
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
)
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
token, err := buildToken(key, map[string]interface{}{
"key": "value",
}, 3600)
assert.Nil(t, err)
req.Header.Set("Authorization", "Bearer "+token)
handler := Authorize(key, WithPrevSecret(prevKey))(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test", "test")
_, err := w.Write([]byte("content"))
assert.Nil(t, err)
}))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, "test", resp.Header().Get("X-Test"))
assert.Equal(t, "content", resp.Body.String())
}
func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
now := time.Now().Unix()
claims := make(jwt.MapClaims)
claims["exp"] = now + seconds
claims["iat"] = now
for k, v := range payloads {
claims[k] = v
}
token := jwt.New(jwt.SigningMethodHS256)
token.Claims = claims
return token.SignedString([]byte(secretKey))
}

View File

@@ -1,41 +0,0 @@
package httphandler
import (
"fmt"
"net/http"
"strings"
"zero/core/breaker"
"zero/core/httphandler/internal"
"zero/core/httpx"
"zero/core/logx"
"zero/core/stat"
)
const breakerSeparator = "://"
func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handler) http.Handler {
brk := breaker.NewBreaker(breaker.WithName(strings.Join([]string{method, path}, breakerSeparator)))
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
promise, err := brk.Allow()
if err != nil {
metrics.AddDrop()
logx.Errorf("[http] dropped, %s - %s - %s",
r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent())
w.WriteHeader(http.StatusServiceUnavailable)
return
}
cw := &internal.WithCodeResponseWriter{Writer: w}
defer func() {
if cw.Code < http.StatusInternalServerError {
promise.Accept()
} else {
promise.Reject(fmt.Sprintf("%d %s", cw.Code, http.StatusText(cw.Code)))
}
}()
next.ServeHTTP(cw, r)
})
}
}

View File

@@ -1,102 +0,0 @@
package httphandler
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"zero/core/logx"
"zero/core/stat"
"github.com/stretchr/testify/assert"
)
func init() {
logx.Disable()
stat.SetReporter(nil)
}
func TestBreakerHandlerAccept(t *testing.T) {
metrics := stat.NewMetrics("unit-test")
breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test", "test")
_, err := w.Write([]byte("content"))
assert.Nil(t, err)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Set("X-Test", "test")
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, "test", resp.Header().Get("X-Test"))
assert.Equal(t, "content", resp.Body.String())
}
func TestBreakerHandlerFail(t *testing.T) {
metrics := stat.NewMetrics("unit-test")
breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadGateway)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusBadGateway, resp.Code)
}
func TestBreakerHandler_4XX(t *testing.T) {
metrics := stat.NewMetrics("unit-test")
breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}))
for i := 0; i < 1000; i++ {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
}
const tries = 100
var pass int
for i := 0; i < tries; i++ {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
if resp.Code == http.StatusBadRequest {
pass++
}
}
assert.Equal(t, tries, pass)
}
func TestBreakerHandlerReject(t *testing.T) {
metrics := stat.NewMetrics("unit-test")
breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
for i := 0; i < 1000; i++ {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
}
var drops int
for i := 0; i < 100; i++ {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
if resp.Code == http.StatusServiceUnavailable {
drops++
}
}
assert.True(t, drops >= 80, fmt.Sprintf("expected to be greater than 80, but got %d", drops))
}

View File

@@ -1,61 +0,0 @@
package httphandler
import (
"net/http"
"time"
"zero/core/codec"
"zero/core/httphandler/internal"
"zero/core/httpx"
"zero/core/logx"
)
const contentSecurity = "X-Content-Security"
type UnsignedCallback func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int)
func ContentSecurityHandler(decrypters map[string]codec.RsaDecrypter, tolerance time.Duration,
strict bool, callbacks ...UnsignedCallback) func(http.Handler) http.Handler {
if len(callbacks) == 0 {
callbacks = append(callbacks, handleVerificationFailure)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodDelete, http.MethodGet, http.MethodPost, http.MethodPut:
header, err := internal.ParseContentSecurity(decrypters, r)
if err != nil {
logx.Infof("Signature parse failed, X-Content-Security: %s, error: %s",
r.Header.Get(contentSecurity), err.Error())
executeCallbacks(w, r, next, strict, httpx.CodeSignatureInvalidHeader, callbacks)
} else if code := internal.VerifySignature(r, header, tolerance); code != httpx.CodeSignaturePass {
logx.Infof("Signature verification failed, X-Content-Security: %s",
r.Header.Get(contentSecurity))
executeCallbacks(w, r, next, strict, code, callbacks)
} else if r.ContentLength > 0 && header.Encrypted() {
CryptionHandler(header.Key)(next).ServeHTTP(w, r)
} else {
next.ServeHTTP(w, r)
}
default:
next.ServeHTTP(w, r)
}
})
}
}
func executeCallbacks(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool,
code int, callbacks []UnsignedCallback) {
for _, callback := range callbacks {
callback(w, r, next, strict, code)
}
}
func handleVerificationFailure(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
if strict {
w.WriteHeader(http.StatusUnauthorized)
} else {
next.ServeHTTP(w, r)
}
}

View File

@@ -1,388 +0,0 @@
package httphandler
import (
"bytes"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strconv"
"strings"
"testing"
"time"
"zero/core/codec"
"zero/core/httpx"
"github.com/stretchr/testify/assert"
)
const timeDiff = time.Hour * 2 * 24
var (
fingerprint = "12345"
pubKey = []byte(`-----BEGIN PUBLIC KEY-----
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQD7bq4FLG0ctccbEFEsUBuRxkjE
eJ5U+0CAEjJk20V9/u2Fu76i1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVH
miYbRgh5Fy6336KepLCtCmV/r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwR
my47YlhspwszKdRP+wIDAQAB
-----END PUBLIC KEY-----`)
priKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQD7bq4FLG0ctccbEFEsUBuRxkjEeJ5U+0CAEjJk20V9/u2Fu76i
1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVHmiYbRgh5Fy6336KepLCtCmV/
r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwRmy47YlhspwszKdRP+wIDAQAB
AoGBANs1qf7UtuSbD1ZnKX5K8V5s07CHwPMygw+lzc3k5ndtNUStZQ2vnAaBXHyH
Nm4lJ4AI2mhQ39jQB/1TyP1uAzvpLhT60fRybEq9zgJ/81Gm9bnaEpFJ9bP2bBrY
J0jbaTMfbzL/PJFl3J3RGMR40C76h5yRYSnOpMoMiKWnJqrhAkEA/zCOkR+34Pk0
Yo3sIP4ranY6AAvwacgNaui4ll5xeYwv3iLOQvPlpxIxFHKXEY0klNNyjjXqgYjP
cOenqtt6UwJBAPw7EYuteVHvHvQVuTbKAaYHcOrp4nFeZF3ndFfl0w2dwGhfzcXO
ROyd5dNQCuCWRo8JBpjG6PFyzezayF4KLrkCQCGditoxHG7FRRJKcbVy5dMzWbaR
3AyDLslLeK1OKZKCVffkC9mj+TeF3PM9mQrV1eDI7ckv7wE7PWA5E8wc90MCQEOV
MCZU3OTvRUPxbicYCUkLRV4sPNhTimD+21WR5vMHCb7trJ0Ln7wmsqXkFIYIve8l
Y/cblN7c/AAyvu0znUECQA318nPldsxR6+H8HTS3uEbkL4UJdjQJHsvTwKxAw5qc
moKExvRlN0zmGGuArKcqS38KG7PXZMrUv3FXPdp6BDQ=
-----END RSA PRIVATE KEY-----`)
key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
)
type requestSettings struct {
method string
url string
body io.Reader
strict bool
crypt bool
requestUri string
timestamp int64
fingerprint string
missHeader bool
signature string
}
func init() {
log.SetOutput(ioutil.Discard)
}
func TestContentSecurityHandler(t *testing.T) {
tests := []struct {
method string
url string
body string
strict bool
crypt bool
requestUri string
timestamp int64
fingerprint string
missHeader bool
signature string
statusCode int
}{
{
method: http.MethodGet,
url: "http://localhost/a/b?c=d&e=f",
strict: true,
crypt: false,
},
{
method: http.MethodPost,
url: "http://localhost/a/b?c=d&e=f",
body: "hello",
strict: true,
crypt: false,
},
{
method: http.MethodGet,
url: "http://localhost/a/b?c=d&e=f",
strict: true,
crypt: true,
},
{
method: http.MethodPost,
url: "http://localhost/a/b?c=d&e=f",
body: "hello",
strict: true,
crypt: true,
},
{
method: http.MethodGet,
url: "http://localhost/a/b?c=d&e=f",
strict: true,
crypt: true,
timestamp: time.Now().Add(timeDiff).Unix(),
statusCode: http.StatusUnauthorized,
},
{
method: http.MethodPost,
url: "http://localhost/a/b?c=d&e=f",
body: "hello",
strict: true,
crypt: true,
timestamp: time.Now().Add(-timeDiff).Unix(),
statusCode: http.StatusUnauthorized,
},
{
method: http.MethodPost,
url: "http://remotehost/",
body: "hello",
strict: true,
crypt: true,
requestUri: "http://localhost/a/b?c=d&e=f",
},
{
method: http.MethodPost,
url: "http://localhost/a/b?c=d&e=f",
body: "hello",
strict: false,
crypt: true,
fingerprint: "badone",
},
{
method: http.MethodPost,
url: "http://localhost/a/b?c=d&e=f",
body: "hello",
strict: true,
crypt: true,
timestamp: time.Now().Add(-timeDiff).Unix(),
fingerprint: "badone",
statusCode: http.StatusUnauthorized,
},
{
method: http.MethodPost,
url: "http://localhost/a/b?c=d&e=f",
body: "hello",
strict: true,
crypt: true,
missHeader: true,
statusCode: http.StatusUnauthorized,
},
{
method: http.MethodHead,
url: "http://localhost/a/b?c=d&e=f",
strict: true,
crypt: false,
},
{
method: http.MethodGet,
url: "http://localhost/a/b?c=d&e=f",
strict: true,
crypt: false,
signature: "badone",
statusCode: http.StatusUnauthorized,
},
}
for _, test := range tests {
t.Run(test.url, func(t *testing.T) {
if test.statusCode == 0 {
test.statusCode = http.StatusOK
}
if len(test.fingerprint) == 0 {
test.fingerprint = fingerprint
}
if test.timestamp == 0 {
test.timestamp = time.Now().Unix()
}
func() {
keyFile, err := createTempFile(priKey)
defer os.Remove(keyFile)
assert.Nil(t, err)
decrypter, err := codec.NewRsaDecrypter(keyFile)
assert.Nil(t, err)
contentSecurityHandler := ContentSecurityHandler(map[string]codec.RsaDecrypter{
fingerprint: decrypter,
}, time.Hour, test.strict)
handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
var reader io.Reader
if len(test.body) > 0 {
reader = strings.NewReader(test.body)
}
setting := requestSettings{
method: test.method,
url: test.url,
body: reader,
strict: test.strict,
crypt: test.crypt,
requestUri: test.requestUri,
timestamp: test.timestamp,
fingerprint: test.fingerprint,
missHeader: test.missHeader,
signature: test.signature,
}
req, err := buildRequest(setting)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, test.statusCode, resp.Code)
}()
})
}
}
func TestContentSecurityHandler_UnsignedCallback(t *testing.T) {
keyFile, err := createTempFile(priKey)
defer os.Remove(keyFile)
assert.Nil(t, err)
decrypter, err := codec.NewRsaDecrypter(keyFile)
assert.Nil(t, err)
contentSecurityHandler := ContentSecurityHandler(
map[string]codec.RsaDecrypter{
fingerprint: decrypter,
},
time.Hour,
true,
func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
w.WriteHeader(http.StatusOK)
})
handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
setting := requestSettings{
method: http.MethodGet,
url: "http://localhost/a/b?c=d&e=f",
signature: "badone",
}
req, err := buildRequest(setting)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}
func TestContentSecurityHandler_UnsignedCallback_WrongTime(t *testing.T) {
keyFile, err := createTempFile(priKey)
defer os.Remove(keyFile)
assert.Nil(t, err)
decrypter, err := codec.NewRsaDecrypter(keyFile)
assert.Nil(t, err)
contentSecurityHandler := ContentSecurityHandler(
map[string]codec.RsaDecrypter{
fingerprint: decrypter,
},
time.Hour,
true,
func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
assert.Equal(t, httpx.CodeSignatureWrongTime, code)
w.WriteHeader(http.StatusOK)
})
handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
var reader io.Reader
reader = strings.NewReader("hello")
setting := requestSettings{
method: http.MethodPost,
url: "http://localhost/a/b?c=d&e=f",
body: reader,
strict: true,
crypt: true,
timestamp: time.Now().Add(time.Hour * 24 * 365).Unix(),
fingerprint: fingerprint,
}
req, err := buildRequest(setting)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}
func buildRequest(rs requestSettings) (*http.Request, error) {
var bodyStr string
var err error
if rs.crypt && rs.body != nil {
var buf bytes.Buffer
io.Copy(&buf, rs.body)
bodyBytes, err := codec.EcbEncrypt(key, buf.Bytes())
if err != nil {
return nil, err
}
bodyStr = base64.StdEncoding.EncodeToString(bodyBytes)
}
r := httptest.NewRequest(rs.method, rs.url, strings.NewReader(bodyStr))
if len(rs.signature) == 0 {
sha := sha256.New()
sha.Write([]byte(bodyStr))
bodySign := fmt.Sprintf("%x", sha.Sum(nil))
var path string
var query string
if len(rs.requestUri) > 0 {
if u, err := url.Parse(rs.requestUri); err != nil {
return nil, err
} else {
path = u.Path
query = u.RawQuery
}
} else {
path = r.URL.Path
query = r.URL.RawQuery
}
contentOfSign := strings.Join([]string{
strconv.FormatInt(rs.timestamp, 10),
rs.method,
path,
query,
bodySign,
}, "\n")
rs.signature = codec.HmacBase64([]byte(key), contentOfSign)
}
var mode string
if rs.crypt {
mode = "1"
} else {
mode = "0"
}
content := strings.Join([]string{
"version=v1",
"type=" + mode,
fmt.Sprintf("key=%s", base64.StdEncoding.EncodeToString(key)),
"time=" + strconv.FormatInt(rs.timestamp, 10),
}, "; ")
encrypter, err := codec.NewRsaEncrypter([]byte(pubKey))
if err != nil {
log.Fatal(err)
}
output, err := encrypter.Encrypt([]byte(content))
if err != nil {
log.Fatal(err)
}
encryptedContent := base64.StdEncoding.EncodeToString(output)
if !rs.missHeader {
r.Header.Set(httpx.ContentSecurity, strings.Join([]string{
fmt.Sprintf("key=%s", rs.fingerprint),
"secret=" + encryptedContent,
"signature=" + rs.signature,
}, "; "))
}
if len(rs.requestUri) > 0 {
r.Header.Set("X-Request-Uri", rs.requestUri)
}
return r, nil
}
func createTempFile(body []byte) (string, error) {
tmpFile, err := ioutil.TempFile(os.TempDir(), "go-unit-*.tmp")
if err != nil {
return "", err
} else {
tmpFile.Close()
}
err = ioutil.WriteFile(tmpFile.Name(), body, os.ModePerm)
if err != nil {
return "", err
}
return tmpFile.Name(), nil
}

View File

@@ -1,101 +0,0 @@
package httphandler
import (
"bytes"
"encoding/base64"
"io"
"io/ioutil"
"net/http"
"zero/core/codec"
"zero/core/logx"
)
const maxBytes = 1 << 20 // 1 MiB
func CryptionHandler(key []byte) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cw := newCryptionResponseWriter(w)
defer cw.flush(key)
if r.ContentLength <= 0 {
next.ServeHTTP(cw, r)
return
}
if err := decryptBody(key, r); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
next.ServeHTTP(cw, r)
})
}
}
func decryptBody(key []byte, r *http.Request) error {
content, err := ioutil.ReadAll(io.LimitReader(r.Body, maxBytes))
if err != nil {
return err
}
content, err = base64.StdEncoding.DecodeString(string(content))
if err != nil {
return err
}
output, err := codec.EcbDecrypt(key, content)
if err != nil {
return err
}
var buf bytes.Buffer
buf.Write(output)
r.Body = ioutil.NopCloser(&buf)
return nil
}
type cryptionResponseWriter struct {
http.ResponseWriter
buf *bytes.Buffer
}
func newCryptionResponseWriter(w http.ResponseWriter) *cryptionResponseWriter {
return &cryptionResponseWriter{
ResponseWriter: w,
buf: new(bytes.Buffer),
}
}
func (w *cryptionResponseWriter) Header() http.Header {
return w.ResponseWriter.Header()
}
func (w *cryptionResponseWriter) Write(p []byte) (int, error) {
return w.buf.Write(p)
}
func (w *cryptionResponseWriter) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *cryptionResponseWriter) flush(key []byte) {
if w.buf.Len() == 0 {
return
}
content, err := codec.EcbEncrypt(key, w.buf.Bytes())
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
body := base64.StdEncoding.EncodeToString(content)
if n, err := io.WriteString(w.ResponseWriter, body); err != nil {
logx.Errorf("write response failed, error: %s", err)
} else if n < len(content) {
logx.Errorf("actual bytes: %d, written bytes: %d", len(content), n)
}
}

View File

@@ -1,90 +0,0 @@
package httphandler
import (
"bytes"
"encoding/base64"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"testing"
"zero/core/codec"
"github.com/stretchr/testify/assert"
)
const (
reqText = "ping"
respText = "pong"
)
var aesKey = []byte(`PdSgVkYp3s6v9y$B&E)H+MbQeThWmZq4`)
func init() {
log.SetOutput(ioutil.Discard)
}
func TestCryptionHandlerGet(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/any", nil)
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte(respText))
w.Header().Set("X-Test", "test")
assert.Nil(t, err)
}))
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
assert.Nil(t, err)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "test", recorder.Header().Get("X-Test"))
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
}
func TestCryptionHandlerPost(t *testing.T) {
var buf bytes.Buffer
enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
assert.Nil(t, err)
buf.WriteString(base64.StdEncoding.EncodeToString(enc))
req := httptest.NewRequest(http.MethodPost, "/any", &buf)
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
assert.Nil(t, err)
assert.Equal(t, reqText, string(body))
w.Write([]byte(respText))
}))
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
assert.Nil(t, err)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
}
func TestCryptionHandlerPostBadEncryption(t *testing.T) {
var buf bytes.Buffer
enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
assert.Nil(t, err)
buf.Write(enc)
req := httptest.NewRequest(http.MethodPost, "/any", &buf)
handler := CryptionHandler(aesKey)(nil)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestCryptionHandlerWriteHeader(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/any", nil)
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
}

View File

@@ -1,27 +0,0 @@
package httphandler
import (
"compress/gzip"
"net/http"
"strings"
"zero/core/httpx"
)
const gzipEncoding = "gzip"
func GunzipHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.Header.Get(httpx.ContentEncoding), gzipEncoding) {
reader, err := gzip.NewReader(r.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
r.Body = reader
}
next.ServeHTTP(w, r)
})
}

View File

@@ -1,66 +0,0 @@
package httphandler
import (
"bytes"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"zero/core/codec"
"zero/core/httpx"
"github.com/stretchr/testify/assert"
)
func TestGunzipHandler(t *testing.T) {
const message = "hello world"
var wg sync.WaitGroup
wg.Add(1)
handler := GunzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
assert.Nil(t, err)
assert.Equal(t, string(body), message)
wg.Done()
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost",
bytes.NewReader(codec.Gzip([]byte(message))))
req.Header.Set(httpx.ContentEncoding, gzipEncoding)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
wg.Wait()
}
func TestGunzipHandler_NoGzip(t *testing.T) {
const message = "hello world"
var wg sync.WaitGroup
wg.Add(1)
handler := GunzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
assert.Nil(t, err)
assert.Equal(t, string(body), message)
wg.Done()
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost",
strings.NewReader(message))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
wg.Wait()
}
func TestGunzipHandler_NoGzipButTelling(t *testing.T) {
const message = "hello world"
handler := GunzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
req := httptest.NewRequest(http.MethodPost, "http://localhost",
strings.NewReader(message))
req.Header.Set(httpx.ContentEncoding, gzipEncoding)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusBadRequest, resp.Code)
}

View File

@@ -1,147 +0,0 @@
package internal
import (
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"zero/core/codec"
"zero/core/httpx"
"zero/core/iox"
"zero/core/logx"
)
const (
requestUriHeader = "X-Request-Uri"
signatureField = "signature"
timeField = "time"
)
var (
ErrInvalidContentType = errors.New("invalid content type")
ErrInvalidHeader = errors.New("invalid X-Content-Security header")
ErrInvalidKey = errors.New("invalid key")
ErrInvalidPublicKey = errors.New("invalid public key")
ErrInvalidSecret = errors.New("invalid secret")
)
type ContentSecurityHeader struct {
Key []byte
Timestamp string
ContentType int
Signature string
}
func (h *ContentSecurityHeader) Encrypted() bool {
return h.ContentType == httpx.CryptionType
}
func ParseContentSecurity(decrypters map[string]codec.RsaDecrypter, r *http.Request) (
*ContentSecurityHeader, error) {
contentSecurity := r.Header.Get(httpx.ContentSecurity)
attrs := httpx.ParseHeader(contentSecurity)
fingerprint := attrs[httpx.KeyField]
secret := attrs[httpx.SecretField]
signature := attrs[signatureField]
if len(fingerprint) == 0 || len(secret) == 0 || len(signature) == 0 {
return nil, ErrInvalidHeader
}
decrypter, ok := decrypters[fingerprint]
if !ok {
return nil, ErrInvalidPublicKey
}
decryptedSecret, err := decrypter.DecryptBase64(secret)
if err != nil {
return nil, ErrInvalidSecret
}
attrs = httpx.ParseHeader(string(decryptedSecret))
base64Key := attrs[httpx.KeyField]
timestamp := attrs[timeField]
contentType := attrs[httpx.TypeField]
key, err := base64.StdEncoding.DecodeString(base64Key)
if err != nil {
return nil, ErrInvalidKey
}
cType, err := strconv.Atoi(contentType)
if err != nil {
return nil, ErrInvalidContentType
}
return &ContentSecurityHeader{
Key: key,
Timestamp: timestamp,
ContentType: cType,
Signature: signature,
}, nil
}
func VerifySignature(r *http.Request, securityHeader *ContentSecurityHeader, tolerance time.Duration) int {
seconds, err := strconv.ParseInt(securityHeader.Timestamp, 10, 64)
if err != nil {
return httpx.CodeSignatureInvalidHeader
}
now := time.Now().Unix()
toleranceSeconds := int64(tolerance.Seconds())
if seconds+toleranceSeconds < now || now+toleranceSeconds < seconds {
return httpx.CodeSignatureWrongTime
}
reqPath, reqQuery := getPathQuery(r)
signContent := strings.Join([]string{
securityHeader.Timestamp,
r.Method,
reqPath,
reqQuery,
computeBodySignature(r),
}, "\n")
actualSignature := codec.HmacBase64(securityHeader.Key, signContent)
passed := securityHeader.Signature == actualSignature
if !passed {
logx.Infof("signature different, expect: %s, actual: %s",
securityHeader.Signature, actualSignature)
}
if passed {
return httpx.CodeSignaturePass
} else {
return httpx.CodeSignatureInvalidToken
}
}
func computeBodySignature(r *http.Request) string {
var dup io.ReadCloser
r.Body, dup = iox.DupReadCloser(r.Body)
sha := sha256.New()
io.Copy(sha, r.Body)
r.Body = dup
return fmt.Sprintf("%x", sha.Sum(nil))
}
func getPathQuery(r *http.Request) (string, string) {
requestUri := r.Header.Get(requestUriHeader)
if len(requestUri) == 0 {
return r.URL.Path, r.URL.RawQuery
}
uri, err := url.Parse(requestUri)
if err != nil {
return r.URL.Path, r.URL.RawQuery
}
return uri.Path, uri.RawQuery
}

View File

@@ -1,21 +0,0 @@
package internal
import "net/http"
type WithCodeResponseWriter struct {
Writer http.ResponseWriter
Code int
}
func (w *WithCodeResponseWriter) Header() http.Header {
return w.Writer.Header()
}
func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) {
return w.Writer.Write(bytes)
}
func (w *WithCodeResponseWriter) WriteHeader(code int) {
w.Writer.WriteHeader(code)
w.Code = code
}

View File

@@ -1,166 +0,0 @@
package httphandler
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httputil"
"time"
"zero/core/httplog"
"zero/core/httpx"
"zero/core/iox"
"zero/core/logx"
"zero/core/timex"
"zero/core/utils"
)
const slowThreshold = time.Millisecond * 500
type LoggedResponseWriter struct {
w http.ResponseWriter
r *http.Request
code int
}
func (w *LoggedResponseWriter) Header() http.Header {
return w.w.Header()
}
func (w *LoggedResponseWriter) Write(bytes []byte) (int, error) {
return w.w.Write(bytes)
}
func (w *LoggedResponseWriter) WriteHeader(code int) {
w.w.WriteHeader(code)
w.code = code
}
func LogHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
timer := utils.NewElapsedTimer()
logs := new(httplog.LogCollector)
lrw := LoggedResponseWriter{
w: w,
r: r,
code: http.StatusOK,
}
var dup io.ReadCloser
r.Body, dup = iox.DupReadCloser(r.Body)
next.ServeHTTP(&lrw, r.WithContext(context.WithValue(r.Context(), httplog.LogContext, logs)))
r.Body = dup
logBrief(r, lrw.code, timer, logs)
})
}
type DetailLoggedResponseWriter struct {
writer *LoggedResponseWriter
buf *bytes.Buffer
}
func newDetailLoggedResponseWriter(writer *LoggedResponseWriter, buf *bytes.Buffer) *DetailLoggedResponseWriter {
return &DetailLoggedResponseWriter{
writer: writer,
buf: buf,
}
}
func (w *DetailLoggedResponseWriter) Header() http.Header {
return w.writer.Header()
}
func (w *DetailLoggedResponseWriter) Write(bs []byte) (int, error) {
w.buf.Write(bs)
return w.writer.Write(bs)
}
func (w *DetailLoggedResponseWriter) WriteHeader(code int) {
w.writer.WriteHeader(code)
}
func DetailedLogHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
timer := utils.NewElapsedTimer()
var buf bytes.Buffer
lrw := newDetailLoggedResponseWriter(&LoggedResponseWriter{
w: w,
r: r,
code: http.StatusOK,
}, &buf)
var dup io.ReadCloser
r.Body, dup = iox.DupReadCloser(r.Body)
logs := new(httplog.LogCollector)
next.ServeHTTP(lrw, r.WithContext(context.WithValue(r.Context(), httplog.LogContext, logs)))
r.Body = dup
logDetails(r, lrw, timer, logs)
})
}
func dumpRequest(r *http.Request) string {
reqContent, err := httputil.DumpRequest(r, true)
if err != nil {
return err.Error()
} else {
return string(reqContent)
}
}
func logBrief(r *http.Request, code int, timer *utils.ElapsedTimer, logs *httplog.LogCollector) {
var buf bytes.Buffer
duration := timer.Duration()
buf.WriteString(fmt.Sprintf("%d - %s - %s - %s - %s",
code, r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration)))
if duration > slowThreshold {
logx.Slowf("[HTTP] %d - %s - %s - %s - slowcall(%s)",
code, r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration))
}
ok := isOkResponse(code)
if !ok {
buf.WriteString(fmt.Sprintf("\n%s", dumpRequest(r)))
}
body := logs.Flush()
if len(body) > 0 {
buf.WriteString(fmt.Sprintf("\n%s", body))
}
if ok {
logx.Info(buf.String())
} else {
logx.Error(buf.String())
}
}
func logDetails(r *http.Request, response *DetailLoggedResponseWriter, timer *utils.ElapsedTimer,
logs *httplog.LogCollector) {
var buf bytes.Buffer
duration := timer.Duration()
buf.WriteString(fmt.Sprintf("%d - %s - %s\n=> %s\n",
response.writer.code, r.RemoteAddr, timex.ReprOfDuration(duration), dumpRequest(r)))
if duration > slowThreshold {
logx.Slowf("[HTTP] %d - %s - slowcall(%s)\n=> %s\n",
response.writer.code, r.RemoteAddr, timex.ReprOfDuration(duration), dumpRequest(r))
}
body := logs.Flush()
if len(body) > 0 {
buf.WriteString(fmt.Sprintf("%s\n", body))
}
respBuf := response.buf.Bytes()
if len(respBuf) > 0 {
buf.WriteString(fmt.Sprintf("<= %s", respBuf))
}
logx.Info(buf.String())
}
func isOkResponse(code int) bool {
// not server error
return code < http.StatusInternalServerError
}

View File

@@ -1,74 +0,0 @@
package httphandler
import (
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"zero/core/httplog"
)
func init() {
log.SetOutput(ioutil.Discard)
}
func TestLogHandler(t *testing.T) {
handlers := []func(handler http.Handler) http.Handler{
LogHandler,
DetailedLogHandler,
}
for _, logHandler := range handlers {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Context().Value(httplog.LogContext).(*httplog.LogCollector).Append("anything")
w.Header().Set("X-Test", "test")
w.WriteHeader(http.StatusServiceUnavailable)
_, err := w.Write([]byte("content"))
assert.Nil(t, err)
}))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
assert.Equal(t, "test", resp.Header().Get("X-Test"))
assert.Equal(t, "content", resp.Body.String())
}
}
func TestLogHandlerSlow(t *testing.T) {
handlers := []func(handler http.Handler) http.Handler{
LogHandler,
DetailedLogHandler,
}
for _, logHandler := range handlers {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(slowThreshold + time.Millisecond*50)
}))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}
}
func BenchmarkLogHandler(b *testing.B) {
b.ReportAllocs()
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
for i := 0; i < b.N; i++ {
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
}
}

View File

@@ -1,27 +0,0 @@
package httphandler
import (
"net/http"
"zero/core/httplog"
)
func MaxBytesHandler(n int64) func(http.Handler) http.Handler {
if n <= 0 {
return func(next http.Handler) http.Handler {
return next
}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.ContentLength > n {
httplog.Errorf(r, "request entity too large, limit is %d, but got %d, rejected with code %d",
n, r.ContentLength, http.StatusRequestEntityTooLarge)
w.WriteHeader(http.StatusRequestEntityTooLarge)
} else {
next.ServeHTTP(w, r)
}
})
}
}

View File

@@ -1,37 +0,0 @@
package httphandler
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestMaxBytesHandler(t *testing.T) {
maxb := MaxBytesHandler(10)
handler := maxb(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
req := httptest.NewRequest(http.MethodPost, "http://localhost",
bytes.NewBufferString("123456789012345"))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusRequestEntityTooLarge, resp.Code)
req = httptest.NewRequest(http.MethodPost, "http://localhost", bytes.NewBufferString("12345"))
resp = httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}
func TestMaxBytesHandlerNoLimit(t *testing.T) {
maxb := MaxBytesHandler(-1)
handler := maxb(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
req := httptest.NewRequest(http.MethodPost, "http://localhost",
bytes.NewBufferString("123456789012345"))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}

View File

@@ -1,37 +0,0 @@
package httphandler
import (
"net/http"
"zero/core/httplog"
"zero/core/logx"
"zero/core/syncx"
)
func MaxConns(n int) func(http.Handler) http.Handler {
if n <= 0 {
return func(next http.Handler) http.Handler {
return next
}
}
return func(next http.Handler) http.Handler {
latchLimiter := syncx.NewLimit(n)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if latchLimiter.TryBorrow() {
defer func() {
if err := latchLimiter.Return(); err != nil {
logx.Error(err)
}
}()
next.ServeHTTP(w, r)
} else {
httplog.Errorf(r, "Concurrent connections over %d, rejected with code %d",
n, http.StatusServiceUnavailable)
w.WriteHeader(http.StatusServiceUnavailable)
}
})
}
}

View File

@@ -1,80 +0,0 @@
package httphandler
import (
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"sync"
"testing"
"zero/core/lang"
"github.com/stretchr/testify/assert"
)
const conns = 4
func init() {
log.SetOutput(ioutil.Discard)
}
func TestMaxConnsHandler(t *testing.T) {
var waitGroup sync.WaitGroup
waitGroup.Add(conns)
done := make(chan lang.PlaceholderType)
defer close(done)
maxConns := MaxConns(conns)
handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
waitGroup.Done()
<-done
}))
for i := 0; i < conns; i++ {
go func() {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler.ServeHTTP(httptest.NewRecorder(), req)
}()
}
waitGroup.Wait()
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
}
func TestWithoutMaxConnsHandler(t *testing.T) {
const (
key = "block"
value = "1"
)
var waitGroup sync.WaitGroup
waitGroup.Add(conns)
done := make(chan lang.PlaceholderType)
defer close(done)
maxConns := MaxConns(0)
handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
val := r.Header.Get(key)
if val == value {
waitGroup.Done()
<-done
}
}))
for i := 0; i < conns; i++ {
go func() {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Set(key, value)
handler.ServeHTTP(httptest.NewRecorder(), req)
}()
}
waitGroup.Wait()
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}

View File

@@ -1,23 +0,0 @@
package httphandler
import (
"net/http"
"zero/core/stat"
"zero/core/timex"
)
func MetricHandler(metrics *stat.Metrics) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
startTime := timex.Now()
defer func() {
metrics.Add(stat.Task{
Duration: timex.Since(startTime),
})
}()
next.ServeHTTP(w, r)
})
}
}

View File

@@ -1,24 +0,0 @@
package httphandler
import (
"net/http"
"net/http/httptest"
"testing"
"zero/core/stat"
"github.com/stretchr/testify/assert"
)
func TestMetricHandler(t *testing.T) {
metrics := stat.NewMetrics("unit-test")
metricHandler := MetricHandler(metrics)
handler := metricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}

View File

@@ -1,47 +0,0 @@
package httphandler
import (
"net/http"
"strconv"
"time"
"zero/core/httphandler/internal"
"zero/core/metric"
"zero/core/timex"
)
const serverNamespace = "http_server"
var (
metricServerReqDur = metric.NewHistogramVec(&metric.HistogramVecOpts{
Namespace: serverNamespace,
Subsystem: "requests",
Name: "duration_ms",
Help: "http server requests duration(ms).",
Labels: []string{"path"},
Buckets: []float64{5, 10, 25, 50, 100, 250, 500, 1000},
})
metricServerReqCodeTotal = metric.NewCounterVec(&metric.CounterVecOpts{
Namespace: serverNamespace,
Subsystem: "requests",
Name: "code_total",
Help: "http server requests error count.",
Labels: []string{"path", "code"},
})
)
func PromMetricHandler(path string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
startTime := timex.Now()
cw := &internal.WithCodeResponseWriter{Writer: w}
defer func() {
metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), path)
metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code))
}()
next.ServeHTTP(cw, r)
})
}
}

View File

@@ -1,21 +0,0 @@
package httphandler
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestPromMetricHandler(t *testing.T) {
promMetricHandler := PromMetricHandler("/user/login")
handler := promMetricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}

View File

@@ -1,22 +0,0 @@
package httphandler
import (
"fmt"
"net/http"
"runtime/debug"
"zero/core/httplog"
)
func RecoverHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if result := recover(); result != nil {
httplog.Error(r, fmt.Sprintf("%v\n%s", result, debug.Stack()))
w.WriteHeader(http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}

View File

@@ -1,36 +0,0 @@
package httphandler
import (
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func init() {
log.SetOutput(ioutil.Discard)
}
func TestWithPanic(t *testing.T) {
handler := RecoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("whatever")
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusInternalServerError, resp.Code)
}
func TestWithoutPanic(t *testing.T) {
handler := RecoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}

View File

@@ -1,63 +0,0 @@
package httphandler
import (
"net/http"
"sync"
"zero/core/httphandler/internal"
"zero/core/httpx"
"zero/core/load"
"zero/core/logx"
"zero/core/stat"
)
const serviceType = "api"
var (
sheddingStat *load.SheddingStat
lock sync.Mutex
)
func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Handler) http.Handler {
if shedder == nil {
return func(next http.Handler) http.Handler {
return next
}
}
ensureSheddingStat()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sheddingStat.IncrementTotal()
promise, err := shedder.Allow()
if err != nil {
metrics.AddDrop()
sheddingStat.IncrementDrop()
logx.Errorf("[http] dropped, %s - %s - %s",
r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent())
w.WriteHeader(http.StatusServiceUnavailable)
return
}
cw := &internal.WithCodeResponseWriter{Writer: w}
defer func() {
if cw.Code == http.StatusServiceUnavailable {
promise.Fail()
} else {
sheddingStat.IncrementPass()
promise.Pass()
}
}()
next.ServeHTTP(cw, r)
})
}
}
func ensureSheddingStat() {
lock.Lock()
if sheddingStat == nil {
sheddingStat = load.NewSheddingStat(serviceType)
}
lock.Unlock()
}

View File

@@ -1,105 +0,0 @@
package httphandler
import (
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"testing"
"zero/core/load"
"zero/core/stat"
"github.com/stretchr/testify/assert"
)
func init() {
log.SetOutput(ioutil.Discard)
}
func TestSheddingHandlerAccept(t *testing.T) {
metrics := stat.NewMetrics("unit-test")
shedder := mockShedder{
allow: true,
}
sheddingHandler := SheddingHandler(shedder, metrics)
handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test", "test")
_, err := w.Write([]byte("content"))
assert.Nil(t, err)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Set("X-Test", "test")
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, "test", resp.Header().Get("X-Test"))
assert.Equal(t, "content", resp.Body.String())
}
func TestSheddingHandlerFail(t *testing.T) {
metrics := stat.NewMetrics("unit-test")
shedder := mockShedder{
allow: true,
}
sheddingHandler := SheddingHandler(shedder, metrics)
handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
}
func TestSheddingHandlerReject(t *testing.T) {
metrics := stat.NewMetrics("unit-test")
shedder := mockShedder{
allow: false,
}
sheddingHandler := SheddingHandler(shedder, metrics)
handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
}
func TestSheddingHandlerNoShedding(t *testing.T) {
metrics := stat.NewMetrics("unit-test")
sheddingHandler := SheddingHandler(nil, metrics)
handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}
type mockShedder struct {
allow bool
}
func (s mockShedder) Allow() (load.Promise, error) {
if s.allow {
return mockPromise{}, nil
} else {
return nil, load.ErrServiceOverloaded
}
}
type mockPromise struct {
}
func (p mockPromise) Pass() {
}
func (p mockPromise) Fail() {
}

View File

@@ -1,18 +0,0 @@
package httphandler
import (
"net/http"
"time"
)
const reason = "Request Timeout"
func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
if duration > 0 {
return http.TimeoutHandler(next, duration, reason)
} else {
return next
}
}
}

View File

@@ -1,52 +0,0 @@
package httphandler
import (
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func init() {
log.SetOutput(ioutil.Discard)
}
func TestTimeout(t *testing.T) {
timeoutHandler := TimeoutHandler(time.Millisecond)
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Minute)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
}
func TestWithinTimeout(t *testing.T) {
timeoutHandler := TimeoutHandler(time.Second)
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Millisecond)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}
func TestWithoutTimeout(t *testing.T) {
timeoutHandler := TimeoutHandler(0)
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}

View File

@@ -1,25 +0,0 @@
package httphandler
import (
"net/http"
"zero/core/logx"
"zero/core/sysx"
"zero/core/trace"
)
func TracingHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
carrier, err := trace.Extract(trace.HttpFormat, r.Header)
// ErrInvalidCarrier means no trace id was set in http header
if err != nil && err != trace.ErrInvalidCarrier {
logx.Error(err)
}
ctx, span := trace.StartServerSpan(r.Context(), carrier, sysx.Hostname(), r.RequestURI)
defer span.Finish()
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}

View File

@@ -1,25 +0,0 @@
package httphandler
import (
"net/http"
"net/http/httptest"
"testing"
"zero/core/trace/tracespec"
"github.com/stretchr/testify/assert"
)
func TestTracingHandler(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Set("X-Trace-ID", "theid")
handler := TracingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
span, ok := r.Context().Value(tracespec.TracingKey).(tracespec.Trace)
assert.True(t, ok)
assert.Equal(t, "theid", span.TraceId())
}))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}

View File

@@ -1,84 +0,0 @@
package httplog
import (
"bytes"
"fmt"
"net/http"
"sync"
"zero/core/httpx"
"zero/core/logx"
)
const LogContext = "request_logs"
type LogCollector struct {
Messages []string
lock sync.Mutex
}
func (lc *LogCollector) Append(msg string) {
lc.lock.Lock()
lc.Messages = append(lc.Messages, msg)
lc.lock.Unlock()
}
func (lc *LogCollector) Flush() string {
var buffer bytes.Buffer
start := true
for _, message := range lc.takeAll() {
if start {
start = false
} else {
buffer.WriteByte('\n')
}
buffer.WriteString(message)
}
return buffer.String()
}
func (lc *LogCollector) takeAll() []string {
lc.lock.Lock()
messages := lc.Messages
lc.Messages = nil
lc.lock.Unlock()
return messages
}
func Error(r *http.Request, v ...interface{}) {
logx.ErrorCaller(1, format(r, v...))
}
func Errorf(r *http.Request, format string, v ...interface{}) {
logx.ErrorCaller(1, formatf(r, format, v...))
}
func Info(r *http.Request, v ...interface{}) {
appendLog(r, format(r, v...))
}
func Infof(r *http.Request, format string, v ...interface{}) {
appendLog(r, formatf(r, format, v...))
}
func appendLog(r *http.Request, message string) {
logs := r.Context().Value(LogContext)
if logs != nil {
logs.(*LogCollector).Append(message)
}
}
func format(r *http.Request, v ...interface{}) string {
return formatWithReq(r, fmt.Sprint(v...))
}
func formatf(r *http.Request, format string, v ...interface{}) string {
return formatWithReq(r, fmt.Sprintf(format, v...))
}
func formatWithReq(r *http.Request, v string) string {
return fmt.Sprintf("(%s - %s) %s", r.RequestURI, httpx.GetRemoteAddr(r), v)
}

View File

@@ -1,38 +0,0 @@
package httplog
import (
"context"
"log"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestInfo(t *testing.T) {
collector := new(LogCollector)
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
req = req.WithContext(context.WithValue(req.Context(), LogContext, collector))
Info(req, "first")
Infof(req, "second %s", "third")
val := collector.Flush()
assert.True(t, strings.Contains(val, "first"))
assert.True(t, strings.Contains(val, "second"))
assert.True(t, strings.Contains(val, "third"))
assert.True(t, strings.Contains(val, "\n"))
}
func TestError(t *testing.T) {
var writer strings.Builder
log.SetOutput(&writer)
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
Error(req, "first")
Errorf(req, "second %s", "third")
val := writer.String()
assert.True(t, strings.Contains(val, "first"))
assert.True(t, strings.Contains(val, "second"))
assert.True(t, strings.Contains(val, "third"))
assert.True(t, strings.Contains(val, "\n"))
}

View File

@@ -1,115 +0,0 @@
package httprouter
import (
"context"
"net/http"
"path"
"strings"
"zero/core/search"
)
const (
allowHeader = "Allow"
allowMethodSeparator = ", "
pathVars = "pathVars"
)
type PatRouter struct {
trees map[string]*search.Tree
notFound http.Handler
}
func NewPatRouter() Router {
return &PatRouter{
trees: make(map[string]*search.Tree),
}
}
func (pr *PatRouter) Handle(method, reqPath string, handler http.Handler) error {
if !validMethod(method) {
return ErrInvalidMethod
}
if len(reqPath) == 0 || reqPath[0] != '/' {
return ErrInvalidPath
}
cleanPath := path.Clean(reqPath)
if tree, ok := pr.trees[method]; ok {
return tree.Add(cleanPath, handler)
} else {
tree = search.NewTree()
pr.trees[method] = tree
return tree.Add(cleanPath, handler)
}
}
func (pr *PatRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
reqPath := path.Clean(r.URL.Path)
if tree, ok := pr.trees[r.Method]; ok {
if result, ok := tree.Search(reqPath); ok {
if len(result.Params) > 0 {
r = r.WithContext(context.WithValue(r.Context(), pathVars, result.Params))
}
result.Item.(http.Handler).ServeHTTP(w, r)
return
}
}
if allow, ok := pr.methodNotAllowed(r.Method, reqPath); ok {
w.Header().Set(allowHeader, allow)
w.WriteHeader(http.StatusMethodNotAllowed)
} else {
pr.handleNotFound(w, r)
}
}
func (pr *PatRouter) SetNotFoundHandler(handler http.Handler) {
pr.notFound = handler
}
func (pr *PatRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
if pr.notFound != nil {
pr.notFound.ServeHTTP(w, r)
} else {
http.NotFound(w, r)
}
}
func (pr *PatRouter) methodNotAllowed(method, path string) (string, bool) {
var allows []string
for treeMethod, tree := range pr.trees {
if treeMethod == method {
continue
}
_, ok := tree.Search(path)
if ok {
allows = append(allows, treeMethod)
}
}
if len(allows) > 0 {
return strings.Join(allows, allowMethodSeparator), true
} else {
return "", false
}
}
func Vars(r *http.Request) map[string]string {
vars, ok := r.Context().Value(pathVars).(map[string]string)
if ok {
return vars
}
return nil
}
func validMethod(method string) bool {
return method == http.MethodDelete || method == http.MethodGet ||
method == http.MethodHead || method == http.MethodOptions ||
method == http.MethodPatch || method == http.MethodPost ||
method == http.MethodPut
}

View File

@@ -1,120 +0,0 @@
package httprouter
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
type mockedResponseWriter struct {
code int
}
func (m *mockedResponseWriter) Header() http.Header {
return http.Header{}
}
func (m *mockedResponseWriter) Write(p []byte) (int, error) {
return len(p), nil
}
func (m *mockedResponseWriter) WriteHeader(code int) {
m.code = code
}
func TestPatRouterHandleErrors(t *testing.T) {
tests := []struct {
method string
path string
err error
}{
{"FAKE", "", ErrInvalidMethod},
{"GET", "", ErrInvalidPath},
}
for _, test := range tests {
t.Run(test.method, func(t *testing.T) {
router := NewPatRouter()
err := router.Handle(test.method, test.path, nil)
assert.Error(t, ErrInvalidMethod, err)
})
}
}
func TestPatRouterNotFound(t *testing.T) {
var notFound bool
router := NewPatRouter()
router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
notFound = true
}))
router.Handle(http.MethodGet, "/a/b", nil)
r, _ := http.NewRequest(http.MethodGet, "/b/c", nil)
w := new(mockedResponseWriter)
router.ServeHTTP(w, r)
assert.True(t, notFound)
}
func TestPatRouter(t *testing.T) {
tests := []struct {
method string
path string
expect bool
code int
err error
}{
// we don't explicitly set status code, framework will do it.
{http.MethodGet, "/a/b", true, 0, nil},
{http.MethodGet, "/a/b/", true, 0, nil},
{http.MethodGet, "/a/b?a=b", true, 0, nil},
{http.MethodGet, "/a/b/?a=b", true, 0, nil},
{http.MethodGet, "/a/b/c?a=b", true, 0, nil},
{http.MethodGet, "/b/d", false, http.StatusNotFound, nil},
}
for _, test := range tests {
t.Run(test.method+":"+test.path, func(t *testing.T) {
routed := false
router := NewPatRouter()
err := router.Handle(test.method, "/a/:b", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
routed = true
assert.Equal(t, 1, len(Vars(r)))
}))
assert.Nil(t, err)
err = router.Handle(test.method, "/a/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
routed = true
assert.Nil(t, Vars(r))
}))
assert.Nil(t, err)
err = router.Handle(test.method, "/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
routed = true
}))
assert.Nil(t, err)
w := new(mockedResponseWriter)
r, _ := http.NewRequest(test.method, test.path, nil)
router.ServeHTTP(w, r)
assert.Equal(t, test.expect, routed)
assert.Equal(t, test.code, w.code)
if test.code == 0 {
r, _ = http.NewRequest(http.MethodPut, test.path, nil)
router.ServeHTTP(w, r)
assert.Equal(t, http.StatusMethodNotAllowed, w.code)
}
})
}
}
func BenchmarkPatRouter(b *testing.B) {
b.ReportAllocs()
router := NewPatRouter()
router.Handle(http.MethodGet, "/api/:user/:name", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
w := &mockedResponseWriter{}
r, _ := http.NewRequest(http.MethodGet, "/api/a/b", nil)
for i := 0; i < b.N; i++ {
router.ServeHTTP(w, r)
}
}

View File

@@ -1,24 +0,0 @@
package httprouter
import (
"errors"
"net/http"
)
var (
ErrInvalidMethod = errors.New("not a valid http method")
ErrInvalidPath = errors.New("path must begin with '/'")
)
type (
Route struct {
Path string
Handler http.HandlerFunc
}
Router interface {
http.Handler
Handle(method string, path string, handler http.Handler) error
SetNotFoundHandler(handler http.Handler)
}
)

View File

@@ -1,122 +0,0 @@
package httpsecurity
import (
"net/http"
"sync"
"sync/atomic"
"time"
"zero/core/timex"
"github.com/dgrijalva/jwt-go"
"github.com/dgrijalva/jwt-go/request"
)
const claimHistoryResetDuration = time.Hour * 24
type (
ParseOption func(parser *TokenParser)
TokenParser struct {
resetTime time.Duration
resetDuration time.Duration
history sync.Map
}
)
func NewTokenParser(opts ...ParseOption) *TokenParser {
parser := &TokenParser{
resetTime: timex.Now(),
resetDuration: claimHistoryResetDuration,
}
for _, opt := range opts {
opt(parser)
}
return parser
}
func (tp *TokenParser) ParseToken(r *http.Request, secret, prevSecret string) (*jwt.Token, error) {
var token *jwt.Token
var err error
if len(prevSecret) > 0 {
count := tp.loadCount(secret)
prevCount := tp.loadCount(prevSecret)
var first, second string
if count > prevCount {
first = secret
second = prevSecret
} else {
first = prevSecret
second = secret
}
token, err = tp.doParseToken(r, first)
if err != nil {
token, err = tp.doParseToken(r, second)
if err != nil {
return nil, err
} else {
tp.incrementCount(second)
}
} else {
tp.incrementCount(first)
}
} else {
token, err = tp.doParseToken(r, secret)
if err != nil {
return nil, err
}
}
return token, nil
}
func (tp *TokenParser) doParseToken(r *http.Request, secret string) (*jwt.Token, error) {
return request.ParseFromRequest(r, request.AuthorizationHeaderExtractor,
func(token *jwt.Token) (interface{}, error) {
return []byte(secret), nil
}, request.WithParser(newParser()))
}
func (tp *TokenParser) incrementCount(secret string) {
now := timex.Now()
if tp.resetTime+tp.resetDuration < now {
tp.history.Range(func(key, value interface{}) bool {
tp.history.Delete(key)
return true
})
}
value, ok := tp.history.Load(secret)
if ok {
atomic.AddUint64(value.(*uint64), 1)
} else {
var count uint64 = 1
tp.history.Store(secret, &count)
}
}
func (tp *TokenParser) loadCount(secret string) uint64 {
value, ok := tp.history.Load(secret)
if ok {
return *value.(*uint64)
}
return 0
}
func WithResetDuration(duration time.Duration) ParseOption {
return func(parser *TokenParser) {
parser.resetDuration = duration
}
}
func newParser() *jwt.Parser {
return &jwt.Parser{
UseJSONNumber: true,
}
}

View File

@@ -1,87 +0,0 @@
package httpsecurity
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/dgrijalva/jwt-go"
"github.com/stretchr/testify/assert"
"zero/core/timex"
)
func TestTokenParser(t *testing.T) {
const (
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
)
keys := []struct {
key string
prevKey string
}{
{
key,
prevKey,
},
{
key,
"",
},
}
for _, pair := range keys {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
token, err := buildToken(key, map[string]interface{}{
"key": "value",
}, 3600)
assert.Nil(t, err)
req.Header.Set("Authorization", "Bearer "+token)
parser := NewTokenParser(WithResetDuration(time.Minute))
tok, err := parser.ParseToken(req, pair.key, pair.prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
}
}
func TestTokenParser_Expired(t *testing.T) {
const (
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
)
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
token, err := buildToken(key, map[string]interface{}{
"key": "value",
}, 3600)
assert.Nil(t, err)
req.Header.Set("Authorization", "Bearer "+token)
parser := NewTokenParser(WithResetDuration(time.Second))
tok, err := parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
tok, err = parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
parser.resetTime = timex.Now() - time.Hour
tok, err = parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
}
func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
now := time.Now().Unix()
claims := make(jwt.MapClaims)
claims["exp"] = now + seconds
claims["iat"] = now
for k, v := range payloads {
claims[k] = v
}
token := jwt.New(jwt.SigningMethodHS256)
token.Claims = claims
return token.SignedString([]byte(secretKey))
}

View File

@@ -1,40 +0,0 @@
package httpserver
import (
"crypto/tls"
"fmt"
"net/http"
)
func StartHttp(host string, port int, handler http.Handler) error {
addr := fmt.Sprintf("%s:%d", host, port)
server := buildHttpServer(addr, handler)
return StartServer(server)
}
func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler) error {
addr := fmt.Sprintf("%s:%d", host, port)
if server, err := buildHttpsServer(addr, handler, certFile, keyFile); err != nil {
return err
} else {
return StartServer(server)
}
}
func buildHttpServer(addr string, handler http.Handler) *http.Server {
return &http.Server{Addr: addr, Handler: handler}
}
func buildHttpsServer(addr string, handler http.Handler, certFile, keyFile string) (*http.Server, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
config := tls.Config{Certificates: []tls.Certificate{cert}}
return &http.Server{
Addr: addr,
Handler: handler,
TLSConfig: &config,
}, nil
}

View File

@@ -1,16 +0,0 @@
package httpserver
import (
"context"
"net/http"
"zero/core/proc"
)
func StartServer(srv *http.Server) error {
proc.AddWrapUpListener(func() {
srv.Shutdown(context.Background())
})
return srv.ListenAndServe()
}

View File

@@ -1,19 +0,0 @@
package httpx
const (
ApplicationJson = "application/json"
ContentEncoding = "Content-Encoding"
ContentSecurity = "X-Content-Security"
ContentType = "Content-Type"
KeyField = "key"
SecretField = "secret"
TypeField = "type"
CryptionType = 1
)
const (
CodeSignaturePass = iota
CodeSignatureInvalidHeader
CodeSignatureWrongTime
CodeSignatureInvalidToken
)

View File

@@ -1,124 +0,0 @@
package httpx
import (
"errors"
"io"
"net/http"
"strings"
"zero/core/httprouter"
"zero/core/mapping"
)
const (
multipartFormData = "multipart/form-data"
xForwardFor = "X-Forward-For"
formKey = "form"
pathKey = "path"
emptyJson = "{}"
maxMemory = 32 << 20 // 32MB
maxBodyLen = 8 << 20 // 8MB
separator = ";"
tokensInAttribute = 2
)
var (
ErrBodylessRequest = errors.New("not a POST|PUT|PATCH request")
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
)
// Returns the peer address, supports X-Forward-For
func GetRemoteAddr(r *http.Request) string {
v := r.Header.Get(xForwardFor)
if len(v) > 0 {
return v
}
return r.RemoteAddr
}
func Parse(r *http.Request, v interface{}) error {
if err := ParsePath(r, v); err != nil {
return err
}
if err := ParseForm(r, v); err != nil {
return err
}
return ParseJsonBody(r, v)
}
// Parses the form request.
func ParseForm(r *http.Request, v interface{}) error {
if strings.Index(r.Header.Get(ContentType), multipartFormData) != -1 {
if err := r.ParseMultipartForm(maxMemory); err != nil {
return err
}
} else {
if err := r.ParseForm(); err != nil {
return err
}
}
params := make(map[string]interface{}, len(r.Form))
for name := range r.Form {
formValue := r.Form.Get(name)
if len(formValue) > 0 {
params[name] = formValue
}
}
return formUnmarshaler.Unmarshal(params, v)
}
func ParseHeader(headerValue string) map[string]string {
ret := make(map[string]string)
fields := strings.Split(headerValue, separator)
for _, field := range fields {
field = strings.TrimSpace(field)
if len(field) == 0 {
continue
}
kv := strings.SplitN(field, "=", tokensInAttribute)
if len(kv) != tokensInAttribute {
continue
}
ret[kv[0]] = kv[1]
}
return ret
}
// Parses the post request which contains json in body.
func ParseJsonBody(r *http.Request, v interface{}) error {
var reader io.Reader
if withJsonBody(r) {
reader = io.LimitReader(r.Body, maxBodyLen)
} else {
reader = strings.NewReader(emptyJson)
}
return mapping.UnmarshalJsonReader(reader, v)
}
// Parses the symbols reside in url path.
// Like http://localhost/bag/:name
func ParsePath(r *http.Request, v interface{}) error {
vars := httprouter.Vars(r)
m := make(map[string]interface{}, len(vars))
for k, v := range vars {
m[k] = v
}
return pathUnmarshaler.Unmarshal(m, v)
}
func withJsonBody(r *http.Request) bool {
return r.ContentLength > 0 && strings.Index(r.Header.Get(ContentType), ApplicationJson) != -1
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,29 +0,0 @@
package httpx
import (
"encoding/json"
"net/http"
"zero/core/logx"
)
func OkJson(w http.ResponseWriter, v interface{}) {
WriteJson(w, http.StatusOK, v)
}
func WriteJson(w http.ResponseWriter, code int, v interface{}) {
w.Header().Set(ContentType, ApplicationJson)
w.WriteHeader(code)
if bs, err := json.Marshal(v); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
} else if n, err := w.Write(bs); err != nil {
// http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
// so it's ignored here.
if err != http.ErrHandlerTimeout {
logx.Errorf("write response failed, error: %s", err)
}
} else if n < len(bs) {
logx.Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
}
}

View File

@@ -1,78 +0,0 @@
package httpx
import (
"net/http"
"strings"
"testing"
"zero/core/logx"
"github.com/stretchr/testify/assert"
)
type message struct {
Name string `json:"name"`
}
func init() {
logx.Disable()
}
func TestOkJson(t *testing.T) {
w := tracedResponseWriter{
headers: make(map[string][]string),
}
msg := message{Name: "anyone"}
OkJson(&w, msg)
assert.Equal(t, http.StatusOK, w.code)
assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
}
func TestWriteJsonTimeout(t *testing.T) {
// only log it and ignore
w := tracedResponseWriter{
headers: make(map[string][]string),
timeout: true,
}
msg := message{Name: "anyone"}
WriteJson(&w, http.StatusOK, msg)
assert.Equal(t, http.StatusOK, w.code)
}
func TestWriteJsonLessWritten(t *testing.T) {
w := tracedResponseWriter{
headers: make(map[string][]string),
lessWritten: true,
}
msg := message{Name: "anyone"}
WriteJson(&w, http.StatusOK, msg)
assert.Equal(t, http.StatusOK, w.code)
}
type tracedResponseWriter struct {
headers map[string][]string
builder strings.Builder
code int
lessWritten bool
timeout bool
}
func (w *tracedResponseWriter) Header() http.Header {
return w.headers
}
func (w *tracedResponseWriter) Write(bytes []byte) (n int, err error) {
if w.timeout {
return 0, http.ErrHandlerTimeout
}
n, err = w.builder.Write(bytes)
if w.lessWritten {
n -= 1
}
return
}
func (w *tracedResponseWriter) WriteHeader(code int) {
w.code = code
}