Spaces:
Paused
Paused
import gradio as gr | |
import numpy as np | |
from sklearn.preprocessing import StandardScaler | |
import pandas as pd | |
import os | |
import zlib | |
from typing import Dict, List, Tuple, Optional, Literal | |
from langchain_mistralai import MistralAIEmbeddings | |
from langchain_core.embeddings import Embeddings | |
import os | |
from dotenv import load_dotenv | |
from ranking_agent import rank_with_ai | |
from scipy.sparse import load_npz | |
from rapidfuzz import process, fuzz | |
import re | |
load_dotenv() | |
class MovieRecommender: | |
def __init__(self, data_dir: str = "amazon_movies_2023"): | |
self.data_dir = data_dir | |
self.embeddings = MistralAIEmbeddings( | |
model="mistral-embed", | |
mistral_api_key=os.getenv("MISTRAL_API_KEY") | |
) | |
# Load both types of embeddings | |
self.load_embeddings() | |
def load_embeddings(self) -> None: | |
# Load LLM embeddings | |
llm_embeddings_path = os.path.join(self.data_dir, "title_embeddings.npz") | |
try: | |
llm_data = np.load(llm_embeddings_path) | |
self.llm_embeddings = llm_data['embeddings'] | |
self.llm_item_ids = llm_data['item_ids'].astype(str) # Ensure string type | |
print(f"Loaded LLM embeddings with shape: {self.llm_embeddings.shape}") | |
print(f"Number of LLM item IDs: {len(self.llm_item_ids)}") | |
except (IOError, zlib.error) as e: | |
raise RuntimeError( | |
f"Error loading LLM embeddings file: {str(e)}\n" | |
"The embeddings file appears to be corrupted or invalid." | |
) | |
# Load GCL embeddings | |
gcl_embeddings_path = os.path.join(self.data_dir, "gcl_embeddings.npz") | |
try: | |
gcl_data = np.load(gcl_embeddings_path) | |
self.gcl_embeddings = gcl_data['embeddings'] | |
self.gcl_item_ids = gcl_data['item_ids'].astype(str) # Ensure string type | |
print(f"Loaded GCL embeddings with shape: {self.gcl_embeddings.shape}") | |
print(f"Number of GCL item IDs: {len(self.gcl_item_ids)}") | |
except (IOError, zlib.error) as e: | |
raise RuntimeError( | |
f"Error loading GCL embeddings file: {str(e)}\n" | |
"Please run gcl_embeddings.py first to generate GCL embeddings." | |
) | |
# Load movie mapping | |
mapping_path = os.path.join(self.data_dir, "title_embeddings_mapping.csv") | |
self.movies_df = pd.read_csv(mapping_path) | |
self.movies_df['item_id'] = self.movies_df['item_id'].astype(str) # Ensure string type | |
# Create standardized embeddings for both types | |
scaler = StandardScaler() | |
self.llm_embeddings = scaler.fit_transform(self.llm_embeddings) | |
self.gcl_embeddings = scaler.fit_transform(self.gcl_embeddings) | |
# Create item_id to index mappings for both types | |
self.llm_id_to_idx = {str(item_id): idx for idx, item_id in enumerate(self.llm_item_ids)} | |
self.gcl_id_to_idx = {str(item_id): idx for idx, item_id in enumerate(self.gcl_item_ids)} | |
# Create title to id mapping for search | |
self.title_to_id = dict(zip(self.movies_df['title'], self.movies_df['item_id'])) | |
# Store all titles for search | |
self.all_titles = self.movies_df['title'].tolist() | |
print(f"Number of movies in mapping: {len(self.movies_df)}") | |
print(f"Number of titles with LLM embeddings: {len(set(self.llm_id_to_idx.keys()) & set(self.title_to_id.values()))}") | |
print(f"Number of titles with GCL embeddings: {len(set(self.gcl_id_to_idx.keys()) & set(self.title_to_id.values()))}") | |
# Pre-process titles for fuzzy matching | |
self.clean_titles = {self.clean_title_for_comparison(title): title for title in self.title_to_id.keys()} | |
def clean_title_for_comparison(self, title): | |
"""Clean title for comparison purposes""" | |
# Remove special characters and extra spaces | |
title = re.sub(r'[^\w\s]', '', str(title)) | |
# Convert to lowercase and strip | |
return ' '.join(title.lower().split()) | |
def search_movies(self, query: str) -> List[str]: | |
if not query: | |
return [] # Return empty if no query to avoid overwhelming UI | |
clean_query = self.clean_title_for_comparison(query) | |
# Use rapidfuzz to find matches across entire dataset | |
matches = process.extract( | |
clean_query, | |
self.clean_titles.keys(), | |
scorer=fuzz.WRatio, # WRatio works well for movie titles | |
limit=None, # No limit - show all matches | |
score_cutoff=60 # Only return matches with score > 60 | |
) | |
# Convert matches back to original titles | |
return [self.clean_titles[match[0]] for match in matches] | |
def get_text_embedding(self, text: str) -> np.ndarray: | |
"""Get embedding for text using LangChain Mistral embeddings""" | |
try: | |
embedding = self.embeddings.embed_query(text) | |
# Convert embedding to numpy array | |
embedding = np.array(embedding, dtype=np.float32) | |
# Normalize the embedding | |
if np.any(embedding): # Only normalize if not all zeros | |
embedding = embedding / np.linalg.norm(embedding) | |
return embedding | |
except Exception as e: | |
print(f"Error getting embedding from Mistral API: {str(e)}") | |
return None | |
def get_recommendations(self, selected_movies: List[str], embedding_type: str = "LLM + GCL", user_preferences: str = "", alpha: float = 0.5) -> str: | |
""" | |
Get recommendations using proper embedding aggregation: | |
- e_h: embedding from user history (selected movies) | |
- e_u: embedding from user preferences (text) | |
- Combined: alpha * e_u + (1-alpha) * e_h | |
""" | |
if not selected_movies and not user_preferences: | |
return "Please select some movies or provide preferences." | |
# Choose embeddings based on type | |
if embedding_type == "LLM + GCL": | |
embeddings = self.gcl_embeddings | |
id_to_idx = self.gcl_id_to_idx | |
else: | |
embeddings = self.llm_embeddings | |
id_to_idx = self.llm_id_to_idx | |
user_profile = None | |
# Get embedding from user history (e_h) | |
e_h = None | |
if selected_movies: | |
movie_ids = [self.title_to_id[title] for title in selected_movies if title in self.title_to_id] | |
if movie_ids: | |
selected_embeddings = [] | |
for movie_id in movie_ids: | |
if movie_id in id_to_idx: | |
idx = id_to_idx[movie_id] | |
selected_embeddings.append(embeddings[idx]) | |
if selected_embeddings: | |
e_h = np.mean(selected_embeddings, axis=0) | |
# Get embedding from user preferences (e_u) | |
e_u = None | |
if user_preferences.strip(): | |
e_u = self.get_text_embedding(user_preferences) | |
# Apply aggregation algorithm | |
if e_h is not None and e_u is not None: | |
# Both available: alpha * e_u + (1-alpha) * e_h | |
user_profile = alpha * e_u + (1 - alpha) * e_h | |
print(f"Using combined embedding: α={alpha} (preferences weight)") | |
elif e_u is not None: | |
# Only preferences available | |
user_profile = e_u | |
print("Using preferences-only embedding") | |
elif e_h is not None: | |
# Only history available | |
user_profile = e_h | |
print("Using history-only embedding") | |
else: | |
return "Could not create user profile from provided input." | |
# Calculate similarity with all movies | |
# Normalize user profile and embeddings for proper cosine similarity | |
user_profile_norm = user_profile / np.linalg.norm(user_profile) | |
embeddings_norm = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) | |
# Calculate cosine similarity (normalized dot product) | |
similarities = np.dot(embeddings_norm, user_profile_norm) | |
print(f"Similarity range: {similarities.min():.3f} to {similarities.max():.3f}") | |
# Get top 100 most similar movies | |
top_indices = np.argsort(similarities)[-100:][::-1] | |
# Filter out selected movies and create recommendations | |
seen_titles = set(selected_movies) if selected_movies else set() | |
seen_clean_titles = set(self.clean_title_for_comparison(title) for title in seen_titles) | |
final_recommendations = [] | |
# Get reverse mapping for the chosen embedding type | |
if embedding_type == "LLM + GCL": | |
idx_to_id = {idx: item_id for item_id, idx in self.gcl_id_to_idx.items()} | |
else: | |
idx_to_id = {idx: item_id for item_id, idx in self.llm_id_to_idx.items()} | |
for idx in top_indices: | |
if idx not in idx_to_id: | |
continue | |
item_id = idx_to_id[idx] | |
# Find the title for this item_id | |
title = None | |
for t, id_ in self.title_to_id.items(): | |
if id_ == item_id: | |
title = t | |
break | |
if not title: | |
continue | |
clean_title = self.clean_title_for_comparison(title) | |
# Skip if exact title is in seen titles | |
if title in seen_titles: | |
continue | |
# Skip if clean version of title is in seen titles | |
if clean_title in seen_clean_titles: | |
continue | |
# Skip collections/trilogies if user has seen any part | |
is_collection = False | |
for seen_title in seen_titles: | |
seen_clean = self.clean_title_for_comparison(seen_title) | |
if seen_clean in clean_title or clean_title in seen_clean: | |
if any(marker in title.lower() for marker in ['collection', 'trilogy', 'series', 'complete']): | |
is_collection = True | |
break | |
if is_collection: | |
continue | |
# Check if this is a duplicate of already recommended movie | |
is_duplicate = any( | |
fuzz.ratio(clean_title, self.clean_title_for_comparison(rec[0])) > 90 | |
for rec in final_recommendations | |
) | |
if is_duplicate: | |
continue | |
# Add with similarity score | |
final_recommendations.append((title, similarities[idx])) | |
if len(final_recommendations) >= 100: | |
break | |
if not final_recommendations: | |
return "No recommendations found based on your input." | |
return final_recommendations[:100] # Return top 100 for ranking agent | |
def create_interface(): | |
try: | |
recommender = MovieRecommender() | |
except Exception as e: | |
print(f"Error initializing recommender: {str(e)}") | |
return None | |
with gr.Blocks() as iface: | |
gr.Markdown( | |
""" | |
# Movie Recommender | |
Get personalized movie recommendations based on your taste and preferences utilizing zero-shot predicitons from Foundation Recommender! | |
**How to use:** | |
1. Search and select movies you've enjoyed | |
2. Describe the type of film you are looking for. Consider factors such as genre, length, whether it is animated, etc. | |
3. Adjust the preference weight (α) to balance between your description and movie history | |
4. Get personalized recommendations | |
""" | |
) | |
selected_movies = gr.State([]) | |
retrieval_results = gr.State([]) # Store retrieval results for ranking | |
with gr.Row(): | |
with gr.Column(): | |
# Movie search and selection | |
movie_search_input = gr.Textbox( | |
label="Search movies", | |
placeholder="Type to search...", | |
interactive=True, | |
every=True | |
) | |
# Show search results as a list of clickable buttons | |
search_results = gr.Radio( | |
choices=[], | |
label="Search Results", | |
interactive=True, | |
visible=True | |
) | |
# Display selected movies with functional red cross buttons | |
with gr.Column(elem_id="selected_movies_container") as selected_movies_container: | |
selected_display = gr.HTML( | |
label="Your Selected Movies", | |
value="<p><i>No movies selected yet</i></p>" | |
) | |
# Individual delete buttons (simpler approach) | |
delete_buttons = [] | |
for i in range(20): # Support up to 20 movies | |
btn = gr.Button(f"× Remove Movie {i+1}", visible=False, size="sm", variant="secondary") | |
delete_buttons.append(btn) | |
# Clear all button | |
clear_btn = gr.Button("Clear All", size="sm", variant="secondary") | |
# User preferences text field | |
user_preferences = gr.Textbox( | |
label="Describe what kind of movie you're looking for", | |
placeholder="E.g., 'A movie with a female main character'", | |
lines=3 | |
) | |
# Alpha slider | |
alpha = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.5, | |
step=0.1, | |
label="Preference Weight (α)", | |
info="0: Use only movie history, 1: Use only your description" | |
) | |
# Embedding type selection (defaulting to GCL) | |
embedding_type = gr.Radio( | |
choices=["LLM + GCL", "LLM"], | |
value="LLM + GCL", | |
label="Embedding Type", | |
info="Choose between pure language model embeddings (LLM) or graph-enhanced embeddings (LLM + GCL)" | |
) | |
# Get recommendations button | |
recommend_btn = gr.Button("Get Recommendations", variant="primary") | |
with gr.Column(): | |
# Display recommendations with streaming | |
recommendations = gr.Markdown( | |
label="Your Personalized Recommendations", | |
value="Recommendations will appear here" | |
) | |
def update_search_results(query): | |
"""Update search results based on input""" | |
if not query or len(query.strip()) < 2: | |
return gr.Radio(choices=[], visible=False) | |
matches = recommender.search_movies(query) | |
# Limit display to first 20 for UI performance | |
display_matches = matches[:20] if len(matches) > 20 else matches | |
if display_matches: | |
return gr.Radio(choices=display_matches, visible=True) | |
else: | |
return gr.Radio(choices=[], visible=False) | |
def format_selected_movies_display(movies): | |
"""Format selected movies with remove buttons on same line""" | |
if not movies: | |
return "<p><i>No movies selected yet</i></p>" | |
html_items = [] | |
for i, movie in enumerate(movies): | |
html_items.append(f""" | |
<div style="display: flex; align-items: center; justify-content: space-between; | |
padding: 8px 12px; margin: 4px 0; background-color: #f8f9fa; | |
border-radius: 6px; border-left: 3px solid #007bff;"> | |
<span style="flex-grow: 1; font-size: 14px; margin-right: 10px;">{i+1}. {movie}</span> | |
</div> | |
""") | |
return f"<div>{''.join(html_items)}</div>" | |
def update_delete_buttons_visibility(movies): | |
"""Update visibility and labels of delete buttons""" | |
button_updates = [] | |
for i in range(20): # Support up to 20 movies | |
if i < len(movies): | |
movie_name = movies[i][:40] + ("..." if len(movies[i]) > 40 else "") | |
button_updates.append(gr.Button(f"🗑️ {movie_name}", visible=True, size="sm", variant="secondary")) | |
else: | |
button_updates.append(gr.Button(f"× Remove Movie {i+1}", visible=False, size="sm", variant="secondary")) | |
return button_updates | |
def delete_movie_by_index(index, current_movies): | |
"""Delete movie at specific index""" | |
if not current_movies or index >= len(current_movies): | |
return current_movies, format_selected_movies_display(current_movies) | |
current_movies.pop(index) | |
return current_movies, format_selected_movies_display(current_movies) | |
def handle_movie_selection(selected_movie, current_movies): | |
"""Handle movie selection from radio buttons""" | |
if not selected_movie: | |
return [current_movies, format_selected_movies_display(current_movies)] + update_delete_buttons_visibility(current_movies) | |
# Check if it's a movie title (exists in our database) | |
if selected_movie in recommender.title_to_id: | |
# It's a movie selection - add it to the list | |
current_movies = current_movies or [] | |
# Remove the 5-movie limit - users can now select as many as they want | |
if selected_movie not in current_movies: | |
current_movies.append(selected_movie) | |
return [current_movies, format_selected_movies_display(current_movies)] + update_delete_buttons_visibility(current_movies) | |
else: | |
# Not a movie from database | |
return [current_movies, format_selected_movies_display(current_movies)] + update_delete_buttons_visibility(current_movies) | |
def clear_all_movies(): | |
"""Clear all selected movies""" | |
empty_movies = [] | |
return [empty_movies, "<p><i>No movies selected yet</i></p>"] + update_delete_buttons_visibility(empty_movies) | |
def get_recommendations(movies, emb_type, preferences, pref_weight): | |
"""Get recommendations: retrieval phase only, then delegate to ranking_agent with streaming""" | |
if not movies and not preferences: | |
yield "Please select some movies or provide preferences" | |
return | |
try: | |
# RETRIEVAL PHASE: Get top 100 candidates using proper embedding aggregation | |
print(f"\n=== RETRIEVAL PHASE ===") | |
print(f"Selected movies: {movies}") | |
print(f"User preferences: '{preferences}'") | |
print(f"Alpha weight: {pref_weight}") | |
print(f"Embedding type: {emb_type}") | |
yield "🔍 Searching for similar movies..." | |
recommendations = recommender.get_recommendations( | |
selected_movies=movies, | |
embedding_type=emb_type, | |
user_preferences=preferences, | |
alpha=pref_weight | |
) | |
# Handle error cases | |
if isinstance(recommendations, str): | |
yield recommendations | |
return | |
# Print retrieval results | |
print(f"\nRETRIEVAL RESULTS: Found {len(recommendations)} candidates") | |
print("Top 100 from retrieval phase:") | |
for i, (title, score) in enumerate(recommendations[:100], 1): | |
print(f" {i:2d}. {title} (score: {score:.3f})") | |
# RERANKING + EXPLANATION PHASE: Delegate to ranking_agent with streaming | |
print(f"\n=== RERANKING PHASE ===") | |
print(f"Calling rank_with_ai with:") | |
print(f" - {len(recommendations)} recommendations") | |
print(f" - preferences: '{preferences}'") | |
print(f" - alpha: {pref_weight}") | |
print(f" - user_movies: {movies}") | |
yield "🤖 AI is ranking and explaining your recommendations..." | |
# Stream the responses from ranking agent | |
for partial_result in rank_with_ai( | |
recommendations=recommendations, | |
user_preferences=preferences, | |
alpha=pref_weight, | |
user_movies=movies | |
): | |
yield partial_result | |
except Exception as e: | |
print(f"ERROR in get_recommendations: {str(e)}") | |
import traceback | |
traceback.print_exc() | |
yield f"Error getting recommendations: {str(e)}" | |
# Event handlers | |
movie_search_input.change( | |
fn=update_search_results, | |
inputs=movie_search_input, | |
outputs=search_results | |
) | |
search_results.change( | |
fn=handle_movie_selection, | |
inputs=[search_results, selected_movies], | |
outputs=[selected_movies, selected_display] + delete_buttons | |
) | |
# Add individual delete button handlers | |
for i, btn in enumerate(delete_buttons): | |
def make_delete_handler(btn_idx): | |
def delete_handler(current_movies): | |
updated_movies, updated_display = delete_movie_by_index(btn_idx, current_movies) | |
return [updated_movies, updated_display] + update_delete_buttons_visibility(updated_movies) | |
return delete_handler | |
btn.click( | |
fn=make_delete_handler(i), | |
inputs=[selected_movies], | |
outputs=[selected_movies, selected_display] + delete_buttons | |
) | |
clear_btn.click( | |
fn=clear_all_movies, | |
inputs=[], | |
outputs=[selected_movies, selected_display] + delete_buttons | |
) | |
recommend_btn.click( | |
fn=get_recommendations, | |
inputs=[selected_movies, embedding_type, user_preferences, alpha], | |
outputs=recommendations | |
) | |
return iface | |
if __name__ == "__main__": | |
iface = create_interface() | |
if iface is not None: | |
iface.launch() | |
else: | |
print("\nPlease fix the issues above and try again.") |