x10z commited on
Commit
0200003
·
verified ·
1 Parent(s): 2ab4d40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -31
app.py CHANGED
@@ -46,7 +46,7 @@ description = (
46
  def predict(image: Image.Image, processing_res_choice: int):
47
  """
48
  Single-frame prediction wrapped for GPU execution.
49
- Returns a DepthNormalPipelineOutput with attributes depth_colored and normal_colored.
50
  """
51
  with torch.no_grad():
52
  return pipe(
@@ -61,7 +61,7 @@ def predict(image: Image.Image, processing_res_choice: int):
61
 
62
  def on_submit_video(video_path: str, processing_res_choice: int):
63
  """
64
- Processes each frame of the input video, generating separate depth and normal videos.
65
  """
66
  if video_path is None:
67
  print("No video uploaded.")
@@ -73,11 +73,9 @@ def on_submit_video(video_path: str, processing_res_choice: int):
73
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
74
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
75
 
76
- # Create temporary output files
77
- tmp_depth = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
78
  tmp_normal = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
79
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
80
- out_depth = cv2.VideoWriter(tmp_depth.name, fourcc, fps, (width, height))
81
  out_normal = cv2.VideoWriter(tmp_normal.name, fourcc, fps, (width, height))
82
 
83
  # Process each frame
@@ -90,16 +88,10 @@ def on_submit_video(video_path: str, processing_res_choice: int):
90
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
91
  pil_image = Image.fromarray(rgb)
92
 
93
- # Predict depth and normals
94
  result = predict(pil_image, processing_res_choice)
95
- depth_colored = result.depth_colored
96
  normal_colored = result.normal_colored
97
 
98
- # Write depth frame
99
- depth_frame = np.array(depth_colored)
100
- depth_bgr = cv2.cvtColor(depth_frame, cv2.COLOR_RGB2BGR)
101
- out_depth.write(depth_bgr)
102
-
103
  # Write normal frame
104
  normal_frame = np.array(normal_colored)
105
  normal_bgr = cv2.cvtColor(normal_frame, cv2.COLOR_RGB2BGR)
@@ -107,24 +99,19 @@ def on_submit_video(video_path: str, processing_res_choice: int):
107
 
108
  # Release resources
109
  cap.release()
110
- out_depth.release()
111
  out_normal.release()
112
 
113
- # Return video paths for download
114
- return tmp_depth.name, tmp_normal.name
115
-
116
 
117
  # Build Gradio interface
118
  with gr.Blocks() as demo:
119
  gr.Markdown(title)
120
  gr.Markdown(description)
121
- gr.Markdown("### Depth and Normals Prediction on Video")
122
 
123
  with gr.Row():
124
- input_video = gr.Video(
125
- label="Input Video",
126
- elem_id='video-display-input'
127
- )
128
  with gr.Column():
129
  processing_res_choice = gr.Radio(
130
  [
@@ -134,22 +121,15 @@ with gr.Blocks() as demo:
134
  label="Processing resolution",
135
  value=768,
136
  )
137
- submit = gr.Button(value="Compute Depth and Normals")
138
 
139
  with gr.Row():
140
- output_depth_video = gr.Video(
141
- label="Depth Video",
142
- elem_id='download'
143
- )
144
- output_normal_video = gr.Video(
145
- label="Normal Video",
146
- elem_id='download'
147
- )
148
 
149
  submit.click(
150
  fn=on_submit_video,
151
  inputs=[input_video, processing_res_choice],
152
- outputs=[output_depth_video, output_normal_video]
153
  )
154
 
155
  if __name__ == "__main__":
 
46
  def predict(image: Image.Image, processing_res_choice: int):
47
  """
48
  Single-frame prediction wrapped for GPU execution.
49
+ Returns a DepthNormalPipelineOutput with attribute normal_colored.
50
  """
51
  with torch.no_grad():
52
  return pipe(
 
61
 
62
  def on_submit_video(video_path: str, processing_res_choice: int):
63
  """
64
+ Processes each frame of the input video, generating a normal map video.
65
  """
66
  if video_path is None:
67
  print("No video uploaded.")
 
73
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
74
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
75
 
76
+ # Temporary output file for normals video
 
77
  tmp_normal = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
78
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
 
79
  out_normal = cv2.VideoWriter(tmp_normal.name, fourcc, fps, (width, height))
80
 
81
  # Process each frame
 
88
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
89
  pil_image = Image.fromarray(rgb)
90
 
91
+ # Predict normals
92
  result = predict(pil_image, processing_res_choice)
 
93
  normal_colored = result.normal_colored
94
 
 
 
 
 
 
95
  # Write normal frame
96
  normal_frame = np.array(normal_colored)
97
  normal_bgr = cv2.cvtColor(normal_frame, cv2.COLOR_RGB2BGR)
 
99
 
100
  # Release resources
101
  cap.release()
 
102
  out_normal.release()
103
 
104
+ # Return video path for download
105
+ return tmp_normal.name
 
106
 
107
  # Build Gradio interface
108
  with gr.Blocks() as demo:
109
  gr.Markdown(title)
110
  gr.Markdown(description)
111
+ gr.Markdown("### Normals Prediction on Video")
112
 
113
  with gr.Row():
114
+ input_video = gr.Video(label="Input Video", elem_id='video-display-input')
 
 
 
115
  with gr.Column():
116
  processing_res_choice = gr.Radio(
117
  [
 
121
  label="Processing resolution",
122
  value=768,
123
  )
124
+ submit = gr.Button(value="Compute Normals")
125
 
126
  with gr.Row():
127
+ output_normal_video = gr.Video(label="Normal Video", elem_id='download')
 
 
 
 
 
 
 
128
 
129
  submit.click(
130
  fn=on_submit_video,
131
  inputs=[input_video, processing_res_choice],
132
+ outputs=[output_normal_video]
133
  )
134
 
135
  if __name__ == "__main__":