MLX_GPT_OSS_120B / whisper /m3_optimized_whisper_2.py
TroglodyteDerivations's picture
Upload 48 files
c28358e verified
#!/usr/bin/env python3
"""
Working Whisper Transcription for Apple M3 Ultra (CPU Version)
Fixes MPS compatibility issues by using CPU
"""
import whisper
import torch
import time
from pathlib import Path
import logging
import os
import sys
import subprocess
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def check_environment():
"""Check the current environment and suggest fixes"""
logger.info("๐Ÿ” Checking environment...")
# Check PyTorch version
logger.info(f"๐Ÿ Python version: {sys.version}")
logger.info(f"๐Ÿ”ฅ PyTorch version: {torch.__version__}")
logger.info(f"๐ŸŽค Whisper version: {whisper.__version__}")
# Check MPS availability
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
logger.info(f"๐ŸŽ MPS available: {mps_available}")
# Check CUDA availability
cuda_available = torch.cuda.is_available()
logger.info(f"๐ŸŽฎ CUDA available: {cuda_available}")
# Recommend CPU for stability
logger.info("๐Ÿ’ก Using CPU for stable transcription (MPS has compatibility issues)")
return "cpu"
def transcribe_with_cpu(audio_file, model_size="medium"):
"""Transcribe using CPU for maximum compatibility"""
logger.info(f"๐ŸŽง Transcribing: {audio_file}")
logger.info(f"๐Ÿค– Using model: {model_size}")
logger.info("โšก Device: CPU (stable mode)")
try:
start_time = time.time()
# Load model with CPU device
model = whisper.load_model(
model_size,
device="cpu",
download_root="./whisper_models"
)
# Transcribe with CPU - explicitly set language to English
logger.info("๐ŸŽค Starting transcription...")
result = model.transcribe(
audio_file,
verbose=True,
fp16=False, # Disable FP16 for CPU stability
temperature=0.0,
best_of=1,
language="en" # Explicitly set language to English
)
return result, time.time() - start_time
except Exception as e:
logger.error(f"โŒ Transcription failed: {e}")
return None, 0
def estimate_time(audio_file, model_size):
"""Estimate transcription time based on file size"""
file_size_mb = os.path.getsize(audio_file) / (1024 * 1024)
# Rough time estimates per MB for CPU
time_per_mb = {
'tiny': 1.5,
'base': 2.0,
'small': 3.0,
'medium': 4.5,
'large': 6.0
}
estimated_seconds = file_size_mb * time_per_mb.get(model_size, 4.5)
minutes = int(estimated_seconds // 60)
seconds = int(estimated_seconds % 60)
return f"{minutes}:{seconds:02d}"
def main():
audio_file = "yuval_harari_lecture.mp3"
if not Path(audio_file).exists():
logger.error(f"โŒ Audio file not found: {audio_file}")
logger.info("๐Ÿ’ก Run: python comprehensive_yt_dl.py to download the lecture")
return
# Check file size
file_size_mb = os.path.getsize(audio_file) / (1024 * 1024)
logger.info(f"๐Ÿ“Š File size: {file_size_mb:.1f} MB")
# Estimate time
estimated_time = estimate_time(audio_file, "medium")
logger.info(f"โฑ๏ธ Estimated time: ~{estimated_time}")
# Check environment
device = check_environment()
logger.info("๐Ÿš€ Starting transcription...")
logger.info("โš ๏ธ This may take a while on CPU - be patient!")
# Transcribe
result, duration = transcribe_with_cpu(audio_file, "medium")
if result:
# Save results
output_file = f"{Path(audio_file).stem}_transcript.txt"
with open(output_file, 'w', encoding='utf-8') as f:
f.write(result['text'])
# Performance metrics
minutes = int(duration // 60)
seconds = int(duration % 60)
logger.info(f"โœ… Transcription completed in {minutes}:{seconds:02d}")
logger.info(f"๐Ÿ“ Saved to: {output_file}")
logger.info(f"๐Ÿ“„ Word count: {len(result['text'].split()):,}")
# Show preview
preview = result['text'][:500] + "..." if len(result['text']) > 500 else result['text']
logger.info(f"๐Ÿ“‹ Preview:\n{preview}")
# Save additional formats
save_additional_formats(Path(audio_file).stem, result)
else:
logger.error("โŒ Transcription failed completely")
def save_additional_formats(base_name, result):
"""Save transcript in additional formats"""
# Save as JSON with timestamps
json_path = f"{base_name}_timestamps.json"
try:
import json
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(result, f, indent=2, ensure_ascii=False)
logger.info(f"โฐ Timestamps saved to: {json_path}")
except Exception as e:
logger.warning(f"โš ๏ธ Could not save JSON: {e}")
# Save as SRT using CLI if available
try:
subprocess.run([
"whisper", f"{base_name}.mp3",
"--model", "medium",
"--output_format", "srt",
"--device", "cpu",
"--language", "en" # Also set language for SRT generation
], timeout=300)
if Path(f"{base_name}.srt").exists():
logger.info(f"๐ŸŽฌ Subtitles saved to: {base_name}.srt")
except:
pass
if __name__ == "__main__":
main()