Tensorflow保存神经网络参数有妙招:Saver和Restore

摘要:这篇文章将讲解TensorFlow如何保存变量和神经网络参数,通过Saver保存神经网络,再通过Restore调用训练好的神经网络。

本文分享自华为云社区《[[Python人工智能] 十一.Tensorflow如何保存神经网络参数 丨【百变AI秀】](https://bbs.huaweicloud.com/b...)》,作者: eastmount。

一.保存变量

通过tf.Variable()定义权重和偏置变量,然后调用tf.train.Saver()存储变量,将数据保存至本地“my_net/save_net.ckpt”文件中。

# -*- coding: utf-8 -*-
"""
Created on Thu Jan  2 20:04:57 2020
@author: xiuzhang Eastmount CSDN
"""
import tensorflow as tf
import numpy as np

#---------------------------------------保存文件---------------------------------------
W = tf.Variable([[1,2,3], [3,4,5]], dtype=tf.float32, name='weights') #2行3列的数据
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')

# 初始化
init = tf.initialize_all_variables()

# 定义saver 存储各种变量
saver = tf.train.Saver()

# 使用Session运行初始化
with tf.Session() as sess:
    sess.run(init)
    # 保存 官方保存格式为ckpt
    save_path = saver.save(sess, "my_net/save_net.ckpt")
    print("Save to path:", save_path)

“Save to path: my_net/save_net.ckpt”保存成功如下图所示:
image.png

打开内容如下图所示:
image.png

接着定义标记变量train,通过Restore操作使用我们保存好的变量。注意,在Restore时需要定义相同的dtype和shape,不需要再定义init。最后直接通过 saver.restore(sess, “my_net/save_net.ckpt”) 提取保存的变量并输出即可。

# -*- coding: utf-8 -*-
"""
Created on Thu Jan  2 20:04:57 2020
@author: xiuzhang Eastmount CSDN
"""
import tensorflow as tf
import numpy as np

# 标记变量
train = False

#---------------------------------------保存文件---------------------------------------
# Save
if train==True:
    # 定义变量
    W = tf.Variable([[1,2,3], [3,4,5]], dtype=tf.float32, name='weights') #2行3列的数据
    b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')

    # 初始化
    init = tf.global_variables_initializer()
 
    # 定义saver 存储各种变量
    saver = tf.train.Saver()
 
    # 使用Session运行初始化
    with tf.Session() as sess:
        sess.run(init)
        # 保存 官方保存格式为ckpt
        save_path = saver.save(sess, "my_net/save_net.ckpt")
        print("Save to path:", save_path)
#---------------------------------------Restore变量-------------------------------------
# Restore
if train==False:
    # 记住在Restore时定义相同的dtype和shape
    # redefine the same shape and same type for your variables
    W = tf.Variable(np.arange(6).reshape((2,3)), dtype=tf.float32, name='weights') #空变量
    b = tf.Variable(np.arange(3).reshape((1,3)), dtype=tf.float32, name='biases') #空变量
 
    # Restore不需要定义init
    saver = tf.train.Saver()
    with tf.Session() as sess:
        # 提取保存的变量
        saver.restore(sess, "my_net/save_net.ckpt")
        # 寻找相同名字和标识的变量并存储在W和b中
        print("weights", sess.run(W))
        print("biases", sess.run(b))

运行代码,如果报错“NotFoundError: Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. ”,则需要重置Spyder即可。
image.png

最后输出之前所保存的变量,weights为 [[1,2,3], [3,4,5]],偏置为 [[1,2,3]]。
image.png

二.保存神经网络

那么,TensorFlow如何保存我们的神经网络框架呢?我们需要把整个网络训练好再进行保存,其方法和上面类似,完整代码如下:

"""
Created on Sun Dec 29 19:21:08 2019
@author: xiuzhang Eastmount CSDN
"""
import os
import glob
import cv2
import numpy as np
import tensorflow as tf

# 定义图片路径
path = 'photo/'

#---------------------------------第一步 读取图像-----------------------------------
def read_img(path):
    cate = [path + x for x in os.listdir(path) if os.path.isdir(path + x)]
    imgs = []
    labels = []
    fpath = []
    for idx, folder in enumerate(cate):
        # 遍历整个目录判断每个文件是不是符合
        for im in glob.glob(folder + '/*.jpg'):
            #print('reading the images:%s' % (im))
            img = cv2.imread(im)             #调用opencv库读取像素点
            img = cv2.resize(img, (32, 32))  #图像像素大小一致
            imgs.append(img)                 #图像数据
            labels.append(idx)               #图像类标
            fpath.append(path+im)            #图像路径名
            #print(path+im, idx)
 
    return np.asarray(fpath, np.string_), np.asarray(imgs, np.float32), np.asarray(labels, np.int32)

# 读取图像
fpaths, data, label = read_img(path)
print(data.shape)  # (1000, 256, 256, 3)
# 计算有多少类图片
num_classes = len(set(label))
print(num_classes)

# 生成等差数列随机调整图像顺序
num_example = data.shape[0]
arr = np.arange(num_example)
np.random.shuffle(arr)
data = data[arr]
label = label[arr]
fpaths = fpaths[arr]

# 拆分训练集和测试集 80%训练集 20%测试集
ratio = 0.8
s = np.int(num_example * ratio)
x_train = data[:s]
y_train = label[:s]
fpaths_train = fpaths[:s] 
x_val = data[s:]
y_val = label[s:]
fpaths_test = fpaths[s:] 
print(len(x_train),len(y_train),len(x_val),len(y_val)) #800 800 200 200
print(y_val)
#---------------------------------第二步 建立神经网络-----------------------------------
# 定义Placeholder
xs = tf.placeholder(tf.float32, [None, 32, 32, 3])  #每张图片32*32*3个点
ys = tf.placeholder(tf.int32, [None])               #每个样本有1个输出
# 存放DropOut参数的容器 
drop = tf.placeholder(tf.float32)                   #训练时为0.25 测试时为0

# 定义卷积层 conv0
conv0 = tf.layers.conv2d(xs, 20, 5, activation=tf.nn.relu)    #20个卷积核 卷积核大小为5 Relu激活
# 定义max-pooling层 pool0
pool0 = tf.layers.max_pooling2d(conv0, [2, 2], [2, 2])        #pooling窗口为2x2 步长为2x2
print("Layer0:\n", conv0, pool0)
 
# 定义卷积层 conv1
conv1 = tf.layers.conv2d(pool0, 40, 4, activation=tf.nn.relu) #40个卷积核 卷积核大小为4 Relu激活
# 定义max-pooling层 pool1
pool1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2])        #pooling窗口为2x2 步长为2x2
print("Layer1:\n", conv1, pool1)

# 将3维特征转换为1维向量
flatten = tf.layers.flatten(pool1)

# 全连接层 转换为长度为400的特征向量
fc = tf.layers.dense(flatten, 400, activation=tf.nn.relu)
print("Layer2:\n", fc)

# 加上DropOut防止过拟合
dropout_fc = tf.layers.dropout(fc, drop)

# 未激活的输出层
logits = tf.layers.dense(dropout_fc, num_classes)
print("Output:\n", logits)

# 定义输出结果
predicted_labels = tf.arg_max(logits, 1)
#---------------------------------第三步 定义损失函数和优化器---------------------------------

# 利用交叉熵定义损失
losses = tf.nn.softmax_cross_entropy_with_logits(
        labels = tf.one_hot(ys, num_classes),       #将input转化为one-hot类型数据输出
        logits = logits)

# 平均损失
mean_loss = tf.reduce_mean(losses)

# 定义优化器 学习效率设置为0.0001
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(losses)
#------------------------------------第四步 模型训练和预测-----------------------------------
# 用于保存和载入模型
saver = tf.train.Saver()
# 训练或预测
train = False
# 模型文件路径
model_path = "model/image_model"

with tf.Session() as sess:
    if train:
        print("训练模式")
        # 训练初始化参数
        sess.run(tf.global_variables_initializer())
        # 定义输入和Label以填充容器 训练时dropout为0.25
        train_feed_dict = {
                xs: x_train,
                ys: y_train,
                drop: 0.25
        }
        # 训练学习1000次
        for step in range(1000):
            _, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)
            if step % 50 == 0:  #每隔50次输出一次结果
                print("step = {}\t mean loss = {}".format(step, mean_loss_val))
        # 保存模型
        saver.save(sess, model_path)
        print("训练结束,保存模型到{}".format(model_path))
    else:
        print("测试模式")
        # 测试载入参数
        saver.restore(sess, model_path)
        print("从{}载入模型".format(model_path))
        # label和名称的对照关系
        label_name_dict = {
            0: "人类",
            1: "沙滩",
            2: "建筑",
            3: "公交",
            4: "恐龙",
            5: "大象",
            6: "花朵",
            7: "野马",
            8: "雪山",
            9: "美食"
        }
        # 定义输入和Label以填充容器 测试时dropout为0
        test_feed_dict = {
            xs: x_val,
            ys: y_val,
            drop: 0
        }
 
        # 真实label与模型预测label
        predicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)
        for fpath, real_label, predicted_label in zip(fpaths_test, y_val, predicted_labels_val):
            # 将label id转换为label名
            real_label_name = label_name_dict[real_label]
            predicted_label_name = label_name_dict[predicted_label]
            print("{}\t{} => {}".format(fpath, real_label_name, predicted_label_name))
        # 评价结果
        print("正确预测个数:", sum(y_val==predicted_labels_val))
        print("准确度为:", 1.0*sum(y_val==predicted_labels_val) / len(y_val))

核心步骤为:

saver = tf.train.Saver()
model_path = "model/image_model"
with tf.Session() as sess:
    if train:
        #保存神经网络
        sess.run(tf.global_variables_initializer())
        for step in range(1000):
            _, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)
            if step % 50 == 0:
                print("step = {}\t mean loss = {}".format(step, mean_loss_val))
        saver.save(sess, model_path)
    else:
        #载入神经网络
        saver.restore(sess, model_path)
        predicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)
        for fpath, real_label, predicted_label in zip(fpaths_test, y_val, predicted_labels_val):
            real_label_name = label_name_dict[real_label]
            predicted_label_name = label_name_dict[predicted_label]
            print("{}\t{} => {}".format(fpath, real_label_name, predicted_label_name))    

预测输出结果如下图所示,最终预测正确181张图片,准确度为0.905。相比之前机器学习KNN的0.500有非常高的提升。
image.png

测试模式

INFO:tensorflow:Restoring parameters from model/image_model
从model/image_model载入模型
b'photo/photo/3\\335.jpg'       公交 => 公交
b'photo/photo/1\\129.jpg'       沙滩 => 沙滩
b'photo/photo/7\\740.jpg'       野马 => 野马
b'photo/photo/5\\564.jpg'       大象 => 大象
...
b'photo/photo/9\\974.jpg'       美食 => 美食
b'photo/photo/2\\220.jpg'       建筑 => 公交
b'photo/photo/9\\912.jpg'       美食 => 美食
b'photo/photo/4\\459.jpg'       恐龙 => 恐龙
b'photo/photo/5\\525.jpg'       大象 => 大象
b'photo/photo/0\\44.jpg'        人类 => 人类

正确预测个数: 181
准确度为: 0.905

点击关注,第一时间了解华为云新鲜技术~


开发者之家
华为云开发者社区,提供全面深入的云计算前景分析、丰富的技术干货、程序样例,分享华为云前沿资讯动态...

生于云,长于云,让开发者成为决定性力量

1.3k 声望
1.7k 粉丝
0 条评论
推荐阅读
【贺】来自开发者的点赞,华为云开发者联盟入选 2022 中国技术品牌影响力企业榜
2023 年 1 月 4 日,中国技术先锋年度评选 | 2022 中国技术品牌影响力企业榜单正式发布。作为中国领先的新一代开发者社区,SegmentFault 思否依托数百万开发者用户数据分析,各科技企业在国内技术领域的行为及影...

华为云开发者联盟阅读 349

Ubuntu20.04 从源代码编译安装 python3.10
Ubuntu 22.04 Release DateUbuntu 22.04 Jammy Jellyfish is scheduled for release on April 21, 2022If you’re ready to use Ubuntu 22.04 Jammy Jellyfish, you can either upgrade your current Ubuntu syste...

ponponon1阅读 4k评论 1

日常Python 代码片段整理
1、简单的 HTTP Web 服务器 {代码...} 2、单行循环List {代码...} 3、更新字典 {代码...} 4、拆分多行字符串 {代码...} 5、跟踪列表中元素的频率 {代码...} 6、不使用 Pandas 读取 CSV 文件 {代码...} 7、将列表...

墨城2阅读 301

Unicode 正则表达式(qbit)
前言本文根据《精通正则表达式》和 Unicode Regular Expressions 整理。本文的示例默认以 Python3 为实现语言,用到 Python3 的 re 模块或 regex 库。基本的 Unicode 属性分类 {代码...} 基本的 Unicode 子属性Le...

qbit阅读 4.3k

Python + Sqlalchemy 对数据库的批量插入或更新(Upsert)
由于不同数据库对这种 upsert 的实现机制不同,Sqlalchemy 也就不再试图做一致性的封装了,而是提供了各自的方言 API,具体到 Mysql,就是给 insert statement ,增加了 on_duplicate_key_update 方法。

songofhawk1阅读 1.9k评论 4

封面图
打脸了兄弟们,Go1.20 arena 来了!
大家好,我是煎鱼。大概半年前,我写过一篇文章《Go 要违背初心吗?新提案:手动管理内存》。有兴趣了深入解的同学,可以再回顾一下。当时我们还想着 Go 团队应该不会接纳,至少不会那么快:懒得翻也可以看我再次...

煎鱼阅读 3.2k

AIGC神器CLIP:技术详解及应用示例
编者按:上一期,我们介绍了Diffusion模型的发展历程、核心原理及其对AIGC发展的推动作用。本期,我们将共同走进另一项AI重要突破——CLIP,著名的DALLE和Stable Diffusion均采用了CLIP哦。

Baihai_IDP1阅读 910

封面图

生于云,长于云,让开发者成为决定性力量

1.3k 声望
1.7k 粉丝
宣传栏