在 Tensorflow 中训练模型后:
- 你如何保存训练好的模型?
- 你以后如何恢复这个保存的模型?
原文由 mathetes 发布,翻译遵循 CC BY-SA 4.0 许可协议
在 Tensorflow 中训练模型后:
原文由 mathetes 发布,翻译遵循 CC BY-SA 4.0 许可协议
我正在改进我的答案以添加更多关于保存和恢复模型的细节。
在(及之后) Tensorflow 版本 0.11 中:
保存模型:
import tensorflow as tf
#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}
#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#Create a saver object which will save all the variables
saver = tf.train.Saver()
#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1
#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)
恢复模型:
import tensorflow as tf
sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated
这个和一些更高级的用例已经在这里得到了很好的解释。
原文由 sankit 发布,翻译遵循 CC BY-SA 3.0 许可协议
2 回答5.1k 阅读✓ 已解决
2 回答1.1k 阅读✓ 已解决
4 回答1.4k 阅读✓ 已解决
3 回答1.3k 阅读✓ 已解决
3 回答1.2k 阅读✓ 已解决
1 回答1.7k 阅读✓ 已解决
1 回答1.2k 阅读✓ 已解决
TensorFlow 2 文档
保存检查点
改编自 文档
更多链接
关于
saved_model
的详尽而有用的教程—-> https://www.tensorflow.org/guide/saved_modelkeras
保存模型的详细指南-> https://www.tensorflow.org/guide/keras/save_and_serialize(重点是我自己的)
张量流 < 2
从文档:
节省
恢复
simple_save
很多好的答案,为了完整起见,我将添加我的 2 美分: simple_save 。也是一个使用
tf.data.Dataset
API 的独立代码示例。蟒蛇3;张量流 1.14
恢复:
独立示例
原创博文
为了演示,以下代码生成随机数据。
Dataset
然后是Iterator
。我们得到迭代器生成的张量,称为input_tensor
它将作为我们模型的输入。input_tensor
构建的:一个基于 GRU 的双向 RNN,后跟一个密集分类器。因为为什么不。softmax_cross_entropy_with_logits
,用Adam
优化。在 2 个 epoch(每个 2 个批次)之后,我们使用tf.saved_model.simple_save
保存“训练过的”模型。如果您按原样运行代码,则模型将保存在您当前工作目录中名为simple/
的文件夹中。tf.saved_model.loader.load
恢复保存的模型。我们使用graph.get_tensor_by_name
和Iterator
使用graph.get_operation_by_name
来获取占位符和 logits。代码:
这将打印: