序
本文主要研究一下Spring AI的EmbeddingModel
EmbeddingModel
spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingModel.java
public interface EmbeddingModel extends Model<EmbeddingRequest, EmbeddingResponse> {
@Override
EmbeddingResponse call(EmbeddingRequest request);
/**
* Embeds the given text into a vector.
* @param text the text to embed.
* @return the embedded vector.
*/
default float[] embed(String text) {
Assert.notNull(text, "Text must not be null");
List<float[]> response = this.embed(List.of(text));
return response.iterator().next();
}
/**
* Embeds the given document's content into a vector.
* @param document the document to embed.
* @return the embedded vector.
*/
float[] embed(Document document);
/**
* Embeds a batch of texts into vectors.
* @param texts list of texts to embed.
* @return list of embedded vectors.
*/
default List<float[]> embed(List<String> texts) {
Assert.notNull(texts, "Texts must not be null");
return this.call(new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().build()))
.getResults()
.stream()
.map(Embedding::getOutput)
.toList();
}
/**
* Embeds a batch of {@link Document}s into vectors based on a
* {@link BatchingStrategy}.
* @param documents list of {@link Document}s.
* @param options {@link EmbeddingOptions}.
* @param batchingStrategy {@link BatchingStrategy}.
* @return a list of float[] that represents the vectors for the incoming
* {@link Document}s. The returned list is expected to be in the same order of the
* {@link Document} list.
*/
default List<float[]> embed(List<Document> documents, EmbeddingOptions options, BatchingStrategy batchingStrategy) {
Assert.notNull(documents, "Documents must not be null");
List<float[]> embeddings = new ArrayList<>(documents.size());
List<List<Document>> batch = batchingStrategy.batch(documents);
for (List<Document> subBatch : batch) {
List<String> texts = subBatch.stream().map(Document::getText).toList();
EmbeddingRequest request = new EmbeddingRequest(texts, options);
EmbeddingResponse response = this.call(request);
for (int i = 0; i < subBatch.size(); i++) {
embeddings.add(response.getResults().get(i).getOutput());
}
}
Assert.isTrue(embeddings.size() == documents.size(),
"Embeddings must have the same number as that of the documents");
return embeddings;
}
/**
* Embeds a batch of texts into vectors and returns the {@link EmbeddingResponse}.
* @param texts list of texts to embed.
* @return the embedding response.
*/
default EmbeddingResponse embedForResponse(List<String> texts) {
Assert.notNull(texts, "Texts must not be null");
return this.call(new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().build()));
}
/**
* Get the number of dimensions of the embedded vectors. Note that by default, this
* method will call the remote Embedding endpoint to get the dimensions of the
* embedded vectors. If the dimensions are known ahead of time, it is recommended to
* override this method.
* @return the number of dimensions of the embedded vectors.
*/
default int dimensions() {
return embed("Test String").length;
}
}
EmbeddingModel继承了Model接口,其入参类型为EmbeddingRequest,返回类型为EmbeddingResponse,它定义了call、embed接口,提供了embed、embedForResponse、dimensions的默认实现
EmbeddingRequest
spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingRequest.java
public class EmbeddingRequest implements ModelRequest<List<String>> {
private final List<String> inputs;
private final EmbeddingOptions options;
public EmbeddingRequest(List<String> inputs, EmbeddingOptions options) {
this.inputs = inputs;
this.options = options;
}
@Override
public List<String> getInstructions() {
return this.inputs;
}
@Override
public EmbeddingOptions getOptions() {
return this.options;
}
}
EmbeddingRequest实现了ModelRequest接口,其getInstructions返回的是List<String>
EmbeddingResponse
spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java
public class EmbeddingResponse implements ModelResponse<Embedding> {
/**
* Embedding data.
*/
private final List<Embedding> embeddings;
/**
* Embedding metadata.
*/
private final EmbeddingResponseMetadata metadata;
/**
* Creates a new {@link EmbeddingResponse} instance with empty metadata.
* @param embeddings the embedding data.
*/
public EmbeddingResponse(List<Embedding> embeddings) {
this(embeddings, new EmbeddingResponseMetadata());
}
/**
* Creates a new {@link EmbeddingResponse} instance.
* @param embeddings the embedding data.
* @param metadata the embedding metadata.
*/
public EmbeddingResponse(List<Embedding> embeddings, EmbeddingResponseMetadata metadata) {
this.embeddings = embeddings;
this.metadata = metadata;
}
/**
* @return Get the embedding metadata.
*/
public EmbeddingResponseMetadata getMetadata() {
return this.metadata;
}
@Override
public Embedding getResult() {
Assert.notEmpty(this.embeddings, "No embedding data available.");
return this.embeddings.get(0);
}
/**
* @return Get the embedding data.
*/
@Override
public List<Embedding> getResults() {
return this.embeddings;
}
//......
}
EmbeddingResponse实现了ModelResponse接口,其result为Embedding类型
AbstractEmbeddingModel
spring-ai-core/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java
public abstract class AbstractEmbeddingModel implements EmbeddingModel {
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = loadKnownModelDimensions();
/**
* Default constructor.
*/
public AbstractEmbeddingModel() {
}
/**
* Cached embedding dimensions.
*/
protected final AtomicInteger embeddingDimensions = new AtomicInteger(-1);
/**
* Return the dimension of the requested embedding generative name. If the generative
* name is unknown uses the EmbeddingModel to perform a dummy EmbeddingModel#embed and
* count the response dimensions.
* @param embeddingModel Fall-back client to determine, empirically the dimensions.
* @param modelName Embedding generative name to retrieve the dimensions for.
* @param dummyContent Dummy content to use for the empirical dimension calculation.
* @return Returns the embedding dimensions for the modelName.
*/
public static int dimensions(EmbeddingModel embeddingModel, String modelName, String dummyContent) {
if (KNOWN_EMBEDDING_DIMENSIONS.containsKey(modelName)) {
// Retrieve the dimension from a pre-configured file.
return KNOWN_EMBEDDING_DIMENSIONS.get(modelName);
}
else {
// Determine the dimensions empirically.
// Generate an embedding and count the dimension size;
return embeddingModel.embed(dummyContent).length;
}
}
private static Map<String, Integer> loadKnownModelDimensions() {
try {
Properties properties = new Properties();
properties.load(new DefaultResourceLoader()
.getResource("classpath:/embedding/embedding-model-dimensions.properties")
.getInputStream());
return properties.entrySet()
.stream()
.collect(Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString())));
}
catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public int dimensions() {
if (this.embeddingDimensions.get() < 0) {
this.embeddingDimensions.set(dimensions(this, "Test", "Hello World"));
}
return this.embeddingDimensions.get();
}
}
AbstractEmbeddingModel实现了EmbeddingModel接口定义的dimensions方法,它在不同模块有不同的实现子类,比如spring-ai-openai的OpenAiEmbeddingModel
、spring-ai-ollama的OllamaEmbeddingModel
、spring-ai-minimax的MiniMaxEmbeddingModel
等
OllamaEmbeddingAutoConfiguration
org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingAutoConfiguration.java
@AutoConfiguration(after = RestClientAutoConfiguration.class)
@ConditionalOnClass(OllamaEmbeddingModel.class)
@ConditionalOnProperty(name = SpringAIModelProperties.EMBEDDING_MODEL, havingValue = SpringAIModels.OLLAMA,
matchIfMissing = true)
@EnableConfigurationProperties({ OllamaEmbeddingProperties.class, OllamaInitializationProperties.class })
@ImportAutoConfiguration(classes = { OllamaApiAutoConfiguration.class, RestClientAutoConfiguration.class,
WebClientAutoConfiguration.class })
public class OllamaEmbeddingAutoConfiguration {
@Bean
@ConditionalOnMissingBean
public OllamaEmbeddingModel ollamaEmbeddingModel(OllamaApi ollamaApi, OllamaEmbeddingProperties properties,
OllamaInitializationProperties initProperties, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<EmbeddingModelObservationConvention> observationConvention) {
var embeddingModelPullStrategy = initProperties.getEmbedding().isInclude()
? initProperties.getPullModelStrategy() : PullModelStrategy.NEVER;
var embeddingModel = OllamaEmbeddingModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(properties.getOptions())
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.modelManagementOptions(new ModelManagementOptions(embeddingModelPullStrategy,
initProperties.getEmbedding().getAdditionalModels(), initProperties.getTimeout(),
initProperties.getMaxRetries()))
.build();
observationConvention.ifAvailable(embeddingModel::setObservationConvention);
return embeddingModel;
}
}
OllamaEmbeddingAutoConfiguration在spring.ai.model.embedding
为ollama
时启用,它自动配置了OllamaEmbeddingModel
OllamaEmbeddingProperties
org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingProperties.java
@ConfigurationProperties(OllamaEmbeddingProperties.CONFIG_PREFIX)
public class OllamaEmbeddingProperties {
public static final String CONFIG_PREFIX = "spring.ai.ollama.embedding";
/**
* Client lever Ollama options. Use this property to configure generative temperature,
* topK and topP and alike parameters. The null values are ignored defaulting to the
* generative's defaults.
*/
@NestedConfigurationProperty
private OllamaOptions options = OllamaOptions.builder().model(OllamaModel.MXBAI_EMBED_LARGE.id()).build();
public String getModel() {
return this.options.getModel();
}
public void setModel(String model) {
this.options.setModel(model);
}
public OllamaOptions getOptions() {
return this.options;
}
}
OllamaEmbeddingProperties主要是提供了OllamaOptions属性配置,具体可以参考https://github.com/ggerganov/llama.cpp/blob/master/examples/main/README.md
OllamaInitializationProperties
org/springframework/ai/model/ollama/autoconfigure/OllamaInitializationProperties.java
@ConfigurationProperties(OllamaInitializationProperties.CONFIG_PREFIX)
public class OllamaInitializationProperties {
public static final String CONFIG_PREFIX = "spring.ai.ollama.init";
/**
* Chat models initialization settings.
*/
private final ModelTypeInit chat = new ModelTypeInit();
/**
* Embedding models initialization settings.
*/
private final ModelTypeInit embedding = new ModelTypeInit();
/**
* Whether to pull models at startup-time and how.
*/
private PullModelStrategy pullModelStrategy = PullModelStrategy.NEVER;
/**
* How long to wait for a model to be pulled.
*/
private Duration timeout = Duration.ofMinutes(5);
/**
* Maximum number of retries for the model pull operation.
*/
private int maxRetries = 0;
public PullModelStrategy getPullModelStrategy() {
return this.pullModelStrategy;
}
public void setPullModelStrategy(PullModelStrategy pullModelStrategy) {
this.pullModelStrategy = pullModelStrategy;
}
public ModelTypeInit getChat() {
return this.chat;
}
public ModelTypeInit getEmbedding() {
return this.embedding;
}
public Duration getTimeout() {
return this.timeout;
}
public void setTimeout(Duration timeout) {
this.timeout = timeout;
}
public int getMaxRetries() {
return this.maxRetries;
}
public void setMaxRetries(int maxRetries) {
this.maxRetries = maxRetries;
}
public static class ModelTypeInit {
/**
* Include this type of models in the initialization task.
*/
private boolean include = true;
/**
* Additional models to initialize besides the ones configured via default
* properties.
*/
private List<String> additionalModels = List.of();
public boolean isInclude() {
return this.include;
}
public void setInclude(boolean include) {
this.include = include;
}
public List<String> getAdditionalModels() {
return this.additionalModels;
}
public void setAdditionalModels(List<String> additionalModels) {
this.additionalModels = additionalModels;
}
}
}
OllamaInitializationProperties提供了spring.ai.ollama.init
即ollama初始化的相关配置,其中ModelTypeInit可以指定初始化哪些额外的model
示例
pom.xml
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-ollama</artifactId>
</dependency>
配置
spring:
ai:
model:
embedding: ollama
ollama:
init:
timeout: 5m
max-retries: 0
embedding:
include: true
additional-models: []
base-url: http://localhost:11434
embedding:
enabled: true
options:
model: bge-m3:latest
truncate: true
example
@Test
public void testCall() {
EmbeddingRequest request = new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"),
OllamaOptions.builder()
.model("bge-m3:latest")
.truncate(false)
.build());
EmbeddingResponse embeddingResponse = embeddingModel.call(request);
log.info("resp:{}", JSON.toJSONString(embeddingResponse));
}
小结
Spring AI定义了EmbeddingModel接口,它继承了Model接口,其入参类型为EmbeddingRequest,返回类型为EmbeddingResponse,它定义了call、embed接口,提供了embed、embedForResponse、dimensions的默认实现;AbstractEmbeddingModel实现了EmbeddingModel接口定义的dimensions方法,它在不同模块有不同的实现子类,比如spring-ai-openai的OpenAiEmbeddingModel
、spring-ai-ollama的OllamaEmbeddingModel
、spring-ai-minimax的MiniMaxEmbeddingModel
等;OllamaEmbeddingAutoConfiguration在spring.ai.model.embedding
为ollama
时启用,它自动配置了OllamaEmbeddingModel。
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。