tahirsher commited on
Commit
49df20f
Β·
verified Β·
1 Parent(s): c3f9689

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -67
app.py CHANGED
@@ -4,19 +4,21 @@ import torch
4
  import torchaudio
5
  import numpy as np
6
  import streamlit as st
 
7
  from huggingface_hub import login
 
8
  from transformers import (
9
  AutoProcessor,
10
  AutoModelForSpeechSeq2Seq,
11
  TrainingArguments,
12
  Trainer,
13
- DataCollatorForSeq2Seq, # βœ… Fix: Use correct data collator
14
  )
15
 
16
  # ================================
17
  # 1️⃣ Authenticate with Hugging Face Hub (Securely)
18
  # ================================
19
- HF_TOKEN = os.getenv("hf_token") # Ensure it's set in Hugging Face Spaces Secrets
20
 
21
  if HF_TOKEN is None:
22
  raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
@@ -30,18 +32,16 @@ MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
30
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
31
  model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
32
 
33
- # Move model to GPU if available
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
  model.to(device)
36
  print(f"βœ… Model loaded on {device}")
37
 
38
  # ================================
39
- # 3️⃣ Load Dataset (Recursively from Extracted Path)
40
  # ================================
41
  DATASET_TAR_PATH = "dev-clean.tar.gz"
42
  EXTRACT_PATH = "./librispeech_dev_clean"
43
 
44
- # Extract dataset if not already extracted
45
  if not os.path.exists(EXTRACT_PATH):
46
  print("πŸ”„ Extracting dataset...")
47
  with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
@@ -50,58 +50,42 @@ if not os.path.exists(EXTRACT_PATH):
50
  else:
51
  print("βœ… Dataset already extracted.")
52
 
53
- # Base directory where audio files are stored
54
- AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
55
 
56
- # Recursively find all `.flac` files inside the dataset directory
57
- def find_audio_files(base_folder):
58
- """Recursively search for all .flac files in subdirectories."""
59
- audio_files = []
60
- for root, _, files in os.walk(base_folder):
61
- for file in files:
62
- if file.endswith(".flac"):
63
- audio_files.append(os.path.join(root, file))
64
- return audio_files
65
 
66
- # Get all audio files
67
- audio_files = find_audio_files(AUDIO_FOLDER)
68
-
69
- if not audio_files:
70
- raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
71
-
72
- print(f"βœ… Found {len(audio_files)} audio files in dataset!")
73
-
74
- # ================================
75
- # 4️⃣ Preprocess Dataset (Fixed input_features)
76
- # ================================
77
- def load_and_process_audio(audio_path):
78
- """Loads and processes a single audio file into model format."""
79
- waveform, sample_rate = torchaudio.load(audio_path)
80
-
81
- # Resample to 16kHz
82
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
83
-
84
- # Convert to model input format
85
- input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
86
-
87
- return input_features
88
-
89
- # Manually create dataset structure
90
- dataset = [{"input_features": load_and_process_audio(f), "labels": []} for f in audio_files[:100]]
91
-
92
- # Split dataset into train and eval
 
 
93
  train_size = int(0.8 * len(dataset))
94
- train_dataset = dataset[:train_size]
95
- eval_dataset = dataset[train_size:]
96
 
97
- print(f"βœ… Dataset Loaded! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
98
 
99
  # ================================
100
- # 5️⃣ Training Arguments & Trainer
101
  # ================================
102
  training_args = TrainingArguments(
103
  output_dir="./asr_model_finetuned",
104
- eval_strategy="epoch", # Fixed deprecated evaluation_strategy
105
  save_strategy="epoch",
106
  learning_rate=5e-5,
107
  per_device_train_batch_size=8,
@@ -111,15 +95,13 @@ training_args = TrainingArguments(
111
  logging_dir="./logs",
112
  logging_steps=500,
113
  save_total_limit=2,
114
- push_to_hub=True, # Fix: Properly authenticate Hugging Face Hub
115
- hub_model_id="tahirsher/ASR_Model_for_Transcription_into_Text", # Replace with your Hugging Face repo
116
  hub_token=HF_TOKEN,
117
  )
118
 
119
- # βœ… FIX: Use correct Data Collator
120
  data_collator = DataCollatorForSeq2Seq(tokenizer=processor.tokenizer, model=model, return_tensors="pt")
121
 
122
- # Define Trainer
123
  trainer = Trainer(
124
  model=model,
125
  args=training_args,
@@ -129,45 +111,54 @@ trainer = Trainer(
129
  )
130
 
131
  # ================================
132
- # 6️⃣ Fine-Tuning Execution
133
  # ================================
134
  if st.button("Start Fine-Tuning"):
135
  with st.spinner("Fine-tuning in progress... Please wait!"):
136
  trainer.train()
137
  st.success("βœ… Fine-Tuning Completed! Model updated.")
138
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  # ================================
140
- # 7️⃣ Streamlit ASR Web App
141
  # ================================
142
- st.title("πŸŽ™οΈ Speech-to-Text ASR with Fine-Tuning 🎢")
143
 
144
- # Upload audio file
145
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
146
 
147
  if audio_file:
148
- # Save uploaded file temporarily
149
  audio_path = "temp_audio.wav"
150
  with open(audio_path, "wb") as f:
151
  f.write(audio_file.read())
152
 
153
- # Load and process audio
154
  waveform, sample_rate = torchaudio.load(audio_path)
155
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
156
 
157
- # Convert audio to model input
158
  input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features
159
 
160
- # βœ… FIX: Ensure input tensor is correctly formatted
161
- input_tensor = input_features.to(device) # Move to GPU/CPU
162
-
163
- # βœ… FIX: Provide decoder_input_ids
164
- decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device)
165
 
166
- # Perform ASR inference
167
  with torch.no_grad():
168
- logits = model(input_tensor, decoder_input_ids=decoder_input_ids).logits
169
- predicted_ids = torch.argmax(logits, dim=-1)
170
- transcription = processor.batch_decode(predicted_ids)[0]
 
 
 
 
 
171
 
172
  # Display transcription
173
  st.success("πŸ“„ Transcription:")
 
4
  import torchaudio
5
  import numpy as np
6
  import streamlit as st
7
+ import matplotlib.pyplot as plt
8
  from huggingface_hub import login
9
+ from datasets import load_dataset, DatasetDict
10
  from transformers import (
11
  AutoProcessor,
12
  AutoModelForSpeechSeq2Seq,
13
  TrainingArguments,
14
  Trainer,
15
+ DataCollatorForSeq2Seq,
16
  )
17
 
18
  # ================================
19
  # 1️⃣ Authenticate with Hugging Face Hub (Securely)
20
  # ================================
21
+ HF_TOKEN = os.getenv("hf_token")
22
 
23
  if HF_TOKEN is None:
24
  raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
 
32
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
33
  model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
34
 
 
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
  model.to(device)
37
  print(f"βœ… Model loaded on {device}")
38
 
39
  # ================================
40
+ # 3️⃣ Load and Prepare Dataset
41
  # ================================
42
  DATASET_TAR_PATH = "dev-clean.tar.gz"
43
  EXTRACT_PATH = "./librispeech_dev_clean"
44
 
 
45
  if not os.path.exists(EXTRACT_PATH):
46
  print("πŸ”„ Extracting dataset...")
47
  with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
 
50
  else:
51
  print("βœ… Dataset already extracted.")
52
 
53
+ # Load dataset with transcripts
54
+ dataset = load_dataset("librispeech_asr", "clean", split="train")
55
 
56
+ # Ensure dataset has transcripts
57
+ if "text" not in dataset.column_names:
58
+ raise ValueError("❌ Dataset is missing transcription text!")
 
 
 
 
 
 
59
 
60
+ # Preprocessing Function
61
+ def preprocess_data(batch):
62
+ # Process audio
63
+ waveform, sample_rate = torchaudio.load(batch["file"])
 
 
 
 
 
 
 
 
 
 
 
 
64
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
65
+
66
+ batch["input_features"] = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
67
+
68
+ # Tokenize transcript text
69
+ batch["labels"] = processor.tokenizer(batch["text"], padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
70
+
71
+ return batch
72
+
73
+ # Apply preprocessing
74
+ dataset = dataset.map(preprocess_data, remove_columns=["file", "audio", "text"])
75
+
76
+ # Split into train & eval
77
  train_size = int(0.8 * len(dataset))
78
+ train_dataset = dataset.select(range(train_size))
79
+ eval_dataset = dataset.select(range(train_size, len(dataset)))
80
 
81
+ print(f"βœ… Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
82
 
83
  # ================================
84
+ # 4️⃣ Training Arguments & Trainer
85
  # ================================
86
  training_args = TrainingArguments(
87
  output_dir="./asr_model_finetuned",
88
+ evaluation_strategy="epoch",
89
  save_strategy="epoch",
90
  learning_rate=5e-5,
91
  per_device_train_batch_size=8,
 
95
  logging_dir="./logs",
96
  logging_steps=500,
97
  save_total_limit=2,
98
+ push_to_hub=True,
99
+ hub_model_id="tahirsher/ASR_Model_for_Transcription_into_Text",
100
  hub_token=HF_TOKEN,
101
  )
102
 
 
103
  data_collator = DataCollatorForSeq2Seq(tokenizer=processor.tokenizer, model=model, return_tensors="pt")
104
 
 
105
  trainer = Trainer(
106
  model=model,
107
  args=training_args,
 
111
  )
112
 
113
  # ================================
114
+ # 5️⃣ Fine-Tuning Execution & Training Stats
115
  # ================================
116
  if st.button("Start Fine-Tuning"):
117
  with st.spinner("Fine-tuning in progress... Please wait!"):
118
  trainer.train()
119
  st.success("βœ… Fine-Tuning Completed! Model updated.")
120
 
121
+ # Plot Training Loss
122
+ train_loss = trainer.state.log_history
123
+ losses = [entry['loss'] for entry in train_loss if 'loss' in entry]
124
+
125
+ plt.figure(figsize=(8, 5))
126
+ plt.plot(range(len(losses)), losses, label="Training Loss", color="blue")
127
+ plt.xlabel("Steps")
128
+ plt.ylabel("Loss")
129
+ plt.title("Training Loss Over Time")
130
+ plt.legend()
131
+ st.pyplot(plt)
132
+
133
  # ================================
134
+ # 6️⃣ Streamlit ASR Web App (Proper Decoding)
135
  # ================================
136
+ st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Fine-Tuning 🎢")
137
 
 
138
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
139
 
140
  if audio_file:
 
141
  audio_path = "temp_audio.wav"
142
  with open(audio_path, "wb") as f:
143
  f.write(audio_file.read())
144
 
 
145
  waveform, sample_rate = torchaudio.load(audio_path)
146
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
147
 
 
148
  input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features
149
 
150
+ input_tensor = input_features.to(device)
 
 
 
 
151
 
152
+ # βœ… FIX: Use `generate()` for Proper Transcription
153
  with torch.no_grad():
154
+ generated_ids = model.generate(
155
+ input_tensor,
156
+ max_length=500,
157
+ num_beams=5,
158
+ do_sample=True,
159
+ top_k=50
160
+ )
161
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
162
 
163
  # Display transcription
164
  st.success("πŸ“„ Transcription:")