websocket: cannot call recv while another coroutine is already waiting for the next message?

题目描述

多个协程读取websocket 返回的数据会报错,错误:cannot call recv while another coroutine is already waiting for the next message

题目来源及自己的思路

  1. client.py 里同时实现主动请求及被动接收广播的数据
  2. ws.py 里,在异步协程里收发请求

基本确定 是ws.py 有两个地方的协程都在读取websocket的数据导致。

 async def receive_message(self, callback=None):
        try:
            async with self.conn as conn:
                # 这里在读取数据
                async for message in conn:
                    if callback:
                        callback(message)
                    ....
    async def coroutine_make_request(self, request_data):
        request_data = json.dumps(request_data)

        async with self.conn as conn:
            await asyncio.wait_for(
                conn.send(request_data),
                self.timeout
            )
            # 以及这里也在读取websockets数据
            resp = await asyncio.wait_for(
                conn.recv(),
                self.timeout
            )
            return json.loads(resp)

两个地方都在读取websocket数据,导致报错。
实际看到报错:
image.png

相关代码

总共3个文件,拷贝即可运行。

client.py

import time

from ws import WSProvider


def receive(message):
    print("receive message:", message)


def main():
    client = WSProvider(endpoint="ws://localhost:9801", kwargs={})

    # Make request at first time
    resp = client.make_request({
        "req": "method_1"
    })
    print("1st:", resp)

    # Continuously receive message
    client.persistent_receive_message(receive)

    # Make request at second time
    # Likely to result in an error: cannot call recv while another coroutine is already waiting for the next message
    resp = client.make_request({
        "req": "method_2"
    })
    print("2nd:", resp)


if __name__ == "__main__":
    main()
    time.sleep(15)

server.py

import time
import websockets
import asyncio
import json
import threading

clients = set()


async def set_interval(interval: int, func):
    async def run():
        msg = time.asctime()
        await func(msg)
        threading.Timer(interval, schedule).start()

    def schedule():
        asyncio.run(run())
    threading.Timer(interval, schedule).start()


async def echo(ws):
    clients.add(ws)
    try:
        async for message in ws:
            msg = json.loads(message)
            msg["current_time"] = time.asctime()
            # send message back
            await ws.send(json.dumps(msg))
    except websockets.ConnectionClosed:
        pass
    finally:
        clients.remove(ws)


async def broadcast(msg: str):
    try:
        print(f"try send message to client {msg}, {len(clients)}")
        for client in clients:
            data = json.dumps({
                "current_time": msg
            })
            print(f"send message to client {data}")
            await client.send(data)
    except websockets.ConnectionClosed:
        pass


async def main():
    async with websockets.serve(echo, "localhost", "9801"):
        print("websocket server wake up")
        # run forever, broadcast time to every client per 5 seconds
        await set_interval(5, broadcast)
        await asyncio.Future()

asyncio.run(main())

ws.py

import asyncio
import json

import websockets

from typing import Type, Callable
from types import TracebackType
from threading import Thread


def start_event_loop(loop: asyncio.AbstractEventLoop) -> None:
    asyncio.set_event_loop(loop)
    loop.run_forever()
    loop.close()


def get_thread_loop() -> asyncio.AbstractEventLoop:
    new_loop = asyncio.new_event_loop()
    thread = Thread(target=start_event_loop, args=(new_loop,), daemon=True)
    thread.start()
    return new_loop


class PersistentConnection:
    def __init__(self, endpoint: str, kwargs) -> None:
        self.ws: websockets.WebSocketClientProtocol = None
        self.endpoint = endpoint
        self.kwargs = kwargs

    async def __aenter__(self):
        if self.ws is None:
            self.ws = await websockets.connect(uri=self.endpoint, **self.kwargs)
        return self.ws

    async def __aexit__(self, exec_type: Type[BaseException], exec_val: BaseException,  exec_traceback: TracebackType):
        if exec_val is not None:
            try:
                await self.ws.close()
            except websockets.ConnectionClosed:
                pass
            finally:
                self.ws = None


class WSProvider:
    _loop: asyncio.AbstractEventLoop = get_thread_loop()

    def __init__(self, *, endpoint: str, timeout: int = None, kwargs):
        self.endpoint = endpoint
        self.timeout = timeout
        self.kwargs = kwargs or {}

        if WSProvider._loop is None:
            WSProvider._loop = get_thread_loop()

        self.conn = PersistentConnection(self.endpoint, self.kwargs)

    # websocket 服务可能 持续广播消息过来
    def persistent_receive_message(self, callback: Callable):
        loop = WSProvider._loop
        if loop and loop.is_running():
            loop.create_task(self.receive_message(callback))
        else:
            asyncio.run(self.receive_message(callback))

    async def receive_message(self, callback=None):
        try:
            async with self.conn as conn:
                async for message in conn:
                    if callback:
                        callback(message)
        except websockets.ConnectionClosed:
            pass

    def make_request(self, request_data):
        future = asyncio.run_coroutine_threadsafe(
            self.coroutine_make_request(request_data),
            WSProvider._loop
        )
        data = future.result()
        return data

    async def coroutine_make_request(self, request_data):
        request_data = json.dumps(request_data)

        async with self.conn as conn:
            await asyncio.wait_for(
                conn.send(request_data),
                self.timeout
            )
            resp = await asyncio.wait_for(
                conn.recv(),
                self.timeout
            )
            return json.loads(resp)

你期待的结果是什么?实际看到的错误信息又是什么?

期望大佬给改良思路,改良此处代码实现。比如 coroutine_make_request 只返回asyncio.Future, 由 receive_message 去修改该 Future的状态。本人多次尝试都没成功。

async def coroutine_make_request(self, request_data):
    request_data = json.dumps(request_data)

    async with self.conn as conn:
        await asyncio.wait_for(
            conn.send(request_data),
            self.timeout
        )
        # 改良此处逻辑,将数据读取统一挪到 receive_message 内
        # resp = await asyncio.wait_for(
        #     conn.recv(),
        #     self.timeout
        # )
        # return json.loads(resp)

==========================

2023.8.3
我尝试改良WSProvider.py

class WSProvider:
    _loop: asyncio.AbstractEventLoop = get_thread_loop()

    def __init__(self, *, endpoint: str, timeout: int = None, kwargs):
        #...
        self.waiters: Dict[str, asyncio.Future] = {}
    # ...
    async def receive_message(self, callback=None):
        try:
            async with self.conn as conn:
                async for message in conn:
                    data = json.loads(message)
                    uid = data.get("uid")
                    # 这里去修改 future的状态,让 coroutine_make_request 能走过await
                    if uid and self.waiters.get(uid):
                        self.waiters.get(uid).set_result(message)
                        del self.waiters[uid]
                    if callback:
                        callback(message)
        except websockets.ConnectionClosed:
            pass

    def make_request(self, request_data):
        #...

    async def coroutine_make_request(self, request_data):
        uid = uuid4().hex
        request_data["uid"] = uid
        request_data = json.dumps(request_data)
        async with self.conn as conn:
            await asyncio.wait_for(
                conn.send(request_data),
                self.timeout
            )
            # resp = await asyncio.wait_for(
            #     conn.recv(),
            #     self.timeout
            # )
            # return json.loads(resp)

            # 我尝试改成这样,但是运行发现 await future 直接就把整个协程阻塞,
            # 其他eventloop的代码根本没办法继续运行,最后代码就停在这里
            future = asyncio.Future()
            self.waiters[uid] = future
            await future
            return future.result()

阅读 3.4k
1 个回答
import asyncio
import websockets

class WSProvider:
    def __init__(self, uri):
        self.uri = uri
        self.queue = asyncio.Queue()

    async def producer(self):
        async with websockets.connect(self.uri) as websocket:
            while True:
                message = await websocket.recv()
                await self.queue.put(message)

    async def consumer(self):
        while True:
            message = await self.queue.get()
            print(f"Received message: {message}")

    async def run(self):
        producer_coro = self.producer()
        consumer_coro = self.consumer()
        await asyncio.gather(producer_coro, consumer_coro)

ws_provider = WSProvider('ws://localhost:8765')
asyncio.run(ws_provider.run())
宣传栏