使用TensorFlow怎么高效的读取数据
更新:HHH   时间:2023-1-7


使用TensorFlow怎么高效的读取数据,针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。

TFRecords

TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件(等会儿就知道为什么了)… …总而言之,这样的文件格式好处多多,所以让我们用起来吧。

TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。

从TFRecords文件中读取数据, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。

接下来,让我们开始读取数据之旅吧~

生成TFRecords文件

我们使用tf.train.Example来定义我们要填入的数据格式,然后使用tf.python_io.TFRecordWriter来写入。

import os
import tensorflow as tf 
from PIL import Image

cwd = os.getcwd()

'''
此处我加载的数据目录如下:
0 -- img1.jpg
   img2.jpg
   img3.jpg
   ...
1 -- img1.jpg
   img2.jpg
   ...
2 -- ...
 这里的0, 1, 2...就是类别,也就是下文中的classes
 classes是我根据自己数据类型定义的一个列表,大家可以根据自己的数据情况灵活运用
...
'''
writer = tf.python_io.TFRecordWriter("train.tfrecords")
for index, name in enumerate(classes):
  class_path = cwd + name + "/"
  for img_name in os.listdir(class_path):
    img_path = class_path + img_name
      img = Image.open(img_path)
      img = img.resize((224, 224))
    img_raw = img.tobytes()       #将图片转化为原生bytes
    example = tf.train.Example(features=tf.train.Features(feature={
      "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
      'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
    }))
    writer.write(example.SerializeToString()) #序列化为字符串
writer.close()

关于Example Feature的相关定义和详细内容,我推荐去官网查看相关API。

基本的,一个Example中包含Features,Features里包含Feature(这里没s)的字典。最后,Feature里包含有一个 FloatList, 或者ByteList,或者Int64List

就这样,我们把相关的信息都存到了一个文件中,所以前面才说不用单独的label文件。而且读取也很方便。

接下来是一个简单的读取小例子:

for serialized_example in tf.python_io.tf_record_iterator("train.tfrecords"):
  example = tf.train.Example()
  example.ParseFromString(serialized_example)

  image = example.features.feature['image'].bytes_list.value
  label = example.features.feature['label'].int64_list.value
  # 可以做一些预处理之类的
  print image, label

使用队列读取

一旦生成了TFRecords文件,为了高效地读取数据,TF中使用队列(queue)读取数据。

def read_and_decode(filename):
  #根据文件名生成一个队列
  filename_queue = tf.train.string_input_producer([filename])

  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)  #返回文件名和文件
  features = tf.parse_single_example(serialized_example,
                    features={
                      'label': tf.FixedLenFeature([], tf.int64),
                      'img_raw' : tf.FixedLenFeature([], tf.string),
                    })

  img = tf.decode_raw(features['img_raw'], tf.uint8)
  img = tf.reshape(img, [224, 224, 3])
  img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
  label = tf.cast(features['label'], tf.int32)

  return img, label

之后我们可以在训练的时候这样使用

img, label = read_and_decode("train.tfrecords")

#使用shuffle_batch可以随机打乱输入
img_batch, label_batch = tf.train.shuffle_batch([img, label],
                        batch_size=30, capacity=2000,
                        min_after_dequeue=1000)
init = tf.initialize_all_variables()

with tf.Session() as sess:
  sess.run(init)
  threads = tf.train.start_queue_runners(sess=sess)
  for i in range(3):
    val, l= sess.run([img_batch, label_batch])
    #我们也可以根据需要对val, l进行处理
    #l = to_categorical(l, 12) 
    print(val.shape, l)

关于使用TensorFlow怎么高效的读取数据问题的解答就分享到这里了,希望以上内容可以对大家有一定的帮助,如果你还有很多疑惑没有解开,可以关注天达云行业资讯频道了解更多相关知识。

返回开发技术教程...