tf.data.Dataset:如何获取数据集大小(一个纪元中的元素数)?

新手上路,请多包涵

假设我以这种方式定义了一个数据集:

 filename_dataset = tf.data.Dataset.list_files("{}/*.png".format(dataset))

我怎样才能得到数据集中的元素数量(因此,构成一个纪元的单个元素的数量)?

我知道 tf.data.Dataset 已经知道数据集的维度,因为 repeat() 方法允许重复输入管道指定数量的纪元。所以它一定是一种获取这些信息的方法。

原文由 nessuno 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 1k
2 个回答

tf.data.Dataset.list_files 创建一个名为 MatchingFiles:0 的张量(如果适用,带有适当的前缀)。

你可以评估

tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0'))[0]

获取文件的数量。

当然,这只适用于简单的情况,特别是如果每张图像只有一个样本(或已知数量的样本)。

在更复杂的情况下,例如,当您不知道每个文件中的样本数时,您只能观察一个 epoch 结束时的样本数。

为此,您可以查看 Dataset 计算的纪元数。 repeat() 创建一个名为 _count 的成员,用于计算纪元数。通过在迭代期间观察它,您可以发现它何时发生变化并从那里计算数据集大小。

这个计数器可能埋藏在依次调用成员函数时创建的 Dataset s的层次结构中,所以我们要这样挖出来。

 d = my_dataset
# RepeatDataset seems not to be exposed -- this is a possible workaround
RepeatDataset = type(tf.data.Dataset().repeat())
try:
  while not isinstance(d, RepeatDataset):
    d = d._input_dataset
except AttributeError:
  warnings.warn('no epoch counter found')
  epoch_counter = None
else:
  epoch_counter = d._count

请注意,使用此技术时,数据集大小的计算并不准确,因为批处理期间 epoch_counter 递增通常混合来自两个连续时期的样本。因此,此计算精确到您的批处理长度。

原文由 P-Gn 发布,翻译遵循 CC BY-SA 4.0 许可协议

len(list(dataset)) 在 eager 模式下工作,尽管这显然不是一个好的通用解决方案。

原文由 markemus 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题
logo
Stack Overflow 翻译
子站问答
访问
宣传栏