1.准备数据,使用占位符,动态加载训练数据
x=tf.placeholder(tf.float32,[None,784])
y_true=tf.placeholder(tf.int32,[None,10])
2.初始化参数,建立模型
weight=tf.Variable(tf.random_normal([784,10],mean=0.0,stddev=1.0))
bias=tf.Variable(tf.canstant(0.0,shape=[10]))
y_predict=tf.matmul(x,weight)+bias
3.求平均交叉熵损失
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict))
4.梯度下降优化
train_op=tf.GradientDescentOptimizer(0.3).minimize(loss)
5.求准确率
equal_list=tf.equal(tf.arg_max(y_true,1),tf.arg_max(y_predict,1))
accuracy=tf.reduce_mean(tf.cast(equal_list,tf.float32))
完整代码:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
mnist = input_data.read_data_sets('./data/MNISI_data/', one_hot=True)
def full_connection():
# 1.准备数据
with tf.variable_scope("data"):
x = tf.placeholder(tf.float32, [None, 784])
y_true = tf.placeholder(tf.int32, [None, 10])
# 2.建立模型
with tf.variable_scope('predict_model'):
weight = tf.Variable(tf.random_normal([784, 10], mean=0.0, stddev=1.0), name='w')
bias = tf.Variable(tf.constant(0.0, shape=[10]))
y_predict = tf.matmul(x, weight) + bias
# 3.平均交叉熵损失
with tf.variable_scope('loss'):
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict))
# 4.梯度下降优化
with tf.variable_scope('optimizer'):
train_op = tf.train.GradientDescentOptimizer(0.4).minimize(loss)
# 5.求准确率
with tf.variable_scope('acc'):
equal_list = tf.equal(tf.arg_max(y_true, 1), tf.arg_max(y_predict, 1))
accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))
init_op = tf.initialize_all_variables()
# 收集变量,tensorboard使用
tf.summary.scalar('loss', loss)
tf.summary.scalar('accuracy', accuracy)
tf.summary.histogram('weight', weight)
tf.summary.histogram('bias', bias)
merged = tf.summary.merge_all()
saver = tf.train.Saver()
is_train = False
with tf.Session() as sess:
if is_train == True:
sess.run(init_op)
fileWriter = tf.summary.FileWriter('./temp/summary/test', graph=sess.graph)
if os.path.exists('./temp/ckpt/checkpoint'):
# 加载训练的模型
saver.restore(sess, './temp/ckpt/full_conn')
for i in range(4000):
# 每次批量货期50个数据集
mnist_x, mnist_y = mnist.train.next_batch(50)
sess.run(train_op, feed_dict={x: mnist_x, y_true: mnist_y})
summary = sess.run(merged, feed_dict={x: mnist_x, y_true: mnist_y})
fileWriter.add_summary(summary, i)
print("训练低%d步,准确率为:%f" % (i, sess.run(accuracy, feed_dict={x: mnist_x, y_true: mnist_y})))
# 保存训练完的模型
saver.save(sess, './temp/ckpt/full_conn')
else:
saver.restore(sess, './temp/ckpt/full_conn')
for i in range(100):
# 每次批量货期1个数据集
x_test, y_test = mnist.test.next_batch(1)
print('低%d张图片,手写数字图片目标:%d--%d' % (
i,
tf.arg_max(y_test, 1).eval(),
tf.arg_max(sess.run(y_predict, feed_dict={x: x_test, y_true: y_test}), 1).eval()
))
if __name__ == '__main__':
full_connection()
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。