#!/usr/bin/env python3 import spaces import os import sys import importlib.util import re import gradio as gr from PIL import Image import torch import requests # for downloading remote checkpoints import shutil # CUDA info try: print(f"CUDA available: {torch.cuda.is_available()}") print(f"CUDA version: {torch.version.cuda}") print(f"GPU device: {torch.cuda.get_device_name(0)}") except: print('CUDA is not available !') # ——— Monkey-patch mmdet to remove its mmcv-version assertion ——— spec = importlib.util.find_spec('mmdet') if spec and spec.origin: src = open(spec.origin, encoding='utf-8').read() patched = re.sub(r'(?ms)^[ \t]*mmcv_minimum_version.*?^__all__', '__all__', src) m = importlib.util.module_from_spec(spec) m.__loader__ = spec.loader m.__file__ = spec.origin m.__path__ = spec.submodule_search_locations sys.modules['mmdet'] = m exec(compile(patched, spec.origin, 'exec'), m.__dict__) from mmpose.apis.inferencers import MMPoseInferencer # Remote checkpoints REMOTE_CHECKPOINTS = { # COCO-trained "rtmo-s_8xb32-600e_coco": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-s_8xb32-600e_coco-640x640-8db55a59_20231211.pth", "rtmo-m_16xb16-600e_coco": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-m_16xb16-600e_coco-640x640-6f4e0306_20231211.pth", "rtmo-l_16xb16-600e_coco": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-l_16xb16-600e_coco-640x640-516a421f_20231211.pth", # BODY7-trained "rtmo-t_8xb32-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-t_8xb32-600e_body7-416x416-f48f75cb_20231219.pth", "rtmo-s_8xb32-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-s_8xb32-600e_body7-640x640-dac2bf74_20231211.pth", "rtmo-m_16xb16-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-m_16xb16-600e_body7-640x640-39e78cc4_20231211.pth", "rtmo-l_16xb16-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-l_16xb16-600e_body7-640x640-b37118ce_20231211.pth", # CrowdPose-trained "rtmo-s_8xb32-700e_crowdpose": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-s_8xb32-700e_crowdpose-640x640-79f81c0d_20231211.pth", "rtmo-m_16xb16-700e_crowdpose": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rrtmo-m_16xb16-700e_crowdpose-640x640-0eaf670d_20231211.pth", "rtmo-l_16xb16-700e_crowdpose": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-l_16xb16-700e_crowdpose-640x640-1008211f_20231211.pth", # Retrainable from HF repo "rtmo-s_coco_retrainable": "https://huggingface.co/Luigi/Retrainable-RTMO-s/resolve/main/rtmo-s_coco_retrainable.pth", } # Variants for inference (prefixes) VARIANT_PREFIX = { 24: "rtmo-t_8xb32-600e_body7-416x416", 32: "rtmo-s_8xb32-600e_body7-640x640", 48: "rtmo-m_16xb16-600e_body7-640x640", 64: "rtmo-l_16xb16-600e_body7-640x640", } # ——— Helper: download checkpoint if remote ——— def get_checkpoint(path_or_key: str) -> str: if path_or_key in REMOTE_CHECKPOINTS: url = REMOTE_CHECKPOINTS[path_or_key] local_path = f"/tmp/{path_or_key}.pth" if not os.path.exists(local_path): r = requests.get(url, stream=True) with open(local_path, 'wb') as f: for chunk in r.iter_content(1024): f.write(chunk) return local_path return path_or_key # ——— Detect variant alias from checkpoint ——— def detect_rtmo_variant(checkpoint_path: str) -> str: ckpt = torch.load(checkpoint_path, map_location='cpu') state_dict = ckpt.get('state_dict', ckpt) key = 'backbone.stem.conv.conv.weight' if key not in state_dict: raise KeyError(f"Cannot find '{key}' in checkpoint.") out_ch = state_dict[key].shape[0] return VARIANT_PREFIX.get(out_ch, 'rtmo-s_8xb32-600e_body7-640x640') # ——— Load inferencer ——— def load_inferencer(checkpoint_path=None, device=None): kwargs = {'scope': 'mmpose', 'device': device, 'det_cat_ids': [0]} if checkpoint_path: variant = detect_rtmo_variant(checkpoint_path) kwargs['pose2d'] = variant kwargs['pose2d_weights'] = checkpoint_path else: kwargs['pose2d'] = 'rtmo' return MMPoseInferencer(**kwargs) # —─── Prediction function ──── @spaces.GPU() def predict(image: Image.Image, video, # new video input remote_ckpt: str, upload_ckpt, bbox_thr: float, nms_thr: float): # 1) Write image or pick up video file if video: # Gradio Video can come in as a filepath string or dict if isinstance(video, dict) and 'name' in video: inp_path = video['name'] elif hasattr(video, "name"): inp_path = video.name else: inp_path = video else: inp_path = "/tmp/upload.jpg" image.save(inp_path) # 2) Determine by extension if this is video ext = os.path.splitext(inp_path)[1].lower() is_video = ext in (".mp4", ".mov", ".avi", ".mkv", ".webm") # checkpoint selection if upload_ckpt: ckpt_path = upload_ckpt.name active = os.path.basename(ckpt_path) else: ckpt_path = get_checkpoint(remote_ckpt) active = remote_ckpt # prepare (and clear) output dir vis_dir = "/tmp/vis" if os.path.exists(vis_dir): shutil.rmtree(vis_dir) os.makedirs(vis_dir, exist_ok=True) # run inferencer (handles both image & video) inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None) for _ in inferencer( inputs=inp_path, bbox_thr=bbox_thr, nms_thr=nms_thr, pose_based_nms=True, show=False, vis_out_dir=vis_dir, ): pass # collect and return results out_files = sorted(os.listdir(vis_dir)) if is_video: # return only the annotated video path out_vid = next((f for f in out_files if f.lower().endswith((".mp4", ".mov", ".avi"))), None) return None, os.path.join(vis_dir, out_vid) if out_vid else None, active else: # return only the annotated image img_f = out_files[0] if out_files else None vis_img = Image.open(os.path.join(vis_dir, img_f)) if img_f and not img_f.lower().endswith((".mp4", ".mov", ".avi")) else None return vis_img, None, active # —─── Gradio UI ──── def main(): with gr.Blocks() as demo: gr.Markdown("## RTMO Pose Demo") with gr.Row(): with gr.Column(scale=1, min_width=300): img_input = gr.Image(type="pil", label="Upload Image") video_input = gr.Video(label="Upload Video") remote_dd = gr.Dropdown( label="Select Remote Checkpoint", choices=list(REMOTE_CHECKPOINTS.keys()), value=list(REMOTE_CHECKPOINTS.keys())[0] ) upload_ckpt = gr.File(file_types=['.pth'], label="Or Upload Your Own Checkpoint (optional)") bbox_thr = gr.Slider(0.0, 1.0, value=0.1, step=0.01, label="Bounding Box Threshold") nms_thr = gr.Slider(0.0, 1.0, value=0.65, step=0.01, label="NMS Threshold") run_btn = gr.Button("Run Inference") with gr.Column(scale=2): output_img = gr.Image(type="pil", label="Annotated Image", elem_id="output_image", interactive=False) output_video = gr.Video(label="Annotated Video", interactive=False) active_tb = gr.Textbox(label="Active Checkpoint", interactive=False) # Examples for quick testing gr.Examples( examples=[ ["https://images.pexels.com/photos/1858175/pexels-photo-1858175.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-s_coco_retrainable", None, 0.1, 0.65], ["https://images.pexels.com/photos/3779706/pexels-photo-3779706.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-t_8xb32-600e_body7", None, 0.1, 0.65], ["https://images.pexels.com/photos/220453/pexels-photo-220453.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-s_8xb32-600e_coco", None, 0.1, 0.65], # 4th example: public-domain Rip Van Winkle (1896) [None, "https://archive.org/download/fred-otts-sneeze/Fred%20Ott%20Sneeze%201894%20GG%20Restore.mp4", "rtmo-s_coco_retrainable", None, 0.1, 0.65], ], inputs=[img_input, video_input, remote_dd, upload_ckpt, bbox_thr, nms_thr], outputs=[output_img, output_video, active_tb], fn=predict, cache_examples=False, label="Examples", examples_per_page=4 ) run_btn.click( predict, inputs=[img_input, video_input, remote_dd, upload_ckpt, bbox_thr, nms_thr], outputs=[output_img, output_video, active_tb] ) demo.launch() if __name__ == "__main__": main()