Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
TrainingArguments, | |
Trainer, | |
DataCollatorForLanguageModeling, | |
pipeline, | |
) | |
#from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, pipeline | |
#from llama_cpp import Llama | |
from datasets import load_dataset | |
import os | |
import requests | |
# Replace with the direct image URL | |
flower_image_url = "https://i.postimg.cc/hG2FG85D/2.png" | |
# Inject custom CSS for the background with a centered and blurred image | |
st.markdown( | |
f""" | |
<style> | |
/* Container for background */ | |
html, body {{ | |
margin: 0; | |
padding: 0; | |
overflow: hidden; | |
}} | |
[data-testid="stAppViewContainer"] {{ | |
position: relative; | |
z-index: 1; /* Ensure UI elements are above the background */ | |
}} | |
/* Blurred background image */ | |
.blurred-background {{ | |
position: fixed; | |
top: 0; | |
left: 0; | |
width: 100%; | |
height: 100%; | |
z-index: -1; /* Send background image behind all UI elements */ | |
background-image: url("{flower_image_url}"); | |
background-size: cover; | |
background-position: center; | |
filter: blur(10px); /* Adjust blur ratio here */ | |
opacity: 0.8; /* Optional: Add slight transparency for a subtle effect */ | |
}} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
# Add the blurred background div | |
st.markdown('<div class="blurred-background"></div>', unsafe_allow_html=True) | |
#""""""""""""""""""""""""" Application Code Starts here """"""""""""""""""""""""""""""""""""""""""""" | |
# Cache resource for dataset loading | |
def load_counseling_dataset(): | |
# Load a smaller subset of the dataset for memory efficiency | |
dataset = load_dataset("Amod/mental_health_counseling_conversations", split="train") | |
return dataset | |
# Process the dataset in batches to avoid memory overuse | |
def process_dataset_in_batches(dataset, batch_size=500): | |
for example in dataset.shuffle().select(range(batch_size)): | |
yield example | |
# Fine-tune the model and save it | |
def fine_tune_model(): | |
# Load base model and tokenizer | |
model_name = "prabureddy/Mental-Health-FineTuned-Mistral-7B-Instruct-v0.2" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Enable gradient checkpointing for memory optimization | |
model.gradient_checkpointing_enable() | |
# Prepare dataset for training | |
dataset = load_counseling_dataset() | |
def preprocess_function(examples): | |
return tokenizer(examples["context"] + "\n" + examples["response"], truncation=True) | |
tokenized_datasets = dataset.map(preprocess_function, batched=True) | |
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
# Training arguments | |
training_args = TrainingArguments( | |
output_dir="./fine_tuned_model", | |
evaluation_strategy="steps", | |
learning_rate=2e-5, | |
per_device_train_batch_size=5, | |
per_device_eval_batch_size=5, | |
num_train_epochs=3, | |
weight_decay=0.01, | |
fp16=True, # Enable FP16 for lower memory usage | |
save_total_limit=2, | |
save_steps=250, | |
logging_steps=50, | |
) | |
# Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_datasets["train"], | |
eval_dataset=tokenized_datasets["validation"], | |
tokenizer=tokenizer, | |
data_collator=data_collator, | |
) | |
trainer.train() | |
# Save the fine-tuned model | |
trainer.save_model("./fine_tuned_model") | |
tokenizer.save_pretrained("./fine_tuned_model") | |
return "./fine_tuned_model" | |
# Load or fine-tune the model | |
model_dir = fine_tune_model() | |
# Load the fine-tuned model for inference | |
def load_pipeline(model_dir): | |
return pipeline("text-generation", model=model_dir) | |
pipe = load_pipeline(model_dir) | |
# Streamlit App | |
st.title("Mental Health Support Assistant") | |
st.markdown(""" | |
Welcome to the **Mental Health Support Assistant**. | |
This tool helps detect potential mental health concerns based on user input and provides **uplifting and positive suggestions** to boost morale. | |
""") | |
# User input for mental health concerns | |
user_input = st.text_area("Please share your concern:", placeholder="Type your question or concern here...") | |
if st.button("Get Supportive Response"): | |
if user_input.strip(): | |
with st.spinner("Analyzing your input and generating a response..."): | |
try: | |
# Generate a response | |
response = pipe(user_input, max_length=150, num_return_sequences=1)[0]["generated_text"] | |
st.subheader("Supportive Suggestion:") | |
st.markdown(f"**{response}**") | |
except Exception as e: | |
st.error(f"An error occurred while generating the response: {e}") | |
else: | |
st.error("Please enter a concern to receive suggestions.") | |
# Sidebar for additional resources | |
st.sidebar.header("Additional Resources") | |
st.sidebar.markdown(""" | |
- [Mental Health Foundation](https://www.mentalhealth.org) | |
- [Mind](https://www.mind.org.uk) | |
- [National Suicide Prevention Lifeline](https://suicidepreventionlifeline.org) | |
""") | |
st.sidebar.info("This application is not a replacement for professional counseling. If you're in crisis, seek professional help immediately.") | |