pvaluedotone commited on
Commit
17808a8
·
verified ·
1 Parent(s): 239acf9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import re
3
+ import torch
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ from transformers import pipeline
8
+
9
+ # Load sentiment pipeline
10
+ sentiment_pipeline = pipeline(
11
+ "text-classification",
12
+ model="pvaluedotone/bigbird-flight",
13
+ tokenizer="pvaluedotone/bigbird-flight",
14
+ device=0 if torch.cuda.is_available() else -1
15
+ )
16
+
17
+ def clean_text(text):
18
+ if not isinstance(text, str):
19
+ return ""
20
+ text = re.sub(r"http\S+|@\w+|#\w+|[^\w\s]", "", text)
21
+ return text.lower().strip()
22
+
23
+ def predict_sentiment(texts):
24
+ results = sentiment_pipeline(texts, truncation=True, batch_size=32)
25
+ sentiments = []
26
+ confidences = []
27
+ for r in results:
28
+ label_num = int(r['label'].split('_')[-1])
29
+ sentiments.append(label_num)
30
+ confidences.append(r['score'])
31
+ return sentiments, confidences
32
+
33
+ def recategorize(labels, mode, pos_threshold, neg_threshold):
34
+ if mode == "Original (1–10)":
35
+ return labels
36
+ elif mode == "Binary (Positive vs Negative)":
37
+ return ["Positive" if lbl >= pos_threshold else "Negative" for lbl in labels]
38
+ elif mode == "Ternary (Pos/Neu/Neg)":
39
+ return [
40
+ "Positive" if lbl >= pos_threshold else
41
+ "Negative" if lbl <= neg_threshold else
42
+ "Neutral" for lbl in labels
43
+ ]
44
+
45
+ def analyze_sentiment(file, text_column, mode, pos_thresh, neg_thresh):
46
+ try:
47
+ df = pd.read_csv(file.name)
48
+ except Exception as e:
49
+ return f"Error reading CSV file: {e}", None, None, None, None, None
50
+
51
+ if text_column not in df.columns:
52
+ return "Selected column not found.", None, None, None, None, None
53
+
54
+ df["clean_text"] = df[text_column].apply(clean_text)
55
+ predictions, confidences = predict_sentiment(df["clean_text"].tolist())
56
+ df["sentiment_1to10"] = predictions
57
+ df["confidence"] = confidences
58
+ df["sentiment_recategorised"] = recategorize(df["sentiment_1to10"], mode, pos_thresh, neg_thresh)
59
+
60
+ # Save results
61
+ output_file = "bigbird_sentiment_results.csv"
62
+ df.to_csv(output_file, index=False)
63
+
64
+ # Plot 1: Original 10-class sentiment distribution
65
+ plt.figure(figsize=(6, 4))
66
+ sns.countplot(x=df["sentiment_1to10"], palette="Blues")
67
+ plt.title("Original 10-Class Sentiment Distribution")
68
+ plt.tight_layout()
69
+ plot1_path = "original_dist.png"
70
+ plt.savefig(plot1_path)
71
+ plt.close()
72
+
73
+ # Plot 2: Recategorized sentiment distribution
74
+ plt.figure(figsize=(6, 4))
75
+ sns.countplot(x=df["sentiment_recategorised"], palette="Set2")
76
+ plt.title(f"Recategorized Sentiment Distribution ({mode})")
77
+ plt.tight_layout()
78
+ plot2_path = "recategorised_dist.png"
79
+ plt.savefig(plot2_path)
80
+ plt.close()
81
+
82
+ # Plot 3: Confidence score distribution
83
+ plt.figure(figsize=(6, 4))
84
+ sns.histplot(df["confidence"], bins=20, color="orange", kde=True)
85
+ plt.title("Confidence Score Distribution")
86
+ plt.xlabel("Confidence")
87
+ plt.tight_layout()
88
+ plot3_path = "confidence_dist.png"
89
+ plt.savefig(plot3_path)
90
+ plt.close()
91
+
92
+ # Sample preview
93
+ preview = df[[text_column, "sentiment_1to10", "confidence", "sentiment_recategorised"]].head(10)
94
+ return f"Sentiment analysis complete. Processed {len(df)} rows.", preview, output_file, plot1_path, plot2_path, plot3_path
95
+
96
+ def get_text_columns(file):
97
+ try:
98
+ df = pd.read_csv(file.name, nrows=1)
99
+ text_columns = df.select_dtypes(include='object').columns.tolist()
100
+ if not text_columns:
101
+ return gr.update(choices=[], value=None, label="⚠️ No text columns found!")
102
+ return gr.update(choices=text_columns, value=text_columns[0])
103
+ except Exception:
104
+ return gr.update(choices=[], value=None, label="⚠️ Error reading file")
105
+
106
+ with gr.Blocks() as app:
107
+ gr.Markdown("## ✈️ Sentiment analysis with `pvaluedotone/bigbird-flight`")
108
+ gr.Markdown("**Citation:** Mat Roni, S. (2025). *Sentiment analysis with Big Bird Flight on Gradio* (version 1.0) [software]. https://huggingface.co/spaces/pvaluedotone/bigbird-flight")
109
+ gr.Markdown("Upload a CSV, choose a text column, select output style (10-class, binary, or ternary), and analyze.")
110
+
111
+ with gr.Row():
112
+ file_input = gr.File(label="Upload CSV", file_types=[".csv"])
113
+ column_dropdown = gr.Dropdown(label="Select Text Column", choices=[], interactive=True)
114
+
115
+ file_input.change(get_text_columns, inputs=file_input, outputs=column_dropdown)
116
+
117
+ output_mode = gr.Radio(
118
+ label="Sentiment Output Type",
119
+ choices=["Original (1–10)", "Binary (Positive vs Negative)", "Ternary (Pos/Neu/Neg)"],
120
+ value="Original (1–10)",
121
+ interactive=True
122
+ )
123
+
124
+ pos_thresh_slider = gr.Slider(5, 10, value=7, step=1, label="Positive Threshold", visible=False)
125
+ neg_thresh_slider = gr.Slider(1, 5, value=4, step=1, label="Negative Threshold", visible=False)
126
+
127
+ def toggle_thresholds(mode):
128
+ show_pos = mode != "Original (1–10)"
129
+ show_neg = mode == "Ternary (Pos/Neu/Neg)"
130
+ return (
131
+ gr.update(visible=show_pos),
132
+ gr.update(visible=show_neg)
133
+ )
134
+
135
+ output_mode.change(toggle_thresholds, inputs=output_mode, outputs=[pos_thresh_slider, neg_thresh_slider])
136
+
137
+ run_button = gr.Button("Run Sentiment Analysis")
138
+
139
+ status = gr.Textbox(label="Status")
140
+ df_output = gr.Dataframe(label="Sample Output (Top 10)")
141
+ file_result = gr.File(label="Download Full Results")
142
+ plot_orig = gr.Image(label="Original Sentiment Distribution")
143
+ plot_recat = gr.Image(label="Recategorized Sentiment Distribution")
144
+ plot_conf = gr.Image(label="Confidence Score Distribution")
145
+
146
+ run_button.click(
147
+ analyze_sentiment,
148
+ inputs=[file_input, column_dropdown, output_mode, pos_thresh_slider, neg_thresh_slider],
149
+ outputs=[status, df_output, file_result, plot_orig, plot_recat, plot_conf]
150
+ )
151
+
152
+ app.launch(debug=True)