plt.imshow() 中图像数据的尺寸无效

新手上路,请多包涵

我正在使用 mnist 数据集在 keras 背景下训练胶囊网络。训练后,我想显示来自 mnist 数据集的图像。为了加载图像,使用了 mnist.load_data()。数据存储为 (x_train, y_train),(x_test, y_test)。现在,为了可视化图像,我的代码如下:

 img_path = x_test[1]
print(img_path.shape)
plt.imshow(img_path)
plt.show()

代码给出如下输出:

 (28, 28, 1)

和 plt.imshow(img_path) 上的错误如下:

 TypeError: Invalid dimensions for image data

如何以png格式显示图像。帮助!

原文由 Anusha Mehta 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 1.2k
2 个回答

您可以使用 tf.squeeze 从张量的形状中删除大小为 1 的维度。

 plt.imshow( tf.shape( tf.squeeze(x_train) ) )

查看 TF2.0 示例

原文由 Parth Patel 发布,翻译遵循 CC BY-SA 4.0 许可协议

根据 @sdcbr 的评论,使用 np.sqeeze 减少了不必要的维度。如果图像是二维的,则 imshow 函数可以正常工作。如果图像有 3 个维度,那么你必须减少额外的 1 个维度。但是,对于更高暗淡的数据,您必须将其减少到 2 暗淡,因此 np.sqeeze 可能会应用多次。 (或者您可以使用其他一些暗淡减少功能来获取更高暗淡的数据)

 import numpy as np
import matplotlib.pyplot as plt
img_path = x_test[1]
print(img_path.shape)
if(len(img_path.shape) == 3):
    plt.imshow(np.squeeze(img_path))
elif(len(img_path.shape) == 2):
    plt.imshow(img_path)
else:
    print("Higher dimensional data")

原文由 shantanu pathak 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进