openfree's picture
Update app.py
3cf27d9 verified
#!/usr/bin/env python
import os
import re
import json
import requests
from collections.abc import Iterator
from threading import Thread
import gradio as gr
from loguru import logger
import pandas as pd
import PyPDF2
##############################################################################
# API Configuration
##############################################################################
FRIENDLI_TOKEN = os.environ.get("FRIENDLI_TOKEN")
if not FRIENDLI_TOKEN:
raise ValueError("Please set FRIENDLI_TOKEN environment variable")
FRIENDLI_MODEL_ID = "dep89a2fld32mcm"
FRIENDLI_API_URL = "https://api.friendli.ai/dedicated/v1/chat/completions"
# SERPHouse API key
SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
if not SERPHOUSE_API_KEY:
logger.warning("SERPHOUSE_API_KEY not set. Web search functionality will be limited.")
##############################################################################
# File Processing Constants
##############################################################################
MAX_FILE_SIZE = 30 * 1024 * 1024 # 30MB
MAX_CONTENT_CHARS = 6000
##############################################################################
# Improved Keyword Extraction
##############################################################################
def extract_keywords(text: str, top_k: int = 5) -> str:
"""
Extract keywords: supports English and Korean
"""
stop_words = {'은', 'λŠ”', '이', 'κ°€', '을', 'λ₯Ό', '의', '에', 'μ—μ„œ',
'the', 'is', 'at', 'on', 'in', 'a', 'an', 'and', 'or', 'but'}
text = re.sub(r"[^a-zA-Z0-9κ°€-힣\s]", "", text)
tokens = text.split()
key_tokens = [
token for token in tokens
if token.lower() not in stop_words and len(token) > 1
][:top_k]
return " ".join(key_tokens)
##############################################################################
# File Size Validation
##############################################################################
def validate_file_size(file_path: str) -> bool:
"""Check if file size is within limits"""
try:
file_size = os.path.getsize(file_path)
return file_size <= MAX_FILE_SIZE
except:
return False
##############################################################################
# Web Search Function
##############################################################################
def do_web_search(query: str, use_korean: bool = False) -> str:
"""
Search web and return top 20 organic results
"""
if not SERPHOUSE_API_KEY:
return "Web search unavailable. API key not configured."
try:
url = "https://api.serphouse.com/serp/live"
params = {
"q": query,
"domain": "google.com",
"serp_type": "web",
"device": "desktop",
"lang": "ko" if use_korean else "en",
"num": "20"
}
headers = {
"Authorization": f"Bearer {SERPHOUSE_API_KEY}"
}
logger.info(f"Calling SerpHouse API... Query: {query}")
response = requests.get(url, headers=headers, params=params, timeout=30)
response.raise_for_status()
data = response.json()
# Parse results
results = data.get("results", {})
organic = None
if isinstance(results, dict) and "organic" in results:
organic = results["organic"]
elif isinstance(results, dict) and "results" in results:
if isinstance(results["results"], dict) and "organic" in results["results"]:
organic = results["results"]["organic"]
elif "organic" in data:
organic = data["organic"]
if not organic:
return "No search results found or unexpected API response structure."
max_results = min(20, len(organic))
limited_organic = organic[:max_results]
summary_lines = []
for idx, item in enumerate(limited_organic, start=1):
title = item.get("title", "No title")
link = item.get("link", "#")
snippet = item.get("snippet", "No description")
displayed_link = item.get("displayed_link", link)
summary_lines.append(
f"### Result {idx}: {title}\n\n"
f"{snippet}\n\n"
f"**Source**: [{displayed_link}]({link})\n\n"
f"---\n"
)
instructions = """
# Web Search Results
Below are the search results. Use this information when answering questions:
1. Reference the title, content, and source links
2. Explicitly cite sources in your answer (e.g., "According to source X...")
3. Include actual source links in your response
4. Synthesize information from multiple sources
"""
search_results = instructions + "\n".join(summary_lines)
return search_results
except requests.exceptions.Timeout:
logger.error("Web search timeout")
return "Web search timed out. Please try again."
except requests.exceptions.RequestException as e:
logger.error(f"Web search network error: {e}")
return "Network error during web search."
except Exception as e:
logger.error(f"Web search failed: {e}")
return f"Web search failed: {str(e)}"
##############################################################################
# File Analysis Functions
##############################################################################
def analyze_csv_file(path: str) -> str:
"""Analyze CSV file with size validation and encoding handling"""
if not validate_file_size(path):
return f"⚠️ Error: File size exceeds {MAX_FILE_SIZE/1024/1024:.1f}MB limit."
try:
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin-1']
df = None
for encoding in encodings:
try:
df = pd.read_csv(path, encoding=encoding, nrows=50)
break
except UnicodeDecodeError:
continue
if df is None:
return f"Failed to read CSV: Unsupported encoding"
total_rows = len(pd.read_csv(path, encoding=encoding, usecols=[0]))
if df.shape[1] > 10:
df = df.iloc[:, :10]
summary = f"**Data size**: {total_rows} rows x {df.shape[1]} columns\n"
summary += f"**Showing**: Top {min(50, total_rows)} rows\n"
summary += f"**Columns**: {', '.join(df.columns)}\n\n"
df_str = df.to_string()
if len(df_str) > MAX_CONTENT_CHARS:
df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
return f"**[CSV File: {os.path.basename(path)}]**\n\n{summary}{df_str}"
except Exception as e:
logger.error(f"CSV read error: {e}")
return f"Failed to read CSV file ({os.path.basename(path)}): {str(e)}"
def analyze_txt_file(path: str) -> str:
"""Analyze text file with automatic encoding detection"""
if not validate_file_size(path):
return f"⚠️ Error: File size exceeds {MAX_FILE_SIZE/1024/1024:.1f}MB limit."
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin-1', 'utf-16']
for encoding in encodings:
try:
with open(path, "r", encoding=encoding) as f:
text = f.read()
file_size = os.path.getsize(path)
size_info = f"**File size**: {file_size/1024:.1f}KB\n\n"
if len(text) > MAX_CONTENT_CHARS:
text = text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
return f"**[TXT File: {os.path.basename(path)}]**\n\n{size_info}{text}"
except UnicodeDecodeError:
continue
return f"Failed to read text file ({os.path.basename(path)}): Unsupported encoding"
def pdf_to_markdown(pdf_path: str) -> str:
"""Convert PDF to markdown with improved error handling"""
if not validate_file_size(pdf_path):
return f"⚠️ Error: File size exceeds {MAX_FILE_SIZE/1024/1024:.1f}MB limit."
text_chunks = []
try:
with open(pdf_path, "rb") as f:
reader = PyPDF2.PdfReader(f)
total_pages = len(reader.pages)
max_pages = min(5, total_pages)
text_chunks.append(f"**Total pages**: {total_pages}")
text_chunks.append(f"**Showing**: First {max_pages} pages\n")
for page_num in range(max_pages):
try:
page = reader.pages[page_num]
page_text = page.extract_text() or ""
page_text = page_text.strip()
if page_text:
if len(page_text) > MAX_CONTENT_CHARS // max_pages:
page_text = page_text[:MAX_CONTENT_CHARS // max_pages] + "...(truncated)"
text_chunks.append(f"## Page {page_num+1}\n\n{page_text}\n")
except Exception as e:
text_chunks.append(f"## Page {page_num+1}\n\nFailed to read page: {str(e)}\n")
if total_pages > max_pages:
text_chunks.append(f"\n...({max_pages}/{total_pages} pages shown)...")
except Exception as e:
logger.error(f"PDF read error: {e}")
return f"Failed to read PDF file ({os.path.basename(pdf_path)}): {str(e)}"
full_text = "\n".join(text_chunks)
if len(full_text) > MAX_CONTENT_CHARS:
full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
##############################################################################
# File Type Check Functions
##############################################################################
def is_image_file(file_path: str) -> bool:
"""Check if file is an image"""
return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE))
def is_video_file(file_path: str) -> bool:
"""Check if file is a video"""
return bool(re.search(r"\.(mp4|avi|mov|mkv)$", file_path, re.IGNORECASE))
def is_document_file(file_path: str) -> bool:
"""Check if file is a document"""
return bool(re.search(r"\.(pdf|csv|txt)$", file_path, re.IGNORECASE))
##############################################################################
# Message Processing Functions
##############################################################################
def process_new_user_message(message: dict) -> str:
"""Process user message and convert to text"""
content_parts = [message["text"]]
if not message.get("files"):
return message["text"]
# Classify files
csv_files = []
txt_files = []
pdf_files = []
image_files = []
video_files = []
unknown_files = []
for file_path in message["files"]:
if file_path.lower().endswith(".csv"):
csv_files.append(file_path)
elif file_path.lower().endswith(".txt"):
txt_files.append(file_path)
elif file_path.lower().endswith(".pdf"):
pdf_files.append(file_path)
elif is_image_file(file_path):
image_files.append(file_path)
elif is_video_file(file_path):
video_files.append(file_path)
else:
unknown_files.append(file_path)
# Process document files
for csv_path in csv_files:
csv_analysis = analyze_csv_file(csv_path)
content_parts.append(csv_analysis)
for txt_path in txt_files:
txt_analysis = analyze_txt_file(txt_path)
content_parts.append(txt_analysis)
for pdf_path in pdf_files:
pdf_markdown = pdf_to_markdown(pdf_path)
content_parts.append(pdf_markdown)
# Warning messages for unsupported files
if image_files:
image_names = [os.path.basename(f) for f in image_files]
content_parts.append(
f"\n⚠️ **Image files detected**: {', '.join(image_names)}\n"
"This demo currently does not support image analysis. "
"Please describe the image content in text if you need help with it."
)
if video_files:
video_names = [os.path.basename(f) for f in video_files]
content_parts.append(
f"\n⚠️ **Video files detected**: {', '.join(video_names)}\n"
"This demo currently does not support video analysis. "
"Please describe the video content in text if you need help with it."
)
if unknown_files:
unknown_names = [os.path.basename(f) for f in unknown_files]
content_parts.append(
f"\n⚠️ **Unsupported file format**: {', '.join(unknown_names)}\n"
"Supported formats: PDF, CSV, TXT"
)
return "\n\n".join(content_parts)
def process_history(history: list[dict]) -> list[dict]:
"""Convert conversation history to Friendli API format"""
messages = []
for item in history:
if item["role"] == "assistant":
messages.append({
"role": "assistant",
"content": item["content"]
})
else: # user
content = item["content"]
if isinstance(content, str):
messages.append({
"role": "user",
"content": content
})
elif isinstance(content, list) and len(content) > 0:
# File processing
file_info = []
for file_path in content:
if isinstance(file_path, str):
file_info.append(f"[File: {os.path.basename(file_path)}]")
if file_info:
messages.append({
"role": "user",
"content": " ".join(file_info)
})
return messages
##############################################################################
# Streaming Response Handler
##############################################################################
def stream_friendli_response(messages: list[dict], max_tokens: int = 1000) -> Iterator[str]:
"""Get streaming response from Friendli AI API"""
headers = {
"Authorization": f"Bearer {FRIENDLI_TOKEN}",
"Content-Type": "application/json"
}
payload = {
"model": FRIENDLI_MODEL_ID,
"messages": messages,
"max_tokens": max_tokens,
"top_p": 0.8,
"temperature": 0.7,
"stream": True,
"stream_options": {
"include_usage": True
}
}
try:
response = requests.post(
FRIENDLI_API_URL,
headers=headers,
json=payload,
stream=True,
timeout=60
)
response.raise_for_status()
full_response = ""
for line in response.iter_lines():
if line:
line_text = line.decode('utf-8')
if line_text.startswith("data: "):
data_str = line_text[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
if "choices" in data and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
content = delta.get("content", "")
if content:
full_response += content
yield full_response
except json.JSONDecodeError:
logger.warning(f"JSON parsing failed: {data_str}")
continue
except requests.exceptions.Timeout:
yield "⚠️ Response timeout. Please try again."
except requests.exceptions.RequestException as e:
logger.error(f"Friendli API network error: {e}")
yield f"⚠️ Network error occurred: {str(e)}"
except Exception as e:
logger.error(f"Friendli API error: {str(e)}")
yield f"⚠️ API call error: {str(e)}"
##############################################################################
# Main Inference Function
##############################################################################
def run(
message: dict,
history: list[dict],
max_new_tokens: int = 512,
use_web_search: bool = False,
use_korean: bool = False,
system_prompt: str = "",
) -> Iterator[str]:
try:
# Prepare system message
messages = []
if use_korean:
combined_system_msg = "λ„ˆλŠ” AI μ–΄μ‹œμŠ€ν„΄νŠΈ 역할이닀. ν•œκ΅­μ–΄λ‘œ μΉœμ ˆν•˜κ³  μ •ν™•ν•˜κ²Œ 닡변해라."
else:
combined_system_msg = "You are an AI assistant. Please respond helpfully and accurately in English."
if system_prompt.strip():
combined_system_msg += f"\n\n{system_prompt.strip()}"
# Web search processing
if use_web_search:
user_text = message.get("text", "")
if user_text:
ws_query = extract_keywords(user_text, top_k=5)
if ws_query.strip():
logger.info(f"[Auto web search keywords] {ws_query!r}")
ws_result = do_web_search(ws_query, use_korean=use_korean)
if not ws_result.startswith("Web search"):
combined_system_msg += f"\n\n[Search Results]\n{ws_result}"
if use_korean:
combined_system_msg += "\n\n[μ€‘μš”: 닡변에 검색 결과의 좜처λ₯Ό λ°˜λ“œμ‹œ μΈμš©ν•˜μ„Έμš”]"
else:
combined_system_msg += "\n\n[Important: Always cite sources from search results in your answer]"
messages.append({
"role": "system",
"content": combined_system_msg
})
# Add conversation history
messages.extend(process_history(history))
# Process current message
user_content = process_new_user_message(message)
messages.append({
"role": "user",
"content": user_content
})
# Debug log
logger.debug(f"Total messages: {len(messages)}")
# Call Friendli API and stream
for response_text in stream_friendli_response(messages, max_new_tokens):
yield response_text
except Exception as e:
logger.error(f"run function error: {str(e)}")
yield f"⚠️ Sorry, an error occurred: {str(e)}"
##############################################################################
# Examples
##############################################################################
examples = [
# PDF comparison example
[
{
"text": "Compare the contents of the two PDF files.",
"files": [
"assets/additional-examples/before.pdf",
"assets/additional-examples/after.pdf",
],
}
],
# CSV analysis example
[
{
"text": "Summarize and analyze the contents of the CSV file.",
"files": ["assets/additional-examples/sample-csv.csv"],
}
],
# Web search example
[
{
"text": "Explain discord.gg/openfreeai",
"files": [],
}
],
# Code generation example
[
{
"text": "Write Python code to generate Fibonacci sequence.",
"files": [],
}
],
]
##############################################################################
# Gradio UI - CSS Styles (Removed blue colors)
##############################################################################
css = """
/* Full width UI */
.gradio-container {
background: rgba(255, 255, 255, 0.95);
padding: 30px 40px;
margin: 20px auto;
width: 100% !important;
max-width: none !important;
border-radius: 12px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.fillable {
width: 100% !important;
max-width: 100% !important;
}
/* Background */
body {
background: linear-gradient(135deg, #f5f7fa 0%, #e0e0e0 100%);
margin: 0;
padding: 0;
font-family: 'Segoe UI', 'Helvetica Neue', Arial, sans-serif;
color: #333;
}
/* Button styles - neutral gray */
button, .btn {
background: #6b7280 !important;
border: none;
color: white !important;
padding: 10px 20px;
text-transform: uppercase;
font-weight: 600;
letter-spacing: 0.5px;
cursor: pointer;
border-radius: 6px;
transition: all 0.3s ease;
}
button:hover, .btn:hover {
background: #4b5563 !important;
transform: translateY(-1px);
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
}
/* Examples section */
#examples_container, .examples-container {
margin: 20px auto;
width: 90%;
background: rgba(255, 255, 255, 0.8);
padding: 20px;
border-radius: 8px;
}
#examples_row, .examples-row {
justify-content: center;
}
/* Example buttons */
.gr-samples-table button,
.gr-examples button,
.examples button {
background: #f0f2f5 !important;
border: 1px solid #d1d5db;
color: #374151 !important;
margin: 5px;
font-size: 14px;
}
.gr-samples-table button:hover,
.gr-examples button:hover,
.examples button:hover {
background: #e5e7eb !important;
border-color: #9ca3af;
}
/* Chat interface */
.chatbox, .chatbot {
background: white !important;
border-radius: 8px;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
}
.message {
padding: 15px;
margin: 10px 0;
border-radius: 8px;
}
/* Input styles */
.multimodal-textbox, textarea, input[type="text"] {
background: white !important;
border: 1px solid #d1d5db;
border-radius: 6px;
padding: 10px;
font-size: 16px;
}
.multimodal-textbox:focus, textarea:focus, input[type="text"]:focus {
border-color: #6b7280;
outline: none;
box-shadow: 0 0 0 3px rgba(107, 114, 128, 0.1);
}
/* Warning messages */
.warning-box {
background: #fef3c7 !important;
border: 1px solid #f59e0b;
border-radius: 8px;
padding: 15px;
margin: 10px 0;
color: #92400e;
}
/* Headings */
h1, h2, h3 {
color: #1f2937;
}
/* Links - neutral gray */
a {
color: #6b7280;
text-decoration: none;
}
a:hover {
text-decoration: underline;
color: #4b5563;
}
/* Slider */
.gr-slider {
margin: 15px 0;
}
/* Checkbox */
input[type="checkbox"] {
width: 18px;
height: 18px;
margin-right: 8px;
}
/* Scrollbar */
::-webkit-scrollbar {
width: 8px;
height: 8px;
}
::-webkit-scrollbar-track {
background: #f1f1f1;
}
::-webkit-scrollbar-thumb {
background: #888;
border-radius: 4px;
}
::-webkit-scrollbar-thumb:hover {
background: #555;
}
"""
##############################################################################
# Gradio UI Main
##############################################################################
with gr.Blocks(css=css, title="Gemma-3-R1984-27B Chatbot") as demo:
# Title
gr.Markdown("# πŸ€— Gemma-3-R1984-27B Chatbot")
gr.Markdown("Community: [https://discord.gg/openfreeai](https://discord.gg/openfreeai)")
# UI Components
with gr.Row():
with gr.Column(scale=2):
web_search_checkbox = gr.Checkbox(
label="πŸ” Enable Deep Research (Web Search)",
value=False,
info="Check for questions requiring latest information"
)
with gr.Column(scale=1):
korean_checkbox = gr.Checkbox(
label="πŸ‡°πŸ‡· ν•œκΈ€ (Korean)",
value=False,
info="Check for Korean responses"
)
with gr.Column(scale=1):
max_tokens_slider = gr.Slider(
label="Max Tokens",
minimum=100,
maximum=8000,
step=50,
value=1000,
info="Adjust response length"
)
# Main chat interface
chat = gr.ChatInterface(
fn=run,
type="messages",
chatbot=gr.Chatbot(type="messages", scale=1),
textbox=gr.MultimodalTextbox(
file_types=[
".webp", ".png", ".jpg", ".jpeg", ".gif",
".mp4", ".csv", ".txt", ".pdf"
],
file_count="multiple",
autofocus=True,
placeholder="Enter text or upload PDF, CSV, TXT files. (Images/videos not supported in this demo)"
),
multimodal=True,
additional_inputs=[
max_tokens_slider,
web_search_checkbox,
korean_checkbox,
],
stop_btn=False,
examples=examples,
run_examples_on_click=False,
cache_examples=False,
delete_cache=(1800, 1800),
)
if __name__ == "__main__":
demo.launch()