add more tests
This commit is contained in:
@@ -2,6 +2,7 @@ package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
@@ -256,24 +257,6 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowStructWithTagsPtr(t *testing.T) {
|
||||
var value = new(struct {
|
||||
Age *int `db:"age"`
|
||||
Name string `db:"name"`
|
||||
})
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
assert.Equal(t, "liao", value.Name)
|
||||
assert.Equal(t, 5, *value.Age)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsBool(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []bool{true, false}
|
||||
@@ -1001,6 +984,62 @@ func TestCommonSqlConn_QueryRowOptional(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
colErr error
|
||||
scanErr error
|
||||
err error
|
||||
next int
|
||||
validate func(err error)
|
||||
}{
|
||||
{
|
||||
name: "with error",
|
||||
err: errors.New("foo"),
|
||||
validate: func(err error) {
|
||||
assert.NotNil(t, err)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "without next",
|
||||
validate: func(err error) {
|
||||
assert.Equal(t, ErrNotFound, err)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with error",
|
||||
scanErr: errors.New("foo"),
|
||||
next: 1,
|
||||
validate: func(err error) {
|
||||
assert.Equal(t, ErrNotFound, err)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs(
|
||||
"anyone").WillReturnRows(rs)
|
||||
|
||||
var r struct {
|
||||
User string `db:"user"`
|
||||
Age int `db:"age"`
|
||||
}
|
||||
test.validate(query(db, func(rows *sql.Rows) error {
|
||||
scanner := mockedScanner{
|
||||
colErr: test.colErr,
|
||||
scanErr: test.scanErr,
|
||||
err: test.err,
|
||||
}
|
||||
return unmarshalRow(&r, &scanner, false)
|
||||
}, "select age from users where user=?", "anyone"))
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
|
||||
logx.Disable()
|
||||
|
||||
@@ -1016,3 +1055,30 @@ func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
|
||||
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
type mockedScanner struct {
|
||||
colErr error
|
||||
scanErr error
|
||||
err error
|
||||
next int
|
||||
}
|
||||
|
||||
func (m *mockedScanner) Columns() ([]string, error) {
|
||||
return nil, m.colErr
|
||||
}
|
||||
|
||||
func (m *mockedScanner) Err() error {
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *mockedScanner) Next() bool {
|
||||
if m.next > 0 {
|
||||
m.next--
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *mockedScanner) Scan(v ...interface{}) error {
|
||||
return m.scanErr
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user