alexnasa commited on
Commit
c037013
·
verified ·
1 Parent(s): 35b0546

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -16
app.py CHANGED
@@ -55,7 +55,14 @@ def install_cuda_toolkit():
55
  print("==> finished installation")
56
 
57
  install_cuda_toolkit()
 
 
 
 
 
58
 
 
 
59
 
60
 
61
  # Utility to select first image from a folder
@@ -84,7 +91,7 @@ def reset_all():
84
  )
85
 
86
  # Step 1: Preprocess the input image (Save and Crop)
87
- # @spaces.GPU()
88
  def preprocess_image(image_array, state):
89
  if image_array is None:
90
  return "❌ Please upload an image first.", None, state, gr.update(interactive=True), gr.update(interactive=False)
@@ -154,23 +161,23 @@ def step3_uv_map(state):
154
  @spaces.GPU()
155
  def step4_track(state):
156
 
157
- import os
158
- import torch
159
- import numpy as np
160
- import trimesh
161
- from pytorch3d.io import load_obj
162
-
163
- from pixel3dmm.tracking.flame.FLAME import FLAME
164
- from pixel3dmm.tracking.tracker import Tracker
165
- from omegaconf import OmegaConf
166
-
167
-
168
- DEVICE = "cuda"
169
-
170
- base_conf = OmegaConf.load(f'{env_paths.CODE_BASE}/configs/tracking.yaml')
171
 
172
- flame_model = FLAME(base_conf).to(DEVICE)
 
 
173
 
 
174
  session_id = state.get("session_id")
175
  base_conf.video_name = f'{session_id}'
176
  tracker = Tracker(base_conf, flame_model)
 
55
  print("==> finished installation")
56
 
57
  install_cuda_toolkit()
58
+
59
+ DEVICE = "cuda"
60
+
61
+ # 1. Prepare config at import time (no CUDA calls)
62
+ base_conf = OmegaConf.load("configs/tracking.yaml")
63
 
64
+ # 2. Empty cache for our heavy objects
65
+ _model_cache = {}
66
 
67
 
68
  # Utility to select first image from a folder
 
91
  )
92
 
93
  # Step 1: Preprocess the input image (Save and Crop)
94
+ @spaces.GPU()
95
  def preprocess_image(image_array, state):
96
  if image_array is None:
97
  return "❌ Please upload an image first.", None, state, gr.update(interactive=True), gr.update(interactive=False)
 
161
  @spaces.GPU()
162
  def step4_track(state):
163
 
164
+ # Lazy init + caching of FLAME model on GPU
165
+ if "flame_model" not in _model_cache:
166
+ import os
167
+ import torch
168
+ import numpy as np
169
+ import trimesh
170
+ from pytorch3d.io import load_obj
171
+
172
+ from pixel3dmm.tracking.flame.FLAME import FLAME
173
+ from pixel3dmm.tracking.tracker import Tracker
174
+ from omegaconf import OmegaConf
 
 
 
175
 
176
+ flame = FLAME(base_conf) # CPU instantiation
177
+ flame = flame.to(DEVICE) # CUDA init happens here
178
+ _model_cache["flame_model"] = flame
179
 
180
+ flame_model = _model_cache["flame_model"]
181
  session_id = state.get("session_id")
182
  base_conf.video_name = f'{session_id}'
183
  tracker = Tracker(base_conf, flame_model)