Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Legal RAG Chatbot for Hugging Face Spaces | |
直接使用原有的chatbot.py,只添加Gradio界面 | |
""" | |
import os | |
import gradio as gr | |
from chatbot import LegalChatbot | |
# 配置信息 - 从环境变量获取 | |
MILVUS_DB_PATH = os.environ.get("MILVUS_DB_PATH", "./milvus_legal_codes.db") | |
COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "legal_codes_collection") | |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "") | |
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o") | |
# 初始化聊天机器人 | |
def initialize_chatbot(): | |
if not OPENAI_API_KEY: | |
raise ValueError("❌ OPENAI_API_KEY environment variable is required!") | |
print(f"🔑 Using API key: {OPENAI_API_KEY[:10]}...") | |
print(f"🌐 Base URL: {OPENAI_BASE_URL or 'Default OpenAI'}") | |
return LegalChatbot( | |
milvus_db_path=MILVUS_DB_PATH, | |
collection_name=COLLECTION_NAME, | |
openai_api_key=OPENAI_API_KEY, | |
openai_base_url=OPENAI_BASE_URL if OPENAI_BASE_URL else None, | |
model_name=MODEL_NAME | |
) | |
# 全局聊天机器人实例 | |
try: | |
chatbot = initialize_chatbot() | |
print("✅ Chatbot initialized successfully!") | |
chatbot_status = f"✅ **Status**: Connected to database with {chatbot.collection_name} collection" | |
except Exception as e: | |
print(f"❌ Failed to initialize chatbot: {e}") | |
chatbot = None | |
chatbot_status = f"❌ **Status**: Configuration error - {str(e)}" | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
""" | |
Gradio ChatInterface响应函数 | |
使用原有chatbot的流式处理功能 | |
""" | |
if chatbot is None: | |
yield "❌ Chatbot not initialized. Please check the configuration." | |
return | |
try: | |
# 重置聊天机器人的对话历史 | |
chatbot.reset_conversation() | |
# 设置系统消息 | |
if system_message.strip(): | |
chatbot.conversation_history[0]["content"] = system_message | |
# 添加历史对话到聊天机器人 | |
for user_msg, assistant_msg in history: | |
if user_msg: | |
chatbot.conversation_history.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
chatbot.conversation_history.append({"role": "assistant", "content": assistant_msg}) | |
# 使用原有的流式处理功能 | |
response = "" | |
for chunk in chatbot.process_message_stream(message): | |
if chunk: | |
response += chunk | |
yield response | |
except Exception as e: | |
print(f"❌ Error in respond: {e}") | |
yield f"抱歉,处理您的消息时出现错误:{str(e)}" | |
# 准备描述信息 | |
base_description = """🤖 **AI法律助手** - 结合向量数据库搜索和大语言模型的智能法律咨询系统 | |
🔍 **核心功能:** | |
- 智能查询分析 - 自动判断是否需要搜索法律数据库 | |
- 向量相似度搜索 - 基于Milvus的高效法律文档检索 | |
- RAG增强生成 - 结合搜索结果提供准确回答 | |
- 实时流式回复 - 支持打字机效果的实时响应 | |
💡 **试试这些问题:** | |
• "What are the fall protection requirements in Ontario construction?" | |
• "Tell me about employer duties under Canada Labour Code" | |
• "Search for information about workplace safety regulations" | |
• "What are my rights under the Charter of Rights and Freedoms?" | |
""" | |
full_description = base_description + f"\n\n{chatbot_status}" | |
# 创建Gradio ChatInterface | |
demo = gr.ChatInterface( | |
fn=respond, | |
title="⚖️ Legal RAG Assistant", | |
description=full_description, | |
additional_inputs=[ | |
gr.Textbox( | |
value="You are a helpful legal assistant with expertise in Canadian law. You have access to a legal database and should provide accurate, well-sourced legal information. Always cite specific legal sources when possible. Remember to include appropriate disclaimers that this is for informational purposes only and not legal advice.", | |
label="System Message", | |
lines=3, | |
max_lines=5 | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=2048, | |
value=1024, | |
step=1, | |
label="Max new tokens" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)" | |
), | |
], | |
theme=gr.themes.Soft(), | |
analytics_enabled=False | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
show_error=True | |
) |