import spaces
import logging
from datetime import datetime
from pathlib import Path
import gradio as gr
import torch
import torchaudio
import os
import requests
from transformers import pipeline
import tempfile
import numpy as np
from einops import rearrange
import cv2
from scipy.io import wavfile
import librosa
import json
from typing import Optional, Tuple, List
import atexit

try:
    import mmaudio
except ImportError:
    os.system("pip install -e .")
    import mmaudio

from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
                                setup_eval_logging)
from mmaudio.model.flow_matching import FlowMatching
from mmaudio.model.networks import MMAudio, get_my_mmaudio
from mmaudio.model.sequence_config import SequenceConfig
from mmaudio.model.utils.features_utils import FeaturesUtils

# 로깅 설정
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
log = logging.getLogger()

# CUDA 설정
if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
else:
    device = torch.device("cpu")

dtype = torch.bfloat16

# 모델 설정
model: ModelConfig = all_model_cfg['large_44k_v2']
model.download_if_needed()
output_dir = Path('./output/gradio')

setup_eval_logging()

# 번역기 및 Pixabay API 설정
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cpu")
PIXABAY_API_KEY = "33492762-a28a596ec4f286f84cd328b17"



def cleanup_temp_files():
    temp_dir = tempfile.gettempdir()
    for file in os.listdir(temp_dir):
        if file.endswith(('.mp4', '.flac')):
            try:
                os.remove(os.path.join(temp_dir, file))
            except:
                pass

atexit.register(cleanup_temp_files)

def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
    with torch.cuda.device(device):
        seq_cfg = model.seq_cfg
        net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
        net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
        log.info(f'Loaded weights from {model.model_path}')

        feature_utils = FeaturesUtils(
            tod_vae_ckpt=model.vae_path,
            synchformer_ckpt=model.synchformer_ckpt,
            enable_conditions=True,
            mode=model.mode,
            bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
            need_vae_encoder=False
        ).to(device, dtype).eval()

        return net, feature_utils, seq_cfg

net, feature_utils, seq_cfg = get_model()


# search_videos 함수 수정
@torch.no_grad()
def search_videos(query):
    try:
        # CPU에서 번역 실행
        query = translate_prompt(query)
        return search_pixabay_videos(query, PIXABAY_API_KEY)
    except Exception as e:
        logging.error(f"Video search error: {e}")
        return []

# translate_prompt 함수도 수정
def translate_prompt(text):
    try:
        if text and any(ord(char) >= 0x3131 and ord(char) <= 0xD7A3 for char in text):
            # CPU에서 번역 실행
            with torch.no_grad():
                translation = translator(text)[0]['translation_text']
            return translation
        return text
    except Exception as e:
        logging.error(f"Translation error: {e}")
        return text

# 디바이스 설정 부분 수정
if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
else:
    device = torch.device("cpu")

# 번역기 설정 수정
translator = pipeline("translation", 
                     model="Helsinki-NLP/opus-mt-ko-en", 
                     device="cpu")  # 명시적으로 CPU 지정



def search_pixabay_videos(query, api_key):
    try:
        base_url = "https://pixabay.com/api/videos/"
        params = {
            "key": api_key,
            "q": query,
            "per_page": 40
        }
        
        response = requests.get(base_url, params=params)
        if response.status_code == 200:
            data = response.json()
            return [video['videos']['large']['url'] for video in data.get('hits', [])]
        return []
    except Exception as e:
        logging.error(f"Pixabay API error: {e}")
        return []


@spaces.GPU
@torch.inference_mode()
def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
                   cfg_strength: float, duration: float):
    prompt = translate_prompt(prompt)
    negative_prompt = translate_prompt(negative_prompt)

    rng = torch.Generator(device=device)
    rng.manual_seed(seed)
    fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)

    clip_frames, sync_frames, duration = load_video(video, duration)
    clip_frames = clip_frames.unsqueeze(0)
    sync_frames = sync_frames.unsqueeze(0)
    seq_cfg.duration = duration
    net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)

    audios = generate(clip_frames,
                      sync_frames, [prompt],
                      negative_text=[negative_prompt],
                      feature_utils=feature_utils,
                      net=net,
                      fm=fm,
                      rng=rng,
                      cfg_strength=cfg_strength)
    audio = audios.float().cpu()[0]

    video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
    make_video(video,
               video_save_path,
               audio,
               sampling_rate=seq_cfg.sampling_rate,
               duration_sec=seq_cfg.duration)
    return video_save_path

@spaces.GPU
@torch.inference_mode()
def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
                  duration: float):
    prompt = translate_prompt(prompt)
    negative_prompt = translate_prompt(negative_prompt)

    rng = torch.Generator(device=device)
    rng.manual_seed(seed)
    fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)

    clip_frames = sync_frames = None
    seq_cfg.duration = duration
    net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)

    audios = generate(clip_frames,
                      sync_frames, [prompt],
                      negative_text=[negative_prompt],
                      feature_utils=feature_utils,
                      net=net,
                      fm=fm,
                      rng=rng,
                      cfg_strength=cfg_strength)
    audio = audios.float().cpu()[0]

    audio_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.flac').name
    torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
    return audio_save_path



# CSS 스타일 수정
custom_css = """
.gradio-container {
    background: linear-gradient(45deg, #1a1a1a, #2a2a2a);
    border-radius: 15px;
    box-shadow: 0 8px 32px rgba(0,0,0,0.3);
    color: #e0e0e0;
}

.input-container, .output-container {
    background: rgba(40, 40, 40, 0.95);
    backdrop-filter: blur(10px);
    border-radius: 10px;
    padding: 20px;
    transform-style: preserve-3d;
    transition: transform 0.3s ease;
    border: 1px solid rgba(255, 255, 255, 0.1);
}

.input-container:hover {
    transform: translateZ(20px);
    box-shadow: 0 8px 32px rgba(0,0,0,0.5);
}

.gallery-item {
    transition: transform 0.3s ease;
    border-radius: 8px;
    overflow: hidden;
    background: #2a2a2a;
}

.gallery-item:hover {
    transform: scale(1.05);
    box-shadow: 0 4px 15px rgba(0,0,0,0.4);
}

.tabs {
    background: rgba(30, 30, 30, 0.95);
    border-radius: 10px;
    padding: 10px;
    border: 1px solid rgba(255, 255, 255, 0.05);
}

button {
    background: linear-gradient(45deg, #2196F3, #1976D2);
    border: none;
    border-radius: 5px;
    transition: all 0.3s ease;
    color: white;
}

button:hover {
    transform: translateY(-2px);
    box-shadow: 0 4px 15px rgba(33,150,243,0.3);
}

/* 텍스트 입력 필드 스타일 */
textarea, input[type="text"], input[type="number"] {
    background: rgba(30, 30, 30, 0.95) !important;
    border: 1px solid rgba(255, 255, 255, 0.1) !important;
    color: #e0e0e0 !important;
    border-radius: 5px !important;
}

/* 레이블 스타일 */
label {
    color: #e0e0e0 !important;
}

/* 갤러리 그리드 스타일 */
.gallery {
    background: rgba(30, 30, 30, 0.95);
    padding: 15px;
    border-radius: 10px;
    border: 1px solid rgba(255, 255, 255, 0.05);
}
"""

text_to_audio_tab = gr.Interface(
    fn=text_to_audio,
    inputs=[
        gr.Textbox(label="Prompt(한글지원)"),
        gr.Textbox(label="Negative Prompt"),
        gr.Number(label="Seed", value=0),
        gr.Number(label="Steps", value=25),
        gr.Number(label="Guidance Scale", value=4.5),
        gr.Number(label="Duration (sec)", value=8),
    ],
    outputs=gr.Audio(label="Generated Audio"),
    css=custom_css
)


video_to_audio_tab = gr.Interface(
    fn=video_to_audio,
    inputs=[
        gr.Video(label="Input Video"),
        gr.Textbox(label="Prompt(한글지원)"),
        gr.Textbox(label="Negative Prompt", value="music"),
        gr.Number(label="Seed", value=0),
        gr.Number(label="Steps", value=25),
        gr.Number(label="Guidance Scale", value=4.5),
        gr.Number(label="Duration (sec)", value=8),
    ],
    outputs=gr.Video(label="Generated Result"),
    css=custom_css
)

# 인터페이스 정의 수정 (영문으로 변경)
video_search_tab = gr.Interface(
    fn=search_videos,
    inputs=gr.Textbox(label="Search Query(한글지원)"),
    outputs=gr.Gallery(label="Search Results", columns=4, rows=20),
    css=custom_css,
    api_name=False
)



# CSS 스타일 수정
css = """
footer {
    visibility: hidden;
}
""" + custom_css  # 기존 custom_css와 새로운 css를 결합

# 메인 실행 부분 수정
if __name__ == "__main__":
    gr.TabbedInterface(
        [video_search_tab, video_to_audio_tab, text_to_audio_tab],
        ["Video Search", "Video-to-Audio", "Text-to-Audio"],
        theme="Yntec/HaleyCH_Theme_Orange",
        css=css
    ).launch(allowed_paths=[output_dir])