kishoregajjala
commited on
Commit
•
a521442
1
Parent(s):
03cda43
Team-5 final demo
Browse files- app.py +149 -0
- chat_agent.py +114 -0
- llama_guard.py +115 -0
- prompts/llama_guard-unsafe_categories.txt +36 -0
- recommendation_agent.py +91 -0
- requirements.txt +19 -0
app.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from llama_guard import moderate_chat, get_category_name
|
3 |
+
import time
|
4 |
+
from chat_agent import convo, main
|
5 |
+
from chat_agent import choose_model1, delete_all_variables
|
6 |
+
from recommendation_agent import recommend2, choose_model2, is_depressed
|
7 |
+
from functools import cached_property
|
8 |
+
from streamlit_js_eval import streamlit_js_eval
|
9 |
+
|
10 |
+
|
11 |
+
# ST : https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
|
12 |
+
|
13 |
+
# Set the page to wide mode
|
14 |
+
st.set_page_config(layout="wide")
|
15 |
+
|
16 |
+
# Set the title
|
17 |
+
st.title('BrighterDays Mentor')
|
18 |
+
|
19 |
+
# Adjust sidebar width to take half the screen
|
20 |
+
col1, col2 = st.columns([2, 3])
|
21 |
+
|
22 |
+
model = st.sidebar.selectbox(label="Choose the LLM model", options=["Venilla Model", "Fine Tuned Model"])
|
23 |
+
print("\n\nSelected LLM model from Dropdown",model)
|
24 |
+
choose_model1(model)
|
25 |
+
choose_model2(model)
|
26 |
+
main()
|
27 |
+
# Function to update recommendations in col1
|
28 |
+
def update_recommendations(sum):
|
29 |
+
# with col1:
|
30 |
+
# st.header("Recommendation")
|
31 |
+
# recommend = recommend2(sum)
|
32 |
+
# st.write(recommend) # Update the content with new_content
|
33 |
+
with st.sidebar:
|
34 |
+
st.divider()
|
35 |
+
st.write("Potential Mental Health Condition:")
|
36 |
+
st.write(is_depressed(sum))
|
37 |
+
st.header("Mental Health Advice:")
|
38 |
+
with st.spinner('Thinking...'):
|
39 |
+
#time.sleep(5)
|
40 |
+
recommend = recommend2(sum) # Assuming recommend2 doesn't require input
|
41 |
+
st.write(recommend)
|
42 |
+
|
43 |
+
# Add refresh button (simulated)
|
44 |
+
# if st.button("Refresh Chat"):
|
45 |
+
# del st.session_state
|
46 |
+
# delete_all_variables(True)
|
47 |
+
# startup()
|
48 |
+
# st.rerun()
|
49 |
+
|
50 |
+
@cached_property
|
51 |
+
def get_recommendations():
|
52 |
+
return "These are some updated recommendations."
|
53 |
+
|
54 |
+
|
55 |
+
def response_generator(response):
|
56 |
+
'''
|
57 |
+
responds the text with a type writter effect
|
58 |
+
'''
|
59 |
+
response_buffer = response.strip()
|
60 |
+
for word in response_buffer.split():
|
61 |
+
yield word + " "
|
62 |
+
time.sleep(0.03)
|
63 |
+
|
64 |
+
|
65 |
+
def startup():
|
66 |
+
with st.chat_message("assistant"):
|
67 |
+
time.sleep(0.2)
|
68 |
+
st.markdown("Hi, I am your Mental Health Counselar. How can I help you today?")
|
69 |
+
|
70 |
+
# Initialize chat history
|
71 |
+
if "messages" not in st.session_state:
|
72 |
+
st.session_state.messages = []
|
73 |
+
|
74 |
+
for message in st.session_state.messages:
|
75 |
+
with st.chat_message(message["role"]):
|
76 |
+
st.markdown(message["content"])
|
77 |
+
|
78 |
+
|
79 |
+
# Check if 'llama_guard_enabled' is already in session state, otherwise initialize it
|
80 |
+
if 'llama_guard_enabled' not in st.session_state:
|
81 |
+
st.session_state['llama_guard_enabled'] = True # Default value to True
|
82 |
+
|
83 |
+
# Modify the checkbox call to include a unique key parameter
|
84 |
+
llama_guard_enabled = st.sidebar.checkbox("Enable LlamaGuard",
|
85 |
+
value=st.session_state['llama_guard_enabled'],
|
86 |
+
key="llama_guard_toggle")
|
87 |
+
|
88 |
+
|
89 |
+
# Update the session state based on the checkbox interaction
|
90 |
+
st.session_state['llama_guard_enabled'] = llama_guard_enabled
|
91 |
+
|
92 |
+
#with st.chat_message("assistant"):
|
93 |
+
#st.write("Please tell me about your mental health condition and we can explore together. Potential mental health advice that could help you will be in the sidebar as we talk")
|
94 |
+
|
95 |
+
# Accept user input
|
96 |
+
#if user_prompt := st.chat_input("Hello, How are you doing today"):
|
97 |
+
if user_prompt := st.chat_input(""):
|
98 |
+
st.session_state.messages.append({"role": "user", "content": user_prompt})
|
99 |
+
with st.chat_message("user"):
|
100 |
+
st.markdown(user_prompt)
|
101 |
+
|
102 |
+
with st.chat_message("assistant"):
|
103 |
+
print('llama guard enabled',st.session_state['llama_guard_enabled'])
|
104 |
+
is_safe = True
|
105 |
+
unsafe_category_name = ""
|
106 |
+
#added on March 29th
|
107 |
+
response = ""
|
108 |
+
if st.session_state['llama_guard_enabled']:
|
109 |
+
#guard_status = moderate_chat(user_prompt)
|
110 |
+
guard_status, error = moderate_chat(user_prompt)
|
111 |
+
if error:
|
112 |
+
st.error(f"Failed to retrieve data from Llama Gaurd: {error}")
|
113 |
+
else:
|
114 |
+
if 'unsafe' in guard_status[0]['generated_text']:
|
115 |
+
is_safe = False
|
116 |
+
#added on March 24th
|
117 |
+
unsafe_category_name = get_category_name(guard_status[0]['generated_text'])
|
118 |
+
print(f'Guard status {guard_status}, Category name {unsafe_category_name}')
|
119 |
+
if is_safe==False:
|
120 |
+
#added on March 24th
|
121 |
+
response = f"I see you are asking something about {unsafe_category_name} Due to eithical and safety reasons, I can't provide the help you need. Please reach out to someone who can, like a family member, friend, or therapist. In urgent situations, contact emergency services or a crisis hotline. Remember, asking for help is brave, and you're not alone."
|
122 |
+
st.write_stream(response_generator(response))
|
123 |
+
response,summary = convo("")
|
124 |
+
st.write_stream(response_generator(response))
|
125 |
+
#update_recommendations(summary)
|
126 |
+
else:
|
127 |
+
response,summary = convo(user_prompt)
|
128 |
+
# print(conversation.memory.buffer)
|
129 |
+
time.sleep(0.2)
|
130 |
+
st.write_stream(response_generator(response))
|
131 |
+
print("This is the response from app.py",response)
|
132 |
+
update_recommendations(summary)
|
133 |
+
|
134 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
135 |
+
|
136 |
+
# if st.button("Refresh Chat"):
|
137 |
+
# st.session_state={'messages': []}
|
138 |
+
# print("\n\n refressed session state:::::::::::::::;",st.session_state)
|
139 |
+
# startup()
|
140 |
+
# st.rerun()
|
141 |
+
# delete_all_variables(True)
|
142 |
+
# startup()
|
143 |
+
|
144 |
+
|
145 |
+
if st.button("Reset Chat"):
|
146 |
+
delete_all_variables()
|
147 |
+
streamlit_js_eval(js_expressions="parent.window.location.reload()")
|
148 |
+
|
149 |
+
startup()
|
chat_agent.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import streamlit as st
|
3 |
+
from dotenv import load_dotenv, find_dotenv
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
from langchain.chains import LLMChain
|
7 |
+
from langchain_community.llms import HuggingFaceEndpoint
|
8 |
+
from langchain.prompts import PromptTemplate
|
9 |
+
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
|
10 |
+
from langchain.prompts.chat import (
|
11 |
+
ChatPromptTemplate,
|
12 |
+
HumanMessagePromptTemplate,
|
13 |
+
SystemMessagePromptTemplate,
|
14 |
+
)
|
15 |
+
from langchain.memory import ChatMessageHistory, ConversationSummaryBufferMemory, ConversationBufferMemory, ConversationSummaryMemory
|
16 |
+
from langchain.chains import LLMChain, ConversationChain
|
17 |
+
|
18 |
+
|
19 |
+
# Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN'
|
20 |
+
load_dotenv(find_dotenv())
|
21 |
+
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
|
22 |
+
|
23 |
+
repo_id ="mistralai/Mistral-7B-Instruct-v0.2"
|
24 |
+
def choose_model1(model):
|
25 |
+
global repo_id
|
26 |
+
if model == "Venilla Model":
|
27 |
+
repo_id="mistralai/Mistral-7B-Instruct-v0.2"
|
28 |
+
print("model chooosed from chat",repo_id)
|
29 |
+
else:
|
30 |
+
repo_id="GRMenon/mental-health-mistral-7b-instructv0.2-finetuned-V2"
|
31 |
+
print("model chooosed from chat",repo_id)
|
32 |
+
|
33 |
+
query2 = " "
|
34 |
+
def main():
|
35 |
+
llm = HuggingFaceEndpoint(
|
36 |
+
repo_id=repo_id, max_length=512, temperature=0.5, token=HUGGINGFACEHUB_API_TOKEN
|
37 |
+
)
|
38 |
+
|
39 |
+
# template="""Act as a therapist, and conduct therapy sessions with the user. Your goal analyse their mental health
|
40 |
+
# problem, based following input:{query}. Do not show your thought process, only output a single question.
|
41 |
+
# Your output should contain consolation related to the query and a single question. Only ask one question at a time."""
|
42 |
+
|
43 |
+
# def ConvoLLM(query: str):
|
44 |
+
# prompt_template=PromptTemplate(input_variables=['query'],template= template)
|
45 |
+
# prompt_template.format(query= query)
|
46 |
+
# chain=LLMChain(llm=llm,prompt=prompt_template)
|
47 |
+
# response = chain.run(query)
|
48 |
+
# return response
|
49 |
+
|
50 |
+
#---------------------------------------------------------------------------------------------------------------------------------------
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
# memory = ConversationSummaryBufferMemory(llm=llm, max_token_limit=10)
|
55 |
+
# memory.save_context({"input": "hi"}, {"output": "whats up"})
|
56 |
+
|
57 |
+
# def ConvoLLM(query: str):
|
58 |
+
# conversation.predict(input=query)
|
59 |
+
|
60 |
+
#---------------------------------------------------------------------------------------------------------------------------------------
|
61 |
+
# print(conversation.predict(input="I am feeling low"))
|
62 |
+
# print(conversation.predict(input="I am alone at home"))
|
63 |
+
# print(conversation.memory.buffer)
|
64 |
+
global conversation,memory
|
65 |
+
template = """ Act as an expert mental health therapist, and conduct therapy sessions with the user. You are an expert Mental Health therapist who is asking the user questions to learn what professional mental health well-being advice could help the user.
|
66 |
+
Your goal is to analyse their mental health problem, based following input:{input}. You will always ask questions to the user to get them to explain more about whatever mental health condition is ailing them.
|
67 |
+
DO NOT give the user any mental health advice or medical advice, ONLY ask for more information about their symptoms.
|
68 |
+
Do not show your thought process, only output a single question. Your output should contain consolation related to the query and a single question.
|
69 |
+
Only ask one question a time.
|
70 |
+
|
71 |
+
Current conversation:
|
72 |
+
{history}
|
73 |
+
|
74 |
+
Human: {input}
|
75 |
+
AI Assistant:"""
|
76 |
+
|
77 |
+
|
78 |
+
PROMPT = PromptTemplate(input_variables=["history","input"], template=template)
|
79 |
+
memory = ConversationBufferMemory(llm=llm)
|
80 |
+
# memory.save_context({"input": "hi"}, {"output": "whats up"})
|
81 |
+
# memory.save_context({"input": "not much you"}, {"output": "not much"})
|
82 |
+
# memory.save_context({"input": "feeling sad"}, {"output": "I am happy you feel that way"})
|
83 |
+
|
84 |
+
conversation = ConversationChain(
|
85 |
+
prompt=PROMPT,
|
86 |
+
llm=llm,
|
87 |
+
memory=memory,
|
88 |
+
# verbose=True
|
89 |
+
)
|
90 |
+
|
91 |
+
|
92 |
+
def convo(query):
|
93 |
+
global conversation, memory, query2
|
94 |
+
response = conversation.predict(input=query)
|
95 |
+
# memory.save_context({"input": query}, {"output": ""})
|
96 |
+
query2 = query2 + "," + query
|
97 |
+
print("\n query2----------",query2)
|
98 |
+
print("\n chat_agent.py----------",memory.chat_memory)
|
99 |
+
summary = query2
|
100 |
+
return response, summary
|
101 |
+
|
102 |
+
|
103 |
+
def delete_all_variables():
|
104 |
+
global query2
|
105 |
+
query2 = " "
|
106 |
+
main()
|
107 |
+
|
108 |
+
# main()
|
109 |
+
|
110 |
+
# convo("I am feeling sad")
|
111 |
+
# convo("I am feeling Lonely")
|
112 |
+
# delete_all_variables()
|
113 |
+
# convo("I am feeling hungry")
|
114 |
+
|
llama_guard.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Docs:- https://huggingface.co/meta-llama/LlamaGuard-7b
|
2 |
+
from dotenv import load_dotenv, find_dotenv
|
3 |
+
import os
|
4 |
+
import requests
|
5 |
+
model_id = "meta-llama/LlamaGuard-7b"
|
6 |
+
|
7 |
+
|
8 |
+
# Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN'
|
9 |
+
load_dotenv(find_dotenv())
|
10 |
+
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
|
11 |
+
|
12 |
+
#updated on March 24th
|
13 |
+
|
14 |
+
#reading unsafe categories file
|
15 |
+
try:
|
16 |
+
with open('./prompts/llama_guard-unsafe_categories.txt', 'r') as file:
|
17 |
+
unsafe_categories = file.read()
|
18 |
+
except FileNotFoundError:
|
19 |
+
print("File not found. Please check the file path.")
|
20 |
+
unsafe_categories = "" # Or handle the missing file as appropriate for your application
|
21 |
+
except PermissionError:
|
22 |
+
print("Permission denied. You don't have the rights to read the file.")
|
23 |
+
unsafe_categories = "" # Or handle the lack of permissions as appropriate
|
24 |
+
except Exception as e: # Catch any other exceptions
|
25 |
+
print(f"An error occurred while reading the file: {e}")
|
26 |
+
unsafe_categories = "" # Or handle unexpected errors as appropriate
|
27 |
+
|
28 |
+
#added on March 24th
|
29 |
+
def moderation_prompt_for_chat(chat):
|
30 |
+
prompt = f"""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
|
31 |
+
<BEGIN UNSAFE CONTENT CATEGORIES>
|
32 |
+
{unsafe_categories}
|
33 |
+
<END UNSAFE CONTENT CATEGORIES>
|
34 |
+
|
35 |
+
<BEGIN CONVERSATION>
|
36 |
+
|
37 |
+
User : {chat}
|
38 |
+
|
39 |
+
<END CONVERSATION>
|
40 |
+
|
41 |
+
Provide your safety assessment for 'User' in the above conversation:
|
42 |
+
- First line must read 'safe' or 'unsafe'.
|
43 |
+
- If unsafe, a second line must include a comma-separated list of violated categories.[/INST]"""
|
44 |
+
return prompt
|
45 |
+
|
46 |
+
def query(payload):
|
47 |
+
API_URL = "https://okoknht2arqo574k.us-east-1.aws.endpoints.huggingface.cloud"
|
48 |
+
bearer_txt = f'Bearer {HUGGINGFACEHUB_API_TOKEN}'
|
49 |
+
headers = {
|
50 |
+
"Accept": "application/json",
|
51 |
+
"Authorization": bearer_txt,
|
52 |
+
"Content-Type": "application/json"
|
53 |
+
}
|
54 |
+
try:
|
55 |
+
response = requests.post(API_URL, headers=headers, json=payload)
|
56 |
+
response.raise_for_status() # This will raise an exception for HTTP error responses
|
57 |
+
return response.json(), None
|
58 |
+
except requests.exceptions.HTTPError as http_err:
|
59 |
+
error_message = f"HTTP error occurred: {http_err}"
|
60 |
+
print(error_message)
|
61 |
+
except requests.exceptions.ConnectionError:
|
62 |
+
error_message = "Could not connect to the API endpoint."
|
63 |
+
print(error_message)
|
64 |
+
except Exception as err:
|
65 |
+
error_message = f"An error occurred: {err}"
|
66 |
+
print(error_message)
|
67 |
+
|
68 |
+
return None, error_message
|
69 |
+
|
70 |
+
|
71 |
+
def moderate_chat(chat):
|
72 |
+
prompt = moderation_prompt_for_chat(chat)
|
73 |
+
|
74 |
+
output, error_msg = query({
|
75 |
+
"inputs": prompt,
|
76 |
+
"parameters": {
|
77 |
+
"top_k": 1,
|
78 |
+
"top_p": 0.2,
|
79 |
+
"temperature": 0.1,
|
80 |
+
"max_new_tokens": 512
|
81 |
+
}
|
82 |
+
})
|
83 |
+
|
84 |
+
return output, error_msg
|
85 |
+
|
86 |
+
|
87 |
+
#added on March 24th
|
88 |
+
def load_category_names_from_string(file_content):
|
89 |
+
"""Load category codes and names from a string into a dictionary."""
|
90 |
+
category_names = {}
|
91 |
+
lines = file_content.split('\n')
|
92 |
+
for line in lines:
|
93 |
+
if line.startswith("O"):
|
94 |
+
parts = line.split(':')
|
95 |
+
if len(parts) == 2:
|
96 |
+
code = parts[0].strip()
|
97 |
+
name = parts[1].strip()
|
98 |
+
category_names[code] = name
|
99 |
+
return category_names
|
100 |
+
|
101 |
+
def get_category_name(input_str):
|
102 |
+
"""Return the category name given a category code from an input string."""
|
103 |
+
# Load the category names from the file content
|
104 |
+
category_names = load_category_names_from_string(unsafe_categories)
|
105 |
+
|
106 |
+
# Extract the category code from the input string
|
107 |
+
category_code = input_str.split('\n')[1].strip()
|
108 |
+
|
109 |
+
# Find the full category name using the code
|
110 |
+
category_name = category_names.get(category_code, "Unknown Category")
|
111 |
+
|
112 |
+
#return f"{category_code} : {category_name}"
|
113 |
+
return f"{category_name}"
|
114 |
+
|
115 |
+
|
prompts/llama_guard-unsafe_categories.txt
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
O1: Violence and Hate.
|
2 |
+
Should not
|
3 |
+
- Help people plan or engage in violence.
|
4 |
+
- Encourage people to engage in violence.
|
5 |
+
- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.
|
6 |
+
- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.
|
7 |
+
Can
|
8 |
+
- Provide information on violence and discrimination.
|
9 |
+
- Discuss the topics of hate and violence, and can discuss historical events involving violence.
|
10 |
+
O2: Sexual Content.
|
11 |
+
Should not
|
12 |
+
- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.
|
13 |
+
Can
|
14 |
+
- Discuss topics related to sexuality and sexual education.
|
15 |
+
- Provide general information regarding sexual activities.
|
16 |
+
- Provide guidance on topics related to sexual health.
|
17 |
+
O3: Criminal Planning.
|
18 |
+
Should not
|
19 |
+
- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.
|
20 |
+
- Encourage people to engage in criminal activities or condone criminal behavior.
|
21 |
+
Can
|
22 |
+
- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.
|
23 |
+
- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems.
|
24 |
+
O4: Guns and Illegal Weapons.
|
25 |
+
Should not
|
26 |
+
- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.
|
27 |
+
- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.
|
28 |
+
Can
|
29 |
+
- Discuss firearms and the arguments for and against firearm ownership.
|
30 |
+
O5: Regulated or Controlled Substances.
|
31 |
+
Should not
|
32 |
+
- Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis.
|
33 |
+
- Assist or encourage people to create such substances when it is illegal to do so.
|
34 |
+
Can
|
35 |
+
- Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are).
|
36 |
+
- Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation.
|
recommendation_agent.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.embeddings.sentence_transformer import (
|
2 |
+
SentenceTransformerEmbeddings,
|
3 |
+
)
|
4 |
+
from langchain.retrievers import ParentDocumentRetriever
|
5 |
+
from langchain_community.vectorstores import Chroma
|
6 |
+
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter
|
7 |
+
import pickle
|
8 |
+
from langchain_community.llms import HuggingFaceEndpoint
|
9 |
+
from dotenv import load_dotenv, find_dotenv
|
10 |
+
import os
|
11 |
+
from langchain.chains import ConversationalRetrievalChain
|
12 |
+
from langchain.chains import LLMChain
|
13 |
+
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
14 |
+
from langchain.chains.question_answering import load_qa_chain
|
15 |
+
from langchain import PromptTemplate
|
16 |
+
from transformers import pipeline
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
# Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN'
|
21 |
+
load_dotenv(find_dotenv())
|
22 |
+
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
|
23 |
+
|
24 |
+
repo_id ="mistralai/Mistral-7B-Instruct-v0.2"
|
25 |
+
def choose_model2(model):
|
26 |
+
global repo_id
|
27 |
+
if model == "Venilla Model":
|
28 |
+
repo_id="mistralai/Mistral-7B-Instruct-v0.2"
|
29 |
+
print("model chooosed from recomm",repo_id)
|
30 |
+
else:
|
31 |
+
repo_id="GRMenon/mental-health-mistral-7b-instructv0.2-finetuned-V2"
|
32 |
+
print("model chooosed from recomm",repo_id)
|
33 |
+
|
34 |
+
llm = HuggingFaceEndpoint(
|
35 |
+
repo_id=repo_id, max_length=512, temperature=0.5, token=HUGGINGFACEHUB_API_TOKEN
|
36 |
+
)
|
37 |
+
|
38 |
+
persist_directory="Data/chroma"
|
39 |
+
#chroma_client = chromadb.PersistentClient(persist_directory=persist_directory)
|
40 |
+
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
41 |
+
vectors = Chroma(persist_directory = persist_directory, embedding_function = embedding_function, collection_name="split_parents")
|
42 |
+
retriever = vectors.as_retriever() #(k=6)
|
43 |
+
|
44 |
+
#prompt="you are a mental health therapist, talking to a person with who is facing some mental health issues. Following is the user feeling {question}"
|
45 |
+
|
46 |
+
prompt = """You're a Mental Health Specialist. Support those with Depressive Disorder. Your task is to provide mental health advice.
|
47 |
+
Listen compassionately, respond helpfully. For casual talk, be friendly. For facts, use context.
|
48 |
+
Following is the user feeling {question}.
|
49 |
+
If unsure, say, 'Out of my knowledge.' Always stay direct.
|
50 |
+
If you cannot find the answer from the pieces of context, just say that you don't know, don't try to make up an answer.
|
51 |
+
PLEASE GIVE THE RESPONSE IN THE FORM OF BULLET POINTS.
|
52 |
+
----------------
|
53 |
+
{context}"""
|
54 |
+
|
55 |
+
prompt = PromptTemplate(input_variables=['question'],template=prompt)
|
56 |
+
|
57 |
+
chain1 = LLMChain(llm=llm, prompt=prompt, verbose=True)
|
58 |
+
doc_chain = load_qa_chain(llm, chain_type="stuff")
|
59 |
+
|
60 |
+
chain = ConversationalRetrievalChain(
|
61 |
+
retriever=retriever,
|
62 |
+
question_generator=chain1,
|
63 |
+
combine_docs_chain=doc_chain,
|
64 |
+
verbose=True,
|
65 |
+
)
|
66 |
+
|
67 |
+
|
68 |
+
def recommend2(query):
|
69 |
+
|
70 |
+
chat_history = []
|
71 |
+
# query = "I feel sad"
|
72 |
+
result = chain.invoke({"question": query, "chat_history": chat_history})
|
73 |
+
print("---------------\nSummary from the chat agent:",query)
|
74 |
+
# print("this is the result from vector database:",vectors.similarity_search(query))
|
75 |
+
# print("this is the response from recommendation agent:", result["answer"])
|
76 |
+
|
77 |
+
return result["answer"]
|
78 |
+
|
79 |
+
def is_depressed(human_inputs):
|
80 |
+
''''
|
81 |
+
returns wether according to human inputs the person is depressed or not
|
82 |
+
'''
|
83 |
+
# Implement Classification
|
84 |
+
# all_user_inputs = ''.join(human_inputs)
|
85 |
+
pipe = pipeline('sentiment-analysis')
|
86 |
+
status = pipe(human_inputs)
|
87 |
+
return 'Is depressed' if status[0]["label"] == "NEGATIVE" else 'Not Depressed'
|
88 |
+
# return status[0]["label"]
|
89 |
+
|
90 |
+
|
91 |
+
# print(recommend2("i am feeling sad"))
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#create new env
|
2 |
+
#conda create --name LLM_chatbot
|
3 |
+
#activate the env
|
4 |
+
#conda activate LLM_chatbot
|
5 |
+
#pip install -r requirements.txt
|
6 |
+
#if streamlit is still unrecognized run this "conda install -c conda-forge streamlit"
|
7 |
+
#to run stremlit use streamlit run streamlit_ui.py
|
8 |
+
langchain==0.1.11
|
9 |
+
torch==2.0.1
|
10 |
+
transformers==4.36.2
|
11 |
+
langchain-community==0.0.27
|
12 |
+
streamlit==1.32.2
|
13 |
+
ctransformers==0.2.27
|
14 |
+
pymupdf==1.23.26
|
15 |
+
sentence-transformers==2.5.1
|
16 |
+
chromadb==0.4.24
|
17 |
+
langchain_experimental
|
18 |
+
accelerate
|
19 |
+
streamlit-js-eval
|