Rudra Rahul Chothe
commited on
Update src/similarity_search.py
Browse files- src/similarity_search.py +23 -21
src/similarity_search.py
CHANGED
@@ -1,22 +1,24 @@
|
|
1 |
-
import faiss
|
2 |
-
import numpy as np
|
3 |
-
import pickle
|
4 |
-
import os
|
5 |
-
|
6 |
-
class SimilaritySearchEngine:
|
7 |
-
def __init__(self
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
22 |
return [self.image_paths[idx] for idx in indices[0]], distances[0]
|
|
|
1 |
+
import faiss
|
2 |
+
import numpy as np
|
3 |
+
import pickle
|
4 |
+
import os
|
5 |
+
|
6 |
+
class SimilaritySearchEngine:
|
7 |
+
def __init__(self):
|
8 |
+
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
9 |
+
embeddings_path = os.path.join(base_dir, 'data', 'embeddings.pkl')
|
10 |
+
|
11 |
+
with open(embeddings_path, 'rb') as f:
|
12 |
+
data = pickle.load(f)
|
13 |
+
self.embeddings = data['embeddings']
|
14 |
+
# Convert Windows paths to Linux paths
|
15 |
+
self.image_paths = [os.path.normpath(path).replace('\\', '/')
|
16 |
+
for path in data['image_paths']]
|
17 |
+
|
18 |
+
dimension = len(self.embeddings[0])
|
19 |
+
self.index = faiss.IndexFlatL2(dimension)
|
20 |
+
self.index.add(np.array(self.embeddings))
|
21 |
+
|
22 |
+
def search_similar_images(self, query_embedding, top_k=5):
|
23 |
+
distances, indices = self.index.search(np.array([query_embedding]), top_k)
|
24 |
return [self.image_paths[idx] for idx in indices[0]], distances[0]
|