【RAG 项目实战 04】添加多轮对话能力


NLP Github 项目:


[!NOTE] 添加多轮对话能力

  1. 存储对话历史
  2. 添加 session_id
  3. 提示模板中添加 chat_history
  4. RunnableWithMessageHistory 包装 Chain 添加对话历史能力
  5. 配置中使用 session_id 进行大模型交互,可以根据 session_id 区分不同用户的对话历史

一、添加多轮对话能力

01 存储对话历史

# 存储对话历史
store = {}


def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]

02 添加 session_id

# 添加 session_id
cl.user_session.set('session_id', 'abc2')

03 在提示模板中添加 chat_history

# 添加 session_id
cl.user_session.set('session_id', 'abc2')

# 提示模板中添加 chat_history
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You're a very knowledgeable historian who provides accurate and eloquent answers to historical questions.",
        ),
        MessagesPlaceholder("chat_history"),
        ("human", "{question}"),
    ]
)

04 使用 RunnableWithMessageHistory 包装 Chain

# 获取当前的 session_id
session_id = cl.user_session.get("session_id")

# 用 RunnableWithMessageHistory 包装 Chain 添加对话历史能力
runnable_with_history = RunnableWithMessageHistory(
    runnable,
    get_session_history,
    input_messages_key="question",
    history_messages_key="chat_history",
)

05 使用 session_id 进行大模型交互

msg = cl.Message(content="")

# 配置中使用 session_id 进行大模型交互
async for chunk in runnable_with_history.astream(
        {"question": message.content},
        config=RunnableConfig(configurable={"session_id": session_id},
                              callbacks=[cl.LangchainCallbackHandler()])
):
    await msg.stream_token(chunk)

await msg.send()
Tips:用 RunnableWithMessageHistory 包装 Chain 添加对话历史能力,不能将其插入到 Chain 中

二、效果展示

多轮能力展示

🌈 这是本人最喜欢的一首诗,它陪我走过不少困难时光。以此与诸君共勉之!祝君前程似锦~

《小松》 -- 唐 杜荀鹤
自小刺头深草里,而今渐觉出蓬蒿。
时人不识凌云木,直待凌云始道高。

三、完整代码

# @Author:青松
# 公众号:FasterAI
# 理想使命:让每个人的AI学习之路走的更容易些,若我的经验能为你前行的道路增添一丝轻松,我将倍感荣幸🌈🌈🌈 让知识传递更容易,用知识让生活更美好!
# Python, version 3.10.14
# Pytorch, version 2.3.0
# Chainlit, version 1.1.301

import chainlit as cl
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain.schema.runnable import Runnable
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_community.chat_models import QianfanChatEndpoint
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import MessagesPlaceholder
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory


model = QianfanChatEndpoint(
    streaming=True,
    model="ERNIE-Speed-8K",
)

# 存储对话历史
store = {}


def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]


@cl.password_auth_callback
def auth_callback(username: str, password: str):
    """ 持久化客户端聊天历史代码,不需要请删除 """
    if (username, password) == ("admin", "admin"):
        return cl.User(
            identifier="admin", metadata={"role": "admin", "provider": "credentials"}
        )
    else:
        return None


@cl.on_chat_start
async def on_chat_start():
    """ 监听会话开始事件 """
    # todo: 添加 FasterAI 知识星球图片以及 FastAI 知识库地址
    image = cl.Image(url="https://qingsong-1257401904.cos.ap-nanjing.myqcloud.com/wecaht.png")

    # 发送一个图片
    await cl.Message(
        content="**青松** 邀你关注 **FasterAI**, 让每个人的 AI 学习之路走的更容易些!开启 AI 学习、面试快车道 **(^_^)** ",
        elements=[image],
    ).send()

    # 添加 session_id
    cl.user_session.set('session_id', 'abc2')

    # 提示模板中添加 chat_history
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "你是一个中国古诗词专家,能准确的一字不差的背诵很多古诗词,请用你最大的能力来回答用户的问题。",
            ),
            MessagesPlaceholder("chat_history"),
            ("human", "{question}"),
        ]
    )
    runnable = prompt | model | StrOutputParser()
    cl.user_session.set("runnable", runnable)


@cl.on_message
async def on_message(message: cl.Message):
    """ 监听用户消息事件 """

    runnable = cl.user_session.get("runnable")  # type: Runnable

    # 获取当前的 session_id
    session_id = cl.user_session.get("session_id")

    # 用 RunnableWithMessageHistory 包装 Chain 添加对话历史能力
    runnable_with_history = RunnableWithMessageHistory(
        runnable,
        get_session_history,
        input_messages_key="question",
        history_messages_key="chat_history",
    )

    msg = cl.Message(content="")

    # 配置中使用 session_id 进行大模型交互
    async for chunk in runnable_with_history.astream(
            {"question": message.content},
            config=RunnableConfig(configurable={"session_id": session_id},
                                  callbacks=[cl.LangchainCallbackHandler()])
    ):
        await msg.stream_token(chunk)

    await msg.send()




【动手学 RAG】系列文章:

本文由mdnice多平台发布


青松
1 声望0 粉丝