序
本文主要研究一下Spring AI的RetrievalAugmentationAdvisor
BaseAdvisor
spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisor.java
public interface BaseAdvisor extends CallAroundAdvisor, StreamAroundAdvisor {
Scheduler DEFAULT_SCHEDULER = Schedulers.boundedElastic();
@Override
default AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
Assert.notNull(advisedRequest, "advisedRequest cannot be null");
Assert.notNull(chain, "chain cannot be null");
AdvisedRequest processedAdvisedRequest = before(advisedRequest);
AdvisedResponse advisedResponse = chain.nextAroundCall(processedAdvisedRequest);
return after(advisedResponse);
}
@Override
default Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
Assert.notNull(advisedRequest, "advisedRequest cannot be null");
Assert.notNull(chain, "chain cannot be null");
Assert.notNull(getScheduler(), "scheduler cannot be null");
Flux<AdvisedResponse> advisedResponses = Mono.just(advisedRequest)
.publishOn(getScheduler())
.map(this::before)
.flatMapMany(chain::nextAroundStream);
return advisedResponses.map(ar -> {
if (AdvisedResponseStreamUtils.onFinishReason().test(ar)) {
ar = after(ar);
}
return ar;
}).onErrorResume(error -> Flux.error(new IllegalStateException("Stream processing failed", error)));
}
@Override
default String getName() {
return this.getClass().getSimpleName();
}
/**
* Logic to be executed before the rest of the advisor chain is called.
*/
AdvisedRequest before(AdvisedRequest request);
/**
* Logic to be executed after the rest of the advisor chain is called.
*/
AdvisedResponse after(AdvisedResponse advisedResponse);
/**
* Scheduler used for processing the advisor logic when streaming.
*/
default Scheduler getScheduler() {
return DEFAULT_SCHEDULER;
}
}
BaseAdvisor继承了CallAroundAdvisor、StreamAroundAdvisor,它提供了aroundCall、aroundStream的default,在执行chain的next之前执行before,之后执行after方法
RetrievalAugmentationAdvisor
spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java
public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
public static final String DOCUMENT_CONTEXT = "rag_document_context";
private final List<QueryTransformer> queryTransformers;
@Nullable
private final QueryExpander queryExpander;
private final DocumentRetriever documentRetriever;
private final DocumentJoiner documentJoiner;
private final QueryAugmenter queryAugmenter;
private final TaskExecutor taskExecutor;
private final Scheduler scheduler;
private final int order;
public RetrievalAugmentationAdvisor(@Nullable List<QueryTransformer> queryTransformers,
@Nullable QueryExpander queryExpander, DocumentRetriever documentRetriever,
@Nullable DocumentJoiner documentJoiner, @Nullable QueryAugmenter queryAugmenter,
@Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, @Nullable Integer order) {
Assert.notNull(documentRetriever, "documentRetriever cannot be null");
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
this.queryTransformers = queryTransformers != null ? queryTransformers : List.of();
this.queryExpander = queryExpander;
this.documentRetriever = documentRetriever;
this.documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner();
this.queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter.builder().build();
this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor();
this.scheduler = scheduler != null ? scheduler : BaseAdvisor.DEFAULT_SCHEDULER;
this.order = order != null ? order : 0;
}
public static Builder builder() {
return new Builder();
}
@Override
public AdvisedRequest before(AdvisedRequest request) {
Map<String, Object> context = new HashMap<>(request.adviseContext());
// 0. Create a query from the user text, parameters, and conversation history.
Query originalQuery = Query.builder()
.text(new PromptTemplate(request.userText(), request.userParams()).render())
.history(request.messages())
.context(context)
.build();
// 1. Transform original user query based on a chain of query transformers.
Query transformedQuery = originalQuery;
for (var queryTransformer : this.queryTransformers) {
transformedQuery = queryTransformer.apply(transformedQuery);
}
// 2. Expand query into one or multiple queries.
List<Query> expandedQueries = this.queryExpander != null ? this.queryExpander.expand(transformedQuery)
: List.of(transformedQuery);
// 3. Get similar documents for each query.
Map<Query, List<List<Document>>> documentsForQuery = expandedQueries.stream()
.map(query -> CompletableFuture.supplyAsync(() -> getDocumentsForQuery(query), this.taskExecutor))
.toList()
.stream()
.map(CompletableFuture::join)
.collect(Collectors.toMap(Map.Entry::getKey, entry -> List.of(entry.getValue())));
// 4. Combine documents retrieved based on multiple queries and from multiple data
// sources.
List<Document> documents = this.documentJoiner.join(documentsForQuery);
context.put(DOCUMENT_CONTEXT, documents);
// 5. Augment user query with the document contextual data.
Query augmentedQuery = this.queryAugmenter.augment(originalQuery, documents);
// 6. Update advised request with augmented prompt.
return AdvisedRequest.from(request).userText(augmentedQuery.text()).adviseContext(context).build();
}
/**
* Processes a single query by routing it to document retrievers and collecting
* documents.
*/
private Map.Entry<Query, List<Document>> getDocumentsForQuery(Query query) {
List<Document> documents = this.documentRetriever.retrieve(query);
return Map.entry(query, documents);
}
@Override
public AdvisedResponse after(AdvisedResponse advisedResponse) {
ChatResponse.Builder chatResponseBuilder;
if (advisedResponse.response() == null) {
chatResponseBuilder = ChatResponse.builder();
}
else {
chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response());
}
chatResponseBuilder.metadata(DOCUMENT_CONTEXT, advisedResponse.adviseContext().get(DOCUMENT_CONTEXT));
return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext());
}
@Override
public Scheduler getScheduler() {
return this.scheduler;
}
@Override
public int getOrder() {
return this.order;
}
private static TaskExecutor buildDefaultTaskExecutor() {
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
taskExecutor.setThreadNamePrefix("ai-advisor-");
taskExecutor.setCorePoolSize(4);
taskExecutor.setMaxPoolSize(16);
taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator());
taskExecutor.initialize();
return taskExecutor;
}
//......
}
RetrievalAugmentationAdvisor实现了BaseAdvisor接口,其before方法先通过queryTransformers来转换原始query,之后使用queryExpander来扩展query,接着通过documentRetriever来对扩展query检索相关文档,然后通过documentJoiner.join来把检索结果结合起来,作为rag_document_context放到context中,最后通过queryAugmenter.augment(originalQuery, documents)来生成最后的augmentedQuery;其after方法将before存储下来的key为rag_document_context的检索到的文档作为metadata附加到response中。
示例
@Test
void ragWithRewrite() {
String question = "Where are the main characters going?";
RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder()
.queryTransformers(RewriteQueryTransformer.builder()
.chatClientBuilder(ChatClient.builder(this.openAiChatModel))
.targetSearchSystem("vector store")
.build())
.documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(this.pgVectorStore).build())
.build();
ChatResponse chatResponse = ChatClient.builder(this.openAiChatModel)
.build()
.prompt()
.user(question)
.advisors(ragAdvisor)
.call()
.chatResponse();
assertThat(chatResponse).isNotNull();
String response = chatResponse.getResult().getOutput().getText();
System.out.println(response);
assertThat(response).containsIgnoringCase("Loch of the Stars");
evaluateRelevancy(question, chatResponse);
}
这里使用了RewriteQueryTransformer、VectorStoreDocumentRetriever
小结
Spring AI定义了BaseAdvisor,它继承了CallAroundAdvisor、StreamAroundAdvisor,它提供了aroundCall、aroundStream的default,在执行chain的next之前执行before,之后执行after方法。RetrievalAugmentationAdvisor实现了BaseAdvisor接口,其before方法先通过queryTransformers来转换原始query,之后使用queryExpander来扩展query,接着通过documentRetriever来对扩展query检索相关文档,然后通过documentJoiner.join来把检索结果结合起来,作为rag_document_context放到context中,最后通过queryAugmenter.augment(originalQuery, documents)来生成最后的augmentedQuery;其after方法将before存储下来的key为rag_document_context的检索到的文档作为metadata附加到response中。
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。