Graph Neural Network: In a Nutshell
Graph Neural Networks (GNNs) are a type of deep learning model that are specifically designed to identify and interpret complex patterns and dependencies within data that is structured as a graph. Unlike traditional neural networks that assume data instances are independent and identically distributed, GNNs leverage the connections between data points. This makes them particularly useful for relational data that is represented as graphs.
General Mathematics of GNNs
GNN operates on a graph denoted as \(G = (V, E)\), where \(V\) represents the set of nodes and \(E\) represents the set of edges connecting these nodes. Each node \(v\) in \(V\) is associated with a feature vector \(X_v\) which encapsulates the attributes or properties of that node.
For instance, in the provided image, each node represents a user in a social network, and the edges represent the connections or interactions between them, such as friendships or followings. The nodes, User 1, User 2, and User 3 make up the set \(V\), and the edges that connect these nodes make up the set \(E\). The feature vector \(X_v\) for each user includes personal details such as name, age, country, date of birth, and other demographic or behavioral information. In a real-world GNN application, these features would be encoded into numerical form to be processed.
The fundamental operation of a GNN is the message passing mechanism, which iteratively updates the representation of each node by aggregating information from its neighbors. Mathematically, the message passing procedure at each layer \(l\) of a GNN can be summarized in two key steps: aggregation and combination.
Aggregation:
For each node \(v\), aggregate the features of its neighboring nodes \(N(v)\) to produce an aggregated message \(m_v^{(l)}\). This step is formalized as: \(m_v^{(l)} = \text{AGGREGATE}^{(l)}(\{h_u^{(l-1)} : u \in N(v)\})\) where \(h_u^{(l-1)}\) is the feature vector of node \(u\) at layer \(l-1\).
Here are some commonly used aggregation functions:
-
Mean Aggregation: Takes the average of neighboring nodes’ features, often used in GCN (Graph Convolutional Networks).
-
Sum Aggregation: Sums up the features of the neighboring nodes, emphasizing larger values and often used in GraphSAGE.
-
Max Pooling: Uses the maximum value among the neighboring nodes’ features for each dimension, capturing the most significant features present in the neighborhood.
-
Attention-based Aggregation: As used in GAT (Graph Attention Networks), it computes attention scores to weight the neighboring nodes’ features dynamically.
Combination
Update the feature vector of node \(v\) by combining its previous feature vector \(h_v^{(l-1)}\) with the aggregated message \(m_v^{(l)}\). The updated feature vector \(h_v^{(l)}\) for layer \(l\) is computed as: \(h_v^{(l)} = \text{COMBINE}^{(l)}(h_v^{(l-1)}, m_v^{(l)})\)
Some of the combination functions include:
-
Linear Combination: Applies a linear transformation to the aggregated features, possibly including the node’s own previous features.
-
Non-linear Combination: Introduces non-linearity through an activation function, such as ReLU or ELU, after linearly combining the features.
-
GRU/LSTM-based Combination: Utilizes Gated Recurrent Units or Long Short-Term Memory units to combine the aggregated features with the node’s current state, facilitating the incorporation of temporal dynamics in the node embeddings.
After \(L\) layers of message passing, the final output \(h_v^{(L)}\) represents the learned embedding of node \(v\), incorporating information from its local neighborhood up to \(L\) hops away. The choice of aggregation and combination functions, as well as the depth \(L\), are crucial design considerations that determine the capacity of a GNN to capture and process graph-structured information.
What Problems Do GNNs Solve?
Node Classification
In node classification, the objective is to predict the label of a node based on its features and its position within the graph. For instance, classifying research papers into subjects based on their content and citation patterns is a typical node classification task.
After the final layer \(L\) of message passing, we obtain node embeddings \(h_v^{(L)}\) that encode both the feature information and the graph structure. The task is to predict the label \(Y_v\) for each node \(v\). The prediction is modeled as: \(\hat{Y}_v = \text{softmax}(W \cdot h_v^{(L)} + b)\) where \(W\) and \(b\) are the parameters of a fully connected layer that maps the node embedding to the output label space, and \(\text{softmax}\) ensures that the output can be interpreted as probabilities over the classes.
The loss function commonly used for node classification is the cross-entropy loss between the predicted labels \(\hat{Y}_v\) and the true labels \(Y_v\): \(\mathcal{L}_{\text{node}} = -\sum_{v \in V_L} \sum_{c=1}^{C} Y_{v,c} \log(\hat{Y}_{v,c})\) where \(V_L\) is the set of nodes with labels, \(C\) is the number of classes, \(Y_{v,c}\) is a binary indicator if class \(c\) is the correct classification for node \(v\), and \(\hat{Y}_{v,c}\) is the predicted probability that node \(v\) belongs to class \(c\).
Edge or Link Prediction
Edge prediction involves predicting the likelihood of a connection between two nodes. This task is foundational in applications like recommending new connections in social networks or predicting protein-protein interactions in biological networks.
We consider pairs of nodes \((u, v)\) and aim to predict the presence of an edge between them. A common approach involves computing a score for a potential edge by considering the embeddings of the endpoint nodes: \(\hat{Y}_{uv} = \sigma(W \cdot \text{concat}(h_u^{(L)}, h_v^{(L)}) + b)\) where \(\sigma\) is the sigmoid function ensuring the output is a probability, \(\text{concat}\) denotes the concatenation of the node embeddings, and \(W\) and \(b\) are parameters.
The loss function for edge prediction is typically the binary cross-entropy loss: \(\mathcal{L}_{\text{edge}} = -\sum_{(u,v) \in E_L} Y_{uv} \log(\hat{Y}_{uv}) + (1 - Y_{uv}) \log(1 - \hat{Y}_{uv})\) where \(E_L\) denotes the set of node pairs with known edge presence or absence, \(Y_{uv}\) is a binary indicator of the presence of an edge between \(u\) and \(v\), and \(\hat{Y}_{uv}\) is the predicted probability of an edge between \(u\) and \(v\).
Graph-Level Tasks
For graph-level tasks, the goal is to infer properties or labels of entire graphs. An example is predicting the toxicity of a chemical compound based on its molecular structure represented as a graph.
The objective is to predict a property or label of the entire graph \(G\). This requires aggregating node embeddings \(h_v^{(L)}\) into a single graph representation, which can be achieved through a readout function: \(h_G = \text{READOUT}(\{h_v^{(L)} : v \in V\})\) Common choices for the \(\text{READOUT}\) function include sum, mean, max or attention pooling.
The prediction for the graph is then made as: \(\hat{Y}_G = W \cdot h_G + b\) where \(W\) and \(b\) are parameters of a fully connected layer tailored to the graph-level output space.
The loss function for graph-level tasks depends on the nature of the task (classification or regression), for classification, it often takes the form of cross-entropy similar to node classification. For regression tasks, MSE or Huber Loss are widely used.
Why Is Self-Supervised Learning on GNNs Important?
Self-supervised learning (SSL) has emerged as a powerful paradigm in GNNs due to its ability to utilize unlabeled data effectively. This section focuses on Graph Autoencoders (GAE) and Contrastive Learning, the two main approaches in this domain.
Graph Autoencoders
A GAE aims to learn low-dimensional, informative representations of nodes or entire graphs by first encoding the graph into a latent space and then reconstructing graph properties from these embeddings. The process involves two primary components: an encoder and a decoder.
-
Encoder: The encoder function maps the input graph data (nodes, edges, and possibly node features) into a latent representation space. For a node \(v\), the encoder \(f_{\text{enc}}\) can be defined as: \(z_v = f_{\text{enc}}(X, A)\) where \(X\) is the node feature matrix, \(A\) is the adjacency matrix of the graph, and \(z_v\) is the latent representation of node \(v\). In practice, \(f_{\text{enc}}\) is often implemented using GNNs, which effectively capture the structural information of the graph.
-
Decoder: The decoder function attempts to reconstruct properties of the graph (such as the adjacency matrix) from the latent representations. The reconstruction \(\hat{A}\) of the adjacency matrix is given by: \(\hat{A} = f_{\text{dec}}(Z)\) where \(Z\) is the matrix of latent representations for all nodes, and \(f_{\text{dec}}\) is the decoder function, often a simple dot product followed by a sigmoid activation to estimate the probability of edge existence between nodes.
The objective is to minimize the difference between the original graph structure \(A\) and the reconstructed graph \(\hat{A}\), typically using a loss function like cross-entropy.
We recently published a paper, Enhancing Deep Neural Network Performance Prediction Using GAE Model, in the NeurIPS MLSys Workshop. If interested, find the paper here.
Contrastive Learning
Contrastive Learning in the context of GNNs involves learning representations by maximizing agreement between related (positive) pairs of nodes, edges, or subgraphs, while minimizing agreement between unrelated (negative) pairs. This can be formalized as follows:
-
Positive Pairs: Given a graph \(G\), a positive pair \((u, v)_+\) consists of two nodes (or subgraphs) that are considered similar based on some criterion, often nodes that are neighbors or connected by paths.
-
Negative Pairs: A negative pair \((u, v)_-\) involves nodes (or subgraphs) that are dissimilar, such as nodes far apart in the graph.
The contrastive loss function aims to discriminate positive pairs from negative ones. A commonly used loss is the Noise Contrastive Estimation (NCE) loss, which for a single positive pair can be expressed as: \(\mathcal{L}_{\text{NCE}} = -\log \frac{\exp(sim(z_u, z_v)_+)}{\exp(sim(z_u, z_v)_+) + \sum_{(u, v)_-}\exp(sim(z_u, z_v)_-)}\) where \(sim(z_u, z_v)\) denotes a similarity measure between the embeddings of two nodes \(u\) and \(v\), such as the dot product or cosine similarity.
The overall goal of contrastive learning in GNNs is to optimize the embeddings such that similar nodes are brought closer in the embedding space, while dissimilar nodes are pushed apart, leveraging the graph’s structure and any available feature information even in the absence of labels.
In summary, both GAEs and contrastive learning provide powerful frameworks for self-supervised learning on graphs, enabling GNNs to learn meaningful representations from unlabeled data by exploiting the inherent structure and features of graphs.
This blog post covers a basic overview of GNN. In my upcoming blog post, I will cover each topic in detail and include their implementation.
Thanks for stopping by, and happy learning!
Reference
-
Karthick Panner Selvam and Mats Brorsson. Can semi- supervised learning improve prediction of deep learning model resource consumption? In Machine Learning for Sys- tems Workshop at 37th NeurIPS Conference, 2023, New Or- leans, LA, US
-
Wu, Zonghan, et al. “A Comprehensive Survey on Graph Neural etworks.” ArXiv, https://doi.org/10.1109/TNNLS.2020.2978386.
-
https://pytorch-geometric.readthedocs.io/en/latest/