pytroch 数据集 datasets DataLoader示例
# 安装依赖包
! pip install torchvision
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting torchvision
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/f0/cb/b484ba727714926cbebe68687960da3481df5619280d17b1d5c90fb610bc/torchvision-0.11.3-cp38-cp38-win_amd64.whl (947 kB)
-------------------------------------- 948.0/948.0 KB 1.7 MB/s eta 0:00:00
Requirement already satisfied: numpy in c:\python38\lib\site-packages (from torchvision) (1.20.3)
Requirement already satisfied: torch==1.10.2 in c:\python38\lib\site-packages (from torchvision) (1.10.2)
Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in c:\python38\lib\site-packages (from torchvision) (8.3.2)
Requirement already satisfied: typing-extensions in c:\python38\lib\site-packages (from torch==1.10.2->torchvision) (3.7.4.3)
Installing collected packages: torchvision
Successfully installed torchvision-0.11.3
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
# pytorch自带的datasets类
training_data = datasets.FashionMNIST( # 下载FashionMNIST数据集
root="data", # 存储训练/测试数据的路径
train=True, # 指定训练或测试数据集
download=True, # 从互联网下载数据
transform=ToTensor() # 指定特征和标签转换
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
labels_map = { # 标签字典
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8)) # 展示窗口大小
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1, )).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show() # 展示图片

# 创建数据集
import os
import pandas as pd
from torchvision.io import read_image
# 自定义dataset类
class CustomImageDataset(Dataset): # 创建数据集类
def __init__(
self,
annotations_file,
img_dir,
transform=None, # 转换函数
target_transform=None # 转换函数
):
self.img_labels = pd.read_csv(annotations_file) # 标签
self.img_dir = img_dir # 图片文件夹
self.transform = transform # 数据转换
self.target_transform = target_transform # 标签转换
def __len__(self): # 数据集中的样本数
return len(self.img_labels)
def __getitem__(self, idx): # 加载并返回给定索引处的数据集中的样本
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx,
0]) # 图片路径
image = read_image(img_path) # 读取图片
label = self.img_labels.iloc[idx, 1] # 读取标签
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
from torch.utils.data import DataLoader
train_dataloader = DataLoader(
training_data, # 数据集
batch_size=64, # 训练批
shuffle=True # 洗牌打乱数据集
)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

版权声明:本文为weixin_44493841原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。