File size: 11,172 Bytes
6596431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import os
import time
import logging
import re
print(f"Initial logging._nameToLevel: {logging._nameToLevel}")
from pathlib import Path
from typing import List, Dict, Any, Optional

import soundfile as sf
import numpy as np
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

# Ensure sensevoice_rknn.py is in the same directory or PYTHONPATH
# Add the directory of this script to sys.path if sensevoice_rknn is not found directly
import sys
SCRIPT_DIR = Path(__file__).resolve().parent
if str(SCRIPT_DIR) not in sys.path:
    sys.path.append(str(SCRIPT_DIR))

try:
    from sensevoice_rknn import WavFrontend, SenseVoiceInferenceSession, FSMNVad, languages
except ImportError as e:
    logging.error(f"Error importing from sensevoice_rknn.py: {e}")
    logging.error("Please ensure sensevoice_rknn.py is in the same directory as server.py or in your PYTHONPATH.")
    # Fallback for critical components if import fails, to allow FastAPI to at least start and show an error
    class WavFrontend:
        def __init__(self, *args, **kwargs): raise NotImplementedError("WavFrontend not loaded")
        def get_features(self, *args, **kwargs): raise NotImplementedError("WavFrontend not loaded")
    class SenseVoiceInferenceSession:
        def __init__(self, *args, **kwargs): raise NotImplementedError("SenseVoiceInferenceSession not loaded")
        def __call__(self, *args, **kwargs): raise NotImplementedError("SenseVoiceInferenceSession not loaded")
    class FSMNVad:
        def __init__(self, *args, **kwargs): raise NotImplementedError("FSMNVad not loaded")
        def segments_offline(self, *args, **kwargs): raise NotImplementedError("FSMNVad not loaded")
        class Vad:
            def all_reset_detection(self, *args, **kwargs): raise NotImplementedError("FSMNVad not loaded")
        vad = Vad()

    languages = {"en": 4} # Default fallback

app = FastAPI()

# Logging will be handled by Uvicorn's default configuration or a custom log_config if provided to uvicorn.run
# Get a logger instance for application-specific logs if needed
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) # Set level for this specific logger

# --- Model Configuration & Loading ---
MODEL_BASE_PATH = Path(__file__).resolve().parent

# These paths should match those used in sensevoice_rknn.py's main function
# or be configurable if they differ.
MVN_PATH = MODEL_BASE_PATH / "am.mvn"
EMBEDDING_NPY_PATH = MODEL_BASE_PATH / "embedding.npy"
ENCODER_RKNN_PATH = MODEL_BASE_PATH / "sense-voice-encoder.rknn"
BPE_MODEL_PATH = MODEL_BASE_PATH / "chn_jpn_yue_eng_ko_spectok.bpe.model"
VAD_CONFIG_DIR = MODEL_BASE_PATH # Assuming fsmn-config.yaml and fsmnvad-offline.onnx are here

# Global model instances
w_frontend: Optional[WavFrontend] = None
asr_model: Optional[SenseVoiceInferenceSession] = None
vad_model: Optional[FSMNVad] = None

@app.on_event("startup")
def load_models():
    global w_frontend, asr_model, vad_model
    logging.info("Loading models...")
    start_time = time.time()
    try:
        if not MVN_PATH.exists():
            raise FileNotFoundError(f"CMVN file not found: {MVN_PATH}")
        w_frontend = WavFrontend(cmvn_file=str(MVN_PATH))

        if not EMBEDDING_NPY_PATH.exists() or not ENCODER_RKNN_PATH.exists() or not BPE_MODEL_PATH.exists():
            raise FileNotFoundError(
                f"One or more ASR model files not found: "
                f"Embedding: {EMBEDDING_NPY_PATH}, Encoder: {ENCODER_RKNN_PATH}, BPE: {BPE_MODEL_PATH}"
            )
        asr_model = SenseVoiceInferenceSession(
            embedding_model_file=str(EMBEDDING_NPY_PATH),
            encoder_model_file=str(ENCODER_RKNN_PATH),
            bpe_model_file=str(BPE_MODEL_PATH),
            # Assuming default device_id and num_threads as in sensevoice_rknn.py's main
            device_id=-1, 
            intra_op_num_threads=4 
        )

        # Check for VAD model files (fsmn-config.yaml, fsmnvad-offline.onnx)
        if not (VAD_CONFIG_DIR / "fsmn-config.yaml").exists() or not (VAD_CONFIG_DIR / "fsmnvad-offline.onnx").exists():
             raise FileNotFoundError(f"VAD config or model not found in {VAD_CONFIG_DIR}")
        vad_model = FSMNVad(config_dir=str(VAD_CONFIG_DIR))
        
        logging.info(f"Models loaded successfully in {time.time() - start_time:.2f} seconds.")
    except FileNotFoundError as e:
        logging.error(f"Model loading failed: {e}")
        # Keep models as None, endpoints will raise errors
    except Exception as e:
        logging.error(f"An unexpected error occurred during model loading: {e}")
        # Keep models as None

class TranscribeRequest(BaseModel):
    audio_file_path: str
    language: str = "en"  # Default to English
    use_itn: bool = False

class Segment(BaseModel):
    start_time_s: float
    end_time_s: float
    text: str

class TranscribeResponse(BaseModel):
    full_transcription: str
    segments: List[Segment]

@app.post("/transcribe", response_model=str)
async def transcribe_audio(request: TranscribeRequest):
    if w_frontend is None or asr_model is None or vad_model is None:
        logging.error("Models not loaded. Transcription cannot proceed.")
        raise HTTPException(status_code=503, detail="Models are not loaded. Please check server logs.")

    audio_path = Path(request.audio_file_path)
    if not audio_path.exists() or not audio_path.is_file():
        logging.error(f"Audio file not found: {audio_path}")
        raise HTTPException(status_code=404, detail=f"Audio file not found: {audio_path}")

    try:
        waveform, sample_rate = sf.read(
            str(audio_path),
            dtype="float32",
            always_2d=True
        )
    except Exception as e:
        logging.error(f"Error reading audio file {audio_path}: {e}")
        raise HTTPException(status_code=400, detail=f"Could not read audio file: {e}")

    if sample_rate != 16000:
        # Basic resampling could be added here if needed, or just raise an error
        logging.warning(f"Audio sample rate is {sample_rate}Hz, expected 16000Hz. Results may be suboptimal.")
        # For now, we proceed but log a warning. For critical applications, convert or reject.

    logging.info(f"Processing audio: {audio_path}, Duration: {len(waveform) / sample_rate:.2f}s, Channels: {waveform.shape[1]}")

    lang_code = languages.get(request.language.lower())
    if lang_code is None:
        logging.warning(f"Unsupported language: {request.language}. Defaulting to 'en'. Supported: {list(languages.keys())}")
        lang_code = languages.get("en", 0) # Fallback to 'en' or 'auto' if 'en' isn't in languages

    all_segments_text: List[str] = []
    detailed_segments: List[Segment] = []
    processing_start_time = time.time()

    for channel_id in range(waveform.shape[1]):
        channel_data = waveform[:, channel_id]
        logging.info(f"Processing channel {channel_id + 1}/{waveform.shape[1]}")
        
        try:
            # Ensure channel_data is 1D for VAD if it expects that
            speech_segments = vad_model.segments_offline(channel_data) # segments_offline expects 1D array
        except Exception as e:
            logging.error(f"VAD processing failed for channel {channel_id}: {e}")
            # Optionally skip this channel or raise an error for the whole request
            continue # Skip to next channel

        for part_idx, part in enumerate(speech_segments):
            start_sample = int(part[0] * 16)  # VAD returns ms, convert to samples (16 samples/ms for 16kHz)
            end_sample = int(part[1] * 16)
            segment_audio = channel_data[start_sample:end_sample]

            if len(segment_audio) == 0:
                logging.info(f"Empty audio segment for channel {channel_id}, part {part_idx}. Skipping.")
                continue
            
            try:
                # Ensure get_features expects 1D array
                audio_feats = w_frontend.get_features(segment_audio) 
                # ASR model expects batch dimension, add [None, ...]
                asr_result_text_raw = asr_model(
                    audio_feats[None, ...],
                    language=lang_code,
                    use_itn=request.use_itn,
                )
                # Remove tags like <|en|>, <|HAPPY|>, etc.
                asr_result_text_cleaned = re.sub(r"<\|[^\|]+\|>", "", asr_result_text_raw).strip()
                
                segment_start_s = part[0] / 1000.0
                segment_end_s = part[1] / 1000.0
                logging.info(f"[Ch{channel_id}] [{segment_start_s:.2f}s - {segment_end_s:.2f}s] Raw: {asr_result_text_raw} Cleaned: {asr_result_text_cleaned}")
                all_segments_text.append(asr_result_text_cleaned)
                detailed_segments.append(Segment(start_time_s=segment_start_s, end_time_s=segment_end_s, text=asr_result_text_cleaned))
            except Exception as e:
                logging.error(f"ASR processing failed for segment {part_idx} in channel {channel_id}: {e}")
                # Optionally add a placeholder or skip this segment's text
                detailed_segments.append(Segment(start_time_s=part[0]/1000.0, end_time_s=part[1]/1000.0, text="[ASR_ERROR]"))

        vad_model.vad.all_reset_detection() # Reset VAD state for next channel or call

    full_transcription = " ".join(all_segments_text).strip()
    logging.info(f"Transcription complete in {time.time() - processing_start_time:.2f}s. Result: {full_transcription}")

    return full_transcription

if __name__ == "__main__":
    import uvicorn

    MINIMAL_LOGGING_CONFIG = {
        "version": 1,
        "disable_existing_loggers": False, # Let other loggers (like our app logger) exist
        "formatters": {
            "default": {
                "()": "uvicorn.logging.DefaultFormatter",
                "fmt": "%(levelprefix)s %(message)s",
                "use_colors": None,
            },
        },
        "handlers": {
            "default": {
                "formatter": "default",
                "class": "logging.StreamHandler",
                "stream": "ext://sys.stderr",
            },
        },
        "loggers": {
            "uvicorn": { # Uvicorn's own operational logs
                "handlers": ["default"],
                "level": logging.INFO, # Explicitly use integer
                "propagate": False,
            },
            "uvicorn.error": { # Logs for errors within Uvicorn
                "handlers": ["default"],
                "level": logging.INFO, # Explicitly use integer
                "propagate": False,
            },
            # We are deliberately not configuring uvicorn.access here for simplicity
            # It might default to INFO or be silent if not configured and no parent handler catches it.
        },
        # Ensure our application logger also works if needed
        __name__: {
            "handlers": ["default"],
            "level": logging.INFO,
            "propagate": False,
        }
    }

    logger.info(f"Attempting to run Uvicorn with minimal explicit log_config.")
    uvicorn.run(app, host="0.0.0.0", port=8000, log_config=MINIMAL_LOGGING_CONFIG)