Anuji commited on
Commit
b975069
·
verified ·
1 Parent(s): 55b3488

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -13,17 +13,21 @@ DEPLOY_MODELS = {}
13
  IMAGE_TOKEN = "<image>"
14
 
15
  # Fetch model
16
- def fetch_model(model_name: str, dtype=torch.bfloat16):
17
  global DEPLOY_MODELS
18
  if model_name not in DEPLOY_MODELS:
19
- logger.info(f"Loading {model_name}...")
 
 
 
 
20
  model_info = load_model(model_name, dtype=dtype)
21
  tokenizer, model, vl_chat_processor = model_info
22
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
  model = model.to(device)
24
  DEPLOY_MODELS[model_name] = (tokenizer, model, vl_chat_processor)
25
- logger.info(f"Loaded {model_name} on {device}")
26
  return DEPLOY_MODELS[model_name]
 
27
 
28
  # Generate prompt with history
29
  def generate_prompt_with_history(text, images, history, vl_chat_processor, tokenizer, max_length=2048):
 
13
  IMAGE_TOKEN = "<image>"
14
 
15
  # Fetch model
16
+ def fetch_model(model_name: str, dtype=None):
17
  global DEPLOY_MODELS
18
  if model_name not in DEPLOY_MODELS:
19
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
+ # Use bfloat16 only if using GPU
21
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
22
+
23
+ logger.info(f"Loading {model_name} on {device} with dtype={dtype}...")
24
  model_info = load_model(model_name, dtype=dtype)
25
  tokenizer, model, vl_chat_processor = model_info
 
26
  model = model.to(device)
27
  DEPLOY_MODELS[model_name] = (tokenizer, model, vl_chat_processor)
28
+ logger.info(f"Loaded {model_name} successfully.")
29
  return DEPLOY_MODELS[model_name]
30
+
31
 
32
  # Generate prompt with history
33
  def generate_prompt_with_history(text, images, history, vl_chat_processor, tokenizer, max_length=2048):