expose sql.DB to let orm operate on it (#1015)
* expose sql.DB to let orm operate on it * add missing RawDB methods * add NewSqlConnFromDB for cooperate with dtm
This commit is contained in:
@@ -600,6 +600,10 @@ func (d dummySqlConn) QueryRowsPartial(v interface{}, query string, args ...inte
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) RawDB() (*sql.DB, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) Transact(func(session sqlx.Session) error) error {
|
||||
return nil
|
||||
}
|
||||
@@ -621,6 +625,10 @@ func (c *trackedConn) QueryRows(v interface{}, query string, args ...interface{}
|
||||
return c.dummySqlConn.QueryRows(v, query, args...)
|
||||
}
|
||||
|
||||
func (c *trackedConn) RawDB() (*sql.DB, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *trackedConn) Transact(fn func(session sqlx.Session) error) error {
|
||||
c.transactValue = true
|
||||
return c.dummySqlConn.Transact(fn)
|
||||
|
||||
@@ -43,6 +43,10 @@ func (c *mockedConn) QueryRowsPartial(v interface{}, query string, args ...inter
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
func (c *mockedConn) RawDB() (*sql.DB, error) {
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
func (c *mockedConn) Transact(func(session Session) error) error {
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
@@ -6,6 +6,9 @@ import (
|
||||
"github.com/tal-tech/go-zero/core/breaker"
|
||||
)
|
||||
|
||||
// datasource placeholder for logging error.
|
||||
const rawDB = "sql.DB"
|
||||
|
||||
// ErrNotFound is an alias of sql.ErrNoRows
|
||||
var ErrNotFound = sql.ErrNoRows
|
||||
|
||||
@@ -23,6 +26,7 @@ type (
|
||||
// SqlConn only stands for raw connections, so Transact method can be called.
|
||||
SqlConn interface {
|
||||
Session
|
||||
RawDB() (*sql.DB, error)
|
||||
Transact(func(session Session) error) error
|
||||
}
|
||||
|
||||
@@ -43,13 +47,15 @@ type (
|
||||
// Because CORBA doesn't support PREPARE, so we need to combine the
|
||||
// query arguments into one string and do underlying query without arguments
|
||||
commonSqlConn struct {
|
||||
driverName string
|
||||
datasource string
|
||||
connProv connProvider
|
||||
beginTx beginnable
|
||||
brk breaker.Breaker
|
||||
accept func(error) bool
|
||||
}
|
||||
|
||||
connProvider func() (*sql.DB, error)
|
||||
|
||||
sessionConn interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
@@ -69,10 +75,30 @@ type (
|
||||
// NewSqlConn returns a SqlConn with given driver name and datasource.
|
||||
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
|
||||
conn := &commonSqlConn{
|
||||
driverName: driverName,
|
||||
datasource: datasource,
|
||||
beginTx: begin,
|
||||
brk: breaker.NewBreaker(),
|
||||
connProv: func() (*sql.DB, error) {
|
||||
return getSqlConn(driverName, datasource)
|
||||
},
|
||||
beginTx: begin,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(conn)
|
||||
}
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
// NewSqlConnFromDB returns a SqlConn with the given sql.DB.
|
||||
// Use it with caution, it's provided for other ORM to interact with.
|
||||
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
|
||||
conn := &commonSqlConn{
|
||||
datasource: rawDB,
|
||||
connProv: func() (*sql.DB, error) {
|
||||
return db, nil
|
||||
},
|
||||
beginTx: begin,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(conn)
|
||||
@@ -84,7 +110,7 @@ func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
|
||||
func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
|
||||
err = db.brk.DoWithAcceptable(func() error {
|
||||
var conn *sql.DB
|
||||
conn, err = getSqlConn(db.driverName, db.datasource)
|
||||
conn, err = db.connProv()
|
||||
if err != nil {
|
||||
logInstanceError(db.datasource, err)
|
||||
return err
|
||||
@@ -100,7 +126,7 @@ func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result,
|
||||
func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
|
||||
err = db.brk.DoWithAcceptable(func() error {
|
||||
var conn *sql.DB
|
||||
conn, err = getSqlConn(db.driverName, db.datasource)
|
||||
conn, err = db.connProv()
|
||||
if err != nil {
|
||||
logInstanceError(db.datasource, err)
|
||||
return err
|
||||
@@ -145,6 +171,10 @@ func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...inter
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) RawDB() (*sql.DB, error) {
|
||||
return db.connProv()
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) Transact(fn func(Session) error) error {
|
||||
return db.brk.DoWithAcceptable(func() error {
|
||||
return transact(db, db.beginTx, fn)
|
||||
@@ -163,7 +193,7 @@ func (db *commonSqlConn) acceptable(err error) bool {
|
||||
func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
||||
var qerr error
|
||||
return db.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getSqlConn(db.driverName, db.datasource)
|
||||
conn, err := db.connProv()
|
||||
if err != nil {
|
||||
logInstanceError(db.datasource, err)
|
||||
return err
|
||||
|
||||
@@ -21,12 +21,15 @@ func TestSqlConn(t *testing.T) {
|
||||
mock.ExpectExec("any")
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}))
|
||||
conn := NewMysql(mockedDatasource)
|
||||
db, err := conn.RawDB()
|
||||
assert.Nil(t, err)
|
||||
rawConn := NewSqlConnFromDB(db, withMysqlAcceptable())
|
||||
badConn := NewMysql("badsql")
|
||||
_, err := conn.Exec("any", "value")
|
||||
_, err = conn.Exec("any", "value")
|
||||
assert.NotNil(t, err)
|
||||
_, err = badConn.Exec("any", "value")
|
||||
assert.NotNil(t, err)
|
||||
_, err = conn.Prepare("any")
|
||||
_, err = rawConn.Prepare("any")
|
||||
assert.NotNil(t, err)
|
||||
_, err = badConn.Prepare("any")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
@@ -71,7 +71,7 @@ func begin(db *sql.DB) (trans, error) {
|
||||
}
|
||||
|
||||
func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) {
|
||||
conn, err := getSqlConn(db.driverName, db.datasource)
|
||||
conn, err := db.connProv()
|
||||
if err != nil {
|
||||
logInstanceError(db.datasource, err)
|
||||
return err
|
||||
|
||||
@@ -13,6 +13,7 @@ type (
|
||||
MockConn struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
statement struct {
|
||||
stmt *sql.Stmt
|
||||
}
|
||||
@@ -62,6 +63,11 @@ func (conn *MockConn) QueryRowsPartial(v interface{}, q string, args ...interfac
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
// RawDB returns the underlying sql.DB.
|
||||
func (conn *MockConn) RawDB() (*sql.DB, error) {
|
||||
return conn.db, nil
|
||||
}
|
||||
|
||||
// Transact is the implemention of sqlx.SqlConn, nothing to do
|
||||
func (conn *MockConn) Transact(func(session sqlx.Session) error) error {
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user