keras 中的 Flatten() 和 GlobalAveragePooling2D() 有什么区别

新手上路,请多包涵

我想将 ConvLSTM 和 Conv2D 的输出传递到 Keras 中的密集层,使用全局平均池和展平之间的区别是什么?两者都适用于我的情况。

 model.add(ConvLSTM2D(filters=256,kernel_size=(3,3)))
model.add(Flatten())
# or model.add(GlobalAveragePooling2D())
model.add(Dense(256,activation='relu'))

原文由 user239457 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 1.5k
2 个回答

两者似乎都有效并不意味着它们的作用相同。

Flatten 会将任意形状的张量转换为一维张量(加上样本维度),但将所有值保留在张量中。例如,张量 (samples, 10, 20, 1) 将被展平为 (samples, 10 * 20 * 1)。

GlobalAveragePooling2D 做了一些不同的事情。它在空间维度上应用平均池化,直到每个空间维度为一个,并保持其他维度不变。在这种情况下,值不会保留为平均值。例如,张量 (samples, 10, 20, 1) 将输出为 (samples, 1, 1, 1),假设第 2 和第 3 维是空间维度(最后是通道)。

原文由 Dr. Snoopy 发布,翻译遵循 CC BY-SA 3.0 许可协议

Flatten 层的作用

After convolutional operations, tf.keras.layers.Flatten will reshape a tensor into (n_samples, height*width*channels) , for example turning (16, 28, 28, 3) into (16, 2352) .让我们试试看:

 import tensorflow as tf

x = tf.random.uniform(shape=(100, 28, 28, 3), minval=0, maxval=256, dtype=tf.int32)

flat = tf.keras.layers.Flatten()

flat(x).shape

 TensorShape([100, 2352])

GlobalAveragePooling 层的作用

卷积运算后, tf.keras.layers.GlobalAveragePooling根据最后一个轴对 所有值进行平均。这意味着生成的形状将为 (n_samples, last_axis) 。例如,如果您的最后一个卷积层有 64 个过滤器,它会将 (16, 7, 7, 64) 变成 (16, 64) 。让我们进行测试,经过一些卷积操作:

 import tensorflow as tf

x = tf.cast(
    tf.random.uniform(shape=(16, 28, 28, 3), minval=0, maxval=256, dtype=tf.int32),
    tf.float32)

gap = tf.keras.layers.GlobalAveragePooling2D()

for i in range(5):
    conv = tf.keras.layers.Conv2D(64, 3)
    x = conv(x)
    print(x.shape)

print(gap(x).shape)

 (16, 24, 24, 64)
(16, 22, 22, 64)
(16, 20, 20, 64)
(16, 18, 18, 64)
(16, 16, 16, 64)

(16, 64)

你应该使用哪个?

Flatten 层将始终至少具有与 GlobalAveragePooling2D 层一样多的参数。如果展平前的最终张量形状仍然很大,例如 (16, 240, 240, 128) ,使用 Flatten 将产生大量参数: 240*240*128 = 7,372,800 这个巨大的数字将乘以下一个密集层中的单元数!在那一刻, GlobalAveragePooling2D 在大多数情况下可能是首选。如果您使用 MaxPooling2DConv2D 以至于您的张量在展平之前的形状就像 (16, 1, 1, 128) 一样,它不会有什么不同。如果你过度拟合,你可能想尝试 GlobalAveragePooling2D

原文由 Nicolas Gervais 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题