Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
@@ -26,7 +26,7 @@ st.set_page_config(
|
|
26 |
# Cache the model loading to avoid reloading on each interaction
|
27 |
@st.cache_resource
|
28 |
def load_model():
|
29 |
-
with st.spinner("Loading model
|
30 |
# Check if GPU is available
|
31 |
gpu_available = torch.cuda.is_available()
|
32 |
st.info(f"GPU available: {gpu_available}")
|
@@ -86,7 +86,9 @@ def load_model():
|
|
86 |
)
|
87 |
|
88 |
# Move model to appropriate device if needed
|
89 |
-
if
|
|
|
|
|
90 |
model = model.to(device)
|
91 |
|
92 |
tokenizer = AutoTokenizer.from_pretrained("sciphi/triplex", trust_remote_code=True)
|
@@ -123,8 +125,15 @@ def triplextract(model, tokenizer, text, entity_types, predicates, use_vllm=True
|
|
123 |
else:
|
124 |
# Use standard transformers
|
125 |
messages = [{'role': 'user', 'content': message}]
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
output = tokenizer.decode(model.generate(input_ids=input_ids, max_length=2048)[0], skip_special_tokens=True)
|
129 |
|
130 |
processing_time = time.time() - start_time
|
@@ -278,9 +287,14 @@ def main():
|
|
278 |
|
279 |
# Add a note about performance
|
280 |
if torch.cuda.is_available():
|
281 |
-
|
282 |
-
|
283 |
-
|
|
|
|
|
|
|
|
|
|
|
284 |
else:
|
285 |
st.warning("""
|
286 |
β οΈ You are running on CPU which can be very slow for the SciPhi/Triplex model.
|
|
|
26 |
# Cache the model loading to avoid reloading on each interaction
|
27 |
@st.cache_resource
|
28 |
def load_model():
|
29 |
+
with st.spinner("Loading model..."):
|
30 |
# Check if GPU is available
|
31 |
gpu_available = torch.cuda.is_available()
|
32 |
st.info(f"GPU available: {gpu_available}")
|
|
|
86 |
)
|
87 |
|
88 |
# Move model to appropriate device if needed
|
89 |
+
# Check if the model has a device_map attribute and if it's not None
|
90 |
+
# If it has a device_map, it's already distributed across devices and shouldn't be moved
|
91 |
+
if not hasattr(model, 'device_map') or model.device_map is None:
|
92 |
model = model.to(device)
|
93 |
|
94 |
tokenizer = AutoTokenizer.from_pretrained("sciphi/triplex", trust_remote_code=True)
|
|
|
125 |
else:
|
126 |
# Use standard transformers
|
127 |
messages = [{'role': 'user', 'content': message}]
|
128 |
+
|
129 |
+
# Handle device mapping differently based on model configuration
|
130 |
+
if hasattr(model, 'device_map') and model.device_map is not None:
|
131 |
+
# Model already has device mapping, don't need to specify device for input_ids
|
132 |
+
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
|
133 |
+
else:
|
134 |
+
# Get the device the model is on
|
135 |
+
device = next(model.parameters()).device
|
136 |
+
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(device)
|
137 |
output = tokenizer.decode(model.generate(input_ids=input_ids, max_length=2048)[0], skip_special_tokens=True)
|
138 |
|
139 |
processing_time = time.time() - start_time
|
|
|
287 |
|
288 |
# Add a note about performance
|
289 |
if torch.cuda.is_available():
|
290 |
+
if use_vllm:
|
291 |
+
st.success("""
|
292 |
+
π Running on GPU with vllm for optimal performance!
|
293 |
+
""")
|
294 |
+
else:
|
295 |
+
st.success("""
|
296 |
+
π Running on GPU for improved performance!
|
297 |
+
""")
|
298 |
else:
|
299 |
st.warning("""
|
300 |
β οΈ You are running on CPU which can be very slow for the SciPhi/Triplex model.
|