Update app.py
Browse files
app.py
CHANGED
@@ -4,19 +4,21 @@ import torch
|
|
4 |
import torchaudio
|
5 |
import numpy as np
|
6 |
import streamlit as st
|
|
|
7 |
from huggingface_hub import login
|
|
|
8 |
from transformers import (
|
9 |
AutoProcessor,
|
10 |
AutoModelForSpeechSeq2Seq,
|
11 |
TrainingArguments,
|
12 |
Trainer,
|
13 |
-
DataCollatorForSeq2Seq,
|
14 |
)
|
15 |
|
16 |
# ================================
|
17 |
# 1οΈβ£ Authenticate with Hugging Face Hub (Securely)
|
18 |
# ================================
|
19 |
-
HF_TOKEN = os.getenv("hf_token")
|
20 |
|
21 |
if HF_TOKEN is None:
|
22 |
raise ValueError("β Hugging Face API token not found. Please set it in Secrets.")
|
@@ -30,18 +32,16 @@ MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
|
|
30 |
processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
31 |
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
|
32 |
|
33 |
-
# Move model to GPU if available
|
34 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
35 |
model.to(device)
|
36 |
print(f"β
Model loaded on {device}")
|
37 |
|
38 |
# ================================
|
39 |
-
# 3οΈβ£ Load
|
40 |
# ================================
|
41 |
DATASET_TAR_PATH = "dev-clean.tar.gz"
|
42 |
EXTRACT_PATH = "./librispeech_dev_clean"
|
43 |
|
44 |
-
# Extract dataset if not already extracted
|
45 |
if not os.path.exists(EXTRACT_PATH):
|
46 |
print("π Extracting dataset...")
|
47 |
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
|
@@ -50,58 +50,42 @@ if not os.path.exists(EXTRACT_PATH):
|
|
50 |
else:
|
51 |
print("β
Dataset already extracted.")
|
52 |
|
53 |
-
#
|
54 |
-
|
55 |
|
56 |
-
#
|
57 |
-
|
58 |
-
"
|
59 |
-
audio_files = []
|
60 |
-
for root, _, files in os.walk(base_folder):
|
61 |
-
for file in files:
|
62 |
-
if file.endswith(".flac"):
|
63 |
-
audio_files.append(os.path.join(root, file))
|
64 |
-
return audio_files
|
65 |
|
66 |
-
#
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
raise FileNotFoundError(f"β No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
|
71 |
-
|
72 |
-
print(f"β
Found {len(audio_files)} audio files in dataset!")
|
73 |
-
|
74 |
-
# ================================
|
75 |
-
# 4οΈβ£ Preprocess Dataset (Fixed input_features)
|
76 |
-
# ================================
|
77 |
-
def load_and_process_audio(audio_path):
|
78 |
-
"""Loads and processes a single audio file into model format."""
|
79 |
-
waveform, sample_rate = torchaudio.load(audio_path)
|
80 |
-
|
81 |
-
# Resample to 16kHz
|
82 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
93 |
train_size = int(0.8 * len(dataset))
|
94 |
-
train_dataset = dataset
|
95 |
-
eval_dataset = dataset
|
96 |
|
97 |
-
print(f"β
Dataset
|
98 |
|
99 |
# ================================
|
100 |
-
#
|
101 |
# ================================
|
102 |
training_args = TrainingArguments(
|
103 |
output_dir="./asr_model_finetuned",
|
104 |
-
|
105 |
save_strategy="epoch",
|
106 |
learning_rate=5e-5,
|
107 |
per_device_train_batch_size=8,
|
@@ -111,15 +95,13 @@ training_args = TrainingArguments(
|
|
111 |
logging_dir="./logs",
|
112 |
logging_steps=500,
|
113 |
save_total_limit=2,
|
114 |
-
push_to_hub=True,
|
115 |
-
hub_model_id="tahirsher/ASR_Model_for_Transcription_into_Text",
|
116 |
hub_token=HF_TOKEN,
|
117 |
)
|
118 |
|
119 |
-
# β
FIX: Use correct Data Collator
|
120 |
data_collator = DataCollatorForSeq2Seq(tokenizer=processor.tokenizer, model=model, return_tensors="pt")
|
121 |
|
122 |
-
# Define Trainer
|
123 |
trainer = Trainer(
|
124 |
model=model,
|
125 |
args=training_args,
|
@@ -129,45 +111,54 @@ trainer = Trainer(
|
|
129 |
)
|
130 |
|
131 |
# ================================
|
132 |
-
#
|
133 |
# ================================
|
134 |
if st.button("Start Fine-Tuning"):
|
135 |
with st.spinner("Fine-tuning in progress... Please wait!"):
|
136 |
trainer.train()
|
137 |
st.success("β
Fine-Tuning Completed! Model updated.")
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
# ================================
|
140 |
-
#
|
141 |
# ================================
|
142 |
-
st.title("ποΈ Speech-to-Text ASR with Fine-Tuning πΆ")
|
143 |
|
144 |
-
# Upload audio file
|
145 |
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
|
146 |
|
147 |
if audio_file:
|
148 |
-
# Save uploaded file temporarily
|
149 |
audio_path = "temp_audio.wav"
|
150 |
with open(audio_path, "wb") as f:
|
151 |
f.write(audio_file.read())
|
152 |
|
153 |
-
# Load and process audio
|
154 |
waveform, sample_rate = torchaudio.load(audio_path)
|
155 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
156 |
|
157 |
-
# Convert audio to model input
|
158 |
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features
|
159 |
|
160 |
-
|
161 |
-
input_tensor = input_features.to(device) # Move to GPU/CPU
|
162 |
-
|
163 |
-
# β
FIX: Provide decoder_input_ids
|
164 |
-
decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device)
|
165 |
|
166 |
-
#
|
167 |
with torch.no_grad():
|
168 |
-
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
# Display transcription
|
173 |
st.success("π Transcription:")
|
|
|
4 |
import torchaudio
|
5 |
import numpy as np
|
6 |
import streamlit as st
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
from huggingface_hub import login
|
9 |
+
from datasets import load_dataset, DatasetDict
|
10 |
from transformers import (
|
11 |
AutoProcessor,
|
12 |
AutoModelForSpeechSeq2Seq,
|
13 |
TrainingArguments,
|
14 |
Trainer,
|
15 |
+
DataCollatorForSeq2Seq,
|
16 |
)
|
17 |
|
18 |
# ================================
|
19 |
# 1οΈβ£ Authenticate with Hugging Face Hub (Securely)
|
20 |
# ================================
|
21 |
+
HF_TOKEN = os.getenv("hf_token")
|
22 |
|
23 |
if HF_TOKEN is None:
|
24 |
raise ValueError("β Hugging Face API token not found. Please set it in Secrets.")
|
|
|
32 |
processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
33 |
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
|
34 |
|
|
|
35 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
36 |
model.to(device)
|
37 |
print(f"β
Model loaded on {device}")
|
38 |
|
39 |
# ================================
|
40 |
+
# 3οΈβ£ Load and Prepare Dataset
|
41 |
# ================================
|
42 |
DATASET_TAR_PATH = "dev-clean.tar.gz"
|
43 |
EXTRACT_PATH = "./librispeech_dev_clean"
|
44 |
|
|
|
45 |
if not os.path.exists(EXTRACT_PATH):
|
46 |
print("π Extracting dataset...")
|
47 |
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
|
|
|
50 |
else:
|
51 |
print("β
Dataset already extracted.")
|
52 |
|
53 |
+
# Load dataset with transcripts
|
54 |
+
dataset = load_dataset("librispeech_asr", "clean", split="train")
|
55 |
|
56 |
+
# Ensure dataset has transcripts
|
57 |
+
if "text" not in dataset.column_names:
|
58 |
+
raise ValueError("β Dataset is missing transcription text!")
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
+
# Preprocessing Function
|
61 |
+
def preprocess_data(batch):
|
62 |
+
# Process audio
|
63 |
+
waveform, sample_rate = torchaudio.load(batch["file"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
65 |
+
|
66 |
+
batch["input_features"] = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
|
67 |
+
|
68 |
+
# Tokenize transcript text
|
69 |
+
batch["labels"] = processor.tokenizer(batch["text"], padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
|
70 |
+
|
71 |
+
return batch
|
72 |
+
|
73 |
+
# Apply preprocessing
|
74 |
+
dataset = dataset.map(preprocess_data, remove_columns=["file", "audio", "text"])
|
75 |
+
|
76 |
+
# Split into train & eval
|
77 |
train_size = int(0.8 * len(dataset))
|
78 |
+
train_dataset = dataset.select(range(train_size))
|
79 |
+
eval_dataset = dataset.select(range(train_size, len(dataset)))
|
80 |
|
81 |
+
print(f"β
Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
|
82 |
|
83 |
# ================================
|
84 |
+
# 4οΈβ£ Training Arguments & Trainer
|
85 |
# ================================
|
86 |
training_args = TrainingArguments(
|
87 |
output_dir="./asr_model_finetuned",
|
88 |
+
evaluation_strategy="epoch",
|
89 |
save_strategy="epoch",
|
90 |
learning_rate=5e-5,
|
91 |
per_device_train_batch_size=8,
|
|
|
95 |
logging_dir="./logs",
|
96 |
logging_steps=500,
|
97 |
save_total_limit=2,
|
98 |
+
push_to_hub=True,
|
99 |
+
hub_model_id="tahirsher/ASR_Model_for_Transcription_into_Text",
|
100 |
hub_token=HF_TOKEN,
|
101 |
)
|
102 |
|
|
|
103 |
data_collator = DataCollatorForSeq2Seq(tokenizer=processor.tokenizer, model=model, return_tensors="pt")
|
104 |
|
|
|
105 |
trainer = Trainer(
|
106 |
model=model,
|
107 |
args=training_args,
|
|
|
111 |
)
|
112 |
|
113 |
# ================================
|
114 |
+
# 5οΈβ£ Fine-Tuning Execution & Training Stats
|
115 |
# ================================
|
116 |
if st.button("Start Fine-Tuning"):
|
117 |
with st.spinner("Fine-tuning in progress... Please wait!"):
|
118 |
trainer.train()
|
119 |
st.success("β
Fine-Tuning Completed! Model updated.")
|
120 |
|
121 |
+
# Plot Training Loss
|
122 |
+
train_loss = trainer.state.log_history
|
123 |
+
losses = [entry['loss'] for entry in train_loss if 'loss' in entry]
|
124 |
+
|
125 |
+
plt.figure(figsize=(8, 5))
|
126 |
+
plt.plot(range(len(losses)), losses, label="Training Loss", color="blue")
|
127 |
+
plt.xlabel("Steps")
|
128 |
+
plt.ylabel("Loss")
|
129 |
+
plt.title("Training Loss Over Time")
|
130 |
+
plt.legend()
|
131 |
+
st.pyplot(plt)
|
132 |
+
|
133 |
# ================================
|
134 |
+
# 6οΈβ£ Streamlit ASR Web App (Proper Decoding)
|
135 |
# ================================
|
136 |
+
st.title("ποΈ Speech-to-Text ASR Model with Fine-Tuning πΆ")
|
137 |
|
|
|
138 |
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
|
139 |
|
140 |
if audio_file:
|
|
|
141 |
audio_path = "temp_audio.wav"
|
142 |
with open(audio_path, "wb") as f:
|
143 |
f.write(audio_file.read())
|
144 |
|
|
|
145 |
waveform, sample_rate = torchaudio.load(audio_path)
|
146 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
147 |
|
|
|
148 |
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features
|
149 |
|
150 |
+
input_tensor = input_features.to(device)
|
|
|
|
|
|
|
|
|
151 |
|
152 |
+
# β
FIX: Use `generate()` for Proper Transcription
|
153 |
with torch.no_grad():
|
154 |
+
generated_ids = model.generate(
|
155 |
+
input_tensor,
|
156 |
+
max_length=500,
|
157 |
+
num_beams=5,
|
158 |
+
do_sample=True,
|
159 |
+
top_k=50
|
160 |
+
)
|
161 |
+
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
162 |
|
163 |
# Display transcription
|
164 |
st.success("π Transcription:")
|