diff --git a/core/logx/customlogger.go b/core/logx/durationlogger.go similarity index 51% rename from core/logx/customlogger.go rename to core/logx/durationlogger.go index 4bb1d7d2..45fe5524 100644 --- a/core/logx/customlogger.go +++ b/core/logx/durationlogger.go @@ -8,55 +8,60 @@ import ( "github.com/tal-tech/go-zero/core/timex" ) -const customCallerDepth = 3 +const durationCallerDepth = 3 -type customLog logEntry +type durationLogger logEntry func WithDuration(d time.Duration) Logger { - return customLog{ + return &durationLogger{ Duration: timex.ReprOfDuration(d), } } -func (l customLog) Error(v ...interface{}) { +func (l *durationLogger) Error(v ...interface{}) { if shouldLog(ErrorLevel) { - l.write(errorLog, levelError, formatWithCaller(fmt.Sprint(v...), customCallerDepth)) + l.write(errorLog, levelError, formatWithCaller(fmt.Sprint(v...), durationCallerDepth)) } } -func (l customLog) Errorf(format string, v ...interface{}) { +func (l *durationLogger) Errorf(format string, v ...interface{}) { if shouldLog(ErrorLevel) { - l.write(errorLog, levelError, formatWithCaller(fmt.Sprintf(format, v...), customCallerDepth)) + l.write(errorLog, levelError, formatWithCaller(fmt.Sprintf(format, v...), durationCallerDepth)) } } -func (l customLog) Info(v ...interface{}) { +func (l *durationLogger) Info(v ...interface{}) { if shouldLog(InfoLevel) { l.write(infoLog, levelInfo, fmt.Sprint(v...)) } } -func (l customLog) Infof(format string, v ...interface{}) { +func (l *durationLogger) Infof(format string, v ...interface{}) { if shouldLog(InfoLevel) { l.write(infoLog, levelInfo, fmt.Sprintf(format, v...)) } } -func (l customLog) Slow(v ...interface{}) { +func (l *durationLogger) Slow(v ...interface{}) { if shouldLog(ErrorLevel) { l.write(slowLog, levelSlow, fmt.Sprint(v...)) } } -func (l customLog) Slowf(format string, v ...interface{}) { +func (l *durationLogger) Slowf(format string, v ...interface{}) { if shouldLog(ErrorLevel) { l.write(slowLog, levelSlow, fmt.Sprintf(format, v...)) } } -func (l customLog) write(writer io.Writer, level, content string) { +func (l *durationLogger) WithDuration(duration time.Duration) Logger { + l.Duration = timex.ReprOfDuration(duration) + return l +} + +func (l *durationLogger) write(writer io.Writer, level, content string) { l.Timestamp = getTimestamp() l.Level = level l.Content = content - outputJson(writer, logEntry(l)) + outputJson(writer, logEntry(*l)) } diff --git a/core/logx/durationlogger_test.go b/core/logx/durationlogger_test.go new file mode 100644 index 00000000..f5b6eb77 --- /dev/null +++ b/core/logx/durationlogger_test.go @@ -0,0 +1,52 @@ +package logx + +import ( + "log" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestWithDurationError(t *testing.T) { + var builder strings.Builder + log.SetOutput(&builder) + WithDuration(time.Second).Error("foo") + assert.True(t, strings.Contains(builder.String(), "duration"), builder.String()) +} + +func TestWithDurationErrorf(t *testing.T) { + var builder strings.Builder + log.SetOutput(&builder) + WithDuration(time.Second).Errorf("foo") + assert.True(t, strings.Contains(builder.String(), "duration"), builder.String()) +} + +func TestWithDurationInfo(t *testing.T) { + var builder strings.Builder + log.SetOutput(&builder) + WithDuration(time.Second).Info("foo") + assert.True(t, strings.Contains(builder.String(), "duration"), builder.String()) +} + +func TestWithDurationInfof(t *testing.T) { + var builder strings.Builder + log.SetOutput(&builder) + WithDuration(time.Second).Infof("foo") + assert.True(t, strings.Contains(builder.String(), "duration"), builder.String()) +} + +func TestWithDurationSlow(t *testing.T) { + var builder strings.Builder + log.SetOutput(&builder) + WithDuration(time.Second).Slow("foo") + assert.True(t, strings.Contains(builder.String(), "duration"), builder.String()) +} + +func TestWithDurationSlowf(t *testing.T) { + var builder strings.Builder + log.SetOutput(&builder) + WithDuration(time.Second).WithDuration(time.Hour).Slowf("foo") + assert.True(t, strings.Contains(builder.String(), "duration"), builder.String()) +} diff --git a/core/logx/logs.go b/core/logx/logs.go index f328abf9..3347d333 100644 --- a/core/logx/logs.go +++ b/core/logx/logs.go @@ -15,6 +15,7 @@ import ( "strings" "sync" "sync/atomic" + "time" "github.com/tal-tech/go-zero/core/iox" "github.com/tal-tech/go-zero/core/sysx" @@ -96,6 +97,7 @@ type ( Infof(string, ...interface{}) Slow(...interface{}) Slowf(string, ...interface{}) + WithDuration(time.Duration) Logger } ) diff --git a/core/logx/tracelog.go b/core/logx/tracelogger.go similarity index 65% rename from core/logx/tracelog.go rename to core/logx/tracelogger.go index f5309d6e..b07a22ee 100644 --- a/core/logx/tracelog.go +++ b/core/logx/tracelogger.go @@ -4,54 +4,61 @@ import ( "context" "fmt" "io" + "time" + "github.com/tal-tech/go-zero/core/timex" "github.com/tal-tech/go-zero/core/trace/tracespec" ) -type tracingEntry struct { +type traceLogger struct { logEntry Trace string `json:"trace,omitempty"` Span string `json:"span,omitempty"` ctx context.Context } -func (l tracingEntry) Error(v ...interface{}) { +func (l *traceLogger) Error(v ...interface{}) { if shouldLog(ErrorLevel) { - l.write(errorLog, levelError, formatWithCaller(fmt.Sprint(v...), customCallerDepth)) + l.write(errorLog, levelError, formatWithCaller(fmt.Sprint(v...), durationCallerDepth)) } } -func (l tracingEntry) Errorf(format string, v ...interface{}) { +func (l *traceLogger) Errorf(format string, v ...interface{}) { if shouldLog(ErrorLevel) { - l.write(errorLog, levelError, formatWithCaller(fmt.Sprintf(format, v...), customCallerDepth)) + l.write(errorLog, levelError, formatWithCaller(fmt.Sprintf(format, v...), durationCallerDepth)) } } -func (l tracingEntry) Info(v ...interface{}) { +func (l *traceLogger) Info(v ...interface{}) { if shouldLog(InfoLevel) { l.write(infoLog, levelInfo, fmt.Sprint(v...)) } } -func (l tracingEntry) Infof(format string, v ...interface{}) { +func (l *traceLogger) Infof(format string, v ...interface{}) { if shouldLog(InfoLevel) { l.write(infoLog, levelInfo, fmt.Sprintf(format, v...)) } } -func (l tracingEntry) Slow(v ...interface{}) { +func (l *traceLogger) Slow(v ...interface{}) { if shouldLog(ErrorLevel) { l.write(slowLog, levelSlow, fmt.Sprint(v...)) } } -func (l tracingEntry) Slowf(format string, v ...interface{}) { +func (l *traceLogger) Slowf(format string, v ...interface{}) { if shouldLog(ErrorLevel) { l.write(slowLog, levelSlow, fmt.Sprintf(format, v...)) } } -func (l tracingEntry) write(writer io.Writer, level, content string) { +func (l *traceLogger) WithDuration(duration time.Duration) Logger { + l.Duration = timex.ReprOfDuration(duration) + return l +} + +func (l *traceLogger) write(writer io.Writer, level, content string) { l.Timestamp = getTimestamp() l.Level = level l.Content = content @@ -61,7 +68,7 @@ func (l tracingEntry) write(writer io.Writer, level, content string) { } func WithContext(ctx context.Context) Logger { - return tracingEntry{ + return &traceLogger{ ctx: ctx, } } diff --git a/core/logx/tracelog_test.go b/core/logx/tracelogger_test.go similarity index 94% rename from core/logx/tracelog_test.go rename to core/logx/tracelogger_test.go index 96223d38..f1e76000 100644 --- a/core/logx/tracelog_test.go +++ b/core/logx/tracelogger_test.go @@ -19,7 +19,7 @@ var mock tracespec.Trace = new(mockTrace) func TestTraceLog(t *testing.T) { var buf strings.Builder ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock) - WithContext(ctx).(tracingEntry).write(&buf, levelInfo, testlog) + WithContext(ctx).(*traceLogger).write(&buf, levelInfo, testlog) assert.True(t, strings.Contains(buf.String(), mockTraceId)) assert.True(t, strings.Contains(buf.String(), mockSpanId)) } diff --git a/doc/images/wechat.jpg b/doc/images/wechat.jpg index 59a9f4e8..1767b4d5 100644 Binary files a/doc/images/wechat.jpg and b/doc/images/wechat.jpg differ diff --git a/doc/jwt.md b/doc/jwt.md index 77b5ef11..fca8b070 100644 --- a/doc/jwt.md +++ b/doc/jwt.md @@ -1,4 +1,4 @@ -### 基于go-zero实现JWT认证 +# 基于go-zero实现JWT认证 关于JWT是什么,大家可以看看[官网](https://jwt.io/),一句话介绍下:是可以实现服务器无状态的鉴权认证方案,也是目前最流行的跨域认证解决方案。 @@ -7,7 +7,7 @@ * 客户端获取JWT token。 * 服务器对客户端带来的JWT token认证。 -### 1. 客户端获取JWT Token +## 1. 客户端获取JWT Token 我们定义一个协议供客户端调用获取JWT token,我们新建一个目录jwt然后在目录中执行 `goctl api -o jwt.api`,将生成的jwt.api改成如下: @@ -61,7 +61,11 @@ func (l *JwtLogic) Jwt(req types.JwtTokenRequest) (*types.JwtTokenResponse, erro return nil, err } - return &types.JwtTokenResponse{AccessToken: accessToken, AccessExpire: now + accessExpire, RefreshAfter: now + accessExpire/2}, nil + return &types.JwtTokenResponse{ + AccessToken: accessToken, + AccessExpire: now + accessExpire, + RefreshAfter: now + accessExpire/2, + }, nil } func (l *JwtLogic) GenToken(iat int64, secretKey string, payloads map[string]interface{}, seconds int64) (string, error) { @@ -91,13 +95,11 @@ JwtAuth: 启动服务器,然后测试下获取到的token。 ```sh -➜ jwt curl --location --request POST '127.0.0.1:8888/user/token' +➜ curl --location --request POST '127.0.0.1:8888/user/token' {"access_token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2MDEyNjE0MjksImlhdCI6MTYwMDY1NjYyOX0.6u_hpE_4m5gcI90taJLZtvfekwUmjrbNJ-5saaDGeQc","access_expire":1601261429,"refresh_after":1600959029} ``` - - -### 2 服务器验证JWT token +## 2. 服务器验证JWT token 1. 在api文件中通过`jwt: JwtAuth`标记的service表示激活了jwt认证。 2. 可以阅读rest/handler/authhandler.go文件了解服务器jwt实现。 @@ -112,7 +114,7 @@ func (l *GetUserLogic) GetUser(req types.GetUserRequest) (*types.GetUserResponse * 我们先不带JWT Authorization header请求头测试下,返回http status code是401,符合预期。 ```sh -➜ jwt curl -w "\nhttp: %{http_code} \n" --location --request POST '127.0.0.1:8888/user/info' \ +➜ curl -w "\nhttp: %{http_code} \n" --location --request POST '127.0.0.1:8888/user/info' \ --header 'Content-Type: application/json' \ --data-raw '{ "userId": "a" @@ -124,7 +126,7 @@ http: 401 * 加上Authorization header请求头测试。 ```sh -➜ jwt curl -w "\nhttp: %{http_code} \n" --location --request POST '127.0.0.1:8888/user/info' \ +➜ curl -w "\nhttp: %{http_code} \n" --location --request POST '127.0.0.1:8888/user/info' \ --header 'Authorization: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2MDEyNjE0MjksImlhdCI6MTYwMDY1NjYyOX0.6u_hpE_4m5gcI90taJLZtvfekwUmjrbNJ-5saaDGeQc' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -134,7 +136,5 @@ http: 401 http: 200 ``` - - 综上所述:基于go-zero的JWT认证完成,在真实生产环境部署时候,AccessSecret, AccessExpire, RefreshAfter根据业务场景通过配置文件配置,RefreshAfter 是告诉客户端什么时候该刷新JWT token了,一般都需要设置过期时间前几天。 diff --git a/rest/httpx/requests_test.go b/rest/httpx/requests_test.go index 49ecaa7e..be930b17 100644 --- a/rest/httpx/requests_test.go +++ b/rest/httpx/requests_test.go @@ -109,6 +109,18 @@ func TestParseRequired(t *testing.T) { assert.NotNil(t, err) } +func TestParseOptions(t *testing.T) { + v := struct { + Position int8 `form:"pos,options=1|2"` + }{} + + r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?pos=4", nil) + assert.Nil(t, err) + + err = Parse(r, &v) + assert.NotNil(t, err) +} + func BenchmarkParseRaw(b *testing.B) { r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", nil) if err != nil { diff --git a/zrpc/internal/client.go b/zrpc/internal/client.go index 3ae36605..45ea36eb 100644 --- a/zrpc/internal/client.go +++ b/zrpc/internal/client.go @@ -66,11 +66,11 @@ func buildDialOptions(opts ...ClientOption) []grpc.DialOption { grpc.WithInsecure(), grpc.WithBlock(), WithUnaryClientInterceptors( - clientinterceptors.BreakerInterceptor, + clientinterceptors.TracingInterceptor, clientinterceptors.DurationInterceptor, + clientinterceptors.BreakerInterceptor, clientinterceptors.PromMetricInterceptor, clientinterceptors.TimeoutInterceptor(clientOptions.Timeout), - clientinterceptors.TracingInterceptor, ), } diff --git a/zrpc/internal/clientinterceptors/durationinterceptor.go b/zrpc/internal/clientinterceptors/durationinterceptor.go index 12ed96bc..b5ace685 100644 --- a/zrpc/internal/clientinterceptors/durationinterceptor.go +++ b/zrpc/internal/clientinterceptors/durationinterceptor.go @@ -18,11 +18,13 @@ func DurationInterceptor(ctx context.Context, method string, req, reply interfac start := timex.Now() err := invoker(ctx, method, req, reply, cc, opts...) if err != nil { - logx.WithDuration(timex.Since(start)).Infof("fail - %s - %v - %s", serverName, req, err.Error()) + logx.WithContext(ctx).WithDuration(timex.Since(start)).Infof("fail - %s - %v - %s", + serverName, req, err.Error()) } else { elapsed := timex.Since(start) if elapsed > slowThreshold { - logx.WithDuration(elapsed).Slowf("[RPC] ok - slowcall - %s - %v - %v", serverName, req, reply) + logx.WithContext(ctx).WithDuration(elapsed).Slowf("[RPC] ok - slowcall - %s - %v - %v", + serverName, req, reply) } } diff --git a/zrpc/internal/rpcserver.go b/zrpc/internal/rpcserver.go index 15976c4c..c9919c74 100644 --- a/zrpc/internal/rpcserver.go +++ b/zrpc/internal/rpcserver.go @@ -17,6 +17,7 @@ type ( } rpcServer struct { + name string *baseRpcServer } ) @@ -40,6 +41,7 @@ func NewRpcServer(address string, opts ...ServerOption) Server { } func (s *rpcServer) SetName(name string) { + s.name = name s.baseRpcServer.SetName(name) } @@ -50,6 +52,7 @@ func (s *rpcServer) Start(register RegisterFn) error { } unaryInterceptors := []grpc.UnaryServerInterceptor{ + serverinterceptors.UnaryTracingInterceptor(s.name), serverinterceptors.UnaryCrashInterceptor(), serverinterceptors.UnaryStatInterceptor(s.metrics), serverinterceptors.UnaryPromMetricInterceptor(), diff --git a/zrpc/internal/serverinterceptors/statinterceptor.go b/zrpc/internal/serverinterceptors/statinterceptor.go index e0900099..3c653c62 100644 --- a/zrpc/internal/serverinterceptors/statinterceptor.go +++ b/zrpc/internal/serverinterceptors/statinterceptor.go @@ -42,10 +42,11 @@ func logDuration(ctx context.Context, method string, req interface{}, duration t } content, err := json.Marshal(req) if err != nil { - logx.Errorf("%s - %s", addr, err.Error()) + logx.WithContext(ctx).Errorf("%s - %s", addr, err.Error()) } else if duration > serverSlowThreshold { - logx.WithDuration(duration).Slowf("[RPC] slowcall - %s - %s - %s", addr, method, string(content)) + logx.WithContext(ctx).WithDuration(duration).Slowf("[RPC] slowcall - %s - %s - %s", + addr, method, string(content)) } else { - logx.WithDuration(duration).Infof("%s - %s - %s", addr, method, string(content)) + logx.WithContext(ctx).WithDuration(duration).Infof("%s - %s - %s", addr, method, string(content)) } } diff --git a/zrpc/server.go b/zrpc/server.go index 08582df7..4de020da 100644 --- a/zrpc/server.go +++ b/zrpc/server.go @@ -109,8 +109,6 @@ func setupInterceptors(server internal.Server, c RpcServerConf, metrics *stat.Me time.Duration(c.Timeout) * time.Millisecond)) } - server.AddUnaryInterceptors(serverinterceptors.UnaryTracingInterceptor(c.Name)) - if c.Auth { authenticator, err := auth.NewAuthenticator(c.Redis.NewRedis(), c.Redis.Key, c.StrictControl) if err != nil {