|
import streamlit as st
|
|
import fitz
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
|
import hashlib
|
|
from langchain.text_splitter import CharacterTextSplitter
|
|
from langchain.vectorstores import FAISS
|
|
from langchain.embeddings import OllamaEmbeddings
|
|
|
|
|
|
|
|
MODEL_PATH = "./fine_tuned_tinyllama_tax"
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
MODEL_PATH,
|
|
torch_dtype=torch.float16,
|
|
device_map="auto"
|
|
)
|
|
|
|
tax_llm = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
|
|
|
|
|
|
|
if "legal_knowledge_base" not in st.session_state:
|
|
st.session_state.legal_knowledge_base = ""
|
|
if "vector_db" not in st.session_state:
|
|
st.session_state.vector_db = None
|
|
if "summary" not in st.session_state:
|
|
st.session_state.summary = ""
|
|
if "answer" not in st.session_state:
|
|
st.session_state.answer = ""
|
|
|
|
|
|
|
|
def compute_file_hash(file):
|
|
"""Computes SHA-256 hash of the uploaded file to track changes."""
|
|
hasher = hashlib.sha256()
|
|
hasher.update(file.read())
|
|
file.seek(0)
|
|
return hasher.hexdigest()
|
|
|
|
def extract_text_from_pdf(pdf_file):
|
|
"""Extracts text from a PDF using PyMuPDF (fitz)."""
|
|
doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
|
|
pdf_file.seek(0)
|
|
text = "\n".join([page.get_text("text") for page in doc])
|
|
return text.strip() if text.strip() else "No extractable text found in PDF."
|
|
|
|
def summarize_text(text):
|
|
"""Summarizes tax policy documents using fine-tuned AI."""
|
|
prompt = f"Summarize this tax policy document concisely:\n{text}"
|
|
summary = tax_llm(prompt, max_length=200, do_sample=True)[0]["generated_text"]
|
|
return summary
|
|
|
|
def create_vector_db():
|
|
"""Creates a searchable vector database from extracted legal documents."""
|
|
text = st.session_state.legal_knowledge_base
|
|
if not text:
|
|
return None
|
|
|
|
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=1000, chunk_overlap=150)
|
|
texts = text_splitter.split_text(text)
|
|
embeddings = OllamaEmbeddings(model="llama3:8b")
|
|
return FAISS.from_texts(texts, embeddings)
|
|
|
|
def retrieve_relevant_text(query, vector_db):
|
|
"""Fetches relevant legal sections from the document."""
|
|
if not vector_db:
|
|
return "No document uploaded."
|
|
|
|
docs = vector_db.similarity_search(query, k=5)
|
|
retrieved_text = "\n".join([doc.page_content for doc in docs])
|
|
return retrieved_text
|
|
|
|
def compute_tax_details(query):
|
|
"""Extracts income & tax rate and calculates tax."""
|
|
import re
|
|
|
|
income_match = re.search(r"βΉ?(\d[\d,]*)", query.replace(",", ""))
|
|
tax_rate_match = re.search(r"(\d+)%", query)
|
|
|
|
if income_match and tax_rate_match:
|
|
income = float(income_match.group(1).replace(",", ""))
|
|
tax_rate = float(tax_rate_match.group(1))
|
|
|
|
computed_tax = round(income * (tax_rate / 100), 2)
|
|
return f"Based on an income of βΉ{income:,.2f} and a tax rate of {tax_rate}%, the tax is **βΉ{computed_tax:,.2f}.**"
|
|
|
|
return None
|
|
|
|
def answer_user_query(query):
|
|
"""Answers tax-related queries using the fine-tuned model."""
|
|
tax_computation_result = compute_tax_details(query)
|
|
|
|
if tax_computation_result:
|
|
st.session_state.answer = tax_computation_result
|
|
return
|
|
|
|
if not st.session_state.vector_db:
|
|
st.error("Please upload a document first.")
|
|
return
|
|
|
|
retrieved_text = retrieve_relevant_text(query, st.session_state.vector_db)
|
|
prompt = f"""
|
|
You are an AI tax expert. Use legal knowledge and tax calculations to answer.
|
|
|
|
Context:
|
|
{retrieved_text}
|
|
|
|
User Query:
|
|
{query}
|
|
|
|
Response:
|
|
"""
|
|
|
|
response = tax_llm(prompt, max_length=300, do_sample=True)[0]["generated_text"]
|
|
st.session_state.answer = response
|
|
|
|
|
|
|
|
def main():
|
|
st.title("π AI Legal Tax Assistant")
|
|
|
|
uploaded_file = st.file_uploader("π Upload Tax Policy PDF", type=["pdf"])
|
|
|
|
if uploaded_file:
|
|
with st.spinner("Extracting text..."):
|
|
extracted_text = extract_text_from_pdf(uploaded_file)
|
|
st.session_state.legal_knowledge_base = extracted_text
|
|
st.success("Document Uploaded!")
|
|
|
|
with st.spinner("Generating summary..."):
|
|
st.session_state.summary = summarize_text(extracted_text)
|
|
st.subheader("π Document Summary:")
|
|
st.text_area("", st.session_state.summary, height=250)
|
|
|
|
with st.spinner("Indexing document..."):
|
|
st.session_state.vector_db = create_vector_db()
|
|
st.success("Document indexed! Ask questions now.")
|
|
|
|
st.subheader("π¬ Ask Questions:")
|
|
user_query = st.text_input("Enter your question:")
|
|
|
|
if st.button("Ask") and user_query.strip():
|
|
with st.spinner("Processing..."):
|
|
answer_user_query(user_query)
|
|
|
|
if st.session_state.answer:
|
|
st.markdown("### π€ AI Response:")
|
|
st.success(st.session_state.answer)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|