序
本文主要研究一下Spring AI的Chat Model
Model
spring-ai-core/src/main/java/org/springframework/ai/model/Model.java
public interface Model<TReq extends ModelRequest<?>, TRes extends ModelResponse<?>> {
/**
* Executes a method call to the AI model.
* @param request the request object to be sent to the AI model
* @return the response from the AI model
*/
TRes call(TReq request);
}
Model接口定义了call方法,入参为ModelRequest类型,返回ModelResponse类型
ModelRequest
spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java
public interface ModelRequest<T> {
/**
* Retrieves the instructions or input required by the AI model.
* @return the instructions or input required by the AI model
*/
T getInstructions(); // required input
/**
* Retrieves the customizable options for AI model interactions.
* @return the customizable options for AI model interactions
*/
ModelOptions getOptions();
}
ModelRequest定义了getInstructions、getOptions方法
ModelResponse
spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java
public interface ModelResponse<T extends ModelResult<?>> {
/**
* Retrieves the result of the AI model.
* @return the result generated by the AI model
*/
T getResult();
/**
* Retrieves the list of generated outputs by the AI model.
* @return the list of generated outputs
*/
List<T> getResults();
/**
* Retrieves the response metadata associated with the AI model's response.
* @return the response metadata
*/
ResponseMetadata getMetadata();
}
ModelResponse定义了getResult、getMetadata方法,其中result为ModelResult类型
ModelResult
spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java
public interface ModelResult<T> {
/**
* Retrieves the output generated by the AI model.
* @return the output generated by the AI model
*/
T getOutput();
/**
* Retrieves the metadata associated with the result of an AI model.
* @return the metadata associated with the result
*/
ResultMetadata getMetadata();
}
ModelResult定义了getMetadata方法
StreamingModel
spring-ai-core/src/main/java/org/springframework/ai/model/StreamingModel.java
public interface StreamingModel<TReq extends ModelRequest<?>, TResChunk extends ModelResponse<?>> {
/**
* Executes a method call to the AI model.
* @param request the request object to be sent to the AI model
* @return the streaming response from the AI model
*/
Flux<TResChunk> stream(TReq request);
}
StreamingModel接口定义了stream方法,入参为ModelRequest类型,返回Flux<ModelResponse>
StreamingChatModel
spring-ai-core/src/main/java/org/springframework/ai/chat/model/StreamingChatModel.java
@FunctionalInterface
public interface StreamingChatModel extends StreamingModel<Prompt, ChatResponse> {
default Flux<String> stream(String message) {
Prompt prompt = new Prompt(message);
return stream(prompt).map(response -> (response.getResult() == null || response.getResult().getOutput() == null
|| response.getResult().getOutput().getText() == null) ? ""
: response.getResult().getOutput().getText());
}
default Flux<String> stream(Message... messages) {
Prompt prompt = new Prompt(Arrays.asList(messages));
return stream(prompt).map(response -> (response.getResult() == null || response.getResult().getOutput() == null
|| response.getResult().getOutput().getText() == null) ? ""
: response.getResult().getOutput().getText());
}
@Override
Flux<ChatResponse> stream(Prompt prompt);
}
StreamingChatModel继承了StreamingModel接口,指定了入参为Prompt类型,返回类型为Flux<ChatResponse>
,并提供了Flux<String> stream(String message)
及Flux<String> stream(Message... messages)
这两个default方法
ChatModel
spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatModel.java
public interface ChatModel extends Model<Prompt, ChatResponse>, StreamingChatModel {
default String call(String message) {
Prompt prompt = new Prompt(new UserMessage(message));
Generation generation = call(prompt).getResult();
return (generation != null) ? generation.getOutput().getText() : "";
}
default String call(Message... messages) {
Prompt prompt = new Prompt(Arrays.asList(messages));
Generation generation = call(prompt).getResult();
return (generation != null) ? generation.getOutput().getText() : "";
}
@Override
ChatResponse call(Prompt prompt);
default ChatOptions getDefaultOptions() {
return ChatOptions.builder().build();
}
default Flux<ChatResponse> stream(Prompt prompt) {
throw new UnsupportedOperationException("streaming is not supported");
}
}
ChatModel继承了Model、StreamingChatModel接口,其中Model的入参为Prompt类型,返回为ChatResponse类型
ChatModel在不同模块中有不同的实现,比如spring-ai-ollama(OllamaChatModel
)、spring-ai-openai(OpenAiChatModel
)、spring-ai-minimax(MiniMaxChatModel
)、spring-ai-moonshot(MoonshotChatModel
)、spring-ai-zhipuai(ZhiPuAiChatModel
)
OllamaAutoConfiguration
org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java
@AutoConfiguration(after = RestClientAutoConfiguration.class)
@ConditionalOnClass(OllamaApi.class)
@EnableConfigurationProperties({ OllamaChatProperties.class, OllamaEmbeddingProperties.class,
OllamaConnectionProperties.class, OllamaInitializationProperties.class })
@ImportAutoConfiguration(classes = { RestClientAutoConfiguration.class, WebClientAutoConfiguration.class })
public class OllamaAutoConfiguration {
@Bean
@ConditionalOnMissingBean(OllamaConnectionDetails.class)
public PropertiesOllamaConnectionDetails ollamaConnectionDetails(OllamaConnectionProperties properties) {
return new PropertiesOllamaConnectionDetails(properties);
}
@Bean
@ConditionalOnMissingBean
public OllamaApi ollamaApi(OllamaConnectionDetails connectionDetails,
ObjectProvider<RestClient.Builder> restClientBuilderProvider,
ObjectProvider<WebClient.Builder> webClientBuilderProvider) {
return new OllamaApi(connectionDetails.getBaseUrl(),
restClientBuilderProvider.getIfAvailable(RestClient::builder),
webClientBuilderProvider.getIfAvailable(WebClient::builder));
}
@Bean
@ConditionalOnMissingBean
@ConditionalOnProperty(prefix = OllamaChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
matchIfMissing = true)
public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties properties,
OllamaInitializationProperties initProperties, List<FunctionCallback> toolFunctionCallbacks,
FunctionCallbackResolver functionCallbackResolver, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<ChatModelObservationConvention> observationConvention) {
var chatModelPullStrategy = initProperties.getChat().isInclude() ? initProperties.getPullModelStrategy()
: PullModelStrategy.NEVER;
var chatModel = OllamaChatModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(properties.getOptions())
.functionCallbackResolver(functionCallbackResolver)
.toolFunctionCallbacks(toolFunctionCallbacks)
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.modelManagementOptions(
new ModelManagementOptions(chatModelPullStrategy, initProperties.getChat().getAdditionalModels(),
initProperties.getTimeout(), initProperties.getMaxRetries()))
.build();
observationConvention.ifAvailable(chatModel::setObservationConvention);
return chatModel;
}
@Bean
@ConditionalOnMissingBean
@ConditionalOnProperty(prefix = OllamaEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
matchIfMissing = true)
public OllamaEmbeddingModel ollamaEmbeddingModel(OllamaApi ollamaApi, OllamaEmbeddingProperties properties,
OllamaInitializationProperties initProperties, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<EmbeddingModelObservationConvention> observationConvention) {
var embeddingModelPullStrategy = initProperties.getEmbedding().isInclude()
? initProperties.getPullModelStrategy() : PullModelStrategy.NEVER;
var embeddingModel = OllamaEmbeddingModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(properties.getOptions())
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.modelManagementOptions(new ModelManagementOptions(embeddingModelPullStrategy,
initProperties.getEmbedding().getAdditionalModels(), initProperties.getTimeout(),
initProperties.getMaxRetries()))
.build();
observationConvention.ifAvailable(embeddingModel::setObservationConvention);
return embeddingModel;
}
@Bean
@ConditionalOnMissingBean
public FunctionCallbackResolver springAiFunctionManager(ApplicationContext context) {
DefaultFunctionCallbackResolver manager = new DefaultFunctionCallbackResolver();
manager.setApplicationContext(context);
return manager;
}
static class PropertiesOllamaConnectionDetails implements OllamaConnectionDetails {
private final OllamaConnectionProperties properties;
PropertiesOllamaConnectionDetails(OllamaConnectionProperties properties) {
this.properties = properties;
}
@Override
public String getBaseUrl() {
return this.properties.getBaseUrl();
}
}
}
spring-ai-spring-boot-autoconfigure提供了一系列的AutoConfiguration,比如OllamaAutoConfiguration自动配置了OllamaChatModel
小结
Spring AI的Model接口定义了call方法,入参为ModelRequest类型,返回ModelResponse类型;StreamingModel接口定义了stream方法,入参为ModelRequest类型,返回Flux<ModelResponse>
;StreamingChatModel继承了StreamingModel接口,指定了入参为Prompt类型,返回类型为Flux<ChatResponse>
,并提供了Flux<String> stream(String message)
及Flux<String> stream(Message... messages)
这两个default方法;而ChatModel继承了Model、StreamingChatModel接口,其中Model的入参为Prompt类型,返回为ChatResponse类型。ChatModel在不同模块中有不同的实现,比如spring-ai-ollama(OllamaChatModel
)、spring-ai-openai(OpenAiChatModel
)、spring-ai-minimax(MiniMaxChatModel
)、spring-ai-moonshot(MoonshotChatModel
)、spring-ai-zhipuai(ZhiPuAiChatModel
)。
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。