假设我以这种方式定义了一个数据集:
filename_dataset = tf.data.Dataset.list_files("{}/*.png".format(dataset))
我怎样才能得到数据集中的元素数量(因此,构成一个纪元的单个元素的数量)?
我知道 tf.data.Dataset
已经知道数据集的维度,因为 repeat()
方法允许重复输入管道指定数量的纪元。所以它一定是一种获取这些信息的方法。
原文由 nessuno 发布,翻译遵循 CC BY-SA 4.0 许可协议
tf.data.Dataset.list_files
创建一个名为MatchingFiles:0
的张量(如果适用,带有适当的前缀)。你可以评估
获取文件的数量。
当然,这只适用于简单的情况,特别是如果每张图像只有一个样本(或已知数量的样本)。
在更复杂的情况下,例如,当您不知道每个文件中的样本数时,您只能观察一个 epoch 结束时的样本数。
为此,您可以查看
Dataset
计算的纪元数。repeat()
创建一个名为_count
的成员,用于计算纪元数。通过在迭代期间观察它,您可以发现它何时发生变化并从那里计算数据集大小。这个计数器可能埋藏在依次调用成员函数时创建的
Dataset
s的层次结构中,所以我们要这样挖出来。请注意,使用此技术时,数据集大小的计算并不准确,因为批处理期间
epoch_counter
递增通常混合来自两个连续时期的样本。因此,此计算精确到您的批处理长度。