MNIST数据集(http://yann.lecun.com/exdb/mnist/)
手写数字图片数据集,存在60000个训练样本,10000个测试样本。每个样本为一个28X28像素的图片。

主要包含四个压缩文件:
train-images-idx3-ubyte.gz 训练样本图片的原始数据 train-labels-idx1-ubyte.gz 训练样本图片对应的标签数据 t10k-images-idx3-ubyte.gz 测试样本图片的原始数据 t10k-labels-idx1-ubyte.gz 测试样本图片对应的标签数据 第一步:数据集的下载
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
trainData = MNIST(root = "./",
train = True,
transform=ToTensor(),
download = True)
testData = MNIST(root = "./",
train = False,
transform=ToTensor(),
download = True)如果download为True,在当前目录下出现MNIST文件夹,其中./MNIST/raw内会出现MNIST的四个文件。否则,会从./MNIST/raw自动加载四个文件。
第二步:数据集加载
from torch.utils.data import DataLoader
batch_size = 64
trainData_loader = DataLoader(dataset = trainData,
batch_size = batch_size,
shuffle = True)
testData_loader = DataLoader(dataset = testData,
batch_size = batch_size,
shuffle = True)batch_size = 64 代表每次加载64个样本
第三步:理解样本数据
3.1 数据查看
examples = enumerate(trainData_loader)
idx, (data,labels) = next(examples)
print(data.shape)
print(labels)torch.Size([64, 1, 28, 28])
tensor([3, 9, 0, 1, 2, 1, 5, 1, 8, 1, 9, 8, 3, 4, 3, 0, 9, 8, 3, 9, 4, 9, 6, 9,
7, 4, 5, 3, 0, 6, 1, 4, 0, 6, 1, 8, 5, 0, 5, 8, 0, 7, 1, 8, 1, 4, 6, 9,
4, 6, 7, 4, 2, 5, 4, 7, 1, 2, 6, 1, 9, 0, 0, 7])data.shape [64,1,28,28] - 64个样本,每个样本有一个通道,每个通道包含28X28的像素;
label - 对应这64个样本的标签;
注:一般灰度图像只有一个通道;如果是彩色图像,是三个通道,对应RGB三原色。
labels - 64个样本图片对应的标签。
3.2 数据显示
import matplotlib.pyplot as plt
data = data.squeeze() # 删除通道维度 [64,1,28,28]->[64,28,28]
fig = plt.figure(dpi=300)
for i in range(8):
for j in range(8):
plt.subplot(8,8, i*8+j+1 )
plt.imshow(data[i*8+j])
plt.xticks([])
plt.yticks([])
plt.show()
生成的图片 与 3.1步骤中显示 labels标签一一对应
附录:完整代码
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
trainData = MNIST(root = "./",
train = True,
transform=ToTensor(),
download = True)
testData = MNIST(root = "./",
train = False,
transform=ToTensor(),
download = True)
batch_size = 64
trainData_loader = DataLoader(dataset = trainData,
batch_size = batch_size,
shuffle = True)
testData_loader = DataLoader(dataset = testData,
batch_size = batch_size,
shuffle = True)
examples = enumerate(trainData_loader)
idx, (data,labels) = next(examples)
fig = plt.figure()
for i in range(8):
for j in range(8):
plt.subplot(8,8, i*8+j+1 )
plt.imshow(data.squeeze()[i*8+j])
plt.xticks([])
plt.yticks([])
plt.show()版权声明:本文为Austin6035原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。