原文地址-我的博客

也许你在数据科学/AI/机器学习的研究中头疼于大型数据加载与落盘的速度问题,毕竟IO过程是最磨人时间的。大家常调侃于python能优化的空间的不多,但事实上我们可以尽量地做到更好。希望本文对你的程序有点帮助。

本文的IO效率提升的探讨限定在数据科学领域内的以numpy.ndarray为代表的大型数组(张量、矩阵)数据对象的IO问题上。解决问题的手段是以多线程/多进程为基础的并行写入/读取。同网络io和普通的小数据量的io问题不同,数据科学的大矩阵对象往往伴随着矩阵的切片等操作,他们对于内存的占用(是否复制、移动等)不明,更容易陷入内存冗余占用问题,这些都会影响io效率。本文探讨如下几个主题:

  • 基于多线程/进程的并行读写方法及性能对比
  • 并行IO中注意内存的冗余拷贝现象
  • 最佳实践总结

IO情景

本文讨论的IO情景很简单,从磁盘上加载大数据进行处理,再将结果存储。这种情况常见于各类机器学习框架中,对数据的load和dump是最基本要解决的问题。下文中讨论的一些原理和技巧也在pytorchtensorflow等的IO接口中体现。

在数据科学场景下,要优化读写的效率,可以从以下几个方向入手:

  • 从文件编码格式入手,采用pkl等的二进制编码加速读写
  • 从读写接口优化入手,采用DirectIO/零拷贝等的优化
  • 分块、分批并行读写,适合数据相对独立情景

上述三种方法第一种操作简单,但编码的形式不方便与其他语言/工具兼容。第二种对于Python来讲有点小题大做,而且Python的IO接口不如静态语言那样显式,虽然也能直接采用os.open(CLONE_FLATS=...)的最底层接口,但采用DirectIO[4]或mmap之类的优化都需要增加设计成本。第三种方法虽涉及多线程/进程,但不涉及通信与同步,实践相对简单。

多线程/多进程并行读写

并行基本逻辑

多进程导致的并行读写逻辑很简单,主要的开销在操作系统对进程的管理上。多线程对并行读写的理论支撑有必要再提一下(针对Cpython), 下图[1]所示的是GIL针对线程IO情景的处理。

上图也显示了多线程的主要开销是各个线程run阶段的总和以及操作系统对线程的管理开销。

针对Cpython的多线程仍需要注意的是

  • Linux下完全是POSIX-thread, 这意味着调度模式仍然是1:1的用户-内核映射关系
  • Cpython多线程默认共享解释器中的全局变量
  • 线程释放GIL的IO时机是进行底层基本的IO系统调用后
  • 多线程关于调度通信使用信号量、条件变量等方法

标准库接口测评

我们设计一个小实验对CPython标准库提供的多线程/进程结果的并行写文件效率进行测试:

import os
import numpy as np
import time
from multiprocessing import Process
from multiprocessing import Pool
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from threading import Thread
from memory_profiler import profile
# Time calculator
class Benchmark:
    def __init__(self, text):
        self.text = text
    def __enter__(self):
        self.start = time.time()
    def __exit__(self, *args):
        self.end = time.time()
        print("%s: consume: %s" % (self.text, self.end - self.start))

# Base Task
def store_task(data: np.ndarray, output, index):
    fname = "%s_worker_%s.csv" % (output, index)
    np.savetxt(fname, data, delimiter='\t')

#main data source
worker_num = os.cpu_count()
big_data = np.random.rand(1000000, 10)
task_num = big_data.shape[0] // worker_num

# 1. multiprocessing.Porcess
@profile
def loop_mp():
    pool = []
    for i in range(worker_num):
        start = i * task_num
        end = (i+1) * task_num
        p = Process(target=store_task, args=(big_data[start: end], 'testdata/', i))
        p.start()
        pool.append(p)
    for p in pool:
        p.join()

# 2. threading.Thread
@profile
def mt_thread():
    pool = []
    for i in range(worker_num):
        start = i * task_num
        end = (i+1) * task_num
        t = Thread(target=store_task, args=(big_data[start: end], 'testdata/thread', i))
        t.start()
        pool.append(t)
    for p in pool:
        p.join()

# 3. multiprocessing.Pool
@profile
def mp_pool():
    with Pool(processes=worker_num) as pool:
        tasks = []
        for i in range(worker_num):
            start = i * task_num
            end = (i+1) * task_num
            tasks.append(
                pool.apply_async(store_task_inner, (big_data[start: end], 'testdata/mp_pool', i)))
        pool.close()
        pool.join()

# 4. ProcessPoolExecutor
@profile
def loop_pool():
    with ProcessPoolExecutor(max_workers=worker_num) as exe:
        for i in range(worker_num):
            start = i * task_num
            end = (i+1) * task_num
            exe.submit(store_task, big_data[start: end], 'testdata/pool', i)

# 5. ThreadPoolExecutor
def loop_thread():
    with ThreadPoolExecutor(max_workers=worker_num) as exe:
        for i in range(worker_num):
            start = i * task_num
            end = (i+1) * task_num
            exe.submit(store_task, big_data[start: end], 'testdata/pool_thread', i)

# 6.  direct
@profile
def direct():
    store_task(big_data, 'testdata/all', 0)

if __name__ == '__main__':
    with Benchmark("loop mp"):
        loop_mp()
    with Benchmark("mt thread"):
        mt_thread()
    with Benchmark("mp pool"):
        mp_pool()
    with Benchmark("loop pool"):
        loop_pool()
    with Benchmark("direct"):
        direct()
    with Benchmark("Thread"):
        loop_thread()

从时间消耗和内存上分析下各个接口的效率(测试环境MacOS 2.2 GHz 四核Intel Core i7):

接口耗时内存
multiprocessing.Process5.14sp.start()产生额外开销,触发参数的复制
theading.Thread10.34s无额外开销
multiprocessing.Pool4.18sPool()构建额外开销, 参数未发生复制
ProcessPoolExecutor3.69s参数未发生复制
ThreadPoolExecutor10.82s无额外开销
direct22.04s无额外开销

时间开销分析

直观上看,多进程的接口加速了4-4.5x, 多线程加速了一半的时间。多线程比多进程要慢的原因比较复杂,原则上切换的开销线程要小于进程,但此例中多线程还涉及到线程间调度上的通信,而多进程则独立运行。当然有兴趣的朋友也可以选择asyncio.tasks基于多路复用的接口对比下,缺点是比较难找到适合的非阻塞读写接口。

值得注意的是,多进程的两个接口的速度也有很大差别,Process的模式比线程池的要慢很多,原因可能是数据拷贝的开销。下节讨论池技术为何避免了数据的拷贝。

内存开销分析

由于CPython的数据类型的限制,对于多线程threading和多进程multiprocessing的数据是否复制不能显式地展现,从原理上讲Thread()是无需拷贝数据的,Process是需要拷贝数据的。然而上表中显示multiprcocessing.PoolProcessPoolExecutor这两个基于线程池的方法未发生数据的拷贝。

代码中的@profile是一个内存分析的三方库,但他的结果也不能充分说明本质。
其中Process的结果是

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
    29    101.3 MiB    101.3 MiB           1   @profile
    30                                         def loop_mp():
    31    101.3 MiB      0.0 MiB           1       pool = []
    32    120.6 MiB      0.0 MiB           9       for i in range(worker_num):
    33    120.6 MiB      0.0 MiB           8           start = i * task_num
    34    120.6 MiB      0.0 MiB           8           end = (i+1) * task_num
    35    120.6 MiB      0.0 MiB           8           p = Process(target=store_task, args=(big_data[start: end], 'testdata/', i))
    36    120.6 MiB     19.3 MiB           8           p.start()
    37    120.6 MiB      0.0 MiB           8           pool.append(p)
    38    120.6 MiB      0.0 MiB           9       for p in pool:
    39    120.6 MiB      0.0 MiB           8           p.join()

明显可以看出 p.start()发生了数据的拷贝,拷贝的就是big_data[start: end]实际大小。这与fork系统调用差别很大,系统调用要明确地传入CLONE_FLAGS来约定子进程与父进程的数据拷贝情况。再来看ProcessPoolExecutor

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
    68    121.1 MiB    121.1 MiB           1   @profile
    69                                         def loop_pool():
    70    121.1 MiB      0.0 MiB           1       with ProcessPoolExecutor(max_workers=worker_num) as exe:
    71    121.2 MiB     -0.0 MiB           9           for i in range(worker_num):
    72    121.2 MiB      0.0 MiB           8               start = i * task_num
    73    121.2 MiB      0.0 MiB           8               end = (i+1) * task_num
    74    121.2 MiB      0.1 MiB           8               exe.submit(store_task, big_data[start: end], 'testdata/pool', i)

表面上看没有发生拷贝,但事实如此吗?因为exe.submit毕竟不是直接触发了Process()的构建,想弄明白这个问题还得深究Pool技术的原理。

关于Cpython的源码解析,已经不少Pythonista做了大量工作。从[2]的参考看到ProcessPoolExecutor的封装逻辑是

|======================= In-process =====================|== Out-of-process ==|

+----------+     +----------+       +--------+     +-----------+    +---------+
|          |  => | Work Ids |    => |        |  => | Call Q    | => |         |
|          |     +----------+       |        |     +-----------+    |         |
|          |     | ...      |       |        |     | ...       |    |         |
|          |     | 6        |       |        |     | 5, call() |    |         |
|          |     | 7        |       |        |     | ...       |    |         |
| Process  |     | ...      |       | Local  |     +-----------+    | Process |
|  Pool    |     +----------+       | Worker |                      |  #1..n  |
| Executor |                        | Thread |                      |         |
|          |     +----------- +     |        |     +-----------+    |         |
|          | <=> | Work Items | <=> |        | <=  | Result Q  | <= |         |
|          |     +------------+     |        |     +-----------+    |         |
|          |     | 6: call()  |     |        |     | ...       |    |         |
|          |     |    future  |     |        |     | 4, result |    |         |
|          |     | ...        |     |        |     | 3, except |    |         |
+
−+----------+     +------------+     +--------+     +-----------+    +---------+

这个流程是否似曾相识?没错,他与之前文章[[C++造轮子] 基于pthread线程池](http://zhikai.pro/post/103)中...:

  • 使用队列维护任务task
  • Pool伴随着空进程的创建
  • 有专门的管理线程来负责Pool的管理与监控

那么具体到参数数据拷贝上便是Queue.put()Queue.get()的操作是否发生数据拷贝了。multiporcessing.Queue是多进程通信的一种重要接口,他是基于共享内存的,参数数据的传递不发生拷贝,这对于大的ndarray对象而言是极其重要的。

ndarray的对象拷贝

Python世界里一切皆对象。 -- Py圈名言

面对企业级大数据时,Python程序出现的内存/显存占用率过高往往不是那么容易查明原因。动态引用类型+gc给python的内存管理带来了方便,但不必要的数据拷贝发生情景还是要尽量避免。

切片与组合

切片和组合是在以numpy为代表的向量/矩阵/张量运算库的常用操作,他们底层是否发生复制很难分析:

import numpy as np
A = np.random.rand(1 << 8, 1 << 8)
B = A[:16]
del A  ## can not release A's mem, for B's reference
print(A)  ## error, the ref A has not exist yet,however its mem still exist
C = np.random.rand(1 << 4, 1 << 8)
D = np.concatenate([B, C], axis=1) ## D is a copy of B+C memory

对于concatenate主要看内存分布决定是否发生复制[6]:

00    04    08    0C    10    14    18    1C    20    24    28    2C
|     |     |     |     |     |     |     |     |     |     |     |
[data1     ][foo ][data2     ][bar ][concat(data1, data2)  ]

data1 & data2 displayed in different place, concat them can only cover a new place.

切片同样是看内存分布,基于row和column的内存排列是不同的,具体的可以使用order=['C', 'F']决定数组是按行在内存排列还是按列。[7] 还有一种办法是探究切片最终能否转换成slice(start, offset, stride)的形式,如是则为view, 不能则大概率是copy, 例如诸多的fancy_index形式都是copy, [:] 其实就是slice(None, None, None),它也是copy.[8]

切片到底是view还是copy在小数据量时无需care,但数据规模达到与内存上限时,大型的ndarray切片操作一定要小心了.

进程创建时的复制

我们希望把数据切片后传递给子进程, 同时我们希望这份数据不发生复制,各个进程共享这一大型ndarray。首先从上一章明确的是,采用multiprocessing.Process(target=func, args=(ndarray[start:offset]))创建子进程的方式是一定会复制ndarray的。其实这里主要用到的技术是multiprocessing的共享内存方法。

Python3.8之后新增加了shared_memeory, 给之前各种共享内存的方式做了一个统一的简易使用接口。我们使用share_memory改造一下上节的代码:

from multiprocessing import shared_memory
def store_task_sha_v2(start, end, output, index, sha_name, shape, dtype):
    fname = "%s_worker_%s.csv" % (output, index)
    exist_sham = shared_memory.SharedMemory(name=sha_name)
    data = np.ndarray(shape, dtype=dtype, buffer=exist_sham.buf)
    print(sha_name, data.shape, index)
    np.savetxt(fname, data[start: end], delimiter='\t')
    del data
    exist_sham.close()

@profile
def mp_pool_sha():
    shm = shared_memory.SharedMemory(create=True, size=big_data.nbytes)
    b = np.ndarray(big_data.shape, dtype=big_data.dtype, buffer=shm.buf)
    b[:] = big_data[:]
    print(b.shape)
    with ProcessPoolExecutor(max_workers=worker_num) as pool:
        tasks = []
        for i in range(worker_num):
            start = i * task_num
            end = (i+1) * task_num
            tasks.append(
                pool.submit(store_task_sha_v2, 
                    start, end, 'testdata/mp_pool_sha', i ,
                    shm.name, b.shape, b.dtype))
        for t in tasks:
            # Note! 在这里捕获异常,ProcessPoolExecutor推荐这么使用!
            try:
                print(t.result())
            except Exception as e:
                print(f'{e}')
    del b
    shm.close()
    shm.unlink() 

代码复杂了不少,但逻辑很简单: 共享缓冲区申请->映射local-ndarray对象->放数据进入共享缓存区->其他进程读写->关闭缓存区。share_memeory的好处还有他可以随时申请local-variable进行共享。

最佳实践总结

并行读文件加载ndarray

加入你的训练数据很大,需要流处理(训练),直接使用torch.datasets等模块加载,他们封装好了并行流处理过程。

如果需要一次性载入RAM处理(如KNN等算法)则可以采用分块并行读:

def parallize_load(file, total_num, worker_num):
    """Load embedding file parallelization
       @emb_file: source filename
       @total_num: total lines
       @worker_num: parallelize process num
    return: np.ndaary
    """
    def load_from_txt(emb, start, n_rows, arr_list):
        data = np.loadtxt(emb, skiprows=start, max_rows=n_rows)
        arr_list.append(data)

    worker_load_num = total_num // worker_num
    pool = []
    with Manager() as manager:
        arr_list = manager.list([])
        for index in range(worker_num):
            s = index * worker_load_num
            if index != worker_num - 1:
                e = worker_load_num
            else:
                e = total_num - (worker_load_num * index)
            p = Process(target=load_from_txt, args=(emb_file, s, e, arr_list))
            pool.append(p)
            p.start()
        for p in pool:
            p.join()
        arr = np.concatenate(arr_list)
    return arr
source_total_num = sum(1 for line in open("souce_big_file", "rb"))
source_emb_data = parallize_load("souce_big_file", source_total_num, worker_num)

这基本上是worker_numX 倍的加速。

并行写入实践

  • 尽量避免对large-ndarray对象的切片、组合操作。
  • 尽量避免使用for-loop, 多用矩阵运算
  • 写入文件多进程效率更高,逻辑更简洁,但要时刻注意进程间数据不要发生复制
  • 尽可能采用三方库的io接口如np.savetxt,df.to_csv等,他们可能对异常、分chunk写入等方面都有优化
  • 写入字符串时,能尽量地拼接'\t'.join(List[]), 就不要使用for ele in List: fp.write("%s\t%s\n" % (ele))

More work

本文讨论的对象只局限于host-device的RAM和disk, 对于更常见的GPU-mem,对于Python诸多三方库的接口来讲可就太痛苦了,他们往往都省略了分配-申请-调度-通信-销毁的过程,出现OOM异常后排查只能靠指标观察。于此,接下来可以继续研究下显存的最佳实践。

最后,也许本文的内容会让你很诧异,因为对Python做优化是一件出力不讨好的事情。但不得不说这些办法在我目前的工作中,在一定资源的constrain下解决了原程序的很多问题。当然目前主流的机器学习算法流程都基于流处理,一次性地过大占用很少出现了,但也有存在embedding读写等需要用到手动读写的地方。

zhikai
42 声望7 粉丝

infra engineer & 统计与机器学习