6. 手写数字图片数据集MNIST

MNIST数据集(http://yann.lecun.com/exdb/mnist/)

手写数字图片数据集,存在60000个训练样本,10000个测试样本。每个样本为一个28X28像素的图片。

主要包含四个压缩文件:

  1. train-images-idx3-ubyte.gz训练样本图片的原始数据
    train-labels-idx1-ubyte.gz训练样本图片对应的标签数据
    t10k-images-idx3-ubyte.gz测试样本图片的原始数据
    t10k-labels-idx1-ubyte.gz测试样本图片对应的标签数据

    第一步:数据集的下载

  2. MNIST — Torchvision 0.12 documentationhttps://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html#torchvision.datasets.MNIST

from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST

trainData = MNIST(root = "./",            
                  train = True,          
                  transform=ToTensor(), 
                  download = True)       
testData = MNIST(root = "./",
                  train = False,
                  transform=ToTensor(),
                  download = True)

如果download为True,在当前目录下出现MNIST文件夹,其中./MNIST/raw内会出现MNIST的四个文件。否则,会从./MNIST/raw自动加载四个文件。

第二步:数据集加载

torch.utils.data — PyTorch 1.11.0 documentationhttps://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

from torch.utils.data import DataLoader
batch_size = 64
trainData_loader = DataLoader(dataset = trainData,
                              batch_size = batch_size, 
                              shuffle = True)  
testData_loader = DataLoader(dataset = testData,
                             batch_size = batch_size,        
                             shuffle = True)

batch_size = 64 代表每次加载64个样本

第三步:理解样本数据

3.1 数据查看

examples = enumerate(trainData_loader)
idx, (data,labels) = next(examples)
print(data.shape)
print(labels)
torch.Size([64, 1, 28, 28])
tensor([3, 9, 0, 1, 2, 1, 5, 1, 8, 1, 9, 8, 3, 4, 3, 0, 9, 8, 3, 9, 4, 9, 6, 9,
        7, 4, 5, 3, 0, 6, 1, 4, 0, 6, 1, 8, 5, 0, 5, 8, 0, 7, 1, 8, 1, 4, 6, 9,
        4, 6, 7, 4, 2, 5, 4, 7, 1, 2, 6, 1, 9, 0, 0, 7])

data.shape [64,1,28,28] -  64个样本,每个样本有一个通道,每个通道包含28X28的像素;

label - 对应这64个样本的标签;

注:一般灰度图像只有一个通道;如果是彩色图像,是三个通道,对应RGB三原色。

labels - 64个样本图片对应的标签。

3.2 数据显示

import matplotlib.pyplot as plt
data = data.squeeze()   # 删除通道维度 [64,1,28,28]->[64,28,28]

fig = plt.figure(dpi=300)
for i in range(8):
    for j in range(8):
        plt.subplot(8,8, i*8+j+1 )
        plt.imshow(data[i*8+j])
        plt.xticks([])
        plt.yticks([])
plt.show()

生成的图片 与 3.1步骤中显示 labels标签一一对应

附录:完整代码

from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

trainData = MNIST(root = "./",            
                  train = True,          
                  transform=ToTensor(), 
                  download = True)       
testData = MNIST(root = "./",
                  train = False,
                  transform=ToTensor(),
                  download = True)

batch_size = 64
trainData_loader = DataLoader(dataset = trainData,
                              batch_size = batch_size, 
                              shuffle = True)  

testData_loader = DataLoader(dataset = testData,
                             batch_size = batch_size,        
                             shuffle = True)

examples = enumerate(trainData_loader)
idx, (data,labels) = next(examples)

fig = plt.figure()
for i in range(8):
    for j in range(8):
        plt.subplot(8,8, i*8+j+1 )
        plt.imshow(data.squeeze()[i*8+j])
        plt.xticks([])
        plt.yticks([])
plt.show()


版权声明:本文为Austin6035原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。