Mnist数据集介绍
这是一个包含了手写数字的图片集合。官网网站:http://yann.lecun.com/exdb/mnist/
它分成了4个文件:
FileName | Description |
---|---|
train-images-idx3-ubyte | 训练集图片,60000张图片 |
train-labels-idx1-ubyte | 训练集标签,60000个标签 |
t10k-images-idx3-ubyte | 测试集图片,10000张图片 |
t10k-labels-idx1-ubyte | 测试集标签,10000个标签 |
Mnist的文件结构
- 所有文件的0x0000~0x0003 表示一个魔数※关于魔数就不展开说明了,反正用不到,你就当一个标签好了
- 0x0004~0x0007 表示文件的内容大小(比如多少个标签,多少个图片)
- 标签文件从0x0008开始就是正式的标签数字了,每个byte放一个标签值,标签值范围为0~9
- 图像文件0x0008~0x0011 代表行数,0x0012~0x0015代表列数,从0x0016开始是每个图片的像素值了。像素值为0~255, 0表示白色,255表示黑色
- 图片行列代表图片由rowscols个像素组成.mnist的图片是28行28列的像素构成的
- 所以对我们来说,标签真正的内容是从0x0008开始的,图片真正的内容是从0x0016开始的。读取文件的时候,我们就要分别偏移8和16个byte了。
下载Mnist数据集
- 直接到官网下载,下载4个文件
- 通过代码
下面展示通过代码的方法
import os
import os.path
import urllib.request
# mnist的下载地址
mnist_url = 'http://yann.lecun.com/exdb/mnist/'
# mnist的下载文件,用dict形式来存储。后续处理文件会比较方便点
mnist_file = {
'tr-img':'train-images-idx3-ubyte.gz', # 训练图片
'tr-lbl':'train-labels-idx1-ubyte.gz', # 训练标签
'te-img':'t10k-images-idx3-ubyte.gz', # 测试图片
'te-lbl':'t10k-labels-idx1-ubyte.gz', # 测试标签
}
# 当前文件的文件夹地址
py_dir = os.path.dirname(os.path.abspath(__file__))
# dataset的文件夹地址,把mnist数据都放入dataset文件夹中
ds_dir = py_dir + '/dataset'
# 放到类里面,先实现下载
class JhuMnist:
def __init__(self):
self._mnist_download() # 下载数据集
def _mnist_download(self):
"""下载mnist数据集"""
# 创建下载的文件夹
if not os.path.exists(ds_dir):
os.makedirs(ds_dir)
# 循环遍历下载文件
for f in mnist_file.values():
f_path = ds_dir + '/' + f
if os.path.exists(f_path):
continue
print(f'开始下载{f} ...')
urllib.request.urlretrieve(mnist_url + f, f_path)
print('Done')
if __name__ == '__main__':
j = JhuMnist()
第一次下载的话会比较慢,下载完成后你会发现新建立了一个dataset文件夹,里面有4个文件:
- train-images-idx3-ubyte.gz
- train-labels-idx1-ubyte.gz
- t10k-images-idx3-ubyte.gz
- t10k-labels-idx1-ubyte.gz
转换数据到Numpy中
用代码来解释,这里代码我只写出新增或修改的部分
import gzip # gzip的解压缩
import numpy as np # Numpy导入
class JhuMnist:
def __init__(self):
self._mnist_download() # 下载数据集
self._convert_numpy() # Numpy转换
def _read_mnist(self,type,f_path):
"""
读取mnist文件
input:
- type: 读取的文件类型('lbl':标签,'img':图片)
- f_path: 读取的文件路径
output:
- numpy array 数组
"""
offset = 8 if type == 'lbl' else 16 # 标签正文移动8,图片移动16
with gzip.open(f_path,'rb') as f:
ds = np.frombuffer(f.read(), np.uint8, offset=offset)
"""
注意如果这里对图片不重新reshape的话,那么图片的ds.shape就是
训练:(10000*28*28,) 而不是(60000,28*28)
测试:(10000*28*28,) 而不是(10000,28*28)
reshape(-1,28*28) 等同于 reshape([-1,28*28])
设置为-1代表让系统自己去计算,等同于reshape(60000,28) or (10000,28)
"""
if type == 'img':
ds = ds.reshape(-1,28*28)
return ds
def _convert_numpy(self):
"""转换数据集变成Numpy类型"""
self.ds = {} # 将内容都放入到dict
for i in mnist_file:
type = 'img' if i.find('img') != -1 else 'lbl' # 根据标签的key设置文件类型,因为要根据文件类型来偏移
self.ds[i] = self._read_mnist(type,ds_dir+'/'+mnist_file[i])
# 这个时候你已经可以通过 j.ds获取到各个数据集的numpy内容了
if __name__ == '__main__':
j = JhuMnist()
把Numpy数据保存起来
我们看到,我们使用numpy的数据其实是通过gzip->numpy的,那么如果每次我们都要gzip解压,转换numpy数据就比较麻烦,所以大家都通过pickle工具来保存训练好的一些numpy数据 老样子,用代码来解释,这里代码我只写出新增或修改的部分
import pickle
# pkl文件的地址
pkl_file = py_dir + '/' + 'mnist.pkl'
class JhuMnist:
def __init__(self):
# 如果已经有pkl文件了就不再去下载转换保存了
if not os.path.exists(pkl_file):
self._mnist_download() # 下载数据集
self._convert_numpy() # Numpy转换
self._save_as_pkl() # 保存pkl
def _save_as_pkl(self):
with open(pkl_file, 'wb') as f:
pickle.dump(self.ds, f, -1)
print('Done')
def load_mnist(self) -> dict:
with open(pkl_file,'rb') as f:
ds = pickle.load(f)
return ds
# 通过Load方法就可以快速的读取数据
if __name__ == '__main__':
j = JhuMnist()
ds = j.load_mnist()
展示图片数据
用matplotlib可以展示图片,老样子,用代码来解释,这里代码我只写出新增或修改的部分
import matplotlib.pyplot as plt
if __name__ == '__main__':
j = JhuMnist()
ds = j.load_mnist()
# 读取训练集的第一个数据,因为图片展示是2维的数据,所以需要在做一次reshape
img = ds['tr-img'][0].reshape(28, 28)
fig =plt.figure()
plt.axis('off')
plt.imshow(img, cmap='gray')
plt.show()
OK,mnist的数据就讲完了,是不是很简单~