GraphSAGE
Contents
pyg 中的实现
-
message propagation 的公式:
-
实现
from typing import Union, Tuple from torch_geometric.typing import OptPairTensor, Adj, Size from torch import Tensor import torch from torch.nn import Linear import torch.nn.functional as F from torch_sparse import SparseTensor, matmul from torch_geometric.nn.conv import MessagePassing class SAGEConv(MessagePassing): def __init__(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, normalize: bool = False, root_weight: bool = True, bias: bool = True, **kwargs): # yapf: disable kwargs.setdefault('aggr', 'mean') super(SAGEConv, self).__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.normalize = normalize self.root_weight = root_weight if isinstance(in_channels, int): in_channels = (in_channels, in_channels) self.lin_l = Linear(in_channels[0], out_channels, bias=bias) if self.root_weight: self.lin_r = Linear(in_channels[1], out_channels, bias=False) self.reset_parameters() def reset_parameters(self): self.lin_l.reset_parameters() if self.root_weight: self.lin_r.reset_parameters() def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None) -> Tensor: if isinstance(x, Tensor): x: OptPairTensor = (x, x) # x[0] 用于 message propagation 计算 message, x[1] 用于保留结点的输入特征 # propagate_type: (x: OptPairTensor) out = self.propagate(edge_index, x=x, size=size) print(f'out after propagate->:\n {out}') out = self.lin_l(out) x_r = x[1] if self.root_weight and x_r is not None: out += self.lin_r(x_r) if self.normalize: out = F.normalize(out, p=2., dim=-1) return out def message(self, x_j: Tensor) -> Tensor: print(f'x_j->\n{x_j}') return x_j def message_and_aggregate(self, adj_t: SparseTensor, x: OptPairTensor) -> Tensor: adj_t = adj_t.set_value(None, layout=None) # print(f'adj_t: type->\n{type(adj_t)}, values->\n{ajd_t}') return matmul(adj_t, x[0], reduce=self.aggr) def __repr__(self): return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels) x = Tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]) # x: (3, 3) edge_index = Tensor([[0, 0, 1, 2], [1, 2, 0, 0]]).long() # (2, 4) sage_nn = SAGEConv(3, 3) out = sage_nn(x=x, edge_index=edge_index) print(f'out->:\n{out}')
- 实现里面有一些注意的地方:
-
message_and_aggregate
方法将message
和aggragation
放到一个函数中来执行,它的目的是为了节省时间和内存,但只有这个函数在子类中被实现,并且输入的 tensor x 是一个 SparseTensor 类型的时候,它才会被调用,目前不用管它 -
当把 x 变成一个 OptPairTensor 的时候 (即 x 扩展成元组 (x, x)), 在 propagate 时仍可以直接将 x 传进去,得到的 message 和传入 x: Tensor 的结果是一样的。这里之所以要将 x 从 Tensor 转换成 OptPairTensor, 是因为每次结点在计算更新的结点表示时,会将当前的结点特征和传入的 message 一并进行操作。
-
- 实现里面有一些注意的地方:
GraphSAGE 中的 mini-batch 采样
首先是前项传播的算法流程,聚合这里没有什么特别之处,和一般的 message passing 过程是一致的,主要理解这里的 K. K 在这里可以被认为是网络的 layer 数目,在每个 layer, aggragator function 的参数不一样,K 次迭代实际上是将结点的 K 阶邻居的信息聚合到一起。这是 graphsage mini-batch 采样的流程。1-6 行是得到所有采样的结点。注意这里 k 的循环方向和算法 1 的循环方向是相反的,原因是,假如我们想得到某个 batch B 的结点的表征,那么他们对应到算法 1, 应该是在最后第 K 次迭代输出得到的,所以这里需要反过来采样结点。那么对于每个 iteration k, 可以得到在这个 iteration 的 batch k. 采样完后,对应到 9-16 行,是算法 1 的过程,利用在每个 iteration 都构造好的 batch k, 来进行邻居结点的聚合。N_k(u) 是一个 random 的邻居结点采样函数。
Author Li Xunsong
LastMod 2021-08-27