如何删除 matplotlib.pyplot 中子图之间的空格?

新手上路,请多包涵

我正在做一个项目,我需要将 10 行和 3 列的绘图网格放在一起。虽然我已经能够制作情节并安排子情节,但我无法制作出没有空白的漂亮情节,例如下面来自 gridspec documentatation 的情节。 没有空白的图像 .

我尝试了以下帖子,但仍然无法像示例图像中那样完全删除空白区域。有人可以给我一些指导吗?谢谢!

这是我的形象: 我的形象

下面是我的代码。 完整的脚本在 GitHub 上。注意:images_2 和 images_fool 都是形状为 (1032, 10) 的扁平图像的 numpy 数组,而 delta 是形状为 (28, 28) 的图像数组。

 def plot_im(array=None, ind=0):
    """A function to plot the image given a images matrix, type of the matrix: \
    either original or fool, and the order of images in the matrix"""
    img_reshaped = array[ind, :].reshape((28, 28))
    imgplot = plt.imshow(img_reshaped)

# Output as a grid of 10 rows and 3 cols with first column being original, second being
# delta and third column being adversaril
nrow = 10
ncol = 3
n = 0

from matplotlib import gridspec
fig = plt.figure(figsize=(30, 30))
gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1])

for row in range(nrow):
    for col in range(ncol):
        plt.subplot(gs[n])
        if col == 0:
            #plt.subplot(nrow, ncol, n)
            plot_im(array=images_2, ind=row)
        elif col == 1:
            #plt.subplot(nrow, ncol, n)
            plt.imshow(w_delta)
        else:
            #plt.subplot(nrow, ncol, n)
            plot_im(array=images_fool, ind=row)
        n += 1

plt.tight_layout()
#plt.show()
plt.savefig('grid_figure.pdf')

原文由 George Liu 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 1.1k
2 个回答

开头的注释:如果您想完全控制间距,请避免使用 plt.tight_layout() 因为它会尝试将图中的图排列得均匀且分布均匀。这基本上很好并且产生令人愉快的结果,但可以随意调整间距。

您从 Matplotlib 示例库中引用的 GridSpec 示例运行良好的原因是子图的方面未预定义。也就是说,子图将简单地在网格上扩展并保留与图形大小无关的设置间距(在本例中为 wspace=0.0, hspace=0.0 )。

与您使用 imshow 绘制图像相反,默认情况下图像的纵横比设置为相等(相当于 ax.set_aspect("equal") )。也就是说,您当然可以将 set_aspect("auto") 放入每个图(并另外添加 wspace=0.0, hspace=0.0 作为 GridSpec 的参数,如画廊示例中所示),这将生成一个没有间距的图。

然而,当使用图像时,保持相等的纵横比很有意义,这样每个像素的宽度和高度一样,并且方形阵列显示为方形图像。

然后您需要做的是调整图像大小和图形边距以获得预期结果。 figsize figure 的参数是以英寸为单位的数字(宽度,高度),这里可以使用两个数字的比率。并且可以手动调整子图参数 wspace, hspace, top, bottom, left 以获得所需的结果。下面是一个例子:

 import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec

nrow = 10
ncol = 3

fig = plt.figure(figsize=(4, 10))

gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1],
         wspace=0.0, hspace=0.0, top=0.95, bottom=0.05, left=0.17, right=0.845)

for i in range(10):
    for j in range(3):
        im = np.random.rand(28,28)
        ax= plt.subplot(gs[i,j])
        ax.imshow(im)
        ax.set_xticklabels([])
        ax.set_yticklabels([])

#plt.tight_layout() # do not use this!!
plt.show()

在此处输入图像描述

编辑:

不必手动调整参数当然是可取的。所以可以根据行数和列数计算出一些最优的。

 nrow = 7
ncol = 7

fig = plt.figure(figsize=(ncol+1, nrow+1))

gs = gridspec.GridSpec(nrow, ncol,
         wspace=0.0, hspace=0.0,
         top=1.-0.5/(nrow+1), bottom=0.5/(nrow+1),
         left=0.5/(ncol+1), right=1-0.5/(ncol+1))

for i in range(nrow):
    for j in range(ncol):
        im = np.random.rand(28,28)
        ax= plt.subplot(gs[i,j])
        ax.imshow(im)
        ax.set_xticklabels([])
        ax.set_yticklabels([])

plt.show()

原文由 ImportanceOfBeingErnest 发布,翻译遵循 CC BY-SA 3.0 许可协议

尝试将此行添加到您的代码中:

 fig.subplots_adjust(wspace=0, hspace=0)

对于每个轴对象集:

 ax.set_xticklabels([])
ax.set_yticklabels([])

原文由 Serenity 发布,翻译遵循 CC BY-SA 3.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题