diff --git a/core/stores/sqlx/orm_test.go b/core/stores/sqlx/orm_test.go index 74b82b38..91f05648 100644 --- a/core/stores/sqlx/orm_test.go +++ b/core/stores/sqlx/orm_test.go @@ -2,6 +2,7 @@ package sqlx import ( "database/sql" + "errors" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -256,24 +257,6 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) { }) } -func TestUnmarshalRowStructWithTagsPtr(t *testing.T) { - var value = new(struct { - Age *int `db:"age"` - Name string `db:"name"` - }) - - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") - mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) - - assert.Nil(t, query(db, func(rows *sql.Rows) error { - return unmarshalRow(value, rows, true) - }, "select name, age from users where user=?", "anyone")) - assert.Equal(t, "liao", value.Name) - assert.Equal(t, 5, *value.Age) - }) -} - func TestUnmarshalRowsBool(t *testing.T) { runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { var expect = []bool{true, false} @@ -1001,6 +984,62 @@ func TestCommonSqlConn_QueryRowOptional(t *testing.T) { }) } +func TestUnmarshalRowError(t *testing.T) { + tests := []struct { + name string + colErr error + scanErr error + err error + next int + validate func(err error) + }{ + { + name: "with error", + err: errors.New("foo"), + validate: func(err error) { + assert.NotNil(t, err) + }, + }, + { + name: "without next", + validate: func(err error) { + assert.Equal(t, ErrNotFound, err) + }, + }, + { + name: "with error", + scanErr: errors.New("foo"), + next: 1, + validate: func(err error) { + assert.Equal(t, ErrNotFound, err) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs( + "anyone").WillReturnRows(rs) + + var r struct { + User string `db:"user"` + Age int `db:"age"` + } + test.validate(query(db, func(rows *sql.Rows) error { + scanner := mockedScanner{ + colErr: test.colErr, + scanErr: test.scanErr, + err: test.err, + } + return unmarshalRow(&r, &scanner, false) + }, "select age from users where user=?", "anyone")) + }) + }) + } +} + func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) { logx.Disable() @@ -1016,3 +1055,30 @@ func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) { t.Errorf("there were unfulfilled expectations: %s", err) } } + +type mockedScanner struct { + colErr error + scanErr error + err error + next int +} + +func (m *mockedScanner) Columns() ([]string, error) { + return nil, m.colErr +} + +func (m *mockedScanner) Err() error { + return m.err +} + +func (m *mockedScanner) Next() bool { + if m.next > 0 { + m.next-- + return true + } + return false +} + +func (m *mockedScanner) Scan(v ...interface{}) error { + return m.scanErr +}