import gradio as gr import cv2 import numpy as np import torch import tempfile from PIL import Image import spaces from tqdm.auto import tqdm from diffusers import DDIMScheduler, AutoencoderKL from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from GeoWizard.geowizard.models.unet_2d_condition import UNet2DConditionModel from GeoWizard.geowizard.models.geowizard_pipeline import DepthNormalEstimationPipeline # Device setup DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' CHECKPOINT_PATH = "GonzaloMG/geowizard-e2e-ft" # Load pretrained components vae = AutoencoderKL.from_pretrained(CHECKPOINT_PATH, subfolder='vae') scheduler = DDIMScheduler.from_pretrained(CHECKPOINT_PATH, timestep_spacing="trailing", subfolder='scheduler') image_encoder = CLIPVisionModelWithProjection.from_pretrained(CHECKPOINT_PATH, subfolder="image_encoder") feature_extractor = CLIPImageProcessor.from_pretrained(CHECKPOINT_PATH, subfolder="feature_extractor") unet = UNet2DConditionModel.from_pretrained(CHECKPOINT_PATH, subfolder="unet") # Instantiate pipeline pipe = DepthNormalEstimationPipeline( vae=vae, image_encoder=image_encoder, feature_extractor=feature_extractor, unet=unet, scheduler=scheduler ).to(DEVICE) pipe.unet.eval() # UI texts title = "# End-to-End Fine-Tuned GeoWizard Video" description = ( """ Please refer to our [paper](https://arxiv.org/abs/2409.11355) and [GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details. """ ) @spaces.GPU def predict(image: Image.Image, processing_res_choice: int): """ Single-frame prediction wrapped for GPU execution. Returns a DepthNormalPipelineOutput with attribute normal_colored. """ with torch.no_grad(): return pipe( image, denoising_steps=1, ensemble_size=1, noise="zeros", processing_res=processing_res_choice, match_input_res=True ) def on_submit_video(video_path: str, processing_res_choice: int): """ Processes each frame of the input video, generating a normal map video. """ if video_path is None: print("No video uploaded.") return None cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) or 30 width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Temporary output file for normals video tmp_normal = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) fourcc = cv2.VideoWriter_fourcc(*'mp4v') out_normal = cv2.VideoWriter(tmp_normal.name, fourcc, fps, (width, height)) # Process each frame for _ in tqdm(range(frame_count), desc="Processing frames"): ret, frame = cap.read() if not ret: break # Convert frame to PIL image rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(rgb) # Predict normals result = predict(pil_image, processing_res_choice) normal_colored = result.normal_colored # Write normal frame normal_frame = np.array(normal_colored) normal_bgr = cv2.cvtColor(normal_frame, cv2.COLOR_RGB2BGR) out_normal.write(normal_bgr) # Release resources cap.release() out_normal.release() # Return video path for download return tmp_normal.name # Build Gradio interface with gr.Blocks() as demo: gr.Markdown(title) gr.Markdown(description) gr.Markdown("### Normals Prediction on Video") with gr.Row(): input_video = gr.Video(label="Input Video", elem_id='video-display-input') with gr.Column(): processing_res_choice = gr.Radio( [ ("Recommended (768)", 768), ("Native (original)", 0), ], label="Processing resolution", value=768, ) submit = gr.Button(value="Compute Normals") with gr.Row(): output_normal_video = gr.Video(label="Normal Video", elem_id='download') submit.click( fn=on_submit_video, inputs=[input_video, processing_res_choice], outputs=[output_normal_video] ) if __name__ == "__main__": demo.queue().launch(share=True)