Spaces:
Running
on
Zero
Running
on
Zero
| from functools import partial | |
| from huggingface_hub import InferenceClient | |
| from os import path, unlink | |
| import gradio as gr | |
| from PIL.Image import Image | |
| import pandas as pd | |
| from pandas import DataFrame | |
| from utils import save_image_to_temp_file, request_image | |
| def image_classification(client: InferenceClient, model: str, image: Image) -> DataFrame: | |
| """Classify an image using Hugging Face Inference API. | |
| This function classifies a recyclable item image into categories: | |
| cardboard, glass, metal, paper, plastic, or other. The image is saved | |
| to a temporary file since InferenceClient requires a file path rather than | |
| a PIL Image object directly. | |
| Args: | |
| client: Hugging Face InferenceClient instance for API calls. | |
| model: Hugging Face model ID to use for image classification. | |
| image: PIL Image object to classify. | |
| Returns: | |
| Pandas DataFrame with two columns: | |
| - Label: The classification label (e.g., "cardboard", "glass") | |
| - Probability: The confidence score as a percentage string (e.g., "95.23%") | |
| Note: | |
| - Automatically cleans up temporary files after classification. | |
| - Temporary file is created with format preservation if possible. | |
| """ | |
| try: | |
| temp_file_path = save_image_to_temp_file(image) # Needed because InferenceClient does not accept PIL Images directly. | |
| classifications = client.image_classification(temp_file_path, model=model) | |
| return pd.DataFrame({ | |
| "Label": classification.label, | |
| "Probability": f"{classification.score:.2%}" | |
| } | |
| for classification | |
| in classifications) | |
| finally: | |
| if temp_file_path and path.exists(temp_file_path): # Clean up temporary file. | |
| try: | |
| unlink(temp_file_path) | |
| except Exception: | |
| pass # Ignore clean-up errors. | |
| def create_image_classification_tab(client: InferenceClient, model: str): | |
| """Create the image classification tab in the Gradio interface. | |
| This function sets up all UI components for image classification, including: | |
| - URL input textbox for fetching images from the web | |
| - Button to retrieve image from URL | |
| - Image preview component | |
| - Classify button and output dataframe showing labels and probabilities | |
| Args: | |
| client: Hugging Face InferenceClient instance to pass to the image_classification function. | |
| model: Hugging Face model ID to use for image classification. | |
| """ | |
| gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).") | |
| image_classification_url_input = gr.Textbox(label="Image URL") | |
| image_classification_image_request_button = gr.Button("Get Image") | |
| image_classification_image_input = gr.Image(label="Image", type="pil") | |
| image_classification_image_request_button.click( | |
| fn=request_image, | |
| inputs=image_classification_url_input, | |
| outputs=image_classification_image_input | |
| ) | |
| image_classification_button = gr.Button("Classify") | |
| image_classification_output = gr.Dataframe(label="Classification", headers=["Label", "Probability"], interactive=False) | |
| image_classification_button.click( | |
| fn=partial(image_classification, client, model), | |
| inputs=image_classification_image_input, | |
| outputs=image_classification_output | |
| ) | |