题目描述
多个协程读取websocket 返回的数据会报错,错误:cannot call recv while another coroutine is already waiting for the next message
题目来源及自己的思路
- client.py 里同时实现主动请求及被动接收广播的数据
- ws.py 里,在异步协程里收发请求
基本确定 是ws.py 有两个地方的协程都在读取websocket的数据导致。
async def receive_message(self, callback=None):
try:
async with self.conn as conn:
# 这里在读取数据
async for message in conn:
if callback:
callback(message)
....
async def coroutine_make_request(self, request_data):
request_data = json.dumps(request_data)
async with self.conn as conn:
await asyncio.wait_for(
conn.send(request_data),
self.timeout
)
# 以及这里也在读取websockets数据
resp = await asyncio.wait_for(
conn.recv(),
self.timeout
)
return json.loads(resp)
两个地方都在读取websocket数据,导致报错。
实际看到报错:
相关代码
总共3个文件,拷贝即可运行。
client.py
import time
from ws import WSProvider
def receive(message):
print("receive message:", message)
def main():
client = WSProvider(endpoint="ws://localhost:9801", kwargs={})
# Make request at first time
resp = client.make_request({
"req": "method_1"
})
print("1st:", resp)
# Continuously receive message
client.persistent_receive_message(receive)
# Make request at second time
# Likely to result in an error: cannot call recv while another coroutine is already waiting for the next message
resp = client.make_request({
"req": "method_2"
})
print("2nd:", resp)
if __name__ == "__main__":
main()
time.sleep(15)
server.py
import time
import websockets
import asyncio
import json
import threading
clients = set()
async def set_interval(interval: int, func):
async def run():
msg = time.asctime()
await func(msg)
threading.Timer(interval, schedule).start()
def schedule():
asyncio.run(run())
threading.Timer(interval, schedule).start()
async def echo(ws):
clients.add(ws)
try:
async for message in ws:
msg = json.loads(message)
msg["current_time"] = time.asctime()
# send message back
await ws.send(json.dumps(msg))
except websockets.ConnectionClosed:
pass
finally:
clients.remove(ws)
async def broadcast(msg: str):
try:
print(f"try send message to client {msg}, {len(clients)}")
for client in clients:
data = json.dumps({
"current_time": msg
})
print(f"send message to client {data}")
await client.send(data)
except websockets.ConnectionClosed:
pass
async def main():
async with websockets.serve(echo, "localhost", "9801"):
print("websocket server wake up")
# run forever, broadcast time to every client per 5 seconds
await set_interval(5, broadcast)
await asyncio.Future()
asyncio.run(main())
ws.py
import asyncio
import json
import websockets
from typing import Type, Callable
from types import TracebackType
from threading import Thread
def start_event_loop(loop: asyncio.AbstractEventLoop) -> None:
asyncio.set_event_loop(loop)
loop.run_forever()
loop.close()
def get_thread_loop() -> asyncio.AbstractEventLoop:
new_loop = asyncio.new_event_loop()
thread = Thread(target=start_event_loop, args=(new_loop,), daemon=True)
thread.start()
return new_loop
class PersistentConnection:
def __init__(self, endpoint: str, kwargs) -> None:
self.ws: websockets.WebSocketClientProtocol = None
self.endpoint = endpoint
self.kwargs = kwargs
async def __aenter__(self):
if self.ws is None:
self.ws = await websockets.connect(uri=self.endpoint, **self.kwargs)
return self.ws
async def __aexit__(self, exec_type: Type[BaseException], exec_val: BaseException, exec_traceback: TracebackType):
if exec_val is not None:
try:
await self.ws.close()
except websockets.ConnectionClosed:
pass
finally:
self.ws = None
class WSProvider:
_loop: asyncio.AbstractEventLoop = get_thread_loop()
def __init__(self, *, endpoint: str, timeout: int = None, kwargs):
self.endpoint = endpoint
self.timeout = timeout
self.kwargs = kwargs or {}
if WSProvider._loop is None:
WSProvider._loop = get_thread_loop()
self.conn = PersistentConnection(self.endpoint, self.kwargs)
# websocket 服务可能 持续广播消息过来
def persistent_receive_message(self, callback: Callable):
loop = WSProvider._loop
if loop and loop.is_running():
loop.create_task(self.receive_message(callback))
else:
asyncio.run(self.receive_message(callback))
async def receive_message(self, callback=None):
try:
async with self.conn as conn:
async for message in conn:
if callback:
callback(message)
except websockets.ConnectionClosed:
pass
def make_request(self, request_data):
future = asyncio.run_coroutine_threadsafe(
self.coroutine_make_request(request_data),
WSProvider._loop
)
data = future.result()
return data
async def coroutine_make_request(self, request_data):
request_data = json.dumps(request_data)
async with self.conn as conn:
await asyncio.wait_for(
conn.send(request_data),
self.timeout
)
resp = await asyncio.wait_for(
conn.recv(),
self.timeout
)
return json.loads(resp)
你期待的结果是什么?实际看到的错误信息又是什么?
期望大佬给改良思路,改良此处代码实现。比如 coroutine_make_request 只返回asyncio.Future, 由 receive_message 去修改该 Future的状态。本人多次尝试都没成功。
async def coroutine_make_request(self, request_data):
request_data = json.dumps(request_data)
async with self.conn as conn:
await asyncio.wait_for(
conn.send(request_data),
self.timeout
)
# 改良此处逻辑,将数据读取统一挪到 receive_message 内
# resp = await asyncio.wait_for(
# conn.recv(),
# self.timeout
# )
# return json.loads(resp)
==========================
2023.8.3
我尝试改良WSProvider.py
class WSProvider:
_loop: asyncio.AbstractEventLoop = get_thread_loop()
def __init__(self, *, endpoint: str, timeout: int = None, kwargs):
#...
self.waiters: Dict[str, asyncio.Future] = {}
# ...
async def receive_message(self, callback=None):
try:
async with self.conn as conn:
async for message in conn:
data = json.loads(message)
uid = data.get("uid")
# 这里去修改 future的状态,让 coroutine_make_request 能走过await
if uid and self.waiters.get(uid):
self.waiters.get(uid).set_result(message)
del self.waiters[uid]
if callback:
callback(message)
except websockets.ConnectionClosed:
pass
def make_request(self, request_data):
#...
async def coroutine_make_request(self, request_data):
uid = uuid4().hex
request_data["uid"] = uid
request_data = json.dumps(request_data)
async with self.conn as conn:
await asyncio.wait_for(
conn.send(request_data),
self.timeout
)
# resp = await asyncio.wait_for(
# conn.recv(),
# self.timeout
# )
# return json.loads(resp)
# 我尝试改成这样,但是运行发现 await future 直接就把整个协程阻塞,
# 其他eventloop的代码根本没办法继续运行,最后代码就停在这里
future = asyncio.Future()
self.waiters[uid] = future
await future
return future.result()