Exploring the Potential of Graph Neural Networks to Transform Recommendations at Zalando
Delivering personalized recommendations is key to engaging Zalando’s customers, but traditional models can miss the complexity of user-content interactions. By integrating graph neural networks (GNNs), we’re exploring a cutting-edge approach to better predict clicks and enhance the shopping experience.
Recommender systems are vital for personalizing user experiences across various platforms. At Zalando, these systems play a crucial role in tailoring content to individual users, thereby enhancing engagement and satisfaction. This is particularly important for Zalando Homepage, which serves as the customers' first impression of the company. Our current recommendation system employed on the Home page excels by leveraging user-content interactions and optimizing for predicted click through rate (CTR). The research introduced in this post focuses primarily on the approach and design of integrating GNN into the existing recommender system. We aim to validate the feasibility and effectiveness of this integration before transitioning to a fully production-ready implementation.
The Problem Statement
Given a preselected pool of content that potentially can be shown to a user on Zalando Homepage, we need to predict CTR for each piece of content so that later in the system the content with the highest expected value (which predicted CTR is part of) can be shown to the user.
Our production model relies on traditional tabular data, capturing user-content interactions such as views and clicks, and contrasts with the high potential of graph neural networks. GNNs have emerged as a powerful tool for modeling relational data, offering a way to represent and learn from complex interaction patterns more effectively. GNNs operate by representing data as graphs, and recommender system can be naturally modeled as a bipartite graph with two node types: users and items, and its links connect users and items and indicate user-item interaction (e.g., click, view, order, etc.).
Our task can then be formulated as follows:
- Given: Past user-item interactions
- Task:
- Predict user-item interactions in the future
- Can be cast as link prediction problem: predict new user-item interaction links given the past links
- For 𝑢 ∈ 𝑼, 𝑣 ∈ 𝑽, we need to get a real-valued score 𝑓(𝑢, 𝑣)
- 𝑲 items with the largest scores for a given user 𝑢 are then recommended
Solution and Methods
While we can train GNN directly to predict clicks (user-content links), in this experiment we propose to employ a graph neural network to specifically train embeddings for Zalando users and content on a click prediction task, and use these embeddings as additional inputs to our production model. Node embeddings are inherently learned as part of the process when running a link prediction task, as the GNN generates these embeddings to capture the relational structure and features of nodes in the graph, which are then used to predict the presence or absence of links.
We represent users and content on Homepage as two types of nodes in a graph, and their interactions (views and clicks) as two types of links, design an architecture with the basis of a GraphSage neural network and train it to predict the “clicked” link given a “viewed” link.
Dataset and data sources for the GNN embeddings training
Training and evaluation datasets are prepared using the Pytorch Geometric library, which provides a rich set of functionalities, including efficient graph data loading, manipulation, and batching. The train and evaluation datasets are based on user-content activity data on a per request level labeled clicked / not clicked.
Graph data structure allows GNNs to capture higher-order interactions and dependencies that traditional methods might miss. For example, in a recommender system, a GNN can model not just the direct interactions between a user and an item but also how similar users have interacted with similar items or how users following the same brand might be interested in the same article.
GNN Architecture
The GNN propagates and aggregates features from nodes along the links of a graph to capture interactions between nodes. Initially, each node has a specific set of features. In our case it is the information about most recently ordered articles for user nodes, and article representations for content nodes (each piece of content on Zalando Homepage is associated with specific articles, presented in this piece of content). As the GNN operates, nodes send their features to adjacent nodes through a process called message passing, during which features might be transformed by neural network layers such as convolution. Following this, each node combines the incoming features from its neighbors using aggregation operations like summing or averaging, updating each node's features. As the network depth increases, allowing more rounds of message passing, the GNN can consider more distant relationships. Thus the GNN model effectively generates embeddings for all nodes which are then passed through a classifier to predict the existence of the “clicked” link between a user and a content node, using a binary cross entropy loss function for updating the gradients.
Graph Mini-batching
To handle large-scale data, we employ mini-batch training, sampling subgraphs and computing embeddings in parallel. We sample links together with neighborhoods of both of their adjacent nodes. The depth of the sampled neighborhood is equal to the depth of the GNN. This approach ensures scalability and efficient use of computational resources, allowing GNNs to handle real-world large-scale graph datasets. For each mini-batch we sample disjoint subgraphs. We also use disjoint sets of links for message passing and for the supervision signal to prevent information leakage.
Integrating GNN trained features into our production model
While we evaluated offline that it is possible to directly predict clicks using a GNN model, integrating such a model into our current production system presents several challenges:
- Graph data generation: generating and maintaining the graph data structure creates operational overhead because raw user activity data is logged in a tabular format and requires time to be converted into a graph. This graph also needs to be updated in real-time (within the user session) with new user interactions which requires developing a new approach to data logging and training dataset generation.
- Inference challenges: inference on a graph is fundamentally different from inference on tabular data, as you need not only the information about a particular user-content pair, but rather all (or part of) the neighboring pairs as well. Aggregating information from a node’s neighbors can be computationally intensive and require specialized infrastructure to handle the graph operations efficiently.
- Scalability: running GNN inference at scale, especially for a large number of users and pieces of content, can pose significant scalability challenges and may require distributed computing environments.
Given these complexities, as an initial step we decided to use the embeddings generated by the GNN model for users and content as additional features in our existing production model. This approach leverages the strengths of GNNs while integrating more seamlessly with our current infrastructure not involving significant operational changes as opposed to running click predictions on GNNs end to end.
The GNN model can be retrained daily, ensuring that its features are regularly updated to reflect the latest user-content interactions. A key advantage of using a GNN is its ability to address the cold-start problem for nodes (e.g., newly introduced content) that were not part of the initial training. Even if a new node has no clicks yet, GNN inference can still be performed using the node's initial features and existing 'view' links formed during the content exploration phase. These initial features are dynamically updated as the node gains more connections and interactions within the graph.
What makes GNN features particularly valuable, compared to static features of individual articles, is their ability to capture and adapt to the relational context in the graph. Unlike static features that rely solely on precomputed attributes, GNN-generated embeddings are task-specific and are trained directly for the click prediction objective. This allows the model to encode not only the intrinsic properties of the content but also its evolving relationships with users and other content, leading to more accurate and context-aware predictions.
Experiments and Results
Evaluation approach and metric
We evaluate our new modeling approach in two stages:
- We evaluate our GNN model on the user-content click binary classification task using the ROC-AUC metric. We conduct several experiments, varying the number of layers (or hops on the graph) and the neighborhood size for graph mini-batching, ultimately selecting the best-performing configuration. To support the offline evaluation of our main production model, we run GNN inference on both the training and evaluation datasets, generating and saving user and content embeddings for all nodes in the respective graphs.
- We feed the generated GNN embeddings for users and content together with other features to the main production model and evaluate it on CTR prediction task also with ROC-AUC metric.
Adding GNN features into the production model has improved our main offline evaluation metric, ROC-AUC, by 0.6 percentage points. While this improvement might seem modest, it's important to note that this was an initial experimentation round focused primarily on validating the feasibility of integrating GNNs into our system, rather than fully optimizing the GNN configuration or the broader model pipeline. The improvements achieved thus far suggest significant untapped potential, as further tuning of hyperparameters, node feature engineering, and experimentation with different graph structures could unlock more substantial performance gains.
On top of that, GNNs offer capabilities that extend beyond traditional deep learning algorithms. They allow us to model complex aspects such as the novelty and diversity of content recommendations, and even the inspirational value of the content. These advanced capabilities can enable us to better align recommendations with higher-level business goals, such as enhancing user engagement through diverse and inspiring content.
Conclusion and Next Steps
We demonstrated the feasibility of using graph neural networks to model user interactions on Zalando’s Homepage. By leveraging GNN embeddings, we have improved the ROC-AUC performance of our recommender system however there is still a lot of room for improvement on both sides: fine-tuning the hyperparameters of production model with GNNs features, as well as testing architectural enhancements to train GNN embeddings. Future work involves experimenting with such improvements and validating the impact of the approach in the production setting. Additionally, creating solid customer representation leveraging GNNs, have strong potential to enable a variety of ML tasks within Zalando, enhancing applications like our recommender model to improve CTR prediction accuracy and enrich the overall user experience.
We're hiring! Do you like working in an ever evolving organization such as Zalando? Consider joining our teams as a Machine Learning Engineer!