by Prof Tim Dodwell
General Linear Models
9. From Regression to Classification - Logistic Regression
In this explainer we look at how we can "generalise" linear models to work for classification problems. This is called logistic regression. At the end of this explainer your should:
- Understand the key idea of logistic regression.
- Understandit's make up (e.g. linear predictor + link function).
- Understand how logisitic models are trained (e.g. loss function + gradient descent).
- Know how logistic regression can be extended to multiple classes.
- Know the issues of overfitting, and the solutions to this (regularisation).
Logistic regression is a statistical model which is used for classification problems. Logisitic regression estimates the probabilility that an event will occur. Hence the output of the model is between 0 and 1.
So we have a supervised learning problem, with our normal data set
Where are our inputs and are our target variables, in this case a binary label of either or .
Instead of fitting a straight line or hyper-plane (or any linear model), the logistic regression model uses the logistic function to squeeze the output of a linear equation between and so it represents a probability of giving label .
The logistic function is defined as:
Probably easier if we sketch this function out, so let's do that
The pointe here is that the input to the sigmoid or logit function is the output of a linear model itself . As we will see this gives great flexibility.
To give an interpretation of this, we can rearrange the equations in terms of , the linear model, so we get
So is the 'odds', the ratio of label A and not A (i.e. B). So the interpretation of a logistic model is building a linear model for the 'log odds'.
Before moving on how we train logistic regression model. Let us sketch some toy examples.
So here we look a simple linear model , this function is then push through the sigmoid function, generating the probability of against input values . In this example we decision boundary is at .
As discussed, a linear model doesn't mean a linear function. So with logistic regression models there is the possibility of great flexibility. Here is an example where is a quadratic function, resulting in a more complex decision boundary
Before looking at training a logistic model, we note that logistic regression is an general extension of ordinary linear models. There are two ingredients
- Linear Predictor just as in a standard linear model .
- A link function, in our case the sigmoid function, but could be more general. The link function provides the relationship between the linear prediction and the mean of the distribution function to be modelled.
Training a Logistic Model
For a Logistic Regression problem we can use a categorical cross-entropy loss, which is given by
The optimal weights can then be found via gradient based optimisation scheme (e.g. steepest descent, Newton or Quasi-Newton).
The gradients can be calculated using chain rule, we don't do the full calculation here but
We note that models are linear with respect to their weights, so if we differentiate with respect to we such get the basis functions .
This leaves us to calculate . The loss is just the loss for each samples added up. So we have
The last part set requires a bit of manipulation, but remember that
Importance of Regularisation
Logistic regression is a convex optimization problem (the likelihood function is concave), and it's known to not have a finite solution when it can fully separate the data, so the loss function can only reach its lowest value asymptomatically as the weights tend to .
Data is fully seperable if we have limited data relative to the flexibility in the model. Which defines the classic trade off between over and underfitting.
This has the effect of tightening decision boundaries around each data point when the data is separable, and the linear model is sufficiently expressive, with this asymptotically overfitting on the training set.
Without regularization, the asymptotic nature of logistic regression would keep driving the loss towards 0 in high dimensions. There are therefore two well used strageties.
-
(or Ridge) Regularisation, where an additional term is added to the loss.
-
Early Stopping.
We deal with regularisation as a seperate topic. For now sklearn
automatically applies ridge regularisation ('l2') by default with set to 1. So you now know what this this does, and understand that the regularisation parameter is actually a hyperparameter which you should also optimise over during training.
Here is a snapshot from the class documentation for logistic regression. Whilst there is no need to write your own code for doing it, since sklearn
's implementation is good, it is important you understand the meaning of the default assumptions. Fitting a good logistic model, will often require tuning of the regularisation parameter. Here you will see they allow a user to set the inverse of the regularisation strength.
Multi-class Classification
In what we have discussed so far we have considered binary classification. This is where there are a choice of only two classes, labels or outcomes. In general classification can map to any number of classes, which is referred to as multi-class classification.
Multi-class classification can be achieve via a simple extension of binary classification, described above, by following an approach called One-vs-all or One-vs_rest.
Suppose we now have classes, logistic regression is classifiered for each class , from to .
We can build a model which predicts the probability of each of the classes seperate, which is a binary classifier. I.e. is it class or is it not class .
The classifier with the highest probability wins
Conclusion
Logistic regression is a extension of linear models to a classification problem. It is part of a broader class of models which are called "Generalised Linear Models". These are the composition of two models
- a linear predictor (here we define as )
- link function - in this case the logit or sigmoid function.
This second function plays the role of squashing the function between and , transforming the output of the model to probability, which can be used for classification.
Regularisation is an important part of classification models, and when you use packages there are often chosen default parameters which play a central role in the outputs or quality of the results.
Finally, binary classification (two classes / labels) can be extended easily to multi-class classification using the so called "one-vs-all" strategy. This is nothing more than building multiple binary classifers, one for each class.