Anuji commited on
Commit
755d5e1
·
verified ·
1 Parent(s): b975069

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -13,19 +13,22 @@ DEPLOY_MODELS = {}
13
  IMAGE_TOKEN = "<image>"
14
 
15
  # Fetch model
16
- def fetch_model(model_name: str, dtype=None):
17
  global DEPLOY_MODELS
18
  if model_name not in DEPLOY_MODELS:
19
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
- # Use bfloat16 only if using GPU
21
- dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
22
-
23
- logger.info(f"Loading {model_name} on {device} with dtype={dtype}...")
24
  model_info = load_model(model_name, dtype=dtype)
25
  tokenizer, model, vl_chat_processor = model_info
26
- model = model.to(device)
 
 
 
 
 
 
 
27
  DEPLOY_MODELS[model_name] = (tokenizer, model, vl_chat_processor)
28
- logger.info(f"Loaded {model_name} successfully.")
29
  return DEPLOY_MODELS[model_name]
30
 
31
 
 
13
  IMAGE_TOKEN = "<image>"
14
 
15
  # Fetch model
16
+ def fetch_model(model_name: str, dtype=torch.bfloat16):
17
  global DEPLOY_MODELS
18
  if model_name not in DEPLOY_MODELS:
19
+ logger.info(f"Loading {model_name}...")
 
 
 
 
20
  model_info = load_model(model_name, dtype=dtype)
21
  tokenizer, model, vl_chat_processor = model_info
22
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
+ try:
24
+ model = model.to(device)
25
+ except RuntimeError as e:
26
+ logger.warning(f"Could not move model to {device}: {e}")
27
+ device = torch.device('cpu')
28
+ model = model.to(device)
29
+ logger.warning("Model fallback to CPU. Inference might be slow.")
30
  DEPLOY_MODELS[model_name] = (tokenizer, model, vl_chat_processor)
31
+ logger.info(f"Loaded {model_name} on {device}")
32
  return DEPLOY_MODELS[model_name]
33
 
34