reflectmysql
railsactive_recordmethod_missingdefine_methodCRUDtransaction
准备工作
create database orm_dbcreateuser
CREATE TABLE `user` ( `id` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT '自增主键', `age` smallint(10) unsigned NOT NULL DEFAULT 0 COMMENT '年龄', `first_name` varchar(45) NOT NULL DEFAULT '' COMMENT '姓', `last_name` varchar(45) NOT NULL DEFAULT '' COMMENT '名', `email` varchar(45) NOT NULL DEFAULT '' COMMENT '邮箱地址', `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', `updated_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', PRIMARY KEY (`id`), KEY `idx_email` (`email`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='用户表';
struct
type User struct { ID int64 `json:"id"` // 自增主键 Age int64 `json:"age"` // 年龄 FirstName string `json:"first_name"` // 姓 LastName string `json:"last_name"` // 名 Email string `json:"email"` // 邮箱地址 CreatedAt time.Time `json:"created_at"` // 创建时间 UpdatedAt time.Time `json:"updated_at"` // 更新时间}
import
package ormimport ( "database/sql" //register driver _ "github.com/go-sql-driver/mysql")
database
//Connect db by dsn e.g. "user:password@tcp(127.0.0.1:3306)/dbname"func Connect(dsn string) (*sql.DB, error) { conn, err := sql.Open("mysql", dsn) if err != nil { return nil, err } //设置连接池 conn.SetMaxOpenConns(100) conn.SetMaxIdleConns(10) conn.SetConnMaxLifetime(10 * time.Minute) return conn, conn.Ping() }
structclass
//Query will build a sqltype Query struct { db *sql.DB table string}
Query
//Table bind db and tablefunc Table(db *sql.DB, tableName string) func() *Query { return func() *Query { return &Query{ db: db, table: tableName, } } }
Queryorm_dbuser
//全局变量ormDB和usersormDB, _ := Connect("user:password@tcp(127.0.0.1:3306)/orm_db") users := Table(ormDB, "user")//调用users().Insert(...)
准备工作到此完成,下面进入正题。
Insert方法
insert
insert into user (first_name, last_name) values ('Tom', 'Cat'), ('Tom', 'Cruise')
keyvalueInsert
//Insert in can be *User, []*User, map[string]interface{}func (q *Query) Insert(in interface{}) (int64, error) { var keys, values []string v := reflect.ValueOf(in) //剥离指针 for v.Kind() == reflect.Ptr { v = v.Elem() } switch v.Kind() { case reflect.Struct: keys, values = sKV(v) case reflect.Map: keys, values = mKV(v) case reflect.Slice: for i := 0; i < v.Len(); i++ { //Kind是切片时,可以用Index()方法遍历 sv := v.Index(i) for sv.Kind() == reflect.Ptr || sv.Kind() == reflect.Interface { sv = sv.Elem() } //切片元素不是struct或者指针,报错 if sv.Kind() != reflect.Struct { return 0, errors.New("method Insert error: in slice is not structs") } //keys只保存一次就行,因为后面的都一样了 if len(keys) == 0 { keys, values = sKV(sv) continue } _, val := sKV(sv) values = append(values, val...) } default: return 0, errors.New("method Insert error: type error") } //todo //...}
inUser类型TypeValue*rtype*rtype*rtype地址元数据Kind原始类型
type myInt int var i myInt t := reflect.TypeOf(i) k := t.Kind()
Interface()sKV()
func sKV(v reflect.Value) ([]string, []string) { var keys, values []string t := v.Type() for n := 0; n < t.NumField(); n++ { tf := t.Field(n) vf := v.Field(n) //忽略非导出字段 if tf.Anonymous { continue } //忽略无效、零值字段 if !vf.IsValid() || reflect.DeepEqual(vf.Interface(), reflect.Zero(vf.Type()).Interface()) { continue } for vf.Type().Kind() == reflect.Ptr { vf = vf.Elem() } //有时候根据需求会组合struct,这里处理下,支持获取嵌套的struct tag和value //如果字段值是time类型之外的struct,递归获取keys和values if vf.Kind() == reflect.Struct && tf.Type.Name() != "Time" { cKeys, cValues := sKV(vf) keys = append(keys, cKeys...) values = append(values, cValues...) continue } //根据字段的json tag获取key,忽略无tag字段 key := strings.Split(tf.Tag.Get("json"), ",")[0] if key == "" { continue } value := format(vf) if value != "" { keys = append(keys, key) values = append(values, value) } } return keys, values }
sKV()format()time.Time
func format(v reflect.Value) string { //断言出time类型直接转unix时间戳 if t, ok := v.Interface().(time.Time); ok { return fmt.Sprintf("FROM_UNIXTIME(%d)", t.Unix()) } switch v.Kind() { case reflect.String: return fmt.Sprintf(`'%s'`, v.Interface()) case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: return fmt.Sprintf(`%d`, v.Interface()) case reflect.Float32, reflect.Float64: return fmt.Sprintf(`%f`, v.Interface()) //如果是切片类型,遍历元素,递归格式化成"(, , , )"形式 case reflect.Slice: var values []string for i := 0; i < v.Len(); i++ { values = append(values, format(v.Index(i))) } return fmt.Sprintf(`(%s)`, strings.Join(values, ",")) //接口类型剥一层递归 case reflect.Interface: return format(v.Elem()) } return ""}
mKV()
func mKV(v reflect.Value) ([]string, []string) { var keys, values []string //获取map的key组成的切片 mapKeys := v.MapKeys() for _, key := range mapKeys { value := format(v.MapIndex(key)) if value != "" { values = append(values, value) keys = append(keys, key.Interface().(string)) } } return keys, values }
todo
//Insert in can be User, *User, []User, []*User, map[string]interface{}func (q *Query) Insert(in interface{}) (int64, error) { //already done kl := len(keys) vl := len(values) if kl == 0 || vl == 0 { return 0, errors.New("method Insert error: no data") } var insertValue string //插入多条记录时需要用","拼接一下values if kl < vl { var tmpValues []string for kl <= vl { if kl%(len(keys)) == 0 { tmpValues = append(tmpValues, fmt.Sprintf("(%s)", strings.Join(values[kl-len(keys):kl], ","))) } kl++ } insertValue = strings.Join(tmpValues, ",") } else { insertValue = fmt.Sprintf("(%s)", strings.Join(values, ",")) } query := fmt.Sprintf(`insert into %s (%s) values %s`, q.table, strings.Join(keys, ","), insertValue) log.Printf("insert sql: %s", query) st, err := q.DB.Prepare(query) if err != nil { return 0, err } result, err := st.Exec() if err != nil { return 0, err } return result.LastInsertId() }
原理很简单,利用反射分析参数,取键值对,然后拼接sql语句,再通过mysql驱动入库。
调用示例:
user1 := &User{ Age: 30, FirstName: "Tom", LastName: "Cat", } user2 := User{ Age: 30, FirstName: "Tom", LastName: "Curise", } user3 := User{ Age: 30, FirstName: "Tom", LastName: "Hanks", } user4 := map[string]interface{}{ "age": 30, "first_name": "Tom", "last_name": "Zzy", } users().Insert([]interface{}{user1, user2}) users().Insert(user3) users().Insert(user4)
增查
Select方法
select
select id, age from user where first_name = 'Tom' and last_name = 'Cat'
selectwhereWhere()Select()
var user []Userusers().Where(?).WhereNot(?).Limit(100).Offset(100).Order("id desc").Only("id", "age").Select(&user)
所以需要改造Query如下,增加属性用于暂存链式调用中添加的值:
//Query will build a sqltype Query struct { db *sql.DB table string wheres []string only []string limit string offset string order string errs []string}
"age > 10"
//Where args can be string, User, *User, map[string]interface{}func (q *Query) Where(wheres ...interface{}) *Query { for _, w := range wheres { v := reflect.ValueOf(w) for v.Kind() == reflect.Ptr { v = v.Elem() } switch v.Kind() { case reflect.String: q.wheres = append(q.wheres, w.(string)) case reflect.Struct: //todo case reflect.Map: //todo default: q.errs = append(q.errs, "method Where error: type error") } } return q }
WhereNot()where()
func where(eq bool, w interface{}) (string, error) { var keys, values []string v := reflect.ValueOf(w) for v.Kind() == reflect.Ptr { v = v.Elem() } switch v.Kind() { case reflect.String: return w.(string), nil case reflect.Struct: keys, values = sKV(v) case reflect.Map: keys, values = mKV(v) default: return "", errors.New("method Where error: type error") } if len(keys) != len(values) { return "", errors.New("method Where error: len(keys) not equal len(values))") } var wheres []string //之前的format()函数里,已经将切片类型值处理成"( , , ,)“形式 for idx, key := range keys { if eq { if strings.HasPrefix(values[idx], "(") && strings.HasSuffix(values[idx], ")") { wheres = append(wheres, fmt.Sprintf("%s in %s", key, values[idx])) continue } wheres = append(wheres, fmt.Sprintf("%s = %s", key, values[idx])) continue } if strings.HasPrefix(values[idx], "(") && strings.HasSuffix(values[idx], ")") { wheres = append(wheres, fmt.Sprintf("%s not in %s", key, values[idx])) continue } wheres = append(wheres, fmt.Sprintf("%s != %s", key, values[idx])) } return strings.Join(wheres, " and "), nil }
Where()方法最终变成:
//Where args can be string, User, *User, map[string]interface{}func (q *Query) Where(wheres ...interface{}) *Query { for _, w := range wheres { str, err := where(true, w) q.wheres = append(q.wheres, str) if err != nil { //因为需要达到链式调用的效果,所以把错误都搜集起来,最后再处理 q.errs = append(q.errs, err.Error()) } } return q }
Limit()Offset()Order()Only()
//Limit .func (q *Query) Limit(limit uint) *Query { q.limit = fmt.Sprintf("limit %d", limit) return q }//Offset .func (q *Query) Offset(offset uint) *Query { q.offset = fmt.Sprintf("offset %d", offset) return q }//Order .func (q *Query) Order(ord string) *Query { q.order = fmt.Sprintf("order by %s", ord) return q }//Only 指定需要查询的字段func (q *Query) Only(columns ...string) *Query { q.only = append(q.only, columns...) return q }
toSQL()
func (q *Query) toSQL() string { var where string if len(q.wheres) > 0 { where = fmt.Sprintf(`where %s`, strings.Join(q.wheres, " and ")) } sqlStr := fmt.Sprintf(`select %s from %s %s %s %s %s`, strings.Join(q.only, ","), q.table, where, q.order, q.limit, q.offset) log.Printf("select sql: %s", sqlStr) return sqlStr }
structuserUser
var user User users.Select(&user)var userPtr *User users.Select(user)
这两种声明方式是不同的,后者只声明了一个指针类型,是错误的。
综上,我们首先为Select()方法做一下的参数检查,确保传入值是一个正确的指针,并确保only属性有值:
//Select dest must be a ptr, e.g. *user, *[]user, *[]*user, *map, *[]map, *int, *[]intfunc (q *Query) Select(dest interface{}) error { if len(q.errs) != 0 { return errors.New(strings.Join(q.errs, " ")) } t := reflect.TypeOf(dest) v := reflect.ValueOf(dest) typeErr := errors.New("method Select error: type error") if t.Kind() != reflect.Ptr { return typeErr } //如果是用 var userPtr *User 方式声明的变量,则不可取址 if !v.Elem().CanAddr() { return typeErr } t = t.Elem() v = v.Elem() //如果only此时仍然为空,说明Only()方法未被调用,我们从struct上取tag填充 if len(q.only) == 0 { switch t.Kind() { case reflect.Struct: if t.Name() != "Time" { q.only = sK(v) } case reflect.Slice: //获取切片的基本类型给一个局部变量 t := t.Elem() if t.Kind() == reflect.Ptr { t = t.Elem() } if t.Kind() == reflect.Struct { if t.Name() != "Time" { q.only = sK(reflect.Zero(t)) } } } } if len(q.only) == 0 { return errors.New("method Select error: type error, no columns to select") } if t.Kind() != reflect.Slice { q.limit = "limit 1" } //todo}
这里只取struct的tag,不取value,我们定义一个新的sK()函数:
func sK(v reflect.Value) []string { var keys []string t := v.Type() for n := 0; n < t.NumField(); n++ { tf := t.Field(n) vf := v.Field(n) //忽略非导出字段 if tf.Anonymous { continue } for vf.Type().Kind() == reflect.Ptr { vf = vf.Elem() } //如果字段值是time类型之外的struct,递归获取keys if vf.Kind() == reflect.Struct && tf.Type.Name() != "Time" { keys = append(keys, sK(vf)...) continue } //根据字段的json tag获取key,忽略无tag字段 key := strings.Split(tf.Tag.Get("json"), ",")[0] if key == "" { continue } keys = append(keys, key) } return keys }
Scan()Set()address()
func address(dest reflect.Value, columns []string) []interface{} { dest = dest.Elem() t := dest.Type() addrs := make([]interface{}, 0) switch t.Kind() { case reflect.Struct: for n := 0; n < t.NumField(); n++ { tf := t.Field(n) vf := dest.Field(n) if tf.Anonymous { continue } for vf.Type().Kind() == reflect.Ptr { vf = vf.Elem() } //如果字段值是time类型之外的struct,递归取址 if vf.Kind() == reflect.Struct && tf.Type.Name() != "Time" { nVf := reflect.New(vf.Type()) vf.Set(nVf.Elem()) addrs = append(addrs, address(nVf, columns)...) continue } column := strings.Split(tf.Tag.Get("json"), ",")[0] if column == "" { continue } //只取选定的字段的地址 for _, col := range columns { if col == column { addrs = append(addrs, vf.Addr().Interface()) break } } } default: addrs = append(addrs, dest.Addr().Interface()) } return addrs }
Value.Addr()Value.CanAddr()relfect.New()TypenewValue指针Set()setMap()
//map的value类型必须是interface{},因为无类型信息,所以mysql驱动会返回一个字节切片,需要自行用[]byte断言func (q *Query) setMap(rows *sql.Rows, t reflect.Type) (reflect.Value, error) { if t.Elem().Kind() != reflect.Interface { return reflect.ValueOf(nil), errors.New("method setMap error: type error, must be map[string]interface{}") } m := reflect.MakeMap(t) addrs := make([]interface{}, len(q.only)) for idx := range q.only { addrs[idx] = new(interface{}) } if err := rows.Scan(addrs...); err != nil { return reflect.ValueOf(nil), err } for idx, column := range q.only { //从指针剥出interface{},再剥出实际值 m.SetMapIndex(reflect.ValueOf(column), reflect.ValueOf(addrs[idx]).Elem().Elem()) } return m, nil }
reflect.MakeMap()make()Kindreflect.MapTypenewsetElem()
//适用于基类型和structfunc (q *Query) setElem(rows *sql.Rows, t reflect.Type) (reflect.Value, error) { addrsErr := errors.New("method setElem error: columns not match addresses") dest := reflect.New(t) addrs := address(dest, q.only) if len(q.only) != len(addrs) { return reflect.ValueOf(nil), addrsErr } if err := rows.Scan(addrs...); err != nil { return reflect.ValueOf(nil), err } return dest, nil}
这些函数完成后,就可以着手完善Select()里的todo部分了:
//already donerows, err := q.DB.Query(q.toSQL()) if err != nil { return err } switch t.Kind() { case reflect.Slice: dt := t.Elem() for dt.Kind() == reflect.Ptr { dt = dt.Elem() } sl := reflect.MakeSlice(t, 0, 0) for rows.Next() { var destination reflect.Value if dt.Kind() == reflect.Map { destination, err = q.setMap(rows, dt) } else { destination, err = q.setElem(rows, dt) } if err != nil { return err } //区分切片元素是否指针 switch t.Elem().Kind() { case reflect.Ptr, reflect.Map: sl = reflect.Append(sl, destination) default: sl = reflect.Append(sl, destination.Elem()) } } v.Set(sl) return nil case reflect.Map: for rows.Next() { m, err := q.setMap(rows, t) if err != nil { return err } v.Set(m) } return nil default: for rows.Next() { destination, err := q.setElem(rows, t) if err != nil { return err } v.Set(destination.Elem()) } } return nil
至此,Select()方法就大功告成了,部分调用方式示例:
var user User users() .Where("first_name = 'Tom'", map[string]interface{}{ "id": []int{1, 2, 3, 4}, }) .WhereNot(&User{LastName: "Cat"}) .Only("last_name") .Select(&user)var userMore []User users().Where("first_name = 'Tom'").Order("id desc").Select(&userMore)var userMoreP []*User users().Where("first_name = 'Tom'").Select(&userMoreP)var lastName string users().Where(&User{FirstName: "Tom"}).Only("last_name").Select(&lastName)var lastNames []string users().Where(map[string]interface{}{ "first_name": "Tom", }).Only("last_name").Select(&lastNames)var userM map[string]interface{} users().Where(&User{FirstName: "Tom"}).Only("last_name").Select(&userM)var userMS []map[string]interface{} users().Where("age > 10").Only("last_name", "age").Limit(100).Select(&userMS)
Update方法
分析update sql语句:
update user set first_name = "z", last_name = "zy" where first_name = "Tom" and last_name = "Curise"
比较简单,直接复用之前写的sKV()和mKV()函数:
//Update src can be *user, user, map[string]interface{}, stringfunc (q *Query) Update(src interface{}) (int64, error) { if len(q.errs) != 0 { return 0, errors.New(strings.Join(q.errs, " ")) } v := reflect.ValueOf(src) for v.Kind() == reflect.Ptr { v = v.Elem() } var toBeUpdated, where string var keys, values []string switch v.Kind() { case reflect.String: toBeUpdated = src.(string) case reflect.Struct: keys, values = sKV(v) case reflect.Map: keys, values = mKV(v) default: return 0, errors.New("method Update error: type error") } if toBeUpdated == "" { if len(keys) != len(values) { return 0, errors.New("method Update error: keys not match values") } var kvs []string for idx, key := range keys { kvs = append(kvs, fmt.Sprintf("%s = %s", key, values[idx])) } toBeUpdated = strings.Join(kvs, ",") } if len(q.wheres) > 0 { where = fmt.Sprintf(`where %s`, strings.Join(q.wheres, " and ")) } query := fmt.Sprintf("update %s set %s %s", q.table, toBeUpdated, where) st, err := q.DB.Prepare(query) if err != nil { return 0, err } result, err := st.Exec() if err != nil { return 0, err } return result.RowsAffected() }
调用方式:
u1 := "age = 100"u2 := map[string]interface{}{ "age": 100, "first_name": "z", "last_name": "zy", } u3 := &User{ Age: 100, FirstName: "z", LastName: "zy", } _, _ = users().Where("age > 10").Update(u1) _, _ = users().Where("age > 10").Update(u2) _, _ = users().Where("age > 10").Update(u3)
Delete方法
这个最简单,没啥好说的:
//Delete no args func (q *Query) Delete() (int64, error) { if len(q.errs) != 0 { return 0, errors.New(strings.Join(q.errs, " ")) } var where string if len(q.wheres) > 0 { where = fmt.Sprintf(`where %s`, strings.Join(q.wheres, " and ")) } st, err := q.DB.Prepare(fmt.Sprintf(`delete from %s %s`, q.table, where)) if err != nil { return 0, err } result, err := st.Exec() if err != nil { return 0, err } return result.RowsAffected() }
删除id为1,2,3,4,并且age大于10的用户的调用方式:
w := map[string]interface{}{ "id": []int{1, 2, 3, 4}, } _, _ = users().Where(w, "age > 10").Delete()
Transaction()
Transaction函数
beginrollbackcommitrecover*sql.DB.Begin()Query()Prepare()
//Dba *sql.DB or *sql.Txtype Dba interface { Query(string, ...interface{}) (*sql.Rows, error) Prepare(string) (*sql.Stmt, error) }
QueryDB
//Query will build a sqltype Query struct { DB Dba ... }
Table()
//Table bind db and tablefunc Table(db *sql.DB, tableName string) func(...Dba) *Query { return func(tx ...Dba) *Query { if len(tx) == 1 { return &Query{ DB: tx[0], table: tableName, } } return &Query{ DB: db, table: tableName, } } }
Transaction()
//Transaction .func Transaction(db *sql.DB, f func(Dba) error) (err error) { tx, err := db.Begin() if err != nil { return err } defer func() { p := recover() if err != nil { if rerr := tx.Rollback(); rerr != nil { panic(rerr) } return } if p != nil { if rerr := tx.Rollback(); rerr != nil { panic(rerr) } err = fmt.Errorf("function Transaction error: %v", p) return } if cerr := tx.Commit(); cerr != nil { panic(cerr) } }() err = f(tx) return err }
第二个参数是一个接受事务具柄,返回error的函数,我们将需要事务的操作全部封装在这个函数里,就能抓到所有的panic和error。
调用方式示例:
unc doTx() error { ormDB, err := Connect("root@tcp(127.0.0.1:3306)/orm_db?parseTime=true&loc=Local") if err != nil { panic(err) } users := Table(ormDB, "user") args := something() //利用闭包传递变量 f := func(tx Dba) error { var id int //select语句无需在事务具柄上进行 if err := users().Where(args).Select(&id); err != nil { return err } //增删改需要在事务上进行 if _, err = users(tx).Insert(args); err != nil { return err } if _, err = users(tx).Update(args); err != nil { return err } if _, err = users(tx).Where(args).Delete(); err != nil { return err } return nil } //开始事务 if err := Transaction(ormDB, f); err != nil { return err } return nil}
到此,这个迷你orm的增删改查和事务功能全部都实现了,代码大概600行,比我预想的多了一倍。