图神经网络(GNN)入门:处理非欧几里得数据
图神经网络(GNN)入门:处理非欧几里得数据
什么是图神经网络?
传统的神经网络(CNN、RNN)处理的是规则的网格数据(图像、序列),但现实世界中很多数据是图结构的:
- 社交网络
- 分子结构
- 交通网络
- 知识图谱
图神经网络(Graph Neural Network,GNN)专门用于处理图结构数据。
图的基本概念
图的表示
一个图G = (V, E)由节点集V和边集E组成:
- 节点(Node/Vertex):图中的实体
- 边(Edge/Link):节点之间的连接
- 邻接矩阵A:表示节点之间的连接关系
- 特征矩阵X:节点的特征向量
import torch
import torch.nn as nn
# 示例:社交网络
# 节点:用户
# 边:好友关系
# 特征:用户的属性(年龄、兴趣等)
# 邻接矩阵
adj_matrix = torch.tensor([
[0, 1, 1, 0], # 用户0与用户1、2相连
[1, 0, 0, 1], # 用户1与用户0、3相连
[1, 0, 0, 1], # 用户2与用户0、3相连
[0, 1, 1, 0], # 用户3与用户1、2相连
], dtype=torch.float)
# 节点特征
node_features = torch.tensor([
[0.1, 0.2], # 用户0的特征
[0.3, 0.4], # 用户1的特征
[0.5, 0.6], # 用户2的特征
[0.7, 0.8], # 用户3的特征
])
图卷积网络(GCN)
GCN是最早也是最经典的图神经网络,通过聚合邻居信息来更新节点表示。
核心思想
每个节点的新表示 = 自己的特征 + 邻居特征的加权平均
数学公式
H^(l+1) = σ(D^(-1/2) A D^(-1/2) H^(l) W^(l))
其中:
- A:邻接矩阵
- D:度矩阵(对角矩阵)
- H:节点特征
- W:可学习的权重
- σ:激活函数
实现
class GCNLayer(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
self.bias = nn.Parameter(torch.FloatTensor(out_features))
nn.init.xavier_uniform_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, X, A):
# X: 节点特征 (num_nodes, in_features)
# A: 邻接矩阵 (num_nodes, num_nodes)
# 计算度矩阵
D = torch.diag(A.sum(dim=1))
D_inv_sqrt = torch.diag(torch.pow(D.diag(), -0.5))
# 归一化邻接矩阵
A_norm = D_inv_sqrt @ A @ D_inv_sqrt
# 图卷积
support = X @ self.weight
output = A_norm @ support + self.bias
return torch.relu(output)
图注意力网络(GAT)
GAT在GCN的基础上引入注意力机制,让每个节点自适应地聚合不同邻居的信息。
class GATLayer(nn.Module):
def __init__(self, in_features, out_features, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = out_features // num_heads
self.W = nn.Linear(in_features, out_features)
self.a = nn.Parameter(torch.zeros(num_heads, 2 * self.head_dim))
nn.init.xavier_uniform_(self.a.unsqueeze(0))
self.leaky_relu = nn.LeakyReLU(0.2)
def forward(self, X, A):
batch_size, num_nodes, _ = X.shape
# 线性变换
Wh = self.W(X) # (batch, num_nodes, out_features)
Wh = Wh.view(batch_size, num_nodes, self.num_heads, self.head_dim)
# 计算注意力分数
Wh_i = Wh.unsqueeze(2).expand(-1, -1, num_nodes, -1, -1)
Wh_j = Wh.unsqueeze(1).expand(-1, num_nodes, -1, -1, -1)
e = self.leaky_relu(torch.cat([Wh_i, Wh_j], dim=-1) @ self.a)
# 使用邻接矩阵mask
mask = A.unsqueeze(-1).unsqueeze(-1)
e = e.masked_fill(mask == 0, -1e9)
# 注意力权重
attention = torch.softmax(e, dim=2)
# 聚合
output = torch.einsum('bnhm,bnmh->bnh', attention, Wh)
return output.reshape(batch_size, num_nodes, -1)
GraphSAGE
GraphSAGE通过采样和聚合邻居来处理大规模图,支持归纳学习。
class GraphSAGE(nn.Module):
def __init__(self, in_features, hidden_features, out_features, num_layers):
super().__init__()
self.layers = nn.ModuleList()
# 输入层
self.layers.append(nn.Linear(in_features, hidden_features))
# 隐藏层
for _ in range(num_layers - 2):
self.layers.append(nn.Linear(hidden_features * 2, hidden_features))
# 输出层
self.layers.append(nn.Linear(hidden_features * 2, out_features))
def aggregate(self, neighbor_features):
# 聚合邻居特征(这里用平均)
return neighbor_features.mean(dim=1)
def forward(self, X, adjacency_list):
# X: 节点特征
# adjacency_list: 每个节点的邻居列表
for i, layer in enumerate(self.layers):
new_X = []
for node_idx in range(X.size(0)):
# 获取邻居特征
neighbors = adjacency_list[node_idx]
neighbor_features = X[neighbors]
# 聚合
aggregated = self.aggregate(neighbor_features)
# 拼接自身特征和聚合特征
node_features = torch.cat([X[node_idx], aggregated])
new_X.append(node_features)
X = layer(torch.stack(new_X))
if i < len(self.layers) - 1:
X = torch.relu(X)
return X
GNN的应用
1. 社交网络分析
- 社区检测
- 用户推荐
- 影响力预测
2. 推荐系统
- 用户-商品二部图
- 知识图谱增强推荐
3. 分子性质预测
- 药物发现
- 材料设计
4. 交通预测
- 路网流量预测
- 出行需求预测
5. 知识图谱
- 链接预测
- 实体分类
- 关系推理
GNN的挑战
- 过平滑:多层GNN可能导致所有节点表示趋同
- 过拟合:图结构可能包含噪声
- 可扩展性:大规模图的训练效率
- 动态图:处理时变图结构
总结
图神经网络是处理图结构数据的强大工具,通过聚合邻居信息来学习节点表示。从GCN到GAT再到GraphSAGE,GNN不断演进,在社交网络、推荐系统、分子设计等领域展现出巨大潜力。