feat: add rest.WithCustomCors to let caller customize the response (#1274)
This commit is contained in:
@@ -62,7 +62,7 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
|
||||
r := httptest.NewRequest(method, "http://localhost", nil)
|
||||
r.Header.Set(originHeader, test.reqOrigin)
|
||||
w := httptest.NewRecorder()
|
||||
handler := NotAllowedHandler(test.origins...)
|
||||
handler := NotAllowedHandler(nil, test.origins...)
|
||||
handler.ServeHTTP(w, r)
|
||||
if method == http.MethodOptions {
|
||||
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||
@@ -71,6 +71,22 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
|
||||
}
|
||||
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
|
||||
})
|
||||
t.Run(test.name+"-handler-custom", func(t *testing.T) {
|
||||
r := httptest.NewRequest(method, "http://localhost", nil)
|
||||
r.Header.Set(originHeader, test.reqOrigin)
|
||||
w := httptest.NewRecorder()
|
||||
handler := NotAllowedHandler(func(w http.ResponseWriter) {
|
||||
w.Header().Set("foo", "bar")
|
||||
}, test.origins...)
|
||||
handler.ServeHTTP(w, r)
|
||||
if method == http.MethodOptions {
|
||||
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||
} else {
|
||||
assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
|
||||
}
|
||||
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
|
||||
assert.Equal(t, "bar", w.Header().Get("foo"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,7 +97,7 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
|
||||
r := httptest.NewRequest(method, "http://localhost", nil)
|
||||
r.Header.Set(originHeader, test.reqOrigin)
|
||||
w := httptest.NewRecorder()
|
||||
handler := Middleware(test.origins...)(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler := Middleware(nil, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler.ServeHTTP(w, r)
|
||||
@@ -92,6 +108,24 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
|
||||
}
|
||||
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
|
||||
})
|
||||
t.Run(test.name+"-middleware-custom", func(t *testing.T) {
|
||||
r := httptest.NewRequest(method, "http://localhost", nil)
|
||||
r.Header.Set(originHeader, test.reqOrigin)
|
||||
w := httptest.NewRecorder()
|
||||
handler := Middleware(func(w http.ResponseWriter) {
|
||||
w.Header().Set("foo", "bar")
|
||||
}, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler.ServeHTTP(w, r)
|
||||
if method == http.MethodOptions {
|
||||
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||
} else {
|
||||
assert.Equal(t, http.StatusOK, w.Result().StatusCode)
|
||||
}
|
||||
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
|
||||
assert.Equal(t, "bar", w.Header().Get("foo"))
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user