Dr Freddy Wordingham

by Dr Freddy Wordingham

Lesson

Training a CNN

5. Continuing training

In the last lesson, we added an additional function to our scripts/classify.py file which visualised the intermediate activations of our CNN's layers.

This time we're going to add a script which continues training a pre-existing model.

💪🏼 Continue Training Script


Now each time we go to train our model we're only going to have a finite amount of time to train. So what we'd also like is a script which can take a pre-existing model, and continue trainging it.

Let's go ahead and create a new file scripts/continue_training.py, and add the following:

from tensorflow.keras import datasets, models
import os
import tensorflow as tf

if __name__ == "__main__":
    # Load data
    (train_images, train_labels), (test_images,
                                   test_labels) = datasets.cifar10.load_data()

    # Normalise pixels values
    train_images, test_images = train_images / 255.0, test_images / 255.0

    # Load model
    init_model_path = os.path.join("output", "model.h5")
    model = models.load_model(init_model_path)

    # Train the model
    model.compile(optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=False), metrics=["accuracy"])

    # Training loop
    epoch_counter = 0
    epochs_per_iteration = 10
    while True:
        history = model.fit(train_images, train_labels, epochs=epochs_per_iteration,
                            validation_data=(test_images, test_labels))
        epoch_counter += epochs_per_iteration

        model.save("output", "model.h5")

        print(f"Total number of epochs this train session: {epoch_counter}")

This script is relatively straightforward and should look familiar to the code we've seen in Lesson 2: Training a CNN.

This script:

  1. Imports necessary libraries from TensorFlow.
  2. Loads CIFAR-10 dataset.
  3. Normalises image data.
  4. Loads a pre-existing model from a specified path.
  5. Compiles the model with Adam optimizer and sparse categorical cross-entropy loss.
  6. Enters an infinite training loop, training the model in 10-epoch chunks.
  7. Saves the model after each chunk.
  8. Prints the total number of epochs trained so far.

We're going to save the model to disk every ten iterations- so if the script gets cancelled arbitrarily or fails at some point, we'll still have a recent checkpoint.

Update the epochs_per_iteration as you see fit!

⚠️ The script employs an infinite loop for ongoing training of the CNN model. However, over time the model is prone to overfitting, where it performs exceptionally well on the training data but poorly on new, unseen data. Additionally, the gains in accuracy will eventually reach a point of diminishing returns, making the extra computational resources spent on training less justifiable. You may wish to incorporate an exit strategy based on evaluation metrics, or set a maximum epoch limit to ensure efficient use of resources. On your machine this probably isn't an issue, but if we left this running on a server, it could rack up substantial costs and possibly exhaust available resources, which could impact other services or tasks handled by the server.

🚀 Run


Continue training the existing model:

python scripts/continue_training.py

ℹ️ This trains the model indefinitely, you will need to manually interrupt the program in the terminal to stop it.

Then, try classifying your dog and bird image again to see how the model has improved:

python scripts/classify.py <path/to/your/image.png>

📑 APPENDIX


🏃 How to Run

Remember to activate the virtual environment, if you haven't already:

source .venv/bin/activate

Install any required packages:

pip install -r requirements.txt

If you haven't already, train a CNN:

python scripts/train.py

Continue training an existing model:

python scripts/continue_training.py

🗂️ Updated Files

Project structure
.
├── .venv/
├── .gitignore
├── resources
│   ├── bird.jpg
│   └── dog.png
├── output
│   ├── activations_conv2d/
│   ├── activations_conv2d_1/
│   ├── activations_conv2d_2/
│   ├── activations_dense/
│   ├── activations_dense_1/
│   ├── model.h5
│   ├── sample_images.png
│   └── training_history.png
├── scripts
│   ├── classify.py
│   ├── continue_training.py
│   └── train.py
├── README.md
└── requirements.txt
matplotlib
tensorflow
scripts/continue_training.py
from tensorflow.keras import datasets, models
import os
import tensorflow as tf

if __name__ == "__main__":
    # Load data
    (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

    # Normalise pixels values
    train_images, test_images = train_images / 255.0, test_images / 255.0

    # Load model
    init_model_path = os.path.join("output", "model.h5")
    model = models.load_model(init_model_path)

    # Train the model
    model.compile(optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=["accuracy"])

    # Training loop
    epoch_counter = 0
    epochs_per_iteration = 10
    while True:
        history = model.fit(train_images, train_labels, epochs=epochs_per_iteration, validation_data=(test_images, test_labels))
        epoch_counter += epochs_per_iteration

        model.save("output", "model.h5")

        print(f"Total number of epochs this train session: {epoch_counter}")