帅气猫咪

帅气猫咪 查看完整档案

填写现居城市  |  填写毕业院校  |  填写所在公司/组织填写个人主网站
编辑
_ | |__ _ _ __ _ | '_ \| | | |/ _` | | |_) | |_| | (_| | |_.__/ \__,_|\__, | |___/ 个人简介什么都没有

个人动态

帅气猫咪 发布了文章 · 2020-11-07

操作DOCX文件

简介

Word 是非常常见的文件格式, 可以使用 python 来操作 Word 文档.

格式类型

Word 有两种类型的文档, 文件后缀分别为 .doc.docx.
前者是 Office 2003 时的格式, 后者是 Office 2007 以后推出的新格式.
通常来说, 我建议大家使用新版本, 但是免不了很多时候会使用老版本.

转换 doc 为 docx 文件

这是在 linux 上使用的, 需要借助 soffice.
如果需要在 windows 上操作, 建议使用 win32com.
这里有一个参考连接.

import glob
import subprocess
from pathlib import Path

"""
借助 soffice 将 doc 转换为 docx
"""

base_dir = Path(__file__).resolve().parent
print(base_dir)

doc_list = glob.glob(base_dir.as_posix() + "/**/*.doc", recursive=True)
print(doc_list)


for doc in doc_list:
    subprocess.call(
        [
            "soffice",
            "--headless",
            "--convert-to",
            "docx",
            "--outdir",
            Path(doc).parent.as_posix(),
            doc,
        ]
    )


doc_list = glob.glob(base_dir.as_posix() + "/**/*.docx", recursive=True)
print(doc_list)s

读写 docx 文件

这里推荐使用 python-docx.

pip install python-docx

简单使用

from docx import Document

doc = Document(doc_path)

# 读取表格列表
table_list = doc.tables
# 读取段落
paragraph_list = [x.text.strip() for x in doc.paragraphs if x.text.strip()]

更多的内容可以参考官方文档.

小结

现在, 我们已经可以使用 python 转换 doc 为 docx 文件, 并从中读取内容了.

查看原文

赞 0 收藏 0 评论 0

帅气猫咪 发布了文章 · 2020-06-21

04GORM源码解读

简介

GORM 源码解读, 基于 v1.9.11 版本.

查询

上一节中, 我们已经探究过了模型是如何定义的, 以及数据表是如何创建的.
这次, 看一下查询是如何实现的.

查询涉及到很大的一块内容, 因为要支持各种类型的方法.
先看一下官方文档中提供的最简单的几个查询方法.

// 根据主键查询第一条记录
db.First(&user)
//// SELECT * FROM users ORDER BY id LIMIT 1;

// 随机获取一条记录
db.Take(&user)
//// SELECT * FROM users LIMIT 1;

// 根据主键查询最后一条记录
db.Last(&user)
//// SELECT * FROM users ORDER BY id DESC LIMIT 1;

// 查询所有的记录
db.Find(&users)
//// SELECT * FROM users;

// 查询指定的某条记录(仅当主键为整型时可用)
db.First(&user, 10)
//// SELECT * FROM users WHERE id = 10;

First 方法为例, 看一下它的实现:

// First find first record that match given conditions, order by primary key
func (s *DB) First(out interface{}, where ...interface{}) *DB {
    newScope := s.NewScope(out)
    newScope.Search.Limit(1)

    return newScope.Set("gorm:order_by_primary_key", "ASC").
        inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
}

First 方法从数据库中获取第一条数据, 以 primary key 升序排序.

前面介绍过, 具体的数据库操作实现是依靠 callbacks 的. 这里用到了 callbacks.queries.

在默认的 callbacks 中, 注册了三个不同的 query 回调函数.

// Define callbacks for querying
func init() {
    DefaultCallback.Query().Register("gorm:query", queryCallback)
    DefaultCallback.Query().Register("gorm:preload", preloadCallback)
    DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
}

查询流程

先来看一下最主要的 queryCallback 函数.

// queryCallback used to query data from database
func queryCallback(scope *Scope) {
    if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
        return
    }

    //we are only preloading relations, dont touch base model
    if _, skip := scope.InstanceGet("gorm:only_preload"); skip {
        return
    }

    defer scope.trace(scope.db.nowFunc())

    var (
        isSlice, isPtr bool
        resultType     reflect.Type
        results        = scope.IndirectValue()
    )

    if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
        if primaryField := scope.PrimaryField(); primaryField != nil {
            scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
        }
    }

    if value, ok := scope.Get("gorm:query_destination"); ok {
        results = indirect(reflect.ValueOf(value))
    }

    if kind := results.Kind(); kind == reflect.Slice {
        isSlice = true
        resultType = results.Type().Elem()
        results.Set(reflect.MakeSlice(results.Type(), 0, 0))

        if resultType.Kind() == reflect.Ptr {
            isPtr = true
            resultType = resultType.Elem()
        }
    } else if kind != reflect.Struct {
        scope.Err(errors.New("unsupported destination, should be slice or struct"))
        return
    }

    scope.prepareQuerySQL()

    if !scope.HasError() {
        scope.db.RowsAffected = 0
        if str, ok := scope.Get("gorm:query_option"); ok {
            scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
        }

        if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
            defer rows.Close()

            columns, _ := rows.Columns()
            for rows.Next() {
                scope.db.RowsAffected++

                elem := results
                if isSlice {
                    elem = reflect.New(resultType).Elem()
                }

                scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())

                if isSlice {
                    if isPtr {
                        results.Set(reflect.Append(results, elem.Addr()))
                    } else {
                        results.Set(reflect.Append(results, elem))
                    }
                }
            }

            if err := rows.Err(); err != nil {
                scope.Err(err)
            } else if scope.db.RowsAffected == 0 && !isSlice {
                scope.Err(ErrRecordNotFound)
            }
        }
    }
}

核心的步骤在于 scope.prepareQuerySQL() 构建 SQL 语句.
然后通过 rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...), 执行了数据库查询.

那么查询到的结果是如何传递的, 传递给谁呢?

函数的开头定义了 results = scope.IndirectValue(), 这就是最终查询结果的归属地.

results 只能是结构体或者是结构体的切片.

if kind := results.Kind(); kind == reflect.Slice {
  isSlice = true
  resultType = results.Type().Elem()
  results.Set(reflect.MakeSlice(results.Type(), 0, 0))

  if resultType.Kind() == reflect.Ptr {
    isPtr = true
    resultType = resultType.Elem()
  }
} else if kind != reflect.Struct {
  scope.Err(errors.New("unsupported destination, should be slice or struct"))
  return
}

具体如何处理查询到的结果是在下面这部分代码中:

columns, _ := rows.Columns()
for rows.Next() {
  scope.db.RowsAffected++

  elem := results
  if isSlice {
    elem = reflect.New(resultType).Elem()
  }

  scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())

  if isSlice {
    if isPtr {
      results.Set(reflect.Append(results, elem.Addr()))
    } else {
      results.Set(reflect.Append(results, elem))
    }
  }
}

这部分代码的核心语句在于 scope.scan, 看一下这个方法的定义:

func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
    var (
        ignored            interface{}
        values             = make([]interface{}, len(columns))
        selectFields       []*Field
        selectedColumnsMap = map[string]int{}
        resetFields        = map[int]*Field{}
    )

    for index, column := range columns {
        values[index] = &ignored

        selectFields = fields
        offset := 0
        if idx, ok := selectedColumnsMap[column]; ok {
            offset = idx + 1
            selectFields = selectFields[offset:]
        }

        for fieldIndex, field := range selectFields {
            if field.DBName == column {
                if field.Field.Kind() == reflect.Ptr {
                    values[index] = field.Field.Addr().Interface()
                } else {
                    reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))
                    reflectValue.Elem().Set(field.Field.Addr())
                    values[index] = reflectValue.Interface()
                    resetFields[index] = field
                }

                selectedColumnsMap[column] = offset + fieldIndex

                if field.IsNormal {
                    break
                }
            }
        }
    }

    scope.Err(rows.Scan(values...))

    for index, field := range resetFields {
        if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
            field.Field.Set(v)
        }
    }
}

就和它的名字暗示的那样, 实际上就是调用了 rows.Scan(values...), 将查询到的数据复制到对应的字段中.

由此, 我们就了解了查询时的主要流程了.

前面专注于流程, 略过了构建 SQL 语句的细节, 来仔细看看 prepareQuerySQL 方法.

构建查询 SQL 语句

func (scope *Scope) prepareQuerySQL() {
    if scope.Search.raw {
        scope.Raw(scope.CombinedConditionSql())
    } else {
        scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))
    }
    return
}

内部分支中都使用到了 scope.Raw, 看一下它的实现:

// Raw set raw sql
func (scope *Scope) Raw(sql string) *Scope {
    scope.SQL = strings.Replace(sql, "$$$", "?", -1)
    return scope
}

它的作用是将获取到的 sql 语句赋值到 scope.SQL 字段上, 其中替换了所有的 $$$?.

回到 prepareQuerySQL 上来, 重要的部分是其实是 Raw 的参数.
if 的后半部分更好理解点, 就是构建了 SELECT 表达式.

SELECT 表达式需要三个变量, 字段名, 表名, 条件.

将每个都看一下吧.

func (scope *Scope) selectSQL() string {
    if len(scope.Search.selects) == 0 {
        if len(scope.Search.joinConditions) > 0 {
            return fmt.Sprintf("%v.*", scope.QuotedTableName())
        }
        return "*"
    }
    return scope.buildSelectQuery(scope.Search.selects)
}

func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) {
    switch value := clause["query"].(type) {
    case string:
        str = value
    case []string:
        str = strings.Join(value, ", ")
    }

    args := clause["args"].([]interface{})
    replacements := []string{}
    for _, arg := range args {
        switch reflect.ValueOf(arg).Kind() {
        case reflect.Slice:
            values := reflect.ValueOf(arg)
            var tempMarks []string
            for i := 0; i < values.Len(); i++ {
                tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
            }
            replacements = append(replacements, strings.Join(tempMarks, ","))
        default:
            if valuer, ok := interface{}(arg).(driver.Valuer); ok {
                arg, _ = valuer.Value()
            }
            replacements = append(replacements, scope.AddToVars(arg))
        }
    }

    buff := bytes.NewBuffer([]byte{})
    i := 0
    for pos, char := range str {
        if str[pos] == '?' {
            buff.WriteString(replacements[i])
            i++
        } else {
            buff.WriteRune(char)
        }
    }

    str = buff.String()

    return
}

scope.Search.selects 为空的时候, 比较简单.
只要根据是否有连表查询, 返回 table.**.

buildSelectQuery 就是根据 scope.Search.selects 构建查询字段名.

前面半部分一看就明白.

switch value := clause["query"].(type) {
case string:
  str = value
case []string:
  str = strings.Join(value, ", ")
}

重点是遇到参数时如何处理, 也就是后半段代码.

args := clause["args"].([]interface{})
replacements := []string{}
for _, arg := range args {
  switch reflect.ValueOf(arg).Kind() {
  case reflect.Slice:
    values := reflect.ValueOf(arg)
    var tempMarks []string
    for i := 0; i < values.Len(); i++ {
      tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
    }
    replacements = append(replacements, strings.Join(tempMarks, ","))
  default:
    if valuer, ok := interface{}(arg).(driver.Valuer); ok {
      arg, _ = valuer.Value()
    }
    replacements = append(replacements, scope.AddToVars(arg))
  }
}

buff := bytes.NewBuffer([]byte{})
i := 0
for pos, char := range str {
  if str[pos] == '?' {
    buff.WriteString(replacements[i])
    i++
  } else {
    buff.WriteRune(char)
  }
}

主要的过程是遍历 args := clause["args"].([]interface{}),
创建了一个 replacements 切片. 然后将 str 中所有的 ?,
替换为了对应的字段.

到此, 构建 SELECT 字段的过程就结束了.

获取表名的过程相对简单, 直接展示代码吧:

// QuotedTableName return quoted table name
func (scope *Scope) QuotedTableName() (name string) {
    if scope.search != nil && len(scope.Search.tableName) > 0 {
        if strings.Contains(scope.Search.tableName, " ") {
            return scope.Search.tableName
        }
        return scope.Quote(scope.Search.tableName)
    }

    return scope.Quote(scope.TableName())
}

条件语句

更多的关注点在于如何构建筛选条件, 即 CombinedConditionSql 方法.

// CombinedConditionSql return combined condition sql
func (scope *Scope) CombinedConditionSql() string {
    joinSQL := scope.joinsSQL()
    whereSQL := scope.whereSQL()
    if scope.Search.raw {
        whereSQL = strings.TrimSuffix(strings.TrimPrefix(whereSQL, "WHERE ("), ")")
    }
    return joinSQL + whereSQL + scope.groupSQL() +
        scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL()
}

短小的代码中是精简的逻辑, 条件语句有很多模块, 这里总共有 6 个子句.
都看一遍吧, 看完之后应该对如何构建条件语句不会陌生了.

func (scope *Scope) joinsSQL() string {
    var joinConditions []string
    for _, clause := range scope.Search.joinConditions {
        if sql := scope.buildCondition(clause, true); sql != "" {
            joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")"))
        }
    }

    return strings.Join(joinConditions, " ") + " "
}

创建 joinSQL 的过程中主要用到了 buildCondition, 继续深入:

func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) (str string) {
    var (
        quotedTableName  = scope.QuotedTableName()
        quotedPrimaryKey = scope.Quote(scope.PrimaryKey())
        equalSQL         = "="
        inSQL            = "IN"
    )

    // If building not conditions
    if !include {
        equalSQL = "<>"
        inSQL = "NOT IN"
    }

    switch value := clause["query"].(type) {
    case sql.NullInt64:
        return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value.Int64)
    case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
        return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value)
    case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
        if !include && reflect.ValueOf(value).Len() == 0 {
            return
        }
        str = fmt.Sprintf("(%v.%v %s (?))", quotedTableName, quotedPrimaryKey, inSQL)
        clause["args"] = []interface{}{value}
    case string:
        if isNumberRegexp.MatchString(value) {
            return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, scope.AddToVars(value))
        }

        if value != "" {
            if !include {
                if comparisonRegexp.MatchString(value) {
                    str = fmt.Sprintf("NOT (%v)", value)
                } else {
                    str = fmt.Sprintf("(%v.%v NOT IN (?))", quotedTableName, scope.Quote(value))
                }
            } else {
                str = fmt.Sprintf("(%v)", value)
            }
        }
    case map[string]interface{}:
        var sqls []string
        for key, value := range value {
            if value != nil {
                sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(key), equalSQL, scope.AddToVars(value)))
            } else {
                if !include {
                    sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", quotedTableName, scope.Quote(key)))
                } else {
                    sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", quotedTableName, scope.Quote(key)))
                }
            }
        }
        return strings.Join(sqls, " AND ")
    case interface{}:
        var sqls []string
        newScope := scope.New(value)

        if len(newScope.Fields()) == 0 {
            scope.Err(fmt.Errorf("invalid query condition: %v", value))
            return
        }
        scopeQuotedTableName := newScope.QuotedTableName()
        for _, field := range newScope.Fields() {
            if !field.IsIgnored && !field.IsBlank {
                sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
            }
        }
        return strings.Join(sqls, " AND ")
    default:
        scope.Err(fmt.Errorf("invalid query condition: %v", value))
        return
    }

    replacements := []string{}
    args := clause["args"].([]interface{})
    for _, arg := range args {
        var err error
        switch reflect.ValueOf(arg).Kind() {
        case reflect.Slice: // For where("id in (?)", []int64{1,2})
            if scanner, ok := interface{}(arg).(driver.Valuer); ok {
                arg, err = scanner.Value()
                replacements = append(replacements, scope.AddToVars(arg))
            } else if b, ok := arg.([]byte); ok {
                replacements = append(replacements, scope.AddToVars(b))
            } else if as, ok := arg.([][]interface{}); ok {
                var tempMarks []string
                for _, a := range as {
                    var arrayMarks []string
                    for _, v := range a {
                        arrayMarks = append(arrayMarks, scope.AddToVars(v))
                    }

                    if len(arrayMarks) > 0 {
                        tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ",")))
                    }
                }

                if len(tempMarks) > 0 {
                    replacements = append(replacements, strings.Join(tempMarks, ","))
                }
            } else if values := reflect.ValueOf(arg); values.Len() > 0 {
                var tempMarks []string
                for i := 0; i < values.Len(); i++ {
                    tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
                }
                replacements = append(replacements, strings.Join(tempMarks, ","))
            } else {
                replacements = append(replacements, scope.AddToVars(Expr("NULL")))
            }
        default:
            if valuer, ok := interface{}(arg).(driver.Valuer); ok {
                arg, err = valuer.Value()
            }

            replacements = append(replacements, scope.AddToVars(arg))
        }

        if err != nil {
            scope.Err(err)
        }
    }

    buff := bytes.NewBuffer([]byte{})
    i := 0
    for _, s := range str {
        if s == '?' && len(replacements) > i {
            buff.WriteString(replacements[i])
            i++
        } else {
            buff.WriteRune(s)
        }
    }

    str = buff.String()

    return
}

开头是一个精妙的选择, 基于 include, 实现了 not 条件.

var (
  quotedTableName  = scope.QuotedTableName()
  quotedPrimaryKey = scope.Quote(scope.PrimaryKey())
  equalSQL         = "="
  inSQL            = "IN"
)

// If building not conditions
if !include {
  equalSQL = "<>"
  inSQL = "NOT IN"
}

中间是一个 switch value := clause["query"].(type) 选择.
在这个 switch 选择中, 大部分的条件都会直接返回.
剩余的部分, 则会构建 str 字符串变量.

而这会继续进入到结尾部分, 这部分的代码和我们上面看过的非常类似,
就是根据 clause["args"] 构建 replacements 切片,
用来替换 str 变量中的 ?.

接着看下一个 whereSQL 方法.

func (scope *Scope) whereSQL() (sql string) {
    var (
        quotedTableName                                = scope.QuotedTableName()
        deletedAtField, hasDeletedAtField              = scope.FieldByName("DeletedAt")
        primaryConditions, andConditions, orConditions []string
    )

    if !scope.Search.Unscoped && hasDeletedAtField {
        sql := fmt.Sprintf("%v.%v IS NULL", quotedTableName, scope.Quote(deletedAtField.DBName))
        primaryConditions = append(primaryConditions, sql)
    }

    if !scope.PrimaryKeyZero() {
        for _, field := range scope.PrimaryFields() {
            sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
            primaryConditions = append(primaryConditions, sql)
        }
    }

    for _, clause := range scope.Search.whereConditions {
        if sql := scope.buildCondition(clause, true); sql != "" {
            andConditions = append(andConditions, sql)
        }
    }

    for _, clause := range scope.Search.orConditions {
        if sql := scope.buildCondition(clause, true); sql != "" {
            orConditions = append(orConditions, sql)
        }
    }

    for _, clause := range scope.Search.notConditions {
        if sql := scope.buildCondition(clause, false); sql != "" {
            andConditions = append(andConditions, sql)
        }
    }

    orSQL := strings.Join(orConditions, " OR ")
    combinedSQL := strings.Join(andConditions, " AND ")
    if len(combinedSQL) > 0 {
        if len(orSQL) > 0 {
            combinedSQL = combinedSQL + " OR " + orSQL
        }
    } else {
        combinedSQL = orSQL
    }

    if len(primaryConditions) > 0 {
        sql = "WHERE " + strings.Join(primaryConditions, " AND ")
        if len(combinedSQL) > 0 {
            sql = sql + " AND (" + combinedSQL + ")"
        }
    } else if len(combinedSQL) > 0 {
        sql = "WHERE " + combinedSQL
    }
    return
}

主要构建了三个部分, primaryConditions, andConditions, orConditions.

if !scope.Search.Unscoped && hasDeletedAtField {
  sql := fmt.Sprintf("%v.%v IS NULL", quotedTableName, scope.Quote(deletedAtField.DBName))
  primaryConditions = append(primaryConditions, sql)
}

if !scope.PrimaryKeyZero() {
  for _, field := range scope.PrimaryFields() {
    sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
    primaryConditions = append(primaryConditions, sql)
  }
}

前面两个 if 构建了 primaryConditions 条件.

for _, clause := range scope.Search.whereConditions {
  if sql := scope.buildCondition(clause, true); sql != "" {
    andConditions = append(andConditions, sql)
  }
}

for _, clause := range scope.Search.orConditions {
  if sql := scope.buildCondition(clause, true); sql != "" {
    orConditions = append(orConditions, sql)
  }
}

for _, clause := range scope.Search.notConditions {
  if sql := scope.buildCondition(clause, false); sql != "" {
    andConditions = append(andConditions, sql)
  }
}

然后三个 for 循环都使用了 buildCondition 方法.
注意到 scope.Search.notConditions 是算在 andConditions 中的.

orSQL := strings.Join(orConditions, " OR ")
combinedSQL := strings.Join(andConditions, " AND ")
if len(combinedSQL) > 0 {
  if len(orSQL) > 0 {
    combinedSQL = combinedSQL + " OR " + orSQL
  }
} else {
  combinedSQL = orSQL
}

结合 orConditionsandConditions 生成了条件语句.

if len(primaryConditions) > 0 {
  sql = "WHERE " + strings.Join(primaryConditions, " AND ")
  if len(combinedSQL) > 0 {
    sql = sql + " AND (" + combinedSQL + ")"
  }
} else if len(combinedSQL) > 0 {
  sql = "WHERE " + combinedSQL
}
return

最后, 结合 primaryConditions 生成最终的 WHERE 子句.

接着看另一个:

func (scope *Scope) groupSQL() string {
    if len(scope.Search.group) == 0 {
        return ""
    }
    return " GROUP BY " + scope.Search.group
}

GROUP BY 子句比较简单, 直接就能构建.

继续:

func (scope *Scope) havingSQL() string {
    if len(scope.Search.havingConditions) == 0 {
        return ""
    }

    var andConditions []string
    for _, clause := range scope.Search.havingConditions {
        if sql := scope.buildCondition(clause, true); sql != "" {
            andConditions = append(andConditions, sql)
        }
    }

    combinedSQL := strings.Join(andConditions, " AND ")
    if len(combinedSQL) == 0 {
        return ""
    }

    return " HAVING " + combinedSQL
}

HAVING 子句也不算难, 构建完条件之后用 AND 连接, 然后在最前面加上 HAVING 就行了.

继续:

func (scope *Scope) orderSQL() string {
    if len(scope.Search.orders) == 0 || scope.Search.ignoreOrderQuery {
        return ""
    }

    var orders []string
    for _, order := range scope.Search.orders {
        if str, ok := order.(string); ok {
            orders = append(orders, scope.quoteIfPossible(str))
        } else if expr, ok := order.(*expr); ok {
            exp := expr.expr
            for _, arg := range expr.args {
                exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
            }
            orders = append(orders, exp)
        }
    }
    return " ORDER BY " + strings.Join(orders, ",")
}

结构也是类似, 遍历 scope.Search.orders 切片, order 有两种不同的类型, 字符串或者 expr 结构体.
后者用于处理带参数的情况.

最后还有一个 limitAndOffsetSQL 方法:

func (scope *Scope) limitAndOffsetSQL() string {
    return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
}

这直接调用了具体数据库驱动中的 LimitAndOffsetSQL 方法.

看两个具体的实现, 一个是通用中的实现, 另一个是 mysql 中的实现.

func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
    if limit != nil {
        if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
            sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
        }
    }
    if offset != nil {
        if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
            sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
        }
    }
    return
}

直接将 limit 和 offset 解析为 int 类型, 然后连接对应的关键字即可.

接着看一下 mysql 中的实现:

func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
    if limit != nil {
        if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
            sql += fmt.Sprintf(" LIMIT %d", parsedLimit)

            if offset != nil {
                if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
                    sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
                }
            }
        }
    }
    return
}

两者的区别在于 offset 的嵌套, mysql 中 offset 必须和 limit 一起使用.

就这样, CombinedConditionSql 中的所有子句都看完了.
说到底其实也没什么魔法, 不过是根据不同的条件, 构建不同的 SQL 语句.

小结

一路从 First 深入到查询的内部细节. 在了解了底层细节之后, 其他类似的方法也就不难理解了.

// Take return a record that match given conditions, the order will depend on the database implementation
func (s *DB) Take(out interface{}, where ...interface{}) *DB {
    newScope := s.NewScope(out)
    newScope.Search.Limit(1)
    return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
}

// Last find last record that match given conditions, order by primary key
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
    newScope := s.NewScope(out)
    newScope.Search.Limit(1)
    return newScope.Set("gorm:order_by_primary_key", "DESC").
        inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
}

// Find find records that match given conditions
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
    return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
}

search 结构体

前面的过程中, 我们只看到了最简单的查询是如何产生的.
在这个过程中, 没有仔细研究查询条件是如何存储的.

看一下如何使用 Where 方法添加查询条件.

// Get first matched record
db.Where("name = ?", "jinzhu").First(&user)
//// SELECT * FROM users WHERE name = 'jinzhu' limit 1;

// Get all matched records
db.Where("name = ?", "jinzhu").Find(&users)
//// SELECT * FROM users WHERE name = 'jinzhu';

上面的例子来自于官方文档. GORM 使用链式调用的风格, 可以串联多个 Where 方法, 或是其他的查询条件.

// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
    return s.clone().search.Where(query, args...).db
}

上面是 Where 方法的代码, 在它的源码附近有很多类似的的方法.

// Or filter records that match before conditions or this one, similar to `Where`
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
    return s.clone().search.Or(query, args...).db
}

// Not filter records that don't match current conditions, similar to `Where`
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
    return s.clone().search.Not(query, args...).db
}

可以很容易的发现, 这一切的源头都是 search 对象.

结构体 DB 定义的时候, 有个字段就是 search:

search            *search

search 的定义

这就是用于存储查询条件的地方. 它的定义如下:

type search struct {
    db               *DB
    whereConditions  []map[string]interface{}
    orConditions     []map[string]interface{}
    notConditions    []map[string]interface{}
    havingConditions []map[string]interface{}
    joinConditions   []map[string]interface{}
    initAttrs        []interface{}
    assignAttrs      []interface{}
    selects          map[string]interface{}
    omits            []string
    orders           []interface{}
    preload          []searchPreload
    offset           interface{}
    limit            interface{}
    group            string
    tableName        string
    raw              bool
    Unscoped         bool
    ignoreOrderQuery bool
}

type searchPreload struct {
    schema     string
    conditions []interface{}
}

这里有很多类型为 []map[string]interface{} 的字段, 结合前面关于条件查询的代码, 就能回忆起这就是存储各种条件的地方.

另一些字段比如 offsetlimit 也很容易明白它的作用.

search 的方法

search 下有很多方法, 虽然方法数量比较多, 但基本都很短, 总共也就一百行出头.

func (s *search) clone() *search {
    clone := *s
    return &clone
}

这个克隆方法有点独特, 似乎什么也没做, 也可能是我见识少.

func (s *search) Where(query interface{}, values ...interface{}) *search {
    s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values})
    return s
}

func (s *search) Not(query interface{}, values ...interface{}) *search {
    s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values})
    return s
}

func (s *search) Or(query interface{}, values ...interface{}) *search {
    s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values})
    return s
}

上面这些方法都是用参数构建成一个 map 然后推入对应的切片中, 考虑到链式调用, 返回了本身.

func (s *search) Attrs(attrs ...interface{}) *search {
    s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...))
    return s
}

func (s *search) Assign(attrs ...interface{}) *search {
    s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...))
    return s
}

func toSearchableMap(attrs ...interface{}) (result interface{}) {
    if len(attrs) > 1 {
        if str, ok := attrs[0].(string); ok {
            result = map[string]interface{}{str: attrs[1]}
        }
    } else if len(attrs) == 1 {
        if attr, ok := attrs[0].(map[string]interface{}); ok {
            result = attr
        }

        if attr, ok := attrs[0].(interface{}); ok {
            result = attr
        }
    }
    return
}

这两个方法也是类似, 并使用了 toSearchableMap 转换参数.

func (s *search) Order(value interface{}, reorder ...bool) *search {
    if len(reorder) > 0 && reorder[0] {
        s.orders = []interface{}{}
    }

    if value != nil && value != "" {
        s.orders = append(s.orders, value)
    }
    return s
}

看到这个可能有点疑惑, 可以从文档和注释中获取解释.

// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions
//     db.Order("name DESC")
//     db.Order("name DESC", true) // reorder
//     db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression
func (s *DB) Order(value interface{}, reorder ...bool) *DB {
    return s.clone().search.Order(value, reorder...).db
}

第二个参数用于判断是否覆盖前面的排序条件.

可能有点奇怪的是为什么 reorder 是可变参数, 不知为了兼容或者是历史遗留.

另一点是不能理解 []interface{}{}, 这其实可以分为两部分, []interface{} 是类型, {} 构造了一个空的该类型实例.

func (s *search) Select(query interface{}, args ...interface{}) *search {
    s.selects = map[string]interface{}{"query": query, "args": args}
    return s
}

func (s *search) Omit(columns ...string) *search {
    s.omits = columns
    return s
}

func (s *search) Limit(limit interface{}) *search {
    s.limit = limit
    return s
}

func (s *search) Offset(offset interface{}) *search {
    s.offset = offset
    return s
}

这几个就是替换型的了, 每次调用都只会保存最新值.

func (s *search) Group(query string) *search {
    s.group = s.getInterfaceAsSQL(query)
    return s
}

func (s *search) getInterfaceAsSQL(value interface{}) (str string) {
    switch value.(type) {
    case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
        str = fmt.Sprintf("%v", value)
    default:
        s.db.AddError(ErrInvalidSQL)
    }

    if str == "-1" {
        return ""
    }
    return
}

getInterfaceAsSQL 的一个特性是使用 -1 会重置.

func (s *search) Having(query interface{}, values ...interface{}) *search {
    if val, ok := query.(*expr); ok {
        s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args})
    } else {
        s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values})
    }
    return s
}

func (s *search) Joins(query string, values ...interface{}) *search {
    s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values})
    return s
}

这其实也比较类似前面看过的, 就不多解释了.

func (s *search) Preload(schema string, values ...interface{}) *search {
    var preloads []searchPreload
    for _, preload := range s.preload {
        if preload.schema != schema {
            preloads = append(preloads, preload)
        }
    }
    preloads = append(preloads, searchPreload{schema, values})
    s.preload = preloads
    return s
}

Preload 需要防止重复, 所以开头会重新遍历一遍已经存在的 schema.

func (s *search) Raw(b bool) *search {
    s.raw = b
    return s
}

func (s *search) unscoped() *search {
    s.Unscoped = true
    return s
}

func (s *search) Table(name string) *search {
    s.tableName = name
    return s
}

最后几个方法也没什么特殊的.

小结

search 结构体还是挺简单的, 定义加方法总共也就一百多行.
但用处却不小, 查询相关的条件都是存储在这里的.

总结

这部分主要查看了 SQL 查询是如何发生的, 并在这个过程中探索了各种查询子句是如何实现的. 同时, 也研究了一下 search 结构体和它的作用.

查看原文

赞 6 收藏 3 评论 0

帅气猫咪 发布了文章 · 2020-01-14

03GORM源码解读

简介

GORM 源码解读, 基于 v1.9.11 版本.

模型交互

前面已经研究过模型是如何定义并被解析的了, 这次看一下模型是如何和数据库交互的.

package main

import (
  "github.com/jinzhu/gorm"
  _ "github.com/jinzhu/gorm/dialects/sqlite"
)

type Product struct {
  gorm.Model
  Code string
  Price uint
}

func main() {
  db, err := gorm.Open("sqlite3", "test.db")
  if err != nil {
    panic("failed to connect database")
  }
  defer db.Close()

  // Migrate the schema
  db.AutoMigrate(&Product{})

  // 创建
  db.Create(&Product{Code: "L1212", Price: 1000})

  // 读取
  var product Product
  db.First(&product, 1) // 查询id为1的product
  db.First(&product, "code = ?", "L1212") // 查询code为l1212的product

  // 更新 - 更新product的price为2000
  db.Model(&product).Update("Price", 2000)

  // 删除 - 删除product
  db.Delete(&product)
}

AutoMigrate

当定义好模型之后, 第一步是使用 AutoMigrate 合并模型:

db.AutoMigrate(&Product{})

看一下它的源码:

// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data
func (s *DB) AutoMigrate(values ...interface{}) *DB {
    db := s.Unscoped()
    for _, value := range values {
        db = db.NewScope(value).autoMigrate().db
    }
    return db
}

内部是对每个传递的参数调用了 db.NewScope(value).autoMigrate().

那具体是如何合并的呢?

func (scope *Scope) autoMigrate() *Scope {
    tableName := scope.TableName()
    quotedTableName := scope.QuotedTableName()

    if !scope.Dialect().HasTable(tableName) {
        scope.createTable()
    } else {
        for _, field := range scope.GetModelStruct().StructFields {
            if !scope.Dialect().HasColumn(tableName, field.DBName) {
                if field.IsNormal {
                    sqlTag := scope.Dialect().DataTypeOf(field)
                    scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
                }
            }
            scope.createJoinTable(field)
        }
        scope.autoIndex()
    }
    return scope
}

中间的 if 部分的代码展示了两条路径. 如果表还没有创建, 直接创建就行了.

否则就需要对模型中的每个字段进行操作, 如果列名不存在, 就需要变更表新增字段了.

scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()

SQL 语句是如何执行的, 先暂时不理会, 但从代码的形式上看算是挺简洁的, 直接使用 Raw 构造语句, Exec 执行.

同时, 对于模型中的每个字段, 还要更新一遍连接表, scope.createJoinTable(field).

在 for 循环处理完模型中的所有字段后, 再更新一遍索引, scope.autoIndex().

总结起来, 自动合并主要做了这么几件事: 创建表, 添加新增的字段, 更新表的关系, 更新索引.

createTable

前面省略了创建表的具体过程, 来仔细看看表是如何创建的.

func (scope *Scope) createTable() *Scope {
    var tags []string
    var primaryKeys []string
    var primaryKeyInColumnType = false
    for _, field := range scope.GetModelStruct().StructFields {
        if field.IsNormal {
            sqlTag := scope.Dialect().DataTypeOf(field)

            // Check if the primary key constraint was specified as
            // part of the column type. If so, we can only support
            // one column as the primary key.
            if strings.Contains(strings.ToLower(sqlTag), "primary key") {
                primaryKeyInColumnType = true
            }

            tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
        }

        if field.IsPrimaryKey {
            primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
        }
        scope.createJoinTable(field)
    }

    var primaryKeyStr string
    if len(primaryKeys) > 0 && !primaryKeyInColumnType {
        primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
    }

    scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()

    scope.autoIndex()
    return scope
}

这就是构建 SQL 创建表的过程, 主要的过程是这行代码:

scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()

前面的过程主要是遍历模型的字段, 获取每个字段的 sqlTag, 并加入 tags 中:

tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)

带有双引号的列名加上空格加上 sqlTag.

这个过程中还涉及到了主键的判断, 不过感觉这部分有点坑, 因为
sqlTag := scope.Dialect().DataTypeOf(field) 的实现取决于每种数据库对 DataTypeOf 的具体实现.

issues 2270 显示出现多个 primary key,
使用的是如下的模型定义, 数据库使用了 sqlite3:

type Permission struct {
    ID   int64  `gorm:"AUTO_INCREMENT;column:id;primary_key"`
    Name string `gorm:"column:name;type:varchar;unique;not null"`
    Idx  int64  `gorm:"AUTO_INCREMENT"`
}

虽然这个模型定义中只指定了一个 primary_key, 但结果 Idx 也变成了 primary_key:

[2019-01-19 19:40:30]  table "permission" has more than one primary key

[2019-01-19 19:40:30]  [0.14ms]  CREATE TABLE "permission" ("id" integer primary key autoincrement,"name" varchar NOT NULL UNIQUE,"idx" integer primary key autoincrement )
[0 rows affected or returned ]

原因只有一个, 它使用了 AUTO_INCREMENT 选项, 而在 sqlite3 的 DataTypeOf 实现中:

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
  if s.fieldCanAutoIncrement(field) {
    field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
    sqlType = "integer primary key autoincrement"
  } else {
    sqlType = "integer"
  }
case reflect.Int64, reflect.Uint64:
  if s.fieldCanAutoIncrement(field) {
    field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
    sqlType = "integer primary key autoincrement"
  } else {
    sqlType = "bigint"
  }

AUTO_INCREMENT 选项导致了返回的结果中存在 primary key.

我怀疑这是个 bug. 因为在后续有对是否是主键的判断, 并添加 primaryKeyStr.

if field.IsPrimaryKey {
  primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
}
var primaryKeyStr string
if len(primaryKeys) > 0 && !primaryKeyInColumnType {
  primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
}

我觉得 sqlType 不应该返回关于 primary key 的信息.
要设置主键, 可以在后面的 primaryKeyStr 中进行.

好了, 对于主键的讨论就此告一段落了.

合并表和创建表的过程中都有 createJoinTable, 但因为关系实现还没有深入研究, 先忽略吧.

callbacks

增删改查都和 DB 结构体中的 callbacks 有关:

// DB contains information for current db connection
type DB struct {
  ...
    // global db
    parent        *DB
    callbacks     *Callback
    dialect       Dialect
    singularTable bool
  ...
}

看一下 Create 方法的代码:

// Create insert the value into database
func (s *DB) Create(value interface{}) *DB {
    scope := s.NewScope(value)
    return scope.callCallbacks(s.parent.callbacks.creates).db
}

在新的 scope 中调用了 callCallbacks 方法, 里面的参数是 s.parent.callbacks.creates.
parent 的类型也是 *DB, 算是继承.

继续挖掘 callCallbacks:

func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
    defer func() {
        if err := recover(); err != nil {
            if db, ok := scope.db.db.(sqlTx); ok {
                db.Rollback()
            }
            panic(err)
        }
    }()
    for _, f := range funcs {
        (*f)(scope)
        if scope.skipLeft {
            break
        }
    }
    return scope
}

使用了 defer 下的 recover 模式, 以前介绍过这个模式, 不再深入.

callCallbacks 的参数其实是个函数的切片, 然后依次调用所有的函数, 除非 scope.skipLeft 为 true.

看过了调用的方式, 让我们来看看 Callback 到底是什么.

// Callback is a struct that contains all CRUD callbacks
//   Field `creates` contains callbacks will be call when creating object
//   Field `updates` contains callbacks will be call when updating object
//   Field `deletes` contains callbacks will be call when deleting object
//   Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association...
//   Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
//   Field `processors` contains all callback processors, will be used to generate above callbacks in order
type Callback struct {
    logger     logger
    creates    []*func(scope *Scope)
    updates    []*func(scope *Scope)
    deletes    []*func(scope *Scope)
    queries    []*func(scope *Scope)
    rowQueries []*func(scope *Scope)
    processors []*CallbackProcessor
}

Callback 里包含了很多的函数切片, 用于增删改查. 注释已经解释的很清楚了.

关注一下 CallbackProcessor, 这是用于按序生成所有 callbacks 的.

// CallbackProcessor contains callback informations
type CallbackProcessor struct {
    logger    logger
    name      string              // current callback's name
    before    string              // register current callback before a callback
    after     string              // register current callback after a callback
    replace   bool                // replace callbacks with same name
    remove    bool                // delete callbacks with same name
    kind      string              // callback type: create, update, delete, query, row_query
    processor *func(scope *Scope) // callback handler
    parent    *Callback
}
// Create could be used to register callbacks for creating object
//     db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
//       // business logic
//       ...
//
//       // set error if some thing wrong happened, will rollback the creating
//       scope.Err(errors.New("error"))
//     })
func (c *Callback) Create() *CallbackProcessor {
    return &CallbackProcessor{logger: c.logger, kind: "create", parent: c}
}

// Update could be used to register callbacks for updating object, refer `Create` for usage
func (c *Callback) Update() *CallbackProcessor {
    return &CallbackProcessor{logger: c.logger, kind: "update", parent: c}
}

// Delete could be used to register callbacks for deleting object, refer `Create` for usage
func (c *Callback) Delete() *CallbackProcessor {
    return &CallbackProcessor{logger: c.logger, kind: "delete", parent: c}
}

// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
// Refer `Create` for usage
func (c *Callback) Query() *CallbackProcessor {
    return &CallbackProcessor{logger: c.logger, kind: "query", parent: c}
}

// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
func (c *Callback) RowQuery() *CallbackProcessor {
    return &CallbackProcessor{logger: c.logger, kind: "row_query", parent: c}
}

Callback 有各种方法来创建不同类型的 CallbackProcessor.

// After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor {
    cp.after = callbackName
    return cp
}

// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create`
func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
    cp.before = callbackName
    return cp
}

AfterBefore 更新了 CallbackProcessor 上特定的属性, 用于后续计算 callback 调用顺序.

db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
  // business logic
  ...

  // set error if some thing wrong happened, will rollback the creating
  scope.Err(errors.New("error"))
})

注释上的例子是这样的, 继续看 Register 方法.

// Register a new callback, refer `Callbacks.Create`
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
    if cp.kind == "row_query" {
        if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
            cp.logger.Print(fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName))
            cp.before = "gorm:row_query"
        }
    }

    cp.name = callbackName
    cp.processor = &callback
    cp.parent.processors = append(cp.parent.processors, cp)
    cp.parent.reorder()
}

主要是设置了 cp 的 processor 属性, 并将该 cp 添加到了 cp.parent.processors 中.
然后调用 cp.parent.reorder() 进行了重新排序.

有注册方法, 当然也有对应的删除方法:

// Remove a registered callback
//     db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
func (cp *CallbackProcessor) Remove(callbackName string) {
    cp.logger.Print(fmt.Sprintf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()))
    cp.name = callbackName
    cp.remove = true
    cp.parent.processors = append(cp.parent.processors, cp)
    cp.parent.reorder()
}

设置 remove 属性为 true, 然后重新排序.

替换的方法也是类似:

// Replace a registered callback with new callback
//     db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
//           scope.SetColumn("Created", now)
//           scope.SetColumn("Updated", now)
//     })
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
    cp.logger.Print(fmt.Sprintf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()))
    cp.name = callbackName
    cp.processor = &callback
    cp.replace = true
    cp.parent.processors = append(cp.parent.processors, cp)
    cp.parent.reorder()
}

还是看一下重新排序是如何进行的吧:

// reorder all registered processors, and reset CRUD callbacks
func (c *Callback) reorder() {
    var creates, updates, deletes, queries, rowQueries []*CallbackProcessor

    for _, processor := range c.processors {
        if processor.name != "" {
            switch processor.kind {
            case "create":
                creates = append(creates, processor)
            case "update":
                updates = append(updates, processor)
            case "delete":
                deletes = append(deletes, processor)
            case "query":
                queries = append(queries, processor)
            case "row_query":
                rowQueries = append(rowQueries, processor)
            }
        }
    }

    c.creates = sortProcessors(creates)
    c.updates = sortProcessors(updates)
    c.deletes = sortProcessors(deletes)
    c.queries = sortProcessors(queries)
    c.rowQueries = sortProcessors(rowQueries)
}

上半部分只是分别归类, 具体还是要看 sortProcessors:

// sortProcessors sort callback processors based on its before, after, remove, replace
func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
    var (
        allNames, sortedNames []string
        sortCallbackProcessor func(c *CallbackProcessor)
    )

    for _, cp := range cps {
        // show warning message the callback name already exists
        if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
            cp.logger.Print(fmt.Sprintf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()))
        }
        allNames = append(allNames, cp.name)
    }

    sortCallbackProcessor = func(c *CallbackProcessor) {
        if getRIndex(sortedNames, c.name) == -1 { // if not sorted
            if c.before != "" { // if defined before callback
                if index := getRIndex(sortedNames, c.before); index != -1 {
                    // if before callback already sorted, append current callback just after it
                    sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
                } else if index := getRIndex(allNames, c.before); index != -1 {
                    // if before callback exists but haven't sorted, append current callback to last
                    sortedNames = append(sortedNames, c.name)
                    sortCallbackProcessor(cps[index])
                }
            }

            if c.after != "" { // if defined after callback
                if index := getRIndex(sortedNames, c.after); index != -1 {
                    // if after callback already sorted, append current callback just before it
                    sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
                } else if index := getRIndex(allNames, c.after); index != -1 {
                    // if after callback exists but haven't sorted
                    cp := cps[index]
                    // set after callback's before callback to current callback
                    if cp.before == "" {
                        cp.before = c.name
                    }
                    sortCallbackProcessor(cp)
                }
            }

            // if current callback haven't been sorted, append it to last
            if getRIndex(sortedNames, c.name) == -1 {
                sortedNames = append(sortedNames, c.name)
            }
        }
    }

    for _, cp := range cps {
        sortCallbackProcessor(cp)
    }

    var sortedFuncs []*func(scope *Scope)
    for _, name := range sortedNames {
        if index := getRIndex(allNames, name); !cps[index].remove {
            sortedFuncs = append(sortedFuncs, cps[index].processor)
        }
    }

    return sortedFuncs
}

首先获取了所有 cp 的名字, 同时提示是否发现了重复. sortedNames 里保存排序好的名字.

// getRIndex get right index from string slice
func getRIndex(strs []string, str string) int {
    for i := len(strs) - 1; i >= 0; i-- {
        if strs[i] == str {
            return i
        }
    }
    return -1
}

getRIndex 获取最右边的索引.

看一下 sortCallbackProcessor 函数到底在做什么.

里面有两个判断部分, 先看第一个部分:

if c.before != "" { // if defined before callback
  if index := getRIndex(sortedNames, c.before); index != -1 {
    // if before callback already sorted, append current callback just after it
    sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
  } else if index := getRIndex(allNames, c.before); index != -1 {
    // if before callback exists but haven't sorted, append current callback to last
    sortedNames = append(sortedNames, c.name)
    sortCallbackProcessor(cps[index])
  }
}

分为两种情况, 如果 before callback 已经排序好了, 直接插在它的后面就行.

如果 before callback 确实存在, 但还没有被排序, 就将当前名字直接放在 sortedNames 的最后.
然后递归调用 sortCallbackProcessor(cps[index]), 这就是直接进入到 before callback 的排序中了.

再看第二个部分:

if c.after != "" { // if defined after callback
  if index := getRIndex(sortedNames, c.after); index != -1 {
    // if after callback already sorted, append current callback just before it
    sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
  } else if index := getRIndex(allNames, c.after); index != -1 {
    // if after callback exists but haven't sorted
    cp := cps[index]
    // set after callback's before callback to current callback
    if cp.before == "" {
      cp.before = c.name
    }
    sortCallbackProcessor(cp)
  }
}

其实和前面的逻辑差不多, 如果 after callback 已经排序好了, 直接插在它的前面就行.

如果 after callback 确实存在, 会修改 after callback 的 before 属性, 设置为当前 callback.
然后递归调用 sortCallbackProcessor(cp), 进入到 after callback 的排序中.

// if current callback haven't been sorted, append it to last
if getRIndex(sortedNames, c.name) == -1 {
  sortedNames = append(sortedNames, c.name)
}

还没保存就直接放到最后. sortCallbackProcessor 的内容就是这样.

for _, cp := range cps {
  sortCallbackProcessor(cp)
}

开始排序. 等排序完了之后, sortedNames 就完成了:

var sortedFuncs []*func(scope *Scope)
for _, name := range sortedNames {
  if index := getRIndex(allNames, name); !cps[index].remove {
    sortedFuncs = append(sortedFuncs, cps[index].processor)
  }
}

return sortedFuncs

将那些不是 remove 状态的 callback, 依次添加到 sortedFuncs 中.

最后还有一个 Get 方法用于获取注册的回调:

// Get registered callback
//    db.Callback().Create().Get("gorm:create")
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
    for _, p := range cp.parent.processors {
        if p.name == callbackName && p.kind == cp.kind {
            if p.remove {
                callback = nil
            } else {
                callback = *p.processor
            }
        }
    }
    return
}

现在, 我们应该已经清楚了回调函数是如何注册并排序的了, 以及如何按名称获取单个回调函数.

实际注册流程

前面只是讲解了理论上的定义, 看一下实际上是在哪里注册的.

DB 在初始化的时候, 即 Open 方法调用了如下的语句:

db = &DB{
  db:        dbSQL,
  logger:    defaultLogger,
  callbacks: DefaultCallback,
  dialect:   newDialect(dialect, dbSQL),
}

这个 DefaultCallback 的定义如下:

// DefaultCallback default callbacks defined by gorm
var DefaultCallback = &Callback{}

一开始我也是有点慌, 这只是个空定义, 肯定有地方初始化的. 扫了一眼目录就明白了.

callback_create.go 文件下定义了 create 方面的注册流程.

// Define callbacks for creating
func init() {
    DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
    DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
    DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
    DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
    DefaultCallback.Create().Register("gorm:create", createCallback)
    DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
    DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
    DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
    DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
}

结合文档,
看一下 BeforeSaveBeforeCreate 是如何实现的.

当你定义一个模型时, 可以在这个模型上实现 BeforeSaveBeforeCreate 之类的方法,
这些方法会在恰当的时候被调用.

func (u *User) BeforeSave() (err error) {
  if !u.IsValid() {
    err = errors.New("can't save invalid data")
  }
  return
}

func (u *User) AfterCreate(scope *gorm.Scope) (err error) {
  if u.ID == 1 {
    scope.DB().Model(u).Update("role", "admin")
  }
  return
}

上面是官方文档上的例子. 在前面我们在注释中看到了如何手动注册一个回调函数,
类似于 DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback),
但如何实现调用模型上定义的方法呢?

看一下 beforeCreateCallback 函数:

// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
func beforeCreateCallback(scope *Scope) {
    if !scope.HasError() {
        scope.CallMethod("BeforeSave")
    }
    if !scope.HasError() {
        scope.CallMethod("BeforeCreate")
    }
}

原来是通过 scope.CallMethod 方法实现的, 传递特定的方法名称就能调用该方法了.

// CallMethod call scope value's method, if it is a slice, will call its element's method one by one
func (scope *Scope) CallMethod(methodName string) {
    if scope.Value == nil {
        return
    }

    if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice {
        for i := 0; i < indirectScopeValue.Len(); i++ {
            scope.callMethod(methodName, indirectScopeValue.Index(i))
        }
    } else {
        scope.callMethod(methodName, indirectScopeValue)
    }
}

绕了一圈, 继续看 callMethod 的代码:

func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
    // Only get address from non-pointer
    if reflectValue.CanAddr() && reflectValue.Kind() != reflect.Ptr {
        reflectValue = reflectValue.Addr()
    }

    if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() {
        switch method := methodValue.Interface().(type) {
        case func():
            method()
        case func(*Scope):
            method(scope)
        case func(*DB):
            newDB := scope.NewDB()
            method(newDB)
            scope.Err(newDB.Error)
        case func() error:
            scope.Err(method())
        case func(*Scope) error:
            scope.Err(method(scope))
        case func(*DB) error:
            newDB := scope.NewDB()
            scope.Err(method(newDB))
            scope.Err(newDB.Error)
        default:
            scope.Err(fmt.Errorf("unsupported function %v", methodName))
        }
    }
}

这些灵活的方式都是靠反射实现的, 关键代码是 methodValue := reflectValue.MethodByName(methodName).

switch 可以看到, 方法可以有不同的签名:

switch method := methodValue.Interface().(type) {
case func():
  method()
case func(*Scope):
  method(scope)
case func(*DB):
  newDB := scope.NewDB()
  method(newDB)
  scope.Err(newDB.Error)
case func() error:
  scope.Err(method())
case func(*Scope) error:
  scope.Err(method(scope))
case func(*DB) error:
  newDB := scope.NewDB()
  scope.Err(method(newDB))
  scope.Err(newDB.Error)
default:
  scope.Err(fmt.Errorf("unsupported function %v", methodName))
}

所以, 实际上这都可以看作是 reflect 的大型示范使用例子.

createCallback

其他的钩子函数不看了, 具体看一下当插入单条数据时都在干什么:

// createCallback the callback used to insert data into database
func createCallback(scope *Scope) {
    if !scope.HasError() {
        defer scope.trace(scope.db.nowFunc())

        var (
            columns, placeholders        []string
            blankColumnsWithDefaultValue []string
        )

        for _, field := range scope.Fields() {
            if scope.changeableField(field) {
                if field.IsNormal && !field.IsIgnored {
                    if field.IsBlank && field.HasDefaultValue {
                        blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
                        scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
                    } else if !field.IsPrimaryKey || !field.IsBlank {
                        columns = append(columns, scope.Quote(field.DBName))
                        placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
                    }
                } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
                    for _, foreignKey := range field.Relationship.ForeignDBNames {
                        if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
                            columns = append(columns, scope.Quote(foreignField.DBName))
                            placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
                        }
                    }
                }
            }
        }

        var (
            returningColumn = "*"
            quotedTableName = scope.QuotedTableName()
            primaryField    = scope.PrimaryField()
            extraOption     string
            insertModifier  string
        )

        if str, ok := scope.Get("gorm:insert_option"); ok {
            extraOption = fmt.Sprint(str)
        }
        if str, ok := scope.Get("gorm:insert_modifier"); ok {
            insertModifier = strings.ToUpper(fmt.Sprint(str))
            if insertModifier == "INTO" {
                insertModifier = ""
            }
        }

        if primaryField != nil {
            returningColumn = scope.Quote(primaryField.DBName)
        }

        lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)

        if len(columns) == 0 {
            scope.Raw(fmt.Sprintf(
                "INSERT %v INTO %v %v%v%v",
                addExtraSpaceIfExist(insertModifier),
                quotedTableName,
                scope.Dialect().DefaultValueStr(),
                addExtraSpaceIfExist(extraOption),
                addExtraSpaceIfExist(lastInsertIDReturningSuffix),
            ))
        } else {
            scope.Raw(fmt.Sprintf(
                "INSERT %v INTO %v (%v) VALUES (%v)%v%v",
                addExtraSpaceIfExist(insertModifier),
                scope.QuotedTableName(),
                strings.Join(columns, ","),
                strings.Join(placeholders, ","),
                addExtraSpaceIfExist(extraOption),
                addExtraSpaceIfExist(lastInsertIDReturningSuffix),
            ))
        }

        // execute create sql
        if lastInsertIDReturningSuffix == "" || primaryField == nil {
            if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
                // set rows affected count
                scope.db.RowsAffected, _ = result.RowsAffected()

                // set primary value to primary field
                if primaryField != nil && primaryField.IsBlank {
                    if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
                        scope.Err(primaryField.Set(primaryValue))
                    }
                }
            }
        } else {
            if primaryField.Field.CanAddr() {
                if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
                    primaryField.IsBlank = false
                    scope.db.RowsAffected = 1
                }
            } else {
                scope.Err(ErrUnaddressable)
            }
        }
    }
}

首先, 内部的第一个 for 循环遍历了所有的字段, 并更新了开头定义的三个切片.

for _, field := range scope.Fields() {
  if scope.changeableField(field) {
    if field.IsNormal && !field.IsIgnored {
      if field.IsBlank && field.HasDefaultValue {
        blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
        scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
      } else if !field.IsPrimaryKey || !field.IsBlank {
        columns = append(columns, scope.Quote(field.DBName))
        placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
      }
    } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
      for _, foreignKey := range field.Relationship.ForeignDBNames {
        if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
          columns = append(columns, scope.Quote(foreignField.DBName))
          placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
        }
      }
    }
  }
}

然后就是获取并设置一些信息:

var (
  returningColumn = "*"
  quotedTableName = scope.QuotedTableName()
  primaryField    = scope.PrimaryField()
  extraOption     string
  insertModifier  string
)

等信息都获取完了, 就开始构造插入语句了:

if len(columns) == 0 {
  scope.Raw(fmt.Sprintf(
    "INSERT %v INTO %v %v%v%v",
    addExtraSpaceIfExist(insertModifier),
    quotedTableName,
    scope.Dialect().DefaultValueStr(),
    addExtraSpaceIfExist(extraOption),
    addExtraSpaceIfExist(lastInsertIDReturningSuffix),
  ))
} else {
  scope.Raw(fmt.Sprintf(
    "INSERT %v INTO %v (%v) VALUES (%v)%v%v",
    addExtraSpaceIfExist(insertModifier),
    scope.QuotedTableName(),
    strings.Join(columns, ","),
    strings.Join(placeholders, ","),
    addExtraSpaceIfExist(extraOption),
    addExtraSpaceIfExist(lastInsertIDReturningSuffix),
  ))
}

最后执行 sql 语句:

// execute create sql
if lastInsertIDReturningSuffix == "" || primaryField == nil {
  if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
    // set rows affected count
    scope.db.RowsAffected, _ = result.RowsAffected()

    // set primary value to primary field
    if primaryField != nil && primaryField.IsBlank {
      if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
        scope.Err(primaryField.Set(primaryValue))
      }
    }
  }
} else {
  if primaryField.Field.CanAddr() {
    if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
      primaryField.IsBlank = false
      scope.db.RowsAffected = 1
    }
  } else {
    scope.Err(ErrUnaddressable)
  }
}

这里的第一个判断条件是和 lastInsertIDReturningSuffix 有关的, 只有 PostgreSQL 会返回非空的字符串.

var userid int
err := db.QueryRow(`INSERT INTO users(name, favorite_fruit, age)
    VALUES('beatrice', 'starfruit', 93) RETURNING id`).Scan(&userid)

PostgreSQL 中不支持 LastInsertId() 方法, 要获取 ID 需要像上面这样调用.
参考 PostgreSQL Queries.

所以执行方式有所不同.

这样, createCallback 回调就看完了, 插入数据的过程也知道了.

总结

在这一部分里, 主要看了数据表是如何创建和合并的, 以及钩子函数是如何注册并排序的, 以及何时被调用的.

查看原文

赞 0 收藏 0 评论 0

帅气猫咪 发布了文章 · 2019-12-29

02GORM源码解读

简介

GORM 源码解读, 基于 v1.9.11 版本.

定义模型

GORM 是 ORM, 所以模型定义是最重要的部分, 这一次来探究下具体实现.

type User struct {
  gorm.Model
  Name         string
  Age          sql.NullInt64
  Birthday     *time.Time
  Email        string  `gorm:"type:varchar(100);unique_index"`
  Role         string  `gorm:"size:255"` // 设置字段大小为255
  MemberNumber *string `gorm:"unique;not null"` // 设置会员号(member number)唯一并且不为空
  Num          int     `gorm:"AUTO_INCREMENT"` // 设置 num 为自增类型
  Address      string  `gorm:"index:addr"` // 给address字段创建名为addr的索引
  IgnoreMe     int     `gorm:"-"` // 忽略本字段
}

这是官方文档上的一个模型定义. 和普通的结构体类似, 但多了属于 gorm 的 tags.

所有的模型都应该包含 gorm.Model, 看一下它的定义:

// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models
//    type User struct {
//      gorm.Model
//    }
type Model struct {
    ID        uint `gorm:"primary_key"`
    CreatedAt time.Time
    UpdatedAt time.Time
    DeletedAt *time.Time `sql:"index"`
}

当然, 这并不是强制要求, 也可以不包含 gorm.Model, 它只是定义了一些非常基础且实用的字段.

定义表的时候, 文档上介绍了很多预设, 比如默认 ID 是表的主键, 表名是结构体名称的复数等.

ModelStruct

要深入了解模型定义, 要从 ModelStruct 开始:

// ModelStruct model definition
type ModelStruct struct {
    PrimaryFields []*StructField
    StructFields  []*StructField
    ModelType     reflect.Type

    defaultTableName string
    l                sync.Mutex
}

ModelStruct 定义了模型结构体的轮廓, 包含主键字段的切片, 普通字段的切片, 模型类型, 默认表名.

获取表名

ModelStruct 有一个方法获取模型的表名, 看一下它的具体代码:

// TableName returns model's table name
func (s *ModelStruct) TableName(db *DB) string {
    s.l.Lock()
    defer s.l.Unlock()

    if s.defaultTableName == "" && db != nil && s.ModelType != nil {
        // Set default table name
        if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok {
            s.defaultTableName = tabler.TableName()
        } else {
            tableName := ToTableName(s.ModelType.Name())
            db.parent.RLock()
            if db == nil || (db.parent != nil && !db.parent.singularTable) {
                tableName = inflection.Plural(tableName)
            }
            db.parent.RUnlock()
            s.defaultTableName = tableName
        }
    }

    return DefaultTableNameHandler(db, s.defaultTableName)
}

首先, 使用反射检查是否实现了 tabler 接口, 如果实现了, 直接调用 TableName() 方法;
没有实现就使用 ToTableName 转换表名, 有条件地将表名转换为复数形式;
最后一步, 对于所有的表名使用 DefaultTableNameHandler 钩子函数进行再次转换.

看过源码之后, 就能更好的理解文档上关于表名的说明了.

StructField

看一下 StructField 的定义, 即表中的字段是如何表示的:

// StructField model field's struct definition
type StructField struct {
    DBName          string
    Name            string
    Names           []string
    IsPrimaryKey    bool
    IsNormal        bool
    IsIgnored       bool
    IsScanner       bool
    HasDefaultValue bool
    Tag             reflect.StructTag
    TagSettings     map[string]string
    Struct          reflect.StructField
    IsForeignKey    bool
    Relationship    *Relationship

    tagSettingsLock sync.RWMutex
}

定义了很多字段, 从字段的名字中可以猜测出很多信息, 比如该字段是否是主键等.

注意到有个 TagSettings 字段, 以及配套的 tagSettingsLock 读写锁.

// TagSettingsSet Sets a tag in the tag settings map
func (sf *StructField) TagSettingsSet(key, val string) {
    sf.tagSettingsLock.Lock()
    defer sf.tagSettingsLock.Unlock()
    sf.TagSettings[key] = val
}

// TagSettingsGet returns a tag from the tag settings
func (sf *StructField) TagSettingsGet(key string) (string, bool) {
    sf.tagSettingsLock.RLock()
    defer sf.tagSettingsLock.RUnlock()
    val, ok := sf.TagSettings[key]
    return val, ok
}

// TagSettingsDelete deletes a tag
func (sf *StructField) TagSettingsDelete(key string) {
    sf.tagSettingsLock.Lock()
    defer sf.tagSettingsLock.Unlock()
    delete(sf.TagSettings, key)
}

这些方法都是和 TagSettings 有关的, 也可以看作是读写锁 sync.RWMutex 的使用范例.

最后一个方法是关于复制结构体的.

func (sf *StructField) clone() *StructField {
    clone := &StructField{
        DBName:          sf.DBName,
        Name:            sf.Name,
        Names:           sf.Names,
        IsPrimaryKey:    sf.IsPrimaryKey,
        IsNormal:        sf.IsNormal,
        IsIgnored:       sf.IsIgnored,
        IsScanner:       sf.IsScanner,
        HasDefaultValue: sf.HasDefaultValue,
        Tag:             sf.Tag,
        TagSettings:     map[string]string{},
        Struct:          sf.Struct,
        IsForeignKey:    sf.IsForeignKey,
    }

    if sf.Relationship != nil {
        relationship := *sf.Relationship
        clone.Relationship = &relationship
    }

    // copy the struct field tagSettings, they should be read-locked while they are copied
    sf.tagSettingsLock.Lock()
    defer sf.tagSettingsLock.Unlock()
    for key, value := range sf.TagSettings {
        clone.TagSettings[key] = value
    }

    return clone
}

复制 tagSettingsLock 中的字段时, 也用到了读锁.

Relationship

结构体 Relationship 定义了关系类型.

type Relationship struct {
    Kind                         string
    PolymorphicType              string
    PolymorphicDBName            string
    PolymorphicValue             string
    ForeignFieldNames            []string
    ForeignDBNames               []string
    AssociationForeignFieldNames []string
    AssociationForeignDBNames    []string
    JoinTableHandler             JoinTableHandlerInterface
}
func getForeignField(column string, fields []*StructField) *StructField {
    for _, field := range fields {
        if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) {
            return field
        }
    }
    return nil
}

更多

在继续探索如何解析模型定义之前, 先来了解一下 Scope 结构体.

Scope

// Scope contain current operation's information when you perform any operation on the database
type Scope struct {
    Search          *search
    Value           interface{}
    SQL             string
    SQLVars         []interface{}
    db              *DB
    instanceID      string
    primaryKeyField *Field
    skipLeft        bool
    fields          *[]*Field
    selectAttrs     *[]string
}

Scope 是非常重要的一部分, 注释中写道, 当你在数据库上执行任何操作时, Scope 都会记录当前操作的信息.

// IndirectValue return scope's reflect value's indirect value
func (scope *Scope) IndirectValue() reflect.Value {
    return indirect(reflect.ValueOf(scope.Value))
}

func indirect(reflectValue reflect.Value) reflect.Value {
    for reflectValue.Kind() == reflect.Ptr {
        reflectValue = reflectValue.Elem()
    }
    return reflectValue
}

// New create a new Scope without search information
func (scope *Scope) New(value interface{}) *Scope {
    return &Scope{db: scope.NewDB(), Search: &search{}, Value: value}
}

// NewDB create a new DB without search information
func (scope *Scope) NewDB() *DB {
    if scope.db != nil {
        db := scope.db.clone()
        db.search = nil
        db.Value = nil
        return db
    }
    return nil
}

Scope 下有很多方法, 先暂时不看. 对它的结构有所了解之后, 回到模型解析上来.

模型解析

用户定义模型之后, 就需要解析模型, 而这个工作是在 Scope 范围内完成的, 所以是其上的方法.

代码很长, 先略览它个大概, 感受一下整体结构.

// GetModelStruct get value's model struct, relationships based on struct and tag definition
func (scope *Scope) GetModelStruct() *ModelStruct {
    var modelStruct ModelStruct
    // Scope value can't be nil
    if scope.Value == nil {
        return &modelStruct
    }

    reflectType := reflect.ValueOf(scope.Value).Type()
    for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr {
        reflectType = reflectType.Elem()
    }

    // Scope value need to be a struct
    if reflectType.Kind() != reflect.Struct {
        return &modelStruct
    }

    // Get Cached model struct
    isSingularTable := false
    if scope.db != nil && scope.db.parent != nil {
        scope.db.parent.RLock()
        isSingularTable = scope.db.parent.singularTable
        scope.db.parent.RUnlock()
    }

    hashKey := struct {
        singularTable bool
        reflectType   reflect.Type
    }{isSingularTable, reflectType}
    if value, ok := modelStructsMap.Load(hashKey); ok && value != nil {
        return value.(*ModelStruct)
    }

    modelStruct.ModelType = reflectType

    // Get all fields
    for i := 0; i < reflectType.NumField(); i++ {
        if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) {
            field := &StructField{
                Struct:      fieldStruct,
                Name:        fieldStruct.Name,
                Names:       []string{fieldStruct.Name},
                Tag:         fieldStruct.Tag,
                TagSettings: parseTagSetting(fieldStruct.Tag),
            }

            // is ignored field
            if _, ok := field.TagSettingsGet("-"); ok {
                field.IsIgnored = true
            } else {
                if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok {
                    field.IsPrimaryKey = true
                    modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
                }

                if _, ok := field.TagSettingsGet("DEFAULT"); ok && !field.IsPrimaryKey {
                    field.HasDefaultValue = true
                }

                if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey {
                    field.HasDefaultValue = true
                }

                indirectType := fieldStruct.Type
                for indirectType.Kind() == reflect.Ptr {
                    indirectType = indirectType.Elem()
                }

                fieldValue := reflect.New(indirectType).Interface()
                if _, isScanner := fieldValue.(sql.Scanner); isScanner {
                    // is scanner
                    field.IsScanner, field.IsNormal = true, true
                    if indirectType.Kind() == reflect.Struct {
                        for i := 0; i < indirectType.NumField(); i++ {
                            for key, value := range parseTagSetting(indirectType.Field(i).Tag) {
                                if _, ok := field.TagSettingsGet(key); !ok {
                                    field.TagSettingsSet(key, value)
                                }
                            }
                        }
                    }
                } else if _, isTime := fieldValue.(*time.Time); isTime {
                    // is time
                    field.IsNormal = true
                } else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous {
                    // is embedded struct
                    for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields {
                        subField = subField.clone()
                        subField.Names = append([]string{fieldStruct.Name}, subField.Names...)
                        if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok {
                            subField.DBName = prefix + subField.DBName
                        }

                        if subField.IsPrimaryKey {
                            if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok {
                                modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField)
                            } else {
                                subField.IsPrimaryKey = false
                            }
                        }

                        if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil {
                            if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok {
                                newJoinTableHandler := &JoinTableHandler{}
                                newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType)
                                subField.Relationship.JoinTableHandler = newJoinTableHandler
                            }
                        }

                        modelStruct.StructFields = append(modelStruct.StructFields, subField)
                    }
                    continue
                } else {
                    // build relationships
                    switch indirectType.Kind() {
                    case reflect.Slice:
                        defer func(field *StructField) {
                            var (
                                relationship           = &Relationship{}
                                toScope                = scope.New(reflect.New(field.Struct.Type).Interface())
                                foreignKeys            []string
                                associationForeignKeys []string
                                elemType               = field.Struct.Type
                            )

                            if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" {
                                foreignKeys = strings.Split(foreignKey, ",")
                            }

                            if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" {
                                associationForeignKeys = strings.Split(foreignKey, ",")
                            } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" {
                                associationForeignKeys = strings.Split(foreignKey, ",")
                            }

                            for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr {
                                elemType = elemType.Elem()
                            }

                            if elemType.Kind() == reflect.Struct {
                                if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" {
                                    relationship.Kind = "many_to_many"

                                    { // Foreign Keys for Source
                                        joinTableDBNames := []string{}

                                        if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" {
                                            joinTableDBNames = strings.Split(foreignKey, ",")
                                        }

                                        // if no foreign keys defined with tag
                                        if len(foreignKeys) == 0 {
                                            for _, field := range modelStruct.PrimaryFields {
                                                foreignKeys = append(foreignKeys, field.DBName)
                                            }
                                        }

                                        for idx, foreignKey := range foreignKeys {
                                            if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
                                                // source foreign keys (db names)
                                                relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName)

                                                // setup join table foreign keys for source
                                                if len(joinTableDBNames) > idx {
                                                    // if defined join table's foreign key
                                                    relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
                                                } else {
                                                    defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName
                                                    relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
                                                }
                                            }
                                        }
                                    }

                                    { // Foreign Keys for Association (Destination)
                                        associationJoinTableDBNames := []string{}

                                        if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" {
                                            associationJoinTableDBNames = strings.Split(foreignKey, ",")
                                        }

                                        // if no association foreign keys defined with tag
                                        if len(associationForeignKeys) == 0 {
                                            for _, field := range toScope.PrimaryFields() {
                                                associationForeignKeys = append(associationForeignKeys, field.DBName)
                                            }
                                        }

                                        for idx, name := range associationForeignKeys {
                                            if field, ok := toScope.FieldByName(name); ok {
                                                // association foreign keys (db names)
                                                relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)

                                                // setup join table foreign keys for association
                                                if len(associationJoinTableDBNames) > idx {
                                                    relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
                                                } else {
                                                    // join table foreign keys for association
                                                    joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName
                                                    relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
                                                }
                                            }
                                        }
                                    }

                                    joinTableHandler := JoinTableHandler{}
                                    joinTableHandler.Setup(relationship, many2many, reflectType, elemType)
                                    relationship.JoinTableHandler = &joinTableHandler
                                    field.Relationship = relationship
                                } else {
                                    // User has many comments, associationType is User, comment use UserID as foreign key
                                    var associationType = reflectType.Name()
                                    var toFields = toScope.GetStructFields()
                                    relationship.Kind = "has_many"

                                    if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" {
                                        // Dog has many toys, tag polymorphic is Owner, then associationType is Owner
                                        // Toy use OwnerID, OwnerType ('dogs') as foreign key
                                        if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
                                            associationType = polymorphic
                                            relationship.PolymorphicType = polymorphicType.Name
                                            relationship.PolymorphicDBName = polymorphicType.DBName
                                            // if Dog has multiple set of toys set name of the set (instead of default 'dogs')
                                            if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok {
                                                relationship.PolymorphicValue = value
                                            } else {
                                                relationship.PolymorphicValue = scope.TableName()
                                            }
                                            polymorphicType.IsForeignKey = true
                                        }
                                    }

                                    // if no foreign keys defined with tag
                                    if len(foreignKeys) == 0 {
                                        // if no association foreign keys defined with tag
                                        if len(associationForeignKeys) == 0 {
                                            for _, field := range modelStruct.PrimaryFields {
                                                foreignKeys = append(foreignKeys, associationType+field.Name)
                                                associationForeignKeys = append(associationForeignKeys, field.Name)
                                            }
                                        } else {
                                            // generate foreign keys from defined association foreign keys
                                            for _, scopeFieldName := range associationForeignKeys {
                                                if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil {
                                                    foreignKeys = append(foreignKeys, associationType+foreignField.Name)
                                                    associationForeignKeys = append(associationForeignKeys, foreignField.Name)
                                                }
                                            }
                                        }
                                    } else {
                                        // generate association foreign keys from foreign keys
                                        if len(associationForeignKeys) == 0 {
                                            for _, foreignKey := range foreignKeys {
                                                if strings.HasPrefix(foreignKey, associationType) {
                                                    associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
                                                    if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
                                                        associationForeignKeys = append(associationForeignKeys, associationForeignKey)
                                                    }
                                                }
                                            }
                                            if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
                                                associationForeignKeys = []string{scope.PrimaryKey()}
                                            }
                                        } else if len(foreignKeys) != len(associationForeignKeys) {
                                            scope.Err(errors.New("invalid foreign keys, should have same length"))
                                            return
                                        }
                                    }

                                    for idx, foreignKey := range foreignKeys {
                                        if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
                                            if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil {
                                                // source foreign keys
                                                foreignField.IsForeignKey = true
                                                relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
                                                relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)

                                                // association foreign keys
                                                relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
                                                relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
                                            }
                                        }
                                    }

                                    if len(relationship.ForeignFieldNames) != 0 {
                                        field.Relationship = relationship
                                    }
                                }
                            } else {
                                field.IsNormal = true
                            }
                        }(field)
                    case reflect.Struct:
                        defer func(field *StructField) {
                            var (
                                // user has one profile, associationType is User, profile use UserID as foreign key
                                // user belongs to profile, associationType is Profile, user use ProfileID as foreign key
                                associationType           = reflectType.Name()
                                relationship              = &Relationship{}
                                toScope                   = scope.New(reflect.New(field.Struct.Type).Interface())
                                toFields                  = toScope.GetStructFields()
                                tagForeignKeys            []string
                                tagAssociationForeignKeys []string
                            )

                            if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" {
                                tagForeignKeys = strings.Split(foreignKey, ",")
                            }

                            if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" {
                                tagAssociationForeignKeys = strings.Split(foreignKey, ",")
                            } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" {
                                tagAssociationForeignKeys = strings.Split(foreignKey, ",")
                            }

                            if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" {
                                // Cat has one toy, tag polymorphic is Owner, then associationType is Owner
                                // Toy use OwnerID, OwnerType ('cats') as foreign key
                                if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
                                    associationType = polymorphic
                                    relationship.PolymorphicType = polymorphicType.Name
                                    relationship.PolymorphicDBName = polymorphicType.DBName
                                    // if Cat has several different types of toys set name for each (instead of default 'cats')
                                    if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok {
                                        relationship.PolymorphicValue = value
                                    } else {
                                        relationship.PolymorphicValue = scope.TableName()
                                    }
                                    polymorphicType.IsForeignKey = true
                                }
                            }

                            // Has One
                            {
                                var foreignKeys = tagForeignKeys
                                var associationForeignKeys = tagAssociationForeignKeys
                                // if no foreign keys defined with tag
                                if len(foreignKeys) == 0 {
                                    // if no association foreign keys defined with tag
                                    if len(associationForeignKeys) == 0 {
                                        for _, primaryField := range modelStruct.PrimaryFields {
                                            foreignKeys = append(foreignKeys, associationType+primaryField.Name)
                                            associationForeignKeys = append(associationForeignKeys, primaryField.Name)
                                        }
                                    } else {
                                        // generate foreign keys form association foreign keys
                                        for _, associationForeignKey := range tagAssociationForeignKeys {
                                            if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
                                                foreignKeys = append(foreignKeys, associationType+foreignField.Name)
                                                associationForeignKeys = append(associationForeignKeys, foreignField.Name)
                                            }
                                        }
                                    }
                                } else {
                                    // generate association foreign keys from foreign keys
                                    if len(associationForeignKeys) == 0 {
                                        for _, foreignKey := range foreignKeys {
                                            if strings.HasPrefix(foreignKey, associationType) {
                                                associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
                                                if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
                                                    associationForeignKeys = append(associationForeignKeys, associationForeignKey)
                                                }
                                            }
                                        }
                                        if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
                                            associationForeignKeys = []string{scope.PrimaryKey()}
                                        }
                                    } else if len(foreignKeys) != len(associationForeignKeys) {
                                        scope.Err(errors.New("invalid foreign keys, should have same length"))
                                        return
                                    }
                                }

                                for idx, foreignKey := range foreignKeys {
                                    if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
                                        if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil {
                                            foreignField.IsForeignKey = true
                                            // source foreign keys
                                            relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name)
                                            relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName)

                                            // association foreign keys
                                            relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
                                            relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
                                        }
                                    }
                                }
                            }

                            if len(relationship.ForeignFieldNames) != 0 {
                                relationship.Kind = "has_one"
                                field.Relationship = relationship
                            } else {
                                var foreignKeys = tagForeignKeys
                                var associationForeignKeys = tagAssociationForeignKeys

                                if len(foreignKeys) == 0 {
                                    // generate foreign keys & association foreign keys
                                    if len(associationForeignKeys) == 0 {
                                        for _, primaryField := range toScope.PrimaryFields() {
                                            foreignKeys = append(foreignKeys, field.Name+primaryField.Name)
                                            associationForeignKeys = append(associationForeignKeys, primaryField.Name)
                                        }
                                    } else {
                                        // generate foreign keys with association foreign keys
                                        for _, associationForeignKey := range associationForeignKeys {
                                            if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil {
                                                foreignKeys = append(foreignKeys, field.Name+foreignField.Name)
                                                associationForeignKeys = append(associationForeignKeys, foreignField.Name)
                                            }
                                        }
                                    }
                                } else {
                                    // generate foreign keys & association foreign keys
                                    if len(associationForeignKeys) == 0 {
                                        for _, foreignKey := range foreignKeys {
                                            if strings.HasPrefix(foreignKey, field.Name) {
                                                associationForeignKey := strings.TrimPrefix(foreignKey, field.Name)
                                                if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil {
                                                    associationForeignKeys = append(associationForeignKeys, associationForeignKey)
                                                }
                                            }
                                        }
                                        if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
                                            associationForeignKeys = []string{toScope.PrimaryKey()}
                                        }
                                    } else if len(foreignKeys) != len(associationForeignKeys) {
                                        scope.Err(errors.New("invalid foreign keys, should have same length"))
                                        return
                                    }
                                }

                                for idx, foreignKey := range foreignKeys {
                                    if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
                                        if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil {
                                            foreignField.IsForeignKey = true

                                            // association foreign keys
                                            relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
                                            relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)

                                            // source foreign keys
                                            relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
                                            relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
                                        }
                                    }
                                }

                                if len(relationship.ForeignFieldNames) != 0 {
                                    relationship.Kind = "belongs_to"
                                    field.Relationship = relationship
                                }
                            }
                        }(field)
                    default:
                        field.IsNormal = true
                    }
                }
            }

            // Even it is ignored, also possible to decode db value into the field
            if value, ok := field.TagSettingsGet("COLUMN"); ok {
                field.DBName = value
            } else {
                field.DBName = ToColumnName(fieldStruct.Name)
            }

            modelStruct.StructFields = append(modelStruct.StructFields, field)
        }
    }

    if len(modelStruct.PrimaryFields) == 0 {
        if field := getForeignField("id", modelStruct.StructFields); field != nil {
            field.IsPrimaryKey = true
            modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
        }
    }

    modelStructsMap.Store(hashKey, &modelStruct)

    return &modelStruct
}

其实首先折叠一下中间的 for 循环会好很多.

// GetModelStruct get value's model struct, relationships based on struct and tag definition
func (scope *Scope) GetModelStruct() *ModelStruct {
    var modelStruct ModelStruct
    // Scope value can't be nil
    if scope.Value == nil {
        return &modelStruct
    }

    reflectType := reflect.ValueOf(scope.Value).Type()
    for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr {
        reflectType = reflectType.Elem()
    }

    // Scope value need to be a struct
    if reflectType.Kind() != reflect.Struct {
        return &modelStruct
    }

    // Get Cached model struct
    isSingularTable := false
    if scope.db != nil && scope.db.parent != nil {
        scope.db.parent.RLock()
        isSingularTable = scope.db.parent.singularTable
        scope.db.parent.RUnlock()
    }

    hashKey := struct {
        singularTable bool
        reflectType   reflect.Type
    }{isSingularTable, reflectType}
    if value, ok := modelStructsMap.Load(hashKey); ok && value != nil {
        return value.(*ModelStruct)
    }

  modelStruct.ModelType = reflectType

  // Get all fields
    for i := 0; i < reflectType.NumField(); i++ {
    ... // 折叠先不看
  }

  if len(modelStruct.PrimaryFields) == 0 {
        if field := getForeignField("id", modelStruct.StructFields); field != nil {
            field.IsPrimaryKey = true
            modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
        }
    }

    modelStructsMap.Store(hashKey, &modelStruct)

  return &modelStruct
}

开头初始化了 var modelStruct ModelStruct, 这也是最后要返回的结果.

一开始先判断了 scope.Value 不能为空, 否则就直接返回.

然后解析了 scope.Value 的具体类型, 对于切片或指针, 要看具

reflectType := reflect.ValueOf(scope.Value).Type()
for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr {
  reflectType = reflectType.Elem()
}

如果 scope.Value 的具体类型不是 Struct, 也是直接返回.

然后, 判断是否有 ModelStruct 的缓存:

// Get Cached model struct
isSingularTable := false
if scope.db != nil && scope.db.parent != nil {
  scope.db.parent.RLock()
  isSingularTable = scope.db.parent.singularTable
  scope.db.parent.RUnlock()
}

hashKey := struct {
  singularTable bool
  reflectType   reflect.Type
}{isSingularTable, reflectType}
if value, ok := modelStructsMap.Load(hashKey); ok && value != nil {
  return value.(*ModelStruct)
}

modelStructsMap 是定义在外部的, 用于共享缓存.

var modelStructsMap sync.Map

如果可以从 modelStructsMap 找到, 就可以直接返回缓存.

modelStruct.ModelType = reflectType

略过 Get all fields 部分, 直接看后面的部分.

if len(modelStruct.PrimaryFields) == 0 {
  if field := getForeignField("id", modelStruct.StructFields); field != nil {
    field.IsPrimaryKey = true
    modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
  }
}

如果没有找到主键, 就会把 ID 作为主键.

modelStructsMap.Store(hashKey, &modelStruct)

return &modelStruct

将解析好的结果保存到 modelStructsMap, 作为缓存, 加快后面解析的过程. 最后返回结果.

现在, 已经将整个解析流程看完了, 除了获取字段的过程不清楚, 其他都应该清楚了.

解析的过程中用到了缓存, 也是我们可以借鉴的地方, sync.Map 可以安全地用于 goroutine 中共享.
另一点是将结构体作为 key, 同时兼顾了单数形式的表名和复数形式的表名.

字段解析

前面的过程中省略了解析字段的过程, 这是非常重要的一部分. GetModelStruct 方法的大部分的代码都集中在这一部分中.

for i := 0; i < reflectType.NumField(); i++ {

reflectType.NumField() 可以获取结构体中的字段总数.

if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) {

只解析可以导出的字段. 使用 reflectType.Field(i) 和索引 i, 可以获取到结构体中的字段.

field := &StructField{
  Struct:      fieldStruct,
  Name:        fieldStruct.Name,
  Names:       []string{fieldStruct.Name},
  Tag:         fieldStruct.Tag,
  TagSettings: parseTagSetting(fieldStruct.Tag),
}

StructField 初始化, 可以看到很多信息都是从 fieldStruct 中获取的.

这一部分对于学习如何解析结构体中的 Tag 非常有帮助, 仔细看一下.

fieldStruct.Tag 可以获取字段中的 tag 部分, 比如:

type Model struct {
  ID        uint `gorm:"primary_key"`
  CreatedAt time.Time
  UpdatedAt time.Time
  DeletedAt *time.Time
}

ID 字段中的 gorm:"primary_key" 部分.

fieldStruct.Name 可以获取字段的名字.

看一下具体是如何解析 Tag 字符串的.

func parseTagSetting(tags reflect.StructTag) map[string]string {
    setting := map[string]string{}
    for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
        if str == "" {
            continue
        }
        tags := strings.Split(str, ";")
        for _, value := range tags {
            v := strings.Split(value, ":")
            k := strings.TrimSpace(strings.ToUpper(v[0]))
            if len(v) >= 2 {
                setting[k] = strings.Join(v[1:], ":")
            } else {
                setting[k] = k
            }
        }
    }
    return setting
}

tags 的类型是 reflect.StructTag, 包含一些实用的方法, 比如 Get 方法可以获取特定的部分.
这里获取了 sqlgorm 部分.

每个 tag 部分中, 都是使用 ; 分隔的选项. 每个选项又可能是 key/value 类型的, 由 : 分隔,
也可能是一个单独的值.

具体看一个例子:

type User struct {
  gorm.Model
  Name         string
  Age          sql.NullInt64
  Birthday     *time.Time
  Email        string  `gorm:"type:varchar(100);unique_index"`
  Role         string  `gorm:"size:255"` // 设置字段大小为255
  MemberNumber *string `gorm:"unique;not null"` // 设置会员号(member number)唯一并且不为空
  Num          int     `gorm:"AUTO_INCREMENT"` // 设置 num 为自增类型
  Address      string  `gorm:"index:addr"` // 给address字段创建名为addr的索引
  IgnoreMe     int     `gorm:"-"` // 忽略本字段
}

比如 Email 字段的 tags 中 gorm 部分有两个选项, 一个是 type:varchar(100), 另一个是 unique_index.

结合上面的 parseTagSetting 代码, 我们知道这个字段的 tags 是如何被解析的了.

对于导出的字段, 也有办法设置忽略该字段, 设置选项为 - 就行了.

// is ignored field
if _, ok := field.TagSettingsGet("-"); ok {
  field.IsIgnored = true
}

然后就是解析每一个选项了. 主要的代码都在这里, 一点点看:

if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok {
  field.IsPrimaryKey = true
  modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
}

if _, ok := field.TagSettingsGet("DEFAULT"); ok && !field.IsPrimaryKey {
  field.HasDefaultValue = true
}

if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey {
  field.HasDefaultValue = true
}

设置 IsPrimaryKeyHasDefaultValue 属性. 如果是主键的话, 还会添加到 PrimaryFields 中.

indirectType := fieldStruct.Type
for indirectType.Kind() == reflect.Ptr {
  indirectType = indirectType.Elem()
}

获取字段的类型.

fieldValue := reflect.New(indirectType).Interface()

获取字段对应的值.

然后是根据 fieldValue 的类型进行了一堆判断, 一个个看.

  • 判断一
if _, isScanner := fieldValue.(sql.Scanner); isScanner {
  // is scanner
  field.IsScanner, field.IsNormal = true, true
  if indirectType.Kind() == reflect.Struct {
    for i := 0; i < indirectType.NumField(); i++ {
      for key, value := range parseTagSetting(indirectType.Field(i).Tag) {
        if _, ok := field.TagSettingsGet(key); !ok {
          field.TagSettingsSet(key, value)
        }
      }
    }
  }
}

如果实现了 sql.Scanner 接口, 设置了两个属性为 true.

  • 判断二

如果该字段是结构体, 将结构体中的每个 tag 设置都添加一遍.

else if _, isTime := fieldValue.(*time.Time); isTime {
  // is time
  field.IsNormal = true
}

如果是 *time.Time 结构体, 设置 IsNormal 为 true.

else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous {
  // is embedded struct
  for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields {
    subField = subField.clone()
    subField.Names = append([]string{fieldStruct.Name}, subField.Names...)
    if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok {
      subField.DBName = prefix + subField.DBName
    }

    if subField.IsPrimaryKey {
      if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok {
        modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField)
      } else {
        subField.IsPrimaryKey = false
      }
    }

    if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil {
      if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok {
        newJoinTableHandler := &JoinTableHandler{}
        newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType)
        subField.Relationship.JoinTableHandler = newJoinTableHandler
      }
    }

    modelStruct.StructFields = append(modelStruct.StructFields, subField)
  }
  continue
}
  • 判断三

如果 tag 设置中有 EMBEDDED 字段, 表示是一个嵌入的结构体.

for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields {

遍历该字段对应的 ModelStruct 中的每个 StructFields 中的 StructField.

subField = subField.clone() 直接在副本上操作.

subField.Names = append([]string{fieldStruct.Name}, subField.Names...)
if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok {
  subField.DBName = prefix + subField.DBName
}

重新设置 NamesDBName.

if subField.IsPrimaryKey {
  if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok {
    modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField)
  } else {
    subField.IsPrimaryKey = false
  }
}

如果 subField 是主键, 且有 PRIMARY_KEY tag 选项, 添加到 modelStruct.PrimaryFields 上去.
否则, 重置 subField.IsPrimaryKey 为 false.

if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil {
  if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok {
    newJoinTableHandler := &JoinTableHandler{}
    newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType)
    subField.Relationship.JoinTableHandler = newJoinTableHandler
  }
}

初始化了 subField 中的 JoinTableHandler.

modelStruct.StructFields = append(modelStruct.StructFields, subField)

最后将 subField 添加到了 modelStruct.StructFields 中.

最后, 使用 continue 开始新的 for 循环. 因此, field.TagSettingsGet("EMBEDDED") 部分也结束了.

  • 判断四

如果上面的三个判断都不满足, 就进入了最后的 else 判断了.

而这里面又是个 switch 判断, 真的是忧伤.

根据 switch indirectType.Kind() { 的类型, 主要是切片和结构体, 先看 default 部分:

default:
  field.IsNormal = true
}

case reflect.Slice:case reflect.Struct: 里都是一个 defer 函数.

case reflect.Slice:
  defer func(field *StructField) {
    var (
      relationship           = &Relationship{}
      toScope                = scope.New(reflect.New(field.Struct.Type).Interface())
      foreignKeys            []string
      associationForeignKeys []string
      elemType               = field.Struct.Type
    )

    if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" {
      foreignKeys = strings.Split(foreignKey, ",")
    }

    if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" {
      associationForeignKeys = strings.Split(foreignKey, ",")
    } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" {
      associationForeignKeys = strings.Split(foreignKey, ",")
    }

    for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr {
      elemType = elemType.Elem()
    }

    if elemType.Kind() == reflect.Struct {
      if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" {
        relationship.Kind = "many_to_many"

        { // Foreign Keys for Source
          joinTableDBNames := []string{}

          if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" {
            joinTableDBNames = strings.Split(foreignKey, ",")
          }

          // if no foreign keys defined with tag
          if len(foreignKeys) == 0 {
            for _, field := range modelStruct.PrimaryFields {
              foreignKeys = append(foreignKeys, field.DBName)
            }
          }

          for idx, foreignKey := range foreignKeys {
            if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
              // source foreign keys (db names)
              relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName)

              // setup join table foreign keys for source
              if len(joinTableDBNames) > idx {
                // if defined join table's foreign key
                relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
              } else {
                defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName
                relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
              }
            }
          }
        }

        { // Foreign Keys for Association (Destination)
          associationJoinTableDBNames := []string{}

          if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" {
            associationJoinTableDBNames = strings.Split(foreignKey, ",")
          }

          // if no association foreign keys defined with tag
          if len(associationForeignKeys) == 0 {
            for _, field := range toScope.PrimaryFields() {
              associationForeignKeys = append(associationForeignKeys, field.DBName)
            }
          }

          for idx, name := range associationForeignKeys {
            if field, ok := toScope.FieldByName(name); ok {
              // association foreign keys (db names)
              relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)

              // setup join table foreign keys for association
              if len(associationJoinTableDBNames) > idx {
                relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
              } else {
                // join table foreign keys for association
                joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName
                relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
              }
            }
          }
        }

        joinTableHandler := JoinTableHandler{}
        joinTableHandler.Setup(relationship, many2many, reflectType, elemType)
        relationship.JoinTableHandler = &joinTableHandler
        field.Relationship = relationship
      } else {
        // User has many comments, associationType is User, comment use UserID as foreign key
        var associationType = reflectType.Name()
        var toFields = toScope.GetStructFields()
        relationship.Kind = "has_many"

        if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" {
          // Dog has many toys, tag polymorphic is Owner, then associationType is Owner
          // Toy use OwnerID, OwnerType ('dogs') as foreign key
          if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
            associationType = polymorphic
            relationship.PolymorphicType = polymorphicType.Name
            relationship.PolymorphicDBName = polymorphicType.DBName
            // if Dog has multiple set of toys set name of the set (instead of default 'dogs')
            if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok {
              relationship.PolymorphicValue = value
            } else {
              relationship.PolymorphicValue = scope.TableName()
            }
            polymorphicType.IsForeignKey = true
          }
        }

        // if no foreign keys defined with tag
        if len(foreignKeys) == 0 {
          // if no association foreign keys defined with tag
          if len(associationForeignKeys) == 0 {
            for _, field := range modelStruct.PrimaryFields {
              foreignKeys = append(foreignKeys, associationType+field.Name)
              associationForeignKeys = append(associationForeignKeys, field.Name)
            }
          } else {
            // generate foreign keys from defined association foreign keys
            for _, scopeFieldName := range associationForeignKeys {
              if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil {
                foreignKeys = append(foreignKeys, associationType+foreignField.Name)
                associationForeignKeys = append(associationForeignKeys, foreignField.Name)
              }
            }
          }
        } else {
          // generate association foreign keys from foreign keys
          if len(associationForeignKeys) == 0 {
            for _, foreignKey := range foreignKeys {
              if strings.HasPrefix(foreignKey, associationType) {
                associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
                if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
                  associationForeignKeys = append(associationForeignKeys, associationForeignKey)
                }
              }
            }
            if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
              associationForeignKeys = []string{scope.PrimaryKey()}
            }
          } else if len(foreignKeys) != len(associationForeignKeys) {
            scope.Err(errors.New("invalid foreign keys, should have same length"))
            return
          }
        }

        for idx, foreignKey := range foreignKeys {
          if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
            if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil {
              // source foreign keys
              foreignField.IsForeignKey = true
              relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
              relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)

              // association foreign keys
              relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
              relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
            }
          }
        }

        if len(relationship.ForeignFieldNames) != 0 {
          field.Relationship = relationship
        }
      }
    } else {
      field.IsNormal = true
    }
  }(field)

主要是处理了关系类型的 tag.

这一部分先跳过吧, 等具体研究关系的实现, 再继续深入.

等这整个 if 判断都结束后, 解析一下列名, 最后将 fields 添加到 modelStruct.StructFields:

// Even it is ignored, also possible to decode db value into the field
if value, ok := field.TagSettingsGet("COLUMN"); ok {
  field.DBName = value
} else {
  field.DBName = ToColumnName(fieldStruct.Name)
}

modelStruct.StructFields = append(modelStruct.StructFields, field)

小结

所以, 整个模型解析的过程就是如此. 最耗时的部分在于解析每个字段, 解析 tag 以及字段间的关系.

总结

定义模型并解析模型的过程已经看完了, 但关于模型还有很多内容, 比如将模型转换为表插入数据库等.

查看原文

赞 0 收藏 0 评论 0

帅气猫咪 发布了文章 · 2019-12-22

01GORM源码解读

简介

GORM 源码解读, 基于 v1.9.11 版本.

起步

官方文档上入门的例子如下:

package main

import (
  "github.com/jinzhu/gorm"
  _ "github.com/jinzhu/gorm/dialects/sqlite"
)

type Product struct {
  gorm.Model
  Code string
  Price uint
}

func main() {
  db, err := gorm.Open("sqlite3", "test.db")
  if err != nil {
    panic("failed to connect database")
  }
  defer db.Close()

  // Migrate the schema
  db.AutoMigrate(&Product{})

  // 创建
  db.Create(&Product{Code: "L1212", Price: 1000})

  // 读取
  var product Product
  db.First(&product, 1) // 查询id为1的product
  db.First(&product, "code = ?", "L1212") // 查询code为l1212的product

  // 更新 - 更新product的price为2000
  db.Model(&product).Update("Price", 2000)

  // 删除 - 删除product
  db.Delete(&product)
}

数据库连接

gorm.Open 开始看起吧, 看数据库是怎么连接的:

// Open initialize a new db connection, need to import driver first, e.g:
//
//     import _ "github.com/go-sql-driver/mysql"
//     func main() {
//       db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
//     }
// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with
//    import _ "github.com/jinzhu/gorm/dialects/mysql"
//    // import _ "github.com/jinzhu/gorm/dialects/postgres"
//    // import _ "github.com/jinzhu/gorm/dialects/sqlite"
//    // import _ "github.com/jinzhu/gorm/dialects/mssql"
func Open(dialect string, args ...interface{}) (db *DB, err error) {
    if len(args) == 0 {
        err = errors.New("invalid database source")
        return nil, err
    }
    var source string
    var dbSQL SQLCommon
    var ownDbSQL bool

    switch value := args[0].(type) {
    case string:
        var driver = dialect
        if len(args) == 1 {
            source = value
        } else if len(args) >= 2 {
            driver = value
            source = args[1].(string)
        }
        dbSQL, err = sql.Open(driver, source)
        ownDbSQL = true
    case SQLCommon:
        dbSQL = value
        ownDbSQL = false
    default:
        return nil, fmt.Errorf("invalid database source: %v is not a valid type", value)
    }

    db = &DB{
        db:        dbSQL,
        logger:    defaultLogger,
        callbacks: DefaultCallback,
        dialect:   newDialect(dialect, dbSQL),
    }
    db.parent = db
    if err != nil {
        return
    }
    // Send a ping to make sure the database connection is alive.
    if d, ok := dbSQL.(*sql.DB); ok {
        if err = d.Ping(); err != nil && ownDbSQL {
            d.Close()
        }
    }
    return
}

gorm.Open 有两个参数, 一个是数据库名称, 其余是连接参数.

switch 语句中, 可以发现如果第一个参数是 string 类型, 实际上是通过 Golang 中的 sql 模块连接:

dbSQL, err = sql.Open(driver, source)

也可以直接传递一个实现了 SQLCommon 接口的实例.

然后初始化了一个 gorm.DB 实例, 并在最后执行了一次 ping 请求, 测试数据库连接是否正常.

看一下 gorm.DB 结构体:

// DB contains information for current db connection
type DB struct {
    sync.RWMutex
    Value        interface{}
    Error        error
    RowsAffected int64

    // single db
    db                SQLCommon
    blockGlobalUpdate bool
    logMode           logModeValue
    logger            logger
    search            *search
    values            sync.Map

    // global db
    parent        *DB
    callbacks     *Callback
    dialect       Dialect
    singularTable bool

    // function to be used to override the creating of a new timestamp
    nowFuncOverride func() time.Time
}

gorm.DB 扩展自 sync.RWMutex 读写互斥锁.

gorm.DB

上面已经看过了 gorm.DB 结构体的定义了, 从入门的示例代码中可以看出, 所有的操作都是围绕它来进行的,
所以 gorm.DB 是核心的结构体. 看下它具体实现了哪些方法.

// New clone a new db connection without search conditions
func (s *DB) New() *DB {
    clone := s.clone()
    clone.search = nil
    clone.Value = nil
    return clone
}

type closer interface {
    Close() error
}

// Close close current db connection.  If database connection is not an io.Closer, returns an error.
func (s *DB) Close() error {
    if db, ok := s.parent.db.(closer); ok {
        return db.Close()
    }
    return errors.New("can't close current db")
}

克隆数据库的连接和关闭数据库连接. New 方法内部使用到了 s.clone(),

func (s *DB) clone() *DB {
    db := &DB{
        db:                s.db,
        parent:            s.parent,
        logger:            s.logger,
        logMode:           s.logMode,
        Value:             s.Value,
        Error:             s.Error,
        blockGlobalUpdate: s.blockGlobalUpdate,
        dialect:           newDialect(s.dialect.GetName(), s.db),
        nowFuncOverride:   s.nowFuncOverride,
    }

    s.values.Range(func(k, v interface{}) bool {
        db.values.Store(k, v)
        return true
    })

    if s.search == nil {
        db.search = &search{limit: -1, offset: -1}
    } else {
        db.search = s.search.clone()
    }

    db.search.db = db
    return db
}

略过一些简单的 get/set 方法, 接着看

// NewScope create a scope for current operation
func (s *DB) NewScope(value interface{}) *Scope {
    dbClone := s.clone()
    dbClone.Value = value
    scope := &Scope{db: dbClone, Value: value}
    if s.search != nil {
        scope.Search = s.search.clone()
    } else {
        scope.Search = &search{}
    }
    return scope
}

NewScope 会为当前的操作创建一个新的 scope (作用域).

// QueryExpr returns the query as expr object
func (s *DB) QueryExpr() *expr {
    scope := s.NewScope(s.Value)
    scope.InstanceSet("skip_bindvar", true)
    scope.prepareQuerySQL()

    return Expr(scope.SQL, scope.SQLVars...)
}

// SubQuery returns the query as sub query
func (s *DB) SubQuery() *expr {
    scope := s.NewScope(s.Value)
    scope.InstanceSet("skip_bindvar", true)
    scope.prepareQuerySQL()

    return Expr(fmt.Sprintf("(%v)", scope.SQL), scope.SQLVars...)
}

QueryExprSubQuery 都用到了 NewScope, 在当前的作用域下获取查询表达式和进行子查询.

接着是很多查询方法, 类似 Where,

// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
    return s.clone().search.Where(query, args...).db
}

跳过这些方法, 等后面探究查询表达式的时候再详细研究.

// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically
//     func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
//         return db.Where("amount > ?", 1000)
//     }
//
//     func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
//         return func (db *gorm.DB) *gorm.DB {
//             return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
//         }
//     }
//
//     db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
// Refer https://jinzhu.github.io/gorm/crud.html#scopes
func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
    for _, f := range funcs {
        s = f(s)
    }
    return s
}

Scopes 是一个钩子函数, 用于动态添加查询条件, 这在函数是一等公民的语言里是一个常见的模式.

事务实现

看一下事务是如何实现的:

// Begin begins a transaction
func (s *DB) Begin() *DB {
    return s.BeginTx(context.Background(), &sql.TxOptions{})
}

// BeginTx begins a transaction with options
func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
    c := s.clone()
    if db, ok := c.db.(sqlDb); ok && db != nil {
        tx, err := db.BeginTx(ctx, opts)
        c.db = interface{}(tx).(SQLCommon)

        c.dialect.SetDB(c.db)
        c.AddError(err)
    } else {
        c.AddError(ErrCantStartTransaction)
    }
    return c
}

这一部分是开始事务时的操作, 实际上是 c.db 实现了 sqlDb 接口, 调用了 BeginTx 方法.

type sqlDb interface {
    Begin() (*sql.Tx, error)
    BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}

接着看如何提交事务:

// Commit commit a transaction
func (s *DB) Commit() *DB {
    var emptySQLTx *sql.Tx
    if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
        s.AddError(db.Commit())
    } else {
        s.AddError(ErrInvalidTransaction)
    }
    return s
}

和开始事务类似, s.db 实现了 sqlTx 接口, 调用了 Commit 方法.

type sqlTx interface {
    Commit() error
    Rollback() error
}

sqlTx 接口里还有个 Rollback 方法, 所以回滚操作也是类似的:

// Rollback rollback a transaction
func (s *DB) Rollback() *DB {
    var emptySQLTx *sql.Tx
    if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
        if err := db.Rollback(); err != nil && err != sql.ErrTxDone {
            s.AddError(err)
        }
    } else {
        s.AddError(ErrInvalidTransaction)
    }
    return s
}

// RollbackUnlessCommitted rollback a transaction if it has not yet been
// committed.
func (s *DB) RollbackUnlessCommitted() *DB {
    var emptySQLTx *sql.Tx
    if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
        err := db.Rollback()
        // Ignore the error indicating that the transaction has already
        // been committed.
        if err != sql.ErrTxDone {
            s.AddError(err)
        }
    } else {
        s.AddError(ErrInvalidTransaction)
    }
    return s
}

RollbackUnlessCommittedRollback 的区别在于前者少了一个 err != nil 的判断,
看了半天还是难以理解这有什么差别.

RollbackUnlessCommitted 作者给出的例子如下:

func doTransaction(DB *gorm.DB) error {
  tx := DB.Begin()
  defer tx.RollbackUnlessCommitted()

  u := User{Name: "test"}
  if err != tx.Save(&User).Error; err != nil {
    return err
  }
  return tx.Commit().Error
}

相比较而言, 官方文档上事务的例子如下:

func CreateAnimals(db *gorm.DB) error {
  // 请注意,事务一旦开始,你就应该使用 tx 作为数据库句柄
  tx := db.Begin()
  defer func() {
    if r := recover(); r != nil {
      tx.Rollback()
    }
  }()

  if err := tx.Error; err != nil {
    return err
  }

  if err := tx.Create(&Animal{Name: "Giraffe"}).Error; err != nil {
    tx.Rollback()
    return err
  }

  if err := tx.Create(&Animal{Name: "Lion"}).Error; err != nil {
    tx.Rollback()
    return err
  }

  return tx.Commit().Error
}

总结

暂时就看到这里吧, gorm.DB 还有很多方法等待后续发掘.

主要看了数据库的连接过程, 基本是通过 dbSQL, err = sql.Open(driver, source) 实现的.

也看了事务部分, 主要是要实现两个接口:

type sqlDb interface {
    Begin() (*sql.Tx, error)
    BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}

type sqlTx interface {
    Commit() error
    Rollback() error
}

当然, 对于其中的 RollbackUnlessCommittedRollback 有点疑惑, 因为我想不明白到底有什么不同.

既然是 ORM, 模型定义应该是重中之重, 后续将探索 Model 实现.

查看原文

赞 0 收藏 0 评论 0

帅气猫咪 发布了文章 · 2019-12-15

04Gin源码解读

简介

Gin 源码解读, 基于 v1.5.0 版本.

内置中间件的实现

前面已经研究过中间件的原理了, 这次来看一下内置的中间件是如何实现的.

recovery

// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
func Recovery() HandlerFunc {
    return RecoveryWithWriter(DefaultErrorWriter)
}

recovery 中间件用于从 panic 中恢复, 并返回 500 响应.

在看代码之前, 首先介绍下内置的 recover 函数.

func recover() interface{}
The recover built-in function allows a program to manage behavior of a panicking goroutine. Executing a call to recover inside a deferred function (but not any function called by it) stops the panicking sequence by restoring normal execution and retrieves the error value passed to the call of panic. If recover is called outside the deferred function it will not stop a panicking sequence. In this case, or when the goroutine is not panicking, or if the argument supplied to panic was nil, recover returns nil. Thus the return value from recover reports whether the goroutine is panicking.

recover 用于控制处于 panic 状态中的 goroutine 的行为, 只能用于 defer 语句的函数中.

简单的用法如下:

package main

import (
    "fmt"
)

func main() {
    defer func() {
        err := recover()
        if err != nil {
            fmt.Println("catch panic:", err)
        }
    }()

    panic("hello error")
}

具体看一下 RecoveryWithWriter 的实现.

// RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one.
func RecoveryWithWriter(out io.Writer) HandlerFunc {
    var logger *log.Logger
    if out != nil {
        logger = log.New(out, "\n\n\x1b[31m", log.LstdFlags)
    }
    return func(c *Context) {
        defer func() {
            if err := recover(); err != nil {
                // Check for a broken connection, as it is not really a
                // condition that warrants a panic stack trace.
                var brokenPipe bool
                if ne, ok := err.(*net.OpError); ok {
                    if se, ok := ne.Err.(*os.SyscallError); ok {
                        if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
                            brokenPipe = true
                        }
                    }
                }
                if logger != nil {
                    stack := stack(3)
                    httpRequest, _ := httputil.DumpRequest(c.Request, false)
                    headers := strings.Split(string(httpRequest), "\r\n")
                    for idx, header := range headers {
                        current := strings.Split(header, ":")
                        if current[0] == "Authorization" {
                            headers[idx] = current[0] + ": *"
                        }
                    }
                    if brokenPipe {
                        logger.Printf("%s\n%s%s", err, string(httpRequest), reset)
                    } else if IsDebugging() {
                        logger.Printf("[Recovery] %s panic recovered:\n%s\n%s\n%s%s",
                            timeFormat(time.Now()), strings.Join(headers, "\r\n"), err, stack, reset)
                    } else {
                        logger.Printf("[Recovery] %s panic recovered:\n%s\n%s%s",
                            timeFormat(time.Now()), err, stack, reset)
                    }
                }

                // If the connection is dead, we can't write a status to it.
                if brokenPipe {
                    c.Error(err.(error)) // nolint: errcheck
                    c.Abort()
                } else {
                    c.AbortWithStatus(http.StatusInternalServerError)
                }
            }
        }()
        c.Next()
    }
}

简单来看, 最后返回的 func(c *Context) 中间件函数内部分为两个主要部分, 一个是 defer 处理, 另一个是c.Next().

实际上中间件函数什么都不做, 只是调用 c.Next() 转移控制权, 顺着调用链去运行其他中间件和 handler 函数.
当调用链全部执行完, c.Next() 运行完毕, recover 结束之后, 就轮到 defer 语句出场了.

首先判断了连接是否已经失效:

// Check for a broken connection, as it is not really a
// condition that warrants a panic stack trace.
var brokenPipe bool
if ne, ok := err.(*net.OpError); ok {
  if se, ok := ne.Err.(*os.SyscallError); ok {
    if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
      brokenPipe = true
    }
  }
}

然后记录日志:

if logger != nil {
  stack := stack(3)
  httpRequest, _ := httputil.DumpRequest(c.Request, false)
  headers := strings.Split(string(httpRequest), "\r\n")
  for idx, header := range headers {
    current := strings.Split(header, ":")
    if current[0] == "Authorization" {
      headers[idx] = current[0] + ": *"
    }
  }
  if brokenPipe {
    logger.Printf("%s\n%s%s", err, string(httpRequest), reset)
  } else if IsDebugging() {
    logger.Printf("[Recovery] %s panic recovered:\n%s\n%s\n%s%s",
      timeFormat(time.Now()), strings.Join(headers, "\r\n"), err, stack, reset)
  } else {
    logger.Printf("[Recovery] %s panic recovered:\n%s\n%s%s",
      timeFormat(time.Now()), err, stack, reset)
  }
}

最后, 根据连接状态, 进行不同的处理:

// If the connection is dead, we can't write a status to it.
if brokenPipe {
  c.Error(err.(error)) // nolint: errcheck
  c.Abort()
} else {
  c.AbortWithStatus(http.StatusInternalServerError)
}

总的来看, 没有什么特殊的, 如果你已经熟悉了 Golang 内置的 recover 机制.

auth

auth 中间件用于 Basic HTTP Authorization.

// BasicAuth returns a Basic HTTP Authorization middleware. It takes as argument a map[string]string where
// the key is the user name and the value is the password.
func BasicAuth(accounts Accounts) HandlerFunc {
    return BasicAuthForRealm(accounts, "")
}

内部实现为:

// BasicAuthForRealm returns a Basic HTTP Authorization middleware. It takes as arguments a map[string]string where
// the key is the user name and the value is the password, as well as the name of the Realm.
// If the realm is empty, "Authorization Required" will be used by default.
// (see http://tools.ietf.org/html/rfc2617#section-1.2)
func BasicAuthForRealm(accounts Accounts, realm string) HandlerFunc {
    if realm == "" {
        realm = "Authorization Required"
    }
    realm = "Basic realm=" + strconv.Quote(realm)
    pairs := processAccounts(accounts)
    return func(c *Context) {
        // Search user in the slice of allowed credentials
        user, found := pairs.searchCredential(c.requestHeader("Authorization"))
        if !found {
            // Credentials doesn't match, we return 401 and abort handlers chain.
            c.Header("WWW-Authenticate", realm)
            c.AbortWithStatus(http.StatusUnauthorized)
            return
        }

        // The user credentials was found, set user's id to key AuthUserKey in this context, the user's id can be read later using
        // c.MustGet(gin.AuthUserKey).
        c.Set(AuthUserKey, user)
    }
}

使用 pairs 变量保存用户名密码对. 如果用户没有找到, 会返回 401 响应, 并设置对应的 WWW-Authenticate Header.

// AuthUserKey is the cookie name for user credential in basic auth.
const AuthUserKey = "user"

// Accounts defines a key/value for user/pass list of authorized logins.
type Accounts map[string]string

type authPair struct {
    value string
    user  string
}

type authPairs []authPair

func (a authPairs) searchCredential(authValue string) (string, bool) {
    if authValue == "" {
        return "", false
    }
    for _, pair := range a {
        if pair.value == authValue {
            return pair.user, true
        }
    }
    return "", false
}

func processAccounts(accounts Accounts) authPairs {
    assert1(len(accounts) > 0, "Empty list of authorized credentials")
    pairs := make(authPairs, 0, len(accounts))
    for user, password := range accounts {
        assert1(user != "", "User can not be empty")
        value := authorizationHeader(user, password)
        pairs = append(pairs, authPair{
            value: value,
            user:  user,
        })
    }
    return pairs
}

func authorizationHeader(user, password string) string {
    base := user + ":" + password
    return "Basic " + base64.StdEncoding.EncodeToString([]byte(base))
}

简单认证中间件也没有什么特殊的, 看源码可以对认证过程有更清晰的了解.
可以参考 MDN-HTTP 身份验证.

logger

logger 实现了内置的日志记录器.

日志是可配置的, 先来看一下数据结构部分.

// LoggerConfig defines the config for Logger middleware.
type LoggerConfig struct {
    // Optional. Default value is gin.defaultLogFormatter
    Formatter LogFormatter

    // Output is a writer where logs are written.
    // Optional. Default value is gin.DefaultWriter.
    Output io.Writer

    // SkipPaths is a url path array which logs are not written.
    // Optional.
    SkipPaths []string
}

// LogFormatter gives the signature of the formatter function passed to LoggerWithFormatter
type LogFormatter func(params LogFormatterParams) string

// LogFormatterParams is the structure any formatter will be handed when time to log comes
type LogFormatterParams struct {
    Request *http.Request

    // TimeStamp shows the time after the server returns a response.
    TimeStamp time.Time
    // StatusCode is HTTP response code.
    StatusCode int
    // Latency is how much time the server cost to process a certain request.
    Latency time.Duration
    // ClientIP equals Context's ClientIP method.
    ClientIP string
    // Method is the HTTP method given to the request.
    Method string
    // Path is a path the client requests.
    Path string
    // ErrorMessage is set if error has occurred in processing the request.
    ErrorMessage string
    // isTerm shows whether does gin's output descriptor refers to a terminal.
    isTerm bool
    // BodySize is the size of the Response Body
    BodySize int
    // Keys are the keys set on the request's context.
    Keys map[string]interface{}
}

日志格式里有个 isTerm 是为 shell 优化的标识符, 用于显示颜色.

const (
    green   = "\033[97;42m"
    white   = "\033[90;47m"
    yellow  = "\033[90;43m"
    red     = "\033[97;41m"
    blue    = "\033[97;44m"
    magenta = "\033[97;45m"
    cyan    = "\033[97;46m"
    reset   = "\033[0m"
)

var consoleColorMode = autoColor

// StatusCodeColor is the ANSI color for appropriately logging http status code to a terminal.
func (p *LogFormatterParams) StatusCodeColor() string {
    code := p.StatusCode

    switch {
    case code >= http.StatusOK && code < http.StatusMultipleChoices:
        return green
    case code >= http.StatusMultipleChoices && code < http.StatusBadRequest:
        return white
    case code >= http.StatusBadRequest && code < http.StatusInternalServerError:
        return yellow
    default:
        return red
    }
}

// MethodColor is the ANSI color for appropriately logging http method to a terminal.
func (p *LogFormatterParams) MethodColor() string {
    method := p.Method

    switch method {
    case "GET":
        return blue
    case "POST":
        return cyan
    case "PUT":
        return yellow
    case "DELETE":
        return red
    case "PATCH":
        return green
    case "HEAD":
        return magenta
    case "OPTIONS":
        return white
    default:
        return reset
    }
}

// ResetColor resets all escape attributes.
func (p *LogFormatterParams) ResetColor() string {
    return reset
}

// IsOutputColor indicates whether can colors be outputted to the log.
func (p *LogFormatterParams) IsOutputColor() bool {
    return consoleColorMode == forceColor || (consoleColorMode == autoColor && p.isTerm)
}

看一下中间件的具体实现:

// Logger instances a Logger middleware that will write the logs to gin.DefaultWriter.
// By default gin.DefaultWriter = os.Stdout.
func Logger() HandlerFunc {
    return LoggerWithConfig(LoggerConfig{})
}

// LoggerWithFormatter instance a Logger middleware with the specified log format function.
func LoggerWithFormatter(f LogFormatter) HandlerFunc {
    return LoggerWithConfig(LoggerConfig{
        Formatter: f,
    })
}

// LoggerWithWriter instance a Logger middleware with the specified writer buffer.
// Example: os.Stdout, a file opened in write mode, a socket...
func LoggerWithWriter(out io.Writer, notlogged ...string) HandlerFunc {
    return LoggerWithConfig(LoggerConfig{
        Output:    out,
        SkipPaths: notlogged,
    })
}

// LoggerWithConfig instance a Logger middleware with config.
func LoggerWithConfig(conf LoggerConfig) HandlerFunc {
    formatter := conf.Formatter
    if formatter == nil {
        formatter = defaultLogFormatter
    }

    out := conf.Output
    if out == nil {
        out = DefaultWriter
    }

    notlogged := conf.SkipPaths

    isTerm := true

    if w, ok := out.(*os.File); !ok || os.Getenv("TERM") == "dumb" ||
        (!isatty.IsTerminal(w.Fd()) && !isatty.IsCygwinTerminal(w.Fd())) {
        isTerm = false
    }

    var skip map[string]struct{}

    if length := len(notlogged); length > 0 {
        skip = make(map[string]struct{}, length)

        for _, path := range notlogged {
            skip[path] = struct{}{}
        }
    }

    return func(c *Context) {
        // Start timer
        start := time.Now()
        path := c.Request.URL.Path
        raw := c.Request.URL.RawQuery

        // Process request
        c.Next()

        // Log only when path is not being skipped
        if _, ok := skip[path]; !ok {
            param := LogFormatterParams{
                Request: c.Request,
                isTerm:  isTerm,
                Keys:    c.Keys,
            }

            // Stop timer
            param.TimeStamp = time.Now()
            param.Latency = param.TimeStamp.Sub(start)

            param.ClientIP = c.ClientIP()
            param.Method = c.Request.Method
            param.StatusCode = c.Writer.Status()
            param.ErrorMessage = c.Errors.ByType(ErrorTypePrivate).String()

            param.BodySize = c.Writer.Size()

            if raw != "" {
                path = path + "?" + raw
            }

            param.Path = path

            fmt.Fprint(out, formatter(param))
        }
    }
}

上面是三个中间件, 内部都使用了 LoggerWithConfig 函数.

中间部分有个转换, notlogged := conf.SkipPaths 的类型是 []string, 但在初始化的时候改成了 map.

var skip map[string]struct{}

if length := len(notlogged); length > 0 {
  skip = make(map[string]struct{}, length)

  for _, path := range notlogged {
    skip[path] = struct{}{}
  }
}

这是因为当判断一个元素是否存在时, hash 的实现 O(1) 比数组 O(n) 要高效, if _, ok := skip[path]; !ok {.

最后, 里面最重要的语句是 fmt.Fprint(out, formatter(param)), 将 out 输出格式化的日志.

默认的格式化函数是 defaultLogFormatter:

formatter := conf.Formatter
if formatter == nil {
  formatter = defaultLogFormatter
}

看一下 defaultLogFormatter 的实现:

// defaultLogFormatter is the default log format function Logger middleware uses.
var defaultLogFormatter = func(param LogFormatterParams) string {
    var statusColor, methodColor, resetColor string
    if param.IsOutputColor() {
        statusColor = param.StatusCodeColor()
        methodColor = param.MethodColor()
        resetColor = param.ResetColor()
    }

    if param.Latency > time.Minute {
        // Truncate in a golang < 1.8 safe way
        param.Latency = param.Latency - param.Latency%time.Second
    }
    return fmt.Sprintf("[GIN] %v |%s %3d %s| %13v | %15s |%s %-7s %s %s\n%s",
        param.TimeStamp.Format("2006/01/02 - 15:04:05"),
        statusColor, param.StatusCode, resetColor,
        param.Latency,
        param.ClientIP,
        methodColor, param.Method, resetColor,
        param.Path,
        param.ErrorMessage,
    )
}

所以, 要实现自定义格式化内容, 就是要实现 func(param LogFormatterParams) string 函数.

官方文档中自定义格式化内容的例子如下:

func main() {
    router := gin.New()

    // LoggerWithFormatter middleware will write the logs to gin.DefaultWriter
    // By default gin.DefaultWriter = os.Stdout
    router.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {

        // your custom format
        return fmt.Sprintf("%s - [%s] \"%s %s %s %d %s \"%s\" %s\"\n",
                param.ClientIP,
                param.TimeStamp.Format(time.RFC1123),
                param.Method,
                param.Path,
                param.Request.Proto,
                param.StatusCode,
                param.Latency,
                param.Request.UserAgent(),
                param.ErrorMessage,
        )
    }))
    router.Use(gin.Recovery())

    router.GET("/ping", func(c *gin.Context) {
        c.String(200, "pong")
    })

    router.Run(":8080")
}

另一点则是计算时延, 在函数的开始计时 start := time.Now(), 当 c.Next() 处理完请求后,
停止计时 param.Latency = param.TimeStamp.Sub(start).

所以, 如果你需要一个完整的时延, 就需要将 logger 放在中间件的最前面.
当你想要忽略中间件的耗时, 只统计 handler 处理时间, 就需要放在中间件的最后.
但遇到后者的情形, 最好还是自己实现一个计时的中间件.

errors

看一下错误类型是如何定义的.

// ErrorType is an unsigned 64-bit error code as defined in the gin spec.
type ErrorType uint64

const (
    // ErrorTypeBind is used when Context.Bind() fails.
    ErrorTypeBind ErrorType = 1 << 63
    // ErrorTypeRender is used when Context.Render() fails.
    ErrorTypeRender ErrorType = 1 << 62
    // ErrorTypePrivate indicates a private error.
    ErrorTypePrivate ErrorType = 1 << 0
    // ErrorTypePublic indicates a public error.
    ErrorTypePublic ErrorType = 1 << 1
    // ErrorTypeAny indicates any other error.
    ErrorTypeAny ErrorType = 1<<64 - 1
    // ErrorTypeNu indicates any other error.
    ErrorTypeNu = 2
)

// Error represents a error's specification.
type Error struct {
    Err  error
    Type ErrorType
    Meta interface{}
}

type errorMsgs []*Error

Error 结构体中有三个字段, 一个是原始的错误 Err, 一个是错误类型 Type, 另一个是 Meta 元信息.

// SetType sets the error's type.
func (msg *Error) SetType(flags ErrorType) *Error {
    msg.Type = flags
    return msg
}

// SetMeta sets the error's meta data.
func (msg *Error) SetMeta(data interface{}) *Error {
    msg.Meta = data
    return msg
}
// JSON creates a properly formatted JSON
func (msg *Error) JSON() interface{} {
    json := H{}
    if msg.Meta != nil {
        value := reflect.ValueOf(msg.Meta)
        switch value.Kind() {
        case reflect.Struct:
            return msg.Meta
        case reflect.Map:
            for _, key := range value.MapKeys() {
                json[key.String()] = value.MapIndex(key).Interface()
            }
        default:
            json["meta"] = msg.Meta
        }
    }
    if _, ok := json["error"]; !ok {
        json["error"] = msg.Error()
    }
    return json
}

// MarshalJSON implements the json.Marshaller interface.
func (msg *Error) MarshalJSON() ([]byte, error) {
    return json.Marshal(msg.JSON())
}

// Error implements the error interface.
func (msg Error) Error() string {
    return msg.Err.Error()
}

判断错误类型的方式有点特别:

// IsType judges one error.
func (msg *Error) IsType(flags ErrorType) bool {
    return (msg.Type & flags) > 0
}

这用到了位运算 &, 难道比普通的 == 更快吗?

后面都是 errorMsgs 的方法:

// ByType returns a readonly copy filtered the byte.
// ie ByType(gin.ErrorTypePublic) returns a slice of errors with type=ErrorTypePublic.
func (a errorMsgs) ByType(typ ErrorType) errorMsgs {
    if len(a) == 0 {
        return nil
    }
    if typ == ErrorTypeAny {
        return a
    }
    var result errorMsgs
    for _, msg := range a {
        if msg.IsType(typ) {
            result = append(result, msg)
        }
    }
    return result
}

// Last returns the last error in the slice. It returns nil if the array is empty.
// Shortcut for errors[len(errors)-1].
func (a errorMsgs) Last() *Error {
    if length := len(a); length > 0 {
        return a[length-1]
    }
    return nil
}

// Errors returns an array will all the error messages.
// Example:
//         c.Error(errors.New("first"))
//         c.Error(errors.New("second"))
//         c.Error(errors.New("third"))
//         c.Errors.Errors() // == []string{"first", "second", "third"}
func (a errorMsgs) Errors() []string {
    if len(a) == 0 {
        return nil
    }
    errorStrings := make([]string, len(a))
    for i, err := range a {
        errorStrings[i] = err.Error()
    }
    return errorStrings
}

func (a errorMsgs) JSON() interface{} {
    switch len(a) {
    case 0:
        return nil
    case 1:
        return a.Last().JSON()
    default:
        json := make([]interface{}, len(a))
        for i, err := range a {
            json[i] = err.JSON()
        }
        return json
    }
}

// MarshalJSON implements the json.Marshaller interface.
func (a errorMsgs) MarshalJSON() ([]byte, error) {
    return json.Marshal(a.JSON())
}

func (a errorMsgs) String() string {
    if len(a) == 0 {
        return ""
    }
    var buffer strings.Builder
    for i, msg := range a {
        fmt.Fprintf(&buffer, "Error #%02d: %s\n", i+1, msg.Err)
        if msg.Meta != nil {
            fmt.Fprintf(&buffer, "     Meta: %v\n", msg.Meta)
        }
    }
    return buffer.String()
}

总结

差不多就是这样, 结合前几篇, 已经将 Gin 的源码看的差不多了.
binding 和 render 部分只挑选了 JSON 实现.

查看原文

赞 1 收藏 0 评论 0

帅气猫咪 发布了文章 · 2019-12-08

03Gin源码解读

简介

Gin 源码解读, 基于 v1.5.0 版本.

Context 初始化

Context 是 Gin 中很重要的一个部分, 先看一下注释是怎么说的.

// Context is the most important part of gin. It allows us to pass variables between middleware,
// manage the flow, validate the JSON of a request and render a JSON response for example.
type Context struct {
    writermem responseWriter
    Request   *http.Request
    Writer    ResponseWriter

    Params   Params
    handlers HandlersChain
    index    int8
    fullPath string

    engine *Engine

    // Keys is a key/value pair exclusively for the context of each request.
    Keys map[string]interface{}

    // Errors is a list of errors attached to all the handlers/middlewares who used this context.
    Errors errorMsgs

    // Accepted defines a list of manually accepted formats for content negotiation.
    Accepted []string

    // queryCache use url.ParseQuery cached the param query result from c.Request.URL.Query()
    queryCache url.Values

    // formCache use url.ParseQuery cached PostForm contains the parsed form data from POST, PATCH,
    // or PUT body parameters.
    formCache url.Values
}

注释中说到, Context 用于中间件中的变量传递, 流程控制, 验证请求的 JSON 格式以及返回 JSON 响应等.

Context 是在每次接受请求的时候初始化的:

// ServeHTTP conforms to the http.Handler interface.
func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    c := engine.pool.Get().(*Context)
    c.writermem.reset(w)
    c.Request = req
    c.reset()

    engine.handleHTTPRequest(c)

    engine.pool.Put(c)
}

里面用到了 sync.Pool, sync.Pool 适用于缓存已分配但未使用的 items, 以便后续重用, 并减轻垃圾回收的压力.

type Pool struct {

    // New optionally specifies a function to generate
    // a value when Get would otherwise return nil.
    // It may not be changed concurrently with calls to Get.
    New func() interface{}
    // contains filtered or unexported fields
}

sync.Pool 需要实现一个名为 New 的方法, 这其实在初始化 Engine 的时候就已经完成了.

// New returns a new blank Engine instance without any middleware attached.
// By default the configuration is:
// - RedirectTrailingSlash:  true
// - RedirectFixedPath:      false
// - HandleMethodNotAllowed: false
// - ForwardedByClientIP:    true
// - UseRawPath:             false
// - UnescapePathValues:     true
func New() *Engine {
    debugPrintWARNINGNew()
    engine := &Engine{
        RouterGroup: RouterGroup{
            Handlers: nil,
            basePath: "/",
            root:     true,
        },
        FuncMap:                template.FuncMap{},
        RedirectTrailingSlash:  true,
        RedirectFixedPath:      false,
        HandleMethodNotAllowed: false,
        ForwardedByClientIP:    true,
        AppEngine:              defaultAppEngine,
        UseRawPath:             false,
        UnescapePathValues:     true,
        MaxMultipartMemory:     defaultMultipartMemory,
        trees:                  make(methodTrees, 0, 9),
        delims:                 render.Delims{Left: "{{", Right: "}}"},
        secureJsonPrefix:       "while(1);",
    }
    engine.RouterGroup.engine = engine
    engine.pool.New = func() interface{} {
        return engine.allocateContext()
    }
    return engine
}

func (engine *Engine) allocateContext() *Context {
    return &Context{engine: engine}
}

由此, 我们已经知道了 Context 是如何初始化的了.

Context 之请求参数获取

Context 肩负着很重要的使命, 所有的处理函数的唯一参数就是 Context.

// HandlerFunc defines the handler used by gin middleware as return value.
type HandlerFunc func(*Context)

在探究中间件的原理时, 我们已经看过了流程控制, 即 context.Next() 方法:

// Next should be used only inside middleware.
// It executes the pending handlers in the chain inside the calling handler.
// See example in GitHub.
func (c *Context) Next() {
    c.index++
    for c.index < int8(len(c.handlers)) {
        c.handlers[c.index](c)
        c.index++
    }
}

接着看一下如何获取请求参数, 比如 URL 中的参数, GET 中的 query, 或者是 POST 中的 data.

// However, this one will match /user/john/ and also /user/john/send
// If no other routers match /user/john, it will redirect to /user/john/
router.GET("/user/:name/*action", func(c *gin.Context) {
  name := c.Param("name")
  action := c.Param("action")
  message := name + " is " + action
  c.String(http.StatusOK, message)
})

// Query string parameters are parsed using the existing underlying request object.
// The request responds to a url matching:  /welcome?firstname=Jane&lastname=Doe
router.GET("/welcome", func(c *gin.Context) {
  firstname := c.DefaultQuery("firstname", "Guest")
  lastname := c.Query("lastname") // shortcut for c.Request.URL.Query().Get("lastname")

  c.String(http.StatusOK, "Hello %s %s", firstname, lastname)
})

router.POST("/form_post", func(c *gin.Context) {
  message := c.PostForm("message")
  nick := c.DefaultPostForm("nick", "anonymous")

  c.JSON(200, gin.H{
    "status":  "posted",
    "message": message,
    "nick":    nick,
  })
})

router.POST("/post", func(c *gin.Context) {

  ids := c.QueryMap("ids")
  names := c.PostFormMap("names")

  fmt.Printf("ids: %v; names: %v", ids, names)
})

上面的示例来自官方文档, 看一下其中涉及到的方法是如何实现的.

func (c *Context) Param(key string) string {
    return c.Params.ByName(key)
}

func (c *Context) Query(key string) string {
    value, _ := c.GetQuery(key)
    return value
}

func (c *Context) DefaultQuery(key, defaultValue string) string {
    if value, ok := c.GetQuery(key); ok {
        return value
    }
    return defaultValue
}

func (c *Context) GetQuery(key string) (string, bool) {
    if values, ok := c.GetQueryArray(key); ok {
        return values[0], ok
    }
    return "", false
}

func (c *Context) getQueryCache() {
    if c.queryCache == nil {
        c.queryCache = c.Request.URL.Query()
    }
}

func (c *Context) GetQueryArray(key string) ([]string, bool) {
    c.getQueryCache()
    if values, ok := c.queryCache[key]; ok && len(values) > 0 {
        return values, true
    }
    return []string{}, false
}

func (c *Context) PostForm(key string) string {
    value, _ := c.GetPostForm(key)
    return value
}

func (c *Context) DefaultPostForm(key, defaultValue string) string {
    if value, ok := c.GetPostForm(key); ok {
        return value
    }
    return defaultValue
}

func (c *Context) GetPostForm(key string) (string, bool) {
    if values, ok := c.GetPostFormArray(key); ok {
        return values[0], ok
    }
    return "", false
}

func (c *Context) PostFormArray(key string) []string {
    values, _ := c.GetPostFormArray(key)
    return values
}

func (c *Context) getFormCache() {
    if c.formCache == nil {
        c.formCache = make(url.Values)
        req := c.Request
        if err := req.ParseMultipartForm(c.engine.MaxMultipartMemory); err != nil {
            if err != http.ErrNotMultipart {
                debugPrint("error on parse multipart form array: %v", err)
            }
        }
        c.formCache = req.PostForm
    }
}

func (c *Context) GetPostFormArray(key string) ([]string, bool) {
    c.getFormCache()
    if values := c.formCache[key]; len(values) > 0 {
        return values, true
    }
    return []string{}, false
}

从上面的代码可以看出 GetQueryArrayGetPostFormArray 的实现非常相似, 都使用内部缓存.

func (c *Context) QueryMap(key string) map[string]string {
    dicts, _ := c.GetQueryMap(key)
    return dicts
}

func (c *Context) GetQueryMap(key string) (map[string]string, bool) {
    c.getQueryCache()
    return c.get(c.queryCache, key)
}

func (c *Context) PostFormMap(key string) map[string]string {
    dicts, _ := c.GetPostFormMap(key)
    return dicts
}

func (c *Context) GetPostFormMap(key string) (map[string]string, bool) {
    c.getFormCache()
    return c.get(c.formCache, key)
}

// get is an internal method and returns a map which satisfy conditions.
func (c *Context) get(m map[string][]string, key string) (map[string]string, bool) {
    dicts := make(map[string]string)
    exist := false
    for k, v := range m {
        if i := strings.IndexByte(k, '['); i >= 1 && k[0:i] == key {
            if j := strings.IndexByte(k[i+1:], ']'); j >= 1 {
                exist = true
                dicts[k[i+1:][:j]] = v[0]
            }
        }
    }
    return dicts, exist
}

上面的代码实现了参数的 map 化, 可以看下具体的请求参数, 下面的例子中 ids 就是一个 map, 它有两个 key.

POST /post?ids[a]=1234&ids[b]=hello HTTP/1.1
Content-Type: application/x-www-form-urlencoded

names[first]=thinkerou&names[second]=tianou

这不是 HTTP 中定义的内容, 使用的时候必须遵从这种规范, 可能在特定的场景下比较有用.
但一般不太会这么使用, 因为如果是公开的 API, 则其他语言都要实现这种类型的解析.
解析的代码倒是没有什么特别的,

再看一下文件类型的如何实现的, 即 FormFile.

// FormFile returns the first file for the provided form key.
func (c *Context) FormFile(name string) (*multipart.FileHeader, error) {
    if c.Request.MultipartForm == nil {
        if err := c.Request.ParseMultipartForm(c.engine.MaxMultipartMemory); err != nil {
            return nil, err
        }
    }
    f, fh, err := c.Request.FormFile(name)
    if err != nil {
        return nil, err
    }
    f.Close()
    return fh, err
}

// MultipartForm is the parsed multipart form, including file uploads.
func (c *Context) MultipartForm() (*multipart.Form, error) {
    err := c.Request.ParseMultipartForm(c.engine.MaxMultipartMemory)
    return c.Request.MultipartForm, err
}

// SaveUploadedFile uploads the form file to specific dst.
func (c *Context) SaveUploadedFile(file *multipart.FileHeader, dst string) error {
    src, err := file.Open()
    if err != nil {
        return err
    }
    defer src.Close()

    out, err := os.Create(dst)
    if err != nil {
        return err
    }
    defer out.Close()

    _, err = io.Copy(out, src)
    return err
}

稍微包装了一下 c.Request.MultipartForm, 对于单个文件而言更方便些, 保存文件的方法也有了.

Context 之模型绑定和验证

模型绑定是一个非常有用的能力, 尤其是和验证结合在一起. 处理请求参数时, 一大重点就是验证.

Gin 支持两种类型的绑定, Must bindShould bind. 请求类型则支持 JSON, XML, YAML 和标准表单绑定.

先来看一下 Must bind:

// Bind checks the Content-Type to select a binding engine automatically,
// Depending the "Content-Type" header different bindings are used:
//     "application/json" --> JSON binding
//     "application/xml"  --> XML binding
// otherwise --> returns an error.
// It parses the request's body as JSON if Content-Type == "application/json" using JSON or XML as a JSON input.
// It decodes the json payload into the struct specified as a pointer.
// It writes a 400 error and sets Content-Type header "text/plain" in the response if input is not valid.
func (c *Context) Bind(obj interface{}) error {
    b := binding.Default(c.Request.Method, c.ContentType())
    return c.MustBindWith(obj, b)
}

// BindJSON is a shortcut for c.MustBindWith(obj, binding.JSON).
func (c *Context) BindJSON(obj interface{}) error {
    return c.MustBindWith(obj, binding.JSON)
}

// BindXML is a shortcut for c.MustBindWith(obj, binding.BindXML).
func (c *Context) BindXML(obj interface{}) error {
    return c.MustBindWith(obj, binding.XML)
}

// BindQuery is a shortcut for c.MustBindWith(obj, binding.Query).
func (c *Context) BindQuery(obj interface{}) error {
    return c.MustBindWith(obj, binding.Query)
}

// BindYAML is a shortcut for c.MustBindWith(obj, binding.YAML).
func (c *Context) BindYAML(obj interface{}) error {
    return c.MustBindWith(obj, binding.YAML)
}

// BindHeader is a shortcut for c.MustBindWith(obj, binding.Header).
func (c *Context) BindHeader(obj interface{}) error {
    return c.MustBindWith(obj, binding.Header)
}

// BindUri binds the passed struct pointer using binding.Uri.
// It will abort the request with HTTP 400 if any error occurs.
func (c *Context) BindUri(obj interface{}) error {
    if err := c.ShouldBindUri(obj); err != nil {
        c.AbortWithError(http.StatusBadRequest, err).SetType(ErrorTypeBind) // nolint: errcheck
        return err
    }
    return nil
}

// MustBindWith binds the passed struct pointer using the specified binding engine.
// It will abort the request with HTTP 400 if any error occurs.
// See the binding package.
func (c *Context) MustBindWith(obj interface{}, b binding.Binding) error {
    if err := c.ShouldBindWith(obj, b); err != nil {
        c.AbortWithError(http.StatusBadRequest, err).SetType(ErrorTypeBind) // nolint: errcheck
        return err
    }
    return nil
}

从上面的代码可以发现, MustBindWith 其实是 ShouldBindWith 的包装, 具体内容还是要看 ShouldBindWith.
另一点是绑定支持多种数据类型, 比如 BindQuery, BindHeader, BindUri.

// ShouldBindWith binds the passed struct pointer using the specified binding engine.
// See the binding package.
func (c *Context) ShouldBindWith(obj interface{}, b binding.Binding) error {
    return b.Bind(c.Request, obj)
}

但实际上 ShouldBindWith 也只是调用了 binding.Binding 上的方法而言.

// Binding describes the interface which needs to be implemented for binding the
// data present in the request such as JSON request body, query parameters or
// the form POST.
type Binding interface {
    Name() string
    Bind(*http.Request, interface{}) error
}

// BindingBody adds BindBody method to Binding. BindBody is similar with Bind,
// but it reads the body from supplied bytes instead of req.Body.
type BindingBody interface {
    Binding
    BindBody([]byte, interface{}) error
}

// BindingUri adds BindUri method to Binding. BindUri is similar with Bind,
// but it read the Params.
type BindingUri interface {
    Name() string
    BindUri(map[string][]string, interface{}) error
}

// These implement the Binding interface and can be used to bind the data
// present in the request to struct instances.
var (
    JSON          = jsonBinding{}
    XML           = xmlBinding{}
    Form          = formBinding{}
    Query         = queryBinding{}
    FormPost      = formPostBinding{}
    FormMultipart = formMultipartBinding{}
    ProtoBuf      = protobufBinding{}
    MsgPack       = msgpackBinding{}
    YAML          = yamlBinding{}
    Uri           = uriBinding{}
    Header        = headerBinding{}
)

上面的代码显示了 Binding 接口, 以及实现了 Binding 接口的类型, 具体以 JSON 为例, 看一下 jsonBinding 是如何实现的.

package binding

import (
    "bytes"
    "fmt"
    "io"
    "net/http"

    "github.com/gin-gonic/gin/internal/json"
)

// EnableDecoderUseNumber is used to call the UseNumber method on the JSON
// Decoder instance. UseNumber causes the Decoder to unmarshal a number into an
// interface{} as a Number instead of as a float64.
var EnableDecoderUseNumber = false

// EnableDecoderDisallowUnknownFields is used to call the DisallowUnknownFields method
// on the JSON Decoder instance. DisallowUnknownFields causes the Decoder to
// return an error when the destination is a struct and the input contains object
// keys which do not match any non-ignored, exported fields in the destination.
var EnableDecoderDisallowUnknownFields = false

type jsonBinding struct{}

func (jsonBinding) Name() string {
    return "json"
}

func (jsonBinding) Bind(req *http.Request, obj interface{}) error {
    if req == nil || req.Body == nil {
        return fmt.Errorf("invalid request")
    }
    return decodeJSON(req.Body, obj)
}

func (jsonBinding) BindBody(body []byte, obj interface{}) error {
    return decodeJSON(bytes.NewReader(body), obj)
}

func decodeJSON(r io.Reader, obj interface{}) error {
    decoder := json.NewDecoder(r)
    if EnableDecoderUseNumber {
        decoder.UseNumber()
    }
    if EnableDecoderDisallowUnknownFields {
        decoder.DisallowUnknownFields()
    }
    if err := decoder.Decode(obj); err != nil {
        return err
    }
    return validate(obj)
}

代码也不长, 内部用了自定义的 json 接口, 以便实现可替换的 JSON 编解码.

解码的最后一步是验证, 调用了 validate 函数:

func validate(obj interface{}) error {
    if Validator == nil {
        return nil
    }
    return Validator.ValidateStruct(obj)
}

由此, 可以引申到验证方面, 看一下是如何结合验证的.

// StructValidator is the minimal interface which needs to be implemented in
// order for it to be used as the validator engine for ensuring the correctness
// of the request. Gin provides a default implementation for this using
// https://github.com/go-playground/validator/tree/v8.18.2.
type StructValidator interface {
    // ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right.
    // If the received type is not a struct, any validation should be skipped and nil must be returned.
    // If the received type is a struct or pointer to a struct, the validation should be performed.
    // If the struct is not valid or the validation itself fails, a descriptive error should be returned.
    // Otherwise nil must be returned.
    ValidateStruct(interface{}) error

    // Engine returns the underlying validator engine which powers the
    // StructValidator implementation.
    Engine() interface{}
}

// Validator is the default validator which implements the StructValidator
// interface. It uses https://github.com/go-playground/validator/tree/v8.18.2
// under the hood.
var Validator StructValidator = &defaultValidator{}

验证器需要实现 StructValidator 接口, 看一下默认的验证器的实现.

package binding

import (
    "reflect"
    "sync"

    "gopkg.in/go-playground/validator.v9"
)

type defaultValidator struct {
    once     sync.Once
    validate *validator.Validate
}

var _ StructValidator = &defaultValidator{}

// ValidateStruct receives any kind of type, but only performed struct or pointer to struct type.
func (v *defaultValidator) ValidateStruct(obj interface{}) error {
    value := reflect.ValueOf(obj)
    valueType := value.Kind()
    if valueType == reflect.Ptr {
        valueType = value.Elem().Kind()
    }
    if valueType == reflect.Struct {
        v.lazyinit()
        if err := v.validate.Struct(obj); err != nil {
            return err
        }
    }
    return nil
}

// Engine returns the underlying validator engine which powers the default
// Validator instance. This is useful if you want to register custom validations
// or struct level validations. See validator GoDoc for more info -
// https://godoc.org/gopkg.in/go-playground/validator.v8
func (v *defaultValidator) Engine() interface{} {
    v.lazyinit()
    return v.validate
}

func (v *defaultValidator) lazyinit() {
    v.once.Do(func() {
        v.validate = validator.New()
        v.validate.SetTagName("binding")
    })
}

默认的验证器是 validator.v9, 使用了懒初始化, 以及使用 reflect 判断数据类型, 只验证结构体.

Context 之响应

看完了请求参数的获取和模型绑定之后, 来看看响应是如何发送的.

先来看一下 Context 中用到的 responseWriter 类型和 ResponseWriter 类型.

type Context struct {
    writermem responseWriter
    Request   *http.Request
    Writer    ResponseWriter
// ResponseWriter ...
type ResponseWriter interface {
    http.ResponseWriter
    http.Hijacker
    http.Flusher
    http.CloseNotifier

    // Returns the HTTP response status code of the current request.
    Status() int

    // Returns the number of bytes already written into the response http body.
    // See Written()
    Size() int

    // Writes the string into the response body.
    WriteString(string) (int, error)

    // Returns true if the response body was already written.
    Written() bool

    // Forces to write the http header (status code + headers).
    WriteHeaderNow()

    // get the http.Pusher for server push
    Pusher() http.Pusher
}

type responseWriter struct {
    http.ResponseWriter
    size   int
    status int
}

var _ ResponseWriter = &responseWriter{}

ResponseWriter 接口组合了 http 包中用于响应的数据结构, 所有的方法上都有注释.
responseWriter 实际上就是实现了 ResponseWriter 接口的结构体.

在继续之前, 先来了解下 Context 中 writermem 的作用.
Writer 是用于写入响应的, 而从 writermem 名字的后缀, 可以推断出这和内存有关.
再寻找一下它的用处.

// ServeHTTP conforms to the http.Handler interface.
func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    c := engine.pool.Get().(*Context)
    c.writermem.reset(w)
    c.Request = req
    c.reset()

    engine.handleHTTPRequest(c)

    engine.pool.Put(c)
}

func (c *Context) reset() {
  c.Writer = &c.writermem
  ...
}

func (w *responseWriter) reset(writer http.ResponseWriter) {
    w.ResponseWriter = writer
    w.size = noWritten
    w.status = defaultStatus
}

所以, 可以推断出 writermem 是每次请求时 w http.ResponseWriter 的拥有者, 而 c.Writer 是它的指针.

继续看 Context 是如何处理响应的.

func (c *Context) requestHeader(key string) string {
    return c.Request.Header.Get(key)
}

// Status sets the HTTP response code.
func (c *Context) Status(code int) {
    c.Writer.WriteHeader(code)
}

// Header is a intelligent shortcut for c.Writer.Header().Set(key, value).
// It writes a header in the response.
// If value == "", this method removes the header `c.Writer.Header().Del(key)`
func (c *Context) Header(key, value string) {
    if value == "" {
        c.Writer.Header().Del(key)
        return
    }
    c.Writer.Header().Set(key, value)
}

// GetHeader returns value from request headers.
func (c *Context) GetHeader(key string) string {
    return c.requestHeader(key)
}

上面是和 Header 有关的部分, 实际上是内部的 Writer.Header() 的代理.

接着看和 Cookie 有关的部分:

// SetCookie adds a Set-Cookie header to the ResponseWriter's headers.
// The provided cookie must have a valid Name. Invalid cookies may be
// silently dropped.
func (c *Context) SetCookie(name, value string, maxAge int, path, domain string, secure, httpOnly bool) {
    if path == "" {
        path = "/"
    }
    http.SetCookie(c.Writer, &http.Cookie{
        Name:     name,
        Value:    url.QueryEscape(value),
        MaxAge:   maxAge,
        Path:     path,
        Domain:   domain,
        Secure:   secure,
        HttpOnly: httpOnly,
    })
}

// Cookie returns the named cookie provided in the request or
// ErrNoCookie if not found. And return the named cookie is unescaped.
// If multiple cookies match the given name, only one cookie will
// be returned.
func (c *Context) Cookie(name string) (string, error) {
    cookie, err := c.Request.Cookie(name)
    if err != nil {
        return "", err
    }
    val, _ := url.QueryUnescape(cookie.Value)
    return val, nil
}

整合了 Cookie 的读取与设置.

看完 Header 和 Cookie 之后, 接下来就是重点了, 看一下如何渲染内容, 即返回的响应.

Gin 支持 XML, JSON, YAML and ProtoBuf rendering, 看一下具体的实现方式.

// Render writes the response headers and calls render.Render to render data.
func (c *Context) Render(code int, r render.Render) {
    c.Status(code)

    if !bodyAllowedForStatus(code) {
        r.WriteContentType(c.Writer)
        c.Writer.WriteHeaderNow()
        return
    }

    if err := r.Render(c.Writer); err != nil {
        panic(err)
    }
}

主要的方法就是 Render, 而内部使用了 render.Render 接口中的 Render 方法.

// HTML renders the HTTP template specified by its file name.
// It also updates the HTTP code and sets the Content-Type as "text/html".
// See http://golang.org/doc/articles/wiki/
func (c *Context) HTML(code int, name string, obj interface{}) {
    instance := c.engine.HTMLRender.Instance(name, obj)
    c.Render(code, instance)
}

// IndentedJSON serializes the given struct as pretty JSON (indented + endlines) into the response body.
// It also sets the Content-Type as "application/json".
// WARNING: we recommend to use this only for development purposes since printing pretty JSON is
// more CPU and bandwidth consuming. Use Context.JSON() instead.
func (c *Context) IndentedJSON(code int, obj interface{}) {
    c.Render(code, render.IndentedJSON{Data: obj})
}

// SecureJSON serializes the given struct as Secure JSON into the response body.
// Default prepends "while(1)," to response body if the given struct is array values.
// It also sets the Content-Type as "application/json".
func (c *Context) SecureJSON(code int, obj interface{}) {
    c.Render(code, render.SecureJSON{Prefix: c.engine.secureJsonPrefix, Data: obj})
}

// JSONP serializes the given struct as JSON into the response body.
// It add padding to response body to request data from a server residing in a different domain than the client.
// It also sets the Content-Type as "application/javascript".
func (c *Context) JSONP(code int, obj interface{}) {
    callback := c.DefaultQuery("callback", "")
    if callback == "" {
        c.Render(code, render.JSON{Data: obj})
        return
    }
    c.Render(code, render.JsonpJSON{Callback: callback, Data: obj})
}

// JSON serializes the given struct as JSON into the response body.
// It also sets the Content-Type as "application/json".
func (c *Context) JSON(code int, obj interface{}) {
    c.Render(code, render.JSON{Data: obj})
}

// AsciiJSON serializes the given struct as JSON into the response body with unicode to ASCII string.
// It also sets the Content-Type as "application/json".
func (c *Context) AsciiJSON(code int, obj interface{}) {
    c.Render(code, render.AsciiJSON{Data: obj})
}

// PureJSON serializes the given struct as JSON into the response body.
// PureJSON, unlike JSON, does not replace special html characters with their unicode entities.
func (c *Context) PureJSON(code int, obj interface{}) {
    c.Render(code, render.PureJSON{Data: obj})
}

// XML serializes the given struct as XML into the response body.
// It also sets the Content-Type as "application/xml".
func (c *Context) XML(code int, obj interface{}) {
    c.Render(code, render.XML{Data: obj})
}

// YAML serializes the given struct as YAML into the response body.
func (c *Context) YAML(code int, obj interface{}) {
    c.Render(code, render.YAML{Data: obj})
}

// ProtoBuf serializes the given struct as ProtoBuf into the response body.
func (c *Context) ProtoBuf(code int, obj interface{}) {
    c.Render(code, render.ProtoBuf{Data: obj})
}

// String writes the given string into the response body.
func (c *Context) String(code int, format string, values ...interface{}) {
    c.Render(code, render.String{Format: format, Data: values})
}

// Redirect returns a HTTP redirect to the specific location.
func (c *Context) Redirect(code int, location string) {
    c.Render(-1, render.Redirect{
        Code:     code,
        Location: location,
        Request:  c.Request,
    })
}

// Data writes some data into the body stream and updates the HTTP code.
func (c *Context) Data(code int, contentType string, data []byte) {
    c.Render(code, render.Data{
        ContentType: contentType,
        Data:        data,
    })
}

// DataFromReader writes the specified reader into the body stream and updates the HTTP code.
func (c *Context) DataFromReader(code int, contentLength int64, contentType string, reader io.Reader, extraHeaders map[string]string) {
    c.Render(code, render.Reader{
        Headers:       extraHeaders,
        ContentType:   contentType,
        ContentLength: contentLength,
        Reader:        reader,
    })
}

看一下这些迥异的 render.Render 接口的实现者.

package render

import "net/http"

// Render interface is to be implemented by JSON, XML, HTML, YAML and so on.
type Render interface {
    // Render writes data with custom ContentType.
    Render(http.ResponseWriter) error
    // WriteContentType writes custom ContentType.
    WriteContentType(w http.ResponseWriter)
}

var (
    _ Render     = JSON{}
    _ Render     = IndentedJSON{}
    _ Render     = SecureJSON{}
    _ Render     = JsonpJSON{}
    _ Render     = XML{}
    _ Render     = String{}
    _ Render     = Redirect{}
    _ Render     = Data{}
    _ Render     = HTML{}
    _ HTMLRender = HTMLDebug{}
    _ HTMLRender = HTMLProduction{}
    _ Render     = YAML{}
    _ Render     = MsgPack{}
    _ Render     = Reader{}
    _ Render     = AsciiJSON{}
    _ Render     = ProtoBuf{}
)

func writeContentType(w http.ResponseWriter, value []string) {
    header := w.Header()
    if val := header["Content-Type"]; len(val) == 0 {
        header["Content-Type"] = value
    }
}

上面是 Render 接口的定义, 主要需要实现 Render 方法.
WriteContentType 方法实际上已经被 writeContentType 函数实现得差不多了,
只是每种渲染方式对应的 Content-Type 值不同.

以 JSON 为例, 看一下具体是如何实现的.

// JSON contains the given interface object.
type JSON struct {
    Data interface{}
}

var jsonContentType = []string{"application/json; charset=utf-8"}

// Render (JSON) writes data with custom ContentType.
func (r JSON) Render(w http.ResponseWriter) (err error) {
    if err = WriteJSON(w, r.Data); err != nil {
        panic(err)
    }
    return
}

// WriteContentType (JSON) writes JSON ContentType.
func (r JSON) WriteContentType(w http.ResponseWriter) {
    writeContentType(w, jsonContentType)
}

// WriteJSON marshals the given interface object and writes it with custom ContentType.
func WriteJSON(w http.ResponseWriter, obj interface{}) error {
    writeContentType(w, jsonContentType)
    encoder := json.NewEncoder(w)
    err := encoder.Encode(&obj)
    return err
}

看上去非常简洁, 实现也不复杂.

RenderBinding 非常相似, 都是通过定义接口, 然后用不同的结构体实现具体的功能.

Context 之高级响应

// File writes the specified file into the body stream in a efficient way.
func (c *Context) File(filepath string) {
    http.ServeFile(c.Writer, c.Request, filepath)
}

// FileAttachment writes the specified file into the body stream in an efficient way
// On the client side, the file will typically be downloaded with the given filename
func (c *Context) FileAttachment(filepath, filename string) {
    c.Writer.Header().Set("content-disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename))
    http.ServeFile(c.Writer, c.Request, filepath)
}

托管静态文件, 使用的是 http.ServeFile, 也实现了附件下载的功能, 还是挺方便的,
虽然只是 content-disposition 这个 Header 的功能.

// SSEvent writes a Server-Sent Event into the body stream.
func (c *Context) SSEvent(name string, message interface{}) {
    c.Render(-1, sse.Event{
        Event: name,
        Data:  message,
    })
}

SSEvent 实现了服务端推送事件的功能, 具体看一下它的实现.

package sse

import (
    "encoding/json"
    "fmt"
    "io"
    "net/http"
    "reflect"
    "strconv"
    "strings"
)

// Server-Sent Events
// W3C Working Draft 29 October 2009
// http://www.w3.org/TR/2009/WD-eventsource-20091029/

const ContentType = "text/event-stream"

var contentType = []string{ContentType}
var noCache = []string{"no-cache"}

var fieldReplacer = strings.NewReplacer(
    "\n", "\\n",
    "\r", "\\r")

var dataReplacer = strings.NewReplacer(
    "\n", "\ndata:",
    "\r", "\\r")

type Event struct {
    Event string
    Id    string
    Retry uint
    Data  interface{}
}

func Encode(writer io.Writer, event Event) error {
    w := checkWriter(writer)
    writeId(w, event.Id)
    writeEvent(w, event.Event)
    writeRetry(w, event.Retry)
    return writeData(w, event.Data)
}

func writeId(w stringWriter, id string) {
    if len(id) > 0 {
        w.WriteString("id:")
        fieldReplacer.WriteString(w, id)
        w.WriteString("\n")
    }
}

func writeEvent(w stringWriter, event string) {
    if len(event) > 0 {
        w.WriteString("event:")
        fieldReplacer.WriteString(w, event)
        w.WriteString("\n")
    }
}

func writeRetry(w stringWriter, retry uint) {
    if retry > 0 {
        w.WriteString("retry:")
        w.WriteString(strconv.FormatUint(uint64(retry), 10))
        w.WriteString("\n")
    }
}

func writeData(w stringWriter, data interface{}) error {
    w.WriteString("data:")
    switch kindOfData(data) {
    case reflect.Struct, reflect.Slice, reflect.Map:
        err := json.NewEncoder(w).Encode(data)
        if err != nil {
            return err
        }
        w.WriteString("\n")
    default:
        dataReplacer.WriteString(w, fmt.Sprint(data))
        w.WriteString("\n\n")
    }
    return nil
}

func (r Event) Render(w http.ResponseWriter) error {
    r.WriteContentType(w)
    return Encode(w, r)
}

func (r Event) WriteContentType(w http.ResponseWriter) {
    header := w.Header()
    header["Content-Type"] = contentType

    if _, exist := header["Cache-Control"]; !exist {
        header["Cache-Control"] = noCache
    }
}

func kindOfData(data interface{}) reflect.Kind {
    value := reflect.ValueOf(data)
    valueType := value.Kind()
    if valueType == reflect.Ptr {
        valueType = value.Elem().Kind()
    }
    return valueType
}

SSEvent 是作为扩展实现的, 代码并不在 Gin 的源码中. 先看一下 Event 结构体.

type Event struct {
    Event string
    Id    string
    Retry uint
    Data  interface{}
}

func (r Event) Render(w http.ResponseWriter) error {
    r.WriteContentType(w)
    return Encode(w, r)
}

Event 实现了 Render 接口, 看一下内部的 Encode 函数.

func Encode(writer io.Writer, event Event) error {
    w := checkWriter(writer)
    writeId(w, event.Id)
    writeEvent(w, event.Event)
    writeRetry(w, event.Retry)
    return writeData(w, event.Data)
}

过程并不复杂, 分为四步写入, 分别是事件 ID, 事件名 Event, 重连时间 Retry, 消息体 Data.
如果对服务端推送事件不太了解, 可以参考
MDN-使用服务器发送事件..

事件流仅仅是一个简单的文本数据流,文本应该使用 UTF- 8 格式的编码.每条消息后面都由一个空行作为分隔符.以冒号开头的行为注释行,会被忽略.
注:注释行可以用来防止连接超时,服务器可以定期发送一条消息注释行,以保持连接不断.
每条消息是由多个字段组成的,每个字段由字段名,一个冒号,以及字段值组成.

实际上并没有对消息体的格式做任何要求, 这属于前后端协定的范围.

func writeData(w stringWriter, data interface{}) error {
    w.WriteString("data:")
    switch kindOfData(data) {
    case reflect.Struct, reflect.Slice, reflect.Map:
        err := json.NewEncoder(w).Encode(data)
        if err != nil {
            return err
        }
        w.WriteString("\n")
    default:
        dataReplacer.WriteString(w, fmt.Sprint(data))
        w.WriteString("\n\n")
    }
    return nil
}

该实现中, 主要使用了 JSON 格式, 但对其他类型的数据直接写入纯文本.

接着看一下流式响应是如何实现的.

// Stream sends a streaming response and returns a boolean
// indicates "Is client disconnected in middle of stream"
func (c *Context) Stream(step func(w io.Writer) bool) bool {
    w := c.Writer
    clientGone := w.CloseNotify()
    for {
        select {
        case <-clientGone:
            return true
        default:
            keepOpen := step(w)
            w.Flush()
            if !keepOpen {
                return false
            }
        }
    }
}

这是一个非常常见的模式, 使用 for 和 select 以及 channel 实现无限循环.

Context 之内容协商

内容协商通过 Accept Header 实现, 用于为不同类型的客户端提供不同类型的资源,
比如协商网页语言或响应格式等.

具体可以参考 MDN-内容协商.

// Negotiate contains all negotiations data.
type Negotiate struct {
    Offered  []string
    HTMLName string
    HTMLData interface{}
    JSONData interface{}
    XMLData  interface{}
    Data     interface{}
}

// Negotiate calls different Render according acceptable Accept format.
func (c *Context) Negotiate(code int, config Negotiate) {
    switch c.NegotiateFormat(config.Offered...) {
    case binding.MIMEJSON:
        data := chooseData(config.JSONData, config.Data)
        c.JSON(code, data)

    case binding.MIMEHTML:
        data := chooseData(config.HTMLData, config.Data)
        c.HTML(code, config.HTMLName, data)

    case binding.MIMEXML:
        data := chooseData(config.XMLData, config.Data)
        c.XML(code, data)

    default:
        c.AbortWithError(http.StatusNotAcceptable, errors.New("the accepted formats are not offered by the server")) // nolint: errcheck
    }
}

// NegotiateFormat returns an acceptable Accept format.
func (c *Context) NegotiateFormat(offered ...string) string {
    assert1(len(offered) > 0, "you must provide at least one offer")

    if c.Accepted == nil {
        c.Accepted = parseAccept(c.requestHeader("Accept"))
    }
    if len(c.Accepted) == 0 {
        return offered[0]
    }
    for _, accepted := range c.Accepted {
        for _, offert := range offered {
            // According to RFC 2616 and RFC 2396, non-ASCII characters are not allowed in headers,
            // therefore we can just iterate over the string without casting it into []rune
            i := 0
            for ; i < len(accepted); i++ {
                if accepted[i] == '*' || offert[i] == '*' {
                    return offert
                }
                if accepted[i] != offert[i] {
                    break
                }
            }
            if i == len(accepted) {
                return offert
            }
        }
    }
    return ""
}

// SetAccepted sets Accept header data.
func (c *Context) SetAccepted(formats ...string) {
    c.Accepted = formats
}

总结

Context 的内容就到这里了, 虽然源文件有点长, 但配合注释还是挺清晰的.

查看原文

赞 0 收藏 0 评论 0

帅气猫咪 发布了文章 · 2019-12-04

02Gin源码解读

简介

Gin 源码解读, 基于 v1.5.0 版本.

HttpRouter 实现

添加路由主要是由 addRoute 完成:

func (engine *Engine) addRoute(method, path string, handlers HandlersChain) {
    assert1(path[0] == '/', "path must begin with '/'")
    assert1(method != "", "HTTP method can not be empty")
    assert1(len(handlers) > 0, "there must be at least one handler")

    debugPrintRoute(method, path, handlers)
    root := engine.trees.get(method)
    if root == nil {
        root = new(node)
        root.fullPath = "/"
        engine.trees = append(engine.trees, methodTree{method: method, root: root})
    }
    root.addRoute(path, handlers)
}

Gin 的路由是通过 httprouter 实现的, 来深入了解下它的源代码.

数据结构

github 的文档解释了实现原理, 具体可以参考 How does it work?.

HttpRouter 内部使用了 Radix 树, 是前缀树的紧凑版变种.

radix_tree.png

上图来自维基百科, 显示了 Radix 树的结构. 相比普通前缀树, Radix 树的边上能存储多个字符, 极大的压缩了树的深度.

看一下数据结构的定义:

// Param is a single URL parameter, consisting of a key and a value.
type Param struct {
    Key   string
    Value string
}

// Params is a Param-slice, as returned by the router.
// The slice is ordered, the first URL parameter is also the first slice value.
// It is therefore safe to read values by the index.
type Params []Param

type methodTree struct {
    method string
    root   *node
}

type methodTrees []methodTree

Engine.trees 的类型就是 methodTrees, 初始化语句是 trees: make(methodTrees, 0, 9),.

func (trees methodTrees) get(method string) *node {
    for _, tree := range trees {
        if tree.method == method {
            return tree.root
        }
    }
    return nil
}

前面添加路由的代码中第一步是找到 root, 即 root := engine.trees.get(method), 结合 get 代码,
我们可以发现 methodTrees 实际上根据 HTTP 方法分类的, 每种方法都对应一颗树.

如果当前该类型的 HTTP 方法不存在, 就新建一棵树 methodTree:

root = new(node)
root.fullPath = "/"
engine.trees = append(engine.trees, methodTree{method: method, root: root})

再看一下树的节点是如何定义的:

type nodeType uint8

const (
    static nodeType = iota // default
    root
    param
    catchAll
)

type node struct {
    path      string
    indices   string
    children  []*node
    handlers  HandlersChain
    priority  uint32
    nType     nodeType
    maxParams uint8
    wildChild bool
    fullPath  string
}

添加路由

数据结构已经了解了, 看一下路由到底是如何添加的, 即 root.addRoute(path, handlers).

// addRoute adds a node with the given handle to the path.
// Not concurrency-safe!
func (n *node) addRoute(path string, handlers HandlersChain) {
    fullPath := path
    n.priority++
    numParams := countParams(path)

    parentFullPathIndex := 0

    // non-empty tree
    if len(n.path) > 0 || len(n.children) > 0 {
    walk:
        for {
            // Update maxParams of the current node
            if numParams > n.maxParams {
                n.maxParams = numParams
            }

            // Find the longest common prefix.
            // This also implies that the common prefix contains no ':' or '*'
            // since the existing key can't contain those chars.
            i := 0
            max := min(len(path), len(n.path))
            for i < max && path[i] == n.path[i] {
                i++
            }

            // Split edge
            if i < len(n.path) {
                child := node{
                    path:      n.path[i:],
                    wildChild: n.wildChild,
                    indices:   n.indices,
                    children:  n.children,
                    handlers:  n.handlers,
                    priority:  n.priority - 1,
                    fullPath:  n.fullPath,
                }

                // Update maxParams (max of all children)
                for i := range child.children {
                    if child.children[i].maxParams > child.maxParams {
                        child.maxParams = child.children[i].maxParams
                    }
                }

                n.children = []*node{&child}
                // []byte for proper unicode char conversion, see #65
                n.indices = string([]byte{n.path[i]})
                n.path = path[:i]
                n.handlers = nil
                n.wildChild = false
                n.fullPath = fullPath[:parentFullPathIndex+i]
            }

            // Make new node a child of this node
            if i < len(path) {
                path = path[i:]

                if n.wildChild {
                    parentFullPathIndex += len(n.path)
                    n = n.children[0]
                    n.priority++

                    // Update maxParams of the child node
                    if numParams > n.maxParams {
                        n.maxParams = numParams
                    }
                    numParams--

                    // Check if the wildcard matches
                    if len(path) >= len(n.path) && n.path == path[:len(n.path)] {
                        // check for longer wildcard, e.g. :name and :names
                        if len(n.path) >= len(path) || path[len(n.path)] == '/' {
                            continue walk
                        }
                    }

                    pathSeg := path
                    if n.nType != catchAll {
                        pathSeg = strings.SplitN(path, "/", 2)[0]
                    }
                    prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path
                    panic("'" + pathSeg +
                        "' in new path '" + fullPath +
                        "' conflicts with existing wildcard '" + n.path +
                        "' in existing prefix '" + prefix +
                        "'")
                }

                c := path[0]

                // slash after param
                if n.nType == param && c == '/' && len(n.children) == 1 {
                    parentFullPathIndex += len(n.path)
                    n = n.children[0]
                    n.priority++
                    continue walk
                }

                // Check if a child with the next path byte exists
                for i := 0; i < len(n.indices); i++ {
                    if c == n.indices[i] {
                        parentFullPathIndex += len(n.path)
                        i = n.incrementChildPrio(i)
                        n = n.children[i]
                        continue walk
                    }
                }

                // Otherwise insert it
                if c != ':' && c != '*' {
                    // []byte for proper unicode char conversion, see #65
                    n.indices += string([]byte{c})
                    child := &node{
                        maxParams: numParams,
                        fullPath:  fullPath,
                    }
                    n.children = append(n.children, child)
                    n.incrementChildPrio(len(n.indices) - 1)
                    n = child
                }
                n.insertChild(numParams, path, fullPath, handlers)
                return

            } else if i == len(path) { // Make node a (in-path) leaf
                if n.handlers != nil {
                    panic("handlers are already registered for path '" + fullPath + "'")
                }
                n.handlers = handlers
            }
            return
        }
    } else { // Empty tree
        n.insertChild(numParams, path, fullPath, handlers)
        n.nType = root
    }
}

addRoute

代码有点长, 先根据 if 语句分为两种情况, 一种是初始化的时候(即树是空的), 另一种是树是非空的.

n.insertChild(numParams, path, fullPath, handlers)
n.nType = root

树是空的情况下, 即 n.path 是空字符串(初始值) 且 n.children 是空切片.
这个时候, 只是通过 insertChild 插入节点, 然后将节点的类型设置为 root.
insertChild 的代码也有点长, 等下再来看.

当树是非空的, 进入到了一个 for 循环中, 先跟着注释看一下 for 大体上是做什么的.

// Update maxParams of the current node
if numParams > n.maxParams {
  n.maxParams = numParams
}

// Find the longest common prefix.
// This also implies that the common prefix contains no ':' or '*'
// since the existing key can't contain those chars.
i := 0
max := min(len(path), len(n.path))
for i < max && path[i] == n.path[i] {
  i++
}

// Split edge

// Make new node a child of this node

前面两个步骤, 更新 maxParams 和计算最长前缀的长度, 非常简单, 直接看代码就行.

看一下节点是如何分裂的, 即第三步:

// Split edge
if i < len(n.path) {
  child := node{
    path:      n.path[i:],
    wildChild: n.wildChild,
    indices:   n.indices,
    children:  n.children,
    handlers:  n.handlers,
    priority:  n.priority - 1,
    fullPath:  n.fullPath,
  }

  // Update maxParams (max of all children)
  for i := range child.children {
    if child.children[i].maxParams > child.maxParams {
      child.maxParams = child.children[i].maxParams
    }
  }

  n.children = []*node{&child}
  // []byte for proper unicode char conversion, see #65
  n.indices = string([]byte{n.path[i]})
  n.path = path[:i]
  n.handlers = nil
  n.wildChild = false
  n.fullPath = fullPath[:parentFullPathIndex+i]
}

当公共前缀的长度小于 n.path 时, 当前节点就会分裂出一个子节点.

比如, 当前节点 node.path = "/ping", 遇到 path = "/pong" 时就会分裂,
公共前缀的长度 i=2, 因此节点会分裂为 node.path = "/p"node.path = "ing".
分裂出来的后一个节点会占据当前节点的大部分属性.

接着看第四步, 如何为当前节点添加一个子节点, 是 root.addRoute(path, handlers) 的核心代码.

这也是一个 if 判断, 让我们先看一下后半部分, 即可能出错时的情况:

else if i == len(path) { // Make node a (in-path) leaf
  if n.handlers != nil {
    panic("handlers are already registered for path '" + fullPath + "'")
  }
  n.handlers = handlers
}

如果 handlers 不为空, 就会发生错误, 这说明 handlers 只被允许注册一次.

看一下 if 的前半部分, 即 if i < len(path) 时的情况:

// Make new node a child of this node
if i < len(path) {
  path = path[i:]

  if n.wildChild {
    parentFullPathIndex += len(n.path)
    n = n.children[0]
    n.priority++

    // Update maxParams of the child node
    if numParams > n.maxParams {
      n.maxParams = numParams
    }
    numParams--

    // Check if the wildcard matches
    if len(path) >= len(n.path) && n.path == path[:len(n.path)] {
      // check for longer wildcard, e.g. :name and :names
      if len(n.path) >= len(path) || path[len(n.path)] == '/' {
        continue walk
      }
    }

    pathSeg := path
    if n.nType != catchAll {
      pathSeg = strings.SplitN(path, "/", 2)[0]
    }
    prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path
    panic("'" + pathSeg +
      "' in new path '" + fullPath +
      "' conflicts with existing wildcard '" + n.path +
      "' in existing prefix '" + prefix +
      "'")
  }

  c := path[0]

  // slash after param
  if n.nType == param && c == '/' && len(n.children) == 1 {
    parentFullPathIndex += len(n.path)
    n = n.children[0]
    n.priority++
    continue walk
  }

  // Check if a child with the next path byte exists
  for i := 0; i < len(n.indices); i++ {
    if c == n.indices[i] {
      parentFullPathIndex += len(n.path)
      i = n.incrementChildPrio(i)
      n = n.children[i]
      continue walk
    }
  }

  // Otherwise insert it
  if c != ':' && c != '*' {
    // []byte for proper unicode char conversion, see #65
    n.indices += string([]byte{c})
    child := &node{
      maxParams: numParams,
      fullPath:  fullPath,
    }
    n.children = append(n.children, child)
    n.incrementChildPrio(len(n.indices) - 1)
    n = child
  }
  n.insertChild(numParams, path, fullPath, handlers)
  return

}

这一部分也是有点长, 也需要一步步拆解来看.

首先根据 path = path[i:], 发现 path 已经去除了公共前缀部分了.

先看一下第一个判断, if n.wildChild, 即存在通配符子节点:

if n.wildChild {
  parentFullPathIndex += len(n.path)
  n = n.children[0]
  n.priority++

  // Update maxParams of the child node
  if numParams > n.maxParams {
    n.maxParams = numParams
  }
  numParams--

  // Check if the wildcard matches
  if len(path) >= len(n.path) && n.path == path[:len(n.path)] {
    // check for longer wildcard, e.g. :name and :names
    if len(n.path) >= len(path) || path[len(n.path)] == '/' {
      continue walk
    }
  }

  pathSeg := path
  if n.nType != catchAll {
    pathSeg = strings.SplitN(path, "/", 2)[0]
  }
  prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path
  panic("'" + pathSeg +
    "' in new path '" + fullPath +
    "' conflicts with existing wildcard '" + n.path +
    "' in existing prefix '" + prefix +
    "'")
}

通配符的判断中, 一般都是触发通配符冲突错误的, 除非前面通配符部分一样, 后面有 /.

c := path[0]

// slash after param
if n.nType == param && c == '/' && len(n.children) == 1 {
  parentFullPathIndex += len(n.path)
  n = n.children[0]
  n.priority++
  continue walk
}

当节点是 : 通配符且 path 开头为 / 后, 进入到新一轮的循环中.

// Check if a child with the next path byte exists
for i := 0; i < len(n.indices); i++ {
  if c == n.indices[i] {
    parentFullPathIndex += len(n.path)
    i = n.incrementChildPrio(i)
    n = n.children[i]
    continue walk
  }
}

检查是否存在一个孩子节点, 如果有的话就直接跳到那个节点, 然后进入新一轮的循环中.
前面节点分裂的时候, 设置了 n.indices = string([]byte{n.path[i]}).

// Otherwise insert it
if c != ':' && c != '*' {
  // []byte for proper unicode char conversion, see #65
  n.indices += string([]byte{c})
  child := &node{
    maxParams: numParams,
    fullPath:  fullPath,
  }
  n.children = append(n.children, child)
  n.incrementChildPrio(len(n.indices) - 1)
  n = child
}

经过了前面的判断之后, 走到这里, 如果 c 不是 :*, 就会插入一个节点, 并替换当前节点为这个节点.

n.insertChild(numParams, path, fullPath, handlers)
return

最后依旧是调用 insertChild. 然后终于可以使用 return 跳出循环, 结束整个方法了.

insertChild

上面在两个地方调用了 n.insertChild(numParams, path, fullPath, handlers), 看一下它的实现.

func (n *node) insertChild(numParams uint8, path string, fullPath string, handlers HandlersChain) {
    var offset int // already handled bytes of the path

    // find prefix until first wildcard (beginning with ':' or '*')
    for i, max := 0, len(path); numParams > 0; i++ {
        c := path[i]
        if c != ':' && c != '*' {
            continue
        }

        // find wildcard end (either '/' or path end)
        end := i + 1
        for end < max && path[end] != '/' {
            switch path[end] {
            // the wildcard name must not contain ':' and '*'
            case ':', '*':
                panic("only one wildcard per path segment is allowed, has: '" +
                    path[i:] + "' in path '" + fullPath + "'")
            default:
                end++
            }
        }

        // check if this Node existing children which would be
        // unreachable if we insert the wildcard here
        if len(n.children) > 0 {
            panic("wildcard route '" + path[i:end] +
                "' conflicts with existing children in path '" + fullPath + "'")
        }

        // check if the wildcard has a name
        if end-i < 2 {
            panic("wildcards must be named with a non-empty name in path '" + fullPath + "'")
        }

        if c == ':' { // param
            // split path at the beginning of the wildcard
            if i > 0 {
                n.path = path[offset:i]
                offset = i
            }

            child := &node{
                nType:     param,
                maxParams: numParams,
                fullPath:  fullPath,
            }
            n.children = []*node{child}
            n.wildChild = true
            n = child
            n.priority++
            numParams--

            // if the path doesn't end with the wildcard, then there
            // will be another non-wildcard subpath starting with '/'
            if end < max {
                n.path = path[offset:end]
                offset = end

                child := &node{
                    maxParams: numParams,
                    priority:  1,
                    fullPath:  fullPath,
                }
                n.children = []*node{child}
                n = child
            }

        } else { // catchAll
            if end != max || numParams > 1 {
                panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'")
            }

            if len(n.path) > 0 && n.path[len(n.path)-1] == '/' {
                panic("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'")
            }

            // currently fixed width 1 for '/'
            i--
            if path[i] != '/' {
                panic("no / before catch-all in path '" + fullPath + "'")
            }

            n.path = path[offset:i]

            // first node: catchAll node with empty path
            child := &node{
                wildChild: true,
                nType:     catchAll,
                maxParams: 1,
                fullPath:  fullPath,
            }
            n.children = []*node{child}
            n.indices = string(path[i])
            n = child
            n.priority++

            // second node: node holding the variable
            child = &node{
                path:      path[i:],
                nType:     catchAll,
                maxParams: 1,
                handlers:  handlers,
                priority:  1,
                fullPath:  fullPath,
            }
            n.children = []*node{child}

            return
        }
    }

    // insert remaining path part and handle to the leaf
    n.path = path[offset:]
    n.handlers = handlers
    n.fullPath = fullPath
}

折叠一下代码, 主要是两部分, 一个 for 循环, 以及一些更新属性的语句.

// insert remaining path part and handle to the leaf
n.path = path[offset:]
n.handlers = handlers
n.fullPath = fullPath

主要看一下 for 循环:

// find prefix until first wildcard (beginning with ':' or '*')
for i, max := 0, len(path); numParams > 0; i++ {
  c := path[i]
  if c != ':' && c != '*' {
    continue
  }

这几行判断, 如同注释说明的那般, 直到遇到通配符字符 ':' or '*' 才开始真正处理.
注意判断条件是 numParams, 这个参数指明了有几个通配符参数.

// find wildcard end (either '/' or path end)
end := i + 1
for end < max && path[end] != '/' {
  switch path[end] {
  // the wildcard name must not contain ':' and '*'
  case ':', '*':
    panic("only one wildcard per path segment is allowed, has: '" +
      path[i:] + "' in path '" + fullPath + "'")
  default:
    end++
  }
        }

这也是个判断, 用于验证通配符名字中不能出现多个 ':' and '*'.

// check if this Node existing children which would be
// unreachable if we insert the wildcard here
if len(n.children) > 0 {
  panic("wildcard route '" + path[i:end] +
    "' conflicts with existing children in path '" + fullPath + "'")
}

// check if the wildcard has a name
if end-i < 2 {
  panic("wildcards must be named with a non-empty name in path '" + fullPath + "'")
}

又是两个判断, 第一个用于验证当前 node 不能存储子节点, 否则通配符节点就冲突了.
第二个用于验证通配符节点的名字必须有长度, 至少要有一个字符.

最后是根据通配符的不同分别构造, 先看一下 c == ':' 时的代码:

if c == ':' { // param
  // split path at the beginning of the wildcard
  if i > 0 {
    n.path = path[offset:i]
    offset = i
  }

  child := &node{
    nType:     param,
    maxParams: numParams,
    fullPath:  fullPath,
  }
  n.children = []*node{child}
  n.wildChild = true
  n = child
  n.priority++
  numParams--

  // if the path doesn't end with the wildcard, then there
  // will be another non-wildcard subpath starting with '/'
  if end < max {
    n.path = path[offset:end]
    offset = end

    child := &node{
      maxParams: numParams,
      priority:  1,
      fullPath:  fullPath,
    }
    n.children = []*node{child}
    n = child
  }

}

然后是 c == '*' 时的代码, 也就是 else 部分:

else { // catchAll
  if end != max || numParams > 1 {
    panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'")
  }

  if len(n.path) > 0 && n.path[len(n.path)-1] == '/' {
    panic("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'")
  }

  // currently fixed width 1 for '/'
  i--
  if path[i] != '/' {
    panic("no / before catch-all in path '" + fullPath + "'")
  }

  n.path = path[offset:i]

  // first node: catchAll node with empty path
  child := &node{
    wildChild: true,
    nType:     catchAll,
    maxParams: 1,
    fullPath:  fullPath,
  }
  n.children = []*node{child}
  n.indices = string(path[i])
  n = child
  n.priority++

  // second node: node holding the variable
  child = &node{
    path:      path[i:],
    nType:     catchAll,
    maxParams: 1,
    handlers:  handlers,
    priority:  1,
    fullPath:  fullPath,
  }
  n.children = []*node{child}

  return
}

catchAll 通配符有点特殊, 这个通配符后面是不允许出现其他通配符参数的, 所以前几行都在判断要求是否符合.
这个过程中会创建两个类型为 catchAll 的节点, 第一个节点指示存储通配符子节点, 即wildChild=true,
第二个节点会占有具体的内容.

添加路由的过程基本上就是这样, 接下来看一下如何读取数据.

获取数据

从树中获取数据, 主要发生在 func (engine *Engine) handleHTTPRequest(c *Context) 中.
看一下代码片段:

root := t[i].root
// Find route in tree
value := root.getValue(rPath, c.Params, unescape)
if value.handlers != nil {
  c.handlers = value.handlers
  c.Params = value.params
  c.fullPath = value.fullPath
  c.Next()
  c.writermem.WriteHeaderNow()
  return
}

主要是通过 getValue 方法获取数据的, 完整代码如下:

// getValue returns the handle registered with the given path (key). The values of
// wildcards are saved to a map.
// If no handle can be found, a TSR (trailing slash redirect) recommendation is
// made if a handle exists with an extra (without the) trailing slash for the
// given path.
func (n *node) getValue(path string, po Params, unescape bool) (value nodeValue) {
    value.params = po
walk: // Outer loop for walking the tree
    for {
        if len(path) > len(n.path) {
            if path[:len(n.path)] == n.path {
                path = path[len(n.path):]
                // If this node does not have a wildcard (param or catchAll)
                // child,  we can just look up the next child node and continue
                // to walk down the tree
                if !n.wildChild {
                    c := path[0]
                    for i := 0; i < len(n.indices); i++ {
                        if c == n.indices[i] {
                            n = n.children[i]
                            continue walk
                        }
                    }

                    // Nothing found.
                    // We can recommend to redirect to the same URL without a
                    // trailing slash if a leaf exists for that path.
                    value.tsr = path == "/" && n.handlers != nil
                    return
                }

                // handle wildcard child
                n = n.children[0]
                switch n.nType {
                case param:
                    // find param end (either '/' or path end)
                    end := 0
                    for end < len(path) && path[end] != '/' {
                        end++
                    }

                    // save param value
                    if cap(value.params) < int(n.maxParams) {
                        value.params = make(Params, 0, n.maxParams)
                    }
                    i := len(value.params)
                    value.params = value.params[:i+1] // expand slice within preallocated capacity
                    value.params[i].Key = n.path[1:]
                    val := path[:end]
                    if unescape {
                        var err error
                        if value.params[i].Value, err = url.QueryUnescape(val); err != nil {
                            value.params[i].Value = val // fallback, in case of error
                        }
                    } else {
                        value.params[i].Value = val
                    }

                    // we need to go deeper!
                    if end < len(path) {
                        if len(n.children) > 0 {
                            path = path[end:]
                            n = n.children[0]
                            continue walk
                        }

                        // ... but we can't
                        value.tsr = len(path) == end+1
                        return
                    }

                    if value.handlers = n.handlers; value.handlers != nil {
                        value.fullPath = n.fullPath
                        return
                    }
                    if len(n.children) == 1 {
                        // No handle found. Check if a handle for this path + a
                        // trailing slash exists for TSR recommendation
                        n = n.children[0]
                        value.tsr = n.path == "/" && n.handlers != nil
                    }

                    return

                case catchAll:
                    // save param value
                    if cap(value.params) < int(n.maxParams) {
                        value.params = make(Params, 0, n.maxParams)
                    }
                    i := len(value.params)
                    value.params = value.params[:i+1] // expand slice within preallocated capacity
                    value.params[i].Key = n.path[2:]
                    if unescape {
                        var err error
                        if value.params[i].Value, err = url.QueryUnescape(path); err != nil {
                            value.params[i].Value = path // fallback, in case of error
                        }
                    } else {
                        value.params[i].Value = path
                    }

                    value.handlers = n.handlers
                    value.fullPath = n.fullPath
                    return

                default:
                    panic("invalid node type")
                }
            }
        } else if path == n.path {
            // We should have reached the node containing the handle.
            // Check if this node has a handle registered.
            if value.handlers = n.handlers; value.handlers != nil {
                value.fullPath = n.fullPath
                return
            }

            if path == "/" && n.wildChild && n.nType != root {
                value.tsr = true
                return
            }

            // No handle found. Check if a handle for this path + a
            // trailing slash exists for trailing slash recommendation
            for i := 0; i < len(n.indices); i++ {
                if n.indices[i] == '/' {
                    n = n.children[i]
                    value.tsr = (len(n.path) == 1 && n.handlers != nil) ||
                        (n.nType == catchAll && n.children[0].handlers != nil)
                    return
                }
            }

            return
        }

        // Nothing found. We can recommend to redirect to the same URL with an
        // extra trailing slash if a leaf exists for that path
        value.tsr = (path == "/") ||
            (len(n.path) == len(path)+1 && n.path[len(path)] == '/' &&
                path == n.path[:len(n.path)-1] && n.handlers != nil)
        return
    }
}

代码有点长, 先读一下注释. 主要是根据路径和参数, 获取注册在上面的 handlers.

// nodeValue holds return values of (*Node).getValue method
type nodeValue struct {
    handlers HandlersChain
    params   Params
    tsr      bool
    fullPath string
}

// Param is a single URL parameter, consisting of a key and a value.
type Param struct {
    Key   string
    Value string
}

里面用到的结构体如上. 方法的主体部分是一个 for 循环.

for 循环里面, 前半部分是一个判断, 先看一下后半部分.

// Nothing found. We can recommend to redirect to the same URL with an
// extra trailing slash if a leaf exists for that path
value.tsr = (path == "/") ||
  (len(n.path) == len(path)+1 && n.path[len(path)] == '/' &&
    path == n.path[:len(n.path)-1] && n.handlers != nil)
return

如果没有找到对应的匹配, 会设置一个叫做 tsr 的标识, 用于判断是否符合 TSR (trailing slash redirect), 即尾部斜杆重定向. 比如 /path 可以重定向到 /path/.

回到 if 判断上来, 先看第一个判断部分, 即 if len(path) > len(n.path).

if len(path) > len(n.path) {
  if path[:len(n.path)] == n.path {
    path = path[len(n.path):]
    // If this node does not have a wildcard (param or catchAll)
    // child,  we can just look up the next child node and continue
    // to walk down the tree
    if !n.wildChild {
      c := path[0]
      for i := 0; i < len(n.indices); i++ {
        if c == n.indices[i] {
          n = n.children[i]
          continue walk
        }
      }

      // Nothing found.
      // We can recommend to redirect to the same URL without a
      // trailing slash if a leaf exists for that path.
      value.tsr = path == "/" && n.handlers != nil
      return
    }

    // handle wildcard child
    n = n.children[0]
    switch n.nType {
    case param:
      // find param end (either '/' or path end)
      end := 0
      for end < len(path) && path[end] != '/' {
        end++
      }

      // save param value
      if cap(value.params) < int(n.maxParams) {
        value.params = make(Params, 0, n.maxParams)
      }
      i := len(value.params)
      value.params = value.params[:i+1] // expand slice within preallocated capacity
      value.params[i].Key = n.path[1:]
      val := path[:end]
      if unescape {
        var err error
        if value.params[i].Value, err = url.QueryUnescape(val); err != nil {
          value.params[i].Value = val // fallback, in case of error
        }
      } else {
        value.params[i].Value = val
      }

      // we need to go deeper!
      if end < len(path) {
        if len(n.children) > 0 {
          path = path[end:]
          n = n.children[0]
          continue walk
        }

        // ... but we can't
        value.tsr = len(path) == end+1
        return
      }

      if value.handlers = n.handlers; value.handlers != nil {
        value.fullPath = n.fullPath
        return
      }
      if len(n.children) == 1 {
        // No handle found. Check if a handle for this path + a
        // trailing slash exists for TSR recommendation
        n = n.children[0]
        value.tsr = n.path == "/" && n.handlers != nil
      }

      return

    case catchAll:
      // save param value
      if cap(value.params) < int(n.maxParams) {
        value.params = make(Params, 0, n.maxParams)
      }
      i := len(value.params)
      value.params = value.params[:i+1] // expand slice within preallocated capacity
      value.params[i].Key = n.path[2:]
      if unescape {
        var err error
        if value.params[i].Value, err = url.QueryUnescape(path); err != nil {
          value.params[i].Value = path // fallback, in case of error
        }
      } else {
        value.params[i].Value = path
      }

      value.handlers = n.handlers
      value.fullPath = n.fullPath
      return

    default:
      panic("invalid node type")
    }
  }
}

这部分的判断里嵌套了一个 if 判断, 用于判断路径的前缀和当前节点的 path 相符, 如果不相等就直接跳过.

然后是根据 n.wildChild 判断, 即基于是否有通配符子节点.

如果没有通配符子节点, 会继续查找下一个子节点, 然后进行新一轮的 for 循环.
如果找不到子节点, 就直接返回了. 是否存在子节点是根据 n.indices 判断的,
n.indices 是个字符串, 保存了所有子节点路径的第一个字符.
比如, 当前注册了两个路径, /ping/pong, 那么当前的节点就是 /p 公共前缀,
然后它的 n.indices="io".

如果存在通配符子节点, 就会根据 n.nType 的类型进行选择处理.

如果类型是 param, 即使用 : 命名的变量, 就会先保存那个变量的值.
如果长度还有剩余 if end < len(path) {, 就会进入到新一个 for 循环中;
否认就认为是结束了, 将 handlersfullPath 复制一下就行了.

如果类型是 catchAll, 即使用 * 命令的任意匹配变量, 处理就比较简单了,
因为不用考虑后面还有路径的问题, * 会匹配所有剩余的 path 路径.
直接保存变量值, 然后将 handlersfullPath 复制一下就行了.

如果类型不符合上述的两种类型, 就会触发 panic.

接着看另一个判断, 即 else if path == n.path.

else if path == n.path {
  // We should have reached the node containing the handle.
  // Check if this node has a handle registered.
  if value.handlers = n.handlers; value.handlers != nil {
    value.fullPath = n.fullPath
    return
  }

  if path == "/" && n.wildChild && n.nType != root {
    value.tsr = true
    return
  }

  // No handle found. Check if a handle for this path + a
  // trailing slash exists for trailing slash recommendation
  for i := 0; i < len(n.indices); i++ {
    if n.indices[i] == '/' {
      n = n.children[i]
      value.tsr = (len(n.path) == 1 && n.handlers != nil) ||
        (n.nType == catchAll && n.children[0].handlers != nil)
      return
    }
  }

  return
}

这部分的处理也是比较简单的, 和前面的逻辑类似, 主要是看路径上是否有 handler 注册.
如果没有 handler 注册, 就会检查 value.tsr 的值, 是否属于尾部斜杆重定向.

由此, 从树中获取数据的过程也已经看完了.

总结

优秀的代码还是要多读读的, 即有助于理解原理, 又能开阔自己的视野.
另外一点, 读代码的时候调试器真的是非常有用, 尤其是观察数据结构是怎么存储的.

查看原文

赞 0 收藏 0 评论 0

帅气猫咪 发布了文章 · 2019-11-30

01Gin源码解读

简介

Gin 源码解读, 基于 v1.5.0 版本.

流程总览

官方文档上, 一个入门例子如下:

package main

import "github.com/gin-gonic/gin"

func main() {
    r := gin.Default()
    r.GET("/ping", func(c *gin.Context) {
        c.JSON(200, gin.H{
            "message": "pong",
        })
    })
    r.Run() // 监听并在 0.0.0.0:8080 上启动服务
}

看上去非常简单, 首先进行初始化 gin.Default(), 接着定义了一个叫做 /ping 的路由, 最后直接启动了 r.Run().

初始化

首先, 深入查看下 gin.Default() 的过程:

// Default returns an Engine instance with the Logger and Recovery middleware already attached.
func Default() *Engine {
    debugPrintWARNINGDefault()
    engine := New()
    engine.Use(Logger(), Recovery())
    return engine
}

配合注释, 我们就明白了 Default 的主要功能是初始化 Engine, 然后加载了两个中间件, 用于日志记录和恢复.

Engine 实际上是一个结构体, 也是 Gin 框架的核心, 看一下它的定义.

// Engine is the framework's instance, it contains the muxer, middleware and configuration settings.
// Create an instance of Engine, by using New() or Default()
type Engine struct {
    RouterGroup

    // Enables automatic redirection if the current route can't be matched but a
    // handler for the path with (without) the trailing slash exists.
    // For example if /foo/ is requested but a route only exists for /foo, the
    // client is redirected to /foo with http status code 301 for GET requests
    // and 307 for all other request methods.
    RedirectTrailingSlash bool

    // If enabled, the router tries to fix the current request path, if no
    // handle is registered for it.
    // First superfluous path elements like ../ or // are removed.
    // Afterwards the router does a case-insensitive lookup of the cleaned path.
    // If a handle can be found for this route, the router makes a redirection
    // to the corrected path with status code 301 for GET requests and 307 for
    // all other request methods.
    // For example /FOO and /..//Foo could be redirected to /foo.
    // RedirectTrailingSlash is independent of this option.
    RedirectFixedPath bool

    // If enabled, the router checks if another method is allowed for the
    // current route, if the current request can not be routed.
    // If this is the case, the request is answered with 'Method Not Allowed'
    // and HTTP status code 405.
    // If no other Method is allowed, the request is delegated to the NotFound
    // handler.
    HandleMethodNotAllowed bool
    ForwardedByClientIP    bool

    // #726 #755 If enabled, it will thrust some headers starting with
    // 'X-AppEngine...' for better integration with that PaaS.
    AppEngine bool

    // If enabled, the url.RawPath will be used to find parameters.
    UseRawPath bool

    // If true, the path value will be unescaped.
    // If UseRawPath is false (by default), the UnescapePathValues effectively is true,
    // as url.Path gonna be used, which is already unescaped.
    UnescapePathValues bool

    // Value of 'maxMemory' param that is given to http.Request's ParseMultipartForm
    // method call.
    MaxMultipartMemory int64

    delims           render.Delims
    secureJsonPrefix string
    HTMLRender       render.HTMLRender
    FuncMap          template.FuncMap
    allNoRoute       HandlersChain
    allNoMethod      HandlersChain
    noRoute          HandlersChain
    noMethod         HandlersChain
    pool             sync.Pool
    trees            methodTrees
}

注释里写到 Engine 的初始化有两种方式, Default() 已经看过了, 看一下 New():

// New returns a new blank Engine instance without any middleware attached.
// By default the configuration is:
// - RedirectTrailingSlash:  true
// - RedirectFixedPath:      false
// - HandleMethodNotAllowed: false
// - ForwardedByClientIP:    true
// - UseRawPath:             false
// - UnescapePathValues:     true
func New() *Engine {
    debugPrintWARNINGNew()
    engine := &Engine{
        RouterGroup: RouterGroup{
            Handlers: nil,
            basePath: "/",
            root:     true,
        },
        FuncMap:                template.FuncMap{},
        RedirectTrailingSlash:  true,
        RedirectFixedPath:      false,
        HandleMethodNotAllowed: false,
        ForwardedByClientIP:    true,
        AppEngine:              defaultAppEngine,
        UseRawPath:             false,
        UnescapePathValues:     true,
        MaxMultipartMemory:     defaultMultipartMemory,
        trees:                  make(methodTrees, 0, 9),
        delims:                 render.Delims{Left: "{{", Right: "}}"},
        secureJsonPrefix:       "while(1);",
    }
    engine.RouterGroup.engine = engine
    engine.pool.New = func() interface{} {
        return engine.allocateContext()
    }
    return engine
}

再看一下, 添加中间件的过程.

// Use attaches a global middleware to the router. ie. the middleware attached though Use() will be
// included in the handlers chain for every single request. Even 404, 405, static files...
// For example, this is the right place for a logger or error management middleware.
func (engine *Engine) Use(middleware ...HandlerFunc) IRoutes {
    engine.RouterGroup.Use(middleware...)
    engine.rebuild404Handlers()
    engine.rebuild405Handlers()
    return engine
}

添加中间件, 实际上是在 RouterGroup 上注册, 那么这 RouterGroup 又是什么呢?

// RouterGroup is used internally to configure router, a RouterGroup is associated with
// a prefix and an array of handlers (middleware).
type RouterGroup struct {
    Handlers HandlersChain
    basePath string
    engine   *Engine
    root     bool
}

原来, RouterGroup 是用来配置路由的, 内部包含一个路由路径 basePath 和中间件数组 Handlers.
所以, 添加中间件只是在 Handlers 中新加一个元素:

// Use adds middleware to the group, see example code in GitHub.
func (group *RouterGroup) Use(middleware ...HandlerFunc) IRoutes {
    group.Handlers = append(group.Handlers, middleware...)
    return group.returnObj()
}

func (group *RouterGroup) returnObj() IRoutes {
    if group.root {
        return group.engine
    }
    return group
}

初始化的过程大体上就是如此.

注册 handler

web 服务器最主要的当然是定义路由和处理函数了.

r.GET("/ping", func(c *gin.Context) {
  c.JSON(200, gin.H{
    "message": "pong",
  })
})

在前面, 我们已经看过了 Engine 的定义, 注意看以下定义:

type Engine struct {
  RouterGroup

这显示了 Engine 的内部使用了 RouterGroup, 所以其实上各种 HTTP 方法都是注册在 RouterGroup 上的. 具体看一下 GET 方法:

// GET is a shortcut for router.Handle("GET", path, handle).
func (group *RouterGroup) GET(relativePath string, handlers ...HandlerFunc) IRoutes {
    return group.handle("GET", relativePath, handlers)
}

通过注释和代码, 我们可以知道, GET 只是一个快捷方式, 其实所有的 HTTP 方法注册都是由 router.Handle 处理的.

func (group *RouterGroup) handle(httpMethod, relativePath string, handlers HandlersChain) IRoutes {
    absolutePath := group.calculateAbsolutePath(relativePath)
    handlers = group.combineHandlers(handlers)
    group.engine.addRoute(httpMethod, absolutePath, handlers)
    return group.returnObj()
}

handle 的核心语句是 group.engine.addRoute(httpMethod, absolutePath, handlers).

先看一下 combineHandlers, 可以发现原来 handlers 是有限制的, 不能超过 63 个.
突然觉得 Golang 中组合 slice 是有点蛋疼, 居然要写三行.

func (group *RouterGroup) combineHandlers(handlers HandlersChain) HandlersChain {
    finalSize := len(group.Handlers) + len(handlers)
    if finalSize >= int(abortIndex) {
        panic("too many handlers")
    }
    mergedHandlers := make(HandlersChain, finalSize)
    copy(mergedHandlers, group.Handlers)
    copy(mergedHandlers[len(group.Handlers):], handlers)
    return mergedHandlers
}

具体看一下 addRoute 有什么操作.

func (engine *Engine) addRoute(method, path string, handlers HandlersChain) {
    assert1(path[0] == '/', "path must begin with '/'")
    assert1(method != "", "HTTP method can not be empty")
    assert1(len(handlers) > 0, "there must be at least one handler")

    debugPrintRoute(method, path, handlers)
    root := engine.trees.get(method)
    if root == nil {
        root = new(node)
        root.fullPath = "/"
        engine.trees = append(engine.trees, methodTree{method: method, root: root})
    }
    root.addRoute(path, handlers)
}

略过前面的判断之后, 可以看到核心是操作 engine.trees. 这用到了 httprouter.
root.addRoute(path, handlers) 的内容有点多, 就不展开了.

总之, 到这里, 路由已经注册好了.

运行

最后, 看一下 r.Run() 部分.

// Run attaches the router to a http.Server and starts listening and serving HTTP requests.
// It is a shortcut for http.ListenAndServe(addr, router)
// Note: this method will block the calling goroutine indefinitely unless an error happens.
func (engine *Engine) Run(addr ...string) (err error) {
    defer func() { debugPrintError(err) }()

    address := resolveAddress(addr)
    debugPrint("Listening and serving HTTP on %s\n", address)
    err = http.ListenAndServe(address, engine)
    return
}

这是一个阻塞的方法, 除非发生错误. 内部使用 net/http 包的 ListenAndServe 函数.

接收请求

上面我们已经看到运行是通过 http.ListenAndServe(address, engine) 实现的,
这是内置的 net/http 包的内容, 看一下具体的定义:

func ListenAndServe(addr string, handler Handler) error

第二个参数的类型是 Handler, 一猜就知道应该是接口类型, 看一下具体要实现什么.

type Handler interface {
  ServeHTTP(ResponseWriter, *Request)
}

看一下 Engine 是如何实现 ServeHTTP 方法的.

// ServeHTTP conforms to the http.Handler interface.
func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    c := engine.pool.Get().(*Context)
    c.writermem.reset(w)
    c.Request = req
    c.reset()

    engine.handleHTTPRequest(c)

    engine.pool.Put(c)
}

主要看一下 engine.handleHTTPRequest(c), 这用于处理 HTTP 请求.

func (engine *Engine) handleHTTPRequest(c *Context) {
    httpMethod := c.Request.Method
    rPath := c.Request.URL.Path
    unescape := false
    if engine.UseRawPath && len(c.Request.URL.RawPath) > 0 {
        rPath = c.Request.URL.RawPath
        unescape = engine.UnescapePathValues
    }
    rPath = cleanPath(rPath)

    // Find root of the tree for the given HTTP method
    t := engine.trees
    for i, tl := 0, len(t); i < tl; i++ {
        if t[i].method != httpMethod {
            continue
        }
        root := t[i].root
        // Find route in tree
        value := root.getValue(rPath, c.Params, unescape)
        if value.handlers != nil {
            c.handlers = value.handlers
            c.Params = value.params
            c.fullPath = value.fullPath
            c.Next()
            c.writermem.WriteHeaderNow()
            return
        }
        if httpMethod != "CONNECT" && rPath != "/" {
            if value.tsr && engine.RedirectTrailingSlash {
                redirectTrailingSlash(c)
                return
            }
            if engine.RedirectFixedPath && redirectFixedPath(c, root, engine.RedirectFixedPath) {
                return
            }
        }
        break
    }

    if engine.HandleMethodNotAllowed {
        for _, tree := range engine.trees {
            if tree.method == httpMethod {
                continue
            }
            if value := tree.root.getValue(rPath, nil, unescape); value.handlers != nil {
                c.handlers = engine.allNoMethod
                serveError(c, http.StatusMethodNotAllowed, default405Body)
                return
            }
        }
    }
    c.handlers = engine.allNoRoute
    serveError(c, http.StatusNotFound, default404Body)
}

代码有点长, 不过有两句注释可以帮我们快速理解逻辑, 主要根据 HTTP 方法和路径从 engine.trees 找到
对应的 handlers, 当能找到时:

if value.handlers != nil {
  c.handlers = value.handlers
  c.Params = value.params
  c.fullPath = value.fullPath
  c.Next()
  c.writermem.WriteHeaderNow()
  return
}

重点关注下 c.Next():

// Next should be used only inside middleware.
// It executes the pending handlers in the chain inside the calling handler.
// See example in GitHub.
func (c *Context) Next() {
    c.index++
    for c.index < int8(len(c.handlers)) {
        c.handlers[c.index](c)
        c.index++
    }
}

由此, 就执行 Context 中的 handlers. 注意到, handlers 中包括了中间件和主处理函数,
因此就完成了路由的处理.

当然, 也有找不到对应的路由的时候, 这可能有多种原因, 比如 HTTP 方法不存在, 或者是路径不存在.

中间件原理

调用 Next() 的过程中涉及到了中间件的原理, 下面具体讲一讲.

先看一下如何定义并添加中间件, 来自官方文档:

func Logger() gin.HandlerFunc {
    return func(c *gin.Context) {
        t := time.Now()

        // Set example variable
        c.Set("example", "12345")

        // before request

        c.Next()

        // after request
        latency := time.Since(t)
        log.Print(latency)

        // access the status we are sending
        status := c.Writer.Status()
        log.Println(status)
    }
}

func main() {
    r := gin.New()
    r.Use(Logger())

    r.GET("/test", func(c *gin.Context) {
        example := c.MustGet("example").(string)

        // it would print: "12345"
        log.Println(example)
    })

    // Listen and serve on 0.0.0.0:8080
    r.Run(":8080")
}

从上面的例子中可以看出, 定义中间件和定义主处理函数没什么区别, 方法定义都是 gin.HandlerFunc:

type HandlerFunc func(*Context)

中间件中也调用了 c.Next(), 划分了请求的过程, c.Next() 运行前是请求前(before request), 运行后是请求后(after request).

中间件流程.png

从上面的流程图中假设注册两个处理函数, 第一个是 log 中间件, 用于日志记录, 另一个是该路径的主处理函数.
当接收到请求时, 进入很多个 Next 中, 这是因为中间件可能也会调用 c.Next(). 可以将 c.Next()
理解为控制流转移, 每当运行 c.Next(), 实际上是运行下一个 handler. 有点类似递归时的调用栈.

总结

Gin 的基本流程就是这样的, 粗看代码, 觉得还是挺清晰的. 更多内容还有待挖掘.

查看原文

赞 0 收藏 0 评论 0

帅气猫咪 发布了文章 · 2019-11-24

gRPC 简单使用

简介

RPC 的全称是 Remote Procedure Call(远程过程调用), 即可以在客户端应用程序中直接调用其他计算机(服务端)上定义的方法.

gRPC 是一个 RPC 框架, 使用 protobuf 作为数据交换协议.

定义服务

既然是 RPC 系统, 主要的目的在于定义方法, 或者说服务 service.

service HelloService {
  rpc SayHello (HelloRequest) returns (HelloResponse);
}

message HelloRequest {
  string greeting = 1;
}

message HelloResponse {
  string reply = 1;
}

在 gRPC 中可以定义四种类型的服务:

  • 一元 RPC: 客户端向服务器发送单个请求并获取单个响应, 类似普通函数调用
  • 服务器流式 RPC: 客户端发送单个请求, 服务端返回流式响应
  • 客户端流式 RPC: 客户端发送流式请求, 服务端返回单个响应
  • 双向流式 RPC: 使用两个独立的流, 客户端发送流式请求, 服务端返回流式响应
rpc SayHello(HelloRequest) returns (HelloResponse){
}
rpc LotsOfReplies(HelloRequest) returns (stream HelloResponse){
}
rpc LotsOfGreetings(stream HelloRequest) returns (HelloResponse) {
}
rpc BidiHello(stream HelloRequest) returns (stream HelloResponse){
}

Golang 下使用

定义 protobuf 文件 hello.proto 的内容为:

syntax = "proto3";

import "google/protobuf/any.proto";

package hello;
option go_package = "hello";

message HelloReq {
  string name = 1;
}

message HelloResp {
  int32 code = 1;
  string greet = 2;
  google.protobuf.Any details = 3;
}

service HelloService {
  rpc Greet(HelloReq) returns (HelloResp) {};
  rpc GreetWithServerStream(HelloReq) returns (stream HelloResp) {};
  rpc GreetWithClientStream(stream HelloReq) returns (HelloResp) {};
  rpc GreetWithBidirectionalStream(stream HelloReq) returns (stream HelloResp) {};
}

初始化项目, 安装必要的依赖:

go mod init tzh.com/app
go get -u github.com/golang/protobuf/protoc-gen-go
go get -u google.golang.org/grpc
mkdir hello
# 假设 protoc3 已经解压好了
.\protoc3\bin\protoc.exe --proto_path=. hello.proto --go_out=plugins=grpc:./hello

生成代码的时候, 注意指定插件 plugins=grpc.

main.go 如下:

package main

import (
    "context"
    "flag"
    "fmt"
    "io"
    "log"
    "net"
    "strings"

    "google.golang.org/grpc"
    pb "tzh.com/app/hello"
)

const port = ":5000"

type server struct {
    pb.UnimplementedHelloServiceServer
}

func (s *server) Greet(ctx context.Context, in *pb.HelloReq) (*pb.HelloResp, error) {
    return &pb.HelloResp{Code: 0, Greet: "hello " + in.GetName()}, nil
}

// 对于空格分隔的 name, 使用 stream 发送多次数据
func (s *server) GreetWithServerStream(in *pb.HelloReq, stream pb.HelloService_GreetWithServerStreamServer) error {
    names := strings.Split(in.GetName(), " ")
    for i, name := range names {
        err := stream.Send(&pb.HelloResp{
            Code:  int32(i),
            Greet: fmt.Sprintf("part %d: hello %s", i, name),
        })
        if err != nil {
            return err
        }
    }
    return nil
}

// 对于客户端发送的多个 name, 合并后发送单条响应
func (s *server) GreetWithClientStream(stream pb.HelloService_GreetWithClientStreamServer) error {
    names := make([]string, 0)
    for {
        msg, err := stream.Recv()
        if err != nil {
            break
        }
        names = append(names, msg.GetName())
    }
    stream.SendAndClose(&pb.HelloResp{
        Code:  0,
        Greet: fmt.Sprintf("hello %s count: %d", strings.Join(names, " "), len(names)),
    })
    return nil
}

// 双向流, 对于每个请求, 一一响应
func (s *server) GreetWithBidirectionalStream(stream pb.HelloService_GreetWithBidirectionalStreamServer) error {
    for {
        msg, err := stream.Recv()
        if err == io.EOF {
            return nil
        }
        if err != nil {
            return err
        }
        if err := stream.Send(&pb.HelloResp{
            Code:  0,
            Greet: "hello " + msg.GetName(),
        }); err != nil {
            return err
        }
    }
}

func runServer() {
    lis, err := net.Listen("tcp", port)
    if err != nil {
        log.Fatalf("failed to listen on %s: %v", port, err)
    }
  s := grpc.NewServer()
  // 注册服务
    pb.RegisterHelloServiceServer(s, &server{})
    if err := s.Serve(lis); err != nil {
        log.Fatalf("failed to server: %v", err)
    }
}

func run1(client pb.HelloServiceClient) {
    log.Println("############ Greet ")
    r, err := client.Greet(context.Background(), &pb.HelloReq{
        Name: "tt",
    })
    if err != nil {
        log.Fatalf("failed to get greet resp: %v", err)
    }
    log.Printf("get code : %d, get greet: %s \n", r.GetCode(), r.GetGreet())
}

func run2(client pb.HelloServiceClient) {
    log.Println("############ GreetWithServerStream")
    serverStream, err := client.GreetWithServerStream(context.Background(), &pb.HelloReq{
        Name: "tt aa xx ff",
    })
    if err != nil {
        log.Fatalf("failed with GreetWithServerStream: %v", err)
    }

    for {
        r, err := serverStream.Recv()
        if err != nil {
            break
        }
        log.Printf("get code : %d, get greet: %s \n", r.GetCode(), r.GetGreet())
    }
}

func run3(client pb.HelloServiceClient) {
    log.Println("############ GreetWithClientStream ")
    clientStream, err := client.GreetWithClientStream(context.Background())
    if err != nil {
        log.Fatalf("failed with GreetWithClientStream: %v", err)
    }

    for _, name := range []string{"tt", "qq", "aa", "yy"} {
        err := clientStream.Send(&pb.HelloReq{
            Name: name,
        })
        if err != nil {
            log.Fatalf("failed with GreetWithClientStream during send request: %v", err)
            break
        }
    }
    r, err := clientStream.CloseAndRecv()
    if err != nil {
        log.Fatalf("failed with GreetWithClientStream when get response: %v", err)
    }
    log.Printf("get code : %d, get greet: %s \n", r.GetCode(), r.GetGreet())
}

func run4(client pb.HelloServiceClient) {
    log.Println("############ GreetWithBidirectionalStream ")
    stream, err := client.GreetWithBidirectionalStream(context.Background())
    if err != nil {
        log.Fatalf("failed with GreetWithBidirectionalStream: %v", err)
    }

    for _, name := range []string{"tt", "qq", "aa", "yy"} {
        err := stream.Send(&pb.HelloReq{
            Name: name,
        })
        if err != nil {
            log.Fatalf("failed with GreetWithClientStream during send request: %v", err)
            break
        }
    }
    stream.CloseSend()

    for {
        r, err := stream.Recv()
        if err != nil {
            break
        }
        log.Printf("get code : %d, get greet: %s \n", r.GetCode(), r.GetGreet())
    }
}

func runClient() {
    address := "localhost" + port
    conn, err := grpc.Dial(address, grpc.WithInsecure())
    if err != nil {
        log.Fatalf("failed to connect %s: %v", address, err)
    }
    defer conn.Close()
    client := pb.NewHelloServiceClient(conn)

    run1(client)
    run2(client)
    run3(client)
    run4(client)
}

func main() {
    isClient := flag.Bool("client", false, "run client")
    flag.Parse()
    if *isClient {
        runClient()
    } else {
        runServer()
    }
}

代码有点长, 因为将服务端代码和客户端代码都混合在同一个文件中, 使用下面的命令分别启动服务端和运行客户端.

# 运行服务端
go run main.go
# 运行客户端
go run main.go --client

参考

查看原文

赞 0 收藏 0 评论 0

认证与成就

  • 获得 21 次点赞
  • 获得 2 枚徽章 获得 0 枚金徽章, 获得 0 枚银徽章, 获得 2 枚铜徽章

擅长技能
编辑

(゚∀゚ )
暂时没有

开源项目 & 著作
编辑

(゚∀゚ )
暂时没有

注册于 2019-08-31
个人主页被 862 人浏览