#!/usr/bin/env python import os import re import tempfile import gc from collections.abc import Iterator from threading import Thread import json import requests import gradio as gr import spaces import torch from loguru import logger from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer # CSV/TXT 분석 import pandas as pd # PDF 텍스트 추출 import PyPDF2 ############################################################################## # 메모리 정리 함수 추가 ############################################################################## def clear_cuda_cache(): """CUDA 캐시를 명시적으로 비웁니다.""" if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() ############################################################################## # SERPHouse API key from environment variable ############################################################################## SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "") ############################################################################## # 간단한 키워드 추출 함수 (한글 + 알파벳 + 숫자 + 공백 보존) ############################################################################## def extract_keywords(text: str, top_k: int = 5) -> str: """ 1) 한글(가-힣), 영어(a-zA-Z), 숫자(0-9), 공백만 남김 2) 공백 기준 토큰 분리 3) 최대 top_k개만 """ text = re.sub(r"[^a-zA-Z0-9가-힣\s]", "", text) tokens = text.split() key_tokens = tokens[:top_k] return " ".join(key_tokens) ############################################################################## # SerpHouse Live endpoint 호출 ############################################################################## def do_web_search(query: str) -> str: """ 상위 20개 'organic' 결과 item 전체(제목, link, snippet 등)를 JSON 문자열 형태로 반환 """ try: url = "https://api.serphouse.com/serp/live" params = { "q": query, "domain": "google.com", "serp_type": "web", "device": "desktop", "lang": "en", "num": "20" } headers = { "Authorization": f"Bearer {SERPHOUSE_API_KEY}" } logger.info(f"SerpHouse API 호출 중... 검색어: {query}") response = requests.get(url, headers=headers, params=params, timeout=60) response.raise_for_status() data = response.json() # 다양한 응답 구조 처리 results = data.get("results", {}) organic = None if isinstance(results, dict) and "organic" in results: organic = results["organic"] elif isinstance(results, dict) and "results" in results: if isinstance(results["results"], dict) and "organic" in results["results"]: organic = results["results"]["organic"] elif "organic" in data: organic = data["organic"] if not organic: logger.warning("응답에서 organic 결과를 찾을 수 없습니다.") return "No web search results found or unexpected API response structure." # 결과 수 제한 및 컨텍스트 길이 최적화 max_results = min(20, len(organic)) limited_organic = organic[:max_results] # 결과 형식 개선 - 마크다운 형식으로 출력 summary_lines = [] for idx, item in enumerate(limited_organic, start=1): title = item.get("title", "No title") link = item.get("link", "#") snippet = item.get("snippet", "No description") displayed_link = item.get("displayed_link", link) summary_lines.append( f"### Result {idx}: {title}\n\n" f"{snippet}\n\n" f"**출처**: [{displayed_link}]({link})\n\n" f"---\n" ) instructions = """ # 웹 검색 결과 아래는 검색 결과입니다. 질문에 답변할 때 이 정보를 활용하세요: 1. 각 결과의 제목, 내용, 출처 링크를 참고하세요 2. 답변에 관련 정보의 출처를 명시적으로 인용하세요 (예: "X 출처에 따르면...") 3. 응답에 실제 출처 링크를 포함하세요 4. 여러 출처의 정보를 종합하여 답변하세요 """ search_results = instructions + "\n".join(summary_lines) logger.info(f"검색 결과 {len(limited_organic)}개 처리 완료") return search_results except Exception as e: logger.error(f"Web search failed: {e}") return f"Web search failed: {str(e)}" ############################################################################## # 모델/토크나이저 로딩 (텍스트 전용) ############################################################################## MAX_CONTENT_CHARS = 2000 MAX_INPUT_LENGTH = 2096 model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-1B") # 텍스트 전용 모델로 로드 tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" ) ############################################################################## # CSV, TXT, PDF 분석 함수 ############################################################################## def analyze_csv_file(path: str) -> str: """CSV 파일을 전체 문자열로 변환. 너무 길 경우 일부만 표시.""" try: df = pd.read_csv(path) if df.shape[0] > 50 or df.shape[1] > 10: df = df.iloc[:50, :10] df_str = df.to_string() if len(df_str) > MAX_CONTENT_CHARS: df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(truncated)..." return f"**[CSV File: {os.path.basename(path)}]**\n\n{df_str}" except Exception as e: return f"Failed to read CSV ({os.path.basename(path)}): {str(e)}" def analyze_txt_file(path: str) -> str: """TXT 파일 전문 읽기. 너무 길면 일부만 표시.""" try: with open(path, "r", encoding="utf-8") as f: text = f.read() if len(text) > MAX_CONTENT_CHARS: text = text[:MAX_CONTENT_CHARS] + "\n...(truncated)..." return f"**[TXT File: {os.path.basename(path)}]**\n\n{text}" except Exception as e: return f"Failed to read TXT ({os.path.basename(path)}): {str(e)}" def pdf_to_markdown(pdf_path: str) -> str: """PDF 텍스트를 Markdown으로 변환. 페이지별로 간단히 텍스트 추출.""" text_chunks = [] try: with open(pdf_path, "rb") as f: reader = PyPDF2.PdfReader(f) max_pages = min(5, len(reader.pages)) for page_num in range(max_pages): page = reader.pages[page_num] page_text = page.extract_text() or "" page_text = page_text.strip() if page_text: if len(page_text) > MAX_CONTENT_CHARS // max_pages: page_text = page_text[:MAX_CONTENT_CHARS // max_pages] + "...(truncated)" text_chunks.append(f"## Page {page_num+1}\n\n{page_text}\n") if len(reader.pages) > max_pages: text_chunks.append(f"\n...(Showing {max_pages} of {len(reader.pages)} pages)...") except Exception as e: return f"Failed to read PDF ({os.path.basename(pdf_path)}): {str(e)}" full_text = "\n".join(text_chunks) if len(full_text) > MAX_CONTENT_CHARS: full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..." return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}" ############################################################################## # 문서 파일 확인 ############################################################################## def is_document_file(file_path: str) -> bool: return ( file_path.lower().endswith(".pdf") or file_path.lower().endswith(".csv") or file_path.lower().endswith(".txt") ) ############################################################################## # 메시지 처리 (텍스트 및 문서 파일만) ############################################################################## def process_new_user_message(message: dict) -> str: """사용자 메시지와 첨부된 문서 파일들을 처리하여 하나의 텍스트로 결합""" content_parts = [message["text"]] if message.get("files"): csv_files = [f for f in message["files"] if f.lower().endswith(".csv")] txt_files = [f for f in message["files"] if f.lower().endswith(".txt")] pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")] for csv_path in csv_files: csv_analysis = analyze_csv_file(csv_path) content_parts.append(csv_analysis) for txt_path in txt_files: txt_analysis = analyze_txt_file(txt_path) content_parts.append(txt_analysis) for pdf_path in pdf_files: pdf_markdown = pdf_to_markdown(pdf_path) content_parts.append(pdf_markdown) return "\n\n".join(content_parts) ############################################################################## # 대화 히스토리 처리 ############################################################################## def process_history(history: list[dict]) -> str: """대화 히스토리를 텍스트 형식으로 변환""" conversation_text = "" for item in history: if item["role"] == "assistant": conversation_text += f"\nAssistant: {item['content']}\n" else: # user content = item["content"] if isinstance(content, str): conversation_text += f"\nUser: {content}\n" elif isinstance(content, list) and len(content) > 0: # 파일 경로만 표시 file_path = content[0] conversation_text += f"\nUser: [File: {os.path.basename(file_path)}]\n" return conversation_text ############################################################################## # 모델 생성 함수 ############################################################################## def _model_gen_with_oom_catch(**kwargs): """별도 스레드에서 OutOfMemoryError를 잡아주기 위해""" try: model.generate(**kwargs) except torch.cuda.OutOfMemoryError: raise RuntimeError( "[OutOfMemoryError] GPU 메모리가 부족합니다. " "Max New Tokens을 줄이거나, 프롬프트 길이를 줄여주세요." ) finally: clear_cuda_cache() ############################################################################## # 메인 추론 함수 (텍스트 전용) ############################################################################## @spaces.GPU(duration=120) def run( message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512, use_web_search: bool = False, web_search_query: str = "", ) -> Iterator[str]: try: # 전체 프롬프트 구성 full_prompt = "" # 시스템 프롬프트 if system_prompt.strip(): full_prompt += f"System: {system_prompt.strip()}\n\n" # 웹 검색 수행 if use_web_search: user_text = message["text"] ws_query = extract_keywords(user_text, top_k=5) if ws_query.strip(): logger.info(f"[Auto WebSearch Keyword] {ws_query!r}") ws_result = do_web_search(ws_query) full_prompt += f"[Web Search Results]\n{ws_result}\n\n" full_prompt += "[중요: 위 검색결과의 출처를 인용하여 답변해 주세요.]\n\n" # 대화 히스토리 if history: conversation_history = process_history(history) full_prompt += conversation_history # 현재 사용자 메시지 user_content = process_new_user_message(message) full_prompt += f"\nUser: {user_content}\nAssistant:" # 토큰화 inputs = tokenizer( full_prompt, return_tensors="pt", truncation=True, max_length=MAX_INPUT_LENGTH ).to(device=model.device) # 스트리밍 설정 streamer = TextIteratorStreamer( tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True ) gen_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, temperature=0.7, top_p=0.9, do_sample=True, ) # 별도 스레드에서 생성 t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs) t.start() # 스트리밍 출력 output = "" for new_text in streamer: output += new_text yield output except Exception as e: logger.error(f"Error in run: {str(e)}") yield f"죄송합니다. 오류가 발생했습니다: {str(e)}" finally: # 메모리 정리 try: del inputs except: pass clear_cuda_cache() ############################################################################## # 예시들 (텍스트 및 문서 파일만) ############################################################################## examples = [ [ { "text": "Compare the contents of the two PDF files.", "files": [ "assets/additional-examples/before.pdf", "assets/additional-examples/after.pdf", ], } ], [ { "text": "Summarize and analyze the contents of the CSV file.", "files": ["assets/additional-examples/sample-csv.csv"], } ], [ { "text": "What are the key findings from this research paper?", "files": ["assets/additional-examples/research.pdf"], } ], [ { "text": "Analyze the data trends in this CSV file.", "files": ["assets/additional-examples/data.csv"], } ], [ { "text": "Summarize the main points from this text document.", "files": ["assets/additional-examples/document.txt"], } ], ] ############################################################################## # Gradio UI ############################################################################## css = """ .gradio-container { background: rgba(255, 255, 255, 0.7); padding: 30px 40px; margin: 20px auto; width: 100% !important; max-width: none !important; } .fillable { width: 100% !important; max-width: 100% !important; } body { background: transparent; margin: 0; padding: 0; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; color: #333; } button, .btn { background: transparent !important; border: 1px solid #ddd; color: #333; padding: 12px 24px; text-transform: uppercase; font-weight: bold; letter-spacing: 1px; cursor: pointer; } button:hover, .btn:hover { background: rgba(0, 0, 0, 0.05) !important; } """ title_html = """
✅Agentic AI Platform ✅Reasoning ✅Text Analysis ✅Deep-Research & RAG
✅Document Processing (PDF, CSV, TXT) ✅Web Search Integration
Operates on an ✅'NVIDIA L40s / A100(ZeroGPU) GPU' as an independent local server
@Model Repository: VIDraft/Gemma-3-R1984-1B, @Based by 'Google Gemma-3-1b'