import os from typing import Union # this is a HF Spaces specific hack for ZeroGPU import spaces import sys import torch from shap_e.models.transmitter.base import Transmitter, VectorDecoder import torch import torch.nn as nn import gradio as gr import numpy as np from PIL import Image from omegaconf import OmegaConf from pytorch_lightning import seed_everything from huggingface_hub import hf_hub_download from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler from einops import rearrange from shap_e.diffusion.sample import sample_latents from shap_e.diffusion.gaussian_diffusion import diffusion_from_config from shap_e.models.download import load_model, load_config from shap_e.util.notebooks import create_pan_cameras from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera import math import time from requests.exceptions import ReadTimeout, ConnectionError from shap_e.util.collections import AttrDict from src.utils.train_util import instantiate_from_config from src.utils.camera_util import ( FOV_to_intrinsics, get_zero123plus_input_cameras, get_circular_camera_poses, spherical_camera_pose ) from src.utils.mesh_util import save_obj, save_glb from src.utils.infer_util import remove_background, resize_foreground def decode_latent_images( xm: Union[Transmitter, VectorDecoder], latent: torch.Tensor, cameras: DifferentiableCameraBatch, rendering_mode: str = "stf", params = None, background_color: torch.Tensor = torch.tensor([255.0, 255.0, 255.0], dtype=torch.float32), ): params = params if params is not None else (xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params( latent[None] ) params = xm.renderer.update(params) decoded = xm.renderer.render_views( AttrDict(cameras=cameras), params=params, options=AttrDict(rendering_mode=rendering_mode, render_with_direction=False), ) bg_color = background_color.to(decoded.channels.device) images = bg_color * decoded.transmittance + (1 - decoded.transmittance) * decoded.channels # arr = decoded.channels.clamp(0, 255).to(torch.uint8)[0].cpu().numpy() return images def create_custom_cameras(size: int, device: torch.device, azimuths: list, elevations: list, fov_degrees: float, distance: float) -> DifferentiableCameraBatch: # Object is in a 2x2x2 bounding box (-1 to 1 in each dimension) object_diagonal = distance # Correct diagonal calculation for the cube # Calculate radius based on object size and FOV fov_radians = math.radians(fov_degrees) radius = (object_diagonal / 2) / math.tan(fov_radians / 2) # Correct radius calculation origins = [] xs = [] ys = [] zs = [] for azimuth, elevation in zip(azimuths, elevations): azimuth_rad = np.radians(azimuth-90) elevation_rad = np.radians(elevation) # Calculate camera position x = radius * np.cos(elevation_rad) * np.cos(azimuth_rad) y = radius * np.cos(elevation_rad) * np.sin(azimuth_rad) z = radius * np.sin(elevation_rad) origin = np.array([x, y, z]) # Calculate camera orientation z_axis = -origin / np.linalg.norm(origin) # Point towards center x_axis = np.array([-np.sin(azimuth_rad), np.cos(azimuth_rad), 0]) y_axis = np.cross(z_axis, x_axis) origins.append(origin) zs.append(z_axis) xs.append(x_axis) ys.append(y_axis) return DifferentiableCameraBatch( shape=(1, len(origins)), flat_camera=DifferentiableProjectiveCamera( origin=torch.from_numpy(np.stack(origins, axis=0)).float().to(device), x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device), y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device), z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device), width=size, height=size, x_fov=fov_radians, y_fov=fov_radians, ), ) def load_models(): """Initialize and load all required models""" config = OmegaConf.load('configs/instant-nerf-large-best.yaml') model_config = config.model_config infer_config = config.infer_config device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load diffusion pipeline with retry logic print('Loading diffusion pipeline...') max_retries = 3 retry_delay = 5 for attempt in range(max_retries): try: pipeline = DiffusionPipeline.from_pretrained( "sudo-ai/zero123plus-v1.2", custom_pipeline="zero123plus", torch_dtype=torch.float16, local_files_only=False, resume_download=True, ) break except (ReadTimeout, ConnectionError) as e: if attempt == max_retries - 1: raise Exception(f"Failed to download pipeline after {max_retries} attempts: {str(e)}") print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...") time.sleep(retry_delay) retry_delay *= 2 # Exponential backoff pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( pipeline.scheduler.config, timestep_spacing='trailing' ) # Modify UNet to handle 8 input channels instead of 4 in_channels = 8 out_channels = pipeline.unet.conv_in.out_channels pipeline.unet.register_to_config(in_channels=in_channels) with torch.no_grad(): new_conv_in = nn.Conv2d( in_channels, out_channels, pipeline.unet.conv_in.kernel_size, pipeline.unet.conv_in.stride, pipeline.unet.conv_in.padding ) new_conv_in.weight.zero_() new_conv_in.weight[:, :4, :, :].copy_(pipeline.unet.conv_in.weight) pipeline.unet.conv_in = new_conv_in # Load custom UNet with retry logic print('Loading custom UNet...') for attempt in range(max_retries): try: pipeline.unet = pipeline.unet.from_pretrained( "YiftachEde/Sharp-It", local_files_only=False, resume_download=True, ).to(torch.float16) break except (ReadTimeout, ConnectionError) as e: if attempt == max_retries - 1: raise Exception(f"Failed to download UNet after {max_retries} attempts: {str(e)}") print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...") time.sleep(retry_delay) retry_delay *= 2 pipeline = pipeline.to(device).to(torch_dtype=torch.float16) # Load reconstruction model with retry logic print('Loading reconstruction model...') model = instantiate_from_config(model_config) for attempt in range(max_retries): try: model_path = hf_hub_download( repo_id="TencentARC/InstantMesh", filename="instant_nerf_large.ckpt", repo_type="model", local_files_only=False, resume_download=True, cache_dir="model_cache" # Use a specific cache directory ) break except (ReadTimeout, ConnectionError) as e: if attempt == max_retries - 1: raise Exception(f"Failed to download model after {max_retries} attempts: {str(e)}") print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...") time.sleep(retry_delay) retry_delay *= 2 state_dict = torch.load(model_path, map_location='cpu')['state_dict'] state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k} model.load_state_dict(state_dict, strict=True) model = model.to(device) model.eval() return pipeline, model, infer_config @spaces.GPU(duration=20) def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=None): """Process input images and run refinement""" device = pipeline.device if isinstance(input_images, list): if len(input_images) == 1: # Check if this is a pre-arranged layout img = Image.open(input_images[0].name).convert('RGB') if img.size == (640, 960): # This is already a layout, use it directly input_image = img else: # Single view - need 6 copies img = img.resize((320, 320)) img_array = np.array(img) / 255.0 images = [img_array] * 6 images = np.stack(images) # Convert to tensor and create layout images = torch.from_numpy(images).float() images = images.permute(0, 3, 1, 2) images = images.reshape(3, 2, 3, 320, 320) images = images.permute(0, 2, 3, 1, 4) images = images.reshape(3, 3, 320, 640) images = images.reshape(1, 3, 960, 640) # Convert back to PIL images = images.permute(0, 2, 3, 1)[0] images = (images.numpy() * 255).astype(np.uint8) input_image = Image.fromarray(images) else: # Multiple individual views images = [] for img_file in input_images: img = Image.open(img_file.name).convert('RGB') img = img.resize((320, 320)) img = np.array(img) / 255.0 images.append(img) # Pad to 6 images if needed while len(images) < 6: images.append(np.zeros_like(images[0])) images = np.stack(images[:6]) # Convert to tensor and create layout images = torch.from_numpy(images).float() images = images.permute(0, 3, 1, 2) images = images.reshape(3, 2, 3, 320, 320) images = images.permute(0, 2, 3, 1, 4) images = images.reshape(3, 3, 320, 640) images = images.reshape(1, 3, 960, 640) # Convert back to PIL images = images.permute(0, 2, 3, 1)[0] images = (images.numpy() * 255).astype(np.uint8) input_image = Image.fromarray(images) else: raise ValueError("Expected a list of images") # Generate refined output output = pipeline.refine( input_image, prompt=prompt, num_inference_steps=int(steps), guidance_scale=guidance_scale ).images[0] return output, input_image @spaces.GPU(duration=20) def create_mesh(refined_image, model, infer_config): """Generate mesh from refined image""" # Convert PIL image to tensor image = np.array(refined_image) / 255.0 image = torch.from_numpy(image).float().permute(2, 0, 1) # Reshape to 6 views image = image.reshape(3, 960, 640) image = image.reshape(3, 3, 320, 640) image = image.permute(1, 0, 2, 3) image = image.reshape(3, 3, 320, 2, 320) image = image.permute(0, 3, 1, 2, 4) image = image.reshape(6, 3, 320, 320) # Add batch dimension image = image.unsqueeze(0) input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to("cuda") image = image.to("cuda") with torch.no_grad(): planes = model.forward_planes(image, input_cameras) mesh_out = model.extract_mesh(planes, **infer_config) vertices, faces, vertex_colors = mesh_out return vertices, faces, vertex_colors class ShapERenderer: def __init__(self, device): print("Initializing Shap-E models...") self.device = device torch.cuda.empty_cache() # Clear GPU memory before loading self.xm = load_model('transmitter', device=self.device) self.model = load_model('text300M', device=self.device) self.diffusion = diffusion_from_config(load_config('diffusion')) print("Shap-E models initialized!") @spaces.GPU(duration=80) def generate_views(self, prompt, guidance_scale=15.0, num_steps=64): try: torch.cuda.empty_cache() # Clear GPU memory before generation # Generate latents using the text-to-3D model batch_size = 1 guidance_scale = float(guidance_scale) with torch.amp.autocast('cuda'): # Use automatic mixed precision # Generate latents directly without nested spaces.GPU context latents = sample_latents( batch_size=batch_size, model=self.model, diffusion=self.diffusion, guidance_scale=guidance_scale, model_kwargs=dict(texts=[prompt] * batch_size), progress=True, clip_denoised=True, use_fp16=True, use_karras=True, karras_steps=num_steps, sigma_min=1e-3, sigma_max=160, s_churn=0, ) # Render the 6 views we need with specific viewing angles size = 320 # Size of each rendered image images = [] # Define our 6 specific camera positions to match refine.py azimuths = [30, 90, 150, 210, 270, 330] elevations = [20, -10, 20, -10, 20, -10] for i, (azimuth, elevation) in enumerate(zip(azimuths, elevations)): cameras = create_custom_cameras(size, self.device, azimuths=[azimuth], elevations=[elevation], fov_degrees=30, distance=3.0) with torch.amp.autocast('cuda'): # Use automatic mixed precision rendered_image = decode_latent_images( self.xm, latents[0], cameras=cameras, rendering_mode='stf' ) images.append(rendered_image[0]) torch.cuda.empty_cache() # Clear GPU memory after each view # Convert images to uint8 images = [np.array(image) for image in images] # Create 2x3 grid layout (640x960) layout = np.zeros((960, 640, 3), dtype=np.uint8) for i, img in enumerate(images): row = i // 2 col = i % 2 layout[row*320:(row+1)*320, col*320:(col+1)*320] = img return Image.fromarray(layout), images except Exception as e: print(f"Error in generate_views: {e}") torch.cuda.empty_cache() # Clear GPU memory on error raise class RefinerInterface: def __init__(self): print("Initializing InstantMesh models...") torch.cuda.empty_cache() # Clear GPU memory before loading self.pipeline, self.model, self.infer_config = load_models() print("InstantMesh models initialized!") def refine_model(self, input_image, prompt, steps=75, guidance_scale=7.5): """Main refinement function""" try: torch.cuda.empty_cache() # Clear GPU memory before processing # Process image and get refined output input_image = Image.fromarray(input_image) # Rotate the layout if needed (if we're getting a 640x960 layout but pipeline expects 960x640) if input_image.width == 960 and input_image.height == 640: # Transpose the image to get 960x640 layout input_array = np.array(input_image) new_layout = np.zeros((960, 640, 3), dtype=np.uint8) # Rearrange from 2x3 to 3x2 for i in range(6): src_row = i // 3 src_col = i % 3 dst_row = i // 2 dst_col = i % 2 new_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \ input_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320] input_image = Image.fromarray(new_layout) # Process with the pipeline (expects 960x640) with torch.amp.autocast('cuda'): # Use automatic mixed precision refined_output_960x640 = self.pipeline.refine( input_image, prompt=prompt, num_inference_steps=int(steps), guidance_scale=guidance_scale ).images[0] torch.cuda.empty_cache() # Clear GPU memory after refinement # Generate mesh using the 960x640 format with torch.amp.autocast('cuda'): # Use automatic mixed precision vertices, faces, vertex_colors = create_mesh( refined_output_960x640, self.model, self.infer_config ) torch.cuda.empty_cache() # Clear GPU memory after mesh generation # Save temporary mesh file os.makedirs("temp", exist_ok=True) temp_obj = os.path.join("temp", "refined_mesh.obj") save_obj(vertices, faces, vertex_colors, temp_obj) # Convert the output to 640x960 for display refined_array = np.array(refined_output_960x640) display_layout = np.zeros((960, 640, 3), dtype=np.uint8) # Rearrange from 3x2 to 2x3 for i in range(6): src_row = i // 2 src_col = i % 2 dst_row = i // 2 dst_col = i % 2 display_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \ refined_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320] refined_output_640x960 = Image.fromarray(display_layout) return refined_output_640x960, temp_obj except Exception as e: print(f"Error in refine_model: {e}") torch.cuda.empty_cache() # Clear GPU memory on error raise def create_demo(): print("Initializing models...") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Initialize models at startup shap_e = ShapERenderer(device) refiner = RefinerInterface() print("All models initialized!") with gr.Blocks() as demo: gr.Markdown("# Shap-E to InstantMesh Pipeline") # First row: Controls with gr.Row(): with gr.Column(): # Shap-E inputs shape_prompt = gr.Textbox( label="Shap-E Prompt", placeholder="Enter text to generate initial 3D model..." ) shape_guidance = gr.Slider( minimum=1, maximum=30, value=15.0, label="Shap-E Guidance Scale" ) shape_steps = gr.Slider( minimum=16, maximum=128, value=64, step=16, label="Shap-E Steps" ) generate_btn = gr.Button("Generate Views") with gr.Column(): # Refinement inputs refine_prompt = gr.Textbox( label="Refinement Prompt", placeholder="Enter prompt to guide refinement..." ) refine_steps = gr.Slider( minimum=30, maximum=100, value=75, step=1, label="Refinement Steps" ) refine_guidance = gr.Slider( minimum=1, maximum=20, value=7.5, label="Refinement Guidance Scale" ) refine_btn = gr.Button("Refine") error_output = gr.Textbox(label="Status/Error Messages", interactive=False) # Second row: Image panels side by side with gr.Row(): # Outputs - Images side by side shape_output = gr.Image( label="Generated Views", width=640, height=960 ) refined_output = gr.Image( label="Refined Output", width=640, height=960 ) # Third row: 3D mesh panel below with gr.Row(): # 3D mesh centered mesh_output = gr.Model3D( label="3D Mesh", clear_color=[1.0, 1.0, 1.0, 1.0], ) # Set up event handlers @spaces.GPU(duration=100) # Add GPU decorator to the generate function def generate(prompt, guidance_scale, num_steps): try: torch.cuda.empty_cache() # Clear GPU memory before starting with torch.no_grad(): layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps) return layout, None # Return None for error message except Exception as e: torch.cuda.empty_cache() # Clear GPU memory on error error_msg = f"Error during generation: {str(e)}" print(error_msg) return None, error_msg @spaces.GPU(duration=20) def refine(input_image, prompt, steps, guidance_scale): try: torch.cuda.empty_cache() # Clear GPU memory before starting refined_img, mesh_path = refiner.refine_model( input_image, prompt, steps, guidance_scale ) return refined_img, mesh_path, None # Return None for error message except Exception as e: torch.cuda.empty_cache() # Clear GPU memory on error error_msg = f"Error during refinement: {str(e)}" print(error_msg) return None, None, error_msg generate_btn.click( fn=generate, inputs=[shape_prompt, shape_guidance, shape_steps], outputs=[shape_output, error_output] ) refine_btn.click( fn=refine, inputs=[shape_output, refine_prompt, refine_steps, refine_guidance], outputs=[refined_output, mesh_output, error_output] ) return demo if __name__ == "__main__": demo = create_demo() demo.launch(share=True)