import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import os import spaces import torch from datasets import load_dataset from huggingface_hub import CommitScheduler from pathlib import Path import uuid import json import time from datetime import datetime import logging # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("app.log"), logging.StreamHandler() ] ) logger = logging.getLogger("darija-llm") device = "cuda:0" if torch.cuda.is_available() else "cpu" logger.info(f'Using device: {device}') # token token = os.environ['TOKEN'] # Load the pretrained model and tokenizer MODEL_NAME = "atlasia/Al-Atlas-0.5B" # "atlasia/Al-Atlas-LLM-mid-training" # "BounharAbdelaziz/Al-Atlas-LLM-0.5B" #"atlasia/Al-Atlas-LLM" logger.info(f"Loading model: {MODEL_NAME}") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME,token=token) # , token=token model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,token=token).to(device) logger.info("Model loaded successfully") # Fix tokenizer padding if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Set pad token logger.info("Set pad_token to eos_token") # Predefined examples examples = [ ["الذكاء الاصطناعي هو فرع من علوم الكمبيوتر اللي كيركز" , 256, 0.7, 0.9, 100, 4, 1.5], ["المستقبل ديال الذكاء الصناعي فالمغرب" , 256, 0.7, 0.9, 100, 4, 1.5], [" المطبخ المغربي" , 256, 0.7, 0.9, 100, 4, 1.5], ["الماكلة المغربية كتعتبر من أحسن الماكلات فالعالم" , 256, 0.7, 0.9, 100, 4, 1.5], ] # Define the file where to save the data submit_file = Path("user_submit/") / f"data_{uuid.uuid4()}.json" feedback_file = submit_file # Create directory if it doesn't exist submit_file.parent.mkdir(exist_ok=True, parents=True) logger.info(f"Created feedback file: {feedback_file}") scheduler = CommitScheduler( repo_id="atlasia/atlaset_inference_ds", repo_type="dataset", folder_path=submit_file.parent, path_in_repo="data", every=5, token=token ) logger.info(f"Initialized CommitScheduler for repo: atlasia/atlaset_inference_ds") # Track usage statistics usage_stats = { "total_generations": 0, "total_tokens_generated": 0, "start_time": time.time() } @spaces.GPU def generate_text(prompt, max_length=256, temperature=0.7, top_p=0.9, top_k=150, num_beams=8, repetition_penalty=1.5, progress=gr.Progress()): if not prompt.strip(): logger.warning("Empty prompt submitted") return "", "الرجاء إدخال نص للتوليد (Please enter text to generate)" logger.info(f"Generating text for prompt: '{prompt[:50]}...' (length: {len(prompt)})") logger.info(f"Parameters: max_length={max_length}, temp={temperature}, top_p={top_p}, top_k={top_k}, beams={num_beams}, rep_penalty={repetition_penalty}") start_time = time.time() # Start progress progress(0, desc="تجهيز النموذج (Preparing model)") # Tokenize input inputs = tokenizer(prompt, return_tensors="pt").to(model.device) progress(0.1, desc="تحليل النص (Tokenizing)") # Generate text with optimized parameters for speed progress(0.2, desc="توليد النص (Generating text)") output = model.generate( **inputs, max_length=max_length, temperature=temperature, top_p=top_p, do_sample=True, repetition_penalty=repetition_penalty, num_beams=1 if num_beams > 4 else num_beams, # Reduce beam search or use greedy decoding top_k=top_k, early_stopping=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, use_cache=True, # Ensure cache is used ) # Decode output progress(0.9, desc="معالجة النتائج (Processing results)") result = tokenizer.decode(output[0], skip_special_tokens=True) # Update stats generation_time = time.time() - start_time token_count = len(output[0]) with scheduler.lock: usage_stats["total_generations"] += 1 usage_stats["total_tokens_generated"] += token_count logger.info(f"Generated {token_count} tokens in {generation_time:.2f}s") logger.info(f"Result: '{result[:50]}...' (length: {len(result)})") # Save feedback with additional metadata save_feedback( prompt, result, { "max_length": max_length, "temperature": temperature, "top_p": top_p, "top_k": top_k, "num_beams": num_beams, "repetition_penalty": repetition_penalty, "generation_time": generation_time, "token_count": token_count, "timestamp": datetime.now().isoformat() } ) progress(1.0, desc="اكتمل (Complete)") return result, f"تم توليد {token_count} رمز في {generation_time:.2f} ثانية (Generated {token_count} tokens in {generation_time:.2f} seconds)" def save_feedback(input, output, params) -> None: """ Append input/outputs and parameters to a JSON Lines file using a thread lock to avoid concurrent writes from different users. """ logger.info(f"Saving feedback to {feedback_file}") with scheduler.lock: try: with feedback_file.open("a") as f: f.write(json.dumps({ "input": input, "output": output, "params": params })) f.write("\n") logger.info("Feedback saved successfully") except Exception as e: logger.error(f"Error saving feedback: {str(e)}") def get_stats(): """Return current usage statistics""" with scheduler.lock: uptime = time.time() - usage_stats["start_time"] hours = uptime / 3600 stats = { "Total generations": usage_stats["total_generations"], "Total tokens generated": usage_stats["total_tokens_generated"], "Uptime": f"{int(hours)}h {int((hours % 1) * 60)}m", "Generations per hour": f"{usage_stats['total_generations'] / hours:.1f}" if hours > 0 else "N/A", "Last updated": datetime.now().strftime("%Y-%m-%d %H:%M:%S") } logger.info(f"Stats requested: {stats}") return stats def reset_params(): """Reset parameters to default values""" logger.info("Parameters reset to defaults") return 128, 0.7, 0.9, 50, 1, 1.2 # Updated defaults for faster generation def thumbs_up_callback(input_text, output_text): """Record positive feedback""" logger.info("Received positive feedback") feedback_path = Path("user_submit") / "positive_feedback.jsonl" feedback_path.parent.mkdir(exist_ok=True, parents=True) with scheduler.lock: try: with feedback_path.open("a") as f: feedback_data = { "input": input_text, "output": output_text, "rating": "positive", "timestamp": datetime.now().isoformat() } f.write(json.dumps(feedback_data)) f.write("\n") logger.info(f"Positive feedback saved to {feedback_path}") except Exception as e: logger.error(f"Error saving positive feedback: {str(e)}") return "شكرا على التقييم الإيجابي!" def thumbs_down_callback(input_text, output_text, feedback=""): """Record negative feedback""" logger.info(f"Received negative feedback: '{feedback}'") feedback_path = Path("user_submit") / "negative_feedback.jsonl" feedback_path.parent.mkdir(exist_ok=True, parents=True) with scheduler.lock: try: with feedback_path.open("a") as f: feedback_data = { "input": input_text, "output": output_text, "rating": "negative", "feedback": feedback, "timestamp": datetime.now().isoformat() } f.write(json.dumps(feedback_data)) f.write("\n") logger.info(f"Negative feedback saved to {feedback_path}") except Exception as e: logger.error(f"Error saving negative feedback: {str(e)}") return "شكرا على ملاحظاتك!" if __name__ == "__main__": logger.info("Starting Moroccan Darija LLM application") # Create the Gradio interface with gr.Blocks(css=""" footer {visibility: hidden} .center-text {text-align: center; margin: 0 auto; max-width: 900px;} .header-text {font-size: 2.5rem; font-weight: bold; margin-bottom: 0.5rem;} .subheader-text {font-size: 1.2rem; margin-bottom: 2rem;} .flag-emoji {font-size: 3rem;} """) as app: with gr.Row(elem_classes=["center-text"]): gr.Markdown(""" # 🇲🇦🇲🇦🇲🇦 # Al-Atlas-0.5B-base """) with gr.Row(): gr.Markdown(""" This is a pretrained model to do text generation in a continuation of text fashion. Do not expect it to behave as a Chat (Instruct) model. The latter is coming soon! This model is pretrained on Moroccan darija in **Arabic scripts** (mainly). """) with gr.Row(): with gr.Column(scale=6): prompt_input = gr.Textbox( label="Prompt ", placeholder="اكتب هنا...", lines=4, rtl=True ) with gr.Row(): submit_btn = gr.Button("Generate", variant="primary") clear_btn = gr.Button("Clear") reset_btn = gr.Button("Reset Parameters") with gr.Accordion("Generation Parameters", open=False): with gr.Row(): with gr.Column(): max_length = gr.Slider(8, 4096, value=128, label="Max Length") # Reduced default temperature = gr.Slider(0.0, 2, value=0.7, label="Temperature") top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top-p") with gr.Column(): top_k = gr.Slider(1, 10000, value=50, label="Top-k") # Reduced default num_beams = gr.Slider(1, 20, value=1, label="Number of Beams") # Reduced default repetition_penalty = gr.Slider(0.0, 100.0, value=1.2, label="Repetition Penalty") # Reduced default with gr.Column(scale=6): output_text = gr.Textbox(label="Generated Text", lines=10, rtl=True) generation_info = gr.Markdown("") with gr.Row(): thumbs_up = gr.Button("👍 ناضي") thumbs_down = gr.Button("👎 عيان") with gr.Accordion("Feedback", open=False, visible=False) as feedback_accordion: feedback_text = gr.Textbox(label="Why didn't you like the output?", lines=2, rtl=True) submit_feedback = gr.Button("Submit Feedback") feedback_result = gr.Markdown("") with gr.Accordion("Usage Statistics", open=False): stats_md = gr.JSON(get_stats, every=10) refresh_stats = gr.Button("Refresh") # Examples section with caching gr.Examples( examples=examples, inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty], outputs=[output_text, generation_info], fn=generate_text, cache_examples=True ) # Button actions submit_btn.click( generate_text, inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty], outputs=[output_text, generation_info] ) clear_btn.click( lambda: ("", ""), inputs=None, outputs=[prompt_input, output_text] ) reset_btn.click( reset_params, inputs=None, outputs=[max_length, temperature, top_p, top_k, num_beams, repetition_penalty] ) # Feedback system thumbs_up.click( thumbs_up_callback, inputs=[prompt_input, output_text], outputs=[feedback_result] ) thumbs_down.click( thumbs_down_callback, inputs=[prompt_input, output_text], outputs=[feedback_result] ) submit_feedback.click( thumbs_down_callback, inputs=[prompt_input, output_text, feedback_text], outputs=[feedback_result] ) # Stats refresh refresh_stats.click( get_stats, inputs=None, outputs=[stats_md] ) # Keyboard shortcuts prompt_input.submit( generate_text, inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty], outputs=[output_text, generation_info] ) logger.info("Launching Gradio interface") app.launch() logger.info("Gradio interface closed")