Spaces:
Sleeping
Sleeping
import os | |
os.environ["XDG_CONFIG_HOME"] = "/tmp" | |
os.environ["XDG_CACHE_HOME"] = "/tmp" | |
os.environ["HF_HOME"] = "/tmp/huggingface" # pour les modèles/datasets | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" | |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub" | |
import streamlit as st | |
import tempfile | |
import pandas as pd | |
from datasets import load_dataset | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
from peft import PeftModel | |
import torch | |
import librosa | |
import numpy as np | |
import evaluate | |
import tempfile | |
from huggingface_hub import snapshot_download | |
from transformers import pipeline | |
import openai | |
from openai import OpenAI | |
st.title("📊 Évaluation WER d'un modèle Whisper") | |
st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.") | |
# Section : Choix du modèle | |
st.subheader("1. Choix du modèle") | |
model_option = st.radio("Quel modèle veux-tu utiliser ?", ( | |
"Whisper Large (baseline)", | |
"Whisper Large + LoRA (SimpleFrog/whisper_finetuned)", | |
"Whisper Large + LoRA + Post-processing Mistral 7B", | |
"Whisper Large + LoRA + Post-processing GPT-4o" | |
)) | |
# Section : Lien du dataset | |
st.subheader("2. Chargement du dataset Hugging Face") | |
dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/Dataset_Test") | |
hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password") | |
openai_api_key = st.text_input("Clé API OpenAI (pour GPT-4o)", type="password") | |
if hf_token: | |
from huggingface_hub import login | |
login(hf_token) | |
# Section : Choix du split | |
split_option = st.selectbox( | |
"Choix du split à évaluer", | |
options=["Tous", "train", "validation", "test"], | |
index=0 # par défaut "Tous" | |
) | |
# Section : Choix du nombre maximal d'exemples à évaluer | |
max_examples_option = st.selectbox( | |
"Nombre maximum d'audios à traiter", | |
options=["1", "5", "10", "Tous"], | |
index=3 # par défaut "Tous" | |
) | |
# Section : Bouton pour lancer l'évaluation | |
start_eval = st.button("🚀 Lancer l'évaluation WER") | |
if start_eval: | |
st.subheader("🔍 Traitement en cours...") | |
# 🔹 Télécharger dataset | |
with st.spinner("Chargement du dataset..."): | |
try: | |
dataset_full = load_dataset(dataset_link, split="train", token=hf_token) | |
# 🔹 Filtrage selon la colonne 'split' | |
if split_option != "Tous": | |
dataset = dataset_full.filter(lambda x: x.get("split", "unknown") == split_option) | |
else: | |
dataset = dataset_full | |
if len(dataset) == 0: | |
st.warning(f"Aucun exemple trouvé pour le split sélectionné : '{split_option}'.") | |
st.stop() | |
except Exception as e: | |
st.error(f"Erreur lors du chargement du dataset : {e}") | |
st.stop() | |
# Limiter le nombre d'exemples selon la sélection | |
if max_examples_option != "Tous": | |
max_examples = int(max_examples_option) | |
dataset = dataset.select(range(min(max_examples, len(dataset)))) | |
# 🔹 Charger le modèle choisi | |
with st.spinner("Chargement du modèle..."): | |
base_model_name = "openai/whisper-large" | |
model = WhisperForConditionalGeneration.from_pretrained(base_model_name) | |
if "LoRA" in model_option: | |
model = PeftModel.from_pretrained(model, "SimpleFrog/whisper_finetuned", token=hf_token) | |
processor = WhisperProcessor.from_pretrained(base_model_name) | |
model.eval() | |
# Charger le pipeline de Mistral si post-processing demandé | |
if "Post-processing Mistral" in model_option: | |
with st.spinner("Chargement du modèle de post-traitement Mistral..."): | |
postproc_pipe = pipeline( | |
"text2text-generation", | |
model="mistralai/Mistral-7B-Instruct-v0.2", | |
device_map="auto", # ou device=0 si tu veux forcer le GPU | |
torch_dtype=torch.float16 # optionnel mais plus léger | |
) | |
st.success("✅ Modèle Mistral chargé.") | |
def postprocess_with_llm(text): | |
prompt = f"Tu es CorrecteurAI, une AI française qui permet de corriger les erreurs de saisie vocal. La translation d'un enregistrement audio tiré d'une inspection détaillé de pont t'es envoyé et tu renvoies le texte identique mais avec les éventuelles corrections si des erreurs sont détectés. Le texte peut comprendre du vocabulaire technique associé aux ouvrages d'art. Renvoies uniquement le texte corrigé en français et sans autre commentaire. Voici le texte : {text}" | |
result = postproc_pipe(prompt, max_new_tokens=256)[0]["generated_text"] | |
return result.strip() | |
#fonction process GPT4o | |
def postprocess_with_gpt4o(text, api_key): | |
client = OpenAI(api_key=api_key) | |
response = client.chat.completions.create( | |
model="gpt-4o", | |
messages=[ | |
{"role": "system", "content": "Tu es CorrecteurAI, une AI française qui permet de corriger les erreurs de saisie vocal. La translation d'un enregistrement audio tiré d'une inspection détaillé de pont t'es envoyé et tu renvoies le texte identique mais avec les éventuelles corrections si des erreurs sont détectés. Le texte peut comprendre du vocabulaire technique associé aux ouvrages d'art. Renvoies uniquement le texte corrigé en français et sans autre commentaire."}, | |
{"role": "user", "content": f"Corrige ce texte : {text}"} | |
], | |
temperature=0.3, | |
max_tokens=512 | |
) | |
return response.choices[0].message.content.strip() | |
# 🔹 Préparer WER metric | |
wer_metric = evaluate.load("wer") | |
results = [] | |
# Téléchargement explicite du dossier audio (chemin local vers chaque fichier) | |
repo_local_path = snapshot_download(repo_id=dataset_link, repo_type="dataset", token=hf_token) | |
for example in dataset: | |
st.write("Exemple brut :", example) | |
try: | |
reference = example["text"] | |
waveform = example["audio"]["array"] | |
audio_path = example["audio"]["path"] | |
waveform = np.expand_dims(waveform, axis=0) | |
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt") | |
with torch.no_grad(): | |
pred_ids = model.generate(input_features=inputs.input_features) | |
prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0] | |
# === Post-processing conditionnel === | |
if "Post-processing Mistral" in model_option: | |
st.write("⏳ Post-processing avec Mistral...") | |
postprocessed_prediction = postprocess_with_llm(prediction) | |
st.write("✅ Terminé.") | |
final_prediction = postprocessed_prediction | |
elif "Post-processing GPT-4o" in model_option: | |
if not openai_api_key: | |
st.error("Clé API OpenAI requise pour GPT-4o.") | |
st.stop() | |
st.write("🤖 Post-processing avec GPT-4o...") | |
try: | |
postprocessed_prediction = postprocess_with_gpt4o(prediction, openai_api_key) | |
except Exception as e: | |
postprocessed_prediction = f"[Erreur GPT-4o: {e}]" | |
final_prediction = postprocessed_prediction | |
else: | |
postprocessed_prediction = "-" | |
final_prediction = prediction | |
# 🔹 Nettoyage ponctuation pour WER "sans ponctuation" | |
def clean(text): | |
return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip() | |
ref_clean = clean(reference) | |
pred_clean = clean(final_prediction) | |
wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean]) | |
results.append({ | |
"Fichier": audio_path, | |
"Référence": reference, | |
"Transcription brute": prediction, | |
"Transcription corrigée": postprocessed_prediction, | |
"WER": round(wer, 4) | |
}) | |
except Exception as e: | |
results.append({ | |
"Fichier": example["audio"].get("path", "unknown"), | |
"Référence": "Erreur", | |
"Transcription brute": f"Erreur: {e}", | |
"Transcription corrigée": "-", | |
"WER": "-" | |
}) | |
# 🔹 Générer le tableau de résultats | |
df = pd.DataFrame(results) | |
# 🔹 Créer un fichier temporaire pour le CSV | |
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv") as tmp_csv: | |
df.to_csv(tmp_csv.name, index=False) | |
mean_wer = df[df["WER"] != "-"]["WER"].mean() | |
st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`") | |
# 🔹 Bouton de téléchargement | |
with open(tmp_csv.name, "rb") as f: | |
st.download_button( | |
label="📥 Télécharger les résultats WER (.csv)", | |
data=f, | |
file_name="wer_results.csv", | |
mime="text/csv" | |
) | |