Compare commits

...

72 Commits

Author SHA1 Message Date
kevin
2ea0a843f8 chore: remove any keywords 2023-03-04 20:54:26 +08:00
Kevin Wan
9e0e01b2bc chore: add tests (#2960) 2023-03-04 20:38:50 +08:00
yangjinheng
af50a80d01 timeout writer add hijack 2023-03-04 20:38:45 +08:00
yangjinheng
703fb8d970 Update timeouthandler.go 2023-03-04 20:38:40 +08:00
MarkJoyMa
e964e530e1 x 2023-03-04 20:32:21 +08:00
MarkJoyMa
52265087d1 x 2023-03-04 20:32:16 +08:00
MarkJoyMa
b4c2677eb9 add ut 2023-03-04 20:32:10 +08:00
MarkJoyMa
30296fb1ca feat: conf add FillDefault func 2023-03-04 20:31:44 +08:00
zhoumingji
356c80defd Fix bug in dartgen: The property 'isEmpty' can't be unconditionally accessed because the receiver can be 'null' 2023-03-04 20:31:38 +08:00
zhoumingji
8c31525378 Fix bug in dartgen: Increase the processing logic when route.RequestType is empty 2023-03-04 20:31:30 +08:00
cui fliter
2cf09f3c36 fix functiom name
Signed-off-by: cui fliter <imcusg@gmail.com>
2023-03-04 20:31:20 +08:00
Kevin Wan
d41e542c92 feat: support grpc client keepalive config (#2950) 2023-03-04 20:30:31 +08:00
tanglihao
265a24ac6d fix code format style use const config.DefaultFormat 2023-03-04 20:30:21 +08:00
tanglihao
7d88fc39dc fix log name conflict 2023-03-04 20:30:16 +08:00
anqiansong
6957b6a344 format code 2023-03-04 20:30:10 +08:00
anqiansong
bca6a230c8 remove unused code 2023-03-04 20:30:04 +08:00
anqiansong
cc8413d683 remove unused code 2023-03-04 20:29:56 +08:00
anqiansong
3842283fa8 Fix #2879 2023-03-04 20:29:41 +08:00
qiying.wang
fe13a533f5 chore: remove redundant prefix of "error: " in error creation 2023-03-04 20:26:40 +08:00
qiying.wang
7a327ccda4 chore: add tests for logc debug 2023-03-04 20:25:52 +08:00
qiying.wang
06e4507406 feat: add debug log for logc 2023-03-04 20:25:27 +08:00
kevin
8794d5b753 chore: add comments 2023-03-04 20:25:21 +08:00
kevin
9bfa63d995 chore: add more tests 2023-03-04 20:25:15 +08:00
kevin
a432b121fb chore: add more tests 2023-03-04 20:25:07 +08:00
kevin
b61c94bb66 feat: check key overwritten 2023-03-04 20:24:33 +08:00
Kevin Wan
93fcf899dc fix: config map cannot handle case-insensitive keys. (#2932)
* fix: #2922

* chore: rename const

* feat: support anonymous map field

* feat: support anonymous map field
2023-03-04 20:23:53 +08:00
Kevin Wan
9f4b3bae92 fix: #2899, using autoscaling/v2beta2 instead of v2beta1 (#2900)
* fix: #2899, using autoscaling/v2 instead of v2beta1

* chore: change hpa definition

---------

Co-authored-by: kevin.wan <kevin.wan@yijinin.com>
2023-03-04 20:22:27 +08:00
Kevin Wan
805cb87d98 chore: refine rest validator (#2928)
* chore: refine rest validator

* chore: add more tests

* chore: reformat code

* chore: add comments
2023-03-04 20:22:10 +08:00
Qiying Wang
366131640e feat: add configurable validator for httpx.Parse (#2923)
Co-authored-by: qiying.wang <qiying.wang@highlight.mobi>
2023-03-04 20:22:05 +08:00
Kevin Wan
956884a3ff fix: timeout not working if greater than global rest timeout (#2926) 2023-03-04 20:21:59 +08:00
raymonder jin
f571cb8af2 del unnecessary blank 2023-03-04 20:21:54 +08:00
Kevin Wan
cc5acf3b90 chore: reformat code (#2925) 2023-03-04 20:21:49 +08:00
chenquan
e1aa665443 fix: fixed the bug that old trace instances may be fetched 2023-03-04 20:21:43 +08:00
xiandong
cd357d9484 rm parseErr when kindJaeger 2023-03-04 20:21:28 +08:00
xiandong
6d4d7cbd6b rm kindJaegerUdp 2023-03-04 20:21:18 +08:00
xiandong
c593b5b531 add parseEndpoint 2023-03-04 20:20:29 +08:00
xiandong
fd5b38b07c add parseEndpoint 2023-03-04 20:20:17 +08:00
xiandong
41efb48f55 add test for Endpoint of kindJaegerUdp 2023-03-04 20:19:40 +08:00
xiandong
0ef3626839 add test for Endpoint of kindJaegerUdp 2023-03-04 20:19:34 +08:00
xiandong
77a72b16e9 add kindJaegerUdp 2023-03-04 20:19:25 +08:00
Kevin Wan
21566f1b7a chore: reformat code (#2903) 2023-03-04 20:17:35 +08:00
anqiansong
b2646e228b feat: Add request.ts (#2901)
* Add request.ts

* Update comments

* Refactor request filename
2023-03-04 20:17:21 +08:00
cong
588b883710 refactor: simplify sqlx fail fast ping and simplify miniredis setup in test (#2897)
* chore(redistest): simplify miniredis setup in test

* refactor(sqlx): simplify sqlx fail fast ping

* chore: close connection if not available
2023-03-04 20:17:16 +08:00
Kevin Wan
033910bbd8 Update readme-cn.md 2023-03-04 20:17:11 +08:00
fondoger
530dd79e3f Fix bug in dart api gen: path parameter is not replaced 2023-03-04 20:17:05 +08:00
Kevin Wan
cd5263ac75 Update readme-cn.md 2023-03-04 20:16:58 +08:00
Kevin Wan
ea3302a468 fix: test failures (#2892)
Co-authored-by: kevin.wan <kevin.wan@yijinin.com>
2023-03-04 20:16:50 +08:00
fondoger
abf15b373c Fix Dart API generation bugs; Add ability to generate API for path parameters (#2887)
* Fix bug in dartgen: Import path should match the generated api filename

* Use Route.HandlerName as generated dart API function name

Reasons:
- There is bug when using url path name as function name, because it may have invalid characters such as ":"
- Switching to HandlerName aligns with other languages such as typescript generation

* [DartGen] Add ability to generate api for url path parameters such as /path/:param
2023-03-04 20:16:44 +08:00
Kevin Wan
a865e9ee29 refactor: simplify stringx.Replacer, and avoid potential infinite loops (#2877)
* simplify replace

* backup

* refactor: simplify stringx.Replacer

* chore: add comments and const

* chore: add more tests

* chore: rename variable
2023-03-04 20:16:37 +08:00
Kevin Wan
f8292198cf Update readme-cn.md 2023-03-04 20:15:38 +08:00
Kevin Wan
016d965f56 chore: refactor (#2875) 2023-03-04 20:15:30 +08:00
dahaihu
95d7c73409 fix Replacer suffix match, and add test case (#2867)
* fix: replace shoud replace the longest match

* feat: revert bytes.Buffer to strings.Builder

* fix: loop reset nextStart

* feat: add node longest match test

* feat: add replacer suffix match test case

* feat: multiple match

* fix: partial match ends

* fix: replace look back upon error

* feat: rm unnecessary branch

---------

Co-authored-by: hudahai <hscxrzs@gmail.com>
Co-authored-by: hushichang <hushichang@sensetime.com>
2023-03-04 20:15:25 +08:00
Kevin Wan
939ef2a181 chore: add more tests (#2873) 2023-03-04 20:15:18 +08:00
Kevin Wan
f0b8dd45fe fix: test failure (#2874) 2023-03-04 20:15:08 +08:00
Mikael
0ba9335b04 only unmashal public variables (#2872)
* only unmashal public variables

* only unmashal public variables

* only unmashal public variables

* only unmashal public variables
2023-03-04 20:15:01 +08:00
Kevin Wan
04f181f0b4 chore: add more tests (#2866)
* chore: add more tests

* chore: add more tests

* chore: fix test failure
2023-03-04 20:14:54 +08:00
hudahai
89f841c126 fix: loop reset nextStart 2023-03-04 20:14:48 +08:00
hudahai
d785c8c377 feat: revert bytes.Buffer to strings.Builder 2023-03-04 20:14:41 +08:00
hudahai
687a1d15da fix: replace shoud replace the longest match 2023-03-04 20:14:35 +08:00
Kevin Wan
aaa974e1ad Update readme-cn.md 2023-03-04 20:14:22 +08:00
Kevin Wan
2779568ccf fix: conf anonymous overlay problem (#2847) 2023-03-04 20:14:10 +08:00
Kevin Wan
f7d50ae626 Update readme-cn.md 2023-03-04 20:14:01 +08:00
Kevin Wan
33594ea350 Chore/rewire (#2836)
* fix: problem on name overlaping in config (#2820)

* chore: fix missing funcs on windows (#2825)

* chore: add more tests (#2812)

* chore: add more tests

* chore: add more tests

* chore: add more tests (#2814)

* chore: add more tests (#2815)

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* feat: upgrade go to v1.18 (#2817)

* feat: upgrade go to v1.18

* feat: upgrade go to v1.18

* chore: change interface{} to any (#2818)

* chore: change interface{} to any

* chore: update goctl version to 1.5.0

* chore: update goctl deps

* chore: update goctl interface{} to any (#2819)

* chore: update goctl interface{} to any

* chore: update goctl interface{} to any

* chore(deps): bump google.golang.org/grpc from 1.52.0 to 1.52.3 (#2823)

* support custom maxBytes in API file (#2822)

* feat: mapreduce generic version (#2827)

* feat: mapreduce generic version

* fix: gateway mr type issue

---------

Co-authored-by: kevin.wan <kevin.wan@yijinin.com>

* feat: add MustNewRedis (#2824)

* feat: add MustNewRedis

* feat: add MustNewRedis

* feat: add MustNewRedis

* x

* x

* fix ut

* x

* x

* x

* x

* x

* chore: improve codecov (#2828)

* feat: converge grpc interceptor processing (#2830)

* feat: converge grpc interceptor processing

* x

* x

* chore(deps): bump go.opentelemetry.io/otel/exporters/zipkin (#2831)

* chore(deps): bump go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp (#2833)

Bumps [go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp](https://github.com/open-telemetry/opentelemetry-go) from 1.11.2 to 1.12.0.
- [Release notes](https://github.com/open-telemetry/opentelemetry-go/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-go/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-go/compare/v1.11.2...v1.12.0)

---
updated-dependencies:
- dependency-name: go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* chore(deps): bump go.opentelemetry.io/otel/exporters/jaeger (#2832)

Bumps [go.opentelemetry.io/otel/exporters/jaeger](https://github.com/open-telemetry/opentelemetry-go) from 1.11.2 to 1.12.0.
- [Release notes](https://github.com/open-telemetry/opentelemetry-go/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-go/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-go/compare/v1.11.2...v1.12.0)

---
updated-dependencies:
- dependency-name: go.opentelemetry.io/otel/exporters/jaeger
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Xiaoju Jiang <44432198+jiang4869@users.noreply.github.com>
Co-authored-by: kevin.wan <kevin.wan@yijinin.com>
Co-authored-by: MarkJoyMa <64180138+MarkJoyMa@users.noreply.github.com>
2023-03-04 20:13:37 +08:00
MarkJoyMa
ee2ec974c4 feat: converge grpc interceptor processing (#2830)
* feat: converge grpc interceptor processing

* x

* x
2023-03-04 20:12:30 +08:00
Kevin Wan
fd2f2f0f54 chore: improve codecov (#2828) 2023-03-04 20:12:16 +08:00
MarkJoyMa
86a2429d7d feat: add MustNewRedis (#2824)
* feat: add MustNewRedis

* feat: add MustNewRedis

* feat: add MustNewRedis

* x

* x

* fix ut

* x

* x

* x

* x

* x
2023-03-04 20:12:05 +08:00
Xiaoju Jiang
e5fe5dcc50 support custom maxBytes in API file (#2822) 2023-03-04 20:11:55 +08:00
Kevin Wan
b510e7c242 chore: fix missing funcs on windows (#2825) 2023-03-04 20:11:46 +08:00
Kevin Wan
dfe92e709f fix: problem on name overlaping in config (#2820) 2023-03-04 20:11:18 +08:00
Kevin Wan
cb649cf627 chore: add more tests (#2815)
* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests
2023-03-04 20:11:03 +08:00
Kevin Wan
ce19a5ade6 chore: add more tests (#2814) 2023-03-04 20:10:57 +08:00
Kevin Wan
6dc56de714 chore: add more tests (#2812)
* chore: add more tests

* chore: add more tests
2023-03-04 20:09:03 +08:00
83 changed files with 2306 additions and 384 deletions

View File

@@ -213,23 +213,23 @@ func (s *Set) validate(i interface{}) {
switch i.(type) {
case int:
if s.tp != intType {
logx.Errorf("Error: element is int, but set contains elements with type %d", s.tp)
logx.Errorf("element is int, but set contains elements with type %d", s.tp)
}
case int64:
if s.tp != int64Type {
logx.Errorf("Error: element is int64, but set contains elements with type %d", s.tp)
logx.Errorf("element is int64, but set contains elements with type %d", s.tp)
}
case uint:
if s.tp != uintType {
logx.Errorf("Error: element is uint, but set contains elements with type %d", s.tp)
logx.Errorf("element is uint, but set contains elements with type %d", s.tp)
}
case uint64:
if s.tp != uint64Type {
logx.Errorf("Error: element is uint64, but set contains elements with type %d", s.tp)
logx.Errorf("element is uint64, but set contains elements with type %d", s.tp)
}
case string:
if s.tp != stringType {
logx.Errorf("Error: element is string, but set contains elements with type %d", s.tp)
logx.Errorf("element is string, but set contains elements with type %d", s.tp)
}
}
}

View File

@@ -13,17 +13,29 @@ import (
"github.com/zeromicro/go-zero/internal/encoding"
)
var loaders = map[string]func([]byte, interface{}) error{
".json": LoadFromJsonBytes,
".toml": LoadFromTomlBytes,
".yaml": LoadFromYamlBytes,
".yml": LoadFromYamlBytes,
const jsonTagKey = "json"
var (
fillDefaultUnmarshaler = mapping.NewUnmarshaler(jsonTagKey, mapping.WithDefault())
loaders = map[string]func([]byte, interface{}) error{
".json": LoadFromJsonBytes,
".toml": LoadFromTomlBytes,
".yaml": LoadFromYamlBytes,
".yml": LoadFromYamlBytes,
}
)
// children and mapField should not be both filled.
// named fields and map cannot be bound to the same field name.
type fieldInfo struct {
children map[string]*fieldInfo
mapField *fieldInfo
}
type fieldInfo struct {
name string
kind reflect.Kind
children map[string]fieldInfo
// FillDefault fills the default values for the given v,
// and the premise is that the value of v must be guaranteed to be empty.
func FillDefault(v interface{}) error {
return fillDefaultUnmarshaler.Unmarshal(map[string]interface{}{}, v)
}
// Load loads config into v from file, .json, .yaml and .yml are acceptable.
@@ -58,13 +70,17 @@ func LoadConfig(file string, v interface{}, opts ...Option) error {
// LoadFromJsonBytes loads config into v from content json bytes.
func LoadFromJsonBytes(content []byte, v interface{}) error {
info, err := buildFieldsInfo(reflect.TypeOf(v))
if err != nil {
return err
}
var m map[string]interface{}
if err := jsonx.Unmarshal(content, &m); err != nil {
return err
}
finfo := buildFieldsInfo(reflect.TypeOf(v))
lowerCaseKeyMap := toLowerCaseKeyMap(m, finfo)
lowerCaseKeyMap := toLowerCaseKeyMap(m, info)
return mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toLowerCase))
}
@@ -108,7 +124,63 @@ func MustLoad(path string, v interface{}, opts ...Option) {
}
}
func buildFieldsInfo(tp reflect.Type) map[string]fieldInfo {
func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo) error {
if prev, ok := info.children[key]; ok {
if child.mapField != nil {
return newDupKeyError(key)
}
if err := mergeFields(prev, key, child.children); err != nil {
return err
}
} else {
info.children[key] = child
}
return nil
}
func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
switch ft.Kind() {
case reflect.Struct:
fields, err := buildFieldsInfo(ft)
if err != nil {
return err
}
for k, v := range fields.children {
if err = addOrMergeFields(info, k, v); err != nil {
return err
}
}
case reflect.Map:
elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
if err != nil {
return err
}
if _, ok := info.children[lowerCaseName]; ok {
return newDupKeyError(lowerCaseName)
}
info.children[lowerCaseName] = &fieldInfo{
children: make(map[string]*fieldInfo),
mapField: elemField,
}
default:
if _, ok := info.children[lowerCaseName]; ok {
return newDupKeyError(lowerCaseName)
}
info.children[lowerCaseName] = &fieldInfo{
children: make(map[string]*fieldInfo),
}
}
return nil
}
func buildFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
tp = mapping.Deref(tp)
switch tp.Kind() {
@@ -116,61 +188,95 @@ func buildFieldsInfo(tp reflect.Type) map[string]fieldInfo {
return buildStructFieldsInfo(tp)
case reflect.Array, reflect.Slice:
return buildFieldsInfo(mapping.Deref(tp.Elem()))
case reflect.Chan, reflect.Func:
return nil, fmt.Errorf("unsupported type: %s", tp.Kind())
default:
return nil
return &fieldInfo{
children: make(map[string]*fieldInfo),
}, nil
}
}
func buildStructFieldsInfo(tp reflect.Type) map[string]fieldInfo {
info := make(map[string]fieldInfo)
func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
var finfo *fieldInfo
var err error
switch ft.Kind() {
case reflect.Struct:
finfo, err = buildFieldsInfo(ft)
if err != nil {
return err
}
case reflect.Array, reflect.Slice:
finfo, err = buildFieldsInfo(ft.Elem())
if err != nil {
return err
}
case reflect.Map:
elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
if err != nil {
return err
}
finfo = &fieldInfo{
children: make(map[string]*fieldInfo),
mapField: elemInfo,
}
default:
finfo, err = buildFieldsInfo(ft)
if err != nil {
return err
}
}
return addOrMergeFields(info, lowerCaseName, finfo)
}
func buildStructFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
info := &fieldInfo{
children: make(map[string]*fieldInfo),
}
for i := 0; i < tp.NumField(); i++ {
field := tp.Field(i)
name := field.Name
lowerCaseName := toLowerCase(name)
ft := mapping.Deref(field.Type)
// flatten anonymous fields
if field.Anonymous {
if ft.Kind() == reflect.Struct {
fields := buildFieldsInfo(ft)
for k, v := range fields {
info[k] = v
}
} else {
info[lowerCaseName] = fieldInfo{
name: name,
kind: ft.Kind(),
}
if err := buildAnonymousFieldInfo(info, lowerCaseName, ft); err != nil {
return nil, err
}
continue
}
var fields map[string]fieldInfo
switch ft.Kind() {
case reflect.Struct:
fields = buildFieldsInfo(ft)
case reflect.Array, reflect.Slice:
fields = buildFieldsInfo(ft.Elem())
case reflect.Map:
fields = buildFieldsInfo(ft.Elem())
}
info[lowerCaseName] = fieldInfo{
name: name,
kind: ft.Kind(),
children: fields,
} else if err := buildNamedFieldInfo(info, lowerCaseName, ft); err != nil {
return nil, err
}
}
return info
return info, nil
}
func mergeFields(prev *fieldInfo, key string, children map[string]*fieldInfo) error {
if len(prev.children) == 0 || len(children) == 0 {
return newDupKeyError(key)
}
// merge fields
for k, v := range children {
if _, ok := prev.children[k]; ok {
return newDupKeyError(k)
}
prev.children[k] = v
}
return nil
}
func toLowerCase(s string) string {
return strings.ToLower(s)
}
func toLowerCaseInterface(v interface{}, info map[string]fieldInfo) interface{} {
func toLowerCaseInterface(v interface{}, info *fieldInfo) interface{} {
switch vv := v.(type) {
case map[string]interface{}:
return toLowerCaseKeyMap(vv, info)
@@ -185,19 +291,21 @@ func toLowerCaseInterface(v interface{}, info map[string]fieldInfo) interface{}
}
}
func toLowerCaseKeyMap(m map[string]interface{}, info map[string]fieldInfo) map[string]interface{} {
func toLowerCaseKeyMap(m map[string]interface{}, info *fieldInfo) map[string]interface{} {
res := make(map[string]interface{})
for k, v := range m {
ti, ok := info[k]
ti, ok := info.children[k]
if ok {
res[k] = toLowerCaseInterface(v, ti.children)
res[k] = toLowerCaseInterface(v, ti)
continue
}
lk := toLowerCase(k)
if ti, ok = info[lk]; ok {
res[lk] = toLowerCaseInterface(v, ti.children)
if ti, ok = info.children[lk]; ok {
res[lk] = toLowerCaseInterface(v, ti)
} else if info.mapField != nil {
res[k] = toLowerCaseInterface(v, info.mapField)
} else {
res[k] = v
}
@@ -205,3 +313,15 @@ func toLowerCaseKeyMap(m map[string]interface{}, info map[string]fieldInfo) map[
return res
}
type dupKeyError struct {
key string
}
func newDupKeyError(key string) dupKeyError {
return dupKeyError{key: key}
}
func (e dupKeyError) Error() string {
return fmt.Sprintf("duplicated key %s", e.key)
}

View File

@@ -9,6 +9,8 @@ import (
"github.com/zeromicro/go-zero/core/hash"
)
var dupErr dupKeyError
func TestLoadConfig_notExists(t *testing.T) {
assert.NotNil(t, Load("not_a_file", nil))
}
@@ -17,7 +19,7 @@ func TestLoadConfig_notRecogFile(t *testing.T) {
filename, err := fs.TempFilenameWithText("hello")
assert.Nil(t, err)
defer os.Remove(filename)
assert.NotNil(t, Load(filename, nil))
assert.NotNil(t, LoadConfig(filename, nil))
}
func TestConfigJson(t *testing.T) {
@@ -64,7 +66,7 @@ func TestLoadFromJsonBytesArray(t *testing.T) {
}
}
assert.NoError(t, LoadFromJsonBytes(input, &val))
assert.NoError(t, LoadConfigFromJsonBytes(input, &val))
var expect []string
for _, user := range val.Users {
expect = append(expect, user.Name)
@@ -172,7 +174,7 @@ B: bar`)
A string
B string
}
assert.NoError(t, LoadFromYamlBytes(text, &val1))
assert.NoError(t, LoadConfigFromYamlBytes(text, &val1))
assert.Equal(t, "foo", val1.A)
assert.Equal(t, "bar", val1.B)
assert.NoError(t, LoadFromYamlBytes(text, &val2))
@@ -384,6 +386,102 @@ func TestLoadFromYamlBytesLayers(t *testing.T) {
assert.Equal(t, "foo", val.Value)
}
func TestLoadFromYamlItemOverlay(t *testing.T) {
type (
Redis struct {
Host string
Port int
}
RedisKey struct {
Redis
Key string
}
Server struct {
Redis RedisKey
}
TestConfig struct {
Server
Redis Redis
}
)
input := []byte(`Redis:
Host: localhost
Port: 6379
Key: test
`)
var c TestConfig
assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
}
func TestLoadFromYamlItemOverlayReverse(t *testing.T) {
type (
Redis struct {
Host string
Port int
}
RedisKey struct {
Redis
Key string
}
Server struct {
Redis Redis
}
TestConfig struct {
Redis RedisKey
Server
}
)
input := []byte(`Redis:
Host: localhost
Port: 6379
Key: test
`)
var c TestConfig
assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
}
func TestLoadFromYamlItemOverlayWithMap(t *testing.T) {
type (
Redis struct {
Host string
Port int
}
RedisKey struct {
Redis
Key string
}
Server struct {
Redis RedisKey
}
TestConfig struct {
Server
Redis map[string]interface{}
}
)
input := []byte(`Redis:
Host: localhost
Port: 6379
Key: test
`)
var c TestConfig
assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
}
func TestUnmarshalJsonBytesMap(t *testing.T) {
input := []byte(`{"foo":{"/mtproto.RPCTos": "bff.bff","bar":"baz"}}`)
@@ -450,6 +548,480 @@ func TestUnmarshalJsonBytesWithAnonymousField(t *testing.T) {
assert.Equal(t, Int(3), c.Int)
}
func TestUnmarshalJsonBytesWithMapValueOfStruct(t *testing.T) {
type (
Value struct {
Name string
}
Config struct {
Items map[string]Value
}
)
var inputs = [][]byte{
[]byte(`{"Items": {"Key":{"Name": "foo"}}}`),
[]byte(`{"Items": {"Key":{"Name": "foo"}}}`),
[]byte(`{"items": {"key":{"name": "foo"}}}`),
[]byte(`{"items": {"key":{"name": "foo"}}}`),
}
for _, input := range inputs {
var c Config
if assert.NoError(t, LoadFromJsonBytes(input, &c)) {
assert.Equal(t, 1, len(c.Items))
for _, v := range c.Items {
assert.Equal(t, "foo", v.Name)
}
}
}
}
func TestUnmarshalJsonBytesWithMapTypeValueOfStruct(t *testing.T) {
type (
Value struct {
Name string
}
Map map[string]Value
Config struct {
Map
}
)
var inputs = [][]byte{
[]byte(`{"Map": {"Key":{"Name": "foo"}}}`),
[]byte(`{"Map": {"Key":{"Name": "foo"}}}`),
[]byte(`{"map": {"key":{"name": "foo"}}}`),
[]byte(`{"map": {"key":{"name": "foo"}}}`),
}
for _, input := range inputs {
var c Config
if assert.NoError(t, LoadFromJsonBytes(input, &c)) {
assert.Equal(t, 1, len(c.Map))
for _, v := range c.Map {
assert.Equal(t, "foo", v.Name)
}
}
}
}
func Test_FieldOverwrite(t *testing.T) {
t.Run("normal", func(t *testing.T) {
type Base struct {
Name string
}
type St1 struct {
Base
Name2 string
}
type St2 struct {
Base
Name2 string
}
type St3 struct {
*Base
Name2 string
}
type St4 struct {
*Base
Name2 *string
}
validate := func(val interface{}) {
input := []byte(`{"Name": "hello", "Name2": "world"}`)
assert.NoError(t, LoadFromJsonBytes(input, val))
}
validate(&St1{})
validate(&St2{})
validate(&St3{})
validate(&St4{})
})
t.Run("Inherit Override", func(t *testing.T) {
type Base struct {
Name string
}
type St1 struct {
Base
Name string
}
type St2 struct {
Base
Name int
}
type St3 struct {
*Base
Name int
}
type St4 struct {
*Base
Name *string
}
validate := func(val interface{}) {
input := []byte(`{"Name": "hello"}`)
err := LoadFromJsonBytes(input, val)
assert.ErrorAs(t, err, &dupErr)
assert.Equal(t, newDupKeyError("name").Error(), err.Error())
}
validate(&St1{})
validate(&St2{})
validate(&St3{})
validate(&St4{})
})
t.Run("Inherit more", func(t *testing.T) {
type Base1 struct {
Name string
}
type St0 struct {
Base1
Name string
}
type St1 struct {
St0
Name string
}
type St2 struct {
St0
Name int
}
type St3 struct {
*St0
Name int
}
type St4 struct {
*St0
Name *int
}
validate := func(val interface{}) {
input := []byte(`{"Name": "hello"}`)
err := LoadFromJsonBytes(input, val)
assert.ErrorAs(t, err, &dupErr)
assert.Equal(t, newDupKeyError("name").Error(), err.Error())
}
validate(&St0{})
validate(&St1{})
validate(&St2{})
validate(&St3{})
validate(&St4{})
})
}
func TestFieldOverwriteComplicated(t *testing.T) {
t.Run("double maps", func(t *testing.T) {
type (
Base1 struct {
Values map[string]string
}
Base2 struct {
Values map[string]string
}
Config struct {
Base1
Base2
}
)
var c Config
input := []byte(`{"Values": {"Key": "Value"}}`)
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
})
t.Run("merge children", func(t *testing.T) {
type (
Inner1 struct {
Name string
}
Inner2 struct {
Age int
}
Base1 struct {
Inner Inner1
}
Base2 struct {
Inner Inner2
}
Config struct {
Base1
Base2
}
)
var c Config
input := []byte(`{"Inner": {"Name": "foo", "Age": 10}}`)
if assert.NoError(t, LoadFromJsonBytes(input, &c)) {
assert.Equal(t, "foo", c.Base1.Inner.Name)
assert.Equal(t, 10, c.Base2.Inner.Age)
}
})
t.Run("overwritten maps", func(t *testing.T) {
type (
Inner struct {
Map map[string]string
}
Config struct {
Map map[string]string
Inner
}
)
var c Config
input := []byte(`{"Inner": {"Map": {"Key": "Value"}}}`)
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
})
t.Run("overwritten nested maps", func(t *testing.T) {
type (
Inner struct {
Map map[string]string
}
Middle1 struct {
Map map[string]string
Inner
}
Middle2 struct {
Map map[string]string
Inner
}
Config struct {
Middle1
Middle2
}
)
var c Config
input := []byte(`{"Middle1": {"Inner": {"Map": {"Key": "Value"}}}}`)
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
})
t.Run("overwritten outer/inner maps", func(t *testing.T) {
type (
Inner struct {
Map map[string]string
}
Middle struct {
Inner
Map map[string]string
}
Config struct {
Middle
}
)
var c Config
input := []byte(`{"Middle": {"Inner": {"Map": {"Key": "Value"}}}}`)
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
})
t.Run("overwritten anonymous maps", func(t *testing.T) {
type (
Inner struct {
Map map[string]string
}
Middle struct {
Inner
Map map[string]string
}
Elem map[string]Middle
Config struct {
Elem
}
)
var c Config
input := []byte(`{"Elem": {"Key": {"Inner": {"Map": {"Key": "Value"}}}}}`)
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
})
t.Run("overwritten primitive and map", func(t *testing.T) {
type (
Inner struct {
Value string
}
Elem map[string]Inner
Named struct {
Elem string
}
Config struct {
Named
Elem
}
)
var c Config
input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`)
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
})
t.Run("overwritten map and slice", func(t *testing.T) {
type (
Inner struct {
Value string
}
Elem []Inner
Named struct {
Elem string
}
Config struct {
Named
Elem
}
)
var c Config
input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`)
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
})
t.Run("overwritten map and string", func(t *testing.T) {
type (
Elem string
Named struct {
Elem string
}
Config struct {
Named
Elem
}
)
var c Config
input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`)
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
})
}
func TestLoadNamedFieldOverwritten(t *testing.T) {
t.Run("overwritten named struct", func(t *testing.T) {
type (
Elem string
Named struct {
Elem string
}
Base struct {
Named
Elem
}
Config struct {
Val Base
}
)
var c Config
input := []byte(`{"Val": {"Elem": {"Key": {"Value": "Value"}}}}`)
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
})
t.Run("overwritten named []struct", func(t *testing.T) {
type (
Elem string
Named struct {
Elem string
}
Base struct {
Named
Elem
}
Config struct {
Vals []Base
}
)
var c Config
input := []byte(`{"Vals": [{"Elem": {"Key": {"Value": "Value"}}}]}`)
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
})
t.Run("overwritten named map[string]struct", func(t *testing.T) {
type (
Elem string
Named struct {
Elem string
}
Base struct {
Named
Elem
}
Config struct {
Vals map[string]Base
}
)
var c Config
input := []byte(`{"Vals": {"Key": {"Elem": {"Key": {"Value": "Value"}}}}}`)
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
})
t.Run("overwritten named *struct", func(t *testing.T) {
type (
Elem string
Named struct {
Elem string
}
Base struct {
Named
Elem
}
Config struct {
Vals *Base
}
)
var c Config
input := []byte(`{"Vals": [{"Elem": {"Key": {"Value": "Value"}}}]}`)
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
})
t.Run("overwritten named struct", func(t *testing.T) {
type (
Named struct {
Elem string
}
Base struct {
Named
Elem Named
}
Config struct {
Val Base
}
)
var c Config
input := []byte(`{"Val": {"Elem": "Value"}}`)
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
})
t.Run("overwritten named struct", func(t *testing.T) {
type Config struct {
Val chan int
}
var c Config
input := []byte(`{"Val": 1}`)
assert.Error(t, LoadFromJsonBytes(input, &c))
})
}
func createTempFile(ext, text string) (string, error) {
tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
if err != nil {
@@ -467,3 +1039,55 @@ func createTempFile(ext, text string) (string, error) {
return filename, nil
}
func TestFillDefaultUnmarshal(t *testing.T) {
t.Run("nil", func(t *testing.T) {
type St struct{}
err := FillDefault(St{})
assert.Error(t, err)
})
t.Run("not nil", func(t *testing.T) {
type St struct{}
err := FillDefault(&St{})
assert.NoError(t, err)
})
t.Run("default", func(t *testing.T) {
type St struct {
A string `json:",default=a"`
B string
}
var st St
err := FillDefault(&st)
assert.NoError(t, err)
assert.Equal(t, st.A, "a")
})
t.Run("env", func(t *testing.T) {
type St struct {
A string `json:",default=a"`
B string
C string `json:",env=TEST_C"`
}
t.Setenv("TEST_C", "c")
var st St
err := FillDefault(&st)
assert.NoError(t, err)
assert.Equal(t, st.A, "a")
assert.Equal(t, st.C, "c")
})
t.Run("has vaue", func(t *testing.T) {
type St struct {
A string `json:",default=a"`
B string
}
var st = St{
A: "b",
}
err := FillDefault(&st)
assert.Error(t, err)
})
}

View File

@@ -12,7 +12,6 @@ import (
// PropertyError represents a configuration error message.
type PropertyError struct {
error
message string
}

View File

@@ -4,6 +4,7 @@
```go
type RestfulConf struct {
ServiceName string `json:",env=SERVICE_NAME"` // read from env automatically
Host string `json:",default=0.0.0.0"`
Port int
LogMode string `json:",options=[file,console]"`
@@ -21,20 +22,20 @@ type RestfulConf struct {
```yaml
# most fields are optional or have default values
Port: 8080
LogMode: console
port: 8080
logMode: console
# you can use env settings
MaxBytes: ${MAX_BYTES}
maxBytes: ${MAX_BYTES}
```
- toml example
```toml
# most fields are optional or have default values
Port = 8_080
LogMode = "console"
port = 8_080
logMode = "console"
# you can use env settings
MaxBytes = "${MAX_BYTES}"
maxBytes = "${MAX_BYTES}"
```
3. Load the config from a file:

View File

@@ -53,10 +53,11 @@ func TestChunkExecutorFlushInterval(t *testing.T) {
}
func TestChunkExecutorEmpty(t *testing.T) {
NewChunkExecutor(func(items []interface{}) {
executor := NewChunkExecutor(func(items []interface{}) {
assert.Fail(t, "should not called")
}, WithChunkBytes(10), WithFlushInterval(time.Millisecond))
time.Sleep(time.Millisecond * 100)
executor.Wait()
}
func TestChunkExecutorFlush(t *testing.T) {

View File

@@ -8,6 +8,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/timex"
)
@@ -67,6 +68,7 @@ func TestPeriodicalExecutor_QuitGoroutine(t *testing.T) {
ticker.Tick()
ticker.Wait(time.Millisecond * idleRound)
assert.Equal(t, routines, runtime.NumGoroutine())
proc.Shutdown()
}
func TestPeriodicalExecutor_Bulk(t *testing.T) {

View File

@@ -27,6 +27,26 @@ func Close() error {
return logx.Close()
}
// Debug writes v into access log.
func Debug(ctx context.Context, v ...interface{}) {
getLogger(ctx).Debug(v...)
}
// Debugf writes v with format into access log.
func Debugf(ctx context.Context, format string, v ...interface{}) {
getLogger(ctx).Debugf(format, v...)
}
// Debugv writes v into access log with json content.
func Debugv(ctx context.Context, v interface{}) {
getLogger(ctx).Debugv(v)
}
// Debugw writes msg along with fields into access log.
func Debugw(ctx context.Context, msg string, fields ...LogField) {
getLogger(ctx).Debugw(msg, fields...)
}
// Error writes v into error log.
func Error(ctx context.Context, v ...interface{}) {
getLogger(ctx).Error(v...)

View File

@@ -140,6 +140,54 @@ func TestInfow(t *testing.T) {
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
}
func TestDebug(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine()
Debug(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
}
func TestDebugf(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine()
Debugf(context.Background(), "foo %s", "bar")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
}
func TestDebugv(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine()
Debugv(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
}
func TestDebugw(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine()
Debugw(context.Background(), "foo", Field("a", "b"))
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
}
func TestMust(t *testing.T) {
assert.NotPanics(t, func() {
Must(nil)

View File

@@ -2,17 +2,34 @@ package logx
// A LogConf is a logging config.
type LogConf struct {
ServiceName string `json:",optional"`
Mode string `json:",default=console,options=[console,file,volume]"`
Encoding string `json:",default=json,options=[json,plain]"`
TimeFormat string `json:",optional"`
Path string `json:",default=logs"`
Level string `json:",default=info,options=[debug,info,error,severe]"`
MaxContentLength uint32 `json:",optional"`
Compress bool `json:",optional"`
Stat bool `json:",default=true"`
KeepDays int `json:",optional"`
StackCooldownMillis int `json:",default=100"`
// ServiceName represents the service name.
ServiceName string `json:",optional"`
// Mode represents the logging mode, default is `console`.
// console: log to console.
// file: log to file.
// volume: used in k8s, prepend the hostname to the log file name.
Mode string `json:",default=console,options=[console,file,volume]"`
// Encoding represents the encoding type, default is `json`.
// json: json encoding.
// plain: plain text encoding, typically used in development.
Encoding string `json:",default=json,options=[json,plain]"`
// TimeFormat represents the time format, default is `2006-01-02T15:04:05.000Z07:00`.
TimeFormat string `json:",optional"`
// Path represents the log file path, default is `logs`.
Path string `json:",default=logs"`
// Level represents the log level, default is `info`.
Level string `json:",default=info,options=[debug,info,error,severe]"`
// MaxContentLength represents the max content bytes, default is no limit.
MaxContentLength uint32 `json:",optional"`
// Compress represents whether to compress the log file, default is `false`.
Compress bool `json:",optional"`
// Stdout represents whether to log statistics, default is `true`.
Stat bool `json:",default=true"`
// KeepDays represents how many days the log files will be kept. Default to keep all files.
// Only take effect when Mode is `file` or `volume`, both work when Rotation is `daily` or `size`.
KeepDays int `json:",optional"`
// StackCooldownMillis represents the cooldown time for stack logging, default is 100ms.
StackCooldownMillis int `json:",default=100"`
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
// Only take effect when RotationRuleType is `size`.
// Even thougth `MaxBackups` sets 0, log files will still be removed

View File

@@ -108,7 +108,7 @@ func TestNopWriter(t *testing.T) {
w.Stack("foo")
w.Stat("foo")
w.Slow("foo")
w.Close()
_ = w.Close()
})
}

View File

@@ -47,6 +47,7 @@ type (
UnmarshalOption func(*unmarshalOptions)
unmarshalOptions struct {
fillDefault bool
fromString bool
canonicalKey func(key string) string
}
@@ -710,7 +711,14 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect
valuer := createValuer(m, opts)
mapValue, hasValue := getValue(valuer, canonicalKey)
if !hasValue {
// When fillDefault is used, m is a null value, hasValue must be false, all priority judgments fillDefault,
if u.opts.fillDefault {
if !value.IsZero() {
return fmt.Errorf("set the default value, %s must be zero", fullName)
}
return u.processNamedFieldWithoutValue(field.Type, value, opts, fullName)
} else if !hasValue {
return u.processNamedFieldWithoutValue(field.Type, value, opts, fullName)
}
@@ -801,6 +809,10 @@ func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, valu
}
}
if u.opts.fillDefault {
return nil
}
switch fieldKind {
case reflect.Array, reflect.Map, reflect.Slice:
if !opts.optional() {
@@ -853,7 +865,12 @@ func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v interface{}, f
numFields := baseType.NumField()
for i := 0; i < numFields; i++ {
if err := u.processField(baseType.Field(i), valElem.Field(i), m, fullName); err != nil {
field := baseType.Field(i)
if !field.IsExported() {
continue
}
if err := u.processField(field, valElem.Field(i), m, fullName); err != nil {
return err
}
}
@@ -868,13 +885,20 @@ func WithStringValues() UnmarshalOption {
}
}
// WithCanonicalKeyFunc customizes an Unmarshaler with Canonical Key func
// WithCanonicalKeyFunc customizes an Unmarshaler with Canonical Key func.
func WithCanonicalKeyFunc(f func(string) string) UnmarshalOption {
return func(opt *unmarshalOptions) {
opt.canonicalKey = f
}
}
// WithDefault customizes an Unmarshaler with fill default values.
func WithDefault() UnmarshalOption {
return func(opt *unmarshalOptions) {
opt.fillDefault = true
}
}
func createValuer(v valuerWithParent, opts *fieldOptionsWithContext) valuerWithParent {
if opts.inherit() {
return recursiveValuer{
@@ -1004,7 +1028,7 @@ func newInitError(name string) error {
}
func newTypeMismatchError(name string) error {
return fmt.Errorf("error: type mismatch for field %s", name)
return fmt.Errorf("type mismatch for field %s", name)
}
func readKeys(key string) []string {

View File

@@ -12,6 +12,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stringx"
)
@@ -793,7 +794,9 @@ func TestUnmarshalStringMapFromNotSettableValue(t *testing.T) {
}
ast := assert.New(t)
ast.Error(UnmarshalKey(m, &v))
ast.NoError(UnmarshalKey(m, &v))
assert.Empty(t, v.sort)
assert.Nil(t, v.psort)
}
func TestUnmarshalStringMapFromString(t *testing.T) {
@@ -4265,6 +4268,24 @@ func TestUnmarshalStructPtrOfPtr(t *testing.T) {
}
}
func TestUnmarshalOnlyPublicVariables(t *testing.T) {
type demo struct {
age int `key:"age"`
Name string `key:"name"`
}
m := map[string]interface{}{
"age": 3,
"name": "go-zero",
}
var in demo
if assert.NoError(t, UnmarshalKey(m, &in)) {
assert.Equal(t, 0, in.age)
assert.Equal(t, "go-zero", in.Name)
}
}
func BenchmarkDefaultValue(b *testing.B) {
for i := 0; i < b.N; i++ {
var a struct {
@@ -4364,3 +4385,56 @@ func BenchmarkUnmarshal(b *testing.B) {
UnmarshalKey(data, &an)
}
}
func TestFillDefaultUnmarshal(t *testing.T) {
fillDefaultUnmarshal := NewUnmarshaler(jsonTagKey, WithDefault())
t.Run("nil", func(t *testing.T) {
type St struct{}
err := fillDefaultUnmarshal.Unmarshal(map[string]interface{}{}, St{})
assert.Error(t, err)
})
t.Run("not nil", func(t *testing.T) {
type St struct{}
err := fillDefaultUnmarshal.Unmarshal(map[string]interface{}{}, &St{})
assert.NoError(t, err)
})
t.Run("default", func(t *testing.T) {
type St struct {
A string `json:",default=a"`
B string
}
var st St
err := fillDefaultUnmarshal.Unmarshal(map[string]interface{}{}, &st)
assert.NoError(t, err)
assert.Equal(t, st.A, "a")
})
t.Run("env", func(t *testing.T) {
type St struct {
A string `json:",default=a"`
B string
C string `json:",env=TEST_C"`
}
t.Setenv("TEST_C", "c")
var st St
err := fillDefaultUnmarshal.Unmarshal(map[string]interface{}{}, &st)
assert.NoError(t, err)
assert.Equal(t, st.A, "a")
assert.Equal(t, st.C, "c")
})
t.Run("has value", func(t *testing.T) {
type St struct {
A string `json:",default=a"`
B string
}
var st = St{
A: "b",
}
err := fillDefaultUnmarshal.Unmarshal(map[string]interface{}{}, &st)
assert.Error(t, err)
})
}

View File

@@ -5,6 +5,7 @@ import (
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/prometheus"
)
@@ -17,6 +18,9 @@ func TestNewCounterVec(t *testing.T) {
})
defer counterVec.close()
counterVecNil := NewCounterVec(nil)
counterVec.Inc("path", "code")
counterVec.Add(1, "path", "code")
proc.Shutdown()
assert.NotNil(t, counterVec)
assert.Nil(t, counterVecNil)
}

View File

@@ -5,6 +5,7 @@ import (
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/proc"
)
func TestNewGaugeVec(t *testing.T) {
@@ -18,6 +19,8 @@ func TestNewGaugeVec(t *testing.T) {
gaugeVecNil := NewGaugeVec(nil)
assert.NotNil(t, gaugeVec)
assert.Nil(t, gaugeVecNil)
proc.Shutdown()
}
func TestGaugeInc(t *testing.T) {

View File

@@ -6,6 +6,7 @@ import (
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/proc"
)
func TestNewHistogramVec(t *testing.T) {
@@ -47,4 +48,6 @@ func TestHistogramObserve(t *testing.T) {
err := testutil.CollectAndCompare(hv.histogram, strings.NewReader(metadata+val))
assert.Nil(t, err)
proc.Shutdown()
}

View File

@@ -15,5 +15,14 @@ func AddWrapUpListener(fn func()) func() {
return fn
}
// SetTimeToForceQuit does nothing on windows.
func SetTimeToForceQuit(duration time.Duration) {
}
// Shutdown does nothing on windows.
func Shutdown() {
}
// WrapUp does nothing on windows.
func WrapUp() {
}

View File

@@ -43,6 +43,16 @@ func SetTimeToForceQuit(duration time.Duration) {
delayTimeBeforeForceQuit = duration
}
// Shutdown calls the registered shutdown listeners, only for test purpose.
func Shutdown() {
shutdownListeners.notifyListeners()
}
// WrapUp wraps up the process, only for test purpose.
func WrapUp() {
wrapUpListeners.notifyListeners()
}
func gracefulStop(signals chan os.Signal) {
signal.Stop(signals)

View File

@@ -18,14 +18,14 @@ func TestShutdown(t *testing.T) {
called := AddWrapUpListener(func() {
val++
})
wrapUpListeners.notifyListeners()
WrapUp()
called()
assert.Equal(t, 1, val)
called = AddShutdownListener(func() {
val += 2
})
shutdownListeners.notifyListeners()
Shutdown()
called()
assert.Equal(t, 3, val)
}

View File

@@ -3,6 +3,7 @@ package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
)
@@ -16,3 +17,15 @@ func TestServiceConf(t *testing.T) {
}
c.MustSetUp()
}
func TestServiceConfWithMetricsUrl(t *testing.T) {
c := ServiceConf{
Name: "foo",
Log: logx.LogConf{
Mode: "volume",
},
Mode: "dev",
MetricsUrl: "http://localhost:8080",
}
assert.NoError(t, c.SetUp())
}

View File

@@ -5,6 +5,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/proc"
)
var (
@@ -55,6 +56,7 @@ func TestServiceGroup(t *testing.T) {
}
group.Stop()
proc.Shutdown()
mutex.Lock()
defer mutex.Unlock()

View File

@@ -9,6 +9,7 @@ import (
"github.com/zeromicro/go-zero/core/errorx"
"github.com/zeromicro/go-zero/core/hash"
"github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/core/syncx"
)
@@ -62,12 +63,12 @@ func New(c ClusterConf, barrier syncx.SingleFlight, st *Stat, errNotFound error,
}
if len(c) == 1 {
return NewNode(c[0].NewRedis(), barrier, st, errNotFound, opts...)
return NewNode(redis.MustNewRedis(c[0].RedisConf), barrier, st, errNotFound, opts...)
}
dispatcher := hash.NewConsistentHash()
for _, node := range c {
cn := NewNode(node.NewRedis(), barrier, st, errNotFound, opts...)
cn := NewNode(redis.MustNewRedis(node.RedisConf), barrier, st, errNotFound, opts...)
dispatcher.AddWithWeight(cn, node.Weight)
}

View File

@@ -163,12 +163,10 @@ func TestCache_SetDel(t *testing.T) {
r1, err := miniredis.Run()
assert.NoError(t, err)
defer r1.Close()
r1.SetError("mock error")
r2, err := miniredis.Run()
assert.NoError(t, err)
defer r2.Close()
r2.SetError("mock error")
conf := ClusterConf{
{
@@ -187,6 +185,8 @@ func TestCache_SetDel(t *testing.T) {
},
}
c := New(conf, syncx.NewSingleFlight(), NewStat("mock"), errPlaceholder)
r1.SetError("mock error")
r2.SetError("mock error")
assert.NoError(t, c.Del("a", "b", "c"))
})
}

View File

@@ -277,5 +277,6 @@ func (c cacheNode) processCache(ctx context.Context, key, data string, v interfa
func (c cacheNode) setCacheWithNotFound(ctx context.Context, key string) error {
seconds := int(math.Ceil(c.aroundDuration(c.notFoundExpiry).Seconds()))
return c.rds.SetexCtx(ctx, key, notFoundPlaceholder, seconds)
_, err := c.rds.SetnxExCtx(ctx, key, notFoundPlaceholder, seconds)
return err
}

View File

@@ -209,6 +209,35 @@ func TestCacheNode_TakeNotFound(t *testing.T) {
assert.Equal(t, errDummy, err)
}
func TestCacheNode_TakeNotFoundButChangedByOthers(t *testing.T) {
store, clean, err := redistest.CreateRedis()
assert.NoError(t, err)
defer clean()
cn := cacheNode{
rds: store,
r: rand.New(rand.NewSource(time.Now().UnixNano())),
barrier: syncx.NewSingleFlight(),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: NewStat("any"),
errNotFound: errTestNotFound,
}
var str string
err = cn.Take(&str, "any", func(v interface{}) error {
store.Set("any", "foo")
return errTestNotFound
})
assert.True(t, cn.IsNotFound(err))
val, err := store.Get("any")
if assert.NoError(t, err) {
assert.Equal(t, "foo", val)
}
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
}
func TestCacheNode_TakeWithExpire(t *testing.T) {
store, clean, err := redistest.CreateRedis()
assert.Nil(t, err)

View File

@@ -5,6 +5,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/proc"
)
func TestNextDelay(t *testing.T) {
@@ -51,6 +52,7 @@ func TestNextDelay(t *testing.T) {
next, ok := nextDelay(test.input)
assert.Equal(t, test.ok, ok)
assert.Equal(t, test.output, next)
proc.Shutdown()
})
}
}

View File

@@ -164,7 +164,7 @@ func NewStore(c KvConf) Store {
// because Store and redis.Redis has different methods.
dispatcher := hash.NewConsistentHash()
for _, node := range c {
cn := node.NewRedis()
cn := redis.MustNewRedis(node.RedisConf)
dispatcher.AddWithWeight(cn, node.Weight)
}

View File

@@ -5,7 +5,6 @@ import (
"github.com/zeromicro/go-zero/core/trace"
"go.mongodb.org/mongo-driver/mongo"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
oteltrace "go.opentelemetry.io/otel/trace"
@@ -14,11 +13,8 @@ import (
var mongoCmdAttributeKey = attribute.Key("mongo.cmd")
func startSpan(ctx context.Context, cmd string) (context.Context, oteltrace.Span) {
tracer := otel.Tracer(trace.TraceName)
ctx, span := tracer.Start(ctx,
spanName,
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
)
tracer := trace.TracerFromContext(ctx)
ctx, span := tracer.Start(ctx, spanName, oteltrace.WithSpanKind(oteltrace.SpanKindClient))
span.SetAttributes(mongoCmdAttributeKey.String(cmd))
return ctx, span

View File

@@ -9,6 +9,8 @@ var (
ErrEmptyType = errors.New("empty redis type")
// ErrEmptyKey is an error that indicates no redis key is set.
ErrEmptyKey = errors.New("empty redis key")
// ErrPing is an error that indicates ping failed.
ErrPing = errors.New("ping redis failed")
)
type (
@@ -28,6 +30,7 @@ type (
)
// NewRedis returns a Redis.
// Deprecated: use MustNewRedis or NewRedis instead.
func (rc RedisConf) NewRedis() *Redis {
var opts []Option
if rc.Type == ClusterType {

View File

@@ -14,7 +14,6 @@ import (
"github.com/zeromicro/go-zero/core/mapping"
"github.com/zeromicro/go-zero/core/timex"
"github.com/zeromicro/go-zero/core/trace"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
oteltrace "go.opentelemetry.io/otel/trace"
@@ -25,15 +24,13 @@ const spanName = "redis"
var (
startTimeKey = contextKey("startTime")
durationHook = hook{tracer: otel.Tracer(trace.TraceName)}
durationHook = hook{}
redisCmdsAttributeKey = attribute.Key("redis.cmds")
)
type (
contextKey string
hook struct {
tracer oteltrace.Tracer
}
hook struct{}
)
func (h hook) BeforeProcess(ctx context.Context, cmd red.Cmder) (context.Context, error) {
@@ -155,7 +152,9 @@ func logDuration(ctx context.Context, cmds []red.Cmder, duration time.Duration)
}
func (h hook) startSpan(ctx context.Context, cmds ...red.Cmder) context.Context {
ctx, span := h.tracer.Start(ctx,
tracer := trace.TracerFromContext(ctx)
ctx, span := tracer.Start(ctx,
spanName,
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
)

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"log"
"strconv"
"time"
@@ -86,7 +87,46 @@ type (
)
// New returns a Redis with given options.
// Deprecated: use MustNewRedis or NewRedis instead.
func New(addr string, opts ...Option) *Redis {
return newRedis(addr, opts...)
}
// MustNewRedis returns a Redis with given options.
func MustNewRedis(conf RedisConf, opts ...Option) *Redis {
rds, err := NewRedis(conf, opts...)
if err != nil {
log.Fatal(err)
}
return rds
}
// NewRedis returns a Redis with given options.
func NewRedis(conf RedisConf, opts ...Option) (*Redis, error) {
if err := conf.Validate(); err != nil {
return nil, err
}
if conf.Type == ClusterType {
opts = append([]Option{Cluster()}, opts...)
}
if len(conf.Pass) > 0 {
opts = append([]Option{WithPass(conf.Pass)}, opts...)
}
if conf.Tls {
opts = append([]Option{WithTLS()}, opts...)
}
rds := newRedis(conf.Host, opts...)
if !rds.Ping() {
return nil, ErrPing
}
return rds, nil
}
func newRedis(addr string, opts ...Option) *Redis {
r := &Redis{
Addr: addr,
Type: NodeType,

View File

@@ -16,6 +16,116 @@ import (
"github.com/zeromicro/go-zero/core/stringx"
)
func TestNewRedis(t *testing.T) {
r1, err := miniredis.Run()
assert.NoError(t, err)
defer r1.Close()
r2, err := miniredis.Run()
assert.NoError(t, err)
defer r2.Close()
r2.SetError("mock")
tests := []struct {
name string
RedisConf
ok bool
redisErr bool
}{
{
name: "missing host",
RedisConf: RedisConf{
Host: "",
Type: NodeType,
Pass: "",
},
ok: false,
},
{
name: "missing type",
RedisConf: RedisConf{
Host: "localhost:6379",
Type: "",
Pass: "",
},
ok: false,
},
{
name: "ok",
RedisConf: RedisConf{
Host: r1.Addr(),
Type: NodeType,
Pass: "",
},
ok: true,
},
{
name: "ok",
RedisConf: RedisConf{
Host: r1.Addr(),
Type: ClusterType,
Pass: "",
},
ok: true,
},
{
name: "password",
RedisConf: RedisConf{
Host: r1.Addr(),
Type: NodeType,
Pass: "pw",
},
ok: true,
},
{
name: "tls",
RedisConf: RedisConf{
Host: r1.Addr(),
Type: NodeType,
Tls: true,
},
ok: true,
},
{
name: "node error",
RedisConf: RedisConf{
Host: r2.Addr(),
Type: NodeType,
Pass: "",
},
ok: true,
redisErr: true,
},
{
name: "cluster error",
RedisConf: RedisConf{
Host: r2.Addr(),
Type: ClusterType,
Pass: "",
},
ok: true,
redisErr: true,
},
}
for _, test := range tests {
t.Run(stringx.RandId(), func(t *testing.T) {
rds, err := NewRedis(test.RedisConf)
if test.ok {
if test.redisErr {
assert.Error(t, err)
assert.Nil(t, rds)
} else {
assert.NoError(t, err)
assert.NotNil(t, rds)
}
} else {
assert.Error(t, err)
}
})
}
}
func TestRedis_Decr(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := New(client.Addr, badType()).Decr("a")
@@ -1651,42 +1761,17 @@ func TestRedis_WithPass(t *testing.T) {
func runOnRedis(t *testing.T, fn func(client *Redis)) {
logx.Disable()
s, err := miniredis.Run()
assert.Nil(t, err)
defer func() {
client, err := clientManager.GetResource(s.Addr(), func() (io.Closer, error) {
return nil, errors.New("should already exist")
})
if err != nil {
t.Error(err)
}
if client != nil {
_ = client.Close()
}
}()
fn(New(s.Addr()))
s := miniredis.RunT(t)
fn(MustNewRedis(RedisConf{
Host: s.Addr(),
Type: NodeType,
}))
}
func runOnRedisWithError(t *testing.T, fn func(client *Redis)) {
logx.Disable()
s, err := miniredis.Run()
assert.NoError(t, err)
defer func() {
client, err := clientManager.GetResource(s.Addr(), func() (io.Closer, error) {
return nil, errors.New("should already exist")
})
if err != nil {
t.Error(err)
}
if client != nil {
_ = client.Close()
}
}()
s := miniredis.RunT(t)
s.SetError("mock error")
fn(New(s.Addr()))
}

View File

@@ -52,14 +52,11 @@ func TestSqlConn(t *testing.T) {
}
func buildConn() (mock sqlmock.Sqlmock, err error) {
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()
return &pingedDB{
DB: db,
}, err
return db, err
})
return
}

View File

@@ -3,7 +3,6 @@ package sqlx
import (
"database/sql"
"io"
"sync"
"time"
"github.com/zeromicro/go-zero/core/syncx"
@@ -17,43 +16,29 @@ const (
var connManager = syncx.NewResourceManager()
type pingedDB struct {
*sql.DB
once sync.Once
}
func getCachedSqlConn(driverName, server string) (*pingedDB, error) {
func getCachedSqlConn(driverName, server string) (*sql.DB, error) {
val, err := connManager.GetResource(server, func() (io.Closer, error) {
conn, err := newDBConnection(driverName, server)
if err != nil {
return nil, err
}
return &pingedDB{
DB: conn,
}, nil
return conn, nil
})
if err != nil {
return nil, err
}
return val.(*pingedDB), nil
return val.(*sql.DB), nil
}
func getSqlConn(driverName, server string) (*sql.DB, error) {
pdb, err := getCachedSqlConn(driverName, server)
conn, err := getCachedSqlConn(driverName, server)
if err != nil {
return nil, err
}
pdb.once.Do(func() {
err = pdb.Ping()
})
if err != nil {
return nil, err
}
return pdb.DB, nil
return conn, nil
}
func newDBConnection(driverName, datasource string) (*sql.DB, error) {
@@ -70,5 +55,10 @@ func newDBConnection(driverName, datasource string) (*sql.DB, error) {
conn.SetMaxOpenConns(maxOpenConns)
conn.SetConnMaxLifetime(maxLifetime)
if err := conn.Ping(); err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}

View File

@@ -5,7 +5,6 @@ import (
"database/sql"
"github.com/zeromicro/go-zero/core/trace"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
oteltrace "go.opentelemetry.io/otel/trace"
@@ -14,11 +13,8 @@ import (
var sqlAttributeKey = attribute.Key("sql.method")
func startSpan(ctx context.Context, method string) (context.Context, oteltrace.Span) {
tracer := otel.Tracer(trace.TraceName)
start, span := tracer.Start(ctx,
spanName,
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
)
tracer := trace.TracerFromContext(ctx)
start, span := tracer.Start(ctx, spanName, oteltrace.WithSpanKind(oteltrace.SpanKindClient))
span.SetAttributes(sqlAttributeKey.String(method))
return start, span

View File

@@ -66,7 +66,7 @@ func format(query string, args ...interface{}) (string, error) {
switch ch {
case '?':
if argIndex >= numArgs {
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
return "", fmt.Errorf("%d ? in sql, but less arguments provided", argIndex)
}
writeValue(&b, args[argIndex])
@@ -93,7 +93,7 @@ func format(query string, args ...interface{}) (string, error) {
index--
if index < 0 || numArgs <= index {
return "", fmt.Errorf("error: wrong index %d in sql", index)
return "", fmt.Errorf("wrong index %d in sql", index)
}
writeValue(&b, args[index])
@@ -124,7 +124,7 @@ func format(query string, args ...interface{}) (string, error) {
}
if argIndex < numArgs {
return "", fmt.Errorf("error: %d arguments provided, not matching sql", argIndex)
return "", fmt.Errorf("%d arguments provided, not matching sql", argIndex)
}
return b.String(), nil

View File

@@ -14,7 +14,6 @@ func (n *node) add(word string) {
}
nd := n
var depth int
for i, char := range chars {
if nd.children == nil {
child := new(node)
@@ -23,7 +22,6 @@ func (n *node) add(word string) {
nd = child
} else if child, ok := nd.children[char]; ok {
nd = child
depth++
} else {
child := new(node)
child.depth = i + 1

View File

@@ -1,6 +1,13 @@
package stringx
import "strings"
import (
"sort"
"strings"
)
// replace more than once to avoid overlapped keywords after replace.
// only try 2 times to avoid too many or infinite loops.
const replaceTimes = 2
type (
// Replacer interface wraps the Replace method.
@@ -30,68 +37,48 @@ func NewReplacer(mapping map[string]string) Replacer {
// Replace replaces text with given substitutes.
func (r *replacer) Replace(text string) string {
var builder strings.Builder
var start int
chars := []rune(text)
size := len(chars)
for start < size {
cur := r.node
if start > 0 {
builder.WriteString(string(chars[:start]))
}
for i := start; i < size; i++ {
child, ok := cur.children[chars[i]]
if ok {
cur = child
} else if cur == r.node {
builder.WriteRune(chars[i])
// cur already points to root, set start only
start = i + 1
continue
} else {
curDepth := cur.depth
cur = cur.fail
child, ok = cur.children[chars[i]]
if !ok {
// write this path
builder.WriteString(string(chars[i-curDepth : i+1]))
// go to root
cur = r.node
start = i + 1
continue
}
failDepth := cur.depth
// write path before jump
builder.WriteString(string(chars[start : start+curDepth-failDepth]))
start += curDepth - failDepth
cur = child
}
if cur.end {
val := string(chars[i+1-cur.depth : i+1])
builder.WriteString(r.mapping[val])
builder.WriteString(string(chars[i+1:]))
// only matching this path, all previous paths are done
if start >= i+1-cur.depth && i+1 >= size {
return builder.String()
}
chars = []rune(builder.String())
size = len(chars)
builder.Reset()
break
}
}
if !cur.end {
builder.WriteString(string(chars[start:]))
return builder.String()
for i := 0; i < replaceTimes; i++ {
var replaced bool
if text, replaced = r.doReplace(text); !replaced {
return text
}
}
return string(chars)
return text
}
func (r *replacer) doReplace(text string) (string, bool) {
chars := []rune(text)
scopes := r.find(chars)
if len(scopes) == 0 {
return text, false
}
sort.Slice(scopes, func(i, j int) bool {
if scopes[i].start < scopes[j].start {
return true
}
if scopes[i].start == scopes[j].start {
return scopes[i].stop > scopes[j].stop
}
return false
})
var buf strings.Builder
var index int
for i := 0; i < len(scopes); i++ {
scp := &scopes[i]
if scp.start < index {
continue
}
buf.WriteString(string(chars[index:scp.start]))
buf.WriteString(r.mapping[string(chars[scp.start:scp.stop])])
index = scp.stop
}
if index < len(chars) {
buf.WriteString(string(chars[index:]))
}
return buf.String(), true
}

View File

@@ -1,5 +1,4 @@
//go:build go1.18
// +build go1.18
package stringx

View File

@@ -15,6 +15,15 @@ func TestReplacer_Replace(t *testing.T) {
assert.Equal(t, "零1234五", NewReplacer(mapping).Replace("零一二三四五"))
}
func TestReplacer_ReplaceJumpMatch(t *testing.T) {
mapping := map[string]string{
"abcdeg": "ABCDEG",
"cdef": "CDEF",
"cde": "CDE",
}
assert.Equal(t, "abCDEF", NewReplacer(mapping).Replace("abcdef"))
}
func TestReplacer_ReplaceOverlap(t *testing.T) {
mapping := map[string]string{
"3d": "34",
@@ -44,6 +53,14 @@ func TestReplacer_ReplacePartialMatch(t *testing.T) {
assert.Equal(t, "零一二三四五", NewReplacer(mapping).Replace("零一二三四五"))
}
func TestReplacer_ReplacePartialMatchEnds(t *testing.T) {
mapping := map[string]string{
"二三四七": "2347",
"三四": "34",
}
assert.Equal(t, "零一二34", NewReplacer(mapping).Replace("零一二三四"))
}
func TestReplacer_ReplaceMultiMatches(t *testing.T) {
mapping := map[string]string{
"二三": "23",
@@ -51,6 +68,54 @@ func TestReplacer_ReplaceMultiMatches(t *testing.T) {
assert.Equal(t, "零一23四五一23四五", NewReplacer(mapping).Replace("零一二三四五一二三四五"))
}
func TestReplacer_ReplaceLongestMatching(t *testing.T) {
keywords := map[string]string{
"日本": "japan",
"日本的首都": "东京",
}
replacer := NewReplacer(keywords)
assert.Equal(t, "东京在japan", replacer.Replace("日本的首都在日本"))
}
func TestReplacer_ReplaceSuffixMatch(t *testing.T) {
// case1
{
keywords := map[string]string{
"abcde": "ABCDE",
"bcde": "BCDE",
"bcd": "BCD",
}
assert.Equal(t, "aBCDf", NewReplacer(keywords).Replace("abcdf"))
}
// case2
{
keywords := map[string]string{
"abcde": "ABCDE",
"bcde": "BCDE",
"cde": "CDE",
"c": "C",
"cd": "CD",
}
assert.Equal(t, "abCDf", NewReplacer(keywords).Replace("abcdf"))
}
}
func TestReplacer_ReplaceLongestOverlap(t *testing.T) {
keywords := map[string]string{
"456": "def",
"abcd": "1234",
}
replacer := NewReplacer(keywords)
assert.Equal(t, "123def7", replacer.Replace("abcd567"))
}
func TestReplacer_ReplaceLongestLonger(t *testing.T) {
mapping := map[string]string{
"c": "3",
}
assert.Equal(t, "3d", NewReplacer(mapping).Replace("cd"))
}
func TestReplacer_ReplaceJumpToFail(t *testing.T) {
mapping := map[string]string{
"bcdf": "1235",
@@ -146,3 +211,21 @@ func TestFuzzReplacerCase2(t *testing.T) {
t.Errorf("result: %s, match: %v", val, keys)
}
}
func TestReplacer_ReplaceLongestMatch(t *testing.T) {
replacer := NewReplacer(map[string]string{
"日本的首都": "东京",
"日本": "本日",
})
assert.Equal(t, "东京是东京", replacer.Replace("日本的首都是东京"))
}
func TestReplacer_ReplaceIndefinitely(t *testing.T) {
mapping := map[string]string{
"日本的首都": "东京",
"东京": "日本的首都",
}
assert.NotPanics(t, func() {
NewReplacer(mapping).Replace("日本的首都是东京")
})
}

View File

@@ -3,6 +3,7 @@ package trace
import (
"context"
"fmt"
"net/url"
"sync"
"github.com/zeromicro/go-zero/core/lang"
@@ -57,6 +58,10 @@ func createExporter(c Config) (sdktrace.SpanExporter, error) {
// Just support jaeger and zipkin now, more for later
switch c.Batcher {
case kindJaeger:
u, _ := url.Parse(c.Endpoint)
if u.Scheme == "udp" {
return jaeger.New(jaeger.WithAgentEndpoint(jaeger.WithAgentHost(u.Hostname()), jaeger.WithAgentPort(u.Port())))
}
return jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint(c.Endpoint)))
case kindZipkin:
return zipkin.New(c.Endpoint)

View File

@@ -15,6 +15,7 @@ func TestStartAgent(t *testing.T) {
endpoint2 = "remotehost:1234"
endpoint3 = "localhost:1235"
endpoint4 = "localhost:1236"
endpoint5 = "udp://localhost:6831"
)
c1 := Config{
Name: "foo",
@@ -44,6 +45,11 @@ func TestStartAgent(t *testing.T) {
Endpoint: endpoint4,
Batcher: kindOtlpHttp,
}
c7 := Config{
Name: "UDP",
Endpoint: endpoint5,
Batcher: kindJaeger,
}
StartAgent(c1)
StartAgent(c1)
@@ -52,16 +58,19 @@ func TestStartAgent(t *testing.T) {
StartAgent(c4)
StartAgent(c5)
StartAgent(c6)
StartAgent(c7)
lock.Lock()
defer lock.Unlock()
// because remotehost cannot be resolved
assert.Equal(t, 4, len(agents))
assert.Equal(t, 5, len(agents))
_, ok := agents[""]
assert.True(t, ok)
_, ok = agents[endpoint1]
assert.True(t, ok)
_, ok = agents[endpoint2]
assert.False(t, ok)
_, ok = agents[endpoint5]
assert.True(t, ok)
}

View File

@@ -0,0 +1,73 @@
package trace
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/dynamicpb"
)
func TestMessageType_Event(t *testing.T) {
var span mockSpan
ctx := trace.ContextWithSpan(context.Background(), &span)
MessageReceived.Event(ctx, 1, "foo")
assert.Equal(t, messageEvent, span.name)
assert.NotEmpty(t, span.options)
}
func TestMessageType_EventProtoMessage(t *testing.T) {
var span mockSpan
var message mockMessage
ctx := trace.ContextWithSpan(context.Background(), &span)
MessageReceived.Event(ctx, 1, message)
assert.Equal(t, messageEvent, span.name)
assert.NotEmpty(t, span.options)
}
type mockSpan struct {
name string
options []trace.EventOption
}
func (m *mockSpan) End(options ...trace.SpanEndOption) {
}
func (m *mockSpan) AddEvent(name string, options ...trace.EventOption) {
m.name = name
m.options = options
}
func (m *mockSpan) IsRecording() bool {
return false
}
func (m *mockSpan) RecordError(err error, options ...trace.EventOption) {
}
func (m *mockSpan) SpanContext() trace.SpanContext {
panic("implement me")
}
func (m *mockSpan) SetStatus(code codes.Code, description string) {
}
func (m *mockSpan) SetName(name string) {
}
func (m *mockSpan) SetAttributes(kv ...attribute.KeyValue) {
}
func (m *mockSpan) TracerProvider() trace.TracerProvider {
return nil
}
type mockMessage struct{}
func (m mockMessage) ProtoReflect() protoreflect.Message {
return new(dynamicpb.Message)
}

View File

@@ -6,8 +6,10 @@ import (
"strings"
ztrace "github.com/zeromicro/go-zero/internal/trace"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/peer"
)
@@ -20,25 +22,6 @@ var (
TraceIDFromContext = ztrace.TraceIDFromContext
)
// PeerFromCtx returns the peer from ctx.
func PeerFromCtx(ctx context.Context) string {
p, ok := peer.FromContext(ctx)
if !ok || p == nil {
return ""
}
return p.Addr.String()
}
// SpanInfo returns the span info.
func SpanInfo(fullMethod, peerAddress string) (string, []attribute.KeyValue) {
attrs := []attribute.KeyValue{RPCSystemGRPC}
name, mAttrs := ParseFullMethod(fullMethod)
attrs = append(attrs, mAttrs...)
attrs = append(attrs, PeerAttr(peerAddress)...)
return name, attrs
}
// ParseFullMethod returns the method name and attributes.
func ParseFullMethod(fullMethod string) (string, []attribute.KeyValue) {
name := strings.TrimLeft(fullMethod, "/")
@@ -75,3 +58,33 @@ func PeerAttr(addr string) []attribute.KeyValue {
semconv.NetPeerPortKey.String(port),
}
}
// PeerFromCtx returns the peer from ctx.
func PeerFromCtx(ctx context.Context) string {
p, ok := peer.FromContext(ctx)
if !ok || p == nil {
return ""
}
return p.Addr.String()
}
// SpanInfo returns the span info.
func SpanInfo(fullMethod, peerAddress string) (string, []attribute.KeyValue) {
attrs := []attribute.KeyValue{RPCSystemGRPC}
name, mAttrs := ParseFullMethod(fullMethod)
attrs = append(attrs, mAttrs...)
attrs = append(attrs, PeerAttr(peerAddress)...)
return name, attrs
}
// TracerFromContext returns a tracer in ctx, otherwise returns a global tracer.
func TracerFromContext(ctx context.Context) (tracer trace.Tracer) {
if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() {
tracer = span.TracerProvider().Tracer(TraceName)
} else {
tracer = otel.Tracer(TraceName)
}
return
}

View File

@@ -6,8 +6,12 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/peer"
)
@@ -151,3 +155,50 @@ func TestPeerAttr(t *testing.T) {
})
}
}
func TestTracerFromContext(t *testing.T) {
traceFn := func(ctx context.Context, hasTraceId bool) {
spanContext := trace.SpanContextFromContext(ctx)
assert.Equal(t, spanContext.IsValid(), hasTraceId)
parentTraceId := spanContext.TraceID().String()
tracer := TracerFromContext(ctx)
_, span := tracer.Start(ctx, "b")
defer span.End()
spanContext = span.SpanContext()
assert.True(t, spanContext.IsValid())
if hasTraceId {
assert.Equal(t, parentTraceId, spanContext.TraceID().String())
}
}
t.Run("context", func(t *testing.T) {
opts := []sdktrace.TracerProviderOption{
// Set the sampling rate based on the parent span to 100%
sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(1))),
// Record information about this application in a Resource.
sdktrace.WithResource(resource.NewSchemaless(semconv.ServiceNameKey.String("test"))),
}
tp = sdktrace.NewTracerProvider(opts...)
otel.SetTracerProvider(tp)
ctx, span := tp.Tracer(TraceName).Start(context.Background(), "a")
defer span.End()
traceFn(ctx, true)
})
t.Run("global", func(t *testing.T) {
opts := []sdktrace.TracerProviderOption{
// Set the sampling rate based on the parent span to 100%
sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(1))),
// Record information about this application in a Resource.
sdktrace.WithResource(resource.NewSchemaless(semconv.ServiceNameKey.String("test"))),
}
tp = sdktrace.NewTracerProvider(opts...)
otel.SetTracerProvider(tp)
traceFn(context.Background(), false)
})
}

View File

@@ -20,9 +20,9 @@
> ***注意:***
>
> 从 v1.3.0 之前版本升级请执行以下命令:
>
>
> `GOPROXY=https://goproxy.cn/,direct go install github.com/zeromicro/go-zero/tools/goctl@latest`
>
>
> `goctl migrate —verbose —version v1.4.3`
## 0. go-zero 介绍
@@ -121,10 +121,10 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/zeromicro
```shell
# Go 1.15 及之前版本
GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/zeromicro/go-zero/tools/goctl@latest
# Go 1.16 及以后版本
GOPROXY=https://goproxy.cn/,direct go install github.com/zeromicro/go-zero/tools/goctl@latest
# For Mac
brew install goctl
@@ -200,7 +200,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/zeromicro
* [快速构建高并发微服务 - 多 RPC 版](https://github.com/zeromicro/zero-doc/blob/main/docs/zero/bookstore.md)
* [goctl 使用帮助](https://github.com/zeromicro/zero-doc/blob/main/doc/goctl.md)
* [Examples](https://github.com/zeromicro/zero-examples)
* 精选 `goctl` 插件
| 插件 | 用途 |
@@ -296,6 +296,11 @@ go-zero 已被许多公司用于生产部署,接入场景如在线教育、电
>81. 广州机智云物联网科技有限公司
>82. 厦门亿联网络技术股份有限公司
>83. 北京麦芽田网络科技有限公司
>84. 佛山市振联科技有限公司
>85. 苏州智言信息科技有限公司
>86. 中国移动上海产业研究院
>87. 天枢数链(浙江)科技有限公司
>88. 北京娱人共享智能科技有限公司
如果贵公司也已使用 go-zero欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。

View File

@@ -129,10 +129,10 @@ goctl migrate —verbose —version v1.4.3
```shell
# for Go 1.15 and earlier
GO111MODULE=on go get -u github.com/zeromicro/go-zero/tools/goctl@latest
# for Go 1.16 and later
go install github.com/zeromicro/go-zero/tools/goctl@latest
# For Mac
brew install goctl
@@ -156,24 +156,24 @@ goctl migrate —verbose —version v1.4.3
Request {
Name string `path:"name,options=[you,me]"` // parameters are auto validated
}
Response {
Message string `json:"message"`
}
)
service greet-api {
@handler GreetHandler
get /greet/from/:name(Request) returns (Response)
}
```
the .api files also can be generated by goctl, like below:
```shell
goctl api -o greet.api
```
4. generate the go server-side code
```shell

View File

@@ -25,8 +25,10 @@ const topCpuUsage = 1000
var ErrSignatureConfig = errors.New("bad config for Signature")
type engine struct {
conf RestConf
routes []featuredRoutes
conf RestConf
routes []featuredRoutes
// timeout is the max timeout of all routes
timeout time.Duration
unauthorizedCallback handler.UnauthorizedCallback
unsignedCallback handler.UnsignedCallback
chain chain.Chain
@@ -38,8 +40,10 @@ type engine struct {
func newEngine(c RestConf) *engine {
svr := &engine{
conf: c,
conf: c,
timeout: time.Duration(c.Timeout) * time.Millisecond,
}
if c.CpuThreshold > 0 {
svr.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
svr.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
@@ -51,6 +55,12 @@ func newEngine(c RestConf) *engine {
func (ng *engine) addRoutes(r featuredRoutes) {
ng.routes = append(ng.routes, r)
// need to guarantee the timeout is the max of all routes
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
if r.timeout > ng.timeout {
ng.timeout = r.timeout
}
}
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
@@ -314,15 +324,15 @@ func (ng *engine) use(middleware Middleware) {
func (ng *engine) withTimeout() internal.StartOption {
return func(svr *http.Server) {
timeout := ng.conf.Timeout
timeout := ng.timeout
if timeout > 0 {
// factor 0.8, to avoid clients send longer content-length than the actual content,
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
// which triggers the circuit breaker.
svr.ReadTimeout = 4 * time.Duration(timeout) * time.Millisecond / 5
svr.ReadTimeout = 4 * timeout / 5
// factor 1.1, to avoid servers don't have enough time to write responses.
// setting the factor less than 1.0 may lead clients not receiving the responses.
svr.WriteTimeout = 11 * time.Duration(timeout) * time.Millisecond / 10
svr.WriteTimeout = 11 * timeout / 10
}
}
}

View File

@@ -43,6 +43,7 @@ Verbose: true
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
timeout: time.Minute,
},
{
priority: true,
@@ -53,6 +54,7 @@ Verbose: true
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
timeout: time.Second,
},
{
priority: true,
@@ -159,6 +161,11 @@ Verbose: true
}
})
assert.NotNil(t, ng.start(mockedRouter{}))
timeout := time.Second * 3
if route.timeout > timeout {
timeout = route.timeout
}
assert.Equal(t, timeout, ng.timeout)
}
}
}

View File

@@ -1,11 +1,13 @@
package handler
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"path"
"runtime"
@@ -125,6 +127,14 @@ type timeoutWriter struct {
var _ http.Pusher = (*timeoutWriter)(nil)
func (tw *timeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacked, ok := tw.w.(http.Hijacker); ok {
return hijacked.Hijack()
}
return nil, nil, errors.New("server doesn't support hijacking")
}
// Header returns the underline temporary http.Header.
func (tw *timeoutWriter) Header() http.Header { return tw.h }

View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/rest/internal/response"
)
func init() {
@@ -134,6 +135,30 @@ func TestTimeoutClientClosed(t *testing.T) {
assert.Equal(t, statusClientClosedRequest, resp.Code)
}
func TestTimeoutHijack(t *testing.T) {
resp := httptest.NewRecorder()
writer := &timeoutWriter{
w: &response.WithCodeResponseWriter{
Writer: resp,
},
}
assert.NotPanics(t, func() {
writer.Hijack()
})
writer = &timeoutWriter{
w: &response.WithCodeResponseWriter{
Writer: mockedHijackable{resp},
},
}
assert.NotPanics(t, func() {
writer.Hijack()
})
}
func TestTimeoutPusher(t *testing.T) {
handler := &timeoutWriter{
w: mockedPusher{},

View File

@@ -156,12 +156,13 @@ func fillPath(u *nurl.URL, val map[string]interface{}) error {
}
func request(r *http.Request, cli client) (*http.Response, error) {
tracer := otel.Tracer(trace.TraceName)
ctx := r.Context()
tracer := trace.TracerFromContext(ctx)
propagator := otel.GetTextMapPropagator()
spanName := r.URL.Path
ctx, span := tracer.Start(
r.Context(),
ctx,
spanName,
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
oteltrace.WithAttributes(semconv.HTTPClientAttributesFromHTTPRequest(r)...),

View File

@@ -4,6 +4,7 @@ import (
"io"
"net/http"
"strings"
"sync/atomic"
"github.com/zeromicro/go-zero/core/mapping"
"github.com/zeromicro/go-zero/rest/internal/encoding"
@@ -23,8 +24,15 @@ const (
var (
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
validator atomic.Value
)
// Validator defines the interface for validating the request.
type Validator interface {
// Validate validates the request and parsed data.
Validate(r *http.Request, data interface{}) error
}
// Parse parses the request.
func Parse(r *http.Request, v interface{}) error {
if err := ParsePath(r, v); err != nil {
@@ -39,7 +47,15 @@ func Parse(r *http.Request, v interface{}) error {
return err
}
return ParseJsonBody(r, v)
if err := ParseJsonBody(r, v); err != nil {
return err
}
if val := validator.Load(); val != nil {
return val.(Validator).Validate(r, v)
}
return nil
}
// ParseHeaders parses the headers request.
@@ -101,6 +117,13 @@ func ParsePath(r *http.Request, v interface{}) error {
return pathUnmarshaler.Unmarshal(m, v)
}
// SetValidator sets the validator.
// The validator is used to validate the request, only called in Parse,
// not in ParseHeaders, ParseForm, ParseHeader, ParseJsonBody, ParsePath.
func SetValidator(val Validator) {
validator.Store(val)
}
func withJsonBody(r *http.Request) bool {
return r.ContentLength > 0 && strings.Contains(r.Header.Get(header.ContentType), header.ApplicationJson)
}

View File

@@ -1,8 +1,10 @@
package httpx
import (
"errors"
"net/http"
"net/http/httptest"
"reflect"
"strconv"
"strings"
"testing"
@@ -207,9 +209,23 @@ func TestParseJsonBody(t *testing.T) {
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
r.Header.Set(ContentType, header.JsonContentType)
assert.Nil(t, Parse(r, &v))
assert.Equal(t, "kevin", v.Name)
assert.Equal(t, 18, v.Age)
if assert.NoError(t, Parse(r, &v)) {
assert.Equal(t, "kevin", v.Name)
assert.Equal(t, 18, v.Age)
}
})
t.Run("bad body", func(t *testing.T) {
var v struct {
Name string `json:"name"`
Age int `json:"age"`
}
body := `{"name":"kevin", "ag": 18}`
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
r.Header.Set(ContentType, header.JsonContentType)
assert.Error(t, Parse(r, &v))
})
t.Run("hasn't body", func(t *testing.T) {
@@ -308,6 +324,36 @@ func TestParseHeaders_Error(t *testing.T) {
assert.NotNil(t, Parse(r, &v))
}
func TestParseWithValidator(t *testing.T) {
SetValidator(mockValidator{})
var v struct {
Name string `form:"name"`
Age int `form:"age"`
Percent float64 `form:"percent,optional"`
}
r, err := http.NewRequest(http.MethodGet, "/a?name=hello&age=18&percent=3.4", http.NoBody)
assert.Nil(t, err)
if assert.NoError(t, Parse(r, &v)) {
assert.Equal(t, "hello", v.Name)
assert.Equal(t, 18, v.Age)
assert.Equal(t, 3.4, v.Percent)
}
}
func TestParseWithValidatorWithError(t *testing.T) {
SetValidator(mockValidator{})
var v struct {
Name string `form:"name"`
Age int `form:"age"`
Percent float64 `form:"percent,optional"`
}
r, err := http.NewRequest(http.MethodGet, "/a?name=world&age=18&percent=3.4", http.NoBody)
assert.Nil(t, err)
assert.Error(t, Parse(r, &v))
}
func BenchmarkParseRaw(b *testing.B) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", http.NoBody)
if err != nil {
@@ -351,3 +397,16 @@ func BenchmarkParseAuto(b *testing.B) {
}
}
}
type mockValidator struct{}
func (m mockValidator) Validate(r *http.Request, data interface{}) error {
if r.URL.Path == "/a" {
val := reflect.ValueOf(data).Elem().FieldByName("Name").String()
if val != "hello" {
return errors.New("name is not hello")
}
}
return nil
}

View File

@@ -8,6 +8,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/proc"
)
func TestStartHttp(t *testing.T) {
@@ -19,6 +20,7 @@ func TestStartHttp(t *testing.T) {
svr.IdleTimeout = 0
})
assert.NotNil(t, err)
proc.WrapUp()
}
func TestStartHttps(t *testing.T) {
@@ -30,4 +32,5 @@ func TestStartHttps(t *testing.T) {
svr.IdleTimeout = 0
})
assert.NotNil(t, err)
proc.WrapUp()
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/zeromicro/go-zero/tools/goctl/api/new"
"github.com/zeromicro/go-zero/tools/goctl/api/tsgen"
"github.com/zeromicro/go-zero/tools/goctl/api/validate"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/plugin"
)
@@ -127,7 +128,7 @@ func init() {
"https://github.com/zeromicro/go-zero-template directory structure")
goCmd.Flags().StringVar(&gogen.VarStringBranch, "branch", "", "The branch of "+
"the remote repo, it does work with --remote")
goCmd.Flags().StringVar(&gogen.VarStringStyle, "style", "gozero", "The file naming format,"+
goCmd.Flags().StringVar(&gogen.VarStringStyle, "style", config.DefaultFormat, "The file naming format,"+
" see [https://github.com/zeromicro/go-zero/blob/master/tools/goctl/config/readme.md]")
javaCmd.Flags().StringVar(&javagen.VarStringDir, "dir", "", "The target dir")
@@ -146,7 +147,7 @@ func init() {
"https://github.com/zeromicro/go-zero-template directory structure")
newCmd.Flags().StringVar(&new.VarStringBranch, "branch", "", "The branch of "+
"the remote repo, it does work with --remote")
newCmd.Flags().StringVar(&new.VarStringStyle, "style", "gozero", "The file naming format,"+
newCmd.Flags().StringVar(&new.VarStringStyle, "style", config.DefaultFormat, "The file naming format,"+
" see [https://github.com/zeromicro/go-zero/blob/master/tools/goctl/config/readme.md]")
pluginCmd.Flags().StringVarP(&plugin.VarStringPlugin, "plugin", "p", "", "The plugin file")
@@ -157,7 +158,6 @@ func init() {
tsCmd.Flags().StringVar(&tsgen.VarStringDir, "dir", "", "The target dir")
tsCmd.Flags().StringVar(&tsgen.VarStringAPI, "api", "", "The api file")
tsCmd.Flags().StringVar(&tsgen.VarStringWebAPI, "webapi", "", "The web api file path")
tsCmd.Flags().StringVar(&tsgen.VarStringCaller, "caller", "", "The web api caller")
tsCmd.Flags().BoolVar(&tsgen.VarBoolUnWrap, "unwrap", false, "Unwrap the webapi caller for import")

View File

@@ -30,19 +30,21 @@ Future {{pathToFuncName .Path}}( {{if ne .Method "get"}}{{with .RequestType}}{{.
{{end}}`
const apiTemplateV2 = `import 'api.dart';
import '../data/{{with .Info}}{{getBaseName .Title}}{{end}}.dart';
import '../data/{{with .Service}}{{.Name}}{{end}}.dart';
{{with .Service}}
/// {{.Name}}
{{range .Routes}}
{{range $i, $Route := .Routes}}
/// --{{.Path}}--
///
/// request: {{with .RequestType}}{{.Name}}{{end}}
/// response: {{with .ResponseType}}{{.Name}}{{end}}
Future {{pathToFuncName .Path}}( {{if ne .Method "get"}}{{with .RequestType}}{{.Name}} request,{{end}}{{end}}
Future {{normalizeHandlerName .Handler}}(
{{if hasUrlPathParams $Route}}{{extractPositionalParamsFromPath $Route}},{{end}}
{{if ne .Method "get"}}{{with .RequestType}}{{.Name}} request,{{end}}{{end}}
{Function({{with .ResponseType}}{{.Name}}{{end}})? ok,
Function(String)? fail,
Function? eventually}) async {
await api{{if eq .Method "get"}}Get{{else}}Post{{end}}('{{.Path}}',{{if ne .Method "get"}}request,{{end}}
await api{{if eq .Method "get"}}Get{{else}}Post{{end}}({{makeDartRequestUrlPath $Route}},{{if ne .Method "get"}}request,{{end}}
ok: (data) {
if (ok != null) ok({{with .ResponseType}}{{.Name}}.fromJson(data){{end}});
}, fail: fail, eventually: eventually);

View File

@@ -31,7 +31,7 @@ Future<Tokens> getTokens() async {
try {
var sp = await SharedPreferences.getInstance();
var str = sp.getString('tokens');
if (str.isEmpty) {
if (str == null || str.isEmpty) {
return null;
}
return Tokens.fromJson(jsonDecode(str));
@@ -65,7 +65,7 @@ Future<Tokens?> getTokens() async {
try {
var sp = await SharedPreferences.getInstance();
var str = sp.getString('tokens');
if (str.isEmpty) {
if (str == null || str.isEmpty) {
return null;
}
return Tokens.fromJson(jsonDecode(str));

View File

@@ -11,6 +11,18 @@ import (
"github.com/zeromicro/go-zero/tools/goctl/api/util"
)
const (
formTagKey = "form"
pathTagKey = "path"
headerTagKey = "header"
)
func normalizeHandlerName(handlerName string) string {
handler := strings.Replace(handlerName, "Handler", "", 1)
handler = lowCamelCase(handler)
return handler
}
func lowCamelCase(s string) string {
if len(s) < 1 {
return ""
@@ -20,21 +32,6 @@ func lowCamelCase(s string) string {
return util.ToLower(s[:1]) + s[1:]
}
func pathToFuncName(path string) string {
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
if !strings.HasPrefix(path, "/api") {
path = "/api" + path
}
path = strings.Replace(path, "/", "_", -1)
path = strings.Replace(path, "-", "_", -1)
camel := util.ToCamelCase(path)
return util.ToLower(camel[:1]) + camel[1:]
}
func getBaseName(str string) string {
return path.Base(str)
}
@@ -170,3 +167,46 @@ func primitiveType(tp string) (string, bool) {
return "", false
}
func hasUrlPathParams(route spec.Route) bool {
ds, ok := route.RequestType.(spec.DefineStruct)
if !ok {
return false
}
return len(route.RequestTypeName()) > 0 && len(ds.GetTagMembers(pathTagKey)) > 0
}
func extractPositionalParamsFromPath(route spec.Route) string {
ds, ok := route.RequestType.(spec.DefineStruct)
if !ok {
return ""
}
var params []string
for _, member := range ds.GetTagMembers(pathTagKey) {
dartType := member.Type.Name()
params = append(params, fmt.Sprintf("%s %s", dartType, getPropertyFromMember(member)))
}
return strings.Join(params, ", ")
}
func makeDartRequestUrlPath(route spec.Route) string {
path := route.Path
if route.RequestType == nil {
return `"` + path + `"`
}
ds, ok := route.RequestType.(spec.DefineStruct)
if !ok {
return path
}
for _, member := range ds.GetTagMembers(pathTagKey) {
paramName := member.Tags()[0].Name
path = strings.ReplaceAll(path, ":"+paramName, "${"+getPropertyFromMember(member)+"}")
}
return `"` + path + `"`
}

View File

@@ -3,13 +3,16 @@ package dartgen
import "text/template"
var funcMap = template.FuncMap{
"getBaseName": getBaseName,
"getPropertyFromMember": getPropertyFromMember,
"isDirectType": isDirectType,
"isClassListType": isClassListType,
"getCoreType": getCoreType,
"pathToFuncName": pathToFuncName,
"lowCamelCase": lowCamelCase,
"getBaseName": getBaseName,
"getPropertyFromMember": getPropertyFromMember,
"isDirectType": isDirectType,
"isClassListType": isClassListType,
"getCoreType": getCoreType,
"lowCamelCase": lowCamelCase,
"normalizeHandlerName": normalizeHandlerName,
"hasUrlPathParams": hasUrlPathParams,
"extractPositionalParamsFromPath": extractPositionalParamsFromPath,
"makeDartRequestUrlPath": makeDartRequestUrlPath,
}
const (

View File

@@ -5,6 +5,7 @@ import (
"os"
"path"
"sort"
"strconv"
"strings"
"text/template"
"time"
@@ -36,7 +37,7 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
`
routesAdditionTemplate = `
server.AddRoutes(
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}}
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}} {{.maxBytes}}
)
`
timeoutThreshold = time.Millisecond
@@ -64,6 +65,7 @@ type (
middlewares []string
prefix string
jwtTrans string
maxBytes string
}
route struct {
method string
@@ -127,10 +129,20 @@ rest.WithPrefix("%s"),`, g.prefix)
return fmt.Errorf("timeout should not less than 1ms, now %v", duration)
}
timeout = fmt.Sprintf("rest.WithTimeout(%d * time.Millisecond),", duration/time.Millisecond)
timeout = fmt.Sprintf("\n rest.WithTimeout(%d * time.Millisecond),", duration/time.Millisecond)
hasTimeout = true
}
var maxBytes string
if len(g.maxBytes) > 0 {
_, err := strconv.ParseInt(g.maxBytes, 10, 64)
if err != nil {
return fmt.Errorf("maxBytes %s parse error,it is an invalid number", g.maxBytes)
}
maxBytes = fmt.Sprintf("\n rest.WithMaxBytes(%s),", g.maxBytes)
}
var routes string
if len(g.middlewares) > 0 {
gbuilder.WriteString("\n}...,")
@@ -152,6 +164,7 @@ rest.WithPrefix("%s"),`, g.prefix)
"signature": signature,
"prefix": prefix,
"timeout": timeout,
"maxBytes": maxBytes,
}); err != nil {
return err
}
@@ -230,6 +243,7 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
}
groupedRoutes.timeout = g.GetAnnotation("timeout")
groupedRoutes.maxBytes = g.GetAnnotation("maxBytes")
jwt := g.GetAnnotation("jwt")
if len(jwt) > 0 {

View File

@@ -176,7 +176,7 @@ func (v *ApiVisitor) VisitAtHandler(ctx *api.AtHandlerContext) interface{} {
return &atHandler
}
// serVisitRoute implements from api.BaseApiParserVisitor
// VisitRoute implements from api.BaseApiParserVisitor
func (v *ApiVisitor) VisitRoute(ctx *api.RouteContext) interface{} {
var route Route
path := ctx.Path()

View File

@@ -39,6 +39,10 @@ func TsCommand(_ *cobra.Command, _ []string) error {
return errors.New("missing -dir")
}
if len(webAPI) == 0 {
webAPI = "."
}
api, err := parser.Parse(apiFile)
if err != nil {
fmt.Println(aurora.Red("Failed"))
@@ -51,6 +55,7 @@ func TsCommand(_ *cobra.Command, _ []string) error {
api.Service = api.Service.JoinPrefix()
logx.Must(pathx.MkdirIfNotExist(dir))
logx.Must(genRequest(dir))
logx.Must(genHandler(dir, webAPI, caller, api, unwrapAPI))
logx.Must(genComponents(dir, api))

View File

@@ -39,7 +39,7 @@ func genHandler(dir, webAPI, caller string, api *spec.ApiSpec, unwrapAPI bool) e
importCaller = "{ " + importCaller + " }"
}
if len(webAPI) > 0 {
imports += `import ` + importCaller + ` from ` + "\"" + webAPI + "\""
imports += `import ` + importCaller + ` from ` + `"./gocliRequest"`
}
if len(api.Types) != 0 {

View File

@@ -0,0 +1,26 @@
package tsgen
import (
_ "embed"
"os"
"path/filepath"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
)
//go:embed request.ts
var requestTemplate string
func genRequest(dir string) error {
abs, err := filepath.Abs(dir)
if err != nil {
return err
}
filename := filepath.Join(abs, "gocliRequest.ts")
if pathx.FileExists(filename) {
return nil
}
return os.WriteFile(filename, []byte(requestTemplate), 0644)
}

View File

@@ -0,0 +1,126 @@
export type Method =
| 'get'
| 'GET'
| 'delete'
| 'DELETE'
| 'head'
| 'HEAD'
| 'options'
| 'OPTIONS'
| 'post'
| 'POST'
| 'put'
| 'PUT'
| 'patch'
| 'PATCH';
/**
* Parse route parameters for responseType
*/
const reg = /:[a-z|A-Z]+/g;
export function parseParams(url: string): Array<string> {
const ps = url.match(reg);
if (!ps) {
return [];
}
return ps.map((k) => k.replace(/:/, ''));
}
/**
* Generate url and parameters
* @param url
* @param params
*/
export function genUrl(url: string, params: unknown) {
if (!params) {
return url;
}
const ps = parseParams(url);
ps.forEach((k) => {
const reg = new RegExp(`:${k}`);
url = url.replace(reg, params[k]);
});
const path: Array<string> = [];
for (const key of Object.keys(params)) {
if (!ps.find((k) => k === key)) {
path.push(`${key}=${params[key]}`);
}
}
return url + (path.length > 0 ? `?${path.join('&')}` : '');
}
export async function request({
method,
url,
data,
config = {}
}: {
method: Method;
url: string;
data?: unknown;
config?: unknown;
}) {
const response = await fetch(url, {
method: method.toLocaleUpperCase(),
credentials: 'include',
headers: {
'Content-Type': 'application/json'
},
body: data ? JSON.stringify(data) : undefined,
// @ts-ignore
...config
});
return response.json();
}
function api<T>(
method: Method = 'get',
url: string,
req: any,
config?: unknown
): Promise<T> {
if (url.match(/:/) || method.match(/get|delete/i)) {
url = genUrl(url, req.params || req.forms);
}
method = method.toLocaleLowerCase() as Method;
switch (method) {
case 'get':
return request({method: 'get', url, data: req, config});
case 'delete':
return request({method: 'delete', url, data: req, config});
case 'put':
return request({method: 'put', url, data: req, config});
case 'post':
return request({method: 'post', url, data: req, config});
case 'patch':
return request({method: 'patch', url, data: req, config});
default:
return request({method: 'post', url, data: req, config});
}
}
export const webapi = {
get<T>(url: string, req: unknown, config?: unknown): Promise<T> {
return api<T>('get', url, req, config);
},
delete<T>(url: string, req: unknown, config?: unknown): Promise<T> {
return api<T>('delete', url, req, config);
},
put<T>(url: string, req: unknown, config?: unknown): Promise<T> {
return api<T>('get', url, req, config);
},
post<T>(url: string, req: unknown, config?: unknown): Promise<T> {
return api<T>('post', url, req, config);
},
patch<T>(url: string, req: unknown, config?: unknown): Promise<T> {
return api<T>('patch', url, req, config);
}
};
export default webapi

View File

@@ -70,7 +70,7 @@ spec:
---
apiVersion: autoscaling/v2beta1
apiVersion: autoscaling/v2beta2
kind: HorizontalPodAutoscaler
metadata:
name: {{.Name}}-hpa-c
@@ -88,11 +88,13 @@ spec:
- type: Resource
resource:
name: cpu
targetAverageUtilization: 80
target:
type: Utilization
averageUtilization: 80
---
apiVersion: autoscaling/v2beta1
apiVersion: autoscaling/v2beta2
kind: HorizontalPodAutoscaler
metadata:
name: {{.Name}}-hpa-m
@@ -110,4 +112,6 @@ spec:
- type: Resource
resource:
name: memory
targetAverageUtilization: 80
target:
type: Utilization
averageUtilization: 80

View File

@@ -53,7 +53,7 @@ func format(query string, args ...interface{}) (string, error) {
for _, ch := range query {
if ch == '?' {
if argIndex >= numArgs {
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
return "", fmt.Errorf("%d ? in sql, but less arguments provided", argIndex)
}
arg := args[argIndex]
@@ -79,7 +79,7 @@ func format(query string, args ...interface{}) (string, error) {
}
if argIndex < numArgs {
return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
return "", fmt.Errorf("%d ? in sql, but more arguments provided", argIndex)
}
return b.String(), nil

View File

@@ -2,6 +2,7 @@ package rpc
import (
"github.com/spf13/cobra"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/rpc/cli"
)
@@ -53,7 +54,7 @@ func init() {
newCmd.Flags().StringSliceVar(&cli.VarStringSliceGoOpt, "go_opt", nil, "")
newCmd.Flags().StringSliceVar(&cli.VarStringSliceGoGRPCOpt, "go-grpc_opt", nil, "")
newCmd.Flags().StringVar(&cli.VarStringStyle, "style", "gozero", "The file "+
newCmd.Flags().StringVar(&cli.VarStringStyle, "style", config.DefaultFormat, "The file "+
"naming format, see [https://github.com/zeromicro/go-zero/tree/master/tools/goctl/config/readme.md]")
newCmd.Flags().BoolVar(&cli.VarBoolIdea, "idea", false, "Whether the command "+
"execution environment is from idea plugin.")
@@ -79,7 +80,7 @@ func init() {
protocCmd.Flags().StringSliceVar(&cli.VarStringSlicePlugin, "plugin", nil, "")
protocCmd.Flags().StringSliceVarP(&cli.VarStringSliceProtoPath, "proto_path", "I", nil, "")
protocCmd.Flags().StringVar(&cli.VarStringZRPCOut, "zrpc_out", "", "The zrpc output directory")
protocCmd.Flags().StringVar(&cli.VarStringStyle, "style", "gozero", "The file "+
protocCmd.Flags().StringVar(&cli.VarStringStyle, "style", config.DefaultFormat, "The file "+
"naming format, see [https://github.com/zeromicro/go-zero/tree/master/tools/goctl/config/readme.md]")
protocCmd.Flags().StringVar(&cli.VarStringHome, "home", "", "The goctl home "+
"path of the template, --home and --remote cannot be set at the same time, if they are, "+

View File

@@ -63,7 +63,24 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
isCallPkgSameToPbPkg := childDir == ctx.GetProtoGo().Filename
isCallPkgSameToGrpcPkg := childDir == ctx.GetProtoGo().Filename
functions, err := g.genFunction(proto.PbPackage, service, isCallPkgSameToGrpcPkg)
serviceName := stringx.From(service.Name).ToCamel()
alias := collection.NewSet()
var hasSameNameBetweenMessageAndService bool
for _, item := range proto.Message {
msgName := getMessageName(*item.Message)
if serviceName == msgName {
hasSameNameBetweenMessageAndService = true
}
if !isCallPkgSameToPbPkg {
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName),
fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
}
}
if hasSameNameBetweenMessageAndService {
serviceName = stringx.From(service.Name + "_zrpc_client").ToCamel()
}
functions, err := g.genFunction(proto.PbPackage, serviceName, service, isCallPkgSameToGrpcPkg)
if err != nil {
return err
}
@@ -78,15 +95,6 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
return err
}
alias := collection.NewSet()
if !isCallPkgSameToPbPkg {
for _, item := range proto.Message {
msgName := getMessageName(*item.Message)
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName),
fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
}
}
pbPackage := fmt.Sprintf(`"%s"`, ctx.GetPb().Package)
protoGoPackage := fmt.Sprintf(`"%s"`, ctx.GetProtoGo().Package)
if isCallPkgSameToGrpcPkg {
@@ -103,7 +111,7 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
"filePackage": dir.Base,
"pbPackage": pbPackage,
"protoGoPackage": protoGoPackage,
"serviceName": stringx.From(service.Name).ToCamel(),
"serviceName": serviceName,
"functions": strings.Join(functions, pathx.NL),
"interface": strings.Join(iFunctions, pathx.NL),
}, filename, true); err != nil {
@@ -126,8 +134,26 @@ func (g *Generator) genCallInCompatibility(ctx DirContext, proto parser.Proto,
return err
}
serviceName := stringx.From(service.Name).ToCamel()
alias := collection.NewSet()
var hasSameNameBetweenMessageAndService bool
for _, item := range proto.Message {
msgName := getMessageName(*item.Message)
if serviceName == msgName {
hasSameNameBetweenMessageAndService = true
}
if !isCallPkgSameToPbPkg {
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName),
fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
}
}
if hasSameNameBetweenMessageAndService {
serviceName = stringx.From(service.Name + "_zrpc_client").ToCamel()
}
filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", callFilename))
functions, err := g.genFunction(proto.PbPackage, service, isCallPkgSameToGrpcPkg)
functions, err := g.genFunction(proto.PbPackage, serviceName, service, isCallPkgSameToGrpcPkg)
if err != nil {
return err
}
@@ -142,15 +168,6 @@ func (g *Generator) genCallInCompatibility(ctx DirContext, proto parser.Proto,
return err
}
alias := collection.NewSet()
if !isCallPkgSameToPbPkg {
for _, item := range proto.Message {
msgName := getMessageName(*item.Message)
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName),
fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
}
}
pbPackage := fmt.Sprintf(`"%s"`, ctx.GetPb().Package)
protoGoPackage := fmt.Sprintf(`"%s"`, ctx.GetProtoGo().Package)
if isCallPkgSameToGrpcPkg {
@@ -166,7 +183,7 @@ func (g *Generator) genCallInCompatibility(ctx DirContext, proto parser.Proto,
"filePackage": dir.Base,
"pbPackage": pbPackage,
"protoGoPackage": protoGoPackage,
"serviceName": stringx.From(service.Name).ToCamel(),
"serviceName": serviceName,
"functions": strings.Join(functions, pathx.NL),
"interface": strings.Join(iFunctions, pathx.NL),
}, filename, true)
@@ -194,7 +211,7 @@ func getMessageName(msg proto.Message) string {
return strings.Join(list, "_")
}
func (g *Generator) genFunction(goPackage string, service parser.Service,
func (g *Generator) genFunction(goPackage string, serviceName string, service parser.Service,
isCallPkgSameToGrpcPkg bool) ([]string, error) {
functions := make([]string, 0)
@@ -212,7 +229,7 @@ func (g *Generator) genFunction(goPackage string, service parser.Service,
parser.CamelCase(rpc.Name), "Client")
}
buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
"serviceName": stringx.From(service.Name).ToCamel(),
"serviceName": serviceName,
"rpcServiceName": parser.CamelCase(service.Name),
"method": parser.CamelCase(rpc.Name),
"package": goPackage,

View File

@@ -21,9 +21,11 @@ func NewGenerator(style string, verbose bool) *Generator {
if err != nil {
log.Fatalln(err)
}
log := console.NewColorConsole(verbose)
colorLogger := console.NewColorConsole(verbose)
return &Generator{
log: log,
log: colorLogger,
cfg: cfg,
verbose: verbose,
}

View File

@@ -8,8 +8,11 @@ import (
"github.com/zeromicro/go-zero/zrpc/internal/auth"
"github.com/zeromicro/go-zero/zrpc/internal/clientinterceptors"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
)
const defaultClientKeepaliveTime = 20 * time.Second
var (
// WithDialOption is an alias of internal.WithDialOption.
WithDialOption = internal.WithDialOption
@@ -62,6 +65,11 @@ func NewClient(c RpcClientConf, options ...ClientOption) (Client, error) {
if c.Timeout > 0 {
opts = append(opts, WithTimeout(time.Duration(c.Timeout)*time.Millisecond))
}
if c.KeepaliveTime > 0 {
opts = append(opts, WithDialOption(grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: c.KeepaliveTime,
})))
}
opts = append(opts, options...)
@@ -90,6 +98,12 @@ func NewClientWithTarget(target string, opts ...ClientOption) (Client, error) {
Timeout: true,
}
opts = append([]ClientOption{
WithDialOption(grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: defaultClientKeepaliveTime,
})),
}, opts...)
return internal.NewClient(target, middlewares, opts...)
}

View File

@@ -113,10 +113,11 @@ func TestDepositServer_Deposit(t *testing.T) {
)
tarConfClient := MustNewClient(
RpcClientConf{
Target: "foo",
App: "foo",
Token: "bar",
Timeout: 1000,
Target: "foo",
App: "foo",
Token: "bar",
Timeout: 1000,
KeepaliveTime: time.Second * 15,
Middlewares: ClientMiddlewaresConf{
Trace: true,
Duration: true,

View File

@@ -1,6 +1,8 @@
package zrpc
import (
"time"
"github.com/zeromicro/go-zero/core/discov"
"github.com/zeromicro/go-zero/core/service"
"github.com/zeromicro/go-zero/core/stores/redis"
@@ -14,6 +16,19 @@ type (
// ServerMiddlewaresConf defines whether to use server middlewares.
ServerMiddlewaresConf = internal.ServerMiddlewaresConf
// A RpcClientConf is a rpc client config.
RpcClientConf struct {
Etcd discov.EtcdConf `json:",optional,inherit"`
Endpoints []string `json:",optional"`
Target string `json:",optional"`
App string `json:",optional"`
Token string `json:",optional"`
NonBlock bool `json:",optional"`
Timeout int64 `json:",default=2000"`
KeepaliveTime time.Duration `json:",default=20s"`
Middlewares ClientMiddlewaresConf
}
// A RpcServerConf is a rpc server config.
RpcServerConf struct {
service.ServiceConf
@@ -29,18 +44,6 @@ type (
Health bool `json:",default=true"`
Middlewares ServerMiddlewaresConf
}
// A RpcClientConf is a rpc client config.
RpcClientConf struct {
Etcd discov.EtcdConf `json:",optional,inherit"`
Endpoints []string `json:",optional"`
Target string `json:",optional"`
App string `json:",optional"`
Token string `json:",optional"`
NonBlock bool `json:",optional"`
Timeout int64 `json:",default=2000"`
Middlewares ClientMiddlewaresConf
}
)
// NewDirectClientConf returns a RpcClientConf.

View File

@@ -10,10 +10,35 @@ import (
)
func TestRpcClientConf(t *testing.T) {
conf := NewDirectClientConf([]string{"localhost:1234"}, "foo", "bar")
assert.True(t, conf.HasCredential())
conf = NewEtcdClientConf([]string{"localhost:1234", "localhost:5678"}, "key", "foo", "bar")
assert.True(t, conf.HasCredential())
t.Run("direct", func(t *testing.T) {
conf := NewDirectClientConf([]string{"localhost:1234"}, "foo", "bar")
assert.True(t, conf.HasCredential())
})
t.Run("etcd", func(t *testing.T) {
conf := NewEtcdClientConf([]string{"localhost:1234", "localhost:5678"},
"key", "foo", "bar")
assert.True(t, conf.HasCredential())
})
t.Run("etcd with account", func(t *testing.T) {
conf := NewEtcdClientConf([]string{"localhost:1234", "localhost:5678"},
"key", "foo", "bar")
conf.Etcd.User = "user"
conf.Etcd.Pass = "pass"
_, err := conf.BuildTarget()
assert.NoError(t, err)
})
t.Run("etcd with tls", func(t *testing.T) {
conf := NewEtcdClientConf([]string{"localhost:1234", "localhost:5678"},
"key", "foo", "bar")
conf.Etcd.CertFile = "cert"
conf.Etcd.CertKeyFile = "key"
conf.Etcd.CACertFile = "ca"
_, err := conf.BuildTarget()
assert.Error(t, err)
})
}
func TestRpcServerConf(t *testing.T) {

View File

@@ -4,9 +4,21 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/discov"
"github.com/zeromicro/go-zero/core/netx"
)
func TestNewRpcPubServer(t *testing.T) {
s, err := NewRpcPubServer(discov.EtcdConf{
User: "user",
Pass: "pass",
}, "", ServerMiddlewaresConf{})
assert.NoError(t, err)
assert.NotPanics(t, func() {
s.Start(nil)
})
}
func TestFigureOutListenOn(t *testing.T) {
tests := []struct {
input string
@@ -24,6 +36,10 @@ func TestFigureOutListenOn(t *testing.T) {
input: ":8080",
expect: netx.InternalIp() + ":8080",
},
{
input: "",
expect: netx.InternalIp(),
},
}
for _, test := range tests {

View File

@@ -59,12 +59,10 @@ func (s *rpcServer) Start(register RegisterFn) error {
return err
}
unaryInterceptors := s.buildUnaryInterceptors()
unaryInterceptors = append(unaryInterceptors, s.unaryInterceptors...)
streamInterceptors := s.buildStreamInterceptors()
streamInterceptors = append(streamInterceptors, s.streamInterceptors...)
options := append(s.options, grpc.ChainUnaryInterceptor(unaryInterceptors...),
grpc.ChainStreamInterceptor(streamInterceptors...))
unaryInterceptorOption := grpc.ChainUnaryInterceptor(s.buildUnaryInterceptors()...)
streamInterceptorOption := grpc.ChainStreamInterceptor(s.buildStreamInterceptors()...)
options := append(s.options, unaryInterceptorOption, streamInterceptorOption)
server := grpc.NewServer(options...)
register(server)
@@ -102,7 +100,7 @@ func (s *rpcServer) buildStreamInterceptors() []grpc.StreamServerInterceptor {
interceptors = append(interceptors, serverinterceptors.StreamBreakerInterceptor)
}
return interceptors
return append(interceptors, s.streamInterceptors...)
}
func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor {
@@ -124,7 +122,7 @@ func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor {
interceptors = append(interceptors, serverinterceptors.UnaryBreakerInterceptor)
}
return interceptors
return append(interceptors, s.unaryInterceptors...)
}
// WithMetrics returns a func that sets metrics to a Server.

View File

@@ -1,10 +1,12 @@
package internal
import (
"context"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/zrpc/internal/mock"
"google.golang.org/grpc"
@@ -18,12 +20,13 @@ func TestRpcServer(t *testing.T) {
Stat: true,
Prometheus: true,
Breaker: true,
}, WithMetrics(metrics))
}, WithMetrics(metrics), WithRpcHealth(true))
server.SetName("mock")
var wg sync.WaitGroup
var wg, wgDone sync.WaitGroup
var grpcServer *grpc.Server
var lock sync.Mutex
wg.Add(1)
wgDone.Add(1)
go func() {
err := server.Start(func(server *grpc.Server) {
lock.Lock()
@@ -33,12 +36,16 @@ func TestRpcServer(t *testing.T) {
wg.Done()
})
assert.Nil(t, err)
wgDone.Done()
}()
wg.Wait()
lock.Lock()
grpcServer.GracefulStop()
lock.Unlock()
proc.WrapUp()
wgDone.Wait()
}
func TestRpcServer_WithBadAddress(t *testing.T) {
@@ -48,10 +55,124 @@ func TestRpcServer_WithBadAddress(t *testing.T) {
Stat: true,
Prometheus: true,
Breaker: true,
})
}, WithRpcHealth(true))
server.SetName("mock")
err := server.Start(func(server *grpc.Server) {
mock.RegisterDepositServiceServer(server, new(mock.DepositServer))
})
assert.NotNil(t, err)
proc.WrapUp()
}
func TestRpcServer_buildUnaryInterceptor(t *testing.T) {
tests := []struct {
name string
r *rpcServer
len int
}{
{
name: "empty",
r: &rpcServer{
baseRpcServer: &baseRpcServer{},
},
len: 0,
},
{
name: "custom",
r: &rpcServer{
baseRpcServer: &baseRpcServer{
unaryInterceptors: []grpc.UnaryServerInterceptor{
func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (interface{}, error) {
return nil, nil
},
},
},
},
len: 1,
},
{
name: "middleware",
r: &rpcServer{
baseRpcServer: &baseRpcServer{
unaryInterceptors: []grpc.UnaryServerInterceptor{
func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (interface{}, error) {
return nil, nil
},
},
},
middlewares: ServerMiddlewaresConf{
Trace: true,
Recover: true,
Stat: true,
Prometheus: true,
Breaker: true,
},
},
len: 6,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
assert.Equal(t, test.len, len(test.r.buildUnaryInterceptors()))
})
}
}
func TestRpcServer_buildStreamInterceptor(t *testing.T) {
tests := []struct {
name string
r *rpcServer
len int
}{
{
name: "empty",
r: &rpcServer{
baseRpcServer: &baseRpcServer{},
},
len: 0,
},
{
name: "custom",
r: &rpcServer{
baseRpcServer: &baseRpcServer{
streamInterceptors: []grpc.StreamServerInterceptor{
func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo,
handler grpc.StreamHandler) error {
return nil
},
},
},
},
len: 1,
},
{
name: "middleware",
r: &rpcServer{
baseRpcServer: &baseRpcServer{
streamInterceptors: []grpc.StreamServerInterceptor{
func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo,
handler grpc.StreamHandler) error {
return nil
},
},
},
middlewares: ServerMiddlewaresConf{
Trace: true,
Recover: true,
Breaker: true,
},
},
len: 4,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
assert.Equal(t, test.len, len(test.r.buildStreamInterceptors()))
})
}
}

View File

@@ -48,5 +48,6 @@ func ParseTarget(target resolver.Target) (Service, error) {
} else {
service.Name = endpoints
}
return service, nil
}

View File

@@ -18,6 +18,15 @@ func TestKubeBuilder_Build(t *testing.T) {
var b kubeBuilder
u, err := url.Parse(fmt.Sprintf("%s://%s", KubernetesScheme, "a,b"))
assert.NoError(t, err)
_, err = b.Build(resolver.Target{
URL: *u,
}, nil, resolver.BuildOptions{})
assert.Error(t, err)
u, err = url.Parse(fmt.Sprintf("%s://%s:9100/a:b:c", KubernetesScheme, "a,b,c,d"))
assert.NoError(t, err)
_, err = b.Build(resolver.Target{
URL: *u,
}, nil, resolver.BuildOptions{})

View File

@@ -3,15 +3,19 @@ package internal
import (
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)
func TestNopResolver(t *testing.T) {
// make sure ResolveNow & Close don't panic
var r nopResolver
r.ResolveNow(resolver.ResolveNowOptions{})
r.Close()
assert.NotPanics(t, func() {
RegisterResolver()
// make sure ResolveNow & Close don't panic
var r nopResolver
r.ResolveNow(resolver.ResolveNowOptions{})
r.Close()
})
}
type mockedClientConn struct {

View File

@@ -7,6 +7,7 @@ import (
"github.com/zeromicro/go-zero/core/load"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/zrpc/internal"
"github.com/zeromicro/go-zero/zrpc/internal/auth"
"github.com/zeromicro/go-zero/zrpc/internal/serverinterceptors"
@@ -120,7 +121,12 @@ func setupInterceptors(server internal.Server, c RpcServerConf, metrics *stat.Me
}
if c.Auth {
authenticator, err := auth.NewAuthenticator(c.Redis.NewRedis(), c.Redis.Key, c.StrictControl)
rds, err := redis.NewRedis(c.Redis.RedisConf)
if err != nil {
return err
}
authenticator, err := auth.NewAuthenticator(rds, c.Redis.Key, c.StrictControl)
if err != nil {
return err
}

View File

@@ -4,6 +4,7 @@ import (
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/discov"
"github.com/zeromicro/go-zero/core/logx"
@@ -16,12 +17,16 @@ import (
)
func TestServer_setupInterceptors(t *testing.T) {
rds, err := miniredis.Run()
assert.NoError(t, err)
defer rds.Close()
server := new(mockedServer)
err := setupInterceptors(server, RpcServerConf{
conf := RpcServerConf{
Auth: true,
Redis: redis.RedisKeyConf{
RedisConf: redis.RedisConf{
Host: "any",
Host: rds.Addr(),
Type: redis.NodeType,
},
Key: "foo",
@@ -35,10 +40,15 @@ func TestServer_setupInterceptors(t *testing.T) {
Prometheus: true,
Breaker: true,
},
}, new(stat.Metrics))
}
err = setupInterceptors(server, conf, new(stat.Metrics))
assert.Nil(t, err)
assert.Equal(t, 3, len(server.unaryInterceptors))
assert.Equal(t, 1, len(server.streamInterceptors))
rds.SetError("mock error")
err = setupInterceptors(server, conf, new(stat.Metrics))
assert.Error(t, err)
}
func TestServer(t *testing.T) {