ko-gemma2-9B-sentiment

ํ•œ๊ตญ์–ด ์œ ํŠœ๋ธŒ ๋Œ“๊ธ€ ๊ฐ์ • ๋ถ„๋ฅ˜๋ฅผ ์œ„ํ•œ LLM (Gemma2 ๊ธฐ๋ฐ˜ LoRA Fine-tuned)

Overview

ko-gemma2-9B-sentiment๋Š” Google์˜ Gemma2 9B ๋ชจ๋ธ์„ ๊ธฐ๋ฐ˜์œผ๋กœ, ํ•œ๊ตญ์–ด ์œ ํŠœ๋ธŒ ๋Œ“๊ธ€์˜ ๊ฐ์ •์„ ๋ถ„๋ฅ˜ํ•˜๊ธฐ ์œ„ํ•ด LoRA ๊ธฐ๋ฐ˜์˜ PEFT ๊ธฐ๋ฒ•์œผ๋กœ ํŒŒ์ธํŠœ๋‹๋œ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.
Chain of Thought (CoT) ๋ฐฉ์‹์˜ ํ”„๋กฌํ”„ํŠธ ์„ค๊ณ„์™€ ์œ ํŠœ๋ธŒ ์˜์ƒ ์š”์•ฝ ์ •๋ณด๋ฅผ ํ™œ์šฉํ•˜์—ฌ '๊ธ์ •', '์ค‘๋ฆฝ', '๋ถ€์ •' ์ค‘ ํ•˜๋‚˜์˜ ๊ฐ์ • ํด๋ž˜์Šค๋ฅผ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.

๋ณธ ๋ชจ๋ธ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ํŠน์„ฑ์„ ๊ฐ€์ง‘๋‹ˆ๋‹ค:

  • Gemma2 ๋Œ€ํ™” ํฌ๋งท (<start_of_turn>user, <end_of_turn>, <start_of_turn>model)
  • 4๋น„ํŠธ ์–‘์žํ™” + LoRA๋กœ ๊ฒฝ๋Ÿ‰ํ™”๋œ ํ•™์Šต
  • CoT + Multimodal Prompt (์˜์ƒ ์š”์•ฝ ์ •๋ณด ํฌํ•จ ๊ฐ€๋Šฅ)

Quickstart

Install

$ pip install transformers peft accelerate

Inference

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch

base_model = "rtzr/ko-gemma-2-9b-it"
adapter_path = "./ko-gemma2-9B-sentiment"

prompt = """<start_of_turn>user
๋Œ“๊ธ€: ์ด ์˜์ƒ ์ •๋ง ๊ฐ๋™์ด์—ˆ์Šต๋‹ˆ๋‹ค. ๋ˆˆ๋ฌผ์ด ๋‚ฌ์–ด์š”.
<end_of_turn>
<start_of_turn>model
"""

tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True)
model = PeftModel.from_pretrained(model, adapter_path)
model.eval()

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=1024, do_sample=False)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

์˜ˆ์‹œ ์ถœ๋ ฅ

๋‹ค์Œ ์œ ํŠœ๋ธŒ ๋Œ“๊ธ€์˜ ๊ฐ์ •์„ ๋ถ„์„ํ•˜๊ณ , '๊ธ์ •', '์ค‘๋ฆฝ', '๋ถ€์ •' ์ค‘ ์–ด๋””์— ํ•ด๋‹นํ•˜๋Š”์ง€ ๋ถ„๋ฅ˜ํ•˜๊ณ , ๋˜ํ•œ ์™œ ๊ทธ๋ ‡๊ฒŒ ๋ถ„๋ฅ˜ํ–ˆ๋Š”์ง€ ๊ฐ์ • ๋ถ„๋ฅ˜์˜ ์ด์œ  ๋ฐ ๊ทผ๊ฑฐ๋„ ์„œ์ˆ ํ•ด์ฃผ์„ธ์š”.

๋Œ“๊ธ€: ์ •๋ง ๊ฐ๋™์ด์—ˆ์Šต๋‹ˆ๋‹ค. ๋ˆˆ๋ฌผ์ด ๋‚ฌ์–ด์š”. ์ž˜ ๋ณด๊ณ  ๊ฐ‘๋‹ˆ๋‹ค~


๋Œ“๊ธ€์„ ๋ถ„์„ํ•œ ๊ฒฐ๊ณผ, ์ด ๋Œ“๊ธ€์˜ ๊ฐ์ •์€ '๊ธ์ •'์ž…๋‹ˆ๋‹ค. 

Training Details

๋ชจ๋ธ ๋ฐ ํ™˜๊ฒฝ ๊ตฌ์„ฑ

  • Base Model: rtzr/ko-gemma-2-9b-it
  • Trainer: Hugging Face Trainer + LoRA
  • Quantization: 4bit (nf4, float16 compute)

LoRA ๊ตฌ์„ฑ

  • target_modules: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj
  • r = 8, alpha = 16, dropout = 0.05
  • gradient_checkpointing = True

๋ฐ์ดํ„ฐ์…‹ ์ •๋ณด

  • Train ๋ฐ์ดํ„ฐ ์ˆ˜ : 3,658๊ฐœ
  • Validation ๋ฐ์ดํ„ฐ ์ˆ˜ : 921๊ฐœ

๊ฐ์ • ๋ ˆ์ด๋ธ” ๋ถ„ํฌ

Train
  • ๊ธ์ •: 1,012๊ฐœ (27.67%)
  • ์ค‘๋ฆฝ: 909๊ฐœ (24.85%)
  • ๋ถ€์ •: 1,737๊ฐœ (47.48%)
Validation
  • ๊ธ์ •: 268๊ฐœ (29.10%)
  • ์ค‘๋ฆฝ: 233๊ฐœ (25.30%)
  • ๋ถ€์ •: 420๊ฐœ (45.60%)

Results


Fine-tuned Performance

Confusion Matrix

Confusion Matrix

Classification Report

              precision    recall  f1-score   support

          ๊ธ์ •       0.89      0.85      0.87       574
          ์ค‘๋ฆฝ       0.46      0.52      0.49       169
          ๋ถ€์ •       0.70      0.70      0.70       246
    accuracy                           0.76       989
   macro avg       0.68      0.69      0.69       989
weighted avg       0.77      0.76      0.76       989

Contact

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support