Spaces:
Running
Running
import os | |
from datetime import date | |
import gradio as gr | |
import openai | |
# Model configuration dictionary | |
MODEL_CONFIGS = { | |
"Falcon-H1-34B-Instruct": { | |
"model_id": "tiiuae/Falcon-H1-34B-Instruct", | |
"api_key_env": "XXL_API_KEY", | |
"base_url_env": "XXL_URL", | |
"description": "XXL (34B)" | |
}, | |
"Falcon-H1-7B-Instruct": { | |
"model_id": "tiiuae/Falcon-H1-7B-Instruct", | |
"api_key_env": "L_API_KEY", | |
"base_url_env": "L_URL", | |
"description": "L (7B)" | |
}, | |
"Falcon-H1-3B-Instruct": { | |
"model_id": "tiiuae/Falcon-H1-3B-Instruct", | |
"api_key_env": "M_API_KEY", | |
"base_url_env": "M_URL", | |
"description": "M (3B)" | |
}, | |
"Falcon-H1-1.5B-Deep-Instruct": { | |
"model_id": "tiiuae/Falcon-H1-1.5B-Deep-Instruct", | |
"api_key_env": "S_API_KEY", | |
"base_url_env": "S_URL", | |
"description": "S (1.5B Deep)" | |
}, | |
"Falcon-H1-1.5B-Instruct": { | |
"model_id": "tiiuae/Falcon-H1-1.5B-Instruct", | |
"api_key_env": "XS_API_KEY", | |
"base_url_env": "XS_URL", | |
"description": "XS (1.5B)" | |
}, | |
"Falcon-H1-0.5B-Instruct": { | |
"model_id": "tiiuae/Falcon-H1-0.5B-Instruct", | |
"api_key_env": "XXS_API_KEY", | |
"base_url_env": "XXS_URL", | |
"description": "XXS (0.5B)" | |
}, | |
} | |
today = date.today() | |
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT") | |
# CSS styling more similar to the second code | |
CSS = """ | |
/* Main style improvements */ | |
.container { | |
max-width: 900px !important; | |
margin-left: auto !important; | |
margin-right: auto !important; | |
} | |
/* Title styling */ | |
h1 { | |
background: linear-gradient(90deg, #4776E6 0%, #8E54E9 100%); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
font-weight: 700 !important; | |
text-align: center; | |
margin-bottom: 0.5rem !important; | |
} | |
.subtitle { | |
text-align: center; | |
color: #666; | |
margin-bottom: 1rem; | |
} | |
/* Button styling */ | |
.duplicate-button { | |
margin: 1rem auto !important; | |
display: block !important; | |
color: #fff !important; | |
background: linear-gradient(90deg, #4776E6 0%, #8E54E9 100%) !important; | |
border-radius: 100vh !important; | |
padding: 0.5rem 1.5rem !important; | |
font-weight: 600 !important; | |
border: none !important; | |
box-shadow: 0 4px 6px rgba(50, 50, 93, 0.11), 0 1px 3px rgba(0, 0, 0, 0.08) !important; | |
} | |
/* Parameter accordion styling */ | |
.accordion { | |
border-radius: 8px !important; | |
overflow: hidden !important; | |
box-shadow: 0 1px 3px rgba(0,0,0,0.1) !important; | |
margin-bottom: 1rem !important; | |
} | |
h3 { | |
text-align: center; | |
} | |
""" | |
TITLE = "<h1>Falcon-H1 Playground</h1>" | |
SUB_TITLE = """ | |
<p class='subtitle'>Falcon-H1 is a new SoTA hybrid model by TII in Abu Dhabi. It is open source and available on Hugging Face. This demo is powered by <a href="https://openinnovation.ai">OpenInnovationAI</a>. Try out our <a href="https://chat.falconllm.tii.ae/auth">chat interface</a>.</p> | |
<p class='subtitle' style='font-size: 0.9rem; color: #888;'></p> | |
""" | |
def stream_chat( | |
message: str, | |
history: list, | |
model_name: str, | |
temperature: float = 0.1, | |
max_new_tokens: int = 1024, | |
top_p: float = 1.0, | |
top_k: int = 20, | |
presence_penalty: float = 1.2, | |
): | |
"""Chat function that streams responses from the selected model""" | |
cfg = MODEL_CONFIGS[model_name] | |
api_key = os.getenv(cfg["api_key_env"]) | |
base_url = os.getenv(cfg.get("base_url_env", ""), None) | |
if not api_key: | |
yield f"❌ Env-var `{cfg['api_key_env']}` not set." | |
return | |
if cfg.get("base_url_env") and not base_url: | |
yield f"❌ Env-var `{cfg['base_url_env']}` not set." | |
return | |
client = openai.OpenAI(api_key=api_key, base_url=base_url) | |
msgs = [{"role": "system", "content": SYSTEM_PROMPT},] | |
for u, a in history: | |
msgs += [{"role": "user", "content": u}, | |
{"role": "assistant", "content": a}] | |
msgs.append({"role": "user", "content": message}) | |
try: | |
stream = client.chat.completions.create( | |
model=cfg["model_id"], | |
messages=msgs, | |
temperature=temperature, | |
top_p=top_p, | |
max_tokens=max_new_tokens, | |
presence_penalty=presence_penalty, | |
stream=True, | |
) | |
partial = "" | |
for chunk in stream: | |
if (delta := chunk.choices[0].delta).content: | |
partial += delta.content | |
yield partial | |
except Exception as e: | |
yield f"❌ Error: {str(e)}" | |
# Create the Gradio interface | |
with gr.Blocks(css=CSS, theme="soft") as demo: | |
# Header section | |
gr.HTML(TITLE) | |
gr.HTML(SUB_TITLE.format(today_date=today.strftime('%B %d, %Y'))) | |
gr.DuplicateButton(value="Duplicate Space", elem_classes="duplicate-button") | |
# Create model selection with descriptions | |
model_options = list(MODEL_CONFIGS.keys()) | |
with gr.Row(): | |
model_dropdown = gr.Dropdown( | |
choices=model_options, | |
value=model_options[0], | |
label="Select Falcon-H1 Model", | |
info="Choose which model checkpoint to use" | |
) | |
# Create chatbot | |
chatbot = gr.Chatbot(height=600, latex_delimiters=[ | |
{"left": "$$", "right": "$$", "display": True}, # For display mode math | |
{"left": "$", "right": "$", "display": False}, # For inline math | |
{"left": "\\(", "right": "\\)", "display": False}, # Common inline delimiters | |
{"left": "\\[", "right": "\\]", "display": True} # Common display delimiters | |
] | |
) | |
# Message input area with a cleaner layout | |
with gr.Row(): | |
with gr.Column(scale=0.85): | |
msg = gr.Textbox( | |
scale=1, | |
show_label=False, | |
placeholder="Enter text and press enter", | |
container=False | |
) | |
with gr.Column(scale=0.15, min_width=0): | |
submit_btn = gr.Button("Submit", variant="primary") | |
# Parameters in accordion similar to second code | |
with gr.Accordion("⚙️ Parameters", open=False, elem_classes="accordion"): | |
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="Temperature", info="Higher values produce more diverse outputs") | |
max_new_tokens = gr.Slider(minimum=64, maximum=4096*8, value=1024, step=64, label="Max new tokens", info="Maximum length of generated response") | |
top_p = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="top_p", info="1.0 means no filtering") | |
top_k = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_k") | |
presence_penalty = gr.Slider(minimum=0, maximum=2, value=1.2, step=0.1, label="Presence penalty", info="Penalizes repetition") | |
# Examples section | |
gr.Examples( | |
examples=[ | |
["Hello there, can you suggest a few places to visit in UAE?"], | |
["What is UAE known for?"], | |
], | |
inputs=msg, | |
) | |
# Chat handler function | |
def user(user_message, history): | |
return "", history + [[user_message, None]] | |
def bot(history, model_choice, temp, max_tokens, top_p_val, top_k_val, penalty): | |
user_message = history[-1][0] | |
history[-1][1] = "" | |
for character in stream_chat( | |
user_message, | |
history[:-1], | |
model_choice, | |
temp, | |
max_tokens, | |
top_p_val, | |
top_k_val, | |
penalty | |
): | |
history[-1][1] = character | |
yield history | |
# Set up event handlers | |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
bot, [chatbot, model_dropdown, temperature, max_new_tokens, top_p, top_k, presence_penalty], chatbot | |
) | |
submit_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
bot, [chatbot, model_dropdown, temperature, max_new_tokens, top_p, top_k, presence_penalty], chatbot | |
) | |
if __name__ == "__main__": | |
demo.launch() |