Update app.py
Browse files
app.py
CHANGED
@@ -3,23 +3,26 @@ import gradio as gr
|
|
3 |
import imageio
|
4 |
import numpy as np
|
5 |
from PIL import Image
|
6 |
-
from torchvision.transforms import ToTensor
|
7 |
import spaces
|
8 |
import tempfile
|
|
|
9 |
|
10 |
@spaces.GPU
|
11 |
-
def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps, duration):
|
12 |
"""
|
13 |
-
Generate a 3D parallax video
|
14 |
|
15 |
Args:
|
16 |
image (PIL.Image): Input RGB image.
|
17 |
-
depth_map (PIL.Image): Grayscale depth map
|
18 |
-
animation_style (str): Animation type (
|
19 |
-
amplitude (float):
|
20 |
-
k (float): Depth displacement scale
|
21 |
fps (int): Frames per second.
|
22 |
duration (float): Video duration in seconds.
|
|
|
|
|
23 |
|
24 |
Returns:
|
25 |
str: Path to the generated video file.
|
@@ -28,68 +31,86 @@ def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps
|
|
28 |
if image.size != depth_map.size:
|
29 |
raise ValueError("Image and depth map must have the same dimensions")
|
30 |
|
31 |
-
# Convert
|
32 |
-
image_tensor = ToTensor()(image).
|
33 |
-
depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda'
|
34 |
depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
|
35 |
-
depth_tensor = depth_tensor.squeeze(0).squeeze(0) # Shape: (H, W)
|
36 |
|
37 |
-
|
|
|
|
|
|
|
38 |
|
39 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
x = torch.arange(0, W).float().to('cuda')
|
41 |
y = torch.arange(0, H).float().to('cuda')
|
42 |
xx, yy = torch.meshgrid(x, y, indexing='xy')
|
43 |
-
pixel_grid = torch.stack((xx, yy), dim=-1)
|
44 |
|
45 |
-
#
|
46 |
num_frames = int(fps * duration)
|
47 |
frames = []
|
|
|
48 |
|
49 |
-
# Generate frames based on animation style
|
50 |
for frame in range(num_frames):
|
51 |
-
t = frame / num_frames
|
52 |
if animation_style == "horizontal":
|
53 |
camera_x = amplitude * np.sin(2 * np.pi * t)
|
54 |
camera_y = 0
|
55 |
-
displacement_scale = 1
|
56 |
elif animation_style == "vertical":
|
57 |
camera_x = 0
|
58 |
camera_y = amplitude * np.sin(2 * np.pi * t)
|
59 |
-
displacement_scale = 1
|
60 |
elif animation_style == "circle":
|
61 |
camera_x = amplitude * np.sin(2 * np.pi * t)
|
62 |
camera_y = amplitude * np.cos(2 * np.pi * t)
|
63 |
-
|
64 |
-
|
65 |
-
camera_x =
|
66 |
-
camera_y =
|
67 |
-
displacement_scale = 1 + 0.5 * np.sin(2 * np.pi * t) # Scales displacement for zoom effect
|
68 |
else:
|
69 |
raise ValueError(f"Unsupported animation style: {animation_style}")
|
70 |
|
71 |
-
# Compute displacements
|
72 |
-
displacement_x =
|
73 |
-
displacement_y =
|
74 |
|
75 |
-
# Calculate source coordinates
|
76 |
source_pixel_x = pixel_grid[:, :, 0] + displacement_x
|
77 |
source_pixel_y = pixel_grid[:, :, 1] + displacement_y
|
78 |
|
79 |
-
# Normalize
|
80 |
grid_x = 2 * source_pixel_x / (W - 1) - 1
|
81 |
grid_y = 2 * source_pixel_y / (H - 1) - 1
|
82 |
-
grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
#
|
85 |
-
|
|
|
86 |
|
87 |
-
# Convert warped tensor to numpy image
|
88 |
-
warped_np = warped.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
89 |
-
frame_img = (warped_np * 255).astype(np.uint8)
|
90 |
frames.append(frame_img)
|
|
|
91 |
|
92 |
-
# Save
|
93 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
|
94 |
output_path = tmpfile.name
|
95 |
writer = imageio.get_writer(output_path, fps=fps, codec='libx264')
|
@@ -99,41 +120,31 @@ def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps
|
|
99 |
|
100 |
return output_path
|
101 |
|
102 |
-
#
|
103 |
-
with gr.Blocks(title="3D Parallax Video Generator") as demo:
|
104 |
-
gr.Markdown("# 3D Parallax Video Generator")
|
105 |
-
gr.Markdown(""
|
106 |
-
Upload an image and its depth map (white = closer, black = farther) to create a 3D parallax video.
|
107 |
-
Select an animation style and adjust parameters below. Start with small amplitude and k values to avoid empty frames.
|
108 |
-
""")
|
109 |
|
110 |
-
# Input section
|
111 |
with gr.Row():
|
112 |
image_input = gr.Image(type="pil", label="Upload Image")
|
113 |
depth_input = gr.Image(type="pil", label="Upload Depth Map")
|
114 |
|
115 |
-
# Parameter controls
|
116 |
with gr.Row():
|
117 |
-
animation_style = gr.Dropdown(
|
118 |
-
choices=["horizontal", "vertical", "circle", "perspective"],
|
119 |
-
label="Animation Style",
|
120 |
-
value="horizontal"
|
121 |
-
)
|
122 |
amplitude_slider = gr.Slider(0, 10, value=2, label="Amplitude", step=0.1)
|
123 |
k_slider = gr.Slider(1, 20, value=5, label="Depth Scale (k)", step=0.1)
|
124 |
-
fps_slider = gr.Slider(10, 60, value=30, label="
|
125 |
-
duration_slider = gr.Slider(1, 10, value=5, label="Duration (
|
|
|
|
|
126 |
|
127 |
-
# Output and interaction
|
128 |
generate_btn = gr.Button("Generate Video")
|
129 |
video_output = gr.Video(label="Parallax Video")
|
130 |
|
131 |
-
# Connect button to function
|
132 |
generate_btn.click(
|
133 |
fn=generate_parallax_video,
|
134 |
-
inputs=[image_input, depth_input, animation_style, amplitude_slider, k_slider, fps_slider, duration_slider],
|
135 |
outputs=video_output
|
136 |
)
|
137 |
|
138 |
-
# Launch the application
|
139 |
demo.launch()
|
|
|
3 |
import imageio
|
4 |
import numpy as np
|
5 |
from PIL import Image
|
6 |
+
from torchvision.transforms import ToTensor, Resize
|
7 |
import spaces
|
8 |
import tempfile
|
9 |
+
from scipy.ndimage import gaussian_filter
|
10 |
|
11 |
@spaces.GPU
|
12 |
+
def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps, duration, ssaa_factor, use_taa):
|
13 |
"""
|
14 |
+
Generate a 3D parallax video with enhanced quality features.
|
15 |
|
16 |
Args:
|
17 |
image (PIL.Image): Input RGB image.
|
18 |
+
depth_map (PIL.Image): Grayscale depth map.
|
19 |
+
animation_style (str): Animation type (e.g., horizontal, spiral).
|
20 |
+
amplitude (float): Camera movement intensity.
|
21 |
+
k (float): Depth displacement scale.
|
22 |
fps (int): Frames per second.
|
23 |
duration (float): Video duration in seconds.
|
24 |
+
ssaa_factor (int): Super sampling factor (1, 2, 4).
|
25 |
+
use_taa (bool): Enable temporal anti-aliasing.
|
26 |
|
27 |
Returns:
|
28 |
str: Path to the generated video file.
|
|
|
31 |
if image.size != depth_map.size:
|
32 |
raise ValueError("Image and depth map must have the same dimensions")
|
33 |
|
34 |
+
# Convert to tensors with high precision
|
35 |
+
image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32)
|
36 |
+
depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32)
|
37 |
depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
|
|
|
38 |
|
39 |
+
# Smooth depth map to improve intersections
|
40 |
+
depth_np = depth_tensor.squeeze().cpu().numpy()
|
41 |
+
depth_np = gaussian_filter(depth_np, sigma=1) # Basic smoothing
|
42 |
+
depth_tensor = torch.tensor(depth_np, device='cuda', dtype=torch.float32).unsqueeze(0)
|
43 |
|
44 |
+
# Apply SSAA: upscale image and depth map
|
45 |
+
if ssaa_factor > 1:
|
46 |
+
upscale = Resize((int(image.height * ssaa_factor), int(image.width * ssaa_factor)), antialias=True)
|
47 |
+
image_tensor = upscale(image_tensor)
|
48 |
+
depth_tensor = upscale(depth_tensor)
|
49 |
+
|
50 |
+
H, W = image_tensor.shape[1], image_tensor.shape[2]
|
51 |
+
|
52 |
+
# Create coordinate grid
|
53 |
x = torch.arange(0, W).float().to('cuda')
|
54 |
y = torch.arange(0, H).float().to('cuda')
|
55 |
xx, yy = torch.meshgrid(x, y, indexing='xy')
|
56 |
+
pixel_grid = torch.stack((xx, yy), dim=-1)
|
57 |
|
58 |
+
# Generate frames
|
59 |
num_frames = int(fps * duration)
|
60 |
frames = []
|
61 |
+
prev_frame = None
|
62 |
|
|
|
63 |
for frame in range(num_frames):
|
64 |
+
t = frame / num_frames
|
65 |
if animation_style == "horizontal":
|
66 |
camera_x = amplitude * np.sin(2 * np.pi * t)
|
67 |
camera_y = 0
|
|
|
68 |
elif animation_style == "vertical":
|
69 |
camera_x = 0
|
70 |
camera_y = amplitude * np.sin(2 * np.pi * t)
|
|
|
71 |
elif animation_style == "circle":
|
72 |
camera_x = amplitude * np.sin(2 * np.pi * t)
|
73 |
camera_y = amplitude * np.cos(2 * np.pi * t)
|
74 |
+
elif animation_style == "spiral": # Inspired by DepthFlow
|
75 |
+
radius = amplitude * (1 - t)
|
76 |
+
camera_x = radius * np.sin(4 * np.pi * t)
|
77 |
+
camera_y = radius * np.cos(4 * np.pi * t)
|
|
|
78 |
else:
|
79 |
raise ValueError(f"Unsupported animation style: {animation_style}")
|
80 |
|
81 |
+
# Compute displacements
|
82 |
+
displacement_x = k * camera_x * depth_tensor.squeeze()
|
83 |
+
displacement_y = k * camera_y * depth_tensor.squeeze()
|
84 |
|
85 |
+
# Calculate source coordinates
|
86 |
source_pixel_x = pixel_grid[:, :, 0] + displacement_x
|
87 |
source_pixel_y = pixel_grid[:, :, 1] + displacement_y
|
88 |
|
89 |
+
# Normalize to [-1, 1]
|
90 |
grid_x = 2 * source_pixel_x / (W - 1) - 1
|
91 |
grid_y = 2 * source_pixel_y / (H - 1) - 1
|
92 |
+
grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0)
|
93 |
+
|
94 |
+
# Warp with high-quality interpolation
|
95 |
+
warped = torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True)
|
96 |
+
|
97 |
+
# Downsample if SSAA is enabled
|
98 |
+
if ssaa_factor > 1:
|
99 |
+
downscale = Resize((image.height, image.width), antialias=True)
|
100 |
+
warped = downscale(warped.squeeze(0)).unsqueeze(0)
|
101 |
+
|
102 |
+
# Convert to numpy
|
103 |
+
frame_img = warped.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
104 |
+
frame_img = (frame_img * 255).astype(np.uint8)
|
105 |
|
106 |
+
# Apply TAA if enabled
|
107 |
+
if use_taa and prev_frame is not None:
|
108 |
+
frame_img = (frame_img * 0.8 + prev_frame * 0.2).astype(np.uint8)
|
109 |
|
|
|
|
|
|
|
110 |
frames.append(frame_img)
|
111 |
+
prev_frame = frame_img
|
112 |
|
113 |
+
# Save video
|
114 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
|
115 |
output_path = tmpfile.name
|
116 |
writer = imageio.get_writer(output_path, fps=fps, codec='libx264')
|
|
|
120 |
|
121 |
return output_path
|
122 |
|
123 |
+
# Gradio interface
|
124 |
+
with gr.Blocks(title="Enhanced 3D Parallax Video Generator") as demo:
|
125 |
+
gr.Markdown("# Enhanced 3D Parallax Video Generator")
|
126 |
+
gr.Markdown("Create high-quality 3D parallax videos with advanced features.")
|
|
|
|
|
|
|
127 |
|
|
|
128 |
with gr.Row():
|
129 |
image_input = gr.Image(type="pil", label="Upload Image")
|
130 |
depth_input = gr.Image(type="pil", label="Upload Depth Map")
|
131 |
|
|
|
132 |
with gr.Row():
|
133 |
+
animation_style = gr.Dropdown(["horizontal", "vertical", "circle", "spiral"], label="Animation Style", value="horizontal")
|
|
|
|
|
|
|
|
|
134 |
amplitude_slider = gr.Slider(0, 10, value=2, label="Amplitude", step=0.1)
|
135 |
k_slider = gr.Slider(1, 20, value=5, label="Depth Scale (k)", step=0.1)
|
136 |
+
fps_slider = gr.Slider(10, 60, value=30, label="FPS", step=1)
|
137 |
+
duration_slider = gr.Slider(1, 10, value=5, label="Duration (s)", step=0.1)
|
138 |
+
ssaa_factor = gr.Dropdown([1, 2, 4], label="SSAA Factor", value=1)
|
139 |
+
use_taa = gr.Checkbox(label="Enable TAA", value=False)
|
140 |
|
|
|
141 |
generate_btn = gr.Button("Generate Video")
|
142 |
video_output = gr.Video(label="Parallax Video")
|
143 |
|
|
|
144 |
generate_btn.click(
|
145 |
fn=generate_parallax_video,
|
146 |
+
inputs=[image_input, depth_input, animation_style, amplitude_slider, k_slider, fps_slider, duration_slider, ssaa_factor, use_taa],
|
147 |
outputs=video_output
|
148 |
)
|
149 |
|
|
|
150 |
demo.launch()
|