本文主要研究一下Spring AI的RetrievalAugmentationAdvisor

BaseAdvisor

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisor.java

public interface BaseAdvisor extends CallAroundAdvisor, StreamAroundAdvisor {

    Scheduler DEFAULT_SCHEDULER = Schedulers.boundedElastic();

    @Override
    default AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
        Assert.notNull(advisedRequest, "advisedRequest cannot be null");
        Assert.notNull(chain, "chain cannot be null");

        AdvisedRequest processedAdvisedRequest = before(advisedRequest);
        AdvisedResponse advisedResponse = chain.nextAroundCall(processedAdvisedRequest);
        return after(advisedResponse);
    }

    @Override
    default Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
        Assert.notNull(advisedRequest, "advisedRequest cannot be null");
        Assert.notNull(chain, "chain cannot be null");
        Assert.notNull(getScheduler(), "scheduler cannot be null");

        Flux<AdvisedResponse> advisedResponses = Mono.just(advisedRequest)
            .publishOn(getScheduler())
            .map(this::before)
            .flatMapMany(chain::nextAroundStream);

        return advisedResponses.map(ar -> {
            if (AdvisedResponseStreamUtils.onFinishReason().test(ar)) {
                ar = after(ar);
            }
            return ar;
        }).onErrorResume(error -> Flux.error(new IllegalStateException("Stream processing failed", error)));
    }

    @Override
    default String getName() {
        return this.getClass().getSimpleName();
    }

    /**
     * Logic to be executed before the rest of the advisor chain is called.
     */
    AdvisedRequest before(AdvisedRequest request);

    /**
     * Logic to be executed after the rest of the advisor chain is called.
     */
    AdvisedResponse after(AdvisedResponse advisedResponse);

    /**
     * Scheduler used for processing the advisor logic when streaming.
     */
    default Scheduler getScheduler() {
        return DEFAULT_SCHEDULER;
    }

}
BaseAdvisor继承了CallAroundAdvisor、StreamAroundAdvisor,它提供了aroundCall、aroundStream的default,在执行chain的next之前执行before,之后执行after方法

RetrievalAugmentationAdvisor

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java

public final class RetrievalAugmentationAdvisor implements BaseAdvisor {

    public static final String DOCUMENT_CONTEXT = "rag_document_context";

    private final List<QueryTransformer> queryTransformers;

    @Nullable
    private final QueryExpander queryExpander;

    private final DocumentRetriever documentRetriever;

    private final DocumentJoiner documentJoiner;

    private final QueryAugmenter queryAugmenter;

    private final TaskExecutor taskExecutor;

    private final Scheduler scheduler;

    private final int order;

    public RetrievalAugmentationAdvisor(@Nullable List<QueryTransformer> queryTransformers,
            @Nullable QueryExpander queryExpander, DocumentRetriever documentRetriever,
            @Nullable DocumentJoiner documentJoiner, @Nullable QueryAugmenter queryAugmenter,
            @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, @Nullable Integer order) {
        Assert.notNull(documentRetriever, "documentRetriever cannot be null");
        Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
        this.queryTransformers = queryTransformers != null ? queryTransformers : List.of();
        this.queryExpander = queryExpander;
        this.documentRetriever = documentRetriever;
        this.documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner();
        this.queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter.builder().build();
        this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor();
        this.scheduler = scheduler != null ? scheduler : BaseAdvisor.DEFAULT_SCHEDULER;
        this.order = order != null ? order : 0;
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override
    public AdvisedRequest before(AdvisedRequest request) {
        Map<String, Object> context = new HashMap<>(request.adviseContext());

        // 0. Create a query from the user text, parameters, and conversation history.
        Query originalQuery = Query.builder()
            .text(new PromptTemplate(request.userText(), request.userParams()).render())
            .history(request.messages())
            .context(context)
            .build();

        // 1. Transform original user query based on a chain of query transformers.
        Query transformedQuery = originalQuery;
        for (var queryTransformer : this.queryTransformers) {
            transformedQuery = queryTransformer.apply(transformedQuery);
        }

        // 2. Expand query into one or multiple queries.
        List<Query> expandedQueries = this.queryExpander != null ? this.queryExpander.expand(transformedQuery)
                : List.of(transformedQuery);

        // 3. Get similar documents for each query.
        Map<Query, List<List<Document>>> documentsForQuery = expandedQueries.stream()
            .map(query -> CompletableFuture.supplyAsync(() -> getDocumentsForQuery(query), this.taskExecutor))
            .toList()
            .stream()
            .map(CompletableFuture::join)
            .collect(Collectors.toMap(Map.Entry::getKey, entry -> List.of(entry.getValue())));

        // 4. Combine documents retrieved based on multiple queries and from multiple data
        // sources.
        List<Document> documents = this.documentJoiner.join(documentsForQuery);
        context.put(DOCUMENT_CONTEXT, documents);

        // 5. Augment user query with the document contextual data.
        Query augmentedQuery = this.queryAugmenter.augment(originalQuery, documents);

        // 6. Update advised request with augmented prompt.
        return AdvisedRequest.from(request).userText(augmentedQuery.text()).adviseContext(context).build();
    }

    /**
     * Processes a single query by routing it to document retrievers and collecting
     * documents.
     */
    private Map.Entry<Query, List<Document>> getDocumentsForQuery(Query query) {
        List<Document> documents = this.documentRetriever.retrieve(query);
        return Map.entry(query, documents);
    }

    @Override
    public AdvisedResponse after(AdvisedResponse advisedResponse) {
        ChatResponse.Builder chatResponseBuilder;
        if (advisedResponse.response() == null) {
            chatResponseBuilder = ChatResponse.builder();
        }
        else {
            chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response());
        }
        chatResponseBuilder.metadata(DOCUMENT_CONTEXT, advisedResponse.adviseContext().get(DOCUMENT_CONTEXT));
        return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext());
    }

    @Override
    public Scheduler getScheduler() {
        return this.scheduler;
    }

    @Override
    public int getOrder() {
        return this.order;
    }

    private static TaskExecutor buildDefaultTaskExecutor() {
        ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
        taskExecutor.setThreadNamePrefix("ai-advisor-");
        taskExecutor.setCorePoolSize(4);
        taskExecutor.setMaxPoolSize(16);
        taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator());
        taskExecutor.initialize();
        return taskExecutor;
    }

    //......

}    
RetrievalAugmentationAdvisor实现了BaseAdvisor接口,其before方法先通过queryTransformers来转换原始query,之后使用queryExpander来扩展query,接着通过documentRetriever来对扩展query检索相关文档,然后通过documentJoiner.join来把检索结果结合起来,作为rag_document_context放到context中,最后通过queryAugmenter.augment(originalQuery, documents)来生成最后的augmentedQuery;其after方法将before存储下来的key为rag_document_context的检索到的文档作为metadata附加到response中。

示例

    @Test
    void ragWithRewrite() {
        String question = "Where are the main characters going?";

        RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder()
            .queryTransformers(RewriteQueryTransformer.builder()
                .chatClientBuilder(ChatClient.builder(this.openAiChatModel))
                .targetSearchSystem("vector store")
                .build())
            .documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(this.pgVectorStore).build())
            .build();

        ChatResponse chatResponse = ChatClient.builder(this.openAiChatModel)
            .build()
            .prompt()
            .user(question)
            .advisors(ragAdvisor)
            .call()
            .chatResponse();

        assertThat(chatResponse).isNotNull();

        String response = chatResponse.getResult().getOutput().getText();
        System.out.println(response);
        assertThat(response).containsIgnoringCase("Loch of the Stars");

        evaluateRelevancy(question, chatResponse);
    }
这里使用了RewriteQueryTransformer、VectorStoreDocumentRetriever

小结

Spring AI定义了BaseAdvisor,它继承了CallAroundAdvisor、StreamAroundAdvisor,它提供了aroundCall、aroundStream的default,在执行chain的next之前执行before,之后执行after方法。RetrievalAugmentationAdvisor实现了BaseAdvisor接口,其before方法先通过queryTransformers来转换原始query,之后使用queryExpander来扩展query,接着通过documentRetriever来对扩展query检索相关文档,然后通过documentJoiner.join来把检索结果结合起来,作为rag_document_context放到context中,最后通过queryAugmenter.augment(originalQuery, documents)来生成最后的augmentedQuery;其after方法将before存储下来的key为rag_document_context的检索到的文档作为metadata附加到response中。

doc


codecraft
11.9k 声望2k 粉丝

当一个代码的工匠回首往事时,不因虚度年华而悔恨,也不因碌碌无为而羞愧,这样,当他老的时候,可以很自豪告诉世人,我曾经将代码注入生命去打造互联网的浪潮之巅,那是个很疯狂的时代,我在一波波的浪潮上留下...