wellbeing_GenAI / app.py
tahirsher's picture
Update app.py
57f5891 verified
import streamlit as st
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
pipeline,
)
#from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, pipeline
#from llama_cpp import Llama
from datasets import load_dataset
import os
import requests
# Replace with the direct image URL
flower_image_url = "https://i.postimg.cc/hG2FG85D/2.png"
# Inject custom CSS for the background with a centered and blurred image
st.markdown(
f"""
<style>
/* Container for background */
html, body {{
margin: 0;
padding: 0;
overflow: hidden;
}}
[data-testid="stAppViewContainer"] {{
position: relative;
z-index: 1; /* Ensure UI elements are above the background */
}}
/* Blurred background image */
.blurred-background {{
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
z-index: -1; /* Send background image behind all UI elements */
background-image: url("{flower_image_url}");
background-size: cover;
background-position: center;
filter: blur(10px); /* Adjust blur ratio here */
opacity: 0.8; /* Optional: Add slight transparency for a subtle effect */
}}
</style>
""",
unsafe_allow_html=True
)
# Add the blurred background div
st.markdown('<div class="blurred-background"></div>', unsafe_allow_html=True)
#""""""""""""""""""""""""" Application Code Starts here """""""""""""""""""""""""""""""""""""""""""""
# Cache resource for dataset loading
@st.cache_resource
def load_counseling_dataset():
# Load a smaller subset of the dataset for memory efficiency
dataset = load_dataset("Amod/mental_health_counseling_conversations", split="train")
return dataset
# Process the dataset in batches to avoid memory overuse
def process_dataset_in_batches(dataset, batch_size=500):
for example in dataset.shuffle().select(range(batch_size)):
yield example
# Fine-tune the model and save it
@st.cache_resource
def fine_tune_model():
# Load base model and tokenizer
model_name = "prabureddy/Mental-Health-FineTuned-Mistral-7B-Instruct-v0.2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Enable gradient checkpointing for memory optimization
model.gradient_checkpointing_enable()
# Prepare dataset for training
dataset = load_counseling_dataset()
def preprocess_function(examples):
return tokenizer(examples["context"] + "\n" + examples["response"], truncation=True)
tokenized_datasets = dataset.map(preprocess_function, batched=True)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Training arguments
training_args = TrainingArguments(
output_dir="./fine_tuned_model",
evaluation_strategy="steps",
learning_rate=2e-5,
per_device_train_batch_size=5,
per_device_eval_batch_size=5,
num_train_epochs=3,
weight_decay=0.01,
fp16=True, # Enable FP16 for lower memory usage
save_total_limit=2,
save_steps=250,
logging_steps=50,
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
tokenizer=tokenizer,
data_collator=data_collator,
)
trainer.train()
# Save the fine-tuned model
trainer.save_model("./fine_tuned_model")
tokenizer.save_pretrained("./fine_tuned_model")
return "./fine_tuned_model"
# Load or fine-tune the model
model_dir = fine_tune_model()
# Load the fine-tuned model for inference
@st.cache_resource
def load_pipeline(model_dir):
return pipeline("text-generation", model=model_dir)
pipe = load_pipeline(model_dir)
# Streamlit App
st.title("Mental Health Support Assistant")
st.markdown("""
Welcome to the **Mental Health Support Assistant**.
This tool helps detect potential mental health concerns based on user input and provides **uplifting and positive suggestions** to boost morale.
""")
# User input for mental health concerns
user_input = st.text_area("Please share your concern:", placeholder="Type your question or concern here...")
if st.button("Get Supportive Response"):
if user_input.strip():
with st.spinner("Analyzing your input and generating a response..."):
try:
# Generate a response
response = pipe(user_input, max_length=150, num_return_sequences=1)[0]["generated_text"]
st.subheader("Supportive Suggestion:")
st.markdown(f"**{response}**")
except Exception as e:
st.error(f"An error occurred while generating the response: {e}")
else:
st.error("Please enter a concern to receive suggestions.")
# Sidebar for additional resources
st.sidebar.header("Additional Resources")
st.sidebar.markdown("""
- [Mental Health Foundation](https://www.mentalhealth.org)
- [Mind](https://www.mind.org.uk)
- [National Suicide Prevention Lifeline](https://suicidepreventionlifeline.org)
""")
st.sidebar.info("This application is not a replacement for professional counseling. If you're in crisis, seek professional help immediately.")