|
|
|
"""Gemmaโ3 Multimodal Utilities โ Whisper v3 Frontend |
|
==================================================== |
|
Pure *preprocessing* helpers for text / image / audio. **Zero model layers.** |
|
|
|
Exports |
|
------- |
|
* `compute_audio_token_count` โ Whisper v3 tokenโlength helper. |
|
* `Gemma3OmniProcessor` โ unify text + image + audio into HF batch. |
|
""" |
|
from __future__ import annotations |
|
|
|
import numpy as np |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
from transformers.image_utils import load_image |
|
from transformers.processing_utils import ProcessorMixin |
|
from transformers.utils import logging |
|
from transformers import AutoProcessor |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
SPECIAL_TOKENS: Dict[str, str] = { |
|
"bos_token": "<bos>", |
|
"eos_token": "<eos>", |
|
"pad_token": "<pad>", |
|
"unk_token": "<unk>", |
|
"boi_token": "<start_of_image>", |
|
"eoi_token": "<end_of_image>", |
|
"image_token": "<image_soft_token>", |
|
"boa_token": "<start_of_audio>", |
|
"eoa_token": "<end_of_audio>", |
|
"audio_token": "<audio_soft_token>", |
|
} |
|
|
|
WHISPER_SR = 16_000 |
|
COMP_RATE = 2 |
|
IMAGE_SEQ_LEN = 256 |
|
|
|
__all__ = [ |
|
"compute_audio_token_count", |
|
"Gemma3OmniProcessor", |
|
] |
|
|
|
|
|
|
|
|
|
|
|
def compute_audio_token_count(mel_frames: int) -> int: |
|
"""Whisper v3 token length = ceil(mel_frames / 2).""" |
|
return (mel_frames + COMP_RATE - 1) // COMP_RATE |
|
|
|
|
|
|
|
|
|
|
|
class Gemma3OmniProcessor(ProcessorMixin): |
|
""" |
|
็ตฑไธ text, image, audio ่ณๆๆต๏ผ็ดๆฅๆก็จๅฎๆน WhisperFeatureExtractorใ |
|
""" |
|
|
|
attributes = ["image_processor", "audio_processor", "tokenizer"] |
|
image_processor_class = "AutoImageProcessor" |
|
audio_processor_class = "AutoFeatureExtractor" |
|
tokenizer_class = "AutoTokenizer" |
|
|
|
def __init__( |
|
self, |
|
image_processor, |
|
audio_processor=None, |
|
tokenizer=None, |
|
*, |
|
special_tokens: Optional[Dict[str, str]] = None, |
|
image_seq_length: int = IMAGE_SEQ_LEN, |
|
whisper_model_name: str = "openai/whisper-large-v3", |
|
) -> None: |
|
super().__init__(tokenizer=tokenizer, image_processor=image_processor, audio_processor=audio_processor) |
|
self.tokenizer = tokenizer |
|
self.image_processor = image_processor |
|
|
|
if audio_processor is None: |
|
self.audio_processor = AutoProcessor.from_pretrained(whisper_model_name).feature_extractor |
|
else: |
|
self.audio_processor = audio_processor |
|
|
|
self.toks: Dict[str, str] = SPECIAL_TOKENS | (special_tokens or {}) |
|
for k, v in self.toks.items(): |
|
if not getattr(self.tokenizer, k, None): |
|
setattr(self.tokenizer, k, v) |
|
|
|
self.boi, self.eoi, self.img_tok = self.toks["boi_token"], self.toks["eoi_token"], self.toks["image_token"] |
|
self.boa, self.eoa, self.aud_tok = self.toks["boa_token"], self.toks["eoa_token"], self.toks["audio_token"] |
|
self.full_img_seq = f"{self.boi}{self.img_tok * image_seq_length}{self.eoi}" |
|
|
|
self.img_id = self.tokenizer.convert_tokens_to_ids(self.img_tok) |
|
self.aud_id = self.tokenizer.convert_tokens_to_ids(self.aud_tok) |
|
|
|
|
|
def _insert_img_tokens(self, s: str) -> str: |
|
return s.replace(self.boi, self.full_img_seq) |
|
|
|
|
|
def apply_chat_template( |
|
self, |
|
example: Dict[str, Any], |
|
*, |
|
add_generation_prompt: bool = True, |
|
tokenize: bool = False, |
|
**tok_kwargs, |
|
) -> Union[str, Dict[str, Any]]: |
|
msgs = example["messages"] if isinstance(example, dict) else example |
|
aud_payload = example.get("audio") if isinstance(example, dict) else None |
|
|
|
nested_imgs, aud_arrs, parts = [], [], [] |
|
for m in msgs: |
|
role = m.get("role") |
|
content = m.get("content", []) |
|
content = content if isinstance(content, list) else [content] |
|
parts.append(f"<start_of_turn>{role}\n") |
|
for c in content: |
|
if isinstance(c, dict): |
|
if c.get("type") == "image": |
|
parts.append(self.boi) |
|
nested_imgs.append([load_image(c.get("url") or c.get("image"))]) |
|
elif c.get("type") == "audio": |
|
parts.append(self.boa) |
|
ad = None |
|
idx = c.get("index") |
|
if aud_payload is not None and idx is not None: |
|
if isinstance(aud_payload, list): |
|
try: |
|
ad = aud_payload[idx] |
|
except Exception as e: |
|
raise KeyError(f"audio list ็ดขๅผ {idx} ็กๆณๅๅพ๏ผ{e}") |
|
elif isinstance(aud_payload, dict) and idx == 0: |
|
ad = aud_payload |
|
else: |
|
raise TypeError(f"audio payload ๅๆ
ไธ็ฌฆ๏ผ{type(aud_payload)}๏ผindex={idx}") |
|
elif "audio" in c: |
|
ad = c["audio"] |
|
else: |
|
raise KeyError(f"audio ๆจ่จไฝๆพไธๅฐ็ดขๅผๆๅ
งๅต audioใcontent: {c}") |
|
|
|
if ad is None or "array" not in ad: |
|
raise KeyError(f"audio object ็ผบๅฐ 'array' ๆฌไฝ๏ผ{ad}") |
|
wav = np.asarray(ad["array"], dtype=np.float32) |
|
sr = ad.get("sampling_rate", WHISPER_SR) |
|
aud_arrs.append((wav, sr)) |
|
elif isinstance(c, str): |
|
parts.append(c) |
|
parts.append("<end_of_turn>\n") |
|
if add_generation_prompt: |
|
parts.append("<start_of_turn>assistant\n") |
|
|
|
|
|
aud_lens = [] |
|
if aud_arrs: |
|
|
|
waves, srs = zip(*aud_arrs) |
|
fe = self.audio_processor(waves, sampling_rate=WHISPER_SR, padding="longest", truncation=True, return_tensors=None) |
|
|
|
|
|
mel_lens = [int(x.sum()) for x in fe["attention_mask"]] |
|
aud_lens = [compute_audio_token_count(n) for n in mel_lens] |
|
filled, cursor = [], 0 |
|
for seg in parts: |
|
if seg == self.boa: |
|
n = aud_lens[cursor] if cursor < len(aud_lens) else 0 |
|
cursor += 1 |
|
filled.append(f"{self.boa}{self.aud_tok * n}{self.eoa}") |
|
else: |
|
filled.append(seg) |
|
prompt = "".join(filled) |
|
if nested_imgs: |
|
prompt = self._insert_img_tokens(prompt) |
|
if not tokenize: |
|
return prompt |
|
|
|
toks = self.tokenizer(prompt, return_tensors="pt", **tok_kwargs) |
|
toks["token_type_ids"] = (toks["input_ids"] == self.img_id).long() |
|
|
|
if nested_imgs and self.image_processor is not None: |
|
imgs = [img for batch in nested_imgs for img in batch] |
|
toks.update(self.image_processor(imgs, return_tensors="pt")) |
|
if aud_arrs: |
|
waves, srs = zip(*aud_arrs) |
|
toks.update(self.audio_processor(waves, sampling_rate=WHISPER_SR, padding="longest", truncation=True, return_tensors="pt")) |
|
return toks |
|
|
|
|
|
def __call__(self, text: str, *, return_tensors: str = "pt"): |
|
dummy = {"messages": [{"role": "user", "content": [{"type": "text", "text": text}]}]} |
|
return self.apply_chat_template(dummy, add_generation_prompt=False, tokenize=True, return_tensors=return_tensors) |
|
|
|
|
|
@property |
|
def model_input_names(self) -> List[str]: |
|
names = set(self.tokenizer.model_input_names) |
|
names.update(["token_type_ids", "input_audio_embeds", "audio_attention_mask", "pixel_values"]) |
|
return sorted(names) |
|
|