yl0628's picture
Update app.py
ed1b89e verified
# app.py
import os
import shutil
import zipfile
import pathlib
import tempfile
import mimetypes
import gradio as gr
import pandas as pd
from PIL import Image, UnidentifiedImageError
import huggingface_hub
from autogluon.multimodal import MultiModalPredictor
# -----------------------------
# Config / constants
# -----------------------------
MODEL_REPO_ID = "ddecosmo/nn_automl_model" # Hugging Face repo id (model)
ZIP_FILENAME = "autogluon_image_predictor_dir.zip" # the zipped predictor file in the repo
HF_TOKEN = os.getenv("HF_TOKEN", None) # pulled from HF Space secret (optional)
# Local cache/extract dirs inside the Space
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"
# App limits / UI config
ALLOWED_MIME = {"image/png", "image/jpeg"}
ALLOWED_EXTS = {".png", ".jpg", ".jpeg"}
MAX_BYTES = 5 * 1024 * 1024
TARGET = 224 # model input size
# -----------------------------
# Model setup
# -----------------------------
def _prepare_predictor_dir() -> str:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
local_zip = huggingface_hub.hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=ZIP_FILENAME,
repo_type="model",
token=HF_TOKEN, # works if model is private; otherwise ignored
local_dir=str(CACHE_DIR),
local_dir_use_symlinks=False,
)
if EXTRACT_DIR.exists():
shutil.rmtree(EXTRACT_DIR)
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(local_zip, "r") as zf:
zf.extractall(str(EXTRACT_DIR))
contents = list(EXTRACT_DIR.iterdir())
predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
return str(predictor_root)
PREDICTOR_DIR = _prepare_predictor_dir()
PREDICTOR = MultiModalPredictor.load(PREDICTOR_DIR)
# If you ever need explicit mapping, keep here (not strictly required if predictor already knows classes)
CLASS_LABELS = {0: "Western", 1: "Asian"}
# -----------------------------
# Helpers
# -----------------------------
def _human_filesize(n: int) -> str:
for unit in ["B","KB","MB","GB","TB"]:
if n < 1024.0:
return f"{n:.1f} {unit}"
n /= 1024.0
return f"{n:.1f} PB"
def _preprocess(img: Image.Image) -> Image.Image:
w, h = img.size
if min(w, h) == 0:
return Image.new("RGB", (TARGET, TARGET), (0, 0, 0))
scale = TARGET / min(w, h)
nw, nh = int(round(w * scale)), int(round(h * scale))
img = img.resize((nw, nh), Image.BILINEAR)
left = (img.width - TARGET) // 2
top = (img.height - TARGET) // 2
return img.crop((left, top, left + TARGET, top + TARGET))
# -----------------------------
# Inference function (filepath input)
# -----------------------------
def process_and_predict(file_path: str):
# basic validation
if not file_path:
gr.Warning("Please upload an image.")
return "Please upload an image.", None, None, {}
try:
size_bytes = os.path.getsize(file_path)
if size_bytes > MAX_BYTES:
gr.Error(f"File too large (max {_human_filesize(MAX_BYTES)}).")
return "File too large.", None, None, {}
except Exception:
# ok to continue (webcam/temp paths may not resolve size cleanly)
pass
ext = os.path.splitext(file_path)[1].lower()
mime = (mimetypes.guess_type(file_path)[0] or "").lower()
if (ext not in ALLOWED_EXTS) or (mime not in ALLOWED_MIME):
gr.Error("Unsupported file type. Please use PNG or JPG/JPEG.")
return "Unsupported file type.", None, None, {}
# open image safely
try:
with Image.open(file_path) as im:
im.verify()
original = Image.open(file_path).convert("RGB")
except UnidentifiedImageError:
return "Invalid image file. Please upload a real image.", None, None, {}
except Exception:
return "Could not process the file. Is it a valid image?", None, None, {}
# preprocess (what the model sees)
processed = _preprocess(original)
# predict using processed image
tmpdir = pathlib.Path(tempfile.mkdtemp())
proc_path = tmpdir / "model_input.png"
processed.save(proc_path)
df = pd.DataFrame({"image": [str(proc_path)]})
proba_df = PREDICTOR.predict_proba(df).rename(columns={0: "Western (0)", 1: "Asian (1)"})
row = proba_df.iloc[0]
pretty = {
"Western": float(row.get("Western (0)", 0.0)),
"Asian": float(row.get("Asian (1)", 0.0)),
}
gr.Info("Image accepted.")
return "Image loaded successfully.", original, processed, pretty
# -----------------------------
# UI
# -----------------------------
EXAMPLES = [
["examples/paella.jpg"],
["examples/mapo_tofu.jpg"],
["examples/kimchi.jpg"],
]
with gr.Blocks() as demo:
gr.Markdown("# Asian Food or Western Food?")
gr.Markdown(
"This Space demonstrates serving an AutoGluon image classifier in Gradio. "
"Upload a **PNG/JPG** (≤ 5 MB) or use your **webcam**. "
"You’ll see the original image and the preprocessed input the model sees."
)
gr.Markdown(
"""
**Model:** [ddecosmo/nn_automl_model](https://huggingface.co/ddecosmo/nn_automl_model)
**Library:** [AutoGluon TabularPredictor docs](https://auto.gluon.ai/dev/api/autogluon.tabular.TabularPredictor.html)
**Dataset:** [maryzhang/hw1-24679-image-dataset](https://huggingface.co/datasets/maryzhang/hw1-24679-image-dataset)
"""
)
with gr.Row():
with gr.Column():
image_in = gr.Image(
type="filepath",
label="Input image",
sources=["upload", "webcam"],
show_label=True,
height=300,
)
gr.Examples(
examples=[e[0] for e in EXAMPLES],
inputs=image_in,
label="Representative examples",
examples_per_page=8,
cache_examples=False,
)
with gr.Column():
status = gr.Markdown()
with gr.Row():
original_out = gr.Image(label="Original", height=300)
processed_out = gr.Image(label=f"Preprocessed ({TARGET}x{TARGET})", height=300)
proba_pretty = gr.Label(num_top_classes=2, label="Class probabilities")
image_in.change(
fn=process_and_predict,
inputs=[image_in],
outputs=[status, original_out, processed_out, proba_pretty],
show_progress="minimal",
)
if __name__ == "__main__":
demo.launch()