Dr Andy Corbett

by Dr Andy Corbett

Lesson

Decision Trees

3. Machine-Designed Decision Trees

In this tutorial, we introduce the basis for a new family of models called tree-based algorithms. This will be our first encounter with a 'deep' model, but one whose depth may be meaningfully interpreted. Decision trees on their own are vulnerable, with risk to over fitting. But they are important modules in state-of-the-art algorithms such as the random forest.

šŸ“‘ Learning Objectives
  • Examine the decision tree algorithm for machine-trained trees.
  • Demonstrate its functiuonality on a classification example.
  • Examine ways to quantify success of the algorithm during training.
  • Introduce the notion of decision trees for regression problems.
Machine Designed Trees

Figure 1. Following the decision trail of a machine designed tree.

Machines that learn trees


Let's try and solve a more dificult problem--a classification problem--using the power of machine learning to optimise our decision tree splits. This shall lead us to an algorithm, and hence a model, that can be deployed on general datasets.

For this example, here's the data: We collect 300 samples, split equally amongst three generic classes: 'Gold', 'Blue' and 'Pink'.

Classification Problem

Figure 2. Sample data.

The classes are gathered in clusters, albeit with some overlap at the boundaries. A good machine learning model would disregard this as noise, and still produce the overall trend in the data. Our data points come in the form x=[x0,x1]\mathbf{x} = [x_0, x_1] along the two axes of the graph. To fit a decision tree we follow the following steps:

  1. Which feature, x0x_0 or x1x_1, maximises the split in the data.
  2. What is the cut-off in that feature which maximises the data split.

Applying this reasoning twice, we obtain the following graph.

Three-node graph

Figure 3. Output of the Decision Tree Classifier.

The top node--the 'root node'--indicates to first split the x1x_1 at 2.1522.152. This choice determines the greatest split of the data, with 87 having x1<2.057x_1<2.057 and the remaining 213213 larger. The next best splitting of the data is to look at the node with 213213 samples and split the x0x_0 variable at 2.0572.057. Let us visualise these two decisions as lines splitting the data.

Three-split data

Figure 4. Coutours of the decision tree predictor.

Intuitively, this seems like a good fit. We have constructed a tree with 5 nodes, two decision nodes and three leaf nodes, which is just two layers deep. So let's measure the accuracy of this model on both the training set and the held-out test set.

Depth 2Accuracy
Training set88.7%
Test set89.6%

šŸ˜‹ This method of optimisation is known as greedy optimisation: at a given node, we do not worry about the best answer with respect to the whole tree, simply how best can we split the data set at that point. This is a short-cutting technique and the benefit in speed is typically exponentially greater than costs in accuracy.

The error recorded here is dependent on the inherent noise in the data. Visually, we can see a trend of three clusters, so this model is the correct choice. But what were to happen if we allowed tree to split further until the entire training set was correctly classified?

Over-fit data

Figure 5. Over fitting can kreep in when hyperparameters permit too much flexibility.

Now we have begun to 'overfit'. Here we left the tree to classify each data point perfectly, so that accuracy would be 100% on the training set. But by inspection, we can't expect this tree to describe unseen data--and this point is demonstrated when the performance is evaluated on the testing set.

Depth 10Accuracy
Training set100.0%
Test set83.3%

What can we conclude? Decision trees are powerful tools for classification, but without user interference, as we did by limiting the model freedom (the depth), they run the risk of over-fitting to training data (and so produces a model which is not useful). Hang on to your seats as we visualise the graph of the tree in the headline Fig. 1.

Optimisation

What mechanisms can we use to automatically optimise these trees? Our objective is to maximise the number of samples into their correct classes. For classification, this could be achieved with:

  • Gini impurity: a measure of how often a randomly chosen data point would be misclassified by a new splitting.
  • Entropy: a measure of chaos, which in this case penalises all classes being equally well represented in a splitting.

Or in the case of regression:

  • The Mean Squared Error: given that in regression problems we have no classes, we penalise splittings in which the mean of the data in each split generates the most absolute error with the ground truth.