• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    python读取mnist数据集方法案例详解

    mnist手写数字数据集在机器学习中非常常见,这里记录一下用python从本地读取mnist数据集的方法。

    数据集格式介绍

    这部分内容网络上很常见,这里还是简明介绍一下。网络上下载的mnist数据集包含4个文件:

    前两个分别是测试集的image和label,包含10000个样本。后两个是训练集的,包含60000个样本。.gz表示这个一个压缩包,如果进行解压的话,会得到.ubyte格式的二进制文件。

    上图是训练集的label和image数据的存储格式。两个文件最开始都有magic number和number of images/items两个数据,有用的是第二个,表示文件中存储的样本个数。另外要注意的是数据的位数,有32位整型和8位整型两种。

    读取方法

    .gz格式的文件读取

    需要import gzip
    读取训练集的代码如下:

    def load_mnist_train(path, kind='train'): 
    '‘'
    path:数据集的路径
    kind:值为train,代表读取训练集
    ‘'‘   
        labels_path = os.path.join(path,'%s-labels-idx1-ubyte.gz'% kind)
        images_path = os.path.join(path,'%s-images-idx3-ubyte.gz'% kind)
        #使用gzip打开文件
        with gzip.open(labels_path, 'rb') as lbpath:
    	    #使用struct.unpack方法读取前两个数据,>代表高位在前,I代表32位整型。lbpath.read(8)表示一次从文件中读取8个字节
    	    #这样读到的前两个数据分别是magic number和样本个数
            magic, n = struct.unpack('>II',lbpath.read(8))
            #使用np.fromstring读取剩下的数据,lbpath.read()表示读取所有的数据
            labels = np.fromstring(lbpath.read(),dtype=np.uint8)
        with gzip.open(images_path, 'rb') as imgpath:
            magic, num, rows, cols = struct.unpack('>IIII',imgpath.read(16))
            images = np.fromstring(imgpath.read(),dtype=np.uint8).reshape(len(labels), 784)
        return images, labels
    

    读取测试集的代码类似。

    非压缩文件的读取

    如果在本地对四个文件解压缩之后,得到的就是.ubyte格式的文件,这时读取的代码有所变化。

    def load_mnist_train(path, kind='train'): 
    '‘'
    path:数据集的路径
    kind:值为train,代表读取训练集
    ‘'‘   
        labels_path = os.path.join(path,'%s-labels-idx1-ubyte'% kind)
        images_path = os.path.join(path,'%s-images-idx3-ubyte'% kind)
        #不再用gzip打开文件
        with open(labels_path, 'rb') as lbpath:
    	    #使用struct.unpack方法读取前两个数据,>代表高位在前,I代表32位整型。lbpath.read(8)表示一次从文件中读取8个字节
    	    #这样读到的前两个数据分别是magic number和样本个数
            magic, n = struct.unpack('>II',lbpath.read(8))
            #使用np.fromfile读取剩下的数据
            labels = np.fromfile(lbpath,dtype=np.uint8)
        with gzip.open(images_path, 'rb') as imgpath:
            magic, num, rows, cols = struct.unpack('>IIII',imgpath.read(16))
            images = np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels), 784)
        return images, labels
    

    读取之后可以查看images和labels的长度,确认读取是否正确。

    到此这篇关于python读取mnist数据集方法案例详解的文章就介绍到这了,更多相关python读取mnist数据集方法内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

    您可能感兴趣的文章:
    • Python rindex()方法案例详解
    • Python 实现静态链表案例详解
    • Python 概率生成问题案例详解
    • Python 二叉树的概念案例详解
    • Python实现堆排序案例详解
    • 超实用的 10 段 Python 案例
    上一篇:Pyqt5将多个类组合在一个界面显示的完整示例
    下一篇:python实现Nao机器人的单目测距
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    python读取mnist数据集方法案例详解 python,读取,mnist,数据,集,