initial import

This commit is contained in:
kevin
2020-07-26 17:09:05 +08:00
commit 7e3a369a8f
647 changed files with 54754 additions and 0 deletions

View File

@@ -0,0 +1,115 @@
package httprouter
import (
"context"
"net/http"
"path"
"strings"
"zero/core/search"
)
const (
allowHeader = "Allow"
allowMethodSeparator = ", "
pathVars = "pathVars"
)
type PatRouter struct {
trees map[string]*search.Tree
notFound http.Handler
}
func NewPatRouter() Router {
return &PatRouter{
trees: make(map[string]*search.Tree),
}
}
func (pr *PatRouter) Handle(method, reqPath string, handler http.Handler) error {
if !validMethod(method) {
return ErrInvalidMethod
}
if len(reqPath) == 0 || reqPath[0] != '/' {
return ErrInvalidPath
}
cleanPath := path.Clean(reqPath)
if tree, ok := pr.trees[method]; ok {
return tree.Add(cleanPath, handler)
} else {
tree = search.NewTree()
pr.trees[method] = tree
return tree.Add(cleanPath, handler)
}
}
func (pr *PatRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
reqPath := path.Clean(r.URL.Path)
if tree, ok := pr.trees[r.Method]; ok {
if result, ok := tree.Search(reqPath); ok {
if len(result.Params) > 0 {
r = r.WithContext(context.WithValue(r.Context(), pathVars, result.Params))
}
result.Item.(http.Handler).ServeHTTP(w, r)
return
}
}
if allow, ok := pr.methodNotAllowed(r.Method, reqPath); ok {
w.Header().Set(allowHeader, allow)
w.WriteHeader(http.StatusMethodNotAllowed)
} else {
pr.handleNotFound(w, r)
}
}
func (pr *PatRouter) SetNotFoundHandler(handler http.Handler) {
pr.notFound = handler
}
func (pr *PatRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
if pr.notFound != nil {
pr.notFound.ServeHTTP(w, r)
} else {
http.NotFound(w, r)
}
}
func (pr *PatRouter) methodNotAllowed(method, path string) (string, bool) {
var allows []string
for treeMethod, tree := range pr.trees {
if treeMethod == method {
continue
}
_, ok := tree.Search(path)
if ok {
allows = append(allows, treeMethod)
}
}
if len(allows) > 0 {
return strings.Join(allows, allowMethodSeparator), true
} else {
return "", false
}
}
func Vars(r *http.Request) map[string]string {
vars, ok := r.Context().Value(pathVars).(map[string]string)
if ok {
return vars
}
return nil
}
func validMethod(method string) bool {
return method == http.MethodDelete || method == http.MethodGet ||
method == http.MethodHead || method == http.MethodOptions ||
method == http.MethodPatch || method == http.MethodPost ||
method == http.MethodPut
}

View File

@@ -0,0 +1,120 @@
package httprouter
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
type mockedResponseWriter struct {
code int
}
func (m *mockedResponseWriter) Header() http.Header {
return http.Header{}
}
func (m *mockedResponseWriter) Write(p []byte) (int, error) {
return len(p), nil
}
func (m *mockedResponseWriter) WriteHeader(code int) {
m.code = code
}
func TestPatRouterHandleErrors(t *testing.T) {
tests := []struct {
method string
path string
err error
}{
{"FAKE", "", ErrInvalidMethod},
{"GET", "", ErrInvalidPath},
}
for _, test := range tests {
t.Run(test.method, func(t *testing.T) {
router := NewPatRouter()
err := router.Handle(test.method, test.path, nil)
assert.Error(t, ErrInvalidMethod, err)
})
}
}
func TestPatRouterNotFound(t *testing.T) {
var notFound bool
router := NewPatRouter()
router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
notFound = true
}))
router.Handle(http.MethodGet, "/a/b", nil)
r, _ := http.NewRequest(http.MethodGet, "/b/c", nil)
w := new(mockedResponseWriter)
router.ServeHTTP(w, r)
assert.True(t, notFound)
}
func TestPatRouter(t *testing.T) {
tests := []struct {
method string
path string
expect bool
code int
err error
}{
// we don't explicitly set status code, framework will do it.
{http.MethodGet, "/a/b", true, 0, nil},
{http.MethodGet, "/a/b/", true, 0, nil},
{http.MethodGet, "/a/b?a=b", true, 0, nil},
{http.MethodGet, "/a/b/?a=b", true, 0, nil},
{http.MethodGet, "/a/b/c?a=b", true, 0, nil},
{http.MethodGet, "/b/d", false, http.StatusNotFound, nil},
}
for _, test := range tests {
t.Run(test.method+":"+test.path, func(t *testing.T) {
routed := false
router := NewPatRouter()
err := router.Handle(test.method, "/a/:b", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
routed = true
assert.Equal(t, 1, len(Vars(r)))
}))
assert.Nil(t, err)
err = router.Handle(test.method, "/a/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
routed = true
assert.Nil(t, Vars(r))
}))
assert.Nil(t, err)
err = router.Handle(test.method, "/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
routed = true
}))
assert.Nil(t, err)
w := new(mockedResponseWriter)
r, _ := http.NewRequest(test.method, test.path, nil)
router.ServeHTTP(w, r)
assert.Equal(t, test.expect, routed)
assert.Equal(t, test.code, w.code)
if test.code == 0 {
r, _ = http.NewRequest(http.MethodPut, test.path, nil)
router.ServeHTTP(w, r)
assert.Equal(t, http.StatusMethodNotAllowed, w.code)
}
})
}
}
func BenchmarkPatRouter(b *testing.B) {
b.ReportAllocs()
router := NewPatRouter()
router.Handle(http.MethodGet, "/api/:user/:name", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
w := &mockedResponseWriter{}
r, _ := http.NewRequest(http.MethodGet, "/api/a/b", nil)
for i := 0; i < b.N; i++ {
router.ServeHTTP(w, r)
}
}

24
core/httprouter/router.go Normal file
View File

@@ -0,0 +1,24 @@
package httprouter
import (
"errors"
"net/http"
)
var (
ErrInvalidMethod = errors.New("not a valid http method")
ErrInvalidPath = errors.New("path must begin with '/'")
)
type (
Route struct {
Path string
Handler http.HandlerFunc
}
Router interface {
http.Handler
Handle(method string, path string, handler http.Handler) error
SetNotFoundHandler(handler http.Handler)
}
)