jxwang1 commited on
Commit
be383f9
·
1 Parent(s): 5e9b2d2

add text2shape

Browse files
Files changed (1) hide show
  1. app.py +121 -4
app.py CHANGED
@@ -1,7 +1,124 @@
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
  import gradio as gr
3
+ import os
4
+ import torch
5
+ import trimesh
6
+ import sys
7
+ sys.path.append("cube")
8
+ from cube3d.inference.engine import EngineFast
9
+ from pathlib import Path
10
+ import uuid
11
+ import shutil
12
 
 
 
13
 
14
+ GLOBAL_STATE = {}
15
+
16
+ def gen_save_folder(max_size=200):
17
+ os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)
18
+
19
+ dirs = [f for f in Path(GLOBAL_STATE["SAVE_DIR"]).iterdir() if f.is_dir()]
20
+
21
+ if len(dirs) >= max_size:
22
+ oldest_dir = min(dirs, key=lambda x: x.stat().st_ctime)
23
+ shutil.rmtree(oldest_dir)
24
+ print(f"Removed the oldest folder: {oldest_dir}")
25
+
26
+ new_folder = os.path.join(GLOBAL_STATE["SAVE_DIR"], str(uuid.uuid4()))
27
+ os.makedirs(new_folder, exist_ok=True)
28
+ print(f"Created new folder: {new_folder}")
29
+
30
+ return new_folder
31
+
32
+ def handle_text_prompt(input_prompt, ):
33
+ mesh_v_f = GLOBAL_STATE["engine_fast"].t2s([input_prompt], use_kv_cache=True, resolution_base=8.0)
34
+ # save output
35
+ vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
36
+ save_folder = gen_save_folder()
37
+ output_path = os.path.join(save_folder, "output.obj")
38
+ trimesh.Trimesh(vertices=vertices, faces=faces).export(output_path)
39
+ return output_path
40
+
41
+ def build_interface():
42
+ """Build UI for gradio app
43
+
44
+ Three tabs:
45
+ - Scene Generation
46
+ - Scene Completion
47
+ - Scene Understanding
48
+ """
49
+ title = "Cube 3D"
50
+ with gr.Blocks(theme=gr.themes.Soft(), title=title, fill_width=True) as interface:
51
+ gr.Markdown(
52
+ f"""
53
+ # {title}
54
+ """
55
+ )
56
+
57
+ with gr.Row():
58
+ with gr.Column(scale=2):
59
+ with gr.Group():
60
+ input_text_box = gr.Textbox(
61
+ value=None,
62
+ label="Prompt",
63
+ lines=2,
64
+ )
65
+ with gr.Row():
66
+ submit_button = gr.Button("Submit", variant="primary")
67
+ with gr.Column(scale=3):
68
+ model3d = gr.Model3D(
69
+ label="Output", height="45em", interactive=False
70
+ )
71
+
72
+ submit_button.click(
73
+ handle_text_prompt,
74
+ inputs=[
75
+ input_text_box
76
+ ],
77
+ outputs=[
78
+ model3d
79
+ ]
80
+ )
81
+
82
+ return interface
83
+
84
+ if __name__=="__main__":
85
+
86
+ parser = argparse.ArgumentParser()
87
+ parser.add_argument(
88
+ "--config_path",
89
+ type=str,
90
+ help="Path to the config file",
91
+ default="cube/cube3d/configs/open_model.yaml",
92
+ )
93
+ parser.add_argument(
94
+ "--gpt_ckpt_path",
95
+ type=str,
96
+ help="Path to the gpt ckpt path",
97
+ default="model_weights/shape_gpt.safetensors",
98
+ )
99
+ parser.add_argument(
100
+ "--shape_ckpt_path",
101
+ type=str,
102
+ help="Path to the shape ckpt path",
103
+ default="model_weights/shape_tokenizer.safetensors",
104
+ )
105
+ parser.add_argument(
106
+ "--save_dir",
107
+ type=str,
108
+ default="gradio_save_dir",
109
+ )
110
+
111
+ args = parser.parse_args()
112
+ engine_fast = EngineFast(
113
+ args.config_path,
114
+ args.gpt_ckpt_path,
115
+ args.shape_ckpt_path,
116
+ device=torch.device("cuda"),
117
+ )
118
+ GLOBAL_STATE["engine_fast"] = engine_fast
119
+ GLOBAL_STATE["SAVE_DIR"] = args.save_dir
120
+ os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)
121
+
122
+ demo = build_interface()
123
+ demo.queue(default_concurrency_limit=None)
124
+ demo.launch()