pyg 中的实现 message propagation 的公式: 实现 from typing import Union, Tuple from torch_geometric.typing import OptPairTensor, Adj, Size from torch import Tensor import torch from torch.nn import Linear import torch.nn.functional as F from torch_sparse import SparseTensor, matmul from torch_geometric.nn.conv import MessagePassing class SAGEConv(MessagePassing): def __init__(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, normalize: bool = False, root_weight: bool = True, bias: bool = True, **kwargs):