上一节,学习了如何实现一个简单的转账事务,但是,我们还没做更新账户余额的操作,因为,它稍复杂一些,需要小心处理并发事务以避免死锁。

本节,将实现这个功能,顺便学习一下数据库锁,以及如何调试死锁的情况。(有点硬核,需要耐心学习,最好自己手动操作一遍,以便深入理解)

测试驱动开发(TDD)

TDD
store_test.goTODO
		// 3. 为`ToAccount`创建一个账目记录,金额为`+10`
		result.ToEntry, err = q.CreateEntry(ctx, CreateEntryParams{
			AccountID: arg.ToAccountID,
			Amount:    arg.Amount,
		})
		if err != nil {
			return err
		}

		// TODO: 更新账户余额操作后面再做

要完成这个单元测试,需要先检查钱是从哪转出来的,然后,检查钱转到哪个账户里面去了。

		// 首先,检查钱是从哪转出来的,
		fromAccount := result.FromAccount
		require.NotEmpty(t, fromAccount)
		require.Equal(t, account1.ID, fromAccount.ID)

		// 然后,检查钱转到哪个账户里面去了
		toAccount := result.ToAccount
		require.NotEmpty(t, toAccount)
		require.Equal(t, account2.ID, toAccount.ID)
account1.Balance - fromAccount.BalancetoAccount.Balance - account2.Balance
kk1n
		// 转账方:转出的金额
		diff1 := account1.Balance - fromAccount.Balance
		// 收钱方:转入的金额
		diff2 := toAccount.Balance - account2.Balance
		// 这两个值应该相同
		require.Equal(t, diff1, diff2)
		// 转出来的钱应该大于0
		require.True(t, diff1 > 0)
		// 转出的金额应该可以被每笔交易的金额整除
		require.True(t, diff1%amount == 0)

最后,在最后在 for 循环外面,检查这两个账户的最终余额:

	// 最后在 for 循环外面,检查这两个账户的最终余额
	updateAccount1, err := testQueries.GetAccount(context.Background(), account1.ID)
	require.NoError(t, err)

	updateAccount2, err := testQueries.GetAccount(context.Background(), account2.ID)
	require.NoError(t, err)

	fmt.Println(">> after:", updateAccount1.Balance, updateAccount2.Balance)
	// account1的余额减去转账次数乘以每次转账的金额,必须等于最终的余额
	require.Equal(t, account1.Balance-int64(n)*amount, updateAccount1.Balance)
	require.Equal(t, account2.Balance+int64(n)*amount, updateAccount2.Balance)

完整的代码如下:

package db

import (
	"context"
	"fmt"
	"testing"

	"github.com/stretchr/testify/require"
)

func TestTransferTx(t *testing.T) {
	store := NewStore(testDB)

	account1 := createRandomAccount(t)
	account2 := createRandomAccount(t)
	fmt.Println(">> before:", account1.Balance, account2.Balance)

	n := 5
	amount := int64(10)

	errs := make(chan error)
	results := make(chan TransferTxResult)

	for i := 0; i < n; i++ {
		go func() {
			result, err := store.TransferTx(context.Background(), TransferTxParams{
				FromAccountID: account1.ID,
				ToAccountID:   account2.ID,
				Amount:        amount,
			})

			errs <- err
			results <- result
		}()
	}

	// 检查结果
	existed := make(map[int]bool)

	for i := 0; i < n; i++ {
		err := <-errs
		require.NoError(t, err)

		result := <-results
		require.NotEmpty(t, result)

		// check transfer
		transfer := result.Transfer
		require.NotEmpty(t, transfer)
		require.Equal(t, account1.ID, transfer.FromAccountID)
		require.Equal(t, account2.ID, transfer.ToAccountID)
		require.Equal(t, amount, transfer.Amount)
		require.NotZero(t, transfer.ID)
		require.NotZero(t, transfer.CreatedAt)

		_, err = store.GetTransfer(context.Background(), transfer.ID)
		require.NoError(t, err)

		// check entries
		formEntry := result.FromEntry
		require.NotEmpty(t, formEntry)
		require.Equal(t, account1.ID, formEntry.AccountID)
		require.Equal(t, -amount, formEntry.Amount)
		require.NotZero(t, formEntry.ID)
		require.NotZero(t, formEntry.CreatedAt)

		_, err = store.GetEntry(context.Background(), formEntry.ID)
		require.NoError(t, err)

		toEntry := result.ToEntry
		require.NotEmpty(t, toEntry)
		require.Equal(t, account2.ID, toEntry.AccountID)
		require.Equal(t, amount, toEntry.Amount)
		require.NotZero(t, toEntry.ID)
		require.NotZero(t, toEntry.CreatedAt)

		_, err = store.GetEntry(context.Background(), toEntry.ID)
		require.NoError(t, err)

		// 首先,检查钱是从哪转出来的,
		fromAccount := result.FromAccount
		require.NotEmpty(t, fromAccount)
		require.Equal(t, account1.ID, fromAccount.ID)

		// 然后,检查钱转到哪个账户里面去了
		toAccount := result.ToAccount
		require.NotEmpty(t, toAccount)
		require.Equal(t, account2.ID, toAccount.ID)

		// 检查更新后的账户余额
		fmt.Println(">> tx:", fromAccount.Balance, toAccount.Balance)
		// 转账方:转出的金额
		diff1 := account1.Balance - fromAccount.Balance
		// 收钱方:转入的金额
		diff2 := toAccount.Balance - account2.Balance
		// 这两个值应该相同
		require.Equal(t, diff1, diff2)
		// 转出来的钱应该大于0
		require.True(t, diff1 > 0)
		// 转出的金额应该可以被每笔交易的金额整除
		require.True(t, diff1%amount == 0)

		// 计算 k = diff1 除以 每笔交易的金额,k 必须是大于等于1,并且小于等于n的
		// 此外,每笔交易的 k 必须是唯一的,意思是第1笔交易时,k应该等于1,第2笔交易时,k应该等于2...
		// 上面,需要定义一个新变量,existed
		k := int(diff1 / amount)
		require.True(t, k >= 1 && k <= n)
		// 检查这个map,不应该包含 k
		require.NotContains(t, existed, k)
		// 之后,给这个 map 赋值
		existed[k] = true
	}
	// 最后在 for 循环外面,检查这两个账户的最终余额
	updateAccount1, err := testQueries.GetAccount(context.Background(), account1.ID)
	require.NoError(t, err)

	updateAccount2, err := testQueries.GetAccount(context.Background(), account2.ID)
	require.NoError(t, err)

	fmt.Println(">> after:", updateAccount1.Balance, updateAccount2.Balance)
	// account1的余额减去转账次数乘以每次转账的金额,必须等于最终的余额
	require.Equal(t, account1.Balance-int64(n)*amount, updateAccount1.Balance)
	require.Equal(t, account2.Balance+int64(n)*amount, updateAccount2.Balance)
}
run test
store.go

更新账户的余额(错误的方法)

account
		// 从数据库中获取 account -> 更新账户余额
		account1, err := q.GetAccount(ctx, arg.FromAccountID)
		if err != nil {
			return err
		}

		result.FromAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
			ID:      arg.FromAccountID,
			Balance: account1.Balance - arg.Amount,
		})
		if err != nil {
			return err
		}

		account2, err := q.GetAccount(ctx, arg.ToAccountID)
		if err != nil {
			return err
		}

		result.ToAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
			ID:      arg.ToAccountID,
			Balance: account2.Balance + arg.Amount,
		})
		if err != nil {
			return err
		}
run testaccount110166
GetAccountSQL
-- name: GetAccount :one
SELECT * FROM accounts
WHERE id = $1 LIMIT 1;
account1
postgres
docker exec -it postgres14 psql -U root -d simple_bank

不加锁查询的情况

让我们在两个不同的终端运行2个并行事务。

BEGIN;BEGIN;select * from accounts where id=1;select * from accounts where id=1;
ROLLBACK;

加锁查询的情况

BEGIN;BEGIN;select * from accounts where id=1 for update;select * from accounts where id=1 for update;update accounts set balance=500 where id=1;COMMIT;balance500

加锁更新账户余额

account.sqlsqlGetAccountForUpdate
-- name: GetAccountForUpdate :one
SELECT * FROM accounts
WHERE id = $1 LIMIT 1
FOR UPDATE;
make sqlcaccount.sql.goGetAccountForUpdate
store.go
		// 从数据库中获取 account -> 更新账户余额
		account1, err := q.GetAccountForUpdate(ctx, arg.FromAccountID)
		if err != nil {
			return err
		}

		result.FromAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
			ID:      arg.FromAccountID,
			Balance: account1.Balance - arg.Amount,
		})
		if err != nil {
			return err
		}

		account2, err := q.GetAccountForUpdate(ctx, arg.ToAccountID)
		if err != nil {
			return err
		}

		result.ToAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
			ID:      arg.ToAccountID,
			Balance: account2.Balance + arg.Amount,
		})
		if err != nil {
			return err
		}
run testdeadlock detected

调试死锁

deadlocksql
TransferTx()store_test.gofor i := 0; i < n; i++ {
txName := fmt.Sprintf("tx %d", i+1)
txNamecontextstore.go
var txKey = struct{}{}
{}store_test.gotxKeytxNamecontext.WithValue()
	for i := 0; i < n; i++ {
		txName := fmt.Sprintf("tx %d", i+1)
		go func() {
			ctx := context.WithValue(context.Background(), txKey, txName)
			result, err := store.TransferTx(ctx, TransferTxParams{
				FromAccountID: account1.ID,
				ToAccountID:   account2.ID,
				Amount:        amount,
			})

			errs <- err
			results <- result
		}()
	}

这样,上下文就可以保存事务名称了。

store.goTransferTx
var txKey = struct{}{}

func (store *Store) TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) {
	var result TransferTxResult

	err := store.execTx(ctx, func(q *Queries) error {
		var err error

		txName := ctx.Value(txKey)

		// 1. 创建一个金额等于`10`的转账记录
		fmt.Println(txName, "create transfer")
		result.Transfer, err = q.CreateTransfer(ctx, CreateTransferParams{
			FromAccountID: arg.FromAccountID,
			ToAccountID:   arg.ToAccountID,
			Amount:        arg.Amount,
		})
		if err != nil {
			return err
		}

		// 2. 为`FromAccount`创建一个账目记录,金额为`-10`
		fmt.Println(txName, "create entry 1")
		result.FromEntry, err = q.CreateEntry(ctx, CreateEntryParams{
			AccountID: arg.FromAccountID,
			Amount:    -arg.Amount,
		})
		if err != nil {
			return err
		}

		// 3. 为`ToAccount`创建一个账目记录,金额为`+10`
		fmt.Println(txName, "create entry 2")
		result.ToEntry, err = q.CreateEntry(ctx, CreateEntryParams{
			AccountID: arg.ToAccountID,
			Amount:    arg.Amount,
		})
		if err != nil {
			return err
		}

		// 从数据库中获取 account -> 更新账户余额
		fmt.Println(txName, "get account 1")
		account1, err := q.GetAccountForUpdate(ctx, arg.FromAccountID)
		if err != nil {
			return err
		}

		fmt.Println(txName, "update account 1")
		result.FromAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
			ID:      arg.FromAccountID,
			Balance: account1.Balance - arg.Amount,
		})
		if err != nil {
			return err
		}

		fmt.Println(txName, "get account 2")
		account2, err := q.GetAccountForUpdate(ctx, arg.ToAccountID)
		if err != nil {
			return err
		}

		fmt.Println(txName, "update account 2")
		result.ToAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
			ID:      arg.ToAccountID,
			Balance: account2.Balance + arg.Amount,
		})
		if err != nil {
			return err
		}

		return err
	})

	return result, err
}
store_test.gonrun test
=== RUN   TestTransferTx
>> before: 103 980
tx 3 create transfer
tx 3 create entry 1
tx 3 create entry 2
tx 2 create transfer
tx 1 create transfer
tx 3 get account 1
tx 3 update account 1
tx 3 get account 2
tx 3 update account 2
tx 2 create entry 1
tx 1 create entry 1
tx 1 create entry 2
tx 2 create entry 2
tx 1 get account 1
tx 2 get account 1
>> tx: 93 990
tx 2 update account 1
gogosql
BEGIN;
-- create transfer
INSERT INTO transfers (from_account_id, to_account_id, amount) VALUES  (1, 2, 10) RETURNING *;
-- create entry 1
INSERT INTO entries (account_id, amount) VALUES (1, -10) RETURNING *;
-- create entry 2
INSERT INTO entries (account_id, amount) VALUES (2, 10) RETURNING *;
-- get account 1
SELECT * FROM accounts WHERE id = 1 FOR UPDATE;
-- update account 1
UPDATE accounts SET balance = 90 WHERE id = 1 RETURNING *;
-- get account 2
SELECT * FROM accounts WHERE id = 2 FOR UPDATE;
-- update account 2
UPDATE accounts SET balance = 110 WHERE id = 2 RETURNING *;
postgrespostgressql
BEGIN;tx 3 create transfer3INSERT INTO transfers (from_account_id, to_account_id, amount) VALUES (1, 2, 10) RETURNING *;tx 3 create entry 13INSERT INTO entries (account_id, amount) VALUES (1, -10) RETURNING *;tx 3 get account 1SELECT * FROM accounts WHERE id = 1 FOR UPDATE;INSERT transfersSELECT accountspostgres lockpostgresWIKI
SELECT blocked_locks.pid     AS blocked_pid,
         blocked_activity.usename  AS blocked_user,
         blocking_locks.pid     AS blocking_pid,
         blocking_activity.usename AS blocking_user,
         blocked_activity.query    AS blocked_statement,
         blocking_activity.query   AS current_statement_in_blocking_process
   FROM  pg_catalog.pg_locks         blocked_locks
    JOIN pg_catalog.pg_stat_activity blocked_activity  ON blocked_activity.pid = blocked_locks.pid
    JOIN pg_catalog.pg_locks         blocking_locks 
        ON blocking_locks.locktype = blocked_locks.locktype
        AND blocking_locks.database IS NOT DISTINCT FROM blocked_locks.database
        AND blocking_locks.relation IS NOT DISTINCT FROM blocked_locks.relation
        AND blocking_locks.page IS NOT DISTINCT FROM blocked_locks.page
        AND blocking_locks.tuple IS NOT DISTINCT FROM blocked_locks.tuple
        AND blocking_locks.virtualxid IS NOT DISTINCT FROM blocked_locks.virtualxid
        AND blocking_locks.transactionid IS NOT DISTINCT FROM blocked_locks.transactionid
        AND blocking_locks.classid IS NOT DISTINCT FROM blocked_locks.classid
        AND blocking_locks.objid IS NOT DISTINCT FROM blocked_locks.objid
        AND blocking_locks.objsubid IS NOT DISTINCT FROM blocked_locks.objsubid
        AND blocking_locks.pid != blocked_locks.pid

    JOIN pg_catalog.pg_stat_activity blocking_activity ON blocking_activity.pid = blocking_locks.pid
   WHERE NOT blocked_locks.granted;
sqlnavicatblockedSELECT * FROM accounts WHERE id = 1 FOR UPDATE;blockingINSERT INTO transfers (from_account_id, to_account_id, amount) VALUES (1, 2, 10) RETURNING *;SELECTINSERTpostgres WIKI
SELECT a.datname,
         l.relation::regclass,
         l.transactionid,
         l.mode,
         l.GRANTED,
         a.usename,
         a.query,
         a.query_start,
         age(now(), a.query_start) AS "age",
         a.pid
FROM pg_stat_activity a
JOIN pg_locks l ON l.pid = a.pid
ORDER BY a.query_start;
navicata.datname,a.application_name,

这个sql中,

l.relation::regclassl.transactionidl.model.locktypel.GRANTEDa.usenamesqla.querysqla.query_startage(now(), a.query_start) AS "age"a.pidORDER BY a.query_startORDER BY a.pidpsqlWHERE a.application_name = 'psql'psql
SELECT a.application_name,
         l.relation::regclass,
         l.transactionid,
         l.mode,
		 l.locktype,
         l.GRANTED,
         a.usename,
         a.query,
         a.pid
FROM pg_stat_activity a
JOIN pg_locks l ON l.pid = a.pid
WHERE a.application_name = 'psql'
ORDER BY a.pid;
navicatGRANTEDftransactionid1075transactionid1075INSERTpid6991SELECT FROM accountsINSERT INTO transfers
sql
ALTER TABLE "transfers" ADD FOREIGN KEY ("from_account_id") REFERENCES "accounts" ("id");
transfersfrom_account_idaccountsSELECT FOR UPDATEdeadlockROLLBACK;\q

修复死锁方案1

sql000001_init_schema.up.sql
-- ALTER TABLE "entries" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id");

-- ALTER TABLE "transfers" ADD FOREIGN KEY ("from_account_id") REFERENCES "accounts" ("id");

-- ALTER TABLE "transfers" ADD FOREIGN KEY ("to_account_id") REFERENCES "accounts" ("id");
make migratedownmake migrateuprun test
sqlmake migratedownmake migrateup

修复死锁方案2

account.sql
-- name: UpdateAccount :one
UPDATE accounts SET balance = $2 WHERE id = $1
RETURNING *;
postgresSELECT * FROM accounts FOR UPDATEpostgresdeadlock
FOR UPDATENO KEY UPDATE
-- name: GetAccountForUpdate :one
SELECT * FROM accounts
WHERE id = $1 LIMIT 1
FOR NO KEY UPDATE;
make sqlc

之后,清理一下代码,把我们之前加的打印日志都去掉吧,再运行一下单元测试,没问题,通过。

可以,看到这个日志,每次交易后的2个账户余额是怎么变化的。

修复死锁方案3

sql
		account1, err := q.GetAccountForUpdate(ctx, arg.FromAccountID)
		if err != nil {
			return err
		}

		result.FromAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
			ID:      arg.FromAccountID,
			Balance: account1.Balance - arg.Amount,
		})
		if err != nil {
			return err
		}
sqlaccount.sqlAddAccountBalanceUpdateAccountbalancebalance
-- name: AddAccountBalance :one
UPDATE accounts SET balance = balance + $2 WHERE id = $1
RETURNING *;
make sqlcaccount.sql.go
type AddAccountBalanceParams struct {
	ID      int64 `json:"id"`
	Balance int64 `json:"balance"`
}
BalanceBalancesqlc
-- name: AddAccountBalance :one
UPDATE accounts SET balance = balance + sqlc.arg(amount) 
WHERE id = sqlc.arg(id)
RETURNING *;
$2sqlc.arg(amount)$1sqlc.arg(id)make sqlc
type AddAccountBalanceParams struct {
	Amount int64 `json:"amount"`
	ID     int64 `json:"id"`
}
BalanceAmount
store.goGetAccountForUpdateUpdateAccountAddAccountBalance
package db

import (
	"context"
	"database/sql"
	"fmt"
)

type Store struct {
	*Queries
	db *sql.DB
}

func NewStore(db *sql.DB) *Store {
	return &Store{
		db:      db,
		Queries: New(db),
	}
}

func (store *Store) execTx(ctx context.Context, fn func(*Queries) error) error {
	tx, err := store.db.BeginTx(ctx, nil)
	if err != nil {
		return err
	}

	q := New(tx)
	err = fn(q)
	if err != nil {
		if rbErr := tx.Rollback(); rbErr != nil {
			return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
		}
		return err
	}

	return tx.Commit()
}

type TransferTxParams struct {
	FromAccountID int64 `json:"from_account_id"`
	ToAccountID   int64 `json:"to_account_id"`
	Amount        int64 `json:"amount"`
}

type TransferTxResult struct {
	Transfer    Transfer `json:"transfer"`
	FromAccount Account  `json:"from_account"`
	ToAccount   Account  `json:"to_account"`
	FromEntry   Entry    `json:"from_entry"`
	ToEntry     Entry    `json:"to_entry"`
}

func (store *Store) TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) {
	var result TransferTxResult

	err := store.execTx(ctx, func(q *Queries) error {
		var err error

		// 1. 创建一个金额等于`10`的转账记录
		result.Transfer, err = q.CreateTransfer(ctx, CreateTransferParams{
			FromAccountID: arg.FromAccountID,
			ToAccountID:   arg.ToAccountID,
			Amount:        arg.Amount,
		})
		if err != nil {
			return err
		}

		// 2. 为`FromAccount`创建一个账目记录,金额为`-10`
		result.FromEntry, err = q.CreateEntry(ctx, CreateEntryParams{
			AccountID: arg.FromAccountID,
			Amount:    -arg.Amount,
		})
		if err != nil {
			return err
		}

		// 3. 为`ToAccount`创建一个账目记录,金额为`+10`
		result.ToEntry, err = q.CreateEntry(ctx, CreateEntryParams{
			AccountID: arg.ToAccountID,
			Amount:    arg.Amount,
		})
		if err != nil {
			return err
		}

		// 更新账户余额
		result.FromAccount, err = q.AddAccountBalance(ctx, AddAccountBalanceParams{
			ID:     arg.FromAccountID,
			Amount: -arg.Amount,
		})
		if err != nil {
			return err
		}

		result.ToAccount, err = q.AddAccountBalance(ctx, AddAccountBalanceParams{
			ID:     arg.ToAccountID,
			Amount: arg.Amount,
		})
		if err != nil {
			return err
		}

		return err
	})

	return result, err
}

再次运行单元测试,没问题,测试通过。

如何避免数据库事务查询中的死锁顺序问题