TextSearchEngine / text_search_engine.py
DanielIglesias97's picture
We have removed the flask application and replaced it with a
6aa0bc7
# Step 1: Install required packages
import configparser
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import torch
import numpy as np
import os
import pandas as pd
class TextSearchEngine():
def __init__(self, embeddings_csv_path):
self.embeddings_csv_path = embeddings_csv_path
torch.classes.__path__ = []
def load_data_and_model(self):
# Load a sample dataset (Stanford Movie Review Dataset)
dataset = load_dataset('imdb', split='train[:1000]') # Using first 1000 examples
df = pd.DataFrame(dataset)[['text', 'label']]
# Load a small model that fits in 4GB VRAM
model = SentenceTransformer('all-MiniLM-L6-v2') # 384-dimensional embeddings
return df, model
def generate_embeddings(self, df, model, overwrite=False):
if ((not os.path.exists(self.embeddings_csv_path)) or overwrite):
texts = df['text'].tolist()
# Generate embeddings in batches for efficiency
embeddings = model.encode(texts, batch_size=32, show_progress_bar=True)
# Convert numpy array to string representation for CSV storage
df['embedding'] = [','.join(map(str, emb)) for emb in embeddings]
df.to_csv(self.embeddings_csv_path, index=False)
return df
def semantic_search(self, query, model, top_k=5):
# Load embeddings from CSV
df = pd.read_csv(self.embeddings_csv_path)
# Convert string embeddings back to numpy arrays
df['embedding'] = df['embedding'].apply(lambda x: np.fromstring(x, sep=','))
# Encode query
query_embedding = model.encode([query])
# Calculate similarities
embeddings_matrix = np.vstack(df['embedding'].values)
similarities = cosine_similarity(query_embedding, embeddings_matrix).flatten()
# Create and sort results
df['similarity'] = similarities
results = df.sort_values('similarity', ascending=False).head(top_k)
return results[['text', 'similarity', 'label']]
# Execution flow
if __name__ == "__main__":
config = configparser.ConfigParser()
config.read('config.cfg')
embeddings_csv_path = config['SERVER']['embeddings_csv_path']
text_search_engine_manager = TextSearchEngine(embeddings_csv_path)
# Generate and save embeddings (run once)
df, model = text_search_engine_manager.load_data_and_model()
text_search_engine_manager.generate_embeddings(df, model, overwrite=False)
# Example search
query = config['TEST']['query']
results = text_search_engine_manager.semantic_search(query, model)
print('Results -> ', results)