#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Created by zd302 at 08/07/2024
import gradio as gr
import tqdm
import torch
import numpy as np
from time import sleep
from datetime import datetime
import threading
import gc
import os
import json
import pytorch_lightning as pl
from urllib.parse import urlparse
from accelerate import Accelerator
import spaces
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from rank_bm25 import BM25Okapi
# import bm25s
# import Stemmer # optional: for stemming
from html2lines import url2lines
from googleapiclient.discovery import build
from averitec.models.DualEncoderModule import DualEncoderModule
from averitec.models.SequenceClassificationModule import SequenceClassificationModule
from averitec.models.JustificationGenerationModule import JustificationGenerationModule
from averitec.data.sample_claims import CLAIMS_Type
# ---------------------------------------------------------------------------
# load .env
from utils import create_user_id
user_id = create_user_id()
from azure.storage.fileshare import ShareServiceClient
try:
from dotenv import load_dotenv
load_dotenv()
except Exception as e:
pass
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
account_url = os.environ["AZURE_ACCOUNT_URL"]
credential = {
"account_key": os.environ['AZURE_ACCOUNT_KEY'],
"account_name": os.environ['AZURE_ACCOUNT_NAME']
}
file_share_name = "averitec"
azure_service = ShareServiceClient(account_url=account_url, credential=credential)
azure_share_client = azure_service.get_share_client(file_share_name)
# ---------- Setting ----------
import requests
from bs4 import BeautifulSoup
import wikipediaapi
wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC (zd302@cam.ac.uk)', 'en')
import nltk
nltk.download('averaged_perceptron_tagger_eng')
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')
nltk.download('punkt_tab')
from nltk import pos_tag, word_tokenize, sent_tokenize
import spacy
os.system("python -m spacy download en_core_web_sm")
nlp = spacy.load("en_core_web_sm")
# ---------------------------------------------------------------------------
# Load sample dict for AVeriTeC search
# all_samples_dict = json.load(open('averitec/data/all_samples.json', 'r'))
train_examples = json.load(open('averitec/data/train.json', 'r'))
def claim2prompts(example):
claim = example["claim"]
# claim_str = "Claim: " + claim + "||Evidence: "
claim_str = "Evidence: "
for question in example["questions"]:
q_text = question["question"].strip()
if len(q_text) == 0:
continue
if not q_text[-1] == "?":
q_text += "?"
answer_strings = []
for a in question["answers"]:
if a["answer_type"] in ["Extractive", "Abstractive"]:
answer_strings.append(a["answer"])
if a["answer_type"] == "Boolean":
answer_strings.append(a["answer"] + ", because " + a["boolean_explanation"].lower().strip())
for a_text in answer_strings:
if not a_text[-1] in [".", "!", ":", "?"]:
a_text += "."
# prompt_lookup_str = claim + " " + a_text
prompt_lookup_str = a_text
this_q_claim_str = claim_str + " " + a_text.strip() + "||Question answered: " + q_text
yield (prompt_lookup_str, this_q_claim_str.replace("\n", " ").replace("||", "\n"))
def generate_reference_corpus(reference_file):
all_data_corpus = []
tokenized_corpus = []
for train_example in train_examples:
train_claim = train_example["claim"]
speaker = train_example["speaker"].strip() if train_example["speaker"] is not None and len(
train_example["speaker"]) > 1 else "they"
questions = [q["question"] for q in train_example["questions"]]
claim_dict_builder = {}
claim_dict_builder["claim"] = train_claim
claim_dict_builder["speaker"] = speaker
claim_dict_builder["questions"] = questions
tokenized_corpus.append(nltk.word_tokenize(claim_dict_builder["claim"]))
all_data_corpus.append(claim_dict_builder)
return tokenized_corpus, all_data_corpus
def generate_step2_reference_corpus(reference_file):
prompt_corpus = []
tokenized_corpus = []
for example in train_examples:
for lookup_str, prompt in claim2prompts(example):
entry = nltk.word_tokenize(lookup_str)
tokenized_corpus.append(entry)
prompt_corpus.append(prompt)
return tokenized_corpus, prompt_corpus
reference_file = "averitec/data/train.json"
tokenized_corpus0, all_data_corpus0 = generate_reference_corpus(reference_file)
qg_bm25 = BM25Okapi(tokenized_corpus0)
tokenized_corpus1, prompt_corpus1 = generate_step2_reference_corpus(reference_file)
prompt_bm25 = BM25Okapi(tokenized_corpus1)
# print(train_examples[0]['claim'])
# ---------------------------------------------------------------------------
# ---------- Load pretrained models ----------
# ---------- load Evidence retrieval model ----------
# from drqa import retriever
# db_class = retriever.get_class('sqlite')
# doc_db = db_class("averitec/data/wikipedia_dumps/enwiki.db")
# ranker = retriever.get_class('tfidf')(tfidf_path="averitec/data/wikipedia_dumps/enwiki-tfidf-with-id-title.npz")
# ---------- Load Veracity and Justification prediction model ----------
print("Loading models ...")
LABEL = [
"Supported",
"Refuted",
"Not Enough Evidence",
"Conflicting Evidence/Cherrypicking",
]
if torch.cuda.is_available():
# # device
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
# question generation
qg_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-1b1")
qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-1b1", torch_dtype=torch.bfloat16).to('cuda')
# qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
# qg_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
# qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
# rerank
rerank_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
rereank_bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2, problem_type="single_label_classification") # Must specify single_label for some reason
best_checkpoint = "averitec/pretrained_models/bert_dual_encoder.ckpt"
rerank_trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=rerank_tokenizer, model=rereank_bert_model)
# rerank_trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=rerank_tokenizer, model=rereank_bert_model).to(device)
# Veracity
veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt", tokenizer=veracity_tokenizer, model=bert_model)
# veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt", tokenizer=veracity_tokenizer, model=bert_model).to(device)
# Justification
justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model)
# justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
# Set up Gradio Theme
theme = gr.themes.Base(
primary_hue="blue",
secondary_hue="red",
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
)
# ---------- Setting ----------
class Docs:
def __init__(self, metadata=dict(), page_content=""):
self.metadata = metadata
self.page_content = page_content
def make_html_source(source, i):
meta = source.metadata
content = source.page_content.strip()
card = f"""
"""
return card
# ----- veracity_prediction -----
class SequenceClassificationDataLoader(pl.LightningDataModule):
def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False):
super().__init__()
self.tokenizer = tokenizer
self.data_file = data_file
self.batch_size = batch_size
self.add_extra_nee = add_extra_nee
def tokenize_strings(
self,
source_sentences,
max_length=400,
pad_to_max_length=False,
return_tensors="pt",
):
encoded_dict = self.tokenizer(
source_sentences,
max_length=max_length,
padding="max_length" if pad_to_max_length else "longest",
truncation=True,
return_tensors=return_tensors,
)
input_ids = encoded_dict["input_ids"]
attention_masks = encoded_dict["attention_mask"]
return input_ids, attention_masks
def quadruple_to_string(self, claim, question, answer, bool_explanation=""):
if bool_explanation is not None and len(bool_explanation) > 0:
bool_explanation = ", because " + bool_explanation.lower().strip()
else:
bool_explanation = ""
return (
"[CLAIM] "
+ claim.strip()
+ " [QUESTION] "
+ question.strip()
+ " "
+ answer.strip()
+ bool_explanation
)
@spaces.GPU
def veracity_prediction(claim, qa_evidence):
dataLoader = SequenceClassificationDataLoader(
tokenizer=veracity_tokenizer,
data_file="this_is_discontinued",
batch_size=32,
add_extra_nee=False,
)
evidence_strings = []
for evidence in qa_evidence:
evidence_strings.append(
dataLoader.quadruple_to_string(claim, evidence.metadata["query"], evidence.metadata["answer"], ""))
if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI.
pred_label = "Not Enough Evidence"
return pred_label
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
example_support = torch.argmax(veracity_model(tokenized_strings.to(veracity_model.device), attention_mask=attention_mask.to(veracity_model.device)).logits, axis=1)
# example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
has_unanswerable = False
has_true = False
has_false = False
for v in example_support:
if v == 0:
has_true = True
if v == 1:
has_false = True
if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this
has_unanswerable = True
if has_unanswerable:
answer = 2
elif has_true and not has_false:
answer = 0
elif not has_true and has_false:
answer = 1
else:
answer = 3
pred_label = LABEL[answer]
return pred_label
@spaces.GPU
def extract_claim_str(claim, qa_evidence, verdict_label):
claim_str = "[CLAIM] " + claim + " [EVIDENCE] "
for evidence in qa_evidence:
q_text = evidence.metadata['query'].strip()
if len(q_text) == 0:
continue
if not q_text[-1] == "?":
q_text += "?"
answer_strings = []
answer_strings.append(evidence.metadata['answer'])
claim_str += q_text
for a_text in answer_strings:
if a_text:
if not a_text[-1] == ".":
a_text += "."
claim_str += " " + a_text.strip()
claim_str += " "
claim_str += " [VERDICT] " + verdict_label
return claim_str
@spaces.GPU
def justification_generation(claim, qa_evidence, verdict_label):
#
# claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
claim_str = "[CLAIM] " + claim + " [EVIDENCE] "
for evi in qa_evidence:
q_text = evi.metadata['query'].strip()
if len(q_text) == 0:
continue
if not q_text[-1] == "?":
q_text += "?"
answer_strings = []
answer_strings.append(evi.metadata['answer'])
claim_str += q_text
for a_text in answer_strings:
if a_text:
if not a_text[-1] == ".":
a_text += "."
claim_str += " " + a_text.strip()
claim_str += " "
claim_str += " [VERDICT] " + verdict_label
#
claim_str.strip()
pred_justification = justification_model.generate(claim_str, device=justification_model.device)
# pred_justification = justification_model.generate(claim_str, device=device)
return pred_justification.strip()
@spaces.GPU
def QAprediction(claim, evidence, sources):
parts = []
#
evidence_title = f"""Retrieved Evidence:
"""
for i, evi in enumerate(evidence, 1):
part = f"""Doc {i}"""
subpart = f"""{i}"""
subparts = "".join([part, subpart])
parts.append(subparts)
evidence_part = ", ".join(parts)
prediction_title = f"""Prediction:
"""
# if 'Google' in sources:
# verdict_label = google_veracity_prediction(claim, evidence)
# justification_label = google_justification_generation(claim, evidence, verdict_label)
# justification_part = f"""Justification: {justification_label}"""
# if 'WikiPedia' in sources:
# verdict_label = wikipedia_veracity_prediction(claim, evidence)
# justification_label = wikipedia_justification_generation(claim, evidence, verdict_label)
# # justification_label = "See retrieved docs."
# justification_part = f"""Justification: {justification_label}"""
verdict_label = veracity_prediction(claim, evidence)
justification_label = justification_generation(claim, evidence, verdict_label)
# justification_label = "See retrieved docs."
justification_part = f"""Justification: {justification_label}"""
verdict_part = f"""Verdict: {verdict_label}.
"""
content_parts = "".join([evidence_title, evidence_part, prediction_title, verdict_part, justification_part])
return content_parts, [verdict_label, justification_label]
# ----------GoogleAPIretriever---------
# def generate_reference_corpus(reference_file):
# # with open(reference_file) as f:
# # train_examples = json.load(f)
#
# all_data_corpus = []
# tokenized_corpus = []
#
# for train_example in train_examples:
# train_claim = train_example["claim"]
#
# speaker = train_example["speaker"].strip() if train_example["speaker"] is not None and len(
# train_example["speaker"]) > 1 else "they"
#
# questions = [q["question"] for q in train_example["questions"]]
#
# claim_dict_builder = {}
# claim_dict_builder["claim"] = train_claim
# claim_dict_builder["speaker"] = speaker
# claim_dict_builder["questions"] = questions
#
# tokenized_corpus.append(nltk.word_tokenize(claim_dict_builder["claim"]))
# all_data_corpus.append(claim_dict_builder)
#
# return tokenized_corpus, all_data_corpus
def doc2prompt(doc):
prompt_parts = "Outrageously, " + doc["speaker"] + " claimed that \"" + doc[
"claim"].strip() + "\". Criticism includes questions like: "
questions = [q.strip() for q in doc["questions"]]
return prompt_parts + " ".join(questions)
def docs2prompt(top_docs):
return "\n\n".join([doc2prompt(d) for d in top_docs])
@spaces.GPU
def prompt_question_generation(test_claim, speaker="they", topk=10):
#
# reference_file = "averitec/data/train.json"
# tokenized_corpus, all_data_corpus = generate_reference_corpus(reference_file)
# bm25 = BM25Okapi(tokenized_corpus)
# --------------------------------------------------
# test claim
s = qg_bm25.get_scores(nltk.word_tokenize(test_claim))
top_n = np.argsort(s)[::-1][:topk]
docs = [all_data_corpus0[i] for i in top_n]
# --------------------------------------------------
prompt = docs2prompt(docs) + "\n\n" + "Outrageously, " + speaker + " claimed that \"" + test_claim.strip() + \
"\". Criticism includes questions like: "
sentences = [prompt]
inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(qg_model.device)
# inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
outputs = qg_model.generate(inputs["input_ids"], max_length=2000, num_beams=2, no_repeat_ngram_size=2, early_stopping=True)
tgt_text = qg_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
in_len = len(sentences[0])
questions_str = tgt_text[in_len:].split("\n")[0]
qs = questions_str.split("?")
qs = [q.strip() + "?" for q in qs if q.strip() and len(q.strip()) < 300]
#
generate_question = [{"question": q, "answers": []} for q in qs]
return generate_question
def check_claim_date(check_date):
try:
year, month, date = check_date.split("-")
except:
month, date, year = "01", "01", "2022"
if len(year) == 2 and int(year) <= 30:
year = "20" + year
elif len(year) == 2:
year = "19" + year
elif len(year) == 1:
year = "200" + year
if len(month) == 1:
month = "0" + month
if len(date) == 1:
date = "0" + date
sort_date = year + month + date
return sort_date
def string_to_search_query(text, author):
parts = word_tokenize(text.strip())
tags = pos_tag(parts)
keep_tags = ["CD", "JJ", "NN", "VB"]
if author is not None:
search_string = author.split()
else:
search_string = []
for token, tag in zip(parts, tags):
for keep_tag in keep_tags:
if tag[1].startswith(keep_tag):
search_string.append(token)
search_string = " ".join(search_string)
return search_string
def google_search(search_term, api_key, cse_id, **kwargs):
service = build("customsearch", "v1", developerKey=api_key)
res = service.cse().list(q=search_term, cx=cse_id, **kwargs).execute()
if "items" in res:
return res['items']
else:
return []
def get_domain_name(url):
if '://' not in url:
url = 'http://' + url
domain = urlparse(url).netloc
if domain.startswith("www."):
return domain[4:]
else:
return domain
def get_and_store(url_link, fp, worker, worker_stack):
page_lines = url2lines(url_link)
with open(fp, "w") as out_f:
print("\n".join([url_link] + page_lines), file=out_f)
worker_stack.append(worker)
gc.collect()
def get_text_from_link(url_link):
page_lines = url2lines(url_link)
return "\n".join([url_link] + page_lines)
def get_google_search_results(api_key, search_engine_id, google_search, sort_date, search_string, page=0):
search_results = []
for i in range(1):
try:
search_results += google_search(
search_string,
api_key,
search_engine_id,
num=3, # num=10,
start=0 + 10 * page,
sort="date:r:19000101:" + sort_date,
dateRestrict=None,
gl="US"
)
break
except:
sleep(1)
# for i in range(3):
# try:
# search_results += google_search(
# search_string,
# api_key,
# search_engine_id,
# num=10,
# start=0 + 10 * page,
# sort="date:r:19000101:" + sort_date,
# dateRestrict=None,
# gl="US"
# )
# break
# except:
# sleep(3)
return search_results
# @spaces.GPU
def averitec_search(claim, generate_question, speaker="they", check_date="2024-07-01", n_pages=1): # n_pages=3
# default config
api_key = os.environ["GOOGLE_API_KEY"]
search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
blacklist = [
"jstor.org", # Blacklisted because their pdfs are not labelled as such, and clog up the download
"facebook.com", # Blacklisted because only post titles can be scraped, but the scraper doesn't know this,
"ftp.cs.princeton.edu", # Blacklisted because it hosts many large NLP corpora that keep showing up
"nlp.cs.princeton.edu",
"huggingface.co"
]
blacklist_files = [ # Blacklisted some NLP nonsense that crashes my machine with OOM errors
"/glove.",
"ftp://ftp.cs.princeton.edu/pub/cs226/autocomplete/words-333333.txt",
"https://web.mit.edu/adamrose/Public/googlelist",
]
# save to folder
store_folder = "averitec/data/store/retrieved_docs"
#
index = 0
questions = [q["question"] for q in generate_question][:3]
# questions = [q["question"] for q in generate_question] # ori
# check the date of the claim
current_date = datetime.now().strftime("%Y-%m-%d")
sort_date = check_claim_date(current_date) # check_date="2022-01-01"
#
search_strings = []
search_types = []
search_string_2 = string_to_search_query(claim, None)
search_strings += [search_string_2, claim, ]
search_types += ["claim", "claim-noformat", ]
search_strings += questions
search_types += ["question" for _ in questions]
# start to search
search_results = []
visited = {}
store_counter = 0
worker_stack = list(range(10))
retrieve_evidence = []
for this_search_string, this_search_type in zip(search_strings, search_types):
for page_num in range(n_pages):
search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
this_search_string, page=page_num)
for result in search_results:
link = str(result["link"])
domain = get_domain_name(link)
if domain in blacklist:
continue
broken = False
for b_file in blacklist_files:
if b_file in link:
broken = True
if broken:
continue
if link.endswith(".pdf") or link.endswith(".doc"):
continue
store_file_path = ""
if link in visited:
web_text = visited[link]
else:
web_text = get_text_from_link(link)
visited[link] = web_text
line = [str(index), claim, link, str(page_num), this_search_string, this_search_type, web_text]
retrieve_evidence.append(line)
return retrieve_evidence
# def generate_step2_reference_corpus(reference_file):
# # with open(reference_file) as f:
# # train_examples = json.load(f)
#
# prompt_corpus = []
# tokenized_corpus = []
#
# for example in train_examples:
# for lookup_str, prompt in claim2prompts(example):
# entry = nltk.word_tokenize(lookup_str)
# tokenized_corpus.append(entry)
# prompt_corpus.append(prompt)
#
# return tokenized_corpus, prompt_corpus
@spaces.GPU
def decorate_with_questions(claim, retrieve_evidence, top_k=3): # top_k=5, 10, 100
#
# reference_file = "averitec/data/train.json"
# tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file)
# prompt_bm25 = BM25Okapi(tokenized_corpus)
#
tokenized_corpus = []
all_data_corpus = []
for retri_evi in tqdm.tqdm(retrieve_evidence):
# store_file = retri_evi[-1]
# with open(store_file, 'r') as f:
web_text = retri_evi[-1]
lines_in_web = web_text.split("\n")
first = True
for line in lines_in_web:
# for line in f:
line = line.strip()
if first:
first = False
location_url = line
continue
if len(line) > 3:
entry = nltk.word_tokenize(line)
if (location_url, line) not in all_data_corpus:
tokenized_corpus.append(entry)
all_data_corpus.append((location_url, line))
if len(tokenized_corpus) == 0:
print("")
bm25 = BM25Okapi(tokenized_corpus)
s = bm25.get_scores(nltk.word_tokenize(claim))
top_n = np.argsort(s)[::-1][:top_k]
docs = [all_data_corpus[i] for i in top_n]
generate_qa_pairs = []
# Then, generate questions for those top 50:
for doc in tqdm.tqdm(docs):
# prompt_lookup_str = example["claim"] + " " + doc[1]
prompt_lookup_str = doc[1]
prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str))
prompt_n = 10
prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n]
prompt_docs = [prompt_corpus1[i] for i in prompt_top_n]
claim_prompt = "Evidence: " + doc[1].replace("\n", " ") + "\nQuestion answered: "
prompt = "\n\n".join(prompt_docs + [claim_prompt])
sentences = [prompt]
inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(qg_model.device)
# inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
outputs = qg_model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2, early_stopping=True)
tgt_text = qg_tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
# We are not allowed to generate more than 250 characters:
tgt_text = tgt_text[:250]
qa_pair = [tgt_text.strip().split("?")[0].replace("\n", " ") + "?", doc[1].replace("\n", " "), doc[0]]
generate_qa_pairs.append(qa_pair)
return generate_qa_pairs
# def decorate_with_questions_michale(claim, retrieve_evidence, top_k=10): # top_k=100
# #
# reference_file = "averitec/data/train.json"
# tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file)
# prompt_bm25 = BM25Okapi(tokenized_corpus)
#
# # Define the bloom model:
# accelerator = Accelerator()
# accel_device = accelerator.device
# # device = "cuda:0" if torch.cuda.is_available() else "cpu"
# # tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
# # model = BloomForCausalLM.from_pretrained(
# # "bigscience/bloom-7b1",
# # device_map="auto",
# # torch_dtype=torch.bfloat16,
# # offload_folder="./offload"
# # )
#
# #
# tokenized_corpus = []
# all_data_corpus = []
#
# for retri_evi in tqdm.tqdm(retrieve_evidence):
# store_file = retri_evi[-1]
#
# with open(store_file, 'r') as f:
# first = True
# for line in f:
# line = line.strip()
#
# if first:
# first = False
# location_url = line
# continue
#
# if len(line) > 3:
# entry = nltk.word_tokenize(line)
# if (location_url, line) not in all_data_corpus:
# tokenized_corpus.append(entry)
# all_data_corpus.append((location_url, line))
#
# if len(tokenized_corpus) == 0:
# print("")
#
# bm25 = BM25Okapi(tokenized_corpus)
# s = bm25.get_scores(nltk.word_tokenize(claim))
# top_n = np.argsort(s)[::-1][:top_k]
# docs = [all_data_corpus[i] for i in top_n]
#
# generate_qa_pairs = []
# # Then, generate questions for those top 50:
# for doc in tqdm.tqdm(docs):
# # prompt_lookup_str = example["claim"] + " " + doc[1]
# prompt_lookup_str = doc[1]
#
# prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str))
# prompt_n = 10
# prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n]
# prompt_docs = [prompt_corpus[i] for i in prompt_top_n]
#
# claim_prompt = "Evidence: " + doc[1].replace("\n", " ") + "\nQuestion answered: "
# prompt = "\n\n".join(prompt_docs + [claim_prompt])
# sentences = [prompt]
#
# inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
# outputs = qg_model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2,
# early_stopping=True)
#
# tgt_text = qg_tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
# # We are not allowed to generate more than 250 characters:
# tgt_text = tgt_text[:250]
#
# qa_pair = [tgt_text.strip().split("?")[0].replace("\n", " ") + "?", doc[1].replace("\n", " "), doc[0]]
# generate_qa_pairs.append(qa_pair)
#
# return generate_qa_pairs
def triple_to_string(x):
return " ".join([item.strip() for item in x])
@spaces.GPU
def rerank_questions(claim, bm25_qas, topk=3):
#
strs_to_score = []
values = []
for question, answer, source in bm25_qas:
str_to_score = triple_to_string([claim, question, answer])
strs_to_score.append(str_to_score)
values.append([question, answer, source])
if len(bm25_qas) > 0:
encoded_dict = rerank_tokenizer(strs_to_score, max_length=512, padding="longest", truncation=True, return_tensors="pt").to(rerank_trained_model.device)
# encoded_dict = rerank_tokenizer(strs_to_score, max_length=512, padding="longest", truncation=True, return_tensors="pt").to(device)
input_ids = encoded_dict['input_ids']
attention_masks = encoded_dict['attention_mask']
scores = torch.softmax(rerank_trained_model(input_ids, attention_mask=attention_masks).logits, axis=-1)[:, 1]
top_n = torch.argsort(scores, descending=True)[:topk]
pass_through = [{"question": values[i][0], "answers": values[i][1], "source_url": values[i][2]} for i in top_n]
else:
pass_through = []
top3_qa_pairs = pass_through
return top3_qa_pairs
@spaces.GPU
def Googleretriever(query, sources):
# ----- Generate QA pairs using AVeriTeC
# step 1: generate questions for the query/claim using Bloom
generate_question = prompt_question_generation(query)
# step 2: retrieve evidence for the generated questions using Google API
retrieve_evidence = averitec_search(query, generate_question)
# step 3: generate QA pairs for each retrieved document
bm25_qa_pairs = decorate_with_questions(query, retrieve_evidence)
# step 4: rerank QA pairs
top3_qa_pairs = rerank_questions(query, bm25_qa_pairs)
# Add score to metadata
results = []
for i, qa in enumerate(top3_qa_pairs):
metadata = dict()
metadata['name'] = qa['question']
metadata['url'] = qa['source_url']
metadata['cached_source_url'] = qa['source_url']
metadata['short_name'] = "Evidence {}".format(i + 1)
metadata['page_number'] = ""
metadata['title'] = qa['question']
metadata['evidence'] = qa['answers']
metadata['query'] = qa['question']
metadata['answer'] = qa['answers']
metadata['page_content'] = "Question: " + qa['question'] + "
" + "Answer: " + qa['answers']
page_content = f"""{metadata['page_content']}"""
results.append(Docs(metadata, page_content))
return results
# ----------GoogleAPIretriever---------
# ----------Wikipediaretriever---------
def bm25_retriever(query, corpus, topk=3):
bm25 = BM25Okapi(corpus)
#
query_tokens = word_tokenize(query)
scores = bm25.get_scores(query_tokens)
top_n = np.argsort(scores)[::-1][:topk]
top_n_scores = [scores[i] for i in top_n]
return top_n, top_n_scores
def bm25s_retriever(query, corpus, topk=3):
# optional: create a stemmer
stemmer = Stemmer.Stemmer("english")
# Tokenize the corpus and only keep the ids (faster and saves memory)
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
# Create the BM25 model and index the corpus
retriever = bm25s.BM25()
retriever.index(corpus_tokens)
# Query the corpus
query_tokens = bm25s.tokenize(query, stemmer=stemmer)
# Get top-k results as a tuple of (doc ids, scores). Both are arrays of shape (n_queries, k)
results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=topk)
top_n = [corpus.index(res) for res in results[0]]
return top_n, scores
def find_evidence_from_wikipedia_dumps(claim):
#
doc = nlp(claim)
entities_in_claim = [str(ent).lower() for ent in doc.ents]
title2id = ranker.doc_dict[0]
wiki_intro, ent_list = [], []
for ent in entities_in_claim:
if ent in title2id.keys():
ids = title2id[ent]
introduction = doc_db.get_doc_intro(ids)
wiki_intro.append([ent, introduction])
# fulltext = doc_db.get_doc_text(ids)
# evidence.append([ent, fulltext])
ent_list.append(ent)
if len(wiki_intro) < 5:
evidence_tfidf = process_topk(claim, title2id, ent_list, k=5)
wiki_intro.extend(evidence_tfidf)
return wiki_intro, doc
def relevant_sentence_retrieval(query, wiki_intro, k):
# 1. Create corpus here
corpus, sentences = [], []
titles = []
for i, (title, intro) in enumerate(wiki_intro):
sents_in_intro = sent_tokenize(intro)
for sent in sents_in_intro:
corpus.append(word_tokenize(sent))
sentences.append(sent)
titles.append(title)
#
# ----- BM25
bm25_top_n, bm25_top_n_scores = bm25_retriever(query, corpus, topk=k)
bm25_top_n_sents = [sentences[i] for i in bm25_top_n]
bm25_top_n_titles = [titles[i] for i in bm25_top_n]
# ----- BM25s
# bm25s_top_n, bm25s_top_n_scores = bm25s_retriever(query, sentences, topk=k) # corpus->sentences
# bm25s_top_n_sents = [sentences[i] for i in bm25s_top_n]
# bm25s_top_n_titles = [titles[i] for i in bm25s_top_n]
return bm25_top_n_sents, bm25_top_n_titles
def process_topk(query, title2id, ent_list, k=1):
doc_names, doc_scores = ranker.closest_docs(query, k)
evidence_tfidf = []
for _name in doc_names:
if _name not in ent_list and len(ent_list) < 5:
ent_list.append(_name)
idx = title2id[_name]
introduction = doc_db.get_doc_intro(idx)
evidence_tfidf.append([_name, introduction])
# fulltext = doc_db.get_doc_text(idx)
# evidence_tfidf.append([_name,fulltext])
return evidence_tfidf
def WikipediaDumpsretriever(claim):
#
# 1. extract relevant wikipedia pages from wikipedia dumps
wiki_intro, doc = find_evidence_from_wikipedia_dumps(claim)
# wiki_intro = [['trump', "'''Trump''' most commonly refers to:\n* Donald Trump (born 1946), President of the United States from 2017 to 2021 \n* Trump (card games), any playing card given an ad-hoc high rank\n\n'''Trump''' may also refer to:"]]
# 2. extract relevant sentences from extracted wikipedia pages
sents, titles = relevant_sentence_retrieval(claim, wiki_intro, k=3)
#
results = []
for i, (sent, title) in enumerate(zip(sents, titles)):
metadata = dict()
metadata['name'] = claim
metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
metadata['short_name'] = "Evidence {}".format(i + 1)
metadata['page_number'] = ""
metadata['query'] = sent
metadata['title'] = title
metadata['evidence'] = sent
metadata['answer'] = ""
metadata['page_content'] = "Title: " + str(metadata['title']) + "
" + "Evidence: " + metadata[
'evidence']
page_content = f"""{metadata['page_content']}"""
results.append(Docs(metadata, page_content))
return results
# ----------WikipediaAPIretriever---------
def clean_str(p):
return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
def get_page_obs(page):
# find all paragraphs
paragraphs = page.split("\n")
paragraphs = [p.strip() for p in paragraphs if p.strip()]
# # find all sentence
# sentences = []
# for p in paragraphs:
# sentences += p.split('. ')
# sentences = [s.strip() + '.' for s in sentences if s.strip()]
# # return ' '.join(sentences[:5])
# return ' '.join(sentences)
return ' '.join(paragraphs[:5])
def search_entity_wikipeida(entity):
find_evidence = []
page_py = wiki_wiki.page(entity)
if page_py.exists():
introduction = page_py.summary
find_evidence.append([str(entity), introduction])
return find_evidence
def search_step(entity):
ent_ = entity.replace(" ", "+")
search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}"
response_text = requests.get(search_url).text
soup = BeautifulSoup(response_text, features="html.parser")
result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
find_evidence = []
if result_divs: # mismatch
# If the wikipeida page of the entity is not exist, find similar wikipedia pages.
result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
similar_titles = result_titles[:5]
for _t in similar_titles:
if len(find_evidence) < 5:
_evi = search_step(_t)
find_evidence.extend(_evi)
else:
page = [p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul")]
if any("may refer to:" in p for p in page):
_evi = search_step("[" + entity + "]")
find_evidence.extend(_evi)
else:
# page_py = wiki_wiki.page(entity)
#
# if page_py.exists():
# introduction = page_py.summary
# else:
page_text = ""
for p in page:
if len(p.split(" ")) > 2:
page_text += clean_str(p)
if not p.endswith("\n"):
page_text += "\n"
introduction = get_page_obs(page_text)
find_evidence.append([entity, introduction])
return find_evidence
def find_similar_wikipedia(entity, relevant_wikipages):
# If the relevant wikipeida page of the entity is less than 5, find similar wikipedia pages.
ent_ = entity.replace(" ", "+")
search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}&title=Special:Search&profile=advanced&fulltext=1&ns0=1"
response_text = requests.get(search_url).text
soup = BeautifulSoup(response_text, features="html.parser")
result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
if result_divs:
result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
similar_titles = result_titles[:5]
saved_titles = [ent[0] for ent in relevant_wikipages] if relevant_wikipages else relevant_wikipages
for _t in similar_titles:
if _t not in saved_titles and len(relevant_wikipages) < 5:
_evi = search_entity_wikipeida(_t)
# _evi = search_step(_t)
relevant_wikipages.extend(_evi)
return relevant_wikipages
def find_evidence_from_wikipedia(claim):
#
doc = nlp(claim)
#
wikipedia_page = []
for ent in doc.ents:
relevant_wikipages = search_entity_wikipeida(ent)
if len(relevant_wikipages) < 5:
relevant_wikipages = find_similar_wikipedia(str(ent), relevant_wikipages)
wikipedia_page.extend(relevant_wikipages)
return wikipedia_page
def relevant_wikipedia_API_retriever(claim):
#
doc = nlp(claim)
wiki_intro = []
for ent in doc.ents:
page_py = wiki_wiki.page(ent)
if page_py.exists():
introduction = page_py.summary
else:
introduction = "No documents found."
wiki_intro.append([str(ent), introduction])
return wiki_intro, doc
def Wikipediaretriever(claim, sources):
#
# 1. extract relevant wikipedia pages from wikipedia dumps
if "Dump" in sources:
wikipedia_page = find_evidence_from_wikipedia_dumps(claim)
else:
wikipedia_page = find_evidence_from_wikipedia(claim)
# wiki_intro, doc = relevant_wikipedia_API_retriever(claim)
# 2. extract relevant sentences from extracted wikipedia pages
sents, titles = relevant_sentence_retrieval(claim, wikipedia_page, k=3)
#
results = []
for i, (sent, title) in enumerate(zip(sents, titles)):
metadata = dict()
metadata['name'] = claim
metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title)
metadata['short_name'] = "Evidence {}".format(i + 1)
metadata['page_number'] = ""
metadata['query'] = sent
metadata['title'] = title
metadata['evidence'] = sent
metadata['answer'] = ""
metadata['page_content'] = "Title: " + str(metadata['title']) + "
" + "Evidence: " + metadata['evidence']
page_content = f"""{metadata['page_content']}"""
results.append(Docs(metadata, page_content))
return results
def log_on_azure(file, logs, azure_share_client):
logs = json.dumps(logs)
file_client = azure_share_client.get_file_client(file)
file_client.upload_file(logs)
@spaces.GPU
def chat(claim, history, sources):
evidence = []
if 'Google' in sources:
evidence = Googleretriever(claim, sources)
if 'WikiPedia' in sources:
evidence = Wikipediaretriever(claim, sources)
answer_set, answer_output = QAprediction(claim, evidence, sources)
docs_html = ""
if len(evidence) > 0:
docs_html = []
for i, evi in enumerate(evidence, 1):
docs_html.append(make_html_source(evi, i))
docs_html = "".join(docs_html)
else:
print("No documents found")
url_of_evidence = ""
output_language = "English"
output_query = claim
history[-1] = (claim, answer_set)
history = [tuple(x) for x in history]
############################################################
evi_list = []
for evi in evidence:
title_str = evi.metadata['title']
evi_str = evi.metadata['evidence']
url_str = evi.metadata['url']
evi_list.append([title_str, evi_str, url_str])
try:
# Log answer on Azure Blob Storage
# IF AZURE_ISSAVE=TRUE, save the logs into the Azure share client.
if os.environ["AZURE_ISSAVE"] == "TRUE":
# timestamp = str(datetime.now().timestamp())
timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
file = timestamp + ".json"
logs = {
"user_id": str(user_id),
"claim": claim,
"sources": sources,
"evidence": evi_list,
"answer": answer_output,
"time": timestamp,
}
log_on_azure(file, logs, azure_share_client)
except Exception as e:
print(f"Error logging on Azure Blob Storage: {e}")
raise gr.Error(
f"AVeriTeC Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
##########
return history, docs_html, output_query, output_language
def main():
init_prompt = """
Hello, I am a fact-checking assistant designed to help you find appropriate evidence to predict the veracity of claims.
What do you want to fact-check?
"""
with gr.Blocks(title="AVeriTeC fact-checker", css="style.css", theme=theme, elem_id="main-component") as demo:
with gr.Tab("AVeriTeC"):
with gr.Row(elem_id="chatbot-row"):
with gr.Column(scale=2):
chatbot = gr.Chatbot(
value=[(None, init_prompt)],
show_copy_button=True, show_label=False, elem_id="chatbot", layout="panel",
avatar_images = (None, "assets/averitec.png")
) # avatar_images=(None, "https://i.ibb.co/YNyd5W2/logo4.png"),
with gr.Row(elem_id="input-message"):
textbox = gr.Textbox(placeholder="Ask me what claim do you want to check!", show_label=False,
scale=7, lines=1, interactive=True, elem_id="input-textbox")
# submit = gr.Button("",elem_id = "submit-button",scale = 1,interactive = True,icon = "https://static-00.iconduck.com/assets.00/settings-icon-2048x2046-cw28eevx.png")
with gr.Column(scale=1, variant="panel", elem_id="right-panel"):
with gr.Tabs() as tabs:
with gr.TabItem("Examples", elem_id="tab-examples", id=0):
examples_hidden = gr.Textbox(visible=False)
first_key = list(CLAIMS_Type.keys())[0]
dropdown_samples = gr.Dropdown(CLAIMS_Type.keys(), value=first_key, interactive=True,
show_label=True,
label="Select claim type",
elem_id="dropdown-samples")
samples = []
for i, key in enumerate(CLAIMS_Type.keys()):
examples_visible = True if i == 0 else False
with gr.Row(visible=examples_visible) as group_examples:
examples_questions = gr.Examples(
CLAIMS_Type[key],
[examples_hidden],
examples_per_page=8,
run_on_click=False,
elem_id=f"examples{i}",
api_name=f"examples{i}",
# label = "Click on the example question or enter your own",
# cache_examples=True,
)
samples.append(group_examples)
with gr.Tab("Sources", elem_id="tab-citations", id=1):
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
docs_textbox = gr.State("")
with gr.Tab("Configuration", elem_id="tab-config", id=2):
gr.Markdown("Reminder: We currently only support fact-checking in English!")
# dropdown_sources = gr.Radio(
# ["AVeriTeC", "WikiPediaDumps", "Google", "WikiPediaAPI"],
# label="Select source",
# value="WikiPediaAPI",
# interactive=True,
# )
dropdown_sources = gr.Radio(
["Google", "WikiPedia"],
label="Select source",
value="WikiPedia",
interactive=True,
)
dropdown_retriever = gr.Dropdown(
["BM25", "BM25s"],
label="Select evidence retriever",
multiselect=False,
value="BM25",
interactive=True,
)
output_query = gr.Textbox(label="Query used for retrieval", show_label=True,
elem_id="reformulated-query", lines=2, interactive=False)
output_language = gr.Textbox(label="Language", show_label=True, elem_id="language", lines=1,
interactive=False)
with gr.Tab("About", elem_classes="max-height other-tabs"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("See more info at [https://fever.ai/task.html](https://fever.ai/task.html)")
def start_chat(query, history):
history = history + [(query, None)]
history = [tuple(x) for x in history]
return (gr.update(interactive=False), gr.update(selected=1), history)
def finish_chat():
return (gr.update(interactive=True, value=""))
(textbox
.submit(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox")
.then(chat, [textbox, chatbot, dropdown_sources],
[chatbot, sources_textbox, output_query, output_language], concurrency_limit=8, api_name="chat_textbox")
.then(finish_chat, None, [textbox], api_name="finish_chat_textbox")
)
(examples_hidden
.change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False,
api_name="start_chat_examples")
.then(chat, [examples_hidden, chatbot, dropdown_sources],
[chatbot, sources_textbox, output_query, output_language], concurrency_limit=8, api_name="chat_examples")
.then(finish_chat, None, [textbox], api_name="finish_chat_examples")
)
def change_sample_questions(key):
index = list(CLAIMS_Type.keys()).index(key)
visible_bools = [False] * len(samples)
visible_bools[index] = True
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
dropdown_samples.change(change_sample_questions, dropdown_samples, samples)
demo.queue()
demo.launch()
# demo.launch(share=True)
if __name__ == "__main__":
main()