Prathamesh1420's picture
Update app.py
db12d10 verified
raw
history blame
9.2 kB
import os
import pickle
import torch
import matplotlib.pyplot as plt
from langchain_community.document_loaders import TextLoader
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import BertModel, BertTokenizer
from langchain_core.prompts import PromptTemplate
import streamlit as st
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import GlobalMaxPooling2D
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from sklearn.neighbors import NearestNeighbors
from numpy.linalg import norm
os.environ['HUGGINGFACEHUB_API_TOKEN'] = "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"
class Chatbot:
def __init__(self):
self.load_data()
self.load_models()
self.load_embeddings()
self.load_template()
def load_data(self):
self.data = load_dataset("ashraq/fashion-product-images-small", split="train")
self.images = self.data["image"]
self.product_frame = self.data.remove_columns("image").to_pandas()
self.product_data = self.product_frame.reset_index(drop=True).to_dict(orient='index')
def load_template(self):
self.template = """
You are a fashion shopping assistant that wants to convert customers based on the information given.
Describe season and usage given in the context in your interaction with the customer.
Use a bullet list when describing each product.
If user ask general question then answer them accordingly, the question may be like when the store will open, where is your store located.
Context: {context}
User question: {question}
Your response: {response}
"""
self.prompt = PromptTemplate.from_template(self.template)
def load_models(self):
self.model = SentenceTransformer('clip-ViT-B-32')
self.bert_model_name = "bert-base-uncased"
self.bert_model = BertModel.from_pretrained(self.bert_model_name)
self.bert_tokenizer = BertTokenizer.from_pretrained(self.bert_model_name)
self.gpt2_model_name = "gpt2"
self.gpt2_model = GPT2LMHeadModel.from_pretrained(self.gpt2_model_name)
self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained(self.gpt2_model_name)
def load_embeddings(self):
if os.path.exists("embeddings_cache.pkl"):
with open("embeddings_cache.pkl", "rb") as f:
embeddings_cache = pickle.load(f)
self.image_embeddings = embeddings_cache["image_embeddings"]
self.text_embeddings = embeddings_cache["text_embeddings"]
else:
self.image_embeddings = self.model.encode([image for image in self.images])
self.text_embeddings = self.model.encode(self.product_frame['productDisplayName'])
embeddings_cache = {"image_embeddings": self.image_embeddings, "text_embeddings": self.text_embeddings}
with open("embeddings_cache.pkl", "wb") as f:
pickle.dump(embeddings_cache, f)
def create_docs(self, results):
docs = []
for result in results:
pid = result['corpus_id']
score = result['score']
result_string = ''
result_string += "Product Name:" + self.product_data[pid]['productDisplayName'] + \
';' + "Category:" + self.product_data[pid]['masterCategory'] + \
';' + "Article Type:" + self.product_data[pid]['articleType'] + \
';' + "Usage:" + self.product_data[pid]['usage'] + \
';' + "Season:" + self.product_data[pid]['season'] + \
';' + "Gender:" + self.product_data[pid]['gender']
# Assuming text is imported from somewhere else
doc = TextLoader(page_content=result_string)
doc.metadata['pid'] = str(pid)
doc.metadata['score'] = score
docs.append(doc)
return docs
def get_results(self, query, embeddings, top_k=10):
query_embedding = self.model.encode([query])
cos_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0]
top_results = torch.topk(cos_scores, k=top_k)
indices = top_results.indices.tolist()
scores = top_results.values.tolist()
results = [{'corpus_id': idx, 'score': score} for idx, score in zip(indices, scores)]
return results
def display_text_and_images(self, results_text):
for result in results_text:
pid = result['corpus_id']
product_info = self.product_data[pid]
print("Product Name:", product_info['productDisplayName'])
print("Category:", product_info['masterCategory'])
print("Article Type:", product_info['articleType'])
print("Usage:", product_info['usage'])
print("Season:", product_info['season'])
print("Gender:", product_info['gender'])
print("Score:", result['score'])
plt.imshow(self.images[pid])
plt.axis('off')
plt.show()
@staticmethod
def cos_sim(a, b):
a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
return torch.mm(a_norm.T, b_norm) # Reshape a_norm to (768, 1)
def generate_response(self, query):
# Process the user query and generate a response
results_text = self.get_results(query, self.text_embeddings)
# Generate chatbot response
chatbot_response = "This is a placeholder response from the chatbot." # Placeholder, replace with actual response
# Display recommended products
self.display_text_and_images(results_text)
# Return both chatbot response and recommended products
return chatbot_response, results_text
# Function to save uploaded file
def save_uploaded_file(uploaded_file):
try:
with open(os.path.join('uploads', uploaded_file.name), 'wb') as f:
f.write(uploaded_file.getbuffer())
return True
except:
return False
# Function to show dashboard content
def show_dashboard():
st.header("Fashion Recommender System")
chatbot = Chatbot()
# Load ResNet model for image feature extraction
model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
model.trainable = False
model = tf.keras.Sequential([
model,
GlobalMaxPooling2D()
])
feature_list = np.array(pickle.load(open('embeddings.pkl', 'rb')))
# filenames = pickle.load(open('filenames.pkl', 'rb')) # No longer needed
# File upload section
uploaded_file = st.file_uploader("Choose an image")
if uploaded_file is not None:
if save_uploaded_file(uploaded_file):
# Display the uploaded image
display_image = Image.open(uploaded_file)
st.image(display_image)
# Feature extraction
features = feature_extraction(os.path.join("uploads", uploaded_file.name), model)
# Recommendation
indices = recommend(features, feature_list)
# Display recommended products using the dataset images
st.write("Recommended Products:")
cols = st.columns(5)
for i, idx in enumerate(indices[0][:5]):
with cols[i]:
st.image(chatbot.images[idx])
product_info = chatbot.product_data[idx]
st.write("Product Name:", product_info['productDisplayName'])
st.write("Category:", product_info['masterCategory'])
st.write("Article Type:", product_info['articleType'])
st.write("Usage:", product_info['usage'])
st.write("Season:", product_info['season'])
st.write("Gender:", product_info['gender'])
else:
st.header("Some error occurred in file upload")
# Chatbot section
user_question = st.text_input("Ask a question:")
if user_question:
bot_response, recommended_products = chatbot.generate_response(user_question)
st.write("Chatbot:", bot_response)
# Display recommended products
for result in recommended_products:
pid = result['corpus_id']
product_info = chatbot.product_data[pid]
st.write("Product Name:", product_info['productDisplayName'])
st.write("Category:", product_info['masterCategory'])
st.write("Article Type:", product_info['articleType'])
st.write("Usage:", product_info['usage'])
st.write("Season:", product_info['season'])
st.write("Gender:", product_info['gender'])
st.image(chatbot.images[pid])
# Main Streamlit app
def main():
# Give title to the app
st.title("Fashion Recommender System")
# Show dashboard content directly
show_dashboard()
# Run the main app
if __name__ == "__main__":
main()