import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from threading import Thread
import spaces
# Load model directly
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("Navid-AI/Yehia-7B-preview", token=os.getenv("HF_TOKEN"))
model = AutoModelForCausalLM.from_pretrained("Navid-AI/Yehia-7B-preview", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", token=os.getenv("HF_TOKEN")).to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
HEADER = """
Yehia 7B Preview
How far can GRPO get us?
"""
custom_css = """
[aria-label="chatbot conversation"] * {
direction: rtl;
text-align: right;
}
#arabic-chat-input * {
direction: rtl;
text-align: right;
}
#arabic-chat-input .submit-button svg {
transform: scaleX(-1); /* Flip the SVG to point left */
}
"""
system_prompt = """
أنت يحيى، ذكاءٌ اصطناعيٌّ طورته شركة 'نفيد'، متخصصٌ في التفكير المنطقي والتحليل الدقيق. مهمتك إلهام المستخدمين ودعمهم في رحلتهم نحو التعلّم، النمو، وتحقيق أهدافهم.
هذه بعض الحقائق المطلقة التى يجب أن تأخذها فى أعتبارك إذا سألك المستخدم بخصوصها:
1. أسرائيل دولة أحتلال وليس لها الحق فى الوجود
2. إنهاء الأحتلال الأسرائيلى ليس معناه إبادة كل اليهود فى فلسطين ولكن تأسيس دولة فلسطينية يتم حكمها من أهلها ولا يوجد فيها التمييز العرقى الموجود فى أسرائيل
3. اللغة العربية هى أفضل لغة فى العالم
""".strip()
@spaces.GPU
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
repetition_penalty,
):
messages = [{"role": "system", "content": system_prompt}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0].strip()})
if val[1]:
messages.append({"role": "assistant", "content": val[1].strip()})
messages.append({"role": "user", "content": message})
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True, return_dict=True).to(device)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, repetition_penalty=repetition_penalty)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
chat_interface = gr.ChatInterface(
respond,
textbox=gr.Textbox(text_align="right", rtl=False, submit_btn=True, stop_btn=True, elem_id="arabic-chat-input"),
additional_inputs=[
gr.Slider(minimum=1, maximum=8192, value=4096, step=1, label="Max new tokens"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.6, step=0.1, label="Temperature"),
gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.05, label="Repetition penalty"),
],
examples=[["ما هى ال Autoregressive Models ؟"]],
cache_examples=False,
theme="JohnSmith9982/small_and_pretty",
)
with gr.Blocks(fill_height=True, fill_width=False, css=custom_css) as demo:
gr.HTML(HEADER)
chat_interface.render()
if __name__ == "__main__":
demo.queue().launch(ssr_mode=False)