Spaces:
Build error
Build error
import streamlit as st | |
import pickle | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image, UnidentifiedImageError | |
from sklearn.metrics.pairwise import cosine_similarity | |
import os | |
from pdf2image import convert_from_path | |
from streamlit_cropper import st_cropper | |
import easyocr | |
from reportlab.lib.pagesizes import letter | |
from reportlab.pdfgen import canvas | |
from reportlab.lib.utils import ImageReader | |
import io | |
import base64 | |
# ------------------- | |
# Set page config (must be done before other elements) | |
# ------------------- | |
st.set_page_config( | |
page_title="Mobica Find", | |
) | |
# Inject custom CSS to force a black background | |
st.markdown( | |
""" | |
<style> | |
.stApp { | |
background-color: black; | |
color: white; /* Ensures your text is visible on black background */ | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
# --------------- | |
# Inject top-left logo | |
# --------------- | |
logo_path = r"E:\Mobica\pdf_parser\logo_mobica.png" | |
with open(logo_path, "rb") as f: | |
logo_bytes = f.read() | |
encoded_logo = base64.b64encode(logo_bytes).decode() | |
st.markdown( | |
f""" | |
<style> | |
.top-left-logo {{ | |
position: fixed; | |
top: 1rem; | |
left: 1rem; | |
z-index: 9999; | |
}} | |
</style> | |
<div class="top-left-logo"> | |
<img src="data:image/png;base64,{encoded_logo}" width="240"> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
# -------------------- | |
# Load Processor, Model, and Metadata | |
# -------------------- | |
def load_resources(): | |
model_name = "kakaobrain/align-base" | |
# Load processor and model directly from Hugging Face | |
processor = AutoProcessor.from_pretrained(model_name) | |
model = AlignModel.from_pretrained(model_name) | |
# Move model to GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
return processor, model | |
processor, model = load_resources() | |
def extract_text_with_easyocr(image, language="en"): | |
""" Extracts text from an image using EasyOCR. """ | |
try: | |
results = reader.readtext(np.array(image), detail=0) # Get only text results | |
return " ".join(results) if results else "" | |
except Exception as e: | |
st.error(f"Error during OCR: {e}") | |
return "" | |
# -------------------- | |
# Embedding Functions | |
# -------------------- | |
def get_image_embedding(image): | |
"""Return normalized image embedding.""" | |
image_inputs = processor(images=image, return_tensors="pt") | |
image_outputs = model.get_image_features(**image_inputs) | |
return F.normalize(image_outputs, dim=1).detach().cpu().numpy() | |
def get_text_embedding(text): | |
"""Return normalized text embedding.""" | |
text_inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True) | |
text_outputs = model.get_text_features(**text_inputs) | |
return F.normalize(text_outputs, dim=1).detach().cpu().numpy() | |
# -------------------- | |
# Search Function | |
# -------------------- | |
def find_most_similar_products( | |
image=None, | |
description=None, | |
n=3, | |
combine_method="none" # "none" (image-only), "text-only", or "average" for combining | |
): | |
""" | |
Returns the top-n most similar products based on the specified method: | |
- image-only | |
- description-only | |
- both (average of embeddings) | |
""" | |
# Prepare the query embedding | |
if combine_method == "none" and image is not None: | |
query_embed = get_image_embedding(image) # image-only | |
elif combine_method == "text-only" and description is not None: | |
query_embed = get_text_embedding(description) # text-only | |
else: | |
# "average" => must have both image & description | |
img_emb = get_image_embedding(image) | |
txt_emb = get_text_embedding(description) | |
query_embed = (img_emb + txt_emb) / 2.0 # simple average | |
similarities = [] | |
# Loop through each product in metadata and compute similarity | |
for entry in embeddings_metadata.values(): | |
image_similarities = [] | |
for emb_path in entry.get("image_embedding_paths", []): | |
emb_path = os.path.normpath(emb_path) | |
if os.path.exists(emb_path): | |
stored_embedding = np.load(emb_path) | |
# Cosine similarity | |
image_similarities.append(cosine_similarity(query_embed, stored_embedding).mean()) | |
# Average all image sims in the product | |
overall_score = np.mean(image_similarities) if image_similarities else 0 | |
if overall_score > 0: | |
similarities.append((overall_score, entry)) | |
# Sort descending by similarity | |
return sorted(similarities, key=lambda x: x[0], reverse=True)[:n] | |
# -------------------- | |
# Session State Setup | |
# -------------------- | |
if "pdf_crops" not in st.session_state: | |
# We'll store pairs (snippet_image, product_image) for each page | |
st.session_state["pdf_crops"] = [] | |
if "results" not in st.session_state: | |
st.session_state["results"] = [] | |
# -------------------- | |
# APP UI | |
# -------------------- | |
st.title("Mobica Find") | |
search_method = st.selectbox( | |
"Choose Search Method", | |
["Upload PDF", "Image Only", "Description Only", "Both (Image + Description)"] | |
) | |
# ----------------------------------------------------------------------------- | |
# 1) PDF METHOD | |
# ----------------------------------------------------------------------------- | |
# ----------------------------------------------------------------------------- | |
# 1) PDF METHOD | |
# ----------------------------------------------------------------------------- | |
# Initialize EasyOCR reader (Supports multiple languages) | |
reader = easyocr.Reader(["en", "ar"]) # Add languages as needed | |
# ------------------- | |
# Set page config (must be done before other elements) | |
# ------------------- | |
st.set_page_config( | |
page_title="Mobica Find", | |
) | |
# Inject custom CSS to force a black background | |
st.markdown( | |
""" | |
<style> | |
.stApp { | |
background-color: black; | |
color: white; /* Ensures your text is visible on black background */ | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
# --------------- | |
# Inject top-left logo | |
# --------------- | |
logo_path = r"E:\Mobica\pdf_parser\logo_mobica.png" | |
with open(logo_path, "rb") as f: | |
logo_bytes = f.read() | |
encoded_logo = base64.b64encode(logo_bytes).decode() | |
st.markdown( | |
f""" | |
<style> | |
.top-left-logo {{ | |
position: fixed; | |
top: 1rem; | |
left: 1rem; | |
z-index: 9999; | |
}} | |
</style> | |
<div class="top-left-logo"> | |
<img src="data:image/png;base64,{encoded_logo}" width="240"> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
# -------------------- | |
# Load Processor, Model, and Metadata | |
# -------------------- | |
def load_resources(): | |
with open(r"E:\Mobica\pdf_parser\Data Sheet\align_processor.pkl", "rb") as f: | |
processor = pickle.load(f) | |
with open(r"E:\Mobica\pdf_parser\Data Sheet\align_model.pkl", "rb") as f: | |
model = pickle.load(f) | |
with open(r"E:\Mobica\pdf_parser\Data Sheet\embeddings_metadata.pkl", "rb") as f: | |
embeddings_metadata = pickle.load(f) | |
return processor, model, embeddings_metadata | |
processor, model, embeddings_metadata = load_resources() | |
# -------------------- | |
# OCR Function using EasyOCR | |
# -------------------- | |
def extract_text_with_easyocr(image, language="en"): | |
""" Extracts text from an image using EasyOCR. """ | |
try: | |
results = reader.readtext(np.array(image), detail=0) # Get only text results | |
return " ".join(results) if results else "" | |
except Exception as e: | |
st.error(f"Error during OCR: {e}") | |
return "" | |
# -------------------- | |
# APP UI | |
# -------------------- | |
st.title("Mobica Find") | |
search_method = st.selectbox( | |
"Choose Search Method", | |
["Upload PDF", "Image Only", "Description Only", "Both (Image + Description)"] | |
) | |
# ----------------------------------------------------------------------------- | |
# PDF Processing Section | |
# ----------------------------------------------------------------------------- | |
if search_method == "Upload PDF": | |
st.subheader("Upload a PDF") | |
uploaded_pdf = st.file_uploader("Upload a PDF", type=["pdf"]) | |
if uploaded_pdf: | |
pdf_path = f"temp_{uploaded_pdf.name}" | |
with open(pdf_path, "wb") as f: | |
f.write(uploaded_pdf.getbuffer()) | |
st.write("Extracting pages from PDF...") | |
pages = convert_from_path(pdf_path, 300) | |
if pages: | |
page_num = st.number_input("Select Page Number", min_value=1, max_value=len(pages), value=1) - 1 | |
page_image = pages[page_num] | |
# -------------------- Crop Snippet for OCR (description) -------------------- | |
st.subheader("Crop Snippet from PDF for OCR") | |
cropped_img_pdf_snippet = st_cropper(page_image, realtime_update=True, box_color='#FF0000') | |
description_ocr = "" | |
if cropped_img_pdf_snippet: | |
cropped_img_pdf_snippet = cropped_img_pdf_snippet.convert("RGB") | |
st.image(cropped_img_pdf_snippet, caption="Cropped PDF Snippet (For OCR)") | |
# Use EasyOCR instead of Tesseract | |
selected_lang = st.selectbox("Select OCR Language", ["en", "ar", "en+ar"], index=0) | |
description_ocr = extract_text_with_easyocr(cropped_img_pdf_snippet, language=selected_lang) | |
if description_ocr: | |
st.success("OCR text extracted successfully!") | |
st.write("**Detected Text**:", description_ocr) | |
else: | |
st.warning("No text detected.") | |
# -------------------- Crop for product image -------------------- | |
st.subheader("Crop the Product Image") | |
furniture_cropped_img = st_cropper(page_image, realtime_update=True, box_color='#00FF00') | |
if furniture_cropped_img: | |
furniture_cropped_img = furniture_cropped_img.convert("RGB") | |
st.image(furniture_cropped_img, caption="Cropped Product Image") | |
# -------------------- "Done" Button to save both crops -------------------- | |
if st.button("Done"): | |
st.session_state.setdefault("pdf_crops", []).append( | |
(cropped_img_pdf_snippet, furniture_cropped_img) | |
) | |
st.success(f"Crop #{len(st.session_state['pdf_crops'])} saved!") | |
# -------------------- Show saved crops if any -------------------- | |
if "pdf_crops" in st.session_state and len(st.session_state["pdf_crops"]) > 0: | |
st.subheader("📊 View Saved Crops") | |
crop_index = st.slider("Select Crop", 1, len(st.session_state["pdf_crops"]), 1) - 1 | |
snippet_img, product_img = st.session_state["pdf_crops"][crop_index] | |
col1, col2 = st.columns(2) | |
with col1: | |
if snippet_img: | |
st.image(snippet_img, caption=f"Snippet Crop {crop_index+1}", use_column_width=True) | |
with col2: | |
if product_img: | |
st.image(product_img, caption=f"Product Crop {crop_index+1}", use_column_width=True) | |
if st.button(f"Delete Crop {crop_index+1}"): | |
st.session_state["pdf_crops"].pop(crop_index) | |
st.success(f"Crop {crop_index+1} deleted!") | |
st.experimental_rerun() | |
# -------------------- Let user choose how many similar products -------------------- | |
n_similar = st.slider("How many similar products do you want?", 1, 10, 3) | |
# -------------------- "Find Similar Products" button -------------------- | |
if st.button("Find Similar Products"): | |
st.session_state["results"] = [] | |
# We'll do an image-based search using the product crop only | |
for snippet_img, product_img in st.session_state["pdf_crops"]: | |
if product_img is not None: | |
results_for_img = find_most_similar_products( | |
image=product_img, | |
n=n_similar, | |
combine_method="none" # image-only | |
) | |
st.session_state["results"].append(results_for_img) | |
st.success("Results generated!") | |
# -------------- Display results in the Streamlit GUI -------------- | |
for i, results_for_img in enumerate(st.session_state["results"]): | |
st.write(f"**Results for Crop {i+1}**:") | |
if results_for_img: | |
for sim_score, matched_entry in results_for_img: | |
# Extract product code from the original image path | |
if "original_image_paths" in matched_entry and matched_entry["original_image_paths"]: | |
matched_img_path = os.path.normpath(matched_entry["original_image_paths"][0]) | |
product_code = os.path.basename(matched_img_path).split('_')[0] # Extract product code | |
st.subheader(f"🔹 Match (Similarity: {sim_score:.4f})") | |
st.write(f"**Product Code:** {product_code}") # Display product code | |
st.write(f"**Description:** {matched_entry.get('description', 'No description')}") | |
# Show the first matched image (if available) | |
if os.path.exists(matched_img_path): | |
try: | |
img_matched = Image.open(matched_img_path).convert("RGB") | |
st.image( | |
img_matched, | |
caption=f"Matched Image (Sim: {sim_score:.4f})", | |
use_column_width=True | |
) | |
except UnidentifiedImageError: | |
st.warning(f"⚠️ Cannot open image: {matched_img_path}. It might be corrupted.") | |
else: | |
st.warning(f"⚠️ Image file not found: {matched_img_path}") | |
else: | |
st.warning(f"No similar products found for Crop {i+1}.") | |
# -------------------- Generate PDF if results are available -------------------- | |
if len(st.session_state["results"]) > 0: | |
pdf_buffer = io.BytesIO() | |
pdf = canvas.Canvas(pdf_buffer, pagesize=letter) | |
# st.session_state["results"] is a list of lists | |
# st.session_state["pdf_crops"] is a list of (snippet_img, product_img) | |
for i, (snippet_img, product_img) in enumerate(st.session_state["pdf_crops"]): | |
pdf.drawString(100, 750, f"Crop {i+1}") | |
# Add cropped product image to PDF | |
if product_img: | |
img_byte_arr = io.BytesIO() | |
product_img.save(img_byte_arr, format='JPEG') | |
img_byte_arr.seek(0) | |
pdf.drawImage(ImageReader(img_byte_arr), 100, 550, width=200, height=150) | |
y_pos = 530 | |
# Go through the matched results for this product | |
if i < len(st.session_state["results"]): | |
for sim_score, matched_entry in st.session_state["results"][i]: | |
if "original_image_paths" in matched_entry and len(matched_entry["original_image_paths"]) > 0: | |
matched_img_path = os.path.normpath(matched_entry["original_image_paths"][0]) | |
product_code = os.path.basename(matched_img_path).split('_')[0] # Extract product code | |
pdf.drawString(100, y_pos, f"Product Code: {product_code}") # Add product code to PDF | |
#pdf.drawString(100, y_pos - 20, f"Similarity: {sim_score:.4f}") | |
y_pos -= 40 | |
if os.path.exists(matched_img_path): | |
pdf.drawImage(matched_img_path, 350, y_pos - 50, width=150, height=100) | |
y_pos -= 120 | |
pdf.showPage() | |
pdf.save() | |
pdf_buffer.seek(0) | |
st.download_button( | |
"📥 Download Results PDF", | |
pdf_buffer, | |
f"{uploaded_pdf.name}_results.pdf", | |
"application/pdf" | |
) | |
# ----------------------------------------------------------------------------- | |
# 2) IMAGE ONLY | |
# ----------------------------------------------------------------------------- | |
elif search_method == "Image Only": | |
st.subheader("Upload an Image") | |
uploaded_image = st.file_uploader("Select an Image", type=["png", "jpg", "jpeg"]) | |
if uploaded_image is not None: | |
image_obj = Image.open(uploaded_image).convert("RGB") | |
st.image(image_obj, use_column_width=True) | |
# Let user choose how many similar products | |
n_similar = st.slider("How many similar products do you want?", 1, 10, 3) | |
# Button to trigger the search | |
if st.button("Find Similar Products"): | |
results = find_most_similar_products( | |
image=image_obj, | |
n=n_similar, | |
combine_method="none" # image-only | |
) | |
if results: | |
for sim_score, matched_entry in results: | |
st.subheader(f"🔹 Match (Similarity: {sim_score:.4f})") | |
st.write(f"**Description:** {matched_entry.get('description','No description')}") | |
# Display the first image of the matched entry | |
if "original_image_paths" in matched_entry and matched_entry["original_image_paths"]: | |
img_path = os.path.normpath(matched_entry["original_image_paths"][0]) # Normalize path | |
if os.path.exists(img_path): | |
try: | |
img_matched = Image.open(img_path).convert("RGB") | |
st.image( | |
img_matched, | |
caption=f"Matched Image (Sim: {sim_score:.4f})", | |
use_column_width=True | |
) | |
except UnidentifiedImageError: | |
st.warning(f"⚠️ Cannot open image: {img_path}. It might be corrupted.") | |
else: | |
st.warning(f"⚠️ Image file not found: {img_path}") | |
else: | |
st.warning("No similar products found.") | |
# ----------------------------------------------------------------------------- | |
# 3) DESCRIPTION ONLY | |
# ----------------------------------------------------------------------------- | |
elif search_method == "Description Only": | |
st.subheader("Enter a Description") | |
user_description = st.text_area("Type or paste your description here") | |
if user_description.strip(): | |
# Let user choose how many similar products | |
n_similar = st.slider("How many similar products do you want?", 1, 10, 3) | |
# Button to trigger the search | |
if st.button("Find Similar Products"): | |
results = find_most_similar_products( | |
description=user_description, | |
n=n_similar, | |
combine_method="text-only" | |
) | |
if results: | |
for sim_score, matched_entry in results: | |
st.subheader(f"🔹 Match (Similarity: {sim_score:.4f})") | |
st.write(f"**Description:** {matched_entry.get('description','No description')}") | |
# Display the first image of the matched entry | |
if "original_image_paths" in matched_entry and matched_entry["original_image_paths"]: | |
img_path = os.path.normpath(matched_entry["original_image_paths"][0]) | |
if os.path.exists(img_path): | |
try: | |
img_matched = Image.open(img_path).convert("RGB") | |
st.image( | |
img_matched, | |
caption=f"Matched Image (Sim: {sim_score:.4f})", | |
use_column_width=True | |
) | |
except UnidentifiedImageError: | |
st.warning(f"⚠️ Cannot open image: {img_path}. It might be corrupted.") | |
else: | |
st.warning(f"⚠️ Image file not found: {img_path}") | |
else: | |
st.warning("No similar products found.") | |
# ----------------------------------------------------------------------------- | |
# 4) BOTH (IMAGE + DESCRIPTION) | |
# ----------------------------------------------------------------------------- | |
elif search_method == "Both (Image + Description)": | |
st.subheader("Upload an Image and Enter a Description") | |
uploaded_image = st.file_uploader("Select an Image", type=["png", "jpg", "jpeg"]) | |
user_description = st.text_area("Type or paste your description here") | |
if uploaded_image is not None: | |
image_obj = Image.open(uploaded_image).convert("RGB") | |
st.image(image_obj, use_column_width=True) | |
if user_description.strip(): | |
# Let user choose how many similar products | |
n_similar = st.slider("How many similar products do you want?", 1, 10, 3) | |
# Button to trigger the search | |
if st.button("Find Similar Products"): | |
results = find_most_similar_products( | |
image=image_obj, | |
description=user_description, | |
n=n_similar, | |
combine_method="average" | |
) | |
if results: | |
for sim_score, matched_entry in results: | |
st.subheader(f"🔹 Match (Similarity: {sim_score:.4f})") | |
st.write(f"**Description:** {matched_entry.get('description','No description')}") | |
# Display the first image of the matched entry | |
if "original_image_paths" in matched_entry and matched_entry["original_image_paths"]: | |
img_path = os.path.normpath(matched_entry["original_image_paths"][0]) | |
if os.path.exists(img_path): | |
try: | |
img_matched = Image.open(img_path).convert("RGB") | |
st.image( | |
img_matched, | |
caption=f"Matched Image (Sim: {sim_score:.4f})", | |
use_column_width=True | |
) | |
except UnidentifiedImageError: | |
st.warning(f"⚠️ Cannot open image: {img_path}. It might be corrupted.") | |
else: | |
st.warning(f"⚠️ Image file not found: {img_path}") | |
else: | |
st.warning("No similar products found.") | |