by Dr Andy Corbett
6. Develop with Lightning
Download the resources for this lesson here.
- ✅ Understand the
lightning
package for PyTorch. - ✅ Create a
lightning
module. - ✅ Assess training with TensorBoard.
- ✅ Save and load models.
We pick up our code from the last tutorial and refactor it to use a powerful high-level interface lightning
, a python package previously known as pytorch-lightning
. This interface allows us to manipulate pytorch
objects and train models in a more fluid and simple way, taking care of much of the behind the scenes boiler plate code.
Here's an example..
Convert to a Lightning module
Our ground zero is to collect our datasets and data loaders built with the PyTorch framework from before. We also build our neural network in PyTorch as before.
Then, step 1 in using lightning--the activation energy--is converting all of this into a LightningModule
.
import lightning.pytorch as pl
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
class PLModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = MLP()
def training_step(self, batch):
"""This defines a step in the training loop."""
x, y = batch # The batch comes from a torch dataloader
y_hat = self.model(x)
loss = nn.functional.binary_cross_entropy(y_hat, y)
self.log('train_loss', loss)
return loss
def train_dataloader(self):
return DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=True)
def validation_step(self, batch):
"""This defines a validation step."""
x, y = batch
y_hat = self.model(x)
loss = nn.functional.binary_cross_entropy(y_hat, y)
self.log('val_loss', loss)
return loss
def val_dataloader(self):
return DataLoader(data_val, batch_size=n_val)
def configure_optimizers(self):
optimiser = torch.optim.SGD(self.parameters(), lr=0.1)
return optimiser
model = PLModel()
This class give a recipe on how the lightning
should act:
- The
__init__
gives us room to define the PyTorch neural network. - The
training_step
andvalidation_step
tell us what to do in each step of training/validation; i.e. in each batch of each epoch. In this method, we useself.log
to report on the values we want to report on. - The
train_dataloader
andval_dataloader
construct the respective data loaders. - The
configure_optimizers
method instructs us which optimiser to use and how.
With this class constructed, we have made all our choices about training and validation and need not specify anything further to plot or analyse the model.
Training with Lightning
The hard work out the way, training becomes very simple. We include the option to checkpoint our model, which is very straight forward using the ModelCheckpoint
API. Training is then called with trainer.fit()
.
from lightning.pytorch.callbacks import ModelCheckpoint
ckpt = ModelCheckpoint(save_top_k=2, monitor='val_loss')
trainer = pl.Trainer(
check_val_every_n_epoch=100,
max_epochs=4000,
callbacks=[ckpt],
)
trainer.fit(model=model)
Here are some further tips.
Save your best model path
Don't forget to identify the be
pth = trainer.checkpoint_callback.best_model_path
model = PLModel.load_from_checkpoint(pth)
Inspect training logs with TensorBoard
We never have to plot the loss or accuracy again, with TensorBoard. After logging all the quantities of interest, simply call tensorboard from your terminal.
%%bash
tensorboard --logdir .
GPU acceleration
Here are some options to control the integrated GPU acceleration:
# CPU accelerator
trainer = Trainer(accelerator="cpu")
# Training with GPU Accelerator using 2 GPUs
trainer = Trainer(devices=2, accelerator="gpu")
# Training with TPU Accelerator using 8 tpu cores
trainer = Trainer(devices=8, accelerator="tpu")
# Training with GPU Accelerator using the DistributedDataParallel strategy
trainer = Trainer(devices=4, accelerator="gpu", strategy="ddp")
The API docs
Here is a resource for the trainer class: https://lightning.ai/docs/pytorch/stable/common/trainer.html#
Validate your model
Once again, we can simply call the trainer to validate the model in a single call.
trainer.validate(ckpt_path='best')
Loading from a checkpoint
Importantly, we want to recover our trained models. This is slightly more technical as the state dictionary holds parameters for the PyTorch model rather than the lightning Module, but here is some code that can fix the renaming.
# Loading from a checkpoint
trained_model = PLModel.load_from_checkpoint(pth)
# Load torhc model from path
checkpoint = torch.load(pth)
# Load torch model
torch_model = MLP()
model_weights = checkpoint["state_dict"]
# update keys by dropping `auto_encoder.`
for key in list(model_weights):
model_weights[key.replace('model.', '')] = model_weights.pop(key)
torch_model.load_state_dict(checkpoint['state_dict'])
# Loading from a checkpoint
trained_model = PLModel.load_from_checkpoint(pth)
# Load torhc model from path
checkpoint = torch.load(pth)
# Load torch model
torch_model = MLP()
model_weights = checkpoint["state_dict"]
# update keys by dropping `auto_encoder.`
for key in list(model_weights):
model_weights[key.replace('model.', '')] = model_weights.pop(key)
torch_model.load_state_dict(checkpoint['state_dict'])
Following these tips and the video, you can apply the lightning acceleration to your own projects.