ai-building-blocks / image_classification.py
LiKenun's picture
Move environment variable querying code out of the Gradio UI-construction functions all the way to the root of the application, `app.py`
55d79e2
raw
history blame
3.57 kB
from functools import partial
from huggingface_hub import InferenceClient
from os import path, unlink
import gradio as gr
from PIL.Image import Image
import pandas as pd
from pandas import DataFrame
from utils import save_image_to_temp_file, request_image
def image_classification(client: InferenceClient, model: str, image: Image) -> DataFrame:
"""Classify an image using Hugging Face Inference API.
This function classifies a recyclable item image into categories:
cardboard, glass, metal, paper, plastic, or other. The image is saved
to a temporary file since InferenceClient requires a file path rather than
a PIL Image object directly.
Args:
client: Hugging Face InferenceClient instance for API calls.
model: Hugging Face model ID to use for image classification.
image: PIL Image object to classify.
Returns:
Pandas DataFrame with two columns:
- Label: The classification label (e.g., "cardboard", "glass")
- Probability: The confidence score as a percentage string (e.g., "95.23%")
Note:
- Automatically cleans up temporary files after classification.
- Temporary file is created with format preservation if possible.
"""
try:
temp_file_path = save_image_to_temp_file(image) # Needed because InferenceClient does not accept PIL Images directly.
classifications = client.image_classification(temp_file_path, model=model)
return pd.DataFrame({
"Label": classification.label,
"Probability": f"{classification.score:.2%}"
}
for classification
in classifications)
finally:
if temp_file_path and path.exists(temp_file_path): # Clean up temporary file.
try:
unlink(temp_file_path)
except Exception:
pass # Ignore clean-up errors.
def create_image_classification_tab(client: InferenceClient, model: str):
"""Create the image classification tab in the Gradio interface.
This function sets up all UI components for image classification, including:
- URL input textbox for fetching images from the web
- Button to retrieve image from URL
- Image preview component
- Classify button and output dataframe showing labels and probabilities
Args:
client: Hugging Face InferenceClient instance to pass to the image_classification function.
model: Hugging Face model ID to use for image classification.
"""
gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).")
image_classification_url_input = gr.Textbox(label="Image URL")
image_classification_image_request_button = gr.Button("Get Image")
image_classification_image_input = gr.Image(label="Image", type="pil")
image_classification_image_request_button.click(
fn=request_image,
inputs=image_classification_url_input,
outputs=image_classification_image_input
)
image_classification_button = gr.Button("Classify")
image_classification_output = gr.Dataframe(label="Classification", headers=["Label", "Probability"], interactive=False)
image_classification_button.click(
fn=partial(image_classification, client, model),
inputs=image_classification_image_input,
outputs=image_classification_output
)