Golang
在开始之前,先聊一下事务。
什么是数据库事务?
银行张三李四105accountsentriestransfers
10transfers张三-10entries李四+10entries张三10accounts李四10accounts
为什么需要使用数据库事务?
主要原因有2个:
- 我们希望这个操作单元可靠且一致,即使系统出现某些故障的情况下也如此。
- 在程序和访问数据库之间提供隔离。
ACID
AAtomicityCConsistencyIIsolationDDurability
SQL中如何操作数据库事务?
BEGINSQLCOMMITROLLBACKROLLBACKGolang
在Golang中实现数据库事务
sqlcstore.goStore
SQLCQueries11QueriesStoreGocompostionGolang
StoreQueriessql.DBsql.DB
type Store struct {
*Queries
db *sql.DB
}
NewStoreStore
func NewStore(db *sql.DB) *Store {
return &Store{
db: db,
Queries: New(db),
}
}
sql.DBStoreStoredbQueriesdbNewNewsqlcQueries
接下来,再新建个方法来执行通用的数据库事务,思路很简单,它将一个上下文和一个回调函数作为入参,然后它将启动一个新的数据库事务。
func (store *Store) execTx(ctx context.Context, fn func(*Queries) error) error {
tx, err := store.db.BeginTx(ctx, nil)
if err != nil {
return err
}
q := New(tx)
err = fn(q)
if err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
}
return err
}
return tx.Commit()
}
解释一下上面这段代码:
tx, err := store.db.BeginTx(ctx, nil)ctxnilsql.TxOptionspostgresnilif err != nilBeginTx()q := New(tx)New()QueriesNewStore()New()NewStoreNew()sql.DBNew()sql.TxNew()DBTXsql.Txsql.DBQuerieserr = fn(q)errorif err != nil {errtx.Rollback()rbErrnil两个1fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)return errtx.Commit()
execTx
TransferTx()TransferTxParamsTransferTxResult
type TransferTxParams struct {
FromAccountID int64 `json:"from_account_id"`
ToAccountID int64 `json:"to_account_id"`
Amount int64 `json:"amount"`
}
type TransferTxResult struct {
Transfer Transfer `json:"transfer"`
FromAccount Account `json:"from_account"`
ToAccount Account `json:"to_account"`
FromEntry Entry `json:"from_entry"`
ToEntry Entry `json:"to_entry"`
}
TransferTx()
func (store *Store) TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) {
var result TransferTxResult
err := store.execTx(ctx, func(q *Queries) error {
var err error
// 1. 创建一个金额等于`10`的转账记录
result.Transfer, err = q.CreateTransfer(ctx, CreateTransferParams{
FromAccountID: arg.FromAccountID,
ToAccountID: arg.ToAccountID,
Amount: arg.Amount,
})
if err != nil {
return err
}
// 2. 为`FromAccount`创建一个账目记录,金额为`-10`
result.FromEntry, err = q.CreateEntry(ctx, CreateEntryParams{
AccountID: arg.FromAccountID,
Amount: -arg.Amount,
})
if err != nil {
return err
}
// 3. 为`ToAccount`创建一个账目记录,金额为`+10`
result.ToEntry, err = q.CreateEntry(ctx, CreateEntryParams{
AccountID: arg.ToAccountID,
Amount: arg.Amount,
})
if err != nil {
return err
}
// TODO: 更新账户余额操作后面再做
return err
})
return result, err
}
q.CreateTransferq.CreateEntrysqlcsqlmake sqlc
上面的代码,就是按转账拆解的5个步骤,分步实现的。
单元测试
store_test.gomain_test.go
package db
import (
"database/sql"
"log"
"os"
"testing"
_ "github.com/lib/pq"
)
const (
dbDriver = "postgres"
dbSource = "postgresql://root:123456@localhost:5432/simple_bank?sslmode=disable"
)
var testQueries *Queries
var testDB *sql.DB
func TestMain(m *testing.M) {
var err error
testDB, err = sql.Open(dbDriver, dbSource)
if err != nil {
log.Fatal("cannot connect to db:", err)
}
testQueries = New(testDB)
os.Exit(m.Run())
}
var testDB *sql.DBstore_test.goconn, err := sql.Open(dbDriver, dbSource)testDB, err = sql.Open(dbDriver, dbSource)testQueries = New(conn)testQueries = New(testDB)
store_test.go
package db
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestTransferTx(t *testing.T) {
store := NewStore(testDB)
account1 := createRandomAccount(t)
account2 := createRandomAccount(t)
n := 5
amount := int64(10)
errs := make(chan error)
results := make(chan TransferTxResult)
for i := 0; i < n; i++ {
go func() {
result, err := store.TransferTx(context.Background(), TransferTxParams{
FromAccountID: account1.ID,
ToAccountID: account2.ID,
Amount: amount,
})
errs <- err
results <- result
}()
}
// 检查结果
for i := 0; i < n; i++ {
err := <-errs
require.NoError(t, err)
result := <-results
require.NotEmpty(t, result)
// check transfer
transfer := result.Transfer
require.NotEmpty(t, transfer)
require.Equal(t, account1.ID, transfer.FromAccountID)
require.Equal(t, account2.ID, transfer.ToAccountID)
require.Equal(t, amount, transfer.Amount)
require.NotZero(t, transfer.ID)
require.NotZero(t, transfer.CreatedAt)
_, err = store.GetTransfer(context.Background(), transfer.ID)
require.NoError(t, err)
// check entries
formEntry := result.FromEntry
require.NotEmpty(t, formEntry)
require.Equal(t, account1.ID, formEntry.AccountID)
require.Equal(t, -amount, formEntry.Amount)
require.NotZero(t, formEntry.ID)
require.NotZero(t, formEntry.CreatedAt)
_, err = store.GetEntry(context.Background(), formEntry.ID)
require.NoError(t, err)
toEntry := result.ToEntry
require.NotEmpty(t, toEntry)
require.Equal(t, account2.ID, toEntry.AccountID)
require.Equal(t, amount, toEntry.Amount)
require.NotZero(t, toEntry.ID)
require.NotZero(t, toEntry.CreatedAt)
_, err = store.GetEntry(context.Background(), toEntry.ID)
require.NoError(t, err)
// TODO: 检查更新后的账户余额
}
}
goGoroutine
n := 5account1account2
for i := 0; i < n; i++ {goroutine
for i := 0; i < n; i++ {
go func() {
}()
}
go func() {store.TransferTx()testifyrequirego routineTestTransferTxgo routine
go routinechannelgo routine
errs := make(chan error)
results := make(chan TransferTxResult)
channelchannelmakechannelgo func() {errerrsresultresults
errs <- err
results <- result
左边右边
go routine<-
run testrun package test