yuyutsu07 commited on
Commit
6c8993e
·
verified ·
1 Parent(s): aef98ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -317
app.py CHANGED
@@ -2,359 +2,138 @@ import torch
2
  import gradio as gr
3
  import imageio
4
  import numpy as np
5
- import cv2
6
  from PIL import Image
7
- from torchvision.transforms import ToTensor, Resize, Compose, ToPILImage
8
  import spaces
9
  import tempfile
10
- import os
11
- import gc
12
- import warnings
13
- import traceback
14
- from huggingface_hub import hf_hub_download
15
- from transformers import pipeline
16
- from diffusers import DPTForDepthEstimation, DPTImageProcessor
17
- from accelerate import Accelerator
18
-
19
- # Suppress warnings
20
- warnings.filterwarnings("ignore")
21
-
22
- # Global model cache
23
- DEPTH_MODEL = None
24
- DEPTH_PROCESSOR = None
25
-
26
- class DepthModelManager:
27
- @staticmethod
28
- def get_depth_model():
29
- """Lazy-loads the depth estimation model on first use"""
30
- global DEPTH_MODEL, DEPTH_PROCESSOR
31
-
32
- if DEPTH_MODEL is None:
33
- try:
34
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
- model_id = "Intel/dpt-large"
36
-
37
- print(f"Loading depth model on {device}...")
38
- DEPTH_MODEL = DPTForDepthEstimation.from_pretrained(model_id).to(device)
39
- DEPTH_PROCESSOR = DPTImageProcessor.from_pretrained(model_id)
40
- print("Depth model loaded successfully")
41
- except Exception as e:
42
- print(f"Error loading depth model: {e}")
43
- raise
44
-
45
- return DEPTH_MODEL, DEPTH_PROCESSOR
46
-
47
- @staticmethod
48
- def generate_depth_map(image):
49
- """Generate a depth map from an input image"""
50
- model, processor = DepthModelManager.get_depth_model()
51
- device = next(model.parameters()).device
52
-
53
- # Preprocess the image
54
- image_size = image.size
55
- inputs = processor(images=image, return_tensors="pt").to(device)
56
-
57
- # Get depth prediction
58
- with torch.no_grad():
59
- outputs = model(**inputs)
60
- predicted_depth = outputs.predicted_depth
61
-
62
- # Postprocess the depth map
63
- prediction = torch.nn.functional.interpolate(
64
- predicted_depth.unsqueeze(1),
65
- size=image_size[::-1],
66
- mode="bicubic",
67
- align_corners=False,
68
- ).squeeze()
69
-
70
- # Normalize the depth map
71
- depth_map = (prediction - prediction.min()) / (prediction.max() - prediction.min())
72
- depth_map = ToPILImage()(depth_map.cpu())
73
-
74
- return depth_map
75
 
76
  @spaces.GPU
77
- def generate_parallax_video(image, depth_map=None, use_auto_depth=False, animation_style="horizontal",
78
- amplitude=2.0, k=5.0, fps=30, duration=5.0, smooth_edges=True,
79
- invert_depth=False, progress=gr.Progress()):
80
  """
81
  Generate a 3D parallax video from an image and depth map with the selected animation style.
82
-
83
  Args:
84
  image (PIL.Image): Input RGB image.
85
- depth_map (PIL.Image, optional): Grayscale depth map (white = closer, black = farther).
86
- use_auto_depth (bool): Whether to auto-generate the depth map.
87
  animation_style (str): Animation type ("horizontal", "vertical", "circle", "perspective").
88
  amplitude (float): Intensity of camera movement.
89
  k (float): Depth displacement scale factor.
90
  fps (int): Frames per second.
91
  duration (float): Video duration in seconds.
92
- smooth_edges (bool): Whether to apply edge smoothing to reduce artifacts.
93
- invert_depth (bool): Whether to invert the depth map.
94
- progress (gr.Progress): Gradio progress indicator.
95
-
96
  Returns:
97
- str: Path to the generated video file or error message.
98
  """
99
- try:
100
- if image is None:
101
- return "Error: Please upload an input image"
102
-
103
- # Generate depth map if auto-mode is selected
104
- if use_auto_depth or depth_map is None:
105
- progress(0.1, desc="Generating depth map...")
106
- depth_map = DepthModelManager.generate_depth_map(image)
107
-
108
- # Validate input dimensions
109
- original_size = image.size
110
- if depth_map.size != image.size:
111
- depth_map = depth_map.resize(image.size, Image.BICUBIC)
112
-
113
- # Handle depth map inversion if requested
114
- if invert_depth:
115
- depth_map = Image.fromarray(255 - np.array(depth_map))
116
-
117
- # Convert to device tensors
118
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
119
- progress(0.2, desc="Processing inputs...")
120
-
121
- # Optimize for memory - use 16-bit precision
122
- torch.set_grad_enabled(False)
123
- if torch.cuda.is_available():
124
- torch.cuda.empty_cache()
125
-
126
- image_tensor = ToTensor()(image).unsqueeze(0).to(device, dtype=torch.float16) # (1, 3, H, W)
127
- depth_tensor = ToTensor()(depth_map.convert('L')).to(device, dtype=torch.float16) # (1, H, W)
128
-
129
- # Normalize depth (min-max)
130
- depth_min = depth_tensor.min()
131
- depth_max = depth_tensor.max()
132
- if depth_max - depth_min < 1e-5: # Handle flat depth maps
133
- depth_tensor = torch.ones_like(depth_tensor) * 0.5
 
 
 
 
 
 
134
  else:
135
- depth_tensor = (depth_tensor - depth_min) / (depth_max - depth_min + 1e-6)
136
-
137
- depth_tensor = depth_tensor.squeeze(0) # (H, W)
138
-
139
- # Apply optional mild gaussian blur to depth for smoother transitions
140
- if smooth_edges:
141
- kernel_size = max(3, min(int(min(image.size) / 100) * 2 + 1, 11))
142
- depth_np = depth_tensor.cpu().numpy()
143
- depth_np = cv2.GaussianBlur(depth_np, (kernel_size, kernel_size), 0)
144
- depth_tensor = torch.tensor(depth_np, device=device, dtype=torch.float16)
145
-
146
- # Extract dimensions
147
- H, W = image_tensor.shape[2], image_tensor.shape[3]
148
-
149
- # Create coordinate grid for warping
150
- x = torch.arange(0, W, device=device, dtype=torch.float16)
151
- y = torch.arange(0, H, device=device, dtype=torch.float16)
152
- xx, yy = torch.meshgrid(x, y, indexing='xy')
153
- pixel_grid = torch.stack((xx, yy), dim=-1) # (H, W, 2)
154
-
155
- # Calculate number of frames
156
- num_frames = int(fps * duration)
157
-
158
- # Prepare video writer
159
- output_path = os.path.join(tempfile.gettempdir(), "parallax_video.mp4")
160
- writer = imageio.get_writer(output_path, fps=fps, codec='libx264', quality=8,
161
- pixelformat='yuv420p', bitrate='8000k')
162
-
163
- # Define easing function for smoother animation
164
- def ease_in_out(t):
165
- return 0.5 * (1 - np.cos(np.pi * t))
166
-
167
- # Animation and rendering
168
- progress(0.3, desc="Generating frames...")
169
- frame_count = 0
170
-
171
- for frame in range(num_frames):
172
- # Report progress
173
- frame_progress = 0.3 + (0.65 * (frame / num_frames))
174
- progress(frame_progress, desc=f"Rendering frame {frame+1}/{num_frames}")
175
-
176
- # Normalized time with easing
177
- t = frame / (num_frames - 1) # [0, 1]
178
- t_eased = ease_in_out(t)
179
-
180
- # Calculate camera position based on animation style
181
- if animation_style == "horizontal":
182
- camera_x = amplitude * np.sin(2 * np.pi * t_eased)
183
- camera_y = 0
184
- displacement_scale = 1
185
- elif animation_style == "vertical":
186
- camera_x = 0
187
- camera_y = amplitude * np.sin(2 * np.pi * t_eased)
188
- displacement_scale = 1
189
- elif animation_style == "circle":
190
- camera_x = amplitude * np.sin(2 * np.pi * t_eased)
191
- camera_y = amplitude * np.cos(2 * np.pi * t_eased)
192
- displacement_scale = 1
193
- elif animation_style == "perspective":
194
- # Better perspective effect
195
- zoom_factor = 0.1 * np.sin(2 * np.pi * t_eased) + 1.0 # [0.9, 1.1]
196
- camera_x = amplitude * 0.5 * np.sin(2 * np.pi * t_eased)
197
- camera_y = amplitude * 0.3 * np.sin(2 * np.pi * t_eased)
198
- displacement_scale = zoom_factor
199
- else:
200
- camera_x = 0
201
- camera_y = 0
202
- displacement_scale = 1
203
-
204
- # Compute displacements with a more natural depth response
205
- displacement_x = displacement_scale * k * camera_x * depth_tensor
206
- displacement_y = displacement_scale * k * camera_y * depth_tensor
207
-
208
- # Calculate source coordinates for warping
209
- source_pixel_x = pixel_grid[:, :, 0] + displacement_x
210
- source_pixel_y = pixel_grid[:, :, 1] + displacement_y
211
-
212
- # Normalize coordinates to [-1, 1] for grid_sample
213
- grid_x = 2 * source_pixel_x / (W - 1) - 1
214
- grid_y = 2 * source_pixel_y / (H - 1) - 1
215
- grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0) # (1, H, W, 2)
216
-
217
- # Warp the image using grid sampling with improved border handling
218
- warped = torch.nn.functional.grid_sample(
219
- image_tensor,
220
- grid,
221
- align_corners=True,
222
- mode='bilinear',
223
- padding_mode='reflection' # Using reflection padding for smoother edges
224
- )
225
-
226
- # Convert warped tensor to numpy image
227
- warped_np = warped.squeeze(0).permute(1, 2, 0).cpu().numpy()
228
- # Convert to 8-bit for video
229
- frame_img = (warped_np * 255).clip(0, 255).astype(np.uint8)
230
-
231
- # Apply a mild vignette effect to hide edge artifacts
232
- if smooth_edges:
233
- h, w = frame_img.shape[:2]
234
- center_x, center_y = w // 2, h // 2
235
- max_dist = np.sqrt(center_x**2 + center_y**2)
236
- y, x = np.ogrid[:h, :w]
237
- dist = np.sqrt((x - center_x)**2 + (y - center_y)**2)
238
- vignette = np.clip(1.0 - dist / max_dist * 0.15, 0.95, 1.0)
239
- frame_img = (frame_img * vignette[:, :, np.newaxis]).astype(np.uint8)
240
-
241
- # Add frame to video
242
- writer.append_data(frame_img)
243
- frame_count += 1
244
-
245
- # Prevent memory issues by cleaning up tensors
246
- del grid, warped
247
- if frame % 10 == 0 and torch.cuda.is_available():
248
- torch.cuda.empty_cache()
249
-
250
- # Ensure all frames are written and close the writer
251
  writer.close()
252
 
253
- # Clean up tensors
254
- del image_tensor, depth_tensor, pixel_grid
255
- if torch.cuda.is_available():
256
- torch.cuda.empty_cache()
257
- gc.collect()
258
-
259
- progress(1.0, desc="Processing complete")
260
-
261
- if frame_count > 0:
262
- return output_path
263
- else:
264
- return "Error: No frames were generated. Please adjust your parameters."
265
-
266
- except Exception as e:
267
- # Clean up any resources
268
- if torch.cuda.is_available():
269
- torch.cuda.empty_cache()
270
-
271
- error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
272
- print(error_msg)
273
- return f"An error occurred: {str(e)}"
274
 
275
  # Define Gradio interface
276
- with gr.Blocks(title="3D Parallax Video Generator", theme=gr.themes.Soft()) as demo:
277
- gr.Markdown("# Advanced 3D Parallax Video Generator")
278
-
279
- with gr.Accordion("About this app", open=False):
280
- gr.Markdown("""
281
- This application converts 2D images into 3D parallax motion videos. Upload an image and
282
- either provide a depth map or use our built-in depth estimation model to automatically
283
- generate one. Customize the animation style and parameters to create your desired effect.
284
-
285
- ### Tips for best results:
286
- - Start with small amplitude and k values (2-5) to avoid extreme distortions
287
- - The depth map should have white areas for objects closer to camera, black for farther objects
288
- - For automatic depth generation, images with clear foreground/background separation work best
289
- - If you see artifacts at the edges, enable the "Smooth edges" option
290
- """)
291
 
292
  # Input section
293
  with gr.Row():
294
- with gr.Column():
295
- image_input = gr.Image(label="Upload Image", type="pil")
296
-
297
- with gr.Row():
298
- use_auto_depth = gr.Checkbox(label="Auto-generate depth map", value=True)
299
- invert_depth = gr.Checkbox(label="Invert depth map", value=False)
300
-
301
- depth_input = gr.Image(label="Upload Depth Map (optional)", type="pil")
302
 
303
  # Parameter controls
304
  with gr.Row():
305
- with gr.Column():
306
- animation_style = gr.Dropdown(
307
- choices=["horizontal", "vertical", "circle", "perspective"],
308
- label="Animation Style",
309
- value="horizontal"
310
- )
311
- amplitude_slider = gr.Slider(0.5, 10, value=2, label="Movement Amplitude", step=0.1)
312
- k_slider = gr.Slider(1, 20, value=5, label="Depth Effect Strength", step=0.1)
313
-
314
- with gr.Column():
315
- fps_slider = gr.Slider(15, 60, value=30, label="Frames Per Second", step=1)
316
- duration_slider = gr.Slider(1, 10, value=3, label="Duration (seconds)", step=0.1)
317
- smooth_edges = gr.Checkbox(label="Smooth edges (reduces artifacts)", value=True)
318
 
319
  # Output and interaction
320
- with gr.Row():
321
- generate_btn = gr.Button("Generate Video", variant="primary")
322
-
323
  video_output = gr.Video(label="Parallax Video")
324
 
325
- # Handle automatic depth map generation
326
- def update_depth_visibility(auto_generate):
327
- return gr.update(visible=not auto_generate)
328
-
329
- use_auto_depth.change(update_depth_visibility, inputs=[use_auto_depth], outputs=[depth_input])
330
-
331
  # Connect button to function
332
  generate_btn.click(
333
  fn=generate_parallax_video,
334
- inputs=[
335
- image_input,
336
- depth_input,
337
- use_auto_depth,
338
- animation_style,
339
- amplitude_slider,
340
- k_slider,
341
- fps_slider,
342
- duration_slider,
343
- smooth_edges,
344
- invert_depth
345
- ],
346
  outputs=video_output
347
  )
348
 
349
- # Add examples
350
- gr.Examples(
351
- examples=[
352
- ["https://huggingface.co/spaces/stabilityai/stable-diffusion/resolve/main/images/lincoln.jpg"],
353
- ["https://images.unsplash.com/photo-1546614042-7df3c24c9e5d"],
354
- ["https://images.unsplash.com/photo-1563473213013-de2a0133c100"],
355
- ],
356
- inputs=[image_input],
357
- )
358
-
359
  # Launch the application
360
  demo.launch()
 
2
  import gradio as gr
3
  import imageio
4
  import numpy as np
 
5
  from PIL import Image
6
+ from torchvision.transforms import ToTensor
7
  import spaces
8
  import tempfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  @spaces.GPU
11
+ def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps, duration):
 
 
12
  """
13
  Generate a 3D parallax video from an image and depth map with the selected animation style.
14
+
15
  Args:
16
  image (PIL.Image): Input RGB image.
17
+ depth_map (PIL.Image): Grayscale depth map (white = closer, black = farther).
 
18
  animation_style (str): Animation type ("horizontal", "vertical", "circle", "perspective").
19
  amplitude (float): Intensity of camera movement.
20
  k (float): Depth displacement scale factor.
21
  fps (int): Frames per second.
22
  duration (float): Video duration in seconds.
23
+
 
 
 
24
  Returns:
25
+ str: Path to the generated video file.
26
  """
27
+ # Validate input dimensions
28
+ if image.size != depth_map.size:
29
+ raise ValueError("Image and depth map must have the same dimensions")
30
+
31
+ # Convert inputs to PyTorch tensors on GPU
32
+ image_tensor = ToTensor()(image).unsqueeze(0).to('cuda') # Shape: (1, 3, H, W)
33
+ depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda') # Shape: (1, 1, H, W)
34
+ depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
35
+ depth_tensor = depth_tensor.squeeze(0).squeeze(0) # Shape: (H, W)
36
+
37
+ H, W = image_tensor.shape[2], image_tensor.shape[3]
38
+
39
+ # Create coordinate grid for warping
40
+ x = torch.arange(0, W).float().to('cuda')
41
+ y = torch.arange(0, H).float().to('cuda')
42
+ xx, yy = torch.meshgrid(x, y, indexing='xy')
43
+ pixel_grid = torch.stack((xx, yy), dim=-1) # Shape: (H, W, 2)
44
+
45
+ # Calculate number of frames
46
+ num_frames = int(fps * duration)
47
+ frames = []
48
+
49
+ # Generate frames based on animation style
50
+ for frame in range(num_frames):
51
+ t = frame / num_frames # Normalized time [0, 1]
52
+ if animation_style == "horizontal":
53
+ camera_x = amplitude * np.sin(2 * np.pi * t)
54
+ camera_y = 0
55
+ displacement_scale = 1
56
+ elif animation_style == "vertical":
57
+ camera_x = 0
58
+ camera_y = amplitude * np.sin(2 * np.pi * t)
59
+ displacement_scale = 1
60
+ elif animation_style == "circle":
61
+ camera_x = amplitude * np.sin(2 * np.pi * t)
62
+ camera_y = amplitude * np.cos(2 * np.pi * t)
63
+ displacement_scale = 1
64
+ elif animation_style == "perspective":
65
+ camera_x = amplitude # Fixed horizontal base for consistency
66
+ camera_y = 0
67
+ displacement_scale = 1 + 0.5 * np.sin(2 * np.pi * t) # Scales displacement for zoom effect
68
  else:
69
+ raise ValueError(f"Unsupported animation style: {animation_style}")
70
+
71
+ # Compute displacements in both x and y directions
72
+ displacement_x = displacement_scale * k * camera_x * depth_tensor
73
+ displacement_y = displacement_scale * k * camera_y * depth_tensor
74
+
75
+ # Calculate source coordinates for warping
76
+ source_pixel_x = pixel_grid[:, :, 0] + displacement_x
77
+ source_pixel_y = pixel_grid[:, :, 1] + displacement_y
78
+
79
+ # Normalize coordinates to [-1, 1] for grid_sample
80
+ grid_x = 2 * source_pixel_x / (W - 1) - 1
81
+ grid_y = 2 * source_pixel_y / (H - 1) - 1
82
+ grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0) # Shape: (1, H, W, 2)
83
+
84
+ # Warp the image using grid sampling
85
+ warped = torch.nn.functional.grid_sample(image_tensor, grid, align_corners=True)
86
+
87
+ # Convert warped tensor to numpy image
88
+ warped_np = warped.squeeze(0).permute(1, 2, 0).cpu().numpy()
89
+ frame_img = (warped_np * 255).astype(np.uint8)
90
+ frames.append(frame_img)
91
+
92
+ # Save frames as a video
93
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
94
+ output_path = tmpfile.name
95
+ writer = imageio.get_writer(output_path, fps=fps, codec='libx264')
96
+ for frame in frames:
97
+ writer.append_data(frame)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  writer.close()
99
 
100
+ return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  # Define Gradio interface
103
+ with gr.Blocks(title="3D Parallax Video Generator") as demo:
104
+ gr.Markdown("# 3D Parallax Video Generator")
105
+ gr.Markdown("""
106
+ Upload an image and its depth map (white = closer, black = farther) to create a 3D parallax video.
107
+ Select an animation style and adjust parameters below. Start with small amplitude and k values to avoid empty frames.
108
+ """)
 
 
 
 
 
 
 
 
 
109
 
110
  # Input section
111
  with gr.Row():
112
+ image_input = gr.Image(type="pil", label="Upload Image")
113
+ depth_input = gr.Image(type="pil", label="Upload Depth Map")
 
 
 
 
 
 
114
 
115
  # Parameter controls
116
  with gr.Row():
117
+ animation_style = gr.Dropdown(
118
+ choices=["horizontal", "vertical", "circle", "perspective"],
119
+ label="Animation Style",
120
+ value="horizontal"
121
+ )
122
+ amplitude_slider = gr.Slider(0, 10, value=2, label="Amplitude", step=0.1)
123
+ k_slider = gr.Slider(1, 20, value=5, label="Depth Scale (k)", step=0.1)
124
+ fps_slider = gr.Slider(10, 60, value=30, label="Frames Per Second", step=1)
125
+ duration_slider = gr.Slider(1, 10, value=5, label="Duration (seconds)", step=0.1)
 
 
 
 
126
 
127
  # Output and interaction
128
+ generate_btn = gr.Button("Generate Video")
 
 
129
  video_output = gr.Video(label="Parallax Video")
130
 
 
 
 
 
 
 
131
  # Connect button to function
132
  generate_btn.click(
133
  fn=generate_parallax_video,
134
+ inputs=[image_input, depth_input, animation_style, amplitude_slider, k_slider, fps_slider, duration_slider],
 
 
 
 
 
 
 
 
 
 
 
135
  outputs=video_output
136
  )
137
 
 
 
 
 
 
 
 
 
 
 
138
  # Launch the application
139
  demo.launch()