本文主要研究一下Spring AI的RAG

Sequential RAG Flows

Naive RAG

Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
        .documentRetriever(VectorStoreDocumentRetriever.builder()
                .similarityThreshold(0.50)
                .vectorStore(vectorStore)
                .build())
        .queryAugmenter(ContextualQueryAugmenter.builder()
                .allowEmptyContext(true)
                .build())
        .build();

String answer = chatClient.prompt()
        .advisors(retrievalAugmentationAdvisor)
        .user(question)
        .call()
        .content();
allowEmptyContext为true告诉大模型不回答context为empty的问题

Advanced RAG

Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
        .queryTransformers(RewriteQueryTransformer.builder()
                .chatClientBuilder(chatClientBuilder.build().mutate())
                .build())
        .documentRetriever(VectorStoreDocumentRetriever.builder()
                .similarityThreshold(0.50)
                .vectorStore(vectorStore)
                .build())
        .build();

String answer = chatClient.prompt()
        .advisors(retrievalAugmentationAdvisor)
        .user(question)
        .call()
        .content();
Advanced RAG可以设置queryTransformers来进行查询转换

Modular RAG

Spring AI受Modular RAG: Transforming RAG Systems into LEGO-like Reconfigurable Frameworks启发实现了Modular RAG,主要分为如下几个阶段:Pre-Retrieval、Retrieval、Post-Retrieval、Generation

Pre-Retrieval

增强和转换用户输入,使其更有效地执行检索任务,解决格式不正确的查询、query 语义不清晰、或不受支持的语言等。

1. QueryAugmenter 查询增强

使用附加的上下文数据信息增强用户query,提供大模型回答问题时的必要上下文信息;
  • ContextualQueryAugmenter使用上下文来增强query

    QueryAugmenter augmenter = ContextualQueryAugmenter. builder()    
          .allowEmptyContext(false)    
          .build(); 
    Query augmentedQuery = augmenter.augment(query, documents);

2. QueryTransformer 查询改写

因为用户的输入通常是片面的,关键信息较少,不便于大模型理解和回答问题。因此需要使用prompt调优手段或者大模型改写用户query;
当使用QueryTransformer时建议设置比较低的temperature(比如0.0)来确保结果的准确性
它有CompressionQueryTransformer、RewriteQueryTransformer、TranslationQueryTransformer三种实现
  • CompressionQueryTransformer使用大模型来压缩会话历史

    Query query = Query.builder()
          .text("And what is its second largest city?")
          .history(new UserMessage("What is the capital of Denmark?"),
                  new AssistantMessage("Copenhagen is the capital of Denmark."))
          .build();
    
    QueryTransformer queryTransformer = CompressionQueryTransformer.builder()
          .chatClientBuilder(chatClientBuilder)
          .build();
    
    Query transformedQuery = queryTransformer.transform(query);
  • RewriteQueryTransformer使用大模型来重写query

    Query query = new Query("I'm studying machine learning. What is an LLM?");
    
    QueryTransformer queryTransformer = RewriteQueryTransformer.builder()
          .chatClientBuilder(chatClientBuilder)
          .build();
    
    Query transformedQuery = queryTransformer.transform(query);
  • TranslationQueryTransformer使用大模型来翻译query

    Query query = new Query("Hvad er Danmarks hovedstad?");
    
    QueryTransformer queryTransformer = TranslationQueryTransformer.builder()
          .chatClientBuilder(chatClientBuilder)
          .targetLanguage("english")
          .build();
    
    Query transformedQuery = queryTransformer.transform(query);

3. QueryExpander 查询扩展

将用户 query 扩展为多个语义不同的变体以获得不同视角,有助于检索额外的上下文信息并增加找到相关结果的机会。
  • MultiQueryExpander使用大模型扩展query

    MultiQueryExpander queryExpander = MultiQueryExpander.builder()
      .chatClientBuilder(chatClientBuilder)
      .numberOfQueries(3)
      .includeOriginal(false) // 默认会包含原始query,设置为false表示不包含
      .build();
    List<Query> queries = expander.expand(new Query("How to run a Spring Boot app?"));

Retrieval

负责查询向量存储等数据系统并检索和用户query相关性最高的Document。

1. DocumentRetriever 检索器

根据 QueryExpander 使用不同的数据源进行检索,例如 搜索引擎、向量存储、数据库或知识图等;它主要有VectorStoreDocumentRetriever、WebSearchRetriever两个实现
  • VectorStoreDocumentRetriever

    DocumentRetriever retriever = VectorStoreDocumentRetriever.builder()
      .vectorStore(vectorStore)
      .similarityThreshold(0.73)
      .topK(5)
      .filterExpression(new FilterExpressionBuilder()
          .eq("genre", "fairytale")
          .build())
      .build();
    List<Document> documents = retriever.retrieve(new Query("What is the main character of the story?"));

2. DocumentJoiner

将从多个query和从多个数据源检索到的Document合并为一个Document集合;它有ConcatenationDocumentJoiner实现
  • ConcatenationDocumentJoiner

    Map<Query, List<List<Document>>> documentsForQuery = ...
    DocumentJoiner documentJoiner = new ConcatenationDocumentJoiner();
    List<Document> documents = documentJoiner.join(documentsForQuery);

Post-Retrieval

负责处理检索到的 Document 以获得最佳的输出结果,解决模型中的中间丢失和上下文长度限制等。
  1. DocumentRanker:根据Document和用户query的相关性对Document进行排序和排名;
  2. DocumentSelector:用于从检索到的Document列表中删除不相关或冗余文档;
  3. DocumentCompressor:用于压缩每个Document,减少检索到的信息中的噪音和冗余。

Generation

生成用户 Query 对应的大模型输出。

源码

org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java

    public static final class Builder {

        private List<QueryTransformer> queryTransformers;

        private QueryExpander queryExpander;

        private DocumentRetriever documentRetriever;

        private DocumentJoiner documentJoiner;

        private QueryAugmenter queryAugmenter;

        private TaskExecutor taskExecutor;

        private Scheduler scheduler;

        private Integer order;

        private Builder() {
        }

        //......
    }    
RetrievalAugmentationAdvisor的Builder提供了Pre-Retrieval(queryAugmenterqueryTransformersqueryExpander)、Retrieval(documentRetrieverdocumentJoiner)这几个组件的配置。

示例

ModuleRAGBasicController

@RestController
@RequestMapping("/module-rag")
public class ModuleRAGBasicController {

    private final ChatClient chatClient;
    private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;

    public ModuleRAGBasicController(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) {

        this.chatClient = chatClientBuilder.build();
        this.retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
                .documentRetriever(VectorStoreDocumentRetriever.builder()
                        .similarityThreshold(0.50)
                        .vectorStore(vectorStore)
                        .build()
                ).build();
    }

    @GetMapping("/rag/basic")
    public String chatWithDocument(@RequestParam("prompt") String prompt) {

        return chatClient.prompt()
                .advisors(retrievalAugmentationAdvisor)
                .user(prompt)
                .call()
                .content();
    }

}

ModuleRAGCompressionController

@RestController
@RequestMapping("/module-rag")
public class ModuleRAGCompressionController {

    private final ChatClient chatClient;

    private final MessageChatMemoryAdvisor chatMemoryAdvisor;

    private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;

    public ModuleRAGCompressionController(
            ChatClient.Builder chatClientBuilder,
            ChatMemory chatMemory,
            VectorStore vectorStore) {

        this.chatClient = chatClientBuilder.build();

        this.chatMemoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory)
                .build();

        var documentRetriever = VectorStoreDocumentRetriever.builder()
                .vectorStore(vectorStore)
                .similarityThreshold(0.50)
                .build();

        var queryTransformer = CompressionQueryTransformer.builder()
                .chatClientBuilder(chatClientBuilder.build().mutate())
                .build();

        this.retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
                .documentRetriever(documentRetriever)
                .queryTransformers(queryTransformer)
                .build();
    }

    @PostMapping("/rag/compression/{chatId}")
    public String rag(
            @RequestBody String prompt,
            @PathVariable("chatId") String conversationId
    ) {

        return chatClient.prompt()
                .advisors(chatMemoryAdvisor, retrievalAugmentationAdvisor)
                .advisors(advisors -> advisors.param(
                        AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, conversationId))
                .user(prompt)
                .call()
                .content();
    }

}

ModuleRAGMemoryController

@RestController
@RequestMapping("/module-rag")
public class ModuleRAGMemoryController {

    private final ChatClient chatClient;

    private final MessageChatMemoryAdvisor chatMemoryAdvisor;

    private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;

    public ModuleRAGMemoryController(
            ChatClient.Builder chatClientBuilder,
            ChatMemory chatMemory,
            VectorStore vectorStore
    ) {

        this.chatClient = chatClientBuilder.build();
        this.chatMemoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory)
                .build();

        this.retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
                .documentRetriever(VectorStoreDocumentRetriever.builder()
                        .similarityThreshold(0.50)
                        .vectorStore(vectorStore)
                        .build())
                .build();
    }

    @PostMapping("/rag/memory/{chatId}")
    public String chatWithDocument(
            @RequestBody String prompt,
            @PathVariable("chatId") String conversationId
    ) {

        return chatClient.prompt()
                .advisors(chatMemoryAdvisor, retrievalAugmentationAdvisor)
                .advisors(advisors -> advisors.param(
                        AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, conversationId))
                .user(prompt)
                .call()
                .content();
    }

}

ModuleRAGRewriteController

@RestController
@RequestMapping("/module-rag")
public class ModuleRAGRewriteController {

    private final ChatClient chatClient;

    private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;

    public ModuleRAGRewriteController(
            ChatClient.Builder chatClientBuilder,
            VectorStore vectorStore
    ) {

        this.chatClient = chatClientBuilder.build();

        var documentRetriever = VectorStoreDocumentRetriever.builder()
                .vectorStore(vectorStore)
                .similarityThreshold(0.50)
                .build();

        var queryTransformer = RewriteQueryTransformer.builder()
                .chatClientBuilder(chatClientBuilder.build().mutate())
                .targetSearchSystem("vector store")
                .build();

        this.retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
                .documentRetriever(documentRetriever)
                .queryTransformers(queryTransformer)
                .build();
    }

    @PostMapping("/rag/rewrite")
    public String rag(@RequestBody String prompt) {

        return chatClient.prompt()
                .advisors(retrievalAugmentationAdvisor)
                .user(prompt)
                .call()
                .content();
    }
}

ModuleRAGTranslationController

@RestController
@RequestMapping("/module-rag")
public class ModuleRAGTranslationController {

    private final ChatClient chatClient;

    private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;

    public ModuleRAGTranslationController(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) {
        this.chatClient = chatClientBuilder.build();

        var documentRetriever = VectorStoreDocumentRetriever.builder()
                .vectorStore(vectorStore)
                .similarityThreshold(0.50)
                .build();

        var queryTransformer = TranslationQueryTransformer.builder()
                .chatClientBuilder(chatClientBuilder.build().mutate())
                .targetLanguage("english")
                .build();

        this.retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
                .documentRetriever(documentRetriever)
                .queryTransformers(queryTransformer)
                .build();
    }

    @PostMapping("/rag/translation")
    public String rag(@RequestBody String prompt) {

        return chatClient.prompt()
                .advisors(retrievalAugmentationAdvisor)
                .user(prompt)
                .call()
                .content();
    }

}

小结

Spring AI通过RetrievalAugmentationAdvisor提供了开箱即用的RAG flows,主要有两大类,一是Sequential RAG Flows(Naive RAGAdvanced RAG),另外Spring AI受Modular RAG: Transforming RAG Systems into LEGO-like Reconfigurable Frameworks启发实现了Modular RAG,主要分为如下几个阶段:Pre-Retrieval、Retrieval、Post-Retrieval、Generation这几个阶段。RetrievalAugmentationAdvisor的Builder提供了Pre-Retrieval(queryAugmenterqueryTransformersqueryExpander)、Retrieval(documentRetrieverdocumentJoiner)这几个组件的配置。

doc


codecraft
11.9k 声望2k 粉丝

当一个代码的工匠回首往事时,不因虚度年华而悔恨,也不因碌碌无为而羞愧,这样,当他老的时候,可以很自豪告诉世人,我曾经将代码注入生命去打造互联网的浪潮之巅,那是个很疯狂的时代,我在一波波的浪潮上留下...