From cd0f3726ed6fa139c79789b08bfd1ee4f2a62ec9 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Sat, 27 May 2023 21:49:11 +0800 Subject: [PATCH] chore: add more tests (#3288) --- core/stores/sqlx/orm.go | 171 ++++++++++++------------ core/stores/sqlx/orm_test.go | 245 ++++++++++++++++++++++++++++++----- 2 files changed, 305 insertions(+), 111 deletions(-) diff --git a/core/stores/sqlx/orm.go b/core/stores/sqlx/orm.go index 956a41a6..d250059a 100644 --- a/core/stores/sqlx/orm.go +++ b/core/stores/sqlx/orm.go @@ -54,27 +54,39 @@ func getTaggedFieldValueMap(v reflect.Value) (map[string]any, error) { } valueField := reflect.Indirect(v).Field(i) - switch valueField.Kind() { - case reflect.Ptr: - if !valueField.CanInterface() { - return nil, ErrNotReadableValue - } - if valueField.IsNil() { - baseValueType := mapping.Deref(valueField.Type()) - valueField.Set(reflect.New(baseValueType)) - } - result[key] = valueField.Interface() - default: - if !valueField.CanAddr() || !valueField.Addr().CanInterface() { - return nil, ErrNotReadableValue - } - result[key] = valueField.Addr().Interface() + valueData, err := getValueInterface(valueField) + if err != nil { + return nil, err } + + result[key] = valueData } return result, nil } +func getValueInterface(value reflect.Value) (any, error) { + switch value.Kind() { + case reflect.Ptr: + if !value.CanInterface() { + return nil, ErrNotReadableValue + } + + if value.IsNil() { + baseValueType := mapping.Deref(value.Type()) + value.Set(reflect.New(baseValueType)) + } + + return value.Interface(), nil + default: + if !value.CanAddr() || !value.Addr().CanInterface() { + return nil, ErrNotReadableValue + } + + return value.Addr().Interface(), nil + } +} + func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]any, error) { fields := unwrapFields(v) if strict && len(columns) < len(fields) { @@ -88,24 +100,18 @@ func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([ values := make([]any, len(columns)) if len(taggedMap) == 0 { + if len(fields) < len(values) { + return nil, ErrNotMatchDestination + } + for i := 0; i < len(values); i++ { valueField := fields[i] - switch valueField.Kind() { - case reflect.Ptr: - if !valueField.CanInterface() { - return nil, ErrNotReadableValue - } - if valueField.IsNil() { - baseValueType := mapping.Deref(valueField.Type()) - valueField.Set(reflect.New(baseValueType)) - } - values[i] = valueField.Interface() - default: - if !valueField.CanAddr() || !valueField.Addr().CanInterface() { - return nil, ErrNotReadableValue - } - values[i] = valueField.Addr().Interface() + valueData, err := getValueInterface(valueField) + if err != nil { + return nil, err } + + values[i] = valueData } } else { for i, column := range columns { @@ -152,11 +158,11 @@ func unmarshalRow(v any, scanner rowsScanner, strict bool) error { reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String: - if rve.CanSet() { - return scanner.Scan(v) + if !rve.CanSet() { + return ErrNotSettable } - return ErrNotSettable + return scanner.Scan(v) case reflect.Struct: columns, err := scanner.Columns() if err != nil { @@ -183,69 +189,66 @@ func unmarshalRows(v any, scanner rowsScanner, strict bool) error { rt := reflect.TypeOf(v) rte := rt.Elem() rve := rv.Elem() + if !rve.CanSet() { + return ErrNotSettable + } + switch rte.Kind() { case reflect.Slice: - if rve.CanSet() { - ptr := rte.Elem().Kind() == reflect.Ptr - appendFn := func(item reflect.Value) { - if ptr { - rve.Set(reflect.Append(rve, item)) - } else { - rve.Set(reflect.Append(rve, reflect.Indirect(item))) - } + ptr := rte.Elem().Kind() == reflect.Ptr + appendFn := func(item reflect.Value) { + if ptr { + rve.Set(reflect.Append(rve, item)) + } else { + rve.Set(reflect.Append(rve, reflect.Indirect(item))) } - fillFn := func(value any) error { - if rve.CanSet() { - if err := scanner.Scan(value); err != nil { - return err - } - - appendFn(reflect.ValueOf(value)) - return nil - } - return ErrNotSettable + } + fillFn := func(value any) error { + if err := scanner.Scan(value); err != nil { + return err } - base := mapping.Deref(rte.Elem()) - switch base.Kind() { - case reflect.Bool, - reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, - reflect.Float32, reflect.Float64, - reflect.String: - for scanner.Next() { - value := reflect.New(base) - if err := fillFn(value.Interface()); err != nil { - return err - } + appendFn(reflect.ValueOf(value)) + return nil + } + + base := mapping.Deref(rte.Elem()) + switch base.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64, + reflect.String: + for scanner.Next() { + value := reflect.New(base) + if err := fillFn(value.Interface()); err != nil { + return err } - case reflect.Struct: - columns, err := scanner.Columns() + } + case reflect.Struct: + columns, err := scanner.Columns() + if err != nil { + return err + } + + for scanner.Next() { + value := reflect.New(base) + values, err := mapStructFieldsIntoSlice(value, columns, strict) if err != nil { return err } - for scanner.Next() { - value := reflect.New(base) - values, err := mapStructFieldsIntoSlice(value, columns, strict) - if err != nil { - return err - } - - if err := scanner.Scan(values...); err != nil { - return err - } - - appendFn(value) + if err := scanner.Scan(values...); err != nil { + return err } - default: - return ErrUnsupportedValueType - } - return nil + appendFn(value) + } + default: + return ErrUnsupportedValueType } - return ErrNotSettable + return nil default: return ErrUnsupportedValueType } @@ -257,6 +260,10 @@ func unwrapFields(v reflect.Value) []reflect.Value { for i := 0; i < indirect.NumField(); i++ { child := indirect.Field(i) + if !child.CanSet() { + continue + } + if child.Kind() == reflect.Ptr && child.IsNil() { baseValueType := mapping.Deref(child.Type()) child.Set(reflect.New(baseValueType)) diff --git a/core/stores/sqlx/orm_test.go b/core/stores/sqlx/orm_test.go index 4aab1bba..7aa82d13 100644 --- a/core/stores/sqlx/orm_test.go +++ b/core/stores/sqlx/orm_test.go @@ -22,6 +22,18 @@ func TestUnmarshalRowBool(t *testing.T) { }, "select value from users where user=?", "anyone")) assert.True(t, value) }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value struct { + Value bool `db:"value"` + } + assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRow(value, rows, true) + }, "select value from users where user=?", "anyone")) + }) } func TestUnmarshalRowBoolNotSettable(t *testing.T) { @@ -207,12 +219,12 @@ func TestUnmarshalRowString(t *testing.T) { } func TestUnmarshalRowStruct(t *testing.T) { - value := new(struct { - Name string - Age int - }) - dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + value := new(struct { + Name string + Age int + }) + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -222,15 +234,58 @@ func TestUnmarshalRowStruct(t *testing.T) { assert.Equal(t, "liao", value.Name) assert.Equal(t, 5, value.Age) }) -} -func TestUnmarshalRowStructWithTags(t *testing.T) { - value := new(struct { - Age int `db:"age"` - Name string `db:"name"` + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + value := new(struct { + Name string + Age int + }) + + errAny := errors.New("any error") + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRow(value, &mockedScanner{ + colErr: errAny, + next: 1, + }, true) + }, "select name, age from users where user=?", "anyone"), errAny) }) dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + value := new(struct { + Name string + age *int + }) + + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRow(value, rows, true) + }, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + type myString chan int + var value myString + assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRow(&value, rows, true) + }, "select value from users where user=?", "anyone"), ErrUnsupportedValueType) + }) +} + +func TestUnmarshalRowStructWithTags(t *testing.T) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + value := new(struct { + Age int `db:"age"` + Name string `db:"name"` + }) + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -240,6 +295,51 @@ func TestUnmarshalRowStructWithTags(t *testing.T) { assert.Equal(t, "liao", value.Name) assert.Equal(t, 5, value.Age) }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + value := new(struct { + age *int `db:"age"` + Name string `db:"name"` + }) + + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRow(value, rows, true) + }, "select name, age from users where user=?", "anyone"), ErrNotReadableValue) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var value struct { + Age *int `db:"age"` + Name *string `db:"name"` + } + + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRow(&value, rows, true) + }, "select name, age from users where user=?", "anyone")) + assert.Equal(t, "liao", *value.Name) + assert.Equal(t, 5, *value.Age) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + value := new(struct { + Age int `db:"age"` + Name string + }) + + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRow(value, rows, true) + }, "select name, age from users where user=?", "anyone")) + assert.Equal(t, 5, value.Age) + }) } func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) { @@ -270,6 +370,42 @@ func TestUnmarshalRowsBool(t *testing.T) { }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []bool + assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRows(value, rows, true) + }, "select value from users where user=?", "anyone")) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value struct { + value []bool `db:"value"` + } + assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone"), ErrUnsupportedValueType) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []bool + errAny := errors.New("any") + assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRows(&value, &mockedScanner{ + scanErr: errAny, + next: 1, + }, true) + }, "select value from users where user=?", "anyone"), errAny) + }) } func TestUnmarshalRowsInt(t *testing.T) { @@ -679,25 +815,25 @@ func TestUnmarshalRowsStringPtr(t *testing.T) { } func TestUnmarshalRowsStruct(t *testing.T) { - expect := []struct { - Name string - Age int64 - }{ - { - Name: "first", - Age: 2, - }, - { - Name: "second", - Age: 3, - }, - } - var value []struct { - Name string - Age int64 - } - dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + expect := []struct { + Name string + Age int64 + }{ + { + Name: "first", + Age: 2, + }, + { + Name: "second", + Age: 3, + }, + } + var value []struct { + Name string + Age int64 + } + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { @@ -709,6 +845,56 @@ func TestUnmarshalRowsStruct(t *testing.T) { assert.Equal(t, each.Age, value[i].Age) } }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var value []struct { + Name string + Age int64 + } + + errAny := errors.New("any error") + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRows(&value, &mockedScanner{ + colErr: errAny, + next: 1, + }, true) + }, "select name, age from users where user=?", "anyone"), errAny) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var value []struct { + Name string + Age int64 + } + + errAny := errors.New("any error") + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRows(&value, &mockedScanner{ + cols: []string{"name", "age"}, + scanErr: errAny, + next: 1, + }, true) + }, "select name, age from users where user=?", "anyone"), errAny) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var value []chan int + + errAny := errors.New("any error") + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRows(&value, &mockedScanner{ + cols: []string{"name", "age"}, + scanErr: errAny, + next: 1, + }, true) + }, "select name, age from users where user=?", "anyone"), ErrUnsupportedValueType) + }) } func TestUnmarshalRowsStructWithNullStringType(t *testing.T) { @@ -1163,6 +1349,7 @@ func TestAnonymousStructPrError(t *testing.T) { } type mockedScanner struct { + cols []string colErr error scanErr error err error @@ -1170,7 +1357,7 @@ type mockedScanner struct { } func (m *mockedScanner) Columns() ([]string, error) { - return nil, m.colErr + return m.cols, m.colErr } func (m *mockedScanner) Err() error {