matplotlib画图越来越慢?

新手上路,请多包涵
import matplotlib.pyplot as plt
%matplotlib inline
for i in range(100000):
    print('\r %d' %i, end = '')
    for j in range(100):
        plt.scatter(mx[i], WX_b[i][j])
     

循环越多每轮越慢,外层循环几百次就慢的不行了,是怎么回事?如何改进?

阅读 11.5k
1 个回答

尽少调用 plt.scatter 方法便可大幅提升性能.

详解
假设 WX_b 为 M N 矩阵, mx 为 M 1 矩阵, 下面代码

for i in range(WX_b.shape[0]):
    for j in range(WX_b.shape[1]):
        plt.scatter(mx[i], WX_b[i][j])

可以优化成

plt.scatter(mx.repeat(WX_b.shape[1], axis=1), WX_b)

jupyter 示例代码

%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

WX_b = np.random.randn(30, 5)
mx = np.random.randn(WX_b.shape[0], 1)

def func1():
    for i in range(WX_b.shape[0]):
        for j in range(WX_b.shape[1]):
            plt.scatter(mx[i], WX_b[i][j])
            
def func2():
    plt.scatter(mx.repeat(WX_b.shape[1], axis=1), WX_b)
    
%time func1()
%time func2()

参考结果: func2 运行时间大约是 func1 的 5%.

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进