adnaan05 commited on
Commit
81dd6cb
·
verified ·
1 Parent(s): de2b5de

Update src/app.py

Browse files
Files changed (1) hide show
  1. src/app.py +180 -462
src/app.py CHANGED
@@ -4,520 +4,238 @@ import pandas as pd
4
  import numpy as np
5
  from pathlib import Path
6
  import sys
 
7
  import plotly.graph_objects as go
8
  from transformers import BertTokenizer
9
  import nltk
10
 
11
  # Download required NLTK data
12
- nltk_data = {
13
- 'tokenizers/punkt': 'punkt',
14
- 'corpora/stopwords': 'stopwords',
15
- 'tokenizers/punkt_tab': 'punkt_tab',
16
- 'corpora/wordnet': 'wordnet'
17
- }
18
- for resource, package in nltk_data.items():
19
- try:
20
- nltk.data.find(resource)
21
- except LookupError:
22
- nltk.download(package)
 
 
 
 
 
23
 
24
  # Add project root to Python path
25
  project_root = Path(__file__).parent.parent
26
  sys.path.append(str(project_root))
27
 
28
  from src.models.hybrid_model import HybridFakeNewsDetector
29
- from src.config.config import BERT_MODEL_NAME, LSTM_HIDDEN_SIZE, LSTM_NUM_LAYERS, DROPOUT_RATE, SAVED_MODELS_DIR, MAX_SEQUENCE_LENGTH
30
  from src.data.preprocessor import TextPreprocessor
31
 
32
- # Custom CSS with Poppins font
33
- st.markdown("""
34
- <style>
35
- @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@200;300;400;500;600;700&display=swap');
36
-
37
- * {
38
- font-family: 'Poppins', sans-serif !important;
39
- box-sizing: border-box;
40
- }
41
-
42
- .stApp {
43
- background: #ffffff;
44
- min-height: 100vh;
45
- color: #1f2a44;
46
- }
47
-
48
- #MainMenu {visibility: hidden;}
49
- footer {visibility: hidden;}
50
- .stDeployButton {display: none;}
51
- header {visibility: hidden;}
52
- .stApp > header {visibility: hidden;}
53
-
54
- /* Main Container */
55
- .main-container {
56
- max-width: 1200px;
57
- margin: 0 auto;
58
- padding: 1rem 2rem;
59
- }
60
-
61
- /* Header Section */
62
- .header-section {
63
- text-align: center;
64
- margin-bottom: 2.5rem;
65
- padding: 1.5rem 0;
66
- }
67
-
68
- .header-title {
69
- font-size: 2.25rem;
70
- font-weight: 700;
71
- color: #1f2a44;
72
- margin: 0;
73
- }
74
-
75
- /* Hero Section */
76
- .hero {
77
- display: flex;
78
- align-items: center;
79
- gap: 2rem;
80
- margin-bottom: 2rem;
81
- padding: 0 1rem;
82
- }
83
-
84
- .hero-left {
85
- flex: 1;
86
- padding: 1.5rem;
87
- }
88
-
89
- .hero-right {
90
- flex: 1;
91
- display: flex;
92
- align-items: center;
93
- justify-content: center;
94
- }
95
-
96
- .hero-right img {
97
- max-width: 100%;
98
- height: auto;
99
- border-radius: 8px;
100
- object-fit: cover;
101
- }
102
-
103
- .hero-title {
104
- font-size: 2.5rem;
105
- font-weight: 700;
106
- color: #1f2a44;
107
- margin-bottom: 0.5rem;
108
- }
109
-
110
- .hero-text {
111
- font-size: 1rem;
112
- color: #6b7280;
113
- line-height: 1.6;
114
- max-width: 450px;
115
- }
116
-
117
- /* About Section */
118
- .about-section {
119
- margin-bottom: 2rem;
120
- text-align: center;
121
- padding: 0 1rem;
122
- }
123
-
124
- .about-title {
125
- font-size: 1.75rem;
126
- font-weight: 600;
127
- color: #1f2a44;
128
- margin-bottom: 0.5rem;
129
- }
130
-
131
- .about-text {
132
- font-size: 0.95rem;
133
- color: #6b7280;
134
- line-height: 1.6;
135
- max-width: 600px;
136
- margin: 0 auto;
137
- }
138
-
139
- /* Input Section */
140
- .input-container {
141
- max-width: 800px;
142
- margin: 0 auto;
143
- }
144
-
145
- .stTextArea > div > div > textarea {
146
- border-radius: 8px !important;
147
- border: 1px solid #d1d5db !important;
148
- padding: 1rem !important;
149
- font-size: 1rem !important;
150
- background: #ffffff !important;
151
- min-height: 150px !important;
152
- transition: all 0.2s ease !important;
153
- }
154
-
155
- .stTextArea > div > div > textarea:focus {
156
- border-color: #6366f1 !important;
157
- box-shadow: 0 0 0 2px rgba(99, 102, 241, 0.1) !important;
158
- outline: none !important;
159
- }
160
-
161
- .stTextArea > div > div > textarea::placeholder {
162
- color: #9ca3af !important;
163
- }
164
-
165
- /* Button Styling */
166
- .stButton > button {
167
- background: #6366f1 !important;
168
- color: white !important;
169
- border-radius: 8px !important;
170
- padding: 0.75rem 2rem !important;
171
- font-size: 1rem !important;
172
- font-weight: 600 !important;
173
- transition: all 0.2s ease !important;
174
- border: none !important;
175
- width: 100% !important;
176
- max-width: 300px;
177
- }
178
-
179
- .stButton > button:hover {
180
- background: #4f46e5 !important;
181
- transform: translateY(-1px) !important;
182
- }
183
-
184
- /* Results Section */
185
- .results-container {
186
- margin-top: 1rem;
187
- padding: 1rem;
188
- border-radius: 8px;
189
- max-width: 1200px;
190
- margin-left: auto;
191
- margin-right: auto;
192
- }
193
-
194
- .result-card {
195
- padding: 1rem;
196
- border-radius: 8px;
197
- border-left: 4px solid transparent;
198
- margin-bottom: 1rem;
199
- }
200
-
201
- .fake-news {
202
- background: #fef2f2;
203
- border-left-color: #ef4444;
204
- }
205
-
206
- .real-news {
207
- background: #ecfdf5;
208
- border-left-color: #10b981;
209
- }
210
-
211
- .prediction-badge {
212
- font-weight: 600;
213
- font-size: 1rem;
214
- margin-bottom: 0.5rem;
215
- display: flex;
216
- align-items: center;
217
- gap: 0.5rem;
218
- }
219
-
220
- .confidence-score {
221
- font-weight: 600;
222
- margin-left: auto;
223
- font-size: 1rem;
224
- }
225
-
226
- /* Chart Containers */
227
- .chart-container {
228
- padding: 1rem;
229
- border-radius: 8px;
230
- margin: 1rem 0;
231
- max-width: 1200px;
232
- margin-left: auto;
233
- margin-right: auto;
234
- }
235
-
236
- /* Footer */
237
- .footer {
238
- border-top: 1px solid #e5e7eb;
239
- padding: 1.5rem 0;
240
- text-align: center;
241
- max-width: 1200px;
242
- margin: 2rem auto 0;
243
- }
244
-
245
- /* Responsive Design */
246
- @media (max-width: 1024px) {
247
- .hero {
248
- flex-direction: column;
249
- text-align: center;
250
- }
251
- .hero-right img {
252
- max-width: 80%;
253
- }
254
- }
255
-
256
- @media (max-width: 768px) {
257
- .header-title {
258
- font-size: 1.75rem;
259
- }
260
- .hero-title {
261
- font-size: 2rem;
262
- }
263
- .hero-text {
264
- font-size: 0.9rem;
265
- }
266
- .about-title {
267
- font-size: 1.5rem;
268
- }
269
- .about-text {
270
- font-size: 0.9rem;
271
- }
272
- }
273
-
274
- @media (max-width: 480px) {
275
- .header-title {
276
- font-size: 1.5rem;
277
- }
278
- .hero-title {
279
- font-size: 1.75rem;
280
- }
281
- .hero-text {
282
- font-size: 0.85rem;
283
- }
284
- .about-title {
285
- font-size: 1.25rem;
286
- }
287
- .about-text {
288
- font-size: 0.85rem;
289
- }
290
- }
291
- </style>
292
- """, unsafe_allow_html=True)
293
 
294
  @st.cache_resource
295
- def load_model_and_tokenizer() -> tuple[HybridFakeNewsDetector, BertTokenizer] | tuple[None, None]:
296
  """Load the model and tokenizer (cached)."""
297
- try:
298
- model = HybridFakeNewsDetector(
299
- bert_model_name=BERT_MODEL_NAME,
300
- lstm_hidden_size=LSTM_HIDDEN_SIZE,
301
- lstm_num_layers=LSTM_NUM_LAYERS,
302
- dropout_rate=DROPOUT_RATE
303
- )
304
- model_path = SAVED_MODELS_DIR / "final_model.pt"
305
- if not model_path.exists():
306
- st.error("Model file not found. Please ensure 'final_model.pt' is in the models/saved directory.")
307
- return None, None
308
- state_dict = torch.load(model_path, map_location=torch.device('cpu'))
309
- model_state_dict = model.state_dict()
310
- filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
311
- model.load_state_dict(filtered_state_dict, strict=False)
312
- model.eval()
313
- tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
314
- return model, tokenizer
315
- except Exception as e:
316
- st.error(f"Error loading model or tokenizer: {str(e)}")
317
- return None, None
 
 
318
 
319
  @st.cache_resource
320
- def get_preprocessor() -> TextPreprocessor | None:
321
  """Get the text preprocessor (cached)."""
322
- try:
323
- return TextPreprocessor()
324
- except Exception as e:
325
- st.error(f"Error initializing preprocessor: {str(e)}")
326
- return None
327
 
328
- def predict_news(text: str) -> dict | None:
329
  """Predict if the given news is fake or real."""
 
330
  model, tokenizer = load_model_and_tokenizer()
331
- if model is None or tokenizer is None:
332
- return None
333
  preprocessor = get_preprocessor()
334
- if preprocessor is None:
335
- return None
336
- try:
337
- processed_text = preprocessor.preprocess_text(text)
338
- encoding = tokenizer.encode_plus(
339
- processed_text,
340
- add_special_tokens=True,
341
- max_length=MAX_SEQUENCE_LENGTH,
342
- padding='max_length',
343
- truncation=True,
344
- return_attention_mask=True,
345
- return_tensors='pt'
 
 
 
 
 
 
 
 
346
  )
347
- with torch.no_grad():
348
- outputs = model(
349
- encoding['input_ids'],
350
- encoding['attention_mask']
351
- )
352
- probabilities = torch.softmax(outputs['logits'], dim=1)
353
- prediction = torch.argmax(outputs['logits'], dim=1)
354
- attention_weights = outputs.get('attention_weights', torch.zeros(1))
355
- attention_weights_np = attention_weights[0].cpu().numpy()
356
- return {
357
- 'prediction': prediction.item(),
358
- 'label': 'FAKE' if prediction.item() == 1 else 'REAL',
359
- 'confidence': torch.max(probabilities, dim=1)[0].item(),
360
- 'probabilities': {
361
- 'REAL': probabilities[0][0].item(),
362
- 'FAKE': probabilities[0][1].item()
363
- },
364
- 'attention_weights': attention_weights_np
365
- }
366
- except Exception as e:
367
- st.error(f"Prediction error: {str(e)}")
368
- return None
369
-
370
- def plot_confidence(probabilities: dict) -> go.Figure:
371
- """Plot prediction confidence with simplified styling."""
372
- if not probabilities or not isinstance(probabilities, dict):
373
- return go.Figure()
374
  fig = go.Figure(data=[
375
  go.Bar(
376
  x=list(probabilities.keys()),
377
  y=list(probabilities.values()),
378
- text=[f'{p:.1%}' for p in probabilities.values()],
379
  textposition='auto',
380
- marker=dict(
381
- color=['#10b981', '#ef4444'],
382
- line=dict(color='#ffffff', width=1),
383
- ),
384
  )
385
  ])
 
386
  fig.update_layout(
387
- title={'text': 'Prediction Confidence', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 18}},
388
- xaxis=dict(title='Classification', titlefont={'size': 12}, tickfont={'size': 10}),
389
- yaxis=dict(title='Probability', range=[0, 1], tickformat='.0%', titlefont={'size': 12}, tickfont={'size': 10}),
390
- template='plotly_white',
391
- height=300,
392
- margin=dict(t=60, b=60)
393
  )
 
394
  return fig
395
 
396
- def plot_attention(text: str, attention_weights: np.ndarray) -> go.Figure:
397
- """Plot attention weights with simplified styling."""
398
- if not text or not attention_weights.size:
399
- return go.Figure()
400
- tokens = text.split()[:20]
401
- attention_weights = attention_weights[:len(tokens)]
402
  if isinstance(attention_weights, (list, np.ndarray)):
403
  attention_weights = np.array(attention_weights).flatten()
404
- normalized_weights = attention_weights / max(attention_weights) if max(attention_weights) > 0 else attention_weights
405
- colors = [f'rgba(99, 102, 241, {0.4 + 0.6 * float(w)})' for w in normalized_weights]
 
 
406
  fig = go.Figure(data=[
407
  go.Bar(
408
  x=tokens,
409
  y=attention_weights,
410
- text=[f'{float(w):.3f}' for w in attention_weights],
411
  textposition='auto',
412
- marker=dict(color=colors),
413
  )
414
  ])
 
415
  fig.update_layout(
416
- title={'text': 'Attention Weights', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 18}},
417
- xaxis=dict(title='Words', tickangle=45, titlefont={'size': 12}, tickfont={'size': 10}),
418
- yaxis=dict(title='Attention Score', titlefont={'size': 12}, tickfont={'size': 10}),
419
- template='plotly_white',
420
- height=350,
421
- margin=dict(t=60, b=80)
422
  )
 
423
  return fig
424
 
425
  def main():
426
- # Main Container
427
- st.markdown('<div class="main-container">', unsafe_allow_html=True)
428
-
429
- # Header Section
430
- st.markdown("""
431
- <div class="header-section">
432
- <h1 class="header-title">🛡️ TruthCheck - Advanced Fake News Detector</h1>
433
- </div>
434
- """, unsafe_allow_html=True)
435
-
436
- # Hero Section
437
- st.markdown("""
438
- <div class="hero">
439
- <div class="hero-left">
440
- <h2 class="hero-title">Instant Fake News Detection</h2>
441
- <p class="hero-text">
442
- Verify news articles with our AI-powered tool, driven by advanced BERT and BiLSTM models for accurate authenticity analysis.
443
- </p>
444
- </div>
445
- <div class="hero-right">
446
- <img src="https://images.pexels.com/photos/267350/pexels-photo-267350.jpeg?auto=compress&cs=tinysrgb&w=500" alt="Fake News Illustration" onerror="this.src='https://via.placeholder.com/500x300.png?text=Fake+News+Illustration'">
447
- </div>
448
- </div>
449
- """, unsafe_allow_html=True)
450
-
451
- # About Section
452
- st.markdown("""
453
- <div class="about-section">
454
- <h2 class="about-title">About TruthCheck</h2>
455
- <p class="about-text">
456
- TruthCheck harnesses a hybrid BERT-BiLSTM model to detect fake news with high precision. Simply paste an article below to analyze its authenticity instantly.
457
- </p>
458
- </div>
459
- """, unsafe_allow_html=True)
460
-
461
- # Input Section
462
- st.markdown('<div class="input-container">', unsafe_allow_html=True)
463
  news_text = st.text_area(
464
- "Analyze a News Article",
465
- height=150,
466
- placeholder="Paste your news article here for instant AI analysis...",
467
- key="news_input"
468
  )
469
- st.markdown('</div>', unsafe_allow_html=True)
470
-
471
- # Analyze Button
472
- col1, col2, col3 = st.columns([1, 2, 1])
473
- with col2:
474
- analyze_button = st.button("🔍 Analyze Now", key="analyze_button")
475
-
476
- if analyze_button:
477
- if news_text and len(news_text.strip()) > 10:
478
- with st.spinner("Analyzing article..."):
479
  result = predict_news(news_text)
480
- if result:
481
- st.markdown('<div class="results-container">', unsafe_allow_html=True)
482
-
483
- # Prediction Result
484
- col1, col2 = st.columns([1, 1], gap="medium")
485
- with col1:
486
- if result['label'] == 'FAKE':
487
- st.markdown(f'''
488
- <div class="result-card fake-news">
489
- <div class="prediction-badge">🚨 Fake News Detected <span class="confidence-score">{result["confidence"]:.1%}</span></div>
490
- <p>Our AI has identified this content as likely misinformation based on linguistic patterns and context.</p>
491
- </div>
492
- ''', unsafe_allow_html=True)
493
- else:
494
- st.markdown(f'''
495
- <div class="result-card real-news">
496
- <div class="prediction-badge">✅ Authentic News <span class="confidence-score">{result["confidence"]:.1%}</span></div>
497
- <p>This content appears legitimate based on professional writing style and factual consistency.</p>
498
- </div>
499
- ''', unsafe_allow_html=True)
500
-
501
- with col2:
502
- st.markdown('<div class="chart-container">', unsafe_allow_html=True)
503
- st.plotly_chart(plot_confidence(result['probabilities']), use_container_width=True)
504
- st.markdown('</div>', unsafe_allow_html=True)
505
-
506
- # Attention Analysis
507
- st.markdown('<div class="chart-container">', unsafe_allow_html=True)
508
- st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True)
509
- st.markdown('</div></div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
510
  else:
511
- st.error("Please enter a news article (at least 10 words) for analysis.")
512
-
513
- # Footer
514
- st.markdown("---")
515
- st.markdown(
516
- '<p style="text-align: center; font-weight: 600; font-size: 16px;">💻 Developed with ❤️ using Streamlit | © 2025</p>',
517
- unsafe_allow_html=True
518
- )
519
-
520
- st.markdown('</div>', unsafe_allow_html=True) # Close main-container
521
 
522
  if __name__ == "__main__":
523
- main()
 
4
  import numpy as np
5
  from pathlib import Path
6
  import sys
7
+ import plotly.express as px
8
  import plotly.graph_objects as go
9
  from transformers import BertTokenizer
10
  import nltk
11
 
12
  # Download required NLTK data
13
+ try:
14
+ nltk.data.find('tokenizers/punkt')
15
+ except LookupError:
16
+ nltk.download('punkt')
17
+ try:
18
+ nltk.data.find('corpora/stopwords')
19
+ except LookupError:
20
+ nltk.download('stopwords')
21
+ try:
22
+ nltk.data.find('tokenizers/punkt_tab')
23
+ except LookupError:
24
+ nltk.download('punkt_tab')
25
+ try:
26
+ nltk.data.find('corpora/wordnet')
27
+ except LookupError:
28
+ nltk.download('wordnet')
29
 
30
  # Add project root to Python path
31
  project_root = Path(__file__).parent.parent
32
  sys.path.append(str(project_root))
33
 
34
  from src.models.hybrid_model import HybridFakeNewsDetector
35
+ from src.config.config import *
36
  from src.data.preprocessor import TextPreprocessor
37
 
38
+ # Page config is set in main app.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  @st.cache_resource
41
+ def load_model_and_tokenizer():
42
  """Load the model and tokenizer (cached)."""
43
+ # Initialize model
44
+ model = HybridFakeNewsDetector(
45
+ bert_model_name=BERT_MODEL_NAME,
46
+ lstm_hidden_size=LSTM_HIDDEN_SIZE,
47
+ lstm_num_layers=LSTM_NUM_LAYERS,
48
+ dropout_rate=DROPOUT_RATE
49
+ )
50
+
51
+ # Load trained weights
52
+ state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu'))
53
+
54
+ # Filter out unexpected keys
55
+ model_state_dict = model.state_dict()
56
+ filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
57
+
58
+ # Load the filtered state dict
59
+ model.load_state_dict(filtered_state_dict, strict=False)
60
+ model.eval()
61
+
62
+ # Initialize tokenizer
63
+ tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
64
+
65
+ return model, tokenizer
66
 
67
  @st.cache_resource
68
+ def get_preprocessor():
69
  """Get the text preprocessor (cached)."""
70
+ return TextPreprocessor()
 
 
 
 
71
 
72
+ def predict_news(text):
73
  """Predict if the given news is fake or real."""
74
+ # Get model, tokenizer, and preprocessor from cache
75
  model, tokenizer = load_model_and_tokenizer()
 
 
76
  preprocessor = get_preprocessor()
77
+
78
+ # Preprocess text
79
+ processed_text = preprocessor.preprocess_text(text)
80
+
81
+ # Tokenize
82
+ encoding = tokenizer.encode_plus(
83
+ processed_text,
84
+ add_special_tokens=True,
85
+ max_length=MAX_SEQUENCE_LENGTH,
86
+ padding='max_length',
87
+ truncation=True,
88
+ return_attention_mask=True,
89
+ return_tensors='pt'
90
+ )
91
+
92
+ # Get prediction
93
+ with torch.no_grad():
94
+ outputs = model(
95
+ encoding['input_ids'],
96
+ encoding['attention_mask']
97
  )
98
+ probabilities = torch.softmax(outputs['logits'], dim=1)
99
+ prediction = torch.argmax(outputs['logits'], dim=1)
100
+ attention_weights = outputs['attention_weights']
101
+
102
+ # Convert attention weights to numpy and get the first sequence
103
+ attention_weights_np = attention_weights[0].cpu().numpy()
104
+
105
+ return {
106
+ 'prediction': prediction.item(),
107
+ 'label': 'FAKE' if prediction.item() == 1 else 'REAL',
108
+ 'confidence': torch.max(probabilities, dim=1)[0].item(),
109
+ 'probabilities': {
110
+ 'REAL': probabilities[0][0].item(),
111
+ 'FAKE': probabilities[0][1].item()
112
+ },
113
+ 'attention_weights': attention_weights_np
114
+ }
115
+
116
+ def plot_confidence(probabilities):
117
+ """Plot prediction confidence."""
 
 
 
 
 
 
 
118
  fig = go.Figure(data=[
119
  go.Bar(
120
  x=list(probabilities.keys()),
121
  y=list(probabilities.values()),
122
+ text=[f'{p:.2%}' for p in probabilities.values()],
123
  textposition='auto',
 
 
 
 
124
  )
125
  ])
126
+
127
  fig.update_layout(
128
+ title='Prediction Confidence',
129
+ xaxis_title='Class',
130
+ yaxis_title='Probability',
131
+ yaxis_range=[0, 1]
 
 
132
  )
133
+
134
  return fig
135
 
136
+ def plot_attention(text, attention_weights):
137
+ """Plot attention weights."""
138
+ tokens = text.split()
139
+ attention_weights = attention_weights[:len(tokens)] # Truncate to match tokens
140
+
141
+ # Ensure attention weights are in the correct format
142
  if isinstance(attention_weights, (list, np.ndarray)):
143
  attention_weights = np.array(attention_weights).flatten()
144
+
145
+ # Format weights for display
146
+ formatted_weights = [f'{float(w):.2f}' for w in attention_weights]
147
+
148
  fig = go.Figure(data=[
149
  go.Bar(
150
  x=tokens,
151
  y=attention_weights,
152
+ text=formatted_weights,
153
  textposition='auto',
 
154
  )
155
  ])
156
+
157
  fig.update_layout(
158
+ title='Attention Weights',
159
+ xaxis_title='Tokens',
160
+ yaxis_title='Attention Weight',
161
+ xaxis_tickangle=45
 
 
162
  )
163
+
164
  return fig
165
 
166
  def main():
167
+ st.title("📰 Fake News Detection System")
168
+ st.write("""
169
+ This application uses a hybrid deep learning model (BERT + BiLSTM + Attention)
170
+ to detect fake news articles. Enter a news article below to analyze it.
171
+ """)
172
+
173
+ # Sidebar
174
+ st.sidebar.title("About")
175
+ st.sidebar.info("""
176
+
177
+ The model combines:
178
+ - BERT for contextual embeddings
179
+ - BiLSTM for sequence modeling
180
+ - Attention mechanism for interpretability
181
+ """)
182
+
183
+ # Main content
184
+ st.header("News Analysis")
185
+
186
+ # Text input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  news_text = st.text_area(
188
+ "Enter the news article to analyze:",
189
+ height=200,
190
+ placeholder="Paste your news article here..."
 
191
  )
192
+
193
+ if st.button("Analyze"):
194
+ if news_text:
195
+ with st.spinner("Analyzing the news article..."):
196
+ # Get prediction
 
 
 
 
 
197
  result = predict_news(news_text)
198
+
199
+ # Display result
200
+ col1, col2 = st.columns(2)
201
+
202
+ with col1:
203
+ st.subheader("Prediction")
204
+ if result['label'] == 'FAKE':
205
+ st.error(f"🔴 This news is likely FAKE (Confidence: {result['confidence']:.2%})")
206
+ else:
207
+ st.success(f"🟢 This news is likely REAL (Confidence: {result['confidence']:.2%})")
208
+
209
+ with col2:
210
+ st.subheader("Confidence Scores")
211
+ st.plotly_chart(plot_confidence(result['probabilities']), use_container_width=True)
212
+
213
+ # Show attention visualization
214
+ st.subheader("Attention Analysis")
215
+ st.write("""
216
+ The attention weights show which parts of the text the model focused on
217
+ while making its prediction. Higher weights indicate more important tokens.
218
+ """)
219
+ st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True)
220
+
221
+ # Show model explanation
222
+ st.subheader("Model Explanation")
223
+ if result['label'] == 'FAKE':
224
+ st.write("""
225
+ The model identified this as fake news based on:
226
+ - Linguistic patterns typical of fake news
227
+ - Inconsistencies in the content
228
+ - Attention weights on suspicious phrases
229
+ """)
230
+ else:
231
+ st.write("""
232
+ The model identified this as real news based on:
233
+ - Credible language patterns
234
+ - Consistent information
235
+ - Attention weights on factual statements
236
+ """)
237
  else:
238
+ st.warning("Please enter a news article to analyze.")
 
 
 
 
 
 
 
 
 
239
 
240
  if __name__ == "__main__":
241
+ main()