统计数组中小于4的数

假设在一个1千万长度的数组中,基本可以肯定只有几十个数小于4。如果要统计具体小于4的数有多少,是否有比较高效率的方法。

我现在的思路是,因为小于4的数很少,是否能减少if分支的数量,比如四个数为一组,因为即使四个位一组,if的结果也99%都为false,应该有利于分支预测,但是实际测试下来并没有一个提升。

if(pb[i]<4||pb[i+1]<4||pb[i+2]<4||pb[i+3]<4){
    //在逐一判断一次,看具体有几个数小于4

}

现在我想通过avx2指令做类似上述的操作。
通过_mm256_cmpgt_epi64指令一次比较四个数,但是返回的也是__m256i类型,而且为true时,64个位填充的0xFFFFFFFFFFFFFFFF。

这样的话,问题就转变成判读__m256i中的每个数是否都含有1,但是我找了半天没发现有这样的指令。

我也不确定这样的做法是否有效率,只是想到了就想测试一下。求大家解惑。

阅读 3.1k
3 个回答

确实可以用simd优化这个过程,下边这个函数,统计ms中小于5的数据有多少,假设数组中大概35个小于等于4的元素。

int test(long long ms[], int length) {
    int similar = 0;
    for (int i = 0; i < length; i += 1)
    {
        if (ms[i] <= 4) {
            similar++;
        }
    }
    ms[0] = similar;
    return similar;
}

我用visiual studio 2019,分别使用msvc和llvm编译器,msvc编译出来的是循环展开成5,逐次判断是否小于等于4,大概耗时6毫秒。

用llvm编译出来的是使用通过avx2指令集版本,一次循环比较16个元素的大小。耗时4.5毫秒左右。

下边是循环体部分的汇编:

00007FF70DD91240  vpcmpgtq    ymm5,ymm1,ymmword ptr [rdx+rax*8]  
00007FF70DD91246  vextracti128 xmm6,ymm5,1  
00007FF70DD9124C  vpackssdw   xmm5,xmm5,xmm6  
00007FF70DD91250  vpcmpgtq    ymm6,ymm1,ymmword ptr [rdx+rax*8+20h]  
00007FF70DD91257  vpsubd      xmm0,xmm0,xmm5  
00007FF70DD9125B  vextracti128 xmm5,ymm6,1  
00007FF70DD91261  vpackssdw   xmm5,xmm6,xmm5  
00007FF70DD91265  vpsubd      xmm2,xmm2,xmm5  
00007FF70DD91269  vpcmpgtq    ymm5,ymm1,ymmword ptr [rdx+rax*8+40h]  
00007FF70DD91270  vextracti128 xmm6,ymm5,1  
00007FF70DD91276  vpackssdw   xmm5,xmm5,xmm6  
00007FF70DD9127A  vpsubd      xmm3,xmm3,xmm5  
00007FF70DD9127E  vpcmpgtq    ymm5,ymm1,ymmword ptr [rdx+rax*8+60h]  
00007FF70DD91285  vextracti128 xmm6,ymm5,1  
00007FF70DD9128B  vpackssdw   xmm5,xmm5,xmm6  
00007FF70DD9128F  vpsubd      xmm4,xmm4,xmm5 

上边的内容可以看作下边的代码重复了四次:

00007FF70DD91240  vpcmpgtq    ymm5,ymm1,ymmword ptr [rdx+rax*8]  --从数组中取四个数和ymm1(ymm1中是四个5)比较大小
00007FF70DD91246  vextracti128 xmm6,ymm5,1     --取256位中255:128位
00007FF70DD9124C  vpackssdw   xmm5,xmm5,xmm6    --因为大于5的时候,64个比特位都为1,所以这里把4个数long压缩为4个int
00007FF70DD91257  vpsubd      xmm0,xmm0,xmm5     --通过减法就可以让xmm0中保存有几个数小于等于4

xmm0,xmm2,xmm3,xmm4都保存的是部分结果,循环结束后,把xmm0,xmm2,xmm3,xmm4中数据相加就是结果了。

这个没有别的办法,因为必须要遍历,而且至少需要1次遍历。
但如果要存储这些数据中,那些位置是小于4的,并且利于后面的再次使用,则可以用bitmap数组映射的方式,把相应结果进行存储,这样数据至少的存储只需要 数组长度/8 + ( (数组长度%8)?1:0) 字节即可,需要注意,这里的除法是整除法。
这样的连续数组,后面还可以类似拼接为更多位数的无符号整数数组,来快速找到到底那些位是小于4的

/* C 伪代码 */
uint8 A[La]; // 数据已经存储在A中,下次遍历找到那些位置是小于4的(该bit用1标记),可以用
uint64 * B= (uint64 *)A;
for(i=0;i<(La/8);i++){
     if(B[i]) printf("%d ---: %lld\n",i,B[i]);// 这里的i表示第i个64bit数据中存在小于4的数据,这里其实可以优化输出bit位信息
   }
}

如果一个数字 < 4,那么这个数字 >> 2 就会变成 0,你可以把所有的数组里面的数字 >> 2,然后再用汇编语言的 repz cmpsd,来找 0,利用 DMA 操作指令来加速。我这仅仅是个思路,如果真的做一次右移操作在查找,不一定会真的快。

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题