Spaces:
Running
Running
import os | |
os.environ["HF_HOME"] = "/tmp/huggingface" # Prevent permission error in HF Spaces | |
import fitz # PyMuPDF | |
import uuid | |
from fastapi import FastAPI, UploadFile, File, Form, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import HTMLResponse, JSONResponse | |
from dotenv import load_dotenv | |
from typing import List | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_core.documents import Document | |
from anthropic import Anthropic | |
# ---- Load API Keys ---- | |
load_dotenv() | |
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") | |
CLAUDE_MODEL = "claude-3-haiku-20240307" | |
# ---- App Init ---- | |
app = FastAPI() | |
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
# Mount static directory (if needed for frontend) | |
os.makedirs(os.path.join(os.path.dirname(__file__), "static"), exist_ok=True) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# ---- In-Memory Stores ---- | |
db_store = {} # session_id β Chroma vector DB | |
chat_store = {} # session_id β chat messages | |
general_chat_sessions = {} # session_id β general (no PDF) flag | |
# ---- Utility Functions ---- | |
def extract_text_from_pdf(file) -> str: | |
"""Extracts text from the first page of a PDF.""" | |
doc = fitz.open(stream=file.file.read(), filetype="pdf") | |
return doc[0].get_text() | |
def build_vector_db(text: str, collection_name: str) -> Chroma: | |
"""Chunks, embeds, and stores text in ChromaDB.""" | |
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
docs = splitter.create_documents([text]) | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
vectordb = Chroma.from_documents(docs, embeddings, collection_name=collection_name) | |
return vectordb | |
def retrieve_context(vectordb: Chroma, query: str, k: int = 3) -> str: | |
"""Fetches top-k similar chunks from the vector DB.""" | |
docs = vectordb.similarity_search(query, k=k) | |
return "\n\n".join([d.page_content for d in docs]) | |
def create_session(is_pdf: bool = True) -> str: | |
"""Creates a new unique session ID.""" | |
sid = str(uuid.uuid4()) | |
chat_store[sid] = [] | |
if not is_pdf: | |
general_chat_sessions[sid] = True | |
return sid | |
def append_chat(session_id: str, role: str, msg: str): | |
chat_store[session_id].append({"role": role, "text": msg}) | |
def get_chat(session_id: str): | |
return chat_store.get(session_id, []) | |
def delete_session(session_id: str): | |
chat_store.pop(session_id, None) | |
db_store.pop(session_id, None) | |
general_chat_sessions.pop(session_id, None) | |
# ---- API Endpoints ---- | |
async def get_home(): | |
try: | |
with open(os.path.join(os.path.dirname(__file__), "static", "index.html")) as f: | |
return f.read() | |
except FileNotFoundError: | |
return HTMLResponse(content="<h1>RAG Chatbot API</h1><p>Upload a PDF or start a chat.</p>") | |
async def start_general_chat(): | |
"""Starts a general chat session without PDF.""" | |
session_id = create_session(is_pdf=False) | |
return {"session_id": session_id, "message": "General chat session started."} | |
async def upload_pdf(file: UploadFile = File(...), current_session_id: str = Form(None)): | |
"""Handles PDF upload and indexing with chat continuity.""" | |
text = extract_text_from_pdf(file) | |
if current_session_id and current_session_id in chat_store: | |
session_id = current_session_id | |
general_chat_sessions.pop(session_id, None) # upgrade to PDF mode | |
else: | |
session_id = create_session() | |
vectordb = build_vector_db(text, collection_name=session_id) | |
db_store[session_id] = vectordb | |
return {"session_id": session_id, "message": "PDF indexed."} | |
async def chat(session_id: str = Form(...), prompt: str = Form(...)): | |
is_general_chat = session_id in general_chat_sessions | |
is_pdf_chat = session_id in db_store | |
if not is_general_chat and not is_pdf_chat: | |
return {"error": "Invalid session ID"} | |
append_chat(session_id, "user", prompt) | |
if not ANTHROPIC_API_KEY: | |
return JSONResponse(status_code=500, content={"error": "Missing ANTHROPIC_API_KEY environment variable"}) | |
client = Anthropic(api_key=ANTHROPIC_API_KEY.strip()) | |
if is_general_chat: | |
# No context, just send prompt | |
response = client.messages.create( | |
model=CLAUDE_MODEL, | |
max_tokens=512, | |
temperature=0.5, | |
messages=[{"role": "user", "content": prompt}] | |
) | |
else: | |
context = retrieve_context(db_store[session_id], prompt) | |
response = client.messages.create( | |
model=CLAUDE_MODEL, | |
max_tokens=512, | |
temperature=0.5, | |
messages=[{"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{prompt}"}] | |
) | |
answer = response.content[0].text | |
append_chat(session_id, "bot", answer) | |
return {"answer": answer, "chat_history": get_chat(session_id)} | |
async def end_chat(session_id: str = Form(...)): | |
"""Ends session and deletes associated data.""" | |
delete_session(session_id) | |
return {"message": "Session cleared."} | |