Spaces:
Build error
Build error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# Created by zd302 at 17/07/2024 | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
# from averitec.models.AveritecModule import Wikipediaretriever, Googleretriever, veracity_prediction, justification_generation | |
import uvicorn | |
import spaces | |
app = FastAPI() | |
# --------------------------------------------------------------------------------------------------------------------- | |
import gradio as gr | |
import tqdm | |
import torch | |
import numpy as np | |
from time import sleep | |
from datetime import datetime | |
import threading | |
import gc | |
import os | |
import json | |
import pytorch_lightning as pl | |
from urllib.parse import urlparse | |
from accelerate import Accelerator | |
import spaces | |
from transformers import BartTokenizer, BartForConditionalGeneration | |
from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification | |
from transformers import RobertaTokenizer, RobertaForSequenceClassification | |
from rank_bm25 import BM25Okapi | |
# import bm25s | |
# import Stemmer # optional: for stemming | |
from html2lines import url2lines | |
from googleapiclient.discovery import build | |
from averitec.models.DualEncoderModule import DualEncoderModule | |
from averitec.models.SequenceClassificationModule import SequenceClassificationModule | |
from averitec.models.JustificationGenerationModule import JustificationGenerationModule | |
from averitec.data.sample_claims import CLAIMS_Type | |
# --------------------------------------------------------------------------- | |
# load .env | |
from utils import create_user_id | |
user_id = create_user_id() | |
from azure.storage.fileshare import ShareServiceClient | |
try: | |
from dotenv import load_dotenv | |
load_dotenv() | |
except Exception as e: | |
pass | |
# --------------------------------------------------------------------------- | |
# os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
account_url = os.environ["AZURE_ACCOUNT_URL"] | |
credential = { | |
"account_key": os.environ['AZURE_ACCOUNT_KEY'], | |
"account_name": os.environ['AZURE_ACCOUNT_NAME'] | |
} | |
file_share_name = "averitec" | |
azure_service = ShareServiceClient(account_url=account_url, credential=credential) | |
azure_share_client = azure_service.get_share_client(file_share_name) | |
# --------------------------------------------------------------------------------------------------------------------- | |
import requests | |
from bs4 import BeautifulSoup | |
import wikipediaapi | |
wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC ([email protected])', 'en') | |
import nltk | |
nltk.download('averaged_perceptron_tagger_eng') | |
nltk.download('averaged_perceptron_tagger') | |
nltk.download('punkt') | |
nltk.download('punkt_tab') | |
from nltk import pos_tag, word_tokenize, sent_tokenize | |
import spacy | |
os.system("python -m spacy download en_core_web_sm") | |
nlp = spacy.load("en_core_web_sm") | |
# --------------------------------------------------------------------------- | |
train_examples = json.load(open('averitec/data/train.json', 'r')) | |
def claim2prompts(example): | |
claim = example["claim"] | |
# claim_str = "Claim: " + claim + "||Evidence: " | |
claim_str = "Evidence: " | |
for question in example["questions"]: | |
q_text = question["question"].strip() | |
if len(q_text) == 0: | |
continue | |
if not q_text[-1] == "?": | |
q_text += "?" | |
answer_strings = [] | |
for a in question["answers"]: | |
if a["answer_type"] in ["Extractive", "Abstractive"]: | |
answer_strings.append(a["answer"]) | |
if a["answer_type"] == "Boolean": | |
answer_strings.append(a["answer"] + ", because " + a["boolean_explanation"].lower().strip()) | |
for a_text in answer_strings: | |
if not a_text[-1] in [".", "!", ":", "?"]: | |
a_text += "." | |
# prompt_lookup_str = claim + " " + a_text | |
prompt_lookup_str = a_text | |
this_q_claim_str = claim_str + " " + a_text.strip() + "||Question answered: " + q_text | |
yield (prompt_lookup_str, this_q_claim_str.replace("\n", " ").replace("||", "\n")) | |
def generate_reference_corpus(reference_file): | |
all_data_corpus = [] | |
tokenized_corpus = [] | |
for train_example in train_examples: | |
train_claim = train_example["claim"] | |
speaker = train_example["speaker"].strip() if train_example["speaker"] is not None and len( | |
train_example["speaker"]) > 1 else "they" | |
questions = [q["question"] for q in train_example["questions"]] | |
claim_dict_builder = {} | |
claim_dict_builder["claim"] = train_claim | |
claim_dict_builder["speaker"] = speaker | |
claim_dict_builder["questions"] = questions | |
tokenized_corpus.append(nltk.word_tokenize(claim_dict_builder["claim"])) | |
all_data_corpus.append(claim_dict_builder) | |
return tokenized_corpus, all_data_corpus | |
def generate_step2_reference_corpus(reference_file): | |
prompt_corpus = [] | |
tokenized_corpus = [] | |
for example in train_examples: | |
for lookup_str, prompt in claim2prompts(example): | |
entry = nltk.word_tokenize(lookup_str) | |
tokenized_corpus.append(entry) | |
prompt_corpus.append(prompt) | |
return tokenized_corpus, prompt_corpus | |
reference_file = "averitec/data/train.json" | |
tokenized_corpus0, all_data_corpus0 = generate_reference_corpus(reference_file) | |
qg_bm25 = BM25Okapi(tokenized_corpus0) | |
tokenized_corpus1, prompt_corpus1 = generate_step2_reference_corpus(reference_file) | |
prompt_bm25 = BM25Okapi(tokenized_corpus1) | |
# --------------------------------------------------------------------------------------------------------------------- | |
# ---------- Setting ---------- | |
# ---------- Load Veracity and Justification prediction model ---------- | |
print("Loading models ...") | |
LABEL = [ | |
"Supported", | |
"Refuted", | |
"Not Enough Evidence", | |
"Conflicting Evidence/Cherrypicking", | |
] | |
if torch.cuda.is_available(): | |
# question generation | |
qg_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-1b1") | |
qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-1b1", torch_dtype=torch.bfloat16).to('cuda') | |
# rerank | |
rerank_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
rereank_bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2, problem_type="single_label_classification") # Must specify single_label for some reason | |
best_checkpoint = "averitec/pretrained_models/bert_dual_encoder.ckpt" | |
rerank_trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=rerank_tokenizer, model=rereank_bert_model) | |
# Veracity | |
veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification") | |
veracity_checkpoint_path = os.getcwd() + "/averitec/pretrained_models/bert_veracity.ckpt" | |
veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model) | |
# Justification | |
justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True) | |
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large") | |
best_checkpoint = os.getcwd() + '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt' | |
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model) | |
# --------------------------------------------------------------------------- | |
# ---------------------------------------------------------------------------- | |
class Docs: | |
def __init__(self, metadata=dict(), page_content=""): | |
self.metadata = metadata | |
self.page_content = page_content | |
# ------------------------------ Googleretriever ----------------------------- | |
def doc2prompt(doc): | |
prompt_parts = "Outrageously, " + doc["speaker"] + " claimed that \"" + doc[ | |
"claim"].strip() + "\". Criticism includes questions like: " | |
questions = [q.strip() for q in doc["questions"]] | |
return prompt_parts + " ".join(questions) | |
def docs2prompt(top_docs): | |
return "\n\n".join([doc2prompt(d) for d in top_docs]) | |
def prompt_question_generation(test_claim, speaker="they", topk=10): | |
# -------------------------------------------------- | |
# test claim | |
s = qg_bm25.get_scores(nltk.word_tokenize(test_claim)) | |
top_n = np.argsort(s)[::-1][:topk] | |
docs = [all_data_corpus0[i] for i in top_n] | |
# -------------------------------------------------- | |
prompt = docs2prompt(docs) + "\n\n" + "Outrageously, " + speaker + " claimed that \"" + test_claim.strip() + \ | |
"\". Criticism includes questions like: " | |
sentences = [prompt] | |
inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(qg_model.device) | |
outputs = qg_model.generate(inputs["input_ids"], max_length=2000, num_beams=2, no_repeat_ngram_size=2, early_stopping=True) | |
tgt_text = qg_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
in_len = len(sentences[0]) | |
questions_str = tgt_text[in_len:].split("\n")[0] | |
qs = questions_str.split("?") | |
qs = [q.strip() + "?" for q in qs if q.strip() and len(q.strip()) < 300] | |
# | |
generate_question = [{"question": q, "answers": []} for q in qs] | |
return generate_question | |
def check_claim_date(check_date): | |
try: | |
year, month, date = check_date.split("-") | |
except: | |
month, date, year = "01", "01", "2022" | |
if len(year) == 2 and int(year) <= 30: | |
year = "20" + year | |
elif len(year) == 2: | |
year = "19" + year | |
elif len(year) == 1: | |
year = "200" + year | |
if len(month) == 1: | |
month = "0" + month | |
if len(date) == 1: | |
date = "0" + date | |
sort_date = year + month + date | |
return sort_date | |
def string_to_search_query(text, author): | |
parts = word_tokenize(text.strip()) | |
tags = pos_tag(parts) | |
keep_tags = ["CD", "JJ", "NN", "VB"] | |
if author is not None: | |
search_string = author.split() | |
else: | |
search_string = [] | |
for token, tag in zip(parts, tags): | |
for keep_tag in keep_tags: | |
if tag[1].startswith(keep_tag): | |
search_string.append(token) | |
search_string = " ".join(search_string) | |
return search_string | |
def get_google_search_results(api_key, search_engine_id, google_search, sort_date, search_string, page=0): | |
search_results = [] | |
for i in range(1): | |
try: | |
search_results += google_search( | |
search_string, | |
api_key, | |
search_engine_id, | |
num=3, # num=10, | |
start=0 + 10 * page, | |
sort="date:r:19000101:" + sort_date, | |
dateRestrict=None, | |
gl="US" | |
) | |
break | |
except: | |
sleep(1) | |
return search_results | |
def google_search(search_term, api_key, cse_id, **kwargs): | |
service = build("customsearch", "v1", developerKey=api_key) | |
res = service.cse().list(q=search_term, cx=cse_id, **kwargs).execute() | |
if "items" in res: | |
return res['items'] | |
else: | |
return [] | |
def get_domain_name(url): | |
if '://' not in url: | |
url = 'http://' + url | |
domain = urlparse(url).netloc | |
if domain.startswith("www."): | |
return domain[4:] | |
else: | |
return domain | |
def get_text_from_link(url_link): | |
page_lines = url2lines(url_link) | |
return "\n".join([url_link] + page_lines) | |
def averitec_search(claim, generate_question, speaker="they", check_date="2024-07-01", n_pages=1): # n_pages=3 | |
# default config | |
api_key = os.environ["GOOGLE_API_KEY"] | |
search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"] | |
blacklist = [ | |
"jstor.org", # Blacklisted because their pdfs are not labelled as such, and clog up the download | |
"facebook.com", # Blacklisted because only post titles can be scraped, but the scraper doesn't know this, | |
"ftp.cs.princeton.edu", # Blacklisted because it hosts many large NLP corpora that keep showing up | |
"nlp.cs.princeton.edu", | |
"huggingface.co" | |
] | |
blacklist_files = [ # Blacklisted some NLP nonsense that crashes my machine with OOM errors | |
"/glove.", | |
"ftp://ftp.cs.princeton.edu/pub/cs226/autocomplete/words-333333.txt", | |
"https://web.mit.edu/adamrose/Public/googlelist", | |
] | |
# save to folder | |
store_folder = "averitec/data/store/retrieved_docs" | |
# | |
index = 0 | |
questions = [q["question"] for q in generate_question][:3] | |
# questions = [q["question"] for q in generate_question] # ori | |
# check the date of the claim | |
current_date = datetime.now().strftime("%Y-%m-%d") | |
sort_date = check_claim_date(current_date) # check_date="2022-01-01" | |
# | |
search_strings = [] | |
search_types = [] | |
search_string_2 = string_to_search_query(claim, None) | |
search_strings += [search_string_2, claim, ] | |
search_types += ["claim", "claim-noformat", ] | |
search_strings += questions | |
search_types += ["question" for _ in questions] | |
# start to search | |
search_results = [] | |
visited = {} | |
store_counter = 0 | |
worker_stack = list(range(10)) | |
retrieve_evidence = [] | |
for this_search_string, this_search_type in zip(search_strings, search_types): | |
for page_num in range(n_pages): | |
search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date, | |
this_search_string, page=page_num) | |
for result in search_results: | |
link = str(result["link"]) | |
domain = get_domain_name(link) | |
if domain in blacklist: | |
continue | |
broken = False | |
for b_file in blacklist_files: | |
if b_file in link: | |
broken = True | |
if broken: | |
continue | |
if link.endswith(".pdf") or link.endswith(".doc"): | |
continue | |
store_file_path = "" | |
if link in visited: | |
web_text = visited[link] | |
else: | |
web_text = get_text_from_link(link) | |
visited[link] = web_text | |
line = [str(index), claim, link, str(page_num), this_search_string, this_search_type, web_text] | |
retrieve_evidence.append(line) | |
return retrieve_evidence | |
def decorate_with_questions(claim, retrieve_evidence, top_k=3): # top_k=5, 10, 100 | |
# | |
tokenized_corpus = [] | |
all_data_corpus = [] | |
for retri_evi in tqdm.tqdm(retrieve_evidence): | |
# store_file = retri_evi[-1] | |
# with open(store_file, 'r') as f: | |
web_text = retri_evi[-1] | |
lines_in_web = web_text.split("\n") | |
first = True | |
for line in lines_in_web: | |
# for line in f: | |
line = line.strip() | |
if first: | |
first = False | |
location_url = line | |
continue | |
if len(line) > 3: | |
entry = nltk.word_tokenize(line) | |
if (location_url, line) not in all_data_corpus: | |
tokenized_corpus.append(entry) | |
all_data_corpus.append((location_url, line)) | |
if len(tokenized_corpus) == 0: | |
print("") | |
bm25 = BM25Okapi(tokenized_corpus) | |
s = bm25.get_scores(nltk.word_tokenize(claim)) | |
top_n = np.argsort(s)[::-1][:top_k] | |
docs = [all_data_corpus[i] for i in top_n] | |
generate_qa_pairs = [] | |
# Then, generate questions for those top 50: | |
for doc in tqdm.tqdm(docs): | |
# prompt_lookup_str = example["claim"] + " " + doc[1] | |
prompt_lookup_str = doc[1] | |
prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str)) | |
prompt_n = 10 | |
prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n] | |
prompt_docs = [prompt_corpus1[i] for i in prompt_top_n] | |
claim_prompt = "Evidence: " + doc[1].replace("\n", " ") + "\nQuestion answered: " | |
prompt = "\n\n".join(prompt_docs + [claim_prompt]) | |
sentences = [prompt] | |
inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(qg_model.device) | |
outputs = qg_model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2, early_stopping=True) | |
tgt_text = qg_tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0] | |
# We are not allowed to generate more than 250 characters: | |
tgt_text = tgt_text[:250] | |
qa_pair = [tgt_text.strip().split("?")[0].replace("\n", " ") + "?", doc[1].replace("\n", " "), doc[0]] | |
generate_qa_pairs.append(qa_pair) | |
return generate_qa_pairs | |
def triple_to_string(x): | |
return " </s> ".join([item.strip() for item in x]) | |
def rerank_questions(claim, bm25_qas, topk=3): | |
# | |
strs_to_score = [] | |
values = [] | |
for question, answer, source in bm25_qas: | |
str_to_score = triple_to_string([claim, question, answer]) | |
strs_to_score.append(str_to_score) | |
values.append([question, answer, source]) | |
if len(bm25_qas) > 0: | |
encoded_dict = rerank_tokenizer(strs_to_score, max_length=512, padding="longest", truncation=True, return_tensors="pt").to(rerank_trained_model.device) | |
input_ids = encoded_dict['input_ids'] | |
attention_masks = encoded_dict['attention_mask'] | |
scores = torch.softmax(rerank_trained_model(input_ids, attention_mask=attention_masks).logits, axis=-1)[:, 1] | |
top_n = torch.argsort(scores, descending=True)[:topk] | |
pass_through = [{"question": values[i][0], "answers": values[i][1], "source_url": values[i][2]} for i in top_n] | |
else: | |
pass_through = [] | |
top3_qa_pairs = pass_through | |
return top3_qa_pairs | |
def Googleretriever(query): | |
# ----- Generate QA pairs using AVeriTeC | |
# step 1: generate questions for the query/claim using Bloom | |
generate_question = prompt_question_generation(query) | |
# step 2: retrieve evidence for the generated questions using Google API | |
retrieve_evidence = averitec_search(query, generate_question) | |
# step 3: generate QA pairs for each retrieved document | |
bm25_qa_pairs = decorate_with_questions(query, retrieve_evidence) | |
# step 4: rerank QA pairs | |
top3_qa_pairs = rerank_questions(query, bm25_qa_pairs) | |
# Add score to metadata | |
results = [] | |
for i, qa in enumerate(top3_qa_pairs): | |
metadata = dict() | |
metadata['name'] = qa['question'] | |
metadata['url'] = qa['source_url'] | |
metadata['cached_source_url'] = qa['source_url'] | |
metadata['short_name'] = "Evidence {}".format(i + 1) | |
metadata['page_number'] = "" | |
metadata['title'] = qa['question'] | |
metadata['evidence'] = qa['answers'] | |
metadata['query'] = qa['question'] | |
metadata['answer'] = qa['answers'] | |
metadata['page_content'] = "<b>Question</b>: " + qa['question'] + "<br>" + "<b>Answer</b>: " + qa['answers'] | |
page_content = f"""{metadata['page_content']}""" | |
results.append(Docs(metadata, page_content)) | |
return results | |
# ------------------------------ Googleretriever ----------------------------- | |
# ------------------------------ Wikipediaretriever -------------------------- | |
def search_entity_wikipeida(entity): | |
find_evidence = [] | |
page_py = wiki_wiki.page(entity) | |
if page_py.exists(): | |
introduction = page_py.summary | |
find_evidence.append([str(entity), introduction]) | |
return find_evidence | |
def clean_str(p): | |
return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8") | |
def find_similar_wikipedia(entity, relevant_wikipages): | |
# If the relevant wikipeida page of the entity is less than 5, find similar wikipedia pages. | |
ent_ = entity.replace(" ", "+") | |
search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}&title=Special:Search&profile=advanced&fulltext=1&ns0=1" | |
response_text = requests.get(search_url).text | |
soup = BeautifulSoup(response_text, features="html.parser") | |
result_divs = soup.find_all("div", {"class": "mw-search-result-heading"}) | |
if result_divs: | |
result_titles = [clean_str(div.get_text().strip()) for div in result_divs] | |
similar_titles = result_titles[:5] | |
saved_titles = [ent[0] for ent in relevant_wikipages] if relevant_wikipages else relevant_wikipages | |
for _t in similar_titles: | |
if _t not in saved_titles and len(relevant_wikipages) < 5: | |
_evi = search_entity_wikipeida(_t) | |
# _evi = search_step(_t) | |
relevant_wikipages.extend(_evi) | |
return relevant_wikipages | |
def find_evidence_from_wikipedia(claim): | |
# | |
doc = nlp(claim) | |
# | |
wikipedia_page = [] | |
for ent in doc.ents: | |
relevant_wikipages = search_entity_wikipeida(ent) | |
if len(relevant_wikipages) < 5: | |
relevant_wikipages = find_similar_wikipedia(str(ent), relevant_wikipages) | |
wikipedia_page.extend(relevant_wikipages) | |
return wikipedia_page | |
def bm25_retriever(query, corpus, topk=3): | |
bm25 = BM25Okapi(corpus) | |
# | |
query_tokens = word_tokenize(query) | |
scores = bm25.get_scores(query_tokens) | |
top_n = np.argsort(scores)[::-1][:topk] | |
top_n_scores = [scores[i] for i in top_n] | |
return top_n, top_n_scores | |
def relevant_sentence_retrieval(query, wiki_intro, k): | |
# 1. Create corpus here | |
corpus, sentences = [], [] | |
titles = [] | |
for i, (title, intro) in enumerate(wiki_intro): | |
sents_in_intro = sent_tokenize(intro) | |
for sent in sents_in_intro: | |
corpus.append(word_tokenize(sent)) | |
sentences.append(sent) | |
titles.append(title) | |
# ----- BM25 | |
bm25_top_n, bm25_top_n_scores = bm25_retriever(query, corpus, topk=k) | |
bm25_top_n_sents = [sentences[i] for i in bm25_top_n] | |
bm25_top_n_titles = [titles[i] for i in bm25_top_n] | |
return bm25_top_n_sents, bm25_top_n_titles | |
# ------------------------------ Wikipediaretriever ----------------------------- | |
def Wikipediaretriever(claim): | |
# 1. extract relevant wikipedia pages from wikipedia dumps | |
wikipedia_page = find_evidence_from_wikipedia(claim) | |
# 2. extract relevant sentences from extracted wikipedia pages | |
sents, titles = relevant_sentence_retrieval(claim, wikipedia_page, k=3) | |
# | |
results = [] | |
for i, (sent, title) in enumerate(zip(sents, titles)): | |
metadata = dict() | |
metadata['name'] = claim | |
metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split()) | |
metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title) | |
metadata['short_name'] = "Evidence {}".format(i + 1) | |
metadata['page_number'] = "" | |
metadata['query'] = sent | |
metadata['title'] = title | |
metadata['evidence'] = sent | |
metadata['answer'] = "" | |
metadata['page_content'] = "<b>Title</b>: " + str(metadata['title']) + "<br>" + "<b>Evidence</b>: " + metadata['evidence'] | |
page_content = f"""{metadata['page_content']}""" | |
results.append(Docs(metadata, page_content)) | |
return results | |
# ------------------------------ Veracity Prediction ------------------------------ | |
class SequenceClassificationDataLoader(pl.LightningDataModule): | |
def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False): | |
super().__init__() | |
self.tokenizer = tokenizer | |
self.data_file = data_file | |
self.batch_size = batch_size | |
self.add_extra_nee = add_extra_nee | |
def tokenize_strings( | |
self, | |
source_sentences, | |
max_length=400, | |
pad_to_max_length=False, | |
return_tensors="pt", | |
): | |
encoded_dict = self.tokenizer( | |
source_sentences, | |
max_length=max_length, | |
padding="max_length" if pad_to_max_length else "longest", | |
truncation=True, | |
return_tensors=return_tensors, | |
) | |
input_ids = encoded_dict["input_ids"] | |
attention_masks = encoded_dict["attention_mask"] | |
return input_ids, attention_masks | |
def quadruple_to_string(self, claim, question, answer, bool_explanation=""): | |
if bool_explanation is not None and len(bool_explanation) > 0: | |
bool_explanation = ", because " + bool_explanation.lower().strip() | |
else: | |
bool_explanation = "" | |
return ( | |
"[CLAIM] " | |
+ claim.strip() | |
+ " [QUESTION] " | |
+ question.strip() | |
+ " " | |
+ answer.strip() | |
+ bool_explanation | |
) | |
def veracity_prediction(claim, evidence): | |
dataLoader = SequenceClassificationDataLoader( | |
tokenizer=veracity_tokenizer, | |
data_file="this_is_discontinued", | |
batch_size=32, | |
add_extra_nee=False, | |
) | |
evidence_strings = [] | |
for evi in evidence: | |
evidence_strings.append(dataLoader.quadruple_to_string(claim, evi.metadata["query"], evi.metadata["answer"], "")) | |
if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI. | |
pred_label = "Not Enough Evidence" | |
return pred_label | |
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings) | |
example_support = torch.argmax(veracity_model(tokenized_strings.to(veracity_model.device), attention_mask=attention_mask.to(veracity_model.device)).logits, axis=1) | |
# example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1) | |
has_unanswerable = False | |
has_true = False | |
has_false = False | |
for v in example_support: | |
if v == 0: | |
has_true = True | |
if v == 1: | |
has_false = True | |
if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this | |
has_unanswerable = True | |
if has_unanswerable: | |
answer = 2 | |
elif has_true and not has_false: | |
answer = 0 | |
elif not has_true and has_false: | |
answer = 1 | |
else: | |
answer = 3 | |
pred_label = LABEL[answer] | |
return pred_label | |
# ------------------------------ Justification Generation ------------------------------ | |
def extract_claim_str(claim, evidence, verdict_label): | |
claim_str = "[CLAIM] " + claim + " [EVIDENCE] " | |
for evi in evidence: | |
q_text = evi.metadata['query'].strip() | |
if len(q_text) == 0: | |
continue | |
if not q_text[-1] == "?": | |
q_text += "?" | |
answer_strings = [] | |
answer_strings.append(evi.metadata['answer']) | |
claim_str += q_text | |
for a_text in answer_strings: | |
if a_text: | |
if not a_text[-1] == ".": | |
a_text += "." | |
claim_str += " " + a_text.strip() | |
claim_str += " " | |
claim_str += " [VERDICT] " + verdict_label | |
return claim_str | |
def justification_generation(claim, evidence, verdict_label): | |
# | |
# claim_str = extract_claim_str(claim, evidence, verdict_label) | |
claim_str = "[CLAIM] " + claim + " [EVIDENCE] " | |
for evi in evidence: | |
q_text = evi.metadata['query'].strip() | |
if len(q_text) == 0: | |
continue | |
if not q_text[-1] == "?": | |
q_text += "?" | |
answer_strings = [] | |
answer_strings.append(evi.metadata['answer']) | |
claim_str += q_text | |
for a_text in answer_strings: | |
if a_text: | |
if not a_text[-1] == ".": | |
a_text += "." | |
claim_str += " " + a_text.strip() | |
claim_str += " " | |
claim_str += " [VERDICT] " + verdict_label | |
# | |
claim_str.strip() | |
pred_justification = justification_model.generate(claim_str, device=justification_model.device) | |
# pred_justification = justification_model.generate(claim_str, device=device) | |
return pred_justification.strip() | |
# --------------------------------------------------------------------------------------------------------------------- | |
class Item(BaseModel): | |
claim: str | |
source: str | |
def greet_json(): | |
return {"Hello": "World!"} | |
def log_on_azure(file, logs, azure_share_client): | |
logs = json.dumps(logs) | |
file_client = azure_share_client.get_file_client(file) | |
file_client.upload_file(logs) | |
def fact_checking(item: Item): | |
# claim = item['claim'] | |
# source = item['source'] | |
claim = item.claim | |
source = item.source | |
# Step1: Evidence Retrieval | |
if source == "Wikipedia": | |
evidence = Wikipediaretriever(claim) | |
elif source == "Google": | |
evidence = Googleretriever(claim) | |
# Step2: Veracity Prediction and Justification Generation | |
verdict_label = veracity_prediction(claim, evidence) | |
justification_label = justification_generation(claim, evidence, verdict_label) | |
############################################################ | |
evidence_list = [] | |
for evi in evidence: | |
title_str = evi.metadata['title'] | |
evi_str = evi.metadata['evidence'] | |
url_str = evi.metadata['url'] | |
evidence_list.append([title_str, evi_str, url_str]) | |
try: | |
# Log answer on Azure Blob Storage | |
# IF AZURE_ISSAVE=TRUE, save the logs into the Azure share client. | |
if os.environ["AZURE_ISSAVE"] == "TRUE": | |
timestamp = str(datetime.now().timestamp()) | |
# timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
file = timestamp + ".json" | |
logs = { | |
"user_id": str(user_id), | |
"claim": claim, | |
"sources": source, | |
"evidence": evidence_list, | |
"answer": [verdict_label, justification_label], | |
"time": timestamp, | |
} | |
log_on_azure(file, logs, azure_share_client) | |
except Exception as e: | |
print(f"Error logging on Azure Blob Storage: {e}") | |
raise gr.Error( | |
f"AVeriTeC Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)") | |
########## | |
return {"Verdict": verdict_label, "Justification": justification_label, "Evidence": evidence_list} | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |
# if __name__ == "__main__": | |
# item = { | |
# "claim": "England won the Euro 2024.", | |
# "source": "Google", # Google, Wikipedia | |
# } | |
# | |
# results = fact_checking(item) | |
# | |
# print(results) | |