Data Streaming/Processing

FastTreeSHAP: Accelerating SHAP value computation for trees

Co-authors: Jilei Yang, Humberto Gonzalez, Parvez Ahammad

In this blog post, we introduce and announce the open sourcing of the FastTreeSHAP package, a Python package based on the paper Fast TreeSHAP: Accelerating SHAP Value Computation for Trees (presented at the NeurIPS2021 XAI4Debugging Workshop). FastTreeSHAP enables an efficient interpretation of tree-based machine learning models by computing sample-level feature importance values, built as a new implementation of the widely-used TreeSHAP algorithm in the SHAP package. Our FastTreeSHAP package implements two new algorithms: FastTreeSHAP v1 and FastTreeSHAP v2, each improving the computational efficiency of TreeSHAP by using a different efficiency approach. Our empirical benchmarking tests show that FastTreeSHAP v1 is 1.5x faster than TreeSHAP while keeping the memory cost unchanged, and FastTreeSHAP v2 is 2.5x faster than TreeSHAP, at the cost of a slightly higher memory usage. Parallel multi-core computing is fully enabled in the FastTreeSHAP package to further speed up its computation. Our FastTreeSHAP package is easy to use, with the same API as the TreeSHAP implementation in the SHAP package, with the exception of three additional arguments which are easy to tune in practice.

Background: SHAP and TreeSHAP

Predictive machine learning models are widespread in industry today. At LinkedIn, we build predictive models to improve our member experience in different member-facing products such as People You May Know (PYMK), newsfeed ranking, search, and job recommendations, as well as customer-facing products within sales and marketing. Among these models, complex models such as random forest, gradient boosted trees, and deep neural networks are being widely used due to their high prediction accuracy. As we continue to build on our Responsible AI program at LinkedIn, a key part of our work is to understand how these models work (a.k.a. model interpretation), which remains an important challenge because these models are intrinsically opaque. 

In a previous blog post, we described how we build transparent and explainable AI systems at LinkedIn, where we highlighted a few ways we've improved transparency in AI, including explainable AI for model consumers to build trust and augment decision-making (Project CrystalCandle, previously called "Intellige"), and explainable AI for modelers to perform model debugging and improvement. One of the key approaches in building transparent and explainable AI systems is to understand input contributions to model output (i.e., feature reasoning), and many times, the interpretations at an individual sample level are of the most interest. A few use cases of sample-level model interpretation at LinkedIn include:

  • In our business predictive models, such as customer acquisition models and customer churn models, sample-level feature reasoning is crucial for model end users (such as the sales and marketing teams) to ensure trust in prediction results, enabling them to create meaningful insights and actionable items accordingly, which eventually leads to the improvements in our key business metrics.

  • In our recruiter search models, sample-level feature reasoning can help answer questions from LinkedIn customers, such as why candidate 1 ranks higher than candidate 2, or why candidate 1’s rank has changed between searches in this month and last month, to build user trust and improve user engagement. It can also help model developers debug the model to further improve its performance. While this functionality hasn’t been implemented on the LinkedIn website yet, it is in our future plan.

  • In our job search models, sample-level feature reasoning is key to supporting legal and regulatory compliance objectives, and can be helpful in ensuring our job recommendation models are fair to LinkedIn members.

There exist several state-of-the-art sample-level model interpretation approaches, e.g., SHAP, LIME, and Integrated Gradient. Among them, SHAP (SHapley Additive exPlanation) calculates SHAP values, which quantify the contribution of each feature to the model prediction by incorporating concepts from game theory and local explanations. More concretely, SHAP calculates the average impact of adding a feature to the model by accounting for all possible subsets of the other features. In contrast to other approaches, SHAP has been justified as the only consistent feature attribution approach with several unique properties (local accuracy, missingness, and consistency), which agree with human intuition. Due to its solid theoretical guarantees, SHAP has become a top model interpretation approach in industry. For more technical details of SHAP, please refer to this paper.

Figure 1 shows a typical example of SHAP values of two individual samples in the public dataset Adult, where the prediction task is to determine whether a person makes over $50K a year by using features such as marital status, educational status, capital gain and capital loss, and age. The left plot shows a prediction score of 0.776 for Person A, which is much larger than the average prediction score of 0.241, indicating a high likelihood of making over $50K a year for Person A. The top driving features are ordered from top to bottom according to their absolute SHAP values, where the red bar represents a positive value and the blue bar represents a negative value. From the left plot, we can easily see that the high capital gain and the marital status (married with a civilian spouse) contribute most to Person A’s high prediction score. Similarly, in the right plot, a prediction score of 0.007 for Person B indicates a very low likelihood of making over $50K a year, which is negatively impacted mainly by this person’s marital status (single) and young age.

graphs-of-SHAP-values-of-two-individual-samples-in-dataset-adult-left-person-a-with-prediction-score-0.776-right-person-b-with-prediction-score-0.007

Figure 1. Example of SHAP values of two individual samples in dataset Adult. Left: Person A with prediction score 0.776; Right: Person B with prediction score 0.007.

Despite the strong theoretical guarantees and a wide availability of use cases of SHAP values, one of the major concerns in SHAP implementation is its computation—the computation time of the exact SHAP values grows exponentially with the number of features in the model. To improve its computational efficiency, TreeSHAP is designed for tree-based models (e.g., decision tree, random forest, gradient boosted trees), which takes polynomial time to compute the exact SHAP values. The polynomial time complexity is achieved by only considering the root-to-leaf paths in the trees that contain the target feature, and all the subsets within these paths. For more technical details of TreeSHAP, please refer to this paper.

After looking into many TreeSHAP use cases, we found that despite its algorithmic complexity improvement, computing SHAP values for a large sample size (e.g., tens of millions of samples) or a large model size (e.g., tree depth >= 8) still remains a computational concern in practice. For example, we have empirically seen in experiments that explaining 20 million samples for a random forest model with 400 trees and a maximum tree depth of 12 can take as long as 30 hours, even on a 50-core server. This is a problem because the need to explain (at least) tens of millions of samples widely exists in user-level predictive models in industry, e.g., feed ranking models, job search models, and subscription propensity models. Spending tens of hours in model interpretation becomes a significant bottleneck in these modeling pipelines:

  • It is likely to cause huge delays in post-hoc model diagnosis via important feature analysis, increasing the risks of incorrect model implementations when features are misused and the risks of untimely model iterations.

  • It can lead to long waiting times in preparing actionable items for model end users (e.g., a marketing team using a subscription propensity model) based on feature reasoning, and as a result, end users may not take appropriate actions in a timely manner, which can negatively impact a company’s revenue.

It is worth noting that there exist several initiatives for scaling SHAP/TreeSHAP computations by leveraging distributed computing/parallel computing mechanisms, e.g., Shparkley and PySpark-SHAP in Spark, and GPUTreeSHAP in GPU. In the FastTreeSHAP package, we mainly focus on improving the computational complexity of the TreeSHAP algorithm, which can be further combined with the distributed computing/parallel computing mechanisms.

FastTreeSHAP algorithm

In the FastTreeSHAP package, we implement two new algorithms, FastTreeSHAP v1 and FastTreeSHAP v2, designed to improve the computational efficiency of TreeSHAP. From a series of evaluation studies, we empirically find that FastTreeSHAP v1 is 1.5x faster than TreeSHAP while keeping the memory cost unchanged, and FastTreeSHAP v2 is 2.5x faster than TreeSHAP, at the cost of a slightly higher memory usage.

Table 1 summarizes the time and space complexities of each variant of the TreeSHAP algorithm (M is the number of samples to be explained, N is the number of features, T is the number of trees, L is the maximum number of leaves in any tree, and D is the maximum depth of any tree). Note that although the time complexity of FastTreeSHAP v1 looks the same as TreeSHAP, the (theoretical) average running time of FastTreeSHAP v1 is reduced to 25% of TreeSHAP. Also note that the time complexity of FastTreeSHAP v2 can be decomposed into two terms and among these two terms, only the second term is relevant to the number of samples M, and it reduces the time complexity of TreeSHAP and FastTreeSHAP v1 by a factor of D. We will discuss the time and space complexities in more detail in the next few sections.

summary-of-computational-complexities-of-TreeSHAP-algorithms

Table 1. Summary of computational complexities of TreeSHAP algorithms.

FastTreeSHAP v1
The key improvement in FastTreeSHAP v1 is to shrink the computation scope among the set of features. While TreeSHAP considers all features in each root-to-leaf path, FastTreeSHAP v1 only considers features satisfying the split rules along the path. The split rules are defined by all the features along the path and their corresponding thresholds. One example of split rules along a root-to-leaf path is {x≥ 0, x< 5, x3 ≥ 2}, where there are three features x1, x2 and x3 along this path. If a sample to be explained is (x1,x2,x3)=(2,6,0), then instead of considering all three features along this path in TreeSHAP, i.e., x1, x2 and x3, FastTreeSHAP v1 only considers one feature along this path, x1, since only x1 satisfies the split rule x1 ≥ 0. On average, for a given sample to be explained, around half of the features along each root-to-leaf path satisfy the split rules. This reduces the constant associated with tree depth D to 50%, which eventually reduces the constant of the time complexity of FastTreeSHAP v1 O(MTLD2) to 25%.

FastTreeSHAP v2
The general idea of FastTreeSHAP v2 is to trade space complexity for time complexity. It is motivated by the observation that the most expensive TreeSHAP step, which calculates the weighted sum of the proportions of all feature subsets that flow down into each leaf node, actually produces replicated outcomes across samples (more details in the original paper). Based on this, we split the algorithm into two parts: Part I, FastTreeSHAP-Prep, pre-computes all possible outcomes of this expensive TreeSHAP step, and stores them in a matrix of size L x 2D. Then, Part II, FastTreeSHAP-Score, calculates SHAP values for incoming samples by looking up in the pre-computed matrix. These two parts lead to two terms, O(TL2DD) and O(MTLD), in the time complexity of FastTreeSHAP v2 respectively, and only the second term is relevant to the number of samples M, and is D-time improvement over TreeSHAP and FastTreeSHAP v1. The space complexity of FastTreeSHAP v2 is dominated by the pre-computed matrix, which is O(L2D).

FastTreeSHAP comparison
In summary, FastTreeSHAP v1 strictly outperforms TreeSHAP. FastTreeSHAP v2 outperforms FastTreeSHAP v1 when you have a sufficiently large number of samples (M>2D+1/D), which commonly occurs in a moderate-sized dataset, e.g., M > 57 when D = 8, M > 630 when D = 12, and M > 7710 when D = 16 (most tree-based models produce trees with depth ≤ 16). Moreover, FastTreeSHAP v2 has a stricter memory constraint: O(L2D) < memory tolerance, but actually this constraint is quite loose in practice (shown in Table 3 and Table 4 in the next section). Both FastTreeSHAP v1 and FastTreeSHAP v2 produce exactly the same SHAP values as TreeSHAP.

FastTreeSHAP implementation

In the FastTreeSHAP package, we have fully enabled parallel computing to further speed up its computation. We have also designed a flexible and intuitive user interface for the FastTreeSHAP package.

Parallel computing performance
Parallel computing via OpenMP is implemented in the FastTreeSHAP package. As a comparison, parallel computing is not enabled in the SHAP package, except for the cases when interpreting XGBoost, LightGBM, and CatBoost models, where the SHAP package directly calls the TreeSHAP functions in these three packages, which use specific parallel computing implementations for each.

The implementation of parallel computing is straightforward for FastTreeSHAP v1 and the original TreeSHAP, where a parallel for-loop is built over all samples. The implementation of parallel computing for FastTreeSHAP v2 is more complicated: two versions of parallel computing have been implemented. Version I builds a parallel for-loop over all trees, which requires (MN + L2D) · C · 8B memory allocation (C is the number of threads; each thread has its own matrices to store both SHAP values and pre-computed values). Version II builds two consecutive parallel for-loops over all trees and over all samples respectively, which requires TL2· 8B memory allocation (first parallel for-loop stores pre-computed values across all trees). In the FastTreeSHAP package, we have added logic to automatically choose the correct version by default. Version I is selected for FastTreeSHAP v2 as long as its memory constraint is satisfied. If not, Version II is selected as long as its memory constraint is satisfied. If the memory constraints in both Version I and Version II are not satisfied, FastTreeSHAP v1 is chosen instead of FastTreeSHAP v2, due to its lower memory usage.

We compare the execution times of FastTreeSHAP v1 and FastTreeSHAP v2 in the FastTreeSHAP package against the TreeSHAP algorithm in the SHAP package (or the TreeSHAP algorithm in XGBoost and LightGBM packages when interpreting those two models) on two public datasets, Adult and Superconductor. Table 2 lists the basic information of these two datasets. For each dataset, we train two scikit-learn random forest models, two XGBoost models, and two LightGBM models, where we fix the number of trees to be 500, and vary the maximum depth of trees to be 8 and 12 respectively. Other hyperparameters in these models are left as default. All the evaluations were run in parallel on eight cores in Azure Virtual Machine with size Standard_D8_v3 (eight cores and 32GB memory), except for scikit-learn models in SHAP package which can only run on a single core. We ran each evaluation on 10,000 samples. Figure 2 and Figure 3 show the results averaged over three runs, and Table 3 and Table 4 quantify the speedups shown in these plots and underscore the benefit of FastTreeSHAP acceleration.

Name # Instances # Attributes (Original) # Attributes (One-Hot) Task Classes
Adult 48,842 14 64 Classification 2
Superconductor 21,263 81 81 Regression -

Table 2. Datasets.

graph-ofTreeSHAP-vs-FastTreeSHAP-v1-vs-FastTreeSHAP-v2-Adult

Figure 2. TreeSHAP vs FastTreeSHAP v1 vs FastTreeSHAP v2 - Adult.

graph-of-TreeSHAP-vs-FastTreeSHAP v1 vs FastTreeSHAP-v2-superconductor

Figure 3. TreeSHAP vs FastTreeSHAP v1 vs FastTreeSHAP v2 - Superconductor.

Model Tree Depth SHAP (s) FastTree SHAP v1 (s) Speed up FastTree SHAP v2 (s) Speed up Memory Cost in v2
sklearn random forest 8 318.44* 43.89 7.26 27.06 11.77 82MB
sklearn random forest 12 2446.12 293.75 8.33 158.93 15.39 280MB
XGBoost 8 17.35** 12.31 1.41 6.53 2.66 42MB
XGBoost 12 62.19 40.31 1.54 21.34 2.91 153MB
LightGBM 8 7.64*** 7.20 1.06 3.24 2.36 40MB
LightGBM 12 9.95 7.96 1.25 4.02 2.48 47MB

Table 3. TreeSHAP vs FastTreeSHAP v1 vs FastTreeSHAP v2 - Adult.

*Parallel computing is not enabled in SHAP package for scikit-learn models, thus TreeSHAP algorithm runs on a single core.
**SHAP package calls TreeSHAP algorithm in XGBoost package, which by default enables parallel computing on all cores.
***SHAP package calls TreeSHAP algorithm in LightGBM package, which by default enables parallel computing on all cores.

Model Tree Depth SHAP (s) FastTree SHAP v1 (s) Speed up FastTree SHAP v2 (s) Speed up Memory Cost in v2
sklearn random forest 8 466.04 58.28 8.00 36.56 12.75 54MB
sklearn random forest 12 5282.52 585.85 9.02 370.09 14.27 435MB
XGBoost 8 35.31 21.09 1.67 13.00 2.72 53MB
XGBoost 12 152.23 82.46 1.85 51.47 2.96 271MB
LightGBM 8 8.73 7.11 1.23 3.58 2.44 51MB
LightGBM 12 14.02 11.14 1.26 4.81 2.91 58MB

Table 4. TreeSHAP vs FastTreeSHAP v1 vs FastTreeSHAP v2 - Superconductor.

In Table 3 and Table 4, we observe that in both datasets, FastTreeSHAP v1 and v2 significantly outperform TreeSHAP in the SHAP package for the scikit-learn random forest model by ~8x and ~14x respectively, since parallel computing is not enabled in SHAP package for scikit-learn models. Even for the XGBoost and LightGBM models, where parallel computing is by default enabled on all available cores, FastTreeSHAP v1 and v2 can still outperform TreeSHAP in XGBoost and LightGBM packages by ~1.5x and ~2.7x respectively. We also observe that although FastTreeSHAP v2 costs more memory than the other two algorithms in theory, in practice, the memory constraint is quite loose, as all the memory costs in Table 3 and Table 4 are not causing out-of-memory issues, even in an ordinary laptop.

User experience
FastTreeSHAP package is built on SHAP package, and the user interface of FastTreeSHAP package is flexible and intuitive. The following snippet shows a typical example of how FastTreeSHAP works:

Note that the user interface of FastTreeSHAP is exactly the same as the user interface of SHAP, except for three additional arguments in the class “TreeExplainer”: “algorithm”, “n_jobs”, and “shortcut”. Users should be very comfortable when using FastTreeSHAP if they are already familiar with SHAP.

  • “algorithm” determines the specific TreeSHAP algorithm to use. It can take values "v0", "v1", "v2", or "auto", where the first three correspond to original TreeSHAP, FastTreeSHAP v1, and FastTreeSHAP v2 respectively. Its default value is “auto”, which conducts automatic algorithm selection between "v0", "v1", and "v2" according to the number of samples to be explained and the constraint on the allocated memory. Specifically, "v1" is always preferred to "v0" in any use case, and "v2" is preferred to "v1" when the number of samples to be explained is sufficiently large (M > 2D+1/D), and the memory constraint is also satisfied (min{(MN + L2D) · C, TL2D} · 8B < 0.25Total Memory). The introduction of the “auto” option greatly facilitates easy usage of the FastTreeSHAP package.

  • “n_jobs” specifies the number of parallel threads. Its default value is “-1”, which means utilizing all available cores.

  • “shortcut” determines whether to use the TreeSHAP implementation embedded in XGBoost and LightGBM packages directly when computing SHAP values for these two models and when computing SHAP interaction values for XGBoost models. Its default value is “False”, which means bypassing the “shortcut” and using the code in FastTreeSHAP package to compute SHAP values for XGBoost and LightGBM models.

For more detailed implementations of FastTreeSHAP, as well as in-depth comparisons between FastTreeSHAP v1, FastTreeSHAP v2, and the original TreeSHAP, check out these notebooks: Census IncomeSuperconductor, and Crop Mapping.

Figure 5 depicts the aggregated SHAP values over all testing samples (~16K) in dataset Adult. Different from using SHAP values to quantify feature contributions for each individual sample in Figure 1, the aggregated SHAP values in Figure 5 are used to measure the overall impact and direction of each feature across all samples to be explained. We interpret Figure 5 by considering the features “education-num” and “never-married” as an example. The red portion of the feature “education-num” represents a higher level of education degree, as shown in the legend on the right, where red means feature value being higher and blue means feature value being lower. The red portion of the feature “education-num” located on the positive x-axis (i.e., positive SHAP values) suggests that people with higher levels of education degree are more likely to make over $50K a year, which makes intuitive sense. Similarly, the blue portion of the feature “education-num” on the negative x-axis (i.e., negative SHAP values) indicates a negative impact on the likelihood of making over $50K a year from a lower level of education degree. For the feature “never-married”, we have a similar yet opposite observation: positive value of the binary feature “never-married” (red portion) leads to negative SHAP values, suggesting that people who have never married are less likely to make over $50K a year.

graph-of-aggregated-SHAP-values-in-dataset-Adult

Figure 5. Aggregated SHAP values (over 16K samples) in dataset Adult. Red means the feature value is higher and blue means the feature value is lower. A more positive SHAP value indicates a larger contribution to the positive class (making over $50K a year) and vice versa.

Conclusion

TreeSHAP has been widely used for explaining tree-based models due to its desirable theoretical properties and polynomial computational complexity. Our FastTreeSHAP package implements FastTreeSHAP v1 and FastTreeSHAP v2, two new algorithms to further improve the computational efficiency of TreeSHAP, with the emphasis on explaining samples with a large size. Moreover, the FastTreeSHAP package enables parallel computing to further improve its computational speed, and it provides a flexible and intuitive user interface.

The current version of FastTreeSHAP package supports one-time usage scenarios (explaining all samples once), and we are working on extending it to multi-time usage scenarios (having a stable model in the backend and receiving new scoring data to be explained on a regular basis) with parallel computing. Some preliminary results of evaluation studies in our FastTreeSHAP paper show that FastTreeSHAP v2 can achieve as high as >3x faster explanation in multi-time usage scenarios. Another future direction is to implement FastTreeSHAP package in Spark to further scale TreeSHAP computations by leveraging distributed computing mechanisms.

Acknowledgements

FastTreeSHAP package is developed by the Data Science Applied Research team at LinkedIn. Special thanks to Diana Negoescu, Wenrong Zeng, Kinjal Basu, Shaunak Chatterjee, Rupesh Gupta, Jon Adams, Hannah Sills, Kayla Guglielmo, Greg Earl, and Fred Han for their helpful comments and feedback. We also thank our management team Rahul Tandra, Sofus Macskássy, Romer Rosales, and Ya Xu for their continuous encouragement and support.