本文主要研究一下Spring AI的Chat Model

Model

spring-ai-core/src/main/java/org/springframework/ai/model/Model.java

public interface Model<TReq extends ModelRequest<?>, TRes extends ModelResponse<?>> {

    /**
     * Executes a method call to the AI model.
     * @param request the request object to be sent to the AI model
     * @return the response from the AI model
     */
    TRes call(TReq request);

}
Model接口定义了call方法,入参为ModelRequest类型,返回ModelResponse类型

ModelRequest

spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java

public interface ModelRequest<T> {

    /**
     * Retrieves the instructions or input required by the AI model.
     * @return the instructions or input required by the AI model
     */
    T getInstructions(); // required input

    /**
     * Retrieves the customizable options for AI model interactions.
     * @return the customizable options for AI model interactions
     */
    ModelOptions getOptions();

}
ModelRequest定义了getInstructions、getOptions方法

ModelResponse

spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java

public interface ModelResponse<T extends ModelResult<?>> {

    /**
     * Retrieves the result of the AI model.
     * @return the result generated by the AI model
     */
    T getResult();

    /**
     * Retrieves the list of generated outputs by the AI model.
     * @return the list of generated outputs
     */
    List<T> getResults();

    /**
     * Retrieves the response metadata associated with the AI model's response.
     * @return the response metadata
     */
    ResponseMetadata getMetadata();

}
ModelResponse定义了getResult、getMetadata方法,其中result为ModelResult类型

ModelResult

spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java

public interface ModelResult<T> {

    /**
     * Retrieves the output generated by the AI model.
     * @return the output generated by the AI model
     */
    T getOutput();

    /**
     * Retrieves the metadata associated with the result of an AI model.
     * @return the metadata associated with the result
     */
    ResultMetadata getMetadata();

}
ModelResult定义了getMetadata方法

StreamingModel

spring-ai-core/src/main/java/org/springframework/ai/model/StreamingModel.java

public interface StreamingModel<TReq extends ModelRequest<?>, TResChunk extends ModelResponse<?>> {

    /**
     * Executes a method call to the AI model.
     * @param request the request object to be sent to the AI model
     * @return the streaming response from the AI model
     */
    Flux<TResChunk> stream(TReq request);

}
StreamingModel接口定义了stream方法,入参为ModelRequest类型,返回Flux<ModelResponse>

StreamingChatModel

spring-ai-core/src/main/java/org/springframework/ai/chat/model/StreamingChatModel.java

@FunctionalInterface
public interface StreamingChatModel extends StreamingModel<Prompt, ChatResponse> {

    default Flux<String> stream(String message) {
        Prompt prompt = new Prompt(message);
        return stream(prompt).map(response -> (response.getResult() == null || response.getResult().getOutput() == null
                || response.getResult().getOutput().getText() == null) ? ""
                        : response.getResult().getOutput().getText());
    }

    default Flux<String> stream(Message... messages) {
        Prompt prompt = new Prompt(Arrays.asList(messages));
        return stream(prompt).map(response -> (response.getResult() == null || response.getResult().getOutput() == null
                || response.getResult().getOutput().getText() == null) ? ""
                        : response.getResult().getOutput().getText());
    }

    @Override
    Flux<ChatResponse> stream(Prompt prompt);

}
StreamingChatModel继承了StreamingModel接口,指定了入参为Prompt类型,返回类型为Flux<ChatResponse>,并提供了Flux<String> stream(String message)Flux<String> stream(Message... messages)这两个default方法

ChatModel

spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatModel.java

public interface ChatModel extends Model<Prompt, ChatResponse>, StreamingChatModel {

    default String call(String message) {
        Prompt prompt = new Prompt(new UserMessage(message));
        Generation generation = call(prompt).getResult();
        return (generation != null) ? generation.getOutput().getText() : "";
    }

    default String call(Message... messages) {
        Prompt prompt = new Prompt(Arrays.asList(messages));
        Generation generation = call(prompt).getResult();
        return (generation != null) ? generation.getOutput().getText() : "";
    }

    @Override
    ChatResponse call(Prompt prompt);

    default ChatOptions getDefaultOptions() {
        return ChatOptions.builder().build();
    }

    default Flux<ChatResponse> stream(Prompt prompt) {
        throw new UnsupportedOperationException("streaming is not supported");
    }

}
ChatModel继承了Model、StreamingChatModel接口,其中Model的入参为Prompt类型,返回为ChatResponse类型
ChatModel在不同模块中有不同的实现,比如spring-ai-ollama(OllamaChatModel)、spring-ai-openai(OpenAiChatModel)、spring-ai-minimax(MiniMaxChatModel)、spring-ai-moonshot(MoonshotChatModel)、spring-ai-zhipuai(ZhiPuAiChatModel)

OllamaAutoConfiguration

org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java

@AutoConfiguration(after = RestClientAutoConfiguration.class)
@ConditionalOnClass(OllamaApi.class)
@EnableConfigurationProperties({ OllamaChatProperties.class, OllamaEmbeddingProperties.class,
        OllamaConnectionProperties.class, OllamaInitializationProperties.class })
@ImportAutoConfiguration(classes = { RestClientAutoConfiguration.class, WebClientAutoConfiguration.class })
public class OllamaAutoConfiguration {

    @Bean
    @ConditionalOnMissingBean(OllamaConnectionDetails.class)
    public PropertiesOllamaConnectionDetails ollamaConnectionDetails(OllamaConnectionProperties properties) {
        return new PropertiesOllamaConnectionDetails(properties);
    }

    @Bean
    @ConditionalOnMissingBean
    public OllamaApi ollamaApi(OllamaConnectionDetails connectionDetails,
            ObjectProvider<RestClient.Builder> restClientBuilderProvider,
            ObjectProvider<WebClient.Builder> webClientBuilderProvider) {
        return new OllamaApi(connectionDetails.getBaseUrl(),
                restClientBuilderProvider.getIfAvailable(RestClient::builder),
                webClientBuilderProvider.getIfAvailable(WebClient::builder));
    }

    @Bean
    @ConditionalOnMissingBean
    @ConditionalOnProperty(prefix = OllamaChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
            matchIfMissing = true)
    public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties properties,
            OllamaInitializationProperties initProperties, List<FunctionCallback> toolFunctionCallbacks,
            FunctionCallbackResolver functionCallbackResolver, ObjectProvider<ObservationRegistry> observationRegistry,
            ObjectProvider<ChatModelObservationConvention> observationConvention) {
        var chatModelPullStrategy = initProperties.getChat().isInclude() ? initProperties.getPullModelStrategy()
                : PullModelStrategy.NEVER;

        var chatModel = OllamaChatModel.builder()
            .ollamaApi(ollamaApi)
            .defaultOptions(properties.getOptions())
            .functionCallbackResolver(functionCallbackResolver)
            .toolFunctionCallbacks(toolFunctionCallbacks)
            .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
            .modelManagementOptions(
                    new ModelManagementOptions(chatModelPullStrategy, initProperties.getChat().getAdditionalModels(),
                            initProperties.getTimeout(), initProperties.getMaxRetries()))
            .build();

        observationConvention.ifAvailable(chatModel::setObservationConvention);

        return chatModel;
    }

    @Bean
    @ConditionalOnMissingBean
    @ConditionalOnProperty(prefix = OllamaEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
            matchIfMissing = true)
    public OllamaEmbeddingModel ollamaEmbeddingModel(OllamaApi ollamaApi, OllamaEmbeddingProperties properties,
            OllamaInitializationProperties initProperties, ObjectProvider<ObservationRegistry> observationRegistry,
            ObjectProvider<EmbeddingModelObservationConvention> observationConvention) {
        var embeddingModelPullStrategy = initProperties.getEmbedding().isInclude()
                ? initProperties.getPullModelStrategy() : PullModelStrategy.NEVER;

        var embeddingModel = OllamaEmbeddingModel.builder()
            .ollamaApi(ollamaApi)
            .defaultOptions(properties.getOptions())
            .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
            .modelManagementOptions(new ModelManagementOptions(embeddingModelPullStrategy,
                    initProperties.getEmbedding().getAdditionalModels(), initProperties.getTimeout(),
                    initProperties.getMaxRetries()))
            .build();

        observationConvention.ifAvailable(embeddingModel::setObservationConvention);

        return embeddingModel;
    }

    @Bean
    @ConditionalOnMissingBean
    public FunctionCallbackResolver springAiFunctionManager(ApplicationContext context) {
        DefaultFunctionCallbackResolver manager = new DefaultFunctionCallbackResolver();
        manager.setApplicationContext(context);
        return manager;
    }

    static class PropertiesOllamaConnectionDetails implements OllamaConnectionDetails {

        private final OllamaConnectionProperties properties;

        PropertiesOllamaConnectionDetails(OllamaConnectionProperties properties) {
            this.properties = properties;
        }

        @Override
        public String getBaseUrl() {
            return this.properties.getBaseUrl();
        }

    }

}
spring-ai-spring-boot-autoconfigure提供了一系列的AutoConfiguration,比如OllamaAutoConfiguration自动配置了OllamaChatModel

小结

Spring AI的Model接口定义了call方法,入参为ModelRequest类型,返回ModelResponse类型;StreamingModel接口定义了stream方法,入参为ModelRequest类型,返回Flux<ModelResponse>;StreamingChatModel继承了StreamingModel接口,指定了入参为Prompt类型,返回类型为Flux<ChatResponse>,并提供了Flux<String> stream(String message)Flux<String> stream(Message... messages)这两个default方法;而ChatModel继承了Model、StreamingChatModel接口,其中Model的入参为Prompt类型,返回为ChatResponse类型。ChatModel在不同模块中有不同的实现,比如spring-ai-ollama(OllamaChatModel)、spring-ai-openai(OpenAiChatModel)、spring-ai-minimax(MiniMaxChatModel)、spring-ai-moonshot(MoonshotChatModel)、spring-ai-zhipuai(ZhiPuAiChatModel)。

doc


codecraft
11.9k 声望2k 粉丝

当一个代码的工匠回首往事时,不因虚度年华而悔恨,也不因碌碌无为而羞愧,这样,当他老的时候,可以很自豪告诉世人,我曾经将代码注入生命去打造互联网的浪潮之巅,那是个很疯狂的时代,我在一波波的浪潮上留下...