Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
from sklearn.cluster import KMeans | |
from sentence_transformers import SentenceTransformer | |
from keybert import KeyBERT | |
import numpy as np | |
import os | |
import io | |
from crewai import Agent, Task, Crew | |
from langchain_community.llms import HuggingFaceHub | |
from langchain_huggingface import HuggingFaceEndpoint | |
# === CONFIGURATION === | |
HUGGINGFACEHUB_API_TOKEN = os.getenv("HF_API_TOKEN") # Set this in environment | |
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2" # Publicly available! | |
# Setup LLM via HuggingFace Hub | |
llm = HuggingFaceEndpoint( | |
repo_id=MODEL_NAME, | |
max_length=128, | |
temperature=0.4, | |
token=HUGGINGFACEHUB_API_TOKEN | |
) | |
# Load embedding model and session state as before... | |
# Setup LLM via HuggingFace Hub | |
llm = HuggingFaceHub( | |
repo_id=MODEL_NAME, | |
huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN, | |
model_kwargs={"temperature": 0.4, "max_new_tokens": 64} | |
) | |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2") | |
keyword_extractor = KeyBERT(model="distilbert-base-nli-mean-tokens") | |
session = { | |
"original_df": None, | |
"current_df": None, | |
"context": "", | |
"topic_labels": {}, | |
"keywords": {}, | |
"clusters_verified": False | |
} | |
# === AGENTS === | |
keyword_agent = Agent( | |
role='Keyword Analyst', | |
goal='Extract top 5 keywords from a group of similar texts', | |
backstory="""You are a skilled keyword analyst who identifies patterns in text data. | |
You focus on extracting concise, meaningful keywords that represent the core themes.""", | |
llm=llm, | |
verbose=False | |
) | |
labeling_agent = Agent( | |
role='Topic Labeler', | |
goal='Generate a short label for a group of similar texts based on context', | |
backstory="""You are a professional theme summarizer. Given example texts and a user context, | |
you generate clear and actionable topic labels.""", | |
llm=llm, | |
verbose=False | |
) | |
validation_agent = Agent( | |
role='QA Analyst', | |
goal='Evaluate whether the clustered topics and keywords form coherent themes', | |
backstory="""You are a quality assurance expert evaluating if generated topics make sense. | |
You return 'Approved' or 'Needs Refinement' based on coherence.""", | |
llm=llm, | |
verbose=False | |
) | |
finalizer_agent = Agent( | |
role='Data Engineer', | |
goal='Prepare final labeled dataset for download', | |
backstory="""You finalize the structured output file after approval and ensure it's ready for export.""", | |
llm=llm, | |
verbose=False | |
) | |
# === TASKS === | |
def create_tasks(text_samples, context_input): | |
extract_keywords_task = Task( | |
description=f"Extract 5 most relevant keywords from the following sample texts:\n\n{text_samples}", | |
agent=keyword_agent, | |
expected_output="Comma-separated list of keywords" | |
) | |
label_topic_task = Task( | |
description=f"Based on the following examples and instruction: '{context_input}', generate a concise topic label.\n\n{text_samples}", | |
agent=labeling_agent, | |
expected_output="A single line topic label" | |
) | |
validate_cluster_task = Task( | |
description=f"Evaluate whether the topic label and keywords make sense together.\n\nLABEL: {{label}}\nKEYWORDS: {{keywords}}", | |
agent=validation_agent, | |
expected_output="'Approved' or 'Needs Refinement'" | |
) | |
finalize_data_task = Task( | |
description="Take the approved labeled DataFrame and format it for download.", | |
agent=finalizer_agent, | |
expected_output="Final CSV content as string" | |
) | |
return extract_keywords_task, label_topic_task, validate_cluster_task, finalize_data_task | |
# === CLUSTERING === | |
def cluster_texts(texts, n_clusters=10): | |
embeddings = embedding_model.encode(texts, show_progress_bar=False) | |
kmeans = KMeans(n_clusters=n_clusters, random_state=42) | |
return kmeans.fit_predict(embeddings) | |
# === FULL PIPELINE FUNCTION === | |
def run_initial_analysis(csv_file, context_input, n_clusters=10): | |
try: | |
df = pd.read_csv(csv_file.name) | |
except Exception as e: | |
return f"Error reading CSV: {str(e)}", "", "" | |
session['original_df'] = df.copy() | |
session['context'] = context_input | |
if 'text' not in df.columns: | |
return "CSV must contain a column named 'text'", "", "" | |
texts = df['text'].tolist() | |
clusters = cluster_texts(texts, n_clusters) | |
df['cluster'] = clusters | |
topic_labels = {} | |
keywords_map = {} | |
for i in range(n_clusters): | |
cluster_texts_i = [texts[j] for j in range(len(clusters)) if clusters[j] == i] | |
if not cluster_texts_i: | |
continue | |
samples = "\n".join(cluster_texts_i[:3]) | |
# Create CrewAI Tasks for this cluster | |
ext_task, lbl_task, val_task, _ = create_tasks(samples, context_input) | |
# Run keyword extraction | |
crew_keyword = Crew(agents=[keyword_agent], tasks=[ext_task]) | |
keyword_result = crew_keyword.kickoff() | |
keywords_map[i] = keyword_result.raw.strip() | |
# Run labeling | |
crew_label = Crew(agents=[labeling_agent], tasks=[lbl_task]) | |
label_result = crew_label.kickoff() | |
topic_labels[i] = label_result.raw.strip() | |
# Assign labels and keywords back to DataFrame | |
df['label'] = df['cluster'].map(topic_labels) | |
df['keywords'] = df['cluster'].map(keywords_map) | |
session['current_df'] = df | |
# Validate Clusters | |
validation_prompts = [] | |
for cid in topic_labels: | |
val_task = Task( | |
description=f"Evaluate whether the topic label and keywords make sense together.\n\nLABEL: {topic_labels[cid]}\nKEYWORDS: {keywords_map.get(cid, '')}", | |
agent=validation_agent, | |
expected_output="'Approved' or 'Needs Refinement'" | |
) | |
crew_validate = Crew(agents=[validation_agent], tasks=[val_task]) | |
res = crew_validate.kickoff() | |
if "Needs" in res.raw: | |
session["clusters_verified"] = False | |
break | |
else: | |
session["clusters_verified"] = True | |
output = io.StringIO() | |
df.to_csv(output, index=False) | |
csv_str = output.getvalue() | |
return "Initial analysis complete!", csv_str, df.head(10).to_markdown(index=False) | |
# === REFINEMENT FUNCTION === | |
def refine_labels(feedback_input): | |
if session['current_df'] is None: | |
return "No data found. Please run initial analysis first.", "", "" | |
df = session['current_df'] | |
current_sample = df[['text', 'label']].head(10).to_markdown(index=False) | |
prompt = f""" | |
You are helping refine topic labels based on user feedback. | |
Current Labels: | |
{current_sample} | |
User Feedback: | |
{feedback_input} | |
Task: | |
Reassign labels accordingly. Keep output format consistent: one label per line. | |
Instructions: | |
Return only the revised labels, one per line. | |
""" | |
# Simulating refinement using the same LLM | |
response = llm(prompt) | |
new_labels = response.strip().split('\n')[:len(df)] | |
df['label'] = new_labels[:len(df)] | |
session['current_df'] = df | |
output = io.StringIO() | |
df.to_csv(output, index=False) | |
csv_str = output.getvalue() | |
return "Labels refined!", csv_str, df.head(10).to_markdown(index=False) | |
# === GRADIO UI === | |
with gr.Blocks(title="🧠 CrewAI + Open LLM Topic Modeling") as demo: | |
gr.Markdown("# 🎯 CrewAI-Powered Topic Modeling with Open LLMs") | |
gr.Markdown("Upload verbatims, get topics via multi-agent system using LLaMA / Mistral / Zephyr.") | |
with gr.Row(): | |
with gr.Column(): | |
upload = gr.File(label="Upload CSV ('text' column)", file_types=[".csv"]) | |
context = gr.Textbox(label="Context/Instruction", lines=5, value="Group these into common themes.") | |
cluster_slider = gr.Slider(2, 20, value=10, step=1, label="Number of Topics") | |
run_btn = gr.Button("Run Initial Analysis") | |
with gr.Column(): | |
feedback = gr.Textbox(label="Feedback / Instructions for Refinement", lines=5) | |
refine_btn = gr.Button("Refine Labels") | |
status = gr.Textbox(label="Status") | |
preview = gr.Textbox(label="First 10 Rows (Editable View)", lines=10) | |
download = gr.File(label="Download Final Labeled CSV") | |
run_btn.click(fn=run_initial_analysis, inputs=[upload, context, cluster_slider], outputs=[status, download, preview]) | |
refine_btn.click(fn=refine_labels, inputs=[feedback], outputs=[status, download, preview]) | |
if __name__ == "__main__": | |
demo.launch() |