[size=17.0667px]In a previous post, I described how decision tree algorithms work and demonstrated their use via the rpart library in R. Decision trees work by splitting a dataset recursively. That is, subsets arising from a split are further split until a predetermined termination criterion is reached. At each step, a split is made based on the independent variable that results in the largest possible reduction in heterogeneity of the dependent variable.
[size=17.0667px](Note: readers unfamiliar with decision trees may want to read that post before proceeding)
[size=17.0667px]The main drawback of decision trees is that they are prone to overfitting. The reason for this is that trees, if grown deep, are able to fit all kinds of variations in the data, including noise. Although it is possible to address this partially by pruning, the result often remains less than satisfactory. This is because makes a locally optimalchoice at each split without any regard to whether the choice made is the best one overall. A poor split made in the initial stages can thus doom the model, a problem that cannot be fixed by post-hoc pruning.
[size=17.0667px]In this post I describe random forests, a tree-based algorithm that addresses the above shortcoming of decision trees. I’ll first describe the intuition behind the algorithm via an analogy and then do a demo using the R randomForest library.
Motivating random forests[size=17.0667px]One of the reasons for the popularity of decision trees is that they reflect the way humans make decisions: by weighing up options at each stage and choosing the best one available. The analogy is particularly useful because it also suggests how decision trees can be improved.
[size=17.0667px]One of the lifelines in the game show, Who Wants to be A Millionaire, is “Ask The Audience” wherein a contestant can ask the audience to vote on the answer to a question. The rationale here is that the majority response from a large number of independent decision makers is more likely to yield a correct answer than one from a randomly chosen person. There are two factors at play here:
- People have different experiences and will therefore draw upon different “data” to answer the question.
- People have different knowledge bases and preferences and will therefore draw upon different “variables” to make their choices at each stage in their decision process.
[size=17.0667px]Taking a cue from the above, it seems reasonable to build many decision trees using:
- Different sets of training data.
- Randomly selected subsets of variables at each split of every decision tree.
[size=17.0667px]Predictions can then made by taking the majority vote over all trees (for classification problems) or averaging results over all trees (for regression problems). This is essentially how the random forest algorithm works.
[size=17.0667px]The net effect of the two strategies is to reduce overfitting by a) averaging over trees created from different samples of the dataset and b) decreasing the likelihood of a small set of strong predictors dominating the splits. The price paid is reduced interpretability as well as increased computational complexity. But then, there is no such thing as a free lunch.
The mechanics of the algorithm[size=17.0667px]Although we will not delve into the mathematical details of the algorithm, it is important to understand how two points made above are implemented in the algorithm.
Bootstrap aggregating… and a (rather cool) error estimate[size=17.0667px]A key feature of the algorithm is the use of multiple datasets for training individual decision trees. This is done via a neat statistical trick called bootstrap aggregating(also called bagging).
[size=17.0667px]Here’s how bagging works: