taxagent / taxagent.py
fragger246's picture
Upload 6 files
2cb9c2c verified
import streamlit as st
import fitz # PyMuPDF for PDF extraction
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import hashlib
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import OllamaEmbeddings
# ========================== LOAD FINE-TUNED MODEL ========================== #
MODEL_PATH = "./fine_tuned_tinyllama_tax" # Change to your actual model path
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch.float16,
device_map="auto"
)
tax_llm = pipeline("text-generation", model=model, tokenizer=tokenizer)
# ========================== SESSION STATE INITIALIZATION ========================== #
if "legal_knowledge_base" not in st.session_state:
st.session_state.legal_knowledge_base = ""
if "vector_db" not in st.session_state:
st.session_state.vector_db = None
if "summary" not in st.session_state:
st.session_state.summary = ""
if "answer" not in st.session_state:
st.session_state.answer = ""
# ========================== HELPER FUNCTIONS ========================== #
def compute_file_hash(file):
"""Computes SHA-256 hash of the uploaded file to track changes."""
hasher = hashlib.sha256()
hasher.update(file.read())
file.seek(0) # Reset file pointer
return hasher.hexdigest()
def extract_text_from_pdf(pdf_file):
"""Extracts text from a PDF using PyMuPDF (fitz)."""
doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
pdf_file.seek(0) # Reset pointer
text = "\n".join([page.get_text("text") for page in doc])
return text.strip() if text.strip() else "No extractable text found in PDF."
def summarize_text(text):
"""Summarizes tax policy documents using fine-tuned AI."""
prompt = f"Summarize this tax policy document concisely:\n{text}"
summary = tax_llm(prompt, max_length=200, do_sample=True)[0]["generated_text"]
return summary
def create_vector_db():
"""Creates a searchable vector database from extracted legal documents."""
text = st.session_state.legal_knowledge_base
if not text:
return None
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=1000, chunk_overlap=150)
texts = text_splitter.split_text(text)
embeddings = OllamaEmbeddings(model="llama3:8b")
return FAISS.from_texts(texts, embeddings)
def retrieve_relevant_text(query, vector_db):
"""Fetches relevant legal sections from the document."""
if not vector_db:
return "No document uploaded."
docs = vector_db.similarity_search(query, k=5)
retrieved_text = "\n".join([doc.page_content for doc in docs])
return retrieved_text
def compute_tax_details(query):
"""Extracts income & tax rate and calculates tax."""
import re
income_match = re.search(r"β‚Ή?(\d[\d,]*)", query.replace(",", ""))
tax_rate_match = re.search(r"(\d+)%", query)
if income_match and tax_rate_match:
income = float(income_match.group(1).replace(",", ""))
tax_rate = float(tax_rate_match.group(1))
computed_tax = round(income * (tax_rate / 100), 2)
return f"Based on an income of β‚Ή{income:,.2f} and a tax rate of {tax_rate}%, the tax is **β‚Ή{computed_tax:,.2f}.**"
return None
def answer_user_query(query):
"""Answers tax-related queries using the fine-tuned model."""
tax_computation_result = compute_tax_details(query)
if tax_computation_result:
st.session_state.answer = tax_computation_result
return
if not st.session_state.vector_db:
st.error("Please upload a document first.")
return
retrieved_text = retrieve_relevant_text(query, st.session_state.vector_db)
prompt = f"""
You are an AI tax expert. Use legal knowledge and tax calculations to answer.
Context:
{retrieved_text}
User Query:
{query}
Response:
"""
response = tax_llm(prompt, max_length=300, do_sample=True)[0]["generated_text"]
st.session_state.answer = response
# ========================== STREAMLIT UI ========================== #
def main():
st.title("πŸ“œ AI Legal Tax Assistant")
uploaded_file = st.file_uploader("πŸ“„ Upload Tax Policy PDF", type=["pdf"])
if uploaded_file:
with st.spinner("Extracting text..."):
extracted_text = extract_text_from_pdf(uploaded_file)
st.session_state.legal_knowledge_base = extracted_text
st.success("Document Uploaded!")
with st.spinner("Generating summary..."):
st.session_state.summary = summarize_text(extracted_text)
st.subheader("πŸ“„ Document Summary:")
st.text_area("", st.session_state.summary, height=250)
with st.spinner("Indexing document..."):
st.session_state.vector_db = create_vector_db()
st.success("Document indexed! Ask questions now.")
st.subheader("πŸ’¬ Ask Questions:")
user_query = st.text_input("Enter your question:")
if st.button("Ask") and user_query.strip():
with st.spinner("Processing..."):
answer_user_query(user_query)
if st.session_state.answer:
st.markdown("### πŸ€– AI Response:")
st.success(st.session_state.answer)
if __name__ == "__main__":
main()