by Dr Andy Corbett
Gradient Boosting
11. Auto-Correction in a Forest of Stumps
Gradient Boosting In Forests
Our journey has taken us to the most powerful out-of-the-box machine learning algorithm: gradient-boosted ensemble models. We've already seen one ensemble model: Random forests, which is an ensemble of decision trees. However the next algorithm under review today outperforms random forests on all machine-learning league tables: gradient boosted trees. This is always worth pulling out of your back pocket and applied first to new data problems. Why? It's performance and ease of use is simply unparalleled. In this tutorial we'll begin by gently unpacking the theory behind the algorithm.
- Introduce the notion of Gradient Boosted Trees.
- Explain terminology: 'weak learners', 'stumps', 'residuals', 'gradients'.
- Write down the mathematical trick that drives the algorithm.
- Walk through the algorithm in its most basic form.
From trees to stumps with autocorrection
Our journey has brought us from trees to forests and out the other side, a field full of stumps. When considering a collection of random forests, we decided that they should be trained independently, and thus also simultaneously. Then the final answer would be formed by averaging over all the outputs in the forest. This method required us to grow trees in proportion to the size (resp. square-root of the size) of the data set for regression problems (resp. classification).
Gradient boosted trees are trained in sequence, each dependent on the previous tree: each tree learns the error in the previous prediction (the residual in regression problems), rather than recreating the data itself. This means, if a data point prediction is miss-predicted by the model, the next tree shall divert attention there. At the same time, all correct predictions would have a small error and be grouped together around zero.
This is so effective, we are able to reduce the size of the tree to a stump: a decision tree with a single splitting. These small intermediate models, learning residuals, are formally known as 'weak learners'. Furthermore, we aren't tied to using just tree based-models. The core concept works well with any model.
Walking though the algorithm with a visual example
Let's define the algorithm by example: a simple quadratic curve, For comparison, let's first reconstruct a random forest of stumps to see how it fares.
Our training data shall consist of 10 points on this curve. Let's set up an example. Our goal is to interpolate between these points with 100 test points.
The training target is then the following list of numbers:
A random forest of stumps
First, let's construct our base comparative model: a random forest of stumps. We'll train each stump in the forest on bagged data (recall this definition)--on the contrary, the gradient boosted tree shall be trained on the entire data set.
This leaves us with a list of models, the 'stumps' which all make predictions via a single splitting. To form the aggregate model (the random forest), as this is a regression problem we take the mean of the predictions.
Figure 1. Using a collection of shallow trees, 'stumps', to form a random forest does not have enough flex to solve simple problems.
This is not an impressive model. The locations of the splittings impact the model drastically. The shape of the error function forces those splitting toward the sides leaving a void in the centre. Averaging a collection of stumps is simply too coarse for this data set.
Let's check the Mean-Squared Error (MSE) for each stump, as well as the forest.
Stump 1 | Stump 2 | Stump 3 | Stump 4 | Mean | |
---|---|---|---|---|---|
MSE: | 0.074 | 0.194 | 0.194 | 0.201 | 0.042 |
The error of each stump is almost five times as high, and uniform across the stumps. We shall keep this metric in mind as we move to our next model.
Out of the woods with gradient boosting
Now let's walk through the gradient boosting algorithm. Each stump is dependent on the previous, and hence each previous. We shall denote our dataset by , where we are trying to predict the vector from the columns of .
-
Let's call the first stump . This model shall be a straight forward decision tree, trained on the original set .
-
Compute the residual error
💠The MSE for is the sum of the squares of numbers in . Keep this in mind for later.
-
Now we train the next tree stump on these residuals, so that the training data is . We are now modeling the difference between and the target .
-
Note that . This shall be our ensemble model. And the next residual is formed by the difference so that
- Now we can iterate this to a given number of stumps ( say), where the -th stump is is trained by using to predict the previous residuals at which stage the model is evaluated by .
In our example, and we can compare the final model to the output of the random forest.
Figure 2. Applying Gradient Boosting to the stumps immediately shows far better convergence with the small number of parameters available.
By eye, it seems that the performance is very much better (see the fourth solid line). To be sure, let's numerically assess the performance.
1 Stump | 2 Stumps | 3 Stumps | 4 Stumps | |
---|---|---|---|---|
MSE | 0.074 | 0.039 | 0.022 | 0.019 |
As expected, in contrast to the independent random forest stumps, the MSE is decreasing. And with 4 stumps, the reported MSE of 0.019 is half as much as the MSE for the mean prediction of the forest.
More general gradients This construction is theoretically delightful. If were to impose a mean squared loss on our model,
then with a gradient boosted model , the gradient of this loss is precisely equal to the residuals:
and so we could pick a small real number , a learning rate and perform our model update , scaled by the learning rate, we recover the update formula used in gradient decent. And so this algorithm fits into a common optimisation framework. We can then also justify taking more general loss functions to use on bespoke problems.