import streamlit as st
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import torch
import json
import os
import glob
from pathlib import Path
from datetime import datetime, timedelta
import edge_tts
import asyncio
import requests
from collections import defaultdict
import streamlit.components.v1 as components
from urllib.parse import quote
from xml.etree import ElementTree as ET
from datasets import load_dataset
import base64
import re

# 🧠 Initialize session state variables
SESSION_VARS = {
    'search_history': [],          # Track search history
    'last_voice_input': "",        # Last voice input
    'transcript_history': [],      # Conversation history
    'should_rerun': False,         # Trigger for UI updates
    'search_columns': [],          # Available search columns
    'initial_search_done': False,  # First search flag
    'tts_voice': "en-US-AriaNeural", # Default voice
    'arxiv_last_query': "",        # Last ArXiv search
    'dataset_loaded': False,       # Dataset load status
    'current_page': 0,            # Current data page
    'data_cache': None,           # Data cache
    'dataset_info': None,         # Dataset metadata
    'nps_submitted': False,       # Track if user submitted NPS
    'nps_last_shown': None,       # When NPS was last shown
    'old_val': None,              # Previous voice input value
    'voice_text': None            # Processed voice text
}

# Constants
ROWS_PER_PAGE = 100
MIN_SEARCH_SCORE = 0.3
EXACT_MATCH_BOOST = 2.0

# Initialize session state
for var, default in SESSION_VARS.items():
    if var not in st.session_state:
        st.session_state[var] = default

# Voice Component Setup
def create_voice_component():
    """Create the voice input component"""
    mycomponent = components.declare_component(
        "mycomponent",
        path="mycomponent"
    )
    return mycomponent

# Utility Functions
def clean_for_speech(text: str) -> str:
    """Clean text for speech synthesis"""
    text = text.replace("\n", " ")
    text = text.replace("</s>", " ")
    text = text.replace("#", "")
    text = re.sub(r"\(https?:\/\/[^\)]+\)", "", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text

async def edge_tts_generate_audio(text, voice="en-US-AriaNeural", rate=0, pitch=0):
    """Generate audio using Edge TTS"""
    text = clean_for_speech(text)
    if not text.strip():
        return None
    rate_str = f"{rate:+d}%"
    pitch_str = f"{pitch:+d}Hz"
    communicate = edge_tts.Communicate(text, voice, rate=rate_str, pitch=pitch_str)
    out_fn = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3"
    await communicate.save(out_fn)
    return out_fn

def speak_with_edge_tts(text, voice="en-US-AriaNeural", rate=0, pitch=0):
    """Wrapper for edge TTS generation"""
    return asyncio.run(edge_tts_generate_audio(text, voice, rate, pitch))

def play_and_download_audio(file_path):
    """Play and provide download link for audio"""
    if file_path and os.path.exists(file_path):
        st.audio(file_path)
        dl_link = f'<a href="data:audio/mpeg;base64,{base64.b64encode(open(file_path,"rb").read()).decode()}" download="{os.path.basename(file_path)}">Download {os.path.basename(file_path)}</a>'
        st.markdown(dl_link, unsafe_allow_html=True)

@st.cache_resource
def get_model():
    """Get sentence transformer model"""
    return SentenceTransformer('all-MiniLM-L6-v2')

@st.cache_data
def load_dataset_page(dataset_id, token, page, rows_per_page):
    """Load dataset page with caching"""
    try:
        start_idx = page * rows_per_page
        end_idx = start_idx + rows_per_page
        dataset = load_dataset(
            dataset_id,
            token=token,
            streaming=False,
            split=f'train[{start_idx}:{end_idx}]'
        )
        return pd.DataFrame(dataset)
    except Exception as e:
        st.error(f"Error loading page {page}: {str(e)}")
        return pd.DataFrame()

@st.cache_data
def get_dataset_info(dataset_id, token):
    """Get dataset info with caching"""
    try:
        dataset = load_dataset(dataset_id, token=token, streaming=True)
        return dataset['train'].info
    except Exception as e:
        st.error(f"Error loading dataset info: {str(e)}")
        return None

def fetch_dataset_info(dataset_id):
    """Fetch dataset information"""
    info_url = f"https://huggingface.co/api/datasets/{dataset_id}"
    try:
        response = requests.get(info_url, timeout=30)
        if response.status_code == 200:
            return response.json()
    except Exception as e:
        st.warning(f"Error fetching dataset info: {e}")
    return None

def generate_filename(text):
    """Generate unique filename from text"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    safe_text = re.sub(r'[^\w\s-]', '', text[:50]).strip().lower()
    safe_text = re.sub(r'[-\s]+', '-', safe_text)
    return f"{timestamp}_{safe_text}"

def render_result(result):
    """Render a single search result"""
    score = result.get('relevance_score', 0)
    result_filtered = {k: v for k, v in result.items() 
                      if k not in ['relevance_score', 'video_embed', 'description_embed', 'audio_embed']}
    
    if 'youtube_id' in result:
        st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}")
    
    cols = st.columns([2, 1])
    with cols[0]:
        text_content = []
        for key, value in result_filtered.items():
            if isinstance(value, (str, int, float)):
                st.write(f"**{key}:** {value}")
                if isinstance(value, str) and len(value.strip()) > 0:
                    text_content.append(f"{key}: {value}")
    
    with cols[1]:
        st.metric("Relevance", f"{score:.2%}")
        
        voices = {
            "Aria (US Female)": "en-US-AriaNeural",
            "Guy (US Male)": "en-US-GuyNeural",
            "Sonia (UK Female)": "en-GB-SoniaNeural",
            "Tony (UK Male)": "en-GB-TonyNeural"
        }
        
        selected_voice = st.selectbox(
            "Voice:",
            list(voices.keys()),
            key=f"voice_{result.get('video_id', '')}"
        )
        
        if st.button("🔊 Read", key=f"read_{result.get('video_id', '')}"):
            text_to_read = ". ".join(text_content)
            audio_file = speak_with_edge_tts(text_to_read, voices[selected_voice])
            if audio_file:
                play_and_download_audio(audio_file)

class FastDatasetSearcher:
    """Fast dataset search with semantic and token matching"""
    
    def __init__(self, dataset_id="tomg-group-umd/cinepile"):
        self.dataset_id = dataset_id
        self.text_model = get_model()
        self.token = os.environ.get('DATASET_KEY')
        if not self.token:
            st.error("Please set the DATASET_KEY environment variable")
            st.stop()
        
        if st.session_state['dataset_info'] is None:
            st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token)

    def load_page(self, page=0):
        """Load a specific page of data"""
        return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE)

    def quick_search(self, query, df):
        """Perform quick search with semantic similarity"""
        if df.empty or not query.strip():
            return df
        
        try:
            searchable_cols = []
            for col in df.columns:
                sample_val = df[col].iloc[0]
                if not isinstance(sample_val, (np.ndarray, bytes)):
                    searchable_cols.append(col)
            
            query_lower = query.lower()
            query_terms = set(query_lower.split())
            query_embedding = self.text_model.encode([query], show_progress_bar=False)[0]
            
            scores = []
            matched_any = []
            
            for _, row in df.iterrows():
                text_parts = []
                row_matched = False
                exact_match = False
                
                priority_fields = ['description', 'matched_text']
                other_fields = [col for col in searchable_cols if col not in priority_fields]
                
                for col in priority_fields:
                    if col in row:
                        val = row[col]
                        if val is not None:
                            val_str = str(val).lower()
                            if query_lower in val_str.split():
                                exact_match = True
                            if any(term in val_str.split() for term in query_terms):
                                row_matched = True
                            text_parts.append(str(val))
                
                for col in other_fields:
                    val = row[col]
                    if val is not None:
                        val_str = str(val).lower()
                        if query_lower in val_str.split():
                            exact_match = True
                        if any(term in val_str.split() for term in query_terms):
                            row_matched = True
                        text_parts.append(str(val))
                
                text = ' '.join(text_parts)
                
                if text.strip():
                    text_tokens = set(text.lower().split())
                    matching_terms = query_terms.intersection(text_tokens)
                    keyword_score = len(matching_terms) / len(query_terms)
                    
                    text_embedding = self.text_model.encode([text], show_progress_bar=False)[0]
                    semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0])
                    
                    combined_score = 0.7 * keyword_score + 0.3 * semantic_score
                    
                    if exact_match:
                        combined_score *= EXACT_MATCH_BOOST
                    elif row_matched:
                        combined_score *= 1.2
                else:
                    combined_score = 0.0
                    row_matched = False
                
                scores.append(combined_score)
                matched_any.append(row_matched)
            
            results_df = df.copy()
            results_df['score'] = scores
            results_df['matched'] = matched_any
            
            filtered_df = results_df[
                (results_df['matched']) |
                (results_df['score'] > MIN_SEARCH_SCORE)
            ]
            
            return filtered_df.sort_values('score', ascending=False)
            
        except Exception as e:
            st.error(f"Search error: {str(e)}")
            return df

def main():
    st.title("🎥 Smart Video & Voice Search")
    
    # Initialize components
    voice_component = create_voice_component()
    search = FastDatasetSearcher()
    
    # Voice input at top level
    voice_val = voice_component(my_input_value="Start speaking...")
    
    # Show voice input if detected
    if voice_val:
        voice_text = str(voice_val).strip()
        edited_input = st.text_area("✏️ Edit Voice Input:", value=voice_text, height=100)
        
        run_option = st.selectbox("Select Search Type:", 
                                ["Quick Search", "Deep Search", "Voice Summary"])
        
        col1, col2 = st.columns(2)
        with col1:
            autorun = st.checkbox("⚡ Auto-Run", value=False)
        with col2:
            full_audio = st.checkbox("🔊 Full Audio", value=False)
        
        input_changed = (voice_text != st.session_state.get('old_val'))
        
        if autorun and input_changed:
            st.session_state['old_val'] = voice_text
            with st.spinner("Processing voice input..."):
                if run_option == "Quick Search":
                    results = search.quick_search(edited_input, search.load_page())
                    for i, result in enumerate(results.iterrows(), 1):
                        with st.expander(f"Result {i}", expanded=(i==1)):
                            render_result(result[1])
                            
                elif run_option == "Deep Search":
                    with st.spinner("Performing deep search..."):
                        results = []
                        for page in range(3):  # Search first 3 pages
                            df = search.load_page(page)
                            results.extend(search.quick_search(edited_input, df).iterrows())
                        
                        for i, result in enumerate(results, 1):
                            with st.expander(f"Result {i}", expanded=(i==1)):
                                render_result(result[1])
                                
                elif run_option == "Voice Summary":
                    audio_file = speak_with_edge_tts(edited_input)
                    if audio_file:
                        play_and_download_audio(audio_file)
                        
        elif st.button("🔍 Search", key="voice_input_search"):
            st.session_state['old_val'] = voice_text
            with st.spinner("Processing..."):
                results = search.quick_search(edited_input, search.load_page())
                for i, result in enumerate(results.iterrows(), 1):
                    with st.expander(f"Result {i}", expanded=(i==1)):
                        render_result(result[1])
    
    # Create main tabs
    tab1, tab2, tab3, tab4 = st.tabs([
        "🔍 Search", "🎙️ Voice", "💾 History", "⚙️ Settings"
    ])
    
    with tab1:
        st.subheader("🔍 Search")
        col1, col2 = st.columns([3, 1])
        with col1:
            query = st.text_input("Enter search query:", 
                                value="" if st.session_state['initial_search_done'] else "")
        with col2:
            search_column = st.selectbox("Search in:", 
                                       ["All Fields"] + st.session_state['search_columns'])
        
        col3, col4 = st.columns(2)
        with col3:
            num_results = st.slider("Max results:", 1, 100, 20)
        with col4:
            search_button = st.button("🔍 Search", key="main_search_button")
        
        if (search_button or not st.session_state['initial_search_done']) and query:
            st.session_state['initial_search_done'] = True
            selected_column = None if search_column == "All Fields" else search_column
            
            with st.spinner("Searching..."):
                df = search.load_page()
                results = search.quick_search(query, df)
                
                if len(results) > 0:
                    st.session_state['search_history'].append({
                        'query': query,
                        'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                        'results': results[:5]
                    })
                    
                    st.write(f"Found {len(results)} results:")
                    for i, (_, result) in enumerate(results.iterrows(), 1):
                        if i > num_results:
                            break
                        with st.expander(f"Result {i}", expanded=(i==1)):
                            render_result(result)
                else:
                    st.warning("No matching results found.")
    
    with tab2:
        st.subheader("🎙️ Voice Input")
        st.write("Use the voice input above to start speaking, or record a new message:")
        
        col1, col2 = st.columns(2)
        with col1:
            if st.button("🎙️ Start New Recording", key="start_recording_button"):
                st.session_state['recording'] = True
                st.experimental_rerun()
        with col2:
            if st.button("🛑 Stop Recording", key="stop_recording_button"):
                st.session_state['recording'] = False
                st.experimental_rerun()
        
        if st.session_state.get('recording', False):
            voice_component = create_voice_component()
            new_val = voice_component(my_input_value="Recording...")
            if new_val:
                st.text_area("Recorded Text:", value=new_val, height=100)
                if st.button("🔍 Search with Recording", key="recording_search_button"):
                    with st.spinner("Processing recording..."):
                        df = search.load_page()
                        results = search.quick_search(new_val, df)
                        for i, (_, result) in enumerate(results.iterrows(), 1):
                            with st.expander(f"Result {i}", expanded=(i==1)):
                                render_result(result)
    
    with tab3:
        st.subheader("💾 Search History")
        if not st.session_state['search_history']:
            st.info("No search history yet. Try searching for something!")
        else:
            for entry in reversed(st.session_state['search_history']):
                with st.expander(f"🕒 {entry['timestamp']} - {entry['query']}", expanded=False):
                    for i, result in enumerate(entry['results'], 1):
                        st.write(f"**Result {i}:**")
                        if isinstance(result, pd.Series):
                            render_result(result)
                        else:
                            st.write(result)
    
    with tab4:
        st.subheader("⚙️ Settings")
        st.write("Voice Settings:")
        default_voice = st.selectbox(
            "Default Voice:",
            [
                "en-US-AriaNeural",
                "en-US-GuyNeural",
                "en-GB-SoniaNeural",
                "en-GB-TonyNeural"
            ],
            index=0,
            key="default_voice_setting"
        )
        
        st.write("Search Settings:")
        st.slider("Minimum Search Score:", 0.0, 1.0, MIN_SEARCH_SCORE, 0.1, key="min_search_score")
        st.slider("Exact Match Boost:", 1.0, 3.0, EXACT_MATCH_BOOST, 0.1, key="exact_match_boost")
        
        if st.button("🗑️ Clear Search History", key="clear_history_button"):
            st.session_state['search_history'] = []
            st.success("Search history cleared!")
            st.experimental_rerun()
    
    # Sidebar with metrics
    with st.sidebar:
        st.subheader("📊 Search Metrics")
        total_searches = len(st.session_state['search_history'])
        st.metric("Total Searches", total_searches)
        
        if total_searches > 0:
            recent_searches = st.session_state['search_history'][-5:]
            st.write("Recent Searches:")
            for entry in reversed(recent_searches):
                st.write(f"🔍 {entry['query']}")

if __name__ == "__main__":
    main()