TensorFlow教程中的next_batch batch_xs, batch_ys = mnist.train.next_batch(100) 从哪里来?

新手上路,请多包涵

我正在试用 TensorFlow 教程,但不明白这一行中的 next_batch 是从哪里来的?

  batch_xs, batch_ys = mnist.train.next_batch(100)

我在看

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

也没有在那里看到 next_batch。

现在,在我自己的代码中尝试 next_batch 时,我得到了

AttributeError: 'numpy.ndarray' object has no attribute 'next_batch'

所以我想了解 next_batch 是从哪里来的?

原文由 Dan 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 849
2 个回答

next_batchDataSet 类的一种方法(见 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/ mnist.py 以获取有关课程内容的更多信息)。

当您加载 mnist 数据并将其分配给变量 mnist 时:

 mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

查看 mnist.train 的类。您可以通过键入以下内容来查看它:

 print mnist.train.__class__

您会看到以下内容:

 <class 'tensorflow.contrib.learn.python.learn.datasets.mnist.Dataset'>

因为 mnist.train 是类 DataSet 的一个实例,你可以使用类的函数 next_batch 。有关类的更多信息,请查看 文档

原文由 Nick Becker 发布,翻译遵循 CC BY-SA 3.0 许可协议

翻了一下tensorflow的仓库,好像起源于这里:

https://github.com/tensorflow/tensorflow/blob/9230423668770036179a72414482d45ddde40a3b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py#L905

但是,如果您希望在自己的代码中(针对您自己的数据集)实现它,那么像我所做的那样,自己将其编写在数据集对象中可能会简单得多。据我了解,这是一种打乱整个数据集并从打乱的数据集中返回 $mini_batch_size 个样本的方法。

这是一些伪代码:

shuffle data.x and data.y while retaining relation return [data.x[:mb_n], data.y[:mb_n]]

原文由 Dark Element 发布,翻译遵循 CC BY-SA 3.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题