x10z's picture
Update app.py
0200003 verified
import gradio as gr
import cv2
import numpy as np
import torch
import tempfile
from PIL import Image
import spaces
from tqdm.auto import tqdm
from diffusers import DDIMScheduler, AutoencoderKL
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from GeoWizard.geowizard.models.unet_2d_condition import UNet2DConditionModel
from GeoWizard.geowizard.models.geowizard_pipeline import DepthNormalEstimationPipeline
# Device setup
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
CHECKPOINT_PATH = "GonzaloMG/geowizard-e2e-ft"
# Load pretrained components
vae = AutoencoderKL.from_pretrained(CHECKPOINT_PATH, subfolder='vae')
scheduler = DDIMScheduler.from_pretrained(CHECKPOINT_PATH, timestep_spacing="trailing", subfolder='scheduler')
image_encoder = CLIPVisionModelWithProjection.from_pretrained(CHECKPOINT_PATH, subfolder="image_encoder")
feature_extractor = CLIPImageProcessor.from_pretrained(CHECKPOINT_PATH, subfolder="feature_extractor")
unet = UNet2DConditionModel.from_pretrained(CHECKPOINT_PATH, subfolder="unet")
# Instantiate pipeline
pipe = DepthNormalEstimationPipeline(
vae=vae,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
unet=unet,
scheduler=scheduler
).to(DEVICE)
pipe.unet.eval()
# UI texts
title = "# End-to-End Fine-Tuned GeoWizard Video"
description = (
"""
Please refer to our [paper](https://arxiv.org/abs/2409.11355) and
[GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details.
"""
)
@spaces.GPU
def predict(image: Image.Image, processing_res_choice: int):
"""
Single-frame prediction wrapped for GPU execution.
Returns a DepthNormalPipelineOutput with attribute normal_colored.
"""
with torch.no_grad():
return pipe(
image,
denoising_steps=1,
ensemble_size=1,
noise="zeros",
processing_res=processing_res_choice,
match_input_res=True
)
def on_submit_video(video_path: str, processing_res_choice: int):
"""
Processes each frame of the input video, generating a normal map video.
"""
if video_path is None:
print("No video uploaded.")
return None
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS) or 30
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Temporary output file for normals video
tmp_normal = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out_normal = cv2.VideoWriter(tmp_normal.name, fourcc, fps, (width, height))
# Process each frame
for _ in tqdm(range(frame_count), desc="Processing frames"):
ret, frame = cap.read()
if not ret:
break
# Convert frame to PIL image
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(rgb)
# Predict normals
result = predict(pil_image, processing_res_choice)
normal_colored = result.normal_colored
# Write normal frame
normal_frame = np.array(normal_colored)
normal_bgr = cv2.cvtColor(normal_frame, cv2.COLOR_RGB2BGR)
out_normal.write(normal_bgr)
# Release resources
cap.release()
out_normal.release()
# Return video path for download
return tmp_normal.name
# Build Gradio interface
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown("### Normals Prediction on Video")
with gr.Row():
input_video = gr.Video(label="Input Video", elem_id='video-display-input')
with gr.Column():
processing_res_choice = gr.Radio(
[
("Recommended (768)", 768),
("Native (original)", 0),
],
label="Processing resolution",
value=768,
)
submit = gr.Button(value="Compute Normals")
with gr.Row():
output_normal_video = gr.Video(label="Normal Video", elem_id='download')
submit.click(
fn=on_submit_video,
inputs=[input_video, processing_res_choice],
outputs=[output_normal_video]
)
if __name__ == "__main__":
demo.queue().launch(share=True)