import os
import time
import torch
from transformers import AutoProcessor, AutoModelForCTC, Wav2Vec2PhonemeCTCTokenizer
import librosa
from itertools import groupby
from datasets import load_dataset

# Load the model and processor
# checkpoint = "bookbot/wav2vec2-ljspeech-gruut"
checkpoint = "facebook/wav2vec2-lv-60-espeak-cv-ft"
model = AutoModelForCTC.from_pretrained(checkpoint)
processor = AutoProcessor.from_pretrained(checkpoint)
tokenizer = Wav2Vec2PhonemeCTCTokenizer.from_pretrained(checkpoint)
sr = processor.feature_extractor.sampling_rate


def decode_phonemes(
    ids: torch.Tensor, processor: AutoProcessor, ignore_stress: bool = False
) -> str:
    """CTC-like decoding. First removes consecutive duplicates, then removes special tokens."""
    # Remove consecutive duplicates
    ids = [id_ for id_, _ in groupby(ids)]

    special_token_ids = processor.tokenizer.all_special_ids + [
        processor.tokenizer.word_delimiter_token_id
    ]
    # Convert id to token, skipping special tokens
    phonemes = [processor.decode(id_) for id_ in ids if id_ not in special_token_ids]

    # Join phonemes
    prediction = " ".join(phonemes)

    # Ignore IPA stress marks if specified
    if ignore_stress:
        prediction = prediction.replace("ˈ", "").replace("ˌ", "")

    return prediction


def text_to_phonemes(text: str) -> str:
    s_time = time.time()
    """Convert text to phonemes using phonemizer."""
    # phonemes = phonemize(text, language="en-us", backend="espeak", strip=True)
    phonemes = tokenizer.phonemize(text, phonemizer_lang="en-us")
    e_time = time.time()
    print(f"Execution time of text_to_phonemes: {e_time - s_time:.6f} seconds")
    return phonemes


def separate_characters(input_string):
    no_spaces = input_string.replace(" ", "")
    spaced_string = " ".join(no_spaces)
    return spaced_string


def predict_phonemes(audio_array):
    # Load audio file and preprocess
    # audio_array, _ = librosa.load(audio_path, sr=sr)

    inputs = processor(audio_array, return_tensors="pt", padding=True)

    # Perform inference
    with torch.no_grad():
        logits = model(inputs["input_values"]).logits

    # Decode the predicted phonemes
    predicted_ids = torch.argmax(logits, dim=-1)
    predicted_phonemes = decode_phonemes(
        predicted_ids[0], processor, ignore_stress=True
    )

    return predicted_phonemes  # Return the predicted phonemes


def adjust_phonemes(predicted: str) -> str:
    # Replace specific phonemes or patterns as needed
    # adjusted = predicted.replace(" ə ", " ")  # Remove schwa if it appears alone
    adjusted = predicted.replace("  ", " ")  # Remove double spaces
    adjusted = adjusted.strip()  # Trim leading/trailing spaces
    return adjusted


def calculate_score(expected: str, predicted: str) -> float:
    expected_list = expected.split()
    predicted_list = predicted.split()

    # Calculate the number of correct matches
    correct_matches = sum(1 for e, p in zip(expected_list, predicted_list) if e == p)

    # Calculate the score as the ratio of correct matches to expected phonemes
    score = correct_matches / len(expected_list) if expected_list else 0
    return score


def test_sound():
    start_time = time.time()

    ds = load_dataset(
        "patrickvonplaten/librispeech_asr_dummy",
        "clean",
        split="validation",
        trust_remote_code=True,
    )
    audio_array = ds[0]["audio"]["array"]

    text = ds[0]["text"]
    # audio_path = "hello.wav"
    # text = "Hello"
    expected_transcript = text  # Expected transcript
    expected_phonemes = text_to_phonemes(text)  # Expected phonemes for "Hello"
    expected_phonemes = separate_characters(expected_phonemes)
    # Call the phoneme prediction function
    predicted_phonemes = predict_phonemes(audio_array)
    adjusted_phonemes = adjust_phonemes(predicted_phonemes)
    print(f"Expected Phonemes: {expected_phonemes}")
    print(f"Predicted Phonemes: {predicted_phonemes}")
    print(f"Adjusted Phonemes: {adjusted_phonemes}")

    # Calculate score based on expected and predicted phonemes
    score = calculate_score(expected_phonemes, adjusted_phonemes)

    # Prepare the output
    text = f"Transcript: {expected_transcript}\nExpected Phonemes: {expected_phonemes}\nPredicted Phonemes: {predicted_phonemes}\nAdjusted Phonemes: {adjusted_phonemes}\nScore: {score:.2f}"
    end_time = time.time()
    execution_time = end_time - start_time
    print(f"Execution time: {execution_time:.6f} seconds")
    return {"text": text}