Spaces:
Running
Running
Commit
·
bcc3d60
1
Parent(s):
ee77d0d
utils to save transcript in mongodb
Browse files- convosim.py +3 -1
- utils/mongo_utils.py +10 -0
convosim.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
import streamlit as st
|
3 |
from streamlit.logger import get_logger
|
4 |
from langchain.schema.messages import HumanMessage
|
5 |
-
from utils.mongo_utils import get_db_client
|
6 |
from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
|
7 |
from utils.memory_utils import clear_memory, push_convo2db
|
8 |
from utils.chain_utils import get_chain, custom_chain_predict
|
@@ -87,6 +87,8 @@ if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG
|
|
87 |
# response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
|
88 |
for response in responses:
|
89 |
st.chat_message("assistant").write(response)
|
|
|
|
|
90 |
|
91 |
st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
|
92 |
if st.session_state['total_messages'] >= MAX_MSG_COUNT:
|
|
|
2 |
import streamlit as st
|
3 |
from streamlit.logger import get_logger
|
4 |
from langchain.schema.messages import HumanMessage
|
5 |
+
from utils.mongo_utils import get_db_client, update_convo
|
6 |
from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
|
7 |
from utils.memory_utils import clear_memory, push_convo2db
|
8 |
from utils.chain_utils import get_chain, custom_chain_predict
|
|
|
87 |
# response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
|
88 |
for response in responses:
|
89 |
st.chat_message("assistant").write(response)
|
90 |
+
transcript = memoryA.load_memory_variables({})[memoryA.memory_key]
|
91 |
+
update_convo(st.session_state["db_client"], st.session_state["convo_id"], transcript)
|
92 |
|
93 |
st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
|
94 |
if st.session_state['total_messages'] >= MAX_MSG_COUNT:
|
utils/mongo_utils.py
CHANGED
@@ -41,6 +41,16 @@ def new_convo(client, issue, language, username, is_comparison, model_one, model
|
|
41 |
logger.info(f"DBUTILS: new convo id is {convo_id}")
|
42 |
st.session_state['convo_id'] = convo_id
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
def new_comparison(client, prompt_timestamp, completion_timestamp,
|
45 |
chat_history, prompt, completionA, completionB,
|
46 |
source="webapp", subset=None
|
|
|
41 |
logger.info(f"DBUTILS: new convo id is {convo_id}")
|
42 |
st.session_state['convo_id'] = convo_id
|
43 |
|
44 |
+
def update_convo(client, convo_id, transcript=""):
|
45 |
+
from bson.objectid import ObjectId
|
46 |
+
db = client[DB_SCHEMA]
|
47 |
+
convos = db[DB_CONVOS]
|
48 |
+
myquery = { "_id": ObjectId(convo_id) }
|
49 |
+
newvalues = { "$set": { "transcript": transcript } }
|
50 |
+
result = convos.update_one(myquery, newvalues)
|
51 |
+
if result.matched_count == 1:
|
52 |
+
logger.debug(f"Updated conversation {convo_id}")
|
53 |
+
|
54 |
def new_comparison(client, prompt_timestamp, completion_timestamp,
|
55 |
chat_history, prompt, completionA, completionB,
|
56 |
source="webapp", subset=None
|