如何提高大量的字符串比较速度

libxd
  • 502

补充数据链接:

candidates: https://pan.baidu.com/s/1nvGWbrV
bg_db: https://pan.baidu.com/s/1sllFLAd

每个字符串长度都是23,只要前面20个字符就行,由于数据太大,我只传了五分之一,大神们可以挑战一下,有速度快的可以贴一下代码,让小弟拜读一下,谢谢!

下面是正式的问题:

我现在有两个字符串数组,姑且称为candidates和bg_db,全部都是长度为20的短字符串,并且每个字符串的每个字符只有ATCG四种可能(没错!就是基因组序列啦!):

candidates = [
    'GGGAGCAGGCAAGGACTCTG',
    'GCTCGGGCTTGTCCACAGGA',
    '...',
    # 被你看出来啦,这些其实人类基因的片段
]

bg_db = [
    'CTGCTGACGGGTGACACCCA',
    'AGGAACTGGTGCTTGATGGC',
    '...',
    # 这个更多,有十亿左右
]

我的任务是对candidates的每一个candidate,找到bg_db中所有与其小于等于4个差异的记录,
举个例子来说:

# 上面一条为candidate,即candidates的一个记录
# 中间|代表相同,*代表不相同
# 下面一条代表bg_db的一条记录

A T C G A T C G A T C G A T C G A T C G
| | | | | | | | | | | | | | | | | | | |    # 差异为0
A T C G A T C G A T C G A T C G A T C G

A T C G A T C G A T C G A T C G A T C G
* | | | | | | | | | | | | | | | | | | |    # 差异为1
T T C G A T C G A T C G A T C G A T C G

A T C G A T C G A T C G A T C G A T C G
* | | | * | | | | | | | | | | | | | | |    # 差异为2
T T C G T T C G A T C G A T C G A T C G

A T C G A T C G A T C G A T C G A T C G
* | | | * | | | | | | * | | | | | | | |    # 差异为3
T T C G T T C G A T C C A T C G A T C G

A T C G A T C G A T C G A T C G A T C G
* | | | * | | | | | | * | | | * | | | |    # 差异为4
T T C G T T C G A T C C A T C A A T C G

我的问题是如果快速地找到:每一个candidate在bg_db中与之差异小于等于4的所有记录,
如果采用暴力遍历的话,以Python为例:

def align(candidate, record_from_bg_db):
    mismatches = 0
    for i in range(20):
        if candidate[i] != record_from_bg_db[i]:
            mismatches += 1
            if mismatches >= 4:
                return False
    return True

candidate = 'GGGAGCAGGCAAGGACTCTG'
record_from_bg_db = 'CTGCTGACGGGTGACACCCA'

align(candidate, record_from_bg_db) # 1.24微秒左右

# 总时间:

10000000 * 1000000000 * 1.24 / 1000 / 1000 / 60 / 60 / 24 / 365
# = 393
# 1千万个candidates,10亿条bg_db记录
# 耗时大约393年
# 完全无法忍受啊

我的想法是,bg_db是高度有序的字符串(长度固定,每个字符的可能只有四种),有没有什么算法,可以让candidate快速比较完所有的bg_db,各位大神,求赐教。

回复
阅读 16.1k
18 个回答

写一个思路

candidates = [
    'GGGAGCAGGCAAGGACTCTG',
    'GCTCGGGCTTGTCCACAGGA',
    '...',
    # 被你看出来啦,这些其实人类基因的片段
]

bg_db = [
    'CTGCTGACGGGTGACACCCA',
    'AGGAACTGGTGCTTGATGGC',
    '...',
    # 这个更多,有十亿左右
]

因为你的数据其实是很有特点的,这里可以进行精简。
因为所有的字符串都是20个字符长度,且都由ATCG四个字符组成。那么可以把它们变换为整数来进行比较。
二进制表现形式如下

A  ==>  00
T  ==>  01
C  ==>  10
G  ==>  11

因为一个字符串长度固定,每个字符可以由2个比特位表示,所以每个字符串可以表示为一个40位的整数。可以表示为32+8的形式,也可以直接使用64位整形。建议使用C语言来做。

再来说说比较。
因为要找到每一个candidate在bg_db中与之差异小于等于4的所有记录,所以只要两个整数一做^按位异或操作,结果中二进制中1不超过8个,且这不超过8个1最多只能分为4个组的才有可能是符合要求的(00^11=11,10^01=11)。
把结果的40个比特位分作20个组,那么就是说最多只有4个组为b01 b10 b11这三个值,其余的全部为b00
那么比较算法就很好写了。
可以对每个字节(四个组)获取其中有几个组是为三个非零值的,来简介获取整体的比较结果。
因为每个字节只有256种可能的值,而符合条件的值只有3^4=81,所以可以先将结果存储起来,然后进行获取。
这里给出一个函数,来获取结果中有几个是非零组。

/*****************下面table中值的生成******//**
  int i;
  for( i=0;i<256;++i){
    int t =0;
    t += (i&0x01 || i&0x02)?1:0;
    t += (i&0x04 || i&0x08)?1:0;
    t += (i&0x10 || i&0x20)?1:0;
    t += (i&0x40 || i&0x80)?1:0;
    printf("%d,",t);
    if(i%10 ==9){putchar('\n');}
  }
********************************************//

int table[] = {
0,1,1,1,1,2,2,2,1,2,
2,2,1,2,2,2,1,2,2,2,
2,3,3,3,2,3,3,3,2,3,
3,3,1,2,2,2,2,3,3,3,
2,3,3,3,2,3,3,3,1,2,
2,2,2,3,3,3,2,3,3,3,
2,3,3,3,1,2,2,2,2,3,
3,3,2,3,3,3,2,3,3,3,
2,3,3,3,3,4,4,4,3,4,
4,4,3,4,4,4,2,3,3,3,
3,4,4,4,3,4,4,4,3,4,
4,4,2,3,3,3,3,4,4,4,
3,4,4,4,3,4,4,4,1,2,
2,2,2,3,3,3,2,3,3,3,
2,3,3,3,2,3,3,3,3,4,
4,4,3,4,4,4,3,4,4,4,
2,3,3,3,3,4,4,4,3,4,
4,4,3,4,4,4,2,3,3,3,
3,4,4,4,3,4,4,4,3,4,
4,4,1,2,2,2,2,3,3,3,
2,3,3,3,2,3,3,3,2,3,
3,3,3,4,4,4,3,4,4,4,
3,4,4,4,2,3,3,3,3,4,
4,4,3,4,4,4,3,4,4,4,
2,3,3,3,3,4,4,4,3,4,
4,4,3,4,4,4
};

int getCount(uint64_t cmpresult)
{
    uint8_t* pb = &cmpresult;    // 这里假设是小段模式,且之前比较结果是存在低40位
    return table[pb[0]]+table[pb[1]]+table[pb[2]]+table[pb[3]]+table[pb[4]];
}

首先,你的时间估算完全不对,这种大规模的数据量处理,好歹跑个几万条,持续十秒以上的时间,才能拿来做乘法算总时间,只算一条的话,这个时间几乎都是初始化进程的开销,而非关键的IO、CPU开销

以下正文

ACTG四种可能性相当于2bit,用一个字符表示一个基因位太过浪费了,一个字符8bit,可以放4个基因位

即使不用任何算法,只是把你的20个基因写成二进制形式,也能节省5倍时间

另外,循环20次,CPU的指令数是20*n条,n估计至少有3,但对于二进制来说,做比较的异或运算直接是cpu指令,指令数是1

算法白痴一个, 不过刚好手上有可以用的计算资源, 就按照原来的思路(暴力无脑并行计算流)做一下题主这个例子:

run.py

from multiprocessing import Pool, cpu_count


def do_align(item):
    with open("bg_db.txt") as fh, open("result.txt", "a") as rh:
        db_line = fh.readline().strip()
        while db_line:
            counts = 0
            for i in [(i, j) for (i, j) in zip(db_line, item)]:
                if i[0] != i[1]:
                    counts += 1
                if counts >= 4:
                    break
            if counts < 4:
                rh.write("{}\n".format(db_line))
            db_line = fh.readline().strip()


def main():
    pool = Pool(cpu_count())
    with open("candidates.txt") as fh:
        pool.map(do_align, map(str.strip, fh))
    pool.close()
    pool.join()


if __name__ == "__main__":
    main()

简单先生成点数据

import random
import string


def id_generator(size=8, chars=string.ascii_letters + string.digits):
    return ''.join(random.choice(chars) for _ in range(size))


with open("candidates.txt", "w") as fh:
    for i in range(10000):
        fh.write("{}\n".format(id_generator(20, "ATCG")))

with open("bg_db.txt", "w") as fh:
    for i in range(1000000):
        fh.write("{}\n".format(id_generator(20, "ATCG")))

嗯, 造了10000行的candidates.txt1000000行的bg_db.txt
运行看下时间:

$time python run.py

real    15m45.445s
user    1362m41.712s
sys     1m12.099s

题主实际的数据是千万行的candidates.txt和十亿行的bg_db.txt, 简单估算下时间
16*1000/(60*24) = 11.11
也就是11天, 这是我用的一个计算节点的配置

CPU Utilization:    1.0    0.0    98.9
user    sys    idle
Hardware
CPUs: 96 x 2.10 GHz
Memory (RAM): 1009.68 GB
Local Disk: Using 351.623 of 941.596 GB
Most Full Disk Partition: 53.5% used.

时间确实好长, 急需神优化

算法不是很了解 但是就经验来说 复杂的算法反而耗时更久 不如这种简单粗暴来的迅速

可以考虑下多线程和集群来处理数据

对了 还有汉明距离貌似可以算这个

同样没有使用算法,暴力解法,用c写的

在我的机器上(CPU: Core 2 Duo E7500, RAM: 4G, OS: Fedora 19),测试结果

candidates    bg.db        cost
10000    1000000    318758165微秒
500      1000000    14950302微秒 

如果换成题主的24核CPU,怎么也得有20倍的性能提升,然后再加上48台机器一起运算,5000W次运算为15s, 时间为
10000000 * 1000000000 / 500 / 1000000 * 15 / 20 / 48 / 3600 / 24 = 3.616898 天

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdbool.h>
#include <sys/time.h>

#define START_CC(flag)  \
    struct timeval st_##flag##_beg; \
    gettimeofday(&st_##flag##_beg, 0)

#define STOP_CC(flag)   \
    struct timeval st_##flag##_end; \
    gettimeofday(&st_##flag##_end, 0)

#define PRINT_CC(flag)  \
    double cost_##flag = 0.0L; \
    cost_##flag = (double)(st_##flag##_end.tv_sec - st_##flag##_beg.tv_sec); \
    cost_##flag = cost_##flag * 1000000L + (double)(st_##flag##_end.tv_usec - st_##flag##_beg.tv_usec);    \
    printf(#flag" cost time %.6lf microsecond.\n", cost_##flag);

#define GENEORDER_CODE_LENGTH 20 + 1

typedef struct GeneOrder
{
    char code[GENEORDER_CODE_LENGTH];
}GeneOrder, *GeneOrderPtr;

typedef struct GOArray
{
    size_t          capacity;
    size_t          length;
    GeneOrderPtr    data;
}GOArray;

GOArray createGOAarray(size_t capacity)
{
    GOArray goa;

    goa.capacity    = capacity;
    goa.length      = 0;
    goa.data        = (GeneOrderPtr)malloc(capacity * sizeof(GeneOrder));

    return goa;
}

void destroyGOArray(GOArray* goa)
{
    if (goa->capacity > 0) {
        free(goa->data);
    }
}

bool readGOFile(char const *file, GOArray *goarray)
{
    FILE* fp = NULL;

    if ((fp = fopen(file, "r+")) == NULL) {
        return false;
    }

    char buff[64];

    while (fgets(buff, 64, fp) != NULL) {
        if (goarray->length < goarray->capacity) {
            memcpy(goarray->data[goarray->length].code,
                buff,
                GENEORDER_CODE_LENGTH * sizeof(char)
            );
            goarray->data[goarray->length].code[GENEORDER_CODE_LENGTH - 1] = '\0';
            goarray->length ++;
        } else {
            fclose(fp);
            return true;
        }
    }

    fclose(fp);
    return true;
}

int main(int argc, char* argv[])
{
    (void)argc;

    GOArray condgo  = createGOAarray(10000);
    GOArray bggo    = createGOAarray(1000000);

    printf("loading ...\n");

    START_CC(loading);
    if (!readGOFile(argv[1], &condgo) || !readGOFile(argv[2], &bggo)) {
        destroyGOArray(&condgo);
        destroyGOArray(&bggo);
        return -1;
    }
    STOP_CC(loading);
    PRINT_CC(loading);


    int count = 0;

    START_CC(compare);
    for (size_t i = 0;i < 500;i ++) {
        const GeneOrderPtr gop = condgo.data + i;
        for (size_t j = 0;j < bggo.length;j ++) {
            const GeneOrderPtr inner_gop = bggo.data + j;
            int inner_count = 0;

            for (size_t k = 0;k < 20;k ++) {
                if (gop->code[k] != inner_gop->code[k]) {
                    if (++inner_count > 4) {
                        break;
                    }
                }
            }

            if (inner_count <= 4) {
            #ifdef DBGPRINT
                printf("%d %s - %s\n", i, gop->code, inner_gop->code);
            #endif
                count++;
            }
        }
    }
    STOP_CC(compare);
    PRINT_CC(compare);

    printf("result = %d\n", count);

    destroyGOArray(&condgo);
    destroyGOArray(&bggo);

    return 0;
}

编译参数&运行

gcc -Wall -Wextra -o ccs main.c -std=c99 -Os && ./ccs candidate.list bg.db

如果改成多线程的话速度会更快一些,在我的机器改为2个线程简单使用500条candidates测试,速度可以提升到9040257微秒,线程增加到4个性能提升就不是很大了,但是较新的CPU都具有超线程技术,速度估计会更好一些。。。

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdbool.h>
#include <sys/time.h>
#include <pthread.h>

#define START_CC(flag)              \
    struct timeval st_##flag##_beg; \
    gettimeofday(&st_##flag##_beg, 0)

#define STOP_CC(flag)               \
    struct timeval st_##flag##_end; \
    gettimeofday(&st_##flag##_end, 0)

#define PRINT_CC(flag)                                                                                  \
    double cost_##flag = 0.0L;                                                                          \
    cost_##flag = (double)(st_##flag##_end.tv_sec - st_##flag##_beg.tv_sec);                            \
    cost_##flag = cost_##flag * 1000000L + (double)(st_##flag##_end.tv_usec - st_##flag##_beg.tv_usec); \
    printf(#flag " cost time %.6lf microsecond.\n", cost_##flag);

#define GENEORDER_CODE_LENGTH 20 + 1

typedef struct GeneOrder {
    char code[GENEORDER_CODE_LENGTH];
} GeneOrder, *GeneOrderPtr;

typedef struct GOArray {
    size_t capacity;
    size_t length;
    GeneOrderPtr data;
} GOArray;

GOArray createGOAarray(size_t capacity)
{
    GOArray goa;

    goa.capacity = capacity;
    goa.length = 0;
    goa.data = (GeneOrderPtr)malloc(capacity * sizeof(GeneOrder));

    return goa;
}

void destroyGOArray(GOArray* goa)
{
    if (goa->capacity > 0) {
        free(goa->data);
    }
}

bool readGOFile(char const* file, GOArray* goarray)
{
    FILE* fp = NULL;

    if ((fp = fopen(file, "r+")) == NULL) {
        return false;
    }

    char buff[64];

    while (fgets(buff, 64, fp) != NULL) {
        if (goarray->length < goarray->capacity) {
            memcpy(goarray->data[goarray->length].code, buff,
                GENEORDER_CODE_LENGTH * sizeof(char));
            goarray->data[goarray->length].code[GENEORDER_CODE_LENGTH - 1] = '\0';
            goarray->length++;
        }
        else {
            fclose(fp);
            return true;
        }
    }

    fclose(fp);
    return true;
}

typedef struct ProcessST {
    GOArray* pcond;
    GOArray* pbg;
    size_t beg;
    size_t end; // [beg, end)
} ProcessST;

void* processThread(void* parg)
{
    ProcessST* ppst = (ProcessST*)parg;
    GOArray* pcond = ppst->pcond;
    GOArray* pbg = ppst->pbg;
    int count = 0;

    for (size_t i = ppst->beg; i < ppst->end; i++) {
        const GeneOrderPtr gop = pcond->data + i;
        for (size_t j = 0; j < pbg->length; j++) {
            const GeneOrderPtr inner_gop = pbg->data + j;
            int inner_count = 0;

            for (size_t k = 0; k < 20; k++) {
                if (gop->code[k] != inner_gop->code[k]) {
                    if (++inner_count > 4) {
                        break;
                    }
                }
            }

            if (inner_count <= 4) {
#ifdef DBGPRINT
                printf("%d %s - %s\n", i, gop->code, inner_gop->code);
#endif
                count++;
            }
        }
    }

    return (void*)count;
}

int main(int argc, char* argv[])
{
    (void)argc;

    GOArray condgo = createGOAarray(10000);
    GOArray bggo = createGOAarray(1000000);

    printf("loading ...\n");

    START_CC(loading);
    if (!readGOFile(argv[1], &condgo) || !readGOFile(argv[2], &bggo)) {
        destroyGOArray(&condgo);
        destroyGOArray(&bggo);
        return -1;
    }
    STOP_CC(loading);
    PRINT_CC(loading);

    size_t range[] = { 0, 250, 500 };
    pthread_t thr[2] = { 0 };
    ProcessST pst[2];

    START_CC(compare);

    for (size_t i = 0; i < 2; i++) {
        pst[i].pcond = &condgo;
        pst[i].pbg = &bggo;
        pst[i].beg = range[i];
        pst[i].end = range[i + 1];
        pthread_create(&thr[i], NULL, processThread, &pst[i]);
    }

    int count = 0;
    int* ret = NULL;

    for (size_t i = 0; i < 2; i++) {
        pthread_join(thr[i], (void**)&ret);
        count += (int)ret;
    }

    STOP_CC(compare);
    PRINT_CC(compare);

    printf("result = %d\n", count);

    destroyGOArray(&condgo);
    destroyGOArray(&bggo);

    return 0;
}

编译测试

gcc -Wall -Wextra -o ccs main.c -std=c99 -O3 -lpthread && ./ccs candidate.list bg.db

抱歉,今天看到还有人回复。
仔细看了一下问题,发现我以前以为只是匹配。
所以我提出用ac自动机。

但是题主是为了找到序列的差异。
这就是找两者的编辑距离。
wiki:编辑距离
wiki:来文斯坦距离

以前刷OJ的时候是使用DP(动态规划)来找一个字符串转换成另外一个字符串的最少编辑次数。

for(i:0->len)
    for(j:0->len)
        str1[i]==str2[j] ? cost=0 : cost=1
        dp[i,j]=min(
            dp[i-1, j  ] + 1,     // 刪除
            dp[i  , j-1] + 1,     // 插入
            dp[i-1, j-1] + cost   // 替換
        )

比如 :

str1    "abcdef"
str2    "acddff"

str2 转换为 str1

插入b 算一次
删除d 算一次
修改f 算一次

对于题主的ATCG基因序列来说,是不是只需要找到修改的就行了。
然而像这种
ATCGATCG
TCGATCGA
这样应该怎么算。

如果仅仅是找到修改的话,直接比较 str1[i] 和 str2[i] 就可以了。

for(i:0->len)
if(str1[i]!=str2[i] )++count;

受到@rockford 的启发。
我们可以对 原始数据 进行预处理。

candidates 中的串
GGGAGCAGGCAAGGACTCTG
A5 T2 C4 G9

进行处理之后的额外数据 A5 T2 C4 G9

bg_db 中的串
CTGCTGACGGGTGACACCCA
A4 T3 C7 G6

进行处理之后的额外数据 A4 T3 C7 G6

A5 -> A4 记作 -1
T2 -> T3 记作 +1
C4 -> C7 记作 +3
G9 -> G6 记作 -3

很明显 A 如果修改只能变成 TCG。
同理,我们只需要统计所有的+ 或者所有的 -
就可以知道他们的至少有多少不同之处。
大于4的都可以不进行比较。

通过先对比预处理的额外数据,然后再通过单次的比较算法来 进行比对。
(星期六加班-ing,下班后写一下)

你单个的任务是确定的,需要的是把这些任务下发给 worker 去做,对于这样的计算都不是同步单进程进行的。
其实就相当于你有[a,b] 和 [c, d] 要对比,你的任务是

  1. [a, c]

  2. [a, d]

  3. [b, c]

  4. [b, d]

如果你是同步串行你需要的时间就是 4 * 单个时间
如果是你 4 个 cpu 或者 4 个 机器并行, 你的时间差不多是单个时间

所以对于像基因组这样的计算基本上都是用大型机器多核并行的任务来完成,基本上参考的原理都是 google MapReduce 这篇论文的原理

算法我不行,但是,像你这样的大量数据,一台电脑对比肯定是不行的,像你这样数据CPU密集型任务,同意其他人说的使用集群或者多进程的方式来计算,也就是我们用map-reduce的模型去计算
map就是映射,你先将你每个candidates一个一个映射到bg_db形成类似这样的数据结构(candidates,bg_db)
做成队列然后交给不同的服务器,每个服务器用多进程去计算,只能这样了,但是你这个数据量太大了,想办法把你的任务分配好,并行计算吧,

可以尝试用字典树来保存所有的字符串。然后在查询的时候就可以用在字典树上遍历的方法。
在树上遍历的时候,可以维护一个当前节点的序列,这个序列里保存着当前遍历到的节点和对应节点mismatch的数量。
在遍历下一个节点的时候,要把当前序列里所有的节点都尝试向下,并形成新的节点序列。
好处是可以把很多个串的当前位放在一起进行比较,可以节约一些时间。由于每个位置选择不多,mismatch也不大,所有应该不会出现当前节点序列膨胀过大的情况。(这是猜想… 没太认真验证过…)

def match(candidate, root):
  nset = [[],[]]
  currentId = 0
  nset[currentId].append((root, 0))
  for ch in candidate:
    nextId = 1 - currentId
    for item in nset[currentId]:
      node, mismatch = item
      for nx in node.child:
        if nx.key == ch:
          nset[nextId].append((nx, mismatch))
        else:
          if mismatch:
            nset[nextId].append((nx, mismatch - 1))
    nset[currentId].clear()
    currentId = 1 - currentId
  return nset[currentId]

上面的代码就是一个大概的意思。如果用C++写的话会再快很多。
整个过程都可以用集群做分布式计算。

目测题主没有多台机器供他计算,
我有一个朴素的思路,计算每个串的字母序数和(A:0,T:1,C:2,G:3),先计算两个字符串的序数和的差值,最大不能超过12,四个A变成四个G,差值小于12的再进行处理,
只是一个大概的想法,具体的权值可以另外设置,理论上可以快很多。
另外有一个算法是计算字符串编辑距离(将一个字符串修改为另一个字符串的最少编辑次数增、删、改)的,我一下子想不起来,你可以自行查一下。

我用blast blastn-short

:)

HEYsir
  • 2
新手上路,请多包涵

基于@乌合之众 的思路,多用一倍空间用4个位表示四种字符,可以简化后面不一致字符个数的查找

思路如下:
第一:比如对比4个差异的数据
**TGACGGGTGACACCCA(删除字符串某4个位置的字符),将字符长度变为16,匹配完全相同的字符串
用map之类的保存TGACGGGTGACACCCA为key值,差异的四个作为values值

第二:对比3个差异的数据
在上述基础的上,进行对比上述的values值为比较长度为3完全相同的字符串

以此类推

可以了解一下CUDA等并行计算,你这种大量重复简单的运算性能提升非常明显

cyer
  • 4
新手上路,请多包涵

基本算法的复杂度是 O(M*N*c)
M =candidates数量
N = bg_db 数量

比较操作看做常数时间


优化:
    1 把20个字符分为4组 每组5个 ,每组可能的类型有 4^5=1024种 。把bg_db 先按照 第一个组 聚合 ,再按照第二个组聚合...构成一个4层的树结构
    2 建立 1024*1024 的表 表示 两个5个字符的串的差异数 内存 1MB
    2 匹配的时候,如果第一个组差异已经超过4个了,就不需要看它的子节点了,可以立刻排除大量节点

每组的长度也可以不确定,最好第一组有足够的区分性,这样可以一下子排除大量数据

提供一个思路,把四个字符简化成 00, 01, 10, 11。
比较的时候先执行 XOR,这样完全相同的就会变成 00;不完全相同的则是 01, 10或者 11。
然后再对XOR的结果每相邻的pair进行 OR,这样 00 会变成0, 01,10或者11就变成1。最后统计 1 的数量。由于都是位运算,理论上应该很快。

不过我的C学得渣,代码就不能贴出来了。

看了@乌合之众 的方法,我想了好久能不能再进一步。

今天看到一个关于 计算int的二进制串中1出现次数的问题。一个很巧妙的方法是通过 n&(n-1),其运算结果恰为把n的二进制位中的最低位的1变为0之后的结果.所以判断几次n&(n-1)后n变为0就可以判断1的个数。这个dna差异问题要是可以转化成找1的个数,那么只需要判断n&(n-1)四次后是否还不为0就可以了。

下边就是我根据乌合之众的方法,ATCG还是用两个比特位代表

A  ==>  00
T  ==>  01
C  ==>  10
G  ==>  11

这样的话,任意两个数做异或运算。如果两个数相同,则结果为00,两个不同的数做异或,结果中有1个或者2个1。

因为是两个二进制位为一组,假设用AB表示的话,通过左移和或运算,可以让两个1的变成一个1
大概过程是:

AB | (AB<<1) | & 0b10L
或
B
与
1 0 

pb = ((pb | (pb << 1)) & 0b001010101010101010101010101010101010101010L);

这样之后,二进制串中就只有00和10两种情况,1出现的次数就等于差异数。

一次pb &= pb - 1运算可以让最低位的1置0,所以只需要重复四次该过程,判断pb的值是否大于0即可。

    public final static boolean similar(long pb) {
        pb = (pb | (pb << 1)) & 0b001010101010101010101010101010101010101010L;
        pb &= pb - 1;
        pb &= pb - 1;
        pb &= pb - 1;
        pb &= pb - 1;
        return pb == 0 ? true : false;
    }

这里提一下,有个blsr的指令,等效于pb &= pb - 1操作。但还有个popcnt指令可以直接统计比特位中1的个数。

上边是一条一条计算的办法。其实可以通过simd加速这个计算过程。

int* similar(long long p, long long ms[], int length) {
    int* result = (int*)malloc(4 * length);
    int position = 1;
    __m256i mp = _mm256_set1_epi64x(p);
    __m256i vpshufbIndex = _mm256_setr_epi64x(0x0202020101010100L, 0x0202020102020201L, 0x0202020101010100L, 0x0202020102020201L);
    __m256i vpshufbMask = _mm256_set1_epi64x(0x0f0f0f0f0f0f0f0fL);
    __m256i zero = _mm256_set1_epi64x(0L);
    __m256i limit = _mm256_set1_epi64x(5);

    const  int length1 = length - 3;
    for (int i = 0; i < length1; i += 4)
    {
        __m256i  ymm1 = _mm256_load_si256((__m256i*) & ms[i]);
        //按位或运算
        ymm1 = _mm256_xor_si256(mp, ymm1);

        __m256i ymm2 = _mm256_and_si256(ymm1, vpshufbMask);
        ymm2 = _mm256_shuffle_epi8(vpshufbIndex, ymm2);

        ymm1 = _mm256_srli_epi16(ymm1, 4);
        ymm1 = _mm256_and_si256(ymm1, vpshufbMask);
        ymm1 = _mm256_shuffle_epi8(vpshufbIndex, ymm1);
        ymm1 = _mm256_add_epi8(ymm1, ymm2);

        ymm1 = _mm256_sad_epu8(zero, ymm1);

        ymm2 = _mm256_cmpgt_epi32(limit, ymm1);

        int match = _mm256_testz_ps(_mm256_castsi256_ps(ymm2), _mm256_castsi256_ps(ymm2));

        if (match == 0) {
            _int64* p = (_int64*)&ymm1;
            if (p[0] < 5) {
                result[position++] = i;
            }
            if (p[1] < 5) {
                result[position++] = i + 1;
            }
            if (p[2] < 5) {
                result[position++] = i + 2;
            }
            if (p[3] < 5) {
                result[position++] = i + 3;
            }
        }
    }
    result[0] = position - 1;
    return result;
}

然后测试了一下,在i7-10750H上,一个数和1千万数据做比较时。单线程需要5.2毫秒。

因为这个数据集有个特点,就是字符串相似度小于等于4的数据是很少的,1千万次比较可能就几十个数符合要求。基于这个特点,可以先计算部分字符串是否匹配。总的思路是,低32位保留了16个字符的信息,高32位保留的只有4个字符的信息,先计算低32位中字符串的差异是否小于等于4,如果小于等于3时再对另外4个字符做比较。

这样的话,低32位映射关系不变,高32位使用下边的映射关系。

A  ==>  0000
T  ==>  0010
C  ==>  0100
G  ==>  1000

因为avx2,一次可以对8个32位数做计算,所以加速效果还是很好的。

int* similar7(unsigned int p, unsigned int p2, unsigned int ms[], unsigned int ms2[], int length) {
    int* result = (int*)malloc(4 * length * 0.002);
    int position = 1;
    __m256i mp = _mm256_set1_epi32(p);
    __m256i vpshufbMask = _mm256_set1_epi64x(0x0f0f0f0f0f0f0f0fL);
    __m256i vpshufbIndex = _mm256_setr_epi64x(0x0202020101010100L, 0x0202020102020201L, 0x0202020101010100L, 0x0202020102020201L);
    __m256i zero = _mm256_set1_epi64x(0L);
    __m256i limit = _mm256_set1_epi64x(5);
    __m256i addMask1 = _mm256_set1_epi64x(0xFFFFFFFF00000000L);
    __m256i addMask2 = _mm256_set1_epi64x(0x00000000FFFFFFFFL);

    const  int length1 = length - 7;
    for (int i = 0; i < length1; i += 8)
    {
        __m256i  pbs = _mm256_load_si256((__m256i*) & ms[i]);
        //按位或运算
        __m256i pb1 = _mm256_xor_si256(mp, pbs);
        /**
        * 下边是计算每个数中1的个数,主要是在vpshufbIndex中放入每4字节对应的1个数,通过_mm256_shuffle_epi8之后
        *
        * pb1中是8个int32,现在是要统计每个int中1bit位的数量
        *
        * 思路是:
        * 1.将int32每4bit一组,放入1个字节的低四位
        * 2.4bit对应的1的数量总共有16种可能,所以用vpshufbIndex的[127:0]位中16个字节对应低四位值对应的1数量,[255:128]也一样
        * 3.因为是对m256计算,所以拆成了两次计算
        *
        */
        pbs = _mm256_and_si256(pb1, vpshufbMask);
        pbs = _mm256_shuffle_epi8(vpshufbIndex, pbs);
        pb1 = _mm256_srli_epi16(pb1, 4);
        pb1 = _mm256_and_si256(pb1, vpshufbMask);
        pb1 = _mm256_shuffle_epi8(vpshufbIndex, pb1);
        pb1 = _mm256_add_epi8(pb1, pbs);
        //4.因为每个数计算结果是分散在4个字节上的,为了把这四个字节的值相加,主要通过_mm256_sad_epu8指令可以将每8字节相加,但是因为每8字节对应的两个int,
        //  所以这里通过addMask1和addMask2把8个字节拆开
        pbs = _mm256_and_si256(pb1, addMask1);
        __m256i add2 = _mm256_and_si256(pb1, addMask2);
        //5. 计算8个字节的和就是一个数对应的1bit位的个数了
        pb1 = _mm256_sad_epu8(zero, pbs);
        add2 = _mm256_sad_epu8(zero, add2);
        //6.  得到1的数量后,和5比大小
        __m256i matchResult1 = _mm256_cmpgt_epi32(limit, pb1);
        __m256i matchResult2 = _mm256_cmpgt_epi32(limit, add2);
        //7. 判断是否全0,因为大部分结果都是1,所以很少走下边的if  _mm256_testz_ps是double类型的符号位判断,比整数的全0判断指令少一个周期,这里用正好
        int match1 = _mm256_testz_ps(_mm256_castsi256_ps(matchResult1), _mm256_castsi256_ps(matchResult1));

        //上边的一些列操作都是为了可以一次判断四个数是否都差异大于4,因为提前对数据集测试过,已经知道match等于1的几率比99.9%还大,这个分支基本不会走。
        if (match1 == 0) {  //如果确实存在差异小于4的情况,则退化成逐一比较,但好在几率低
            _int64* p = (_int64*)&pb1;
            if ((p[0] + (__popcnt(p2 ^ ms2[i + 1]) >> 1) < 5)) {
                result[position++] = i;
            }
            if ((p[1] + (__popcnt(p2 ^ ms2[i + 3]) >> 1)) < 5) {
                result[position++] = i + 2;
            }
            if ((p[2] + (__popcnt(p2 ^ ms2[i + 5]) >> 1)) < 5) {
                result[position++] = i + 4;
            }
            if ((p[3] + (__popcnt(p2 ^ ms2[i + 7]) >> 1)) < 5) {
                result[position++] = i + 6;
            }
        }
        int match2 = _mm256_testz_ps(_mm256_castsi256_ps(matchResult2), _mm256_castsi256_ps(matchResult2));
        if (match2 == 0) {
            _int64* p = (_int64*)&add2;
            if ((p[0] + (__popcnt(p2 ^ ms2[i]) >> 1)) < 5) {
                result[position++] = i + 1;
            }
            if ((p[1] + (__popcnt(p2 ^ ms2[i + 2]) >> 1)) < 5) {
                result[position++] = i + 3;
            }
            if ((p[2] + (__popcnt(p2 ^ ms2[i + 4]) >> 1)) < 5) {
                result[position++] = i + 5;
            }
            if ((p[3] + (__popcnt(p2 ^ ms2[i + 6]) >> 1)) < 5) {
                result[position++] = i + 7;
            }
        }
    }
    result[0] = position - 1;

    return result;
}

一个数和1千万数据做比较时。单线程需要3.4毫秒。

看到有说cuda的,这个了解过一点点,写了个测试用例,也不确定是不是最优秀,但是提升很明显了。

#ifndef __CUDACC__
#define __CUDACC__
#include "cuda_texture_types.h"
#endif

#include <cuda.h>
#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include <iostream>

#include <windows.h>

#include "kernel.cuh"


using namespace std;

#define M 1000000
#define N 200000000

__global__ void  similar(unsigned long long* pbs, unsigned long long* ms,int length, unsigned int* result)
{
    int position = blockDim.x * blockIdx.x + threadIdx.x;
    unsigned long long pb1 = pbs[position];
    int similar = 0;
    for (int i = 0; i < length; ++i)
    {
        unsigned long long pb = pb1 ^ ms[i];
        pb = (pb | (pb << 1)) & 0b001010101010101010101010101010101010101010L;
        pb = __popcll(pb);
        if (pb <= 4) {
            similar++;
        }
       // similar += (pb <= 4);
    }
    result[position] = similar;
}

void test(unsigned long long* host_pbs,const int h_length,unsigned long long* host_ms,int ms_length,unsigned int* host_result) {

    unsigned long long* device_pbs;
    unsigned long long* device_ms;
    unsigned int* device_result;

    cudaMalloc((void**)&device_pbs, sizeof(unsigned long long) * h_length);
    cudaMalloc((void**)&device_ms, sizeof(unsigned long long) * ms_length);
    cudaMalloc((void**)&device_result, sizeof(unsigned int) * ms_length);

    cudaMemcpy(device_pbs, host_pbs, sizeof(unsigned long long) * h_length, cudaMemcpyHostToDevice);
    cudaMemcpy(device_ms, host_ms, sizeof(unsigned long long) * ms_length, cudaMemcpyHostToDevice);

    similar << <30, 1024 >> > (device_pbs, device_ms, ms_length, device_result);
    cudaDeviceSynchronize();
    cudaMemcpy(host_result, device_result, sizeof(unsigned int) * ms_length, cudaMemcpyDeviceToHost);

}

int main(int argc, char* argv[]) {
    unsigned long long* host_pbs;
    unsigned long long* host_ms;
    unsigned int* host_result;

    host_pbs = (unsigned long long*)malloc( M * sizeof(unsigned long long));
    host_ms = (unsigned long long*)malloc(N * sizeof(unsigned long long));
    host_result = (unsigned int*)malloc(N * sizeof(unsigned int));


    for (int i = 0; i < M; ++i)
    {
        host_pbs[i] = rand();
    }
    for (int i = 0; i < N; ++i)
    {
        host_ms[i] = rand();
    }
    
    unsigned long long* device_pbs;
    unsigned long long* device_ms;
    unsigned int* device_result;

    cudaMalloc((void**)&device_pbs, sizeof(unsigned long long) * M);
    cudaMalloc((void**)&device_ms, sizeof(unsigned long long) * N);
    cudaMalloc((void**)&device_result, sizeof(unsigned int) * N);

    cudaMemcpy(device_pbs, host_pbs, sizeof(unsigned long long) * M, cudaMemcpyHostToDevice);
    cudaMemcpy(device_ms, host_ms, sizeof(unsigned long long) * N, cudaMemcpyHostToDevice);
  

    long t1 = GetTickCount();
    // 第一个参数的大小最好是sm数量的倍数  3060有30个,所以一般是30的倍数
    similar << <60,1024 >> > (device_pbs, device_ms,N,device_result);
    cudaDeviceSynchronize();
    long t2 = GetTickCount();
    std::cout << (t2 - t1) << std::endl;


    //cudaMemcpy(host_result, device_result, sizeof(unsigned int) * N, cudaMemcpyDeviceToHost);
}

GPU: 105瓦的 RTX3060 6G
测试结果是43469毫秒。2亿的数据量,比较61440次,用时43秒。平均0.708毫秒一次。
换算成10亿数据的话,也就是一次比较需要3.54毫秒。二百二十万次比较也就2个多小时。

宣传栏