用GO写一个RPC框架 s05 (客户端编写)

dollarkillerx

前言

前面几章我们完成了 服务端的编写 现在开始客户端编写

https://github.com/dollarkill...

Client

type Client struct {
    options *Options
}

func NewClient(discover discovery.Discovery, options ...Option) *Client {
    client := &Client{
        options: defaultOptions(),
    }

    client.options.Discovery = discover

    for _, fn := range options {
        fn(client.options)
    }

    return client
}

option

type Options struct {
    Discovery         discovery.Discovery                 // 服务发现插件
    loadBalancing     load_banlancing.LoadBalancing       // 负载均衡插件
    serializationType codes.SerializationType             // 序列化插件
    compressorType    codes.CompressorType                // 压缩插件

    pool         int                                      // 连接池大小
    cryptology   cryptology.Cryptology
    rsaPublicKey []byte
    writeTimeout time.Duration
    readTimeout  time.Duration
    heartBeat    time.Duration
    Trace        bool
    AUTH         string                                   // AUTH TOKEN
}

func defaultOptions() *Options {
    defaultPoolSize := runtime.NumCPU() * 4
    if defaultPoolSize < 20 {
        defaultPoolSize = 20
    }

    return &Options{
        pool:              defaultPoolSize,
        serializationType: codes.MsgPack,
        compressorType:    codes.Snappy,
        loadBalancing:     load_banlancing.NewPolling(),
        cryptology:        cryptology.AES,
        rsaPublicKey: []byte(`
-----BEGIN PUBLIC KEY-----
-----END PUBLIC KEY-----`),
        writeTimeout: time.Minute,
        readTimeout:  time.Minute * 3,
        heartBeat:    time.Minute,
        Trace:        false,
        AUTH:         "",
    }
}

具体每个链接

type Connect struct {
    Client     *Client
    pool       *connectPool
    close      chan struct{}
    serverName string
}

func (c *Client) NewConnect(serverName string) (conn *Connect, err error) {
    connect := &Connect{
        Client:     c,
        serverName: serverName,
        close:      make(chan struct{}),
    }

    connect.pool, err = initPool(connect)
    return connect, err
}

初始化连接池

func initPool(c *Connect) (*connectPool, error) {
    cp := &connectPool{
        connect: c,
        pool:    make(chan LightClient, c.Client.options.pool),
    }

    return cp, cp.initPool()
}

func (c *connectPool) initPool() error {
    hosts, err := c.connect.Client.options.Discovery.Discovery(c.connect.serverName) // 调用服务发现 查看  发现具体服务
    if err != nil {
        return err
    }

    if len(hosts) == 0 {
        return errors.New(fmt.Sprintf("%s server 404", c.connect.serverName))
    }

    c.connect.Client.options.loadBalancing.InitBalancing(hosts)  // 初始化 负载均衡插件

    // 初始化连接池
    for i := 0; i < c.connect.Client.options.pool; i++ {
        client, err := newBaseClient(c.connect.serverName, c.connect.Client.options)  // 建立链接
        if err != nil {
            return errors.WithStack(err)
        }
        c.pool <- client
    }

    return nil
}

// 连接池中获取一个链接
func (c *connectPool) Get(ctx context.Context) (LightClient, error) {
    select {
    case <-ctx.Done():
        return nil, errors.New("pool get timeout")
    case r := <-c.pool:
        return r, nil
    }
}

// 放回一个链接
func (c *connectPool) Put(client LightClient) {
    if client.Error() == nil {
        c.pool <- client
        return
    }

        // 如果 client.Error() 有异常  需要新初始化一个链接 放入连接池
    go func() {
        fmt.Println("The server starts to restore")
        for {
            time.Sleep(time.Second)
            hosts, err := c.connect.Client.options.Discovery.Discovery(c.connect.serverName)
            if err != nil {
                log.Println(err)
                continue
            }

            if len(hosts) == 0 {
                err := errors.New(fmt.Sprintf("%s server 404", c.connect.serverName))
                log.Println(err)
                continue
            }

            c.connect.Client.options.loadBalancing.InitBalancing(hosts)
            baseClient, err := newBaseClient(c.connect.serverName, c.connect.Client.options)
            if err != nil {
                log.Println(err)
                continue
            }

            c.pool <- baseClient
            fmt.Println("Service recovery success")
            break
        }
    }()
}

Connect 调用具体服务

func (c *Connect) Call(ctx *light.Context, serviceMethod string, request interface{}, response interface{}) error {
    ctxT, _ := context.WithTimeout(context.TODO(), time.Second*6)
    var err error
        
        // 连接池中获取一个链接
    client, err := c.pool.Get(ctxT)
    if err != nil {
        return errors.WithStack(err)
    }
        
        // 用完 放回链接
    defer func() {
        c.pool.Put(client)
    }()

        // 设置token
    ctx.SetValue("Light_AUTH", c.Client.options.AUTH)
    // 具体调用
        err = client.Call(ctx, serviceMethod, request, response)
    if err != nil {
        return errors.WithStack(err)
    }

    return nil
}

调用核心 重点

复习 s03 协议设计

/**
    协议设计
    起始符 :  版本号 :  crc32校验 :   magicNumberSize:    serverNameSize :   serverMethodSize :  metaDataSize : payloadSize:  respType :   compressorType :    serializationType :    magicNumber :  serverName :   serverMethod :  metaData :  payload
        0x05  :  0x01  :     4     :        4         :         4         :         4          :       4       :      4     :      1    :          1       :           1          :        xxx     :       xxx   :        xxx     :    xxx    :    xxx
*/

注意: 每一个请求都有一个 magicNumber 都有一个请求ID

单个链接定义

type BaseClient struct {
    conn       net.Conn
    options    *Options
    serverName string

    aesKey        []byte
    serialization codes.Serialization
    compressor    codes.Compressor

    respInterMap map[string]*respMessage
    respInterRM  sync.RWMutex     // 返回结构锁
    writeMu      sync.Mutex   // 写锁

    err   error          // 错误
    close chan struct{}  // 用于关闭服务
}

type respMessage struct {
    response interface{}
    ctx      *light.Context
    respChan chan error
}

初始化单个链接

func newBaseClient(serverName string, options *Options) (*BaseClient, error) {
        // 服务发现用
    service, err := options.loadBalancing.GetService()
    if err != nil {
        return nil, err
    }
    con, err := transport.Client.Gen(service.Protocol, service.Addr)
    if err != nil {
        return nil, errors.WithStack(err)
    }

    serialization, ex := codes.SerializationManager.Get(options.serializationType)
    if !ex {
        return nil, pkg.ErrSerialization404
    }

    compressor, ex := codes.CompressorManager.Get(options.compressorType)
    if !ex {
        return nil, pkg.ErrCompressor404
    }

    // 握手
    encrypt, err := cryptology.RsaEncrypt([]byte(options.AUTH), options.rsaPublicKey)
    if err != nil {
        return nil, err
    }

    aesKey := []byte(strings.ReplaceAll(uuid.New().String(), "-", ""))

    // 交换秘钥
    aesKey2, err := cryptology.RsaEncrypt(aesKey, options.rsaPublicKey)
    if err != nil {
        return nil, err
    }
    handshake := protocol.EncodeHandshake(aesKey2, encrypt, []byte(""))
    _, err = con.Write(handshake)
    if err != nil {
        con.Close()
        return nil, err
    }

    hsk := &protocol.Handshake{}
    err = hsk.Handshake(con)
    if err != nil {
        con.Close()
        return nil, err
    }
    if hsk.Error != nil && len(hsk.Error) > 0 {
        con.Close()
        err := string(hsk.Error)
        return nil, errors.New(err)
    }

    bc := &BaseClient{
        serverName:    serverName,
        conn:          con,
        options:       options,
        serialization: serialization,
        compressor:    compressor,
        respInterMap:  map[string]*respMessage{},
        aesKey:        aesKey,
        close:         make(chan struct{}),
    }

    go bc.heartBeat()  // 心跳服务
    go bc.processMessageManager()  // 返回消息的处理

    return bc, nil
}

heartBeat 心跳服务

func (b *BaseClient) heartBeat() {
    defer func() {
        fmt.Println("heartBeat Close")
    }()

loop:
    for {
        select {
        case <-b.close:
            break loop
        case <-time.After(b.options.heartBeat):  // 定时发送心跳
            _, i, err := protocol.EncodeMessage("x", []byte(""), []byte(""), []byte(""), byte(protocol.HeartBeat), byte(b.options.compressorType), byte(b.options.serializationType), []byte(""))
            if err != nil {
                log.Println(err)
                break
            }
            now := time.Now()
            b.conn.SetDeadline(now.Add(b.options.writeTimeout))
            b.conn.SetWriteDeadline(now.Add(b.options.writeTimeout))
            b.writeMu.Lock()
            _, err = b.conn.Write(i)
            b.writeMu.Unlock()
            if err != nil {
                b.err = err
                break loop
            }
        }
    }
}

processMessageManager 返回消息的处理服务 (注意这里可以并发的来)

func (b *BaseClient) processMessageManager() {
    defer func() {
        fmt.Println("processMessageManager Close")
    }()

    for {
        magic, respChan, err := b.processMessage() // 处理某个消息
        if err == nil && magic == "" {
            continue
        }

        if err != nil && magic == "" {
            break
        }

        if err != nil && magic != "" && respChan != nil {
            respChan <- err
        }

        if err == nil && magic != "" && respChan != nil {
            close(respChan)
        }
    }
}

func (b *BaseClient) processMessage() (magic string, respChan chan error, err error) {
    // 3.封装回执
    now := time.Now()
    b.conn.SetReadDeadline(now.Add(b.options.readTimeout))

    proto := protocol.NewProtocol()
    msg, err := proto.IODecode(b.conn)
    if err != nil {
        b.err = err
        close(b.close)
        return "", nil, err
    }

    // heartbeat
    if msg.Header.RespType == byte(protocol.HeartBeat) {
        if b.options.Trace {
            log.Println("is HeartBeat")
        }
        return "", nil, nil
    }

    b.respInterRM.RLock()
    message, ex := b.respInterMap[msg.MagicNumber]
    b.respInterRM.RUnlock()
    if !ex { // 不存在 代表消息已经失效
        if b.options.Trace {
            log.Println("Not Ex", msg.MagicNumber)
        }
        return "", nil, nil
    }

    comp, ex := codes.CompressorManager.Get(codes.CompressorType(msg.Header.CompressorType))
    if !ex {
        return "", nil, nil
    }

    // 1. 解压缩
    msg.MetaData, err = comp.Unzip(msg.MetaData)
    if err != nil {
        return "", nil, err
    }
    msg.Payload, err = comp.Unzip(msg.Payload)
    if err != nil {
        return "", nil, err
    }
    // 2. 解密
    msg.MetaData, err = cryptology.AESDecrypt(b.aesKey, msg.MetaData)
    if err != nil {
        if len(msg.MetaData) != 0 {
            return "", nil, err
        }
        msg.Payload = []byte("")
    }

    msg.Payload, err = cryptology.AESDecrypt(b.aesKey, msg.Payload)
    if err != nil {
        if len(msg.Payload) != 0 {
            return "", nil, err
        }
        msg.Payload = []byte("")
    }
    // 3. 反序列化 RespError
    mtData := make(map[string]string)
    err = b.serialization.Decode(msg.MetaData, &mtData)
    if err != nil {
        return "", nil, err
    }

    message.ctx.SetMetaData(mtData)

    value := message.ctx.Value("RespError")
    if value != "" {
        return msg.MagicNumber, message.respChan, errors.New(value)
    }

    return msg.MagicNumber, message.respChan, b.serialization.Decode(msg.Payload, message.response)
}

服务调用

func (b *BaseClient) call(ctx *light.Context, serviceMethod string, request interface{}, response interface{}, respChan chan error) (magic string, err error) {
    metaData := ctx.GetMetaData()  // 获取ctx 进行基础编码

    // 1. 构造请求
    // 1.1 序列化
    serviceNameByte := []byte(b.serverName)
    serviceMethodByte := []byte(serviceMethod)
    var metaDataBytes []byte
    var requestBytes []byte
    metaDataBytes, err = b.serialization.Encode(metaData)
    if err != nil {
        return "", err
    }
    requestBytes, err = b.serialization.Encode(request)
    if err != nil {
        return "", err
    }

    // 1.2 加密
    metaDataBytes, err = cryptology.AESEncrypt(b.aesKey, metaDataBytes)
    if err != nil {
        return "", err
    }

    requestBytes, err = cryptology.AESEncrypt(b.aesKey, requestBytes)
    if err != nil {
        return "", err
    }

    compressorType := b.options.compressorType
    if len(metaDataBytes) > compressorMin && len(metaDataBytes) < compressorMax {
        // 1.3 压缩
        metaDataBytes, err = b.compressor.Zip(metaDataBytes)
        if err != nil {
            return "", err
        }

        requestBytes, err = b.compressor.Zip(requestBytes)
        if err != nil {
            return "", err
        }
    } else {
        compressorType = codes.RawData
    }

    // 1.4 封装消息
    magic, message, err := protocol.EncodeMessage("", serviceNameByte, serviceMethodByte, metaDataBytes, byte(protocol.Request), byte(compressorType), byte(b.options.serializationType), requestBytes)
    if err != nil {
        return "", err
    }
    // 2. 发送消息
    if b.options.writeTimeout > 0 {
        now := time.Now()
        timeout := ctx.GetTimeout() // 如果ctx 存在设置 则采用 返之使用默认配置
        if timeout > 0 {
            b.conn.SetDeadline(now.Add(timeout))
            b.conn.SetWriteDeadline(now.Add(timeout))
        } else {
            b.conn.SetDeadline(now.Add(b.options.writeTimeout))
            b.conn.SetWriteDeadline(now.Add(b.options.writeTimeout))
        }
    }
    // 写MAP
    b.respInterRM.Lock()
    b.respInterMap[magic] = &respMessage{
        response: response,
        ctx:      ctx,
        respChan: respChan,
    }
    b.respInterRM.Unlock()

    // 有点暴力呀 直接上锁
    b.writeMu.Lock()
    _, err = b.conn.Write(message)
    b.writeMu.Unlock()
    if err != nil {
        if b.options.Trace {
            log.Println(err)
        }
        b.err = err
        return "", errors.WithStack(err)
    }

    return magic, nil
}
阅读 168

手把手教你用GO编写RPC框架
手把手教你用GO编写RPC框架

坎坷之路,终抵群星

43 声望
2 粉丝
0 条评论
你知道吗?

坎坷之路,终抵群星

43 声望
2 粉丝
宣传栏