你如何解码 Tensorflow 中的 one-hot 标签?

新手上路,请多包涵

一直在寻找,但似乎找不到任何有关如何从 TensorFlow 中的单热值解码或转换回单个整数的示例。

我使用 tf.one_hot 并能够训练我的模型,但我对如何在分类后理解标签感到困惑。我的数据是通过我创建的 TFRecords 文件输入的。我考虑过在文件中存储一个文本标签,但无法让它工作。似乎 TFRecords 无法存储文本字符串,或者我弄错了。

原文由 Matt Camp 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 506
2 个回答

您可以使用 tf.argmax 矩阵中最大元素的索引。由于您的一个热矢量将是一维的,并且只有一个 1 和其他 0 s,假设您正在处理单个矢量,这将起作用。

 index = tf.argmax(one_hot_vector, axis=0)

对于 batch_size * num_classes 的更标准矩阵,使用 axis=1 得到大小为 batch_size * 1 的结果。

原文由 martianwars 发布,翻译遵循 CC BY-SA 3.0 许可协议

由于 one-hot 编码通常只是一个具有 batch_size 行和 num_classes 列的矩阵,并且每行全为零,并且对应于所选类的单个非零值,您可以使用 tf.argmax() 恢复整数标签向量:

 BATCH_SIZE = 3
NUM_CLASSES = 4
one_hot_encoded = tf.constant([[0, 1, 0, 0],
                               [1, 0, 0, 0],
                               [0, 0, 0, 1]])

# Compute the argmax across the columns.
decoded = tf.argmax(one_hot_encoded, axis=1)

# ...
print sess.run(decoded)  # ==> array([1, 0, 3])

原文由 mrry 发布,翻译遵循 CC BY-SA 3.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题