misinfo / src /evidence /text2text_retrieval.py
gyigit's picture
update
54e8a79
raw
history blame
7.51 kB
import h5py
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import os
import torch
import pandas as pd
from src.utils.path_utils import get_project_root
class SemanticSimilarity:
def __init__(
self,
train_embeddings_file,
test_embeddings_file,
train_csv_path=None,
test_csv_path=None,
train_df=None,
test_df=None,
):
# We use the Bi-Encoder to encode all passages
self.bi_encoder = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
self.bi_encoder.max_seq_length = 512 # Truncate long passages to 256 tokens
self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
self.train_embeddings, self.train_ids = self._load_embeddings(
train_embeddings_file
)
self.test_embeddings, self.test_ids = self._load_embeddings(
test_embeddings_file
)
# Load corresponding CSV files for enriched evidence
self.train_csv = (
train_df if train_df is not None else pd.read_csv(train_csv_path)
)
self.test_csv = test_df if test_df is not None else pd.read_csv(test_csv_path)
def _load_embeddings(self, h5_file_path):
"""
Load embeddings and IDs from the HDF5 file
"""
with h5py.File(h5_file_path, "r") as h5_file:
embeddings = torch.tensor(h5_file["embeddings"][:], dtype=torch.float16)
ids = list(h5_file["ids"][:]) # Retrieve the IDs as a list of strings
return embeddings, ids
def search(self, query, top_k):
##### Sematic Search #####
# Encode the query using the bi-encoder and find potentially relevant passages
question_embedding = self.bi_encoder.encode(query, convert_to_tensor=True)
question_embedding = question_embedding.to(dtype=torch.float16)
# question_embedding = question_embedding
hits_train = util.semantic_search(
question_embedding, self.train_embeddings, top_k=top_k * 5
)
hits_train = hits_train[0] # Get the hits for the first query
# print(f"len(hits_train) = {len(hits_train)}")
hits_test = util.semantic_search(
question_embedding, self.test_embeddings, top_k=top_k * 5
)
hits_test = hits_test[0]
# print(f"len(hits_test): {len(hits_test)}")
##### Re-Ranking #####
# Now, score all retrieved passages with the cross_encoder
cross_inp_train = [
[query, self.train_csv["evidence_enriched"][hit["corpus_id"]]]
for hit in hits_train
]
cross_scores_train = self.cross_encoder.predict(cross_inp_train)
cross_inp_test = [
[query, self.test_csv["evidence_enriched"][hit["corpus_id"]]]
for hit in hits_test
]
cross_scores_test = self.cross_encoder.predict(cross_inp_test)
# Sort results by the cross-encoder scores
for idx in range(len(cross_scores_train)):
hits_train[idx]["cross-score"] = cross_scores_train[idx]
for idx in range(len(cross_scores_test)):
hits_test[idx]["cross-score"] = cross_scores_test[idx]
hits_train_cross_encoder = sorted(
hits_train, key=lambda x: x.get("cross-score"), reverse=True
)
hits_train_cross_encoder = hits_train_cross_encoder[: top_k * 5]
hits_test_cross_encoder = sorted(
hits_test, key=lambda x: x.get("cross-score"), reverse=True
)
hits_test_cross_encoder = hits_test_cross_encoder[: top_k * 5]
results = [
(self.train_ids[hit["corpus_id"]].decode("utf-8"), hit.get("cross-score"))
for hit in hits_train_cross_encoder
] + [
(self.test_ids[hit["corpus_id"]].decode("utf-8"), hit.get("cross-score"))
for hit in hits_test_cross_encoder
]
##### Filter out duplicates based on scores #####
unique_scores = set()
filtered_results = []
# print(results)
for id_, score in sorted(results, key=lambda x: x[1], reverse=True):
if score not in unique_scores:
unique_scores.add(score)
filtered_results.append((id_, score))
if (
len(filtered_results) == top_k
): # Stop when top_k unique scores are reached
break
return filtered_results
class TextCorpus:
def __init__(self, data_dir, split):
self.bi_encoder = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
self.split = split # train evidences or test evidences
self.data_dir = data_dir # .csv file for enriched train and test is contained.
def encode_corpus(self):
"""
Encode the corpus (evidence_enriched column for both train and test) and store the embeddings.
"""
file_path = os.path.join(self.data_dir, f"{self.split}_enriched.csv")
df = pd.read_csv(file_path)
# Extract the enriched evidence column and ids
evidence_enriched = df["evidence_enriched"].tolist()
ids = df["id"].tolist() # Assuming the 'id' column is in the CSV
# Encode the evidence using the bi-encoder
embeddings = self.bi_encoder.encode(evidence_enriched, convert_to_tensor=True)
# Define HDF5 file path
h5_file_path = os.path.join(get_project_root(), f"{self.split}_embeddings.h5")
with h5py.File(h5_file_path, "w") as h5_file:
h5_file.create_dataset(
"embeddings", data=embeddings.numpy(), dtype="float16"
)
h5_file.create_dataset(
"ids",
data=[f"{self.split}_{id}" for id in ids],
dtype=h5py.string_dtype(),
)
print(f"Embeddings saved to {h5_file_path}")
if __name__ == "__main__":
import time
start_time = time.time()
project_root = get_project_root()
data_dir = os.path.join(project_root, "data", "preprocessed")
# query = train_enriched['evidence_enriched'][0]
# train_embeddings = os.path.join(get_project_root(), 'train_evidence_embeddings.pkl')
# test_embeddings = os.path.join(get_project_root(), 'test_evidence_embeddings.pkl')
# semantic = SemanticSimilarity(train_embeddings, test_embeddings)
# semantic.search(query, top_k=10)
# evidence = TextCorpus(data_dir, 'train')
# Define file paths
train_csv_path = os.path.join(data_dir, "train_enriched.csv")
test_csv_path = os.path.join(data_dir, "test_enriched.csv")
train_embeddings_file = os.path.join(project_root, "train_embeddings.h5")
test_embeddings_file = os.path.join(project_root, "test_embeddings.h5")
# Initialize the SemanticSimilarity class
similarity = SemanticSimilarity(
train_embeddings_file=train_embeddings_file,
test_embeddings_file=test_embeddings_file,
train_csv_path=train_csv_path,
test_csv_path=test_csv_path,
)
# Load the first query from train_enriched.csv
train_df = pd.read_csv(train_csv_path)
first_query = train_df["claim_enriched"].iloc[2] # Get the first query
# Define the number of top-k results to retrieve
top_k = 5
# Perform the semantic search
results = similarity.search(query=first_query, top_k=top_k)
finish_time = time.time() - start_time
# Display the results
print(results)
print(f"Finish time: {finish_time}")