VocRT / embeddings.py
Anurag
version-2 initial version
5306da4
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
from langchain_text_splitters import RecursiveCharacterTextSplitter
from typing import List
import os
from qdrent import store_embeddings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
model_path = './models/e5-base-v2'
# model_path = '/Volumes/AnuragSSD/anurag/Projects/vocrt/models/e5-base-v2'
model = SentenceTransformer(model_path)
embedding_model = SentenceTransformer(model_path)
# tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
# def custom_token_text_splitter(
# text: str,
# max_tokens: int = 350,
# overlap_tokens: int = 100,
# separators: List[str] = ["\n\n", "\n", ". ", "? ", "! ", ", ", " ", "-"],
# min_chunk_tokens: int = 50,
# ) -> List[str]:
# def count_tokens(text):
# return len(tokenizer.encode(text, add_special_tokens=True))
# def split_text(text_chunk: str, current_separator_index: int) -> List[str]:
# if current_separator_index >= len(separators):
# tokens = tokenizer.encode(text_chunk, add_special_tokens=True)
# if len(tokens) <= max_tokens:
# return [text_chunk]
# else:
# chunks = []
# step = max_tokens - overlap_tokens
# for i in range(0, len(tokens), step):
# chunk_tokens = tokens[i:i+max_tokens]
# chunk_text = tokenizer.decode(
# chunk_tokens, skip_special_tokens=True)
# if chunk_text.strip():
# chunks.append(chunk_text)
# return chunks
# else:
# separator = separators[current_separator_index]
# if not separator:
# return split_text(text_chunk, current_separator_index + 1)
# splits = text_chunk.split(separator)
# chunks = []
# temp_chunk = ""
# for i, split in enumerate(splits):
# piece_to_add = separator + split if temp_chunk else split
# # Check the token count if we add this piece to temp_chunk
# potential_new_chunk = temp_chunk + piece_to_add
# token_count = count_tokens(potential_new_chunk)
# if token_count <= max_tokens + overlap_tokens:
# temp_chunk = potential_new_chunk
# if i == len(splits) - 1 and temp_chunk.strip():
# chunks.append(temp_chunk.strip())
# else:
# if temp_chunk.strip():
# chunks.append(temp_chunk.strip())
# temp_chunk = split
# final_chunks = []
# for chunk in chunks:
# if count_tokens(chunk) > max_tokens:
# final_chunks.extend(split_text(
# chunk, current_separator_index + 1))
# else:
# final_chunks.append(chunk)
# return final_chunks
# chunks = split_text(text, 0)
# if min_chunk_tokens > 0:
# filtered_chunks = []
# for chunk in chunks:
# if count_tokens(chunk) >= min_chunk_tokens or len(chunks) == 1:
# filtered_chunks.append(chunk)
# chunks = filtered_chunks
# return chunks
async def get_and_store_embeddings(input_texts, session_id, name, title, summary, categories):
try:
# chunks = custom_token_text_splitter(
# input_texts,
# max_tokens=400,
# overlap_tokens=100,
# separators=["\n\n", "\n", ". ", "? ", "! ", ", ", " "],
# min_chunk_tokens=50,
# )
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=400, chunk_overlap=100)
chunks = text_splitter.split_text(input_texts)
# # Printing chunks and their token counts
# for i, chunk in enumerate(chunks):
# token_count = len(tokenizer.encode(
# chunk, add_special_tokens=False))
# print(f"Chunk {i+1} ({token_count} tokens):")
# print(chunk.strip())
# print("-" * 70)
# Preparing chunks with prefixes
prefixed_chunks = [f"passage: {chunk.strip()}" for chunk in chunks]
# Encoding the chunks
chunk_embeddings = embedding_model.encode(
prefixed_chunks,
normalize_embeddings=True
)
# print("embeddings : ", chunk_embeddings)
await store_embeddings(session_id, chunk_embeddings, chunks, name, title, summary, categories)
return True
except Exception as e:
print("Error in getting chunks and upserting into qdrant : ", e)
return False
def get_query_embeddings(text):
query = f"query : {text}"
chunk_embeddings = embedding_model.encode(
query,
normalize_embeddings=True
)
return chunk_embeddings