reflectmysql
railsactive_recordmethod_missingdefine_methodCRUDtransaction

准备工作

create database orm_dbcreateuser
CREATE TABLE `user` (
  `id` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT '自增主键',
  `age` smallint(10) unsigned NOT NULL DEFAULT 0 COMMENT '年龄',
  `first_name` varchar(45) NOT NULL DEFAULT '' COMMENT '姓',
  `last_name` varchar(45) NOT NULL DEFAULT '' COMMENT '名',
  `email` varchar(45) NOT NULL DEFAULT '' COMMENT '邮箱地址',
  `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
  `updated_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
  PRIMARY KEY (`id`),
  KEY `idx_email` (`email`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='用户表';
struct
type User struct {
    ID        int64     `json:"id"`         // 自增主键
    Age       int64     `json:"age"`        // 年龄
    FirstName string    `json:"first_name"` // 姓
    LastName  string    `json:"last_name"`  // 名
    Email     string    `json:"email"`      // 邮箱地址
    CreatedAt time.Time `json:"created_at"` // 创建时间
    UpdatedAt time.Time `json:"updated_at"` // 更新时间}
import
package ormimport (    "database/sql"
    
    //register driver
    _ "github.com/go-sql-driver/mysql")
database
//Connect db by dsn e.g. "user:password@tcp(127.0.0.1:3306)/dbname"func Connect(dsn string) (*sql.DB, error) {
    conn, err := sql.Open("mysql", dsn)    if err != nil {        return nil, err
    }    //设置连接池
    conn.SetMaxOpenConns(100)
    conn.SetMaxIdleConns(10)
    conn.SetConnMaxLifetime(10 * time.Minute)    return conn, conn.Ping()
}
structclass
//Query will build a sqltype Query struct {
    db      *sql.DB
    table   string}
Query
//Table bind db and tablefunc Table(db *sql.DB, tableName string) func() *Query {    return func() *Query {        return &Query{
            db:    db,
            table: tableName,
        }
    }
}
Queryorm_dbuser
//全局变量ormDB和usersormDB, _ := Connect("user:password@tcp(127.0.0.1:3306)/orm_db")
users := Table(ormDB, "user")//调用users().Insert(...)

准备工作到此完成,下面进入正题。

Insert方法

insert
insert into user (first_name, last_name) values ('Tom', 'Cat'), ('Tom', 'Cruise')
keyvalueInsert
//Insert in can be *User, []*User, map[string]interface{}func (q *Query) Insert(in interface{}) (int64, error) {    var keys, values []string
    v := reflect.ValueOf(in)    //剥离指针
    for v.Kind() == reflect.Ptr {
        v = v.Elem()
    }    switch v.Kind() {    case reflect.Struct:
        keys, values = sKV(v)    case reflect.Map:
        keys, values = mKV(v)    case reflect.Slice:        for i := 0; i < v.Len(); i++ {            //Kind是切片时,可以用Index()方法遍历
            sv := v.Index(i)            for sv.Kind() == reflect.Ptr || sv.Kind() == reflect.Interface {
                sv = sv.Elem()
            }            //切片元素不是struct或者指针,报错
            if sv.Kind() != reflect.Struct {                return 0, errors.New("method Insert error: in slice is not structs")
            }            //keys只保存一次就行,因为后面的都一样了
            if len(keys) == 0 {
                keys, values = sKV(sv)                continue
            }
            _, val := sKV(sv)
            values = append(values, val...)
        }    default:        return 0, errors.New("method Insert error: type error")
    }    //todo
    //...}
inUser类型TypeValue*rtype*rtype*rtype地址元数据Kind原始类型
type myInt int
var i myInt
t := reflect.TypeOf(i)
k := t.Kind()
Interface()sKV()
func sKV(v reflect.Value) ([]string, []string) {
    var keys, values []string
    t := v.Type()    for n := 0; n < t.NumField(); n++ {
        tf := t.Field(n)
        vf := v.Field(n)        //忽略非导出字段
        if tf.Anonymous {            continue
        }        //忽略无效、零值字段
        if !vf.IsValid() || reflect.DeepEqual(vf.Interface(), reflect.Zero(vf.Type()).Interface()) {            continue
        }        for vf.Type().Kind() == reflect.Ptr {
            vf = vf.Elem()
        }        //有时候根据需求会组合struct,这里处理下,支持获取嵌套的struct tag和value
        //如果字段值是time类型之外的struct,递归获取keys和values
        if vf.Kind() == reflect.Struct && tf.Type.Name() != "Time" {
            cKeys, cValues := sKV(vf)
            keys = append(keys, cKeys...)
            values = append(values, cValues...)            continue
        }        //根据字段的json tag获取key,忽略无tag字段
        key := strings.Split(tf.Tag.Get("json"), ",")[0]        if key == "" {            continue
        }
        value := format(vf)        if value != "" {
            keys = append(keys, key)
            values = append(values, value)
        }
    }    return keys, values
}
sKV()format()time.Time
func format(v reflect.Value) string {    //断言出time类型直接转unix时间戳
    if t, ok := v.Interface().(time.Time); ok {        return fmt.Sprintf("FROM_UNIXTIME(%d)", t.Unix())
    }    switch v.Kind() {    case reflect.String:        return fmt.Sprintf(`'%s'`, v.Interface())    case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:        return fmt.Sprintf(`%d`, v.Interface())    case reflect.Float32, reflect.Float64:        return fmt.Sprintf(`%f`, v.Interface())    //如果是切片类型,遍历元素,递归格式化成"(, , , )"形式
    case reflect.Slice:        var values []string        for i := 0; i < v.Len(); i++ {
            values = append(values, format(v.Index(i)))
        }        return fmt.Sprintf(`(%s)`, strings.Join(values, ","))    //接口类型剥一层递归
    case reflect.Interface:        return format(v.Elem())
    }    return ""}
mKV()
func mKV(v reflect.Value) ([]string, []string) {
    var keys, values []string
    //获取map的key组成的切片
    mapKeys := v.MapKeys()    for _, key := range mapKeys {
        value := format(v.MapIndex(key))        if value != "" {
            values = append(values, value)
            keys = append(keys, key.Interface().(string))
        }
    }    return keys, values
}
todo
//Insert in can be User, *User, []User, []*User, map[string]interface{}func (q *Query) Insert(in interface{}) (int64, error) {    //already done
    kl := len(keys)
    vl := len(values)    if kl == 0 || vl == 0 {        return 0, errors.New("method Insert error: no data")
    }    var insertValue string    //插入多条记录时需要用","拼接一下values
    if kl < vl {        var tmpValues []string        for kl <= vl {            if kl%(len(keys)) == 0 {
                tmpValues = append(tmpValues, fmt.Sprintf("(%s)", strings.Join(values[kl-len(keys):kl], ",")))
            }
            kl++
        }
        insertValue = strings.Join(tmpValues, ",")
    } else {
        insertValue = fmt.Sprintf("(%s)", strings.Join(values, ","))
    }
    query := fmt.Sprintf(`insert into %s (%s) values %s`, q.table, strings.Join(keys, ","), insertValue)
    log.Printf("insert sql: %s", query)
    st, err := q.DB.Prepare(query)    if err != nil {        return 0, err
    }
    result, err := st.Exec()    if err != nil {        return 0, err
    }    return result.LastInsertId()
}

原理很简单,利用反射分析参数,取键值对,然后拼接sql语句,再通过mysql驱动入库。
调用示例:

user1 := &User{
    Age:       30,
    FirstName: "Tom",
    LastName:  "Cat",
}
user2 := User{
    Age:       30,
    FirstName: "Tom",
    LastName:  "Curise",
}
user3 := User{
    Age:       30,
    FirstName: "Tom",
    LastName:  "Hanks",
}
user4 := map[string]interface{}{    "age":        30,    "first_name": "Tom",    "last_name":  "Zzy",
}
users().Insert([]interface{}{user1, user2})
users().Insert(user3)
users().Insert(user4)
增查

Select方法

select
select id, age from user where first_name = 'Tom' and last_name = 'Cat'
selectwhereWhere()Select()
var user []Userusers().Where(?).WhereNot(?).Limit(100).Offset(100).Order("id desc").Only("id", "age").Select(&user)

所以需要改造Query如下,增加属性用于暂存链式调用中添加的值:

//Query will build a sqltype Query struct {
    db     *sql.DB
    table  string
    wheres []string
    only   []string
    limit  string
    offset string
    order  string
    errs   []string}
"age > 10"
//Where args can be string, User, *User, map[string]interface{}func (q *Query) Where(wheres ...interface{}) *Query {    for _, w := range wheres {
        v := reflect.ValueOf(w)        for v.Kind() == reflect.Ptr {
            v = v.Elem()
        }        switch v.Kind() {        case reflect.String:
            q.wheres = append(q.wheres, w.(string))        case reflect.Struct:            //todo
        case reflect.Map:            //todo
        default:
            q.errs = append(q.errs, "method Where error: type error")
        }
    }    return q
}
WhereNot()where()
func where(eq bool, w interface{}) (string, error) {
    var keys, values []string
    v := reflect.ValueOf(w)    for v.Kind() == reflect.Ptr {
        v = v.Elem()
    }    switch v.Kind() {    case reflect.String:        return w.(string), nil    case reflect.Struct:
        keys, values = sKV(v)    case reflect.Map:
        keys, values = mKV(v)    default:        return "", errors.New("method Where error: type error")
    }    if len(keys) != len(values) {        return "", errors.New("method Where error: len(keys) not equal len(values))")
    }
    var wheres []string
    //之前的format()函数里,已经将切片类型值处理成"( , , ,)“形式
    for idx, key := range keys {        if eq {            if strings.HasPrefix(values[idx], "(") && strings.HasSuffix(values[idx], ")") {
                wheres = append(wheres, fmt.Sprintf("%s in %s", key, values[idx]))                continue
            }
            wheres = append(wheres, fmt.Sprintf("%s = %s", key, values[idx]))            continue
        }        if strings.HasPrefix(values[idx], "(") && strings.HasSuffix(values[idx], ")") {
            wheres = append(wheres, fmt.Sprintf("%s not in %s", key, values[idx]))            continue
        }
        wheres = append(wheres, fmt.Sprintf("%s != %s", key, values[idx]))
    }    return strings.Join(wheres, " and "), nil
}

Where()方法最终变成:

//Where args can be string, User, *User, map[string]interface{}func (q *Query) Where(wheres ...interface{}) *Query {    for _, w := range wheres {
        str, err := where(true, w)
        q.wheres = append(q.wheres, str)        if err != nil {            //因为需要达到链式调用的效果,所以把错误都搜集起来,最后再处理
            q.errs = append(q.errs, err.Error())
        }
    }    return q
}
Limit()Offset()Order()Only()
//Limit .func (q *Query) Limit(limit uint) *Query {
    q.limit = fmt.Sprintf("limit %d", limit)    return q
}//Offset .func (q *Query) Offset(offset uint) *Query {
    q.offset = fmt.Sprintf("offset %d", offset)    return q
}//Order .func (q *Query) Order(ord string) *Query {
    q.order = fmt.Sprintf("order by %s", ord)    return q
}//Only 指定需要查询的字段func (q *Query) Only(columns ...string) *Query {
    q.only = append(q.only, columns...)    return q
}
toSQL()
func (q *Query) toSQL() string {
    var where string    if len(q.wheres) > 0 {        where = fmt.Sprintf(`where %s`, strings.Join(q.wheres, " and "))
    }
    sqlStr := fmt.Sprintf(`select %s from %s %s %s %s %s`, strings.Join(q.only, ","), q.table, where, q.order, q.limit, q.offset)
    log.Printf("select sql: %s", sqlStr)    return sqlStr
}
structuserUser
var user User
users.Select(&user)var userPtr *User
users.Select(user)

这两种声明方式是不同的,后者只声明了一个指针类型,是错误的。
综上,我们首先为Select()方法做一下的参数检查,确保传入值是一个正确的指针,并确保only属性有值:

//Select dest must be a ptr, e.g. *user, *[]user, *[]*user, *map, *[]map, *int, *[]intfunc (q *Query) Select(dest interface{}) error {    if len(q.errs) != 0 {        return errors.New(strings.Join(q.errs, "
"))
    }
    t := reflect.TypeOf(dest)
    v := reflect.ValueOf(dest)
    typeErr := errors.New("method Select error: type error")    if t.Kind() != reflect.Ptr {        return typeErr
    }    //如果是用 var userPtr *User 方式声明的变量,则不可取址
    if !v.Elem().CanAddr() {        return typeErr
    }
    t = t.Elem()
    v = v.Elem()    //如果only此时仍然为空,说明Only()方法未被调用,我们从struct上取tag填充
    if len(q.only) == 0 {        switch t.Kind() {        case reflect.Struct:            if t.Name() != "Time" {
                q.only = sK(v)
            }        case reflect.Slice:            //获取切片的基本类型给一个局部变量
            t := t.Elem()            if t.Kind() == reflect.Ptr {
                t = t.Elem()
            }            if t.Kind() == reflect.Struct {                if t.Name() != "Time" {
                    q.only = sK(reflect.Zero(t))
                }
            }
        }
    }    if len(q.only) == 0 {        return errors.New("method Select error: type error, no columns to select")
    }    if t.Kind() != reflect.Slice {
        q.limit = "limit 1"
    }    //todo}

这里只取struct的tag,不取value,我们定义一个新的sK()函数:

func sK(v reflect.Value) []string {
    var keys []string
    t := v.Type()    for n := 0; n < t.NumField(); n++ {
        tf := t.Field(n)
        vf := v.Field(n)        //忽略非导出字段
        if tf.Anonymous {            continue
        }        for vf.Type().Kind() == reflect.Ptr {
            vf = vf.Elem()
        }        //如果字段值是time类型之外的struct,递归获取keys
        if vf.Kind() == reflect.Struct && tf.Type.Name() != "Time" {
            keys = append(keys, sK(vf)...)            continue
        }        //根据字段的json tag获取key,忽略无tag字段
        key := strings.Split(tf.Tag.Get("json"), ",")[0]        if key == "" {            continue
        }
        keys = append(keys, key)
    }    return keys
}
Scan()Set()address()
func address(dest reflect.Value, columns []string) []interface{} {
    dest = dest.Elem()
    t := dest.Type()
    addrs := make([]interface{}, 0)    switch t.Kind() {    case reflect.Struct:        for n := 0; n < t.NumField(); n++ {
            tf := t.Field(n)
            vf := dest.Field(n)            if tf.Anonymous {                continue
            }            for vf.Type().Kind() == reflect.Ptr {
                vf = vf.Elem()
            }            //如果字段值是time类型之外的struct,递归取址
            if vf.Kind() == reflect.Struct && tf.Type.Name() != "Time" {
                nVf := reflect.New(vf.Type())
                vf.Set(nVf.Elem())
                addrs = append(addrs, address(nVf, columns)...)                continue
            }
            column := strings.Split(tf.Tag.Get("json"), ",")[0]            if column == "" {                continue
            }            //只取选定的字段的地址
            for _, col := range columns {                if col == column {
                    addrs = append(addrs, vf.Addr().Interface())                    break
                }
            }
        }    default:
        addrs = append(addrs, dest.Addr().Interface())
    }    return addrs
}
Value.Addr()Value.CanAddr()relfect.New()TypenewValue指针Set()setMap()
//map的value类型必须是interface{},因为无类型信息,所以mysql驱动会返回一个字节切片,需要自行用[]byte断言func (q *Query) setMap(rows *sql.Rows, t reflect.Type) (reflect.Value, error) {    if t.Elem().Kind() != reflect.Interface {        return reflect.ValueOf(nil), errors.New("method setMap error: type error, must be map[string]interface{}")
    }
    m := reflect.MakeMap(t)
    addrs := make([]interface{}, len(q.only))    for idx := range q.only {
        addrs[idx] = new(interface{})
    }    if err := rows.Scan(addrs...); err != nil {        return reflect.ValueOf(nil), err
    }    for idx, column := range q.only {        //从指针剥出interface{},再剥出实际值
        m.SetMapIndex(reflect.ValueOf(column), reflect.ValueOf(addrs[idx]).Elem().Elem())
    }    return m, nil
}
reflect.MakeMap()make()Kindreflect.MapTypenewsetElem()
//适用于基类型和structfunc (q *Query) setElem(rows *sql.Rows, t reflect.Type) (reflect.Value, error) {
    addrsErr := errors.New("method setElem error: columns not match addresses")
    dest := reflect.New(t)
    addrs := address(dest, q.only)    if len(q.only) != len(addrs) {        return reflect.ValueOf(nil), addrsErr
    }    if err := rows.Scan(addrs...); err != nil {        return reflect.ValueOf(nil), err
    }    return dest, nil}

这些函数完成后,就可以着手完善Select()里的todo部分了:

//already donerows, err := q.DB.Query(q.toSQL())    if err != nil {        return err
    }    switch t.Kind() {    case reflect.Slice:
        dt := t.Elem()        for dt.Kind() == reflect.Ptr {
            dt = dt.Elem()
        }
        sl := reflect.MakeSlice(t, 0, 0)        for rows.Next() {
            var destination reflect.Value            if dt.Kind() == reflect.Map {
                destination, err = q.setMap(rows, dt)
            } else {
                destination, err = q.setElem(rows, dt)
            }            if err != nil {                return err
            }            //区分切片元素是否指针
            switch t.Elem().Kind() {            case reflect.Ptr, reflect.Map:
                sl = reflect.Append(sl, destination)            default:
                sl = reflect.Append(sl, destination.Elem())
            }
        }
        v.Set(sl)        return nil
    case reflect.Map:        for rows.Next() {
            m, err := q.setMap(rows, t)            if err != nil {                return err
            }
            v.Set(m)
        }        return nil
    default:        for rows.Next() {
            destination, err := q.setElem(rows, t)            if err != nil {                return err
            }
            v.Set(destination.Elem())
        }
    }    return nil

至此,Select()方法就大功告成了,部分调用方式示例:

var user User
users()
.Where("first_name = 'Tom'", map[string]interface{}{    "id": []int{1, 2, 3, 4},
})
.WhereNot(&User{LastName: "Cat"})
.Only("last_name")
.Select(&user)var userMore []User
users().Where("first_name = 'Tom'").Order("id desc").Select(&userMore)var userMoreP []*User
users().Where("first_name = 'Tom'").Select(&userMoreP)var lastName string
users().Where(&User{FirstName: "Tom"}).Only("last_name").Select(&lastName)var lastNames []string
users().Where(map[string]interface{}{    "first_name": "Tom",
}).Only("last_name").Select(&lastNames)var userM map[string]interface{}
users().Where(&User{FirstName: "Tom"}).Only("last_name").Select(&userM)var userMS []map[string]interface{}
users().Where("age > 10").Only("last_name", "age").Limit(100).Select(&userMS)

Update方法

分析update sql语句:

update user set first_name = "z", last_name = "zy" where first_name = "Tom" and last_name = "Curise"

比较简单,直接复用之前写的sKV()和mKV()函数:

//Update src can be *user, user, map[string]interface{}, stringfunc (q *Query) Update(src interface{}) (int64, error) {    if len(q.errs) != 0 {        return 0, errors.New(strings.Join(q.errs, "
"))
    }
    v := reflect.ValueOf(src)    for v.Kind() == reflect.Ptr {
        v = v.Elem()
    }    var toBeUpdated, where string    var keys, values []string    switch v.Kind() {    case reflect.String:
        toBeUpdated = src.(string)    case reflect.Struct:
        keys, values = sKV(v)    case reflect.Map:
        keys, values = mKV(v)    default:        return 0, errors.New("method Update error: type error")
    }    if toBeUpdated == "" {        if len(keys) != len(values) {            return 0, errors.New("method Update error: keys not match values")
        }        var kvs []string        for idx, key := range keys {
            kvs = append(kvs, fmt.Sprintf("%s = %s", key, values[idx]))
        }
        toBeUpdated = strings.Join(kvs, ",")
    }    if len(q.wheres) > 0 {
        where = fmt.Sprintf(`where %s`, strings.Join(q.wheres, " and "))
    }
    query := fmt.Sprintf("update %s set %s %s", q.table, toBeUpdated, where)
    st, err := q.DB.Prepare(query)    if err != nil {        return 0, err
    }
    result, err := st.Exec()    if err != nil {        return 0, err
    }    return result.RowsAffected()
}

调用方式:

u1 := "age = 100"u2 := map[string]interface{}{    "age":        100,    "first_name": "z",    "last_name":  "zy",
}
u3 := &User{
    Age:       100,
    FirstName: "z",
    LastName:  "zy",
}
_, _ = users().Where("age > 10").Update(u1)
_, _ = users().Where("age > 10").Update(u2)
_, _ = users().Where("age > 10").Update(u3)

Delete方法

这个最简单,没啥好说的:

//Delete no args
func (q *Query) Delete() (int64, error) {    if len(q.errs) != 0 {        return 0, errors.New(strings.Join(q.errs, "
"))
    }
    var where string    if len(q.wheres) > 0 {        where = fmt.Sprintf(`where %s`, strings.Join(q.wheres, " and "))
    }
    st, err := q.DB.Prepare(fmt.Sprintf(`delete from %s %s`, q.table, where))    if err != nil {        return 0, err
    }
    result, err := st.Exec()    if err != nil {        return 0, err
    }    return result.RowsAffected()
}

删除id为1,2,3,4,并且age大于10的用户的调用方式:

w := map[string]interface{}{    "id": []int{1, 2, 3, 4},
}
_, _ = users().Where(w, "age > 10").Delete()
Transaction()

Transaction函数

beginrollbackcommitrecover*sql.DB.Begin()Query()Prepare()
//Dba *sql.DB or *sql.Txtype Dba interface {
    Query(string, ...interface{}) (*sql.Rows, error)
    Prepare(string) (*sql.Stmt, error)
}
QueryDB
//Query will build a sqltype Query struct {
    DB     Dba
    ...
}
Table()
//Table bind db and tablefunc Table(db *sql.DB, tableName string) func(...Dba) *Query {    return func(tx ...Dba) *Query {        if len(tx) == 1 {            return &Query{
                DB:    tx[0],
                table: tableName,
            }
        }        return &Query{
            DB:    db,
            table: tableName,
        }
    }
}
Transaction()
//Transaction .func Transaction(db *sql.DB, f func(Dba) error) (err error) {
    tx, err := db.Begin()    if err != nil {        return err
    }
    defer func() {
        p := recover()        if err != nil {            if rerr := tx.Rollback(); rerr != nil {
                panic(rerr)
            }            return
        }        if p != nil {            if rerr := tx.Rollback(); rerr != nil {
                panic(rerr)
            }
            err = fmt.Errorf("function Transaction error: %v", p)            return
        }        if cerr := tx.Commit(); cerr != nil {
            panic(cerr)
        }
    }()
    err = f(tx)    return err
}

第二个参数是一个接受事务具柄,返回error的函数,我们将需要事务的操作全部封装在这个函数里,就能抓到所有的panic和error。
调用方式示例:

unc doTx() error {
    ormDB, err := Connect("root@tcp(127.0.0.1:3306)/orm_db?parseTime=true&loc=Local")    if err != nil {
        panic(err)
    }
    users := Table(ormDB, "user")
    args := something()    //利用闭包传递变量
    f := func(tx Dba) error {
        var id int
        //select语句无需在事务具柄上进行
        if err := users().Where(args).Select(&id); err != nil {            return err
        }        //增删改需要在事务上进行
        if _, err = users(tx).Insert(args); err != nil {            return err
        }        if _, err = users(tx).Update(args); err != nil {            return err
        }        if _, err = users(tx).Where(args).Delete(); err != nil {            return err
        }        return nil
    }    //开始事务
    if err := Transaction(ormDB, f); err != nil {        return err
    }    return nil}

到此,这个迷你orm的增删改查和事务功能全部都实现了,代码大概600行,比我预想的多了一倍。