Compare commits
51 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
57b73d8b49 | ||
|
|
a79cee12ee | ||
|
|
7a921f66e6 | ||
|
|
12e235efb0 | ||
|
|
01060cf16d | ||
|
|
0786862a35 | ||
|
|
efa43483b2 | ||
|
|
771371e051 | ||
|
|
2ee95f8981 | ||
|
|
5bc01e4bfd | ||
|
|
510e966982 | ||
|
|
10e3b8ac80 | ||
|
|
04059bbf5a | ||
|
|
d643007c79 | ||
|
|
fc43876cc5 | ||
|
|
a926cb514f | ||
|
|
25cab2f273 | ||
|
|
8d2e2753a2 | ||
|
|
cc4c50e3eb | ||
|
|
751072bdb0 | ||
|
|
e97e1f10db | ||
|
|
0bd2a0656c | ||
|
|
71a2b20301 | ||
|
|
8df7de94e3 | ||
|
|
bf21203297 | ||
|
|
ae98375194 | ||
|
|
82d1ccf376 | ||
|
|
bb6d49c17e | ||
|
|
ed735ec47c | ||
|
|
ba4bac3a03 | ||
|
|
08433d7e04 | ||
|
|
a3b525b50d | ||
|
|
097f6886f2 | ||
|
|
07a1549634 | ||
|
|
befca26c58 | ||
|
|
3556a2eef4 | ||
|
|
807765f77e | ||
|
|
e44584e549 | ||
|
|
acd48f0abb | ||
|
|
f919bc6713 | ||
|
|
a0030b8f45 | ||
|
|
a5f0cce1b1 | ||
|
|
4d13dda605 | ||
|
|
b56cc8e459 | ||
|
|
c435811479 | ||
|
|
c686c93fb5 | ||
|
|
da8f76e6bd | ||
|
|
99596a4149 | ||
|
|
ec2a9f2c57 | ||
|
|
fd73ced6dc | ||
|
|
5071736ab4 |
67
.github/workflows/codeql-analysis.yml
vendored
Normal file
67
.github/workflows/codeql-analysis.yml
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
# For most projects, this workflow file will not need changing; you simply need
|
||||
# to commit it to your repository.
|
||||
#
|
||||
# You may wish to alter this file to override the set of languages analyzed,
|
||||
# or to provide custom queries or build logic.
|
||||
#
|
||||
# ******** NOTE ********
|
||||
# We have attempted to detect the languages in your repository. Please check
|
||||
# the `language` matrix defined below to confirm you have the correct set of
|
||||
# supported CodeQL languages.
|
||||
#
|
||||
name: "CodeQL"
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
pull_request:
|
||||
# The branches below must be a subset of the branches above
|
||||
branches: [ master ]
|
||||
schedule:
|
||||
- cron: '18 19 * * 6'
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
name: Analyze
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
language: [ 'go' ]
|
||||
# CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ]
|
||||
# Learn more:
|
||||
# https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v1
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||
# By default, queries listed here will override any specified in a config file.
|
||||
# Prefix the list here with "+" to use these queries and those in the config file.
|
||||
# queries: ./path/to/local/query, your-org/your-repo/queries@main
|
||||
|
||||
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
|
||||
# If this step fails, then you should remove it and run the build manually (see below)
|
||||
- name: Autobuild
|
||||
uses: github/codeql-action/autobuild@v1
|
||||
|
||||
# ℹ️ Command-line programs to run using the OS shell.
|
||||
# 📚 https://git.io/JvXDl
|
||||
|
||||
# ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
|
||||
# and modify them (or add more) to build your code if your project
|
||||
# uses a compiled language
|
||||
|
||||
#- run: |
|
||||
# make bootstrap
|
||||
# make release
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v1
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -4,6 +4,7 @@
|
||||
# Unignore all with extensions
|
||||
!*.*
|
||||
!**/Dockerfile
|
||||
!**/Makefile
|
||||
|
||||
# Unignore all dirs
|
||||
!*/
|
||||
@@ -12,7 +13,6 @@
|
||||
.idea
|
||||
**/.DS_Store
|
||||
**/logs
|
||||
!Makefile
|
||||
|
||||
# gitlab ci
|
||||
.cache
|
||||
|
||||
@@ -8,8 +8,10 @@ import (
|
||||
)
|
||||
|
||||
type (
|
||||
// RollingWindowOption let callers customize the RollingWindow.
|
||||
RollingWindowOption func(rollingWindow *RollingWindow)
|
||||
|
||||
// RollingWindow defines a rolling window to calculate the events in buckets with time interval.
|
||||
RollingWindow struct {
|
||||
lock sync.RWMutex
|
||||
size int
|
||||
@@ -17,10 +19,12 @@ type (
|
||||
interval time.Duration
|
||||
offset int
|
||||
ignoreCurrent bool
|
||||
lastTime time.Duration
|
||||
lastTime time.Duration // start time of the last bucket
|
||||
}
|
||||
)
|
||||
|
||||
// NewRollingWindow returns a RollingWindow that with size buckets and time interval,
|
||||
// use opts to customize the RollingWindow.
|
||||
func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOption) *RollingWindow {
|
||||
if size < 1 {
|
||||
panic("size must be greater than 0")
|
||||
@@ -38,6 +42,7 @@ func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOpt
|
||||
return w
|
||||
}
|
||||
|
||||
// Add adds value to current bucket.
|
||||
func (rw *RollingWindow) Add(v float64) {
|
||||
rw.lock.Lock()
|
||||
defer rw.lock.Unlock()
|
||||
@@ -45,6 +50,7 @@ func (rw *RollingWindow) Add(v float64) {
|
||||
rw.win.add(rw.offset, v)
|
||||
}
|
||||
|
||||
// Reduce runs fn on all buckets, ignore current bucket if ignoreCurrent was set.
|
||||
func (rw *RollingWindow) Reduce(fn func(b *Bucket)) {
|
||||
rw.lock.RLock()
|
||||
defer rw.lock.RUnlock()
|
||||
@@ -79,26 +85,18 @@ func (rw *RollingWindow) updateOffset() {
|
||||
}
|
||||
|
||||
offset := rw.offset
|
||||
start := offset + 1
|
||||
steps := start + span
|
||||
var remainder int
|
||||
if steps > rw.size {
|
||||
remainder = steps - rw.size
|
||||
steps = rw.size
|
||||
}
|
||||
|
||||
// reset expired buckets
|
||||
for i := start; i < steps; i++ {
|
||||
rw.win.resetBucket(i)
|
||||
}
|
||||
for i := 0; i < remainder; i++ {
|
||||
rw.win.resetBucket(i)
|
||||
for i := 0; i < span; i++ {
|
||||
rw.win.resetBucket((offset + i + 1) % rw.size)
|
||||
}
|
||||
|
||||
rw.offset = (offset + span) % rw.size
|
||||
rw.lastTime = timex.Now()
|
||||
now := timex.Now()
|
||||
// align to interval time boundary
|
||||
rw.lastTime = now - (now-rw.lastTime)%rw.interval
|
||||
}
|
||||
|
||||
// Bucket defines the bucket that holds sum and num of additions.
|
||||
type Bucket struct {
|
||||
Sum float64
|
||||
Count int64
|
||||
@@ -144,6 +142,7 @@ func (w *window) resetBucket(offset int) {
|
||||
w.buckets[offset%w.size].reset()
|
||||
}
|
||||
|
||||
// IgnoreCurrentBucket lets the Reduce call ignore current bucket.
|
||||
func IgnoreCurrentBucket() RollingWindowOption {
|
||||
return func(w *RollingWindow) {
|
||||
w.ignoreCurrent = true
|
||||
|
||||
@@ -105,6 +105,37 @@ func TestRollingWindowReduce(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRollingWindowBucketTimeBoundary(t *testing.T) {
|
||||
const size = 3
|
||||
interval := time.Millisecond * 30
|
||||
r := NewRollingWindow(size, interval)
|
||||
listBuckets := func() []float64 {
|
||||
var buckets []float64
|
||||
r.Reduce(func(b *Bucket) {
|
||||
buckets = append(buckets, b.Sum)
|
||||
})
|
||||
return buckets
|
||||
}
|
||||
assert.Equal(t, []float64{0, 0, 0}, listBuckets())
|
||||
r.Add(1)
|
||||
assert.Equal(t, []float64{0, 0, 1}, listBuckets())
|
||||
time.Sleep(time.Millisecond * 45)
|
||||
r.Add(2)
|
||||
r.Add(3)
|
||||
assert.Equal(t, []float64{0, 1, 5}, listBuckets())
|
||||
// sleep time should be less than interval, and make the bucket change happen
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
r.Add(4)
|
||||
r.Add(5)
|
||||
r.Add(6)
|
||||
assert.Equal(t, []float64{1, 5, 15}, listBuckets())
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
r.Add(7)
|
||||
r.Add(8)
|
||||
r.Add(9)
|
||||
assert.Equal(t, []float64{0, 0, 24}, listBuckets())
|
||||
}
|
||||
|
||||
func TestRollingWindowDataRace(t *testing.T) {
|
||||
const size = 3
|
||||
r := NewRollingWindow(size, duration)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/mapping"
|
||||
@@ -19,7 +20,7 @@ func LoadConfig(file string, v interface{}) error {
|
||||
if content, err := ioutil.ReadFile(file); err != nil {
|
||||
return err
|
||||
} else if loader, ok := loaders[path.Ext(file)]; ok {
|
||||
return loader(content, v)
|
||||
return loader([]byte(os.ExpandEnv(string(content))), v)
|
||||
} else {
|
||||
return fmt.Errorf("unrecoginized file type: %s", file)
|
||||
}
|
||||
|
||||
@@ -17,13 +17,14 @@ func TestConfigJson(t *testing.T) {
|
||||
}
|
||||
text := `{
|
||||
"a": "foo",
|
||||
"b": 1
|
||||
"b": 1,
|
||||
"c": "${FOO}"
|
||||
}`
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
os.Setenv("FOO", "2")
|
||||
defer os.Unsetenv("FOO")
|
||||
tmpfile, err := createTempFile(test, text)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
@@ -31,10 +32,12 @@ func TestConfigJson(t *testing.T) {
|
||||
var val struct {
|
||||
A string `json:"a"`
|
||||
B int `json:"b"`
|
||||
C string `json:"c"`
|
||||
}
|
||||
MustLoad(tmpfile, &val)
|
||||
assert.Equal(t, "foo", val.A)
|
||||
assert.Equal(t, 1, val.B)
|
||||
assert.Equal(t, "2", val.C)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package executors
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/lang"
|
||||
@@ -35,6 +36,7 @@ type (
|
||||
// avoid race condition on waitGroup when calling wg.Add/Done/Wait(...)
|
||||
wgBarrier syncx.Barrier
|
||||
confirmChan chan lang.PlaceholderType
|
||||
inflight int32
|
||||
guarded bool
|
||||
newTicker func(duration time.Duration) timex.Ticker
|
||||
lock sync.Mutex
|
||||
@@ -91,18 +93,16 @@ func (pe *PeriodicalExecutor) Wait() {
|
||||
func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool) {
|
||||
pe.lock.Lock()
|
||||
defer func() {
|
||||
var start bool
|
||||
if !pe.guarded {
|
||||
pe.guarded = true
|
||||
start = true
|
||||
// defer to unlock quickly
|
||||
defer pe.backgroundFlush()
|
||||
}
|
||||
pe.lock.Unlock()
|
||||
if start {
|
||||
pe.backgroundFlush()
|
||||
}
|
||||
}()
|
||||
|
||||
if pe.container.AddTask(task) {
|
||||
atomic.AddInt32(&pe.inflight, 1)
|
||||
return pe.container.RemoveAll(), true
|
||||
}
|
||||
|
||||
@@ -111,6 +111,9 @@ func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool)
|
||||
|
||||
func (pe *PeriodicalExecutor) backgroundFlush() {
|
||||
threading.GoSafe(func() {
|
||||
// flush before quit goroutine to avoid missing tasks
|
||||
defer pe.Flush()
|
||||
|
||||
ticker := pe.newTicker(pe.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -120,6 +123,7 @@ func (pe *PeriodicalExecutor) backgroundFlush() {
|
||||
select {
|
||||
case vals := <-pe.commander:
|
||||
commanded = true
|
||||
atomic.AddInt32(&pe.inflight, -1)
|
||||
pe.enterExecution()
|
||||
pe.confirmChan <- lang.Placeholder
|
||||
pe.executeTasks(vals)
|
||||
@@ -129,13 +133,7 @@ func (pe *PeriodicalExecutor) backgroundFlush() {
|
||||
commanded = false
|
||||
} else if pe.Flush() {
|
||||
last = timex.Now()
|
||||
} else if timex.Since(last) > pe.interval*idleRound {
|
||||
pe.lock.Lock()
|
||||
pe.guarded = false
|
||||
pe.lock.Unlock()
|
||||
|
||||
// flush again to avoid missing tasks
|
||||
pe.Flush()
|
||||
} else if pe.shallQuit(last) {
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -178,3 +176,19 @@ func (pe *PeriodicalExecutor) hasTasks(tasks interface{}) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (pe *PeriodicalExecutor) shallQuit(last time.Duration) (stop bool) {
|
||||
if timex.Since(last) <= pe.interval*idleRound {
|
||||
return
|
||||
}
|
||||
|
||||
// checking pe.inflight and setting pe.guarded should be locked together
|
||||
pe.lock.Lock()
|
||||
if atomic.LoadInt32(&pe.inflight) == 0 {
|
||||
pe.guarded = false
|
||||
stop = true
|
||||
}
|
||||
pe.lock.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -140,6 +140,26 @@ func TestPeriodicalExecutor_WaitFast(t *testing.T) {
|
||||
assert.Equal(t, total, cnt)
|
||||
}
|
||||
|
||||
func TestPeriodicalExecutor_Deadlock(t *testing.T) {
|
||||
executor := NewBulkExecutor(func(tasks []interface{}) {
|
||||
}, WithBulkTasks(1), WithBulkInterval(time.Millisecond))
|
||||
for i := 0; i < 1e5; i++ {
|
||||
executor.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeriodicalExecutor_hasTasks(t *testing.T) {
|
||||
ticker := timex.NewFakeTicker()
|
||||
defer ticker.Stop()
|
||||
|
||||
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, nil))
|
||||
exec.newTicker = func(d time.Duration) timex.Ticker {
|
||||
return ticker
|
||||
}
|
||||
assert.False(t, exec.hasTasks(nil))
|
||||
assert.True(t, exec.hasTasks(1))
|
||||
}
|
||||
|
||||
// go test -benchtime 10s -bench .
|
||||
func BenchmarkExecutor(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
@@ -21,6 +21,7 @@ var mock tracespec.Trace = new(mockTrace)
|
||||
|
||||
func TestTraceLog(t *testing.T) {
|
||||
var buf mockWriter
|
||||
atomic.StoreUint32(&initialized, 1)
|
||||
ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock)
|
||||
WithContext(ctx).(*traceLogger).write(&buf, levelInfo, testlog)
|
||||
assert.True(t, strings.Contains(buf.String(), mockTraceId))
|
||||
|
||||
@@ -153,58 +153,57 @@ func doParseKeyAndOptions(field reflect.StructField, value string) (string, *fie
|
||||
key := strings.TrimSpace(segments[0])
|
||||
options := segments[1:]
|
||||
|
||||
if len(options) > 0 {
|
||||
var fieldOpts fieldOptions
|
||||
|
||||
for _, segment := range options {
|
||||
option := strings.TrimSpace(segment)
|
||||
switch {
|
||||
case option == stringOption:
|
||||
fieldOpts.FromString = true
|
||||
case strings.HasPrefix(option, optionalOption):
|
||||
segs := strings.Split(option, equalToken)
|
||||
switch len(segs) {
|
||||
case 1:
|
||||
fieldOpts.Optional = true
|
||||
case 2:
|
||||
fieldOpts.Optional = true
|
||||
fieldOpts.OptionalDep = segs[1]
|
||||
default:
|
||||
return "", nil, fmt.Errorf("field %s has wrong optional", field.Name)
|
||||
}
|
||||
case option == optionalOption:
|
||||
fieldOpts.Optional = true
|
||||
case strings.HasPrefix(option, optionsOption):
|
||||
segs := strings.Split(option, equalToken)
|
||||
if len(segs) != 2 {
|
||||
return "", nil, fmt.Errorf("field %s has wrong options", field.Name)
|
||||
} else {
|
||||
fieldOpts.Options = strings.Split(segs[1], optionSeparator)
|
||||
}
|
||||
case strings.HasPrefix(option, defaultOption):
|
||||
segs := strings.Split(option, equalToken)
|
||||
if len(segs) != 2 {
|
||||
return "", nil, fmt.Errorf("field %s has wrong default option", field.Name)
|
||||
} else {
|
||||
fieldOpts.Default = strings.TrimSpace(segs[1])
|
||||
}
|
||||
case strings.HasPrefix(option, rangeOption):
|
||||
segs := strings.Split(option, equalToken)
|
||||
if len(segs) != 2 {
|
||||
return "", nil, fmt.Errorf("field %s has wrong range", field.Name)
|
||||
}
|
||||
if nr, err := parseNumberRange(segs[1]); err != nil {
|
||||
return "", nil, err
|
||||
} else {
|
||||
fieldOpts.Range = nr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return key, &fieldOpts, nil
|
||||
if len(options) == 0 {
|
||||
return key, nil, nil
|
||||
}
|
||||
|
||||
return key, nil, nil
|
||||
var fieldOpts fieldOptions
|
||||
for _, segment := range options {
|
||||
option := strings.TrimSpace(segment)
|
||||
switch {
|
||||
case option == stringOption:
|
||||
fieldOpts.FromString = true
|
||||
case strings.HasPrefix(option, optionalOption):
|
||||
segs := strings.Split(option, equalToken)
|
||||
switch len(segs) {
|
||||
case 1:
|
||||
fieldOpts.Optional = true
|
||||
case 2:
|
||||
fieldOpts.Optional = true
|
||||
fieldOpts.OptionalDep = segs[1]
|
||||
default:
|
||||
return "", nil, fmt.Errorf("field %s has wrong optional", field.Name)
|
||||
}
|
||||
case option == optionalOption:
|
||||
fieldOpts.Optional = true
|
||||
case strings.HasPrefix(option, optionsOption):
|
||||
segs := strings.Split(option, equalToken)
|
||||
if len(segs) != 2 {
|
||||
return "", nil, fmt.Errorf("field %s has wrong options", field.Name)
|
||||
} else {
|
||||
fieldOpts.Options = strings.Split(segs[1], optionSeparator)
|
||||
}
|
||||
case strings.HasPrefix(option, defaultOption):
|
||||
segs := strings.Split(option, equalToken)
|
||||
if len(segs) != 2 {
|
||||
return "", nil, fmt.Errorf("field %s has wrong default option", field.Name)
|
||||
} else {
|
||||
fieldOpts.Default = strings.TrimSpace(segs[1])
|
||||
}
|
||||
case strings.HasPrefix(option, rangeOption):
|
||||
segs := strings.Split(option, equalToken)
|
||||
if len(segs) != 2 {
|
||||
return "", nil, fmt.Errorf("field %s has wrong range", field.Name)
|
||||
}
|
||||
if nr, err := parseNumberRange(segs[1]); err != nil {
|
||||
return "", nil, err
|
||||
} else {
|
||||
fieldOpts.Range = nr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return key, &fieldOpts, nil
|
||||
}
|
||||
|
||||
func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) {
|
||||
|
||||
16
core/prof/profilecenter_test.go
Normal file
16
core/prof/profilecenter_test.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package prof
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestReport(t *testing.T) {
|
||||
once.Do(func() {})
|
||||
assert.NotContains(t, generateReport(), "foo")
|
||||
report("foo", time.Second)
|
||||
assert.Contains(t, generateReport(), "foo")
|
||||
report("foo", time.Second)
|
||||
}
|
||||
23
core/prof/profiler_test.go
Normal file
23
core/prof/profiler_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package prof
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/utils"
|
||||
)
|
||||
|
||||
func TestProfiler(t *testing.T) {
|
||||
EnableProfiling()
|
||||
Start()
|
||||
Report("foo", ProfilePoint{
|
||||
ElapsedTimer: utils.NewElapsedTimer(),
|
||||
})
|
||||
}
|
||||
|
||||
func TestNullProfiler(t *testing.T) {
|
||||
p := newNullProfiler()
|
||||
p.Start()
|
||||
p.Report("foo", ProfilePoint{
|
||||
ElapsedTimer: utils.NewElapsedTimer(),
|
||||
})
|
||||
}
|
||||
@@ -70,8 +70,6 @@ func (g *sharedGroup) createCall(key string) (c *call, done bool) {
|
||||
|
||||
func (g *sharedGroup) makeCall(c *call, key string, fn func() (interface{}, error)) {
|
||||
defer func() {
|
||||
// delete key first, done later. can't reverse the order, because if reverse,
|
||||
// another Do call might wg.Wait() without get notified with wg.Done()
|
||||
g.lock.Lock()
|
||||
delete(g.calls, key)
|
||||
g.lock.Unlock()
|
||||
|
||||
@@ -129,7 +129,7 @@ go get -u github.com/tal-tech/go-zero
|
||||
the .api files also can be generate by goctl, like below:
|
||||
|
||||
```shell
|
||||
goctl api -o greet.api
|
||||
goctl api -o greet.api
|
||||
```
|
||||
|
||||
3. generate the go server side code
|
||||
@@ -208,3 +208,7 @@ goctl api -o greet.api
|
||||
|
||||
* [Rapid development of microservice systems](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl-en.md)
|
||||
* [Rapid development of microservice systems - multiple RPCs](https://github.com/tal-tech/zero-doc/blob/main/doc/bookstore-en.md)
|
||||
|
||||
## 9. Chat group
|
||||
|
||||
Join the chat via https://discord.gg/4JQvC5A4Fe
|
||||
|
||||
15
readme.md
15
readme.md
@@ -5,8 +5,9 @@
|
||||
[English](readme-en.md) | 简体中文
|
||||
|
||||
[](https://github.com/tal-tech/go-zero/actions)
|
||||
[](https://codecov.io/gh/tal-tech/go-zero)
|
||||
[](https://goreportcard.com/report/github.com/tal-tech/go-zero)
|
||||
[](https://goproxy.cn/stats/github.com/tal-tech/go-zero/badges/download-count.svg)
|
||||
[](https://codecov.io/gh/tal-tech/go-zero)
|
||||
[](https://github.com/tal-tech/go-zero)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
@@ -95,7 +96,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
|
||||
|
||||
[快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
|
||||
|
||||
[快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/frame/bookstore.md)
|
||||
[快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/zero/bookstore.md)
|
||||
|
||||
1. 安装 goctl 工具
|
||||
|
||||
@@ -162,7 +163,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
|
||||
|
||||
* awesome 系列
|
||||
* [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
|
||||
* [快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/frame/bookstore.md)
|
||||
* [快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/zero/bookstore.md)
|
||||
* [goctl 使用帮助](https://github.com/tal-tech/zero-doc/blob/main/doc/goctl.md)
|
||||
* [通过 MapReduce 降低服务响应时间](https://github.com/tal-tech/zero-doc/blob/main/doc/mapreduce.md)
|
||||
* [关键字替换和敏感词过滤工具](https://github.com/tal-tech/zero-doc/blob/main/doc/keywords.md)
|
||||
@@ -172,7 +173,13 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
|
||||
* [文本序列化和反序列化](https://github.com/tal-tech/zero-doc/blob/main/doc/mapping.md)
|
||||
* [快速构建 jwt 鉴权认证](https://github.com/tal-tech/zero-doc/blob/main/doc/jwt.md)
|
||||
|
||||
## 8. 微信交流群
|
||||
## 8. 微信公众号
|
||||
|
||||
`go-zero` 相关文章都会在 `微服务实践` 公众号整理呈现,欢迎扫码关注,也可以通过公众号私信我 👏
|
||||
|
||||
<img src="https://gitee.com/kevwan/static/raw/master/images/wechat-micro.jpg" alt="wechat" width="300" />
|
||||
|
||||
## 9. 微信交流群
|
||||
|
||||
如果文档中未能覆盖的任何疑问,欢迎您在群里提出,我们会尽快答复。
|
||||
|
||||
|
||||
171
rest/engine_test.go
Normal file
171
rest/engine_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/conf"
|
||||
)
|
||||
|
||||
func TestNewEngine(t *testing.T) {
|
||||
yamls := []string{
|
||||
`Name: foo
|
||||
Port: 54321
|
||||
`,
|
||||
`Name: foo
|
||||
Port: 54321
|
||||
CpuThreshold: 500
|
||||
`,
|
||||
`Name: foo
|
||||
Port: 54321
|
||||
CpuThreshold: 500
|
||||
Verbose: true
|
||||
`,
|
||||
}
|
||||
|
||||
routes := []featuredRoutes{
|
||||
{
|
||||
jwt: jwtSetting{},
|
||||
signature: signatureSetting{},
|
||||
routes: []Route{{
|
||||
Method: http.MethodGet,
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||
}},
|
||||
},
|
||||
{
|
||||
priority: true,
|
||||
jwt: jwtSetting{},
|
||||
signature: signatureSetting{},
|
||||
routes: []Route{{
|
||||
Method: http.MethodGet,
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||
}},
|
||||
},
|
||||
{
|
||||
priority: true,
|
||||
jwt: jwtSetting{
|
||||
enabled: true,
|
||||
},
|
||||
signature: signatureSetting{},
|
||||
routes: []Route{{
|
||||
Method: http.MethodGet,
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||
}},
|
||||
},
|
||||
{
|
||||
priority: true,
|
||||
jwt: jwtSetting{
|
||||
enabled: true,
|
||||
prevSecret: "thesecret",
|
||||
},
|
||||
signature: signatureSetting{},
|
||||
routes: []Route{{
|
||||
Method: http.MethodGet,
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||
}},
|
||||
},
|
||||
{
|
||||
priority: true,
|
||||
jwt: jwtSetting{
|
||||
enabled: true,
|
||||
},
|
||||
signature: signatureSetting{},
|
||||
routes: []Route{{
|
||||
Method: http.MethodGet,
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||
}},
|
||||
},
|
||||
{
|
||||
priority: true,
|
||||
jwt: jwtSetting{
|
||||
enabled: true,
|
||||
},
|
||||
signature: signatureSetting{
|
||||
enabled: true,
|
||||
},
|
||||
routes: []Route{{
|
||||
Method: http.MethodGet,
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||
}},
|
||||
},
|
||||
{
|
||||
priority: true,
|
||||
jwt: jwtSetting{
|
||||
enabled: true,
|
||||
},
|
||||
signature: signatureSetting{
|
||||
enabled: true,
|
||||
SignatureConf: SignatureConf{
|
||||
Strict: true,
|
||||
},
|
||||
},
|
||||
routes: []Route{{
|
||||
Method: http.MethodGet,
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||
}},
|
||||
},
|
||||
{
|
||||
priority: true,
|
||||
jwt: jwtSetting{
|
||||
enabled: true,
|
||||
},
|
||||
signature: signatureSetting{
|
||||
enabled: true,
|
||||
SignatureConf: SignatureConf{
|
||||
Strict: true,
|
||||
PrivateKeys: []PrivateKeyConf{
|
||||
{
|
||||
Fingerprint: "a",
|
||||
KeyFile: "b",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
routes: []Route{{
|
||||
Method: http.MethodGet,
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, yaml := range yamls {
|
||||
for _, route := range routes {
|
||||
var cnf RestConf
|
||||
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(yaml), &cnf))
|
||||
ng := newEngine(cnf)
|
||||
ng.AddRoutes(route)
|
||||
ng.use(func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
assert.NotNil(t, ng.StartWithRouter(mockedRouter{}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type mockedRouter struct {
|
||||
}
|
||||
|
||||
func (m mockedRouter) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
}
|
||||
|
||||
func (m mockedRouter) Handle(method string, path string, handler http.Handler) error {
|
||||
return errors.New("foo")
|
||||
}
|
||||
|
||||
func (m mockedRouter) SetNotFoundHandler(handler http.Handler) {
|
||||
}
|
||||
|
||||
func (m mockedRouter) SetNotAllowedHandler(handler http.Handler) {
|
||||
}
|
||||
@@ -46,18 +46,18 @@ func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.H
|
||||
parser := token.NewTokenParser()
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token, err := parser.ParseToken(r, secret, authOpts.PrevSecret)
|
||||
tok, err := parser.ParseToken(r, secret, authOpts.PrevSecret)
|
||||
if err != nil {
|
||||
unauthorized(w, r, err, authOpts.Callback)
|
||||
return
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
if !tok.Valid {
|
||||
unauthorized(w, r, errInvalidToken, authOpts.Callback)
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
claims, ok := tok.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
unauthorized(w, r, errNoClaims, authOpts.Callback)
|
||||
return
|
||||
@@ -122,6 +122,12 @@ func newGuardedResponseWriter(w http.ResponseWriter) *guardedResponseWriter {
|
||||
}
|
||||
}
|
||||
|
||||
func (grw *guardedResponseWriter) Flush() {
|
||||
if flusher, ok := grw.writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (grw *guardedResponseWriter) Header() http.Header {
|
||||
return grw.writer.Header()
|
||||
}
|
||||
|
||||
@@ -41,6 +41,10 @@ func TestAuthHandler(t *testing.T) {
|
||||
w.Header().Set("X-Test", "test")
|
||||
_, err := w.Write([]byte("content"))
|
||||
assert.Nil(t, err)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
assert.True(t, ok)
|
||||
flusher.Flush()
|
||||
}))
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
|
||||
@@ -83,6 +83,12 @@ func newCryptionResponseWriter(w http.ResponseWriter) *cryptionResponseWriter {
|
||||
}
|
||||
}
|
||||
|
||||
func (w *cryptionResponseWriter) Flush() {
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *cryptionResponseWriter) Header() http.Header {
|
||||
return w.ResponseWriter.Header()
|
||||
}
|
||||
|
||||
@@ -87,3 +87,19 @@ func TestCryptionHandlerWriteHeader(t *testing.T) {
|
||||
handler.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
|
||||
}
|
||||
|
||||
func TestCryptionHandlerFlush(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/any", nil)
|
||||
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(respText))
|
||||
flusher, ok := w.(http.Flusher)
|
||||
assert.True(t, ok)
|
||||
flusher.Flush()
|
||||
}))
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
|
||||
}
|
||||
|
||||
@@ -38,6 +38,12 @@ func (w *LoggedResponseWriter) WriteHeader(code int) {
|
||||
w.code = code
|
||||
}
|
||||
|
||||
func (w *LoggedResponseWriter) Flush() {
|
||||
if flusher, ok := w.w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func LogHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
timer := utils.NewElapsedTimer()
|
||||
@@ -68,6 +74,10 @@ func newDetailLoggedResponseWriter(writer *LoggedResponseWriter, buf *bytes.Buff
|
||||
}
|
||||
}
|
||||
|
||||
func (w *DetailLoggedResponseWriter) Flush() {
|
||||
w.writer.Flush()
|
||||
}
|
||||
|
||||
func (w *DetailLoggedResponseWriter) Header() http.Header {
|
||||
return w.writer.Header()
|
||||
}
|
||||
|
||||
@@ -30,6 +30,10 @@ func TestLogHandler(t *testing.T) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
_, err := w.Write([]byte("content"))
|
||||
assert.Nil(t, err)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
assert.True(t, ok)
|
||||
flusher.Flush()
|
||||
}))
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
|
||||
@@ -7,6 +7,12 @@ type WithCodeResponseWriter struct {
|
||||
Code int
|
||||
}
|
||||
|
||||
func (w *WithCodeResponseWriter) Flush() {
|
||||
if flusher, ok := w.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WithCodeResponseWriter) Header() http.Header {
|
||||
return w.Writer.Header()
|
||||
}
|
||||
|
||||
33
rest/internal/security/withcoderesponsewriter_test.go
Normal file
33
rest/internal/security/withcoderesponsewriter_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestWithCodeResponseWriter(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cw := &WithCodeResponseWriter{Writer: w}
|
||||
|
||||
cw.Header().Set("X-Test", "test")
|
||||
cw.WriteHeader(http.StatusServiceUnavailable)
|
||||
assert.Equal(t, cw.Code, http.StatusServiceUnavailable)
|
||||
|
||||
_, err := cw.Write([]byte("content"))
|
||||
assert.Nil(t, err)
|
||||
|
||||
flusher, ok := http.ResponseWriter(cw).(http.Flusher)
|
||||
assert.True(t, ok)
|
||||
flusher.Flush()
|
||||
})
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
|
||||
assert.Equal(t, "test", resp.Header().Get("X-Test"))
|
||||
assert.Equal(t, "content", resp.Body.String())
|
||||
}
|
||||
@@ -64,7 +64,7 @@ func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
allow, ok := pr.methodNotAllowed(r.Method, reqPath)
|
||||
allows, ok := pr.methodsAllowed(r.Method, reqPath)
|
||||
if !ok {
|
||||
pr.handleNotFound(w, r)
|
||||
return
|
||||
@@ -73,7 +73,7 @@ func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if pr.notAllowed != nil {
|
||||
pr.notAllowed.ServeHTTP(w, r)
|
||||
} else {
|
||||
w.Header().Set(allowHeader, allow)
|
||||
w.Header().Set(allowHeader, allows)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
@@ -94,7 +94,7 @@ func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (pr *patRouter) methodNotAllowed(method, path string) (string, bool) {
|
||||
func (pr *patRouter) methodsAllowed(method, path string) (string, bool) {
|
||||
var allows []string
|
||||
|
||||
for treeMethod, tree := range pr.trees {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
@@ -24,6 +23,9 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
// MustNewServer returns a server with given config of c and options defined in opts.
|
||||
// Be aware that later RunOption might overwrite previous one that write the same option.
|
||||
// The process will exit if error occurs.
|
||||
func MustNewServer(c RestConf, opts ...RunOption) *Server {
|
||||
engine, err := NewServer(c, opts...)
|
||||
if err != nil {
|
||||
@@ -33,11 +35,9 @@ func MustNewServer(c RestConf, opts ...RunOption) *Server {
|
||||
return engine
|
||||
}
|
||||
|
||||
// NewServer returns a server with given config of c and options defined in opts.
|
||||
// Be aware that later RunOption might overwrite previous one that write the same option.
|
||||
func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
|
||||
if len(opts) > 1 {
|
||||
return nil, errors.New("only one RunOption is allowed")
|
||||
}
|
||||
|
||||
if err := c.SetUp(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -8,18 +8,84 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/conf"
|
||||
"github.com/tal-tech/go-zero/rest/httpx"
|
||||
"github.com/tal-tech/go-zero/rest/router"
|
||||
)
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
_, err := NewServer(RestConf{}, WithNotFoundHandler(nil), WithNotAllowedHandler(nil))
|
||||
assert.NotNil(t, err)
|
||||
const configYaml = `
|
||||
Name: foo
|
||||
Port: 54321
|
||||
`
|
||||
var cnf RestConf
|
||||
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
|
||||
failStart := func(server *Server) {
|
||||
server.opts.start = func(e *engine) error {
|
||||
return http.ErrServerClosed
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
c RestConf
|
||||
opts []RunOption
|
||||
fail bool
|
||||
}{
|
||||
{
|
||||
c: RestConf{},
|
||||
opts: []RunOption{failStart},
|
||||
fail: true,
|
||||
},
|
||||
{
|
||||
c: cnf,
|
||||
opts: []RunOption{failStart},
|
||||
},
|
||||
{
|
||||
c: cnf,
|
||||
opts: []RunOption{WithNotAllowedHandler(nil), failStart},
|
||||
},
|
||||
{
|
||||
c: cnf,
|
||||
opts: []RunOption{WithNotFoundHandler(nil), failStart},
|
||||
},
|
||||
{
|
||||
c: cnf,
|
||||
opts: []RunOption{WithUnauthorizedCallback(nil), failStart},
|
||||
},
|
||||
{
|
||||
c: cnf,
|
||||
opts: []RunOption{WithUnsignedCallback(nil), failStart},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
srv, err := NewServer(test.c, test.opts...)
|
||||
if test.fail {
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
srv.Use(ToMiddleware(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}))
|
||||
srv.AddRoute(Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "/",
|
||||
Handler: nil,
|
||||
}, WithJwt("thesecret"), WithSignature(SignatureConf{}),
|
||||
WithJwtTransition("preivous", "thenewone"))
|
||||
srv.Start()
|
||||
srv.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithMiddleware(t *testing.T) {
|
||||
m := make(map[string]string)
|
||||
router := router.NewRouter()
|
||||
rt := router.NewRouter()
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
var v struct {
|
||||
Nickname string `form:"nickname"`
|
||||
@@ -56,14 +122,14 @@ func TestWithMiddleware(t *testing.T) {
|
||||
"http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
|
||||
}
|
||||
for _, route := range rs {
|
||||
assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler))
|
||||
assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
|
||||
}
|
||||
for _, url := range urls {
|
||||
r, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
router.ServeHTTP(rr, r)
|
||||
rt.ServeHTTP(rr, r)
|
||||
|
||||
assert.Equal(t, "whatever:200000", rr.Body.String())
|
||||
}
|
||||
@@ -76,7 +142,7 @@ func TestWithMiddleware(t *testing.T) {
|
||||
|
||||
func TestMultiMiddlewares(t *testing.T) {
|
||||
m := make(map[string]string)
|
||||
router := router.NewRouter()
|
||||
rt := router.NewRouter()
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
var v struct {
|
||||
Nickname string `form:"nickname"`
|
||||
@@ -127,14 +193,14 @@ func TestMultiMiddlewares(t *testing.T) {
|
||||
"http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
|
||||
}
|
||||
for _, route := range rs {
|
||||
assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler))
|
||||
assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
|
||||
}
|
||||
for _, url := range urls {
|
||||
r, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
router.ServeHTTP(rr, r)
|
||||
rt.ServeHTTP(rr, r)
|
||||
|
||||
assert.Equal(t, "whatever:200000200000", rr.Body.String())
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"go/scanner"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"github.com/tal-tech/go-zero/core/errorx"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/api/parser"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/api/util"
|
||||
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
|
||||
"github.com/urfave/cli"
|
||||
)
|
||||
|
||||
@@ -103,24 +105,108 @@ func apiFormat(data string) (string, error) {
|
||||
var builder strings.Builder
|
||||
s := bufio.NewScanner(strings.NewReader(data))
|
||||
var tapCount = 0
|
||||
var newLineCount = 0
|
||||
var preLine string
|
||||
for s.Scan() {
|
||||
line := strings.TrimSpace(s.Text())
|
||||
if len(line) == 0 {
|
||||
if newLineCount > 0 {
|
||||
continue
|
||||
}
|
||||
newLineCount++
|
||||
} else {
|
||||
if preLine == rightBrace {
|
||||
builder.WriteString(ctlutil.NL)
|
||||
}
|
||||
newLineCount = 0
|
||||
}
|
||||
|
||||
if tapCount == 0 {
|
||||
format, err := formatGoTypeDef(line, s, &builder)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if format {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
noCommentLine := util.RemoveComment(line)
|
||||
if noCommentLine == rightParenthesis || noCommentLine == rightBrace {
|
||||
tapCount -= 1
|
||||
}
|
||||
if tapCount < 0 {
|
||||
line = strings.TrimSuffix(line, rightBrace)
|
||||
line := strings.TrimSuffix(noCommentLine, rightBrace)
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasSuffix(line, leftBrace) {
|
||||
tapCount += 1
|
||||
}
|
||||
}
|
||||
util.WriteIndent(&builder, tapCount)
|
||||
builder.WriteString(line + "\n")
|
||||
builder.WriteString(line + ctlutil.NL)
|
||||
if strings.HasSuffix(noCommentLine, leftParenthesis) || strings.HasSuffix(noCommentLine, leftBrace) {
|
||||
tapCount += 1
|
||||
}
|
||||
preLine = line
|
||||
}
|
||||
return strings.TrimSpace(builder.String()), nil
|
||||
}
|
||||
|
||||
func formatGoTypeDef(line string, scanner *bufio.Scanner, builder *strings.Builder) (bool, error) {
|
||||
noCommentLine := util.RemoveComment(line)
|
||||
tokenCount := 0
|
||||
if strings.HasPrefix(noCommentLine, "type") && (strings.HasSuffix(noCommentLine, leftParenthesis) ||
|
||||
strings.HasSuffix(noCommentLine, leftBrace)) {
|
||||
var typeBuilder strings.Builder
|
||||
typeBuilder.WriteString(mayInsertStructKeyword(line, &tokenCount) + ctlutil.NL)
|
||||
for scanner.Scan() {
|
||||
noCommentLine := util.RemoveComment(scanner.Text())
|
||||
typeBuilder.WriteString(mayInsertStructKeyword(scanner.Text(), &tokenCount) + ctlutil.NL)
|
||||
if noCommentLine == rightBrace || noCommentLine == rightParenthesis {
|
||||
tokenCount--
|
||||
}
|
||||
if tokenCount == 0 {
|
||||
ts, err := format.Source([]byte(typeBuilder.String()))
|
||||
if err != nil {
|
||||
return false, errors.New("error format \n" + typeBuilder.String())
|
||||
}
|
||||
|
||||
result := strings.ReplaceAll(string(ts), " struct ", " ")
|
||||
result = strings.ReplaceAll(result, "type ()", "")
|
||||
builder.WriteString(result)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func mayInsertStructKeyword(line string, token *int) string {
|
||||
insertStruct := func() string {
|
||||
if strings.Contains(line, " struct") {
|
||||
return line
|
||||
}
|
||||
index := strings.Index(line, leftBrace)
|
||||
return line[:index] + " struct " + line[index:]
|
||||
}
|
||||
|
||||
noCommentLine := util.RemoveComment(line)
|
||||
if strings.HasSuffix(noCommentLine, leftBrace) {
|
||||
*token++
|
||||
return insertStruct()
|
||||
}
|
||||
if strings.HasSuffix(noCommentLine, rightBrace) {
|
||||
noCommentLine = strings.TrimSuffix(noCommentLine, rightBrace)
|
||||
noCommentLine = util.RemoveComment(noCommentLine)
|
||||
if strings.HasSuffix(noCommentLine, leftBrace) {
|
||||
return insertStruct()
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(noCommentLine, leftParenthesis) {
|
||||
*token++
|
||||
}
|
||||
return line
|
||||
}
|
||||
|
||||
@@ -24,11 +24,11 @@ handler: GreetHandler
|
||||
}
|
||||
`
|
||||
|
||||
formattedStr = `type Request struct {
|
||||
formattedStr = `type Request {
|
||||
Name string
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
type Response {
|
||||
Message string
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ service A-api {
|
||||
}`
|
||||
)
|
||||
|
||||
func TestInlineTypeNotExist(t *testing.T) {
|
||||
func TestFormat(t *testing.T) {
|
||||
r, err := apiFormat(notFormattedStr)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, r, formattedStr)
|
||||
|
||||
@@ -38,11 +38,12 @@ func RevertTemplate(name string) error {
|
||||
return util.CreateTemplate(category, name, content)
|
||||
}
|
||||
|
||||
func Update(category string) error {
|
||||
func Update() error {
|
||||
err := Clean()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return util.InitTemplates(category, templates)
|
||||
}
|
||||
|
||||
@@ -50,6 +51,6 @@ func Clean() error {
|
||||
return util.Clean(category)
|
||||
}
|
||||
|
||||
func GetCategory() string {
|
||||
func Category() string {
|
||||
return category
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ func TestUpdate(t *testing.T) {
|
||||
|
||||
assert.Equal(t, string(data), modifyData)
|
||||
|
||||
assert.Nil(t, Update(category))
|
||||
assert.Nil(t, Update())
|
||||
|
||||
data, err = ioutil.ReadFile(file)
|
||||
assert.Nil(t, err)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package javagen
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"path"
|
||||
@@ -17,6 +18,8 @@ const (
|
||||
package com.xhb.logic.http.packet.{{.packet}}.model;
|
||||
|
||||
import com.xhb.logic.http.DeProguardable;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
{{.componentType}}
|
||||
`
|
||||
@@ -28,7 +31,7 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error {
|
||||
return nil
|
||||
}
|
||||
for _, ty := range types {
|
||||
if err := createComponent(dir, packetName, ty); err != nil {
|
||||
if err := createComponent(dir, packetName, ty, api.Types); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -36,7 +39,7 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func createComponent(dir, packetName string, ty spec.Type) error {
|
||||
func createComponent(dir, packetName string, ty spec.Type, types []spec.Type) error {
|
||||
modelFile := util.Title(ty.Name) + ".java"
|
||||
filename := path.Join(dir, modelDir, modelFile)
|
||||
if err := util.RemoveOrQuit(filename); err != nil {
|
||||
@@ -52,7 +55,7 @@ func createComponent(dir, packetName string, ty spec.Type) error {
|
||||
}
|
||||
defer fp.Close()
|
||||
|
||||
tys, err := buildType(ty)
|
||||
tys, err := buildType(ty, types)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -64,22 +67,66 @@ func createComponent(dir, packetName string, ty spec.Type) error {
|
||||
})
|
||||
}
|
||||
|
||||
func buildType(ty spec.Type) (string, error) {
|
||||
func buildType(ty spec.Type, types []spec.Type) (string, error) {
|
||||
var builder strings.Builder
|
||||
if err := writeType(&builder, ty); err != nil {
|
||||
if err := writeType(&builder, ty, types); err != nil {
|
||||
return "", apiutil.WrapErr(err, "Type "+ty.Name+" generate error")
|
||||
}
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
func writeType(writer io.Writer, tp spec.Type) error {
|
||||
func writeType(writer io.Writer, tp spec.Type, types []spec.Type) error {
|
||||
fmt.Fprintf(writer, "public class %s implements DeProguardable {\n", util.Title(tp.Name))
|
||||
for _, member := range tp.Members {
|
||||
if err := writeProperty(writer, member, 1); err != nil {
|
||||
return err
|
||||
}
|
||||
var members []spec.Member
|
||||
err := writeMembers(writer, types, tp.Members, &members, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
genGetSet(writer, members, 1)
|
||||
fmt.Fprintf(writer, "}")
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeMembers(writer io.Writer, types []spec.Type, members []spec.Member, allMembers *[]spec.Member, indent int) error {
|
||||
for _, member := range members {
|
||||
if !member.IsInline {
|
||||
_, err := member.GetPropertyName()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if !member.IsBodyMember() {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, item := range *allMembers {
|
||||
if item.Name == member.Name {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if member.IsInline {
|
||||
hasInline := false
|
||||
for _, ty := range types {
|
||||
if strings.ToLower(ty.Name) == strings.ToLower(member.Name) {
|
||||
err := writeMembers(writer, types, ty.Members, allMembers, indent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hasInline = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasInline {
|
||||
return errors.New("inline type " + member.Name + " not exist, please correct api file")
|
||||
}
|
||||
} else {
|
||||
if err := writeProperty(writer, member, indent); err != nil {
|
||||
return err
|
||||
}
|
||||
*allMembers = append(*allMembers, member)
|
||||
}
|
||||
}
|
||||
genGetSet(writer, tp, 1)
|
||||
fmt.Fprintf(writer, "}\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -19,23 +19,27 @@ const packetTemplate = `package com.xhb.logic.http.packet.{{.packet}};
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.xhb.commons.JSON;
|
||||
import com.xhb.commons.JsonParser;
|
||||
import com.xhb.commons.JsonMarshal;
|
||||
import com.xhb.core.network.HttpRequestClient;
|
||||
import com.xhb.core.packet.HttpRequestPacket;
|
||||
import com.xhb.core.response.HttpResponseData;
|
||||
import com.xhb.logic.http.DeProguardable;
|
||||
{{if not .HasRequestBody}}
|
||||
import com.xhb.logic.http.request.EmptyRequest;
|
||||
{{end}}
|
||||
{{.import}}
|
||||
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
import org.json.JSONObject;
|
||||
|
||||
public class {{.packetName}} extends HttpRequestPacket<{{.packetName}}.{{.packetName}}Response> {
|
||||
|
||||
{{.paramsDeclaration}}
|
||||
|
||||
public {{.packetName}}({{.params}}{{.requestType}} request) {
|
||||
super(request);
|
||||
this.request = request;{{.paramsSet}}
|
||||
public {{.packetName}}({{.params}}{{if .HasRequestBody}}, {{.requestType}} request{{end}}) {
|
||||
{{if .HasRequestBody}}super(request);{{else}}super(EmptyRequest.instance);{{end}}
|
||||
{{if .HasRequestBody}}this.request = request;{{end}}{{.paramsSet}}
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -113,7 +117,8 @@ func createWith(dir string, api *spec.ApiSpec, route spec.Route, packetName stri
|
||||
} else {
|
||||
fmt.Fprintln(&builder)
|
||||
}
|
||||
if err := genType(&builder, tp); err != nil {
|
||||
|
||||
if err := genType(&builder, tp, api.Types); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -126,7 +131,7 @@ func createWith(dir string, api *spec.ApiSpec, route spec.Route, packetName stri
|
||||
|
||||
t := template.Must(template.New("packetTemplate").Parse(packetTemplate))
|
||||
var tmplBytes bytes.Buffer
|
||||
err = t.Execute(&tmplBytes, map[string]string{
|
||||
err = t.Execute(&tmplBytes, map[string]interface{}{
|
||||
"packetName": packet,
|
||||
"method": strings.ToUpper(route.Method),
|
||||
"uri": processUri(route),
|
||||
@@ -137,6 +142,7 @@ func createWith(dir string, api *spec.ApiSpec, route spec.Route, packetName stri
|
||||
"paramsSet": paramsSet,
|
||||
"packet": packetName,
|
||||
"requestType": util.Title(route.RequestType.Name),
|
||||
"HasRequestBody": len(route.RequestType.GetBodyMembers()) > 0,
|
||||
"import": getImports(api, route, packetName),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -209,7 +215,7 @@ func paramsForRoute(route spec.Route) string {
|
||||
builder.WriteString(fmt.Sprintf("String %s, ", cop[1:]))
|
||||
}
|
||||
}
|
||||
return builder.String()
|
||||
return strings.TrimSuffix(builder.String(), ", ")
|
||||
}
|
||||
|
||||
func declarationForRoute(route spec.Route) string {
|
||||
@@ -260,18 +266,22 @@ func processUri(route spec.Route) string {
|
||||
return result
|
||||
}
|
||||
|
||||
func genType(writer io.Writer, tp spec.Type) error {
|
||||
writeIndent(writer, 1)
|
||||
fmt.Fprintf(writer, "static class %s implements DeProguardable {\n", util.Title(tp.Name))
|
||||
for _, member := range tp.Members {
|
||||
if err := writeProperty(writer, member, 2); err != nil {
|
||||
return err
|
||||
}
|
||||
func genType(writer io.Writer, tp spec.Type, types []spec.Type) error {
|
||||
if len(tp.GetBodyMembers()) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
writeBreakline(writer)
|
||||
writeIndent(writer, 1)
|
||||
genGetSet(writer, tp, 2)
|
||||
fmt.Fprintf(writer, "static class %s implements DeProguardable {\n", util.Title(tp.Name))
|
||||
var members []spec.Member
|
||||
err := writeMembers(writer, types, tp.Members, &members, 2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
writeNewline(writer)
|
||||
writeIndent(writer, 1)
|
||||
genGetSet(writer, members, 2)
|
||||
writeIndent(writer, 1)
|
||||
fmt.Fprintln(writer, "}")
|
||||
|
||||
|
||||
@@ -67,8 +67,8 @@ func indentString(indent int) string {
|
||||
return result
|
||||
}
|
||||
|
||||
func writeBreakline(writer io.Writer) {
|
||||
fmt.Fprint(writer, "\n")
|
||||
func writeNewline(writer io.Writer) {
|
||||
fmt.Fprint(writer, util.NL)
|
||||
}
|
||||
|
||||
func isPrimitiveType(tp string) bool {
|
||||
@@ -87,6 +87,7 @@ func goTypeToJava(tp string) (string, error) {
|
||||
if len(tp) == 0 {
|
||||
return "", errors.New("property type empty")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(tp, "*") {
|
||||
tp = tp[1:]
|
||||
}
|
||||
@@ -107,39 +108,44 @@ func goTypeToJava(tp string) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(tys) == 0 {
|
||||
return "", fmt.Errorf("%s tp parse error", tp)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(tys[0])), nil
|
||||
} else if strings.HasPrefix(tp, "map") {
|
||||
tys, err := apiutil.DecomposeType(tp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(tys) == 2 {
|
||||
return "", fmt.Errorf("%s tp parse error", tp)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("java.util.HashMap<String, %s>", util.Title(tys[1])), nil
|
||||
}
|
||||
return util.Title(tp), nil
|
||||
}
|
||||
|
||||
func genGetSet(writer io.Writer, tp spec.Type, indent int) error {
|
||||
func genGetSet(writer io.Writer, members []spec.Member, indent int) error {
|
||||
t := template.Must(template.New("getSetTemplate").Parse(getSetTemplate))
|
||||
for _, member := range tp.Members {
|
||||
for _, member := range members {
|
||||
var tmplBytes bytes.Buffer
|
||||
|
||||
oty, err := goTypeToJava(member.Type)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tyString := oty
|
||||
decorator := ""
|
||||
if !isPrimitiveType(member.Type) {
|
||||
if member.IsOptional() {
|
||||
decorator = "@org.jetbrains.annotations.Nullable "
|
||||
decorator = "@Nullable "
|
||||
} else {
|
||||
decorator = "@org.jetbrains.annotations.NotNull "
|
||||
decorator = "@NotNull "
|
||||
}
|
||||
tyString = decorator + tyString
|
||||
}
|
||||
@@ -155,6 +161,7 @@ func genGetSet(writer io.Writer, tp spec.Type, indent int) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r := tmplBytes.String()
|
||||
r = strings.Replace(r, " boolean get", " boolean is", 1)
|
||||
writer.Write([]byte(r))
|
||||
|
||||
@@ -63,10 +63,6 @@ func (m Member) IsOmitempty() bool {
|
||||
|
||||
func (m Member) GetPropertyName() (string, error) {
|
||||
tags := m.Tags()
|
||||
if len(tags) == 0 {
|
||||
return "", errors.New("json property name not exist, member: " + m.Name)
|
||||
}
|
||||
|
||||
for _, tag := range tags {
|
||||
if stringx.Contains(definedKeys, tag.Key) {
|
||||
if tag.Name == "-" {
|
||||
|
||||
@@ -85,7 +85,7 @@ func genHandler(dir, webApi, caller string, api *spec.ApiSpec, unwrapApi bool) e
|
||||
imports += fmt.Sprintf(`import * as components from "%s"`, "./"+outputFile)
|
||||
}
|
||||
|
||||
apis, err := genApi(api, localTypes, caller, prefixForType)
|
||||
apis, err := genApi(api, caller, prefixForType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -119,32 +119,34 @@ func genTypes(localTypes []spec.Type, inlineType func(string) (*spec.Type, error
|
||||
return types, nil
|
||||
}
|
||||
|
||||
func genApi(api *spec.ApiSpec, localTypes []spec.Type, caller string, prefixForType func(string) string) (string, error) {
|
||||
func genApi(api *spec.ApiSpec, caller string, prefixForType func(string) string) (string, error) {
|
||||
var builder strings.Builder
|
||||
for _, route := range api.Service.Routes() {
|
||||
handler, ok := apiutil.GetAnnotationValue(route.Annotations, "server", "handler")
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing handler annotation for route %q", route.Path)
|
||||
}
|
||||
handler = util.Untitle(handler)
|
||||
handler = strings.Replace(handler, "Handler", "", 1)
|
||||
comment := commentForRoute(route)
|
||||
if len(comment) > 0 {
|
||||
fmt.Fprintf(&builder, "%s\n", comment)
|
||||
}
|
||||
fmt.Fprintf(&builder, "export function %s(%s) {\n", handler, paramsForRoute(route, prefixForType))
|
||||
writeIndent(&builder, 1)
|
||||
responseGeneric := "<null>"
|
||||
if len(route.ResponseType.Name) > 0 {
|
||||
val, err := goTypeToTs(route.ResponseType.Name, prefixForType)
|
||||
if err != nil {
|
||||
return "", err
|
||||
for _, group := range api.Service.Groups {
|
||||
for _, route := range group.Routes {
|
||||
handler, ok := apiutil.GetAnnotationValue(route.Annotations, "server", "handler")
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing handler annotation for route %q", route.Path)
|
||||
}
|
||||
responseGeneric = fmt.Sprintf("<%s>", val)
|
||||
handler = util.Untitle(handler)
|
||||
handler = strings.Replace(handler, "Handler", "", 1)
|
||||
comment := commentForRoute(route)
|
||||
if len(comment) > 0 {
|
||||
fmt.Fprintf(&builder, "%s\n", comment)
|
||||
}
|
||||
fmt.Fprintf(&builder, "export function %s(%s) {\n", handler, paramsForRoute(route, prefixForType))
|
||||
writeIndent(&builder, 1)
|
||||
responseGeneric := "<null>"
|
||||
if len(route.ResponseType.Name) > 0 {
|
||||
val, err := goTypeToTs(route.ResponseType.Name, prefixForType)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
responseGeneric = fmt.Sprintf("<%s>", val)
|
||||
}
|
||||
fmt.Fprintf(&builder, `return %s.%s%s(%s)`, caller, strings.ToLower(route.Method),
|
||||
util.Title(responseGeneric), callParamsForRoute(route, group))
|
||||
builder.WriteString("\n}\n\n")
|
||||
}
|
||||
fmt.Fprintf(&builder, `return %s.%s%s(%s)`, caller, strings.ToLower(route.Method),
|
||||
util.Title(responseGeneric), callParamsForRoute(route))
|
||||
builder.WriteString("\n}\n\n")
|
||||
}
|
||||
|
||||
apis := builder.String()
|
||||
@@ -188,21 +190,28 @@ func commentForRoute(route spec.Route) string {
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func callParamsForRoute(route spec.Route) string {
|
||||
func callParamsForRoute(route spec.Route, group spec.Group) string {
|
||||
hasParams := pathHasParams(route)
|
||||
hasBody := hasRequestBody(route)
|
||||
if hasParams && hasBody {
|
||||
return fmt.Sprintf("%s, %s, %s", pathForRoute(route), "params", "req")
|
||||
return fmt.Sprintf("%s, %s, %s", pathForRoute(route, group), "params", "req")
|
||||
} else if hasParams {
|
||||
return fmt.Sprintf("%s, %s", pathForRoute(route), "params")
|
||||
return fmt.Sprintf("%s, %s", pathForRoute(route, group), "params")
|
||||
} else if hasBody {
|
||||
return fmt.Sprintf("%s, %s", pathForRoute(route), "req")
|
||||
return fmt.Sprintf("%s, %s", pathForRoute(route, group), "req")
|
||||
}
|
||||
return pathForRoute(route)
|
||||
return pathForRoute(route, group)
|
||||
}
|
||||
|
||||
func pathForRoute(route spec.Route) string {
|
||||
return "\"" + route.Path + "\""
|
||||
func pathForRoute(route spec.Route, group spec.Group) string {
|
||||
value, ok := apiutil.GetAnnotationValue(group.Annotations, "server", pathPrefix)
|
||||
if !ok {
|
||||
return "\"" + route.Path + "\""
|
||||
} else {
|
||||
value = strings.TrimPrefix(value, `"`)
|
||||
value = strings.TrimSuffix(value, `"`)
|
||||
return fmt.Sprintf(`"%s/%s"`, value, strings.TrimPrefix(route.Path, "/"))
|
||||
}
|
||||
}
|
||||
|
||||
func pathHasParams(route spec.Route) bool {
|
||||
|
||||
@@ -2,4 +2,5 @@ package tsgen
|
||||
|
||||
const (
|
||||
packagePrefix = "components."
|
||||
pathPrefix = "pathPrefix"
|
||||
)
|
||||
|
||||
@@ -9,15 +9,17 @@ import (
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/logrusorgru/aurora"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
|
||||
"github.com/urfave/cli"
|
||||
)
|
||||
|
||||
const (
|
||||
etcDir = "etc"
|
||||
yamlEtx = ".yaml"
|
||||
cstOffset = 60 * 60 * 8 // 8 hours offset for Chinese Standard Time
|
||||
dockerfileName = "Dockerfile"
|
||||
etcDir = "etc"
|
||||
yamlEtx = ".yaml"
|
||||
cstOffset = 60 * 60 * 8 // 8 hours offset for Chinese Standard Time
|
||||
)
|
||||
|
||||
type Docker struct {
|
||||
@@ -25,10 +27,18 @@ type Docker struct {
|
||||
GoRelPath string
|
||||
GoFile string
|
||||
ExeFile string
|
||||
HasPort bool
|
||||
Port int
|
||||
Argument string
|
||||
}
|
||||
|
||||
func DockerCommand(c *cli.Context) error {
|
||||
func DockerCommand(c *cli.Context) (err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
fmt.Println(aurora.Green("Done."))
|
||||
}
|
||||
}()
|
||||
|
||||
goFile := c.String("go")
|
||||
if len(goFile) == 0 {
|
||||
return errors.New("-go can't be empty")
|
||||
@@ -38,8 +48,9 @@ func DockerCommand(c *cli.Context) error {
|
||||
return fmt.Errorf("file %q not found", goFile)
|
||||
}
|
||||
|
||||
port := c.Int("port")
|
||||
if _, err := os.Stat(etcDir); os.IsNotExist(err) {
|
||||
return generateDockerfile(goFile)
|
||||
return generateDockerfile(goFile, port)
|
||||
}
|
||||
|
||||
cfg, err := findConfig(goFile, etcDir)
|
||||
@@ -47,13 +58,13 @@ func DockerCommand(c *cli.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := generateDockerfile(goFile, "-f", "etc/"+cfg); err != nil {
|
||||
if err := generateDockerfile(goFile, port, "-f", "etc/"+cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
projDir, ok := util.FindProjectPath(goFile)
|
||||
if ok {
|
||||
fmt.Printf("Run \"docker build ...\" command in dir %q\n", projDir)
|
||||
fmt.Printf("Hint: run \"docker build ...\" command in dir %q\n", projDir)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -88,18 +99,22 @@ func findConfig(file, dir string) (string, error) {
|
||||
return files[0], nil
|
||||
}
|
||||
|
||||
func generateDockerfile(goFile string, args ...string) error {
|
||||
func generateDockerfile(goFile string, port int, args ...string) error {
|
||||
projPath, err := getFilePath(filepath.Dir(goFile))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pos := strings.IndexByte(projPath, '/')
|
||||
if pos >= 0 {
|
||||
projPath = projPath[pos+1:]
|
||||
if len(projPath) == 0 {
|
||||
projPath = "."
|
||||
} else {
|
||||
pos := strings.IndexByte(projPath, os.PathSeparator)
|
||||
if pos >= 0 {
|
||||
projPath = projPath[pos+1:]
|
||||
}
|
||||
}
|
||||
|
||||
out, err := util.CreateIfNotExist("Dockerfile")
|
||||
out, err := util.CreateIfNotExist(dockerfileName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -122,6 +137,8 @@ func generateDockerfile(goFile string, args ...string) error {
|
||||
GoRelPath: projPath,
|
||||
GoFile: goFile,
|
||||
ExeFile: util.FileNameWithoutExt(filepath.Base(goFile)),
|
||||
HasPort: port > 0,
|
||||
Port: port,
|
||||
Argument: builder.String(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,34 +14,59 @@ LABEL stage=gobuilder
|
||||
|
||||
ENV CGO_ENABLED 0
|
||||
ENV GOOS linux
|
||||
{{if .Chinese}}ENV GOPROXY https://goproxy.cn,direct{{end}}
|
||||
|
||||
{{if .Chinese}}ENV GOPROXY https://goproxy.cn,direct
|
||||
{{end}}
|
||||
WORKDIR /build/zero
|
||||
|
||||
ADD go.mod .
|
||||
ADD go.sum .
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
COPY {{.GoRelPath}}/etc /app/etc
|
||||
RUN go build -ldflags="-s -w" -o /app/{{.ExeFile}} {{.GoRelPath}}/{{.GoFile}}
|
||||
{{if .Argument}}COPY {{.GoRelPath}}/etc /app/etc
|
||||
{{end}}RUN go build -ldflags="-s -w" -o /app/{{.ExeFile}} {{.GoRelPath}}/{{.GoFile}}
|
||||
|
||||
|
||||
FROM alpine
|
||||
|
||||
RUN apk update --no-cache
|
||||
RUN apk add --no-cache ca-certificates
|
||||
RUN apk add --no-cache tzdata
|
||||
RUN apk update --no-cache && apk add --no-cache ca-certificates tzdata
|
||||
ENV TZ Asia/Shanghai
|
||||
|
||||
WORKDIR /app
|
||||
COPY --from=builder /app/{{.ExeFile}} /app/{{.ExeFile}}
|
||||
COPY --from=builder /app/etc /app/etc
|
||||
|
||||
COPY --from=builder /app/{{.ExeFile}} /app/{{.ExeFile}}{{if .Argument}}
|
||||
COPY --from=builder /app/etc /app/etc{{end}}
|
||||
{{if .HasPort}}
|
||||
EXPOSE {{.Port}}
|
||||
{{end}}
|
||||
CMD ["./{{.ExeFile}}"{{.Argument}}]
|
||||
`
|
||||
)
|
||||
|
||||
func Clean() error {
|
||||
return util.Clean(category)
|
||||
}
|
||||
|
||||
func GenTemplates(_ *cli.Context) error {
|
||||
return initTemplate()
|
||||
}
|
||||
|
||||
func Category() string {
|
||||
return category
|
||||
}
|
||||
|
||||
func RevertTemplate(name string) error {
|
||||
return util.CreateTemplate(category, name, dockerTemplate)
|
||||
}
|
||||
|
||||
func Update() error {
|
||||
err := Clean()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return initTemplate()
|
||||
}
|
||||
|
||||
func initTemplate() error {
|
||||
return util.InitTemplates(category, map[string]string{
|
||||
dockerTemplateFile: dockerTemplate,
|
||||
})
|
||||
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
BuildVersion = "20201125"
|
||||
BuildVersion = "1.1.1"
|
||||
commands = []cli.Command{
|
||||
{
|
||||
Name: "api",
|
||||
@@ -54,14 +54,12 @@ var (
|
||||
Usage: "the format target dir",
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "iu",
|
||||
Usage: "ignore update",
|
||||
Required: false,
|
||||
Name: "iu",
|
||||
Usage: "ignore update",
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "stdin",
|
||||
Usage: "use stdin to input api doc content, press \"ctrl + d\" to send EOF",
|
||||
Required: false,
|
||||
Name: "stdin",
|
||||
Usage: "use stdin to input api doc content, press \"ctrl + d\" to send EOF",
|
||||
},
|
||||
},
|
||||
Action: format.GoFormatApi,
|
||||
@@ -101,9 +99,8 @@ var (
|
||||
Usage: "the api file",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "style",
|
||||
Required: false,
|
||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||
Name: "style",
|
||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||
},
|
||||
},
|
||||
Action: gogen.GoCommand,
|
||||
@@ -136,19 +133,16 @@ var (
|
||||
Usage: "the api file",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "webapi",
|
||||
Usage: "the web api file path",
|
||||
Required: false,
|
||||
Name: "webapi",
|
||||
Usage: "the web api file path",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "caller",
|
||||
Usage: "the web api caller",
|
||||
Required: false,
|
||||
Name: "caller",
|
||||
Usage: "the web api caller",
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "unwrap",
|
||||
Usage: "unwrap the webapi caller for import",
|
||||
Required: false,
|
||||
Name: "unwrap",
|
||||
Usage: "unwrap the webapi caller for import",
|
||||
},
|
||||
},
|
||||
Action: tsgen.TsCommand,
|
||||
@@ -204,9 +198,8 @@ var (
|
||||
Usage: "the api file",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "style",
|
||||
Required: false,
|
||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||
Name: "style",
|
||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||
},
|
||||
},
|
||||
Action: plugin.PluginCommand,
|
||||
@@ -221,6 +214,11 @@ var (
|
||||
Name: "go",
|
||||
Usage: "the file that contains main function",
|
||||
},
|
||||
cli.IntFlag{
|
||||
Name: "port",
|
||||
Usage: "the port to expose, default none",
|
||||
Value: 0,
|
||||
},
|
||||
},
|
||||
Action: docker.DockerCommand,
|
||||
},
|
||||
@@ -248,9 +246,8 @@ var (
|
||||
Required: true,
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "secret",
|
||||
Usage: "the image pull secret",
|
||||
Required: true,
|
||||
Name: "secret",
|
||||
Usage: "the secret to image pull from registry",
|
||||
},
|
||||
cli.IntFlag{
|
||||
Name: "requestCpu",
|
||||
@@ -321,9 +318,8 @@ var (
|
||||
Usage: `generate rpc demo service`,
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{
|
||||
Name: "style",
|
||||
Required: false,
|
||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||
Name: "style",
|
||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "idea",
|
||||
@@ -360,9 +356,8 @@ var (
|
||||
Usage: `the target path of the code`,
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "style",
|
||||
Required: false,
|
||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||
Name: "style",
|
||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "idea",
|
||||
@@ -394,9 +389,8 @@ var (
|
||||
Usage: "the target dir",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "style",
|
||||
Required: false,
|
||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||
Name: "style",
|
||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "cache, c",
|
||||
@@ -430,9 +424,8 @@ var (
|
||||
Usage: "the target dir",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "style",
|
||||
Required: false,
|
||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||
Name: "style",
|
||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "idea",
|
||||
@@ -476,7 +469,7 @@ var (
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{
|
||||
Name: "category,c",
|
||||
Usage: "the category of template, enum [api,rpc,model]",
|
||||
Usage: "the category of template, enum [api,rpc,model,docker,kube]",
|
||||
},
|
||||
},
|
||||
Action: tpl.UpdateTemplates,
|
||||
@@ -487,7 +480,7 @@ var (
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{
|
||||
Name: "category,c",
|
||||
Usage: "the category of template, enum [api,rpc,model]",
|
||||
Usage: "the category of template, enum [api,rpc,model,docker,kube]",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "name,n",
|
||||
|
||||
@@ -47,9 +47,9 @@ spec:
|
||||
volumeMounts:
|
||||
- name: timezone
|
||||
mountPath: /etc/localtime
|
||||
imagePullSecrets:
|
||||
{{if .Secret}}imagePullSecrets:
|
||||
- name: {{.Secret}}
|
||||
volumes:
|
||||
{{end}}volumes:
|
||||
- name: timezone
|
||||
hostPath:
|
||||
path: /usr/share/zoneinfo/Asia/Shanghai
|
||||
|
||||
@@ -2,8 +2,10 @@ package kube
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"text/template"
|
||||
|
||||
"github.com/logrusorgru/aurora"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||
"github.com/urfave/cli"
|
||||
)
|
||||
@@ -16,47 +18,23 @@ const (
|
||||
portLimit = 32767
|
||||
)
|
||||
|
||||
var errUnknownServiceType = errors.New("unknown service type")
|
||||
|
||||
type (
|
||||
ServiceType string
|
||||
|
||||
KubeRequest struct {
|
||||
Env string
|
||||
ServiceName string
|
||||
ServiceType ServiceType
|
||||
Namespace string
|
||||
Schedule string
|
||||
Replicas int
|
||||
RevisionHistoryLimit int
|
||||
Port int
|
||||
LimitCpu int
|
||||
LimitMem int
|
||||
RequestCpu int
|
||||
RequestMem int
|
||||
SuccessfulJobsHistoryLimit int
|
||||
HpaMinReplicas int
|
||||
HpaMaxReplicas int
|
||||
}
|
||||
|
||||
Deployment struct {
|
||||
Name string
|
||||
Namespace string
|
||||
Image string
|
||||
Secret string
|
||||
Replicas int
|
||||
Revisions int
|
||||
Port int
|
||||
NodePort int
|
||||
UseNodePort bool
|
||||
RequestCpu int
|
||||
RequestMem int
|
||||
LimitCpu int
|
||||
LimitMem int
|
||||
MinReplicas int
|
||||
MaxReplicas int
|
||||
}
|
||||
)
|
||||
type Deployment struct {
|
||||
Name string
|
||||
Namespace string
|
||||
Image string
|
||||
Secret string
|
||||
Replicas int
|
||||
Revisions int
|
||||
Port int
|
||||
NodePort int
|
||||
UseNodePort bool
|
||||
RequestCpu int
|
||||
RequestMem int
|
||||
LimitCpu int
|
||||
LimitMem int
|
||||
MinReplicas int
|
||||
MaxReplicas int
|
||||
}
|
||||
|
||||
func DeploymentCommand(c *cli.Context) error {
|
||||
nodePort := c.Int("nodePort")
|
||||
@@ -77,7 +55,7 @@ func DeploymentCommand(c *cli.Context) error {
|
||||
defer out.Close()
|
||||
|
||||
t := template.Must(template.New("deploymentTemplate").Parse(text))
|
||||
return t.Execute(out, Deployment{
|
||||
err = t.Execute(out, Deployment{
|
||||
Name: c.String("name"),
|
||||
Namespace: c.String("namespace"),
|
||||
Image: c.String("image"),
|
||||
@@ -94,6 +72,20 @@ func DeploymentCommand(c *cli.Context) error {
|
||||
MinReplicas: c.Int("minReplicas"),
|
||||
MaxReplicas: c.Int("maxReplicas"),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println(aurora.Green("Done."))
|
||||
return nil
|
||||
}
|
||||
|
||||
func Category() string {
|
||||
return category
|
||||
}
|
||||
|
||||
func Clean() error {
|
||||
return util.Clean(category)
|
||||
}
|
||||
|
||||
func GenTemplates(_ *cli.Context) error {
|
||||
@@ -102,3 +94,19 @@ func GenTemplates(_ *cli.Context) error {
|
||||
jobTemplateFile: jobTmeplate,
|
||||
})
|
||||
}
|
||||
|
||||
func RevertTemplate(name string) error {
|
||||
return util.CreateTemplate(category, name, deploymentTemplate)
|
||||
}
|
||||
|
||||
func Update() error {
|
||||
err := Clean()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return util.InitTemplates(category, map[string]string{
|
||||
deployTemplateFile: deploymentTemplate,
|
||||
jobTemplateFile: jobTmeplate,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -61,9 +61,9 @@ func FieldNames(in interface{}) []string {
|
||||
// gets us a StructField
|
||||
fi := typ.Field(i)
|
||||
if tagv := fi.Tag.Get(dbTag); tagv != "" {
|
||||
out = append(out, tagv)
|
||||
out = append(out, fmt.Sprintf("`%v`", tagv))
|
||||
} else {
|
||||
out = append(out, fi.Name)
|
||||
out = append(out, fmt.Sprintf("`%v`", fi.Name))
|
||||
}
|
||||
}
|
||||
return out
|
||||
|
||||
@@ -28,8 +28,7 @@ var userFields = FieldNames(User{})
|
||||
func TestFieldNames(t *testing.T) {
|
||||
var u User
|
||||
out := FieldNames(&u)
|
||||
fmt.Println(out)
|
||||
actual := []string{"id", "user_name", "sex", "uuid", "age"}
|
||||
actual := []string{"`id`", "`user_name`", "`sex`", "`uuid`", "`age`"}
|
||||
assert.Equal(t, out, actual)
|
||||
}
|
||||
|
||||
@@ -54,7 +53,7 @@ func TestBuilderSql(t *testing.T) {
|
||||
sql, args, err := builder.Select(fields...).From("user").Where(eq).ToSQL()
|
||||
fmt.Println(sql, args, err)
|
||||
|
||||
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE id=?"
|
||||
actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE id=?"
|
||||
actualArgs := []interface{}{"123123"}
|
||||
assert.Equal(t, sql, actualSql)
|
||||
assert.Equal(t, args, actualArgs)
|
||||
@@ -68,7 +67,7 @@ func TestBuildSqlDefaultValue(t *testing.T) {
|
||||
sql, args, err := builder.Select(userFields...).From("user").Where(eq).ToSQL()
|
||||
fmt.Println(sql, args, err)
|
||||
|
||||
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE age=? AND user_name=?"
|
||||
actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE age=? AND user_name=?"
|
||||
actualArgs := []interface{}{0, ""}
|
||||
assert.Equal(t, sql, actualSql)
|
||||
assert.Equal(t, args, actualArgs)
|
||||
@@ -83,7 +82,7 @@ func TestBuilderSqlIn(t *testing.T) {
|
||||
sql, args, err := builder.Select(userFields...).From("user").Where(in).And(gtU).ToSQL()
|
||||
fmt.Println(sql, args, err)
|
||||
|
||||
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE id IN (?,?,?) AND age>?"
|
||||
actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE id IN (?,?,?) AND age>?"
|
||||
actualArgs := []interface{}{"1", "2", "3", 18}
|
||||
assert.Equal(t, sql, actualSql)
|
||||
assert.Equal(t, args, actualArgs)
|
||||
@@ -94,7 +93,7 @@ func TestBuildSqlLike(t *testing.T) {
|
||||
sql, args, err := builder.Select(userFields...).From("user").Where(like).ToSQL()
|
||||
fmt.Println(sql, args, err)
|
||||
|
||||
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE name LIKE ?"
|
||||
actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE name LIKE ?"
|
||||
actualArgs := []interface{}{"%wang%"}
|
||||
assert.Equal(t, sql, actualSql)
|
||||
assert.Equal(t, args, actualArgs)
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
# generate model with cache from ddl
|
||||
fromDDL:
|
||||
goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/user" -cache
|
||||
fromDDLWithCache:
|
||||
goctl template clean;
|
||||
goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/cache/user" -cache;
|
||||
|
||||
fromDDLWithoutCache:
|
||||
goctl template clean;
|
||||
goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/nocache/user";
|
||||
|
||||
|
||||
# generate model with cache from data source
|
||||
@@ -12,4 +17,5 @@ datasource=127.0.0.1:3306
|
||||
database=gozero
|
||||
|
||||
fromDataSource:
|
||||
goctl model mysql datasource -url="$(user):$(password)@tcp($(datasource))/$(database)" -table="*" -dir ./model/cache -c -style gozero
|
||||
goctl template clean;
|
||||
goctl model mysql datasource -url="$(user):$(password)@tcp($(datasource))/$(database)" -table="*" -dir ./model/cache -c -style gozero;
|
||||
@@ -36,7 +36,7 @@ func genDelete(table Table, withCache bool) (string, string, error) {
|
||||
"lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(),
|
||||
"dataType": table.PrimaryKey.DataType,
|
||||
"keys": strings.Join(keySet.KeysStr(), "\n"),
|
||||
"originalPrimaryKey": table.PrimaryKey.Name.Source(),
|
||||
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
|
||||
"keyValues": strings.Join(keyVariableSet.KeysStr(), ", "),
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -19,7 +19,7 @@ func genFindOne(table Table, withCache bool) (string, string, error) {
|
||||
"withCache": withCache,
|
||||
"upperStartCamelObject": camel,
|
||||
"lowerStartCamelObject": stringx.From(camel).Untitle(),
|
||||
"originalPrimaryKey": table.PrimaryKey.Name.Source(),
|
||||
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
|
||||
"lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(),
|
||||
"dataType": table.PrimaryKey.DataType,
|
||||
"cacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].KeyExpression,
|
||||
|
||||
@@ -39,7 +39,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
|
||||
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
|
||||
"lowerStartCamelField": stringx.From(camelFieldName).Untitle(),
|
||||
"upperStartCamelPrimaryKey": table.PrimaryKey.Name.ToCamel(),
|
||||
"originalField": field.Name.Source(),
|
||||
"originalField": wrapWithRawString(field.Name.Source()),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -82,7 +82,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
|
||||
"upperStartCamelObject": camelTableName,
|
||||
"primaryKeyLeft": table.CacheKey[table.PrimaryKey.Name.Source()].Left,
|
||||
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
|
||||
"originalPrimaryField": table.PrimaryKey.Name.Source(),
|
||||
"originalPrimaryField": wrapWithRawString(table.PrimaryKey.Name.Source()),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -21,9 +21,6 @@ import (
|
||||
const (
|
||||
pwd = "."
|
||||
createTableFlag = `(?m)^(?i)CREATE\s+TABLE` // ignore case
|
||||
NamingLower = "lower"
|
||||
NamingCamel = "camel"
|
||||
NamingSnake = "snake"
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -280,3 +277,20 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
|
||||
|
||||
return output.String(), nil
|
||||
}
|
||||
|
||||
func wrapWithRawString(v string) string {
|
||||
if v == "`" {
|
||||
return v
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(v, "`") {
|
||||
v = "`" + v
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(v, "`") {
|
||||
v = v + "`"
|
||||
} else if len(v) == 1 {
|
||||
v = v + "`"
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
package gen
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"github.com/tal-tech/go-zero/core/stringx"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/config"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -79,3 +84,32 @@ func TestNamingModel(t *testing.T) {
|
||||
return err == nil
|
||||
}())
|
||||
}
|
||||
|
||||
func TestWrapWithRawString(t *testing.T) {
|
||||
assert.Equal(t, "``", wrapWithRawString(""))
|
||||
assert.Equal(t, "``", wrapWithRawString("``"))
|
||||
assert.Equal(t, "`a`", wrapWithRawString("a"))
|
||||
assert.Equal(t, "` `", wrapWithRawString(" "))
|
||||
}
|
||||
|
||||
func TestFields(t *testing.T) {
|
||||
type Student struct {
|
||||
Id int64 `db:"id"`
|
||||
Name string `db:"name"`
|
||||
Age sql.NullInt64 `db:"age"`
|
||||
Score sql.NullFloat64 `db:"score"`
|
||||
CreateTime time.Time `db:"create_time"`
|
||||
UpdateTime sql.NullTime `db:"update_time"`
|
||||
}
|
||||
var (
|
||||
studentFieldNames = builderx.FieldNames(&Student{})
|
||||
studentRows = strings.Join(studentFieldNames, ",")
|
||||
studentRowsExpectAutoSet = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
|
||||
studentRowsWithPlaceHolder = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
|
||||
)
|
||||
|
||||
assert.Equal(t, []string{"`id`", "`name`", "`age`", "`score`", "`create_time`", "`update_time`"}, studentFieldNames)
|
||||
assert.Equal(t, "`id`,`name`,`age`,`score`,`create_time`,`update_time`", studentRows)
|
||||
assert.Equal(t, "`name`,`age`,`score`", studentRowsExpectAutoSet)
|
||||
assert.Equal(t, "`name`=?,`age`=?,`score`=?", studentRowsWithPlaceHolder)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ func genNew(table Table, withCache bool) (string, error) {
|
||||
output, err := util.With("new").
|
||||
Parse(text).
|
||||
Execute(map[string]interface{}{
|
||||
"table": table.Name.Source(),
|
||||
"table": wrapWithRawString(table.Name.Source()),
|
||||
"withCache": withCache,
|
||||
"upperStartCamelObject": table.Name.ToCamel(),
|
||||
})
|
||||
|
||||
@@ -54,6 +54,14 @@ var templates = map[string]string{
|
||||
errTemplateFile: template.Error,
|
||||
}
|
||||
|
||||
func Category() string {
|
||||
return category
|
||||
}
|
||||
|
||||
func Clean() error {
|
||||
return util.Clean(category)
|
||||
}
|
||||
|
||||
func GenTemplates(_ *cli.Context) error {
|
||||
return util.InitTemplates(category, templates)
|
||||
}
|
||||
@@ -66,18 +74,10 @@ func RevertTemplate(name string) error {
|
||||
return util.CreateTemplate(category, name, content)
|
||||
}
|
||||
|
||||
func Clean() error {
|
||||
return util.Clean(category)
|
||||
}
|
||||
|
||||
func Update(category string) error {
|
||||
func Update() error {
|
||||
err := Clean()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return util.InitTemplates(category, templates)
|
||||
}
|
||||
|
||||
func GetCategory() string {
|
||||
return category
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ func TestUpdate(t *testing.T) {
|
||||
|
||||
assert.Equal(t, string(data), modifyData)
|
||||
|
||||
assert.Nil(t, Update(category))
|
||||
assert.Nil(t, Update())
|
||||
|
||||
data, err = ioutil.ReadFile(file)
|
||||
assert.Nil(t, err)
|
||||
|
||||
@@ -35,7 +35,7 @@ func genUpdate(table Table, withCache bool) (string, string, error) {
|
||||
"primaryCacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].DataKeyExpression,
|
||||
"primaryKeyVariable": table.CacheKey[table.PrimaryKey.Name.Source()].Variable,
|
||||
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
|
||||
"originalPrimaryKey": table.PrimaryKey.Name.Source(),
|
||||
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
|
||||
"expressionValues": strings.Join(expressionValues, ", "),
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -27,7 +27,7 @@ func genVars(table Table, withCache bool) (string, error) {
|
||||
"upperStartCamelObject": camel,
|
||||
"cacheKeys": strings.Join(keys, "\n"),
|
||||
"autoIncrement": table.PrimaryKey.AutoIncrement,
|
||||
"originalPrimaryKey": table.PrimaryKey.Name.Source(),
|
||||
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
|
||||
"withCache": withCache,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/tal-tech/go-zero/core/stores/sqlx"
|
||||
)
|
||||
|
||||
type (
|
||||
DDLModel struct {
|
||||
conn sqlx.SqlConn
|
||||
}
|
||||
DDL struct {
|
||||
Table string `db:"Table"`
|
||||
DDL string `db:"Create Table"`
|
||||
}
|
||||
)
|
||||
|
||||
func NewDDLModel(conn sqlx.SqlConn) *DDLModel {
|
||||
return &DDLModel{conn: conn}
|
||||
}
|
||||
|
||||
func (m *DDLModel) ShowDDL(table ...string) ([]string, error) {
|
||||
var ddl []string
|
||||
for _, t := range table {
|
||||
query := `show create table ` + t
|
||||
var resp DDL
|
||||
err := m.conn.QueryRow(&resp, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ddl = append(ddl, resp.DDL)
|
||||
}
|
||||
return ddl, nil
|
||||
}
|
||||
@@ -1,12 +1,14 @@
|
||||
package template
|
||||
|
||||
var Vars = `
|
||||
import "fmt"
|
||||
|
||||
var Vars = fmt.Sprintf(`
|
||||
var (
|
||||
{{.lowerStartCamelObject}}FieldNames = builderx.FieldNames(&{{.upperStartCamelObject}}{})
|
||||
{{.lowerStartCamelObject}}Rows = strings.Join({{.lowerStartCamelObject}}FieldNames, ",")
|
||||
{{.lowerStartCamelObject}}RowsExpectAutoSet = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "create_time", "update_time"), ",")
|
||||
{{.lowerStartCamelObject}}RowsWithPlaceHolder = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, "{{.originalPrimaryKey}}", "create_time", "update_time"), "=?,") + "=?"
|
||||
{{.lowerStartCamelObject}}RowsExpectAutoSet = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "%screate_time%s", "%supdate_time%s"), ",")
|
||||
{{.lowerStartCamelObject}}RowsWithPlaceHolder = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, "{{.originalPrimaryKey}}", "%screate_time%s", "%supdate_time%s"), "=?,") + "=?"
|
||||
|
||||
{{if .withCache}}{{.cacheKeys}}{{end}}
|
||||
)
|
||||
`
|
||||
`, "`", "`", "`", "`", "`", "`", "`", "`")
|
||||
|
||||
235
tools/goctl/model/sql/test/model/model_test.go
Normal file
235
tools/goctl/model/sql/test/model/model_test.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/stores/cache"
|
||||
"github.com/tal-tech/go-zero/core/stores/redis"
|
||||
"github.com/tal-tech/go-zero/core/stores/redis/redistest"
|
||||
mocksql "github.com/tal-tech/go-zero/tools/goctl/model/sql/test"
|
||||
)
|
||||
|
||||
func TestStudentModel(t *testing.T) {
|
||||
var (
|
||||
testTimeValue = time.Now()
|
||||
testTable = "`student`"
|
||||
testUpdateName = "gozero1"
|
||||
testRowsAffected int64 = 1
|
||||
testInsertId int64 = 1
|
||||
)
|
||||
|
||||
var data Student
|
||||
data.Id = testInsertId
|
||||
data.Name = "gozero"
|
||||
data.Age = sql.NullInt64{
|
||||
Int64: 1,
|
||||
Valid: true,
|
||||
}
|
||||
data.Score = sql.NullFloat64{
|
||||
Float64: 100,
|
||||
Valid: true,
|
||||
}
|
||||
data.CreateTime = testTimeValue
|
||||
data.UpdateTime = sql.NullTime{
|
||||
Time: testTimeValue,
|
||||
Valid: true,
|
||||
}
|
||||
|
||||
err := mockStudent(func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(fmt.Sprintf("insert into %s", testTable)).
|
||||
WithArgs(data.Name, data.Age, data.Score).
|
||||
WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
|
||||
}, func(m StudentModel) {
|
||||
r, err := m.Insert(data)
|
||||
assert.Nil(t, err)
|
||||
|
||||
lastInsertId, err := r.LastInsertId()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, testInsertId, lastInsertId)
|
||||
|
||||
rowsAffected, err := r.RowsAffected()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, testRowsAffected, rowsAffected)
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = mockStudent(func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s", testTable)).
|
||||
WithArgs(testInsertId).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertId, data.Name, data.Age, data.Score, testTimeValue, testTimeValue))
|
||||
}, func(m StudentModel) {
|
||||
result, err := m.FindOne(testInsertId)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, *result, data)
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = mockStudent(func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(fmt.Sprintf("update %s", testTable)).WithArgs(testUpdateName, data.Age, data.Score, testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
|
||||
}, func(m StudentModel) {
|
||||
data.Name = testUpdateName
|
||||
err := m.Update(data)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = mockStudent(func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s ", testTable)).
|
||||
WithArgs(testInsertId).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertId, data.Name, data.Age, data.Score, testTimeValue, testTimeValue))
|
||||
}, func(m StudentModel) {
|
||||
result, err := m.FindOne(testInsertId)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, *result, data)
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = mockStudent(func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(fmt.Sprintf("delete from %s where `id` = ?", testTable)).WithArgs(testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
|
||||
}, func(m StudentModel) {
|
||||
err := m.Delete(testInsertId)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestUserModel(t *testing.T) {
|
||||
var (
|
||||
testTimeValue = time.Now()
|
||||
testTable = "`user`"
|
||||
testUpdateName = "gozero1"
|
||||
testUser = "gozero"
|
||||
testPassword = "test"
|
||||
testMobile = "test_mobile"
|
||||
testGender = "男"
|
||||
testNickname = "test_nickname"
|
||||
testRowsAffected int64 = 1
|
||||
testInsertId int64 = 1
|
||||
)
|
||||
|
||||
var data User
|
||||
data.Id = testInsertId
|
||||
data.User = testUser
|
||||
data.Name = "gozero"
|
||||
data.Password = testPassword
|
||||
data.Mobile = testMobile
|
||||
data.Gender = testGender
|
||||
data.Nickname = testNickname
|
||||
data.CreateTime = testTimeValue
|
||||
data.UpdateTime = testTimeValue
|
||||
|
||||
err := mockUser(func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(fmt.Sprintf("insert into %s", testTable)).
|
||||
WithArgs(data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname).
|
||||
WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
|
||||
}, func(m UserModel) {
|
||||
r, err := m.Insert(data)
|
||||
assert.Nil(t, err)
|
||||
|
||||
lastInsertId, err := r.LastInsertId()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, testInsertId, lastInsertId)
|
||||
|
||||
rowsAffected, err := r.RowsAffected()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, testRowsAffected, rowsAffected)
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = mockUser(func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s", testTable)).
|
||||
WithArgs(testInsertId).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "user", "name", "password", "mobile", "gender", "nickname", "create_time", "update_time"}).AddRow(testInsertId, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, testTimeValue, testTimeValue))
|
||||
}, func(m UserModel) {
|
||||
result, err := m.FindOne(testInsertId)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, *result, data)
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = mockUser(func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(fmt.Sprintf("update %s", testTable)).WithArgs(data.User, testUpdateName, data.Password, data.Mobile, data.Gender, data.Nickname, testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
|
||||
}, func(m UserModel) {
|
||||
data.Name = testUpdateName
|
||||
err := m.Update(data)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = mockUser(func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s ", testTable)).
|
||||
WithArgs(testInsertId).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "user", "name", "password", "mobile", "gender", "nickname", "create_time", "update_time"}).AddRow(testInsertId, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, testTimeValue, testTimeValue))
|
||||
}, func(m UserModel) {
|
||||
result, err := m.FindOne(testInsertId)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, *result, data)
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = mockUser(func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(fmt.Sprintf("delete from %s where `id` = ?", testTable)).WithArgs(testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
|
||||
}, func(m UserModel) {
|
||||
err := m.Delete(testInsertId)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
// with cache
|
||||
func mockStudent(mockFn func(mock sqlmock.Sqlmock), fn func(m StudentModel)) error {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectBegin()
|
||||
mockFn(mock)
|
||||
mock.ExpectCommit()
|
||||
|
||||
conn := mocksql.NewMockConn(db)
|
||||
r, clean, err := redistest.CreateRedis()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer clean()
|
||||
|
||||
m := NewStudentModel(conn, cache.CacheConf{
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: r.Addr,
|
||||
Type: "node",
|
||||
},
|
||||
Weight: 100,
|
||||
},
|
||||
})
|
||||
fn(m)
|
||||
return nil
|
||||
}
|
||||
|
||||
// without cache
|
||||
func mockUser(mockFn func(mock sqlmock.Sqlmock), fn func(m UserModel)) error {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectBegin()
|
||||
mockFn(mock)
|
||||
mock.ExpectCommit()
|
||||
|
||||
conn := mocksql.NewMockConn(db)
|
||||
m := NewUserModel(conn)
|
||||
fn(m)
|
||||
return nil
|
||||
}
|
||||
105
tools/goctl/model/sql/test/model/studentmodel.go
Executable file
105
tools/goctl/model/sql/test/model/studentmodel.go
Executable file
@@ -0,0 +1,105 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/stores/cache"
|
||||
"github.com/tal-tech/go-zero/core/stores/sqlc"
|
||||
"github.com/tal-tech/go-zero/core/stores/sqlx"
|
||||
"github.com/tal-tech/go-zero/core/stringx"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
|
||||
)
|
||||
|
||||
var (
|
||||
studentFieldNames = builderx.FieldNames(&Student{})
|
||||
studentRows = strings.Join(studentFieldNames, ",")
|
||||
studentRowsExpectAutoSet = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
|
||||
studentRowsWithPlaceHolder = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
|
||||
|
||||
cacheStudentIdPrefix = "cache#Student#id#"
|
||||
)
|
||||
|
||||
type (
|
||||
StudentModel interface {
|
||||
Insert(data Student) (sql.Result, error)
|
||||
FindOne(id int64) (*Student, error)
|
||||
Update(data Student) error
|
||||
Delete(id int64) error
|
||||
}
|
||||
|
||||
defaultStudentModel struct {
|
||||
sqlc.CachedConn
|
||||
table string
|
||||
}
|
||||
|
||||
Student struct {
|
||||
Id int64 `db:"id"`
|
||||
Name string `db:"name"`
|
||||
Age sql.NullInt64 `db:"age"`
|
||||
Score sql.NullFloat64 `db:"score"`
|
||||
CreateTime time.Time `db:"create_time"`
|
||||
UpdateTime sql.NullTime `db:"update_time"`
|
||||
}
|
||||
)
|
||||
|
||||
func NewStudentModel(conn sqlx.SqlConn, c cache.CacheConf) StudentModel {
|
||||
return &defaultStudentModel{
|
||||
CachedConn: sqlc.NewConn(conn, c),
|
||||
table: "`student`",
|
||||
}
|
||||
}
|
||||
|
||||
func (m *defaultStudentModel) Insert(data Student) (sql.Result, error) {
|
||||
query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?)", m.table, studentRowsExpectAutoSet)
|
||||
ret, err := m.ExecNoCache(query, data.Name, data.Age, data.Score)
|
||||
|
||||
return ret, err
|
||||
}
|
||||
|
||||
func (m *defaultStudentModel) FindOne(id int64) (*Student, error) {
|
||||
studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, id)
|
||||
var resp Student
|
||||
err := m.QueryRow(&resp, studentIdKey, func(conn sqlx.SqlConn, v interface{}) error {
|
||||
query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", studentRows, m.table)
|
||||
return conn.QueryRow(v, query, id)
|
||||
})
|
||||
switch err {
|
||||
case nil:
|
||||
return &resp, nil
|
||||
case sqlc.ErrNotFound:
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (m *defaultStudentModel) Update(data Student) error {
|
||||
studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, data.Id)
|
||||
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
|
||||
query := fmt.Sprintf("update %s set %s where `id` = ?", m.table, studentRowsWithPlaceHolder)
|
||||
return conn.Exec(query, data.Name, data.Age, data.Score, data.Id)
|
||||
}, studentIdKey)
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *defaultStudentModel) Delete(id int64) error {
|
||||
|
||||
studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, id)
|
||||
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
|
||||
query := fmt.Sprintf("delete from %s where `id` = ?", m.table)
|
||||
return conn.Exec(query, id)
|
||||
}, studentIdKey)
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *defaultStudentModel) formatPrimary(primary interface{}) string {
|
||||
return fmt.Sprintf("%s%v", cacheStudentIdPrefix, primary)
|
||||
}
|
||||
|
||||
func (m *defaultStudentModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error {
|
||||
query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", studentRows, m.table)
|
||||
return conn.QueryRow(v, query, primary)
|
||||
}
|
||||
130
tools/goctl/model/sql/test/model/usermodel.go
Executable file
130
tools/goctl/model/sql/test/model/usermodel.go
Executable file
@@ -0,0 +1,130 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/stores/sqlc"
|
||||
"github.com/tal-tech/go-zero/core/stores/sqlx"
|
||||
"github.com/tal-tech/go-zero/core/stringx"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
|
||||
)
|
||||
|
||||
var (
|
||||
userFieldNames = builderx.FieldNames(&User{})
|
||||
userRows = strings.Join(userFieldNames, ",")
|
||||
userRowsExpectAutoSet = strings.Join(stringx.Remove(userFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
|
||||
userRowsWithPlaceHolder = strings.Join(stringx.Remove(userFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
|
||||
)
|
||||
|
||||
type (
|
||||
UserModel interface {
|
||||
Insert(data User) (sql.Result, error)
|
||||
FindOne(id int64) (*User, error)
|
||||
FindOneByUser(user string) (*User, error)
|
||||
FindOneByName(name string) (*User, error)
|
||||
FindOneByMobile(mobile string) (*User, error)
|
||||
Update(data User) error
|
||||
Delete(id int64) error
|
||||
}
|
||||
|
||||
defaultUserModel struct {
|
||||
conn sqlx.SqlConn
|
||||
table string
|
||||
}
|
||||
|
||||
User struct {
|
||||
Id int64 `db:"id"`
|
||||
User string `db:"user"` // 用户
|
||||
Name string `db:"name"` // 用户名称
|
||||
Password string `db:"password"` // 用户密码
|
||||
Mobile string `db:"mobile"` // 手机号
|
||||
Gender string `db:"gender"` // 男|女|未公开
|
||||
Nickname string `db:"nickname"` // 用户昵称
|
||||
CreateTime time.Time `db:"create_time"`
|
||||
UpdateTime time.Time `db:"update_time"`
|
||||
}
|
||||
)
|
||||
|
||||
func NewUserModel(conn sqlx.SqlConn) UserModel {
|
||||
return &defaultUserModel{
|
||||
conn: conn,
|
||||
table: "`user`",
|
||||
}
|
||||
}
|
||||
|
||||
func (m *defaultUserModel) Insert(data User) (sql.Result, error) {
|
||||
query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?, ?, ?, ?)", m.table, userRowsExpectAutoSet)
|
||||
ret, err := m.conn.Exec(query, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
func (m *defaultUserModel) FindOne(id int64) (*User, error) {
|
||||
query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", userRows, m.table)
|
||||
var resp User
|
||||
err := m.conn.QueryRow(&resp, query, id)
|
||||
switch err {
|
||||
case nil:
|
||||
return &resp, nil
|
||||
case sqlc.ErrNotFound:
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (m *defaultUserModel) FindOneByUser(user string) (*User, error) {
|
||||
var resp User
|
||||
query := fmt.Sprintf("select %s from %s where `user` = ? limit 1", userRows, m.table)
|
||||
err := m.conn.QueryRow(&resp, query, user)
|
||||
switch err {
|
||||
case nil:
|
||||
return &resp, nil
|
||||
case sqlc.ErrNotFound:
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (m *defaultUserModel) FindOneByName(name string) (*User, error) {
|
||||
var resp User
|
||||
query := fmt.Sprintf("select %s from %s where `name` = ? limit 1", userRows, m.table)
|
||||
err := m.conn.QueryRow(&resp, query, name)
|
||||
switch err {
|
||||
case nil:
|
||||
return &resp, nil
|
||||
case sqlc.ErrNotFound:
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (m *defaultUserModel) FindOneByMobile(mobile string) (*User, error) {
|
||||
var resp User
|
||||
query := fmt.Sprintf("select %s from %s where `mobile` = ? limit 1", userRows, m.table)
|
||||
err := m.conn.QueryRow(&resp, query, mobile)
|
||||
switch err {
|
||||
case nil:
|
||||
return &resp, nil
|
||||
case sqlc.ErrNotFound:
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (m *defaultUserModel) Update(data User) error {
|
||||
query := fmt.Sprintf("update %s set %s where `id` = ?", m.table, userRowsWithPlaceHolder)
|
||||
_, err := m.conn.Exec(query, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, data.Id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *defaultUserModel) Delete(id int64) error {
|
||||
query := fmt.Sprintf("delete from %s where `id` = ?", m.table)
|
||||
_, err := m.conn.Exec(query, id)
|
||||
return err
|
||||
}
|
||||
5
tools/goctl/model/sql/test/model/vars.go
Normal file
5
tools/goctl/model/sql/test/model/vars.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package model
|
||||
|
||||
import "github.com/tal-tech/go-zero/core/stores/sqlx"
|
||||
|
||||
var ErrNotFound = sqlx.ErrNotFound
|
||||
255
tools/goctl/model/sql/test/orm.go
Normal file
255
tools/goctl/model/sql/test/orm.go
Normal file
@@ -0,0 +1,255 @@
|
||||
// copy from core/stores/sqlx/orm.go
|
||||
package mocksql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/mapping"
|
||||
)
|
||||
|
||||
const tagName = "db"
|
||||
|
||||
var (
|
||||
ErrNotMatchDestination = errors.New("not matching destination to scan")
|
||||
ErrNotReadableValue = errors.New("value not addressable or interfaceable")
|
||||
ErrNotSettable = errors.New("passed in variable is not settable")
|
||||
ErrUnsupportedValueType = errors.New("unsupported unmarshal type")
|
||||
)
|
||||
|
||||
type rowsScanner interface {
|
||||
Columns() ([]string, error)
|
||||
Err() error
|
||||
Next() bool
|
||||
Scan(v ...interface{}) error
|
||||
}
|
||||
|
||||
func getTaggedFieldValueMap(v reflect.Value) (map[string]interface{}, error) {
|
||||
rt := mapping.Deref(v.Type())
|
||||
size := rt.NumField()
|
||||
result := make(map[string]interface{}, size)
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
key := parseTagName(rt.Field(i))
|
||||
if len(key) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
valueField := reflect.Indirect(v).Field(i)
|
||||
switch valueField.Kind() {
|
||||
case reflect.Ptr:
|
||||
if !valueField.CanInterface() {
|
||||
return nil, ErrNotReadableValue
|
||||
}
|
||||
if valueField.IsNil() {
|
||||
baseValueType := mapping.Deref(valueField.Type())
|
||||
valueField.Set(reflect.New(baseValueType))
|
||||
}
|
||||
result[key] = valueField.Interface()
|
||||
default:
|
||||
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
|
||||
return nil, ErrNotReadableValue
|
||||
}
|
||||
result[key] = valueField.Addr().Interface()
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]interface{}, error) {
|
||||
fields := unwrapFields(v)
|
||||
if strict && len(columns) < len(fields) {
|
||||
return nil, ErrNotMatchDestination
|
||||
}
|
||||
|
||||
taggedMap, err := getTaggedFieldValueMap(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
values := make([]interface{}, len(columns))
|
||||
if len(taggedMap) == 0 {
|
||||
for i := 0; i < len(values); i++ {
|
||||
valueField := fields[i]
|
||||
switch valueField.Kind() {
|
||||
case reflect.Ptr:
|
||||
if !valueField.CanInterface() {
|
||||
return nil, ErrNotReadableValue
|
||||
}
|
||||
if valueField.IsNil() {
|
||||
baseValueType := mapping.Deref(valueField.Type())
|
||||
valueField.Set(reflect.New(baseValueType))
|
||||
}
|
||||
values[i] = valueField.Interface()
|
||||
default:
|
||||
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
|
||||
return nil, ErrNotReadableValue
|
||||
}
|
||||
values[i] = valueField.Addr().Interface()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i, column := range columns {
|
||||
if tagged, ok := taggedMap[column]; ok {
|
||||
values[i] = tagged
|
||||
} else {
|
||||
var anonymous interface{}
|
||||
values[i] = &anonymous
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return values, nil
|
||||
}
|
||||
|
||||
func parseTagName(field reflect.StructField) string {
|
||||
key := field.Tag.Get(tagName)
|
||||
if len(key) == 0 {
|
||||
return ""
|
||||
} else {
|
||||
options := strings.Split(key, ",")
|
||||
return options[0]
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshalRow(v interface{}, scanner rowsScanner, strict bool) error {
|
||||
if !scanner.Next() {
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
if err := mapping.ValidatePtr(&rv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rte := reflect.TypeOf(v).Elem()
|
||||
rve := rv.Elem()
|
||||
switch rte.Kind() {
|
||||
case reflect.Bool,
|
||||
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||
reflect.Float32, reflect.Float64,
|
||||
reflect.String:
|
||||
if rve.CanSet() {
|
||||
return scanner.Scan(v)
|
||||
} else {
|
||||
return ErrNotSettable
|
||||
}
|
||||
case reflect.Struct:
|
||||
columns, err := scanner.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if values, err := mapStructFieldsIntoSlice(rve, columns, strict); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return scanner.Scan(values...)
|
||||
}
|
||||
default:
|
||||
return ErrUnsupportedValueType
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshalRows(v interface{}, scanner rowsScanner, strict bool) error {
|
||||
rv := reflect.ValueOf(v)
|
||||
if err := mapping.ValidatePtr(&rv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rt := reflect.TypeOf(v)
|
||||
rte := rt.Elem()
|
||||
rve := rv.Elem()
|
||||
switch rte.Kind() {
|
||||
case reflect.Slice:
|
||||
if rve.CanSet() {
|
||||
ptr := rte.Elem().Kind() == reflect.Ptr
|
||||
appendFn := func(item reflect.Value) {
|
||||
if ptr {
|
||||
rve.Set(reflect.Append(rve, item))
|
||||
} else {
|
||||
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
|
||||
}
|
||||
}
|
||||
fillFn := func(value interface{}) error {
|
||||
if rve.CanSet() {
|
||||
if err := scanner.Scan(value); err != nil {
|
||||
return err
|
||||
} else {
|
||||
appendFn(reflect.ValueOf(value))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ErrNotSettable
|
||||
}
|
||||
|
||||
base := mapping.Deref(rte.Elem())
|
||||
switch base.Kind() {
|
||||
case reflect.Bool,
|
||||
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||
reflect.Float32, reflect.Float64,
|
||||
reflect.String:
|
||||
for scanner.Next() {
|
||||
value := reflect.New(base)
|
||||
if err := fillFn(value.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
columns, err := scanner.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for scanner.Next() {
|
||||
value := reflect.New(base)
|
||||
if values, err := mapStructFieldsIntoSlice(value, columns, strict); err != nil {
|
||||
return err
|
||||
} else {
|
||||
if err := scanner.Scan(values...); err != nil {
|
||||
return err
|
||||
} else {
|
||||
appendFn(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
return ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
return nil
|
||||
} else {
|
||||
return ErrNotSettable
|
||||
}
|
||||
default:
|
||||
return ErrUnsupportedValueType
|
||||
}
|
||||
}
|
||||
|
||||
func unwrapFields(v reflect.Value) []reflect.Value {
|
||||
var fields []reflect.Value
|
||||
indirect := reflect.Indirect(v)
|
||||
|
||||
for i := 0; i < indirect.NumField(); i++ {
|
||||
child := indirect.Field(i)
|
||||
if child.Kind() == reflect.Ptr && child.IsNil() {
|
||||
baseValueType := mapping.Deref(child.Type())
|
||||
child.Set(reflect.New(baseValueType))
|
||||
}
|
||||
|
||||
child = reflect.Indirect(child)
|
||||
childType := indirect.Type().Field(i)
|
||||
if child.Kind() == reflect.Struct && childType.Anonymous {
|
||||
fields = append(fields, unwrapFields(child)...)
|
||||
} else {
|
||||
fields = append(fields, child)
|
||||
}
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
90
tools/goctl/model/sql/test/sqlconn.go
Normal file
90
tools/goctl/model/sql/test/sqlconn.go
Normal file
@@ -0,0 +1,90 @@
|
||||
// copy from core/stores/sqlx/sqlconn.go
|
||||
package mocksql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/stores/sqlx"
|
||||
)
|
||||
|
||||
type (
|
||||
MockConn struct {
|
||||
db *sql.DB
|
||||
}
|
||||
statement struct {
|
||||
stmt *sql.Stmt
|
||||
}
|
||||
)
|
||||
|
||||
func NewMockConn(db *sql.DB) *MockConn {
|
||||
return &MockConn{db: db}
|
||||
}
|
||||
|
||||
func (conn *MockConn) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
return exec(conn.db, query, args...)
|
||||
}
|
||||
|
||||
func (conn *MockConn) Prepare(query string) (sqlx.StmtSession, error) {
|
||||
st, err := conn.db.Prepare(query)
|
||||
return statement{stmt: st}, err
|
||||
}
|
||||
|
||||
func (conn *MockConn) QueryRow(v interface{}, q string, args ...interface{}) error {
|
||||
return query(conn.db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, true)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (conn *MockConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return query(conn.db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, false)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (conn *MockConn) QueryRows(v interface{}, q string, args ...interface{}) error {
|
||||
return query(conn.db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, true)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (conn *MockConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return query(conn.db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, false)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (conn *MockConn) Transact(func(session sqlx.Session) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s statement) Close() error {
|
||||
return s.stmt.Close()
|
||||
}
|
||||
|
||||
func (s statement) Exec(args ...interface{}) (sql.Result, error) {
|
||||
return execStmt(s.stmt, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRow(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, true)
|
||||
}, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, false)
|
||||
}, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRows(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, true)
|
||||
}, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, false)
|
||||
}, args...)
|
||||
}
|
||||
122
tools/goctl/model/sql/test/stmt.go
Normal file
122
tools/goctl/model/sql/test/stmt.go
Normal file
@@ -0,0 +1,122 @@
|
||||
// copy from core/stores/sqlx/stmt.go
|
||||
|
||||
package mocksql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"github.com/tal-tech/go-zero/core/timex"
|
||||
)
|
||||
|
||||
const slowThreshold = time.Millisecond * 500
|
||||
|
||||
func exec(db *sql.DB, q string, args ...interface{}) (sql.Result, error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
switch err {
|
||||
case nil:
|
||||
err = tx.Commit()
|
||||
default:
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
result, err := tx.Exec(q, args...)
|
||||
duration := timex.Since(startTime)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("sql exec: %s", stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logSqlError(stmt, err)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func execStmt(conn *sql.Stmt, args ...interface{}) (sql.Result, error) {
|
||||
stmt := fmt.Sprint(args...)
|
||||
startTime := timex.Now()
|
||||
result, err := conn.Exec(args...)
|
||||
duration := timex.Since(startTime)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("sql execStmt: %s", stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logSqlError(stmt, err)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func query(db *sql.DB, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
switch err {
|
||||
case nil:
|
||||
err = tx.Commit()
|
||||
default:
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
rows, err := tx.Query(q, args...)
|
||||
duration := timex.Since(startTime)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("sql query: %s", stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logSqlError(stmt, err)
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanner(rows)
|
||||
}
|
||||
|
||||
func queryStmt(conn *sql.Stmt, scanner func(*sql.Rows) error, args ...interface{}) error {
|
||||
stmt := fmt.Sprint(args...)
|
||||
startTime := timex.Now()
|
||||
rows, err := conn.Query(args...)
|
||||
duration := timex.Since(startTime)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("sql queryStmt: %s", stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logSqlError(stmt, err)
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanner(rows)
|
||||
}
|
||||
105
tools/goctl/model/sql/test/utils.go
Normal file
105
tools/goctl/model/sql/test/utils.go
Normal file
@@ -0,0 +1,105 @@
|
||||
// copy from core/stores/sqlx/utils.go
|
||||
package mocksql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"github.com/tal-tech/go-zero/core/mapping"
|
||||
)
|
||||
|
||||
var ErrNotFound = sql.ErrNoRows
|
||||
|
||||
func desensitize(datasource string) string {
|
||||
// remove account
|
||||
pos := strings.LastIndex(datasource, "@")
|
||||
if 0 <= pos && pos+1 < len(datasource) {
|
||||
datasource = datasource[pos+1:]
|
||||
}
|
||||
|
||||
return datasource
|
||||
}
|
||||
|
||||
func escape(input string) string {
|
||||
var b strings.Builder
|
||||
|
||||
for _, ch := range input {
|
||||
switch ch {
|
||||
case '\x00':
|
||||
b.WriteString(`\x00`)
|
||||
case '\r':
|
||||
b.WriteString(`\r`)
|
||||
case '\n':
|
||||
b.WriteString(`\n`)
|
||||
case '\\':
|
||||
b.WriteString(`\\`)
|
||||
case '\'':
|
||||
b.WriteString(`\'`)
|
||||
case '"':
|
||||
b.WriteString(`\"`)
|
||||
case '\x1a':
|
||||
b.WriteString(`\x1a`)
|
||||
default:
|
||||
b.WriteRune(ch)
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func format(query string, args ...interface{}) (string, error) {
|
||||
numArgs := len(args)
|
||||
if numArgs == 0 {
|
||||
return query, nil
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
argIndex := 0
|
||||
|
||||
for _, ch := range query {
|
||||
if ch == '?' {
|
||||
if argIndex >= numArgs {
|
||||
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
|
||||
}
|
||||
|
||||
arg := args[argIndex]
|
||||
argIndex++
|
||||
|
||||
switch v := arg.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
b.WriteByte('1')
|
||||
} else {
|
||||
b.WriteByte('0')
|
||||
}
|
||||
case string:
|
||||
b.WriteByte('\'')
|
||||
b.WriteString(escape(v))
|
||||
b.WriteByte('\'')
|
||||
default:
|
||||
b.WriteString(mapping.Repr(v))
|
||||
}
|
||||
} else {
|
||||
b.WriteRune(ch)
|
||||
}
|
||||
}
|
||||
|
||||
if argIndex < numArgs {
|
||||
return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
func logInstanceError(datasource string, err error) {
|
||||
datasource = desensitize(datasource)
|
||||
logx.Errorf("Error on getting sql instance of %s: %v", datasource, err)
|
||||
}
|
||||
|
||||
func logSqlError(stmt string, err error) {
|
||||
if err != nil && err != ErrNotFound {
|
||||
logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
|
||||
}
|
||||
}
|
||||
@@ -25,9 +25,10 @@ const (
|
||||
)
|
||||
|
||||
type Plugin struct {
|
||||
Api *spec.ApiSpec
|
||||
Style string
|
||||
Dir string
|
||||
Api *spec.ApiSpec
|
||||
ApiFilePath string
|
||||
Style string
|
||||
Dir string
|
||||
}
|
||||
|
||||
func PluginCommand(c *cli.Context) error {
|
||||
@@ -86,6 +87,12 @@ func prepareArgs(c *cli.Context) ([]byte, error) {
|
||||
transferData.Api = api
|
||||
}
|
||||
|
||||
absApiFilePath, err := filepath.Abs(apiPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
transferData.ApiFilePath = absApiFilePath
|
||||
dirAbs, err := filepath.Abs(c.String("dir"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"go/build"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -44,7 +45,9 @@ func TestRpcGenerate(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
_, err = execx.Run("go test "+projectName, projectDir)
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), "not in GOROOT")
|
||||
assert.True(t, func() bool {
|
||||
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
|
||||
}())
|
||||
}
|
||||
|
||||
// case go mod
|
||||
@@ -61,7 +64,9 @@ func TestRpcGenerate(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
_, err = execx.Run("go test "+projectName, projectDir)
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), "not in GOROOT")
|
||||
assert.True(t, func() bool {
|
||||
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
|
||||
}())
|
||||
}
|
||||
|
||||
// case not in go mod and go path
|
||||
@@ -69,7 +74,9 @@ func TestRpcGenerate(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
_, err = execx.Run("go test "+projectName, projectDir)
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), "not in GOROOT")
|
||||
assert.True(t, func() bool {
|
||||
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
|
||||
}())
|
||||
}
|
||||
|
||||
// invalid directory
|
||||
|
||||
@@ -15,12 +15,12 @@ const svcTemplate = `package svc
|
||||
import {{.imports}}
|
||||
|
||||
type ServiceContext struct {
|
||||
c config.Config
|
||||
Config config.Config
|
||||
}
|
||||
|
||||
func NewServiceContext(c config.Config) *ServiceContext {
|
||||
return &ServiceContext{
|
||||
c:c,
|
||||
Config:c,
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
@@ -54,14 +54,15 @@ func Clean() error {
|
||||
return util.Clean(category)
|
||||
}
|
||||
|
||||
func Update(category string) error {
|
||||
func Update() error {
|
||||
err := Clean()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return util.InitTemplates(category, templates)
|
||||
}
|
||||
|
||||
func GetCategory() string {
|
||||
func Category() string {
|
||||
return category
|
||||
}
|
||||
|
||||
@@ -97,8 +97,7 @@ func TestUpdate(t *testing.T) {
|
||||
}
|
||||
assert.Equal(t, "modify", string(data))
|
||||
|
||||
err = Update(category)
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, Update())
|
||||
|
||||
data, err = ioutil.ReadFile(mainTpl)
|
||||
if err != nil {
|
||||
@@ -109,6 +108,6 @@ func TestUpdate(t *testing.T) {
|
||||
|
||||
func TestGetCategory(t *testing.T) {
|
||||
_ = Clean()
|
||||
result := GetCategory()
|
||||
result := Category()
|
||||
assert.Equal(t, category, result)
|
||||
}
|
||||
|
||||
@@ -76,12 +76,16 @@ func UpdateTemplates(ctx *cli.Context) (err error) {
|
||||
}
|
||||
}()
|
||||
switch category {
|
||||
case gogen.GetCategory():
|
||||
return gogen.Update(category)
|
||||
case rpcgen.GetCategory():
|
||||
return rpcgen.Update(category)
|
||||
case modelgen.GetCategory():
|
||||
return modelgen.Update(category)
|
||||
case docker.Category():
|
||||
return docker.Update()
|
||||
case gogen.Category():
|
||||
return gogen.Update()
|
||||
case kube.Category():
|
||||
return kube.Update()
|
||||
case rpcgen.Category():
|
||||
return rpcgen.Update()
|
||||
case modelgen.Category():
|
||||
return modelgen.Update()
|
||||
default:
|
||||
err = fmt.Errorf("unexpected category: %s", category)
|
||||
return
|
||||
@@ -97,11 +101,15 @@ func RevertTemplates(ctx *cli.Context) (err error) {
|
||||
}
|
||||
}()
|
||||
switch category {
|
||||
case gogen.GetCategory():
|
||||
case docker.Category():
|
||||
return docker.RevertTemplate(filename)
|
||||
case kube.Category():
|
||||
return kube.RevertTemplate(filename)
|
||||
case gogen.Category():
|
||||
return gogen.RevertTemplate(filename)
|
||||
case rpcgen.GetCategory():
|
||||
case rpcgen.Category():
|
||||
return rpcgen.RevertTemplate(filename)
|
||||
case modelgen.GetCategory():
|
||||
case modelgen.Category():
|
||||
return modelgen.RevertTemplate(filename)
|
||||
default:
|
||||
err = fmt.Errorf("unexpected category: %s", category)
|
||||
|
||||
@@ -60,7 +60,6 @@ func FindGoModPath(dir string) (string, bool) {
|
||||
var hasGoMod = false
|
||||
for {
|
||||
if FileExists(filepath.Join(tempPath, goModeIdentifier)) {
|
||||
tempPath = filepath.Dir(tempPath)
|
||||
rootPath = strings.TrimPrefix(absDir[len(tempPath):], "/")
|
||||
hasGoMod = true
|
||||
break
|
||||
|
||||
@@ -19,10 +19,11 @@ func Untitle(s string) string {
|
||||
}
|
||||
|
||||
func Index(slice []string, item string) int {
|
||||
for i, _ := range slice {
|
||||
for i := range slice {
|
||||
if slice[i] == item {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
@@ -64,10 +64,10 @@ func (l *Logger) Warning(args ...interface{}) {
|
||||
// ignore builtin grpc warning
|
||||
}
|
||||
|
||||
func (l *Logger) Warningln(args ...interface{}) {
|
||||
// ignore builtin grpc warning
|
||||
}
|
||||
|
||||
func (l *Logger) Warningf(format string, args ...interface{}) {
|
||||
// ignore builtin grpc warning
|
||||
}
|
||||
|
||||
func (l *Logger) Warningln(args ...interface{}) {
|
||||
// ignore builtin grpc warning
|
||||
}
|
||||
|
||||
83
zrpc/internal/rpclogger_test.go
Normal file
83
zrpc/internal/rpclogger_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const content = "foo"
|
||||
|
||||
func TestLoggerError(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
log.SetOutput(&builder)
|
||||
logger := new(Logger)
|
||||
logger.Error(content)
|
||||
assert.Contains(t, builder.String(), content)
|
||||
}
|
||||
|
||||
func TestLoggerErrorf(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
log.SetOutput(&builder)
|
||||
logger := new(Logger)
|
||||
logger.Errorf(content)
|
||||
assert.Contains(t, builder.String(), content)
|
||||
}
|
||||
|
||||
func TestLoggerErrorln(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
log.SetOutput(&builder)
|
||||
logger := new(Logger)
|
||||
logger.Errorln(content)
|
||||
assert.Contains(t, builder.String(), content)
|
||||
}
|
||||
|
||||
func TestLoggerFatal(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
log.SetOutput(&builder)
|
||||
logger := new(Logger)
|
||||
logger.Fatal(content)
|
||||
assert.Contains(t, builder.String(), content)
|
||||
}
|
||||
|
||||
func TestLoggerFatalf(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
log.SetOutput(&builder)
|
||||
logger := new(Logger)
|
||||
logger.Fatalf(content)
|
||||
assert.Contains(t, builder.String(), content)
|
||||
}
|
||||
|
||||
func TestLoggerFatalln(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
log.SetOutput(&builder)
|
||||
logger := new(Logger)
|
||||
logger.Fatalln(content)
|
||||
assert.Contains(t, builder.String(), content)
|
||||
}
|
||||
|
||||
func TestLoggerWarning(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
log.SetOutput(&builder)
|
||||
logger := new(Logger)
|
||||
logger.Warning(content)
|
||||
assert.Empty(t, builder.String())
|
||||
}
|
||||
|
||||
func TestLoggerWarningf(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
log.SetOutput(&builder)
|
||||
logger := new(Logger)
|
||||
logger.Warningf(content)
|
||||
assert.Empty(t, builder.String())
|
||||
}
|
||||
|
||||
func TestLoggerWarningln(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
log.SetOutput(&builder)
|
||||
logger := new(Logger)
|
||||
logger.Warningln(content)
|
||||
assert.Empty(t, builder.String())
|
||||
}
|
||||
Reference in New Issue
Block a user