ABE101's picture
Upload 11 files
1835c79 verified
raw
history blame contribute delete
11.3 kB
import streamlit as st
import time
import asyncio
import nest_asyncio
import traceback
from typing import List, Dict, Any
import re # for extracting citation IDs
# --- Configuration and Service Initialization ---
try:
print("App: Loading config...")
import config
print("App: Loading utils...")
from utils import clean_source_text
print("App: Loading services...")
from services.retriever import init_retriever, get_retriever_status
from services.openai_service import init_openai_client, get_openai_status
print("App: Loading RAG processor...")
from rag_processor import execute_validate_generate_pipeline, PIPELINE_VALIDATE_GENERATE_GPT4O
print("App: Imports successful.")
except ImportError as e:
st.error(f"Fatal Error: Module import failed. {e}", icon="🚨")
traceback.print_exc()
st.stop()
except Exception as e:
st.error(f"Fatal Error during initial setup: {e}", icon="🚨")
traceback.print_exc()
st.stop()
nest_asyncio.apply()
# --- Initialize Required Services ---
print("App: Initializing services...")
try:
retriever_ready_init, retriever_msg_init = init_retriever()
openai_ready_init, openai_msg_init = init_openai_client()
print("App: Service initialization calls complete.")
except Exception as init_err:
st.error(f"Error during service initialization: {init_err}", icon="🔥")
traceback.print_exc()
# --- Streamlit Page Configuration and Styling ---
st.set_page_config(page_title="Divrey Yoel AI Chat (GPT-4o Gen)", layout="wide")
st.markdown("""<style> /* ... Keep existing styles ... */ </style>""", unsafe_allow_html=True)
st.markdown("<h1 class='rtl-text'> דברות קודש - חיפוש ועיון</h1>", unsafe_allow_html=True)
st.markdown("<p class='rtl-text'>מבוסס על ספרי דברי יואל מסאטמאר זצוק'ל זי'ע - אחזור מידע חכם (RAG)</p>", unsafe_allow_html=True)
st.markdown("<p class='rtl-text' style='font-size: 0.9em; color: #555;'>תהליך: אחזור -> אימות (GPT-4o) -> יצירה (GPT-4o)</p>", unsafe_allow_html=True)
# --- UI Helper Functions ---
def display_sidebar() -> Dict[str, Any]:
st.sidebar.markdown("<h3 class='rtl-text'>מצב המערכת</h3>", unsafe_allow_html=True)
retriever_ready, _ = get_retriever_status()
openai_ready, _ = get_openai_status()
st.sidebar.markdown(
f"<p class='rtl-text'><strong>מאחזר (Pinecone):</strong> {'✅' if retriever_ready else '❌'}</p>",
unsafe_allow_html=True
)
if not retriever_ready:
st.sidebar.error("מאחזר אינו זמין.", icon="🛑")
st.stop()
st.sidebar.markdown("<hr>", unsafe_allow_html=True)
st.sidebar.markdown(
f"<p class='rtl-text'><strong>OpenAI ({config.OPENAI_VALIDATION_MODEL} / {config.OPENAI_GENERATION_MODEL}):</strong> {'✅' if openai_ready else '❌'}</p>",
unsafe_allow_html=True
)
if not openai_ready:
st.sidebar.error("OpenAI אינו זמין.", icon="⚠️")
st.sidebar.markdown("<hr>", unsafe_allow_html=True)
st.sidebar.markdown("<h3 class='rtl-text'>הגדרות חיפוש</h3>", unsafe_allow_html=True)
n_retrieve = st.sidebar.slider("מספר פסקאות לאחזור", 1, 300, config.DEFAULT_N_RETRIEVE)
max_validate = min(n_retrieve, 100)
n_validate = st.sidebar.slider(
"פסקאות לאימות (GPT-4o)",
1,
max_validate,
min(config.DEFAULT_N_VALIDATE, max_validate),
disabled=not openai_ready
)
st.sidebar.info("התשובות מבוססות רק על המקורות שאומתו.", icon="ℹ️")
return {"n_retrieve": n_retrieve, "n_validate": n_validate, "services_ready": (retriever_ready and openai_ready)}
def display_chat_message(message: Dict[str, Any]):
role = message.get("role", "assistant")
with st.chat_message(role):
st.markdown(message.get('content', ''), unsafe_allow_html=True)
if role == "assistant" and message.get("final_docs"):
docs = message["final_docs"]
exp_title = f"<span class='rtl-text'>הצג {len(docs)} קטעי מקור שנשלחו למחולל (GPT-4o)</span>"
with st.expander(exp_title, expanded=False):
st.markdown("<div dir='rtl' class='expander-content'>", unsafe_allow_html=True)
for i, doc in enumerate(docs, start=1):
if not isinstance(doc, dict):
continue
source = doc.get('source_name', '') or 'מקור לא ידוע'
text = clean_source_text(doc.get('hebrew_text', ''))
st.markdown(
f"<div class='source-info rtl-text'><strong>מקור {i}:</strong> {source}</div>",
unsafe_allow_html=True
)
st.markdown(f"<div class='hebrew-text'>{text}</div>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
def display_status_updates(status_log: List[str]):
if status_log:
with st.expander("<span class='rtl-text'>הצג פרטי עיבוד</span>", expanded=False):
for u in status_log:
st.markdown(
f"<code class='status-update rtl-text'>- {u}</code>",
unsafe_allow_html=True
)
# --- Main Application Logic ---
if "messages" not in st.session_state:
st.session_state.messages = []
rag_params = display_sidebar()
# Render history
for msg in st.session_state.messages:
display_chat_message(msg)
if prompt := st.chat_input("שאל שאלה בענייני חסידות...", disabled=not rag_params["services_ready"]):
st.session_state.messages.append({"role": "user", "content": prompt})
display_chat_message(st.session_state.messages[-1])
with st.chat_message("assistant"):
msg_placeholder = st.empty()
status_container = st.status("מעבד בקשה...", expanded=True)
chunks: List[str] = []
try:
def status_cb(m):
status_container.update(label=f"<span class='rtl-text'>{m}</span>")
def stream_cb(c):
chunks.append(c)
msg_placeholder.markdown(
f"<div dir='rtl' class='rtl-text'>{''.join(chunks)}▌</div>",
unsafe_allow_html=True
)
loop = asyncio.get_event_loop()
final_rag = loop.run_until_complete(
execute_validate_generate_pipeline(
history=st.session_state.messages,
params=rag_params,
status_callback=status_cb,
stream_callback=stream_cb
)
)
if isinstance(final_rag, dict):
raw = final_rag.get("final_response", "")
err = final_rag.get("error")
log = final_rag.get("status_log", [])
docs = final_rag.get("generator_input_documents", [])
pipeline = final_rag.get("pipeline_used", PIPELINE_VALIDATE_GENERATE_GPT4O)
# wrap in RTL div if needed
final = raw
if not (err and final.strip().startswith("<div")) and not final.strip().startswith((
'<div', '<p', '<ul', '<ol', '<strong'
)):
final = f"<div dir='rtl' class='rtl-text'>{final or 'לא התקבלה תשובה מהמחולל.'}</div>"
msg_placeholder.markdown(final, unsafe_allow_html=True)
# --- Show only cited paragraphs ---
cited_ids = set(re.findall(r'\(מקור\s*([0-9]+)\)', raw))
if cited_ids:
enumerated_docs = list(enumerate(docs, start=1))
docs_to_show = [(idx, doc) for idx, doc in enumerated_docs if str(idx) in cited_ids]
else:
docs_to_show = list(enumerate(docs, start=1))
if docs_to_show:
label = f"<span class='rtl-text'>הצג {len(docs_to_show)} קטעי מקור שהוזכרו בתשובה</span>"
with st.expander(label, expanded=False):
st.markdown("<div dir='rtl' class='expander-content'>", unsafe_allow_html=True)
for idx, doc in docs_to_show:
source = doc.get('source_name', '') or 'מקור לא ידוע'
text = clean_source_text(doc.get('hebrew_text', ''))
st.markdown(
f"<div class='source-info rtl-text'><strong>מקור {idx}:</strong> {source}</div>",
unsafe_allow_html=True
)
st.markdown(f"<div class='hebrew-text'>{text}</div>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
# --- end filter display ---
# store assistant message
assistant_data = {
"role": "assistant",
"content": final,
"final_docs": docs,
"pipeline_used": pipeline,
"status_log": log,
"error": err
}
st.session_state.messages.append(assistant_data)
display_status_updates(log)
if err:
status_container.update(label="שגיאה בעיבוד!", state="error", expanded=False)
else:
status_container.update(label="העיבוד הושלם!", state="complete", expanded=False)
else:
msg_placeholder.markdown(
"<div dir='rtl' class='rtl-text'><strong>שגיאה בלתי צפויה בתקשורת.</strong></div>",
unsafe_allow_html=True
)
st.session_state.messages.append({
"role": "assistant",
"content": "שגיאה בלתי צפויה בתקשורת.",
"final_docs": [],
"pipeline_used": "Error",
"status_log": ["Unexpected result"],
"error": "Unexpected"
})
status_container.update(label="שגיאה בלתי צפויה!", state="error", expanded=False)
except Exception as e:
traceback.print_exc()
err_html = (f"<div dir='rtl' class='rtl-text'><strong>שגיאה קריטית!</strong><br>נסה לרענן."
f"<details><summary>פרטים</summary><pre>{traceback.format_exc()}</pre></details></div>")
msg_placeholder.error(err_html, icon="🔥")
st.session_state.messages.append({
"role": "assistant",
"content": err_html,
"final_docs": [],
"pipeline_used": "Critical Error",
"status_log": [f"Critical: {type(e).__name__}"],
"error": str(e)
})
status_container.update(label=str(e), state="error", expanded=False)