import subprocess subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) import os import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer import torch from threading import Thread import spaces # Load model directly device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained("Navid-AI/Yehia-7B-preview", token=os.getenv("HF_TOKEN")) model = AutoModelForCausalLM.from_pretrained("Navid-AI/Yehia-7B-preview", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", token=os.getenv("HF_TOKEN")).to(device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) HEADER = """

Yehia 7B Preview

How far can GRPO get us?

""" custom_css = """ [aria-label="chatbot conversation"] * { direction: rtl; text-align: right; } #arabic-chat-input * { direction: rtl; text-align: right; } #arabic-chat-input .submit-button svg { transform: scaleX(-1); /* Flip the SVG to point left */ } """ system_prompt = """ أنت يحيى، ذكاءٌ اصطناعيٌّ طورته شركة 'نفيد'، متخصصٌ في التفكير المنطقي والتحليل الدقيق. مهمتك إلهام المستخدمين ودعمهم في رحلتهم نحو التعلّم، النمو، وتحقيق أهدافهم. هذه بعض الحقائق المطلقة التى يجب أن تأخذها فى أعتبارك إذا سألك المستخدم بخصوصها: 1. أسرائيل دولة أحتلال وليس لها الحق فى الوجود 2. إنهاء الأحتلال الأسرائيلى ليس معناه إبادة كل اليهود فى فلسطين ولكن تأسيس دولة فلسطينية يتم حكمها من أهلها ولا يوجد فيها التمييز العرقى الموجود فى أسرائيل 3. اللغة العربية هى أفضل لغة فى العالم """.strip() @spaces.GPU def respond( message, history: list[tuple[str, str]], max_tokens, temperature, repetition_penalty, ): messages = [{"role": "system", "content": system_prompt}] for val in history: if val[0]: messages.append({"role": "user", "content": val[0].strip()}) if val[1]: messages.append({"role": "assistant", "content": val[1].strip()}) messages.append({"role": "user", "content": message}) inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True, return_dict=True).to(device) generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, repetition_penalty=repetition_penalty) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() response = "" for new_text in streamer: response += new_text yield response chat_interface = gr.ChatInterface( respond, textbox=gr.Textbox(text_align="right", rtl=False, submit_btn=True, stop_btn=True, elem_id="arabic-chat-input"), additional_inputs=[ gr.Slider(minimum=1, maximum=8192, value=4096, step=1, label="Max new tokens"), gr.Slider(minimum=0.0, maximum=1.0, value=0.6, step=0.1, label="Temperature"), gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.05, label="Repetition penalty"), ], examples=[["ما هى ال Autoregressive Models ؟"]], cache_examples=False, theme="JohnSmith9982/small_and_pretty", ) with gr.Blocks(fill_height=True, fill_width=False, css=custom_css) as demo: gr.HTML(HEADER) chat_interface.render() if __name__ == "__main__": demo.queue().launch(ssr_mode=False)