ThomasTheMaker commited on
Commit
6596431
·
verified ·
1 Parent(s): 919eceb

Upload server.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. server.py +246 -0
server.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import logging
4
+ import re
5
+ print(f"Initial logging._nameToLevel: {logging._nameToLevel}")
6
+ from pathlib import Path
7
+ from typing import List, Dict, Any, Optional
8
+
9
+ import soundfile as sf
10
+ import numpy as np
11
+ from fastapi import FastAPI, HTTPException
12
+ from pydantic import BaseModel
13
+
14
+ # Ensure sensevoice_rknn.py is in the same directory or PYTHONPATH
15
+ # Add the directory of this script to sys.path if sensevoice_rknn is not found directly
16
+ import sys
17
+ SCRIPT_DIR = Path(__file__).resolve().parent
18
+ if str(SCRIPT_DIR) not in sys.path:
19
+ sys.path.append(str(SCRIPT_DIR))
20
+
21
+ try:
22
+ from sensevoice_rknn import WavFrontend, SenseVoiceInferenceSession, FSMNVad, languages
23
+ except ImportError as e:
24
+ logging.error(f"Error importing from sensevoice_rknn.py: {e}")
25
+ logging.error("Please ensure sensevoice_rknn.py is in the same directory as server.py or in your PYTHONPATH.")
26
+ # Fallback for critical components if import fails, to allow FastAPI to at least start and show an error
27
+ class WavFrontend:
28
+ def __init__(self, *args, **kwargs): raise NotImplementedError("WavFrontend not loaded")
29
+ def get_features(self, *args, **kwargs): raise NotImplementedError("WavFrontend not loaded")
30
+ class SenseVoiceInferenceSession:
31
+ def __init__(self, *args, **kwargs): raise NotImplementedError("SenseVoiceInferenceSession not loaded")
32
+ def __call__(self, *args, **kwargs): raise NotImplementedError("SenseVoiceInferenceSession not loaded")
33
+ class FSMNVad:
34
+ def __init__(self, *args, **kwargs): raise NotImplementedError("FSMNVad not loaded")
35
+ def segments_offline(self, *args, **kwargs): raise NotImplementedError("FSMNVad not loaded")
36
+ class Vad:
37
+ def all_reset_detection(self, *args, **kwargs): raise NotImplementedError("FSMNVad not loaded")
38
+ vad = Vad()
39
+
40
+ languages = {"en": 4} # Default fallback
41
+
42
+ app = FastAPI()
43
+
44
+ # Logging will be handled by Uvicorn's default configuration or a custom log_config if provided to uvicorn.run
45
+ # Get a logger instance for application-specific logs if needed
46
+ logger = logging.getLogger(__name__)
47
+ logger.setLevel(logging.INFO) # Set level for this specific logger
48
+
49
+ # --- Model Configuration & Loading ---
50
+ MODEL_BASE_PATH = Path(__file__).resolve().parent
51
+
52
+ # These paths should match those used in sensevoice_rknn.py's main function
53
+ # or be configurable if they differ.
54
+ MVN_PATH = MODEL_BASE_PATH / "am.mvn"
55
+ EMBEDDING_NPY_PATH = MODEL_BASE_PATH / "embedding.npy"
56
+ ENCODER_RKNN_PATH = MODEL_BASE_PATH / "sense-voice-encoder.rknn"
57
+ BPE_MODEL_PATH = MODEL_BASE_PATH / "chn_jpn_yue_eng_ko_spectok.bpe.model"
58
+ VAD_CONFIG_DIR = MODEL_BASE_PATH # Assuming fsmn-config.yaml and fsmnvad-offline.onnx are here
59
+
60
+ # Global model instances
61
+ w_frontend: Optional[WavFrontend] = None
62
+ asr_model: Optional[SenseVoiceInferenceSession] = None
63
+ vad_model: Optional[FSMNVad] = None
64
+
65
+ @app.on_event("startup")
66
+ def load_models():
67
+ global w_frontend, asr_model, vad_model
68
+ logging.info("Loading models...")
69
+ start_time = time.time()
70
+ try:
71
+ if not MVN_PATH.exists():
72
+ raise FileNotFoundError(f"CMVN file not found: {MVN_PATH}")
73
+ w_frontend = WavFrontend(cmvn_file=str(MVN_PATH))
74
+
75
+ if not EMBEDDING_NPY_PATH.exists() or not ENCODER_RKNN_PATH.exists() or not BPE_MODEL_PATH.exists():
76
+ raise FileNotFoundError(
77
+ f"One or more ASR model files not found: "
78
+ f"Embedding: {EMBEDDING_NPY_PATH}, Encoder: {ENCODER_RKNN_PATH}, BPE: {BPE_MODEL_PATH}"
79
+ )
80
+ asr_model = SenseVoiceInferenceSession(
81
+ embedding_model_file=str(EMBEDDING_NPY_PATH),
82
+ encoder_model_file=str(ENCODER_RKNN_PATH),
83
+ bpe_model_file=str(BPE_MODEL_PATH),
84
+ # Assuming default device_id and num_threads as in sensevoice_rknn.py's main
85
+ device_id=-1,
86
+ intra_op_num_threads=4
87
+ )
88
+
89
+ # Check for VAD model files (fsmn-config.yaml, fsmnvad-offline.onnx)
90
+ if not (VAD_CONFIG_DIR / "fsmn-config.yaml").exists() or not (VAD_CONFIG_DIR / "fsmnvad-offline.onnx").exists():
91
+ raise FileNotFoundError(f"VAD config or model not found in {VAD_CONFIG_DIR}")
92
+ vad_model = FSMNVad(config_dir=str(VAD_CONFIG_DIR))
93
+
94
+ logging.info(f"Models loaded successfully in {time.time() - start_time:.2f} seconds.")
95
+ except FileNotFoundError as e:
96
+ logging.error(f"Model loading failed: {e}")
97
+ # Keep models as None, endpoints will raise errors
98
+ except Exception as e:
99
+ logging.error(f"An unexpected error occurred during model loading: {e}")
100
+ # Keep models as None
101
+
102
+ class TranscribeRequest(BaseModel):
103
+ audio_file_path: str
104
+ language: str = "en" # Default to English
105
+ use_itn: bool = False
106
+
107
+ class Segment(BaseModel):
108
+ start_time_s: float
109
+ end_time_s: float
110
+ text: str
111
+
112
+ class TranscribeResponse(BaseModel):
113
+ full_transcription: str
114
+ segments: List[Segment]
115
+
116
+ @app.post("/transcribe", response_model=str)
117
+ async def transcribe_audio(request: TranscribeRequest):
118
+ if w_frontend is None or asr_model is None or vad_model is None:
119
+ logging.error("Models not loaded. Transcription cannot proceed.")
120
+ raise HTTPException(status_code=503, detail="Models are not loaded. Please check server logs.")
121
+
122
+ audio_path = Path(request.audio_file_path)
123
+ if not audio_path.exists() or not audio_path.is_file():
124
+ logging.error(f"Audio file not found: {audio_path}")
125
+ raise HTTPException(status_code=404, detail=f"Audio file not found: {audio_path}")
126
+
127
+ try:
128
+ waveform, sample_rate = sf.read(
129
+ str(audio_path),
130
+ dtype="float32",
131
+ always_2d=True
132
+ )
133
+ except Exception as e:
134
+ logging.error(f"Error reading audio file {audio_path}: {e}")
135
+ raise HTTPException(status_code=400, detail=f"Could not read audio file: {e}")
136
+
137
+ if sample_rate != 16000:
138
+ # Basic resampling could be added here if needed, or just raise an error
139
+ logging.warning(f"Audio sample rate is {sample_rate}Hz, expected 16000Hz. Results may be suboptimal.")
140
+ # For now, we proceed but log a warning. For critical applications, convert or reject.
141
+
142
+ logging.info(f"Processing audio: {audio_path}, Duration: {len(waveform) / sample_rate:.2f}s, Channels: {waveform.shape[1]}")
143
+
144
+ lang_code = languages.get(request.language.lower())
145
+ if lang_code is None:
146
+ logging.warning(f"Unsupported language: {request.language}. Defaulting to 'en'. Supported: {list(languages.keys())}")
147
+ lang_code = languages.get("en", 0) # Fallback to 'en' or 'auto' if 'en' isn't in languages
148
+
149
+ all_segments_text: List[str] = []
150
+ detailed_segments: List[Segment] = []
151
+ processing_start_time = time.time()
152
+
153
+ for channel_id in range(waveform.shape[1]):
154
+ channel_data = waveform[:, channel_id]
155
+ logging.info(f"Processing channel {channel_id + 1}/{waveform.shape[1]}")
156
+
157
+ try:
158
+ # Ensure channel_data is 1D for VAD if it expects that
159
+ speech_segments = vad_model.segments_offline(channel_data) # segments_offline expects 1D array
160
+ except Exception as e:
161
+ logging.error(f"VAD processing failed for channel {channel_id}: {e}")
162
+ # Optionally skip this channel or raise an error for the whole request
163
+ continue # Skip to next channel
164
+
165
+ for part_idx, part in enumerate(speech_segments):
166
+ start_sample = int(part[0] * 16) # VAD returns ms, convert to samples (16 samples/ms for 16kHz)
167
+ end_sample = int(part[1] * 16)
168
+ segment_audio = channel_data[start_sample:end_sample]
169
+
170
+ if len(segment_audio) == 0:
171
+ logging.info(f"Empty audio segment for channel {channel_id}, part {part_idx}. Skipping.")
172
+ continue
173
+
174
+ try:
175
+ # Ensure get_features expects 1D array
176
+ audio_feats = w_frontend.get_features(segment_audio)
177
+ # ASR model expects batch dimension, add [None, ...]
178
+ asr_result_text_raw = asr_model(
179
+ audio_feats[None, ...],
180
+ language=lang_code,
181
+ use_itn=request.use_itn,
182
+ )
183
+ # Remove tags like <|en|>, <|HAPPY|>, etc.
184
+ asr_result_text_cleaned = re.sub(r"<\|[^\|]+\|>", "", asr_result_text_raw).strip()
185
+
186
+ segment_start_s = part[0] / 1000.0
187
+ segment_end_s = part[1] / 1000.0
188
+ logging.info(f"[Ch{channel_id}] [{segment_start_s:.2f}s - {segment_end_s:.2f}s] Raw: {asr_result_text_raw} Cleaned: {asr_result_text_cleaned}")
189
+ all_segments_text.append(asr_result_text_cleaned)
190
+ detailed_segments.append(Segment(start_time_s=segment_start_s, end_time_s=segment_end_s, text=asr_result_text_cleaned))
191
+ except Exception as e:
192
+ logging.error(f"ASR processing failed for segment {part_idx} in channel {channel_id}: {e}")
193
+ # Optionally add a placeholder or skip this segment's text
194
+ detailed_segments.append(Segment(start_time_s=part[0]/1000.0, end_time_s=part[1]/1000.0, text="[ASR_ERROR]"))
195
+
196
+ vad_model.vad.all_reset_detection() # Reset VAD state for next channel or call
197
+
198
+ full_transcription = " ".join(all_segments_text).strip()
199
+ logging.info(f"Transcription complete in {time.time() - processing_start_time:.2f}s. Result: {full_transcription}")
200
+
201
+ return full_transcription
202
+
203
+ if __name__ == "__main__":
204
+ import uvicorn
205
+
206
+ MINIMAL_LOGGING_CONFIG = {
207
+ "version": 1,
208
+ "disable_existing_loggers": False, # Let other loggers (like our app logger) exist
209
+ "formatters": {
210
+ "default": {
211
+ "()": "uvicorn.logging.DefaultFormatter",
212
+ "fmt": "%(levelprefix)s %(message)s",
213
+ "use_colors": None,
214
+ },
215
+ },
216
+ "handlers": {
217
+ "default": {
218
+ "formatter": "default",
219
+ "class": "logging.StreamHandler",
220
+ "stream": "ext://sys.stderr",
221
+ },
222
+ },
223
+ "loggers": {
224
+ "uvicorn": { # Uvicorn's own operational logs
225
+ "handlers": ["default"],
226
+ "level": logging.INFO, # Explicitly use integer
227
+ "propagate": False,
228
+ },
229
+ "uvicorn.error": { # Logs for errors within Uvicorn
230
+ "handlers": ["default"],
231
+ "level": logging.INFO, # Explicitly use integer
232
+ "propagate": False,
233
+ },
234
+ # We are deliberately not configuring uvicorn.access here for simplicity
235
+ # It might default to INFO or be silent if not configured and no parent handler catches it.
236
+ },
237
+ # Ensure our application logger also works if needed
238
+ __name__: {
239
+ "handlers": ["default"],
240
+ "level": logging.INFO,
241
+ "propagate": False,
242
+ }
243
+ }
244
+
245
+ logger.info(f"Attempting to run Uvicorn with minimal explicit log_config.")
246
+ uvicorn.run(app, host="0.0.0.0", port=8000, log_config=MINIMAL_LOGGING_CONFIG)