Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from typing import Union | |
# this is a HF Spaces specific hack for ZeroGPU | |
import spaces | |
import sys | |
import torch | |
from shap_e.models.transmitter.base import Transmitter, VectorDecoder | |
import torch | |
import torch.nn as nn | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from omegaconf import OmegaConf | |
from pytorch_lightning import seed_everything | |
from huggingface_hub import hf_hub_download | |
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler | |
from einops import rearrange | |
from shap_e.diffusion.sample import sample_latents | |
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config | |
from shap_e.models.download import load_model, load_config | |
from shap_e.util.notebooks import create_pan_cameras | |
from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera | |
import math | |
import time | |
from requests.exceptions import ReadTimeout, ConnectionError | |
from shap_e.util.collections import AttrDict | |
from src.utils.train_util import instantiate_from_config | |
from src.utils.camera_util import ( | |
FOV_to_intrinsics, | |
get_zero123plus_input_cameras, | |
get_circular_camera_poses, | |
spherical_camera_pose | |
) | |
from src.utils.mesh_util import save_obj, save_glb | |
from src.utils.infer_util import remove_background, resize_foreground | |
def decode_latent_images( | |
xm: Union[Transmitter, VectorDecoder], | |
latent: torch.Tensor, | |
cameras: DifferentiableCameraBatch, | |
rendering_mode: str = "stf", | |
params = None, | |
background_color: torch.Tensor = torch.tensor([255.0, 255.0, 255.0], dtype=torch.float32), | |
): | |
params = params if params is not None else (xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params( | |
latent[None] | |
) | |
params = xm.renderer.update(params) | |
decoded = xm.renderer.render_views( | |
AttrDict(cameras=cameras), | |
params=params, | |
options=AttrDict(rendering_mode=rendering_mode, render_with_direction=False), | |
) | |
bg_color = background_color.to(decoded.channels.device) | |
images = bg_color * decoded.transmittance + (1 - decoded.transmittance) * decoded.channels | |
# arr = decoded.channels.clamp(0, 255).to(torch.uint8)[0].cpu().numpy() | |
return images | |
def create_custom_cameras(size: int, device: torch.device, azimuths: list, elevations: list, | |
fov_degrees: float, distance: float) -> DifferentiableCameraBatch: | |
# Object is in a 2x2x2 bounding box (-1 to 1 in each dimension) | |
object_diagonal = distance # Correct diagonal calculation for the cube | |
# Calculate radius based on object size and FOV | |
fov_radians = math.radians(fov_degrees) | |
radius = (object_diagonal / 2) / math.tan(fov_radians / 2) # Correct radius calculation | |
origins = [] | |
xs = [] | |
ys = [] | |
zs = [] | |
for azimuth, elevation in zip(azimuths, elevations): | |
azimuth_rad = np.radians(azimuth-90) | |
elevation_rad = np.radians(elevation) | |
# Calculate camera position | |
x = radius * np.cos(elevation_rad) * np.cos(azimuth_rad) | |
y = radius * np.cos(elevation_rad) * np.sin(azimuth_rad) | |
z = radius * np.sin(elevation_rad) | |
origin = np.array([x, y, z]) | |
# Calculate camera orientation | |
z_axis = -origin / np.linalg.norm(origin) # Point towards center | |
x_axis = np.array([-np.sin(azimuth_rad), np.cos(azimuth_rad), 0]) | |
y_axis = np.cross(z_axis, x_axis) | |
origins.append(origin) | |
zs.append(z_axis) | |
xs.append(x_axis) | |
ys.append(y_axis) | |
return DifferentiableCameraBatch( | |
shape=(1, len(origins)), | |
flat_camera=DifferentiableProjectiveCamera( | |
origin=torch.from_numpy(np.stack(origins, axis=0)).float().to(device), | |
x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device), | |
y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device), | |
z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device), | |
width=size, | |
height=size, | |
x_fov=fov_radians, | |
y_fov=fov_radians, | |
), | |
) | |
def load_models(): | |
"""Initialize and load all required models""" | |
config = OmegaConf.load('configs/instant-nerf-large-best.yaml') | |
model_config = config.model_config | |
infer_config = config.infer_config | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Load diffusion pipeline with retry logic | |
print('Loading diffusion pipeline...') | |
max_retries = 3 | |
retry_delay = 5 | |
for attempt in range(max_retries): | |
try: | |
pipeline = DiffusionPipeline.from_pretrained( | |
"sudo-ai/zero123plus-v1.2", | |
custom_pipeline="zero123plus", | |
torch_dtype=torch.float16, | |
local_files_only=False, | |
resume_download=True, | |
) | |
break | |
except (ReadTimeout, ConnectionError) as e: | |
if attempt == max_retries - 1: | |
raise Exception(f"Failed to download pipeline after {max_retries} attempts: {str(e)}") | |
print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...") | |
time.sleep(retry_delay) | |
retry_delay *= 2 # Exponential backoff | |
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
pipeline.scheduler.config, timestep_spacing='trailing' | |
) | |
# Modify UNet to handle 8 input channels instead of 4 | |
in_channels = 8 | |
out_channels = pipeline.unet.conv_in.out_channels | |
pipeline.unet.register_to_config(in_channels=in_channels) | |
with torch.no_grad(): | |
new_conv_in = nn.Conv2d( | |
in_channels, out_channels, pipeline.unet.conv_in.kernel_size, | |
pipeline.unet.conv_in.stride, pipeline.unet.conv_in.padding | |
) | |
new_conv_in.weight.zero_() | |
new_conv_in.weight[:, :4, :, :].copy_(pipeline.unet.conv_in.weight) | |
pipeline.unet.conv_in = new_conv_in | |
# Load custom UNet with retry logic | |
print('Loading custom UNet...') | |
for attempt in range(max_retries): | |
try: | |
pipeline.unet = pipeline.unet.from_pretrained( | |
"YiftachEde/Sharp-It", | |
local_files_only=False, | |
resume_download=True, | |
).to(torch.float16) | |
break | |
except (ReadTimeout, ConnectionError) as e: | |
if attempt == max_retries - 1: | |
raise Exception(f"Failed to download UNet after {max_retries} attempts: {str(e)}") | |
print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...") | |
time.sleep(retry_delay) | |
retry_delay *= 2 | |
pipeline = pipeline.to(device).to(torch_dtype=torch.float16) | |
# Load reconstruction model with retry logic | |
print('Loading reconstruction model...') | |
model = instantiate_from_config(model_config) | |
for attempt in range(max_retries): | |
try: | |
model_path = hf_hub_download( | |
repo_id="TencentARC/InstantMesh", | |
filename="instant_nerf_large.ckpt", | |
repo_type="model", | |
local_files_only=False, | |
resume_download=True, | |
cache_dir="model_cache" # Use a specific cache directory | |
) | |
break | |
except (ReadTimeout, ConnectionError) as e: | |
if attempt == max_retries - 1: | |
raise Exception(f"Failed to download model after {max_retries} attempts: {str(e)}") | |
print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...") | |
time.sleep(retry_delay) | |
retry_delay *= 2 | |
state_dict = torch.load(model_path, map_location='cpu')['state_dict'] | |
state_dict = {k[14:]: v for k, v in state_dict.items() | |
if k.startswith('lrm_generator.') and 'source_camera' not in k} | |
model.load_state_dict(state_dict, strict=True) | |
model = model.to(device) | |
model.eval() | |
return pipeline, model, infer_config | |
def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=None): | |
"""Process input images and run refinement""" | |
device = pipeline.device | |
if isinstance(input_images, list): | |
if len(input_images) == 1: | |
# Check if this is a pre-arranged layout | |
img = Image.open(input_images[0].name).convert('RGB') | |
if img.size == (640, 960): | |
# This is already a layout, use it directly | |
input_image = img | |
else: | |
# Single view - need 6 copies | |
img = img.resize((320, 320)) | |
img_array = np.array(img) / 255.0 | |
images = [img_array] * 6 | |
images = np.stack(images) | |
# Convert to tensor and create layout | |
images = torch.from_numpy(images).float() | |
images = images.permute(0, 3, 1, 2) | |
images = images.reshape(3, 2, 3, 320, 320) | |
images = images.permute(0, 2, 3, 1, 4) | |
images = images.reshape(3, 3, 320, 640) | |
images = images.reshape(1, 3, 960, 640) | |
# Convert back to PIL | |
images = images.permute(0, 2, 3, 1)[0] | |
images = (images.numpy() * 255).astype(np.uint8) | |
input_image = Image.fromarray(images) | |
else: | |
# Multiple individual views | |
images = [] | |
for img_file in input_images: | |
img = Image.open(img_file.name).convert('RGB') | |
img = img.resize((320, 320)) | |
img = np.array(img) / 255.0 | |
images.append(img) | |
# Pad to 6 images if needed | |
while len(images) < 6: | |
images.append(np.zeros_like(images[0])) | |
images = np.stack(images[:6]) | |
# Convert to tensor and create layout | |
images = torch.from_numpy(images).float() | |
images = images.permute(0, 3, 1, 2) | |
images = images.reshape(3, 2, 3, 320, 320) | |
images = images.permute(0, 2, 3, 1, 4) | |
images = images.reshape(3, 3, 320, 640) | |
images = images.reshape(1, 3, 960, 640) | |
# Convert back to PIL | |
images = images.permute(0, 2, 3, 1)[0] | |
images = (images.numpy() * 255).astype(np.uint8) | |
input_image = Image.fromarray(images) | |
else: | |
raise ValueError("Expected a list of images") | |
# Generate refined output | |
output = pipeline.refine( | |
input_image, | |
prompt=prompt, | |
num_inference_steps=int(steps), | |
guidance_scale=guidance_scale | |
).images[0] | |
return output, input_image | |
def create_mesh(refined_image, model, infer_config): | |
"""Generate mesh from refined image""" | |
# Convert PIL image to tensor | |
image = np.array(refined_image) / 255.0 | |
image = torch.from_numpy(image).float().permute(2, 0, 1) | |
# Reshape to 6 views | |
image = image.reshape(3, 960, 640) | |
image = image.reshape(3, 3, 320, 640) | |
image = image.permute(1, 0, 2, 3) | |
image = image.reshape(3, 3, 320, 2, 320) | |
image = image.permute(0, 3, 1, 2, 4) | |
image = image.reshape(6, 3, 320, 320) | |
# Add batch dimension | |
image = image.unsqueeze(0) | |
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to("cuda") | |
image = image.to("cuda") | |
with torch.no_grad(): | |
planes = model.forward_planes(image, input_cameras) | |
mesh_out = model.extract_mesh(planes, **infer_config) | |
vertices, faces, vertex_colors = mesh_out | |
return vertices, faces, vertex_colors | |
class ShapERenderer: | |
def __init__(self, device): | |
print("Initializing Shap-E models...") | |
self.device = device | |
torch.cuda.empty_cache() # Clear GPU memory before loading | |
self.xm = load_model('transmitter', device=self.device) | |
self.model = load_model('text300M', device=self.device) | |
self.diffusion = diffusion_from_config(load_config('diffusion')) | |
print("Shap-E models initialized!") | |
def generate_views(self, prompt, guidance_scale=15.0, num_steps=64): | |
try: | |
torch.cuda.empty_cache() # Clear GPU memory before generation | |
# Generate latents using the text-to-3D model | |
batch_size = 1 | |
guidance_scale = float(guidance_scale) | |
with torch.amp.autocast('cuda'): # Use automatic mixed precision | |
# Generate latents directly without nested spaces.GPU context | |
latents = sample_latents( | |
batch_size=batch_size, | |
model=self.model, | |
diffusion=self.diffusion, | |
guidance_scale=guidance_scale, | |
model_kwargs=dict(texts=[prompt] * batch_size), | |
progress=True, | |
clip_denoised=True, | |
use_fp16=True, | |
use_karras=True, | |
karras_steps=num_steps, | |
sigma_min=1e-3, | |
sigma_max=160, | |
s_churn=0, | |
) | |
# Render the 6 views we need with specific viewing angles | |
size = 320 # Size of each rendered image | |
images = [] | |
# Define our 6 specific camera positions to match refine.py | |
azimuths = [30, 90, 150, 210, 270, 330] | |
elevations = [20, -10, 20, -10, 20, -10] | |
for i, (azimuth, elevation) in enumerate(zip(azimuths, elevations)): | |
cameras = create_custom_cameras(size, self.device, azimuths=[azimuth], elevations=[elevation], fov_degrees=30, distance=3.0) | |
with torch.amp.autocast('cuda'): # Use automatic mixed precision | |
rendered_image = decode_latent_images( | |
self.xm, | |
latents[0], | |
cameras=cameras, | |
rendering_mode='stf' | |
) | |
images.append(rendered_image[0]) | |
torch.cuda.empty_cache() # Clear GPU memory after each view | |
# Convert images to uint8 | |
images = [np.array(image) for image in images] | |
# Create 2x3 grid layout (640x960) | |
layout = np.zeros((960, 640, 3), dtype=np.uint8) | |
for i, img in enumerate(images): | |
row = i // 2 | |
col = i % 2 | |
layout[row*320:(row+1)*320, col*320:(col+1)*320] = img | |
return Image.fromarray(layout), images | |
except Exception as e: | |
print(f"Error in generate_views: {e}") | |
torch.cuda.empty_cache() # Clear GPU memory on error | |
raise | |
class RefinerInterface: | |
def __init__(self): | |
print("Initializing InstantMesh models...") | |
torch.cuda.empty_cache() # Clear GPU memory before loading | |
self.pipeline, self.model, self.infer_config = load_models() | |
print("InstantMesh models initialized!") | |
def refine_model(self, input_image, prompt, steps=75, guidance_scale=7.5): | |
"""Main refinement function""" | |
try: | |
torch.cuda.empty_cache() # Clear GPU memory before processing | |
# Process image and get refined output | |
input_image = Image.fromarray(input_image) | |
# Rotate the layout if needed (if we're getting a 640x960 layout but pipeline expects 960x640) | |
if input_image.width == 960 and input_image.height == 640: | |
# Transpose the image to get 960x640 layout | |
input_array = np.array(input_image) | |
new_layout = np.zeros((960, 640, 3), dtype=np.uint8) | |
# Rearrange from 2x3 to 3x2 | |
for i in range(6): | |
src_row = i // 3 | |
src_col = i % 3 | |
dst_row = i // 2 | |
dst_col = i % 2 | |
new_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \ | |
input_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320] | |
input_image = Image.fromarray(new_layout) | |
# Process with the pipeline (expects 960x640) | |
with torch.amp.autocast('cuda'): # Use automatic mixed precision | |
refined_output_960x640 = self.pipeline.refine( | |
input_image, | |
prompt=prompt, | |
num_inference_steps=int(steps), | |
guidance_scale=guidance_scale | |
).images[0] | |
torch.cuda.empty_cache() # Clear GPU memory after refinement | |
# Generate mesh using the 960x640 format | |
with torch.amp.autocast('cuda'): # Use automatic mixed precision | |
vertices, faces, vertex_colors = create_mesh( | |
refined_output_960x640, | |
self.model, | |
self.infer_config | |
) | |
torch.cuda.empty_cache() # Clear GPU memory after mesh generation | |
# Save temporary mesh file | |
os.makedirs("temp", exist_ok=True) | |
temp_obj = os.path.join("temp", "refined_mesh.obj") | |
save_obj(vertices, faces, vertex_colors, temp_obj) | |
# Convert the output to 640x960 for display | |
refined_array = np.array(refined_output_960x640) | |
display_layout = np.zeros((960, 640, 3), dtype=np.uint8) | |
# Rearrange from 3x2 to 2x3 | |
for i in range(6): | |
src_row = i // 2 | |
src_col = i % 2 | |
dst_row = i // 2 | |
dst_col = i % 2 | |
display_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \ | |
refined_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320] | |
refined_output_640x960 = Image.fromarray(display_layout) | |
return refined_output_640x960, temp_obj | |
except Exception as e: | |
print(f"Error in refine_model: {e}") | |
torch.cuda.empty_cache() # Clear GPU memory on error | |
raise | |
def create_demo(): | |
print("Initializing models...") | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Initialize models at startup | |
shap_e = ShapERenderer(device) | |
refiner = RefinerInterface() | |
print("All models initialized!") | |
with gr.Blocks() as demo: | |
gr.Markdown("# Shap-E to InstantMesh Pipeline") | |
# First row: Controls | |
with gr.Row(): | |
with gr.Column(): | |
# Shap-E inputs | |
shape_prompt = gr.Textbox( | |
label="Shap-E Prompt", | |
placeholder="Enter text to generate initial 3D model..." | |
) | |
shape_guidance = gr.Slider( | |
minimum=1, | |
maximum=30, | |
value=15.0, | |
label="Shap-E Guidance Scale" | |
) | |
shape_steps = gr.Slider( | |
minimum=16, | |
maximum=128, | |
value=64, | |
step=16, | |
label="Shap-E Steps" | |
) | |
generate_btn = gr.Button("Generate Views") | |
with gr.Column(): | |
# Refinement inputs | |
refine_prompt = gr.Textbox( | |
label="Refinement Prompt", | |
placeholder="Enter prompt to guide refinement..." | |
) | |
refine_steps = gr.Slider( | |
minimum=30, | |
maximum=100, | |
value=75, | |
step=1, | |
label="Refinement Steps" | |
) | |
refine_guidance = gr.Slider( | |
minimum=1, | |
maximum=20, | |
value=7.5, | |
label="Refinement Guidance Scale" | |
) | |
refine_btn = gr.Button("Refine") | |
error_output = gr.Textbox(label="Status/Error Messages", interactive=False) | |
# Second row: Image panels side by side | |
with gr.Row(): | |
# Outputs - Images side by side | |
shape_output = gr.Image( | |
label="Generated Views", | |
width=640, | |
height=960 | |
) | |
refined_output = gr.Image( | |
label="Refined Output", | |
width=640, | |
height=960 | |
) | |
# Third row: 3D mesh panel below | |
with gr.Row(): | |
# 3D mesh centered | |
mesh_output = gr.Model3D( | |
label="3D Mesh", | |
clear_color=[1.0, 1.0, 1.0, 1.0], | |
) | |
# Set up event handlers | |
# Add GPU decorator to the generate function | |
def generate(prompt, guidance_scale, num_steps): | |
try: | |
torch.cuda.empty_cache() # Clear GPU memory before starting | |
with torch.no_grad(): | |
layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps) | |
return layout, None # Return None for error message | |
except Exception as e: | |
torch.cuda.empty_cache() # Clear GPU memory on error | |
error_msg = f"Error during generation: {str(e)}" | |
print(error_msg) | |
return None, error_msg | |
def refine(input_image, prompt, steps, guidance_scale): | |
try: | |
torch.cuda.empty_cache() # Clear GPU memory before starting | |
refined_img, mesh_path = refiner.refine_model( | |
input_image, | |
prompt, | |
steps, | |
guidance_scale | |
) | |
return refined_img, mesh_path, None # Return None for error message | |
except Exception as e: | |
torch.cuda.empty_cache() # Clear GPU memory on error | |
error_msg = f"Error during refinement: {str(e)}" | |
print(error_msg) | |
return None, None, error_msg | |
generate_btn.click( | |
fn=generate, | |
inputs=[shape_prompt, shape_guidance, shape_steps], | |
outputs=[shape_output, error_output] | |
) | |
refine_btn.click( | |
fn=refine, | |
inputs=[shape_output, refine_prompt, refine_steps, refine_guidance], | |
outputs=[refined_output, mesh_output, error_output] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.launch(share=True) |