yuyutsu07 commited on
Commit
16f4e59
·
verified ·
1 Parent(s): 2291646

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -56
app.py CHANGED
@@ -3,23 +3,26 @@ import gradio as gr
3
  import imageio
4
  import numpy as np
5
  from PIL import Image
6
- from torchvision.transforms import ToTensor
7
  import spaces
8
  import tempfile
 
9
 
10
  @spaces.GPU
11
- def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps, duration):
12
  """
13
- Generate a 3D parallax video from an image and depth map with the selected animation style.
14
 
15
  Args:
16
  image (PIL.Image): Input RGB image.
17
- depth_map (PIL.Image): Grayscale depth map (white = closer, black = farther).
18
- animation_style (str): Animation type ("horizontal", "vertical", "circle", "perspective").
19
- amplitude (float): Intensity of camera movement.
20
- k (float): Depth displacement scale factor.
21
  fps (int): Frames per second.
22
  duration (float): Video duration in seconds.
 
 
23
 
24
  Returns:
25
  str: Path to the generated video file.
@@ -28,68 +31,86 @@ def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps
28
  if image.size != depth_map.size:
29
  raise ValueError("Image and depth map must have the same dimensions")
30
 
31
- # Convert inputs to PyTorch tensors on GPU
32
- image_tensor = ToTensor()(image).unsqueeze(0).to('cuda') # Shape: (1, 3, H, W)
33
- depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda') # Shape: (1, 1, H, W)
34
  depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
35
- depth_tensor = depth_tensor.squeeze(0).squeeze(0) # Shape: (H, W)
36
 
37
- H, W = image_tensor.shape[2], image_tensor.shape[3]
 
 
 
38
 
39
- # Create coordinate grid for warping
 
 
 
 
 
 
 
 
40
  x = torch.arange(0, W).float().to('cuda')
41
  y = torch.arange(0, H).float().to('cuda')
42
  xx, yy = torch.meshgrid(x, y, indexing='xy')
43
- pixel_grid = torch.stack((xx, yy), dim=-1) # Shape: (H, W, 2)
44
 
45
- # Calculate number of frames
46
  num_frames = int(fps * duration)
47
  frames = []
 
48
 
49
- # Generate frames based on animation style
50
  for frame in range(num_frames):
51
- t = frame / num_frames # Normalized time [0, 1]
52
  if animation_style == "horizontal":
53
  camera_x = amplitude * np.sin(2 * np.pi * t)
54
  camera_y = 0
55
- displacement_scale = 1
56
  elif animation_style == "vertical":
57
  camera_x = 0
58
  camera_y = amplitude * np.sin(2 * np.pi * t)
59
- displacement_scale = 1
60
  elif animation_style == "circle":
61
  camera_x = amplitude * np.sin(2 * np.pi * t)
62
  camera_y = amplitude * np.cos(2 * np.pi * t)
63
- displacement_scale = 1
64
- elif animation_style == "perspective":
65
- camera_x = amplitude # Fixed horizontal base for consistency
66
- camera_y = 0
67
- displacement_scale = 1 + 0.5 * np.sin(2 * np.pi * t) # Scales displacement for zoom effect
68
  else:
69
  raise ValueError(f"Unsupported animation style: {animation_style}")
70
 
71
- # Compute displacements in both x and y directions
72
- displacement_x = displacement_scale * k * camera_x * depth_tensor
73
- displacement_y = displacement_scale * k * camera_y * depth_tensor
74
 
75
- # Calculate source coordinates for warping
76
  source_pixel_x = pixel_grid[:, :, 0] + displacement_x
77
  source_pixel_y = pixel_grid[:, :, 1] + displacement_y
78
 
79
- # Normalize coordinates to [-1, 1] for grid_sample
80
  grid_x = 2 * source_pixel_x / (W - 1) - 1
81
  grid_y = 2 * source_pixel_y / (H - 1) - 1
82
- grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0) # Shape: (1, H, W, 2)
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- # Warp the image using grid sampling
85
- warped = torch.nn.functional.grid_sample(image_tensor, grid, align_corners=True)
 
86
 
87
- # Convert warped tensor to numpy image
88
- warped_np = warped.squeeze(0).permute(1, 2, 0).cpu().numpy()
89
- frame_img = (warped_np * 255).astype(np.uint8)
90
  frames.append(frame_img)
 
91
 
92
- # Save frames as a video
93
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
94
  output_path = tmpfile.name
95
  writer = imageio.get_writer(output_path, fps=fps, codec='libx264')
@@ -99,41 +120,31 @@ def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps
99
 
100
  return output_path
101
 
102
- # Define Gradio interface
103
- with gr.Blocks(title="3D Parallax Video Generator") as demo:
104
- gr.Markdown("# 3D Parallax Video Generator")
105
- gr.Markdown("""
106
- Upload an image and its depth map (white = closer, black = farther) to create a 3D parallax video.
107
- Select an animation style and adjust parameters below. Start with small amplitude and k values to avoid empty frames.
108
- """)
109
 
110
- # Input section
111
  with gr.Row():
112
  image_input = gr.Image(type="pil", label="Upload Image")
113
  depth_input = gr.Image(type="pil", label="Upload Depth Map")
114
 
115
- # Parameter controls
116
  with gr.Row():
117
- animation_style = gr.Dropdown(
118
- choices=["horizontal", "vertical", "circle", "perspective"],
119
- label="Animation Style",
120
- value="horizontal"
121
- )
122
  amplitude_slider = gr.Slider(0, 10, value=2, label="Amplitude", step=0.1)
123
  k_slider = gr.Slider(1, 20, value=5, label="Depth Scale (k)", step=0.1)
124
- fps_slider = gr.Slider(10, 60, value=30, label="Frames Per Second", step=1)
125
- duration_slider = gr.Slider(1, 10, value=5, label="Duration (seconds)", step=0.1)
 
 
126
 
127
- # Output and interaction
128
  generate_btn = gr.Button("Generate Video")
129
  video_output = gr.Video(label="Parallax Video")
130
 
131
- # Connect button to function
132
  generate_btn.click(
133
  fn=generate_parallax_video,
134
- inputs=[image_input, depth_input, animation_style, amplitude_slider, k_slider, fps_slider, duration_slider],
135
  outputs=video_output
136
  )
137
 
138
- # Launch the application
139
  demo.launch()
 
3
  import imageio
4
  import numpy as np
5
  from PIL import Image
6
+ from torchvision.transforms import ToTensor, Resize
7
  import spaces
8
  import tempfile
9
+ from scipy.ndimage import gaussian_filter
10
 
11
  @spaces.GPU
12
+ def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps, duration, ssaa_factor, use_taa):
13
  """
14
+ Generate a 3D parallax video with enhanced quality features.
15
 
16
  Args:
17
  image (PIL.Image): Input RGB image.
18
+ depth_map (PIL.Image): Grayscale depth map.
19
+ animation_style (str): Animation type (e.g., horizontal, spiral).
20
+ amplitude (float): Camera movement intensity.
21
+ k (float): Depth displacement scale.
22
  fps (int): Frames per second.
23
  duration (float): Video duration in seconds.
24
+ ssaa_factor (int): Super sampling factor (1, 2, 4).
25
+ use_taa (bool): Enable temporal anti-aliasing.
26
 
27
  Returns:
28
  str: Path to the generated video file.
 
31
  if image.size != depth_map.size:
32
  raise ValueError("Image and depth map must have the same dimensions")
33
 
34
+ # Convert to tensors with high precision
35
+ image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32)
36
+ depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32)
37
  depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
 
38
 
39
+ # Smooth depth map to improve intersections
40
+ depth_np = depth_tensor.squeeze().cpu().numpy()
41
+ depth_np = gaussian_filter(depth_np, sigma=1) # Basic smoothing
42
+ depth_tensor = torch.tensor(depth_np, device='cuda', dtype=torch.float32).unsqueeze(0)
43
 
44
+ # Apply SSAA: upscale image and depth map
45
+ if ssaa_factor > 1:
46
+ upscale = Resize((int(image.height * ssaa_factor), int(image.width * ssaa_factor)), antialias=True)
47
+ image_tensor = upscale(image_tensor)
48
+ depth_tensor = upscale(depth_tensor)
49
+
50
+ H, W = image_tensor.shape[1], image_tensor.shape[2]
51
+
52
+ # Create coordinate grid
53
  x = torch.arange(0, W).float().to('cuda')
54
  y = torch.arange(0, H).float().to('cuda')
55
  xx, yy = torch.meshgrid(x, y, indexing='xy')
56
+ pixel_grid = torch.stack((xx, yy), dim=-1)
57
 
58
+ # Generate frames
59
  num_frames = int(fps * duration)
60
  frames = []
61
+ prev_frame = None
62
 
 
63
  for frame in range(num_frames):
64
+ t = frame / num_frames
65
  if animation_style == "horizontal":
66
  camera_x = amplitude * np.sin(2 * np.pi * t)
67
  camera_y = 0
 
68
  elif animation_style == "vertical":
69
  camera_x = 0
70
  camera_y = amplitude * np.sin(2 * np.pi * t)
 
71
  elif animation_style == "circle":
72
  camera_x = amplitude * np.sin(2 * np.pi * t)
73
  camera_y = amplitude * np.cos(2 * np.pi * t)
74
+ elif animation_style == "spiral": # Inspired by DepthFlow
75
+ radius = amplitude * (1 - t)
76
+ camera_x = radius * np.sin(4 * np.pi * t)
77
+ camera_y = radius * np.cos(4 * np.pi * t)
 
78
  else:
79
  raise ValueError(f"Unsupported animation style: {animation_style}")
80
 
81
+ # Compute displacements
82
+ displacement_x = k * camera_x * depth_tensor.squeeze()
83
+ displacement_y = k * camera_y * depth_tensor.squeeze()
84
 
85
+ # Calculate source coordinates
86
  source_pixel_x = pixel_grid[:, :, 0] + displacement_x
87
  source_pixel_y = pixel_grid[:, :, 1] + displacement_y
88
 
89
+ # Normalize to [-1, 1]
90
  grid_x = 2 * source_pixel_x / (W - 1) - 1
91
  grid_y = 2 * source_pixel_y / (H - 1) - 1
92
+ grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0)
93
+
94
+ # Warp with high-quality interpolation
95
+ warped = torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True)
96
+
97
+ # Downsample if SSAA is enabled
98
+ if ssaa_factor > 1:
99
+ downscale = Resize((image.height, image.width), antialias=True)
100
+ warped = downscale(warped.squeeze(0)).unsqueeze(0)
101
+
102
+ # Convert to numpy
103
+ frame_img = warped.squeeze(0).permute(1, 2, 0).cpu().numpy()
104
+ frame_img = (frame_img * 255).astype(np.uint8)
105
 
106
+ # Apply TAA if enabled
107
+ if use_taa and prev_frame is not None:
108
+ frame_img = (frame_img * 0.8 + prev_frame * 0.2).astype(np.uint8)
109
 
 
 
 
110
  frames.append(frame_img)
111
+ prev_frame = frame_img
112
 
113
+ # Save video
114
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
115
  output_path = tmpfile.name
116
  writer = imageio.get_writer(output_path, fps=fps, codec='libx264')
 
120
 
121
  return output_path
122
 
123
+ # Gradio interface
124
+ with gr.Blocks(title="Enhanced 3D Parallax Video Generator") as demo:
125
+ gr.Markdown("# Enhanced 3D Parallax Video Generator")
126
+ gr.Markdown("Create high-quality 3D parallax videos with advanced features.")
 
 
 
127
 
 
128
  with gr.Row():
129
  image_input = gr.Image(type="pil", label="Upload Image")
130
  depth_input = gr.Image(type="pil", label="Upload Depth Map")
131
 
 
132
  with gr.Row():
133
+ animation_style = gr.Dropdown(["horizontal", "vertical", "circle", "spiral"], label="Animation Style", value="horizontal")
 
 
 
 
134
  amplitude_slider = gr.Slider(0, 10, value=2, label="Amplitude", step=0.1)
135
  k_slider = gr.Slider(1, 20, value=5, label="Depth Scale (k)", step=0.1)
136
+ fps_slider = gr.Slider(10, 60, value=30, label="FPS", step=1)
137
+ duration_slider = gr.Slider(1, 10, value=5, label="Duration (s)", step=0.1)
138
+ ssaa_factor = gr.Dropdown([1, 2, 4], label="SSAA Factor", value=1)
139
+ use_taa = gr.Checkbox(label="Enable TAA", value=False)
140
 
 
141
  generate_btn = gr.Button("Generate Video")
142
  video_output = gr.Video(label="Parallax Video")
143
 
 
144
  generate_btn.click(
145
  fn=generate_parallax_video,
146
+ inputs=[image_input, depth_input, animation_style, amplitude_slider, k_slider, fps_slider, duration_slider, ssaa_factor, use_taa],
147
  outputs=video_output
148
  )
149
 
 
150
  demo.launch()