nouamanetazi HF staff commited on
Commit
109fb13
·
verified ·
1 Parent(s): 48e09b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +279 -32
app.py CHANGED
@@ -8,10 +8,24 @@ from huggingface_hub import CommitScheduler
8
  from pathlib import Path
9
  import uuid
10
  import json
 
 
 
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
13
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
14
- print(f'[INFO] Using device: {device}')
15
 
16
  # token
17
  token = os.environ['TOKEN']
@@ -19,12 +33,15 @@ token = os.environ['TOKEN']
19
  # Load the pretrained model and tokenizer
20
  MODEL_NAME = "atlasia/Al-Atlas-0.5B" # "atlasia/Al-Atlas-LLM-mid-training" # "BounharAbdelaziz/Al-Atlas-LLM-0.5B" #"atlasia/Al-Atlas-LLM"
21
 
 
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME,token=token) # , token=token
23
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,token=token).to(device)
 
24
 
25
  # Fix tokenizer padding
26
  if tokenizer.pad_token is None:
27
  tokenizer.pad_token = tokenizer.eos_token # Set pad token
 
28
 
29
  # Predefined examples
30
  examples = [
@@ -44,6 +61,7 @@ feedback_file = submit_file
44
 
45
  # Create directory if it doesn't exist
46
  submit_file.parent.mkdir(exist_ok=True, parents=True)
 
47
 
48
  scheduler = CommitScheduler(
49
  repo_id="atlasia/atlaset_inference_ds",
@@ -53,10 +71,42 @@ scheduler = CommitScheduler(
53
  every=5,
54
  token=token
55
  )
 
 
 
 
 
 
 
 
56
 
57
  @spaces.GPU
58
- def generate_text(prompt, max_length=256, temperature=0.7, top_p=0.9, top_k=150, num_beams=8, repetition_penalty=1.5):
 
 
 
 
 
 
 
 
 
 
 
59
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  output = model.generate(
61
  **inputs,
62
  max_length=max_length,
@@ -65,63 +115,260 @@ def generate_text(prompt, max_length=256, temperature=0.7, top_p=0.9, top_k=150,
65
  do_sample=True,
66
  repetition_penalty=repetition_penalty,
67
  num_beams=num_beams,
68
- top_k= top_k,
69
- early_stopping = True,
70
- pad_token_id=tokenizer.pad_token_id, # Explicit pad token
71
- eos_token_id=tokenizer.eos_token_id, # Explicit eos token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  )
73
- result=tokenizer.decode(output[0], skip_special_tokens=True)
74
- save_feedback(prompt,result,f"{max_length},{temperature},{top_p},{top_k},{num_beams},{repetition_penalty}")
75
- return result
 
 
76
 
77
  def save_feedback(input, output, params) -> None:
78
  """
79
  Append input/outputs and parameters to a JSON Lines file using a thread lock
80
  to avoid concurrent writes from different users.
81
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  with scheduler.lock:
83
- with feedback_file.open("a") as f:
84
- f.write(json.dumps({"input": input, "output": output, "params": params}))
85
- f.write("\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  if __name__ == "__main__":
 
 
88
  # Create the Gradio interface
89
- with gr.Blocks() as app:
 
 
 
 
 
 
 
 
90
  with gr.Row():
91
- with gr.Column():
92
- prompt_input = gr.Textbox(label="Prompt: دخل النص بالدارجة")
93
- max_length = gr.Slider(8, 4096, value=256, label="Max Length")
94
- temperature = gr.Slider(0.0, 2, value=0.7, label="Temperature")
95
- top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top-p")
96
- top_k = gr.Slider(1, 10000, value=150, label="Top-k")
97
- num_beams = gr.Slider(1, 20, value=8, label="Number of Beams")
98
- repetition_penalty = gr.Slider(0.0, 100.0, value=1.5, label="Repetition Penalty")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- submit_btn = gr.Button("Generate")
 
 
101
 
102
- with gr.Column():
103
- output_text = gr.Textbox(label="Generated Text in Moroccan Darija")
 
 
 
 
 
 
 
104
 
105
  # Examples section with caching
106
  gr.Examples(
107
  examples=examples,
108
  inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty],
109
- outputs=output_text,
110
  fn=generate_text,
111
  cache_examples=True
112
  )
113
 
114
- # Button action
115
  submit_btn.click(
116
  generate_text,
117
  inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty],
118
- outputs=output_text
119
  )
120
 
121
- gr.Markdown("""
122
- # Moroccan Darija LLM
 
 
 
123
 
124
- Enter a prompt and get AI-generated text using our pretrained LLM on Moroccan Darija.
125
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from pathlib import Path
9
  import uuid
10
  import json
11
+ import time
12
+ from datetime import datetime
13
+ import logging
14
 
15
 
16
+ # Configure logging
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
20
+ handlers=[
21
+ logging.FileHandler("app.log"),
22
+ logging.StreamHandler()
23
+ ]
24
+ )
25
+ logger = logging.getLogger("darija-llm")
26
+
27
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
28
+ logger.info(f'Using device: {device}')
29
 
30
  # token
31
  token = os.environ['TOKEN']
 
33
  # Load the pretrained model and tokenizer
34
  MODEL_NAME = "atlasia/Al-Atlas-0.5B" # "atlasia/Al-Atlas-LLM-mid-training" # "BounharAbdelaziz/Al-Atlas-LLM-0.5B" #"atlasia/Al-Atlas-LLM"
35
 
36
+ logger.info(f"Loading model: {MODEL_NAME}")
37
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME,token=token) # , token=token
38
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,token=token).to(device)
39
+ logger.info("Model loaded successfully")
40
 
41
  # Fix tokenizer padding
42
  if tokenizer.pad_token is None:
43
  tokenizer.pad_token = tokenizer.eos_token # Set pad token
44
+ logger.info("Set pad_token to eos_token")
45
 
46
  # Predefined examples
47
  examples = [
 
61
 
62
  # Create directory if it doesn't exist
63
  submit_file.parent.mkdir(exist_ok=True, parents=True)
64
+ logger.info(f"Created feedback file: {feedback_file}")
65
 
66
  scheduler = CommitScheduler(
67
  repo_id="atlasia/atlaset_inference_ds",
 
71
  every=5,
72
  token=token
73
  )
74
+ logger.info(f"Initialized CommitScheduler for repo: atlasia/atlaset_inference_ds")
75
+
76
+ # Track usage statistics
77
+ usage_stats = {
78
+ "total_generations": 0,
79
+ "total_tokens_generated": 0,
80
+ "start_time": time.time()
81
+ }
82
 
83
  @spaces.GPU
84
+ def generate_text(prompt, max_length=256, temperature=0.7, top_p=0.9, top_k=150, num_beams=8, repetition_penalty=1.5, progress=gr.Progress()):
85
+ if not prompt.strip():
86
+ logger.warning("Empty prompt submitted")
87
+ return "", "الرجاء إدخال نص للتوليد (Please enter text to generate)"
88
+
89
+ logger.info(f"Generating text for prompt: '{prompt[:50]}...' (length: {len(prompt)})")
90
+ logger.info(f"Parameters: max_length={max_length}, temp={temperature}, top_p={top_p}, top_k={top_k}, beams={num_beams}, rep_penalty={repetition_penalty}")
91
+
92
+ start_time = time.time()
93
+
94
+ # Update progress bar - tokenization step
95
+ progress(0.1, desc="تحليل النص (Tokenizing input)")
96
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
97
+
98
+ # Update progress bar - generation starting
99
+ progress(0.2, desc="توليد النص (Generating text)")
100
+
101
+ # Define a callback function to update progress during generation
102
+ def generation_callback(beam_idx, token_idx, token_id, scores, generation_config):
103
+ # Estimate progress based on token index and max length
104
+ # We start at 20% and go to 90%, leaving room for post-processing
105
+ progress_value = 0.2 + 0.7 * min(token_idx / max_length, 1.0)
106
+ progress(progress_value, desc=f"توليد النص: {token_idx}/{max_length} (Generating: {token_idx}/{max_length})")
107
+ return False # Continue generation
108
+
109
+ # Generate with callback
110
  output = model.generate(
111
  **inputs,
112
  max_length=max_length,
 
115
  do_sample=True,
116
  repetition_penalty=repetition_penalty,
117
  num_beams=num_beams,
118
+ top_k=top_k,
119
+ early_stopping=True,
120
+ pad_token_id=tokenizer.pad_token_id,
121
+ eos_token_id=tokenizer.eos_token_id,
122
+ callback=generation_callback if hasattr(model, "generation_config") else None,
123
+ )
124
+
125
+ # Update progress bar - decoding step
126
+ progress(0.9, desc="معالجة النتائج (Processing results)")
127
+ result = tokenizer.decode(output[0], skip_special_tokens=True)
128
+
129
+ # Update stats
130
+ generation_time = time.time() - start_time
131
+ token_count = len(output[0])
132
+
133
+ with scheduler.lock:
134
+ usage_stats["total_generations"] += 1
135
+ usage_stats["total_tokens_generated"] += token_count
136
+
137
+ logger.info(f"Generated {token_count} tokens in {generation_time:.2f}s")
138
+ logger.info(f"Result: '{result[:50]}...' (length: {len(result)})")
139
+
140
+ # Save feedback with additional metadata
141
+ save_feedback(
142
+ prompt,
143
+ result,
144
+ {
145
+ "max_length": max_length,
146
+ "temperature": temperature,
147
+ "top_p": top_p,
148
+ "top_k": top_k,
149
+ "num_beams": num_beams,
150
+ "repetition_penalty": repetition_penalty,
151
+ "generation_time": generation_time,
152
+ "token_count": token_count,
153
+ "timestamp": datetime.now().isoformat()
154
+ }
155
  )
156
+
157
+ # Complete progress
158
+ progress(1.0, desc="اكتمل (Complete)")
159
+
160
+ return result, f"تم توليد {token_count} رمز في {generation_time:.2f} ثانية (Generated {token_count} tokens in {generation_time:.2f} seconds)"
161
 
162
  def save_feedback(input, output, params) -> None:
163
  """
164
  Append input/outputs and parameters to a JSON Lines file using a thread lock
165
  to avoid concurrent writes from different users.
166
  """
167
+ logger.info(f"Saving feedback to {feedback_file}")
168
+
169
+ with scheduler.lock:
170
+ try:
171
+ with feedback_file.open("a") as f:
172
+ f.write(json.dumps({
173
+ "input": input,
174
+ "output": output,
175
+ "params": params
176
+ }))
177
+ f.write("\n")
178
+ logger.info("Feedback saved successfully")
179
+ except Exception as e:
180
+ logger.error(f"Error saving feedback: {str(e)}")
181
+
182
+ def get_stats():
183
+ """Return current usage statistics"""
184
+ with scheduler.lock:
185
+ uptime = time.time() - usage_stats["start_time"]
186
+ hours = uptime / 3600
187
+
188
+ stats = {
189
+ "Total generations": usage_stats["total_generations"],
190
+ "Total tokens generated": usage_stats["total_tokens_generated"],
191
+ "Uptime": f"{int(hours)}h {int((hours % 1) * 60)}m",
192
+ "Generations per hour": f"{usage_stats['total_generations'] / hours:.1f}" if hours > 0 else "N/A",
193
+ "Last updated": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
194
+ }
195
+
196
+ logger.info(f"Stats requested: {stats}")
197
+ return stats
198
+
199
+ def reset_params():
200
+ """Reset parameters to default values"""
201
+ logger.info("Parameters reset to defaults")
202
+ return 256, 0.7, 0.9, 150, 8, 1.5
203
+
204
+ def thumbs_up_callback(input_text, output_text):
205
+ """Record positive feedback"""
206
+ logger.info("Received positive feedback")
207
+
208
+ feedback_path = Path("user_submit") / "positive_feedback.jsonl"
209
+ feedback_path.parent.mkdir(exist_ok=True, parents=True)
210
+
211
+ with scheduler.lock:
212
+ try:
213
+ with feedback_path.open("a") as f:
214
+ feedback_data = {
215
+ "input": input_text,
216
+ "output": output_text,
217
+ "rating": "positive",
218
+ "timestamp": datetime.now().isoformat()
219
+ }
220
+ f.write(json.dumps(feedback_data))
221
+ f.write("\n")
222
+
223
+ logger.info(f"Positive feedback saved to {feedback_path}")
224
+ except Exception as e:
225
+ logger.error(f"Error saving positive feedback: {str(e)}")
226
+
227
+ return "شكرا على التقييم الإيجابي!"
228
+
229
+ def thumbs_down_callback(input_text, output_text, feedback=""):
230
+ """Record negative feedback"""
231
+ logger.info(f"Received negative feedback: '{feedback}'")
232
+
233
+ feedback_path = Path("user_submit") / "negative_feedback.jsonl"
234
+ feedback_path.parent.mkdir(exist_ok=True, parents=True)
235
+
236
  with scheduler.lock:
237
+ try:
238
+ with feedback_path.open("a") as f:
239
+ feedback_data = {
240
+ "input": input_text,
241
+ "output": output_text,
242
+ "rating": "negative",
243
+ "feedback": feedback,
244
+ "timestamp": datetime.now().isoformat()
245
+ }
246
+ f.write(json.dumps(feedback_data))
247
+ f.write("\n")
248
+
249
+ logger.info(f"Negative feedback saved to {feedback_path}")
250
+ except Exception as e:
251
+ logger.error(f"Error saving negative feedback: {str(e)}")
252
+
253
+ return "شكرا على ملاحظاتك!"
254
 
255
  if __name__ == "__main__":
256
+ logger.info("Starting Moroccan Darija LLM application")
257
+
258
  # Create the Gradio interface
259
+ with gr.Blocks(css="footer {visibility: hidden}") as app:
260
+ gr.Markdown("""
261
+ # 🇲🇦 نموذج اللغة المغربية الدارجة (Moroccan Darija LLM)
262
+
263
+ أدخل نصًا بالدارجة المغربية واحصل على نص تم إنشاؤه بواسطة نموذج اللغة الخاص بنا المدرب على الدارجة المغربية.
264
+
265
+ Enter a prompt and get AI-generated text using our pretrained LLM on Moroccan Darija.
266
+ """)
267
+
268
  with gr.Row():
269
+ with gr.Column(scale=6):
270
+ prompt_input = gr.Textbox(
271
+ label="الدخل (Prompt): دخل النص بالدارجة",
272
+ placeholder="اكتب هنا...",
273
+ lines=4
274
+ )
275
+
276
+ with gr.Row():
277
+ submit_btn = gr.Button("توليد النص (Generate)", variant="primary")
278
+ clear_btn = gr.Button("مسح (Clear)")
279
+ reset_btn = gr.Button("إعادة ضبط المعلمات (Reset Parameters)")
280
+
281
+ with gr.Accordion("معلمات التوليد (Generation Parameters)", open=False):
282
+ with gr.Row():
283
+ with gr.Column():
284
+ max_length = gr.Slider(8, 4096, value=256, label="Max Length (الطول الأقصى)")
285
+ temperature = gr.Slider(0.0, 2, value=0.7, label="Temperature (درجة الحرارة)")
286
+ top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top-p (أعلى احتمال)")
287
+
288
+ with gr.Column():
289
+ top_k = gr.Slider(1, 10000, value=150, label="Top-k (أعلى ك)")
290
+ num_beams = gr.Slider(1, 20, value=8, label="Number of Beams (عدد الأشعة)")
291
+ repetition_penalty = gr.Slider(0.0, 100.0, value=1.5, label="Repetition Penalty (عقوبة التكرار)")
292
+
293
+ with gr.Column(scale=6):
294
+ output_text = gr.Textbox(label="النص المولد (Generated Text)", lines=10)
295
+ generation_info = gr.Markdown("")
296
 
297
+ with gr.Row():
298
+ thumbs_up = gr.Button("👍 جيد")
299
+ thumbs_down = gr.Button("👎 سيء")
300
 
301
+ with gr.Accordion("تعليق (Feedback)", open=False, visible=False) as feedback_accordion:
302
+ feedback_text = gr.Textbox(label="لماذا لم يعجبك الناتج؟ (Why didn't you like the output?)", lines=2)
303
+ submit_feedback = gr.Button("إرسال التعليق (Submit Feedback)")
304
+
305
+ feedback_result = gr.Markdown("")
306
+
307
+ with gr.Accordion("إحصائيات الاستخدام (Usage Statistics)", open=False):
308
+ stats_md = gr.JSON(get_stats, every=10)
309
+ refresh_stats = gr.Button("تحديث (Refresh)")
310
 
311
  # Examples section with caching
312
  gr.Examples(
313
  examples=examples,
314
  inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty],
315
+ outputs=[output_text, generation_info],
316
  fn=generate_text,
317
  cache_examples=True
318
  )
319
 
320
+ # Button actions
321
  submit_btn.click(
322
  generate_text,
323
  inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty],
324
+ outputs=[output_text, generation_info]
325
  )
326
 
327
+ clear_btn.click(
328
+ lambda: ("", ""),
329
+ inputs=None,
330
+ outputs=[prompt_input, output_text]
331
+ )
332
 
333
+ reset_btn.click(
334
+ reset_params,
335
+ inputs=None,
336
+ outputs=[max_length, temperature, top_p, top_k, num_beams, repetition_penalty]
337
+ )
338
+
339
+ # Feedback system
340
+ thumbs_up.click(
341
+ thumbs_up_callback,
342
+ inputs=[prompt_input, output_text],
343
+ outputs=[feedback_result]
344
+ )
345
+
346
+ thumbs_down.click(
347
+ lambda: (gr.Accordion.update(visible=True, open=True), ""),
348
+ inputs=None,
349
+ outputs=[feedback_accordion, feedback_result]
350
+ )
351
 
352
+ submit_feedback.click(
353
+ thumbs_down_callback,
354
+ inputs=[prompt_input, output_text, feedback_text],
355
+ outputs=[feedback_result]
356
+ )
357
+
358
+ # Stats refresh
359
+ refresh_stats.click(
360
+ get_stats,
361
+ inputs=None,
362
+ outputs=[stats_md]
363
+ )
364
+
365
+ # Keyboard shortcuts
366
+ prompt_input.submit(
367
+ generate_text,
368
+ inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty],
369
+ outputs=[output_text, generation_info]
370
+ )
371
+
372
+ logger.info("Launching Gradio interface")
373
+ app.launch()
374
+ logger.info("Gradio interface closed")