Tensorflow基本概念

喵先生的进阶之路

Tensorflow基本概念

1.Tensor

Tensorflow张量,是Tensorflow中最基础的概念,也是最主要的数据结构。它是一个N维数组。

2.Variable

Tensorflow变量,一般用于表示图中的各计算参数,包括矩阵,向量等。它在图中有固定的位置。

3.placeholder

Tensorflow占位符,用于表示输入输出数据的格式,允许传入指定的类型和形状的数据。

4.Session

Tensorflow会话,在Tensorflow中是计算图的具体执行者,与图进行实际的交互。

5.Operation

Tensorflow操作,是Tensorflow图中的节点,它的输入和输出都是Tensor。它的操作都是完成各种操作,包括算数操作、矩阵操作、神经网络构建操作等。

6.Queue

Tensorflow队列,也是图中的一个节点,是一种有状态的节点。

7.QueueRunner

队列管理器,通常会使用多个线程来读取数据,然后使用一个线程来使用数据。使用队列管理器来管理这些读写队列的线程。

8.Coordinator

使用QueueRunner时,由于入队和出队由各自线程完成,且未进行同步通讯,导致程序无法正常结束的情况。为了实现线程之间的同步,需要使用Coordinator

Tensorflow程序步骤

(一)加载训练数据

1.生成或导入样本数据集。
2.归一化处理。
3.划分样本数据集为训练样本集测试样本集

(二)构建训练模型

1.初始化超参数
2.初始化变量和占位符
3.定义模型结构
4.定义损失函数

(三)进行数据训练

1.初始化模型
2.加载数据进行训练

(四)评估和预测

1.评估机器学习模型
2.调优超参数
3.预测结果

加载数据

在Tensorflow中加载数据的方式一共有三种:预加载数据、填充数据和从文件读取数据。

预加载数据

在Tensorflow中定义常量或变量来保存所有数据,例如:

a = tf.constant([1, 2])
b = tf.constant([3, 4])
x = tf.add(a, b)
因为常数会直接存储在数据流图数据结构中,在训练过程中,这个结构体可能会被复制多次,从而导致内存的大量消耗。

填充数据

将数据填充到任意一个张量中。然后通过会话run()函数中的feed_dict参数进行获取数据:

数据量大时,填充数据的方式也存在消耗内存的问题。

从CVS文件中读取数据

要存文件中读取数据, 首先需要使用读取器将数据读取到队列中,然后从队列中获取数据进行处理:

1.创建队列
2.创建读取器获取数据
3.处理数据

读取TFRecords数据

Tensorflow针对处理数据量巨大的应用场景进行了优化,定义了TFRecords格式。

采用这种读取方式读取数据分为两个步骤:

1.把样本数据转换为TFRecords二进制文件
2.读取TFRecords格式。

存储和加载模型

Tensorflow中提供了tf.train.Saver类实现训练模型的保存和加载。

存储模型

在模型的设计和训练的过程中,会消耗大量的时间。为了降低训练过程中意外情况发生造成的不良影响,所以会对训练过程中模型进行定期存储。(模型复用,节省整体训练时间)

saver = tf.train.Saver(max_to_keep, keep_checkpoint_event_n_hours)

存储的模型,会生成四个文件:

image.png

加载模型

为了保证意外中断的模型能够继续训练以及训练完成的模型加载在其他数据上直接使用,会对模型进行加载使用。

加载存储好的模型,包括了两个步骤:

1.加载模型:

saver = tf.train.import_meta_graph("my_test_model-100.meta")

2.加载训练参数:

saver.restore(sess, tf.train.latest_checkpoint('./'))
阅读 570

朴世超
个人学习总结与项目实战问题记录
336 声望
19 粉丝
0 条评论
你知道吗?

336 声望
19 粉丝
宣传栏