From 23deaf50e68c39e278f8b59845dbb3dc5a8df3f5 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Wed, 29 Dec 2021 17:37:36 +0800 Subject: [PATCH] feat: support array in default and options tags (#1386) * feat: support array in default and options tags * feat: ignore spaces in tags * test: add more tests --- core/mapping/unmarshaler.go | 73 +++++++++++++-------- core/mapping/unmarshaler_test.go | 105 +++++++++++++++++++++++++++++++ core/mapping/utils.go | 97 +++++++++++++++++++++++++--- core/mapping/utils_test.go | 76 ++++++++++++++++++++++ 4 files changed, 315 insertions(+), 36 deletions(-) diff --git a/core/mapping/unmarshaler.go b/core/mapping/unmarshaler.go index c6fd842b..4fe701b4 100644 --- a/core/mapping/unmarshaler.go +++ b/core/mapping/unmarshaler.go @@ -7,7 +7,6 @@ import ( "reflect" "strings" "sync" - "sync/atomic" "time" "github.com/tal-tech/go-zero/core/jsonx" @@ -25,15 +24,17 @@ var ( errValueNotSettable = errors.New("value is not settable") errValueNotStruct = errors.New("value type is not struct") keyUnmarshaler = NewUnmarshaler(defaultKeyName) - cacheKeys atomic.Value - cacheKeysLock sync.Mutex durationType = reflect.TypeOf(time.Duration(0)) + cacheKeys map[string][]string + cacheKeysLock sync.Mutex + defaultCache map[string]interface{} + defaultCacheLock sync.Mutex emptyMap = map[string]interface{}{} emptyValue = reflect.ValueOf(lang.Placeholder) ) type ( - // A Unmarshaler is used to unmarshal with given tag key. + // Unmarshaler is used to unmarshal with given tag key. Unmarshaler struct { key string opts unmarshalOptions @@ -46,12 +47,11 @@ type ( fromString bool canonicalKey func(key string) string } - - keyCache map[string][]string ) func init() { - cacheKeys.Store(make(keyCache)) + cacheKeys = make(map[string][]string) + defaultCache = make(map[string]interface{}) } // NewUnmarshaler returns a Unmarshaler. @@ -388,7 +388,13 @@ func (u *Unmarshaler) processNamedFieldWithoutValue(field reflect.StructField, v if derefedType == durationType { return fillDurationValue(fieldKind, value, defaultValue) } - return setValue(fieldKind, value, defaultValue) + + switch fieldKind { + case reflect.Array, reflect.Slice: + return u.fillSliceWithDefault(derefedType, value, defaultValue) + default: + return setValue(fieldKind, value, defaultValue) + } } switch fieldKind { @@ -502,7 +508,8 @@ func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect. return nil } -func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int, baseKind reflect.Kind, value interface{}) error { +func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int, + baseKind reflect.Kind, value interface{}) error { ithVal := slice.Index(index) switch v := value.(type) { case json.Number: @@ -531,6 +538,28 @@ func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int, baseKind re } } +func (u *Unmarshaler) fillSliceWithDefault(derefedType reflect.Type, value reflect.Value, + defaultValue string) error { + baseFieldType := Deref(derefedType.Elem()) + baseFieldKind := baseFieldType.Kind() + defaultCacheLock.Lock() + slice, ok := defaultCache[defaultValue] + defaultCacheLock.Unlock() + if !ok { + if baseFieldKind == reflect.String { + slice = parseGroupedSegments(defaultValue) + } else if err := jsonx.UnmarshalFromString(defaultValue, &slice); err != nil { + return err + } + + defaultCacheLock.Lock() + defaultCache[defaultValue] = slice + defaultCacheLock.Unlock() + } + + return u.fillSlice(derefedType, value, slice) +} + func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue interface{}) (reflect.Value, error) { mapType := reflect.MapOf(keyType, elemType) valueType := reflect.TypeOf(mapValue) @@ -724,20 +753,6 @@ func getValueWithChainedKeys(m Valuer, keys []string) (interface{}, bool) { return nil, false } -func insertKeys(key string, cache []string) { - cacheKeysLock.Lock() - defer cacheKeysLock.Unlock() - - keys := cacheKeys.Load().(keyCache) - // copy the contents into the new map, to guarantee the old map is immutable - newKeys := make(keyCache) - for k, v := range keys { - newKeys[k] = v - } - newKeys[key] = cache - cacheKeys.Store(newKeys) -} - func join(elem ...string) string { var builder strings.Builder @@ -768,15 +783,19 @@ func newTypeMismatchError(name string) error { } func readKeys(key string) []string { - cache := cacheKeys.Load().(keyCache) - if keys, ok := cache[key]; ok { + cacheKeysLock.Lock() + keys, ok := cacheKeys[key] + cacheKeysLock.Unlock() + if ok { return keys } - keys := strings.FieldsFunc(key, func(c rune) bool { + keys = strings.FieldsFunc(key, func(c rune) bool { return c == delimiter }) - insertKeys(key, keys) + cacheKeysLock.Lock() + cacheKeys[key] = keys + cacheKeysLock.Unlock() return keys } diff --git a/core/mapping/unmarshaler_test.go b/core/mapping/unmarshaler_test.go index 143333a0..6c261887 100644 --- a/core/mapping/unmarshaler_test.go +++ b/core/mapping/unmarshaler_test.go @@ -198,6 +198,66 @@ func TestUnmarshalIntWithDefault(t *testing.T) { assert.Equal(t, 1, in.Int) } +func TestUnmarshalBoolSliceWithDefault(t *testing.T) { + type inner struct { + Bools []bool `key:"bools,default=[true,false]"` + } + + var in inner + assert.Nil(t, UnmarshalKey(nil, &in)) + assert.ElementsMatch(t, []bool{true, false}, in.Bools) +} + +func TestUnmarshalIntSliceWithDefault(t *testing.T) { + type inner struct { + Ints []int `key:"ints,default=[1,2,3]"` + } + + var in inner + assert.Nil(t, UnmarshalKey(nil, &in)) + assert.ElementsMatch(t, []int{1, 2, 3}, in.Ints) +} + +func TestUnmarshalIntSliceWithDefaultHasSpaces(t *testing.T) { + type inner struct { + Ints []int `key:"ints,default=[1, 2, 3]"` + } + + var in inner + assert.Nil(t, UnmarshalKey(nil, &in)) + assert.ElementsMatch(t, []int{1, 2, 3}, in.Ints) +} + +func TestUnmarshalFloatSliceWithDefault(t *testing.T) { + type inner struct { + Floats []float32 `key:"floats,default=[1.1,2.2,3.3]"` + } + + var in inner + assert.Nil(t, UnmarshalKey(nil, &in)) + assert.ElementsMatch(t, []float32{1.1, 2.2, 3.3}, in.Floats) +} + +func TestUnmarshalStringSliceWithDefault(t *testing.T) { + type inner struct { + Strs []string `key:"strs,default=[foo,bar,woo]"` + } + + var in inner + assert.Nil(t, UnmarshalKey(nil, &in)) + assert.ElementsMatch(t, []string{"foo", "bar", "woo"}, in.Strs) +} + +func TestUnmarshalStringSliceWithDefaultHasSpaces(t *testing.T) { + type inner struct { + Strs []string `key:"strs,default=[foo, bar, woo]"` + } + + var in inner + assert.Nil(t, UnmarshalKey(nil, &in)) + assert.ElementsMatch(t, []string{"foo", "bar", "woo"}, in.Strs) +} + func TestUnmarshalUint(t *testing.T) { type inner struct { Uint uint `key:"uint"` @@ -861,10 +921,12 @@ func TestUnmarshalSliceOfStruct(t *testing.T) { func TestUnmarshalWithStringOptionsCorrect(t *testing.T) { type inner struct { Value string `key:"value,options=first|second"` + Foo string `key:"foo,options=[bar,baz]"` Correct string `key:"correct,options=1|2"` } m := map[string]interface{}{ "value": "first", + "foo": "bar", "correct": "2", } @@ -872,6 +934,7 @@ func TestUnmarshalWithStringOptionsCorrect(t *testing.T) { ast := assert.New(t) ast.Nil(UnmarshalKey(m, &in)) ast.Equal("first", in.Value) + ast.Equal("bar", in.Foo) ast.Equal("2", in.Correct) } @@ -943,6 +1006,22 @@ func TestUnmarshalStringOptionsWithStringOptionsIncorrect(t *testing.T) { ast.NotNil(unmarshaler.Unmarshal(m, &in)) } +func TestUnmarshalStringOptionsWithStringOptionsIncorrectGrouped(t *testing.T) { + type inner struct { + Value string `key:"value,options=[first,second]"` + Correct string `key:"correct,options=1|2"` + } + m := map[string]interface{}{ + "value": "third", + "correct": "2", + } + + var in inner + unmarshaler := NewUnmarshaler(defaultKeyName, WithStringValues()) + ast := assert.New(t) + ast.NotNil(unmarshaler.Unmarshal(m, &in)) +} + func TestUnmarshalWithStringOptionsIncorrect(t *testing.T) { type inner struct { Value string `key:"value,options=first|second"` @@ -2518,3 +2597,29 @@ func TestUnmarshalJsonReaderPtrArray(t *testing.T) { assert.Nil(t, err) assert.Equal(t, 3, len(res.B)) } + +func TestUnmarshalJsonWithoutKey(t *testing.T) { + payload := `{"A": "1", "B": "2"}` + var res struct { + A string `json:""` + B string `json:","` + } + reader := strings.NewReader(payload) + err := UnmarshalJsonReader(reader, &res) + assert.Nil(t, err) + assert.Equal(t, "1", res.A) + assert.Equal(t, "2", res.B) +} + +func BenchmarkDefaultValue(b *testing.B) { + for i := 0; i < b.N; i++ { + var a struct { + Ints []int `json:"ints,default=[1,2,3]"` + Strs []string `json:"strs,default=[foo,bar,baz]"` + } + _ = UnmarshalJsonMap(nil, &a) + if len(a.Strs) != 3 || len(a.Ints) != 3 { + b.Fatal("failed") + } + } +} diff --git a/core/mapping/utils.go b/core/mapping/utils.go index e8ff102e..35a8a38a 100644 --- a/core/mapping/utils.go +++ b/core/mapping/utils.go @@ -14,13 +14,19 @@ import ( ) const ( - defaultOption = "default" - stringOption = "string" - optionalOption = "optional" - optionsOption = "options" - rangeOption = "range" - optionSeparator = "|" - equalToken = "=" + defaultOption = "default" + stringOption = "string" + optionalOption = "optional" + optionsOption = "options" + rangeOption = "range" + optionSeparator = "|" + equalToken = "=" + escapeChar = '\\' + leftBracket = '(' + rightBracket = ')' + leftSquareBracket = '[' + rightSquareBracket = ']' + segmentSeparator = ',' ) var ( @@ -118,7 +124,7 @@ func convertType(kind reflect.Kind, str string) (interface{}, error) { } func doParseKeyAndOptions(field reflect.StructField, value string) (string, *fieldOptions, error) { - segments := strings.Split(value, ",") + segments := parseSegments(value) key := strings.TrimSpace(segments[0]) options := segments[1:] @@ -198,6 +204,16 @@ func maybeNewValue(field reflect.StructField, value reflect.Value) { } } +func parseGroupedSegments(val string) []string { + val = strings.TrimLeftFunc(val, func(r rune) bool { + return r == leftBracket || r == leftSquareBracket + }) + val = strings.TrimRightFunc(val, func(r rune) bool { + return r == rightBracket || r == rightSquareBracket + }) + return parseSegments(val) +} + // don't modify returned fieldOptions, it's cached and shared among different calls. func parseKeyAndOptions(tagName string, field reflect.StructField) (string, *fieldOptions, error) { value := field.Tag.Get(tagName) @@ -309,7 +325,7 @@ func parseOption(fieldOpts *fieldOptions, fieldName, option string) error { return fmt.Errorf("field %s has wrong options", fieldName) } - fieldOpts.Options = strings.Split(segs[1], optionSeparator) + fieldOpts.Options = parseOptions(segs[1]) case strings.HasPrefix(option, defaultOption): segs := strings.Split(option, equalToken) if len(segs) != 2 { @@ -334,6 +350,69 @@ func parseOption(fieldOpts *fieldOptions, fieldName, option string) error { return nil } +// parseOptions parses the given options in tag. +// for example: `json:"name,options=foo|bar"` or `json:"name,options=[foo,bar]"` +func parseOptions(val string) []string { + if len(val) == 0 { + return nil + } + + if val[0] == leftSquareBracket { + return parseGroupedSegments(val) + } + + return strings.Split(val, optionSeparator) +} + +func parseSegments(val string) []string { + var segments []string + var escaped, grouped bool + var buf strings.Builder + + for _, ch := range val { + if escaped { + buf.WriteRune(ch) + escaped = false + continue + } + + switch ch { + case segmentSeparator: + if grouped { + buf.WriteRune(ch) + } else { + // need to trim spaces, but we cannot ignore empty string, + // because the first segment stands for the key might be empty. + // if ignored, the later tag will be used as the key. + segments = append(segments, strings.TrimSpace(buf.String())) + buf.Reset() + } + case escapeChar: + if grouped { + buf.WriteRune(ch) + } else { + escaped = true + } + case leftBracket, leftSquareBracket: + buf.WriteRune(ch) + grouped = true + case rightBracket, rightSquareBracket: + buf.WriteRune(ch) + grouped = false + default: + buf.WriteRune(ch) + } + } + + last := strings.TrimSpace(buf.String()) + // ignore last empty string + if len(last) > 0 { + segments = append(segments, last) + } + + return segments +} + func reprOfValue(val reflect.Value) string { switch vt := val.Interface().(type) { case bool: diff --git a/core/mapping/utils_test.go b/core/mapping/utils_test.go index 85b27c42..fd5194cd 100644 --- a/core/mapping/utils_test.go +++ b/core/mapping/utils_test.go @@ -90,6 +90,82 @@ func TestParseKeyAndOptionWithTagAndOption(t *testing.T) { assert.True(t, options.FromString) } +func TestParseSegments(t *testing.T) { + tests := []struct { + input string + expect []string + }{ + { + input: "", + expect: []string{}, + }, + { + input: ",", + expect: []string{""}, + }, + { + input: "foo,", + expect: []string{"foo"}, + }, + { + input: ",foo", + // the first empty string cannot be ignored, it's the key. + expect: []string{"", "foo"}, + }, + { + input: "foo", + expect: []string{"foo"}, + }, + { + input: "foo,bar", + expect: []string{"foo", "bar"}, + }, + { + input: "foo,bar,baz", + expect: []string{"foo", "bar", "baz"}, + }, + { + input: "foo,options=a|b", + expect: []string{"foo", "options=a|b"}, + }, + { + input: "foo,bar,default=[baz,qux]", + expect: []string{"foo", "bar", "default=[baz,qux]"}, + }, + { + input: "foo,bar,options=[baz,qux]", + expect: []string{"foo", "bar", "options=[baz,qux]"}, + }, + { + input: `foo\,bar,options=[baz,qux]`, + expect: []string{`foo,bar`, "options=[baz,qux]"}, + }, + { + input: `foo,bar,options=\[baz,qux]`, + expect: []string{"foo", "bar", "options=[baz", "qux]"}, + }, + { + input: `foo,bar,options=[baz\,qux]`, + expect: []string{"foo", "bar", `options=[baz\,qux]`}, + }, + { + input: `foo\,bar,options=[baz,qux],default=baz`, + expect: []string{`foo,bar`, "options=[baz,qux]", "default=baz"}, + }, + { + input: `foo\,bar,options=[baz,qux, quux],default=[qux, baz]`, + expect: []string{`foo,bar`, "options=[baz,qux, quux]", "default=[qux, baz]"}, + }, + } + + for _, test := range tests { + test := test + t.Run(test.input, func(t *testing.T) { + assert.ElementsMatch(t, test.expect, parseSegments(test.input)) + }) + } +} + func TestValidatePtrWithNonPtr(t *testing.T) { var foo string rve := reflect.ValueOf(foo)