一、背景

通过感知算法(比如pHash等)可以将一张图片处理成一个64bit的二进制字符串。这样,通过计算两张图片的字符串的hamming距离,我们就可以判断两张图片是否相似(一般,两张图片的hamming距离d<=5即可认为相似)。

业务场景是,需要对图片进行鉴定,之后打标签并把标签保存起来。这样我们就有了一个集合,key是图片处理成的64bit二进制字符串,value是对应图片的标签(比如色情等)。每当获取到一个新的图片,先在集合中进行搜索,查看是否有相似图片(即,在集合中是否存在一个key,使得hamming距离d<=5)。如果存在,则直接获取对应的标签返回,如果不存在,则进行鉴定打标签(另外一个算法,这里不关心)并将结果放入集合。(有点缓存的味道)

那么,在这里的难题是,如何在一定时间内(比如100ms内),在集合中获取到近似的图片。这个难题抽象为,在一定时间内,如何在一个二进制字符串的集合中搜索出hamming距离d<=5的字符串。

二、现有的解决方案于局限性

对于这个问题其实网上已经有比较通用的解决方案,可以参考 这个链接 。按照这个链接,可选的算法有

但是这些算法是通过将集合内的字符串全部放到内存中,然后通过算法搜索出符合要求的字符串。这就导致一个问题:当集合中的二进制字符串越来越多,对于执行这个算法的服务要求的内存也越来越大。

当然,如果把这个抽离为独立的服务来维护,这点内存是接受的。但是目前在我们项目中并没有打算把近似图片抽离成一个服务的计划,只是一个当成一个小模块来使用。那么,为了尽量减少服务内存的占用以及搜索的速度,上面的算法就都不可行。

三、解决思路

为了减少占用服务内存,计算是将集合放到redis中,这样占用服务的内存就不会增加多少(当然redis内存会是会增大的)。如何搭配redis设计好一个算法?

1 暴力算法

每次获取到一个新图片转成的二进制字符串后,去取存放在redis中所有的二进制字符串,逐个的比较hamming距离,一直到找到符合要求的字符串或者全部遍历一次为止。
这种算法在时间上有瑕疵,当集合足够大,很明显时间就会很长。

2 针对64bit字符串的优化

2.1 原始思路

搜索算法的优化第一直觉就是剪枝。我们可以考虑,我们其实不需要每次都将一个64bit二进制字符串都遍历一次才知道他是否符合要求的。可以将集合中和新获取的64bit字符串s分为8bit一组的子串,s1 s2 ... s8,即 s1 + s2 + ... + s8 = s。
no1.png

接着还是对集合的元素进行遍历,逐个与新字符串进行比较。跟之前的区别在于:之前是64bit一起比较,现在是按照顺序从s1开始到s8进行比较,每次比较完记录对应的距离d1,d2 ... d8。假设比较到某个di(1<=i<=8),sum(di)>5,即可停止比较。

no2.png

2.2 利用redis

将64bit字符串分成8份以后,需要将他们分别放到redis中。设计为一个hash结构,key是8个固定值,'key1','key2',...,'key8',field则是集合中每个字符串对应的s(i), value 则是原本的64bit字符串(对,这里其实重复存取了8份)。结构如下

no3.png

针对待搜索字符串str中每个s(i),

  • 我们从 0000,0000 到 1111,1111 进行遍历,找出所有与s(i) 的hamming距离d<=5的字符串(根据计算,应该有218个),设为 s(i,j) (0<=j<=218)

    no4png.png

  • 通过redis的hget(key(i), s(i,j)) 查看集合是否存在对应的field,如果不存在说明hamming距离已经大于5,则抛弃;如果存在,记录对应的haming距离d的值 以及 对应的原64bit字符串str。因为有8个部分,所以会遍历8次,每次将得到的字符串存到对应的list,计为l(i),其中0<=i<=8
  • 计算最后sum(d)>5,则抛弃
2.3 整合

经过2.2以后,我们可以得到一个列表l(i),记录每一部分得到的子串,结构如下:

no5.png

其中,i,j,k 相互独立,0<= i+n, j+n, k+n <= 集合数量

这8个list,每一个list代表里面的值满足d<=5,但是还存在两个问题:

  • 8个list中是否存在相同的str

  • 8个list中相同的str的sum(d)是否小于等于5

很明显,我们只需要每个list遍历一次,计算哪个str出现了8次即可。时间复杂度是O(n),n代表8个list的个数和。然后计算出现了8次的str的sum(d) ,找出小于等于5的值即可。

四、代码如下

private static String PIC_HASH_ = "KEY";

        // hash:图片处理后得到的64bit字符串
    private String isSamePic(String hash){
          // 切割成8bit一份的子串
        List<String> hashs = Splitter.fixedLength(8).splitToList(hash);
        Map<String, Integer> value1List = new HashMap<>();
        Map<String, Integer> value2List = new HashMap<>();
        Map<String, Integer> value3List = new HashMap<>();
        Map<String, Integer> value4List = new HashMap<>();
        Map<String, Integer> value5List = new HashMap<>();
        Map<String, Integer> value6List = new HashMap<>();
        Map<String, Integer> value7List = new HashMap<>();
        Map<String, Integer> value8List = new HashMap<>();
      
          // 从第一部分到第八部分遍历
        for (int i = 1; i <= hashs.size(); i++) {
            String key = PIC_HASH_ + i;
              // 遍历8bit所有可能,可到有嫌疑的字符串
            Map<Integer, Integer> reviewList = getReviewList(hashs.get(i - 1));
              // 过滤:嫌疑数据是否在redis存在,如果不存在,说明d>5
            Map<String, Integer> valueAndDistince = filterReviewList(reviewList, key);
            switch (i){
                case 1:value1List = valueAndDistince;break;
                case 2:value2List = valueAndDistince;break;
                case 3:value3List = valueAndDistince;break;
                case 4:value4List = valueAndDistince;break;
                case 5:value5List = valueAndDistince;break;
                case 6:value6List = valueAndDistince;break;
                case 7:value7List = valueAndDistince;break;
                case 8:value8List = valueAndDistince;break;
            }
        }
      
          // 过滤:找到出现8次的字符串
        Map<String, Integer> valueCnt = new HashMap<>();
        for (String value : value1List.keySet()) {
            Integer cnt = valueCnt.get(value);
            if (cnt == null){
                valueCnt.put(value, 1);
            } else {
                valueCnt.put(value, cnt+1);
            }
        }

        for (String value : value2List.keySet()) {
            Integer cnt = valueCnt.get(value);
            if (cnt == null){
                valueCnt.put(value, 1);
            } else {
                valueCnt.put(value, cnt+1);
            }
        }

        for (String value : value3List.keySet()) {
            Integer cnt = valueCnt.get(value);
            if (cnt == null){
                valueCnt.put(value, 1);
            } else {
                valueCnt.put(value, cnt+1);
            }
        }

        for (String value : value4List.keySet()) {
            Integer cnt = valueCnt.get(value);
            if (cnt == null){
                valueCnt.put(value, 1);
            } else {
                valueCnt.put(value, cnt+1);
            }
        }

        for (String value : value5List.keySet()) {
            Integer cnt = valueCnt.get(value);
            if (cnt == null){
                valueCnt.put(value, 1);
            } else {
                valueCnt.put(value, cnt+1);
            }
        }

        for (String value : value6List.keySet()) {
            Integer cnt = valueCnt.get(value);
            if (cnt == null){
                valueCnt.put(value, 1);
            } else {
                valueCnt.put(value, cnt+1);
            }
        }

        for (String value : value7List.keySet()) {
            Integer cnt = valueCnt.get(value);
            if (cnt == null){
                valueCnt.put(value, 1);
            } else {
                valueCnt.put(value, cnt+1);
            }
        }

        for (String value : value8List.keySet()) {
            Integer cnt = valueCnt.get(value);
            if (cnt == null){
                valueCnt.put(value, 1);
            } else {
                valueCnt.put(value, cnt+1);
            }
        }

          // 找到的字符串中,查看总距离是否小于等于5
        for (Map.Entry<String, Integer> entry : valueCnt.entrySet()) {
            if (entry.getValue() == 8){
                List<String> resultHash = Splitter.fixedLength(8).splitToList(entry.getKey());
                Integer v1 = value1List.get(resultHash.get(0));
                Integer v2 = value1List.get(resultHash.get(1));
                Integer v3 = value1List.get(resultHash.get(2));
                Integer v4 = value1List.get(resultHash.get(3));
                Integer v5 = value1List.get(resultHash.get(4));
                Integer v6 = value1List.get(resultHash.get(5));
                Integer v7 = value1List.get(resultHash.get(6));
                Integer v8 = value1List.get(resultHash.get(7));
                if (v1 + v2 + v3 + v4 + v5 + v6 + v7 + v8 <=5){
                    return entry.getKey();
                }
            }
        }

        return null;
    }

    /**
     * 过滤一部分数据
     * @param reviewList        8bit 以及 对应的d
     * @param key               key,第几部分
     * @return                  8bit对应的hash以及8bit对应的d
     */
    private Map<String, Integer> filterReviewList(Map<Integer, Integer> reviewList, String key){
        Map<String, Integer> hashs = new HashMap<>();
        Iterator<Map.Entry<Integer, Integer>> iterator = reviewList.entrySet().iterator();
        while (iterator.hasNext()){
            Map.Entry<Integer, Integer> entry = iterator.next();
            String reviewStr = toBinary(entry.getKey(), 8);
            String value = redis.hget(key, reviewStr);
            if (StringUtils.isBlank(value)) {
                iterator.remove();
            } else {
                hashs.put(value, entry.getValue());
            }
        }
        return hashs;
    }



        // 从 0000,0000 遍历到 1111,1111 得到与hash的d<=5的值,有218个
    private Map<Integer, Integer> getReviewList(String hash){
        Map<Integer, Integer> reviewMap = new HashMap<>();
        for (int i = 0; i < 256; i++) {
            int distance = hammingDistance(i, Integer.parseInt(hash, 2));
            if (distance<=5){
                reviewMap.put(i, distance);
            }
        }
        return reviewMap;
    }

        // 计算hamming距离
    public int hammingDistance(int x, int y) {
        int res = x ^ y;
        int cnt = 0;
        while(res!=0){
            cnt++;
            res =  res & (res-1);
        }
        return cnt;
    }

    /**
     * 将一个int数字转换为二进制的字符串形式。
     * @param num 需要转换的int类型数据
     * @param digits 要转换的二进制位数,位数不足则在前面补0
     * @return 二进制的字符串形式
     */
    public static String toBinary(int num, int digits) {
        String cover = Integer.toBinaryString(1 << digits).substring(1);
        String s = Integer.toBinaryString(num);
        return s.length() < digits ? cover.substring(s.length()) + s : s;
    }

五、参考链接


CheukKwan
0 声望0 粉丝