rename ngin to rest

This commit is contained in:
kevin
2020-07-31 11:14:48 +08:00
parent e133ffd820
commit 0897f60c5d
78 changed files with 118 additions and 111 deletions

View File

@@ -0,0 +1,21 @@
package context
import (
"context"
"net/http"
)
const pathVars = "pathVars"
func Vars(r *http.Request) map[string]string {
vars, ok := r.Context().Value(pathVars).(map[string]string)
if ok {
return vars
}
return nil
}
func WithPathVars(r *http.Request, params map[string]string) *http.Request {
return r.WithContext(context.WithValue(r.Context(), pathVars, params))
}

83
rest/internal/log.go Normal file
View File

@@ -0,0 +1,83 @@
package internal
import (
"bytes"
"fmt"
"net/http"
"sync"
"zero/core/logx"
)
const LogContext = "request_logs"
type LogCollector struct {
Messages []string
lock sync.Mutex
}
func (lc *LogCollector) Append(msg string) {
lc.lock.Lock()
lc.Messages = append(lc.Messages, msg)
lc.lock.Unlock()
}
func (lc *LogCollector) Flush() string {
var buffer bytes.Buffer
start := true
for _, message := range lc.takeAll() {
if start {
start = false
} else {
buffer.WriteByte('\n')
}
buffer.WriteString(message)
}
return buffer.String()
}
func (lc *LogCollector) takeAll() []string {
lc.lock.Lock()
messages := lc.Messages
lc.Messages = nil
lc.lock.Unlock()
return messages
}
func Error(r *http.Request, v ...interface{}) {
logx.ErrorCaller(1, format(r, v...))
}
func Errorf(r *http.Request, format string, v ...interface{}) {
logx.ErrorCaller(1, formatf(r, format, v...))
}
func Info(r *http.Request, v ...interface{}) {
appendLog(r, format(r, v...))
}
func Infof(r *http.Request, format string, v ...interface{}) {
appendLog(r, formatf(r, format, v...))
}
func appendLog(r *http.Request, message string) {
logs := r.Context().Value(LogContext)
if logs != nil {
logs.(*LogCollector).Append(message)
}
}
func format(r *http.Request, v ...interface{}) string {
return formatWithReq(r, fmt.Sprint(v...))
}
func formatf(r *http.Request, format string, v ...interface{}) string {
return formatWithReq(r, fmt.Sprintf(format, v...))
}
func formatWithReq(r *http.Request, v string) string {
return fmt.Sprintf("(%s - %s) %s", r.RequestURI, GetRemoteAddr(r), v)
}

38
rest/internal/log_test.go Normal file
View File

@@ -0,0 +1,38 @@
package internal
import (
"context"
"log"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestInfo(t *testing.T) {
collector := new(LogCollector)
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
req = req.WithContext(context.WithValue(req.Context(), LogContext, collector))
Info(req, "first")
Infof(req, "second %s", "third")
val := collector.Flush()
assert.True(t, strings.Contains(val, "first"))
assert.True(t, strings.Contains(val, "second"))
assert.True(t, strings.Contains(val, "third"))
assert.True(t, strings.Contains(val, "\n"))
}
func TestError(t *testing.T) {
var writer strings.Builder
log.SetOutput(&writer)
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
Error(req, "first")
Errorf(req, "second %s", "third")
val := writer.String()
assert.True(t, strings.Contains(val, "first"))
assert.True(t, strings.Contains(val, "second"))
assert.True(t, strings.Contains(val, "third"))
assert.True(t, strings.Contains(val, "\n"))
}

View File

@@ -0,0 +1,105 @@
package router
import (
"net/http"
"path"
"strings"
"zero/core/search"
"zero/rest/internal/context"
)
const (
allowHeader = "Allow"
allowMethodSeparator = ", "
)
type PatRouter struct {
trees map[string]*search.Tree
notFound http.Handler
}
func NewPatRouter() Router {
return &PatRouter{
trees: make(map[string]*search.Tree),
}
}
func (pr *PatRouter) Handle(method, reqPath string, handler http.Handler) error {
if !validMethod(method) {
return ErrInvalidMethod
}
if len(reqPath) == 0 || reqPath[0] != '/' {
return ErrInvalidPath
}
cleanPath := path.Clean(reqPath)
if tree, ok := pr.trees[method]; ok {
return tree.Add(cleanPath, handler)
} else {
tree = search.NewTree()
pr.trees[method] = tree
return tree.Add(cleanPath, handler)
}
}
func (pr *PatRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
reqPath := path.Clean(r.URL.Path)
if tree, ok := pr.trees[r.Method]; ok {
if result, ok := tree.Search(reqPath); ok {
if len(result.Params) > 0 {
r = context.WithPathVars(r, result.Params)
}
result.Item.(http.Handler).ServeHTTP(w, r)
return
}
}
if allow, ok := pr.methodNotAllowed(r.Method, reqPath); ok {
w.Header().Set(allowHeader, allow)
w.WriteHeader(http.StatusMethodNotAllowed)
} else {
pr.handleNotFound(w, r)
}
}
func (pr *PatRouter) SetNotFoundHandler(handler http.Handler) {
pr.notFound = handler
}
func (pr *PatRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
if pr.notFound != nil {
pr.notFound.ServeHTTP(w, r)
} else {
http.NotFound(w, r)
}
}
func (pr *PatRouter) methodNotAllowed(method, path string) (string, bool) {
var allows []string
for treeMethod, tree := range pr.trees {
if treeMethod == method {
continue
}
_, ok := tree.Search(path)
if ok {
allows = append(allows, treeMethod)
}
}
if len(allows) > 0 {
return strings.Join(allows, allowMethodSeparator), true
} else {
return "", false
}
}
func validMethod(method string) bool {
return method == http.MethodDelete || method == http.MethodGet ||
method == http.MethodHead || method == http.MethodOptions ||
method == http.MethodPatch || method == http.MethodPost ||
method == http.MethodPut
}

View File

@@ -0,0 +1,122 @@
package router
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"zero/rest/internal/context"
)
type mockedResponseWriter struct {
code int
}
func (m *mockedResponseWriter) Header() http.Header {
return http.Header{}
}
func (m *mockedResponseWriter) Write(p []byte) (int, error) {
return len(p), nil
}
func (m *mockedResponseWriter) WriteHeader(code int) {
m.code = code
}
func TestPatRouterHandleErrors(t *testing.T) {
tests := []struct {
method string
path string
err error
}{
{"FAKE", "", ErrInvalidMethod},
{"GET", "", ErrInvalidPath},
}
for _, test := range tests {
t.Run(test.method, func(t *testing.T) {
router := NewPatRouter()
err := router.Handle(test.method, test.path, nil)
assert.Error(t, ErrInvalidMethod, err)
})
}
}
func TestPatRouterNotFound(t *testing.T) {
var notFound bool
router := NewPatRouter()
router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
notFound = true
}))
router.Handle(http.MethodGet, "/a/b", nil)
r, _ := http.NewRequest(http.MethodGet, "/b/c", nil)
w := new(mockedResponseWriter)
router.ServeHTTP(w, r)
assert.True(t, notFound)
}
func TestPatRouter(t *testing.T) {
tests := []struct {
method string
path string
expect bool
code int
err error
}{
// we don't explicitly set status code, framework will do it.
{http.MethodGet, "/a/b", true, 0, nil},
{http.MethodGet, "/a/b/", true, 0, nil},
{http.MethodGet, "/a/b?a=b", true, 0, nil},
{http.MethodGet, "/a/b/?a=b", true, 0, nil},
{http.MethodGet, "/a/b/c?a=b", true, 0, nil},
{http.MethodGet, "/b/d", false, http.StatusNotFound, nil},
}
for _, test := range tests {
t.Run(test.method+":"+test.path, func(t *testing.T) {
routed := false
router := NewPatRouter()
err := router.Handle(test.method, "/a/:b", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
routed = true
assert.Equal(t, 1, len(context.Vars(r)))
}))
assert.Nil(t, err)
err = router.Handle(test.method, "/a/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
routed = true
assert.Nil(t, context.Vars(r))
}))
assert.Nil(t, err)
err = router.Handle(test.method, "/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
routed = true
}))
assert.Nil(t, err)
w := new(mockedResponseWriter)
r, _ := http.NewRequest(test.method, test.path, nil)
router.ServeHTTP(w, r)
assert.Equal(t, test.expect, routed)
assert.Equal(t, test.code, w.code)
if test.code == 0 {
r, _ = http.NewRequest(http.MethodPut, test.path, nil)
router.ServeHTTP(w, r)
assert.Equal(t, http.StatusMethodNotAllowed, w.code)
}
})
}
}
func BenchmarkPatRouter(b *testing.B) {
b.ReportAllocs()
router := NewPatRouter()
router.Handle(http.MethodGet, "/api/:user/:name", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
w := &mockedResponseWriter{}
r, _ := http.NewRequest(http.MethodGet, "/api/a/b", nil)
for i := 0; i < b.N; i++ {
router.ServeHTTP(w, r)
}
}

View File

@@ -0,0 +1,24 @@
package router
import (
"errors"
"net/http"
)
var (
ErrInvalidMethod = errors.New("not a valid http method")
ErrInvalidPath = errors.New("path must begin with '/'")
)
type (
Route struct {
Path string
Handler http.HandlerFunc
}
Router interface {
http.Handler
Handle(method string, path string, handler http.Handler) error
SetNotFoundHandler(handler http.Handler)
}
)

View File

@@ -0,0 +1,147 @@
package security
import (
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"zero/core/codec"
"zero/core/iox"
"zero/core/logx"
"zero/rest/httpx"
)
const (
requestUriHeader = "X-Request-Uri"
signatureField = "signature"
timeField = "time"
)
var (
ErrInvalidContentType = errors.New("invalid content type")
ErrInvalidHeader = errors.New("invalid X-Content-Security header")
ErrInvalidKey = errors.New("invalid key")
ErrInvalidPublicKey = errors.New("invalid public key")
ErrInvalidSecret = errors.New("invalid secret")
)
type ContentSecurityHeader struct {
Key []byte
Timestamp string
ContentType int
Signature string
}
func (h *ContentSecurityHeader) Encrypted() bool {
return h.ContentType == httpx.CryptionType
}
func ParseContentSecurity(decrypters map[string]codec.RsaDecrypter, r *http.Request) (
*ContentSecurityHeader, error) {
contentSecurity := r.Header.Get(httpx.ContentSecurity)
attrs := httpx.ParseHeader(contentSecurity)
fingerprint := attrs[httpx.KeyField]
secret := attrs[httpx.SecretField]
signature := attrs[signatureField]
if len(fingerprint) == 0 || len(secret) == 0 || len(signature) == 0 {
return nil, ErrInvalidHeader
}
decrypter, ok := decrypters[fingerprint]
if !ok {
return nil, ErrInvalidPublicKey
}
decryptedSecret, err := decrypter.DecryptBase64(secret)
if err != nil {
return nil, ErrInvalidSecret
}
attrs = httpx.ParseHeader(string(decryptedSecret))
base64Key := attrs[httpx.KeyField]
timestamp := attrs[timeField]
contentType := attrs[httpx.TypeField]
key, err := base64.StdEncoding.DecodeString(base64Key)
if err != nil {
return nil, ErrInvalidKey
}
cType, err := strconv.Atoi(contentType)
if err != nil {
return nil, ErrInvalidContentType
}
return &ContentSecurityHeader{
Key: key,
Timestamp: timestamp,
ContentType: cType,
Signature: signature,
}, nil
}
func VerifySignature(r *http.Request, securityHeader *ContentSecurityHeader, tolerance time.Duration) int {
seconds, err := strconv.ParseInt(securityHeader.Timestamp, 10, 64)
if err != nil {
return httpx.CodeSignatureInvalidHeader
}
now := time.Now().Unix()
toleranceSeconds := int64(tolerance.Seconds())
if seconds+toleranceSeconds < now || now+toleranceSeconds < seconds {
return httpx.CodeSignatureWrongTime
}
reqPath, reqQuery := getPathQuery(r)
signContent := strings.Join([]string{
securityHeader.Timestamp,
r.Method,
reqPath,
reqQuery,
computeBodySignature(r),
}, "\n")
actualSignature := codec.HmacBase64(securityHeader.Key, signContent)
passed := securityHeader.Signature == actualSignature
if !passed {
logx.Infof("signature different, expect: %s, actual: %s",
securityHeader.Signature, actualSignature)
}
if passed {
return httpx.CodeSignaturePass
} else {
return httpx.CodeSignatureInvalidToken
}
}
func computeBodySignature(r *http.Request) string {
var dup io.ReadCloser
r.Body, dup = iox.DupReadCloser(r.Body)
sha := sha256.New()
io.Copy(sha, r.Body)
r.Body = dup
return fmt.Sprintf("%x", sha.Sum(nil))
}
func getPathQuery(r *http.Request) (string, string) {
requestUri := r.Header.Get(requestUriHeader)
if len(requestUri) == 0 {
return r.URL.Path, r.URL.RawQuery
}
uri, err := url.Parse(requestUri)
if err != nil {
return r.URL.Path, r.URL.RawQuery
}
return uri.Path, uri.RawQuery
}

View File

@@ -0,0 +1,21 @@
package security
import "net/http"
type WithCodeResponseWriter struct {
Writer http.ResponseWriter
Code int
}
func (w *WithCodeResponseWriter) Header() http.Header {
return w.Writer.Header()
}
func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) {
return w.Writer.Write(bytes)
}
func (w *WithCodeResponseWriter) WriteHeader(code int) {
w.Writer.WriteHeader(code)
w.Code = code
}

40
rest/internal/server.go Normal file
View File

@@ -0,0 +1,40 @@
package internal
import (
"crypto/tls"
"fmt"
"net/http"
)
func StartHttp(host string, port int, handler http.Handler) error {
addr := fmt.Sprintf("%s:%d", host, port)
server := buildHttpServer(addr, handler)
return StartServer(server)
}
func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler) error {
addr := fmt.Sprintf("%s:%d", host, port)
if server, err := buildHttpsServer(addr, handler, certFile, keyFile); err != nil {
return err
} else {
return StartServer(server)
}
}
func buildHttpServer(addr string, handler http.Handler) *http.Server {
return &http.Server{Addr: addr, Handler: handler}
}
func buildHttpsServer(addr string, handler http.Handler, certFile, keyFile string) (*http.Server, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
config := tls.Config{Certificates: []tls.Certificate{cert}}
return &http.Server{
Addr: addr,
Handler: handler,
TLSConfig: &config,
}, nil
}

16
rest/internal/starter.go Normal file
View File

@@ -0,0 +1,16 @@
package internal
import (
"context"
"net/http"
"zero/core/proc"
)
func StartServer(srv *http.Server) error {
proc.AddWrapUpListener(func() {
srv.Shutdown(context.Background())
})
return srv.ListenAndServe()
}

View File

@@ -0,0 +1,122 @@
package internal
import (
"net/http"
"sync"
"sync/atomic"
"time"
"zero/core/timex"
"github.com/dgrijalva/jwt-go"
"github.com/dgrijalva/jwt-go/request"
)
const claimHistoryResetDuration = time.Hour * 24
type (
ParseOption func(parser *TokenParser)
TokenParser struct {
resetTime time.Duration
resetDuration time.Duration
history sync.Map
}
)
func NewTokenParser(opts ...ParseOption) *TokenParser {
parser := &TokenParser{
resetTime: timex.Now(),
resetDuration: claimHistoryResetDuration,
}
for _, opt := range opts {
opt(parser)
}
return parser
}
func (tp *TokenParser) ParseToken(r *http.Request, secret, prevSecret string) (*jwt.Token, error) {
var token *jwt.Token
var err error
if len(prevSecret) > 0 {
count := tp.loadCount(secret)
prevCount := tp.loadCount(prevSecret)
var first, second string
if count > prevCount {
first = secret
second = prevSecret
} else {
first = prevSecret
second = secret
}
token, err = tp.doParseToken(r, first)
if err != nil {
token, err = tp.doParseToken(r, second)
if err != nil {
return nil, err
} else {
tp.incrementCount(second)
}
} else {
tp.incrementCount(first)
}
} else {
token, err = tp.doParseToken(r, secret)
if err != nil {
return nil, err
}
}
return token, nil
}
func (tp *TokenParser) doParseToken(r *http.Request, secret string) (*jwt.Token, error) {
return request.ParseFromRequest(r, request.AuthorizationHeaderExtractor,
func(token *jwt.Token) (interface{}, error) {
return []byte(secret), nil
}, request.WithParser(newParser()))
}
func (tp *TokenParser) incrementCount(secret string) {
now := timex.Now()
if tp.resetTime+tp.resetDuration < now {
tp.history.Range(func(key, value interface{}) bool {
tp.history.Delete(key)
return true
})
}
value, ok := tp.history.Load(secret)
if ok {
atomic.AddUint64(value.(*uint64), 1)
} else {
var count uint64 = 1
tp.history.Store(secret, &count)
}
}
func (tp *TokenParser) loadCount(secret string) uint64 {
value, ok := tp.history.Load(secret)
if ok {
return *value.(*uint64)
}
return 0
}
func WithResetDuration(duration time.Duration) ParseOption {
return func(parser *TokenParser) {
parser.resetDuration = duration
}
}
func newParser() *jwt.Parser {
return &jwt.Parser{
UseJSONNumber: true,
}
}

View File

@@ -0,0 +1,87 @@
package internal
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/dgrijalva/jwt-go"
"github.com/stretchr/testify/assert"
"zero/core/timex"
)
func TestTokenParser(t *testing.T) {
const (
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
)
keys := []struct {
key string
prevKey string
}{
{
key,
prevKey,
},
{
key,
"",
},
}
for _, pair := range keys {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
token, err := buildToken(key, map[string]interface{}{
"key": "value",
}, 3600)
assert.Nil(t, err)
req.Header.Set("Authorization", "Bearer "+token)
parser := NewTokenParser(WithResetDuration(time.Minute))
tok, err := parser.ParseToken(req, pair.key, pair.prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
}
}
func TestTokenParser_Expired(t *testing.T) {
const (
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
)
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
token, err := buildToken(key, map[string]interface{}{
"key": "value",
}, 3600)
assert.Nil(t, err)
req.Header.Set("Authorization", "Bearer "+token)
parser := NewTokenParser(WithResetDuration(time.Second))
tok, err := parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
tok, err = parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
parser.resetTime = timex.Now() - time.Hour
tok, err = parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
}
func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
now := time.Now().Unix()
claims := make(jwt.MapClaims)
claims["exp"] = now + seconds
claims["iat"] = now
for k, v := range payloads {
claims[k] = v
}
token := jwt.New(jwt.SigningMethodHS256)
token.Claims = claims
return token.SignedString([]byte(secretKey))
}

14
rest/internal/util.go Normal file
View File

@@ -0,0 +1,14 @@
package internal
import "net/http"
const xForwardFor = "X-Forward-For"
// Returns the peer address, supports X-Forward-For
func GetRemoteAddr(r *http.Request) string {
v := r.Header.Get(xForwardFor)
if len(v) > 0 {
return v
}
return r.RemoteAddr
}

View File

@@ -0,0 +1,19 @@
package internal
import (
"net/http"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestGetRemoteAddr(t *testing.T) {
host := "8.8.8.8"
r, err := http.NewRequest(http.MethodGet, "/", strings.NewReader(""))
assert.Nil(t, err)
r.Header.Set(xForwardFor, host)
assert.Equal(t, host, GetRemoteAddr(r))
}