数据集读取-Dataset,DataLoader

#%%

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
import PIL.Image as Image
import matplotlib.pyplot as plt
from torchvision import transforms

#%% md
'''
数据量小
数据量小的时候,没有大问题,直接加载到内存。比如我们利用一些数据做的线性回归

数据量大
数据量大的时候,将所有的数据读取到内存中训练就会内存不够。而大数据量是非常常见的现象。

思路:Dataloader的处理逻辑是先通过Dataset类里面的 __getitem__ 函数获取单个的数据,然后组合成batch,
再使用collate_fn所指定的函数对这个batch做一些操作,比如padding啊之类的。如果没有使用collate_fn,默认就是基本的操作。
实现:
Dataset类:Pytorch读取图片,主要是通过Dataset类,Dataset类作为所有的datasets的基类存在,所有的datasets都需要继承它,类似于C++中的虚基类
1、需要继承Dataset类
2、需要实现__getitem__(self, index),及__len__(self)方法。__getitem__方法输入一个index(通常指图片数据的路径和标签信息),输出图片数据和标签;
__len__方法返回数据集的大小
说明:
1、Dataset类及其子类是迭代器
2、DataLoader类是迭代器
'''

#%% 肝脏数据类
def make_dataset(root):
    imgsPath=[]
    n=len(os.listdir(root))//2
    for i in range(n):
        img=os.path.join(root, "%03d.png"%i)
        mask=os.path.join(root,"%03d_mask.png"%i)
        imgsPath.append((img,mask))
    return imgsPath

class LiverDataset(Dataset):
    # 初始化的目的是获取所有的图像或者所有图像的索引,方便__getitem__读取
    def __init__(self,path):
        imgsPath = make_dataset(path)                   #"data/train",获取路径下的所有文件路径
        self.imgsPath = imgsPath                        #存放着图片路径
    
    def __getitem__(self, index):
        x_path, y_path = self.imgsPath[index]           #输入图像路径,标签图像路径
        img_x = Image.open(x_path)
        img_y = Image.open(y_path)
        trans_x = transforms.ToTensor()
        trans_y = transforms.ToTensor()
        return trans_x(img_x), trans_y(img_y)           #输入图像,标签图像(转换为tensor)
    
    def __len__(self):
        return len(self.imgsPath)                       #数据集的组数(1输入图像+1标签图像=1组)


#输入,1)训练集路径 2)输入图片处理方式  3)标签处理方式
batch_size = 2
path = r'G:\code\python\deeplearning\project-U-Net\data\train'
liver_dataset = LiverDataset(path)

dataloaders   = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
#%% LiverDataset是一个迭代器
a = next(iter(liver_dataset))
t = transforms.ToPILImage()
plt.imshow(t(a[0]))
plt.show()
plt.imshow(t(a[1]))
plt.show()

#%% DataLoader是一个迭代器
x = dataloaders.__iter__().__next__()
t = transforms.ToPILImage()
plt.imshow(t(x[0][0,:,:,:]))
plt.show()
plt.imshow(t(x[1][0,:,:,:]))
plt.show()
# print(dataloaders.__len__())
# for x,y in dataloaders:
#     x
#     y

i=0,for x,y in dataloaders: 打散dataset数据索引,遍历整个dataset数据

i=1,for x,y in dataloaders:重新打散dataset数据索引,继续遍历dataset数据

for i in range(2):
    for x,y in dataloaders:
        print((x))
        print((y))

情形2:DataLoader会依次读取 顺序/打乱的迭代器,当数据读取完再进行读取时候,raise StopIteration异常 

from torch.utils.data.dataloader import DataLoader

loader = DataLoader(dataset=range(100),batch_size=1,shuffle=True)
data = iter(loader)

for i in range(101):
    print(i,next(data))

StopIteration

可能会有疑问,为什么一般代码中写for step in training_data_loader: 执行,再下一个epoch的时候,却没有给出StopIteration异常!!!注意,通过查阅相关资料可以知道,在for循环(当前epoch)的过程中,都使用的同一个迭代器,使用next方法获取数据,当数据遍历一遍,迭代器销毁。下一个epoch时,使用的是另外一个迭代器,因此不会发生停止迭代异常。

情形3:batchSize >1 时,队列中最后一个数据不满足batchSize大小,输出个数则小于batchSize

from torch.utils.data.dataloader import DataLoader

loader = DataLoader(dataset=range(100),batch_size=3,shuffle=False)
data = iter(loader)

for i in range(101):
    print(i,next(data))

输出:

......

30 tensor([90, 91, 92])
31 tensor([93, 94, 95])
32 tensor([96, 97, 98])
33 tensor([99])

文件夹中列出所有文件,并添加到自定义dataset类中

class tublinDataset_selfDefineTest(Dataset):
    def __init__(self,dirName):
        super(tublinDataset_selfDefineTest, self).__init__()
        if not os.path.isdir(dirName):
            raise ValueError('input file_path is not a dir')
        self.dirName = dirName
        # 获取路径下所有的图片名称,必须保证路径内没有图片以外的数据
        self.image_list = os.listdir(self.dirName)

版权声明:本文为xingghaoyuxitong原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。