• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    tensorflow学习笔记之tfrecord文件的生成与读取

    训练模型时,我们并不是直接将图像送入模型,而是先将图像转换为tfrecord文件,再将tfrecord文件送入模型。为进一步理解tfrecord文件,本例先将6幅图像及其标签转换为tfrecord文件,然后读取tfrecord文件,重现6幅图像及其标签。
    1、生成tfrecord文件

    import os
    import numpy as np
    import tensorflow as tf
    from PIL import Image
    
    filenames = [
    'images/cat/1.jpg',
    'images/cat/2.jpg',
    'images/dog/1.jpg',
    'images/dog/2.jpg',
    'images/pig/1.jpg',
    'images/pig/2.jpg',]
    
    labels = {'cat':0, 'dog':1, 'pig':2}
    
    def int64_feature(values):
    	if not isinstance(values, (tuple, list)):
    		values = [values]
    	return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
    
    def bytes_feature(values):
    	return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
    
    with tf.Session() as sess:
    	output_filename = os.path.join('images/train.tfrecords')
    	with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
    		for filename in filenames:
    			#读取图像
    			image_data = Image.open(filename)
    			#图像灰度化
    			image_data = np.array(image_data.convert('L'))
    			#将图像转化为bytes
    			image_data = image_data.tobytes()
    			#读取label
    			label = labels[filename.split('/')[-2]]
    			#生成protocol数据类型
    			example = tf.train.Example(features=tf.train.Features(feature={'image': bytes_feature(image_data),
    																			'label': int64_feature(label)}))
    			tfrecord_writer.write(example.SerializeToString())

    2、读取tfrecord文件

    import tensorflow as tf
    import matplotlib.pyplot as plt
    from PIL import Image
    
    # 根据文件名生成一个队列
    filename_queue = tf.train.string_input_producer(['images/train.tfrecords'])
    reader = tf.TFRecordReader()
    # 返回文件名和文件
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example, 
    									features={'image': tf.FixedLenFeature([], tf.string), 
    												'label': tf.FixedLenFeature([], tf.int64)})
    # 获取图像数据
    image = tf.decode_raw(features['image'], tf.uint8)
    # 恢复图像原始尺寸[高,宽]
    image = tf.reshape(image, [60, 160])
    # 获取label
    label = tf.cast(features['label'], tf.int32)
    
    with tf.Session() as sess:
    	# 创建一个协调器,管理线程
    	coord = tf.train.Coordinator()
    	# 启动QueueRunner, 此时文件名队列已经进队
    	threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    	for i in range(6):
    		image_b, label_b = sess.run([image, label])
    		img = Image.fromarray(image_b, 'L')
    		plt.imshow(img)
    		plt.axis('off')
    		plt.show()
    		print(label_b)
    
    	# 通知其他线程关闭
    	coord.request_stop()
    	# 其他所有线程关闭之后,这一函数才能返回
    	coord.join(threads)

    到此这篇关于tensorflow学习笔记之tfrecord文件的生成与读取的文章就介绍到这了,更多相关tfrecord文件的生成与读取内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

    您可能感兴趣的文章:
    • tensorflow TFRecords文件的生成和读取的方法
    • tensorflow生成多个tfrecord文件实例
    • Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取
    • tensorflow将图片保存为tfrecord和tfrecord的读取方式
    • tensorflow入门:TFRecordDataset变长数据的batch读取详解
    • Tensorflow中使用tfrecord方式读取数据的方法
    上一篇:python创建与遍历二叉树的方法实例
    下一篇:Python中快速掌握Data Frame的常用操作
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    tensorflow学习笔记之tfrecord文件的生成与读取 tensorflow,学习,笔记,之,tfrecord,