Spaces:
Sleeping
Sleeping
File size: 4,929 Bytes
ae3d800 37ddc4e ae3d800 016bff1 37ddc4e ae3d800 016bff1 ae3d800 016bff1 ae3d800 016bff1 ae3d800 016bff1 37ddc4e 016bff1 ae3d800 016bff1 ae3d800 016bff1 ae3d800 016bff1 ae3d800 37ddc4e a08cd01 016bff1 a08cd01 016bff1 a08cd01 ae3d800 37ddc4e 5eecf0a ae3d800 a08cd01 37ddc4e ae3d800 016bff1 ae3d800 37ddc4e 016bff1 37ddc4e 016bff1 37ddc4e ae3d800 016bff1 ae3d800 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
#!/usr/bin/env python3
"""
Legal RAG Chatbot for Hugging Face Spaces
直接使用原有的chatbot.py,只添加Gradio界面
"""
import os
import gradio as gr
from chatbot import LegalChatbot
# 配置信息 - 从环境变量获取
MILVUS_DB_PATH = os.environ.get("MILVUS_DB_PATH", "./milvus_legal_codes.db")
COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "legal_codes_collection")
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "")
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
# 初始化聊天机器人
def initialize_chatbot():
if not OPENAI_API_KEY:
raise ValueError("❌ OPENAI_API_KEY environment variable is required!")
print(f"🔑 Using API key: {OPENAI_API_KEY[:10]}...")
print(f"🌐 Base URL: {OPENAI_BASE_URL or 'Default OpenAI'}")
return LegalChatbot(
milvus_db_path=MILVUS_DB_PATH,
collection_name=COLLECTION_NAME,
openai_api_key=OPENAI_API_KEY,
openai_base_url=OPENAI_BASE_URL if OPENAI_BASE_URL else None,
model_name=MODEL_NAME
)
# 全局聊天机器人实例
try:
chatbot = initialize_chatbot()
print("✅ Chatbot initialized successfully!")
chatbot_status = f"✅ **Status**: Connected to database with {chatbot.collection_name} collection"
except Exception as e:
print(f"❌ Failed to initialize chatbot: {e}")
chatbot = None
chatbot_status = f"❌ **Status**: Configuration error - {str(e)}"
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
"""
Gradio ChatInterface响应函数
使用原有chatbot的流式处理功能
"""
if chatbot is None:
yield "❌ Chatbot not initialized. Please check the configuration."
return
try:
# 重置聊天机器人的对话历史
chatbot.reset_conversation()
# 设置系统消息
if system_message.strip():
chatbot.conversation_history[0]["content"] = system_message
# 添加历史对话到聊天机器人
for user_msg, assistant_msg in history:
if user_msg:
chatbot.conversation_history.append({"role": "user", "content": user_msg})
if assistant_msg:
chatbot.conversation_history.append({"role": "assistant", "content": assistant_msg})
# 使用原有的流式处理功能
response = ""
for chunk in chatbot.process_message_stream(message):
if chunk:
response += chunk
yield response
except Exception as e:
print(f"❌ Error in respond: {e}")
yield f"抱歉,处理您的消息时出现错误:{str(e)}"
# 准备描述信息
base_description = """🤖 **AI法律助手** - 结合向量数据库搜索和大语言模型的智能法律咨询系统
🔍 **核心功能:**
- 智能查询分析 - 自动判断是否需要搜索法律数据库
- 向量相似度搜索 - 基于Milvus的高效法律文档检索
- RAG增强生成 - 结合搜索结果提供准确回答
- 实时流式回复 - 支持打字机效果的实时响应
💡 **试试这些问题:**
• "What are the fall protection requirements in Ontario construction?"
• "Tell me about employer duties under Canada Labour Code"
• "Search for information about workplace safety regulations"
• "What are my rights under the Charter of Rights and Freedoms?"
"""
full_description = base_description + f"\n\n{chatbot_status}"
# 创建Gradio ChatInterface
demo = gr.ChatInterface(
fn=respond,
title="⚖️ Legal RAG Assistant",
description=full_description,
additional_inputs=[
gr.Textbox(
value="You are a helpful legal assistant with expertise in Canadian law. You have access to a legal database and should provide accurate, well-sourced legal information. Always cite specific legal sources when possible. Remember to include appropriate disclaimers that this is for informational purposes only and not legal advice.",
label="System Message",
lines=3,
max_lines=5
),
gr.Slider(
minimum=1,
maximum=2048,
value=1024,
step=1,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
),
],
theme=gr.themes.Soft(),
analytics_enabled=False
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
) |