File size: 7,413 Bytes
17808a8
 
 
 
 
 
 
 
dcfdbaf
 
 
17808a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcfdbaf
17808a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcfdbaf
 
17808a8
 
 
 
 
 
 
 
 
dcfdbaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17808a8
 
 
 
 
 
dcfdbaf
 
 
 
 
 
 
 
 
 
17808a8
 
 
 
 
 
 
 
 
 
dcfdbaf
 
 
 
 
 
 
 
 
 
 
17808a8
 
 
dcfdbaf
 
17808a8
 
 
 
 
 
 
 
 
 
 
 
 
 
dcfdbaf
17808a8
 
 
 
 
 
 
 
 
 
 
 
 
 
dcfdbaf
 
 
 
17808a8
 
 
 
 
 
 
 
 
 
 
dcfdbaf
17808a8
 
 
 
 
dcfdbaf
17808a8
 
 
 
dcfdbaf
17808a8
 
 
dcfdbaf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import pandas as pd
import re
import torch
import gradio as gr
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import pipeline

cached_df = None
cached_file_name = None

# Load sentiment pipeline
sentiment_pipeline = pipeline(
    "text-classification",
    model="pvaluedotone/bigbird-flight",
    tokenizer="pvaluedotone/bigbird-flight",
    device=0 if torch.cuda.is_available() else -1
)

def clean_text(text):
    if not isinstance(text, str):
        return ""
    text = re.sub(r"http\S+|@\w+|#\w+|[^\w\s]", "", text)
    return text.lower().strip()

def predict_sentiment(texts):
    results = sentiment_pipeline(texts, truncation=False, batch_size=32)
    sentiments = []
    confidences = []
    for r in results:
        label_num = int(r['label'].split('_')[-1])
        sentiments.append(label_num)
        confidences.append(r['score'])
    return sentiments, confidences

def recategorize(labels, mode, pos_threshold, neg_threshold):
    if mode == "Original (1–10)":
        return labels
    elif mode == "Binary (Positive vs Negative)":
        return ["Positive" if lbl >= pos_threshold else "Negative" for lbl in labels]
    elif mode == "Ternary (Pos/Neu/Neg)":
        return [
            "Positive" if lbl >= pos_threshold else
            "Negative" if lbl <= neg_threshold else
            "Neutral" for lbl in labels
        ]
def analyze_sentiment(file, text_column, mode, pos_thresh, neg_thresh, auto_fix):
    global cached_df, cached_file_name

    try:
        df = pd.read_csv(file.name)
    except Exception as e:
        return f"Error reading CSV file: {e}", None, None, None, None, None

    if text_column not in df.columns:
        return "Selected column not found.", None, None, None, None, None

    # Check if sentiment analysis already done and file is unchanged
    if (
        cached_df is not None and
        cached_file_name == file.name and
        "sentiment_1to10" in cached_df.columns and
        "confidence" in cached_df.columns
    ):
        df = cached_df.copy()
    else:
        # Clean and predict
        df["clean_text"] = df[text_column].apply(clean_text)
        predictions, confidences = predict_sentiment(df["clean_text"].tolist())
        df["sentiment_1to10"] = predictions
        df["confidence"] = confidences
        # Cache result
        cached_df = df.copy()
        cached_file_name = file.name

    # πŸ›‘ Check thresholds
    if mode == "Ternary (Pos/Neu/Neg)":
        if pos_thresh <= neg_thresh:
            if auto_fix:
                neg_thresh = pos_thresh - 1
                if neg_thresh < 1:
                    return "⚠️ Cannot auto-correct: thresholds out of valid range (1–10).", None, None, None, None, None
            else:
                return (
                    f"⚠️ Invalid thresholds: Positive min ({pos_thresh}) must be greater than Negative max ({neg_thresh}).",
                    None, None, None, None, None
                )

    # Apply recategorization
    df["sentiment_recategorised"] = recategorize(df["sentiment_1to10"], mode, pos_thresh, neg_thresh)

    # Save results
    output_file = "bigbird_sentiment_results.csv"
    df.to_csv(output_file, index=False)

    # Plot 1: Original 10-class sentiment distribution (only if new analysis)
    if "plot1_path" not in globals():
        plt.figure(figsize=(6, 4))
        sns.countplot(x=df["sentiment_1to10"], palette="Blues")
        plt.title("Original 10-Class Sentiment Distribution")
        plt.tight_layout()
        global plot1_path
        plot1_path = "original_dist.png"
        plt.savefig(plot1_path)
        plt.close()

    # Plot 2: Recategorized sentiment distribution
    plt.figure(figsize=(6, 4))
    sns.countplot(x=df["sentiment_recategorised"], palette="Set2")
    plt.title(f"Recategorized Sentiment Distribution ({mode})")
    plt.tight_layout()
    plot2_path = "recategorised_dist.png"
    plt.savefig(plot2_path)
    plt.close()

    # Plot 3: Confidence score distribution (only if new analysis)
    if "plot3_path" not in globals():
        plt.figure(figsize=(6, 4))
        sns.histplot(df["confidence"], bins=20, color="skyblue", kde=True)
        plt.title("Confidence Score Distribution")
        plt.xlabel("Confidence")
        plt.tight_layout()
        global plot3_path
        plot3_path = "confidence_dist.png"
        plt.savefig(plot3_path)
        plt.close()

    # Sample preview
    preview = df[[text_column, "sentiment_1to10", "confidence", "sentiment_recategorised"]].head(10)
    return f"βœ… Sentiment analysis complete. Used cache: {cached_file_name == file.name}", preview, output_file, plot1_path, plot2_path, plot3_path


def get_text_columns(file):
    try:
        df = pd.read_csv(file.name, nrows=1)
        text_columns = df.select_dtypes(include='object').columns.tolist()
        if not text_columns:
            return gr.update(choices=[], value=None, label="⚠️ No text columns found!")
        return gr.update(choices=text_columns, value=text_columns[0])
    except Exception:
        return gr.update(choices=[], value=None, label="⚠️ Error reading file")

with gr.Blocks() as app:
    gr.Markdown("## ✈️ Sentiment analysis with `pvaluedotone/bigbird-flight`")
    gr.Markdown("**Citation:** Mat Roni, S. (2025). *Sentiment analysis with Big Bird Flight on Gradio* (version 1.0) [software]. https://huggingface.co/spaces/pvaluedotone/bigbird-flight")
    gr.Markdown("Upload a CSV, choose a text column to analyse, select output style (10-class, binary, or ternary), and analyse.")

    with gr.Row():
        file_input = gr.File(label="Upload CSV", file_types=[".csv"])
        column_dropdown = gr.Dropdown(label="Select Text Column", choices=[], interactive=True)

    file_input.change(get_text_columns, inputs=file_input, outputs=column_dropdown)

    output_mode = gr.Radio(
        label="Sentiment Output Type",
        choices=["Original (1–10)", "Binary (Positive vs Negative)", "Ternary (Pos/Neu/Neg)"],
        value="Original (1–10)",
        interactive=True
    )

    pos_thresh_slider = gr.Slider(3, 10, value=7, step=1, label="Positive min", visible=False)
    neg_thresh_slider = gr.Slider(1, 7, value=4, step=1, label="Negative max", visible=False)
    auto_fix_checkbox = gr.Checkbox(label="Auto-correct thresholds if overlapping?", value=True)


    def toggle_thresholds(mode):
        show_pos = mode != "Original (1–10)"
        show_neg = mode == "Ternary (Pos/Neu/Neg)"
        return (
            gr.update(visible=show_pos),
            gr.update(visible=show_neg)
        )

    output_mode.change(toggle_thresholds, inputs=output_mode, outputs=[pos_thresh_slider, neg_thresh_slider])

    run_button = gr.Button("Process sentiment")

    status = gr.Textbox(label="Status")
    df_output = gr.Dataframe(label="Sample Output (Top 10)")
    file_result = gr.File(label="Download Full Results")
    plot_orig = gr.Image(label="Original Sentiment Distribution")
    plot_recat = gr.Image(label="Recategorised Sentiment Distribution")
    plot_conf = gr.Image(label="Confidence Score Distribution")

    run_button.click(
        analyze_sentiment,
        inputs=[file_input, column_dropdown, output_mode, pos_thresh_slider, neg_thresh_slider, auto_fix_checkbox],
        outputs=[status, df_output, file_result, plot_orig, plot_recat, plot_conf]
    )

app.launch(share=True, debug=True)