feat: add rest.WithCustomCors to let caller customize the response (#1274)

This commit is contained in:
Kevin Wan
2021-11-25 23:03:37 +08:00
committed by GitHub
parent 86f9f63b46
commit 0395ba1816
4 changed files with 72 additions and 6 deletions

View File

@@ -62,7 +62,7 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
r := httptest.NewRequest(method, "http://localhost", nil)
r.Header.Set(originHeader, test.reqOrigin)
w := httptest.NewRecorder()
handler := NotAllowedHandler(test.origins...)
handler := NotAllowedHandler(nil, test.origins...)
handler.ServeHTTP(w, r)
if method == http.MethodOptions {
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
@@ -71,6 +71,22 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
}
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
})
t.Run(test.name+"-handler-custom", func(t *testing.T) {
r := httptest.NewRequest(method, "http://localhost", nil)
r.Header.Set(originHeader, test.reqOrigin)
w := httptest.NewRecorder()
handler := NotAllowedHandler(func(w http.ResponseWriter) {
w.Header().Set("foo", "bar")
}, test.origins...)
handler.ServeHTTP(w, r)
if method == http.MethodOptions {
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
} else {
assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
}
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
assert.Equal(t, "bar", w.Header().Get("foo"))
})
}
}
@@ -81,7 +97,7 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
r := httptest.NewRequest(method, "http://localhost", nil)
r.Header.Set(originHeader, test.reqOrigin)
w := httptest.NewRecorder()
handler := Middleware(test.origins...)(func(w http.ResponseWriter, r *http.Request) {
handler := Middleware(nil, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler.ServeHTTP(w, r)
@@ -92,6 +108,24 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
}
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
})
t.Run(test.name+"-middleware-custom", func(t *testing.T) {
r := httptest.NewRequest(method, "http://localhost", nil)
r.Header.Set(originHeader, test.reqOrigin)
w := httptest.NewRecorder()
handler := Middleware(func(w http.ResponseWriter) {
w.Header().Set("foo", "bar")
}, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler.ServeHTTP(w, r)
if method == http.MethodOptions {
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
} else {
assert.Equal(t, http.StatusOK, w.Result().StatusCode)
}
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
assert.Equal(t, "bar", w.Header().Get("foo"))
})
}
}
}