Ilyasch2's picture
remove incorrect date
9fa69a5
import os
from datetime import date
import gradio as gr
import openai
# Model configuration dictionary
MODEL_CONFIGS = {
"Falcon-H1-34B-Instruct": {
"model_id": "tiiuae/Falcon-H1-34B-Instruct",
"api_key_env": "XXL_API_KEY",
"base_url_env": "XXL_URL",
"description": "XXL (34B)"
},
"Falcon-H1-7B-Instruct": {
"model_id": "tiiuae/Falcon-H1-7B-Instruct",
"api_key_env": "L_API_KEY",
"base_url_env": "L_URL",
"description": "L (7B)"
},
"Falcon-H1-3B-Instruct": {
"model_id": "tiiuae/Falcon-H1-3B-Instruct",
"api_key_env": "M_API_KEY",
"base_url_env": "M_URL",
"description": "M (3B)"
},
"Falcon-H1-1.5B-Deep-Instruct": {
"model_id": "tiiuae/Falcon-H1-1.5B-Deep-Instruct",
"api_key_env": "S_API_KEY",
"base_url_env": "S_URL",
"description": "S (1.5B Deep)"
},
"Falcon-H1-1.5B-Instruct": {
"model_id": "tiiuae/Falcon-H1-1.5B-Instruct",
"api_key_env": "XS_API_KEY",
"base_url_env": "XS_URL",
"description": "XS (1.5B)"
},
"Falcon-H1-0.5B-Instruct": {
"model_id": "tiiuae/Falcon-H1-0.5B-Instruct",
"api_key_env": "XXS_API_KEY",
"base_url_env": "XXS_URL",
"description": "XXS (0.5B)"
},
}
today = date.today()
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT")
# CSS styling more similar to the second code
CSS = """
/* Main style improvements */
.container {
max-width: 900px !important;
margin-left: auto !important;
margin-right: auto !important;
}
/* Title styling */
h1 {
background: linear-gradient(90deg, #4776E6 0%, #8E54E9 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-weight: 700 !important;
text-align: center;
margin-bottom: 0.5rem !important;
}
.subtitle {
text-align: center;
color: #666;
margin-bottom: 1rem;
}
/* Button styling */
.duplicate-button {
margin: 1rem auto !important;
display: block !important;
color: #fff !important;
background: linear-gradient(90deg, #4776E6 0%, #8E54E9 100%) !important;
border-radius: 100vh !important;
padding: 0.5rem 1.5rem !important;
font-weight: 600 !important;
border: none !important;
box-shadow: 0 4px 6px rgba(50, 50, 93, 0.11), 0 1px 3px rgba(0, 0, 0, 0.08) !important;
}
/* Parameter accordion styling */
.accordion {
border-radius: 8px !important;
overflow: hidden !important;
box-shadow: 0 1px 3px rgba(0,0,0,0.1) !important;
margin-bottom: 1rem !important;
}
h3 {
text-align: center;
}
"""
TITLE = "<h1>Falcon-H1 Playground</h1>"
SUB_TITLE = """
<p class='subtitle'>Falcon-H1 is a new SoTA hybrid model by TII in Abu Dhabi. It is open source and available on Hugging Face. This demo is powered by <a href="https://openinnovation.ai">OpenInnovationAI</a>. Try out our <a href="https://chat.falconllm.tii.ae/auth">chat interface</a>.</p>
<p class='subtitle' style='font-size: 0.9rem; color: #888;'></p>
"""
def stream_chat(
message: str,
history: list,
model_name: str,
temperature: float = 0.1,
max_new_tokens: int = 1024,
top_p: float = 1.0,
top_k: int = 20,
presence_penalty: float = 1.2,
):
"""Chat function that streams responses from the selected model"""
cfg = MODEL_CONFIGS[model_name]
api_key = os.getenv(cfg["api_key_env"])
base_url = os.getenv(cfg.get("base_url_env", ""), None)
if not api_key:
yield f"❌ Env-var `{cfg['api_key_env']}` not set."
return
if cfg.get("base_url_env") and not base_url:
yield f"❌ Env-var `{cfg['base_url_env']}` not set."
return
client = openai.OpenAI(api_key=api_key, base_url=base_url)
msgs = [{"role": "system", "content": SYSTEM_PROMPT},]
for u, a in history:
msgs += [{"role": "user", "content": u},
{"role": "assistant", "content": a}]
msgs.append({"role": "user", "content": message})
try:
stream = client.chat.completions.create(
model=cfg["model_id"],
messages=msgs,
temperature=temperature,
top_p=top_p,
max_tokens=max_new_tokens,
presence_penalty=presence_penalty,
stream=True,
)
partial = ""
for chunk in stream:
if (delta := chunk.choices[0].delta).content:
partial += delta.content
yield partial
except Exception as e:
yield f"❌ Error: {str(e)}"
# Create the Gradio interface
with gr.Blocks(css=CSS, theme="soft") as demo:
# Header section
gr.HTML(TITLE)
gr.HTML(SUB_TITLE.format(today_date=today.strftime('%B %d, %Y')))
gr.DuplicateButton(value="Duplicate Space", elem_classes="duplicate-button")
# Create model selection with descriptions
model_options = list(MODEL_CONFIGS.keys())
with gr.Row():
model_dropdown = gr.Dropdown(
choices=model_options,
value=model_options[0],
label="Select Falcon-H1 Model",
info="Choose which model checkpoint to use"
)
# Create chatbot
chatbot = gr.Chatbot(height=600, latex_delimiters=[
{"left": "$$", "right": "$$", "display": True}, # For display mode math
{"left": "$", "right": "$", "display": False}, # For inline math
{"left": "\\(", "right": "\\)", "display": False}, # Common inline delimiters
{"left": "\\[", "right": "\\]", "display": True} # Common display delimiters
]
)
# Message input area with a cleaner layout
with gr.Row():
with gr.Column(scale=0.85):
msg = gr.Textbox(
scale=1,
show_label=False,
placeholder="Enter text and press enter",
container=False
)
with gr.Column(scale=0.15, min_width=0):
submit_btn = gr.Button("Submit", variant="primary")
# Parameters in accordion similar to second code
with gr.Accordion("⚙️ Parameters", open=False, elem_classes="accordion"):
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="Temperature", info="Higher values produce more diverse outputs")
max_new_tokens = gr.Slider(minimum=64, maximum=4096*8, value=1024, step=64, label="Max new tokens", info="Maximum length of generated response")
top_p = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="top_p", info="1.0 means no filtering")
top_k = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_k")
presence_penalty = gr.Slider(minimum=0, maximum=2, value=1.2, step=0.1, label="Presence penalty", info="Penalizes repetition")
# Examples section
gr.Examples(
examples=[
["Hello there, can you suggest a few places to visit in UAE?"],
["What is UAE known for?"],
],
inputs=msg,
)
# Chat handler function
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history, model_choice, temp, max_tokens, top_p_val, top_k_val, penalty):
user_message = history[-1][0]
history[-1][1] = ""
for character in stream_chat(
user_message,
history[:-1],
model_choice,
temp,
max_tokens,
top_p_val,
top_k_val,
penalty
):
history[-1][1] = character
yield history
# Set up event handlers
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, model_dropdown, temperature, max_new_tokens, top_p, top_k, presence_penalty], chatbot
)
submit_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, model_dropdown, temperature, max_new_tokens, top_p, top_k, presence_penalty], chatbot
)
if __name__ == "__main__":
demo.launch()