pytorch geometric
Contents
Intro
pytorch geometric 是基于 pytorch 的一个 GNN 框架库
安装
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+${CUDA}.html
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+${CUDA}.html
pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.7.0+${CUDA}.html
pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.7.0+${CUDA}.html
pip install torch-geometric
这里 ${CUDA}
根据对应的 cuda 版本选择:cpu
, cu92
, cu101
, cu102
, cu110
实验室 aimax 上的 cuda 版本是 cu101, 所以把这几行弄成一个脚本 install-pyg.sh
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
pip install torch-geometric
定义一个图数据
pyg 提供了一个类 torch_geometric.data.Data
来定义一个图,它包含了一个图的一些基本属性:
from torch_geometric.data import Data
import torch
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3],
[1, 2, 3, 0, 0, 0]], dtype=torch.long) # edge_index: (2, num_edges)
x = torch.randn((4, 256), dtype=torch.float) # x: (num_nodes, num_node_features)
data = Data(x=x, edge_index=edge_index)
print('num node features:', data.num_node_features)
print('num edges:', data.num_edges)
print('num nodes:', data.num_nodes)
print('num features:', data.num_features)
print('num edge features', data.num_edge_features)
这里我们给 Data 类传入了 x 和 edge_index, 即所有的 node 和所有的 edge.
Basic Message Passing Network
MessagePassing 类是所有图神经网络 layer 的 Base class. 各种不同的 message passing 方式都是从继承 MessagePassing 这个 base class 而来,然后再定制自己的 message()
, aggregate()
, update()
函数。在 forward 函数里调用 propagate()
函数,就会依次调用 message()
, aggregate()
, update()
这三个函数,完成 message passing 的过程。可以参考 pyg torch.geometric.nn
中各种不同 graph layer 的实现,他们对应着不同的 message function.
propagate()
函数必须传入的是 edge_index, 其余的参数根据 message() 和 update() 的需要来传递。
propagate(edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], size: Optional[Tuple[int, int]] = None, **kwargs)
:
在 **kwargs 里面,可以传递任意的参数,这些参数会被 message, update 这些函数用到。例如下面这个代码,我在 propagate 中额外添加了 lxs
这个参数,它可以在后面的 message, update 函数中被用到。
import torch.nn as nn
from torch_geometric.nn import MessagePassing
import numpy as np
from torch_geometric.utils import to_undirected
import torch
class MyMessagePassingLayer(MessagePassing):
def __init__(self):
super(MyMessagePassingLayer, self).__init__(aggr='add')
def forward(self, x, edge_index):
return self.propagate(edge_index, x=x, lxs=torch.randn((5, 4)))
def message(self, x_i, x_j):
print('x_i', x_i)
print('x_j', x_j)
return x_j
def update(self, inputs, x_i):
return inputs
x = torch.randn((8, 4))
net = MyMessagePassingLayer()
edge_index = torch.tensor([[1, 2, 3], [2, 4, 5]])
out = net(x, edge_index)
message()
可以接收 propagate()
中的任意参数,并且如果这些参数在变量名中加上 ‘_i’, ‘_j’, 那么这个参数将会被映射成 source_node 和 target_node 的形式。假如输入是表示结点的 x (N, d), 还传入 edge_index (2, m), 那么在 message 中, x_i 和 x_j 的形状将变成 (m, d). 这个映射会根据 edge_index 去进行索引,x_i 是 edge_index 中的 edge_index[1] 对应的结点下标值,x_j 是 edge_index[0] 对应的结点下标值。
import torch.nn as nn
from torch_geometric.nn import MessagePassing
import numpy as np
from torch_geometric.utils import to_undirected
import torch
class MyMessagePassingLayer(MessagePassing):
def __init__(self):
super(MyMessagePassingLayer, self).__init__(aggr='add')
def forward(self, x, edge_index):
print('x before message:\n', x)
print('edge index:\n', edge_index)
return self.propagate(edge_index, x=x, lxs=torch.randn((5, 4)))
def message(self, x_i, x_j):
print('x after enter message:\n', x)
print('x_i:\n', x_i)
print('x_j:\n', x_j)
return x_j
def update(self, inputs, x_i):
print('x_i in update:\n', x_i)
return inputs
x = torch.randn((4, 4))
net = MyMessagePassingLayer()
edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]])
out = net(x, edge_index)
从这个结果还可以看到, x_i, x_j 在 update() 函数里面仍然可以使用。
update()
函数的原型是: update(inputs: torch.Tensor) → torch.Tensor
, 它接受的第一个参数 inputs
是每个结点聚合后的 message, 其大小为 (N, dim)
, N 是结点数目。另外还可以接受传给 propagate()
的任意参数,例如我们这里的 edge_index
, lxs
.
pyg 官方库中的 Graph convolutional layer
学习 pyg 官方库中对各种 graph layer 的实现
GraphSAGE
TODO 对于 bipartite graph 如何使用
source 结点和 target 结点的参数要不一样 pyg 中不同 GNN layer 对具有不同属性的图的支持可以在这里看到:
Ref
Author Li Xunsong
LastMod 2021-08-27