Identifying churn drivers with Random Forests

At RetainKit, we aim to tackle the challenging problem of churn at SaaS companies by using AI and machine learning. If you run a SaaS company and you have churn issues, we’d be happy to talk to you and see if our product could help. You can also follow us on Product Hunt Upcoming.


In the early days of Post Planner (my previous startup), everything was going fine, except that it wasn’t. We had built a product that solved a problem, or so we thought. Customers were signing up in numbers. We had even reached a break-even point, something crucial for a bootstrapped company. And yet we weren’t growing.

Missing the forest for the trees

As most companies focused on growth, growth, growth, we had made a rookie mistake. We only cared about the number of new customers coming in (acquisition) but didn’t pay much attention to customers leaving (churn).

It was a problem because we didn’t know why people were leaving — was Post Planner too expensive for them? Did we truly solve their problem? Or was it something else we were unaware of? Also, the more customers we got through the front door, the more people left. It felt like we had a very leaky bucket.

At Post Planner, eventually our focus shifted to balance acquisition and churn reduction. And with great effort, we managed to get out of this particular hole.

However, the importance of customer retention and the cost of ignoring it stuck with me. Now, years later, I’m working to help SaaS companies minimize churn using the power of Machine Learning.

My new project, called RetainKit, utilizes product usage and customer data to predict which customers might be leaving soon, and why.
In this blog post, I’m going to outline some of the techniques we use in RetainKit to analyze the “Why” of churn.

Random Forests

One of the approaches we use for this is Random Forest (RF). RF is an ensemble method, which creates multiple decision trees and averages/votes their predictions. If you are interested in learning more about decision trees and random forests,’s upcoming course on ML will be a great way to get started (it’s still in production).

This post can be viewed as a review of the main approaches to estimating feature importance in Random Forest (RF).

A further review of feature importance (including methods for Gradient Boosted Trees) is available by CeShine Lee.

Feature Importance

One of the great advantages of Decision Tree-based models is their interpretability. That is, we can understand the reasoning behind each prediction. Using such a white-box model also gives us the ability to reason about which features were most important. Let’s take a look at the main methods of computing these:

Mean Decrease Impurity (MDI)

TL, DR: Average of how much each feature decreases error.

The authors of “Elements of Statistical Learning” explain the approach quite succinctly:

“At each split in each tree, the improvement in the split-criterion is the importance measure attributed to the splitting variable, and is accumulated over all the trees in the forest separately for each variable. “

Let’s elaborate on what this means. As we know, a decision tree is a series of nodes (splits) that each divide the samples that reached it in two. Each split uses a single feature (“the splitting variable”) to do so. This feature is chosen so that it minimizes the error (“the split-criterion”). The error can be Mean Squared Error, Gini Impurity, Information gain, or something else depending on your particular model. We sum up the decrease in error over all splits on this particular variable in all trees.

In the scikit-learn package, the improvement each split brings is weighted by the number of samples that reach the node. The feature importances are then normalized.

It’s worth noting that this approach tends to overestimate the importance of features with lots of categories. An alternative approach that corrects the bias of MDI is described here.

Mean Decrease Accuracy (MDA)

TL, DR: Measure the decrease in accuracy if we replace a given feature with a shuffled version of it.

This clever method computes the importance by utilizing out-of-bag data. The OOB data is the partition of training data not used to train this particular tree.

The baseline error of the tree on OOB data is computed. Then, for each feature, its data is randomly shuffled across the samples. Effectively, this acts like replacing the variable with random data with the same distribution and negates any knowledge the tree has about the feature.

If the variable was an important one, the performance of the tree should have decreased compared to the baseline. This decrease becomes the importance score for the scrambled feature. Each feature importance is then averaged across all trees.

While this method is not part of scikit-learn, it’s relatively simple to implement.


TL, DR: Iteratively remove features performing worse than the best shuffled features.

The main idea is to check which features are more important than random noise. To do this we create shadow variables, which are shuffled versions of all features. It’s like performing the variable shuffling described in “Mean Decrease Accuracy”, but for all variables at once. We combine the “shadowed” and the original features in a new dataset.

Then, we train a Random Forest on this new data. Using the MDA or MDI metric described above, we note which of the original, un-shuffled features are deemed more important than the best-performing shadow variables.

The importance metrics are more accurate if there are less irrelevant features. So the above process is repeated a predefined number of times or until a minimum feature count is reached.

The algorithm gradually removes features starting from the less relevant. Thus we can use the order of removal as a proxy metric for feature importance.

Boruta is an “all relevant” feature selection algorithm. This is subtly different than identifying the minimal set of features that give optimal prediction accuracy. As the authors of the method put it:

“…it tries to find all features carrying information usable for prediction, rather than finding a possibly compact subset of features on which some classifier has a minimal error. “

A python implementation of Boruta is available, with a detailed post explaining it.

Predicting churn is a very challenging problem. But it is a serious issue that hurts many businesses today. At RetainKit, we are working towards solving it using the techniques outlined in this post and much more. Here is how feature importance works inside RetainKit’s UI:

If you are interested in learning more about our tech, please follow this blog.

If you run a SaaS company and you have churn issues, we’d be happy to talk to you and see if our product could help. You can also follow us on Product Hunt Upcoming.

If you liked this article, please help others find it by holding that clap icon for a while. Thanks!

Identifying churn drivers with Random Forests was originally published in Slav on Medium, where people are continuing the conversation by highlighting and responding to this story.

Please follow and like us: