diff --git a/core/conf/config.go b/core/conf/config.go index 605f218c..bdaf7061 100644 --- a/core/conf/config.go +++ b/core/conf/config.go @@ -5,6 +5,7 @@ import ( "log" "os" "path" + "reflect" "strings" "github.com/zeromicro/go-zero/core/jsonx" @@ -21,6 +22,12 @@ var loaders = map[string]func([]byte, interface{}) error{ ".yml": LoadFromYamlBytes, } +type fieldInfo struct { + name string + kind reflect.Kind + children map[string]fieldInfo +} + // Load loads config into v from file, .json, .yaml and .yml are acceptable. func Load(file string, v interface{}, opts ...Option) error { content, err := os.ReadFile(file) @@ -58,7 +65,10 @@ func LoadFromJsonBytes(content []byte, v interface{}) error { return err } - return mapping.UnmarshalJsonMap(toCamelCaseKeyMap(m), v, mapping.WithCanonicalKeyFunc(toCamelCase)) + finfo := buildFieldsInfo(reflect.TypeOf(v)) + camelCaseKeyMap := toCamelCaseKeyMap(m, finfo) + + return mapping.UnmarshalJsonMap(camelCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toCamelCase)) } // LoadConfigFromJsonBytes loads config into v from content json bytes. @@ -100,6 +110,64 @@ func MustLoad(path string, v interface{}, opts ...Option) { } } +func buildFieldsInfo(tp reflect.Type) map[string]fieldInfo { + tp = mapping.Deref(tp) + + switch tp.Kind() { + case reflect.Struct: + return buildStructFieldsInfo(tp) + case reflect.Array, reflect.Slice: + return buildFieldsInfo(mapping.Deref(tp.Elem())) + default: + return nil + } +} + +func buildStructFieldsInfo(tp reflect.Type) map[string]fieldInfo { + info := make(map[string]fieldInfo) + + for i := 0; i < tp.NumField(); i++ { + field := tp.Field(i) + name := field.Name + ccName := toCamelCase(name) + ft := mapping.Deref(field.Type) + + // flatten anonymous fields + if field.Anonymous { + if ft.Kind() == reflect.Struct { + fields := buildFieldsInfo(ft) + for k, v := range fields { + info[k] = v + } + } else { + info[ccName] = fieldInfo{ + name: name, + kind: ft.Kind(), + } + } + continue + } + + var fields map[string]fieldInfo + switch ft.Kind() { + case reflect.Struct: + fields = buildFieldsInfo(ft) + case reflect.Array, reflect.Slice: + fields = buildFieldsInfo(ft.Elem()) + case reflect.Map: + fields = buildFieldsInfo(ft.Elem()) + } + + info[ccName] = fieldInfo{ + name: name, + kind: ft.Kind(), + children: fields, + } + } + + return info +} + func toCamelCase(s string) string { var buf strings.Builder buf.Grow(len(s)) @@ -123,14 +191,19 @@ func toCamelCase(s string) string { if isCap || isLow { buf.WriteRune(v) capNext = false - } else if v == ' ' || v == '\t' { + continue + } + + switch v { + // '.' is used for chained keys, e.g. "grand.parent.child" + case ' ', '.', '\t': buf.WriteRune(v) capNext = false boundary = true - } else if v == '_' { + case '_': capNext = true boundary = true - } else { + default: buf.WriteRune(v) capNext = true } @@ -139,14 +212,14 @@ func toCamelCase(s string) string { return buf.String() } -func toCamelCaseInterface(v interface{}) interface{} { +func toCamelCaseInterface(v interface{}, info map[string]fieldInfo) interface{} { switch vv := v.(type) { case map[string]interface{}: - return toCamelCaseKeyMap(vv) + return toCamelCaseKeyMap(vv, info) case []interface{}: var arr []interface{} for _, vvv := range vv { - arr = append(arr, toCamelCaseInterface(vvv)) + arr = append(arr, toCamelCaseInterface(vvv, info)) } return arr default: @@ -154,10 +227,22 @@ func toCamelCaseInterface(v interface{}) interface{} { } } -func toCamelCaseKeyMap(m map[string]interface{}) map[string]interface{} { +func toCamelCaseKeyMap(m map[string]interface{}, info map[string]fieldInfo) map[string]interface{} { res := make(map[string]interface{}) + for k, v := range m { - res[toCamelCase(k)] = toCamelCaseInterface(v) + ti, ok := info[k] + if ok { + res[k] = toCamelCaseInterface(v, ti.children) + continue + } + + cck := toCamelCase(k) + if ti, ok = info[cck]; ok { + res[toCamelCase(k)] = toCamelCaseInterface(v, ti.children) + } else { + res[k] = v + } } return res diff --git a/core/conf/config_test.go b/core/conf/config_test.go index 6485ef46..547b4745 100644 --- a/core/conf/config_test.go +++ b/core/conf/config_test.go @@ -283,6 +283,10 @@ func TestToCamelCase(t *testing.T) { input: "Hello World Foo_Bar", expect: "hello world fooBar", }, + { + input: "Hello.World Foo_Bar", + expect: "hello.world fooBar", + }, { input: "你好 World Foo_Bar", expect: "你好 world fooBar", @@ -328,6 +332,84 @@ func TestLoadFromYamlBytes(t *testing.T) { assert.Equal(t, "foo", val.Layer1.Layer2.Layer3) } +func TestLoadFromYamlBytesLayers(t *testing.T) { + input := []byte(`layer1: + layer2: + layer3: foo`) + var val struct { + Value string `json:"Layer1.Layer2.Layer3"` + } + + assert.NoError(t, LoadFromYamlBytes(input, &val)) + assert.Equal(t, "foo", val.Value) +} + +func TestUnmarshalJsonBytesMap(t *testing.T) { + input := []byte(`{"foo":{"/mtproto.RPCTos": "bff.bff","bar":"baz"}}`) + + var val struct { + Foo map[string]string + } + + assert.NoError(t, LoadFromJsonBytes(input, &val)) + assert.Equal(t, "bff.bff", val.Foo["/mtproto.RPCTos"]) + assert.Equal(t, "baz", val.Foo["bar"]) +} + +func TestUnmarshalJsonBytesMapWithSliceElements(t *testing.T) { + input := []byte(`{"foo":{"/mtproto.RPCTos": ["bff.bff", "any"],"bar":["baz", "qux"]}}`) + + var val struct { + Foo map[string][]string + } + + assert.NoError(t, LoadFromJsonBytes(input, &val)) + assert.EqualValues(t, []string{"bff.bff", "any"}, val.Foo["/mtproto.RPCTos"]) + assert.EqualValues(t, []string{"baz", "qux"}, val.Foo["bar"]) +} + +func TestUnmarshalJsonBytesMapWithSliceOfStructs(t *testing.T) { + input := []byte(`{"foo":{ + "/mtproto.RPCTos": [{"bar": "any"}], + "bar":[{"bar": "qux"}, {"bar": "ever"}]}}`) + + var val struct { + Foo map[string][]struct { + Bar string + } + } + + assert.NoError(t, LoadFromJsonBytes(input, &val)) + assert.Equal(t, 1, len(val.Foo["/mtproto.RPCTos"])) + assert.Equal(t, "any", val.Foo["/mtproto.RPCTos"][0].Bar) + assert.Equal(t, 2, len(val.Foo["bar"])) + assert.Equal(t, "qux", val.Foo["bar"][0].Bar) + assert.Equal(t, "ever", val.Foo["bar"][1].Bar) +} + +func TestUnmarshalJsonBytesWithAnonymousField(t *testing.T) { + type ( + Int int + + InnerConf struct { + Name string + } + + Conf struct { + Int + InnerConf + } + ) + + var ( + input = []byte(`{"Name": "hello", "int": 3}`) + c Conf + ) + assert.NoError(t, LoadFromJsonBytes(input, &c)) + assert.Equal(t, "hello", c.Name) + assert.Equal(t, Int(3), c.Int) +} + func createTempFile(ext, text string) (string, error) { tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext) if err != nil { diff --git a/core/mapping/unmarshaler.go b/core/mapping/unmarshaler.go index 59977b44..767b2903 100644 --- a/core/mapping/unmarshaler.go +++ b/core/mapping/unmarshaler.go @@ -376,19 +376,51 @@ func (u *Unmarshaler) processAnonymousField(field reflect.StructField, value ref return err } - if _, hasValue := getValue(m, key); hasValue { - return fmt.Errorf("fields of %s can't be wrapped inside, because it's anonymous", key) - } - if options.optional() { - return u.processAnonymousFieldOptional(field.Type, value, key, m, fullName) + return u.processAnonymousFieldOptional(field, value, key, m, fullName) } - return u.processAnonymousFieldRequired(field.Type, value, m, fullName) + return u.processAnonymousFieldRequired(field, value, m, fullName) } -func (u *Unmarshaler) processAnonymousFieldOptional(fieldType reflect.Type, value reflect.Value, +func (u *Unmarshaler) processAnonymousFieldOptional(field reflect.StructField, value reflect.Value, key string, m valuerWithParent, fullName string) error { + derefedFieldType := Deref(field.Type) + + switch derefedFieldType.Kind() { + case reflect.Struct: + return u.processAnonymousStructFieldOptional(field.Type, value, key, m, fullName) + default: + return u.processNamedField(field, value, m, fullName) + } +} + +func (u *Unmarshaler) processAnonymousFieldRequired(field reflect.StructField, value reflect.Value, + m valuerWithParent, fullName string) error { + fieldType := field.Type + maybeNewValue(fieldType, value) + derefedFieldType := Deref(fieldType) + indirectValue := reflect.Indirect(value) + + switch derefedFieldType.Kind() { + case reflect.Struct: + for i := 0; i < derefedFieldType.NumField(); i++ { + if err := u.processField(derefedFieldType.Field(i), indirectValue.Field(i), + m, fullName); err != nil { + return err + } + } + default: + if err := u.processNamedField(field, indirectValue, m, fullName); err != nil { + return err + } + } + + return nil +} + +func (u *Unmarshaler) processAnonymousStructFieldOptional(fieldType reflect.Type, + value reflect.Value, key string, m valuerWithParent, fullName string) error { var filled bool var required int var requiredFilled int @@ -428,21 +460,6 @@ func (u *Unmarshaler) processAnonymousFieldOptional(fieldType reflect.Type, valu return nil } -func (u *Unmarshaler) processAnonymousFieldRequired(fieldType reflect.Type, value reflect.Value, - m valuerWithParent, fullName string) error { - maybeNewValue(fieldType, value) - derefedFieldType := Deref(fieldType) - indirectValue := reflect.Indirect(value) - - for i := 0; i < derefedFieldType.NumField(); i++ { - if err := u.processField(derefedFieldType.Field(i), indirectValue.Field(i), m, fullName); err != nil { - return err - } - } - - return nil -} - func (u *Unmarshaler) processField(field reflect.StructField, value reflect.Value, m valuerWithParent, fullName string) error { if usingDifferentKeys(u.key, field) { diff --git a/core/mapping/unmarshaler_test.go b/core/mapping/unmarshaler_test.go index 356cbb3a..c45b0600 100644 --- a/core/mapping/unmarshaler_test.go +++ b/core/mapping/unmarshaler_test.go @@ -212,6 +212,24 @@ func TestUnmarshalIntPtr(t *testing.T) { assert.Equal(t, 1, *in.Int) } +func TestUnmarshalIntSliceOfPtr(t *testing.T) { + type inner struct { + Ints []*int `key:"ints"` + } + m := map[string]interface{}{ + "ints": []int{1, 2, 3}, + } + + var in inner + assert.NoError(t, UnmarshalKey(m, &in)) + assert.NotEmpty(t, in.Ints) + var ints []int + for _, i := range in.Ints { + ints = append(ints, *i) + } + assert.EqualValues(t, []int{1, 2, 3}, ints) +} + func TestUnmarshalIntWithDefault(t *testing.T) { type inner struct { Int int `key:"int,default=5"` @@ -3665,6 +3683,7 @@ func TestUnmarshalJsonBytesSliceOfMaps(t *testing.T) { Name string `json:"name"` ActualAmount int `json:"actual_amount"` } + OrderApplyRefundReq struct { OrderId string `json:"order_id"` RefundReason RefundReasonData `json:"refund_reason,optional"` @@ -3676,6 +3695,130 @@ func TestUnmarshalJsonBytesSliceOfMaps(t *testing.T) { assert.NoError(t, UnmarshalJsonBytes(input, &req)) } +func TestUnmarshalJsonBytesWithAnonymousField(t *testing.T) { + type ( + Int int + + InnerConf struct { + Name string + } + + Conf struct { + Int + InnerConf + } + ) + + var ( + input = []byte(`{"Name": "hello", "Int": 3}`) + c Conf + ) + assert.NoError(t, UnmarshalJsonBytes(input, &c)) + assert.Equal(t, "hello", c.Name) + assert.Equal(t, Int(3), c.Int) +} + +func TestUnmarshalJsonBytesWithAnonymousFieldOptional(t *testing.T) { + type ( + Int int + + InnerConf struct { + Name string + } + + Conf struct { + Int `json:",optional"` + InnerConf + } + ) + + var ( + input = []byte(`{"Name": "hello", "Int": 3}`) + c Conf + ) + assert.NoError(t, UnmarshalJsonBytes(input, &c)) + assert.Equal(t, "hello", c.Name) + assert.Equal(t, Int(3), c.Int) +} + +func TestUnmarshalJsonBytesWithAnonymousFieldBadTag(t *testing.T) { + type ( + Int int + + InnerConf struct { + Name string + } + + Conf struct { + Int `json:",optional=123"` + InnerConf + } + ) + + var ( + input = []byte(`{"Name": "hello", "Int": 3}`) + c Conf + ) + assert.Error(t, UnmarshalJsonBytes(input, &c)) +} + +func TestUnmarshalJsonBytesWithAnonymousFieldBadValue(t *testing.T) { + type ( + Int int + + InnerConf struct { + Name string + } + + Conf struct { + Int + InnerConf + } + ) + + var ( + input = []byte(`{"Name": "hello", "Int": "3"}`) + c Conf + ) + assert.Error(t, UnmarshalJsonBytes(input, &c)) +} + +func TestUnmarshalJsonBytesWithAnonymousFieldBadTagInStruct(t *testing.T) { + type ( + InnerConf struct { + Name string `json:",optional=123"` + } + + Conf struct { + InnerConf `json:",optional"` + } + ) + + var ( + input = []byte(`{"Name": "hello"}`) + c Conf + ) + assert.Error(t, UnmarshalJsonBytes(input, &c)) +} + +func TestUnmarshalJsonBytesWithAnonymousFieldNotInOptions(t *testing.T) { + type ( + InnerConf struct { + Name string `json:",options=[a,b]"` + } + + Conf struct { + InnerConf `json:",optional"` + } + ) + + var ( + input = []byte(`{"Name": "hello"}`) + c Conf + ) + assert.Error(t, UnmarshalJsonBytes(input, &c)) +} + func BenchmarkDefaultValue(b *testing.B) { for i := 0; i < b.N; i++ { var a struct { diff --git a/core/mapping/valuer_test.go b/core/mapping/valuer_test.go index 01ee56f5..f154ac1b 100644 --- a/core/mapping/valuer_test.go +++ b/core/mapping/valuer_test.go @@ -31,3 +31,27 @@ func TestMapValuerWithInherit_Value(t *testing.T) { assert.Equal(t, "localhost", m["host"]) assert.Equal(t, 8080, m["port"]) } + +func TestRecursiveValuer_Value(t *testing.T) { + input := map[string]interface{}{ + "component": map[string]interface{}{ + "name": "test", + "foo": map[string]interface{}{ + "bar": "baz", + }, + }, + "foo": "value", + } + valuer := recursiveValuer{ + current: mapValuer(input["component"].(map[string]interface{})), + parent: simpleValuer{ + current: mapValuer(input), + }, + } + + val, ok := valuer.Value("foo") + assert.True(t, ok) + assert.EqualValues(t, map[string]interface{}{ + "bar": "baz", + }, val) +}