Al-Atlas-LLM / app.py
BounharAbdelaziz's picture
put disclamer in new row
a81d596 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import spaces
import torch
from datasets import load_dataset
from huggingface_hub import CommitScheduler
from pathlib import Path
import uuid
import json
import time
from datetime import datetime
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("app.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger("darija-llm")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
logger.info(f'Using device: {device}')
# token
token = os.environ['TOKEN']
# Load the pretrained model and tokenizer
MODEL_NAME = "atlasia/Al-Atlas-0.5B" # "atlasia/Al-Atlas-LLM-mid-training" # "BounharAbdelaziz/Al-Atlas-LLM-0.5B" #"atlasia/Al-Atlas-LLM"
logger.info(f"Loading model: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME,token=token) # , token=token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,token=token).to(device)
logger.info("Model loaded successfully")
# Fix tokenizer padding
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token # Set pad token
logger.info("Set pad_token to eos_token")
# Predefined examples
examples = [
["الذكاء الاصطناعي هو فرع من علوم الكمبيوتر اللي كيركز"
, 256, 0.7, 0.9, 100, 4, 1.5],
["المستقبل ديال الذكاء الصناعي فالمغرب"
, 256, 0.7, 0.9, 100, 4, 1.5],
[" المطبخ المغربي"
, 256, 0.7, 0.9, 100, 4, 1.5],
["الماكلة المغربية كتعتبر من أحسن الماكلات فالعالم"
, 256, 0.7, 0.9, 100, 4, 1.5],
]
# Define the file where to save the data
submit_file = Path("user_submit/") / f"data_{uuid.uuid4()}.json"
feedback_file = submit_file
# Create directory if it doesn't exist
submit_file.parent.mkdir(exist_ok=True, parents=True)
logger.info(f"Created feedback file: {feedback_file}")
scheduler = CommitScheduler(
repo_id="atlasia/atlaset_inference_ds",
repo_type="dataset",
folder_path=submit_file.parent,
path_in_repo="data",
every=5,
token=token
)
logger.info(f"Initialized CommitScheduler for repo: atlasia/atlaset_inference_ds")
# Track usage statistics
usage_stats = {
"total_generations": 0,
"total_tokens_generated": 0,
"start_time": time.time()
}
@spaces.GPU
def generate_text(prompt, max_length=256, temperature=0.7, top_p=0.9, top_k=150, num_beams=8, repetition_penalty=1.5, progress=gr.Progress()):
if not prompt.strip():
logger.warning("Empty prompt submitted")
return "", "الرجاء إدخال نص للتوليد (Please enter text to generate)"
logger.info(f"Generating text for prompt: '{prompt[:50]}...' (length: {len(prompt)})")
logger.info(f"Parameters: max_length={max_length}, temp={temperature}, top_p={top_p}, top_k={top_k}, beams={num_beams}, rep_penalty={repetition_penalty}")
start_time = time.time()
# Start progress
progress(0, desc="تجهيز النموذج (Preparing model)")
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
progress(0.1, desc="تحليل النص (Tokenizing)")
# Generate text with optimized parameters for speed
progress(0.2, desc="توليد النص (Generating text)")
output = model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
repetition_penalty=repetition_penalty,
num_beams=1 if num_beams > 4 else num_beams, # Reduce beam search or use greedy decoding
top_k=top_k,
early_stopping=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True, # Ensure cache is used
)
# Decode output
progress(0.9, desc="معالجة النتائج (Processing results)")
result = tokenizer.decode(output[0], skip_special_tokens=True)
# Update stats
generation_time = time.time() - start_time
token_count = len(output[0])
with scheduler.lock:
usage_stats["total_generations"] += 1
usage_stats["total_tokens_generated"] += token_count
logger.info(f"Generated {token_count} tokens in {generation_time:.2f}s")
logger.info(f"Result: '{result[:50]}...' (length: {len(result)})")
# Save feedback with additional metadata
save_feedback(
prompt,
result,
{
"max_length": max_length,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"num_beams": num_beams,
"repetition_penalty": repetition_penalty,
"generation_time": generation_time,
"token_count": token_count,
"timestamp": datetime.now().isoformat()
}
)
progress(1.0, desc="اكتمل (Complete)")
return result, f"تم توليد {token_count} رمز في {generation_time:.2f} ثانية (Generated {token_count} tokens in {generation_time:.2f} seconds)"
def save_feedback(input, output, params) -> None:
"""
Append input/outputs and parameters to a JSON Lines file using a thread lock
to avoid concurrent writes from different users.
"""
logger.info(f"Saving feedback to {feedback_file}")
with scheduler.lock:
try:
with feedback_file.open("a") as f:
f.write(json.dumps({
"input": input,
"output": output,
"params": params
}))
f.write("\n")
logger.info("Feedback saved successfully")
except Exception as e:
logger.error(f"Error saving feedback: {str(e)}")
def get_stats():
"""Return current usage statistics"""
with scheduler.lock:
uptime = time.time() - usage_stats["start_time"]
hours = uptime / 3600
stats = {
"Total generations": usage_stats["total_generations"],
"Total tokens generated": usage_stats["total_tokens_generated"],
"Uptime": f"{int(hours)}h {int((hours % 1) * 60)}m",
"Generations per hour": f"{usage_stats['total_generations'] / hours:.1f}" if hours > 0 else "N/A",
"Last updated": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
logger.info(f"Stats requested: {stats}")
return stats
def reset_params():
"""Reset parameters to default values"""
logger.info("Parameters reset to defaults")
return 128, 0.7, 0.9, 50, 1, 1.2 # Updated defaults for faster generation
def thumbs_up_callback(input_text, output_text):
"""Record positive feedback"""
logger.info("Received positive feedback")
feedback_path = Path("user_submit") / "positive_feedback.jsonl"
feedback_path.parent.mkdir(exist_ok=True, parents=True)
with scheduler.lock:
try:
with feedback_path.open("a") as f:
feedback_data = {
"input": input_text,
"output": output_text,
"rating": "positive",
"timestamp": datetime.now().isoformat()
}
f.write(json.dumps(feedback_data))
f.write("\n")
logger.info(f"Positive feedback saved to {feedback_path}")
except Exception as e:
logger.error(f"Error saving positive feedback: {str(e)}")
return "شكرا على التقييم الإيجابي!"
def thumbs_down_callback(input_text, output_text, feedback=""):
"""Record negative feedback"""
logger.info(f"Received negative feedback: '{feedback}'")
feedback_path = Path("user_submit") / "negative_feedback.jsonl"
feedback_path.parent.mkdir(exist_ok=True, parents=True)
with scheduler.lock:
try:
with feedback_path.open("a") as f:
feedback_data = {
"input": input_text,
"output": output_text,
"rating": "negative",
"feedback": feedback,
"timestamp": datetime.now().isoformat()
}
f.write(json.dumps(feedback_data))
f.write("\n")
logger.info(f"Negative feedback saved to {feedback_path}")
except Exception as e:
logger.error(f"Error saving negative feedback: {str(e)}")
return "شكرا على ملاحظاتك!"
if __name__ == "__main__":
logger.info("Starting Moroccan Darija LLM application")
# Create the Gradio interface
with gr.Blocks(css="""
footer {visibility: hidden}
.center-text {text-align: center; margin: 0 auto; max-width: 900px;}
.header-text {font-size: 2.5rem; font-weight: bold; margin-bottom: 0.5rem;}
.subheader-text {font-size: 1.2rem; margin-bottom: 2rem;}
.flag-emoji {font-size: 3rem;}
""") as app:
with gr.Row(elem_classes=["center-text"]):
gr.Markdown("""
# 🇲🇦🇲🇦🇲🇦
# Al-Atlas-0.5B-base
""")
with gr.Row():
gr.Markdown("""
This is a pretrained model to do text generation in a continuation of text fashion. Do not expect it to behave as a Chat (Instruct) model. The latter is coming soon!
This model is pretrained on Moroccan darija in **Arabic scripts** (mainly).
""")
with gr.Row():
with gr.Column(scale=6):
prompt_input = gr.Textbox(
label="Prompt ",
placeholder="اكتب هنا...",
lines=4, rtl=True
)
with gr.Row():
submit_btn = gr.Button("Generate", variant="primary")
clear_btn = gr.Button("Clear")
reset_btn = gr.Button("Reset Parameters")
with gr.Accordion("Generation Parameters", open=False):
with gr.Row():
with gr.Column():
max_length = gr.Slider(8, 4096, value=128, label="Max Length") # Reduced default
temperature = gr.Slider(0.0, 2, value=0.7, label="Temperature")
top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top-p")
with gr.Column():
top_k = gr.Slider(1, 10000, value=50, label="Top-k") # Reduced default
num_beams = gr.Slider(1, 20, value=1, label="Number of Beams") # Reduced default
repetition_penalty = gr.Slider(0.0, 100.0, value=1.2, label="Repetition Penalty") # Reduced default
with gr.Column(scale=6):
output_text = gr.Textbox(label="Generated Text", lines=10, rtl=True)
generation_info = gr.Markdown("")
with gr.Row():
thumbs_up = gr.Button("👍 ناضي")
thumbs_down = gr.Button("👎 عيان")
with gr.Accordion("Feedback", open=False, visible=False) as feedback_accordion:
feedback_text = gr.Textbox(label="Why didn't you like the output?", lines=2, rtl=True)
submit_feedback = gr.Button("Submit Feedback")
feedback_result = gr.Markdown("")
with gr.Accordion("Usage Statistics", open=False):
stats_md = gr.JSON(get_stats, every=10)
refresh_stats = gr.Button("Refresh")
# Examples section with caching
gr.Examples(
examples=examples,
inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty],
outputs=[output_text, generation_info],
fn=generate_text,
cache_examples=True
)
# Button actions
submit_btn.click(
generate_text,
inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty],
outputs=[output_text, generation_info]
)
clear_btn.click(
lambda: ("", ""),
inputs=None,
outputs=[prompt_input, output_text]
)
reset_btn.click(
reset_params,
inputs=None,
outputs=[max_length, temperature, top_p, top_k, num_beams, repetition_penalty]
)
# Feedback system
thumbs_up.click(
thumbs_up_callback,
inputs=[prompt_input, output_text],
outputs=[feedback_result]
)
thumbs_down.click(
thumbs_down_callback,
inputs=[prompt_input, output_text],
outputs=[feedback_result]
)
submit_feedback.click(
thumbs_down_callback,
inputs=[prompt_input, output_text, feedback_text],
outputs=[feedback_result]
)
# Stats refresh
refresh_stats.click(
get_stats,
inputs=None,
outputs=[stats_md]
)
# Keyboard shortcuts
prompt_input.submit(
generate_text,
inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty],
outputs=[output_text, generation_info]
)
logger.info("Launching Gradio interface")
app.launch()
logger.info("Gradio interface closed")