import base64 import logging import os import random import sys import comfy.model_management import folder_paths import numpy as np import torch import trimesh from PIL import Image from trimesh.exchange import gltf sys.path.append(os.path.dirname(__file__)) from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE from spar3d.system import SPAR3D from spar3d.utils import foreground_crop SPAR3D_CATEGORY = "SPAR3D" SPAR3D_MODEL_NAME = "stabilityai/spar3d" class SPAR3DLoader: CATEGORY = SPAR3D_CATEGORY FUNCTION = "load" RETURN_NAMES = ("spar3d_model",) RETURN_TYPES = ("SPAR3D_MODEL",) @classmethod def INPUT_TYPES(cls): return {"required": {}} def load(self): device = comfy.model_management.get_torch_device() model = SPAR3D.from_pretrained( SPAR3D_MODEL_NAME, config_name="config.yaml", weight_name="model.safetensors", ) model.to(device) model.eval() return (model,) class SPAR3DPreview: CATEGORY = SPAR3D_CATEGORY FUNCTION = "preview" OUTPUT_NODE = True RETURN_TYPES = () @classmethod def INPUT_TYPES(s): return {"required": {"mesh": ("MESH",)}} def preview(self, mesh): glbs = [] for m in mesh: scene = trimesh.Scene(m) glb_data = gltf.export_glb(scene, include_normals=True) glb_base64 = base64.b64encode(glb_data).decode("utf-8") glbs.append(glb_base64) return {"ui": {"glbs": glbs}} class SPAR3DSampler: CATEGORY = SPAR3D_CATEGORY FUNCTION = "predict" RETURN_NAMES = ("mesh", "pointcloud") RETURN_TYPES = ("MESH", "POINTCLOUD") @classmethod def INPUT_TYPES(s): remesh_choices = ["none"] if TRIANGLE_REMESH_AVAILABLE: remesh_choices.append("triangle") if QUAD_REMESH_AVAILABLE: remesh_choices.append("quad") opt_dict = { "mask": ("MASK",), "pointcloud": ("POINTCLOUD",), "target_type": (["none", "vertex", "face"],), "target_count": ( "INT", {"default": 1000, "min": 3, "max": 20000, "step": 1}, ), "guidance_scale": ( "FLOAT", {"default": 3.0, "min": 1.0, "max": 5.0, "step": 0.05}, ), "seed": ( "INT", {"default": 42, "min": 0, "max": 2**32 - 1, "step": 1}, ), } if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE: opt_dict["remesh"] = (remesh_choices,) return { "required": { "model": ("SPAR3D_MODEL",), "image": ("IMAGE",), "foreground_ratio": ( "FLOAT", {"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.01}, ), "texture_resolution": ( "INT", {"default": 1024, "min": 512, "max": 2048, "step": 256}, ), }, "optional": opt_dict, } def predict( s, model, image, mask, foreground_ratio, texture_resolution, pointcloud=None, remesh="none", target_type="none", target_count=1000, guidance_scale=3.0, seed=42, ): if image.shape[0] != 1: raise ValueError("Only one image can be processed at a time") vertex_count = ( -1 if target_type == "none" else (target_count // 2 if target_type == "face" else target_count) ) pil_image = Image.fromarray( torch.clamp(torch.round(255.0 * image[0]), 0, 255) .type(torch.uint8) .cpu() .numpy() ) if mask is not None: print("Using Mask") mask_np = np.clip(255.0 * mask[0].detach().cpu().numpy(), 0, 255).astype( np.uint8 ) mask_pil = Image.fromarray(mask_np, mode="L") pil_image.putalpha(mask_pil) else: if image.shape[3] != 4: print("No mask or alpha channel detected, Converting to RGBA") pil_image = pil_image.convert("RGBA") pil_image = foreground_crop(pil_image, foreground_ratio) model.cfg.guidance_scale = guidance_scale random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) print(remesh) with torch.no_grad(): with torch.autocast(device_type="cuda", dtype=torch.float16): if not TRIANGLE_REMESH_AVAILABLE and remesh == "triangle": raise ImportError( "Triangle remeshing requires gpytoolbox to be installed" ) if not QUAD_REMESH_AVAILABLE and remesh == "quad": raise ImportError("Quad remeshing requires pynim to be installed") mesh, glob_dict = model.run_image( pil_image, bake_resolution=texture_resolution, pointcloud=pointcloud, remesh=remesh, vertex_count=vertex_count, ) if mesh.vertices.shape[0] == 0: raise ValueError("No subject detected in the image") return ( [mesh], glob_dict["pointcloud"].view(-1).detach().cpu().numpy().tolist(), ) class SPAR3DSave: CATEGORY = SPAR3D_CATEGORY FUNCTION = "save" OUTPUT_NODE = True RETURN_TYPES = () @classmethod def INPUT_TYPES(s): return { "required": { "mesh": ("MESH",), "filename_prefix": ("STRING", {"default": "SPAR3D"}), } } def __init__(self): self.type = "output" def save(self, mesh, filename_prefix): output_dir = folder_paths.get_output_directory() glbs = [] for idx, m in enumerate(mesh): scene = trimesh.Scene(m) glb_data = gltf.export_glb(scene, include_normals=True) logging.info(f"Generated GLB model with {len(glb_data)} bytes") full_output_folder, filename, counter, subfolder, filename_prefix = ( folder_paths.get_save_image_path(filename_prefix, output_dir) ) filename = filename.replace("%batch_num%", str(idx)) out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}_.glb") with open(out_path, "wb") as f: f.write(glb_data) glbs.append(base64.b64encode(glb_data).decode("utf-8")) return {"ui": {"glbs": glbs}} class SPAR3DPointCloudLoader: CATEGORY = SPAR3D_CATEGORY FUNCTION = "load_pointcloud" RETURN_TYPES = ("POINTCLOUD",) RETURN_NAMES = ("pointcloud",) @classmethod def INPUT_TYPES(cls): return { "required": { "file": ("STRING", {"default": None}), } } def load_pointcloud(self, file): if file is None or file == "": return (None,) # Load the mesh using trimesh mesh = trimesh.load(file) # Extract vertices and colors vertices = mesh.vertices # Get vertex colors, defaulting to white if none exist if mesh.visual.vertex_colors is not None: colors = ( mesh.visual.vertex_colors[:, :3] / 255.0 ) # Convert 0-255 to 0-1 range else: colors = np.ones((len(vertices), 3)) # Interleave XYZ and RGB values point_cloud = [] for vertex, color in zip(vertices, colors): point_cloud.extend( [ float(vertex[0]), float(vertex[1]), float(vertex[2]), float(color[0]), float(color[1]), float(color[2]), ] ) return (point_cloud,) class SPAR3DPointCloudSaver: CATEGORY = SPAR3D_CATEGORY FUNCTION = "save_pointcloud" OUTPUT_NODE = True RETURN_TYPES = () @classmethod def INPUT_TYPES(s): return { "required": { "pointcloud": ("POINTCLOUD",), "filename_prefix": ("STRING", {"default": "SPAR3D"}), } } def save_pointcloud(self, pointcloud, filename_prefix): if pointcloud is None: return {"ui": {"text": "No point cloud data to save"}} # Reshape the flat list into points with XYZ and RGB points = np.array(pointcloud).reshape(-1, 6) # Create vertex array for PLY vertex_array = np.zeros( len(points), dtype=[ ("x", "f4"), ("y", "f4"), ("z", "f4"), ("red", "u1"), ("green", "u1"), ("blue", "u1"), ], ) # Fill vertex array vertex_array["x"] = points[:, 0] vertex_array["y"] = points[:, 1] vertex_array["z"] = points[:, 2] # Convert RGB from 0-1 to 0-255 range vertex_array["red"] = (points[:, 3] * 255).astype(np.uint8) vertex_array["green"] = (points[:, 4] * 255).astype(np.uint8) vertex_array["blue"] = (points[:, 5] * 255).astype(np.uint8) # Create PLY object ply_data = trimesh.PointCloud( vertices=points[:, :3], colors=points[:, 3:] * 255 ) # Save to file output_dir = folder_paths.get_output_directory() full_output_folder, filename, counter, subfolder, filename_prefix = ( folder_paths.get_save_image_path(filename_prefix, output_dir) ) out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}.ply") ply_data.export(out_path) return {"ui": {"text": f"Saved point cloud to {out_path}"}} NODE_DISPLAY_NAME_MAPPINGS = { "SPAR3DLoader": "SPAR3D Loader", "SPAR3DPreview": "SPAR3D Preview", "SPAR3DSampler": "SPAR3D Sampler", "SPAR3DSave": "SPAR3D Save", "SPAR3DPointCloudLoader": "SPAR3D Point Cloud Loader", "SPAR3DPointCloudSaver": "SPAR3D Point Cloud Saver", } NODE_CLASS_MAPPINGS = { "SPAR3DLoader": SPAR3DLoader, "SPAR3DPreview": SPAR3DPreview, "SPAR3DSampler": SPAR3DSampler, "SPAR3DSave": SPAR3DSave, "SPAR3DPointCloudLoader": SPAR3DPointCloudLoader, "SPAR3DPointCloudSaver": SPAR3DPointCloudSaver, } WEB_DIRECTORY = "./comfyui" __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]