diff --git a/rest/server_test.go b/rest/server_test.go index 413325ad..2c79884d 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -8,18 +8,84 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/core/conf" "github.com/tal-tech/go-zero/rest/httpx" "github.com/tal-tech/go-zero/rest/router" ) func TestNewServer(t *testing.T) { - _, err := NewServer(RestConf{}, WithNotFoundHandler(nil), WithNotAllowedHandler(nil)) - assert.NotNil(t, err) + const configYaml = ` +Name: foo +Port: 54321 +` + var cnf RestConf + assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf)) + failStart := func(server *Server) { + server.opts.start = func(e *engine) error { + return http.ErrServerClosed + } + } + + tests := []struct { + c RestConf + opts []RunOption + fail bool + }{ + { + c: RestConf{}, + opts: []RunOption{failStart}, + fail: true, + }, + { + c: cnf, + opts: []RunOption{failStart}, + }, + { + c: cnf, + opts: []RunOption{WithNotAllowedHandler(nil), failStart}, + }, + { + c: cnf, + opts: []RunOption{WithNotFoundHandler(nil), failStart}, + }, + { + c: cnf, + opts: []RunOption{WithUnauthorizedCallback(nil), failStart}, + }, + { + c: cnf, + opts: []RunOption{WithUnsignedCallback(nil), failStart}, + }, + } + + for _, test := range tests { + srv, err := NewServer(test.c, test.opts...) + if test.fail { + assert.NotNil(t, err) + } + if err != nil { + continue + } + + srv.Use(ToMiddleware(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + }) + })) + srv.AddRoute(Route{ + Method: http.MethodGet, + Path: "/", + Handler: nil, + }, WithJwt("thesecret"), WithSignature(SignatureConf{}), + WithJwtTransition("preivous", "thenewone")) + srv.Start() + srv.Stop() + } } func TestWithMiddleware(t *testing.T) { m := make(map[string]string) - router := router.NewRouter() + rt := router.NewRouter() handler := func(w http.ResponseWriter, r *http.Request) { var v struct { Nickname string `form:"nickname"` @@ -56,14 +122,14 @@ func TestWithMiddleware(t *testing.T) { "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000", } for _, route := range rs { - assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler)) + assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler)) } for _, url := range urls { r, err := http.NewRequest(http.MethodGet, url, nil) assert.Nil(t, err) rr := httptest.NewRecorder() - router.ServeHTTP(rr, r) + rt.ServeHTTP(rr, r) assert.Equal(t, "whatever:200000", rr.Body.String()) } @@ -76,7 +142,7 @@ func TestWithMiddleware(t *testing.T) { func TestMultiMiddlewares(t *testing.T) { m := make(map[string]string) - router := router.NewRouter() + rt := router.NewRouter() handler := func(w http.ResponseWriter, r *http.Request) { var v struct { Nickname string `form:"nickname"` @@ -127,14 +193,14 @@ func TestMultiMiddlewares(t *testing.T) { "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000", } for _, route := range rs { - assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler)) + assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler)) } for _, url := range urls { r, err := http.NewRequest(http.MethodGet, url, nil) assert.Nil(t, err) rr := httptest.NewRecorder() - router.ServeHTTP(rr, r) + rt.ServeHTTP(rr, r) assert.Equal(t, "whatever:200000200000", rr.Body.String()) }