← 返回首页
🤖

图神经网络(GNN)入门:处理非欧几里得数据

📂 ai ⏱ 2 min 377 words

图神经网络(GNN)入门:处理非欧几里得数据

什么是图神经网络?

传统的神经网络(CNN、RNN)处理的是规则的网格数据(图像、序列),但现实世界中很多数据是图结构的:

图神经网络(Graph Neural Network,GNN)专门用于处理图结构数据。

图的基本概念

图的表示

一个图G = (V, E)由节点集V和边集E组成:

import torch
import torch.nn as nn

# 示例:社交网络
# 节点:用户
# 边:好友关系
# 特征:用户的属性(年龄、兴趣等)

# 邻接矩阵
adj_matrix = torch.tensor([
    [0, 1, 1, 0],  # 用户0与用户1、2相连
    [1, 0, 0, 1],  # 用户1与用户0、3相连
    [1, 0, 0, 1],  # 用户2与用户0、3相连
    [0, 1, 1, 0],  # 用户3与用户1、2相连
], dtype=torch.float)

# 节点特征
node_features = torch.tensor([
    [0.1, 0.2],  # 用户0的特征
    [0.3, 0.4],  # 用户1的特征
    [0.5, 0.6],  # 用户2的特征
    [0.7, 0.8],  # 用户3的特征
])

图卷积网络(GCN)

GCN是最早也是最经典的图神经网络,通过聚合邻居信息来更新节点表示。

核心思想

每个节点的新表示 = 自己的特征 + 邻居特征的加权平均

数学公式

H^(l+1) = σ(D^(-1/2) A D^(-1/2) H^(l) W^(l))

其中:

实现

class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        self.bias = nn.Parameter(torch.FloatTensor(out_features))
        nn.init.xavier_uniform_(self.weight)
        nn.init.zeros_(self.bias)
    
    def forward(self, X, A):
        # X: 节点特征 (num_nodes, in_features)
        # A: 邻接矩阵 (num_nodes, num_nodes)
        
        # 计算度矩阵
        D = torch.diag(A.sum(dim=1))
        D_inv_sqrt = torch.diag(torch.pow(D.diag(), -0.5))
        
        # 归一化邻接矩阵
        A_norm = D_inv_sqrt @ A @ D_inv_sqrt
        
        # 图卷积
        support = X @ self.weight
        output = A_norm @ support + self.bias
        
        return torch.relu(output)

图注意力网络(GAT)

GAT在GCN的基础上引入注意力机制,让每个节点自适应地聚合不同邻居的信息。

class GATLayer(nn.Module):
    def __init__(self, in_features, out_features, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = out_features // num_heads
        
        self.W = nn.Linear(in_features, out_features)
        self.a = nn.Parameter(torch.zeros(num_heads, 2 * self.head_dim))
        nn.init.xavier_uniform_(self.a.unsqueeze(0))
        
        self.leaky_relu = nn.LeakyReLU(0.2)
    
    def forward(self, X, A):
        batch_size, num_nodes, _ = X.shape
        
        # 线性变换
        Wh = self.W(X)  # (batch, num_nodes, out_features)
        Wh = Wh.view(batch_size, num_nodes, self.num_heads, self.head_dim)
        
        # 计算注意力分数
        Wh_i = Wh.unsqueeze(2).expand(-1, -1, num_nodes, -1, -1)
        Wh_j = Wh.unsqueeze(1).expand(-1, num_nodes, -1, -1, -1)
        
        e = self.leaky_relu(torch.cat([Wh_i, Wh_j], dim=-1) @ self.a)
        
        # 使用邻接矩阵mask
        mask = A.unsqueeze(-1).unsqueeze(-1)
        e = e.masked_fill(mask == 0, -1e9)
        
        # 注意力权重
        attention = torch.softmax(e, dim=2)
        
        # 聚合
        output = torch.einsum('bnhm,bnmh->bnh', attention, Wh)
        
        return output.reshape(batch_size, num_nodes, -1)

GraphSAGE

GraphSAGE通过采样和聚合邻居来处理大规模图,支持归纳学习。

class GraphSAGE(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, num_layers):
        super().__init__()
        self.layers = nn.ModuleList()
        
        # 输入层
        self.layers.append(nn.Linear(in_features, hidden_features))
        
        # 隐藏层
        for _ in range(num_layers - 2):
            self.layers.append(nn.Linear(hidden_features * 2, hidden_features))
        
        # 输出层
        self.layers.append(nn.Linear(hidden_features * 2, out_features))
    
    def aggregate(self, neighbor_features):
        # 聚合邻居特征(这里用平均)
        return neighbor_features.mean(dim=1)
    
    def forward(self, X, adjacency_list):
        # X: 节点特征
        # adjacency_list: 每个节点的邻居列表
        
        for i, layer in enumerate(self.layers):
            new_X = []
            for node_idx in range(X.size(0)):
                # 获取邻居特征
                neighbors = adjacency_list[node_idx]
                neighbor_features = X[neighbors]
                
                # 聚合
                aggregated = self.aggregate(neighbor_features)
                
                # 拼接自身特征和聚合特征
                node_features = torch.cat([X[node_idx], aggregated])
                new_X.append(node_features)
            
            X = layer(torch.stack(new_X))
            if i < len(self.layers) - 1:
                X = torch.relu(X)
        
        return X

GNN的应用

1. 社交网络分析

2. 推荐系统

3. 分子性质预测

4. 交通预测

5. 知识图谱

GNN的挑战

  1. 过平滑:多层GNN可能导致所有节点表示趋同
  2. 过拟合:图结构可能包含噪声
  3. 可扩展性:大规模图的训练效率
  4. 动态图:处理时变图结构

总结

图神经网络是处理图结构数据的强大工具,通过聚合邻居信息来学习节点表示。从GCN到GAT再到GraphSAGE,GNN不断演进,在社交网络、推荐系统、分子设计等领域展现出巨大潜力。