import pandas as pd import re import torch import gradio as gr import matplotlib.pyplot as plt import seaborn as sns from transformers import pipeline cached_df = None cached_file_name = None # Load sentiment pipeline sentiment_pipeline = pipeline( "text-classification", model="pvaluedotone/bigbird-flight", tokenizer="pvaluedotone/bigbird-flight", device=0 if torch.cuda.is_available() else -1 ) def clean_text(text): if not isinstance(text, str): return "" text = re.sub(r"http\S+|@\w+|#\w+|[^\w\s]", "", text) return text.lower().strip() def predict_sentiment(texts): results = sentiment_pipeline(texts, truncation=False, batch_size=32) sentiments = [] confidences = [] for r in results: label_num = int(r['label'].split('_')[-1]) sentiments.append(label_num) confidences.append(r['score']) return sentiments, confidences def recategorize(labels, mode, pos_threshold, neg_threshold): if mode == "Original (1–10)": return labels elif mode == "Binary (Positive vs Negative)": return ["Positive" if lbl >= pos_threshold else "Negative" for lbl in labels] elif mode == "Ternary (Pos/Neu/Neg)": return [ "Positive" if lbl >= pos_threshold else "Negative" if lbl <= neg_threshold else "Neutral" for lbl in labels ] def analyze_sentiment(file, text_column, mode, pos_thresh, neg_thresh, auto_fix): global cached_df, cached_file_name try: df = pd.read_csv(file.name) except Exception as e: return f"Error reading CSV file: {e}", None, None, None, None, None if text_column not in df.columns: return "Selected column not found.", None, None, None, None, None # Check if sentiment analysis already done and file is unchanged if ( cached_df is not None and cached_file_name == file.name and "sentiment_1to10" in cached_df.columns and "confidence" in cached_df.columns ): df = cached_df.copy() else: # Clean and predict df["clean_text"] = df[text_column].apply(clean_text) predictions, confidences = predict_sentiment(df["clean_text"].tolist()) df["sentiment_1to10"] = predictions df["confidence"] = confidences # Cache result cached_df = df.copy() cached_file_name = file.name # 🛑 Check thresholds if mode == "Ternary (Pos/Neu/Neg)": if pos_thresh <= neg_thresh: if auto_fix: neg_thresh = pos_thresh - 1 if neg_thresh < 1: return "⚠️ Cannot auto-correct: thresholds out of valid range (1–10).", None, None, None, None, None else: return ( f"⚠️ Invalid thresholds: Positive min ({pos_thresh}) must be greater than Negative max ({neg_thresh}).", None, None, None, None, None ) # Apply recategorization df["sentiment_recategorised"] = recategorize(df["sentiment_1to10"], mode, pos_thresh, neg_thresh) # Save results output_file = "bigbird_sentiment_results.csv" df.to_csv(output_file, index=False) # Plot 1: Original 10-class sentiment distribution (only if new analysis) if "plot1_path" not in globals(): plt.figure(figsize=(6, 4)) sns.countplot(x=df["sentiment_1to10"], palette="Blues") plt.title("Original 10-Class Sentiment Distribution") plt.tight_layout() global plot1_path plot1_path = "original_dist.png" plt.savefig(plot1_path) plt.close() # Plot 2: Recategorized sentiment distribution plt.figure(figsize=(6, 4)) sns.countplot(x=df["sentiment_recategorised"], palette="Set2") plt.title(f"Recategorized Sentiment Distribution ({mode})") plt.tight_layout() plot2_path = "recategorised_dist.png" plt.savefig(plot2_path) plt.close() # Plot 3: Confidence score distribution (only if new analysis) if "plot3_path" not in globals(): plt.figure(figsize=(6, 4)) sns.histplot(df["confidence"], bins=20, color="skyblue", kde=True) plt.title("Confidence Score Distribution") plt.xlabel("Confidence") plt.tight_layout() global plot3_path plot3_path = "confidence_dist.png" plt.savefig(plot3_path) plt.close() # Sample preview preview = df[[text_column, "sentiment_1to10", "confidence", "sentiment_recategorised"]].head(10) return f"✅ Sentiment analysis complete. Used cache: {cached_file_name == file.name}", preview, output_file, plot1_path, plot2_path, plot3_path def get_text_columns(file): try: df = pd.read_csv(file.name, nrows=1) text_columns = df.select_dtypes(include='object').columns.tolist() if not text_columns: return gr.update(choices=[], value=None, label="⚠️ No text columns found!") return gr.update(choices=text_columns, value=text_columns[0]) except Exception: return gr.update(choices=[], value=None, label="⚠️ Error reading file") with gr.Blocks() as app: gr.Markdown("## ✈️ Sentiment analysis with `pvaluedotone/bigbird-flight`") gr.Markdown("**Citation:** Mat Roni, S. (2025). *Sentiment analysis with Big Bird Flight on Gradio* (version 1.0) [software]. https://huggingface.co/spaces/pvaluedotone/bigbird-flight") gr.Markdown("Upload a CSV, choose a text column to analyse, select output style (10-class, binary, or ternary), and analyse.") with gr.Row(): file_input = gr.File(label="Upload CSV", file_types=[".csv"]) column_dropdown = gr.Dropdown(label="Select Text Column", choices=[], interactive=True) file_input.change(get_text_columns, inputs=file_input, outputs=column_dropdown) output_mode = gr.Radio( label="Sentiment Output Type", choices=["Original (1–10)", "Binary (Positive vs Negative)", "Ternary (Pos/Neu/Neg)"], value="Original (1–10)", interactive=True ) pos_thresh_slider = gr.Slider(3, 10, value=7, step=1, label="Positive min", visible=False) neg_thresh_slider = gr.Slider(1, 7, value=4, step=1, label="Negative max", visible=False) auto_fix_checkbox = gr.Checkbox(label="Auto-correct thresholds if overlapping?", value=True) def toggle_thresholds(mode): show_pos = mode != "Original (1–10)" show_neg = mode == "Ternary (Pos/Neu/Neg)" return ( gr.update(visible=show_pos), gr.update(visible=show_neg) ) output_mode.change(toggle_thresholds, inputs=output_mode, outputs=[pos_thresh_slider, neg_thresh_slider]) run_button = gr.Button("Process sentiment") status = gr.Textbox(label="Status") df_output = gr.Dataframe(label="Sample Output (Top 10)") file_result = gr.File(label="Download Full Results") plot_orig = gr.Image(label="Original Sentiment Distribution") plot_recat = gr.Image(label="Recategorised Sentiment Distribution") plot_conf = gr.Image(label="Confidence Score Distribution") run_button.click( analyze_sentiment, inputs=[file_input, column_dropdown, output_mode, pos_thresh_slider, neg_thresh_slider, auto_fix_checkbox], outputs=[status, df_output, file_result, plot_orig, plot_recat, plot_conf] ) app.launch(share=True, debug=True)