by Dr Freddy Wordingham
Web App
12. Returning the class predictions
In the previous lesson we looked at how we can deploy our app to the cloud.
Now let's improve our app buy showing the probabilities of each category.
We're going to need to update the backend code to send these values through, which we'll cover in this part. And then we'll need to display that information in the frontend, which we'll do in the next part.
š Classify Endpoint
In main.py we need to add a classify endpoint.
First, let's update the ClassifyOutput
class to indicate that we're also going to be returning a dictionary of floats:
class ClassifyOutput(BaseModel):
predicted_class: str
predictions: dict[str, str]
And then we need to add the classify route:
# Classify an image
@app.post("/classify")
async def classify(file: UploadFile = File(...)):
# Load the image
image = Image.open(file.file)
image_array = np.array(image)
# Ensure the image has 3 channels for RGB, and resize to 32x32
image_pil = Image.fromarray((image_array * 255).astype("uint8"))
image_pil = image_pil.convert("RGB").resize((32, 32))
image_array = np.array(image_pil)
# Add a batch dimension
image_array = tf.expand_dims(image_array, 0)
# Load the model
model = models.load_model(os.path.join("output", "model.h5"))
# Sample image
predictions = model.predict(image_array)
predicted_class = np.argmax(predictions)
# Sort the predictions
sorted_indices = np.argsort(predictions, axis=-1)[:, ::-1]
# Print the results
for i in sorted_indices[0]:
key = CLASS_NAMES[i].ljust(20, ".")
probability = "{:.1f}".format(predictions[0][i] * 100).rjust(5, " ")
print(f"{key} : {probability}%")
predictions = {
CLASS_NAMES[i]: f"{float(predictions[0][i]):.2f}" for i in sorted_indices[0]}
return ClassifyOutput(predicted_class=CLASS_NAMES[predicted_class], predictions=predictions)
This looks like the scripts/classify.py
file we wrote back in chapter 3. We've added lines to sort the predictions, highest to lowest, and then format the results into a dictionary of the human readable class names, and store the predicted value as a string with two decimal places. We then modify the returned value of ClassfyOutput to contain the predictions.
š Run
Check that things are behaving as we'd expect (that is, not throwing errors and returning some recognisable results) by running the backend server:
python -m uvicorn main:app --port 8000 --reload
Then poking the classify endpoint by sending it an image using cURL:
curl -X POST "http://127.0.0.1:8000/classify" \
-H "accept: application/json" \
-H "Content-Type: multipart/form-data" \
-F "file=@resources/dog.jpg"
ā Part 2
That's the backend done; in the next part of this chapter we'll display these predictions using the frontend of our app.
š APPENDIX
ā¹ļø How to Run
š§± Build Frontend
Navigate to the frontend/
directory:
cd frontend
Install any missing frontend dependancies:
npm install
Build the files for distributing the frontend to clients:
npm run build
š² Run the Backend
Go back to the project root directory:
cd ..
Activate the virtual environment, if you haven't already:
source .venv/bin/activate
Install any missing packages:
pip install -r requirements.txt
If you haven't already, train a CNN:
python scripts/train.py
Continue training an existing model:
python scripts/continue_training.py
Serve the web app:
python -m uvicorn main:app --port 8000 --reload
š Deploy
Deploy to the cloud:
serverless deploy
Remove from the cloud:
severless remove
šļø Updated Files
Project structure
.
āāā .venv/
āāā .gitignore
āāā .serverless/
āāā resources
ā āāā dog.jpg
āāā frontend
ā āāā build/
ā āāā node_modules/
ā āāā public/
ā āāā src
ā ā āāā App.css
ā ā āāā App.test.tsx
ā ā āāā App.tsx
ā ā āāā ImageUpload.tsx
ā ā āāā index.css
ā ā āāā index.tsx
ā ā āāā logo.svg
ā ā āāā react-app-env.d.ts
ā ā āāā reportWebVitals.ts
ā ā āāā setupTests.ts
ā ā āāā Sum.tsx
ā āāā .gitignore
ā āāā package-lock.json
ā āāā package.json
ā āāā README.md
ā āāā tsconfig.json
āāā output
ā āāā activations_conv2d/
ā āāā activations_conv2d_1/
ā āāā activations_conv2d_2/
ā āāā activations_dense/
ā āāā activations_dense_1/
ā āāā model.h5
ā āāā sample_images.png
ā āāā training_history.png
āāā scripts
ā āāā classify.py
ā āāā continue_training.py
ā āāā train.py
āāā main.py
āāā README.md
āāā requirements.txt
āāā serverless.yml
main.py
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from mangum import Mangum
from PIL import Image
from pydantic import BaseModel
from tensorflow.keras import models
import numpy as np
import os
import tensorflow as tf
CLASS_NAMES = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck"
]
# Instantiate the app
app = FastAPI()
# Ping test method
@app.get("/ping")
def ping():
return "pong!"
class SumInput(BaseModel):
a: int
b: int
class SumOutput(BaseModel):
sum: int
# Sum two numbers together
@app.post("/sum")
def sum(input: SumInput):
return SumOutput(sum=input.a + input.b)
class DimensionsOutput(BaseModel):
width: int
height: int
# Tell us the dimensions of an image
@app.post("/dimensions")
def dimensions(file: UploadFile = File(...)):
image = Image.open(file.file)
image_array = np.array(image)
width = image_array.shape[1]
height = image_array.shape[0]
return DimensionsOutput(width=width, height=height)
class ClassifyOutput(BaseModel):
predicted_class: str
predictions: dict[str, str]
# Classify an image
@app.post("/classify")
async def classify(file: UploadFile = File(...)):
# Load the image
image = Image.open(file.file)
image_array = np.array(image)
# Ensure the image has 3 channels for RGB, and resize to 32x32
image_pil = Image.fromarray((image_array * 255).astype("uint8"))
image_pil = image_pil.convert("RGB").resize((32, 32))
image_array = np.array(image_pil)
# Add a batch dimension
image_array = tf.expand_dims(image_array, 0)
# Load the model
model = models.load_model(os.path.join("output", "model.h5"))
# Sample image
predictions = model.predict(image_array)
predicted_class = np.argmax(predictions)
# Sort the predictions
sorted_indices = np.argsort(predictions, axis=-1)[:, ::-1]
# Print the results
for i in sorted_indices[0]:
key = CLASS_NAMES[i].ljust(20, ".")
probability = "{:.1f}".format(predictions[0][i] * 100).rjust(5, " ")
print(f"{key} : {probability}%")
predictions = {
CLASS_NAMES[i]: f"{float(predictions[0][i]):.2f}" for i in sorted_indices[0]}
return ClassifyOutput(predicted_class=CLASS_NAMES[predicted_class], predictions=predictions)
# Server our react application at the root
app.mount("/", StaticFiles(directory=os.path.join("frontend",
"build"), html=True), name="build")
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Permits requests from all origins.
# Allows cookies and credentials to be included in the request.
allow_credentials=True,
allow_methods=["*"], # Allows all HTTP methods.
allow_headers=["*"] # Allows all headers.
)
# Define the Lambda handler
handler = Mangum(app)
# Prevent Lambda showing errors in CloudWatch by handling warmup requests correctly
def lambda_handler(event, context):
if "source" in event and event["source"] == "aws.events":
print("This is a warm-ip invocation")
return {}
else:
return handler(event, context)