序
本文主要研究一下langchain4j的Advanced RAG
核心流程
- 将UserMessage转换为一个原始的Query
- QueryTransformer将原始的Query转换为多个Query
- 每个Query通过QueryRouter被路由到一个或多个ContentRetriever
- 每个ContentRetriever检索对应Query相关的Content
- ContentAggregator将所有检索到的Content合并成一个最终排序的列表
- 这个内容列表被注入到原始的UserMessage中
- 最后包含原始查询以及注入的相关内容的UserMessage被发送到LLM
示例
public class _02_Advanced_RAG_with_Query_Routing_Example {
/**
* Please refer to {@link Naive_RAG_Example} for a basic context.
* <p>
* Advanced RAG in LangChain4j is described here: https://github.com/langchain4j/langchain4j/pull/538
* <p>
* This example showcases the implementation of a more advanced RAG application
* using a technique known as "query routing".
* <p>
* Often, private data is spread across multiple sources and formats.
* This might include internal company documentation on Confluence, your project's code in a Git repository,
* a relational database with user data, or a search engine with the products you sell, among others.
* In a RAG flow that utilizes data from multiple sources, you will likely have multiple
* {@link EmbeddingStore}s or {@link ContentRetriever}s.
* While you could route each user query to all available {@link ContentRetriever}s,
* this approach might be inefficient and counterproductive.
* <p>
* "Query routing" is the solution to this challenge. It involves directing a query to the most appropriate
* {@link ContentRetriever} (or several). Routing can be implemented in various ways:
* - Using rules (e.g., depending on the user's privileges, location, etc.).
* - Using keywords (e.g., if a query contains words X1, X2, X3, route it to {@link ContentRetriever} X, etc.).
* - Using semantic similarity (see EmbeddingModelTextClassifierExample in this repository).
* - Using an LLM to make a routing decision.
* <p>
* For scenarios 1, 2, and 3, you can implement a custom {@link QueryRouter}.
* For scenario 4, this example will demonstrate how to use a {@link LanguageModelQueryRouter}.
*/
public static void main(String[] args) {
Assistant assistant = createAssistant();
// First, ask "What is the legacy of John Doe?"
// Then, ask "Can I cancel my reservation?"
// Now, see the logs to observe how the queries are routed to different retrievers.
startConversationWith(assistant);
}
private static Assistant createAssistant() {
EmbeddingModel embeddingModel = new BgeSmallEnV15QuantizedEmbeddingModel();
// Let's create a separate embedding store specifically for biographies.
EmbeddingStore<TextSegment> biographyEmbeddingStore =
embed(toPath("documents/biography-of-john-doe.txt"), embeddingModel);
ContentRetriever biographyContentRetriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(biographyEmbeddingStore)
.embeddingModel(embeddingModel)
.maxResults(2)
.minScore(0.6)
.build();
// Additionally, let's create a separate embedding store dedicated to terms of use.
EmbeddingStore<TextSegment> termsOfUseEmbeddingStore =
embed(toPath("documents/miles-of-smiles-terms-of-use.txt"), embeddingModel);
ContentRetriever termsOfUseContentRetriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(termsOfUseEmbeddingStore)
.embeddingModel(embeddingModel)
.maxResults(2)
.minScore(0.6)
.build();
ChatLanguageModel chatLanguageModel = OpenAiChatModel.builder()
.apiKey(OPENAI_API_KEY)
.modelName(GPT_4_O_MINI)
.build();
// Let's create a query router.
Map<ContentRetriever, String> retrieverToDescription = new HashMap<>();
retrieverToDescription.put(biographyContentRetriever, "biography of John Doe");
retrieverToDescription.put(termsOfUseContentRetriever, "terms of use of car rental company");
QueryRouter queryRouter = new LanguageModelQueryRouter(chatLanguageModel, retrieverToDescription);
RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
.queryRouter(queryRouter)
.build();
return AiServices.builder(Assistant.class)
.chatLanguageModel(chatLanguageModel)
.retrievalAugmentor(retrievalAugmentor)
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
.build();
}
private static EmbeddingStore<TextSegment> embed(Path documentPath, EmbeddingModel embeddingModel) {
DocumentParser documentParser = new TextDocumentParser();
Document document = loadDocument(documentPath, documentParser);
DocumentSplitter splitter = DocumentSplitters.recursive(300, 0);
List<TextSegment> segments = splitter.split(document);
List<Embedding> embeddings = embeddingModel.embedAll(segments).content();
EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
embeddingStore.addAll(embeddings, segments);
return embeddingStore;
}
}
这里使用了DefaultRetrievalAugmentor来设置了LanguageModelQueryRouter,这里设置了biographyContentRetriever、termsOfUseContentRetriever两个ContentRetriever。
源码解析
RetrievalAugmentor
dev/langchain4j/rag/RetrievalAugmentor.java
@Experimental
public interface RetrievalAugmentor {
/**
* Augments the {@link ChatMessage} provided in the {@link AugmentationRequest} with retrieved {@link Content}s.
* <br>
* This method has a default implementation in order to <b>temporarily</b> support
* current custom implementations of {@code RetrievalAugmentor}. The default implementation will be removed soon.
*
* @param augmentationRequest The {@code AugmentationRequest} containing the {@code ChatMessage} to augment.
* @return The {@link AugmentationResult} containing the augmented {@code ChatMessage}.
*/
default AugmentationResult augment(AugmentationRequest augmentationRequest) {
if (!(augmentationRequest.chatMessage() instanceof UserMessage)) {
throw runtime("Please implement 'AugmentationResult augment(AugmentationRequest)' method " +
"in order to augment " + augmentationRequest.chatMessage().getClass());
}
UserMessage augmented = augment((UserMessage) augmentationRequest.chatMessage(), augmentationRequest.metadata());
return AugmentationResult.builder()
.chatMessage(augmented)
.build();
}
/**
* Augments the provided {@link UserMessage} with retrieved content.
*
* @param userMessage The {@link UserMessage} to be augmented.
* @param metadata The {@link Metadata} that may be useful or necessary for retrieval and augmentation.
* @return The augmented {@link UserMessage}.
* @deprecated Use/implement {@link #augment(AugmentationRequest)} instead.
*/
@Deprecated
UserMessage augment(UserMessage userMessage, Metadata metadata);
}
RetrievalAugmentor接口定义了augment(AugmentationRequest augmentationRequest)方法,它作为langchain4j的RAG入口,负责根据AugmentationRequest来检索相关Content,它提供了默认实现主要是适配废弃的augment(UserMessage userMessage, Metadata metadata)
方法
DefaultRetrievalAugmentor
dev/langchain4j/rag/DefaultRetrievalAugmentor.java
public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
private static final Logger log = LoggerFactory.getLogger(DefaultRetrievalAugmentor.class);
private final QueryTransformer queryTransformer;
private final QueryRouter queryRouter;
private final ContentAggregator contentAggregator;
private final ContentInjector contentInjector;
private final Executor executor;
public DefaultRetrievalAugmentor(QueryTransformer queryTransformer,
QueryRouter queryRouter,
ContentAggregator contentAggregator,
ContentInjector contentInjector,
Executor executor) {
this.queryTransformer = getOrDefault(queryTransformer, DefaultQueryTransformer::new);
this.queryRouter = ensureNotNull(queryRouter, "queryRouter");
this.contentAggregator = getOrDefault(contentAggregator, DefaultContentAggregator::new);
this.contentInjector = getOrDefault(contentInjector, DefaultContentInjector::new);
this.executor = getOrDefault(executor, DefaultRetrievalAugmentor::createDefaultExecutor);
}
private static ExecutorService createDefaultExecutor() {
return new ThreadPoolExecutor(
0, Integer.MAX_VALUE,
1, SECONDS,
new SynchronousQueue<>()
);
}
/**
* @deprecated use {@link #augment(AugmentationRequest)} instead.
*/
@Override
@Deprecated
public UserMessage augment(UserMessage userMessage, Metadata metadata) {
AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
return (UserMessage) augment(augmentationRequest).chatMessage();
}
@Override
public AugmentationResult augment(AugmentationRequest augmentationRequest) {
ChatMessage chatMessage = augmentationRequest.chatMessage();
Metadata metadata = augmentationRequest.metadata();
Query originalQuery = Query.from(chatMessage.text(), metadata);
Collection<Query> queries = queryTransformer.transform(originalQuery);
logQueries(originalQuery, queries);
Map<Query, Collection<List<Content>>> queryToContents = process(queries);
List<Content> contents = contentAggregator.aggregate(queryToContents);
log(queryToContents, contents);
ChatMessage augmentedChatMessage = contentInjector.inject(contents, chatMessage);
log(augmentedChatMessage);
return AugmentationResult.builder()
.chatMessage(augmentedChatMessage)
.contents(contents)
.build();
}
//......
}
DefaultRetrievalAugmentor实现了RetrievalAugmentor接口,它定义了queryTransformer、queryRouter、contentAggregator、contentInjector、executor;其augment方法先是将chatMessage转换为originalQuery,接着通过queryTransformer.transform得到queries,接着通过process得到一系列的Content,然后通过contentAggregator.aggregate来聚合contents,最后通过contentInjector.inject来将contents注入到chatMessage得到augmentedChatMessage,返回结果AugmentationResult包含了augmentedChatMessage及注入的contents。
其构造器要求queryRouter不为null,对于queryTransformer为null的默认使用DefaultQueryTransformer,对于contentAggregator为null的默认使用DefaultContentAggregator,对于contentInjector为null的默认使用DefaultContentInjector,对于executor为null的默认创建了一个coreSize为0,maximumPoolSize为Integer.MAX\_VALUE,keepAliveTime为1s,workQueue为SynchronousQueue的ThreadPoolExecutor
private Map<Query, Collection<List<Content>>> process(Collection<Query> queries) {
if (queries.size() == 1) {
Query query = queries.iterator().next();
Collection<ContentRetriever> retrievers = queryRouter.route(query);
if (retrievers.size() == 1) {
ContentRetriever contentRetriever = retrievers.iterator().next();
List<Content> contents = contentRetriever.retrieve(query);
return singletonMap(query, singletonList(contents));
} else if (retrievers.size() > 1) {
Collection<List<Content>> contents = retrieveFromAll(retrievers, query).join();
return singletonMap(query, contents);
} else {
return emptyMap();
}
} else if (queries.size() > 1) {
Map<Query, CompletableFuture<Collection<List<Content>>>> queryToFutureContents = new ConcurrentHashMap<>();
queries.forEach(query -> {
CompletableFuture<Collection<List<Content>>> futureContents =
supplyAsync(() -> {
Collection<ContentRetriever> retrievers = queryRouter.route(query);
log(query, retrievers);
return retrievers;
},
executor
).thenCompose(retrievers -> retrieveFromAll(retrievers, query));
queryToFutureContents.put(query, futureContents);
});
return join(queryToFutureContents);
} else {
return emptyMap();
}
}
private CompletableFuture<Collection<List<Content>>> retrieveFromAll(Collection<ContentRetriever> retrievers,
Query query) {
List<CompletableFuture<List<Content>>> futureContents = retrievers.stream()
.map(retriever -> supplyAsync(() -> retrieve(retriever, query), executor))
.collect(Collectors.toList());
return allOf(futureContents.toArray(new CompletableFuture[0]))
.thenApply(ignored ->
futureContents.stream()
.map(CompletableFuture::join)
.collect(Collectors.toList()));
}
private static List<Content> retrieve(ContentRetriever retriever, Query query) {
List<Content> contents = retriever.retrieve(query);
log(query, retriever, contents);
return contents;
}
private static Map<Query, Collection<List<Content>>> join(
Map<Query, CompletableFuture<Collection<List<Content>>>> queryToFutureContents) {
return allOf(queryToFutureContents.values().toArray(new CompletableFuture[0]))
.thenApply(ignored ->
queryToFutureContents.entrySet().stream()
.collect(toMap(
Map.Entry::getKey,
entry -> entry.getValue().join()
))
).join();
}
process方法主要是通过queryRouter.route(query)来获取retrievers的路由,之后对每个ContentRetriever执行retrieve获取对应的List<Content>
,它针对queries是1个还是多个做了特殊处理,多个则通过executor来并发执行,最后通过join来等待。
QueryTransformer
dev/langchain4j/rag/query/transformer/QueryTransformer.java
@Experimental
public interface QueryTransformer {
/**
* Transforms the given {@link Query} into one or multiple {@link Query}s.
*
* @param query The {@link Query} to be transformed.
* @return A collection of one or more {@link Query}s derived from the original {@link Query}.
*/
Collection<Query> transform(Query query);
}
QueryTransformer定义了transform方法,主要用于修改或者扩展原始Query,已知的场景比如:查询压缩、查询扩展、查询重写、Step-back prompting、Hypothetical document embeddings (HyDE)。它有几个实现分别是:DefaultQueryTransformer, CompressingQueryTransformer, ExpandingQueryTransformer。
DefaultQueryTransformer
dev/langchain4j/rag/query/transformer/DefaultQueryTransformer.java
public class DefaultQueryTransformer implements QueryTransformer {
@Override
public Collection<Query> transform(Query query) {
return singletonList(query);
}
}
DefaultQueryTransformer将query包装为list返回
CompressingQueryTransformer
dev/langchain4j/rag/query/transformer/CompressingQueryTransformer.java
public class CompressingQueryTransformer implements QueryTransformer {
public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
"""
Read and understand the conversation between the User and the AI. \
Then, analyze the new query from the User. \
Identify all relevant details, terms, and context from both the conversation and the new query. \
Reformulate this query into a clear, concise, and self-contained format suitable for information retrieval.
Conversation:
{{chatMemory}}
User query: {{query}}
It is very important that you provide only reformulated query and nothing else! \
Do not prepend a query with anything!"""
);
protected final PromptTemplate promptTemplate;
protected final ChatLanguageModel chatLanguageModel;
public CompressingQueryTransformer(ChatLanguageModel chatLanguageModel) {
this(chatLanguageModel, DEFAULT_PROMPT_TEMPLATE);
}
public CompressingQueryTransformer(ChatLanguageModel chatLanguageModel, PromptTemplate promptTemplate) {
this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel");
this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
}
public static CompressingQueryTransformerBuilder builder() {
return new CompressingQueryTransformerBuilder();
}
@Override
public Collection<Query> transform(Query query) {
List<ChatMessage> chatMemory = query.metadata().chatMemory();
if (chatMemory.isEmpty()) {
// no need to compress if there are no previous messages
return singletonList(query);
}
Prompt prompt = createPrompt(query, format(chatMemory));
String compressedQueryText = chatLanguageModel.chat(prompt.text());
Query compressedQuery = query.metadata() == null
? Query.from(compressedQueryText)
: Query.from(compressedQueryText, query.metadata());
return singletonList(compressedQuery);
}
protected String format(List<ChatMessage> chatMemory) {
return chatMemory.stream()
.map(this::format)
.filter(Objects::nonNull)
.collect(joining("\n"));
}
protected String format(ChatMessage message) {
if (message instanceof UserMessage) {
return "User: " + message.text();
} else if (message instanceof AiMessage aiMessage) {
if (aiMessage.hasToolExecutionRequests()) {
return null;
}
return "AI: " + aiMessage.text();
} else {
return null;
}
}
protected Prompt createPrompt(Query query, String chatMemory) {
Map<String, Object> variables = new HashMap<>();
variables.put("query", query.text());
variables.put("chatMemory", chatMemory);
return promptTemplate.apply(variables);
}
//......
}
CompressingQueryTransformer在开启ChatMemory的时候才可以用,它通过DEFAULT_PROMPT_TEMPLATE将query和历史ChatMessage一起构建prompt发给LLM,让LLM进行压缩
ExpandingQueryTransformer
dev/langchain4j/rag/query/transformer/ExpandingQueryTransformer.java
public class ExpandingQueryTransformer implements QueryTransformer {
public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
"""
Generate {{n}} different versions of a provided user query. \
Each version should be worded differently, using synonyms or alternative sentence structures, \
but they should all retain the original meaning. \
These versions will be used to retrieve relevant documents. \
It is very important to provide each query version on a separate line, \
without enumerations, hyphens, or any additional formatting!
User query: {{query}}"""
);
public static final int DEFAULT_N = 3;
protected final ChatLanguageModel chatLanguageModel;
protected final PromptTemplate promptTemplate;
protected final int n;
public ExpandingQueryTransformer(ChatLanguageModel chatLanguageModel) {
this(chatLanguageModel, DEFAULT_PROMPT_TEMPLATE, DEFAULT_N);
}
public ExpandingQueryTransformer(ChatLanguageModel chatLanguageModel, int n) {
this(chatLanguageModel, DEFAULT_PROMPT_TEMPLATE, n);
}
public ExpandingQueryTransformer(ChatLanguageModel chatLanguageModel, PromptTemplate promptTemplate) {
this(chatLanguageModel, ensureNotNull(promptTemplate, "promptTemplate"), DEFAULT_N);
}
public ExpandingQueryTransformer(ChatLanguageModel chatLanguageModel, PromptTemplate promptTemplate, Integer n) {
this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel");
this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
this.n = ensureGreaterThanZero(getOrDefault(n, DEFAULT_N), "n");
}
public static ExpandingQueryTransformerBuilder builder() {
return new ExpandingQueryTransformerBuilder();
}
@Override
public Collection<Query> transform(Query query) {
Prompt prompt = createPrompt(query);
String response = chatLanguageModel.chat(prompt.text());
List<String> queries = parse(response);
return queries.stream()
.map(queryText -> query.metadata() == null
? Query.from(queryText)
: Query.from(queryText, query.metadata()))
.collect(toList());
}
protected Prompt createPrompt(Query query) {
Map<String, Object> variables = new HashMap<>();
variables.put("query", query.text());
variables.put("n", n);
return promptTemplate.apply(variables);
}
protected List<String> parse(String queries) {
return stream(queries.split("\n"))
.filter(Utils::isNotNullOrBlank)
.collect(toList());
}
//......
}
ExpandingQueryTransformer利用DEFAULT_PROMPT_TEMPLATE让LLM给出用户query的n个不同版本,默认n为3
QueryRouter
dev/langchain4j/rag/query/router/QueryRouter.java
@Experimental
public interface QueryRouter {
/**
* Routes the given {@link Query} to one or multiple {@link ContentRetriever}s.
*
* @param query The {@link Query} to be routed.
* @return A collection of one or more {@link ContentRetriever}s to which the {@link Query} should be routed.
*/
Collection<ContentRetriever> route(Query query);
}
QueryRouter定义了route方法,根据query返回一系列的ContentRetriever,它有DefaultQueryRouter、LanguageModelQueryRouter两个实现
DefaultQueryRouter
dev/langchain4j/rag/query/router/DefaultQueryRouter.java
public class DefaultQueryRouter implements QueryRouter {
private final Collection<ContentRetriever> contentRetrievers;
public DefaultQueryRouter(ContentRetriever... contentRetrievers) {
this(asList(contentRetrievers));
}
public DefaultQueryRouter(Collection<ContentRetriever> contentRetrievers) {
this.contentRetrievers = unmodifiableCollection(ensureNotEmpty(contentRetrievers, "contentRetrievers"));
}
@Override
public Collection<ContentRetriever> route(Query query) {
return contentRetrievers;
}
}
DefaultQueryRouter构造器要求输入contentRetrievers,route方法直接返回
LanguageModelQueryRouter
dev/langchain4j/rag/query/router/LanguageModelQueryRouter.java
public class LanguageModelQueryRouter implements QueryRouter {
private static final Logger log = LoggerFactory.getLogger(LanguageModelQueryRouter.class);
public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
"""
Based on the user query, determine the most suitable data source(s) \
to retrieve relevant information from the following options:
{{options}}
It is very important that your answer consists of either a single number \
or multiple numbers separated by commas and nothing else!
User query: {{query}}"""
);
protected final ChatLanguageModel chatLanguageModel;
protected final PromptTemplate promptTemplate;
protected final String options;
protected final Map<Integer, ContentRetriever> idToRetriever;
protected final FallbackStrategy fallbackStrategy;
public LanguageModelQueryRouter(ChatLanguageModel chatLanguageModel,
Map<ContentRetriever, String> retrieverToDescription) {
this(chatLanguageModel, retrieverToDescription, DEFAULT_PROMPT_TEMPLATE, DO_NOT_ROUTE);
}
public LanguageModelQueryRouter(ChatLanguageModel chatLanguageModel,
Map<ContentRetriever, String> retrieverToDescription,
PromptTemplate promptTemplate,
FallbackStrategy fallbackStrategy) {
this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel");
ensureNotEmpty(retrieverToDescription, "retrieverToDescription");
this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
Map<Integer, ContentRetriever> idToRetriever = new HashMap<>();
StringBuilder optionsBuilder = new StringBuilder();
int id = 1;
for (Map.Entry<ContentRetriever, String> entry : retrieverToDescription.entrySet()) {
idToRetriever.put(id, ensureNotNull(entry.getKey(), "ContentRetriever"));
if (id > 1) {
optionsBuilder.append("\n");
}
optionsBuilder.append(id);
optionsBuilder.append(": ");
optionsBuilder.append(ensureNotBlank(entry.getValue(), "ContentRetriever description"));
id++;
}
this.idToRetriever = idToRetriever;
this.options = optionsBuilder.toString();
this.fallbackStrategy = getOrDefault(fallbackStrategy, DO_NOT_ROUTE);
}
public static LanguageModelQueryRouterBuilder builder() {
return new LanguageModelQueryRouterBuilder();
}
@Override
public Collection<ContentRetriever> route(Query query) {
Prompt prompt = createPrompt(query);
try {
String response = chatLanguageModel.chat(prompt.text());
return parse(response);
} catch (Exception e) {
log.warn("Failed to route query '{}'", query.text(), e);
return fallback(query, e);
}
}
protected Collection<ContentRetriever> fallback(Query query, Exception e) {
return switch (fallbackStrategy) {
case DO_NOT_ROUTE -> {
log.debug("Fallback: query '{}' will not be routed", query.text());
yield emptyList();
}
case ROUTE_TO_ALL -> {
log.debug("Fallback: query '{}' will be routed to all available content retrievers", query.text());
yield new ArrayList<>(idToRetriever.values());
}
default -> throw new RuntimeException(e);
};
}
protected Prompt createPrompt(Query query) {
Map<String, Object> variables = new HashMap<>();
variables.put("query", query.text());
variables.put("options", options);
return promptTemplate.apply(variables);
}
protected Collection<ContentRetriever> parse(String choices) {
return stream(choices.split(","))
.map(String::trim)
.map(Integer::parseInt)
.map(idToRetriever::get)
.collect(toList());
}
/**
* Strategy applied if the call to the LLM fails of if LLM does not return a valid response.
* It could be because it was formatted improperly, or it is unclear where to route.
*/
public enum FallbackStrategy {
/**
* In this case, the {@link Query} will not be routed to any {@link ContentRetriever},
* thus skipping the RAG flow. No content will be appended to the original {@link UserMessage}.
*/
DO_NOT_ROUTE,
/**
* In this case, the {@link Query} will be routed to all {@link ContentRetriever}s.
*/
ROUTE_TO_ALL,
/**
* In this case, an original exception will be re-thrown, and the RAG flow will fail.
*/
FAIL
}
//......
}
LanguageModelQueryRouter使用chatLanguageModel去进行路由决策,其构造器要求输入chatLanguageModel以及Map<ContentRetriever, String>
,其中String是关于这个ContentRetriever的描述用于帮助chatLanguageModel决策路由到哪个ContentRetriever,它有定义一个fallbackStrategy用于指定调用chatLanguageModel发生异常的时候如何处理,默认是DO_NOT_ROUTE。
ContentRetriever
dev/langchain4j/rag/content/retriever/ContentRetriever.java
public interface ContentRetriever {
/**
* Retrieves relevant {@link Content}s using a given {@link Query}.
* The {@link Content}s are sorted by relevance, with the most relevant {@link Content}s appearing
* at the beginning of the returned {@code List<Content>}.
*
* @param query The {@link Query} to use for retrieval.
* @return A list of retrieved {@link Content}s.
*/
List<Content> retrieve(Query query);
}
ContentRetriever接口定义了retrieve方法,用于根据query来返回一系列的Content,它有EmbeddingStoreContentRetriever、WebSearchContentRetriever实现
ContentAggregator
dev/langchain4j/rag/content/aggregator/ContentAggregator.java
@Experimental
public interface ContentAggregator {
/**
* Aggregates all {@link Content}s retrieved by all {@link ContentRetriever}s using all {@link Query}s.
* The {@link Content}s, both on input and output, are sorted by relevance,
* with the most relevant {@link Content}s appearing at the beginning of {@code List<Content>}.
*
* @param queryToContents A map from a {@link Query} to all {@code List<Content>} retrieved with that {@link Query}.
* Given that each {@link Query} can be routed to multiple {@link ContentRetriever}s, the
* value of this map is a {@code Collection<List<Content>>}
* rather than a simple {@code List<Content>}.
* @return A list of aggregated {@link Content}s.
*/
List<Content> aggregate(Map<Query, Collection<List<Content>>> queryToContents);
}
ContentAggregator接口定义了aggregate方法,它聚合queryToContents返回一系列的Content,这一步主要是确保传给LLM的Contents是最相关的且是没有冗余的。一些有效的方法包括:重排序(ReRankingContentAggregator
)、Reciprocal Rank Fusion(ReciprocalRankFuser
,DefaultContentAggregator和ReRankingContentAggregator都有用到)
DefaultContentAggregator
dev/langchain4j/rag/content/aggregator/DefaultContentAggregator.java
public class DefaultContentAggregator implements ContentAggregator {
@Override
public List<Content> aggregate(Map<Query, Collection<List<Content>>> queryToContents) {
// First, for each query, fuse all contents retrieved from different sources using that query.
Map<Query, List<Content>> fused = fuse(queryToContents);
// Then, fuse all contents retrieved using all queries
return ReciprocalRankFuser.fuse(fused.values());
}
protected Map<Query, List<Content>> fuse(Map<Query, Collection<List<Content>>> queryToContents) {
Map<Query, List<Content>> fused = new LinkedHashMap<>();
for (Query query : queryToContents.keySet()) {
Collection<List<Content>> contents = queryToContents.get(query);
fused.put(query, ReciprocalRankFuser.fuse(contents));
}
return fused;
}
}
DefaultContentAggregator主要是使用了两阶段的fuse,第一阶段先将每个query检索的所有List<Content>
合并为一个List<Content>
;第二阶段再将所有List<Content>
(第一阶段的结果)合并为一个List<Content>
。这里使用的是ReciprocalRankFuser.fuse进行合并。
ReRankingContentAggregator
dev/langchain4j/rag/content/aggregator/ReRankingContentAggregator.java
public class ReRankingContentAggregator implements ContentAggregator {
public static final Function<Map<Query, Collection<List<Content>>>, Query> DEFAULT_QUERY_SELECTOR =
(queryToContents) -> {
if (queryToContents.size() > 1) {
throw illegalArgument(
"The 'queryToContents' contains %s queries, making the re-ranking ambiguous. " +
"Because there are multiple queries, it is unclear which one should be " +
"used for re-ranking. Please provide a 'querySelector' in the constructor/builder.",
queryToContents.size()
);
}
return queryToContents.keySet().iterator().next();
};
private final ScoringModel scoringModel;
private final Function<Map<Query, Collection<List<Content>>>, Query> querySelector;
private final Double minScore;
private final Integer maxResults;
public ReRankingContentAggregator(ScoringModel scoringModel) {
this(scoringModel, DEFAULT_QUERY_SELECTOR, null);
}
public ReRankingContentAggregator(ScoringModel scoringModel,
Function<Map<Query, Collection<List<Content>>>, Query> querySelector,
Double minScore) {
this(scoringModel, querySelector, minScore, null);
}
public ReRankingContentAggregator(ScoringModel scoringModel,
Function<Map<Query, Collection<List<Content>>>, Query> querySelector,
Double minScore,
Integer maxResults) {
this.scoringModel = ensureNotNull(scoringModel, "scoringModel");
this.querySelector = getOrDefault(querySelector, DEFAULT_QUERY_SELECTOR);
this.minScore = minScore;
this.maxResults = getOrDefault(maxResults, Integer.MAX_VALUE);
}
public static ReRankingContentAggregatorBuilder builder() {
return new ReRankingContentAggregatorBuilder();
}
@Override
public List<Content> aggregate(Map<Query, Collection<List<Content>>> queryToContents) {
if (queryToContents.isEmpty()) {
return emptyList();
}
// Select a query against which all contents will be re-ranked
Query query = querySelector.apply(queryToContents);
// For each query, fuse all contents retrieved from different sources using that query
Map<Query, List<Content>> queryToFusedContents = fuse(queryToContents);
// Fuse all contents retrieved using all queries
List<Content> fusedContents = ReciprocalRankFuser.fuse(queryToFusedContents.values());
if (fusedContents.isEmpty()) {
return fusedContents;
}
// Re-rank all the fused contents against the query selected by the query selector
return reRankAndFilter(fusedContents, query);
}
protected Map<Query, List<Content>> fuse(Map<Query, Collection<List<Content>>> queryToContents) {
Map<Query, List<Content>> fused = new LinkedHashMap<>();
for (Query query : queryToContents.keySet()) {
Collection<List<Content>> contents = queryToContents.get(query);
fused.put(query, ReciprocalRankFuser.fuse(contents));
}
return fused;
}
protected List<Content> reRankAndFilter(List<Content> contents, Query query) {
List<TextSegment> segments = contents.stream()
.map(Content::textSegment)
.collect(Collectors.toList());
List<Double> scores = scoringModel.scoreAll(segments, query.text()).content();
Map<TextSegment, Double> segmentToScore = new HashMap<>();
for (int i = 0; i < segments.size(); i++) {
segmentToScore.put(segments.get(i), scores.get(i));
}
return segmentToScore.entrySet().stream()
.filter(entry -> minScore == null || entry.getValue() >= minScore)
.sorted(Map.Entry.<TextSegment, Double>comparingByValue().reversed())
.map(entry -> Content.from(entry.getKey(), Map.of(RERANKED_SCORE, entry.getValue())))
.limit(maxResults)
.collect(Collectors.toList());
}
//......
}
ReRankingContentAggregator使用诸如Cohere的ScoringModel进行re-ranking;ScoringModel根据Query来给Contents进行打分,如果输入了多个Query(比如使用了ExpandingQueryTransformer
)那么必须提供一个querySelector来选择一个Query用于对所有Content进行排名;也可以自定义实现根据用于检索的Query对所有Contents进行评分,然后基于这些评分进行重新排序;其aggregate方法先通过querySelector选择一个query,之后进行两阶段fuse,最后通过scoringModel选出来的query和fusedContents进行评分,再根据minScore进行过滤、转换、返回maxResults。
ContentInjector
dev/langchain4j/rag/content/injector/ContentInjector.java
@Experimental
public interface ContentInjector {
/**
* Injects given {@link Content}s into a given {@link ChatMessage}.
* <br>
* This method has a default implementation in order to <b>temporarily</b> support
* current custom implementations of {@code ContentInjector}. The default implementation will be removed soon.
*
* @param contents The list of {@link Content} to be injected.
* @param chatMessage The {@link ChatMessage} into which the {@link Content}s are to be injected.
* Can be either a {@link UserMessage} or a {@link SystemMessage}.
* @return The {@link UserMessage} with the injected {@link Content}s.
*/
default ChatMessage inject(List<Content> contents, ChatMessage chatMessage) {
if (!(chatMessage instanceof UserMessage)) {
throw runtime("Please implement 'ChatMessage inject(List<Content>, ChatMessage)' method " +
"in order to inject contents into " + chatMessage);
}
return inject(contents, (UserMessage) chatMessage);
}
/**
* Injects given {@link Content}s into a given {@link UserMessage}.
*
* @param contents The list of {@link Content} to be injected.
* @param userMessage The {@link UserMessage} into which the {@link Content}s are to be injected.
* @return The {@link UserMessage} with the injected {@link Content}s.
* @deprecated Use/implement {@link #inject(List, ChatMessage)} instead.
*/
@Deprecated
UserMessage inject(List<Content> contents, UserMessage userMessage);
}
ContentInjector定义了inject方法,它将contents注入到userMessage,返回新的ChatMessage
DefaultContentInjector
dev/langchain4j/rag/content/injector/DefaultContentInjector.java
public class DefaultContentInjector implements ContentInjector {
public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
"""
{{userMessage}}
Answer using the following information:
{{contents}}""");
private final PromptTemplate promptTemplate;
private final List<String> metadataKeysToInclude;
public DefaultContentInjector() {
this(DEFAULT_PROMPT_TEMPLATE, null);
}
public DefaultContentInjector(List<String> metadataKeysToInclude) {
this(DEFAULT_PROMPT_TEMPLATE, ensureNotEmpty(metadataKeysToInclude, "metadataKeysToInclude"));
}
public DefaultContentInjector(PromptTemplate promptTemplate) {
this(ensureNotNull(promptTemplate, "promptTemplate"), null);
}
public DefaultContentInjector(PromptTemplate promptTemplate, List<String> metadataKeysToInclude) {
this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
this.metadataKeysToInclude = copyIfNotNull(metadataKeysToInclude);
}
public static DefaultContentInjectorBuilder builder() {
return new DefaultContentInjectorBuilder();
}
@Override
public ChatMessage inject(List<Content> contents, ChatMessage chatMessage) {
if (contents.isEmpty()) {
return chatMessage;
}
Prompt prompt = createPrompt(chatMessage, contents);
if (chatMessage instanceof UserMessage message && isNotNullOrBlank(message.name())) {
return prompt.toUserMessage(message.name());
}
return prompt.toUserMessage();
}
protected Prompt createPrompt(ChatMessage chatMessage, List<Content> contents) {
return createPrompt((UserMessage) chatMessage, contents);
}
/**
* @deprecated use {@link #inject(List, ChatMessage)} instead.
*/
@Override
@Deprecated
public UserMessage inject(List<Content> contents, UserMessage userMessage) {
if (contents.isEmpty()) {
return userMessage;
}
Prompt prompt = createPrompt(userMessage, contents);
if (isNotNullOrBlank(userMessage.name())) {
return prompt.toUserMessage(userMessage.name());
}
return prompt.toUserMessage();
}
/**
* @deprecated implement/override {@link #createPrompt(ChatMessage, List)} instead.
*/
@Deprecated
protected Prompt createPrompt(UserMessage userMessage, List<Content> contents) {
Map<String, Object> variables = new HashMap<>();
variables.put("userMessage", userMessage.singleText());
variables.put("contents", format(contents));
return promptTemplate.apply(variables);
}
protected String format(List<Content> contents) {
return contents.stream().map(this::format).collect(joining("\n\n"));
}
protected String format(Content content) {
TextSegment segment = content.textSegment();
if (isNullOrEmpty(metadataKeysToInclude)) {
return segment.text();
}
String segmentContent = segment.text();
String segmentMetadata = format(segment.metadata());
return format(segmentContent, segmentMetadata);
}
protected String format(Metadata metadata) {
StringBuilder formattedMetadata = new StringBuilder();
for (String metadataKey : metadataKeysToInclude) {
String metadataValue = metadata.getString(metadataKey);
if (metadataValue != null) {
if (!formattedMetadata.isEmpty()) {
formattedMetadata.append("\n");
}
formattedMetadata.append(metadataKey).append(": ").append(metadataValue);
}
}
return formattedMetadata.toString();
}
protected String format(String segmentContent, String segmentMetadata) {
return segmentMetadata.isEmpty()
? segmentContent
: String.format("content: %s\n%s", segmentContent, segmentMetadata);
}
//......
}
DefaultContentInjector通过promptTemplate来将contents注入到userMessage中
小结
langchain4j的Advanced RAG提供的入口是RetrievalAugmentor,它包含了QueryTransformer、QueryRouter、ContentAggregator、ContentInjector、Executor这些属性。
- DefaultRetrievalAugmentor的构造器要求queryRouter不为null(
QueryRouter的route方法会返回ContentRetriever
) - 对于queryTransformer为null的默认使用DefaultQueryTransformer
- 对于contentAggregator为null的默认使用DefaultContentAggregator
- 对于contentInjector为null的默认使用DefaultContentInjector
- 对于executor为null的默认创建了一个coreSize为0,maximumPoolSize为Integer.MAX\_VALUE,keepAliveTime为1s,workQueue为SynchronousQueue的ThreadPoolExecutor。而QueryRouter则包含了ContentRetriever
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。