序
本文主要研究一下Spring AI的StructuredOutputConverter
StructuredOutputConverter
org/springframework/ai/converter/StructuredOutputConverter.java
public interface StructuredOutputConverter<T> extends Converter<String, T>, FormatProvider {
}
StructuredOutputConverter接口继承了Converter、FormatProvider接口,它有两个抽象类,分别是AbstractMessageOutputConverter、AbstractConversionServiceOutputConverter
Converter
org/springframework/core/convert/converter/Converter.java
@FunctionalInterface
public interface Converter<S, T> {
/**
* Convert the source object of type {@code S} to target type {@code T}.
* @param source the source object to convert, which must be an instance of {@code S} (never {@code null})
* @return the converted object, which must be an instance of {@code T} (potentially {@code null})
* @throws IllegalArgumentException if the source cannot be converted to the desired target type
*/
@Nullable
T convert(S source);
/**
* Construct a composed {@link Converter} that first applies this {@link Converter}
* to its input, and then applies the {@code after} {@link Converter} to the
* result.
* @param after the {@link Converter} to apply after this {@link Converter}
* is applied
* @param <U> the type of output of both the {@code after} {@link Converter}
* and the composed {@link Converter}
* @return a composed {@link Converter} that first applies this {@link Converter}
* and then applies the {@code after} {@link Converter}
* @since 5.3
*/
default <U> Converter<S, U> andThen(Converter<? super T, ? extends U> after) {
Assert.notNull(after, "'after' Converter must not be null");
return (S s) -> {
T initialResult = convert(s);
return (initialResult != null ? after.convert(initialResult) : null);
};
}
}
Converter接口定义了convert方法,并提供了andThen的default实现
FormatProvider
org/springframework/ai/converter/FormatProvider.java
public interface FormatProvider {
/**
* Get the format of the output of a language generative.
* @return Returns a string containing instructions for how the output of a language
* generative should be formatted.
*/
String getFormat();
}
FormatProvider定义了getFormat接口
AbstractMessageOutputConverter
org/springframework/ai/converter/AbstractMessageOutputConverter.java
public abstract class AbstractMessageOutputConverter<T> implements StructuredOutputConverter<T> {
private MessageConverter messageConverter;
/**
* Create a new AbstractMessageOutputConverter.
* @param messageConverter the message converter to use
*/
public AbstractMessageOutputConverter(MessageConverter messageConverter) {
this.messageConverter = messageConverter;
}
/**
* Return the message converter used by this output converter.
* @return the message converter
*/
public MessageConverter getMessageConverter() {
return this.messageConverter;
}
}
AbstractMessageOutputConverter定义了MessageConverter属性,它的实现类为MapOutputConverter
MapOutputConverter
org/springframework/ai/converter/MapOutputConverter.java
public class MapOutputConverter extends AbstractMessageOutputConverter<Map<String, Object>> {
public MapOutputConverter() {
super(new MappingJackson2MessageConverter());
}
@Override
public Map<String, Object> convert(@NonNull String text) {
if (text.startsWith("```json") && text.endsWith("```")) {
text = text.substring(7, text.length() - 3);
}
Message<?> message = MessageBuilder.withPayload(text.getBytes(StandardCharsets.UTF_8)).build();
return (Map) this.getMessageConverter().fromMessage(message, HashMap.class);
}
@Override
public String getFormat() {
String raw = """
Your response should be in JSON format.
The data structure for the JSON should match this Java class: %s
Do not include any explanations, only provide a RFC8259 compliant JSON response following this format without deviation.
Remove the ```json markdown surrounding the output including the trailing "```".
""";
return String.format(raw, HashMap.class.getName());
}
}
MapOutputConverter继承了AbstractMessageOutputConverter,其MessageConverter为MappingJackson2MessageConverter
AbstractConversionServiceOutputConverter
org/springframework/ai/converter/AbstractConversionServiceOutputConverter.java
public abstract class AbstractConversionServiceOutputConverter<T> implements StructuredOutputConverter<T> {
private final DefaultConversionService conversionService;
/**
* Create a new {@link AbstractConversionServiceOutputConverter} instance.
* @param conversionService the {@link DefaultConversionService} to use for converting
* the output.
*/
public AbstractConversionServiceOutputConverter(DefaultConversionService conversionService) {
this.conversionService = conversionService;
}
/**
* Return the ConversionService used by this converter.
* @return the ConversionService used by this converter.
*/
public DefaultConversionService getConversionService() {
return this.conversionService;
}
}
AbstractConversionServiceOutputConverter定义了DefaultConversionService属性,它的实现类为ListOutputConverter
ListOutputConverter
org/springframework/ai/converter/ListOutputConverter.java
public class ListOutputConverter extends AbstractConversionServiceOutputConverter<List<String>> {
public ListOutputConverter(DefaultConversionService defaultConversionService) {
super(defaultConversionService);
}
@Override
public String getFormat() {
return """
Respond with only a list of comma-separated values, without any leading or trailing text.
Example format: foo, bar, baz
""";
}
@Override
public List<String> convert(@NonNull String text) {
return this.getConversionService().convert(text, List.class);
}
}
ListOutputConverter继承了AbstractConversionServiceOutputConverter,其convert将text转换为List<String>
BeanOutputConverter
org/springframework/ai/converter/BeanOutputConverter.java
public class BeanOutputConverter<T> implements StructuredOutputConverter<T> {
private final Logger logger = LoggerFactory.getLogger(BeanOutputConverter.class);
/**
* The target class type reference to which the output will be converted.
*/
private final Type type;
/** The object mapper used for deserialization and other JSON operations. */
private final ObjectMapper objectMapper;
/** Holds the generated JSON schema for the target type. */
private String jsonSchema;
/**
* Constructor to initialize with the target type's class.
* @param clazz The target type's class.
*/
public BeanOutputConverter(Class<T> clazz) {
this(ParameterizedTypeReference.forType(clazz));
}
/**
* Constructor to initialize with the target type's class, a custom object mapper, and
* a line endings normalizer to ensure consistent line endings on any platform.
* @param clazz The target type's class.
* @param objectMapper Custom object mapper for JSON operations. endings.
*/
public BeanOutputConverter(Class<T> clazz, ObjectMapper objectMapper) {
this(ParameterizedTypeReference.forType(clazz), objectMapper);
}
/**
* Constructor to initialize with the target class type reference.
* @param typeRef The target class type reference.
*/
public BeanOutputConverter(ParameterizedTypeReference<T> typeRef) {
this(typeRef.getType(), null);
}
/**
* Constructor to initialize with the target class type reference, a custom object
* mapper, and a line endings normalizer to ensure consistent line endings on any
* platform.
* @param typeRef The target class type reference.
* @param objectMapper Custom object mapper for JSON operations. endings.
*/
public BeanOutputConverter(ParameterizedTypeReference<T> typeRef, ObjectMapper objectMapper) {
this(typeRef.getType(), objectMapper);
}
/**
* Constructor to initialize with the target class type reference, a custom object
* mapper, and a line endings normalizer to ensure consistent line endings on any
* platform.
* @param type The target class type.
* @param objectMapper Custom object mapper for JSON operations. endings.
*/
private BeanOutputConverter(Type type, ObjectMapper objectMapper) {
Objects.requireNonNull(type, "Type cannot be null;");
this.type = type;
this.objectMapper = objectMapper != null ? objectMapper : getObjectMapper();
generateSchema();
}
/**
* Generates the JSON schema for the target type.
*/
private void generateSchema() {
JacksonModule jacksonModule = new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED,
JacksonOption.RESPECT_JSONPROPERTY_ORDER);
SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder(
com.github.victools.jsonschema.generator.SchemaVersion.DRAFT_2020_12,
com.github.victools.jsonschema.generator.OptionPreset.PLAIN_JSON)
.with(jacksonModule)
.with(Option.FORBIDDEN_ADDITIONAL_PROPERTIES_BY_DEFAULT);
SchemaGeneratorConfig config = configBuilder.build();
SchemaGenerator generator = new SchemaGenerator(config);
JsonNode jsonNode = generator.generateSchema(this.type);
ObjectWriter objectWriter = this.objectMapper.writer(new DefaultPrettyPrinter()
.withObjectIndenter(new DefaultIndenter().withLinefeed(System.lineSeparator())));
try {
this.jsonSchema = objectWriter.writeValueAsString(jsonNode);
}
catch (JsonProcessingException e) {
logger.error("Could not pretty print json schema for jsonNode: {}", jsonNode);
throw new RuntimeException("Could not pretty print json schema for " + this.type, e);
}
}
/**
* Parses the given text to transform it to the desired target type.
* @param text The LLM output in string format.
* @return The parsed output in the desired target type.
*/
@SuppressWarnings("unchecked")
@Override
public T convert(@NonNull String text) {
try {
// Remove leading and trailing whitespace
text = text.trim();
// Check for and remove triple backticks and "json" identifier
if (text.startsWith("```") && text.endsWith("```")) {
// Remove the first line if it contains "```json"
String[] lines = text.split("\n", 2);
if (lines[0].trim().equalsIgnoreCase("```json")) {
text = lines.length > 1 ? lines[1] : "";
}
else {
text = text.substring(3); // Remove leading ```
}
// Remove trailing ```
text = text.substring(0, text.length() - 3);
// Trim again to remove any potential whitespace
text = text.trim();
}
return (T) this.objectMapper.readValue(text, this.objectMapper.constructType(this.type));
}
catch (JsonProcessingException e) {
logger.error(SENSITIVE_DATA_MARKER,
"Could not parse the given text to the desired target type: \"{}\" into {}", text, this.type);
throw new RuntimeException(e);
}
}
/**
* Configures and returns an object mapper for JSON operations.
* @return Configured object mapper.
*/
protected ObjectMapper getObjectMapper() {
return JsonMapper.builder()
.addModules(JacksonUtils.instantiateAvailableModules())
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
.build();
}
/**
* Provides the expected format of the response, instructing that it should adhere to
* the generated JSON schema.
* @return The instruction format string.
*/
@Override
public String getFormat() {
String template = """
Your response should be in JSON format.
Do not include any explanations, only provide a RFC8259 compliant JSON response following this format without deviation.
Do not include markdown code blocks in your response.
Remove the ```json markdown from the output.
Here is the JSON Schema instance your output must adhere to:
```%s```
""";
return String.format(template, this.jsonSchema);
}
/**
* Provides the generated JSON schema for the target type.
* @return The generated JSON schema.
*/
public String getJsonSchema() {
return this.jsonSchema;
}
public Map<String, Object> getJsonSchemaMap() {
try {
return this.objectMapper.readValue(this.jsonSchema, Map.class);
}
catch (JsonProcessingException ex) {
logger.error("Could not parse the JSON Schema to a Map object", ex);
throw new IllegalStateException(ex);
}
}
}
BeanOutputConverter通过objectMapper将json转换为bean
示例
chatModel + outputConverter
@Test
void mapOutputConvert() {
MapOutputConverter outputConverter = new MapOutputConverter();
String format = outputConverter.getFormat();
String template = """
For each letter in the RGB color scheme, tell me what it stands for.
Example: R -> Red.
{format}
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = this.chatModel.call(prompt).getResult();
Map<String, Object> result = outputConverter.convert(generation.getOutput().getText());
assertThat(result).isNotNull();
assertThat((String) result.get("R")).containsIgnoringCase("red");
assertThat((String) result.get("G")).containsIgnoringCase("green");
assertThat((String) result.get("B")).containsIgnoringCase("blue");
}
chatClient + outputConverter
@Test
public void responseEntityTest() {
ChatResponseMetadata metadata = ChatResponseMetadata.builder().keyValue("key1", "value1").build();
var chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("""
{"name":"John", "age":30}
"""))), metadata);
given(this.chatModel.call(this.promptCaptor.capture())).willReturn(chatResponse);
ResponseEntity<ChatResponse, MyBean> responseEntity = ChatClient.builder(this.chatModel)
.build()
.prompt()
.user("Tell me about John")
.call()
.responseEntity(MyBean.class);
assertThat(responseEntity.getResponse()).isEqualTo(chatResponse);
assertThat(responseEntity.getResponse().getMetadata().get("key1").toString()).isEqualTo("value1");
assertThat(responseEntity.getEntity()).isEqualTo(new MyBean("John", 30));
Message userMessage = this.promptCaptor.getValue().getInstructions().get(0);
assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER);
assertThat(userMessage.getText()).contains("Tell me about John");
}
ChatClient配置responseEntity类型,内部使用了BeanOutputConverter
小结
Spring AI提供了Structured Output Converters来将LLM的输出转换为结构化的格式。目前主要有MapOutputConverter、ListOutputConverter、BeanOutputConverter这几种实现。
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。