Spaces:
Running
Running
Commit
·
5b74371
1
Parent(s):
a59350a
conversation end
Browse files- app_config.py +6 -3
- convosim.py +13 -4
- models/databricks/texter_sim_llm.py +2 -2
- requirements.txt +2 -1
app_config.py
CHANGED
@@ -9,7 +9,7 @@ SOURCES = [
|
|
9 |
'OA_rolemodel',
|
10 |
# 'OA_finetuned',
|
11 |
]
|
12 |
-
SOURCES_LAB = {"OA_rolemodel":'OpenAI
|
13 |
"OA_finetuned":'Finetuned OpenAI',
|
14 |
# "CTL_llama2": "Llama 2",
|
15 |
"CTL_llama3": "Llama 3",
|
@@ -29,10 +29,13 @@ def source2label(source):
|
|
29 |
def issue2label(issue):
|
30 |
return seed2str.get(issue, "GCT")
|
31 |
|
32 |
-
ENVIRON = "
|
33 |
|
34 |
DB_SCHEMA = 'prod_db' if ENVIRON == 'prod' else 'test_db'
|
35 |
DB_CONVOS = 'conversations'
|
36 |
DB_COMPLETIONS = 'comparison_completions'
|
37 |
DB_BATTLES = 'battles'
|
38 |
-
DB_ERRORS = 'completion_errors'
|
|
|
|
|
|
|
|
9 |
'OA_rolemodel',
|
10 |
# 'OA_finetuned',
|
11 |
]
|
12 |
+
SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT4o',
|
13 |
"OA_finetuned":'Finetuned OpenAI',
|
14 |
# "CTL_llama2": "Llama 2",
|
15 |
"CTL_llama3": "Llama 3",
|
|
|
29 |
def issue2label(issue):
|
30 |
return seed2str.get(issue, "GCT")
|
31 |
|
32 |
+
ENVIRON = "prod"
|
33 |
|
34 |
DB_SCHEMA = 'prod_db' if ENVIRON == 'prod' else 'test_db'
|
35 |
DB_CONVOS = 'conversations'
|
36 |
DB_COMPLETIONS = 'comparison_completions'
|
37 |
DB_BATTLES = 'battles'
|
38 |
+
DB_ERRORS = 'completion_errors'
|
39 |
+
|
40 |
+
MAX_MSG_COUNT = 10
|
41 |
+
WARN_MSG_COUT = int(MAX_MSG_COUNT*0.8)
|
convosim.py
CHANGED
@@ -6,7 +6,7 @@ from utils.mongo_utils import get_db_client
|
|
6 |
from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
|
7 |
from utils.memory_utils import clear_memory, push_convo2db
|
8 |
from utils.chain_utils import get_chain, custom_chain_predict
|
9 |
-
from app_config import ISSUES, SOURCES, source2label, issue2label
|
10 |
|
11 |
logger = get_logger(__name__)
|
12 |
openai_api_key = os.environ['OPENAI_API_KEY']
|
@@ -15,6 +15,8 @@ temperature = 0.8
|
|
15 |
|
16 |
if "sent_messages" not in st.session_state:
|
17 |
st.session_state['sent_messages'] = 0
|
|
|
|
|
18 |
if "issue" not in st.session_state:
|
19 |
st.session_state['issue'] = ISSUES[0]
|
20 |
if 'previous_source' not in st.session_state:
|
@@ -57,6 +59,7 @@ if changed_source:
|
|
57 |
st.session_state['previous_source'] = source
|
58 |
st.session_state['issue'] = issue
|
59 |
st.session_state['sent_messages'] = 0
|
|
|
60 |
create_memory_add_initial_message(memories,
|
61 |
issue,
|
62 |
language,
|
@@ -69,12 +72,12 @@ memoryA = st.session_state[list(memories.keys())[0]]
|
|
69 |
llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
|
70 |
|
71 |
st.title("💬 Simulator")
|
72 |
-
|
73 |
for msg in memoryA.buffer_as_messages:
|
74 |
role = "user" if type(msg) == HumanMessage else "assistant"
|
75 |
st.chat_message(role).write(msg.content)
|
76 |
|
77 |
-
if prompt := st.chat_input():
|
78 |
st.session_state['sent_messages'] += 1
|
79 |
st.chat_message("user").write(prompt)
|
80 |
if 'convo_id' not in st.session_state:
|
@@ -85,6 +88,12 @@ if prompt := st.chat_input():
|
|
85 |
for response in responses:
|
86 |
st.chat_message("assistant").write(response)
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
with st.sidebar:
|
89 |
st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]")
|
90 |
-
st.markdown(f"### Total Messages: :red[**{
|
|
|
6 |
from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
|
7 |
from utils.memory_utils import clear_memory, push_convo2db
|
8 |
from utils.chain_utils import get_chain, custom_chain_predict
|
9 |
+
from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
|
10 |
|
11 |
logger = get_logger(__name__)
|
12 |
openai_api_key = os.environ['OPENAI_API_KEY']
|
|
|
15 |
|
16 |
if "sent_messages" not in st.session_state:
|
17 |
st.session_state['sent_messages'] = 0
|
18 |
+
if "total_messages" not in st.session_state:
|
19 |
+
st.session_state['total_messages'] = 0
|
20 |
if "issue" not in st.session_state:
|
21 |
st.session_state['issue'] = ISSUES[0]
|
22 |
if 'previous_source' not in st.session_state:
|
|
|
59 |
st.session_state['previous_source'] = source
|
60 |
st.session_state['issue'] = issue
|
61 |
st.session_state['sent_messages'] = 0
|
62 |
+
st.session_state['total_messages'] = 0
|
63 |
create_memory_add_initial_message(memories,
|
64 |
issue,
|
65 |
language,
|
|
|
72 |
llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
|
73 |
|
74 |
st.title("💬 Simulator")
|
75 |
+
st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
|
76 |
for msg in memoryA.buffer_as_messages:
|
77 |
role = "user" if type(msg) == HumanMessage else "assistant"
|
78 |
st.chat_message(role).write(msg.content)
|
79 |
|
80 |
+
if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG_COUNT - 4): #account for next interaction
|
81 |
st.session_state['sent_messages'] += 1
|
82 |
st.chat_message("user").write(prompt)
|
83 |
if 'convo_id' not in st.session_state:
|
|
|
88 |
for response in responses:
|
89 |
st.chat_message("assistant").write(response)
|
90 |
|
91 |
+
st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
|
92 |
+
if st.session_state['total_messages'] >= MAX_MSG_COUNT:
|
93 |
+
st.toast(f"Total of {MAX_MSG_COUNT} Messages reached. Conversation Ended", icon=":material/verified:")
|
94 |
+
elif st.session_state['total_messages'] >= WARN_MSG_COUT:
|
95 |
+
st.toast(f"The conversation will end at {MAX_MSG_COUNT} Total Messages ", icon=":material/warning:")
|
96 |
+
|
97 |
with st.sidebar:
|
98 |
st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]")
|
99 |
+
st.markdown(f"### Total Messages: :red[**{st.session_state['total_messages']}**]")
|
models/databricks/texter_sim_llm.py
CHANGED
@@ -24,8 +24,8 @@ def get_databricks_chain(source, issue, language, memory, temperature=0.8, texte
|
|
24 |
)
|
25 |
|
26 |
llm = CustomDatabricksLLM(
|
27 |
-
endpoint_url="https://dbc-6dca8e8f-4084.cloud.databricks.com/serving-endpoints/databricks-meta-llama-3-1-70b-instruct/invocations",
|
28 |
-
|
29 |
bearer_token=os.environ["DATABRICKS_TOKEN"],
|
30 |
texter_name=texter_name,
|
31 |
issue=issue,
|
|
|
24 |
)
|
25 |
|
26 |
llm = CustomDatabricksLLM(
|
27 |
+
# endpoint_url="https://dbc-6dca8e8f-4084.cloud.databricks.com/serving-endpoints/databricks-meta-llama-3-1-70b-instruct/invocations",
|
28 |
+
endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=endpoint_name),
|
29 |
bearer_token=os.environ["DATABRICKS_TOKEN"],
|
30 |
texter_name=texter_name,
|
31 |
issue=issue,
|
requirements.txt
CHANGED
@@ -2,4 +2,5 @@ scipy==1.11.1
|
|
2 |
langchain==0.3.0
|
3 |
pymongo==4.5.0
|
4 |
mlflow==2.9.0
|
5 |
-
langchain-openai==0.2.0
|
|
|
|
2 |
langchain==0.3.0
|
3 |
pymongo==4.5.0
|
4 |
mlflow==2.9.0
|
5 |
+
langchain-openai==0.2.0
|
6 |
+
streamlit==1.38.0
|