OGB数据集的加载与处理【基于PyG】

GNN-Dataset

典型图数据集的加载与使用(基于PyG)。

  • OGB数据集
    • ogbn-arxiv
    • ogbn-products

我的代码:https://github.com/ytchx1999/GNN-Dataset/blob/main/OGBn.ipynb

from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
import torch_geometric.transforms as T

ogbn-arxiv

1、加载数据集

首先会去下载数据集,速度比较慢,需要科学上网。

默认图结构信息为边表edge_index的形式

dataset = PygNodePropPredDataset(name='ogbn-arxiv', root='./arxiv/')
print(dataset)
PygNodePropPredDataset()
data = dataset[0]
print(data)
Data(edge_index=[2, 1166243], node_year=[169343, 1], x=[169343, 128], y=[169343, 1])

也可以指定transform参数,处理成稀疏矩阵adj_t的形式

注意:后面所有用到edge_index的地方都要换成adj_t

dataset = PygNodePropPredDataset(name='ogbn-arxiv', root='./arxiv/', transform=T.ToSparseTensor())
print(dataset)
PygNodePropPredDataset()
data = dataset[0]
print(data)
Data(adj_t=[169343, 169343, nnz=1166243], node_year=[169343, 1], x=[169343, 128], y=[169343, 1])

数据集的属性以及探索同Cora,这里不再赘述

2、划分数据集、定义评估器

split_idx = dataset.get_idx_split()
evaluator = Evaluator(name='ogbn-arxiv')

train_idx存放训练节点的索引id,这个索引可以直接用于提取训练集以及训练结果,作用相当于train_mask

train_idx = split_idx['train']
val_idx = split_idx['valid']
test_idx = split_idx['test']
print(train_idx.shape)
print(val_idx.shape)
print(test_idx.shape)
torch.Size([90941])
torch.Size([29799])
torch.Size([48603])
y = data.y.squeeze(1)[train_idx]
print(y.shape)
torch.Size([90941])

也可以采样节点(minibatch)

from torch_geometric.data import NeighborSampler
train_loader = NeighborSampler(edge_index=data.adj_t, node_idx=train_idx,
                               sizes=[15, 10, 5], batch_size=1024, shuffle=True, num_workers=12)
print(train_loader)
NeighborSampler(sizes=[15, 10, 5])

ogbn-products

步骤同ogbn-arxiv

dataset = PygNodePropPredDataset(name='ogbn-products', root='./products/', transform=T.ToSparseTensor())
print(dataset)
PygNodePropPredDataset()
data = dataset[0]
print(data)
Data(adj_t=[2449029, 2449029, nnz=123718280], x=[2449029, 100], y=[2449029, 1])
split_idx = dataset.get_idx_split()
evaluator = Evaluator(name='ogbn-products')
train_idx = split_idx['train']
val_idx = split_idx['valid']
test_idx = split_idx['test']
print(train_idx.shape)
print(val_idx.shape)
print(test_idx.shape)
torch.Size([196615])
torch.Size([39323])
torch.Size([2213091])
from torch_geometric.data import NeighborSampler
train_loader = NeighborSampler(edge_index=data.adj_t, node_idx=train_idx,
                               sizes=[15, 10, 5], batch_size=1024, shuffle=True, num_workers=12)
print(train_loader)
NeighborSampler(sizes=[15, 10, 5])

版权声明:本文为weixin_41650348原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。