SkalskiP commited on
Commit
75be6c3
·
1 Parent(s): 1015457

test video processing on HF spaces

Browse files
Files changed (2) hide show
  1. app.py +61 -21
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,5 +1,7 @@
1
- from typing import Union
 
2
 
 
3
  import gradio as gr
4
  import numpy as np
5
  import supervision as sv
@@ -9,22 +11,16 @@ from rfdetr.detr import RFDETR
9
  from rfdetr.util.coco_classes import COCO_CLASSES
10
 
11
  from utils.image import calculate_resolution_wh
12
- from utils.video import create_directory
 
 
13
 
14
  MARKDOWN = """
15
  # RF-DETR 🔥
16
 
17
- <div>
18
- <a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-finetune-rf-detr-on-detection-dataset.ipynb">
19
- <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="colab" style="display:inline-block;">
20
- </a>
21
- <a href="https://blog.roboflow.com/rf-detr">
22
- <img src="https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg" alt="roboflow" style="display:inline-block;">
23
- </a>
24
- <a href="https://github.com/roboflow/rf-detr">
25
- <img src="https://badges.aleen42.com/src/github.svg" alt="roboflow" style="display:inline-block;">
26
- </a>
27
- </div>
28
 
29
  RF-DETR is a real-time, transformer-based object detection model architecture developed
30
  by [Roboflow](https://roboflow.com/) and released under the Apache 2.0 license.
@@ -41,12 +37,18 @@ COLOR = sv.ColorPalette.from_hex([
41
  "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00"
42
  ])
43
 
 
44
  VIDEO_SCALE_FACTOR = 0.5
45
  VIDEO_TARGET_DIRECTORY = "tmp"
 
46
  create_directory(directory_path=VIDEO_TARGET_DIRECTORY)
47
 
48
 
49
- def detect_and_annotate(model: RFDETR, image: Union[Image.Image, np.ndarray], confidence: float):
 
 
 
 
50
  detections = model.predict(image, threshold=confidence)
51
 
52
  resolution_wh = calculate_resolution_wh(image)
@@ -73,16 +75,54 @@ def detect_and_annotate(model: RFDETR, image: Union[Image.Image, np.ndarray], co
73
  return annotated_image
74
 
75
 
76
- def image_processing_inference(input_image: Image.Image, confidence: float, resolution: int, checkpoint: str):
77
- model_class = RFDETRBase if checkpoint == "base" else RFDETRLarge
78
- model = model_class(resolution=resolution)
 
 
 
 
 
 
 
 
 
 
 
 
79
  return detect_and_annotate(model=model, image=input_image, confidence=confidence)
80
 
81
 
82
- def video_processing_inference(input_video: str, confidence: float, resolution: int, checkpoint: str):
83
- model_class = RFDETRBase if checkpoint == "base" else RFDETRLarge
84
- model = model_class(resolution=resolution)
85
- return input_video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  with gr.Blocks() as demo:
88
  gr.Markdown(MARKDOWN)
 
1
+ import os
2
+ from typing import TypeVar
3
 
4
+ from tqdm import tqdm
5
  import gradio as gr
6
  import numpy as np
7
  import supervision as sv
 
11
  from rfdetr.util.coco_classes import COCO_CLASSES
12
 
13
  from utils.image import calculate_resolution_wh
14
+ from utils.video import create_directory, generate_unique_name
15
+
16
+ ImageType = TypeVar("ImageType", Image.Image, np.ndarray)
17
 
18
  MARKDOWN = """
19
  # RF-DETR 🔥
20
 
21
+ [`[code]`](https://github.com/roboflow/rf-detr)
22
+ [`[blog]`](https://blog.roboflow.com/rf-detr)
23
+ [`[notebook]`](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-finetune-rf-detr-on-detection-dataset.ipynb)
 
 
 
 
 
 
 
 
24
 
25
  RF-DETR is a real-time, transformer-based object detection model architecture developed
26
  by [Roboflow](https://roboflow.com/) and released under the Apache 2.0 license.
 
37
  "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00"
38
  ])
39
 
40
+ MAX_VIDEO_LENGTH_SECONDS = 2
41
  VIDEO_SCALE_FACTOR = 0.5
42
  VIDEO_TARGET_DIRECTORY = "tmp"
43
+
44
  create_directory(directory_path=VIDEO_TARGET_DIRECTORY)
45
 
46
 
47
+ def detect_and_annotate(
48
+ model: RFDETR,
49
+ image: ImageType,
50
+ confidence: float
51
+ ) -> ImageType:
52
  detections = model.predict(image, threshold=confidence)
53
 
54
  resolution_wh = calculate_resolution_wh(image)
 
75
  return annotated_image
76
 
77
 
78
+ def load_model(resolution: int, checkpoint: str) -> RFDETR:
79
+ if checkpoint == "base":
80
+ return RFDETRBase(resolution=resolution)
81
+ elif checkpoint == "large":
82
+ return RFDETRLarge(resolution=resolution)
83
+ raise TypeError("Checkpoint must be a base or large.")
84
+
85
+
86
+ def image_processing_inference(
87
+ input_image: Image.Image,
88
+ confidence: float,
89
+ resolution: int,
90
+ checkpoint: str
91
+ ):
92
+ model = load_model(resolution=resolution, checkpoint=checkpoint)
93
  return detect_and_annotate(model=model, image=input_image, confidence=confidence)
94
 
95
 
96
+ def video_processing_inference(
97
+ input_video: str,
98
+ confidence: float,
99
+ resolution: int,
100
+ checkpoint: str,
101
+ progress=gr.Progress(track_tqdm=True)
102
+ ):
103
+ model = load_model(resolution=resolution, checkpoint=checkpoint)
104
+
105
+ name = generate_unique_name()
106
+ output_video = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4")
107
+
108
+ video_info = sv.VideoInfo.from_video_path(input_video)
109
+ video_info.width = int(video_info.width * VIDEO_SCALE_FACTOR)
110
+ video_info.height = int(video_info.height * VIDEO_SCALE_FACTOR)
111
+
112
+ total = min(video_info.total_frames, video_info.fps * MAX_VIDEO_LENGTH_SECONDS)
113
+ frames_generator = sv.get_video_frames_generator(input_video, end=total)
114
+
115
+ with sv.VideoSink(output_video, video_info=video_info) as sink:
116
+ for frame in tqdm(frames_generator, total=total):
117
+ frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR)
118
+ annotated_frame = detect_and_annotate(
119
+ model=model,
120
+ image=frame,
121
+ confidence=confidence
122
+ )
123
+ sink.write_frame(annotated_frame)
124
+
125
+ return output_video
126
 
127
  with gr.Blocks() as demo:
128
  gr.Markdown(MARKDOWN)
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  gradio
2
  spaces
3
- rfdetr
 
 
1
  gradio
2
  spaces
3
+ rfdetr
4
+ tqdm