import gradio as gr 
from transformers import pipeline, Wav2Vec2ProcessorWithLM, Wav2Vec2ForCTC
import os
import soundfile as sf
import torch
import numpy as np

HF_TOKEN = os.environ.get("HF_TOKEN")

model_name = "bond005/wav2vec2-large-ru-golos-with-lm"
processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)
pipe = pipeline("automatic-speech-recognition", model=model, tokenizer=processor, feature_extractor=processor.feature_extractor, decoder=processor.decoder) 

detokenize_dict = {value: key for key, value in processor.tokenizer.get_vocab().items()}

dict_v = ["а", "у" "о" "и" "э" "ы" "я" "ю" "е" "ё"]

def count_char_borders(predicted_ids, input_values, processor, sample_rate=16000):
    predicted_ids_l = predicted_ids[0].tolist()
    duration_sec = input_values.shape[1] / sample_rate

    ids_c_time = [(i / len(predicted_ids_l) * duration_sec, _id) for i, _id in enumerate(predicted_ids_l)]

    t_chars_list = [[i[0], detokenize_dict[i[1]]] for i in ids_c_time if i[1] != processor.tokenizer.pad_token_id]

    t_chars_list_cl = []
    cur = None
    for i, item in enumerate(t_chars_list[:-1]):
      if i == 0 or cur == None:
        cur = item
      if item[1] != t_chars_list[i + 1][1]:
        cur.append(t_chars_list[i + 1][0])
        t_chars_list_cl.append(cur)
        cur = t_chars_list[i + 1]

    t_chars_list_cl = [i if i[1] != "|" else [i[0], "", i[2]] for i in t_chars_list_cl]
    chars, char_start_times, char_end_times = [], [], []
    for c in t_chars_list_cl:
      if c[1].lower() in dict_v and c[1] != "":
          chars.append("v")
      elif c[1] != "":
          chars.append("c")
      else:
          chars.append("")
      char_start_times.append(c[0])
      char_end_times.append(c[2])
    return chars, char_start_times, char_end_times



# обработка seg-файла, получение информации для расчётов
# предполагается, что на вход получаем seg либо 'corpres' - с разметкой по корпресу, либо упрощённая разметка 'cv' - с разметкой на согласные и гласные

def preprocess(chars, starts, labelled='cv'):
    start_and_sound = []
  # берём из seg-файла метки звуков, отсчёты переводим в секунды, получаем общую длительность
    for i, item in enumerate(chars):
        start_time = float(starts[i])
        label = item
        start_and_sound.append([start_time, label])

    # заводим переменные, необходимые для расчётов
    clusters_and_duration = []
    pauses = 0
    sum_dur_vowels = 0
    # флаг для определения границ кластеров. важно, если до и после паузы звуки одного класса
    postpause_flag = 0

    # обработка файлов с гласно-согласной разметкой
    if labelled == 'cv':
      total_duration = 0
      # определяем к какому классу относится каждый звук и считаем длительность (отдельных гласных и согласных кластеров)
      for n, i in enumerate(start_and_sound):
        sound = i[1]
        # определяем не является ли звук конечным
        if n != len(start_and_sound) - 1:
          duration = start_and_sound[n+1][0] - i[0]
          # выделяем гласные
          if sound == 'V' or sound == 'v':
            total_duration += duration
            # записываем отдельно звук в нулевой позиции в обход ошибки индекса
            if n == 0:
              clusters_and_duration.append(['V', duration])

            # объединяем длительности, если предыдущий звук тоже был гласным
            elif clusters_and_duration[-1][0] == 'V' and postpause_flag == 0:
              clusters_and_duration[-1][1] += duration

            # фиксируем длительность отдельного гласного звука
            else:
              clusters_and_duration.append(['V', duration])

            # считаем длителность всех гласных интервалов в записи
            sum_dur_vowels += duration
            # снимаем флаг
            postpause_flag = 0

          # выделяем паузы
          elif sound == '':
            pauses += duration
            total_duration += duration
            # ставим флаг для следующего звука
            postpause_flag = 1

          # выделяем согласные
          else:
            total_duration += duration
            # записываем отдельно звук в нулевой позиции в обход ошибки
            if n == 0:
              clusters_and_duration.append(['C', duration])

            # объединяем длительности, если предыдущий звук тоже был согласным
            elif clusters_and_duration[-1][0] == 'C' and postpause_flag == 0:
              clusters_and_duration[-1][1] += duration

            # фиксируем длительность отдельного согласного звука
            else:
              clusters_and_duration.append(['C', duration])

            # снимаем флаг
            postpause_flag = 0

  # функция возвращает метки кластеров и их длительность и общую длительность всех гласных интервалов
    return clusters_and_duration, sum_dur_vowels, total_duration, pauses


def delta_C(cons_clusters):
  # применяем функцию numpy среднеквадратического отклонения
  dC = np.std(cons_clusters)
  return dC

def percent_V(vowels, total_wo_pauses):
  pV = vowels / total_wo_pauses
  return pV


# point_1 = np.array((0, 0, 0))
# point_2 = np.array((3, 3, 3))
def count_eucl(point_1, point_2):
    # Initializing the points
    # Get the square of the difference of the 2 vectors
    square = np.square(point_1 - point_2)
    # Get the sum of the square
    sum_square = np.sum(square)
    # The last step is to get the square root and print the Euclidean distance
    distance = np.sqrt(sum_square)
    return distance

 
ex_dict = {"eng": np.array((0.0535, 0.401)), "kat": np.array((0.0452, 0.456)), "jap": np.array((0.0356, 0.531))}


def classify_rhytm(dC, pV):
    our = np.array((dC, pV))
    res = {}
    if (dC > 0.08 and pV > 0.45) or (dC < 0.03 and pV < 0.04):
        text = "Вы не укладываетесь ни в какие рамки и прекрасны в этом!"
    else:
        for k, v in ex_dict.items():
            res[k] = count_eucl(our, v)
        
        sorted_tuples = sorted(res.items(), key=lambda item: item[1])
        sorted_res = {k: v for k, v in sorted_tuples}
        if [i for i in sorted_res.keys()][0] == "eng":
            text = "По типу ритма ваша речь близка к тактосчитающим языкам (английский)."
        if [i for i in sorted_res.keys()][0] == "kat":
            text = "По типу ритма ваша речь близка к слогосчитающим языкам (испанский)."
        if [i for i in sorted_res.keys()][0] == "jap":
            text = "По типу ритма ваша речь близка к моросчитающим языкам (японский)."
    return text
        


def transcribe(audio):
    y, sr = sf.read(audio, samplerate=16000)
    input_values = processor(y, sampling_rate=sr, return_tensors="pt").input_values

    logits = model(input_values).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    
    chars, char_start_times, char_end_times = count_char_borders(predicted_ids, input_values, processor)

    clusters_and_duration, sum_dur_vowels, total_duration, pauses = preprocess(chars, char_start_times)
    
    cons_clusters = []
    
    # параметры для ΔC
    for x in clusters_and_duration:
      if x[0] == 'C':
        cons_clusters.append(x[1])
    
    # параметры для %V
    vowels_duration = sum_dur_vowels
    duration_without_pauses = total_duration - pauses

    # расчёт метрик
    dC = delta_C(cons_clusters) / 5
    pV = percent_V(vowels_duration, duration_without_pauses) * 5
    
    transcription = processor.batch_decode(logits.detach().numpy()).text[0]
    
    text = {"transcription": transcription}

    text['dC'] = dC

    text['pV'] = pV
    
    cl = classify_rhytm(dC, pV)
    
    text['result'] = cl
    
    return text

iface = gr.Interface(
    fn=transcribe, 
    inputs=gr.Audio(type="filepath"), 
    outputs="text",
    title="Mihaj/Wav2Vec2RhytmAnalyzer",
    description="Демо анализатор ритма на основе модели Wav2Vec large от bond005.",
)

iface.launch()