File size: 13,924 Bytes
30ef1d4
 
1557b00
3c739a1
 
75c3f8c
 
33f2b2c
fa004d4
48e09b8
109fb13
 
 
3c739a1
 
109fb13
 
 
 
 
 
 
 
 
 
 
3c739a1
109fb13
1557b00
 
 
30ef1d4
 
e6ae8f1
30ef1d4
109fb13
348a268
 
109fb13
30ef1d4
3c73224
 
 
109fb13
3c73224
7a9fdcc
 
 
8d2365b
7a9fdcc
8d2365b
7a9fdcc
8d2365b
7a9fdcc
8d2365b
7a9fdcc
 
48e09b8
75c3f8c
48e09b8
 
 
 
109fb13
48e09b8
4cff703
48e09b8
 
 
 
ece7108
48e09b8
 
109fb13
 
 
 
 
 
 
 
ea7e643
3c739a1
8cafaac
109fb13
 
 
 
 
 
 
 
 
8cafaac
 
 
c45c066
 
8cafaac
c45c066
ea7e643
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c45c066
 
8cafaac
c45c066
109fb13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30ef1d4
109fb13
8cafaac
 
109fb13
30ef1d4
48e09b8
 
 
 
 
109fb13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea7e643
109fb13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75c3f8c
109fb13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75c3f8c
30ef1d4
109fb13
 
33cf987
374d720
 
 
 
 
 
 
 
 
22c8c96
5c22f96
5e545f6
9990132
a81d596
9990132
37d4ad6
 
9990132
374d720
48e09b8
109fb13
 
6f09e8c
109fb13
ece7108
109fb13
 
 
4746dfe
 
 
109fb13
4746dfe
109fb13
 
4746dfe
6f09e8c
4746dfe
109fb13
 
4746dfe
 
 
ea7e643
109fb13
4746dfe
109fb13
48e09b8
109fb13
4746dfe
 
48e09b8
4746dfe
 
 
109fb13
 
 
4746dfe
109fb13
4746dfe
48e09b8
 
 
33cf987
48e09b8
109fb13
48e09b8
 
 
 
109fb13
48e09b8
 
 
109fb13
33cf987
48e09b8
109fb13
 
 
 
 
48e09b8
109fb13
 
 
 
 
 
 
 
 
 
 
 
 
 
a6954be
 
 
109fb13
48e09b8
109fb13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import spaces
import torch
from datasets import load_dataset
from huggingface_hub import CommitScheduler
from pathlib import Path
import uuid
import json
import time
from datetime import datetime
import logging


# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("app.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger("darija-llm")

device = "cuda:0" if torch.cuda.is_available() else "cpu"
logger.info(f'Using device: {device}')

# token
token = os.environ['TOKEN']

# Load the pretrained model and tokenizer
MODEL_NAME = "atlasia/Al-Atlas-0.5B" # "atlasia/Al-Atlas-LLM-mid-training" # "BounharAbdelaziz/Al-Atlas-LLM-0.5B" #"atlasia/Al-Atlas-LLM"

logger.info(f"Loading model: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME,token=token) # , token=token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,token=token).to(device)
logger.info("Model loaded successfully")

# Fix tokenizer padding
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token  # Set pad token
    logger.info("Set pad_token to eos_token")

# Predefined examples
examples = [
    ["الذكاء الاصطناعي هو فرع من علوم الكمبيوتر اللي كيركز"
     , 256, 0.7, 0.9, 100, 4, 1.5],
    ["المستقبل ديال الذكاء الصناعي فالمغرب"
     , 256, 0.7, 0.9, 100, 4, 1.5],
    [" المطبخ المغربي"
     , 256, 0.7, 0.9, 100, 4, 1.5],
    ["الماكلة المغربية كتعتبر من أحسن الماكلات فالعالم"
     , 256, 0.7, 0.9, 100, 4, 1.5],
]

# Define the file where to save the data
submit_file = Path("user_submit/") / f"data_{uuid.uuid4()}.json"
feedback_file = submit_file

# Create directory if it doesn't exist
submit_file.parent.mkdir(exist_ok=True, parents=True)
logger.info(f"Created feedback file: {feedback_file}")

scheduler = CommitScheduler(
    repo_id="atlasia/atlaset_inference_ds",
    repo_type="dataset",
    folder_path=submit_file.parent,
    path_in_repo="data",
    every=5,
    token=token
)
logger.info(f"Initialized CommitScheduler for repo: atlasia/atlaset_inference_ds")

# Track usage statistics
usage_stats = {
    "total_generations": 0,
    "total_tokens_generated": 0,
    "start_time": time.time()
}

@spaces.GPU
def generate_text(prompt, max_length=256, temperature=0.7, top_p=0.9, top_k=150, num_beams=8, repetition_penalty=1.5, progress=gr.Progress()):
    if not prompt.strip():
        logger.warning("Empty prompt submitted")
        return "", "الرجاء إدخال نص للتوليد (Please enter text to generate)"
    
    logger.info(f"Generating text for prompt: '{prompt[:50]}...' (length: {len(prompt)})")
    logger.info(f"Parameters: max_length={max_length}, temp={temperature}, top_p={top_p}, top_k={top_k}, beams={num_beams}, rep_penalty={repetition_penalty}")
    
    start_time = time.time()
    
    # Start progress
    progress(0, desc="تجهيز النموذج (Preparing model)")
    
    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    progress(0.1, desc="تحليل النص (Tokenizing)")
    
    # Generate text with optimized parameters for speed
    progress(0.2, desc="توليد النص (Generating text)")
    output = model.generate(
        **inputs, 
        max_length=max_length, 
        temperature=temperature, 
        top_p=top_p, 
        do_sample=True,
        repetition_penalty=repetition_penalty,
        num_beams=1 if num_beams > 4 else num_beams,  # Reduce beam search or use greedy decoding
        top_k=top_k,
        early_stopping=True,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        use_cache=True,  # Ensure cache is used
    )
    
    # Decode output
    progress(0.9, desc="معالجة النتائج (Processing results)")
    result = tokenizer.decode(output[0], skip_special_tokens=True)
    
    # Update stats
    generation_time = time.time() - start_time
    token_count = len(output[0])
    
    with scheduler.lock:
        usage_stats["total_generations"] += 1
        usage_stats["total_tokens_generated"] += token_count
    
    logger.info(f"Generated {token_count} tokens in {generation_time:.2f}s")
    logger.info(f"Result: '{result[:50]}...' (length: {len(result)})")
    
    # Save feedback with additional metadata
    save_feedback(
        prompt, 
        result,
        {
            "max_length": max_length,
            "temperature": temperature,
            "top_p": top_p,
            "top_k": top_k,
            "num_beams": num_beams,
            "repetition_penalty": repetition_penalty,
            "generation_time": generation_time,
            "token_count": token_count,
            "timestamp": datetime.now().isoformat()
        }
    )
    
    progress(1.0, desc="اكتمل (Complete)")
    
    return result, f"تم توليد {token_count} رمز في {generation_time:.2f} ثانية (Generated {token_count} tokens in {generation_time:.2f} seconds)"

def save_feedback(input, output, params) -> None:
    """
    Append input/outputs and parameters to a JSON Lines file using a thread lock
    to avoid concurrent writes from different users.
    """
    logger.info(f"Saving feedback to {feedback_file}")
    
    with scheduler.lock:
        try:
            with feedback_file.open("a") as f:
                f.write(json.dumps({
                    "input": input, 
                    "output": output, 
                    "params": params
                }))
                f.write("\n")
            logger.info("Feedback saved successfully")
        except Exception as e:
            logger.error(f"Error saving feedback: {str(e)}")

def get_stats():
    """Return current usage statistics"""
    with scheduler.lock:
        uptime = time.time() - usage_stats["start_time"]
        hours = uptime / 3600
        
        stats = {
            "Total generations": usage_stats["total_generations"],
            "Total tokens generated": usage_stats["total_tokens_generated"],
            "Uptime": f"{int(hours)}h {int((hours % 1) * 60)}m",
            "Generations per hour": f"{usage_stats['total_generations'] / hours:.1f}" if hours > 0 else "N/A",
            "Last updated": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        }
        
        logger.info(f"Stats requested: {stats}")
        return stats

def reset_params():
    """Reset parameters to default values"""
    logger.info("Parameters reset to defaults")
    return 128, 0.7, 0.9, 50, 1, 1.2  # Updated defaults for faster generation

def thumbs_up_callback(input_text, output_text):
    """Record positive feedback"""
    logger.info("Received positive feedback")
    
    feedback_path = Path("user_submit") / "positive_feedback.jsonl"
    feedback_path.parent.mkdir(exist_ok=True, parents=True)
    
    with scheduler.lock:
        try:
            with feedback_path.open("a") as f:
                feedback_data = {
                    "input": input_text, 
                    "output": output_text, 
                    "rating": "positive",
                    "timestamp": datetime.now().isoformat()
                }
                f.write(json.dumps(feedback_data))
                f.write("\n")
            
            logger.info(f"Positive feedback saved to {feedback_path}")
        except Exception as e:
            logger.error(f"Error saving positive feedback: {str(e)}")
    
    return "شكرا على التقييم الإيجابي!"

def thumbs_down_callback(input_text, output_text, feedback=""):
    """Record negative feedback"""
    logger.info(f"Received negative feedback: '{feedback}'")
    
    feedback_path = Path("user_submit") / "negative_feedback.jsonl"
    feedback_path.parent.mkdir(exist_ok=True, parents=True)
    
    with scheduler.lock:
        try:
            with feedback_path.open("a") as f:
                feedback_data = {
                    "input": input_text, 
                    "output": output_text, 
                    "rating": "negative",
                    "feedback": feedback,
                    "timestamp": datetime.now().isoformat()
                }
                f.write(json.dumps(feedback_data))
                f.write("\n")
            
            logger.info(f"Negative feedback saved to {feedback_path}")
        except Exception as e:
            logger.error(f"Error saving negative feedback: {str(e)}")
    
    return "شكرا على ملاحظاتك!"
    
if __name__ == "__main__":
    logger.info("Starting Moroccan Darija LLM application")
    
    # Create the Gradio interface
    with gr.Blocks(css="""
        footer {visibility: hidden}
        .center-text {text-align: center; margin: 0 auto; max-width: 900px;}
        .header-text {font-size: 2.5rem; font-weight: bold; margin-bottom: 0.5rem;}
        .subheader-text {font-size: 1.2rem; margin-bottom: 2rem;}
        .flag-emoji {font-size: 3rem;}
    """) as app:
        with gr.Row(elem_classes=["center-text"]):
            gr.Markdown("""
            # 🇲🇦🇲🇦🇲🇦
            
            # Al-Atlas-0.5B-base
            """)
        with gr.Row():
            gr.Markdown("""

            This is a pretrained model to do text generation in a continuation of text fashion. Do not expect it to behave as a Chat (Instruct) model. The latter is coming soon!
            This model is pretrained on Moroccan darija in **Arabic scripts** (mainly).
            """)
        with gr.Row():
            with gr.Column(scale=6):
                prompt_input = gr.Textbox(
                    label="Prompt ", 
                    placeholder="اكتب هنا...",
                    lines=4, rtl=True
                )
                
                with gr.Row():
                    submit_btn = gr.Button("Generate", variant="primary")
                    clear_btn = gr.Button("Clear")
                    reset_btn = gr.Button("Reset Parameters")
                
                with gr.Accordion("Generation Parameters", open=False):
                    with gr.Row():
                        with gr.Column():
                            max_length = gr.Slider(8, 4096, value=128, label="Max Length")  # Reduced default
                            temperature = gr.Slider(0.0, 2, value=0.7, label="Temperature")
                            top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top-p")
                        
                        with gr.Column():
                            top_k = gr.Slider(1, 10000, value=50, label="Top-k")  # Reduced default
                            num_beams = gr.Slider(1, 20, value=1, label="Number of Beams")  # Reduced default
                            repetition_penalty = gr.Slider(0.0, 100.0, value=1.2, label="Repetition Penalty")  # Reduced default
                
            with gr.Column(scale=6):
                output_text = gr.Textbox(label="Generated Text", lines=10, rtl=True)
                generation_info = gr.Markdown("")
                
                with gr.Row():
                    thumbs_up = gr.Button("👍 ناضي")
                    thumbs_down = gr.Button("👎 عيان")
                
                with gr.Accordion("Feedback", open=False, visible=False) as feedback_accordion:
                    feedback_text = gr.Textbox(label="Why didn't you like the output?", lines=2, rtl=True)
                    submit_feedback = gr.Button("Submit Feedback")
                
                feedback_result = gr.Markdown("")
        
        with gr.Accordion("Usage Statistics", open=False):
            stats_md = gr.JSON(get_stats, every=10)
            refresh_stats = gr.Button("Refresh")
        
        # Examples section with caching
        gr.Examples(
            examples=examples,
            inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty],
            outputs=[output_text, generation_info],
            fn=generate_text,
            cache_examples=True
        )
        
        # Button actions
        submit_btn.click(
            generate_text,
            inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty],
            outputs=[output_text, generation_info]
        )
        
        clear_btn.click(
            lambda: ("", ""),
            inputs=None,
            outputs=[prompt_input, output_text]
        )
        
        reset_btn.click(
            reset_params,
            inputs=None,
            outputs=[max_length, temperature, top_p, top_k, num_beams, repetition_penalty]
        )
        
        # Feedback system
        thumbs_up.click(
            thumbs_up_callback,
            inputs=[prompt_input, output_text],
            outputs=[feedback_result]
        )
        
        thumbs_down.click(
            thumbs_down_callback,
            inputs=[prompt_input, output_text],
            outputs=[feedback_result]
        )
        
        submit_feedback.click(
            thumbs_down_callback,
            inputs=[prompt_input, output_text, feedback_text],
            outputs=[feedback_result]
        )
        
        # Stats refresh
        refresh_stats.click(
            get_stats,
            inputs=None,
            outputs=[stats_md]
        )
        
        # Keyboard shortcuts
        prompt_input.submit(
            generate_text,
            inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty],
            outputs=[output_text, generation_info]
        )
    
    logger.info("Launching Gradio interface")
    app.launch()
    logger.info("Gradio interface closed")