|
import h5py |
|
from sentence_transformers import SentenceTransformer, CrossEncoder, util |
|
import os |
|
import torch |
|
import pandas as pd |
|
|
|
from src.utils.path_utils import get_project_root |
|
|
|
|
|
class SemanticSimilarity: |
|
def __init__( |
|
self, |
|
train_embeddings_file, |
|
test_embeddings_file, |
|
train_csv_path=None, |
|
test_csv_path=None, |
|
train_df=None, |
|
test_df=None, |
|
): |
|
|
|
self.bi_encoder = SentenceTransformer("multi-qa-mpnet-base-dot-v1") |
|
self.bi_encoder.max_seq_length = 512 |
|
|
|
self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") |
|
|
|
self.train_embeddings, self.train_ids = self._load_embeddings( |
|
train_embeddings_file |
|
) |
|
self.test_embeddings, self.test_ids = self._load_embeddings( |
|
test_embeddings_file |
|
) |
|
|
|
|
|
self.train_csv = ( |
|
train_df if train_df is not None else pd.read_csv(train_csv_path) |
|
) |
|
self.test_csv = test_df if test_df is not None else pd.read_csv(test_csv_path) |
|
|
|
def _load_embeddings(self, h5_file_path): |
|
""" |
|
Load embeddings and IDs from the HDF5 file |
|
""" |
|
with h5py.File(h5_file_path, "r") as h5_file: |
|
embeddings = torch.tensor(h5_file["embeddings"][:], dtype=torch.float16) |
|
ids = list(h5_file["ids"][:]) |
|
|
|
return embeddings, ids |
|
|
|
def search(self, query, top_k): |
|
|
|
|
|
question_embedding = self.bi_encoder.encode(query, convert_to_tensor=True) |
|
question_embedding = question_embedding.to(dtype=torch.float16) |
|
|
|
|
|
hits_train = util.semantic_search( |
|
question_embedding, self.train_embeddings, top_k=top_k * 5 |
|
) |
|
hits_train = hits_train[0] |
|
|
|
hits_test = util.semantic_search( |
|
question_embedding, self.test_embeddings, top_k=top_k * 5 |
|
) |
|
hits_test = hits_test[0] |
|
|
|
|
|
|
|
|
|
cross_inp_train = [ |
|
[query, self.train_csv["evidence_enriched"][hit["corpus_id"]]] |
|
for hit in hits_train |
|
] |
|
cross_scores_train = self.cross_encoder.predict(cross_inp_train) |
|
|
|
cross_inp_test = [ |
|
[query, self.test_csv["evidence_enriched"][hit["corpus_id"]]] |
|
for hit in hits_test |
|
] |
|
cross_scores_test = self.cross_encoder.predict(cross_inp_test) |
|
|
|
|
|
for idx in range(len(cross_scores_train)): |
|
hits_train[idx]["cross-score"] = cross_scores_train[idx] |
|
|
|
for idx in range(len(cross_scores_test)): |
|
hits_test[idx]["cross-score"] = cross_scores_test[idx] |
|
|
|
hits_train_cross_encoder = sorted( |
|
hits_train, key=lambda x: x.get("cross-score"), reverse=True |
|
) |
|
hits_train_cross_encoder = hits_train_cross_encoder[: top_k * 5] |
|
hits_test_cross_encoder = sorted( |
|
hits_test, key=lambda x: x.get("cross-score"), reverse=True |
|
) |
|
hits_test_cross_encoder = hits_test_cross_encoder[: top_k * 5] |
|
|
|
results = [ |
|
(self.train_ids[hit["corpus_id"]].decode("utf-8"), hit.get("cross-score")) |
|
for hit in hits_train_cross_encoder |
|
] + [ |
|
(self.test_ids[hit["corpus_id"]].decode("utf-8"), hit.get("cross-score")) |
|
for hit in hits_test_cross_encoder |
|
] |
|
|
|
|
|
unique_scores = set() |
|
filtered_results = [] |
|
|
|
|
|
for id_, score in sorted(results, key=lambda x: x[1], reverse=True): |
|
if score not in unique_scores: |
|
unique_scores.add(score) |
|
filtered_results.append((id_, score)) |
|
|
|
if ( |
|
len(filtered_results) == top_k |
|
): |
|
break |
|
|
|
return filtered_results |
|
|
|
|
|
class TextCorpus: |
|
def __init__(self, data_dir, split): |
|
self.bi_encoder = SentenceTransformer("multi-qa-mpnet-base-dot-v1") |
|
self.split = split |
|
self.data_dir = data_dir |
|
|
|
def encode_corpus(self): |
|
""" |
|
Encode the corpus (evidence_enriched column for both train and test) and store the embeddings. |
|
""" |
|
file_path = os.path.join(self.data_dir, f"{self.split}_enriched.csv") |
|
df = pd.read_csv(file_path) |
|
|
|
|
|
evidence_enriched = df["evidence_enriched"].tolist() |
|
ids = df["id"].tolist() |
|
|
|
|
|
embeddings = self.bi_encoder.encode(evidence_enriched, convert_to_tensor=True) |
|
|
|
|
|
h5_file_path = os.path.join(get_project_root(), f"{self.split}_embeddings.h5") |
|
|
|
with h5py.File(h5_file_path, "w") as h5_file: |
|
h5_file.create_dataset( |
|
"embeddings", data=embeddings.numpy(), dtype="float16" |
|
) |
|
|
|
h5_file.create_dataset( |
|
"ids", |
|
data=[f"{self.split}_{id}" for id in ids], |
|
dtype=h5py.string_dtype(), |
|
) |
|
|
|
print(f"Embeddings saved to {h5_file_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
import time |
|
|
|
start_time = time.time() |
|
project_root = get_project_root() |
|
data_dir = os.path.join(project_root, "data", "preprocessed") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_csv_path = os.path.join(data_dir, "train_enriched.csv") |
|
test_csv_path = os.path.join(data_dir, "test_enriched.csv") |
|
train_embeddings_file = os.path.join(project_root, "train_embeddings.h5") |
|
test_embeddings_file = os.path.join(project_root, "test_embeddings.h5") |
|
|
|
|
|
similarity = SemanticSimilarity( |
|
train_embeddings_file=train_embeddings_file, |
|
test_embeddings_file=test_embeddings_file, |
|
train_csv_path=train_csv_path, |
|
test_csv_path=test_csv_path, |
|
) |
|
|
|
|
|
train_df = pd.read_csv(train_csv_path) |
|
first_query = train_df["claim_enriched"].iloc[2] |
|
|
|
|
|
top_k = 5 |
|
|
|
|
|
results = similarity.search(query=first_query, top_k=top_k) |
|
finish_time = time.time() - start_time |
|
|
|
|
|
print(results) |
|
print(f"Finish time: {finish_time}") |
|
|