chore: add more tests (#3288)

This commit is contained in:
Kevin Wan
2023-05-27 21:49:11 +08:00
committed by GitHub
parent 0217044900
commit cd0f3726ed
2 changed files with 305 additions and 111 deletions

View File

@@ -54,27 +54,39 @@ func getTaggedFieldValueMap(v reflect.Value) (map[string]any, error) {
} }
valueField := reflect.Indirect(v).Field(i) valueField := reflect.Indirect(v).Field(i)
switch valueField.Kind() { valueData, err := getValueInterface(valueField)
case reflect.Ptr: if err != nil {
if !valueField.CanInterface() { return nil, err
return nil, ErrNotReadableValue
}
if valueField.IsNil() {
baseValueType := mapping.Deref(valueField.Type())
valueField.Set(reflect.New(baseValueType))
}
result[key] = valueField.Interface()
default:
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
result[key] = valueField.Addr().Interface()
} }
result[key] = valueData
} }
return result, nil return result, nil
} }
func getValueInterface(value reflect.Value) (any, error) {
switch value.Kind() {
case reflect.Ptr:
if !value.CanInterface() {
return nil, ErrNotReadableValue
}
if value.IsNil() {
baseValueType := mapping.Deref(value.Type())
value.Set(reflect.New(baseValueType))
}
return value.Interface(), nil
default:
if !value.CanAddr() || !value.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
return value.Addr().Interface(), nil
}
}
func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]any, error) { func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]any, error) {
fields := unwrapFields(v) fields := unwrapFields(v)
if strict && len(columns) < len(fields) { if strict && len(columns) < len(fields) {
@@ -88,24 +100,18 @@ func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([
values := make([]any, len(columns)) values := make([]any, len(columns))
if len(taggedMap) == 0 { if len(taggedMap) == 0 {
if len(fields) < len(values) {
return nil, ErrNotMatchDestination
}
for i := 0; i < len(values); i++ { for i := 0; i < len(values); i++ {
valueField := fields[i] valueField := fields[i]
switch valueField.Kind() { valueData, err := getValueInterface(valueField)
case reflect.Ptr: if err != nil {
if !valueField.CanInterface() { return nil, err
return nil, ErrNotReadableValue
}
if valueField.IsNil() {
baseValueType := mapping.Deref(valueField.Type())
valueField.Set(reflect.New(baseValueType))
}
values[i] = valueField.Interface()
default:
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
values[i] = valueField.Addr().Interface()
} }
values[i] = valueData
} }
} else { } else {
for i, column := range columns { for i, column := range columns {
@@ -152,11 +158,11 @@ func unmarshalRow(v any, scanner rowsScanner, strict bool) error {
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64, reflect.Float32, reflect.Float64,
reflect.String: reflect.String:
if rve.CanSet() { if !rve.CanSet() {
return scanner.Scan(v) return ErrNotSettable
} }
return ErrNotSettable return scanner.Scan(v)
case reflect.Struct: case reflect.Struct:
columns, err := scanner.Columns() columns, err := scanner.Columns()
if err != nil { if err != nil {
@@ -183,69 +189,66 @@ func unmarshalRows(v any, scanner rowsScanner, strict bool) error {
rt := reflect.TypeOf(v) rt := reflect.TypeOf(v)
rte := rt.Elem() rte := rt.Elem()
rve := rv.Elem() rve := rv.Elem()
if !rve.CanSet() {
return ErrNotSettable
}
switch rte.Kind() { switch rte.Kind() {
case reflect.Slice: case reflect.Slice:
if rve.CanSet() { ptr := rte.Elem().Kind() == reflect.Ptr
ptr := rte.Elem().Kind() == reflect.Ptr appendFn := func(item reflect.Value) {
appendFn := func(item reflect.Value) { if ptr {
if ptr { rve.Set(reflect.Append(rve, item))
rve.Set(reflect.Append(rve, item)) } else {
} else { rve.Set(reflect.Append(rve, reflect.Indirect(item)))
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
}
} }
fillFn := func(value any) error { }
if rve.CanSet() { fillFn := func(value any) error {
if err := scanner.Scan(value); err != nil { if err := scanner.Scan(value); err != nil {
return err return err
}
appendFn(reflect.ValueOf(value))
return nil
}
return ErrNotSettable
} }
base := mapping.Deref(rte.Elem()) appendFn(reflect.ValueOf(value))
switch base.Kind() { return nil
case reflect.Bool, }
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, base := mapping.Deref(rte.Elem())
reflect.Float32, reflect.Float64, switch base.Kind() {
reflect.String: case reflect.Bool,
for scanner.Next() { reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
value := reflect.New(base) reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
if err := fillFn(value.Interface()); err != nil { reflect.Float32, reflect.Float64,
return err reflect.String:
} for scanner.Next() {
value := reflect.New(base)
if err := fillFn(value.Interface()); err != nil {
return err
} }
case reflect.Struct: }
columns, err := scanner.Columns() case reflect.Struct:
columns, err := scanner.Columns()
if err != nil {
return err
}
for scanner.Next() {
value := reflect.New(base)
values, err := mapStructFieldsIntoSlice(value, columns, strict)
if err != nil { if err != nil {
return err return err
} }
for scanner.Next() { if err := scanner.Scan(values...); err != nil {
value := reflect.New(base) return err
values, err := mapStructFieldsIntoSlice(value, columns, strict)
if err != nil {
return err
}
if err := scanner.Scan(values...); err != nil {
return err
}
appendFn(value)
} }
default:
return ErrUnsupportedValueType
}
return nil appendFn(value)
}
default:
return ErrUnsupportedValueType
} }
return ErrNotSettable return nil
default: default:
return ErrUnsupportedValueType return ErrUnsupportedValueType
} }
@@ -257,6 +260,10 @@ func unwrapFields(v reflect.Value) []reflect.Value {
for i := 0; i < indirect.NumField(); i++ { for i := 0; i < indirect.NumField(); i++ {
child := indirect.Field(i) child := indirect.Field(i)
if !child.CanSet() {
continue
}
if child.Kind() == reflect.Ptr && child.IsNil() { if child.Kind() == reflect.Ptr && child.IsNil() {
baseValueType := mapping.Deref(child.Type()) baseValueType := mapping.Deref(child.Type())
child.Set(reflect.New(baseValueType)) child.Set(reflect.New(baseValueType))

View File

@@ -22,6 +22,18 @@ func TestUnmarshalRowBool(t *testing.T) {
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.True(t, value) assert.True(t, value)
}) })
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value struct {
Value bool `db:"value"`
}
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select value from users where user=?", "anyone"))
})
} }
func TestUnmarshalRowBoolNotSettable(t *testing.T) { func TestUnmarshalRowBoolNotSettable(t *testing.T) {
@@ -207,12 +219,12 @@ func TestUnmarshalRowString(t *testing.T) {
} }
func TestUnmarshalRowStruct(t *testing.T) { func TestUnmarshalRowStruct(t *testing.T) {
value := new(struct {
Name string
Age int
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
Name string
Age int
})
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -222,15 +234,58 @@ func TestUnmarshalRowStruct(t *testing.T) {
assert.Equal(t, "liao", value.Name) assert.Equal(t, "liao", value.Name)
assert.Equal(t, 5, value.Age) assert.Equal(t, 5, value.Age)
}) })
}
func TestUnmarshalRowStructWithTags(t *testing.T) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct { value := new(struct {
Age int `db:"age"` Name string
Name string `db:"name"` Age int
})
errAny := errors.New("any error")
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, &mockedScanner{
colErr: errAny,
next: 1,
}, true)
}, "select name, age from users where user=?", "anyone"), errAny)
}) })
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
Name string
age *int
})
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
type myString chan int
var value myString
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"), ErrUnsupportedValueType)
})
}
func TestUnmarshalRowStructWithTags(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
Age int `db:"age"`
Name string `db:"name"`
})
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -240,6 +295,51 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
assert.Equal(t, "liao", value.Name) assert.Equal(t, "liao", value.Name)
assert.Equal(t, 5, value.Age) assert.Equal(t, 5, value.Age)
}) })
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
age *int `db:"age"`
Name string `db:"name"`
})
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var value struct {
Age *int `db:"age"`
Name *string `db:"name"`
}
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
assert.Equal(t, "liao", *value.Name)
assert.Equal(t, 5, *value.Age)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
Age int `db:"age"`
Name string
})
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"))
assert.Equal(t, 5, value.Age)
})
} }
func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) { func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
@@ -270,6 +370,42 @@ func TestUnmarshalRowsBool(t *testing.T) {
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
}) })
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []bool
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(value, rows, true)
}, "select value from users where user=?", "anyone"))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value struct {
value []bool `db:"value"`
}
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"), ErrUnsupportedValueType)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []bool
errAny := errors.New("any")
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, &mockedScanner{
scanErr: errAny,
next: 1,
}, true)
}, "select value from users where user=?", "anyone"), errAny)
})
} }
func TestUnmarshalRowsInt(t *testing.T) { func TestUnmarshalRowsInt(t *testing.T) {
@@ -679,25 +815,25 @@ func TestUnmarshalRowsStringPtr(t *testing.T) {
} }
func TestUnmarshalRowsStruct(t *testing.T) { func TestUnmarshalRowsStruct(t *testing.T) {
expect := []struct {
Name string
Age int64
}{
{
Name: "first",
Age: 2,
},
{
Name: "second",
Age: 3,
},
}
var value []struct {
Name string
Age int64
}
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []struct {
Name string
Age int64
}{
{
Name: "first",
Age: 2,
},
{
Name: "second",
Age: 3,
},
}
var value []struct {
Name string
Age int64
}
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -709,6 +845,56 @@ func TestUnmarshalRowsStruct(t *testing.T) {
assert.Equal(t, each.Age, value[i].Age) assert.Equal(t, each.Age, value[i].Age)
} }
}) })
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var value []struct {
Name string
Age int64
}
errAny := errors.New("any error")
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, &mockedScanner{
colErr: errAny,
next: 1,
}, true)
}, "select name, age from users where user=?", "anyone"), errAny)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var value []struct {
Name string
Age int64
}
errAny := errors.New("any error")
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, &mockedScanner{
cols: []string{"name", "age"},
scanErr: errAny,
next: 1,
}, true)
}, "select name, age from users where user=?", "anyone"), errAny)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var value []chan int
errAny := errors.New("any error")
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, &mockedScanner{
cols: []string{"name", "age"},
scanErr: errAny,
next: 1,
}, true)
}, "select name, age from users where user=?", "anyone"), ErrUnsupportedValueType)
})
} }
func TestUnmarshalRowsStructWithNullStringType(t *testing.T) { func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
@@ -1163,6 +1349,7 @@ func TestAnonymousStructPrError(t *testing.T) {
} }
type mockedScanner struct { type mockedScanner struct {
cols []string
colErr error colErr error
scanErr error scanErr error
err error err error
@@ -1170,7 +1357,7 @@ type mockedScanner struct {
} }
func (m *mockedScanner) Columns() ([]string, error) { func (m *mockedScanner) Columns() ([]string, error) {
return nil, m.colErr return m.cols, m.colErr
} }
func (m *mockedScanner) Err() error { func (m *mockedScanner) Err() error {