earica-audio-1b / processing_gemma3_omni.py
voidful's picture
Update processing_gemma3_omni.py
bff974b verified
# -*- coding: utf-8 -*-
"""Gemmaโ€‘3 Multimodal Utilities โ€“ Whisper v3 Frontend
====================================================
Pure *preprocessing* helpers for text / image / audio. **Zero model layers.**
Exports
-------
* `compute_audio_token_count` โ€“ Whisper v3 tokenโ€‘length helper.
* `Gemma3OmniProcessor` โ€“ unify text + image + audio into HF batch.
"""
from __future__ import annotations
import numpy as np
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers.image_utils import load_image
from transformers.processing_utils import ProcessorMixin
from transformers.utils import logging
from transformers import AutoProcessor
logger = logging.get_logger(__name__)
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Constants
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
SPECIAL_TOKENS: Dict[str, str] = {
"bos_token": "<bos>",
"eos_token": "<eos>",
"pad_token": "<pad>",
"unk_token": "<unk>",
"boi_token": "<start_of_image>",
"eoi_token": "<end_of_image>",
"image_token": "<image_soft_token>",
"boa_token": "<start_of_audio>",
"eoa_token": "<end_of_audio>",
"audio_token": "<audio_soft_token>",
}
WHISPER_SR = 16_000
COMP_RATE = 2
IMAGE_SEQ_LEN = 256
__all__ = [
"compute_audio_token_count",
"Gemma3OmniProcessor",
]
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Audio helpers
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def compute_audio_token_count(mel_frames: int) -> int:
"""Whisper v3 token length = ceil(mel_frames / 2)."""
return (mel_frames + COMP_RATE - 1) // COMP_RATE
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Omni Processor
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class Gemma3OmniProcessor(ProcessorMixin):
"""
็ตฑไธ€ text, image, audio ่ณ‡ๆ–™ๆต๏ผŒ็›ดๆŽฅๆŽก็”จๅฎ˜ๆ–น WhisperFeatureExtractorใ€‚
"""
attributes = ["image_processor", "audio_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
audio_processor_class = "AutoFeatureExtractor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor,
audio_processor=None,
tokenizer=None,
*,
special_tokens: Optional[Dict[str, str]] = None,
image_seq_length: int = IMAGE_SEQ_LEN,
whisper_model_name: str = "openai/whisper-large-v3",
) -> None:
super().__init__(tokenizer=tokenizer, image_processor=image_processor, audio_processor=audio_processor)
self.tokenizer = tokenizer
self.image_processor = image_processor
# audio_processor ้ ่จญ็”จๅฎ˜ๆ–น WhisperFeatureExtractor
if audio_processor is None:
self.audio_processor = AutoProcessor.from_pretrained(whisper_model_name).feature_extractor
else:
self.audio_processor = audio_processor
self.toks: Dict[str, str] = SPECIAL_TOKENS | (special_tokens or {})
for k, v in self.toks.items():
if not getattr(self.tokenizer, k, None):
setattr(self.tokenizer, k, v)
self.boi, self.eoi, self.img_tok = self.toks["boi_token"], self.toks["eoi_token"], self.toks["image_token"]
self.boa, self.eoa, self.aud_tok = self.toks["boa_token"], self.toks["eoa_token"], self.toks["audio_token"]
self.full_img_seq = f"{self.boi}{self.img_tok * image_seq_length}{self.eoi}"
self.img_id = self.tokenizer.convert_tokens_to_ids(self.img_tok)
self.aud_id = self.tokenizer.convert_tokens_to_ids(self.aud_tok)
# ---------------- helpers ----------------
def _insert_img_tokens(self, s: str) -> str:
return s.replace(self.boi, self.full_img_seq)
# ---------------- main API ----------------
def apply_chat_template(
self,
example: Dict[str, Any],
*,
add_generation_prompt: bool = True,
tokenize: bool = False,
**tok_kwargs,
) -> Union[str, Dict[str, Any]]:
msgs = example["messages"] if isinstance(example, dict) else example
aud_payload = example.get("audio") if isinstance(example, dict) else None
nested_imgs, aud_arrs, parts = [], [], []
for m in msgs:
role = m.get("role")
content = m.get("content", [])
content = content if isinstance(content, list) else [content]
parts.append(f"<start_of_turn>{role}\n")
for c in content:
if isinstance(c, dict):
if c.get("type") == "image":
parts.append(self.boi)
nested_imgs.append([load_image(c.get("url") or c.get("image"))])
elif c.get("type") == "audio":
parts.append(self.boa)
ad = None
idx = c.get("index")
if aud_payload is not None and idx is not None:
if isinstance(aud_payload, list):
try:
ad = aud_payload[idx]
except Exception as e:
raise KeyError(f"audio list ็ดขๅผ• {idx} ็„กๆณ•ๅ–ๅพ—๏ผš{e}")
elif isinstance(aud_payload, dict) and idx == 0:
ad = aud_payload
else:
raise TypeError(f"audio payload ๅž‹ๆ…‹ไธ็ฌฆ๏ผš{type(aud_payload)}๏ผŒindex={idx}")
elif "audio" in c:
ad = c["audio"]
else:
raise KeyError(f"audio ๆจ™่จ˜ไฝ†ๆ‰พไธๅˆฐ็ดขๅผ•ๆˆ–ๅ…งๅตŒ audioใ€‚content: {c}")
if ad is None or "array" not in ad:
raise KeyError(f"audio object ็ผบๅฐ‘ 'array' ๆฌ„ไฝ๏ผš{ad}")
wav = np.asarray(ad["array"], dtype=np.float32)
sr = ad.get("sampling_rate", WHISPER_SR)
aud_arrs.append((wav, sr))
elif isinstance(c, str):
parts.append(c)
parts.append("<end_of_turn>\n")
if add_generation_prompt:
parts.append("<start_of_turn>assistant\n")
# audio token lengths
aud_lens = []
if aud_arrs:
# ๅฎ˜ๆ–น WhisperFeatureExtractor๏ผš่ผธๅ…ฅ waves๏ผˆlist[np.ndarray]๏ผ‰๏ผŒ่‡ชๅ‹• pad/cutใ€resample๏ผŒ็ตฆๅ‡บๆญฃ็ขบ shape
waves, srs = zip(*aud_arrs)
fe = self.audio_processor(waves, sampling_rate=WHISPER_SR, padding="longest", truncation=True, return_tensors=None)
# fe["input_features"].shape = (B, 128, 3000)๏ผ›fe["attention_mask"] shape = (B, 3000)
# token ๆ•ธ่ฆไปฅๆœ‰ mask ็š„ frame ่จˆ็ฎ—
mel_lens = [int(x.sum()) for x in fe["attention_mask"]]
aud_lens = [compute_audio_token_count(n) for n in mel_lens]
filled, cursor = [], 0
for seg in parts:
if seg == self.boa:
n = aud_lens[cursor] if cursor < len(aud_lens) else 0
cursor += 1
filled.append(f"{self.boa}{self.aud_tok * n}{self.eoa}")
else:
filled.append(seg)
prompt = "".join(filled)
if nested_imgs:
prompt = self._insert_img_tokens(prompt)
if not tokenize:
return prompt
toks = self.tokenizer(prompt, return_tensors="pt", **tok_kwargs)
toks["token_type_ids"] = (toks["input_ids"] == self.img_id).long()
if nested_imgs and self.image_processor is not None:
imgs = [img for batch in nested_imgs for img in batch]
toks.update(self.image_processor(imgs, return_tensors="pt"))
if aud_arrs:
waves, srs = zip(*aud_arrs)
toks.update(self.audio_processor(waves, sampling_rate=WHISPER_SR, padding="longest", truncation=True, return_tensors="pt"))
return toks
# ---------------- convenience ----------------
def __call__(self, text: str, *, return_tensors: str = "pt"):
dummy = {"messages": [{"role": "user", "content": [{"type": "text", "text": text}]}]}
return self.apply_chat_template(dummy, add_generation_prompt=False, tokenize=True, return_tensors=return_tensors)
# ---------------- model_input_names ----------------
@property
def model_input_names(self) -> List[str]:
names = set(self.tokenizer.model_input_names)
names.update(["token_type_ids", "input_audio_embeds", "audio_attention_mask", "pixel_values"])
return sorted(names)