harishvijayasarangan05 commited on
Commit
9a84d3e
·
verified ·
1 Parent(s): 964d67a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +27 -41
main.py CHANGED
@@ -1,19 +1,19 @@
1
  import os
 
 
2
  import fitz # PyMuPDF
3
  import uuid
4
  from fastapi import FastAPI, UploadFile, File, Form, Request
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from fastapi.staticfiles import StaticFiles
7
  from fastapi.responses import HTMLResponse, JSONResponse
8
- from pydantic import BaseModel
9
- from typing import List
10
  from dotenv import load_dotenv
 
11
 
12
  from langchain_text_splitters import RecursiveCharacterTextSplitter
13
  from langchain_community.vectorstores import Chroma
14
  from langchain_community.embeddings import HuggingFaceEmbeddings
15
  from langchain_core.documents import Document
16
-
17
  from anthropic import Anthropic
18
 
19
  # ---- Load API Keys ----
@@ -25,18 +25,16 @@ CLAUDE_MODEL = "claude-3-haiku-20240307"
25
  app = FastAPI()
26
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
27
 
28
- # Create static directory if it doesn't exist
29
  os.makedirs(os.path.join(os.path.dirname(__file__), "static"), exist_ok=True)
30
-
31
- # Mount static files directory
32
  app.mount("/static", StaticFiles(directory="static"), name="static")
33
 
34
  # ---- In-Memory Stores ----
35
- db_store = {}
36
- chat_store = {}
37
- general_chat_sessions = {}
38
 
39
- # ---- Utils ----
40
 
41
  def extract_text_from_pdf(file) -> str:
42
  """Extracts text from the first page of a PDF."""
@@ -47,9 +45,7 @@ def build_vector_db(text: str, collection_name: str) -> Chroma:
47
  """Chunks, embeds, and stores text in ChromaDB."""
48
  splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
49
  docs = splitter.create_documents([text])
50
-
51
- # Using a standard model that should be available publicly
52
- embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
53
  vectordb = Chroma.from_documents(docs, embeddings, collection_name=collection_name)
54
  return vectordb
55
 
@@ -62,11 +58,8 @@ def create_session(is_pdf: bool = True) -> str:
62
  """Creates a new unique session ID."""
63
  sid = str(uuid.uuid4())
64
  chat_store[sid] = []
65
-
66
- # Track if this is a general chat session (without PDF)
67
  if not is_pdf:
68
  general_chat_sessions[sid] = True
69
-
70
  return sid
71
 
72
  def append_chat(session_id: str, role: str, msg: str):
@@ -80,62 +73,55 @@ def delete_session(session_id: str):
80
  db_store.pop(session_id, None)
81
  general_chat_sessions.pop(session_id, None)
82
 
83
- # ---- API Routes ----
84
 
85
  @app.get("/", response_class=HTMLResponse)
86
  async def get_home():
87
- with open(os.path.join(os.path.dirname(__file__), "static", "index.html")) as f:
88
- return f.read()
 
 
 
89
 
90
  @app.post("/start-chat/")
91
  async def start_general_chat():
92
  """Starts a general chat session without PDF."""
93
  session_id = create_session(is_pdf=False)
94
  return {"session_id": session_id, "message": "General chat session started."}
95
-
96
  @app.post("/upload/")
97
  async def upload_pdf(file: UploadFile = File(...), current_session_id: str = Form(None)):
98
  """Handles PDF upload and indexing with chat continuity."""
99
- # Extract text from PDF
100
  text = extract_text_from_pdf(file)
101
-
102
- # Handle session continuity
103
  if current_session_id and current_session_id in chat_store:
104
- # Continue with existing session
105
  session_id = current_session_id
106
- # Remove from general chat sessions if it was one
107
- if session_id in general_chat_sessions:
108
- general_chat_sessions.pop(session_id)
109
  else:
110
- # Create a new session
111
  session_id = create_session()
112
-
113
- # Create and store the vector database
114
  vectordb = build_vector_db(text, collection_name=session_id)
115
  db_store[session_id] = vectordb
116
-
117
  return {"session_id": session_id, "message": "PDF indexed."}
118
 
119
  @app.post("/chat/")
120
  async def chat(session_id: str = Form(...), prompt: str = Form(...)):
121
- """Handles user chat prompt, fetches relevant info, calls Claude."""
122
- # Check if this is a general chat or PDF chat
123
  is_general_chat = session_id in general_chat_sessions
124
  is_pdf_chat = session_id in db_store
125
-
126
  if not is_general_chat and not is_pdf_chat:
127
  return {"error": "Invalid session ID"}
128
-
129
  append_chat(session_id, "user", prompt)
130
-
131
- # Ensure we have an API key and initialize with proper parameters
132
  if not ANTHROPIC_API_KEY:
133
  return JSONResponse(status_code=500, content={"error": "Missing ANTHROPIC_API_KEY environment variable"})
134
-
135
  client = Anthropic(api_key=ANTHROPIC_API_KEY.strip())
136
-
137
  if is_general_chat:
138
- # General chat without PDF context
139
  response = client.messages.create(
140
  model=CLAUDE_MODEL,
141
  max_tokens=512,
@@ -143,7 +129,6 @@ async def chat(session_id: str = Form(...), prompt: str = Form(...)):
143
  messages=[{"role": "user", "content": prompt}]
144
  )
145
  else:
146
- # PDF-based chat with context
147
  context = retrieve_context(db_store[session_id], prompt)
148
  response = client.messages.create(
149
  model=CLAUDE_MODEL,
@@ -162,3 +147,4 @@ async def end_chat(session_id: str = Form(...)):
162
  """Ends session and deletes associated data."""
163
  delete_session(session_id)
164
  return {"message": "Session cleared."}
 
 
1
  import os
2
+ os.environ["HF_HOME"] = "/tmp/huggingface" # Prevent permission error in HF Spaces
3
+
4
  import fitz # PyMuPDF
5
  import uuid
6
  from fastapi import FastAPI, UploadFile, File, Form, Request
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.staticfiles import StaticFiles
9
  from fastapi.responses import HTMLResponse, JSONResponse
 
 
10
  from dotenv import load_dotenv
11
+ from typing import List
12
 
13
  from langchain_text_splitters import RecursiveCharacterTextSplitter
14
  from langchain_community.vectorstores import Chroma
15
  from langchain_community.embeddings import HuggingFaceEmbeddings
16
  from langchain_core.documents import Document
 
17
  from anthropic import Anthropic
18
 
19
  # ---- Load API Keys ----
 
25
  app = FastAPI()
26
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
27
 
28
+ # Mount static directory (if needed for frontend)
29
  os.makedirs(os.path.join(os.path.dirname(__file__), "static"), exist_ok=True)
 
 
30
  app.mount("/static", StaticFiles(directory="static"), name="static")
31
 
32
  # ---- In-Memory Stores ----
33
+ db_store = {} # session_id → Chroma vector DB
34
+ chat_store = {} # session_id → chat messages
35
+ general_chat_sessions = {} # session_id → general (no PDF) flag
36
 
37
+ # ---- Utility Functions ----
38
 
39
  def extract_text_from_pdf(file) -> str:
40
  """Extracts text from the first page of a PDF."""
 
45
  """Chunks, embeds, and stores text in ChromaDB."""
46
  splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
47
  docs = splitter.create_documents([text])
48
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
 
49
  vectordb = Chroma.from_documents(docs, embeddings, collection_name=collection_name)
50
  return vectordb
51
 
 
58
  """Creates a new unique session ID."""
59
  sid = str(uuid.uuid4())
60
  chat_store[sid] = []
 
 
61
  if not is_pdf:
62
  general_chat_sessions[sid] = True
 
63
  return sid
64
 
65
  def append_chat(session_id: str, role: str, msg: str):
 
73
  db_store.pop(session_id, None)
74
  general_chat_sessions.pop(session_id, None)
75
 
76
+ # ---- API Endpoints ----
77
 
78
  @app.get("/", response_class=HTMLResponse)
79
  async def get_home():
80
+ try:
81
+ with open(os.path.join(os.path.dirname(__file__), "static", "index.html")) as f:
82
+ return f.read()
83
+ except FileNotFoundError:
84
+ return HTMLResponse(content="<h1>RAG Chatbot API</h1><p>Upload a PDF or start a chat.</p>")
85
 
86
  @app.post("/start-chat/")
87
  async def start_general_chat():
88
  """Starts a general chat session without PDF."""
89
  session_id = create_session(is_pdf=False)
90
  return {"session_id": session_id, "message": "General chat session started."}
91
+
92
  @app.post("/upload/")
93
  async def upload_pdf(file: UploadFile = File(...), current_session_id: str = Form(None)):
94
  """Handles PDF upload and indexing with chat continuity."""
 
95
  text = extract_text_from_pdf(file)
96
+
 
97
  if current_session_id and current_session_id in chat_store:
 
98
  session_id = current_session_id
99
+ general_chat_sessions.pop(session_id, None) # upgrade to PDF mode
 
 
100
  else:
 
101
  session_id = create_session()
102
+
 
103
  vectordb = build_vector_db(text, collection_name=session_id)
104
  db_store[session_id] = vectordb
105
+
106
  return {"session_id": session_id, "message": "PDF indexed."}
107
 
108
  @app.post("/chat/")
109
  async def chat(session_id: str = Form(...), prompt: str = Form(...)):
 
 
110
  is_general_chat = session_id in general_chat_sessions
111
  is_pdf_chat = session_id in db_store
112
+
113
  if not is_general_chat and not is_pdf_chat:
114
  return {"error": "Invalid session ID"}
115
+
116
  append_chat(session_id, "user", prompt)
117
+
 
118
  if not ANTHROPIC_API_KEY:
119
  return JSONResponse(status_code=500, content={"error": "Missing ANTHROPIC_API_KEY environment variable"})
120
+
121
  client = Anthropic(api_key=ANTHROPIC_API_KEY.strip())
122
+
123
  if is_general_chat:
124
+ # No context, just send prompt
125
  response = client.messages.create(
126
  model=CLAUDE_MODEL,
127
  max_tokens=512,
 
129
  messages=[{"role": "user", "content": prompt}]
130
  )
131
  else:
 
132
  context = retrieve_context(db_store[session_id], prompt)
133
  response = client.messages.create(
134
  model=CLAUDE_MODEL,
 
147
  """Ends session and deletes associated data."""
148
  delete_session(session_id)
149
  return {"message": "Session cleared."}
150
+