import argparse import gradio as gr import os import torch import trimesh import sys from pathlib import Path pathdir = Path(__file__).parent / 'cube' sys.path.append(pathdir.as_posix()) # print(__file__) # print(os.listdir()) # print(os.listdir('cube')) # print(pathdir.as_posix()) from cube3d.inference.engine import EngineFast, Engine from pathlib import Path import uuid import shutil from huggingface_hub import snapshot_download GLOBAL_STATE = {} def gen_save_folder(max_size=200): os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True) dirs = [f for f in Path(GLOBAL_STATE["SAVE_DIR"]).iterdir() if f.is_dir()] if len(dirs) >= max_size: oldest_dir = min(dirs, key=lambda x: x.stat().st_ctime) shutil.rmtree(oldest_dir) print(f"Removed the oldest folder: {oldest_dir}") new_folder = os.path.join(GLOBAL_STATE["SAVE_DIR"], str(uuid.uuid4())) os.makedirs(new_folder, exist_ok=True) print(f"Created new folder: {new_folder}") return new_folder def handle_text_prompt(input_prompt, variance = 0): print(f"prompt: {input_prompt}, variance: {variance}") top_p = None if variance == 0 else (100 - variance) / 100.0 mesh_v_f = GLOBAL_STATE["engine_fast"].t2s([input_prompt], use_kv_cache=True, resolution_base=8.0, top_p=top_p) # save output vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1] save_folder = gen_save_folder() output_path = os.path.join(save_folder, "output.glb") trimesh.Trimesh(vertices=vertices, faces=faces).export(output_path) return output_path def build_interface(): """Build UI for gradio app """ title = "Cube 3D" with gr.Blocks(theme=gr.themes.Soft(), title=title, fill_width=True) as interface: gr.Markdown( f""" # {title} # Check out our [Github](https://github.com/Roblox/cube) to try it on your own machine! """ ) with gr.Row(): with gr.Column(scale=2): with gr.Group(): input_text_box = gr.Textbox( value=None, label="Prompt", lines=2, ) variance = gr.Slider(minimum=0, maximum=99, step=1, value=0, label="Variance") with gr.Row(): submit_button = gr.Button("Submit", variant="primary") with gr.Column(scale=3): model3d = gr.Model3D( label="Output", height="45em", interactive=False ) submit_button.click( handle_text_prompt, inputs=[ input_text_box, variance ], outputs=[ model3d ] ) return interface if __name__=="__main__": parser = argparse.ArgumentParser() parser.add_argument( "--config_path", type=str, help="Path to the config file", default="cube/cube3d/configs/open_model.yaml", ) parser.add_argument( "--gpt_ckpt_path", type=str, help="Path to the gpt ckpt path", default="model_weights/shape_gpt.safetensors", ) parser.add_argument( "--shape_ckpt_path", type=str, help="Path to the shape ckpt path", default="model_weights/shape_tokenizer.safetensors", ) parser.add_argument( "--save_dir", type=str, default="gradio_save_dir", ) args = parser.parse_args() snapshot_download( repo_id="Roblox/cube3d-v0.1", local_dir="./model_weights" ) config_path = args.config_path gpt_ckpt_path = "./model_weights/shape_gpt.safetensors" shape_ckpt_path = "./model_weights/shape_tokenizer.safetensors" engine_fast = EngineFast( config_path, gpt_ckpt_path, shape_ckpt_path, device=torch.device("cuda"), ) GLOBAL_STATE["engine_fast"] = engine_fast GLOBAL_STATE["SAVE_DIR"] = args.save_dir os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True) demo = build_interface() demo.queue(default_concurrency_limit=1) demo.launch()