Graph Neural Networks
Graphs are everywhere --- social networks, molecules, knowledge bases, recommendation systems, road networks. Standard neural networks (CNNs, transformers) assume grid or sequence structure. Graph neural networks (GNNs) operate directly on arbitrary graph topologies. This page covers the message passing framework, derives GCN from spectral graph theory, implements GraphSAGE and GAT, and builds a node classifier on the Cora citation network.
Why Graphs?
| Data Type | Structure | Examples |
|---|---|---|
| Social networks | User-user connections | Friend recommendations |
| Molecules | Atom-bond graphs | Drug discovery |
| Knowledge graphs | Entity-relation triples | Question answering |
| Citation networks | Paper-paper citations | Paper classification |
| Scene graphs | Object-relationship | Visual reasoning |
| Traffic networks | Intersection-road | Traffic prediction |
Standard neural networks cannot handle:
- Variable number of neighbors per node
- No canonical ordering of nodes
- Permutation invariance requirement
The Message Passing Framework
All GNNs follow the same high-level pattern. At each layer
1. Aggregate messages from neighbors:
2. Update the node's representation:
Different GNN variants differ in how they implement AGGREGATE and UPDATE.
Receptive Field Growth
After
Graph Convolutional Network (GCN)
Spectral Derivation
The graph Laplacian is
Spectral graph convolution applies a filter
where
Kipf and Welling (2017) approximate the spectral filter with first-order Chebyshev polynomials, yielding:
where
Per-Node View
For a single node
This is mean aggregation with symmetric normalization.
Worked Example — Message Passing on a 4-Node Graph
Setup: 4-node graph with edges: 0--1, 0--2, 1--2, 2--3. Each node has a 2D feature.
Node 0 --- Node 1
\ /
Node 2 --- Node 3Node features
| Node | Degree | ||
|---|---|---|---|
| 0 | 1.0 | 0.0 | 2 |
| 1 | 0.0 | 1.0 | 2 |
| 2 | 0.5 | 0.5 | 3 |
| 3 | 1.0 | 1.0 | 1 |
Weight matrix
Step 1: Add self-loops.
Step 2: Compute
Step 3: For node 3 (neighbors: 2, plus self):
Result: After one GCN layer, node 0 (originally
GCN Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class GCNLayer(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_features, out_features) * 0.01)
self.bias = nn.Parameter(torch.zeros(out_features))
def forward(self, X, A_hat):
"""
X: (num_nodes, in_features) node features
A_hat: (num_nodes, num_nodes) normalized adjacency (D^{-1/2} A_tilde D^{-1/2})
"""
support = X @ self.weight + self.bias
output = A_hat @ support # Neighborhood aggregation
return output
class GCN(nn.Module):
def __init__(self, n_features, n_hidden, n_classes, dropout=0.5):
super().__init__()
self.gc1 = GCNLayer(n_features, n_hidden)
self.gc2 = GCNLayer(n_hidden, n_classes)
self.dropout = dropout
def forward(self, X, A_hat):
h = F.relu(self.gc1(X, A_hat))
h = F.dropout(h, self.dropout, training=self.training)
h = self.gc2(h, A_hat)
return F.log_softmax(h, dim=1)GraphSAGE: Inductive Learning
GCN is transductive --- it requires the full graph during training. GraphSAGE (Hamilton et al., 2017) learns generalizable aggregation functions that work on unseen nodes.
Algorithm
For each layer
- Sample a fixed number of neighbors (not all)
- Aggregate neighbor features:
- Concatenate and transform:
- Normalize:
Aggregation Functions
| Aggregator | Formula | Properties |
|---|---|---|
| Mean | $\frac{1}{ | \mathcal |
| Max pool | Captures salient features | |
| LSTM | LSTM on random permutation | Expressive but order-dependent |
class GraphSAGELayer(nn.Module):
def __init__(self, in_features, out_features, aggregator='mean'):
super().__init__()
self.aggregator = aggregator
# Input is concatenation of self + aggregated neighbor features
self.linear = nn.Linear(in_features * 2, out_features)
def forward(self, X, adj_lists):
"""
X: (num_nodes, features)
adj_lists: dict mapping node -> list of neighbor indices
"""
num_nodes = X.size(0)
neigh_feats = torch.zeros_like(X)
for node in range(num_nodes):
neighbors = adj_lists[node]
if len(neighbors) > 0:
if self.aggregator == 'mean':
neigh_feats[node] = X[neighbors].mean(dim=0)
elif self.aggregator == 'max':
neigh_feats[node] = X[neighbors].max(dim=0)[0]
combined = torch.cat([X, neigh_feats], dim=1)
output = self.linear(combined)
# L2 normalize
output = F.normalize(output, p=2, dim=1)
return outputGraph Attention Network (GAT)
GAT (Velickovic et al., 2018) learns different attention weights for different neighbors, rather than treating all neighbors equally.
Attention Mechanism
For node
Multi-head attention (concatenate or average
GCN vs GraphSAGE vs GAT
| Feature | GCN | GraphSAGE | GAT |
|---|---|---|---|
| Aggregation | Symmetric normalization | Learned (mean/max/LSTM) | Attention-weighted |
| Neighbor weighting | Fixed (degree-based) | Equal | Learned per edge |
| Inductive | No (transductive) | Yes | Yes |
| Scalability | Full graph needed | Mini-batch via sampling | Mini-batch possible |
| Expressiveness | Low | Medium | High |
PyTorch Geometric: Cora Classification
PyTorch Geometric (PyG) is the standard library for GNNs.
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, GATConv
# ── Load Cora Dataset ────────────────────────────────────────────────
dataset = Planetoid(root='./data', name='Cora')
data = dataset[0]
print(f"Nodes: {data.num_nodes}") # 2708
print(f"Edges: {data.num_edges}") # 10556
print(f"Features: {data.num_features}") # 1433
print(f"Classes: {dataset.num_classes}") # 7
print(f"Train nodes: {data.train_mask.sum()}") # 140
# ── GCN Model ────────────────────────────────────────────────────────
class GCNModel(torch.nn.Module):
def __init__(self, n_features, n_hidden, n_classes):
super().__init__()
self.conv1 = GCNConv(n_features, n_hidden)
self.conv2 = GCNConv(n_hidden, n_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
# ── GAT Model ────────────────────────────────────────────────────────
class GATModel(torch.nn.Module):
def __init__(self, n_features, n_hidden, n_classes, heads=8):
super().__init__()
self.conv1 = GATConv(n_features, n_hidden, heads=heads, dropout=0.6)
self.conv2 = GATConv(n_hidden * heads, n_classes, heads=1, dropout=0.6)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
# ── Training ─────────────────────────────────────────────────────────
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GATModel(dataset.num_features, 8, dataset.num_classes).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
for epoch in range(200):
model.train()
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
if (epoch + 1) % 50 == 0:
model.eval()
pred = model(data).argmax(dim=1)
test_acc = (pred[data.test_mask] == data.y[data.test_mask]).float().mean()
print(f"Epoch {epoch+1}: Test Accuracy = {test_acc:.4f}")
# Expected: ~82-83% (GCN), ~83-85% (GAT)Graph-Level Tasks
For graph classification (e.g., molecule property prediction), we need a graph-level representation.
Graph Readout
Common readout functions:
- Mean pooling:
- Sum pooling:
- Hierarchical pooling: Learn to coarsen the graph (DiffPool, TopKPool)
from torch_geometric.nn import global_mean_pool, GINConv
class GraphClassifier(torch.nn.Module):
def __init__(self, n_features, n_hidden, n_classes):
super().__init__()
nn1 = torch.nn.Sequential(
torch.nn.Linear(n_features, n_hidden),
torch.nn.ReLU(),
torch.nn.Linear(n_hidden, n_hidden),
)
nn2 = torch.nn.Sequential(
torch.nn.Linear(n_hidden, n_hidden),
torch.nn.ReLU(),
torch.nn.Linear(n_hidden, n_hidden),
)
self.conv1 = GINConv(nn1)
self.conv2 = GINConv(nn2)
self.classifier = torch.nn.Linear(n_hidden, n_classes)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
x = global_mean_pool(x, batch) # Graph-level readout
return self.classifier(x)Over-Smoothing Problem
As GNN depth increases, all node representations converge to the same value. After many layers of averaging neighbor features, distinct node features become indistinguishable.
Solutions:
- Use few layers (2-3 is usually optimal)
- Skip connections (like ResNet)
- DropEdge: randomly remove edges during training
- PairNorm: normalize to prevent convergence
Expressiveness: The WL Test
The Weisfeiler-Leman (WL) graph isomorphism test provides an upper bound on GNN expressiveness.
Standard GNNs (GCN, GraphSAGE) are at most as powerful as the 1-WL test. This means they cannot distinguish certain non-isomorphic graphs. GIN (Graph Isomorphism Network) achieves exactly the 1-WL power by using:
where
What GNNs Cannot Distinguish
Two regular graphs with the same degree sequence but different structure (e.g., a 6-cycle vs two 3-cycles) look identical to 1-WL. Higher-order GNNs (k-WL) or subgraph GNNs are needed for these cases.
Link Prediction
Predict whether an edge should exist between two nodes:
class LinkPredictor(torch.nn.Module):
def __init__(self, in_channels, hidden_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
def encode(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = self.conv2(x, edge_index)
return x
def decode(self, z, edge_label_index):
"""Predict edge existence via dot product."""
return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
def forward(self, x, edge_index, edge_label_index):
z = self.encode(x, edge_index)
return self.decode(z, edge_label_index)Heterogeneous Graphs
Real-world graphs often have multiple node and edge types (e.g., paper-author-venue).
from torch_geometric.nn import HeteroConv, SAGEConv
class HeteroGNN(torch.nn.Module):
def __init__(self, metadata, hidden_channels):
super().__init__()
self.conv1 = HeteroConv({
edge_type: SAGEConv((-1, -1), hidden_channels)
for edge_type in metadata[1]
})
self.conv2 = HeteroConv({
edge_type: SAGEConv((-1, -1), hidden_channels)
for edge_type in metadata[1]
})
def forward(self, x_dict, edge_index_dict):
x_dict = self.conv1(x_dict, edge_index_dict)
x_dict = {key: F.relu(x) for key, x in x_dict.items()}
x_dict = self.conv2(x_dict, edge_index_dict)
return x_dictTemporal Graphs
For graphs that change over time (e.g., transaction networks), use temporal GNNs that process snapshots or continuous-time events:
| Method | Approach | Use Case |
|---|---|---|
| TGAT | Temporal attention | Continuous-time events |
| TGN | Memory + attention | Dynamic interactions |
| Snapshot | GNN per timestep | Periodic updates |
| EvolveGCN | Evolving GCN weights | Slowly changing graphs |
GNN Applications in Practice
Molecular Property Prediction
from torch_geometric.datasets import MoleculeNet
from torch_geometric.nn import GINConv, global_add_pool
# Load ESOL (water solubility prediction)
dataset = MoleculeNet(root='./data', name='ESOL')
class MoleculeGNN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels):
super().__init__()
nn1 = torch.nn.Sequential(
torch.nn.Linear(in_channels, hidden_channels),
torch.nn.ReLU(),
torch.nn.Linear(hidden_channels, hidden_channels)
)
nn2 = torch.nn.Sequential(
torch.nn.Linear(hidden_channels, hidden_channels),
torch.nn.ReLU(),
torch.nn.Linear(hidden_channels, hidden_channels)
)
self.conv1 = GINConv(nn1)
self.conv2 = GINConv(nn2)
self.fc = torch.nn.Linear(hidden_channels, 1)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
x = global_add_pool(x, batch)
return self.fc(x).squeeze()Scalability Techniques
| Technique | Description | Scale |
|---|---|---|
| Mini-batch (GraphSAGE) | Sample neighbors per layer | Millions of nodes |
| Cluster-GCN | Partition graph, train on subgraphs | Millions of nodes |
| GraphSAINT | Subgraph sampling with normalization | Billions of edges |
| DistDGL | Distributed GNN training | Multi-machine |
Cross-References
- Foundations: Neural Network Basics --- backprop through computation graphs
- Attention mechanism: Transformers --- attention applied to sequences
- Architecture guide: Architecture Selection Guide --- when to use GNNs
- Training: Training Techniques --- dropout, normalization
- Multimodal: Multimodal Models --- combining graph + text + vision