import os import time import glob import json import yaml import torch import trimesh import argparse import mesh2sdf.core import numpy as np import skimage.measure import seaborn as sns from scipy.spatial.transform import Rotation from mesh_to_sdf import get_surface_point_cloud from accelerate.utils import set_seed from accelerate import Accelerator from huggingface_hub.file_download import hf_hub_download from huggingface_hub import list_repo_files from primitive_anything.utils import path_mkdir, count_parameters from primitive_anything.utils.logger import print_log os.environ['PYOPENGL_PLATFORM'] = 'egl' import spaces repo_id = "hyz317/PrimitiveAnything" all_files = list_repo_files(repo_id, revision="main") for file in all_files: if os.path.exists(file): continue hf_hub_download(repo_id, file, local_dir="./ckpt") hf_hub_download("Maikou/Michelangelo", "checkpoints/aligned_shape_latents/shapevae-256.ckpt", local_dir="./ckpt") def parse_args(): parser = argparse.ArgumentParser(description='Process 3D model files') parser.add_argument( '--input', type=str, default='./data/demo_glb/', help='Input file or directory path (default: ./data/demo_glb/)' ) parser.add_argument( '--log_path', type=str, default='./results/demo', help='Output directory path (default: results/demo)' ) return parser.parse_args() def get_input_files(input_path): if os.path.isfile(input_path): return [input_path] elif os.path.isdir(input_path): return glob.glob(os.path.join(input_path, '*')) else: raise ValueError(f"Input path {input_path} is neither a file nor a directory") args = parse_args() LOG_PATH = args.log_path os.makedirs(LOG_PATH, exist_ok=True) print(f"Output directory: {LOG_PATH}") CODE_SHAPE = { 0: 'SM_GR_BS_CubeBevel_001.ply', 1: 'SM_GR_BS_SphereSharp_001.ply', 2: 'SM_GR_BS_CylinderSharp_001.ply', } shapename_map = { 'SM_GR_BS_CubeBevel_001.ply': 1101002001034001, 'SM_GR_BS_SphereSharp_001.ply': 1101002001034010, 'SM_GR_BS_CylinderSharp_001.ply': 1101002001034002, } #### config bs_dir = 'data/basic_shapes_norm' config_path = './configs/infer.yml' AR_checkpoint_path = './ckpt/mesh-transformer.ckpt.60.pt' temperature= 0.0 #### init model mesh_bs = {} for bs_path in glob.glob(os.path.join(bs_dir, '*.ply')): bs_name = os.path.basename(bs_path) bs = trimesh.load(bs_path) bs.visual.uv = np.clip(bs.visual.uv, 0, 1) bs.visual = bs.visual.to_color() mesh_bs[bs_name] = bs def create_model(cfg_model): kwargs = cfg_model name = kwargs.pop('name') model = get_model(name)(**kwargs) print_log("Model '{}' init: nb_params={:,}, kwargs={}".format(name, count_parameters(model), kwargs)) return model from primitive_anything.primitive_transformer import PrimitiveTransformerDiscrete def get_model(name): return { 'discrete': PrimitiveTransformerDiscrete, }[name] with open(config_path, mode='r') as fp: AR_train_cfg = yaml.load(fp, Loader=yaml.FullLoader) AR_checkpoint = torch.load(AR_checkpoint_path) transformer = create_model(AR_train_cfg['model']) transformer.load_state_dict(AR_checkpoint) device = torch.device('cuda') accelerator = Accelerator( mixed_precision='fp16', ) transformer = accelerator.prepare(transformer) transformer.eval() transformer.bs_pc = transformer.bs_pc.cuda() transformer.rotation_matrix_align_coord = transformer.rotation_matrix_align_coord.cuda() print('model loaded to device') def sample_surface_points(mesh, number_of_points=500000, surface_point_method='scan', sign_method='normal', scan_count=100, scan_resolution=400, sample_point_count=10000000, return_gradients=False, return_surface_pc_normals=False, normalized=False): sample_start = time.time() if surface_point_method == 'sample' and sign_method == 'depth': print("Incompatible methods for sampling points and determining sign, using sign_method='normal' instead.") sign_method = 'normal' surface_start = time.time() bound_radius = 1 if normalized else None surface_point_cloud = get_surface_point_cloud(mesh, surface_point_method, bound_radius, scan_count, scan_resolution, sample_point_count, calculate_normals=sign_method == 'normal' or return_gradients) surface_end = time.time() print('surface point cloud time cost :', surface_end - surface_start) normal_start = time.time() if return_surface_pc_normals: rng = np.random.default_rng() assert surface_point_cloud.points.shape[0] == surface_point_cloud.normals.shape[0] indices = rng.choice(surface_point_cloud.points.shape[0], number_of_points, replace=True) points = surface_point_cloud.points[indices] normals = surface_point_cloud.normals[indices] surface_points = np.concatenate([points, normals], axis=-1) else: surface_points = surface_point_cloud.get_random_surface_points(number_of_points, use_scans=True) normal_end = time.time() print('normal time cost :', normal_end - normal_start) sample_end = time.time() print('sample surface point time cost :', sample_end - sample_start) return surface_points def normalize_vertices(vertices, scale=0.9): bbmin, bbmax = vertices.min(0), vertices.max(0) center = (bbmin + bbmax) * 0.5 scale = 2.0 * scale / (bbmax - bbmin).max() vertices = (vertices - center) * scale return vertices, center, scale def export_to_watertight(normalized_mesh, octree_depth: int = 7): """ Convert the non-watertight mesh to watertight. Args: input_path (str): normalized path octree_depth (int): Returns: mesh(trimesh.Trimesh): watertight mesh """ size = 2 ** octree_depth level = 2 / size scaled_vertices, to_orig_center, to_orig_scale = normalize_vertices(normalized_mesh.vertices) sdf = mesh2sdf.core.compute(scaled_vertices, normalized_mesh.faces, size=size) vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), level) # watertight mesh vertices = vertices / size * 2 - 1 # -1 to 1 vertices = vertices / to_orig_scale + to_orig_center mesh = trimesh.Trimesh(vertices, faces, normals=normals) return mesh def process_mesh_to_surface_pc(mesh_list, marching_cubes=False, dilated_offset=0.0, sample_num=10000): # mesh_list : list of trimesh pc_normal_list = [] return_mesh_list = [] for mesh in mesh_list: if marching_cubes: mesh = export_to_watertight(mesh) print("MC over!") if dilated_offset > 0: new_vertices = mesh.vertices + mesh.vertex_normals * dilated_offset mesh.vertices = new_vertices print("dilate over!") mesh.merge_vertices() mesh.update_faces(mesh.unique_faces()) mesh.fix_normals() return_mesh_list.append(mesh) pc_normal = np.asarray(sample_surface_points(mesh, sample_num, return_surface_pc_normals=True)) pc_normal_list.append(pc_normal) print("process mesh success") return pc_normal_list, return_mesh_list #### utils def euler_to_quat(euler): return Rotation.from_euler('XYZ', euler, degrees=True).as_quat() def SRT_quat_to_matrix(scale, quat, translation): rotation_matrix = Rotation.from_quat(quat).as_matrix() transform_matrix = np.eye(4) transform_matrix[:3, :3] = rotation_matrix * scale transform_matrix[:3, 3] = translation return transform_matrix def write_output(primitives, name): out_json = {} out_json['operation'] = 0 out_json['type'] = 1 out_json['scene_id'] = None new_group = [] model_scene = trimesh.Scene() color_map = sns.color_palette("hls", primitives['type_code'].squeeze().shape[0]) color_map = (np.array(color_map) * 255).astype("uint8") for idx, (scale, rotation, translation, type_code) in enumerate(zip( primitives['scale'].squeeze().cpu().numpy(), primitives['rotation'].squeeze().cpu().numpy(), primitives['translation'].squeeze().cpu().numpy(), primitives['type_code'].squeeze().cpu().numpy() )): if type_code == -1: break bs_name = CODE_SHAPE[type_code] new_block = {} new_block['type_id'] = shapename_map[bs_name] new_block['data'] = {} new_block['data']['location'] = translation.tolist() new_block['data']['rotation'] = euler_to_quat(rotation).tolist() new_block['data']['scale'] = scale.tolist() new_block['data']['color'] = ['808080'] new_group.append(new_block) trans = SRT_quat_to_matrix(scale, euler_to_quat(rotation), translation) bs = mesh_bs[bs_name].copy().apply_transform(trans) new_vertex_colors = np.repeat(color_map[idx:idx+1], bs.visual.vertex_colors.shape[0], axis=0) bs.visual.vertex_colors[:, :3] = new_vertex_colors vertices = bs.vertices.copy() vertices[:, 1] = bs.vertices[:, 2] vertices[:, 2] = -bs.vertices[:, 1] bs.vertices = vertices model_scene.add_geometry(bs) out_json['group'] = new_group json_path = os.path.join(LOG_PATH, f'output_{name}.json') with open(json_path, 'w') as json_file: json.dump(out_json, json_file, indent=4) glb_path = os.path.join(LOG_PATH, f'output_{name}.glb') model_scene.export(glb_path) return glb_path, out_json @torch.no_grad() def do_inference(input_3d, dilated_offset=0.0, sample_seed=0, do_sampling=False, do_marching_cubes=False, postprocess='none'): t1 = time.time() set_seed(sample_seed) input_mesh = trimesh.load(input_3d, force='mesh') # scale mesh vertices = input_mesh.vertices bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)]) vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2 vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6 input_mesh.vertices = vertices pc_list, mesh_list = process_mesh_to_surface_pc( [input_mesh], marching_cubes=do_marching_cubes, dilated_offset=dilated_offset ) pc_normal = pc_list[0] # 10000, 6 mesh = mesh_list[0] pc_coor = pc_normal[:, :3] normals = pc_normal[:, 3:] if dilated_offset > 0: # scale mesh and pc vertices = mesh.vertices bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)]) vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2 vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6 mesh.vertices = vertices pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2 pc_coor = pc_coor / (bounds[1] - bounds[0]).max() * 1.6 input_save_name = os.path.join(LOG_PATH, f'processed_{os.path.basename(input_3d)}') mesh.export(input_save_name) assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), 'normals should be unit vectors, something wrong' normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16) input_pc = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None] with accelerator.autocast(): if postprocess == 'postprocess1': recon_primitives, mask = transformer.generate_w_recon_loss(pc=input_pc, temperature=temperature, single_directional=True) else: recon_primitives, mask = transformer.generate(pc=input_pc, temperature=temperature) output_glb, output_json = write_output(recon_primitives, os.path.basename(input_3d)[:-4]) return input_save_name, output_glb, output_json import gradio as gr @spaces.GPU def process_3d_model(input_3d, dilated_offset, do_marching_cubes, postprocess_method="postprocess1"): print(f"processing: {input_3d}") # try: preprocess_model_obj, output_model_obj, output_model_json = do_inference( input_3d, dilated_offset=dilated_offset, do_marching_cubes=do_marching_cubes, postprocess=postprocess_method ) return output_model_obj # except Exception as e: # return f"Error processing file: {str(e)}" _HEADER_ = '''

[SIGGRAPH 2025] PrimitiveAnything 🤗 Gradio Demo

This is official demo for our SIGGRAPH 2025 paper PrimitiveAnything: Human-Crafted 3D Primitive Assembly Generation with Auto-Regressive Transformer. Code: GitHub. Paper: ArXiv. ❗️❗️❗️**Important Notes:** - Currently our demo supports 3D models only. You can use other text- and image-conditioned models (e.g. [Tencent Hunyuan3D](https://huggingface.co/spaces/tencent/Hunyuan3D-2) or [TRELLIS](https://huggingface.co/spaces/theseanlavery/TRELLIS-3D)) to generate 3D models and then upload them here. - For optimal results with fine structures, we apply marching cubes and dilation operations by default (which differs from testing and evaluation). This prevents quality degradation in thin areas. ''' _CITE_ = r""" If PrimitiveAnything is helpful, please help to ⭐ the GitHub Repo. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/PrimitiveAnything/PrimitiveAnything?style=social)](https://github.com/PrimitiveAnything/PrimitiveAnything) --- 📝 **Citation** If you find our work useful for your research or applications, please cite using this bibtex: ```bibtex @misc{ye2025primitiveanything, title={PrimitiveAnything: Human-Crafted 3D Primitive Assembly Generation with Auto-Regressive Transformer}, author={Jingwen Ye and Yuze He and Yanning Zhou and Yiqin Zhu and Kaiwen Xiao and Yong-Jin Liu and Wei Yang and Xiao Han}, year={2025}, eprint={2505.04622}, archivePrefix={arXiv}, primaryClass={cs.GR} } ``` 📧 **Contact** If you have any questions, feel free to open a discussion or contact us at hyz22@mails.tsinghua.edu.cn. """ with gr.Blocks(title="PrimitiveAnything: Human-Crafted 3D Primitive Assembly Generation with Auto-Regressive Transformer") as demo: # Title section gr.Markdown(_HEADER_) with gr.Row(): with gr.Column(): # Input components input_3d = gr.Model3D(label="Upload 3D Model File") dilated_offset = gr.Number(label="Dilated Offset", value=0.015) do_marching_cubes = gr.Checkbox(label="Perform Marching Cubes", value=True) submit_btn = gr.Button("Process Model") with gr.Column(): # Output components output = gr.Model3D(label="Primitive Assembly Predition") submit_btn.click( fn=process_3d_model, inputs=[input_3d, dilated_offset, do_marching_cubes], outputs=output ) # Prepare examples properly example_files = [ [f] for f in glob.glob('./data/demo_glb/*.glb') ] # Note: wrapped in list and filtered for GLB example = gr.Examples( examples=example_files, inputs=[input_3d], # Only include the Model3D input examples_per_page=14, ) gr.Markdown(_CITE_) if __name__ == "__main__": demo.launch(ssr_mode=False)