Skip to content
Unverified — AI-generated content. Help verify this page

Graph Neural Networks

Graphs are everywhere --- social networks, molecules, knowledge bases, recommendation systems, road networks. Standard neural networks (CNNs, transformers) assume grid or sequence structure. Graph neural networks (GNNs) operate directly on arbitrary graph topologies. This page covers the message passing framework, derives GCN from spectral graph theory, implements GraphSAGE and GAT, and builds a node classifier on the Cora citation network.

Why Graphs?

Data TypeStructureExamples
Social networksUser-user connectionsFriend recommendations
MoleculesAtom-bond graphsDrug discovery
Knowledge graphsEntity-relation triplesQuestion answering
Citation networksPaper-paper citationsPaper classification
Scene graphsObject-relationshipVisual reasoning
Traffic networksIntersection-roadTraffic prediction

Standard neural networks cannot handle:

  • Variable number of neighbors per node
  • No canonical ordering of nodes
  • Permutation invariance requirement

The Message Passing Framework

All GNNs follow the same high-level pattern. At each layer k, every node v updates its representation by:

1. Aggregate messages from neighbors:

mv(k)=AGGREGATE(k)({hu(k1):uN(v)})

2. Update the node's representation:

hv(k)=UPDATE(k)(hv(k1),mv(k))

Different GNN variants differ in how they implement AGGREGATE and UPDATE.

Receptive Field Growth

After K layers of message passing, each node's representation captures information from its K-hop neighborhood. This is analogous to receptive fields in CNNs.

Graph Convolutional Network (GCN)

Spectral Derivation

The graph Laplacian is L=DA where D is the degree matrix and A is the adjacency matrix. The normalized Laplacian:

L~=ID1/2AD1/2

Spectral graph convolution applies a filter gθ in the spectral domain:

gθx=Ugθ(Λ)UTx

where L=UΛUT is the eigendecomposition. Computing this is O(n2).

Kipf and Welling (2017) approximate the spectral filter with first-order Chebyshev polynomials, yielding:

H(l+1)=σ(D~1/2A~D~1/2H(l)W(l))

where A~=A+I (add self-loops) and D~ii=jA~ij.

Per-Node View

For a single node v:

hv(l+1)=σ(W(l)uN(v){v}hu(l)deg(u)deg(v))

This is mean aggregation with symmetric normalization.

Worked Example — Message Passing on a 4-Node Graph

Setup: 4-node graph with edges: 0--1, 0--2, 1--2, 2--3. Each node has a 2D feature.

Node 0 --- Node 1
  \        /
   Node 2 --- Node 3

Node features H(0):

Nodeh1h2Degree
01.00.02
10.01.02
20.50.53
31.01.01

Weight matrix W=[1001] (identity for simplicity)

Step 1: Add self-loops. A~=A+I. New degrees: d~0=3,d~1=3,d~2=4,d~3=2

Step 2: Compute hv(1) for node 0 (neighbors: 1, 2, plus self):

h0(1)=σ(Wu{0,1,2}hud~ud~0)=σ([1,0]33+[0,1]33+[0.5,0.5]43)=σ([1,0]3+[0,1]3+[0.5,0.5]3.46)=σ([0.333+0+0.144,0+0.333+0.144])=σ([0.478,0.478])=ReLU([0.478,0.478])=[0.478,0.478]

Step 3: For node 3 (neighbors: 2, plus self):

h3(1)=σ([0.5,0.5]42+[1,1]22)=σ([0.5,0.5]2.83+[1,1]2)=σ([0.677,0.677])=[0.677,0.677]

Result: After one GCN layer, node 0 (originally [1,0]) became [0.478,0.478] --- it absorbed features from its neighbors 1 ([0,1]) and 2 ([0.5,0.5]), averaging toward the local neighborhood. Node 3 has high values because both it and its neighbor (node 2) had positive features. The symmetric normalization prevents high-degree nodes from dominating.

GCN Implementation

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_features, out_features) * 0.01)
        self.bias = nn.Parameter(torch.zeros(out_features))

    def forward(self, X, A_hat):
        """
        X: (num_nodes, in_features) node features
        A_hat: (num_nodes, num_nodes) normalized adjacency (D^{-1/2} A_tilde D^{-1/2})
        """
        support = X @ self.weight + self.bias
        output = A_hat @ support  # Neighborhood aggregation
        return output

class GCN(nn.Module):
    def __init__(self, n_features, n_hidden, n_classes, dropout=0.5):
        super().__init__()
        self.gc1 = GCNLayer(n_features, n_hidden)
        self.gc2 = GCNLayer(n_hidden, n_classes)
        self.dropout = dropout

    def forward(self, X, A_hat):
        h = F.relu(self.gc1(X, A_hat))
        h = F.dropout(h, self.dropout, training=self.training)
        h = self.gc2(h, A_hat)
        return F.log_softmax(h, dim=1)

GraphSAGE: Inductive Learning

GCN is transductive --- it requires the full graph during training. GraphSAGE (Hamilton et al., 2017) learns generalizable aggregation functions that work on unseen nodes.

Algorithm

For each layer k:

  1. Sample a fixed number of neighbors (not all)
  2. Aggregate neighbor features:
hN(v)(k)=AGGREGATEk({hu(k1),uNsample(v)})
  1. Concatenate and transform:
hv(k)=σ(W(k)CONCAT(hv(k1),hN(v)(k)))
  1. Normalize:
hv(k)=hv(k)hv(k)2

Aggregation Functions

AggregatorFormulaProperties
Mean$\frac{1}{\mathcal
Max poolmax({σ(Wpoolhu+b)})Captures salient features
LSTMLSTM on random permutationExpressive but order-dependent
python
class GraphSAGELayer(nn.Module):
    def __init__(self, in_features, out_features, aggregator='mean'):
        super().__init__()
        self.aggregator = aggregator
        # Input is concatenation of self + aggregated neighbor features
        self.linear = nn.Linear(in_features * 2, out_features)

    def forward(self, X, adj_lists):
        """
        X: (num_nodes, features)
        adj_lists: dict mapping node -> list of neighbor indices
        """
        num_nodes = X.size(0)
        neigh_feats = torch.zeros_like(X)

        for node in range(num_nodes):
            neighbors = adj_lists[node]
            if len(neighbors) > 0:
                if self.aggregator == 'mean':
                    neigh_feats[node] = X[neighbors].mean(dim=0)
                elif self.aggregator == 'max':
                    neigh_feats[node] = X[neighbors].max(dim=0)[0]

        combined = torch.cat([X, neigh_feats], dim=1)
        output = self.linear(combined)
        # L2 normalize
        output = F.normalize(output, p=2, dim=1)
        return output

Graph Attention Network (GAT)

GAT (Velickovic et al., 2018) learns different attention weights for different neighbors, rather than treating all neighbors equally.

Attention Mechanism

For node v and neighbor u:

evu=LeakyReLU(aT[WhvWhu])αvu=exp(evu)kN(v)exp(evk)hv=σ(uN(v)αvuWhu)

Multi-head attention (concatenate or average K heads):

hv=k=1Kσ(uN(v)αvukWkhu)

GCN vs GraphSAGE vs GAT

FeatureGCNGraphSAGEGAT
AggregationSymmetric normalizationLearned (mean/max/LSTM)Attention-weighted
Neighbor weightingFixed (degree-based)EqualLearned per edge
InductiveNo (transductive)YesYes
ScalabilityFull graph neededMini-batch via samplingMini-batch possible
ExpressivenessLowMediumHigh

PyTorch Geometric: Cora Classification

PyTorch Geometric (PyG) is the standard library for GNNs.

python
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, GATConv

# ── Load Cora Dataset ────────────────────────────────────────────────
dataset = Planetoid(root='./data', name='Cora')
data = dataset[0]

print(f"Nodes: {data.num_nodes}")         # 2708
print(f"Edges: {data.num_edges}")         # 10556
print(f"Features: {data.num_features}")   # 1433
print(f"Classes: {dataset.num_classes}")  # 7
print(f"Train nodes: {data.train_mask.sum()}")  # 140

# ── GCN Model ────────────────────────────────────────────────────────
class GCNModel(torch.nn.Module):
    def __init__(self, n_features, n_hidden, n_classes):
        super().__init__()
        self.conv1 = GCNConv(n_features, n_hidden)
        self.conv2 = GCNConv(n_hidden, n_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# ── GAT Model ────────────────────────────────────────────────────────
class GATModel(torch.nn.Module):
    def __init__(self, n_features, n_hidden, n_classes, heads=8):
        super().__init__()
        self.conv1 = GATConv(n_features, n_hidden, heads=heads, dropout=0.6)
        self.conv2 = GATConv(n_hidden * heads, n_classes, heads=1, dropout=0.6)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# ── Training ─────────────────────────────────────────────────────────
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GATModel(dataset.num_features, 8, dataset.num_classes).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

for epoch in range(200):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 50 == 0:
        model.eval()
        pred = model(data).argmax(dim=1)
        test_acc = (pred[data.test_mask] == data.y[data.test_mask]).float().mean()
        print(f"Epoch {epoch+1}: Test Accuracy = {test_acc:.4f}")

# Expected: ~82-83% (GCN), ~83-85% (GAT)

Graph-Level Tasks

For graph classification (e.g., molecule property prediction), we need a graph-level representation.

Graph Readout

hG=READOUT({hv(K):vG})

Common readout functions:

  • Mean pooling: hG=1|V|vhv
  • Sum pooling: hG=vhv
  • Hierarchical pooling: Learn to coarsen the graph (DiffPool, TopKPool)
python
from torch_geometric.nn import global_mean_pool, GINConv

class GraphClassifier(torch.nn.Module):
    def __init__(self, n_features, n_hidden, n_classes):
        super().__init__()
        nn1 = torch.nn.Sequential(
            torch.nn.Linear(n_features, n_hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(n_hidden, n_hidden),
        )
        nn2 = torch.nn.Sequential(
            torch.nn.Linear(n_hidden, n_hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(n_hidden, n_hidden),
        )
        self.conv1 = GINConv(nn1)
        self.conv2 = GINConv(nn2)
        self.classifier = torch.nn.Linear(n_hidden, n_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)  # Graph-level readout
        return self.classifier(x)

Over-Smoothing Problem

As GNN depth increases, all node representations converge to the same value. After many layers of averaging neighbor features, distinct node features become indistinguishable.

Solutions:

  • Use few layers (2-3 is usually optimal)
  • Skip connections (like ResNet)
  • DropEdge: randomly remove edges during training
  • PairNorm: normalize to prevent convergence

Expressiveness: The WL Test

The Weisfeiler-Leman (WL) graph isomorphism test provides an upper bound on GNN expressiveness.

Standard GNNs (GCN, GraphSAGE) are at most as powerful as the 1-WL test. This means they cannot distinguish certain non-isomorphic graphs. GIN (Graph Isomorphism Network) achieves exactly the 1-WL power by using:

hv(k)=MLP(k)((1+ϵ(k))hv(k1)+uN(v)hu(k1))

where ϵ is a learnable parameter.

What GNNs Cannot Distinguish

Two regular graphs with the same degree sequence but different structure (e.g., a 6-cycle vs two 3-cycles) look identical to 1-WL. Higher-order GNNs (k-WL) or subgraph GNNs are needed for these cases.

Predict whether an edge should exist between two nodes:

python
class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

    def encode(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

    def decode(self, z, edge_label_index):
        """Predict edge existence via dot product."""
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)

    def forward(self, x, edge_index, edge_label_index):
        z = self.encode(x, edge_index)
        return self.decode(z, edge_label_index)

Heterogeneous Graphs

Real-world graphs often have multiple node and edge types (e.g., paper-author-venue).

python
from torch_geometric.nn import HeteroConv, SAGEConv

class HeteroGNN(torch.nn.Module):
    def __init__(self, metadata, hidden_channels):
        super().__init__()
        self.conv1 = HeteroConv({
            edge_type: SAGEConv((-1, -1), hidden_channels)
            for edge_type in metadata[1]
        })
        self.conv2 = HeteroConv({
            edge_type: SAGEConv((-1, -1), hidden_channels)
            for edge_type in metadata[1]
        })

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict

Temporal Graphs

For graphs that change over time (e.g., transaction networks), use temporal GNNs that process snapshots or continuous-time events:

MethodApproachUse Case
TGATTemporal attentionContinuous-time events
TGNMemory + attentionDynamic interactions
SnapshotGNN per timestepPeriodic updates
EvolveGCNEvolving GCN weightsSlowly changing graphs

GNN Applications in Practice

Molecular Property Prediction

python
from torch_geometric.datasets import MoleculeNet
from torch_geometric.nn import GINConv, global_add_pool

# Load ESOL (water solubility prediction)
dataset = MoleculeNet(root='./data', name='ESOL')

class MoleculeGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        nn1 = torch.nn.Sequential(
            torch.nn.Linear(in_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, hidden_channels)
        )
        nn2 = torch.nn.Sequential(
            torch.nn.Linear(hidden_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, hidden_channels)
        )
        self.conv1 = GINConv(nn1)
        self.conv2 = GINConv(nn2)
        self.fc = torch.nn.Linear(hidden_channels, 1)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_add_pool(x, batch)
        return self.fc(x).squeeze()

Scalability Techniques

TechniqueDescriptionScale
Mini-batch (GraphSAGE)Sample neighbors per layerMillions of nodes
Cluster-GCNPartition graph, train on subgraphsMillions of nodes
GraphSAINTSubgraph sampling with normalizationBillions of edges
DistDGLDistributed GNN trainingMulti-machine

Cross-References

"What I cannot create, I do not understand." — Richard Feynman