Dr Andy Corbett

by Dr Andy Corbett

Lesson

8. The Training Loop

In this video you will...

Learn the key hyperparameters needed for training a neural net. These include:

  • ✅ Walk through the training cycle.
  • ✅ Use batches and dataloading techniques.
  • ✅ See how set parameter updates.
  • ✅ Choose regularisation methods in training.

Training a neural network can seem like an art form. Or rather a dark art. But with any luck, some of the mysteries should become a little more familiar.

The goal of this video is to explain the following boiler plate code. This is the template for a nerual network training routine. Once commited to muscle memory, it'll feel like riding a bike... looping over and over again! :bike:

model = MLP()  # Instantiate the netork
loss_func = nn.BCELoss()  # Loss function
optimiser = torch.optim.SGD(model.parameters(), lr=0.1)  # Optimisation method

NUM_EPOCHS = 4000

# To store loss statistics
loss_train, loss_val = list(), list()
for ep in range(NUM_EPOCHS):
    running_loss = 0.0
    for ii, batch in enumerate(loader_train):
        inputs, targets = batch
        optimiser.zero_grad()  # Zero gradients in network
        preds = model(inputs)  # Make a forward pass
        loss = loss_func(preds, targets)  # Compute loss
        loss.backward()  # Backpropagate loss (compute gradients)
        optimiser.step()  # Update parameters

        # Record loss at this patch
        running_loss += loss.item()

    # Update epoch-level training loss
    loss_train.append(running_loss / n_train)

    # Test on validation set
    for ii, batch in enumerate(loader_val):
        inputs, targets = batch
        preds = model(inputs)  # Make a forward pass
        loss = loss_func(preds, targets)  # Compute loss
    loss_val.append(loss.item())

The Training Loop


First up is the cycle of epochs which we call training. This is by design a for loop. In a single epoch we pass through all the training data once.

But then within the epoch, we pass through all the batches. So it is a nested for loop. Lots of scope for parallelisation. But importantly, this is the recipe for what you must do in each iteration (i.e. upon each batch).

  1. Initialise/reset you variables. This includes clearing the accumulated gradients that are stored on each round as well as getting the next batch and moving onto your hardware device.
  2. Forward pass. This moves all the data through the network to obtain predictions
  3. Compute the loss. This applies the loss function to the prediction vs. the ground truth--remember this is a supervised training routine.
  4. Backpropagate. This means computing the gradients across the whole network.
  5. Update parameters. Now we apply our optimisation algorithm to update the network parameters.

Batches and data loading


In the training loop we call on batches of data. This allows us to manipulate the data on the fly during training.

  • We can use pre-designed data loaders (such as PyTorch's dataloaders) for computational efficiency.
  • We can shuffle the dataset each epoch as a form of regularisation (seeding the process to maintain reproducibility).
  • We can apply on-the-fly augmentation to our data with random probability.
  • Batch loading adds an element of stochasticity. By grouping the data in different ways each epoch, we are preventing deterministic trends from falling into local minima during optimisation.

Updating parameters


We mentioned that that batch loading adds an element of stochasticity. Indeed, this is what seperates gradient descent from stochastic gradient descent (SGD), a popular optimisation algorithm.

We can make the choice of optimiser early on in the code, chosing from SGD, Adam, or various other. And for these we are interested in at least two hyperparameters:

  • The learning rate: this controls the size of the update we make each step. Setting this lower prevents the model from diverging due to random noise early on. The trade off is that you must train for longer.
  • Set the regularisation. This can be of various forms. In the optimise we set the L2L^2 regularisation via a weight_decay parameter.

Regularisation during trianing


  • Training with L2L^2 regularisation tempers the parameters by preferring to not grow so large. This is measured in the loss function and hence effected by the parameter updates.
  • Batch normalisation: this is a layer you can add to your neural network what normalises the activations themselves during training. Again, this helps prevent unforeseen divergent behaviour.
  • Dropout: This is a technique that zeros node in the network at random given a specified probability. If the network is concentrating on a given pathway, this technique can help to smooth out that behaviour.

Note: both batch normalisation and dropout are ineffective during evaluation, only being applied in training. PyTroch knows when to activate/deactivate these layers by calling model.train() and model.eval(), respectively. Don't for get to include them in your code! (Although this is automated with Lightning.)

Further reading. Applying dropout to your network can be thought of as swiching between an ensemble of architectures. Fire samples through this architecture in evaluation mode can build up a Monte Carlo-based picture of the distributin of model outputs, given this probabilistic view of the ensemble model architecture. The technique is known as 'MC Dropout' and the paper introducing this notion can be found here.