File size: 19,935 Bytes
f513b53
b561129
 
11694c7
b561129
 
11694c7
b561129
11694c7
 
b561129
 
 
 
 
 
 
56d0815
b561129
 
11694c7
56d0815
b561129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11694c7
b561129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11694c7
b561129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11694c7
b561129
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
import streamlit as st
import logging
import os
from io import BytesIO
import pdfplumber
from pdf2image import convert_from_bytes
from PIL import Image
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import FAISS
from sentence_transformers import SentenceTransformer
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments
from datasets import load_dataset
from rank_bm25 import BM25Okapi
from rouge_score import rouge_scorer
import re
import time
import pytesseract

# Setup logging for Spaces
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Lazy load models
@st.cache_resource(ttl=1800)
def load_embeddings_model():
    logger.info("Loading embeddings model")
    try:
        return SentenceTransformer("all-MiniLM-L6-v2")
    except Exception as e:
        logger.error(f"Embeddings load error: {str(e)}")
        st.error(f"Embedding model error: {str(e)}")
        return None

@st.cache_resource(ttl=1800)
def load_qa_pipeline():
    logger.info("Loading QA pipeline")
    try:
        dataset = load_and_prepare_dataset()
        if dataset:
            fine_tuned_pipeline = fine_tune_qa_model(dataset)
            if fine_tuned_pipeline:
                return fine_tuned_pipeline
        return pipeline("text2text-generation", model="google/flan-t5-small", max_length=300)
    except Exception as e:
        logger.error(f"QA model load error: {str(e)}")
        st.error(f"QA model error: {str(e)}")
        return None

@st.cache_resource(ttl=1800)
def load_summary_pipeline():
    logger.info("Loading summary pipeline")
    try:
        return pipeline("summarization", model="facebook/bart-large-cnn", max_length=250)
    except Exception as e:
        logger.error(f"Summary model load error: {str(e)}")
        st.error(f"Summary model error: {str(e)}")
        return None

# Load and prepare dataset (e.g., SQuAD)
@st.cache_data(ttl=3600)
def load_and_prepare_dataset(dataset_name="squad", max_samples=1000):
    logger.info(f"Loading dataset: {dataset_name}")
    try:
        dataset = load_dataset(dataset_name, split="train[:80%]")
        dataset = dataset.shuffle(seed=42).select(range(min(max_samples, len(dataset))))
        
        def preprocess(examples):
            inputs = [f"question: {q} context: {c}" for q, c in zip(examples['question'], examples['context'])]
            targets = examples['answers']['text']
            return {'input_text': inputs, 'target_text': [t[0] if t else "" for t in targets]}
        
        dataset = dataset.map(preprocess, batched=True, remove_columns=dataset.column_names)
        return dataset
    except Exception as e:
        logger.error(f"Dataset load error: {str(e)}")
        return None

# Fine-tune QA model
@st.cache_resource(ttl=3600)
def fine_tune_qa_model(dataset):
    logger.info("Starting fine-tuning")
    try:
        model_name = "google/flan-t5-small"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        
        def tokenize_function(examples):
            model_inputs = tokenizer(examples['input_text'], max_length=512, truncation=True, padding="max_length")
            labels = tokenizer(examples['target_text'], max_length=128, truncation=True, padding="max_length")
            model_inputs["labels"] = labels["input_ids"]
            return model_inputs
        
        tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=['input_text', 'target_text'])
        
        training_args = TrainingArguments(
            output_dir="./fine_tuned_model",
            num_train_epochs=2,
            per_device_train_batch_size=4,
            save_steps=500,
            logging_steps=100,
            evaluation_strategy="no",
            learning_rate=3e-5,
            fp16=False,
        )
        
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_dataset,
        )
        trainer.train()
        
        model.save_pretrained("./fine_tuned_model")
        tokenizer.save_pretrained("./fine_tuned_model")
        logger.info("Fine-tuning complete")
        return pipeline("text2text-generation", model="./fine_tuned_model", tokenizer="./fine_tuned_model", max_length=300)
    except Exception as e:
        logger.error(f"Fine-tuning error: {str(e)}")
        return None

# Augment vector store with dataset
def augment_vector_store(vector_store, dataset_name="squad", max_samples=300):
    logger.info(f"Augmenting vector store with dataset: {dataset_name}")
    try:
        dataset = load_dataset(dataset_name, split="train").select(range(min(max_samples, len(dataset))))
        chunks = [f"Context: {c}\nAnswer: {a['text'][0]}" for c, a in zip(dataset['context'], dataset['answers'])]
        embeddings_model = load_embeddings_model()
        if embeddings_model and vector_store:
            embeddings = embeddings_model.encode(chunks, batch_size=128, show_progress_bar=False)
            vector_store.add_embeddings(zip(chunks, embeddings))
        return vector_store
    except Exception as e:
        logger.error(f"Vector store augmentation error: {str(e)}")
        return vector_store

# Process PDF with enhanced extraction and OCR fallback
def process_pdf(uploaded_file):
    logger.info("Processing PDF with enhanced extraction")
    try:
        text = ""
        code_blocks = []
        images = []
        with pdfplumber.open(BytesIO(uploaded_file.getvalue())) as pdf:
            for page in pdf.pages[:8]:
                extracted = page.extract_text(layout=False)
                if not extracted:
                    try:
                        img = page.to_image(resolution=150).original
                        extracted = pytesseract.image_to_string(img, config='--psm 6')
                        images.append(img)
                    except Exception as ocr_e:
                        logger.warning(f"OCR failed: {str(ocr_e)}")
                if extracted:
                    lines = extracted.split("\n")
                    cleaned_lines = [line for line in lines if not re.match(r'^\s*(Page \d+|.*\d{4}-\d{4}|Copyright.*)\s*$', line, re.I)]
                    text += "\n".join(cleaned_lines) + "\n"
                for char in page.chars:
                    if 'fontname' in char and 'mono' in char['fontname'].lower():
                        code_blocks.append(char['text'])
                code_text = page.extract_text()
                code_matches = re.finditer(r'(^\s{2,}.*?(?:\n\s{2,}.*?)*)', code_text, re.MULTILINE)
                for match in code_matches:
                    code_blocks.append(match.group().strip())
                tables = page.extract_tables()
                if tables:
                    for table in tables:
                        text += "\n".join([" | ".join(map(str, row)) for row in table if row]) + "\n"
                for obj in page.extract_words():
                    if obj.get('size', 0) > 12:
                        text += f"\n{obj['text']}\n"

        code_text = "\n".join(code_blocks).strip()
        if not text:
            raise ValueError("No text extracted from PDF")
        
        text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=250, chunk_overlap=40, keep_separator=True)
        text_chunks = text_splitter.split_text(text)[:25]
        code_chunks = text_splitter.split_text(code_text)[:10] if code_text else []
        
        embeddings_model = load_embeddings_model()
        if not embeddings_model:
            return None, None, text, code_text, images
        
        text_vector_store = FAISS.from_embeddings(
            zip(text_chunks, [embeddings_model.encode(chunk, show_progress_bar=False, batch_size=128) for chunk in text_chunks]),
            embeddings_model.encode
        ) if text_chunks else None
        code_vector_store = FAISS.from_embeddings(
            zip(code_chunks, [embeddings_model.encode(chunk, show_progress_bar=False, batch_size=128) for chunk in code_chunks]),
            embeddings_model.encode
        ) if code_chunks else None
        
        if text_vector_store:
            text_vector_store = augment_vector_store(text_vector_store)
        
        logger.info("PDF processed successfully")
        return text_vector_store, code_vector_store, text, code_text, images
    except Exception as e:
        logger.error(f"PDF processing error: {str(e)}")
        st.error(f"PDF error: {str(e)}")
        return None, None, "", "", []

# Summarize PDF with ROUGE metrics and improved topic focus
def summarize_pdf(text):
    logger.info("Generating summary")
    try:
        summary_pipeline = load_summary_pipeline()
        if not summary_pipeline:
            return "Summary model unavailable."
        
        text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=250, chunk_overlap=40)
        chunks = text_splitter.split_text(text)
        
        # Hybrid search for relevant chunks
        embeddings_model = load_embeddings_model()
        if embeddings_model and chunks:
            temp_vector_store = FAISS.from_embeddings(
                zip(chunks, [embeddings_model.encode(chunk, show_progress_bar=False) for chunk in chunks]),
                embeddings_model.encode
            )
            bm25 = BM25Okapi([chunk.split() for chunk in chunks])
            query = "main topic and key points"
            bm25_docs = bm25.get_top_n(query.split(), chunks, n=4)
            faiss_docs = temp_vector_store.similarity_search(query, k=4)
            selected_chunks = list(set(bm25_docs + [doc.page_content for doc in faiss_docs]))[:4]
        else:
            selected_chunks = chunks[:4]
        
        summaries = []
        for chunk in selected_chunks:
            summary = summary_pipeline(f"Summarize the main topic and key points in detail: {chunk[:250]}", max_length=100, min_length=50, do_sample=False)[0]['summary_text']
            summaries.append(summary.strip())
        
        combined_summary = " ".join(summaries)
        if len(combined_summary.split()) > 250:
            combined_summary = " ".join(combined_summary.split()[:250])
        
        word_count = len(combined_summary.split())
        scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
        scores = scorer.score(text[:500], combined_summary)
        logger.info(f"ROUGE scores: {scores}")
        
        return f"**Main Topic Summary** ({word_count} words):\n{combined_summary}\n\n**ROUGE-1**: {scores['rouge1'].fmeasure:.2f}"
    except Exception as e:
        logger.error(f"Summary error: {str(e)}")
        return f"Oops, something went wrong summarizing: {str(e)}"

# Answer question with hybrid search
def answer_question(text_vector_store, code_vector_store, query):
    logger.info(f"Processing query: {query}")
    try:
        if not text_vector_store and not code_vector_store:
            return "Please upload a PDF first!"
        
        qa_pipeline = load_qa_pipeline()
        if not qa_pipeline:
            return "Sorry, the QA model is unavailable right now."
        
        is_code_query = any(keyword in query.lower() for keyword in ["code", "script", "function", "programming", "give me code", "show code"])
        if is_code_query and code_vector_store:
            docs = code_vector_store.similarity_search(query, k=3)
            code = "\n".join(doc.page_content for doc in docs)
            explanation = qa_pipeline(f"Explain this code: {code[:500]}")[0]['generated_text']
            return f"**Code**:\n```python\n{code}\n```\n**Explanation**:\n{explanation}"
        
        vector_store = text_vector_store
        if not vector_store:
            return "No relevant content found for your query."
        
        # Hybrid search: FAISS + BM25
        text_chunks = [doc.page_content for doc in vector_store.similarity_search(query, k=10)]
        bm25 = BM25Okapi([chunk.split() for chunk in text_chunks])
        bm25_docs = bm25.get_top_n(query.split(), text_chunks, n=5)
        faiss_docs = vector_store.similarity_search(query, k=5)
        combined_docs = list(set(bm25_docs + [doc.page_content for doc in faiss_docs]))[:5]
        context = "\n".join(combined_docs)
        
        prompt = f"Use the following PDF content to answer the question accurately and concisely. Avoid speculation and focus on the provided context:\n\n{context}\n\nQuestion: {query}\nAnswer:"
        response = qa_pipeline(prompt)[0]['generated_text']
        logger.info("Answer generated")
        return f"**Answer**:\n{response.strip()}\n\n**Source Context**:\n{context[:500]}..."
    except Exception as e:
        logger.error(f"Query error: {str(e)}")
        return f"Sorry, something went wrong: {str(e)}"

# Streamlit UI
try:
    st.set_page_config(page_title="Smart PDF Q&A", page_icon="📄", layout="wide")
    st.markdown("""
        <style>
        .main { max-width: 900px; margin: 0 auto; padding: 20px; }
        .sidebar { background-color: #f8f9fa; padding: 10px; border-radius: 5px; }
        .message { margin: 10px 0; padding: 10px; border-radius: 5px; display: block; }
        .user { background-color: #e6f3ff; }
        .assistant { background-color: #f0f0f0; }
        .dark .user { background-color: #2a2a72; color: #fff; }
        .dark .assistant { background-color: #2e2e2e; color: #fff; }
        .stButton>button { background-color: #4CAF50; color: white; border: none; padding: 8px 16px; border-radius: 5px; }
        .stButton>button:hover { background-color: #45a049; }
        pre { background-color: #f8f8f8; padding: 10px; border-radius: 5px; overflow-x: auto; }
        .header { background: linear-gradient(90deg, #4CAF50, #81C784); color: white; padding: 10px; border-radius: 5px; text-align: center; }
        .progress-bar { background-color: #e0e0e0; border-radius: 5px; height: 10px; }
        .progress-fill { background-color: #4CAF50; height: 100%; border-radius: 5px; transition: width 0.5s ease; }
        </style>
    """, unsafe_allow_html=True)

    st.markdown('<div class="header"><h1>Smart PDF Q&A</h1></div>', unsafe_allow_html=True)
    st.markdown("Upload a PDF to ask questions, summarize (~150 words), or extract code with 'give me code'. Fast and friendly responses!")

    # Initialize session state
    if "messages" not in st.session_state:
        st.session_state.messages = [{"role": "assistant", "content": "Hello! Upload a PDF and process it to start chatting."}]
    if "text_vector_store" not in st.session_state:
        st.session_state.text_vector_store = None
    if "code_vector_store" not in st.session_state:
        st.session_state.code_vector_store = None
    if "pdf_text" not in st.session_state:
        st.session_state.pdf_text = ""
    if "code_text" not in st.session_state:
        st.session_state.code_text = ""
    if "images" not in st.session_state:
        st.session_state.images = []

    # Sidebar with toggle
    with st.sidebar:
        st.markdown('<div class="sidebar">', unsafe_allow_html=True)
        theme = st.radio("Theme", ["Light", "Dark"], index=0)
        dataset_name = st.selectbox("Select Dataset for Fine-Tuning", ["squad", "cnn_dailymail", "bigcode/the-stack"], index=0)
        if st.button("Fine-Tune Model"):
            progress_bar = st.progress(0)
            for i in range(100):
                time.sleep(0.008)
                progress_bar.progress(i + 1)
            dataset = load_and_prepare_dataset(dataset_name=dataset_name)
            if dataset:
                fine_tuned_pipeline = fine_tune_qa_model(dataset)
                if fine_tuned_pipeline:
                    st.success("Model fine-tuned successfully!")
                else:
                    st.error("Fine-tuning failed.")
        if st.button("Clear Chat"):
            st.session_state.messages = []
            st.experimental_rerun()
        if st.button("Retry Summarization") and st.session_state.pdf_text:
            progress_bar = st.progress(0)
            with st.spinner("Retrying summarization..."):
                for i in range(100):
                    time.sleep(0.008)
                    progress_bar.progress(i + 1)
                summary = summarize_pdf(st.session_state.pdf_text)
                st.session_state.messages.append({"role": "assistant", "content": summary})
                st.markdown(summary, unsafe_allow_html=True)
        st.markdown('</div>', unsafe_allow_html=True)

    # PDF upload and processing
    uploaded_file = st.file_uploader("Upload a PDF", type=["pdf"])
    col1, col2 = st.columns([1, 1])
    with col1:
        if st.button("Process PDF"):
            progress_bar = st.progress(0)
            with st.spinner("Processing PDF..."):
                for i in range(100):
                    time.sleep(0.02)
                    progress_bar.progress(i + 1)
                st.session_state.text_vector_store, st.session_state.code_vector_store, st.session_state.pdf_text, st.session_state.code_text, st.session_state.images = process_pdf(uploaded_file)
                if st.session_state.text_vector_store or st.session_state.code_vector_store:
                    st.success("PDF processed! Ask away or summarize.")
                    st.session_state.messages = [{"role": "assistant", "content": "PDF processed! What would you like to know?"}]
                else:
                    st.error("Failed to process PDF.")
    with col2:
        if st.button("Summarize PDF") and st.session_state.pdf_text:
            progress_bar = st.progress(0)
            with st.spinner("Summarizing..."):
                for i in range(100):
                    time.sleep(0.008)
                    progress_bar.progress(i + 1)
                summary = summarize_pdf(st.session_state.pdf_text)
                st.session_state.messages.append({"role": "assistant", "content": summary})
                st.markdown(summary, unsafe_allow_html=True)

    # Chat interface
    if st.session_state.text_vector_store or st.session_state.code_vector_store:
        prompt = st.chat_input("Ask a question (e.g., 'Give me code' or 'What’s the main idea?'):")
        if prompt:
            st.session_state.messages.append({"role": "user", "content": prompt})
            with st.chat_message("user"):
                st.markdown(prompt)
            with st.chat_message("assistant"):
                progress_bar = st.progress(0)
                with st.spinner('<div class="spinner">⏳ Processing...</div>'):
                    for i in range(100):
                        time.sleep(0.004)
                        progress_bar.progress(i + 1)
                    answer = answer_question(st.session_state.text_vector_store, st.session_state.code_vector_store, prompt)
                st.markdown(answer, unsafe_allow_html=True)
            st.session_state.messages.append({"role": "assistant", "content": answer})

    # Display chat history
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"], unsafe_allow_html=True)

    # Display extracted images
    if st.session_state.images:
        st.header("Extracted Images")
        for img in st.session_state.images:
            st.image(img, caption="Extracted PDF Image", use_column_width=True)

    # Download chat history
    if st.session_state.messages:
        chat_text = "\n".join(f"{m['role'].capitalize()}: {m['content']}" for m in st.session_state.messages)
        st.download_button("Download Chat History", chat_text, "chat_history.txt")

except Exception as e:
    logger.error(f"App initialization failed: {str(e)}")
    st.error(f"App failed to start: {str(e)}. Check Spaces logs or contact support.")