Add MustTempDir (#1069)

This commit is contained in:
anqiansong
2021-09-21 10:13:43 +08:00
committed by GitHub
parent 30e49f2939
commit 9a724fe907
11 changed files with 48 additions and 25 deletions

View File

@@ -8,7 +8,9 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/api/parser/g4/ast" "github.com/tal-tech/go-zero/tools/goctl/api/parser/g4/ast"
"github.com/tal-tech/go-zero/tools/goctl/util"
) )
var ( var (
@@ -118,7 +120,7 @@ func TestApiParser(t *testing.T) {
}) })
t.Run("nestedImport", func(t *testing.T) { t.Run("nestedImport", func(t *testing.T) {
file := filepath.Join(t.TempDir(), "foo.api") file := filepath.Join(util.MustTempDir(), "foo.api")
err := ioutil.WriteFile(file, []byte(nestedAPIImport), os.ModePerm) err := ioutil.WriteFile(file, []byte(nestedAPIImport), os.ModePerm)
if err != nil { if err != nil {
return return
@@ -148,7 +150,7 @@ func TestApiParser(t *testing.T) {
}) })
t.Run("ambiguousSyntax", func(t *testing.T) { t.Run("ambiguousSyntax", func(t *testing.T) {
file := filepath.Join(t.TempDir(), "foo.api") file := filepath.Join(util.MustTempDir(), "foo.api")
err := ioutil.WriteFile(file, []byte(ambiguousSyntax), os.ModePerm) err := ioutil.WriteFile(file, []byte(ambiguousSyntax), os.ModePerm)
if err != nil { if err != nil {
return return
@@ -162,7 +164,7 @@ func TestApiParser(t *testing.T) {
}) })
t.Run("ambiguousSyntax", func(t *testing.T) { t.Run("ambiguousSyntax", func(t *testing.T) {
file := filepath.Join(t.TempDir(), "foo.api") file := filepath.Join(util.MustTempDir(), "foo.api")
err := ioutil.WriteFile(file, []byte(ambiguousSyntax), os.ModePerm) err := ioutil.WriteFile(file, []byte(ambiguousSyntax), os.ModePerm)
if err != nil { if err != nil {
return return
@@ -176,7 +178,7 @@ func TestApiParser(t *testing.T) {
}) })
t.Run("ambiguousService", func(t *testing.T) { t.Run("ambiguousService", func(t *testing.T) {
file := filepath.Join(t.TempDir(), "foo.api") file := filepath.Join(util.MustTempDir(), "foo.api")
err := ioutil.WriteFile(file, []byte(ambiguousService), os.ModePerm) err := ioutil.WriteFile(file, []byte(ambiguousService), os.ModePerm)
if err != nil { if err != nil {
return return
@@ -206,7 +208,7 @@ func TestApiParser(t *testing.T) {
`) `)
assert.Error(t, err) assert.Error(t, err)
file := filepath.Join(t.TempDir(), "foo.api") file := filepath.Join(util.MustTempDir(), "foo.api")
err = ioutil.WriteFile(file, []byte(duplicateHandler), os.ModePerm) err = ioutil.WriteFile(file, []byte(duplicateHandler), os.ModePerm)
if err != nil { if err != nil {
return return
@@ -235,7 +237,7 @@ func TestApiParser(t *testing.T) {
`) `)
assert.Error(t, err) assert.Error(t, err)
file := filepath.Join(t.TempDir(), "foo.api") file := filepath.Join(util.MustTempDir(), "foo.api")
err = ioutil.WriteFile(file, []byte(duplicateRoute), os.ModePerm) err = ioutil.WriteFile(file, []byte(duplicateRoute), os.ModePerm)
if err != nil { if err != nil {
return return
@@ -259,7 +261,7 @@ func TestApiParser(t *testing.T) {
`) `)
assert.Error(t, err) assert.Error(t, err)
file := filepath.Join(t.TempDir(), "foo.api") file := filepath.Join(util.MustTempDir(), "foo.api")
err = ioutil.WriteFile(file, []byte(duplicateType), os.ModePerm) err = ioutil.WriteFile(file, []byte(duplicateType), os.ModePerm)
if err != nil { if err != nil {
return return

View File

@@ -48,7 +48,7 @@ func convertVersion(version string) (versionNumber float64, tag string) {
} }
return '_' return '_'
}, splits[0]) }, splits[0])
numberStr = strings.ReplaceAll(numberStr, "_", "") numberStr = strings.Replace(numberStr, "_", "", -1)
versionNumber, _ = json.Number(numberStr).Float64() versionNumber, _ = json.Number(numberStr).Float64()
return return
} }

View File

@@ -6,7 +6,9 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/config" "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/util"
) )
var testTypes = ` var testTypes = `
@@ -18,7 +20,7 @@ func TestDo(t *testing.T) {
cfg, err := config.NewConfig(config.DefaultFormat) cfg, err := config.NewConfig(config.DefaultFormat)
assert.Nil(t, err) assert.Nil(t, err)
tempDir := t.TempDir() tempDir := util.MustTempDir()
typesfile := filepath.Join(tempDir, "types.go") typesfile := filepath.Join(tempDir, "types.go")
err = ioutil.WriteFile(typesfile, []byte(testTypes), 0o666) err = ioutil.WriteFile(typesfile, []byte(testTypes), 0o666)
assert.Nil(t, err) assert.Nil(t, err)

View File

@@ -7,6 +7,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/config" "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/gen" "github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
@@ -23,12 +24,12 @@ func TestFromDDl(t *testing.T) {
err := gen.Clean() err := gen.Clean()
assert.Nil(t, err) assert.Nil(t, err)
err = fromDDL("./user.sql", t.TempDir(), cfg, true, false, "go_zero") err = fromDDL("./user.sql", util.MustTempDir(), cfg, true, false, "go_zero")
assert.Equal(t, errNotMatched, err) assert.Equal(t, errNotMatched, err)
// case dir is not exists // case dir is not exists
unknownDir := filepath.Join(t.TempDir(), "test", "user.sql") unknownDir := filepath.Join(util.MustTempDir(), "test", "user.sql")
err = fromDDL(unknownDir, t.TempDir(), cfg, true, false, "go_zero") err = fromDDL(unknownDir, util.MustTempDir(), cfg, true, false, "go_zero")
assert.True(t, func() bool { assert.True(t, func() bool {
switch err.(type) { switch err.(type) {
case *os.PathError: case *os.PathError:
@@ -39,12 +40,12 @@ func TestFromDDl(t *testing.T) {
}()) }())
// case empty src // case empty src
err = fromDDL("", t.TempDir(), cfg, true, false, "go_zero") err = fromDDL("", util.MustTempDir(), cfg, true, false, "go_zero")
if err != nil { if err != nil {
assert.Equal(t, "expected path or path globbing patterns, but nothing found", err.Error()) assert.Equal(t, "expected path or path globbing patterns, but nothing found", err.Error())
} }
tempDir := filepath.Join(t.TempDir(), "test") tempDir := filepath.Join(util.MustTempDir(), "test")
err = util.MkdirIfNotExist(tempDir) err = util.MkdirIfNotExist(tempDir)
if err != nil { if err != nil {
return return

View File

@@ -10,10 +10,12 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/stringx" "github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/config" "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx" "github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
"github.com/tal-tech/go-zero/tools/goctl/util"
) )
var source = "CREATE TABLE `test_user` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL,\n `class` bigint NOT NULL,\n `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `mobile_unique` (`mobile`),\n UNIQUE KEY `class_name_unique` (`class`,`name`),\n KEY `create_index` (`create_time`),\n KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;" var source = "CREATE TABLE `test_user` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL,\n `class` bigint NOT NULL,\n `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `mobile_unique` (`mobile`),\n UNIQUE KEY `class_name_unique` (`class`,`name`),\n KEY `create_index` (`create_time`),\n KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;"
@@ -22,11 +24,11 @@ func TestCacheModel(t *testing.T) {
logx.Disable() logx.Disable()
_ = Clean() _ = Clean()
sqlFile := filepath.Join(t.TempDir(), "tmp.sql") sqlFile := filepath.Join(util.MustTempDir(), "tmp.sql")
err := ioutil.WriteFile(sqlFile, []byte(source), 0o777) err := ioutil.WriteFile(sqlFile, []byte(source), 0o777)
assert.Nil(t, err) assert.Nil(t, err)
dir := filepath.Join(t.TempDir(), "./testmodel") dir := filepath.Join(util.MustTempDir(), "./testmodel")
cacheDir := filepath.Join(dir, "cache") cacheDir := filepath.Join(dir, "cache")
noCacheDir := filepath.Join(dir, "nocache") noCacheDir := filepath.Join(dir, "nocache")
g, err := NewDefaultGenerator(cacheDir, &config.Config{ g, err := NewDefaultGenerator(cacheDir, &config.Config{
@@ -57,7 +59,7 @@ func TestNamingModel(t *testing.T) {
logx.Disable() logx.Disable()
_ = Clean() _ = Clean()
sqlFile := filepath.Join(t.TempDir(), "tmp.sql") sqlFile := filepath.Join(util.MustTempDir(), "tmp.sql")
err := ioutil.WriteFile(sqlFile, []byte(source), 0o777) err := ioutil.WriteFile(sqlFile, []byte(source), 0o777)
assert.Nil(t, err) assert.Nil(t, err)

View File

@@ -6,12 +6,14 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/model" "github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/util" "github.com/tal-tech/go-zero/tools/goctl/model/sql/util"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
) )
func TestParsePlainText(t *testing.T) { func TestParsePlainText(t *testing.T) {
sqlFile := filepath.Join(t.TempDir(), "tmp.sql") sqlFile := filepath.Join(ctlutil.MustTempDir(), "tmp.sql")
err := ioutil.WriteFile(sqlFile, []byte("plain text"), 0o777) err := ioutil.WriteFile(sqlFile, []byte("plain text"), 0o777)
assert.Nil(t, err) assert.Nil(t, err)
@@ -20,7 +22,7 @@ func TestParsePlainText(t *testing.T) {
} }
func TestParseSelect(t *testing.T) { func TestParseSelect(t *testing.T) {
sqlFile := filepath.Join(t.TempDir(), "tmp.sql") sqlFile := filepath.Join(ctlutil.MustTempDir(), "tmp.sql")
err := ioutil.WriteFile(sqlFile, []byte("select * from user"), 0o777) err := ioutil.WriteFile(sqlFile, []byte("select * from user"), 0o777)
assert.Nil(t, err) assert.Nil(t, err)
@@ -30,7 +32,7 @@ func TestParseSelect(t *testing.T) {
} }
func TestParseCreateTable(t *testing.T) { func TestParseCreateTable(t *testing.T) {
sqlFile := filepath.Join(t.TempDir(), "tmp.sql") sqlFile := filepath.Join(ctlutil.MustTempDir(), "tmp.sql")
err := ioutil.WriteFile(sqlFile, []byte("CREATE TABLE `test_user` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL comment '手\\t机 号',\n `class` bigint NOT NULL comment '班级',\n `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL comment '姓\n 名',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP comment '创建\\r时间',\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `mobile_unique` (`mobile`),\n UNIQUE KEY `class_name_unique` (`class`,`name`),\n KEY `create_index` (`create_time`),\n KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;"), 0o777) err := ioutil.WriteFile(sqlFile, []byte("CREATE TABLE `test_user` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL comment '手\\t机 号',\n `class` bigint NOT NULL comment '班级',\n `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL comment '姓\n 名',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP comment '创建\\r时间',\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `mobile_unique` (`mobile`),\n UNIQUE KEY `class_name_unique` (`class`,`name`),\n KEY `create_index` (`create_time`),\n KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;"), 0o777)
assert.Nil(t, err) assert.Nil(t, err)

View File

@@ -12,6 +12,7 @@ import (
"github.com/tal-tech/go-zero/core/stringx" "github.com/tal-tech/go-zero/core/stringx"
conf "github.com/tal-tech/go-zero/tools/goctl/config" conf "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx" "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
"github.com/tal-tech/go-zero/tools/goctl/util"
) )
var cfg = &conf.Config{ var cfg = &conf.Config{
@@ -57,7 +58,7 @@ func TestRpcGenerate(t *testing.T) {
// case go mod // case go mod
t.Run("GOMOD", func(t *testing.T) { t.Run("GOMOD", func(t *testing.T) {
workDir := t.TempDir() workDir := util.MustTempDir()
name := filepath.Base(workDir) name := filepath.Base(workDir)
_, err = execx.Run("go mod init "+name, workDir) _, err = execx.Run("go mod init "+name, workDir)
if err != nil { if err != nil {

View File

@@ -25,7 +25,7 @@ service {{.serviceName}} {
} }
` `
// ProtoTmpl returns an sample of a proto file // ProtoTmpl returns a sample of a proto file
func ProtoTmpl(out string) error { func ProtoTmpl(out string) error {
protoFilename := filepath.Base(out) protoFilename := filepath.Base(out)
serviceName := stringx.From(strings.TrimSuffix(protoFilename, filepath.Ext(protoFilename))) serviceName := stringx.From(strings.TrimSuffix(protoFilename, filepath.Ext(protoFilename)))

View File

@@ -5,16 +5,18 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/util"
) )
func TestProtoTmpl(t *testing.T) { func TestProtoTmpl(t *testing.T) {
_ = Clean() _ = Clean()
// exists dir // exists dir
err := ProtoTmpl(t.TempDir()) err := ProtoTmpl(util.MustTempDir())
assert.Nil(t, err) assert.Nil(t, err)
// not exist dir // not exist dir
dir := filepath.Join(t.TempDir(), "test") dir := filepath.Join(util.MustTempDir(), "test")
err = ProtoTmpl(dir) err = ProtoTmpl(dir)
assert.Nil(t, err) assert.Nil(t, err)
} }

View File

@@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@@ -180,3 +181,13 @@ func createTemplate(file, content string, force bool) error {
_, err = f.WriteString(content) _, err = f.WriteString(content)
return err return err
} }
// MustTempDir creates a temporary directory
func MustTempDir() string {
dir, err := ioutil.TempDir("", "")
if err != nil {
log.Fatalln(err)
}
return dir
}

View File

@@ -39,7 +39,7 @@ func (s String) Upper() string {
// ReplaceAll calls the strings.ReplaceAll // ReplaceAll calls the strings.ReplaceAll
func (s String) ReplaceAll(old, new string) string { func (s String) ReplaceAll(old, new string) string {
return strings.ReplaceAll(s.source, old, new) return strings.Replace(s.source, old, new, -1)
} }
// Source returns the source string value // Source returns the source string value