Spaces:
Sleeping
Sleeping
# Step 1: Install required packages | |
import configparser | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import torch | |
import numpy as np | |
import os | |
import pandas as pd | |
class TextSearchEngine(): | |
def __init__(self, embeddings_csv_path): | |
self.embeddings_csv_path = embeddings_csv_path | |
torch.classes.__path__ = [] | |
def load_data_and_model(self): | |
# Load a sample dataset (Stanford Movie Review Dataset) | |
dataset = load_dataset('imdb', split='train[:1000]') # Using first 1000 examples | |
df = pd.DataFrame(dataset)[['text', 'label']] | |
# Load a small model that fits in 4GB VRAM | |
model = SentenceTransformer('all-MiniLM-L6-v2') # 384-dimensional embeddings | |
return df, model | |
def generate_embeddings(self, df, model, overwrite=False): | |
if ((not os.path.exists(self.embeddings_csv_path)) or overwrite): | |
texts = df['text'].tolist() | |
# Generate embeddings in batches for efficiency | |
embeddings = model.encode(texts, batch_size=32, show_progress_bar=True) | |
# Convert numpy array to string representation for CSV storage | |
df['embedding'] = [','.join(map(str, emb)) for emb in embeddings] | |
df.to_csv(self.embeddings_csv_path, index=False) | |
return df | |
def semantic_search(self, query, model, top_k=5): | |
# Load embeddings from CSV | |
df = pd.read_csv(self.embeddings_csv_path) | |
# Convert string embeddings back to numpy arrays | |
df['embedding'] = df['embedding'].apply(lambda x: np.fromstring(x, sep=',')) | |
# Encode query | |
query_embedding = model.encode([query]) | |
# Calculate similarities | |
embeddings_matrix = np.vstack(df['embedding'].values) | |
similarities = cosine_similarity(query_embedding, embeddings_matrix).flatten() | |
# Create and sort results | |
df['similarity'] = similarities | |
results = df.sort_values('similarity', ascending=False).head(top_k) | |
return results[['text', 'similarity', 'label']] | |
# Execution flow | |
if __name__ == "__main__": | |
config = configparser.ConfigParser() | |
config.read('config.cfg') | |
embeddings_csv_path = config['SERVER']['embeddings_csv_path'] | |
text_search_engine_manager = TextSearchEngine(embeddings_csv_path) | |
# Generate and save embeddings (run once) | |
df, model = text_search_engine_manager.load_data_and_model() | |
text_search_engine_manager.generate_embeddings(df, model, overwrite=False) | |
# Example search | |
query = config['TEST']['query'] | |
results = text_search_engine_manager.semantic_search(query, model) | |
print('Results -> ', results) |