Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import shutil | |
| import zipfile | |
| import pathlib | |
| import tempfile | |
| import mimetypes | |
| import gradio as gr | |
| import pandas as pd | |
| from PIL import Image, UnidentifiedImageError | |
| import huggingface_hub | |
| from autogluon.multimodal import MultiModalPredictor | |
| # ----------------------------- | |
| # Config / constants | |
| # ----------------------------- | |
| MODEL_REPO_ID = "ddecosmo/nn_automl_model" # Hugging Face repo id (model) | |
| ZIP_FILENAME = "autogluon_image_predictor_dir.zip" # the zipped predictor file in the repo | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) # pulled from HF Space secret (optional) | |
| # Local cache/extract dirs inside the Space | |
| CACHE_DIR = pathlib.Path("hf_assets") | |
| EXTRACT_DIR = CACHE_DIR / "predictor_native" | |
| # App limits / UI config | |
| ALLOWED_MIME = {"image/png", "image/jpeg"} | |
| ALLOWED_EXTS = {".png", ".jpg", ".jpeg"} | |
| MAX_BYTES = 5 * 1024 * 1024 | |
| TARGET = 224 # model input size | |
| # ----------------------------- | |
| # Model setup | |
| # ----------------------------- | |
| def _prepare_predictor_dir() -> str: | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| local_zip = huggingface_hub.hf_hub_download( | |
| repo_id=MODEL_REPO_ID, | |
| filename=ZIP_FILENAME, | |
| repo_type="model", | |
| token=HF_TOKEN, # works if model is private; otherwise ignored | |
| local_dir=str(CACHE_DIR), | |
| local_dir_use_symlinks=False, | |
| ) | |
| if EXTRACT_DIR.exists(): | |
| shutil.rmtree(EXTRACT_DIR) | |
| EXTRACT_DIR.mkdir(parents=True, exist_ok=True) | |
| with zipfile.ZipFile(local_zip, "r") as zf: | |
| zf.extractall(str(EXTRACT_DIR)) | |
| contents = list(EXTRACT_DIR.iterdir()) | |
| predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR | |
| return str(predictor_root) | |
| PREDICTOR_DIR = _prepare_predictor_dir() | |
| PREDICTOR = MultiModalPredictor.load(PREDICTOR_DIR) | |
| # If you ever need explicit mapping, keep here (not strictly required if predictor already knows classes) | |
| CLASS_LABELS = {0: "Western", 1: "Asian"} | |
| # ----------------------------- | |
| # Helpers | |
| # ----------------------------- | |
| def _human_filesize(n: int) -> str: | |
| for unit in ["B","KB","MB","GB","TB"]: | |
| if n < 1024.0: | |
| return f"{n:.1f} {unit}" | |
| n /= 1024.0 | |
| return f"{n:.1f} PB" | |
| def _preprocess(img: Image.Image) -> Image.Image: | |
| w, h = img.size | |
| if min(w, h) == 0: | |
| return Image.new("RGB", (TARGET, TARGET), (0, 0, 0)) | |
| scale = TARGET / min(w, h) | |
| nw, nh = int(round(w * scale)), int(round(h * scale)) | |
| img = img.resize((nw, nh), Image.BILINEAR) | |
| left = (img.width - TARGET) // 2 | |
| top = (img.height - TARGET) // 2 | |
| return img.crop((left, top, left + TARGET, top + TARGET)) | |
| # ----------------------------- | |
| # Inference function (filepath input) | |
| # ----------------------------- | |
| def process_and_predict(file_path: str): | |
| # basic validation | |
| if not file_path: | |
| gr.Warning("Please upload an image.") | |
| return "Please upload an image.", None, None, {} | |
| try: | |
| size_bytes = os.path.getsize(file_path) | |
| if size_bytes > MAX_BYTES: | |
| gr.Error(f"File too large (max {_human_filesize(MAX_BYTES)}).") | |
| return "File too large.", None, None, {} | |
| except Exception: | |
| # ok to continue (webcam/temp paths may not resolve size cleanly) | |
| pass | |
| ext = os.path.splitext(file_path)[1].lower() | |
| mime = (mimetypes.guess_type(file_path)[0] or "").lower() | |
| if (ext not in ALLOWED_EXTS) or (mime not in ALLOWED_MIME): | |
| gr.Error("Unsupported file type. Please use PNG or JPG/JPEG.") | |
| return "Unsupported file type.", None, None, {} | |
| # open image safely | |
| try: | |
| with Image.open(file_path) as im: | |
| im.verify() | |
| original = Image.open(file_path).convert("RGB") | |
| except UnidentifiedImageError: | |
| return "Invalid image file. Please upload a real image.", None, None, {} | |
| except Exception: | |
| return "Could not process the file. Is it a valid image?", None, None, {} | |
| # preprocess (what the model sees) | |
| processed = _preprocess(original) | |
| # predict using processed image | |
| tmpdir = pathlib.Path(tempfile.mkdtemp()) | |
| proc_path = tmpdir / "model_input.png" | |
| processed.save(proc_path) | |
| df = pd.DataFrame({"image": [str(proc_path)]}) | |
| proba_df = PREDICTOR.predict_proba(df).rename(columns={0: "Western (0)", 1: "Asian (1)"}) | |
| row = proba_df.iloc[0] | |
| pretty = { | |
| "Western": float(row.get("Western (0)", 0.0)), | |
| "Asian": float(row.get("Asian (1)", 0.0)), | |
| } | |
| gr.Info("Image accepted.") | |
| return "Image loaded successfully.", original, processed, pretty | |
| # ----------------------------- | |
| # UI | |
| # ----------------------------- | |
| EXAMPLES = [ | |
| ["examples/paella.jpg"], | |
| ["examples/mapo_tofu.jpg"], | |
| ["examples/kimchi.jpg"], | |
| ] | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Asian Food or Western Food?") | |
| gr.Markdown( | |
| "This Space demonstrates serving an AutoGluon image classifier in Gradio. " | |
| "Upload a **PNG/JPG** (≤ 5 MB) or use your **webcam**. " | |
| "You’ll see the original image and the preprocessed input the model sees." | |
| ) | |
| gr.Markdown( | |
| """ | |
| **Model:** [ddecosmo/nn_automl_model](https://huggingface.co/ddecosmo/nn_automl_model) | |
| **Library:** [AutoGluon TabularPredictor docs](https://auto.gluon.ai/dev/api/autogluon.tabular.TabularPredictor.html) | |
| **Dataset:** [maryzhang/hw1-24679-image-dataset](https://huggingface.co/datasets/maryzhang/hw1-24679-image-dataset) | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_in = gr.Image( | |
| type="filepath", | |
| label="Input image", | |
| sources=["upload", "webcam"], | |
| show_label=True, | |
| height=300, | |
| ) | |
| gr.Examples( | |
| examples=[e[0] for e in EXAMPLES], | |
| inputs=image_in, | |
| label="Representative examples", | |
| examples_per_page=8, | |
| cache_examples=False, | |
| ) | |
| with gr.Column(): | |
| status = gr.Markdown() | |
| with gr.Row(): | |
| original_out = gr.Image(label="Original", height=300) | |
| processed_out = gr.Image(label=f"Preprocessed ({TARGET}x{TARGET})", height=300) | |
| proba_pretty = gr.Label(num_top_classes=2, label="Class probabilities") | |
| image_in.change( | |
| fn=process_and_predict, | |
| inputs=[image_in], | |
| outputs=[status, original_out, processed_out, proba_pretty], | |
| show_progress="minimal", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |