本文主要研究一下Spring AI的StructuredOutputConverter

StructuredOutputConverter

org/springframework/ai/converter/StructuredOutputConverter.java

public interface StructuredOutputConverter<T> extends Converter<String, T>, FormatProvider {

}
StructuredOutputConverter接口继承了Converter、FormatProvider接口,它有两个抽象类,分别是AbstractMessageOutputConverter、AbstractConversionServiceOutputConverter

Converter

org/springframework/core/convert/converter/Converter.java

@FunctionalInterface
public interface Converter<S, T> {

    /**
     * Convert the source object of type {@code S} to target type {@code T}.
     * @param source the source object to convert, which must be an instance of {@code S} (never {@code null})
     * @return the converted object, which must be an instance of {@code T} (potentially {@code null})
     * @throws IllegalArgumentException if the source cannot be converted to the desired target type
     */
    @Nullable
    T convert(S source);

    /**
     * Construct a composed {@link Converter} that first applies this {@link Converter}
     * to its input, and then applies the {@code after} {@link Converter} to the
     * result.
     * @param after the {@link Converter} to apply after this {@link Converter}
     * is applied
     * @param <U> the type of output of both the {@code after} {@link Converter}
     * and the composed {@link Converter}
     * @return a composed {@link Converter} that first applies this {@link Converter}
     * and then applies the {@code after} {@link Converter}
     * @since 5.3
     */
    default <U> Converter<S, U> andThen(Converter<? super T, ? extends U> after) {
        Assert.notNull(after, "'after' Converter must not be null");
        return (S s) -> {
            T initialResult = convert(s);
            return (initialResult != null ? after.convert(initialResult) : null);
        };
    }

}
Converter接口定义了convert方法,并提供了andThen的default实现

FormatProvider

org/springframework/ai/converter/FormatProvider.java

public interface FormatProvider {

    /**
     * Get the format of the output of a language generative.
     * @return Returns a string containing instructions for how the output of a language
     * generative should be formatted.
     */
    String getFormat();

}
FormatProvider定义了getFormat接口

AbstractMessageOutputConverter

org/springframework/ai/converter/AbstractMessageOutputConverter.java

public abstract class AbstractMessageOutputConverter<T> implements StructuredOutputConverter<T> {

    private MessageConverter messageConverter;

    /**
     * Create a new AbstractMessageOutputConverter.
     * @param messageConverter the message converter to use
     */
    public AbstractMessageOutputConverter(MessageConverter messageConverter) {
        this.messageConverter = messageConverter;
    }

    /**
     * Return the message converter used by this output converter.
     * @return the message converter
     */
    public MessageConverter getMessageConverter() {
        return this.messageConverter;
    }

}
AbstractMessageOutputConverter定义了MessageConverter属性,它的实现类为MapOutputConverter

MapOutputConverter

org/springframework/ai/converter/MapOutputConverter.java

public class MapOutputConverter extends AbstractMessageOutputConverter<Map<String, Object>> {

    public MapOutputConverter() {
        super(new MappingJackson2MessageConverter());
    }

    @Override
    public Map<String, Object> convert(@NonNull String text) {
        if (text.startsWith("```json") && text.endsWith("```")) {
            text = text.substring(7, text.length() - 3);
        }

        Message<?> message = MessageBuilder.withPayload(text.getBytes(StandardCharsets.UTF_8)).build();
        return (Map) this.getMessageConverter().fromMessage(message, HashMap.class);
    }

    @Override
    public String getFormat() {
        String raw = """
                Your response should be in JSON format.
                The data structure for the JSON should match this Java class: %s
                Do not include any explanations, only provide a RFC8259 compliant JSON response following this format without deviation.
                Remove the ```json markdown surrounding the output including the trailing "```".
                """;
        return String.format(raw, HashMap.class.getName());
    }

}
MapOutputConverter继承了AbstractMessageOutputConverter,其MessageConverter为MappingJackson2MessageConverter

AbstractConversionServiceOutputConverter

org/springframework/ai/converter/AbstractConversionServiceOutputConverter.java

public abstract class AbstractConversionServiceOutputConverter<T> implements StructuredOutputConverter<T> {

    private final DefaultConversionService conversionService;

    /**
     * Create a new {@link AbstractConversionServiceOutputConverter} instance.
     * @param conversionService the {@link DefaultConversionService} to use for converting
     * the output.
     */
    public AbstractConversionServiceOutputConverter(DefaultConversionService conversionService) {
        this.conversionService = conversionService;
    }

    /**
     * Return the ConversionService used by this converter.
     * @return the ConversionService used by this converter.
     */
    public DefaultConversionService getConversionService() {
        return this.conversionService;
    }

}
AbstractConversionServiceOutputConverter定义了DefaultConversionService属性,它的实现类为ListOutputConverter

ListOutputConverter

org/springframework/ai/converter/ListOutputConverter.java

public class ListOutputConverter extends AbstractConversionServiceOutputConverter<List<String>> {

    public ListOutputConverter(DefaultConversionService defaultConversionService) {
        super(defaultConversionService);
    }

    @Override
    public String getFormat() {
        return """
                Respond with only a list of comma-separated values, without any leading or trailing text.
                Example format: foo, bar, baz
                """;
    }

    @Override
    public List<String> convert(@NonNull String text) {
        return this.getConversionService().convert(text, List.class);
    }

}
ListOutputConverter继承了AbstractConversionServiceOutputConverter,其convert将text转换为List<String>

BeanOutputConverter

org/springframework/ai/converter/BeanOutputConverter.java

public class BeanOutputConverter<T> implements StructuredOutputConverter<T> {

    private final Logger logger = LoggerFactory.getLogger(BeanOutputConverter.class);

    /**
     * The target class type reference to which the output will be converted.
     */
    private final Type type;

    /** The object mapper used for deserialization and other JSON operations. */
    private final ObjectMapper objectMapper;

    /** Holds the generated JSON schema for the target type. */
    private String jsonSchema;

    /**
     * Constructor to initialize with the target type's class.
     * @param clazz The target type's class.
     */
    public BeanOutputConverter(Class<T> clazz) {
        this(ParameterizedTypeReference.forType(clazz));
    }

    /**
     * Constructor to initialize with the target type's class, a custom object mapper, and
     * a line endings normalizer to ensure consistent line endings on any platform.
     * @param clazz The target type's class.
     * @param objectMapper Custom object mapper for JSON operations. endings.
     */
    public BeanOutputConverter(Class<T> clazz, ObjectMapper objectMapper) {
        this(ParameterizedTypeReference.forType(clazz), objectMapper);
    }

    /**
     * Constructor to initialize with the target class type reference.
     * @param typeRef The target class type reference.
     */
    public BeanOutputConverter(ParameterizedTypeReference<T> typeRef) {
        this(typeRef.getType(), null);
    }

    /**
     * Constructor to initialize with the target class type reference, a custom object
     * mapper, and a line endings normalizer to ensure consistent line endings on any
     * platform.
     * @param typeRef The target class type reference.
     * @param objectMapper Custom object mapper for JSON operations. endings.
     */
    public BeanOutputConverter(ParameterizedTypeReference<T> typeRef, ObjectMapper objectMapper) {
        this(typeRef.getType(), objectMapper);
    }

    /**
     * Constructor to initialize with the target class type reference, a custom object
     * mapper, and a line endings normalizer to ensure consistent line endings on any
     * platform.
     * @param type The target class type.
     * @param objectMapper Custom object mapper for JSON operations. endings.
     */
    private BeanOutputConverter(Type type, ObjectMapper objectMapper) {
        Objects.requireNonNull(type, "Type cannot be null;");
        this.type = type;
        this.objectMapper = objectMapper != null ? objectMapper : getObjectMapper();
        generateSchema();
    }

    /**
     * Generates the JSON schema for the target type.
     */
    private void generateSchema() {
        JacksonModule jacksonModule = new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED,
                JacksonOption.RESPECT_JSONPROPERTY_ORDER);
        SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder(
                com.github.victools.jsonschema.generator.SchemaVersion.DRAFT_2020_12,
                com.github.victools.jsonschema.generator.OptionPreset.PLAIN_JSON)
            .with(jacksonModule)
            .with(Option.FORBIDDEN_ADDITIONAL_PROPERTIES_BY_DEFAULT);
        SchemaGeneratorConfig config = configBuilder.build();
        SchemaGenerator generator = new SchemaGenerator(config);
        JsonNode jsonNode = generator.generateSchema(this.type);
        ObjectWriter objectWriter = this.objectMapper.writer(new DefaultPrettyPrinter()
            .withObjectIndenter(new DefaultIndenter().withLinefeed(System.lineSeparator())));
        try {
            this.jsonSchema = objectWriter.writeValueAsString(jsonNode);
        }
        catch (JsonProcessingException e) {
            logger.error("Could not pretty print json schema for jsonNode: {}", jsonNode);
            throw new RuntimeException("Could not pretty print json schema for " + this.type, e);
        }
    }

    /**
     * Parses the given text to transform it to the desired target type.
     * @param text The LLM output in string format.
     * @return The parsed output in the desired target type.
     */
    @SuppressWarnings("unchecked")
    @Override
    public T convert(@NonNull String text) {
        try {
            // Remove leading and trailing whitespace
            text = text.trim();

            // Check for and remove triple backticks and "json" identifier
            if (text.startsWith("```") && text.endsWith("```")) {
                // Remove the first line if it contains "```json"
                String[] lines = text.split("\n", 2);
                if (lines[0].trim().equalsIgnoreCase("```json")) {
                    text = lines.length > 1 ? lines[1] : "";
                }
                else {
                    text = text.substring(3); // Remove leading ```
                }

                // Remove trailing ```
                text = text.substring(0, text.length() - 3);

                // Trim again to remove any potential whitespace
                text = text.trim();
            }
            return (T) this.objectMapper.readValue(text, this.objectMapper.constructType(this.type));
        }
        catch (JsonProcessingException e) {
            logger.error(SENSITIVE_DATA_MARKER,
                    "Could not parse the given text to the desired target type: \"{}\" into {}", text, this.type);
            throw new RuntimeException(e);
        }
    }

    /**
     * Configures and returns an object mapper for JSON operations.
     * @return Configured object mapper.
     */
    protected ObjectMapper getObjectMapper() {
        return JsonMapper.builder()
            .addModules(JacksonUtils.instantiateAvailableModules())
            .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
            .build();
    }

    /**
     * Provides the expected format of the response, instructing that it should adhere to
     * the generated JSON schema.
     * @return The instruction format string.
     */
    @Override
    public String getFormat() {
        String template = """
                Your response should be in JSON format.
                Do not include any explanations, only provide a RFC8259 compliant JSON response following this format without deviation.
                Do not include markdown code blocks in your response.
                Remove the ```json markdown from the output.
                Here is the JSON Schema instance your output must adhere to:
                ```%s```
                """;
        return String.format(template, this.jsonSchema);
    }

    /**
     * Provides the generated JSON schema for the target type.
     * @return The generated JSON schema.
     */
    public String getJsonSchema() {
        return this.jsonSchema;
    }

    public Map<String, Object> getJsonSchemaMap() {
        try {
            return this.objectMapper.readValue(this.jsonSchema, Map.class);
        }
        catch (JsonProcessingException ex) {
            logger.error("Could not parse the JSON Schema to a Map object", ex);
            throw new IllegalStateException(ex);
        }
    }

}
BeanOutputConverter通过objectMapper将json转换为bean

示例

chatModel + outputConverter

    @Test
    void mapOutputConvert() {
        MapOutputConverter outputConverter = new MapOutputConverter();

        String format = outputConverter.getFormat();
        String template = """
                For each letter in the RGB color scheme, tell me what it stands for.
                Example: R -> Red.
                {format}
                """;
        PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
        Prompt prompt = new Prompt(promptTemplate.createMessage());

        Generation generation = this.chatModel.call(prompt).getResult();

        Map<String, Object> result = outputConverter.convert(generation.getOutput().getText());
        assertThat(result).isNotNull();
        assertThat((String) result.get("R")).containsIgnoringCase("red");
        assertThat((String) result.get("G")).containsIgnoringCase("green");
        assertThat((String) result.get("B")).containsIgnoringCase("blue");
    }

chatClient + outputConverter

    @Test
    public void responseEntityTest() {

        ChatResponseMetadata metadata = ChatResponseMetadata.builder().keyValue("key1", "value1").build();

        var chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("""
                {"name":"John", "age":30}
                """))), metadata);

        given(this.chatModel.call(this.promptCaptor.capture())).willReturn(chatResponse);

        ResponseEntity<ChatResponse, MyBean> responseEntity = ChatClient.builder(this.chatModel)
            .build()
            .prompt()
            .user("Tell me about John")
            .call()
            .responseEntity(MyBean.class);

        assertThat(responseEntity.getResponse()).isEqualTo(chatResponse);
        assertThat(responseEntity.getResponse().getMetadata().get("key1").toString()).isEqualTo("value1");

        assertThat(responseEntity.getEntity()).isEqualTo(new MyBean("John", 30));

        Message userMessage = this.promptCaptor.getValue().getInstructions().get(0);
        assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER);
        assertThat(userMessage.getText()).contains("Tell me about John");
    }
ChatClient配置responseEntity类型,内部使用了BeanOutputConverter

小结

Spring AI提供了Structured Output Converters来将LLM的输出转换为结构化的格式。目前主要有MapOutputConverter、ListOutputConverter、BeanOutputConverter这几种实现。

doc


codecraft
11.9k 声望2k 粉丝

当一个代码的工匠回首往事时,不因虚度年华而悔恨,也不因碌碌无为而羞愧,这样,当他老的时候,可以很自豪告诉世人,我曾经将代码注入生命去打造互联网的浪潮之巅,那是个很疯狂的时代,我在一波波的浪潮上留下...