seastar105's picture
Update README.md
5b63f65 verified
metadata
datasets:
  - kresnik/zeroth_korean
metrics:
  - bleu
  - cer
base_model:
  - microsoft/Phi-4-multimodal-instruct
model-index:
  - name: Phi-4-mm-inst-zeroth-kor
    results:
      - task:
          type: speech-to-text-translation
        dataset:
          type: seastar105/fleurs_ko_en_test
          name: fleurs (ko-en test intersection)
        metrics:
          - type: bleu
            name: ko2en
            value: 7.07
          - type: bleu
            name: ko2en-cot
            value: 9.19
          - type: bleu
            name: en2ko (ko-mecab)
            value: 13.08
          - type: bleu
            name: en2ko-cot (ko-mecab)
            value: 9.35
      - task:
          type: automatic-speech-recognition
        dataset:
          type: kresnik/zeroth_korean
          name: zeroth_korean test
        metrics:
          - type: cer
            name: test CER
            value: 7.02
language:
  - ko

This model is fine-tuned from microsoft/Phi-4-multimodal-instruct on kresnik/zeroth_korean dataset only 1 epoch.

script for fine-tuning is here, adapted from phi-4 repository example

model is trained only 174 steps on zeroth train set, and main purpose is to check if only korean ASR training can expand to other speech tasks(e.g. speech-to-text-translation)

Evaluation

ASR on zeroth-test set and Speech translation on fleurs ko <-> en speech translation result. script is here, and used 1 A40.

Model zeroth-test fleurs-ko2en fleurs-ko2en-cot fleurs-en2ko fleurs-en2ko-cot
original 195.92 5.62 2.45 6.87 4.35
finetune (this model) 7.02 7.07 9.19 13.08 9.35

Example script

orig_model_path = "microsoft/Phi-4-multimodal-instruct"
ft_model_path = "seastar105/Phi-4-mm-inst-zeroth-kor"
generation_config = GenerationConfig.from_pretrained(orig_model_path, 'generation_config.json')
processor = AutoProcessor.from_pretrained(orig_model_path, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
    ft_model_path,
    trust_remote_code=True,
    torch_dtype='auto',
    _attn_implementation='flash_attention_2',
).cuda()

user_prompt = '<|user|>'
assistant_prompt = '<|assistant|>'
prompt_suffix = '<|end|>'

# task prompt is from technical report
asr_prompt = f'{user_prompt}<|audio_1|>Transcribe the audio clip into text.{prompt_suffix}{assistant_prompt}'
ast_ko_prompt = f'{user_prompt}<|audio_1|>Translate the audio to Korean.{prompt_suffix}{assistant_prompt}'
ast_cot_ko_prompt = f'{user_prompt}<|audio_1|>Transcribe the audio to text, and then translate the audio to Korean. Use <sep> as a separator between the original transcript and the translation.{prompt_suffix}{assistant_prompt}'
ast_en_prompt = f'{user_prompt}<|audio_1|>Translate the audio to English.{prompt_suffix}{assistant_prompt}'
ast_cot_en_prompt = f'{user_prompt}<|audio_1|>Transcribe the audio to text, and then translate the audio to English. Use <sep> as a separator between the original transcript and the translation.{prompt_suffix}{assistant_prompt}'

asr_ds = load_dataset("kresnik/zeroth_korean", split="test")
ast_ds = load_dataset("seastar105/fleurs_ko_en_test", split="train")

# ASR
item = asr_ds[0]
audio = (item["audio"]["array"], item["audio"]["sampling_rate"])
inputs = processor(text=asr_prompt, audios=[audio], return_tensors='pt').to(model.device)
generate_ids = model.generate(
    **inputs,
    max_new_tokens=max_new_tokens,
    generation_config=generation_config,
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1] :]
response = processor.batch_decode(
    generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0] # "๋ชฌํ†  ํ‚ฌ์€ ์ž๋…€๋“ค์ด ์‚ฌ๋ž‘์„ ์ œ๋Œ€๋กœ ๋ชป ๋ฐ›๊ณ  ํฌ๋ฉด ๋งค์šฐ ์‹ฌ๊ฐํ•œ ๊ฒฐ๊ณผ๊ฐ€ ์ดˆ๋ž˜๋œ๋‹ค๋Š” ๊ฒฐ๋ก ์„ ๋‚ด๋ ธ์Šต๋‹ˆ๋‹ค"

# AST, EN -> KO
item = ast_ds[-1]
audio = (item["en_audio"]["array"], item["en_audio"]["sampling_rate"])
inputs = processor(text=ast_en, audios=[audio], return_tensors='pt').to(model.device)
generate_ids = model.generate(
    **inputs,
    max_new_tokens=max_new_tokens,
    generation_config=generation_config,
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1] :]
response = processor.batch_decode(
    generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0] # "๊ฐ€์žฅ ์‰ฝ๊ฒŒ ์ ‘๊ทผ ๊ฐ€๋Šฅํ•œ ์‹๋ฌผ ์ž์›์€ ์žŽ๊ณผ lรฉgumes์—์„œ ์ ‘๊ทผ ๊ฐ€๋Šฅํ•œ ๋‹จ๋ฐฑ์งˆ์ด์—ˆ์„ ๊ฒƒ์ด๋‹ค๊ฐ€์š” ํ•˜์ง€๋งŒ ์ด๊ฒƒ๋“ค์€ ๊ณ ํ˜•์ƒ ๋™๋ฌผ์ฒ˜๋Ÿผ ์šฐ๋ฆฌ์—๊ฒŒ ์†Œํ™”ํ•˜๊ธฐ ์–ด๋ ต์Šต๋‹ˆ๋‹ค๋งŒ ๊ทธ๊ฒƒ๋“ค์ด ๋“์—ฌ ์žˆ๋‹ค๋ฉด์š”"