jxwang1 commited on
Commit
7e70d08
·
1 Parent(s): 3640d21

add download weights

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -4,11 +4,12 @@ import os
4
  import torch
5
  import trimesh
6
  import sys
7
- sys.path.append("/home/user/app/cube")
8
  from cube3d.inference.engine import EngineFast
9
  from pathlib import Path
10
  import uuid
11
  import shutil
 
12
 
13
 
14
  GLOBAL_STATE = {}
@@ -104,10 +105,17 @@ if __name__=="__main__":
104
  )
105
 
106
  args = parser.parse_args()
 
 
 
 
 
 
 
107
  engine_fast = EngineFast(
108
- args.config_path,
109
- args.gpt_ckpt_path,
110
- args.shape_ckpt_path,
111
  device=torch.device("cuda"),
112
  )
113
  GLOBAL_STATE["engine_fast"] = engine_fast
 
4
  import torch
5
  import trimesh
6
  import sys
7
+ sys.path.append("cube")
8
  from cube3d.inference.engine import EngineFast
9
  from pathlib import Path
10
  import uuid
11
  import shutil
12
+ from huggingface_hub import snapshot_download
13
 
14
 
15
  GLOBAL_STATE = {}
 
105
  )
106
 
107
  args = parser.parse_args()
108
+ snapshot_download(
109
+ repo_id="Roblox/cube3d-v0.1",
110
+ local_dir="./model_weights"
111
+ )
112
+ config_path = "./model_weights/shape_tokenizer.safetensors"
113
+ gpt_ckpt_path = "./model_weights/shape_gpt.safetensors"
114
+ shape_ckpt_path = "./model_weights/shape_tokenizer.safetensors"
115
  engine_fast = EngineFast(
116
+ config_path,
117
+ gpt_ckpt_path,
118
+ shape_ckpt_path,
119
  device=torch.device("cuda"),
120
  )
121
  GLOBAL_STATE["engine_fast"] = engine_fast