kimhyunwoo's picture
Update app.py
c386b23 verified
raw
history blame
1.91 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaTokenizer
import torch
import os
import gradio as gr
try:
# ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
model_id = "kimhyunwoo/gemma2-ko-dialogue-lora-fp16"
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
# AutoTokenizer ๋Œ€์‹  ์ง์ ‘ GemmaTokenizer๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
tokenizer = GemmaTokenizer.from_pretrained(model_id, trust_remote_code=True)
# CPU ํ™˜๊ฒฝ์—์„œ ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๊ณ  ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ torch.float32๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
model = model.to(torch.float32)
model.eval()
# ํ…์ŠคํŠธ ์ƒ์„ฑ ํ•จ์ˆ˜
def generate_text(text, max_length=50, do_sample=True, temperature=1.0):
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(**inputs, max_length=max_length, do_sample=do_sample, temperature=temperature)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
iface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(lines=5, placeholder="ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”..."),
gr.Slider(minimum=10, maximum=200, step=1, value=50, label="์ตœ๋Œ€ ๊ธธ์ด"),
gr.Checkbox(label="์ƒ˜ํ”Œ๋ง"),
gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="์˜จ๋„"),
],
outputs=gr.Textbox(lines=5, label="์ƒ์„ฑ๋œ ํ…์ŠคํŠธ"),
title="Gemma 2 Text Generator",
description="Fine-tuned๋œ Gemma 2 ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.",
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)
except ImportError as e:
print(f"ImportError ๋ฐœ์ƒ: {e}")
except Exception as e:
print(f"์˜ˆ์ƒ์น˜ ๋ชปํ•œ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")