Spaces:
Running
Running
| from typing import List, Dict, Any | |
| import os | |
| import shutil | |
| import uuid | |
| import time | |
| from functools import wraps | |
| from fastapi import FastAPI, Request, UploadFile, File, HTTPException, status | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.staticfiles import StaticFiles | |
| from starlette.middleware.sessions import SessionMiddleware | |
| from starlette.responses import FileResponse | |
| from starlette.background import BackgroundTask | |
| from pydantic import BaseModel | |
| from datasets import Dataset, Audio | |
| # --- Pydantic Models for Data Validation/Serialization --- | |
| class SaveAnnotationRequest(BaseModel): | |
| """Model for the POST request payload to save transcription.""" | |
| index: int | |
| transcription: str | |
| speaker: str | |
| class AudioDataResponse(BaseModel): | |
| """Model for the GET response when loading an audio row.""" | |
| index: int | |
| filename: str | |
| transcription: str | |
| speaker: str | |
| max_index: int | |
| # --- Configuration and Global State --- | |
| # Directory to save user-uploaded audio files | |
| UPLOAD_DIR = "./uploaded_audio" | |
| # Data structure to hold the annotation state: {f'{user_id}': [{"filename": str, "speaker": str, "transcription": str}]} | |
| ANNOTATION_DATA: Dict[str, List[Dict[str, Any]]] = {} | |
| # Index of the audio file currently being displayed/annotated | |
| current_index: Dict[str, int] = {} | |
| # Ensure the upload directory exists | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| # --- FastAPI Setup --- | |
| app = FastAPI(title="Audio Annotation Tool with File Upload") | |
| app.mount("/static", StaticFiles(directory="./static", html=True), name="static") | |
| app.add_middleware(SessionMiddleware, secret_key="audio-annotator-application") | |
| templates = Jinja2Templates(directory="./templates") | |
| # --- Utility Functions --- | |
| def load_data_for_index(user_id: int, index: int) -> Dict[str, Any]: | |
| """Helper to safely fetch data for a given index.""" | |
| if not ANNOTATION_DATA[user_id]: | |
| return {"index": -1, "filename": "", "transcription": "No files uploaded yet.", "speaker": "No files uploaded yet.", "max_index": 0} | |
| if 0 <= index < len(ANNOTATION_DATA[user_id]): | |
| item = ANNOTATION_DATA[user_id][index] | |
| return { | |
| "index": index, | |
| "filename": item['filename'], | |
| "transcription": item['transcription'], | |
| "speaker": item['speaker'], | |
| "max_index": len(ANNOTATION_DATA[user_id]) | |
| } | |
| else: | |
| # Wrap around if needed, or handle boundary cases | |
| raise IndexError("Index out of bounds.") | |
| def get_user_directory(user_id): | |
| return os.path.join(UPLOAD_DIR, f'{user_id}') | |
| def serve_index_html(request: Request): | |
| try: | |
| user_id = request.session.get('_id', None) | |
| if user_id is None: | |
| user_id = str(uuid.uuid4()) | |
| request.session['_id'] = user_id | |
| # Ensure the upload directory exists | |
| # print(user_id) | |
| user_dir = get_user_directory(user_id) | |
| os.makedirs(user_dir, exist_ok=True) | |
| ANNOTATION_DATA[user_id] = [] | |
| current_index[user_id] = -1 | |
| return templates.TemplateResponse("index.html", context={"request": request}) | |
| except FileNotFoundError: | |
| return HTMLResponse(content="<h1>Server Error: index.html not found.</h1>", status_code=500) | |
| def rate_limited(max_calls: int, time_frame:int): | |
| """ | |
| :param max_calls: Maximum number of calls allowed in the specified time frame. | |
| :param time_frame: The time frame (in seconds) for which the limit applies. | |
| :return: Decorator function. | |
| """ | |
| def decorator(func): | |
| calls = [] | |
| async def wrapper(request: Request, *args, **kwargs): | |
| now = time.time() | |
| calls_in_time_frame = [call for call in calls if call > now - time_frame] | |
| if len(calls_in_time_frame) > max_calls: | |
| raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Rate limit exceeded.") | |
| calls.append(now) | |
| return await func(request, *args, **kwargs) | |
| return wrapper | |
| return decorator | |
| # --- Routes --- | |
| async def index(request: Request): | |
| """Serves the main application page by reading index.html.""" | |
| return serve_index_html(request) | |
| async def annotate(request: Request): | |
| """Serves the main application page by reading index.html.""" | |
| return serve_index_html(request) | |
| async def upload_audio(request: Request, audio_files: List[UploadFile] = File(...)): | |
| """Handles multiple audio file uploads from the client.""" | |
| global ANNOTATION_DATA, current_index | |
| new_files_count = 0 | |
| user_id = request.session.get('_id', None) | |
| # Reset index if this is the first upload | |
| if not ANNOTATION_DATA[user_id]: | |
| current_index[user_id] = 0 | |
| for file in audio_files: | |
| # Construct the full path | |
| user_dir = get_user_directory(user_id) | |
| file_path = os.path.join(user_dir, file.filename) | |
| # Save the file to disk | |
| try: | |
| with open(file_path, "wb") as buffer: | |
| # Read the file chunk by chunk to handle large files | |
| shutil.copyfileobj(file.file, buffer) | |
| # Update the annotation data structure | |
| ANNOTATION_DATA[user_id].append({ | |
| "filename": file.filename, | |
| "transcription": "", # Initialize transcription as empty, | |
| "speaker": "" # Initialize speaker as empty | |
| }) | |
| new_files_count += 1 | |
| except Exception as e: | |
| print(f"Error saving file {file.filename}: {e}") | |
| raise HTTPException(status_code=500, detail=f"Failed to save file: {file.filename}") | |
| return JSONResponse({ | |
| "message": f"Successfully uploaded {new_files_count} files.", | |
| "total_files": len(ANNOTATION_DATA[user_id]) | |
| }) | |
| async def save_annotation(request: Request, data: SaveAnnotationRequest): | |
| """Saves the transcription for the current index.""" | |
| try: | |
| user_id = request.session.get('_id', None) | |
| index_to_save = data.index | |
| if 0 <= index_to_save < len(ANNOTATION_DATA[user_id]): | |
| # Update the transcription text | |
| ANNOTATION_DATA[user_id][index_to_save]["transcription"] = data.transcription | |
| ANNOTATION_DATA[user_id][index_to_save]["speaker"] = data.speaker | |
| return JSONResponse({"success": True, "message": f"Row {index_to_save + 1} saved."}) | |
| else: | |
| raise HTTPException(status_code=400, detail="Invalid index for saving.") | |
| except Exception as e: | |
| print(f"Error during save: {e}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| async def load_audio_data(request: Request, direction: str): | |
| """Loads the audio data and increments/decrements the current_index.""" | |
| global current_index | |
| user_id = request.session.get('_id', None) | |
| if not ANNOTATION_DATA[user_id]: | |
| return JSONResponse(load_data_for_index(user_id, -1)) | |
| new_index = current_index[user_id] | |
| max_len = len(ANNOTATION_DATA[user_id]) | |
| if direction == 'next': | |
| new_index = (current_index[user_id] + 1) % max_len | |
| elif direction == 'prev': | |
| # Handles wrapping from 0 back to the last index | |
| new_index = (current_index[user_id] - 1 + max_len) % max_len | |
| else: | |
| # 'current' direction is used for initial load or after upload | |
| pass | |
| try: | |
| data = load_data_for_index(user_id, new_index) | |
| # Only update the global index if navigation was successful | |
| current_index[user_id] = new_index | |
| return JSONResponse(data) | |
| except IndexError: | |
| raise HTTPException(status_code=404, detail="No more audio files to load.") | |
| async def serve_audio_file(request: Request, filename: str): | |
| """Streams the requested audio file from the upload directory.""" | |
| user_id = request.session.get('_id', None) | |
| user_dir = get_user_directory(user_id) | |
| file_path = os.path.join(user_dir, filename) | |
| if os.path.exists(file_path): | |
| # FileResponse sends the file directly, optimized for binary streams | |
| return FileResponse(file_path, media_type="audio/wav") # Assume WAV for simplicity, use relevant type if required | |
| raise HTTPException(status_code=404, detail="Audio file not found.") | |
| async def download_annotations(request: Request): | |
| """Returns the entire annotated dataset as a downloadable JSON file.""" | |
| global current_index, ANNOTATION_DATA | |
| user_id = request.session.get('_id', None) | |
| if not ANNOTATION_DATA[user_id]: | |
| raise HTTPException(status_code=404, detail="No annotations available to download.") | |
| user_dir = get_user_directory(user_id) | |
| # Convert the dataset to Dataset | |
| data = {"audio": [], "transcription": [], "speaker": []} | |
| for item in ANNOTATION_DATA[user_id]: | |
| data['audio'].append(os.path.join(user_dir, item['filename'])) | |
| data['transcription'].append(item['transcription']) | |
| data['speaker'].append(item['speaker']) | |
| # print(data) | |
| ds = Dataset.from_dict(data).cast_column('audio', Audio(sampling_rate=16000)) | |
| dataset_dir = os.path.join(user_dir, 'dataset') | |
| ds.save_to_disk(dataset_dir) | |
| # Write the content to a temporary file | |
| zip_dir = os.path.join(user_dir, 'final') | |
| shutil.make_archive(zip_dir, 'zip', dataset_dir) | |
| # Create a temporary file path | |
| temp_file = f'{zip_dir}.zip' | |
| def cleanup_file(): | |
| try: | |
| shutil.rmtree(user_dir) | |
| os.makedirs(user_dir) | |
| except Exception as e: | |
| print(f"Error deleting directory {user_dir}: {e}") | |
| ANNOTATION_DATA[user_id] = [] | |
| current_index[user_id] = -1 | |
| # Return the file, which will be deleted after being sent | |
| return FileResponse( | |
| path=temp_file, | |
| filename='annotated_data.zip', | |
| media_type="application/zip", | |
| background=BackgroundTask(cleanup_file) | |
| ) | |