|
|
|
|
|
import numpy as np |
|
from typing import List, Optional, Union |
|
from transformers import WhisperFeatureExtractor, Qwen2TokenizerFast |
|
from transformers.processing_utils import ProcessorMixin |
|
from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput |
|
from transformers.feature_extraction_utils import BatchFeature |
|
|
|
|
|
class AudioOnlyProcessor(ProcessorMixin): |
|
""" |
|
A processor class for AudioOnlyThinker. Handles only text + audio input (no image/video support). |
|
""" |
|
|
|
feature_extractor_class = "WhisperFeatureExtractor" |
|
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") |
|
model_input_names = ["input_features", "attention_mask", "input_ids", "feature_attention_mask"] |
|
|
|
def __init__(self, feature_extractor=None, tokenizer=None, chat_template=None): |
|
self.audio_token = "<|AUDIO|>" |
|
self.audio_bos_token = "<|audio_bos|>" |
|
self.audio_eos_token = "<|audio_eos|>" |
|
self.tokenizer = tokenizer |
|
self.feature_extractor = feature_extractor |
|
self.current_processor = self.tokenizer |
|
self.chat_template = chat_template |
|
|
|
def __call__( |
|
self, |
|
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], |
|
audios: Union[np.ndarray, List[np.ndarray]], |
|
sampling_rate: Optional[int] = 16000, |
|
padding: Union[bool, str, PaddingStrategy] = False, |
|
**kwargs, |
|
) -> BatchFeature: |
|
if not isinstance(text, list): |
|
text = [text] |
|
|
|
audios_inputs = self.feature_extractor( |
|
audios, sampling_rate=sampling_rate, return_attention_mask=True, padding="max_length", **kwargs |
|
) |
|
audios_inputs["feature_attention_mask"] = audios_inputs.pop("attention_mask") |
|
audios_inputs["input_features"] = audios_inputs.pop("input_features") |
|
|
|
input_lengths = (audios_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1 |
|
audio_lengths = (input_lengths - 2) // 2 + 1 |
|
|
|
|
|
for i in range(len(text)): |
|
text[i] = text[i].replace( |
|
self.audio_token, |
|
"<|audio_placeholder|>" * audio_lengths[0], |
|
1, |
|
) |
|
text[i] = text[i].replace("<|audio_placeholder|>", self.audio_token) |
|
|
|
text_inputs = self.tokenizer(text, padding=padding, return_tensors=kwargs.get("return_tensors", None)) |
|
|
|
return BatchFeature(data={**text_inputs, **audios_inputs}, tensor_type=kwargs.get("return_tensors")) |
|
|
|
def apply_chat_template(self, conversations, chat_template=None, **kwargs): |
|
if isinstance(conversations[0], dict): |
|
conversations = [conversations] |
|
return self.tokenizer.apply_chat_template(conversations, chat_template=chat_template, **kwargs) |
|
|
|
def batch_decode(self, *args, **kwargs): |
|
return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
|
def decode(self, *args, **kwargs): |
|
return self.tokenizer.decode(*args, **kwargs) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
tokenizer = Qwen2TokenizerFast.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
feature_extractor = WhisperFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer) |
|
|
|
def save_pretrained(self, save_directory): |
|
self.tokenizer.save_pretrained(save_directory) |
|
self.feature_extractor.save_pretrained(save_directory) |
|
|