billusanda007 commited on
Commit
8db25aa
·
verified ·
1 Parent(s): 5794e14

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +623 -0
app.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
+ import gradio as gr
4
+ import pandas as pd
5
+ from collections import Counter, defaultdict
6
+ import os
7
+ from huggingface_hub import login
8
+ import requests
9
+ from bs4 import BeautifulSoup
10
+ import numpy as np
11
+ import re
12
+ from sklearn.feature_extraction.text import TfidfVectorizer
13
+ from sklearn.metrics.pairwise import cosine_similarity
14
+ import googlesearch
15
+ import time
16
+
17
+ import nltk
18
+ nltk.download('punkt')
19
+ from sentence_transformers import SentenceTransformer, util
20
+ model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
21
+
22
+
23
+ def fetch_article_text_sequential(url):
24
+ headers = {
25
+ "Content-Type": "application/json",
26
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
27
+ }
28
+
29
+ exclude=["Thank you for your patience","Subscribe","subscribe","trouble retrieving the article content","browser settings",
30
+ "Thank you for your patience while we verify access. If you are in Reader mode please exit and log into your Times account, or subscribe for all of The Times.",
31
+ "Thank you for your patience while we verify access.",
32
+ "Already a subscriber? Log in.",
33
+ "Want all of The Times? Subscribe.",
34
+ "Advertisement",
35
+ "Site Index",
36
+ "Thank you for your patience while we verify access. If you are in Reader mode please exit andlog intoyour Times account, orsubscribefor all of The Times.",
37
+ "Already a subscriber?Log in.",
38
+ "Want all of The Times?Subscribe.",
39
+ "Site Information Navigation"
40
+ ]
41
+
42
+ try:
43
+
44
+ # Send a request to the webpage with the specified headers
45
+ response = requests.get(url, headers=headers)
46
+ response.raise_for_status() # Check that the request was successful
47
+
48
+ # Parse the webpage content
49
+ soup = BeautifulSoup(response.text, 'html.parser')
50
+
51
+ # Initialize an empty list to store the text sequentially
52
+ article_content = []
53
+
54
+ # Define the tags we are interested in (headlines and paragraphs)
55
+ tags_of_interest = ['h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'p']
56
+
57
+ # Find all tags of interest in the order they appear in the document
58
+ for tag in soup.find_all(tags_of_interest):
59
+ if not any(excluded_phrase in tag.get_text() for excluded_phrase in exclude):
60
+ text = tag.get_text(strip=True)
61
+ article_content.append(text)
62
+
63
+ return '\n'.join(article_content)
64
+
65
+ except:
66
+ return None
67
+
68
+
69
+ def get_google_search_results(query, start=0):
70
+ search_url = "https://www.google.com/search"
71
+ params = {"q": query, "start": start}
72
+ headers = {
73
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
74
+ }
75
+
76
+ response = requests.get(search_url, params=params, headers=headers)
77
+ soup = BeautifulSoup(response.text, "html.parser")
78
+
79
+ search_results = []
80
+ for g in soup.find_all(class_="g"):
81
+ title = g.find("h3").text if g.find("h3") else "No title"
82
+ link = g.find("a")["href"] if g.find("a") else "No link"
83
+
84
+ if not link.lower().endswith(('.pdf', '.PDF')):
85
+ search_results.append({"title": title, "link": link})
86
+
87
+ return search_results
88
+
89
+
90
+
91
+ def fetch_sentences_from_html(html):
92
+ try:
93
+ # Parse the string with BeautifulSoup
94
+ if html == None:
95
+ return []
96
+ soup = BeautifulSoup(html, 'html.parser')
97
+ paragraphs = soup.find_all("p")
98
+ text = " ".join(p.get_text() for p in paragraphs)
99
+ sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
100
+
101
+ ##print(sentences)
102
+
103
+ return sentences
104
+ except Exception as e:
105
+ ##print(f"Failed to fetch {html}: {str(e)}")
106
+ return []
107
+
108
+
109
+
110
+ # Function to rank sentences using cosine similarity
111
+ def rank_sentences(sentences):
112
+ if not sentences:
113
+ return [] # Return an empty list if no sentences are found
114
+
115
+ embeddings = model.encode(sentences, convert_to_tensor=True)
116
+
117
+ # Compute pairwise cosine similarity between sentences
118
+ similarities = util.pytorch_cos_sim(embeddings, embeddings).cpu().numpy()
119
+
120
+ # Calculate the average similarity for each sentence
121
+ avg_similarities = np.mean(similarities, axis=1)
122
+
123
+ # Rank sentences based on their average similarity
124
+ ranked_sentences = sorted(zip(sentences, avg_similarities), key=lambda x: x[1], reverse=True)
125
+ ranked_sentences = [sentence for sentence, _ in ranked_sentences]
126
+
127
+
128
+ return ranked_sentences
129
+
130
+
131
+
132
+ def rank_sentences_new(sentences, query, top_n=20):
133
+ if sentences == None:
134
+ return []
135
+ sentences = re.split("\n", sentences.strip())
136
+ # Remove any empty strings from the list
137
+ [sentence.strip() for sentence in sentences if sentence.strip()]
138
+ vectorizer = TfidfVectorizer().fit_transform([query] + sentences)
139
+ vectors = vectorizer.toarray()
140
+ query_vector = vectors[0]
141
+ sentences_vectors = vectors[1:]
142
+ cosine_similarities = cosine_similarity([query_vector], sentences_vectors).flatten()
143
+ ranked_indices = cosine_similarities.argsort()[-top_n:][::-1]
144
+ return [sentences[idx] for idx in ranked_indices]
145
+
146
+
147
+
148
+ domains = [
149
+ "wikipedia.org", "nytimes.com", "cnn.com", "bbc.com", "theguardian.com",
150
+ "forbes.com", "reuters.com", "cnbc.com", "bloomberg.com", "foxnews.com",
151
+ "npr.org", "washingtonpost.com", "wsj.com", "aljazeera.com", "ft.com",
152
+ "huffpost.com", "nationalgeographic.com", "scientificamerican.com",
153
+ "nature.com", "time.com", "usatoday.com", "apnews.com", "abcnews.go.com",
154
+ "cbsnews.com", "nbcnews.com", "news.yahoo.com", "theatlantic.com",
155
+ "vox.com", "politico.com", "economist.com"
156
+ ]
157
+
158
+ exclude=["Thank you for your patience","Subscribe","subscribe","trouble retrieving the article content","browser settings",
159
+ "Thank you for your patience while we verify access. If you are in Reader mode please exit and log into your Times account, or subscribe for all of The Times.",
160
+ "Thank you for your patience while we verify access.",
161
+ "Already a subscriber? Log in.",
162
+ "Want all of The Times? Subscribe.",
163
+ "Advertisement",
164
+ "Site Index",
165
+ "Thank you for your patience while we verify access. If you are in Reader mode please exit andlog intoyour Times account, orsubscribefor all of The Times.",
166
+ "Already a subscriber?Log in.",
167
+ "Want all of The Times?Subscribe.",
168
+ "Site Information Navigation",
169
+ "Please enable JS and disable any ad blocker"
170
+ ]
171
+
172
+ # Define number of results we want to retrieve
173
+ num_results_needed = 10
174
+ all_results = []
175
+ start = 0
176
+
177
+ # Ask the user for a search query
178
+ # user_query = input("Enter a search query: ")
179
+
180
+
181
+ def get_web_content(user_query,num_results_needed):
182
+
183
+ all_results = []
184
+ start = 0
185
+
186
+ t1=time.time()
187
+
188
+ while len(all_results) < num_results_needed:
189
+ results = get_google_search_results(user_query, start=start)
190
+
191
+ all_results.extend(results)
192
+ all_results = all_results[:num_results_needed] # Ensure no more than needed results
193
+ start += 10
194
+
195
+ all_sentences_2 = []
196
+ # #print the search results and top sentences from each URL
197
+
198
+
199
+ delimiter='\n'
200
+
201
+ ans = []
202
+
203
+ for result in all_results:
204
+ #print(f"Title: {result['title']}")
205
+ #print(f"Link: {result['link']}")
206
+ # sentences = get_top_sentences(result['link'])
207
+ text = fetch_article_text_sequential(result['link'])
208
+
209
+ top_sentences = rank_sentences_new(text, user_query)
210
+
211
+ ans=[]
212
+
213
+
214
+ for sentence in top_sentences:
215
+ if not any(excluded_phrase in sentence for excluded_phrase in exclude):
216
+ #print(sentence)
217
+ ans.append(sentence)
218
+
219
+ if(len(ans))==15:
220
+ break
221
+
222
+ all_sentences_2.extend(ans)
223
+
224
+ #print()
225
+
226
+
227
+
228
+
229
+ t2=time.time()
230
+ minutes, seconds = divmod(t2-t1, 60)
231
+
232
+ #print(f"{minutes} minutes and {seconds} seconds")
233
+
234
+
235
+ ans = "\n".join(sentence.strip() for sentence in all_sentences_2 if sentence.strip())
236
+ return ans , all_sentences_2
237
+
238
+
239
+ def get_web_content_new(user_query,num_results_needed):
240
+
241
+ all_results = []
242
+ start = 0
243
+
244
+ t1=time.time()
245
+
246
+ while len(all_results) < num_results_needed:
247
+ results = get_google_search_results(user_query, start=start)
248
+
249
+ all_results.extend(results)
250
+ all_results = all_results[:num_results_needed] # Ensure no more than needed results
251
+ start += 10
252
+
253
+ all_sentences = []
254
+ # #print the search results and top sentences from each URL
255
+
256
+ all_sentences_2 = []
257
+
258
+
259
+ delimiter='\n'
260
+
261
+ for result in all_results:
262
+ ##print(f"Title: {result['title']}")
263
+ ##print(f"Link: {result['link']}")
264
+
265
+ text = fetch_article_text_sequential(result['link'])
266
+
267
+ ######
268
+
269
+ ##print(text)
270
+
271
+ sentences = nltk.sent_tokenize(text)
272
+ sentences=sentences[:min(150,len(sentences))]
273
+ all_sentences.extend(sentences)
274
+ ranked_sentences = rank_sentences(all_sentences)
275
+ #print("Ranked Sentences")
276
+ #print("ranked_sentences",ranked_sentences,"\n\n")
277
+
278
+ ans2=[]
279
+
280
+
281
+ for sentence in ranked_sentences:
282
+ if not any(excluded_phrase in sentence for excluded_phrase in exclude):
283
+ ##print(sentence)
284
+ ans2.append(sentence)
285
+
286
+ if(len(ans2))==15:
287
+ break
288
+
289
+ all_sentences_2.extend(ans2)
290
+
291
+ ##print()
292
+
293
+
294
+
295
+ t2=time.time()
296
+ minutes, seconds = divmod(t2-t1, 60)
297
+
298
+ ##print(f"{minutes} minutes and {seconds} seconds")
299
+
300
+ #return "\n".join(sentence.strip() for sentence in all_sentences_2 if sentence.strip())
301
+ #return text
302
+ return ranked_sentences
303
+
304
+
305
+ #sentences, sent = get_web_content("Who has been awarded the Nobel Prize in Physics in 2023",2)
306
+ #res = get_web_content(Question[0],10)
307
+ #Context = res[0]
308
+
309
+
310
+ # Get the token from the environment variable
311
+ api_token = os.getenv('HF_TOKEN')
312
+
313
+ # Load pre-trained model and tokenizer
314
+ model_name = "gpt2-large"
315
+ model = GPT2LMHeadModel.from_pretrained(model_name)
316
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
317
+
318
+ device = torch.device("mps")
319
+ model.to(device)
320
+ model.eval()
321
+
322
+
323
+ top_p = 0.9
324
+ threshold = 0.6
325
+ max_length = 100
326
+ #context_tokens = tokenizer.tokenize(Context)
327
+
328
+
329
+
330
+ def create_ngrams(tokens, n): return [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)]
331
+
332
+
333
+ ###Smoothing___
334
+ def kneser_ney_smoothing(ngram_counts, lower_order_counts, discount=0.75):
335
+ """
336
+ Apply Kneser-Ney smoothing to n-gram counts.
337
+
338
+ Args:
339
+ ngram_counts (Counter): Counts of n-grams (e.g., 4-grams or 3-grams).
340
+ lower_order_counts (Counter): Counts of (n-1)-grams (e.g., 3-grams or 2-grams).
341
+ discount (float): Discounting parameter.
342
+
343
+ Returns:
344
+ defaultdict: Smoothed probabilities.
345
+ """
346
+ continuation_counts = Counter()
347
+ lower_counts = Counter()
348
+
349
+ for ngram in ngram_counts:
350
+ lower_ngram = ngram[1:]
351
+ continuation_counts[lower_ngram] += 1
352
+ lower_counts[lower_ngram] += 1
353
+
354
+ def continuation_probability(word):
355
+ return continuation_counts[word] / sum(continuation_counts.values())
356
+
357
+ probabilities = defaultdict(lambda: defaultdict(float))
358
+
359
+ for ngram, count in ngram_counts.items():
360
+ lower_ngram = ngram[:-1]
361
+ lower_count = lower_order_counts[lower_ngram]
362
+ discounted_count = max(count - discount, 0)
363
+ lambda_factor = (discount / lower_count) * len(continuation_counts)
364
+ probabilities[lower_ngram][ngram[-1]] = (discounted_count / lower_count) + lambda_factor * continuation_probability(ngram[-1])
365
+
366
+ return probabilities
367
+
368
+
369
+ def get_probability_from_context(Context):
370
+
371
+ context_tokens = tokenizer.tokenize(Context)
372
+ four_grams = create_ngrams(context_tokens, 4)
373
+ three_grams = create_ngrams(context_tokens, 3)
374
+ four_gram_counts = Counter(four_grams)
375
+ three_gram_counts = Counter(three_grams)
376
+ probabilities = kneser_ney_smoothing(four_gram_counts, three_gram_counts)
377
+
378
+ return probabilities, four_gram_counts, three_gram_counts
379
+
380
+
381
+ #_probabilities__, four_gram_counts, three_gram_counts = get_probability_from_context(Context)
382
+ #input_tokens = tokenizer.tokenize(initial_text)
383
+ #input_3_gram = tuple(input_tokens[-3:])
384
+
385
+
386
+ def predict_next_token(probabilities, three_gram): return probabilities.get(three_gram, {})
387
+
388
+
389
+
390
+ #next_token_probs = predict_next_token(_probabilities__, input_3_gram)
391
+ #top_k = 4
392
+ #top_k_tokens = sorted(next_token_probs.items(), key=lambda x: x[1], reverse=True)[:top_k]
393
+ #probs = (kneser_ney_smoothing(four_gram_counts, three_gram_counts))
394
+ #next_token_probs = predict_next_token(probs, input_3_gram)
395
+
396
+
397
+
398
+ def generate_text_with_probs(initial_context, context_text , top_p, max_length, top_k, threshold=0.6):
399
+
400
+ Tokens = {}
401
+
402
+ input_ids = tokenizer.encode(initial_context, return_tensors="pt").to(device='mps')
403
+ #input_ids = tokenizer.encode(initial_text, return_tensors="pt")
404
+ generated_text = initial_context
405
+ token_tables = []
406
+
407
+ token_no = 1
408
+
409
+ context_tokens = tokenizer.tokenize(context_text)
410
+
411
+ four_grams = create_ngrams(context_tokens, 4)
412
+ three_grams = create_ngrams(context_tokens, 3)
413
+ two_grams = create_ngrams(context_tokens, 2)
414
+ one_grams = create_ngrams(context_tokens, 1)
415
+
416
+ four_gram_counts = Counter(four_grams)
417
+ three_gram_counts = Counter(three_grams)
418
+ two_grams_counts = Counter(two_grams)
419
+ one_grams_counts = Counter(one_grams)
420
+
421
+ prob_list = ["four_gram", "three_gram", "two_gram", "one_gram"] # Define prob_list here
422
+
423
+
424
+ prob = [four_gram_counts ,three_gram_counts ,two_grams_counts ,one_grams_counts]
425
+ probs = kneser_ney_smoothing(four_gram_counts, three_gram_counts)
426
+
427
+ use_llm = 0
428
+ use_llm_back_up = 0
429
+ use_ngram = 0
430
+
431
+ flag = False
432
+ count = 0
433
+
434
+ Token_index = 0
435
+ colored_text = initial_context
436
+
437
+
438
+ with torch.no_grad():
439
+
440
+ #while len(generated_text.split()) < max_length:
441
+ for _ in range(max_length):
442
+
443
+ outputs = model(input_ids=input_ids)
444
+ next_token_logits = outputs.logits[:, -1, :]
445
+
446
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
447
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
448
+ sorted_indices_to_remove = cumulative_probs > top_p
449
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
450
+ sorted_indices_to_remove[..., 0] = 0
451
+
452
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
453
+ next_token_logits[:, indices_to_remove] = -float('Inf')
454
+ probabilities = torch.softmax(next_token_logits, dim=-1)
455
+
456
+ top_tokens = sorted_indices[0, :top_k]
457
+ top_probs = probabilities[0, top_tokens]
458
+ top_token_probs = [(tokenizer.decode([token.item()]), prob.item()) for token, prob in zip(top_tokens, top_probs)]
459
+
460
+ df = pd.DataFrame(top_token_probs, columns=["Token", "Probability"])
461
+ df.index = df.index + 1
462
+ token_tables.append((f"{token_no}>> Next token options from LLM", df))
463
+
464
+
465
+
466
+ ##print("Next token options from LLM")
467
+ ##print(df)
468
+
469
+ cumulative_prob = cumulative_probs[0, top_k - 1].item()
470
+ ##print(f"cumulative_prob from LLM: {cumulative_prob}")
471
+ entropy = (-1)*np.sum(np.array(df['Probability'])*np.log(df['Probability']))
472
+ ##print("LLM Entropy:",(-1)*np.sum(np.array(df['Probability'])*np.log(df['Probability'])))
473
+ ##print("\n")
474
+
475
+ input_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
476
+ input_tokens = tokenizer.tokenize(input_text)
477
+
478
+ use_llm += 1
479
+ __token_pob__ = {}
480
+
481
+ num = 0
482
+ num_ = 4
483
+ while __token_pob__ == {} and num < 3:
484
+
485
+ probs = kneser_ney_smoothing(prob[num],prob[num+1])
486
+ __inputs__ = tuple(input_tokens[-(3-num):])
487
+ __token_pob__ = probs.get(__inputs__, {})
488
+
489
+ ##print(num,"\n",num_)
490
+
491
+ num += 1
492
+ num_ -= 1
493
+
494
+
495
+
496
+
497
+ ##print(f"Next word probs N_GRAM:{__token_pob__},\n input_{num_}_gram: {__inputs__},\n using {prob_list[num]}_counter and {prob_list[num-1]}_counter; probability exist: {__token_pob__ != {}}")
498
+ df = pd.DataFrame(list(__token_pob__.items()), columns=['Token', 'Probability'])
499
+ df.index = df.index + 1
500
+ token_tables.append((f"{token_no}>> Next token options from N_gram", df))
501
+
502
+ token_no +=1
503
+ ##print(f"Next token options from N_GRAM:")
504
+ ##print(df)
505
+ ##print("Cumulative Probability of N_gram:",np.sum(df['Probability']))
506
+
507
+ #print("\n")
508
+
509
+ if cumulative_prob < threshold and __token_pob__ != {} and flag == True and count >= 4 or np.sum(df['Probability']) > cumulative_prob:
510
+ Token_index+=1
511
+ #if cumulative_prob < threshold and __token_pob__ != {} and flag == True and count >= 4 or entropy >= 0.6:
512
+
513
+
514
+ ##print("Using n-gram model")
515
+ next_token = max(__token_pob__, key=__token_pob__.get)
516
+
517
+ if next_token == 'Ċ':
518
+ sorted_tokens = sorted(__token_pob__.items(), key=lambda x: x[1], reverse=True)
519
+ if len(sorted_tokens) > 1:
520
+ next_token = sorted_tokens[1][0]
521
+ ##print("Second max token : ", next_token)
522
+ Tokens[Token_index] = [next_token,"ngram",__token_pob__[next_token]]
523
+ #######
524
+ color_code = "#78bfd3" # Light blue for n-gram
525
+ colored_text += f"<span style='color: {color_code}'>{tokenizer.convert_tokens_to_string(next_token)}</span>"
526
+ else:
527
+ Tokens[Token_index] = [next_token,"ngram",__token_pob__[next_token]]
528
+ ######
529
+ color_code = "#78bfd3" # Light blue for n-gram
530
+ colored_text += f"<span style='color: {color_code}'>{tokenizer.convert_tokens_to_string(next_token)}</span>"
531
+
532
+
533
+
534
+ ##print("n-gram token : ",next_token)
535
+ input_tokens.append(next_token)
536
+ generated_text = tokenizer.convert_tokens_to_string(input_tokens)
537
+
538
+ ##print(generated_text)
539
+ initial_context = generated_text
540
+ input_ids = tokenizer.encode(generated_text, return_tensors="pt").to(device='mps')
541
+
542
+ use_ngram += 1
543
+
544
+
545
+ else:
546
+
547
+ ##print("Using LLM")
548
+ Token_index+=1
549
+ next_token = torch.multinomial(probabilities, num_samples=1)
550
+ next_token_prob = probabilities[0, next_token].item()
551
+ next_token_text = tokenizer.decode(next_token.item())
552
+
553
+ ##print("LLM token : ",next_token_text)
554
+ Tokens[Token_index] = [next_token_text,"llm",next_token_prob]
555
+ color_code = "#c99a6e"
556
+ colored_text += f"<span style='color: {color_code}'>{next_token_text}</span>"
557
+ count += 1
558
+
559
+ if count >= 4:
560
+ flag = True
561
+
562
+ #token_no += 1
563
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
564
+
565
+
566
+ if next_token.item() == tokenizer.eos_token_id:
567
+ break
568
+ generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
569
+ ##print(generated_text)
570
+ initial_context = generated_text
571
+ use_llm_back_up += 1
572
+
573
+ ##print(initial_context)
574
+ ##print('-------------------------------------------------------------------------------------------------------------------------------------------------------------\n\n')
575
+ ##print("\n\n")
576
+
577
+ generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
578
+
579
+ total = use_llm + use_llm_back_up + use_ngram
580
+
581
+ ##print(f"total: {use_llm} ({(use_llm / total) * 100:.2f}%)")
582
+ ##print(f"use_llms: {use_llm_back_up} ({(use_llm_back_up / total) * 100:.2f}%)")
583
+ ##print(f"use_ngram: {use_ngram} ({(use_ngram / total) * 100:.2f}%)")
584
+ ##print('-------------------------------------------------------------------------------------------------------------------------------------------------------------\n\n')
585
+
586
+
587
+
588
+
589
+
590
+ return generated_text, Tokens, token_tables,colored_text
591
+
592
+
593
+
594
+ def combined_model_predictions(query, initial_context, top_p, max_length, top_k, threshold, docs):
595
+ Question = [query]
596
+ context_text = get_web_content(Question[0], docs)[0]
597
+ generated_text, tokens, token_tables, colored_html = generate_text_with_probs(initial_context, context_text, top_p, max_length, top_k, threshold)
598
+ data_list = [(token_index, tupes[0], tupes[1], tupes[2]) for token_index, tupes in tokens.items()]
599
+ df = pd.DataFrame(data_list, columns=['Token_pos', 'Token', 'Source Model', "Probability"])
600
+
601
+ return colored_html, df, token_tables
602
+
603
+
604
+ iface = gr.Interface(
605
+ fn=combined_model_predictions,
606
+ inputs=[
607
+ gr.Textbox(lines=2,placeholder="Enter query here..."),
608
+ gr.Textbox(lines=2,placeholder="Enter initial context here..."),
609
+ gr.Slider(0, 1, step=0.01, value=0.9, label="Top-p (nucleus) sampling"),
610
+ gr.Slider(1, 100, value= 4, step=1, label="Max Length"),
611
+ gr.Slider(1, 50, value= 5, step=1, label="Top-k"),
612
+ gr.Slider(0, 1, step=0.01, value=0.9, label="LLM cumulative Threshold"),
613
+ gr.Slider(1, 50, step=1, value=10, label="Web_retrieved Docs to fetch")
614
+ ],
615
+ outputs=[
616
+ gr.HTML(label="Generated Text"),
617
+ gr.Dataframe(label="Tokens"),
618
+ gr.Dataframe(label="Token tables"),
619
+ ],
620
+ title="Next Token Visualizer (GPT-2-large - 812M param.)"
621
+ )
622
+
623
+ iface.launch()