chore: only allow cors middleware to change headers (#1276)
This commit is contained in:
@@ -45,12 +45,12 @@ func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.H
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Middleware returns a middleware that adds CORS headers to the response.
|
// Middleware returns a middleware that adds CORS headers to the response.
|
||||||
func Middleware(fn func(w http.ResponseWriter), origins ...string) func(http.HandlerFunc) http.HandlerFunc {
|
func Middleware(fn func(w http.Header), origins ...string) func(http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(next http.HandlerFunc) http.HandlerFunc {
|
return func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
checkAndSetHeaders(w, r, origins)
|
checkAndSetHeaders(w, r, origins)
|
||||||
if fn != nil {
|
if fn != nil {
|
||||||
fn(w)
|
fn(w.Header())
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Method == http.MethodOptions {
|
if r.Method == http.MethodOptions {
|
||||||
|
|||||||
@@ -114,8 +114,8 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
|
|||||||
r := httptest.NewRequest(method, "http://localhost", nil)
|
r := httptest.NewRequest(method, "http://localhost", nil)
|
||||||
r.Header.Set(originHeader, test.reqOrigin)
|
r.Header.Set(originHeader, test.reqOrigin)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
handler := Middleware(func(w http.ResponseWriter) {
|
handler := Middleware(func(header http.Header) {
|
||||||
w.Header().Set("foo", "bar")
|
header.Set("foo", "bar")
|
||||||
}, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
|
}, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -106,10 +106,11 @@ func WithCors(origin ...string) RunOption {
|
|||||||
|
|
||||||
// WithCustomCors returns a func to enable CORS for given origin, or default to all origins (*),
|
// WithCustomCors returns a func to enable CORS for given origin, or default to all origins (*),
|
||||||
// fn lets caller customizing the response.
|
// fn lets caller customizing the response.
|
||||||
func WithCustomCors(fn func(http.ResponseWriter), origin ...string) RunOption {
|
func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(http.ResponseWriter),
|
||||||
|
origin ...string) RunOption {
|
||||||
return func(server *Server) {
|
return func(server *Server) {
|
||||||
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(fn, origin...))
|
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(notAllowedFn, origin...))
|
||||||
server.Use(cors.Middleware(fn, origin...))
|
server.Use(cors.Middleware(middlewareFn, origin...))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -322,8 +322,10 @@ Port: 54321
|
|||||||
srv, err := NewServer(cnf, WithRouter(rt))
|
srv, err := NewServer(cnf, WithRouter(rt))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
opt := WithCustomCors(func(w http.ResponseWriter) {
|
opt := WithCustomCors(func(header http.Header) {
|
||||||
w.Header().Set("foo", "bar")
|
header.Set("foo", "bar")
|
||||||
|
}, func(w http.ResponseWriter) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
}, "local")
|
}, "local")
|
||||||
opt(srv)
|
opt(srv)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user