RAG / app.py
amiguel's picture
Update app.py
5b7c927 verified
import streamlit as st
import torch
import os
import time
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document
# --- Hugging Face Token ---
HF_TOKEN = st.secrets["HF_TOKEN"]
# --- Page Config ---
st.set_page_config(page_title="DigiTwin RAG", page_icon="πŸ“‚", layout="centered")
st.title("πŸ“‚ DigiTs the Twin")
# --- Sidebar ---
with st.sidebar:
st.header("πŸ“„ Upload Knowledge Files")
uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
model_choice = st.selectbox("🧠 Choose Model", ["Qwen", "Mistral", "Llama3"])
if uploaded_files:
st.success(f"{len(uploaded_files)} file(s) uploaded")
# --- Load Model & Tokenizer ---
@st.cache_resource
def load_model(selected_model):
if selected_model == "Qwen":
model_id = "amiguel/GM_Qwen1.8B_Finetune"
elif selected_model == "Llama3":
model_id = "amiguel/Llama3_8B_Instruct_FP16"
else:
model_id = "amiguel/GM_Mistral7B_Finetune"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
trust_remote_code=True,
token=HF_TOKEN
)
return model, tokenizer, model_id
model, tokenizer, model_id = load_model(model_choice)
# --- System Prompt ---
SYSTEM_PROMPT = (
"You are DigiTwin, a digital expert and senior topside engineer specializing in inspection and maintenance "
"of offshore piping systems, structural elements, mechanical equipment, floating production units, pressure vessels "
"(with emphasis on Visual Internal Inspection - VII), and pressure safety devices (PSDs). Rely on uploaded documents "
"and context to provide practical, standards-driven, and technically accurate responses. Your guidance reflects deep "
"field experience, industry regulations, and proven methodologies in asset integrity and reliability engineering."
)
# --- Prompt Builder ---
def build_prompt(messages, context="", model_name="Qwen"):
if "Mistral" in model_name:
prompt = f"You are DigiTwin, an expert in offshore inspection, maintenance, and asset integrity.\n"
if context:
prompt += f"Here is relevant context:\n{context}\n\n"
for msg in messages:
if msg["role"] == "user":
prompt += f"### Instruction:\n{msg['content'].strip()}\n"
elif msg["role"] == "assistant":
prompt += f"### Response:\n{msg['content'].strip()}\n"
prompt += "### Response:\n"
elif "Llama" in model_name:
prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
prompt += f"{SYSTEM_PROMPT}\n\nContext:\n{context}\n"
for msg in messages:
if msg["role"] == "user":
prompt += "<|start_header_id|>user<|end_header_id|>\n" + msg["content"].strip() + "\n"
elif msg["role"] == "assistant":
prompt += "<|start_header_id|>assistant<|end_header_id|>\n" + msg["content"].strip() + "\n"
prompt += "<|start_header_id|>assistant<|end_header_id|>\n"
else: # Qwen
prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}\n\nContext:\n{context}<|im_end|>\n"
for msg in messages:
role = msg["role"]
prompt += f"<|im_start|>{role}\n{msg['content']}<|im_end|>\n"
prompt += "<|im_start|>assistant\n"
return prompt
# --- Embed Uploaded Documents ---
@st.cache_resource
def embed_uploaded_files(files):
raw_docs = []
for f in files:
path = f"/tmp/{f.name}"
with open(path, "wb") as out_file:
out_file.write(f.read())
loader = PyPDFLoader(path) if f.name.endswith(".pdf") else TextLoader(path)
raw_docs.extend(loader.load())
splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=64)
chunks = splitter.split_documents(raw_docs)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
db = FAISS.from_documents(chunks, embedding=embeddings)
return db
retriever = embed_uploaded_files(uploaded_files) if uploaded_files else None
# --- Streaming Generator ---
def generate_response(prompt_text):
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
thread = Thread(target=model.generate, kwargs={
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_new_tokens": 1024,
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.1,
"do_sample": True,
"streamer": streamer
})
thread.start()
return streamer
# --- Avatars ---
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
# --- Initialize Chat Memory ---
if "messages" not in st.session_state:
st.session_state.messages = []
# --- Display Chat History ---
for msg in st.session_state.messages:
with st.chat_message(msg["role"], avatar=USER_AVATAR if msg["role"] == "user" else BOT_AVATAR):
st.markdown(msg["content"])
# --- Chat Interface ---
if prompt := st.chat_input("Ask something based on uploaded documents..."):
st.chat_message("user", avatar=USER_AVATAR).markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
context = ""
docs = []
if retriever:
docs = retriever.similarity_search(prompt, k=3)
context = "\n\n".join([doc.page_content for doc in docs])
recent_messages = st.session_state.messages[-6:]
full_prompt = build_prompt(recent_messages, context, model_name=model_id)
with st.chat_message("assistant", avatar=BOT_AVATAR):
start = time.time()
container = st.empty()
answer = ""
for chunk in generate_response(full_prompt):
answer += chunk
cleaned = answer
if "Mistral" in model_id or "Llama" in model_id:
cleaned = cleaned.replace("<|im_start|>", "").replace("<|im_end|>", "").strip()
cleaned = cleaned.replace("<|start_header_id|>", "").replace("<|end_header_id|>", "")
cleaned = cleaned.replace("<|begin_of_text|>", "").strip()
container.markdown(cleaned + "β–Œ", unsafe_allow_html=True)
end = time.time()
st.session_state.messages.append({"role": "assistant", "content": cleaned})
input_tokens = len(tokenizer(full_prompt)["input_ids"])
output_tokens = len(tokenizer(cleaned)["input_ids"])
speed = output_tokens / (end - start)
with st.expander("πŸ“Š Debug Info"):
st.caption(
f"πŸ”‘ Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
f"πŸ•’ Speed: {speed:.1f} tokens/sec"
)
for i, doc in enumerate(docs):
st.markdown(f"**Chunk #{i+1}**")
st.code(doc.page_content.strip()[:500])