Golang

在开始之前,先聊一下事务。

什么是数据库事务?

银行张三李四105accountsentriestransfers
10transfers张三-10entries李四+10entries张三10accounts李四10accounts

为什么需要使用数据库事务?

主要原因有2个:

  1. 我们希望这个操作单元可靠且一致,即使系统出现某些故障的情况下也如此。
  2. 在程序和访问数据库之间提供隔离。
ACID
AAtomicityCConsistencyIIsolationDDurability

SQL中如何操作数据库事务?

BEGINSQLCOMMITROLLBACKROLLBACKGolang

在Golang中实现数据库事务

sqlcstore.goStore
SQLCQueries11QueriesStoreGocompostionGolang
StoreQueriessql.DBsql.DB
type Store struct {
	*Queries
	db *sql.DB
}
NewStoreStore
func NewStore(db *sql.DB) *Store {
	return &Store{
		db:      db,
		Queries: New(db),
	}
}
sql.DBStoreStoredbQueriesdbNewNewsqlcQueries

接下来,再新建个方法来执行通用的数据库事务,思路很简单,它将一个上下文和一个回调函数作为入参,然后它将启动一个新的数据库事务。

func (store *Store) execTx(ctx context.Context, fn func(*Queries) error) error {
	tx, err := store.db.BeginTx(ctx, nil)
	if err != nil {
		return err
	}

	q := New(tx)
	err = fn(q)
	if err != nil {
		if rbErr := tx.Rollback(); rbErr != nil {
			return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
		}
		return err
	}

	return tx.Commit()
}

解释一下上面这段代码:

tx, err := store.db.BeginTx(ctx, nil)ctxnilsql.TxOptionspostgresnilif err != nilBeginTx()q := New(tx)New()QueriesNewStore()New()NewStoreNew()sql.DBNew()sql.TxNew()DBTXsql.Txsql.DBQuerieserr = fn(q)errorif err != nil {errtx.Rollback()rbErrnil两个1fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)return errtx.Commit()
execTx
TransferTx()TransferTxParamsTransferTxResult
type TransferTxParams struct {
	FromAccountID int64 `json:"from_account_id"`
	ToAccountID   int64 `json:"to_account_id"`
	Amount        int64 `json:"amount"`
}

type TransferTxResult struct {
	Transfer    Transfer `json:"transfer"`
	FromAccount Account  `json:"from_account"`
	ToAccount   Account  `json:"to_account"`
	FromEntry   Entry    `json:"from_entry"`
	ToEntry     Entry    `json:"to_entry"`
}
TransferTx()
func (store *Store) TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) {
	var result TransferTxResult

	err := store.execTx(ctx, func(q *Queries) error {
		var err error

		// 1. 创建一个金额等于`10`的转账记录
		result.Transfer, err = q.CreateTransfer(ctx, CreateTransferParams{
			FromAccountID: arg.FromAccountID,
			ToAccountID:   arg.ToAccountID,
			Amount:        arg.Amount,
		})
		if err != nil {
			return err
		}

		// 2. 为`FromAccount`创建一个账目记录,金额为`-10`
		result.FromEntry, err = q.CreateEntry(ctx, CreateEntryParams{
			AccountID: arg.FromAccountID,
			Amount:    -arg.Amount,
		})
		if err != nil {
			return err
		}

		// 3. 为`ToAccount`创建一个账目记录,金额为`+10`
		result.ToEntry, err = q.CreateEntry(ctx, CreateEntryParams{
			AccountID: arg.ToAccountID,
			Amount:    arg.Amount,
		})
		if err != nil {
			return err
		}

		// TODO: 更新账户余额操作后面再做

		return err
	})

	return result, err
}
q.CreateTransferq.CreateEntrysqlcsqlmake sqlc

上面的代码,就是按转账拆解的5个步骤,分步实现的。

单元测试

store_test.gomain_test.go
package db

import (
	"database/sql"
	"log"
	"os"
	"testing"

	_ "github.com/lib/pq"
)

const (
	dbDriver = "postgres"
	dbSource = "postgresql://root:123456@localhost:5432/simple_bank?sslmode=disable"
)

var testQueries *Queries
var testDB *sql.DB

func TestMain(m *testing.M) {
	var err error
	testDB, err = sql.Open(dbDriver, dbSource)
	if err != nil {
		log.Fatal("cannot connect to db:", err)
	}

	testQueries = New(testDB)

	os.Exit(m.Run())
}
var testDB *sql.DBstore_test.goconn, err := sql.Open(dbDriver, dbSource)testDB, err = sql.Open(dbDriver, dbSource)testQueries = New(conn)testQueries = New(testDB)
store_test.go
package db

import (
	"context"
	"testing"

	"github.com/stretchr/testify/require"
)

func TestTransferTx(t *testing.T) {
	store := NewStore(testDB)

	account1 := createRandomAccount(t)
	account2 := createRandomAccount(t)

	n := 5
	amount := int64(10)

	errs := make(chan error)
	results := make(chan TransferTxResult)

	for i := 0; i < n; i++ {
		go func() {
			result, err := store.TransferTx(context.Background(), TransferTxParams{
				FromAccountID: account1.ID,
				ToAccountID:   account2.ID,
				Amount:        amount,
			})

			errs <- err
			results <- result
		}()
	}

	// 检查结果
	for i := 0; i < n; i++ {
		err := <-errs
		require.NoError(t, err)

		result := <-results
		require.NotEmpty(t, result)

		// check transfer
		transfer := result.Transfer
		require.NotEmpty(t, transfer)
		require.Equal(t, account1.ID, transfer.FromAccountID)
		require.Equal(t, account2.ID, transfer.ToAccountID)
		require.Equal(t, amount, transfer.Amount)
		require.NotZero(t, transfer.ID)
		require.NotZero(t, transfer.CreatedAt)

		_, err = store.GetTransfer(context.Background(), transfer.ID)
		require.NoError(t, err)

		// check entries
		formEntry := result.FromEntry
		require.NotEmpty(t, formEntry)
		require.Equal(t, account1.ID, formEntry.AccountID)
		require.Equal(t, -amount, formEntry.Amount)
		require.NotZero(t, formEntry.ID)
		require.NotZero(t, formEntry.CreatedAt)

		_, err = store.GetEntry(context.Background(), formEntry.ID)
		require.NoError(t, err)

		toEntry := result.ToEntry
		require.NotEmpty(t, toEntry)
		require.Equal(t, account2.ID, toEntry.AccountID)
		require.Equal(t, amount, toEntry.Amount)
		require.NotZero(t, toEntry.ID)
		require.NotZero(t, toEntry.CreatedAt)

		_, err = store.GetEntry(context.Background(), toEntry.ID)
		require.NoError(t, err)

		// TODO: 检查更新后的账户余额
	}
}
goGoroutine
n := 5account1account2
for i := 0; i < n; i++ {goroutine
for i := 0; i < n; i++ {
		go func() {
		}()
}
go func() {store.TransferTx()testifyrequirego routineTestTransferTxgo routine
go routinechannelgo routine
	errs := make(chan error)
	results := make(chan TransferTxResult)
channelchannelmakechannelgo func() {errerrsresultresults
errs <- err
results <- result
左边右边
go routine<-
run testrun package test