序
本文主要研究一下langchain4j的Naive RAG
示例
public class Naive_RAG_Example {
/**
* This example demonstrates how to implement a naive Retrieval-Augmented Generation (RAG) application.
* By "naive", we mean that we won't use any advanced RAG techniques.
* In each interaction with the Large Language Model (LLM), we will:
* 1. Take the user's query as-is.
* 2. Embed it using an embedding model.
* 3. Use the query's embedding to search an embedding store (containing small segments of your documents)
* for the X most relevant segments.
* 4. Append the found segments to the user's query.
* 5. Send the combined input (user query + segments) to the LLM.
* 6. Hope that:
* - The user's query is well-formulated and contains all necessary details for retrieval.
* - The found segments are relevant to the user's query.
*/
public static void main(String[] args) {
// Let's create an assistant that will know about our document
Assistant assistant = createAssistant("documents/miles-of-smiles-terms-of-use.txt");
// Now, let's start the conversation with the assistant. We can ask questions like:
// - Can I cancel my reservation?
// - I had an accident, should I pay extra?
startConversationWith(assistant);
}
private static Assistant createAssistant(String documentPath) {
// First, let's create a chat model, also known as a LLM, which will answer our queries.
// In this example, we will use OpenAI's gpt-4o-mini, but you can choose any supported model.
// Langchain4j currently supports more than 10 popular LLM providers.
ChatLanguageModel chatLanguageModel = OpenAiChatModel.builder()
.apiKey(OPENAI_API_KEY)
.modelName(GPT_4_O_MINI)
.build();
// Now, let's load a document that we want to use for RAG.
// We will use the terms of use from an imaginary car rental company, "Miles of Smiles".
// For this example, we'll import only a single document, but you can load as many as you need.
// LangChain4j offers built-in support for loading documents from various sources:
// File System, URL, Amazon S3, Azure Blob Storage, GitHub, Tencent COS.
// Additionally, LangChain4j supports parsing multiple document types:
// text, pdf, doc, xls, ppt.
// However, you can also manually import your data from other sources.
DocumentParser documentParser = new TextDocumentParser();
Document document = loadDocument(toPath(documentPath), documentParser);
// Now, we need to split this document into smaller segments, also known as "chunks."
// This approach allows us to send only relevant segments to the LLM in response to a user query,
// rather than the entire document. For instance, if a user asks about cancellation policies,
// we will identify and send only those segments related to cancellation.
// A good starting point is to use a recursive document splitter that initially attempts
// to split by paragraphs. If a paragraph is too large to fit into a single segment,
// the splitter will recursively divide it by newlines, then by sentences, and finally by words,
// if necessary, to ensure each piece of text fits into a single segment.
DocumentSplitter splitter = DocumentSplitters.recursive(300, 0);
List<TextSegment> segments = splitter.split(document);
// Now, we need to embed (also known as "vectorize") these segments.
// Embedding is needed for performing similarity searches.
// For this example, we'll use a local in-process embedding model, but you can choose any supported model.
// Langchain4j currently supports more than 10 popular embedding model providers.
EmbeddingModel embeddingModel = new BgeSmallEnV15QuantizedEmbeddingModel();
List<Embedding> embeddings = embeddingModel.embedAll(segments).content();
// Next, we will store these embeddings in an embedding store (also known as a "vector database").
// This store will be used to search for relevant segments during each interaction with the LLM.
// For simplicity, this example uses an in-memory embedding store, but you can choose from any supported store.
// Langchain4j currently supports more than 15 popular embedding stores.
EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
embeddingStore.addAll(embeddings, segments);
// We could also use EmbeddingStoreIngestor to hide manual steps above behind a simpler API.
// See an example of using EmbeddingStoreIngestor in _01_Advanced_RAG_with_Query_Compression_Example.
// The content retriever is responsible for retrieving relevant content based on a user query.
// Currently, it is capable of retrieving text segments, but future enhancements will include support for
// additional modalities like images, audio, and more.
ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(2) // on each interaction we will retrieve the 2 most relevant segments
.minScore(0.5) // we want to retrieve segments at least somewhat similar to user query
.build();
// Optionally, we can use a chat memory, enabling back-and-forth conversation with the LLM
// and allowing it to remember previous interactions.
// Currently, LangChain4j offers two chat memory implementations:
// MessageWindowChatMemory and TokenWindowChatMemory.
ChatMemory chatMemory = MessageWindowChatMemory.withMaxMessages(10);
// The final step is to build our AI Service,
// configuring it to use the components we've created above.
return AiServices.builder(Assistant.class)
.chatLanguageModel(chatLanguageModel)
.contentRetriever(contentRetriever)
.chatMemory(chatMemory)
.build();
}
}
public static void startConversationWith(Assistant assistant) {
Logger log = LoggerFactory.getLogger(Assistant.class);
try (Scanner scanner = new Scanner(System.in)) {
while (true) {
log.info("==================================================");
log.info("User: ");
String userQuery = scanner.nextLine();
log.info("==================================================");
if ("exit".equalsIgnoreCase(userQuery)) {
break;
}
String agentAnswer = assistant.answer(userQuery);
log.info("==================================================");
log.info("Assistant: " + agentAnswer);
}
}
}
所谓naive就是我们不会使用任何高级的RAG技术,Easy RAG使用了EmbeddingStoreIngestor来隐藏了文档解析、分割、嵌入、嵌入存储,Naive RAG亦可使用。Naive RAG使用EmbeddingStoreContentRetriever.builder()进行查询相关的定制(Easy RAG直接使用EmbeddingStoreContentRetriever.from(embeddingStore)
),startConversationWith的整个流程如下:
- 直接使用用户的查询
- 使用嵌入模型对查询进行嵌入
- 使用查询的嵌入在嵌入存储中(包含文档的小片段)搜索X个最相关的片段
- 将找到的片段附加到用户的查询中
- 将组合的输入(用户查询 + 片段)发送给LLM
ContentRetriever
dev/langchain4j/rag/content/retriever/ContentRetriever.java
public interface ContentRetriever {
/**
* Retrieves relevant {@link Content}s using a given {@link Query}.
* The {@link Content}s are sorted by relevance, with the most relevant {@link Content}s appearing
* at the beginning of the returned {@code List<Content>}.
*
* @param query The {@link Query} to use for retrieval.
* @return A list of retrieved {@link Content}s.
*/
List<Content> retrieve(Query query);
}
ContentRetriever定义了retrieve接口,它代表从一个数据源根据query参数返回List<Content>
,这些数据源通常是Embedding (vector) store(EmbeddingStoreContentRetriever
),全文检索数据源(AzureAiSearchContentRetriever
),混合搜索数据源(AzureAiSearchContentRetriever
),搜索引擎数据源(ebSearchContentRetriever
),知识图谱数据源(Neo4jContentRetriever
),关系数据库数据源(SqlDatabaseContentRetriever
)。
EmbeddingStoreContentRetriever
dev/langchain4j/rag/content/retriever/EmbeddingStoreContentRetriever.java
public class EmbeddingStoreContentRetriever implements ContentRetriever {
public static final Function<Query, Integer> DEFAULT_MAX_RESULTS = (query) -> 3;
public static final Function<Query, Double> DEFAULT_MIN_SCORE = (query) -> 0.0;
public static final Function<Query, Filter> DEFAULT_FILTER = (query) -> null;
public static final String DEFAULT_DISPLAY_NAME = "Default";
private final EmbeddingStore<TextSegment> embeddingStore;
private final EmbeddingModel embeddingModel;
private final Function<Query, Integer> maxResultsProvider;
private final Function<Query, Double> minScoreProvider;
private final Function<Query, Filter> filterProvider;
private final String displayName;
public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore,
EmbeddingModel embeddingModel) {
this(
DEFAULT_DISPLAY_NAME,
embeddingStore,
embeddingModel,
DEFAULT_MAX_RESULTS,
DEFAULT_MIN_SCORE,
DEFAULT_FILTER
);
}
public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore,
EmbeddingModel embeddingModel,
int maxResults) {
this(
DEFAULT_DISPLAY_NAME,
embeddingStore,
embeddingModel,
(query) -> maxResults,
DEFAULT_MIN_SCORE,
DEFAULT_FILTER
);
}
public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore,
EmbeddingModel embeddingModel,
Integer maxResults,
Double minScore) {
this(
DEFAULT_DISPLAY_NAME,
embeddingStore,
embeddingModel,
(query) -> maxResults,
(query) -> minScore,
DEFAULT_FILTER
);
}
private EmbeddingStoreContentRetriever(String displayName,
EmbeddingStore<TextSegment> embeddingStore,
EmbeddingModel embeddingModel,
Function<Query, Integer> dynamicMaxResults,
Function<Query, Double> dynamicMinScore,
Function<Query, Filter> dynamicFilter) {
this.displayName = getOrDefault(displayName, DEFAULT_DISPLAY_NAME);
this.embeddingStore = ensureNotNull(embeddingStore, "embeddingStore");
this.embeddingModel = ensureNotNull(
getOrDefault(embeddingModel, EmbeddingStoreContentRetriever::loadEmbeddingModel),
"embeddingModel"
);
this.maxResultsProvider = getOrDefault(dynamicMaxResults, DEFAULT_MAX_RESULTS);
this.minScoreProvider = getOrDefault(dynamicMinScore, DEFAULT_MIN_SCORE);
this.filterProvider = getOrDefault(dynamicFilter, DEFAULT_FILTER);
}
//......
@Override
public List<Content> retrieve(Query query) {
Embedding embeddedQuery = embeddingModel.embed(query.text()).content();
EmbeddingSearchRequest searchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(embeddedQuery)
.maxResults(maxResultsProvider.apply(query))
.minScore(minScoreProvider.apply(query))
.filter(filterProvider.apply(query))
.build();
EmbeddingSearchResult<TextSegment> searchResult = embeddingStore.search(searchRequest);
return searchResult.matches().stream()
.map(embeddingMatch -> Content.from(
embeddingMatch.embedded(),
Map.of(
ContentMetadata.SCORE, embeddingMatch.score(),
ContentMetadata.EMBEDDING_ID, embeddingMatch.embeddingId()
)
))
.collect(Collectors.toList());
}
@Override
public String toString() {
return "EmbeddingStoreContentRetriever{" +
"displayName='" + displayName + '\'' +
'}';
}
}
EmbeddingStoreContentRetriever实现了ContentRetriever接口,它定义了embeddingStore、embeddingModel、maxResultsProvider、minScoreProvider、filterProvider、displayName这几个属性,默认DEFAULT_MAX_RESULTS为3,DEFAULT_MIN_SCORE为0.0,DEFAULT_FILTER为null。其retrieve方法先执行embeddingModel.embed(query.text()).content()来对查询进行嵌入处理,之后构造EmbeddingSearchRequest(设置embeddedQuery、maxResults、minScore、filter
),接着通过embeddingStore.search(searchRequest)发起查询请求,最后解析EmbeddingSearchResult转换为List<Content>
。
它还提供了几个静态方法
private static EmbeddingModel loadEmbeddingModel() {
Collection<EmbeddingModelFactory> factories = loadFactories(EmbeddingModelFactory.class);
if (factories.size() > 1) {
throw new RuntimeException("Conflict: multiple embedding models have been found in the classpath. " +
"Please explicitly specify the one you wish to use.");
}
for (EmbeddingModelFactory factory : factories) {
return factory.create();
}
return null;
}
public static EmbeddingStoreContentRetrieverBuilder builder() {
return new EmbeddingStoreContentRetrieverBuilder();
}
public static EmbeddingStoreContentRetriever from(EmbeddingStore<TextSegment> embeddingStore) {
return builder().embeddingStore(embeddingStore).build();
}
EmbeddingStoreContentRetrieverBuilder
public static class EmbeddingStoreContentRetrieverBuilder {
private String displayName;
private EmbeddingStore<TextSegment> embeddingStore;
private EmbeddingModel embeddingModel;
private Function<Query, Integer> dynamicMaxResults;
private Function<Query, Double> dynamicMinScore;
private Function<Query, Filter> dynamicFilter;
EmbeddingStoreContentRetrieverBuilder() {
}
public EmbeddingStoreContentRetrieverBuilder maxResults(Integer maxResults) {
if (maxResults != null) {
dynamicMaxResults = (query) -> ensureGreaterThanZero(maxResults, "maxResults");
}
return this;
}
public EmbeddingStoreContentRetrieverBuilder minScore(Double minScore) {
if (minScore != null) {
dynamicMinScore = (query) -> ensureBetween(minScore, 0, 1, "minScore");
}
return this;
}
public EmbeddingStoreContentRetrieverBuilder filter(Filter filter) {
if (filter != null) {
dynamicFilter = (query) -> filter;
}
return this;
}
public EmbeddingStoreContentRetrieverBuilder displayName(String displayName) {
this.displayName = displayName;
return this;
}
public EmbeddingStoreContentRetrieverBuilder embeddingStore(EmbeddingStore<TextSegment> embeddingStore) {
this.embeddingStore = embeddingStore;
return this;
}
public EmbeddingStoreContentRetrieverBuilder embeddingModel(EmbeddingModel embeddingModel) {
this.embeddingModel = embeddingModel;
return this;
}
public EmbeddingStoreContentRetrieverBuilder dynamicMaxResults(Function<Query, Integer> dynamicMaxResults) {
this.dynamicMaxResults = dynamicMaxResults;
return this;
}
public EmbeddingStoreContentRetrieverBuilder dynamicMinScore(Function<Query, Double> dynamicMinScore) {
this.dynamicMinScore = dynamicMinScore;
return this;
}
public EmbeddingStoreContentRetrieverBuilder dynamicFilter(Function<Query, Filter> dynamicFilter) {
this.dynamicFilter = dynamicFilter;
return this;
}
public EmbeddingStoreContentRetriever build() {
return new EmbeddingStoreContentRetriever(this.displayName, this.embeddingStore, this.embeddingModel, this.dynamicMaxResults, this.dynamicMinScore, this.dynamicFilter);
}
public String toString() {
return "EmbeddingStoreContentRetriever.EmbeddingStoreContentRetrieverBuilder(displayName=" + this.displayName + ", embeddingStore=" + this.embeddingStore + ", embeddingModel=" + this.embeddingModel + ", dynamicMaxResults=" + this.dynamicMaxResults + ", dynamicMinScore=" + this.dynamicMinScore + ", dynamicFilter=" + this.dynamicFilter + ")";
}
}
EmbeddingStoreContentRetrieverBuilder主要是提供了对embeddingStore、embeddingModel、maxResults、minScore、filter、displayName这几个属性的设置。
小结
langchain4j提供了EmbeddingStoreContentRetriever来开启Naive RAG的功能,EmbeddingStoreContentRetriever.builder()进行查询相关的定制,可以设置embeddingStore、embeddingModel、maxResults、minScore。EmbeddingStoreContentRetriever主要是先执行embeddingModel.embed(query.text()).content()来对查询进行嵌入处理,之后构造EmbeddingSearchRequest(设置embeddedQuery、maxResults、minScore、filter
),接着通过embeddingStore.search(searchRequest)发起查询请求,最后解析EmbeddingSearchResult转换为List<Content>
。
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。