ivnban27-ctl commited on
Commit
bcc3d60
·
1 Parent(s): ee77d0d

utils to save transcript in mongodb

Browse files
Files changed (2) hide show
  1. convosim.py +3 -1
  2. utils/mongo_utils.py +10 -0
convosim.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import streamlit as st
3
  from streamlit.logger import get_logger
4
  from langchain.schema.messages import HumanMessage
5
- from utils.mongo_utils import get_db_client
6
  from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
7
  from utils.memory_utils import clear_memory, push_convo2db
8
  from utils.chain_utils import get_chain, custom_chain_predict
@@ -87,6 +87,8 @@ if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG
87
  # response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
88
  for response in responses:
89
  st.chat_message("assistant").write(response)
 
 
90
 
91
  st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
92
  if st.session_state['total_messages'] >= MAX_MSG_COUNT:
 
2
  import streamlit as st
3
  from streamlit.logger import get_logger
4
  from langchain.schema.messages import HumanMessage
5
+ from utils.mongo_utils import get_db_client, update_convo
6
  from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
7
  from utils.memory_utils import clear_memory, push_convo2db
8
  from utils.chain_utils import get_chain, custom_chain_predict
 
87
  # response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
88
  for response in responses:
89
  st.chat_message("assistant").write(response)
90
+ transcript = memoryA.load_memory_variables({})[memoryA.memory_key]
91
+ update_convo(st.session_state["db_client"], st.session_state["convo_id"], transcript)
92
 
93
  st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
94
  if st.session_state['total_messages'] >= MAX_MSG_COUNT:
utils/mongo_utils.py CHANGED
@@ -41,6 +41,16 @@ def new_convo(client, issue, language, username, is_comparison, model_one, model
41
  logger.info(f"DBUTILS: new convo id is {convo_id}")
42
  st.session_state['convo_id'] = convo_id
43
 
 
 
 
 
 
 
 
 
 
 
44
  def new_comparison(client, prompt_timestamp, completion_timestamp,
45
  chat_history, prompt, completionA, completionB,
46
  source="webapp", subset=None
 
41
  logger.info(f"DBUTILS: new convo id is {convo_id}")
42
  st.session_state['convo_id'] = convo_id
43
 
44
+ def update_convo(client, convo_id, transcript=""):
45
+ from bson.objectid import ObjectId
46
+ db = client[DB_SCHEMA]
47
+ convos = db[DB_CONVOS]
48
+ myquery = { "_id": ObjectId(convo_id) }
49
+ newvalues = { "$set": { "transcript": transcript } }
50
+ result = convos.update_one(myquery, newvalues)
51
+ if result.matched_count == 1:
52
+ logger.debug(f"Updated conversation {convo_id}")
53
+
54
  def new_comparison(client, prompt_timestamp, completion_timestamp,
55
  chat_history, prompt, completionA, completionB,
56
  source="webapp", subset=None