reflectmysqlrailsactive_recordmethod_missingdefine_methodCRUDtransaction准备工作
create database orm_dbcreateuserCREATE TABLE `user` ( `id` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT '自增主键', `age` smallint(10) unsigned NOT NULL DEFAULT 0 COMMENT '年龄', `first_name` varchar(45) NOT NULL DEFAULT '' COMMENT '姓', `last_name` varchar(45) NOT NULL DEFAULT '' COMMENT '名', `email` varchar(45) NOT NULL DEFAULT '' COMMENT '邮箱地址', `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', `updated_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', PRIMARY KEY (`id`), KEY `idx_email` (`email`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='用户表';
structtype User struct {
ID int64 `json:"id"` // 自增主键
Age int64 `json:"age"` // 年龄
FirstName string `json:"first_name"` // 姓
LastName string `json:"last_name"` // 名
Email string `json:"email"` // 邮箱地址
CreatedAt time.Time `json:"created_at"` // 创建时间
UpdatedAt time.Time `json:"updated_at"` // 更新时间}importpackage ormimport ( "database/sql" //register driver _ "github.com/go-sql-driver/mysql")
database//Connect db by dsn e.g. "user:password@tcp(127.0.0.1:3306)/dbname"func Connect(dsn string) (*sql.DB, error) {
conn, err := sql.Open("mysql", dsn) if err != nil { return nil, err
} //设置连接池
conn.SetMaxOpenConns(100)
conn.SetMaxIdleConns(10)
conn.SetConnMaxLifetime(10 * time.Minute) return conn, conn.Ping()
}structclass//Query will build a sqltype Query struct {
db *sql.DB
table string}Query//Table bind db and tablefunc Table(db *sql.DB, tableName string) func() *Query { return func() *Query { return &Query{
db: db,
table: tableName,
}
}
}Queryorm_dbuser//全局变量ormDB和usersormDB, _ := Connect("user:password@tcp(127.0.0.1:3306)/orm_db")
users := Table(ormDB, "user")//调用users().Insert(...)准备工作到此完成,下面进入正题。
Insert方法
insertinsert into user (first_name, last_name) values ('Tom', 'Cat'), ('Tom', 'Cruise')keyvalueInsert//Insert in can be *User, []*User, map[string]interface{}func (q *Query) Insert(in interface{}) (int64, error) { var keys, values []string
v := reflect.ValueOf(in) //剥离指针
for v.Kind() == reflect.Ptr {
v = v.Elem()
} switch v.Kind() { case reflect.Struct:
keys, values = sKV(v) case reflect.Map:
keys, values = mKV(v) case reflect.Slice: for i := 0; i < v.Len(); i++ { //Kind是切片时,可以用Index()方法遍历
sv := v.Index(i) for sv.Kind() == reflect.Ptr || sv.Kind() == reflect.Interface {
sv = sv.Elem()
} //切片元素不是struct或者指针,报错
if sv.Kind() != reflect.Struct { return 0, errors.New("method Insert error: in slice is not structs")
} //keys只保存一次就行,因为后面的都一样了
if len(keys) == 0 {
keys, values = sKV(sv) continue
}
_, val := sKV(sv)
values = append(values, val...)
} default: return 0, errors.New("method Insert error: type error")
} //todo
//...}inUser类型TypeValue*rtype*rtype*rtype地址元数据Kind原始类型type myInt int var i myInt t := reflect.TypeOf(i) k := t.Kind()
Interface()sKV()func sKV(v reflect.Value) ([]string, []string) {
var keys, values []string
t := v.Type() for n := 0; n < t.NumField(); n++ {
tf := t.Field(n)
vf := v.Field(n) //忽略非导出字段
if tf.Anonymous { continue
} //忽略无效、零值字段
if !vf.IsValid() || reflect.DeepEqual(vf.Interface(), reflect.Zero(vf.Type()).Interface()) { continue
} for vf.Type().Kind() == reflect.Ptr {
vf = vf.Elem()
} //有时候根据需求会组合struct,这里处理下,支持获取嵌套的struct tag和value
//如果字段值是time类型之外的struct,递归获取keys和values
if vf.Kind() == reflect.Struct && tf.Type.Name() != "Time" {
cKeys, cValues := sKV(vf)
keys = append(keys, cKeys...)
values = append(values, cValues...) continue
} //根据字段的json tag获取key,忽略无tag字段
key := strings.Split(tf.Tag.Get("json"), ",")[0] if key == "" { continue
}
value := format(vf) if value != "" {
keys = append(keys, key)
values = append(values, value)
}
} return keys, values
}sKV()format()time.Timefunc format(v reflect.Value) string { //断言出time类型直接转unix时间戳
if t, ok := v.Interface().(time.Time); ok { return fmt.Sprintf("FROM_UNIXTIME(%d)", t.Unix())
} switch v.Kind() { case reflect.String: return fmt.Sprintf(`'%s'`, v.Interface()) case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: return fmt.Sprintf(`%d`, v.Interface()) case reflect.Float32, reflect.Float64: return fmt.Sprintf(`%f`, v.Interface()) //如果是切片类型,遍历元素,递归格式化成"(, , , )"形式
case reflect.Slice: var values []string for i := 0; i < v.Len(); i++ {
values = append(values, format(v.Index(i)))
} return fmt.Sprintf(`(%s)`, strings.Join(values, ",")) //接口类型剥一层递归
case reflect.Interface: return format(v.Elem())
} return ""}mKV()func mKV(v reflect.Value) ([]string, []string) {
var keys, values []string
//获取map的key组成的切片
mapKeys := v.MapKeys() for _, key := range mapKeys {
value := format(v.MapIndex(key)) if value != "" {
values = append(values, value)
keys = append(keys, key.Interface().(string))
}
} return keys, values
}todo//Insert in can be User, *User, []User, []*User, map[string]interface{}func (q *Query) Insert(in interface{}) (int64, error) { //already done
kl := len(keys)
vl := len(values) if kl == 0 || vl == 0 { return 0, errors.New("method Insert error: no data")
} var insertValue string //插入多条记录时需要用","拼接一下values
if kl < vl { var tmpValues []string for kl <= vl { if kl%(len(keys)) == 0 {
tmpValues = append(tmpValues, fmt.Sprintf("(%s)", strings.Join(values[kl-len(keys):kl], ",")))
}
kl++
}
insertValue = strings.Join(tmpValues, ",")
} else {
insertValue = fmt.Sprintf("(%s)", strings.Join(values, ","))
}
query := fmt.Sprintf(`insert into %s (%s) values %s`, q.table, strings.Join(keys, ","), insertValue)
log.Printf("insert sql: %s", query)
st, err := q.DB.Prepare(query) if err != nil { return 0, err
}
result, err := st.Exec() if err != nil { return 0, err
} return result.LastInsertId()
}原理很简单,利用反射分析参数,取键值对,然后拼接sql语句,再通过mysql驱动入库。
调用示例:
user1 := &User{
Age: 30,
FirstName: "Tom",
LastName: "Cat",
}
user2 := User{
Age: 30,
FirstName: "Tom",
LastName: "Curise",
}
user3 := User{
Age: 30,
FirstName: "Tom",
LastName: "Hanks",
}
user4 := map[string]interface{}{ "age": 30, "first_name": "Tom", "last_name": "Zzy",
}
users().Insert([]interface{}{user1, user2})
users().Insert(user3)
users().Insert(user4)增查Select方法
selectselect id, age from user where first_name = 'Tom' and last_name = 'Cat'
selectwhereWhere()Select()var user []Userusers().Where(?).WhereNot(?).Limit(100).Offset(100).Order("id desc").Only("id", "age").Select(&user)所以需要改造Query如下,增加属性用于暂存链式调用中添加的值:
//Query will build a sqltype Query struct {
db *sql.DB
table string
wheres []string
only []string
limit string
offset string
order string
errs []string}"age > 10"//Where args can be string, User, *User, map[string]interface{}func (q *Query) Where(wheres ...interface{}) *Query { for _, w := range wheres {
v := reflect.ValueOf(w) for v.Kind() == reflect.Ptr {
v = v.Elem()
} switch v.Kind() { case reflect.String:
q.wheres = append(q.wheres, w.(string)) case reflect.Struct: //todo
case reflect.Map: //todo
default:
q.errs = append(q.errs, "method Where error: type error")
}
} return q
}WhereNot()where()func where(eq bool, w interface{}) (string, error) {
var keys, values []string
v := reflect.ValueOf(w) for v.Kind() == reflect.Ptr {
v = v.Elem()
} switch v.Kind() { case reflect.String: return w.(string), nil case reflect.Struct:
keys, values = sKV(v) case reflect.Map:
keys, values = mKV(v) default: return "", errors.New("method Where error: type error")
} if len(keys) != len(values) { return "", errors.New("method Where error: len(keys) not equal len(values))")
}
var wheres []string
//之前的format()函数里,已经将切片类型值处理成"( , , ,)“形式
for idx, key := range keys { if eq { if strings.HasPrefix(values[idx], "(") && strings.HasSuffix(values[idx], ")") {
wheres = append(wheres, fmt.Sprintf("%s in %s", key, values[idx])) continue
}
wheres = append(wheres, fmt.Sprintf("%s = %s", key, values[idx])) continue
} if strings.HasPrefix(values[idx], "(") && strings.HasSuffix(values[idx], ")") {
wheres = append(wheres, fmt.Sprintf("%s not in %s", key, values[idx])) continue
}
wheres = append(wheres, fmt.Sprintf("%s != %s", key, values[idx]))
} return strings.Join(wheres, " and "), nil
}Where()方法最终变成:
//Where args can be string, User, *User, map[string]interface{}func (q *Query) Where(wheres ...interface{}) *Query { for _, w := range wheres {
str, err := where(true, w)
q.wheres = append(q.wheres, str) if err != nil { //因为需要达到链式调用的效果,所以把错误都搜集起来,最后再处理
q.errs = append(q.errs, err.Error())
}
} return q
}Limit()Offset()Order()Only()//Limit .func (q *Query) Limit(limit uint) *Query {
q.limit = fmt.Sprintf("limit %d", limit) return q
}//Offset .func (q *Query) Offset(offset uint) *Query {
q.offset = fmt.Sprintf("offset %d", offset) return q
}//Order .func (q *Query) Order(ord string) *Query {
q.order = fmt.Sprintf("order by %s", ord) return q
}//Only 指定需要查询的字段func (q *Query) Only(columns ...string) *Query {
q.only = append(q.only, columns...) return q
}toSQL()func (q *Query) toSQL() string {
var where string if len(q.wheres) > 0 { where = fmt.Sprintf(`where %s`, strings.Join(q.wheres, " and "))
}
sqlStr := fmt.Sprintf(`select %s from %s %s %s %s %s`, strings.Join(q.only, ","), q.table, where, q.order, q.limit, q.offset)
log.Printf("select sql: %s", sqlStr) return sqlStr
}structuserUservar user User users.Select(&user)var userPtr *User users.Select(user)
这两种声明方式是不同的,后者只声明了一个指针类型,是错误的。
综上,我们首先为Select()方法做一下的参数检查,确保传入值是一个正确的指针,并确保only属性有值:
//Select dest must be a ptr, e.g. *user, *[]user, *[]*user, *map, *[]map, *int, *[]intfunc (q *Query) Select(dest interface{}) error { if len(q.errs) != 0 { return errors.New(strings.Join(q.errs, "
"))
}
t := reflect.TypeOf(dest)
v := reflect.ValueOf(dest)
typeErr := errors.New("method Select error: type error") if t.Kind() != reflect.Ptr { return typeErr
} //如果是用 var userPtr *User 方式声明的变量,则不可取址
if !v.Elem().CanAddr() { return typeErr
}
t = t.Elem()
v = v.Elem() //如果only此时仍然为空,说明Only()方法未被调用,我们从struct上取tag填充
if len(q.only) == 0 { switch t.Kind() { case reflect.Struct: if t.Name() != "Time" {
q.only = sK(v)
} case reflect.Slice: //获取切片的基本类型给一个局部变量
t := t.Elem() if t.Kind() == reflect.Ptr {
t = t.Elem()
} if t.Kind() == reflect.Struct { if t.Name() != "Time" {
q.only = sK(reflect.Zero(t))
}
}
}
} if len(q.only) == 0 { return errors.New("method Select error: type error, no columns to select")
} if t.Kind() != reflect.Slice {
q.limit = "limit 1"
} //todo}这里只取struct的tag,不取value,我们定义一个新的sK()函数:
func sK(v reflect.Value) []string {
var keys []string
t := v.Type() for n := 0; n < t.NumField(); n++ {
tf := t.Field(n)
vf := v.Field(n) //忽略非导出字段
if tf.Anonymous { continue
} for vf.Type().Kind() == reflect.Ptr {
vf = vf.Elem()
} //如果字段值是time类型之外的struct,递归获取keys
if vf.Kind() == reflect.Struct && tf.Type.Name() != "Time" {
keys = append(keys, sK(vf)...) continue
} //根据字段的json tag获取key,忽略无tag字段
key := strings.Split(tf.Tag.Get("json"), ",")[0] if key == "" { continue
}
keys = append(keys, key)
} return keys
}Scan()Set()address()func address(dest reflect.Value, columns []string) []interface{} {
dest = dest.Elem()
t := dest.Type()
addrs := make([]interface{}, 0) switch t.Kind() { case reflect.Struct: for n := 0; n < t.NumField(); n++ {
tf := t.Field(n)
vf := dest.Field(n) if tf.Anonymous { continue
} for vf.Type().Kind() == reflect.Ptr {
vf = vf.Elem()
} //如果字段值是time类型之外的struct,递归取址
if vf.Kind() == reflect.Struct && tf.Type.Name() != "Time" {
nVf := reflect.New(vf.Type())
vf.Set(nVf.Elem())
addrs = append(addrs, address(nVf, columns)...) continue
}
column := strings.Split(tf.Tag.Get("json"), ",")[0] if column == "" { continue
} //只取选定的字段的地址
for _, col := range columns { if col == column {
addrs = append(addrs, vf.Addr().Interface()) break
}
}
} default:
addrs = append(addrs, dest.Addr().Interface())
} return addrs
}Value.Addr()Value.CanAddr()relfect.New()TypenewValue指针Set()setMap()//map的value类型必须是interface{},因为无类型信息,所以mysql驱动会返回一个字节切片,需要自行用[]byte断言func (q *Query) setMap(rows *sql.Rows, t reflect.Type) (reflect.Value, error) { if t.Elem().Kind() != reflect.Interface { return reflect.ValueOf(nil), errors.New("method setMap error: type error, must be map[string]interface{}")
}
m := reflect.MakeMap(t)
addrs := make([]interface{}, len(q.only)) for idx := range q.only {
addrs[idx] = new(interface{})
} if err := rows.Scan(addrs...); err != nil { return reflect.ValueOf(nil), err
} for idx, column := range q.only { //从指针剥出interface{},再剥出实际值
m.SetMapIndex(reflect.ValueOf(column), reflect.ValueOf(addrs[idx]).Elem().Elem())
} return m, nil
}reflect.MakeMap()make()Kindreflect.MapTypenewsetElem()//适用于基类型和structfunc (q *Query) setElem(rows *sql.Rows, t reflect.Type) (reflect.Value, error) {
addrsErr := errors.New("method setElem error: columns not match addresses")
dest := reflect.New(t)
addrs := address(dest, q.only) if len(q.only) != len(addrs) { return reflect.ValueOf(nil), addrsErr
} if err := rows.Scan(addrs...); err != nil { return reflect.ValueOf(nil), err
} return dest, nil}这些函数完成后,就可以着手完善Select()里的todo部分了:
//already donerows, err := q.DB.Query(q.toSQL()) if err != nil { return err
} switch t.Kind() { case reflect.Slice:
dt := t.Elem() for dt.Kind() == reflect.Ptr {
dt = dt.Elem()
}
sl := reflect.MakeSlice(t, 0, 0) for rows.Next() {
var destination reflect.Value if dt.Kind() == reflect.Map {
destination, err = q.setMap(rows, dt)
} else {
destination, err = q.setElem(rows, dt)
} if err != nil { return err
} //区分切片元素是否指针
switch t.Elem().Kind() { case reflect.Ptr, reflect.Map:
sl = reflect.Append(sl, destination) default:
sl = reflect.Append(sl, destination.Elem())
}
}
v.Set(sl) return nil
case reflect.Map: for rows.Next() {
m, err := q.setMap(rows, t) if err != nil { return err
}
v.Set(m)
} return nil
default: for rows.Next() {
destination, err := q.setElem(rows, t) if err != nil { return err
}
v.Set(destination.Elem())
}
} return nil至此,Select()方法就大功告成了,部分调用方式示例:
var user User
users()
.Where("first_name = 'Tom'", map[string]interface{}{ "id": []int{1, 2, 3, 4},
})
.WhereNot(&User{LastName: "Cat"})
.Only("last_name")
.Select(&user)var userMore []User
users().Where("first_name = 'Tom'").Order("id desc").Select(&userMore)var userMoreP []*User
users().Where("first_name = 'Tom'").Select(&userMoreP)var lastName string
users().Where(&User{FirstName: "Tom"}).Only("last_name").Select(&lastName)var lastNames []string
users().Where(map[string]interface{}{ "first_name": "Tom",
}).Only("last_name").Select(&lastNames)var userM map[string]interface{}
users().Where(&User{FirstName: "Tom"}).Only("last_name").Select(&userM)var userMS []map[string]interface{}
users().Where("age > 10").Only("last_name", "age").Limit(100).Select(&userMS)Update方法
分析update sql语句:
update user set first_name = "z", last_name = "zy" where first_name = "Tom" and last_name = "Curise"
比较简单,直接复用之前写的sKV()和mKV()函数:
//Update src can be *user, user, map[string]interface{}, stringfunc (q *Query) Update(src interface{}) (int64, error) { if len(q.errs) != 0 { return 0, errors.New(strings.Join(q.errs, "
"))
}
v := reflect.ValueOf(src) for v.Kind() == reflect.Ptr {
v = v.Elem()
} var toBeUpdated, where string var keys, values []string switch v.Kind() { case reflect.String:
toBeUpdated = src.(string) case reflect.Struct:
keys, values = sKV(v) case reflect.Map:
keys, values = mKV(v) default: return 0, errors.New("method Update error: type error")
} if toBeUpdated == "" { if len(keys) != len(values) { return 0, errors.New("method Update error: keys not match values")
} var kvs []string for idx, key := range keys {
kvs = append(kvs, fmt.Sprintf("%s = %s", key, values[idx]))
}
toBeUpdated = strings.Join(kvs, ",")
} if len(q.wheres) > 0 {
where = fmt.Sprintf(`where %s`, strings.Join(q.wheres, " and "))
}
query := fmt.Sprintf("update %s set %s %s", q.table, toBeUpdated, where)
st, err := q.DB.Prepare(query) if err != nil { return 0, err
}
result, err := st.Exec() if err != nil { return 0, err
} return result.RowsAffected()
}调用方式:
u1 := "age = 100"u2 := map[string]interface{}{ "age": 100, "first_name": "z", "last_name": "zy",
}
u3 := &User{
Age: 100,
FirstName: "z",
LastName: "zy",
}
_, _ = users().Where("age > 10").Update(u1)
_, _ = users().Where("age > 10").Update(u2)
_, _ = users().Where("age > 10").Update(u3)Delete方法
这个最简单,没啥好说的:
//Delete no args
func (q *Query) Delete() (int64, error) { if len(q.errs) != 0 { return 0, errors.New(strings.Join(q.errs, "
"))
}
var where string if len(q.wheres) > 0 { where = fmt.Sprintf(`where %s`, strings.Join(q.wheres, " and "))
}
st, err := q.DB.Prepare(fmt.Sprintf(`delete from %s %s`, q.table, where)) if err != nil { return 0, err
}
result, err := st.Exec() if err != nil { return 0, err
} return result.RowsAffected()
}删除id为1,2,3,4,并且age大于10的用户的调用方式:
w := map[string]interface{}{ "id": []int{1, 2, 3, 4},
}
_, _ = users().Where(w, "age > 10").Delete()Transaction()Transaction函数
beginrollbackcommitrecover*sql.DB.Begin()Query()Prepare()//Dba *sql.DB or *sql.Txtype Dba interface {
Query(string, ...interface{}) (*sql.Rows, error)
Prepare(string) (*sql.Stmt, error)
}QueryDB//Query will build a sqltype Query struct {
DB Dba
...
}Table()//Table bind db and tablefunc Table(db *sql.DB, tableName string) func(...Dba) *Query { return func(tx ...Dba) *Query { if len(tx) == 1 { return &Query{
DB: tx[0],
table: tableName,
}
} return &Query{
DB: db,
table: tableName,
}
}
}Transaction()//Transaction .func Transaction(db *sql.DB, f func(Dba) error) (err error) {
tx, err := db.Begin() if err != nil { return err
}
defer func() {
p := recover() if err != nil { if rerr := tx.Rollback(); rerr != nil {
panic(rerr)
} return
} if p != nil { if rerr := tx.Rollback(); rerr != nil {
panic(rerr)
}
err = fmt.Errorf("function Transaction error: %v", p) return
} if cerr := tx.Commit(); cerr != nil {
panic(cerr)
}
}()
err = f(tx) return err
}第二个参数是一个接受事务具柄,返回error的函数,我们将需要事务的操作全部封装在这个函数里,就能抓到所有的panic和error。
调用方式示例:
unc doTx() error {
ormDB, err := Connect("root@tcp(127.0.0.1:3306)/orm_db?parseTime=true&loc=Local") if err != nil {
panic(err)
}
users := Table(ormDB, "user")
args := something() //利用闭包传递变量
f := func(tx Dba) error {
var id int
//select语句无需在事务具柄上进行
if err := users().Where(args).Select(&id); err != nil { return err
} //增删改需要在事务上进行
if _, err = users(tx).Insert(args); err != nil { return err
} if _, err = users(tx).Update(args); err != nil { return err
} if _, err = users(tx).Where(args).Delete(); err != nil { return err
} return nil
} //开始事务
if err := Transaction(ormDB, f); err != nil { return err
} return nil}到此,这个迷你orm的增删改查和事务功能全部都实现了,代码大概600行,比我预想的多了一倍。