序
本文主要研究一下Spring AI的Advisors
Advisor
org/springframework/ai/chat/client/advisor/api/Advisor.java
public interface Advisor extends Ordered {
/**
* Useful constant for the default Chat Memory precedence order. Ensures this order
* has lower priority (e.g. precedences) than the Spring AI internal advisors. It
* leaves room (1000 slots) for the user to plug in their own advisors with higher
* priority.
*/
int DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER = Ordered.HIGHEST_PRECEDENCE + 1000;
/**
* Return the name of the advisor.
* @return the advisor name.
*/
String getName();
}
Advisor接口继承了Ordered,定义了getName方法
CallAroundAdvisor
org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java
public interface CallAroundAdvisor extends Advisor {
/**
* Around advice that wraps the ChatModel#call(Prompt) method.
* @param advisedRequest the advised request
* @param chain the advisor chain
* @return the response
*/
AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain);
}
CallAroundAdvisor继承了Advisor接口,它定义了aroundCall方法,入参是AdvisedRequest及CallAroundAdvisorChain,返回AdvisedResponse
StreamAroundAdvisor
org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java
public interface StreamAroundAdvisor extends Advisor {
/**
* Around advice that wraps the invocation of the advised request.
* @param advisedRequest the advised request
* @param chain the chain of advisors to execute
* @return the result of the advised request
*/
Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain);
}
StreamAroundAdvisor继承了Advisor接口,它定义了aroundStream方法,入参是AdvisedRequest及StreamAroundAdvisorChain,返回的是Flux<AdvisedResponse>
SimpleLoggerAdvisor
org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java
public class SimpleLoggerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
public static final Function<AdvisedRequest, String> DEFAULT_REQUEST_TO_STRING = request -> request.toString();
public static final Function<ChatResponse, String> DEFAULT_RESPONSE_TO_STRING = response -> ModelOptionsUtils
.toJsonString(response);
private static final Logger logger = LoggerFactory.getLogger(SimpleLoggerAdvisor.class);
private final Function<AdvisedRequest, String> requestToString;
private final Function<ChatResponse, String> responseToString;
private int order;
public SimpleLoggerAdvisor() {
this(DEFAULT_REQUEST_TO_STRING, DEFAULT_RESPONSE_TO_STRING, 0);
}
public SimpleLoggerAdvisor(int order) {
this(DEFAULT_REQUEST_TO_STRING, DEFAULT_RESPONSE_TO_STRING, order);
}
public SimpleLoggerAdvisor(Function<AdvisedRequest, String> requestToString,
Function<ChatResponse, String> responseToString, int order) {
this.requestToString = requestToString;
this.responseToString = responseToString;
this.order = order;
}
@Override
public String getName() {
return this.getClass().getSimpleName();
}
@Override
public int getOrder() {
return this.order;
}
private AdvisedRequest before(AdvisedRequest request) {
logger.debug("request: {}", this.requestToString.apply(request));
return request;
}
private void observeAfter(AdvisedResponse advisedResponse) {
logger.debug("response: {}", this.responseToString.apply(advisedResponse.response()));
}
@Override
public String toString() {
return SimpleLoggerAdvisor.class.getSimpleName();
}
@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
advisedRequest = before(advisedRequest);
AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest);
observeAfter(advisedResponse);
return advisedResponse;
}
@Override
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
advisedRequest = before(advisedRequest);
Flux<AdvisedResponse> advisedResponses = chain.nextAroundStream(advisedRequest);
return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, this::observeAfter);
}
}
SimpleLoggerAdvisor实现了CallAroundAdvisor、StreamAroundAdvisor接口,内部定义了before、observeAfter方法打印request和response,其aroundCall、aroundStream分别在chain执行next之前调用before,之后调用observeAfter
SafeGuardAdvisor
org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java
public class SafeGuardAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
private final static String DEFAULT_FAILURE_RESPONSE = "I'm unable to respond to that due to sensitive content. Could we rephrase or discuss something else?";
private final static int DEFAULT_ORDER = 0;
private final String failureResponse;
private final List<String> sensitiveWords;
private final int order;
public SafeGuardAdvisor(List<String> sensitiveWords) {
this(sensitiveWords, DEFAULT_FAILURE_RESPONSE, DEFAULT_ORDER);
}
public SafeGuardAdvisor(List<String> sensitiveWords, String failureResponse, int order) {
Assert.notNull(sensitiveWords, "Sensitive words must not be null!");
Assert.notNull(failureResponse, "Failure response must not be null!");
this.sensitiveWords = sensitiveWords;
this.failureResponse = failureResponse;
this.order = order;
}
public static Builder builder() {
return new Builder();
}
public String getName() {
return this.getClass().getSimpleName();
}
@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
if (!CollectionUtils.isEmpty(this.sensitiveWords)
&& this.sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) {
return createFailureResponse(advisedRequest);
}
return chain.nextAroundCall(advisedRequest);
}
@Override
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
if (!CollectionUtils.isEmpty(this.sensitiveWords)
&& this.sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) {
return Flux.just(createFailureResponse(advisedRequest));
}
return chain.nextAroundStream(advisedRequest);
}
private AdvisedResponse createFailureResponse(AdvisedRequest advisedRequest) {
return new AdvisedResponse(ChatResponse.builder()
.withGenerations(List.of(new Generation(new AssistantMessage(this.failureResponse))))
.build(), advisedRequest.adviseContext());
}
@Override
public int getOrder() {
return this.order;
}
//......
}
SafeGuardAdvisor实现了CallAroundAdvisor、StreamAroundAdvisor接口,其构造器可以输入sensitiveWords、failureResponse、order,其aroundCall及aroundStream方法主要是执行before逻辑,通过判断用户的输入是否包含sensitiveWords实现安全拦截,命中的话返回failureResponse,不继续往下执行。
QuestionAnswerAdvisor
org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java
public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
public static final String RETRIEVED_DOCUMENTS = "qa_retrieved_documents";
public static final String FILTER_EXPRESSION = "qa_filter_expression";
private static final String DEFAULT_USER_TEXT_ADVISE = """
Context information is below, surrounded by ---------------------
---------------------
{question_answer_context}
---------------------
Given the context and provided history information and not prior knowledge,
reply to the user comment. If the answer is not in the context, inform
the user that you can't answer the question.
""";
private static final int DEFAULT_ORDER = 0;
private final VectorStore vectorStore;
private final String userTextAdvise;
private final SearchRequest searchRequest;
private final boolean protectFromBlocking;
private final int order;
/**
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
* combines it with the user's text.
* @param vectorStore The vector store to use
*/
public QuestionAnswerAdvisor(VectorStore vectorStore) {
this(vectorStore, SearchRequest.builder().build(), DEFAULT_USER_TEXT_ADVISE);
}
/**
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
* combines it with the user's text.
* @param vectorStore The vector store to use
* @param searchRequest The search request defined using the portable filter
* expression syntax
*/
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest) {
this(vectorStore, searchRequest, DEFAULT_USER_TEXT_ADVISE);
}
/**
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
* combines it with the user's text.
* @param vectorStore The vector store to use
* @param searchRequest The search request defined using the portable filter
* expression syntax
* @param userTextAdvise The user text to append to the existing user prompt. The text
* should contain a placeholder named "question_answer_context".
*/
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise) {
this(vectorStore, searchRequest, userTextAdvise, true);
}
/**
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
* combines it with the user's text.
* @param vectorStore The vector store to use
* @param searchRequest The search request defined using the portable filter
* expression syntax
* @param userTextAdvise The user text to append to the existing user prompt. The text
* should contain a placeholder named "question_answer_context".
* @param protectFromBlocking If true the advisor will protect the execution from
* blocking threads. If false the advisor will not protect the execution from blocking
* threads. This is useful when the advisor is used in a non-blocking environment. It
* is true by default.
*/
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise,
boolean protectFromBlocking) {
this(vectorStore, searchRequest, userTextAdvise, protectFromBlocking, DEFAULT_ORDER);
}
/**
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
* combines it with the user's text.
* @param vectorStore The vector store to use
* @param searchRequest The search request defined using the portable filter
* expression syntax
* @param userTextAdvise The user text to append to the existing user prompt. The text
* should contain a placeholder named "question_answer_context".
* @param protectFromBlocking If true the advisor will protect the execution from
* blocking threads. If false the advisor will not protect the execution from blocking
* threads. This is useful when the advisor is used in a non-blocking environment. It
* is true by default.
* @param order The order of the advisor.
*/
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise,
boolean protectFromBlocking, int order) {
Assert.notNull(vectorStore, "The vectorStore must not be null!");
Assert.notNull(searchRequest, "The searchRequest must not be null!");
Assert.hasText(userTextAdvise, "The userTextAdvise must not be empty!");
this.vectorStore = vectorStore;
this.searchRequest = searchRequest;
this.userTextAdvise = userTextAdvise;
this.protectFromBlocking = protectFromBlocking;
this.order = order;
}
public static Builder builder(VectorStore vectorStore) {
return new Builder(vectorStore);
}
@Override
public String getName() {
return this.getClass().getSimpleName();
}
@Override
public int getOrder() {
return this.order;
}
@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
AdvisedRequest advisedRequest2 = before(advisedRequest);
AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest2);
return after(advisedResponse);
}
@Override
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
// This can be executed by both blocking and non-blocking Threads
// E.g. a command line or Tomcat blocking Thread implementation
// or by a WebFlux dispatch in a non-blocking manner.
Flux<AdvisedResponse> advisedResponses = (this.protectFromBlocking) ?
// @formatter:off
Mono.just(advisedRequest)
.publishOn(Schedulers.boundedElastic())
.map(this::before)
.flatMapMany(request -> chain.nextAroundStream(request))
: chain.nextAroundStream(before(advisedRequest));
// @formatter:on
return advisedResponses.map(ar -> {
if (onFinishReason().test(ar)) {
ar = after(ar);
}
return ar;
});
}
private AdvisedRequest before(AdvisedRequest request) {
var context = new HashMap<>(request.adviseContext());
// 1. Advise the system text.
String advisedUserText = request.userText() + System.lineSeparator() + this.userTextAdvise;
// 2. Search for similar documents in the vector store.
String query = new PromptTemplate(request.userText(), request.userParams()).render();
var searchRequestToUse = SearchRequest.from(this.searchRequest)
.query(query)
.filterExpression(doGetFilterExpression(context))
.build();
List<Document> documents = this.vectorStore.similaritySearch(searchRequestToUse);
// 3. Create the context from the documents.
context.put(RETRIEVED_DOCUMENTS, documents);
String documentContext = documents.stream()
.map(Document::getText)
.collect(Collectors.joining(System.lineSeparator()));
// 4. Advise the user parameters.
Map<String, Object> advisedUserParams = new HashMap<>(request.userParams());
advisedUserParams.put("question_answer_context", documentContext);
AdvisedRequest advisedRequest = AdvisedRequest.from(request)
.userText(advisedUserText)
.userParams(advisedUserParams)
.adviseContext(context)
.build();
return advisedRequest;
}
private AdvisedResponse after(AdvisedResponse advisedResponse) {
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response());
chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, advisedResponse.adviseContext().get(RETRIEVED_DOCUMENTS));
return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext());
}
//......
}
QuestionAnswerAdvisor实现了CallAroundAdvisor及StreamAroundAdvisor接口,其构造器要求输入VectorStore;其before方法先执行vectorStore.similaritySearch,将结果作为question_answer_context添加到advisedUserParams,一起构建advisedRequest;其after方法将advisedResponse作为qa_retrieved_documents添加到chatResponse的metadata中
AbstractChatMemoryAdvisor
org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java
public abstract class AbstractChatMemoryAdvisor<T> implements CallAroundAdvisor, StreamAroundAdvisor {
/**
* The key to retrieve the chat memory conversation id from the context.
*/
public static final String CHAT_MEMORY_CONVERSATION_ID_KEY = "chat_memory_conversation_id";
/**
* The key to retrieve the chat memory response size from the context.
*/
public static final String CHAT_MEMORY_RETRIEVE_SIZE_KEY = "chat_memory_response_size";
/**
* The default conversation id to use when no conversation id is provided.
*/
public static final String DEFAULT_CHAT_MEMORY_CONVERSATION_ID = "default";
/**
* The default chat memory retrieve size to use when no retrieve size is provided.
*/
public static final int DEFAULT_CHAT_MEMORY_RESPONSE_SIZE = 100;
/**
* The chat memory store.
*/
protected final T chatMemoryStore;
/**
* The default conversation id.
*/
protected final String defaultConversationId;
/**
* The default chat memory retrieve size.
*/
protected final int defaultChatMemoryRetrieveSize;
/**
* Whether to protect from blocking.
*/
private final boolean protectFromBlocking;
/**
* The order of the advisor.
*/
private final int order;
/**
* Constructor to create a new {@link AbstractChatMemoryAdvisor} instance.
* @param chatMemory the chat memory store
*/
protected AbstractChatMemoryAdvisor(T chatMemory) {
this(chatMemory, DEFAULT_CHAT_MEMORY_CONVERSATION_ID, DEFAULT_CHAT_MEMORY_RESPONSE_SIZE, true);
}
/**
* Constructor to create a new {@link AbstractChatMemoryAdvisor} instance.
* @param chatMemory the chat memory store
* @param defaultConversationId the default conversation id
* @param defaultChatMemoryRetrieveSize the default chat memory retrieve size
* @param protectFromBlocking whether to protect from blocking
*/
protected AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId, int defaultChatMemoryRetrieveSize,
boolean protectFromBlocking) {
this(chatMemory, defaultConversationId, defaultChatMemoryRetrieveSize, protectFromBlocking,
Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER);
}
/**
* Constructor to create a new {@link AbstractChatMemoryAdvisor} instance.
* @param chatMemory the chat memory store
* @param defaultConversationId the default conversation id
* @param defaultChatMemoryRetrieveSize the default chat memory retrieve size
* @param protectFromBlocking whether to protect from blocking
* @param order the order
*/
protected AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId, int defaultChatMemoryRetrieveSize,
boolean protectFromBlocking, int order) {
Assert.notNull(chatMemory, "The chatMemory must not be null!");
Assert.hasText(defaultConversationId, "The conversationId must not be empty!");
Assert.isTrue(defaultChatMemoryRetrieveSize > 0, "The defaultChatMemoryRetrieveSize must be greater than 0!");
this.chatMemoryStore = chatMemory;
this.defaultConversationId = defaultConversationId;
this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize;
this.protectFromBlocking = protectFromBlocking;
this.order = order;
}
@Override
public String getName() {
return this.getClass().getSimpleName();
}
@Override
public int getOrder() {
// by default the (Ordered.HIGHEST_PRECEDENCE + 1000) value ensures this order has
// lower priority (e.g. precedences) than the internal Spring AI advisors. It
// leaves room (1000 slots) for the user to plug in their own advisors with higher
// priority.
return this.order;
}
/**
* Get the chat memory store.
* @return the chat memory store
*/
protected T getChatMemoryStore() {
return this.chatMemoryStore;
}
/**
* Get the default conversation id.
* @param context the context
* @return the default conversation id
*/
protected String doGetConversationId(Map<String, Object> context) {
return context.containsKey(CHAT_MEMORY_CONVERSATION_ID_KEY)
? context.get(CHAT_MEMORY_CONVERSATION_ID_KEY).toString() : this.defaultConversationId;
}
/**
* Get the default chat memory retrieve size.
* @param context the context
* @return the default chat memory retrieve size
*/
protected int doGetChatMemoryRetrieveSize(Map<String, Object> context) {
return context.containsKey(CHAT_MEMORY_RETRIEVE_SIZE_KEY)
? Integer.parseInt(context.get(CHAT_MEMORY_RETRIEVE_SIZE_KEY).toString())
: this.defaultChatMemoryRetrieveSize;
}
/**
* Execute the next advisor in the chain.
* @param advisedRequest the advised request
* @param chain the advisor chain
* @param beforeAdvise the before advise function
* @return the advised response
*/
protected Flux<AdvisedResponse> doNextWithProtectFromBlockingBefore(AdvisedRequest advisedRequest,
StreamAroundAdvisorChain chain, Function<AdvisedRequest, AdvisedRequest> beforeAdvise) {
// This can be executed by both blocking and non-blocking Threads
// E.g. a command line or Tomcat blocking Thread implementation
// or by a WebFlux dispatch in a non-blocking manner.
return (this.protectFromBlocking) ?
// @formatter:off
Mono.just(advisedRequest)
.publishOn(Schedulers.boundedElastic())
.map(beforeAdvise)
.flatMapMany(request -> chain.nextAroundStream(request))
: chain.nextAroundStream(beforeAdvise.apply(advisedRequest));
}
//......
}
AbstractChatMemoryAdvisor声明实现CallAroundAdvisor、StreamAroundAdvisor接口,它有三个实现类,分别是MessageChatMemoryAdvisor、PromptChatMemoryAdvisor、VectorStoreChatMemoryAdvisor
MessageChatMemoryAdvisor
public class MessageChatMemoryAdvisor extends AbstractChatMemoryAdvisor<ChatMemory> {
public MessageChatMemoryAdvisor(ChatMemory chatMemory) {
super(chatMemory);
}
public MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize) {
this(chatMemory, defaultConversationId, chatHistoryWindowSize, Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER);
}
public MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize,
int order) {
super(chatMemory, defaultConversationId, chatHistoryWindowSize, true, order);
}
public static Builder builder(ChatMemory chatMemory) {
return new Builder(chatMemory);
}
@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
advisedRequest = this.before(advisedRequest);
AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest);
this.observeAfter(advisedResponse);
return advisedResponse;
}
@Override
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
Flux<AdvisedResponse> advisedResponses = this.doNextWithProtectFromBlockingBefore(advisedRequest, chain,
this::before);
return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, this::observeAfter);
}
private AdvisedRequest before(AdvisedRequest request) {
String conversationId = this.doGetConversationId(request.adviseContext());
int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(request.adviseContext());
// 1. Retrieve the chat memory for the current conversation.
List<Message> memoryMessages = this.getChatMemoryStore().get(conversationId, chatMemoryRetrieveSize);
// 2. Advise the request messages list.
List<Message> advisedMessages = new ArrayList<>(request.messages());
advisedMessages.addAll(memoryMessages);
// 3. Create a new request with the advised messages.
AdvisedRequest advisedRequest = AdvisedRequest.from(request).messages(advisedMessages).build();
// 4. Add the new user input to the conversation memory.
UserMessage userMessage = new UserMessage(request.userText(), request.media());
this.getChatMemoryStore().add(this.doGetConversationId(request.adviseContext()), userMessage);
return advisedRequest;
}
private void observeAfter(AdvisedResponse advisedResponse) {
List<Message> assistantMessages = advisedResponse.response()
.getResults()
.stream()
.map(g -> (Message) g.getOutput())
.toList();
this.getChatMemoryStore().add(this.doGetConversationId(advisedResponse.adviseContext()), assistantMessages);
}
public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder<ChatMemory> {
protected Builder(ChatMemory chatMemory) {
super(chatMemory);
}
public MessageChatMemoryAdvisor build() {
return new MessageChatMemoryAdvisor(this.chatMemory, this.conversationId, this.chatMemoryRetrieveSize,
this.order);
}
}
}
MessageChatMemoryAdvisor继承了AbstractChatMemoryAdvisor,其泛型为ChatMemory;其before方法先获取conversationId、chatMemoryRetrieveSize,之后从chatMemoryStore获取memoryMessages,然后将请求的message与memoryMessages一起构造了advisedMessages形成advisedRequest;其observeAfter方法将返回的assistantMessages添加到chatMemoryStore
PromptChatMemoryAdvisor
public class PromptChatMemoryAdvisor extends AbstractChatMemoryAdvisor<ChatMemory> {
private static final String DEFAULT_SYSTEM_TEXT_ADVISE = """
Use the conversation memory from the MEMORY section to provide accurate answers.
---------------------
MEMORY:
{memory}
---------------------
""";
private final String systemTextAdvise;
public PromptChatMemoryAdvisor(ChatMemory chatMemory) {
this(chatMemory, DEFAULT_SYSTEM_TEXT_ADVISE);
}
public PromptChatMemoryAdvisor(ChatMemory chatMemory, String systemTextAdvise) {
super(chatMemory);
this.systemTextAdvise = systemTextAdvise;
}
public PromptChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize,
String systemTextAdvise) {
this(chatMemory, defaultConversationId, chatHistoryWindowSize, systemTextAdvise,
Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER);
}
public PromptChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize,
String systemTextAdvise, int order) {
super(chatMemory, defaultConversationId, chatHistoryWindowSize, true, order);
this.systemTextAdvise = systemTextAdvise;
}
public static Builder builder(ChatMemory chatMemory) {
return new Builder(chatMemory);
}
@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
advisedRequest = this.before(advisedRequest);
AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest);
this.observeAfter(advisedResponse);
return advisedResponse;
}
@Override
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
Flux<AdvisedResponse> advisedResponses = this.doNextWithProtectFromBlockingBefore(advisedRequest, chain,
this::before);
return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, this::observeAfter);
}
private AdvisedRequest before(AdvisedRequest request) {
// 1. Advise system parameters.
List<Message> memoryMessages = this.getChatMemoryStore()
.get(this.doGetConversationId(request.adviseContext()),
this.doGetChatMemoryRetrieveSize(request.adviseContext()));
String memory = (memoryMessages != null) ? memoryMessages.stream()
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
.map(m -> m.getMessageType() + ":" + ((Content) m).getText())
.collect(Collectors.joining(System.lineSeparator())) : "";
Map<String, Object> advisedSystemParams = new HashMap<>(request.systemParams());
advisedSystemParams.put("memory", memory);
// 2. Advise the system text.
String advisedSystemText = request.systemText() + System.lineSeparator() + this.systemTextAdvise;
// 3. Create a new request with the advised system text and parameters.
AdvisedRequest advisedRequest = AdvisedRequest.from(request)
.systemText(advisedSystemText)
.systemParams(advisedSystemParams)
.build();
// 4. Add the new user input to the conversation memory.
UserMessage userMessage = new UserMessage(request.userText(), request.media());
this.getChatMemoryStore().add(this.doGetConversationId(request.adviseContext()), userMessage);
return advisedRequest;
}
private void observeAfter(AdvisedResponse advisedResponse) {
List<Message> assistantMessages = advisedResponse.response()
.getResults()
.stream()
.map(g -> (Message) g.getOutput())
.toList();
this.getChatMemoryStore().add(this.doGetConversationId(advisedResponse.adviseContext()), assistantMessages);
}
//......
}
PromptChatMemoryAdvisor继承了AbstractChatMemoryAdvisor,其before方法先从chatMemoryStore获取memoryMessages,过滤出type为USER或者ASSISTANT的,作为memory加入到advisedSystemParams,一起构建AdvisedRequest,同时将userMessage添加到chatMemoryStore中;其observeAfter方法将返回的assistantMessages添加到chatMemoryStore
VectorStoreChatMemoryAdvisor
org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java
public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<VectorStore> {
private static final String DOCUMENT_METADATA_CONVERSATION_ID = "conversationId";
private static final String DOCUMENT_METADATA_MESSAGE_TYPE = "messageType";
private static final String DEFAULT_SYSTEM_TEXT_ADVISE = """
Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers.
---------------------
LONG_TERM_MEMORY:
{long_term_memory}
---------------------
""";
private final String systemTextAdvise;
public VectorStoreChatMemoryAdvisor(VectorStore vectorStore) {
this(vectorStore, DEFAULT_SYSTEM_TEXT_ADVISE);
}
public VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String systemTextAdvise) {
super(vectorStore);
this.systemTextAdvise = systemTextAdvise;
}
public VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId,
int chatHistoryWindowSize) {
this(vectorStore, defaultConversationId, chatHistoryWindowSize, DEFAULT_SYSTEM_TEXT_ADVISE);
}
public VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId,
int chatHistoryWindowSize, String systemTextAdvise) {
this(vectorStore, defaultConversationId, chatHistoryWindowSize, systemTextAdvise,
Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER);
}
/**
* Constructor for VectorStoreChatMemoryAdvisor.
* @param vectorStore the vector store instance used for managing and querying
* documents.
* @param defaultConversationId the default conversation ID used if none is provided
* in the context.
* @param chatHistoryWindowSize the window size for the chat history retrieval.
* @param systemTextAdvise the system text advice used for the chat advisor system.
* @param order the order of precedence for this advisor in the chain.
*/
public VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId,
int chatHistoryWindowSize, String systemTextAdvise, int order) {
super(vectorStore, defaultConversationId, chatHistoryWindowSize, true, order);
this.systemTextAdvise = systemTextAdvise;
}
public static Builder builder(VectorStore chatMemory) {
return new Builder(chatMemory);
}
@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
advisedRequest = this.before(advisedRequest);
AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest);
this.observeAfter(advisedResponse);
return advisedResponse;
}
@Override
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
Flux<AdvisedResponse> advisedResponses = this.doNextWithProtectFromBlockingBefore(advisedRequest, chain,
this::before);
// The observeAfter will certainly be executed on non-blocking Threads in case
// of some models - e.g. when the model client is a WebClient
return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, this::observeAfter);
}
private AdvisedRequest before(AdvisedRequest request) {
String advisedSystemText;
if (StringUtils.hasText(request.systemText())) {
advisedSystemText = request.systemText() + System.lineSeparator() + this.systemTextAdvise;
}
else {
advisedSystemText = this.systemTextAdvise;
}
var searchRequest = SearchRequest.builder()
.query(request.userText())
.topK(this.doGetChatMemoryRetrieveSize(request.adviseContext()))
.filterExpression(
DOCUMENT_METADATA_CONVERSATION_ID + "=='" + this.doGetConversationId(request.adviseContext()) + "'")
.build();
List<Document> documents = this.getChatMemoryStore().similaritySearch(searchRequest);
String longTermMemory = documents.stream()
.map(Document::getText)
.collect(Collectors.joining(System.lineSeparator()));
Map<String, Object> advisedSystemParams = new HashMap<>(request.systemParams());
advisedSystemParams.put("long_term_memory", longTermMemory);
AdvisedRequest advisedRequest = AdvisedRequest.from(request)
.systemText(advisedSystemText)
.systemParams(advisedSystemParams)
.build();
UserMessage userMessage = new UserMessage(request.userText(), request.media());
this.getChatMemoryStore()
.write(toDocuments(List.of(userMessage), this.doGetConversationId(request.adviseContext())));
return advisedRequest;
}
private void observeAfter(AdvisedResponse advisedResponse) {
List<Message> assistantMessages = advisedResponse.response()
.getResults()
.stream()
.map(g -> (Message) g.getOutput())
.toList();
this.getChatMemoryStore()
.write(toDocuments(assistantMessages, this.doGetConversationId(advisedResponse.adviseContext())));
}
private List<Document> toDocuments(List<Message> messages, String conversationId) {
List<Document> docs = messages.stream()
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
.map(message -> {
var metadata = new HashMap<>(message.getMetadata() != null ? message.getMetadata() : new HashMap<>());
metadata.put(DOCUMENT_METADATA_CONVERSATION_ID, conversationId);
metadata.put(DOCUMENT_METADATA_MESSAGE_TYPE, message.getMessageType().name());
if (message instanceof UserMessage userMessage) {
return Document.builder()
.text(userMessage.getText())
// userMessage.getMedia().get(0).getId()
// TODO vector store for memory would not store this into the
// vector store, could store an 'id' instead
// .media(userMessage.getMedia())
.metadata(metadata)
.build();
}
else if (message instanceof AssistantMessage assistantMessage) {
return Document.builder().text(assistantMessage.getText()).metadata(metadata).build();
}
throw new RuntimeException("Unknown message type: " + message.getMessageType());
})
.toList();
return docs;
}
//......
}
VectorStoreChatMemoryAdvisor继承了AbstractChatMemoryAdvisor,其泛型为VectorStore,其before方法先从构建searchRequest从VectorStore获取topK的documents为long_term_memory加入到advisedSystemParams一起构建advisedRequest,同时将userMessage写入到VectorStore;其observeAfter方法将返回的assistantMessages添加到VectorStore
DefaultChatClient
org/springframework/ai/chat/client/DefaultChatClient.java
private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequestSpec,
@Nullable String formatParam, Observation parentObservation) {
AdvisedRequest advisedRequest = toAdvisedRequest(inputRequestSpec, formatParam);
// Apply the around advisor chain that terminates with the last model call
// advisor.
AdvisedResponse advisedResponse = inputRequestSpec.aroundAdvisorChainBuilder.build()
.nextAroundCall(advisedRequest);
return advisedResponse.response();
}
DefaultChatClient的doGetChatResponse会构建DefaultAroundAdvisorChain然后执行其nextAroundCall方法
DefaultAroundAdvisorChain
org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java
public class DefaultAroundAdvisorChain implements CallAroundAdvisorChain, StreamAroundAdvisorChain {
public static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention();
private final Deque<CallAroundAdvisor> callAroundAdvisors;
private final Deque<StreamAroundAdvisor> streamAroundAdvisors;
private final ObservationRegistry observationRegistry;
DefaultAroundAdvisorChain(ObservationRegistry observationRegistry, Deque<CallAroundAdvisor> callAroundAdvisors,
Deque<StreamAroundAdvisor> streamAroundAdvisors) {
Assert.notNull(observationRegistry, "the observationRegistry must be non-null");
Assert.notNull(callAroundAdvisors, "the callAroundAdvisors must be non-null");
Assert.notNull(streamAroundAdvisors, "the streamAroundAdvisors must be non-null");
this.observationRegistry = observationRegistry;
this.callAroundAdvisors = callAroundAdvisors;
this.streamAroundAdvisors = streamAroundAdvisors;
}
public static Builder builder(ObservationRegistry observationRegistry) {
return new Builder(observationRegistry);
}
@Override
public AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest) {
if (this.callAroundAdvisors.isEmpty()) {
throw new IllegalStateException("No AroundAdvisor available to execute");
}
var advisor = this.callAroundAdvisors.pop();
var observationContext = AdvisorObservationContext.builder()
.advisorName(advisor.getName())
.advisorType(AdvisorObservationContext.Type.AROUND)
.advisedRequest(advisedRequest)
.advisorRequestContext(advisedRequest.adviseContext())
.order(advisor.getOrder())
.build();
return AdvisorObservationDocumentation.AI_ADVISOR
.observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry)
.observe(() -> advisor.aroundCall(advisedRequest, this));
}
@Override
public Flux<AdvisedResponse> nextAroundStream(AdvisedRequest advisedRequest) {
return Flux.deferContextual(contextView -> {
if (this.streamAroundAdvisors.isEmpty()) {
return Flux.error(new IllegalStateException("No AroundAdvisor available to execute"));
}
var advisor = this.streamAroundAdvisors.pop();
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
.advisorName(advisor.getName())
.advisorType(AdvisorObservationContext.Type.AROUND)
.advisedRequest(advisedRequest)
.advisorRequestContext(advisedRequest.adviseContext())
.order(advisor.getOrder())
.build();
var observation = AdvisorObservationDocumentation.AI_ADVISOR.observation(null,
DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry);
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
// @formatter:off
return Flux.defer(() -> advisor.aroundStream(advisedRequest, this))
.doOnError(observation::error)
.doFinally(s -> observation.stop())
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
// @formatter:on
});
}
public static class Builder {
private final ObservationRegistry observationRegistry;
private final Deque<CallAroundAdvisor> callAroundAdvisors;
private final Deque<StreamAroundAdvisor> streamAroundAdvisors;
public Builder(ObservationRegistry observationRegistry) {
this.observationRegistry = observationRegistry;
this.callAroundAdvisors = new ConcurrentLinkedDeque<>();
this.streamAroundAdvisors = new ConcurrentLinkedDeque<>();
}
public Builder push(Advisor aroundAdvisor) {
Assert.notNull(aroundAdvisor, "the aroundAdvisor must be non-null");
return this.pushAll(List.of(aroundAdvisor));
}
public Builder pushAll(List<? extends Advisor> advisors) {
Assert.notNull(advisors, "the advisors must be non-null");
if (!CollectionUtils.isEmpty(advisors)) {
List<CallAroundAdvisor> callAroundAdvisorList = advisors.stream()
.filter(a -> a instanceof CallAroundAdvisor)
.map(a -> (CallAroundAdvisor) a)
.toList();
if (!CollectionUtils.isEmpty(callAroundAdvisorList)) {
callAroundAdvisorList.forEach(this.callAroundAdvisors::push);
}
List<StreamAroundAdvisor> streamAroundAdvisorList = advisors.stream()
.filter(a -> a instanceof StreamAroundAdvisor)
.map(a -> (StreamAroundAdvisor) a)
.toList();
if (!CollectionUtils.isEmpty(streamAroundAdvisorList)) {
streamAroundAdvisorList.forEach(this.streamAroundAdvisors::push);
}
this.reOrder();
}
return this;
}
/**
* (Re)orders the advisors in priority order based on their Ordered attribute.
*/
private void reOrder() {
ArrayList<CallAroundAdvisor> callAdvisors = new ArrayList<>(this.callAroundAdvisors);
OrderComparator.sort(callAdvisors);
this.callAroundAdvisors.clear();
callAdvisors.forEach(this.callAroundAdvisors::addLast);
ArrayList<StreamAroundAdvisor> streamAdvisors = new ArrayList<>(this.streamAroundAdvisors);
OrderComparator.sort(streamAdvisors);
this.streamAroundAdvisors.clear();
streamAdvisors.forEach(this.streamAroundAdvisors::addLast);
}
public DefaultAroundAdvisorChain build() {
return new DefaultAroundAdvisorChain(this.observationRegistry, this.callAroundAdvisors,
this.streamAroundAdvisors);
}
}
}
DefaultAroundAdvisorChain实现了CallAroundAdvisorChain, StreamAroundAdvisorChain接口,它用Deque类型存储了aroundAdvisors,其nextAroundCall、nextAroundStream会先pop出来当前的advisor,然后执行其aroundCall方法,每个实现类的。
其Builder的push方法每次都会执行reOrder方法对roundAdvisors进行重新排序。
DefaultChatClientRequestSpec
org/springframework/ai/chat/client/DefaultChatClient.java
public static class DefaultChatClientRequestSpec implements ChatClientRequestSpec {
//......
public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText,
Map<String, Object> userParams, @Nullable String systemText, Map<String, Object> systemParams,
List<FunctionCallback> functionCallbacks, List<Message> messages, List<String> functionNames,
List<Media> media, @Nullable ChatOptions chatOptions, List<Advisor> advisors,
Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention customObservationConvention,
Map<String, Object> toolContext) {
Assert.notNull(chatModel, "chatModel cannot be null");
Assert.notNull(userParams, "userParams cannot be null");
Assert.notNull(systemParams, "systemParams cannot be null");
Assert.notNull(functionCallbacks, "functionCallbacks cannot be null");
Assert.notNull(messages, "messages cannot be null");
Assert.notNull(functionNames, "functionNames cannot be null");
Assert.notNull(media, "media cannot be null");
Assert.notNull(advisors, "advisors cannot be null");
Assert.notNull(advisorParams, "advisorParams cannot be null");
Assert.notNull(observationRegistry, "observationRegistry cannot be null");
Assert.notNull(toolContext, "toolContext cannot be null");
this.chatModel = chatModel;
this.chatOptions = chatOptions != null ? chatOptions.copy()
: (chatModel.getDefaultOptions() != null) ? chatModel.getDefaultOptions().copy() : null;
this.userText = userText;
this.userParams.putAll(userParams);
this.systemText = systemText;
this.systemParams.putAll(systemParams);
this.functionNames.addAll(functionNames);
this.functionCallbacks.addAll(functionCallbacks);
this.messages.addAll(messages);
this.media.addAll(media);
this.advisors.addAll(advisors);
this.advisorParams.putAll(advisorParams);
this.observationRegistry = observationRegistry;
this.customObservationConvention = customObservationConvention != null ? customObservationConvention
: DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION;
this.toolContext.putAll(toolContext);
// @formatter:off
// At the stack bottom add the non-streaming and streaming model call advisors.
// They play the role of the last advisor in the around advisor chain.
this.advisors.add(new CallAroundAdvisor() {
@Override
public String getName() {
return CallAroundAdvisor.class.getSimpleName();
}
@Override
public int getOrder() {
return Ordered.LOWEST_PRECEDENCE;
}
@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
return new AdvisedResponse(chatModel.call(advisedRequest.toPrompt()), Collections.unmodifiableMap(advisedRequest.adviseContext()));
}
});
this.advisors.add(new StreamAroundAdvisor() {
@Override
public String getName() {
return StreamAroundAdvisor.class.getSimpleName();
}
@Override
public int getOrder() {
return Ordered.LOWEST_PRECEDENCE;
}
@Override
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
return chatModel.stream(advisedRequest.toPrompt())
.map(chatResponse -> new AdvisedResponse(chatResponse, Collections.unmodifiableMap(advisedRequest.adviseContext())))
.publishOn(Schedulers.boundedElastic()); // TODO add option to disable.
}
});
// @formatter:on
this.aroundAdvisorChainBuilder = DefaultAroundAdvisorChain.builder(observationRegistry)
.pushAll(this.advisors);
}
//......
}
DefaultChatClientRequestSpec会以Ordered.LOWEST_PRECEDENCE的顺序加上最后一道advisor,来终止整个chain,避免调用chain.nextAroundCall(advisedRequest)导致抛出异常。
小结
Spring AI的Advisor提供了类似aop的机制,核心接口是CallAroundAdvisor, StreamAroundAdvisor,它们有SimpleLoggerAdvisor、SafeGuardAdvisor、QuestionAnswerAdvisor、AbstractChatMemoryAdvisor(MessageChatMemoryAdvisor
、PromptChatMemoryAdvisor
、VectorStoreChatMemoryAdvisor
)、BaseAdvisor(RetrievalAugmentationAdvisor
)这些实现。
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。