序
本文主要研究一下langchain4j-spring-boot-starter的AiServicesAutoConfig
LangChain4jAutoConfig
dev/langchain4j/spring/LangChain4jAutoConfig.java
@AutoConfiguration
@Import({
AiServicesAutoConfig.class,
RagAutoConfig.class,
AiServiceScannerProcessor.class
})
public class LangChain4jAutoConfig {
}
LangChain4jAutoConfig自动import了AiServicesAutoConfig、RagAutoConfig、AiServiceScannerProcessor
AiServiceScannerProcessor
dev/langchain4j/service/spring/AiServiceScannerProcessor.java
@Component
public class AiServiceScannerProcessor implements BeanDefinitionRegistryPostProcessor, EnvironmentAware {
private Environment environment;
@Override
public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
ClassPathAiServiceScanner scanner = new ClassPathAiServiceScanner(registry, false);
Set<String> basePackages = getBasePackages((ConfigurableListableBeanFactory) registry);
scanner.scan(StringUtils.toStringArray(basePackages));
removeAiServicesWithInactiveProfiles(registry);
}
private Set<String> getBasePackages(ConfigurableListableBeanFactory beanFactory) {
Set<String> basePackages = new LinkedHashSet<>();
// AutoConfiguration
List<String> autoConfigPackages = AutoConfigurationPackages.get(beanFactory);
basePackages.addAll(autoConfigPackages);
// ComponentScan
addComponentScanPackages(beanFactory, basePackages);
return basePackages;
}
private void addComponentScanPackages(ConfigurableListableBeanFactory beanFactory, Set<String> collectedBasePackages) {
for (String beanName : beanFactory.getBeanNamesForAnnotation(ComponentScan.class)) {
Class<?> beanClass = beanFactory.getType(beanName);
if (beanClass != null) {
Set<ComponentScan> componentScans = AnnotatedElementUtils.getMergedRepeatableAnnotations(beanClass, ComponentScan.class);
for (ComponentScan componentScan : componentScans) {
Set<String> basePackages = new LinkedHashSet<>();
for (String pkg : componentScan.basePackages()) {
String[] tokenized = StringUtils.tokenizeToStringArray(this.environment.resolvePlaceholders(pkg),
ConfigurableApplicationContext.CONFIG_LOCATION_DELIMITERS);
Collections.addAll(basePackages, tokenized);
}
for (Class<?> clazz : componentScan.basePackageClasses()) {
basePackages.add(ClassUtils.getPackageName(clazz));
}
if (basePackages.isEmpty()) {
basePackages.add(ClassUtils.getPackageName(beanClass));
}
collectedBasePackages.addAll(basePackages);
}
}
}
}
private void removeAiServicesWithInactiveProfiles(BeanDefinitionRegistry registry) {
Arrays.stream(registry.getBeanDefinitionNames())
.filter(beanName -> {
try {
BeanDefinition beanDefinition = registry.getBeanDefinition(beanName);
if (beanDefinition.getBeanClassName() != null) {
Class<?> beanClass = Class.forName(beanDefinition.getBeanClassName());
if (beanClass.isAnnotationPresent(AiService.class)
&& beanClass.isAnnotationPresent(Profile.class)) {
Profile profileAnnotation = beanClass.getAnnotation(Profile.class);
String[] profiles = profileAnnotation.value();
return !environment.matchesProfiles(profiles);
}
}
} catch (ClassNotFoundException e) {
// should not happen
}
return false;
}).forEach(registry::removeBeanDefinition);
}
@Override
public void setEnvironment(Environment environment) {
this.environment = environment;
}
}
AiServiceScannerProcessor实现了BeanDefinitionRegistryPostProcessor接口,其postProcessBeanDefinitionRegistry通过ClassPathAiServiceScanner去扫描@AiService注解的类
ClassPathAiServiceScanner
dev/langchain4j/service/spring/ClassPathAiServiceScanner.java
class ClassPathAiServiceScanner extends ClassPathBeanDefinitionScanner {
ClassPathAiServiceScanner(BeanDefinitionRegistry registry, boolean useDefaultFilters) {
super(registry, useDefaultFilters);
addIncludeFilter(new AnnotationTypeFilter(AiService.class));
}
@Override
protected boolean isCandidateComponent(AnnotatedBeanDefinition beanDefinition) {
AnnotationMetadata annotationMetadata = beanDefinition.getMetadata();
return annotationMetadata.isInterface() && annotationMetadata.isIndependent();
}
}
ClassPathAiServiceScanner继承了ClassPathBeanDefinitionScanner,其构造器新增了AiService类型的AnnotationTypeFilter,其isCandidateComponent要求被扫描到的类是接口,而且是独立的(top-level的class或者是静态内部class)
ClassPathBeanDefinitionScanner
org/springframework/context/annotation/ClassPathBeanDefinitionScanner.java
protected Set<BeanDefinitionHolder> doScan(String... basePackages) {
Assert.notEmpty(basePackages, "At least one base package must be specified");
Set<BeanDefinitionHolder> beanDefinitions = new LinkedHashSet<>();
for (String basePackage : basePackages) {
Set<BeanDefinition> candidates = findCandidateComponents(basePackage);
for (BeanDefinition candidate : candidates) {
ScopeMetadata scopeMetadata = this.scopeMetadataResolver.resolveScopeMetadata(candidate);
candidate.setScope(scopeMetadata.getScopeName());
String beanName = this.beanNameGenerator.generateBeanName(candidate, this.registry);
if (candidate instanceof AbstractBeanDefinition abstractBeanDefinition) {
postProcessBeanDefinition(abstractBeanDefinition, beanName);
}
if (candidate instanceof AnnotatedBeanDefinition annotatedBeanDefinition) {
AnnotationConfigUtils.processCommonDefinitionAnnotations(annotatedBeanDefinition);
}
if (checkCandidate(beanName, candidate)) {
BeanDefinitionHolder definitionHolder = new BeanDefinitionHolder(candidate, beanName);
definitionHolder =
AnnotationConfigUtils.applyScopedProxyMode(scopeMetadata, definitionHolder, this.registry);
beanDefinitions.add(definitionHolder);
registerBeanDefinition(definitionHolder, this.registry);
}
}
}
return beanDefinitions;
}
protected void registerBeanDefinition(BeanDefinitionHolder definitionHolder, BeanDefinitionRegistry registry) {
BeanDefinitionReaderUtils.registerBeanDefinition(definitionHolder, registry);
}
扫描出来是ScannedGenericBeanDefinition(既是AbstractBeanDefinition类型,也实现了AnnotatedBeanDefinition接口
),先执行下postProcessBeanDefinition,再执行下AnnotationConfigUtils.processCommonDefinitionAnnotations处理@Lazy、@Primary、@DependsOn、@Role、@Description逻辑,最后通过checkCandidate判断是否要注册beanDefinition,是则调用registerBeanDefinition通过BeanDefinitionReaderUtils.registerBeanDefinition(definitionHolder, registry)注册到registry
AiServicesAutoConfig
dev/langchain4j/service/spring/AiServicesAutoConfig.java
public class AiServicesAutoConfig implements ApplicationEventPublisherAware {
private static final Logger log = LoggerFactory.getLogger(AiServicesAutoConfig.class);
private ApplicationEventPublisher eventPublisher;
@Override
public void setApplicationEventPublisher(ApplicationEventPublisher eventPublisher) {
this.eventPublisher = eventPublisher;
}
@Bean
BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
return beanFactory -> {
// all components available in the application context
String[] chatLanguageModels = beanFactory.getBeanNamesForType(ChatLanguageModel.class);
String[] streamingChatLanguageModels = beanFactory.getBeanNamesForType(StreamingChatLanguageModel.class);
String[] chatMemories = beanFactory.getBeanNamesForType(ChatMemory.class);
String[] chatMemoryProviders = beanFactory.getBeanNamesForType(ChatMemoryProvider.class);
String[] contentRetrievers = beanFactory.getBeanNamesForType(ContentRetriever.class);
String[] retrievalAugmentors = beanFactory.getBeanNamesForType(RetrievalAugmentor.class);
String[] moderationModels = beanFactory.getBeanNamesForType(ModerationModel.class);
Set<String> toolBeanNames = new HashSet<>();
List<ToolSpecification> toolSpecifications = new ArrayList<>();
for (String beanName : beanFactory.getBeanDefinitionNames()) {
try {
String beanClassName = beanFactory.getBeanDefinition(beanName).getBeanClassName();
if (beanClassName == null) {
continue;
}
Class<?> beanClass = Class.forName(beanClassName);
for (Method beanMethod : beanClass.getDeclaredMethods()) {
if (beanMethod.isAnnotationPresent(Tool.class)) {
toolBeanNames.add(beanName);
try {
toolSpecifications.add(ToolSpecifications.toolSpecificationFrom(beanMethod));
} catch (Exception e) {
log.warn("Cannot convert %s.%s method annotated with @Tool into ToolSpecification"
.formatted(beanClass.getName(), beanMethod.getName()), e);
}
}
}
} catch (Exception e) {
// TODO
}
}
String[] aiServices = beanFactory.getBeanNamesForAnnotation(AiService.class);
for (String aiService : aiServices) {
Class<?> aiServiceClass = beanFactory.getType(aiService);
GenericBeanDefinition aiServiceBeanDefinition = new GenericBeanDefinition();
aiServiceBeanDefinition.setBeanClass(AiServiceFactory.class);
aiServiceBeanDefinition.getConstructorArgumentValues().addGenericArgumentValue(aiServiceClass);
MutablePropertyValues propertyValues = aiServiceBeanDefinition.getPropertyValues();
AiService aiServiceAnnotation = aiServiceClass.getAnnotation(AiService.class);
addBeanReference(
ChatLanguageModel.class,
aiServiceAnnotation,
aiServiceAnnotation.chatModel(),
chatLanguageModels,
"chatModel",
"chatLanguageModel",
propertyValues
);
addBeanReference(
StreamingChatLanguageModel.class,
aiServiceAnnotation,
aiServiceAnnotation.streamingChatModel(),
streamingChatLanguageModels,
"streamingChatModel",
"streamingChatLanguageModel",
propertyValues
);
addBeanReference(
ChatMemory.class,
aiServiceAnnotation,
aiServiceAnnotation.chatMemory(),
chatMemories,
"chatMemory",
"chatMemory",
propertyValues
);
addBeanReference(
ChatMemoryProvider.class,
aiServiceAnnotation,
aiServiceAnnotation.chatMemoryProvider(),
chatMemoryProviders,
"chatMemoryProvider",
"chatMemoryProvider",
propertyValues
);
addBeanReference(
ContentRetriever.class,
aiServiceAnnotation,
aiServiceAnnotation.contentRetriever(),
contentRetrievers,
"contentRetriever",
"contentRetriever",
propertyValues
);
addBeanReference(
RetrievalAugmentor.class,
aiServiceAnnotation,
aiServiceAnnotation.retrievalAugmentor(),
retrievalAugmentors,
"retrievalAugmentor",
"retrievalAugmentor",
propertyValues
);
addBeanReference(
ModerationModel.class,
aiServiceAnnotation,
aiServiceAnnotation.moderationModel(),
moderationModels,
"moderationModel",
"moderationModel",
propertyValues
);
if (aiServiceAnnotation.wiringMode() == EXPLICIT) {
propertyValues.add("tools", toManagedList(asList(aiServiceAnnotation.tools())));
} else if (aiServiceAnnotation.wiringMode() == AUTOMATIC) {
propertyValues.add("tools", toManagedList(toolBeanNames));
} else {
throw illegalArgument("Unknown wiring mode: " + aiServiceAnnotation.wiringMode());
}
BeanDefinitionRegistry registry = (BeanDefinitionRegistry) beanFactory;
registry.removeBeanDefinition(aiService);
registry.registerBeanDefinition(lowercaseFirstLetter(aiService), aiServiceBeanDefinition);
if (eventPublisher != null) {
eventPublisher.publishEvent(new AiServiceRegisteredEvent(this, aiServiceClass, toolSpecifications));
}
}
};
}
//......
}
AiServicesAutoConfig是个import类,它定义了aiServicesRegisteringBeanFactoryPostProcessor,在postProcessBeanFactory的时候获取ChatLanguageModel、StreamingChatLanguageModel、ChatMemory、ChatMemoryProvider、ContentRetriever、RetrievalAugmentor、解析bean注解@Tool的方法、解析标注@AiService的bean,然后给它添加beanReference,之后先removeBeanDefinition再把填充好属性的aiServiceBeanDefinition(AiServiceFactory
)重新注册上去,最后发布AiServiceRegisteredEvent事件。
AiServiceFactory
dev/langchain4j/service/spring/AiServiceFactory.java
class AiServiceFactory implements FactoryBean<Object> {
private final Class<Object> aiServiceClass;
private ChatLanguageModel chatLanguageModel;
private StreamingChatLanguageModel streamingChatLanguageModel;
private ChatMemory chatMemory;
private ChatMemoryProvider chatMemoryProvider;
private ContentRetriever contentRetriever;
private RetrievalAugmentor retrievalAugmentor;
private ModerationModel moderationModel;
private List<Object> tools;
public AiServiceFactory(Class<Object> aiServiceClass) {
this.aiServiceClass = aiServiceClass;
}
public void setChatLanguageModel(ChatLanguageModel chatLanguageModel) {
this.chatLanguageModel = chatLanguageModel;
}
public void setStreamingChatLanguageModel(StreamingChatLanguageModel streamingChatLanguageModel) {
this.streamingChatLanguageModel = streamingChatLanguageModel;
}
public void setChatMemory(ChatMemory chatMemory) {
this.chatMemory = chatMemory;
}
public void setChatMemoryProvider(ChatMemoryProvider chatMemoryProvider) {
this.chatMemoryProvider = chatMemoryProvider;
}
public void setContentRetriever(ContentRetriever contentRetriever) {
this.contentRetriever = contentRetriever;
}
public void setRetrievalAugmentor(RetrievalAugmentor retrievalAugmentor) {
this.retrievalAugmentor = retrievalAugmentor;
}
public void setModerationModel(ModerationModel moderationModel) {
this.moderationModel = moderationModel;
}
public void setTools(List<Object> tools) {
this.tools = tools;
}
@Override
public Object getObject() {
AiServices<Object> builder = AiServices.builder(aiServiceClass);
if (chatLanguageModel != null) {
builder = builder.chatLanguageModel(chatLanguageModel);
}
if (streamingChatLanguageModel != null) {
builder = builder.streamingChatLanguageModel(streamingChatLanguageModel);
}
if (chatMemory != null) {
builder.chatMemory(chatMemory);
}
if (chatMemoryProvider != null) {
builder.chatMemoryProvider(chatMemoryProvider);
}
if (retrievalAugmentor != null) {
builder = builder.retrievalAugmentor(retrievalAugmentor);
} else if (contentRetriever != null) {
builder = builder.contentRetriever(contentRetriever);
}
if (moderationModel != null) {
builder = builder.moderationModel(moderationModel);
}
if (!isNullOrEmpty(tools)) {
for (Object tool : tools) {
if (isAopProxy(tool)) {
builder = builder.tools(aopEnhancedTools(tool));
} else {
builder = builder.tools(tool);
}
}
}
return builder.build();
}
@Override
public Class<?> getObjectType() {
return aiServiceClass;
}
@Override
public boolean isSingleton() {
return true; // TODO
}
//......
}
AiServiceFactory实现了FactoryBean接口,其getObject主要通过AiServices.builder(aiServiceClass)来进行构建,默认实现类是DefaultAiServices,它通过Proxy.newProxyInstance来创建实现类,InvocationHandler的实现主要是处理systemMessage、userMessage、构建chatMemory、toolExecutionContext,最后构建ChatRequest,通过context.chatModel.chat(chatRequest)执行请求,然后解析和适配输出。
小结
langchain4j-spring-boot-starter的LangChain4jAutoConfig自动import了AiServicesAutoConfig、RagAutoConfig、AiServiceScannerProcessor,其中AiServiceScannerProcessor会扫描标注@AiService的类注册到BeanDefinitionRegistry,之后AiServicesAutoConfig会在postProcessBeanFactory的时候填充好相关属性,然后移除掉之前定义的BeanDefinition,把填充好属性的aiServiceBeanDefinition(AiServiceFactory
)重新注册上去。
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。