File size: 7,509 Bytes
54e8a79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import h5py
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import os
import torch
import pandas as pd
from src.utils.path_utils import get_project_root
class SemanticSimilarity:
def __init__(
self,
train_embeddings_file,
test_embeddings_file,
train_csv_path=None,
test_csv_path=None,
train_df=None,
test_df=None,
):
# We use the Bi-Encoder to encode all passages
self.bi_encoder = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
self.bi_encoder.max_seq_length = 512 # Truncate long passages to 256 tokens
self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
self.train_embeddings, self.train_ids = self._load_embeddings(
train_embeddings_file
)
self.test_embeddings, self.test_ids = self._load_embeddings(
test_embeddings_file
)
# Load corresponding CSV files for enriched evidence
self.train_csv = (
train_df if train_df is not None else pd.read_csv(train_csv_path)
)
self.test_csv = test_df if test_df is not None else pd.read_csv(test_csv_path)
def _load_embeddings(self, h5_file_path):
"""
Load embeddings and IDs from the HDF5 file
"""
with h5py.File(h5_file_path, "r") as h5_file:
embeddings = torch.tensor(h5_file["embeddings"][:], dtype=torch.float16)
ids = list(h5_file["ids"][:]) # Retrieve the IDs as a list of strings
return embeddings, ids
def search(self, query, top_k):
##### Sematic Search #####
# Encode the query using the bi-encoder and find potentially relevant passages
question_embedding = self.bi_encoder.encode(query, convert_to_tensor=True)
question_embedding = question_embedding.to(dtype=torch.float16)
# question_embedding = question_embedding
hits_train = util.semantic_search(
question_embedding, self.train_embeddings, top_k=top_k * 5
)
hits_train = hits_train[0] # Get the hits for the first query
# print(f"len(hits_train) = {len(hits_train)}")
hits_test = util.semantic_search(
question_embedding, self.test_embeddings, top_k=top_k * 5
)
hits_test = hits_test[0]
# print(f"len(hits_test): {len(hits_test)}")
##### Re-Ranking #####
# Now, score all retrieved passages with the cross_encoder
cross_inp_train = [
[query, self.train_csv["evidence_enriched"][hit["corpus_id"]]]
for hit in hits_train
]
cross_scores_train = self.cross_encoder.predict(cross_inp_train)
cross_inp_test = [
[query, self.test_csv["evidence_enriched"][hit["corpus_id"]]]
for hit in hits_test
]
cross_scores_test = self.cross_encoder.predict(cross_inp_test)
# Sort results by the cross-encoder scores
for idx in range(len(cross_scores_train)):
hits_train[idx]["cross-score"] = cross_scores_train[idx]
for idx in range(len(cross_scores_test)):
hits_test[idx]["cross-score"] = cross_scores_test[idx]
hits_train_cross_encoder = sorted(
hits_train, key=lambda x: x.get("cross-score"), reverse=True
)
hits_train_cross_encoder = hits_train_cross_encoder[: top_k * 5]
hits_test_cross_encoder = sorted(
hits_test, key=lambda x: x.get("cross-score"), reverse=True
)
hits_test_cross_encoder = hits_test_cross_encoder[: top_k * 5]
results = [
(self.train_ids[hit["corpus_id"]].decode("utf-8"), hit.get("cross-score"))
for hit in hits_train_cross_encoder
] + [
(self.test_ids[hit["corpus_id"]].decode("utf-8"), hit.get("cross-score"))
for hit in hits_test_cross_encoder
]
##### Filter out duplicates based on scores #####
unique_scores = set()
filtered_results = []
# print(results)
for id_, score in sorted(results, key=lambda x: x[1], reverse=True):
if score not in unique_scores:
unique_scores.add(score)
filtered_results.append((id_, score))
if (
len(filtered_results) == top_k
): # Stop when top_k unique scores are reached
break
return filtered_results
class TextCorpus:
def __init__(self, data_dir, split):
self.bi_encoder = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
self.split = split # train evidences or test evidences
self.data_dir = data_dir # .csv file for enriched train and test is contained.
def encode_corpus(self):
"""
Encode the corpus (evidence_enriched column for both train and test) and store the embeddings.
"""
file_path = os.path.join(self.data_dir, f"{self.split}_enriched.csv")
df = pd.read_csv(file_path)
# Extract the enriched evidence column and ids
evidence_enriched = df["evidence_enriched"].tolist()
ids = df["id"].tolist() # Assuming the 'id' column is in the CSV
# Encode the evidence using the bi-encoder
embeddings = self.bi_encoder.encode(evidence_enriched, convert_to_tensor=True)
# Define HDF5 file path
h5_file_path = os.path.join(get_project_root(), f"{self.split}_embeddings.h5")
with h5py.File(h5_file_path, "w") as h5_file:
h5_file.create_dataset(
"embeddings", data=embeddings.numpy(), dtype="float16"
)
h5_file.create_dataset(
"ids",
data=[f"{self.split}_{id}" for id in ids],
dtype=h5py.string_dtype(),
)
print(f"Embeddings saved to {h5_file_path}")
if __name__ == "__main__":
import time
start_time = time.time()
project_root = get_project_root()
data_dir = os.path.join(project_root, "data", "preprocessed")
# query = train_enriched['evidence_enriched'][0]
# train_embeddings = os.path.join(get_project_root(), 'train_evidence_embeddings.pkl')
# test_embeddings = os.path.join(get_project_root(), 'test_evidence_embeddings.pkl')
# semantic = SemanticSimilarity(train_embeddings, test_embeddings)
# semantic.search(query, top_k=10)
# evidence = TextCorpus(data_dir, 'train')
# Define file paths
train_csv_path = os.path.join(data_dir, "train_enriched.csv")
test_csv_path = os.path.join(data_dir, "test_enriched.csv")
train_embeddings_file = os.path.join(project_root, "train_embeddings.h5")
test_embeddings_file = os.path.join(project_root, "test_embeddings.h5")
# Initialize the SemanticSimilarity class
similarity = SemanticSimilarity(
train_embeddings_file=train_embeddings_file,
test_embeddings_file=test_embeddings_file,
train_csv_path=train_csv_path,
test_csv_path=test_csv_path,
)
# Load the first query from train_enriched.csv
train_df = pd.read_csv(train_csv_path)
first_query = train_df["claim_enriched"].iloc[2] # Get the first query
# Define the number of top-k results to retrieve
top_k = 5
# Perform the semantic search
results = similarity.search(query=first_query, top_k=top_k)
finish_time = time.time() - start_time
# Display the results
print(results)
print(f"Finish time: {finish_time}")
|