Spaces:
Running
Running
import gradio as gr | |
import os | |
import json | |
import pandas as pd | |
from datasets import load_dataset, DatasetDict, Dataset, Audio | |
from huggingface_hub import HfApi, whoami, login | |
import tempfile | |
import shutil | |
import gc | |
import time | |
import psutil | |
from pydub import AudioSegment # For audio trimming | |
import soundfile as sf | |
from tenacity import retry, stop_after_attempt, wait_exponential | |
import re | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
# Log in with Hugging Face token if available | |
token = os.getenv("hf_token") | |
login(token) | |
# Configuration | |
HF_DATASET_NAME = "navidved/channelb-raw-data" | |
AUDIO_DIR = "audio" | |
SAVE_PATH = "labeled_data.json" | |
ALLOWED_USERS = ["vargha", "navidved"] | |
CURRENT_USERNAME = None | |
PAGE_SIZE = 100 # Number of samples to load at once | |
# Global state variables | |
current_page = 0 | |
ds_iter = None | |
current_page_data = None | |
audio_backup = {} # To store original audio paths for each sample | |
# Load saved labels from JSON file (and convert old formats if needed) | |
def load_saved_labels(): | |
# First, try to load from the local file system. | |
if os.path.exists(SAVE_PATH): | |
try: | |
with open(SAVE_PATH, "r", encoding="utf-8") as f: | |
data = json.load(f) | |
except Exception as e: | |
print("Error loading local JSON file: " + str(e)) | |
data = {} | |
else: | |
data = {} | |
# If no local file found or it's empty, try to download from the HF dataset repo. | |
if not data: | |
try: | |
# Download the file from HF repo (make sure repo_type="dataset" is specified) | |
hf_path = hf_hub_download( | |
repo_id=HF_DATASET_NAME, | |
filename=SAVE_PATH, | |
repo_type="dataset", | |
token=token | |
) | |
with open(hf_path, "r", encoding="utf-8") as f: | |
data = json.load(f) | |
# Write the downloaded file to the local path for caching. | |
with open(SAVE_PATH, "w", encoding="utf-8") as f: | |
json.dump(data, f, ensure_ascii=False, indent=4) | |
print("Loaded JSON file from HF dataset repository and cached it locally.") | |
except Exception as e: | |
print("Error loading JSON file from HF repo: " + str(e)) | |
data = {} | |
# Convert any old string-formats to the new dict format. | |
for key, value in data.items(): | |
if isinstance(value, str): | |
data[key] = {"transcript": value, "reviewer": "unreviewed"} | |
return data | |
saved_labels = load_saved_labels() | |
# New function: Push the JSON file to the Hugging Face dataset repository. | |
def push_json_to_hf(): | |
try: | |
api = HfApi() | |
api.upload_file( | |
path_or_fileobj=SAVE_PATH, # File to upload | |
path_in_repo=SAVE_PATH, # Destination path in repo | |
repo_type="dataset", | |
repo_id=HF_DATASET_NAME, # Repository id | |
token=token | |
) | |
print("Uploaded labeled_data.json to Hugging Face repository") | |
except Exception as e: | |
print("Error uploading json file: " + str(e)) | |
# Initialize dataset iterator for streaming | |
def init_dataset_iterator(): | |
global ds_iter | |
try: | |
ds = load_dataset(HF_DATASET_NAME, split="train", streaming=True) | |
ds_iter = iter(ds) | |
return True | |
except Exception as e: | |
print(f"Error initializing dataset iterator: {e}") | |
return False | |
# Load a page of data | |
def load_page_data(page_num=0): | |
global ds_iter, current_page_data, current_page | |
if ds_iter is None: | |
if not init_dataset_iterator(): | |
return pd.DataFrame(columns=["audio", "sentence"]) | |
if page_num < current_page: | |
ds_iter = iter(load_dataset(HF_DATASET_NAME, split="train", streaming=True)) | |
current_page = 0 | |
samples_to_skip = page_num * PAGE_SIZE - (current_page * PAGE_SIZE) if page_num > current_page else 0 | |
for _ in range(samples_to_skip): | |
try: | |
next(ds_iter) | |
except StopIteration: | |
break | |
samples = [] | |
for _ in range(PAGE_SIZE): | |
try: | |
samples.append(next(ds_iter)) | |
except StopIteration: | |
break | |
current_page = page_num | |
current_page_data = pd.DataFrame(samples) | |
gc.collect() | |
return current_page_data | |
# Get dataset info (number of samples) | |
def get_dataset_info(): | |
try: | |
info = load_dataset(HF_DATASET_NAME, split="train", streaming=True).info | |
if hasattr(info, 'splits') and 'train' in info.splits: | |
return {'num_samples': info.splits['train'].num_examples} | |
return {'num_samples': -1} | |
except Exception as e: | |
print(f"Error getting dataset info: {e}") | |
return {'num_samples': -1} | |
init_dataset_iterator() | |
current_page_data = load_page_data(0) | |
dataset_info = get_dataset_info() | |
total_samples = dataset_info.get('num_samples', -1) | |
# Resolve the audio file path (or data) from an entry | |
def get_audio_path(audio_entry): | |
if isinstance(audio_entry, dict): | |
if "array" in audio_entry and "sampling_rate" in audio_entry: | |
# Return the tuple as-is (raw audio data) | |
return (audio_entry["sampling_rate"], audio_entry["array"]) | |
return audio_entry.get("path", None) | |
if isinstance(audio_entry, str): | |
if audio_entry.startswith("http://") or audio_entry.startswith("https://"): | |
return audio_entry | |
if os.path.exists(audio_entry): | |
return audio_entry | |
joined_path = os.path.join(AUDIO_DIR, audio_entry) | |
if os.path.exists(joined_path): | |
return joined_path | |
return audio_entry | |
# Save sample data (transcript, reviewer, and any edited audio info) | |
def save_sample_data(page_idx, idx, transcript, reviewer): | |
global current_page_data | |
if idx >= len(current_page_data): | |
return "Invalid index" | |
absolute_idx = page_idx * PAGE_SIZE + idx | |
audio_entry = current_page_data.iloc[idx]["audio"] | |
key = f"{absolute_idx}_{os.path.basename(str(audio_entry))}" | |
if key in saved_labels: | |
saved_labels[key]["transcript"] = transcript.strip() | |
saved_labels[key]["reviewer"] = reviewer | |
else: | |
saved_labels[key] = {"transcript": transcript.strip(), "reviewer": reviewer} | |
try: | |
with open(SAVE_PATH, "w", encoding="utf-8") as f: | |
json.dump(saved_labels, f, ensure_ascii=False, indent=4) | |
# Every time 10 audios are saved, push the json file to the Hugging Face dataset repo. | |
if len(saved_labels) % 10 == 0: | |
push_json_to_hf() | |
return f"✓ Saved data for {key}" | |
except Exception as e: | |
return f"Error saving: {str(e)}" | |
# Retrieve sample data (audio, transcript, status, reviewer) for a given index. | |
def get_sample(page_idx, idx): | |
global current_page_data | |
if idx < 0 or idx >= len(current_page_data): | |
return None, "", f"Invalid index. Range is 0-{len(current_page_data)-1}", "unreviewed" | |
absolute_idx = page_idx * PAGE_SIZE + idx | |
audio_entry = current_page_data.iloc[idx]["audio"] | |
key = f"{absolute_idx}_{os.path.basename(str(audio_entry))}" | |
sample = saved_labels.get(key, {"transcript": current_page_data.iloc[idx]["sentence"], "reviewer": "unreviewed"}) | |
transcript = sample.get("transcript", current_page_data.iloc[idx]["sentence"]) | |
reviewer = sample.get("reviewer", "unreviewed") | |
# Check for audio edits or deletion flag. | |
if sample.get("deleted", False): | |
audio_val = None | |
elif "audio_edit" in sample: | |
audio_val = sample["audio_edit"] | |
else: | |
audio_val = get_audio_path(audio_entry) | |
status = f"Sample {absolute_idx+1}" | |
if total_samples > 0: | |
status += f" of {total_samples}" | |
return audio_val, transcript, status, reviewer | |
# Load the interface components for a sample index; also backup original audio. | |
def load_interface(page_idx, idx): | |
audio, text, base_status, saved_reviewer = get_sample(page_idx, idx) | |
absolute_idx = page_idx * PAGE_SIZE + idx | |
audio_entry = current_page_data.iloc[idx]["audio"] | |
key = f"{absolute_idx}_{os.path.basename(str(audio_entry))}" | |
if key not in audio_backup: | |
audio_backup[key] = audio # Store the original audio path | |
status = f"{base_status} - Page {page_idx+1} - Reviewer: {saved_reviewer}" | |
return page_idx, idx, audio, text, saved_reviewer, status, text | |
# Navigation functions | |
def next_sample(page_idx, idx, current_text, current_annotator, original_transcript): | |
global current_page_data | |
if idx >= len(current_page_data) - 1: | |
new_page_idx = page_idx + 1 | |
new_idx = 0 | |
load_page_data(new_page_idx) | |
else: | |
new_page_idx = page_idx | |
new_idx = idx + 1 | |
return load_interface(new_page_idx, new_idx) | |
def go_next_without_save(page_idx, idx, current_text, current_annotator, original_transcript): | |
return next_sample(page_idx, idx, original_transcript, current_annotator, original_transcript) | |
def prev_sample(page_idx, idx, current_text, current_annotator, original_transcript): | |
global current_page_data | |
if current_text.strip() != original_transcript.strip(): | |
save_sample_data(page_idx, idx, current_text, CURRENT_USERNAME) | |
if idx <= 0: | |
if page_idx > 0: | |
new_page_idx = page_idx - 1 | |
new_data = load_page_data(new_page_idx) | |
new_idx = len(new_data) - 1 | |
else: | |
new_page_idx = page_idx | |
new_idx = idx | |
else: | |
new_page_idx = page_idx | |
new_idx = idx - 1 | |
return load_interface(new_page_idx, new_idx) | |
def jump_to(target_idx, page_idx, idx, current_text, current_annotator, original_transcript): | |
if current_text.strip() != original_transcript.strip(): | |
save_sample_data(page_idx, idx, current_text, CURRENT_USERNAME) | |
try: | |
target_idx = int(target_idx) | |
if target_idx < 0: | |
target_idx = 0 | |
new_page_idx = target_idx // PAGE_SIZE | |
new_idx = target_idx % PAGE_SIZE | |
new_data = load_page_data(new_page_idx) | |
if new_idx >= len(new_data): | |
new_idx = len(new_data) - 1 if len(new_data) > 0 else 0 | |
return load_interface(new_page_idx, new_idx) | |
except: | |
return load_interface(page_idx, idx) | |
def save_and_next_sample(page_idx, idx, current_text, current_annotator, original_transcript): | |
save_sample_data(page_idx, idx, current_text, CURRENT_USERNAME) | |
return next_sample(page_idx, idx, current_text, current_annotator, current_text) | |
# ----------------- Audio Editing Functions ----------------- | |
def trim_audio_action(page_idx, idx, trim_start, trim_end, current_text, current_annotator, original_transcript): | |
audio, transcript, base_status, saved_reviewer = get_sample(page_idx, idx) | |
temp_audio_file = None | |
# If the audio is provided as raw data (tuple), convert it to a temporary WAV file. | |
if isinstance(audio, tuple): | |
sample_rate, audio_array = audio | |
try: | |
import numpy as np | |
from scipy.io.wavfile import write | |
# If the audio array is in float format, convert it to int16 | |
if np.issubdtype(audio_array.dtype, np.floating): | |
# Normalize if necessary and then convert to int16 | |
audio_array = (audio_array * 32767).astype(np.int16) | |
else: | |
# Ensure the data is a numpy array | |
audio_array = np.array(audio_array) | |
temp_audio_file = os.path.join(tempfile.gettempdir(), f"temp_{page_idx}_{idx}.wav") | |
write(temp_audio_file, sample_rate, audio_array) | |
audio = temp_audio_file | |
except Exception as e: | |
return page_idx, idx, audio, transcript, saved_reviewer, f"Error converting raw audio: {str(e)}", transcript | |
# Now ensure audio is a file path before trimming. | |
if not isinstance(audio, str) or not os.path.exists(audio): | |
return page_idx, idx, audio, transcript, saved_reviewer, "Trimming not supported for this audio format.", transcript | |
try: | |
audio_seg = AudioSegment.from_file(audio) | |
start_ms = int(float(trim_start) * 1000) | |
end_ms = int(float(trim_end) * 1000) | |
trimmed_seg = audio_seg[start_ms:end_ms] | |
os.makedirs("trimmed_audio", exist_ok=True) | |
trimmed_path = os.path.join("trimmed_audio", f"trimmed_{os.path.basename(str(audio))}") | |
trimmed_seg.export(trimmed_path, format="wav") | |
absolute_idx = page_idx * PAGE_SIZE + idx | |
audio_entry = current_page_data.iloc[idx]["audio"] | |
key = f"{absolute_idx}_{os.path.basename(str(audio_entry))}" | |
if key not in saved_labels: | |
saved_labels[key] = {} | |
saved_labels[key]["audio_edit"] = trimmed_path | |
new_status = f"{base_status} [Trimmed]" | |
return page_idx, idx, trimmed_path, transcript, saved_reviewer, new_status, transcript | |
except Exception as e: | |
return page_idx, idx, audio, transcript, saved_reviewer, f"Error trimming audio: {str(e)}", transcript | |
def undo_trim_action(page_idx, idx, current_text, current_annotator, original_transcript): | |
audio, transcript, base_status, saved_reviewer = get_sample(page_idx, idx) | |
absolute_idx = page_idx * PAGE_SIZE + idx | |
audio_entry = current_page_data.iloc[idx]["audio"] | |
key = f"{absolute_idx}_{os.path.basename(str(audio_entry))}" | |
if key in saved_labels and "audio_edit" in saved_labels[key]: | |
del saved_labels[key]["audio_edit"] | |
orig_audio = audio_backup.get(key, audio) | |
new_status = f"{base_status} [Trim undone]" | |
return page_idx, idx, orig_audio, transcript, saved_reviewer, new_status, transcript | |
def confirm_delete_audio(page_idx, idx, current_text, current_annotator, original_transcript): | |
audio, transcript, base_status, saved_reviewer = get_sample(page_idx, idx) | |
absolute_idx = page_idx * PAGE_SIZE + idx | |
audio_entry = current_page_data.iloc[idx]["audio"] | |
key = f"{absolute_idx}_{os.path.basename(str(audio_entry))}" | |
if key not in saved_labels: | |
saved_labels[key] = {} | |
saved_labels[key]["deleted"] = True | |
saved_labels[key]["transcript"] = "AUDIO DELETED (This audio has been removed.)" | |
new_status = f"{base_status} [Audio deleted]" | |
# Remove any audio editing info if present. | |
if "audio_edit" in saved_labels[key]: | |
del saved_labels[key]["audio_edit"] | |
return page_idx, idx, None, saved_labels[key]["transcript"], "deleted", new_status, saved_labels[key]["transcript"] | |
# ----------------- Export to Hugging Face Function ----------------- | |
def sanitize_string(s): | |
"""Remove invalid characters from strings.""" | |
if not isinstance(s, str): | |
s = str(s) | |
return re.sub(r'[^\w\-\./]', '_', s) | |
def sanitize_sentence(s): | |
"""Ensure sentence is a clean string.""" | |
if not isinstance(s, str): | |
s = str(s) | |
return s.encode('utf-8', errors='ignore').decode('utf-8') | |
def push_to_hub_with_retry(dataset_dict, repo_id, private=True, token=None): | |
"""Push dataset to Hugging Face Hub with retry logic.""" | |
print(f"Pushing dataset to {repo_id}") | |
dataset_dict.push_to_hub(repo_id, private=private, token=token) | |
def export_to_huggingface(repo_name, token, progress=gr.Progress()): | |
""" | |
Export the labeled dataset to Hugging Face with 'audio' column cast to Audio | |
and 'sentence' column, pushed as a DatasetDict to the Hub. | |
Args: | |
repo_name (str): The Hugging Face repository name (e.g., "username/dataset-name"). | |
token (str): Hugging Face authentication token. | |
progress (gr.Progress): Gradio progress tracker. | |
Returns: | |
str: Status message indicating success or failure. | |
""" | |
try: | |
start_time = time.time() | |
repo_name = sanitize_string(repo_name) | |
print(f"Export started at {time.strftime('%Y-%m-%d %H:%M:%S')}") | |
print(f"Repo name: {repo_name}, Token: {'[hidden]' if token else 'None'}") | |
print(f"Total labels loaded: {len(saved_labels)}") | |
print(f"Sample label keys: {list(saved_labels.keys())[:10]}") | |
# Get total number of samples | |
total_samples = get_dataset_info().get('num_samples', -1) | |
if total_samples <= 0: | |
total_samples = 1500 # Limit for testing | |
chunk_size = 100 # Smaller chunks | |
num_chunks = (total_samples + chunk_size - 1) // chunk_size | |
print(f"Total samples: {total_samples}, chunks: {num_chunks}") | |
progress(0, f"Total samples: {total_samples}, chunks: {num_chunks}") | |
# Temporary directory | |
with tempfile.TemporaryDirectory() as temp_dir: | |
print(f"Temp dir: {temp_dir}") | |
# Load dataset | |
ds = load_dataset(HF_DATASET_NAME, split="train", streaming=True) | |
ds_iter = iter(ds) | |
print("Dataset initialized") | |
all_datasets = [] | |
label_found_count = 0 | |
label_missing_count = 0 | |
processed_samples = 0 | |
# Process chunks | |
for chunk_idx in range(num_chunks): | |
chunk_start_time = time.time() | |
print(f"Chunk {chunk_idx + 1}/{num_chunks}") | |
chunk_data = [] | |
for _ in range(chunk_size): | |
sample_start_time = time.time() | |
try: | |
sample = next(ds_iter) | |
absolute_idx = chunk_idx * chunk_size + len(chunk_data) | |
process = psutil.Process() | |
memory_info = process.memory_info() | |
print(f"Sample {absolute_idx} memory: {memory_info.rss / 1024**2:.2f} MB") | |
# Handle audio | |
audio_entry = sample["audio"] | |
audio_dict = None | |
audio_key = None | |
try: | |
if isinstance(audio_entry, dict) and "array" in audio_entry: | |
# Validate audio array | |
if not np.all(np.isfinite(audio_entry["array"])): | |
print(f"Sample {absolute_idx}: invalid audio array, skipping") | |
audio_dict = None | |
else: | |
audio_dict = { | |
"array": audio_entry["array"], | |
"sampling_rate": audio_entry["sampling_rate"] | |
} | |
audio_key = audio_entry.get("path", f"sample_{absolute_idx}.mp3") | |
audio_key = sanitize_string(audio_key) | |
print(f"Sample {absolute_idx}: raw audio, path: {audio_key}") | |
elif isinstance(audio_entry, str): | |
if audio_entry.startswith("http://") or audio_entry.startswith("https://"): | |
print(f"Sample {absolute_idx}: skipping URL: {audio_entry}") | |
audio_dict = None | |
audio_key = sanitize_string(audio_entry) | |
else: | |
resolved_path = get_audio_path(audio_entry) | |
if os.path.exists(resolved_path): | |
try: | |
audio_array, sample_rate = sf.read(resolved_path) | |
if not np.all(np.isfinite(audio_array)): | |
print(f"Sample {absolute_idx}: invalid audio array, skipping") | |
audio_dict = None | |
else: | |
audio_dict = { | |
"array": audio_array, | |
"sampling_rate": sample_rate | |
} | |
audio_key = sanitize_string(resolved_path) | |
print(f"Sample {absolute_idx}: loaded audio: {resolved_path}") | |
except Exception as e: | |
print(f"Sample {absolute_idx}: audio load error: {str(e)}") | |
audio_dict = None | |
audio_key = sanitize_string(resolved_path) | |
else: | |
print(f"Sample {absolute_idx}: file not found: {resolved_path}") | |
audio_dict = None | |
audio_key = sanitize_string(resolved_path) | |
else: | |
print(f"Sample {absolute_idx}: unhandled format: {type(audio_entry)}") | |
audio_dict = None | |
audio_key = sanitize_string(str(audio_entry)) | |
except Exception as e: | |
print(f"Sample {absolute_idx}: audio error: {str(e)}") | |
audio_dict = None | |
audio_key = f"sample_{absolute_idx}" | |
# Label lookup | |
key = f"{absolute_idx}_{os.path.basename(audio_key)}" | |
print(f"Sample {absolute_idx}: key: {key}") | |
# Fallback key | |
fallback_key = f"{absolute_idx}" | |
sentence = None | |
if key in saved_labels: | |
sentence = saved_labels[key]["transcript"] | |
label_found_count += 1 | |
print(f"Sample {absolute_idx}: label found: {sentence[:50]}") | |
elif fallback_key in saved_labels: | |
sentence = saved_labels[fallback_key]["transcript"] | |
label_found_count += 1 | |
print(f"Sample {absolute_idx}: label found (fallback): {sentence[:50]}") | |
else: | |
sentence = sanitize_sentence(sample.get("sentence", "")) | |
label_missing_count += 1 | |
print(f"Sample {absolute_idx}: no label, default: {sentence[:50]}") | |
# Collect data | |
chunk_data.append({ | |
"audio": audio_dict, | |
"sentence": sentence | |
}) | |
# Timing check | |
sample_time = time.time() - sample_start_time | |
if sample_time > 10: | |
print(f"Sample {absolute_idx}: took {sample_time:.2f}s, slow!") | |
# Cleanup | |
del sample, audio_entry, audio_dict | |
gc.collect() | |
except StopIteration: | |
print(f"Chunk {chunk_idx + 1}: dataset end") | |
break | |
if chunk_data: | |
try: | |
print(f"Chunk {chunk_idx + 1}: saving {len(chunk_data)} samples") | |
chunk_dataset = Dataset.from_list(chunk_data) | |
chunk_dataset = chunk_dataset.cast_column("audio", Audio()) | |
chunk_path = os.path.join(temp_dir, f"chunk_{chunk_idx}.parquet") | |
chunk_dataset.to_parquet(chunk_path) | |
all_datasets.append(chunk_path) | |
processed_samples += len(chunk_data) | |
progress(processed_samples / total_samples, f"Processed {processed_samples}/{total_samples}") | |
print(f"Chunk {chunk_idx + 1}: saved to {chunk_path}, took {time.time() - chunk_start_time:.2f}s") | |
except Exception as e: | |
print(f"Chunk {chunk_idx + 1}: save error: {str(e)}") | |
raise | |
del chunk_data | |
gc.collect() | |
print(f"Labels: {label_found_count} found, {label_missing_count} missing") | |
# Combine and upload | |
if all_datasets: | |
print("Combining chunks") | |
try: | |
combined_dataset = Dataset.from_parquet([p for p in all_datasets]) | |
print("Creating DatasetDict") | |
dataset_dict = DatasetDict({"train": combined_dataset}) | |
except Exception as e: | |
print(f"Combine error: {str(e)}") | |
raise | |
progress(0.95, "Uploading to Hugging Face...") | |
print(f"Pushing to {repo_name}") | |
push_to_hub_with_retry( | |
dataset_dict=dataset_dict, | |
repo_id=repo_name, | |
private=True, | |
token=token | |
) | |
print(f"Upload done, total time: {time.time() - start_time:.2f}s") | |
progress(1.0, "Upload complete!") | |
return f"Exported to huggingface.co/datasets/{repo_name}" | |
else: | |
print("No data") | |
return "No data to export." | |
except Exception as e: | |
error_msg = f"Error: {str(e)}" | |
print(error_msg) | |
return f"Export failed: {str(e)}" | |
def hf_login(hf_token): | |
global CURRENT_USERNAME | |
try: | |
username = whoami(token=hf_token)['name'] | |
if username in ALLOWED_USERS: | |
CURRENT_USERNAME = username | |
return gr.update(visible=False), gr.update(visible=True), username, hf_token, "Login successful!" | |
else: | |
return gr.update(visible=True), gr.update(visible(False)), "", hf_token, "User not authorized!" | |
except Exception as e: | |
return gr.update(visible=True), gr.update(visible(False)), "", hf_token, f"Login failed: {str(e)}" | |
# Set initial values from the first sample | |
if len(current_page_data) > 0: | |
init_page_idx, init_idx, init_audio, init_text, init_reviewer, init_msg, init_original = load_interface(0, 0) | |
else: | |
init_page_idx, init_idx, init_audio, init_text, init_reviewer, init_msg, init_original = 0, 0, None, "", "unreviewed", "No data available. Please check your dataset configuration.", "" | |
# ----------------- Build Gradio Interface ----------------- | |
with gr.Blocks(title="ASR Dataset Labeling with HF Authentication") as demo: | |
hf_token_state = gr.State("") | |
# Login interface | |
with gr.Column(visible=True, elem_id="login_container") as login_container: | |
gr.Markdown("## HF Authentication\nPlease enter your Hugging Face token to proceed.") | |
hf_token_input = gr.Textbox(label="Hugging Face Token", type="password", placeholder="Enter your HF token") | |
login_button = gr.Button("Login") | |
login_message = gr.Markdown("") | |
# Main labeling interface | |
with gr.Column(visible=False, elem_id="main_container") as main_container: | |
gr.Markdown("# ASR Dataset Labeling Interface") | |
gr.Markdown("Listen to audio and edit transcriptions. Changes are saved via the Save & Next button.") | |
with gr.Row(): | |
current_page_idx = gr.State(value=init_page_idx) | |
current_idx = gr.State(value=init_idx) | |
original_transcript = gr.State(value=init_original) | |
with gr.Column(): | |
audio_player = gr.Audio(value=init_audio, label="Audio", autoplay=True) | |
transcript = gr.TextArea(value=init_text, label="Transcript", lines=5, placeholder="Edit transcript here...") | |
reviewer = gr.Textbox(value=init_reviewer, label="Reviewer", placeholder="Reviewer (auto-filled)", interactive=False) | |
status = gr.Markdown(value=init_msg) | |
# Navigation buttons: Reordered so Save & Next is in the middle. | |
with gr.Row(): | |
prev_button = gr.Button("← Previous") | |
save_next_button = gr.Button("Save & Next", variant="primary") | |
next_button = gr.Button("Next") | |
# Audio editing buttons and inputs | |
with gr.Row(): | |
trim_start = gr.Textbox(label="Trim Start (seconds)", placeholder="e.g., 1.5") | |
trim_end = gr.Textbox(label="Trim End (seconds)", placeholder="e.g., 3.0") | |
trim_button = gr.Button("Trim Audio", variant="primary") | |
undo_trim_button = gr.Button("Undo Trim") | |
# Delete audio buttons | |
with gr.Row(): | |
delete_button = gr.Button("Delete Audio", variant="stop") | |
with gr.Row(): | |
# These confirmation buttons become visible when delete is requested | |
confirm_delete_button = gr.Button("Confirm Delete", visible=False) | |
cancel_delete_button = gr.Button("Cancel Delete", visible=False) | |
# Jump to specific index and Export to Hugging Face | |
with gr.Row(): | |
jump_text = gr.Textbox(label="Jump to Global Index", placeholder="Enter index number") | |
jump_button = gr.Button("Jump") | |
with gr.Row(): | |
hf_repo_name = gr.Textbox(label="Repository Name (username/dataset-name)", | |
placeholder="e.g., your-username/asr-dataset") | |
with gr.Row(): | |
hf_export_button = gr.Button("Export to Hugging Face", variant="primary") | |
hf_export_status = gr.Markdown("") | |
# Event handlers | |
save_next_button.click( | |
fn=save_and_next_sample, | |
inputs=[current_page_idx, current_idx, transcript, reviewer, original_transcript], | |
outputs=[current_page_idx, current_idx, audio_player, transcript, reviewer, status, original_transcript] | |
) | |
next_button.click( | |
fn=go_next_without_save, | |
inputs=[current_page_idx, current_idx, transcript, reviewer, original_transcript], | |
outputs=[current_page_idx, current_idx, audio_player, transcript, reviewer, status, original_transcript] | |
) | |
prev_button.click( | |
fn=prev_sample, | |
inputs=[current_page_idx, current_idx, transcript, reviewer, original_transcript], | |
outputs=[current_page_idx, current_idx, audio_player, transcript, reviewer, status, original_transcript] | |
) | |
jump_button.click( | |
fn=jump_to, | |
inputs=[jump_text, current_page_idx, current_idx, transcript, reviewer, original_transcript], | |
outputs=[current_page_idx, current_idx, audio_player, transcript, reviewer, status, original_transcript] | |
) | |
trim_button.click( | |
fn=trim_audio_action, | |
inputs=[current_page_idx, current_idx, trim_start, trim_end, transcript, reviewer, original_transcript], | |
outputs=[current_page_idx, current_idx, audio_player, transcript, reviewer, status, original_transcript] | |
) | |
undo_trim_button.click( | |
fn=undo_trim_action, | |
inputs=[current_page_idx, current_idx, transcript, reviewer, original_transcript], | |
outputs=[current_page_idx, current_idx, audio_player, transcript, reviewer, status, original_transcript] | |
) | |
delete_button.click( | |
fn=lambda: (gr.update(visible=True), gr.update(visible=True)), | |
inputs=None, | |
outputs=[confirm_delete_button, cancel_delete_button] | |
) | |
confirm_delete_button.click( | |
fn=confirm_delete_audio, | |
inputs=[current_page_idx, current_idx, transcript, reviewer, original_transcript], | |
outputs=[current_page_idx, current_idx, audio_player, transcript, reviewer, status, original_transcript] | |
) | |
# After deletion, hide the confirmation buttons. | |
confirm_delete_button.click(lambda: gr.update(visible=False), inputs=None, outputs=confirm_delete_button) | |
confirm_delete_button.click(lambda: gr.update(visible=False), inputs=None, outputs=cancel_delete_button) | |
cancel_delete_button.click(lambda: gr.update(visible=False), inputs=None, outputs=confirm_delete_button) | |
cancel_delete_button.click(lambda: gr.update(visible=False), inputs=None, outputs=cancel_delete_button) | |
hf_export_button.click(fn=export_to_huggingface, inputs=[hf_repo_name, hf_token_state], outputs=[hf_export_status], queue=False) | |
login_button.click( | |
fn=hf_login, | |
inputs=[hf_token_input], | |
outputs=[login_container, main_container, reviewer, hf_token_state, login_message] | |
) | |
demo.launch() | |