Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ from transformers import TextIteratorStreamer
|
|
3 |
from threading import Thread
|
4 |
from transformers import StoppingCriteria, StoppingCriteriaList
|
5 |
import torch
|
|
|
6 |
import os
|
7 |
model_name = "microsoft/Phi-3-medium-128k-instruct"
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
@@ -16,6 +17,7 @@ class StopOnTokens(StoppingCriteria):
|
|
16 |
if input_ids[0][-1] == stop_id:
|
17 |
return True
|
18 |
return False
|
|
|
19 |
def predict(message, history):
|
20 |
history_transformer_format = history + [[message, ""]]
|
21 |
stop = StopOnTokens()
|
|
|
3 |
from threading import Thread
|
4 |
from transformers import StoppingCriteria, StoppingCriteriaList
|
5 |
import torch
|
6 |
+
import spaces
|
7 |
import os
|
8 |
model_name = "microsoft/Phi-3-medium-128k-instruct"
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
17 |
if input_ids[0][-1] == stop_id:
|
18 |
return True
|
19 |
return False
|
20 |
+
@spaces.GPU()
|
21 |
def predict(message, history):
|
22 |
history_transformer_format = history + [[message, ""]]
|
23 |
stop = StopOnTokens()
|