求段尽量靠谱的完整的在 Go 中实现 SQL 查询超时的代码?

感觉 Go 文档 中的例子好像不完整,函数参数中的 context 是哪来的?直接用默认的那个上下文就行?另外好像还需要处理 context 的错误?文档里好像没有。网上搜了一会儿,也看了一会儿,失去耐心了。

所谓靠谱就是最好是官方文档中的代码。我怀疑我搜索的方法不对,可能有完整的例子。如果确实没有的话,经验丰富的大佬给段代码也行。或者大型开源项目中的代码段也行,我自己试着在 github 上搜了一下,搜索代码的时候好像不能按照项目的收藏数排序。

阅读 539
2 个回答

代码

package main

import (
    "context"
    "database/sql"
    "fmt"
    "log"
    "time"

    _ "github.com/go-sql-driver/mysql" // MySQL 驱动
)

// DBClient 封装了数据库客户端,包含连接池管理和超时控制。
type DBClient struct {
    db             *sql.DB
    defaultTimeout time.Duration
}

// QueryResult 保存数据库查询的结果。
type QueryResult struct {
    Rows   *sql.Rows
    Error  error
    closed bool // 私有字段,防止外部误用
}

// Close 关闭查询结果的行,如果尚未关闭。
func (qr *QueryResult) Close() error {
    if qr.Rows != nil && !qr.closed {
        qr.closed = true
        return qr.Rows.Close()
    }
    return nil
}

// NewDBClient 创建一个新的数据库客户端,指定驱动、数据源和默认超时时间。
// 它会配置连接池并测试数据库连接。
func NewDBClient(driverName, dataSourceName string, defaultTimeout time.Duration) (*DBClient, error) {
    db, err := sql.Open(driverName, dataSourceName)
    if err != nil {
        return nil, fmt.Errorf("打开数据库失败: %w", err)
    }

    // 配置连接池
    db.SetMaxOpenConns(25)                 // 最大打开连接数
    db.SetMaxIdleConns(5)                  // 最大空闲连接数
    db.SetConnMaxLifetime(5 * time.Minute) // 连接最大生存时间

    // 测试连接
    ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    defer cancel()
    if err := db.PingContext(ctx); err != nil {
        db.Close()
        return nil, fmt.Errorf("数据库连接测试失败: %w", err)
    }

    return &DBClient{
        db:             db,
        defaultTimeout: defaultTimeout,
    }, nil
}

// QueryWithTimeout 执行带超时的查询,使用指定的超时时间或客户端的默认超时时间。
// 返回 QueryResult,包含查询结果行或错误。调用者需负责调用 Close 方法关闭结果。
func (client *DBClient) QueryWithTimeout(ctx context.Context, timeout time.Duration, query string, args ...interface{}) (*QueryResult, error) {
    var queryCtx context.Context
    var cancel context.CancelFunc

    if timeout > 0 {
        queryCtx, cancel = context.WithTimeout(ctx, timeout)
    } else if client.defaultTimeout > 0 {
        queryCtx, cancel = context.WithTimeout(ctx, client.defaultTimeout)
    } else {
        queryCtx, cancel = context.WithCancel(ctx)
    }
    defer cancel()

    rows, err := client.db.QueryContext(queryCtx, query, args...)
    if err != nil {
        if err == context.DeadlineExceeded {
            return nil, fmt.Errorf("查询超时: %w", err)
        }
        return nil, fmt.Errorf("查询失败: %w", err)
    }
    return &QueryResult{Rows: rows, Error: err}, nil
}

// QueryRowWithTimeout 执行单行查询,使用指定的超时时间或客户端的默认超时时间。
// 返回 sql.Row 用于扫描结果。
func (client *DBClient) QueryRowWithTimeout(ctx context.Context, timeout time.Duration, query string, args ...interface{}) *sql.Row {
    var queryCtx context.Context
    var cancel context.CancelFunc

    if timeout > 0 {
        queryCtx, cancel = context.WithTimeout(ctx, timeout)
    } else if client.defaultTimeout > 0 {
        queryCtx, cancel = context.WithTimeout(ctx, client.defaultTimeout)
    } else {
        queryCtx = ctx
        cancel = func() {}
    }
    defer cancel()

    return client.db.QueryRowContext(queryCtx, query, args...)
}

// ExecWithTimeout 执行 INSERT、UPDATE 或 DELETE 查询,使用指定的超时时间或客户端的默认超时时间。
// 返回 sql.Result 或错误(超时或执行失败)。
func (client *DBClient) ExecWithTimeout(ctx context.Context, timeout time.Duration, query string, args ...interface{}) (sql.Result, error) {
    var queryCtx context.Context
    var cancel context.CancelFunc

    if timeout > 0 {
        queryCtx, cancel = context.WithTimeout(ctx, timeout)
    } else if client.defaultTimeout > 0 {
        queryCtx, cancel = context.WithTimeout(ctx, client.defaultTimeout)
    } else {
        queryCtx, cancel = context.WithCancel(ctx)
    }
    defer cancel()

    result, err := client.db.ExecContext(queryCtx, query, args...)
    if err != nil {
        if err == context.DeadlineExceeded {
            return nil, fmt.Errorf("执行超时: %w", err)
        }
        return nil, fmt.Errorf("执行失败: %w", err)
    }
    return result, nil
}

// BeginTxWithTimeout 开始一个事务,使用指定的超时时间或客户端的默认超时时间。
// 返回 sql.Tx 或错误(事务启动失败或超时)。
func (client *DBClient) BeginTxWithTimeout(ctx context.Context, timeout time.Duration, opts *sql.TxOptions) (*sql.Tx, error) {
    var txCtx context.Context
    var cancel context.CancelFunc

    if timeout > 0 {
        txCtx, cancel = context.WithTimeout(ctx, timeout)
    } else if client.defaultTimeout > 0 {
        txCtx, cancel = context.WithTimeout(ctx, client.defaultTimeout)
    } else {
        txCtx, cancel = context.WithCancel(ctx)
    }
    defer cancel()

    return client.db.BeginTx(txCtx, opts)
}

// PrepareWithTimeout 准备一个预编译语句,使用指定的超时时间或客户端的默认超时时间。
// 返回 sql.Stmt 或错误(准备失败或超时)。
func (client *DBClient) PrepareWithTimeout(ctx context.Context, timeout time.Duration, query string) (*sql.Stmt, error) {
    var queryCtx context.Context
    var cancel context.CancelFunc

    if timeout > 0 {
        queryCtx, cancel = context.WithTimeout(ctx, timeout)
    } else if client.defaultTimeout > 0 {
        queryCtx, cancel = context.WithTimeout(ctx, client.defaultTimeout)
    } else {
        queryCtx, cancel = context.WithCancel(ctx)
    }
    defer cancel()

    return client.db.PrepareContext(queryCtx, query)
}

// Close 关闭数据库连接。
func (client *DBClient) Close() error {
    if client.db != nil {
        return client.db.Close()
    }
    return nil
}

// GetDB 返回底层的 sql.DB 对象,供特殊场景使用。
func (client *DBClient) GetDB() *sql.DB {
    return client.db
}

// main 展示 DBClient 的使用示例。
func main() {
    // 创建数据库客户端,默认超时时间 30 秒
    // 请替换为实际的 DSN,例如从环境变量加载
    client, err := NewDBClient("mysql", "user:password@tcp(localhost:3306)/database", 30*time.Second)
    if err != nil {
        log.Fatal("创建数据库客户端失败:", err)
    }
    defer client.Close()

    // 根上下文
    ctx := context.Background()

    // 示例 1:带超时的查询
    fmt.Println("=== 查询示例 ===")
    result, err := client.QueryWithTimeout(ctx, 10*time.Second, "SELECT id, name FROM users WHERE status = ?", "active")
    if err != nil {
        log.Printf("查询失败: %v", err)
    } else {
        defer result.Close()
        for result.Rows.Next() {
            var id int
            var name string
            if err := result.Rows.Scan(&id, &name); err != nil {
                log.Printf("扫描错误: %v", err)
                continue
            }
            fmt.Printf("ID: %d, 名称: %s\n", id, name)
        }
        if err := result.Rows.Err(); err != nil {
            log.Printf("行错误: %v", err)
        }
    }

    // 示例 2:单行查询
    fmt.Println("\n=== 单行查询示例 ===")
    var userCount int
    row := client.QueryRowWithTimeout(ctx, 5*time.Second, "SELECT COUNT(*) FROM users")
    if err := row.Scan(&userCount); err != nil {
        log.Printf("单行查询失败: %v", err)
    } else {
        fmt.Printf("用户数量: %d\n", userCount)
    }

    // 示例 3:执行更新
    fmt.Println("\n=== 更新示例 ===")
    execResult, err := client.ExecWithTimeout(ctx, 15*time.Second, "UPDATE users SET last_login = NOW() WHERE id = ?", 1)
    if err != nil {
        log.Printf("更新失败: %v", err)
    } else {
        rowsAffected, _ := execResult.RowsAffected()
        fmt.Printf("影响的行数: %d\n", rowsAffected)
    }

    // 示例 4:预编译语句
    fmt.Println("\n=== 预编译语句示例 ===")
    stmt, err := client.PrepareWithTimeout(ctx, 5*time.Second, "SELECT name FROM users WHERE id = ?")
    if err != nil {
        log.Printf("预编译失败: %v", err)
    } else {
        defer stmt.Close()
        var userName string
        err = stmt.QueryRowContext(ctx, 1).Scan(&userName)
        if err != nil {
            log.Printf("预编译语句查询失败: %v", err)
        } else {
            fmt.Printf("用户名称: %s\n", userName)
        }
    }

    // 示例 5:超时测试
    fmt.Println("\n=== 超时测试 ===")
    shortCtx, cancel := context.WithTimeout(ctx, 1*time.Millisecond)
    defer cancel()
    _, err = client.QueryWithTimeout(shortCtx, 0, "SELECT SLEEP(1)")
    if err != nil {
        fmt.Printf("预期超时错误: %v\n", err)
    }

    // 示例 6:事务
    fmt.Println("\n=== 事务示例 ===")
    tx, err := client.BeginTxWithTimeout(ctx, 10*time.Second, nil)
    if err != nil {
        log.Printf("开启事务失败: %v", err)
    } else {
        _, err := tx.ExecContext(ctx, "UPDATE users SET status = ? WHERE id = ?", "inactive", 1)
        if err != nil {
            tx.Rollback()
            log.Printf("事务执行失败: %v", err)
        } else {
            if err := tx.Commit(); err != nil {
                log.Printf("事务提交失败: %v", err)
            } else {
                fmt.Println("事务提交成功")
            }
        }
    }
}

使用说明

数据库连接

使用环境变量加载 DSN(数据源名称):

import "os"

dsn := os.Getenv("DB_DSN")
if dsn == "" {
    log.Fatal("未设置 DB_DSN 环境变量")
}
client, err := NewDBClient("mysql", dsn, 30*time.Second)

安装 MySQL 驱动

确定安装 MySQL 驱动:

go get github.com/go-sql-driver/mysql

测试数据库准备

确定数据库和表存在,例如 users 表结构:

CREATE TABLE users (
    id INT PRIMARY KEY,
    name VARCHAR(255),
    status VARCHAR(50),
    last_login DATETIME
);
INSERT INTO users (id, name, status) VALUES (1, 'Alice', 'active');
package main

import (
    "context"
    "database/sql"
    "fmt"
    "log"
    "time"

    _ "github.com/go-sql-driver/mysql"
)

func main() {
    // 替换为你的 MySQL 数据源名称 (DSN)
    dsn := "user:password@tcp(127.0.0.1:3306)/dbname"

    // 打开数据库连接
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatalf("Failed to connect to database: %v", err)
    }
    defer db.Close()

    // 设置查询的上下文,带超时
    ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) // 超时时间为 5 秒
    defer cancel()

    // 执行带上下文的查询
    query := "SELECT id, name FROM users WHERE status = ?"
    rows, err := db.QueryContext(ctx, query, 1)

    if err != nil {
        if ctx.Err() == context.DeadlineExceeded {
            log.Fatalf("Query timed out: %v", err)
        } else {
            log.Fatalf("Database error: %v", err)
        }
    }
    defer rows.Close()

    // 处理结果
    for rows.Next() {
        var id int
        var name string
        if err := rows.Scan(&id, &name); err != nil {
            log.Fatalf("Scan error: %v", err)
        }
        fmt.Printf("ID: %d, Name: %s\n", id, name)
    }

    if err := rows.Err(); err != nil {
        log.Fatalf("Error iterating rows: %v", err)
    }
}
推荐问题