audio-labelling / app.py
navidved's picture
Update app.py
6eab7ac verified
import gradio as gr
import os
import json
import pandas as pd
from datasets import load_dataset, DatasetDict, Dataset, Audio
from huggingface_hub import HfApi, whoami, login
import tempfile
import shutil
import gc
import time
import psutil
from pydub import AudioSegment # For audio trimming
import soundfile as sf
from tenacity import retry, stop_after_attempt, wait_exponential
import re
import numpy as np
from huggingface_hub import hf_hub_download
# Log in with Hugging Face token if available
token = os.getenv("hf_token")
login(token)
# Configuration
HF_DATASET_NAME = "navidved/channelb-raw-data"
AUDIO_DIR = "audio"
SAVE_PATH = "labeled_data.json"
ALLOWED_USERS = ["vargha", "navidved"]
CURRENT_USERNAME = None
PAGE_SIZE = 100 # Number of samples to load at once
# Global state variables
current_page = 0
ds_iter = None
current_page_data = None
audio_backup = {} # To store original audio paths for each sample
# Load saved labels from JSON file (and convert old formats if needed)
def load_saved_labels():
# First, try to load from the local file system.
if os.path.exists(SAVE_PATH):
try:
with open(SAVE_PATH, "r", encoding="utf-8") as f:
data = json.load(f)
except Exception as e:
print("Error loading local JSON file: " + str(e))
data = {}
else:
data = {}
# If no local file found or it's empty, try to download from the HF dataset repo.
if not data:
try:
# Download the file from HF repo (make sure repo_type="dataset" is specified)
hf_path = hf_hub_download(
repo_id=HF_DATASET_NAME,
filename=SAVE_PATH,
repo_type="dataset",
token=token
)
with open(hf_path, "r", encoding="utf-8") as f:
data = json.load(f)
# Write the downloaded file to the local path for caching.
with open(SAVE_PATH, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
print("Loaded JSON file from HF dataset repository and cached it locally.")
except Exception as e:
print("Error loading JSON file from HF repo: " + str(e))
data = {}
# Convert any old string-formats to the new dict format.
for key, value in data.items():
if isinstance(value, str):
data[key] = {"transcript": value, "reviewer": "unreviewed"}
return data
saved_labels = load_saved_labels()
# New function: Push the JSON file to the Hugging Face dataset repository.
def push_json_to_hf():
try:
api = HfApi()
api.upload_file(
path_or_fileobj=SAVE_PATH, # File to upload
path_in_repo=SAVE_PATH, # Destination path in repo
repo_type="dataset",
repo_id=HF_DATASET_NAME, # Repository id
token=token
)
print("Uploaded labeled_data.json to Hugging Face repository")
except Exception as e:
print("Error uploading json file: " + str(e))
# Initialize dataset iterator for streaming
def init_dataset_iterator():
global ds_iter
try:
ds = load_dataset(HF_DATASET_NAME, split="train", streaming=True)
ds_iter = iter(ds)
return True
except Exception as e:
print(f"Error initializing dataset iterator: {e}")
return False
# Load a page of data
def load_page_data(page_num=0):
global ds_iter, current_page_data, current_page
if ds_iter is None:
if not init_dataset_iterator():
return pd.DataFrame(columns=["audio", "sentence"])
if page_num < current_page:
ds_iter = iter(load_dataset(HF_DATASET_NAME, split="train", streaming=True))
current_page = 0
samples_to_skip = page_num * PAGE_SIZE - (current_page * PAGE_SIZE) if page_num > current_page else 0
for _ in range(samples_to_skip):
try:
next(ds_iter)
except StopIteration:
break
samples = []
for _ in range(PAGE_SIZE):
try:
samples.append(next(ds_iter))
except StopIteration:
break
current_page = page_num
current_page_data = pd.DataFrame(samples)
gc.collect()
return current_page_data
# Get dataset info (number of samples)
def get_dataset_info():
try:
info = load_dataset(HF_DATASET_NAME, split="train", streaming=True).info
if hasattr(info, 'splits') and 'train' in info.splits:
return {'num_samples': info.splits['train'].num_examples}
return {'num_samples': -1}
except Exception as e:
print(f"Error getting dataset info: {e}")
return {'num_samples': -1}
init_dataset_iterator()
current_page_data = load_page_data(0)
dataset_info = get_dataset_info()
total_samples = dataset_info.get('num_samples', -1)
# Resolve the audio file path (or data) from an entry
def get_audio_path(audio_entry):
if isinstance(audio_entry, dict):
if "array" in audio_entry and "sampling_rate" in audio_entry:
# Return the tuple as-is (raw audio data)
return (audio_entry["sampling_rate"], audio_entry["array"])
return audio_entry.get("path", None)
if isinstance(audio_entry, str):
if audio_entry.startswith("http://") or audio_entry.startswith("https://"):
return audio_entry
if os.path.exists(audio_entry):
return audio_entry
joined_path = os.path.join(AUDIO_DIR, audio_entry)
if os.path.exists(joined_path):
return joined_path
return audio_entry
# Save sample data (transcript, reviewer, and any edited audio info)
def save_sample_data(page_idx, idx, transcript, reviewer):
global current_page_data
if idx >= len(current_page_data):
return "Invalid index"
absolute_idx = page_idx * PAGE_SIZE + idx
audio_entry = current_page_data.iloc[idx]["audio"]
key = f"{absolute_idx}_{os.path.basename(str(audio_entry))}"
if key in saved_labels:
saved_labels[key]["transcript"] = transcript.strip()
saved_labels[key]["reviewer"] = reviewer
else:
saved_labels[key] = {"transcript": transcript.strip(), "reviewer": reviewer}
try:
with open(SAVE_PATH, "w", encoding="utf-8") as f:
json.dump(saved_labels, f, ensure_ascii=False, indent=4)
# Every time 10 audios are saved, push the json file to the Hugging Face dataset repo.
if len(saved_labels) % 10 == 0:
push_json_to_hf()
return f"✓ Saved data for {key}"
except Exception as e:
return f"Error saving: {str(e)}"
# Retrieve sample data (audio, transcript, status, reviewer) for a given index.
def get_sample(page_idx, idx):
global current_page_data
if idx < 0 or idx >= len(current_page_data):
return None, "", f"Invalid index. Range is 0-{len(current_page_data)-1}", "unreviewed"
absolute_idx = page_idx * PAGE_SIZE + idx
audio_entry = current_page_data.iloc[idx]["audio"]
key = f"{absolute_idx}_{os.path.basename(str(audio_entry))}"
sample = saved_labels.get(key, {"transcript": current_page_data.iloc[idx]["sentence"], "reviewer": "unreviewed"})
transcript = sample.get("transcript", current_page_data.iloc[idx]["sentence"])
reviewer = sample.get("reviewer", "unreviewed")
# Check for audio edits or deletion flag.
if sample.get("deleted", False):
audio_val = None
elif "audio_edit" in sample:
audio_val = sample["audio_edit"]
else:
audio_val = get_audio_path(audio_entry)
status = f"Sample {absolute_idx+1}"
if total_samples > 0:
status += f" of {total_samples}"
return audio_val, transcript, status, reviewer
# Load the interface components for a sample index; also backup original audio.
def load_interface(page_idx, idx):
audio, text, base_status, saved_reviewer = get_sample(page_idx, idx)
absolute_idx = page_idx * PAGE_SIZE + idx
audio_entry = current_page_data.iloc[idx]["audio"]
key = f"{absolute_idx}_{os.path.basename(str(audio_entry))}"
if key not in audio_backup:
audio_backup[key] = audio # Store the original audio path
status = f"{base_status} - Page {page_idx+1} - Reviewer: {saved_reviewer}"
return page_idx, idx, audio, text, saved_reviewer, status, text
# Navigation functions
def next_sample(page_idx, idx, current_text, current_annotator, original_transcript):
global current_page_data
if idx >= len(current_page_data) - 1:
new_page_idx = page_idx + 1
new_idx = 0
load_page_data(new_page_idx)
else:
new_page_idx = page_idx
new_idx = idx + 1
return load_interface(new_page_idx, new_idx)
def go_next_without_save(page_idx, idx, current_text, current_annotator, original_transcript):
return next_sample(page_idx, idx, original_transcript, current_annotator, original_transcript)
def prev_sample(page_idx, idx, current_text, current_annotator, original_transcript):
global current_page_data
if current_text.strip() != original_transcript.strip():
save_sample_data(page_idx, idx, current_text, CURRENT_USERNAME)
if idx <= 0:
if page_idx > 0:
new_page_idx = page_idx - 1
new_data = load_page_data(new_page_idx)
new_idx = len(new_data) - 1
else:
new_page_idx = page_idx
new_idx = idx
else:
new_page_idx = page_idx
new_idx = idx - 1
return load_interface(new_page_idx, new_idx)
def jump_to(target_idx, page_idx, idx, current_text, current_annotator, original_transcript):
if current_text.strip() != original_transcript.strip():
save_sample_data(page_idx, idx, current_text, CURRENT_USERNAME)
try:
target_idx = int(target_idx)
if target_idx < 0:
target_idx = 0
new_page_idx = target_idx // PAGE_SIZE
new_idx = target_idx % PAGE_SIZE
new_data = load_page_data(new_page_idx)
if new_idx >= len(new_data):
new_idx = len(new_data) - 1 if len(new_data) > 0 else 0
return load_interface(new_page_idx, new_idx)
except:
return load_interface(page_idx, idx)
def save_and_next_sample(page_idx, idx, current_text, current_annotator, original_transcript):
save_sample_data(page_idx, idx, current_text, CURRENT_USERNAME)
return next_sample(page_idx, idx, current_text, current_annotator, current_text)
# ----------------- Audio Editing Functions -----------------
def trim_audio_action(page_idx, idx, trim_start, trim_end, current_text, current_annotator, original_transcript):
audio, transcript, base_status, saved_reviewer = get_sample(page_idx, idx)
temp_audio_file = None
# If the audio is provided as raw data (tuple), convert it to a temporary WAV file.
if isinstance(audio, tuple):
sample_rate, audio_array = audio
try:
import numpy as np
from scipy.io.wavfile import write
# If the audio array is in float format, convert it to int16
if np.issubdtype(audio_array.dtype, np.floating):
# Normalize if necessary and then convert to int16
audio_array = (audio_array * 32767).astype(np.int16)
else:
# Ensure the data is a numpy array
audio_array = np.array(audio_array)
temp_audio_file = os.path.join(tempfile.gettempdir(), f"temp_{page_idx}_{idx}.wav")
write(temp_audio_file, sample_rate, audio_array)
audio = temp_audio_file
except Exception as e:
return page_idx, idx, audio, transcript, saved_reviewer, f"Error converting raw audio: {str(e)}", transcript
# Now ensure audio is a file path before trimming.
if not isinstance(audio, str) or not os.path.exists(audio):
return page_idx, idx, audio, transcript, saved_reviewer, "Trimming not supported for this audio format.", transcript
try:
audio_seg = AudioSegment.from_file(audio)
start_ms = int(float(trim_start) * 1000)
end_ms = int(float(trim_end) * 1000)
trimmed_seg = audio_seg[start_ms:end_ms]
os.makedirs("trimmed_audio", exist_ok=True)
trimmed_path = os.path.join("trimmed_audio", f"trimmed_{os.path.basename(str(audio))}")
trimmed_seg.export(trimmed_path, format="wav")
absolute_idx = page_idx * PAGE_SIZE + idx
audio_entry = current_page_data.iloc[idx]["audio"]
key = f"{absolute_idx}_{os.path.basename(str(audio_entry))}"
if key not in saved_labels:
saved_labels[key] = {}
saved_labels[key]["audio_edit"] = trimmed_path
new_status = f"{base_status} [Trimmed]"
return page_idx, idx, trimmed_path, transcript, saved_reviewer, new_status, transcript
except Exception as e:
return page_idx, idx, audio, transcript, saved_reviewer, f"Error trimming audio: {str(e)}", transcript
def undo_trim_action(page_idx, idx, current_text, current_annotator, original_transcript):
audio, transcript, base_status, saved_reviewer = get_sample(page_idx, idx)
absolute_idx = page_idx * PAGE_SIZE + idx
audio_entry = current_page_data.iloc[idx]["audio"]
key = f"{absolute_idx}_{os.path.basename(str(audio_entry))}"
if key in saved_labels and "audio_edit" in saved_labels[key]:
del saved_labels[key]["audio_edit"]
orig_audio = audio_backup.get(key, audio)
new_status = f"{base_status} [Trim undone]"
return page_idx, idx, orig_audio, transcript, saved_reviewer, new_status, transcript
def confirm_delete_audio(page_idx, idx, current_text, current_annotator, original_transcript):
audio, transcript, base_status, saved_reviewer = get_sample(page_idx, idx)
absolute_idx = page_idx * PAGE_SIZE + idx
audio_entry = current_page_data.iloc[idx]["audio"]
key = f"{absolute_idx}_{os.path.basename(str(audio_entry))}"
if key not in saved_labels:
saved_labels[key] = {}
saved_labels[key]["deleted"] = True
saved_labels[key]["transcript"] = "AUDIO DELETED (This audio has been removed.)"
new_status = f"{base_status} [Audio deleted]"
# Remove any audio editing info if present.
if "audio_edit" in saved_labels[key]:
del saved_labels[key]["audio_edit"]
return page_idx, idx, None, saved_labels[key]["transcript"], "deleted", new_status, saved_labels[key]["transcript"]
# ----------------- Export to Hugging Face Function -----------------
def sanitize_string(s):
"""Remove invalid characters from strings."""
if not isinstance(s, str):
s = str(s)
return re.sub(r'[^\w\-\./]', '_', s)
def sanitize_sentence(s):
"""Ensure sentence is a clean string."""
if not isinstance(s, str):
s = str(s)
return s.encode('utf-8', errors='ignore').decode('utf-8')
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
def push_to_hub_with_retry(dataset_dict, repo_id, private=True, token=None):
"""Push dataset to Hugging Face Hub with retry logic."""
print(f"Pushing dataset to {repo_id}")
dataset_dict.push_to_hub(repo_id, private=private, token=token)
def export_to_huggingface(repo_name, token, progress=gr.Progress()):
"""
Export the labeled dataset to Hugging Face with 'audio' column cast to Audio
and 'sentence' column, pushed as a DatasetDict to the Hub.
Args:
repo_name (str): The Hugging Face repository name (e.g., "username/dataset-name").
token (str): Hugging Face authentication token.
progress (gr.Progress): Gradio progress tracker.
Returns:
str: Status message indicating success or failure.
"""
try:
start_time = time.time()
repo_name = sanitize_string(repo_name)
print(f"Export started at {time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Repo name: {repo_name}, Token: {'[hidden]' if token else 'None'}")
print(f"Total labels loaded: {len(saved_labels)}")
print(f"Sample label keys: {list(saved_labels.keys())[:10]}")
# Get total number of samples
total_samples = get_dataset_info().get('num_samples', -1)
if total_samples <= 0:
total_samples = 1500 # Limit for testing
chunk_size = 100 # Smaller chunks
num_chunks = (total_samples + chunk_size - 1) // chunk_size
print(f"Total samples: {total_samples}, chunks: {num_chunks}")
progress(0, f"Total samples: {total_samples}, chunks: {num_chunks}")
# Temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
print(f"Temp dir: {temp_dir}")
# Load dataset
ds = load_dataset(HF_DATASET_NAME, split="train", streaming=True)
ds_iter = iter(ds)
print("Dataset initialized")
all_datasets = []
label_found_count = 0
label_missing_count = 0
processed_samples = 0
# Process chunks
for chunk_idx in range(num_chunks):
chunk_start_time = time.time()
print(f"Chunk {chunk_idx + 1}/{num_chunks}")
chunk_data = []
for _ in range(chunk_size):
sample_start_time = time.time()
try:
sample = next(ds_iter)
absolute_idx = chunk_idx * chunk_size + len(chunk_data)
process = psutil.Process()
memory_info = process.memory_info()
print(f"Sample {absolute_idx} memory: {memory_info.rss / 1024**2:.2f} MB")
# Handle audio
audio_entry = sample["audio"]
audio_dict = None
audio_key = None
try:
if isinstance(audio_entry, dict) and "array" in audio_entry:
# Validate audio array
if not np.all(np.isfinite(audio_entry["array"])):
print(f"Sample {absolute_idx}: invalid audio array, skipping")
audio_dict = None
else:
audio_dict = {
"array": audio_entry["array"],
"sampling_rate": audio_entry["sampling_rate"]
}
audio_key = audio_entry.get("path", f"sample_{absolute_idx}.mp3")
audio_key = sanitize_string(audio_key)
print(f"Sample {absolute_idx}: raw audio, path: {audio_key}")
elif isinstance(audio_entry, str):
if audio_entry.startswith("http://") or audio_entry.startswith("https://"):
print(f"Sample {absolute_idx}: skipping URL: {audio_entry}")
audio_dict = None
audio_key = sanitize_string(audio_entry)
else:
resolved_path = get_audio_path(audio_entry)
if os.path.exists(resolved_path):
try:
audio_array, sample_rate = sf.read(resolved_path)
if not np.all(np.isfinite(audio_array)):
print(f"Sample {absolute_idx}: invalid audio array, skipping")
audio_dict = None
else:
audio_dict = {
"array": audio_array,
"sampling_rate": sample_rate
}
audio_key = sanitize_string(resolved_path)
print(f"Sample {absolute_idx}: loaded audio: {resolved_path}")
except Exception as e:
print(f"Sample {absolute_idx}: audio load error: {str(e)}")
audio_dict = None
audio_key = sanitize_string(resolved_path)
else:
print(f"Sample {absolute_idx}: file not found: {resolved_path}")
audio_dict = None
audio_key = sanitize_string(resolved_path)
else:
print(f"Sample {absolute_idx}: unhandled format: {type(audio_entry)}")
audio_dict = None
audio_key = sanitize_string(str(audio_entry))
except Exception as e:
print(f"Sample {absolute_idx}: audio error: {str(e)}")
audio_dict = None
audio_key = f"sample_{absolute_idx}"
# Label lookup
key = f"{absolute_idx}_{os.path.basename(audio_key)}"
print(f"Sample {absolute_idx}: key: {key}")
# Fallback key
fallback_key = f"{absolute_idx}"
sentence = None
if key in saved_labels:
sentence = saved_labels[key]["transcript"]
label_found_count += 1
print(f"Sample {absolute_idx}: label found: {sentence[:50]}")
elif fallback_key in saved_labels:
sentence = saved_labels[fallback_key]["transcript"]
label_found_count += 1
print(f"Sample {absolute_idx}: label found (fallback): {sentence[:50]}")
else:
sentence = sanitize_sentence(sample.get("sentence", ""))
label_missing_count += 1
print(f"Sample {absolute_idx}: no label, default: {sentence[:50]}")
# Collect data
chunk_data.append({
"audio": audio_dict,
"sentence": sentence
})
# Timing check
sample_time = time.time() - sample_start_time
if sample_time > 10:
print(f"Sample {absolute_idx}: took {sample_time:.2f}s, slow!")
# Cleanup
del sample, audio_entry, audio_dict
gc.collect()
except StopIteration:
print(f"Chunk {chunk_idx + 1}: dataset end")
break
if chunk_data:
try:
print(f"Chunk {chunk_idx + 1}: saving {len(chunk_data)} samples")
chunk_dataset = Dataset.from_list(chunk_data)
chunk_dataset = chunk_dataset.cast_column("audio", Audio())
chunk_path = os.path.join(temp_dir, f"chunk_{chunk_idx}.parquet")
chunk_dataset.to_parquet(chunk_path)
all_datasets.append(chunk_path)
processed_samples += len(chunk_data)
progress(processed_samples / total_samples, f"Processed {processed_samples}/{total_samples}")
print(f"Chunk {chunk_idx + 1}: saved to {chunk_path}, took {time.time() - chunk_start_time:.2f}s")
except Exception as e:
print(f"Chunk {chunk_idx + 1}: save error: {str(e)}")
raise
del chunk_data
gc.collect()
print(f"Labels: {label_found_count} found, {label_missing_count} missing")
# Combine and upload
if all_datasets:
print("Combining chunks")
try:
combined_dataset = Dataset.from_parquet([p for p in all_datasets])
print("Creating DatasetDict")
dataset_dict = DatasetDict({"train": combined_dataset})
except Exception as e:
print(f"Combine error: {str(e)}")
raise
progress(0.95, "Uploading to Hugging Face...")
print(f"Pushing to {repo_name}")
push_to_hub_with_retry(
dataset_dict=dataset_dict,
repo_id=repo_name,
private=True,
token=token
)
print(f"Upload done, total time: {time.time() - start_time:.2f}s")
progress(1.0, "Upload complete!")
return f"Exported to huggingface.co/datasets/{repo_name}"
else:
print("No data")
return "No data to export."
except Exception as e:
error_msg = f"Error: {str(e)}"
print(error_msg)
return f"Export failed: {str(e)}"
def hf_login(hf_token):
global CURRENT_USERNAME
try:
username = whoami(token=hf_token)['name']
if username in ALLOWED_USERS:
CURRENT_USERNAME = username
return gr.update(visible=False), gr.update(visible=True), username, hf_token, "Login successful!"
else:
return gr.update(visible=True), gr.update(visible(False)), "", hf_token, "User not authorized!"
except Exception as e:
return gr.update(visible=True), gr.update(visible(False)), "", hf_token, f"Login failed: {str(e)}"
# Set initial values from the first sample
if len(current_page_data) > 0:
init_page_idx, init_idx, init_audio, init_text, init_reviewer, init_msg, init_original = load_interface(0, 0)
else:
init_page_idx, init_idx, init_audio, init_text, init_reviewer, init_msg, init_original = 0, 0, None, "", "unreviewed", "No data available. Please check your dataset configuration.", ""
# ----------------- Build Gradio Interface -----------------
with gr.Blocks(title="ASR Dataset Labeling with HF Authentication") as demo:
hf_token_state = gr.State("")
# Login interface
with gr.Column(visible=True, elem_id="login_container") as login_container:
gr.Markdown("## HF Authentication\nPlease enter your Hugging Face token to proceed.")
hf_token_input = gr.Textbox(label="Hugging Face Token", type="password", placeholder="Enter your HF token")
login_button = gr.Button("Login")
login_message = gr.Markdown("")
# Main labeling interface
with gr.Column(visible=False, elem_id="main_container") as main_container:
gr.Markdown("# ASR Dataset Labeling Interface")
gr.Markdown("Listen to audio and edit transcriptions. Changes are saved via the Save & Next button.")
with gr.Row():
current_page_idx = gr.State(value=init_page_idx)
current_idx = gr.State(value=init_idx)
original_transcript = gr.State(value=init_original)
with gr.Column():
audio_player = gr.Audio(value=init_audio, label="Audio", autoplay=True)
transcript = gr.TextArea(value=init_text, label="Transcript", lines=5, placeholder="Edit transcript here...")
reviewer = gr.Textbox(value=init_reviewer, label="Reviewer", placeholder="Reviewer (auto-filled)", interactive=False)
status = gr.Markdown(value=init_msg)
# Navigation buttons: Reordered so Save & Next is in the middle.
with gr.Row():
prev_button = gr.Button("← Previous")
save_next_button = gr.Button("Save & Next", variant="primary")
next_button = gr.Button("Next")
# Audio editing buttons and inputs
with gr.Row():
trim_start = gr.Textbox(label="Trim Start (seconds)", placeholder="e.g., 1.5")
trim_end = gr.Textbox(label="Trim End (seconds)", placeholder="e.g., 3.0")
trim_button = gr.Button("Trim Audio", variant="primary")
undo_trim_button = gr.Button("Undo Trim")
# Delete audio buttons
with gr.Row():
delete_button = gr.Button("Delete Audio", variant="stop")
with gr.Row():
# These confirmation buttons become visible when delete is requested
confirm_delete_button = gr.Button("Confirm Delete", visible=False)
cancel_delete_button = gr.Button("Cancel Delete", visible=False)
# Jump to specific index and Export to Hugging Face
with gr.Row():
jump_text = gr.Textbox(label="Jump to Global Index", placeholder="Enter index number")
jump_button = gr.Button("Jump")
with gr.Row():
hf_repo_name = gr.Textbox(label="Repository Name (username/dataset-name)",
placeholder="e.g., your-username/asr-dataset")
with gr.Row():
hf_export_button = gr.Button("Export to Hugging Face", variant="primary")
hf_export_status = gr.Markdown("")
# Event handlers
save_next_button.click(
fn=save_and_next_sample,
inputs=[current_page_idx, current_idx, transcript, reviewer, original_transcript],
outputs=[current_page_idx, current_idx, audio_player, transcript, reviewer, status, original_transcript]
)
next_button.click(
fn=go_next_without_save,
inputs=[current_page_idx, current_idx, transcript, reviewer, original_transcript],
outputs=[current_page_idx, current_idx, audio_player, transcript, reviewer, status, original_transcript]
)
prev_button.click(
fn=prev_sample,
inputs=[current_page_idx, current_idx, transcript, reviewer, original_transcript],
outputs=[current_page_idx, current_idx, audio_player, transcript, reviewer, status, original_transcript]
)
jump_button.click(
fn=jump_to,
inputs=[jump_text, current_page_idx, current_idx, transcript, reviewer, original_transcript],
outputs=[current_page_idx, current_idx, audio_player, transcript, reviewer, status, original_transcript]
)
trim_button.click(
fn=trim_audio_action,
inputs=[current_page_idx, current_idx, trim_start, trim_end, transcript, reviewer, original_transcript],
outputs=[current_page_idx, current_idx, audio_player, transcript, reviewer, status, original_transcript]
)
undo_trim_button.click(
fn=undo_trim_action,
inputs=[current_page_idx, current_idx, transcript, reviewer, original_transcript],
outputs=[current_page_idx, current_idx, audio_player, transcript, reviewer, status, original_transcript]
)
delete_button.click(
fn=lambda: (gr.update(visible=True), gr.update(visible=True)),
inputs=None,
outputs=[confirm_delete_button, cancel_delete_button]
)
confirm_delete_button.click(
fn=confirm_delete_audio,
inputs=[current_page_idx, current_idx, transcript, reviewer, original_transcript],
outputs=[current_page_idx, current_idx, audio_player, transcript, reviewer, status, original_transcript]
)
# After deletion, hide the confirmation buttons.
confirm_delete_button.click(lambda: gr.update(visible=False), inputs=None, outputs=confirm_delete_button)
confirm_delete_button.click(lambda: gr.update(visible=False), inputs=None, outputs=cancel_delete_button)
cancel_delete_button.click(lambda: gr.update(visible=False), inputs=None, outputs=confirm_delete_button)
cancel_delete_button.click(lambda: gr.update(visible=False), inputs=None, outputs=cancel_delete_button)
hf_export_button.click(fn=export_to_huggingface, inputs=[hf_repo_name, hf_token_state], outputs=[hf_export_status], queue=False)
login_button.click(
fn=hf_login,
inputs=[hf_token_input],
outputs=[login_container, main_container, reviewer, hf_token_state, login_message]
)
demo.launch()