本文主要研究一下Spring AI的Prompt

Prompt

org/springframework/ai/chat/prompt/Prompt.java

public class Prompt implements ModelRequest<List<Message>> {

    private final List<Message> messages;

    private ChatOptions chatOptions;

    public Prompt(String contents) {
        this(new UserMessage(contents));
    }

    public Prompt(Message message) {
        this(Collections.singletonList(message));
    }

    public Prompt(List<Message> messages) {
        this(messages, null);
    }

    public Prompt(Message... messages) {
        this(Arrays.asList(messages), null);
    }

    public Prompt(String contents, ChatOptions chatOptions) {
        this(new UserMessage(contents), chatOptions);
    }

    public Prompt(Message message, ChatOptions chatOptions) {
        this(Collections.singletonList(message), chatOptions);
    }

    public Prompt(List<Message> messages, ChatOptions chatOptions) {
        this.messages = messages;
        this.chatOptions = chatOptions;
    }

    public String getContents() {
        StringBuilder sb = new StringBuilder();
        for (Message message : getInstructions()) {
            sb.append(message.getText());
        }
        return sb.toString();
    }

    //......
}    
Prompt实现了ModelRequest方法,其getInstructions返回的类型为List<Message>,其getContents方法遍历getInstructions添加message.getText()

Message

org/springframework/ai/chat/messages/Message.java

public interface Message extends Content {

    /**
     * Get the message type.
     * @return the message type
     */
    MessageType getMessageType();

}

MessageType

org/springframework/ai/chat/messages/MessageType.java

public enum MessageType {

    /**
     * A {@link Message} of type {@literal user}, having the user role and originating
     * from an end-user or developer.
     * @see UserMessage
     */
    USER("user"),

    /**
     * A {@link Message} of type {@literal assistant} passed in subsequent input
     * {@link Message Messages} as the {@link Message} generated in response to the user.
     * @see AssistantMessage
     */
    ASSISTANT("assistant"),

    /**
     * A {@link Message} of type {@literal system} passed as input {@link Message
     * Messages} containing high-level instructions for the conversation, such as behave
     * like a certain character or provide answers in a specific format.
     * @see SystemMessage
     */
    SYSTEM("system"),

    /**
     * A {@link Message} of type {@literal function} passed as input {@link Message
     * Messages} with function content in a chat application.
     * @see ToolResponseMessage
     */
    TOOL("tool");

    private final String value;

    MessageType(String value) {
        this.value = value;
    }

    public static MessageType fromValue(String value) {
        for (MessageType messageType : MessageType.values()) {
            if (messageType.getValue().equals(value)) {
                return messageType;
            }
        }
        throw new IllegalArgumentException("Invalid MessageType value: " + value);
    }

    public String getValue() {
        return this.value;
    }

}
MessageType定义了USER、SYSTEM、ASSISTANT、TOOL这几种类型

PromptTemplate

PromptTemplateMessageActions

org/springframework/ai/chat/prompt/PromptTemplateMessageActions.java

public interface PromptTemplateMessageActions {

    Message createMessage();

    Message createMessage(List<Media> mediaList);

    Message createMessage(Map<String, Object> model);

}
PromptTemplateMessageActions定义了createMessage方法

PromptTemplateStringActions

org/springframework/ai/chat/prompt/PromptTemplateStringActions.java

public interface PromptTemplateStringActions {

    String render();

    String render(Map<String, Object> model);

}
PromptTemplateStringActions定义了render方法,渲染为String类型

PromptTemplateChatActions

org/springframework/ai/chat/prompt/PromptTemplateChatActions.java

public interface PromptTemplateChatActions {

    List<Message> createMessages();

    List<Message> createMessages(Map<String, Object> model);

}
PromptTemplateChatActions接口定义了createMessages方法,返回List<Message>

PromptTemplateActions

org/springframework/ai/chat/prompt/PromptTemplateActions.java

public interface PromptTemplateActions extends PromptTemplateStringActions {

    Prompt create();

    Prompt create(ChatOptions modelOptions);

    Prompt create(Map<String, Object> model);

    Prompt create(Map<String, Object> model, ChatOptions modelOptions);

}
PromptTemplateActions继承了PromptTemplateStringActions接口,它定义了create方法,用于创建Prompt

PromptTemplate

public class PromptTemplate implements PromptTemplateActions, PromptTemplateMessageActions {

    protected String template;

    protected TemplateFormat templateFormat = TemplateFormat.ST;

    private ST st;

    private Map<String, Object> dynamicModel = new HashMap<>();

    //......

    // Render Methods
    @Override
    public String render() {
        validate(this.dynamicModel);
        return this.st.render();
    }

    @Override
    public String render(Map<String, Object> model) {
        validate(model);
        for (Entry<String, Object> entry : model.entrySet()) {
            if (this.st.getAttribute(entry.getKey()) != null) {
                this.st.remove(entry.getKey());
            }
            if (entry.getValue() instanceof Resource) {
                this.st.add(entry.getKey(), renderResource((Resource) entry.getValue()));
            }
            else {
                this.st.add(entry.getKey(), entry.getValue());
            }

        }
        return this.st.render();
    }

    @Override
    public Message createMessage() {
        return new UserMessage(render());
    }

    @Override
    public Message createMessage(List<Media> mediaList) {
        return new UserMessage(render(), mediaList);
    }

    @Override
    public Message createMessage(Map<String, Object> model) {
        return new UserMessage(render(model));
    }

    @Override
    public Prompt create() {
        return new Prompt(render(new HashMap<>()));
    }

    @Override
    public Prompt create(ChatOptions modelOptions) {
        return new Prompt(render(new HashMap<>()), modelOptions);
    }

    @Override
    public Prompt create(Map<String, Object> model) {
        return new Prompt(render(model));
    }

    @Override
    public Prompt create(Map<String, Object> model, ChatOptions modelOptions) {
        return new Prompt(render(model), modelOptions);
    }

    //......        
}    
PromptTemplate实现了PromptTemplateActions、PromptTemplateMessageActions接口,其render使用了org.stringtemplate.v4.ST来渲染。
PromptTemplateStringActions专注于创建和渲染提示字符串,代表了提示生成的最基本形式。
PromptTemplateMessageActions专为通过生成和操作Message对象来创建提示而设计。
PromptTemplateActions旨在返回Prompt对象,该对象可以传递给ChatModel以生成响应。

SystemPromptTemplate

org/springframework/ai/chat/prompt/SystemPromptTemplate.java

public class SystemPromptTemplate extends PromptTemplate {

    public SystemPromptTemplate(String template) {
        super(template);
    }

    public SystemPromptTemplate(Resource resource) {
        super(resource);
    }

    @Override
    public Message createMessage() {
        return new SystemMessage(render());
    }

    @Override
    public Message createMessage(Map<String, Object> model) {
        return new SystemMessage(render(model));
    }

    @Override
    public Prompt create() {
        return new Prompt(new SystemMessage(render()));
    }

    @Override
    public Prompt create(Map<String, Object> model) {
        return new Prompt(new SystemMessage(render(model)));
    }

}
SystemPromptTemplate继承了PromptTemplate,其createMessage返回的是SystemMessage

FunctionPromptTemplate

org/springframework/ai/chat/prompt/FunctionPromptTemplate.java

public class FunctionPromptTemplate extends PromptTemplate {

    private String name;

    public FunctionPromptTemplate(String template) {
        super(template);
    }

}
FunctionPromptTemplate继承了PromptTemplate,它定义了一个name属性

ChatPromptTemplate

org/springframework/ai/chat/prompt/ChatPromptTemplate.java

public class ChatPromptTemplate implements PromptTemplateActions, PromptTemplateChatActions {

    private final List<PromptTemplate> promptTemplates;

    public ChatPromptTemplate(List<PromptTemplate> promptTemplates) {
        this.promptTemplates = promptTemplates;
    }

    @Override
    public String render() {
        StringBuilder sb = new StringBuilder();
        for (PromptTemplate promptTemplate : this.promptTemplates) {
            sb.append(promptTemplate.render());
        }
        return sb.toString();
    }

    @Override
    public String render(Map<String, Object> model) {
        StringBuilder sb = new StringBuilder();
        for (PromptTemplate promptTemplate : this.promptTemplates) {
            sb.append(promptTemplate.render(model));
        }
        return sb.toString();
    }

    @Override
    public List<Message> createMessages() {
        List<Message> messages = new ArrayList<>();
        for (PromptTemplate promptTemplate : this.promptTemplates) {
            messages.add(promptTemplate.createMessage());
        }
        return messages;
    }

    @Override
    public List<Message> createMessages(Map<String, Object> model) {
        List<Message> messages = new ArrayList<>();
        for (PromptTemplate promptTemplate : this.promptTemplates) {
            messages.add(promptTemplate.createMessage(model));
        }
        return messages;
    }

    @Override
    public Prompt create() {
        List<Message> messages = createMessages();
        return new Prompt(messages);
    }

    @Override
    public Prompt create(ChatOptions modelOptions) {
        List<Message> messages = createMessages();
        return new Prompt(messages, modelOptions);
    }

    @Override
    public Prompt create(Map<String, Object> model) {
        List<Message> messages = createMessages(model);
        return new Prompt(messages);
    }

    @Override
    public Prompt create(Map<String, Object> model, ChatOptions modelOptions) {
        List<Message> messages = createMessages(model);
        return new Prompt(messages, modelOptions);
    }

}
ChatPromptTemplate实现了PromptTemplateActions, PromptTemplateChatActions接口,其构造器输入promptTemplates,其render方法遍历promptTemplates,挨个添加promptTemplate.render();其createMessages方法遍历promptTemplates,挨个添加promptTemplate.createMessage()

AssistantPromptTemplate

org/springframework/ai/chat/prompt/AssistantPromptTemplate.java

public class AssistantPromptTemplate extends PromptTemplate {

    public AssistantPromptTemplate(String template) {
        super(template);
    }

    public AssistantPromptTemplate(Resource resource) {
        super(resource);
    }

    @Override
    public Prompt create() {
        return new Prompt(new AssistantMessage(render()));
    }

    @Override
    public Prompt create(Map<String, Object> model) {
        return new Prompt(new AssistantMessage(render(model)));
    }

    @Override
    public Message createMessage() {
        return new AssistantMessage(render());
    }

    @Override
    public Message createMessage(Map<String, Object> model) {
        return new AssistantMessage(render(model));
    }

}
AssistantPromptTemplate继承了PromptTemplate,其createMessage方法返回的是AssistantMessage

示例

PromptTemplate示例

PromptTemplate promptTemplate = new PromptTemplate("Tell me a {adjective} joke about {topic}");

Prompt prompt = promptTemplate.create(Map.of("adjective", adjective, "topic", topic));

return chatModel.call(prompt).getResult();

SystemPromptTemplate示例

String userText = """
    Tell me about three famous pirates from the Golden Age of Piracy and why they did.
    Write at least a sentence for each pirate.
    """;

Message userMessage = new UserMessage(userText);

String systemText = """
  You are a helpful AI assistant that helps people find information.
  Your name is {name}
  You should reply to the user's request with your name and also in the style of a {voice}.
  """;

SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemText);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice));

Prompt prompt = new Prompt(List.of(userMessage, systemMessage));

List<Generation> response = chatModel.call(prompt).getResults();

小结

Spring AI的Message定义了MessageType属性,它有USER、SYSTEM、ASSISTANT、TOOL这几种类型;PromptTemplate的createMessage方法返回的是UserMessage,SystemPromptTemplate的createMessage方法返回的是SystemMessage,AssistantPromptTemplate的createMessage方法返回的是AssistantMessage。SystemPromptTemplate及AssistantPromptTemplate都继承了PromptTemplate,其render方法使用了org.stringtemplate.v4.ST来渲染;ChatPromptTemplate则是聚合了一系列的promptTemplates。

doc


codecraft
11.9k 声望2k 粉丝

当一个代码的工匠回首往事时,不因虚度年华而悔恨,也不因碌碌无为而羞愧,这样,当他老的时候,可以很自豪告诉世人,我曾经将代码注入生命去打造互联网的浪潮之巅,那是个很疯狂的时代,我在一波波的浪潮上留下...