import os import re import json import time import shutil import asyncio import logging import traceback from chardet import detect from httpx import AsyncClient, RequestError from typing import List, Dict, Any, Optional from fastapi.staticfiles import StaticFiles from fastapi import FastAPI, Request, HTTPException, UploadFile, File, Form from fastapi.responses import StreamingResponse, JSONResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware from dotenv import load_dotenv from tenacity import RetryError from openai import RateLimitError from anthropic import RateLimitError as AnthropicRateLimitError from google.api_core.exceptions import ResourceExhausted from src.helpers.helper import get_folder_size, clear_folder logger = logging.getLogger() logger.setLevel(logging.INFO) # Path to the .env file ENV_FILE_PATH = os.getenv("WRITABLE_DIR", "/tmp") + "/.env" # Define the upload directory and maximum folder size UPLOAD_DIRECTORY = os.getenv("WRITABLE_DIR", "/tmp") + "/uploads" MAX_FOLDER_SIZE = 10 * 1024 * 1024 # 10 MB in bytes CONTEXT_LENGTH = 128000 BUFFER = 10000 MAX_TOKENS_ALLOWED = CONTEXT_LENGTH - BUFFER # Per-session state SESSION_STORE: Dict[str, Dict[str, Any]] = {} # Format error message for SSE def format_error_sse(event_type: str, data: str) -> str: lines = data.splitlines() sse_message = f"event: {event_type}\n" for line in lines: sse_message += f"data: {line}\n" sse_message += "\n" return sse_message # Stop the task on error (non-fastapi) def stop_on_error(): state = SESSION_STORE if "process_task" in state: state["process_task"].cancel() del state["process_task"] # Get OAuth tokens for MCP tools def get_oauth_token(provider: str) -> Optional[str]: if "oauth_tokens" in SESSION_STORE and provider in SESSION_STORE["oauth_tokens"]: token_data = SESSION_STORE["oauth_tokens"][provider] # Check if token is expired (1 hour) if time.time() - token_data["timestamp"] < 3600: return token_data["token"] else: # Token expired, remove it del SESSION_STORE["oauth_tokens"][provider] logger.info(f"{provider} token expired and removed") return None # Initialize the components async def initialize_components(): load_dotenv(ENV_FILE_PATH, override=True) from src.search.search_engine import SearchEngine from src.query_processing.query_processor import QueryProcessor # from src.rag.neo4j_graphrag import Neo4jGraphRAG from src.rag.graph_rag import GraphRAG from src.evaluation.evaluator import Evaluator from src.reasoning.reasoner import Reasoner from src.crawl.crawler import CustomCrawler from src.utils.api_key_manager import APIKeyManager from src.query_processing.late_chunking.late_chunker import LateChunker from src.integrations.mcp_client import MCPClient state = SESSION_STORE manager = APIKeyManager() manager._reinit() state['search_engine'] = SearchEngine() state['query_processor'] = QueryProcessor() state['crawler'] = CustomCrawler(max_concurrent_requests=1000) # state['graph_rag'] = Neo4jGraphRAG(num_workers=os.cpu_count() * 2) state['graph_rag'] = GraphRAG(num_workers=os.cpu_count() * 2) state['evaluator'] = Evaluator() state['reasoner'] = Reasoner() state['model'] = manager.get_llm() state['late_chunker'] = LateChunker() state["mcp_client"] = MCPClient() state["initialized"] = True state["session_id"] = await state["crawler"].create_session() # Main function to process user queries async def process_query(user_query: str, sse_queue: asyncio.Queue): state = SESSION_STORE try: # --- Categorize the query --- category = await state["query_processor"].classify_query(user_query) cat_lower = category.lower().strip() user_query = re.sub(r'category:.*', '', user_query, flags=re.IGNORECASE).strip() # --- Read and extract user-provided files and links --- # Initialize caches if not present if "user_files_cache" not in state: state["user_files_cache"] = {} if "user_links_cache" not in state: state["user_links_cache"] = {} # Extract user-provided context user_context = "" user_links = state.get("user_provided_links", []) # Read new uploaded files if state["session_id"]: session_upload_path = os.path.join(UPLOAD_DIRECTORY, state["session_id"]) if os.path.exists(session_upload_path): for filename in os.listdir(session_upload_path): file_path = os.path.join(session_upload_path, filename) if os.path.isfile(file_path): # Check if file is already in cache if filename not in state["user_files_cache"]: try: await sse_queue.put(("step", "Reading User-Provided Files...")) # Always read as binary first with open(file_path, 'rb') as f: file_bytes = f.read() # Try to decode with multiple strategies file_content = None # Strategy 1: Try UTF-8 with BOM handling try: # Handle UTF-8 BOM if present if file_bytes.startswith(b'\xef\xbb\xbf'): file_content = file_bytes[3:].decode('utf-8') else: file_content = file_bytes.decode('utf-8') logger.info(f"Successfully decoded {filename} as UTF-8") except UnicodeDecodeError: # Strategy 2: Try other common encodings for encoding in ['utf-8-sig', 'latin-1', 'cp1252', 'iso-8859-1', 'windows-1252']: try: file_content = file_bytes.decode(encoding) logger.info(f"Successfully decoded {filename} with {encoding}") break except UnicodeDecodeError: continue # Strategy 3: If all else fails, use chardet for detection if file_content is None: try: detected = detect(file_bytes) if detected['encoding']: file_content = file_bytes.decode(detected['encoding']) logger.info(f"Decoded {filename} with detected encoding: {detected['encoding']}") except: pass # Final fallback: Use UTF-8 with replacement if file_content is None: file_content = file_bytes.decode('utf-8', errors='replace') logger.warning(f"Had to use error replacement for {filename}") # Store the decoded content state["user_files_cache"][filename] = file_content logger.info(f"Successfully cached file {filename}, length: {len(file_content)} chars") except Exception as e: logger.error(f"Error reading file {filename}: {str(e)}") state["user_files_cache"][filename] = "" # Cache empty to avoid retrying # Add all cached file contents for filename, content in state["user_files_cache"].items(): if content: user_context += f"\n[USER PROVIDED FILE: {filename} START]\n{content}\n[USER PROVIDED FILE: {filename} END]\n\n" # Crawl new user-provided links if user_links: await sse_queue.put(("step", "Crawling User-Provided Links...")) new_links = [link for link in user_links if link not in state["user_links_cache"]] if new_links: # Only crawl new links link_contents = await state['crawler'].fetch_page_contents( new_links, user_query, state["session_id"], max_attempts=1 ) # Cache the new contents for link, content in zip(new_links, link_contents): if not isinstance(content, Exception) and content: state["user_links_cache"][link] = content else: state["user_links_cache"][link] = "" # Cache empty to avoid retrying # Add all cached link contents for link, content in state["user_links_cache"].items(): if content: idx = user_links.index(link) + 1 if link in user_links else 0 user_context += f"\n[USER PROVIDED LINK {idx} START]\n{content}\n[USER PROVIDED LINK {idx} END]\n\n" # --- Fetch apps data from MCP service --- app_context = "" selected_services = state.get("selected_services", {}) # Check if any services are selected has_google = selected_services.get("google", []) has_microsoft = selected_services.get("microsoft", []) has_slack = selected_services.get("slack", False) if has_google or has_microsoft or has_slack: await sse_queue.put(("step", "Fetching Data From Connected Apps...")) # Fetch from each provider in parallel tasks = [] # Google services if has_google and len(has_google) > 0: google_token = get_oauth_token("google") tasks.append( state['mcp_client'].fetch_app_data( provider="google", services=has_google, query=user_query, user_id=state["session_id"], access_token=google_token ) ) # Microsoft services if has_microsoft and len(has_microsoft) > 0: microsoft_token = get_oauth_token("microsoft") tasks.append( state['mcp_client'].fetch_app_data( provider="microsoft", services=has_microsoft, query=user_query, user_id=state["session_id"], access_token=microsoft_token ) ) # Slack if has_slack: slack_token = get_oauth_token("slack") tasks.append( state['mcp_client'].fetch_app_data( provider="slack", services=[], # Slack doesn't have sub-services query=user_query, user_id=state["session_id"], access_token=slack_token ) ) # Execute all requests in parallel if tasks: results = await asyncio.gather(*tasks, return_exceptions=True) print("=== Main ===") print(f"Raw results fetched from fetch_mcp_data:-\n{results}") # Process results for i, result in enumerate(results): if isinstance(result, Exception): logger.error(f"Error fetching app data: {result}") elif isinstance(result, dict): # Determine which provider this result is from if i == 0 and has_google: provider = "google" elif (i == 1 and has_microsoft) or (i == 0 and not has_google and has_microsoft): provider = "microsoft" else: provider = "slack" # Format the data formatted_context = state['mcp_client'].format_as_context(provider, result) if formatted_context: app_context += formatted_context # Log how much app data we got if app_context: logger.info(f"Retrieved app data: {len(app_context)} characters") # Prepend app context to user context if app_context: user_context = app_context + "\n\n" + user_context # Upgrade basic to advanced if user has provided links if cat_lower == "basic" and user_links: cat_lower = "advanced" # --- Process the query based on the category --- if cat_lower == "basic": response = "" chunk_counter = 1 if user_context: # Include user context if available print(f"User Context:-\n{user_context}") await sse_queue.put(("step", "Generating Response...")) async for chunk in state["reasoner"].answer(user_query, user_context, query_type="basic"): await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) response += chunk chunk_counter += 1 else: # No user context provided async for chunk in state["reasoner"].answer(user_query): await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) response += chunk chunk_counter += 1 await sse_queue.put(("final_message", response)) SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) await sse_queue.put(("action", { "name": "evaluate", "payload": {"query": user_query, "response": response} })) await sse_queue.put(("complete", "done")) elif cat_lower == "advanced": await sse_queue.put(("step", "Searching...")) optimized_query = await state['search_engine'].generate_optimized_query(user_query) search_results = await state['search_engine'].search( optimized_query, num_results=3, exclude_filetypes=["pdf"] ) urls = [r.get('link', 'No URL') for r in search_results] search_contents = await state['crawler'].fetch_page_contents( urls, user_query, state["session_id"], max_attempts=1 ) # Start with user-provided context contents = user_context # Add crawled contents if search_contents: for k, content in enumerate(search_contents, 1): if isinstance(content, Exception): print(f"Error fetching content: {content}") elif content: contents += f"[SOURCE {k} START]\n{content}\n[SOURCE {k} END]\n\n" if len(contents.strip()) > 0: await sse_queue.put(("step", "Generating Response...")) token_count = state['model'].get_num_tokens(contents) if token_count > MAX_TOKENS_ALLOWED: contents = await state['late_chunker'].chunker(contents, user_query, MAX_TOKENS_ALLOWED) await sse_queue.put(("sources_read", len(search_contents))) response = "" chunk_counter = 1 async for chunk in state["reasoner"].answer(user_query, contents): await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) response += chunk chunk_counter += 1 sources_for_answer = [] for idx, (result, content) in enumerate(zip(search_results, search_contents), 1): if content: # Only include if content was successfully fetched sources_for_answer.append({ "id": idx, "title": result.get('title', 'No Title'), "link": result.get('link', 'No URL') }) await sse_queue.put(("final_message", response)) await sse_queue.put(("final_sources", json.dumps(sources_for_answer))) SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) SESSION_STORE["answer"] = response SESSION_STORE["source_contents"] = contents await sse_queue.put(("action", { "name": "sources", "payload": {"search_results": search_results, "search_contents": search_contents} })) await sse_queue.put(("action", { "name": "evaluate", "payload": {"query": user_query, "contents": [contents], "response": response} })) await sse_queue.put(("complete", "done")) else: await sse_queue.put(("error", "No results found.")) elif cat_lower == "pro": current_search_results = [] current_search_contents = [] await sse_queue.put(("step", "Thinking...")) start = time.time() intent = await state['query_processor'].get_query_intent(user_query) sub_queries, _ = await state['query_processor'].decompose_query(user_query, intent) async def sub_query_task(sub_query): try: await sse_queue.put(("step", "Searching...")) await sse_queue.put(("task", (sub_query, "RUNNING"))) optimized_query = await state['search_engine'].generate_optimized_query(sub_query) search_results = await state['search_engine'].search( optimized_query, num_results=10, exclude_filetypes=["pdf"] ) filtered_urls = await state['search_engine'].filter_urls( sub_query, category, search_results ) current_search_results.extend(filtered_urls) # Combine search results with user-provided links all_search_results = search_results + \ [{"link": url, "title": f"User provided: {url}", "snippet": ""} for url in user_links] urls = [r.get('link', 'No URL') for r in all_search_results] search_contents = await state['crawler'].fetch_page_contents( urls, sub_query, state["session_id"], max_attempts=1 ) current_search_contents.extend(search_contents) contents = user_context if search_contents: for k, c in enumerate(search_contents, 1): if isinstance(c, Exception): logger.info(f"Error fetching content: {c}") elif c: contents += f"[SOURCE {k} START]\n{c}\n[SOURCE {k} END]\n\n" if len(contents.strip()) > 0: await sse_queue.put(("task", (sub_query, "DONE"))) else: await sse_queue.put(("task", (sub_query, "FAILED"))) return contents except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError): await sse_queue.put(("task", (sub_query, "FAILED"))) return "" tasks = [] if len(sub_queries) > 1 and sub_queries[0] != user_query: for sub_query in sub_queries: tasks.append(sub_query_task(sub_query)) results = await asyncio.gather(*tasks) end = time.time() # Start with user-provided context contents = user_context # Add searched contents contents += "\n\n".join(r for r in results if r.strip()) unique_results = [] seen = set() for entry in current_search_results: link = entry["link"] if link not in seen: seen.add(link) unique_results.append(entry) current_search_results = unique_results current_search_contents = list(set(current_search_contents)) if len(contents.strip()) > 0: await sse_queue.put(("step", "Generating Response...")) token_count = state['model'].get_num_tokens(contents) if token_count > MAX_TOKENS_ALLOWED: contents = await state['late_chunker'].chunker( text=contents, query=user_query, max_tokens=MAX_TOKENS_ALLOWED ) logger.info(f"Number of tokens in the answer: {token_count}") logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(contents)}") await sse_queue.put(("sources_read", len(current_search_contents))) response = "" chunk_counter = 1 is_first_chunk = True async for chunk in state['reasoner'].answer(user_query, contents): if is_first_chunk: await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds")) is_first_chunk = False await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) response += chunk chunk_counter += 1 sources_for_answer = [] for idx, (result, content) in enumerate(zip(current_search_results, current_search_contents), 1): if content: # Only include if content was successfully fetched sources_for_answer.append({ "id": idx, "title": result.get('title', 'No Title'), "link": result.get('link', 'No URL') }) await sse_queue.put(("final_message", response)) await sse_queue.put(("final_sources", json.dumps(sources_for_answer))) SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) SESSION_STORE["answer"] = response SESSION_STORE["source_contents"] = contents await sse_queue.put(("action", { "name": "sources", "payload": { "search_results": current_search_results, "search_contents": current_search_contents } })) await sse_queue.put(("action", { "name": "evaluate", "payload": {"query": user_query, "contents": [contents], "response": response} })) await sse_queue.put(("complete", "done")) else: await sse_queue.put(("error", "No results found.")) elif cat_lower == "super": current_search_results = [] current_search_contents = [] await sse_queue.put(("step", "Thinking...")) start = time.time() main_query_intent = await state['query_processor'].get_query_intent(user_query) sub_queries, _ = await state['query_processor'].decompose_query(user_query, main_query_intent) await sse_queue.put(("step", "Searching...")) async def sub_query_task(sub_query): try: async def sub_sub_query_task(sub_sub_query): optimized_query = await state['search_engine'].generate_optimized_query(sub_sub_query) search_results = await state['search_engine'].search( optimized_query, num_results=10, exclude_filetypes=["pdf"] ) filtered_urls = await state['search_engine'].filter_urls( sub_sub_query, category, search_results ) current_search_results.extend(filtered_urls) urls = [r.get('link', 'No URL') for r in filtered_urls] search_contents = await state['crawler'].fetch_page_contents( urls, sub_sub_query, state["session_id"], max_attempts=1, timeout=20 ) current_search_contents.extend(search_contents) contents = "" if search_contents: for k, c in enumerate(search_contents, 1): if isinstance(c, Exception): logger.info(f"Error fetching content: {c}") elif c: contents += f"[SOURCE {k} START]\n{c}\n[SOURCE {k} END]\n\n" return contents await sse_queue.put(("task", (sub_query, "RUNNING"))) sub_sub_queries, _ = await state['query_processor'].decompose_query(sub_query) tasks = [] if len(sub_sub_queries) > 1 and sub_sub_queries[0] != user_query: for sub_sub_query in sub_sub_queries: tasks.append(sub_sub_query_task(sub_sub_query)) results = await asyncio.gather(*tasks) if any(result.strip() for result in results): await sse_queue.put(("task", (sub_query, "DONE"))) else: await sse_queue.put(("task", (sub_query, "FAILED"))) return results except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError): await sse_queue.put(("task", (sub_query, "FAILED"))) return [] tasks = [] if len(sub_queries) > 1 and sub_queries[0] != user_query: for sub_query in sub_queries: tasks.append(sub_query_task(sub_query)) results = await asyncio.gather(*tasks) end = time.time() # Start with user-provided context previous_contents = [] if user_context: previous_contents.append(user_context) for result in results: if result: for content in result: if isinstance(content, str) and len(content.strip()) > 0: previous_contents.append(content) contents = "\n\n".join(previous_contents) unique_results = [] seen = set() for entry in current_search_results: link = entry["link"] if link not in seen: seen.add(link) unique_results.append(entry) current_search_results = unique_results current_search_contents = list(set(current_search_contents)) if len(contents.strip()) > 0: await sse_queue.put(("step", "Generating Response...")) token_count = state['model'].get_num_tokens(contents) if token_count > MAX_TOKENS_ALLOWED: contents = await state['late_chunker'].chunker( text=contents, query=user_query, max_tokens=MAX_TOKENS_ALLOWED ) logger.info(f"Number of tokens in the answer: {token_count}") logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(contents)}") await sse_queue.put(("sources_read", len(current_search_contents))) response = "" chunk_counter = 1 is_first_chunk = True async for chunk in state['reasoner'].answer(user_query, contents): if is_first_chunk: await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds")) is_first_chunk = False await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) response += chunk chunk_counter += 1 sources_for_answer = [] for idx, (result, content) in enumerate(zip(current_search_results, current_search_contents), 1): if content: # Only include if content was successfully fetched sources_for_answer.append({ "id": idx, "title": result.get('title', 'No Title'), "link": result.get('link', 'No URL') }) await sse_queue.put(("final_message", response)) await sse_queue.put(("final_sources", json.dumps(sources_for_answer))) SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) SESSION_STORE["answer"] = response SESSION_STORE["source_contents"] = contents await sse_queue.put(("action", { "name": "sources", "payload": { "search_results": current_search_results, "search_contents": current_search_contents } })) await sse_queue.put(("action", { "name": "evaluate", "payload": {"query": user_query, "contents": [contents], "response": response} })) await sse_queue.put(("complete", "done")) else: await sse_queue.put(("error", "No results found.")) elif cat_lower == "ultra": current_search_results = [] current_search_contents = [] match = re.search( r"^This is the previous context of the conversation:\s*.*?\s*Current Query:\s*(.*)$", user_query, flags=re.DOTALL | re.MULTILINE ) if match: user_query = match.group(1) await sse_queue.put(("step", "Thinking...")) await asyncio.sleep(0.01) # Sleep for a short time to allow the message to be sent async def on_event_callback(event_type, data): if event_type == "graph_operation": if data["operation_type"] == "creating_new_graph": await sse_queue.put(("step", "Creating New Graph...")) elif data["operation_type"] == "modifying_existing_graph": await sse_queue.put(("step", "Modifying Existing Graph...")) elif data["operation_type"] == "loading_existing_graph": await sse_queue.put(("step", "Loading Existing Graph...")) elif event_type == "sub_query_created": sub_query = data["sub_query"] await sse_queue.put(("task", (sub_query, "RUNNING"))) elif event_type == "search_process_started": await sse_queue.put(("step", "Searching...")) elif event_type == "sub_query_processed": sub_query = data["sub_query"] await sse_queue.put(("task", (sub_query, "DONE"))) elif event_type == "sub_query_failed": sub_query = data["sub_query"] await sse_queue.put(("task", (sub_query, "FAILED"))) elif event_type == "search_results_filtered": current_search_results.extend(data["filtered_urls"]) filtered_urls = data["filtered_urls"] current_search_results.extend(filtered_urls) elif event_type == "search_contents_fetched": current_search_contents.extend(data["contents"]) contents = data["contents"] current_search_contents.extend(contents) elif event_type == "search_process_completed": await sse_queue.put(("step", "Processing final graph tasks...")) await asyncio.sleep(0.01) # Sleep for a short time to allow the message to be sent state['graph_rag'].set_on_event_callback(on_event_callback) start = time.time() # state['graph_rag'].initialize_schema() await state['graph_rag'].process_graph( user_query, similarity_threshold=0.8, relevance_threshold=0.8, max_tokens_allowed=MAX_TOKENS_ALLOWED ) end = time.time() unique_results = [] seen = set() for entry in current_search_results: link = entry["link"] if link not in seen: seen.add(link) unique_results.append(entry) current_search_results = unique_results current_search_contents = list(set(current_search_contents)) await sse_queue.put(("step", "Generating Response...")) answer = state['graph_rag'].query_graph(user_query) if answer: # Start with user-provided context previous_contents = [] if user_context: previous_contents.append(user_context) token_count = state['model'].get_num_tokens(answer) if token_count > MAX_TOKENS_ALLOWED: answer = await state['late_chunker'].chunker( text=answer, query=user_query, max_tokens=MAX_TOKENS_ALLOWED ) logger.info(f"Number of tokens in the answer: {token_count}") logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(answer)}") await sse_queue.put(("sources_read", len(current_search_contents))) response = "" chunk_counter = 1 is_first_chunk = True async for chunk in state['reasoner'].answer(user_query, answer): if is_first_chunk: await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds")) is_first_chunk = False await sse_queue.put(("token", json.dumps({"chunk": chunk, "index": chunk_counter}))) response += chunk chunk_counter += 1 sources_for_answer = [] for idx, (result, content) in enumerate(zip(current_search_results, current_search_contents), 1): if content: # Only include if content was successfully fetched sources_for_answer.append({ "id": idx, "title": result.get('title', 'No Title'), "link": result.get('link', 'No URL') }) await sse_queue.put(("final_message", response)) await sse_queue.put(("final_sources", json.dumps(sources_for_answer))) SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) SESSION_STORE["answer"] = response SESSION_STORE["source_contents"] = contents await sse_queue.put(("action", { "name": "sources", "payload": {"search_results": current_search_results, "search_contents": current_search_contents}, })) await sse_queue.put(("action", { "name": "graph", "payload": {"query": user_query}, })) await sse_queue.put(("action", { "name": "evaluate", "payload": {"query": user_query, "contents": [answer], "response": response}, })) await sse_queue.put(("complete", "done")) else: await sse_queue.put(("error", "No results found.")) else: await sse_queue.put(("final_message", "I'm not sure how to handle your query.")) except Exception as e: await sse_queue.put(("error", str(e))) traceback.print_exc() stop() # Create a FastAPI app app = FastAPI() # Define allowed origins origins = [ "http://localhost:3000", "https://localhost:3000", "http://localhost:7860", "https://localhost:7860", "http://localhost:8000", "https://localhost:8000", "http://localhost", "https://localhost" ] # Add the CORS middleware to your FastAPI app app.add_middleware( CORSMiddleware, allow_origins=origins, # Allows only these origins allow_credentials=True, allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.) allow_headers=["*"], # Allows all headers ) # Serve the React app (the production build) at the root URL. app.mount("/static", StaticFiles(directory="frontend/build/static", html=True), name="static") # Define the routes for the FastAPI app # Define the route for sources action to display search results @app.post("/action/sources") def action_sources(payload: Dict[str, Any]) -> Dict[str, Any]: try: search_contents = payload.get("search_contents", []) search_results = payload.get("search_results", []) sources = [] word_limit = 15 # Maximum number of words for the description for result, contents in zip(search_results, search_contents): if contents: title = result.get('title', 'No Title') link = result.get('link', 'No URL') snippet = result.get('snippet', 'No snippet') cleaned = re.sub(r'<[^>]+>|\[\/?.*?\]', '', snippet) words = cleaned.split() if len(words) > word_limit: description = " ".join(words[:word_limit]) + "..." else: description = " ".join(words) source_obj = { "title": title, "link": link, "description": description } sources.append(source_obj) return {"result": sources} except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) # Define the route for graph action to display the graph @app.post("/action/graph") def action_graph() -> Dict[str, Any]: state = SESSION_STORE try: html_str = state['graph_rag'].display_graph() return {"result": html_str} except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) # Define the route for evaluate action to display evaluation results @app.post("/action/evaluate") async def action_evaluate(payload: Dict[str, Any]) -> Dict[str, Any]: state = SESSION_STORE try: query = payload.get("query", "") contents = payload.get("contents", []) response = payload.get("response", "") metrics = payload.get("metrics", []) result = await state["evaluator"].evaluate_response(query, response, contents, include_metrics=metrics) return {"result": result} except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) # Define the route for excerpts action to display excerpts from the sources @app.post("/action/excerpts") async def action_excerpts() -> Dict[str, Any]: def validate_excerpts_format(excerpts): if not isinstance(excerpts, list): return False for item in excerpts: if not isinstance(item, dict): return False for statement, sources in item.items(): if not isinstance(statement, str) or not isinstance(sources, dict): return False for src_num, excerpt in sources.items(): if not (isinstance(src_num, int) or isinstance(src_num, str)): return False if not isinstance(excerpt, str): return False return True try: state = SESSION_STORE response = state["answer"] contents = state["source_contents"] if not response or not contents: raise ValueError("Required data for excerpts not found") excerpts_list = await state["reasoner"].get_excerpts(response, contents) cleaned_excerpts = re.sub( r'```[\w\s]*\n?|```|~~~[\w\s]*\n?|~~~', '', excerpts_list, flags=re.MULTILINE | re.DOTALL ).strip() try: excerpts = eval(cleaned_excerpts) except Exception: print(f"Error parsing excerpts:\n{cleaned_excerpts}") raise ValueError("Excerpts could not be parsed as a Python list.") if not validate_excerpts_format(excerpts): print(f"Excerpts format validation failed:\n{excerpts}") raise ValueError("Excerpts are not in the required format.") print(f"Excerpts:\n{excerpts}") return {"result": excerpts} except Exception as e: print(f"Error in action_excerpts: {e}") return JSONResponse(content={"error": str(e)}, status_code=500) # Define the route for settings to set or update the environment variables @app.post("/settings") async def update_settings(data: Dict[str, Any]): from src.helpers.helper import ( prepare_provider_key_updates, prepare_proxy_list_updates, update_env_vars ) provider = data.get("Model_Provider", "").strip() model_name = data.get("Model_Name", "").strip() multiple_api_keys = data.get("Model_API_Keys", "").strip() brave_api_key = data.get("Brave_Search_API_Key", "").strip() proxy_list = data.get("Proxy_List", "").strip() model_temperature = str(data.get("Model_Temperature", 0.0)) model_top_p = str(data.get("Model_Top_P", 1.0)) prov_lower = provider.lower() key_updates = prepare_provider_key_updates(prov_lower, multiple_api_keys) env_updates = {} env_updates.update(key_updates) px = prepare_proxy_list_updates(proxy_list) if px: env_updates.update(px) env_updates["BRAVE_API_KEY"] = brave_api_key env_updates["MODEL_PROVIDER"] = prov_lower env_updates["MODEL_NAME"] = model_name env_updates["MODEL_TEMPERATURE"] = model_temperature env_updates["MODEL_TOP_P"] = model_top_p update_env_vars(env_updates) load_dotenv(override=True) await initialize_components() return {"success": True} # Define the route for adding/uploading content for a specific session @app.post("/add-content") async def add_content(files: Optional[List[UploadFile]] = File(None), urls: str = Form(...)): state = SESSION_STORE session_id = state.get("session_id") if not session_id: raise HTTPException(status_code=400, detail="Session ID is not set. Please start a session first.") session_upload_path = os.path.join(UPLOAD_DIRECTORY, session_id) os.makedirs(session_upload_path, exist_ok=True) saved_filenames = [] if files: total_new_files_size = sum(file.size for file in files) current_folder_size = get_folder_size(session_upload_path) # Check if the total size exceeds the maximum allowed folder size if current_folder_size + total_new_files_size > MAX_FOLDER_SIZE: raise HTTPException( status_code=400, detail=f"Cannot add files as total storage would exceed 10 MB. Current size: {current_folder_size / (1024 * 1024):.2f} MB" ) for file in files: file_path = os.path.join(session_upload_path, file.filename) try: with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) saved_filenames.append(file.filename) finally: file.file.close() try: parsed_urls = json.loads(urls) print(f"Received links: {parsed_urls}") except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid URL format.") # Store user-provided links in session if parsed_urls: SESSION_STORE["user_provided_links"] = parsed_urls return { "message": "Content added successfully", "files_added": saved_filenames, "links_added": parsed_urls } # Define the route to update the selected services for searching @app.post("/api/selected-services") async def update_selected_services(data: Dict[str, Any]): state = SESSION_STORE selected_services = data.get("services", {}) state["selected_services"] = selected_services logger.info(f"Updated selected services: {selected_services}") return {"success": True, "services": selected_services} # Define the route to receive OAuth tokens from the frontend @app.post("/api/session-token") async def receive_session_token(data: Dict[str, Any]): # Helper function to exchange Slack code for OAuth token async def exchange_slack_code_for_token(code: str, redirect_uri: str) -> Dict[str, Any]: # Get Slack OAuth credentials from environment client_id = os.getenv("SLACK_CLIENT_ID") client_secret = os.getenv("SLACK_CLIENT_SECRET") if not client_id or not client_secret: logger.error("Slack OAuth credentials not configured") raise HTTPException( status_code=500, detail="Slack OAuth credentials not configured" ) # Prepare the request to Slack's OAuth endpoint url = "https://slack.com/api/oauth.v2.access" # Form data for the request data = { "code": code, "client_id": client_id, "client_secret": client_secret, "redirect_uri": redirect_uri } try: # Make the request to exchange code for token async with AsyncClient() as client: response = await client.post( url, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"} ) # Parse the response result = response.json() # Check if the request was successful if not result.get("ok"): error = result.get("error", "Unknown error") logger.error(f"Slack OAuth error: {error}") raise HTTPException( status_code=400, detail=f"Slack authentication failed: {error}" ) # Log the full response for debugging logger.info("Slack OAuth endpoint response: %s", result.get("ok")) # Extract access_token: Check authed_user first (for user tokens), then root (for bot tokens) access_token = result.get("authed_user", {}).get("access_token") or result.get("access_token") if not access_token: logger.error("Slack OAuth token exchange failed: No access token found in response") raise HTTPException( status_code=400, detail="Slack authentication failed: No access token found in response" ) # Determine token_type: From authed_user if user token, else root token_type = result.get("authed_user", {}).get("token_type") or result.get("token_type", "bot") # Extract other fields return { "access_token": access_token, "token_type": token_type, "scope": result.get("authed_user", {}).get("scope") or result.get("scope", ""), "bot_user_id": result.get("bot_user_id"), "app_id": result.get("app_id"), "team": result.get("team", {}), "enterprise": result.get("enterprise"), "authed_user": result.get("authed_user", {}), "is_enterprise_install": result.get("is_enterprise_install", False) } except RequestError as e: logger.error(f"Error making request to Slack: {e}") raise HTTPException( status_code=500, detail="Failed to communicate with Slack OAuth service" ) except Exception as e: logger.error(f"Unexpected error during Slack OAuth: {e}") raise HTTPException( status_code=500, detail="An unexpected error occurred during authentication" ) provider = data.get("provider") # 'google', 'microsoft', 'slack' token = data.get("token") code = data.get("code") # For Slack authorization code team_id = data.get("team_id") # For Slack workspace team_name = data.get("team_name") # For Slack workspace name if not provider: raise HTTPException(status_code=400, detail="Provider is required") # Initialize oauth_tokens if it doesn't exist if "oauth_tokens" not in SESSION_STORE: SESSION_STORE["oauth_tokens"] = {} token_data = { "timestamp": time.time() } # Handle Slack OAuth code exchange if provider == "slack" and code: try: # Build the redirect URI (must match what was sent in the initial OAuth request) request_origin = data.get("origin", "https://localhost:3000") redirect_uri = f"{request_origin}/auth-receiver.html" # Exchange the code for an access token slack_response = await exchange_slack_code_for_token(code, redirect_uri) # Store the access token token_data["token"] = slack_response["access_token"] token_data["token_type"] = slack_response.get("token_type", "bot") token_data["scope"] = slack_response.get("scope", "") # Store team/workspace information team_info = slack_response.get("team", {}) token_data["team_id"] = team_info.get("id", team_id) token_data["team_name"] = team_info.get("name", team_name) token_data["team_domain"] = team_info.get("domain", "") # Store enterprise information if available if slack_response.get("enterprise"): token_data["enterprise"] = slack_response["enterprise"] token_data["is_enterprise_install"] = slack_response.get("is_enterprise_install", False) # Store authed user information if slack_response.get("authed_user"): token_data["authed_user"] = slack_response["authed_user"] # Log successful exchange logger.info(f"Successfully exchanged Slack code for access token. Team: {token_data.get('team_name', 'Unknown')}") except HTTPException: raise # Re-raise HTTP exceptions except Exception as e: logger.error(f"Failed to exchange Slack code: {e}") raise HTTPException( status_code=500, detail="Failed to complete Slack authentication" ) elif provider in ["google", "microsoft"] and token: # For Google and Microsoft, we already have the token token_data["token"] = token else: raise HTTPException( status_code=400, detail=f"Invalid authentication data for provider: {provider}" ) # Store the token data SESSION_STORE["oauth_tokens"][provider] = token_data # Return success response with workspace info for Slack response_data = { "success": True, "message": f"{provider} token stored successfully" } if provider == "slack" and "team_name" in token_data: response_data["workspace"] = { "id": token_data.get("team_id"), "name": token_data.get("team_name"), "domain": token_data.get("team_domain") } return response_data # Define the route for cleaning up a session if the session ID matches @app.post("/cleanup") async def cleanup_session(): state = SESSION_STORE session_id = state.get("session_id") if not session_id: raise HTTPException(status_code=400, detail="Session ID is not set. Please start a session first.") session_upload_path = os.path.join(UPLOAD_DIRECTORY, session_id) if session_id: # Clear the session upload directory clear_folder(session_upload_path) # Clear user-provided links and caches SESSION_STORE["user_provided_links"] = [] SESSION_STORE["user_files_cache"] = {} SESSION_STORE["user_links_cache"] = {} SESSION_STORE["selected_services"] = {} SESSION_STORE["oauth_tokens"] = {} return {"message": "Cleanup successful."} return {"message": "No session ID provided, cleanup skipped."} @app.on_event("startup") def init_chat(): if not SESSION_STORE: print("Initializing chat...") # Create the upload directory if it doesn't exist print("Creating upload directory...") os.makedirs(UPLOAD_DIRECTORY, exist_ok=True) # Initialize the session store SESSION_STORE["settings_saved"] = False SESSION_STORE["session_id"] = None SESSION_STORE["answer"] = None SESSION_STORE["source_contents"] = None SESSION_STORE["chat_history"] = [] SESSION_STORE["user_provided_links"] = [] SESSION_STORE["user_files_cache"] = {} SESSION_STORE["user_links_cache"] = {} SESSION_STORE["selected_services"] = {} SESSION_STORE["oauth_tokens"] = {} print("Chat initialized!") return {"sucess": True} else: print("Chat already initialized!") return {"success": False} @app.get("/message-sse") async def sse_message(request: Request, user_message: str): state = SESSION_STORE sse_queue = asyncio.Queue() async def event_generator(): # Build the prompt context = state["chat_history"][-3:] if context: prompt = \ f"""This is the previous context of the conversation: {context} Current Query: {user_message}""" else: prompt = user_message task = asyncio.create_task(process_query(prompt, sse_queue)) state["process_task"] = task while True: if await request.is_disconnected(): task.cancel() break try: event_type, data = await asyncio.wait_for(sse_queue.get(), timeout=5) if event_type == "token": yield f"event: token\ndata: {data}\n\n" elif event_type == "final_message": yield f"event: final_message\ndata: {data}\n\n" elif event_type == "final_sources": yield f"event: final_sources\ndata: {data}\n\n" elif event_type == "error": stop_on_error() yield format_error_sse("error", data) elif event_type == "step": yield f"event: step\ndata: {data}\n\n" elif event_type == "task": subq, status = data j = {"task": subq, "status": status} yield f"event: task\ndata: {json.dumps(j)}\n\n" elif event_type == "sources_read": yield f"event: sources_read\ndata: {data}\n\n" elif event_type == "action": yield f"event: action\ndata: {json.dumps(data)}\n\n" elif event_type == "complete": yield f"event: complete\ndata: {data}\n\n" break else: yield f"event: message\ndata: {data}\n\n" except asyncio.TimeoutError: if task.done(): break continue except asyncio.CancelledError: break if not task.done(): task.cancel() if "process_task" in state: del state["process_task"] return StreamingResponse(event_generator(), media_type="text/event-stream") @app.post("/stop") def stop(): state = SESSION_STORE if "process_task" in state: state["process_task"].cancel() del state["process_task"] return {"message": "Stopped task manually"} # Catch-all route for frontend paths. @app.get("/{full_path:path}") async def serve_frontend(full_path: str, request: Request): index_path = os.path.join("frontend", "build", "index.html") if not os.path.exists(index_path): raise HTTPException(status_code=500, detail="Frontend build not found") return FileResponse(index_path)