by Dr Freddy Wordingham
Web App
13. Returning internal model activations
In the last lesson, we added the bar chart to display the probability of each class being the focus of the picture.
This time, we'll send the internal activation images so we can also display those.
We'll do this in two parts, again. First we'll update the backend code to send through this information, and then next time we'll update the frontend to display it.
š Update the Classify Endpoint
Update the imports:
from tensorflow.keras import models, Model
import matplotlib.pyplot as plt
from io import BytesIO
import base64
We need to update main.py once more such that the classify
route also returns the images visualising the internal activations of the CNN.
Start by adding activation_images
to the ClassifyOutput
output class in main.py
:
class ClassifyOutput(BaseModel):
predicted_class: str
predictions: dict[str, float]
activation_images: dict[str, list[str]]
We return a dictionary where each key represents the name of an internal layer. The value for each key is a list, with each element being a base64 encoded string of the activation map image for that layer.
We can then update the actual classify
route:
# Classify an image
@app.post("/classify")
async def classify(file: UploadFile = File(...)):
# Load the image
image = Image.open(file.file)
image_array = np.array(image)
# Ensure the image has 3 channels for RGB, and resize to 32x32
image_pil = Image.fromarray((image_array * 255).astype("uint8"))
image_pil = image_pil.convert("RGB").resize((32, 32))
image_array = np.array(image_pil)
# Add a batch dimension
image_array = tf.expand_dims(image_array, 0)
# Load the model
model = models.load_model(os.path.join("output", "model.h5"))
# Sample image
predictions = model.predict(image_array)
predicted_class = np.argmax(predictions)
# Sort the predictions
sorted_indices = np.argsort(predictions, axis=-1)[:, ::-1]
# Print the results
for i in sorted_indices[0]:
key = CLASS_NAMES[i].ljust(20, ".")
probability = "{:.1f}".format(predictions[0][i] * 100).rjust(5, " ")
print(f"{key} : {probability}%")
predictions = {
CLASS_NAMES[i]: f"{float(predictions[0][i]):.2f}" for i in sorted_indices[0]}
# Visualize intermediate layers
layer_names = [
layer.name for layer in model.layers if "conv" in layer.name or "dense" in layer.name]
activation_images = visualise_intermediate_layers(
model, image_array, layer_names)
return ClassifyOutput(predicted_class=CLASS_NAMES[predicted_class], predictions=predictions, activation_images=activation_images)
This is the same as we had before, but we've added lines 110 to 112, where we're specifying which layers we'd like to capture the activations of, and we then get those images from the visualize_intermediate_layers
function.
We define the visualize_intermediate_layers
function below:
def visualise_intermediate_layers(model, image_array, layer_names):
"""Visualise the intermediate layers of a model"""
layer_outputs = [layer.output for layer in model.layers]
activation_model = Model(inputs=model.input, outputs=layer_outputs)
activations = activation_model.predict(image_array)
activation_images = {}
for layer_name, activation in zip(layer_names, activations):
activation_images[layer_name] = []
num_channels = activation.shape[-1]
for i in range(num_channels):
plt.figure()
plt.imshow(activation[0, :, :, i], cmap="viridis")
plt.axis("off")
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
# Save the image to a bytes buffer
buf = BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
buf.seek(0)
# Convert the image to a base64 string
img_str = base64.b64encode(buf.read()).decode()
activation_images[layer_name].append(img_str)
plt.close()
return activation_images
This should look familar to the code we wrote back in chapter 4, but now instead of writing the image to a disc we encode it as a base64 string, and add it to the activation_images
dictionary.
ā¹ļø Run
Once again, it's good to check that things are behaving as we'd expect. We can start the backend server running:
source .venv/bin/activate
python -m uvicorn main:app --port 8000 --reload
and then poke at it with:
curl -X POST "http://127.0.0.1:8000/classify" \
-H "accept: application/json" \
-H "Content-Type: multipart/form-data" \
-F "file=@resources/dog.jpg"
But as we're expecting to recieve many base64 encoded images as the output, don't be suprised if you just see a lot of random characters appear in the terminal in response!
If we look in the terminal running the backend, we should see that we successfully recieved and processed the image:
š¢ Part 2
That's the backend done; in the next part of this chapter we'll display this new data in the frontend.
š APPENDIX
š How to Run
š§± Build Frontend
Navigate to the frontend/
directory:
cd frontend
Install any missing frontend dependancies:
npm install
Build the files for distributing the frontend to clients:
npm run build
š² Run the Backend
Go back to the project root directory:
cd ..
Activate the virtual environment, if you haven't already:
source .venv/bin/activate
Install any missing packages:
pip install -r requirements.txt
If you haven't already, train a CNN:
python scripts/train.py
Continue training an existing model:
python scripts/continue_training.py
Serve the web app:
python -m uvicorn main:app --port 8000 --reload
š Deploy
Deploy to the cloud:
serverless deploy
Remove from the cloud:
severless remove
šļø Updated Files
Project structure
.
āāā .venv/
āāā .gitignore
āāā .serverless/
āāā resources
ā āāā dog.jpg
āāā frontend
ā āāā build/
ā āāā node_modules/
ā āāā public/
ā āāā src
ā ā āāā App.css
ā ā āāā App.test.tsx
ā ā āāā App.tsx
ā ā āāā ImageGrid.tsx
ā ā āāā ImageUpload.tsx
ā ā āāā index.css
ā ā āāā index.tsx
ā ā āāā Predictions.css
ā ā āāā Predictions.tsx
ā ā āāā logo.svg
ā ā āāā react-app-env.d.ts
ā ā āāā reportWebVitals.ts
ā ā āāā setupTests.ts
ā ā āāā Sum.tsx
ā āāā .gitignore
ā āāā package-lock.json
ā āāā package.json
ā āāā README.md
ā āāā tsconfig.json
āāā output
ā āāā activations_conv2d/
ā āāā activations_conv2d_1/
ā āāā activations_conv2d_2/
ā āāā activations_dense/
ā āāā activations_dense_1/
ā āāā model.h5
ā āāā sample_images.png
ā āāā training_history.png
āāā scripts
ā āāā classify.py
ā āāā continue_training.py
ā āāā train.py
āāā main.py
āāā README.md
āāā requirements.txt
āāā serverless.yml
main.py
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from mangum import Mangum
from PIL import Image
from pydantic import BaseModel
from tensorflow.keras import models, Model
import numpy as np
import os
import tensorflow as tf
import matplotlib.pyplot as plt
from io import BytesIO
import base64
CLASS_NAMES = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck"
]
# Instantiate the app
app = FastAPI()
# Ping test method
@app.get("/ping")
def ping():
return "pong!"
class SumInput(BaseModel):
a: int
b: int
class SumOutput(BaseModel):
sum: int
# Sum two numbers together
@app.post("/sum")
def sum(input: SumInput):
return SumOutput(sum=input.a + input.b)
class DimensionsOutput(BaseModel):
width: int
height: int
# Tell us the dimensions of an image
@app.post("/dimensions")
def dimensions(file: UploadFile = File(...)):
image = Image.open(file.file)
image_array = np.array(image)
width = image_array.shape[1]
height = image_array.shape[0]
return DimensionsOutput(width=width, height=height)
class ClassifyOutput(BaseModel):
predicted_class: str
predictions: dict[str, str]
activation_images: dict[str, list[str]]
# Classify an image
@app.post("/classify")
async def classify(file: UploadFile = File(...)):
# Load the image
image = Image.open(file.file)
image_array = np.array(image)
# Ensure the image has 3 channels for RGB, and resize to 32x32
image_pil = Image.fromarray((image_array * 255).astype("uint8"))
image_pil = image_pil.convert("RGB").resize((32, 32))
image_array = np.array(image_pil)
# Add a batch dimension
image_array = tf.expand_dims(image_array, 0)
# Load the model
model = models.load_model(os.path.join("output", "model.h5"))
# Sample image
predictions = model.predict(image_array)
predicted_class = np.argmax(predictions)
# Sort the predictions
sorted_indices = np.argsort(predictions, axis=-1)[:, ::-1]
# Print the results
for i in sorted_indices[0]:
key = CLASS_NAMES[i].ljust(20, ".")
probability = "{:.1f}".format(predictions[0][i] * 100).rjust(5, " ")
print(f"{key} : {probability}%")
predictions = {
CLASS_NAMES[i]: f"{float(predictions[0][i]):.2f}" for i in sorted_indices[0]}
# Visualize intermediate layers
layer_names = [
layer.name for layer in model.layers if "conv" in layer.name or "dense" in layer.name]
activation_images = visualise_intermediate_layers(
model, image_array, layer_names)
return ClassifyOutput(predicted_class=CLASS_NAMES[predicted_class], predictions=predictions, activation_images=activation_images)
def visualise_intermediate_layers(model, image_array, layer_names):
"""Visualise the intermediate layers of a model"""
layer_outputs = [layer.output for layer in model.layers]
activation_model = Model(inputs=model.input, outputs=layer_outputs)
activations = activation_model.predict(image_array)
activation_images = {}
for layer_name, activation in zip(layer_names, activations):
activation_images[layer_name] = []
num_channels = activation.shape[-1]
for i in range(num_channels):
plt.figure()
plt.imshow(activation[0, :, :, i], cmap="viridis")
plt.axis("off")
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
# Save the image to a bytes buffer
buf = BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
buf.seek(0)
# Convert the image to a base64 string
img_str = base64.b64encode(buf.read()).decode()
activation_images[layer_name].append(img_str)
plt.close()
return activation_images
# Server our react application at the root
app.mount("/", StaticFiles(directory=os.path.join("frontend",
"build"), html=True), name="build")
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Permits requests from all origins.
# Allows cookies and credentials to be included in the request.
allow_credentials=True,
allow_methods=["*"], # Allows all HTTP methods.
allow_headers=["*"] # Allows all headers.
)
# Define the Lambda handler
handler = Mangum(app)
# Prevent Lambda showing errors in CloudWatch by handling warmup requests correctly
def lambda_handler(event, context):
if "source" in event and event["source"] == "aws.events":
print("This is a warm-ip invocation")
return {}
else:
return handler(event, context)