support cors in rest server
This commit is contained in:
@@ -22,8 +22,9 @@ var (
|
||||
)
|
||||
|
||||
type patRouter struct {
|
||||
trees map[string]*search.Tree
|
||||
notFound http.Handler
|
||||
trees map[string]*search.Tree
|
||||
notFound http.Handler
|
||||
notAllowed http.Handler
|
||||
}
|
||||
|
||||
func NewRouter() httpx.Router {
|
||||
@@ -63,11 +64,17 @@ func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
if allow, ok := pr.methodNotAllowed(r.Method, reqPath); ok {
|
||||
allow, ok := pr.methodNotAllowed(r.Method, reqPath)
|
||||
if !ok {
|
||||
pr.handleNotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if pr.notAllowed != nil {
|
||||
pr.notAllowed.ServeHTTP(w, r)
|
||||
} else {
|
||||
w.Header().Set(allowHeader, allow)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
} else {
|
||||
pr.handleNotFound(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,6 +82,10 @@ func (pr *patRouter) SetNotFoundHandler(handler http.Handler) {
|
||||
pr.notFound = handler
|
||||
}
|
||||
|
||||
func (pr *patRouter) SetNotAllowedHandler(handler http.Handler) {
|
||||
pr.notAllowed = handler
|
||||
}
|
||||
|
||||
func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
|
||||
if pr.notFound != nil {
|
||||
pr.notFound.ServeHTTP(w, r)
|
||||
|
||||
@@ -60,13 +60,30 @@ func TestPatRouterNotFound(t *testing.T) {
|
||||
router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
notFound = true
|
||||
}))
|
||||
router.Handle(http.MethodGet, "/a/b", nil)
|
||||
err := router.Handle(http.MethodGet, "/a/b",
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
assert.Nil(t, err)
|
||||
r, _ := http.NewRequest(http.MethodGet, "/b/c", nil)
|
||||
w := new(mockedResponseWriter)
|
||||
router.ServeHTTP(w, r)
|
||||
assert.True(t, notFound)
|
||||
}
|
||||
|
||||
func TestPatRouterNotAllowed(t *testing.T) {
|
||||
var notAllowed bool
|
||||
router := NewRouter()
|
||||
router.SetNotAllowedHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
notAllowed = true
|
||||
}))
|
||||
err := router.Handle(http.MethodGet, "/a/b",
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
assert.Nil(t, err)
|
||||
r, _ := http.NewRequest(http.MethodPost, "/a/b", nil)
|
||||
w := new(mockedResponseWriter)
|
||||
router.ServeHTTP(w, r)
|
||||
assert.True(t, notAllowed)
|
||||
}
|
||||
|
||||
func TestPatRouter(t *testing.T) {
|
||||
tests := []struct {
|
||||
method string
|
||||
|
||||
Reference in New Issue
Block a user