Spaces:
Running
Running
#!/usr/bin/env python | |
import os | |
import re | |
import json | |
import requests | |
from collections.abc import Iterator | |
from threading import Thread | |
import gradio as gr | |
from loguru import logger | |
import pandas as pd | |
import PyPDF2 | |
############################################################################## | |
# API Configuration | |
############################################################################## | |
FRIENDLI_TOKEN = os.environ.get("FRIENDLI_TOKEN") | |
if not FRIENDLI_TOKEN: | |
raise ValueError("Please set FRIENDLI_TOKEN environment variable") | |
FRIENDLI_MODEL_ID = "dep89a2fld32mcm" | |
FRIENDLI_API_URL = "https://api.friendli.ai/dedicated/v1/chat/completions" | |
# SERPHouse API key | |
SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "") | |
if not SERPHOUSE_API_KEY: | |
logger.warning("SERPHOUSE_API_KEY not set. Web search functionality will be limited.") | |
############################################################################## | |
# File Processing Constants | |
############################################################################## | |
MAX_FILE_SIZE = 30 * 1024 * 1024 # 30MB | |
MAX_CONTENT_CHARS = 6000 | |
############################################################################## | |
# Improved Keyword Extraction | |
############################################################################## | |
def extract_keywords(text: str, top_k: int = 5) -> str: | |
""" | |
Extract keywords: supports English and Korean | |
""" | |
stop_words = {'μ', 'λ', 'μ΄', 'κ°', 'μ', 'λ₯Ό', 'μ', 'μ', 'μμ', | |
'the', 'is', 'at', 'on', 'in', 'a', 'an', 'and', 'or', 'but'} | |
text = re.sub(r"[^a-zA-Z0-9κ°-ν£\s]", "", text) | |
tokens = text.split() | |
key_tokens = [ | |
token for token in tokens | |
if token.lower() not in stop_words and len(token) > 1 | |
][:top_k] | |
return " ".join(key_tokens) | |
############################################################################## | |
# File Size Validation | |
############################################################################## | |
def validate_file_size(file_path: str) -> bool: | |
"""Check if file size is within limits""" | |
try: | |
file_size = os.path.getsize(file_path) | |
return file_size <= MAX_FILE_SIZE | |
except: | |
return False | |
############################################################################## | |
# Web Search Function | |
############################################################################## | |
def do_web_search(query: str, use_korean: bool = False) -> str: | |
""" | |
Search web and return top 20 organic results | |
""" | |
if not SERPHOUSE_API_KEY: | |
return "Web search unavailable. API key not configured." | |
try: | |
url = "https://api.serphouse.com/serp/live" | |
params = { | |
"q": query, | |
"domain": "google.com", | |
"serp_type": "web", | |
"device": "desktop", | |
"lang": "ko" if use_korean else "en", | |
"num": "20" | |
} | |
headers = { | |
"Authorization": f"Bearer {SERPHOUSE_API_KEY}" | |
} | |
logger.info(f"Calling SerpHouse API... Query: {query}") | |
response = requests.get(url, headers=headers, params=params, timeout=30) | |
response.raise_for_status() | |
data = response.json() | |
# Parse results | |
results = data.get("results", {}) | |
organic = None | |
if isinstance(results, dict) and "organic" in results: | |
organic = results["organic"] | |
elif isinstance(results, dict) and "results" in results: | |
if isinstance(results["results"], dict) and "organic" in results["results"]: | |
organic = results["results"]["organic"] | |
elif "organic" in data: | |
organic = data["organic"] | |
if not organic: | |
return "No search results found or unexpected API response structure." | |
max_results = min(20, len(organic)) | |
limited_organic = organic[:max_results] | |
summary_lines = [] | |
for idx, item in enumerate(limited_organic, start=1): | |
title = item.get("title", "No title") | |
link = item.get("link", "#") | |
snippet = item.get("snippet", "No description") | |
displayed_link = item.get("displayed_link", link) | |
summary_lines.append( | |
f"### Result {idx}: {title}\n\n" | |
f"{snippet}\n\n" | |
f"**Source**: [{displayed_link}]({link})\n\n" | |
f"---\n" | |
) | |
instructions = """ | |
# Web Search Results | |
Below are the search results. Use this information when answering questions: | |
1. Reference the title, content, and source links | |
2. Explicitly cite sources in your answer (e.g., "According to source X...") | |
3. Include actual source links in your response | |
4. Synthesize information from multiple sources | |
""" | |
search_results = instructions + "\n".join(summary_lines) | |
return search_results | |
except requests.exceptions.Timeout: | |
logger.error("Web search timeout") | |
return "Web search timed out. Please try again." | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Web search network error: {e}") | |
return "Network error during web search." | |
except Exception as e: | |
logger.error(f"Web search failed: {e}") | |
return f"Web search failed: {str(e)}" | |
############################################################################## | |
# File Analysis Functions | |
############################################################################## | |
def analyze_csv_file(path: str) -> str: | |
"""Analyze CSV file with size validation and encoding handling""" | |
if not validate_file_size(path): | |
return f"β οΈ Error: File size exceeds {MAX_FILE_SIZE/1024/1024:.1f}MB limit." | |
try: | |
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin-1'] | |
df = None | |
for encoding in encodings: | |
try: | |
df = pd.read_csv(path, encoding=encoding, nrows=50) | |
break | |
except UnicodeDecodeError: | |
continue | |
if df is None: | |
return f"Failed to read CSV: Unsupported encoding" | |
total_rows = len(pd.read_csv(path, encoding=encoding, usecols=[0])) | |
if df.shape[1] > 10: | |
df = df.iloc[:, :10] | |
summary = f"**Data size**: {total_rows} rows x {df.shape[1]} columns\n" | |
summary += f"**Showing**: Top {min(50, total_rows)} rows\n" | |
summary += f"**Columns**: {', '.join(df.columns)}\n\n" | |
df_str = df.to_string() | |
if len(df_str) > MAX_CONTENT_CHARS: | |
df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(truncated)..." | |
return f"**[CSV File: {os.path.basename(path)}]**\n\n{summary}{df_str}" | |
except Exception as e: | |
logger.error(f"CSV read error: {e}") | |
return f"Failed to read CSV file ({os.path.basename(path)}): {str(e)}" | |
def analyze_txt_file(path: str) -> str: | |
"""Analyze text file with automatic encoding detection""" | |
if not validate_file_size(path): | |
return f"β οΈ Error: File size exceeds {MAX_FILE_SIZE/1024/1024:.1f}MB limit." | |
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin-1', 'utf-16'] | |
for encoding in encodings: | |
try: | |
with open(path, "r", encoding=encoding) as f: | |
text = f.read() | |
file_size = os.path.getsize(path) | |
size_info = f"**File size**: {file_size/1024:.1f}KB\n\n" | |
if len(text) > MAX_CONTENT_CHARS: | |
text = text[:MAX_CONTENT_CHARS] + "\n...(truncated)..." | |
return f"**[TXT File: {os.path.basename(path)}]**\n\n{size_info}{text}" | |
except UnicodeDecodeError: | |
continue | |
return f"Failed to read text file ({os.path.basename(path)}): Unsupported encoding" | |
def pdf_to_markdown(pdf_path: str) -> str: | |
"""Convert PDF to markdown with improved error handling""" | |
if not validate_file_size(pdf_path): | |
return f"β οΈ Error: File size exceeds {MAX_FILE_SIZE/1024/1024:.1f}MB limit." | |
text_chunks = [] | |
try: | |
with open(pdf_path, "rb") as f: | |
reader = PyPDF2.PdfReader(f) | |
total_pages = len(reader.pages) | |
max_pages = min(5, total_pages) | |
text_chunks.append(f"**Total pages**: {total_pages}") | |
text_chunks.append(f"**Showing**: First {max_pages} pages\n") | |
for page_num in range(max_pages): | |
try: | |
page = reader.pages[page_num] | |
page_text = page.extract_text() or "" | |
page_text = page_text.strip() | |
if page_text: | |
if len(page_text) > MAX_CONTENT_CHARS // max_pages: | |
page_text = page_text[:MAX_CONTENT_CHARS // max_pages] + "...(truncated)" | |
text_chunks.append(f"## Page {page_num+1}\n\n{page_text}\n") | |
except Exception as e: | |
text_chunks.append(f"## Page {page_num+1}\n\nFailed to read page: {str(e)}\n") | |
if total_pages > max_pages: | |
text_chunks.append(f"\n...({max_pages}/{total_pages} pages shown)...") | |
except Exception as e: | |
logger.error(f"PDF read error: {e}") | |
return f"Failed to read PDF file ({os.path.basename(pdf_path)}): {str(e)}" | |
full_text = "\n".join(text_chunks) | |
if len(full_text) > MAX_CONTENT_CHARS: | |
full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..." | |
return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}" | |
############################################################################## | |
# File Type Check Functions | |
############################################################################## | |
def is_image_file(file_path: str) -> bool: | |
"""Check if file is an image""" | |
return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE)) | |
def is_video_file(file_path: str) -> bool: | |
"""Check if file is a video""" | |
return bool(re.search(r"\.(mp4|avi|mov|mkv)$", file_path, re.IGNORECASE)) | |
def is_document_file(file_path: str) -> bool: | |
"""Check if file is a document""" | |
return bool(re.search(r"\.(pdf|csv|txt)$", file_path, re.IGNORECASE)) | |
############################################################################## | |
# Message Processing Functions | |
############################################################################## | |
def process_new_user_message(message: dict) -> str: | |
"""Process user message and convert to text""" | |
content_parts = [message["text"]] | |
if not message.get("files"): | |
return message["text"] | |
# Classify files | |
csv_files = [] | |
txt_files = [] | |
pdf_files = [] | |
image_files = [] | |
video_files = [] | |
unknown_files = [] | |
for file_path in message["files"]: | |
if file_path.lower().endswith(".csv"): | |
csv_files.append(file_path) | |
elif file_path.lower().endswith(".txt"): | |
txt_files.append(file_path) | |
elif file_path.lower().endswith(".pdf"): | |
pdf_files.append(file_path) | |
elif is_image_file(file_path): | |
image_files.append(file_path) | |
elif is_video_file(file_path): | |
video_files.append(file_path) | |
else: | |
unknown_files.append(file_path) | |
# Process document files | |
for csv_path in csv_files: | |
csv_analysis = analyze_csv_file(csv_path) | |
content_parts.append(csv_analysis) | |
for txt_path in txt_files: | |
txt_analysis = analyze_txt_file(txt_path) | |
content_parts.append(txt_analysis) | |
for pdf_path in pdf_files: | |
pdf_markdown = pdf_to_markdown(pdf_path) | |
content_parts.append(pdf_markdown) | |
# Warning messages for unsupported files | |
if image_files: | |
image_names = [os.path.basename(f) for f in image_files] | |
content_parts.append( | |
f"\nβ οΈ **Image files detected**: {', '.join(image_names)}\n" | |
"This demo currently does not support image analysis. " | |
"Please describe the image content in text if you need help with it." | |
) | |
if video_files: | |
video_names = [os.path.basename(f) for f in video_files] | |
content_parts.append( | |
f"\nβ οΈ **Video files detected**: {', '.join(video_names)}\n" | |
"This demo currently does not support video analysis. " | |
"Please describe the video content in text if you need help with it." | |
) | |
if unknown_files: | |
unknown_names = [os.path.basename(f) for f in unknown_files] | |
content_parts.append( | |
f"\nβ οΈ **Unsupported file format**: {', '.join(unknown_names)}\n" | |
"Supported formats: PDF, CSV, TXT" | |
) | |
return "\n\n".join(content_parts) | |
def process_history(history: list[dict]) -> list[dict]: | |
"""Convert conversation history to Friendli API format""" | |
messages = [] | |
for item in history: | |
if item["role"] == "assistant": | |
messages.append({ | |
"role": "assistant", | |
"content": item["content"] | |
}) | |
else: # user | |
content = item["content"] | |
if isinstance(content, str): | |
messages.append({ | |
"role": "user", | |
"content": content | |
}) | |
elif isinstance(content, list) and len(content) > 0: | |
# File processing | |
file_info = [] | |
for file_path in content: | |
if isinstance(file_path, str): | |
file_info.append(f"[File: {os.path.basename(file_path)}]") | |
if file_info: | |
messages.append({ | |
"role": "user", | |
"content": " ".join(file_info) | |
}) | |
return messages | |
############################################################################## | |
# Streaming Response Handler | |
############################################################################## | |
def stream_friendli_response(messages: list[dict], max_tokens: int = 1000) -> Iterator[str]: | |
"""Get streaming response from Friendli AI API""" | |
headers = { | |
"Authorization": f"Bearer {FRIENDLI_TOKEN}", | |
"Content-Type": "application/json" | |
} | |
payload = { | |
"model": FRIENDLI_MODEL_ID, | |
"messages": messages, | |
"max_tokens": max_tokens, | |
"top_p": 0.8, | |
"temperature": 0.7, | |
"stream": True, | |
"stream_options": { | |
"include_usage": True | |
} | |
} | |
try: | |
response = requests.post( | |
FRIENDLI_API_URL, | |
headers=headers, | |
json=payload, | |
stream=True, | |
timeout=60 | |
) | |
response.raise_for_status() | |
full_response = "" | |
for line in response.iter_lines(): | |
if line: | |
line_text = line.decode('utf-8') | |
if line_text.startswith("data: "): | |
data_str = line_text[6:] | |
if data_str == "[DONE]": | |
break | |
try: | |
data = json.loads(data_str) | |
if "choices" in data and len(data["choices"]) > 0: | |
delta = data["choices"][0].get("delta", {}) | |
content = delta.get("content", "") | |
if content: | |
full_response += content | |
yield full_response | |
except json.JSONDecodeError: | |
logger.warning(f"JSON parsing failed: {data_str}") | |
continue | |
except requests.exceptions.Timeout: | |
yield "β οΈ Response timeout. Please try again." | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Friendli API network error: {e}") | |
yield f"β οΈ Network error occurred: {str(e)}" | |
except Exception as e: | |
logger.error(f"Friendli API error: {str(e)}") | |
yield f"β οΈ API call error: {str(e)}" | |
############################################################################## | |
# Main Inference Function | |
############################################################################## | |
def run( | |
message: dict, | |
history: list[dict], | |
max_new_tokens: int = 512, | |
use_web_search: bool = False, | |
use_korean: bool = False, | |
system_prompt: str = "", | |
) -> Iterator[str]: | |
try: | |
# Prepare system message | |
messages = [] | |
if use_korean: | |
combined_system_msg = "λλ AI μ΄μμ€ν΄νΈ μν μ΄λ€. νκ΅μ΄λ‘ μΉμ νκ³ μ ννκ² λ΅λ³ν΄λΌ." | |
else: | |
combined_system_msg = "You are an AI assistant. Please respond helpfully and accurately in English." | |
if system_prompt.strip(): | |
combined_system_msg += f"\n\n{system_prompt.strip()}" | |
# Web search processing | |
if use_web_search: | |
user_text = message.get("text", "") | |
if user_text: | |
ws_query = extract_keywords(user_text, top_k=5) | |
if ws_query.strip(): | |
logger.info(f"[Auto web search keywords] {ws_query!r}") | |
ws_result = do_web_search(ws_query, use_korean=use_korean) | |
if not ws_result.startswith("Web search"): | |
combined_system_msg += f"\n\n[Search Results]\n{ws_result}" | |
if use_korean: | |
combined_system_msg += "\n\n[μ€μ: λ΅λ³μ κ²μ κ²°κ³Όμ μΆμ²λ₯Ό λ°λμ μΈμ©νμΈμ]" | |
else: | |
combined_system_msg += "\n\n[Important: Always cite sources from search results in your answer]" | |
messages.append({ | |
"role": "system", | |
"content": combined_system_msg | |
}) | |
# Add conversation history | |
messages.extend(process_history(history)) | |
# Process current message | |
user_content = process_new_user_message(message) | |
messages.append({ | |
"role": "user", | |
"content": user_content | |
}) | |
# Debug log | |
logger.debug(f"Total messages: {len(messages)}") | |
# Call Friendli API and stream | |
for response_text in stream_friendli_response(messages, max_new_tokens): | |
yield response_text | |
except Exception as e: | |
logger.error(f"run function error: {str(e)}") | |
yield f"β οΈ Sorry, an error occurred: {str(e)}" | |
############################################################################## | |
# Examples | |
############################################################################## | |
examples = [ | |
# PDF comparison example | |
[ | |
{ | |
"text": "Compare the contents of the two PDF files.", | |
"files": [ | |
"assets/additional-examples/before.pdf", | |
"assets/additional-examples/after.pdf", | |
], | |
} | |
], | |
# CSV analysis example | |
[ | |
{ | |
"text": "Summarize and analyze the contents of the CSV file.", | |
"files": ["assets/additional-examples/sample-csv.csv"], | |
} | |
], | |
# Web search example | |
[ | |
{ | |
"text": "Explain discord.gg/openfreeai", | |
"files": [], | |
} | |
], | |
# Code generation example | |
[ | |
{ | |
"text": "Write Python code to generate Fibonacci sequence.", | |
"files": [], | |
} | |
], | |
] | |
############################################################################## | |
# Gradio UI - CSS Styles (Removed blue colors) | |
############################################################################## | |
css = """ | |
/* Full width UI */ | |
.gradio-container { | |
background: rgba(255, 255, 255, 0.95); | |
padding: 30px 40px; | |
margin: 20px auto; | |
width: 100% !important; | |
max-width: none !important; | |
border-radius: 12px; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
} | |
.fillable { | |
width: 100% !important; | |
max-width: 100% !important; | |
} | |
/* Background */ | |
body { | |
background: linear-gradient(135deg, #f5f7fa 0%, #e0e0e0 100%); | |
margin: 0; | |
padding: 0; | |
font-family: 'Segoe UI', 'Helvetica Neue', Arial, sans-serif; | |
color: #333; | |
} | |
/* Button styles - neutral gray */ | |
button, .btn { | |
background: #6b7280 !important; | |
border: none; | |
color: white !important; | |
padding: 10px 20px; | |
text-transform: uppercase; | |
font-weight: 600; | |
letter-spacing: 0.5px; | |
cursor: pointer; | |
border-radius: 6px; | |
transition: all 0.3s ease; | |
} | |
button:hover, .btn:hover { | |
background: #4b5563 !important; | |
transform: translateY(-1px); | |
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2); | |
} | |
/* Examples section */ | |
#examples_container, .examples-container { | |
margin: 20px auto; | |
width: 90%; | |
background: rgba(255, 255, 255, 0.8); | |
padding: 20px; | |
border-radius: 8px; | |
} | |
#examples_row, .examples-row { | |
justify-content: center; | |
} | |
/* Example buttons */ | |
.gr-samples-table button, | |
.gr-examples button, | |
.examples button { | |
background: #f0f2f5 !important; | |
border: 1px solid #d1d5db; | |
color: #374151 !important; | |
margin: 5px; | |
font-size: 14px; | |
} | |
.gr-samples-table button:hover, | |
.gr-examples button:hover, | |
.examples button:hover { | |
background: #e5e7eb !important; | |
border-color: #9ca3af; | |
} | |
/* Chat interface */ | |
.chatbox, .chatbot { | |
background: white !important; | |
border-radius: 8px; | |
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); | |
} | |
.message { | |
padding: 15px; | |
margin: 10px 0; | |
border-radius: 8px; | |
} | |
/* Input styles */ | |
.multimodal-textbox, textarea, input[type="text"] { | |
background: white !important; | |
border: 1px solid #d1d5db; | |
border-radius: 6px; | |
padding: 10px; | |
font-size: 16px; | |
} | |
.multimodal-textbox:focus, textarea:focus, input[type="text"]:focus { | |
border-color: #6b7280; | |
outline: none; | |
box-shadow: 0 0 0 3px rgba(107, 114, 128, 0.1); | |
} | |
/* Warning messages */ | |
.warning-box { | |
background: #fef3c7 !important; | |
border: 1px solid #f59e0b; | |
border-radius: 8px; | |
padding: 15px; | |
margin: 10px 0; | |
color: #92400e; | |
} | |
/* Headings */ | |
h1, h2, h3 { | |
color: #1f2937; | |
} | |
/* Links - neutral gray */ | |
a { | |
color: #6b7280; | |
text-decoration: none; | |
} | |
a:hover { | |
text-decoration: underline; | |
color: #4b5563; | |
} | |
/* Slider */ | |
.gr-slider { | |
margin: 15px 0; | |
} | |
/* Checkbox */ | |
input[type="checkbox"] { | |
width: 18px; | |
height: 18px; | |
margin-right: 8px; | |
} | |
/* Scrollbar */ | |
::-webkit-scrollbar { | |
width: 8px; | |
height: 8px; | |
} | |
::-webkit-scrollbar-track { | |
background: #f1f1f1; | |
} | |
::-webkit-scrollbar-thumb { | |
background: #888; | |
border-radius: 4px; | |
} | |
::-webkit-scrollbar-thumb:hover { | |
background: #555; | |
} | |
""" | |
############################################################################## | |
# Gradio UI Main | |
############################################################################## | |
with gr.Blocks(css=css, title="Gemma-3-R1984-27B Chatbot") as demo: | |
# Title | |
gr.Markdown("# π€ Gemma-3-R1984-27B Chatbot") | |
gr.Markdown("Community: [https://discord.gg/openfreeai](https://discord.gg/openfreeai)") | |
# UI Components | |
with gr.Row(): | |
with gr.Column(scale=2): | |
web_search_checkbox = gr.Checkbox( | |
label="π Enable Deep Research (Web Search)", | |
value=False, | |
info="Check for questions requiring latest information" | |
) | |
with gr.Column(scale=1): | |
korean_checkbox = gr.Checkbox( | |
label="π°π· νκΈ (Korean)", | |
value=False, | |
info="Check for Korean responses" | |
) | |
with gr.Column(scale=1): | |
max_tokens_slider = gr.Slider( | |
label="Max Tokens", | |
minimum=100, | |
maximum=8000, | |
step=50, | |
value=1000, | |
info="Adjust response length" | |
) | |
# Main chat interface | |
chat = gr.ChatInterface( | |
fn=run, | |
type="messages", | |
chatbot=gr.Chatbot(type="messages", scale=1), | |
textbox=gr.MultimodalTextbox( | |
file_types=[ | |
".webp", ".png", ".jpg", ".jpeg", ".gif", | |
".mp4", ".csv", ".txt", ".pdf" | |
], | |
file_count="multiple", | |
autofocus=True, | |
placeholder="Enter text or upload PDF, CSV, TXT files. (Images/videos not supported in this demo)" | |
), | |
multimodal=True, | |
additional_inputs=[ | |
max_tokens_slider, | |
web_search_checkbox, | |
korean_checkbox, | |
], | |
stop_btn=False, | |
examples=examples, | |
run_examples_on_click=False, | |
cache_examples=False, | |
delete_cache=(1800, 1800), | |
) | |
if __name__ == "__main__": | |
demo.launch() |