File size: 4,365 Bytes
e2bd985
fbad7a8
 
 
 
378bb55
 
 
fbad7a8
58cc205
 
378bb55
 
58cc205
378bb55
2ab4d40
 
378bb55
 
2ab4d40
 
 
 
 
378bb55
 
 
 
 
 
 
 
2ab4d40
e2bd985
 
378bb55
 
2ab4d40
 
 
 
 
 
378bb55
fbad7a8
378bb55
 
 
0200003
378bb55
e2bd985
378bb55
 
 
 
 
 
 
 
 
 
 
 
0200003
378bb55
 
 
 
 
 
 
 
 
 
 
0200003
378bb55
 
 
 
2ab4d40
378bb55
 
 
 
 
2ab4d40
378bb55
 
 
0200003
0b350a4
 
378bb55
 
 
 
 
 
 
 
 
 
0200003
 
2ab4d40
378bb55
 
fbad7a8
 
0200003
fbad7a8
58cc205
0200003
a17d30b
5dc4767
ef248bc
 
378bb55
ef248bc
 
 
 
0200003
58cc205
378bb55
0200003
fbad7a8
378bb55
 
 
0200003
378bb55
fbad7a8
378bb55
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
import cv2
import numpy as np
import torch
import tempfile
from PIL import Image
import spaces
from tqdm.auto import tqdm

from diffusers import DDIMScheduler, AutoencoderKL
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from GeoWizard.geowizard.models.unet_2d_condition import UNet2DConditionModel
from GeoWizard.geowizard.models.geowizard_pipeline import DepthNormalEstimationPipeline

# Device setup
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
CHECKPOINT_PATH = "GonzaloMG/geowizard-e2e-ft"

# Load pretrained components
vae = AutoencoderKL.from_pretrained(CHECKPOINT_PATH, subfolder='vae')
scheduler = DDIMScheduler.from_pretrained(CHECKPOINT_PATH, timestep_spacing="trailing", subfolder='scheduler')
image_encoder = CLIPVisionModelWithProjection.from_pretrained(CHECKPOINT_PATH, subfolder="image_encoder")
feature_extractor = CLIPImageProcessor.from_pretrained(CHECKPOINT_PATH, subfolder="feature_extractor")
unet = UNet2DConditionModel.from_pretrained(CHECKPOINT_PATH, subfolder="unet")

# Instantiate pipeline
pipe = DepthNormalEstimationPipeline(
    vae=vae,
    image_encoder=image_encoder,
    feature_extractor=feature_extractor,
    unet=unet,
    scheduler=scheduler
).to(DEVICE)
pipe.unet.eval()

# UI texts
title = "# End-to-End Fine-Tuned GeoWizard Video"
description = (
    """
    Please refer to our [paper](https://arxiv.org/abs/2409.11355) and
    [GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details.
    """
)

@spaces.GPU
def predict(image: Image.Image, processing_res_choice: int):
    """
    Single-frame prediction wrapped for GPU execution.
    Returns a DepthNormalPipelineOutput with attribute normal_colored.
    """
    with torch.no_grad():
        return pipe(
            image,
            denoising_steps=1,
            ensemble_size=1,
            noise="zeros",
            processing_res=processing_res_choice,
            match_input_res=True
        )


def on_submit_video(video_path: str, processing_res_choice: int):
    """
    Processes each frame of the input video, generating a normal map video.
    """
    if video_path is None:
        print("No video uploaded.")
        return None

    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS) or 30
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # Temporary output file for normals video
    tmp_normal = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out_normal = cv2.VideoWriter(tmp_normal.name, fourcc, fps, (width, height))

    # Process each frame
    for _ in tqdm(range(frame_count), desc="Processing frames"):
        ret, frame = cap.read()
        if not ret:
            break

        # Convert frame to PIL image
        rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_image = Image.fromarray(rgb)

        # Predict normals
        result = predict(pil_image, processing_res_choice)
        normal_colored = result.normal_colored

        # Write normal frame
        normal_frame = np.array(normal_colored)
        normal_bgr = cv2.cvtColor(normal_frame, cv2.COLOR_RGB2BGR)
        out_normal.write(normal_bgr)

    # Release resources
    cap.release()
    out_normal.release()

    # Return video path for download
    return tmp_normal.name

# Build Gradio interface
with gr.Blocks() as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    gr.Markdown("### Normals Prediction on Video")

    with gr.Row():
        input_video = gr.Video(label="Input Video", elem_id='video-display-input')
        with gr.Column():
            processing_res_choice = gr.Radio(
                [
                    ("Recommended (768)", 768),
                    ("Native (original)", 0),
                ],
                label="Processing resolution",
                value=768,
            )
            submit = gr.Button(value="Compute Normals")

    with gr.Row():
        output_normal_video = gr.Video(label="Normal Video", elem_id='download')

    submit.click(
        fn=on_submit_video,
        inputs=[input_video, processing_res_choice],
        outputs=[output_normal_video]
    )

if __name__ == "__main__":
    demo.queue().launch(share=True)