Sharp-It / app2.py
YiftachEde's picture
add
776d5b3
import os
import torch
import torch.nn as nn
import gradio as gr
import numpy as np
from PIL import Image
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from huggingface_hub import hf_hub_download
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
from einops import rearrange
from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images
from util import create_custom_cameras
from src.utils.train_util import instantiate_from_config
from src.utils.camera_util import (
FOV_to_intrinsics,
get_zero123plus_input_cameras,
get_circular_camera_poses,
spherical_camera_pose
)
from src.utils.mesh_util import save_obj, save_glb
from src.utils.infer_util import remove_background, resize_foreground
def load_models():
"""Initialize and load all required models"""
config = OmegaConf.load('configs/instant-nerf-large-best.yaml')
model_config = config.model_config
infer_config = config.infer_config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load diffusion pipeline
print('Loading diffusion pipeline...')
pipeline = DiffusionPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.2",
custom_pipeline="zero123plus",
torch_dtype=torch.float16
)
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config, timestep_spacing='trailing'
)
# Modify UNet to handle 8 input channels instead of 4
in_channels = 8
out_channels = pipeline.unet.conv_in.out_channels
pipeline.unet.register_to_config(in_channels=in_channels)
with torch.no_grad():
new_conv_in = nn.Conv2d(
in_channels, out_channels, pipeline.unet.conv_in.kernel_size,
pipeline.unet.conv_in.stride, pipeline.unet.conv_in.padding
)
new_conv_in.weight.zero_()
new_conv_in.weight[:, :4, :, :].copy_(pipeline.unet.conv_in.weight)
pipeline.unet.conv_in = new_conv_in
# Load custom UNet
print('Loading custom UNet...')
unet_path = "best_21.ckpt"
state_dict = torch.load(unet_path, map_location='cpu')
# Process the state dict to match the model keys
if 'state_dict' in state_dict:
new_state_dict = {key.replace('unet.unet.', ''): value for key, value in state_dict['state_dict'].items()}
pipeline.unet.load_state_dict(new_state_dict, strict=False)
else:
pipeline.unet.load_state_dict(state_dict, strict=False)
pipeline = pipeline.to(device).to(torch_dtype=torch.float16)
# Load reconstruction model
print('Loading reconstruction model...')
model = instantiate_from_config(model_config)
model_path = hf_hub_download(
repo_id="TencentARC/InstantMesh",
filename="instant_nerf_large.ckpt",
repo_type="model"
)
state_dict = torch.load(model_path, map_location='cpu')['state_dict']
state_dict = {k[14:]: v for k, v in state_dict.items()
if k.startswith('lrm_generator.') and 'source_camera' not in k}
model.load_state_dict(state_dict, strict=True)
model = model.to(device)
model.eval()
return pipeline, model, infer_config
def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=None):
"""Process input images and run refinement"""
device = pipeline.device
if isinstance(input_images, list):
if len(input_images) == 1:
# Check if this is a pre-arranged layout
img = Image.open(input_images[0].name).convert('RGB')
if img.size == (640, 960):
# This is already a layout, use it directly
input_image = img
else:
# Single view - need 6 copies
img = img.resize((320, 320))
img_array = np.array(img) / 255.0
images = [img_array] * 6
images = np.stack(images)
# Convert to tensor and create layout
images = torch.from_numpy(images).float()
images = images.permute(0, 3, 1, 2)
images = images.reshape(3, 2, 3, 320, 320)
images = images.permute(0, 2, 3, 1, 4)
images = images.reshape(3, 3, 320, 640)
images = images.reshape(1, 3, 960, 640)
# Convert back to PIL
images = images.permute(0, 2, 3, 1)[0]
images = (images.numpy() * 255).astype(np.uint8)
input_image = Image.fromarray(images)
else:
# Multiple individual views
images = []
for img_file in input_images:
img = Image.open(img_file.name).convert('RGB')
img = img.resize((320, 320))
img = np.array(img) / 255.0
images.append(img)
# Pad to 6 images if needed
while len(images) < 6:
images.append(np.zeros_like(images[0]))
images = np.stack(images[:6])
# Convert to tensor and create layout
images = torch.from_numpy(images).float()
images = images.permute(0, 3, 1, 2)
images = images.reshape(3, 2, 3, 320, 320)
images = images.permute(0, 2, 3, 1, 4)
images = images.reshape(3, 3, 320, 640)
images = images.reshape(1, 3, 960, 640)
# Convert back to PIL
images = images.permute(0, 2, 3, 1)[0]
images = (images.numpy() * 255).astype(np.uint8)
input_image = Image.fromarray(images)
else:
raise ValueError("Expected a list of images")
# Generate refined output
output = pipeline.refine(
input_image,
prompt=prompt,
num_inference_steps=int(steps),
guidance_scale=guidance_scale
).images[0]
return output, input_image
def create_mesh(refined_image, model, infer_config):
"""Generate mesh from refined image"""
# Convert PIL image to tensor
image = np.array(refined_image) / 255.0
image = torch.from_numpy(image).float().permute(2, 0, 1)
# Reshape to 6 views
image = image.reshape(3, 960, 640)
image = image.reshape(3, 3, 320, 640)
image = image.permute(1, 0, 2, 3)
image = image.reshape(3, 3, 320, 2, 320)
image = image.permute(0, 3, 1, 2, 4)
image = image.reshape(6, 3, 320, 320)
# Add batch dimension
image = image.unsqueeze(0)
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to("cuda")
image = image.to("cuda")
with torch.no_grad():
planes = model.forward_planes(image, input_cameras)
mesh_out = model.extract_mesh(planes, **infer_config)
vertices, faces, vertex_colors = mesh_out
return vertices, faces, vertex_colors
class ShapERenderer:
def __init__(self, device):
print("Loading Shap-E models...")
self.device = device
self.xm = load_model('transmitter', device=device)
self.model = load_model('text300M', device=device)
self.diffusion = diffusion_from_config(load_config('diffusion'))
print("Shap-E models loaded!")
def generate_views(self, prompt, guidance_scale=15.0, num_steps=64):
# Generate latents using the text-to-3D model
batch_size = 1
guidance_scale = float(guidance_scale)
latents = sample_latents(
batch_size=batch_size,
model=self.model,
diffusion=self.diffusion,
guidance_scale=guidance_scale,
model_kwargs=dict(texts=[prompt] * batch_size),
progress=True,
clip_denoised=True,
use_fp16=True,
use_karras=True,
karras_steps=num_steps,
sigma_min=1e-3,
sigma_max=160,
s_churn=0,
)
# Render the 6 views we need with specific viewing angles
size = 320 # Size of each rendered image
images = []
# Define our 6 specific camera positions to match refine.py
azimuths = [30, 90, 150, 210, 270, 330]
elevations = [20, -10, 20, -10, 20, -10]
for i, (azimuth, elevation) in enumerate(zip(azimuths, elevations)):
cameras = create_custom_cameras(size, self.device, azimuths=[azimuth], elevations=[elevation], fov_degrees=30, distance=3.0)
rendered_image = decode_latent_images(
self.xm,
latents[0],
rendering_mode='stf',
cameras=cameras
)
images.append(rendered_image.detach().cpu().numpy())
# Convert images to uint8
images = [(image).astype(np.uint8) for image in images]
# Create 2x3 grid layout (640x960) instead of 3x2 (960x640)
layout = np.zeros((960, 640, 3), dtype=np.uint8)
for i, img in enumerate(images):
row = i // 2 # Now 3 images per row
col = i % 2 # Now 3 images per row
layout[row*320:(row+1)*320, col*320:(col+1)*320] = img
return Image.fromarray(layout), images
class RefinerInterface:
def __init__(self):
print("Initializing InstantMesh models...")
self.pipeline, self.model, self.infer_config = load_models()
print("InstantMesh models loaded!")
def refine_model(self, input_image, prompt, steps=75, guidance_scale=7.5):
"""Main refinement function"""
# Process image and get refined output
input_image = Image.fromarray(input_image)
# Rotate the layout if needed (if we're getting a 640x960 layout but pipeline expects 960x640)
if input_image.width == 960 and input_image.height == 640:
# Transpose the image to get 960x640 layout
input_array = np.array(input_image)
new_layout = np.zeros((960, 640, 3), dtype=np.uint8)
# Rearrange from 2x3 to 3x2
for i in range(6):
src_row = i // 3
src_col = i % 3
dst_row = i // 2
dst_col = i % 2
new_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \
input_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320]
input_image = Image.fromarray(new_layout)
# Process with the pipeline (expects 960x640)
refined_output_960x640 = self.pipeline.refine(
input_image,
prompt=prompt,
num_inference_steps=int(steps),
guidance_scale=guidance_scale
).images[0]
# Generate mesh using the 960x640 format
vertices, faces, vertex_colors = create_mesh(
refined_output_960x640,
self.model,
self.infer_config
)
# Save temporary mesh file
os.makedirs("temp", exist_ok=True)
temp_obj = os.path.join("temp", "refined_mesh.obj")
save_obj(vertices, faces, vertex_colors, temp_obj)
# Convert the output to 640x960 for display
refined_array = np.array(refined_output_960x640)
display_layout = np.zeros((960, 640, 3), dtype=np.uint8)
# Rearrange from 3x2 to 2x3
for i in range(6):
src_row = i // 2
src_col = i % 2
dst_row = i // 2
dst_col = i % 2
display_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \
refined_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320]
refined_output_640x960 = Image.fromarray(display_layout)
return refined_output_640x960, temp_obj
def create_demo():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
shap_e = ShapERenderer(device)
refiner = RefinerInterface()
with gr.Blocks() as demo:
gr.Markdown("# Shap-E to InstantMesh Pipeline")
# First row: Controls
with gr.Row():
with gr.Column():
# Shap-E inputs
shape_prompt = gr.Textbox(
label="Shap-E Prompt",
placeholder="Enter text to generate initial 3D model..."
)
shape_guidance = gr.Slider(
minimum=1,
maximum=30,
value=15.0,
label="Shap-E Guidance Scale"
)
shape_steps = gr.Slider(
minimum=16,
maximum=128,
value=64,
step=16,
label="Shap-E Steps"
)
generate_btn = gr.Button("Generate Views")
with gr.Column():
# Refinement inputs
refine_prompt = gr.Textbox(
label="Refinement Prompt",
placeholder="Enter prompt to guide refinement..."
)
refine_steps = gr.Slider(
minimum=30,
maximum=100,
value=75,
step=1,
label="Refinement Steps"
)
refine_guidance = gr.Slider(
minimum=1,
maximum=20,
value=7.5,
label="Refinement Guidance Scale"
)
refine_btn = gr.Button("Refine")
# Second row: Image panels side by side
with gr.Row():
# Outputs - Images side by side
shape_output = gr.Image(
label="Generated Views",
width=640, # Swapped dimensions
height=960 # Swapped dimensions
)
refined_output = gr.Image(
label="Refined Output",
width=640, # Swapped dimensions
height=960 # Swapped dimensions
)
# Third row: 3D mesh panel below
with gr.Row():
# 3D mesh centered
mesh_output = gr.Model3D(
label="3D Mesh",
clear_color=[1.0, 1.0, 1.0, 1.0],
width=1280, # Full width
height=600 # Taller for better visualization
)
# Set up event handlers
def generate(prompt, guidance_scale, num_steps):
with torch.no_grad():
layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps)
return layout
def refine(input_image, prompt, steps, guidance_scale):
refined_img, mesh_path = refiner.refine_model(
input_image,
prompt,
steps,
guidance_scale
)
return refined_img, mesh_path
generate_btn.click(
fn=generate,
inputs=[shape_prompt, shape_guidance, shape_steps],
outputs=[shape_output]
)
refine_btn.click(
fn=refine,
inputs=[shape_output, refine_prompt, refine_steps, refine_guidance],
outputs=[refined_output, mesh_output]
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch(share=True)