File size: 19,423 Bytes
8cc98f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12e0fa1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
import re
import json
from pymilvus import MilvusClient, model
from openai import OpenAI
import time

class LegalChatbot:
    def __init__(self, milvus_db_path, collection_name, openai_api_key, openai_base_url=None, model_name="deepseek-reasoner"):
        """
        Initialize Legal RAG Chatbot
        
        Args:
            milvus_db_path: Milvus database path
            collection_name: Collection name to search
            openai_api_key: OpenAI API key
            openai_base_url: Optional API base URL (for DeepSeek etc.)
            model_name: LLM model name to use
        """
        # Initialize Milvus client
        self.milvus_client = MilvusClient(milvus_db_path)
        self.collection_name = collection_name
        
        # Check if collection exists, create if not
        if not self.milvus_client.has_collection(collection_name=collection_name):
            print(f"Collection '{collection_name}' does not exist. Creating it...")
            # Initialize embedding model
            self.embedding_fn = model.DefaultEmbeddingFunction()
            vector_dim = self.embedding_fn.dim
            
            # Create new collection
            self.milvus_client.create_collection(
                collection_name=collection_name,
                dimension=vector_dim
            )
            print(f"Collection '{collection_name}' created successfully.")
        
        # Initialize embedding model
        self.embedding_fn = model.DefaultEmbeddingFunction()
        
        # Initialize OpenAI client
        if openai_base_url:
            self.llm_client = OpenAI(api_key=openai_api_key, base_url=openai_base_url)
        else:
            self.llm_client = OpenAI(api_key=openai_api_key)
        
        self.model_name = model_name
        self.conversation_history = [
            {"role": "system", "content": """You are a helpful paralegal assistant with expertise in Canadian and U.S. law.
            
            You will help users with their legal questions. When answering, you should be helpful, accurate, and cite specific legal sources when possible.
            
            Users are members of the general public and may ask questions in Chinese or English. Please respond in the same language as the user's question.
            """}
        ]
    
    def search_legal_database(self, query, limit=5):
        """
        Search legal database using Milvus
        
        Args:
            query: Search query
            limit: Number of results to return
            
        Returns:
            Formatted search results string
        """
        if not query or query.strip() == "" or query.strip().lower() == "query":
            return "Invalid search query. Please provide specific search content."
            
        # Check if database has data
        collection_stats = self.milvus_client.get_collection_stats(self.collection_name)
        row_count = collection_stats.get("row_count", 0)
        
        if row_count == 0:
            # If collection is empty, add sample data
            print("Collection is empty, adding sample data...")
            self._add_sample_data()
            
        # Generate query vector
        query_vector = self.embedding_fn.encode_queries([query])
        
        # Execute search
        search_results = self.milvus_client.search(
            collection_name=self.collection_name,
            data=query_vector,
            limit=limit,
            output_fields=["text", "page_num", "source"]
        )
        
        # Check if there are results
        if not search_results or len(search_results[0]) == 0:
            return "No results found related to this query."
            
        # Format search results
        formatted_results = []
        for i, result in enumerate(search_results[0]):
            similarity = 1 - result['distance']
            source = result['entity'].get('source', 'Unknown source')
            page_num = result['entity'].get('page_num', 'Unknown page')
            text = result['entity'].get('text', '')
            
            formatted_result = f"[Result {i+1}] Source: {source}, Page: {page_num}, Relevance: {similarity:.4f}\n"
            formatted_result += f"Content: {text}\n\n"
            formatted_results.append(formatted_result)
        
        return "\n".join(formatted_results)
            
    def _add_sample_data(self):
        """Add sample legal text data to empty collection"""
        # Simple legal text examples
        docs = [
            "Ontario Regulation 213/91 (Construction Projects) under the Occupational Health and Safety Act contains provisions for construction safety. Section 26 requires that every worker who may be exposed to the hazard of falling more than 3 metres shall use a fall protection system.",
            "Under the Canada Labour Code, employers have a duty to ensure that the health and safety at work of every person employed by the employer is protected (Section 124). This includes providing proper training and supervision.",
            "The Criminal Code of Canada Section 217.1 states that everyone who undertakes, or has the authority, to direct how another person does work or performs a task is under a legal duty to take reasonable steps to prevent bodily harm to that person, or any other person, arising from that work or task.",
            "British Columbia's Workers Compensation Act requires employers to ensure the health and safety of all workers and comply with occupational health and safety regulations. This includes providing proper equipment, training, and supervision for construction activities.",
            "Alberta's Occupational Health and Safety Code (Part 9) contains specific requirements for fall protection systems when workers are at heights of 3 metres or more, including the use of guardrails, safety nets, or personal fall arrest systems."
        ]
        
        # Generate vectors
        vectors = self.embedding_fn.encode_documents(docs)
        
        # Prepare data
        data = []
        for i in range(len(docs)):
            source_name = f"Sample Legal Text {i+1}"
            data.append({
                "id": i,
                "vector": vectors[i],
                "text": docs[i],
                "page_num": 1,
                "source": source_name
            })
            
        # Insert data
        self.milvus_client.insert(collection_name=self.collection_name, data=data)
        print(f"Added {len(data)} sample data entries to collection")
    
    def _analyze_query_need(self, user_message):
        """
        Analyze user message to determine if legal database search is needed
        
        Args:
            user_message: User's message
            
        Returns:
            dict: {"needs_search": bool, "queries": list} 
        """
        # Preprocessing: Check if user explicitly requests search
        search_keywords = [
            "search in database", "search in the database", "search database", 
            "look up", "find in database", "search for", "after searching",
            "query database", "database search", "database lookup"
        ]
        
        user_message_lower = user_message.lower()
        explicit_search_request = any(keyword in user_message_lower for keyword in search_keywords)
        
        if explicit_search_request:
            print("Detected explicit user request for database search")
            # Clean query content, remove search-related instructions (case insensitive)
            clean_query = user_message
            
            # All phrases to remove
            all_phrases_to_remove = search_keywords + [
                "Answer me after searching in the database", "answer me after", 
                "please search", "search and tell me", "look up and answer",
                "tell me", "what is", "what are", "explain"
            ]
            
            for phrase in all_phrases_to_remove:
                # Case insensitive replacement
                import re
                pattern = re.compile(re.escape(phrase), re.IGNORECASE)
                clean_query = pattern.sub("", clean_query)
            
            clean_query = clean_query.strip(".,?! ")
            
            if not clean_query or len(clean_query) < 3:
                clean_query = "legal information"
                
            return {
                "needs_search": True,
                "reasoning": "User explicitly requested database search",
                "queries": [clean_query]
            }
        
        analysis_prompt = [
            {"role": "system", "content": """You are an AI assistant that analyzes user questions to determine if they need legal database searches.
            
            Your task is to analyze the user's question and determine:
            1. Whether this question requires searching a legal database
            2. If yes, what specific search queries would be most helpful
            
            Respond in JSON format:
            {
                "needs_search": true/false,
                "reasoning": "brief explanation of why search is or isn't needed",
                "queries": ["query1", "query2"] // only if needs_search is true
            }
            
            IMPORTANT RULES:
            1. If the user explicitly requests database search (phrases like "search in database", "look up", "find in database"), always set needs_search to true
            2. For ANY legal topic question, default to needs_search = true unless it's clearly a simple greeting or completely non-legal
            3. Legal topics include: laws, regulations, legal procedures, legal documents, legal concepts, legal rights, etc.
            
            Search should be needed for:
            - ANY legal question (wills, trusts, contracts, rights, procedures, etc.)
            - Questions about specific laws, regulations, or legal codes  
            - Requests for legal precedents or case law
            - Questions about legal procedures or requirements
            - Legal document comparisons (like will vs trust)
            - When user explicitly asks to search database
            
            Search should NOT be needed ONLY for:
            - Simple greetings ("hello", "how are you")
            - Completely non-legal topics (weather, sports, etc.)
            - Technical issues with the system itself
            """}
        ]
        
        # Add recent conversation history as context
        context_messages = self.conversation_history[-3:] if len(self.conversation_history) > 3 else self.conversation_history[1:]
        for msg in context_messages:
            analysis_prompt.append(msg)
            
        analysis_prompt.append({"role": "user", "content": f"Analyze this question: {user_message}"})
        
        # Display the analysis prompt
        print("\n<prompt>")
        print("Query Analysis Prompt:")
        print(f"User Message: {user_message}")
        print("System: Analyzing if legal database search is needed...")
        print("</prompt>\n")
        
        response = self.llm_client.chat.completions.create(
            model=self.model_name,
            messages=analysis_prompt,
            stream=False,
            temperature=0.1
        )
        
        response_content = response.choices[0].message.content.strip()
        print(f"LLM Raw Response: {response_content}")
        
        # Try to extract JSON content (if response contains other text)
        import re
        json_match = re.search(r'\{.*\}', response_content, re.DOTALL)
        if json_match:
            json_content = json_match.group(0)
        else:
            json_content = response_content
        
        analysis_result = json.loads(json_content)
        print(f"Query Analysis Result: {analysis_result}")
        return analysis_result
        
    def process_message(self, user_message):
        """
        Process user message and generate response (two-stage mode)
        
        Args:
            user_message: User's message
            
        Returns:
            Assistant's response
        """
        # Add user message to conversation history
        self.conversation_history.append({"role": "user", "content": user_message})
        
        # Stage 1: Analyze if search is needed
        analysis = self._analyze_query_need(user_message)
        
        search_results = ""
        if analysis.get("needs_search", False) and analysis.get("queries"):
            # Stage 2: Execute search
            all_results = []
            for query in analysis["queries"][:2]:  # Execute max 2 queries
                print(f"Executing search query: {query}")
                result = self.search_legal_database(query)
                if result and result.strip():
                    all_results.append(f"Query: {query}\n{result}")
            
            if all_results:
                search_results = "\n\n" + "="*50 + "\n".join(all_results)
                
                # Display RAG results with tags
                print("\n<RAG_result>")
                print("Search Results from Legal Database:")
                print(search_results)
                print("</RAG_result>\n")
        
        # Stage 3: Generate answer based on search results
        final_prompt = self.conversation_history.copy()
        
        if search_results:
            final_prompt.append({
                "role": "system", 
                "content": f"The following are relevant legal search results, please reference this information in your answer:\n{search_results}\n\nPlease answer the user's question based on these search results, and cite specific sources and page numbers."
            })
            
        response = self.llm_client.chat.completions.create(
            model=self.model_name,
            messages=final_prompt,
            stream=False
        )
        
        assistant_response = response.choices[0].message.content
        
        # Add final response to conversation history
        self.conversation_history.append({"role": "assistant", "content": assistant_response})
        
        return assistant_response
    
    def process_message_stream(self, user_message):
        """
        Process user message and return streaming response with intelligent RAG queries
        
        Args:
            user_message (str): User input message
            
        Yields:
            str: Response text fragments
        """
        # Add user message to conversation history
        self.conversation_history.append({"role": "user", "content": user_message})
    
        # Stage 1: Analyze if search is needed
        analysis = self._analyze_query_need(user_message)
        
        search_results = ""
        if analysis.get("needs_search", False) and analysis.get("queries"):
            # Output search prompt
            yield "\n<prompt>\n"
            yield "πŸ” Analyzing query for legal database search...\n"
            yield f"Query Analysis: {analysis.get('reasoning', 'Legal topic detected')}\n"
            yield f"Search needed: {analysis.get('needs_search', False)}\n"
            yield "</prompt>\n\n"
            
            yield "[πŸ” Searching relevant legal information...]\n\n"
            
            # Stage 2: Execute search
            all_results = []
            for query in analysis["queries"][:2]:  # Execute max 2 queries
                print(f"Executing streaming search query: {query}")
                result = self.search_legal_database(query)
                if result and result.strip():
                    all_results.append(f"Query: {query}\n{result}")
            
            if all_results:
                search_results = "\n\n" + "="*50 + "\n".join(all_results)
                
                # Output RAG results with tags
                yield "\n<RAG_result>\n"
                yield "πŸ“š Search Results from Legal Database:\n\n"
                for i, result in enumerate(all_results, 1):
                    yield f"Search {i}:\n{result}\n\n"
                yield "</RAG_result>\n\n"
                
                yield "[βœ… Search completed, generating answer...]\n\n"
        
        # Stage 3: Generate streaming answer based on search results
        final_prompt = self.conversation_history.copy()
        
        if search_results:
            final_prompt.append({
                "role": "system", 
                "content": f"The following are relevant legal search results, please reference this information in your answer:\n{search_results}\n\nPlease answer the user's question based on these search results, and cite specific sources and page numbers."
            })
        
        # Create streaming completion request
        response = self.llm_client.chat.completions.create(
            model=self.model_name,
            messages=final_prompt,
            stream=True,
            temperature=0.3,
            max_tokens=2048
        )
        
        full_response = ""  # Store complete response
        
        # Process streaming response
        for chunk in response:
            if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
                content = chunk.choices[0].delta.content
                full_response += content
                yield content
        
        # Add final response to conversation history
        self.conversation_history.append({"role": "assistant", "content": full_response})
    
    def reset_conversation(self):
        """Reset conversation history"""
        self.conversation_history = [self.conversation_history[0]]  # Keep system message

def main():
    # Configuration
    MILVUS_DB_PATH = "./milvus_legal_codes.db"  # Use your existing database name or create new
    COLLECTION_NAME = "legal_codes_collection"  # Use your existing collection name or new name
    # OPENAI_API_KEY = "sk-dad31a53a4684587aed060afc0e4d75b"  # Replace with actual API key
    # OPENAI_BASE_URL = "https://api.deepseek.com"  # Remove this line if using OpenAI API
    OPENAI_API_KEY = "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"  # Replace with actual API key
    OPENAI_BASE_URL = ""  # Remove this line if using OpenAI API
    
    # Initialize chatbot
    chatbot = LegalChatbot(
        milvus_db_path=MILVUS_DB_PATH,
        collection_name=COLLECTION_NAME,
        openai_api_key=OPENAI_API_KEY,
        openai_base_url=OPENAI_BASE_URL,
        # model_name="deepseek-chat"
        model_name="gpt-4o"
    )
    
    print("Legal RAG Chatbot initialized. Type 'exit' or 'quit' to end session.")
    
    while True:
        user_input = input("\nYou: ")
        
        if user_input.lower() in ['exit', 'quit']:
            print("Session ended.")
            break
        
        if user_input.lower() in ['reset', 'clear']:
            chatbot.reset_conversation()
            print("Conversation history reset.")
            continue
        
        print("\nThinking...")
        start_time = time.time()
        
        response = chatbot.process_message(user_input)
        
        end_time = time.time()
        print(f"Assistant ({end_time - start_time:.2f}s): {response}")

if __name__ == "__main__":
    main()