|
from huggingface_hub import login |
|
from fastapi import FastAPI, Depends, HTTPException |
|
import logging |
|
from pydantic import BaseModel |
|
from transformers import AutoTokenizer, AutoModel |
|
from services.qdrant_searcher import QdrantSearcher |
|
from services.openai_service import generate_rag_response |
|
from utils.auth import token_required |
|
from dotenv import load_dotenv |
|
import os |
|
import torch |
|
from utils.auth_x import x_api_key_auth |
|
import time |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
os.environ["HF_HOME"] = "/tmp/huggingface_cache" |
|
|
|
|
|
hf_home_dir = os.environ["HF_HOME"] |
|
if not os.path.exists(hf_home_dir): |
|
os.makedirs(hf_home_dir) |
|
|
|
collection_name = os.getenv('QDRANT_COLLECTION_NAME') |
|
logging.info(f"Collection name: {collection_name}") |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
huggingface_token = os.getenv('HUGGINGFACE_HUB_TOKEN') |
|
if huggingface_token: |
|
try: |
|
login(token=huggingface_token, add_to_git_credential=True) |
|
logging.info("Successfully logged into Hugging Face Hub.") |
|
except Exception as e: |
|
logging.error(f"Failed to log into Hugging Face Hub: {e}") |
|
raise HTTPException(status_code=500, detail="Failed to log into Hugging Face Hub.") |
|
else: |
|
raise ValueError("Hugging Face token is not set. Please set the HUGGINGFACE_HUB_TOKEN environment variable.") |
|
|
|
|
|
qdrant_url = os.getenv('QDRANT_URL') |
|
access_token = os.getenv('QDRANT_ACCESS_TOKEN') |
|
|
|
if not qdrant_url or not access_token: |
|
raise ValueError("Qdrant URL or Access Token is not set. Please set the QDRANT_URL and QDRANT_ACCESS_TOKEN environment variables.") |
|
|
|
|
|
try: |
|
cache_folder = os.path.join(hf_home_dir, "transformers_cache") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True) |
|
model = AutoModel.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True) |
|
|
|
logging.info("Successfully loaded the model and tokenizer with transformers.") |
|
|
|
|
|
global searcher |
|
searcher = QdrantSearcher(qdrant_url=qdrant_url, access_token=access_token) |
|
|
|
except Exception as e: |
|
logging.error(f"Failed to load the model or initialize searcher: {e}") |
|
raise HTTPException(status_code=500, detail="Failed to load the custom model or initialize searcher.") |
|
|
|
|
|
def embed_text(text): |
|
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt") |
|
outputs = model(**inputs) |
|
embeddings = outputs.last_hidden_state.mean(dim=1) |
|
return embeddings.detach().numpy() |
|
|
|
|
|
class SearchDocumentsRequest(BaseModel): |
|
query: str |
|
limit: int = 3 |
|
file_id: str = None |
|
|
|
class GenerateRAGRequest(BaseModel): |
|
search_query: str |
|
file_id: str = None |
|
|
|
class XApiKeyRequest(BaseModel): |
|
organization_id: str |
|
user_id: str |
|
search_query: str |
|
file_id: str = None |
|
|
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "Welcome to the Search and RAG API!, go to relevant address for API request"} |
|
|
|
|
|
@app.post("/api/search-documents") |
|
async def search_documents( |
|
body: SearchDocumentsRequest, |
|
credentials: tuple = Depends(token_required) |
|
): |
|
customer_id, user_id = credentials |
|
start_time = time.time() |
|
if not customer_id or not user_id: |
|
logging.error("Failed to extract customer_id or user_id from the JWT token.") |
|
raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id") |
|
|
|
logging.info("Received request to search documents") |
|
try: |
|
logging.info("Starting document search") |
|
|
|
|
|
query_embedding = embed_text(body.query) |
|
print(body.query) |
|
|
|
logging.info("Performing search using the precomputed embeddings") |
|
if body.file_id: |
|
hits, error = searcher.search_documents(collection_name, query_embedding, user_id, body.limit, file_id=body.file_id) |
|
|
|
hits, error = searcher.search_documents(collection_name, query_embedding, user_id, body.limit) |
|
|
|
if error: |
|
logging.error(f"Search documents error: {error}") |
|
raise HTTPException(status_code=500, detail=error) |
|
end_time = time.time() |
|
time_taken = end_time - start_time |
|
return hits, time_taken |
|
except Exception as e: |
|
logging.error(f"Unexpected error: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/api/generate-rag-response") |
|
async def generate_rag_response_api( |
|
body: GenerateRAGRequest, |
|
credentials: tuple = Depends(token_required) |
|
): |
|
customer_id, user_id = credentials |
|
start_time = time.time() |
|
if not customer_id or not user_id: |
|
logging.error("Failed to extract customer_id or user_id from the JWT token.") |
|
raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id") |
|
|
|
logging.info("Received request to generate RAG response") |
|
|
|
try: |
|
search_time = time.time() |
|
logging.info("Starting document search") |
|
|
|
query_embedding = embed_text(body.search_query) |
|
print(body.search_query) |
|
|
|
|
|
if body.file_id: |
|
hits, error = searcher.search_documents(collection_name, query_embedding, user_id, file_id=body.file_id) |
|
else: |
|
hits, error = searcher.search_documents(collection_name, query_embedding, user_id) |
|
|
|
if error: |
|
logging.error(f"Search documents error: {error}") |
|
raise HTTPException(status_code=500, detail=error) |
|
|
|
logging.info("Generating RAG response") |
|
end_search_time = time.time() |
|
search_time_taken = end_search_time - search_time |
|
rag_start_time = time.time() |
|
|
|
response, error = generate_rag_response(hits, body.search_query) |
|
rag_end_time = time.time() |
|
rag_time_taken = rag_end_time - rag_start_time |
|
end_time= time.time() |
|
total_time = end_time - start_time |
|
logging.info(f"Search time: {search_time_taken}, RAG time: {rag_time_taken}, Total time: {total_time}") |
|
if error: |
|
logging.error(f"Generate RAG response error: {error}") |
|
raise HTTPException(status_code=500, detail=error) |
|
|
|
return {"response": response} |
|
except Exception as e: |
|
logging.error(f"Unexpected error: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/api/search-documents/v1") |
|
async def search_documents_x_api_key( |
|
body: XApiKeyRequest, |
|
authorized: bool = Depends(x_api_key_auth) |
|
): |
|
if not authorized: |
|
raise HTTPException(status_code=401, detail="Unauthorized") |
|
start_time = time.time() |
|
organization_id = body.organization_id |
|
user_id = body.user_id |
|
file_id = body.file_id |
|
|
|
logging.info(f'search query {body.search_query}') |
|
logging.info(f"organization_id: {organization_id}, user_id: {user_id}") |
|
logging.info("Received request to search documents with x-api-key auth") |
|
try: |
|
logging.info("Starting document search") |
|
|
|
|
|
query_embedding = embed_text(body.search_query) |
|
|
|
|
|
|
|
hits, error = searcher.search_documents(collection_name, query_embedding, user_id, limit=3, file_id=file_id) |
|
|
|
if error: |
|
logging.error(f"Search documents error: {error}") |
|
raise HTTPException(status_code=500, detail=error) |
|
|
|
logging.info(f"Document search completed with {len(hits)} hits") |
|
end_time = time.time() |
|
logging.info(f"Time taken: {end_time - start_time}") |
|
return hits |
|
except Exception as e: |
|
logging.error(f"Unexpected error: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/api/generate-rag-response/v1") |
|
async def generate_rag_response_x_api_key( |
|
body: XApiKeyRequest, |
|
authorized: bool = Depends(x_api_key_auth) |
|
): |
|
|
|
if not authorized: |
|
raise HTTPException(status_code=401, detail="Unauthorized") |
|
start_time = time.time() |
|
organization_id = body.organization_id |
|
user_id = body.user_id |
|
file_id = body.file_id |
|
|
|
logging.info(f'search query {body.search_query}') |
|
logging.info(f"organization_id: {organization_id}, user_id: {user_id}") |
|
logging.info("Received request to generate RAG response with x-api-key auth") |
|
try: |
|
logging.info("Starting document search") |
|
|
|
|
|
query_embedding = embed_text(body.search_query) |
|
|
|
|
|
|
|
hits, error = searcher.search_documents(collection_name, query_embedding, user_id, file_id=file_id) |
|
|
|
if error: |
|
logging.error(f"Search documents error: {error}") |
|
raise HTTPException(status_code=500, detail=error) |
|
|
|
logging.info("Generating RAG response") |
|
|
|
|
|
response, error = generate_rag_response(hits, body.search_query) |
|
|
|
if error: |
|
logging.error(f"Generate RAG response error: {error}") |
|
raise HTTPException(status_code=500, detail=error) |
|
end_time = time.time() |
|
logging.info(f"Time taken: {end_time - start_time}") |
|
return {"response": response} |
|
except Exception as e: |
|
logging.error(f"Unexpected error: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
import uvicorn |
|
uvicorn.run(app, host='0.0.0.0', port=8000) |
|
|