上一节,学习了如何实现一个简单的转账事务,但是,我们还没做更新账户余额的操作,因为,它稍复杂一些,需要小心处理并发事务以避免死锁。
本节,将实现这个功能,顺便学习一下数据库锁,以及如何调试死锁的情况。(有点硬核,需要耐心学习,最好自己手动操作一遍,以便深入理解)
测试驱动开发(TDD)
TDD
store_test.goTODO
// 3. 为`ToAccount`创建一个账目记录,金额为`+10`
result.ToEntry, err = q.CreateEntry(ctx, CreateEntryParams{
AccountID: arg.ToAccountID,
Amount: arg.Amount,
})
if err != nil {
return err
}
// TODO: 更新账户余额操作后面再做
要完成这个单元测试,需要先检查钱是从哪转出来的,然后,检查钱转到哪个账户里面去了。
// 首先,检查钱是从哪转出来的,
fromAccount := result.FromAccount
require.NotEmpty(t, fromAccount)
require.Equal(t, account1.ID, fromAccount.ID)
// 然后,检查钱转到哪个账户里面去了
toAccount := result.ToAccount
require.NotEmpty(t, toAccount)
require.Equal(t, account2.ID, toAccount.ID)
account1.Balance - fromAccount.BalancetoAccount.Balance - account2.Balance
kk1n
// 转账方:转出的金额
diff1 := account1.Balance - fromAccount.Balance
// 收钱方:转入的金额
diff2 := toAccount.Balance - account2.Balance
// 这两个值应该相同
require.Equal(t, diff1, diff2)
// 转出来的钱应该大于0
require.True(t, diff1 > 0)
// 转出的金额应该可以被每笔交易的金额整除
require.True(t, diff1%amount == 0)
最后,在最后在 for 循环外面,检查这两个账户的最终余额:
// 最后在 for 循环外面,检查这两个账户的最终余额
updateAccount1, err := testQueries.GetAccount(context.Background(), account1.ID)
require.NoError(t, err)
updateAccount2, err := testQueries.GetAccount(context.Background(), account2.ID)
require.NoError(t, err)
fmt.Println(">> after:", updateAccount1.Balance, updateAccount2.Balance)
// account1的余额减去转账次数乘以每次转账的金额,必须等于最终的余额
require.Equal(t, account1.Balance-int64(n)*amount, updateAccount1.Balance)
require.Equal(t, account2.Balance+int64(n)*amount, updateAccount2.Balance)
完整的代码如下:
package db
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestTransferTx(t *testing.T) {
store := NewStore(testDB)
account1 := createRandomAccount(t)
account2 := createRandomAccount(t)
fmt.Println(">> before:", account1.Balance, account2.Balance)
n := 5
amount := int64(10)
errs := make(chan error)
results := make(chan TransferTxResult)
for i := 0; i < n; i++ {
go func() {
result, err := store.TransferTx(context.Background(), TransferTxParams{
FromAccountID: account1.ID,
ToAccountID: account2.ID,
Amount: amount,
})
errs <- err
results <- result
}()
}
// 检查结果
existed := make(map[int]bool)
for i := 0; i < n; i++ {
err := <-errs
require.NoError(t, err)
result := <-results
require.NotEmpty(t, result)
// check transfer
transfer := result.Transfer
require.NotEmpty(t, transfer)
require.Equal(t, account1.ID, transfer.FromAccountID)
require.Equal(t, account2.ID, transfer.ToAccountID)
require.Equal(t, amount, transfer.Amount)
require.NotZero(t, transfer.ID)
require.NotZero(t, transfer.CreatedAt)
_, err = store.GetTransfer(context.Background(), transfer.ID)
require.NoError(t, err)
// check entries
formEntry := result.FromEntry
require.NotEmpty(t, formEntry)
require.Equal(t, account1.ID, formEntry.AccountID)
require.Equal(t, -amount, formEntry.Amount)
require.NotZero(t, formEntry.ID)
require.NotZero(t, formEntry.CreatedAt)
_, err = store.GetEntry(context.Background(), formEntry.ID)
require.NoError(t, err)
toEntry := result.ToEntry
require.NotEmpty(t, toEntry)
require.Equal(t, account2.ID, toEntry.AccountID)
require.Equal(t, amount, toEntry.Amount)
require.NotZero(t, toEntry.ID)
require.NotZero(t, toEntry.CreatedAt)
_, err = store.GetEntry(context.Background(), toEntry.ID)
require.NoError(t, err)
// 首先,检查钱是从哪转出来的,
fromAccount := result.FromAccount
require.NotEmpty(t, fromAccount)
require.Equal(t, account1.ID, fromAccount.ID)
// 然后,检查钱转到哪个账户里面去了
toAccount := result.ToAccount
require.NotEmpty(t, toAccount)
require.Equal(t, account2.ID, toAccount.ID)
// 检查更新后的账户余额
fmt.Println(">> tx:", fromAccount.Balance, toAccount.Balance)
// 转账方:转出的金额
diff1 := account1.Balance - fromAccount.Balance
// 收钱方:转入的金额
diff2 := toAccount.Balance - account2.Balance
// 这两个值应该相同
require.Equal(t, diff1, diff2)
// 转出来的钱应该大于0
require.True(t, diff1 > 0)
// 转出的金额应该可以被每笔交易的金额整除
require.True(t, diff1%amount == 0)
// 计算 k = diff1 除以 每笔交易的金额,k 必须是大于等于1,并且小于等于n的
// 此外,每笔交易的 k 必须是唯一的,意思是第1笔交易时,k应该等于1,第2笔交易时,k应该等于2...
// 上面,需要定义一个新变量,existed
k := int(diff1 / amount)
require.True(t, k >= 1 && k <= n)
// 检查这个map,不应该包含 k
require.NotContains(t, existed, k)
// 之后,给这个 map 赋值
existed[k] = true
}
// 最后在 for 循环外面,检查这两个账户的最终余额
updateAccount1, err := testQueries.GetAccount(context.Background(), account1.ID)
require.NoError(t, err)
updateAccount2, err := testQueries.GetAccount(context.Background(), account2.ID)
require.NoError(t, err)
fmt.Println(">> after:", updateAccount1.Balance, updateAccount2.Balance)
// account1的余额减去转账次数乘以每次转账的金额,必须等于最终的余额
require.Equal(t, account1.Balance-int64(n)*amount, updateAccount1.Balance)
require.Equal(t, account2.Balance+int64(n)*amount, updateAccount2.Balance)
}
run test
store.go
更新账户的余额(错误的方法)
account
// 从数据库中获取 account -> 更新账户余额
account1, err := q.GetAccount(ctx, arg.FromAccountID)
if err != nil {
return err
}
result.FromAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
ID: arg.FromAccountID,
Balance: account1.Balance - arg.Amount,
})
if err != nil {
return err
}
account2, err := q.GetAccount(ctx, arg.ToAccountID)
if err != nil {
return err
}
result.ToAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
ID: arg.ToAccountID,
Balance: account2.Balance + arg.Amount,
})
if err != nil {
return err
}
run testaccount110166
GetAccountSQL
-- name: GetAccount :one
SELECT * FROM accounts
WHERE id = $1 LIMIT 1;
account1
postgres
docker exec -it postgres14 psql -U root -d simple_bank
不加锁查询的情况
让我们在两个不同的终端运行2个并行事务。
BEGIN;BEGIN;select * from accounts where id=1;select * from accounts where id=1;
ROLLBACK;
加锁查询的情况
BEGIN;BEGIN;select * from accounts where id=1 for update;select * from accounts where id=1 for update;update accounts set balance=500 where id=1;COMMIT;balance500
加锁更新账户余额
account.sqlsqlGetAccountForUpdate
-- name: GetAccountForUpdate :one
SELECT * FROM accounts
WHERE id = $1 LIMIT 1
FOR UPDATE;
make sqlcaccount.sql.goGetAccountForUpdate
store.go
// 从数据库中获取 account -> 更新账户余额
account1, err := q.GetAccountForUpdate(ctx, arg.FromAccountID)
if err != nil {
return err
}
result.FromAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
ID: arg.FromAccountID,
Balance: account1.Balance - arg.Amount,
})
if err != nil {
return err
}
account2, err := q.GetAccountForUpdate(ctx, arg.ToAccountID)
if err != nil {
return err
}
result.ToAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
ID: arg.ToAccountID,
Balance: account2.Balance + arg.Amount,
})
if err != nil {
return err
}
run testdeadlock detected
调试死锁
deadlocksql
TransferTx()store_test.gofor i := 0; i < n; i++ {
txName := fmt.Sprintf("tx %d", i+1)
txNamecontextstore.go
var txKey = struct{}{}
{}store_test.gotxKeytxNamecontext.WithValue()
for i := 0; i < n; i++ {
txName := fmt.Sprintf("tx %d", i+1)
go func() {
ctx := context.WithValue(context.Background(), txKey, txName)
result, err := store.TransferTx(ctx, TransferTxParams{
FromAccountID: account1.ID,
ToAccountID: account2.ID,
Amount: amount,
})
errs <- err
results <- result
}()
}
这样,上下文就可以保存事务名称了。
store.goTransferTx
var txKey = struct{}{}
func (store *Store) TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) {
var result TransferTxResult
err := store.execTx(ctx, func(q *Queries) error {
var err error
txName := ctx.Value(txKey)
// 1. 创建一个金额等于`10`的转账记录
fmt.Println(txName, "create transfer")
result.Transfer, err = q.CreateTransfer(ctx, CreateTransferParams{
FromAccountID: arg.FromAccountID,
ToAccountID: arg.ToAccountID,
Amount: arg.Amount,
})
if err != nil {
return err
}
// 2. 为`FromAccount`创建一个账目记录,金额为`-10`
fmt.Println(txName, "create entry 1")
result.FromEntry, err = q.CreateEntry(ctx, CreateEntryParams{
AccountID: arg.FromAccountID,
Amount: -arg.Amount,
})
if err != nil {
return err
}
// 3. 为`ToAccount`创建一个账目记录,金额为`+10`
fmt.Println(txName, "create entry 2")
result.ToEntry, err = q.CreateEntry(ctx, CreateEntryParams{
AccountID: arg.ToAccountID,
Amount: arg.Amount,
})
if err != nil {
return err
}
// 从数据库中获取 account -> 更新账户余额
fmt.Println(txName, "get account 1")
account1, err := q.GetAccountForUpdate(ctx, arg.FromAccountID)
if err != nil {
return err
}
fmt.Println(txName, "update account 1")
result.FromAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
ID: arg.FromAccountID,
Balance: account1.Balance - arg.Amount,
})
if err != nil {
return err
}
fmt.Println(txName, "get account 2")
account2, err := q.GetAccountForUpdate(ctx, arg.ToAccountID)
if err != nil {
return err
}
fmt.Println(txName, "update account 2")
result.ToAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
ID: arg.ToAccountID,
Balance: account2.Balance + arg.Amount,
})
if err != nil {
return err
}
return err
})
return result, err
}
store_test.gonrun test
=== RUN TestTransferTx
>> before: 103 980
tx 3 create transfer
tx 3 create entry 1
tx 3 create entry 2
tx 2 create transfer
tx 1 create transfer
tx 3 get account 1
tx 3 update account 1
tx 3 get account 2
tx 3 update account 2
tx 2 create entry 1
tx 1 create entry 1
tx 1 create entry 2
tx 2 create entry 2
tx 1 get account 1
tx 2 get account 1
>> tx: 93 990
tx 2 update account 1
gogosql
BEGIN;
-- create transfer
INSERT INTO transfers (from_account_id, to_account_id, amount) VALUES (1, 2, 10) RETURNING *;
-- create entry 1
INSERT INTO entries (account_id, amount) VALUES (1, -10) RETURNING *;
-- create entry 2
INSERT INTO entries (account_id, amount) VALUES (2, 10) RETURNING *;
-- get account 1
SELECT * FROM accounts WHERE id = 1 FOR UPDATE;
-- update account 1
UPDATE accounts SET balance = 90 WHERE id = 1 RETURNING *;
-- get account 2
SELECT * FROM accounts WHERE id = 2 FOR UPDATE;
-- update account 2
UPDATE accounts SET balance = 110 WHERE id = 2 RETURNING *;
postgrespostgressql
BEGIN;tx 3 create transfer3INSERT INTO transfers (from_account_id, to_account_id, amount) VALUES (1, 2, 10) RETURNING *;tx 3 create entry 13INSERT INTO entries (account_id, amount) VALUES (1, -10) RETURNING *;tx 3 get account 1SELECT * FROM accounts WHERE id = 1 FOR UPDATE;INSERT transfersSELECT accountspostgres lockpostgresWIKI
SELECT blocked_locks.pid AS blocked_pid,
blocked_activity.usename AS blocked_user,
blocking_locks.pid AS blocking_pid,
blocking_activity.usename AS blocking_user,
blocked_activity.query AS blocked_statement,
blocking_activity.query AS current_statement_in_blocking_process
FROM pg_catalog.pg_locks blocked_locks
JOIN pg_catalog.pg_stat_activity blocked_activity ON blocked_activity.pid = blocked_locks.pid
JOIN pg_catalog.pg_locks blocking_locks
ON blocking_locks.locktype = blocked_locks.locktype
AND blocking_locks.database IS NOT DISTINCT FROM blocked_locks.database
AND blocking_locks.relation IS NOT DISTINCT FROM blocked_locks.relation
AND blocking_locks.page IS NOT DISTINCT FROM blocked_locks.page
AND blocking_locks.tuple IS NOT DISTINCT FROM blocked_locks.tuple
AND blocking_locks.virtualxid IS NOT DISTINCT FROM blocked_locks.virtualxid
AND blocking_locks.transactionid IS NOT DISTINCT FROM blocked_locks.transactionid
AND blocking_locks.classid IS NOT DISTINCT FROM blocked_locks.classid
AND blocking_locks.objid IS NOT DISTINCT FROM blocked_locks.objid
AND blocking_locks.objsubid IS NOT DISTINCT FROM blocked_locks.objsubid
AND blocking_locks.pid != blocked_locks.pid
JOIN pg_catalog.pg_stat_activity blocking_activity ON blocking_activity.pid = blocking_locks.pid
WHERE NOT blocked_locks.granted;
sqlnavicatblockedSELECT * FROM accounts WHERE id = 1 FOR UPDATE;blockingINSERT INTO transfers (from_account_id, to_account_id, amount) VALUES (1, 2, 10) RETURNING *;SELECTINSERTpostgres WIKI
SELECT a.datname,
l.relation::regclass,
l.transactionid,
l.mode,
l.GRANTED,
a.usename,
a.query,
a.query_start,
age(now(), a.query_start) AS "age",
a.pid
FROM pg_stat_activity a
JOIN pg_locks l ON l.pid = a.pid
ORDER BY a.query_start;
navicata.datname,a.application_name,
这个sql中,
l.relation::regclassl.transactionidl.model.locktypel.GRANTEDa.usenamesqla.querysqla.query_startage(now(), a.query_start) AS "age"a.pidORDER BY a.query_startORDER BY a.pidpsqlWHERE a.application_name = 'psql'psql
SELECT a.application_name,
l.relation::regclass,
l.transactionid,
l.mode,
l.locktype,
l.GRANTED,
a.usename,
a.query,
a.pid
FROM pg_stat_activity a
JOIN pg_locks l ON l.pid = a.pid
WHERE a.application_name = 'psql'
ORDER BY a.pid;
navicatGRANTEDftransactionid1075transactionid1075INSERTpid6991SELECT FROM accountsINSERT INTO transfers
sql
ALTER TABLE "transfers" ADD FOREIGN KEY ("from_account_id") REFERENCES "accounts" ("id");
transfersfrom_account_idaccountsSELECT FOR UPDATEdeadlockROLLBACK;\q
修复死锁方案1
sql000001_init_schema.up.sql
-- ALTER TABLE "entries" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id");
-- ALTER TABLE "transfers" ADD FOREIGN KEY ("from_account_id") REFERENCES "accounts" ("id");
-- ALTER TABLE "transfers" ADD FOREIGN KEY ("to_account_id") REFERENCES "accounts" ("id");
make migratedownmake migrateuprun test
sqlmake migratedownmake migrateup
修复死锁方案2
account.sql
-- name: UpdateAccount :one
UPDATE accounts SET balance = $2 WHERE id = $1
RETURNING *;
postgresSELECT * FROM accounts FOR UPDATEpostgresdeadlock
FOR UPDATENO KEY UPDATE
-- name: GetAccountForUpdate :one
SELECT * FROM accounts
WHERE id = $1 LIMIT 1
FOR NO KEY UPDATE;
make sqlc
之后,清理一下代码,把我们之前加的打印日志都去掉吧,再运行一下单元测试,没问题,通过。
可以,看到这个日志,每次交易后的2个账户余额是怎么变化的。
修复死锁方案3
sql
account1, err := q.GetAccountForUpdate(ctx, arg.FromAccountID)
if err != nil {
return err
}
result.FromAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
ID: arg.FromAccountID,
Balance: account1.Balance - arg.Amount,
})
if err != nil {
return err
}
sqlaccount.sqlAddAccountBalanceUpdateAccountbalancebalance
-- name: AddAccountBalance :one
UPDATE accounts SET balance = balance + $2 WHERE id = $1
RETURNING *;
make sqlcaccount.sql.go
type AddAccountBalanceParams struct {
ID int64 `json:"id"`
Balance int64 `json:"balance"`
}
BalanceBalancesqlc
-- name: AddAccountBalance :one
UPDATE accounts SET balance = balance + sqlc.arg(amount)
WHERE id = sqlc.arg(id)
RETURNING *;
$2sqlc.arg(amount)$1sqlc.arg(id)make sqlc
type AddAccountBalanceParams struct {
Amount int64 `json:"amount"`
ID int64 `json:"id"`
}
BalanceAmount
store.goGetAccountForUpdateUpdateAccountAddAccountBalance
package db
import (
"context"
"database/sql"
"fmt"
)
type Store struct {
*Queries
db *sql.DB
}
func NewStore(db *sql.DB) *Store {
return &Store{
db: db,
Queries: New(db),
}
}
func (store *Store) execTx(ctx context.Context, fn func(*Queries) error) error {
tx, err := store.db.BeginTx(ctx, nil)
if err != nil {
return err
}
q := New(tx)
err = fn(q)
if err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
}
return err
}
return tx.Commit()
}
type TransferTxParams struct {
FromAccountID int64 `json:"from_account_id"`
ToAccountID int64 `json:"to_account_id"`
Amount int64 `json:"amount"`
}
type TransferTxResult struct {
Transfer Transfer `json:"transfer"`
FromAccount Account `json:"from_account"`
ToAccount Account `json:"to_account"`
FromEntry Entry `json:"from_entry"`
ToEntry Entry `json:"to_entry"`
}
func (store *Store) TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) {
var result TransferTxResult
err := store.execTx(ctx, func(q *Queries) error {
var err error
// 1. 创建一个金额等于`10`的转账记录
result.Transfer, err = q.CreateTransfer(ctx, CreateTransferParams{
FromAccountID: arg.FromAccountID,
ToAccountID: arg.ToAccountID,
Amount: arg.Amount,
})
if err != nil {
return err
}
// 2. 为`FromAccount`创建一个账目记录,金额为`-10`
result.FromEntry, err = q.CreateEntry(ctx, CreateEntryParams{
AccountID: arg.FromAccountID,
Amount: -arg.Amount,
})
if err != nil {
return err
}
// 3. 为`ToAccount`创建一个账目记录,金额为`+10`
result.ToEntry, err = q.CreateEntry(ctx, CreateEntryParams{
AccountID: arg.ToAccountID,
Amount: arg.Amount,
})
if err != nil {
return err
}
// 更新账户余额
result.FromAccount, err = q.AddAccountBalance(ctx, AddAccountBalanceParams{
ID: arg.FromAccountID,
Amount: -arg.Amount,
})
if err != nil {
return err
}
result.ToAccount, err = q.AddAccountBalance(ctx, AddAccountBalanceParams{
ID: arg.ToAccountID,
Amount: arg.Amount,
})
if err != nil {
return err
}
return err
})
return result, err
}
再次运行单元测试,没问题,测试通过。
如何避免数据库事务查询中的死锁顺序问题