rename ngin to rest
This commit is contained in:
21
rest/internal/context/params.go
Normal file
21
rest/internal/context/params.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package context
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const pathVars = "pathVars"
|
||||
|
||||
func Vars(r *http.Request) map[string]string {
|
||||
vars, ok := r.Context().Value(pathVars).(map[string]string)
|
||||
if ok {
|
||||
return vars
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func WithPathVars(r *http.Request, params map[string]string) *http.Request {
|
||||
return r.WithContext(context.WithValue(r.Context(), pathVars, params))
|
||||
}
|
||||
83
rest/internal/log.go
Normal file
83
rest/internal/log.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"zero/core/logx"
|
||||
)
|
||||
|
||||
const LogContext = "request_logs"
|
||||
|
||||
type LogCollector struct {
|
||||
Messages []string
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (lc *LogCollector) Append(msg string) {
|
||||
lc.lock.Lock()
|
||||
lc.Messages = append(lc.Messages, msg)
|
||||
lc.lock.Unlock()
|
||||
}
|
||||
|
||||
func (lc *LogCollector) Flush() string {
|
||||
var buffer bytes.Buffer
|
||||
|
||||
start := true
|
||||
for _, message := range lc.takeAll() {
|
||||
if start {
|
||||
start = false
|
||||
} else {
|
||||
buffer.WriteByte('\n')
|
||||
}
|
||||
buffer.WriteString(message)
|
||||
}
|
||||
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
func (lc *LogCollector) takeAll() []string {
|
||||
lc.lock.Lock()
|
||||
messages := lc.Messages
|
||||
lc.Messages = nil
|
||||
lc.lock.Unlock()
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func Error(r *http.Request, v ...interface{}) {
|
||||
logx.ErrorCaller(1, format(r, v...))
|
||||
}
|
||||
|
||||
func Errorf(r *http.Request, format string, v ...interface{}) {
|
||||
logx.ErrorCaller(1, formatf(r, format, v...))
|
||||
}
|
||||
|
||||
func Info(r *http.Request, v ...interface{}) {
|
||||
appendLog(r, format(r, v...))
|
||||
}
|
||||
|
||||
func Infof(r *http.Request, format string, v ...interface{}) {
|
||||
appendLog(r, formatf(r, format, v...))
|
||||
}
|
||||
|
||||
func appendLog(r *http.Request, message string) {
|
||||
logs := r.Context().Value(LogContext)
|
||||
if logs != nil {
|
||||
logs.(*LogCollector).Append(message)
|
||||
}
|
||||
}
|
||||
|
||||
func format(r *http.Request, v ...interface{}) string {
|
||||
return formatWithReq(r, fmt.Sprint(v...))
|
||||
}
|
||||
|
||||
func formatf(r *http.Request, format string, v ...interface{}) string {
|
||||
return formatWithReq(r, fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
func formatWithReq(r *http.Request, v string) string {
|
||||
return fmt.Sprintf("(%s - %s) %s", r.RequestURI, GetRemoteAddr(r), v)
|
||||
}
|
||||
38
rest/internal/log_test.go
Normal file
38
rest/internal/log_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInfo(t *testing.T) {
|
||||
collector := new(LogCollector)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), LogContext, collector))
|
||||
Info(req, "first")
|
||||
Infof(req, "second %s", "third")
|
||||
val := collector.Flush()
|
||||
assert.True(t, strings.Contains(val, "first"))
|
||||
assert.True(t, strings.Contains(val, "second"))
|
||||
assert.True(t, strings.Contains(val, "third"))
|
||||
assert.True(t, strings.Contains(val, "\n"))
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
var writer strings.Builder
|
||||
log.SetOutput(&writer)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
Error(req, "first")
|
||||
Errorf(req, "second %s", "third")
|
||||
val := writer.String()
|
||||
assert.True(t, strings.Contains(val, "first"))
|
||||
assert.True(t, strings.Contains(val, "second"))
|
||||
assert.True(t, strings.Contains(val, "third"))
|
||||
assert.True(t, strings.Contains(val, "\n"))
|
||||
}
|
||||
105
rest/internal/router/patrouter.go
Normal file
105
rest/internal/router/patrouter.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"zero/core/search"
|
||||
"zero/rest/internal/context"
|
||||
)
|
||||
|
||||
const (
|
||||
allowHeader = "Allow"
|
||||
allowMethodSeparator = ", "
|
||||
)
|
||||
|
||||
type PatRouter struct {
|
||||
trees map[string]*search.Tree
|
||||
notFound http.Handler
|
||||
}
|
||||
|
||||
func NewPatRouter() Router {
|
||||
return &PatRouter{
|
||||
trees: make(map[string]*search.Tree),
|
||||
}
|
||||
}
|
||||
|
||||
func (pr *PatRouter) Handle(method, reqPath string, handler http.Handler) error {
|
||||
if !validMethod(method) {
|
||||
return ErrInvalidMethod
|
||||
}
|
||||
|
||||
if len(reqPath) == 0 || reqPath[0] != '/' {
|
||||
return ErrInvalidPath
|
||||
}
|
||||
|
||||
cleanPath := path.Clean(reqPath)
|
||||
if tree, ok := pr.trees[method]; ok {
|
||||
return tree.Add(cleanPath, handler)
|
||||
} else {
|
||||
tree = search.NewTree()
|
||||
pr.trees[method] = tree
|
||||
return tree.Add(cleanPath, handler)
|
||||
}
|
||||
}
|
||||
|
||||
func (pr *PatRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
reqPath := path.Clean(r.URL.Path)
|
||||
if tree, ok := pr.trees[r.Method]; ok {
|
||||
if result, ok := tree.Search(reqPath); ok {
|
||||
if len(result.Params) > 0 {
|
||||
r = context.WithPathVars(r, result.Params)
|
||||
}
|
||||
result.Item.(http.Handler).ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if allow, ok := pr.methodNotAllowed(r.Method, reqPath); ok {
|
||||
w.Header().Set(allowHeader, allow)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
} else {
|
||||
pr.handleNotFound(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (pr *PatRouter) SetNotFoundHandler(handler http.Handler) {
|
||||
pr.notFound = handler
|
||||
}
|
||||
|
||||
func (pr *PatRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
|
||||
if pr.notFound != nil {
|
||||
pr.notFound.ServeHTTP(w, r)
|
||||
} else {
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (pr *PatRouter) methodNotAllowed(method, path string) (string, bool) {
|
||||
var allows []string
|
||||
|
||||
for treeMethod, tree := range pr.trees {
|
||||
if treeMethod == method {
|
||||
continue
|
||||
}
|
||||
|
||||
_, ok := tree.Search(path)
|
||||
if ok {
|
||||
allows = append(allows, treeMethod)
|
||||
}
|
||||
}
|
||||
|
||||
if len(allows) > 0 {
|
||||
return strings.Join(allows, allowMethodSeparator), true
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func validMethod(method string) bool {
|
||||
return method == http.MethodDelete || method == http.MethodGet ||
|
||||
method == http.MethodHead || method == http.MethodOptions ||
|
||||
method == http.MethodPatch || method == http.MethodPost ||
|
||||
method == http.MethodPut
|
||||
}
|
||||
122
rest/internal/router/patrouter_test.go
Normal file
122
rest/internal/router/patrouter_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"zero/rest/internal/context"
|
||||
)
|
||||
|
||||
type mockedResponseWriter struct {
|
||||
code int
|
||||
}
|
||||
|
||||
func (m *mockedResponseWriter) Header() http.Header {
|
||||
return http.Header{}
|
||||
}
|
||||
|
||||
func (m *mockedResponseWriter) Write(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (m *mockedResponseWriter) WriteHeader(code int) {
|
||||
m.code = code
|
||||
}
|
||||
|
||||
func TestPatRouterHandleErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
method string
|
||||
path string
|
||||
err error
|
||||
}{
|
||||
{"FAKE", "", ErrInvalidMethod},
|
||||
{"GET", "", ErrInvalidPath},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.method, func(t *testing.T) {
|
||||
router := NewPatRouter()
|
||||
err := router.Handle(test.method, test.path, nil)
|
||||
assert.Error(t, ErrInvalidMethod, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatRouterNotFound(t *testing.T) {
|
||||
var notFound bool
|
||||
router := NewPatRouter()
|
||||
router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
notFound = true
|
||||
}))
|
||||
router.Handle(http.MethodGet, "/a/b", nil)
|
||||
r, _ := http.NewRequest(http.MethodGet, "/b/c", nil)
|
||||
w := new(mockedResponseWriter)
|
||||
router.ServeHTTP(w, r)
|
||||
assert.True(t, notFound)
|
||||
}
|
||||
|
||||
func TestPatRouter(t *testing.T) {
|
||||
tests := []struct {
|
||||
method string
|
||||
path string
|
||||
expect bool
|
||||
code int
|
||||
err error
|
||||
}{
|
||||
// we don't explicitly set status code, framework will do it.
|
||||
{http.MethodGet, "/a/b", true, 0, nil},
|
||||
{http.MethodGet, "/a/b/", true, 0, nil},
|
||||
{http.MethodGet, "/a/b?a=b", true, 0, nil},
|
||||
{http.MethodGet, "/a/b/?a=b", true, 0, nil},
|
||||
{http.MethodGet, "/a/b/c?a=b", true, 0, nil},
|
||||
{http.MethodGet, "/b/d", false, http.StatusNotFound, nil},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.method+":"+test.path, func(t *testing.T) {
|
||||
routed := false
|
||||
router := NewPatRouter()
|
||||
err := router.Handle(test.method, "/a/:b", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
routed = true
|
||||
assert.Equal(t, 1, len(context.Vars(r)))
|
||||
}))
|
||||
assert.Nil(t, err)
|
||||
err = router.Handle(test.method, "/a/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
routed = true
|
||||
assert.Nil(t, context.Vars(r))
|
||||
}))
|
||||
assert.Nil(t, err)
|
||||
err = router.Handle(test.method, "/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
routed = true
|
||||
}))
|
||||
assert.Nil(t, err)
|
||||
|
||||
w := new(mockedResponseWriter)
|
||||
r, _ := http.NewRequest(test.method, test.path, nil)
|
||||
router.ServeHTTP(w, r)
|
||||
assert.Equal(t, test.expect, routed)
|
||||
assert.Equal(t, test.code, w.code)
|
||||
|
||||
if test.code == 0 {
|
||||
r, _ = http.NewRequest(http.MethodPut, test.path, nil)
|
||||
router.ServeHTTP(w, r)
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, w.code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPatRouter(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
router := NewPatRouter()
|
||||
router.Handle(http.MethodGet, "/api/:user/:name", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
}))
|
||||
w := &mockedResponseWriter{}
|
||||
r, _ := http.NewRequest(http.MethodGet, "/api/a/b", nil)
|
||||
for i := 0; i < b.N; i++ {
|
||||
router.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
24
rest/internal/router/router.go
Normal file
24
rest/internal/router/router.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidMethod = errors.New("not a valid http method")
|
||||
ErrInvalidPath = errors.New("path must begin with '/'")
|
||||
)
|
||||
|
||||
type (
|
||||
Route struct {
|
||||
Path string
|
||||
Handler http.HandlerFunc
|
||||
}
|
||||
|
||||
Router interface {
|
||||
http.Handler
|
||||
Handle(method string, path string, handler http.Handler) error
|
||||
SetNotFoundHandler(handler http.Handler)
|
||||
}
|
||||
)
|
||||
147
rest/internal/security/contentsecurity.go
Normal file
147
rest/internal/security/contentsecurity.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"zero/core/codec"
|
||||
"zero/core/iox"
|
||||
"zero/core/logx"
|
||||
"zero/rest/httpx"
|
||||
)
|
||||
|
||||
const (
|
||||
requestUriHeader = "X-Request-Uri"
|
||||
signatureField = "signature"
|
||||
timeField = "time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidContentType = errors.New("invalid content type")
|
||||
ErrInvalidHeader = errors.New("invalid X-Content-Security header")
|
||||
ErrInvalidKey = errors.New("invalid key")
|
||||
ErrInvalidPublicKey = errors.New("invalid public key")
|
||||
ErrInvalidSecret = errors.New("invalid secret")
|
||||
)
|
||||
|
||||
type ContentSecurityHeader struct {
|
||||
Key []byte
|
||||
Timestamp string
|
||||
ContentType int
|
||||
Signature string
|
||||
}
|
||||
|
||||
func (h *ContentSecurityHeader) Encrypted() bool {
|
||||
return h.ContentType == httpx.CryptionType
|
||||
}
|
||||
|
||||
func ParseContentSecurity(decrypters map[string]codec.RsaDecrypter, r *http.Request) (
|
||||
*ContentSecurityHeader, error) {
|
||||
contentSecurity := r.Header.Get(httpx.ContentSecurity)
|
||||
attrs := httpx.ParseHeader(contentSecurity)
|
||||
fingerprint := attrs[httpx.KeyField]
|
||||
secret := attrs[httpx.SecretField]
|
||||
signature := attrs[signatureField]
|
||||
|
||||
if len(fingerprint) == 0 || len(secret) == 0 || len(signature) == 0 {
|
||||
return nil, ErrInvalidHeader
|
||||
}
|
||||
|
||||
decrypter, ok := decrypters[fingerprint]
|
||||
if !ok {
|
||||
return nil, ErrInvalidPublicKey
|
||||
}
|
||||
|
||||
decryptedSecret, err := decrypter.DecryptBase64(secret)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidSecret
|
||||
}
|
||||
|
||||
attrs = httpx.ParseHeader(string(decryptedSecret))
|
||||
base64Key := attrs[httpx.KeyField]
|
||||
timestamp := attrs[timeField]
|
||||
contentType := attrs[httpx.TypeField]
|
||||
|
||||
key, err := base64.StdEncoding.DecodeString(base64Key)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidKey
|
||||
}
|
||||
|
||||
cType, err := strconv.Atoi(contentType)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidContentType
|
||||
}
|
||||
|
||||
return &ContentSecurityHeader{
|
||||
Key: key,
|
||||
Timestamp: timestamp,
|
||||
ContentType: cType,
|
||||
Signature: signature,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func VerifySignature(r *http.Request, securityHeader *ContentSecurityHeader, tolerance time.Duration) int {
|
||||
seconds, err := strconv.ParseInt(securityHeader.Timestamp, 10, 64)
|
||||
if err != nil {
|
||||
return httpx.CodeSignatureInvalidHeader
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
toleranceSeconds := int64(tolerance.Seconds())
|
||||
if seconds+toleranceSeconds < now || now+toleranceSeconds < seconds {
|
||||
return httpx.CodeSignatureWrongTime
|
||||
}
|
||||
|
||||
reqPath, reqQuery := getPathQuery(r)
|
||||
signContent := strings.Join([]string{
|
||||
securityHeader.Timestamp,
|
||||
r.Method,
|
||||
reqPath,
|
||||
reqQuery,
|
||||
computeBodySignature(r),
|
||||
}, "\n")
|
||||
actualSignature := codec.HmacBase64(securityHeader.Key, signContent)
|
||||
|
||||
passed := securityHeader.Signature == actualSignature
|
||||
if !passed {
|
||||
logx.Infof("signature different, expect: %s, actual: %s",
|
||||
securityHeader.Signature, actualSignature)
|
||||
}
|
||||
|
||||
if passed {
|
||||
return httpx.CodeSignaturePass
|
||||
} else {
|
||||
return httpx.CodeSignatureInvalidToken
|
||||
}
|
||||
}
|
||||
|
||||
func computeBodySignature(r *http.Request) string {
|
||||
var dup io.ReadCloser
|
||||
r.Body, dup = iox.DupReadCloser(r.Body)
|
||||
sha := sha256.New()
|
||||
io.Copy(sha, r.Body)
|
||||
r.Body = dup
|
||||
return fmt.Sprintf("%x", sha.Sum(nil))
|
||||
}
|
||||
|
||||
func getPathQuery(r *http.Request) (string, string) {
|
||||
requestUri := r.Header.Get(requestUriHeader)
|
||||
if len(requestUri) == 0 {
|
||||
return r.URL.Path, r.URL.RawQuery
|
||||
}
|
||||
|
||||
uri, err := url.Parse(requestUri)
|
||||
if err != nil {
|
||||
return r.URL.Path, r.URL.RawQuery
|
||||
}
|
||||
|
||||
return uri.Path, uri.RawQuery
|
||||
}
|
||||
21
rest/internal/security/withcoderesponsewriter.go
Normal file
21
rest/internal/security/withcoderesponsewriter.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package security
|
||||
|
||||
import "net/http"
|
||||
|
||||
type WithCodeResponseWriter struct {
|
||||
Writer http.ResponseWriter
|
||||
Code int
|
||||
}
|
||||
|
||||
func (w *WithCodeResponseWriter) Header() http.Header {
|
||||
return w.Writer.Header()
|
||||
}
|
||||
|
||||
func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) {
|
||||
return w.Writer.Write(bytes)
|
||||
}
|
||||
|
||||
func (w *WithCodeResponseWriter) WriteHeader(code int) {
|
||||
w.Writer.WriteHeader(code)
|
||||
w.Code = code
|
||||
}
|
||||
40
rest/internal/server.go
Normal file
40
rest/internal/server.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func StartHttp(host string, port int, handler http.Handler) error {
|
||||
addr := fmt.Sprintf("%s:%d", host, port)
|
||||
server := buildHttpServer(addr, handler)
|
||||
return StartServer(server)
|
||||
}
|
||||
|
||||
func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler) error {
|
||||
addr := fmt.Sprintf("%s:%d", host, port)
|
||||
if server, err := buildHttpsServer(addr, handler, certFile, keyFile); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return StartServer(server)
|
||||
}
|
||||
}
|
||||
|
||||
func buildHttpServer(addr string, handler http.Handler) *http.Server {
|
||||
return &http.Server{Addr: addr, Handler: handler}
|
||||
}
|
||||
|
||||
func buildHttpsServer(addr string, handler http.Handler, certFile, keyFile string) (*http.Server, error) {
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config := tls.Config{Certificates: []tls.Certificate{cert}}
|
||||
return &http.Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
TLSConfig: &config,
|
||||
}, nil
|
||||
}
|
||||
16
rest/internal/starter.go
Normal file
16
rest/internal/starter.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"zero/core/proc"
|
||||
)
|
||||
|
||||
func StartServer(srv *http.Server) error {
|
||||
proc.AddWrapUpListener(func() {
|
||||
srv.Shutdown(context.Background())
|
||||
})
|
||||
|
||||
return srv.ListenAndServe()
|
||||
}
|
||||
122
rest/internal/tokenparser.go
Normal file
122
rest/internal/tokenparser.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"zero/core/timex"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/dgrijalva/jwt-go/request"
|
||||
)
|
||||
|
||||
const claimHistoryResetDuration = time.Hour * 24
|
||||
|
||||
type (
|
||||
ParseOption func(parser *TokenParser)
|
||||
|
||||
TokenParser struct {
|
||||
resetTime time.Duration
|
||||
resetDuration time.Duration
|
||||
history sync.Map
|
||||
}
|
||||
)
|
||||
|
||||
func NewTokenParser(opts ...ParseOption) *TokenParser {
|
||||
parser := &TokenParser{
|
||||
resetTime: timex.Now(),
|
||||
resetDuration: claimHistoryResetDuration,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(parser)
|
||||
}
|
||||
|
||||
return parser
|
||||
}
|
||||
|
||||
func (tp *TokenParser) ParseToken(r *http.Request, secret, prevSecret string) (*jwt.Token, error) {
|
||||
var token *jwt.Token
|
||||
var err error
|
||||
|
||||
if len(prevSecret) > 0 {
|
||||
count := tp.loadCount(secret)
|
||||
prevCount := tp.loadCount(prevSecret)
|
||||
|
||||
var first, second string
|
||||
if count > prevCount {
|
||||
first = secret
|
||||
second = prevSecret
|
||||
} else {
|
||||
first = prevSecret
|
||||
second = secret
|
||||
}
|
||||
|
||||
token, err = tp.doParseToken(r, first)
|
||||
if err != nil {
|
||||
token, err = tp.doParseToken(r, second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
tp.incrementCount(second)
|
||||
}
|
||||
} else {
|
||||
tp.incrementCount(first)
|
||||
}
|
||||
} else {
|
||||
token, err = tp.doParseToken(r, secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (tp *TokenParser) doParseToken(r *http.Request, secret string) (*jwt.Token, error) {
|
||||
return request.ParseFromRequest(r, request.AuthorizationHeaderExtractor,
|
||||
func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(secret), nil
|
||||
}, request.WithParser(newParser()))
|
||||
}
|
||||
|
||||
func (tp *TokenParser) incrementCount(secret string) {
|
||||
now := timex.Now()
|
||||
if tp.resetTime+tp.resetDuration < now {
|
||||
tp.history.Range(func(key, value interface{}) bool {
|
||||
tp.history.Delete(key)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
value, ok := tp.history.Load(secret)
|
||||
if ok {
|
||||
atomic.AddUint64(value.(*uint64), 1)
|
||||
} else {
|
||||
var count uint64 = 1
|
||||
tp.history.Store(secret, &count)
|
||||
}
|
||||
}
|
||||
|
||||
func (tp *TokenParser) loadCount(secret string) uint64 {
|
||||
value, ok := tp.history.Load(secret)
|
||||
if ok {
|
||||
return *value.(*uint64)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func WithResetDuration(duration time.Duration) ParseOption {
|
||||
return func(parser *TokenParser) {
|
||||
parser.resetDuration = duration
|
||||
}
|
||||
}
|
||||
|
||||
func newParser() *jwt.Parser {
|
||||
return &jwt.Parser{
|
||||
UseJSONNumber: true,
|
||||
}
|
||||
}
|
||||
87
rest/internal/tokenparser_test.go
Normal file
87
rest/internal/tokenparser_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"zero/core/timex"
|
||||
)
|
||||
|
||||
func TestTokenParser(t *testing.T) {
|
||||
const (
|
||||
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
|
||||
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
|
||||
)
|
||||
keys := []struct {
|
||||
key string
|
||||
prevKey string
|
||||
}{
|
||||
{
|
||||
key,
|
||||
prevKey,
|
||||
},
|
||||
{
|
||||
key,
|
||||
"",
|
||||
},
|
||||
}
|
||||
|
||||
for _, pair := range keys {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
token, err := buildToken(key, map[string]interface{}{
|
||||
"key": "value",
|
||||
}, 3600)
|
||||
assert.Nil(t, err)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
parser := NewTokenParser(WithResetDuration(time.Minute))
|
||||
tok, err := parser.ParseToken(req, pair.key, pair.prevKey)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenParser_Expired(t *testing.T) {
|
||||
const (
|
||||
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
|
||||
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
token, err := buildToken(key, map[string]interface{}{
|
||||
"key": "value",
|
||||
}, 3600)
|
||||
assert.Nil(t, err)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
parser := NewTokenParser(WithResetDuration(time.Second))
|
||||
tok, err := parser.ParseToken(req, key, prevKey)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
|
||||
tok, err = parser.ParseToken(req, key, prevKey)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
|
||||
parser.resetTime = timex.Now() - time.Hour
|
||||
tok, err = parser.ParseToken(req, key, prevKey)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
|
||||
}
|
||||
|
||||
func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
|
||||
now := time.Now().Unix()
|
||||
claims := make(jwt.MapClaims)
|
||||
claims["exp"] = now + seconds
|
||||
claims["iat"] = now
|
||||
for k, v := range payloads {
|
||||
claims[k] = v
|
||||
}
|
||||
|
||||
token := jwt.New(jwt.SigningMethodHS256)
|
||||
token.Claims = claims
|
||||
|
||||
return token.SignedString([]byte(secretKey))
|
||||
}
|
||||
14
rest/internal/util.go
Normal file
14
rest/internal/util.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package internal
|
||||
|
||||
import "net/http"
|
||||
|
||||
const xForwardFor = "X-Forward-For"
|
||||
|
||||
// Returns the peer address, supports X-Forward-For
|
||||
func GetRemoteAddr(r *http.Request) string {
|
||||
v := r.Header.Get(xForwardFor)
|
||||
if len(v) > 0 {
|
||||
return v
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
19
rest/internal/util_test.go
Normal file
19
rest/internal/util_test.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetRemoteAddr(t *testing.T) {
|
||||
host := "8.8.8.8"
|
||||
r, err := http.NewRequest(http.MethodGet, "/", strings.NewReader(""))
|
||||
assert.Nil(t, err)
|
||||
|
||||
r.Header.Set(xForwardFor, host)
|
||||
assert.Equal(t, host, GetRemoteAddr(r))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user