Compare commits

..

51 Commits

Author SHA1 Message Date
Kevin Wan
57b73d8b49 make sure offset less than size even it's checked inside (#354) 2021-01-05 16:06:36 +08:00
Kevin Wan
a79cee12ee add godoc for RollingWindow (#351) 2021-01-04 22:43:55 +08:00
zjbztianya
7a921f66e6 simple rolling windows code (#346) 2021-01-04 22:11:18 +08:00
kingxt
12e235efb0 optimized goctl format (#336)
* fix format

* refactor

* refactor

* optimized

* refactor

* refactor

* refactor

* add js path prefix
2021-01-04 18:59:48 +08:00
Kevin Wan
01060cf16d close issue of #337 (#347) 2021-01-04 16:36:27 +08:00
Kevin Wan
0786862a35 align bucket boundary to interval in rolling window (#345) 2021-01-04 11:17:59 +08:00
Kevin Wan
efa43483b2 fix potential data race in PeriodicalExecutor (#344)
* fix potential data race in PeriodicalExecutor

* add comment
2021-01-03 20:56:17 +08:00
Kevin Wan
771371e051 simplify rolling window code, and make tests run faster (#343) 2021-01-03 20:47:29 +08:00
zjbztianya
2ee95f8981 fix rolling window bug (#340) 2021-01-03 20:27:47 +08:00
Kevin Wan
5bc01e4bfd set guarded to false only on quitting background flush (#342)
* set guarded to false only on quitting background flush

* set guarded to false only on quitting background flush, cont.
2021-01-03 19:54:11 +08:00
Kevin Wan
510e966982 simplify periodical executor background routine (#339) 2021-01-03 14:02:51 +08:00
Kevin Wan
10e3b8ac80 optimize code that fixes issue #317 (#338) 2021-01-02 19:01:37 +08:00
Kevin Wan
04059bbf5a add discord chat group in readme 2021-01-02 18:35:33 +08:00
weibobo
d643007c79 fix bug #317 (#335)
* fix bug #317.
* add counter for current task. If it's bigger then zero, do not quit background thread

* Revert "fix issue #317 (#331)"

This reverts commit fc43876cc5.
2021-01-02 18:04:04 +08:00
Kevin Wan
fc43876cc5 fix issue #317 (#331) 2021-01-01 13:24:28 +08:00
FengZhang
a926cb514f modify the goctl gensvc template (#323) 2020-12-30 10:05:26 +08:00
kingxt
25cab2f273 Java (#327)
* add g4 file

* new define api by g4

* reactor parser to g4gen

* add syntax parser & test

* add syntax parser & test

* add syntax parser & test

* update g4 file

* add import parse & test

* ractor AT lexer

* panic with error

* revert AT

* update g4 file

* update g4 file

* update g4 file

* optimize parser

* update g4 file

* parse info

* optimized java generator

* revert

* optimize java generator

* update java generator

* update java generator

* update java generator

* update java generator

Co-authored-by: anqiansong <anqiansong@xiaoheiban.cn>
2020-12-29 17:50:41 +08:00
Kevin Wan
8d2e2753a2 simplify http.Flusher implementation (#326)
* simplify code with http.Flusher type conversion

* simplify code with http.Flusher type conversion, better version
2020-12-29 15:02:36 +08:00
Kevin Wan
cc4c50e3eb fix broken link. 2020-12-29 11:54:32 +08:00
Kevin Wan
751072bdb0 fix broken doc link 2020-12-29 11:52:55 +08:00
Kevin Wan
e97e1f10db simplify code with http.Flusher type conversion (#325)
* simplify code with http.Flusher type conversion

* simplify code with http.Flusher type conversion, better version
2020-12-29 10:25:55 +08:00
jichangyun
0bd2a0656c The ResponseWriters defined in rest.handler add Flush interface. (#318) 2020-12-28 21:30:24 +08:00
Kevin Wan
71a2b20301 add more tests for prof (#322) 2020-12-27 14:45:14 +08:00
Kevin Wan
8df7de94e3 add more tests for zrpc (#321) 2020-12-27 14:08:24 +08:00
Kevin Wan
bf21203297 add more tests (#320) 2020-12-27 12:26:31 +08:00
Kevin Wan
ae98375194 add more tests (#319) 2020-12-26 20:30:02 +08:00
Kevin Wan
82d1ccf376 fixes #286 (#315) 2020-12-25 19:47:27 +08:00
Kevin Wan
bb6d49c17e add go report card back (#313)
* add go report card back

* avoid test failure, run tests sequentially
2020-12-25 12:09:59 +08:00
Kevin Wan
ed735ec47c Update codeql-analysis.yml
disable python code analysis, python code is in examples.
2020-12-25 12:09:43 +08:00
Kevin Wan
ba4bac3a03 format code (#312) 2020-12-25 11:53:37 +08:00
FengZhang
08433d7e04 add config load support env var (#309) 2020-12-25 11:42:19 +08:00
anqiansong
a3b525b50d feature model fix (#296)
* add raw stirng quote for sql field

* remove unused code
2020-12-21 09:43:32 +08:00
Kevin Wan
097f6886f2 Update readme.md 2020-12-15 23:47:41 +08:00
Kevin Wan
07a1549634 add wechat micro practice qrcode image (#289) 2020-12-14 17:49:58 +08:00
Kevin Wan
befca26c58 Update readme.md
add goproxy.cn download badge
2020-12-13 00:02:32 +08:00
Kevin Wan
3556a2eef4 Update readme-en.md
goreportcard is not working, submitted an issue to them.
2020-12-12 23:40:26 +08:00
Kevin Wan
807765f77e Update readme.md
goreportcard is not working, submitted a issue to them.
2020-12-12 23:39:28 +08:00
Kevin Wan
e44584e549 Create codeql-analysis.yml 2020-12-12 23:01:15 +08:00
Kevin Wan
acd48f0abb optimize dockerfile generation (#284) 2020-12-12 16:53:06 +08:00
kingxt
f919bc6713 refactor (#283) 2020-12-12 11:18:22 +08:00
Kevin Wan
a0030b8f45 format dockerfile on non-chinese mode (#282) 2020-12-12 10:13:33 +08:00
Kevin Wan
a5f0cce1b1 Update readme-en.md 2020-12-12 09:06:09 +08:00
Kevin Wan
4d13dda605 add EXPOSE in dockerfile generation (#281) 2020-12-12 08:18:01 +08:00
songmeizi
b56cc8e459 optimize test case of TestRpcGenerate (#279)
Co-authored-by: anqiansong <anqiansong@xiaoheiban.cn>
2020-12-11 21:57:04 +08:00
Kevin Wan
c435811479 fix gocyclo warnings (#278) 2020-12-11 20:57:48 +08:00
Kevin Wan
c686c93fb5 fix dockerfile generation bug (#277) 2020-12-11 20:31:31 +08:00
Kevin Wan
da8f76e6bd add category docker & kube (#276) 2020-12-11 18:53:40 +08:00
Kevin Wan
99596a4149 fix issue #266 (#275)
* optimize dockerfile

* fix issue #266
2020-12-11 16:12:33 +08:00
wayne
ec2a9f2c57 fix tracelogger_test TestTraceLog (#271) 2020-12-10 17:04:57 +08:00
Kevin Wan
fd73ced6dc optimize dockerfile (#272) 2020-12-10 16:21:06 +08:00
Kevin Wan
5071736ab4 fmt code (#270) 2020-12-10 15:16:13 +08:00
75 changed files with 2265 additions and 387 deletions

67
.github/workflows/codeql-analysis.yml vendored Normal file
View File

@@ -0,0 +1,67 @@
# For most projects, this workflow file will not need changing; you simply need
# to commit it to your repository.
#
# You may wish to alter this file to override the set of languages analyzed,
# or to provide custom queries or build logic.
#
# ******** NOTE ********
# We have attempted to detect the languages in your repository. Please check
# the `language` matrix defined below to confirm you have the correct set of
# supported CodeQL languages.
#
name: "CodeQL"
on:
push:
branches: [ master ]
pull_request:
# The branches below must be a subset of the branches above
branches: [ master ]
schedule:
- cron: '18 19 * * 6'
jobs:
analyze:
name: Analyze
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
language: [ 'go' ]
# CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ]
# Learn more:
# https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed
steps:
- name: Checkout repository
uses: actions/checkout@v2
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v1
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
# By default, queries listed here will override any specified in a config file.
# Prefix the list here with "+" to use these queries and those in the config file.
# queries: ./path/to/local/query, your-org/your-repo/queries@main
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
uses: github/codeql-action/autobuild@v1
# Command-line programs to run using the OS shell.
# 📚 https://git.io/JvXDl
# ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
# and modify them (or add more) to build your code if your project
# uses a compiled language
#- run: |
# make bootstrap
# make release
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v1

2
.gitignore vendored
View File

@@ -4,6 +4,7 @@
# Unignore all with extensions
!*.*
!**/Dockerfile
!**/Makefile
# Unignore all dirs
!*/
@@ -12,7 +13,6 @@
.idea
**/.DS_Store
**/logs
!Makefile
# gitlab ci
.cache

View File

@@ -8,8 +8,10 @@ import (
)
type (
// RollingWindowOption let callers customize the RollingWindow.
RollingWindowOption func(rollingWindow *RollingWindow)
// RollingWindow defines a rolling window to calculate the events in buckets with time interval.
RollingWindow struct {
lock sync.RWMutex
size int
@@ -17,10 +19,12 @@ type (
interval time.Duration
offset int
ignoreCurrent bool
lastTime time.Duration
lastTime time.Duration // start time of the last bucket
}
)
// NewRollingWindow returns a RollingWindow that with size buckets and time interval,
// use opts to customize the RollingWindow.
func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOption) *RollingWindow {
if size < 1 {
panic("size must be greater than 0")
@@ -38,6 +42,7 @@ func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOpt
return w
}
// Add adds value to current bucket.
func (rw *RollingWindow) Add(v float64) {
rw.lock.Lock()
defer rw.lock.Unlock()
@@ -45,6 +50,7 @@ func (rw *RollingWindow) Add(v float64) {
rw.win.add(rw.offset, v)
}
// Reduce runs fn on all buckets, ignore current bucket if ignoreCurrent was set.
func (rw *RollingWindow) Reduce(fn func(b *Bucket)) {
rw.lock.RLock()
defer rw.lock.RUnlock()
@@ -79,26 +85,18 @@ func (rw *RollingWindow) updateOffset() {
}
offset := rw.offset
start := offset + 1
steps := start + span
var remainder int
if steps > rw.size {
remainder = steps - rw.size
steps = rw.size
}
// reset expired buckets
for i := start; i < steps; i++ {
rw.win.resetBucket(i)
}
for i := 0; i < remainder; i++ {
rw.win.resetBucket(i)
for i := 0; i < span; i++ {
rw.win.resetBucket((offset + i + 1) % rw.size)
}
rw.offset = (offset + span) % rw.size
rw.lastTime = timex.Now()
now := timex.Now()
// align to interval time boundary
rw.lastTime = now - (now-rw.lastTime)%rw.interval
}
// Bucket defines the bucket that holds sum and num of additions.
type Bucket struct {
Sum float64
Count int64
@@ -144,6 +142,7 @@ func (w *window) resetBucket(offset int) {
w.buckets[offset%w.size].reset()
}
// IgnoreCurrentBucket lets the Reduce call ignore current bucket.
func IgnoreCurrentBucket() RollingWindowOption {
return func(w *RollingWindow) {
w.ignoreCurrent = true

View File

@@ -105,6 +105,37 @@ func TestRollingWindowReduce(t *testing.T) {
}
}
func TestRollingWindowBucketTimeBoundary(t *testing.T) {
const size = 3
interval := time.Millisecond * 30
r := NewRollingWindow(size, interval)
listBuckets := func() []float64 {
var buckets []float64
r.Reduce(func(b *Bucket) {
buckets = append(buckets, b.Sum)
})
return buckets
}
assert.Equal(t, []float64{0, 0, 0}, listBuckets())
r.Add(1)
assert.Equal(t, []float64{0, 0, 1}, listBuckets())
time.Sleep(time.Millisecond * 45)
r.Add(2)
r.Add(3)
assert.Equal(t, []float64{0, 1, 5}, listBuckets())
// sleep time should be less than interval, and make the bucket change happen
time.Sleep(time.Millisecond * 20)
r.Add(4)
r.Add(5)
r.Add(6)
assert.Equal(t, []float64{1, 5, 15}, listBuckets())
time.Sleep(time.Millisecond * 100)
r.Add(7)
r.Add(8)
r.Add(9)
assert.Equal(t, []float64{0, 0, 24}, listBuckets())
}
func TestRollingWindowDataRace(t *testing.T) {
const size = 3
r := NewRollingWindow(size, duration)

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"io/ioutil"
"log"
"os"
"path"
"github.com/tal-tech/go-zero/core/mapping"
@@ -19,7 +20,7 @@ func LoadConfig(file string, v interface{}) error {
if content, err := ioutil.ReadFile(file); err != nil {
return err
} else if loader, ok := loaders[path.Ext(file)]; ok {
return loader(content, v)
return loader([]byte(os.ExpandEnv(string(content))), v)
} else {
return fmt.Errorf("unrecoginized file type: %s", file)
}

View File

@@ -17,13 +17,14 @@ func TestConfigJson(t *testing.T) {
}
text := `{
"a": "foo",
"b": 1
"b": 1,
"c": "${FOO}"
}`
for _, test := range tests {
test := test
t.Run(test, func(t *testing.T) {
t.Parallel()
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(test, text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
@@ -31,10 +32,12 @@ func TestConfigJson(t *testing.T) {
var val struct {
A string `json:"a"`
B int `json:"b"`
C string `json:"c"`
}
MustLoad(tmpfile, &val)
assert.Equal(t, "foo", val.A)
assert.Equal(t, 1, val.B)
assert.Equal(t, "2", val.C)
})
}
}

View File

@@ -3,6 +3,7 @@ package executors
import (
"reflect"
"sync"
"sync/atomic"
"time"
"github.com/tal-tech/go-zero/core/lang"
@@ -35,6 +36,7 @@ type (
// avoid race condition on waitGroup when calling wg.Add/Done/Wait(...)
wgBarrier syncx.Barrier
confirmChan chan lang.PlaceholderType
inflight int32
guarded bool
newTicker func(duration time.Duration) timex.Ticker
lock sync.Mutex
@@ -91,18 +93,16 @@ func (pe *PeriodicalExecutor) Wait() {
func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool) {
pe.lock.Lock()
defer func() {
var start bool
if !pe.guarded {
pe.guarded = true
start = true
// defer to unlock quickly
defer pe.backgroundFlush()
}
pe.lock.Unlock()
if start {
pe.backgroundFlush()
}
}()
if pe.container.AddTask(task) {
atomic.AddInt32(&pe.inflight, 1)
return pe.container.RemoveAll(), true
}
@@ -111,6 +111,9 @@ func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool)
func (pe *PeriodicalExecutor) backgroundFlush() {
threading.GoSafe(func() {
// flush before quit goroutine to avoid missing tasks
defer pe.Flush()
ticker := pe.newTicker(pe.interval)
defer ticker.Stop()
@@ -120,6 +123,7 @@ func (pe *PeriodicalExecutor) backgroundFlush() {
select {
case vals := <-pe.commander:
commanded = true
atomic.AddInt32(&pe.inflight, -1)
pe.enterExecution()
pe.confirmChan <- lang.Placeholder
pe.executeTasks(vals)
@@ -129,13 +133,7 @@ func (pe *PeriodicalExecutor) backgroundFlush() {
commanded = false
} else if pe.Flush() {
last = timex.Now()
} else if timex.Since(last) > pe.interval*idleRound {
pe.lock.Lock()
pe.guarded = false
pe.lock.Unlock()
// flush again to avoid missing tasks
pe.Flush()
} else if pe.shallQuit(last) {
return
}
}
@@ -178,3 +176,19 @@ func (pe *PeriodicalExecutor) hasTasks(tasks interface{}) bool {
return true
}
}
func (pe *PeriodicalExecutor) shallQuit(last time.Duration) (stop bool) {
if timex.Since(last) <= pe.interval*idleRound {
return
}
// checking pe.inflight and setting pe.guarded should be locked together
pe.lock.Lock()
if atomic.LoadInt32(&pe.inflight) == 0 {
pe.guarded = false
stop = true
}
pe.lock.Unlock()
return
}

View File

@@ -140,6 +140,26 @@ func TestPeriodicalExecutor_WaitFast(t *testing.T) {
assert.Equal(t, total, cnt)
}
func TestPeriodicalExecutor_Deadlock(t *testing.T) {
executor := NewBulkExecutor(func(tasks []interface{}) {
}, WithBulkTasks(1), WithBulkInterval(time.Millisecond))
for i := 0; i < 1e5; i++ {
executor.Add(1)
}
}
func TestPeriodicalExecutor_hasTasks(t *testing.T) {
ticker := timex.NewFakeTicker()
defer ticker.Stop()
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, nil))
exec.newTicker = func(d time.Duration) timex.Ticker {
return ticker
}
assert.False(t, exec.hasTasks(nil))
assert.True(t, exec.hasTasks(1))
}
// go test -benchtime 10s -bench .
func BenchmarkExecutor(b *testing.B) {
b.ReportAllocs()

View File

@@ -21,6 +21,7 @@ var mock tracespec.Trace = new(mockTrace)
func TestTraceLog(t *testing.T) {
var buf mockWriter
atomic.StoreUint32(&initialized, 1)
ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock)
WithContext(ctx).(*traceLogger).write(&buf, levelInfo, testlog)
assert.True(t, strings.Contains(buf.String(), mockTraceId))

View File

@@ -153,58 +153,57 @@ func doParseKeyAndOptions(field reflect.StructField, value string) (string, *fie
key := strings.TrimSpace(segments[0])
options := segments[1:]
if len(options) > 0 {
var fieldOpts fieldOptions
for _, segment := range options {
option := strings.TrimSpace(segment)
switch {
case option == stringOption:
fieldOpts.FromString = true
case strings.HasPrefix(option, optionalOption):
segs := strings.Split(option, equalToken)
switch len(segs) {
case 1:
fieldOpts.Optional = true
case 2:
fieldOpts.Optional = true
fieldOpts.OptionalDep = segs[1]
default:
return "", nil, fmt.Errorf("field %s has wrong optional", field.Name)
}
case option == optionalOption:
fieldOpts.Optional = true
case strings.HasPrefix(option, optionsOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong options", field.Name)
} else {
fieldOpts.Options = strings.Split(segs[1], optionSeparator)
}
case strings.HasPrefix(option, defaultOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong default option", field.Name)
} else {
fieldOpts.Default = strings.TrimSpace(segs[1])
}
case strings.HasPrefix(option, rangeOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong range", field.Name)
}
if nr, err := parseNumberRange(segs[1]); err != nil {
return "", nil, err
} else {
fieldOpts.Range = nr
}
}
}
return key, &fieldOpts, nil
if len(options) == 0 {
return key, nil, nil
}
return key, nil, nil
var fieldOpts fieldOptions
for _, segment := range options {
option := strings.TrimSpace(segment)
switch {
case option == stringOption:
fieldOpts.FromString = true
case strings.HasPrefix(option, optionalOption):
segs := strings.Split(option, equalToken)
switch len(segs) {
case 1:
fieldOpts.Optional = true
case 2:
fieldOpts.Optional = true
fieldOpts.OptionalDep = segs[1]
default:
return "", nil, fmt.Errorf("field %s has wrong optional", field.Name)
}
case option == optionalOption:
fieldOpts.Optional = true
case strings.HasPrefix(option, optionsOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong options", field.Name)
} else {
fieldOpts.Options = strings.Split(segs[1], optionSeparator)
}
case strings.HasPrefix(option, defaultOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong default option", field.Name)
} else {
fieldOpts.Default = strings.TrimSpace(segs[1])
}
case strings.HasPrefix(option, rangeOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong range", field.Name)
}
if nr, err := parseNumberRange(segs[1]); err != nil {
return "", nil, err
} else {
fieldOpts.Range = nr
}
}
}
return key, &fieldOpts, nil
}
func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) {

View File

@@ -0,0 +1,16 @@
package prof
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestReport(t *testing.T) {
once.Do(func() {})
assert.NotContains(t, generateReport(), "foo")
report("foo", time.Second)
assert.Contains(t, generateReport(), "foo")
report("foo", time.Second)
}

View File

@@ -0,0 +1,23 @@
package prof
import (
"testing"
"github.com/tal-tech/go-zero/core/utils"
)
func TestProfiler(t *testing.T) {
EnableProfiling()
Start()
Report("foo", ProfilePoint{
ElapsedTimer: utils.NewElapsedTimer(),
})
}
func TestNullProfiler(t *testing.T) {
p := newNullProfiler()
p.Start()
p.Report("foo", ProfilePoint{
ElapsedTimer: utils.NewElapsedTimer(),
})
}

View File

@@ -70,8 +70,6 @@ func (g *sharedGroup) createCall(key string) (c *call, done bool) {
func (g *sharedGroup) makeCall(c *call, key string, fn func() (interface{}, error)) {
defer func() {
// delete key first, done later. can't reverse the order, because if reverse,
// another Do call might wg.Wait() without get notified with wg.Done()
g.lock.Lock()
delete(g.calls, key)
g.lock.Unlock()

View File

@@ -129,7 +129,7 @@ go get -u github.com/tal-tech/go-zero
the .api files also can be generate by goctl, like below:
```shell
goctl api -o greet.api
goctl api -o greet.api
```
3. generate the go server side code
@@ -208,3 +208,7 @@ goctl api -o greet.api
* [Rapid development of microservice systems](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl-en.md)
* [Rapid development of microservice systems - multiple RPCs](https://github.com/tal-tech/zero-doc/blob/main/doc/bookstore-en.md)
## 9. Chat group
Join the chat via https://discord.gg/4JQvC5A4Fe

View File

@@ -5,8 +5,9 @@
[English](readme-en.md) | 简体中文
[![Go](https://github.com/tal-tech/go-zero/workflows/Go/badge.svg?branch=master)](https://github.com/tal-tech/go-zero/actions)
[![codecov](https://codecov.io/gh/tal-tech/go-zero/branch/master/graph/badge.svg)](https://codecov.io/gh/tal-tech/go-zero)
[![Go Report Card](https://goreportcard.com/badge/github.com/tal-tech/go-zero)](https://goreportcard.com/report/github.com/tal-tech/go-zero)
[![goproxy](https://goproxy.cn/stats/github.com/tal-tech/go-zero/badges/download-count.svg)](https://goproxy.cn/stats/github.com/tal-tech/go-zero/badges/download-count.svg)
[![codecov](https://codecov.io/gh/tal-tech/go-zero/branch/master/graph/badge.svg)](https://codecov.io/gh/tal-tech/go-zero)
[![Release](https://img.shields.io/github/v/release/tal-tech/go-zero.svg?style=flat-square)](https://github.com/tal-tech/go-zero)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
@@ -95,7 +96,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
[快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
[快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/frame/bookstore.md)
[快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/zero/bookstore.md)
1. 安装 goctl 工具
@@ -162,7 +163,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
* awesome 系列
* [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
* [快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/frame/bookstore.md)
* [快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/zero/bookstore.md)
* [goctl 使用帮助](https://github.com/tal-tech/zero-doc/blob/main/doc/goctl.md)
* [通过 MapReduce 降低服务响应时间](https://github.com/tal-tech/zero-doc/blob/main/doc/mapreduce.md)
* [关键字替换和敏感词过滤工具](https://github.com/tal-tech/zero-doc/blob/main/doc/keywords.md)
@@ -172,7 +173,13 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
* [文本序列化和反序列化](https://github.com/tal-tech/zero-doc/blob/main/doc/mapping.md)
* [快速构建 jwt 鉴权认证](https://github.com/tal-tech/zero-doc/blob/main/doc/jwt.md)
## 8. 微信交流群
## 8. 微信公众号
`go-zero` 相关文章都会在 `微服务实践` 公众号整理呈现,欢迎扫码关注,也可以通过公众号私信我 👏
<img src="https://gitee.com/kevwan/static/raw/master/images/wechat-micro.jpg" alt="wechat" width="300" />
## 9. 微信交流群
如果文档中未能覆盖的任何疑问,欢迎您在群里提出,我们会尽快答复。

171
rest/engine_test.go Normal file
View File

@@ -0,0 +1,171 @@
package rest
import (
"errors"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/conf"
)
func TestNewEngine(t *testing.T) {
yamls := []string{
`Name: foo
Port: 54321
`,
`Name: foo
Port: 54321
CpuThreshold: 500
`,
`Name: foo
Port: 54321
CpuThreshold: 500
Verbose: true
`,
}
routes := []featuredRoutes{
{
jwt: jwtSetting{},
signature: signatureSetting{},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{},
signature: signatureSetting{},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{
enabled: true,
},
signature: signatureSetting{},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{
enabled: true,
prevSecret: "thesecret",
},
signature: signatureSetting{},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{
enabled: true,
},
signature: signatureSetting{},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{
enabled: true,
},
signature: signatureSetting{
enabled: true,
},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{
enabled: true,
},
signature: signatureSetting{
enabled: true,
SignatureConf: SignatureConf{
Strict: true,
},
},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{
enabled: true,
},
signature: signatureSetting{
enabled: true,
SignatureConf: SignatureConf{
Strict: true,
PrivateKeys: []PrivateKeyConf{
{
Fingerprint: "a",
KeyFile: "b",
},
},
},
},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
}
for _, yaml := range yamls {
for _, route := range routes {
var cnf RestConf
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(yaml), &cnf))
ng := newEngine(cnf)
ng.AddRoutes(route)
ng.use(func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
}
})
assert.NotNil(t, ng.StartWithRouter(mockedRouter{}))
}
}
}
type mockedRouter struct {
}
func (m mockedRouter) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
}
func (m mockedRouter) Handle(method string, path string, handler http.Handler) error {
return errors.New("foo")
}
func (m mockedRouter) SetNotFoundHandler(handler http.Handler) {
}
func (m mockedRouter) SetNotAllowedHandler(handler http.Handler) {
}

View File

@@ -46,18 +46,18 @@ func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.H
parser := token.NewTokenParser()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, err := parser.ParseToken(r, secret, authOpts.PrevSecret)
tok, err := parser.ParseToken(r, secret, authOpts.PrevSecret)
if err != nil {
unauthorized(w, r, err, authOpts.Callback)
return
}
if !token.Valid {
if !tok.Valid {
unauthorized(w, r, errInvalidToken, authOpts.Callback)
return
}
claims, ok := token.Claims.(jwt.MapClaims)
claims, ok := tok.Claims.(jwt.MapClaims)
if !ok {
unauthorized(w, r, errNoClaims, authOpts.Callback)
return
@@ -122,6 +122,12 @@ func newGuardedResponseWriter(w http.ResponseWriter) *guardedResponseWriter {
}
}
func (grw *guardedResponseWriter) Flush() {
if flusher, ok := grw.writer.(http.Flusher); ok {
flusher.Flush()
}
}
func (grw *guardedResponseWriter) Header() http.Header {
return grw.writer.Header()
}

View File

@@ -41,6 +41,10 @@ func TestAuthHandler(t *testing.T) {
w.Header().Set("X-Test", "test")
_, err := w.Write([]byte("content"))
assert.Nil(t, err)
flusher, ok := w.(http.Flusher)
assert.True(t, ok)
flusher.Flush()
}))
resp := httptest.NewRecorder()

View File

@@ -83,6 +83,12 @@ func newCryptionResponseWriter(w http.ResponseWriter) *cryptionResponseWriter {
}
}
func (w *cryptionResponseWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
func (w *cryptionResponseWriter) Header() http.Header {
return w.ResponseWriter.Header()
}

View File

@@ -87,3 +87,19 @@ func TestCryptionHandlerWriteHeader(t *testing.T) {
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
}
func TestCryptionHandlerFlush(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/any", nil)
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(respText))
flusher, ok := w.(http.Flusher)
assert.True(t, ok)
flusher.Flush()
}))
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
assert.Nil(t, err)
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
}

View File

@@ -38,6 +38,12 @@ func (w *LoggedResponseWriter) WriteHeader(code int) {
w.code = code
}
func (w *LoggedResponseWriter) Flush() {
if flusher, ok := w.w.(http.Flusher); ok {
flusher.Flush()
}
}
func LogHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
timer := utils.NewElapsedTimer()
@@ -68,6 +74,10 @@ func newDetailLoggedResponseWriter(writer *LoggedResponseWriter, buf *bytes.Buff
}
}
func (w *DetailLoggedResponseWriter) Flush() {
w.writer.Flush()
}
func (w *DetailLoggedResponseWriter) Header() http.Header {
return w.writer.Header()
}

View File

@@ -30,6 +30,10 @@ func TestLogHandler(t *testing.T) {
w.WriteHeader(http.StatusServiceUnavailable)
_, err := w.Write([]byte("content"))
assert.Nil(t, err)
flusher, ok := w.(http.Flusher)
assert.True(t, ok)
flusher.Flush()
}))
resp := httptest.NewRecorder()

View File

@@ -7,6 +7,12 @@ type WithCodeResponseWriter struct {
Code int
}
func (w *WithCodeResponseWriter) Flush() {
if flusher, ok := w.Writer.(http.Flusher); ok {
flusher.Flush()
}
}
func (w *WithCodeResponseWriter) Header() http.Header {
return w.Writer.Header()
}

View File

@@ -0,0 +1,33 @@
package security
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestWithCodeResponseWriter(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cw := &WithCodeResponseWriter{Writer: w}
cw.Header().Set("X-Test", "test")
cw.WriteHeader(http.StatusServiceUnavailable)
assert.Equal(t, cw.Code, http.StatusServiceUnavailable)
_, err := cw.Write([]byte("content"))
assert.Nil(t, err)
flusher, ok := http.ResponseWriter(cw).(http.Flusher)
assert.True(t, ok)
flusher.Flush()
})
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
assert.Equal(t, "test", resp.Header().Get("X-Test"))
assert.Equal(t, "content", resp.Body.String())
}

View File

@@ -64,7 +64,7 @@ func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
allow, ok := pr.methodNotAllowed(r.Method, reqPath)
allows, ok := pr.methodsAllowed(r.Method, reqPath)
if !ok {
pr.handleNotFound(w, r)
return
@@ -73,7 +73,7 @@ func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if pr.notAllowed != nil {
pr.notAllowed.ServeHTTP(w, r)
} else {
w.Header().Set(allowHeader, allow)
w.Header().Set(allowHeader, allows)
w.WriteHeader(http.StatusMethodNotAllowed)
}
}
@@ -94,7 +94,7 @@ func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
}
}
func (pr *patRouter) methodNotAllowed(method, path string) (string, bool) {
func (pr *patRouter) methodsAllowed(method, path string) (string, bool) {
var allows []string
for treeMethod, tree := range pr.trees {

View File

@@ -1,7 +1,6 @@
package rest
import (
"errors"
"log"
"net/http"
@@ -24,6 +23,9 @@ type (
}
)
// MustNewServer returns a server with given config of c and options defined in opts.
// Be aware that later RunOption might overwrite previous one that write the same option.
// The process will exit if error occurs.
func MustNewServer(c RestConf, opts ...RunOption) *Server {
engine, err := NewServer(c, opts...)
if err != nil {
@@ -33,11 +35,9 @@ func MustNewServer(c RestConf, opts ...RunOption) *Server {
return engine
}
// NewServer returns a server with given config of c and options defined in opts.
// Be aware that later RunOption might overwrite previous one that write the same option.
func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
if len(opts) > 1 {
return nil, errors.New("only one RunOption is allowed")
}
if err := c.SetUp(); err != nil {
return nil, err
}

View File

@@ -8,18 +8,84 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/conf"
"github.com/tal-tech/go-zero/rest/httpx"
"github.com/tal-tech/go-zero/rest/router"
)
func TestNewServer(t *testing.T) {
_, err := NewServer(RestConf{}, WithNotFoundHandler(nil), WithNotAllowedHandler(nil))
assert.NotNil(t, err)
const configYaml = `
Name: foo
Port: 54321
`
var cnf RestConf
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
failStart := func(server *Server) {
server.opts.start = func(e *engine) error {
return http.ErrServerClosed
}
}
tests := []struct {
c RestConf
opts []RunOption
fail bool
}{
{
c: RestConf{},
opts: []RunOption{failStart},
fail: true,
},
{
c: cnf,
opts: []RunOption{failStart},
},
{
c: cnf,
opts: []RunOption{WithNotAllowedHandler(nil), failStart},
},
{
c: cnf,
opts: []RunOption{WithNotFoundHandler(nil), failStart},
},
{
c: cnf,
opts: []RunOption{WithUnauthorizedCallback(nil), failStart},
},
{
c: cnf,
opts: []RunOption{WithUnsignedCallback(nil), failStart},
},
}
for _, test := range tests {
srv, err := NewServer(test.c, test.opts...)
if test.fail {
assert.NotNil(t, err)
}
if err != nil {
continue
}
srv.Use(ToMiddleware(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
})
}))
srv.AddRoute(Route{
Method: http.MethodGet,
Path: "/",
Handler: nil,
}, WithJwt("thesecret"), WithSignature(SignatureConf{}),
WithJwtTransition("preivous", "thenewone"))
srv.Start()
srv.Stop()
}
}
func TestWithMiddleware(t *testing.T) {
m := make(map[string]string)
router := router.NewRouter()
rt := router.NewRouter()
handler := func(w http.ResponseWriter, r *http.Request) {
var v struct {
Nickname string `form:"nickname"`
@@ -56,14 +122,14 @@ func TestWithMiddleware(t *testing.T) {
"http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
}
for _, route := range rs {
assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler))
assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
}
for _, url := range urls {
r, err := http.NewRequest(http.MethodGet, url, nil)
assert.Nil(t, err)
rr := httptest.NewRecorder()
router.ServeHTTP(rr, r)
rt.ServeHTTP(rr, r)
assert.Equal(t, "whatever:200000", rr.Body.String())
}
@@ -76,7 +142,7 @@ func TestWithMiddleware(t *testing.T) {
func TestMultiMiddlewares(t *testing.T) {
m := make(map[string]string)
router := router.NewRouter()
rt := router.NewRouter()
handler := func(w http.ResponseWriter, r *http.Request) {
var v struct {
Nickname string `form:"nickname"`
@@ -127,14 +193,14 @@ func TestMultiMiddlewares(t *testing.T) {
"http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
}
for _, route := range rs {
assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler))
assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
}
for _, url := range urls {
r, err := http.NewRequest(http.MethodGet, url, nil)
assert.Nil(t, err)
rr := httptest.NewRecorder()
router.ServeHTTP(rr, r)
rt.ServeHTTP(rr, r)
assert.Equal(t, "whatever:200000200000", rr.Body.String())
}

View File

@@ -4,6 +4,7 @@ import (
"bufio"
"errors"
"fmt"
"go/format"
"go/scanner"
"io/ioutil"
"os"
@@ -13,6 +14,7 @@ import (
"github.com/tal-tech/go-zero/core/errorx"
"github.com/tal-tech/go-zero/tools/goctl/api/parser"
"github.com/tal-tech/go-zero/tools/goctl/api/util"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/urfave/cli"
)
@@ -103,24 +105,108 @@ func apiFormat(data string) (string, error) {
var builder strings.Builder
s := bufio.NewScanner(strings.NewReader(data))
var tapCount = 0
var newLineCount = 0
var preLine string
for s.Scan() {
line := strings.TrimSpace(s.Text())
if len(line) == 0 {
if newLineCount > 0 {
continue
}
newLineCount++
} else {
if preLine == rightBrace {
builder.WriteString(ctlutil.NL)
}
newLineCount = 0
}
if tapCount == 0 {
format, err := formatGoTypeDef(line, s, &builder)
if err != nil {
return "", err
}
if format {
continue
}
}
noCommentLine := util.RemoveComment(line)
if noCommentLine == rightParenthesis || noCommentLine == rightBrace {
tapCount -= 1
}
if tapCount < 0 {
line = strings.TrimSuffix(line, rightBrace)
line := strings.TrimSuffix(noCommentLine, rightBrace)
line = strings.TrimSpace(line)
if strings.HasSuffix(line, leftBrace) {
tapCount += 1
}
}
util.WriteIndent(&builder, tapCount)
builder.WriteString(line + "\n")
builder.WriteString(line + ctlutil.NL)
if strings.HasSuffix(noCommentLine, leftParenthesis) || strings.HasSuffix(noCommentLine, leftBrace) {
tapCount += 1
}
preLine = line
}
return strings.TrimSpace(builder.String()), nil
}
func formatGoTypeDef(line string, scanner *bufio.Scanner, builder *strings.Builder) (bool, error) {
noCommentLine := util.RemoveComment(line)
tokenCount := 0
if strings.HasPrefix(noCommentLine, "type") && (strings.HasSuffix(noCommentLine, leftParenthesis) ||
strings.HasSuffix(noCommentLine, leftBrace)) {
var typeBuilder strings.Builder
typeBuilder.WriteString(mayInsertStructKeyword(line, &tokenCount) + ctlutil.NL)
for scanner.Scan() {
noCommentLine := util.RemoveComment(scanner.Text())
typeBuilder.WriteString(mayInsertStructKeyword(scanner.Text(), &tokenCount) + ctlutil.NL)
if noCommentLine == rightBrace || noCommentLine == rightParenthesis {
tokenCount--
}
if tokenCount == 0 {
ts, err := format.Source([]byte(typeBuilder.String()))
if err != nil {
return false, errors.New("error format \n" + typeBuilder.String())
}
result := strings.ReplaceAll(string(ts), " struct ", " ")
result = strings.ReplaceAll(result, "type ()", "")
builder.WriteString(result)
break
}
}
return true, nil
}
return false, nil
}
func mayInsertStructKeyword(line string, token *int) string {
insertStruct := func() string {
if strings.Contains(line, " struct") {
return line
}
index := strings.Index(line, leftBrace)
return line[:index] + " struct " + line[index:]
}
noCommentLine := util.RemoveComment(line)
if strings.HasSuffix(noCommentLine, leftBrace) {
*token++
return insertStruct()
}
if strings.HasSuffix(noCommentLine, rightBrace) {
noCommentLine = strings.TrimSuffix(noCommentLine, rightBrace)
noCommentLine = util.RemoveComment(noCommentLine)
if strings.HasSuffix(noCommentLine, leftBrace) {
return insertStruct()
}
}
if strings.HasSuffix(noCommentLine, leftParenthesis) {
*token++
}
return line
}

View File

@@ -24,11 +24,11 @@ handler: GreetHandler
}
`
formattedStr = `type Request struct {
formattedStr = `type Request {
Name string
}
type Response struct {
type Response {
Message string
}
@@ -40,7 +40,7 @@ service A-api {
}`
)
func TestInlineTypeNotExist(t *testing.T) {
func TestFormat(t *testing.T) {
r, err := apiFormat(notFormattedStr)
assert.Nil(t, err)
assert.Equal(t, r, formattedStr)

View File

@@ -38,11 +38,12 @@ func RevertTemplate(name string) error {
return util.CreateTemplate(category, name, content)
}
func Update(category string) error {
func Update() error {
err := Clean()
if err != nil {
return err
}
return util.InitTemplates(category, templates)
}
@@ -50,6 +51,6 @@ func Clean() error {
return util.Clean(category)
}
func GetCategory() string {
func Category() string {
return category
}

View File

@@ -84,7 +84,7 @@ func TestUpdate(t *testing.T) {
assert.Equal(t, string(data), modifyData)
assert.Nil(t, Update(category))
assert.Nil(t, Update())
data, err = ioutil.ReadFile(file)
assert.Nil(t, err)

View File

@@ -1,6 +1,7 @@
package javagen
import (
"errors"
"fmt"
"io"
"path"
@@ -17,6 +18,8 @@ const (
package com.xhb.logic.http.packet.{{.packet}}.model;
import com.xhb.logic.http.DeProguardable;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
{{.componentType}}
`
@@ -28,7 +31,7 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error {
return nil
}
for _, ty := range types {
if err := createComponent(dir, packetName, ty); err != nil {
if err := createComponent(dir, packetName, ty, api.Types); err != nil {
return err
}
}
@@ -36,7 +39,7 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error {
return nil
}
func createComponent(dir, packetName string, ty spec.Type) error {
func createComponent(dir, packetName string, ty spec.Type, types []spec.Type) error {
modelFile := util.Title(ty.Name) + ".java"
filename := path.Join(dir, modelDir, modelFile)
if err := util.RemoveOrQuit(filename); err != nil {
@@ -52,7 +55,7 @@ func createComponent(dir, packetName string, ty spec.Type) error {
}
defer fp.Close()
tys, err := buildType(ty)
tys, err := buildType(ty, types)
if err != nil {
return err
}
@@ -64,22 +67,66 @@ func createComponent(dir, packetName string, ty spec.Type) error {
})
}
func buildType(ty spec.Type) (string, error) {
func buildType(ty spec.Type, types []spec.Type) (string, error) {
var builder strings.Builder
if err := writeType(&builder, ty); err != nil {
if err := writeType(&builder, ty, types); err != nil {
return "", apiutil.WrapErr(err, "Type "+ty.Name+" generate error")
}
return builder.String(), nil
}
func writeType(writer io.Writer, tp spec.Type) error {
func writeType(writer io.Writer, tp spec.Type, types []spec.Type) error {
fmt.Fprintf(writer, "public class %s implements DeProguardable {\n", util.Title(tp.Name))
for _, member := range tp.Members {
if err := writeProperty(writer, member, 1); err != nil {
return err
}
var members []spec.Member
err := writeMembers(writer, types, tp.Members, &members, 1)
if err != nil {
return err
}
genGetSet(writer, members, 1)
fmt.Fprintf(writer, "}")
return nil
}
func writeMembers(writer io.Writer, types []spec.Type, members []spec.Member, allMembers *[]spec.Member, indent int) error {
for _, member := range members {
if !member.IsInline {
_, err := member.GetPropertyName()
if err != nil {
return err
}
}
if !member.IsBodyMember() {
continue
}
for _, item := range *allMembers {
if item.Name == member.Name {
continue
}
}
if member.IsInline {
hasInline := false
for _, ty := range types {
if strings.ToLower(ty.Name) == strings.ToLower(member.Name) {
err := writeMembers(writer, types, ty.Members, allMembers, indent)
if err != nil {
return err
}
hasInline = true
break
}
}
if !hasInline {
return errors.New("inline type " + member.Name + " not exist, please correct api file")
}
} else {
if err := writeProperty(writer, member, indent); err != nil {
return err
}
*allMembers = append(*allMembers, member)
}
}
genGetSet(writer, tp, 1)
fmt.Fprintf(writer, "}\n")
return nil
}

View File

@@ -19,23 +19,27 @@ const packetTemplate = `package com.xhb.logic.http.packet.{{.packet}};
import com.google.gson.Gson;
import com.xhb.commons.JSON;
import com.xhb.commons.JsonParser;
import com.xhb.commons.JsonMarshal;
import com.xhb.core.network.HttpRequestClient;
import com.xhb.core.packet.HttpRequestPacket;
import com.xhb.core.response.HttpResponseData;
import com.xhb.logic.http.DeProguardable;
{{if not .HasRequestBody}}
import com.xhb.logic.http.request.EmptyRequest;
{{end}}
{{.import}}
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.json.JSONObject;
public class {{.packetName}} extends HttpRequestPacket<{{.packetName}}.{{.packetName}}Response> {
{{.paramsDeclaration}}
public {{.packetName}}({{.params}}{{.requestType}} request) {
super(request);
this.request = request;{{.paramsSet}}
public {{.packetName}}({{.params}}{{if .HasRequestBody}}, {{.requestType}} request{{end}}) {
{{if .HasRequestBody}}super(request);{{else}}super(EmptyRequest.instance);{{end}}
{{if .HasRequestBody}}this.request = request;{{end}}{{.paramsSet}}
}
@Override
@@ -113,7 +117,8 @@ func createWith(dir string, api *spec.ApiSpec, route spec.Route, packetName stri
} else {
fmt.Fprintln(&builder)
}
if err := genType(&builder, tp); err != nil {
if err := genType(&builder, tp, api.Types); err != nil {
return err
}
}
@@ -126,7 +131,7 @@ func createWith(dir string, api *spec.ApiSpec, route spec.Route, packetName stri
t := template.Must(template.New("packetTemplate").Parse(packetTemplate))
var tmplBytes bytes.Buffer
err = t.Execute(&tmplBytes, map[string]string{
err = t.Execute(&tmplBytes, map[string]interface{}{
"packetName": packet,
"method": strings.ToUpper(route.Method),
"uri": processUri(route),
@@ -137,6 +142,7 @@ func createWith(dir string, api *spec.ApiSpec, route spec.Route, packetName stri
"paramsSet": paramsSet,
"packet": packetName,
"requestType": util.Title(route.RequestType.Name),
"HasRequestBody": len(route.RequestType.GetBodyMembers()) > 0,
"import": getImports(api, route, packetName),
})
if err != nil {
@@ -209,7 +215,7 @@ func paramsForRoute(route spec.Route) string {
builder.WriteString(fmt.Sprintf("String %s, ", cop[1:]))
}
}
return builder.String()
return strings.TrimSuffix(builder.String(), ", ")
}
func declarationForRoute(route spec.Route) string {
@@ -260,18 +266,22 @@ func processUri(route spec.Route) string {
return result
}
func genType(writer io.Writer, tp spec.Type) error {
writeIndent(writer, 1)
fmt.Fprintf(writer, "static class %s implements DeProguardable {\n", util.Title(tp.Name))
for _, member := range tp.Members {
if err := writeProperty(writer, member, 2); err != nil {
return err
}
func genType(writer io.Writer, tp spec.Type, types []spec.Type) error {
if len(tp.GetBodyMembers()) == 0 {
return nil
}
writeBreakline(writer)
writeIndent(writer, 1)
genGetSet(writer, tp, 2)
fmt.Fprintf(writer, "static class %s implements DeProguardable {\n", util.Title(tp.Name))
var members []spec.Member
err := writeMembers(writer, types, tp.Members, &members, 2)
if err != nil {
return err
}
writeNewline(writer)
writeIndent(writer, 1)
genGetSet(writer, members, 2)
writeIndent(writer, 1)
fmt.Fprintln(writer, "}")

View File

@@ -67,8 +67,8 @@ func indentString(indent int) string {
return result
}
func writeBreakline(writer io.Writer) {
fmt.Fprint(writer, "\n")
func writeNewline(writer io.Writer) {
fmt.Fprint(writer, util.NL)
}
func isPrimitiveType(tp string) bool {
@@ -87,6 +87,7 @@ func goTypeToJava(tp string) (string, error) {
if len(tp) == 0 {
return "", errors.New("property type empty")
}
if strings.HasPrefix(tp, "*") {
tp = tp[1:]
}
@@ -107,39 +108,44 @@ func goTypeToJava(tp string) (string, error) {
if err != nil {
return "", err
}
if len(tys) == 0 {
return "", fmt.Errorf("%s tp parse error", tp)
}
return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(tys[0])), nil
} else if strings.HasPrefix(tp, "map") {
tys, err := apiutil.DecomposeType(tp)
if err != nil {
return "", err
}
if len(tys) == 2 {
return "", fmt.Errorf("%s tp parse error", tp)
}
return fmt.Sprintf("java.util.HashMap<String, %s>", util.Title(tys[1])), nil
}
return util.Title(tp), nil
}
func genGetSet(writer io.Writer, tp spec.Type, indent int) error {
func genGetSet(writer io.Writer, members []spec.Member, indent int) error {
t := template.Must(template.New("getSetTemplate").Parse(getSetTemplate))
for _, member := range tp.Members {
for _, member := range members {
var tmplBytes bytes.Buffer
oty, err := goTypeToJava(member.Type)
if err != nil {
return err
}
tyString := oty
decorator := ""
if !isPrimitiveType(member.Type) {
if member.IsOptional() {
decorator = "@org.jetbrains.annotations.Nullable "
decorator = "@Nullable "
} else {
decorator = "@org.jetbrains.annotations.NotNull "
decorator = "@NotNull "
}
tyString = decorator + tyString
}
@@ -155,6 +161,7 @@ func genGetSet(writer io.Writer, tp spec.Type, indent int) error {
if err != nil {
return err
}
r := tmplBytes.String()
r = strings.Replace(r, " boolean get", " boolean is", 1)
writer.Write([]byte(r))

View File

@@ -63,10 +63,6 @@ func (m Member) IsOmitempty() bool {
func (m Member) GetPropertyName() (string, error) {
tags := m.Tags()
if len(tags) == 0 {
return "", errors.New("json property name not exist, member: " + m.Name)
}
for _, tag := range tags {
if stringx.Contains(definedKeys, tag.Key) {
if tag.Name == "-" {

View File

@@ -85,7 +85,7 @@ func genHandler(dir, webApi, caller string, api *spec.ApiSpec, unwrapApi bool) e
imports += fmt.Sprintf(`import * as components from "%s"`, "./"+outputFile)
}
apis, err := genApi(api, localTypes, caller, prefixForType)
apis, err := genApi(api, caller, prefixForType)
if err != nil {
return err
}
@@ -119,32 +119,34 @@ func genTypes(localTypes []spec.Type, inlineType func(string) (*spec.Type, error
return types, nil
}
func genApi(api *spec.ApiSpec, localTypes []spec.Type, caller string, prefixForType func(string) string) (string, error) {
func genApi(api *spec.ApiSpec, caller string, prefixForType func(string) string) (string, error) {
var builder strings.Builder
for _, route := range api.Service.Routes() {
handler, ok := apiutil.GetAnnotationValue(route.Annotations, "server", "handler")
if !ok {
return "", fmt.Errorf("missing handler annotation for route %q", route.Path)
}
handler = util.Untitle(handler)
handler = strings.Replace(handler, "Handler", "", 1)
comment := commentForRoute(route)
if len(comment) > 0 {
fmt.Fprintf(&builder, "%s\n", comment)
}
fmt.Fprintf(&builder, "export function %s(%s) {\n", handler, paramsForRoute(route, prefixForType))
writeIndent(&builder, 1)
responseGeneric := "<null>"
if len(route.ResponseType.Name) > 0 {
val, err := goTypeToTs(route.ResponseType.Name, prefixForType)
if err != nil {
return "", err
for _, group := range api.Service.Groups {
for _, route := range group.Routes {
handler, ok := apiutil.GetAnnotationValue(route.Annotations, "server", "handler")
if !ok {
return "", fmt.Errorf("missing handler annotation for route %q", route.Path)
}
responseGeneric = fmt.Sprintf("<%s>", val)
handler = util.Untitle(handler)
handler = strings.Replace(handler, "Handler", "", 1)
comment := commentForRoute(route)
if len(comment) > 0 {
fmt.Fprintf(&builder, "%s\n", comment)
}
fmt.Fprintf(&builder, "export function %s(%s) {\n", handler, paramsForRoute(route, prefixForType))
writeIndent(&builder, 1)
responseGeneric := "<null>"
if len(route.ResponseType.Name) > 0 {
val, err := goTypeToTs(route.ResponseType.Name, prefixForType)
if err != nil {
return "", err
}
responseGeneric = fmt.Sprintf("<%s>", val)
}
fmt.Fprintf(&builder, `return %s.%s%s(%s)`, caller, strings.ToLower(route.Method),
util.Title(responseGeneric), callParamsForRoute(route, group))
builder.WriteString("\n}\n\n")
}
fmt.Fprintf(&builder, `return %s.%s%s(%s)`, caller, strings.ToLower(route.Method),
util.Title(responseGeneric), callParamsForRoute(route))
builder.WriteString("\n}\n\n")
}
apis := builder.String()
@@ -188,21 +190,28 @@ func commentForRoute(route spec.Route) string {
return builder.String()
}
func callParamsForRoute(route spec.Route) string {
func callParamsForRoute(route spec.Route, group spec.Group) string {
hasParams := pathHasParams(route)
hasBody := hasRequestBody(route)
if hasParams && hasBody {
return fmt.Sprintf("%s, %s, %s", pathForRoute(route), "params", "req")
return fmt.Sprintf("%s, %s, %s", pathForRoute(route, group), "params", "req")
} else if hasParams {
return fmt.Sprintf("%s, %s", pathForRoute(route), "params")
return fmt.Sprintf("%s, %s", pathForRoute(route, group), "params")
} else if hasBody {
return fmt.Sprintf("%s, %s", pathForRoute(route), "req")
return fmt.Sprintf("%s, %s", pathForRoute(route, group), "req")
}
return pathForRoute(route)
return pathForRoute(route, group)
}
func pathForRoute(route spec.Route) string {
return "\"" + route.Path + "\""
func pathForRoute(route spec.Route, group spec.Group) string {
value, ok := apiutil.GetAnnotationValue(group.Annotations, "server", pathPrefix)
if !ok {
return "\"" + route.Path + "\""
} else {
value = strings.TrimPrefix(value, `"`)
value = strings.TrimSuffix(value, `"`)
return fmt.Sprintf(`"%s/%s"`, value, strings.TrimPrefix(route.Path, "/"))
}
}
func pathHasParams(route spec.Route) bool {

View File

@@ -2,4 +2,5 @@ package tsgen
const (
packagePrefix = "components."
pathPrefix = "pathPrefix"
)

View File

@@ -9,15 +9,17 @@ import (
"text/template"
"time"
"github.com/logrusorgru/aurora"
"github.com/tal-tech/go-zero/tools/goctl/util"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/urfave/cli"
)
const (
etcDir = "etc"
yamlEtx = ".yaml"
cstOffset = 60 * 60 * 8 // 8 hours offset for Chinese Standard Time
dockerfileName = "Dockerfile"
etcDir = "etc"
yamlEtx = ".yaml"
cstOffset = 60 * 60 * 8 // 8 hours offset for Chinese Standard Time
)
type Docker struct {
@@ -25,10 +27,18 @@ type Docker struct {
GoRelPath string
GoFile string
ExeFile string
HasPort bool
Port int
Argument string
}
func DockerCommand(c *cli.Context) error {
func DockerCommand(c *cli.Context) (err error) {
defer func() {
if err == nil {
fmt.Println(aurora.Green("Done."))
}
}()
goFile := c.String("go")
if len(goFile) == 0 {
return errors.New("-go can't be empty")
@@ -38,8 +48,9 @@ func DockerCommand(c *cli.Context) error {
return fmt.Errorf("file %q not found", goFile)
}
port := c.Int("port")
if _, err := os.Stat(etcDir); os.IsNotExist(err) {
return generateDockerfile(goFile)
return generateDockerfile(goFile, port)
}
cfg, err := findConfig(goFile, etcDir)
@@ -47,13 +58,13 @@ func DockerCommand(c *cli.Context) error {
return err
}
if err := generateDockerfile(goFile, "-f", "etc/"+cfg); err != nil {
if err := generateDockerfile(goFile, port, "-f", "etc/"+cfg); err != nil {
return err
}
projDir, ok := util.FindProjectPath(goFile)
if ok {
fmt.Printf("Run \"docker build ...\" command in dir %q\n", projDir)
fmt.Printf("Hint: run \"docker build ...\" command in dir %q\n", projDir)
}
return nil
@@ -88,18 +99,22 @@ func findConfig(file, dir string) (string, error) {
return files[0], nil
}
func generateDockerfile(goFile string, args ...string) error {
func generateDockerfile(goFile string, port int, args ...string) error {
projPath, err := getFilePath(filepath.Dir(goFile))
if err != nil {
return err
}
pos := strings.IndexByte(projPath, '/')
if pos >= 0 {
projPath = projPath[pos+1:]
if len(projPath) == 0 {
projPath = "."
} else {
pos := strings.IndexByte(projPath, os.PathSeparator)
if pos >= 0 {
projPath = projPath[pos+1:]
}
}
out, err := util.CreateIfNotExist("Dockerfile")
out, err := util.CreateIfNotExist(dockerfileName)
if err != nil {
return err
}
@@ -122,6 +137,8 @@ func generateDockerfile(goFile string, args ...string) error {
GoRelPath: projPath,
GoFile: goFile,
ExeFile: util.FileNameWithoutExt(filepath.Base(goFile)),
HasPort: port > 0,
Port: port,
Argument: builder.String(),
})
}

View File

@@ -14,34 +14,59 @@ LABEL stage=gobuilder
ENV CGO_ENABLED 0
ENV GOOS linux
{{if .Chinese}}ENV GOPROXY https://goproxy.cn,direct{{end}}
{{if .Chinese}}ENV GOPROXY https://goproxy.cn,direct
{{end}}
WORKDIR /build/zero
ADD go.mod .
ADD go.sum .
RUN go mod download
COPY . .
COPY {{.GoRelPath}}/etc /app/etc
RUN go build -ldflags="-s -w" -o /app/{{.ExeFile}} {{.GoRelPath}}/{{.GoFile}}
{{if .Argument}}COPY {{.GoRelPath}}/etc /app/etc
{{end}}RUN go build -ldflags="-s -w" -o /app/{{.ExeFile}} {{.GoRelPath}}/{{.GoFile}}
FROM alpine
RUN apk update --no-cache
RUN apk add --no-cache ca-certificates
RUN apk add --no-cache tzdata
RUN apk update --no-cache && apk add --no-cache ca-certificates tzdata
ENV TZ Asia/Shanghai
WORKDIR /app
COPY --from=builder /app/{{.ExeFile}} /app/{{.ExeFile}}
COPY --from=builder /app/etc /app/etc
COPY --from=builder /app/{{.ExeFile}} /app/{{.ExeFile}}{{if .Argument}}
COPY --from=builder /app/etc /app/etc{{end}}
{{if .HasPort}}
EXPOSE {{.Port}}
{{end}}
CMD ["./{{.ExeFile}}"{{.Argument}}]
`
)
func Clean() error {
return util.Clean(category)
}
func GenTemplates(_ *cli.Context) error {
return initTemplate()
}
func Category() string {
return category
}
func RevertTemplate(name string) error {
return util.CreateTemplate(category, name, dockerTemplate)
}
func Update() error {
err := Clean()
if err != nil {
return err
}
return initTemplate()
}
func initTemplate() error {
return util.InitTemplates(category, map[string]string{
dockerTemplateFile: dockerTemplate,
})

View File

@@ -27,7 +27,7 @@ import (
)
var (
BuildVersion = "20201125"
BuildVersion = "1.1.1"
commands = []cli.Command{
{
Name: "api",
@@ -54,14 +54,12 @@ var (
Usage: "the format target dir",
},
cli.BoolFlag{
Name: "iu",
Usage: "ignore update",
Required: false,
Name: "iu",
Usage: "ignore update",
},
cli.BoolFlag{
Name: "stdin",
Usage: "use stdin to input api doc content, press \"ctrl + d\" to send EOF",
Required: false,
Name: "stdin",
Usage: "use stdin to input api doc content, press \"ctrl + d\" to send EOF",
},
},
Action: format.GoFormatApi,
@@ -101,9 +99,8 @@ var (
Usage: "the api file",
},
cli.StringFlag{
Name: "style",
Required: false,
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
Name: "style",
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
},
},
Action: gogen.GoCommand,
@@ -136,19 +133,16 @@ var (
Usage: "the api file",
},
cli.StringFlag{
Name: "webapi",
Usage: "the web api file path",
Required: false,
Name: "webapi",
Usage: "the web api file path",
},
cli.StringFlag{
Name: "caller",
Usage: "the web api caller",
Required: false,
Name: "caller",
Usage: "the web api caller",
},
cli.BoolFlag{
Name: "unwrap",
Usage: "unwrap the webapi caller for import",
Required: false,
Name: "unwrap",
Usage: "unwrap the webapi caller for import",
},
},
Action: tsgen.TsCommand,
@@ -204,9 +198,8 @@ var (
Usage: "the api file",
},
cli.StringFlag{
Name: "style",
Required: false,
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
Name: "style",
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
},
},
Action: plugin.PluginCommand,
@@ -221,6 +214,11 @@ var (
Name: "go",
Usage: "the file that contains main function",
},
cli.IntFlag{
Name: "port",
Usage: "the port to expose, default none",
Value: 0,
},
},
Action: docker.DockerCommand,
},
@@ -248,9 +246,8 @@ var (
Required: true,
},
cli.StringFlag{
Name: "secret",
Usage: "the image pull secret",
Required: true,
Name: "secret",
Usage: "the secret to image pull from registry",
},
cli.IntFlag{
Name: "requestCpu",
@@ -321,9 +318,8 @@ var (
Usage: `generate rpc demo service`,
Flags: []cli.Flag{
cli.StringFlag{
Name: "style",
Required: false,
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
Name: "style",
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
},
cli.BoolFlag{
Name: "idea",
@@ -360,9 +356,8 @@ var (
Usage: `the target path of the code`,
},
cli.StringFlag{
Name: "style",
Required: false,
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
Name: "style",
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
},
cli.BoolFlag{
Name: "idea",
@@ -394,9 +389,8 @@ var (
Usage: "the target dir",
},
cli.StringFlag{
Name: "style",
Required: false,
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
Name: "style",
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
},
cli.BoolFlag{
Name: "cache, c",
@@ -430,9 +424,8 @@ var (
Usage: "the target dir",
},
cli.StringFlag{
Name: "style",
Required: false,
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
Name: "style",
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
},
cli.BoolFlag{
Name: "idea",
@@ -476,7 +469,7 @@ var (
Flags: []cli.Flag{
cli.StringFlag{
Name: "category,c",
Usage: "the category of template, enum [api,rpc,model]",
Usage: "the category of template, enum [api,rpc,model,docker,kube]",
},
},
Action: tpl.UpdateTemplates,
@@ -487,7 +480,7 @@ var (
Flags: []cli.Flag{
cli.StringFlag{
Name: "category,c",
Usage: "the category of template, enum [api,rpc,model]",
Usage: "the category of template, enum [api,rpc,model,docker,kube]",
},
cli.StringFlag{
Name: "name,n",

View File

@@ -47,9 +47,9 @@ spec:
volumeMounts:
- name: timezone
mountPath: /etc/localtime
imagePullSecrets:
{{if .Secret}}imagePullSecrets:
- name: {{.Secret}}
volumes:
{{end}}volumes:
- name: timezone
hostPath:
path: /usr/share/zoneinfo/Asia/Shanghai

View File

@@ -2,8 +2,10 @@ package kube
import (
"errors"
"fmt"
"text/template"
"github.com/logrusorgru/aurora"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/urfave/cli"
)
@@ -16,47 +18,23 @@ const (
portLimit = 32767
)
var errUnknownServiceType = errors.New("unknown service type")
type (
ServiceType string
KubeRequest struct {
Env string
ServiceName string
ServiceType ServiceType
Namespace string
Schedule string
Replicas int
RevisionHistoryLimit int
Port int
LimitCpu int
LimitMem int
RequestCpu int
RequestMem int
SuccessfulJobsHistoryLimit int
HpaMinReplicas int
HpaMaxReplicas int
}
Deployment struct {
Name string
Namespace string
Image string
Secret string
Replicas int
Revisions int
Port int
NodePort int
UseNodePort bool
RequestCpu int
RequestMem int
LimitCpu int
LimitMem int
MinReplicas int
MaxReplicas int
}
)
type Deployment struct {
Name string
Namespace string
Image string
Secret string
Replicas int
Revisions int
Port int
NodePort int
UseNodePort bool
RequestCpu int
RequestMem int
LimitCpu int
LimitMem int
MinReplicas int
MaxReplicas int
}
func DeploymentCommand(c *cli.Context) error {
nodePort := c.Int("nodePort")
@@ -77,7 +55,7 @@ func DeploymentCommand(c *cli.Context) error {
defer out.Close()
t := template.Must(template.New("deploymentTemplate").Parse(text))
return t.Execute(out, Deployment{
err = t.Execute(out, Deployment{
Name: c.String("name"),
Namespace: c.String("namespace"),
Image: c.String("image"),
@@ -94,6 +72,20 @@ func DeploymentCommand(c *cli.Context) error {
MinReplicas: c.Int("minReplicas"),
MaxReplicas: c.Int("maxReplicas"),
})
if err != nil {
return err
}
fmt.Println(aurora.Green("Done."))
return nil
}
func Category() string {
return category
}
func Clean() error {
return util.Clean(category)
}
func GenTemplates(_ *cli.Context) error {
@@ -102,3 +94,19 @@ func GenTemplates(_ *cli.Context) error {
jobTemplateFile: jobTmeplate,
})
}
func RevertTemplate(name string) error {
return util.CreateTemplate(category, name, deploymentTemplate)
}
func Update() error {
err := Clean()
if err != nil {
return err
}
return util.InitTemplates(category, map[string]string{
deployTemplateFile: deploymentTemplate,
jobTemplateFile: jobTmeplate,
})
}

View File

@@ -61,9 +61,9 @@ func FieldNames(in interface{}) []string {
// gets us a StructField
fi := typ.Field(i)
if tagv := fi.Tag.Get(dbTag); tagv != "" {
out = append(out, tagv)
out = append(out, fmt.Sprintf("`%v`", tagv))
} else {
out = append(out, fi.Name)
out = append(out, fmt.Sprintf("`%v`", fi.Name))
}
}
return out

View File

@@ -28,8 +28,7 @@ var userFields = FieldNames(User{})
func TestFieldNames(t *testing.T) {
var u User
out := FieldNames(&u)
fmt.Println(out)
actual := []string{"id", "user_name", "sex", "uuid", "age"}
actual := []string{"`id`", "`user_name`", "`sex`", "`uuid`", "`age`"}
assert.Equal(t, out, actual)
}
@@ -54,7 +53,7 @@ func TestBuilderSql(t *testing.T) {
sql, args, err := builder.Select(fields...).From("user").Where(eq).ToSQL()
fmt.Println(sql, args, err)
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE id=?"
actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE id=?"
actualArgs := []interface{}{"123123"}
assert.Equal(t, sql, actualSql)
assert.Equal(t, args, actualArgs)
@@ -68,7 +67,7 @@ func TestBuildSqlDefaultValue(t *testing.T) {
sql, args, err := builder.Select(userFields...).From("user").Where(eq).ToSQL()
fmt.Println(sql, args, err)
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE age=? AND user_name=?"
actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE age=? AND user_name=?"
actualArgs := []interface{}{0, ""}
assert.Equal(t, sql, actualSql)
assert.Equal(t, args, actualArgs)
@@ -83,7 +82,7 @@ func TestBuilderSqlIn(t *testing.T) {
sql, args, err := builder.Select(userFields...).From("user").Where(in).And(gtU).ToSQL()
fmt.Println(sql, args, err)
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE id IN (?,?,?) AND age>?"
actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE id IN (?,?,?) AND age>?"
actualArgs := []interface{}{"1", "2", "3", 18}
assert.Equal(t, sql, actualSql)
assert.Equal(t, args, actualArgs)
@@ -94,7 +93,7 @@ func TestBuildSqlLike(t *testing.T) {
sql, args, err := builder.Select(userFields...).From("user").Where(like).ToSQL()
fmt.Println(sql, args, err)
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE name LIKE ?"
actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE name LIKE ?"
actualArgs := []interface{}{"%wang%"}
assert.Equal(t, sql, actualSql)
assert.Equal(t, args, actualArgs)

View File

@@ -1,8 +1,13 @@
#!/bin/bash
# generate model with cache from ddl
fromDDL:
goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/user" -cache
fromDDLWithCache:
goctl template clean;
goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/cache/user" -cache;
fromDDLWithoutCache:
goctl template clean;
goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/nocache/user";
# generate model with cache from data source
@@ -12,4 +17,5 @@ datasource=127.0.0.1:3306
database=gozero
fromDataSource:
goctl model mysql datasource -url="$(user):$(password)@tcp($(datasource))/$(database)" -table="*" -dir ./model/cache -c -style gozero
goctl template clean;
goctl model mysql datasource -url="$(user):$(password)@tcp($(datasource))/$(database)" -table="*" -dir ./model/cache -c -style gozero;

View File

@@ -36,7 +36,7 @@ func genDelete(table Table, withCache bool) (string, string, error) {
"lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(),
"dataType": table.PrimaryKey.DataType,
"keys": strings.Join(keySet.KeysStr(), "\n"),
"originalPrimaryKey": table.PrimaryKey.Name.Source(),
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
"keyValues": strings.Join(keyVariableSet.KeysStr(), ", "),
})
if err != nil {

View File

@@ -19,7 +19,7 @@ func genFindOne(table Table, withCache bool) (string, string, error) {
"withCache": withCache,
"upperStartCamelObject": camel,
"lowerStartCamelObject": stringx.From(camel).Untitle(),
"originalPrimaryKey": table.PrimaryKey.Name.Source(),
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
"lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(),
"dataType": table.PrimaryKey.DataType,
"cacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].KeyExpression,

View File

@@ -39,7 +39,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
"lowerStartCamelField": stringx.From(camelFieldName).Untitle(),
"upperStartCamelPrimaryKey": table.PrimaryKey.Name.ToCamel(),
"originalField": field.Name.Source(),
"originalField": wrapWithRawString(field.Name.Source()),
})
if err != nil {
return nil, err
@@ -82,7 +82,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
"upperStartCamelObject": camelTableName,
"primaryKeyLeft": table.CacheKey[table.PrimaryKey.Name.Source()].Left,
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
"originalPrimaryField": table.PrimaryKey.Name.Source(),
"originalPrimaryField": wrapWithRawString(table.PrimaryKey.Name.Source()),
})
if err != nil {
return nil, err

View File

@@ -21,9 +21,6 @@ import (
const (
pwd = "."
createTableFlag = `(?m)^(?i)CREATE\s+TABLE` // ignore case
NamingLower = "lower"
NamingCamel = "camel"
NamingSnake = "snake"
)
type (
@@ -280,3 +277,20 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
return output.String(), nil
}
func wrapWithRawString(v string) string {
if v == "`" {
return v
}
if !strings.HasPrefix(v, "`") {
v = "`" + v
}
if !strings.HasSuffix(v, "`") {
v = v + "`"
} else if len(v) == 1 {
v = v + "`"
}
return v
}

View File

@@ -1,13 +1,18 @@
package gen
import (
"database/sql"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
)
var (
@@ -79,3 +84,32 @@ func TestNamingModel(t *testing.T) {
return err == nil
}())
}
func TestWrapWithRawString(t *testing.T) {
assert.Equal(t, "``", wrapWithRawString(""))
assert.Equal(t, "``", wrapWithRawString("``"))
assert.Equal(t, "`a`", wrapWithRawString("a"))
assert.Equal(t, "` `", wrapWithRawString(" "))
}
func TestFields(t *testing.T) {
type Student struct {
Id int64 `db:"id"`
Name string `db:"name"`
Age sql.NullInt64 `db:"age"`
Score sql.NullFloat64 `db:"score"`
CreateTime time.Time `db:"create_time"`
UpdateTime sql.NullTime `db:"update_time"`
}
var (
studentFieldNames = builderx.FieldNames(&Student{})
studentRows = strings.Join(studentFieldNames, ",")
studentRowsExpectAutoSet = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
studentRowsWithPlaceHolder = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
)
assert.Equal(t, []string{"`id`", "`name`", "`age`", "`score`", "`create_time`", "`update_time`"}, studentFieldNames)
assert.Equal(t, "`id`,`name`,`age`,`score`,`create_time`,`update_time`", studentRows)
assert.Equal(t, "`name`,`age`,`score`", studentRowsExpectAutoSet)
assert.Equal(t, "`name`=?,`age`=?,`score`=?", studentRowsWithPlaceHolder)
}

View File

@@ -14,7 +14,7 @@ func genNew(table Table, withCache bool) (string, error) {
output, err := util.With("new").
Parse(text).
Execute(map[string]interface{}{
"table": table.Name.Source(),
"table": wrapWithRawString(table.Name.Source()),
"withCache": withCache,
"upperStartCamelObject": table.Name.ToCamel(),
})

View File

@@ -54,6 +54,14 @@ var templates = map[string]string{
errTemplateFile: template.Error,
}
func Category() string {
return category
}
func Clean() error {
return util.Clean(category)
}
func GenTemplates(_ *cli.Context) error {
return util.InitTemplates(category, templates)
}
@@ -66,18 +74,10 @@ func RevertTemplate(name string) error {
return util.CreateTemplate(category, name, content)
}
func Clean() error {
return util.Clean(category)
}
func Update(category string) error {
func Update() error {
err := Clean()
if err != nil {
return err
}
return util.InitTemplates(category, templates)
}
func GetCategory() string {
return category
}

View File

@@ -85,7 +85,7 @@ func TestUpdate(t *testing.T) {
assert.Equal(t, string(data), modifyData)
assert.Nil(t, Update(category))
assert.Nil(t, Update())
data, err = ioutil.ReadFile(file)
assert.Nil(t, err)

View File

@@ -35,7 +35,7 @@ func genUpdate(table Table, withCache bool) (string, string, error) {
"primaryCacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].DataKeyExpression,
"primaryKeyVariable": table.CacheKey[table.PrimaryKey.Name.Source()].Variable,
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
"originalPrimaryKey": table.PrimaryKey.Name.Source(),
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
"expressionValues": strings.Join(expressionValues, ", "),
})
if err != nil {

View File

@@ -27,7 +27,7 @@ func genVars(table Table, withCache bool) (string, error) {
"upperStartCamelObject": camel,
"cacheKeys": strings.Join(keys, "\n"),
"autoIncrement": table.PrimaryKey.AutoIncrement,
"originalPrimaryKey": table.PrimaryKey.Name.Source(),
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
"withCache": withCache,
})
if err != nil {

View File

@@ -1,34 +0,0 @@
package model
import (
"github.com/tal-tech/go-zero/core/stores/sqlx"
)
type (
DDLModel struct {
conn sqlx.SqlConn
}
DDL struct {
Table string `db:"Table"`
DDL string `db:"Create Table"`
}
)
func NewDDLModel(conn sqlx.SqlConn) *DDLModel {
return &DDLModel{conn: conn}
}
func (m *DDLModel) ShowDDL(table ...string) ([]string, error) {
var ddl []string
for _, t := range table {
query := `show create table ` + t
var resp DDL
err := m.conn.QueryRow(&resp, query)
if err != nil {
return nil, err
}
ddl = append(ddl, resp.DDL)
}
return ddl, nil
}

View File

@@ -1,12 +1,14 @@
package template
var Vars = `
import "fmt"
var Vars = fmt.Sprintf(`
var (
{{.lowerStartCamelObject}}FieldNames = builderx.FieldNames(&{{.upperStartCamelObject}}{})
{{.lowerStartCamelObject}}Rows = strings.Join({{.lowerStartCamelObject}}FieldNames, ",")
{{.lowerStartCamelObject}}RowsExpectAutoSet = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "create_time", "update_time"), ",")
{{.lowerStartCamelObject}}RowsWithPlaceHolder = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, "{{.originalPrimaryKey}}", "create_time", "update_time"), "=?,") + "=?"
{{.lowerStartCamelObject}}RowsExpectAutoSet = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "%screate_time%s", "%supdate_time%s"), ",")
{{.lowerStartCamelObject}}RowsWithPlaceHolder = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, "{{.originalPrimaryKey}}", "%screate_time%s", "%supdate_time%s"), "=?,") + "=?"
{{if .withCache}}{{.cacheKeys}}{{end}}
)
`
`, "`", "`", "`", "`", "`", "`", "`", "`")

View File

@@ -0,0 +1,235 @@
package model
import (
"database/sql"
"fmt"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/stores/cache"
"github.com/tal-tech/go-zero/core/stores/redis"
"github.com/tal-tech/go-zero/core/stores/redis/redistest"
mocksql "github.com/tal-tech/go-zero/tools/goctl/model/sql/test"
)
func TestStudentModel(t *testing.T) {
var (
testTimeValue = time.Now()
testTable = "`student`"
testUpdateName = "gozero1"
testRowsAffected int64 = 1
testInsertId int64 = 1
)
var data Student
data.Id = testInsertId
data.Name = "gozero"
data.Age = sql.NullInt64{
Int64: 1,
Valid: true,
}
data.Score = sql.NullFloat64{
Float64: 100,
Valid: true,
}
data.CreateTime = testTimeValue
data.UpdateTime = sql.NullTime{
Time: testTimeValue,
Valid: true,
}
err := mockStudent(func(mock sqlmock.Sqlmock) {
mock.ExpectExec(fmt.Sprintf("insert into %s", testTable)).
WithArgs(data.Name, data.Age, data.Score).
WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
}, func(m StudentModel) {
r, err := m.Insert(data)
assert.Nil(t, err)
lastInsertId, err := r.LastInsertId()
assert.Nil(t, err)
assert.Equal(t, testInsertId, lastInsertId)
rowsAffected, err := r.RowsAffected()
assert.Nil(t, err)
assert.Equal(t, testRowsAffected, rowsAffected)
})
assert.Nil(t, err)
err = mockStudent(func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s", testTable)).
WithArgs(testInsertId).
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertId, data.Name, data.Age, data.Score, testTimeValue, testTimeValue))
}, func(m StudentModel) {
result, err := m.FindOne(testInsertId)
assert.Nil(t, err)
assert.Equal(t, *result, data)
})
assert.Nil(t, err)
err = mockStudent(func(mock sqlmock.Sqlmock) {
mock.ExpectExec(fmt.Sprintf("update %s", testTable)).WithArgs(testUpdateName, data.Age, data.Score, testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
}, func(m StudentModel) {
data.Name = testUpdateName
err := m.Update(data)
assert.Nil(t, err)
})
assert.Nil(t, err)
err = mockStudent(func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s ", testTable)).
WithArgs(testInsertId).
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertId, data.Name, data.Age, data.Score, testTimeValue, testTimeValue))
}, func(m StudentModel) {
result, err := m.FindOne(testInsertId)
assert.Nil(t, err)
assert.Equal(t, *result, data)
})
assert.Nil(t, err)
err = mockStudent(func(mock sqlmock.Sqlmock) {
mock.ExpectExec(fmt.Sprintf("delete from %s where `id` = ?", testTable)).WithArgs(testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
}, func(m StudentModel) {
err := m.Delete(testInsertId)
assert.Nil(t, err)
})
assert.Nil(t, err)
}
func TestUserModel(t *testing.T) {
var (
testTimeValue = time.Now()
testTable = "`user`"
testUpdateName = "gozero1"
testUser = "gozero"
testPassword = "test"
testMobile = "test_mobile"
testGender = "男"
testNickname = "test_nickname"
testRowsAffected int64 = 1
testInsertId int64 = 1
)
var data User
data.Id = testInsertId
data.User = testUser
data.Name = "gozero"
data.Password = testPassword
data.Mobile = testMobile
data.Gender = testGender
data.Nickname = testNickname
data.CreateTime = testTimeValue
data.UpdateTime = testTimeValue
err := mockUser(func(mock sqlmock.Sqlmock) {
mock.ExpectExec(fmt.Sprintf("insert into %s", testTable)).
WithArgs(data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname).
WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
}, func(m UserModel) {
r, err := m.Insert(data)
assert.Nil(t, err)
lastInsertId, err := r.LastInsertId()
assert.Nil(t, err)
assert.Equal(t, testInsertId, lastInsertId)
rowsAffected, err := r.RowsAffected()
assert.Nil(t, err)
assert.Equal(t, testRowsAffected, rowsAffected)
})
assert.Nil(t, err)
err = mockUser(func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s", testTable)).
WithArgs(testInsertId).
WillReturnRows(sqlmock.NewRows([]string{"id", "user", "name", "password", "mobile", "gender", "nickname", "create_time", "update_time"}).AddRow(testInsertId, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, testTimeValue, testTimeValue))
}, func(m UserModel) {
result, err := m.FindOne(testInsertId)
assert.Nil(t, err)
assert.Equal(t, *result, data)
})
assert.Nil(t, err)
err = mockUser(func(mock sqlmock.Sqlmock) {
mock.ExpectExec(fmt.Sprintf("update %s", testTable)).WithArgs(data.User, testUpdateName, data.Password, data.Mobile, data.Gender, data.Nickname, testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
}, func(m UserModel) {
data.Name = testUpdateName
err := m.Update(data)
assert.Nil(t, err)
})
assert.Nil(t, err)
err = mockUser(func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s ", testTable)).
WithArgs(testInsertId).
WillReturnRows(sqlmock.NewRows([]string{"id", "user", "name", "password", "mobile", "gender", "nickname", "create_time", "update_time"}).AddRow(testInsertId, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, testTimeValue, testTimeValue))
}, func(m UserModel) {
result, err := m.FindOne(testInsertId)
assert.Nil(t, err)
assert.Equal(t, *result, data)
})
assert.Nil(t, err)
err = mockUser(func(mock sqlmock.Sqlmock) {
mock.ExpectExec(fmt.Sprintf("delete from %s where `id` = ?", testTable)).WithArgs(testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
}, func(m UserModel) {
err := m.Delete(testInsertId)
assert.Nil(t, err)
})
assert.Nil(t, err)
}
// with cache
func mockStudent(mockFn func(mock sqlmock.Sqlmock), fn func(m StudentModel)) error {
db, mock, err := sqlmock.New()
if err != nil {
return err
}
defer db.Close()
mock.ExpectBegin()
mockFn(mock)
mock.ExpectCommit()
conn := mocksql.NewMockConn(db)
r, clean, err := redistest.CreateRedis()
if err != nil {
return err
}
defer clean()
m := NewStudentModel(conn, cache.CacheConf{
{
RedisConf: redis.RedisConf{
Host: r.Addr,
Type: "node",
},
Weight: 100,
},
})
fn(m)
return nil
}
// without cache
func mockUser(mockFn func(mock sqlmock.Sqlmock), fn func(m UserModel)) error {
db, mock, err := sqlmock.New()
if err != nil {
return err
}
defer db.Close()
mock.ExpectBegin()
mockFn(mock)
mock.ExpectCommit()
conn := mocksql.NewMockConn(db)
m := NewUserModel(conn)
fn(m)
return nil
}

View File

@@ -0,0 +1,105 @@
package model
import (
"database/sql"
"fmt"
"strings"
"time"
"github.com/tal-tech/go-zero/core/stores/cache"
"github.com/tal-tech/go-zero/core/stores/sqlc"
"github.com/tal-tech/go-zero/core/stores/sqlx"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
)
var (
studentFieldNames = builderx.FieldNames(&Student{})
studentRows = strings.Join(studentFieldNames, ",")
studentRowsExpectAutoSet = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
studentRowsWithPlaceHolder = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
cacheStudentIdPrefix = "cache#Student#id#"
)
type (
StudentModel interface {
Insert(data Student) (sql.Result, error)
FindOne(id int64) (*Student, error)
Update(data Student) error
Delete(id int64) error
}
defaultStudentModel struct {
sqlc.CachedConn
table string
}
Student struct {
Id int64 `db:"id"`
Name string `db:"name"`
Age sql.NullInt64 `db:"age"`
Score sql.NullFloat64 `db:"score"`
CreateTime time.Time `db:"create_time"`
UpdateTime sql.NullTime `db:"update_time"`
}
)
func NewStudentModel(conn sqlx.SqlConn, c cache.CacheConf) StudentModel {
return &defaultStudentModel{
CachedConn: sqlc.NewConn(conn, c),
table: "`student`",
}
}
func (m *defaultStudentModel) Insert(data Student) (sql.Result, error) {
query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?)", m.table, studentRowsExpectAutoSet)
ret, err := m.ExecNoCache(query, data.Name, data.Age, data.Score)
return ret, err
}
func (m *defaultStudentModel) FindOne(id int64) (*Student, error) {
studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, id)
var resp Student
err := m.QueryRow(&resp, studentIdKey, func(conn sqlx.SqlConn, v interface{}) error {
query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", studentRows, m.table)
return conn.QueryRow(v, query, id)
})
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *defaultStudentModel) Update(data Student) error {
studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, data.Id)
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := fmt.Sprintf("update %s set %s where `id` = ?", m.table, studentRowsWithPlaceHolder)
return conn.Exec(query, data.Name, data.Age, data.Score, data.Id)
}, studentIdKey)
return err
}
func (m *defaultStudentModel) Delete(id int64) error {
studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, id)
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := fmt.Sprintf("delete from %s where `id` = ?", m.table)
return conn.Exec(query, id)
}, studentIdKey)
return err
}
func (m *defaultStudentModel) formatPrimary(primary interface{}) string {
return fmt.Sprintf("%s%v", cacheStudentIdPrefix, primary)
}
func (m *defaultStudentModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error {
query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", studentRows, m.table)
return conn.QueryRow(v, query, primary)
}

View File

@@ -0,0 +1,130 @@
package model
import (
"database/sql"
"fmt"
"strings"
"time"
"github.com/tal-tech/go-zero/core/stores/sqlc"
"github.com/tal-tech/go-zero/core/stores/sqlx"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
)
var (
userFieldNames = builderx.FieldNames(&User{})
userRows = strings.Join(userFieldNames, ",")
userRowsExpectAutoSet = strings.Join(stringx.Remove(userFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
userRowsWithPlaceHolder = strings.Join(stringx.Remove(userFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
)
type (
UserModel interface {
Insert(data User) (sql.Result, error)
FindOne(id int64) (*User, error)
FindOneByUser(user string) (*User, error)
FindOneByName(name string) (*User, error)
FindOneByMobile(mobile string) (*User, error)
Update(data User) error
Delete(id int64) error
}
defaultUserModel struct {
conn sqlx.SqlConn
table string
}
User struct {
Id int64 `db:"id"`
User string `db:"user"` // 用户
Name string `db:"name"` // 用户名称
Password string `db:"password"` // 用户密码
Mobile string `db:"mobile"` // 手机号
Gender string `db:"gender"` // 男|女|未公开
Nickname string `db:"nickname"` // 用户昵称
CreateTime time.Time `db:"create_time"`
UpdateTime time.Time `db:"update_time"`
}
)
func NewUserModel(conn sqlx.SqlConn) UserModel {
return &defaultUserModel{
conn: conn,
table: "`user`",
}
}
func (m *defaultUserModel) Insert(data User) (sql.Result, error) {
query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?, ?, ?, ?)", m.table, userRowsExpectAutoSet)
ret, err := m.conn.Exec(query, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname)
return ret, err
}
func (m *defaultUserModel) FindOne(id int64) (*User, error) {
query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", userRows, m.table)
var resp User
err := m.conn.QueryRow(&resp, query, id)
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *defaultUserModel) FindOneByUser(user string) (*User, error) {
var resp User
query := fmt.Sprintf("select %s from %s where `user` = ? limit 1", userRows, m.table)
err := m.conn.QueryRow(&resp, query, user)
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *defaultUserModel) FindOneByName(name string) (*User, error) {
var resp User
query := fmt.Sprintf("select %s from %s where `name` = ? limit 1", userRows, m.table)
err := m.conn.QueryRow(&resp, query, name)
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *defaultUserModel) FindOneByMobile(mobile string) (*User, error) {
var resp User
query := fmt.Sprintf("select %s from %s where `mobile` = ? limit 1", userRows, m.table)
err := m.conn.QueryRow(&resp, query, mobile)
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *defaultUserModel) Update(data User) error {
query := fmt.Sprintf("update %s set %s where `id` = ?", m.table, userRowsWithPlaceHolder)
_, err := m.conn.Exec(query, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, data.Id)
return err
}
func (m *defaultUserModel) Delete(id int64) error {
query := fmt.Sprintf("delete from %s where `id` = ?", m.table)
_, err := m.conn.Exec(query, id)
return err
}

View File

@@ -0,0 +1,5 @@
package model
import "github.com/tal-tech/go-zero/core/stores/sqlx"
var ErrNotFound = sqlx.ErrNotFound

View File

@@ -0,0 +1,255 @@
// copy from core/stores/sqlx/orm.go
package mocksql
import (
"errors"
"reflect"
"strings"
"github.com/tal-tech/go-zero/core/mapping"
)
const tagName = "db"
var (
ErrNotMatchDestination = errors.New("not matching destination to scan")
ErrNotReadableValue = errors.New("value not addressable or interfaceable")
ErrNotSettable = errors.New("passed in variable is not settable")
ErrUnsupportedValueType = errors.New("unsupported unmarshal type")
)
type rowsScanner interface {
Columns() ([]string, error)
Err() error
Next() bool
Scan(v ...interface{}) error
}
func getTaggedFieldValueMap(v reflect.Value) (map[string]interface{}, error) {
rt := mapping.Deref(v.Type())
size := rt.NumField()
result := make(map[string]interface{}, size)
for i := 0; i < size; i++ {
key := parseTagName(rt.Field(i))
if len(key) == 0 {
return nil, nil
}
valueField := reflect.Indirect(v).Field(i)
switch valueField.Kind() {
case reflect.Ptr:
if !valueField.CanInterface() {
return nil, ErrNotReadableValue
}
if valueField.IsNil() {
baseValueType := mapping.Deref(valueField.Type())
valueField.Set(reflect.New(baseValueType))
}
result[key] = valueField.Interface()
default:
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
result[key] = valueField.Addr().Interface()
}
}
return result, nil
}
func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]interface{}, error) {
fields := unwrapFields(v)
if strict && len(columns) < len(fields) {
return nil, ErrNotMatchDestination
}
taggedMap, err := getTaggedFieldValueMap(v)
if err != nil {
return nil, err
}
values := make([]interface{}, len(columns))
if len(taggedMap) == 0 {
for i := 0; i < len(values); i++ {
valueField := fields[i]
switch valueField.Kind() {
case reflect.Ptr:
if !valueField.CanInterface() {
return nil, ErrNotReadableValue
}
if valueField.IsNil() {
baseValueType := mapping.Deref(valueField.Type())
valueField.Set(reflect.New(baseValueType))
}
values[i] = valueField.Interface()
default:
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
values[i] = valueField.Addr().Interface()
}
}
} else {
for i, column := range columns {
if tagged, ok := taggedMap[column]; ok {
values[i] = tagged
} else {
var anonymous interface{}
values[i] = &anonymous
}
}
}
return values, nil
}
func parseTagName(field reflect.StructField) string {
key := field.Tag.Get(tagName)
if len(key) == 0 {
return ""
} else {
options := strings.Split(key, ",")
return options[0]
}
}
func unmarshalRow(v interface{}, scanner rowsScanner, strict bool) error {
if !scanner.Next() {
if err := scanner.Err(); err != nil {
return err
}
return ErrNotFound
}
rv := reflect.ValueOf(v)
if err := mapping.ValidatePtr(&rv); err != nil {
return err
}
rte := reflect.TypeOf(v).Elem()
rve := rv.Elem()
switch rte.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String:
if rve.CanSet() {
return scanner.Scan(v)
} else {
return ErrNotSettable
}
case reflect.Struct:
columns, err := scanner.Columns()
if err != nil {
return err
}
if values, err := mapStructFieldsIntoSlice(rve, columns, strict); err != nil {
return err
} else {
return scanner.Scan(values...)
}
default:
return ErrUnsupportedValueType
}
}
func unmarshalRows(v interface{}, scanner rowsScanner, strict bool) error {
rv := reflect.ValueOf(v)
if err := mapping.ValidatePtr(&rv); err != nil {
return err
}
rt := reflect.TypeOf(v)
rte := rt.Elem()
rve := rv.Elem()
switch rte.Kind() {
case reflect.Slice:
if rve.CanSet() {
ptr := rte.Elem().Kind() == reflect.Ptr
appendFn := func(item reflect.Value) {
if ptr {
rve.Set(reflect.Append(rve, item))
} else {
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
}
}
fillFn := func(value interface{}) error {
if rve.CanSet() {
if err := scanner.Scan(value); err != nil {
return err
} else {
appendFn(reflect.ValueOf(value))
return nil
}
}
return ErrNotSettable
}
base := mapping.Deref(rte.Elem())
switch base.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String:
for scanner.Next() {
value := reflect.New(base)
if err := fillFn(value.Interface()); err != nil {
return err
}
}
case reflect.Struct:
columns, err := scanner.Columns()
if err != nil {
return err
}
for scanner.Next() {
value := reflect.New(base)
if values, err := mapStructFieldsIntoSlice(value, columns, strict); err != nil {
return err
} else {
if err := scanner.Scan(values...); err != nil {
return err
} else {
appendFn(value)
}
}
}
default:
return ErrUnsupportedValueType
}
return nil
} else {
return ErrNotSettable
}
default:
return ErrUnsupportedValueType
}
}
func unwrapFields(v reflect.Value) []reflect.Value {
var fields []reflect.Value
indirect := reflect.Indirect(v)
for i := 0; i < indirect.NumField(); i++ {
child := indirect.Field(i)
if child.Kind() == reflect.Ptr && child.IsNil() {
baseValueType := mapping.Deref(child.Type())
child.Set(reflect.New(baseValueType))
}
child = reflect.Indirect(child)
childType := indirect.Type().Field(i)
if child.Kind() == reflect.Struct && childType.Anonymous {
fields = append(fields, unwrapFields(child)...)
} else {
fields = append(fields, child)
}
}
return fields
}

View File

@@ -0,0 +1,90 @@
// copy from core/stores/sqlx/sqlconn.go
package mocksql
import (
"database/sql"
"github.com/tal-tech/go-zero/core/stores/sqlx"
)
type (
MockConn struct {
db *sql.DB
}
statement struct {
stmt *sql.Stmt
}
)
func NewMockConn(db *sql.DB) *MockConn {
return &MockConn{db: db}
}
func (conn *MockConn) Exec(query string, args ...interface{}) (sql.Result, error) {
return exec(conn.db, query, args...)
}
func (conn *MockConn) Prepare(query string) (sqlx.StmtSession, error) {
st, err := conn.db.Prepare(query)
return statement{stmt: st}, err
}
func (conn *MockConn) QueryRow(v interface{}, q string, args ...interface{}) error {
return query(conn.db, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true)
}, q, args...)
}
func (conn *MockConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
return query(conn.db, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false)
}, q, args...)
}
func (conn *MockConn) QueryRows(v interface{}, q string, args ...interface{}) error {
return query(conn.db, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true)
}, q, args...)
}
func (conn *MockConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
return query(conn.db, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false)
}, q, args...)
}
func (conn *MockConn) Transact(func(session sqlx.Session) error) error {
return nil
}
func (s statement) Close() error {
return s.stmt.Close()
}
func (s statement) Exec(args ...interface{}) (sql.Result, error) {
return execStmt(s.stmt, args...)
}
func (s statement) QueryRow(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true)
}, args...)
}
func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false)
}, args...)
}
func (s statement) QueryRows(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true)
}, args...)
}
func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false)
}, args...)
}

View File

@@ -0,0 +1,122 @@
// copy from core/stores/sqlx/stmt.go
package mocksql
import (
"database/sql"
"fmt"
"time"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/timex"
)
const slowThreshold = time.Millisecond * 500
func exec(db *sql.DB, q string, args ...interface{}) (sql.Result, error) {
tx, err := db.Begin()
if err != nil {
return nil, err
}
defer func() {
switch err {
case nil:
err = tx.Commit()
default:
tx.Rollback()
}
}()
stmt, err := format(q, args...)
if err != nil {
return nil, err
}
startTime := timex.Now()
result, err := tx.Exec(q, args...)
duration := timex.Since(startTime)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql exec: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
}
return result, err
}
func execStmt(conn *sql.Stmt, args ...interface{}) (sql.Result, error) {
stmt := fmt.Sprint(args...)
startTime := timex.Now()
result, err := conn.Exec(args...)
duration := timex.Since(startTime)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql execStmt: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
}
return result, err
}
func query(db *sql.DB, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer func() {
switch err {
case nil:
err = tx.Commit()
default:
tx.Rollback()
}
}()
stmt, err := format(q, args...)
if err != nil {
return err
}
startTime := timex.Now()
rows, err := tx.Query(q, args...)
duration := timex.Since(startTime)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql query: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
return err
}
defer rows.Close()
return scanner(rows)
}
func queryStmt(conn *sql.Stmt, scanner func(*sql.Rows) error, args ...interface{}) error {
stmt := fmt.Sprint(args...)
startTime := timex.Now()
rows, err := conn.Query(args...)
duration := timex.Since(startTime)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql queryStmt: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
return err
}
defer rows.Close()
return scanner(rows)
}

View File

@@ -0,0 +1,105 @@
// copy from core/stores/sqlx/utils.go
package mocksql
import (
"database/sql"
"fmt"
"strings"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/mapping"
)
var ErrNotFound = sql.ErrNoRows
func desensitize(datasource string) string {
// remove account
pos := strings.LastIndex(datasource, "@")
if 0 <= pos && pos+1 < len(datasource) {
datasource = datasource[pos+1:]
}
return datasource
}
func escape(input string) string {
var b strings.Builder
for _, ch := range input {
switch ch {
case '\x00':
b.WriteString(`\x00`)
case '\r':
b.WriteString(`\r`)
case '\n':
b.WriteString(`\n`)
case '\\':
b.WriteString(`\\`)
case '\'':
b.WriteString(`\'`)
case '"':
b.WriteString(`\"`)
case '\x1a':
b.WriteString(`\x1a`)
default:
b.WriteRune(ch)
}
}
return b.String()
}
func format(query string, args ...interface{}) (string, error) {
numArgs := len(args)
if numArgs == 0 {
return query, nil
}
var b strings.Builder
argIndex := 0
for _, ch := range query {
if ch == '?' {
if argIndex >= numArgs {
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
}
arg := args[argIndex]
argIndex++
switch v := arg.(type) {
case bool:
if v {
b.WriteByte('1')
} else {
b.WriteByte('0')
}
case string:
b.WriteByte('\'')
b.WriteString(escape(v))
b.WriteByte('\'')
default:
b.WriteString(mapping.Repr(v))
}
} else {
b.WriteRune(ch)
}
}
if argIndex < numArgs {
return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
}
return b.String(), nil
}
func logInstanceError(datasource string, err error) {
datasource = desensitize(datasource)
logx.Errorf("Error on getting sql instance of %s: %v", datasource, err)
}
func logSqlError(stmt string, err error) {
if err != nil && err != ErrNotFound {
logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
}
}

View File

@@ -25,9 +25,10 @@ const (
)
type Plugin struct {
Api *spec.ApiSpec
Style string
Dir string
Api *spec.ApiSpec
ApiFilePath string
Style string
Dir string
}
func PluginCommand(c *cli.Context) error {
@@ -86,6 +87,12 @@ func prepareArgs(c *cli.Context) ([]byte, error) {
transferData.Api = api
}
absApiFilePath, err := filepath.Abs(apiPath)
if err != nil {
return nil, err
}
transferData.ApiFilePath = absApiFilePath
dirAbs, err := filepath.Abs(c.String("dir"))
if err != nil {
return nil, err

View File

@@ -4,6 +4,7 @@ import (
"go/build"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
@@ -44,7 +45,9 @@ func TestRpcGenerate(t *testing.T) {
assert.Nil(t, err)
_, err = execx.Run("go test "+projectName, projectDir)
if err != nil {
assert.Contains(t, err.Error(), "not in GOROOT")
assert.True(t, func() bool {
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
}())
}
// case go mod
@@ -61,7 +64,9 @@ func TestRpcGenerate(t *testing.T) {
assert.Nil(t, err)
_, err = execx.Run("go test "+projectName, projectDir)
if err != nil {
assert.Contains(t, err.Error(), "not in GOROOT")
assert.True(t, func() bool {
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
}())
}
// case not in go mod and go path
@@ -69,7 +74,9 @@ func TestRpcGenerate(t *testing.T) {
assert.Nil(t, err)
_, err = execx.Run("go test "+projectName, projectDir)
if err != nil {
assert.Contains(t, err.Error(), "not in GOROOT")
assert.True(t, func() bool {
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
}())
}
// invalid directory

View File

@@ -15,12 +15,12 @@ const svcTemplate = `package svc
import {{.imports}}
type ServiceContext struct {
c config.Config
Config config.Config
}
func NewServiceContext(c config.Config) *ServiceContext {
return &ServiceContext{
c:c,
Config:c,
}
}
`

View File

@@ -54,14 +54,15 @@ func Clean() error {
return util.Clean(category)
}
func Update(category string) error {
func Update() error {
err := Clean()
if err != nil {
return err
}
return util.InitTemplates(category, templates)
}
func GetCategory() string {
func Category() string {
return category
}

View File

@@ -97,8 +97,7 @@ func TestUpdate(t *testing.T) {
}
assert.Equal(t, "modify", string(data))
err = Update(category)
assert.Nil(t, err)
assert.Nil(t, Update())
data, err = ioutil.ReadFile(mainTpl)
if err != nil {
@@ -109,6 +108,6 @@ func TestUpdate(t *testing.T) {
func TestGetCategory(t *testing.T) {
_ = Clean()
result := GetCategory()
result := Category()
assert.Equal(t, category, result)
}

View File

@@ -76,12 +76,16 @@ func UpdateTemplates(ctx *cli.Context) (err error) {
}
}()
switch category {
case gogen.GetCategory():
return gogen.Update(category)
case rpcgen.GetCategory():
return rpcgen.Update(category)
case modelgen.GetCategory():
return modelgen.Update(category)
case docker.Category():
return docker.Update()
case gogen.Category():
return gogen.Update()
case kube.Category():
return kube.Update()
case rpcgen.Category():
return rpcgen.Update()
case modelgen.Category():
return modelgen.Update()
default:
err = fmt.Errorf("unexpected category: %s", category)
return
@@ -97,11 +101,15 @@ func RevertTemplates(ctx *cli.Context) (err error) {
}
}()
switch category {
case gogen.GetCategory():
case docker.Category():
return docker.RevertTemplate(filename)
case kube.Category():
return kube.RevertTemplate(filename)
case gogen.Category():
return gogen.RevertTemplate(filename)
case rpcgen.GetCategory():
case rpcgen.Category():
return rpcgen.RevertTemplate(filename)
case modelgen.GetCategory():
case modelgen.Category():
return modelgen.RevertTemplate(filename)
default:
err = fmt.Errorf("unexpected category: %s", category)

View File

@@ -60,7 +60,6 @@ func FindGoModPath(dir string) (string, bool) {
var hasGoMod = false
for {
if FileExists(filepath.Join(tempPath, goModeIdentifier)) {
tempPath = filepath.Dir(tempPath)
rootPath = strings.TrimPrefix(absDir[len(tempPath):], "/")
hasGoMod = true
break

View File

@@ -19,10 +19,11 @@ func Untitle(s string) string {
}
func Index(slice []string, item string) int {
for i, _ := range slice {
for i := range slice {
if slice[i] == item {
return i
}
}
return -1
}

View File

@@ -64,10 +64,10 @@ func (l *Logger) Warning(args ...interface{}) {
// ignore builtin grpc warning
}
func (l *Logger) Warningln(args ...interface{}) {
// ignore builtin grpc warning
}
func (l *Logger) Warningf(format string, args ...interface{}) {
// ignore builtin grpc warning
}
func (l *Logger) Warningln(args ...interface{}) {
// ignore builtin grpc warning
}

View File

@@ -0,0 +1,83 @@
package internal
import (
"log"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
const content = "foo"
func TestLoggerError(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
logger := new(Logger)
logger.Error(content)
assert.Contains(t, builder.String(), content)
}
func TestLoggerErrorf(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
logger := new(Logger)
logger.Errorf(content)
assert.Contains(t, builder.String(), content)
}
func TestLoggerErrorln(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
logger := new(Logger)
logger.Errorln(content)
assert.Contains(t, builder.String(), content)
}
func TestLoggerFatal(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
logger := new(Logger)
logger.Fatal(content)
assert.Contains(t, builder.String(), content)
}
func TestLoggerFatalf(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
logger := new(Logger)
logger.Fatalf(content)
assert.Contains(t, builder.String(), content)
}
func TestLoggerFatalln(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
logger := new(Logger)
logger.Fatalln(content)
assert.Contains(t, builder.String(), content)
}
func TestLoggerWarning(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
logger := new(Logger)
logger.Warning(content)
assert.Empty(t, builder.String())
}
func TestLoggerWarningf(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
logger := new(Logger)
logger.Warningf(content)
assert.Empty(t, builder.String())
}
func TestLoggerWarningln(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
logger := new(Logger)
logger.Warningln(content)
assert.Empty(t, builder.String())
}