gorm v2.0

支持gormv2.0
This commit is contained in:
谢小军
2020-09-24 20:59:50 +08:00
parent 33d22fe2d4
commit f585a627c1
5 changed files with 46 additions and 46 deletions

View File

@@ -58,6 +58,11 @@ func (obj *_BaseMgr) SetIsRelated(b bool) {
obj.isRelated = b obj.isRelated = b
} }
// New new gorm.新gorm
func (obj *_BaseMgr) New() *gorm.DB {
return obj.DB.Session(&gorm.Session{WithConditions: false, Context: obj.ctx})
}
type options struct { type options struct {
query map[string]interface{} query map[string]interface{}
} }
@@ -96,9 +101,8 @@ func {{$obj.StructName}}Mgr(db *gorm.DB) *_{{$obj.StructName}}Mgr {
if db == nil { if db == nil {
panic(fmt.Errorf("{{$obj.StructName}}Mgr need init by db")) panic(fmt.Errorf("{{$obj.StructName}}Mgr need init by db"))
} }
timeout := 10 * time.Second ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), timeout) return &_{{$obj.StructName}}Mgr{_BaseMgr: &_BaseMgr{DB: db.Table("{{$obj.TableName}}"), isRelated: globalIsRelated,ctx:ctx,cancel:cancel,timeout:-1}}
return &_{{$obj.StructName}}Mgr{_BaseMgr: &_BaseMgr{DB: db.Table("{{$obj.TableName}}"), isRelated: globalIsRelated,ctx:ctx,cancel:cancel,timeout:timeout}}
} }
// GetTableName get sql table name.获取数据库名字 // GetTableName get sql table name.获取数据库名字
@@ -199,12 +203,12 @@ func (obj *_{{$obj.StructName}}Mgr) GetBatchFrom{{$oem.ColStructName}}({{CapLowe
` `
genPreload = `if err == nil && obj.isRelated { {{range $obj := .}}{{if $obj.IsMulti}} genPreload = `if err == nil && obj.isRelated { {{range $obj := .}}{{if $obj.IsMulti}}
if err = obj.WithContext(obj.ctx).Table("{{$obj.ForeignkeyTableName}}").Where("{{$obj.ForeignkeyCol}} = ?", result.{{$obj.ColStructName}}).Find(&result.{{$obj.ForeignkeyStructName}}List).Error;err != nil { // {{$obj.Notes}} if err = obj.New().Table("{{$obj.ForeignkeyTableName}}").Where("{{$obj.ForeignkeyCol}} = ?", result.{{$obj.ColStructName}}).Find(&result.{{$obj.ForeignkeyStructName}}List).Error;err != nil { // {{$obj.Notes}}
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
} {{else}} } {{else}}
if err = obj.WithContext(obj.ctx).Table("{{$obj.ForeignkeyTableName}}").Where("{{$obj.ForeignkeyCol}} = ?", result.{{$obj.ColStructName}}).Find(&result.{{$obj.ForeignkeyStructName}}).Error; err != nil { // {{$obj.Notes}} if err = obj.New().Table("{{$obj.ForeignkeyTableName}}").Where("{{$obj.ForeignkeyCol}} = ?", result.{{$obj.ColStructName}}).Find(&result.{{$obj.ForeignkeyStructName}}).Error; err != nil { // {{$obj.Notes}}
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
@@ -212,12 +216,12 @@ func (obj *_{{$obj.StructName}}Mgr) GetBatchFrom{{$oem.ColStructName}}({{CapLowe
` `
genPreloadMulti = `if err == nil && obj.isRelated { genPreloadMulti = `if err == nil && obj.isRelated {
for i := 0; i < len(results); i++ { {{range $obj := .}}{{if $obj.IsMulti}} for i := 0; i < len(results); i++ { {{range $obj := .}}{{if $obj.IsMulti}}
if err = obj.WithContext(obj.ctx).Table("{{$obj.ForeignkeyTableName}}").Where("{{$obj.ForeignkeyCol}} = ?", results[i].{{$obj.ColStructName}}).Find(&results[i].{{$obj.ForeignkeyStructName}}List).Error;err != nil { // {{$obj.Notes}} if err = obj.New().Table("{{$obj.ForeignkeyTableName}}").Where("{{$obj.ForeignkeyCol}} = ?", results[i].{{$obj.ColStructName}}).Find(&results[i].{{$obj.ForeignkeyStructName}}List).Error;err != nil { // {{$obj.Notes}}
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
} {{else}} } {{else}}
if err = obj.WithContext(obj.ctx).Table("{{$obj.ForeignkeyTableName}}").Where("{{$obj.ForeignkeyCol}} = ?", results[i].{{$obj.ColStructName}}).Find(&results[i].{{$obj.ForeignkeyStructName}}).Error; err != nil { // {{$obj.Notes}} if err = obj.New().Table("{{$obj.ForeignkeyTableName}}").Where("{{$obj.ForeignkeyCol}} = ?", results[i].{{$obj.ColStructName}}).Find(&results[i].{{$obj.ForeignkeyStructName}}).Error; err != nil { // {{$obj.Notes}}
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }

View File

@@ -36,16 +36,18 @@ func GetGorm(dataSourceName string) *gorm.DB {
return db.Debug() return db.Debug()
} }
func NewDB(){ // func NewDB(){
db, _ := gorm.Open(...) // db, _ := gorm.Open(...)
db.Model(&AAA).Where("aaa = ?", 2) // db.Model(&AAA).Where("aaa = ?", 2)
} // CallFunc(db)
// }
func CallFunc(db *gorm.DB){ // func CallFunc(db *gorm.DB){
// select a... // // select a...
var bbb BBB // var bbb BBB
db.Table("bbb").Where("bbb = ?", 2).Find() // db.Table("bbb").Where("bbb = ?", 2).Find(&bbb)// in this case aaa = ? valid
} // // in this func how to us db to query BBB
// }
// TestFuncGet 测试条件获(Get/Gets) // TestFuncGet 测试条件获(Get/Gets)
func TestFuncGet(t *testing.T) { func TestFuncGet(t *testing.T) {

View File

@@ -5,7 +5,6 @@ import (
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
) )
var globalIsRelated bool = true // 全局预加载 var globalIsRelated bool = true // 全局预加载
@@ -50,10 +49,9 @@ func (obj *_BaseMgr) SetIsRelated(b bool) {
obj.isRelated = b obj.isRelated = b
} }
func (obj *_BaseMgr) new() *gorm.DB { // New new gorm.新gorm
newDb := obj.DB.WithContext(obj.ctx) func (obj *_BaseMgr) New() *gorm.DB {
newDb.Statement.Clauses = make(map[string]clause.Clause) return obj.DB.Session(&gorm.Session{WithConditions: false, Context: obj.ctx})
return newDb
} }
type options struct { type options struct {

View File

@@ -3,7 +3,6 @@ package model
import ( import (
"context" "context"
"fmt" "fmt"
"time"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -17,9 +16,8 @@ func AccountMgr(db *gorm.DB) *_AccountMgr {
if db == nil { if db == nil {
panic(fmt.Errorf("AccountMgr need init by db")) panic(fmt.Errorf("AccountMgr need init by db"))
} }
timeout := 10 * time.Second ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), timeout) return &_AccountMgr{_BaseMgr: &_BaseMgr{DB: db.Table("account"), isRelated: globalIsRelated, ctx: ctx, cancel: cancel, timeout: -1}}
return &_AccountMgr{_BaseMgr: &_BaseMgr{DB: db.Table("account"), isRelated: globalIsRelated, ctx: ctx, cancel: cancel, timeout: timeout}}
} }
// GetTableName get sql table name.获取数据库名字 // GetTableName get sql table name.获取数据库名字
@@ -31,7 +29,7 @@ func (obj *_AccountMgr) GetTableName() string {
func (obj *_AccountMgr) Get() (result Account, err error) { func (obj *_AccountMgr) Get() (result Account, err error) {
err = obj.DB.Table(obj.GetTableName()).Find(&result).Error err = obj.DB.Table(obj.GetTableName()).Find(&result).Error
if err == nil && obj.isRelated { if err == nil && obj.isRelated {
if err = obj.new().Table("user").Where("user_id = ?", result.UserID).Find(&result.User).Error; err != nil { // if err = obj.New().Table("user").Where("user_id = ?", result.UserID).Find(&result.User).Error; err != nil { //
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
@@ -46,7 +44,7 @@ func (obj *_AccountMgr) Gets() (results []*Account, err error) {
err = obj.DB.Table(obj.GetTableName()).Find(&results).Error err = obj.DB.Table(obj.GetTableName()).Find(&results).Error
if err == nil && obj.isRelated { if err == nil && obj.isRelated {
for i := 0; i < len(results); i++ { for i := 0; i < len(results); i++ {
if err = obj.Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { // if err = obj.New().Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { //
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
@@ -94,7 +92,7 @@ func (obj *_AccountMgr) GetByOption(opts ...Option) (result Account, err error)
err = obj.DB.Table(obj.GetTableName()).Where(options.query).Find(&result).Error err = obj.DB.Table(obj.GetTableName()).Where(options.query).Find(&result).Error
if err == nil && obj.isRelated { if err == nil && obj.isRelated {
if err = obj.WithContext(obj.ctx).Table("user").Where("user_id = ?", result.UserID).Find(&result.User).Error; err != nil { // if err = obj.New().Table("user").Where("user_id = ?", result.UserID).Find(&result.User).Error; err != nil { //
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
@@ -116,7 +114,7 @@ func (obj *_AccountMgr) GetByOptions(opts ...Option) (results []*Account, err er
err = obj.DB.Table(obj.GetTableName()).Where(options.query).Find(&results).Error err = obj.DB.Table(obj.GetTableName()).Where(options.query).Find(&results).Error
if err == nil && obj.isRelated { if err == nil && obj.isRelated {
for i := 0; i < len(results); i++ { for i := 0; i < len(results); i++ {
if err = obj.WithContext(obj.ctx).Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { // if err = obj.New().Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { //
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
@@ -132,7 +130,7 @@ func (obj *_AccountMgr) GetByOptions(opts ...Option) (results []*Account, err er
func (obj *_AccountMgr) GetFromID(id int) (result Account, err error) { func (obj *_AccountMgr) GetFromID(id int) (result Account, err error) {
err = obj.DB.Table(obj.GetTableName()).Where("id = ?", id).Find(&result).Error err = obj.DB.Table(obj.GetTableName()).Where("id = ?", id).Find(&result).Error
if err == nil && obj.isRelated { if err == nil && obj.isRelated {
if err = obj.WithContext(obj.ctx).Table("user").Where("user_id = ?", result.UserID).Find(&result.User).Error; err != nil { // if err = obj.New().Table("user").Where("user_id = ?", result.UserID).Find(&result.User).Error; err != nil { //
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
@@ -147,7 +145,7 @@ func (obj *_AccountMgr) GetBatchFromID(ids []int) (results []*Account, err error
err = obj.DB.Table(obj.GetTableName()).Where("id IN (?)", ids).Find(&results).Error err = obj.DB.Table(obj.GetTableName()).Where("id IN (?)", ids).Find(&results).Error
if err == nil && obj.isRelated { if err == nil && obj.isRelated {
for i := 0; i < len(results); i++ { for i := 0; i < len(results); i++ {
if err = obj.WithContext(obj.ctx).Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { // if err = obj.New().Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { //
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
@@ -162,7 +160,7 @@ func (obj *_AccountMgr) GetFromAccountID(accountID int) (results []*Account, err
err = obj.DB.Table(obj.GetTableName()).Where("account_id = ?", accountID).Find(&results).Error err = obj.DB.Table(obj.GetTableName()).Where("account_id = ?", accountID).Find(&results).Error
if err == nil && obj.isRelated { if err == nil && obj.isRelated {
for i := 0; i < len(results); i++ { for i := 0; i < len(results); i++ {
if err = obj.WithContext(obj.ctx).Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { // if err = obj.New().Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { //
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
@@ -177,7 +175,7 @@ func (obj *_AccountMgr) GetBatchFromAccountID(accountIDs []int) (results []*Acco
err = obj.DB.Table(obj.GetTableName()).Where("account_id IN (?)", accountIDs).Find(&results).Error err = obj.DB.Table(obj.GetTableName()).Where("account_id IN (?)", accountIDs).Find(&results).Error
if err == nil && obj.isRelated { if err == nil && obj.isRelated {
for i := 0; i < len(results); i++ { for i := 0; i < len(results); i++ {
if err = obj.WithContext(obj.ctx).Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { // if err = obj.New().Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { //
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
@@ -192,7 +190,7 @@ func (obj *_AccountMgr) GetFromUserID(userID int) (results []*Account, err error
err = obj.DB.Table(obj.GetTableName()).Where("user_id = ?", userID).Find(&results).Error err = obj.DB.Table(obj.GetTableName()).Where("user_id = ?", userID).Find(&results).Error
if err == nil && obj.isRelated { if err == nil && obj.isRelated {
for i := 0; i < len(results); i++ { for i := 0; i < len(results); i++ {
if err = obj.WithContext(obj.ctx).Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { // if err = obj.New().Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { //
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
@@ -207,7 +205,7 @@ func (obj *_AccountMgr) GetBatchFromUserID(userIDs []int) (results []*Account, e
err = obj.DB.Table(obj.GetTableName()).Where("user_id IN (?)", userIDs).Find(&results).Error err = obj.DB.Table(obj.GetTableName()).Where("user_id IN (?)", userIDs).Find(&results).Error
if err == nil && obj.isRelated { if err == nil && obj.isRelated {
for i := 0; i < len(results); i++ { for i := 0; i < len(results); i++ {
if err = obj.WithContext(obj.ctx).Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { // if err = obj.New().Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { //
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
@@ -222,7 +220,7 @@ func (obj *_AccountMgr) GetFromType(_type int) (results []*Account, err error) {
err = obj.DB.Table(obj.GetTableName()).Where("type = ?", _type).Find(&results).Error err = obj.DB.Table(obj.GetTableName()).Where("type = ?", _type).Find(&results).Error
if err == nil && obj.isRelated { if err == nil && obj.isRelated {
for i := 0; i < len(results); i++ { for i := 0; i < len(results); i++ {
if err = obj.WithContext(obj.ctx).Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { // if err = obj.New().Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { //
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
@@ -237,7 +235,7 @@ func (obj *_AccountMgr) GetBatchFromType(_types []int) (results []*Account, err
err = obj.DB.Table(obj.GetTableName()).Where("type IN (?)", _types).Find(&results).Error err = obj.DB.Table(obj.GetTableName()).Where("type IN (?)", _types).Find(&results).Error
if err == nil && obj.isRelated { if err == nil && obj.isRelated {
for i := 0; i < len(results); i++ { for i := 0; i < len(results); i++ {
if err = obj.WithContext(obj.ctx).Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { // if err = obj.New().Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { //
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
@@ -252,7 +250,7 @@ func (obj *_AccountMgr) GetFromName(name string) (results []*Account, err error)
err = obj.DB.Table(obj.GetTableName()).Where("name = ?", name).Find(&results).Error err = obj.DB.Table(obj.GetTableName()).Where("name = ?", name).Find(&results).Error
if err == nil && obj.isRelated { if err == nil && obj.isRelated {
for i := 0; i < len(results); i++ { for i := 0; i < len(results); i++ {
if err = obj.WithContext(obj.ctx).Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { // if err = obj.New().Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { //
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
@@ -267,7 +265,7 @@ func (obj *_AccountMgr) GetBatchFromName(names []string) (results []*Account, er
err = obj.DB.Table(obj.GetTableName()).Where("name IN (?)", names).Find(&results).Error err = obj.DB.Table(obj.GetTableName()).Where("name IN (?)", names).Find(&results).Error
if err == nil && obj.isRelated { if err == nil && obj.isRelated {
for i := 0; i < len(results); i++ { for i := 0; i < len(results); i++ {
if err = obj.WithContext(obj.ctx).Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { // if err = obj.New().Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { //
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
@@ -283,7 +281,7 @@ func (obj *_AccountMgr) GetBatchFromName(names []string) (results []*Account, er
func (obj *_AccountMgr) FetchByPrimaryKey(id int) (result Account, err error) { func (obj *_AccountMgr) FetchByPrimaryKey(id int) (result Account, err error) {
err = obj.DB.Table(obj.GetTableName()).Where("id = ?", id).Find(&result).Error err = obj.DB.Table(obj.GetTableName()).Where("id = ?", id).Find(&result).Error
if err == nil && obj.isRelated { if err == nil && obj.isRelated {
if err = obj.WithContext(obj.ctx).Table("user").Where("user_id = ?", result.UserID).Find(&result.User).Error; err != nil { // if err = obj.New().Table("user").Where("user_id = ?", result.UserID).Find(&result.User).Error; err != nil { //
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
@@ -297,7 +295,7 @@ func (obj *_AccountMgr) FetchByPrimaryKey(id int) (result Account, err error) {
func (obj *_AccountMgr) FetchUniqueIndexByAccount(accountID int, userID int) (result Account, err error) { func (obj *_AccountMgr) FetchUniqueIndexByAccount(accountID int, userID int) (result Account, err error) {
err = obj.DB.Table(obj.GetTableName()).Where("account_id = ? AND user_id = ?", accountID, userID).Find(&result).Error err = obj.DB.Table(obj.GetTableName()).Where("account_id = ? AND user_id = ?", accountID, userID).Find(&result).Error
if err == nil && obj.isRelated { if err == nil && obj.isRelated {
if err = obj.WithContext(obj.ctx).Table("user").Where("user_id = ?", result.UserID).Find(&result.User).Error; err != nil { // if err = obj.New().Table("user").Where("user_id = ?", result.UserID).Find(&result.User).Error; err != nil { //
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }
@@ -312,7 +310,7 @@ func (obj *_AccountMgr) FetchIndexByTp(userID int, _type int) (results []*Accoun
err = obj.DB.Table(obj.GetTableName()).Where("user_id = ? AND type = ?", userID, _type).Find(&results).Error err = obj.DB.Table(obj.GetTableName()).Where("user_id = ? AND type = ?", userID, _type).Find(&results).Error
if err == nil && obj.isRelated { if err == nil && obj.isRelated {
for i := 0; i < len(results); i++ { for i := 0; i < len(results); i++ {
if err = obj.WithContext(obj.ctx).Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { // if err = obj.New().Table("user").Where("user_id = ?", results[i].UserID).Find(&results[i].User).Error; err != nil { //
if err != gorm.ErrRecordNotFound { // 非 没找到 if err != gorm.ErrRecordNotFound { // 非 没找到
return return
} }

View File

@@ -3,7 +3,6 @@ package model
import ( import (
"context" "context"
"fmt" "fmt"
"time"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -17,9 +16,8 @@ func UserMgr(db *gorm.DB) *_UserMgr {
if db == nil { if db == nil {
panic(fmt.Errorf("UserMgr need init by db")) panic(fmt.Errorf("UserMgr need init by db"))
} }
timeout := 10 * time.Second ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), timeout) return &_UserMgr{_BaseMgr: &_BaseMgr{DB: db.Table("user"), isRelated: globalIsRelated, ctx: ctx, cancel: cancel, timeout: -1}}
return &_UserMgr{_BaseMgr: &_BaseMgr{DB: db.Table("user"), isRelated: globalIsRelated, ctx: ctx, cancel: cancel, timeout: timeout}}
} }
// GetTableName get sql table name.获取数据库名字 // GetTableName get sql table name.获取数据库名字