kimhyunwoo commited on
Commit
c386b23
ยท
verified ยท
1 Parent(s): 2a40503

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -31
app.py CHANGED
@@ -1,39 +1,44 @@
1
  from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaTokenizer
2
  import torch
3
  import os
 
4
 
5
- # ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
6
- model_id = "kimhyunwoo/gemma2-ko-dialogue-lora-fp16"
7
- model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
8
- # AutoTokenizer ๋Œ€์‹  ์ง์ ‘ GemmaTokenizer๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
9
- tokenizer = GemmaTokenizer.from_pretrained(model_id, trust_remote_code=True)
 
10
 
11
- # CPU ํ™˜๊ฒฝ์—์„œ ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๊ณ  ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ torch.float32๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
12
- model = model.to(torch.float32)
13
- model.eval()
14
 
15
- # ํ…์ŠคํŠธ ์ƒ์„ฑ ํ•จ์ˆ˜
16
- def generate_text(text, max_length=50, do_sample=True, temperature=1.0):
17
- inputs = tokenizer(text, return_tensors="pt")
18
- with torch.no_grad():
19
- outputs = model.generate(**inputs, max_length=max_length, do_sample=do_sample, temperature=temperature)
20
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
21
- return generated_text
22
 
23
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
24
- import gradio as gr
25
- iface = gr.Interface(
26
- fn=generate_text,
27
- inputs=[
28
- gr.Textbox(lines=5, placeholder="ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”..."),
29
- gr.Slider(minimum=10, maximum=200, step=1, value=50, label="์ตœ๋Œ€ ๊ธธ์ด"),
30
- gr.Checkbox(label="์ƒ˜ํ”Œ๋ง"),
31
- gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="์˜จ๋„"),
32
- ],
33
- outputs=gr.Textbox(lines=5, label="์ƒ์„ฑ๋œ ํ…์ŠคํŠธ"),
34
- title="Gemma 2 Text Generator",
35
- description="Fine-tuned๋œ Gemma 2 ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.",
36
- )
37
 
38
- if __name__ == "__main__":
39
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaTokenizer
2
  import torch
3
  import os
4
+ import gradio as gr
5
 
6
+ try:
7
+ # ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
8
+ model_id = "kimhyunwoo/gemma2-ko-dialogue-lora-fp16"
9
+ model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
10
+ # AutoTokenizer ๋Œ€์‹  ์ง์ ‘ GemmaTokenizer๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
11
+ tokenizer = GemmaTokenizer.from_pretrained(model_id, trust_remote_code=True)
12
 
13
+ # CPU ํ™˜๊ฒฝ์—์„œ ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๊ณ  ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ torch.float32๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
14
+ model = model.to(torch.float32)
15
+ model.eval()
16
 
17
+ # ํ…์ŠคํŠธ ์ƒ์„ฑ ํ•จ์ˆ˜
18
+ def generate_text(text, max_length=50, do_sample=True, temperature=1.0):
19
+ inputs = tokenizer(text, return_tensors="pt")
20
+ with torch.no_grad():
21
+ outputs = model.generate(**inputs, max_length=max_length, do_sample=do_sample, temperature=temperature)
22
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
23
+ return generated_text
24
 
25
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
26
+ iface = gr.Interface(
27
+ fn=generate_text,
28
+ inputs=[
29
+ gr.Textbox(lines=5, placeholder="ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”..."),
30
+ gr.Slider(minimum=10, maximum=200, step=1, value=50, label="์ตœ๋Œ€ ๊ธธ์ด"),
31
+ gr.Checkbox(label="์ƒ˜ํ”Œ๋ง"),
32
+ gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="์˜จ๋„"),
33
+ ],
34
+ outputs=gr.Textbox(lines=5, label="์ƒ์„ฑ๋œ ํ…์ŠคํŠธ"),
35
+ title="Gemma 2 Text Generator",
36
+ description="Fine-tuned๋œ Gemma 2 ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.",
37
+ )
 
38
 
39
+ if __name__ == "__main__":
40
+ iface.launch(server_name="0.0.0.0", server_port=7860)
41
+ except ImportError as e:
42
+ print(f"ImportError ๋ฐœ์ƒ: {e}")
43
+ except Exception as e:
44
+ print(f"์˜ˆ์ƒ์น˜ ๋ชปํ•œ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")