Dr Freddy Wordingham

by Dr Freddy Wordingham

Lesson

Web App

13. Returning internal model activations

In the last lesson, we added the bar chart to display the probability of each class being the focus of the picture.

This time, we'll send the internal activation images so we can also display those.

We'll do this in two parts, again. First we'll update the backend code to send through this information, and then next time we'll update the frontend to display it.

šŸ—œ Update the Classify Endpoint


Update the imports:

from tensorflow.keras import models, Model
import matplotlib.pyplot as plt
from io import BytesIO
import base64

We need to update main.py once more such that the classify route also returns the images visualising the internal activations of the CNN.

Start by adding activation_images to the ClassifyOutput output class in main.py:

class ClassifyOutput(BaseModel):
    predicted_class: str
    predictions: dict[str, float]
    activation_images: dict[str, list[str]]

We return a dictionary where each key represents the name of an internal layer. The value for each key is a list, with each element being a base64 encoded string of the activation map image for that layer.

We can then update the actual classify route:

# Classify an image
@app.post("/classify")
async def classify(file: UploadFile = File(...)):
    # Load the image
    image = Image.open(file.file)
    image_array = np.array(image)

    # Ensure the image has 3 channels for RGB, and resize to 32x32
    image_pil = Image.fromarray((image_array * 255).astype("uint8"))
    image_pil = image_pil.convert("RGB").resize((32, 32))
    image_array = np.array(image_pil)

    # Add a batch dimension
    image_array = tf.expand_dims(image_array, 0)

    # Load the model
    model = models.load_model(os.path.join("output", "model.h5"))

    # Sample image
    predictions = model.predict(image_array)
    predicted_class = np.argmax(predictions)

    # Sort the predictions
    sorted_indices = np.argsort(predictions, axis=-1)[:, ::-1]

    # Print the results
    for i in sorted_indices[0]:
        key = CLASS_NAMES[i].ljust(20, ".")
        probability = "{:.1f}".format(predictions[0][i] * 100).rjust(5, " ")
        print(f"{key} : {probability}%")

    predictions = {
        CLASS_NAMES[i]: f"{float(predictions[0][i]):.2f}" for i in sorted_indices[0]}

    # Visualize intermediate layers
    layer_names = [
        layer.name for layer in model.layers if "conv" in layer.name or "dense" in layer.name]
    activation_images = visualise_intermediate_layers(
        model, image_array, layer_names)

    return ClassifyOutput(predicted_class=CLASS_NAMES[predicted_class], predictions=predictions, activation_images=activation_images)

This is the same as we had before, but we've added lines 110 to 112, where we're specifying which layers we'd like to capture the activations of, and we then get those images from the visualize_intermediate_layers function.

We define the visualize_intermediate_layers function below:

def visualise_intermediate_layers(model, image_array, layer_names):
    """Visualise the intermediate layers of a model"""

    layer_outputs = [layer.output for layer in model.layers]
    activation_model = Model(inputs=model.input, outputs=layer_outputs)
    activations = activation_model.predict(image_array)
    activation_images = {}

    for layer_name, activation in zip(layer_names, activations):
        activation_images[layer_name] = []

        num_channels = activation.shape[-1]

        for i in range(num_channels):
            plt.figure()
            plt.imshow(activation[0, :, :, i], cmap="viridis")
            plt.axis("off")
            plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

            # Save the image to a bytes buffer
            buf = BytesIO()
            plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
            buf.seek(0)

            # Convert the image to a base64 string
            img_str = base64.b64encode(buf.read()).decode()

            activation_images[layer_name].append(img_str)

            plt.close()

    return activation_images

This should look familar to the code we wrote back in chapter 4, but now instead of writing the image to a disc we encode it as a base64 string, and add it to the activation_images dictionary.

ā›¹ļø Run


Once again, it's good to check that things are behaving as we'd expect. We can start the backend server running:

source .venv/bin/activate
python -m uvicorn main:app --port 8000 --reload

and then poke at it with:

curl -X POST "http://127.0.0.1:8000/classify" \
    -H  "accept: application/json" \
    -H  "Content-Type: multipart/form-data" \
    -F "file=@resources/dog.jpg"

But as we're expecting to recieve many base64 encoded images as the output, don't be suprised if you just see a lot of random characters appear in the terminal in response!

If we look in the terminal running the backend, we should see that we successfully recieved and processed the image:

šŸ“¢ Part 2


That's the backend done; in the next part of this chapter we'll display this new data in the frontend.

šŸ“‘ APPENDIX


šŸƒ How to Run

šŸ§± Build Frontend

Navigate to the frontend/ directory:

cd frontend

Install any missing frontend dependancies:

npm install

Build the files for distributing the frontend to clients:

npm run build

šŸ–² Run the Backend

Go back to the project root directory:

cd ..

Activate the virtual environment, if you haven't already:

source .venv/bin/activate

Install any missing packages:

pip install -r requirements.txt

If you haven't already, train a CNN:

python scripts/train.py

Continue training an existing model:

python scripts/continue_training.py

Serve the web app:

python -m uvicorn main:app --port 8000 --reload

šŸš€ Deploy

Deploy to the cloud:

serverless deploy

Remove from the cloud:

severless remove

šŸ—‚ļø Updated Files

Project structure
.
ā”œā”€ā”€ .venv/
ā”œā”€ā”€ .gitignore
ā”œā”€ā”€ .serverless/
ā”œā”€ā”€ resources
ā”‚   ā””ā”€ā”€ dog.jpg
ā”œā”€ā”€ frontend
ā”‚   ā”œā”€ā”€ build/
ā”‚   ā”œā”€ā”€ node_modules/
ā”‚   ā”œā”€ā”€ public/
ā”‚   ā”œā”€ā”€ src
ā”‚   ā”‚   ā”œā”€ā”€ App.css
ā”‚   ā”‚   ā”œā”€ā”€ App.test.tsx
ā”‚   ā”‚   ā”œā”€ā”€ App.tsx
ā”‚   ā”‚   ā”œā”€ā”€ ImageGrid.tsx
ā”‚   ā”‚   ā”œā”€ā”€ ImageUpload.tsx
ā”‚   ā”‚   ā”œā”€ā”€ index.css
ā”‚   ā”‚   ā”œā”€ā”€ index.tsx
ā”‚   ā”‚   ā”œā”€ā”€ Predictions.css
ā”‚   ā”‚   ā”œā”€ā”€ Predictions.tsx
ā”‚   ā”‚   ā”œā”€ā”€ logo.svg
ā”‚   ā”‚   ā”œā”€ā”€ react-app-env.d.ts
ā”‚   ā”‚   ā”œā”€ā”€ reportWebVitals.ts
ā”‚   ā”‚   ā”œā”€ā”€ setupTests.ts
ā”‚   ā”‚   ā””ā”€ā”€ Sum.tsx
ā”‚   ā”œā”€ā”€ .gitignore
ā”‚   ā”œā”€ā”€ package-lock.json
ā”‚   ā”œā”€ā”€ package.json
ā”‚   ā”œā”€ā”€ README.md
ā”‚   ā””ā”€ā”€ tsconfig.json
ā”œā”€ā”€ output
ā”‚   ā”œā”€ā”€ activations_conv2d/
ā”‚   ā”œā”€ā”€ activations_conv2d_1/
ā”‚   ā”œā”€ā”€ activations_conv2d_2/
ā”‚   ā”œā”€ā”€ activations_dense/
ā”‚   ā”œā”€ā”€ activations_dense_1/
ā”‚   ā”œā”€ā”€ model.h5
ā”‚   ā”œā”€ā”€ sample_images.png
ā”‚   ā””ā”€ā”€ training_history.png
ā”œā”€ā”€ scripts
ā”‚   ā”œā”€ā”€ classify.py
ā”‚   ā”œā”€ā”€ continue_training.py
ā”‚   ā””ā”€ā”€ train.py
ā”œā”€ā”€ main.py
ā”œā”€ā”€ README.md
ā”œā”€ā”€ requirements.txt
ā””ā”€ā”€ serverless.yml
main.py
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from mangum import Mangum
from PIL import Image
from pydantic import BaseModel
from tensorflow.keras import models, Model
import numpy as np
import os
import tensorflow as tf
import matplotlib.pyplot as plt
from io import BytesIO
import base64


CLASS_NAMES = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck"
]


# Instantiate the app
app = FastAPI()


# Ping test method
@app.get("/ping")
def ping():
    return "pong!"


class SumInput(BaseModel):
    a: int
    b: int


class SumOutput(BaseModel):
    sum: int


# Sum two numbers together
@app.post("/sum")
def sum(input: SumInput):
    return SumOutput(sum=input.a + input.b)


class DimensionsOutput(BaseModel):
    width: int
    height: int


# Tell us the dimensions of an image
@app.post("/dimensions")
def dimensions(file: UploadFile = File(...)):
    image = Image.open(file.file)
    image_array = np.array(image)

    width = image_array.shape[1]
    height = image_array.shape[0]

    return DimensionsOutput(width=width, height=height)


class ClassifyOutput(BaseModel):
    predicted_class: str
    predictions: dict[str, str]
    activation_images: dict[str, list[str]]


# Classify an image
@app.post("/classify")
async def classify(file: UploadFile = File(...)):
    # Load the image
    image = Image.open(file.file)
    image_array = np.array(image)

    # Ensure the image has 3 channels for RGB, and resize to 32x32
    image_pil = Image.fromarray((image_array * 255).astype("uint8"))
    image_pil = image_pil.convert("RGB").resize((32, 32))
    image_array = np.array(image_pil)

    # Add a batch dimension
    image_array = tf.expand_dims(image_array, 0)

    # Load the model
    model = models.load_model(os.path.join("output", "model.h5"))

    # Sample image
    predictions = model.predict(image_array)
    predicted_class = np.argmax(predictions)

    # Sort the predictions
    sorted_indices = np.argsort(predictions, axis=-1)[:, ::-1]

    # Print the results
    for i in sorted_indices[0]:
        key = CLASS_NAMES[i].ljust(20, ".")
        probability = "{:.1f}".format(predictions[0][i] * 100).rjust(5, " ")
        print(f"{key} : {probability}%")

    predictions = {
        CLASS_NAMES[i]: f"{float(predictions[0][i]):.2f}" for i in sorted_indices[0]}

    # Visualize intermediate layers
    layer_names = [
        layer.name for layer in model.layers if "conv" in layer.name or "dense" in layer.name]
    activation_images = visualise_intermediate_layers(
        model, image_array, layer_names)

    return ClassifyOutput(predicted_class=CLASS_NAMES[predicted_class], predictions=predictions, activation_images=activation_images)


def visualise_intermediate_layers(model, image_array, layer_names):
    """Visualise the intermediate layers of a model"""

    layer_outputs = [layer.output for layer in model.layers]
    activation_model = Model(inputs=model.input, outputs=layer_outputs)
    activations = activation_model.predict(image_array)
    activation_images = {}

    for layer_name, activation in zip(layer_names, activations):
        activation_images[layer_name] = []

        num_channels = activation.shape[-1]

        for i in range(num_channels):
            plt.figure()
            plt.imshow(activation[0, :, :, i], cmap="viridis")
            plt.axis("off")
            plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

            # Save the image to a bytes buffer
            buf = BytesIO()
            plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
            buf.seek(0)

            # Convert the image to a base64 string
            img_str = base64.b64encode(buf.read()).decode()

            activation_images[layer_name].append(img_str)

            plt.close()

    return activation_images


# Server our react application at the root
app.mount("/", StaticFiles(directory=os.path.join("frontend",
          "build"), html=True), name="build")


# CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],    # Permits requests from all origins.
    # Allows cookies and credentials to be included in the request.
    allow_credentials=True,
    allow_methods=["*"],    # Allows all HTTP methods.
    allow_headers=["*"]     # Allows all headers.
)

# Define the Lambda handler
handler = Mangum(app)


# Prevent Lambda showing errors in CloudWatch by handling warmup requests correctly
def lambda_handler(event, context):
    if "source" in event and event["source"] == "aws.events":
        print("This is a warm-ip invocation")
        return {}
    else:
        return handler(event, context)