|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
import re |
|
import os |
|
import numpy as np |
|
import faiss |
|
from sentence_transformers import SentenceTransformer |
|
from langchain_groq import ChatGroq |
|
from langchain.chains import LLMChain |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from pydantic import BaseModel, Field |
|
from langchain.output_parsers import PydanticOutputParser |
|
from lm import get_query_llm, get_answer_llm |
|
|
|
|
|
q_llm = get_query_llm() |
|
a_llm = get_answer_llm() |
|
|
|
|
|
embedding_model = SentenceTransformer("all-MiniLM-L6-v2") |
|
save_dir = "." |
|
|
|
def load_embeddings_and_index(save_dir="."): |
|
embedding = np.load(os.path.join(save_dir, "embeddings.npy")) |
|
index = faiss.read_index(os.path.join(save_dir, "index.faiss")) |
|
with open(os.path.join(save_dir, "chunks.txt"), "r", encoding="utf-8") as f: |
|
chunks = [line.strip() for line in f.readlines()] |
|
return embedding, index, chunks |
|
|
|
similar_words = [ |
|
"explain", "elaborate", "describe", "clarify", "detail", "break down", "simplify", "outline", |
|
"demonstrate", "illustrate", "interpret", "expand on", "go over", "walk through", "define", |
|
"unpack", "decode", "shed light on", "analyze", "discuss", "make clear", "reveal", "disclose", |
|
"comment on", "talk about", "lay out", "spell out", "express", "delve into", "explore", |
|
"enlighten", "present", "review", "report", "state", "point out", "inform", "highlight" |
|
] |
|
|
|
def is_explanation_query(query): |
|
return not any(word in query.lower() for word in similar_words) |
|
|
|
def retrieve_relevant_chunks(query, index, chunks, top_k=5): |
|
sub_str = "article" |
|
numbers = re.findall(r'\d+', query) |
|
|
|
if sub_str in query.lower() and numbers: |
|
flag = False |
|
article_number = str(numbers[0]) |
|
for i, chunk in enumerate(chunks): |
|
if chunk.lower().startswith(f"article;{article_number}"): |
|
flag = is_explanation_query(query) |
|
return [chunk], flag |
|
|
|
query_embedding = embedding_model.encode([query]) |
|
query_embedding = np.array(query_embedding).astype("float32") |
|
distances, indices = index.search(query_embedding, top_k) |
|
relevant_chunks = [chunks[i] for i in indices[0]] |
|
flag = is_explanation_query(query) |
|
return relevant_chunks, flag |
|
|
|
|
|
refine_prompt_template = ChatPromptTemplate.from_messages([ |
|
('system', 'You are a helpful legal assistant. Your job is to clean user queries by fixing grammar or spelling only. Do not expand or explain them. if number is mentioned returned in digit format.'), |
|
('human', '{query}') |
|
]) |
|
refine_chain = LLMChain(llm=q_llm, prompt=refine_prompt_template) |
|
|
|
|
|
class LegalResponse(BaseModel): |
|
title: str = Field (...,description='Return the title') |
|
answer: str = Field(..., description="The assistant's answer to the user's query") |
|
is_relevant: bool = Field(..., description="True if the query is relevant to the Constitution of Pakistan, otherwise False") |
|
article_number: str = Field(..., description="Mentioned article number if available, else empty string") |
|
|
|
parser = PydanticOutputParser(pydantic_object=LegalResponse) |
|
|
|
|
|
answer_prompt_template_query = ChatPromptTemplate.from_messages([ |
|
("system", |
|
"You are a legal assistant with expertise in the Constitution of Pakistan. " |
|
"Return answer in structure format." |
|
"Your task is to extract and present the exact constitutional text, without paraphrasing, ensuring accuracy and fidelity to the original wording" |
|
"Especially return the title"), |
|
("human", |
|
"User Query: {query}\n\n" |
|
"Instructions:\n" |
|
"0. Return Title" |
|
"1. Return the exact wording from the Constitution.\n" |
|
"2. If a query references a specific article or sub-clause (e.g., Article 11(3)(b), Article 11(b), or 11(i)), return only the exact wording of that clause from the Constitution — do not include the full article unless required by structure\n" |
|
"3. Indicate whether the query is related to the Constitution of Pakistan (Yes/No).ar\n" |
|
"4. Extract and return the article number if it is mentioned. with sub-clause if its mentioned like 1,2 or 1(a)\n\n" |
|
"Context:\n{context}\n\n" |
|
"{format_instructions}\n") |
|
]) |
|
|
|
answer_chain_article = LLMChain(llm=a_llm, prompt=answer_prompt_template_query, output_parser=parser) |
|
|
|
|
|
explanation_prompt_template_query = ChatPromptTemplate.from_messages([ |
|
("system", |
|
"You are a legal expert assistant with deep knowledge of the Constitution of Pakistan. " |
|
"Return answer in well structured format." |
|
"Your task is to provide clear, in-depth explanations strictly based on the constitutional text. " |
|
"Do not refer to real-world examples, cases, or interpretations beyond the Constitution itself."), |
|
|
|
("human", |
|
"User Query: {query}\n\n" |
|
"Based on the provided context, follow these steps:\n\n" |
|
"1. Provide a comprehensive and detailed explanation strictly based on the Constitution of Pakistan.\n" |
|
"2. If the query refers to any law, article, amendment, clause, or provision that is passed, explain it thoroughly—" |
|
"including what it means, how it is structured, and how it functions within the Constitution.\n" |
|
"3. Avoid using any real-world references, court cases, or practical examples.\n" |
|
"4. Indicate whether the query is relevant to the Constitution of Pakistan (True/False).\n" |
|
"5. Extract and return the article number if any is mentioned in the query.\n\n" |
|
"Context:\n{context}\n\n" |
|
"{format_instructions}\n\n") |
|
]) |
|
|
|
|
|
answer_chain_explanation = LLMChain(llm=a_llm, prompt=explanation_prompt_template_query, output_parser=parser) |
|
|
|
|
|
embeddings, index, chunks = load_embeddings_and_index(save_dir) |
|
|
|
|
|
def get_legal_response(query): |
|
try: |
|
refined_query = refine_chain.run(query=query) |
|
except Exception as e: |
|
print(f"[Refinement Error] Using raw query instead: {e}") |
|
refined_query = query |
|
|
|
print("\nRefined Query:", refined_query) |
|
|
|
relevant_chunks, flag = retrieve_relevant_chunks(refined_query, index, chunks, top_k=5) |
|
|
|
print("\nTop Relevant Chunks:") |
|
for i, chunk in enumerate(relevant_chunks, 1): |
|
print(f"\nChunk {i}:\n{'-'*50}\n{chunk}") |
|
|
|
context = "\n\n".join(relevant_chunks) |
|
|
|
if flag==True: |
|
print('okokokokokokokokokokok') |
|
response = answer_chain_article.run(query=refined_query,context=context,format_instructions=parser.get_format_instructions()) |
|
else: |
|
print('nononononononononono') |
|
response = answer_chain_explanation.run(query=refined_query,context=context,format_instructions=parser.get_format_instructions()) |
|
|
|
return { |
|
"title":response.title, |
|
"answer": response.answer, |
|
"is_relevant": response.is_relevant, |
|
"article_number": response.article_number |
|
} |
|
|