chore: add more tests (#3288)
This commit is contained in:
@@ -54,27 +54,39 @@ func getTaggedFieldValueMap(v reflect.Value) (map[string]any, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
valueField := reflect.Indirect(v).Field(i)
|
valueField := reflect.Indirect(v).Field(i)
|
||||||
switch valueField.Kind() {
|
valueData, err := getValueInterface(valueField)
|
||||||
case reflect.Ptr:
|
if err != nil {
|
||||||
if !valueField.CanInterface() {
|
return nil, err
|
||||||
return nil, ErrNotReadableValue
|
|
||||||
}
|
|
||||||
if valueField.IsNil() {
|
|
||||||
baseValueType := mapping.Deref(valueField.Type())
|
|
||||||
valueField.Set(reflect.New(baseValueType))
|
|
||||||
}
|
|
||||||
result[key] = valueField.Interface()
|
|
||||||
default:
|
|
||||||
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
|
|
||||||
return nil, ErrNotReadableValue
|
|
||||||
}
|
|
||||||
result[key] = valueField.Addr().Interface()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
result[key] = valueData
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getValueInterface(value reflect.Value) (any, error) {
|
||||||
|
switch value.Kind() {
|
||||||
|
case reflect.Ptr:
|
||||||
|
if !value.CanInterface() {
|
||||||
|
return nil, ErrNotReadableValue
|
||||||
|
}
|
||||||
|
|
||||||
|
if value.IsNil() {
|
||||||
|
baseValueType := mapping.Deref(value.Type())
|
||||||
|
value.Set(reflect.New(baseValueType))
|
||||||
|
}
|
||||||
|
|
||||||
|
return value.Interface(), nil
|
||||||
|
default:
|
||||||
|
if !value.CanAddr() || !value.Addr().CanInterface() {
|
||||||
|
return nil, ErrNotReadableValue
|
||||||
|
}
|
||||||
|
|
||||||
|
return value.Addr().Interface(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]any, error) {
|
func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]any, error) {
|
||||||
fields := unwrapFields(v)
|
fields := unwrapFields(v)
|
||||||
if strict && len(columns) < len(fields) {
|
if strict && len(columns) < len(fields) {
|
||||||
@@ -88,24 +100,18 @@ func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([
|
|||||||
|
|
||||||
values := make([]any, len(columns))
|
values := make([]any, len(columns))
|
||||||
if len(taggedMap) == 0 {
|
if len(taggedMap) == 0 {
|
||||||
|
if len(fields) < len(values) {
|
||||||
|
return nil, ErrNotMatchDestination
|
||||||
|
}
|
||||||
|
|
||||||
for i := 0; i < len(values); i++ {
|
for i := 0; i < len(values); i++ {
|
||||||
valueField := fields[i]
|
valueField := fields[i]
|
||||||
switch valueField.Kind() {
|
valueData, err := getValueInterface(valueField)
|
||||||
case reflect.Ptr:
|
if err != nil {
|
||||||
if !valueField.CanInterface() {
|
return nil, err
|
||||||
return nil, ErrNotReadableValue
|
|
||||||
}
|
|
||||||
if valueField.IsNil() {
|
|
||||||
baseValueType := mapping.Deref(valueField.Type())
|
|
||||||
valueField.Set(reflect.New(baseValueType))
|
|
||||||
}
|
|
||||||
values[i] = valueField.Interface()
|
|
||||||
default:
|
|
||||||
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
|
|
||||||
return nil, ErrNotReadableValue
|
|
||||||
}
|
|
||||||
values[i] = valueField.Addr().Interface()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
values[i] = valueData
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for i, column := range columns {
|
for i, column := range columns {
|
||||||
@@ -152,11 +158,11 @@ func unmarshalRow(v any, scanner rowsScanner, strict bool) error {
|
|||||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||||
reflect.Float32, reflect.Float64,
|
reflect.Float32, reflect.Float64,
|
||||||
reflect.String:
|
reflect.String:
|
||||||
if rve.CanSet() {
|
if !rve.CanSet() {
|
||||||
return scanner.Scan(v)
|
return ErrNotSettable
|
||||||
}
|
}
|
||||||
|
|
||||||
return ErrNotSettable
|
return scanner.Scan(v)
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
columns, err := scanner.Columns()
|
columns, err := scanner.Columns()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -183,69 +189,66 @@ func unmarshalRows(v any, scanner rowsScanner, strict bool) error {
|
|||||||
rt := reflect.TypeOf(v)
|
rt := reflect.TypeOf(v)
|
||||||
rte := rt.Elem()
|
rte := rt.Elem()
|
||||||
rve := rv.Elem()
|
rve := rv.Elem()
|
||||||
|
if !rve.CanSet() {
|
||||||
|
return ErrNotSettable
|
||||||
|
}
|
||||||
|
|
||||||
switch rte.Kind() {
|
switch rte.Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
if rve.CanSet() {
|
ptr := rte.Elem().Kind() == reflect.Ptr
|
||||||
ptr := rte.Elem().Kind() == reflect.Ptr
|
appendFn := func(item reflect.Value) {
|
||||||
appendFn := func(item reflect.Value) {
|
if ptr {
|
||||||
if ptr {
|
rve.Set(reflect.Append(rve, item))
|
||||||
rve.Set(reflect.Append(rve, item))
|
} else {
|
||||||
} else {
|
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
|
||||||
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
fillFn := func(value any) error {
|
}
|
||||||
if rve.CanSet() {
|
fillFn := func(value any) error {
|
||||||
if err := scanner.Scan(value); err != nil {
|
if err := scanner.Scan(value); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
|
||||||
|
|
||||||
appendFn(reflect.ValueOf(value))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return ErrNotSettable
|
|
||||||
}
|
}
|
||||||
|
|
||||||
base := mapping.Deref(rte.Elem())
|
appendFn(reflect.ValueOf(value))
|
||||||
switch base.Kind() {
|
return nil
|
||||||
case reflect.Bool,
|
}
|
||||||
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
|
||||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
base := mapping.Deref(rte.Elem())
|
||||||
reflect.Float32, reflect.Float64,
|
switch base.Kind() {
|
||||||
reflect.String:
|
case reflect.Bool,
|
||||||
for scanner.Next() {
|
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||||
value := reflect.New(base)
|
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||||
if err := fillFn(value.Interface()); err != nil {
|
reflect.Float32, reflect.Float64,
|
||||||
return err
|
reflect.String:
|
||||||
}
|
for scanner.Next() {
|
||||||
|
value := reflect.New(base)
|
||||||
|
if err := fillFn(value.Interface()); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
}
|
||||||
columns, err := scanner.Columns()
|
case reflect.Struct:
|
||||||
|
columns, err := scanner.Columns()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for scanner.Next() {
|
||||||
|
value := reflect.New(base)
|
||||||
|
values, err := mapStructFieldsIntoSlice(value, columns, strict)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for scanner.Next() {
|
if err := scanner.Scan(values...); err != nil {
|
||||||
value := reflect.New(base)
|
return err
|
||||||
values, err := mapStructFieldsIntoSlice(value, columns, strict)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := scanner.Scan(values...); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
appendFn(value)
|
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
return ErrUnsupportedValueType
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
appendFn(value)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return ErrUnsupportedValueType
|
||||||
}
|
}
|
||||||
|
|
||||||
return ErrNotSettable
|
return nil
|
||||||
default:
|
default:
|
||||||
return ErrUnsupportedValueType
|
return ErrUnsupportedValueType
|
||||||
}
|
}
|
||||||
@@ -257,6 +260,10 @@ func unwrapFields(v reflect.Value) []reflect.Value {
|
|||||||
|
|
||||||
for i := 0; i < indirect.NumField(); i++ {
|
for i := 0; i < indirect.NumField(); i++ {
|
||||||
child := indirect.Field(i)
|
child := indirect.Field(i)
|
||||||
|
if !child.CanSet() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if child.Kind() == reflect.Ptr && child.IsNil() {
|
if child.Kind() == reflect.Ptr && child.IsNil() {
|
||||||
baseValueType := mapping.Deref(child.Type())
|
baseValueType := mapping.Deref(child.Type())
|
||||||
child.Set(reflect.New(baseValueType))
|
child.Set(reflect.New(baseValueType))
|
||||||
|
|||||||
@@ -22,6 +22,18 @@ func TestUnmarshalRowBool(t *testing.T) {
|
|||||||
}, "select value from users where user=?", "anyone"))
|
}, "select value from users where user=?", "anyone"))
|
||||||
assert.True(t, value)
|
assert.True(t, value)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
var value struct {
|
||||||
|
Value bool `db:"value"`
|
||||||
|
}
|
||||||
|
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(value, rows, true)
|
||||||
|
}, "select value from users where user=?", "anyone"))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowBoolNotSettable(t *testing.T) {
|
func TestUnmarshalRowBoolNotSettable(t *testing.T) {
|
||||||
@@ -207,12 +219,12 @@ func TestUnmarshalRowString(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowStruct(t *testing.T) {
|
func TestUnmarshalRowStruct(t *testing.T) {
|
||||||
value := new(struct {
|
|
||||||
Name string
|
|
||||||
Age int
|
|
||||||
})
|
|
||||||
|
|
||||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
value := new(struct {
|
||||||
|
Name string
|
||||||
|
Age int
|
||||||
|
})
|
||||||
|
|
||||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -222,15 +234,58 @@ func TestUnmarshalRowStruct(t *testing.T) {
|
|||||||
assert.Equal(t, "liao", value.Name)
|
assert.Equal(t, "liao", value.Name)
|
||||||
assert.Equal(t, 5, value.Age)
|
assert.Equal(t, 5, value.Age)
|
||||||
})
|
})
|
||||||
}
|
|
||||||
|
|
||||||
func TestUnmarshalRowStructWithTags(t *testing.T) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
value := new(struct {
|
value := new(struct {
|
||||||
Age int `db:"age"`
|
Name string
|
||||||
Name string `db:"name"`
|
Age int
|
||||||
|
})
|
||||||
|
|
||||||
|
errAny := errors.New("any error")
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(value, &mockedScanner{
|
||||||
|
colErr: errAny,
|
||||||
|
next: 1,
|
||||||
|
}, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"), errAny)
|
||||||
})
|
})
|
||||||
|
|
||||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
value := new(struct {
|
||||||
|
Name string
|
||||||
|
age *int
|
||||||
|
})
|
||||||
|
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(value, rows, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
type myString chan int
|
||||||
|
var value myString
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(&value, rows, true)
|
||||||
|
}, "select value from users where user=?", "anyone"), ErrUnsupportedValueType)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalRowStructWithTags(t *testing.T) {
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
value := new(struct {
|
||||||
|
Age int `db:"age"`
|
||||||
|
Name string `db:"name"`
|
||||||
|
})
|
||||||
|
|
||||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -240,6 +295,51 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
|
|||||||
assert.Equal(t, "liao", value.Name)
|
assert.Equal(t, "liao", value.Name)
|
||||||
assert.Equal(t, 5, value.Age)
|
assert.Equal(t, 5, value.Age)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
value := new(struct {
|
||||||
|
age *int `db:"age"`
|
||||||
|
Name string `db:"name"`
|
||||||
|
})
|
||||||
|
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(value, rows, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
var value struct {
|
||||||
|
Age *int `db:"age"`
|
||||||
|
Name *string `db:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(&value, rows, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"))
|
||||||
|
assert.Equal(t, "liao", *value.Name)
|
||||||
|
assert.Equal(t, 5, *value.Age)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
value := new(struct {
|
||||||
|
Age int `db:"age"`
|
||||||
|
Name string
|
||||||
|
})
|
||||||
|
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(value, rows, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"))
|
||||||
|
assert.Equal(t, 5, value.Age)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
|
func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
|
||||||
@@ -270,6 +370,42 @@ func TestUnmarshalRowsBool(t *testing.T) {
|
|||||||
}, "select value from users where user=?", "anyone"))
|
}, "select value from users where user=?", "anyone"))
|
||||||
assert.EqualValues(t, expect, value)
|
assert.EqualValues(t, expect, value)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
var value []bool
|
||||||
|
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(value, rows, true)
|
||||||
|
}, "select value from users where user=?", "anyone"))
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
var value struct {
|
||||||
|
value []bool `db:"value"`
|
||||||
|
}
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, rows, true)
|
||||||
|
}, "select value from users where user=?", "anyone"), ErrUnsupportedValueType)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
var value []bool
|
||||||
|
errAny := errors.New("any")
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, &mockedScanner{
|
||||||
|
scanErr: errAny,
|
||||||
|
next: 1,
|
||||||
|
}, true)
|
||||||
|
}, "select value from users where user=?", "anyone"), errAny)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsInt(t *testing.T) {
|
func TestUnmarshalRowsInt(t *testing.T) {
|
||||||
@@ -679,25 +815,25 @@ func TestUnmarshalRowsStringPtr(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsStruct(t *testing.T) {
|
func TestUnmarshalRowsStruct(t *testing.T) {
|
||||||
expect := []struct {
|
|
||||||
Name string
|
|
||||||
Age int64
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
Name: "first",
|
|
||||||
Age: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "second",
|
|
||||||
Age: 3,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
var value []struct {
|
|
||||||
Name string
|
|
||||||
Age int64
|
|
||||||
}
|
|
||||||
|
|
||||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
expect := []struct {
|
||||||
|
Name string
|
||||||
|
Age int64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "first",
|
||||||
|
Age: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "second",
|
||||||
|
Age: 3,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
var value []struct {
|
||||||
|
Name string
|
||||||
|
Age int64
|
||||||
|
}
|
||||||
|
|
||||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
@@ -709,6 +845,56 @@ func TestUnmarshalRowsStruct(t *testing.T) {
|
|||||||
assert.Equal(t, each.Age, value[i].Age)
|
assert.Equal(t, each.Age, value[i].Age)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
var value []struct {
|
||||||
|
Name string
|
||||||
|
Age int64
|
||||||
|
}
|
||||||
|
|
||||||
|
errAny := errors.New("any error")
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, &mockedScanner{
|
||||||
|
colErr: errAny,
|
||||||
|
next: 1,
|
||||||
|
}, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"), errAny)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
var value []struct {
|
||||||
|
Name string
|
||||||
|
Age int64
|
||||||
|
}
|
||||||
|
|
||||||
|
errAny := errors.New("any error")
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, &mockedScanner{
|
||||||
|
cols: []string{"name", "age"},
|
||||||
|
scanErr: errAny,
|
||||||
|
next: 1,
|
||||||
|
}, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"), errAny)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
var value []chan int
|
||||||
|
|
||||||
|
errAny := errors.New("any error")
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, &mockedScanner{
|
||||||
|
cols: []string{"name", "age"},
|
||||||
|
scanErr: errAny,
|
||||||
|
next: 1,
|
||||||
|
}, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"), ErrUnsupportedValueType)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
|
func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
|
||||||
@@ -1163,6 +1349,7 @@ func TestAnonymousStructPrError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type mockedScanner struct {
|
type mockedScanner struct {
|
||||||
|
cols []string
|
||||||
colErr error
|
colErr error
|
||||||
scanErr error
|
scanErr error
|
||||||
err error
|
err error
|
||||||
@@ -1170,7 +1357,7 @@ type mockedScanner struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockedScanner) Columns() ([]string, error) {
|
func (m *mockedScanner) Columns() ([]string, error) {
|
||||||
return nil, m.colErr
|
return m.cols, m.colErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockedScanner) Err() error {
|
func (m *mockedScanner) Err() error {
|
||||||
|
|||||||
Reference in New Issue
Block a user