Dr Andy Corbett

by Dr Andy Corbett


6. Develop with Lightning

In this video you will...
  • ✅ Understand the lightning package for PyTorch.
  • ✅ Create a lightning module.
  • ✅ Assess training with TensorBoard.
  • ✅ Save and load models.

We pick up our code from the last tutorial and refactor it to use a powerful high-level interface lightning, a python package previously known as pytorch-lightning. This interface allows us to manipulate pytorch objects and train models in a more fluid and simple way, taking care of much of the behind the scenes boiler plate code.

Here's an example..

Convert to a Lightning module

Our ground zero is to collect our datasets and data loaders built with the PyTorch framework from before. We also build our neural network in PyTorch as before.

Then, step 1 in using lightning--the activation energy--is converting all of this into a LightningModule.

import lightning.pytorch as pl
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS

class PLModel(pl.LightningModule):
    def __init__(self):
        self.model = MLP()

    def training_step(self, batch):
        """This defines a step in the training loop."""
        x, y = batch  # The batch comes from a torch dataloader
        y_hat = self.model(x)
        loss = nn.functional.binary_cross_entropy(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def train_dataloader(self):
        return DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=True)

    def validation_step(self, batch):
        """This defines a validation step."""
        x, y = batch
        y_hat = self.model(x)
        loss = nn.functional.binary_cross_entropy(y_hat, y)
        self.log('val_loss', loss)
        return loss

    def val_dataloader(self):
        return DataLoader(data_val, batch_size=n_val)

    def configure_optimizers(self):
        optimiser = torch.optim.SGD(self.parameters(), lr=0.1)
        return optimiser

model = PLModel()

This class give a recipe on how the lightning should act:

  1. The __init__ gives us room to define the PyTorch neural network.
  2. The training_step and validation_step tell us what to do in each step of training/validation; i.e. in each batch of each epoch. In this method, we use self.log to report on the values we want to report on.
  3. The train_dataloader and val_dataloader construct the respective data loaders.
  4. The configure_optimizers method instructs us which optimiser to use and how.

With this class constructed, we have made all our choices about training and validation and need not specify anything further to plot or analyse the model.

Training with Lightning

The hard work out the way, training becomes very simple. We include the option to checkpoint our model, which is very straight forward using the ModelCheckpoint API. Training is then called with trainer.fit().

from lightning.pytorch.callbacks import ModelCheckpoint

ckpt = ModelCheckpoint(save_top_k=2, monitor='val_loss')

trainer = pl.Trainer(


Here are some further tips.

Save your best model path

Don't forget to identify the be

pth = trainer.checkpoint_callback.best_model_path
model = PLModel.load_from_checkpoint(pth)

Inspect training logs with TensorBoard

We never have to plot the loss or accuracy again, with TensorBoard. After logging all the quantities of interest, simply call tensorboard from your terminal.


tensorboard --logdir .

GPU acceleration

Here are some options to control the integrated GPU acceleration:

# CPU accelerator
trainer = Trainer(accelerator="cpu")

# Training with GPU Accelerator using 2 GPUs
trainer = Trainer(devices=2, accelerator="gpu")

# Training with TPU Accelerator using 8 tpu cores
trainer = Trainer(devices=8, accelerator="tpu")

# Training with GPU Accelerator using the DistributedDataParallel strategy
trainer = Trainer(devices=4, accelerator="gpu", strategy="ddp")

The API docs

Here is a resource for the trainer class: https://lightning.ai/docs/pytorch/stable/common/trainer.html#

Validate your model

Once again, we can simply call the trainer to validate the model in a single call.


Loading from a checkpoint

Importantly, we want to recover our trained models. This is slightly more technical as the state dictionary holds parameters for the PyTorch model rather than the lightning Module, but here is some code that can fix the renaming.

# Loading from a checkpoint
trained_model = PLModel.load_from_checkpoint(pth)

# Load torhc model from path
checkpoint = torch.load(pth)

# Load torch model
torch_model = MLP()
model_weights = checkpoint["state_dict"]

# update keys by dropping `auto_encoder.`
for key in list(model_weights):
    model_weights[key.replace('model.', '')] = model_weights.pop(key)

Following these tips and the video, you can apply the lightning acceleration to your own projects.