#!/usr/bin/env python import os import re import json import requests from collections.abc import Iterator from threading import Thread import gradio as gr from loguru import logger import pandas as pd import PyPDF2 ############################################################################## # API Configuration ############################################################################## FRIENDLI_TOKEN = os.environ.get("FRIENDLI_TOKEN") if not FRIENDLI_TOKEN: raise ValueError("Please set FRIENDLI_TOKEN environment variable") FRIENDLI_MODEL_ID = "dep89a2fld32mcm" FRIENDLI_API_URL = "https://api.friendli.ai/dedicated/v1/chat/completions" # SERPHouse API key SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "") if not SERPHOUSE_API_KEY: logger.warning("SERPHOUSE_API_KEY not set. Web search functionality will be limited.") ############################################################################## # File Processing Constants ############################################################################## MAX_FILE_SIZE = 30 * 1024 * 1024 # 30MB MAX_CONTENT_CHARS = 6000 ############################################################################## # Improved Keyword Extraction ############################################################################## def extract_keywords(text: str, top_k: int = 5) -> str: """ Extract keywords: supports English and Korean """ stop_words = {'은', '는', '이', '가', '을', '를', '의', '에', '에서', 'the', 'is', 'at', 'on', 'in', 'a', 'an', 'and', 'or', 'but'} text = re.sub(r"[^a-zA-Z0-9가-힣\s]", "", text) tokens = text.split() key_tokens = [ token for token in tokens if token.lower() not in stop_words and len(token) > 1 ][:top_k] return " ".join(key_tokens) ############################################################################## # File Size Validation ############################################################################## def validate_file_size(file_path: str) -> bool: """Check if file size is within limits""" try: file_size = os.path.getsize(file_path) return file_size <= MAX_FILE_SIZE except: return False ############################################################################## # Web Search Function ############################################################################## def do_web_search(query: str, use_korean: bool = False) -> str: """ Search web and return top 20 organic results """ if not SERPHOUSE_API_KEY: return "Web search unavailable. API key not configured." try: url = "https://api.serphouse.com/serp/live" params = { "q": query, "domain": "google.com", "serp_type": "web", "device": "desktop", "lang": "ko" if use_korean else "en", "num": "20" } headers = { "Authorization": f"Bearer {SERPHOUSE_API_KEY}" } logger.info(f"Calling SerpHouse API... Query: {query}") response = requests.get(url, headers=headers, params=params, timeout=30) response.raise_for_status() data = response.json() # Parse results results = data.get("results", {}) organic = None if isinstance(results, dict) and "organic" in results: organic = results["organic"] elif isinstance(results, dict) and "results" in results: if isinstance(results["results"], dict) and "organic" in results["results"]: organic = results["results"]["organic"] elif "organic" in data: organic = data["organic"] if not organic: return "No search results found or unexpected API response structure." max_results = min(20, len(organic)) limited_organic = organic[:max_results] summary_lines = [] for idx, item in enumerate(limited_organic, start=1): title = item.get("title", "No title") link = item.get("link", "#") snippet = item.get("snippet", "No description") displayed_link = item.get("displayed_link", link) summary_lines.append( f"### Result {idx}: {title}\n\n" f"{snippet}\n\n" f"**Source**: [{displayed_link}]({link})\n\n" f"---\n" ) instructions = """ # Web Search Results Below are the search results. Use this information when answering questions: 1. Reference the title, content, and source links 2. Explicitly cite sources in your answer (e.g., "According to source X...") 3. Include actual source links in your response 4. Synthesize information from multiple sources """ search_results = instructions + "\n".join(summary_lines) return search_results except requests.exceptions.Timeout: logger.error("Web search timeout") return "Web search timed out. Please try again." except requests.exceptions.RequestException as e: logger.error(f"Web search network error: {e}") return "Network error during web search." except Exception as e: logger.error(f"Web search failed: {e}") return f"Web search failed: {str(e)}" ############################################################################## # File Analysis Functions ############################################################################## def analyze_csv_file(path: str) -> str: """Analyze CSV file with size validation and encoding handling""" if not validate_file_size(path): return f"⚠️ Error: File size exceeds {MAX_FILE_SIZE/1024/1024:.1f}MB limit." try: encodings = ['utf-8', 'cp949', 'euc-kr', 'latin-1'] df = None for encoding in encodings: try: df = pd.read_csv(path, encoding=encoding, nrows=50) break except UnicodeDecodeError: continue if df is None: return f"Failed to read CSV: Unsupported encoding" total_rows = len(pd.read_csv(path, encoding=encoding, usecols=[0])) if df.shape[1] > 10: df = df.iloc[:, :10] summary = f"**Data size**: {total_rows} rows x {df.shape[1]} columns\n" summary += f"**Showing**: Top {min(50, total_rows)} rows\n" summary += f"**Columns**: {', '.join(df.columns)}\n\n" df_str = df.to_string() if len(df_str) > MAX_CONTENT_CHARS: df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(truncated)..." return f"**[CSV File: {os.path.basename(path)}]**\n\n{summary}{df_str}" except Exception as e: logger.error(f"CSV read error: {e}") return f"Failed to read CSV file ({os.path.basename(path)}): {str(e)}" def analyze_txt_file(path: str) -> str: """Analyze text file with automatic encoding detection""" if not validate_file_size(path): return f"⚠️ Error: File size exceeds {MAX_FILE_SIZE/1024/1024:.1f}MB limit." encodings = ['utf-8', 'cp949', 'euc-kr', 'latin-1', 'utf-16'] for encoding in encodings: try: with open(path, "r", encoding=encoding) as f: text = f.read() file_size = os.path.getsize(path) size_info = f"**File size**: {file_size/1024:.1f}KB\n\n" if len(text) > MAX_CONTENT_CHARS: text = text[:MAX_CONTENT_CHARS] + "\n...(truncated)..." return f"**[TXT File: {os.path.basename(path)}]**\n\n{size_info}{text}" except UnicodeDecodeError: continue return f"Failed to read text file ({os.path.basename(path)}): Unsupported encoding" def pdf_to_markdown(pdf_path: str) -> str: """Convert PDF to markdown with improved error handling""" if not validate_file_size(pdf_path): return f"⚠️ Error: File size exceeds {MAX_FILE_SIZE/1024/1024:.1f}MB limit." text_chunks = [] try: with open(pdf_path, "rb") as f: reader = PyPDF2.PdfReader(f) total_pages = len(reader.pages) max_pages = min(5, total_pages) text_chunks.append(f"**Total pages**: {total_pages}") text_chunks.append(f"**Showing**: First {max_pages} pages\n") for page_num in range(max_pages): try: page = reader.pages[page_num] page_text = page.extract_text() or "" page_text = page_text.strip() if page_text: if len(page_text) > MAX_CONTENT_CHARS // max_pages: page_text = page_text[:MAX_CONTENT_CHARS // max_pages] + "...(truncated)" text_chunks.append(f"## Page {page_num+1}\n\n{page_text}\n") except Exception as e: text_chunks.append(f"## Page {page_num+1}\n\nFailed to read page: {str(e)}\n") if total_pages > max_pages: text_chunks.append(f"\n...({max_pages}/{total_pages} pages shown)...") except Exception as e: logger.error(f"PDF read error: {e}") return f"Failed to read PDF file ({os.path.basename(pdf_path)}): {str(e)}" full_text = "\n".join(text_chunks) if len(full_text) > MAX_CONTENT_CHARS: full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..." return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}" ############################################################################## # File Type Check Functions ############################################################################## def is_image_file(file_path: str) -> bool: """Check if file is an image""" return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE)) def is_video_file(file_path: str) -> bool: """Check if file is a video""" return bool(re.search(r"\.(mp4|avi|mov|mkv)$", file_path, re.IGNORECASE)) def is_document_file(file_path: str) -> bool: """Check if file is a document""" return bool(re.search(r"\.(pdf|csv|txt)$", file_path, re.IGNORECASE)) ############################################################################## # Message Processing Functions ############################################################################## def process_new_user_message(message: dict) -> str: """Process user message and convert to text""" content_parts = [message["text"]] if not message.get("files"): return message["text"] # Classify files csv_files = [] txt_files = [] pdf_files = [] image_files = [] video_files = [] unknown_files = [] for file_path in message["files"]: if file_path.lower().endswith(".csv"): csv_files.append(file_path) elif file_path.lower().endswith(".txt"): txt_files.append(file_path) elif file_path.lower().endswith(".pdf"): pdf_files.append(file_path) elif is_image_file(file_path): image_files.append(file_path) elif is_video_file(file_path): video_files.append(file_path) else: unknown_files.append(file_path) # Process document files for csv_path in csv_files: csv_analysis = analyze_csv_file(csv_path) content_parts.append(csv_analysis) for txt_path in txt_files: txt_analysis = analyze_txt_file(txt_path) content_parts.append(txt_analysis) for pdf_path in pdf_files: pdf_markdown = pdf_to_markdown(pdf_path) content_parts.append(pdf_markdown) # Warning messages for unsupported files if image_files: image_names = [os.path.basename(f) for f in image_files] content_parts.append( f"\n⚠️ **Image files detected**: {', '.join(image_names)}\n" "This demo currently does not support image analysis. " "Please describe the image content in text if you need help with it." ) if video_files: video_names = [os.path.basename(f) for f in video_files] content_parts.append( f"\n⚠️ **Video files detected**: {', '.join(video_names)}\n" "This demo currently does not support video analysis. " "Please describe the video content in text if you need help with it." ) if unknown_files: unknown_names = [os.path.basename(f) for f in unknown_files] content_parts.append( f"\n⚠️ **Unsupported file format**: {', '.join(unknown_names)}\n" "Supported formats: PDF, CSV, TXT" ) return "\n\n".join(content_parts) def process_history(history: list[dict]) -> list[dict]: """Convert conversation history to Friendli API format""" messages = [] for item in history: if item["role"] == "assistant": messages.append({ "role": "assistant", "content": item["content"] }) else: # user content = item["content"] if isinstance(content, str): messages.append({ "role": "user", "content": content }) elif isinstance(content, list) and len(content) > 0: # File processing file_info = [] for file_path in content: if isinstance(file_path, str): file_info.append(f"[File: {os.path.basename(file_path)}]") if file_info: messages.append({ "role": "user", "content": " ".join(file_info) }) return messages ############################################################################## # Streaming Response Handler ############################################################################## def stream_friendli_response(messages: list[dict], max_tokens: int = 1000) -> Iterator[str]: """Get streaming response from Friendli AI API""" headers = { "Authorization": f"Bearer {FRIENDLI_TOKEN}", "Content-Type": "application/json" } payload = { "model": FRIENDLI_MODEL_ID, "messages": messages, "max_tokens": max_tokens, "top_p": 0.8, "temperature": 0.7, "stream": True, "stream_options": { "include_usage": True } } try: response = requests.post( FRIENDLI_API_URL, headers=headers, json=payload, stream=True, timeout=60 ) response.raise_for_status() full_response = "" for line in response.iter_lines(): if line: line_text = line.decode('utf-8') if line_text.startswith("data: "): data_str = line_text[6:] if data_str == "[DONE]": break try: data = json.loads(data_str) if "choices" in data and len(data["choices"]) > 0: delta = data["choices"][0].get("delta", {}) content = delta.get("content", "") if content: full_response += content yield full_response except json.JSONDecodeError: logger.warning(f"JSON parsing failed: {data_str}") continue except requests.exceptions.Timeout: yield "⚠️ Response timeout. Please try again." except requests.exceptions.RequestException as e: logger.error(f"Friendli API network error: {e}") yield f"⚠️ Network error occurred: {str(e)}" except Exception as e: logger.error(f"Friendli API error: {str(e)}") yield f"⚠️ API call error: {str(e)}" ############################################################################## # Main Inference Function ############################################################################## def run( message: dict, history: list[dict], max_new_tokens: int = 512, use_web_search: bool = False, use_korean: bool = False, system_prompt: str = "", ) -> Iterator[str]: try: # Prepare system message messages = [] if use_korean: combined_system_msg = "너는 AI 어시스턴트 역할이다. 한국어로 친절하고 정확하게 답변해라." else: combined_system_msg = "You are an AI assistant. Please respond helpfully and accurately in English." if system_prompt.strip(): combined_system_msg += f"\n\n{system_prompt.strip()}" # Web search processing if use_web_search: user_text = message.get("text", "") if user_text: ws_query = extract_keywords(user_text, top_k=5) if ws_query.strip(): logger.info(f"[Auto web search keywords] {ws_query!r}") ws_result = do_web_search(ws_query, use_korean=use_korean) if not ws_result.startswith("Web search"): combined_system_msg += f"\n\n[Search Results]\n{ws_result}" if use_korean: combined_system_msg += "\n\n[중요: 답변에 검색 결과의 출처를 반드시 인용하세요]" else: combined_system_msg += "\n\n[Important: Always cite sources from search results in your answer]" messages.append({ "role": "system", "content": combined_system_msg }) # Add conversation history messages.extend(process_history(history)) # Process current message user_content = process_new_user_message(message) messages.append({ "role": "user", "content": user_content }) # Debug log logger.debug(f"Total messages: {len(messages)}") # Call Friendli API and stream for response_text in stream_friendli_response(messages, max_new_tokens): yield response_text except Exception as e: logger.error(f"run function error: {str(e)}") yield f"⚠️ Sorry, an error occurred: {str(e)}" ############################################################################## # Examples ############################################################################## examples = [ # PDF comparison example [ { "text": "Compare the contents of the two PDF files.", "files": [ "assets/additional-examples/before.pdf", "assets/additional-examples/after.pdf", ], } ], # CSV analysis example [ { "text": "Summarize and analyze the contents of the CSV file.", "files": ["assets/additional-examples/sample-csv.csv"], } ], # Web search example [ { "text": "Explain discord.gg/openfreeai", "files": [], } ], # Code generation example [ { "text": "Write Python code to generate Fibonacci sequence.", "files": [], } ], ] ############################################################################## # Gradio UI - CSS Styles (Removed blue colors) ############################################################################## css = """ /* Full width UI */ .gradio-container { background: rgba(255, 255, 255, 0.95); padding: 30px 40px; margin: 20px auto; width: 100% !important; max-width: none !important; border-radius: 12px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); } .fillable { width: 100% !important; max-width: 100% !important; } /* Background */ body { background: linear-gradient(135deg, #f5f7fa 0%, #e0e0e0 100%); margin: 0; padding: 0; font-family: 'Segoe UI', 'Helvetica Neue', Arial, sans-serif; color: #333; } /* Button styles - neutral gray */ button, .btn { background: #6b7280 !important; border: none; color: white !important; padding: 10px 20px; text-transform: uppercase; font-weight: 600; letter-spacing: 0.5px; cursor: pointer; border-radius: 6px; transition: all 0.3s ease; } button:hover, .btn:hover { background: #4b5563 !important; transform: translateY(-1px); box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2); } /* Examples section */ #examples_container, .examples-container { margin: 20px auto; width: 90%; background: rgba(255, 255, 255, 0.8); padding: 20px; border-radius: 8px; } #examples_row, .examples-row { justify-content: center; } /* Example buttons */ .gr-samples-table button, .gr-examples button, .examples button { background: #f0f2f5 !important; border: 1px solid #d1d5db; color: #374151 !important; margin: 5px; font-size: 14px; } .gr-samples-table button:hover, .gr-examples button:hover, .examples button:hover { background: #e5e7eb !important; border-color: #9ca3af; } /* Chat interface */ .chatbox, .chatbot { background: white !important; border-radius: 8px; box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); } .message { padding: 15px; margin: 10px 0; border-radius: 8px; } /* Input styles */ .multimodal-textbox, textarea, input[type="text"] { background: white !important; border: 1px solid #d1d5db; border-radius: 6px; padding: 10px; font-size: 16px; } .multimodal-textbox:focus, textarea:focus, input[type="text"]:focus { border-color: #6b7280; outline: none; box-shadow: 0 0 0 3px rgba(107, 114, 128, 0.1); } /* Warning messages */ .warning-box { background: #fef3c7 !important; border: 1px solid #f59e0b; border-radius: 8px; padding: 15px; margin: 10px 0; color: #92400e; } /* Headings */ h1, h2, h3 { color: #1f2937; } /* Links - neutral gray */ a { color: #6b7280; text-decoration: none; } a:hover { text-decoration: underline; color: #4b5563; } /* Slider */ .gr-slider { margin: 15px 0; } /* Checkbox */ input[type="checkbox"] { width: 18px; height: 18px; margin-right: 8px; } /* Scrollbar */ ::-webkit-scrollbar { width: 8px; height: 8px; } ::-webkit-scrollbar-track { background: #f1f1f1; } ::-webkit-scrollbar-thumb { background: #888; border-radius: 4px; } ::-webkit-scrollbar-thumb:hover { background: #555; } """ ############################################################################## # Gradio UI Main ############################################################################## with gr.Blocks(css=css, title="Gemma-3-R1984-27B Chatbot") as demo: # Title gr.Markdown("# 🤗 Gemma-3-R1984-27B Chatbot") gr.Markdown("Community: [https://discord.gg/openfreeai](https://discord.gg/openfreeai)") # UI Components with gr.Row(): with gr.Column(scale=2): web_search_checkbox = gr.Checkbox( label="🔍 Enable Deep Research (Web Search)", value=False, info="Check for questions requiring latest information" ) with gr.Column(scale=1): korean_checkbox = gr.Checkbox( label="🇰🇷 한글 (Korean)", value=False, info="Check for Korean responses" ) with gr.Column(scale=1): max_tokens_slider = gr.Slider( label="Max Tokens", minimum=100, maximum=8000, step=50, value=1000, info="Adjust response length" ) # Main chat interface chat = gr.ChatInterface( fn=run, type="messages", chatbot=gr.Chatbot(type="messages", scale=1), textbox=gr.MultimodalTextbox( file_types=[ ".webp", ".png", ".jpg", ".jpeg", ".gif", ".mp4", ".csv", ".txt", ".pdf" ], file_count="multiple", autofocus=True, placeholder="Enter text or upload PDF, CSV, TXT files. (Images/videos not supported in this demo)" ), multimodal=True, additional_inputs=[ max_tokens_slider, web_search_checkbox, korean_checkbox, ], stop_btn=False, examples=examples, run_examples_on_click=False, cache_examples=False, delete_cache=(1800, 1800), ) if __name__ == "__main__": demo.launch()