tcp handler

tcp handler 处理每一个tcp connection

type tcpServer struct {
    ctx *context
}

func (p *tcpServer) Handle(clientConn net.Conn) {
    p.ctx.nsqd.logf("TCP: new client(%s)", clientConn.RemoteAddr())

    // The client should initialize itself by sending a 4 byte sequence indicating
    // the version of the protocol that it intends to communicate, this will allow us
    // to gracefully upgrade the protocol away from text/line oriented to whatever...
    // ztd: 客户端每次建立连接后的第一条消息都会发协议版本过来,从代码    
    //来看,目前只支持 v2
    buf := make([]byte, 4)
    _, err := io.ReadFull(clientConn, buf)
    if err != nil {
        p.ctx.nsqd.logf("ERROR: failed to read protocol version - %s", err)
        return
    }
    protocolMagic := string(buf)

    p.ctx.nsqd.logf("CLIENT(%s): desired protocol magic '%s'",
        clientConn.RemoteAddr(), protocolMagic)

    var prot protocol.Protocol
    switch protocolMagic {
    case "  V2":
        prot = &protocolV2{ctx: p.ctx}
    default:
        protocol.SendFramedResponse(clientConn, frameTypeError, []byte("E_BAD_PROTOCOL"))
        clientConn.Close()
        p.ctx.nsqd.logf("ERROR: client(%s) bad protocol magic '%s'",
            clientConn.RemoteAddr(), protocolMagic)
        return
    }
    // ztd: 调用了IOLoop 函数来处理客户端的连接
    err = prot.IOLoop(clientConn)
    if err != nil {
        p.ctx.nsqd.logf("ERROR: client(%s) - %s", clientConn.RemoteAddr(), err)
        return
    }
}

IOloop对于tcp connection 的处理:

func (p *protocolV2) IOLoop(conn net.Conn) error {
    var err error
    var line []byte
    var zeroTime time.Time

    clientID := atomic.AddInt64(&p.ctx.nsqd.clientIDSequence, 1)
    client := newClientV2(clientID, conn, p.ctx)

    // synchronize the startup of messagePump in order
    // to guarantee that it gets a chance to initialize
    // goroutine local state derived from client attributes
    // and avoid a potential race with IDENTIFY (where a client
    // could have changed or disabled said attributes)
    messagePumpStartedChan := make(chan bool)
    go p.messagePump(client, messagePumpStartedChan)
    <-messagePumpStartedChan

    for {
        if client.HeartbeatInterval > 0 {
            client.SetReadDeadline(time.Now().Add(client.HeartbeatInterval * 2))
        } else {
            client.SetReadDeadline(zeroTime)
        }

        // ReadSlice does not allocate new space for the data each request
        // ie. the returned slice is only valid until the next call to it
        line, err = client.Reader.ReadSlice('\n')
        if err != nil {
            if err == io.EOF {
                err = nil
            } else {
                err = fmt.Errorf("failed to read command - %s", err)
            }
            break
        }

        // trim the '\n'
        line = line[:len(line)-1]
        // optionally trim the '\r'
        if len(line) > 0 && line[len(line)-1] == '\r' {
            line = line[:len(line)-1]
        }
        // ztd:命令以空格分隔
        params := bytes.Split(line, separatorBytes)

        if p.ctx.nsqd.getOpts().Verbose {
            p.ctx.nsqd.logf("PROTOCOL(V2): [%s] %s", client, params)
        }

        var response []byte
        // ztd: 执行命令,不同的命令执行不同的函数,后面对照一个典型的client 讨论
        response, err = p.Exec(client, params)
        if err != nil {
            ctx := ""
            if parentErr := err.(protocol.ChildErr).Parent(); parentErr != nil {
                ctx = " - " + parentErr.Error()
            }
            p.ctx.nsqd.logf("ERROR: [%s] - %s%s", client, err, ctx)

            sendErr := p.Send(client, frameTypeError, []byte(err.Error()))
            if sendErr != nil {
                p.ctx.nsqd.logf("ERROR: [%s] - %s%s", client, sendErr, ctx)
                break
            }

            // errors of type FatalClientErr should forceably close the connection
            if _, ok := err.(*protocol.FatalClientErr); ok {
                break
            }
            continue
        }
        
        if response != nil {
            err = p.Send(client, frameTypeResponse, response)
            if err != nil {
                err = fmt.Errorf("failed to send response - %s", err)
                break
            }
        }
    }
    // ztd: 收到EOF表明客户端关闭了连接
    p.ctx.nsqd.logf("PROTOCOL(V2): [%s] exiting ioloop", client)
    conn.Close()
    close(client.ExitChan)
    if client.Channel != nil {
        client.Channel.RemoveClient(client.ID)
    }

    return err
}

参照官网的consumer 示例写了一个简单的client,这个client 的功能就是订阅一个topic 和 channel,当有producer 向这个channel 发消息时,将消息打印在屏幕上。希望通过交互的过程来进一步理解server NSQD. 如下

package main

import (
    "fmt"

    nsq "github.com/nsqio/go-nsq"
)

func main() {
    config := nsq.NewConfig()

    c, err := nsq.NewConsumer("nsq", "consumer", config)
    if err != nil {
        fmt.Println("Failed to init consumer: ", err.Error())
        return
    }

    c.AddHandler(nsq.HandlerFunc(func(m *nsq.Message) error {
        fmt.Println("received message: ", string(m.Body))
        m.Finish()
        return nil
    }))

    err = c.ConnectToNSQD("127.0.0.1:4150")
    if err != nil {
        fmt.Println("Failed to connect to nsqd: ", err.Error())
        return
    }

    <-c.StopChan
}

在ConnectToNSQD 过程中,有两步与server 端的交互。第一步:

resp, err := conn.Connect()

在Connect 部分:

    conn, err := dialer.Dial("tcp", c.addr)
    if err != nil {
        return nil, err
    }
    c.conn = conn.(*net.TCPConn)
    c.r = conn
    c.w = conn

    _, err = c.Write(MagicV2)

tcp 连接建立后,会向server 端发送协议版本号,正如我们在tcp handler 看到的那样,每次连接建立后都会先收到一个协议版本号。
第二步交互:

    cmd := Subscribe(r.topic, r.channel)
    err = conn.WriteCommand(cmd)

客户端会向服务器端发送"SUB topic channel" 这样一条命令。
下面来看server 端是如何处理这个命令的
在方法func (p *protocolV2) SUB(client *clientV2, params [][]byte) ([]byte, error) {中,我们略过各种各样的检查:

    topic := p.ctx.nsqd.GetTopic(topicName)
    channel := topic.GetChannel(channelName)
    channel.AddClient(client.ID, client)

    atomic.StoreInt32(&client.State, stateSubscribed)
    client.Channel = channel
    // update message pump
    client.SubEventChan <- channel

除了给channel 添加了一个client 和给client 分配一个channel意外,还更新了message pump,好了,是时候来看看这个message pump 都做了什么了。

在tcp Handler 中,有这样的代码

    messagePumpStartedChan := make(chan bool)
    go p.messagePump(client, messagePumpStartedChan)
    <-messagePumpStartedChan

进入到protocal_v2 的messagePump 中:

func (p *protocolV2) messagePump(client *clientV2, startedChan chan bool) {
    var err error
    var buf bytes.Buffer
    var memoryMsgChan chan *Message
    var backendMsgChan chan []byte
    var subChannel *Channel
    // NOTE: `flusherChan` is used to bound message latency for
    // the pathological case of a channel on a low volume topic
    // with >1 clients having >1 RDY counts
    var flusherChan <-chan time.Time
    var sampleRate int32

    subEventChan := client.SubEventChan
    identifyEventChan := client.IdentifyEventChan
    outputBufferTicker := time.NewTicker(client.OutputBufferTimeout)
    heartbeatTicker := time.NewTicker(client.HeartbeatInterval)
    heartbeatChan := heartbeatTicker.C
    msgTimeout := client.MsgTimeout

    // v2 opportunistically buffers data to clients to reduce write system calls
    // we force flush in two cases:
    //    1. when the client is not ready to receive messages
    //    2. we're buffered and the channel has nothing left to send us
    //       (ie. we would block in this loop anyway)
    //
    flushed := true

    // signal to the goroutine that started the messagePump
    // that we've started up
    close(startedChan)

    for {
        if subChannel == nil || !client.IsReadyForMessages() {
            // the client is not ready to receive messages...
            memoryMsgChan = nil
            backendMsgChan = nil
            flusherChan = nil
            // force flush
            client.writeLock.Lock()
            err = client.Flush()
            client.writeLock.Unlock()
            if err != nil {
                goto exit
            }
            flushed = true
        // ztd: 一旦subChannel 不是nil,就将memoryMsgChan 和            
        //backendMsgChan 赋值
        } else if flushed {
            // last iteration we flushed...
            // do not select on the flusher ticker channel
            memoryMsgChan = subChannel.memoryMsgChan
            backendMsgChan = subChannel.backend.ReadChan()
            flusherChan = nil
        } else {
            // we're buffered (if there isn't any more data we should flush)...
            // select on the flusher ticker channel, too
            memoryMsgChan = subChannel.memoryMsgChan
            backendMsgChan = subChannel.backend.ReadChan()
            flusherChan = outputBufferTicker.C
        }

        select {
        case <-flusherChan:
            // if this case wins, we're either starved
            // or we won the race between other channels...
            // in either case, force flush
            client.writeLock.Lock()
            err = client.Flush()
            client.writeLock.Unlock()
            if err != nil {
                goto exit
            }
            flushed = true
        case <-client.ReadyStateChan:
        // ztd: 在SUB 函数里client.SubEventChan <- channel 就是
        // 给这个subChannel 赋了值
        case subChannel = <-subEventChan:
            // you can't SUB anymore
            subEventChan = nil
        case identifyData := <-identifyEventChan:
            // you can't IDENTIFY anymore
            identifyEventChan = nil

            outputBufferTicker.Stop()
            if identifyData.OutputBufferTimeout > 0 {
                outputBufferTicker = time.NewTicker(identifyData.OutputBufferTimeout)
            }

            heartbeatTicker.Stop()
            heartbeatChan = nil
            if identifyData.HeartbeatInterval > 0 {
                heartbeatTicker = time.NewTicker(identifyData.HeartbeatInterval)
                heartbeatChan = heartbeatTicker.C
            }

            if identifyData.SampleRate > 0 {
                sampleRate = identifyData.SampleRate
            }

            msgTimeout = identifyData.MsgTimeout
        case <-heartbeatChan:
            err = p.Send(client, frameTypeResponse, heartbeatBytes)
            if err != nil {
                goto exit
            }
        case b := <-backendMsgChan:
            if sampleRate > 0 && rand.Int31n(100) > sampleRate {
                continue
            }

            msg, err := decodeMessage(b)
            if err != nil {
                p.ctx.nsqd.logf("ERROR: failed to decode message - %s", err)
                continue
            }
            msg.Attempts++

            subChannel.StartInFlightTimeout(msg, client.ID, msgTimeout)
            client.SendingMessage()
            err = p.SendMessage(client, msg, &buf)
            if err != nil {
                goto exit
            }
            flushed = false
        case msg := <-memoryMsgChan:
            if sampleRate > 0 && rand.Int31n(100) > sampleRate {
                continue
            }
            msg.Attempts++

            subChannel.StartInFlightTimeout(msg, client.ID, msgTimeout)
            client.SendingMessage()
            err = p.SendMessage(client, msg, &buf)
            if err != nil {
                goto exit
            }
            flushed = false
        case <-client.ExitChan:
            goto exit
        }
    }

exit:
    p.ctx.nsqd.logf("PROTOCOL(V2): [%s] exiting messagePump", client)
    heartbeatTicker.Stop()
    outputBufferTicker.Stop()
    if err != nil {
        p.ctx.nsqd.logf("PROTOCOL(V2): [%s] messagePump error - %s", client, err)
    }
}```
在这段代码里,一旦有client 订阅一个channel,就开始监听memoryMsgChan, 等待producer 发送信息过来。下面看一下pub 的过程 :
topic := p.ctx.nsqd.GetTopic(topicName)
msg := NewMessage(topic.GenerateID(), messageBody)
err = topic.PutMessage(msg)
if err != nil {
    return nil, protocol.NewFatalClientErr(err, "E_PUB_FAILED", "PUB failed "+err.Error())
}

在做了一系列检查之后,向topic put 了一条message。 每次new 一个topic 的时候,会启动一个topic的 messagePump:
// Topic constructor
func NewTopic(topicName string, ctx *context, deleteCallback func(*Topic)) *Topic {
t := &Topic{
    name:              topicName,
    channelMap:        make(map[string]*Channel),
    memoryMsgChan:     make(chan *Message, ctx.nsqd.getOpts().MemQueueSize),
    exitChan:          make(chan int),
    channelUpdateChan: make(chan int),
    ctx:               ctx,
    pauseChan:         make(chan bool),
    deleteCallback:    deleteCallback,
    idFactory:         NewGUIDFactory(ctx.nsqd.getOpts().ID),
}

if strings.HasSuffix(topicName, "#ephemeral") {
    t.ephemeral = true
    t.backend = newDummyBackendQueue()
} else {
    t.backend = diskqueue.New(topicName,
        ctx.nsqd.getOpts().DataPath,
        ctx.nsqd.getOpts().MaxBytesPerFile,
        int32(minValidMsgLength),
        int32(ctx.nsqd.getOpts().MaxMsgSize)+minValidMsgLength,
        ctx.nsqd.getOpts().SyncEvery,
        ctx.nsqd.getOpts().SyncTimeout,
        ctx.nsqd.getOpts().Logger)
}

t.waitGroup.Wrap(func() { t.messagePump() })

t.ctx.nsqd.Notify(t)

return t
}

在messagePump 中,会监听topic 的memoryMsgChan:
for {
    select {
    case msg = <-memoryMsgChan:

而每次收到一个消息,会向topic 下面所有的channel 进行广播:
    for i, channel := range chans {
        chanMsg := msg
        // copy the message because each channel
        // needs a unique instance but...
        // fastpath to avoid copy if its the first channel
        // (the topic already created the first copy)
        if i > 0 {
            chanMsg = NewMessage(msg.ID, msg.Body)
            chanMsg.Timestamp = msg.Timestamp
            chanMsg.deferred = msg.deferred
        }
        if chanMsg.deferred != 0 {
            channel.PutMessageDeferred(chanMsg, chanMsg.deferred)
            continue
        }
        err := channel.PutMessage(chanMsg)
        if err != nil {
            t.ctx.nsqd.logf(
                "TOPIC(%s) ERROR: failed to put msg(%s) to channel(%s) - %s",
                t.name, msg.ID, channel.name, err)
        }
    }
如果对nsq 有了解的话,会知道每一个topic 会将一个msg 广播给所有的channel,这个逻辑的实现就在这块。
channel 的PutMessage:
func (c *Channel) put(m *Message) error {
    select {
    case c.memoryMsgChan <- m:
将消息塞入了channel 的memoryMsgChan中。这时,代码又回到了protocal_v2 的 messagePump中:
    case msg := <-memoryMsgChan:
        if sampleRate > 0 && rand.Int31n(100) > sampleRate {
            continue
        }
        msg.Attempts++
        // ztd: 出于可靠性的考虑,将消息发送到subscriber 并不真    
        // 正将消息删除,而是设置过期时间后,将消息缓存起来
        subChannel.StartInFlightTimeout(msg, client.ID, msgTimeout)
        
        client.SendingMessage()
        // ztd: 向客户端发送消息
        err = p.SendMessage(client, msg, &buf)
        if err != nil {
            goto exit
        }
        flushed = false
在客户端的代码中,我们注意有一行: `m.Finish()`,这行代码是告诉server 端我已经消费完这条信息了,可以丢弃了。这行代码向server 端发送一条`FIN` 命令。在server 端:
 // FinishMessage successfully discards an in-flight message
 func (c *Channel) FinishMessage(clientID int64, id MessageID) error {
    msg, err := c.popInFlightMessage(clientID, id)
    if err != nil {
        return err
    }
    c.removeFromInFlightPQ(msg)
    if c.e2eProcessingLatencyStream != nil {
        c.e2eProcessingLatencyStream.Insert(msg.Timestamp)
    }
    return nil
}
作者将这个msg 加到了两个queue,一个是message queue(数据结构其实是个map),另外一个是infligtPG,后面会讲到inflightPQ 的作用。
如果没有及时Finish 消息,怎么处理timeout 的消息呢?在NSQD的 Main 函数中,启动了一个queueScanLoop:
n.waitGroup.Wrap(func() { n.queueScanLoop() })
在这个loop 中,设置了一个ticker,每过一段时间,就会执行resizePool:
func (n *NSQD) queueScanLoop() {
workCh := make(chan *Channel, n.getOpts().QueueScanSelectionCount)
responseCh := make(chan bool, n.getOpts().QueueScanSelectionCount)
closeCh := make(chan int)

workTicker := time.NewTicker(n.getOpts().QueueScanInterval)
refreshTicker := time.NewTicker(n.getOpts().QueueScanRefreshInterval)

channels := n.channels()
n.resizePool(len(channels), workCh, responseCh, closeCh)

for {
    select {
    case <-workTicker.C:
        if len(channels) == 0 {
            continue
        }
    case <-refreshTicker.C:
        channels = n.channels()
        n.resizePool(len(channels), workCh, responseCh, closeCh)
        continue
    case <-n.exitChan:
        goto exit
    }
resizePool 中,执行了这么个函数:
func (c *Channel) processInFlightQueue(t int64) bool {
c.exitMutex.RLock()
defer c.exitMutex.RUnlock()

if c.Exiting() {
    return false
}

dirty := false
for {
    c.inFlightMutex.Lock()
    // ztd: 取出一个过期的
    msg, _ := c.inFlightPQ.PeekAndShift(t)
    c.inFlightMutex.Unlock()

    if msg == nil {
        goto exit
    }
    dirty = true

    _, err := c.popInFlightMessage(msg.clientID, msg.ID)
    if err != nil {
        goto exit
    }
    atomic.AddUint64(&c.timeoutCount, 1)
    c.RLock()
    client, ok := c.clients[msg.clientID]
    c.RUnlock()
    if ok {
        client.TimedOutMessage()
    }
    c.put(msg)
}

exit:
return dirty

}

这个函数不断从inflightPqueue 中取出一个过期的,从inflightMsgQueue 中删除。PeekAndShift:
func (pq *inFlightPqueue) PeekAndShift(max int64) (*Message, int64) {
if len(*pq) == 0 {
    return nil, 0
}

x := (*pq)[0]
if x.pri > max {
    return nil, x.pri - max
}
pq.Pop()

return x, 0

}

inflightPqueue 的数据结构是一个最小堆,每次push 一条新的消息:
func (pq *inFlightPqueue) up(j int) {
for {
    i := (j - 1) / 2 // parent
    if i == j || (*pq)[j].pri >= (*pq)[i].pri {
        break
    }
    pq.Swap(i, j)
    j = i
}
}
所以,堆顶的消息是最近一个过期的消息,如果最近一条过期的消息都还没有过期,那就没有过期的消息。如果有过期的,就pop 出来。这样在for 循环中不断把过期消息pop 出来,直到没有过期的消息。

buptztd
91 声望23 粉丝

Gopher