gemma3n-audio-encoder-VQ-32k-whisper-decoder

Combine mesolitica/gemma-3n-e4b-it-audio-encoder Encoder + Projection + VQ + Projection Layer Norm + openai/whisper-large-v3-turbo Decoder.

This model to introduce VQ on top mesolitica/gemma3n-audio-encoder-whisper-decoder

This is the most compressed speech token model, 6.25 TPS with 32768 embedding size.

WanDB at https://wandb.ai/huseinzol05/gemma3n-audio-vq-whisper-decoder-v5

Training dataset

  1. malaysia-ai/common_voice_17_0
  2. mesolitica/Malaysian-STT-Whisper-Stage2/malaysian_multiturn_chat_assistants_segments
  3. mesolitica/Malaysian-STT-Whisper-Stage2/malaysian_multiturn_chat_assistants_manglish_segments

how to audio token

from transformers import AutoFeatureExtractor, AutoModel, AutoTokenizer
import librosa

model_id = "mesolitica/gemma3n-audio-encoder-VQ-32k-whisper-decoder"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id, trust_remote_code = True, torch_dtype = 'auto').cuda()
encoder = model.model.get_encoder()
y, sr = librosa.load('common_voice_ba_26517811.mp3', sr = feature_extractor.sampling_rate)
features = feature_extractor([y], return_tensors = 'pt')
features['input_features'] = features['input_features'].cuda()
features['input_features_mask'] = features['input_features_mask'].cuda()
_, tokens = encoder(**features)
print(tokens)
tensor([ 4679, 20093,  8341,  7777, 21322, 30807,  3741, 10235,  4053,  6004,
        17969,  1095, 30875, 10580,  9639, 22731, 29890, 28581, 20118,  3688,
        29172,  3227, 23437, 22097, 11855, 13388,  8268, 17958, 18715],
       device='cuda:0')

how to decode

from transformers import AutoFeatureExtractor, AutoModel, AutoTokenizer
import librosa

model_id = "mesolitica/gemma3n-audio-encoder-VQ-32k-whisper-decoder"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id, trust_remote_code = True, torch_dtype = 'auto').cuda()
tokenizer = AutoTokenizer.from_pretrained(model_id)

y, sr = librosa.load('common_voice_ba_26517811.mp3', sr = feature_extractor.sampling_rate)
input_ids = tokenizer(
    '<|startoftranscript|><|ru|><|transcribe|><|notimestamps|>', 
    add_special_tokens = False, return_tensors = 'pt')['input_ids']
features = feature_extractor([y], return_tensors = 'pt')
features['input_features'] = features['input_features'].cuda()
features['input_features_mask'] = features['input_features_mask'].cuda()
features['attention_mask'] = features['input_features_mask']
features['decoder_input_ids'] = input_ids.cuda()

generate_kwargs = dict(
    **features,
    max_new_tokens=1024,
)
generation_output = model.generate(**generate_kwargs)
tokenizer.decode(generation_output[0])

Output,

<|startoftranscript|><|ru|><|transcribe|><|notimestamps|> Купыкта был широкое глобка шляпше на битапсы.<|endoftext|>

Evaluation

Evaluate on malaysia-ai/common_voice_17_0/test, with some conditions,

  1. Lower case.
  2. Remove punctuation.
  3. Provide language tagging for decoder input ids, <|startoftranscript|><|{lang}|><|transcribe|><|notimestamps|>.

Source code

Source code at https://github.com/mesolitica/malaya-speech/tree/master/session/gemma3n-audio-whisper-decoder

Downloads last month
31
Safetensors
Model size
903M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including mesolitica/gemma3n-audio-encoder-VQ-32k-whisper-decoder