StevenChen16 commited on
Commit
8cc98f8
·
verified ·
1 Parent(s): 12e0fa1

Update chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +436 -407
chatbot.py CHANGED
@@ -1,408 +1,437 @@
1
- import re
2
- import json
3
- from pymilvus import MilvusClient, model
4
- from openai import OpenAI
5
- import time
6
-
7
- class LegalChatbot:
8
- def __init__(self, milvus_db_path, collection_name, openai_api_key, openai_base_url=None, model_name="deepseek-reasoner"):
9
- """
10
- 初始化法律RAG聊天机器人
11
-
12
- 参数:
13
- milvus_db_path: Milvus数据库路径
14
- collection_name: 要搜索的集合名称
15
- openai_api_key: OpenAI API密钥
16
- openai_base_url: 可选的API基础URL(对于使用DeepSeek等服务)
17
- model_name: 要使用的LLM模型名称
18
- """
19
- # 初始化Milvus客户端
20
- self.milvus_client = MilvusClient(milvus_db_path)
21
- self.collection_name = collection_name
22
-
23
- # 检查集合是否存在,如果不存在则创建它
24
- if not self.milvus_client.has_collection(collection_name=collection_name):
25
- print(f"Collection '{collection_name}' does not exist. Creating it...")
26
- # 初始化嵌入模型
27
- self.embedding_fn = model.DefaultEmbeddingFunction()
28
- vector_dim = self.embedding_fn.dim
29
-
30
- # 创建新的集合
31
- self.milvus_client.create_collection(
32
- collection_name=collection_name,
33
- dimension=vector_dim
34
- )
35
- print(f"Collection '{collection_name}' created successfully.")
36
-
37
- # 初始化嵌入模型
38
- self.embedding_fn = model.DefaultEmbeddingFunction()
39
-
40
- # 初始化OpenAI客户端
41
- if openai_base_url:
42
- self.llm_client = OpenAI(api_key=openai_api_key, base_url=openai_base_url)
43
- else:
44
- self.llm_client = OpenAI(api_key=openai_api_key)
45
-
46
- self.model_name = model_name
47
- self.conversation_history = [
48
- {"role": "system", "content": """You are a helpful paralegal assistant with expertise in Canadian and U.S. law.
49
-
50
- You will help users with their legal questions. When answering, you should be helpful, accurate, and cite specific legal sources when possible.
51
-
52
- Users are members of the general public and may ask questions in Chinese or English. Please respond in the same language as the user's question.
53
- """}
54
- ]
55
-
56
- def search_legal_database(self, query, limit=5):
57
- """
58
- 使用Milvus搜索法律数据库
59
-
60
- 参数:
61
- query: 搜索查询
62
- limit: 返回结果的数量
63
-
64
- 返回:
65
- 格式化的搜索结果字符串
66
- """
67
- if not query or query.strip() == "" or query.strip().lower() == "query":
68
- return "无效的搜索查询。请提供具体的搜索内容。"
69
-
70
- # 检查数据库中是否有数据
71
- collection_stats = self.milvus_client.get_collection_stats(self.collection_name)
72
- row_count = collection_stats.get("row_count", 0)
73
-
74
- if row_count == 0:
75
- # 如果集合为空,添加一些示例数据
76
- print("集合为空,添加示例数据...")
77
- self._add_sample_data()
78
-
79
- # 生成查询向量
80
- query_vector = self.embedding_fn.encode_queries([query])
81
-
82
- # 执行搜索
83
- search_results = self.milvus_client.search(
84
- collection_name=self.collection_name,
85
- data=query_vector,
86
- limit=limit,
87
- output_fields=["text", "page_num", "source"]
88
- )
89
-
90
- # 检查是否有结果
91
- if not search_results or len(search_results[0]) == 0:
92
- return "没有找到与此查询相关的结果。"
93
-
94
- # 格式化搜索结果
95
- formatted_results = []
96
- for i, result in enumerate(search_results[0]):
97
- similarity = 1 - result['distance']
98
- source = result['entity'].get('source', 'Unknown source')
99
- page_num = result['entity'].get('page_num', 'Unknown page')
100
- text = result['entity'].get('text', '')
101
-
102
- formatted_result = f"[结果 {i+1}] 来源: {source}, 页码: {page_num}, 相关度: {similarity:.4f}\n"
103
- formatted_result += f"内容: {text}\n\n"
104
- formatted_results.append(formatted_result)
105
-
106
- return "\n".join(formatted_results)
107
-
108
- def _add_sample_data(self):
109
- """添加示例法律文本数据到空集合中"""
110
- # 简单的法律文本示例
111
- docs = [
112
- "Ontario Regulation 213/91 (Construction Projects) under the Occupational Health and Safety Act contains provisions for construction safety. Section 26 requires that every worker who may be exposed to the hazard of falling more than 3 metres shall use a fall protection system.",
113
- "Under the Canada Labour Code, employers have a duty to ensure that the health and safety at work of every person employed by the employer is protected (Section 124). This includes providing proper training and supervision.",
114
- "The Criminal Code of Canada Section 217.1 states that everyone who undertakes, or has the authority, to direct how another person does work or performs a task is under a legal duty to take reasonable steps to prevent bodily harm to that person, or any other person, arising from that work or task.",
115
- "British Columbia's Workers Compensation Act requires employers to ensure the health and safety of all workers and comply with occupational health and safety regulations. This includes providing proper equipment, training, and supervision for construction activities.",
116
- "Alberta's Occupational Health and Safety Code (Part 9) contains specific requirements for fall protection systems when workers are at heights of 3 metres or more, including the use of guardrails, safety nets, or personal fall arrest systems."
117
- ]
118
-
119
- # 生成向量
120
- vectors = self.embedding_fn.encode_documents(docs)
121
-
122
- # 准备数据
123
- data = []
124
- for i in range(len(docs)):
125
- source_name = f"Sample Legal Text {i+1}"
126
- data.append({
127
- "id": i,
128
- "vector": vectors[i],
129
- "text": docs[i],
130
- "page_num": 1,
131
- "source": source_name
132
- })
133
-
134
- # 插入数据
135
- self.milvus_client.insert(collection_name=self.collection_name, data=data)
136
- print(f"已添加 {len(data)} 条示例数据到集合")
137
-
138
- def _analyze_query_need(self, user_message):
139
- """
140
- 分析用户消息,判断是否需要搜索法律数据库
141
-
142
- 参数:
143
- user_message: 用户的消息
144
-
145
- 返回:
146
- dict: {"needs_search": bool, "queries": list}
147
- """
148
- # 预处理:检查用户是否明确要求搜索
149
- search_keywords = [
150
- "search in database", "search in the database", "search database",
151
- "look up", "find in database", "search for", "after searching",
152
- "查询数据库", "搜索数据库", "数据库查找"
153
- ]
154
-
155
- user_message_lower = user_message.lower()
156
- explicit_search_request = any(keyword in user_message_lower for keyword in search_keywords)
157
-
158
- if explicit_search_request:
159
- print("检测到用户明确要求搜索数据库")
160
- # 清理查询内容,去掉搜索相关的指令(大小写不敏感)
161
- clean_query = user_message
162
-
163
- # 所有需要去掉的短语
164
- all_phrases_to_remove = search_keywords + [
165
- "Answer me after searching in the database", "answer me after",
166
- "please search", "search and tell me", "look up and answer",
167
- "tell me", "what is", "what are", "explain"
168
- ]
169
-
170
- for phrase in all_phrases_to_remove:
171
- # 大小写不敏感的替换
172
- import re
173
- pattern = re.compile(re.escape(phrase), re.IGNORECASE)
174
- clean_query = pattern.sub("", clean_query)
175
-
176
- clean_query = clean_query.strip(".,?! ")
177
-
178
- if not clean_query or len(clean_query) < 3:
179
- clean_query = "legal information"
180
-
181
- return {
182
- "needs_search": True,
183
- "reasoning": "用户明确要求搜索数据库",
184
- "queries": [clean_query]
185
- }
186
- analysis_prompt = [
187
- {"role": "system", "content": """You are an AI assistant that analyzes user questions to determine if they need legal database searches.
188
-
189
- Your task is to analyze the user's question and determine:
190
- 1. Whether this question requires searching a legal database
191
- 2. If yes, what specific search queries would be most helpful
192
-
193
- Respond in JSON format:
194
- {
195
- "needs_search": true/false,
196
- "reasoning": "brief explanation of why search is or isn't needed",
197
- "queries": ["query1", "query2"] // only if needs_search is true
198
- }
199
-
200
- IMPORTANT RULES:
201
- 1. If the user explicitly requests database search (phrases like "search in database", "look up", "find in database"), always set needs_search to true
202
- 2. For ANY legal topic question, default to needs_search = true unless it's clearly a simple greeting or completely non-legal
203
- 3. Legal topics include: laws, regulations, legal procedures, legal documents, legal concepts, legal rights, etc.
204
-
205
- Search should be needed for:
206
- - ANY legal question (wills, trusts, contracts, rights, procedures, etc.)
207
- - Questions about specific laws, regulations, or legal codes
208
- - Requests for legal precedents or case law
209
- - Questions about legal procedures or requirements
210
- - Legal document comparisons (like will vs trust)
211
- - When user explicitly asks to search database
212
-
213
- Search should NOT be needed ONLY for:
214
- - Simple greetings ("hello", "how are you")
215
- - Completely non-legal topics (weather, sports, etc.)
216
- - Technical issues with the system itself
217
- """}
218
- ]
219
-
220
- # 添加当前对话历史的最后几条消息作为上下文
221
- context_messages = self.conversation_history[-3:] if len(self.conversation_history) > 3 else self.conversation_history[1:]
222
- for msg in context_messages:
223
- analysis_prompt.append(msg)
224
-
225
- analysis_prompt.append({"role": "user", "content": f"Analyze this question: {user_message}"})
226
-
227
- response = self.llm_client.chat.completions.create(
228
- model=self.model_name,
229
- messages=analysis_prompt,
230
- stream=False,
231
- temperature=0.1
232
- )
233
-
234
- response_content = response.choices[0].message.content.strip()
235
- print(f"LLM原始响应: {response_content}")
236
-
237
- # 尝试提取JSON内容(如果响应包含其他文本)
238
- import re
239
- json_match = re.search(r'\{.*\}', response_content, re.DOTALL)
240
- if json_match:
241
- json_content = json_match.group(0)
242
- else:
243
- json_content = response_content
244
-
245
- analysis_result = json.loads(json_content)
246
- print(f"查询分析结果: {analysis_result}")
247
- return analysis_result
248
-
249
- def process_message(self, user_message):
250
- """
251
- 处理用户消息并生成响应(两阶段模式)
252
-
253
- 参数:
254
- user_message: 用户的消息
255
-
256
- 返回:
257
- 助手的响应
258
- """
259
- # 将用户消息添加到对话历史
260
- self.conversation_history.append({"role": "user", "content": user_message})
261
-
262
- # 第一阶段:分析是否需要搜索
263
- analysis = self._analyze_query_need(user_message)
264
-
265
- search_results = ""
266
- if analysis.get("needs_search", False) and analysis.get("queries"):
267
- # 第二阶段:执行搜索
268
- all_results = []
269
- for query in analysis["queries"][:2]: # 最多执行2个查询
270
- print(f"执行搜索查询: {query}")
271
- result = self.search_legal_database(query)
272
- if result and result.strip():
273
- all_results.append(f"查询: {query}\n{result}")
274
-
275
- if all_results:
276
- search_results = "\n\n" + "="*50 + "\n".join(all_results)
277
-
278
- # 第三阶段:基于搜索结果生成回答
279
- final_prompt = self.conversation_history.copy()
280
-
281
- if search_results:
282
- final_prompt.append({
283
- "role": "system",
284
- "content": f"以下是相关的法律搜索结果,请在回答中引用这些信息:\n{search_results}\n\n请基于这些搜索结果回答用户的问题,并引用具体的来源和页码。"
285
- })
286
- response = self.llm_client.chat.completions.create(
287
- model=self.model_name,
288
- messages=final_prompt,
289
- stream=False
290
- )
291
-
292
- assistant_response = response.choices[0].message.content
293
-
294
- # 将最终响应添加到对话历史
295
- self.conversation_history.append({"role": "assistant", "content": assistant_response})
296
-
297
- return assistant_response
298
-
299
- def process_message_stream(self, user_message):
300
- """
301
- 处理用户消息并以流式方式返回响应,支持智能RAG查询
302
-
303
- Args:
304
- user_message (str): 用户输入的消息
305
-
306
- Yields:
307
- str: 响应文本的片段
308
- """
309
- # 添加用户消息到对话历史
310
- self.conversation_history.append({"role": "user", "content": user_message})
311
-
312
- # 第一阶段:分析是否需要搜索
313
- analysis = self._analyze_query_need(user_message)
314
-
315
- search_results = ""
316
- if analysis.get("needs_search", False) and analysis.get("queries"):
317
- # 输出搜索提示
318
- yield "[正在搜索相关法律信息...]\n\n"
319
-
320
- # 第二阶段:执行搜索
321
- all_results = []
322
- for query in analysis["queries"][:2]: # 最多执行2个查询
323
- print(f"执行流式搜索查询: {query}")
324
- result = self.search_legal_database(query)
325
- if result and result.strip():
326
- all_results.append(f"查询: {query}\n{result}")
327
-
328
- if all_results:
329
- search_results = "\n\n" + "="*50 + "\n".join(all_results)
330
- yield "[搜索完成,正在生成回答...]\n\n"
331
-
332
- # 第三阶段:基于搜索结果生成流式回答
333
- final_prompt = self.conversation_history.copy()
334
-
335
- if search_results:
336
- final_prompt.append({
337
- "role": "system",
338
- "content": f"以下是相关的法律搜索结果,请在回答中引用这些信息:\n{search_results}\n\n请基于这些搜索结果回答用户的问题,并引用具体的来源和页码。"
339
- })
340
-
341
- # 创建流式完成请求
342
- response = self.llm_client.chat.completions.create(
343
- model=self.model_name,
344
- messages=final_prompt,
345
- stream=True,
346
- temperature=0.3,
347
- max_tokens=2048
348
- )
349
-
350
- full_response = "" # 存储完整响应
351
-
352
- # 处理流式响应
353
- for chunk in response:
354
- if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
355
- content = chunk.choices[0].delta.content
356
- full_response += content
357
- yield content
358
-
359
- # 将最终响应添加到对话历史
360
- self.conversation_history.append({"role": "assistant", "content": full_response})
361
-
362
- def reset_conversation(self):
363
- """重置对话历史"""
364
- self.conversation_history = [self.conversation_history[0]] # 保留系统消息
365
-
366
- def main():
367
- # 配置信息
368
- MILVUS_DB_PATH = "./milvus_legal_codes.db" # 使用您之前创建的数据库名,或创建新的
369
- COLLECTION_NAME = "legal_codes_collection" # 使用您之前创建的集合名或新名称
370
- # OPENAI_API_KEY = "sk-dad31a53a4684587aed060afc0e4d75b" # 请替换为实际的API密钥
371
- # OPENAI_BASE_URL = "https://api.deepseek.com" # 如果使用OpenAI API,请移除此行
372
- OPENAI_API_KEY = "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" # 请替换为实际的API密钥
373
- OPENAI_BASE_URL = "" # 如果使用OpenAI API,请移除此行
374
-
375
- # 初始化聊天机器人
376
- chatbot = LegalChatbot(
377
- milvus_db_path=MILVUS_DB_PATH,
378
- collection_name=COLLECTION_NAME,
379
- openai_api_key=OPENAI_API_KEY,
380
- openai_base_url=OPENAI_BASE_URL,
381
- # model_name="deepseek-chat"
382
- model_name="gpt-4o"
383
- )
384
-
385
- print("法律RAG聊天机器人已初始化。输入'exit'或'quit'结束会话。")
386
-
387
- while True:
388
- user_input = input("\n您: ")
389
-
390
- if user_input.lower() in ['exit', 'quit']:
391
- print("会话结束。")
392
- break
393
-
394
- if user_input.lower() in ['reset', 'clear']:
395
- chatbot.reset_conversation()
396
- print("已重置对话历史。")
397
- continue
398
-
399
- print("\n正在思考...")
400
- start_time = time.time()
401
-
402
- response = chatbot.process_message(user_input)
403
-
404
- end_time = time.time()
405
- print(f"助手 ({end_time - start_time:.2f}秒): {response}")
406
-
407
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  main()
 
1
+ import re
2
+ import json
3
+ from pymilvus import MilvusClient, model
4
+ from openai import OpenAI
5
+ import time
6
+
7
+ class LegalChatbot:
8
+ def __init__(self, milvus_db_path, collection_name, openai_api_key, openai_base_url=None, model_name="deepseek-reasoner"):
9
+ """
10
+ Initialize Legal RAG Chatbot
11
+
12
+ Args:
13
+ milvus_db_path: Milvus database path
14
+ collection_name: Collection name to search
15
+ openai_api_key: OpenAI API key
16
+ openai_base_url: Optional API base URL (for DeepSeek etc.)
17
+ model_name: LLM model name to use
18
+ """
19
+ # Initialize Milvus client
20
+ self.milvus_client = MilvusClient(milvus_db_path)
21
+ self.collection_name = collection_name
22
+
23
+ # Check if collection exists, create if not
24
+ if not self.milvus_client.has_collection(collection_name=collection_name):
25
+ print(f"Collection '{collection_name}' does not exist. Creating it...")
26
+ # Initialize embedding model
27
+ self.embedding_fn = model.DefaultEmbeddingFunction()
28
+ vector_dim = self.embedding_fn.dim
29
+
30
+ # Create new collection
31
+ self.milvus_client.create_collection(
32
+ collection_name=collection_name,
33
+ dimension=vector_dim
34
+ )
35
+ print(f"Collection '{collection_name}' created successfully.")
36
+
37
+ # Initialize embedding model
38
+ self.embedding_fn = model.DefaultEmbeddingFunction()
39
+
40
+ # Initialize OpenAI client
41
+ if openai_base_url:
42
+ self.llm_client = OpenAI(api_key=openai_api_key, base_url=openai_base_url)
43
+ else:
44
+ self.llm_client = OpenAI(api_key=openai_api_key)
45
+
46
+ self.model_name = model_name
47
+ self.conversation_history = [
48
+ {"role": "system", "content": """You are a helpful paralegal assistant with expertise in Canadian and U.S. law.
49
+
50
+ You will help users with their legal questions. When answering, you should be helpful, accurate, and cite specific legal sources when possible.
51
+
52
+ Users are members of the general public and may ask questions in Chinese or English. Please respond in the same language as the user's question.
53
+ """}
54
+ ]
55
+
56
+ def search_legal_database(self, query, limit=5):
57
+ """
58
+ Search legal database using Milvus
59
+
60
+ Args:
61
+ query: Search query
62
+ limit: Number of results to return
63
+
64
+ Returns:
65
+ Formatted search results string
66
+ """
67
+ if not query or query.strip() == "" or query.strip().lower() == "query":
68
+ return "Invalid search query. Please provide specific search content."
69
+
70
+ # Check if database has data
71
+ collection_stats = self.milvus_client.get_collection_stats(self.collection_name)
72
+ row_count = collection_stats.get("row_count", 0)
73
+
74
+ if row_count == 0:
75
+ # If collection is empty, add sample data
76
+ print("Collection is empty, adding sample data...")
77
+ self._add_sample_data()
78
+
79
+ # Generate query vector
80
+ query_vector = self.embedding_fn.encode_queries([query])
81
+
82
+ # Execute search
83
+ search_results = self.milvus_client.search(
84
+ collection_name=self.collection_name,
85
+ data=query_vector,
86
+ limit=limit,
87
+ output_fields=["text", "page_num", "source"]
88
+ )
89
+
90
+ # Check if there are results
91
+ if not search_results or len(search_results[0]) == 0:
92
+ return "No results found related to this query."
93
+
94
+ # Format search results
95
+ formatted_results = []
96
+ for i, result in enumerate(search_results[0]):
97
+ similarity = 1 - result['distance']
98
+ source = result['entity'].get('source', 'Unknown source')
99
+ page_num = result['entity'].get('page_num', 'Unknown page')
100
+ text = result['entity'].get('text', '')
101
+
102
+ formatted_result = f"[Result {i+1}] Source: {source}, Page: {page_num}, Relevance: {similarity:.4f}\n"
103
+ formatted_result += f"Content: {text}\n\n"
104
+ formatted_results.append(formatted_result)
105
+
106
+ return "\n".join(formatted_results)
107
+
108
+ def _add_sample_data(self):
109
+ """Add sample legal text data to empty collection"""
110
+ # Simple legal text examples
111
+ docs = [
112
+ "Ontario Regulation 213/91 (Construction Projects) under the Occupational Health and Safety Act contains provisions for construction safety. Section 26 requires that every worker who may be exposed to the hazard of falling more than 3 metres shall use a fall protection system.",
113
+ "Under the Canada Labour Code, employers have a duty to ensure that the health and safety at work of every person employed by the employer is protected (Section 124). This includes providing proper training and supervision.",
114
+ "The Criminal Code of Canada Section 217.1 states that everyone who undertakes, or has the authority, to direct how another person does work or performs a task is under a legal duty to take reasonable steps to prevent bodily harm to that person, or any other person, arising from that work or task.",
115
+ "British Columbia's Workers Compensation Act requires employers to ensure the health and safety of all workers and comply with occupational health and safety regulations. This includes providing proper equipment, training, and supervision for construction activities.",
116
+ "Alberta's Occupational Health and Safety Code (Part 9) contains specific requirements for fall protection systems when workers are at heights of 3 metres or more, including the use of guardrails, safety nets, or personal fall arrest systems."
117
+ ]
118
+
119
+ # Generate vectors
120
+ vectors = self.embedding_fn.encode_documents(docs)
121
+
122
+ # Prepare data
123
+ data = []
124
+ for i in range(len(docs)):
125
+ source_name = f"Sample Legal Text {i+1}"
126
+ data.append({
127
+ "id": i,
128
+ "vector": vectors[i],
129
+ "text": docs[i],
130
+ "page_num": 1,
131
+ "source": source_name
132
+ })
133
+
134
+ # Insert data
135
+ self.milvus_client.insert(collection_name=self.collection_name, data=data)
136
+ print(f"Added {len(data)} sample data entries to collection")
137
+
138
+ def _analyze_query_need(self, user_message):
139
+ """
140
+ Analyze user message to determine if legal database search is needed
141
+
142
+ Args:
143
+ user_message: User's message
144
+
145
+ Returns:
146
+ dict: {"needs_search": bool, "queries": list}
147
+ """
148
+ # Preprocessing: Check if user explicitly requests search
149
+ search_keywords = [
150
+ "search in database", "search in the database", "search database",
151
+ "look up", "find in database", "search for", "after searching",
152
+ "query database", "database search", "database lookup"
153
+ ]
154
+
155
+ user_message_lower = user_message.lower()
156
+ explicit_search_request = any(keyword in user_message_lower for keyword in search_keywords)
157
+
158
+ if explicit_search_request:
159
+ print("Detected explicit user request for database search")
160
+ # Clean query content, remove search-related instructions (case insensitive)
161
+ clean_query = user_message
162
+
163
+ # All phrases to remove
164
+ all_phrases_to_remove = search_keywords + [
165
+ "Answer me after searching in the database", "answer me after",
166
+ "please search", "search and tell me", "look up and answer",
167
+ "tell me", "what is", "what are", "explain"
168
+ ]
169
+
170
+ for phrase in all_phrases_to_remove:
171
+ # Case insensitive replacement
172
+ import re
173
+ pattern = re.compile(re.escape(phrase), re.IGNORECASE)
174
+ clean_query = pattern.sub("", clean_query)
175
+
176
+ clean_query = clean_query.strip(".,?! ")
177
+
178
+ if not clean_query or len(clean_query) < 3:
179
+ clean_query = "legal information"
180
+
181
+ return {
182
+ "needs_search": True,
183
+ "reasoning": "User explicitly requested database search",
184
+ "queries": [clean_query]
185
+ }
186
+
187
+ analysis_prompt = [
188
+ {"role": "system", "content": """You are an AI assistant that analyzes user questions to determine if they need legal database searches.
189
+
190
+ Your task is to analyze the user's question and determine:
191
+ 1. Whether this question requires searching a legal database
192
+ 2. If yes, what specific search queries would be most helpful
193
+
194
+ Respond in JSON format:
195
+ {
196
+ "needs_search": true/false,
197
+ "reasoning": "brief explanation of why search is or isn't needed",
198
+ "queries": ["query1", "query2"] // only if needs_search is true
199
+ }
200
+
201
+ IMPORTANT RULES:
202
+ 1. If the user explicitly requests database search (phrases like "search in database", "look up", "find in database"), always set needs_search to true
203
+ 2. For ANY legal topic question, default to needs_search = true unless it's clearly a simple greeting or completely non-legal
204
+ 3. Legal topics include: laws, regulations, legal procedures, legal documents, legal concepts, legal rights, etc.
205
+
206
+ Search should be needed for:
207
+ - ANY legal question (wills, trusts, contracts, rights, procedures, etc.)
208
+ - Questions about specific laws, regulations, or legal codes
209
+ - Requests for legal precedents or case law
210
+ - Questions about legal procedures or requirements
211
+ - Legal document comparisons (like will vs trust)
212
+ - When user explicitly asks to search database
213
+
214
+ Search should NOT be needed ONLY for:
215
+ - Simple greetings ("hello", "how are you")
216
+ - Completely non-legal topics (weather, sports, etc.)
217
+ - Technical issues with the system itself
218
+ """}
219
+ ]
220
+
221
+ # Add recent conversation history as context
222
+ context_messages = self.conversation_history[-3:] if len(self.conversation_history) > 3 else self.conversation_history[1:]
223
+ for msg in context_messages:
224
+ analysis_prompt.append(msg)
225
+
226
+ analysis_prompt.append({"role": "user", "content": f"Analyze this question: {user_message}"})
227
+
228
+ # Display the analysis prompt
229
+ print("\n<prompt>")
230
+ print("Query Analysis Prompt:")
231
+ print(f"User Message: {user_message}")
232
+ print("System: Analyzing if legal database search is needed...")
233
+ print("</prompt>\n")
234
+
235
+ response = self.llm_client.chat.completions.create(
236
+ model=self.model_name,
237
+ messages=analysis_prompt,
238
+ stream=False,
239
+ temperature=0.1
240
+ )
241
+
242
+ response_content = response.choices[0].message.content.strip()
243
+ print(f"LLM Raw Response: {response_content}")
244
+
245
+ # Try to extract JSON content (if response contains other text)
246
+ import re
247
+ json_match = re.search(r'\{.*\}', response_content, re.DOTALL)
248
+ if json_match:
249
+ json_content = json_match.group(0)
250
+ else:
251
+ json_content = response_content
252
+
253
+ analysis_result = json.loads(json_content)
254
+ print(f"Query Analysis Result: {analysis_result}")
255
+ return analysis_result
256
+
257
+ def process_message(self, user_message):
258
+ """
259
+ Process user message and generate response (two-stage mode)
260
+
261
+ Args:
262
+ user_message: User's message
263
+
264
+ Returns:
265
+ Assistant's response
266
+ """
267
+ # Add user message to conversation history
268
+ self.conversation_history.append({"role": "user", "content": user_message})
269
+
270
+ # Stage 1: Analyze if search is needed
271
+ analysis = self._analyze_query_need(user_message)
272
+
273
+ search_results = ""
274
+ if analysis.get("needs_search", False) and analysis.get("queries"):
275
+ # Stage 2: Execute search
276
+ all_results = []
277
+ for query in analysis["queries"][:2]: # Execute max 2 queries
278
+ print(f"Executing search query: {query}")
279
+ result = self.search_legal_database(query)
280
+ if result and result.strip():
281
+ all_results.append(f"Query: {query}\n{result}")
282
+
283
+ if all_results:
284
+ search_results = "\n\n" + "="*50 + "\n".join(all_results)
285
+
286
+ # Display RAG results with tags
287
+ print("\n<RAG_result>")
288
+ print("Search Results from Legal Database:")
289
+ print(search_results)
290
+ print("</RAG_result>\n")
291
+
292
+ # Stage 3: Generate answer based on search results
293
+ final_prompt = self.conversation_history.copy()
294
+
295
+ if search_results:
296
+ final_prompt.append({
297
+ "role": "system",
298
+ "content": f"The following are relevant legal search results, please reference this information in your answer:\n{search_results}\n\nPlease answer the user's question based on these search results, and cite specific sources and page numbers."
299
+ })
300
+
301
+ response = self.llm_client.chat.completions.create(
302
+ model=self.model_name,
303
+ messages=final_prompt,
304
+ stream=False
305
+ )
306
+
307
+ assistant_response = response.choices[0].message.content
308
+
309
+ # Add final response to conversation history
310
+ self.conversation_history.append({"role": "assistant", "content": assistant_response})
311
+
312
+ return assistant_response
313
+
314
+ def process_message_stream(self, user_message):
315
+ """
316
+ Process user message and return streaming response with intelligent RAG queries
317
+
318
+ Args:
319
+ user_message (str): User input message
320
+
321
+ Yields:
322
+ str: Response text fragments
323
+ """
324
+ # Add user message to conversation history
325
+ self.conversation_history.append({"role": "user", "content": user_message})
326
+
327
+ # Stage 1: Analyze if search is needed
328
+ analysis = self._analyze_query_need(user_message)
329
+
330
+ search_results = ""
331
+ if analysis.get("needs_search", False) and analysis.get("queries"):
332
+ # Output search prompt
333
+ yield "\n<prompt>\n"
334
+ yield "🔍 Analyzing query for legal database search...\n"
335
+ yield f"Query Analysis: {analysis.get('reasoning', 'Legal topic detected')}\n"
336
+ yield f"Search needed: {analysis.get('needs_search', False)}\n"
337
+ yield "</prompt>\n\n"
338
+
339
+ yield "[🔍 Searching relevant legal information...]\n\n"
340
+
341
+ # Stage 2: Execute search
342
+ all_results = []
343
+ for query in analysis["queries"][:2]: # Execute max 2 queries
344
+ print(f"Executing streaming search query: {query}")
345
+ result = self.search_legal_database(query)
346
+ if result and result.strip():
347
+ all_results.append(f"Query: {query}\n{result}")
348
+
349
+ if all_results:
350
+ search_results = "\n\n" + "="*50 + "\n".join(all_results)
351
+
352
+ # Output RAG results with tags
353
+ yield "\n<RAG_result>\n"
354
+ yield "📚 Search Results from Legal Database:\n\n"
355
+ for i, result in enumerate(all_results, 1):
356
+ yield f"Search {i}:\n{result}\n\n"
357
+ yield "</RAG_result>\n\n"
358
+
359
+ yield "[✅ Search completed, generating answer...]\n\n"
360
+
361
+ # Stage 3: Generate streaming answer based on search results
362
+ final_prompt = self.conversation_history.copy()
363
+
364
+ if search_results:
365
+ final_prompt.append({
366
+ "role": "system",
367
+ "content": f"The following are relevant legal search results, please reference this information in your answer:\n{search_results}\n\nPlease answer the user's question based on these search results, and cite specific sources and page numbers."
368
+ })
369
+
370
+ # Create streaming completion request
371
+ response = self.llm_client.chat.completions.create(
372
+ model=self.model_name,
373
+ messages=final_prompt,
374
+ stream=True,
375
+ temperature=0.3,
376
+ max_tokens=2048
377
+ )
378
+
379
+ full_response = "" # Store complete response
380
+
381
+ # Process streaming response
382
+ for chunk in response:
383
+ if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
384
+ content = chunk.choices[0].delta.content
385
+ full_response += content
386
+ yield content
387
+
388
+ # Add final response to conversation history
389
+ self.conversation_history.append({"role": "assistant", "content": full_response})
390
+
391
+ def reset_conversation(self):
392
+ """Reset conversation history"""
393
+ self.conversation_history = [self.conversation_history[0]] # Keep system message
394
+
395
+ def main():
396
+ # Configuration
397
+ MILVUS_DB_PATH = "./milvus_legal_codes.db" # Use your existing database name or create new
398
+ COLLECTION_NAME = "legal_codes_collection" # Use your existing collection name or new name
399
+ # OPENAI_API_KEY = "sk-dad31a53a4684587aed060afc0e4d75b" # Replace with actual API key
400
+ # OPENAI_BASE_URL = "https://api.deepseek.com" # Remove this line if using OpenAI API
401
+ OPENAI_API_KEY = "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" # Replace with actual API key
402
+ OPENAI_BASE_URL = "" # Remove this line if using OpenAI API
403
+
404
+ # Initialize chatbot
405
+ chatbot = LegalChatbot(
406
+ milvus_db_path=MILVUS_DB_PATH,
407
+ collection_name=COLLECTION_NAME,
408
+ openai_api_key=OPENAI_API_KEY,
409
+ openai_base_url=OPENAI_BASE_URL,
410
+ # model_name="deepseek-chat"
411
+ model_name="gpt-4o"
412
+ )
413
+
414
+ print("Legal RAG Chatbot initialized. Type 'exit' or 'quit' to end session.")
415
+
416
+ while True:
417
+ user_input = input("\nYou: ")
418
+
419
+ if user_input.lower() in ['exit', 'quit']:
420
+ print("Session ended.")
421
+ break
422
+
423
+ if user_input.lower() in ['reset', 'clear']:
424
+ chatbot.reset_conversation()
425
+ print("Conversation history reset.")
426
+ continue
427
+
428
+ print("\nThinking...")
429
+ start_time = time.time()
430
+
431
+ response = chatbot.process_message(user_input)
432
+
433
+ end_time = time.time()
434
+ print(f"Assistant ({end_time - start_time:.2f}s): {response}")
435
+
436
+ if __name__ == "__main__":
437
  main()