前言
去年更新了一系列和SPI相关的内容,最近因为业务需要,我又基于业务场景,实现了一版。对于什么是spi,很久之前有写过一篇文章,java之spi机制简介感兴趣的朋友可以蛮看一下
需求分析
用过原生jdk提供的spi的朋友,应该会知道原生jdk的spi有个缺陷,就是没法实现按需加载,因此本文的实现就是来解决这个问题。
自定义SPI
核心代码
@Slf4j
public final class SpiLoader<T> {
private static final String SPI_DIRECTORY = "META-INF/spi/";
private static final Map<Class<?>, SpiLoader<?>> LOADERS = new ConcurrentHashMap<>();
private final Class<T> clazz;
private final ClassLoader classLoader;
private final Holder<Map<String, ClassEntity>> cachedClasses = new Holder<>();
private final Map<String, Holder<Object>> cachedInstances = new ConcurrentHashMap<>();
private final Map<Class<?>, Object> targetInstances = new ConcurrentHashMap<>();
/**
* Instantiates a new Extension loader.
*
* @param clazz the clazz.
*/
private SpiLoader(final Class<T> clazz, final ClassLoader cl) {
this.clazz = clazz;
this.classLoader = cl;
if (!Objects.equals(clazz, SpiFactory.class)) {
SpiLoader.getExtensionLoader(SpiFactory.class).getExtensionClassesEntity();
}
}
/**
* Gets extension loader.
*
* @param <T> the type parameter
* @param clazz the clazz
* @param cl the cl
* @return the extension loader.
*/
public static <T> SpiLoader<T> getExtensionLoader(final Class<T> clazz, final ClassLoader cl) {
Objects.requireNonNull(clazz, "extension clazz is null");
if (!clazz.isInterface()) {
throw new IllegalArgumentException("extension clazz (" + clazz + ") is not interface!");
}
SpiLoader<T> extensionLoader = (SpiLoader<T>) LOADERS.get(clazz);
if (Objects.nonNull(extensionLoader)) {
return extensionLoader;
}
LOADERS.putIfAbsent(clazz, new SpiLoader<>(clazz, cl));
return (SpiLoader<T>) LOADERS.get(clazz);
}
/**
* Gets extension loader.
*
* @param <T> the type parameter
* @param clazz the clazz
* @return the extension loader
*/
public static <T> SpiLoader<T> getExtensionLoader(final Class<T> clazz) {
return getExtensionLoader(clazz, SpiLoader.class.getClassLoader());
}
/**
* Gets target.
*
* @param name the name
* @return the target.
*/
public T getTarget(final String name) {
if (StringUtils.isBlank(name)) {
throw new NullPointerException("get target name is null");
}
Holder<Object> objectHolder = cachedInstances.get(name);
if (Objects.isNull(objectHolder)) {
cachedInstances.putIfAbsent(name, new Holder<>());
objectHolder = cachedInstances.get(name);
}
Object value = objectHolder.getValue();
if (Objects.isNull(value)) {
synchronized (cachedInstances) {
value = objectHolder.getValue();
if (Objects.isNull(value)) {
createExtension(name, objectHolder);
value = objectHolder.getValue();
}
}
}
return (T) value;
}
/**
* get all target spi.
*
* @return list. target
*/
public List<T> getTargets() {
Map<String, ClassEntity> extensionClassesEntity = this.getExtensionClassesEntity();
if (extensionClassesEntity.isEmpty()) {
return Collections.emptyList();
}
if (Objects.equals(extensionClassesEntity.size(), cachedInstances.size())) {
return (List<T>) this.cachedInstances.values().stream()
.map(Holder::getValue).collect(Collectors.toList());
}
List<T> targets = new ArrayList<>();
List<ClassEntity> classEntities = new ArrayList<>(extensionClassesEntity.values());
classEntities.forEach(v -> {
T target = this.getTarget(v.getName());
targets.add(target);
});
return targets;
}
@SuppressWarnings("unchecked")
private void createExtension(final String name, final Holder<Object> holder) {
ClassEntity classEntity = getExtensionClassesEntity().get(name);
if (Objects.isNull(classEntity)) {
throw new IllegalArgumentException(name + " name is error");
}
Class<?> aClass = classEntity.getClazz();
Object o = targetInstances.get(aClass);
if (Objects.isNull(o)) {
try {
targetInstances.putIfAbsent(aClass, aClass.newInstance());
o = targetInstances.get(aClass);
} catch (InstantiationException | IllegalAccessException e) {
throw new IllegalStateException("Extension instance(name: " + name + ", class: "
+ aClass + ") could not be instantiated: " + e.getMessage(), e);
}
}
holder.setValue(o);
}
/**
* Gets extension classes.
*
* @return the extension classes
*/
public Map<String, Class<?>> getTargetClassesMap() {
Map<String, ClassEntity> classes = this.getExtensionClassesEntity();
return classes.values().stream().collect(Collectors.toMap(ClassEntity::getName, ClassEntity::getClazz, (a, b) -> a));
}
private Map<String, ClassEntity> getExtensionClassesEntity() {
Map<String, ClassEntity> classes = cachedClasses.getValue();
if (Objects.isNull(classes)) {
synchronized (cachedClasses) {
classes = cachedClasses.getValue();
if (Objects.isNull(classes)) {
classes = loadExtensionClass();
cachedClasses.setValue(classes);
}
}
}
return classes;
}
private Map<String, ClassEntity> loadExtensionClass() {
Map<String, ClassEntity> classes = new HashMap<>(16);
loadDirectory(classes);
return classes;
}
/**
* Load files under SPI_DIRECTORY.
*/
private void loadDirectory(final Map<String, ClassEntity> classes) {
String fileName = SPI_DIRECTORY + clazz.getName();
try {
Enumeration<URL> urls = Objects.nonNull(this.classLoader) ? classLoader.getResources(fileName)
: ClassLoader.getSystemResources(fileName);
if (Objects.nonNull(urls)) {
while (urls.hasMoreElements()) {
URL url = urls.nextElement();
loadResources(classes, url);
}
}
} catch (IOException t) {
log.error("load extension class error {}", fileName, t);
}
}
private void loadResources(final Map<String, ClassEntity> classes, final URL url) throws IOException {
try (InputStream inputStream = url.openStream()) {
Properties properties = new Properties();
properties.load(inputStream);
properties.forEach((k, v) -> {
String name = (String) k;
String classPath = (String) v;
if (StringUtils.isNotBlank(name) && StringUtils.isNotBlank(classPath)) {
try {
loadClass(classes, name, classPath);
} catch (ClassNotFoundException e) {
throw new IllegalStateException("load extension resources error", e);
}
}
});
} catch (IOException e) {
throw new IllegalStateException("load extension resources error", e);
}
}
private void loadClass(final Map<String, ClassEntity> classes,
final String name, final String classPath) throws ClassNotFoundException {
Class<?> subClass = Objects.nonNull(this.classLoader) ? Class.forName(classPath, true, this.classLoader) : Class.forName(classPath);
if (!clazz.isAssignableFrom(subClass)) {
throw new IllegalStateException("load extension resources error," + subClass + " subtype is not of " + clazz);
}
ClassEntity oldClassEntity = classes.get(name);
if (Objects.isNull(oldClassEntity)) {
ClassEntity classEntity = new ClassEntity(name, subClass);
classes.put(name, classEntity);
} else if (!Objects.equals(oldClassEntity.getClazz(), subClass)) {
throw new IllegalStateException("load extension resources error,Duplicate class " + clazz.getName() + " name "
+ name + " on " + oldClassEntity.getClazz().getName() + " or " + subClass.getName());
}
}
/**
* The type Holder.
*
* @param <T> the type parameter.
*/
private static final class Holder<T> {
private volatile T value;
/**
* Gets value.
*
* @return the value
*/
public T getValue() {
return value;
}
/**
* Sets value.
*
* @param value the value
*/
public void setValue(final T value) {
this.value = value;
}
}
private static final class ClassEntity {
/**
* name.
*/
private final String name;
/**
* class.
*/
private Class<?> clazz;
private ClassEntity(final String name, final Class<?> clazz) {
this.name = name;
this.clazz = clazz;
}
/**
* get class.
*
* @return class.
*/
public Class<?> getClazz() {
return clazz;
}
/**
* set class.
*
* @param clazz class.
*/
public void setClazz(final Class<?> clazz) {
this.clazz = clazz;
}
/**
* get name.
*
* @return name.
*/
public String getName() {
return name;
}
}
}
代码解读:
从classpath类路径下查找/META-INF/spi/接口文件,并解析相关文件,将解析后的key和class类名放入本地缓存,最后根据业务实际需要,按需将class实例化为对象
示例
以mock一个不同日志门面打印为例子
1、创建日志接口
public interface LogService {
void info(String msg);
}
2、创建日志实现
public class Log4jService implements LogService {
@Override
public void info(String msg) {
System.out.println(Log4jService.class.getName() + " info: " + msg);
}
}
3、在具体实现的classpath目录下创建
/META-INF/spi/com.github.lybgeek.log.LogService文件,并填入如下内容
log4j=com.github.lybgeek.log.Log4jService
4、测试
public class LogMainTest {
public static void main(String[] args) {
LogService logService = SpiLoader.getExtensionLoader(LogService.class).getTarget("log4j2");
logService.info("log4j2-hello");
logService = SpiFactoriesLoader.loadFactories().getTarget("log4j",LogService.class);
logService.info("log4j-hello");
}
}
可以看到控制台输出如下内容
com.github.lybgeek.log.Log4j2Service info: log4j2-hello
com.github.lybgeek.log.Log4jService info: log4j-hello
总结
本文主要是实现原生SPI不支持按需加载的能力,其次本文的核心实现其实是搬dubbo的spi能力,因为我们业务场景比较简单,并不需要dubbo那么灵活的spi能力,因此在实现时,就仅仅搬了dubbo的一部分能力扩展
demo链接
https://github.com/lyb-geek/springboot-learning/tree/master/springboot-custom-spi
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。