pytorch通过ImageFolder函数读取数据集(详细实例)

pytorch通过ImageFolder函数读取数据集(详细实例)

dataset = ImageFolder(“E:/pycharmproject/dataset_read/veg200_images/”, transform=data_transform) #主要有两个参数,一个是图像根目录(被映射成标签的子目录的上一级),一个是数据操作

本文理论参考,以下博客的方法二:https://blog.csdn.net/qq_36852276/article/details/94588656

1 代码

#通过子文件夹映射成标签来使用

from torchvision.datasets import ImageFolder #该方法主要函数,将子文件夹映射成标签;主要有两个参数,一个是图像根目录(被映射成标签的子目录的上一级),一个是数据操作
import torch
import torchvision #用于下载数据集,进行图像增广操作等
import matplotlib.pyplot as plt #用于显示图片
import numpy as np


import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

#忽略警告
import warnings
warnings.filterwarnings('ignore')

#选择运行设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#加载数据集
#图像增广一定要注意三者的顺序,具体可见:https://blog.csdn.net/DD_PP_JJ/article/details/102730050
data_transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224, scale=(1, 1), ratio=(1, 1)),  #scale裁剪,ratio宽高比,100宽高缩放到相应像素,
    torchvision.transforms.ToTensor(),
    #torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = ImageFolder("E:/pycharmproject/dataset_read/veg200_images/", transform=data_transform) #主要有两个参数,一个是图像根目录(被映射成标签的子目录的上一级),一个是数据操作
train_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=10, shuffle=True)
print(dataset[0][0].size()) #第一张图片的图片矩阵
print(dataset[0][1]) #第一张图片的标签
print(dataset.class_to_idx) #查看子文件夹与标签的映射,注意:不是顺序映射

#画图
#batch 要求是这些类型:tensors, numpy arrays, numbers, dicts or lists
#stack要求每个tensor大小相等,也就是images要大小相同

#print(train_loader)
#print(len(train_loader))
#print(type(train_loader))
'''
print(train_loader.shape)
print(train_loader.size())
DataLoader对象没有以上两个数据
'''

for i, (images, labels) in enumerate(train_loader):
    images = np.array(images)
    images = torch.from_numpy(images)
    for j in range(len(images)):
        image = images[j]
        #print(image)
        #print(image.size()) #torch.Size([3, 100, 100])
        #matplotlib.pyplot.imshow()需要数据是二维的数组或者第三维深度是3或4的三维数组,当第三维深度为1时,使用np.squeeze()压缩数据成为二维数组
        #https://jianzhuwang.blog.csdn.net/article/details/103723536
        plt.imshow((image).numpy().transpose(1, 2, 0))  # 显示图片
        plt.axis('off')  # 不显示坐标轴
        plt.title("$The label of the picture is {} $".format(labels[j]))
        plt.show()

2 运行结构

总体结构
数据集在veg200_images中,包含200个子文件夹;代码在method_3中
在这里插入图片描述
veg200_images数据集结构(部分)
在这里插入图片描述

3 结果

以下代码运行结果(部分)
在这里插入图片描述
在这里插入图片描述
读取图片结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述


版权声明:本文为wqy1837154675原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。