Spaces:
Sleeping
Sleeping
import gradio as gr | |
import cv2 | |
import numpy as np | |
import torch | |
import tempfile | |
from PIL import Image | |
import spaces | |
from tqdm.auto import tqdm | |
from diffusers import DDIMScheduler, AutoencoderKL | |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
from GeoWizard.geowizard.models.unet_2d_condition import UNet2DConditionModel | |
from GeoWizard.geowizard.models.geowizard_pipeline import DepthNormalEstimationPipeline | |
# Device setup | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
CHECKPOINT_PATH = "GonzaloMG/geowizard-e2e-ft" | |
# Load pretrained components | |
vae = AutoencoderKL.from_pretrained(CHECKPOINT_PATH, subfolder='vae') | |
scheduler = DDIMScheduler.from_pretrained(CHECKPOINT_PATH, timestep_spacing="trailing", subfolder='scheduler') | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained(CHECKPOINT_PATH, subfolder="image_encoder") | |
feature_extractor = CLIPImageProcessor.from_pretrained(CHECKPOINT_PATH, subfolder="feature_extractor") | |
unet = UNet2DConditionModel.from_pretrained(CHECKPOINT_PATH, subfolder="unet") | |
# Instantiate pipeline | |
pipe = DepthNormalEstimationPipeline( | |
vae=vae, | |
image_encoder=image_encoder, | |
feature_extractor=feature_extractor, | |
unet=unet, | |
scheduler=scheduler | |
).to(DEVICE) | |
pipe.unet.eval() | |
# UI texts | |
title = "# End-to-End Fine-Tuned GeoWizard Video" | |
description = ( | |
""" | |
Please refer to our [paper](https://arxiv.org/abs/2409.11355) and | |
[GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details. | |
""" | |
) | |
def predict(image: Image.Image, processing_res_choice: int): | |
""" | |
Single-frame prediction wrapped for GPU execution. | |
Returns a DepthNormalPipelineOutput with attribute normal_colored. | |
""" | |
with torch.no_grad(): | |
return pipe( | |
image, | |
denoising_steps=1, | |
ensemble_size=1, | |
noise="zeros", | |
processing_res=processing_res_choice, | |
match_input_res=True | |
) | |
def on_submit_video(video_path: str, processing_res_choice: int): | |
""" | |
Processes each frame of the input video, generating a normal map video. | |
""" | |
if video_path is None: | |
print("No video uploaded.") | |
return None | |
cap = cv2.VideoCapture(video_path) | |
fps = cap.get(cv2.CAP_PROP_FPS) or 30 | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# Temporary output file for normals video | |
tmp_normal = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out_normal = cv2.VideoWriter(tmp_normal.name, fourcc, fps, (width, height)) | |
# Process each frame | |
for _ in tqdm(range(frame_count), desc="Processing frames"): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Convert frame to PIL image | |
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
pil_image = Image.fromarray(rgb) | |
# Predict normals | |
result = predict(pil_image, processing_res_choice) | |
normal_colored = result.normal_colored | |
# Write normal frame | |
normal_frame = np.array(normal_colored) | |
normal_bgr = cv2.cvtColor(normal_frame, cv2.COLOR_RGB2BGR) | |
out_normal.write(normal_bgr) | |
# Release resources | |
cap.release() | |
out_normal.release() | |
# Return video path for download | |
return tmp_normal.name | |
# Build Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown(title) | |
gr.Markdown(description) | |
gr.Markdown("### Normals Prediction on Video") | |
with gr.Row(): | |
input_video = gr.Video(label="Input Video", elem_id='video-display-input') | |
with gr.Column(): | |
processing_res_choice = gr.Radio( | |
[ | |
("Recommended (768)", 768), | |
("Native (original)", 0), | |
], | |
label="Processing resolution", | |
value=768, | |
) | |
submit = gr.Button(value="Compute Normals") | |
with gr.Row(): | |
output_normal_video = gr.Video(label="Normal Video", elem_id='download') | |
submit.click( | |
fn=on_submit_video, | |
inputs=[input_video, processing_res_choice], | |
outputs=[output_normal_video] | |
) | |
if __name__ == "__main__": | |
demo.queue().launch(share=True) | |