from transformers import AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from typing import List | |
import os | |
from qdrent import store_embeddings | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
model_path = './models/e5-base-v2' | |
# model_path = '/Volumes/AnuragSSD/anurag/Projects/vocrt/models/e5-base-v2' | |
model = SentenceTransformer(model_path) | |
embedding_model = SentenceTransformer(model_path) | |
# tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) | |
# def custom_token_text_splitter( | |
# text: str, | |
# max_tokens: int = 350, | |
# overlap_tokens: int = 100, | |
# separators: List[str] = ["\n\n", "\n", ". ", "? ", "! ", ", ", " ", "-"], | |
# min_chunk_tokens: int = 50, | |
# ) -> List[str]: | |
# def count_tokens(text): | |
# return len(tokenizer.encode(text, add_special_tokens=True)) | |
# def split_text(text_chunk: str, current_separator_index: int) -> List[str]: | |
# if current_separator_index >= len(separators): | |
# tokens = tokenizer.encode(text_chunk, add_special_tokens=True) | |
# if len(tokens) <= max_tokens: | |
# return [text_chunk] | |
# else: | |
# chunks = [] | |
# step = max_tokens - overlap_tokens | |
# for i in range(0, len(tokens), step): | |
# chunk_tokens = tokens[i:i+max_tokens] | |
# chunk_text = tokenizer.decode( | |
# chunk_tokens, skip_special_tokens=True) | |
# if chunk_text.strip(): | |
# chunks.append(chunk_text) | |
# return chunks | |
# else: | |
# separator = separators[current_separator_index] | |
# if not separator: | |
# return split_text(text_chunk, current_separator_index + 1) | |
# splits = text_chunk.split(separator) | |
# chunks = [] | |
# temp_chunk = "" | |
# for i, split in enumerate(splits): | |
# piece_to_add = separator + split if temp_chunk else split | |
# # Check the token count if we add this piece to temp_chunk | |
# potential_new_chunk = temp_chunk + piece_to_add | |
# token_count = count_tokens(potential_new_chunk) | |
# if token_count <= max_tokens + overlap_tokens: | |
# temp_chunk = potential_new_chunk | |
# if i == len(splits) - 1 and temp_chunk.strip(): | |
# chunks.append(temp_chunk.strip()) | |
# else: | |
# if temp_chunk.strip(): | |
# chunks.append(temp_chunk.strip()) | |
# temp_chunk = split | |
# final_chunks = [] | |
# for chunk in chunks: | |
# if count_tokens(chunk) > max_tokens: | |
# final_chunks.extend(split_text( | |
# chunk, current_separator_index + 1)) | |
# else: | |
# final_chunks.append(chunk) | |
# return final_chunks | |
# chunks = split_text(text, 0) | |
# if min_chunk_tokens > 0: | |
# filtered_chunks = [] | |
# for chunk in chunks: | |
# if count_tokens(chunk) >= min_chunk_tokens or len(chunks) == 1: | |
# filtered_chunks.append(chunk) | |
# chunks = filtered_chunks | |
# return chunks | |
async def get_and_store_embeddings(input_texts, session_id, name, title, summary, categories): | |
try: | |
# chunks = custom_token_text_splitter( | |
# input_texts, | |
# max_tokens=400, | |
# overlap_tokens=100, | |
# separators=["\n\n", "\n", ". ", "? ", "! ", ", ", " "], | |
# min_chunk_tokens=50, | |
# ) | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=400, chunk_overlap=100) | |
chunks = text_splitter.split_text(input_texts) | |
# # Printing chunks and their token counts | |
# for i, chunk in enumerate(chunks): | |
# token_count = len(tokenizer.encode( | |
# chunk, add_special_tokens=False)) | |
# print(f"Chunk {i+1} ({token_count} tokens):") | |
# print(chunk.strip()) | |
# print("-" * 70) | |
# Preparing chunks with prefixes | |
prefixed_chunks = [f"passage: {chunk.strip()}" for chunk in chunks] | |
# Encoding the chunks | |
chunk_embeddings = embedding_model.encode( | |
prefixed_chunks, | |
normalize_embeddings=True | |
) | |
# print("embeddings : ", chunk_embeddings) | |
await store_embeddings(session_id, chunk_embeddings, chunks, name, title, summary, categories) | |
return True | |
except Exception as e: | |
print("Error in getting chunks and upserting into qdrant : ", e) | |
return False | |
def get_query_embeddings(text): | |
query = f"query : {text}" | |
chunk_embeddings = embedding_model.encode( | |
query, | |
normalize_embeddings=True | |
) | |
return chunk_embeddings | |