Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import time | |
import glob | |
import json | |
import yaml | |
import torch | |
import trimesh | |
import argparse | |
import mesh2sdf.core | |
import numpy as np | |
import skimage.measure | |
import seaborn as sns | |
from scipy.spatial.transform import Rotation | |
from mesh_to_sdf import get_surface_point_cloud | |
from accelerate.utils import set_seed | |
from accelerate import Accelerator | |
from huggingface_hub.file_download import hf_hub_download | |
from huggingface_hub import list_repo_files | |
from primitive_anything.utils import path_mkdir, count_parameters | |
from primitive_anything.utils.logger import print_log | |
os.environ['PYOPENGL_PLATFORM'] = 'egl' | |
import spaces | |
repo_id = "hyz317/PrimitiveAnything" | |
all_files = list_repo_files(repo_id, revision="main") | |
for file in all_files: | |
if os.path.exists(file): | |
continue | |
hf_hub_download(repo_id, file, local_dir="./ckpt") | |
hf_hub_download("Maikou/Michelangelo", "checkpoints/aligned_shape_latents/shapevae-256.ckpt", local_dir="./ckpt") | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Process 3D model files') | |
parser.add_argument( | |
'--input', | |
type=str, | |
default='./data/demo_glb/', | |
help='Input file or directory path (default: ./data/demo_glb/)' | |
) | |
parser.add_argument( | |
'--log_path', | |
type=str, | |
default='./results/demo', | |
help='Output directory path (default: results/demo)' | |
) | |
return parser.parse_args() | |
def get_input_files(input_path): | |
if os.path.isfile(input_path): | |
return [input_path] | |
elif os.path.isdir(input_path): | |
return glob.glob(os.path.join(input_path, '*')) | |
else: | |
raise ValueError(f"Input path {input_path} is neither a file nor a directory") | |
args = parse_args() | |
LOG_PATH = args.log_path | |
os.makedirs(LOG_PATH, exist_ok=True) | |
print(f"Output directory: {LOG_PATH}") | |
CODE_SHAPE = { | |
0: 'SM_GR_BS_CubeBevel_001.ply', | |
1: 'SM_GR_BS_SphereSharp_001.ply', | |
2: 'SM_GR_BS_CylinderSharp_001.ply', | |
} | |
shapename_map = { | |
'SM_GR_BS_CubeBevel_001.ply': 1101002001034001, | |
'SM_GR_BS_SphereSharp_001.ply': 1101002001034010, | |
'SM_GR_BS_CylinderSharp_001.ply': 1101002001034002, | |
} | |
#### config | |
bs_dir = 'data/basic_shapes_norm' | |
config_path = './configs/infer.yml' | |
AR_checkpoint_path = './ckpt/mesh-transformer.ckpt.60.pt' | |
temperature= 0.0 | |
#### init model | |
mesh_bs = {} | |
for bs_path in glob.glob(os.path.join(bs_dir, '*.ply')): | |
bs_name = os.path.basename(bs_path) | |
bs = trimesh.load(bs_path) | |
bs.visual.uv = np.clip(bs.visual.uv, 0, 1) | |
bs.visual = bs.visual.to_color() | |
mesh_bs[bs_name] = bs | |
def create_model(cfg_model): | |
kwargs = cfg_model | |
name = kwargs.pop('name') | |
model = get_model(name)(**kwargs) | |
print_log("Model '{}' init: nb_params={:,}, kwargs={}".format(name, count_parameters(model), kwargs)) | |
return model | |
from primitive_anything.primitive_transformer import PrimitiveTransformerDiscrete | |
def get_model(name): | |
return { | |
'discrete': PrimitiveTransformerDiscrete, | |
}[name] | |
with open(config_path, mode='r') as fp: | |
AR_train_cfg = yaml.load(fp, Loader=yaml.FullLoader) | |
AR_checkpoint = torch.load(AR_checkpoint_path) | |
transformer = create_model(AR_train_cfg['model']) | |
transformer.load_state_dict(AR_checkpoint) | |
device = torch.device('cuda') | |
accelerator = Accelerator( | |
mixed_precision='fp16', | |
) | |
transformer = accelerator.prepare(transformer) | |
transformer.eval() | |
transformer.bs_pc = transformer.bs_pc.cuda() | |
transformer.rotation_matrix_align_coord = transformer.rotation_matrix_align_coord.cuda() | |
print('model loaded to device') | |
def sample_surface_points(mesh, number_of_points=500000, surface_point_method='scan', sign_method='normal', | |
scan_count=100, scan_resolution=400, sample_point_count=10000000, return_gradients=False, | |
return_surface_pc_normals=False, normalized=False): | |
sample_start = time.time() | |
if surface_point_method == 'sample' and sign_method == 'depth': | |
print("Incompatible methods for sampling points and determining sign, using sign_method='normal' instead.") | |
sign_method = 'normal' | |
surface_start = time.time() | |
bound_radius = 1 if normalized else None | |
surface_point_cloud = get_surface_point_cloud(mesh, surface_point_method, bound_radius, scan_count, scan_resolution, | |
sample_point_count, | |
calculate_normals=sign_method == 'normal' or return_gradients) | |
surface_end = time.time() | |
print('surface point cloud time cost :', surface_end - surface_start) | |
normal_start = time.time() | |
if return_surface_pc_normals: | |
rng = np.random.default_rng() | |
assert surface_point_cloud.points.shape[0] == surface_point_cloud.normals.shape[0] | |
indices = rng.choice(surface_point_cloud.points.shape[0], number_of_points, replace=True) | |
points = surface_point_cloud.points[indices] | |
normals = surface_point_cloud.normals[indices] | |
surface_points = np.concatenate([points, normals], axis=-1) | |
else: | |
surface_points = surface_point_cloud.get_random_surface_points(number_of_points, use_scans=True) | |
normal_end = time.time() | |
print('normal time cost :', normal_end - normal_start) | |
sample_end = time.time() | |
print('sample surface point time cost :', sample_end - sample_start) | |
return surface_points | |
def normalize_vertices(vertices, scale=0.9): | |
bbmin, bbmax = vertices.min(0), vertices.max(0) | |
center = (bbmin + bbmax) * 0.5 | |
scale = 2.0 * scale / (bbmax - bbmin).max() | |
vertices = (vertices - center) * scale | |
return vertices, center, scale | |
def export_to_watertight(normalized_mesh, octree_depth: int = 7): | |
""" | |
Convert the non-watertight mesh to watertight. | |
Args: | |
input_path (str): normalized path | |
octree_depth (int): | |
Returns: | |
mesh(trimesh.Trimesh): watertight mesh | |
""" | |
size = 2 ** octree_depth | |
level = 2 / size | |
scaled_vertices, to_orig_center, to_orig_scale = normalize_vertices(normalized_mesh.vertices) | |
sdf = mesh2sdf.core.compute(scaled_vertices, normalized_mesh.faces, size=size) | |
vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), level) | |
# watertight mesh | |
vertices = vertices / size * 2 - 1 # -1 to 1 | |
vertices = vertices / to_orig_scale + to_orig_center | |
mesh = trimesh.Trimesh(vertices, faces, normals=normals) | |
return mesh | |
def process_mesh_to_surface_pc(mesh_list, marching_cubes=False, dilated_offset=0.0, sample_num=10000): | |
# mesh_list : list of trimesh | |
pc_normal_list = [] | |
return_mesh_list = [] | |
for mesh in mesh_list: | |
if marching_cubes: | |
mesh = export_to_watertight(mesh) | |
print("MC over!") | |
if dilated_offset > 0: | |
new_vertices = mesh.vertices + mesh.vertex_normals * dilated_offset | |
mesh.vertices = new_vertices | |
print("dilate over!") | |
mesh.merge_vertices() | |
mesh.update_faces(mesh.unique_faces()) | |
mesh.fix_normals() | |
return_mesh_list.append(mesh) | |
pc_normal = np.asarray(sample_surface_points(mesh, sample_num, return_surface_pc_normals=True)) | |
pc_normal_list.append(pc_normal) | |
print("process mesh success") | |
return pc_normal_list, return_mesh_list | |
#### utils | |
def euler_to_quat(euler): | |
return Rotation.from_euler('XYZ', euler, degrees=True).as_quat() | |
def SRT_quat_to_matrix(scale, quat, translation): | |
rotation_matrix = Rotation.from_quat(quat).as_matrix() | |
transform_matrix = np.eye(4) | |
transform_matrix[:3, :3] = rotation_matrix * scale | |
transform_matrix[:3, 3] = translation | |
return transform_matrix | |
def write_output(primitives, name): | |
out_json = {} | |
out_json['operation'] = 0 | |
out_json['type'] = 1 | |
out_json['scene_id'] = None | |
new_group = [] | |
model_scene = trimesh.Scene() | |
color_map = sns.color_palette("hls", primitives['type_code'].squeeze().shape[0]) | |
color_map = (np.array(color_map) * 255).astype("uint8") | |
for idx, (scale, rotation, translation, type_code) in enumerate(zip( | |
primitives['scale'].squeeze().cpu().numpy(), | |
primitives['rotation'].squeeze().cpu().numpy(), | |
primitives['translation'].squeeze().cpu().numpy(), | |
primitives['type_code'].squeeze().cpu().numpy() | |
)): | |
if type_code == -1: | |
break | |
bs_name = CODE_SHAPE[type_code] | |
new_block = {} | |
new_block['type_id'] = shapename_map[bs_name] | |
new_block['data'] = {} | |
new_block['data']['location'] = translation.tolist() | |
new_block['data']['rotation'] = euler_to_quat(rotation).tolist() | |
new_block['data']['scale'] = scale.tolist() | |
new_block['data']['color'] = ['808080'] | |
new_group.append(new_block) | |
trans = SRT_quat_to_matrix(scale, euler_to_quat(rotation), translation) | |
bs = mesh_bs[bs_name].copy().apply_transform(trans) | |
new_vertex_colors = np.repeat(color_map[idx:idx+1], bs.visual.vertex_colors.shape[0], axis=0) | |
bs.visual.vertex_colors[:, :3] = new_vertex_colors | |
vertices = bs.vertices.copy() | |
vertices[:, 1] = bs.vertices[:, 2] | |
vertices[:, 2] = -bs.vertices[:, 1] | |
bs.vertices = vertices | |
model_scene.add_geometry(bs) | |
out_json['group'] = new_group | |
json_path = os.path.join(LOG_PATH, f'output_{name}.json') | |
with open(json_path, 'w') as json_file: | |
json.dump(out_json, json_file, indent=4) | |
glb_path = os.path.join(LOG_PATH, f'output_{name}.glb') | |
model_scene.export(glb_path) | |
return glb_path, out_json | |
def do_inference(input_3d, dilated_offset=0.0, sample_seed=0, do_sampling=False, do_marching_cubes=False, postprocess='none'): | |
t1 = time.time() | |
set_seed(sample_seed) | |
input_mesh = trimesh.load(input_3d, force='mesh') | |
# scale mesh | |
vertices = input_mesh.vertices | |
bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)]) | |
vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2 | |
vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6 | |
input_mesh.vertices = vertices | |
pc_list, mesh_list = process_mesh_to_surface_pc( | |
[input_mesh], | |
marching_cubes=do_marching_cubes, | |
dilated_offset=dilated_offset | |
) | |
pc_normal = pc_list[0] # 10000, 6 | |
mesh = mesh_list[0] | |
pc_coor = pc_normal[:, :3] | |
normals = pc_normal[:, 3:] | |
if dilated_offset > 0: | |
# scale mesh and pc | |
vertices = mesh.vertices | |
bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)]) | |
vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2 | |
vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6 | |
mesh.vertices = vertices | |
pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2 | |
pc_coor = pc_coor / (bounds[1] - bounds[0]).max() * 1.6 | |
input_save_name = os.path.join(LOG_PATH, f'processed_{os.path.basename(input_3d)}') | |
mesh.export(input_save_name) | |
assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), 'normals should be unit vectors, something wrong' | |
normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16) | |
input_pc = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None] | |
with accelerator.autocast(): | |
if postprocess == 'postprocess1': | |
recon_primitives, mask = transformer.generate_w_recon_loss(pc=input_pc, temperature=temperature, single_directional=True) | |
else: | |
recon_primitives, mask = transformer.generate(pc=input_pc, temperature=temperature) | |
output_glb, output_json = write_output(recon_primitives, os.path.basename(input_3d)[:-4]) | |
return input_save_name, output_glb, output_json | |
import gradio as gr | |
def process_3d_model(input_3d, dilated_offset, do_marching_cubes, postprocess_method="postprocess1"): | |
print(f"processing: {input_3d}") | |
# try: | |
preprocess_model_obj, output_model_obj, output_model_json = do_inference( | |
input_3d, | |
dilated_offset=dilated_offset, | |
do_marching_cubes=do_marching_cubes, | |
postprocess=postprocess_method | |
) | |
return output_model_obj | |
# except Exception as e: | |
# return f"Error processing file: {str(e)}" | |
_HEADER_ = ''' | |
<h2><b>[SIGGRAPH 2025] PrimitiveAnything π€ Gradio Demo</b></h2> | |
This is official demo for our SIGGRAPH 2025 paper <a href="">PrimitiveAnything: Human-Crafted 3D Primitive Assembly Generation with Auto-Regressive Transformer</a>. | |
Code: <a href='https://github.com/PrimitiveAnything/PrimitiveAnything' target='_blank'>GitHub</a>. Paper: <a href='https://arxiv.org/abs/2505.04622' target='_blank'>ArXiv</a>. | |
βοΈβοΈβοΈ**Important Notes:** | |
- Currently our demo supports 3D models only. You can use other text- and image-conditioned models (e.g. [Tencent Hunyuan3D](https://huggingface.co/spaces/tencent/Hunyuan3D-2) or [TRELLIS](https://huggingface.co/spaces/theseanlavery/TRELLIS-3D)) to generate 3D models and then upload them here. | |
- For optimal results with fine structures, we apply marching cubes and dilation operations by default (which differs from testing and evaluation). This prevents quality degradation in thin areas. | |
''' | |
_CITE_ = r""" | |
If PrimitiveAnything is helpful, please help to β the <a href='https://github.com/PrimitiveAnything/PrimitiveAnything' target='_blank'>GitHub Repo</a>. Thanks! [](https://github.com/PrimitiveAnything/PrimitiveAnything) | |
--- | |
π **Citation** | |
If you find our work useful for your research or applications, please cite using this bibtex: | |
```bibtex | |
@misc{ye2025primitiveanything, | |
title={PrimitiveAnything: Human-Crafted 3D Primitive Assembly Generation with Auto-Regressive Transformer}, | |
author={Jingwen Ye and Yuze He and Yanning Zhou and Yiqin Zhu and Kaiwen Xiao and Yong-Jin Liu and Wei Yang and Xiao Han}, | |
year={2025}, | |
eprint={2505.04622}, | |
archivePrefix={arXiv}, | |
primaryClass={cs.GR} | |
} | |
``` | |
π§ **Contact** | |
If you have any questions, feel free to open a discussion or contact us at <b>[email protected]</b>. | |
""" | |
with gr.Blocks(title="PrimitiveAnything: Human-Crafted 3D Primitive Assembly Generation with Auto-Regressive Transformer") as demo: | |
# Title section | |
gr.Markdown(_HEADER_) | |
with gr.Row(): | |
with gr.Column(): | |
# Input components | |
input_3d = gr.Model3D(label="Upload 3D Model File") | |
dilated_offset = gr.Number(label="Dilated Offset", value=0.015) | |
do_marching_cubes = gr.Checkbox(label="Perform Marching Cubes", value=True) | |
submit_btn = gr.Button("Process Model") | |
with gr.Column(): | |
# Output components | |
output = gr.Model3D(label="Primitive Assembly Predition") | |
submit_btn.click( | |
fn=process_3d_model, | |
inputs=[input_3d, dilated_offset, do_marching_cubes], | |
outputs=output | |
) | |
# Prepare examples properly | |
example_files = [ [f] for f in glob.glob('./data/demo_glb/*.glb') ] # Note: wrapped in list and filtered for GLB | |
example = gr.Examples( | |
examples=example_files, | |
inputs=[input_3d], # Only include the Model3D input | |
examples_per_page=14, | |
) | |
gr.Markdown(_CITE_) | |
if __name__ == "__main__": | |
demo.launch(ssr_mode=False) |