还记得第一次遇到需要比较对象引用而非内容的场景吗?大多数 Java 开发者习惯了使用 HashMap 等集合类,它们通过 equals()方法比较键值对象。比如在处理用户信息时,我们只关心两个 User 对象的 id 是否相同,而不管它们是不是同一个对象实例。但有些场景下,这种行为却成了阻碍。比如处理对象图遍历、深拷贝或序列化时,我们更关心"这是否是同一个对象实例",而非"这两个对象的内容是否相同"。Java 集合框架中有一个不太起眼却很有用的类正是为解决这类问题设计的—— IdentityHashMap。

IdentityHashMap 基本概念

IdentityHashMap 是 java.util 包中一个特殊的 Map 实现,它使用"=="操作符而非 equals()方法来比较键。这个简单的区别改变了 Map 的整个行为模式。

graph TD
    A[HashMap] -->|"使用equals()和hashCode()"| B[比较对象内容相等性]
    C[IdentityHashMap] -->|使用==操作符| D[比较对象引用相等性]

看个简单例子,对比 IdentityHashMap 和 HashMap 的区别:

import java.util.*;

public class IdentityHashMapDemo {
    public static void main(String[] args) {
        // 创建两个内容相同的字符串
        String key1 = new String("测试键");
        String key2 = new String("测试键");

        // 确认两个键内容相同但引用不同
        System.out.println("key1.equals(key2): " + key1.equals(key2));  // true
        System.out.println("key1 == key2: " + (key1 == key2));          // false

        // 使用HashMap测试
        Map<String, String> hashMap = new HashMap<>();
        hashMap.put(key1, "HashMap值");
        System.out.println("HashMap中通过key2能否找到值: " + hashMap.get(key2)); // 能找到

        // 使用IdentityHashMap测试
        Map<String, String> identityMap = new IdentityHashMap<>();
        identityMap.put(key1, "IdentityHashMap值");
        System.out.println("IdentityHashMap中通过key2能否找到值: " + identityMap.get(key2)); // 找不到
    }
}

在这个例子中,虽然 key1 和 key2 内容相同,但它们是不同的对象引用。HashMap 认为它们相等(通过 equals()方法),而 IdentityHashMap 将它们视为不同键(通过==操作符)。

IdentityHashMap 的内部实现

IdentityHashMap 内部使用开放寻址法(open addressing)而非链表法来处理哈希冲突。具体来说,它采用线性探测法(linear probing):当发生冲突时,它会线性查找下一个可用位置。

graph TD
    subgraph IdentityHashMap内部结构
    A[table数组]
    A -->|索引0| K1[键1引用]
    A -->|索引1| V1[值1]
    A -->|索引2| K2[键2引用]
    A -->|索引3| V2[值2]
    A -->|索引4| NULL[...]
    end

键的哈希值通过 System.identityHashCode()方法计算,这个方法返回对象的内存地址相关哈希值,而不依赖于对象的 hashCode()实现。这确保了即使两个对象内容相同,只要它们是不同的实例,IdentityHashMap 就能区分它们。

与 HashMap 使用链表/红黑树处理冲突不同,开放寻址法的特点是:

  • 优点:结构更简单,内存分配更连续,对 CPU 缓存更友好
  • 缺点:当负载因子增高时,性能可能急剧下降(因此 IdentityHashMap 默认维持较低的负载因子)

IdentityHashMap 的应用场景

1. 对象图遍历(避免循环引用,精确识别已访问对象)

处理复杂对象图结构时,特别是可能包含循环引用的情况下,我们需要跟踪已访问过的对象以避免无限循环。这种场景下,对象的引用身份比内容更重要:

import java.util.*;
import java.lang.reflect.Field;

public class ObjectGraphTraversal {
    public static void traverseObjectGraph(Object obj) {
        Set<Object> visited = Collections.newSetFromMap(new IdentityHashMap<>());
        traverseObjectGraph(obj, visited);
    }

    private static void traverseObjectGraph(Object obj, Set<Object> visited) {
        if (obj == null || visited.contains(obj)) {
            return; // null或已访问过的对象,直接返回
        }

        // 标记当前对象为已访问
        visited.add(obj);
        System.out.println("访问对象: " + obj);

        // 使用反射获取并遍历当前对象的所有引用字段
        try {
            Class<?> clazz = obj.getClass();

            // 处理数组
            if (clazz.isArray() && !clazz.getComponentType().isPrimitive()) {
                Object[] array = (Object[]) obj;
                for (Object item : array) {
                    traverseObjectGraph(item, visited);
                }
                return;
            }

            // 处理集合
            if (obj instanceof Collection) {
                for (Object item : (Collection<?>) obj) {
                    traverseObjectGraph(item, visited);
                }
                return;
            }

            // 处理Map
            if (obj instanceof Map) {
                for (Map.Entry<?, ?> entry : ((Map<?, ?>) obj).entrySet()) {
                    traverseObjectGraph(entry.getKey(), visited);
                    traverseObjectGraph(entry.getValue(), visited);
                }
                return;
            }

            // 处理普通对象字段
            for (Field field : clazz.getDeclaredFields()) {
                field.setAccessible(true);

                // 只处理引用类型字段,忽略基本类型
                if (!field.getType().isPrimitive()) {
                    Object fieldValue = field.get(obj);
                    traverseObjectGraph(fieldValue, visited);
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] args) {
        List<Object> cyclicList = new ArrayList<>();
        cyclicList.add("元素1");
        cyclicList.add("元素2");
        cyclicList.add(cyclicList); // 创建循环引用

        traverseObjectGraph(cyclicList);
    }
}

如果使用普通 HashMap 或 HashSet,对于内容相同但引用不同的对象,会错误地被认为"已访问",导致遍历不完整。

2. 对象深度复制(精确跟踪已复制对象,避免重复复制)

在实现对象深拷贝时,我们需要追踪已复制过的对象,避免重复复制或陷入循环引用的无限递归:

import java.util.*;
import java.lang.reflect.*;

public class DeepCopyExample {
    private Map<Object, Object> copiedObjects = new IdentityHashMap<>();

    public Object deepCopy(Object original) {
        if (original == null) {
            return null;
        }

        // 检查是否已经复制过这个对象
        if (copiedObjects.containsKey(original)) {
            return copiedObjects.get(original);
        }

        // 根据original的类型创建新的对象实例
        Object copy = createCopyFor(original);

        // 先将创建的空对象放入映射表,防止循环引用导致的无限递归
        copiedObjects.put(original, copy);

        // 复制对象的所有字段
        copyFields(original, copy);

        return copy;
    }

    private Object createCopyFor(Object original) {
        Class<?> clazz = original.getClass();

        // 处理不可变对象或基本类型包装类
        if (original instanceof String || original instanceof Integer ||
            original instanceof Long || original instanceof Boolean ||
            original instanceof Double || original instanceof Float ||
            original instanceof Byte || original instanceof Character ||
            original instanceof Short) {
            return original;
        }

        // 处理集合类
        if (original instanceof ArrayList) {
            return new ArrayList<>();
        } else if (original instanceof HashMap) {
            return new HashMap<>();
        } else if (original instanceof HashSet) {
            return new HashSet<>();
        }

        // 处理数组
        if (clazz.isArray()) {
            Class<?> componentType = clazz.getComponentType();
            int length = Array.getLength(original);
            return Array.newInstance(componentType, length);
        }

        // 处理自定义类
        try {
            // 尝试找到无参构造器
            Constructor<?> constructor = clazz.getDeclaredConstructor();
            constructor.setAccessible(true);
            return constructor.newInstance();
        } catch (NoSuchMethodException e) {
            throw new RuntimeException("类 " + clazz.getName() + " 缺少无参构造器,无法深拷贝", e);
        } catch (Exception e) {
            throw new RuntimeException("无法创建" + clazz.getName() + "的实例", e);
        }
    }

    private void copyFields(Object original, Object copy) {
        Class<?> clazz = original.getClass();

        // 处理集合类型
        if (original instanceof ArrayList && copy instanceof ArrayList) {
            ArrayList<?> originalList = (ArrayList<?>) original;
            ArrayList<Object> copyList = (ArrayList<Object>) copy;

            // 递归复制列表中的每个元素
            for (Object item : originalList) {
                copyList.add(deepCopy(item));
            }
        } else if (original instanceof HashMap && copy instanceof HashMap) {
            HashMap<?, ?> originalMap = (HashMap<?, ?>) original;
            HashMap<Object, Object> copyMap = (HashMap<Object, Object>) copy;

            // 递归复制Map中的每个键值对
            for (Map.Entry<?, ?> entry : originalMap.entrySet()) {
                Object keyClone = deepCopy(entry.getKey());
                Object valueClone = deepCopy(entry.getValue());
                copyMap.put(keyClone, valueClone);
            }
        } else if (clazz.isArray()) {
            // 处理数组
            int length = Array.getLength(original);
            for (int i = 0; i < length; i++) {
                Object originalItem = Array.get(original, i);
                Object copyItem = deepCopy(originalItem);
                Array.set(copy, i, copyItem);
            }
        } else {
            // 处理自定义类,复制所有字段(包括继承的)
            for (Class<?> c = clazz; c != Object.class; c = c.getSuperclass()) {
                Field[] fields = c.getDeclaredFields();
                for (Field field : fields) {
                    if (Modifier.isStatic(field.getModifiers())) {
                        continue; // 跳过静态字段
                    }

                    field.setAccessible(true);
                    try {
                        Object value = field.get(original);
                        Object valueCopy = deepCopy(value);
                        field.set(copy, valueCopy);
                    } catch (Exception e) {
                        throw new RuntimeException("复制字段 " + field.getName() + " 出错", e);
                    }
                }
            }
        }
    }
}

3. 序列化和反序列化(维护对象引用一致性)

在序列化和反序列化过程中,需要保持对象引用的一致性,避免重复序列化同一个对象:

import java.util.*;
import java.io.*;

public class SerializationHelper {
    // 序列化部分
    public static class Serializer {
        private Map<Object, Integer> serializedObjects = new IdentityHashMap<>();
        private DataOutputStream outputStream;

        public Serializer(OutputStream out) {
            this.outputStream = new DataOutputStream(out);
        }

        public void serialize(Object obj) throws IOException {
            serialize(obj, 0);
        }

        private void serialize(Object obj, int objectId) throws IOException {
            // 检查对象是否已经序列化过
            if (obj == null) {
                // 写入null标记
                outputStream.writeInt(-2);
                return;
            }

            if (serializedObjects.containsKey(obj)) {
                // 写入引用标记和已存在对象的ID
                int existingId = serializedObjects.get(obj);
                outputStream.writeInt(-1); // 引用标记
                outputStream.writeInt(existingId);
                return;
            }

            // 记录新对象
            serializedObjects.put(obj, objectId);

            // 序列化对象类型信息
            outputStream.writeInt(objectId);
            outputStream.writeUTF(obj.getClass().getName());

            // 根据对象类型执行不同的序列化逻辑
            if (obj instanceof String) {
                outputStream.writeUTF((String) obj);
            } else if (obj instanceof Integer) {
                outputStream.writeInt((Integer) obj);
            } else if (obj instanceof ArrayList) {
                ArrayList<?> list = (ArrayList<?>) obj;
                // 写入列表大小
                outputStream.writeInt(list.size());
                // 序列化每个元素
                int nextId = objectId + 1;
                for (Object element : list) {
                    serialize(element, nextId++);
                }
            }
            // 其他类型的序列化逻辑...
        }
    }

    // 反序列化部分
    public static class Deserializer {
        private Map<Integer, Object> deserializedObjects = new HashMap<>();
        private DataInputStream inputStream;

        public Deserializer(InputStream in) {
            this.inputStream = new DataInputStream(in);
        }

        public Object deserialize() throws IOException, ClassNotFoundException {
            int marker = inputStream.readInt();

            // 处理特殊标记
            if (marker == -2) {
                return null; // null对象
            } else if (marker == -1) {
                // 引用已存在对象
                int objectId = inputStream.readInt();
                return deserializedObjects.get(objectId);
            }

            // 否则是新对象,读取类型
            int objectId = marker;
            String className = inputStream.readUTF();
            Class<?> clazz = Class.forName(className);

            // 根据类型创建对象并执行反序列化
            Object result = null;

            if (clazz == String.class) {
                result = inputStream.readUTF();
            } else if (clazz == Integer.class) {
                result = inputStream.readInt();
            } else if (clazz == ArrayList.class) {
                int size = inputStream.readInt();
                ArrayList<Object> list = new ArrayList<>(size);
                // 先记录对象,再填充内容,处理循环引用
                deserializedObjects.put(objectId, list);

                // 读取列表元素
                for (int i = 0; i < size; i++) {
                    list.add(deserialize());
                }
                result = list;
            } else {
                // 其他类型的反序列化逻辑...
            }

            // 如果尚未记录对象(处理基本类型等),现在记录
            if (!deserializedObjects.containsKey(objectId)) {
                deserializedObjects.put(objectId, result);
            }

            return result;
        }
    }
}

IdentityHashMap 性能特点

IdentityHashMap 在特定情况下的性能表现与 HashMap 有明显差异。由于使用==操作符而不调用 equals()和 hashCode()方法,当处理内容相等但引用不同的对象,或处理 equals()计算复杂的对象时,IdentityHashMap 通常提供更快的查找性能。

graph LR
    subgraph 性能对比
    A[操作类型] --- B[HashMap] --- C[IdentityHashMap]
    D[引用比较] --- E["需调用equals()"] --- F[直接使用==]
    G[内存结构] --- H[链表/红黑树节点] --- I[紧凑数组存储]
    J[哈希冲突处理] --- K[链表/红黑树] --- L[开放寻址法]
    M[高负载性能] --- N["较稳定(红黑树O(log n))"] --- O["退化严重(最坏O(n))"]
    end

随着负载因子增加,IdentityHashMap 的性能变化:

graph LR
    subgraph "负载因子对性能的影响"
    A[负载因子] --- B[0.1] --- C[0.5] --- D[0.7] --- E[0.9]
    F[IdentityHashMap性能] --- G[极佳] --- H[良好] --- I[下降] --- J[严重退化]
    K[HashMap性能] --- L[良好] --- M[良好] --- N[变化不大] --- O[链表变红黑树]
    end

下面是一个详细的性能测试对比:

import java.util.*;

public class MapPerformanceTest {
    public static void main(String[] args) {
        final int COUNT = 1000000;

        // 准备测试数据 - 普通字符串键
        String[] keys = new String[COUNT];
        for (int i = 0; i < COUNT; i++) {
            keys[i] = new String("key" + i);
        }

        // 准备一个equals方法耗时较长的对象类型
        ComplexKey[] complexKeys = new ComplexKey[COUNT];
        for (int i = 0; i < COUNT; i++) {
            complexKeys[i] = new ComplexKey("key" + i, i);
        }

        // 测试HashMap - 普通键
        Map<String, Integer> hashMap = new HashMap<>();
        long start = System.nanoTime();
        for (int i = 0; i < COUNT; i++) {
            hashMap.put(keys[i], i);
        }
        for (int i = 0; i < COUNT; i++) {
            hashMap.get(keys[i]);
        }
        long hashMapTime = System.nanoTime() - start;

        // 测试IdentityHashMap - 普通键
        Map<String, Integer> identityMap = new IdentityHashMap<>();
        start = System.nanoTime();
        for (int i = 0; i < COUNT; i++) {
            identityMap.put(keys[i], i);
        }
        for (int i = 0; i < COUNT; i++) {
            identityMap.get(keys[i]);
        }
        long identityMapTime = System.nanoTime() - start;

        // 测试HashMap - 复杂键
        Map<ComplexKey, Integer> complexHashMap = new HashMap<>();
        start = System.nanoTime();
        for (int i = 0; i < COUNT; i++) {
            complexHashMap.put(complexKeys[i], i);
        }
        for (int i = 0; i < COUNT; i++) {
            complexHashMap.get(complexKeys[i]);
        }
        long complexHashMapTime = System.nanoTime() - start;

        // 测试IdentityHashMap - 复杂键
        Map<ComplexKey, Integer> complexIdentityMap = new IdentityHashMap<>();
        start = System.nanoTime();
        for (int i = 0; i < COUNT; i++) {
            complexIdentityMap.put(complexKeys[i], i);
        }
        for (int i = 0; i < COUNT; i++) {
            complexIdentityMap.get(complexKeys[i]);
        }
        long complexIdentityMapTime = System.nanoTime() - start;

        // 输出性能测试结果
        System.out.println("普通键 - HashMap时间: " + hashMapTime / 1000000 + "ms");
        System.out.println("普通键 - IdentityHashMap时间: " + identityMapTime / 1000000 + "ms");
        System.out.println("复杂键 - HashMap时间: " + complexHashMapTime / 1000000 + "ms");
        System.out.println("复杂键 - IdentityHashMap时间: " + complexIdentityMapTime / 1000000 + "ms");
    }

    // 一个equals方法耗时较长的复杂键
    static class ComplexKey {
        private String id;
        private int num;
        private long[] data;

        public ComplexKey(String id, int num) {
            this.id = id;
            this.num = num;
            this.data = new long[100];
            for (int i = 0; i < data.length; i++) {
                data[i] = i * num;
            }
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;

            // 模拟复杂的相等性比较,消耗更多时间
            ComplexKey that = (ComplexKey) o;
            if (num != that.num) return false;
            if (!id.equals(that.id)) return false;

            // 比较整个数组内容
            for (int i = 0; i < data.length; i++) {
                if (data[i] != that.data[i]) return false;
            }
            return true;
        }

        @Override
        public int hashCode() {
            int result = id.hashCode();
            result = 31 * result + num;
            for (int i = 0; i < 10; i++) {
                result = 31 * result + (int)(data[i] ^ (data[i] >>> 32));
            }
            return result;
        }
    }
}

性能测试结果分析:

  • 普通字符串键场景:IdentityHashMap 通常快 10%-20%,因为它跳过了 equals()调用
  • 复杂键场景:IdentityHashMap 可能快 50%以上,因为复杂键的 equals()方法耗时较长

不过,测试结果会受 JVM 优化、数据规模、冲突率等因素影响,实际应用中的性能差异可能有所不同。

使用建议与注意事项

使用 IdentityHashMap 时需要注意:

  1. 只在确实需要基于引用相等的场景使用,大多数业务场景更适合使用 HashMap
  2. 不是线程安全的,多线程环境需要外部同步:
Map<K, V> synchronizedMap = Collections.synchronizedMap(new IdentityHashMap<>());
  1. 违反了 Map 接口的通用约定,因为它不使用 equals()判断键相等,这会导致一些意外行为:
  • 当 IdentityHashMap 的键作为普通 HashMap 的键时,可能因 equals()和==的差异导致查找失败
  • 即使内容完全相同的两个对象在 IdentityHashMap 中会被视为不同的键
  1. 不要依赖迭代顺序,它可能与插入顺序不同
  2. 内存占用可能高于预期,由于开放寻址法需要维持较低的负载因子以避免性能下降

一个容易踩的坑是使用字符串字面量作为键:

IdentityHashMap<String, String> map = new IdentityHashMap<>();
map.put("key", "value1");
System.out.println(map.get("key")); // 可能返回null!

这是因为 Java 会对字符串字面量进行池化,两个"key"可能指向不同的对象引用。使用时应当保存键的引用:

String key = "key";
map.put(key, "value1");
System.out.println(map.get(key)); // 正确获取到值

总结

特性IdentityHashMapHashMap
比较方式== (引用相等)equals() (内容相等)
哈希函数System.identityHashCode()对象的 hashCode()
冲突解决开放寻址法(线性探测)链表/红黑树
内存占用因开放寻址法需预留空间,可能略高紧凑结构,通常较低
高负载性能显著下降较为稳定
典型应用对象图遍历、深拷贝、缓存一般映射需求
特别适合场景处理对象身份而非内容基于内容相等性的映射

异常君
1 声望2 粉丝

在 Java 的世界里,永远有下一座技术高峰等着你。我愿做你登山路上的同频伙伴,陪你从看懂代码到写出让自己骄傲的代码。咱们,代码里见!