diff --git a/core/logx/logs.go b/core/logx/logs.go index 1f5e120b..c1de018f 100644 --- a/core/logx/logs.go +++ b/core/logx/logs.go @@ -197,7 +197,12 @@ func Must(err error) { msg := err.Error() log.Print(msg) getWriter().Severe(msg) - os.Exit(1) + + if ExitOnFatal.True() { + os.Exit(1) + } else { + panic(msg) + } } // MustSetup sets up logging with given config c. It exits on error. diff --git a/core/logx/logs_test.go b/core/logx/logs_test.go index 541da29c..0b37960f 100644 --- a/core/logx/logs_test.go +++ b/core/logx/logs_test.go @@ -24,6 +24,10 @@ var ( _ Writer = (*mockWriter)(nil) ) +func init() { + ExitOnFatal.Set(false) +} + type mockWriter struct { lock sync.Mutex builder strings.Builder @@ -208,6 +212,12 @@ func TestFileLineConsoleMode(t *testing.T) { assert.True(t, w.Contains(fmt.Sprintf("%s:%d", file, line+1))) } +func TestMust(t *testing.T) { + assert.Panics(t, func() { + Must(errors.New("foo")) + }) +} + func TestStructedLogAlert(t *testing.T) { w := new(mockWriter) old := writer.Swap(w) @@ -574,26 +584,38 @@ func TestSetup(t *testing.T) { atomic.StoreUint32(&encoding, jsonEncodingType) }() + setupOnce = sync.Once{} + MustSetup(LogConf{ + ServiceName: "any", + Mode: "console", + Encoding: "json", + TimeFormat: timeFormat, + }) + setupOnce = sync.Once{} MustSetup(LogConf{ ServiceName: "any", Mode: "console", TimeFormat: timeFormat, }) + setupOnce = sync.Once{} MustSetup(LogConf{ ServiceName: "any", Mode: "file", Path: os.TempDir(), }) + setupOnce = sync.Once{} MustSetup(LogConf{ ServiceName: "any", Mode: "volume", Path: os.TempDir(), }) + setupOnce = sync.Once{} MustSetup(LogConf{ ServiceName: "any", Mode: "console", TimeFormat: timeFormat, }) + setupOnce = sync.Once{} MustSetup(LogConf{ ServiceName: "any", Mode: "console", diff --git a/core/logx/rotatelogger.go b/core/logx/rotatelogger.go index 926bb020..42d84b1c 100644 --- a/core/logx/rotatelogger.go +++ b/core/logx/rotatelogger.go @@ -237,7 +237,7 @@ func NewLogger(filename string, rule RotateRule, compress bool) (*RotateLogger, rule: rule, compress: compress, } - if err := l.init(); err != nil { + if err := l.initialize(); err != nil { return nil, err } @@ -281,7 +281,7 @@ func (l *RotateLogger) getBackupFilename() string { return l.backup } -func (l *RotateLogger) init() error { +func (l *RotateLogger) initialize() error { l.backup = l.rule.BackupFileName() if fileInfo, err := os.Stat(l.filename); err != nil { diff --git a/core/logx/vars.go b/core/logx/vars.go index 441bf973..59688e87 100644 --- a/core/logx/vars.go +++ b/core/logx/vars.go @@ -1,6 +1,10 @@ package logx -import "errors" +import ( + "errors" + + "github.com/zeromicro/go-zero/core/syncx" +) const ( // DebugLevel logs everything @@ -61,6 +65,8 @@ var ( ErrLogPathNotSet = errors.New("log path must be set") // ErrLogServiceNameNotSet is an error that indicates that the service name is not set. ErrLogServiceNameNotSet = errors.New("log service name must be set") + // ExitOnFatal defines whether to exit on fatal errors, defined here to make it easier to test. + ExitOnFatal = syncx.ForAtomicBool(true) truncatedField = Field(truncatedKey, true) ) diff --git a/rest/engine.go b/rest/engine.go index cf5d0538..68a946fa 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -301,22 +301,26 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain.Chai }, nil } -func (ng *engine) start(router httpx.Router, opts ...internal.StartOption) error { +func (ng *engine) start(router httpx.Router, opts ...StartOption) error { if err := ng.bindRoutes(router); err != nil { return err } - opts = append(opts, ng.withTimeout()) + // make sure user defined options overwrite default options + opts = append([]StartOption{ng.withTimeout()}, opts...) if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 { return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, opts...) } - opts = append(opts, func(svr *http.Server) { - if ng.tlsConfig != nil { - svr.TLSConfig = ng.tlsConfig - } - }) + // make sure user defined options overwrite default options + opts = append([]StartOption{ + func(svr *http.Server) { + if ng.tlsConfig != nil { + svr.TLSConfig = ng.tlsConfig + } + }, + }, opts...) return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile, ng.conf.KeyFile, router, opts...) diff --git a/rest/engine_test.go b/rest/engine_test.go index 2c676d55..ece1d0d8 100644 --- a/rest/engine_test.go +++ b/rest/engine_test.go @@ -3,6 +3,7 @@ package rest import ( "context" "errors" + "fmt" "net/http" "net/http/httptest" "sync/atomic" @@ -17,18 +18,21 @@ import ( func TestNewEngine(t *testing.T) { yamls := []string{ `Name: foo -Port: 54321 +Host: localhost +Port: 0 Middlewares: Log: false `, `Name: foo -Port: 54321 +Host: localhost +Port: 0 CpuThreshold: 500 Middlewares: Log: false `, `Name: foo -Port: 54321 +Host: localhost +Port: 0 CpuThreshold: 500 Verbose: true `, @@ -150,22 +154,29 @@ Verbose: true } for _, yaml := range yamls { + yaml := yaml for _, route := range routes { - var cnf RestConf - assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf)) - ng := newEngine(cnf) - ng.addRoutes(route) - ng.use(func(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - next.ServeHTTP(w, r) + route := route + t.Run(fmt.Sprintf("%s-%v", yaml, route.routes), func(t *testing.T) { + var cnf RestConf + assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf)) + ng := newEngine(cnf) + ng.addRoutes(route) + ng.use(func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + } + }) + + assert.NotNil(t, ng.start(mockedRouter{}, func(svr *http.Server) { + })) + + timeout := time.Second * 3 + if route.timeout > timeout { + timeout = route.timeout } + assert.Equal(t, timeout, ng.timeout) }) - assert.NotNil(t, ng.start(mockedRouter{})) - timeout := time.Second * 3 - if route.timeout > timeout { - timeout = route.timeout - } - assert.Equal(t, timeout, ng.timeout) } } } @@ -340,7 +351,8 @@ func TestEngine_withTimeout(t *testing.T) { } } -type mockedRouter struct{} +type mockedRouter struct { +} func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) { } diff --git a/rest/server.go b/rest/server.go index 95447718..9583ea52 100644 --- a/rest/server.go +++ b/rest/server.go @@ -2,7 +2,6 @@ package rest import ( "crypto/tls" - "log" "net/http" "path" "time" @@ -21,7 +20,7 @@ type ( RunOption func(*Server) // StartOption defines the method to customize http server. - StartOption func(svr *http.Server) + StartOption = internal.StartOption // A Server is a http server. Server struct { @@ -36,7 +35,7 @@ type ( func MustNewServer(c RestConf, opts ...RunOption) *Server { server, err := NewServer(c, opts...) if err != nil { - log.Fatal(err) + logx.Must(err) } return server @@ -116,12 +115,15 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Start starts the Server. // Graceful shutdown is enabled by default. // Use proc.SetTimeToForceQuit to customize the graceful shutdown period. -func (s *Server) Start(opts ...StartOption) { - var startOption []internal.StartOption - for _, opt := range opts { - startOption = append(startOption, internal.StartOption(opt)) - } - handleError(s.ngin.start(s.router, startOption...)) +func (s *Server) Start() { + handleError(s.ngin.start(s.router)) +} + +// StartWithOpts starts the Server. +// Graceful shutdown is enabled by default. +// Use proc.SetTimeToForceQuit to customize the graceful shutdown period. +func (s *Server) StartWithOpts(opts ...StartOption) { + handleError(s.ngin.start(s.router, opts...)) } // Stop stops the Server. diff --git a/rest/server_test.go b/rest/server_test.go index bd215fae..bc3f6577 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -28,7 +28,8 @@ func TestNewServer(t *testing.T) { const configYaml = ` Name: foo -Port: 54321 +Host: localhost +Port: 0 ` var cnf RestConf assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf)) @@ -101,6 +102,23 @@ Port: 54321 svr.Start() svr.Stop() }() + + func() { + defer func() { + p := recover() + switch v := p.(type) { + case error: + assert.Equal(t, "foo", v.Error()) + default: + t.Fail() + } + }() + + svr.StartWithOpts(func(svr *http.Server) { + svr.RegisterOnShutdown(func() {}) + }) + svr.Stop() + }() } } @@ -569,7 +587,6 @@ Port: 54321 Method: http.MethodGet, Path: "/user/:name", Handler: func(writer http.ResponseWriter, request *http.Request) { - var userInfo struct { Name string `path:"name"` }