In 2021 and 2022, when Amazon Science asked members of the program committees of the Knowledge Discovery and Data Mining Conference (KDD) to discuss the state of their field, the conversations revolved around graph neural networks.
Graph learning remains the most popular topic at KDD 2023, but as Yizhou Sun, an associate professor of computer science at the University of California, Los Angeles; an Amazon Scholar; and the conference’s general chair, explains, that doesn’t mean that the field has stood still.
Graph neural networks (GNNs) are machine learning models that produce embeddings, or vector representations, of graph nodes that capture information about the nodes’ relationships to other nodes. They can be used for graph-related tasks, such as predicting edges or labeling nodes, but they can also be used for arbitrary downstream processing tasks, which simply take advantage of the information encoded in graph structure.
But within that general definition, “the implication of ‘graph neural network’ could be very different,” Sun says. “‘Graph neural network’ is a very broad term.”
For instance, Sun explains, traditional GNNs use message passing to produce embeddings. Each node in the graph is embedded, and then each node receives the embeddings of its neighboring nodes (the passed messages), which it integrates into an updated embedding. Typically, this process is performed two to three times, so that the embedding of each node captures information about its one- to three-hop neighborhood.
“If I do message passing, I can only collect information from my immediate neighbors,” Sun explains. “I need to go through many, many layers to model long-range dependencies. For some specific applications, like software analysis or simulation of physical systems, long-range dependency becomes critical.
“So people asked how we can change this architecture. They were inspired by the transformer” — the attention-based neural architecture that underlies today’s large language models — “because the transformer can be considered a special case of a graph neural network, where in the input window, every token can be connected to every other token.
“If every node can communicate with every node in the graph, you can easily address this long-range-dependency issue. But there will be two limitations. One is efficiency. For some graphs, there are many millions or even billions of nodes. You cannot efficiently talk to everyone else in the graph.”
The second concern, Sun explains, is that too much long-range connectivity undermines the very point of graphical representation. Graphs are useful because they capture meaningful relationships between nodes — which means leaving out the meaningless ones. If every node in the graph communicates with every other node, the meaningful connections are diluted.
To combat this problem, “people try to find a way to mimic the position encoding in the text setting or the image setting,” Sun says. “In the text setting, we just turned the position into some encoding. And later, in the computer vision domain, people said, ‘Okay, let's also do that with image patches.’ So, for example, we can break each image into six-by-six patches, and the relative position of those patches can be turned into a position encoding.
“So the next question is, in the graph setting, how we can get that natural kind of relative position? There are different ways to do that, like random walk — a very simple one. And also people try to do eigendecomposition, where we utilize eigenvectors to encode the relative position of those nodes. But eigendecomposition is very time consuming, so again, it comes down to the efficiency problem.”
Efficiency
Indeed, Sun explains, improving the efficiency of GNNs is itself an active area of research — from high-level algorithmic design down to the level of chip design.
“At the algorithm level, you might try to do some sort of sampling technique, just try to make the number of operations smaller,” she says. “Or maybe just design some more efficient algorithms to sparsify the graphs. For example, let's say we wanted to do some sort of similarity search, to keep the most similar nodes to each target node. Then people can design some smart index technology to make that part very fast.
“And in the inference stage, we can do knowledge distillation to distill a very complicated model, let's say a graph neural network, into a very simple graph neural network — or not necessarily a graph neural network, maybe just a very simple kind of structure, like an MLP [multilayer perceptron]. Then we can do the calculation much faster. Quantization can also be applied in the inference stage to make computation much faster.
“So that's at the algorithm level. But nowadays people go deeper. Sometimes, if you want to solve the problem, you need to go to the system level. So people say, let's see how we can design this distributed system to accelerate the training, accelerate the inference.
“For example, in some cases, the memory becomes the main constraint. In this case, probably the only thing we can do is distribute the workload. Then the natural problems are how we can coordinate or synchronize the model parameters trained by each computational node. If we have to distribute the data to 10 machines, how can you coordinate with those 10 machines to make sure you only have one final version?
“And people now even go even deeper, to do the acceleration on the hardware side. So software-hardware co-design also becomes more and more popular. It requires people to really know so many different fields.
“By the way, at KDD, compared to many other machine learning conferences, real-world problems are always our top focus. In many cases, in order to solve the real-world problem, we have to talk to people with different backgrounds, because we cannot just wrap it up into the kind of ideal problems we solved when we were in high school.”
Applications
Beyond such general efforts to improve GNNs’ versatility and accuracy, however, there’s also new research on specific applications of GNN technology.
“There’s some work on how we can do causal analysis in the graph setting, meaning that the objects actually interfere with each other,” Sun explains. “This is quite different from the traditional setting: the patients in a drug study, for example, are independent from each other.
“There is also a new trend to combine deep representation learning with the causal inference. For example, how can we represent the treatment you try as a continuous vector, instead of just a binary treatment? Can we make the treatment timewise continuous — meaning that it's not just a static kind of one-time treatment? If I put the treatment 10 days later, how would the outcome compare to putting the treatment 20 days later? Time is very important; how can we inject that time information in?
“Graphs can also be considered a good data structure to describe multiagent dynamical systems — how those objects interact with each other in a dynamic network setting. And then, how can we incorporate the generative idea into graphs? Graph generation is very useful for many fields, such as in the drug industry.
“And then there are so many applications where we can benefit from large language models [LLMs]. For example, knowledge graph reasoning. We know that LLMs hallucinate, and reasoning on KGs is very rigorous. What would be a good combination of these two?
“With GNNs, there’s always new stuff. Graphs are just a very useful data structure to model our interconnected world.”