|
from pathlib import Path |
|
import os |
|
from os import path as osp |
|
import gradio as gr |
|
from dotenv import load_dotenv |
|
from crud.vector_store import MultimodalLanceDB |
|
from preprocess.embedding import BridgeTowerEmbeddings |
|
from preprocess.preprocessing import extract_and_save_frames_and_metadata |
|
from utils import ( |
|
download_video, |
|
get_transcript_vtt, |
|
download_youtube_subtitle, |
|
get_video_id_from_url, |
|
str2time, |
|
maintain_aspect_ratio_resize, |
|
getSubs, |
|
encode_image, |
|
) |
|
from mistralai import Mistral |
|
from langchain_core.runnables import ( |
|
RunnableParallel, |
|
RunnablePassthrough, |
|
RunnableLambda |
|
) |
|
from PIL import Image |
|
import lancedb |
|
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
SPACE_ID = os.getenv("SPACE_ID") |
|
IS_SPACES = SPACE_ID is not None |
|
|
|
if IS_SPACES: |
|
LANCEDB_HOST_FILE = "/tmp/.lancedb" |
|
VIDEO_DIR = "/tmp/videos/video1" |
|
os.makedirs("/tmp", exist_ok=True) |
|
else: |
|
LANCEDB_HOST_FILE = "./shared_data/.lancedb" |
|
VIDEO_DIR = "./shared_data/videos/video1" |
|
|
|
TBL_NAME = "vectorstore" |
|
|
|
|
|
db = lancedb.connect(LANCEDB_HOST_FILE) |
|
embedder = BridgeTowerEmbeddings() |
|
|
|
|
|
|
|
|
|
def preprocess_and_store(youtube_url: str): |
|
"""Download video, extract frames+metadata, embed & store in LanceDB""" |
|
try: |
|
video_url = youtube_url |
|
video_dir = VIDEO_DIR |
|
|
|
|
|
video_filepath = download_video(video_url, video_dir) |
|
|
|
|
|
video_transcript_filepath = download_youtube_subtitle(video_url, video_dir) |
|
|
|
extracted_frames_path = osp.join(video_dir, 'extracted_frame') |
|
|
|
|
|
Path(extracted_frames_path).mkdir(parents=True, exist_ok=True) |
|
Path(video_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
metadatas = extract_and_save_frames_and_metadata( |
|
video_filepath, |
|
video_transcript_filepath, |
|
extracted_frames_path, |
|
video_dir, |
|
) |
|
|
|
|
|
video_trans = [vid['transcript'] for vid in metadatas] |
|
video_img_path = [vid['extracted_frame_path'] for vid in metadatas] |
|
|
|
n = 7 |
|
updated_video_trans = [ |
|
' '.join(video_trans[i-int(n/2) : i+int(n/2)]) if i-int(n/2) >= 0 else |
|
' '.join(video_trans[0 : i + int(n/2)]) for i in range(len(video_trans)) |
|
] |
|
|
|
|
|
for i in range(len(updated_video_trans)): |
|
metadatas[i]['transcript'] = updated_video_trans[i] |
|
|
|
_ = MultimodalLanceDB.from_text_image_pairs( |
|
texts=updated_video_trans, |
|
image_paths=video_img_path, |
|
embedding=embedder, |
|
metadatas=metadatas, |
|
connection=db, |
|
table_name=TBL_NAME, |
|
mode="overwrite", |
|
) |
|
|
|
return f"β
Video processed and stored: {youtube_url}" |
|
|
|
except Exception as e: |
|
return f"β Error processing video: {str(e)}" |
|
|
|
|
|
|
|
|
|
vectorstore = MultimodalLanceDB( |
|
uri=LANCEDB_HOST_FILE, |
|
embedding=embedder, |
|
table_name=TBL_NAME |
|
) |
|
|
|
retriever_module = vectorstore.as_retriever( |
|
search_type="similarity", |
|
search_kwargs={"k": 3} |
|
) |
|
|
|
def prompt_processing(input): |
|
retrieved_results = input["retrieved_results"] |
|
user_query = input["user_query"] |
|
|
|
if not retrieved_results: |
|
return {"prompt": "No relevant content found.", "frame_path": None} |
|
|
|
retrieved_results = retrieved_results[0] |
|
prompt_template = ( |
|
"The transcript associated with the image is '{transcript}'. " |
|
"{user_query}" |
|
) |
|
|
|
retrieved_metadata = retrieved_results.metadata |
|
transcript = retrieved_metadata["transcript"] |
|
frame_path = retrieved_metadata["extracted_frame_path"] |
|
|
|
return { |
|
"prompt": prompt_template.format(transcript=transcript, user_query=user_query), |
|
"frame_path": frame_path, |
|
} |
|
|
|
def lvlm_inference(input): |
|
try: |
|
|
|
lvlm_prompt = input['prompt'] |
|
frame_path = input['frame_path'] |
|
|
|
if frame_path is None: |
|
return "No relevant frame found.", None |
|
|
|
|
|
api_key = os.getenv("MISTRAL_API_KEY") |
|
if not api_key: |
|
return "β MISTRAL_API_KEY not found. Please set it in the environment variables.", frame_path |
|
|
|
|
|
client = Mistral(api_key=api_key) |
|
|
|
base64_image = encode_image(frame_path) |
|
|
|
|
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "text", |
|
"text": lvlm_prompt |
|
}, |
|
{ |
|
"type": "image_url", |
|
"image_url": f"data:image/jpeg;base64,{base64_image}" |
|
} |
|
] |
|
} |
|
] |
|
|
|
|
|
chat_response = client.chat.complete( |
|
model="pixtral-12b-2409", |
|
messages=messages |
|
) |
|
|
|
return chat_response.choices[0].message.content, frame_path |
|
|
|
except Exception as e: |
|
return f"β Error in inference: {str(e)}", frame_path |
|
|
|
|
|
prompt_processing_module = RunnableLambda(prompt_processing) |
|
lvlm_inference_module = RunnableLambda(lvlm_inference) |
|
|
|
mm_rag_chain = ( |
|
RunnableParallel({"retrieved_results": retriever_module, "user_query": RunnablePassthrough()}) |
|
| prompt_processing_module |
|
| lvlm_inference_module |
|
) |
|
|
|
|
|
|
|
|
|
video_loaded = False |
|
|
|
def load_video(youtube_url): |
|
global video_loaded |
|
if not youtube_url.strip(): |
|
return "β Please enter a YouTube URL" |
|
|
|
try: |
|
status = preprocess_and_store(youtube_url) |
|
if "β
" in status: |
|
video_loaded = True |
|
return status |
|
except Exception as e: |
|
return f"β Error loading video: {str(e)}" |
|
|
|
def chat_interface(message, history): |
|
if not video_loaded: |
|
return "", history + [(message, "β Please load a video first in the 'Load Video' tab.")], None |
|
|
|
if not message.strip(): |
|
return "", history, None |
|
|
|
try: |
|
final_text_response, frame_path = mm_rag_chain.invoke(message) |
|
history.append((message, final_text_response)) |
|
|
|
|
|
retrieved_image = None |
|
if frame_path: |
|
try: |
|
retrieved_image = Image.open(frame_path) |
|
except Exception as e: |
|
print(f"Error loading image: {e}") |
|
|
|
return "", history, retrieved_image |
|
|
|
except Exception as e: |
|
error_msg = f"β Error processing query: {str(e)}" |
|
history.append((message, error_msg)) |
|
return "", history, None |
|
|
|
def clear_chat(): |
|
return [], None |
|
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
title="Multimodal RAG Video Chat", |
|
theme=gr.themes.Default() |
|
) as demo: |
|
gr.Markdown(""" |
|
# π¬ Multimodal RAG Video Chat |
|
|
|
Chat with YouTube videos using BridgeTower embeddings + LanceDB + Pixtral Vision-Language Model! |
|
|
|
β οΈ **Important**: You need to set your `MISTRAL_API_KEY` in the Space settings for this to work. |
|
""") |
|
|
|
with gr.Tab("1. Load Video"): |
|
with gr.Column(): |
|
youtube_url = gr.Textbox( |
|
label="YouTube URL", |
|
placeholder="https://www.youtube.com/watch?v=...", |
|
lines=1, |
|
scale=4 |
|
) |
|
with gr.Row(): |
|
load_btn = gr.Button("π Process Video", variant="primary", scale=1) |
|
status = gr.Textbox( |
|
label="Status", |
|
interactive=False, |
|
lines=2 |
|
) |
|
|
|
load_btn.click( |
|
fn=load_video, |
|
inputs=youtube_url, |
|
outputs=status, |
|
show_progress=True |
|
) |
|
|
|
with gr.Tab("2. Chat with Video"): |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
chatbot = gr.Chatbot( |
|
label="Chat about the video", |
|
height=500 |
|
) |
|
|
|
with gr.Column(scale=1): |
|
retrieved_image = gr.Image( |
|
label="Retrieved Frame", |
|
height=400, |
|
show_label=True, |
|
interactive=False |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
msg = gr.Textbox( |
|
label="Your question", |
|
placeholder="Ask something about the video content...", |
|
lines=2, |
|
container=False |
|
) |
|
with gr.Column(scale=1, min_width=100): |
|
send_btn = gr.Button("π€ Send", variant="primary") |
|
clear_btn = gr.Button("ποΈ Clear", variant="secondary") |
|
|
|
|
|
msg.submit( |
|
fn=chat_interface, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot, retrieved_image], |
|
show_progress=True |
|
) |
|
send_btn.click( |
|
fn=chat_interface, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot, retrieved_image], |
|
show_progress=True |
|
) |
|
clear_btn.click( |
|
fn=clear_chat, |
|
outputs=[chatbot, retrieved_image] |
|
) |
|
|
|
with gr.Tab("π Instructions"): |
|
gr.Markdown(""" |
|
## How to use this Multimodal RAG system: |
|
|
|
### π§ Setup: |
|
1. **Set API Key**: Make sure `MISTRAL_API_KEY` is set in your Space settings |
|
2. This app uses Pixtral-12B for vision-language understanding |
|
|
|
### π₯ Load Video: |
|
1. Go to the "Load Video" tab |
|
2. Paste a YouTube URL (make sure it's publicly accessible) |
|
3. Click "π Process Video" and wait for processing to complete |
|
4. Look for the β
success message |
|
|
|
### π¬ Chat with Video: |
|
1. Go to the "Chat with Video" tab |
|
2. Ask questions about the video content |
|
3. The system will retrieve the most relevant frame and provide answers |
|
4. The retrieved frame will be displayed on the right side |
|
|
|
## β¨ Features: |
|
- π₯ **Automatic YouTube Processing**: Downloads and processes YouTube videos |
|
- π§ **Multimodal Embeddings**: Uses BridgeTower for combined text+image understanding |
|
- πΎ **Vector Storage**: Stores data in LanceDB for fast similarity search |
|
- π€ **Vision-Language AI**: Powered by Mistral's Pixtral model |
|
- πΌοΈ **Visual Context**: Shows relevant video frames alongside responses |
|
- π **Real-time Processing**: Fast retrieval and inference |
|
|
|
## β οΈ Limitations: |
|
- Works with publicly accessible YouTube videos only |
|
- Processing time depends on video length |
|
- Requires stable internet connection for video download |
|
- API rate limits apply based on Mistral usage |
|
|
|
## π οΈ Technical Stack: |
|
- **Embeddings**: BridgeTower (multimodal) |
|
- **Vector DB**: LanceDB |
|
- **Vision-Language Model**: Pixtral-12B |
|
- **Framework**: LangChain + Gradio |
|
""") |
|
|
|
with gr.Tab("π About"): |
|
gr.Markdown(""" |
|
## Multimodal RAG Video Chat System |
|
|
|
This application demonstrates a complete multimodal Retrieval-Augmented Generation (RAG) pipeline that can understand and answer questions about video content. |
|
|
|
### Architecture: |
|
1. **Video Processing**: Downloads YouTube videos and extracts frames with timestamps |
|
2. **Multimodal Embedding**: Uses BridgeTower to create embeddings that understand both visual and textual content |
|
3. **Vector Storage**: Stores embeddings in LanceDB for efficient similarity search |
|
4. **Retrieval**: Finds the most relevant video segments based on user queries |
|
5. **Generation**: Uses Pixtral vision-language model to generate contextual responses |
|
|
|
### Built with: |
|
- **Gradio**: For the web interface |
|
- **LangChain**: For orchestrating the RAG pipeline |
|
- **LanceDB**: For vector storage and retrieval |
|
- **BridgeTower**: For multimodal embeddings |
|
- **Mistral Pixtral**: For vision-language understanding |
|
|
|
--- |
|
|
|
π‘ **Tip**: For best results, ask specific questions about visual content, actions, or scenes in the video. |
|
""") |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
print('π Starting Multimodal RAG Video Chat App...') |
|
|
|
|
|
if not os.getenv("MISTRAL_API_KEY"): |
|
print("β οΈ WARNING: MISTRAL_API_KEY not found in environment variables") |
|
print(" Please set this in your HuggingFace Space settings") |
|
|
|
|
|
if IS_SPACES: |
|
demo.launch(share=True, server_name="0.0.0.0", server_port=7860) |
|
else: |
|
demo.launch(share=True, server_name="0.0.0.0", server_port=7860) |