Luigi commited on
Commit
db758db
Β·
1 Parent(s): 986677d

Add video input

Browse files
Files changed (2) hide show
  1. README.md +0 -10
  2. app.py +68 -31
README.md CHANGED
@@ -61,13 +61,3 @@ The following variants are available out of the box:
61
  - **app.py**: Main Gradio application script.
62
  - **requirements.txt**: Python dependencies, including MMCV and MMPose.
63
  - **README.md**: This documentation file.
64
-
65
- ## Development
66
-
67
- To update dependencies, edit `requirements.txt`. To extend functionality or add new variants, modify `app.py` accordingly.
68
-
69
- ## Future Plans
70
-
71
- 1. Support video input streams.
72
- 2. Enable ONNX model inference via `rtmlib`.
73
-
 
61
  - **app.py**: Main Gradio application script.
62
  - **requirements.txt**: Python dependencies, including MMCV and MMPose.
63
  - **README.md**: This documentation file.
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,10 +1,14 @@
1
  #!/usr/bin/env python3
2
  import spaces
3
- import os, sys, importlib.util, re
 
 
 
4
  import gradio as gr
5
  from PIL import Image
6
  import torch
7
  import requests # for downloading remote checkpoints
 
8
 
9
  # CUDA info
10
  try:
@@ -92,22 +96,45 @@ def load_inferencer(checkpoint_path=None, device=None):
92
  # —─── Prediction function ────
93
  @spaces.GPU()
94
  def predict(image: Image.Image,
 
95
  remote_ckpt: str,
96
  upload_ckpt,
97
  bbox_thr: float,
98
  nms_thr: float):
99
- inp_path = "/tmp/upload.jpg"
100
- image.save(inp_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  if upload_ckpt:
102
  ckpt_path = upload_ckpt.name
103
  active = os.path.basename(ckpt_path)
104
  else:
105
  ckpt_path = get_checkpoint(remote_ckpt)
106
  active = remote_ckpt
 
 
107
  vis_dir = "/tmp/vis"
 
 
108
  os.makedirs(vis_dir, exist_ok=True)
 
 
109
  inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None)
110
- for result in inferencer(
111
  inputs=inp_path,
112
  bbox_thr=bbox_thr,
113
  nms_thr=nms_thr,
@@ -116,9 +143,18 @@ def predict(image: Image.Image,
116
  vis_out_dir=vis_dir,
117
  ):
118
  pass
 
 
119
  out_files = sorted(os.listdir(vis_dir))
120
- vis_img = Image.open(os.path.join(vis_dir, out_files[0])) if out_files else None
121
- return vis_img, active
 
 
 
 
 
 
 
122
 
123
  # —─── Gradio UI ────
124
  def main():
@@ -126,43 +162,44 @@ def main():
126
  gr.Markdown("## RTMO Pose Demo")
127
  with gr.Row():
128
  with gr.Column(scale=1, min_width=300):
129
- img_input = gr.Image(type="pil", label="Upload Image")
130
- remote_dd = gr.Dropdown(label="Select Remote Checkpoint",
131
- choices=list(REMOTE_CHECKPOINTS.keys()),
132
- value=list(REMOTE_CHECKPOINTS.keys())[0])
 
 
 
133
  upload_ckpt = gr.File(file_types=['.pth'], label="Or Upload Your Own Checkpoint (optional)")
134
- bbox_thr = gr.Slider(minimum=0.0, maximum=1.0, step=0.01,
135
- value=0.1, label="Bounding Box Threshold")
136
- nms_thr = gr.Slider(minimum=0.0, maximum=1.0, step=0.01,
137
- value=0.65, label="NMS Threshold")
138
- run_btn = gr.Button("Run Inference")
139
  with gr.Column(scale=2):
140
- output_img = gr.Image(type="pil", label="Annotated Image",
141
- elem_id="output_image", interactive=False)
142
- active_tb = gr.Textbox(label="Active Checkpoint", interactive=False)
143
-
144
  # Examples for quick testing
145
  gr.Examples(
146
  examples=[
147
- ["https://images.pexels.com/photos/1858175/pexels-photo-1858175.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=614",
148
- "rtmo-s_coco_retrainable", None, 0.1, 0.65],
149
- ["https://images.pexels.com/photos/3779706/pexels-photo-3779706.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=614",
150
- "rtmo-t_8xb32-600e_body7", None, 0.1, 0.65],
151
- ["https://images.pexels.com/photos/220453/pexels-photo-220453.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=614",
152
- "rtmo-s_8xb32-600e_coco", None, 0.1, 0.65],
153
  ],
154
- inputs=[img_input, remote_dd, upload_ckpt, bbox_thr, nms_thr],
155
- outputs=[output_img, active_tb],
156
  fn=predict,
157
  cache_examples=False,
158
  label="Examples",
159
  examples_per_page=3
160
  )
161
 
162
- run_btn.click(predict,
163
- inputs=[img_input, remote_dd, upload_ckpt, bbox_thr, nms_thr],
164
- outputs=[output_img, active_tb])
 
 
 
165
  demo.launch()
166
 
167
  if __name__ == "__main__":
168
- main()
 
1
  #!/usr/bin/env python3
2
  import spaces
3
+ import os
4
+ import sys
5
+ import importlib.util
6
+ import re
7
  import gradio as gr
8
  from PIL import Image
9
  import torch
10
  import requests # for downloading remote checkpoints
11
+ import shutil
12
 
13
  # CUDA info
14
  try:
 
96
  # —─── Prediction function ────
97
  @spaces.GPU()
98
  def predict(image: Image.Image,
99
+ video, # new video input
100
  remote_ckpt: str,
101
  upload_ckpt,
102
  bbox_thr: float,
103
  nms_thr: float):
104
+ # 1) Write image or pick up video file
105
+ if video:
106
+ # Gradio Video can come in as a filepath string or dict
107
+ if isinstance(video, dict) and 'name' in video:
108
+ inp_path = video['name']
109
+ elif hasattr(video, "name"):
110
+ inp_path = video.name
111
+ else:
112
+ inp_path = video
113
+ else:
114
+ inp_path = "/tmp/upload.jpg"
115
+ image.save(inp_path)
116
+
117
+ # 2) Determine by extension if this is video
118
+ ext = os.path.splitext(inp_path)[1].lower()
119
+ is_video = ext in (".mp4", ".mov", ".avi", ".mkv", ".webm")
120
+
121
+ # checkpoint selection
122
  if upload_ckpt:
123
  ckpt_path = upload_ckpt.name
124
  active = os.path.basename(ckpt_path)
125
  else:
126
  ckpt_path = get_checkpoint(remote_ckpt)
127
  active = remote_ckpt
128
+
129
+ # prepare (and clear) output dir
130
  vis_dir = "/tmp/vis"
131
+ if os.path.exists(vis_dir):
132
+ shutil.rmtree(vis_dir)
133
  os.makedirs(vis_dir, exist_ok=True)
134
+
135
+ # run inferencer (handles both image & video)
136
  inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None)
137
+ for _ in inferencer(
138
  inputs=inp_path,
139
  bbox_thr=bbox_thr,
140
  nms_thr=nms_thr,
 
143
  vis_out_dir=vis_dir,
144
  ):
145
  pass
146
+
147
+ # collect and return results
148
  out_files = sorted(os.listdir(vis_dir))
149
+ if is_video:
150
+ # return only the annotated video path
151
+ out_vid = next((f for f in out_files if f.lower().endswith((".mp4", ".mov", ".avi"))), None)
152
+ return None, os.path.join(vis_dir, out_vid) if out_vid else None, active
153
+ else:
154
+ # return only the annotated image
155
+ img_f = out_files[0] if out_files else None
156
+ vis_img = Image.open(os.path.join(vis_dir, img_f)) if img_f and not img_f.lower().endswith((".mp4", ".mov", ".avi")) else None
157
+ return vis_img, None, active
158
 
159
  # —─── Gradio UI ────
160
  def main():
 
162
  gr.Markdown("## RTMO Pose Demo")
163
  with gr.Row():
164
  with gr.Column(scale=1, min_width=300):
165
+ img_input = gr.Image(type="pil", label="Upload Image")
166
+ video_input = gr.Video(label="Upload Video")
167
+ remote_dd = gr.Dropdown(
168
+ label="Select Remote Checkpoint",
169
+ choices=list(REMOTE_CHECKPOINTS.keys()),
170
+ value=list(REMOTE_CHECKPOINTS.keys())[0]
171
+ )
172
  upload_ckpt = gr.File(file_types=['.pth'], label="Or Upload Your Own Checkpoint (optional)")
173
+ bbox_thr = gr.Slider(0.0, 1.0, value=0.1, step=0.01, label="Bounding Box Threshold")
174
+ nms_thr = gr.Slider(0.0, 1.0, value=0.65, step=0.01, label="NMS Threshold")
175
+ run_btn = gr.Button("Run Inference")
 
 
176
  with gr.Column(scale=2):
177
+ output_img = gr.Image(type="pil", label="Annotated Image", elem_id="output_image", interactive=False)
178
+ output_video = gr.Video(label="Annotated Video", interactive=False)
179
+ active_tb = gr.Textbox(label="Active Checkpoint", interactive=False)
180
+
181
  # Examples for quick testing
182
  gr.Examples(
183
  examples=[
184
+ ["https://images.pexels.com/photos/1858175/pexels-photo-1858175.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-s_coco_retrainable", None, 0.1, 0.65],
185
+ ["https://images.pexels.com/photos/3779706/pexels-photo-3779706.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-t_8xb32-600e_body7", None, 0.1, 0.65],
186
+ ["https://images.pexels.com/photos/220453/pexels-photo-220453.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-s_8xb32-600e_coco", None, 0.1, 0.65],
 
 
 
187
  ],
188
+ inputs=[img_input, video_input, remote_dd, upload_ckpt, bbox_thr, nms_thr],
189
+ outputs=[output_img, output_video, active_tb],
190
  fn=predict,
191
  cache_examples=False,
192
  label="Examples",
193
  examples_per_page=3
194
  )
195
 
196
+ run_btn.click(
197
+ predict,
198
+ inputs=[img_input, video_input, remote_dd, upload_ckpt, bbox_thr, nms_thr],
199
+ outputs=[output_img, output_video, active_tb]
200
+ )
201
+
202
  demo.launch()
203
 
204
  if __name__ == "__main__":
205
+ main()