|
import os |
|
import json |
|
import asyncio |
|
import requests |
|
from datetime import datetime |
|
from typing import List, Dict, Optional |
|
from fastapi import FastAPI, Request, HTTPException, Depends |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import StreamingResponse |
|
from openai import OpenAI |
|
import logging |
|
|
|
|
|
def verify_origin(request: Request): |
|
"""Verify that the request comes from an allowed origin for /chat endpoint""" |
|
origin = request.headers.get("origin") |
|
referer = request.headers.get("referer") |
|
|
|
allowed_origins = [ |
|
"https://chrunos.com", |
|
"https://www.chrunos.com" |
|
] |
|
|
|
|
|
if origin and any(origin.startswith(local) for local in ["http://localhost:", "http://127.0.0.1:"]): |
|
return True |
|
|
|
|
|
if origin in allowed_origins: |
|
return True |
|
|
|
|
|
if referer and any(referer.startswith(allowed) for allowed in allowed_origins): |
|
return True |
|
|
|
raise HTTPException( |
|
status_code=403, |
|
detail="Access denied: This endpoint is only accessible from chrunos.com" |
|
) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") |
|
GOOGLE_CX = os.getenv("GOOGLE_CX") |
|
LLM_API_KEY = os.getenv("LLM_API_KEY") |
|
LLM_BASE_URL = os.getenv("LLM_BASE_URL", "https://api-15i2e8ze256bvfn6.aistudio-app.com/v1") |
|
|
|
|
|
SYSTEM_PROMPT_WITH_SEARCH = """You are an intelligent AI assistant with access to real-time web search capabilities. |
|
|
|
When you need current information, recent events, specific facts, or when the user's question would benefit from up-to-date information, use the google_search function. |
|
|
|
**Use search for:** |
|
- Recent news or events |
|
- Current statistics or data |
|
- Specific factual information you're unsure about |
|
- Questions about things that may have changed recently |
|
- When the user explicitly asks for current/recent information |
|
|
|
**Response Guidelines:** |
|
1. Always use the search tool when it would provide more accurate or current information |
|
2. Synthesize information from multiple sources when available |
|
3. Clearly indicate when information comes from search results |
|
4. Provide comprehensive, well-structured answers |
|
5. Cite sources appropriately with links. |
|
6. If search results conflict with my knowledge, prioritize the search results. |
|
|
|
Current date: {current_date}""" |
|
|
|
SYSTEM_PROMPT_NO_SEARCH = """You are an intelligent AI assistant. Provide helpful, accurate, and comprehensive responses based on your training data. |
|
|
|
Current date: {current_date}""" |
|
|
|
|
|
async def google_search_tool_async(query: str, num_results: int = 3) -> List[Dict]: |
|
""" |
|
Async Google Custom Search - reduced results for faster response |
|
""" |
|
if not GOOGLE_API_KEY or not GOOGLE_CX or not query.strip(): |
|
return [] |
|
|
|
logger.info(f"Executing search for: '{query}'") |
|
|
|
search_url = "https://www.googleapis.com/customsearch/v1" |
|
params = { |
|
"key": GOOGLE_API_KEY, |
|
"cx": GOOGLE_CX, |
|
"q": query.strip(), |
|
"num": min(num_results, 5), |
|
"dateRestrict": "m3" |
|
} |
|
|
|
try: |
|
loop = asyncio.get_event_loop() |
|
response = await loop.run_in_executor( |
|
None, |
|
lambda: requests.get(search_url, params=params, timeout=10) |
|
) |
|
response.raise_for_status() |
|
search_results = response.json() |
|
|
|
if "items" not in search_results: |
|
return [] |
|
|
|
parsed_results = [] |
|
for item in search_results.get("items", [])[:num_results]: |
|
title = item.get("title", "").strip() |
|
url = item.get("link", "").strip() |
|
snippet = item.get("snippet", "").strip() |
|
|
|
if title and url and snippet: |
|
parsed_results.append({ |
|
"source_title": title, |
|
"url": url, |
|
"snippet": snippet, |
|
"domain": url.split('/')[2] if '/' in url else url |
|
}) |
|
|
|
logger.info(f"Retrieved {len(parsed_results)} search results") |
|
return parsed_results |
|
|
|
except Exception as e: |
|
logger.error(f"Search error: {e}") |
|
return [] |
|
|
|
def format_search_results_compact(search_results: List[Dict]) -> str: |
|
"""Compact formatting for faster processing""" |
|
if not search_results: |
|
return "No search results found." |
|
|
|
formatted = ["Search Results:"] |
|
for i, result in enumerate(search_results, 1): |
|
formatted.append(f"\n{i}. {result['source_title']}") |
|
formatted.append(f" Source: {result['domain']}") |
|
formatted.append(f" Content: {result['snippet']}") |
|
|
|
return "\n".join(formatted) |
|
|
|
|
|
app = FastAPI(title="Streaming AI Chatbot", version="2.1.0") |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=[ |
|
"https://chrunos.com", |
|
"https://www.chrunos.com", |
|
"http://localhost:3000", |
|
"http://localhost:8000", |
|
], |
|
allow_credentials=True, |
|
allow_methods=["GET", "POST", "OPTIONS"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
if not LLM_API_KEY or not LLM_BASE_URL: |
|
logger.error("LLM_API_KEY or LLM_BASE_URL not configured") |
|
client = None |
|
else: |
|
client = OpenAI(api_key=LLM_API_KEY, base_url=LLM_BASE_URL) |
|
logger.info("OpenAI client initialized successfully") |
|
|
|
|
|
available_tools = [ |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "google_search", |
|
"description": "Search Google for current information, recent events, or specific facts. Use this when you need up-to-date information or when the user's question would benefit from current data.", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"query": { |
|
"type": "string", |
|
"description": "Search query with relevant keywords" |
|
} |
|
}, |
|
"required": ["query"] |
|
} |
|
} |
|
} |
|
] |
|
|
|
|
|
async def generate_streaming_response(messages: List[Dict], use_search: bool, temperature: float): |
|
"""Generate streaming response with optional search""" |
|
|
|
try: |
|
|
|
llm_kwargs = { |
|
"model": "unsloth/Qwen3-30B-A3B-GGUF", |
|
"temperature": temperature, |
|
"messages": messages, |
|
"max_tokens": 2000, |
|
"stream": True |
|
} |
|
|
|
if use_search: |
|
llm_kwargs["tools"] = available_tools |
|
llm_kwargs["tool_choice"] = "auto" |
|
|
|
source_links = [] |
|
response_content = "" |
|
tool_calls_data = [] |
|
current_tool_call = None |
|
|
|
|
|
stream = client.chat.completions.create(**llm_kwargs) |
|
|
|
|
|
collecting_tool_call = False |
|
|
|
for chunk in stream: |
|
delta = chunk.choices[0].delta |
|
finish_reason = chunk.choices[0].finish_reason |
|
|
|
|
|
if delta.content: |
|
content_chunk = delta.content |
|
response_content += content_chunk |
|
yield f"data: {json.dumps({'type': 'content', 'data': content_chunk})}\n\n" |
|
|
|
|
|
if delta.tool_calls: |
|
collecting_tool_call = True |
|
for tool_call in delta.tool_calls: |
|
|
|
while len(tool_calls_data) <= tool_call.index: |
|
tool_calls_data.append({ |
|
"id": None, |
|
"function": {"name": None, "arguments": ""} |
|
}) |
|
|
|
|
|
if tool_call.id: |
|
tool_calls_data[tool_call.index]["id"] = tool_call.id |
|
if tool_call.function and tool_call.function.name: |
|
tool_calls_data[tool_call.index]["function"]["name"] = tool_call.function.name |
|
if tool_call.function and tool_call.function.arguments: |
|
tool_calls_data[tool_call.index]["function"]["arguments"] += tool_call.function.arguments |
|
|
|
|
|
if finish_reason in ["tool_calls", "stop"] and collecting_tool_call: |
|
break |
|
|
|
|
|
processed_any_tools = False |
|
if tool_calls_data and any(tc.get("id") and tc.get("function", {}).get("name") for tc in tool_calls_data): |
|
yield f"data: {json.dumps({'type': 'status', 'data': 'Searching...'})}\n\n" |
|
|
|
tool_responses = [] |
|
|
|
|
|
for tool_call in tool_calls_data: |
|
if not tool_call.get("id") or not tool_call.get("function", {}).get("name"): |
|
continue |
|
|
|
function_name = tool_call["function"]["name"] |
|
|
|
if function_name == "google_search": |
|
try: |
|
args = json.loads(tool_call["function"]["arguments"]) |
|
query = args.get("query", "").strip() |
|
if query: |
|
logger.info(f"Executing search with query: {query}") |
|
search_results = await google_search_tool_async(query) |
|
|
|
if search_results: |
|
processed_any_tools = True |
|
|
|
|
|
for result in search_results: |
|
source_links.append({ |
|
"title": result["source_title"], |
|
"url": result["url"], |
|
"domain": result["domain"] |
|
}) |
|
|
|
|
|
search_context = format_search_results_compact(search_results) |
|
tool_responses.append({ |
|
"tool_call_id": tool_call["id"], |
|
"role": "tool", |
|
"content": search_context |
|
}) |
|
else: |
|
tool_responses.append({ |
|
"tool_call_id": tool_call["id"], |
|
"role": "tool", |
|
"content": "No search results found." |
|
}) |
|
except json.JSONDecodeError as e: |
|
logger.error(f"Failed to parse tool arguments: {e}") |
|
tool_responses.append({ |
|
"tool_call_id": tool_call["id"], |
|
"role": "tool", |
|
"content": "Error: Invalid search query format." |
|
}) |
|
except Exception as e: |
|
logger.error(f"Search tool error: {e}") |
|
tool_responses.append({ |
|
"tool_call_id": tool_call["id"], |
|
"role": "tool", |
|
"content": f"Search error: {str(e)}" |
|
}) |
|
|
|
|
|
if tool_responses: |
|
yield f"data: {json.dumps({'type': 'status', 'data': 'Generating response...'})}\n\n" |
|
|
|
|
|
final_messages = messages.copy() |
|
|
|
|
|
assistant_message = { |
|
"role": "assistant", |
|
"content": response_content if response_content else None, |
|
"tool_calls": [ |
|
{ |
|
"id": tc["id"], |
|
"type": "function", |
|
"function": { |
|
"name": tc["function"]["name"], |
|
"arguments": tc["function"]["arguments"] |
|
} |
|
} |
|
for tc in tool_calls_data if tc.get("id") and tc.get("function", {}).get("name") |
|
] |
|
} |
|
final_messages.append(assistant_message) |
|
|
|
|
|
final_messages.extend(tool_responses) |
|
|
|
|
|
final_stream = client.chat.completions.create( |
|
model="unsloth/Qwen3-30B-A3B-GGUF", |
|
temperature=temperature, |
|
messages=final_messages, |
|
max_tokens=2000, |
|
stream=True |
|
) |
|
|
|
for chunk in final_stream: |
|
if chunk.choices[0].delta.content: |
|
content = chunk.choices[0].delta.content |
|
yield f"data: {json.dumps({'type': 'content', 'data': content})}\n\n" |
|
|
|
|
|
if source_links: |
|
yield f"data: {json.dumps({'type': 'sources', 'data': source_links})}\n\n" |
|
|
|
yield f"data: {json.dumps({'type': 'done', 'data': {'search_used': processed_any_tools}})}\n\n" |
|
|
|
except Exception as e: |
|
logger.error(f"Streaming error: {e}") |
|
yield f"data: {json.dumps({'type': 'error', 'data': str(e)})}\n\n" |
|
|
|
|
|
@app.post("/chat/stream") |
|
async def chat_stream_endpoint(request: Request, _: None = Depends(verify_origin)): |
|
if not client: |
|
raise HTTPException(status_code=500, detail="LLM client not configured") |
|
|
|
try: |
|
data = await request.json() |
|
user_message = data.get("message", "").strip() |
|
use_search = data.get("use_search", False) |
|
temperature = max(0, min(2, data.get("temperature", 0.7))) |
|
conversation_history = data.get("history", []) |
|
user_prompt = data.get("system_prompt") |
|
|
|
if not user_message: |
|
raise HTTPException(status_code=400, detail="No message provided") |
|
|
|
|
|
current_date = datetime.now().strftime("%Y-%m-%d") |
|
system_content = (SYSTEM_PROMPT_WITH_SEARCH if use_search else user_prompt |
|
).format(current_date=current_date) |
|
messages = [{"role": "system", "content": system_content}] + conversation_history + [{"role": "user", "content": user_message}] |
|
|
|
logger.info(f"Stream request - search: {use_search}, temp: {temperature}, message: {user_message[:100]}...") |
|
|
|
return StreamingResponse( |
|
generate_streaming_response(messages, use_search, temperature), |
|
media_type="text/plain", |
|
headers={ |
|
"Cache-Control": "no-cache", |
|
"Connection": "keep-alive", |
|
"X-Accel-Buffering": "no" |
|
} |
|
) |
|
|
|
except json.JSONDecodeError: |
|
raise HTTPException(status_code=400, detail="Invalid JSON") |
|
except Exception as e: |
|
logger.error(f"Stream endpoint error: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |