AITeamVN commited on
Commit
c2e7c2a
·
verified ·
1 Parent(s): a5fa2bb

Upload evaluation_model.py

Browse files
Files changed (1) hide show
  1. evaluation_model.py +191 -0
evaluation_model.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import json
4
+ import pandas as pd
5
+ from tqdm import tqdm
6
+ from typing import List, Dict, Tuple, Set, Union, Optional
7
+ from langchain.docstore.document import Document
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain_community.vectorstores.faiss import DistanceStrategy
10
+ from langchain_core.embeddings.embeddings import Embeddings
11
+ from FlagEmbedding import BGEM3FlagModel
12
+
13
+ def setup_gpu_info() -> None:
14
+ print(f"Số lượng GPU khả dụng: {torch.cuda.device_count()}")
15
+ print(f"GPU hiện tại: {torch.cuda.current_device()}")
16
+ print(f"Tên GPU: {torch.cuda.get_device_name(0)}")
17
+
18
+ def load_model(model_name: str, use_fp16: bool = False) -> BGEM3FlagModel:
19
+ return BGEM3FlagModel(model_name, use_fp16=use_fp16)
20
+
21
+ def load_json_file(file_path: str) -> dict:
22
+ with open(file_path, 'r', encoding='utf-8') as f:
23
+ return json.load(f)
24
+
25
+ def load_jsonl_file(file_path: str) -> List[Dict]:
26
+ corpus = []
27
+ with open(file_path, "r", encoding="utf-8") as file:
28
+ for line in file:
29
+ data = json.loads(line.strip())
30
+ corpus.append(data)
31
+ return corpus
32
+
33
+ def extract_corpus_from_legal_documents(legal_data: dict) -> List[Dict]:
34
+ corpus = []
35
+ for document in legal_data:
36
+ for article in document['articles']:
37
+ chunk = {
38
+ "law_id": document['law_id'],
39
+ "article_id": article['article_id'],
40
+ "title": article['title'],
41
+ "text": article['title'] + '\n' + article['text']
42
+ }
43
+ corpus.append(chunk)
44
+ return corpus
45
+
46
+ def convert_corpus_to_documents(corpus: List[Dict[str, str]]) -> List[Document]:
47
+ documents = []
48
+ for i in tqdm(range(len(corpus)), desc="Converting corpus to documents"):
49
+ context = corpus[i]['text']
50
+ metadata = {
51
+ 'law_id': corpus[i]['law_id'],
52
+ 'article_id': corpus[i]['article_id'],
53
+ 'title': corpus[i]['title']
54
+ }
55
+ documents.append(Document(page_content=context, metadata=metadata))
56
+ return documents
57
+
58
+ class CustomEmbedding(Embeddings):
59
+ """Custom embedding class that uses the BGEM3FlagModel."""
60
+
61
+ def __init__(self, model: BGEM3FlagModel, batch_size: int = 1):
62
+ self.model = model
63
+ self.batch_size = batch_size
64
+
65
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
66
+ embeddings = []
67
+ for i in tqdm(range(0, len(texts), self.batch_size), desc="Embedding documents"):
68
+ batch_texts = texts[i:i+self.batch_size]
69
+ batch_embeddings = self._get_batch_embeddings(batch_texts)
70
+ embeddings.extend(batch_embeddings)
71
+ torch.cuda.empty_cache()
72
+ return np.vstack(embeddings)
73
+
74
+ def embed_query(self, text: str) -> List[float]:
75
+ embedding = self.model.encode(text, max_length=256)['dense_vecs']
76
+ return embedding
77
+
78
+ def _get_batch_embeddings(self, texts: List[str]) -> List[List[float]]:
79
+ with torch.no_grad():
80
+ outputs = self.model.encode(texts, batch_size=self.batch_size, max_length=2048)['dense_vecs']
81
+ batch_embeddings = outputs
82
+ del outputs
83
+ return batch_embeddings
84
+
85
+
86
+ class VectorDB:
87
+ """Vector database for document retrieval."""
88
+
89
+ def __init__(
90
+ self,
91
+ documents: List[Document],
92
+ embedding: Embeddings,
93
+ vector_db=FAISS,
94
+ index_path: Optional[str] = None
95
+ ) -> None:
96
+ self.vector_db = vector_db
97
+ self.embedding = embedding
98
+ self.index_path = index_path
99
+ self.db = self._build_db(documents)
100
+
101
+ def _build_db(self, documents: List[Document]):
102
+ if self.index_path:
103
+ db = self.vector_db.load_local(
104
+ self.index_path,
105
+ self.embedding,
106
+ allow_dangerous_deserialization=True
107
+ )
108
+ else:
109
+ db = self.vector_db.from_documents(
110
+ documents=documents,
111
+ embedding=self.embedding,
112
+ distance_strategy=DistanceStrategy.DOT_PRODUCT
113
+ )
114
+ return db
115
+
116
+ def get_retriever(self, search_type: str = "similarity", search_kwargs: dict = {"k": 10}):
117
+ retriever = self.db.as_retriever(search_type=search_type, search_kwargs=search_kwargs)
118
+ return retriever
119
+
120
+ def save_local(self, folder_path: str) -> None:
121
+ self.db.save_local(folder_path)
122
+
123
+
124
+ def process_sample(sample: dict, retriever) -> List[int]:
125
+ question = sample['question']
126
+ docs = retriever.invoke(question)
127
+ retrieved_article_full_ids = [
128
+ docs[i].metadata['law_id'] + "#" + docs[i].metadata['article_id']
129
+ for i in range(len(docs))
130
+ ]
131
+ indexes = []
132
+ for article in sample['relevant_articles']:
133
+ article_full_id = article['law_id'] + "#" + article['article_id']
134
+ if article_full_id in retrieved_article_full_ids:
135
+ idx = retrieved_article_full_ids.index(article_full_id) + 1
136
+ indexes.append(idx)
137
+ else:
138
+ indexes.append(0)
139
+ return indexes
140
+
141
+ def calculate_metrics(all_indexes: List[List[int]], num_samples: int, selected_keys: Set[str]) -> Dict[str, float]:
142
+ count = [len(indexes) for indexes in all_indexes]
143
+ result = {}
144
+
145
+ for thres in [1, 3, 5, 10, 100]:
146
+ found = [[y for y in x if 0 < y <= thres] for x in all_indexes]
147
+ found_count = [len(x) for x in found]
148
+ acc = sum(1 for i in range(num_samples) if found_count[i] > 0) / num_samples
149
+ rec = sum(found_count[i] / count[i] for i in range(num_samples)) / num_samples
150
+ pre = sum(found_count[i] / thres for i in range(num_samples)) / num_samples
151
+ mrr = sum(1 / min(x) if x else 0 for x in found) / num_samples
152
+
153
+ if f"Accuracy@{thres}" in selected_keys:
154
+ result[f"Accuracy@{thres}"] = acc
155
+ if f"MRR@{thres}" in selected_keys:
156
+ result[f"MRR@{thres}"] = mrr
157
+
158
+ return result
159
+
160
+
161
+ def save_results(result: Dict[str, float], output_path: str) -> None:
162
+ with open(output_path, "w", encoding="utf-8") as f:
163
+ json.dump(result, f, indent=4, ensure_ascii=False)
164
+ print(f"Results saved to {output_path}")
165
+
166
+
167
+ def main():
168
+ setup_gpu_info()
169
+ model = load_model('AITeamVN/Vietnamese_Embedding', use_fp16=False)
170
+ samples = load_json_file('zalo_kaggle/train_question_answer.json')['items']
171
+ legal_data = load_json_file('zalo_kaggle/legal_corpus.json')
172
+
173
+ corpus = extract_corpus_from_legal_documents(legal_data)
174
+ documents = convert_corpus_to_documents(corpus)
175
+ embedding = CustomEmbedding(model, batch_size=1) # Increased batch size for efficiency time
176
+ vectordb = VectorDB(
177
+ documents=documents,
178
+ embedding=embedding,
179
+ vector_db=FAISS,
180
+ index_path=None
181
+ )
182
+ retriever = vectordb.get_retriever(search_type="similarity", search_kwargs={"k": 100})
183
+ all_indexes = []
184
+ for sample in tqdm(samples, desc="Processing samples"):
185
+ all_indexes.append(process_sample(sample, retriever))
186
+ selected_keys = {"Accuracy@1", "Accuracy@3", "Accuracy@5", "Accuracy@10", "MRR@10", "Accuracy@100"}
187
+ result = calculate_metrics(all_indexes, len(samples), selected_keys)
188
+ print(result)
189
+ save_results(result, "zalo_kaggle/Vietnamese_Embedding.json")
190
+ if __name__ == "__main__":
191
+ main()