initial import
This commit is contained in:
187
core/stores/sqlx/bulkinserter.go
Normal file
187
core/stores/sqlx/bulkinserter.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"zero/core/executors"
|
||||
"zero/core/logx"
|
||||
"zero/core/stringx"
|
||||
)
|
||||
|
||||
const (
|
||||
flushInterval = time.Second
|
||||
maxBulkRows = 1000
|
||||
valuesKeyword = "values"
|
||||
)
|
||||
|
||||
var emptyBulkStmt bulkStmt
|
||||
|
||||
type (
|
||||
ResultHandler func(sql.Result, error)
|
||||
|
||||
BulkInserter struct {
|
||||
executor *executors.PeriodicalExecutor
|
||||
inserter *dbInserter
|
||||
stmt bulkStmt
|
||||
}
|
||||
|
||||
bulkStmt struct {
|
||||
prefix string
|
||||
valueFormat string
|
||||
suffix string
|
||||
}
|
||||
)
|
||||
|
||||
func NewBulkInserter(sqlConn SqlConn, stmt string) (*BulkInserter, error) {
|
||||
bkStmt, err := parseInsertStmt(stmt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inserter := &dbInserter{
|
||||
sqlConn: sqlConn,
|
||||
stmt: bkStmt,
|
||||
}
|
||||
|
||||
return &BulkInserter{
|
||||
executor: executors.NewPeriodicalExecutor(flushInterval, inserter),
|
||||
inserter: inserter,
|
||||
stmt: bkStmt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (bi *BulkInserter) Flush() {
|
||||
bi.executor.Flush()
|
||||
}
|
||||
|
||||
func (bi *BulkInserter) Insert(args ...interface{}) error {
|
||||
value, err := format(bi.stmt.valueFormat, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bi.executor.Add(value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
|
||||
bi.executor.Sync(func() {
|
||||
bi.inserter.resultHandler = handler
|
||||
})
|
||||
}
|
||||
|
||||
func (bi *BulkInserter) UpdateOrDelete(fn func()) {
|
||||
bi.executor.Flush()
|
||||
fn()
|
||||
}
|
||||
|
||||
func (bi *BulkInserter) UpdateStmt(stmt string) error {
|
||||
bkStmt, err := parseInsertStmt(stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bi.executor.Flush()
|
||||
bi.executor.Sync(func() {
|
||||
bi.inserter.stmt = bkStmt
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type dbInserter struct {
|
||||
sqlConn SqlConn
|
||||
stmt bulkStmt
|
||||
values []string
|
||||
resultHandler ResultHandler
|
||||
}
|
||||
|
||||
func (in *dbInserter) AddTask(task interface{}) bool {
|
||||
in.values = append(in.values, task.(string))
|
||||
return len(in.values) >= maxBulkRows
|
||||
}
|
||||
|
||||
func (in *dbInserter) Execute(bulk interface{}) {
|
||||
values := bulk.([]string)
|
||||
if len(values) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
stmtWithoutValues := in.stmt.prefix
|
||||
valuesStr := strings.Join(values, ", ")
|
||||
stmt := strings.Join([]string{stmtWithoutValues, valuesStr}, " ")
|
||||
if len(in.stmt.suffix) > 0 {
|
||||
stmt = strings.Join([]string{stmt, in.stmt.suffix}, " ")
|
||||
}
|
||||
result, err := in.sqlConn.Exec(stmt)
|
||||
if in.resultHandler != nil {
|
||||
in.resultHandler(result, err)
|
||||
} else if err != nil {
|
||||
logx.Errorf("sql: %s, error: %s", stmt, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (in *dbInserter) RemoveAll() interface{} {
|
||||
values := in.values
|
||||
in.values = nil
|
||||
return values
|
||||
}
|
||||
|
||||
func parseInsertStmt(stmt string) (bulkStmt, error) {
|
||||
lower := strings.ToLower(stmt)
|
||||
pos := strings.Index(lower, valuesKeyword)
|
||||
if pos <= 0 {
|
||||
return emptyBulkStmt, fmt.Errorf("bad sql: %q", stmt)
|
||||
}
|
||||
|
||||
var columns int
|
||||
right := strings.LastIndexByte(lower[:pos], ')')
|
||||
if right > 0 {
|
||||
left := strings.LastIndexByte(lower[:right], '(')
|
||||
if left > 0 {
|
||||
values := lower[left+1 : right]
|
||||
values = stringx.Filter(values, func(r rune) bool {
|
||||
return r == ' ' || r == '\t' || r == '\r' || r == '\n'
|
||||
})
|
||||
fields := strings.FieldsFunc(values, func(r rune) bool {
|
||||
return r == ','
|
||||
})
|
||||
columns = len(fields)
|
||||
}
|
||||
}
|
||||
|
||||
var variables int
|
||||
var valueFormat string
|
||||
var suffix string
|
||||
left := strings.IndexByte(lower[pos:], '(')
|
||||
if left > 0 {
|
||||
right = strings.IndexByte(lower[pos+left:], ')')
|
||||
if right > 0 {
|
||||
values := lower[pos+left : pos+left+right]
|
||||
for _, x := range values {
|
||||
if x == '?' {
|
||||
variables++
|
||||
}
|
||||
}
|
||||
valueFormat = stmt[pos+left : pos+left+right+1]
|
||||
suffix = strings.TrimSpace(stmt[pos+left+right+1:])
|
||||
}
|
||||
}
|
||||
|
||||
if variables == 0 {
|
||||
return emptyBulkStmt, fmt.Errorf("no variables: %q", stmt)
|
||||
}
|
||||
if columns > 0 && columns != variables {
|
||||
return emptyBulkStmt, fmt.Errorf("columns and variables mismatch: %q", stmt)
|
||||
}
|
||||
|
||||
return bulkStmt{
|
||||
prefix: stmt[:pos+len(valuesKeyword)],
|
||||
valueFormat: valueFormat,
|
||||
suffix: suffix,
|
||||
}, nil
|
||||
}
|
||||
98
core/stores/sqlx/bulkinserter_test.go
Normal file
98
core/stores/sqlx/bulkinserter_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"zero/core/logx"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockedConn struct {
|
||||
query string
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
c.query = query
|
||||
c.args = args
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *mockedConn) Prepare(query string) (StmtSession, error) {
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
func (c *mockedConn) QueryRow(v interface{}, query string, args ...interface{}) error {
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
func (c *mockedConn) QueryRowPartial(v interface{}, query string, args ...interface{}) error {
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
func (c *mockedConn) QueryRows(v interface{}, query string, args ...interface{}) error {
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
func (c *mockedConn) QueryRowsPartial(v interface{}, query string, args ...interface{}) error {
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
func (c *mockedConn) Transact(func(session Session) error) error {
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
func TestBulkInserter(t *testing.T) {
|
||||
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var conn mockedConn
|
||||
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
|
||||
assert.Nil(t, err)
|
||||
for i := 0; i < 5; i++ {
|
||||
assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
|
||||
}
|
||||
inserter.Flush()
|
||||
assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
|
||||
`('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
|
||||
`('class_3', 'user_3', 3), ('class_4', 'user_4', 4)`,
|
||||
conn.query)
|
||||
assert.Nil(t, conn.args)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBulkInserterSuffix(t *testing.T) {
|
||||
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var conn mockedConn
|
||||
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+
|
||||
`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)
|
||||
assert.Nil(t, err)
|
||||
for i := 0; i < 5; i++ {
|
||||
assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
|
||||
}
|
||||
inserter.Flush()
|
||||
assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
|
||||
`('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
|
||||
`('class_3', 'user_3', 3), ('class_4', 'user_4', 4) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`,
|
||||
conn.query)
|
||||
assert.Nil(t, conn.args)
|
||||
})
|
||||
}
|
||||
|
||||
func runSqlTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
|
||||
logx.Disable()
|
||||
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
fn(db, mock)
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||
}
|
||||
}
|
||||
37
core/stores/sqlx/mysql.go
Normal file
37
core/stores/sqlx/mysql.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package sqlx
|
||||
|
||||
import "github.com/go-sql-driver/mysql"
|
||||
|
||||
const (
|
||||
mysqlDriverName = "mysql"
|
||||
duplicateEntryCode uint16 = 1062
|
||||
)
|
||||
|
||||
func NewMysql(datasource string, opts ...SqlOption) SqlConn {
|
||||
opts = append(opts, withMysqlAcceptable())
|
||||
return NewSqlConn(mysqlDriverName, datasource, opts...)
|
||||
}
|
||||
|
||||
func mysqlAcceptable(err error) bool {
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
myerr, ok := err.(*mysql.MySQLError)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch myerr.Number {
|
||||
case duplicateEntryCode:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func withMysqlAcceptable() SqlOption {
|
||||
return func(conn *commonSqlConn) {
|
||||
conn.accept = mysqlAcceptable
|
||||
}
|
||||
}
|
||||
56
core/stores/sqlx/mysql_test.go
Normal file
56
core/stores/sqlx/mysql_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"zero/core/breaker"
|
||||
"zero/core/logx"
|
||||
"zero/core/stat"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
stat.SetReporter(nil)
|
||||
}
|
||||
|
||||
func TestBreakerOnDuplicateEntry(t *testing.T) {
|
||||
logx.Disable()
|
||||
|
||||
err := tryOnDuplicateEntryError(t, mysqlAcceptable)
|
||||
assert.Equal(t, duplicateEntryCode, err.(*mysql.MySQLError).Number)
|
||||
}
|
||||
|
||||
func TestBreakerOnNotHandlingDuplicateEntry(t *testing.T) {
|
||||
logx.Disable()
|
||||
|
||||
var found bool
|
||||
for i := 0; i < 100; i++ {
|
||||
if tryOnDuplicateEntryError(t, nil) == breaker.ErrServiceUnavailable {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func tryOnDuplicateEntryError(t *testing.T, accept func(error) bool) error {
|
||||
logx.Disable()
|
||||
|
||||
conn := commonSqlConn{
|
||||
brk: breaker.NewBreaker(),
|
||||
accept: accept,
|
||||
}
|
||||
for i := 0; i < 1000; i++ {
|
||||
assert.NotNil(t, conn.brk.DoWithAcceptable(func() error {
|
||||
return &mysql.MySQLError{
|
||||
Number: duplicateEntryCode,
|
||||
}
|
||||
}, conn.acceptable))
|
||||
}
|
||||
return conn.brk.DoWithAcceptable(func() error {
|
||||
return &mysql.MySQLError{
|
||||
Number: duplicateEntryCode,
|
||||
}
|
||||
}, conn.acceptable)
|
||||
}
|
||||
254
core/stores/sqlx/orm.go
Normal file
254
core/stores/sqlx/orm.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"zero/core/mapping"
|
||||
)
|
||||
|
||||
const tagName = "db"
|
||||
|
||||
var (
|
||||
ErrNotMatchDestination = errors.New("not matching destination to scan")
|
||||
ErrNotReadableValue = errors.New("value not addressable or interfaceable")
|
||||
ErrNotSettable = errors.New("passed in variable is not settable")
|
||||
ErrUnsupportedValueType = errors.New("unsupported unmarshal type")
|
||||
)
|
||||
|
||||
type rowsScanner interface {
|
||||
Columns() ([]string, error)
|
||||
Err() error
|
||||
Next() bool
|
||||
Scan(v ...interface{}) error
|
||||
}
|
||||
|
||||
func getTaggedFieldValueMap(v reflect.Value) (map[string]interface{}, error) {
|
||||
rt := mapping.Deref(v.Type())
|
||||
size := rt.NumField()
|
||||
result := make(map[string]interface{}, size)
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
key := parseTagName(rt.Field(i))
|
||||
if len(key) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
valueField := reflect.Indirect(v).Field(i)
|
||||
switch valueField.Kind() {
|
||||
case reflect.Ptr:
|
||||
if !valueField.CanInterface() {
|
||||
return nil, ErrNotReadableValue
|
||||
}
|
||||
if valueField.IsNil() {
|
||||
baseValueType := mapping.Deref(valueField.Type())
|
||||
valueField.Set(reflect.New(baseValueType))
|
||||
}
|
||||
result[key] = valueField.Interface()
|
||||
default:
|
||||
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
|
||||
return nil, ErrNotReadableValue
|
||||
}
|
||||
result[key] = valueField.Addr().Interface()
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]interface{}, error) {
|
||||
fields := unwrapFields(v)
|
||||
if strict && len(columns) < len(fields) {
|
||||
return nil, ErrNotMatchDestination
|
||||
}
|
||||
|
||||
taggedMap, err := getTaggedFieldValueMap(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
values := make([]interface{}, len(columns))
|
||||
if len(taggedMap) == 0 {
|
||||
for i := 0; i < len(values); i++ {
|
||||
valueField := fields[i]
|
||||
switch valueField.Kind() {
|
||||
case reflect.Ptr:
|
||||
if !valueField.CanInterface() {
|
||||
return nil, ErrNotReadableValue
|
||||
}
|
||||
if valueField.IsNil() {
|
||||
baseValueType := mapping.Deref(valueField.Type())
|
||||
valueField.Set(reflect.New(baseValueType))
|
||||
}
|
||||
values[i] = valueField.Interface()
|
||||
default:
|
||||
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
|
||||
return nil, ErrNotReadableValue
|
||||
}
|
||||
values[i] = valueField.Addr().Interface()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i, column := range columns {
|
||||
if tagged, ok := taggedMap[column]; ok {
|
||||
values[i] = tagged
|
||||
} else {
|
||||
var anonymous interface{}
|
||||
values[i] = &anonymous
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return values, nil
|
||||
}
|
||||
|
||||
func parseTagName(field reflect.StructField) string {
|
||||
key := field.Tag.Get(tagName)
|
||||
if len(key) == 0 {
|
||||
return ""
|
||||
} else {
|
||||
options := strings.Split(key, ",")
|
||||
return options[0]
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshalRow(v interface{}, scanner rowsScanner, strict bool) error {
|
||||
if !scanner.Next() {
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
if err := mapping.ValidatePtr(&rv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rte := reflect.TypeOf(v).Elem()
|
||||
rve := rv.Elem()
|
||||
switch rte.Kind() {
|
||||
case reflect.Bool,
|
||||
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||
reflect.Float32, reflect.Float64,
|
||||
reflect.String:
|
||||
if rve.CanSet() {
|
||||
return scanner.Scan(v)
|
||||
} else {
|
||||
return ErrNotSettable
|
||||
}
|
||||
case reflect.Struct:
|
||||
columns, err := scanner.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if values, err := mapStructFieldsIntoSlice(rve, columns, strict); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return scanner.Scan(values...)
|
||||
}
|
||||
default:
|
||||
return ErrUnsupportedValueType
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshalRows(v interface{}, scanner rowsScanner, strict bool) error {
|
||||
rv := reflect.ValueOf(v)
|
||||
if err := mapping.ValidatePtr(&rv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rt := reflect.TypeOf(v)
|
||||
rte := rt.Elem()
|
||||
rve := rv.Elem()
|
||||
switch rte.Kind() {
|
||||
case reflect.Slice:
|
||||
if rve.CanSet() {
|
||||
ptr := rte.Elem().Kind() == reflect.Ptr
|
||||
appendFn := func(item reflect.Value) {
|
||||
if ptr {
|
||||
rve.Set(reflect.Append(rve, item))
|
||||
} else {
|
||||
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
|
||||
}
|
||||
}
|
||||
fillFn := func(value interface{}) error {
|
||||
if rve.CanSet() {
|
||||
if err := scanner.Scan(value); err != nil {
|
||||
return err
|
||||
} else {
|
||||
appendFn(reflect.ValueOf(value))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ErrNotSettable
|
||||
}
|
||||
|
||||
base := mapping.Deref(rte.Elem())
|
||||
switch base.Kind() {
|
||||
case reflect.Bool,
|
||||
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||
reflect.Float32, reflect.Float64,
|
||||
reflect.String:
|
||||
for scanner.Next() {
|
||||
value := reflect.New(base)
|
||||
if err := fillFn(value.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
columns, err := scanner.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for scanner.Next() {
|
||||
value := reflect.New(base)
|
||||
if values, err := mapStructFieldsIntoSlice(value, columns, strict); err != nil {
|
||||
return err
|
||||
} else {
|
||||
if err := scanner.Scan(values...); err != nil {
|
||||
return err
|
||||
} else {
|
||||
appendFn(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
return ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
return nil
|
||||
} else {
|
||||
return ErrNotSettable
|
||||
}
|
||||
default:
|
||||
return ErrUnsupportedValueType
|
||||
}
|
||||
}
|
||||
|
||||
func unwrapFields(v reflect.Value) []reflect.Value {
|
||||
var fields []reflect.Value
|
||||
indirect := reflect.Indirect(v)
|
||||
|
||||
for i := 0; i < indirect.NumField(); i++ {
|
||||
child := indirect.Field(i)
|
||||
if child.Kind() == reflect.Ptr && child.IsNil() {
|
||||
baseValueType := mapping.Deref(child.Type())
|
||||
child.Set(reflect.New(baseValueType))
|
||||
}
|
||||
|
||||
child = reflect.Indirect(child)
|
||||
childType := indirect.Type().Field(i)
|
||||
if child.Kind() == reflect.Struct && childType.Anonymous {
|
||||
fields = append(fields, unwrapFields(child)...)
|
||||
} else {
|
||||
fields = append(fields, child)
|
||||
}
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
973
core/stores/sqlx/orm_test.go
Normal file
973
core/stores/sqlx/orm_test.go
Normal file
@@ -0,0 +1,973 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"zero/core/logx"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUnmarshalRowBool(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value bool
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.True(t, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowInt(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, 2, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowInt8(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int8
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, int8(3), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowInt16(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int16
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.Equal(t, int16(4), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowInt32(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.Equal(t, int32(5), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowInt64(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, int64(6), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowUint(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, uint(2), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowUint8(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint8
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, uint8(3), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowUint16(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint16
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, uint16(4), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowUint32(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, uint32(5), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowUint64(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, uint16(6), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowFloat32(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("7")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value float32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, float32(7), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowFloat64(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value float64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, float64(8), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowString(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
const expect = "hello"
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString(expect)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value string
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowStruct(t *testing.T) {
|
||||
var value = new(struct {
|
||||
Name string
|
||||
Age int
|
||||
})
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
assert.Equal(t, "liao", value.Name)
|
||||
assert.Equal(t, 5, value.Age)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowStructWithTags(t *testing.T) {
|
||||
var value = new(struct {
|
||||
Age int `db:"age"`
|
||||
Name string `db:"name"`
|
||||
})
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
assert.Equal(t, "liao", value.Name)
|
||||
assert.Equal(t, 5, value.Age)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsBool(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []bool{true, false}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []bool
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []int{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt8(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []int8{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int8
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt16(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []int16{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int16
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt32(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []int32{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt64(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []int64{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUint(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []uint{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUint8(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []uint8{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint8
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUint16(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []uint16{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint16
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUint32(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []uint32{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUint64(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []uint64{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsFloat32(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []float32{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []float32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsFloat64(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []float64{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []float64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsString(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []string{"hello", "world"}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []string
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsBoolPtr(t *testing.T) {
|
||||
yes := true
|
||||
no := false
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*bool{&yes, &no}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*bool
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsIntPtr(t *testing.T) {
|
||||
two := 2
|
||||
three := 3
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*int{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt8Ptr(t *testing.T) {
|
||||
two := int8(2)
|
||||
three := int8(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*int8{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int8
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt16Ptr(t *testing.T) {
|
||||
two := int16(2)
|
||||
three := int16(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*int16{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int16
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt32Ptr(t *testing.T) {
|
||||
two := int32(2)
|
||||
three := int32(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*int32{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt64Ptr(t *testing.T) {
|
||||
two := int64(2)
|
||||
three := int64(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*int64{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUintPtr(t *testing.T) {
|
||||
two := uint(2)
|
||||
three := uint(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*uint{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUint8Ptr(t *testing.T) {
|
||||
two := uint8(2)
|
||||
three := uint8(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*uint8{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint8
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUint16Ptr(t *testing.T) {
|
||||
two := uint16(2)
|
||||
three := uint16(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*uint16{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint16
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUint32Ptr(t *testing.T) {
|
||||
two := uint32(2)
|
||||
three := uint32(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*uint32{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUint64Ptr(t *testing.T) {
|
||||
two := uint64(2)
|
||||
three := uint64(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*uint64{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
|
||||
two := float32(2)
|
||||
three := float32(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*float32{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*float32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
|
||||
two := float64(2)
|
||||
three := float64(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*float64{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*float64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStringPtr(t *testing.T) {
|
||||
hello := "hello"
|
||||
world := "world"
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*string{&hello, &world}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*string
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStruct(t *testing.T) {
|
||||
var expect = []struct {
|
||||
Name string
|
||||
Age int64
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
Age: 2,
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
Age: 3,
|
||||
},
|
||||
}
|
||||
var value []struct {
|
||||
Name string
|
||||
Age int64
|
||||
}
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
for i, each := range expect {
|
||||
assert.Equal(t, each.Name, value[i].Name)
|
||||
assert.Equal(t, each.Age, value[i].Age)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
|
||||
var expect = []struct {
|
||||
Name string
|
||||
NullString sql.NullString
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
NullString: sql.NullString{
|
||||
String: "firstnullstring",
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
NullString: sql.NullString{
|
||||
String: "",
|
||||
Valid: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
var value []struct {
|
||||
Name string `db:"name"`
|
||||
NullString sql.NullString `db:"value"`
|
||||
}
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "value"}).AddRow(
|
||||
"first", "firstnullstring").AddRow("second", nil)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
for i, each := range expect {
|
||||
assert.Equal(t, each.Name, value[i].Name)
|
||||
assert.Equal(t, each.NullString.String, value[i].NullString.String)
|
||||
assert.Equal(t, each.NullString.Valid, value[i].NullString.Valid)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructWithTags(t *testing.T) {
|
||||
var expect = []struct {
|
||||
Name string
|
||||
Age int64
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
Age: 2,
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
Age: 3,
|
||||
},
|
||||
}
|
||||
var value []struct {
|
||||
Age int64 `db:"age"`
|
||||
Name string `db:"name"`
|
||||
}
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
for i, each := range expect {
|
||||
assert.Equal(t, each.Name, value[i].Name)
|
||||
assert.Equal(t, each.Age, value[i].Age)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) {
|
||||
type Embed struct {
|
||||
Value int64 `db:"value"`
|
||||
}
|
||||
|
||||
var expect = []struct {
|
||||
Name string
|
||||
Age int64
|
||||
Value int64
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
Age: 2,
|
||||
Value: 3,
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
Age: 3,
|
||||
Value: 4,
|
||||
},
|
||||
}
|
||||
var value []struct {
|
||||
Name string `db:"name"`
|
||||
Age int64 `db:"age"`
|
||||
Embed
|
||||
}
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age, value from users where user=?", "anyone"))
|
||||
|
||||
for i, each := range expect {
|
||||
assert.Equal(t, each.Name, value[i].Name)
|
||||
assert.Equal(t, each.Age, value[i].Age)
|
||||
assert.Equal(t, each.Value, value[i].Value)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T) {
|
||||
type Embed struct {
|
||||
Value int64 `db:"value"`
|
||||
}
|
||||
|
||||
var expect = []struct {
|
||||
Name string
|
||||
Age int64
|
||||
Value int64
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
Age: 2,
|
||||
Value: 3,
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
Age: 3,
|
||||
Value: 4,
|
||||
},
|
||||
}
|
||||
var value []struct {
|
||||
Name string `db:"name"`
|
||||
Age int64 `db:"age"`
|
||||
*Embed
|
||||
}
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age, value from users where user=?", "anyone"))
|
||||
|
||||
for i, each := range expect {
|
||||
assert.Equal(t, each.Name, value[i].Name)
|
||||
assert.Equal(t, each.Age, value[i].Age)
|
||||
assert.Equal(t, each.Value, value[i].Value)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructPtr(t *testing.T) {
|
||||
var expect = []*struct {
|
||||
Name string
|
||||
Age int64
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
Age: 2,
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
Age: 3,
|
||||
},
|
||||
}
|
||||
var value []*struct {
|
||||
Name string
|
||||
Age int64
|
||||
}
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
for i, each := range expect {
|
||||
assert.Equal(t, each.Name, value[i].Name)
|
||||
assert.Equal(t, each.Age, value[i].Age)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) {
|
||||
var expect = []*struct {
|
||||
Name string
|
||||
Age int64
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
Age: 2,
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
Age: 3,
|
||||
},
|
||||
}
|
||||
var value []*struct {
|
||||
Age int64 `db:"age"`
|
||||
Name string `db:"name"`
|
||||
}
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
for i, each := range expect {
|
||||
assert.Equal(t, each.Name, value[i].Name)
|
||||
assert.Equal(t, each.Age, value[i].Age)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
|
||||
var expect = []*struct {
|
||||
Name string
|
||||
Age int64
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
Age: 2,
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
Age: 3,
|
||||
},
|
||||
}
|
||||
var value []*struct {
|
||||
Age *int64 `db:"age"`
|
||||
Name string `db:"name"`
|
||||
}
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
for i, each := range expect {
|
||||
assert.Equal(t, each.Name, value[i].Name)
|
||||
assert.Equal(t, each.Age, *value[i].Age)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCommonSqlConn_QueryRowOptional(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var r struct {
|
||||
User string `db:"user"`
|
||||
Age int `db:"age"`
|
||||
}
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&r, rows, false)
|
||||
}, "select age from users where user=?", "anyone"))
|
||||
assert.Empty(t, r.User)
|
||||
assert.Equal(t, 5, r.Age)
|
||||
})
|
||||
}
|
||||
|
||||
func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
|
||||
logx.Disable()
|
||||
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
fn(db, mock)
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||
}
|
||||
}
|
||||
204
core/stores/sqlx/sqlconn.go
Normal file
204
core/stores/sqlx/sqlconn.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"zero/core/breaker"
|
||||
)
|
||||
|
||||
var ErrNotFound = sql.ErrNoRows
|
||||
|
||||
type (
|
||||
// Session stands for raw connections or transaction sessions
|
||||
Session interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(query string) (StmtSession, error)
|
||||
QueryRow(v interface{}, query string, args ...interface{}) error
|
||||
QueryRowPartial(v interface{}, query string, args ...interface{}) error
|
||||
QueryRows(v interface{}, query string, args ...interface{}) error
|
||||
QueryRowsPartial(v interface{}, query string, args ...interface{}) error
|
||||
}
|
||||
|
||||
// SqlConn only stands for raw connections, so Transact method can be called.
|
||||
SqlConn interface {
|
||||
Session
|
||||
Transact(func(session Session) error) error
|
||||
}
|
||||
|
||||
SqlOption func(*commonSqlConn)
|
||||
|
||||
StmtSession interface {
|
||||
Close() error
|
||||
Exec(args ...interface{}) (sql.Result, error)
|
||||
QueryRow(v interface{}, args ...interface{}) error
|
||||
QueryRowPartial(v interface{}, args ...interface{}) error
|
||||
QueryRows(v interface{}, args ...interface{}) error
|
||||
QueryRowsPartial(v interface{}, args ...interface{}) error
|
||||
}
|
||||
|
||||
// thread-safe
|
||||
// Because CORBA doesn't support PREPARE, so we need to combine the
|
||||
// query arguments into one string and do underlying query without arguments
|
||||
commonSqlConn struct {
|
||||
driverName string
|
||||
datasource string
|
||||
beginTx beginnable
|
||||
brk breaker.Breaker
|
||||
accept func(error) bool
|
||||
}
|
||||
|
||||
sessionConn interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
statement struct {
|
||||
stmt *sql.Stmt
|
||||
}
|
||||
|
||||
stmtConn interface {
|
||||
Exec(args ...interface{}) (sql.Result, error)
|
||||
Query(args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
)
|
||||
|
||||
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
|
||||
conn := &commonSqlConn{
|
||||
driverName: driverName,
|
||||
datasource: datasource,
|
||||
beginTx: begin,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(conn)
|
||||
}
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
|
||||
err = db.brk.DoWithAcceptable(func() error {
|
||||
var conn *sql.DB
|
||||
conn, err = getSqlConn(db.driverName, db.datasource)
|
||||
if err != nil {
|
||||
logInstanceError(db.datasource, err)
|
||||
return err
|
||||
}
|
||||
|
||||
result, err = exec(conn, q, args...)
|
||||
return err
|
||||
}, db.acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
|
||||
err = db.brk.DoWithAcceptable(func() error {
|
||||
var conn *sql.DB
|
||||
conn, err = getSqlConn(db.driverName, db.datasource)
|
||||
if err != nil {
|
||||
logInstanceError(db.datasource, err)
|
||||
return err
|
||||
}
|
||||
|
||||
if st, err := conn.Prepare(query); err != nil {
|
||||
return err
|
||||
} else {
|
||||
stmt = statement{
|
||||
stmt: st,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}, db.acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error {
|
||||
return db.queryRows(func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, true)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return db.queryRows(func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, false)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error {
|
||||
return db.queryRows(func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, true)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return db.queryRows(func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, false)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) Transact(fn func(Session) error) error {
|
||||
return db.brk.DoWithAcceptable(func() error {
|
||||
return transact(db, db.beginTx, fn)
|
||||
}, db.acceptable)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) acceptable(err error) bool {
|
||||
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone
|
||||
if db.accept == nil {
|
||||
return ok
|
||||
} else {
|
||||
return ok || db.accept(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
||||
var qerr error
|
||||
return db.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getSqlConn(db.driverName, db.datasource)
|
||||
if err != nil {
|
||||
logInstanceError(db.datasource, err)
|
||||
return err
|
||||
}
|
||||
|
||||
return query(conn, func(rows *sql.Rows) error {
|
||||
qerr = scanner(rows)
|
||||
return qerr
|
||||
}, q, args...)
|
||||
}, func(err error) bool {
|
||||
return qerr == err || db.acceptable(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s statement) Close() error {
|
||||
return s.stmt.Close()
|
||||
}
|
||||
|
||||
func (s statement) Exec(args ...interface{}) (sql.Result, error) {
|
||||
return execStmt(s.stmt, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRow(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, true)
|
||||
}, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, false)
|
||||
}, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRows(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, true)
|
||||
}, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, false)
|
||||
}, args...)
|
||||
}
|
||||
74
core/stores/sqlx/sqlmanager.go
Normal file
74
core/stores/sqlx/sqlmanager.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"zero/core/syncx"
|
||||
)
|
||||
|
||||
const (
|
||||
maxIdleConns = 64
|
||||
maxOpenConns = 64
|
||||
maxLifetime = time.Minute
|
||||
)
|
||||
|
||||
var connManager = syncx.NewResourceManager()
|
||||
|
||||
type pingedDB struct {
|
||||
*sql.DB
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func getCachedSqlConn(driverName, server string) (*pingedDB, error) {
|
||||
val, err := connManager.GetResource(server, func() (io.Closer, error) {
|
||||
conn, err := newDBConnection(driverName, server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pingedDB{
|
||||
DB: conn,
|
||||
}, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return val.(*pingedDB), nil
|
||||
}
|
||||
|
||||
func getSqlConn(driverName, server string) (*sql.DB, error) {
|
||||
pdb, err := getCachedSqlConn(driverName, server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pdb.once.Do(func() {
|
||||
err = pdb.Ping()
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pdb.DB, nil
|
||||
}
|
||||
|
||||
func newDBConnection(driverName, datasource string) (*sql.DB, error) {
|
||||
conn, err := sql.Open(driverName, datasource)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// we need to do this until the issue https://github.com/golang/go/issues/9851 get fixed
|
||||
// discussed here https://github.com/go-sql-driver/mysql/issues/257
|
||||
// if the discussed SetMaxIdleTimeout methods added, we'll change this behavior
|
||||
// 8 means we can't have more than 8 goroutines to concurrently access the same database.
|
||||
conn.SetMaxIdleConns(maxIdleConns)
|
||||
conn.SetMaxOpenConns(maxOpenConns)
|
||||
conn.SetConnMaxLifetime(maxLifetime)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
92
core/stores/sqlx/stmt.go
Normal file
92
core/stores/sqlx/stmt.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/core/timex"
|
||||
)
|
||||
|
||||
const slowThreshold = time.Millisecond * 500
|
||||
|
||||
func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
result, err := conn.Exec(q, args...)
|
||||
duration := timex.Since(startTime)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("sql exec: %s", stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logSqlError(stmt, err)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
|
||||
stmt := fmt.Sprint(args...)
|
||||
startTime := timex.Now()
|
||||
result, err := conn.Exec(args...)
|
||||
duration := timex.Since(startTime)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("sql execStmt: %s", stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logSqlError(stmt, err)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
rows, err := conn.Query(q, args...)
|
||||
duration := timex.Since(startTime)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("sql query: %s", stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logSqlError(stmt, err)
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanner(rows)
|
||||
}
|
||||
|
||||
func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, args ...interface{}) error {
|
||||
stmt := fmt.Sprint(args...)
|
||||
startTime := timex.Now()
|
||||
rows, err := conn.Query(args...)
|
||||
duration := timex.Since(startTime)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("sql queryStmt: %s", stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logSqlError(stmt, err)
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanner(rows)
|
||||
}
|
||||
103
core/stores/sqlx/tx.go
Normal file
103
core/stores/sqlx/tx.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type (
|
||||
beginnable func(*sql.DB) (trans, error)
|
||||
|
||||
trans interface {
|
||||
Session
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
txSession struct {
|
||||
*sql.Tx
|
||||
}
|
||||
)
|
||||
|
||||
func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) {
|
||||
return exec(t.Tx, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) Prepare(q string) (StmtSession, error) {
|
||||
if stmt, err := t.Tx.Prepare(q); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return statement{
|
||||
stmt: stmt,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error {
|
||||
return query(t.Tx, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, true)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return query(t.Tx, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, false)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error {
|
||||
return query(t.Tx, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, true)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return query(t.Tx, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, false)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func begin(db *sql.DB) (trans, error) {
|
||||
if tx, err := db.Begin(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return txSession{
|
||||
Tx: tx,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) {
|
||||
conn, err := getSqlConn(db.driverName, db.datasource)
|
||||
if err != nil {
|
||||
logInstanceError(db.datasource, err)
|
||||
return err
|
||||
}
|
||||
|
||||
return transactOnConn(conn, b, fn)
|
||||
}
|
||||
|
||||
func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err error) {
|
||||
var tx trans
|
||||
tx, err = b(conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
if e := tx.Rollback(); e != nil {
|
||||
err = fmt.Errorf("recover from %#v, rollback failed: %s", p, e)
|
||||
} else {
|
||||
err = fmt.Errorf("recoveer from %#v", p)
|
||||
}
|
||||
} else if err != nil {
|
||||
if e := tx.Rollback(); e != nil {
|
||||
err = fmt.Errorf("transaction failed: %s, rollback failed: %s", err, e)
|
||||
}
|
||||
} else {
|
||||
err = tx.Commit()
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(tx)
|
||||
}
|
||||
76
core/stores/sqlx/tx_test.go
Normal file
76
core/stores/sqlx/tx_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
mockCommit = 1
|
||||
mockRollback = 2
|
||||
)
|
||||
|
||||
type mockTx struct {
|
||||
status int
|
||||
}
|
||||
|
||||
func (mt *mockTx) Commit() error {
|
||||
mt.status |= mockCommit
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) Exec(q string, args ...interface{}) (sql.Result, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) Prepare(query string) (StmtSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRow(v interface{}, q string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRows(v interface{}, q string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) Rollback() error {
|
||||
mt.status |= mockRollback
|
||||
return nil
|
||||
}
|
||||
|
||||
func beginMock(mock *mockTx) beginnable {
|
||||
return func(*sql.DB) (trans, error) {
|
||||
return mock, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactCommit(t *testing.T) {
|
||||
mock := &mockTx{}
|
||||
err := transactOnConn(nil, beginMock(mock), func(Session) error {
|
||||
return nil
|
||||
})
|
||||
assert.Equal(t, mockCommit, mock.status)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestTransactRollback(t *testing.T) {
|
||||
mock := &mockTx{}
|
||||
err := transactOnConn(nil, beginMock(mock), func(Session) error {
|
||||
return errors.New("rollback")
|
||||
})
|
||||
assert.Equal(t, mockRollback, mock.status)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
101
core/stores/sqlx/utils.go
Normal file
101
core/stores/sqlx/utils.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/core/mapping"
|
||||
)
|
||||
|
||||
func desensitize(datasource string) string {
|
||||
// remove account
|
||||
pos := strings.LastIndex(datasource, "@")
|
||||
if 0 <= pos && pos+1 < len(datasource) {
|
||||
datasource = datasource[pos+1:]
|
||||
}
|
||||
|
||||
return datasource
|
||||
}
|
||||
|
||||
func escape(input string) string {
|
||||
var b strings.Builder
|
||||
|
||||
for _, ch := range input {
|
||||
switch ch {
|
||||
case '\x00':
|
||||
b.WriteString(`\x00`)
|
||||
case '\r':
|
||||
b.WriteString(`\r`)
|
||||
case '\n':
|
||||
b.WriteString(`\n`)
|
||||
case '\\':
|
||||
b.WriteString(`\\`)
|
||||
case '\'':
|
||||
b.WriteString(`\'`)
|
||||
case '"':
|
||||
b.WriteString(`\"`)
|
||||
case '\x1a':
|
||||
b.WriteString(`\x1a`)
|
||||
default:
|
||||
b.WriteRune(ch)
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func format(query string, args ...interface{}) (string, error) {
|
||||
numArgs := len(args)
|
||||
if numArgs == 0 {
|
||||
return query, nil
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
argIndex := 0
|
||||
|
||||
for _, ch := range query {
|
||||
if ch == '?' {
|
||||
if argIndex >= numArgs {
|
||||
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
|
||||
}
|
||||
|
||||
arg := args[argIndex]
|
||||
argIndex++
|
||||
|
||||
switch v := arg.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
b.WriteByte('1')
|
||||
} else {
|
||||
b.WriteByte('0')
|
||||
}
|
||||
case string:
|
||||
b.WriteByte('\'')
|
||||
b.WriteString(escape(v))
|
||||
b.WriteByte('\'')
|
||||
default:
|
||||
b.WriteString(mapping.Repr(v))
|
||||
}
|
||||
} else {
|
||||
b.WriteRune(ch)
|
||||
}
|
||||
}
|
||||
|
||||
if argIndex < numArgs {
|
||||
return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
func logInstanceError(datasource string, err error) {
|
||||
datasource = desensitize(datasource)
|
||||
logx.Errorf("Error on getting sql instance of %s: %v", datasource, err)
|
||||
}
|
||||
|
||||
func logSqlError(stmt string, err error) {
|
||||
if err != nil && err != ErrNotFound {
|
||||
logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
|
||||
}
|
||||
}
|
||||
30
core/stores/sqlx/utils_test.go
Normal file
30
core/stores/sqlx/utils_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEscape(t *testing.T) {
|
||||
s := "a\x00\n\r\\'\"\x1ab"
|
||||
|
||||
out := escape(s)
|
||||
|
||||
assert.Equal(t, `a\x00\n\r\\\'\"\x1ab`, out)
|
||||
}
|
||||
|
||||
func TestDesensitize(t *testing.T) {
|
||||
datasource := "user:pass@tcp(111.222.333.44:3306)/any_table?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai"
|
||||
datasource = desensitize(datasource)
|
||||
assert.False(t, strings.Contains(datasource, "user"))
|
||||
assert.False(t, strings.Contains(datasource, "pass"))
|
||||
assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
|
||||
}
|
||||
|
||||
func TestDesensitize_WithoutAccount(t *testing.T) {
|
||||
datasource := "tcp(111.222.333.44:3306)/any_table?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai"
|
||||
datasource = desensitize(datasource)
|
||||
assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
|
||||
}
|
||||
Reference in New Issue
Block a user