博客地址:陈小默的CSDN
PyTorch Geometric 为数据集提供了两个抽象类torch_geometric.data.Dataset和torch_geometric.data.InMemoryDataset。
其中InMemoryDataset继承自Dataset,如果要使用InMemoryDataset则需要使数据集大小适合存放在内存中。
首先需要一个保存有数据文件的文件夹root,该文件夹将被划分为两个文件夹,一个用于存储数据集的文件夹raw_dir和一个用来保存处理后数据集的文件夹processed_dir。
除了root,类初始化的init函数还接收三个函数参数transform, pre_transform 和pre_filter,这些参数的默认值都是None。transform函数用于动态的转换数据对象。pre_transform函数在数据保存到硬盘之前进行一次转换。pre_filter用于过滤某些数据对象。
保存在内存中的数据集
为了创建InMemoryDataset,需要实现下面四个方法:
raw_file_names():该函数返回的文件名需要在raw_dir文件夹下找到才可以跳过下载过程。processed_file_names():该函数放回的文件名需要在processed_dir中找到才可以跳过处理过程。download():下载文件到raw_dirprocess():处理原始数据并保存在processed_dir
在process()函数中,我们需要读入并创建一个Data对象列表之后将所有Data类型的对象保存在processed_dir文件夹中。由于无法将全部数据保存到内存中,需要在数据固化之前通过collate()函数保存Data对象的索引,此外,该函数还会返回一个slices字典用于从本地重建单个样例对象。于是在数据集对象new的时候,需要从本地读取self.data 和 self.slices对象。
import torch
from torch_geometric.data import InMemoryDataset
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super(MyOwnDataset, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
# Download to `self.raw_dir`.
def process(self):
# Read data into huge `Data` list.
data_list = [...]
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
创建更大规模的数据集
有一些数据的规模太大,无法一次性加载到内存中,那么我们需要自己实现torch_geometric.data.Dataset,只需要额外实现两个方法:
len(): 返回数据集的长度get():自定义加载Graph的方法
import os.path as osp
import torch
from torch_geometric.data import Dataset
class MyOwnDataset(Dataset):
def __init__(self, root, transform=None, pre_transform=None):
super(MyOwnDataset, self).__init__(root, transform, pre_transform)
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data_1.pt', 'data_2.pt', ...]
def download(self):
# Download to `self.raw_dir`.
def process(self):
i = 0
for raw_path in self.raw_paths:
# Read data from `raw_path`.
data = Data(...)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
i += 1
def len(self):
return len(self.processed_file_names)
def get(self, idx):
data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
return data
常见问题
如何跳过
download()和process()过程?可以通过忽略(不重写)该函数的方法实现。
class MyOwnDataset(Dataset): def __init__(self, transform=None, pre_transform=None): super(MyOwnDataset, self).__init__(None, transform, pre_transform)以上这些函数是否都是必须要使用到的?
对于动态创建的数据集而言,保存到内存中是非必要的,甚至在某些特定情况下,可以直接使用list充当数据集,如下:
from torch_geometric.data import Data, DataLoader data_list = [Data(...), ..., Data(...)] loader = DataLoader(data_list, batch_size=32)