WER_Evaluation / app.py
SimpleFrog's picture
Update app.py
f56d7bc verified
import os
os.environ["XDG_CONFIG_HOME"] = "/tmp"
os.environ["XDG_CACHE_HOME"] = "/tmp"
os.environ["HF_HOME"] = "/tmp/huggingface" # pour les modèles/datasets
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
import streamlit as st
import tempfile
import pandas as pd
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from peft import PeftModel
import torch
import librosa
import numpy as np
import evaluate
import tempfile
from huggingface_hub import snapshot_download
from transformers import pipeline
import openai
from openai import OpenAI
st.title("📊 Évaluation WER d'un modèle Whisper")
st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.")
# Section : Choix du modèle
st.subheader("1. Choix du modèle")
model_option = st.radio("Quel modèle veux-tu utiliser ?", (
"Whisper Large (baseline)",
"Whisper Large + LoRA (SimpleFrog/whisper_finetuned)",
"Whisper Large + LoRA + Post-processing Mistral 7B",
"Whisper Large + LoRA + Post-processing GPT-4o"
))
# Section : Lien du dataset
st.subheader("2. Chargement du dataset Hugging Face")
dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/Dataset_Test")
hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password")
openai_api_key = st.text_input("Clé API OpenAI (pour GPT-4o)", type="password")
if hf_token:
from huggingface_hub import login
login(hf_token)
# Section : Choix du split
split_option = st.selectbox(
"Choix du split à évaluer",
options=["Tous", "train", "validation", "test"],
index=0 # par défaut "Tous"
)
# Section : Choix du nombre maximal d'exemples à évaluer
max_examples_option = st.selectbox(
"Nombre maximum d'audios à traiter",
options=["1", "5", "10", "Tous"],
index=3 # par défaut "Tous"
)
# Section : Bouton pour lancer l'évaluation
start_eval = st.button("🚀 Lancer l'évaluation WER")
if start_eval:
st.subheader("🔍 Traitement en cours...")
# 🔹 Télécharger dataset
with st.spinner("Chargement du dataset..."):
try:
dataset_full = load_dataset(dataset_link, split="train", token=hf_token)
# 🔹 Filtrage selon la colonne 'split'
if split_option != "Tous":
dataset = dataset_full.filter(lambda x: x.get("split", "unknown") == split_option)
else:
dataset = dataset_full
if len(dataset) == 0:
st.warning(f"Aucun exemple trouvé pour le split sélectionné : '{split_option}'.")
st.stop()
except Exception as e:
st.error(f"Erreur lors du chargement du dataset : {e}")
st.stop()
# Limiter le nombre d'exemples selon la sélection
if max_examples_option != "Tous":
max_examples = int(max_examples_option)
dataset = dataset.select(range(min(max_examples, len(dataset))))
# 🔹 Charger le modèle choisi
with st.spinner("Chargement du modèle..."):
base_model_name = "openai/whisper-large"
model = WhisperForConditionalGeneration.from_pretrained(base_model_name)
if "LoRA" in model_option:
model = PeftModel.from_pretrained(model, "SimpleFrog/whisper_finetuned", token=hf_token)
processor = WhisperProcessor.from_pretrained(base_model_name)
model.eval()
# Charger le pipeline de Mistral si post-processing demandé
if "Post-processing Mistral" in model_option:
with st.spinner("Chargement du modèle de post-traitement Mistral..."):
postproc_pipe = pipeline(
"text2text-generation",
model="mistralai/Mistral-7B-Instruct-v0.2",
device_map="auto", # ou device=0 si tu veux forcer le GPU
torch_dtype=torch.float16 # optionnel mais plus léger
)
st.success("✅ Modèle Mistral chargé.")
def postprocess_with_llm(text):
prompt = f"Tu es CorrecteurAI, une AI française qui permet de corriger les erreurs de saisie vocal. La translation d'un enregistrement audio tiré d'une inspection détaillé de pont t'es envoyé et tu renvoies le texte identique mais avec les éventuelles corrections si des erreurs sont détectés. Le texte peut comprendre du vocabulaire technique associé aux ouvrages d'art. Renvoies uniquement le texte corrigé en français et sans autre commentaire. Voici le texte : {text}"
result = postproc_pipe(prompt, max_new_tokens=256)[0]["generated_text"]
return result.strip()
#fonction process GPT4o
def postprocess_with_gpt4o(text, api_key):
client = OpenAI(api_key=api_key)
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "Tu es CorrecteurAI, une AI française qui permet de corriger les erreurs de saisie vocal. La translation d'un enregistrement audio tiré d'une inspection détaillé de pont t'es envoyé et tu renvoies le texte identique mais avec les éventuelles corrections si des erreurs sont détectés. Le texte peut comprendre du vocabulaire technique associé aux ouvrages d'art. Renvoies uniquement le texte corrigé en français et sans autre commentaire."},
{"role": "user", "content": f"Corrige ce texte : {text}"}
],
temperature=0.3,
max_tokens=512
)
return response.choices[0].message.content.strip()
# 🔹 Préparer WER metric
wer_metric = evaluate.load("wer")
results = []
# Téléchargement explicite du dossier audio (chemin local vers chaque fichier)
repo_local_path = snapshot_download(repo_id=dataset_link, repo_type="dataset", token=hf_token)
for example in dataset:
st.write("Exemple brut :", example)
try:
reference = example["text"]
waveform = example["audio"]["array"]
audio_path = example["audio"]["path"]
waveform = np.expand_dims(waveform, axis=0)
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
pred_ids = model.generate(input_features=inputs.input_features)
prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
# === Post-processing conditionnel ===
if "Post-processing Mistral" in model_option:
st.write("⏳ Post-processing avec Mistral...")
postprocessed_prediction = postprocess_with_llm(prediction)
st.write("✅ Terminé.")
final_prediction = postprocessed_prediction
elif "Post-processing GPT-4o" in model_option:
if not openai_api_key:
st.error("Clé API OpenAI requise pour GPT-4o.")
st.stop()
st.write("🤖 Post-processing avec GPT-4o...")
try:
postprocessed_prediction = postprocess_with_gpt4o(prediction, openai_api_key)
except Exception as e:
postprocessed_prediction = f"[Erreur GPT-4o: {e}]"
final_prediction = postprocessed_prediction
else:
postprocessed_prediction = "-"
final_prediction = prediction
# 🔹 Nettoyage ponctuation pour WER "sans ponctuation"
def clean(text):
return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip()
ref_clean = clean(reference)
pred_clean = clean(final_prediction)
wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean])
results.append({
"Fichier": audio_path,
"Référence": reference,
"Transcription brute": prediction,
"Transcription corrigée": postprocessed_prediction,
"WER": round(wer, 4)
})
except Exception as e:
results.append({
"Fichier": example["audio"].get("path", "unknown"),
"Référence": "Erreur",
"Transcription brute": f"Erreur: {e}",
"Transcription corrigée": "-",
"WER": "-"
})
# 🔹 Générer le tableau de résultats
df = pd.DataFrame(results)
# 🔹 Créer un fichier temporaire pour le CSV
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv") as tmp_csv:
df.to_csv(tmp_csv.name, index=False)
mean_wer = df[df["WER"] != "-"]["WER"].mean()
st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`")
# 🔹 Bouton de téléchargement
with open(tmp_csv.name, "rb") as f:
st.download_button(
label="📥 Télécharger les résultats WER (.csv)",
data=f,
file_name="wer_results.csv",
mime="text/csv"
)