Compare commits
72 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ea0a843f8 | ||
|
|
9e0e01b2bc | ||
|
|
af50a80d01 | ||
|
|
703fb8d970 | ||
|
|
e964e530e1 | ||
|
|
52265087d1 | ||
|
|
b4c2677eb9 | ||
|
|
30296fb1ca | ||
|
|
356c80defd | ||
|
|
8c31525378 | ||
|
|
2cf09f3c36 | ||
|
|
d41e542c92 | ||
|
|
265a24ac6d | ||
|
|
7d88fc39dc | ||
|
|
6957b6a344 | ||
|
|
bca6a230c8 | ||
|
|
cc8413d683 | ||
|
|
3842283fa8 | ||
|
|
fe13a533f5 | ||
|
|
7a327ccda4 | ||
|
|
06e4507406 | ||
|
|
8794d5b753 | ||
|
|
9bfa63d995 | ||
|
|
a432b121fb | ||
|
|
b61c94bb66 | ||
|
|
93fcf899dc | ||
|
|
9f4b3bae92 | ||
|
|
805cb87d98 | ||
|
|
366131640e | ||
|
|
956884a3ff | ||
|
|
f571cb8af2 | ||
|
|
cc5acf3b90 | ||
|
|
e1aa665443 | ||
|
|
cd357d9484 | ||
|
|
6d4d7cbd6b | ||
|
|
c593b5b531 | ||
|
|
fd5b38b07c | ||
|
|
41efb48f55 | ||
|
|
0ef3626839 | ||
|
|
77a72b16e9 | ||
|
|
21566f1b7a | ||
|
|
b2646e228b | ||
|
|
588b883710 | ||
|
|
033910bbd8 | ||
|
|
530dd79e3f | ||
|
|
cd5263ac75 | ||
|
|
ea3302a468 | ||
|
|
abf15b373c | ||
|
|
a865e9ee29 | ||
|
|
f8292198cf | ||
|
|
016d965f56 | ||
|
|
95d7c73409 | ||
|
|
939ef2a181 | ||
|
|
f0b8dd45fe | ||
|
|
0ba9335b04 | ||
|
|
04f181f0b4 | ||
|
|
89f841c126 | ||
|
|
d785c8c377 | ||
|
|
687a1d15da | ||
|
|
aaa974e1ad | ||
|
|
2779568ccf | ||
|
|
f7d50ae626 | ||
|
|
33594ea350 | ||
|
|
ee2ec974c4 | ||
|
|
fd2f2f0f54 | ||
|
|
86a2429d7d | ||
|
|
e5fe5dcc50 | ||
|
|
b510e7c242 | ||
|
|
dfe92e709f | ||
|
|
cb649cf627 | ||
|
|
ce19a5ade6 | ||
|
|
6dc56de714 |
@@ -213,23 +213,23 @@ func (s *Set) validate(i interface{}) {
|
||||
switch i.(type) {
|
||||
case int:
|
||||
if s.tp != intType {
|
||||
logx.Errorf("Error: element is int, but set contains elements with type %d", s.tp)
|
||||
logx.Errorf("element is int, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case int64:
|
||||
if s.tp != int64Type {
|
||||
logx.Errorf("Error: element is int64, but set contains elements with type %d", s.tp)
|
||||
logx.Errorf("element is int64, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case uint:
|
||||
if s.tp != uintType {
|
||||
logx.Errorf("Error: element is uint, but set contains elements with type %d", s.tp)
|
||||
logx.Errorf("element is uint, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case uint64:
|
||||
if s.tp != uint64Type {
|
||||
logx.Errorf("Error: element is uint64, but set contains elements with type %d", s.tp)
|
||||
logx.Errorf("element is uint64, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case string:
|
||||
if s.tp != stringType {
|
||||
logx.Errorf("Error: element is string, but set contains elements with type %d", s.tp)
|
||||
logx.Errorf("element is string, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,17 +13,29 @@ import (
|
||||
"github.com/zeromicro/go-zero/internal/encoding"
|
||||
)
|
||||
|
||||
var loaders = map[string]func([]byte, interface{}) error{
|
||||
".json": LoadFromJsonBytes,
|
||||
".toml": LoadFromTomlBytes,
|
||||
".yaml": LoadFromYamlBytes,
|
||||
".yml": LoadFromYamlBytes,
|
||||
const jsonTagKey = "json"
|
||||
|
||||
var (
|
||||
fillDefaultUnmarshaler = mapping.NewUnmarshaler(jsonTagKey, mapping.WithDefault())
|
||||
loaders = map[string]func([]byte, interface{}) error{
|
||||
".json": LoadFromJsonBytes,
|
||||
".toml": LoadFromTomlBytes,
|
||||
".yaml": LoadFromYamlBytes,
|
||||
".yml": LoadFromYamlBytes,
|
||||
}
|
||||
)
|
||||
|
||||
// children and mapField should not be both filled.
|
||||
// named fields and map cannot be bound to the same field name.
|
||||
type fieldInfo struct {
|
||||
children map[string]*fieldInfo
|
||||
mapField *fieldInfo
|
||||
}
|
||||
|
||||
type fieldInfo struct {
|
||||
name string
|
||||
kind reflect.Kind
|
||||
children map[string]fieldInfo
|
||||
// FillDefault fills the default values for the given v,
|
||||
// and the premise is that the value of v must be guaranteed to be empty.
|
||||
func FillDefault(v interface{}) error {
|
||||
return fillDefaultUnmarshaler.Unmarshal(map[string]interface{}{}, v)
|
||||
}
|
||||
|
||||
// Load loads config into v from file, .json, .yaml and .yml are acceptable.
|
||||
@@ -58,13 +70,17 @@ func LoadConfig(file string, v interface{}, opts ...Option) error {
|
||||
|
||||
// LoadFromJsonBytes loads config into v from content json bytes.
|
||||
func LoadFromJsonBytes(content []byte, v interface{}) error {
|
||||
info, err := buildFieldsInfo(reflect.TypeOf(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var m map[string]interface{}
|
||||
if err := jsonx.Unmarshal(content, &m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
finfo := buildFieldsInfo(reflect.TypeOf(v))
|
||||
lowerCaseKeyMap := toLowerCaseKeyMap(m, finfo)
|
||||
lowerCaseKeyMap := toLowerCaseKeyMap(m, info)
|
||||
|
||||
return mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toLowerCase))
|
||||
}
|
||||
@@ -108,7 +124,63 @@ func MustLoad(path string, v interface{}, opts ...Option) {
|
||||
}
|
||||
}
|
||||
|
||||
func buildFieldsInfo(tp reflect.Type) map[string]fieldInfo {
|
||||
func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo) error {
|
||||
if prev, ok := info.children[key]; ok {
|
||||
if child.mapField != nil {
|
||||
return newDupKeyError(key)
|
||||
}
|
||||
|
||||
if err := mergeFields(prev, key, child.children); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
info.children[key] = child
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
|
||||
switch ft.Kind() {
|
||||
case reflect.Struct:
|
||||
fields, err := buildFieldsInfo(ft)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for k, v := range fields.children {
|
||||
if err = addOrMergeFields(info, k, v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case reflect.Map:
|
||||
elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, ok := info.children[lowerCaseName]; ok {
|
||||
return newDupKeyError(lowerCaseName)
|
||||
}
|
||||
|
||||
info.children[lowerCaseName] = &fieldInfo{
|
||||
children: make(map[string]*fieldInfo),
|
||||
mapField: elemField,
|
||||
}
|
||||
default:
|
||||
if _, ok := info.children[lowerCaseName]; ok {
|
||||
return newDupKeyError(lowerCaseName)
|
||||
}
|
||||
|
||||
info.children[lowerCaseName] = &fieldInfo{
|
||||
children: make(map[string]*fieldInfo),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
|
||||
tp = mapping.Deref(tp)
|
||||
|
||||
switch tp.Kind() {
|
||||
@@ -116,61 +188,95 @@ func buildFieldsInfo(tp reflect.Type) map[string]fieldInfo {
|
||||
return buildStructFieldsInfo(tp)
|
||||
case reflect.Array, reflect.Slice:
|
||||
return buildFieldsInfo(mapping.Deref(tp.Elem()))
|
||||
case reflect.Chan, reflect.Func:
|
||||
return nil, fmt.Errorf("unsupported type: %s", tp.Kind())
|
||||
default:
|
||||
return nil
|
||||
return &fieldInfo{
|
||||
children: make(map[string]*fieldInfo),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func buildStructFieldsInfo(tp reflect.Type) map[string]fieldInfo {
|
||||
info := make(map[string]fieldInfo)
|
||||
func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
|
||||
var finfo *fieldInfo
|
||||
var err error
|
||||
|
||||
switch ft.Kind() {
|
||||
case reflect.Struct:
|
||||
finfo, err = buildFieldsInfo(ft)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case reflect.Array, reflect.Slice:
|
||||
finfo, err = buildFieldsInfo(ft.Elem())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case reflect.Map:
|
||||
elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
finfo = &fieldInfo{
|
||||
children: make(map[string]*fieldInfo),
|
||||
mapField: elemInfo,
|
||||
}
|
||||
default:
|
||||
finfo, err = buildFieldsInfo(ft)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return addOrMergeFields(info, lowerCaseName, finfo)
|
||||
}
|
||||
|
||||
func buildStructFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
|
||||
info := &fieldInfo{
|
||||
children: make(map[string]*fieldInfo),
|
||||
}
|
||||
|
||||
for i := 0; i < tp.NumField(); i++ {
|
||||
field := tp.Field(i)
|
||||
name := field.Name
|
||||
lowerCaseName := toLowerCase(name)
|
||||
ft := mapping.Deref(field.Type)
|
||||
|
||||
// flatten anonymous fields
|
||||
if field.Anonymous {
|
||||
if ft.Kind() == reflect.Struct {
|
||||
fields := buildFieldsInfo(ft)
|
||||
for k, v := range fields {
|
||||
info[k] = v
|
||||
}
|
||||
} else {
|
||||
info[lowerCaseName] = fieldInfo{
|
||||
name: name,
|
||||
kind: ft.Kind(),
|
||||
}
|
||||
if err := buildAnonymousFieldInfo(info, lowerCaseName, ft); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
var fields map[string]fieldInfo
|
||||
switch ft.Kind() {
|
||||
case reflect.Struct:
|
||||
fields = buildFieldsInfo(ft)
|
||||
case reflect.Array, reflect.Slice:
|
||||
fields = buildFieldsInfo(ft.Elem())
|
||||
case reflect.Map:
|
||||
fields = buildFieldsInfo(ft.Elem())
|
||||
}
|
||||
|
||||
info[lowerCaseName] = fieldInfo{
|
||||
name: name,
|
||||
kind: ft.Kind(),
|
||||
children: fields,
|
||||
} else if err := buildNamedFieldInfo(info, lowerCaseName, ft); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return info
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func mergeFields(prev *fieldInfo, key string, children map[string]*fieldInfo) error {
|
||||
if len(prev.children) == 0 || len(children) == 0 {
|
||||
return newDupKeyError(key)
|
||||
}
|
||||
|
||||
// merge fields
|
||||
for k, v := range children {
|
||||
if _, ok := prev.children[k]; ok {
|
||||
return newDupKeyError(k)
|
||||
}
|
||||
|
||||
prev.children[k] = v
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func toLowerCase(s string) string {
|
||||
return strings.ToLower(s)
|
||||
}
|
||||
|
||||
func toLowerCaseInterface(v interface{}, info map[string]fieldInfo) interface{} {
|
||||
func toLowerCaseInterface(v interface{}, info *fieldInfo) interface{} {
|
||||
switch vv := v.(type) {
|
||||
case map[string]interface{}:
|
||||
return toLowerCaseKeyMap(vv, info)
|
||||
@@ -185,19 +291,21 @@ func toLowerCaseInterface(v interface{}, info map[string]fieldInfo) interface{}
|
||||
}
|
||||
}
|
||||
|
||||
func toLowerCaseKeyMap(m map[string]interface{}, info map[string]fieldInfo) map[string]interface{} {
|
||||
func toLowerCaseKeyMap(m map[string]interface{}, info *fieldInfo) map[string]interface{} {
|
||||
res := make(map[string]interface{})
|
||||
|
||||
for k, v := range m {
|
||||
ti, ok := info[k]
|
||||
ti, ok := info.children[k]
|
||||
if ok {
|
||||
res[k] = toLowerCaseInterface(v, ti.children)
|
||||
res[k] = toLowerCaseInterface(v, ti)
|
||||
continue
|
||||
}
|
||||
|
||||
lk := toLowerCase(k)
|
||||
if ti, ok = info[lk]; ok {
|
||||
res[lk] = toLowerCaseInterface(v, ti.children)
|
||||
if ti, ok = info.children[lk]; ok {
|
||||
res[lk] = toLowerCaseInterface(v, ti)
|
||||
} else if info.mapField != nil {
|
||||
res[k] = toLowerCaseInterface(v, info.mapField)
|
||||
} else {
|
||||
res[k] = v
|
||||
}
|
||||
@@ -205,3 +313,15 @@ func toLowerCaseKeyMap(m map[string]interface{}, info map[string]fieldInfo) map[
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
type dupKeyError struct {
|
||||
key string
|
||||
}
|
||||
|
||||
func newDupKeyError(key string) dupKeyError {
|
||||
return dupKeyError{key: key}
|
||||
}
|
||||
|
||||
func (e dupKeyError) Error() string {
|
||||
return fmt.Sprintf("duplicated key %s", e.key)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/hash"
|
||||
)
|
||||
|
||||
var dupErr dupKeyError
|
||||
|
||||
func TestLoadConfig_notExists(t *testing.T) {
|
||||
assert.NotNil(t, Load("not_a_file", nil))
|
||||
}
|
||||
@@ -17,7 +19,7 @@ func TestLoadConfig_notRecogFile(t *testing.T) {
|
||||
filename, err := fs.TempFilenameWithText("hello")
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(filename)
|
||||
assert.NotNil(t, Load(filename, nil))
|
||||
assert.NotNil(t, LoadConfig(filename, nil))
|
||||
}
|
||||
|
||||
func TestConfigJson(t *testing.T) {
|
||||
@@ -64,7 +66,7 @@ func TestLoadFromJsonBytesArray(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
assert.NoError(t, LoadFromJsonBytes(input, &val))
|
||||
assert.NoError(t, LoadConfigFromJsonBytes(input, &val))
|
||||
var expect []string
|
||||
for _, user := range val.Users {
|
||||
expect = append(expect, user.Name)
|
||||
@@ -172,7 +174,7 @@ B: bar`)
|
||||
A string
|
||||
B string
|
||||
}
|
||||
assert.NoError(t, LoadFromYamlBytes(text, &val1))
|
||||
assert.NoError(t, LoadConfigFromYamlBytes(text, &val1))
|
||||
assert.Equal(t, "foo", val1.A)
|
||||
assert.Equal(t, "bar", val1.B)
|
||||
assert.NoError(t, LoadFromYamlBytes(text, &val2))
|
||||
@@ -384,6 +386,102 @@ func TestLoadFromYamlBytesLayers(t *testing.T) {
|
||||
assert.Equal(t, "foo", val.Value)
|
||||
}
|
||||
|
||||
func TestLoadFromYamlItemOverlay(t *testing.T) {
|
||||
type (
|
||||
Redis struct {
|
||||
Host string
|
||||
Port int
|
||||
}
|
||||
|
||||
RedisKey struct {
|
||||
Redis
|
||||
Key string
|
||||
}
|
||||
|
||||
Server struct {
|
||||
Redis RedisKey
|
||||
}
|
||||
|
||||
TestConfig struct {
|
||||
Server
|
||||
Redis Redis
|
||||
}
|
||||
)
|
||||
|
||||
input := []byte(`Redis:
|
||||
Host: localhost
|
||||
Port: 6379
|
||||
Key: test
|
||||
`)
|
||||
|
||||
var c TestConfig
|
||||
assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
|
||||
}
|
||||
|
||||
func TestLoadFromYamlItemOverlayReverse(t *testing.T) {
|
||||
type (
|
||||
Redis struct {
|
||||
Host string
|
||||
Port int
|
||||
}
|
||||
|
||||
RedisKey struct {
|
||||
Redis
|
||||
Key string
|
||||
}
|
||||
|
||||
Server struct {
|
||||
Redis Redis
|
||||
}
|
||||
|
||||
TestConfig struct {
|
||||
Redis RedisKey
|
||||
Server
|
||||
}
|
||||
)
|
||||
|
||||
input := []byte(`Redis:
|
||||
Host: localhost
|
||||
Port: 6379
|
||||
Key: test
|
||||
`)
|
||||
|
||||
var c TestConfig
|
||||
assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
|
||||
}
|
||||
|
||||
func TestLoadFromYamlItemOverlayWithMap(t *testing.T) {
|
||||
type (
|
||||
Redis struct {
|
||||
Host string
|
||||
Port int
|
||||
}
|
||||
|
||||
RedisKey struct {
|
||||
Redis
|
||||
Key string
|
||||
}
|
||||
|
||||
Server struct {
|
||||
Redis RedisKey
|
||||
}
|
||||
|
||||
TestConfig struct {
|
||||
Server
|
||||
Redis map[string]interface{}
|
||||
}
|
||||
)
|
||||
|
||||
input := []byte(`Redis:
|
||||
Host: localhost
|
||||
Port: 6379
|
||||
Key: test
|
||||
`)
|
||||
|
||||
var c TestConfig
|
||||
assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesMap(t *testing.T) {
|
||||
input := []byte(`{"foo":{"/mtproto.RPCTos": "bff.bff","bar":"baz"}}`)
|
||||
|
||||
@@ -450,6 +548,480 @@ func TestUnmarshalJsonBytesWithAnonymousField(t *testing.T) {
|
||||
assert.Equal(t, Int(3), c.Int)
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesWithMapValueOfStruct(t *testing.T) {
|
||||
type (
|
||||
Value struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
Config struct {
|
||||
Items map[string]Value
|
||||
}
|
||||
)
|
||||
|
||||
var inputs = [][]byte{
|
||||
[]byte(`{"Items": {"Key":{"Name": "foo"}}}`),
|
||||
[]byte(`{"Items": {"Key":{"Name": "foo"}}}`),
|
||||
[]byte(`{"items": {"key":{"name": "foo"}}}`),
|
||||
[]byte(`{"items": {"key":{"name": "foo"}}}`),
|
||||
}
|
||||
for _, input := range inputs {
|
||||
var c Config
|
||||
if assert.NoError(t, LoadFromJsonBytes(input, &c)) {
|
||||
assert.Equal(t, 1, len(c.Items))
|
||||
for _, v := range c.Items {
|
||||
assert.Equal(t, "foo", v.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesWithMapTypeValueOfStruct(t *testing.T) {
|
||||
type (
|
||||
Value struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
Map map[string]Value
|
||||
|
||||
Config struct {
|
||||
Map
|
||||
}
|
||||
)
|
||||
|
||||
var inputs = [][]byte{
|
||||
[]byte(`{"Map": {"Key":{"Name": "foo"}}}`),
|
||||
[]byte(`{"Map": {"Key":{"Name": "foo"}}}`),
|
||||
[]byte(`{"map": {"key":{"name": "foo"}}}`),
|
||||
[]byte(`{"map": {"key":{"name": "foo"}}}`),
|
||||
}
|
||||
for _, input := range inputs {
|
||||
var c Config
|
||||
if assert.NoError(t, LoadFromJsonBytes(input, &c)) {
|
||||
assert.Equal(t, 1, len(c.Map))
|
||||
for _, v := range c.Map {
|
||||
assert.Equal(t, "foo", v.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_FieldOverwrite(t *testing.T) {
|
||||
t.Run("normal", func(t *testing.T) {
|
||||
type Base struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type St1 struct {
|
||||
Base
|
||||
Name2 string
|
||||
}
|
||||
|
||||
type St2 struct {
|
||||
Base
|
||||
Name2 string
|
||||
}
|
||||
|
||||
type St3 struct {
|
||||
*Base
|
||||
Name2 string
|
||||
}
|
||||
|
||||
type St4 struct {
|
||||
*Base
|
||||
Name2 *string
|
||||
}
|
||||
|
||||
validate := func(val interface{}) {
|
||||
input := []byte(`{"Name": "hello", "Name2": "world"}`)
|
||||
assert.NoError(t, LoadFromJsonBytes(input, val))
|
||||
}
|
||||
|
||||
validate(&St1{})
|
||||
validate(&St2{})
|
||||
validate(&St3{})
|
||||
validate(&St4{})
|
||||
})
|
||||
|
||||
t.Run("Inherit Override", func(t *testing.T) {
|
||||
type Base struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type St1 struct {
|
||||
Base
|
||||
Name string
|
||||
}
|
||||
|
||||
type St2 struct {
|
||||
Base
|
||||
Name int
|
||||
}
|
||||
|
||||
type St3 struct {
|
||||
*Base
|
||||
Name int
|
||||
}
|
||||
|
||||
type St4 struct {
|
||||
*Base
|
||||
Name *string
|
||||
}
|
||||
|
||||
validate := func(val interface{}) {
|
||||
input := []byte(`{"Name": "hello"}`)
|
||||
err := LoadFromJsonBytes(input, val)
|
||||
assert.ErrorAs(t, err, &dupErr)
|
||||
assert.Equal(t, newDupKeyError("name").Error(), err.Error())
|
||||
}
|
||||
|
||||
validate(&St1{})
|
||||
validate(&St2{})
|
||||
validate(&St3{})
|
||||
validate(&St4{})
|
||||
})
|
||||
|
||||
t.Run("Inherit more", func(t *testing.T) {
|
||||
type Base1 struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type St0 struct {
|
||||
Base1
|
||||
Name string
|
||||
}
|
||||
|
||||
type St1 struct {
|
||||
St0
|
||||
Name string
|
||||
}
|
||||
|
||||
type St2 struct {
|
||||
St0
|
||||
Name int
|
||||
}
|
||||
|
||||
type St3 struct {
|
||||
*St0
|
||||
Name int
|
||||
}
|
||||
|
||||
type St4 struct {
|
||||
*St0
|
||||
Name *int
|
||||
}
|
||||
|
||||
validate := func(val interface{}) {
|
||||
input := []byte(`{"Name": "hello"}`)
|
||||
err := LoadFromJsonBytes(input, val)
|
||||
assert.ErrorAs(t, err, &dupErr)
|
||||
assert.Equal(t, newDupKeyError("name").Error(), err.Error())
|
||||
}
|
||||
|
||||
validate(&St0{})
|
||||
validate(&St1{})
|
||||
validate(&St2{})
|
||||
validate(&St3{})
|
||||
validate(&St4{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldOverwriteComplicated(t *testing.T) {
|
||||
t.Run("double maps", func(t *testing.T) {
|
||||
type (
|
||||
Base1 struct {
|
||||
Values map[string]string
|
||||
}
|
||||
Base2 struct {
|
||||
Values map[string]string
|
||||
}
|
||||
Config struct {
|
||||
Base1
|
||||
Base2
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Values": {"Key": "Value"}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("merge children", func(t *testing.T) {
|
||||
type (
|
||||
Inner1 struct {
|
||||
Name string
|
||||
}
|
||||
Inner2 struct {
|
||||
Age int
|
||||
}
|
||||
Base1 struct {
|
||||
Inner Inner1
|
||||
}
|
||||
Base2 struct {
|
||||
Inner Inner2
|
||||
}
|
||||
Config struct {
|
||||
Base1
|
||||
Base2
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Inner": {"Name": "foo", "Age": 10}}`)
|
||||
if assert.NoError(t, LoadFromJsonBytes(input, &c)) {
|
||||
assert.Equal(t, "foo", c.Base1.Inner.Name)
|
||||
assert.Equal(t, 10, c.Base2.Inner.Age)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("overwritten maps", func(t *testing.T) {
|
||||
type (
|
||||
Inner struct {
|
||||
Map map[string]string
|
||||
}
|
||||
Config struct {
|
||||
Map map[string]string
|
||||
Inner
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Inner": {"Map": {"Key": "Value"}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten nested maps", func(t *testing.T) {
|
||||
type (
|
||||
Inner struct {
|
||||
Map map[string]string
|
||||
}
|
||||
Middle1 struct {
|
||||
Map map[string]string
|
||||
Inner
|
||||
}
|
||||
Middle2 struct {
|
||||
Map map[string]string
|
||||
Inner
|
||||
}
|
||||
Config struct {
|
||||
Middle1
|
||||
Middle2
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Middle1": {"Inner": {"Map": {"Key": "Value"}}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten outer/inner maps", func(t *testing.T) {
|
||||
type (
|
||||
Inner struct {
|
||||
Map map[string]string
|
||||
}
|
||||
Middle struct {
|
||||
Inner
|
||||
Map map[string]string
|
||||
}
|
||||
Config struct {
|
||||
Middle
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Middle": {"Inner": {"Map": {"Key": "Value"}}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten anonymous maps", func(t *testing.T) {
|
||||
type (
|
||||
Inner struct {
|
||||
Map map[string]string
|
||||
}
|
||||
Middle struct {
|
||||
Inner
|
||||
Map map[string]string
|
||||
}
|
||||
Elem map[string]Middle
|
||||
Config struct {
|
||||
Elem
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Elem": {"Key": {"Inner": {"Map": {"Key": "Value"}}}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten primitive and map", func(t *testing.T) {
|
||||
type (
|
||||
Inner struct {
|
||||
Value string
|
||||
}
|
||||
Elem map[string]Inner
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Config struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten map and slice", func(t *testing.T) {
|
||||
type (
|
||||
Inner struct {
|
||||
Value string
|
||||
}
|
||||
Elem []Inner
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Config struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten map and string", func(t *testing.T) {
|
||||
type (
|
||||
Elem string
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Config struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadNamedFieldOverwritten(t *testing.T) {
|
||||
t.Run("overwritten named struct", func(t *testing.T) {
|
||||
type (
|
||||
Elem string
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Base struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
Config struct {
|
||||
Val Base
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Val": {"Elem": {"Key": {"Value": "Value"}}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten named []struct", func(t *testing.T) {
|
||||
type (
|
||||
Elem string
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Base struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
Config struct {
|
||||
Vals []Base
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Vals": [{"Elem": {"Key": {"Value": "Value"}}}]}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten named map[string]struct", func(t *testing.T) {
|
||||
type (
|
||||
Elem string
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Base struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
Config struct {
|
||||
Vals map[string]Base
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Vals": {"Key": {"Elem": {"Key": {"Value": "Value"}}}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten named *struct", func(t *testing.T) {
|
||||
type (
|
||||
Elem string
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Base struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
Config struct {
|
||||
Vals *Base
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Vals": [{"Elem": {"Key": {"Value": "Value"}}}]}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten named struct", func(t *testing.T) {
|
||||
type (
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Base struct {
|
||||
Named
|
||||
Elem Named
|
||||
}
|
||||
Config struct {
|
||||
Val Base
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Val": {"Elem": "Value"}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten named struct", func(t *testing.T) {
|
||||
type Config struct {
|
||||
Val chan int
|
||||
}
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Val": 1}`)
|
||||
assert.Error(t, LoadFromJsonBytes(input, &c))
|
||||
})
|
||||
}
|
||||
|
||||
func createTempFile(ext, text string) (string, error) {
|
||||
tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
|
||||
if err != nil {
|
||||
@@ -467,3 +1039,55 @@ func createTempFile(ext, text string) (string, error) {
|
||||
|
||||
return filename, nil
|
||||
}
|
||||
|
||||
func TestFillDefaultUnmarshal(t *testing.T) {
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
type St struct{}
|
||||
err := FillDefault(St{})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("not nil", func(t *testing.T) {
|
||||
type St struct{}
|
||||
err := FillDefault(&St{})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("default", func(t *testing.T) {
|
||||
type St struct {
|
||||
A string `json:",default=a"`
|
||||
B string
|
||||
}
|
||||
var st St
|
||||
err := FillDefault(&st)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, st.A, "a")
|
||||
})
|
||||
|
||||
t.Run("env", func(t *testing.T) {
|
||||
type St struct {
|
||||
A string `json:",default=a"`
|
||||
B string
|
||||
C string `json:",env=TEST_C"`
|
||||
}
|
||||
t.Setenv("TEST_C", "c")
|
||||
|
||||
var st St
|
||||
err := FillDefault(&st)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, st.A, "a")
|
||||
assert.Equal(t, st.C, "c")
|
||||
})
|
||||
|
||||
t.Run("has vaue", func(t *testing.T) {
|
||||
type St struct {
|
||||
A string `json:",default=a"`
|
||||
B string
|
||||
}
|
||||
var st = St{
|
||||
A: "b",
|
||||
}
|
||||
err := FillDefault(&st)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
// PropertyError represents a configuration error message.
|
||||
type PropertyError struct {
|
||||
error
|
||||
message string
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
```go
|
||||
type RestfulConf struct {
|
||||
ServiceName string `json:",env=SERVICE_NAME"` // read from env automatically
|
||||
Host string `json:",default=0.0.0.0"`
|
||||
Port int
|
||||
LogMode string `json:",options=[file,console]"`
|
||||
@@ -21,20 +22,20 @@ type RestfulConf struct {
|
||||
|
||||
```yaml
|
||||
# most fields are optional or have default values
|
||||
Port: 8080
|
||||
LogMode: console
|
||||
port: 8080
|
||||
logMode: console
|
||||
# you can use env settings
|
||||
MaxBytes: ${MAX_BYTES}
|
||||
maxBytes: ${MAX_BYTES}
|
||||
```
|
||||
|
||||
- toml example
|
||||
|
||||
```toml
|
||||
# most fields are optional or have default values
|
||||
Port = 8_080
|
||||
LogMode = "console"
|
||||
port = 8_080
|
||||
logMode = "console"
|
||||
# you can use env settings
|
||||
MaxBytes = "${MAX_BYTES}"
|
||||
maxBytes = "${MAX_BYTES}"
|
||||
```
|
||||
|
||||
3. Load the config from a file:
|
||||
|
||||
@@ -53,10 +53,11 @@ func TestChunkExecutorFlushInterval(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestChunkExecutorEmpty(t *testing.T) {
|
||||
NewChunkExecutor(func(items []interface{}) {
|
||||
executor := NewChunkExecutor(func(items []interface{}) {
|
||||
assert.Fail(t, "should not called")
|
||||
}, WithChunkBytes(10), WithFlushInterval(time.Millisecond))
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
executor.Wait()
|
||||
}
|
||||
|
||||
func TestChunkExecutorFlush(t *testing.T) {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
)
|
||||
|
||||
@@ -67,6 +68,7 @@ func TestPeriodicalExecutor_QuitGoroutine(t *testing.T) {
|
||||
ticker.Tick()
|
||||
ticker.Wait(time.Millisecond * idleRound)
|
||||
assert.Equal(t, routines, runtime.NumGoroutine())
|
||||
proc.Shutdown()
|
||||
}
|
||||
|
||||
func TestPeriodicalExecutor_Bulk(t *testing.T) {
|
||||
|
||||
@@ -27,6 +27,26 @@ func Close() error {
|
||||
return logx.Close()
|
||||
}
|
||||
|
||||
// Debug writes v into access log.
|
||||
func Debug(ctx context.Context, v ...interface{}) {
|
||||
getLogger(ctx).Debug(v...)
|
||||
}
|
||||
|
||||
// Debugf writes v with format into access log.
|
||||
func Debugf(ctx context.Context, format string, v ...interface{}) {
|
||||
getLogger(ctx).Debugf(format, v...)
|
||||
}
|
||||
|
||||
// Debugv writes v into access log with json content.
|
||||
func Debugv(ctx context.Context, v interface{}) {
|
||||
getLogger(ctx).Debugv(v)
|
||||
}
|
||||
|
||||
// Debugw writes msg along with fields into access log.
|
||||
func Debugw(ctx context.Context, msg string, fields ...LogField) {
|
||||
getLogger(ctx).Debugw(msg, fields...)
|
||||
}
|
||||
|
||||
// Error writes v into error log.
|
||||
func Error(ctx context.Context, v ...interface{}) {
|
||||
getLogger(ctx).Error(v...)
|
||||
|
||||
@@ -140,6 +140,54 @@ func TestInfow(t *testing.T) {
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestDebug(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
file, line := getFileLine()
|
||||
Debug(context.Background(), "foo")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestDebugf(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
file, line := getFileLine()
|
||||
Debugf(context.Background(), "foo %s", "bar")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestDebugv(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
file, line := getFileLine()
|
||||
Debugv(context.Background(), "foo")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestDebugw(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
file, line := getFileLine()
|
||||
Debugw(context.Background(), "foo", Field("a", "b"))
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestMust(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
Must(nil)
|
||||
|
||||
@@ -2,17 +2,34 @@ package logx
|
||||
|
||||
// A LogConf is a logging config.
|
||||
type LogConf struct {
|
||||
ServiceName string `json:",optional"`
|
||||
Mode string `json:",default=console,options=[console,file,volume]"`
|
||||
Encoding string `json:",default=json,options=[json,plain]"`
|
||||
TimeFormat string `json:",optional"`
|
||||
Path string `json:",default=logs"`
|
||||
Level string `json:",default=info,options=[debug,info,error,severe]"`
|
||||
MaxContentLength uint32 `json:",optional"`
|
||||
Compress bool `json:",optional"`
|
||||
Stat bool `json:",default=true"`
|
||||
KeepDays int `json:",optional"`
|
||||
StackCooldownMillis int `json:",default=100"`
|
||||
// ServiceName represents the service name.
|
||||
ServiceName string `json:",optional"`
|
||||
// Mode represents the logging mode, default is `console`.
|
||||
// console: log to console.
|
||||
// file: log to file.
|
||||
// volume: used in k8s, prepend the hostname to the log file name.
|
||||
Mode string `json:",default=console,options=[console,file,volume]"`
|
||||
// Encoding represents the encoding type, default is `json`.
|
||||
// json: json encoding.
|
||||
// plain: plain text encoding, typically used in development.
|
||||
Encoding string `json:",default=json,options=[json,plain]"`
|
||||
// TimeFormat represents the time format, default is `2006-01-02T15:04:05.000Z07:00`.
|
||||
TimeFormat string `json:",optional"`
|
||||
// Path represents the log file path, default is `logs`.
|
||||
Path string `json:",default=logs"`
|
||||
// Level represents the log level, default is `info`.
|
||||
Level string `json:",default=info,options=[debug,info,error,severe]"`
|
||||
// MaxContentLength represents the max content bytes, default is no limit.
|
||||
MaxContentLength uint32 `json:",optional"`
|
||||
// Compress represents whether to compress the log file, default is `false`.
|
||||
Compress bool `json:",optional"`
|
||||
// Stdout represents whether to log statistics, default is `true`.
|
||||
Stat bool `json:",default=true"`
|
||||
// KeepDays represents how many days the log files will be kept. Default to keep all files.
|
||||
// Only take effect when Mode is `file` or `volume`, both work when Rotation is `daily` or `size`.
|
||||
KeepDays int `json:",optional"`
|
||||
// StackCooldownMillis represents the cooldown time for stack logging, default is 100ms.
|
||||
StackCooldownMillis int `json:",default=100"`
|
||||
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
|
||||
// Only take effect when RotationRuleType is `size`.
|
||||
// Even thougth `MaxBackups` sets 0, log files will still be removed
|
||||
|
||||
@@ -108,7 +108,7 @@ func TestNopWriter(t *testing.T) {
|
||||
w.Stack("foo")
|
||||
w.Stat("foo")
|
||||
w.Slow("foo")
|
||||
w.Close()
|
||||
_ = w.Close()
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ type (
|
||||
UnmarshalOption func(*unmarshalOptions)
|
||||
|
||||
unmarshalOptions struct {
|
||||
fillDefault bool
|
||||
fromString bool
|
||||
canonicalKey func(key string) string
|
||||
}
|
||||
@@ -710,7 +711,14 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect
|
||||
|
||||
valuer := createValuer(m, opts)
|
||||
mapValue, hasValue := getValue(valuer, canonicalKey)
|
||||
if !hasValue {
|
||||
|
||||
// When fillDefault is used, m is a null value, hasValue must be false, all priority judgments fillDefault,
|
||||
if u.opts.fillDefault {
|
||||
if !value.IsZero() {
|
||||
return fmt.Errorf("set the default value, %s must be zero", fullName)
|
||||
}
|
||||
return u.processNamedFieldWithoutValue(field.Type, value, opts, fullName)
|
||||
} else if !hasValue {
|
||||
return u.processNamedFieldWithoutValue(field.Type, value, opts, fullName)
|
||||
}
|
||||
|
||||
@@ -801,6 +809,10 @@ func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, valu
|
||||
}
|
||||
}
|
||||
|
||||
if u.opts.fillDefault {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch fieldKind {
|
||||
case reflect.Array, reflect.Map, reflect.Slice:
|
||||
if !opts.optional() {
|
||||
@@ -853,7 +865,12 @@ func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v interface{}, f
|
||||
|
||||
numFields := baseType.NumField()
|
||||
for i := 0; i < numFields; i++ {
|
||||
if err := u.processField(baseType.Field(i), valElem.Field(i), m, fullName); err != nil {
|
||||
field := baseType.Field(i)
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := u.processField(field, valElem.Field(i), m, fullName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -868,13 +885,20 @@ func WithStringValues() UnmarshalOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithCanonicalKeyFunc customizes an Unmarshaler with Canonical Key func
|
||||
// WithCanonicalKeyFunc customizes an Unmarshaler with Canonical Key func.
|
||||
func WithCanonicalKeyFunc(f func(string) string) UnmarshalOption {
|
||||
return func(opt *unmarshalOptions) {
|
||||
opt.canonicalKey = f
|
||||
}
|
||||
}
|
||||
|
||||
// WithDefault customizes an Unmarshaler with fill default values.
|
||||
func WithDefault() UnmarshalOption {
|
||||
return func(opt *unmarshalOptions) {
|
||||
opt.fillDefault = true
|
||||
}
|
||||
}
|
||||
|
||||
func createValuer(v valuerWithParent, opts *fieldOptionsWithContext) valuerWithParent {
|
||||
if opts.inherit() {
|
||||
return recursiveValuer{
|
||||
@@ -1004,7 +1028,7 @@ func newInitError(name string) error {
|
||||
}
|
||||
|
||||
func newTypeMismatchError(name string) error {
|
||||
return fmt.Errorf("error: type mismatch for field %s", name)
|
||||
return fmt.Errorf("type mismatch for field %s", name)
|
||||
}
|
||||
|
||||
func readKeys(key string) []string {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
@@ -793,7 +794,9 @@ func TestUnmarshalStringMapFromNotSettableValue(t *testing.T) {
|
||||
}
|
||||
|
||||
ast := assert.New(t)
|
||||
ast.Error(UnmarshalKey(m, &v))
|
||||
ast.NoError(UnmarshalKey(m, &v))
|
||||
assert.Empty(t, v.sort)
|
||||
assert.Nil(t, v.psort)
|
||||
}
|
||||
|
||||
func TestUnmarshalStringMapFromString(t *testing.T) {
|
||||
@@ -4265,6 +4268,24 @@ func TestUnmarshalStructPtrOfPtr(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalOnlyPublicVariables(t *testing.T) {
|
||||
type demo struct {
|
||||
age int `key:"age"`
|
||||
Name string `key:"name"`
|
||||
}
|
||||
|
||||
m := map[string]interface{}{
|
||||
"age": 3,
|
||||
"name": "go-zero",
|
||||
}
|
||||
|
||||
var in demo
|
||||
if assert.NoError(t, UnmarshalKey(m, &in)) {
|
||||
assert.Equal(t, 0, in.age)
|
||||
assert.Equal(t, "go-zero", in.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDefaultValue(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var a struct {
|
||||
@@ -4364,3 +4385,56 @@ func BenchmarkUnmarshal(b *testing.B) {
|
||||
UnmarshalKey(data, &an)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFillDefaultUnmarshal(t *testing.T) {
|
||||
fillDefaultUnmarshal := NewUnmarshaler(jsonTagKey, WithDefault())
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
type St struct{}
|
||||
err := fillDefaultUnmarshal.Unmarshal(map[string]interface{}{}, St{})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("not nil", func(t *testing.T) {
|
||||
type St struct{}
|
||||
err := fillDefaultUnmarshal.Unmarshal(map[string]interface{}{}, &St{})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("default", func(t *testing.T) {
|
||||
type St struct {
|
||||
A string `json:",default=a"`
|
||||
B string
|
||||
}
|
||||
var st St
|
||||
err := fillDefaultUnmarshal.Unmarshal(map[string]interface{}{}, &st)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, st.A, "a")
|
||||
})
|
||||
|
||||
t.Run("env", func(t *testing.T) {
|
||||
type St struct {
|
||||
A string `json:",default=a"`
|
||||
B string
|
||||
C string `json:",env=TEST_C"`
|
||||
}
|
||||
t.Setenv("TEST_C", "c")
|
||||
|
||||
var st St
|
||||
err := fillDefaultUnmarshal.Unmarshal(map[string]interface{}{}, &st)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, st.A, "a")
|
||||
assert.Equal(t, st.C, "c")
|
||||
})
|
||||
|
||||
t.Run("has value", func(t *testing.T) {
|
||||
type St struct {
|
||||
A string `json:",default=a"`
|
||||
B string
|
||||
}
|
||||
var st = St{
|
||||
A: "b",
|
||||
}
|
||||
err := fillDefaultUnmarshal.Unmarshal(map[string]interface{}{}, &st)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/prometheus"
|
||||
)
|
||||
|
||||
@@ -17,6 +18,9 @@ func TestNewCounterVec(t *testing.T) {
|
||||
})
|
||||
defer counterVec.close()
|
||||
counterVecNil := NewCounterVec(nil)
|
||||
counterVec.Inc("path", "code")
|
||||
counterVec.Add(1, "path", "code")
|
||||
proc.Shutdown()
|
||||
assert.NotNil(t, counterVec)
|
||||
assert.Nil(t, counterVecNil)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
)
|
||||
|
||||
func TestNewGaugeVec(t *testing.T) {
|
||||
@@ -18,6 +19,8 @@ func TestNewGaugeVec(t *testing.T) {
|
||||
gaugeVecNil := NewGaugeVec(nil)
|
||||
assert.NotNil(t, gaugeVec)
|
||||
assert.Nil(t, gaugeVecNil)
|
||||
|
||||
proc.Shutdown()
|
||||
}
|
||||
|
||||
func TestGaugeInc(t *testing.T) {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
)
|
||||
|
||||
func TestNewHistogramVec(t *testing.T) {
|
||||
@@ -47,4 +48,6 @@ func TestHistogramObserve(t *testing.T) {
|
||||
|
||||
err := testutil.CollectAndCompare(hv.histogram, strings.NewReader(metadata+val))
|
||||
assert.Nil(t, err)
|
||||
|
||||
proc.Shutdown()
|
||||
}
|
||||
|
||||
@@ -15,5 +15,14 @@ func AddWrapUpListener(fn func()) func() {
|
||||
return fn
|
||||
}
|
||||
|
||||
// SetTimeToForceQuit does nothing on windows.
|
||||
func SetTimeToForceQuit(duration time.Duration) {
|
||||
}
|
||||
|
||||
// Shutdown does nothing on windows.
|
||||
func Shutdown() {
|
||||
}
|
||||
|
||||
// WrapUp does nothing on windows.
|
||||
func WrapUp() {
|
||||
}
|
||||
|
||||
@@ -43,6 +43,16 @@ func SetTimeToForceQuit(duration time.Duration) {
|
||||
delayTimeBeforeForceQuit = duration
|
||||
}
|
||||
|
||||
// Shutdown calls the registered shutdown listeners, only for test purpose.
|
||||
func Shutdown() {
|
||||
shutdownListeners.notifyListeners()
|
||||
}
|
||||
|
||||
// WrapUp wraps up the process, only for test purpose.
|
||||
func WrapUp() {
|
||||
wrapUpListeners.notifyListeners()
|
||||
}
|
||||
|
||||
func gracefulStop(signals chan os.Signal) {
|
||||
signal.Stop(signals)
|
||||
|
||||
|
||||
@@ -18,14 +18,14 @@ func TestShutdown(t *testing.T) {
|
||||
called := AddWrapUpListener(func() {
|
||||
val++
|
||||
})
|
||||
wrapUpListeners.notifyListeners()
|
||||
WrapUp()
|
||||
called()
|
||||
assert.Equal(t, 1, val)
|
||||
|
||||
called = AddShutdownListener(func() {
|
||||
val += 2
|
||||
})
|
||||
shutdownListeners.notifyListeners()
|
||||
Shutdown()
|
||||
called()
|
||||
assert.Equal(t, 3, val)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
@@ -16,3 +17,15 @@ func TestServiceConf(t *testing.T) {
|
||||
}
|
||||
c.MustSetUp()
|
||||
}
|
||||
|
||||
func TestServiceConfWithMetricsUrl(t *testing.T) {
|
||||
c := ServiceConf{
|
||||
Name: "foo",
|
||||
Log: logx.LogConf{
|
||||
Mode: "volume",
|
||||
},
|
||||
Mode: "dev",
|
||||
MetricsUrl: "http://localhost:8080",
|
||||
}
|
||||
assert.NoError(t, c.SetUp())
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -55,6 +56,7 @@ func TestServiceGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
group.Stop()
|
||||
proc.Shutdown()
|
||||
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
5
core/stores/cache/cache.go
vendored
5
core/stores/cache/cache.go
vendored
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/zeromicro/go-zero/core/errorx"
|
||||
"github.com/zeromicro/go-zero/core/hash"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
)
|
||||
|
||||
@@ -62,12 +63,12 @@ func New(c ClusterConf, barrier syncx.SingleFlight, st *Stat, errNotFound error,
|
||||
}
|
||||
|
||||
if len(c) == 1 {
|
||||
return NewNode(c[0].NewRedis(), barrier, st, errNotFound, opts...)
|
||||
return NewNode(redis.MustNewRedis(c[0].RedisConf), barrier, st, errNotFound, opts...)
|
||||
}
|
||||
|
||||
dispatcher := hash.NewConsistentHash()
|
||||
for _, node := range c {
|
||||
cn := NewNode(node.NewRedis(), barrier, st, errNotFound, opts...)
|
||||
cn := NewNode(redis.MustNewRedis(node.RedisConf), barrier, st, errNotFound, opts...)
|
||||
dispatcher.AddWithWeight(cn, node.Weight)
|
||||
}
|
||||
|
||||
|
||||
4
core/stores/cache/cache_test.go
vendored
4
core/stores/cache/cache_test.go
vendored
@@ -163,12 +163,10 @@ func TestCache_SetDel(t *testing.T) {
|
||||
r1, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r1.Close()
|
||||
r1.SetError("mock error")
|
||||
|
||||
r2, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r2.Close()
|
||||
r2.SetError("mock error")
|
||||
|
||||
conf := ClusterConf{
|
||||
{
|
||||
@@ -187,6 +185,8 @@ func TestCache_SetDel(t *testing.T) {
|
||||
},
|
||||
}
|
||||
c := New(conf, syncx.NewSingleFlight(), NewStat("mock"), errPlaceholder)
|
||||
r1.SetError("mock error")
|
||||
r2.SetError("mock error")
|
||||
assert.NoError(t, c.Del("a", "b", "c"))
|
||||
})
|
||||
}
|
||||
|
||||
3
core/stores/cache/cachenode.go
vendored
3
core/stores/cache/cachenode.go
vendored
@@ -277,5 +277,6 @@ func (c cacheNode) processCache(ctx context.Context, key, data string, v interfa
|
||||
|
||||
func (c cacheNode) setCacheWithNotFound(ctx context.Context, key string) error {
|
||||
seconds := int(math.Ceil(c.aroundDuration(c.notFoundExpiry).Seconds()))
|
||||
return c.rds.SetexCtx(ctx, key, notFoundPlaceholder, seconds)
|
||||
_, err := c.rds.SetnxExCtx(ctx, key, notFoundPlaceholder, seconds)
|
||||
return err
|
||||
}
|
||||
|
||||
29
core/stores/cache/cachenode_test.go
vendored
29
core/stores/cache/cachenode_test.go
vendored
@@ -209,6 +209,35 @@ func TestCacheNode_TakeNotFound(t *testing.T) {
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCacheNode_TakeNotFoundButChangedByOthers(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.NoError(t, err)
|
||||
defer clean()
|
||||
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
barrier: syncx.NewSingleFlight(),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewStat("any"),
|
||||
errNotFound: errTestNotFound,
|
||||
}
|
||||
|
||||
var str string
|
||||
err = cn.Take(&str, "any", func(v interface{}) error {
|
||||
store.Set("any", "foo")
|
||||
return errTestNotFound
|
||||
})
|
||||
assert.True(t, cn.IsNotFound(err))
|
||||
|
||||
val, err := store.Get("any")
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, "foo", val)
|
||||
}
|
||||
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
|
||||
}
|
||||
|
||||
func TestCacheNode_TakeWithExpire(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
|
||||
2
core/stores/cache/cleaner_test.go
vendored
2
core/stores/cache/cleaner_test.go
vendored
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
)
|
||||
|
||||
func TestNextDelay(t *testing.T) {
|
||||
@@ -51,6 +52,7 @@ func TestNextDelay(t *testing.T) {
|
||||
next, ok := nextDelay(test.input)
|
||||
assert.Equal(t, test.ok, ok)
|
||||
assert.Equal(t, test.output, next)
|
||||
proc.Shutdown()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -164,7 +164,7 @@ func NewStore(c KvConf) Store {
|
||||
// because Store and redis.Redis has different methods.
|
||||
dispatcher := hash.NewConsistentHash()
|
||||
for _, node := range c {
|
||||
cn := node.NewRedis()
|
||||
cn := redis.MustNewRedis(node.RedisConf)
|
||||
dispatcher.AddWithWeight(cn, node.Weight)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
|
||||
"github.com/zeromicro/go-zero/core/trace"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
@@ -14,11 +13,8 @@ import (
|
||||
var mongoCmdAttributeKey = attribute.Key("mongo.cmd")
|
||||
|
||||
func startSpan(ctx context.Context, cmd string) (context.Context, oteltrace.Span) {
|
||||
tracer := otel.Tracer(trace.TraceName)
|
||||
ctx, span := tracer.Start(ctx,
|
||||
spanName,
|
||||
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
|
||||
)
|
||||
tracer := trace.TracerFromContext(ctx)
|
||||
ctx, span := tracer.Start(ctx, spanName, oteltrace.WithSpanKind(oteltrace.SpanKindClient))
|
||||
span.SetAttributes(mongoCmdAttributeKey.String(cmd))
|
||||
|
||||
return ctx, span
|
||||
|
||||
@@ -9,6 +9,8 @@ var (
|
||||
ErrEmptyType = errors.New("empty redis type")
|
||||
// ErrEmptyKey is an error that indicates no redis key is set.
|
||||
ErrEmptyKey = errors.New("empty redis key")
|
||||
// ErrPing is an error that indicates ping failed.
|
||||
ErrPing = errors.New("ping redis failed")
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -28,6 +30,7 @@ type (
|
||||
)
|
||||
|
||||
// NewRedis returns a Redis.
|
||||
// Deprecated: use MustNewRedis or NewRedis instead.
|
||||
func (rc RedisConf) NewRedis() *Redis {
|
||||
var opts []Option
|
||||
if rc.Type == ClusterType {
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/mapping"
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
"github.com/zeromicro/go-zero/core/trace"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
@@ -25,15 +24,13 @@ const spanName = "redis"
|
||||
|
||||
var (
|
||||
startTimeKey = contextKey("startTime")
|
||||
durationHook = hook{tracer: otel.Tracer(trace.TraceName)}
|
||||
durationHook = hook{}
|
||||
redisCmdsAttributeKey = attribute.Key("redis.cmds")
|
||||
)
|
||||
|
||||
type (
|
||||
contextKey string
|
||||
hook struct {
|
||||
tracer oteltrace.Tracer
|
||||
}
|
||||
hook struct{}
|
||||
)
|
||||
|
||||
func (h hook) BeforeProcess(ctx context.Context, cmd red.Cmder) (context.Context, error) {
|
||||
@@ -155,7 +152,9 @@ func logDuration(ctx context.Context, cmds []red.Cmder, duration time.Duration)
|
||||
}
|
||||
|
||||
func (h hook) startSpan(ctx context.Context, cmds ...red.Cmder) context.Context {
|
||||
ctx, span := h.tracer.Start(ctx,
|
||||
tracer := trace.TracerFromContext(ctx)
|
||||
|
||||
ctx, span := tracer.Start(ctx,
|
||||
spanName,
|
||||
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -86,7 +87,46 @@ type (
|
||||
)
|
||||
|
||||
// New returns a Redis with given options.
|
||||
// Deprecated: use MustNewRedis or NewRedis instead.
|
||||
func New(addr string, opts ...Option) *Redis {
|
||||
return newRedis(addr, opts...)
|
||||
}
|
||||
|
||||
// MustNewRedis returns a Redis with given options.
|
||||
func MustNewRedis(conf RedisConf, opts ...Option) *Redis {
|
||||
rds, err := NewRedis(conf, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return rds
|
||||
}
|
||||
|
||||
// NewRedis returns a Redis with given options.
|
||||
func NewRedis(conf RedisConf, opts ...Option) (*Redis, error) {
|
||||
if err := conf.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if conf.Type == ClusterType {
|
||||
opts = append([]Option{Cluster()}, opts...)
|
||||
}
|
||||
if len(conf.Pass) > 0 {
|
||||
opts = append([]Option{WithPass(conf.Pass)}, opts...)
|
||||
}
|
||||
if conf.Tls {
|
||||
opts = append([]Option{WithTLS()}, opts...)
|
||||
}
|
||||
|
||||
rds := newRedis(conf.Host, opts...)
|
||||
if !rds.Ping() {
|
||||
return nil, ErrPing
|
||||
}
|
||||
|
||||
return rds, nil
|
||||
}
|
||||
|
||||
func newRedis(addr string, opts ...Option) *Redis {
|
||||
r := &Redis{
|
||||
Addr: addr,
|
||||
Type: NodeType,
|
||||
|
||||
@@ -16,6 +16,116 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
func TestNewRedis(t *testing.T) {
|
||||
r1, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r1.Close()
|
||||
|
||||
r2, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r2.Close()
|
||||
r2.SetError("mock")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
RedisConf
|
||||
ok bool
|
||||
redisErr bool
|
||||
}{
|
||||
{
|
||||
name: "missing host",
|
||||
RedisConf: RedisConf{
|
||||
Host: "",
|
||||
Type: NodeType,
|
||||
Pass: "",
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
name: "missing type",
|
||||
RedisConf: RedisConf{
|
||||
Host: "localhost:6379",
|
||||
Type: "",
|
||||
Pass: "",
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
RedisConf: RedisConf{
|
||||
Host: r1.Addr(),
|
||||
Type: NodeType,
|
||||
Pass: "",
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
RedisConf: RedisConf{
|
||||
Host: r1.Addr(),
|
||||
Type: ClusterType,
|
||||
Pass: "",
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "password",
|
||||
RedisConf: RedisConf{
|
||||
Host: r1.Addr(),
|
||||
Type: NodeType,
|
||||
Pass: "pw",
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "tls",
|
||||
RedisConf: RedisConf{
|
||||
Host: r1.Addr(),
|
||||
Type: NodeType,
|
||||
Tls: true,
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "node error",
|
||||
RedisConf: RedisConf{
|
||||
Host: r2.Addr(),
|
||||
Type: NodeType,
|
||||
Pass: "",
|
||||
},
|
||||
ok: true,
|
||||
redisErr: true,
|
||||
},
|
||||
{
|
||||
name: "cluster error",
|
||||
RedisConf: RedisConf{
|
||||
Host: r2.Addr(),
|
||||
Type: ClusterType,
|
||||
Pass: "",
|
||||
},
|
||||
ok: true,
|
||||
redisErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
rds, err := NewRedis(test.RedisConf)
|
||||
if test.ok {
|
||||
if test.redisErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, rds)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, rds)
|
||||
}
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedis_Decr(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
_, err := New(client.Addr, badType()).Decr("a")
|
||||
@@ -1651,42 +1761,17 @@ func TestRedis_WithPass(t *testing.T) {
|
||||
func runOnRedis(t *testing.T, fn func(client *Redis)) {
|
||||
logx.Disable()
|
||||
|
||||
s, err := miniredis.Run()
|
||||
assert.Nil(t, err)
|
||||
defer func() {
|
||||
client, err := clientManager.GetResource(s.Addr(), func() (io.Closer, error) {
|
||||
return nil, errors.New("should already exist")
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if client != nil {
|
||||
_ = client.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
fn(New(s.Addr()))
|
||||
s := miniredis.RunT(t)
|
||||
fn(MustNewRedis(RedisConf{
|
||||
Host: s.Addr(),
|
||||
Type: NodeType,
|
||||
}))
|
||||
}
|
||||
|
||||
func runOnRedisWithError(t *testing.T, fn func(client *Redis)) {
|
||||
logx.Disable()
|
||||
|
||||
s, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
client, err := clientManager.GetResource(s.Addr(), func() (io.Closer, error) {
|
||||
return nil, errors.New("should already exist")
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if client != nil {
|
||||
_ = client.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
s := miniredis.RunT(t)
|
||||
s.SetError("mock error")
|
||||
fn(New(s.Addr()))
|
||||
}
|
||||
|
||||
@@ -52,14 +52,11 @@ func TestSqlConn(t *testing.T) {
|
||||
}
|
||||
|
||||
func buildConn() (mock sqlmock.Sqlmock, err error) {
|
||||
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
|
||||
connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
|
||||
var db *sql.DB
|
||||
var err error
|
||||
db, mock, err = sqlmock.New()
|
||||
return &pingedDB{
|
||||
DB: db,
|
||||
}, err
|
||||
return db, err
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package sqlx
|
||||
import (
|
||||
"database/sql"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
@@ -17,43 +16,29 @@ const (
|
||||
|
||||
var connManager = syncx.NewResourceManager()
|
||||
|
||||
type pingedDB struct {
|
||||
*sql.DB
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func getCachedSqlConn(driverName, server string) (*pingedDB, error) {
|
||||
func getCachedSqlConn(driverName, server string) (*sql.DB, error) {
|
||||
val, err := connManager.GetResource(server, func() (io.Closer, error) {
|
||||
conn, err := newDBConnection(driverName, server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pingedDB{
|
||||
DB: conn,
|
||||
}, nil
|
||||
return conn, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return val.(*pingedDB), nil
|
||||
return val.(*sql.DB), nil
|
||||
}
|
||||
|
||||
func getSqlConn(driverName, server string) (*sql.DB, error) {
|
||||
pdb, err := getCachedSqlConn(driverName, server)
|
||||
conn, err := getCachedSqlConn(driverName, server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pdb.once.Do(func() {
|
||||
err = pdb.Ping()
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pdb.DB, nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func newDBConnection(driverName, datasource string) (*sql.DB, error) {
|
||||
@@ -70,5 +55,10 @@ func newDBConnection(driverName, datasource string) (*sql.DB, error) {
|
||||
conn.SetMaxOpenConns(maxOpenConns)
|
||||
conn.SetConnMaxLifetime(maxLifetime)
|
||||
|
||||
if err := conn.Ping(); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/trace"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
@@ -14,11 +13,8 @@ import (
|
||||
var sqlAttributeKey = attribute.Key("sql.method")
|
||||
|
||||
func startSpan(ctx context.Context, method string) (context.Context, oteltrace.Span) {
|
||||
tracer := otel.Tracer(trace.TraceName)
|
||||
start, span := tracer.Start(ctx,
|
||||
spanName,
|
||||
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
|
||||
)
|
||||
tracer := trace.TracerFromContext(ctx)
|
||||
start, span := tracer.Start(ctx, spanName, oteltrace.WithSpanKind(oteltrace.SpanKindClient))
|
||||
span.SetAttributes(sqlAttributeKey.String(method))
|
||||
|
||||
return start, span
|
||||
|
||||
@@ -66,7 +66,7 @@ func format(query string, args ...interface{}) (string, error) {
|
||||
switch ch {
|
||||
case '?':
|
||||
if argIndex >= numArgs {
|
||||
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
|
||||
return "", fmt.Errorf("%d ? in sql, but less arguments provided", argIndex)
|
||||
}
|
||||
|
||||
writeValue(&b, args[argIndex])
|
||||
@@ -93,7 +93,7 @@ func format(query string, args ...interface{}) (string, error) {
|
||||
|
||||
index--
|
||||
if index < 0 || numArgs <= index {
|
||||
return "", fmt.Errorf("error: wrong index %d in sql", index)
|
||||
return "", fmt.Errorf("wrong index %d in sql", index)
|
||||
}
|
||||
|
||||
writeValue(&b, args[index])
|
||||
@@ -124,7 +124,7 @@ func format(query string, args ...interface{}) (string, error) {
|
||||
}
|
||||
|
||||
if argIndex < numArgs {
|
||||
return "", fmt.Errorf("error: %d arguments provided, not matching sql", argIndex)
|
||||
return "", fmt.Errorf("%d arguments provided, not matching sql", argIndex)
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
|
||||
@@ -14,7 +14,6 @@ func (n *node) add(word string) {
|
||||
}
|
||||
|
||||
nd := n
|
||||
var depth int
|
||||
for i, char := range chars {
|
||||
if nd.children == nil {
|
||||
child := new(node)
|
||||
@@ -23,7 +22,6 @@ func (n *node) add(word string) {
|
||||
nd = child
|
||||
} else if child, ok := nd.children[char]; ok {
|
||||
nd = child
|
||||
depth++
|
||||
} else {
|
||||
child := new(node)
|
||||
child.depth = i + 1
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
package stringx
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// replace more than once to avoid overlapped keywords after replace.
|
||||
// only try 2 times to avoid too many or infinite loops.
|
||||
const replaceTimes = 2
|
||||
|
||||
type (
|
||||
// Replacer interface wraps the Replace method.
|
||||
@@ -30,68 +37,48 @@ func NewReplacer(mapping map[string]string) Replacer {
|
||||
|
||||
// Replace replaces text with given substitutes.
|
||||
func (r *replacer) Replace(text string) string {
|
||||
var builder strings.Builder
|
||||
var start int
|
||||
chars := []rune(text)
|
||||
size := len(chars)
|
||||
|
||||
for start < size {
|
||||
cur := r.node
|
||||
|
||||
if start > 0 {
|
||||
builder.WriteString(string(chars[:start]))
|
||||
}
|
||||
|
||||
for i := start; i < size; i++ {
|
||||
child, ok := cur.children[chars[i]]
|
||||
if ok {
|
||||
cur = child
|
||||
} else if cur == r.node {
|
||||
builder.WriteRune(chars[i])
|
||||
// cur already points to root, set start only
|
||||
start = i + 1
|
||||
continue
|
||||
} else {
|
||||
curDepth := cur.depth
|
||||
cur = cur.fail
|
||||
child, ok = cur.children[chars[i]]
|
||||
if !ok {
|
||||
// write this path
|
||||
builder.WriteString(string(chars[i-curDepth : i+1]))
|
||||
// go to root
|
||||
cur = r.node
|
||||
start = i + 1
|
||||
continue
|
||||
}
|
||||
|
||||
failDepth := cur.depth
|
||||
// write path before jump
|
||||
builder.WriteString(string(chars[start : start+curDepth-failDepth]))
|
||||
start += curDepth - failDepth
|
||||
cur = child
|
||||
}
|
||||
|
||||
if cur.end {
|
||||
val := string(chars[i+1-cur.depth : i+1])
|
||||
builder.WriteString(r.mapping[val])
|
||||
builder.WriteString(string(chars[i+1:]))
|
||||
// only matching this path, all previous paths are done
|
||||
if start >= i+1-cur.depth && i+1 >= size {
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
chars = []rune(builder.String())
|
||||
size = len(chars)
|
||||
builder.Reset()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !cur.end {
|
||||
builder.WriteString(string(chars[start:]))
|
||||
return builder.String()
|
||||
for i := 0; i < replaceTimes; i++ {
|
||||
var replaced bool
|
||||
if text, replaced = r.doReplace(text); !replaced {
|
||||
return text
|
||||
}
|
||||
}
|
||||
|
||||
return string(chars)
|
||||
return text
|
||||
}
|
||||
|
||||
func (r *replacer) doReplace(text string) (string, bool) {
|
||||
chars := []rune(text)
|
||||
scopes := r.find(chars)
|
||||
if len(scopes) == 0 {
|
||||
return text, false
|
||||
}
|
||||
|
||||
sort.Slice(scopes, func(i, j int) bool {
|
||||
if scopes[i].start < scopes[j].start {
|
||||
return true
|
||||
}
|
||||
if scopes[i].start == scopes[j].start {
|
||||
return scopes[i].stop > scopes[j].stop
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
var buf strings.Builder
|
||||
var index int
|
||||
for i := 0; i < len(scopes); i++ {
|
||||
scp := &scopes[i]
|
||||
if scp.start < index {
|
||||
continue
|
||||
}
|
||||
|
||||
buf.WriteString(string(chars[index:scp.start]))
|
||||
buf.WriteString(r.mapping[string(chars[scp.start:scp.stop])])
|
||||
index = scp.stop
|
||||
}
|
||||
if index < len(chars) {
|
||||
buf.WriteString(string(chars[index:]))
|
||||
}
|
||||
|
||||
return buf.String(), true
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build go1.18
|
||||
// +build go1.18
|
||||
|
||||
package stringx
|
||||
|
||||
|
||||
@@ -15,6 +15,15 @@ func TestReplacer_Replace(t *testing.T) {
|
||||
assert.Equal(t, "零1234五", NewReplacer(mapping).Replace("零一二三四五"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceJumpMatch(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"abcdeg": "ABCDEG",
|
||||
"cdef": "CDEF",
|
||||
"cde": "CDE",
|
||||
}
|
||||
assert.Equal(t, "abCDEF", NewReplacer(mapping).Replace("abcdef"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceOverlap(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"3d": "34",
|
||||
@@ -44,6 +53,14 @@ func TestReplacer_ReplacePartialMatch(t *testing.T) {
|
||||
assert.Equal(t, "零一二三四五", NewReplacer(mapping).Replace("零一二三四五"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplacePartialMatchEnds(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"二三四七": "2347",
|
||||
"三四": "34",
|
||||
}
|
||||
assert.Equal(t, "零一二34", NewReplacer(mapping).Replace("零一二三四"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceMultiMatches(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"二三": "23",
|
||||
@@ -51,6 +68,54 @@ func TestReplacer_ReplaceMultiMatches(t *testing.T) {
|
||||
assert.Equal(t, "零一23四五一23四五", NewReplacer(mapping).Replace("零一二三四五一二三四五"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceLongestMatching(t *testing.T) {
|
||||
keywords := map[string]string{
|
||||
"日本": "japan",
|
||||
"日本的首都": "东京",
|
||||
}
|
||||
replacer := NewReplacer(keywords)
|
||||
assert.Equal(t, "东京在japan", replacer.Replace("日本的首都在日本"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceSuffixMatch(t *testing.T) {
|
||||
// case1
|
||||
{
|
||||
keywords := map[string]string{
|
||||
"abcde": "ABCDE",
|
||||
"bcde": "BCDE",
|
||||
"bcd": "BCD",
|
||||
}
|
||||
assert.Equal(t, "aBCDf", NewReplacer(keywords).Replace("abcdf"))
|
||||
}
|
||||
// case2
|
||||
{
|
||||
keywords := map[string]string{
|
||||
"abcde": "ABCDE",
|
||||
"bcde": "BCDE",
|
||||
"cde": "CDE",
|
||||
"c": "C",
|
||||
"cd": "CD",
|
||||
}
|
||||
assert.Equal(t, "abCDf", NewReplacer(keywords).Replace("abcdf"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceLongestOverlap(t *testing.T) {
|
||||
keywords := map[string]string{
|
||||
"456": "def",
|
||||
"abcd": "1234",
|
||||
}
|
||||
replacer := NewReplacer(keywords)
|
||||
assert.Equal(t, "123def7", replacer.Replace("abcd567"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceLongestLonger(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"c": "3",
|
||||
}
|
||||
assert.Equal(t, "3d", NewReplacer(mapping).Replace("cd"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceJumpToFail(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"bcdf": "1235",
|
||||
@@ -146,3 +211,21 @@ func TestFuzzReplacerCase2(t *testing.T) {
|
||||
t.Errorf("result: %s, match: %v", val, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceLongestMatch(t *testing.T) {
|
||||
replacer := NewReplacer(map[string]string{
|
||||
"日本的首都": "东京",
|
||||
"日本": "本日",
|
||||
})
|
||||
assert.Equal(t, "东京是东京", replacer.Replace("日本的首都是东京"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceIndefinitely(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"日本的首都": "东京",
|
||||
"东京": "日本的首都",
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
NewReplacer(mapping).Replace("日本的首都是东京")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package trace
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
@@ -57,6 +58,10 @@ func createExporter(c Config) (sdktrace.SpanExporter, error) {
|
||||
// Just support jaeger and zipkin now, more for later
|
||||
switch c.Batcher {
|
||||
case kindJaeger:
|
||||
u, _ := url.Parse(c.Endpoint)
|
||||
if u.Scheme == "udp" {
|
||||
return jaeger.New(jaeger.WithAgentEndpoint(jaeger.WithAgentHost(u.Hostname()), jaeger.WithAgentPort(u.Port())))
|
||||
}
|
||||
return jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint(c.Endpoint)))
|
||||
case kindZipkin:
|
||||
return zipkin.New(c.Endpoint)
|
||||
|
||||
@@ -15,6 +15,7 @@ func TestStartAgent(t *testing.T) {
|
||||
endpoint2 = "remotehost:1234"
|
||||
endpoint3 = "localhost:1235"
|
||||
endpoint4 = "localhost:1236"
|
||||
endpoint5 = "udp://localhost:6831"
|
||||
)
|
||||
c1 := Config{
|
||||
Name: "foo",
|
||||
@@ -44,6 +45,11 @@ func TestStartAgent(t *testing.T) {
|
||||
Endpoint: endpoint4,
|
||||
Batcher: kindOtlpHttp,
|
||||
}
|
||||
c7 := Config{
|
||||
Name: "UDP",
|
||||
Endpoint: endpoint5,
|
||||
Batcher: kindJaeger,
|
||||
}
|
||||
|
||||
StartAgent(c1)
|
||||
StartAgent(c1)
|
||||
@@ -52,16 +58,19 @@ func TestStartAgent(t *testing.T) {
|
||||
StartAgent(c4)
|
||||
StartAgent(c5)
|
||||
StartAgent(c6)
|
||||
StartAgent(c7)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
// because remotehost cannot be resolved
|
||||
assert.Equal(t, 4, len(agents))
|
||||
assert.Equal(t, 5, len(agents))
|
||||
_, ok := agents[""]
|
||||
assert.True(t, ok)
|
||||
_, ok = agents[endpoint1]
|
||||
assert.True(t, ok)
|
||||
_, ok = agents[endpoint2]
|
||||
assert.False(t, ok)
|
||||
_, ok = agents[endpoint5]
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
73
core/trace/message_test.go
Normal file
73
core/trace/message_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package trace
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
"google.golang.org/protobuf/types/dynamicpb"
|
||||
)
|
||||
|
||||
func TestMessageType_Event(t *testing.T) {
|
||||
var span mockSpan
|
||||
ctx := trace.ContextWithSpan(context.Background(), &span)
|
||||
MessageReceived.Event(ctx, 1, "foo")
|
||||
assert.Equal(t, messageEvent, span.name)
|
||||
assert.NotEmpty(t, span.options)
|
||||
}
|
||||
|
||||
func TestMessageType_EventProtoMessage(t *testing.T) {
|
||||
var span mockSpan
|
||||
var message mockMessage
|
||||
ctx := trace.ContextWithSpan(context.Background(), &span)
|
||||
MessageReceived.Event(ctx, 1, message)
|
||||
assert.Equal(t, messageEvent, span.name)
|
||||
assert.NotEmpty(t, span.options)
|
||||
}
|
||||
|
||||
type mockSpan struct {
|
||||
name string
|
||||
options []trace.EventOption
|
||||
}
|
||||
|
||||
func (m *mockSpan) End(options ...trace.SpanEndOption) {
|
||||
}
|
||||
|
||||
func (m *mockSpan) AddEvent(name string, options ...trace.EventOption) {
|
||||
m.name = name
|
||||
m.options = options
|
||||
}
|
||||
|
||||
func (m *mockSpan) IsRecording() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *mockSpan) RecordError(err error, options ...trace.EventOption) {
|
||||
}
|
||||
|
||||
func (m *mockSpan) SpanContext() trace.SpanContext {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockSpan) SetStatus(code codes.Code, description string) {
|
||||
}
|
||||
|
||||
func (m *mockSpan) SetName(name string) {
|
||||
}
|
||||
|
||||
func (m *mockSpan) SetAttributes(kv ...attribute.KeyValue) {
|
||||
}
|
||||
|
||||
func (m *mockSpan) TracerProvider() trace.TracerProvider {
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockMessage struct{}
|
||||
|
||||
func (m mockMessage) ProtoReflect() protoreflect.Message {
|
||||
return new(dynamicpb.Message)
|
||||
}
|
||||
@@ -6,8 +6,10 @@ import (
|
||||
"strings"
|
||||
|
||||
ztrace "github.com/zeromicro/go-zero/internal/trace"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
@@ -20,25 +22,6 @@ var (
|
||||
TraceIDFromContext = ztrace.TraceIDFromContext
|
||||
)
|
||||
|
||||
// PeerFromCtx returns the peer from ctx.
|
||||
func PeerFromCtx(ctx context.Context) string {
|
||||
p, ok := peer.FromContext(ctx)
|
||||
if !ok || p == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return p.Addr.String()
|
||||
}
|
||||
|
||||
// SpanInfo returns the span info.
|
||||
func SpanInfo(fullMethod, peerAddress string) (string, []attribute.KeyValue) {
|
||||
attrs := []attribute.KeyValue{RPCSystemGRPC}
|
||||
name, mAttrs := ParseFullMethod(fullMethod)
|
||||
attrs = append(attrs, mAttrs...)
|
||||
attrs = append(attrs, PeerAttr(peerAddress)...)
|
||||
return name, attrs
|
||||
}
|
||||
|
||||
// ParseFullMethod returns the method name and attributes.
|
||||
func ParseFullMethod(fullMethod string) (string, []attribute.KeyValue) {
|
||||
name := strings.TrimLeft(fullMethod, "/")
|
||||
@@ -75,3 +58,33 @@ func PeerAttr(addr string) []attribute.KeyValue {
|
||||
semconv.NetPeerPortKey.String(port),
|
||||
}
|
||||
}
|
||||
|
||||
// PeerFromCtx returns the peer from ctx.
|
||||
func PeerFromCtx(ctx context.Context) string {
|
||||
p, ok := peer.FromContext(ctx)
|
||||
if !ok || p == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return p.Addr.String()
|
||||
}
|
||||
|
||||
// SpanInfo returns the span info.
|
||||
func SpanInfo(fullMethod, peerAddress string) (string, []attribute.KeyValue) {
|
||||
attrs := []attribute.KeyValue{RPCSystemGRPC}
|
||||
name, mAttrs := ParseFullMethod(fullMethod)
|
||||
attrs = append(attrs, mAttrs...)
|
||||
attrs = append(attrs, PeerAttr(peerAddress)...)
|
||||
return name, attrs
|
||||
}
|
||||
|
||||
// TracerFromContext returns a tracer in ctx, otherwise returns a global tracer.
|
||||
func TracerFromContext(ctx context.Context) (tracer trace.Tracer) {
|
||||
if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() {
|
||||
tracer = span.TracerProvider().Tracer(TraceName)
|
||||
} else {
|
||||
tracer = otel.Tracer(TraceName)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -6,8 +6,12 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/sdk/resource"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
@@ -151,3 +155,50 @@ func TestPeerAttr(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTracerFromContext(t *testing.T) {
|
||||
traceFn := func(ctx context.Context, hasTraceId bool) {
|
||||
spanContext := trace.SpanContextFromContext(ctx)
|
||||
assert.Equal(t, spanContext.IsValid(), hasTraceId)
|
||||
parentTraceId := spanContext.TraceID().String()
|
||||
|
||||
tracer := TracerFromContext(ctx)
|
||||
_, span := tracer.Start(ctx, "b")
|
||||
defer span.End()
|
||||
|
||||
spanContext = span.SpanContext()
|
||||
assert.True(t, spanContext.IsValid())
|
||||
if hasTraceId {
|
||||
assert.Equal(t, parentTraceId, spanContext.TraceID().String())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
t.Run("context", func(t *testing.T) {
|
||||
opts := []sdktrace.TracerProviderOption{
|
||||
// Set the sampling rate based on the parent span to 100%
|
||||
sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(1))),
|
||||
// Record information about this application in a Resource.
|
||||
sdktrace.WithResource(resource.NewSchemaless(semconv.ServiceNameKey.String("test"))),
|
||||
}
|
||||
tp = sdktrace.NewTracerProvider(opts...)
|
||||
otel.SetTracerProvider(tp)
|
||||
ctx, span := tp.Tracer(TraceName).Start(context.Background(), "a")
|
||||
|
||||
defer span.End()
|
||||
traceFn(ctx, true)
|
||||
})
|
||||
|
||||
t.Run("global", func(t *testing.T) {
|
||||
opts := []sdktrace.TracerProviderOption{
|
||||
// Set the sampling rate based on the parent span to 100%
|
||||
sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(1))),
|
||||
// Record information about this application in a Resource.
|
||||
sdktrace.WithResource(resource.NewSchemaless(semconv.ServiceNameKey.String("test"))),
|
||||
}
|
||||
tp = sdktrace.NewTracerProvider(opts...)
|
||||
otel.SetTracerProvider(tp)
|
||||
|
||||
traceFn(context.Background(), false)
|
||||
})
|
||||
}
|
||||
|
||||
15
readme-cn.md
15
readme-cn.md
@@ -20,9 +20,9 @@
|
||||
> ***注意:***
|
||||
>
|
||||
> 从 v1.3.0 之前版本升级请执行以下命令:
|
||||
>
|
||||
>
|
||||
> `GOPROXY=https://goproxy.cn/,direct go install github.com/zeromicro/go-zero/tools/goctl@latest`
|
||||
>
|
||||
>
|
||||
> `goctl migrate —verbose —version v1.4.3`
|
||||
|
||||
## 0. go-zero 介绍
|
||||
@@ -121,10 +121,10 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/zeromicro
|
||||
```shell
|
||||
# Go 1.15 及之前版本
|
||||
GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/zeromicro/go-zero/tools/goctl@latest
|
||||
|
||||
|
||||
# Go 1.16 及以后版本
|
||||
GOPROXY=https://goproxy.cn/,direct go install github.com/zeromicro/go-zero/tools/goctl@latest
|
||||
|
||||
|
||||
# For Mac
|
||||
brew install goctl
|
||||
|
||||
@@ -200,7 +200,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/zeromicro
|
||||
* [快速构建高并发微服务 - 多 RPC 版](https://github.com/zeromicro/zero-doc/blob/main/docs/zero/bookstore.md)
|
||||
* [goctl 使用帮助](https://github.com/zeromicro/zero-doc/blob/main/doc/goctl.md)
|
||||
* [Examples](https://github.com/zeromicro/zero-examples)
|
||||
|
||||
|
||||
* 精选 `goctl` 插件
|
||||
|
||||
| 插件 | 用途 |
|
||||
@@ -296,6 +296,11 @@ go-zero 已被许多公司用于生产部署,接入场景如在线教育、电
|
||||
>81. 广州机智云物联网科技有限公司
|
||||
>82. 厦门亿联网络技术股份有限公司
|
||||
>83. 北京麦芽田网络科技有限公司
|
||||
>84. 佛山市振联科技有限公司
|
||||
>85. 苏州智言信息科技有限公司
|
||||
>86. 中国移动上海产业研究院
|
||||
>87. 天枢数链(浙江)科技有限公司
|
||||
>88. 北京娱人共享智能科技有限公司
|
||||
|
||||
如果贵公司也已使用 go-zero,欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。
|
||||
|
||||
|
||||
12
readme.md
12
readme.md
@@ -129,10 +129,10 @@ goctl migrate —verbose —version v1.4.3
|
||||
```shell
|
||||
# for Go 1.15 and earlier
|
||||
GO111MODULE=on go get -u github.com/zeromicro/go-zero/tools/goctl@latest
|
||||
|
||||
|
||||
# for Go 1.16 and later
|
||||
go install github.com/zeromicro/go-zero/tools/goctl@latest
|
||||
|
||||
|
||||
# For Mac
|
||||
brew install goctl
|
||||
|
||||
@@ -156,24 +156,24 @@ goctl migrate —verbose —version v1.4.3
|
||||
Request {
|
||||
Name string `path:"name,options=[you,me]"` // parameters are auto validated
|
||||
}
|
||||
|
||||
|
||||
Response {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
service greet-api {
|
||||
@handler GreetHandler
|
||||
get /greet/from/:name(Request) returns (Response)
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
the .api files also can be generated by goctl, like below:
|
||||
|
||||
```shell
|
||||
goctl api -o greet.api
|
||||
```
|
||||
|
||||
|
||||
4. generate the go server-side code
|
||||
|
||||
```shell
|
||||
|
||||
@@ -25,8 +25,10 @@ const topCpuUsage = 1000
|
||||
var ErrSignatureConfig = errors.New("bad config for Signature")
|
||||
|
||||
type engine struct {
|
||||
conf RestConf
|
||||
routes []featuredRoutes
|
||||
conf RestConf
|
||||
routes []featuredRoutes
|
||||
// timeout is the max timeout of all routes
|
||||
timeout time.Duration
|
||||
unauthorizedCallback handler.UnauthorizedCallback
|
||||
unsignedCallback handler.UnsignedCallback
|
||||
chain chain.Chain
|
||||
@@ -38,8 +40,10 @@ type engine struct {
|
||||
|
||||
func newEngine(c RestConf) *engine {
|
||||
svr := &engine{
|
||||
conf: c,
|
||||
conf: c,
|
||||
timeout: time.Duration(c.Timeout) * time.Millisecond,
|
||||
}
|
||||
|
||||
if c.CpuThreshold > 0 {
|
||||
svr.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
|
||||
svr.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
|
||||
@@ -51,6 +55,12 @@ func newEngine(c RestConf) *engine {
|
||||
|
||||
func (ng *engine) addRoutes(r featuredRoutes) {
|
||||
ng.routes = append(ng.routes, r)
|
||||
|
||||
// need to guarantee the timeout is the max of all routes
|
||||
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
|
||||
if r.timeout > ng.timeout {
|
||||
ng.timeout = r.timeout
|
||||
}
|
||||
}
|
||||
|
||||
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
|
||||
@@ -314,15 +324,15 @@ func (ng *engine) use(middleware Middleware) {
|
||||
|
||||
func (ng *engine) withTimeout() internal.StartOption {
|
||||
return func(svr *http.Server) {
|
||||
timeout := ng.conf.Timeout
|
||||
timeout := ng.timeout
|
||||
if timeout > 0 {
|
||||
// factor 0.8, to avoid clients send longer content-length than the actual content,
|
||||
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
|
||||
// which triggers the circuit breaker.
|
||||
svr.ReadTimeout = 4 * time.Duration(timeout) * time.Millisecond / 5
|
||||
svr.ReadTimeout = 4 * timeout / 5
|
||||
// factor 1.1, to avoid servers don't have enough time to write responses.
|
||||
// setting the factor less than 1.0 may lead clients not receiving the responses.
|
||||
svr.WriteTimeout = 11 * time.Duration(timeout) * time.Millisecond / 10
|
||||
svr.WriteTimeout = 11 * timeout / 10
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ Verbose: true
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||
}},
|
||||
timeout: time.Minute,
|
||||
},
|
||||
{
|
||||
priority: true,
|
||||
@@ -53,6 +54,7 @@ Verbose: true
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||
}},
|
||||
timeout: time.Second,
|
||||
},
|
||||
{
|
||||
priority: true,
|
||||
@@ -159,6 +161,11 @@ Verbose: true
|
||||
}
|
||||
})
|
||||
assert.NotNil(t, ng.start(mockedRouter{}))
|
||||
timeout := time.Second * 3
|
||||
if route.timeout > timeout {
|
||||
timeout = route.timeout
|
||||
}
|
||||
assert.Equal(t, timeout, ng.timeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"path"
|
||||
"runtime"
|
||||
@@ -125,6 +127,14 @@ type timeoutWriter struct {
|
||||
|
||||
var _ http.Pusher = (*timeoutWriter)(nil)
|
||||
|
||||
func (tw *timeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hijacked, ok := tw.w.(http.Hijacker); ok {
|
||||
return hijacked.Hijack()
|
||||
}
|
||||
|
||||
return nil, nil, errors.New("server doesn't support hijacking")
|
||||
}
|
||||
|
||||
// Header returns the underline temporary http.Header.
|
||||
func (tw *timeoutWriter) Header() http.Header { return tw.h }
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/rest/internal/response"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -134,6 +135,30 @@ func TestTimeoutClientClosed(t *testing.T) {
|
||||
assert.Equal(t, statusClientClosedRequest, resp.Code)
|
||||
}
|
||||
|
||||
func TestTimeoutHijack(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
|
||||
writer := &timeoutWriter{
|
||||
w: &response.WithCodeResponseWriter{
|
||||
Writer: resp,
|
||||
},
|
||||
}
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
|
||||
writer = &timeoutWriter{
|
||||
w: &response.WithCodeResponseWriter{
|
||||
Writer: mockedHijackable{resp},
|
||||
},
|
||||
}
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
}
|
||||
|
||||
func TestTimeoutPusher(t *testing.T) {
|
||||
handler := &timeoutWriter{
|
||||
w: mockedPusher{},
|
||||
|
||||
@@ -156,12 +156,13 @@ func fillPath(u *nurl.URL, val map[string]interface{}) error {
|
||||
}
|
||||
|
||||
func request(r *http.Request, cli client) (*http.Response, error) {
|
||||
tracer := otel.Tracer(trace.TraceName)
|
||||
ctx := r.Context()
|
||||
tracer := trace.TracerFromContext(ctx)
|
||||
propagator := otel.GetTextMapPropagator()
|
||||
|
||||
spanName := r.URL.Path
|
||||
ctx, span := tracer.Start(
|
||||
r.Context(),
|
||||
ctx,
|
||||
spanName,
|
||||
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
|
||||
oteltrace.WithAttributes(semconv.HTTPClientAttributesFromHTTPRequest(r)...),
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/mapping"
|
||||
"github.com/zeromicro/go-zero/rest/internal/encoding"
|
||||
@@ -23,8 +24,15 @@ const (
|
||||
var (
|
||||
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
|
||||
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
|
||||
validator atomic.Value
|
||||
)
|
||||
|
||||
// Validator defines the interface for validating the request.
|
||||
type Validator interface {
|
||||
// Validate validates the request and parsed data.
|
||||
Validate(r *http.Request, data interface{}) error
|
||||
}
|
||||
|
||||
// Parse parses the request.
|
||||
func Parse(r *http.Request, v interface{}) error {
|
||||
if err := ParsePath(r, v); err != nil {
|
||||
@@ -39,7 +47,15 @@ func Parse(r *http.Request, v interface{}) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return ParseJsonBody(r, v)
|
||||
if err := ParseJsonBody(r, v); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if val := validator.Load(); val != nil {
|
||||
return val.(Validator).Validate(r, v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseHeaders parses the headers request.
|
||||
@@ -101,6 +117,13 @@ func ParsePath(r *http.Request, v interface{}) error {
|
||||
return pathUnmarshaler.Unmarshal(m, v)
|
||||
}
|
||||
|
||||
// SetValidator sets the validator.
|
||||
// The validator is used to validate the request, only called in Parse,
|
||||
// not in ParseHeaders, ParseForm, ParseHeader, ParseJsonBody, ParsePath.
|
||||
func SetValidator(val Validator) {
|
||||
validator.Store(val)
|
||||
}
|
||||
|
||||
func withJsonBody(r *http.Request) bool {
|
||||
return r.ContentLength > 0 && strings.Contains(r.Header.Get(header.ContentType), header.ApplicationJson)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -207,9 +209,23 @@ func TestParseJsonBody(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||
r.Header.Set(ContentType, header.JsonContentType)
|
||||
|
||||
assert.Nil(t, Parse(r, &v))
|
||||
assert.Equal(t, "kevin", v.Name)
|
||||
assert.Equal(t, 18, v.Age)
|
||||
if assert.NoError(t, Parse(r, &v)) {
|
||||
assert.Equal(t, "kevin", v.Name)
|
||||
assert.Equal(t, 18, v.Age)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bad body", func(t *testing.T) {
|
||||
var v struct {
|
||||
Name string `json:"name"`
|
||||
Age int `json:"age"`
|
||||
}
|
||||
|
||||
body := `{"name":"kevin", "ag": 18}`
|
||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||
r.Header.Set(ContentType, header.JsonContentType)
|
||||
|
||||
assert.Error(t, Parse(r, &v))
|
||||
})
|
||||
|
||||
t.Run("hasn't body", func(t *testing.T) {
|
||||
@@ -308,6 +324,36 @@ func TestParseHeaders_Error(t *testing.T) {
|
||||
assert.NotNil(t, Parse(r, &v))
|
||||
}
|
||||
|
||||
func TestParseWithValidator(t *testing.T) {
|
||||
SetValidator(mockValidator{})
|
||||
var v struct {
|
||||
Name string `form:"name"`
|
||||
Age int `form:"age"`
|
||||
Percent float64 `form:"percent,optional"`
|
||||
}
|
||||
|
||||
r, err := http.NewRequest(http.MethodGet, "/a?name=hello&age=18&percent=3.4", http.NoBody)
|
||||
assert.Nil(t, err)
|
||||
if assert.NoError(t, Parse(r, &v)) {
|
||||
assert.Equal(t, "hello", v.Name)
|
||||
assert.Equal(t, 18, v.Age)
|
||||
assert.Equal(t, 3.4, v.Percent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseWithValidatorWithError(t *testing.T) {
|
||||
SetValidator(mockValidator{})
|
||||
var v struct {
|
||||
Name string `form:"name"`
|
||||
Age int `form:"age"`
|
||||
Percent float64 `form:"percent,optional"`
|
||||
}
|
||||
|
||||
r, err := http.NewRequest(http.MethodGet, "/a?name=world&age=18&percent=3.4", http.NoBody)
|
||||
assert.Nil(t, err)
|
||||
assert.Error(t, Parse(r, &v))
|
||||
}
|
||||
|
||||
func BenchmarkParseRaw(b *testing.B) {
|
||||
r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", http.NoBody)
|
||||
if err != nil {
|
||||
@@ -351,3 +397,16 @@ func BenchmarkParseAuto(b *testing.B) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type mockValidator struct{}
|
||||
|
||||
func (m mockValidator) Validate(r *http.Request, data interface{}) error {
|
||||
if r.URL.Path == "/a" {
|
||||
val := reflect.ValueOf(data).Elem().FieldByName("Name").String()
|
||||
if val != "hello" {
|
||||
return errors.New("name is not hello")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
)
|
||||
|
||||
func TestStartHttp(t *testing.T) {
|
||||
@@ -19,6 +20,7 @@ func TestStartHttp(t *testing.T) {
|
||||
svr.IdleTimeout = 0
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
proc.WrapUp()
|
||||
}
|
||||
|
||||
func TestStartHttps(t *testing.T) {
|
||||
@@ -30,4 +32,5 @@ func TestStartHttps(t *testing.T) {
|
||||
svr.IdleTimeout = 0
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
proc.WrapUp()
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/new"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/tsgen"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/validate"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/plugin"
|
||||
)
|
||||
|
||||
@@ -127,7 +128,7 @@ func init() {
|
||||
"https://github.com/zeromicro/go-zero-template directory structure")
|
||||
goCmd.Flags().StringVar(&gogen.VarStringBranch, "branch", "", "The branch of "+
|
||||
"the remote repo, it does work with --remote")
|
||||
goCmd.Flags().StringVar(&gogen.VarStringStyle, "style", "gozero", "The file naming format,"+
|
||||
goCmd.Flags().StringVar(&gogen.VarStringStyle, "style", config.DefaultFormat, "The file naming format,"+
|
||||
" see [https://github.com/zeromicro/go-zero/blob/master/tools/goctl/config/readme.md]")
|
||||
|
||||
javaCmd.Flags().StringVar(&javagen.VarStringDir, "dir", "", "The target dir")
|
||||
@@ -146,7 +147,7 @@ func init() {
|
||||
"https://github.com/zeromicro/go-zero-template directory structure")
|
||||
newCmd.Flags().StringVar(&new.VarStringBranch, "branch", "", "The branch of "+
|
||||
"the remote repo, it does work with --remote")
|
||||
newCmd.Flags().StringVar(&new.VarStringStyle, "style", "gozero", "The file naming format,"+
|
||||
newCmd.Flags().StringVar(&new.VarStringStyle, "style", config.DefaultFormat, "The file naming format,"+
|
||||
" see [https://github.com/zeromicro/go-zero/blob/master/tools/goctl/config/readme.md]")
|
||||
|
||||
pluginCmd.Flags().StringVarP(&plugin.VarStringPlugin, "plugin", "p", "", "The plugin file")
|
||||
@@ -157,7 +158,6 @@ func init() {
|
||||
|
||||
tsCmd.Flags().StringVar(&tsgen.VarStringDir, "dir", "", "The target dir")
|
||||
tsCmd.Flags().StringVar(&tsgen.VarStringAPI, "api", "", "The api file")
|
||||
tsCmd.Flags().StringVar(&tsgen.VarStringWebAPI, "webapi", "", "The web api file path")
|
||||
tsCmd.Flags().StringVar(&tsgen.VarStringCaller, "caller", "", "The web api caller")
|
||||
tsCmd.Flags().BoolVar(&tsgen.VarBoolUnWrap, "unwrap", false, "Unwrap the webapi caller for import")
|
||||
|
||||
|
||||
@@ -30,19 +30,21 @@ Future {{pathToFuncName .Path}}( {{if ne .Method "get"}}{{with .RequestType}}{{.
|
||||
{{end}}`
|
||||
|
||||
const apiTemplateV2 = `import 'api.dart';
|
||||
import '../data/{{with .Info}}{{getBaseName .Title}}{{end}}.dart';
|
||||
import '../data/{{with .Service}}{{.Name}}{{end}}.dart';
|
||||
{{with .Service}}
|
||||
/// {{.Name}}
|
||||
{{range .Routes}}
|
||||
{{range $i, $Route := .Routes}}
|
||||
/// --{{.Path}}--
|
||||
///
|
||||
/// request: {{with .RequestType}}{{.Name}}{{end}}
|
||||
/// response: {{with .ResponseType}}{{.Name}}{{end}}
|
||||
Future {{pathToFuncName .Path}}( {{if ne .Method "get"}}{{with .RequestType}}{{.Name}} request,{{end}}{{end}}
|
||||
Future {{normalizeHandlerName .Handler}}(
|
||||
{{if hasUrlPathParams $Route}}{{extractPositionalParamsFromPath $Route}},{{end}}
|
||||
{{if ne .Method "get"}}{{with .RequestType}}{{.Name}} request,{{end}}{{end}}
|
||||
{Function({{with .ResponseType}}{{.Name}}{{end}})? ok,
|
||||
Function(String)? fail,
|
||||
Function? eventually}) async {
|
||||
await api{{if eq .Method "get"}}Get{{else}}Post{{end}}('{{.Path}}',{{if ne .Method "get"}}request,{{end}}
|
||||
await api{{if eq .Method "get"}}Get{{else}}Post{{end}}({{makeDartRequestUrlPath $Route}},{{if ne .Method "get"}}request,{{end}}
|
||||
ok: (data) {
|
||||
if (ok != null) ok({{with .ResponseType}}{{.Name}}.fromJson(data){{end}});
|
||||
}, fail: fail, eventually: eventually);
|
||||
|
||||
@@ -31,7 +31,7 @@ Future<Tokens> getTokens() async {
|
||||
try {
|
||||
var sp = await SharedPreferences.getInstance();
|
||||
var str = sp.getString('tokens');
|
||||
if (str.isEmpty) {
|
||||
if (str == null || str.isEmpty) {
|
||||
return null;
|
||||
}
|
||||
return Tokens.fromJson(jsonDecode(str));
|
||||
@@ -65,7 +65,7 @@ Future<Tokens?> getTokens() async {
|
||||
try {
|
||||
var sp = await SharedPreferences.getInstance();
|
||||
var str = sp.getString('tokens');
|
||||
if (str.isEmpty) {
|
||||
if (str == null || str.isEmpty) {
|
||||
return null;
|
||||
}
|
||||
return Tokens.fromJson(jsonDecode(str));
|
||||
|
||||
@@ -11,6 +11,18 @@ import (
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/util"
|
||||
)
|
||||
|
||||
const (
|
||||
formTagKey = "form"
|
||||
pathTagKey = "path"
|
||||
headerTagKey = "header"
|
||||
)
|
||||
|
||||
func normalizeHandlerName(handlerName string) string {
|
||||
handler := strings.Replace(handlerName, "Handler", "", 1)
|
||||
handler = lowCamelCase(handler)
|
||||
return handler
|
||||
}
|
||||
|
||||
func lowCamelCase(s string) string {
|
||||
if len(s) < 1 {
|
||||
return ""
|
||||
@@ -20,21 +32,6 @@ func lowCamelCase(s string) string {
|
||||
return util.ToLower(s[:1]) + s[1:]
|
||||
}
|
||||
|
||||
func pathToFuncName(path string) string {
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
if !strings.HasPrefix(path, "/api") {
|
||||
path = "/api" + path
|
||||
}
|
||||
|
||||
path = strings.Replace(path, "/", "_", -1)
|
||||
path = strings.Replace(path, "-", "_", -1)
|
||||
|
||||
camel := util.ToCamelCase(path)
|
||||
return util.ToLower(camel[:1]) + camel[1:]
|
||||
}
|
||||
|
||||
func getBaseName(str string) string {
|
||||
return path.Base(str)
|
||||
}
|
||||
@@ -170,3 +167,46 @@ func primitiveType(tp string) (string, bool) {
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func hasUrlPathParams(route spec.Route) bool {
|
||||
ds, ok := route.RequestType.(spec.DefineStruct)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return len(route.RequestTypeName()) > 0 && len(ds.GetTagMembers(pathTagKey)) > 0
|
||||
}
|
||||
|
||||
func extractPositionalParamsFromPath(route spec.Route) string {
|
||||
ds, ok := route.RequestType.(spec.DefineStruct)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
var params []string
|
||||
for _, member := range ds.GetTagMembers(pathTagKey) {
|
||||
dartType := member.Type.Name()
|
||||
params = append(params, fmt.Sprintf("%s %s", dartType, getPropertyFromMember(member)))
|
||||
}
|
||||
|
||||
return strings.Join(params, ", ")
|
||||
}
|
||||
|
||||
func makeDartRequestUrlPath(route spec.Route) string {
|
||||
path := route.Path
|
||||
if route.RequestType == nil {
|
||||
return `"` + path + `"`
|
||||
}
|
||||
|
||||
ds, ok := route.RequestType.(spec.DefineStruct)
|
||||
if !ok {
|
||||
return path
|
||||
}
|
||||
|
||||
for _, member := range ds.GetTagMembers(pathTagKey) {
|
||||
paramName := member.Tags()[0].Name
|
||||
path = strings.ReplaceAll(path, ":"+paramName, "${"+getPropertyFromMember(member)+"}")
|
||||
}
|
||||
|
||||
return `"` + path + `"`
|
||||
}
|
||||
|
||||
@@ -3,13 +3,16 @@ package dartgen
|
||||
import "text/template"
|
||||
|
||||
var funcMap = template.FuncMap{
|
||||
"getBaseName": getBaseName,
|
||||
"getPropertyFromMember": getPropertyFromMember,
|
||||
"isDirectType": isDirectType,
|
||||
"isClassListType": isClassListType,
|
||||
"getCoreType": getCoreType,
|
||||
"pathToFuncName": pathToFuncName,
|
||||
"lowCamelCase": lowCamelCase,
|
||||
"getBaseName": getBaseName,
|
||||
"getPropertyFromMember": getPropertyFromMember,
|
||||
"isDirectType": isDirectType,
|
||||
"isClassListType": isClassListType,
|
||||
"getCoreType": getCoreType,
|
||||
"lowCamelCase": lowCamelCase,
|
||||
"normalizeHandlerName": normalizeHandlerName,
|
||||
"hasUrlPathParams": hasUrlPathParams,
|
||||
"extractPositionalParamsFromPath": extractPositionalParamsFromPath,
|
||||
"makeDartRequestUrlPath": makeDartRequestUrlPath,
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
@@ -36,7 +37,7 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
|
||||
`
|
||||
routesAdditionTemplate = `
|
||||
server.AddRoutes(
|
||||
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}}
|
||||
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}} {{.maxBytes}}
|
||||
)
|
||||
`
|
||||
timeoutThreshold = time.Millisecond
|
||||
@@ -64,6 +65,7 @@ type (
|
||||
middlewares []string
|
||||
prefix string
|
||||
jwtTrans string
|
||||
maxBytes string
|
||||
}
|
||||
route struct {
|
||||
method string
|
||||
@@ -127,10 +129,20 @@ rest.WithPrefix("%s"),`, g.prefix)
|
||||
return fmt.Errorf("timeout should not less than 1ms, now %v", duration)
|
||||
}
|
||||
|
||||
timeout = fmt.Sprintf("rest.WithTimeout(%d * time.Millisecond),", duration/time.Millisecond)
|
||||
timeout = fmt.Sprintf("\n rest.WithTimeout(%d * time.Millisecond),", duration/time.Millisecond)
|
||||
hasTimeout = true
|
||||
}
|
||||
|
||||
var maxBytes string
|
||||
if len(g.maxBytes) > 0 {
|
||||
_, err := strconv.ParseInt(g.maxBytes, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("maxBytes %s parse error,it is an invalid number", g.maxBytes)
|
||||
}
|
||||
|
||||
maxBytes = fmt.Sprintf("\n rest.WithMaxBytes(%s),", g.maxBytes)
|
||||
}
|
||||
|
||||
var routes string
|
||||
if len(g.middlewares) > 0 {
|
||||
gbuilder.WriteString("\n}...,")
|
||||
@@ -152,6 +164,7 @@ rest.WithPrefix("%s"),`, g.prefix)
|
||||
"signature": signature,
|
||||
"prefix": prefix,
|
||||
"timeout": timeout,
|
||||
"maxBytes": maxBytes,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -230,6 +243,7 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
|
||||
}
|
||||
|
||||
groupedRoutes.timeout = g.GetAnnotation("timeout")
|
||||
groupedRoutes.maxBytes = g.GetAnnotation("maxBytes")
|
||||
|
||||
jwt := g.GetAnnotation("jwt")
|
||||
if len(jwt) > 0 {
|
||||
|
||||
@@ -176,7 +176,7 @@ func (v *ApiVisitor) VisitAtHandler(ctx *api.AtHandlerContext) interface{} {
|
||||
return &atHandler
|
||||
}
|
||||
|
||||
// serVisitRoute implements from api.BaseApiParserVisitor
|
||||
// VisitRoute implements from api.BaseApiParserVisitor
|
||||
func (v *ApiVisitor) VisitRoute(ctx *api.RouteContext) interface{} {
|
||||
var route Route
|
||||
path := ctx.Path()
|
||||
|
||||
@@ -39,6 +39,10 @@ func TsCommand(_ *cobra.Command, _ []string) error {
|
||||
return errors.New("missing -dir")
|
||||
}
|
||||
|
||||
if len(webAPI) == 0 {
|
||||
webAPI = "."
|
||||
}
|
||||
|
||||
api, err := parser.Parse(apiFile)
|
||||
if err != nil {
|
||||
fmt.Println(aurora.Red("Failed"))
|
||||
@@ -51,6 +55,7 @@ func TsCommand(_ *cobra.Command, _ []string) error {
|
||||
|
||||
api.Service = api.Service.JoinPrefix()
|
||||
logx.Must(pathx.MkdirIfNotExist(dir))
|
||||
logx.Must(genRequest(dir))
|
||||
logx.Must(genHandler(dir, webAPI, caller, api, unwrapAPI))
|
||||
logx.Must(genComponents(dir, api))
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ func genHandler(dir, webAPI, caller string, api *spec.ApiSpec, unwrapAPI bool) e
|
||||
importCaller = "{ " + importCaller + " }"
|
||||
}
|
||||
if len(webAPI) > 0 {
|
||||
imports += `import ` + importCaller + ` from ` + "\"" + webAPI + "\""
|
||||
imports += `import ` + importCaller + ` from ` + `"./gocliRequest"`
|
||||
}
|
||||
|
||||
if len(api.Types) != 0 {
|
||||
|
||||
26
tools/goctl/api/tsgen/genrequest.go
Normal file
26
tools/goctl/api/tsgen/genrequest.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package tsgen
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||
)
|
||||
|
||||
//go:embed request.ts
|
||||
var requestTemplate string
|
||||
|
||||
func genRequest(dir string) error {
|
||||
abs, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filename := filepath.Join(abs, "gocliRequest.ts")
|
||||
if pathx.FileExists(filename) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return os.WriteFile(filename, []byte(requestTemplate), 0644)
|
||||
}
|
||||
126
tools/goctl/api/tsgen/request.ts
Normal file
126
tools/goctl/api/tsgen/request.ts
Normal file
@@ -0,0 +1,126 @@
|
||||
export type Method =
|
||||
| 'get'
|
||||
| 'GET'
|
||||
| 'delete'
|
||||
| 'DELETE'
|
||||
| 'head'
|
||||
| 'HEAD'
|
||||
| 'options'
|
||||
| 'OPTIONS'
|
||||
| 'post'
|
||||
| 'POST'
|
||||
| 'put'
|
||||
| 'PUT'
|
||||
| 'patch'
|
||||
| 'PATCH';
|
||||
|
||||
/**
|
||||
* Parse route parameters for responseType
|
||||
*/
|
||||
const reg = /:[a-z|A-Z]+/g;
|
||||
|
||||
export function parseParams(url: string): Array<string> {
|
||||
const ps = url.match(reg);
|
||||
if (!ps) {
|
||||
return [];
|
||||
}
|
||||
return ps.map((k) => k.replace(/:/, ''));
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate url and parameters
|
||||
* @param url
|
||||
* @param params
|
||||
*/
|
||||
export function genUrl(url: string, params: unknown) {
|
||||
if (!params) {
|
||||
return url;
|
||||
}
|
||||
|
||||
const ps = parseParams(url);
|
||||
ps.forEach((k) => {
|
||||
const reg = new RegExp(`:${k}`);
|
||||
url = url.replace(reg, params[k]);
|
||||
});
|
||||
|
||||
const path: Array<string> = [];
|
||||
for (const key of Object.keys(params)) {
|
||||
if (!ps.find((k) => k === key)) {
|
||||
path.push(`${key}=${params[key]}`);
|
||||
}
|
||||
}
|
||||
|
||||
return url + (path.length > 0 ? `?${path.join('&')}` : '');
|
||||
}
|
||||
|
||||
export async function request({
|
||||
method,
|
||||
url,
|
||||
data,
|
||||
config = {}
|
||||
}: {
|
||||
method: Method;
|
||||
url: string;
|
||||
data?: unknown;
|
||||
config?: unknown;
|
||||
}) {
|
||||
const response = await fetch(url, {
|
||||
method: method.toLocaleUpperCase(),
|
||||
credentials: 'include',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: data ? JSON.stringify(data) : undefined,
|
||||
// @ts-ignore
|
||||
...config
|
||||
});
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
function api<T>(
|
||||
method: Method = 'get',
|
||||
url: string,
|
||||
req: any,
|
||||
config?: unknown
|
||||
): Promise<T> {
|
||||
if (url.match(/:/) || method.match(/get|delete/i)) {
|
||||
url = genUrl(url, req.params || req.forms);
|
||||
}
|
||||
method = method.toLocaleLowerCase() as Method;
|
||||
|
||||
switch (method) {
|
||||
case 'get':
|
||||
return request({method: 'get', url, data: req, config});
|
||||
case 'delete':
|
||||
return request({method: 'delete', url, data: req, config});
|
||||
case 'put':
|
||||
return request({method: 'put', url, data: req, config});
|
||||
case 'post':
|
||||
return request({method: 'post', url, data: req, config});
|
||||
case 'patch':
|
||||
return request({method: 'patch', url, data: req, config});
|
||||
default:
|
||||
return request({method: 'post', url, data: req, config});
|
||||
}
|
||||
}
|
||||
|
||||
export const webapi = {
|
||||
get<T>(url: string, req: unknown, config?: unknown): Promise<T> {
|
||||
return api<T>('get', url, req, config);
|
||||
},
|
||||
delete<T>(url: string, req: unknown, config?: unknown): Promise<T> {
|
||||
return api<T>('delete', url, req, config);
|
||||
},
|
||||
put<T>(url: string, req: unknown, config?: unknown): Promise<T> {
|
||||
return api<T>('get', url, req, config);
|
||||
},
|
||||
post<T>(url: string, req: unknown, config?: unknown): Promise<T> {
|
||||
return api<T>('post', url, req, config);
|
||||
},
|
||||
patch<T>(url: string, req: unknown, config?: unknown): Promise<T> {
|
||||
return api<T>('patch', url, req, config);
|
||||
}
|
||||
};
|
||||
|
||||
export default webapi
|
||||
@@ -70,7 +70,7 @@ spec:
|
||||
|
||||
---
|
||||
|
||||
apiVersion: autoscaling/v2beta1
|
||||
apiVersion: autoscaling/v2beta2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
name: {{.Name}}-hpa-c
|
||||
@@ -88,11 +88,13 @@ spec:
|
||||
- type: Resource
|
||||
resource:
|
||||
name: cpu
|
||||
targetAverageUtilization: 80
|
||||
target:
|
||||
type: Utilization
|
||||
averageUtilization: 80
|
||||
|
||||
---
|
||||
|
||||
apiVersion: autoscaling/v2beta1
|
||||
apiVersion: autoscaling/v2beta2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
name: {{.Name}}-hpa-m
|
||||
@@ -110,4 +112,6 @@ spec:
|
||||
- type: Resource
|
||||
resource:
|
||||
name: memory
|
||||
targetAverageUtilization: 80
|
||||
target:
|
||||
type: Utilization
|
||||
averageUtilization: 80
|
||||
|
||||
@@ -53,7 +53,7 @@ func format(query string, args ...interface{}) (string, error) {
|
||||
for _, ch := range query {
|
||||
if ch == '?' {
|
||||
if argIndex >= numArgs {
|
||||
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
|
||||
return "", fmt.Errorf("%d ? in sql, but less arguments provided", argIndex)
|
||||
}
|
||||
|
||||
arg := args[argIndex]
|
||||
@@ -79,7 +79,7 @@ func format(query string, args ...interface{}) (string, error) {
|
||||
}
|
||||
|
||||
if argIndex < numArgs {
|
||||
return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
|
||||
return "", fmt.Errorf("%d ? in sql, but more arguments provided", argIndex)
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
|
||||
@@ -2,6 +2,7 @@ package rpc
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/rpc/cli"
|
||||
)
|
||||
|
||||
@@ -53,7 +54,7 @@ func init() {
|
||||
|
||||
newCmd.Flags().StringSliceVar(&cli.VarStringSliceGoOpt, "go_opt", nil, "")
|
||||
newCmd.Flags().StringSliceVar(&cli.VarStringSliceGoGRPCOpt, "go-grpc_opt", nil, "")
|
||||
newCmd.Flags().StringVar(&cli.VarStringStyle, "style", "gozero", "The file "+
|
||||
newCmd.Flags().StringVar(&cli.VarStringStyle, "style", config.DefaultFormat, "The file "+
|
||||
"naming format, see [https://github.com/zeromicro/go-zero/tree/master/tools/goctl/config/readme.md]")
|
||||
newCmd.Flags().BoolVar(&cli.VarBoolIdea, "idea", false, "Whether the command "+
|
||||
"execution environment is from idea plugin.")
|
||||
@@ -79,7 +80,7 @@ func init() {
|
||||
protocCmd.Flags().StringSliceVar(&cli.VarStringSlicePlugin, "plugin", nil, "")
|
||||
protocCmd.Flags().StringSliceVarP(&cli.VarStringSliceProtoPath, "proto_path", "I", nil, "")
|
||||
protocCmd.Flags().StringVar(&cli.VarStringZRPCOut, "zrpc_out", "", "The zrpc output directory")
|
||||
protocCmd.Flags().StringVar(&cli.VarStringStyle, "style", "gozero", "The file "+
|
||||
protocCmd.Flags().StringVar(&cli.VarStringStyle, "style", config.DefaultFormat, "The file "+
|
||||
"naming format, see [https://github.com/zeromicro/go-zero/tree/master/tools/goctl/config/readme.md]")
|
||||
protocCmd.Flags().StringVar(&cli.VarStringHome, "home", "", "The goctl home "+
|
||||
"path of the template, --home and --remote cannot be set at the same time, if they are, "+
|
||||
|
||||
@@ -63,7 +63,24 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
|
||||
isCallPkgSameToPbPkg := childDir == ctx.GetProtoGo().Filename
|
||||
isCallPkgSameToGrpcPkg := childDir == ctx.GetProtoGo().Filename
|
||||
|
||||
functions, err := g.genFunction(proto.PbPackage, service, isCallPkgSameToGrpcPkg)
|
||||
serviceName := stringx.From(service.Name).ToCamel()
|
||||
alias := collection.NewSet()
|
||||
var hasSameNameBetweenMessageAndService bool
|
||||
for _, item := range proto.Message {
|
||||
msgName := getMessageName(*item.Message)
|
||||
if serviceName == msgName {
|
||||
hasSameNameBetweenMessageAndService = true
|
||||
}
|
||||
if !isCallPkgSameToPbPkg {
|
||||
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName),
|
||||
fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
|
||||
}
|
||||
}
|
||||
if hasSameNameBetweenMessageAndService {
|
||||
serviceName = stringx.From(service.Name + "_zrpc_client").ToCamel()
|
||||
}
|
||||
|
||||
functions, err := g.genFunction(proto.PbPackage, serviceName, service, isCallPkgSameToGrpcPkg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -78,15 +95,6 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
|
||||
return err
|
||||
}
|
||||
|
||||
alias := collection.NewSet()
|
||||
if !isCallPkgSameToPbPkg {
|
||||
for _, item := range proto.Message {
|
||||
msgName := getMessageName(*item.Message)
|
||||
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName),
|
||||
fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
|
||||
}
|
||||
}
|
||||
|
||||
pbPackage := fmt.Sprintf(`"%s"`, ctx.GetPb().Package)
|
||||
protoGoPackage := fmt.Sprintf(`"%s"`, ctx.GetProtoGo().Package)
|
||||
if isCallPkgSameToGrpcPkg {
|
||||
@@ -103,7 +111,7 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
|
||||
"filePackage": dir.Base,
|
||||
"pbPackage": pbPackage,
|
||||
"protoGoPackage": protoGoPackage,
|
||||
"serviceName": stringx.From(service.Name).ToCamel(),
|
||||
"serviceName": serviceName,
|
||||
"functions": strings.Join(functions, pathx.NL),
|
||||
"interface": strings.Join(iFunctions, pathx.NL),
|
||||
}, filename, true); err != nil {
|
||||
@@ -126,8 +134,26 @@ func (g *Generator) genCallInCompatibility(ctx DirContext, proto parser.Proto,
|
||||
return err
|
||||
}
|
||||
|
||||
serviceName := stringx.From(service.Name).ToCamel()
|
||||
alias := collection.NewSet()
|
||||
var hasSameNameBetweenMessageAndService bool
|
||||
for _, item := range proto.Message {
|
||||
msgName := getMessageName(*item.Message)
|
||||
if serviceName == msgName {
|
||||
hasSameNameBetweenMessageAndService = true
|
||||
}
|
||||
if !isCallPkgSameToPbPkg {
|
||||
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName),
|
||||
fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
|
||||
}
|
||||
}
|
||||
|
||||
if hasSameNameBetweenMessageAndService {
|
||||
serviceName = stringx.From(service.Name + "_zrpc_client").ToCamel()
|
||||
}
|
||||
|
||||
filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", callFilename))
|
||||
functions, err := g.genFunction(proto.PbPackage, service, isCallPkgSameToGrpcPkg)
|
||||
functions, err := g.genFunction(proto.PbPackage, serviceName, service, isCallPkgSameToGrpcPkg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -142,15 +168,6 @@ func (g *Generator) genCallInCompatibility(ctx DirContext, proto parser.Proto,
|
||||
return err
|
||||
}
|
||||
|
||||
alias := collection.NewSet()
|
||||
if !isCallPkgSameToPbPkg {
|
||||
for _, item := range proto.Message {
|
||||
msgName := getMessageName(*item.Message)
|
||||
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName),
|
||||
fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
|
||||
}
|
||||
}
|
||||
|
||||
pbPackage := fmt.Sprintf(`"%s"`, ctx.GetPb().Package)
|
||||
protoGoPackage := fmt.Sprintf(`"%s"`, ctx.GetProtoGo().Package)
|
||||
if isCallPkgSameToGrpcPkg {
|
||||
@@ -166,7 +183,7 @@ func (g *Generator) genCallInCompatibility(ctx DirContext, proto parser.Proto,
|
||||
"filePackage": dir.Base,
|
||||
"pbPackage": pbPackage,
|
||||
"protoGoPackage": protoGoPackage,
|
||||
"serviceName": stringx.From(service.Name).ToCamel(),
|
||||
"serviceName": serviceName,
|
||||
"functions": strings.Join(functions, pathx.NL),
|
||||
"interface": strings.Join(iFunctions, pathx.NL),
|
||||
}, filename, true)
|
||||
@@ -194,7 +211,7 @@ func getMessageName(msg proto.Message) string {
|
||||
return strings.Join(list, "_")
|
||||
}
|
||||
|
||||
func (g *Generator) genFunction(goPackage string, service parser.Service,
|
||||
func (g *Generator) genFunction(goPackage string, serviceName string, service parser.Service,
|
||||
isCallPkgSameToGrpcPkg bool) ([]string, error) {
|
||||
functions := make([]string, 0)
|
||||
|
||||
@@ -212,7 +229,7 @@ func (g *Generator) genFunction(goPackage string, service parser.Service,
|
||||
parser.CamelCase(rpc.Name), "Client")
|
||||
}
|
||||
buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
|
||||
"serviceName": stringx.From(service.Name).ToCamel(),
|
||||
"serviceName": serviceName,
|
||||
"rpcServiceName": parser.CamelCase(service.Name),
|
||||
"method": parser.CamelCase(rpc.Name),
|
||||
"package": goPackage,
|
||||
|
||||
@@ -21,9 +21,11 @@ func NewGenerator(style string, verbose bool) *Generator {
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
log := console.NewColorConsole(verbose)
|
||||
|
||||
colorLogger := console.NewColorConsole(verbose)
|
||||
|
||||
return &Generator{
|
||||
log: log,
|
||||
log: colorLogger,
|
||||
cfg: cfg,
|
||||
verbose: verbose,
|
||||
}
|
||||
|
||||
@@ -8,8 +8,11 @@ import (
|
||||
"github.com/zeromicro/go-zero/zrpc/internal/auth"
|
||||
"github.com/zeromicro/go-zero/zrpc/internal/clientinterceptors"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
)
|
||||
|
||||
const defaultClientKeepaliveTime = 20 * time.Second
|
||||
|
||||
var (
|
||||
// WithDialOption is an alias of internal.WithDialOption.
|
||||
WithDialOption = internal.WithDialOption
|
||||
@@ -62,6 +65,11 @@ func NewClient(c RpcClientConf, options ...ClientOption) (Client, error) {
|
||||
if c.Timeout > 0 {
|
||||
opts = append(opts, WithTimeout(time.Duration(c.Timeout)*time.Millisecond))
|
||||
}
|
||||
if c.KeepaliveTime > 0 {
|
||||
opts = append(opts, WithDialOption(grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: c.KeepaliveTime,
|
||||
})))
|
||||
}
|
||||
|
||||
opts = append(opts, options...)
|
||||
|
||||
@@ -90,6 +98,12 @@ func NewClientWithTarget(target string, opts ...ClientOption) (Client, error) {
|
||||
Timeout: true,
|
||||
}
|
||||
|
||||
opts = append([]ClientOption{
|
||||
WithDialOption(grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: defaultClientKeepaliveTime,
|
||||
})),
|
||||
}, opts...)
|
||||
|
||||
return internal.NewClient(target, middlewares, opts...)
|
||||
}
|
||||
|
||||
|
||||
@@ -113,10 +113,11 @@ func TestDepositServer_Deposit(t *testing.T) {
|
||||
)
|
||||
tarConfClient := MustNewClient(
|
||||
RpcClientConf{
|
||||
Target: "foo",
|
||||
App: "foo",
|
||||
Token: "bar",
|
||||
Timeout: 1000,
|
||||
Target: "foo",
|
||||
App: "foo",
|
||||
Token: "bar",
|
||||
Timeout: 1000,
|
||||
KeepaliveTime: time.Second * 15,
|
||||
Middlewares: ClientMiddlewaresConf{
|
||||
Trace: true,
|
||||
Duration: true,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package zrpc
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/discov"
|
||||
"github.com/zeromicro/go-zero/core/service"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
@@ -14,6 +16,19 @@ type (
|
||||
// ServerMiddlewaresConf defines whether to use server middlewares.
|
||||
ServerMiddlewaresConf = internal.ServerMiddlewaresConf
|
||||
|
||||
// A RpcClientConf is a rpc client config.
|
||||
RpcClientConf struct {
|
||||
Etcd discov.EtcdConf `json:",optional,inherit"`
|
||||
Endpoints []string `json:",optional"`
|
||||
Target string `json:",optional"`
|
||||
App string `json:",optional"`
|
||||
Token string `json:",optional"`
|
||||
NonBlock bool `json:",optional"`
|
||||
Timeout int64 `json:",default=2000"`
|
||||
KeepaliveTime time.Duration `json:",default=20s"`
|
||||
Middlewares ClientMiddlewaresConf
|
||||
}
|
||||
|
||||
// A RpcServerConf is a rpc server config.
|
||||
RpcServerConf struct {
|
||||
service.ServiceConf
|
||||
@@ -29,18 +44,6 @@ type (
|
||||
Health bool `json:",default=true"`
|
||||
Middlewares ServerMiddlewaresConf
|
||||
}
|
||||
|
||||
// A RpcClientConf is a rpc client config.
|
||||
RpcClientConf struct {
|
||||
Etcd discov.EtcdConf `json:",optional,inherit"`
|
||||
Endpoints []string `json:",optional"`
|
||||
Target string `json:",optional"`
|
||||
App string `json:",optional"`
|
||||
Token string `json:",optional"`
|
||||
NonBlock bool `json:",optional"`
|
||||
Timeout int64 `json:",default=2000"`
|
||||
Middlewares ClientMiddlewaresConf
|
||||
}
|
||||
)
|
||||
|
||||
// NewDirectClientConf returns a RpcClientConf.
|
||||
|
||||
@@ -10,10 +10,35 @@ import (
|
||||
)
|
||||
|
||||
func TestRpcClientConf(t *testing.T) {
|
||||
conf := NewDirectClientConf([]string{"localhost:1234"}, "foo", "bar")
|
||||
assert.True(t, conf.HasCredential())
|
||||
conf = NewEtcdClientConf([]string{"localhost:1234", "localhost:5678"}, "key", "foo", "bar")
|
||||
assert.True(t, conf.HasCredential())
|
||||
t.Run("direct", func(t *testing.T) {
|
||||
conf := NewDirectClientConf([]string{"localhost:1234"}, "foo", "bar")
|
||||
assert.True(t, conf.HasCredential())
|
||||
})
|
||||
|
||||
t.Run("etcd", func(t *testing.T) {
|
||||
conf := NewEtcdClientConf([]string{"localhost:1234", "localhost:5678"},
|
||||
"key", "foo", "bar")
|
||||
assert.True(t, conf.HasCredential())
|
||||
})
|
||||
|
||||
t.Run("etcd with account", func(t *testing.T) {
|
||||
conf := NewEtcdClientConf([]string{"localhost:1234", "localhost:5678"},
|
||||
"key", "foo", "bar")
|
||||
conf.Etcd.User = "user"
|
||||
conf.Etcd.Pass = "pass"
|
||||
_, err := conf.BuildTarget()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("etcd with tls", func(t *testing.T) {
|
||||
conf := NewEtcdClientConf([]string{"localhost:1234", "localhost:5678"},
|
||||
"key", "foo", "bar")
|
||||
conf.Etcd.CertFile = "cert"
|
||||
conf.Etcd.CertKeyFile = "key"
|
||||
conf.Etcd.CACertFile = "ca"
|
||||
_, err := conf.BuildTarget()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRpcServerConf(t *testing.T) {
|
||||
|
||||
@@ -4,9 +4,21 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/discov"
|
||||
"github.com/zeromicro/go-zero/core/netx"
|
||||
)
|
||||
|
||||
func TestNewRpcPubServer(t *testing.T) {
|
||||
s, err := NewRpcPubServer(discov.EtcdConf{
|
||||
User: "user",
|
||||
Pass: "pass",
|
||||
}, "", ServerMiddlewaresConf{})
|
||||
assert.NoError(t, err)
|
||||
assert.NotPanics(t, func() {
|
||||
s.Start(nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFigureOutListenOn(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
@@ -24,6 +36,10 @@ func TestFigureOutListenOn(t *testing.T) {
|
||||
input: ":8080",
|
||||
expect: netx.InternalIp() + ":8080",
|
||||
},
|
||||
{
|
||||
input: "",
|
||||
expect: netx.InternalIp(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
||||
@@ -59,12 +59,10 @@ func (s *rpcServer) Start(register RegisterFn) error {
|
||||
return err
|
||||
}
|
||||
|
||||
unaryInterceptors := s.buildUnaryInterceptors()
|
||||
unaryInterceptors = append(unaryInterceptors, s.unaryInterceptors...)
|
||||
streamInterceptors := s.buildStreamInterceptors()
|
||||
streamInterceptors = append(streamInterceptors, s.streamInterceptors...)
|
||||
options := append(s.options, grpc.ChainUnaryInterceptor(unaryInterceptors...),
|
||||
grpc.ChainStreamInterceptor(streamInterceptors...))
|
||||
unaryInterceptorOption := grpc.ChainUnaryInterceptor(s.buildUnaryInterceptors()...)
|
||||
streamInterceptorOption := grpc.ChainStreamInterceptor(s.buildStreamInterceptors()...)
|
||||
|
||||
options := append(s.options, unaryInterceptorOption, streamInterceptorOption)
|
||||
server := grpc.NewServer(options...)
|
||||
register(server)
|
||||
|
||||
@@ -102,7 +100,7 @@ func (s *rpcServer) buildStreamInterceptors() []grpc.StreamServerInterceptor {
|
||||
interceptors = append(interceptors, serverinterceptors.StreamBreakerInterceptor)
|
||||
}
|
||||
|
||||
return interceptors
|
||||
return append(interceptors, s.streamInterceptors...)
|
||||
}
|
||||
|
||||
func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor {
|
||||
@@ -124,7 +122,7 @@ func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor {
|
||||
interceptors = append(interceptors, serverinterceptors.UnaryBreakerInterceptor)
|
||||
}
|
||||
|
||||
return interceptors
|
||||
return append(interceptors, s.unaryInterceptors...)
|
||||
}
|
||||
|
||||
// WithMetrics returns a func that sets metrics to a Server.
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/zrpc/internal/mock"
|
||||
"google.golang.org/grpc"
|
||||
@@ -18,12 +20,13 @@ func TestRpcServer(t *testing.T) {
|
||||
Stat: true,
|
||||
Prometheus: true,
|
||||
Breaker: true,
|
||||
}, WithMetrics(metrics))
|
||||
}, WithMetrics(metrics), WithRpcHealth(true))
|
||||
server.SetName("mock")
|
||||
var wg sync.WaitGroup
|
||||
var wg, wgDone sync.WaitGroup
|
||||
var grpcServer *grpc.Server
|
||||
var lock sync.Mutex
|
||||
wg.Add(1)
|
||||
wgDone.Add(1)
|
||||
go func() {
|
||||
err := server.Start(func(server *grpc.Server) {
|
||||
lock.Lock()
|
||||
@@ -33,12 +36,16 @@ func TestRpcServer(t *testing.T) {
|
||||
wg.Done()
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
wgDone.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
lock.Lock()
|
||||
grpcServer.GracefulStop()
|
||||
lock.Unlock()
|
||||
|
||||
proc.WrapUp()
|
||||
wgDone.Wait()
|
||||
}
|
||||
|
||||
func TestRpcServer_WithBadAddress(t *testing.T) {
|
||||
@@ -48,10 +55,124 @@ func TestRpcServer_WithBadAddress(t *testing.T) {
|
||||
Stat: true,
|
||||
Prometheus: true,
|
||||
Breaker: true,
|
||||
})
|
||||
}, WithRpcHealth(true))
|
||||
server.SetName("mock")
|
||||
err := server.Start(func(server *grpc.Server) {
|
||||
mock.RegisterDepositServiceServer(server, new(mock.DepositServer))
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
proc.WrapUp()
|
||||
}
|
||||
|
||||
func TestRpcServer_buildUnaryInterceptor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
r *rpcServer
|
||||
len int
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
r: &rpcServer{
|
||||
baseRpcServer: &baseRpcServer{},
|
||||
},
|
||||
len: 0,
|
||||
},
|
||||
{
|
||||
name: "custom",
|
||||
r: &rpcServer{
|
||||
baseRpcServer: &baseRpcServer{
|
||||
unaryInterceptors: []grpc.UnaryServerInterceptor{
|
||||
func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler) (interface{}, error) {
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
len: 1,
|
||||
},
|
||||
{
|
||||
name: "middleware",
|
||||
r: &rpcServer{
|
||||
baseRpcServer: &baseRpcServer{
|
||||
unaryInterceptors: []grpc.UnaryServerInterceptor{
|
||||
func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler) (interface{}, error) {
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
},
|
||||
middlewares: ServerMiddlewaresConf{
|
||||
Trace: true,
|
||||
Recover: true,
|
||||
Stat: true,
|
||||
Prometheus: true,
|
||||
Breaker: true,
|
||||
},
|
||||
},
|
||||
len: 6,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
assert.Equal(t, test.len, len(test.r.buildUnaryInterceptors()))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRpcServer_buildStreamInterceptor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
r *rpcServer
|
||||
len int
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
r: &rpcServer{
|
||||
baseRpcServer: &baseRpcServer{},
|
||||
},
|
||||
len: 0,
|
||||
},
|
||||
{
|
||||
name: "custom",
|
||||
r: &rpcServer{
|
||||
baseRpcServer: &baseRpcServer{
|
||||
streamInterceptors: []grpc.StreamServerInterceptor{
|
||||
func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
len: 1,
|
||||
},
|
||||
{
|
||||
name: "middleware",
|
||||
r: &rpcServer{
|
||||
baseRpcServer: &baseRpcServer{
|
||||
streamInterceptors: []grpc.StreamServerInterceptor{
|
||||
func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
middlewares: ServerMiddlewaresConf{
|
||||
Trace: true,
|
||||
Recover: true,
|
||||
Breaker: true,
|
||||
},
|
||||
},
|
||||
len: 4,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
assert.Equal(t, test.len, len(test.r.buildStreamInterceptors()))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,5 +48,6 @@ func ParseTarget(target resolver.Target) (Service, error) {
|
||||
} else {
|
||||
service.Name = endpoints
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
@@ -18,6 +18,15 @@ func TestKubeBuilder_Build(t *testing.T) {
|
||||
var b kubeBuilder
|
||||
u, err := url.Parse(fmt.Sprintf("%s://%s", KubernetesScheme, "a,b"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = b.Build(resolver.Target{
|
||||
URL: *u,
|
||||
}, nil, resolver.BuildOptions{})
|
||||
assert.Error(t, err)
|
||||
|
||||
u, err = url.Parse(fmt.Sprintf("%s://%s:9100/a:b:c", KubernetesScheme, "a,b,c,d"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = b.Build(resolver.Target{
|
||||
URL: *u,
|
||||
}, nil, resolver.BuildOptions{})
|
||||
|
||||
@@ -3,15 +3,19 @@ package internal
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/resolver"
|
||||
"google.golang.org/grpc/serviceconfig"
|
||||
)
|
||||
|
||||
func TestNopResolver(t *testing.T) {
|
||||
// make sure ResolveNow & Close don't panic
|
||||
var r nopResolver
|
||||
r.ResolveNow(resolver.ResolveNowOptions{})
|
||||
r.Close()
|
||||
assert.NotPanics(t, func() {
|
||||
RegisterResolver()
|
||||
// make sure ResolveNow & Close don't panic
|
||||
var r nopResolver
|
||||
r.ResolveNow(resolver.ResolveNowOptions{})
|
||||
r.Close()
|
||||
})
|
||||
}
|
||||
|
||||
type mockedClientConn struct {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/load"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
"github.com/zeromicro/go-zero/zrpc/internal"
|
||||
"github.com/zeromicro/go-zero/zrpc/internal/auth"
|
||||
"github.com/zeromicro/go-zero/zrpc/internal/serverinterceptors"
|
||||
@@ -120,7 +121,12 @@ func setupInterceptors(server internal.Server, c RpcServerConf, metrics *stat.Me
|
||||
}
|
||||
|
||||
if c.Auth {
|
||||
authenticator, err := auth.NewAuthenticator(c.Redis.NewRedis(), c.Redis.Key, c.StrictControl)
|
||||
rds, err := redis.NewRedis(c.Redis.RedisConf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
authenticator, err := auth.NewAuthenticator(rds, c.Redis.Key, c.StrictControl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/discov"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
@@ -16,12 +17,16 @@ import (
|
||||
)
|
||||
|
||||
func TestServer_setupInterceptors(t *testing.T) {
|
||||
rds, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer rds.Close()
|
||||
|
||||
server := new(mockedServer)
|
||||
err := setupInterceptors(server, RpcServerConf{
|
||||
conf := RpcServerConf{
|
||||
Auth: true,
|
||||
Redis: redis.RedisKeyConf{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: "any",
|
||||
Host: rds.Addr(),
|
||||
Type: redis.NodeType,
|
||||
},
|
||||
Key: "foo",
|
||||
@@ -35,10 +40,15 @@ func TestServer_setupInterceptors(t *testing.T) {
|
||||
Prometheus: true,
|
||||
Breaker: true,
|
||||
},
|
||||
}, new(stat.Metrics))
|
||||
}
|
||||
err = setupInterceptors(server, conf, new(stat.Metrics))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 3, len(server.unaryInterceptors))
|
||||
assert.Equal(t, 1, len(server.streamInterceptors))
|
||||
|
||||
rds.SetError("mock error")
|
||||
err = setupInterceptors(server, conf, new(stat.Metrics))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user