The Essential Guide to GNN (Graph Neural Networks)

Share on linkedin
Share on twitter
Share on facebook
Share on whatsapp
Share on pocket
The Essential_Guide to GNN

Table of Contents

Introduction Graph Neural Networks

Graph neural networks (GNNs) are a set of deep learning methods that work in the graph domain. These networks have recently been applied in multiple areas including;  combinatorial optimization, recommender systems, computer vision – just to mention a few. These networks can also be used to model large systems such as social networks, protein-protein interaction networks, knowledge graphs among other research areas. Unlike other data such as images, graph data works in the non-euclidean space. Graph analysis is therefore aimed at node classification, link prediction, and clustering.


In this article, let’s explore Graph neural networks (GNNs) further. 

What is a Graph?

A Graph is a data structure containing nodes and vertices. The relationship between the various nodes is defined by the vertices. If the direction is specified in the nodes the graph is said to be directed, otherwise, it is undirected. 


A great example of graphs in use is modeling the connection between various people in a social network. 

GNN (Graph Neural Networks)

Graph Neural Networks are a special class of neural networks that are capable of working with data that is represented in graph form. These networks are heavily motivated by Convolutional Neural Networks (CNNs) and graph embedding. CNN’s are not able to handle graph data because the nodes in the graphs aren’t represented in any order and the fact that dependency information between two nodes is represented by edges.

Graphs with NetworkX

Let’s take a minute and look at how one can create graphs using NetworkX. NetworkX is a Python package that can be used for creating graphs. Here is how you can use the package to create an empty graph with no nodes. 

import networkx as nx
G = nx.Graph()


You can then add some nodes to the graph using the `add_nodes` function. 

G.add_nodes_from([2, 3])


Next, add some edges to the graph using the `add_edges_from` function. 

edges = [(2,1),(2,2),(3,2),(4,3),(6,4),(7,5),(14,5)]


The graph can be visualized using Matplotlib. That is done by calling the `draw` function and using Matpotlib to show the graph. 

nx.draw(G, with_labels=True, font_weight='bold')
import matplotlib.pyplot as plt


How do Graph Neural Networks work?

The idea of graph neural network (GNN) was first introduced by Franco Scarselli Bruna et al in 2009. In their paper dubbed “The graph neural network model”, they proposed the extension of existing neural networks for processing data represented in graphical form. The model could process graphs that are acyclic, cyclic, directed, and undirected. The objective of GNN is to learn a state embedding that encapsulates the information of the neighborhood for each node. This embedding is used to produce the output. The output can be, for example, a node label. 

The original GNN proposal had a couple of limitations: 

  • Updating of the hidden states of nodes was inefficient for a fixed point 
  • The GNNs used the same parameters in each iteration while other neural networks use different parameters in each layer
  • Modeling of informative features obtained from the edges was difficult 

Traditional Graph Analysis methods

Graphs can also be analyzed using traditional methods. These methods are usually algorithms. They include:

The challenge of these methods is the requirement of prior knowledge hence they can not be used for graph classification. 

Types of Graph Neural Networks

There are several types of Graph Neural Networks. Let’s take a look at a couple of them.

Graph Convolutional Networks (GCNs)

Graph Convolutional Networks (GCNs) utilize the same convolution operation as in normal Convolutional Neural Networks. GCNs learn features through the inspection of neighboring nodes. They are usually made up of a Graph convolution, a linear layer, and non-linear activation. GNNs work by aggregating vectors in the neighborhood, passing the result to a dense neural net layer, and finally applying non-linearity.  

GNNs differ from CNNs in that they are built to work with non-Euclidian structured data. There are two major types of GCNs namely:

  • Spatial Convolutional Networks. In these networks, the features of neighboring nodes are combined into a central node. The features are summed similar to the normal convolution operation. 
  • Spectral Convolutional Network: In Spectral networks, the convolution operation is defined in the Fourier domain by computing the eigendecomposition of the graph Laplacian.

Graph Auto-Encoder Networks

Graph Auto-Encoder Networks are made up of an encoder and a decoder. The two networks are joined by a bottleneck layer. An encode obtains features from an image by passing them through convolutional filters. The decoder attempts to reconstruct the input. Autoencoder models are known to deal with extreme class imbalance that is common in link prediction problems. Graph Auto-Encoder Networks, therefore, try to learn graph representations and then re-build the graphs using the decoder.

Recurrent Graph Neural Networks

Graph recurrent neural networks (GRNNs) utilize multi-relational graphs and use graph-based regularizers to boost smoothness and mitigate over-parametrization. Since the exact size of the neighborhood is not always known a Recurrent GNN layer is used to make the network more flexible. GRNN can learn the best diffusion pattern that fits the data. It is also able to handle situations where a node is involved in multiple relations. The network is also computationally inexpensive because the number of operations is scaled linearly with regard to the number of graph edges. 


Gated Graph Neural Network (GGNN)

Gated Graph Neural Networks (GGNNs) perform better than Recurrent Graph Neural Networks on problems with long-term dependencies. The long-term dependencies are encoded by node and edge gates.  Long-term temporal dependencies are encoded by time gates. Therefore, Gated Graph Neural Networks improve Recurrent Graph Neural Networks by adding gating mechanisms. These gates are responsible for remembering and forgetting information in different states.  

List of GNN Applications

Let’s now take a moment to look at what GNNs can do:

  • Node classification: The objective here is to predict the labels of nodes by considering the labels of their neighbors. 
  • Link prediction: In this case, the goal is to predict the relationship between various entities in a graph. This can for example be applied in prediction connections for social networks. 
  • Graph clustering: This involves dividing the nodes of a graph into clusters. The partitioning can be done based on edge weights or edge distances or by considering the graphs as objects and grouping similar objects together. 
  • Graph classification: This entails classifying a graph into a category. This can be applied in social network analysis and categorizing documents in natural language processing. Other applications in NLP include text classification, extracting semantic relationships between texts, and sequence labeling. 
  • Computer vision: In the computer vision world, GNNs can be used to generate regions of interest for object detection. They can also be used in image classification whereby a scene graph is generated. The scene generation model then identifies objects in the image and the semantic relationship between them. Other applications in this field include interaction detection and region classification. 

Problems Associated with GNNs

Graph Neural Networks are powerful networks. However, there are a couple of known problems associated with them:

  • Shallow in nature: Traditional neural networks can go very deep to obtain better performance. Unfortunately, GNNs are usually shallow with the majority having just three layers. The creation of deep GNNs is still an active research area. 
  • Dynamic Graphs: Dynamic graphs have a structure that keeps changing hence making them hard to model. Dynamic GNN is also an active research area. 
  • Lack of standard graph generation methods: There is no standard way of generating graphs. In some applications, fully connected graphs are used while in others algorithms detect graph nodes. 
  • Scalability: Applying GNNs in applications such as recommender systems and social networks at scale is a challenge. The main hurdle here is that these methods are computationally expensive. 

Example: Graph Neural Networks with PyTorch

PyTorch can be coupled with DGL to build Graph Neural Networks for node prediction. Deep Graph Library (DGL)  is a Python package that can be used to implement GNNs with PyTorch and TensorFlow. The official docs provide this example on how to get started.

Let’s take a look at a PyTorch example. The first step is to import the packages and load the data. 

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
dataset =
g = dataset[0]


The example shows how to build a GNN for a semi-supervised node classification model on the Cora dataset. The next step is to define the  Graph Convolutional Network that will compute node representations using neighborhood information. This is done using `dgl.nn.GraphConv`.

from dgl.nn import GraphConv
class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h
# Create the model with given dimensions
model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes)


The next move is to train the neural network. The training is done similar to how you would have done training in PyTorch or TensorFlow. 

def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    for e in range(100):
        # Forward
        logits = model(g, features)

        # Compute prediction
        pred = logits.argmax(1)

        # Compute loss
        # Note that you should only compute the losses of the nodes in the training set.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward

        if e % 5 == 0:
            print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(
                e, loss, val_acc, best_val_acc, test_acc, best_test_acc))
model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)


You can also use the Deep Graph Library with TensorFlow. That will require you to export that backend in your environment. Here is how that can be done on Google Colab

!export DGLBACKEND tensorflow


Implementing Graph Neural Networks in TensorFlow and Keras

This paper proposes the Keras Graph Convolutional Neural Network Python package (kgcnn) based on TensorFlow and Keras. It provides Keras layers for Graph Neural Networks. The official page provides numerous examples of how to use the package. One of the examples is how to use kgcnn for node classification using the Cora dataset. Let’s take a look at a snippet of this illustration. 

Training a GNN for graph classification

The first step is usually to load the required packages. 

from import cora_graph
from kgcnn.literature.GCN import make_gcn
from kgcnn.utils.adj import precompute_adjacency_scaled, convert_scaled_adjacency_to_list, make_adjacency_undirected_logical_or
from import ragged_tensor_from_nested_numpy
from kgcnn.utils.learning import lr_lin_reduction


The next step is to load the data and convert it into a dense matrix. 

# Download and load Dataset
A_data, X_data, y_data = cora_graph()
# Make node features dense
nodes = X_data.todense()


The next step is to precompute the scaled and undirected adjacency matrix and map the adjacency matrix to the index list plus edge weights. After that, the shape of the array is converted using NumPy. 

# Precompute scaled and undirected (symmetric) adjacency matrix
A_scaled = precompute_adjacency_scaled(make_adjacency_undirected_logical_or(A_data))
# Use edge_indices and weights instead of adj_matrix
edge_index, edge_weight = convert_scaled_adjacency_to_list(A_scaled)
edge_weight = np.expand_dims(edge_weight, axis=-1)


Next, one-hot encodes the labels. 

# Change labels to one-hot-encoding
labels = np.expand_dims(y_data, axis=-1)
labels = np.array(labels == np.arange(70), dtype=np.float)


The model can be defined using the `make_gcn` function. The function expects the shape of the input node, the shape of the input edges, depth among others. 

model = make_gcn(
    input_node_shape=[None, 8710],
    input_edge_shape=[None, 1],
    # Output
    output_embedd={"output_mode": 'node'},
    output_mlp={"use_bias": [True, True, False], "units": [140, 70, 70], "activation": ['relu', 'relu', 'softmax']},
    # model specs
    gcn_args={"units": 140, "use_bias": True, "activation": "relu", "has_unconnected": True}


Here is a summary of the model. 

summary of the model

The next step is to train this model. The training ended after 300 epochs on this Google Colab

# Training loop
trainlossall = []
testlossall = []
start = time.process_time()
for iepoch in range(0, epo, epostep):
    hist =, ytrain,
                     epochs=iepoch + epostep,
                     sample_weight=train_mask  # Important!!!

    testlossall.append(model.evaluate(xtrain, ytrain, sample_weight=val_mask))
stop = time.process_time()
print("Print Time for taining: ", stop - start)


You can then check the training and test loss by plotting them using Matplotlib. 

plt.plot(np.arange(1, len(trainlossall) + 1), trainlossall, label='Training Loss', c='blue')
plt.plot(np.arange(epostep, epo + epostep, epostep), testlossall[:, 1], label='Test Loss', c='red')
plt.legend(loc='lower right', fontsize='x-large')

This illustration is one example of numerous `kgcnn` examples that can be found on the official repository

Other Graph Neural Network libraries

Let’s briefly mention other Graph Neural Network libraries:

Final thoughts

In this article, you have learned about Graph Neural Networks. You have seen that they can be used to solve problems that involve unstructured graph data. In a nutshell, you have covered:

  • What is a Graph?
  • GNN (Graph Neural Networks)
  • What GNN can Do
  • Applications of GNNs
  • How Graph Neural Networks work
  • Types of Graph Neural Networks
  • Graph Neural Networks with PyTorch
  • Implementing Graph Neural networks with TensorFlow and Keras
  • Training a GNN for Graph Classification

Just to mention a few. 

Must-read papers on GNN

How powerful are Graph Neural Networks?

Semi-supervised Classification With

Graph Convolutional Networks

Simplifying Graph Convolutional Networks

Gated Graph Recurrent Neural Networks


Graph Neural Networks: A Review of Methods and Applications

Notebook used in this article 

A Comprehensive Survey on Graph Neural Networks

Graph LSTMs

How To Visualize Large-Scale Data by Learning a Graph Neural Network Representation

Top MLOps guides and news in your inbox every month

Share on facebook
Share on twitter
Share on linkedin
Share on whatsapp
Share on pocket

Announcing CORE, a free ML Platform for the community to help data scientists focus more on data science and less on technical complexity

Download cnrvg CORE for Free

By submitting this form, I agree to’s
privacy policy and terms of service.