import os import time import glob import json import yaml import torch import trimesh import argparse import mesh2sdf.core import numpy as np import skimage.measure import seaborn as sns from scipy.spatial.transform import Rotation from mesh_to_sdf import get_surface_point_cloud from accelerate.utils import set_seed from accelerate import Accelerator from huggingface_hub.file_download import hf_hub_download from huggingface_hub import list_repo_files from primitive_anything.utils import path_mkdir, count_parameters from primitive_anything.utils.logger import print_log os.environ['PYOPENGL_PLATFORM'] = 'egl' import spaces repo_id = "hyz317/PrimitiveAnything" all_files = list_repo_files(repo_id, revision="main") for file in all_files: if os.path.exists(file): continue hf_hub_download(repo_id, file, local_dir="./ckpt") hf_hub_download("Maikou/Michelangelo", "checkpoints/aligned_shape_latents/shapevae-256.ckpt", local_dir="./ckpt") def parse_args(): parser = argparse.ArgumentParser(description='Process 3D model files') parser.add_argument( '--input', type=str, default='./data/demo_glb/', help='Input file or directory path (default: ./data/demo_glb/)' ) parser.add_argument( '--log_path', type=str, default='./results/demo', help='Output directory path (default: results/demo)' ) return parser.parse_args() def get_input_files(input_path): if os.path.isfile(input_path): return [input_path] elif os.path.isdir(input_path): return glob.glob(os.path.join(input_path, '*')) else: raise ValueError(f"Input path {input_path} is neither a file nor a directory") args = parse_args() LOG_PATH = args.log_path os.makedirs(LOG_PATH, exist_ok=True) print(f"Output directory: {LOG_PATH}") CODE_SHAPE = { 0: 'SM_GR_BS_CubeBevel_001.ply', 1: 'SM_GR_BS_SphereSharp_001.ply', 2: 'SM_GR_BS_CylinderSharp_001.ply', } shapename_map = { 'SM_GR_BS_CubeBevel_001.ply': 1101002001034001, 'SM_GR_BS_SphereSharp_001.ply': 1101002001034010, 'SM_GR_BS_CylinderSharp_001.ply': 1101002001034002, } #### config bs_dir = 'data/basic_shapes_norm' config_path = './configs/infer.yml' AR_checkpoint_path = './ckpt/mesh-transformer.ckpt.60.pt' temperature= 0.0 #### init model mesh_bs = {} for bs_path in glob.glob(os.path.join(bs_dir, '*.ply')): bs_name = os.path.basename(bs_path) bs = trimesh.load(bs_path) bs.visual.uv = np.clip(bs.visual.uv, 0, 1) bs.visual = bs.visual.to_color() mesh_bs[bs_name] = bs def create_model(cfg_model): kwargs = cfg_model name = kwargs.pop('name') model = get_model(name)(**kwargs) print_log("Model '{}' init: nb_params={:,}, kwargs={}".format(name, count_parameters(model), kwargs)) return model from primitive_anything.primitive_transformer import PrimitiveTransformerDiscrete def get_model(name): return { 'discrete': PrimitiveTransformerDiscrete, }[name] with open(config_path, mode='r') as fp: AR_train_cfg = yaml.load(fp, Loader=yaml.FullLoader) AR_checkpoint = torch.load(AR_checkpoint_path) transformer = create_model(AR_train_cfg['model']) transformer.load_state_dict(AR_checkpoint) device = torch.device('cuda') accelerator = Accelerator( mixed_precision='fp16', ) transformer = accelerator.prepare(transformer) transformer.eval() transformer.bs_pc = transformer.bs_pc.cuda() transformer.rotation_matrix_align_coord = transformer.rotation_matrix_align_coord.cuda() print('model loaded to device') def sample_surface_points(mesh, number_of_points=500000, surface_point_method='scan', sign_method='normal', scan_count=100, scan_resolution=400, sample_point_count=10000000, return_gradients=False, return_surface_pc_normals=False, normalized=False): sample_start = time.time() if surface_point_method == 'sample' and sign_method == 'depth': print("Incompatible methods for sampling points and determining sign, using sign_method='normal' instead.") sign_method = 'normal' surface_start = time.time() bound_radius = 1 if normalized else None surface_point_cloud = get_surface_point_cloud(mesh, surface_point_method, bound_radius, scan_count, scan_resolution, sample_point_count, calculate_normals=sign_method == 'normal' or return_gradients) surface_end = time.time() print('surface point cloud time cost :', surface_end - surface_start) normal_start = time.time() if return_surface_pc_normals: rng = np.random.default_rng() assert surface_point_cloud.points.shape[0] == surface_point_cloud.normals.shape[0] indices = rng.choice(surface_point_cloud.points.shape[0], number_of_points, replace=True) points = surface_point_cloud.points[indices] normals = surface_point_cloud.normals[indices] surface_points = np.concatenate([points, normals], axis=-1) else: surface_points = surface_point_cloud.get_random_surface_points(number_of_points, use_scans=True) normal_end = time.time() print('normal time cost :', normal_end - normal_start) sample_end = time.time() print('sample surface point time cost :', sample_end - sample_start) return surface_points def normalize_vertices(vertices, scale=0.9): bbmin, bbmax = vertices.min(0), vertices.max(0) center = (bbmin + bbmax) * 0.5 scale = 2.0 * scale / (bbmax - bbmin).max() vertices = (vertices - center) * scale return vertices, center, scale def export_to_watertight(normalized_mesh, octree_depth: int = 7): """ Convert the non-watertight mesh to watertight. Args: input_path (str): normalized path octree_depth (int): Returns: mesh(trimesh.Trimesh): watertight mesh """ size = 2 ** octree_depth level = 2 / size scaled_vertices, to_orig_center, to_orig_scale = normalize_vertices(normalized_mesh.vertices) sdf = mesh2sdf.core.compute(scaled_vertices, normalized_mesh.faces, size=size) vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), level) # watertight mesh vertices = vertices / size * 2 - 1 # -1 to 1 vertices = vertices / to_orig_scale + to_orig_center mesh = trimesh.Trimesh(vertices, faces, normals=normals) return mesh def process_mesh_to_surface_pc(mesh_list, marching_cubes=False, dilated_offset=0.0, sample_num=10000): # mesh_list : list of trimesh pc_normal_list = [] return_mesh_list = [] for mesh in mesh_list: if marching_cubes: mesh = export_to_watertight(mesh) print("MC over!") if dilated_offset > 0: new_vertices = mesh.vertices + mesh.vertex_normals * dilated_offset mesh.vertices = new_vertices print("dilate over!") mesh.merge_vertices() mesh.update_faces(mesh.unique_faces()) mesh.fix_normals() return_mesh_list.append(mesh) pc_normal = np.asarray(sample_surface_points(mesh, sample_num, return_surface_pc_normals=True)) pc_normal_list.append(pc_normal) print("process mesh success") return pc_normal_list, return_mesh_list #### utils def euler_to_quat(euler): return Rotation.from_euler('XYZ', euler, degrees=True).as_quat() def SRT_quat_to_matrix(scale, quat, translation): rotation_matrix = Rotation.from_quat(quat).as_matrix() transform_matrix = np.eye(4) transform_matrix[:3, :3] = rotation_matrix * scale transform_matrix[:3, 3] = translation return transform_matrix def write_output(primitives, name): out_json = {} out_json['operation'] = 0 out_json['type'] = 1 out_json['scene_id'] = None new_group = [] model_scene = trimesh.Scene() color_map = sns.color_palette("hls", primitives['type_code'].squeeze().shape[0]) color_map = (np.array(color_map) * 255).astype("uint8") for idx, (scale, rotation, translation, type_code) in enumerate(zip( primitives['scale'].squeeze().cpu().numpy(), primitives['rotation'].squeeze().cpu().numpy(), primitives['translation'].squeeze().cpu().numpy(), primitives['type_code'].squeeze().cpu().numpy() )): if type_code == -1: break bs_name = CODE_SHAPE[type_code] new_block = {} new_block['type_id'] = shapename_map[bs_name] new_block['data'] = {} new_block['data']['location'] = translation.tolist() new_block['data']['rotation'] = euler_to_quat(rotation).tolist() new_block['data']['scale'] = scale.tolist() new_block['data']['color'] = ['808080'] new_group.append(new_block) trans = SRT_quat_to_matrix(scale, euler_to_quat(rotation), translation) bs = mesh_bs[bs_name].copy().apply_transform(trans) new_vertex_colors = np.repeat(color_map[idx:idx+1], bs.visual.vertex_colors.shape[0], axis=0) bs.visual.vertex_colors[:, :3] = new_vertex_colors vertices = bs.vertices.copy() vertices[:, 1] = bs.vertices[:, 2] vertices[:, 2] = -bs.vertices[:, 1] bs.vertices = vertices model_scene.add_geometry(bs) out_json['group'] = new_group json_path = os.path.join(LOG_PATH, f'output_{name}.json') with open(json_path, 'w') as json_file: json.dump(out_json, json_file, indent=4) glb_path = os.path.join(LOG_PATH, f'output_{name}.glb') model_scene.export(glb_path) return glb_path, out_json @torch.no_grad() def do_inference(input_3d, dilated_offset=0.0, sample_seed=0, do_sampling=False, do_marching_cubes=False, postprocess='none'): t1 = time.time() set_seed(sample_seed) input_mesh = trimesh.load(input_3d, force='mesh') # scale mesh vertices = input_mesh.vertices bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)]) vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2 vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6 input_mesh.vertices = vertices pc_list, mesh_list = process_mesh_to_surface_pc( [input_mesh], marching_cubes=do_marching_cubes, dilated_offset=dilated_offset ) pc_normal = pc_list[0] # 10000, 6 mesh = mesh_list[0] pc_coor = pc_normal[:, :3] normals = pc_normal[:, 3:] if dilated_offset > 0: # scale mesh and pc vertices = mesh.vertices bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)]) vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2 vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6 mesh.vertices = vertices pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2 pc_coor = pc_coor / (bounds[1] - bounds[0]).max() * 1.6 input_save_name = os.path.join(LOG_PATH, f'processed_{os.path.basename(input_3d)}') mesh.export(input_save_name) assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), 'normals should be unit vectors, something wrong' normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16) input_pc = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None] with accelerator.autocast(): if postprocess == 'postprocess1': recon_primitives, mask = transformer.generate_w_recon_loss(pc=input_pc, temperature=temperature, single_directional=True) else: recon_primitives, mask = transformer.generate(pc=input_pc, temperature=temperature) output_glb, output_json = write_output(recon_primitives, os.path.basename(input_3d)[:-4]) return input_save_name, output_glb, output_json import gradio as gr @spaces.GPU def process_3d_model(input_3d, dilated_offset, do_marching_cubes, postprocess_method="postprocess1"): print(f"processing: {input_3d}") # try: preprocess_model_obj, output_model_obj, output_model_json = do_inference( input_3d, dilated_offset=dilated_offset, do_marching_cubes=do_marching_cubes, postprocess=postprocess_method ) return output_model_obj # except Exception as e: # return f"Error processing file: {str(e)}" _HEADER_ = '''