import streamlit as st
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import torch
import json
import os
import glob
from pathlib import Path
from datetime import datetime
import edge_tts
import asyncio
import base64
import requests
from collections import defaultdict
from audio_recorder_streamlit import audio_recorder
import streamlit.components.v1 as components
from urllib.parse import quote
from xml.etree import ElementTree as ET

# Initialize session state
if 'search_history' not in st.session_state:
    st.session_state['search_history'] = []
if 'last_voice_input' not in st.session_state:
    st.session_state['last_voice_input'] = ""
if 'transcript_history' not in st.session_state:
    st.session_state['transcript_history'] = []
if 'should_rerun' not in st.session_state:
    st.session_state['should_rerun'] = False
if 'search_columns' not in st.session_state:
    st.session_state['search_columns'] = []
if 'initial_search_done' not in st.session_state:
    st.session_state['initial_search_done'] = False
if 'tts_voice' not in st.session_state:
    st.session_state['tts_voice'] = "en-US-AriaNeural"
if 'arxiv_last_query' not in st.session_state:
    st.session_state['arxiv_last_query'] = ""

class VideoSearch:
    def __init__(self):
        self.text_model = SentenceTransformer('all-MiniLM-L6-v2')
        self.load_dataset()
        
    def fetch_dataset_rows(self):
        """Fetch dataset from Hugging Face API"""
        try:
            url = "https://datasets-server.huggingface.co/first-rows?dataset=omegalabsinc%2Fomega-multimodal&config=default&split=train"
            response = requests.get(url, timeout=30)
            if response.status_code == 200:
                data = response.json()
                if 'rows' in data:
                    processed_rows = []
                    for row_data in data['rows']:
                        row = row_data.get('row', row_data)
                        for key in row:
                            if any(term in key.lower() for term in ['embed', 'vector', 'encoding']):
                                if isinstance(row[key], str):
                                    try:
                                        row[key] = [float(x.strip()) for x in row[key].strip('[]').split(',') if x.strip()]
                                    except:
                                        continue
                        processed_rows.append(row)
                    
                    df = pd.DataFrame(processed_rows)
                    st.session_state['search_columns'] = [col for col in df.columns 
                                                        if col not in ['video_embed', 'description_embed', 'audio_embed']]
                    return df
            return self.load_example_data()
        except:
            return self.load_example_data()

    def prepare_features(self):
        """Prepare embeddings with adaptive field detection"""
        try:
            embed_cols = [col for col in self.dataset.columns 
                         if any(term in col.lower() for term in ['embed', 'vector', 'encoding'])]
            
            embeddings = {}
            for col in embed_cols:
                try:
                    data = []
                    for row in self.dataset[col]:
                        if isinstance(row, str):
                            values = [float(x.strip()) for x in row.strip('[]').split(',') if x.strip()]
                        elif isinstance(row, list):
                            values = row
                        else:
                            continue
                        data.append(values)
                    
                    if data:
                        embeddings[col] = np.array(data)
                except:
                    continue
            
            # Set main embeddings for search
            if 'video_embed' in embeddings:
                self.video_embeds = embeddings['video_embed']
            else:
                self.video_embeds = next(iter(embeddings.values()))
                
            if 'description_embed' in embeddings:
                self.text_embeds = embeddings['description_embed']
            else:
                self.text_embeds = self.video_embeds
                
        except:
            # Fallback to random embeddings
            num_rows = len(self.dataset)
            self.video_embeds = np.random.randn(num_rows, 384)
            self.text_embeds = np.random.randn(num_rows, 384)

    def load_example_data(self):
        """Load example data as fallback"""
        example_data = [
            {
                "video_id": "cd21da96-fcca-4c94-a60f-0b1e4e1e29fc",
                "youtube_id": "IO-vwtyicn4",
                "description": "This video shows a close-up of an ancient text carved into a surface.",
                "views": 45489,
                "start_time": 1452,
                "end_time": 1458,
                "video_embed": [0.014160037972033024, -0.003111184574663639, -0.016604168340563774],
                "description_embed": [-0.05835828185081482, 0.02589797042310238, 0.11952091753482819]
            }
        ]
        return pd.DataFrame(example_data)
    
    def load_dataset(self):
        self.dataset = self.fetch_dataset_rows()
        self.prepare_features()

    def search(self, query, column=None, top_k=20):
        query_embedding = self.text_model.encode([query])[0]
        video_sims = cosine_similarity([query_embedding], self.video_embeds)[0]
        text_sims = cosine_similarity([query_embedding], self.text_embeds)[0]
        combined_sims = 0.5 * video_sims + 0.5 * text_sims
        
        # Column filtering
        if column and column in self.dataset.columns and column != "All Fields":
            mask = self.dataset[column].astype(str).str.contains(query, case=False)
            combined_sims[~mask] *= 0.5
        
        top_k = min(top_k, 100)
        top_indices = np.argsort(combined_sims)[-top_k:][::-1]
        
        results = []
        for idx in top_indices:
            result = {'relevance_score': float(combined_sims[idx])}
            for col in self.dataset.columns:
                if col not in ['video_embed', 'description_embed', 'audio_embed']:
                    result[col] = self.dataset.iloc[idx][col]
            results.append(result)
        
        return results

@st.cache_resource
def get_speech_model():
    return edge_tts.Communicate

async def generate_speech(text, voice=None):
    if not text.strip():
        return None
    if not voice:
        voice = st.session_state['tts_voice']
    try:
        communicate = get_speech_model()(text, voice)
        audio_file = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3"
        await communicate.save(audio_file)
        return audio_file
    except Exception as e:
        st.error(f"Error generating speech: {e}")
        return None

def transcribe_audio(audio_path):
    """Placeholder for ASR transcription (no OpenAI/Anthropic).
       Integrate your own ASR model or API here."""
    # For now, just return a message:
    return "ASR not implemented. Integrate a local model or another service here."

def show_file_manager():
    """Display file manager interface"""
    st.subheader("📂 File Manager")
    col1, col2 = st.columns(2)
    with col1:
        uploaded_file = st.file_uploader("Upload File", type=['txt', 'md', 'mp3'])
        if uploaded_file:
            with open(uploaded_file.name, "wb") as f:
                f.write(uploaded_file.getvalue())
            st.success(f"Uploaded: {uploaded_file.name}")
            st.experimental_rerun()
    
    with col2:
        if st.button("🗑 Clear All Files"):
            for f in glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3"):
                os.remove(f)
            st.success("All files cleared!")
            st.experimental_rerun()
    
    files = glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3")
    if files:
        st.write("### Existing Files")
        for f in files:
            with st.expander(f"📄 {os.path.basename(f)}"):
                if f.endswith('.mp3'):
                    st.audio(f)
                else:
                    with open(f, 'r', encoding='utf-8') as file:
                        st.text_area("Content", file.read(), height=100)
                if st.button(f"Delete {os.path.basename(f)}", key=f"del_{f}"):
                    os.remove(f)
                    st.experimental_rerun()

def arxiv_search(query, max_results=5):
    """Perform a simple Arxiv search using their API and return top results."""
    base_url = "http://export.arxiv.org/api/query?"
    # Encode the query
    search_url = base_url + f"search_query={quote(query)}&start=0&max_results={max_results}"
    r = requests.get(search_url)
    if r.status_code == 200:
        root = ET.fromstring(r.text)
        # Namespace handling
        ns = {'atom': 'http://www.w3.org/2005/Atom'}
        entries = root.findall('atom:entry', ns)
        results = []
        for entry in entries:
            title = entry.find('atom:title', ns).text.strip()
            summary = entry.find('atom:summary', ns).text.strip()
            link = None
            for l in entry.findall('atom:link', ns):
                if l.get('type') == 'text/html':
                    link = l.get('href')
                    break
            results.append((title, summary, link))
        return results
    return []

def perform_arxiv_lookup(q, vocal_summary=True, titles_summary=True, full_audio=False):
    results = arxiv_search(q, max_results=5)
    if not results:
        st.write("No Arxiv results found.")
        return
    st.markdown(f"**Arxiv Search Results for '{q}':**")
    for i, (title, summary, link) in enumerate(results, start=1):
        st.markdown(f"**{i}. {title}**")
        st.write(summary)
        if link:
            st.markdown(f"[View Paper]({link})")

    # TTS Options
    if vocal_summary:
        spoken_text = f"Here are some Arxiv results for {q}. "
        if titles_summary:
            spoken_text += " Titles: " + ", ".join([res[0] for res in results])
        else:
            # Just first summary if no titles_summary
            spoken_text += " " + results[0][1][:200]

        audio_file = asyncio.run(generate_speech(spoken_text))
        if audio_file:
            st.audio(audio_file)
    
    if full_audio:
        # Full audio of summaries
        full_text = ""
        for i,(title, summary, _) in enumerate(results, start=1):
            full_text += f"Result {i}: {title}. {summary} "
        audio_file_full = asyncio.run(generate_speech(full_text))
        if audio_file_full:
            st.write("### Full Audio")
            st.audio(audio_file_full)

def main():
    st.title("🎥 Video & Arxiv Search with Voice (No OpenAI/Anthropic)")
    
    # Initialize search class
    search = VideoSearch()
    
    # Create tabs
    tab1, tab2, tab3, tab4 = st.tabs(["🔍 Search", "🎙️ Voice Input", "📚 Arxiv", "📂 Files"])
    
    # ---- Tab 1: Video Search ----
    with tab1:
        st.subheader("Search Videos")
        col1, col2 = st.columns([3, 1])
        with col1:
            query = st.text_input("Enter your search query:", 
                                  value="ancient" if not st.session_state['initial_search_done'] else "")
        with col2:
            search_column = st.selectbox("Search in field:", 
                                       ["All Fields"] + st.session_state['search_columns'])
        
        col3, col4 = st.columns(2)
        with col3:
            num_results = st.slider("Number of results:", 1, 100, 20)
        with col4:
            search_button = st.button("🔍 Search")
        
        if (search_button or not st.session_state['initial_search_done']) and query:
            st.session_state['initial_search_done'] = True
            selected_column = None if search_column == "All Fields" else search_column
            with st.spinner("Searching..."):
                results = search.search(query, selected_column, num_results)
            
            st.session_state['search_history'].append({
                'query': query,
                'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                'results': results[:5]
            })
            
            for i, result in enumerate(results, 1):
                with st.expander(f"Result {i}: {result['description'][:100]}...", expanded=(i==1)):
                    cols = st.columns([2, 1])
                    with cols[0]:
                        st.markdown("**Description:**")
                        st.write(result['description'])
                        st.markdown(f"**Time Range:** {result['start_time']}s - {result['end_time']}s")
                        st.markdown(f"**Views:** {result['views']:,}")
                    
                    with cols[1]:
                        st.markdown(f"**Relevance Score:** {result['relevance_score']:.2%}")
                        if result.get('youtube_id'):
                            st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result['start_time']}")
                        
                        if st.button(f"🔊 Audio Summary", key=f"audio_{i}"):
                            summary = f"Video summary: {result['description'][:200]}"
                            audio_file = asyncio.run(generate_speech(summary))
                            if audio_file:
                                st.audio(audio_file)

    # ---- Tab 2: Voice Input ----
    with tab2:
        st.subheader("Voice Input")
        
        st.write("🎙️ Record your voice:")
        audio_bytes = audio_recorder()
        if audio_bytes:
            audio_path = f"temp_audio_{datetime.now().strftime('%Y%m%d_%H%M%S')}.wav"
            with open(audio_path, "wb") as f:
                f.write(audio_bytes)
            st.success("Audio recorded successfully!")
            
            voice_query = transcribe_audio(audio_path)
            st.markdown("**Transcribed Text:**")
            st.write(voice_query)
            st.session_state['last_voice_input'] = voice_query
            
            if st.button("🔍 Search from Voice"):
                results = search.search(voice_query, None, 20)
                for i, result in enumerate(results, 1):
                    with st.expander(f"Result {i}", expanded=(i==1)):
                        st.write(result['description'])
                        if result.get('youtube_id'):
                            st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}")
            
            if os.path.exists(audio_path):
                os.remove(audio_path)

    # ---- Tab 3: Arxiv Search ----
    with tab3:
        st.subheader("Arxiv Search")
        q = st.text_input("Enter your Arxiv search query:", value=st.session_state['arxiv_last_query'])
        vocal_summary = st.checkbox("🎙 Short Audio Summary", value=True)
        titles_summary = st.checkbox("🔖 Titles Only", value=True)
        full_audio = st.checkbox("📚 Full Audio Results", value=False)
        
        if st.button("🔍 Arxiv Search"):
            st.session_state['arxiv_last_query'] = q
            perform_arxiv_lookup(q, vocal_summary=vocal_summary, titles_summary=titles_summary, full_audio=full_audio)

    # ---- Tab 4: File Manager ----
    with tab4:
        show_file_manager()

    # Sidebar
    with st.sidebar:
        st.subheader("⚙️ Settings & History")
        if st.button("🗑️ Clear History"):
            st.session_state['search_history'] = []
            st.experimental_rerun()
        
        st.markdown("### Recent Searches")
        for entry in reversed(st.session_state['search_history'][-5:]):
            with st.expander(f"{entry['timestamp']}: {entry['query']}"):
                for i, result in enumerate(entry['results'], 1):
                    st.write(f"{i}. {result['description'][:100]}...")

        st.markdown("### Voice Settings")
        st.selectbox("TTS Voice:", 
                     ["en-US-AriaNeural", "en-US-GuyNeural", "en-GB-SoniaNeural"],
                     key="tts_voice")

if __name__ == "__main__":
    main()