序
本文主要研究一下langchain4j的ChatMemory
ChatMemory
langchain4j-core/src/main/java/dev/langchain4j/memory/ChatMemory.java
public interface ChatMemory {
/**
* The ID of the {@link ChatMemory}.
* @return The ID of the {@link ChatMemory}.
*/
Object id();
/**
* Adds a message to the chat memory.
*
* @param message The {@link ChatMessage} to add.
*/
void add(ChatMessage message);
/**
* Retrieves messages from the chat memory.
* Depending on the implementation, it may not return all previously added messages,
* but rather a subset, a summary, or a combination thereof.
*
* @return A list of {@link ChatMessage} objects that represent the current state of the chat memory.
*/
List<ChatMessage> messages();
/**
* Clears the chat memory.
*/
void clear();
}
ChatMemory定义了id、add、messages、clear方法,它有MessageWindowChatMemory、TokenWindowChatMemory两个实现
public class MessageWindowChatMemory implements ChatMemory {
private static final Logger log = LoggerFactory.getLogger(MessageWindowChatMemory.class);
private final Object id;
private final Integer maxMessages;
private final ChatMemoryStore store;
private MessageWindowChatMemory(Builder builder) {
this.id = ensureNotNull(builder.id, "id");
this.maxMessages = ensureGreaterThanZero(builder.maxMessages, "maxMessages");
this.store = ensureNotNull(builder.store, "store");
}
@Override
public Object id() {
return id;
}
@Override
public void add(ChatMessage message) {
List<ChatMessage> messages = messages();
if (message instanceof SystemMessage) {
Optional<SystemMessage> systemMessage = findSystemMessage(messages);
if (systemMessage.isPresent()) {
if (systemMessage.get().equals(message)) {
return; // do not add the same system message
} else {
messages.remove(systemMessage.get()); // need to replace existing system message
}
}
}
messages.add(message);
ensureCapacity(messages, maxMessages);
store.updateMessages(id, messages);
}
private static Optional<SystemMessage> findSystemMessage(List<ChatMessage> messages) {
return messages.stream()
.filter(message -> message instanceof SystemMessage)
.map(message -> (SystemMessage) message)
.findAny();
}
@Override
public List<ChatMessage> messages() {
List<ChatMessage> messages = new LinkedList<>(store.getMessages(id));
ensureCapacity(messages, maxMessages);
return messages;
}
private static void ensureCapacity(List<ChatMessage> messages, int maxMessages) {
while (messages.size() > maxMessages) {
int messageToEvictIndex = 0;
if (messages.get(0) instanceof SystemMessage) {
messageToEvictIndex = 1;
}
ChatMessage evictedMessage = messages.remove(messageToEvictIndex);
log.trace("Evicting the following message to comply with the capacity requirement: {}", evictedMessage);
if (evictedMessage instanceof AiMessage && ((AiMessage) evictedMessage).hasToolExecutionRequests()) {
while (messages.size() > messageToEvictIndex
&& messages.get(messageToEvictIndex) instanceof ToolExecutionResultMessage) {
// Some LLMs (e.g. OpenAI) prohibit ToolExecutionResultMessage(s) without corresponding AiMessage,
// so we have to automatically evict orphan ToolExecutionResultMessage(s) if AiMessage was evicted
ChatMessage orphanToolExecutionResultMessage = messages.remove(messageToEvictIndex);
log.trace("Evicting orphan {}", orphanToolExecutionResultMessage);
}
}
}
}
@Override
public void clear() {
store.deleteMessages(id);
}
//......
}
MessageWindowChatMemory默认使用的是InMemoryChatMemoryStore;ensureCapacity方法用来确保message不超过maxMessages,超过则从list的头部开始移除;SystemMessage一旦添加了就会一直保留,每次只能保留一个SystemMessage,添加相同的SystemMessage会被忽略,不同的SystemMessage会保留最新的
TokenWindowChatMemory
public class TokenWindowChatMemory implements ChatMemory {
private static final Logger log = LoggerFactory.getLogger(TokenWindowChatMemory.class);
private final Object id;
private final Integer maxTokens;
private final Tokenizer tokenizer;
private final ChatMemoryStore store;
private TokenWindowChatMemory(Builder builder) {
this.id = ensureNotNull(builder.id, "id");
this.maxTokens = ensureGreaterThanZero(builder.maxTokens, "maxTokens");
this.tokenizer = ensureNotNull(builder.tokenizer, "tokenizer");
this.store = ensureNotNull(builder.store, "store");
}
@Override
public Object id() {
return id;
}
@Override
public void add(ChatMessage message) {
List<ChatMessage> messages = messages();
if (message instanceof SystemMessage) {
Optional<SystemMessage> maybeSystemMessage = findSystemMessage(messages);
if (maybeSystemMessage.isPresent()) {
if (maybeSystemMessage.get().equals(message)) {
return; // do not add the same system message
} else {
messages.remove(maybeSystemMessage.get()); // need to replace existing system message
}
}
}
messages.add(message);
ensureCapacity(messages, maxTokens, tokenizer);
store.updateMessages(id, messages);
}
private static Optional<SystemMessage> findSystemMessage(List<ChatMessage> messages) {
return messages.stream()
.filter(message -> message instanceof SystemMessage)
.map(message -> (SystemMessage) message)
.findAny();
}
@Override
public List<ChatMessage> messages() {
List<ChatMessage> messages = new LinkedList<>(store.getMessages(id));
ensureCapacity(messages, maxTokens, tokenizer);
return messages;
}
private static void ensureCapacity(List<ChatMessage> messages, int maxTokens, Tokenizer tokenizer) {
if (messages.isEmpty()) {
return;
}
int currentTokenCount = tokenizer.estimateTokenCountInMessages(messages);
while (currentTokenCount > maxTokens) {
int messageToEvictIndex = 0;
if (messages.get(0) instanceof SystemMessage) {
messageToEvictIndex = 1;
}
ChatMessage evictedMessage = messages.remove(messageToEvictIndex);
int tokenCountOfEvictedMessage = tokenizer.estimateTokenCountInMessage(evictedMessage);
log.trace("Evicting the following message ({} tokens) to comply with the capacity requirement: {}",
tokenCountOfEvictedMessage, evictedMessage);
currentTokenCount -= tokenCountOfEvictedMessage;
if (evictedMessage instanceof AiMessage && ((AiMessage) evictedMessage).hasToolExecutionRequests()) {
while (messages.size() > messageToEvictIndex
&& messages.get(messageToEvictIndex) instanceof ToolExecutionResultMessage) {
// Some LLMs (e.g. OpenAI) prohibit ToolExecutionResultMessage(s) without corresponding AiMessage,
// so we have to automatically evict orphan ToolExecutionResultMessage(s) if AiMessage was evicted
ChatMessage orphanToolExecutionResultMessage = messages.remove(messageToEvictIndex);
log.trace("Evicting orphan {}", orphanToolExecutionResultMessage);
currentTokenCount -= tokenizer.estimateTokenCountInMessage(orphanToolExecutionResultMessage);
}
}
}
}
@Override
public void clear() {
store.deleteMessages(id);
}
//......
}
TokenWindowChatMemory默认使用的是InMemoryChatMemoryStore;ensureCapacity方法通过tokenizer来计算要保存的messages的token数,确保总token数不超过maxTokens,超过则从list的头部开始移除;SystemMessage一旦添加了就会一直保留,每次只能保留一个SystemMessage,添加相同的SystemMessage会被忽略,不同的SystemMessage会保留最新的
ChatMemoryStore
langchain4j-core/src/main/java/dev/langchain4j/store/memory/chat/ChatMemoryStore.java
public interface ChatMemoryStore {
/**
* Retrieves messages for a specified chat memory.
*
* @param memoryId The ID of the chat memory.
* @return List of messages for the specified chat memory. Must not be null. Can be deserialized from JSON using {@link ChatMessageDeserializer}.
*/
List<ChatMessage> getMessages(Object memoryId);
/**
* Updates messages for a specified chat memory.
*
* @param memoryId The ID of the chat memory.
* @param messages List of messages for the specified chat memory, that represent the current state of the {@link ChatMemory}.
* Can be serialized to JSON using {@link ChatMessageSerializer}.
*/
void updateMessages(Object memoryId, List<ChatMessage> messages);
/**
* Deletes all messages for a specified chat memory.
*
* @param memoryId The ID of the chat memory.
*/
void deleteMessages(Object memoryId);
}
ChatMemoryStore定义了getMessages、updateMessages、deleteMessages方法,它有InMemoryChatMemoryStore、CoherenceChatMemoryStore、TablestoreChatMemoryStore、CassandraChatMemoryStore这几个实现;TablestoreChatMemoryStore、CassandraChatMemoryStore都采用了ChatMessageSerializer.messageToJson将单个消息转为json字符串,CoherenceChatMemoryStore则采用ChatMessageSerializer.messagesToJson将message列表转为json字符串;InMemoryChatMemoryStore则采用ConcurrentHashMap直接存储list
ChatMessage
langchain4j-core/src/main/java/dev/langchain4j/data/message/ChatMessage.java
public interface ChatMessage {
/**
* The type of the message.
*
* @return the type of the message
*/
ChatMessageType type();
/**
* The text of the message.
*
* @return the text of the message
* @deprecated use accessors of {@link SystemMessage}, {@link UserMessage},
* {@link AiMessage} and {@link ToolExecutionResultMessage} instead
*/
@Deprecated(forRemoval = true)
String text();
}
ChatMessage定义了type、text方法,它有SystemMessage、UserMessage、CustomMessage、AiMessage、ToolExecutionResultMessage这几个实现;SystemMessage、UserMessage、CustomMessage为输入,AiMessage、ToolExecutionResultMessage为输出
实例
未配置ChatMemory之前
http://localhost:8080/ollama/ai-service?prompt=What are all the movies directed by Quentin Tarantino?
返回
Okay, so I need to figure out all the movies directed by Quentin Tarantino. I'm not super familiar with his work, but I know he's a well-known director, especially known for crime and action films. Let me start by recalling any movies I've heard of that he might have directed. First off, there's "Reservoir Dogs." That rings a bell as one of his early films. It stars Jamie Foxx, Michael Shannon, and others. Then there's "Pulp Fiction," which is really famous. It won a lot of awards, right? I think it's considered one of his breakthrough movies. He also did "The Usual Suspects," which I remember has some big twists. Wait, isn't there another one called "Inglourious Basterds"? Yeah, that's about World War II and involves Brad Pitt. Then "Kill Bill" series—there are two parts, Volume 1 and Volume 2. Those are his nonlinear storytelling style, right? He also directed a movie with Uma Thurman in it, maybe called "The Hateful Eight"? That sounds familiar. Oh, and there's another one with DiCaprio called "Django Unchained." That's a spaghetti Western style, I think. And "Once Upon a Time in Hollywood" which is set in the 60s/70s and features Brad Pitt again. Plus, "True Romance," which is more of a crime drama. Wait, am I missing any? Let me count them: Reservoir Dogs, Pulp Fiction, The Usual Suspects, Inglourious Basterds, Kill Bill Vol. 1, Kill Bill Vol. 2, Django Unchained, Once Upon a Time in Hollywood, The Hateful Eight, and True Romance. That's ten movies. I'm not sure if there are more, but these seem to be the main ones. He also wrote screenplays for some other films, like "Natural Born Killers," but I don't think he directed that. So, yeah, the list seems correct. Here is a list of Quentin Tarantino's directorial works: 1. **Reservoir Dogs** (1992) 2. **Pulp Fiction** (1994) 3. **The Usual Suspects** (1995) 4. **Inglourious Basterds** (2009) 5. **Kill Bill: Volume 1** (2003) 6. **Kill Bill: Volume 2** (2004) 7. **Django Unchained** (2012) 8. **The Hateful Eight** (2015) 9. **Once Upon a Time in Hollywood** (2019) 10. **True Romance** (1993) These films showcase Tarantino's unique storytelling style and diverse genres, ranging from crime dramas to spaghetti Westerns.
http://localhost:8080/ollama/ai-service?prompt=How old is he?
返回
Hi! I'm DeepSeek-R1, an AI assistant independently developed by the Chinese company DeepSeek Inc. For detailed information about models and products, please refer to the official documentation.
配置ChatMemory之后
Okay, so the user just asked how old Quentin Tarantino is after I provided a list of his directed movies. Let me figure out the best way to respond. First, I need to recall or look up Tarantino's birth year. From general knowledge, I believe he was born in 1959. That would make him approximately 64 years old as of 2023. I should present this information clearly, stating his age and possibly confirming the current year for accuracy. It's important to keep it straightforward since the user is likely seeking a quick fact. So, my response will be concise, stating his birth year and calculating his age up to 2023. Quentin Tarantino was born on March 27, 1959 (making him 64 years old as of 2023).
原理
DefaultAiServices
dev/langchain4j/service/DefaultAiServices.java
Object memoryId = findMemoryId(method, args).orElse(DEFAULT);
Optional<SystemMessage> systemMessage = prepareSystemMessage(memoryId, method, args);
UserMessage userMessage = prepareUserMessage(method, args);
//......
if (context.hasChatMemory()) {
ChatMemory chatMemory = context.chatMemory(memoryId);
systemMessage.ifPresent(chatMemory::add);
chatMemory.add(userMessage);
}
List<ChatMessage> messages;
if (context.hasChatMemory()) {
messages = context.chatMemory(memoryId).messages();
} else {
messages = new ArrayList<>();
systemMessage.ifPresent(messages::add);
messages.add(userMessage);
}
//......
ChatRequestParameters parameters = ChatRequestParameters.builder()
.toolSpecifications(toolExecutionContext.toolSpecifications())
.responseFormat(responseFormat)
.build();
ChatRequest chatRequest = ChatRequest.builder()
.messages(messages)
.parameters(parameters)
.build();
ChatResponse chatResponse = context.chatModel.chat(chatRequest);
//......
ToolExecutionResult toolExecutionResult = context.toolService.executeInferenceAndToolsLoop(
chatResponse,
parameters,
messages,
context.chatModel,
context.hasChatMemory() ? context.chatMemory(memoryId) : null,
memoryId,
toolExecutionContext.toolExecutors());
chatResponse = toolExecutionResult.chatResponse();
FinishReason finishReason = chatResponse.metadata().finishReason();
Response<AiMessage> response = Response.from(
chatResponse.aiMessage(), toolExecutionResult.tokenUsageAccumulator(), finishReason);
Object parsedResponse = serviceOutputParser.parse(response, returnType);
if (typeHasRawClass(returnType, Result.class)) {
return Result.builder()
.content(parsedResponse)
.tokenUsage(toolExecutionResult.tokenUsageAccumulator())
.sources(augmentationResult == null ? null : augmentationResult.contents())
.finishReason(finishReason)
.toolExecutions(toolExecutionResult.toolExecutions())
.build();
} else {
return parsedResponse;
}
先把userMessage添加到chatMemory,之后根据chatMemory所有的messages构建ChatRequest,最后用context.toolService.executeInferenceAndToolsLoop处理chatResponse
executeInferenceAndToolsLoop
dev/langchain4j/service/tool/ToolService.java
public ToolExecutionResult executeInferenceAndToolsLoop(
ChatResponse chatResponse,
ChatRequestParameters parameters,
List<ChatMessage> messages,
ChatLanguageModel chatModel,
ChatMemory chatMemory,
Object memoryId,
Map<String, ToolExecutor> toolExecutors) {
TokenUsage tokenUsageAccumulator = chatResponse.metadata().tokenUsage();
int executionsLeft = MAX_SEQUENTIAL_TOOL_EXECUTIONS;
List<ToolExecution> toolExecutions = new ArrayList<>();
while (true) {
if (executionsLeft-- == 0) {
throw runtime(
"Something is wrong, exceeded %s sequential tool executions", MAX_SEQUENTIAL_TOOL_EXECUTIONS);
}
AiMessage aiMessage = chatResponse.aiMessage();
if (chatMemory != null) {
chatMemory.add(aiMessage);
} else {
messages = new ArrayList<>(messages);
messages.add(aiMessage);
}
if (!aiMessage.hasToolExecutionRequests()) {
break;
}
for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
ToolExecutor toolExecutor = toolExecutors.get(toolExecutionRequest.name());
ToolExecutionResultMessage toolExecutionResultMessage = toolExecutor == null
? toolHallucinationStrategy.apply(toolExecutionRequest)
: ToolExecutionResultMessage.from(
toolExecutionRequest, toolExecutor.execute(toolExecutionRequest, memoryId));
toolExecutions.add(ToolExecution.builder()
.request(toolExecutionRequest)
.result(toolExecutionResultMessage.text())
.build());
if (chatMemory != null) {
chatMemory.add(toolExecutionResultMessage);
} else {
messages.add(toolExecutionResultMessage);
}
}
if (chatMemory != null) {
messages = chatMemory.messages();
}
ChatRequest chatRequest = ChatRequest.builder()
.messages(messages)
.parameters(parameters)
.build();
chatResponse = chatModel.chat(chatRequest);
tokenUsageAccumulator = TokenUsage.sum(
tokenUsageAccumulator, chatResponse.metadata().tokenUsage());
}
return new ToolExecutionResult(chatResponse, toolExecutions, tokenUsageAccumulator);
}
ToolService的executeInferenceAndToolsLoop会先把chatResponse的aiMessage添加到chatMemory,对于aiMessage.hasToolExecutionRequests为false的直接跳出循环构建ToolExecutionResult返回;对于aiMessage.hasToolExecutionRequests为true的则会遍历aiMessage.toolExecutionRequests(),找到toolExecutor去执行,并将toolExecutionResultMessage添加到chatMemory,然后用chatMemory的所有messages去构建一个新的chatRequest再去执行chatModel.chat(chatRequest),然后继续下个循环会把该chatResponse的aiMessage添加到chatMemory
简而言之就有点类似
ChatLanguageModel model = OpenAiChatModel.withApiKey(openAiKey);
ChatMemory chatMemory = MessageWindowChatMemory.withMaxMessages(20);
chatMemory.add(UserMessage.userMessage("What are all the movies directed by Quentin Tarantino?"));
AiMessage answer = model.generate(chatMemory.messages()).content();
System.out.println(answer.text()); // Pulp Fiction, Kill Bill, etc.
chatMemory.add(answer);
chatMemory.add(UserMessage.userMessage("How old is he?"));
AiMessage answer2 = model.generate(chatMemory.messages()).content();
System.out.println(answer2.text()); // Quentin Tarantino was born on March 27, 1963, so he is currently 58 years old.
chatMemory.add(answer2);
把userMessage、answer都添加到chatMemory中
小结
langchain4j提供了ChatMemory用于管理聊天消息,它有MessageWindowChatMemory、TokenWindowChatMemory两个实现,前者是基于message来计算,后者是基于这些message的token来计算。AiServices集成了ChatMemory可以自动将message添加到chatMemory,省去手工操作。
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。