1

前言

去年更新了一系列和SPI相关的内容,最近因为业务需要,我又基于业务场景,实现了一版。对于什么是spi,很久之前有写过一篇文章,java之spi机制简介感兴趣的朋友可以蛮看一下

需求分析

用过原生jdk提供的spi的朋友,应该会知道原生jdk的spi有个缺陷,就是没法实现按需加载,因此本文的实现就是来解决这个问题。

自定义SPI

核心代码

@Slf4j
public final class SpiLoader<T> {
    
    
    private static final String SPI_DIRECTORY = "META-INF/spi/";
    
    private static final Map<Class<?>, SpiLoader<?>> LOADERS = new ConcurrentHashMap<>();
    
    private final Class<T> clazz;
    
    private final ClassLoader classLoader;
    
    private final Holder<Map<String, ClassEntity>> cachedClasses = new Holder<>();
    
    private final Map<String, Holder<Object>> cachedInstances = new ConcurrentHashMap<>();
    
    private final Map<Class<?>, Object> targetInstances = new ConcurrentHashMap<>();
    

    
    /**
     * Instantiates a new Extension loader.
     *
     * @param clazz the clazz.
     */
    private SpiLoader(final Class<T> clazz, final ClassLoader cl) {
        this.clazz = clazz;
        this.classLoader = cl;
        if (!Objects.equals(clazz, SpiFactory.class)) {
            SpiLoader.getExtensionLoader(SpiFactory.class).getExtensionClassesEntity();
        }
    }
    
    /**
     * Gets extension loader.
     *
     * @param <T>   the type parameter
     * @param clazz the clazz
     * @param cl    the cl
     * @return the extension loader.
     */
    public static <T> SpiLoader<T> getExtensionLoader(final Class<T> clazz, final ClassLoader cl) {
        
        Objects.requireNonNull(clazz, "extension clazz is null");
        
        if (!clazz.isInterface()) {
            throw new IllegalArgumentException("extension clazz (" + clazz + ") is not interface!");
        }

        SpiLoader<T> extensionLoader = (SpiLoader<T>) LOADERS.get(clazz);
        if (Objects.nonNull(extensionLoader)) {
            return extensionLoader;
        }
        LOADERS.putIfAbsent(clazz, new SpiLoader<>(clazz, cl));
        return (SpiLoader<T>) LOADERS.get(clazz);
    }
    
    /**
     * Gets extension loader.
     *
     * @param <T>   the type parameter
     * @param clazz the clazz
     * @return the extension loader
     */
    public static <T> SpiLoader<T> getExtensionLoader(final Class<T> clazz) {
        return getExtensionLoader(clazz, SpiLoader.class.getClassLoader());
    }
    

    
    /**
     * Gets target.
     *
     * @param name the name
     * @return the target.
     */
    public T getTarget(final String name) {
        if (StringUtils.isBlank(name)) {
            throw new NullPointerException("get target name is null");
        }
        Holder<Object> objectHolder = cachedInstances.get(name);
        if (Objects.isNull(objectHolder)) {
            cachedInstances.putIfAbsent(name, new Holder<>());
            objectHolder = cachedInstances.get(name);
        }
        Object value = objectHolder.getValue();
        if (Objects.isNull(value)) {
            synchronized (cachedInstances) {
                value = objectHolder.getValue();
                if (Objects.isNull(value)) {
                    createExtension(name, objectHolder);
                    value = objectHolder.getValue();
                }
            }
        }
        return (T) value;
    }
    
    /**
     * get all target spi.
     *
     * @return list. target
     */
    public List<T> getTargets() {
        Map<String, ClassEntity> extensionClassesEntity = this.getExtensionClassesEntity();
        if (extensionClassesEntity.isEmpty()) {
            return Collections.emptyList();
        }
        if (Objects.equals(extensionClassesEntity.size(), cachedInstances.size())) {
            return (List<T>) this.cachedInstances.values().stream()
                    .map(Holder::getValue).collect(Collectors.toList());
        }
        List<T> targets = new ArrayList<>();
        List<ClassEntity> classEntities = new ArrayList<>(extensionClassesEntity.values());
        classEntities.forEach(v -> {
            T target = this.getTarget(v.getName());
            targets.add(target);
        });
        return targets;
    }
    
    @SuppressWarnings("unchecked")
    private void createExtension(final String name, final Holder<Object> holder) {
        ClassEntity classEntity = getExtensionClassesEntity().get(name);
        if (Objects.isNull(classEntity)) {
            throw new IllegalArgumentException(name + " name is error");
        }
        Class<?> aClass = classEntity.getClazz();
        Object o = targetInstances.get(aClass);
        if (Objects.isNull(o)) {
            try {
                targetInstances.putIfAbsent(aClass, aClass.newInstance());
                o = targetInstances.get(aClass);
            } catch (InstantiationException | IllegalAccessException e) {
                throw new IllegalStateException("Extension instance(name: " + name + ", class: "
                        + aClass + ")  could not be instantiated: " + e.getMessage(), e);
                
            }
        }
        holder.setValue(o);
    }
    
    /**
     * Gets extension classes.
     *
     * @return the extension classes
     */
    public Map<String, Class<?>> getTargetClassesMap() {
        Map<String, ClassEntity> classes = this.getExtensionClassesEntity();
        return classes.values().stream().collect(Collectors.toMap(ClassEntity::getName, ClassEntity::getClazz, (a, b) -> a));
    }
    
    private Map<String, ClassEntity> getExtensionClassesEntity() {
        Map<String, ClassEntity> classes = cachedClasses.getValue();
        if (Objects.isNull(classes)) {
            synchronized (cachedClasses) {
                classes = cachedClasses.getValue();
                if (Objects.isNull(classes)) {
                    classes = loadExtensionClass();
                    cachedClasses.setValue(classes);
                }
            }
        }
        return classes;
    }
    
    private Map<String, ClassEntity> loadExtensionClass() {
        Map<String, ClassEntity> classes = new HashMap<>(16);
        loadDirectory(classes);
        return classes;
    }
    
    /**
     * Load files under SPI_DIRECTORY.
     */
    private void loadDirectory(final Map<String, ClassEntity> classes) {
        String fileName = SPI_DIRECTORY + clazz.getName();
        try {
            Enumeration<URL> urls = Objects.nonNull(this.classLoader) ? classLoader.getResources(fileName)
                    : ClassLoader.getSystemResources(fileName);
            if (Objects.nonNull(urls)) {
                while (urls.hasMoreElements()) {
                    URL url = urls.nextElement();
                    loadResources(classes, url);
                }
            }
        } catch (IOException t) {
            log.error("load extension class error {}", fileName, t);
        }
    }
    
    private void loadResources(final Map<String, ClassEntity> classes, final URL url) throws IOException {
        try (InputStream inputStream = url.openStream()) {
            Properties properties = new Properties();
            properties.load(inputStream);
            properties.forEach((k, v) -> {
                String name = (String) k;
                String classPath = (String) v;
                if (StringUtils.isNotBlank(name) && StringUtils.isNotBlank(classPath)) {
                    try {
                        loadClass(classes, name, classPath);
                    } catch (ClassNotFoundException e) {
                        throw new IllegalStateException("load extension resources error", e);
                    }
                }
            });
        } catch (IOException e) {
            throw new IllegalStateException("load extension resources error", e);
        }
    }
    
    private void loadClass(final Map<String, ClassEntity> classes,
                           final String name, final String classPath) throws ClassNotFoundException {
        Class<?> subClass = Objects.nonNull(this.classLoader) ? Class.forName(classPath, true, this.classLoader) : Class.forName(classPath);
        if (!clazz.isAssignableFrom(subClass)) {
            throw new IllegalStateException("load extension resources error," + subClass + " subtype is not of " + clazz);
        }

        ClassEntity oldClassEntity = classes.get(name);
        if (Objects.isNull(oldClassEntity)) {
            ClassEntity classEntity = new ClassEntity(name,  subClass);
            classes.put(name, classEntity);
        } else if (!Objects.equals(oldClassEntity.getClazz(), subClass)) {
            throw new IllegalStateException("load extension resources error,Duplicate class " + clazz.getName() + " name "
                    + name + " on " + oldClassEntity.getClazz().getName() + " or " + subClass.getName());
        }
    }
    
    /**
     * The type Holder.
     *
     * @param <T> the type parameter.
     */
    private static final class Holder<T> {
        
        private volatile T value;

        
        /**
         * Gets value.
         *
         * @return the value
         */
        public T getValue() {
            return value;
        }
        
        /**
         * Sets value.
         *
         * @param value the value
         */
        public void setValue(final T value) {
            this.value = value;
        }
    }
    
    private static final class ClassEntity {
        
        /**
         * name.
         */
        private final String name;
        

        /**
         * class.
         */
        private Class<?> clazz;
        
        private ClassEntity(final String name, final Class<?> clazz) {
            this.name = name;
            this.clazz = clazz;
        }
        
        /**
         * get class.
         *
         * @return class.
         */
        public Class<?> getClazz() {
            return clazz;
        }
        
        /**
         * set class.
         *
         * @param clazz class.
         */
        public void setClazz(final Class<?> clazz) {
            this.clazz = clazz;
        }
        
        /**
         * get name.
         *
         * @return name.
         */
        public String getName() {
            return name;
        }
        

    }
}

代码解读:

从classpath类路径下查找/META-INF/spi/接口文件,并解析相关文件,将解析后的key和class类名放入本地缓存,最后根据业务实际需要,按需将class实例化为对象

示例

以mock一个不同日志门面打印为例子

1、创建日志接口
public interface LogService {

    void info(String msg);
}
2、创建日志实现
public class Log4jService implements LogService {
    @Override
    public void info(String msg) {
        System.out.println(Log4jService.class.getName() + " info: " + msg);
    }
}
3、在具体实现的classpath目录下创建

/META-INF/spi/com.github.lybgeek.log.LogService文件,并填入如下内容

log4j=com.github.lybgeek.log.Log4jService
4、测试
public class LogMainTest {
    public static void main(String[] args) {
        LogService logService = SpiLoader.getExtensionLoader(LogService.class).getTarget("log4j2");
        logService.info("log4j2-hello");

        logService = SpiFactoriesLoader.loadFactories().getTarget("log4j",LogService.class);
        logService.info("log4j-hello");

    }
}

可以看到控制台输出如下内容

com.github.lybgeek.log.Log4j2Service info: log4j2-hello
com.github.lybgeek.log.Log4jService info: log4j-hello

总结

本文主要是实现原生SPI不支持按需加载的能力,其次本文的核心实现其实是搬dubbo的spi能力,因为我们业务场景比较简单,并不需要dubbo那么灵活的spi能力,因此在实现时,就仅仅搬了dubbo的一部分能力扩展

demo链接

https://github.com/lyb-geek/springboot-learning/tree/master/springboot-custom-spi


linyb极客之路
333 声望192 粉丝