TFrecord是一个Google提供的用于深度学习的数据格式,个人觉得很方便规范,值得学习。本文主要讲的是怎么存储array,别的数据存储较为简单,举一反三就行。

在TFrecord中的数据都需要进行一个转化的过程,这个转化分成三种

  • int64
  • float
  • bytes

一般来讲我们的图片读进来以后是两种形式,

  1. tf.image.decode_jpeg 解码图片读取成 (width,height,channels)的矩阵,这个读取的方式和cv2.imread以及ndimage.imread一样
  2. tf.image.convert_image_dtype会将读进来的上面的矩阵归一化,一般来讲我们都要进行这个归一化的过程。归一化的好处可以去查。

但是存储在TFrecord里面的不能是array的形式,所以我们需要利用tostring()将上面的矩阵转化成字符串再通过tf.train.BytesList转化成可以存储的形式。

下面给个实例代码,大家看看就懂了

adjust_pic.py : 作用就是转化Image大小

# -*- coding: utf-8 -*-  
  
import tensorflow as tf  
  
def resize(img_data, width, high, method=0):  
    return tf.image.resize_images(img_data,[width, high], method)
    

pic2tfrecords.py :将图片存成TFrecord

# -*- coding: utf-8 -*-  
# 将图片保存成 TFRecord  
import os.path  
import matplotlib.image as mpimg  
import tensorflow as tf  
import adjust_pic as ap  
from PIL import Image  
  
  
SAVE_PATH = 'data/dataset.tfrecords'  
  
  
def _int64_feature(value):  
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))  
  
def _bytes_feature(value):  
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))  
  
def load_data(datafile, width, high, method=0, save=False):  
    train_list = open(datafile,'r')  
    # 准备一个 writer 用来写 TFRecord 文件  
    writer = tf.python_io.TFRecordWriter(SAVE_PATH)  
  
    with tf.Session() as sess:  
        for line in train_list:  
            # 获得图片的路径和类型  
            tmp = line.strip().split(' ')  
            img_path = tmp[0]  
            label = int(tmp[1])  
  
            # 读取图片  
            image = tf.gfile.FastGFile(img_path, 'r').read()  
            # 解码图片(如果是 png 格式就使用 decode_png)  
            image = tf.image.decode_jpeg(image)  
            # 转换数据类型  
            # 因为为了将图片数据能够保存到 TFRecord 结构体中,所以需要将其图片矩阵转换成 string,所以为了在使用时能够转换回来,这里确定下数据格式为 tf.float32  
            image = tf.image.convert_image_dtype(image, dtype=tf.float32)  
            # 既然都将图片保存成 TFRecord 了,那就先把图片转换成希望的大小吧  
            image = ap.resize(image, width, high)  
            # 执行 op: image  
            image = sess.run(image)  
              
            # 将其图片矩阵转换成 string  
            image_raw = image.tostring()  
            # 将数据整理成 TFRecord 需要的数据结构  
            example = tf.train.Example(features=tf.train.Features(feature={  
                'image_raw': _bytes_feature(image_raw),  
                'label': _int64_feature(label),  
                }))  
  
            # 写 TFRecord  
            writer.write(example.SerializeToString())  
  
    writer.close()  
  
  
load_data('train_list.txt_bak', 224, 224) 


tfrecords2data.py :读取Tfrecord里的内容

# -*- coding: utf-8 -*-  
# 从 TFRecord 中读取并保存图片  
import tensorflow as tf  
import numpy as np  
  
  
SAVE_PATH = 'data/dataset.tfrecords'  
  
  
def load_data(width, high):  
    reader = tf.TFRecordReader()  
    filename_queue = tf.train.string_input_producer([SAVE_PATH])  
  
    # 从 TFRecord 读取内容并保存到 serialized_example 中  
    _, serialized_example = reader.read(filename_queue)  
    # 读取 serialized_example 的格式  
    features = tf.parse_single_example(  
        serialized_example,  
        features={  
            'image_raw': tf.FixedLenFeature([], tf.string),  
            'label': tf.FixedLenFeature([], tf.int64),  
        })  
  
    # 解析从 serialized_example 读取到的内容  
    images = tf.decode_raw(features['image_raw'], tf.uint8)  
    labels = tf.cast(features['label'], tf.int64)  
  
    with tf.Session() as sess:  
        # 启动多线程  
        coord = tf.train.Coordinator()  
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)  
  
        # 因为我这里只有 2 张图片,所以下面循环 2 次  
        for i in range(2):  
            # 获取一张图片和其对应的类型  
            label, image = sess.run([labels, images])  
            # 这里特别说明下:  
            #   因为要想把图片保存成 TFRecord,那就必须先将图片矩阵转换成 string,即:  
            #       pic2tfrecords.py 中 image_raw = image.tostring() 这行  
            #   所以这里需要执行下面这行将 string 转换回来,否则会无法 reshape 成图片矩阵,请看下面的小例子:  
            #       a = np.array([[1, 2], [3, 4]], dtype=np.int64) # 2*2 的矩阵  
            #       b = a.tostring()  
            #       # 下面这行的输出是 32,即: 2*2 之后还要再乘 8  
            #       # 如果 tostring 之后的长度是 2*2=4 的话,那可以将 b 直接 reshape([2, 2]),但现在的长度是 2*2*8 = 32,所以无法直接 reshape  
            #       # 同理如果你的图片是 500*500*3 的话,那 tostring() 之后的长度是 500*500*3 后再乘上一个数  
            #       print len(b)  
            #  
            #   但在网上有很多提供的代码里都没有下面这一行,你们那真的能 reshape ?  
            image = np.fromstring(image, dtype=np.float32)  
            # reshape 成图片矩阵  
            image = tf.reshape(image, [224, 224, 3])  
            # 因为要保存图片,所以将其转换成 uint8  
            image = tf.image.convert_image_dtype(image, dtype=tf.uint8)  
            # 按照 jpeg 格式编码  
            image = tf.image.encode_jpeg(image)  
            # 保存图片  
            with tf.gfile.GFile('pic_%d.jpg' % label, 'wb') as f:  
                f.write(sess.run(image))  
  
  
load_data(224, 224)  

以上代码摘自TFRecord 的使用,觉得挺好的,没改原样照搬,我自己做实验时改了很多,因为我是在im2txt的基础上写的。


jasperyang
203 声望58 粉丝

Highest purpose is Hacking...