本文主要研究一下langchain4j的Advanced RAG

核心流程

  • 将UserMessage转换为一个原始的Query
  • QueryTransformer将原始的Query转换为多个Query
  • 每个Query通过QueryRouter被路由到一个或多个ContentRetriever
  • 每个ContentRetriever检索对应Query相关的Content
  • ContentAggregator将所有检索到的Content合并成一个最终排序的列表
  • 这个内容列表被注入到原始的UserMessage中
  • 最后包含原始查询以及注入的相关内容的UserMessage被发送到LLM

示例

public class _02_Advanced_RAG_with_Query_Routing_Example {

    /**
     * Please refer to {@link Naive_RAG_Example} for a basic context.
     * <p>
     * Advanced RAG in LangChain4j is described here: https://github.com/langchain4j/langchain4j/pull/538
     * <p>
     * This example showcases the implementation of a more advanced RAG application
     * using a technique known as "query routing".
     * <p>
     * Often, private data is spread across multiple sources and formats.
     * This might include internal company documentation on Confluence, your project's code in a Git repository,
     * a relational database with user data, or a search engine with the products you sell, among others.
     * In a RAG flow that utilizes data from multiple sources, you will likely have multiple
     * {@link EmbeddingStore}s or {@link ContentRetriever}s.
     * While you could route each user query to all available {@link ContentRetriever}s,
     * this approach might be inefficient and counterproductive.
     * <p>
     * "Query routing" is the solution to this challenge. It involves directing a query to the most appropriate
     * {@link ContentRetriever} (or several). Routing can be implemented in various ways:
     * - Using rules (e.g., depending on the user's privileges, location, etc.).
     * - Using keywords (e.g., if a query contains words X1, X2, X3, route it to {@link ContentRetriever} X, etc.).
     * - Using semantic similarity (see EmbeddingModelTextClassifierExample in this repository).
     * - Using an LLM to make a routing decision.
     * <p>
     * For scenarios 1, 2, and 3, you can implement a custom {@link QueryRouter}.
     * For scenario 4, this example will demonstrate how to use a {@link LanguageModelQueryRouter}.
     */

    public static void main(String[] args) {

        Assistant assistant = createAssistant();

        // First, ask "What is the legacy of John Doe?"
        // Then, ask "Can I cancel my reservation?"
        // Now, see the logs to observe how the queries are routed to different retrievers.
        startConversationWith(assistant);
    }

    private static Assistant createAssistant() {

        EmbeddingModel embeddingModel = new BgeSmallEnV15QuantizedEmbeddingModel();

        // Let's create a separate embedding store specifically for biographies.
        EmbeddingStore<TextSegment> biographyEmbeddingStore =
                embed(toPath("documents/biography-of-john-doe.txt"), embeddingModel);
        ContentRetriever biographyContentRetriever = EmbeddingStoreContentRetriever.builder()
                .embeddingStore(biographyEmbeddingStore)
                .embeddingModel(embeddingModel)
                .maxResults(2)
                .minScore(0.6)
                .build();

        // Additionally, let's create a separate embedding store dedicated to terms of use.
        EmbeddingStore<TextSegment> termsOfUseEmbeddingStore =
                embed(toPath("documents/miles-of-smiles-terms-of-use.txt"), embeddingModel);
        ContentRetriever termsOfUseContentRetriever = EmbeddingStoreContentRetriever.builder()
                .embeddingStore(termsOfUseEmbeddingStore)
                .embeddingModel(embeddingModel)
                .maxResults(2)
                .minScore(0.6)
                .build();

        ChatLanguageModel chatLanguageModel = OpenAiChatModel.builder()
                .apiKey(OPENAI_API_KEY)
                .modelName(GPT_4_O_MINI)
                .build();

        // Let's create a query router.
        Map<ContentRetriever, String> retrieverToDescription = new HashMap<>();
        retrieverToDescription.put(biographyContentRetriever, "biography of John Doe");
        retrieverToDescription.put(termsOfUseContentRetriever, "terms of use of car rental company");
        QueryRouter queryRouter = new LanguageModelQueryRouter(chatLanguageModel, retrieverToDescription);

        RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
                .queryRouter(queryRouter)
                .build();

        return AiServices.builder(Assistant.class)
                .chatLanguageModel(chatLanguageModel)
                .retrievalAugmentor(retrievalAugmentor)
                .chatMemory(MessageWindowChatMemory.withMaxMessages(10))
                .build();
    }

    private static EmbeddingStore<TextSegment> embed(Path documentPath, EmbeddingModel embeddingModel) {
        DocumentParser documentParser = new TextDocumentParser();
        Document document = loadDocument(documentPath, documentParser);

        DocumentSplitter splitter = DocumentSplitters.recursive(300, 0);
        List<TextSegment> segments = splitter.split(document);

        List<Embedding> embeddings = embeddingModel.embedAll(segments).content();

        EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
        embeddingStore.addAll(embeddings, segments);
        return embeddingStore;
    }
}
这里使用了DefaultRetrievalAugmentor来设置了LanguageModelQueryRouter,这里设置了biographyContentRetriever、termsOfUseContentRetriever两个ContentRetriever。

源码解析

RetrievalAugmentor

dev/langchain4j/rag/RetrievalAugmentor.java

@Experimental
public interface RetrievalAugmentor {

    /**
     * Augments the {@link ChatMessage} provided in the {@link AugmentationRequest} with retrieved {@link Content}s.
     * <br>
     * This method has a default implementation in order to <b>temporarily</b> support
     * current custom implementations of {@code RetrievalAugmentor}. The default implementation will be removed soon.
     *
     * @param augmentationRequest The {@code AugmentationRequest} containing the {@code ChatMessage} to augment.
     * @return The {@link AugmentationResult} containing the augmented {@code ChatMessage}.
     */
    default AugmentationResult augment(AugmentationRequest augmentationRequest) {

        if (!(augmentationRequest.chatMessage() instanceof UserMessage)) {
            throw runtime("Please implement 'AugmentationResult augment(AugmentationRequest)' method " +
                    "in order to augment " + augmentationRequest.chatMessage().getClass());
        }

        UserMessage augmented = augment((UserMessage) augmentationRequest.chatMessage(), augmentationRequest.metadata());

        return AugmentationResult.builder()
                .chatMessage(augmented)
                .build();
    }

    /**
     * Augments the provided {@link UserMessage} with retrieved content.
     *
     * @param userMessage The {@link UserMessage} to be augmented.
     * @param metadata    The {@link Metadata} that may be useful or necessary for retrieval and augmentation.
     * @return The augmented {@link UserMessage}.
     * @deprecated Use/implement {@link #augment(AugmentationRequest)} instead.
     */
    @Deprecated
    UserMessage augment(UserMessage userMessage, Metadata metadata);
}
RetrievalAugmentor接口定义了augment(AugmentationRequest augmentationRequest)方法,它作为langchain4j的RAG入口,负责根据AugmentationRequest来检索相关Content,它提供了默认实现主要是适配废弃的augment(UserMessage userMessage, Metadata metadata)方法

DefaultRetrievalAugmentor

dev/langchain4j/rag/DefaultRetrievalAugmentor.java

public class DefaultRetrievalAugmentor implements RetrievalAugmentor {

    private static final Logger log = LoggerFactory.getLogger(DefaultRetrievalAugmentor.class);

    private final QueryTransformer queryTransformer;
    private final QueryRouter queryRouter;
    private final ContentAggregator contentAggregator;
    private final ContentInjector contentInjector;
    private final Executor executor;

    public DefaultRetrievalAugmentor(QueryTransformer queryTransformer,
                                     QueryRouter queryRouter,
                                     ContentAggregator contentAggregator,
                                     ContentInjector contentInjector,
                                     Executor executor) {
        this.queryTransformer = getOrDefault(queryTransformer, DefaultQueryTransformer::new);
        this.queryRouter = ensureNotNull(queryRouter, "queryRouter");
        this.contentAggregator = getOrDefault(contentAggregator, DefaultContentAggregator::new);
        this.contentInjector = getOrDefault(contentInjector, DefaultContentInjector::new);
        this.executor = getOrDefault(executor, DefaultRetrievalAugmentor::createDefaultExecutor);
    }

    private static ExecutorService createDefaultExecutor() {
        return new ThreadPoolExecutor(
            0, Integer.MAX_VALUE,
            1, SECONDS,
            new SynchronousQueue<>()
        );
    }

    /**
     * @deprecated use {@link #augment(AugmentationRequest)} instead.
     */
    @Override
    @Deprecated
    public UserMessage augment(UserMessage userMessage, Metadata metadata) {
        AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
        return (UserMessage) augment(augmentationRequest).chatMessage();
    }

    @Override
    public AugmentationResult augment(AugmentationRequest augmentationRequest) {

        ChatMessage chatMessage = augmentationRequest.chatMessage();
        Metadata metadata = augmentationRequest.metadata();

        Query originalQuery = Query.from(chatMessage.text(), metadata);

        Collection<Query> queries = queryTransformer.transform(originalQuery);
        logQueries(originalQuery, queries);

        Map<Query, Collection<List<Content>>> queryToContents = process(queries);

        List<Content> contents = contentAggregator.aggregate(queryToContents);
        log(queryToContents, contents);

        ChatMessage augmentedChatMessage = contentInjector.inject(contents, chatMessage);
        log(augmentedChatMessage);

        return AugmentationResult.builder()
            .chatMessage(augmentedChatMessage)
            .contents(contents)
            .build();
    }

    //......

}    
DefaultRetrievalAugmentor实现了RetrievalAugmentor接口,它定义了queryTransformer、queryRouter、contentAggregator、contentInjector、executor;其augment方法先是将chatMessage转换为originalQuery,接着通过queryTransformer.transform得到queries,接着通过process得到一系列的Content,然后通过contentAggregator.aggregate来聚合contents,最后通过contentInjector.inject来将contents注入到chatMessage得到augmentedChatMessage,返回结果AugmentationResult包含了augmentedChatMessage及注入的contents。
其构造器要求queryRouter不为null,对于queryTransformer为null的默认使用DefaultQueryTransformer,对于contentAggregator为null的默认使用DefaultContentAggregator,对于contentInjector为null的默认使用DefaultContentInjector,对于executor为null的默认创建了一个coreSize为0,maximumPoolSize为Integer.MAX\_VALUE,keepAliveTime为1s,workQueue为SynchronousQueue的ThreadPoolExecutor
    private Map<Query, Collection<List<Content>>> process(Collection<Query> queries) {
        if (queries.size() == 1) {
            Query query = queries.iterator().next();
            Collection<ContentRetriever> retrievers = queryRouter.route(query);
            if (retrievers.size() == 1) {
                ContentRetriever contentRetriever = retrievers.iterator().next();
                List<Content> contents = contentRetriever.retrieve(query);
                return singletonMap(query, singletonList(contents));
            } else if (retrievers.size() > 1) {
                Collection<List<Content>> contents = retrieveFromAll(retrievers, query).join();
                return singletonMap(query, contents);
            } else {
                return emptyMap();
            }
        } else if (queries.size() > 1) {
            Map<Query, CompletableFuture<Collection<List<Content>>>> queryToFutureContents = new ConcurrentHashMap<>();
            queries.forEach(query -> {
                CompletableFuture<Collection<List<Content>>> futureContents =
                    supplyAsync(() -> {
                            Collection<ContentRetriever> retrievers = queryRouter.route(query);
                            log(query, retrievers);
                            return retrievers;
                        },
                        executor
                    ).thenCompose(retrievers -> retrieveFromAll(retrievers, query));
                queryToFutureContents.put(query, futureContents);
            });
            return join(queryToFutureContents);
        } else {
            return emptyMap();
        }
    }

    private CompletableFuture<Collection<List<Content>>> retrieveFromAll(Collection<ContentRetriever> retrievers,
                                                                         Query query) {
        List<CompletableFuture<List<Content>>> futureContents = retrievers.stream()
            .map(retriever -> supplyAsync(() -> retrieve(retriever, query), executor))
            .collect(Collectors.toList());

        return allOf(futureContents.toArray(new CompletableFuture[0]))
            .thenApply(ignored ->
                futureContents.stream()
                    .map(CompletableFuture::join)
                    .collect(Collectors.toList()));
    }

    private static List<Content> retrieve(ContentRetriever retriever, Query query) {
        List<Content> contents = retriever.retrieve(query);
        log(query, retriever, contents);
        return contents;
    }

    private static Map<Query, Collection<List<Content>>> join(
        Map<Query, CompletableFuture<Collection<List<Content>>>> queryToFutureContents) {
        return allOf(queryToFutureContents.values().toArray(new CompletableFuture[0]))
            .thenApply(ignored ->
                queryToFutureContents.entrySet().stream()
                    .collect(toMap(
                        Map.Entry::getKey,
                        entry -> entry.getValue().join()
                    ))
            ).join();
    }    
process方法主要是通过queryRouter.route(query)来获取retrievers的路由,之后对每个ContentRetriever执行retrieve获取对应的List<Content>,它针对queries是1个还是多个做了特殊处理,多个则通过executor来并发执行,最后通过join来等待。

QueryTransformer

dev/langchain4j/rag/query/transformer/QueryTransformer.java

@Experimental
public interface QueryTransformer {

    /**
     * Transforms the given {@link Query} into one or multiple {@link Query}s.
     *
     * @param query The {@link Query} to be transformed.
     * @return A collection of one or more {@link Query}s derived from the original {@link Query}.
     */
    Collection<Query> transform(Query query);
}
QueryTransformer定义了transform方法,主要用于修改或者扩展原始Query,已知的场景比如:查询压缩、查询扩展、查询重写、Step-back prompting、Hypothetical document embeddings (HyDE)。它有几个实现分别是:DefaultQueryTransformer, CompressingQueryTransformer, ExpandingQueryTransformer。

DefaultQueryTransformer

dev/langchain4j/rag/query/transformer/DefaultQueryTransformer.java

public class DefaultQueryTransformer implements QueryTransformer {

    @Override
    public Collection<Query> transform(Query query) {
        return singletonList(query);
    }
}
DefaultQueryTransformer将query包装为list返回

CompressingQueryTransformer

dev/langchain4j/rag/query/transformer/CompressingQueryTransformer.java

public class CompressingQueryTransformer implements QueryTransformer {

    public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
            """
                    Read and understand the conversation between the User and the AI. \
                    Then, analyze the new query from the User. \
                    Identify all relevant details, terms, and context from both the conversation and the new query. \
                    Reformulate this query into a clear, concise, and self-contained format suitable for information retrieval.
                    
                    Conversation:
                    {{chatMemory}}
                    
                    User query: {{query}}
                    
                    It is very important that you provide only reformulated query and nothing else! \
                    Do not prepend a query with anything!"""
    );

    protected final PromptTemplate promptTemplate;
    protected final ChatLanguageModel chatLanguageModel;

    public CompressingQueryTransformer(ChatLanguageModel chatLanguageModel) {
        this(chatLanguageModel, DEFAULT_PROMPT_TEMPLATE);
    }

    public CompressingQueryTransformer(ChatLanguageModel chatLanguageModel, PromptTemplate promptTemplate) {
        this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel");
        this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
    }

    public static CompressingQueryTransformerBuilder builder() {
        return new CompressingQueryTransformerBuilder();
    }

    @Override
    public Collection<Query> transform(Query query) {

        List<ChatMessage> chatMemory = query.metadata().chatMemory();
        if (chatMemory.isEmpty()) {
            // no need to compress if there are no previous messages
            return singletonList(query);
        }

        Prompt prompt = createPrompt(query, format(chatMemory));
        String compressedQueryText = chatLanguageModel.chat(prompt.text());
        Query compressedQuery = query.metadata() == null
                ? Query.from(compressedQueryText)
                : Query.from(compressedQueryText, query.metadata());
        return singletonList(compressedQuery);
    }

    protected String format(List<ChatMessage> chatMemory) {
        return chatMemory.stream()
                .map(this::format)
                .filter(Objects::nonNull)
                .collect(joining("\n"));
    }

    protected String format(ChatMessage message) {
        if (message instanceof UserMessage) {
            return "User: " + message.text();
        } else if (message instanceof AiMessage aiMessage) {
            if (aiMessage.hasToolExecutionRequests()) {
                return null;
            }
            return "AI: " + aiMessage.text();
        } else {
            return null;
        }
    }

    protected Prompt createPrompt(Query query, String chatMemory) {
        Map<String, Object> variables = new HashMap<>();
        variables.put("query", query.text());
        variables.put("chatMemory", chatMemory);
        return promptTemplate.apply(variables);
    }

    //......
}    
CompressingQueryTransformer在开启ChatMemory的时候才可以用,它通过DEFAULT_PROMPT_TEMPLATE将query和历史ChatMessage一起构建prompt发给LLM,让LLM进行压缩

ExpandingQueryTransformer

dev/langchain4j/rag/query/transformer/ExpandingQueryTransformer.java

public class ExpandingQueryTransformer implements QueryTransformer {

    public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
            """
                    Generate {{n}} different versions of a provided user query. \
                    Each version should be worded differently, using synonyms or alternative sentence structures, \
                    but they should all retain the original meaning. \
                    These versions will be used to retrieve relevant documents. \
                    It is very important to provide each query version on a separate line, \
                    without enumerations, hyphens, or any additional formatting!
                    User query: {{query}}"""
    );
    public static final int DEFAULT_N = 3;

    protected final ChatLanguageModel chatLanguageModel;
    protected final PromptTemplate promptTemplate;
    protected final int n;

    public ExpandingQueryTransformer(ChatLanguageModel chatLanguageModel) {
        this(chatLanguageModel, DEFAULT_PROMPT_TEMPLATE, DEFAULT_N);
    }

    public ExpandingQueryTransformer(ChatLanguageModel chatLanguageModel, int n) {
        this(chatLanguageModel, DEFAULT_PROMPT_TEMPLATE, n);
    }

    public ExpandingQueryTransformer(ChatLanguageModel chatLanguageModel, PromptTemplate promptTemplate) {
        this(chatLanguageModel, ensureNotNull(promptTemplate, "promptTemplate"), DEFAULT_N);
    }

    public ExpandingQueryTransformer(ChatLanguageModel chatLanguageModel, PromptTemplate promptTemplate, Integer n) {
        this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel");
        this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
        this.n = ensureGreaterThanZero(getOrDefault(n, DEFAULT_N), "n");
    }

    public static ExpandingQueryTransformerBuilder builder() {
        return new ExpandingQueryTransformerBuilder();
    }

    @Override
    public Collection<Query> transform(Query query) {
        Prompt prompt = createPrompt(query);
        String response = chatLanguageModel.chat(prompt.text());
        List<String> queries = parse(response);
        return queries.stream()
                .map(queryText -> query.metadata() == null
                        ? Query.from(queryText)
                        : Query.from(queryText, query.metadata()))
                .collect(toList());
    }

    protected Prompt createPrompt(Query query) {
        Map<String, Object> variables = new HashMap<>();
        variables.put("query", query.text());
        variables.put("n", n);
        return promptTemplate.apply(variables);
    }

    protected List<String> parse(String queries) {
        return stream(queries.split("\n"))
                .filter(Utils::isNotNullOrBlank)
                .collect(toList());
    }

    //......
}    
ExpandingQueryTransformer利用DEFAULT_PROMPT_TEMPLATE让LLM给出用户query的n个不同版本,默认n为3

QueryRouter

dev/langchain4j/rag/query/router/QueryRouter.java

@Experimental
public interface QueryRouter {

    /**
     * Routes the given {@link Query} to one or multiple {@link ContentRetriever}s.
     *
     * @param query The {@link Query} to be routed.
     * @return A collection of one or more {@link ContentRetriever}s to which the {@link Query} should be routed.
     */
    Collection<ContentRetriever> route(Query query);
}
QueryRouter定义了route方法,根据query返回一系列的ContentRetriever,它有DefaultQueryRouter、LanguageModelQueryRouter两个实现

DefaultQueryRouter

dev/langchain4j/rag/query/router/DefaultQueryRouter.java

public class DefaultQueryRouter implements QueryRouter {

    private final Collection<ContentRetriever> contentRetrievers;

    public DefaultQueryRouter(ContentRetriever... contentRetrievers) {
        this(asList(contentRetrievers));
    }

    public DefaultQueryRouter(Collection<ContentRetriever> contentRetrievers) {
        this.contentRetrievers = unmodifiableCollection(ensureNotEmpty(contentRetrievers, "contentRetrievers"));
    }

    @Override
    public Collection<ContentRetriever> route(Query query) {
        return contentRetrievers;
    }
}
DefaultQueryRouter构造器要求输入contentRetrievers,route方法直接返回

LanguageModelQueryRouter

dev/langchain4j/rag/query/router/LanguageModelQueryRouter.java

public class LanguageModelQueryRouter implements QueryRouter {

    private static final Logger log = LoggerFactory.getLogger(LanguageModelQueryRouter.class);

    public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
            """
                    Based on the user query, determine the most suitable data source(s) \
                    to retrieve relevant information from the following options:
                    {{options}}
                    It is very important that your answer consists of either a single number \
                    or multiple numbers separated by commas and nothing else!
                    User query: {{query}}"""
    );

    protected final ChatLanguageModel chatLanguageModel;
    protected final PromptTemplate promptTemplate;
    protected final String options;
    protected final Map<Integer, ContentRetriever> idToRetriever;
    protected final FallbackStrategy fallbackStrategy;

    public LanguageModelQueryRouter(ChatLanguageModel chatLanguageModel,
                                    Map<ContentRetriever, String> retrieverToDescription) {
        this(chatLanguageModel, retrieverToDescription, DEFAULT_PROMPT_TEMPLATE, DO_NOT_ROUTE);
    }

    public LanguageModelQueryRouter(ChatLanguageModel chatLanguageModel,
                                    Map<ContentRetriever, String> retrieverToDescription,
                                    PromptTemplate promptTemplate,
                                    FallbackStrategy fallbackStrategy) {
        this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel");
        ensureNotEmpty(retrieverToDescription, "retrieverToDescription");
        this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);

        Map<Integer, ContentRetriever> idToRetriever = new HashMap<>();
        StringBuilder optionsBuilder = new StringBuilder();
        int id = 1;
        for (Map.Entry<ContentRetriever, String> entry : retrieverToDescription.entrySet()) {
            idToRetriever.put(id, ensureNotNull(entry.getKey(), "ContentRetriever"));

            if (id > 1) {
                optionsBuilder.append("\n");
            }
            optionsBuilder.append(id);
            optionsBuilder.append(": ");
            optionsBuilder.append(ensureNotBlank(entry.getValue(), "ContentRetriever description"));

            id++;
        }
        this.idToRetriever = idToRetriever;
        this.options = optionsBuilder.toString();
        this.fallbackStrategy = getOrDefault(fallbackStrategy, DO_NOT_ROUTE);
    }

    public static LanguageModelQueryRouterBuilder builder() {
        return new LanguageModelQueryRouterBuilder();
    }

    @Override
    public Collection<ContentRetriever> route(Query query) {
        Prompt prompt = createPrompt(query);
        try {
            String response = chatLanguageModel.chat(prompt.text());
            return parse(response);
        } catch (Exception e) {
            log.warn("Failed to route query '{}'", query.text(), e);
            return fallback(query, e);
        }
    }

    protected Collection<ContentRetriever> fallback(Query query, Exception e) {
        return switch (fallbackStrategy) {
            case DO_NOT_ROUTE -> {
                log.debug("Fallback: query '{}' will not be routed", query.text());
                yield emptyList();
            }
            case ROUTE_TO_ALL -> {
                log.debug("Fallback: query '{}' will be routed to all available content retrievers", query.text());
                yield new ArrayList<>(idToRetriever.values());
            }
            default -> throw new RuntimeException(e);
        };
    }

    protected Prompt createPrompt(Query query) {
        Map<String, Object> variables = new HashMap<>();
        variables.put("query", query.text());
        variables.put("options", options);
        return promptTemplate.apply(variables);
    }

    protected Collection<ContentRetriever> parse(String choices) {
        return stream(choices.split(","))
                .map(String::trim)
                .map(Integer::parseInt)
                .map(idToRetriever::get)
                .collect(toList());
    }

    /**
     * Strategy applied if the call to the LLM fails of if LLM does not return a valid response.
     * It could be because it was formatted improperly, or it is unclear where to route.
     */
    public enum FallbackStrategy {

        /**
         * In this case, the {@link Query} will not be routed to any {@link ContentRetriever},
         * thus skipping the RAG flow. No content will be appended to the original {@link UserMessage}.
         */
        DO_NOT_ROUTE,

        /**
         * In this case, the {@link Query} will be routed to all {@link ContentRetriever}s.
         */
        ROUTE_TO_ALL,

        /**
         * In this case, an original exception will be re-thrown, and the RAG flow will fail.
         */
        FAIL
    }

    //......
 }   
LanguageModelQueryRouter使用chatLanguageModel去进行路由决策,其构造器要求输入chatLanguageModel以及Map<ContentRetriever, String>,其中String是关于这个ContentRetriever的描述用于帮助chatLanguageModel决策路由到哪个ContentRetriever,它有定义一个fallbackStrategy用于指定调用chatLanguageModel发生异常的时候如何处理,默认是DO_NOT_ROUTE。

ContentRetriever

dev/langchain4j/rag/content/retriever/ContentRetriever.java

public interface ContentRetriever {

    /**
     * Retrieves relevant {@link Content}s using a given {@link Query}.
     * The {@link Content}s are sorted by relevance, with the most relevant {@link Content}s appearing
     * at the beginning of the returned {@code List<Content>}.
     *
     * @param query The {@link Query} to use for retrieval.
     * @return A list of retrieved {@link Content}s.
     */
    List<Content> retrieve(Query query);
}
ContentRetriever接口定义了retrieve方法,用于根据query来返回一系列的Content,它有EmbeddingStoreContentRetriever、WebSearchContentRetriever实现

ContentAggregator

dev/langchain4j/rag/content/aggregator/ContentAggregator.java

@Experimental
public interface ContentAggregator {

    /**
     * Aggregates all {@link Content}s retrieved by all {@link ContentRetriever}s using all {@link Query}s.
     * The {@link Content}s, both on input and output, are sorted by relevance,
     * with the most relevant {@link Content}s appearing at the beginning of {@code List<Content>}.
     *
     * @param queryToContents A map from a {@link Query} to all {@code List<Content>} retrieved with that {@link Query}.
     *                        Given that each {@link Query} can be routed to multiple {@link ContentRetriever}s, the
     *                        value of this map is a {@code Collection<List<Content>>}
     *                        rather than a simple {@code List<Content>}.
     * @return A list of aggregated {@link Content}s.
     */
    List<Content> aggregate(Map<Query, Collection<List<Content>>> queryToContents);
}
ContentAggregator接口定义了aggregate方法,它聚合queryToContents返回一系列的Content,这一步主要是确保传给LLM的Contents是最相关的且是没有冗余的。一些有效的方法包括:重排序(ReRankingContentAggregator)、Reciprocal Rank Fusion(ReciprocalRankFuser,DefaultContentAggregator和ReRankingContentAggregator都有用到)

DefaultContentAggregator

dev/langchain4j/rag/content/aggregator/DefaultContentAggregator.java

public class DefaultContentAggregator implements ContentAggregator {

    @Override
    public List<Content> aggregate(Map<Query, Collection<List<Content>>> queryToContents) {

        // First, for each query, fuse all contents retrieved from different sources using that query.
        Map<Query, List<Content>> fused = fuse(queryToContents);

        // Then, fuse all contents retrieved using all queries
        return ReciprocalRankFuser.fuse(fused.values());
    }

    protected Map<Query, List<Content>> fuse(Map<Query, Collection<List<Content>>> queryToContents) {
        Map<Query, List<Content>> fused = new LinkedHashMap<>();
        for (Query query : queryToContents.keySet()) {
            Collection<List<Content>> contents = queryToContents.get(query);
            fused.put(query, ReciprocalRankFuser.fuse(contents));
        }
        return fused;
    }
}
DefaultContentAggregator主要是使用了两阶段的fuse,第一阶段先将每个query检索的所有List<Content>合并为一个List<Content>;第二阶段再将所有List<Content>(第一阶段的结果)合并为一个List<Content>。这里使用的是ReciprocalRankFuser.fuse进行合并。

ReRankingContentAggregator

dev/langchain4j/rag/content/aggregator/ReRankingContentAggregator.java

public class ReRankingContentAggregator implements ContentAggregator {

    public static final Function<Map<Query, Collection<List<Content>>>, Query> DEFAULT_QUERY_SELECTOR =
            (queryToContents) -> {
                if (queryToContents.size() > 1) {
                    throw illegalArgument(
                            "The 'queryToContents' contains %s queries, making the re-ranking ambiguous. " +
                                    "Because there are multiple queries, it is unclear which one should be " +
                                    "used for re-ranking. Please provide a 'querySelector' in the constructor/builder.",
                            queryToContents.size()
                    );
                }
                return queryToContents.keySet().iterator().next();
            };

    private final ScoringModel scoringModel;
    private final Function<Map<Query, Collection<List<Content>>>, Query> querySelector;
    private final Double minScore;
    private final Integer maxResults;

    public ReRankingContentAggregator(ScoringModel scoringModel) {
        this(scoringModel, DEFAULT_QUERY_SELECTOR, null);
    }

    public ReRankingContentAggregator(ScoringModel scoringModel,
                                      Function<Map<Query, Collection<List<Content>>>, Query> querySelector,
                                      Double minScore) {
        this(scoringModel, querySelector, minScore, null);
    }

    public ReRankingContentAggregator(ScoringModel scoringModel,
                                      Function<Map<Query, Collection<List<Content>>>, Query> querySelector,
                                      Double minScore,
                                      Integer maxResults) {
        this.scoringModel = ensureNotNull(scoringModel, "scoringModel");
        this.querySelector = getOrDefault(querySelector, DEFAULT_QUERY_SELECTOR);
        this.minScore = minScore;
        this.maxResults = getOrDefault(maxResults, Integer.MAX_VALUE);
    }

    public static ReRankingContentAggregatorBuilder builder() {
        return new ReRankingContentAggregatorBuilder();
    }

    @Override
    public List<Content> aggregate(Map<Query, Collection<List<Content>>> queryToContents) {

        if (queryToContents.isEmpty()) {
            return emptyList();
        }

        // Select a query against which all contents will be re-ranked
        Query query = querySelector.apply(queryToContents);

        // For each query, fuse all contents retrieved from different sources using that query
        Map<Query, List<Content>> queryToFusedContents = fuse(queryToContents);

        // Fuse all contents retrieved using all queries
        List<Content> fusedContents = ReciprocalRankFuser.fuse(queryToFusedContents.values());

        if (fusedContents.isEmpty()) {
            return fusedContents;
        }

        // Re-rank all the fused contents against the query selected by the query selector
        return reRankAndFilter(fusedContents, query);
    }

    protected Map<Query, List<Content>> fuse(Map<Query, Collection<List<Content>>> queryToContents) {
        Map<Query, List<Content>> fused = new LinkedHashMap<>();
        for (Query query : queryToContents.keySet()) {
            Collection<List<Content>> contents = queryToContents.get(query);
            fused.put(query, ReciprocalRankFuser.fuse(contents));
        }
        return fused;
    }

    protected List<Content> reRankAndFilter(List<Content> contents, Query query) {

        List<TextSegment> segments = contents.stream()
                .map(Content::textSegment)
                .collect(Collectors.toList());

        List<Double> scores = scoringModel.scoreAll(segments, query.text()).content();

        Map<TextSegment, Double> segmentToScore = new HashMap<>();
        for (int i = 0; i < segments.size(); i++) {
            segmentToScore.put(segments.get(i), scores.get(i));
        }

        return segmentToScore.entrySet().stream()
                .filter(entry -> minScore == null || entry.getValue() >= minScore)
                .sorted(Map.Entry.<TextSegment, Double>comparingByValue().reversed())
                .map(entry ->  Content.from(entry.getKey(), Map.of(RERANKED_SCORE, entry.getValue())))
                .limit(maxResults)
                .collect(Collectors.toList());
    }

    //......
}    
ReRankingContentAggregator使用诸如Cohere的ScoringModel进行re-ranking;ScoringModel根据Query来给Contents进行打分,如果输入了多个Query(比如使用了ExpandingQueryTransformer)那么必须提供一个querySelector来选择一个Query用于对所有Content进行排名;也可以自定义实现根据用于检索的Query对所有Contents进行评分,然后基于这些评分进行重新排序;其aggregate方法先通过querySelector选择一个query,之后进行两阶段fuse,最后通过scoringModel选出来的query和fusedContents进行评分,再根据minScore进行过滤、转换、返回maxResults。

ContentInjector

dev/langchain4j/rag/content/injector/ContentInjector.java

@Experimental
public interface ContentInjector {

    /**
     * Injects given {@link Content}s into a given {@link ChatMessage}.
     * <br>
     * This method has a default implementation in order to <b>temporarily</b> support
     * current custom implementations of {@code ContentInjector}. The default implementation will be removed soon.
     *
     * @param contents    The list of {@link Content} to be injected.
     * @param chatMessage The {@link ChatMessage} into which the {@link Content}s are to be injected.
     *                    Can be either a {@link UserMessage} or a {@link SystemMessage}.
     * @return The {@link UserMessage} with the injected {@link Content}s.
     */
    default ChatMessage inject(List<Content> contents, ChatMessage chatMessage) {

        if (!(chatMessage instanceof UserMessage)) {
            throw runtime("Please implement 'ChatMessage inject(List<Content>, ChatMessage)' method " +
                    "in order to inject contents into " + chatMessage);
        }

        return inject(contents, (UserMessage) chatMessage);
    }

    /**
     * Injects given {@link Content}s into a given {@link UserMessage}.
     *
     * @param contents    The list of {@link Content} to be injected.
     * @param userMessage The {@link UserMessage} into which the {@link Content}s are to be injected.
     * @return The {@link UserMessage} with the injected {@link Content}s.
     * @deprecated Use/implement {@link #inject(List, ChatMessage)} instead.
     */
    @Deprecated
    UserMessage inject(List<Content> contents, UserMessage userMessage);
}
ContentInjector定义了inject方法,它将contents注入到userMessage,返回新的ChatMessage

DefaultContentInjector

dev/langchain4j/rag/content/injector/DefaultContentInjector.java

public class DefaultContentInjector implements ContentInjector {

    public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
            """
                    {{userMessage}}

                    Answer using the following information:
                    {{contents}}""");

    private final PromptTemplate promptTemplate;
    private final List<String> metadataKeysToInclude;

    public DefaultContentInjector() {
        this(DEFAULT_PROMPT_TEMPLATE, null);
    }

    public DefaultContentInjector(List<String> metadataKeysToInclude) {
        this(DEFAULT_PROMPT_TEMPLATE, ensureNotEmpty(metadataKeysToInclude, "metadataKeysToInclude"));
    }

    public DefaultContentInjector(PromptTemplate promptTemplate) {
        this(ensureNotNull(promptTemplate, "promptTemplate"), null);
    }

    public DefaultContentInjector(PromptTemplate promptTemplate, List<String> metadataKeysToInclude) {
        this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
        this.metadataKeysToInclude = copyIfNotNull(metadataKeysToInclude);
    }

    public static DefaultContentInjectorBuilder builder() {
        return new DefaultContentInjectorBuilder();
    }

    @Override
    public ChatMessage inject(List<Content> contents, ChatMessage chatMessage) {

        if (contents.isEmpty()) {
            return chatMessage;
        }

        Prompt prompt = createPrompt(chatMessage, contents);
        if (chatMessage instanceof UserMessage message && isNotNullOrBlank(message.name())) {
            return prompt.toUserMessage(message.name());
        }

        return prompt.toUserMessage();
    }

    protected Prompt createPrompt(ChatMessage chatMessage, List<Content> contents) {
        return createPrompt((UserMessage) chatMessage, contents);
    }

    /**
     * @deprecated use {@link #inject(List, ChatMessage)} instead.
     */
    @Override
    @Deprecated
    public UserMessage inject(List<Content> contents, UserMessage userMessage) {

        if (contents.isEmpty()) {
            return userMessage;
        }

        Prompt prompt = createPrompt(userMessage, contents);
        if (isNotNullOrBlank(userMessage.name())) {
            return prompt.toUserMessage(userMessage.name());
        }
        return prompt.toUserMessage();
    }

    /**
     * @deprecated implement/override {@link #createPrompt(ChatMessage, List)} instead.
     */
    @Deprecated
    protected Prompt createPrompt(UserMessage userMessage, List<Content> contents) {
        Map<String, Object> variables = new HashMap<>();
        variables.put("userMessage", userMessage.singleText());
        variables.put("contents", format(contents));
        return promptTemplate.apply(variables);
    }

    protected String format(List<Content> contents) {
        return contents.stream().map(this::format).collect(joining("\n\n"));
    }

    protected String format(Content content) {

        TextSegment segment = content.textSegment();

        if (isNullOrEmpty(metadataKeysToInclude)) {
            return segment.text();
        }

        String segmentContent = segment.text();
        String segmentMetadata = format(segment.metadata());

        return format(segmentContent, segmentMetadata);
    }

    protected String format(Metadata metadata) {
        StringBuilder formattedMetadata = new StringBuilder();
        for (String metadataKey : metadataKeysToInclude) {
            String metadataValue = metadata.getString(metadataKey);
            if (metadataValue != null) {
                if (!formattedMetadata.isEmpty()) {
                    formattedMetadata.append("\n");
                }
                formattedMetadata.append(metadataKey).append(": ").append(metadataValue);
            }
        }
        return formattedMetadata.toString();
    }

    protected String format(String segmentContent, String segmentMetadata) {
        return segmentMetadata.isEmpty()
                ? segmentContent
                : String.format("content: %s\n%s", segmentContent, segmentMetadata);
    }

    //......
}    
DefaultContentInjector通过promptTemplate来将contents注入到userMessage中

小结

langchain4j的Advanced RAG提供的入口是RetrievalAugmentor,它包含了QueryTransformer、QueryRouter、ContentAggregator、ContentInjector、Executor这些属性。

  • DefaultRetrievalAugmentor的构造器要求queryRouter不为null(QueryRouter的route方法会返回ContentRetriever)
  • 对于queryTransformer为null的默认使用DefaultQueryTransformer
  • 对于contentAggregator为null的默认使用DefaultContentAggregator
  • 对于contentInjector为null的默认使用DefaultContentInjector
  • 对于executor为null的默认创建了一个coreSize为0,maximumPoolSize为Integer.MAX\_VALUE,keepAliveTime为1s,workQueue为SynchronousQueue的ThreadPoolExecutor。而QueryRouter则包含了ContentRetriever

doc


codecraft
11.9k 声望2k 粉丝

当一个代码的工匠回首往事时,不因虚度年华而悔恨,也不因碌碌无为而羞愧,这样,当他老的时候,可以很自豪告诉世人,我曾经将代码注入生命去打造互联网的浪潮之巅,那是个很疯狂的时代,我在一波波的浪潮上留下...