mmrag-hf / src /app.py
doggdad's picture
Create app.py
ea92a48 verified
from pathlib import Path
import os
from os import path as osp
import gradio as gr
from dotenv import load_dotenv
from crud.vector_store import MultimodalLanceDB
from preprocess.embedding import BridgeTowerEmbeddings
from preprocess.preprocessing import extract_and_save_frames_and_metadata
from utils import (
download_video,
get_transcript_vtt,
download_youtube_subtitle,
get_video_id_from_url,
str2time,
maintain_aspect_ratio_resize,
getSubs,
encode_image,
)
from mistralai import Mistral
from langchain_core.runnables import (
RunnableParallel,
RunnablePassthrough,
RunnableLambda
)
from PIL import Image
import lancedb
# -------------------------------
# 1. Setup - HuggingFace Spaces Configuration
# -------------------------------
load_dotenv()
# HuggingFace Spaces specific setup
SPACE_ID = os.getenv("SPACE_ID")
IS_SPACES = SPACE_ID is not None
if IS_SPACES:
LANCEDB_HOST_FILE = "/tmp/.lancedb"
VIDEO_DIR = "/tmp/videos/video1"
os.makedirs("/tmp", exist_ok=True)
else:
LANCEDB_HOST_FILE = "./shared_data/.lancedb"
VIDEO_DIR = "./shared_data/videos/video1"
TBL_NAME = "vectorstore"
# Initialize components
db = lancedb.connect(LANCEDB_HOST_FILE)
embedder = BridgeTowerEmbeddings()
# -------------------------------
# 2. Preprocessing + Storage
# -------------------------------
def preprocess_and_store(youtube_url: str):
"""Download video, extract frames+metadata, embed & store in LanceDB"""
try:
video_url = youtube_url
video_dir = VIDEO_DIR
# download Youtube video
video_filepath = download_video(video_url, video_dir)
# download Youtube video's subtitle
video_transcript_filepath = download_youtube_subtitle(video_url, video_dir)
extracted_frames_path = osp.join(video_dir, 'extracted_frame')
# create these output folders if not existing
Path(extracted_frames_path).mkdir(parents=True, exist_ok=True)
Path(video_dir).mkdir(parents=True, exist_ok=True)
# call the function to extract frames and metadatas
metadatas = extract_and_save_frames_and_metadata(
video_filepath,
video_transcript_filepath,
extracted_frames_path,
video_dir,
)
# collect transcripts and image paths
video_trans = [vid['transcript'] for vid in metadatas]
video_img_path = [vid['extracted_frame_path'] for vid in metadatas]
n = 7
updated_video_trans = [
' '.join(video_trans[i-int(n/2) : i+int(n/2)]) if i-int(n/2) >= 0 else
' '.join(video_trans[0 : i + int(n/2)]) for i in range(len(video_trans))
]
# also need to update the updated transcripts in metadata
for i in range(len(updated_video_trans)):
metadatas[i]['transcript'] = updated_video_trans[i]
_ = MultimodalLanceDB.from_text_image_pairs(
texts=updated_video_trans,
image_paths=video_img_path,
embedding=embedder,
metadatas=metadatas,
connection=db,
table_name=TBL_NAME,
mode="overwrite",
)
return f"βœ… Video processed and stored: {youtube_url}"
except Exception as e:
return f"❌ Error processing video: {str(e)}"
# -------------------------------
# 3. Retrieval + Prompt Functions
# -------------------------------
vectorstore = MultimodalLanceDB(
uri=LANCEDB_HOST_FILE,
embedding=embedder,
table_name=TBL_NAME
)
retriever_module = vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": 3}
)
def prompt_processing(input):
retrieved_results = input["retrieved_results"]
user_query = input["user_query"]
if not retrieved_results:
return {"prompt": "No relevant content found.", "frame_path": None}
retrieved_results = retrieved_results[0]
prompt_template = (
"The transcript associated with the image is '{transcript}'. "
"{user_query}"
)
retrieved_metadata = retrieved_results.metadata
transcript = retrieved_metadata["transcript"]
frame_path = retrieved_metadata["extracted_frame_path"]
return {
"prompt": prompt_template.format(transcript=transcript, user_query=user_query),
"frame_path": frame_path,
}
def lvlm_inference(input):
try:
# get the retrieved results and user's query
lvlm_prompt = input['prompt']
frame_path = input['frame_path']
if frame_path is None:
return "No relevant frame found.", None
# Retrieve the API key from environment variables
api_key = os.getenv("MISTRAL_API_KEY")
if not api_key:
return "❌ MISTRAL_API_KEY not found. Please set it in the environment variables.", frame_path
# Initialize the Mistral client
client = Mistral(api_key=api_key)
base64_image = encode_image(frame_path)
# Define the messages for the chat
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": lvlm_prompt
},
{
"type": "image_url",
"image_url": f"data:image/jpeg;base64,{base64_image}"
}
]
}
]
# Get the chat response
chat_response = client.chat.complete(
model="pixtral-12b-2409",
messages=messages
)
return chat_response.choices[0].message.content, frame_path
except Exception as e:
return f"❌ Error in inference: {str(e)}", frame_path
# LangChain Runnable chain
prompt_processing_module = RunnableLambda(prompt_processing)
lvlm_inference_module = RunnableLambda(lvlm_inference)
mm_rag_chain = (
RunnableParallel({"retrieved_results": retriever_module, "user_query": RunnablePassthrough()})
| prompt_processing_module
| lvlm_inference_module
)
# -------------------------------
# 4. Chat API for Gradio
# -------------------------------
video_loaded = False
def load_video(youtube_url):
global video_loaded
if not youtube_url.strip():
return "❌ Please enter a YouTube URL"
try:
status = preprocess_and_store(youtube_url)
if "βœ…" in status:
video_loaded = True
return status
except Exception as e:
return f"❌ Error loading video: {str(e)}"
def chat_interface(message, history):
if not video_loaded:
return "", history + [(message, "❌ Please load a video first in the 'Load Video' tab.")], None
if not message.strip():
return "", history, None
try:
final_text_response, frame_path = mm_rag_chain.invoke(message)
history.append((message, final_text_response))
# Load and return the image
retrieved_image = None
if frame_path:
try:
retrieved_image = Image.open(frame_path)
except Exception as e:
print(f"Error loading image: {e}")
return "", history, retrieved_image
except Exception as e:
error_msg = f"❌ Error processing query: {str(e)}"
history.append((message, error_msg))
return "", history, None
def clear_chat():
return [], None
# -------------------------------
# 5. Enhanced Gradio Interface
# -------------------------------
with gr.Blocks(
title="Multimodal RAG Video Chat",
theme=gr.themes.Default()
) as demo:
gr.Markdown("""
# 🎬 Multimodal RAG Video Chat
Chat with YouTube videos using BridgeTower embeddings + LanceDB + Pixtral Vision-Language Model!
⚠️ **Important**: You need to set your `MISTRAL_API_KEY` in the Space settings for this to work.
""")
with gr.Tab("1. Load Video"):
with gr.Column():
youtube_url = gr.Textbox(
label="YouTube URL",
placeholder="https://www.youtube.com/watch?v=...",
lines=1,
scale=4
)
with gr.Row():
load_btn = gr.Button("πŸ”„ Process Video", variant="primary", scale=1)
status = gr.Textbox(
label="Status",
interactive=False,
lines=2
)
load_btn.click(
fn=load_video,
inputs=youtube_url,
outputs=status,
show_progress=True
)
with gr.Tab("2. Chat with Video"):
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Chat about the video",
height=500
)
with gr.Column(scale=1):
retrieved_image = gr.Image(
label="Retrieved Frame",
height=400,
show_label=True,
interactive=False
)
with gr.Row():
with gr.Column(scale=4):
msg = gr.Textbox(
label="Your question",
placeholder="Ask something about the video content...",
lines=2,
container=False
)
with gr.Column(scale=1, min_width=100):
send_btn = gr.Button("πŸ“€ Send", variant="primary")
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
# Event handlers
msg.submit(
fn=chat_interface,
inputs=[msg, chatbot],
outputs=[msg, chatbot, retrieved_image],
show_progress=True
)
send_btn.click(
fn=chat_interface,
inputs=[msg, chatbot],
outputs=[msg, chatbot, retrieved_image],
show_progress=True
)
clear_btn.click(
fn=clear_chat,
outputs=[chatbot, retrieved_image]
)
with gr.Tab("πŸ“– Instructions"):
gr.Markdown("""
## How to use this Multimodal RAG system:
### πŸ”§ Setup:
1. **Set API Key**: Make sure `MISTRAL_API_KEY` is set in your Space settings
2. This app uses Pixtral-12B for vision-language understanding
### πŸŽ₯ Load Video:
1. Go to the "Load Video" tab
2. Paste a YouTube URL (make sure it's publicly accessible)
3. Click "πŸ”„ Process Video" and wait for processing to complete
4. Look for the βœ… success message
### πŸ’¬ Chat with Video:
1. Go to the "Chat with Video" tab
2. Ask questions about the video content
3. The system will retrieve the most relevant frame and provide answers
4. The retrieved frame will be displayed on the right side
## ✨ Features:
- πŸŽ₯ **Automatic YouTube Processing**: Downloads and processes YouTube videos
- 🧠 **Multimodal Embeddings**: Uses BridgeTower for combined text+image understanding
- πŸ’Ύ **Vector Storage**: Stores data in LanceDB for fast similarity search
- πŸ€– **Vision-Language AI**: Powered by Mistral's Pixtral model
- πŸ–ΌοΈ **Visual Context**: Shows relevant video frames alongside responses
- πŸ”„ **Real-time Processing**: Fast retrieval and inference
## ⚠️ Limitations:
- Works with publicly accessible YouTube videos only
- Processing time depends on video length
- Requires stable internet connection for video download
- API rate limits apply based on Mistral usage
## πŸ› οΈ Technical Stack:
- **Embeddings**: BridgeTower (multimodal)
- **Vector DB**: LanceDB
- **Vision-Language Model**: Pixtral-12B
- **Framework**: LangChain + Gradio
""")
with gr.Tab("πŸ” About"):
gr.Markdown("""
## Multimodal RAG Video Chat System
This application demonstrates a complete multimodal Retrieval-Augmented Generation (RAG) pipeline that can understand and answer questions about video content.
### Architecture:
1. **Video Processing**: Downloads YouTube videos and extracts frames with timestamps
2. **Multimodal Embedding**: Uses BridgeTower to create embeddings that understand both visual and textual content
3. **Vector Storage**: Stores embeddings in LanceDB for efficient similarity search
4. **Retrieval**: Finds the most relevant video segments based on user queries
5. **Generation**: Uses Pixtral vision-language model to generate contextual responses
### Built with:
- **Gradio**: For the web interface
- **LangChain**: For orchestrating the RAG pipeline
- **LanceDB**: For vector storage and retrieval
- **BridgeTower**: For multimodal embeddings
- **Mistral Pixtral**: For vision-language understanding
---
πŸ’‘ **Tip**: For best results, ask specific questions about visual content, actions, or scenes in the video.
""")
# -------------------------------
# 6. Launch Configuration
# -------------------------------
if __name__ == "__main__":
print('πŸš€ Starting Multimodal RAG Video Chat App...')
# Check for required environment variables
if not os.getenv("MISTRAL_API_KEY"):
print("⚠️ WARNING: MISTRAL_API_KEY not found in environment variables")
print(" Please set this in your HuggingFace Space settings")
# Launch with appropriate settings for HF Spaces
if IS_SPACES:
demo.launch(share=True, server_name="0.0.0.0", server_port=7860) # Use default settings for HF Spaces
else:
demo.launch(share=True, server_name="0.0.0.0", server_port=7860)