TensorFlow:如何从 SavedModel 进行预测?

新手上路,请多包涵

我导出了一个 SavedModel 现在我要将它加载回去并进行预测。它使用以下特征和标签进行训练:

 F1 : FLOAT32
F2 : FLOAT32
F3 : FLOAT32
L1 : FLOAT32

所以说我想输入值 20.9, 1.8, 0.9 得到一个 FLOAT32 预测。我该如何做到这一点?我已成功加载模型,但我不确定如何访问它以进行预测调用。

 with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(
        sess,
        [tf.saved_model.tag_constants.SERVING],
        "/job/export/Servo/1503723455"
    )

    # How can I predict from here?
    # I want to do something like prediction = model.predict([20.9, 1.8, 0.9])

此问题与 此处 发布的问题不重复。这个问题侧重于对任何模型类的 SavedModel 执行推理的最小示例(不仅限于 tf.estimator )以及指定输入和输出节点名称的语法。

原文由 jshapy8 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 856
2 个回答

加载图形后,它在当前上下文中可用,您可以通过它提供输入数据以获得预测。每个用例都相当不同,但是添加到您的代码中的内容如下所示:

 with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(
        sess,
        [tf.saved_model.tag_constants.SERVING],
        "/job/export/Servo/1503723455"
    )

    prediction = sess.run(
        'prefix/predictions/Identity:0',
        feed_dict={
            'Placeholder:0': [20.9],
            'Placeholder_1:0': [1.8],
            'Placeholder_2:0': [0.9]
        }
    )

    print(prediction)

在这里,您需要知道预测输入的名称。如果你没有在你的 serving_fn 中给他们一个中殿,那么他们默认为 Placeholder_n ,其中 n

sess.run 的第一个字符串参数是预测目标的名称。这将根据您的用例而有所不同。

原文由 jshapy8 发布,翻译遵循 CC BY-SA 3.0 许可协议

假设您想要在 Python 中进行预测, SavedModelPredictor 可能是加载 SavedModel 并获取预测的最简单方法。假设您像这样保存模型:

 # Build the graph
f1 = tf.placeholder(shape=[], dtype=tf.float32)
f2 = tf.placeholder(shape=[], dtype=tf.float32)
f3 = tf.placeholder(shape=[], dtype=tf.float32)
l1 = tf.placeholder(shape=[], dtype=tf.float32)
output = build_graph(f1, f2, f3, l1)

# Save the model
inputs = {'F1': f1, 'F2': f2, 'F3': f3, 'L1': l1}
outputs = {'output': output_tensor}
tf.contrib.simple_save(sess, export_dir, inputs, outputs)

(输入可以是任何形状,甚至不必是图中的占位符或根节点)。

然后,在将使用 SavedModel 的 Python 程序中,我们可以获得如下预测:

 from tensorflow.contrib import predictor

predict_fn = predictor.from_saved_model(export_dir)
predictions = predict_fn(
    {"F1": 1.0, "F2": 2.0, "F3": 3.0, "L1": 4.0})
print(predictions)

这个答案 展示了如何在 Java、C++ 和 Python 中获得预测(尽管 问题 集中在 Estimators 上,但答案实际上独立于 SavedModel 是如何创建的)。

原文由 rhaertel80 发布,翻译遵循 CC BY-SA 3.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题