Challenges and practical lessons from building a deep-learning-based ads CTR prediction model
August 29, 2022
At LinkedIn, our ads business is powered by click-through-rate (CTR) prediction, a core machine learning model. CTR prediction estimates the probability of clicks between a LinkedIn member and a potential advertisement. That probability is then used for ads auctions, which decide the order of ads being displayed to members. A better CTR model can enhance the member and advertiser experience by bringing more relevant ads and more efficient advertiser budget spending.
In the past, we predicted ads CTR through a GLMix model. Being a highly optimized framework coupled with abundant efforts of feature engineering, it was a baseline that was hard to surpass. We recently replaced this model with a deep-learning-based system and in this blog post we will describe some of the challenges we tackled, practical lessons we learned, and explain how the transition brought large relevance lifts (+8.5% CTR) for our ads business.
We would also like to highlight that this work was enabled by LinkedIn ML frameworks and infrastructure including GDMix, Lambda Learner, and other libraries.
Three towers, three challenges
Figure 1: The three-tower model architecture. The shallow and the deep towers take in generalization features and are trained at daily frequency while the wide tower takes in memorization features and is re-trained at hourly frequency.
Our deep CTR model has a three-tower architecture, the “deep tower,” the “wide tower,” and the “shallow tower.” The output of these three towers is summed together and fed into a sigmoid layer and a regular cross-entropy loss function is used. While at first glance it looks similar to the popular wide-and-deep model, the actual setup is quite different. In this section, we will use each tower to introduce one unique challenge we tackled, thus three towers, and three challenges.
The deep tower: complete feature interaction
The deep tower is a vanilla multilayer perceptron (MLP) where input features include member and advertiser profiles, activities, and context features. Those features will first be converted to dense embedding through embedding layers, then concatenated and fed to fully connected layers. The challenge with the deep tower is getting complete feature interaction across member, ad, and context features.
In general, there are two ways to productionize deep learning:
- (a) Train a deep model to generate some type of offline embedding, such as ads embedding and member embedding, then inject the embedding as a feature to the baseline model. The embedding can be stored in a key-value store and fetched during online scoring time.
- (b) Train a deep model and serve the entire deep model online.
When comparing the two approaches, (a) has a much lower engineering cost to achieve because approach (b) requires setting up the entire deep model online in the serving stack. The downside of (a), however, is that its deep model is only based on members or ads, and it cannot capture the complete feature interaction across members, ads, and context.
Earlier, we attempted to use deep learning for ads CTR through approach (a). Unfortunately, those attempts did not succeed, which led us to taking approach (b) to have the complete feature interaction. While how to build a large-scale deep learning serving system under the strict latency requirements of ads auction is out of the scope of this blog, we did some post-ramping analysis and found that the interaction between context features and other features is critical to the relevance lift we saw during A/B tests. This proved that the complete feature interaction enabled by end-to-end deep model serving is a key to success.
The wide tower: fast memorization
The wide tower is a linear layer that takes in sparse ID features such as ad ID and advertiser ID. Essentially these features help the model memorize the historical performance of each entity. Freshness is important to this type of feature as the performance of ads can trend differently through time and date and as new ads/advertisers keep entering our platform. To ensure the freshness of our model, we perform frequent partial re-training of the wide tower. For each model, we first perform cold-start training on the other two towers. Then, we freeze their coefficients and perform frequent warm-start training on the wide tower using the latest data. The generalization features of that latest data will be instantly scored by the other two towers and stored on HDFS as cold-start offset after they are collected from Kafka tracking events. Then the cold-start offset and sparse ID features are used to update the coefficients of the wide tower. Because the input features are lightweight, the warm-start retraining process is fast, and with GDMix and Lambda learner as our backbone, we are able to perform this partial re-training on an hourly basis.
Figure 2: The complete training process is decomposed into 3 steps. Step 1: Training a model with the deep and the shallow towers only using generalization features. Step 2: Whenever new data comes into our offline system, we take the generalization features of the new data and perform inference using the deep & shallow towers trained in Step 1 and get a “cold-start offset”. Step 3: Training the wide tower only with the cold-start offset plus memorization features. Step 1 happens on a daily basis. Step 2 happens whenever a new batch of tracking data becomes available (every few minutes). Step 3 happens on an hourly basis.
We did ablation studies and found that 1) the wide tower has a significant boost to model performance during A/B tests, and 2) increasing re-training frequency from daily level to hourly level makes noticeable improvements to model performance.
The shallow tower: ease of calibration
Unlike many verticals where better relevance is the only major goal, ads consider monetization values in its ranking objective, which is called Expected Cost Per Click (ECPI). For click-type ads, a simplified formula of ECPI is as follows:
where pCTR is the prediction score from our CTR model and biddingPrice is the amount of money advertisers are willing to pay if the member clicks on the ad.
ECPI is not just used for ranking, in many cases it is also used to charge advertisers. Thus apart from the relative order derived from ranking, the absolute value of ECPI and pCTR matters because inaccurate pCTR can lead to the overcharging or under-charging of advertisers. The process of getting pCTR to the right absolute value (i.e. oCTR, the observed ground truth probability of click) is called calibration. Deep models tend to produce a different distribution of pCTR and calibrating it is rather challenging. For the GLMix baseline model and the new deep model, we use isotonic regression as a post-training calibration module. However, it was not solving the problem for the deep model. When we tested the first version of the deep model, it produced pCTR that was on average 40% higher than our baseline model, which meant that it could overcharge and hurt advertisers’ ROI if it was ramped to production. We call this issue over-prediction.
The shallow tower trick
We found that a simple trick that alleviates the over-prediction problem is inserting a shallow tower into the model. The shallow tower is a linear layer that takes in almost the same features as the deep tower. While the theoretical explanation for this needs more study in the future, we can provide a hypothesis on why the shallow tower trick works. It has been empirically studied that deep models tend to be overconfident in their predictions when compared with linear models. The shallow and deep tower architecture can be thought of as a special residual block that combines a linear model and a deep model. Instead of optimizing for desired underlying mapping directly, the deep tower is now optimizing for the residual between the desired mapping and the linear model mapping. We hypothesize that this architecture can not only prevent model degradation, but also produce a mapping function that is closer to the linear models and reduce calibration error. However, more case studies are needed to reach this conclusion.
In practice, we found that adding the shallow tower reduces over-prediction from 40% to about 10%. Note that while both the shallow tower and the wide tower are linear layers, we do not combine them because the shallow tower takes in heavy features that cannot be processed and trained at hourly frequency.
Figure 3: Comparing the distribution before and after inserting the shallow tower: The Deep+Wide Network generates over-confident predictions (e.g. pCTR>0.5) while the Deep+Wide+Shallow Network has less of the issue.
The position feature
Another twist we made is removing the position feature from the deep tower and only feeding it to the shallow tower. Position refers to the position of the ad on the LinkedIn home feed, e.g., the second feed position. The position feature is special in the sense that it is a de-biasing feature that is only available during training time but not available during online serving due to the nature of our ads system. We performed an ablation study and found that putting the position feature into the deep tower enables the model to learn unwanted interaction between position and other features, which makes calibration harder.
Beyond the shallow tower: calibration and exposure bias
Despite the shallow tower trick alleviating the issue, we still had about 10% over-prediction, which prevented us from ramping the model to production. So another question we wanted to answer was, why is our isotonic-regression-based calibration module not fixing over-prediction?
The short answer is that exposure bias in the system leads to different distribution of data in offline dataset and online dataset, so the calibration models trained on offline dataset cannot generalize to the online request data. This has been observed in other industry applications and been called “selection bias,” but we think “exposure bias” can be a more accurate term here.
For each request in our online system, the model scores a few hundred ad candidates but only the most competitive ads (based on their bids and predicted CTRs) can win the auction, get exposed to members, and then be collected into our offline dataset. In other words, our offline dataset is biased by the baseline linear model (that was used for predicting CTRs) and thus different from our true online test set.
Figure 4: A system with the baseline model and a system with the deep model cause different biases on which ads get exposed and collected into the offline dataset. Their data samples and distribution are denoted in yellow and blue color separately. When we trained the first deep model, all our offline data came from the baseline model scoring (yellow) and the deep model did not show over-prediction offline. However, when the deep model was ramped online, it suffered from over-prediction on its own distribution of exposure data (blue) . Thus the solution is collecting its data samples into the offline dataset then training calibration models based on it.
One naive solution to the problem was ramping our deep model to 100% traffic, collecting data that has over-prediction issues into the offline dataset, and then training the calibration model based on that data. However, the solution was not practical because it could cause drastic business metrics shifts. So instead, we made a compromise. First, we ramped the deep model to a small percentage of traffic, then only used the tracking data generated by the deep model to train its calibration module. We gradually ramped up the deep model and found that the over-prediction eventually came down to 0 as we ramped it higher and collected more data for calibration training.
The new ads CTR model combines deep feature interaction, fast memorization, and ease of calibration. In this blog, we discussed some practical lessons when building a deep-learning based CTR model such as using embedding as features vs end-to-end deep model serving, ways to achieve hourly model retraining frequency etc. In particular, we shared that solving the over-prediction issues caused by deep models is a unique challenge to the ads domain and that we are doing more studies to confirm our solution.
This article is a slice of a larger project that spanned more than one year and involved multiple teams. In particular, we would like to thank our teammates and leadership from the Ads AI team: Renpeng Fang, Mark Yang, Zhenqi Hu, David Pardoe, Hiroto Udagawa, Arjun Kulothungun, Onkar Dalal, and our collaborators from the AI Foundations team and Machine Learning Infra team: Jun Shi, Sida Wang, Keerthi Selvaraj, Haichao Wei, Yun Dai, Pei-Lun Liao. We would also like to thank Rupesh Gupta, Kayla Guglielmo, Katherine H. Vaiente and the LinkedIn Editorial team for your reviews and suggestions.