diff --git a/core/mapping/unmarshaler.go b/core/mapping/unmarshaler.go index 4184c009..e4293714 100644 --- a/core/mapping/unmarshaler.go +++ b/core/mapping/unmarshaler.go @@ -77,7 +77,7 @@ func (u *Unmarshaler) Unmarshal(i interface{}, v interface{}) error { return errValueNotSettable } - elemType := valueType.Elem() + elemType := Deref(valueType) switch iv := i.(type) { case map[string]interface{}: if elemType.Kind() != reflect.Struct { @@ -818,15 +818,22 @@ func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v interface{}, f return err } - rte := reflect.TypeOf(v).Elem() - if rte.Kind() != reflect.Struct { + valueType := reflect.TypeOf(v) + baseType := Deref(valueType) + if baseType.Kind() != reflect.Struct { return errValueNotStruct } - rve := rv.Elem() - numFields := rte.NumField() + valElem := rv.Elem() + if valElem.Kind() == reflect.Ptr { + target := reflect.New(baseType).Elem() + SetValue(valueType.Elem(), valElem, target) + valElem = target + } + + numFields := baseType.NumField() for i := 0; i < numFields; i++ { - if err := u.processField(rte.Field(i), rve.Field(i), m, fullName); err != nil { + if err := u.processField(baseType.Field(i), valElem.Field(i), m, fullName); err != nil { return err } } diff --git a/core/mapping/unmarshaler_test.go b/core/mapping/unmarshaler_test.go index 148dae50..fc4d2cd4 100644 --- a/core/mapping/unmarshaler_test.go +++ b/core/mapping/unmarshaler_test.go @@ -3,6 +3,7 @@ package mapping import ( "encoding/json" "fmt" + "os" "strconv" "strings" "testing" @@ -3388,7 +3389,8 @@ func TestUnmarshal_EnvString(t *testing.T) { envName = "TEST_NAME_STRING" envVal = "this is a name" ) - t.Setenv(envName, envVal) + os.Setenv(envName, envVal) + defer os.Unsetenv(envName) var v Value if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { @@ -3405,7 +3407,8 @@ func TestUnmarshal_EnvStringOverwrite(t *testing.T) { envName = "TEST_NAME_STRING" envVal = "this is a name" ) - t.Setenv(envName, envVal) + os.Setenv(envName, envVal) + defer os.Unsetenv(envName) var v Value if assert.NoError(t, UnmarshalKey(map[string]interface{}{ @@ -3420,8 +3423,12 @@ func TestUnmarshal_EnvInt(t *testing.T) { Age int `key:"age,env=TEST_NAME_INT"` } - const envName = "TEST_NAME_INT" - t.Setenv(envName, "123") + const ( + envName = "TEST_NAME_INT" + envVal = "123" + ) + os.Setenv(envName, envVal) + defer os.Unsetenv(envName) var v Value if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { @@ -3434,8 +3441,12 @@ func TestUnmarshal_EnvIntOverwrite(t *testing.T) { Age int `key:"age,env=TEST_NAME_INT"` } - const envName = "TEST_NAME_INT" - t.Setenv(envName, "123") + const ( + envName = "TEST_NAME_INT" + envVal = "123" + ) + os.Setenv(envName, envVal) + defer os.Unsetenv(envName) var v Value if assert.NoError(t, UnmarshalKey(map[string]interface{}{ @@ -3450,8 +3461,12 @@ func TestUnmarshal_EnvFloat(t *testing.T) { Age float32 `key:"name,env=TEST_NAME_FLOAT"` } - const envName = "TEST_NAME_FLOAT" - t.Setenv(envName, "123.45") + const ( + envName = "TEST_NAME_FLOAT" + envVal = "123.45" + ) + os.Setenv(envName, envVal) + defer os.Unsetenv(envName) var v Value if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { @@ -3464,8 +3479,12 @@ func TestUnmarshal_EnvFloatOverwrite(t *testing.T) { Age float32 `key:"age,env=TEST_NAME_FLOAT"` } - const envName = "TEST_NAME_FLOAT" - t.Setenv(envName, "123.45") + const ( + envName = "TEST_NAME_FLOAT" + envVal = "123.45" + ) + os.Setenv(envName, envVal) + defer os.Unsetenv(envName) var v Value if assert.NoError(t, UnmarshalKey(map[string]interface{}{ @@ -3480,8 +3499,12 @@ func TestUnmarshal_EnvBoolTrue(t *testing.T) { Enable bool `key:"enable,env=TEST_NAME_BOOL_TRUE"` } - const envName = "TEST_NAME_BOOL_TRUE" - t.Setenv(envName, "true") + const ( + envName = "TEST_NAME_BOOL_TRUE" + envVal = "true" + ) + os.Setenv(envName, envVal) + defer os.Unsetenv(envName) var v Value if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { @@ -3494,8 +3517,12 @@ func TestUnmarshal_EnvBoolFalse(t *testing.T) { Enable bool `key:"enable,env=TEST_NAME_BOOL_FALSE"` } - const envName = "TEST_NAME_BOOL_FALSE" - t.Setenv(envName, "false") + const ( + envName = "TEST_NAME_BOOL_FALSE" + envVal = "false" + ) + os.Setenv(envName, envVal) + defer os.Unsetenv(envName) var v Value if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { @@ -3508,8 +3535,12 @@ func TestUnmarshal_EnvBoolBad(t *testing.T) { Enable bool `key:"enable,env=TEST_NAME_BOOL_BAD"` } - const envName = "TEST_NAME_BOOL_BAD" - t.Setenv(envName, "bad") + const ( + envName = "TEST_NAME_BOOL_BAD" + envVal = "bad" + ) + os.Setenv(envName, envVal) + defer os.Unsetenv(envName) var v Value assert.Error(t, UnmarshalKey(emptyMap, &v)) @@ -3520,8 +3551,12 @@ func TestUnmarshal_EnvDuration(t *testing.T) { Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"` } - const envName = "TEST_NAME_DURATION" - t.Setenv(envName, "1s") + const ( + envName = "TEST_NAME_DURATION" + envVal = "1s" + ) + os.Setenv(envName, envVal) + defer os.Unsetenv(envName) var v Value if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { @@ -3534,8 +3569,12 @@ func TestUnmarshal_EnvDurationBadValue(t *testing.T) { Duration time.Duration `key:"duration,env=TEST_NAME_BAD_DURATION"` } - const envName = "TEST_NAME_BAD_DURATION" - t.Setenv(envName, "bad") + const ( + envName = "TEST_NAME_BAD_DURATION" + envVal = "bad" + ) + os.Setenv(envName, envVal) + defer os.Unsetenv(envName) var v Value assert.Error(t, UnmarshalKey(emptyMap, &v)) @@ -3550,7 +3589,8 @@ func TestUnmarshal_EnvWithOptions(t *testing.T) { envName = "TEST_NAME_ENV_OPTIONS_MATCH" envVal = "123" ) - t.Setenv(envName, envVal) + os.Setenv(envName, envVal) + defer os.Unsetenv(envName) var v Value if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { @@ -3567,7 +3607,8 @@ func TestUnmarshal_EnvWithOptionsWrongValueBool(t *testing.T) { envName = "TEST_NAME_ENV_OPTIONS_BOOL" envVal = "false" ) - t.Setenv(envName, envVal) + os.Setenv(envName, envVal) + defer os.Unsetenv(envName) var v Value assert.Error(t, UnmarshalKey(emptyMap, &v)) @@ -3582,7 +3623,8 @@ func TestUnmarshal_EnvWithOptionsWrongValueDuration(t *testing.T) { envName = "TEST_NAME_ENV_OPTIONS_DURATION" envVal = "4s" ) - t.Setenv(envName, envVal) + os.Setenv(envName, envVal) + defer os.Unsetenv(envName) var v Value assert.Error(t, UnmarshalKey(emptyMap, &v)) @@ -3597,7 +3639,8 @@ func TestUnmarshal_EnvWithOptionsWrongValueNumber(t *testing.T) { envName = "TEST_NAME_ENV_OPTIONS_AGE" envVal = "30" ) - t.Setenv(envName, envVal) + os.Setenv(envName, envVal) + defer os.Unsetenv(envName) var v Value assert.Error(t, UnmarshalKey(emptyMap, &v)) @@ -3612,7 +3655,8 @@ func TestUnmarshal_EnvWithOptionsWrongValueString(t *testing.T) { envName = "TEST_NAME_ENV_OPTIONS_STRING" envVal = "this is a name" ) - t.Setenv(envName, envVal) + os.Setenv(envName, envVal) + defer os.Unsetenv(envName) var v Value assert.Error(t, UnmarshalKey(emptyMap, &v)) @@ -4115,6 +4159,20 @@ func TestUnmarshalNestedPtr(t *testing.T) { } } +func TestUnmarshalStructPtrOfPtr(t *testing.T) { + type inner struct { + Int int `key:"int"` + } + m := map[string]interface{}{ + "int": 1, + } + + in := new(inner) + if assert.NoError(t, UnmarshalKey(m, &in)) { + assert.Equal(t, 1, in.Int) + } +} + func BenchmarkDefaultValue(b *testing.B) { for i := 0; i < b.N; i++ { var a struct { diff --git a/rest/engine.go b/rest/engine.go index 6f12ee23..11e1af74 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -118,7 +118,7 @@ func (ng *engine) buildChainWithNativeMiddlewares(fr featuredRoutes, route Route chn := chain.New() if ng.conf.Middlewares.Trace { - chn = chn.Append(handler.TracingHandler(ng.conf.Name, + chn = chn.Append(handler.TraceHandler(ng.conf.Name, route.Path, handler.WithTraceIgnorePaths(ng.conf.TraceIgnorePaths))) } @@ -204,7 +204,7 @@ func (ng *engine) getShedder(priority bool) load.Shedder { func (ng *engine) notFoundHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { chn := chain.New( - handler.TracingHandler(ng.conf.Name, + handler.TraceHandler(ng.conf.Name, "", handler.WithTraceIgnorePaths(ng.conf.TraceIgnorePaths)), ng.getLogHandler(), diff --git a/rest/handler/tracehandler.go b/rest/handler/tracehandler.go new file mode 100644 index 00000000..bc98e73b --- /dev/null +++ b/rest/handler/tracehandler.go @@ -0,0 +1,78 @@ +package handler + +import ( + "net/http" + + "github.com/zeromicro/go-zero/core/collection" + "github.com/zeromicro/go-zero/core/trace" + "github.com/zeromicro/go-zero/rest/internal/response" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" + semconv "go.opentelemetry.io/otel/semconv/v1.4.0" + oteltrace "go.opentelemetry.io/otel/trace" +) + +type ( + // TraceOption defines the method to customize an traceOptions. + TraceOption func(options *traceOptions) + + // traceOptions is TraceHandler options. + traceOptions struct { + traceIgnorePaths []string + } +) + +// TraceHandler return a middleware that process the opentelemetry. +func TraceHandler(serviceName, path string, opts ...TraceOption) func(http.Handler) http.Handler { + var options traceOptions + for _, opt := range opts { + opt(&options) + } + + ignorePaths := collection.NewSet() + ignorePaths.AddStr(options.traceIgnorePaths...) + + return func(next http.Handler) http.Handler { + tracer := otel.Tracer(trace.TraceName) + propagator := otel.GetTextMapPropagator() + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + spanName := path + if len(spanName) == 0 { + spanName = r.URL.Path + } + + if ignorePaths.Contains(spanName) { + next.ServeHTTP(w, r) + return + } + + ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header)) + spanCtx, span := tracer.Start( + ctx, + spanName, + oteltrace.WithSpanKind(oteltrace.SpanKindServer), + oteltrace.WithAttributes(semconv.HTTPServerAttributesFromHTTPRequest( + serviceName, spanName, r)...), + ) + defer span.End() + + // convenient for tracking error messages + propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header())) + + trw := &response.WithCodeResponseWriter{Writer: w, Code: http.StatusOK} + next.ServeHTTP(trw, r.WithContext(spanCtx)) + + span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(trw.Code)...) + span.SetStatus(semconv.SpanStatusFromHTTPStatusCodeAndSpanKind( + trw.Code, oteltrace.SpanKindServer)) + }) + } +} + +// WithTraceIgnorePaths specifies the traceIgnorePaths option for TraceHandler. +func WithTraceIgnorePaths(traceIgnorePaths []string) TraceOption { + return func(options *traceOptions) { + options.traceIgnorePaths = append(options.traceIgnorePaths, traceIgnorePaths...) + } +} diff --git a/rest/handler/tracinghandler_test.go b/rest/handler/tracehandler_test.go similarity index 95% rename from rest/handler/tracinghandler_test.go rename to rest/handler/tracehandler_test.go index 0e110010..9a24ffb1 100644 --- a/rest/handler/tracinghandler_test.go +++ b/rest/handler/tracehandler_test.go @@ -27,7 +27,7 @@ func TestOtelHandler(t *testing.T) { for _, test := range []string{"", "bar"} { t.Run(test, func(t *testing.T) { - h := chain.New(TracingHandler("foo", test)).Then( + h := chain.New(TraceHandler("foo", test)).Then( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { span := trace.SpanFromContext(r.Context()) assert.True(t, span.SpanContext().IsValid()) @@ -65,7 +65,7 @@ func TestDontTracingSpan(t *testing.T) { for _, test := range []string{"", "bar", "foo"} { t.Run(test, func(t *testing.T) { - h := chain.New(TracingHandler("foo", test, WithTraceIgnorePaths([]string{"bar"}))).Then( + h := chain.New(TraceHandler("foo", test, WithTraceIgnorePaths([]string{"bar"}))).Then( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { span := trace.SpanFromContext(r.Context()) spanCtx := span.SpanContext() @@ -110,7 +110,7 @@ func TestTraceResponseWriter(t *testing.T) { for _, test := range []int{0, 200, 300, 400, 401, 500, 503} { t.Run(strconv.Itoa(test), func(t *testing.T) { - h := chain.New(TracingHandler("foo", "bar")).Then( + h := chain.New(TraceHandler("foo", "bar")).Then( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { span := trace.SpanFromContext(r.Context()) spanCtx := span.SpanContext() diff --git a/rest/handler/tracinghandler.go b/rest/handler/tracinghandler.go deleted file mode 100644 index 6de3f164..00000000 --- a/rest/handler/tracinghandler.go +++ /dev/null @@ -1,80 +0,0 @@ -package handler - -import ( - "net/http" - - "github.com/zeromicro/go-zero/core/collection" - "github.com/zeromicro/go-zero/core/trace" - "github.com/zeromicro/go-zero/rest/internal/response" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/propagation" - semconv "go.opentelemetry.io/otel/semconv/v1.4.0" - oteltrace "go.opentelemetry.io/otel/trace" -) - -type ( - // TracingOption defines the method to customize an tracingOptions. - TracingOption func(options *tracingOptions) - - // tracingOptions is TracingHandler options. - tracingOptions struct { - traceIgnorePaths []string - } -) - -// TracingHandler return a middleware that process the opentelemetry. -func TracingHandler(serviceName, path string, opts ...TracingOption) func(http.Handler) http.Handler { - var tracingOpts tracingOptions - for _, opt := range opts { - opt(&tracingOpts) - } - - ignorePaths := collection.NewSet() - ignorePaths.AddStr(tracingOpts.traceIgnorePaths...) - traceHandler := func(checkIgnore bool) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - tracer := otel.Tracer(trace.TraceName) - propagator := otel.GetTextMapPropagator() - - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - spanName := path - if len(spanName) == 0 { - spanName = r.URL.Path - } - - if checkIgnore && ignorePaths.Contains(spanName) { - next.ServeHTTP(w, r) - return - } - - ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header)) - spanCtx, span := tracer.Start( - ctx, - spanName, - oteltrace.WithSpanKind(oteltrace.SpanKindServer), - oteltrace.WithAttributes(semconv.HTTPServerAttributesFromHTTPRequest( - serviceName, spanName, r)...), - ) - defer span.End() - - // convenient for tracking error messages - propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header())) - - trw := &response.WithCodeResponseWriter{Writer: w, Code: http.StatusOK} - next.ServeHTTP(trw, r.WithContext(spanCtx)) - - span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(trw.Code)...) - span.SetStatus(semconv.SpanStatusFromHTTPStatusCodeAndSpanKind(trw.Code, oteltrace.SpanKindServer)) - }) - } - } - checkIgnore := ignorePaths.Count() > 0 - return traceHandler(checkIgnore) -} - -// WithTraceIgnorePaths specifies the traceIgnorePaths option for TracingHandler. -func WithTraceIgnorePaths(traceIgnorePaths []string) TracingOption { - return func(options *tracingOptions) { - options.traceIgnorePaths = append(options.traceIgnorePaths, traceIgnorePaths...) - } -}