Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,212 +1,22 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
"""
|
3 |
Legal RAG Chatbot for Hugging Face Spaces
|
4 |
-
|
5 |
"""
|
6 |
|
7 |
import os
|
8 |
import gradio as gr
|
9 |
-
from
|
10 |
-
from openai import OpenAI
|
11 |
-
import json
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
print(f"📁 Database path: {milvus_db_path}")
|
20 |
-
print(f"📚 Collection: {collection_name}")
|
21 |
-
print(f"🤖 Model: {model_name}")
|
22 |
-
|
23 |
-
# 初始化Milvus客户端
|
24 |
-
self.milvus_client = MilvusClient(milvus_db_path)
|
25 |
-
self.collection_name = collection_name
|
26 |
-
|
27 |
-
# 检查集合是否存在
|
28 |
-
if not self.milvus_client.has_collection(collection_name=collection_name):
|
29 |
-
print(f"⚠️ Collection '{collection_name}' does not exist. Creating with sample data...")
|
30 |
-
# 初始化嵌入模型
|
31 |
-
self.embedding_fn = model.DefaultEmbeddingFunction()
|
32 |
-
vector_dim = self.embedding_fn.dim
|
33 |
-
|
34 |
-
# 创建新的集合
|
35 |
-
self.milvus_client.create_collection(
|
36 |
-
collection_name=collection_name,
|
37 |
-
dimension=vector_dim
|
38 |
-
)
|
39 |
-
self._add_sample_data()
|
40 |
-
else:
|
41 |
-
print(f"✅ Found existing collection '{collection_name}'")
|
42 |
-
|
43 |
-
# 初始化嵌入模型
|
44 |
-
self.embedding_fn = model.DefaultEmbeddingFunction()
|
45 |
-
|
46 |
-
# 初始化OpenAI客户端
|
47 |
-
if openai_base_url:
|
48 |
-
self.llm_client = OpenAI(api_key=openai_api_key, base_url=openai_base_url)
|
49 |
-
else:
|
50 |
-
self.llm_client = OpenAI(api_key=openai_api_key)
|
51 |
-
|
52 |
-
self.model_name = model_name
|
53 |
-
|
54 |
-
print("✅ Legal Chatbot initialized successfully!")
|
55 |
-
|
56 |
-
def search_legal_database(self, query, limit=5):
|
57 |
-
"""使用Milvus搜索法律数据库"""
|
58 |
-
if not query or query.strip() == "" or query.strip().lower() == "query":
|
59 |
-
return "无效的搜索查询。请提供具体的搜索内容。"
|
60 |
-
|
61 |
-
# 检查数据库中是否有数据
|
62 |
-
collection_stats = self.milvus_client.get_collection_stats(self.collection_name)
|
63 |
-
row_count = collection_stats.get("row_count", 0)
|
64 |
-
|
65 |
-
if row_count == 0:
|
66 |
-
print("⚠️ Collection is empty, adding sample data...")
|
67 |
-
self._add_sample_data()
|
68 |
-
|
69 |
-
# 生成查询向量
|
70 |
-
query_vector = self.embedding_fn.encode_queries([query])
|
71 |
-
|
72 |
-
# 执行搜索
|
73 |
-
search_results = self.milvus_client.search(
|
74 |
-
collection_name=self.collection_name,
|
75 |
-
data=query_vector,
|
76 |
-
limit=limit,
|
77 |
-
output_fields=["text", "page_num", "source"]
|
78 |
-
)
|
79 |
-
|
80 |
-
# 检查是否有结果
|
81 |
-
if not search_results or len(search_results[0]) == 0:
|
82 |
-
return "没有找到与此查询相关的结果。"
|
83 |
-
|
84 |
-
# 格式化搜索结果
|
85 |
-
formatted_results = []
|
86 |
-
for i, result in enumerate(search_results[0]):
|
87 |
-
similarity = 1 - result['distance']
|
88 |
-
source = result['entity'].get('source', 'Unknown source')
|
89 |
-
page_num = result['entity'].get('page_num', 'Unknown page')
|
90 |
-
text = result['entity'].get('text', '')
|
91 |
-
|
92 |
-
formatted_result = f"[结果 {i+1}] 来源: {source}, 页码: {page_num}, 相关度: {similarity:.4f}\n"
|
93 |
-
formatted_result += f"内容: {text}\n\n"
|
94 |
-
formatted_results.append(formatted_result)
|
95 |
-
|
96 |
-
return "\n".join(formatted_results)
|
97 |
-
|
98 |
-
def _add_sample_data(self):
|
99 |
-
"""添加示例法律文本数据到空集合中"""
|
100 |
-
docs = [
|
101 |
-
"Ontario Regulation 213/91 (Construction Projects) under the Occupational Health and Safety Act contains provisions for construction safety. Section 26 requires that every worker who may be exposed to the hazard of falling more than 3 metres shall use a fall protection system.",
|
102 |
-
"Under the Canada Labour Code, employers have a duty to ensure that the health and safety at work of every person employed by the employer is protected (Section 124). This includes providing proper training and supervision.",
|
103 |
-
"The Criminal Code of Canada Section 217.1 states that everyone who undertakes, or has the authority, to direct how another person does work or performs a task is under a legal duty to take reasonable steps to prevent bodily harm to that person, or any other person, arising from that work or task.",
|
104 |
-
"British Columbia's Workers Compensation Act requires employers to ensure the health and safety of all workers and comply with occupational health and safety regulations. This includes providing proper equipment, training, and supervision for construction activities.",
|
105 |
-
"Alberta's Occupational Health and Safety Code (Part 9) contains specific requirements for fall protection systems when workers are at heights of 3 metres or more, including the use of guardrails, safety nets, or personal fall arrest systems.",
|
106 |
-
"Employment Standards Act provides minimum standards for wages, hours of work, overtime pay, vacation time, and termination notice. Employers must provide at least two weeks written notice or pay in lieu for termination.",
|
107 |
-
"Canadian Charter of Rights and Freedoms Section 15 guarantees equality rights. Every individual is equal before and under the law and has the right to equal protection and benefit of the law without discrimination.",
|
108 |
-
"Personal Information Protection and Electronic Documents Act (PIPEDA) governs how private sector organizations collect, use and disclose personal information in the course of commercial activities.",
|
109 |
-
"Competition Act prohibits anti-competitive practices including conspiracies, bid-rigging, price maintenance, and abuse of dominant position. The Act is enforced by the Competition Bureau.",
|
110 |
-
"Consumer Protection Acts across Canada provide consumers with rights including cooling-off periods for certain contracts, protection from unfair business practices, and warranties for goods and services."
|
111 |
-
]
|
112 |
-
|
113 |
-
# 生成向量
|
114 |
-
vectors = self.embedding_fn.encode_documents(docs)
|
115 |
-
|
116 |
-
# 准备数据
|
117 |
-
data = []
|
118 |
-
for i in range(len(docs)):
|
119 |
-
source_name = f"Canadian Legal Code {i+1}"
|
120 |
-
data.append({
|
121 |
-
"id": i,
|
122 |
-
"vector": vectors[i],
|
123 |
-
"text": docs[i],
|
124 |
-
"page_num": 1,
|
125 |
-
"source": source_name
|
126 |
-
})
|
127 |
-
|
128 |
-
# 插入数据
|
129 |
-
self.milvus_client.insert(collection_name=self.collection_name, data=data)
|
130 |
-
print(f"✅ Added {len(data)} sample legal documents to collection")
|
131 |
-
|
132 |
-
def _analyze_query_need(self, user_message):
|
133 |
-
"""分析用户消息,判断是否需要搜索法律数据库"""
|
134 |
-
# 简化的搜索判断逻辑
|
135 |
-
legal_keywords = [
|
136 |
-
"law", "legal", "regulation", "act", "code", "section", "rights",
|
137 |
-
"protection", "safety", "employment", "contract", "liability",
|
138 |
-
"charter", "constitution", "statute", "criminal", "civil",
|
139 |
-
"搜索", "查询", "数据库", "法律", "法规", "条例"
|
140 |
-
]
|
141 |
-
|
142 |
-
user_message_lower = user_message.lower()
|
143 |
-
|
144 |
-
# 检查是否包含法律关键词
|
145 |
-
needs_search = any(keyword in user_message_lower for keyword in legal_keywords)
|
146 |
-
|
147 |
-
if needs_search:
|
148 |
-
# 清理查询内容
|
149 |
-
clean_query = user_message.strip()
|
150 |
-
return {
|
151 |
-
"needs_search": True,
|
152 |
-
"reasoning": "检测到法律相关查询",
|
153 |
-
"queries": [clean_query]
|
154 |
-
}
|
155 |
-
else:
|
156 |
-
return {
|
157 |
-
"needs_search": False,
|
158 |
-
"reasoning": "未检测到法律关键词",
|
159 |
-
"queries": []
|
160 |
-
}
|
161 |
-
|
162 |
-
def process_message_with_history(self, user_message, history, system_message):
|
163 |
-
"""
|
164 |
-
处理用户消息,兼容Gradio ChatInterface的历史格式
|
165 |
-
"""
|
166 |
-
# 构建完整的对话历史
|
167 |
-
messages = [{"role": "system", "content": system_message}]
|
168 |
-
|
169 |
-
# 添加历史对话
|
170 |
-
for user_msg, assistant_msg in history:
|
171 |
-
if user_msg:
|
172 |
-
messages.append({"role": "user", "content": user_msg})
|
173 |
-
if assistant_msg:
|
174 |
-
messages.append({"role": "assistant", "content": assistant_msg})
|
175 |
-
|
176 |
-
# 分析是否需要搜索
|
177 |
-
analysis = self._analyze_query_need(user_message)
|
178 |
-
|
179 |
-
search_results = ""
|
180 |
-
if analysis.get("needs_search", False) and analysis.get("queries"):
|
181 |
-
# 执行搜索
|
182 |
-
all_results = []
|
183 |
-
for query in analysis["queries"][:2]: # 最多执行2个查询
|
184 |
-
print(f"🔍 Searching for: {query}")
|
185 |
-
result = self.search_legal_database(query)
|
186 |
-
if result and result.strip():
|
187 |
-
all_results.append(f"查询: {query}\n{result}")
|
188 |
-
|
189 |
-
if all_results:
|
190 |
-
search_results = "\n\n" + "="*50 + "\n".join(all_results)
|
191 |
-
|
192 |
-
# 添加当前用户消息
|
193 |
-
if search_results:
|
194 |
-
enhanced_message = f"{user_message}\n\n以下是相关的法律搜索结果,请在回答中引用这些信息:\n{search_results}"
|
195 |
-
messages.append({"role": "user", "content": enhanced_message})
|
196 |
-
else:
|
197 |
-
messages.append({"role": "user", "content": user_message})
|
198 |
-
|
199 |
-
return messages
|
200 |
|
201 |
# 初始化聊天机器人
|
202 |
def initialize_chatbot():
|
203 |
-
# 从环境变量获取配置
|
204 |
-
MILVUS_DB_PATH = os.environ.get("MILVUS_DB_PATH", "./milvus_legal_codes.db")
|
205 |
-
COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "legal_codes_collection")
|
206 |
-
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
|
207 |
-
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "")
|
208 |
-
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
|
209 |
-
|
210 |
if not OPENAI_API_KEY:
|
211 |
raise ValueError("❌ OPENAI_API_KEY environment variable is required!")
|
212 |
|
@@ -225,9 +35,11 @@ def initialize_chatbot():
|
|
225 |
try:
|
226 |
chatbot = initialize_chatbot()
|
227 |
print("✅ Chatbot initialized successfully!")
|
|
|
228 |
except Exception as e:
|
229 |
print(f"❌ Failed to initialize chatbot: {e}")
|
230 |
chatbot = None
|
|
|
231 |
|
232 |
def respond(
|
233 |
message,
|
@@ -237,28 +49,34 @@ def respond(
|
|
237 |
temperature,
|
238 |
top_p,
|
239 |
):
|
|
|
|
|
|
|
|
|
240 |
if chatbot is None:
|
241 |
yield "❌ Chatbot not initialized. Please check the configuration."
|
242 |
return
|
243 |
|
244 |
try:
|
245 |
-
#
|
246 |
-
|
247 |
|
248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
|
250 |
-
#
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
stream=True,
|
256 |
-
temperature=temperature,
|
257 |
-
top_p=top_p,
|
258 |
-
):
|
259 |
-
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
|
260 |
-
token = chunk.choices[0].delta.content
|
261 |
-
response += token
|
262 |
yield response
|
263 |
|
264 |
except Exception as e:
|
@@ -266,14 +84,22 @@ def respond(
|
|
266 |
yield f"抱歉,处理您的消息时出现错误:{str(e)}"
|
267 |
|
268 |
# 准备描述信息
|
269 |
-
base_description = "🤖 AI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
|
271 |
-
|
272 |
-
status_info = f"\n\n✅ **Status**: Connected to database with {chatbot.collection_name} collection"
|
273 |
-
examples_info = "\n\n💡 **Try asking:**\n• What are the fall protection requirements in Ontario construction?\n• Tell me about employer duties under Canada Labour Code\n• What are my rights under the Charter of Rights and Freedoms?\n• Search for information about workplace safety regulations"
|
274 |
-
full_description = base_description + status_info + examples_info
|
275 |
-
else:
|
276 |
-
full_description = base_description + "\n\n❌ **Status**: Configuration error - please check environment variables"
|
277 |
|
278 |
# 创建Gradio ChatInterface
|
279 |
demo = gr.ChatInterface(
|
@@ -284,23 +110,39 @@ demo = gr.ChatInterface(
|
|
284 |
gr.Textbox(
|
285 |
value="You are a helpful legal assistant with expertise in Canadian law. You have access to a legal database and should provide accurate, well-sourced legal information. Always cite specific legal sources when possible. Remember to include appropriate disclaimers that this is for informational purposes only and not legal advice.",
|
286 |
label="System Message",
|
287 |
-
lines=3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
),
|
289 |
-
gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max new tokens"),
|
290 |
-
gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
|
291 |
gr.Slider(
|
292 |
minimum=0.1,
|
293 |
maximum=1.0,
|
294 |
value=0.95,
|
295 |
step=0.05,
|
296 |
-
label="Top-p (nucleus sampling)"
|
297 |
),
|
298 |
-
]
|
|
|
|
|
299 |
)
|
300 |
|
301 |
if __name__ == "__main__":
|
302 |
demo.launch(
|
303 |
server_name="0.0.0.0",
|
304 |
server_port=7860,
|
305 |
-
share=False
|
|
|
306 |
)
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
"""
|
3 |
Legal RAG Chatbot for Hugging Face Spaces
|
4 |
+
直接使用原有的chatbot.py,只添加Gradio界面
|
5 |
"""
|
6 |
|
7 |
import os
|
8 |
import gradio as gr
|
9 |
+
from chatbot import LegalChatbot
|
|
|
|
|
10 |
|
11 |
+
# 配置信息 - 从环境变量获取
|
12 |
+
MILVUS_DB_PATH = os.environ.get("MILVUS_DB_PATH", "./milvus_legal_codes.db")
|
13 |
+
COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "legal_codes_collection")
|
14 |
+
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
|
15 |
+
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "")
|
16 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# 初始化聊天机器人
|
19 |
def initialize_chatbot():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
if not OPENAI_API_KEY:
|
21 |
raise ValueError("❌ OPENAI_API_KEY environment variable is required!")
|
22 |
|
|
|
35 |
try:
|
36 |
chatbot = initialize_chatbot()
|
37 |
print("✅ Chatbot initialized successfully!")
|
38 |
+
chatbot_status = f"✅ **Status**: Connected to database with {chatbot.collection_name} collection"
|
39 |
except Exception as e:
|
40 |
print(f"❌ Failed to initialize chatbot: {e}")
|
41 |
chatbot = None
|
42 |
+
chatbot_status = f"❌ **Status**: Configuration error - {str(e)}"
|
43 |
|
44 |
def respond(
|
45 |
message,
|
|
|
49 |
temperature,
|
50 |
top_p,
|
51 |
):
|
52 |
+
"""
|
53 |
+
Gradio ChatInterface响应函数
|
54 |
+
使用原有chatbot的流式处理功能
|
55 |
+
"""
|
56 |
if chatbot is None:
|
57 |
yield "❌ Chatbot not initialized. Please check the configuration."
|
58 |
return
|
59 |
|
60 |
try:
|
61 |
+
# 重置聊天机器人的对话历史
|
62 |
+
chatbot.reset_conversation()
|
63 |
|
64 |
+
# 设置系统消息
|
65 |
+
if system_message.strip():
|
66 |
+
chatbot.conversation_history[0]["content"] = system_message
|
67 |
+
|
68 |
+
# 添加历史对话到聊天机器人
|
69 |
+
for user_msg, assistant_msg in history:
|
70 |
+
if user_msg:
|
71 |
+
chatbot.conversation_history.append({"role": "user", "content": user_msg})
|
72 |
+
if assistant_msg:
|
73 |
+
chatbot.conversation_history.append({"role": "assistant", "content": assistant_msg})
|
74 |
|
75 |
+
# 使用原有的流式处理功能
|
76 |
+
response = ""
|
77 |
+
for chunk in chatbot.process_message_stream(message):
|
78 |
+
if chunk:
|
79 |
+
response += chunk
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
yield response
|
81 |
|
82 |
except Exception as e:
|
|
|
84 |
yield f"抱歉,处理您的消息时出现错误:{str(e)}"
|
85 |
|
86 |
# 准备描述信息
|
87 |
+
base_description = """🤖 **AI法律助手** - 结合向量数据库搜索和大语言模型的智能法律咨询系统
|
88 |
+
|
89 |
+
🔍 **核心功能:**
|
90 |
+
- 智能查询分析 - 自动判断是否需要搜索法律数据库
|
91 |
+
- 向量相似度搜索 - 基于Milvus的高效法律文档检索
|
92 |
+
- RAG增强生成 - 结合搜索结果提供准确回答
|
93 |
+
- 实时流式回复 - 支持打字机效果的实时响应
|
94 |
+
|
95 |
+
💡 **试试这些问题:**
|
96 |
+
• "What are the fall protection requirements in Ontario construction?"
|
97 |
+
• "Tell me about employer duties under Canada Labour Code"
|
98 |
+
• "Search for information about workplace safety regulations"
|
99 |
+
• "What are my rights under the Charter of Rights and Freedoms?"
|
100 |
+
"""
|
101 |
|
102 |
+
full_description = base_description + f"\n\n{chatbot_status}"
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
# 创建Gradio ChatInterface
|
105 |
demo = gr.ChatInterface(
|
|
|
110 |
gr.Textbox(
|
111 |
value="You are a helpful legal assistant with expertise in Canadian law. You have access to a legal database and should provide accurate, well-sourced legal information. Always cite specific legal sources when possible. Remember to include appropriate disclaimers that this is for informational purposes only and not legal advice.",
|
112 |
label="System Message",
|
113 |
+
lines=3,
|
114 |
+
max_lines=5
|
115 |
+
),
|
116 |
+
gr.Slider(
|
117 |
+
minimum=1,
|
118 |
+
maximum=2048,
|
119 |
+
value=1024,
|
120 |
+
step=1,
|
121 |
+
label="Max new tokens"
|
122 |
+
),
|
123 |
+
gr.Slider(
|
124 |
+
minimum=0.1,
|
125 |
+
maximum=2.0,
|
126 |
+
value=0.7,
|
127 |
+
step=0.1,
|
128 |
+
label="Temperature"
|
129 |
),
|
|
|
|
|
130 |
gr.Slider(
|
131 |
minimum=0.1,
|
132 |
maximum=1.0,
|
133 |
value=0.95,
|
134 |
step=0.05,
|
135 |
+
label="Top-p (nucleus sampling)"
|
136 |
),
|
137 |
+
],
|
138 |
+
theme=gr.themes.Soft(),
|
139 |
+
analytics_enabled=False
|
140 |
)
|
141 |
|
142 |
if __name__ == "__main__":
|
143 |
demo.launch(
|
144 |
server_name="0.0.0.0",
|
145 |
server_port=7860,
|
146 |
+
share=False,
|
147 |
+
show_error=True
|
148 |
)
|