Introduction to Graph Neural Networks: The Message Passing Framework

Graph Neural Networks (GNNs) and Geometric Deep Learning are at the forefront of modern machine learning, offering powerful tools for modelling complex systems and structures. Especially in fields such as chemistry, biology, and social network analysis, where data can be naturally represented as graphs, these techniques prove advantageous.

In this blog post, we will explore the fundamentals of GNNs, their applications, and the underlying mathematical principles.

Introduction to Graph Neural Networks

Graph Neural Networks are a class of neural networks designed to perform inference on data structured as graphs. Unlike traditional neural networks that operate on grid-like data (e.g. images, sequences, tabular data), GNNs can directly handle the complex relationships and dependencies inherent in graph-structured data.

They are designed to perform tasks such as node classification, link prediction, and graph classification. This capability makes them particularly suited for tasks where the data is naturally represented as a network of interconnected entities, such as:

  • Cheminformatics and Drug Discovery: In chemistry, GNNs can predict molecular properties, generate new molecules, and model chemical reactions.
  • Social Networks: GNNs can be used to analyse social networks, predicting user behaviour, identifying communities, and recommending friends or content.
  • Protein Biology: GNNs play a crucial role in understanding the 3D conformations of proteins, which are essential for drug design and understanding biological processes at a molecular level.
  • Recommender Systems: GNNs can enhance recommendation systems by modelling user-item interactions as a graph and capturing complex dependencies.

Graph Representation

A graph $G$ is represented as a tuple $(\mathcal{V}, \mathcal{E})$, where $\mathcal{V}$ is a set of vertices (or nodes), and $\mathcal{E}$ is a set of edges connecting pairs of nodes.

A graph $G$ can be represented as adjacency matrix $A$ capturing the connections between nodes.

$$A = \begin{pmatrix} a_{11} & \ldots & a_{1j} & \ldots & a_{1N_v} \\ \vdots & \ddots & \vdots & \ddots & \vdots \\ a_{i1} & \ldots & a_{ij} & \ldots & a_{iN_v} \\ \vdots & \ddots & \vdots & \ddots & \vdots \\ a_{N_v1} & \ldots & a_{N_vj} & \ldots & a_{N_vN_v} \end{pmatrix}$$

In this matrix $a_{ij} = 1$ indicates the presence of an edge between nodes $\mathcal{V}_i$ and $\mathcal{V}_j$, while $a_{ij} = 0$ indicates no edge. For weighted graphs, $a_{ij}$​ represents the edge weight rather than a binary indicator

The Message Passing Framework

The message passing framework, is the foundation of modern GNNs and provides a unified view of how information propagates through graph structures. It involves two main phases: the message passing phase and the readout phase.

Message Passing Phase

During the message passing phase, information is propagated through the graph. Each node aggregates messages from its neighbours and updates its state.

This process is typically defined using two neural networks:

  • Message Function $M$: Computes the message from one node to another.
  • Vertex Update Function $U$ : Updates the node's hidden state based on the aggregated messages.
A graph with exemplary nodes, hidden state vectors and embedding state vectors depicting the message passing phase dependencies

Given an undirected graph $G$ with node features $\boldsymbol{x}_i$ and edge features $\boldsymbol{e}_{ij}$, we apply the message passing phase following these steps:

$$\boldsymbol{h}_{i}^{0} = \boldsymbol{x}_{i} \quad \forall i \quad \text{or} \quad \boldsymbol{h}_{i}^{0} = \text{MLP}(\boldsymbol{x}_{i}; \boldsymbol{\theta}) \quad \forall i$$

Initially at $t=0$ the hidden state $\boldsymbol{h}_i^0$ may be directly set to the input feature $\boldsymbol{x}_i$ or already derived through an MLP.

$$\boldsymbol{m}_{i}^{t+1} = \sum_{j: a_{i j}=1} M(\boldsymbol{h}_{i}^{t}, \boldsymbol{h}_{j}^{t}, \boldsymbol{e}_{i, j}; \boldsymbol{w})$$

The messages of the message passing function $M$, taking into account $\boldsymbol{h}_i^t$, the neighbouring hidden states $\boldsymbol{h}_j^t$ of the connected nodes $\{j: a_{i j}=1\}$, along with their edge connection features $\boldsymbol{e}_{ij}$, are being aggregated to form a new message $\boldsymbol{m}_i^{t+1}$.

💡
Note that $\sum$ can be replaced by any permutation invariant aggregation function like mean, max or min.

$$\boldsymbol{h}_{i}^{t+1} = U(\boldsymbol{h}_{i}^{t}, \boldsymbol{m}_{i}^{t+1}; \boldsymbol{v})$$

The hidden state information of each node $i$ at $t+1$ is being updated by the update function $U$ taking into account the newly generated message $\boldsymbol{m}_i^{t+1}$ and the hidden state information $\boldsymbol{h}_i^t$.

Here, $\boldsymbol{\theta}, \boldsymbol{w}, \boldsymbol{v}$ denote weights of the neural networks.

The message passing phase runs for $T$ steps, the goal being to update the node features iteratively to obtain useful hidden representations.

Readout Phase

The readout phase involves aggregating the node states to produce a final prediction output, which can be either at the node level or the graph level.

For node-level predictions, the readout function generates predictions for each node based on its final hidden state:

$$\hat y = R(\boldsymbol{h}_i^T)$$

For graph-level predictions, the permutation invariant readout function aggregates the hidden states of all nodes of the graph to produce a single prediction:

$$\hat{y} = R(\{\boldsymbol{h}_1^T, \boldsymbol{h}_2^T, \ldots, \boldsymbol{h}_N^T\})$$

An example of a readout function for graph-level classification/regression might be:

$$\hat{y} = \boldsymbol{w}^T \left( \frac{1}{|G|} \sum_{v \in G} \boldsymbol{h}_v^T \right)$$

The hidden states of all nodes at $T$ are being average-pooled and then transformed via the weights $\boldsymbol{w}$ which are learned for the final classification/regression task. We are again free to choose any permutation invariant aggregation function.

Problems

While the message passing framework of GNNs is powerful, it comes with several challenges

  • Oversmoothing: Representations of hidden states become increasingly similar over multiple message-passing steps, losing distinctiveness making nodes indistinguishable.
  • Oversquashing: The influence of one node on another decreases exponentially with the length of the shortest path between them, akin to the vanishing gradient problem in traditional neural networks.
  • Under-reaching: Necessary information from one node may not reach another node within the given number of message-passing steps $T$, leading to incomplete information transfer. This limits the model's ability to capture long-range dependencies in the graph.
  • Limited expressivity: Reduced ability to differentiate between non-isomorphic graphs due to GNNs primarily relying on local neighbourhood information. Nodes in different graphs may end up with similar representations if their local neighbourhoods are structurally similar. This similarity can make it challenging to distinguish between different graph structures, thereby limiting expressivity.

Conclusion

Graph Neural Networks and Geometric Deep Learning provide powerful frameworks for modeling complex, graph structured data. By leveraging the message passing framework, these models can capture complex relational patterns and achieve remarkable performance across diverse domains where data naturally exhibits graph structure.

However, practitioners must carefully consider the inherent limitations and assess the applicability of GNNs when designing architectures for their specific use-cases.