Spaces:
Running
on
L4
Running
on
L4
import base64 | |
import logging | |
import os | |
import random | |
import sys | |
import comfy.model_management | |
import folder_paths | |
import numpy as np | |
import torch | |
import trimesh | |
from PIL import Image | |
from trimesh.exchange import gltf | |
sys.path.append(os.path.dirname(__file__)) | |
from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE | |
from spar3d.system import SPAR3D | |
from spar3d.utils import foreground_crop | |
SPAR3D_CATEGORY = "SPAR3D" | |
SPAR3D_MODEL_NAME = "stabilityai/spar3d" | |
class SPAR3DLoader: | |
CATEGORY = SPAR3D_CATEGORY | |
FUNCTION = "load" | |
RETURN_NAMES = ("spar3d_model",) | |
RETURN_TYPES = ("SPAR3D_MODEL",) | |
def INPUT_TYPES(cls): | |
return {"required": {}} | |
def load(self): | |
device = comfy.model_management.get_torch_device() | |
model = SPAR3D.from_pretrained( | |
SPAR3D_MODEL_NAME, | |
config_name="config.yaml", | |
weight_name="model.safetensors", | |
) | |
model.to(device) | |
model.eval() | |
return (model,) | |
class SPAR3DPreview: | |
CATEGORY = SPAR3D_CATEGORY | |
FUNCTION = "preview" | |
OUTPUT_NODE = True | |
RETURN_TYPES = () | |
def INPUT_TYPES(s): | |
return {"required": {"mesh": ("MESH",)}} | |
def preview(self, mesh): | |
glbs = [] | |
for m in mesh: | |
scene = trimesh.Scene(m) | |
glb_data = gltf.export_glb(scene, include_normals=True) | |
glb_base64 = base64.b64encode(glb_data).decode("utf-8") | |
glbs.append(glb_base64) | |
return {"ui": {"glbs": glbs}} | |
class SPAR3DSampler: | |
CATEGORY = SPAR3D_CATEGORY | |
FUNCTION = "predict" | |
RETURN_NAMES = ("mesh", "pointcloud") | |
RETURN_TYPES = ("MESH", "POINTCLOUD") | |
def INPUT_TYPES(s): | |
remesh_choices = ["none"] | |
if TRIANGLE_REMESH_AVAILABLE: | |
remesh_choices.append("triangle") | |
if QUAD_REMESH_AVAILABLE: | |
remesh_choices.append("quad") | |
opt_dict = { | |
"mask": ("MASK",), | |
"pointcloud": ("POINTCLOUD",), | |
"target_type": (["none", "vertex", "face"],), | |
"target_count": ( | |
"INT", | |
{"default": 1000, "min": 3, "max": 20000, "step": 1}, | |
), | |
"guidance_scale": ( | |
"FLOAT", | |
{"default": 3.0, "min": 1.0, "max": 5.0, "step": 0.05}, | |
), | |
"seed": ( | |
"INT", | |
{"default": 42, "min": 0, "max": 2**32 - 1, "step": 1}, | |
), | |
} | |
if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE: | |
opt_dict["remesh"] = (remesh_choices,) | |
return { | |
"required": { | |
"model": ("SPAR3D_MODEL",), | |
"image": ("IMAGE",), | |
"foreground_ratio": ( | |
"FLOAT", | |
{"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.01}, | |
), | |
"texture_resolution": ( | |
"INT", | |
{"default": 1024, "min": 512, "max": 2048, "step": 256}, | |
), | |
}, | |
"optional": opt_dict, | |
} | |
def predict( | |
s, | |
model, | |
image, | |
mask, | |
foreground_ratio, | |
texture_resolution, | |
pointcloud=None, | |
remesh="none", | |
target_type="none", | |
target_count=1000, | |
guidance_scale=3.0, | |
seed=42, | |
): | |
if image.shape[0] != 1: | |
raise ValueError("Only one image can be processed at a time") | |
vertex_count = ( | |
-1 | |
if target_type == "none" | |
else (target_count // 2 if target_type == "face" else target_count) | |
) | |
pil_image = Image.fromarray( | |
torch.clamp(torch.round(255.0 * image[0]), 0, 255) | |
.type(torch.uint8) | |
.cpu() | |
.numpy() | |
) | |
if mask is not None: | |
print("Using Mask") | |
mask_np = np.clip(255.0 * mask[0].detach().cpu().numpy(), 0, 255).astype( | |
np.uint8 | |
) | |
mask_pil = Image.fromarray(mask_np, mode="L") | |
pil_image.putalpha(mask_pil) | |
else: | |
if image.shape[3] != 4: | |
print("No mask or alpha channel detected, Converting to RGBA") | |
pil_image = pil_image.convert("RGBA") | |
pil_image = foreground_crop(pil_image, foreground_ratio) | |
model.cfg.guidance_scale = guidance_scale | |
random.seed(seed) | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
print(remesh) | |
with torch.no_grad(): | |
with torch.autocast(device_type="cuda", dtype=torch.float16): | |
if not TRIANGLE_REMESH_AVAILABLE and remesh == "triangle": | |
raise ImportError( | |
"Triangle remeshing requires gpytoolbox to be installed" | |
) | |
if not QUAD_REMESH_AVAILABLE and remesh == "quad": | |
raise ImportError("Quad remeshing requires pynim to be installed") | |
mesh, glob_dict = model.run_image( | |
pil_image, | |
bake_resolution=texture_resolution, | |
pointcloud=pointcloud, | |
remesh=remesh, | |
vertex_count=vertex_count, | |
) | |
if mesh.vertices.shape[0] == 0: | |
raise ValueError("No subject detected in the image") | |
return ( | |
[mesh], | |
glob_dict["pointcloud"].view(-1).detach().cpu().numpy().tolist(), | |
) | |
class SPAR3DSave: | |
CATEGORY = SPAR3D_CATEGORY | |
FUNCTION = "save" | |
OUTPUT_NODE = True | |
RETURN_TYPES = () | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"mesh": ("MESH",), | |
"filename_prefix": ("STRING", {"default": "SPAR3D"}), | |
} | |
} | |
def __init__(self): | |
self.type = "output" | |
def save(self, mesh, filename_prefix): | |
output_dir = folder_paths.get_output_directory() | |
glbs = [] | |
for idx, m in enumerate(mesh): | |
scene = trimesh.Scene(m) | |
glb_data = gltf.export_glb(scene, include_normals=True) | |
logging.info(f"Generated GLB model with {len(glb_data)} bytes") | |
full_output_folder, filename, counter, subfolder, filename_prefix = ( | |
folder_paths.get_save_image_path(filename_prefix, output_dir) | |
) | |
filename = filename.replace("%batch_num%", str(idx)) | |
out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}_.glb") | |
with open(out_path, "wb") as f: | |
f.write(glb_data) | |
glbs.append(base64.b64encode(glb_data).decode("utf-8")) | |
return {"ui": {"glbs": glbs}} | |
class SPAR3DPointCloudLoader: | |
CATEGORY = SPAR3D_CATEGORY | |
FUNCTION = "load_pointcloud" | |
RETURN_TYPES = ("POINTCLOUD",) | |
RETURN_NAMES = ("pointcloud",) | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"file": ("STRING", {"default": None}), | |
} | |
} | |
def load_pointcloud(self, file): | |
if file is None or file == "": | |
return (None,) | |
# Load the mesh using trimesh | |
mesh = trimesh.load(file) | |
# Extract vertices and colors | |
vertices = mesh.vertices | |
# Get vertex colors, defaulting to white if none exist | |
if mesh.visual.vertex_colors is not None: | |
colors = ( | |
mesh.visual.vertex_colors[:, :3] / 255.0 | |
) # Convert 0-255 to 0-1 range | |
else: | |
colors = np.ones((len(vertices), 3)) | |
# Interleave XYZ and RGB values | |
point_cloud = [] | |
for vertex, color in zip(vertices, colors): | |
point_cloud.extend( | |
[ | |
float(vertex[0]), | |
float(vertex[1]), | |
float(vertex[2]), | |
float(color[0]), | |
float(color[1]), | |
float(color[2]), | |
] | |
) | |
return (point_cloud,) | |
class SPAR3DPointCloudSaver: | |
CATEGORY = SPAR3D_CATEGORY | |
FUNCTION = "save_pointcloud" | |
OUTPUT_NODE = True | |
RETURN_TYPES = () | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"pointcloud": ("POINTCLOUD",), | |
"filename_prefix": ("STRING", {"default": "SPAR3D"}), | |
} | |
} | |
def save_pointcloud(self, pointcloud, filename_prefix): | |
if pointcloud is None: | |
return {"ui": {"text": "No point cloud data to save"}} | |
# Reshape the flat list into points with XYZ and RGB | |
points = np.array(pointcloud).reshape(-1, 6) | |
# Create vertex array for PLY | |
vertex_array = np.zeros( | |
len(points), | |
dtype=[ | |
("x", "f4"), | |
("y", "f4"), | |
("z", "f4"), | |
("red", "u1"), | |
("green", "u1"), | |
("blue", "u1"), | |
], | |
) | |
# Fill vertex array | |
vertex_array["x"] = points[:, 0] | |
vertex_array["y"] = points[:, 1] | |
vertex_array["z"] = points[:, 2] | |
# Convert RGB from 0-1 to 0-255 range | |
vertex_array["red"] = (points[:, 3] * 255).astype(np.uint8) | |
vertex_array["green"] = (points[:, 4] * 255).astype(np.uint8) | |
vertex_array["blue"] = (points[:, 5] * 255).astype(np.uint8) | |
# Create PLY object | |
ply_data = trimesh.PointCloud( | |
vertices=points[:, :3], colors=points[:, 3:] * 255 | |
) | |
# Save to file | |
output_dir = folder_paths.get_output_directory() | |
full_output_folder, filename, counter, subfolder, filename_prefix = ( | |
folder_paths.get_save_image_path(filename_prefix, output_dir) | |
) | |
out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}.ply") | |
ply_data.export(out_path) | |
return {"ui": {"text": f"Saved point cloud to {out_path}"}} | |
NODE_DISPLAY_NAME_MAPPINGS = { | |
"SPAR3DLoader": "SPAR3D Loader", | |
"SPAR3DPreview": "SPAR3D Preview", | |
"SPAR3DSampler": "SPAR3D Sampler", | |
"SPAR3DSave": "SPAR3D Save", | |
"SPAR3DPointCloudLoader": "SPAR3D Point Cloud Loader", | |
"SPAR3DPointCloudSaver": "SPAR3D Point Cloud Saver", | |
} | |
NODE_CLASS_MAPPINGS = { | |
"SPAR3DLoader": SPAR3DLoader, | |
"SPAR3DPreview": SPAR3DPreview, | |
"SPAR3DSampler": SPAR3DSampler, | |
"SPAR3DSave": SPAR3DSave, | |
"SPAR3DPointCloudLoader": SPAR3DPointCloudLoader, | |
"SPAR3DPointCloudSaver": SPAR3DPointCloudSaver, | |
} | |
WEB_DIRECTORY = "./comfyui" | |
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"] | |