Dr Freddy Wordingham

by Dr Freddy Wordingham

Lesson

Web App

12. Returning the class predictions

In the previous lesson we looked at how we can deploy our app to the cloud.

Now let's improve our app buy showing the probabilities of each category.

We're going to need to update the backend code to send these values through, which we'll cover in this part. And then we'll need to display that information in the frontend, which we'll do in the next part.

šŸ” Classify Endpoint


In main.py we need to add a classify endpoint.

First, let's update the ClassifyOutput class to indicate that we're also going to be returning a dictionary of floats:

class ClassifyOutput(BaseModel):
    predicted_class: str
    predictions: dict[str, str]

And then we need to add the classify route:

# Classify an image
@app.post("/classify")
async def classify(file: UploadFile = File(...)):
    # Load the image
    image = Image.open(file.file)
    image_array = np.array(image)

    # Ensure the image has 3 channels for RGB, and resize to 32x32
    image_pil = Image.fromarray((image_array * 255).astype("uint8"))
    image_pil = image_pil.convert("RGB").resize((32, 32))
    image_array = np.array(image_pil)

    # Add a batch dimension
    image_array = tf.expand_dims(image_array, 0)

    # Load the model
    model = models.load_model(os.path.join("output", "model.h5"))

    # Sample image
    predictions = model.predict(image_array)
    predicted_class = np.argmax(predictions)

    # Sort the predictions
    sorted_indices = np.argsort(predictions, axis=-1)[:, ::-1]

    # Print the results
    for i in sorted_indices[0]:
        key = CLASS_NAMES[i].ljust(20, ".")
        probability = "{:.1f}".format(predictions[0][i] * 100).rjust(5, " ")
        print(f"{key} : {probability}%")

    predictions = {
        CLASS_NAMES[i]: f"{float(predictions[0][i]):.2f}" for i in sorted_indices[0]}

    return ClassifyOutput(predicted_class=CLASS_NAMES[predicted_class], predictions=predictions)

This looks like the scripts/classify.py file we wrote back in chapter 3. We've added lines to sort the predictions, highest to lowest, and then format the results into a dictionary of the human readable class names, and store the predicted value as a string with two decimal places. We then modify the returned value of ClassfyOutput to contain the predictions.

šŸ† Run


Check that things are behaving as we'd expect (that is, not throwing errors and returning some recognisable results) by running the backend server:

python -m uvicorn main:app --port 8000 --reload

Then poking the classify endpoint by sending it an image using cURL:

curl -X POST "http://127.0.0.1:8000/classify" \
    -H  "accept: application/json" \
    -H  "Content-Type: multipart/form-data" \
    -F "file=@resources/dog.jpg"

ā— Part 2


That's the backend done; in the next part of this chapter we'll display these predictions using the frontend of our app.

šŸ“‘ APPENDIX


ā›¹ļø How to Run

šŸ§± Build Frontend

Navigate to the frontend/ directory:

cd frontend

Install any missing frontend dependancies:

npm install

Build the files for distributing the frontend to clients:

npm run build

šŸ–² Run the Backend

Go back to the project root directory:

cd ..

Activate the virtual environment, if you haven't already:

source .venv/bin/activate

Install any missing packages:

pip install -r requirements.txt

If you haven't already, train a CNN:

python scripts/train.py

Continue training an existing model:

python scripts/continue_training.py

Serve the web app:

python -m uvicorn main:app --port 8000 --reload

šŸš€ Deploy

Deploy to the cloud:

serverless deploy

Remove from the cloud:

severless remove

šŸ—‚ļø Updated Files

Project structure
.
ā”œā”€ā”€ .venv/
ā”œā”€ā”€ .gitignore
ā”œā”€ā”€ .serverless/
ā”œā”€ā”€ resources
ā”‚   ā””ā”€ā”€ dog.jpg
ā”œā”€ā”€ frontend
ā”‚   ā”œā”€ā”€ build/
ā”‚   ā”œā”€ā”€ node_modules/
ā”‚   ā”œā”€ā”€ public/
ā”‚   ā”œā”€ā”€ src
ā”‚   ā”‚   ā”œā”€ā”€ App.css
ā”‚   ā”‚   ā”œā”€ā”€ App.test.tsx
ā”‚   ā”‚   ā”œā”€ā”€ App.tsx
ā”‚   ā”‚   ā”œā”€ā”€ ImageUpload.tsx
ā”‚   ā”‚   ā”œā”€ā”€ index.css
ā”‚   ā”‚   ā”œā”€ā”€ index.tsx
ā”‚   ā”‚   ā”œā”€ā”€ logo.svg
ā”‚   ā”‚   ā”œā”€ā”€ react-app-env.d.ts
ā”‚   ā”‚   ā”œā”€ā”€ reportWebVitals.ts
ā”‚   ā”‚   ā”œā”€ā”€ setupTests.ts
ā”‚   ā”‚   ā””ā”€ā”€ Sum.tsx
ā”‚   ā”œā”€ā”€ .gitignore
ā”‚   ā”œā”€ā”€ package-lock.json
ā”‚   ā”œā”€ā”€ package.json
ā”‚   ā”œā”€ā”€ README.md
ā”‚   ā””ā”€ā”€ tsconfig.json
ā”œā”€ā”€ output
ā”‚   ā”œā”€ā”€ activations_conv2d/
ā”‚   ā”œā”€ā”€ activations_conv2d_1/
ā”‚   ā”œā”€ā”€ activations_conv2d_2/
ā”‚   ā”œā”€ā”€ activations_dense/
ā”‚   ā”œā”€ā”€ activations_dense_1/
ā”‚   ā”œā”€ā”€ model.h5
ā”‚   ā”œā”€ā”€ sample_images.png
ā”‚   ā””ā”€ā”€ training_history.png
ā”œā”€ā”€ scripts
ā”‚   ā”œā”€ā”€ classify.py
ā”‚   ā”œā”€ā”€ continue_training.py
ā”‚   ā””ā”€ā”€ train.py
ā”œā”€ā”€ main.py
ā”œā”€ā”€ README.md
ā”œā”€ā”€ requirements.txt
ā””ā”€ā”€ serverless.yml
main.py
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from mangum import Mangum
from PIL import Image
from pydantic import BaseModel
from tensorflow.keras import models
import numpy as np
import os
import tensorflow as tf


CLASS_NAMES = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck"
]


# Instantiate the app
app = FastAPI()


# Ping test method
@app.get("/ping")
def ping():
    return "pong!"


class SumInput(BaseModel):
    a: int
    b: int


class SumOutput(BaseModel):
    sum: int


# Sum two numbers together
@app.post("/sum")
def sum(input: SumInput):
    return SumOutput(sum=input.a + input.b)


class DimensionsOutput(BaseModel):
    width: int
    height: int


# Tell us the dimensions of an image
@app.post("/dimensions")
def dimensions(file: UploadFile = File(...)):
    image = Image.open(file.file)
    image_array = np.array(image)

    width = image_array.shape[1]
    height = image_array.shape[0]

    return DimensionsOutput(width=width, height=height)


class ClassifyOutput(BaseModel):
    predicted_class: str
    predictions: dict[str, str]


# Classify an image
@app.post("/classify")
async def classify(file: UploadFile = File(...)):
    # Load the image
    image = Image.open(file.file)
    image_array = np.array(image)

    # Ensure the image has 3 channels for RGB, and resize to 32x32
    image_pil = Image.fromarray((image_array * 255).astype("uint8"))
    image_pil = image_pil.convert("RGB").resize((32, 32))
    image_array = np.array(image_pil)

    # Add a batch dimension
    image_array = tf.expand_dims(image_array, 0)

    # Load the model
    model = models.load_model(os.path.join("output", "model.h5"))

    # Sample image
    predictions = model.predict(image_array)
    predicted_class = np.argmax(predictions)

    # Sort the predictions
    sorted_indices = np.argsort(predictions, axis=-1)[:, ::-1]

    # Print the results
    for i in sorted_indices[0]:
        key = CLASS_NAMES[i].ljust(20, ".")
        probability = "{:.1f}".format(predictions[0][i] * 100).rjust(5, " ")
        print(f"{key} : {probability}%")

    predictions = {
        CLASS_NAMES[i]: f"{float(predictions[0][i]):.2f}" for i in sorted_indices[0]}

    return ClassifyOutput(predicted_class=CLASS_NAMES[predicted_class], predictions=predictions)


# Server our react application at the root
app.mount("/", StaticFiles(directory=os.path.join("frontend",
          "build"), html=True), name="build")


# CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],    # Permits requests from all origins.
    # Allows cookies and credentials to be included in the request.
    allow_credentials=True,
    allow_methods=["*"],    # Allows all HTTP methods.
    allow_headers=["*"]     # Allows all headers.
)

# Define the Lambda handler
handler = Mangum(app)


# Prevent Lambda showing errors in CloudWatch by handling warmup requests correctly
def lambda_handler(event, context):
    if "source" in event and event["source"] == "aws.events":
        print("This is a warm-ip invocation")
        return {}
    else:
        return handler(event, context)