feras-vbrl commited on
Commit
d847350
Β·
verified Β·
1 Parent(s): d876e1d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -7
app.py CHANGED
@@ -26,7 +26,7 @@ st.set_page_config(
26
  # Cache the model loading to avoid reloading on each interaction
27
  @st.cache_resource
28
  def load_model():
29
- with st.spinner("Loading model with vllm for T4 GPU..."):
30
  # Check if GPU is available
31
  gpu_available = torch.cuda.is_available()
32
  st.info(f"GPU available: {gpu_available}")
@@ -86,7 +86,9 @@ def load_model():
86
  )
87
 
88
  # Move model to appropriate device if needed
89
- if 'device_map' not in locals() or device_map is None:
 
 
90
  model = model.to(device)
91
 
92
  tokenizer = AutoTokenizer.from_pretrained("sciphi/triplex", trust_remote_code=True)
@@ -123,8 +125,15 @@ def triplextract(model, tokenizer, text, entity_types, predicates, use_vllm=True
123
  else:
124
  # Use standard transformers
125
  messages = [{'role': 'user', 'content': message}]
126
- device = next(model.parameters()).device # Get the device the model is on
127
- input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(device)
 
 
 
 
 
 
 
128
  output = tokenizer.decode(model.generate(input_ids=input_ids, max_length=2048)[0], skip_special_tokens=True)
129
 
130
  processing_time = time.time() - start_time
@@ -278,9 +287,14 @@ def main():
278
 
279
  # Add a note about performance
280
  if torch.cuda.is_available():
281
- st.success("""
282
- πŸš€ Running on GPU with vllm for optimal performance!
283
- """)
 
 
 
 
 
284
  else:
285
  st.warning("""
286
  ⚠️ You are running on CPU which can be very slow for the SciPhi/Triplex model.
 
26
  # Cache the model loading to avoid reloading on each interaction
27
  @st.cache_resource
28
  def load_model():
29
+ with st.spinner("Loading model..."):
30
  # Check if GPU is available
31
  gpu_available = torch.cuda.is_available()
32
  st.info(f"GPU available: {gpu_available}")
 
86
  )
87
 
88
  # Move model to appropriate device if needed
89
+ # Check if the model has a device_map attribute and if it's not None
90
+ # If it has a device_map, it's already distributed across devices and shouldn't be moved
91
+ if not hasattr(model, 'device_map') or model.device_map is None:
92
  model = model.to(device)
93
 
94
  tokenizer = AutoTokenizer.from_pretrained("sciphi/triplex", trust_remote_code=True)
 
125
  else:
126
  # Use standard transformers
127
  messages = [{'role': 'user', 'content': message}]
128
+
129
+ # Handle device mapping differently based on model configuration
130
+ if hasattr(model, 'device_map') and model.device_map is not None:
131
+ # Model already has device mapping, don't need to specify device for input_ids
132
+ input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
133
+ else:
134
+ # Get the device the model is on
135
+ device = next(model.parameters()).device
136
+ input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(device)
137
  output = tokenizer.decode(model.generate(input_ids=input_ids, max_length=2048)[0], skip_special_tokens=True)
138
 
139
  processing_time = time.time() - start_time
 
287
 
288
  # Add a note about performance
289
  if torch.cuda.is_available():
290
+ if use_vllm:
291
+ st.success("""
292
+ πŸš€ Running on GPU with vllm for optimal performance!
293
+ """)
294
+ else:
295
+ st.success("""
296
+ πŸš€ Running on GPU for improved performance!
297
+ """)
298
  else:
299
  st.warning("""
300
  ⚠️ You are running on CPU which can be very slow for the SciPhi/Triplex model.