也许你在数据科学/AI/机器学习的研究中头疼于大型数据加载与落盘的速度问题,毕竟IO过程是最磨人时间的。大家常调侃于python能优化的空间的不多,但事实上我们可以尽量地做到更好。希望本文对你的程序有点帮助。
本文的IO效率提升的探讨限定在数据科学领域内的以numpy.ndarray
为代表的大型数组(张量、矩阵)数据对象的IO问题上。解决问题的手段是以多线程/多进程为基础的并行写入/读取。同网络io和普通的小数据量的io问题不同,数据科学的大矩阵对象往往伴随着矩阵的切片等操作,他们对于内存的占用(是否复制、移动等)不明,更容易陷入内存冗余占用问题,这些都会影响io效率。本文探讨如下几个主题:
- 基于多线程/进程的并行读写方法及性能对比
- 并行IO中注意内存的冗余拷贝现象
- 最佳实践总结
IO情景
本文讨论的IO情景很简单,从磁盘上加载大数据进行处理,再将结果存储。这种情况常见于各类机器学习框架中,对数据的load和dump是最基本要解决的问题。下文中讨论的一些原理和技巧也在pytorch
、tensorflow
等的IO接口中体现。
在数据科学场景下,要优化读写的效率,可以从以下几个方向入手:
- 从文件编码格式入手,采用
pkl
等的二进制编码加速读写 - 从读写接口优化入手,采用DirectIO/零拷贝等的优化
- 分块、分批并行读写,适合数据相对独立情景
上述三种方法第一种操作简单,但编码的形式不方便与其他语言/工具兼容。第二种对于Python来讲有点小题大做,而且Python的IO接口不如静态语言那样显式,虽然也能直接采用os.open(CLONE_FLATS=...)
的最底层接口,但采用DirectIO
[4]或mmap
之类的优化都需要增加设计成本。第三种方法虽涉及多线程/进程,但不涉及通信与同步,实践相对简单。
多线程/多进程并行读写
并行基本逻辑
多进程导致的并行读写逻辑很简单,主要的开销在操作系统对进程的管理上。多线程对并行读写的理论支撑有必要再提一下(针对Cpython), 下图[1]所示的是GIL针对线程IO情景的处理。
上图也显示了多线程的主要开销是各个线程run
阶段的总和以及操作系统对线程的管理开销。
针对Cpython的多线程仍需要注意的是
- Linux下完全是POSIX-thread, 这意味着调度模式仍然是1:1的用户-内核映射关系
- Cpython多线程默认共享解释器中的全局变量
- 线程释放GIL的IO时机是进行底层基本的IO系统调用后
- 多线程关于调度通信使用信号量、条件变量等方法
标准库接口测评
我们设计一个小实验对CPython标准库提供的多线程/进程结果的并行写文件效率进行测试:
import os
import numpy as np
import time
from multiprocessing import Process
from multiprocessing import Pool
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from threading import Thread
from memory_profiler import profile
# Time calculator
class Benchmark:
def __init__(self, text):
self.text = text
def __enter__(self):
self.start = time.time()
def __exit__(self, *args):
self.end = time.time()
print("%s: consume: %s" % (self.text, self.end - self.start))
# Base Task
def store_task(data: np.ndarray, output, index):
fname = "%s_worker_%s.csv" % (output, index)
np.savetxt(fname, data, delimiter='\t')
#main data source
worker_num = os.cpu_count()
big_data = np.random.rand(1000000, 10)
task_num = big_data.shape[0] // worker_num
# 1. multiprocessing.Porcess
@profile
def loop_mp():
pool = []
for i in range(worker_num):
start = i * task_num
end = (i+1) * task_num
p = Process(target=store_task, args=(big_data[start: end], 'testdata/', i))
p.start()
pool.append(p)
for p in pool:
p.join()
# 2. threading.Thread
@profile
def mt_thread():
pool = []
for i in range(worker_num):
start = i * task_num
end = (i+1) * task_num
t = Thread(target=store_task, args=(big_data[start: end], 'testdata/thread', i))
t.start()
pool.append(t)
for p in pool:
p.join()
# 3. multiprocessing.Pool
@profile
def mp_pool():
with Pool(processes=worker_num) as pool:
tasks = []
for i in range(worker_num):
start = i * task_num
end = (i+1) * task_num
tasks.append(
pool.apply_async(store_task_inner, (big_data[start: end], 'testdata/mp_pool', i)))
pool.close()
pool.join()
# 4. ProcessPoolExecutor
@profile
def loop_pool():
with ProcessPoolExecutor(max_workers=worker_num) as exe:
for i in range(worker_num):
start = i * task_num
end = (i+1) * task_num
exe.submit(store_task, big_data[start: end], 'testdata/pool', i)
# 5. ThreadPoolExecutor
def loop_thread():
with ThreadPoolExecutor(max_workers=worker_num) as exe:
for i in range(worker_num):
start = i * task_num
end = (i+1) * task_num
exe.submit(store_task, big_data[start: end], 'testdata/pool_thread', i)
# 6. direct
@profile
def direct():
store_task(big_data, 'testdata/all', 0)
if __name__ == '__main__':
with Benchmark("loop mp"):
loop_mp()
with Benchmark("mt thread"):
mt_thread()
with Benchmark("mp pool"):
mp_pool()
with Benchmark("loop pool"):
loop_pool()
with Benchmark("direct"):
direct()
with Benchmark("Thread"):
loop_thread()
从时间消耗和内存上分析下各个接口的效率(测试环境MacOS 2.2 GHz 四核Intel Core i7
):
接口 | 耗时 | 内存 |
---|---|---|
multiprocessing.Process | 5.14s | p.start() 产生额外开销,触发参数的复制 |
theading.Thread | 10.34s | 无额外开销 |
multiprocessing.Pool | 4.18s | Pool() 构建额外开销, 参数未发生复制 |
ProcessPoolExecutor | 3.69s | 参数未发生复制 |
ThreadPoolExecutor | 10.82s | 无额外开销 |
direct | 22.04s | 无额外开销 |
时间开销分析
直观上看,多进程的接口加速了4-4.5x, 多线程加速了一半的时间。多线程比多进程要慢的原因比较复杂,原则上切换的开销线程要小于进程,但此例中多线程还涉及到线程间调度上的通信,而多进程则独立运行。当然有兴趣的朋友也可以选择asyncio.tasks
基于多路复用的接口对比下,缺点是比较难找到适合的非阻塞读写接口。
值得注意的是,多进程的两个接口的速度也有很大差别,Process
的模式比线程池的要慢很多,原因可能是数据拷贝的开销。下节讨论池技术为何避免了数据的拷贝。
内存开销分析
由于CPython的数据类型的限制,对于多线程threading
和多进程multiprocessing
的数据是否复制不能显式地展现,从原理上讲Thread()
是无需拷贝数据的,Process
是需要拷贝数据的。然而上表中显示multiprcocessing.Pool
和ProcessPoolExecutor
这两个基于线程池的方法未发生数据的拷贝。
代码中的@profile
是一个内存分析的三方库,但他的结果也不能充分说明本质。
其中Process
的结果是
Line # Mem usage Increment Occurences Line Contents
============================================================
29 101.3 MiB 101.3 MiB 1 @profile
30 def loop_mp():
31 101.3 MiB 0.0 MiB 1 pool = []
32 120.6 MiB 0.0 MiB 9 for i in range(worker_num):
33 120.6 MiB 0.0 MiB 8 start = i * task_num
34 120.6 MiB 0.0 MiB 8 end = (i+1) * task_num
35 120.6 MiB 0.0 MiB 8 p = Process(target=store_task, args=(big_data[start: end], 'testdata/', i))
36 120.6 MiB 19.3 MiB 8 p.start()
37 120.6 MiB 0.0 MiB 8 pool.append(p)
38 120.6 MiB 0.0 MiB 9 for p in pool:
39 120.6 MiB 0.0 MiB 8 p.join()
明显可以看出 p.start()
发生了数据的拷贝,拷贝的就是big_data[start: end]
实际大小。这与fork
系统调用差别很大,系统调用要明确地传入CLONE_FLAGS
来约定子进程与父进程的数据拷贝情况。再来看ProcessPoolExecutor
Line # Mem usage Increment Occurences Line Contents
============================================================
68 121.1 MiB 121.1 MiB 1 @profile
69 def loop_pool():
70 121.1 MiB 0.0 MiB 1 with ProcessPoolExecutor(max_workers=worker_num) as exe:
71 121.2 MiB -0.0 MiB 9 for i in range(worker_num):
72 121.2 MiB 0.0 MiB 8 start = i * task_num
73 121.2 MiB 0.0 MiB 8 end = (i+1) * task_num
74 121.2 MiB 0.1 MiB 8 exe.submit(store_task, big_data[start: end], 'testdata/pool', i)
表面上看没有发生拷贝,但事实如此吗?因为exe.submit
毕竟不是直接触发了Process()
的构建,想弄明白这个问题还得深究Pool
技术的原理。
关于Cpython的源码解析,已经不少Pythonista做了大量工作。从[2]的参考看到ProcessPoolExecutor
的封装逻辑是
|======================= In-process =====================|== Out-of-process ==|
+----------+ +----------+ +--------+ +-----------+ +---------+
| | => | Work Ids | => | | => | Call Q | => | |
| | +----------+ | | +-----------+ | |
| | | ... | | | | ... | | |
| | | 6 | | | | 5, call() | | |
| | | 7 | | | | ... | | |
| Process | | ... | | Local | +-----------+ | Process |
| Pool | +----------+ | Worker | | #1..n |
| Executor | | Thread | | |
| | +----------- + | | +-----------+ | |
| | <=> | Work Items | <=> | | <= | Result Q | <= | |
| | +------------+ | | +-----------+ | |
| | | 6: call() | | | | ... | | |
| | | future | | | | 4, result | | |
| | | ... | | | | 3, except | | |
+
−+----------+ +------------+ +--------+ +-----------+ +---------+
这个流程是否似曾相识?没错,他与之前文章[[C++造轮子] 基于pthread线程池](http://zhikai.pro/post/103)中...:
- 使用队列维护任务task
- Pool伴随着空进程的创建
- 有专门的管理线程来负责Pool的管理与监控
那么具体到参数数据拷贝上便是Queue.put()
与Queue.get()
的操作是否发生数据拷贝了。multiporcessing.Queue
是多进程通信的一种重要接口,他是基于共享内存的,参数数据的传递不发生拷贝,这对于大的ndarray
对象而言是极其重要的。
ndarray
的对象拷贝
Python世界里一切皆对象。 -- Py圈名言
面对企业级大数据时,Python程序出现的内存/显存占用率过高往往不是那么容易查明原因。动态引用类型+gc给python的内存管理带来了方便,但不必要的数据拷贝发生情景还是要尽量避免。
切片与组合
切片和组合是在以numpy
为代表的向量/矩阵/张量运算库的常用操作,他们底层是否发生复制很难分析:
import numpy as np
A = np.random.rand(1 << 8, 1 << 8)
B = A[:16]
del A ## can not release A's mem, for B's reference
print(A) ## error, the ref A has not exist yet,however its mem still exist
C = np.random.rand(1 << 4, 1 << 8)
D = np.concatenate([B, C], axis=1) ## D is a copy of B+C memory
对于concatenate
主要看内存分布决定是否发生复制[6]:
00 04 08 0C 10 14 18 1C 20 24 28 2C
| | | | | | | | | | | |
[data1 ][foo ][data2 ][bar ][concat(data1, data2) ]
data1 & data2 displayed in different place, concat them can only cover a new place.
切片同样是看内存分布,基于row和column的内存排列是不同的,具体的可以使用order=['C', 'F']
决定数组是按行在内存排列还是按列。[7] 还有一种办法是探究切片最终能否转换成slice(start, offset, stride)
的形式,如是则为view, 不能则大概率是copy, 例如诸多的fancy_index
形式都是copy, [:]
其实就是slice(None, None, None)
,它也是copy.[8]
切片到底是view还是copy在小数据量时无需care,但数据规模达到与内存上限时,大型的ndarray切片操作一定要小心了.
进程创建时的复制
我们希望把数据切片后传递给子进程, 同时我们希望这份数据不发生复制,各个进程共享这一大型ndarray
。首先从上一章明确的是,采用multiprocessing.Process(target=func, args=(ndarray[start:offset]))
创建子进程的方式是一定会复制ndarray的。其实这里主要用到的技术是multiprocessing
的共享内存方法。
Python3.8之后新增加了shared_memeory
, 给之前各种共享内存的方式做了一个统一的简易使用接口。我们使用share_memory改造一下上节的代码:
from multiprocessing import shared_memory
def store_task_sha_v2(start, end, output, index, sha_name, shape, dtype):
fname = "%s_worker_%s.csv" % (output, index)
exist_sham = shared_memory.SharedMemory(name=sha_name)
data = np.ndarray(shape, dtype=dtype, buffer=exist_sham.buf)
print(sha_name, data.shape, index)
np.savetxt(fname, data[start: end], delimiter='\t')
del data
exist_sham.close()
@profile
def mp_pool_sha():
shm = shared_memory.SharedMemory(create=True, size=big_data.nbytes)
b = np.ndarray(big_data.shape, dtype=big_data.dtype, buffer=shm.buf)
b[:] = big_data[:]
print(b.shape)
with ProcessPoolExecutor(max_workers=worker_num) as pool:
tasks = []
for i in range(worker_num):
start = i * task_num
end = (i+1) * task_num
tasks.append(
pool.submit(store_task_sha_v2,
start, end, 'testdata/mp_pool_sha', i ,
shm.name, b.shape, b.dtype))
for t in tasks:
# Note! 在这里捕获异常,ProcessPoolExecutor推荐这么使用!
try:
print(t.result())
except Exception as e:
print(f'{e}')
del b
shm.close()
shm.unlink()
代码复杂了不少,但逻辑很简单: 共享缓冲区申请->映射local-ndarray对象->放数据进入共享缓存区->其他进程读写->关闭缓存区。share_memeory
的好处还有他可以随时申请local-variable进行共享。
最佳实践总结
并行读文件加载ndarray
加入你的训练数据很大,需要流处理(训练),直接使用torch.datasets
等模块加载,他们封装好了并行流处理过程。
如果需要一次性载入RAM处理(如KNN等算法)则可以采用分块并行读:
def parallize_load(file, total_num, worker_num):
"""Load embedding file parallelization
@emb_file: source filename
@total_num: total lines
@worker_num: parallelize process num
return: np.ndaary
"""
def load_from_txt(emb, start, n_rows, arr_list):
data = np.loadtxt(emb, skiprows=start, max_rows=n_rows)
arr_list.append(data)
worker_load_num = total_num // worker_num
pool = []
with Manager() as manager:
arr_list = manager.list([])
for index in range(worker_num):
s = index * worker_load_num
if index != worker_num - 1:
e = worker_load_num
else:
e = total_num - (worker_load_num * index)
p = Process(target=load_from_txt, args=(emb_file, s, e, arr_list))
pool.append(p)
p.start()
for p in pool:
p.join()
arr = np.concatenate(arr_list)
return arr
source_total_num = sum(1 for line in open("souce_big_file", "rb"))
source_emb_data = parallize_load("souce_big_file", source_total_num, worker_num)
这基本上是worker_num
X 倍的加速。
并行写入实践
- 尽量避免对large-ndarray对象的切片、组合操作。
- 尽量避免使用
for-loop
, 多用矩阵运算 - 写入文件多进程效率更高,逻辑更简洁,但要时刻注意进程间数据不要发生复制
- 尽可能采用三方库的io接口如
np.savetxt
,df.to_csv
等,他们可能对异常、分chunk写入等方面都有优化 - 写入字符串时,能尽量地拼接
'\t'.join(List[])
, 就不要使用for ele in List: fp.write("%s\t%s\n" % (ele))
More work
本文讨论的对象只局限于host-device
的RAM和disk, 对于更常见的GPU-mem,对于Python诸多三方库的接口来讲可就太痛苦了,他们往往都省略了分配-申请-调度-通信-销毁的过程,出现OOM异常后排查只能靠指标观察。于此,接下来可以继续研究下显存的最佳实践。
最后,也许本文的内容会让你很诧异,因为对Python做优化是一件出力不讨好的事情。但不得不说这些办法在我目前的工作中,在一定资源的constrain下解决了原程序的很多问题。当然目前主流的机器学习算法流程都基于流处理,一次性地过大占用很少出现了,但也有存在embedding读写等需要用到手动读写的地方。
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。