tahirsher commited on
Commit
76c5c38
Β·
verified Β·
1 Parent(s): a4a8364

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -83
app.py CHANGED
@@ -4,7 +4,6 @@ import torch
4
  import torchaudio
5
  import numpy as np
6
  import streamlit as st
7
- import matplotlib.pyplot as plt
8
  from huggingface_hub import login
9
  from transformers import (
10
  AutoProcessor,
@@ -13,62 +12,50 @@ from transformers import (
13
  from cryptography.fernet import Fernet
14
 
15
  # ================================
16
- # 1️⃣ Authenticate with Hugging Face Hub
17
  # ================================
18
- HF_TOKEN = os.getenv("hf_token")
 
 
 
 
 
19
 
20
- if HF_TOKEN is None:
21
- raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
22
-
23
- login(token=HF_TOKEN)
24
 
25
  # ================================
26
- # 2️⃣ Load Model & Processor
27
  # ================================
28
- MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
29
- processor = AutoProcessor.from_pretrained(MODEL_NAME)
30
- model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
 
 
 
31
 
32
- device = "cuda" if torch.cuda.is_available() else "cpu"
33
- model.to(device)
34
- print(f"βœ… Model loaded on {device}")
35
 
36
  # ================================
37
- # 3️⃣ Load Dataset
38
  # ================================
39
- DATASET_TAR_PATH = "dev-clean.tar.gz"
40
- EXTRACT_PATH = "./librispeech_dev_clean"
41
-
42
- if not os.path.exists(EXTRACT_PATH):
43
- print("πŸ”„ Extracting dataset...")
44
- with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
45
- tar.extractall(EXTRACT_PATH)
46
- print("βœ… Extraction complete.")
47
- else:
48
- print("βœ… Dataset already extracted.")
49
-
50
- AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
51
-
52
- def find_audio_files(base_folder):
53
- audio_files = []
54
- for root, _, files in os.walk(base_folder):
55
- for file in files:
56
- if file.endswith(".flac"):
57
- audio_files.append(os.path.join(root, file))
58
- return audio_files
59
 
60
- audio_files = find_audio_files(AUDIO_FOLDER)
 
 
 
61
 
62
- if not audio_files:
63
- raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
64
-
65
- print(f"βœ… Found {len(audio_files)} audio files in dataset!")
66
 
67
  # ================================
68
- # 4️⃣ Load Transcripts
69
  # ================================
 
70
  def load_transcripts():
71
- transcript_dict = {}
72
  for root, _, files in os.walk(AUDIO_FOLDER):
73
  for file in files:
74
  if file.endswith(".txt"):
@@ -76,18 +63,13 @@ def load_transcripts():
76
  for line in f:
77
  parts = line.strip().split(" ", 1)
78
  if len(parts) == 2:
79
- file_id, text = parts
80
- transcript_dict[file_id] = text
81
- return transcript_dict
82
 
83
  transcripts = load_transcripts()
84
- if not transcripts:
85
- raise FileNotFoundError("❌ No transcripts found! Check dataset structure.")
86
-
87
- print(f"βœ… Loaded {len(transcripts)} transcripts.")
88
 
89
  # ================================
90
- # 5️⃣ Streamlit Sidebar: Fine-Tuning & Security
91
  # ================================
92
  st.sidebar.title("πŸ”§ Fine-Tuning & Security Settings")
93
 
@@ -101,25 +83,21 @@ enable_encryption = st.sidebar.checkbox("πŸ”’ Encrypt Transcription", value=True
101
  show_transcription = st.sidebar.checkbox("πŸ“– Show Transcription", value=False)
102
 
103
  # ================================
104
- # 6️⃣ Encryption Functionality
105
  # ================================
106
- def generate_key():
107
- return Fernet.generate_key()
108
 
109
- def encrypt_text(text, key):
110
- fernet = Fernet(key)
111
  return fernet.encrypt(text.encode())
112
 
113
- def decrypt_text(encrypted_text, key):
114
- fernet = Fernet(key)
115
  return fernet.decrypt(encrypted_text).decode()
116
 
117
- encryption_key = generate_key()
118
-
119
  # ================================
120
- # 7️⃣ Streamlit ASR Web App
121
  # ================================
122
- st.title("πŸŽ™οΈ Speech-to-Text ASR Model Finetuned on Libri Speech Dataset with Security Features")
123
 
124
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
125
 
@@ -130,48 +108,42 @@ if audio_file:
130
 
131
  waveform, sample_rate = torchaudio.load(audio_path)
132
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
133
- waveform = waveform.to(dtype=torch.float32)
134
 
135
  # ================================
136
- # βœ… Improved Adversarial Attack Handling
137
  # ================================
138
  noise = attack_strength * torch.randn_like(waveform)
139
-
140
- # Apply noise but then perform denoising to counteract attack effects
141
  adversarial_waveform = waveform + noise
142
  adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
 
 
143
  denoised_waveform = torchaudio.functional.vad(adversarial_waveform, sample_rate=16000)
144
 
145
- input_features = processor(denoised_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
146
-
 
 
 
147
  with torch.inference_mode():
148
- generated_ids = model.generate(
149
- input_features,
150
- max_length=200,
151
- num_beams=2,
152
- do_sample=False,
153
- use_cache=True,
154
- attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device),
155
- language="en"
156
- )
157
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
158
-
159
  if attack_strength > 0.3:
160
- st.warning("⚠️ Adversarial attack detected! Mitigated using denoising.")
161
 
162
  # ================================
163
- # βœ… Encryption Handling
164
  # ================================
165
  if enable_encryption:
166
- encrypted_transcription = encrypt_text(transcription, encryption_key)
167
- st.info("πŸ”’ Transcription is encrypted. To view, enable 'Show Transcription' in the sidebar.")
168
 
169
  if show_transcription:
170
- decrypted_text = decrypt_text(encrypted_transcription, encryption_key)
171
  st.success("πŸ“„ Secure Transcription:")
172
  st.write(decrypted_text)
173
  else:
174
- st.write("πŸ”’ [Encrypted] Transcription is hidden. Enable 'Show Transcription' to view.")
175
  else:
176
  st.success("πŸ“„ Transcription:")
177
  st.write(transcription)
 
4
  import torchaudio
5
  import numpy as np
6
  import streamlit as st
 
7
  from huggingface_hub import login
8
  from transformers import (
9
  AutoProcessor,
 
12
  from cryptography.fernet import Fernet
13
 
14
  # ================================
15
+ # 1️⃣ Authenticate with Hugging Face Hub (Cache to prevent re-authentication)
16
  # ================================
17
+ @st.cache_resource
18
+ def authenticate_hf():
19
+ HF_TOKEN = os.getenv("hf_token")
20
+ if HF_TOKEN is None:
21
+ raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
22
+ login(token=HF_TOKEN)
23
 
24
+ authenticate_hf()
 
 
 
25
 
26
  # ================================
27
+ # 2️⃣ Load Model & Processor (Cached)
28
  # ================================
29
+ @st.cache_resource
30
+ def load_model():
31
+ MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
32
+ processor = AutoProcessor.from_pretrained(MODEL_NAME)
33
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME).to("cuda" if torch.cuda.is_available() else "cpu")
34
+ return processor, model
35
 
36
+ processor, model = load_model()
 
 
37
 
38
  # ================================
39
+ # 3️⃣ Dataset Extraction (Cached)
40
  # ================================
41
+ @st.cache_resource
42
+ def extract_dataset():
43
+ DATASET_TAR_PATH = "dev-clean.tar.gz"
44
+ EXTRACT_PATH = "./librispeech_dev_clean"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ if not os.path.exists(EXTRACT_PATH):
47
+ with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
48
+ tar.extractall(EXTRACT_PATH)
49
+ return os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
50
 
51
+ AUDIO_FOLDER = extract_dataset()
 
 
 
52
 
53
  # ================================
54
+ # 4️⃣ Load Transcripts (Cached)
55
  # ================================
56
+ @st.cache_resource
57
  def load_transcripts():
58
+ transcripts = {}
59
  for root, _, files in os.walk(AUDIO_FOLDER):
60
  for file in files:
61
  if file.endswith(".txt"):
 
63
  for line in f:
64
  parts = line.strip().split(" ", 1)
65
  if len(parts) == 2:
66
+ transcripts[parts[0]] = parts[1]
67
+ return transcripts
 
68
 
69
  transcripts = load_transcripts()
 
 
 
 
70
 
71
  # ================================
72
+ # 5️⃣ Streamlit Sidebar for Fine-Tuning & Security
73
  # ================================
74
  st.sidebar.title("πŸ”§ Fine-Tuning & Security Settings")
75
 
 
83
  show_transcription = st.sidebar.checkbox("πŸ“– Show Transcription", value=False)
84
 
85
  # ================================
86
+ # 6️⃣ Encryption Handling (Precomputed Key)
87
  # ================================
88
+ encryption_key = Fernet.generate_key()
89
+ fernet = Fernet(encryption_key)
90
 
91
+ def encrypt_text(text):
 
92
  return fernet.encrypt(text.encode())
93
 
94
+ def decrypt_text(encrypted_text):
 
95
  return fernet.decrypt(encrypted_text).decode()
96
 
 
 
97
  # ================================
98
+ # 7️⃣ Optimized ASR Web App
99
  # ================================
100
+ st.title("πŸŽ™οΈ Speech-to-Text ASR Model Finetuned on Librispeech Corpus with Security Features")
101
 
102
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
103
 
 
108
 
109
  waveform, sample_rate = torchaudio.load(audio_path)
110
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
 
111
 
112
  # ================================
113
+ # βœ… Optimized Adversarial Attack Handling
114
  # ================================
115
  noise = attack_strength * torch.randn_like(waveform)
 
 
116
  adversarial_waveform = waveform + noise
117
  adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
118
+
119
+ # Remove background noise for speed & accuracy
120
  denoised_waveform = torchaudio.functional.vad(adversarial_waveform, sample_rate=16000)
121
 
122
+ # ================================
123
+ # βœ… Fast Transcription Processing
124
+ # ================================
125
+ input_features = processor(denoised_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to("cuda" if torch.cuda.is_available() else "cpu")
126
+
127
  with torch.inference_mode():
128
+ generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False)
 
 
 
 
 
 
 
 
129
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
130
+
131
  if attack_strength > 0.3:
132
+ st.warning("⚠️ Adversarial attack detected! Denoising applied.")
133
 
134
  # ================================
135
+ # βœ… Optimized Encryption Handling
136
  # ================================
137
  if enable_encryption:
138
+ encrypted_transcription = encrypt_text(transcription)
139
+ st.info("πŸ”’ Transcription is encrypted. Enable 'Show Transcription' to view.")
140
 
141
  if show_transcription:
142
+ decrypted_text = decrypt_text(encrypted_transcription)
143
  st.success("πŸ“„ Secure Transcription:")
144
  st.write(decrypted_text)
145
  else:
146
+ st.write("πŸ”’ [Encrypted] Transcription hidden. Enable 'Show Transcription' to view.")
147
  else:
148
  st.success("πŸ“„ Transcription:")
149
  st.write(transcription)