Upload 8 files
Browse files- app.py +158 -0
- backend/extractors.py +57 -0
- backend/file_handler.py +37 -0
- backend/image_processor.py +101 -0
- backend/models.py +56 -0
- backend/qa_engine.py +65 -0
- backend/response_formatter.py +48 -0
- requirements.txt +14 -0
app.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
2 |
+
from fastapi.responses import JSONResponse
|
3 |
+
from fastapi.staticfiles import StaticFiles
|
4 |
+
from fastapi.templating import Jinja2Templates
|
5 |
+
from starlette.requests import Request
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Optional, List
|
10 |
+
|
11 |
+
# Import your backend modules
|
12 |
+
from backend.file_handler import save_upload
|
13 |
+
from backend.extractors import extract_text
|
14 |
+
from backend.qa_engine import QAEngine
|
15 |
+
from backend.image_processor import ImageProcessor
|
16 |
+
from backend.response_formatter import ResponseFormatter
|
17 |
+
|
18 |
+
app = FastAPI(
|
19 |
+
title="Intelligent QA Service",
|
20 |
+
description="Question answering for documents and images"
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
# Try to set Hugging Face token if available in environment
|
26 |
+
huggingface_token = os.environ.get("HF_TOKEN")
|
27 |
+
if huggingface_token:
|
28 |
+
from huggingface_hub import login
|
29 |
+
login(token=huggingface_token)
|
30 |
+
|
31 |
+
# Initialize models with fallback options
|
32 |
+
try:
|
33 |
+
qa_engine = QAEngine(model_name="distilbert-base-cased-distilled-squad") # Use a public model
|
34 |
+
except Exception as e:
|
35 |
+
print(f"Error initializing QA engine: {str(e)}")
|
36 |
+
# Fallback to a simpler implementation if needed
|
37 |
+
from backend.qa_engine import SimpleQAEngine
|
38 |
+
qa_engine = SimpleQAEngine()
|
39 |
+
|
40 |
+
|
41 |
+
try:
|
42 |
+
image_processor = ImageProcessor()
|
43 |
+
except Exception as e:
|
44 |
+
print(f"Error initializing Image Processor: {str(e)}")
|
45 |
+
# Create a fallback image processor if needed
|
46 |
+
from backend.image_processor import SimpleImageProcessor
|
47 |
+
image_processor = SimpleImageProcessor()
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
formatter = ResponseFormatter()
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
# Mount static files and templates
|
56 |
+
templates = Jinja2Templates(directory="frontend/templates")
|
57 |
+
app.mount("/static", StaticFiles(directory="frontend/static"), name="static")
|
58 |
+
|
59 |
+
@app.get("/")
|
60 |
+
async def read_root(request: Request):
|
61 |
+
"""Render the main page"""
|
62 |
+
return templates.TemplateResponse("index.html", {"request": request})
|
63 |
+
|
64 |
+
@app.post("/api/document-qa")
|
65 |
+
async def document_qa(
|
66 |
+
file: UploadFile = File(...),
|
67 |
+
question: str = Form(...)
|
68 |
+
):
|
69 |
+
"""Process document and answer question"""
|
70 |
+
try:
|
71 |
+
# Save the uploaded file
|
72 |
+
file_id, file_name = save_upload(file)
|
73 |
+
file_path = Path(f"/tmp/uploads/{file_name}")
|
74 |
+
|
75 |
+
# Extract text from document
|
76 |
+
document_text = extract_text(str(file_path))
|
77 |
+
|
78 |
+
# Get answer from QA engine
|
79 |
+
if isinstance(document_text, dict):
|
80 |
+
# Handle structured document text
|
81 |
+
# This is a simplistic approach - you'd need to convert the
|
82 |
+
# structured content to plain text for the QA engine
|
83 |
+
if "content" in document_text:
|
84 |
+
if isinstance(document_text["content"], list):
|
85 |
+
if isinstance(document_text["content"][0], dict):
|
86 |
+
# Handle docx structure
|
87 |
+
text = " ".join([p["text"] for p in document_text["content"]])
|
88 |
+
else:
|
89 |
+
# Handle txt structure
|
90 |
+
text = " ".join(document_text["content"])
|
91 |
+
else:
|
92 |
+
text = str(document_text["content"])
|
93 |
+
else:
|
94 |
+
text = str(document_text)
|
95 |
+
else:
|
96 |
+
# Plain text from PDF or PPTX
|
97 |
+
text = document_text
|
98 |
+
|
99 |
+
qa_result = qa_engine.answer_question(text, question)
|
100 |
+
qa_result["timestamp"] = time.time()
|
101 |
+
|
102 |
+
# Format response
|
103 |
+
response = formatter.format_document_qa_response(qa_result, file.filename)
|
104 |
+
|
105 |
+
return JSONResponse(content=response)
|
106 |
+
|
107 |
+
except Exception as e:
|
108 |
+
error_response = formatter.format_error_response(str(e))
|
109 |
+
return JSONResponse(content=error_response, status_code=error_response["status_code"])
|
110 |
+
|
111 |
+
@app.post("/api/image-qa")
|
112 |
+
async def image_qa(
|
113 |
+
file: UploadFile = File(...),
|
114 |
+
question: str = Form(...)
|
115 |
+
):
|
116 |
+
"""Process image and answer question"""
|
117 |
+
try:
|
118 |
+
print(f"Received image: {file.filename}, size: {file.size}, question: {question}")
|
119 |
+
|
120 |
+
# Validate file is an image
|
121 |
+
if not file.content_type.startswith('image/'):
|
122 |
+
print(f"Invalid content type: {file.content_type}")
|
123 |
+
return JSONResponse(
|
124 |
+
content={"error": "File must be an image", "status_code": 400},
|
125 |
+
status_code=400
|
126 |
+
)
|
127 |
+
|
128 |
+
# Save the uploaded file
|
129 |
+
file_id, file_name = save_upload(file)
|
130 |
+
file_path = Path(f"/tmp/uploads/{file_name}")
|
131 |
+
print(f"Saved image to: {file_path}")
|
132 |
+
|
133 |
+
if not file_path.exists():
|
134 |
+
print(f"File not saved properly at {file_path}")
|
135 |
+
return JSONResponse(
|
136 |
+
content={"error": "File could not be saved", "status_code": 500},
|
137 |
+
status_code=500
|
138 |
+
)
|
139 |
+
|
140 |
+
# Process the image
|
141 |
+
vqa_result = image_processor.answer_image_question(str(file_path), question)
|
142 |
+
vqa_result["timestamp"] = time.time()
|
143 |
+
|
144 |
+
# Format response
|
145 |
+
response = formatter.format_image_qa_response(vqa_result, file.filename)
|
146 |
+
|
147 |
+
return JSONResponse(content=response)
|
148 |
+
|
149 |
+
except Exception as e:
|
150 |
+
import traceback
|
151 |
+
print(f"Error in image_qa: {str(e)}")
|
152 |
+
print(traceback.format_exc())
|
153 |
+
error_response = formatter.format_error_response(str(e))
|
154 |
+
return JSONResponse(content=error_response, status_code=error_response.get("status_code", 500))
|
155 |
+
|
156 |
+
if __name__ == "__main__":
|
157 |
+
import uvicorn
|
158 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
backend/extractors.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import fitz # PyMuPDF
|
3 |
+
from docx import Document
|
4 |
+
from pptx import Presentation
|
5 |
+
|
6 |
+
# PDF
|
7 |
+
def extract_text_pdf(file_path: str) -> str:
|
8 |
+
text = ""
|
9 |
+
with fitz.open(file_path) as doc:
|
10 |
+
for page in doc:
|
11 |
+
text += page.get_text()
|
12 |
+
return text.strip()
|
13 |
+
|
14 |
+
# DOCX
|
15 |
+
def extract_text_docx(file_path: str) -> dict:
|
16 |
+
doc = Document(file_path)
|
17 |
+
paragraphs = []
|
18 |
+
|
19 |
+
for para in doc.paragraphs:
|
20 |
+
text = para.text.strip()
|
21 |
+
if text: # Only include non-empty paragraphs
|
22 |
+
paragraphs.append({
|
23 |
+
"style": para.style.name,
|
24 |
+
"text": text
|
25 |
+
})
|
26 |
+
|
27 |
+
return {"content": paragraphs}
|
28 |
+
|
29 |
+
# TXT
|
30 |
+
def extract_text_txt(file_path: str) -> str:
|
31 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
32 |
+
lines = f.read().splitlines() # ✅ split into clean lines
|
33 |
+
return {"content": lines}
|
34 |
+
|
35 |
+
# PPTX
|
36 |
+
def extract_text_pptx(file_path: str) -> str:
|
37 |
+
prs = Presentation(file_path)
|
38 |
+
text = []
|
39 |
+
for slide in prs.slides:
|
40 |
+
for shape in slide.shapes:
|
41 |
+
if hasattr(shape, "text"):
|
42 |
+
text.append(shape.text)
|
43 |
+
return "\n".join(text).strip()
|
44 |
+
|
45 |
+
# Dispatcher
|
46 |
+
def extract_text(file_path: str) -> str:
|
47 |
+
ext = Path(file_path).suffix.lower()
|
48 |
+
if ext == ".pdf":
|
49 |
+
return extract_text_pdf(file_path)
|
50 |
+
elif ext == ".docx":
|
51 |
+
return extract_text_docx(file_path)
|
52 |
+
elif ext == ".txt":
|
53 |
+
return extract_text_txt(file_path)
|
54 |
+
elif ext == ".pptx":
|
55 |
+
return extract_text_pptx(file_path)
|
56 |
+
else:
|
57 |
+
raise ValueError("Unsupported file extension")
|
backend/file_handler.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import UploadFile, HTTPException
|
2 |
+
from pathlib import Path
|
3 |
+
from uuid import uuid4
|
4 |
+
|
5 |
+
# Accepted file types
|
6 |
+
ALLOWED_TYPES = {
|
7 |
+
"application/pdf": ".pdf",
|
8 |
+
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
9 |
+
"text/plain": ".txt",
|
10 |
+
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx"
|
11 |
+
}
|
12 |
+
|
13 |
+
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
|
14 |
+
|
15 |
+
tmp_dir = Path("/tmp/uploads")
|
16 |
+
tmp_dir.mkdir(parents=True, exist_ok=True)
|
17 |
+
|
18 |
+
UPLOAD_DIR = tmp_dir
|
19 |
+
|
20 |
+
def save_upload(file: UploadFile) -> tuple[str, str]:
|
21 |
+
if file.content_type not in ALLOWED_TYPES:
|
22 |
+
raise HTTPException(status_code=400, detail="Unsupported file type.")
|
23 |
+
|
24 |
+
# Read file into memory to check size and save it
|
25 |
+
file_bytes = file.file.read()
|
26 |
+
|
27 |
+
if len(file_bytes) > MAX_FILE_SIZE:
|
28 |
+
raise HTTPException(status_code=413, detail="File is too large. Maximum size is 10MB.")
|
29 |
+
|
30 |
+
file_ext = ALLOWED_TYPES[file.content_type]
|
31 |
+
file_id = str(uuid4())
|
32 |
+
file_path = UPLOAD_DIR / f"{file_id}{file_ext}"
|
33 |
+
|
34 |
+
with open(file_path, "wb") as f:
|
35 |
+
f.write(file_bytes)
|
36 |
+
|
37 |
+
return file_id, file_path.name
|
backend/image_processor.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
|
2 |
+
from transformers import ViltProcessor, ViltForQuestionAnswering # Add these imports
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
from typing import Dict, Any, Union, List
|
6 |
+
|
7 |
+
class ImageProcessor:
|
8 |
+
def __init__(self, caption_model_name: str = "nlpconnect/vit-gpt2-image-captioning",
|
9 |
+
vqa_model_name: str = "dandelin/vilt-b32-finetuned-vqa"):
|
10 |
+
# Image captioning model
|
11 |
+
self.caption_processor = ViTImageProcessor.from_pretrained(caption_model_name)
|
12 |
+
self.tokenizer = AutoTokenizer.from_pretrained(caption_model_name)
|
13 |
+
self.caption_model = VisionEncoderDecoderModel.from_pretrained(caption_model_name)
|
14 |
+
|
15 |
+
# VQA model
|
16 |
+
self.vqa_processor = ViltProcessor.from_pretrained(vqa_model_name)
|
17 |
+
self.vqa_model = ViltForQuestionAnswering.from_pretrained(vqa_model_name)
|
18 |
+
|
19 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
self.caption_model.to(self.device)
|
21 |
+
self.vqa_model.to(self.device)
|
22 |
+
|
23 |
+
def generate_caption(self, image_path: str) -> str:
|
24 |
+
"""Generate a descriptive caption for the provided image"""
|
25 |
+
try:
|
26 |
+
image = Image.open(image_path).convert("RGB")
|
27 |
+
pixel_values = self.caption_processor(image, return_tensors="pt").pixel_values.to(self.device)
|
28 |
+
|
29 |
+
gen_kwargs = {
|
30 |
+
"max_length": 50,
|
31 |
+
"num_beams": 4,
|
32 |
+
"early_stopping": True
|
33 |
+
}
|
34 |
+
|
35 |
+
output_ids = self.caption_model.generate(pixel_values, **gen_kwargs)
|
36 |
+
caption = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
37 |
+
|
38 |
+
return caption
|
39 |
+
except Exception as e:
|
40 |
+
return f"Error processing image: {str(e)}"
|
41 |
+
|
42 |
+
def answer_image_question(self, image_path: str, question: str) -> Dict[str, Any]:
|
43 |
+
"""Answer a question about the provided image using a Visual QA model"""
|
44 |
+
try:
|
45 |
+
# Open image
|
46 |
+
image = Image.open(image_path).convert("RGB")
|
47 |
+
|
48 |
+
# Prepare inputs
|
49 |
+
inputs = self.vqa_processor(image, question, return_tensors="pt")
|
50 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
51 |
+
|
52 |
+
# Forward pass
|
53 |
+
with torch.no_grad():
|
54 |
+
outputs = self.vqa_model(**inputs)
|
55 |
+
|
56 |
+
# Get answer
|
57 |
+
logits = outputs.logits
|
58 |
+
idx = logits.argmax(-1).item()
|
59 |
+
answer = self.vqa_model.config.id2label[idx]
|
60 |
+
confidence = torch.softmax(logits, dim=-1)[0, idx].item()
|
61 |
+
|
62 |
+
return {"answer": answer, "confidence": confidence}
|
63 |
+
|
64 |
+
except Exception as e:
|
65 |
+
print(f"VQA Error: {str(e)}")
|
66 |
+
# Fallback to caption
|
67 |
+
try:
|
68 |
+
caption = self.generate_caption(image_path)
|
69 |
+
return {
|
70 |
+
"answer": f"Based on the image which shows {caption}, I cannot provide a specific answer.",
|
71 |
+
"confidence": 0.0
|
72 |
+
}
|
73 |
+
except Exception as e2:
|
74 |
+
return {"answer": f"Error processing image: {str(e)}, {str(e2)}", "confidence": 0.0}
|
75 |
+
|
76 |
+
|
77 |
+
# Add this to the end of your image_processor.py file
|
78 |
+
|
79 |
+
class SimpleImageProcessor:
|
80 |
+
"""A simple fallback image processor that doesn't require external models"""
|
81 |
+
|
82 |
+
def __init__(self):
|
83 |
+
"""Initialize without any models"""
|
84 |
+
print("Using SimpleImageProcessor fallback")
|
85 |
+
|
86 |
+
def generate_caption(self, image_path: str) -> str:
|
87 |
+
"""Generate a basic caption for the provided image"""
|
88 |
+
try:
|
89 |
+
# Just extract basic image information
|
90 |
+
from PIL import Image
|
91 |
+
img = Image.open(image_path)
|
92 |
+
return f"an image of size {img.width}x{img.height}"
|
93 |
+
except Exception as e:
|
94 |
+
return f"an image (could not process: {str(e)})"
|
95 |
+
|
96 |
+
def answer_image_question(self, image_path: str, question: str) -> Dict[str, Any]:
|
97 |
+
"""Provide a fallback answer for image questions"""
|
98 |
+
return {
|
99 |
+
"answer": "I cannot analyze this image right now. The image processing system is not fully functional.",
|
100 |
+
"confidence": 0.0
|
101 |
+
}
|
backend/models.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import pipeline
|
2 |
+
import os
|
3 |
+
from typing import Dict, Any, Optional
|
4 |
+
|
5 |
+
# Singleton model manager
|
6 |
+
class ModelManager:
|
7 |
+
_instance = None
|
8 |
+
|
9 |
+
@classmethod
|
10 |
+
def get_instance(cls):
|
11 |
+
if cls._instance is None:
|
12 |
+
cls._instance = cls()
|
13 |
+
return cls._instance
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
self.pipelines = {}
|
17 |
+
|
18 |
+
# Set default models - preferably small ones that are publicly accessible
|
19 |
+
self.model_configs = {
|
20 |
+
"document_qa": {
|
21 |
+
"name": "distilbert-base-cased-distilled-squad", # Smaller, public model
|
22 |
+
"type": "question-answering"
|
23 |
+
},
|
24 |
+
"image_captioning": {
|
25 |
+
"name": "Salesforce/blip-image-captioning-base", # Public model
|
26 |
+
"type": "image-to-text"
|
27 |
+
}
|
28 |
+
}
|
29 |
+
|
30 |
+
def load_pipeline(self, pipeline_type: str) -> bool:
|
31 |
+
"""Load a specific pipeline if it's not already loaded"""
|
32 |
+
if pipeline_type not in self.model_configs:
|
33 |
+
return False
|
34 |
+
|
35 |
+
if pipeline_type in self.pipelines:
|
36 |
+
return True
|
37 |
+
|
38 |
+
config = self.model_configs[pipeline_type]
|
39 |
+
model_name = config["name"]
|
40 |
+
|
41 |
+
try:
|
42 |
+
if config["type"] == "question-answering":
|
43 |
+
self.pipelines[pipeline_type] = pipeline("question-answering", model=model_name)
|
44 |
+
elif config["type"] == "image-to-text":
|
45 |
+
self.pipelines[pipeline_type] = pipeline("image-to-text", model=model_name)
|
46 |
+
|
47 |
+
return True
|
48 |
+
except Exception as e:
|
49 |
+
print(f"Error loading pipeline {model_name}: {str(e)}")
|
50 |
+
return False
|
51 |
+
|
52 |
+
def get_pipeline(self, pipeline_type: str) -> Optional[Any]:
|
53 |
+
"""Get a loaded pipeline or load it if not already loaded"""
|
54 |
+
if pipeline_type not in self.pipelines and not self.load_pipeline(pipeline_type):
|
55 |
+
return None
|
56 |
+
return self.pipelines[pipeline_type]
|
backend/qa_engine.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import pipeline
|
2 |
+
import torch
|
3 |
+
from typing import Dict, List, Any
|
4 |
+
|
5 |
+
class QAEngine:
|
6 |
+
def __init__(self, model_name: str = "deepset/roberta-base-squad2"):
|
7 |
+
# Use the pipeline API which works better with Hugging Face Spaces
|
8 |
+
self.qa_pipeline = pipeline("question-answering", model=model_name)
|
9 |
+
|
10 |
+
def answer_question(self, context: str, question: str) -> Dict[str, Any]:
|
11 |
+
"""Answer a question based on the provided context"""
|
12 |
+
try:
|
13 |
+
# Use the pipeline directly
|
14 |
+
result = self.qa_pipeline(question=question, context=context)
|
15 |
+
|
16 |
+
return {
|
17 |
+
"answer": result["answer"],
|
18 |
+
"confidence": result["score"],
|
19 |
+
"start_position": result["start"],
|
20 |
+
"end_position": result["end"]
|
21 |
+
}
|
22 |
+
except Exception as e:
|
23 |
+
return {
|
24 |
+
"answer": f"Error processing question: {str(e)}",
|
25 |
+
"confidence": 0.0,
|
26 |
+
"start_position": 0,
|
27 |
+
"end_position": 0
|
28 |
+
}
|
29 |
+
|
30 |
+
def answer_multiple_questions(self, context: str, questions: List[str]) -> List[Dict[str, Any]]:
|
31 |
+
"""Answer multiple questions from the same context"""
|
32 |
+
return [self.answer_question(context, question) for question in questions]
|
33 |
+
|
34 |
+
|
35 |
+
# Add this to qa_engine.py
|
36 |
+
class SimpleQAEngine:
|
37 |
+
"""A simple QA engine that doesn't rely on complex models"""
|
38 |
+
|
39 |
+
def answer_question(self, context: str, question: str) -> Dict[str, Any]:
|
40 |
+
"""Basic keyword-based answer extraction (fallback when models fail)"""
|
41 |
+
# Very basic implementation - just finds sentences with keywords from the question
|
42 |
+
from nltk.tokenize import sent_tokenize
|
43 |
+
try:
|
44 |
+
import nltk
|
45 |
+
nltk.download('punkt', quiet=True)
|
46 |
+
except:
|
47 |
+
pass
|
48 |
+
|
49 |
+
question_words = set(question.lower().split())
|
50 |
+
best_sentence = ""
|
51 |
+
best_score = 0
|
52 |
+
|
53 |
+
for sentence in sent_tokenize(context):
|
54 |
+
sentence_words = set(sentence.lower().split())
|
55 |
+
overlap = len(question_words.intersection(sentence_words))
|
56 |
+
if overlap > best_score:
|
57 |
+
best_score = overlap
|
58 |
+
best_sentence = sentence
|
59 |
+
|
60 |
+
return {
|
61 |
+
"answer": best_sentence if best_score > 0 else "No relevant information found.",
|
62 |
+
"confidence": min(best_score / max(1, len(question_words)), 1.0),
|
63 |
+
"start_position": context.find(best_sentence) if best_sentence in context else 0,
|
64 |
+
"end_position": context.find(best_sentence) + len(best_sentence) if best_sentence in context else 0
|
65 |
+
}
|
backend/response_formatter.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any, List, Union
|
2 |
+
|
3 |
+
class ResponseFormatter:
|
4 |
+
@staticmethod
|
5 |
+
def format_document_qa_response(qa_result: Dict[str, Any], document_name: str) -> Dict[str, Any]:
|
6 |
+
"""Format the response from the QA engine for document questions"""
|
7 |
+
formatted_response = {
|
8 |
+
"document": document_name,
|
9 |
+
"answer": qa_result.get("answer", "No answer found"),
|
10 |
+
"confidence": round(qa_result.get("confidence", 0) * 100, 2),
|
11 |
+
"metadata": {
|
12 |
+
"source_type": "document",
|
13 |
+
"timestamp": qa_result.get("timestamp")
|
14 |
+
}
|
15 |
+
}
|
16 |
+
|
17 |
+
# Add highlighted text positions if available
|
18 |
+
if "start_position" in qa_result and "end_position" in qa_result:
|
19 |
+
formatted_response["highlight"] = {
|
20 |
+
"start": qa_result["start_position"],
|
21 |
+
"end": qa_result["end_position"]
|
22 |
+
}
|
23 |
+
|
24 |
+
return formatted_response
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def format_image_qa_response(vqa_result: Dict[str, Any], image_name: str) -> Dict[str, Any]:
|
28 |
+
"""Format the response from the image QA engine"""
|
29 |
+
formatted_response = {
|
30 |
+
"image": image_name,
|
31 |
+
"answer": vqa_result.get("answer", "No answer found"),
|
32 |
+
"confidence": round(vqa_result.get("confidence", 0) * 100, 2),
|
33 |
+
"metadata": {
|
34 |
+
"source_type": "image",
|
35 |
+
"timestamp": vqa_result.get("timestamp")
|
36 |
+
}
|
37 |
+
}
|
38 |
+
|
39 |
+
return formatted_response
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def format_error_response(error_message: str, status_code: int = 400) -> Dict[str, Any]:
|
43 |
+
"""Format error responses"""
|
44 |
+
return {
|
45 |
+
"error": True,
|
46 |
+
"message": error_message,
|
47 |
+
"status_code": status_code
|
48 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.24.3
|
2 |
+
fastapi==0.95.1
|
3 |
+
uvicorn==0.22.0
|
4 |
+
python-multipart==0.0.6
|
5 |
+
Jinja2==3.1.2
|
6 |
+
torch==2.0.1
|
7 |
+
transformers==4.30.2
|
8 |
+
accelerate==0.20.3
|
9 |
+
sentencepiece==0.1.99
|
10 |
+
pillow==9.5.0
|
11 |
+
PyMuPDF==1.22.5
|
12 |
+
python-docx==0.8.11
|
13 |
+
python-pptx==0.6.21
|
14 |
+
huggingface-hub[hf_xet]
|