如何把 .ckpt 转换成 .pb?

新手上路,请多包涵

我是深度学习的新手,我想使用预训练 (EAST) 模型从 AI Platform Serving 提供服务,开发人员提供了以下文件:

  1. model.ckpt-49491.data-00000-of-00001
  2. 检查点
  3. 模型.ckpt-49491.index
  4. 型号.ckpt-49491.meta

我想把它转换成 TensorFlow .pb 格式。有办法吗?我从 这里 拿走了模型

完整代码可 在此处 获得。

我在 这里 查找,它显示了以下代码来转换它:

来自 tensorflow/models/research/

 INPUT_TYPE=image_tensor
PIPELINE_CONFIG_PATH={path to pipeline config file}
TRAINED_CKPT_PREFIX={path to model.ckpt}
EXPORT_DIR={path to folder that will be used for export}

python object_detection/export_inference_graph.py \
    --input_type=${INPUT_TYPE} \
    --pipeline_config_path=${PIPELINE_CONFIG_PATH} \
    --trained_checkpoint_prefix=${TRAINED_CKPT_PREFIX} \
    --output_directory=${EXPORT_DIR}

我无法弄清楚要传递什么值:

  • 输入类型
  • PIPELINE_CONFIG_PATH。

原文由 Shivam Sahu 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 687
2 个回答

这是将检查点转换为 SavedModel 的代码

import os
import tensorflow as tf

trained_checkpoint_prefix = 'models/model.ckpt-49491'
export_dir = os.path.join('export_dir', '0')

graph = tf.Graph()
with tf.compat.v1.Session(graph=graph) as sess:
    # Restore from checkpoint
    loader = tf.compat.v1.train.import_meta_graph(trained_checkpoint_prefix + '.meta')
    loader.restore(sess, trained_checkpoint_prefix)

    # Export checkpoint to SavedModel
    builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir)
    builder.add_meta_graph_and_variables(sess,
                                         [tf.saved_model.TRAINING, tf.saved_model.SERVING],
                                         strip_default_attrs=True)
    builder.save()

原文由 Puneith Kaul 发布,翻译遵循 CC BY-SA 4.0 许可协议

按照@Puneith Kaul 的回答,这里是 tensorflow 1.7 版的语法:

 import os
import tensorflow as tf

export_dir = 'export_dir'
trained_checkpoint_prefix = 'models/model.ckpt'
graph = tf.Graph()
loader = tf.train.import_meta_graph(trained_checkpoint_prefix + ".meta" )
sess = tf.Session()
loader.restore(sess,trained_checkpoint_prefix)
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING, tf.saved_model.tag_constants.SERVING], strip_default_attrs=True)
builder.save()

原文由 mcExchange 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进