File size: 8,732 Bytes
c9b15b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eba7586
 
c9b15b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import pandas as pd
import re
import torch
import gradio as gr
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import pipeline

cached_df = None
cached_file_name = None

# Load sentiment pipeline
sentiment_pipeline = pipeline(
    "text-classification",
    model="pvaluedotone/bigbird-flight-2",
    tokenizer="pvaluedotone/bigbird-flight-2",
    device=0 if torch.cuda.is_available() else -1
)

# Contractions dictionary
contractions_dict = {
    "don't": "do not", "can't": "cannot", "i'm": "i am", "it's": "it is",
    "he's": "he is", "she's": "she is", "they're": "they are", "we're": "we are",
    "you're": "you are", "that's": "that is", "there's": "there is", "what's": "what is",
    "won't": "will not", "isn't": "is not", "aren't": "are not", "wasn't": "was not",
    "weren't": "were not", "didn't": "did not", "doesn't": "does not", "haven't": "have not",
    "hasn't": "has not", "hadn't": "had not", "wouldn't": "would not", "shouldn't": "should not",
    "couldn't": "could not", "mustn't": "must not", "let's": "let us"
}
contractions_pattern = re.compile(r"\b(" + "|".join(re.escape(k) for k in contractions_dict.keys()) + r")\b")

def expand_contractions(text: str) -> str:
    def replace(match):
        return contractions_dict[match.group(0)]
    return contractions_pattern.sub(replace, text)

# Emoticon mapping
emoticon_dict = {
    ":)": "smile", ":-)": "smile", ":(": "sad", ":-(": "sad",
    ";)": "wink", ";-)": "wink", ":d": "laugh", ":-d": "laugh",
    ":p": "playful", ":-p": "playful", ":'(": "cry", ":/": "skeptical",
    ":'-)": "tears_of_joy"
}

def clean_text(text: str) -> str:
    if not isinstance(text, str):
        return ""
    text = re.sub(r"http\S+|@\w+", "", text)
    text = expand_contractions(text)
    try:
        import emoji
        text = emoji.demojize(text)
    except ImportError:
        pass
    for emoticon, desc in emoticon_dict.items():
        text = text.replace(emoticon, f" {desc} ")
    text = re.sub(r"#(\w+)", r"\1", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text

def predict_sentiment(texts):
    results = sentiment_pipeline(texts, truncation=False, batch_size=32)
    sentiments = []
    confidences = []
    for r in results:
        label_num = int(r['label'].split('_')[-1])
        sentiments.append(label_num)
        confidences.append(r['score'])
    return sentiments, confidences

def recategorize(labels, mode, pos_threshold, neg_threshold):
    if mode == "Original (1–10)":
        return labels
    elif mode == "Binary (Positive vs Negative)":
        return ["Positive" if lbl >= pos_threshold else "Negative" for lbl in labels]
    elif mode == "Ternary (Pos/Neu/Neg)":
        return [
            "Positive" if lbl >= pos_threshold else
            "Negative" if lbl <= neg_threshold else
            "Neutral" for lbl in labels
        ]

def analyze_sentiment(file, text_column, mode, pos_thresh, neg_thresh, auto_fix, apply_cleaning):
    global cached_df, cached_file_name

    try:
        df = pd.read_csv(file.name)
    except Exception as e:
        return f"Error reading CSV file: {e}", None, None, None, None, None

    if text_column not in df.columns:
        return "Selected column not found.", None, None, None, None, None

    if (
        cached_df is not None and
        cached_file_name == file.name and
        "sentiment_1to10" in cached_df.columns and
        "confidence" in cached_df.columns
    ):
        df = cached_df.copy()
    else:
        if apply_cleaning:
            df["processed_text"] = df[text_column].apply(clean_text)
        else:
            df["processed_text"] = df[text_column].astype(str)

        predictions, confidences = predict_sentiment(df["processed_text"].tolist())
        df["sentiment_1to10"] = predictions
        df["confidence"] = confidences

        cached_df = df.copy()
        cached_file_name = file.name

    if mode == "Ternary (Pos/Neu/Neg)":
        if pos_thresh <= neg_thresh:
            if auto_fix:
                neg_thresh = pos_thresh - 1
                if neg_thresh < 1:
                    return "⚠️ Unable to auto-correct: thresholds out of valid range (1–10).", None, None, None, None, None
            else:
                return (
                    f"⚠️ Invalid thresholds: Positive min ({pos_thresh}) must be greater than Negative max ({neg_thresh}).",
                    None, None, None, None, None
                )

    df["sentiment_recategorised"] = recategorize(df["sentiment_1to10"], mode, pos_thresh, neg_thresh)

    output_file = "bigbird_sentiment_results.csv"
    df.to_csv(output_file, index=False)

    if "plot1_path" not in globals():
        plt.figure(figsize=(6, 4))
        sns.countplot(x=df["sentiment_1to10"], palette="Blues")
        plt.title("Original 10-Class Sentiment Distribution")
        plt.tight_layout()
        global plot1_path
        plot1_path = "original_dist.png"
        plt.savefig(plot1_path)
        plt.close()

    plt.figure(figsize=(6, 4))
    sns.countplot(x=df["sentiment_recategorised"], palette="Set2")
    plt.title(f"Recategorised Sentiment Distribution ({mode})")
    plt.tight_layout()
    plot2_path = "recategorised_dist.png"
    plt.savefig(plot2_path)
    plt.close()

    if "plot3_path" not in globals():
        plt.figure(figsize=(6, 4))
        sns.histplot(df["confidence"], bins=20, color="skyblue", kde=True)
        plt.title("Confidence Score Distribution")
        plt.xlabel("Confidence")
        plt.tight_layout()
        global plot3_path
        plot3_path = "confidence_dist.png"
        plt.savefig(plot3_path)
        plt.close()

    preview = df[[text_column, "processed_text", "sentiment_1to10", "confidence", "sentiment_recategorised"]].head(10)
    return f"βœ… Sentiment analysis complete. Used cache: {cached_file_name == file.name}", preview, output_file, plot1_path, plot2_path, plot3_path

def get_text_columns(file):
    try:
        df = pd.read_csv(file.name, nrows=1)
        text_columns = df.select_dtypes(include='object').columns.tolist()
        if not text_columns:
            return gr.update(choices=[], value=None, label="⚠️ No text columns found!")
        return gr.update(choices=text_columns, value=text_columns[0])
    except Exception:
        return gr.update(choices=[], value=None, label="⚠️ Error reading file")

with gr.Blocks() as app:
    gr.Markdown("## ✈️ Sentiment analysis with Big Bird Flight 2")
    gr.Markdown("**Citation:** Mat Roni, S. (2025). *Sentiment analysis with Big Bird Flight 2 on Gradio* (version 1.0) [software]. https://huggingface.co/spaces/pvaluedotone/bigbird-flight-2 DOI: https://doi.org/10.57967/hf/5780")

    with gr.Row():
        file_input = gr.File(label="Upload CSV", file_types=[".csv"])
        column_dropdown = gr.Dropdown(label="Select Text Column", choices=[], interactive=True)

    file_input.change(get_text_columns, inputs=file_input, outputs=column_dropdown)

    output_mode = gr.Radio(
        label="Sentiment Output Type",
        choices=["Original (1–10)", "Binary (Positive vs Negative)", "Ternary (Pos/Neu/Neg)"],
        value="Original (1–10)",
        interactive=True
    )

    pos_thresh_slider = gr.Slider(3, 10, value=7, step=1, label="Positive min", visible=False)
    neg_thresh_slider = gr.Slider(1, 7, value=4, step=1, label="Negative max", visible=False)
    auto_fix_checkbox = gr.Checkbox(label="Auto-correct thresholds if overlapping?", value=True)
    cleaning_checkbox = gr.Checkbox(label="Apply Text Cleaning", value=True)  # βœ… New toggle

    def toggle_thresholds(mode):
        show_pos = mode != "Original (1–10)"
        show_neg = mode == "Ternary (Pos/Neu/Neg)"
        return (
            gr.update(visible=show_pos),
            gr.update(visible=show_neg)
        )

    output_mode.change(toggle_thresholds, inputs=output_mode, outputs=[pos_thresh_slider, neg_thresh_slider])

    run_button = gr.Button("Process sentiment")

    status = gr.Textbox(label="Status")
    df_output = gr.Dataframe(label="Sample Output (Top 10)")
    file_result = gr.File(label="Download Full Results")
    plot_orig = gr.Image(label="Original Sentiment Distribution")
    plot_recat = gr.Image(label="Recategorised Sentiment Distribution")
    plot_conf = gr.Image(label="Confidence Score Distribution")

    run_button.click(
        analyze_sentiment,
        inputs=[
            file_input, column_dropdown, output_mode,
            pos_thresh_slider, neg_thresh_slider, auto_fix_checkbox,
            cleaning_checkbox  # βœ… New input
        ],
        outputs=[status, df_output, file_result, plot_orig, plot_recat, plot_conf]
    )

app.launch(share=True, debug=True)