feat: support CORS, better implementation (#1217)

* feat: support CORS, better implementation

* chore: refine code
This commit is contained in:
Kevin Wan
2021-11-09 20:35:57 +08:00
committed by GitHub
parent c1abe87953
commit 28409791fa
3 changed files with 78 additions and 22 deletions

View File

@@ -10,23 +10,42 @@ import (
func TestCorsHandlerWithOrigins(t *testing.T) {
tests := []struct {
name string
origins []string
expect string
name string
origins []string
reqOrigin string
expect string
}{
{
name: "allow all origins",
expect: allOrigins,
},
{
name: "allow one origin",
origins: []string{"local"},
expect: "local",
name: "allow one origin",
origins: []string{"http://local"},
reqOrigin: "http://local",
expect: "http://local",
},
{
name: "allow many origins",
origins: []string{"local", "remote"},
expect: "local",
name: "allow many origins",
origins: []string{"http://local", "http://remote"},
reqOrigin: "http://local",
expect: "http://local",
},
{
name: "allow all origins",
reqOrigin: "http://local",
expect: "*",
},
{
name: "allow many origins with all mark",
origins: []string{"http://local", "http://remote", "*"},
reqOrigin: "http://another",
expect: "http://another",
},
{
name: "not allow origin",
origins: []string{"http://local", "http://remote"},
reqOrigin: "http://another",
},
}
@@ -41,8 +60,9 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
test := test
t.Run(test.name+"-handler", func(t *testing.T) {
r := httptest.NewRequest(method, "http://localhost", nil)
r.Header.Set(originHeader, test.reqOrigin)
w := httptest.NewRecorder()
handler := Handler(test.origins...)
handler := NotAllowedHandler(test.origins...)
handler.ServeHTTP(w, r)
if method == http.MethodOptions {
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
@@ -59,6 +79,7 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
test := test
t.Run(test.name+"-middleware", func(t *testing.T) {
r := httptest.NewRequest(method, "http://localhost", nil)
r.Header.Set(originHeader, test.reqOrigin)
w := httptest.NewRecorder()
handler := Middleware(test.origins...)(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)