refactor(rest): keep rest log collector context key private (#3407)
This commit is contained in:
@@ -2,6 +2,7 @@ package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
@@ -10,13 +11,32 @@ import (
|
||||
"github.com/zeromicro/go-zero/rest/httpx"
|
||||
)
|
||||
|
||||
// LogContext is a context key.
|
||||
var LogContext = contextKey("request_logs")
|
||||
// logContextKey is a context key.
|
||||
var logContextKey = contextKey("request_logs")
|
||||
|
||||
// A LogCollector is used to collect logs.
|
||||
type LogCollector struct {
|
||||
Messages []string
|
||||
lock sync.Mutex
|
||||
type (
|
||||
// LogCollector is used to collect logs.
|
||||
LogCollector struct {
|
||||
Messages []string
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
contextKey string
|
||||
)
|
||||
|
||||
// WithLogCollector returns a new context with LogCollector.
|
||||
func WithLogCollector(ctx context.Context, lc *LogCollector) context.Context {
|
||||
return context.WithValue(ctx, logContextKey, lc)
|
||||
}
|
||||
|
||||
// LogCollectorFromContext returns LogCollector from ctx.
|
||||
func LogCollectorFromContext(ctx context.Context) *LogCollector {
|
||||
val := ctx.Value(logContextKey)
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return val.(*LogCollector)
|
||||
}
|
||||
|
||||
// Append appends msg into log context.
|
||||
@@ -73,9 +93,9 @@ func Infof(r *http.Request, format string, v ...any) {
|
||||
}
|
||||
|
||||
func appendLog(r *http.Request, message string) {
|
||||
logs := r.Context().Value(LogContext)
|
||||
logs := LogCollectorFromContext(r.Context())
|
||||
if logs != nil {
|
||||
logs.(*LogCollector).Append(message)
|
||||
logs.Append(message)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,9 +110,3 @@ func formatf(r *http.Request, format string, v ...any) string {
|
||||
func formatWithReq(r *http.Request, v string) string {
|
||||
return fmt.Sprintf("(%s - %s) %s", r.RequestURI, httpx.GetRemoteAddr(r), v)
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
func (c contextKey) String() string {
|
||||
return "rest/internal context key " + string(c)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
func TestInfo(t *testing.T) {
|
||||
collector := new(LogCollector)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
|
||||
req = req.WithContext(context.WithValue(req.Context(), LogContext, collector))
|
||||
req = req.WithContext(WithLogCollector(req.Context(), collector))
|
||||
Info(req, "first")
|
||||
Infof(req, "second %s", "third")
|
||||
val := collector.Flush()
|
||||
@@ -35,7 +35,10 @@ func TestError(t *testing.T) {
|
||||
assert.True(t, strings.Contains(val, "third"))
|
||||
}
|
||||
|
||||
func TestContextKey_String(t *testing.T) {
|
||||
val := contextKey("foo")
|
||||
assert.True(t, strings.Contains(val.String(), "foo"))
|
||||
func TestLogCollectorContext(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
assert.Nil(t, LogCollectorFromContext(ctx))
|
||||
collector := new(LogCollector)
|
||||
ctx = WithLogCollector(ctx, collector)
|
||||
assert.Equal(t, collector, LogCollectorFromContext(ctx))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user