Spaces:
Sleeping
Sleeping
File size: 4,365 Bytes
e2bd985 fbad7a8 378bb55 fbad7a8 58cc205 378bb55 58cc205 378bb55 2ab4d40 378bb55 2ab4d40 378bb55 2ab4d40 e2bd985 378bb55 2ab4d40 378bb55 fbad7a8 378bb55 0200003 378bb55 e2bd985 378bb55 0200003 378bb55 0200003 378bb55 2ab4d40 378bb55 2ab4d40 378bb55 0200003 0b350a4 378bb55 0200003 2ab4d40 378bb55 fbad7a8 0200003 fbad7a8 58cc205 0200003 a17d30b 5dc4767 ef248bc 378bb55 ef248bc 0200003 58cc205 378bb55 0200003 fbad7a8 378bb55 0200003 378bb55 fbad7a8 378bb55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
import cv2
import numpy as np
import torch
import tempfile
from PIL import Image
import spaces
from tqdm.auto import tqdm
from diffusers import DDIMScheduler, AutoencoderKL
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from GeoWizard.geowizard.models.unet_2d_condition import UNet2DConditionModel
from GeoWizard.geowizard.models.geowizard_pipeline import DepthNormalEstimationPipeline
# Device setup
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
CHECKPOINT_PATH = "GonzaloMG/geowizard-e2e-ft"
# Load pretrained components
vae = AutoencoderKL.from_pretrained(CHECKPOINT_PATH, subfolder='vae')
scheduler = DDIMScheduler.from_pretrained(CHECKPOINT_PATH, timestep_spacing="trailing", subfolder='scheduler')
image_encoder = CLIPVisionModelWithProjection.from_pretrained(CHECKPOINT_PATH, subfolder="image_encoder")
feature_extractor = CLIPImageProcessor.from_pretrained(CHECKPOINT_PATH, subfolder="feature_extractor")
unet = UNet2DConditionModel.from_pretrained(CHECKPOINT_PATH, subfolder="unet")
# Instantiate pipeline
pipe = DepthNormalEstimationPipeline(
vae=vae,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
unet=unet,
scheduler=scheduler
).to(DEVICE)
pipe.unet.eval()
# UI texts
title = "# End-to-End Fine-Tuned GeoWizard Video"
description = (
"""
Please refer to our [paper](https://arxiv.org/abs/2409.11355) and
[GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details.
"""
)
@spaces.GPU
def predict(image: Image.Image, processing_res_choice: int):
"""
Single-frame prediction wrapped for GPU execution.
Returns a DepthNormalPipelineOutput with attribute normal_colored.
"""
with torch.no_grad():
return pipe(
image,
denoising_steps=1,
ensemble_size=1,
noise="zeros",
processing_res=processing_res_choice,
match_input_res=True
)
def on_submit_video(video_path: str, processing_res_choice: int):
"""
Processes each frame of the input video, generating a normal map video.
"""
if video_path is None:
print("No video uploaded.")
return None
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS) or 30
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Temporary output file for normals video
tmp_normal = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out_normal = cv2.VideoWriter(tmp_normal.name, fourcc, fps, (width, height))
# Process each frame
for _ in tqdm(range(frame_count), desc="Processing frames"):
ret, frame = cap.read()
if not ret:
break
# Convert frame to PIL image
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(rgb)
# Predict normals
result = predict(pil_image, processing_res_choice)
normal_colored = result.normal_colored
# Write normal frame
normal_frame = np.array(normal_colored)
normal_bgr = cv2.cvtColor(normal_frame, cv2.COLOR_RGB2BGR)
out_normal.write(normal_bgr)
# Release resources
cap.release()
out_normal.release()
# Return video path for download
return tmp_normal.name
# Build Gradio interface
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown("### Normals Prediction on Video")
with gr.Row():
input_video = gr.Video(label="Input Video", elem_id='video-display-input')
with gr.Column():
processing_res_choice = gr.Radio(
[
("Recommended (768)", 768),
("Native (original)", 0),
],
label="Processing resolution",
value=768,
)
submit = gr.Button(value="Compute Normals")
with gr.Row():
output_normal_video = gr.Video(label="Normal Video", elem_id='download')
submit.click(
fn=on_submit_video,
inputs=[input_video, processing_res_choice],
outputs=[output_normal_video]
)
if __name__ == "__main__":
demo.queue().launch(share=True)
|