Compare commits
53 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4572fa064 | ||
|
|
ccbabf6f58 | ||
|
|
5989444227 | ||
|
|
dc286a03f5 | ||
|
|
b82c02ed16 | ||
|
|
59ba4ecc5b | ||
|
|
5e7b514ae2 | ||
|
|
2b1466e41e | ||
|
|
9c9f80518f | ||
|
|
25973d6b59 | ||
|
|
6237d01948 | ||
|
|
49316b113e | ||
|
|
6a673e8cb0 | ||
|
|
0efa28ddbd | ||
|
|
0b6a13fe84 | ||
|
|
11aa6668e8 | ||
|
|
267a283328 | ||
|
|
2d8366b30e | ||
|
|
db83843558 | ||
|
|
50565c9765 | ||
|
|
4c02a19a14 | ||
|
|
a1b990c5ec | ||
|
|
2607bb8863 | ||
|
|
5bf37535fe | ||
|
|
ed85775fd5 | ||
|
|
418f8f6666 | ||
|
|
22e75cdf78 | ||
|
|
e79c42add1 | ||
|
|
9e14820698 | ||
|
|
2ebb5b6b58 | ||
|
|
2673dbc6e1 | ||
|
|
d21d770b5b | ||
|
|
1252bd9cde | ||
|
|
054d9b5540 | ||
|
|
f03cfb0ff7 | ||
|
|
0214161bfc | ||
|
|
d4e38cb7f0 | ||
|
|
693a8b627a | ||
|
|
701208b6f4 | ||
|
|
b65fcc5512 | ||
|
|
3321ed3519 | ||
|
|
5e007c1f9f | ||
|
|
de2f8c06fb | ||
|
|
926d746df5 | ||
|
|
4b636cd293 | ||
|
|
4bdf5e4c90 | ||
|
|
721b7def7c | ||
|
|
f294090130 | ||
|
|
489980ea0f | ||
|
|
e12c8ae993 | ||
|
|
21aad62513 | ||
|
|
0b08aca554 | ||
|
|
6ef1b5e14c |
4
.codecov.yml
Normal file
4
.codecov.yml
Normal file
@@ -0,0 +1,4 @@
|
||||
ignore:
|
||||
- "doc"
|
||||
- "example"
|
||||
- "tools"
|
||||
7
.github/workflows/go.yml
vendored
7
.github/workflows/go.yml
vendored
@@ -27,4 +27,9 @@ jobs:
|
||||
go get -v -t -d ./...
|
||||
|
||||
- name: Test
|
||||
run: go test -v -race ./...
|
||||
run: go test -race -coverprofile=coverage.txt -covermode=atomic ./...
|
||||
|
||||
- name: Codecov
|
||||
uses: codecov/codecov-action@v1.0.6
|
||||
with:
|
||||
token: ${{secrets.CODECOV_TOKEN}}
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
stages:
|
||||
- analysis
|
||||
|
||||
variables:
|
||||
GOPATH: '/runner-cache/zero'
|
||||
GOCACHE: '/runner-cache/zero'
|
||||
GOPROXY: 'https://goproxy.cn,direct'
|
||||
|
||||
analysis:
|
||||
stage: analysis
|
||||
image: golang
|
||||
script:
|
||||
- go version && go env
|
||||
- go test -short $(go list ./...) | grep -v "no test"
|
||||
only:
|
||||
- merge_requests
|
||||
tags:
|
||||
- common
|
||||
@@ -60,17 +60,15 @@ func do(name string, execute func(b Breaker) error) error {
|
||||
lock.RUnlock()
|
||||
if ok {
|
||||
return execute(b)
|
||||
} else {
|
||||
lock.Lock()
|
||||
b, ok = breakers[name]
|
||||
if ok {
|
||||
lock.Unlock()
|
||||
return execute(b)
|
||||
} else {
|
||||
b = NewBreaker(WithName(name))
|
||||
breakers[name] = b
|
||||
lock.Unlock()
|
||||
return execute(b)
|
||||
}
|
||||
}
|
||||
|
||||
lock.Lock()
|
||||
b, ok = breakers[name]
|
||||
if !ok {
|
||||
b = NewBreaker(WithName(name))
|
||||
breakers[name] = b
|
||||
}
|
||||
lock.Unlock()
|
||||
|
||||
return execute(b)
|
||||
}
|
||||
|
||||
@@ -72,9 +72,9 @@ func TestCacheWithLruEvicts(t *testing.T) {
|
||||
cache.Set("third", "third element")
|
||||
cache.Set("fourth", "fourth element")
|
||||
|
||||
value, ok := cache.Get("first")
|
||||
_, ok := cache.Get("first")
|
||||
assert.False(t, ok)
|
||||
value, ok = cache.Get("second")
|
||||
value, ok := cache.Get("second")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "second element", value)
|
||||
value, ok = cache.Get("third")
|
||||
@@ -94,9 +94,9 @@ func TestCacheWithLruEvicted(t *testing.T) {
|
||||
cache.Set("third", "third element")
|
||||
cache.Set("fourth", "fourth element")
|
||||
|
||||
value, ok := cache.Get("first")
|
||||
_, ok := cache.Get("first")
|
||||
assert.False(t, ok)
|
||||
value, ok = cache.Get("second")
|
||||
value, ok := cache.Get("second")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "second element", value)
|
||||
cache.Set("fifth", "fifth element")
|
||||
|
||||
@@ -213,7 +213,10 @@ func TestTimingWheel_SetTimer(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var count int32
|
||||
ticker := timex.NewFakeTicker()
|
||||
tick := func() {
|
||||
@@ -291,7 +294,10 @@ func TestTimingWheel_SetAndMoveThenStart(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var count int32
|
||||
ticker := timex.NewFakeTicker()
|
||||
tick := func() {
|
||||
@@ -376,7 +382,10 @@ func TestTimingWheel_SetAndMoveTwice(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var count int32
|
||||
ticker := timex.NewFakeTicker()
|
||||
tick := func() {
|
||||
@@ -454,7 +463,10 @@ func TestTimingWheel_ElapsedAndSet(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var count int32
|
||||
ticker := timex.NewFakeTicker()
|
||||
tick := func() {
|
||||
@@ -542,7 +554,10 @@ func TestTimingWheel_ElapsedAndSetThenMove(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var count int32
|
||||
ticker := timex.NewFakeTicker()
|
||||
tick := func() {
|
||||
|
||||
@@ -34,7 +34,7 @@ func TestContextCancel(t *testing.T) {
|
||||
assert.NotEqual(t, context.Canceled, c2.Err())
|
||||
}
|
||||
|
||||
func TestConextDeadline(t *testing.T) {
|
||||
func TestContextDeadline(t *testing.T) {
|
||||
c, _ := context.WithDeadline(context.Background(), time.Now().Add(10*time.Millisecond))
|
||||
o := ValueOnlyFrom(c)
|
||||
select {
|
||||
|
||||
@@ -12,14 +12,14 @@ func TestBulkExecutor(t *testing.T) {
|
||||
var values []int
|
||||
var lock sync.Mutex
|
||||
|
||||
exeutor := NewBulkExecutor(func(items []interface{}) {
|
||||
executor := NewBulkExecutor(func(items []interface{}) {
|
||||
lock.Lock()
|
||||
values = append(values, len(items))
|
||||
lock.Unlock()
|
||||
}, WithBulkTasks(10), WithBulkInterval(time.Minute))
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
exeutor.Add(1)
|
||||
executor.Add(1)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -40,13 +40,13 @@ func TestBulkExecutorFlushInterval(t *testing.T) {
|
||||
var wait sync.WaitGroup
|
||||
|
||||
wait.Add(1)
|
||||
exeutor := NewBulkExecutor(func(items []interface{}) {
|
||||
executor := NewBulkExecutor(func(items []interface{}) {
|
||||
assert.Equal(t, size, len(items))
|
||||
wait.Done()
|
||||
}, WithBulkTasks(caches), WithBulkInterval(time.Millisecond*100))
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
exeutor.Add(1)
|
||||
executor.Add(1)
|
||||
}
|
||||
|
||||
wait.Wait()
|
||||
|
||||
@@ -12,14 +12,14 @@ func TestChunkExecutor(t *testing.T) {
|
||||
var values []int
|
||||
var lock sync.Mutex
|
||||
|
||||
exeutor := NewChunkExecutor(func(items []interface{}) {
|
||||
executor := NewChunkExecutor(func(items []interface{}) {
|
||||
lock.Lock()
|
||||
values = append(values, len(items))
|
||||
lock.Unlock()
|
||||
}, WithChunkBytes(10), WithFlushInterval(time.Minute))
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
exeutor.Add(1, 1)
|
||||
executor.Add(1, 1)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -40,13 +40,13 @@ func TestChunkExecutorFlushInterval(t *testing.T) {
|
||||
var wait sync.WaitGroup
|
||||
|
||||
wait.Add(1)
|
||||
exeutor := NewChunkExecutor(func(items []interface{}) {
|
||||
executor := NewChunkExecutor(func(items []interface{}) {
|
||||
assert.Equal(t, size, len(items))
|
||||
wait.Done()
|
||||
}, WithChunkBytes(caches), WithFlushInterval(time.Millisecond*100))
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
exeutor.Add(1, 1)
|
||||
executor.Add(1, 1)
|
||||
}
|
||||
|
||||
wait.Wait()
|
||||
|
||||
@@ -4,9 +4,8 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/fs"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/fs"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -15,34 +14,34 @@ const (
|
||||
text = `first line
|
||||
Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus.
|
||||
Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum.
|
||||
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna.
|
||||
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculous sed magna.
|
||||
Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil!
|
||||
Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur.
|
||||
Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus.
|
||||
Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum.
|
||||
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna.
|
||||
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculous sed magna.
|
||||
Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil!
|
||||
Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur.
|
||||
Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus.
|
||||
Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum.
|
||||
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna.
|
||||
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculous sed magna.
|
||||
Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil!
|
||||
Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur.
|
||||
` + longLine
|
||||
textWithLastNewline = `first line
|
||||
Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus.
|
||||
Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum.
|
||||
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna.
|
||||
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculous sed magna.
|
||||
Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil!
|
||||
Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur.
|
||||
Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus.
|
||||
Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum.
|
||||
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna.
|
||||
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculous sed magna.
|
||||
Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil!
|
||||
Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur.
|
||||
Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus.
|
||||
Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum.
|
||||
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna.
|
||||
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculous sed magna.
|
||||
Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil!
|
||||
Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur.
|
||||
` + longLine + "\n"
|
||||
|
||||
@@ -49,7 +49,7 @@ func From(generate GenerateFunc) Stream {
|
||||
return Range(source)
|
||||
}
|
||||
|
||||
// Just converts the given arbitary items to a Stream.
|
||||
// Just converts the given arbitrary items to a Stream.
|
||||
func Just(items ...interface{}) Stream {
|
||||
source := make(chan interface{}, len(items))
|
||||
for _, item := range items {
|
||||
@@ -195,7 +195,7 @@ func (p Stream) Merge() Stream {
|
||||
return Range(source)
|
||||
}
|
||||
|
||||
// Parallel applies the given ParallenFunc to each item concurrently with given number of workers.
|
||||
// Parallel applies the given ParallelFunc to each item concurrently with given number of workers.
|
||||
func (p Stream) Parallel(fn ParallelFunc, opts ...Option) {
|
||||
p.Walk(func(item interface{}, pipe chan<- interface{}) {
|
||||
fn(item)
|
||||
|
||||
@@ -9,24 +9,24 @@ import (
|
||||
|
||||
var ErrNoAvailablePusher = errors.New("no available pusher")
|
||||
|
||||
type BalancedQueuePusher struct {
|
||||
type BalancedPusher struct {
|
||||
name string
|
||||
pushers []Pusher
|
||||
index uint64
|
||||
}
|
||||
|
||||
func NewBalancedQueuePusher(pushers []Pusher) Pusher {
|
||||
return &BalancedQueuePusher{
|
||||
func NewBalancedPusher(pushers []Pusher) Pusher {
|
||||
return &BalancedPusher{
|
||||
name: generateName(pushers),
|
||||
pushers: pushers,
|
||||
}
|
||||
}
|
||||
|
||||
func (pusher *BalancedQueuePusher) Name() string {
|
||||
func (pusher *BalancedPusher) Name() string {
|
||||
return pusher.name
|
||||
}
|
||||
|
||||
func (pusher *BalancedQueuePusher) Push(message string) error {
|
||||
func (pusher *BalancedPusher) Push(message string) error {
|
||||
size := len(pusher.pushers)
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
@@ -20,7 +20,7 @@ func TestBalancedQueuePusher(t *testing.T) {
|
||||
mockedPushers = append(mockedPushers, p)
|
||||
}
|
||||
|
||||
pusher := NewBalancedQueuePusher(pushers)
|
||||
pusher := NewBalancedPusher(pushers)
|
||||
assert.True(t, len(pusher.Name()) > 0)
|
||||
|
||||
for i := 0; i < numPushers*1000; i++ {
|
||||
@@ -37,7 +37,7 @@ func TestBalancedQueuePusher(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBalancedQueuePusher_NoAvailable(t *testing.T) {
|
||||
pusher := NewBalancedQueuePusher(nil)
|
||||
pusher := NewBalancedPusher(nil)
|
||||
assert.True(t, len(pusher.Name()) == 0)
|
||||
assert.Equal(t, ErrNoAvailablePusher, pusher.Push("item"))
|
||||
}
|
||||
@@ -2,23 +2,23 @@ package queue
|
||||
|
||||
import "github.com/tal-tech/go-zero/core/errorx"
|
||||
|
||||
type MultiQueuePusher struct {
|
||||
type MultiPusher struct {
|
||||
name string
|
||||
pushers []Pusher
|
||||
}
|
||||
|
||||
func NewMultiQueuePusher(pushers []Pusher) Pusher {
|
||||
return &MultiQueuePusher{
|
||||
func NewMultiPusher(pushers []Pusher) Pusher {
|
||||
return &MultiPusher{
|
||||
name: generateName(pushers),
|
||||
pushers: pushers,
|
||||
}
|
||||
}
|
||||
|
||||
func (pusher *MultiQueuePusher) Name() string {
|
||||
func (pusher *MultiPusher) Name() string {
|
||||
return pusher.name
|
||||
}
|
||||
|
||||
func (pusher *MultiQueuePusher) Push(message string) error {
|
||||
func (pusher *MultiPusher) Push(message string) error {
|
||||
var batchError errorx.BatchError
|
||||
|
||||
for _, each := range pusher.pushers {
|
||||
@@ -21,7 +21,7 @@ func TestMultiQueuePusher(t *testing.T) {
|
||||
mockedPushers = append(mockedPushers, p)
|
||||
}
|
||||
|
||||
pusher := NewMultiQueuePusher(pushers)
|
||||
pusher := NewMultiPusher(pushers)
|
||||
assert.True(t, len(pusher.Name()) > 0)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/tal-tech/go-zero/core/proc"
|
||||
"github.com/tal-tech/go-zero/core/sysx"
|
||||
"github.com/tal-tech/go-zero/core/timex"
|
||||
"github.com/tal-tech/go-zero/core/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -24,7 +23,7 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
reporter = utils.Report
|
||||
reporter func(string)
|
||||
lock sync.RWMutex
|
||||
lessExecutor = executors.NewLessExecutor(time.Minute * 5)
|
||||
dropped int32
|
||||
|
||||
@@ -212,10 +212,12 @@ func TestRedis_Persist(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
err = client.Expire("key", 5)
|
||||
assert.Nil(t, err)
|
||||
ok, err = client.Persist("key")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
err = client.Expireat("key", time.Now().Unix()+5)
|
||||
assert.Nil(t, err)
|
||||
ok, err = client.Persist("key")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
@@ -379,7 +381,7 @@ func TestRedis_SortedSet(t *testing.T) {
|
||||
rank, err := client.Zrank("key", "value2")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), rank)
|
||||
rank, err = client.Zrank("key", "value4")
|
||||
_, err = client.Zrank("key", "value4")
|
||||
assert.Equal(t, redis.Nil, err)
|
||||
num, err := client.Zrem("key", "value2", "value3")
|
||||
assert.Nil(t, err)
|
||||
|
||||
@@ -249,10 +249,12 @@ func TestRedis_Persist(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
err = client.Expire("key", 5)
|
||||
assert.Nil(t, err)
|
||||
ok, err = client.Persist("key")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
err = client.Expireat("key", time.Now().Unix()+5)
|
||||
assert.Nil(t, err)
|
||||
ok, err = client.Persist("key")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
@@ -447,7 +449,7 @@ func TestRedis_SortedSet(t *testing.T) {
|
||||
rank, err := client.Zrank("key", "value2")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), rank)
|
||||
rank, err = client.Zrank("key", "value4")
|
||||
_, err = client.Zrank("key", "value4")
|
||||
assert.Equal(t, Nil, err)
|
||||
num, err := client.Zrem("key", "value2", "value3")
|
||||
assert.Nil(t, err)
|
||||
@@ -558,6 +560,7 @@ func TestRedis_Pipelined(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "1", value)
|
||||
score, err := client.Zscore("zadd", "zadd")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(12), score)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -63,11 +63,6 @@ func (r *replacer) Replace(text string) string {
|
||||
i = j - 1
|
||||
builder.WriteString(r.mapping[string(chars[start:end])])
|
||||
} else {
|
||||
if j < size {
|
||||
end = j + 1
|
||||
} else {
|
||||
end = size
|
||||
}
|
||||
builder.WriteRune(chars[i])
|
||||
}
|
||||
start = -1
|
||||
|
||||
@@ -2,7 +2,11 @@ package stringx
|
||||
|
||||
import "github.com/tal-tech/go-zero/core/lang"
|
||||
|
||||
const defaultMask = '*'
|
||||
|
||||
type (
|
||||
TrieOption func(trie *trieNode)
|
||||
|
||||
Trie interface {
|
||||
Filter(text string) (string, []string, bool)
|
||||
FindKeywords(text string) []string
|
||||
@@ -10,6 +14,7 @@ type (
|
||||
|
||||
trieNode struct {
|
||||
node
|
||||
mask rune
|
||||
}
|
||||
|
||||
scope struct {
|
||||
@@ -18,8 +23,15 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
func NewTrie(words []string) Trie {
|
||||
func NewTrie(words []string, opts ...TrieOption) Trie {
|
||||
n := new(trieNode)
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(n)
|
||||
}
|
||||
if n.mask == 0 {
|
||||
n.mask = defaultMask
|
||||
}
|
||||
for _, word := range words {
|
||||
n.add(word)
|
||||
}
|
||||
@@ -114,6 +126,12 @@ func (n *trieNode) findKeywordScopes(chars []rune) []scope {
|
||||
|
||||
func (n *trieNode) replaceWithAsterisk(chars []rune, start, stop int) {
|
||||
for i := start; i < stop; i++ {
|
||||
chars[i] = '*'
|
||||
chars[i] = n.mask
|
||||
}
|
||||
}
|
||||
|
||||
func WithMask(mask rune) TrieOption {
|
||||
return func(n *trieNode) {
|
||||
n.mask = mask
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,25 +109,25 @@ func TestTrie(t *testing.T) {
|
||||
func TestTrieSingleWord(t *testing.T) {
|
||||
trie := NewTrie([]string{
|
||||
"闹",
|
||||
})
|
||||
}, WithMask('#'))
|
||||
output, keywords, ok := trie.Filter("今晚真热闹")
|
||||
assert.ElementsMatch(t, []string{"闹"}, keywords)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "今晚真热*", output)
|
||||
assert.Equal(t, "今晚真热#", output)
|
||||
}
|
||||
|
||||
func TestTrieOverlap(t *testing.T) {
|
||||
trie := NewTrie([]string{
|
||||
"一二三四五",
|
||||
"二三四五六七八",
|
||||
})
|
||||
}, WithMask('#'))
|
||||
output, keywords, ok := trie.Filter("零一二三四五六七八九十")
|
||||
assert.ElementsMatch(t, []string{
|
||||
"一二三四五",
|
||||
"二三四五六七八",
|
||||
}, keywords)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "零********九十", output)
|
||||
assert.Equal(t, "零########九十", output)
|
||||
}
|
||||
|
||||
func TestTrieNested(t *testing.T) {
|
||||
@@ -135,7 +135,7 @@ func TestTrieNested(t *testing.T) {
|
||||
"一二三",
|
||||
"一二三四五",
|
||||
"一二三四五六七八",
|
||||
})
|
||||
}, WithMask('#'))
|
||||
output, keywords, ok := trie.Filter("零一二三四五六七八九十")
|
||||
assert.ElementsMatch(t, []string{
|
||||
"一二三",
|
||||
@@ -143,7 +143,7 @@ func TestTrieNested(t *testing.T) {
|
||||
"一二三四五六七八",
|
||||
}, keywords)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "零********九十", output)
|
||||
assert.Equal(t, "零########九十", output)
|
||||
}
|
||||
|
||||
func BenchmarkTrie(b *testing.B) {
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
package utils
|
||||
|
||||
func Report(content string) {
|
||||
// TODO: implement the report method
|
||||
}
|
||||
46
doc/goctl.md
46
doc/goctl.md
@@ -3,8 +3,8 @@
|
||||
## goctl用途
|
||||
* 定义api请求
|
||||
* 根据定义的api自动生成golang(后端), java(iOS & Android), typescript(web & 晓程序),dart(flutter)
|
||||
* 生成MySQL CURD (https://goctl.xiaoheiban.cn)
|
||||
* 生成MongoDB CURD (https://goctl.xiaoheiban.cn)
|
||||
* 生成MySQL CURD+Cache
|
||||
* 生成MongoDB CURD+Cache
|
||||
|
||||
## goctl使用说明
|
||||
#### goctl参数说明
|
||||
@@ -179,23 +179,38 @@ service user-api {
|
||||
* 在定义的get/post/put/delete等请求的handler和logic里增加处理业务逻辑的代码
|
||||
|
||||
#### 根据定义好的api文件生成java代码
|
||||
`goctl api java -api user/user.api -dir ./src`
|
||||
```shell
|
||||
goctl api java -api user/user.api -dir ./src
|
||||
```
|
||||
|
||||
#### 根据定义好的api文件生成typescript代码
|
||||
`goctl api ts -api user/user.api -dir ./src -webapi ***`
|
||||
|
||||
ts需要指定webapi所在目录
|
||||
```shell
|
||||
goctl api ts -api user/user.api -dir ./src -webapi ***
|
||||
|
||||
ts需要指定webapi所在目录
|
||||
```
|
||||
|
||||
#### 根据定义好的api文件生成Dart代码
|
||||
`goctl api dart -api user/user.api -dir ./src`
|
||||
```shell
|
||||
goctl api dart -api user/user.api -dir ./src
|
||||
```
|
||||
|
||||
## 根据mysql ddl或者datasource生成model文件
|
||||
|
||||
```shell script
|
||||
$ goctl model mysql -src={filename} -dir={dir} -cache={true|false}
|
||||
```
|
||||
详情参考[model文档](https://github.com/tal-tech/go-zero/blob/master/tools/goctl/model/sql/README.MD)
|
||||
|
||||
## 根据定义好的简单go文件生成mongo代码文件(仅限golang使用)
|
||||
`goctl model mongo -src {{yourDir}}/xiao/service/xhb/user/model/usermodel.go -cache yes`
|
||||
|
||||
-src需要提供简单的usermodel.go文件,里面只需要提供一个结构体即可
|
||||
-cache 控制是否需要缓存 yes=需要 no=不需要
|
||||
src 示例代码如下
|
||||
```
|
||||
```shell
|
||||
goctl model mongo -src {{yourDir}}/xiao/service/xhb/user/model/usermodel.go -cache yes
|
||||
|
||||
-src需要提供简单的usermodel.go文件,里面只需要提供一个结构体即可
|
||||
-cache 控制是否需要缓存 yes=需要 no=不需要
|
||||
src 示例代码如下
|
||||
```
|
||||
```go
|
||||
package model
|
||||
|
||||
type User struct {
|
||||
@@ -210,7 +225,7 @@ type User struct {
|
||||
o是改字段需要生产的操作函数 可以取得get,find,set 分别表示生成返回单个对象的查询方法,返回多个对象的查询方法,设置该字段方法
|
||||
生成的目标文件会覆盖该简单go文件
|
||||
|
||||
## goctl rpc生成
|
||||
## goctl rpc生成(业务剥离中,暂未开放)
|
||||
|
||||
命令 `goctl rpc proto -proto ${proto} -service ${serviceName} -project ${projectName} -dir ${directory} -shared ${shared}`
|
||||
如: `goctl rpc proto -proto test.proto -service test -project xjy -dir .`
|
||||
@@ -261,5 +276,4 @@ type User struct {
|
||||
│ └── test.go [强制覆盖更新]
|
||||
└── test.proto
|
||||
```
|
||||
- 注意 :目前rpc目录生成的proto文件暂不支持import外部proto文件
|
||||
* 如有不理解的地方,随时问Kim/Kevin
|
||||
- 注意 :目前rpc目录生成的proto文件暂不支持import外部proto文件
|
||||
BIN
doc/images/architecture.png
Normal file
BIN
doc/images/architecture.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 333 KiB |
BIN
doc/images/benchmark.png
Normal file
BIN
doc/images/benchmark.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 26 KiB |
BIN
doc/images/trie.png
Normal file
BIN
doc/images/trie.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 91 KiB |
86
doc/keywords.md
Normal file
86
doc/keywords.md
Normal file
@@ -0,0 +1,86 @@
|
||||
# 高效的关键词替换和敏感词过滤工具
|
||||
|
||||
## 1. 算法介绍
|
||||
|
||||
利用高效的Trie树建立关键词树,如下图所示,然后依次查找字符串中的相连字符是否形成树的一条路径
|
||||
|
||||
<img src="images/trie.png" alt="trie" width="350" />
|
||||
|
||||
发现掘金上[这篇文章](https://juejin.im/post/6844903750490914829)写的比较详细,可以一读,具体原理在此不详述。
|
||||
|
||||
## 2. 关键词替换
|
||||
|
||||
支持关键词重叠,自动选用最长的关键词,代码示例如下:
|
||||
|
||||
```go
|
||||
replacer := stringx.NewReplacer(map[string]string{
|
||||
"日本": "法国",
|
||||
"日本的首都": "东京",
|
||||
"东京": "日本的首都",
|
||||
})
|
||||
fmt.Println(replacer.Replace("日本的首都是东京"))
|
||||
```
|
||||
|
||||
可以得到:
|
||||
```
|
||||
东京是日本的首都
|
||||
```
|
||||
|
||||
示例代码见`example/stringx/replace/replace.go`
|
||||
|
||||
## 3. 查找敏感词
|
||||
|
||||
代码示例如下:
|
||||
|
||||
```go
|
||||
filter := stringx.NewTrie([]string{
|
||||
"AV演员",
|
||||
"苍井空",
|
||||
"AV",
|
||||
"日本AV女优",
|
||||
"AV演员色情",
|
||||
})
|
||||
keywords := filter.FindKeywords("日本AV演员兼电视、电影演员。苍井空AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演")
|
||||
fmt.Println(keywords)
|
||||
```
|
||||
|
||||
可以得到:
|
||||
|
||||
```
|
||||
[苍井空 日本AV女优 AV演员色情 AV AV演员]
|
||||
```
|
||||
|
||||
## 4. 敏感词过滤
|
||||
|
||||
代码示例如下:
|
||||
|
||||
```go
|
||||
filter := stringx.NewTrie([]string{
|
||||
"AV演员",
|
||||
"苍井空",
|
||||
"AV",
|
||||
"日本AV女优",
|
||||
"AV演员色情",
|
||||
}, stringx.WithMask('?')) // 默认替换为*
|
||||
safe, keywords, found := filter.Filter("日本AV演员兼电视、电影演员。苍井空AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演")
|
||||
fmt.Println(safe)
|
||||
fmt.Println(keywords)
|
||||
fmt.Println(found)
|
||||
```
|
||||
|
||||
可以得到:
|
||||
|
||||
```
|
||||
日本????兼电视、电影演员。?????女优是xx出道, ??????们最精彩的表演是??????表演
|
||||
[苍井空 日本AV女优 AV演员色情 AV AV演员]
|
||||
true
|
||||
```
|
||||
|
||||
示例代码见`example/stringx/filter/filter.go`
|
||||
|
||||
## 5. Benchmark
|
||||
|
||||
| Sentences | Keywords | Regex | go-zero |
|
||||
| --------- | -------- | -------- | ------- |
|
||||
| 10000 | 10000 | 16min10s | 27.2ms |
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
freq = flag.Int("freq", 100, "frequence")
|
||||
freq = flag.Int("freq", 100, "frequency")
|
||||
duration = flag.String("duration", "10s", "duration")
|
||||
)
|
||||
|
||||
@@ -84,8 +84,8 @@ func (m *metric) reset() counting {
|
||||
return result
|
||||
}
|
||||
|
||||
func runRequests(url string, frequence int, metrics *metric, done <-chan lang.PlaceholderType) {
|
||||
ticker := time.NewTicker(time.Second / time.Duration(frequence))
|
||||
func runRequests(url string, frequency int, metrics *metric, done <-chan lang.PlaceholderType) {
|
||||
ticker := time.NewTicker(time.Second / time.Duration(frequency))
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
exeutor := executors.NewBulkExecutor(func(items []interface{}) {
|
||||
executor := executors.NewBulkExecutor(func(items []interface{}) {
|
||||
fmt.Println(len(items))
|
||||
}, executors.WithBulkTasks(10))
|
||||
for {
|
||||
exeutor.Add(1)
|
||||
executor.Add(1)
|
||||
time.Sleep(time.Millisecond * 90)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/discov"
|
||||
@@ -10,13 +12,31 @@ import (
|
||||
"github.com/tal-tech/go-zero/rpcx"
|
||||
)
|
||||
|
||||
var lb = flag.String("t", "direct", "the load balancer type")
|
||||
|
||||
func main() {
|
||||
cli := rpcx.MustNewClient(rpcx.RpcClientConf{
|
||||
Etcd: discov.EtcdConf{
|
||||
Hosts: []string{"localhost:2379"},
|
||||
Key: "rpcx",
|
||||
},
|
||||
})
|
||||
flag.Parse()
|
||||
|
||||
var cli rpcx.Client
|
||||
switch *lb {
|
||||
case "direct":
|
||||
cli = rpcx.MustNewClient(rpcx.RpcClientConf{
|
||||
Endpoints: []string{
|
||||
"localhost:3456",
|
||||
"localhost:3457",
|
||||
},
|
||||
})
|
||||
case "discov":
|
||||
cli = rpcx.MustNewClient(rpcx.RpcClientConf{
|
||||
Etcd: discov.EtcdConf{
|
||||
Hosts: []string{"localhost:2379"},
|
||||
Key: "rpcx",
|
||||
},
|
||||
})
|
||||
default:
|
||||
log.Fatal("bad load balancing type")
|
||||
}
|
||||
|
||||
greet := unary.NewGreeterClient(cli.Conn())
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
21
example/stringx/filter/filter.go
Normal file
21
example/stringx/filter/filter.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
func main() {
|
||||
filter := stringx.NewTrie([]string{
|
||||
"AV演员",
|
||||
"苍井空",
|
||||
"AV",
|
||||
"日本AV女优",
|
||||
"AV演员色情",
|
||||
}, stringx.WithMask('?'))
|
||||
safe, keywords, found := filter.Filter("日本AV演员兼电视、电影演员。苍井空AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演")
|
||||
fmt.Println(safe)
|
||||
fmt.Println(keywords)
|
||||
fmt.Println(found)
|
||||
}
|
||||
16
example/stringx/replace/replace.go
Normal file
16
example/stringx/replace/replace.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
func main() {
|
||||
replacer := stringx.NewReplacer(map[string]string{
|
||||
"日本": "法国",
|
||||
"日本的首都": "东京",
|
||||
"东京": "日本的首都",
|
||||
})
|
||||
fmt.Println(replacer.Replace("日本的首都是东京"))
|
||||
}
|
||||
3
go.mod
3
go.mod
@@ -13,6 +13,7 @@ require (
|
||||
github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8
|
||||
github.com/go-redis/redis v6.15.7+incompatible
|
||||
github.com/go-sql-driver/mysql v1.5.0
|
||||
github.com/go-xorm/builder v0.3.4
|
||||
github.com/gogo/protobuf v1.3.1 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect
|
||||
github.com/golang/mock v1.4.3
|
||||
@@ -22,6 +23,7 @@ require (
|
||||
github.com/google/uuid v1.1.1
|
||||
github.com/gorilla/websocket v1.4.2 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.14.3 // indirect
|
||||
github.com/iancoleman/strcase v0.0.0-20191112232945-16388991a334
|
||||
github.com/justinas/alice v1.2.0
|
||||
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
|
||||
@@ -41,6 +43,7 @@ require (
|
||||
github.com/stretchr/testify v1.5.1
|
||||
github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6 // indirect
|
||||
github.com/urfave/cli v1.22.4
|
||||
github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2
|
||||
github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb // indirect
|
||||
go.etcd.io/etcd v0.0.0-20200402134248-51bdeb39e698
|
||||
go.uber.org/automaxprocs v1.3.0
|
||||
|
||||
8
go.sum
8
go.sum
@@ -76,6 +76,10 @@ github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gG
|
||||
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||
github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk=
|
||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||
github.com/go-xorm/builder v0.3.4 h1:FxkeGB4Cggdw3tPwutLCpfjng2jugfkg6LDMrd/KsoY=
|
||||
github.com/go-xorm/builder v0.3.4/go.mod h1:KxkQkNN1DpPKTedxXyTQcmH+rXfvk4LZ9SOOBoZBAxw=
|
||||
github.com/go-xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:9wScpmSP5A3Bk8V3XHWUcJmYTh+ZnlHVyc+A4oZYS3Y=
|
||||
github.com/go-xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:56xuuqnHyryaerycW3BfssRdxQstACi0Epw/yC5E2xM=
|
||||
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||
github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4=
|
||||
@@ -134,6 +138,8 @@ github.com/grpc-ecosystem/grpc-gateway v1.14.3 h1:OCJlWkOUoTnl0neNGlf4fUm3TmbEtg
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.14.3/go.mod h1:6CwZWGDSPRJidgKAtJVvND6soZe6fT7iteq8wDPdhb0=
|
||||
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/iancoleman/strcase v0.0.0-20191112232945-16388991a334 h1:VHgatEHNcBFEB7inlalqfNqw65aNkM1lGX2yt3NmbS8=
|
||||
github.com/iancoleman/strcase v0.0.0-20191112232945-16388991a334/go.mod h1:SK73tn/9oHe+/Y0h39VT4UCxmurVJkR5NA7kMEAOgSE=
|
||||
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
|
||||
github.com/jonboulle/clockwork v0.1.0 h1:VKV+ZcuP6l3yW9doeqz6ziZGgcynBVQO+obU0+0hcPo=
|
||||
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
|
||||
@@ -266,6 +272,8 @@ github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 h1:eY9dn8+vbi4tKz5
|
||||
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=
|
||||
github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6 h1:YdYsPAZ2pC6Tow/nPZOPQ96O3hm/ToAkGsPLzedXERk=
|
||||
github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6/go.mod h1:ce1O1j6UtZfjr22oyGxGLbauSBp2YVXpARAosm7dHBg=
|
||||
github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 h1:zzrxE1FKn5ryBNl9eKOeqQ58Y/Qpo3Q9QNxKHX5uzzQ=
|
||||
github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2/go.mod h1:hzfGeIUDq/j97IG+FhNqkowIyEcD88LrW6fyU3K3WqY=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb h1:ZkM6LRnq40pR1Ox0hTHlnpkcOTuFIDQpZ1IN8rKKhX0=
|
||||
github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb/go.mod h1:gqRgreBUhTSL0GeU64rtZ3Uq3wtjOa/TB2YfrtkCbVQ=
|
||||
|
||||
82
readme.md
82
readme.md
@@ -1,6 +1,27 @@
|
||||
# go-zero项目介绍
|
||||
# go-zero
|
||||
|
||||

|
||||
[](https://github.com/tal-tech/go-zero/actions)
|
||||
[](https://codecov.io/gh/tal-tech/go-zero)
|
||||
[](https://goreportcard.com/report/github.com/tal-tech/go-zero)
|
||||
[](https://github.com/tal-tech/go-zero)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
## 0. go-zero介绍
|
||||
|
||||
go-zero是一个集成了各种工程实践的web和rpc框架。通过弹性设计保障了大并发服务端的稳定性,经受了充分的实战检验。
|
||||
|
||||
go-zero 包含极简的 API 定义和生成工具 goctl,可以根据定义的 api 文件一键生成 Go, iOS, Android, Kotlin, Dart, TypeScript, JavaScript 代码,并可直接运行。
|
||||
|
||||
使用go-zero的好处:
|
||||
|
||||
* 轻松获得支撑千万日活服务的稳定性
|
||||
* 内建级联超时控制、限流、自适应熔断、自适应降载等微服务治理能力,无需配置和额外代码
|
||||
* 微服务治理中间件可无缝集成到其它现有框架使用
|
||||
* 极简的API描述,一键生成各端代码
|
||||
* 自动校验客户端请求参数合法性
|
||||
* 大量微服务治理和并发工具包
|
||||
|
||||
<img src="doc/images/architecture.png" alt="架构图" width="1500" />
|
||||
|
||||
## 1. go-zero框架背景
|
||||
|
||||
@@ -53,33 +74,20 @@ go-zero是一个集成了各种工程实践的包含web和rpc框架,有如下
|
||||
|
||||

|
||||
|
||||
## 4. go-zero框架收益
|
||||
|
||||
* 保障大并发服务端的稳定性,经受了充分的实战检验
|
||||
* 极简的API定义
|
||||
* 一键生成Go, iOS, Android, Dart, TypeScript, JavaScript代码,并可直接运行
|
||||
* 服务端自动校验参数合法性
|
||||
|
||||
## 5. go-zero近期开发计划
|
||||
## 4. go-zero近期开发计划
|
||||
|
||||
* 自动生成API mock server,便于客户端开发
|
||||
* 自动生成服务端功能测试
|
||||
|
||||
## 6. Installation
|
||||
## 5. Installation
|
||||
|
||||
1. 在项目目录下通过如下命令安装:
|
||||
在项目目录下通过如下命令安装:
|
||||
|
||||
```shell
|
||||
go get -u github.com/tal-tech/go-zero
|
||||
```
|
||||
```shell
|
||||
go get -u github.com/tal-tech/go-zero
|
||||
```
|
||||
|
||||
2. 代码里导入go-zero
|
||||
|
||||
```go
|
||||
import "github.com/tal-tech/go-zero"
|
||||
```
|
||||
|
||||
## 7. Quick Start
|
||||
## 6. Quick Start
|
||||
|
||||
1. 编译goctl工具
|
||||
|
||||
@@ -93,7 +101,7 @@ go-zero是一个集成了各种工程实践的包含web和rpc框架,有如下
|
||||
|
||||
```go
|
||||
type Request struct {
|
||||
Name string `path:"name"`
|
||||
Name string `path:"name,options=you|me"` // 框架自动验证请求参数是否合法
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
@@ -123,7 +131,6 @@ go-zero是一个集成了各种工程实践的包含web和rpc框架,有如下
|
||||
生成的文件结构如下:
|
||||
|
||||
```
|
||||
.
|
||||
├── greet
|
||||
│ ├── etc
|
||||
│ │ └── greet-api.json // 配置文件
|
||||
@@ -141,26 +148,24 @@ go-zero是一个集成了各种工程实践的包含web和rpc框架,有如下
|
||||
│ └── types
|
||||
│ └── types.go // 请求、返回等类型定义
|
||||
└── greet.api // api描述文件
|
||||
|
||||
8 directories, 9 files
|
||||
```
|
||||
生成的代码可以直接运行:
|
||||
|
||||
```shell
|
||||
cd greet
|
||||
go run greet.go -f etc/greet-api.json
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
默认侦听在8888端口(可以在配置文件里修改),可以通过curl请求:
|
||||
|
||||
|
||||
```shell
|
||||
➜ go-zero git:(master) curl -w "\ncode: %{http_code}\n" http://localhost:8888/greet/from/kevin
|
||||
{"code":0}
|
||||
code: 200
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
编写业务代码:
|
||||
|
||||
|
||||
* 可以在servicecontext.go里面传递依赖给logic,比如mysql, redis等
|
||||
* 在api定义的get/post/put/delete等请求对应的logic里增加业务处理逻辑
|
||||
|
||||
@@ -172,6 +177,17 @@ go-zero是一个集成了各种工程实践的包含web和rpc框架,有如下
|
||||
...
|
||||
```
|
||||
|
||||
### 微信交流群
|
||||
## 7. Benchmark
|
||||
|
||||

|
||||
|
||||
[测试代码见这里](https://github.com/smallnest/go-web-framework-benchmark)
|
||||
|
||||
## 8. 文档 (逐步完善中)
|
||||
|
||||
* [goctl使用帮助](doc/goctl.md)
|
||||
* [关键字替换和敏感词过滤工具](doc/keywords.md)
|
||||
|
||||
## 9. 微信交流群
|
||||
|
||||
添加我的微信:kevwan,请注明go-zero,我拉进go-zero社区群🤝
|
||||
|
||||
214
rest/engine.go
Normal file
214
rest/engine.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
"github.com/tal-tech/go-zero/core/codec"
|
||||
"github.com/tal-tech/go-zero/core/load"
|
||||
"github.com/tal-tech/go-zero/core/stat"
|
||||
"github.com/tal-tech/go-zero/rest/handler"
|
||||
"github.com/tal-tech/go-zero/rest/httpx"
|
||||
"github.com/tal-tech/go-zero/rest/internal"
|
||||
"github.com/tal-tech/go-zero/rest/router"
|
||||
)
|
||||
|
||||
// use 1000m to represent 100%
|
||||
const topCpuUsage = 1000
|
||||
|
||||
var ErrSignatureConfig = errors.New("bad config for Signature")
|
||||
|
||||
type engine struct {
|
||||
conf RestConf
|
||||
routes []featuredRoutes
|
||||
unauthorizedCallback handler.UnauthorizedCallback
|
||||
unsignedCallback handler.UnsignedCallback
|
||||
middlewares []Middleware
|
||||
shedder load.Shedder
|
||||
priorityShedder load.Shedder
|
||||
}
|
||||
|
||||
func newEngine(c RestConf) *engine {
|
||||
srv := &engine{
|
||||
conf: c,
|
||||
}
|
||||
if c.CpuThreshold > 0 {
|
||||
srv.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
|
||||
srv.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
|
||||
(c.CpuThreshold + topCpuUsage) >> 1))
|
||||
}
|
||||
|
||||
return srv
|
||||
}
|
||||
|
||||
func (s *engine) AddRoutes(r featuredRoutes) {
|
||||
s.routes = append(s.routes, r)
|
||||
}
|
||||
|
||||
func (s *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) {
|
||||
s.unauthorizedCallback = callback
|
||||
}
|
||||
|
||||
func (s *engine) SetUnsignedCallback(callback handler.UnsignedCallback) {
|
||||
s.unsignedCallback = callback
|
||||
}
|
||||
|
||||
func (s *engine) Start() error {
|
||||
return s.StartWithRouter(router.NewPatRouter())
|
||||
}
|
||||
|
||||
func (s *engine) StartWithRouter(router httpx.Router) error {
|
||||
if err := s.bindRoutes(router); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return internal.StartHttp(s.conf.Host, s.conf.Port, router)
|
||||
}
|
||||
|
||||
func (s *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
|
||||
verifier func(alice.Chain) alice.Chain) alice.Chain {
|
||||
if fr.jwt.enabled {
|
||||
if len(fr.jwt.prevSecret) == 0 {
|
||||
chain = chain.Append(handler.Authorize(fr.jwt.secret,
|
||||
handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
|
||||
} else {
|
||||
chain = chain.Append(handler.Authorize(fr.jwt.secret,
|
||||
handler.WithPrevSecret(fr.jwt.prevSecret),
|
||||
handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
|
||||
}
|
||||
}
|
||||
|
||||
return verifier(chain)
|
||||
}
|
||||
|
||||
func (s *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
|
||||
verifier, err := s.signatureVerifier(fr.signature)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, route := range fr.routes {
|
||||
if err := s.bindRoute(fr, router, metrics, route, verifier); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
|
||||
route Route, verifier func(chain alice.Chain) alice.Chain) error {
|
||||
chain := alice.New(
|
||||
handler.TracingHandler,
|
||||
s.getLogHandler(),
|
||||
handler.MaxConns(s.conf.MaxConns),
|
||||
handler.BreakerHandler(route.Method, route.Path, metrics),
|
||||
handler.SheddingHandler(s.getShedder(fr.priority), metrics),
|
||||
handler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond),
|
||||
handler.RecoverHandler,
|
||||
handler.MetricHandler(metrics),
|
||||
handler.PromMetricHandler(route.Path),
|
||||
handler.MaxBytesHandler(s.conf.MaxBytes),
|
||||
handler.GunzipHandler,
|
||||
)
|
||||
chain = s.appendAuthHandler(fr, chain, verifier)
|
||||
|
||||
for _, middleware := range s.middlewares {
|
||||
chain = chain.Append(convertMiddleware(middleware))
|
||||
}
|
||||
handle := chain.ThenFunc(route.Handler)
|
||||
|
||||
return router.Handle(route.Method, route.Path, handle)
|
||||
}
|
||||
|
||||
func (s *engine) bindRoutes(router httpx.Router) error {
|
||||
metrics := s.createMetrics()
|
||||
|
||||
for _, fr := range s.routes {
|
||||
if err := s.bindFeaturedRoutes(router, fr, metrics); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *engine) createMetrics() *stat.Metrics {
|
||||
var metrics *stat.Metrics
|
||||
|
||||
if len(s.conf.Name) > 0 {
|
||||
metrics = stat.NewMetrics(s.conf.Name)
|
||||
} else {
|
||||
metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", s.conf.Host, s.conf.Port))
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
func (s *engine) getLogHandler() func(http.Handler) http.Handler {
|
||||
if s.conf.Verbose {
|
||||
return handler.DetailedLogHandler
|
||||
} else {
|
||||
return handler.LogHandler
|
||||
}
|
||||
}
|
||||
|
||||
func (s *engine) getShedder(priority bool) load.Shedder {
|
||||
if priority && s.priorityShedder != nil {
|
||||
return s.priorityShedder
|
||||
}
|
||||
return s.shedder
|
||||
}
|
||||
|
||||
func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
|
||||
if !signature.enabled {
|
||||
return func(chain alice.Chain) alice.Chain {
|
||||
return chain
|
||||
}, nil
|
||||
}
|
||||
|
||||
if len(signature.PrivateKeys) == 0 {
|
||||
if signature.Strict {
|
||||
return nil, ErrSignatureConfig
|
||||
} else {
|
||||
return func(chain alice.Chain) alice.Chain {
|
||||
return chain
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
decrypters := make(map[string]codec.RsaDecrypter)
|
||||
for _, key := range signature.PrivateKeys {
|
||||
fingerprint := key.Fingerprint
|
||||
file := key.KeyFile
|
||||
decrypter, err := codec.NewRsaDecrypter(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
decrypters[fingerprint] = decrypter
|
||||
}
|
||||
|
||||
return func(chain alice.Chain) alice.Chain {
|
||||
if s.unsignedCallback != nil {
|
||||
return chain.Append(handler.ContentSecurityHandler(
|
||||
decrypters, signature.Expiry, signature.Strict, s.unsignedCallback))
|
||||
} else {
|
||||
return chain.Append(handler.ContentSecurityHandler(
|
||||
decrypters, signature.Expiry, signature.Strict))
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *engine) use(middleware Middleware) {
|
||||
s.middlewares = append(s.middlewares, middleware)
|
||||
}
|
||||
|
||||
func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return ware(next.ServeHTTP)
|
||||
}
|
||||
}
|
||||
@@ -217,6 +217,7 @@ func TestContentSecurityHandler(t *testing.T) {
|
||||
signature: test.signature,
|
||||
}
|
||||
req, err := buildRequest(setting)
|
||||
assert.Nil(t, err)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, test.statusCode, resp.Code)
|
||||
@@ -249,6 +250,7 @@ func TestContentSecurityHandler_UnsignedCallback(t *testing.T) {
|
||||
signature: "badone",
|
||||
}
|
||||
req, err := buildRequest(setting)
|
||||
assert.Nil(t, err)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
@@ -285,6 +287,7 @@ func TestContentSecurityHandler_UnsignedCallback_WrongTime(t *testing.T) {
|
||||
fingerprint: fingerprint,
|
||||
}
|
||||
req, err := buildRequest(setting)
|
||||
assert.Nil(t, err)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -17,6 +18,24 @@ func init() {
|
||||
logx.Disable()
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
const body = "foo"
|
||||
w := tracedResponseWriter{
|
||||
headers: make(map[string][]string),
|
||||
}
|
||||
Error(&w, errors.New(body))
|
||||
assert.Equal(t, http.StatusBadRequest, w.code)
|
||||
assert.Equal(t, body, strings.TrimSpace(w.builder.String()))
|
||||
}
|
||||
|
||||
func TestOk(t *testing.T) {
|
||||
w := tracedResponseWriter{
|
||||
headers: make(map[string][]string),
|
||||
}
|
||||
Ok(&w)
|
||||
assert.Equal(t, http.StatusOK, w.code)
|
||||
}
|
||||
|
||||
func TestOkJson(t *testing.T) {
|
||||
w := tracedResponseWriter{
|
||||
headers: make(map[string][]string),
|
||||
|
||||
170
rest/ngin.go
170
rest/ngin.go
@@ -1,170 +0,0 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"github.com/tal-tech/go-zero/rest/handler"
|
||||
"github.com/tal-tech/go-zero/rest/httpx"
|
||||
)
|
||||
|
||||
type (
|
||||
runOptions struct {
|
||||
start func(*engine) error
|
||||
}
|
||||
|
||||
RunOption func(*Server)
|
||||
|
||||
Server struct {
|
||||
ngin *engine
|
||||
opts runOptions
|
||||
}
|
||||
)
|
||||
|
||||
func MustNewServer(c RestConf, opts ...RunOption) *Server {
|
||||
engine, err := NewServer(c, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return engine
|
||||
}
|
||||
|
||||
func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
|
||||
if err := c.SetUp(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
ngin: newEngine(c),
|
||||
opts: runOptions{
|
||||
start: func(srv *engine) error {
|
||||
return srv.Start()
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(server)
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (e *Server) AddRoutes(rs []Route, opts ...RouteOption) {
|
||||
r := featuredRoutes{
|
||||
routes: rs,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(&r)
|
||||
}
|
||||
e.ngin.AddRoutes(r)
|
||||
}
|
||||
|
||||
func (e *Server) AddRoute(r Route, opts ...RouteOption) {
|
||||
e.AddRoutes([]Route{r}, opts...)
|
||||
}
|
||||
|
||||
func (e *Server) Start() {
|
||||
handleError(e.opts.start(e.ngin))
|
||||
}
|
||||
|
||||
func (e *Server) Stop() {
|
||||
logx.Close()
|
||||
}
|
||||
|
||||
func (e *Server) Use(middleware Middleware) {
|
||||
e.ngin.use(middleware)
|
||||
}
|
||||
|
||||
func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
|
||||
return func(handle http.HandlerFunc) http.HandlerFunc {
|
||||
return handler(handle).ServeHTTP
|
||||
}
|
||||
}
|
||||
|
||||
func WithJwt(secret string) RouteOption {
|
||||
return func(r *featuredRoutes) {
|
||||
validateSecret(secret)
|
||||
r.jwt.enabled = true
|
||||
r.jwt.secret = secret
|
||||
}
|
||||
}
|
||||
|
||||
func WithJwtTransition(secret, prevSecret string) RouteOption {
|
||||
return func(r *featuredRoutes) {
|
||||
// why not validate prevSecret, because prevSecret is an already used one,
|
||||
// even it not meet our requirement, we still need to allow the transition.
|
||||
validateSecret(secret)
|
||||
r.jwt.enabled = true
|
||||
r.jwt.secret = secret
|
||||
r.jwt.prevSecret = prevSecret
|
||||
}
|
||||
}
|
||||
|
||||
func WithMiddleware(middleware Middleware, rs ...Route) []Route {
|
||||
routes := make([]Route, len(rs))
|
||||
|
||||
for i := range rs {
|
||||
route := rs[i]
|
||||
routes[i] = Route{
|
||||
Method: route.Method,
|
||||
Path: route.Path,
|
||||
Handler: middleware(route.Handler),
|
||||
}
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
func WithPriority() RouteOption {
|
||||
return func(r *featuredRoutes) {
|
||||
r.priority = true
|
||||
}
|
||||
}
|
||||
|
||||
func WithRouter(router httpx.Router) RunOption {
|
||||
return func(server *Server) {
|
||||
server.opts.start = func(srv *engine) error {
|
||||
return srv.StartWithRouter(router)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WithSignature(signature SignatureConf) RouteOption {
|
||||
return func(r *featuredRoutes) {
|
||||
r.signature.enabled = true
|
||||
r.signature.Strict = signature.Strict
|
||||
r.signature.Expiry = signature.Expiry
|
||||
r.signature.PrivateKeys = signature.PrivateKeys
|
||||
}
|
||||
}
|
||||
|
||||
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
|
||||
return func(engine *Server) {
|
||||
engine.ngin.SetUnauthorizedCallback(callback)
|
||||
}
|
||||
}
|
||||
|
||||
func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
|
||||
return func(engine *Server) {
|
||||
engine.ngin.SetUnsignedCallback(callback)
|
||||
}
|
||||
}
|
||||
|
||||
func handleError(err error) {
|
||||
// ErrServerClosed means the server is closed manually
|
||||
if err == nil || err == http.ErrServerClosed {
|
||||
return
|
||||
}
|
||||
|
||||
logx.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
func validateSecret(secret string) {
|
||||
if len(secret) < 8 {
|
||||
panic("secret's length can't be less than 8")
|
||||
}
|
||||
}
|
||||
316
rest/server.go
316
rest/server.go
@@ -1,214 +1,170 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
"github.com/tal-tech/go-zero/core/codec"
|
||||
"github.com/tal-tech/go-zero/core/load"
|
||||
"github.com/tal-tech/go-zero/core/stat"
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"github.com/tal-tech/go-zero/rest/handler"
|
||||
"github.com/tal-tech/go-zero/rest/httpx"
|
||||
"github.com/tal-tech/go-zero/rest/internal"
|
||||
"github.com/tal-tech/go-zero/rest/router"
|
||||
)
|
||||
|
||||
// use 1000m to represent 100%
|
||||
const topCpuUsage = 1000
|
||||
|
||||
var ErrSignatureConfig = errors.New("bad config for Signature")
|
||||
|
||||
type engine struct {
|
||||
conf RestConf
|
||||
routes []featuredRoutes
|
||||
unauthorizedCallback handler.UnauthorizedCallback
|
||||
unsignedCallback handler.UnsignedCallback
|
||||
middlewares []Middleware
|
||||
shedder load.Shedder
|
||||
priorityShedder load.Shedder
|
||||
}
|
||||
|
||||
func newEngine(c RestConf) *engine {
|
||||
srv := &engine{
|
||||
conf: c,
|
||||
}
|
||||
if c.CpuThreshold > 0 {
|
||||
srv.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
|
||||
srv.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
|
||||
(c.CpuThreshold + topCpuUsage) >> 1))
|
||||
type (
|
||||
runOptions struct {
|
||||
start func(*engine) error
|
||||
}
|
||||
|
||||
return srv
|
||||
}
|
||||
RunOption func(*Server)
|
||||
|
||||
func (s *engine) AddRoutes(r featuredRoutes) {
|
||||
s.routes = append(s.routes, r)
|
||||
}
|
||||
|
||||
func (s *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) {
|
||||
s.unauthorizedCallback = callback
|
||||
}
|
||||
|
||||
func (s *engine) SetUnsignedCallback(callback handler.UnsignedCallback) {
|
||||
s.unsignedCallback = callback
|
||||
}
|
||||
|
||||
func (s *engine) Start() error {
|
||||
return s.StartWithRouter(router.NewPatRouter())
|
||||
}
|
||||
|
||||
func (s *engine) StartWithRouter(router httpx.Router) error {
|
||||
if err := s.bindRoutes(router); err != nil {
|
||||
return err
|
||||
Server struct {
|
||||
ngin *engine
|
||||
opts runOptions
|
||||
}
|
||||
)
|
||||
|
||||
return internal.StartHttp(s.conf.Host, s.conf.Port, router)
|
||||
}
|
||||
|
||||
func (s *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
|
||||
verifier func(alice.Chain) alice.Chain) alice.Chain {
|
||||
if fr.jwt.enabled {
|
||||
if len(fr.jwt.prevSecret) == 0 {
|
||||
chain = chain.Append(handler.Authorize(fr.jwt.secret,
|
||||
handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
|
||||
} else {
|
||||
chain = chain.Append(handler.Authorize(fr.jwt.secret,
|
||||
handler.WithPrevSecret(fr.jwt.prevSecret),
|
||||
handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
|
||||
}
|
||||
}
|
||||
|
||||
return verifier(chain)
|
||||
}
|
||||
|
||||
func (s *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
|
||||
verifier, err := s.signatureVerifier(fr.signature)
|
||||
func MustNewServer(c RestConf, opts ...RunOption) *Server {
|
||||
engine, err := NewServer(c, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for _, route := range fr.routes {
|
||||
if err := s.bindRoute(fr, router, metrics, route, verifier); err != nil {
|
||||
return err
|
||||
return engine
|
||||
}
|
||||
|
||||
func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
|
||||
if err := c.SetUp(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
ngin: newEngine(c),
|
||||
opts: runOptions{
|
||||
start: func(srv *engine) error {
|
||||
return srv.Start()
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(server)
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (e *Server) AddRoutes(rs []Route, opts ...RouteOption) {
|
||||
r := featuredRoutes{
|
||||
routes: rs,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(&r)
|
||||
}
|
||||
e.ngin.AddRoutes(r)
|
||||
}
|
||||
|
||||
func (e *Server) AddRoute(r Route, opts ...RouteOption) {
|
||||
e.AddRoutes([]Route{r}, opts...)
|
||||
}
|
||||
|
||||
func (e *Server) Start() {
|
||||
handleError(e.opts.start(e.ngin))
|
||||
}
|
||||
|
||||
func (e *Server) Stop() {
|
||||
logx.Close()
|
||||
}
|
||||
|
||||
func (e *Server) Use(middleware Middleware) {
|
||||
e.ngin.use(middleware)
|
||||
}
|
||||
|
||||
func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
|
||||
return func(handle http.HandlerFunc) http.HandlerFunc {
|
||||
return handler(handle).ServeHTTP
|
||||
}
|
||||
}
|
||||
|
||||
func WithJwt(secret string) RouteOption {
|
||||
return func(r *featuredRoutes) {
|
||||
validateSecret(secret)
|
||||
r.jwt.enabled = true
|
||||
r.jwt.secret = secret
|
||||
}
|
||||
}
|
||||
|
||||
func WithJwtTransition(secret, prevSecret string) RouteOption {
|
||||
return func(r *featuredRoutes) {
|
||||
// why not validate prevSecret, because prevSecret is an already used one,
|
||||
// even it not meet our requirement, we still need to allow the transition.
|
||||
validateSecret(secret)
|
||||
r.jwt.enabled = true
|
||||
r.jwt.secret = secret
|
||||
r.jwt.prevSecret = prevSecret
|
||||
}
|
||||
}
|
||||
|
||||
func WithMiddleware(middleware Middleware, rs ...Route) []Route {
|
||||
routes := make([]Route, len(rs))
|
||||
|
||||
for i := range rs {
|
||||
route := rs[i]
|
||||
routes[i] = Route{
|
||||
Method: route.Method,
|
||||
Path: route.Path,
|
||||
Handler: middleware(route.Handler),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return routes
|
||||
}
|
||||
|
||||
func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
|
||||
route Route, verifier func(chain alice.Chain) alice.Chain) error {
|
||||
chain := alice.New(
|
||||
handler.TracingHandler,
|
||||
s.getLogHandler(),
|
||||
handler.MaxConns(s.conf.MaxConns),
|
||||
handler.BreakerHandler(route.Method, route.Path, metrics),
|
||||
handler.SheddingHandler(s.getShedder(fr.priority), metrics),
|
||||
handler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond),
|
||||
handler.RecoverHandler,
|
||||
handler.MetricHandler(metrics),
|
||||
handler.PromMetricHandler(route.Path),
|
||||
handler.MaxBytesHandler(s.conf.MaxBytes),
|
||||
handler.GunzipHandler,
|
||||
)
|
||||
chain = s.appendAuthHandler(fr, chain, verifier)
|
||||
|
||||
for _, middleware := range s.middlewares {
|
||||
chain = chain.Append(convertMiddleware(middleware))
|
||||
func WithPriority() RouteOption {
|
||||
return func(r *featuredRoutes) {
|
||||
r.priority = true
|
||||
}
|
||||
handle := chain.ThenFunc(route.Handler)
|
||||
|
||||
return router.Handle(route.Method, route.Path, handle)
|
||||
}
|
||||
|
||||
func (s *engine) bindRoutes(router httpx.Router) error {
|
||||
metrics := s.createMetrics()
|
||||
|
||||
for _, fr := range s.routes {
|
||||
if err := s.bindFeaturedRoutes(router, fr, metrics); err != nil {
|
||||
return err
|
||||
func WithRouter(router httpx.Router) RunOption {
|
||||
return func(server *Server) {
|
||||
server.opts.start = func(srv *engine) error {
|
||||
return srv.StartWithRouter(router)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *engine) createMetrics() *stat.Metrics {
|
||||
var metrics *stat.Metrics
|
||||
|
||||
if len(s.conf.Name) > 0 {
|
||||
metrics = stat.NewMetrics(s.conf.Name)
|
||||
} else {
|
||||
metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", s.conf.Host, s.conf.Port))
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
func (s *engine) getLogHandler() func(http.Handler) http.Handler {
|
||||
if s.conf.Verbose {
|
||||
return handler.DetailedLogHandler
|
||||
} else {
|
||||
return handler.LogHandler
|
||||
func WithSignature(signature SignatureConf) RouteOption {
|
||||
return func(r *featuredRoutes) {
|
||||
r.signature.enabled = true
|
||||
r.signature.Strict = signature.Strict
|
||||
r.signature.Expiry = signature.Expiry
|
||||
r.signature.PrivateKeys = signature.PrivateKeys
|
||||
}
|
||||
}
|
||||
|
||||
func (s *engine) getShedder(priority bool) load.Shedder {
|
||||
if priority && s.priorityShedder != nil {
|
||||
return s.priorityShedder
|
||||
}
|
||||
return s.shedder
|
||||
}
|
||||
|
||||
func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
|
||||
if !signature.enabled {
|
||||
return func(chain alice.Chain) alice.Chain {
|
||||
return chain
|
||||
}, nil
|
||||
}
|
||||
|
||||
if len(signature.PrivateKeys) == 0 {
|
||||
if signature.Strict {
|
||||
return nil, ErrSignatureConfig
|
||||
} else {
|
||||
return func(chain alice.Chain) alice.Chain {
|
||||
return chain
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
decrypters := make(map[string]codec.RsaDecrypter)
|
||||
for _, key := range signature.PrivateKeys {
|
||||
fingerprint := key.Fingerprint
|
||||
file := key.KeyFile
|
||||
decrypter, err := codec.NewRsaDecrypter(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
decrypters[fingerprint] = decrypter
|
||||
}
|
||||
|
||||
return func(chain alice.Chain) alice.Chain {
|
||||
if s.unsignedCallback != nil {
|
||||
return chain.Append(handler.ContentSecurityHandler(
|
||||
decrypters, signature.Expiry, signature.Strict, s.unsignedCallback))
|
||||
} else {
|
||||
return chain.Append(handler.ContentSecurityHandler(
|
||||
decrypters, signature.Expiry, signature.Strict))
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *engine) use(middleware Middleware) {
|
||||
s.middlewares = append(s.middlewares, middleware)
|
||||
}
|
||||
|
||||
func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(ware(next.ServeHTTP))
|
||||
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
|
||||
return func(engine *Server) {
|
||||
engine.ngin.SetUnauthorizedCallback(callback)
|
||||
}
|
||||
}
|
||||
|
||||
func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
|
||||
return func(engine *Server) {
|
||||
engine.ngin.SetUnsignedCallback(callback)
|
||||
}
|
||||
}
|
||||
|
||||
func handleError(err error) {
|
||||
// ErrServerClosed means the server is closed manually
|
||||
if err == nil || err == http.ErrServerClosed {
|
||||
return
|
||||
}
|
||||
|
||||
logx.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
func validateSecret(secret string) {
|
||||
if len(secret) < 8 {
|
||||
panic("secret's length can't be less than 8")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,10 +49,10 @@ func NewClient(c RpcClientConf, options ...internal.ClientOption) (Client, error
|
||||
|
||||
var client Client
|
||||
var err error
|
||||
if len(c.Server) > 0 {
|
||||
client, err = internal.NewDirectClient(c.Server, opts...)
|
||||
if len(c.Endpoints) > 0 {
|
||||
client, err = internal.NewClient(internal.BuildDirectTarget(c.Endpoints), opts...)
|
||||
} else if err = c.Etcd.Validate(); err == nil {
|
||||
client, err = internal.NewDiscovClient(c.Etcd.Hosts, c.Etcd.Key, opts...)
|
||||
client, err = internal.NewClient(internal.BuildDiscovTarget(c.Etcd.Hosts, c.Etcd.Key), opts...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -64,7 +64,7 @@ func NewClient(c RpcClientConf, options ...internal.ClientOption) (Client, error
|
||||
}
|
||||
|
||||
func NewClientNoAuth(c discov.EtcdConf) (Client, error) {
|
||||
client, err := internal.NewDiscovClient(c.Hosts, c.Key)
|
||||
client, err := internal.NewClient(internal.BuildDiscovTarget(c.Hosts, c.Key))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -74,6 +74,10 @@ func NewClientNoAuth(c discov.EtcdConf) (Client, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewClientWithTarget(target string, opts ...internal.ClientOption) (Client, error) {
|
||||
return internal.NewClient(target, opts...)
|
||||
}
|
||||
|
||||
func (rc *RpcClient) Conn() *grpc.ClientConn {
|
||||
return rc.client.Conn()
|
||||
}
|
||||
|
||||
@@ -21,19 +21,19 @@ type (
|
||||
}
|
||||
|
||||
RpcClientConf struct {
|
||||
Etcd discov.EtcdConf `json:",optional"`
|
||||
Server string `json:",optional=!Etcd"`
|
||||
App string `json:",optional"`
|
||||
Token string `json:",optional"`
|
||||
Timeout int64 `json:",optional"`
|
||||
Etcd discov.EtcdConf `json:",optional"`
|
||||
Endpoints []string `json:",optional=!Etcd"`
|
||||
App string `json:",optional"`
|
||||
Token string `json:",optional"`
|
||||
Timeout int64 `json:",optional"`
|
||||
}
|
||||
)
|
||||
|
||||
func NewDirectClientConf(server, app, token string) RpcClientConf {
|
||||
func NewDirectClientConf(endpoints []string, app, token string) RpcClientConf {
|
||||
return RpcClientConf{
|
||||
Server: server,
|
||||
App: app,
|
||||
Token: token,
|
||||
Endpoints: endpoints,
|
||||
App: app,
|
||||
Token: token,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
62
rpcx/internal/auth/credential_test.go
Normal file
62
rpcx/internal/auth/credential_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func TestParseCredential(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
withNil bool
|
||||
withEmptyMd bool
|
||||
app string
|
||||
token string
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
withNil: true,
|
||||
},
|
||||
{
|
||||
name: "empty md",
|
||||
withEmptyMd: true,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
app: "foo",
|
||||
token: "bar",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var ctx context.Context
|
||||
if test.withNil {
|
||||
ctx = context.Background()
|
||||
} else if test.withEmptyMd {
|
||||
ctx = metadata.NewIncomingContext(context.Background(), metadata.MD{})
|
||||
} else {
|
||||
md := metadata.New(map[string]string{
|
||||
"app": test.app,
|
||||
"token": test.token,
|
||||
})
|
||||
ctx = metadata.NewIncomingContext(context.Background(), md)
|
||||
}
|
||||
cred := ParseCredential(ctx)
|
||||
assert.False(t, cred.RequireTransportSecurity())
|
||||
m, err := cred.GetRequestMetadata(context.Background())
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, test.app, m[appKey])
|
||||
assert.Equal(t, test.token, m[tokenKey])
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
|
||||
const (
|
||||
Name = "p2c_ewma"
|
||||
decayTime = int64(time.Millisecond * 600)
|
||||
decayTime = int64(time.Second * 10) // default value from finagle
|
||||
forcePick = int64(time.Second)
|
||||
initSuccess = 1000
|
||||
throttleSuccess = initSuccess / 2
|
||||
|
||||
@@ -3,7 +3,9 @@ package p2c
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -33,19 +35,31 @@ func TestP2cPicker_Pick(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
candidates int
|
||||
threshold float64
|
||||
}{
|
||||
{
|
||||
name: "single",
|
||||
candidates: 1,
|
||||
threshold: 0.9,
|
||||
},
|
||||
{
|
||||
name: "two",
|
||||
candidates: 2,
|
||||
threshold: 0.5,
|
||||
},
|
||||
{
|
||||
name: "multiple",
|
||||
candidates: 100,
|
||||
threshold: 0.95,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const total = 10000
|
||||
builder := new(p2cPickerBuilder)
|
||||
ready := make(map[resolver.Address]balancer.SubConn)
|
||||
for i := 0; i < test.candidates; i++ {
|
||||
@@ -55,7 +69,9 @@ func TestP2cPicker_Pick(t *testing.T) {
|
||||
}
|
||||
|
||||
picker := builder.Build(ready)
|
||||
for i := 0; i < 10000; i++ {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(total)
|
||||
for i := 0; i < total; i++ {
|
||||
_, done, err := picker.Pick(context.Background(), balancer.PickInfo{
|
||||
FullMethodName: "/",
|
||||
Ctx: context.Background(),
|
||||
@@ -64,11 +80,16 @@ func TestP2cPicker_Pick(t *testing.T) {
|
||||
if i%100 == 0 {
|
||||
err = status.Error(codes.DeadlineExceeded, "deadline")
|
||||
}
|
||||
done(balancer.DoneInfo{
|
||||
Err: err,
|
||||
})
|
||||
go func() {
|
||||
runtime.Gosched()
|
||||
done(balancer.DoneInfo{
|
||||
Err: err,
|
||||
})
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
dist := make(map[interface{}]int)
|
||||
conns := picker.(*p2cPicker).conns
|
||||
for _, conn := range conns {
|
||||
@@ -76,7 +97,8 @@ func TestP2cPicker_Pick(t *testing.T) {
|
||||
}
|
||||
|
||||
entropy := mathx.CalcEntropy(dist)
|
||||
assert.True(t, entropy > .95, fmt.Sprintf("entropy is %f, less than .95", entropy))
|
||||
assert.True(t, entropy > test.threshold, fmt.Sprintf("entropy is %f, less than %f",
|
||||
entropy, test.threshold))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
47
rpcx/internal/chainclientinterceptors_test.go
Normal file
47
rpcx/internal/chainclientinterceptors_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestWithStreamClientInterceptors(t *testing.T) {
|
||||
opts := WithStreamClientInterceptors()
|
||||
assert.NotNil(t, opts)
|
||||
}
|
||||
|
||||
func TestWithUnaryClientInterceptors(t *testing.T) {
|
||||
opts := WithUnaryClientInterceptors()
|
||||
assert.NotNil(t, opts)
|
||||
}
|
||||
|
||||
func TestChainStreamClientInterceptors_zero(t *testing.T) {
|
||||
interceptors := chainStreamClientInterceptors()
|
||||
_, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo",
|
||||
func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
|
||||
opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestChainStreamClientInterceptors_one(t *testing.T) {
|
||||
var called int32
|
||||
interceptors := chainStreamClientInterceptors(func(ctx context.Context, desc *grpc.StreamDesc,
|
||||
cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (
|
||||
grpc.ClientStream, error) {
|
||||
atomic.AddInt32(&called, 1)
|
||||
return nil, nil
|
||||
})
|
||||
_, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo",
|
||||
func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
|
||||
opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&called))
|
||||
}
|
||||
@@ -5,12 +5,18 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/rpcx/internal/balancer/p2c"
|
||||
"github.com/tal-tech/go-zero/rpcx/internal/clientinterceptors"
|
||||
"github.com/tal-tech/go-zero/rpcx/internal/resolver"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
const dialTimeout = time.Second * 3
|
||||
|
||||
func init() {
|
||||
resolver.RegisterResolver()
|
||||
}
|
||||
|
||||
type (
|
||||
ClientOptions struct {
|
||||
Timeout time.Duration
|
||||
@@ -18,8 +24,26 @@ type (
|
||||
}
|
||||
|
||||
ClientOption func(options *ClientOptions)
|
||||
|
||||
client struct {
|
||||
conn *grpc.ClientConn
|
||||
}
|
||||
)
|
||||
|
||||
func NewClient(target string, opts ...ClientOption) (*client, error) {
|
||||
opts = append(opts, WithDialOption(grpc.WithBalancerName(p2c.Name)))
|
||||
conn, err := dial(target, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &client{conn: conn}, nil
|
||||
}
|
||||
|
||||
func (c *client) Conn() *grpc.ClientConn {
|
||||
return c.conn
|
||||
}
|
||||
|
||||
func WithDialOption(opt grpc.DialOption) ClientOption {
|
||||
return func(options *ClientOptions) {
|
||||
options.DialOptions = append(options.DialOptions, opt)
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
package clientinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/breaker"
|
||||
"github.com/tal-tech/go-zero/core/stat"
|
||||
rcodes "github.com/tal-tech/go-zero/rpcx/internal/codes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
@@ -49,3 +52,30 @@ func TestBreakerInterceptorDeadlineExceeded(t *testing.T) {
|
||||
assert.True(t, errs[err] > 0)
|
||||
assert.True(t, errs[breaker.ErrServiceUnavailable] > 0)
|
||||
}
|
||||
|
||||
func TestBreakerInterceptor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "with error",
|
||||
err: errors.New("mock"),
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
cc := new(grpc.ClientConn)
|
||||
err := BreakerInterceptor(context.Background(), "/foo", nil, nil, cc,
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
return test.err
|
||||
})
|
||||
assert.Equal(t, test.err, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
37
rpcx/internal/clientinterceptors/durationinterceptor_test.go
Normal file
37
rpcx/internal/clientinterceptors/durationinterceptor_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package clientinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestDurationInterceptor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "with error",
|
||||
err: errors.New("mock"),
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
cc := new(grpc.ClientConn)
|
||||
err := DurationInterceptor(context.Background(), "/foo", nil, nil, cc,
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
return test.err
|
||||
})
|
||||
assert.Equal(t, test.err, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package clientinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestPromMetricInterceptor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "with error",
|
||||
err: errors.New("mock"),
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
cc := new(grpc.ClientConn)
|
||||
err := PromMetricInterceptor(context.Background(), "/foo", nil, nil, cc,
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
return test.err
|
||||
})
|
||||
assert.Equal(t, test.err, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
50
rpcx/internal/clientinterceptors/timeoutinterceptor_test.go
Normal file
50
rpcx/internal/clientinterceptors/timeoutinterceptor_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package clientinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestTimeoutInterceptor(t *testing.T) {
|
||||
timeouts := []time.Duration{0, time.Millisecond * 10}
|
||||
for _, timeout := range timeouts {
|
||||
t.Run(strconv.FormatInt(int64(timeout), 10), func(t *testing.T) {
|
||||
interceptor := TimeoutInterceptor(timeout)
|
||||
cc := new(grpc.ClientConn)
|
||||
err := interceptor(context.Background(), "/foo", nil, nil, cc,
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
return nil
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutInterceptor_timeout(t *testing.T) {
|
||||
const timeout = time.Millisecond * 10
|
||||
interceptor := TimeoutInterceptor(timeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
cc := new(grpc.ClientConn)
|
||||
err := interceptor(ctx, "/foo", nil, nil, cc,
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
defer wg.Done()
|
||||
tm, ok := ctx.Deadline()
|
||||
assert.True(t, ok)
|
||||
assert.True(t, tm.Before(time.Now().Add(timeout+time.Millisecond)))
|
||||
return nil
|
||||
})
|
||||
wg.Wait()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
53
rpcx/internal/clientinterceptors/tracinginterceptor_test.go
Normal file
53
rpcx/internal/clientinterceptors/tracinginterceptor_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package clientinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/trace"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func TestTracingInterceptor(t *testing.T) {
|
||||
var run int32
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
cc := new(grpc.ClientConn)
|
||||
err := TracingInterceptor(context.Background(), "/foo", nil, nil, cc,
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
defer wg.Done()
|
||||
atomic.AddInt32(&run, 1)
|
||||
return nil
|
||||
})
|
||||
wg.Wait()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&run))
|
||||
}
|
||||
|
||||
func TestTracingInterceptor_GrpcFormat(t *testing.T) {
|
||||
var run int32
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
md := metadata.New(map[string]string{
|
||||
"foo": "bar",
|
||||
})
|
||||
carrier, err := trace.Inject(trace.GrpcFormat, md)
|
||||
assert.Nil(t, err)
|
||||
ctx, _ := trace.StartServerSpan(context.Background(), carrier, "user", "/foo")
|
||||
cc := new(grpc.ClientConn)
|
||||
err = TracingInterceptor(ctx, "/foo", nil, nil, cc,
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
defer wg.Done()
|
||||
atomic.AddInt32(&run, 1)
|
||||
return nil
|
||||
})
|
||||
wg.Wait()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&run))
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/balancer/roundrobin"
|
||||
)
|
||||
|
||||
type DirectClient struct {
|
||||
conn *grpc.ClientConn
|
||||
}
|
||||
|
||||
func NewDirectClient(server string, opts ...ClientOption) (*DirectClient, error) {
|
||||
opts = append(opts, WithDialOption(grpc.WithBalancerName(roundrobin.Name)))
|
||||
conn, err := dial(server, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &DirectClient{
|
||||
conn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *DirectClient) Conn() *grpc.ClientConn {
|
||||
return c.conn
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/tal-tech/go-zero/rpcx/internal/balancer/p2c"
|
||||
"github.com/tal-tech/go-zero/rpcx/internal/resolver"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func init() {
|
||||
resolver.RegisterResolver()
|
||||
}
|
||||
|
||||
type DiscovClient struct {
|
||||
conn *grpc.ClientConn
|
||||
}
|
||||
|
||||
func NewDiscovClient(endpoints []string, key string, opts ...ClientOption) (*DiscovClient, error) {
|
||||
opts = append(opts, WithDialOption(grpc.WithBalancerName(p2c.Name)))
|
||||
target := fmt.Sprintf("%s://%s/%s", resolver.DiscovScheme,
|
||||
strings.Join(endpoints, resolver.EndpointSep), key)
|
||||
conn, err := dial(target, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &DiscovClient{conn: conn}, nil
|
||||
}
|
||||
|
||||
func (c *DiscovClient) Conn() *grpc.ClientConn {
|
||||
return c.conn
|
||||
}
|
||||
32
rpcx/internal/resolver/directbuilder.go
Normal file
32
rpcx/internal/resolver/directbuilder.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"google.golang.org/grpc/resolver"
|
||||
)
|
||||
|
||||
type directBuilder struct{}
|
||||
|
||||
func (d *directBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (
|
||||
resolver.Resolver, error) {
|
||||
var addrs []resolver.Address
|
||||
endpoints := strings.FieldsFunc(target.Endpoint, func(r rune) bool {
|
||||
return r == EndpointSep
|
||||
})
|
||||
|
||||
for _, val := range subset(endpoints, subsetSize) {
|
||||
addrs = append(addrs, resolver.Address{
|
||||
Addr: val,
|
||||
})
|
||||
}
|
||||
cc.UpdateState(resolver.State{
|
||||
Addresses: addrs,
|
||||
})
|
||||
|
||||
return &nopResolver{cc: cc}, nil
|
||||
}
|
||||
|
||||
func (d *directBuilder) Scheme() string {
|
||||
return DirectScheme
|
||||
}
|
||||
52
rpcx/internal/resolver/directbuilder_test.go
Normal file
52
rpcx/internal/resolver/directbuilder_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/lang"
|
||||
"github.com/tal-tech/go-zero/core/mathx"
|
||||
"google.golang.org/grpc/resolver"
|
||||
)
|
||||
|
||||
func TestDirectBuilder_Build(t *testing.T) {
|
||||
tests := []int{
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
subsetSize / 2,
|
||||
subsetSize,
|
||||
subsetSize * 2,
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(strconv.Itoa(test), func(t *testing.T) {
|
||||
var servers []string
|
||||
for i := 0; i < test; i++ {
|
||||
servers = append(servers, fmt.Sprintf("localhost:%d", i))
|
||||
}
|
||||
var b directBuilder
|
||||
cc := new(mockedClientConn)
|
||||
_, err := b.Build(resolver.Target{
|
||||
Scheme: DirectScheme,
|
||||
Endpoint: strings.Join(servers, ","),
|
||||
}, cc, resolver.BuildOptions{})
|
||||
assert.Nil(t, err)
|
||||
size := mathx.MinInt(test, subsetSize)
|
||||
assert.Equal(t, size, len(cc.state.Addresses))
|
||||
m := make(map[string]lang.PlaceholderType)
|
||||
for _, each := range cc.state.Addresses {
|
||||
m[each.Addr] = lang.Placeholder
|
||||
}
|
||||
assert.Equal(t, size, len(m))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirectBuilder_Scheme(t *testing.T) {
|
||||
var b directBuilder
|
||||
assert.Equal(t, DirectScheme, b.Scheme())
|
||||
}
|
||||
41
rpcx/internal/resolver/discovbuilder.go
Normal file
41
rpcx/internal/resolver/discovbuilder.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/discov"
|
||||
"google.golang.org/grpc/resolver"
|
||||
)
|
||||
|
||||
type discovBuilder struct{}
|
||||
|
||||
func (d *discovBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (
|
||||
resolver.Resolver, error) {
|
||||
hosts := strings.FieldsFunc(target.Authority, func(r rune) bool {
|
||||
return r == EndpointSep
|
||||
})
|
||||
sub, err := discov.NewSubscriber(hosts, target.Endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
update := func() {
|
||||
var addrs []resolver.Address
|
||||
for _, val := range subset(sub.Values(), subsetSize) {
|
||||
addrs = append(addrs, resolver.Address{
|
||||
Addr: val,
|
||||
})
|
||||
}
|
||||
cc.UpdateState(resolver.State{
|
||||
Addresses: addrs,
|
||||
})
|
||||
}
|
||||
sub.AddListener(update)
|
||||
update()
|
||||
|
||||
return &nopResolver{cc: cc}, nil
|
||||
}
|
||||
|
||||
func (d *discovBuilder) Scheme() string {
|
||||
return DiscovScheme
|
||||
}
|
||||
@@ -1,68 +1,30 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/discov"
|
||||
"google.golang.org/grpc/resolver"
|
||||
)
|
||||
import "google.golang.org/grpc/resolver"
|
||||
|
||||
const (
|
||||
DirectScheme = "direct"
|
||||
DiscovScheme = "discov"
|
||||
EndpointSep = ","
|
||||
EndpointSep = ','
|
||||
subsetSize = 32
|
||||
)
|
||||
|
||||
var builder discovBuilder
|
||||
var (
|
||||
dirBuilder directBuilder
|
||||
disBuilder discovBuilder
|
||||
)
|
||||
|
||||
type discovBuilder struct{}
|
||||
|
||||
func (b *discovBuilder) Scheme() string {
|
||||
return DiscovScheme
|
||||
func RegisterResolver() {
|
||||
resolver.Register(&dirBuilder)
|
||||
resolver.Register(&disBuilder)
|
||||
}
|
||||
|
||||
func (b *discovBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (
|
||||
resolver.Resolver, error) {
|
||||
if target.Scheme != DiscovScheme {
|
||||
return nil, fmt.Errorf("bad scheme: %s", target.Scheme)
|
||||
}
|
||||
|
||||
hosts := strings.Split(target.Authority, EndpointSep)
|
||||
sub, err := discov.NewSubscriber(hosts, target.Endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
update := func() {
|
||||
var addrs []resolver.Address
|
||||
for _, val := range subset(sub.Values(), subsetSize) {
|
||||
addrs = append(addrs, resolver.Address{
|
||||
Addr: val,
|
||||
})
|
||||
}
|
||||
cc.UpdateState(resolver.State{
|
||||
Addresses: addrs,
|
||||
})
|
||||
}
|
||||
sub.AddListener(update)
|
||||
update()
|
||||
|
||||
return &discovResolver{
|
||||
cc: cc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type discovResolver struct {
|
||||
type nopResolver struct {
|
||||
cc resolver.ClientConn
|
||||
}
|
||||
|
||||
func (r *discovResolver) Close() {
|
||||
func (r *nopResolver) Close() {
|
||||
}
|
||||
|
||||
func (r *discovResolver) ResolveNow(options resolver.ResolveNowOptions) {
|
||||
}
|
||||
|
||||
func RegisterResolver() {
|
||||
resolver.Register(&builder)
|
||||
func (r *nopResolver) ResolveNow(options resolver.ResolveNowOptions) {
|
||||
}
|
||||
|
||||
36
rpcx/internal/resolver/resolver_test.go
Normal file
36
rpcx/internal/resolver/resolver_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc/resolver"
|
||||
"google.golang.org/grpc/serviceconfig"
|
||||
)
|
||||
|
||||
func TestNopResolver(t *testing.T) {
|
||||
// make sure ResolveNow & Close don't panic
|
||||
var r nopResolver
|
||||
r.ResolveNow(resolver.ResolveNowOptions{})
|
||||
r.Close()
|
||||
}
|
||||
|
||||
type mockedClientConn struct {
|
||||
state resolver.State
|
||||
}
|
||||
|
||||
func (m *mockedClientConn) UpdateState(state resolver.State) {
|
||||
m.state = state
|
||||
}
|
||||
|
||||
func (m *mockedClientConn) ReportError(err error) {
|
||||
}
|
||||
|
||||
func (m *mockedClientConn) NewAddress(addresses []resolver.Address) {
|
||||
}
|
||||
|
||||
func (m *mockedClientConn) NewServiceConfig(serviceConfig string) {
|
||||
}
|
||||
|
||||
func (m *mockedClientConn) ParseServiceConfig(serviceConfigJSON string) *serviceconfig.ParseResult {
|
||||
return nil
|
||||
}
|
||||
200
rpcx/internal/serverinterceptors/authinterceptor_test.go
Normal file
200
rpcx/internal/serverinterceptors/authinterceptor_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/alicebob/miniredis"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/stores/redis"
|
||||
"github.com/tal-tech/go-zero/rpcx/internal/auth"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func TestStreamAuthorizeInterceptor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
app string
|
||||
token string
|
||||
strict bool
|
||||
hasError bool
|
||||
}{
|
||||
{
|
||||
name: "strict=false",
|
||||
strict: false,
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "strict=true",
|
||||
strict: true,
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "strict=true,with token",
|
||||
app: "foo",
|
||||
token: "bar",
|
||||
strict: true,
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "strict=true,with error token",
|
||||
app: "foo",
|
||||
token: "error",
|
||||
strict: true,
|
||||
hasError: true,
|
||||
},
|
||||
}
|
||||
|
||||
r := miniredis.NewMiniRedis()
|
||||
assert.Nil(t, r.Start())
|
||||
defer r.Close()
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
store := redis.NewRedis(r.Addr(), redis.NodeType)
|
||||
if len(test.app) > 0 {
|
||||
assert.Nil(t, store.Hset("apps", test.app, test.token))
|
||||
defer store.Hdel("apps", test.app)
|
||||
}
|
||||
|
||||
authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
|
||||
assert.Nil(t, err)
|
||||
interceptor := StreamAuthorizeInterceptor(authenticator)
|
||||
md := metadata.New(map[string]string{
|
||||
"app": "foo",
|
||||
"token": "bar",
|
||||
})
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
stream := mockedStream{ctx: ctx}
|
||||
err = interceptor(nil, stream, nil, func(srv interface{}, stream grpc.ServerStream) error {
|
||||
return nil
|
||||
})
|
||||
if test.hasError {
|
||||
assert.NotNil(t, err)
|
||||
} else {
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnaryAuthorizeInterceptor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
app string
|
||||
token string
|
||||
strict bool
|
||||
hasError bool
|
||||
}{
|
||||
{
|
||||
name: "strict=false",
|
||||
strict: false,
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "strict=true",
|
||||
strict: true,
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "strict=true,with token",
|
||||
app: "foo",
|
||||
token: "bar",
|
||||
strict: true,
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "strict=true,with error token",
|
||||
app: "foo",
|
||||
token: "error",
|
||||
strict: true,
|
||||
hasError: true,
|
||||
},
|
||||
}
|
||||
|
||||
r := miniredis.NewMiniRedis()
|
||||
assert.Nil(t, r.Start())
|
||||
defer r.Close()
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
store := redis.NewRedis(r.Addr(), redis.NodeType)
|
||||
if len(test.app) > 0 {
|
||||
assert.Nil(t, store.Hset("apps", test.app, test.token))
|
||||
defer store.Hdel("apps", test.app)
|
||||
}
|
||||
|
||||
authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
|
||||
assert.Nil(t, err)
|
||||
interceptor := UnaryAuthorizeInterceptor(authenticator)
|
||||
md := metadata.New(map[string]string{
|
||||
"app": "foo",
|
||||
"token": "bar",
|
||||
})
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
_, err = interceptor(ctx, nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
if test.hasError {
|
||||
assert.NotNil(t, err)
|
||||
} else {
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
if test.strict {
|
||||
_, err = interceptor(context.Background(), nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
var md metadata.MD
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
_, err = interceptor(ctx, nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
md = metadata.New(map[string]string{
|
||||
"app": "",
|
||||
"token": "",
|
||||
})
|
||||
ctx = metadata.NewIncomingContext(context.Background(), md)
|
||||
_, err = interceptor(ctx, nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockedStream struct {
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (m mockedStream) SetHeader(md metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockedStream) SendHeader(md metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockedStream) SetTrailer(md metadata.MD) {
|
||||
}
|
||||
|
||||
func (m mockedStream) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func (m mockedStream) SendMsg(v interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockedStream) RecvMsg(v interface{}) error {
|
||||
return nil
|
||||
}
|
||||
31
rpcx/internal/serverinterceptors/crashinterceptor_test.go
Normal file
31
rpcx/internal/serverinterceptors/crashinterceptor_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logx.Disable()
|
||||
}
|
||||
|
||||
func TestStreamCrashInterceptor(t *testing.T) {
|
||||
err := StreamCrashInterceptor(nil, nil, nil, func(
|
||||
srv interface{}, stream grpc.ServerStream) error {
|
||||
panic("mock panic")
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestUnaryCrashInterceptor(t *testing.T) {
|
||||
interceptor := UnaryCrashInterceptor()
|
||||
_, err := interceptor(context.Background(), nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
panic("mock panic")
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
@@ -33,12 +33,12 @@ var (
|
||||
)
|
||||
|
||||
func UnaryPromMetricInterceptor() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
|
||||
interface{}, error) {
|
||||
startTime := timex.Now()
|
||||
resp, err := handler(ctx, req)
|
||||
metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), info.FullMethod)
|
||||
metricServerReqCodeTotal.Inc(info.FullMethod, strconv.Itoa(int(status.Code(err))))
|
||||
return resp, err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestUnaryPromMetricInterceptor(t *testing.T) {
|
||||
interceptor := UnaryPromMetricInterceptor()
|
||||
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
77
rpcx/internal/serverinterceptors/sheddinginterceptor_test.go
Normal file
77
rpcx/internal/serverinterceptors/sheddinginterceptor_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/load"
|
||||
"github.com/tal-tech/go-zero/core/stat"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestUnarySheddingInterceptor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allow bool
|
||||
handleErr error
|
||||
expect error
|
||||
}{
|
||||
{
|
||||
name: "allow",
|
||||
allow: true,
|
||||
handleErr: nil,
|
||||
expect: nil,
|
||||
},
|
||||
{
|
||||
name: "allow",
|
||||
allow: true,
|
||||
handleErr: context.DeadlineExceeded,
|
||||
expect: context.DeadlineExceeded,
|
||||
},
|
||||
{
|
||||
name: "reject",
|
||||
allow: false,
|
||||
handleErr: nil,
|
||||
expect: load.ErrServiceOverloaded,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
shedder := mockedShedder{allow: test.allow}
|
||||
metrics := stat.NewMetrics("mock")
|
||||
interceptor := UnarySheddingInterceptor(shedder, metrics)
|
||||
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, test.handleErr
|
||||
})
|
||||
assert.Equal(t, test.expect, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockedShedder struct {
|
||||
allow bool
|
||||
}
|
||||
|
||||
func (m mockedShedder) Allow() (load.Promise, error) {
|
||||
if m.allow {
|
||||
return mockedPromise{}, nil
|
||||
} else {
|
||||
return nil, load.ErrServiceOverloaded
|
||||
}
|
||||
}
|
||||
|
||||
type mockedPromise struct {
|
||||
}
|
||||
|
||||
func (m mockedPromise) Pass() {
|
||||
}
|
||||
|
||||
func (m mockedPromise) Fail() {
|
||||
}
|
||||
32
rpcx/internal/serverinterceptors/statinterceptor_test.go
Normal file
32
rpcx/internal/serverinterceptors/statinterceptor_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/stat"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestUnaryStatInterceptor(t *testing.T) {
|
||||
metrics := stat.NewMetrics("mock")
|
||||
interceptor := UnaryStatInterceptor(metrics)
|
||||
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestUnaryStatInterceptor_crash(t *testing.T) {
|
||||
metrics := stat.NewMetrics("mock")
|
||||
interceptor := UnaryStatInterceptor(metrics)
|
||||
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
panic("error")
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
41
rpcx/internal/serverinterceptors/timeoutinterceptor_test.go
Normal file
41
rpcx/internal/serverinterceptors/timeoutinterceptor_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestUnaryTimeoutInterceptor(t *testing.T) {
|
||||
interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10)
|
||||
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestUnaryTimeoutInterceptor_timeout(t *testing.T) {
|
||||
const timeout = time.Millisecond * 10
|
||||
interceptor := UnaryTimeoutInterceptor(timeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
defer wg.Done()
|
||||
tm, ok := ctx.Deadline()
|
||||
assert.True(t, ok)
|
||||
assert.True(t, tm.Before(time.Now().Add(timeout+time.Millisecond)))
|
||||
return nil, nil
|
||||
})
|
||||
wg.Wait()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
48
rpcx/internal/serverinterceptors/tracinginterceptor_test.go
Normal file
48
rpcx/internal/serverinterceptors/tracinginterceptor_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/trace/tracespec"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func TestUnaryTracingInterceptor(t *testing.T) {
|
||||
interceptor := UnaryTracingInterceptor("foo")
|
||||
var run int32
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
defer wg.Done()
|
||||
atomic.AddInt32(&run, 1)
|
||||
return nil, nil
|
||||
})
|
||||
wg.Wait()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&run))
|
||||
}
|
||||
|
||||
func TestUnaryTracingInterceptor_GrpcFormat(t *testing.T) {
|
||||
interceptor := UnaryTracingInterceptor("foo")
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
var md metadata.MD
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
defer wg.Done()
|
||||
assert.True(t, len(ctx.Value(tracespec.TracingKey).(tracespec.Trace).TraceId()) > 0)
|
||||
assert.True(t, len(ctx.Value(tracespec.TracingKey).(tracespec.Trace).SpanId()) > 0)
|
||||
return nil, nil
|
||||
})
|
||||
wg.Wait()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
18
rpcx/internal/target.go
Normal file
18
rpcx/internal/target.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/tal-tech/go-zero/rpcx/internal/resolver"
|
||||
)
|
||||
|
||||
func BuildDirectTarget(endpoints []string) string {
|
||||
return fmt.Sprintf("%s:///%s", resolver.DirectScheme, strings.Join(
|
||||
endpoints, fmt.Sprintf("%c", resolver.EndpointSep)))
|
||||
}
|
||||
|
||||
func BuildDiscovTarget(endpoints []string, key string) string {
|
||||
return fmt.Sprintf("%s://%s/%s", resolver.DiscovScheme, strings.Join(
|
||||
endpoints, fmt.Sprintf("%c", resolver.EndpointSep)), key)
|
||||
}
|
||||
17
rpcx/internal/target_test.go
Normal file
17
rpcx/internal/target_test.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBuildDirectTarget(t *testing.T) {
|
||||
target := BuildDirectTarget([]string{"localhost:123", "localhost:456"})
|
||||
assert.Equal(t, "direct:///localhost:123,localhost:456", target)
|
||||
}
|
||||
|
||||
func TestBuildDiscovTarget(t *testing.T) {
|
||||
target := BuildDiscovTarget([]string{"localhost:123", "localhost:456"}, "foo")
|
||||
assert.Equal(t, "discov://localhost:123,localhost:456/foo", target)
|
||||
}
|
||||
@@ -38,11 +38,11 @@ func (p *RpcProxy) TakeConn(ctx context.Context) (*grpc.ClientConn, error) {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
client, err := NewClient(RpcClientConf{
|
||||
Server: p.backend,
|
||||
App: cred.App,
|
||||
Token: cred.Token,
|
||||
}, p.options...)
|
||||
opts := append(p.options, WithDialOption(grpc.WithPerRPCCredentials(&auth.Credential{
|
||||
App: cred.App,
|
||||
Token: cred.Token,
|
||||
})))
|
||||
client, err := NewClientWithTarget(p.backend, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -19,29 +19,24 @@ const apiTemplate = `info(
|
||||
email: {{.gitEmail}}
|
||||
)
|
||||
|
||||
type request struct{
|
||||
type request struct {
|
||||
// TODO: add members here and delete this comment
|
||||
}
|
||||
|
||||
type response struct{
|
||||
type response struct {
|
||||
// TODO: add members here and delete this comment
|
||||
}
|
||||
|
||||
@server(
|
||||
port: // TODO: add port here and delete this comment
|
||||
)
|
||||
service {{.serviceName}} {
|
||||
@server(
|
||||
handler: // TODO: set handler name and delete this comment
|
||||
)
|
||||
// TODO: edit the below line
|
||||
// get /users/id/:userId(request) returns(response)
|
||||
get /users/id/:userId(request) returns(response)
|
||||
|
||||
@server(
|
||||
handler: // TODO: set handler name and delete this comment
|
||||
)
|
||||
// TODO: edit the below line
|
||||
// post /users/create(request)
|
||||
post /users/create(request)
|
||||
}
|
||||
`
|
||||
|
||||
|
||||
@@ -148,7 +148,6 @@ func createGoModFileIfNeed(dir string) {
|
||||
}
|
||||
tempPath = filepath.Dir(tempPath)
|
||||
if util.FileExists(filepath.Join(tempPath, goModeIdentifier)) {
|
||||
tempPath = filepath.Dir(tempPath)
|
||||
hasGoMod = true
|
||||
break
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ func genHandler(dir string, group spec.Group, route spec.Route) error {
|
||||
req = ""
|
||||
}
|
||||
var logicResponse string
|
||||
var writeResponse = "nil, nil"
|
||||
var writeResponse string
|
||||
var respWriter = `httpx.WriteJson(w, http.StatusOK, resp)`
|
||||
if len(route.ResponseType.Name) > 0 {
|
||||
logicResponse = "resp, err :="
|
||||
|
||||
@@ -43,6 +43,7 @@ var mapping = map[string]string{
|
||||
"head": "http.MethodHead",
|
||||
"post": "http.MethodPost",
|
||||
"put": "http.MethodPut",
|
||||
"patch": "http.MethodPatch",
|
||||
}
|
||||
|
||||
type (
|
||||
|
||||
42
tools/goctl/api/ktgen/cmd.go
Normal file
42
tools/goctl/api/ktgen/cmd.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package ktgen
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/tal-tech/go-zero/tools/goctl/api/parser"
|
||||
"github.com/urfave/cli"
|
||||
)
|
||||
|
||||
func KtCommand(c *cli.Context) error {
|
||||
apiFile := c.String("api")
|
||||
if apiFile == "" {
|
||||
return errors.New("missing -api")
|
||||
}
|
||||
dir := c.String("dir")
|
||||
if dir == "" {
|
||||
return errors.New("missing -dir")
|
||||
}
|
||||
pkg := c.String("pkg")
|
||||
if pkg == "" {
|
||||
return errors.New("missing -pkg")
|
||||
}
|
||||
|
||||
p, e := parser.NewParser(apiFile)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
api, e := p.Parse()
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
e = genBase(dir, pkg, api)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
e = genApi(dir, pkg, api)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
return nil
|
||||
}
|
||||
77
tools/goctl/api/ktgen/funcs.go
Normal file
77
tools/goctl/api/ktgen/funcs.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package ktgen
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/iancoleman/strcase"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/api/util"
|
||||
)
|
||||
|
||||
var funcsMap = template.FuncMap{
|
||||
"lowCamelCase": lowCamelCase,
|
||||
"routeToFuncName": routeToFuncName,
|
||||
"parseType": parseType,
|
||||
"add": add,
|
||||
"upperCase": upperCase,
|
||||
}
|
||||
|
||||
func lowCamelCase(s string) string {
|
||||
if len(s) < 1 {
|
||||
return ""
|
||||
}
|
||||
s = util.ToCamelCase(util.ToSnakeCase(s))
|
||||
return util.ToLower(s[:1]) + s[1:]
|
||||
}
|
||||
|
||||
func routeToFuncName(method, path string) string {
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
|
||||
path = strings.ReplaceAll(path, "/", "_")
|
||||
path = strings.ReplaceAll(path, "-", "_")
|
||||
path = strings.ReplaceAll(path, ":", "With_")
|
||||
|
||||
return strings.ToLower(method) + strcase.ToCamel(path)
|
||||
}
|
||||
|
||||
func parseType(t string) string {
|
||||
t = strings.Replace(t, "*", "", -1)
|
||||
if strings.HasPrefix(t, "[]") {
|
||||
return "List<" + parseType(t[2:]) + ">"
|
||||
}
|
||||
|
||||
if strings.HasPrefix(t, "map") {
|
||||
tys, e := util.DecomposeType(t)
|
||||
if e != nil {
|
||||
log.Fatal(e)
|
||||
}
|
||||
if len(tys) != 2 {
|
||||
log.Fatal("Map type number !=2")
|
||||
}
|
||||
return "Map<String," + parseType(tys[1]) + ">"
|
||||
}
|
||||
|
||||
switch t {
|
||||
case "string":
|
||||
return "String"
|
||||
case "int", "int32", "int64":
|
||||
return "Int"
|
||||
case "float", "float32", "float64":
|
||||
return "Double"
|
||||
case "bool":
|
||||
return "Boolean"
|
||||
default:
|
||||
return t
|
||||
}
|
||||
}
|
||||
|
||||
func add(a, i int) int {
|
||||
return a + i
|
||||
}
|
||||
|
||||
func upperCase(s string) string {
|
||||
return strings.ToUpper(s)
|
||||
}
|
||||
150
tools/goctl/api/ktgen/gen.go
Normal file
150
tools/goctl/api/ktgen/gen.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package ktgen
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"text/template"
|
||||
|
||||
"github.com/iancoleman/strcase"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/api/spec"
|
||||
)
|
||||
|
||||
const (
|
||||
apiBaseTemplate = `package {{.}}
|
||||
|
||||
import com.google.gson.Gson
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.io.BufferedReader
|
||||
import java.io.InputStreamReader
|
||||
import java.io.OutputStreamWriter
|
||||
import java.net.HttpURLConnection
|
||||
import java.net.URL
|
||||
|
||||
const val SERVER = "http://localhost:8080"
|
||||
|
||||
suspend fun apiRequest(
|
||||
method: String,
|
||||
uri: String,
|
||||
body: Any = "",
|
||||
onOk: ((String) -> Unit)? = null,
|
||||
onFail: ((String) -> Unit)? = null,
|
||||
eventually: (() -> Unit)? = null
|
||||
) = withContext(Dispatchers.IO) {
|
||||
val url = URL(SERVER + uri)
|
||||
with(url.openConnection() as HttpURLConnection) {
|
||||
connectTimeout = 3000
|
||||
requestMethod = method
|
||||
doInput = true
|
||||
if (method == "POST" || method == "PUT" || method == "PATCH") {
|
||||
setRequestProperty("Content-Type", "application/json")
|
||||
doOutput = true
|
||||
val data = when (body) {
|
||||
is String -> {
|
||||
body
|
||||
}
|
||||
else -> {
|
||||
Gson().toJson(body)
|
||||
}
|
||||
}
|
||||
val wr = OutputStreamWriter(outputStream)
|
||||
wr.write(data)
|
||||
wr.flush()
|
||||
}
|
||||
|
||||
try {
|
||||
if (responseCode >= 400) {
|
||||
BufferedReader(InputStreamReader(errorStream)).use {
|
||||
val response = it.readText()
|
||||
onFail?.invoke(response)
|
||||
}
|
||||
return@with
|
||||
}
|
||||
//response
|
||||
BufferedReader(InputStreamReader(inputStream)).use {
|
||||
val response = it.readText()
|
||||
onOk?.invoke(response)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
e.message?.let { onFail?.invoke(it) }
|
||||
}
|
||||
}
|
||||
eventually?.invoke()
|
||||
}
|
||||
`
|
||||
apiTemplate = `package {{with .Info}}{{.Desc}}{{end}}
|
||||
|
||||
import com.google.gson.Gson
|
||||
|
||||
object {{with .Info}}{{.Title}}{{end}}{
|
||||
{{range .Types}}
|
||||
data class {{.Name}}({{$length := (len .Members)}}{{range $i,$item := .Members}}
|
||||
val {{with $item}}{{lowCamelCase .Name}}: {{parseType .Type}}{{end}}{{if ne $i (add $length -1)}},{{end}}{{end}}
|
||||
){{end}}
|
||||
{{with .Service}}
|
||||
{{range .Routes}}suspend fun {{routeToFuncName .Method .Path}}({{with .RequestType}}{{if ne .Name ""}}
|
||||
req:{{.Name}},{{end}}{{end}}
|
||||
onOk: (({{with .ResponseType}}{{.Name}}{{end}}) -> Unit)? = null,
|
||||
onFail: ((String) -> Unit)? = null,
|
||||
eventually: (() -> Unit)? = null
|
||||
){
|
||||
apiRequest("{{upperCase .Method}}","{{.Path}}",{{with .RequestType}}{{if ne .Name ""}}body=req,{{end}}{{end}} onOk = { {{with .ResponseType}}
|
||||
onOk?.invoke({{if ne .Name ""}}Gson().fromJson(it,{{.Name}}::class.java){{end}}){{end}}
|
||||
}, onFail = onFail, eventually =eventually)
|
||||
}
|
||||
{{end}}{{end}}
|
||||
}`
|
||||
)
|
||||
|
||||
func genBase(dir, pkg string, api *spec.ApiSpec) error {
|
||||
e := os.MkdirAll(dir, 0755)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
path := filepath.Join(dir, "BaseApi.kt")
|
||||
if _, e := os.Stat(path); e == nil {
|
||||
fmt.Println("BaseApi.kt already exists, skipped it.")
|
||||
return nil
|
||||
}
|
||||
|
||||
file, e := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
t, e := template.New("n").Parse(apiBaseTemplate)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
e = t.Execute(file, pkg)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func genApi(dir, pkg string, api *spec.ApiSpec) error {
|
||||
name := strcase.ToCamel(api.Info.Title + "Api")
|
||||
path := filepath.Join(dir, name+".kt")
|
||||
api.Info.Title = name
|
||||
api.Info.Desc = pkg
|
||||
|
||||
e := os.MkdirAll(dir, 0755)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
file, e := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0644)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
t, e := template.New("api").Funcs(funcsMap).Parse(apiTemplate)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
return t.Execute(file, api)
|
||||
}
|
||||
@@ -99,8 +99,6 @@ func (s rootState) processToken(token string, annos []spec.Annotation) (state, e
|
||||
switch token {
|
||||
case infoDirective:
|
||||
return newInfoState(s.baseState), nil
|
||||
//case typeDirective:
|
||||
//return newTypeState(s.baseState, annos), nil
|
||||
case serviceDirective:
|
||||
return newServiceState(s.baseState, annos), nil
|
||||
default:
|
||||
|
||||
@@ -13,6 +13,10 @@ import (
|
||||
func writeProperty(writer io.Writer, member spec.Member, indent int, prefixForType func(string) string) error {
|
||||
writeIndent(writer, indent)
|
||||
ty, err := goTypeToTs(member.Type, prefixForType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
optionalTag := ""
|
||||
if member.IsOptional() || member.IsOmitempty() {
|
||||
optionalTag = "?"
|
||||
@@ -21,13 +25,14 @@ func writeProperty(writer io.Writer, member spec.Member, indent int, prefixForTy
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
comment := member.GetComment()
|
||||
if len(comment) > 0 {
|
||||
comment = strings.TrimPrefix(comment, "//")
|
||||
comment = " // " + strings.TrimSpace(comment)
|
||||
}
|
||||
if len(member.Docs) > 0 {
|
||||
_, err = fmt.Fprintf(writer, "%s\n", strings.Join(member.Docs, ""))
|
||||
fmt.Fprintf(writer, "%s\n", strings.Join(member.Docs, ""))
|
||||
writeIndent(writer, 1)
|
||||
}
|
||||
_, err = fmt.Fprintf(writer, "%s%s: %s%s\n", name, optionalTag, ty, comment)
|
||||
|
||||
@@ -28,6 +28,9 @@ func MaybeCreateFile(dir, subdir, file string) (fp *os.File, created bool, err e
|
||||
|
||||
func ClearAndOpenFile(fpath string) (*os.File, error) {
|
||||
f, err := os.OpenFile(fpath, os.O_WRONLY|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = f.WriteString("")
|
||||
if err != nil {
|
||||
|
||||
@@ -8,12 +8,10 @@ import (
|
||||
)
|
||||
|
||||
var feature = `
|
||||
1、新增对rpc错误转换处理
|
||||
1.1、目前暂时仅处理not found 和 unknown错误
|
||||
2、增加feature命令支持,详细使用请通过命令[goctl -feature]查看
|
||||
1、增加goctl model支持
|
||||
`
|
||||
|
||||
func Feature(c *cli.Context) error {
|
||||
func Feature(_ *cli.Context) error {
|
||||
fmt.Println(aurora.Blue("\nFEATURE:"))
|
||||
fmt.Println(aurora.Blue(feature))
|
||||
return nil
|
||||
|
||||
@@ -11,11 +11,13 @@ import (
|
||||
"github.com/tal-tech/go-zero/tools/goctl/api/format"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/api/gogen"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/api/javagen"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/api/ktgen"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/api/tsgen"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/api/validate"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/configgen"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/docker"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/feature"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/command"
|
||||
"github.com/urfave/cli"
|
||||
)
|
||||
|
||||
@@ -150,6 +152,25 @@ var (
|
||||
},
|
||||
Action: dartgen.DartCommand,
|
||||
},
|
||||
{
|
||||
Name: "kt",
|
||||
Usage: "generate kotlin code for provided api file",
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{
|
||||
Name: "dir",
|
||||
Usage: "the target directory",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "api",
|
||||
Usage: "the api file",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "pkg",
|
||||
Usage: "define package name for kotlin file",
|
||||
},
|
||||
},
|
||||
Action: ktgen.KtCommand,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -169,15 +190,63 @@ var (
|
||||
},
|
||||
{
|
||||
Name: "model",
|
||||
Usage: "generate sql model",
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{
|
||||
Name: "config, c",
|
||||
Usage: "the file that contains main function",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "dir, d",
|
||||
Usage: "the target dir",
|
||||
Usage: "generate model code",
|
||||
Subcommands: []cli.Command{
|
||||
{
|
||||
Name: "mysql",
|
||||
Usage: `generate mysql model"`,
|
||||
Subcommands: []cli.Command{
|
||||
{
|
||||
Name: "ddl",
|
||||
Usage: `generate mysql model from ddl"`,
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{
|
||||
Name: "src, s",
|
||||
Usage: "the file path of the ddl source file",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "dir, d",
|
||||
Usage: "the target dir",
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "cache, c",
|
||||
Usage: "generate code with cache [optional]",
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "idea",
|
||||
Usage: "for idea plugin [optional]",
|
||||
},
|
||||
},
|
||||
Action: command.MysqlDDL,
|
||||
},
|
||||
{
|
||||
Name: "datasource",
|
||||
Usage: `generate model from datasource"`,
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{
|
||||
Name: "url",
|
||||
Usage: `the data source of database,like "root:password@tcp(127.0.0.1:3306)/database"`,
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "table, t",
|
||||
Usage: `source table,tables separated by commas,like "user,course"`,
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "cache, c",
|
||||
Usage: "generate code with cache [optional]",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "dir, d",
|
||||
Usage: "the target dir",
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "idea",
|
||||
Usage: "for idea plugin [optional]",
|
||||
},
|
||||
},
|
||||
Action: command.MyDataSource,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -3,9 +3,7 @@
|
||||
## goctl用途
|
||||
* 定义api请求
|
||||
* 根据定义的api自动生成golang(后端), java(iOS & Android), typescript(web & 晓程序),dart(flutter)
|
||||
* 生成MySQL CURD (https://goctl.xiaoheiban.cn)
|
||||
* 生成MongoDB CURD (https://goctl.xiaoheiban.cn)
|
||||
|
||||
* 生成MySQL CURD 详情见[goctl model模块](https://github.com/tal-tech/go-zero/tools/goctl/model)
|
||||
## goctl使用说明
|
||||
#### goctl参数说明
|
||||
|
||||
@@ -188,79 +186,5 @@ service user-api {
|
||||
|
||||
#### 根据定义好的api文件生成Dart代码
|
||||
`goctl api dart -api user/user.api -dir ./src`
|
||||
|
||||
## 根据定义好的简单go文件生成mongo代码文件(仅限golang使用)
|
||||
`goctl model mongo -src {{yourDir}}/xiao/service/xhb/user/model/usermodel.go -cache yes`
|
||||
|
||||
-src需要提供简单的usermodel.go文件,里面只需要提供一个结构体即可
|
||||
-cache 控制是否需要缓存 yes=需要 no=不需要
|
||||
src 示例代码如下
|
||||
```
|
||||
package model
|
||||
|
||||
type User struct {
|
||||
Name string `o:"find,get,set" c:"姓名"`
|
||||
Age int `o:"find,get,set" c:"年纪"`
|
||||
School string `c:"学校"`
|
||||
}
|
||||
|
||||
```
|
||||
结构体中不需要提供Id,CreateTime,UpdateTime三个字段,会自动生成
|
||||
结构体中每个tag有两个可选标签 c 和 o
|
||||
c是改字段的注释
|
||||
o是改字段需要生产的操作函数 可以取得get,find,set 分别表示生成返回单个对象的查询方法,返回多个对象的查询方法,设置该字段方法
|
||||
生成的目标文件会覆盖该简单go文件
|
||||
|
||||
## goctl rpc生成
|
||||
|
||||
命令 `goctl rpc proto -proto ${proto} -service ${serviceName} -project ${projectName} -dir ${directory} -shared ${shared}`
|
||||
如: `goctl rpc proto -proto test.proto -service test -project xjy -dir .`
|
||||
|
||||
参数说明:
|
||||
|
||||
- ${proto}: proto文件
|
||||
- ${serviceName}: rpc服务名称
|
||||
- ${projectName}: 所属项目,如xjy,xhb,crm,hera,具体查看help,主要为了根据不同项目服务往redis注册key,可选
|
||||
- ${directory}: 输出目录
|
||||
- ${shared}: shared文件生成目录,可选,默认为${pwd}/shared
|
||||
|
||||
生成目录结构示例:
|
||||
|
||||
``` go
|
||||
.
|
||||
├── shared [示例目录,可自己指定,强制覆盖更新]
|
||||
│ └── contentservicemodel.go
|
||||
├── test
|
||||
│ ├── etc
|
||||
│ │ └── test.json
|
||||
│ ├── internal
|
||||
│ │ ├── config
|
||||
│ │ │ └── config.go
|
||||
│ │ ├── handler [强制覆盖更新]
|
||||
│ │ │ ├── changeavatarhandler.go
|
||||
│ │ │ ├── changebirthdayhandler.go
|
||||
│ │ │ ├── changenamehandler.go
|
||||
│ │ │ ├── changepasswordhandler.go
|
||||
│ │ │ ├── changeuserinfohandler.go
|
||||
│ │ │ ├── getuserinfohandler.go
|
||||
│ │ │ ├── loginhandler.go
|
||||
│ │ │ ├── logouthandler.go
|
||||
│ │ │ └── testhandler.go
|
||||
│ │ ├── logic
|
||||
│ │ │ ├── changeavatarlogic.go
|
||||
│ │ │ ├── changebirthdaylogic.go
|
||||
│ │ │ ├── changenamelogic.go
|
||||
│ │ │ ├── changepasswordlogic.go
|
||||
│ │ │ ├── changeuserinfologic.go
|
||||
│ │ │ ├── getuserinfologic.go
|
||||
│ │ │ ├── loginlogic.go
|
||||
│ │ │ └── logoutlogic.go
|
||||
│ │ └── svc
|
||||
│ │ └── servicecontext.go
|
||||
│ ├── pb
|
||||
│ │ └── test.pb.go
|
||||
│ └── test.go [强制覆盖更新]
|
||||
└── test.proto
|
||||
```
|
||||
- 注意 :目前rpc目录生成的proto文件暂不支持import外部proto文件
|
||||
|
||||
* 如有不理解的地方,随时问Kim/Kevin
|
||||
10
tools/goctl/model/sql/CHANGELOG.md
Normal file
10
tools/goctl/model/sql/CHANGELOG.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# Change log
|
||||
|
||||
# 2020-08-20
|
||||
* 新增支持通过连接数据库生成model
|
||||
* 支持数据库多表生成
|
||||
* 优化stringx
|
||||
|
||||
# 2020-08-19
|
||||
* 重构model代码生成逻辑
|
||||
* 实现从ddl解析表信息生成代码
|
||||
@@ -1,430 +1,277 @@
|
||||
<div style="text-align: center;"><h1>Sql生成工具说明文档</h1></div>
|
||||
# Goctl Model
|
||||
|
||||
<h2>前言</h2>
|
||||
在当前Sql代码生成工具是基于sqlc生成的逻辑。
|
||||
goctl model 为go-zero下的工具模块中的组件之一,目前支持识别mysql ddl进行model层代码生成,通过命令行或者idea插件(即将支持)可以有选择地生成带redis cache或者不带redis cache的代码逻辑。
|
||||
|
||||
<h2>关键字</h2>
|
||||
# 快速开始
|
||||
|
||||
+ 查询类型(前暂不支持同一字段多种类型混合生成,如按照campus_id查询单结果又查询All或者Limit)
|
||||
- 单结果查询
|
||||
- FindOne(主键特有)
|
||||
- FindOneByXxx
|
||||
- 多结果查询
|
||||
- FindAllByXxx
|
||||
- FindLimitByXxx
|
||||
- withCache
|
||||
- withoutCache
|
||||
* 通过ddl生成
|
||||
|
||||
<h2>准备工作</h2>
|
||||
|
||||
- table
|
||||
|
||||
```
|
||||
CREATE TABLE `user_info` (
|
||||
`id` bigint(20) NOT NULL COMMENT '主键',
|
||||
`campus_id` bigint(20) DEFAULT NULL COMMENT '整校id',
|
||||
`name` varchar(255) DEFAULT NULL COMMENT '用户姓名',
|
||||
`id_number` varchar(255) DEFAULT NULL COMMENT '身份证',
|
||||
`age` int(10) DEFAULT NULL COMMENT '年龄',
|
||||
`gender` tinyint(1) DEFAULT NULL COMMENT '性别,0-男,1-女,2-不限',
|
||||
`mobile` varchar(20) DEFAULT NULL COMMENT '手机号',
|
||||
`create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
|
||||
`update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
|
||||
PRIMARY KEY (`id`)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;
|
||||
```shell script
|
||||
$ goctl model mysql ddl -src="./sql/user.sql" -dir="./sql/model" -c=true
|
||||
```
|
||||
|
||||
<h2>imports生成</h2>
|
||||
imports代码生成对应model中包的引入管理,仅使用于晓黑板项目中(非相对路径动态生成),目前受`withCache`参数的影响,除此之外其实为固定代码。
|
||||
|
||||
- withCache
|
||||
|
||||
```
|
||||
import (
|
||||
"database/sql""fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/stores/sqlc"
|
||||
"github.com/tal-tech/go-zero/core/stores/sqlx"
|
||||
"github.com/tal-tech/go-zero/core/stringx"
|
||||
"xiao/service/shared/builderx"
|
||||
)
|
||||
```
|
||||
|
||||
- withoutCache
|
||||
|
||||
```
|
||||
import (
|
||||
"database/sql""fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/stores/sqlx"
|
||||
"github.com/tal-tech/go-zero/core/stringx"
|
||||
"xiao/service/shared/builderx"
|
||||
)
|
||||
```
|
||||
|
||||
<h2>vars生成</h2>
|
||||
|
||||
vars部分对应model中var声明的包含的代码块,由`table`名和`withCache`来决定其中的代码生成内容,`withCache`决定是否要生成缓存key变量的声明。
|
||||
|
||||
- withCache
|
||||
执行上述命令后即可快速生成CURD代码。
|
||||
|
||||
```
|
||||
var (
|
||||
UserInfoFieldNames = builderx.FieldNames(&UserInfo{})
|
||||
UserInfoRows = strings.Join(UserInfoFieldNames, ",")
|
||||
UserInfoRowsExpectAutoSet = strings.Join(stringx.Remove(UserInfoFieldNames, "id", "create_time", "update_time"), ",")
|
||||
UserInfoRowsWithPlaceHolder = strings.Join(stringx.Remove(UserInfoFieldNames, "id", "create_time", "update_time"), "=?,") + "=?"
|
||||
model
|
||||
│ ├── error.go
|
||||
│ └── usermodel.go
|
||||
```
|
||||
* 通过datasource生成
|
||||
|
||||
```shell script
|
||||
$ goctl model mysql datasource -url="user:password@tcp(127.0.0.1:3306)/database" -table="table1,table2" -dir="./model"
|
||||
```
|
||||
|
||||
cacheUserInfoIdPrefix = "cache#userInfo#id#"
|
||||
cacheUserInfoCampusIdPrefix = "cache#userInfo#campusId#"
|
||||
cacheUserInfoNamePrefix = "cache#userInfo#name#"
|
||||
cacheUserInfoMobilePrefix = "cache#userInfo#mobile#"
|
||||
)
|
||||
```
|
||||
|
||||
- withoutCache
|
||||
|
||||
```
|
||||
var (
|
||||
UserInfoFieldNames = builderx.FieldNames(&UserInfo{})
|
||||
UserInfoRows = strings.Join(UserInfoFieldNames, ",")
|
||||
UserInfoRowsExpectAutoSet = strings.Join(stringx.Remove(UserInfoFieldNames, "id", "create_time", "update_time"), ",")
|
||||
UserInfoRowsWithPlaceHolder = strings.Join(stringx.Remove(UserInfoFieldNames, "id", "create_time", "update_time"), "=?,") + "=?"
|
||||
)
|
||||
```
|
||||
|
||||
<h2>types生成</h2>
|
||||
|
||||
ypes部分对应model中type声明的包含的代码块,由`table`名和`withCache`来决定其中的代码生成内容,`withCache`决定引入sqlc还是sqlx。
|
||||
|
||||
- withCache
|
||||
```
|
||||
type (
|
||||
UserInfoModel struct {
|
||||
conn sqlc.CachedConn
|
||||
table string
|
||||
}
|
||||
|
||||
UserInfo struct {
|
||||
Id int64 `db:"id"` // 主键id
|
||||
CampusId int64 `db:"campus_id"` // 整校id
|
||||
Name string `db:"name"` // 用户姓名
|
||||
IdNumber string `db:"id_number"` // 身份证
|
||||
Age int64 `db:"age"` // 年龄
|
||||
Gender int64 `db:"gender"` // 性别,0-男,1-女,2-不限
|
||||
Mobile string `db:"mobile"` // 手机号
|
||||
CreateTime time.Time `db:"create_time"` // 创建时间
|
||||
UpdateTime time.Time `db:"update_time"` // 更新时间
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
- withoutCache
|
||||
```
|
||||
type (
|
||||
UserInfoModel struct {
|
||||
conn sqlx.SqlConn
|
||||
table string
|
||||
}
|
||||
|
||||
UserInfo struct {
|
||||
Id int64 `db:"id"` // 主键id
|
||||
CampusId int64 `db:"campus_id"` // 整校id
|
||||
Name string `db:"name"` // 用户姓名
|
||||
IdNumber string `db:"id_number"` // 身份证
|
||||
Age int64 `db:"age"` // 年龄
|
||||
Gender int64 `db:"gender"` // 性别,0-男,1-女,2-不限
|
||||
Mobile string `db:"mobile"` // 手机号
|
||||
CreateTime time.Time `db:"create_time"` // 创建时间
|
||||
UpdateTime time.Time `db:"update_time"` // 更新时间
|
||||
}
|
||||
)
|
||||
```
|
||||
<h2>New生成</h2>
|
||||
new生成对应model中struct的New函数,受`withCache`影响决定是否要引入cacheRedis
|
||||
|
||||
- withCache
|
||||
```
|
||||
func NewUserInfoModel(conn sqlx.SqlConn, c cache.CacheConf, table string) *UserInfoModel {
|
||||
return &UserInfoModel{
|
||||
CachedConn: sqlc.NewConn(conn, c),
|
||||
table: table,
|
||||
}
|
||||
}
|
||||
```
|
||||
- withoutCache
|
||||
```
|
||||
func NewUserInfoModel(conn sqlx.SqlConn, table string) *UserInfoModel {
|
||||
return &UserInfoModel{conn: conn, table: table}
|
||||
}
|
||||
```
|
||||
> 详情用法请参考[example](https://github.com/tal-tech/go-zero/tree/master/tools/goctl/model/sql/example)
|
||||
|
||||
|
||||
<h2>FindOne查询生成</h2>
|
||||
FindOne查询代码生成仅对主键有效。如`user_info`中生成的FindOne如下:
|
||||
|
||||
- withCache
|
||||
|
||||
```
|
||||
func (m *UserInfoModel) FindOne(id int64) (*UserInfo, error) {
|
||||
idKey := fmt.Sprintf("%s%v", cacheUserInfoIdPrefix, id)
|
||||
var resp UserInfo
|
||||
err := m.QueryRow(&resp, idKey, func(conn sqlx.SqlConn, v interface{}) error {
|
||||
query := `select ` + userInfoRows + ` from ` + m.table + `where id = ? limit 1`
|
||||
return conn.QueryRow(v, query, id)
|
||||
})
|
||||
switch err {
|
||||
case nil:
|
||||
return &resp, nil
|
||||
case sqlc.ErrNotFound:
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- withoutCache
|
||||
|
||||
```
|
||||
func (m *UserInfoModel) FindOne(id int64) (*UserInfo, error) {
|
||||
|
||||
query := `select ` + userInfoRows + ` from ` + m.table + `where id = ? limit 1`
|
||||
var resp UserInfo
|
||||
err := m.conn.QueryRow(&resp, query, id)
|
||||
switch err {
|
||||
case nil:
|
||||
return &resp, nil
|
||||
case sqlx.ErrNotFound:
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
|
||||
}
|
||||
```
|
||||
|
||||
<h2>FindOneByXxx查询生成</h2>
|
||||
|
||||
FindOneByXxx查询生成可以按照单个字段查询、多个字段以AND关系且表达式符号为`=`的查询(下称:组合查询),对除主键之外的字段有效,对于单个字段可以用`withCache`来控制是否需要缓存,这里的缓存只缓存主键,并不缓存整个struct,注意:这里有一个隐藏的规则,如果单个字段查询需要cache,那么主键一定有cache;多个字段组成的`组合查询`一律没有缓存处理,<strong><i>且组合查询不能相互嵌套</i></strong>,否则会报`circle query with other fields`错误,下面我们按场景来依次查看对应代码生成后的示例。
|
||||
|
||||
>注:目前暂不支持除equals之外的条件查询。
|
||||
|
||||
+ 单字段查询
|
||||
以name查询为例
|
||||
- withCache
|
||||
```
|
||||
func (m *UserInfoModel) FindOneByName(name string) (*UserInfo, error) {
|
||||
nameKey := fmt.Sprintf("%s%v", cacheUserInfoNamePrefix, name)
|
||||
var id string
|
||||
err := m.GetCache(key, &id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if id != "" {
|
||||
return m.FindOne(id)
|
||||
}
|
||||
var resp UserInfo
|
||||
query := `select ` + userInfoRows + ` from ` + m.table + `where name = ? limit 1`
|
||||
err = m.QueryRowNoCache(&resp, query, name)
|
||||
switch err {
|
||||
case nil:
|
||||
err = m.SetCache(nameKey, resp.Id)
|
||||
if err != nil {
|
||||
logx.Error(err)
|
||||
}
|
||||
return &resp, nil
|
||||
case sqlc.ErrNotFound:
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
```
|
||||
- withoutCache
|
||||
|
||||
```
|
||||
func (m *UserInfoModel) FindOneByName(name string) (*UserInfo, error) {
|
||||
var resp UserInfo
|
||||
query := `select ` + userInfoRows + ` from ` + m.table + `where name = ? limit 1`
|
||||
err = m.conn.QueryRow(&resp, query, name)
|
||||
switch err {
|
||||
case nil:
|
||||
return &resp, nil
|
||||
case sqlx.ErrNotFound:
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- 组合查询
|
||||
以`campus_id`和`id_number`查询为例。
|
||||
* 生成代码示例
|
||||
|
||||
``` go
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/stores/cache"
|
||||
"github.com/tal-tech/go-zero/core/stores/sqlc"
|
||||
"github.com/tal-tech/go-zero/core/stores/sqlx"
|
||||
"github.com/tal-tech/go-zero/core/stringx"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
|
||||
)
|
||||
|
||||
var (
|
||||
userFieldNames = builderx.FieldNames(&User{})
|
||||
userRows = strings.Join(userFieldNames, ",")
|
||||
userRowsExpectAutoSet = strings.Join(stringx.Remove(userFieldNames, "id", "create_time", "update_time"), ",")
|
||||
userRowsWithPlaceHolder = strings.Join(stringx.Remove(userFieldNames, "id", "create_time", "update_time"), "=?,") + "=?"
|
||||
|
||||
cacheUserMobilePrefix = "cache#User#mobile#"
|
||||
cacheUserIdPrefix = "cache#User#id#"
|
||||
cacheUserNamePrefix = "cache#User#name#"
|
||||
)
|
||||
|
||||
type (
|
||||
UserModel struct {
|
||||
sqlc.CachedConn
|
||||
table string
|
||||
}
|
||||
|
||||
User struct {
|
||||
Id int64 `db:"id"`
|
||||
Name string `db:"name"` // 用户名称
|
||||
Password string `db:"password"` // 用户密码
|
||||
Mobile string `db:"mobile"` // 手机号
|
||||
Gender string `db:"gender"` // 男|女|未公开
|
||||
Nickname string `db:"nickname"` // 用户昵称
|
||||
CreateTime time.Time `db:"create_time"`
|
||||
UpdateTime time.Time `db:"update_time"`
|
||||
}
|
||||
)
|
||||
|
||||
func NewUserModel(conn sqlx.SqlConn, c cache.CacheConf, table string) *UserModel {
|
||||
return &UserModel{
|
||||
CachedConn: sqlc.NewConn(conn, c),
|
||||
table: table,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *UserModel) Insert(data User) (sql.Result, error) {
|
||||
query := `insert into ` + m.table + `(` + userRowsExpectAutoSet + `) value (?, ?, ?, ?, ?)`
|
||||
return m.ExecNoCache(query, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname)
|
||||
}
|
||||
|
||||
func (m *UserModel) FindOne(id int64) (*User, error) {
|
||||
userIdKey := fmt.Sprintf("%s%v", cacheUserIdPrefix, id)
|
||||
var resp User
|
||||
err := m.QueryRow(&resp, userIdKey, func(conn sqlx.SqlConn, v interface{}) error {
|
||||
query := `select ` + userRows + ` from ` + m.table + ` where id = ? limit 1`
|
||||
return conn.QueryRow(v, query, id)
|
||||
})
|
||||
switch err {
|
||||
case nil:
|
||||
return &resp, nil
|
||||
case sqlc.ErrNotFound:
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (m *UserModel) FindOneByName(name string) (*User, error) {
|
||||
userNameKey := fmt.Sprintf("%s%v", cacheUserNamePrefix, name)
|
||||
var resp User
|
||||
err := m.QueryRowIndex(&resp, userNameKey, func(primary interface{}) string {
|
||||
return fmt.Sprintf("%s%v", cacheUserIdPrefix, primary)
|
||||
}, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) {
|
||||
query := `select ` + userRows + ` from ` + m.table + ` where name = ? limit 1`
|
||||
if err := conn.QueryRow(&resp, query, name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Id, nil
|
||||
}, func(conn sqlx.SqlConn, v, primary interface{}) error {
|
||||
query := `select ` + userRows + ` from ` + m.table + ` where id = ? limit 1`
|
||||
return conn.QueryRow(v, query, primary)
|
||||
})
|
||||
switch err {
|
||||
case nil:
|
||||
return &resp, nil
|
||||
case sqlc.ErrNotFound:
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (m *UserModel) FindOneByMobile(mobile string) (*User, error) {
|
||||
userMobileKey := fmt.Sprintf("%s%v", cacheUserMobilePrefix, mobile)
|
||||
var resp User
|
||||
err := m.QueryRowIndex(&resp, userMobileKey, func(primary interface{}) string {
|
||||
return fmt.Sprintf("%s%v", cacheUserIdPrefix, primary)
|
||||
}, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) {
|
||||
query := `select ` + userRows + ` from ` + m.table + ` where mobile = ? limit 1`
|
||||
if err := conn.QueryRow(&resp, query, mobile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Id, nil
|
||||
}, func(conn sqlx.SqlConn, v, primary interface{}) error {
|
||||
query := `select ` + userRows + ` from ` + m.table + ` where id = ? limit 1`
|
||||
return conn.QueryRow(v, query, primary)
|
||||
})
|
||||
switch err {
|
||||
case nil:
|
||||
return &resp, nil
|
||||
case sqlc.ErrNotFound:
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (m *UserModel) Update(data User) error {
|
||||
userIdKey := fmt.Sprintf("%s%v", cacheUserIdPrefix, data.Id)
|
||||
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
|
||||
query := `update ` + m.table + ` set ` + userRowsWithPlaceHolder + ` where id = ?`
|
||||
return conn.Exec(query, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, data.Id)
|
||||
}, userIdKey)
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *UserModel) Delete(id int64) error {
|
||||
data, err := m.FindOne(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userIdKey := fmt.Sprintf("%s%v", cacheUserIdPrefix, id)
|
||||
userNameKey := fmt.Sprintf("%s%v", cacheUserNamePrefix, data.Name)
|
||||
userMobileKey := fmt.Sprintf("%s%v", cacheUserMobilePrefix, data.Mobile)
|
||||
_, err = m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
|
||||
query := `delete from ` + m.table + ` where id = ?`
|
||||
return conn.Exec(query, id)
|
||||
}, userIdKey, userNameKey, userMobileKey)
|
||||
return err
|
||||
}
|
||||
```
|
||||
|
||||
# 用法
|
||||
|
||||
```
|
||||
$ goctl model mysql -h
|
||||
```
|
||||
|
||||
```
|
||||
NAME:
|
||||
goctl model mysql - generate mysql model"
|
||||
|
||||
USAGE:
|
||||
goctl model mysql command [command options] [arguments...]
|
||||
|
||||
COMMANDS:
|
||||
ddl generate mysql model from ddl"
|
||||
datasource generate model from datasource"
|
||||
|
||||
OPTIONS:
|
||||
--help, -h show help
|
||||
```
|
||||
|
||||
# 生成规则
|
||||
|
||||
* 默认规则
|
||||
|
||||
我们默认用户在建表时会创建createTime、updateTime字段(忽略大小写、下划线命名风格)且默认值均为`CURRENT_TIMESTAMP`,而updateTime支持`ON UPDATE CURRENT_TIMESTAMP`,对于这两个字段生成`insert`、`update`时会被移除,不在赋值范畴内,当然,如果你不需要这两个字段那也无大碍。
|
||||
* 带缓存模式
|
||||
* ddl
|
||||
|
||||
```shell script
|
||||
$ goctl model mysql -src={filename} -dir={dir} -cache=true
|
||||
```
|
||||
func (m *UserInfoModel) FindOneByCampusIdAndIdNumber(campusId int64,idNumber string) (*UserInfo, error) {
|
||||
var resp UserInfo
|
||||
query := `select ` + userInfoRows + ` from ` + m.table + `where campus_id = ? AND id_number = ? limit 1`
|
||||
err = m.QueryRowNoCache(&resp, query, campusId, idNumber)
|
||||
// err = m.conn.QueryRows(&resp, query, campusId, idNumber)
|
||||
switch err {
|
||||
case nil:
|
||||
return &resp, nil
|
||||
case sqlx.ErrNotFound:
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
* datasource
|
||||
|
||||
```shell script
|
||||
$ goctl model mysql datasource -url={datasource} -table={tables} -dir={dir} -cache=true
|
||||
```
|
||||
<h2>FindAllByXxx生成</h2>
|
||||
FindAllByXxx查询和FindOneByXxx功能相似,只是FindOneByXxx限制了limit等于1,而FindAllByXxx是查询所有,以两个例子来说明
|
||||
|
||||
目前仅支持redis缓存,如果选择带缓存模式,即生成的`FindOne(ByXxx)`&`Delete`代码会生成带缓存逻辑的代码,目前仅支持单索引字段(除全文索引外),对于联合索引我们默认认为不需要带缓存,且不属于通用型代码,因此没有放在代码生成行列,如example中user表中的`id`、`name`、`mobile`字段均属于单字段索引。
|
||||
|
||||
- 查询单个字段`name`等于某值的所有数据
|
||||
```
|
||||
func (m *UserInfoModel) FindAllByName(name string) ([]*UserInfo, error) {
|
||||
var resp []*UserInfo
|
||||
query := `select ` + userInfoRows + ` from ` + m.table + `where name = ?`
|
||||
err := m.QueryRowsNoCache(&resp, query, name)
|
||||
// err := m.conn.QueryRows(&resp, query, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
```
|
||||
- 查询多个组合字段`campus_id`等于某值且`gender`等于某值的所有数据
|
||||
```
|
||||
func (m *UserInfoModel) FindAllByCampusIdAndGender(campusId int64,gender int64) ([]*UserInfo, error) {
|
||||
var resp []*UserInfo
|
||||
query := `select ` + userInfoRows + ` from ` + m.table + `where campus_id = ? AND gender = ?`
|
||||
err := m.QueryRowsNoCache(&resp, query, campusId, gender)
|
||||
// err := m.conn.QueryRows(&resp, query, campusId, gender)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
```
|
||||
* 不带缓存模式
|
||||
|
||||
* ddl
|
||||
|
||||
```shell script
|
||||
$ goctl model -src={filename} -dir={dir}
|
||||
```
|
||||
* datasource
|
||||
|
||||
```shell script
|
||||
$ goctl model mysql datasource -url={datasource} -table={tables} -dir={dir}
|
||||
```
|
||||
or
|
||||
* ddl
|
||||
|
||||
```shell script
|
||||
$ goctl model -src={filename} -dir={dir} -cache=false
|
||||
```
|
||||
* datasource
|
||||
|
||||
```shell script
|
||||
$ goctl model mysql datasource -url={datasource} -table={tables} -dir={dir} -cache=false
|
||||
```
|
||||
|
||||
生成代码仅基本的CURD结构。
|
||||
|
||||
<h2>FindLimitByXxx生成</h2>
|
||||
FindLimitByXxx查询和FindAllByXxx功能相似,只是FindAllByXxx限制了limit,除此之外还会生成查询对应Count总数的代码,而FindAllByXxx是查询所有数据,以几个例子来说明
|
||||
# 缓存
|
||||
|
||||
- 查询`gender`等于某值的分页数据,按照`create_time`降序
|
||||
```
|
||||
func (m *UserInfoModel) FindLimitByGender(gender int64, page, limit int) ([]*UserInfo, error) {
|
||||
var resp []*UserInfo
|
||||
query := `select ` + userInfoRows + `from ` + m.table + `where gender = ? order by create_time DESC limit ?,?`
|
||||
err := m.QueryRowsNoCache(&resp, query, gender, (page-1)*limit, limit)
|
||||
// err := m.conn.QueryRows(&resp, query, gender, (page-1)*limit, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
对于缓存这一块我选择用一问一答的形式进行罗列。我想这样能够更清晰的描述model中缓存的功能。
|
||||
|
||||
func (m *UserInfoModel) FindAllCountByGender(gender int64) (int64, error) {
|
||||
var count int64
|
||||
query := `select count(1) from ` + m.table + `where gender = ? `
|
||||
err := m.QueryRowsNoCache(&count, query, gender)
|
||||
// err := m.conn.QueryRow(&count, query, gender)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
```
|
||||
- 查询`gender`等于某值的分页数据,按照`create_time`降序、`update_time`生序排序
|
||||
```
|
||||
func (m *UserInfoModel) FindLimitByGender(gender int64, page, limit int) ([]*UserInfo, error) {
|
||||
var resp []*UserInfo
|
||||
query := `select ` + userInfoRows + `from ` + m.table + `where gender = ? order by create_time DESC,update_time ASC limit ?,?`
|
||||
err := m.QueryRowsNoCache(&resp, query, gender, (page-1)*limit, limit)
|
||||
// err := m.conn.QueryRows(&resp, query, gender, (page-1)*limit, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
* 缓存会缓存哪些信息?
|
||||
|
||||
func (m *UserInfoModel) FindAllCountByGender(gender int64) (int64, error) {
|
||||
var count int64
|
||||
query := `select count(1) from ` + m.table + `where gender = ? `
|
||||
err := m.QueryRowNoCache(&count, query, gender)
|
||||
// err := m.conn.QueryRow(&count, query, gender)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
```
|
||||
- 查询`gender`等于某值且`campus_id`为某值按照`create_time`降序的分页数据
|
||||
```
|
||||
func (m *UserInfoModel) FindLimitByGenderAndCampusId(gender int64,campusId int64, page, limit int) ([]*UserInfo, error) {
|
||||
var resp []*UserInfo
|
||||
query := `select ` + userInfoRows + `from ` + m.table + `where gender = ? AND campus_id = ? order by create_time DESC limit ?,?`
|
||||
err := m.QueryRowsNoCache(&resp, query, gender, campusId, (page-1)*limit, limit)
|
||||
// err := m.conn.QueryRows(&resp, query, gender, campusId, (page-1)*limit, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
对于主键字段缓存,会缓存整个结构体信息,而对于单索引字段(除全文索引)则缓存主键字段值。
|
||||
|
||||
func (m *UserInfoModel) FindAllCountByGenderAndCampusId(gender int64,campusId int64) (int64, error) {
|
||||
var count int64
|
||||
query := `select count(1) from ` + m.table + `where gender = ? AND campus_id = ? `
|
||||
err := m.QueryRowsNoCache(&count, query, gender, campusId)
|
||||
// err := m.conn.QueryRow(&count, query, gender, campusId)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
```
|
||||
* 数据有更新(`update`)操作会清空缓存吗?
|
||||
|
||||
会,但仅清空主键缓存的信息,why?这里就不做详细赘述了。
|
||||
|
||||
<h2>Delete生成</h2>
|
||||
Delete代码根据`withCache`的不同可以生成带缓存逻辑代码和不带缓存逻辑代码,<strong><i>Delete代码生成仅按照主键删除</i></strong>。从FindOneByXxx方法描述得知,非主键`withCache`了那么主键会强制被cache,因此在delete时也会删除主键cache。
|
||||
* 为什么不按照单索引字段生成`updateByXxx`和`deleteByXxx`的代码?
|
||||
|
||||
理论上是没任何问题,但是我们认为,对于model层的数据操作均是以整个结构体为单位,包括查询,我不建议只查询某部分字段(不反对),否则我们的缓存就没有意义了。
|
||||
|
||||
- withCache
|
||||
根据`mobile`查询用户信息
|
||||
* 为什么不支持`findPageLimit`、`findAll`这么模式代码生层?
|
||||
|
||||
目前,我认为除了基本的CURD外,其他的代码均属于<i>业务型</i>代码,这个我觉得开发人员根据业务需要进行编写更好。
|
||||
|
||||
```
|
||||
func (m *UserInfoModel) Delete(userId int64) error {
|
||||
userIdKey := fmt.Sprintf("%s%v", cacheUserInfoUserIdPrefix, userId)
|
||||
mobileKey := fmt.Sprintf("%s%v", cacheUserInfoMobilePrefix, mobile)
|
||||
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
|
||||
query := `delete from ` + m.table + + `where user_id = ?`
|
||||
return conn.Exec(query, userId)
|
||||
}, userIdKey, mobileKey)
|
||||
return err
|
||||
}
|
||||
```
|
||||
- withoutCache
|
||||
```
|
||||
func (m *UserInfoModel) Delete(userId int64) error {
|
||||
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
|
||||
query := `delete from ` + m.table + + `where user_id = ?`
|
||||
return conn.Exec(query, userId)
|
||||
}, )
|
||||
return err
|
||||
}
|
||||
```
|
||||
<h2>Insert生成</h2>
|
||||
# QA
|
||||
|
||||
<h2>Update生成</h2>
|
||||
* goctl model除了命令行模式,支持插件模式吗?
|
||||
|
||||
<h2>待完善(TODO)</h2>
|
||||
很快支持idea插件。
|
||||
|
||||
- 同一字段多种查询方式代码生成(优先级较高)
|
||||
- 条件查询
|
||||
- 范围查询
|
||||
- ...
|
||||
|
||||
<h2>反馈与建议</h2>
|
||||
|
||||
|
||||
- 无
|
||||
|
||||
|
||||
|
||||
97
tools/goctl/model/sql/builderx/builder.go
Normal file
97
tools/goctl/model/sql/builderx/builder.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package builderx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/go-xorm/builder"
|
||||
)
|
||||
|
||||
const dbTag = "db"
|
||||
|
||||
func NewEq(in interface{}) builder.Eq {
|
||||
return builder.Eq(ToMap(in))
|
||||
}
|
||||
|
||||
func NewGt(in interface{}) builder.Gt {
|
||||
return builder.Gt(ToMap(in))
|
||||
}
|
||||
|
||||
func ToMap(in interface{}) map[string]interface{} {
|
||||
out := make(map[string]interface{})
|
||||
v := reflect.ValueOf(in)
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
// we only accept structs
|
||||
if v.Kind() != reflect.Struct {
|
||||
panic(fmt.Errorf("ToMap only accepts structs; got %T", v))
|
||||
}
|
||||
typ := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
// gets us a StructField
|
||||
fi := typ.Field(i)
|
||||
if tagv := fi.Tag.Get(dbTag); tagv != "" {
|
||||
// set key of map to value in struct field
|
||||
val := v.Field(i)
|
||||
zero := reflect.Zero(val.Type()).Interface()
|
||||
current := val.Interface()
|
||||
|
||||
if reflect.DeepEqual(current, zero) {
|
||||
continue
|
||||
}
|
||||
out[tagv] = current
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func FieldNames(in interface{}) []string {
|
||||
out := make([]string, 0)
|
||||
v := reflect.ValueOf(in)
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
// we only accept structs
|
||||
if v.Kind() != reflect.Struct {
|
||||
panic(fmt.Errorf("ToMap only accepts structs; got %T", v))
|
||||
}
|
||||
typ := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
// gets us a StructField
|
||||
fi := typ.Field(i)
|
||||
if tagv := fi.Tag.Get(dbTag); tagv != "" {
|
||||
out = append(out, tagv)
|
||||
} else {
|
||||
out = append(out, fi.Name)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
func FieldNamesAlias(in interface{}, alias string) []string {
|
||||
out := make([]string, 0)
|
||||
v := reflect.ValueOf(in)
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
// we only accept structs
|
||||
if v.Kind() != reflect.Struct {
|
||||
panic(fmt.Errorf("ToMap only accepts structs; got %T", v))
|
||||
}
|
||||
typ := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
// gets us a StructField
|
||||
fi := typ.Field(i)
|
||||
tagName := ""
|
||||
if tagv := fi.Tag.Get(dbTag); tagv != "" {
|
||||
tagName = tagv
|
||||
} else {
|
||||
tagName = fi.Name
|
||||
}
|
||||
if len(alias) > 0 {
|
||||
tagName = alias + "." + tagName
|
||||
}
|
||||
out = append(out, tagName)
|
||||
}
|
||||
return out
|
||||
}
|
||||
101
tools/goctl/model/sql/builderx/builder_test.go
Normal file
101
tools/goctl/model/sql/builderx/builder_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package builderx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/go-xorm/builder"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type (
|
||||
User struct {
|
||||
// 自增id
|
||||
Id string `db:"id" json:"id,omitempty"`
|
||||
// 姓名
|
||||
UserName string `db:"user_name" json:"userName,omitempty"`
|
||||
// 1男,2女
|
||||
Sex int `db:"sex" json:"sex,omitempty"`
|
||||
|
||||
Uuid string `db:"uuid" uuid:"uuid,omitempty"`
|
||||
|
||||
Age int `db:"age" json:"age"`
|
||||
}
|
||||
)
|
||||
|
||||
var userFields = FieldNames(User{})
|
||||
|
||||
func TestFieldNames(t *testing.T) {
|
||||
var u User
|
||||
out := FieldNames(&u)
|
||||
fmt.Println(out)
|
||||
actual := []string{"id", "user_name", "sex", "uuid", "age"}
|
||||
assert.Equal(t, out, actual)
|
||||
}
|
||||
|
||||
func TestNewEq(t *testing.T) {
|
||||
u := &User{
|
||||
Id: "123456",
|
||||
UserName: "wahaha",
|
||||
}
|
||||
out := NewEq(u)
|
||||
fmt.Println(out)
|
||||
actual := builder.Eq{"id": "123456", "user_name": "wahaha"}
|
||||
assert.Equal(t, out, actual)
|
||||
}
|
||||
|
||||
// @see https://github.com/go-xorm/builder
|
||||
func TestBuilderSql(t *testing.T) {
|
||||
u := &User{
|
||||
Id: "123123",
|
||||
}
|
||||
fields := FieldNames(u)
|
||||
eq := NewEq(u)
|
||||
sql, args, err := builder.Select(fields...).From("user").Where(eq).ToSQL()
|
||||
fmt.Println(sql, args, err)
|
||||
|
||||
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE id=?"
|
||||
actualArgs := []interface{}{"123123"}
|
||||
assert.Equal(t, sql, actualSql)
|
||||
assert.Equal(t, args, actualArgs)
|
||||
}
|
||||
|
||||
func TestBuildSqlDefaultValue(t *testing.T) {
|
||||
var eq = builder.Eq{}
|
||||
eq["age"] = 0
|
||||
eq["user_name"] = ""
|
||||
|
||||
sql, args, err := builder.Select(userFields...).From("user").Where(eq).ToSQL()
|
||||
fmt.Println(sql, args, err)
|
||||
|
||||
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE age=? AND user_name=?"
|
||||
actualArgs := []interface{}{0, ""}
|
||||
assert.Equal(t, sql, actualSql)
|
||||
assert.Equal(t, args, actualArgs)
|
||||
}
|
||||
|
||||
func TestBuilderSqlIn(t *testing.T) {
|
||||
u := &User{
|
||||
Age: 18,
|
||||
}
|
||||
gtU := NewGt(u)
|
||||
in := builder.In("id", []string{"1", "2", "3"})
|
||||
sql, args, err := builder.Select(userFields...).From("user").Where(in).And(gtU).ToSQL()
|
||||
fmt.Println(sql, args, err)
|
||||
|
||||
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE id IN (?,?,?) AND age>?"
|
||||
actualArgs := []interface{}{"1", "2", "3", 18}
|
||||
assert.Equal(t, sql, actualSql)
|
||||
assert.Equal(t, args, actualArgs)
|
||||
}
|
||||
|
||||
func TestBuildSqlLike(t *testing.T) {
|
||||
like := builder.Like{"name", "wang"}
|
||||
sql, args, err := builder.Select(userFields...).From("user").Where(like).ToSQL()
|
||||
fmt.Println(sql, args, err)
|
||||
|
||||
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE name LIKE ?"
|
||||
actualArgs := []interface{}{"%wang%"}
|
||||
assert.Equal(t, sql, actualSql)
|
||||
assert.Equal(t, args, actualArgs)
|
||||
}
|
||||
86
tools/goctl/model/sql/command/command.go
Normal file
86
tools/goctl/model/sql/command/command.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/collection"
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"github.com/tal-tech/go-zero/core/stores/sqlx"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util/console"
|
||||
"github.com/urfave/cli"
|
||||
)
|
||||
|
||||
const (
|
||||
flagSrc = "src"
|
||||
flagDir = "dir"
|
||||
flagCache = "cache"
|
||||
flagIdea = "idea"
|
||||
flagUrl = "url"
|
||||
flagTable = "table"
|
||||
)
|
||||
|
||||
func MysqlDDL(ctx *cli.Context) error {
|
||||
src := ctx.String(flagSrc)
|
||||
dir := ctx.String(flagDir)
|
||||
cache := ctx.Bool(flagCache)
|
||||
idea := ctx.Bool(flagIdea)
|
||||
log := console.NewConsole(idea)
|
||||
fileSrc, err := filepath.Abs(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := ioutil.ReadFile(fileSrc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
source := string(data)
|
||||
generator := gen.NewDefaultGenerator(source, dir, gen.WithConsoleOption(log))
|
||||
err = generator.Start(cache)
|
||||
if err != nil {
|
||||
log.Error("%v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func MyDataSource(ctx *cli.Context) error {
|
||||
url := strings.TrimSpace(ctx.String(flagUrl))
|
||||
dir := strings.TrimSpace(ctx.String(flagDir))
|
||||
cache := ctx.Bool(flagCache)
|
||||
idea := ctx.Bool(flagIdea)
|
||||
table := strings.TrimSpace(ctx.String(flagTable))
|
||||
log := console.NewConsole(idea)
|
||||
if len(url) == 0 {
|
||||
log.Error("%v", "expected data source of mysql, but is empty")
|
||||
return nil
|
||||
}
|
||||
if len(table) == 0 {
|
||||
log.Error("%v", "expected table(s), but nothing found")
|
||||
return nil
|
||||
}
|
||||
logx.Disable()
|
||||
conn := sqlx.NewMysql(url)
|
||||
m := model.NewDDLModel(conn)
|
||||
tables := collection.NewSet()
|
||||
for _, item := range strings.Split(table, ",") {
|
||||
item = strings.TrimSpace(item)
|
||||
if len(item) == 0 {
|
||||
continue
|
||||
}
|
||||
tables.AddStr(item)
|
||||
}
|
||||
ddl, err := m.ShowDDL(tables.KeysStr()...)
|
||||
if err != nil {
|
||||
log.Error("%v", err)
|
||||
return nil
|
||||
}
|
||||
generator := gen.NewDefaultGenerator(strings.Join(ddl, "\n"), dir, gen.WithConsoleOption(log))
|
||||
err = generator.Start(cache)
|
||||
if err != nil {
|
||||
log.Error("%v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,19 +1,25 @@
|
||||
package model
|
||||
package converter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
CommonMysqlDataTypeMap = map[string]string{
|
||||
"tinyint": "int",
|
||||
"smallint": "int",
|
||||
commonMysqlDataTypeMap = map[string]string{
|
||||
// For consistency, all integer types are converted to int64
|
||||
"tinyint": "int64",
|
||||
"smallint": "int64",
|
||||
"mediumint": "int64",
|
||||
"int": "int64",
|
||||
"integer": "int64",
|
||||
"bigint": "int64",
|
||||
"float": "float32",
|
||||
"float": "float64",
|
||||
"double": "float64",
|
||||
"decimal": "float64",
|
||||
"date": "time.Time",
|
||||
"time": "string",
|
||||
"year": "int",
|
||||
"year": "int64",
|
||||
"datetime": "time.Time",
|
||||
"timestamp": "time.Time",
|
||||
"char": "string",
|
||||
@@ -29,6 +35,12 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
ModeDirPerm = 0755
|
||||
)
|
||||
func ConvertDataType(dataBaseType string) (goDataType string, err error) {
|
||||
tp, ok := commonMysqlDataTypeMap[strings.ToLower(dataBaseType)]
|
||||
if !ok {
|
||||
err = fmt.Errorf("unexpected database type: %s", dataBaseType)
|
||||
return
|
||||
}
|
||||
goDataType = tp
|
||||
return
|
||||
}
|
||||
7
tools/goctl/model/sql/example/generator.sh
Normal file
7
tools/goctl/model/sql/example/generator.sh
Normal file
@@ -0,0 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
# generate model with cache from ddl
|
||||
goctl model mysql ddl -src="./sql/user.sql" -dir="./sql/model" -c=true
|
||||
|
||||
# generate model with cache from data source
|
||||
goctl model mysql datasource -url="user:password@tcp(127.0.0.1:3306)/database" -table="table1,table2" -dir="./model"
|
||||
15
tools/goctl/model/sql/example/sql/user.sql
Normal file
15
tools/goctl/model/sql/example/sql/user.sql
Normal file
@@ -0,0 +1,15 @@
|
||||
-- 用户表 --
|
||||
CREATE TABLE `user` (
|
||||
`id` bigint(10) NOT NULL AUTO_INCREMENT,
|
||||
`name` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户名称',
|
||||
`password` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户密码',
|
||||
`mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '手机号',
|
||||
`gender` char(5) COLLATE utf8mb4_general_ci NOT NULL COMMENT '男|女|未公开',
|
||||
`nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT '用户昵称',
|
||||
`create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `name_index` (`name`),
|
||||
KEY `mobile_index` (`mobile`)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
package gen
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/util"
|
||||
)
|
||||
|
||||
func TableConvert(outerTable OuterTable) (*InnerTable, error) {
|
||||
var table InnerTable
|
||||
table.CreateNotFound = outerTable.CreateNotFound
|
||||
tableSnakeCase, tableUpperCamelCase, tableLowerCamelCase := util.FormatField(outerTable.Table)
|
||||
table.SnakeCase = tableSnakeCase
|
||||
table.UpperCamelCase = tableUpperCamelCase
|
||||
table.LowerCamelCase = tableLowerCamelCase
|
||||
fields := make([]*InnerField, 0)
|
||||
var primaryField *InnerField
|
||||
conflict := make(map[string]struct{})
|
||||
var containsCache bool
|
||||
for _, field := range outerTable.Fields {
|
||||
if field.Cache && !containsCache {
|
||||
containsCache = true
|
||||
}
|
||||
fieldSnakeCase, fieldUpperCamelCase, fieldLowerCamelCase := util.FormatField(field.Name)
|
||||
tag, err := genTag(fieldSnakeCase)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var comment string
|
||||
if field.Comment != "" {
|
||||
comment = fmt.Sprintf("// %s", field.Comment)
|
||||
}
|
||||
withFields := make([]InnerWithField, 0)
|
||||
unique := make([]string, 0)
|
||||
unique = append(unique, fmt.Sprintf("%v", field.QueryType))
|
||||
unique = append(unique, field.Name)
|
||||
|
||||
for _, item := range field.WithFields {
|
||||
unique = append(unique, item.Name)
|
||||
withFieldSnakeCase, withFieldUpperCamelCase, withFieldLowerCamelCase := util.FormatField(item.Name)
|
||||
withFields = append(withFields, InnerWithField{
|
||||
Case: Case{
|
||||
SnakeCase: withFieldSnakeCase,
|
||||
LowerCamelCase: withFieldLowerCamelCase,
|
||||
UpperCamelCase: withFieldUpperCamelCase,
|
||||
},
|
||||
DataType: commonMysqlDataTypeMap[item.DataBaseType],
|
||||
})
|
||||
}
|
||||
sort.Strings(unique)
|
||||
uniqueKey := strings.Join(unique, "#")
|
||||
if _, ok := conflict[uniqueKey]; ok {
|
||||
return nil, ErrCircleQuery
|
||||
} else {
|
||||
conflict[uniqueKey] = struct{}{}
|
||||
}
|
||||
sortFields := make([]InnerSort, 0)
|
||||
for _, sortField := range field.OuterSort {
|
||||
sortSnake, sortUpperCamelCase, sortLowerCamelCase := util.FormatField(sortField.Field)
|
||||
sortFields = append(sortFields, InnerSort{
|
||||
Field: Case{
|
||||
SnakeCase: sortSnake,
|
||||
LowerCamelCase: sortUpperCamelCase,
|
||||
UpperCamelCase: sortLowerCamelCase,
|
||||
},
|
||||
Asc: sortField.Asc,
|
||||
})
|
||||
}
|
||||
innerField := &InnerField{
|
||||
IsPrimaryKey: field.IsPrimaryKey,
|
||||
InnerWithField: InnerWithField{
|
||||
Case: Case{
|
||||
SnakeCase: fieldSnakeCase,
|
||||
LowerCamelCase: fieldLowerCamelCase,
|
||||
UpperCamelCase: fieldUpperCamelCase,
|
||||
},
|
||||
DataType: commonMysqlDataTypeMap[field.DataBaseType],
|
||||
},
|
||||
DataBaseType: field.DataBaseType,
|
||||
Tag: tag,
|
||||
Comment: comment,
|
||||
Cache: field.Cache,
|
||||
QueryType: field.QueryType,
|
||||
WithFields: withFields,
|
||||
Sort: sortFields,
|
||||
}
|
||||
if field.IsPrimaryKey {
|
||||
primaryField = innerField
|
||||
}
|
||||
fields = append(fields, innerField)
|
||||
}
|
||||
if primaryField == nil {
|
||||
return nil, errors.New("please ensure that primary exists")
|
||||
}
|
||||
table.ContainsCache = containsCache
|
||||
primaryField.Cache = containsCache
|
||||
table.PrimaryField = primaryField
|
||||
table.Fields = fields
|
||||
cacheKey, err := genCacheKeys(&table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
table.CacheKey = cacheKey
|
||||
return &table, nil
|
||||
}
|
||||
@@ -1,51 +1,47 @@
|
||||
package gen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
|
||||
"github.com/tal-tech/go-zero/core/collection"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util/templatex"
|
||||
)
|
||||
|
||||
func genDelete(table *InnerTable) (string, error) {
|
||||
t, err := template.New("delete").Parse(sqltemplate.Delete)
|
||||
if err != nil {
|
||||
return "", nil
|
||||
}
|
||||
deleteBuffer := new(bytes.Buffer)
|
||||
keys := make([]string, 0)
|
||||
keyValues := make([]string, 0)
|
||||
for snake, key := range table.CacheKey {
|
||||
if snake == table.PrimaryField.SnakeCase {
|
||||
keys = append(keys, key.Key)
|
||||
func genDelete(table Table, withCache bool) (string, error) {
|
||||
keySet := collection.NewSet()
|
||||
keyVariableSet := collection.NewSet()
|
||||
for fieldName, key := range table.CacheKey {
|
||||
if fieldName == table.PrimaryKey.Name.Source() {
|
||||
keySet.AddStr(key.KeyExpression)
|
||||
} else {
|
||||
keys = append(keys, key.DataKey)
|
||||
keySet.AddStr(key.DataKeyExpression)
|
||||
}
|
||||
keyValues = append(keyValues, key.KeyVariable)
|
||||
keyVariableSet.AddStr(key.Variable)
|
||||
}
|
||||
var isOnlyPrimaryKeyCache = true
|
||||
var containsIndexCache = false
|
||||
for _, item := range table.Fields {
|
||||
if item.IsPrimaryKey {
|
||||
continue
|
||||
}
|
||||
if item.Cache {
|
||||
isOnlyPrimaryKeyCache = false
|
||||
if item.IsKey {
|
||||
containsIndexCache = true
|
||||
break
|
||||
}
|
||||
}
|
||||
err = t.Execute(deleteBuffer, map[string]interface{}{
|
||||
"upperObject": table.UpperCamelCase,
|
||||
"containsCache": table.ContainsCache,
|
||||
"isNotPrimaryKey": !isOnlyPrimaryKeyCache,
|
||||
"lowerPrimaryKey": table.PrimaryField.LowerCamelCase,
|
||||
"dataType": table.PrimaryField.DataType,
|
||||
"keys": strings.Join(keys, "\r\n"),
|
||||
"snakePrimaryKey": table.PrimaryField.SnakeCase,
|
||||
"keyValues": strings.Join(keyValues, ", "),
|
||||
})
|
||||
camel := table.Name.ToCamel()
|
||||
output, err := templatex.With("delete").
|
||||
Parse(template.Delete).
|
||||
Execute(map[string]interface{}{
|
||||
"upperStartCamelObject": camel,
|
||||
"withCache": withCache,
|
||||
"containsIndexCache": containsIndexCache,
|
||||
"lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).UnTitle(),
|
||||
"dataType": table.PrimaryKey.DataType,
|
||||
"keys": strings.Join(keySet.KeysStr(), "\n"),
|
||||
"originalPrimaryKey": table.PrimaryKey.Name.Source(),
|
||||
"keyValues": strings.Join(keyVariableSet.KeysStr(), ", "),
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return deleteBuffer.String(), nil
|
||||
return output.String(), nil
|
||||
}
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
package gen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/parser"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util/templatex"
|
||||
)
|
||||
|
||||
func genFields(fields []*InnerField) (string, error) {
|
||||
list := make([]string, 0)
|
||||
func genFields(fields []parser.Field) (string, error) {
|
||||
var list []string
|
||||
for _, field := range fields {
|
||||
result, err := genField(field)
|
||||
if err != nil {
|
||||
@@ -17,23 +17,25 @@ func genFields(fields []*InnerField) (string, error) {
|
||||
}
|
||||
list = append(list, result)
|
||||
}
|
||||
return strings.Join(list, "\r\n"), nil
|
||||
return strings.Join(list, "\n"), nil
|
||||
}
|
||||
|
||||
func genField(field *InnerField) (string, error) {
|
||||
t, err := template.New("types").Parse(sqltemplate.Field)
|
||||
if err != nil {
|
||||
return "", nil
|
||||
}
|
||||
var typeBuffer = new(bytes.Buffer)
|
||||
err = t.Execute(typeBuffer, map[string]string{
|
||||
"name": field.UpperCamelCase,
|
||||
"type": field.DataType,
|
||||
"tag": field.Tag,
|
||||
"comment": field.Comment,
|
||||
})
|
||||
func genField(field parser.Field) (string, error) {
|
||||
tag, err := genTag(field.Name.Source())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return typeBuffer.String(), nil
|
||||
output, err := templatex.With("types").
|
||||
Parse(template.Field).
|
||||
Execute(map[string]interface{}{
|
||||
"name": field.Name.ToCamel(),
|
||||
"type": field.DataType,
|
||||
"tag": tag,
|
||||
"hasComment": field.Comment != "",
|
||||
"comment": field.Comment,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return output.String(), nil
|
||||
}
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
package gen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
|
||||
)
|
||||
|
||||
func genFindAllByField(table *InnerTable) (string, error) {
|
||||
t, err := template.New("findAllByField").Parse(sqltemplate.FindAllByField)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
list := make([]string, 0)
|
||||
for _, field := range table.Fields {
|
||||
if field.IsPrimaryKey {
|
||||
continue
|
||||
}
|
||||
if field.QueryType != QueryAll {
|
||||
continue
|
||||
}
|
||||
fineOneByFieldBuffer := new(bytes.Buffer)
|
||||
upperFields := make([]string, 0)
|
||||
in := make([]string, 0)
|
||||
expressionFields := make([]string, 0)
|
||||
expressionValuesFields := make([]string, 0)
|
||||
upperFields = append(upperFields, field.UpperCamelCase)
|
||||
in = append(in, field.LowerCamelCase+" "+field.DataType)
|
||||
expressionFields = append(expressionFields, field.SnakeCase+" = ?")
|
||||
expressionValuesFields = append(expressionValuesFields, field.LowerCamelCase)
|
||||
for _, withField := range field.WithFields {
|
||||
upperFields = append(upperFields, withField.UpperCamelCase)
|
||||
in = append(in, withField.LowerCamelCase+" "+withField.DataType)
|
||||
expressionFields = append(expressionFields, withField.SnakeCase+" = ?")
|
||||
expressionValuesFields = append(expressionValuesFields, withField.LowerCamelCase)
|
||||
}
|
||||
err = t.Execute(fineOneByFieldBuffer, map[string]interface{}{
|
||||
"in": strings.Join(in, ","),
|
||||
"upperObject": table.UpperCamelCase,
|
||||
"upperFields": strings.Join(upperFields, "And"),
|
||||
"lowerObject": table.LowerCamelCase,
|
||||
"snakePrimaryKey": field.SnakeCase,
|
||||
"expression": strings.Join(expressionFields, " AND "),
|
||||
"expressionValues": strings.Join(expressionValuesFields, ", "),
|
||||
"containsCache": table.ContainsCache,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
list = append(list, fineOneByFieldBuffer.String())
|
||||
}
|
||||
return strings.Join(list, ""), nil
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
package gen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
|
||||
)
|
||||
|
||||
func genFindLimitByField(table *InnerTable) (string, error) {
|
||||
t, err := template.New("findLimitByField").Parse(sqltemplate.FindLimitByField)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
list := make([]string, 0)
|
||||
for _, field := range table.Fields {
|
||||
if field.IsPrimaryKey {
|
||||
continue
|
||||
}
|
||||
if field.QueryType != QueryLimit {
|
||||
continue
|
||||
}
|
||||
fineOneByFieldBuffer := new(bytes.Buffer)
|
||||
upperFields := make([]string, 0)
|
||||
in := make([]string, 0)
|
||||
expressionFields := make([]string, 0)
|
||||
expressionValuesFields := make([]string, 0)
|
||||
upperFields = append(upperFields, field.UpperCamelCase)
|
||||
in = append(in, field.LowerCamelCase+" "+field.DataType)
|
||||
expressionFields = append(expressionFields, field.SnakeCase+" = ?")
|
||||
expressionValuesFields = append(expressionValuesFields, field.LowerCamelCase)
|
||||
for _, withField := range field.WithFields {
|
||||
upperFields = append(upperFields, withField.UpperCamelCase)
|
||||
in = append(in, withField.LowerCamelCase+" "+withField.DataType)
|
||||
expressionFields = append(expressionFields, withField.SnakeCase+" = ?")
|
||||
expressionValuesFields = append(expressionValuesFields, withField.LowerCamelCase)
|
||||
}
|
||||
sortList := make([]string, 0)
|
||||
for _, item := range field.Sort {
|
||||
var sort = "ASC"
|
||||
if !item.Asc {
|
||||
sort = "DESC"
|
||||
}
|
||||
sortList = append(sortList, item.Field.SnakeCase+" "+sort)
|
||||
}
|
||||
err = t.Execute(fineOneByFieldBuffer, map[string]interface{}{
|
||||
"in": strings.Join(in, ","),
|
||||
"upperObject": table.UpperCamelCase,
|
||||
"upperFields": strings.Join(upperFields, "And"),
|
||||
"lowerObject": table.LowerCamelCase,
|
||||
"expression": strings.Join(expressionFields, " AND "),
|
||||
"expressionValues": strings.Join(expressionValuesFields, ", "),
|
||||
"sortExpression": strings.Join(sortList, ","),
|
||||
"containsCache": table.ContainsCache,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
list = append(list, fineOneByFieldBuffer.String())
|
||||
}
|
||||
return strings.Join(list, ""), nil
|
||||
}
|
||||
@@ -1,30 +1,27 @@
|
||||
package gen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"text/template"
|
||||
|
||||
sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util/templatex"
|
||||
)
|
||||
|
||||
func genFindOne(table *InnerTable) (string, error) {
|
||||
t, err := template.New("findOne").Parse(sqltemplate.FindOne)
|
||||
func genFindOne(table Table, withCache bool) (string, error) {
|
||||
camel := table.Name.ToCamel()
|
||||
output, err := templatex.With("findOne").
|
||||
Parse(template.FindOne).
|
||||
Execute(map[string]interface{}{
|
||||
"withCache": withCache,
|
||||
"upperStartCamelObject": camel,
|
||||
"lowerStartCamelObject": stringx.From(camel).UnTitle(),
|
||||
"originalPrimaryKey": table.PrimaryKey.Name.Source(),
|
||||
"lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).UnTitle(),
|
||||
"dataType": table.PrimaryKey.DataType,
|
||||
"cacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].KeyExpression,
|
||||
"cacheKeyVariable": table.CacheKey[table.PrimaryKey.Name.Source()].Variable,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
fineOneBuffer := new(bytes.Buffer)
|
||||
err = t.Execute(fineOneBuffer, map[string]interface{}{
|
||||
"withCache": table.PrimaryField.Cache,
|
||||
"upperObject": table.UpperCamelCase,
|
||||
"lowerObject": table.LowerCamelCase,
|
||||
"snakePrimaryKey": table.PrimaryField.SnakeCase,
|
||||
"lowerPrimaryKey": table.PrimaryField.LowerCamelCase,
|
||||
"dataType": table.PrimaryField.DataType,
|
||||
"cacheKey": table.CacheKey[table.PrimaryField.SnakeCase].Key,
|
||||
"cacheKeyVariable": table.CacheKey[table.PrimaryField.SnakeCase].KeyVariable,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fineOneBuffer.String(), nil
|
||||
return output.String(), nil
|
||||
}
|
||||
|
||||
@@ -1,67 +1,41 @@
|
||||
package gen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util/templatex"
|
||||
)
|
||||
|
||||
func genFineOneByField(table *InnerTable) (string, error) {
|
||||
t, err := template.New("findOneByField").Parse(sqltemplate.FindOneByField)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
list := make([]string, 0)
|
||||
func genFineOneByField(table Table, withCache bool) (string, error) {
|
||||
t := templatex.With("findOneByField").Parse(template.FindOneByField)
|
||||
var list []string
|
||||
camelTableName := table.Name.ToCamel()
|
||||
for _, field := range table.Fields {
|
||||
if field.IsPrimaryKey {
|
||||
if field.IsPrimaryKey || !field.IsKey {
|
||||
continue
|
||||
}
|
||||
if field.QueryType != QueryOne {
|
||||
continue
|
||||
}
|
||||
fineOneByFieldBuffer := new(bytes.Buffer)
|
||||
upperFields := make([]string, 0)
|
||||
in := make([]string, 0)
|
||||
expressionFields := make([]string, 0)
|
||||
expressionValuesFields := make([]string, 0)
|
||||
upperFields = append(upperFields, field.UpperCamelCase)
|
||||
in = append(in, field.LowerCamelCase+" "+field.DataType)
|
||||
expressionFields = append(expressionFields, field.SnakeCase+" = ?")
|
||||
expressionValuesFields = append(expressionValuesFields, field.LowerCamelCase)
|
||||
for _, withField := range field.WithFields {
|
||||
upperFields = append(upperFields, withField.UpperCamelCase)
|
||||
in = append(in, withField.LowerCamelCase+" "+withField.DataType)
|
||||
expressionFields = append(expressionFields, withField.SnakeCase+" = ?")
|
||||
expressionValuesFields = append(expressionValuesFields, withField.LowerCamelCase)
|
||||
}
|
||||
err = t.Execute(fineOneByFieldBuffer, map[string]interface{}{
|
||||
"in": strings.Join(in, ","),
|
||||
"upperObject": table.UpperCamelCase,
|
||||
"upperFields": strings.Join(upperFields, "And"),
|
||||
"onlyOneFiled": len(field.WithFields) == 0,
|
||||
"withCache": field.Cache,
|
||||
"containsCache": table.ContainsCache,
|
||||
"lowerObject": table.LowerCamelCase,
|
||||
"lowerField": field.LowerCamelCase,
|
||||
"snakeField": field.SnakeCase,
|
||||
"lowerPrimaryKey": table.PrimaryField.LowerCamelCase,
|
||||
"UpperPrimaryKey": table.PrimaryField.UpperCamelCase,
|
||||
"primaryKeyDefine": table.CacheKey[table.PrimaryField.SnakeCase].Define,
|
||||
"primarySnakeField": table.PrimaryField.SnakeCase,
|
||||
"primaryDataType": table.PrimaryField.DataType,
|
||||
"primaryDataTypeString": table.PrimaryField.DataType == "string",
|
||||
"upperObjectKey": table.PrimaryField.UpperCamelCase,
|
||||
"cacheKey": table.CacheKey[field.SnakeCase].Key,
|
||||
"cacheKeyVariable": table.CacheKey[field.SnakeCase].KeyVariable,
|
||||
"expression": strings.Join(expressionFields, " AND "),
|
||||
"expressionValues": strings.Join(expressionValuesFields, ", "),
|
||||
camelFieldName := field.Name.ToCamel()
|
||||
output, err := t.Execute(map[string]interface{}{
|
||||
"upperStartCamelObject": camelTableName,
|
||||
"upperField": camelFieldName,
|
||||
"in": fmt.Sprintf("%s %s", stringx.From(camelFieldName).UnTitle(), field.DataType),
|
||||
"withCache": withCache,
|
||||
"cacheKey": table.CacheKey[field.Name.Source()].KeyExpression,
|
||||
"cacheKeyVariable": table.CacheKey[field.Name.Source()].Variable,
|
||||
"primaryKeyLeft": table.CacheKey[table.PrimaryKey.Name.Source()].Left,
|
||||
"lowerStartCamelObject": stringx.From(camelTableName).UnTitle(),
|
||||
"lowerStartCamelField": stringx.From(camelFieldName).UnTitle(),
|
||||
"upperStartCamelPrimaryKey": table.PrimaryKey.Name.ToCamel(),
|
||||
"originalField": field.Name.Source(),
|
||||
"originalPrimaryField": table.PrimaryKey.Name.Source(),
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
list = append(list, fineOneByFieldBuffer.String())
|
||||
list = append(list, output.String())
|
||||
}
|
||||
return strings.Join(list, ""), nil
|
||||
return strings.Join(list, "\n"), nil
|
||||
}
|
||||
|
||||
183
tools/goctl/model/sql/gen/gen.go
Normal file
183
tools/goctl/model/sql/gen/gen.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package gen
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/parser"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util/console"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util/templatex"
|
||||
)
|
||||
|
||||
const (
|
||||
pwd = "."
|
||||
createTableFlag = `(?m)^(?i)CREATE\s+TABLE` // ignore case
|
||||
)
|
||||
|
||||
type (
|
||||
defaultGenerator struct {
|
||||
source string
|
||||
dir string
|
||||
console.Console
|
||||
}
|
||||
Option func(generator *defaultGenerator)
|
||||
)
|
||||
|
||||
func NewDefaultGenerator(source, dir string, opt ...Option) *defaultGenerator {
|
||||
if dir == "" {
|
||||
dir = pwd
|
||||
}
|
||||
generator := &defaultGenerator{source: source, dir: dir}
|
||||
var optionList []Option
|
||||
optionList = append(optionList, newDefaultOption())
|
||||
optionList = append(optionList, opt...)
|
||||
for _, fn := range optionList {
|
||||
fn(generator)
|
||||
}
|
||||
return generator
|
||||
}
|
||||
|
||||
func WithConsoleOption(c console.Console) Option {
|
||||
return func(generator *defaultGenerator) {
|
||||
generator.Console = c
|
||||
}
|
||||
}
|
||||
|
||||
func newDefaultOption() Option {
|
||||
return func(generator *defaultGenerator) {
|
||||
generator.Console = console.NewColorConsole()
|
||||
}
|
||||
}
|
||||
|
||||
func (g *defaultGenerator) Start(withCache bool) error {
|
||||
dirAbs, err := filepath.Abs(g.dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = util.MkdirIfNotExist(dirAbs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
modelList, err := g.genFromDDL(withCache)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for tableName, code := range modelList {
|
||||
name := fmt.Sprintf("%smodel.go", strings.ToLower(stringx.From(tableName).ToCamel()))
|
||||
filename := filepath.Join(dirAbs, name)
|
||||
if util.FileExists(filename) {
|
||||
g.Warning("%s already exists, ignored.", name)
|
||||
continue
|
||||
}
|
||||
err = ioutil.WriteFile(filename, []byte(code), os.ModePerm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// generate error file
|
||||
filename := filepath.Join(dirAbs, "error.go")
|
||||
if !util.FileExists(filename) {
|
||||
err = ioutil.WriteFile(filename, []byte(template.Error), os.ModePerm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
g.Success("Done.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ret1: key-table name,value-code
|
||||
func (g *defaultGenerator) genFromDDL(withCache bool) (map[string]string, error) {
|
||||
ddlList := g.split()
|
||||
m := make(map[string]string)
|
||||
for _, ddl := range ddlList {
|
||||
table, err := parser.Parse(ddl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
code, err := g.genModel(*table, withCache)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m[table.Name.Source()] = code
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
type (
|
||||
Table struct {
|
||||
parser.Table
|
||||
CacheKey map[string]Key
|
||||
}
|
||||
)
|
||||
|
||||
func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, error) {
|
||||
t := templatex.With("model").
|
||||
Parse(template.Model).
|
||||
GoFmt(true)
|
||||
|
||||
m, err := genCacheKeys(in)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
importsCode := genImports(withCache)
|
||||
var table Table
|
||||
table.Table = in
|
||||
table.CacheKey = m
|
||||
|
||||
varsCode, err := genVars(table, withCache)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
typesCode, err := genTypes(table, withCache)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
newCode, err := genNew(table, withCache)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
insertCode, err := genInsert(table, withCache)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var findCode = make([]string, 0)
|
||||
findOneCode, err := genFindOne(table, withCache)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
findOneByFieldCode, err := genFineOneByField(table, withCache)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
findCode = append(findCode, findOneCode, findOneByFieldCode)
|
||||
updateCode, err := genUpdate(table, withCache)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
deleteCode, err := genDelete(table, withCache)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
output, err := t.Execute(map[string]interface{}{
|
||||
"imports": importsCode,
|
||||
"vars": varsCode,
|
||||
"types": typesCode,
|
||||
"new": newCode,
|
||||
"insert": insertCode,
|
||||
"find": strings.Join(findCode, "\r\n"),
|
||||
"update": updateCode,
|
||||
"delete": deleteCode,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return output.String(), nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user