initial import

This commit is contained in:
kevin
2020-07-26 17:09:05 +08:00
commit 7e3a369a8f
647 changed files with 54754 additions and 0 deletions

View File

@@ -0,0 +1,187 @@
package sqlx
import (
"database/sql"
"fmt"
"strings"
"time"
"zero/core/executors"
"zero/core/logx"
"zero/core/stringx"
)
const (
flushInterval = time.Second
maxBulkRows = 1000
valuesKeyword = "values"
)
var emptyBulkStmt bulkStmt
type (
ResultHandler func(sql.Result, error)
BulkInserter struct {
executor *executors.PeriodicalExecutor
inserter *dbInserter
stmt bulkStmt
}
bulkStmt struct {
prefix string
valueFormat string
suffix string
}
)
func NewBulkInserter(sqlConn SqlConn, stmt string) (*BulkInserter, error) {
bkStmt, err := parseInsertStmt(stmt)
if err != nil {
return nil, err
}
inserter := &dbInserter{
sqlConn: sqlConn,
stmt: bkStmt,
}
return &BulkInserter{
executor: executors.NewPeriodicalExecutor(flushInterval, inserter),
inserter: inserter,
stmt: bkStmt,
}, nil
}
func (bi *BulkInserter) Flush() {
bi.executor.Flush()
}
func (bi *BulkInserter) Insert(args ...interface{}) error {
value, err := format(bi.stmt.valueFormat, args...)
if err != nil {
return err
}
bi.executor.Add(value)
return nil
}
func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
bi.executor.Sync(func() {
bi.inserter.resultHandler = handler
})
}
func (bi *BulkInserter) UpdateOrDelete(fn func()) {
bi.executor.Flush()
fn()
}
func (bi *BulkInserter) UpdateStmt(stmt string) error {
bkStmt, err := parseInsertStmt(stmt)
if err != nil {
return err
}
bi.executor.Flush()
bi.executor.Sync(func() {
bi.inserter.stmt = bkStmt
})
return nil
}
type dbInserter struct {
sqlConn SqlConn
stmt bulkStmt
values []string
resultHandler ResultHandler
}
func (in *dbInserter) AddTask(task interface{}) bool {
in.values = append(in.values, task.(string))
return len(in.values) >= maxBulkRows
}
func (in *dbInserter) Execute(bulk interface{}) {
values := bulk.([]string)
if len(values) == 0 {
return
}
stmtWithoutValues := in.stmt.prefix
valuesStr := strings.Join(values, ", ")
stmt := strings.Join([]string{stmtWithoutValues, valuesStr}, " ")
if len(in.stmt.suffix) > 0 {
stmt = strings.Join([]string{stmt, in.stmt.suffix}, " ")
}
result, err := in.sqlConn.Exec(stmt)
if in.resultHandler != nil {
in.resultHandler(result, err)
} else if err != nil {
logx.Errorf("sql: %s, error: %s", stmt, err)
}
}
func (in *dbInserter) RemoveAll() interface{} {
values := in.values
in.values = nil
return values
}
func parseInsertStmt(stmt string) (bulkStmt, error) {
lower := strings.ToLower(stmt)
pos := strings.Index(lower, valuesKeyword)
if pos <= 0 {
return emptyBulkStmt, fmt.Errorf("bad sql: %q", stmt)
}
var columns int
right := strings.LastIndexByte(lower[:pos], ')')
if right > 0 {
left := strings.LastIndexByte(lower[:right], '(')
if left > 0 {
values := lower[left+1 : right]
values = stringx.Filter(values, func(r rune) bool {
return r == ' ' || r == '\t' || r == '\r' || r == '\n'
})
fields := strings.FieldsFunc(values, func(r rune) bool {
return r == ','
})
columns = len(fields)
}
}
var variables int
var valueFormat string
var suffix string
left := strings.IndexByte(lower[pos:], '(')
if left > 0 {
right = strings.IndexByte(lower[pos+left:], ')')
if right > 0 {
values := lower[pos+left : pos+left+right]
for _, x := range values {
if x == '?' {
variables++
}
}
valueFormat = stmt[pos+left : pos+left+right+1]
suffix = strings.TrimSpace(stmt[pos+left+right+1:])
}
}
if variables == 0 {
return emptyBulkStmt, fmt.Errorf("no variables: %q", stmt)
}
if columns > 0 && columns != variables {
return emptyBulkStmt, fmt.Errorf("columns and variables mismatch: %q", stmt)
}
return bulkStmt{
prefix: stmt[:pos+len(valuesKeyword)],
valueFormat: valueFormat,
suffix: suffix,
}, nil
}

View File

@@ -0,0 +1,98 @@
package sqlx
import (
"database/sql"
"strconv"
"testing"
"zero/core/logx"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
)
type mockedConn struct {
query string
args []interface{}
}
func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
c.query = query
c.args = args
return nil, nil
}
func (c *mockedConn) Prepare(query string) (StmtSession, error) {
panic("should not called")
}
func (c *mockedConn) QueryRow(v interface{}, query string, args ...interface{}) error {
panic("should not called")
}
func (c *mockedConn) QueryRowPartial(v interface{}, query string, args ...interface{}) error {
panic("should not called")
}
func (c *mockedConn) QueryRows(v interface{}, query string, args ...interface{}) error {
panic("should not called")
}
func (c *mockedConn) QueryRowsPartial(v interface{}, query string, args ...interface{}) error {
panic("should not called")
}
func (c *mockedConn) Transact(func(session Session) error) error {
panic("should not called")
}
func TestBulkInserter(t *testing.T) {
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var conn mockedConn
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
assert.Nil(t, err)
for i := 0; i < 5; i++ {
assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
}
inserter.Flush()
assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
`('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
`('class_3', 'user_3', 3), ('class_4', 'user_4', 4)`,
conn.query)
assert.Nil(t, conn.args)
})
}
func TestBulkInserterSuffix(t *testing.T) {
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var conn mockedConn
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+
`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)
assert.Nil(t, err)
for i := 0; i < 5; i++ {
assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
}
inserter.Flush()
assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
`('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
`('class_3', 'user_3', 3), ('class_4', 'user_4', 4) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`,
conn.query)
assert.Nil(t, conn.args)
})
}
func runSqlTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
logx.Disable()
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
fn(db, mock)
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

37
core/stores/sqlx/mysql.go Normal file
View File

@@ -0,0 +1,37 @@
package sqlx
import "github.com/go-sql-driver/mysql"
const (
mysqlDriverName = "mysql"
duplicateEntryCode uint16 = 1062
)
func NewMysql(datasource string, opts ...SqlOption) SqlConn {
opts = append(opts, withMysqlAcceptable())
return NewSqlConn(mysqlDriverName, datasource, opts...)
}
func mysqlAcceptable(err error) bool {
if err == nil {
return true
}
myerr, ok := err.(*mysql.MySQLError)
if !ok {
return false
}
switch myerr.Number {
case duplicateEntryCode:
return true
default:
return false
}
}
func withMysqlAcceptable() SqlOption {
return func(conn *commonSqlConn) {
conn.accept = mysqlAcceptable
}
}

View File

@@ -0,0 +1,56 @@
package sqlx
import (
"testing"
"zero/core/breaker"
"zero/core/logx"
"zero/core/stat"
"github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
)
func init() {
stat.SetReporter(nil)
}
func TestBreakerOnDuplicateEntry(t *testing.T) {
logx.Disable()
err := tryOnDuplicateEntryError(t, mysqlAcceptable)
assert.Equal(t, duplicateEntryCode, err.(*mysql.MySQLError).Number)
}
func TestBreakerOnNotHandlingDuplicateEntry(t *testing.T) {
logx.Disable()
var found bool
for i := 0; i < 100; i++ {
if tryOnDuplicateEntryError(t, nil) == breaker.ErrServiceUnavailable {
found = true
}
}
assert.True(t, found)
}
func tryOnDuplicateEntryError(t *testing.T, accept func(error) bool) error {
logx.Disable()
conn := commonSqlConn{
brk: breaker.NewBreaker(),
accept: accept,
}
for i := 0; i < 1000; i++ {
assert.NotNil(t, conn.brk.DoWithAcceptable(func() error {
return &mysql.MySQLError{
Number: duplicateEntryCode,
}
}, conn.acceptable))
}
return conn.brk.DoWithAcceptable(func() error {
return &mysql.MySQLError{
Number: duplicateEntryCode,
}
}, conn.acceptable)
}

254
core/stores/sqlx/orm.go Normal file
View File

@@ -0,0 +1,254 @@
package sqlx
import (
"errors"
"reflect"
"strings"
"zero/core/mapping"
)
const tagName = "db"
var (
ErrNotMatchDestination = errors.New("not matching destination to scan")
ErrNotReadableValue = errors.New("value not addressable or interfaceable")
ErrNotSettable = errors.New("passed in variable is not settable")
ErrUnsupportedValueType = errors.New("unsupported unmarshal type")
)
type rowsScanner interface {
Columns() ([]string, error)
Err() error
Next() bool
Scan(v ...interface{}) error
}
func getTaggedFieldValueMap(v reflect.Value) (map[string]interface{}, error) {
rt := mapping.Deref(v.Type())
size := rt.NumField()
result := make(map[string]interface{}, size)
for i := 0; i < size; i++ {
key := parseTagName(rt.Field(i))
if len(key) == 0 {
return nil, nil
}
valueField := reflect.Indirect(v).Field(i)
switch valueField.Kind() {
case reflect.Ptr:
if !valueField.CanInterface() {
return nil, ErrNotReadableValue
}
if valueField.IsNil() {
baseValueType := mapping.Deref(valueField.Type())
valueField.Set(reflect.New(baseValueType))
}
result[key] = valueField.Interface()
default:
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
result[key] = valueField.Addr().Interface()
}
}
return result, nil
}
func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]interface{}, error) {
fields := unwrapFields(v)
if strict && len(columns) < len(fields) {
return nil, ErrNotMatchDestination
}
taggedMap, err := getTaggedFieldValueMap(v)
if err != nil {
return nil, err
}
values := make([]interface{}, len(columns))
if len(taggedMap) == 0 {
for i := 0; i < len(values); i++ {
valueField := fields[i]
switch valueField.Kind() {
case reflect.Ptr:
if !valueField.CanInterface() {
return nil, ErrNotReadableValue
}
if valueField.IsNil() {
baseValueType := mapping.Deref(valueField.Type())
valueField.Set(reflect.New(baseValueType))
}
values[i] = valueField.Interface()
default:
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
values[i] = valueField.Addr().Interface()
}
}
} else {
for i, column := range columns {
if tagged, ok := taggedMap[column]; ok {
values[i] = tagged
} else {
var anonymous interface{}
values[i] = &anonymous
}
}
}
return values, nil
}
func parseTagName(field reflect.StructField) string {
key := field.Tag.Get(tagName)
if len(key) == 0 {
return ""
} else {
options := strings.Split(key, ",")
return options[0]
}
}
func unmarshalRow(v interface{}, scanner rowsScanner, strict bool) error {
if !scanner.Next() {
if err := scanner.Err(); err != nil {
return err
}
return ErrNotFound
}
rv := reflect.ValueOf(v)
if err := mapping.ValidatePtr(&rv); err != nil {
return err
}
rte := reflect.TypeOf(v).Elem()
rve := rv.Elem()
switch rte.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String:
if rve.CanSet() {
return scanner.Scan(v)
} else {
return ErrNotSettable
}
case reflect.Struct:
columns, err := scanner.Columns()
if err != nil {
return err
}
if values, err := mapStructFieldsIntoSlice(rve, columns, strict); err != nil {
return err
} else {
return scanner.Scan(values...)
}
default:
return ErrUnsupportedValueType
}
}
func unmarshalRows(v interface{}, scanner rowsScanner, strict bool) error {
rv := reflect.ValueOf(v)
if err := mapping.ValidatePtr(&rv); err != nil {
return err
}
rt := reflect.TypeOf(v)
rte := rt.Elem()
rve := rv.Elem()
switch rte.Kind() {
case reflect.Slice:
if rve.CanSet() {
ptr := rte.Elem().Kind() == reflect.Ptr
appendFn := func(item reflect.Value) {
if ptr {
rve.Set(reflect.Append(rve, item))
} else {
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
}
}
fillFn := func(value interface{}) error {
if rve.CanSet() {
if err := scanner.Scan(value); err != nil {
return err
} else {
appendFn(reflect.ValueOf(value))
return nil
}
}
return ErrNotSettable
}
base := mapping.Deref(rte.Elem())
switch base.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String:
for scanner.Next() {
value := reflect.New(base)
if err := fillFn(value.Interface()); err != nil {
return err
}
}
case reflect.Struct:
columns, err := scanner.Columns()
if err != nil {
return err
}
for scanner.Next() {
value := reflect.New(base)
if values, err := mapStructFieldsIntoSlice(value, columns, strict); err != nil {
return err
} else {
if err := scanner.Scan(values...); err != nil {
return err
} else {
appendFn(value)
}
}
}
default:
return ErrUnsupportedValueType
}
return nil
} else {
return ErrNotSettable
}
default:
return ErrUnsupportedValueType
}
}
func unwrapFields(v reflect.Value) []reflect.Value {
var fields []reflect.Value
indirect := reflect.Indirect(v)
for i := 0; i < indirect.NumField(); i++ {
child := indirect.Field(i)
if child.Kind() == reflect.Ptr && child.IsNil() {
baseValueType := mapping.Deref(child.Type())
child.Set(reflect.New(baseValueType))
}
child = reflect.Indirect(child)
childType := indirect.Type().Field(i)
if child.Kind() == reflect.Struct && childType.Anonymous {
fields = append(fields, unwrapFields(child)...)
} else {
fields = append(fields, child)
}
}
return fields
}

View File

@@ -0,0 +1,973 @@
package sqlx
import (
"database/sql"
"testing"
"zero/core/logx"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
)
func TestUnmarshalRowBool(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value bool
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.True(t, value)
})
}
func TestUnmarshalRowInt(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, 2, value)
})
}
func TestUnmarshalRowInt8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int8
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, int8(3), value)
})
}
func TestUnmarshalRowInt16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int16
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.Equal(t, int16(4), value)
})
}
func TestUnmarshalRowInt32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.Equal(t, int32(5), value)
})
}
func TestUnmarshalRowInt64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, int64(6), value)
})
}
func TestUnmarshalRowUint(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint(2), value)
})
}
func TestUnmarshalRowUint8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint8
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint8(3), value)
})
}
func TestUnmarshalRowUint16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint16
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint16(4), value)
})
}
func TestUnmarshalRowUint32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint32(5), value)
})
}
func TestUnmarshalRowUint64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint16(6), value)
})
}
func TestUnmarshalRowFloat32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("7")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value float32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, float32(7), value)
})
}
func TestUnmarshalRowFloat64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value float64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, float64(8), value)
})
}
func TestUnmarshalRowString(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
const expect = "hello"
rs := sqlmock.NewRows([]string{"value"}).FromCSVString(expect)
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value string
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowStruct(t *testing.T) {
var value = new(struct {
Name string
Age int
})
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"))
assert.Equal(t, "liao", value.Name)
assert.Equal(t, 5, value.Age)
})
}
func TestUnmarshalRowStructWithTags(t *testing.T) {
var value = new(struct {
Age int `db:"age"`
Name string `db:"name"`
})
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"))
assert.Equal(t, "liao", value.Name)
assert.Equal(t, 5, value.Age)
})
}
func TestUnmarshalRowsBool(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []bool{true, false}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []bool
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsInt(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []int{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsInt8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []int8{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int8
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsInt16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []int16{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int16
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsInt32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []int32{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsInt64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []int64{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUint(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []uint{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUint8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []uint8{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint8
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUint16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []uint16{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint16
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUint32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []uint32{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUint64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []uint64{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsFloat32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []float32{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []float32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsFloat64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []float64{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []float64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsString(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []string{"hello", "world"}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []string
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsBoolPtr(t *testing.T) {
yes := true
no := false
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*bool{&yes, &no}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*bool
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsIntPtr(t *testing.T) {
two := 2
three := 3
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*int{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsInt8Ptr(t *testing.T) {
two := int8(2)
three := int8(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*int8{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int8
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsInt16Ptr(t *testing.T) {
two := int16(2)
three := int16(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*int16{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int16
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsInt32Ptr(t *testing.T) {
two := int32(2)
three := int32(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*int32{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsInt64Ptr(t *testing.T) {
two := int64(2)
three := int64(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*int64{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUintPtr(t *testing.T) {
two := uint(2)
three := uint(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*uint{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUint8Ptr(t *testing.T) {
two := uint8(2)
three := uint8(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*uint8{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint8
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUint16Ptr(t *testing.T) {
two := uint16(2)
three := uint16(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*uint16{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint16
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUint32Ptr(t *testing.T) {
two := uint32(2)
three := uint32(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*uint32{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUint64Ptr(t *testing.T) {
two := uint64(2)
three := uint64(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*uint64{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
two := float32(2)
three := float32(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*float32{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*float32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
two := float64(2)
three := float64(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*float64{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*float64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsStringPtr(t *testing.T) {
hello := "hello"
world := "world"
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*string{&hello, &world}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*string
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsStruct(t *testing.T) {
var expect = []struct {
Name string
Age int64
}{
{
Name: "first",
Age: 2,
},
{
Name: "second",
Age: 3,
},
}
var value []struct {
Name string
Age int64
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
for i, each := range expect {
assert.Equal(t, each.Name, value[i].Name)
assert.Equal(t, each.Age, value[i].Age)
}
})
}
func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
var expect = []struct {
Name string
NullString sql.NullString
}{
{
Name: "first",
NullString: sql.NullString{
String: "firstnullstring",
Valid: true,
},
},
{
Name: "second",
NullString: sql.NullString{
String: "",
Valid: false,
},
},
}
var value []struct {
Name string `db:"name"`
NullString sql.NullString `db:"value"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "value"}).AddRow(
"first", "firstnullstring").AddRow("second", nil)
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
for i, each := range expect {
assert.Equal(t, each.Name, value[i].Name)
assert.Equal(t, each.NullString.String, value[i].NullString.String)
assert.Equal(t, each.NullString.Valid, value[i].NullString.Valid)
}
})
}
func TestUnmarshalRowsStructWithTags(t *testing.T) {
var expect = []struct {
Name string
Age int64
}{
{
Name: "first",
Age: 2,
},
{
Name: "second",
Age: 3,
},
}
var value []struct {
Age int64 `db:"age"`
Name string `db:"name"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
for i, each := range expect {
assert.Equal(t, each.Name, value[i].Name)
assert.Equal(t, each.Age, value[i].Age)
}
})
}
func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) {
type Embed struct {
Value int64 `db:"value"`
}
var expect = []struct {
Name string
Age int64
Value int64
}{
{
Name: "first",
Age: 2,
Value: 3,
},
{
Name: "second",
Age: 3,
Value: 4,
},
}
var value []struct {
Name string `db:"name"`
Age int64 `db:"age"`
Embed
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age, value from users where user=?", "anyone"))
for i, each := range expect {
assert.Equal(t, each.Name, value[i].Name)
assert.Equal(t, each.Age, value[i].Age)
assert.Equal(t, each.Value, value[i].Value)
}
})
}
func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T) {
type Embed struct {
Value int64 `db:"value"`
}
var expect = []struct {
Name string
Age int64
Value int64
}{
{
Name: "first",
Age: 2,
Value: 3,
},
{
Name: "second",
Age: 3,
Value: 4,
},
}
var value []struct {
Name string `db:"name"`
Age int64 `db:"age"`
*Embed
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age, value from users where user=?", "anyone"))
for i, each := range expect {
assert.Equal(t, each.Name, value[i].Name)
assert.Equal(t, each.Age, value[i].Age)
assert.Equal(t, each.Value, value[i].Value)
}
})
}
func TestUnmarshalRowsStructPtr(t *testing.T) {
var expect = []*struct {
Name string
Age int64
}{
{
Name: "first",
Age: 2,
},
{
Name: "second",
Age: 3,
},
}
var value []*struct {
Name string
Age int64
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
for i, each := range expect {
assert.Equal(t, each.Name, value[i].Name)
assert.Equal(t, each.Age, value[i].Age)
}
})
}
func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) {
var expect = []*struct {
Name string
Age int64
}{
{
Name: "first",
Age: 2,
},
{
Name: "second",
Age: 3,
},
}
var value []*struct {
Age int64 `db:"age"`
Name string `db:"name"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
for i, each := range expect {
assert.Equal(t, each.Name, value[i].Name)
assert.Equal(t, each.Age, value[i].Age)
}
})
}
func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
var expect = []*struct {
Name string
Age int64
}{
{
Name: "first",
Age: 2,
},
{
Name: "second",
Age: 3,
},
}
var value []*struct {
Age *int64 `db:"age"`
Name string `db:"name"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
for i, each := range expect {
assert.Equal(t, each.Name, value[i].Name)
assert.Equal(t, each.Age, *value[i].Age)
}
})
}
func TestCommonSqlConn_QueryRowOptional(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var r struct {
User string `db:"user"`
Age int `db:"age"`
}
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&r, rows, false)
}, "select age from users where user=?", "anyone"))
assert.Empty(t, r.User)
assert.Equal(t, 5, r.Age)
})
}
func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
logx.Disable()
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
fn(db, mock)
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

204
core/stores/sqlx/sqlconn.go Normal file
View File

@@ -0,0 +1,204 @@
package sqlx
import (
"database/sql"
"zero/core/breaker"
)
var ErrNotFound = sql.ErrNoRows
type (
// Session stands for raw connections or transaction sessions
Session interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (StmtSession, error)
QueryRow(v interface{}, query string, args ...interface{}) error
QueryRowPartial(v interface{}, query string, args ...interface{}) error
QueryRows(v interface{}, query string, args ...interface{}) error
QueryRowsPartial(v interface{}, query string, args ...interface{}) error
}
// SqlConn only stands for raw connections, so Transact method can be called.
SqlConn interface {
Session
Transact(func(session Session) error) error
}
SqlOption func(*commonSqlConn)
StmtSession interface {
Close() error
Exec(args ...interface{}) (sql.Result, error)
QueryRow(v interface{}, args ...interface{}) error
QueryRowPartial(v interface{}, args ...interface{}) error
QueryRows(v interface{}, args ...interface{}) error
QueryRowsPartial(v interface{}, args ...interface{}) error
}
// thread-safe
// Because CORBA doesn't support PREPARE, so we need to combine the
// query arguments into one string and do underlying query without arguments
commonSqlConn struct {
driverName string
datasource string
beginTx beginnable
brk breaker.Breaker
accept func(error) bool
}
sessionConn interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
}
statement struct {
stmt *sql.Stmt
}
stmtConn interface {
Exec(args ...interface{}) (sql.Result, error)
Query(args ...interface{}) (*sql.Rows, error)
}
)
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{
driverName: driverName,
datasource: datasource,
beginTx: begin,
brk: breaker.NewBreaker(),
}
for _, opt := range opts {
opt(conn)
}
return conn
}
func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
err = db.brk.DoWithAcceptable(func() error {
var conn *sql.DB
conn, err = getSqlConn(db.driverName, db.datasource)
if err != nil {
logInstanceError(db.datasource, err)
return err
}
result, err = exec(conn, q, args...)
return err
}, db.acceptable)
return
}
func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
err = db.brk.DoWithAcceptable(func() error {
var conn *sql.DB
conn, err = getSqlConn(db.driverName, db.datasource)
if err != nil {
logInstanceError(db.datasource, err)
return err
}
if st, err := conn.Prepare(query); err != nil {
return err
} else {
stmt = statement{
stmt: st,
}
return nil
}
}, db.acceptable)
return
}
func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error {
return db.queryRows(func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true)
}, q, args...)
}
func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
return db.queryRows(func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false)
}, q, args...)
}
func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error {
return db.queryRows(func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true)
}, q, args...)
}
func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
return db.queryRows(func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false)
}, q, args...)
}
func (db *commonSqlConn) Transact(fn func(Session) error) error {
return db.brk.DoWithAcceptable(func() error {
return transact(db, db.beginTx, fn)
}, db.acceptable)
}
func (db *commonSqlConn) acceptable(err error) bool {
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone
if db.accept == nil {
return ok
} else {
return ok || db.accept(err)
}
}
func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error {
var qerr error
return db.brk.DoWithAcceptable(func() error {
conn, err := getSqlConn(db.driverName, db.datasource)
if err != nil {
logInstanceError(db.datasource, err)
return err
}
return query(conn, func(rows *sql.Rows) error {
qerr = scanner(rows)
return qerr
}, q, args...)
}, func(err error) bool {
return qerr == err || db.acceptable(err)
})
}
func (s statement) Close() error {
return s.stmt.Close()
}
func (s statement) Exec(args ...interface{}) (sql.Result, error) {
return execStmt(s.stmt, args...)
}
func (s statement) QueryRow(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true)
}, args...)
}
func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false)
}, args...)
}
func (s statement) QueryRows(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true)
}, args...)
}
func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false)
}, args...)
}

View File

@@ -0,0 +1,74 @@
package sqlx
import (
"database/sql"
"io"
"sync"
"time"
"zero/core/syncx"
)
const (
maxIdleConns = 64
maxOpenConns = 64
maxLifetime = time.Minute
)
var connManager = syncx.NewResourceManager()
type pingedDB struct {
*sql.DB
once sync.Once
}
func getCachedSqlConn(driverName, server string) (*pingedDB, error) {
val, err := connManager.GetResource(server, func() (io.Closer, error) {
conn, err := newDBConnection(driverName, server)
if err != nil {
return nil, err
}
return &pingedDB{
DB: conn,
}, nil
})
if err != nil {
return nil, err
}
return val.(*pingedDB), nil
}
func getSqlConn(driverName, server string) (*sql.DB, error) {
pdb, err := getCachedSqlConn(driverName, server)
if err != nil {
return nil, err
}
pdb.once.Do(func() {
err = pdb.Ping()
})
if err != nil {
return nil, err
}
return pdb.DB, nil
}
func newDBConnection(driverName, datasource string) (*sql.DB, error) {
conn, err := sql.Open(driverName, datasource)
if err != nil {
return nil, err
}
// we need to do this until the issue https://github.com/golang/go/issues/9851 get fixed
// discussed here https://github.com/go-sql-driver/mysql/issues/257
// if the discussed SetMaxIdleTimeout methods added, we'll change this behavior
// 8 means we can't have more than 8 goroutines to concurrently access the same database.
conn.SetMaxIdleConns(maxIdleConns)
conn.SetMaxOpenConns(maxOpenConns)
conn.SetConnMaxLifetime(maxLifetime)
return conn, nil
}

92
core/stores/sqlx/stmt.go Normal file
View File

@@ -0,0 +1,92 @@
package sqlx
import (
"database/sql"
"fmt"
"time"
"zero/core/logx"
"zero/core/timex"
)
const slowThreshold = time.Millisecond * 500
func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
stmt, err := format(q, args...)
if err != nil {
return nil, err
}
startTime := timex.Now()
result, err := conn.Exec(q, args...)
duration := timex.Since(startTime)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql exec: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
}
return result, err
}
func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
stmt := fmt.Sprint(args...)
startTime := timex.Now()
result, err := conn.Exec(args...)
duration := timex.Since(startTime)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql execStmt: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
}
return result, err
}
func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
stmt, err := format(q, args...)
if err != nil {
return err
}
startTime := timex.Now()
rows, err := conn.Query(q, args...)
duration := timex.Since(startTime)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql query: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
return err
}
defer rows.Close()
return scanner(rows)
}
func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, args ...interface{}) error {
stmt := fmt.Sprint(args...)
startTime := timex.Now()
rows, err := conn.Query(args...)
duration := timex.Since(startTime)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql queryStmt: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
return err
}
defer rows.Close()
return scanner(rows)
}

103
core/stores/sqlx/tx.go Normal file
View File

@@ -0,0 +1,103 @@
package sqlx
import (
"database/sql"
"fmt"
)
type (
beginnable func(*sql.DB) (trans, error)
trans interface {
Session
Commit() error
Rollback() error
}
txSession struct {
*sql.Tx
}
)
func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) {
return exec(t.Tx, q, args...)
}
func (t txSession) Prepare(q string) (StmtSession, error) {
if stmt, err := t.Tx.Prepare(q); err != nil {
return nil, err
} else {
return statement{
stmt: stmt,
}, nil
}
}
func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error {
return query(t.Tx, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true)
}, q, args...)
}
func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
return query(t.Tx, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false)
}, q, args...)
}
func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error {
return query(t.Tx, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true)
}, q, args...)
}
func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
return query(t.Tx, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false)
}, q, args...)
}
func begin(db *sql.DB) (trans, error) {
if tx, err := db.Begin(); err != nil {
return nil, err
} else {
return txSession{
Tx: tx,
}, nil
}
}
func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) {
conn, err := getSqlConn(db.driverName, db.datasource)
if err != nil {
logInstanceError(db.datasource, err)
return err
}
return transactOnConn(conn, b, fn)
}
func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err error) {
var tx trans
tx, err = b(conn)
if err != nil {
return
}
defer func() {
if p := recover(); p != nil {
if e := tx.Rollback(); e != nil {
err = fmt.Errorf("recover from %#v, rollback failed: %s", p, e)
} else {
err = fmt.Errorf("recoveer from %#v", p)
}
} else if err != nil {
if e := tx.Rollback(); e != nil {
err = fmt.Errorf("transaction failed: %s, rollback failed: %s", err, e)
}
} else {
err = tx.Commit()
}
}()
return fn(tx)
}

View File

@@ -0,0 +1,76 @@
package sqlx
import (
"database/sql"
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
const (
mockCommit = 1
mockRollback = 2
)
type mockTx struct {
status int
}
func (mt *mockTx) Commit() error {
mt.status |= mockCommit
return nil
}
func (mt *mockTx) Exec(q string, args ...interface{}) (sql.Result, error) {
return nil, nil
}
func (mt *mockTx) Prepare(query string) (StmtSession, error) {
return nil, nil
}
func (mt *mockTx) QueryRow(v interface{}, q string, args ...interface{}) error {
return nil
}
func (mt *mockTx) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
return nil
}
func (mt *mockTx) QueryRows(v interface{}, q string, args ...interface{}) error {
return nil
}
func (mt *mockTx) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
return nil
}
func (mt *mockTx) Rollback() error {
mt.status |= mockRollback
return nil
}
func beginMock(mock *mockTx) beginnable {
return func(*sql.DB) (trans, error) {
return mock, nil
}
}
func TestTransactCommit(t *testing.T) {
mock := &mockTx{}
err := transactOnConn(nil, beginMock(mock), func(Session) error {
return nil
})
assert.Equal(t, mockCommit, mock.status)
assert.Nil(t, err)
}
func TestTransactRollback(t *testing.T) {
mock := &mockTx{}
err := transactOnConn(nil, beginMock(mock), func(Session) error {
return errors.New("rollback")
})
assert.Equal(t, mockRollback, mock.status)
assert.NotNil(t, err)
}

101
core/stores/sqlx/utils.go Normal file
View File

@@ -0,0 +1,101 @@
package sqlx
import (
"fmt"
"strings"
"zero/core/logx"
"zero/core/mapping"
)
func desensitize(datasource string) string {
// remove account
pos := strings.LastIndex(datasource, "@")
if 0 <= pos && pos+1 < len(datasource) {
datasource = datasource[pos+1:]
}
return datasource
}
func escape(input string) string {
var b strings.Builder
for _, ch := range input {
switch ch {
case '\x00':
b.WriteString(`\x00`)
case '\r':
b.WriteString(`\r`)
case '\n':
b.WriteString(`\n`)
case '\\':
b.WriteString(`\\`)
case '\'':
b.WriteString(`\'`)
case '"':
b.WriteString(`\"`)
case '\x1a':
b.WriteString(`\x1a`)
default:
b.WriteRune(ch)
}
}
return b.String()
}
func format(query string, args ...interface{}) (string, error) {
numArgs := len(args)
if numArgs == 0 {
return query, nil
}
var b strings.Builder
argIndex := 0
for _, ch := range query {
if ch == '?' {
if argIndex >= numArgs {
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
}
arg := args[argIndex]
argIndex++
switch v := arg.(type) {
case bool:
if v {
b.WriteByte('1')
} else {
b.WriteByte('0')
}
case string:
b.WriteByte('\'')
b.WriteString(escape(v))
b.WriteByte('\'')
default:
b.WriteString(mapping.Repr(v))
}
} else {
b.WriteRune(ch)
}
}
if argIndex < numArgs {
return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
}
return b.String(), nil
}
func logInstanceError(datasource string, err error) {
datasource = desensitize(datasource)
logx.Errorf("Error on getting sql instance of %s: %v", datasource, err)
}
func logSqlError(stmt string, err error) {
if err != nil && err != ErrNotFound {
logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
}
}

View File

@@ -0,0 +1,30 @@
package sqlx
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestEscape(t *testing.T) {
s := "a\x00\n\r\\'\"\x1ab"
out := escape(s)
assert.Equal(t, `a\x00\n\r\\\'\"\x1ab`, out)
}
func TestDesensitize(t *testing.T) {
datasource := "user:pass@tcp(111.222.333.44:3306)/any_table?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai"
datasource = desensitize(datasource)
assert.False(t, strings.Contains(datasource, "user"))
assert.False(t, strings.Contains(datasource, "pass"))
assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
}
func TestDesensitize_WithoutAccount(t *testing.T) {
datasource := "tcp(111.222.333.44:3306)/any_table?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai"
datasource = desensitize(datasource)
assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
}