Spaces:
Running
Running
import datetime | |
from uuid import uuid4 | |
from openai import OpenAI | |
import gradio as gr | |
from theme import apriel | |
from utils import COMMUNITY_POSTFIX_URL, get_model_config, check_format, models_config, \ | |
logged_event_handler, DEBUG_MODEL, log_debug, log_info, log_error | |
from log_chat import log_chat | |
MODEL_TEMPERATURE = 0.8 | |
BUTTON_WIDTH = 160 | |
DEFAULT_OPT_OUT_VALUE = False | |
DEFAULT_MODEL_NAME = "Apriel-Nemotron-15b-Thinker" if not DEBUG_MODEL else "Apriel-5b" | |
BUTTON_ENABLED = gr.update(interactive=True) | |
BUTTON_DISABLED = gr.update(interactive=False) | |
INPUT_ENABLED = gr.update(interactive=True) | |
INPUT_DISABLED = gr.update(interactive=False) | |
DROPDOWN_ENABLED = gr.update(interactive=True) | |
DROPDOWN_DISABLED = gr.update(interactive=False) | |
SEND_BUTTON_ENABLED = gr.update(interactive=True, visible=True) | |
SEND_BUTTON_DISABLED = gr.update(interactive=True, visible=False) | |
STOP_BUTTON_ENABLED = gr.update(interactive=True, visible=True) | |
STOP_BUTTON_DISABLED = gr.update(interactive=True, visible=False) | |
chat_start_count = 0 | |
model_config = {} | |
openai_client = None | |
def app_loaded(state, request: gr.Request): | |
message_html = setup_model(DEFAULT_MODEL_NAME, intial=False) | |
state['session'] = request.session_hash if request else uuid4().hex | |
log_debug(f"app_loaded() --> Session: {state['session']}") | |
return state, message_html | |
def update_model_and_clear_chat(model_name): | |
actual_model_name = model_name.replace("Model: ", "") | |
desc = setup_model(actual_model_name) | |
return desc, [] | |
def setup_model(model_name, intial=False): | |
global model_config, openai_client | |
model_config = get_model_config(model_name) | |
log_debug(f"update_model() --> Model config: {model_config}") | |
openai_client = OpenAI( | |
api_key=model_config.get('AUTH_TOKEN'), | |
base_url=model_config.get('VLLM_API_URL') | |
) | |
_model_hf_name = model_config.get("MODEL_HF_URL").split('https://huggingface.co/')[1] | |
_link = f"<a href='{model_config.get('MODEL_HF_URL')}{COMMUNITY_POSTFIX_URL}' target='_blank'>{_model_hf_name}</a>" | |
_description = f"We'd love to hear your thoughts on the model. Click here to provide feedback - {_link}" | |
log_debug(f"Switched to model {_model_hf_name}") | |
if intial: | |
return | |
else: | |
return _description | |
def chat_started(): | |
# outputs: model_dropdown, user_input, send_btn, stop_btn, clear_btn | |
return (DROPDOWN_DISABLED, gr.update(value="", interactive=False), | |
SEND_BUTTON_DISABLED, STOP_BUTTON_ENABLED, BUTTON_DISABLED) | |
def chat_finished(): | |
# outputs: model_dropdown, user_input, send_btn, stop_btn, clear_btn | |
return DROPDOWN_ENABLED, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED | |
def stop_chat(state): | |
state["stop_flag"] = True | |
gr.Info("Chat stopped") | |
return state | |
def toggle_opt_out(state, checkbox): | |
state["opt_out"] = checkbox | |
return state | |
def run_chat_inference(history, message, state): | |
global chat_start_count | |
state["is_streaming"] = True | |
state["stop_flag"] = False | |
error = None | |
model_name = model_config.get('MODEL_NAME') | |
if len(history) == 0: | |
state["chat_id"] = uuid4().hex | |
if openai_client is None: | |
log_info("Client UI is stale, letting user know to refresh the page") | |
gr.Warning("Client UI is stale, please refresh the page") | |
return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state | |
# outputs: model_dropdown, user_input, send_btn, stop_btn, clear_btn, session_state | |
log_debug(f"{'-' * 80}") | |
log_debug(f"chat_fn() --> Message: {message}") | |
log_debug(f"chat_fn() --> History: {history}") | |
try: | |
# Check if the message is empty | |
if not message.strip(): | |
gr.Info("Please enter a message before sending") | |
yield history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state | |
return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state | |
chat_start_count = chat_start_count + 1 | |
user_messages_count = sum(1 for item in history if isinstance(item, dict) and item.get("role") == "user") | |
log_info(f"chat_start_count: {chat_start_count}, turns: {user_messages_count}, model: {model_name}") | |
is_reasoning = model_config.get("REASONING") | |
# Remove any assistant messages with metadata from history for multiple turns | |
log_debug(f"Initial History: {history}") | |
check_format(history, "messages") | |
history.append({"role": "user", "content": message}) | |
log_debug(f"History with user message: {history}") | |
check_format(history, "messages") | |
# Create the streaming response | |
try: | |
history_no_thoughts = [item for item in history if | |
not (isinstance(item, dict) and | |
item.get("role") == "assistant" and | |
isinstance(item.get("metadata"), dict) and | |
item.get("metadata", {}).get("title") is not None)] | |
log_debug(f"Updated History: {history_no_thoughts}") | |
check_format(history_no_thoughts, "messages") | |
log_debug(f"history_no_thoughts with user message: {history_no_thoughts}") | |
stream = openai_client.chat.completions.create( | |
model=model_name, | |
messages=history_no_thoughts, | |
temperature=MODEL_TEMPERATURE, | |
stream=True | |
) | |
except Exception as e: | |
log_error(f"Error: {e}") | |
error = str(e) | |
yield ([{"role": "assistant", | |
"content": "😔 The model is unavailable at the moment. Please try again later."}], | |
INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state) | |
if state["opt_out"] is not True: | |
log_chat(chat_id=state["chat_id"], | |
session_id=state["session"], | |
model_name=model_name, | |
prompt=message, | |
history=history, | |
info={"is_reasoning": model_config.get("REASONING"), "temperature": MODEL_TEMPERATURE, | |
"stopped": True, "error": str(e)}, | |
) | |
else: | |
log_info(f"User opted out of chat history. Not logging chat. model: {model_name}") | |
return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state | |
if is_reasoning: | |
history.append(gr.ChatMessage( | |
role="assistant", | |
content="Thinking...", | |
metadata={"title": "🧠 Thought"} | |
)) | |
log_debug(f"History added thinking: {history}") | |
check_format(history, "messages") | |
else: | |
history.append(gr.ChatMessage( | |
role="assistant", | |
content="", | |
)) | |
log_debug(f"History added empty assistant: {history}") | |
check_format(history, "messages") | |
output = "" | |
completion_started = False | |
for chunk in stream: | |
if state["stop_flag"]: | |
log_debug(f"chat_fn() --> Stopping streaming...") | |
break # Exit the loop if the stop flag is set | |
# Extract the new content from the delta field | |
content = getattr(chunk.choices[0].delta, "content", "") | |
output += content | |
if is_reasoning: | |
parts = output.split("[BEGIN FINAL RESPONSE]") | |
if len(parts) > 1: | |
if parts[1].endswith("[END FINAL RESPONSE]"): | |
parts[1] = parts[1].replace("[END FINAL RESPONSE]", "") | |
if parts[1].endswith("[END FINAL RESPONSE]\n<|end|>"): | |
parts[1] = parts[1].replace("[END FINAL RESPONSE]\n<|end|>", "") | |
if parts[1].endswith("<|end|>"): | |
parts[1] = parts[1].replace("<|end|>", "") | |
history[-1 if not completion_started else -2] = gr.ChatMessage( | |
role="assistant", | |
content=parts[0], | |
metadata={"title": "🧠 Thought"} | |
) | |
if completion_started: | |
history[-1] = gr.ChatMessage( | |
role="assistant", | |
content=parts[1] | |
) | |
elif len(parts) > 1 and not completion_started: | |
completion_started = True | |
history.append(gr.ChatMessage( | |
role="assistant", | |
content=parts[1] | |
)) | |
else: | |
if output.endswith("<|end|>"): | |
output = output.replace("<|end|>", "") | |
history[-1] = gr.ChatMessage( | |
role="assistant", | |
content=output | |
) | |
# log_message(f"Yielding messages: {history}") | |
yield history, INPUT_DISABLED, SEND_BUTTON_DISABLED, STOP_BUTTON_ENABLED, BUTTON_DISABLED, state | |
log_debug(f"Final History: {history}") | |
check_format(history, "messages") | |
yield history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state | |
finally: | |
if error is None: | |
log_debug(f"chat_fn() --> Finished streaming. {chat_start_count} chats started.") | |
if state["opt_out"] is not True: | |
log_chat(chat_id=state["chat_id"], | |
session_id=state["session"], | |
model_name=model_name, | |
prompt=message, | |
history=history, | |
info={"is_reasoning": model_config.get("REASONING"), "temperature": MODEL_TEMPERATURE, | |
"stopped": state["stop_flag"]}, | |
) | |
else: | |
log_info(f"User opted out of chat history. Not logging chat. model: {model_name}") | |
state["is_streaming"] = False | |
state["stop_flag"] = False | |
return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state | |
log_info(f"Gradio version: {gr.__version__}") | |
title = None | |
description = None | |
theme = apriel | |
with open('styles.css', 'r') as f: | |
custom_css = f.read() | |
with gr.Blocks(theme=theme, css=custom_css) as demo: | |
session_state = gr.State(value={ | |
"is_streaming": False, | |
"stop_flag": False, | |
"chat_id": None, | |
"session": None, | |
"opt_out": DEFAULT_OPT_OUT_VALUE, | |
}) # Store session state as a dictionary | |
gr.HTML(f""" | |
<style> | |
@media (min-width: 1024px) {{ | |
.send-button-container, .clear-button-container {{ | |
max-width: {BUTTON_WIDTH}px; | |
}} | |
}} | |
</style> | |
""", elem_classes="css-styles") | |
with gr.Row(variant="panel", elem_classes="responsive-row"): | |
with gr.Column(scale=1, min_width=400, elem_classes="model-dropdown-container"): | |
model_dropdown = gr.Dropdown( | |
choices=[f"Model: {model}" for model in models_config.keys()], | |
value=f"Model: {DEFAULT_MODEL_NAME}", | |
label=None, | |
interactive=True, | |
container=False, | |
scale=0, | |
min_width=400 | |
) | |
with gr.Column(scale=4, min_width=0): | |
feedback_message_html = gr.HTML(description, elem_classes="model-message") | |
chatbot = gr.Chatbot( | |
type="messages", | |
height="calc(100dvh - 310px)", | |
elem_classes="chatbot", | |
) | |
with gr.Row(): | |
with gr.Column(scale=10, min_width=400): | |
with gr.Row(): | |
user_input = gr.Textbox( | |
show_label=False, | |
placeholder="Type your message here and press Enter", | |
container=False | |
) | |
with gr.Column(scale=1, min_width=BUTTON_WIDTH * 2 + 20): | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=BUTTON_WIDTH, elem_classes="send-button-container"): | |
send_btn = gr.Button("Send", variant="primary") | |
stop_btn = gr.Button("Stop", variant="cancel", visible=False) | |
with gr.Column(scale=1, min_width=BUTTON_WIDTH, elem_classes="clear-button-container"): | |
clear_btn = gr.ClearButton(chatbot, value="New Chat", variant="secondary") | |
with gr.Row(): | |
with gr.Column(min_width=400, elem_classes="opt-out-container"): | |
with gr.Row(): | |
gr.HTML( | |
"We may use your chats to improve our AI. You may opt out if you don’t want your conversations saved.", | |
elem_classes="opt-out-message") | |
with gr.Row(): | |
opt_out_checkbox = gr.Checkbox( | |
label="Don’t save my chat history for improvements or training", | |
value=DEFAULT_OPT_OUT_VALUE, | |
elem_classes="opt-out-checkbox", | |
interactive=True, | |
container=False | |
) | |
gr.on( | |
triggers=[send_btn.click, user_input.submit], | |
fn=run_chat_inference, # this generator streams results. do not use logged_event_handler wrapper | |
inputs=[chatbot, user_input, session_state], | |
outputs=[chatbot, user_input, send_btn, stop_btn, clear_btn, session_state], | |
concurrency_limit=4, | |
api_name=False | |
).then( | |
fn=chat_finished, inputs=None, outputs=[model_dropdown, user_input, send_btn, stop_btn, clear_btn], queue=False) | |
# In parallel, disable or update the UI controls | |
gr.on( | |
triggers=[send_btn.click, user_input.submit], | |
fn=chat_started, | |
inputs=None, | |
outputs=[model_dropdown, user_input, send_btn, stop_btn, clear_btn], | |
queue=False, | |
show_progress='hidden', | |
api_name=False | |
) | |
stop_btn.click( | |
fn=stop_chat, | |
inputs=[session_state], | |
outputs=[session_state], | |
api_name=False | |
) | |
opt_out_checkbox.change(fn=toggle_opt_out, inputs=[session_state, opt_out_checkbox], outputs=[session_state]) | |
# Ensure the model is reset to default on page reload | |
demo.load( | |
fn=logged_event_handler( | |
log_msg="Browser session started", | |
event_handler=app_loaded | |
), | |
inputs=[session_state], | |
outputs=[session_state, feedback_message_html], | |
queue=True, | |
api_name=False | |
) | |
model_dropdown.change( | |
fn=update_model_and_clear_chat, | |
inputs=[model_dropdown], | |
outputs=[feedback_message_html, chatbot], | |
api_name=False | |
) | |
demo.queue(default_concurrency_limit=2).launch(ssr_mode=False, show_api=False) | |
log_info("Gradio app launched") | |