PyG图神经网络框架学习--示例介绍

实例介绍

通过自带的示例介绍并学习PyG。主要从以下4各方面:

图数据处理
通用基准数据集
Mini-batches
数据转换
图学习方法

图数据处理

图用于对对象(节点)之前的关系(边)进行建模。PyG中的图可以用torch_geometric.data.Data的一个实例表示,默认情况下包含以下属性:

  • data.x: 具有[num_nodes, num_node_features]形状的节点特征矩阵
  • data.edge_index: 形状为[2, num_edges],类型为Torch.long
  • data.edge_attr: 形状为[num_edges, num_edge_features]的边特征矩阵
  • data.y:标签(可能具有任意形状),例如,node-level任务形状为[num_nodes, *]或graph-level任务形状为[1, *]
  • data.pos: 节点位置矩阵,形状为[num_nodes, num_dimensions]

我们展示了一个简单的例子:一个有三个节点和四条边的无权无向图。每个节点正好包含一个特征。

import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0,1,1,2],
                          [1,0,2,1]],dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index = edge_index)
>>> Data(edge_index=[2, 4], x=[3, 1])

在这里插入图片描述
值得注意得是,edge_index,即定义所有源节点和目标节点的张量,不是索引对的列表。如果想写成索引对的形式,则需要在将它们传递给构造函数前对他们进行转置 .t() 并调用 .contiguous() 函数:

import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1],
                          [1, 0],
                          [1, 2],
                          [2, 1]],dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index = edge_index.t().contiguous())
>>> Data(edge_index=[2, 4], x=[3, 1])

注意:此图虽然只有两条边,但是却需要四个索引对来说明一条边的两个方向。

除了持有一些node-level、edge-level或graph-level的属性外,Data 还提供了一些有用的实用功能,例如:

print(data.keys)
>>> ['x', 'edge_index']

print(data['x'])
>>> tensor([[-1.0],
            [0.0],
            [1.0]])

for key, item in data:
    print(f'{key} found in data')
>>> x found in data
>>> edge_index found in data

'edge_attr' in data
>>> False

data.num_nodes
>>> 3

data.num_edges
>>> 4

data.num_node_features
>>> 1

data.has_isolated_nodes()
>>> False

data.has_self_loops()
>>> False

data.is_directed()
>>> False

# 将数据从cpu转移到gpu上
device = torch.device('cuda')
data = data.to(device)

更多详情可以从 torch_geometric.data.Data 源码中查看

通用基准数据集

由于我做的是node-level,这里仅记录所有的 Planetoid datasets (Cora, Citeseer, Pubmed)。

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='./data/Cora/', name='Cora')
>>> Cora()

len(dataset)
>>> 1

dataset.num_classes
>>> 7

dataset.num_node_features
>>> 1433

注意,如果本地没有数据集,会自动下载到 ./data/Cora/raw 文件夹中。如果本地已有数据集,请将 ./data/Cora/raw/ 作为 root 传入。

Cora数据集仅包含一个无向图,代表引用关系。

data = dataset[0]
>>> Data(edge_index=[2, 10556], test_mask=[2708],
         train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])

边有 10556 / 2条, 节点个数有 2708, 特征为 1433。test_mask – 测试集, train_mask – 训练集, val_mask – 验证集。 x – 节点特征, y – 节点标签。

data.is_undirected()
>>> True

data.train_mask.sum().item()
>>> 140

data.val_mask.sum().item()
>>> 500

data.test_mask.sum().item()
>>> 1000

Mini-batches和数据转换不太适用于node-level任务(我没用到,个人见解,如有错误请指出)。

图学习方法

我们开始搭建第一个图神经网络(GNN)。我们将使用一个简单的GCN层,并在Cora引文数据集上进行复制实验。

首先我们要加载数据集:

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='/data/Cora/raw', name='Cora')
>>> Cora()

我们不需要使用数据转换或dataloader。我们接下来实现一个两层的GCN:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self,):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        
        return F.log_softmax(x, dim=1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

model.eval()
pred = model(data).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')
>>> Accuracy: 0.8150

版权声明:本文为weixin_42727538原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。