openfree commited on
Commit
2aa4818
ยท
verified ยท
1 Parent(s): 6722f51

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +305 -0
app.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import time
5
+ import numpy as np
6
+
7
+ class MergedModelTester:
8
+ def __init__(self):
9
+ self.model = None
10
+ self.tokenizer = None
11
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+
13
+ def load_model(self, model_id="openfree/gpt2-bert", progress=gr.Progress()):
14
+ """๋ณ‘ํ•ฉ ๋ชจ๋ธ ๋กœ๋“œ"""
15
+ try:
16
+ progress(0.2, desc="ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์ค‘...")
17
+ self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
18
+ self.tokenizer.pad_token = self.tokenizer.eos_token
19
+
20
+ progress(0.5, desc="๋ชจ๋ธ ๋กœ๋“œ ์ค‘...")
21
+ self.model = AutoModelForCausalLM.from_pretrained(
22
+ model_id,
23
+ torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32,
24
+ device_map="auto" if self.device.type == 'cuda' else None
25
+ )
26
+
27
+ if self.device.type == 'cpu':
28
+ self.model = self.model.to(self.device)
29
+
30
+ self.model.eval()
31
+
32
+ progress(1.0, desc="์™„๋ฃŒ!")
33
+
34
+ # ๋ชจ๋ธ ์ •๋ณด
35
+ num_params = sum(p.numel() for p in self.model.parameters())
36
+ return f"""โœ… ๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต!
37
+ - ๋ชจ๋ธ: {model_id}
38
+ - ํŒŒ๋ผ๋ฏธํ„ฐ: {num_params:,}
39
+ - ๋””๋ฐ”์ด์Šค: {self.device}"""
40
+
41
+ except Exception as e:
42
+ return f"โŒ ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {str(e)}"
43
+
44
+ def generate_text(self, prompt, max_length=100, temperature=0.8,
45
+ top_p=0.9, repetition_penalty=1.2, progress=gr.Progress()):
46
+ """ํ…์ŠคํŠธ ์ƒ์„ฑ"""
47
+ if self.model is None:
48
+ return "๋จผ์ € ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜์„ธ์š”!", None, None
49
+
50
+ try:
51
+ progress(0.3, desc="ํ…์ŠคํŠธ ์ƒ์„ฑ ์ค‘...")
52
+
53
+ # ์ž…๋ ฅ ํ† ํฐํ™”
54
+ inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
55
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
56
+
57
+ # ์ƒ์„ฑ ์‹œ์ž‘ ์‹œ๊ฐ„
58
+ start_time = time.time()
59
+
60
+ # ํ…์ŠคํŠธ ์ƒ์„ฑ
61
+ with torch.no_grad():
62
+ outputs = self.model.generate(
63
+ **inputs,
64
+ max_new_tokens=max_length,
65
+ temperature=temperature,
66
+ top_p=top_p,
67
+ repetition_penalty=repetition_penalty,
68
+ do_sample=True,
69
+ pad_token_id=self.tokenizer.pad_token_id,
70
+ eos_token_id=self.tokenizer.eos_token_id
71
+ )
72
+
73
+ # ์ƒ์„ฑ ์‹œ๊ฐ„ ๊ณ„์‚ฐ
74
+ generation_time = time.time() - start_time
75
+
76
+ # ๋””์ฝ”๋”ฉ
77
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
78
+
79
+ # ํ†ต๊ณ„ ์ •๋ณด
80
+ input_tokens = len(inputs['input_ids'][0])
81
+ output_tokens = len(outputs[0])
82
+ new_tokens = output_tokens - input_tokens
83
+
84
+ stats = f"""๐Ÿ“Š ์ƒ์„ฑ ํ†ต๊ณ„:
85
+ - ์ž…๋ ฅ ํ† ํฐ: {input_tokens}
86
+ - ์ƒ์„ฑ ํ† ํฐ: {new_tokens}
87
+ - ์ „์ฒด ํ† ํฐ: {output_tokens}
88
+ - ์ƒ์„ฑ ์‹œ๊ฐ„: {generation_time:.2f}์ดˆ
89
+ - ์†๋„: {new_tokens/generation_time:.1f} tokens/sec"""
90
+
91
+ progress(1.0, desc="์™„๋ฃŒ!")
92
+
93
+ return generated_text, stats, None
94
+
95
+ except Exception as e:
96
+ return f"โŒ ์ƒ์„ฑ ์‹คํŒจ: {str(e)}", None, str(e)
97
+
98
+ def compare_with_parents(self, prompt, max_length=50, progress=gr.Progress()):
99
+ """๋ถ€๋ชจ ๋ชจ๋ธ๋“ค๊ณผ ๋น„๊ต"""
100
+ results = {}
101
+
102
+ # GPT-2 (๋ถ€๋ชจ 1)
103
+ try:
104
+ progress(0.1, desc="GPT-2 ๋กœ๋“œ ์ค‘...")
105
+ gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")
106
+ gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token
107
+ gpt2_model = AutoModelForCausalLM.from_pretrained("gpt2").to(self.device)
108
+
109
+ progress(0.3, desc="GPT-2 ์ƒ์„ฑ ์ค‘...")
110
+ inputs = gpt2_tokenizer(prompt, return_tensors="pt").to(self.device)
111
+ with torch.no_grad():
112
+ outputs = gpt2_model.generate(**inputs, max_new_tokens=max_length, do_sample=True)
113
+ results['gpt2'] = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True)
114
+ del gpt2_model
115
+
116
+ except Exception as e:
117
+ results['gpt2'] = f"๋กœ๋“œ ์‹คํŒจ: {str(e)}"
118
+
119
+ # BERT๋Š” ์ƒ์„ฑ ๋ชจ๋ธ์ด ์•„๋‹ˆ๋ฏ€๋กœ ์ œ์™ธ
120
+ results['bert'] = "BERT๋Š” ์ƒ์„ฑ ๋ชจ๋ธ์ด ์•„๋‹™๋‹ˆ๋‹ค (์ธ์ฝ”๋” ์ „์šฉ)"
121
+
122
+ # ๋ณ‘ํ•ฉ ๋ชจ๋ธ
123
+ try:
124
+ progress(0.6, desc="๋ณ‘ํ•ฉ ๋ชจ๋ธ ์ƒ์„ฑ ์ค‘...")
125
+ if self.model is None:
126
+ self.load_model()
127
+
128
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
129
+ with torch.no_grad():
130
+ outputs = self.model.generate(**inputs, max_new_tokens=max_length, do_sample=True)
131
+ results['merged'] = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
132
+
133
+ except Exception as e:
134
+ results['merged'] = f"์ƒ์„ฑ ์‹คํŒจ: {str(e)}"
135
+
136
+ progress(1.0, desc="์™„๋ฃŒ!")
137
+
138
+ # ๊ฒฐ๊ณผ ํฌ๋งทํŒ…
139
+ comparison = f"""๐Ÿ”„ ๋ชจ๋ธ ๋น„๊ต ๊ฒฐ๊ณผ:
140
+
141
+ **GPT-2 (๋ถ€๋ชจ 1):**
142
+ {results['gpt2']}
143
+
144
+ **BERT (๋ถ€๋ชจ 2):**
145
+ {results['bert']}
146
+
147
+ **๋ณ‘ํ•ฉ ๋ชจ๋ธ (openfree/gpt2-bert):**
148
+ {results['merged']}"""
149
+
150
+ return comparison
151
+
152
+ # ์ „์—ญ ์ธ์Šคํ„ด์Šค
153
+ tester = MergedModelTester()
154
+
155
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค
156
+ with gr.Blocks(title="GPT2-BERT ๋ณ‘ํ•ฉ ๋ชจ๋ธ ํ…Œ์Šคํ„ฐ") as demo:
157
+ gr.Markdown("""
158
+ # ๐Ÿงฌ GPT2-BERT ๋ณ‘ํ•ฉ ๋ชจ๋ธ ํ…Œ์Šคํ„ฐ
159
+
160
+ ์ง„ํ™”์  ์•Œ๊ณ ๋ฆฌ์ฆ˜์œผ๋กœ ๋ณ‘ํ•ฉ๋œ [openfree/gpt2-bert](https://huggingface.co/openfree/gpt2-bert) ๋ชจ๋ธ์„ ํ…Œ์ŠคํŠธํ•ฉ๋‹ˆ๋‹ค.
161
+
162
+ ## ๐Ÿ“Š ๋ชจ๋ธ ์ •๋ณด
163
+ - **๋ถ€๋ชจ 1**: openai-community/gpt2
164
+ - **๋ถ€๋ชจ 2**: google-bert/bert-base-uncased
165
+ - **๋ณ‘ํ•ฉ ๋ฐฉ๋ฒ•**: SLERP (์ง„ํ™”์  ์ตœ์ ํ™”)
166
+ - **์ตœ์ข… ์„ฑ๋Šฅ**: 82-84% accuracy
167
+ """)
168
+
169
+ with gr.Tab("๐Ÿš€ ๋น ๋ฅธ ํ…Œ์ŠคํŠธ"):
170
+ with gr.Row():
171
+ with gr.Column():
172
+ load_btn = gr.Button("๐Ÿ“ฅ ๋ชจ๋ธ ๋กœ๋“œ", variant="primary")
173
+ load_status = gr.Textbox(label="๋กœ๋“œ ์ƒํƒœ", lines=4)
174
+
175
+ prompt_input = gr.Textbox(
176
+ label="ํ”„๋กฌํ”„ํŠธ",
177
+ placeholder="ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”...",
178
+ value="The future of AI is",
179
+ lines=3
180
+ )
181
+
182
+ with gr.Row():
183
+ max_length = gr.Slider(20, 200, 100, label="์ตœ๋Œ€ ๊ธธ์ด")
184
+ temperature = gr.Slider(0.1, 2.0, 0.8, label="Temperature")
185
+
186
+ with gr.Row():
187
+ top_p = gr.Slider(0.1, 1.0, 0.9, label="Top-p")
188
+ rep_penalty = gr.Slider(1.0, 2.0, 1.2, label="๋ฐ˜๋ณต ํŒจ๋„ํ‹ฐ")
189
+
190
+ generate_btn = gr.Button("โœจ ํ…์ŠคํŠธ ์ƒ์„ฑ", variant="primary")
191
+
192
+ with gr.Column():
193
+ output_text = gr.Textbox(label="์ƒ์„ฑ๋œ ํ…์ŠคํŠธ", lines=10)
194
+ stats_text = gr.Textbox(label="์ƒ์„ฑ ํ†ต๊ณ„", lines=6)
195
+
196
+ with gr.Tab("๐Ÿ”ฌ ๋ชจ๋ธ ๋น„๊ต"):
197
+ compare_prompt = gr.Textbox(
198
+ label="๋น„๊ตํ•  ํ”„๋กฌํ”„ํŠธ",
199
+ value="Once upon a time",
200
+ lines=2
201
+ )
202
+ compare_length = gr.Slider(20, 100, 50, label="์ƒ์„ฑ ๊ธธ์ด")
203
+ compare_btn = gr.Button("๐Ÿ”„ ๋ถ€๋ชจ ๋ชจ๋ธ๊ณผ ๋น„๊ต", variant="primary")
204
+ comparison_output = gr.Textbox(label="๋น„๊ต ๊ฒฐ๊ณผ", lines=20)
205
+
206
+ with gr.Tab("๐Ÿงช ๊ณ ๊ธ‰ ํ…Œ์ŠคํŠธ"):
207
+ gr.Markdown("### ๋‹ค์–‘ํ•œ ํƒœ์Šคํฌ ํ…Œ์ŠคํŠธ")
208
+
209
+ task_type = gr.Radio(
210
+ ["์ด์•ผ๊ธฐ ์ƒ์„ฑ", "์งˆ๋ฌธ ๋‹ต๋ณ€", "์ฝ”๋“œ ์ƒ์„ฑ", "์‹œ ์ž‘์„ฑ"],
211
+ label="ํƒœ์Šคํฌ ์„ ํƒ",
212
+ value="์ด์•ผ๊ธฐ ์ƒ์„ฑ"
213
+ )
214
+
215
+ task_prompts = {
216
+ "์ด์•ผ๊ธฐ ์ƒ์„ฑ": "In a distant galaxy, a young explorer discovered",
217
+ "์งˆ๋ฌธ ๋‹ต๋ณ€": "Q: What is machine learning?\nA:",
218
+ "์ฝ”๋“œ ์ƒ์„ฑ": "# Python function to calculate fibonacci\ndef fibonacci(n):",
219
+ "์‹œ ์ž‘์„ฑ": "Roses are red,\nViolets are blue,"
220
+ }
221
+
222
+ def update_prompt(task):
223
+ return task_prompts.get(task, "")
224
+
225
+ task_prompt = gr.Textbox(label="ํƒœ์Šคํฌ ํ”„๋กฌํ”„ํŠธ", lines=3)
226
+ task_output = gr.Textbox(label="๊ฒฐ๊ณผ", lines=10)
227
+ task_btn = gr.Button("๐ŸŽฏ ํƒœ์Šคํฌ ์‹คํ–‰", variant="primary")
228
+
229
+ task_type.change(update_prompt, task_type, task_prompt)
230
+
231
+ with gr.Tab("๐Ÿ“ˆ ์„ฑ๋Šฅ ๋ถ„์„"):
232
+ gr.Markdown("""
233
+ ### ์ง„ํ™” ์‹คํ—˜ ๊ฒฐ๊ณผ
234
+
235
+ | ๋ฉ”ํŠธ๋ฆญ | ๊ฐ’ |
236
+ |--------|-----|
237
+ | ์ดˆ๊ธฐ ์„ฑ๋Šฅ | 10.56% |
238
+ | ์ตœ์ข… ์„ฑ๋Šฅ | 82-84% |
239
+ | ๊ฐœ์„ ์œจ | +700% |
240
+ | ์ด ๊ฐœ์„  ํšŸ์ˆ˜ | 2,136ํšŒ |
241
+ | ํ•™์Šต ์‹œ๊ฐ„ | 7.7๋ถ„ |
242
+
243
+ ### ์„ธ๋Œ€๋ณ„ ์„ฑ๋Šฅ
244
+ - **์ดˆ๊ธฐ (0-2000)**: ํฐ ๊ฐœ์„  (+20-30%/์„ธ๋Œ€)
245
+ - **์ค‘๊ธฐ (2000-5000)**: ์ค‘๊ฐ„ ๊ฐœ์„  (+10-15%/์„ธ๋Œ€)
246
+ - **ํ›„๊ธฐ (5000-10000)**: ๋ฏธ์„ธ ์กฐ์ • (+2-5%/์„ธ๋Œ€)
247
+ """)
248
+
249
+ test_suite_btn = gr.Button("๐Ÿ” ์ „์ฒด ํ…Œ์ŠคํŠธ ์Šค์œ„ํŠธ ์‹คํ–‰", variant="primary")
250
+ test_results = gr.Textbox(label="ํ…Œ์ŠคํŠธ ๊ฒฐ๊ณผ", lines=15)
251
+
252
+ # ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
253
+ load_btn.click(
254
+ lambda: tester.load_model("openfree/gpt2-bert"),
255
+ outputs=load_status
256
+ )
257
+
258
+ generate_btn.click(
259
+ tester.generate_text,
260
+ inputs=[prompt_input, max_length, temperature, top_p, rep_penalty],
261
+ outputs=[output_text, stats_text, gr.Textbox(visible=False)]
262
+ )
263
+
264
+ compare_btn.click(
265
+ tester.compare_with_parents,
266
+ inputs=[compare_prompt, compare_length],
267
+ outputs=comparison_output
268
+ )
269
+
270
+ task_btn.click(
271
+ lambda p: tester.generate_text(p, 100, 0.8, 0.9, 1.2),
272
+ inputs=task_prompt,
273
+ outputs=[task_output, gr.Textbox(visible=False), gr.Textbox(visible=False)]
274
+ )
275
+
276
+ def run_test_suite(progress=gr.Progress()):
277
+ """์ „์ฒด ํ…Œ์ŠคํŠธ ์Šค์œ„ํŠธ ์‹คํ–‰"""
278
+ results = []
279
+
280
+ test_prompts = [
281
+ "The meaning of life is",
282
+ "import numpy as np\n",
283
+ "Scientists have discovered",
284
+ "def hello_world():",
285
+ "Breaking news:"
286
+ ]
287
+
288
+ for i, prompt in enumerate(test_prompts):
289
+ progress((i+1)/len(test_prompts), desc=f"ํ…Œ์ŠคํŠธ {i+1}/{len(test_prompts)}")
290
+ try:
291
+ output, stats, _ = tester.generate_text(prompt, 30)
292
+ results.append(f"โœ… ํ”„๋กฌํ”„ํŠธ: {prompt[:30]}...\n ์ƒ์„ฑ ์„ฑ๊ณต")
293
+ except:
294
+ results.append(f"โŒ ํ”„๋กฌํ”„ํŠธ: {prompt[:30]}...\n ์ƒ์„ฑ ์‹คํŒจ")
295
+
296
+ return "\n".join(results)
297
+
298
+ test_suite_btn.click(
299
+ run_test_suite,
300
+ outputs=test_results
301
+ )
302
+
303
+ # ์‹คํ–‰
304
+ if __name__ == "__main__":
305
+ demo.launch(share=False)